From caa698a5b13563f6dc67eecebc9b6f7de8399b92 Mon Sep 17 00:00:00 2001 From: Piotr Idzik <65706193+vil02@users.noreply.github.com> Date: Sat, 11 Nov 2023 22:07:22 +0100 Subject: [PATCH] Handle nonunique max frequency elements in `mode` (#603) --- src/math/average.rs | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/src/math/average.rs b/src/math/average.rs index 3eb27c0eddc..b4842def20e 100644 --- a/src/math/average.rs +++ b/src/math/average.rs @@ -10,6 +10,7 @@ This program approximates the mean, median and mode of a finite sequence. Note: `mean` function only limited to float 64 numbers. Floats sequences are not allowed for `median` & `mode` functions. "] use std::collections::HashMap; +use std::collections::HashSet; /// # Argument /// /// * `sequence` - A vector of float64 numbers. @@ -45,17 +46,24 @@ pub fn median(mut sequence: Vec) -> T { } } +fn histogram(sequence: Vec) -> HashMap { + sequence.into_iter().fold(HashMap::new(), |mut res, val| { + *res.entry(val).or_insert(0) += 1; + res + }) +} + /// # Argument /// /// * `sequence` - The input vector. /// Returns mode of `sequence`. -pub fn mode(sequence: Vec) -> T { - let mut hash = HashMap::new(); - for value in sequence { - let count = hash.entry(value).or_insert(0); - *count += 1; - } - *hash.iter().max_by_key(|entry| entry.1).unwrap().0 +pub fn mode(sequence: Vec) -> HashSet { + let hist = histogram(sequence); + let max_count = *hist.values().max().unwrap(); + hist.into_iter() + .filter(|(_, count)| *count == max_count) + .map(|(value, _)| value) + .collect() } #[cfg(test)] @@ -71,9 +79,15 @@ mod test { } #[test] fn mode_test() { - assert_eq!(mode(vec![4, 53, 2, 1, 9, 0, 2, 3, 6]), 2); - assert_eq!(mode(vec![-9, -8, 0, 1, 2, 2, 3, -1, -1, 9, -1, -9]), -1); - assert_eq!(mode(vec!["a", "b", "a"]), "a"); + assert_eq!(mode(vec![4, 53, 2, 1, 9, 0, 2, 3, 6]), HashSet::from([2])); + assert_eq!( + mode(vec![-9, -8, 0, 1, 2, 2, 3, -1, -1, 9, -1, -9]), + HashSet::from([-1]) + ); + assert_eq!(mode(vec!["a", "b", "a"]), HashSet::from(["a"])); + assert_eq!(mode(vec![1, 2, 2, 1]), HashSet::from([1, 2])); + assert_eq!(mode(vec![1, 2, 2, 1, 3]), HashSet::from([1, 2])); + assert_eq!(mode(vec![1]), HashSet::from([1])); } #[test] fn mean_test() {