itertools/
combinations_with_replacement.rs

1use alloc::boxed::Box;
2use alloc::vec::Vec;
3use std::fmt;
4use std::iter::FusedIterator;
5
6use super::lazy_buffer::LazyBuffer;
7use crate::adaptors::checked_binomial;
8
9/// An iterator to iterate through all the `n`-length combinations in an iterator, with replacement.
10///
11/// See [`.combinations_with_replacement()`](crate::Itertools::combinations_with_replacement)
12/// for more information.
13#[derive(Clone)]
14#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
15pub struct CombinationsWithReplacement<I>
16where
17    I: Iterator,
18    I::Item: Clone,
19{
20    indices: Box<[usize]>,
21    pool: LazyBuffer<I>,
22    first: bool,
23}
24
25impl<I> fmt::Debug for CombinationsWithReplacement<I>
26where
27    I: Iterator + fmt::Debug,
28    I::Item: fmt::Debug + Clone,
29{
30    debug_fmt_fields!(CombinationsWithReplacement, indices, pool, first);
31}
32
33impl<I> CombinationsWithReplacement<I>
34where
35    I: Iterator,
36    I::Item: Clone,
37{
38    /// Map the current mask over the pool to get an output combination
39    fn current(&self) -> Vec<I::Item> {
40        self.indices.iter().map(|i| self.pool[*i].clone()).collect()
41    }
42}
43
44/// Create a new `CombinationsWithReplacement` from a clonable iterator.
45pub fn combinations_with_replacement<I>(iter: I, k: usize) -> CombinationsWithReplacement<I>
46where
47    I: Iterator,
48    I::Item: Clone,
49{
50    let indices = alloc::vec![0; k].into_boxed_slice();
51    let pool: LazyBuffer<I> = LazyBuffer::new(iter);
52
53    CombinationsWithReplacement {
54        indices,
55        pool,
56        first: true,
57    }
58}
59
60impl<I> Iterator for CombinationsWithReplacement<I>
61where
62    I: Iterator,
63    I::Item: Clone,
64{
65    type Item = Vec<I::Item>;
66    fn next(&mut self) -> Option<Self::Item> {
67        // If this is the first iteration, return early
68        if self.first {
69            // In empty edge cases, stop iterating immediately
70            return if !(self.indices.is_empty() || self.pool.get_next()) {
71                None
72            // Otherwise, yield the initial state
73            } else {
74                self.first = false;
75                Some(self.current())
76            };
77        }
78
79        // Check if we need to consume more from the iterator
80        // This will run while we increment our first index digit
81        self.pool.get_next();
82
83        // Work out where we need to update our indices
84        let mut increment: Option<(usize, usize)> = None;
85        for (i, indices_int) in self.indices.iter().enumerate().rev() {
86            if *indices_int < self.pool.len() - 1 {
87                increment = Some((i, indices_int + 1));
88                break;
89            }
90        }
91
92        match increment {
93            // If we can update the indices further
94            Some((increment_from, increment_value)) => {
95                // We need to update the rightmost non-max value
96                // and all those to the right
97                for indices_index in increment_from..self.indices.len() {
98                    self.indices[indices_index] = increment_value;
99                }
100                Some(self.current())
101            }
102            // Otherwise, we're done
103            None => None,
104        }
105    }
106
107    fn size_hint(&self) -> (usize, Option<usize>) {
108        let (mut low, mut upp) = self.pool.size_hint();
109        low = remaining_for(low, self.first, &self.indices).unwrap_or(usize::MAX);
110        upp = upp.and_then(|upp| remaining_for(upp, self.first, &self.indices));
111        (low, upp)
112    }
113
114    fn count(self) -> usize {
115        let Self {
116            indices,
117            pool,
118            first,
119        } = self;
120        let n = pool.count();
121        remaining_for(n, first, &indices).unwrap()
122    }
123}
124
125impl<I> FusedIterator for CombinationsWithReplacement<I>
126where
127    I: Iterator,
128    I::Item: Clone,
129{
130}
131
132/// For a given size `n`, return the count of remaining combinations with replacement or None if it would overflow.
133fn remaining_for(n: usize, first: bool, indices: &[usize]) -> Option<usize> {
134    // With a "stars and bars" representation, choose k values with replacement from n values is
135    // like choosing k out of k + n − 1 positions (hence binomial(k + n - 1, k) possibilities)
136    // to place k stars and therefore n - 1 bars.
137    // Example (n=4, k=6): ***|*||** represents [0,0,0,1,3,3].
138    let count = |n: usize, k: usize| {
139        let positions = if n == 0 {
140            k.saturating_sub(1)
141        } else {
142            (n - 1).checked_add(k)?
143        };
144        checked_binomial(positions, k)
145    };
146    let k = indices.len();
147    if first {
148        count(n, k)
149    } else {
150        // The algorithm is similar to the one for combinations *without replacement*,
151        // except we choose values *with replacement* and indices are *non-strictly* monotonically sorted.
152
153        // The combinations generated after the current one can be counted by counting as follows:
154        // - The subsequent combinations that differ in indices[0]:
155        //   If subsequent combinations differ in indices[0], then their value for indices[0]
156        //   must be at least 1 greater than the current indices[0].
157        //   As indices is monotonically sorted, this means we can effectively choose k values with
158        //   replacement from (n - 1 - indices[0]), leading to count(n - 1 - indices[0], k) possibilities.
159        // - The subsequent combinations with same indices[0], but differing indices[1]:
160        //   Here we can choose k - 1 values with replacement from (n - 1 - indices[1]) values,
161        //   leading to count(n - 1 - indices[1], k - 1) possibilities.
162        // - (...)
163        // - The subsequent combinations with same indices[0..=i], but differing indices[i]:
164        //   Here we can choose k - i values with replacement from (n - 1 - indices[i]) values: count(n - 1 - indices[i], k - i).
165        //   Since subsequent combinations can in any index, we must sum up the aforementioned binomial coefficients.
166
167        // Below, `n0` resembles indices[i].
168        indices.iter().enumerate().try_fold(0usize, |sum, (i, n0)| {
169            sum.checked_add(count(n - 1 - *n0, k - i)?)
170        })
171    }
172}