pyfru 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. pyfru-0.1.0/Cargo.toml +4 -0
  2. pyfru-0.1.0/PKG-INFO +39 -0
  3. pyfru-0.1.0/README.md +9 -0
  4. pyfru-0.1.0/fru-arrow/Cargo.toml +26 -0
  5. pyfru-0.1.0/fru-arrow/README.md +9 -0
  6. pyfru-0.1.0/fru-arrow/src/attribute.rs +191 -0
  7. pyfru-0.1.0/fru-arrow/src/classification/da.rs +59 -0
  8. pyfru-0.1.0/fru-arrow/src/classification/impurity.rs +187 -0
  9. pyfru-0.1.0/fru-arrow/src/classification/votes.rs +45 -0
  10. pyfru-0.1.0/fru-arrow/src/classification.rs +137 -0
  11. pyfru-0.1.0/fru-arrow/src/lib.rs +738 -0
  12. pyfru-0.1.0/fru-arrow/src/regression/da.rs +46 -0
  13. pyfru-0.1.0/fru-arrow/src/regression/impurity.rs +143 -0
  14. pyfru-0.1.0/fru-arrow/src/regression/votes.rs +38 -0
  15. pyfru-0.1.0/fru-arrow/src/regression.rs +159 -0
  16. pyfru-0.1.0/fru-arrow/src/serialize.rs +223 -0
  17. pyfru-0.1.0/fru-arrow/src/tools.rs +27 -0
  18. pyfru-0.1.0/fru-arrow/tests/random_forest.rs +518 -0
  19. pyfru-0.1.0/pyfru/Cargo.lock +335 -0
  20. pyfru-0.1.0/pyfru/Cargo.toml +20 -0
  21. pyfru-0.1.0/pyfru/README.md +9 -0
  22. pyfru-0.1.0/pyfru/python/docs/Makefile +16 -0
  23. pyfru-0.1.0/pyfru/python/docs/api.rst +8 -0
  24. pyfru-0.1.0/pyfru/python/docs/conf.py +43 -0
  25. pyfru-0.1.0/pyfru/python/docs/getting_started.rst +217 -0
  26. pyfru-0.1.0/pyfru/python/docs/index.rst +14 -0
  27. pyfru-0.1.0/pyfru/python/tests/test_pyfru.py +356 -0
  28. pyfru-0.1.0/pyfru/src/lib.rs +144 -0
  29. pyfru-0.1.0/pyfru/uv.lock +808 -0
  30. pyfru-0.1.0/pyproject.toml +49 -0
  31. pyfru-0.1.0/python/pyfru/__init__.py +3 -0
  32. pyfru-0.1.0/python/pyfru/data_structures.py +74 -0
  33. pyfru-0.1.0/python/pyfru/random_forest.py +411 -0
pyfru-0.1.0/Cargo.toml ADDED
@@ -0,0 +1,4 @@
1
+ [workspace]
2
+ members = ["fru-arrow", "pyfru"]
3
+
4
+ resolver = "3"
pyfru-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,39 @@
1
+ Metadata-Version: 2.4
2
+ Name: pyfru
3
+ Version: 0.1.0
4
+ Classifier: Development Status :: 4 - Beta
5
+ Classifier: License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)
6
+ Classifier: Operating System :: OS Independent
7
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
8
+ Classifier: Programming Language :: Python
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Programming Language :: Python :: 3.12
11
+ Classifier: Programming Language :: Python :: 3.13
12
+ Classifier: Programming Language :: Python :: 3.14
13
+ Requires-Dist: numpy>=2.0.0
14
+ Requires-Dist: pyarrow>=23.0.1
15
+ Requires-Dist: pytest>=9 ; extra == 'dev'
16
+ Requires-Dist: pandas>=3 ; extra == 'dev'
17
+ Requires-Dist: furo>=2025.12.19 ; extra == 'dev'
18
+ Requires-Dist: myst-parser>=5.1.0 ; extra == 'dev'
19
+ Requires-Dist: ruff>=0.15.12 ; extra == 'dev'
20
+ Requires-Dist: sphinx>=9.1.0 ; extra == 'dev'
21
+ Provides-Extra: dev
22
+ Summary: Blazingly fast Random Forest with arrow support (pandas, polars, pyarrow etc.)
23
+ Author-email: Chris Piwonski <fruarrow@kpiwonski.com>
24
+ License-Expression: GPL-3.0-or-later
25
+ Requires-Python: >=3.12
26
+ Description-Content-Type: text/markdown; charset=UTF-8; variant=GFM
27
+ Project-URL: documentation, https://kpiwonski.github.io/fru-arrow/
28
+ Project-URL: repository, https://github.com/kpiwonski/fru-arrow
29
+
30
+ # Fru-arrow
31
+
32
+ [R version](https://cran.r-project.org/web/packages/fru/index.html) |
33
+ [Pyfru docs](https://kpiwonski.github.io/fru-arrow/)
34
+
35
+ Fru-arrow is a highly performant implementation of the **Random Forest** model. It uses Arrow PyCapsule underneath,
36
+ making integration with any library that supports it - ``polars``, ``pandas``, ``pyarrow`` straightforward.
37
+ Moreover, it features permutation importance with a novel, highly optimized algorithm.
38
+ It can be used for both **classification** and **regression**, as well as out-of-bag predictions.
39
+
pyfru-0.1.0/README.md ADDED
@@ -0,0 +1,9 @@
1
+ # Fru-arrow
2
+
3
+ [R version](https://cran.r-project.org/web/packages/fru/index.html) |
4
+ [Pyfru docs](https://kpiwonski.github.io/fru-arrow/)
5
+
6
+ Fru-arrow is a highly performant implementation of the **Random Forest** model. It uses Arrow PyCapsule underneath,
7
+ making integration with any library that supports it - ``polars``, ``pandas``, ``pyarrow`` straightforward.
8
+ Moreover, it features permutation importance with a novel, highly optimized algorithm.
9
+ It can be used for both **classification** and **regression**, as well as out-of-bag predictions.
@@ -0,0 +1,26 @@
1
+ [package]
2
+ name = "fru-arrow"
3
+ version = "0.1.0"
4
+ authors = ["Krzysztof Piwonski <fruarrow@kpiwonski.com>", "Miron B. Kursa <m@mbq.me>"]
5
+ edition="2024"
6
+ license = "GPL-3.0-or-later"
7
+ description = "Blazingly fast implementation of Random Forest with apache arrow support"
8
+ readme = "README.md"
9
+ keywords = ["machine-learning", "ml", "random-forest", "decision-tree"]
10
+ categories = ["artificial-intelligence"]
11
+ repository = "https://github.com/kpiwonski/fru-arrow"
12
+
13
+ [dependencies]
14
+ xrf = {version="0.1.7"}
15
+ minarrow = {version="0.11.0", features=["extended_numeric_types", "extended_categorical"]}
16
+ serde = "1.0.228"
17
+ postcard = {version = "1.1.3", features = ["use-std"]}
18
+ thiserror = "2.0.18"
19
+
20
+ [dev-dependencies]
21
+ rand = "0.9"
22
+ minarrow = {version="0.11.0", features=["extended_numeric_types", "extended_categorical", "views"]}
23
+
24
+ [profile.release]
25
+ codegen-units = 1
26
+ lto = "fat"
@@ -0,0 +1,9 @@
1
+ # Fru-arrow
2
+
3
+ [R version](https://cran.r-project.org/web/packages/fru/index.html) |
4
+ [Pyfru docs](https://kpiwonski.github.io/fru-arrow/)
5
+
6
+ Fru-arrow is a highly performant implementation of the **Random Forest** model. It uses Arrow PyCapsule underneath,
7
+ making integration with any library that supports it - ``polars``, ``pandas``, ``pyarrow`` straightforward.
8
+ Moreover, it features permutation importance with a novel, highly optimized algorithm.
9
+ It can be used for both **classification** and **regression**, as well as out-of-bag predictions.
@@ -0,0 +1,191 @@
1
+ use std::marker::PhantomData;
2
+
3
+ use minarrow::{
4
+ Array, BooleanArray, CategoricalArray, FloatArray, IntegerArray, NumericArray, TextArray,
5
+ };
6
+ use serde::{Deserialize, Serialize};
7
+ use xrf::{FeatureSampler, RfInput, RfRng};
8
+
9
+ #[derive(Serialize, Deserialize, Debug, Clone)]
10
+ pub enum DfPivot {
11
+ Logical,
12
+ Real(f64),
13
+ Integer(i64),
14
+ UInteger(u64),
15
+ Subset(u64),
16
+ }
17
+
18
+ pub struct SplittingIterator<'a, M> {
19
+ pair: DfSplittingPair<'a>,
20
+ mask_iter: M,
21
+ }
22
+
23
+ enum DfSplittingPair<'a> {
24
+ Boolean(&'a BooleanArray<()>),
25
+ Float32(&'a FloatArray<f32>, f64),
26
+ Float64(&'a FloatArray<f64>, f64),
27
+ Integer8(&'a IntegerArray<i8>, i64),
28
+ Integer16(&'a IntegerArray<i16>, i64),
29
+ Integer32(&'a IntegerArray<i32>, i64),
30
+ Integer64(&'a IntegerArray<i64>, i64),
31
+ UInteger8(&'a IntegerArray<u8>, u64),
32
+ UInteger16(&'a IntegerArray<u16>, u64),
33
+ UInteger32(&'a IntegerArray<u32>, u64),
34
+ UInteger64(&'a IntegerArray<u64>, u64),
35
+ Categorical8(&'a CategoricalArray<u8>, u64),
36
+ Categorical16(&'a CategoricalArray<u16>, u64),
37
+ Categorical32(&'a CategoricalArray<u32>, u64),
38
+ Categorical64(&'a CategoricalArray<u64>, u64),
39
+ }
40
+ impl<'a, M> SplittingIterator<'a, M> {
41
+ pub fn new(x: &'a Array, pivot: &DfPivot, mask_iter: M) -> Self {
42
+ let pair = match x {
43
+ Array::NumericArray(num) => match (num, pivot) {
44
+ (NumericArray::Float32(arr), &DfPivot::Real(xt)) => {
45
+ DfSplittingPair::Float32(arr, xt)
46
+ }
47
+ (NumericArray::Float64(arr), &DfPivot::Real(xt)) => {
48
+ DfSplittingPair::Float64(arr, xt)
49
+ }
50
+ (NumericArray::Int8(arr), &DfPivot::Integer(xt)) => {
51
+ DfSplittingPair::Integer8(arr, xt)
52
+ }
53
+ (NumericArray::Int16(arr), &DfPivot::Integer(xt)) => {
54
+ DfSplittingPair::Integer16(arr, xt)
55
+ }
56
+ (NumericArray::Int32(arr), &DfPivot::Integer(xt)) => {
57
+ DfSplittingPair::Integer32(arr, xt)
58
+ }
59
+ (NumericArray::Int64(arr), &DfPivot::Integer(xt)) => {
60
+ DfSplittingPair::Integer64(arr, xt)
61
+ }
62
+ (NumericArray::UInt8(arr), &DfPivot::UInteger(xt)) => {
63
+ DfSplittingPair::UInteger8(arr, xt)
64
+ }
65
+ (NumericArray::UInt16(arr), &DfPivot::UInteger(xt)) => {
66
+ DfSplittingPair::UInteger16(arr, xt)
67
+ }
68
+ (NumericArray::UInt32(arr), &DfPivot::UInteger(xt)) => {
69
+ DfSplittingPair::UInteger32(arr, xt)
70
+ }
71
+ (NumericArray::UInt64(arr), &DfPivot::UInteger(xt)) => {
72
+ DfSplittingPair::UInteger64(arr, xt)
73
+ }
74
+ _ => panic!("Unsupported array type!"),
75
+ },
76
+ Array::TextArray(cat) => match (cat, pivot) {
77
+ (TextArray::Categorical8(arr), &DfPivot::Subset(sub)) => {
78
+ DfSplittingPair::Categorical8(arr, sub)
79
+ }
80
+ (TextArray::Categorical16(arr), &DfPivot::Subset(sub)) => {
81
+ DfSplittingPair::Categorical16(arr, sub)
82
+ }
83
+ (TextArray::Categorical32(arr), &DfPivot::Subset(sub)) => {
84
+ DfSplittingPair::Categorical32(arr, sub)
85
+ }
86
+ (TextArray::Categorical64(arr), &DfPivot::Subset(sub)) => {
87
+ DfSplittingPair::Categorical64(arr, sub)
88
+ }
89
+ _ => panic!("Unsupported array type!"),
90
+ },
91
+ Array::BooleanArray(arr) => match pivot {
92
+ &DfPivot::Logical => DfSplittingPair::Boolean(arr),
93
+ _ => panic!("Unsupported array type!"),
94
+ },
95
+ _ => panic!("Unsupported array type!"),
96
+ };
97
+
98
+ Self { pair, mask_iter }
99
+ }
100
+ }
101
+
102
+ impl<'a, 'b, M> Iterator for SplittingIterator<'a, M>
103
+ where
104
+ M: Iterator<Item = &'b usize>,
105
+ {
106
+ type Item = bool;
107
+ fn next(&mut self) -> Option<bool> {
108
+ if let Some(&e) = self.mask_iter.next() {
109
+ let ans = match self.pair {
110
+ DfSplittingPair::Boolean(x) => x[e],
111
+ DfSplittingPair::Float32(x, xt) => x[e] as f64 > xt,
112
+ DfSplittingPair::Float64(x, xt) => x[e] > xt,
113
+ DfSplittingPair::Integer8(x, xt) => x[e] as i64 > xt,
114
+ DfSplittingPair::Integer16(x, xt) => x[e] as i64 > xt,
115
+ DfSplittingPair::Integer32(x, xt) => x[e] as i64 > xt,
116
+ DfSplittingPair::Integer64(x, xt) => x[e] > xt,
117
+ DfSplittingPair::UInteger8(x, xt) => x[e] as u64 > xt,
118
+ DfSplittingPair::UInteger16(x, xt) => x[e] as u64 > xt,
119
+ DfSplittingPair::UInteger32(x, xt) => x[e] as u64 > xt,
120
+ DfSplittingPair::UInteger64(x, xt) => x[e] > xt,
121
+ DfSplittingPair::Categorical8(x, split) => split & (1 << x[e] as u64) != 0,
122
+ DfSplittingPair::Categorical16(x, split) => split & (1 << x[e] as u64) != 0,
123
+ DfSplittingPair::Categorical32(x, split) => split & (1 << x[e] as u64) != 0,
124
+ DfSplittingPair::Categorical64(x, split) => split & (1 << x[e]) != 0,
125
+ };
126
+ Some(ans)
127
+ } else {
128
+ None
129
+ }
130
+ }
131
+ }
132
+
133
+ pub struct FYSampler<I> {
134
+ mixed: Vec<usize>,
135
+ left: usize,
136
+ marker: PhantomData<I>,
137
+ }
138
+
139
+ impl<I: RfInput<FeatureId = usize>> FYSampler<I> {
140
+ pub fn new(input: &I) -> Self {
141
+ Self {
142
+ mixed: (0..input.feature_count()).collect(),
143
+ left: input.feature_count(),
144
+ marker: PhantomData,
145
+ }
146
+ }
147
+ }
148
+
149
+ impl<I: RfInput<FeatureId = usize>> FeatureSampler<I> for FYSampler<I> {
150
+ fn random_feature(&mut self, rng: &mut RfRng) -> I::FeatureId {
151
+ let sel = rng.up_to(self.left);
152
+ let ans = self.mixed[sel];
153
+ self.left = self.left.checked_sub(1).unwrap();
154
+ self.mixed.swap(sel, self.left);
155
+ ans
156
+ }
157
+ fn reload(&mut self) {
158
+ self.left = self.mixed.len();
159
+ }
160
+ fn reset(&mut self) {
161
+ self.mixed = (0..self.mixed.len()).collect();
162
+ self.left = self.mixed.len();
163
+ }
164
+ }
165
+
166
+ macro_rules! impl_from_uint_for_dfpivot {
167
+ ($($t:ty),* $(,)?) => {
168
+ $(
169
+ impl From<$t> for DfPivot {
170
+ fn from(value: $t) -> Self {
171
+ DfPivot::UInteger(value as u64)
172
+ }
173
+ }
174
+ )*
175
+ };
176
+ }
177
+
178
+ macro_rules! impl_from_int_for_dfpivot {
179
+ ($($t:ty),* $(,)?) => {
180
+ $(
181
+ impl From<$t> for DfPivot {
182
+ fn from(value: $t) -> Self {
183
+ DfPivot::Integer(value as i64)
184
+ }
185
+ }
186
+ )*
187
+ };
188
+ }
189
+
190
+ impl_from_int_for_dfpivot!(i8, i16, i32, i64);
191
+ impl_from_uint_for_dfpivot!(u8, u16, u32, u64);
@@ -0,0 +1,59 @@
1
+ use crate::classification::ClsDecisionBasicType;
2
+
3
+ use super::DataFrame;
4
+ use minarrow::CategoricalArray;
5
+ use std::collections::HashMap;
6
+ use xrf::{AccuracyDecreaseAggregator, Mask, RfInput};
7
+
8
+ pub struct ClsDaAggregator {
9
+ direct: Vec<Option<ClsDecisionBasicType>>,
10
+ drops: HashMap<usize, isize>,
11
+ n: usize,
12
+ true_decision: CategoricalArray<ClsDecisionBasicType>,
13
+ }
14
+ impl AccuracyDecreaseAggregator<DataFrame> for ClsDaAggregator {
15
+ fn new(input: &DataFrame, on: &Mask, n: usize) -> Self {
16
+ Self {
17
+ direct: vec![None; n],
18
+ drops: HashMap::new(),
19
+ n: on.len(),
20
+ //TODO: Reference here
21
+ true_decision: input.decision.clone(),
22
+ }
23
+ }
24
+ fn ingest(&mut self, permutted: Option<usize>, mask: &Mask, vote: &ClsDecisionBasicType) {
25
+ if let Some(permutted) = permutted {
26
+ let diff: isize = mask
27
+ .iter()
28
+ .map(|&e| {
29
+ let oob_vote = self.direct.get(e).unwrap().unwrap();
30
+ if !oob_vote.eq(vote) {
31
+ let tru = self.true_decision[e];
32
+ match (tru.eq(vote), tru.eq(&oob_vote)) {
33
+ (true, true) => unreachable!("Logic error"),
34
+ (true, false) => -1,
35
+ (false, true) => 1,
36
+
37
+ (false, false) => 0,
38
+ }
39
+ } else {
40
+ 0
41
+ }
42
+ })
43
+ .sum();
44
+ *self.drops.entry(permutted).or_insert(0) += diff;
45
+ } else {
46
+ for &e in mask.iter() {
47
+ self.direct[e] = Some(*vote);
48
+ }
49
+ }
50
+ }
51
+ fn get_direct_vote(&self, e: usize) -> ClsDecisionBasicType {
52
+ self.direct.get(e).unwrap().unwrap()
53
+ }
54
+ fn mda_iter(&self) -> impl Iterator<Item = (<DataFrame as RfInput>::FeatureId, f64)> {
55
+ self.drops
56
+ .iter()
57
+ .map(|(a, b)| (*a, (*b as f64) / (self.n as f64)))
58
+ }
59
+ }
@@ -0,0 +1,187 @@
1
+ use std::ops::Add;
2
+
3
+ use super::{DecisionSlice, Votes};
4
+ use crate::{attribute::DfPivot, classification::ClsDecisionBasicType, tools::MidpointThreshold};
5
+ use minarrow::BooleanArray;
6
+ use xrf::{Mask, RfRng, VoteAggregator};
7
+
8
+ pub fn scan_bin(x: &BooleanArray<()>, ys: &DecisionSlice, mask: &Mask) -> Option<(DfPivot, f64)> {
9
+ let mut left = Votes::new(ys.ncat);
10
+ let mut xt = 0_usize;
11
+ let n = mask.len();
12
+ mask.iter()
13
+ .map(|&e| x[e])
14
+ .zip(ys.values.iter())
15
+ .for_each(|(x, &y)| {
16
+ if x {
17
+ left.ingest_vote(y);
18
+ xt += 1;
19
+ }
20
+ });
21
+ let score: f64 = ys
22
+ .summary
23
+ .0
24
+ .iter()
25
+ .zip(left.0.iter())
26
+ .map(|(total, &left)| {
27
+ let for_false = (n - xt) as f64;
28
+ let for_true = xt as f64;
29
+ let n = n as f64;
30
+ let right = (total - left) as f64;
31
+ let left = left as f64;
32
+ (left / for_true) * (left / n) + (right / for_false) * (right / n)
33
+ })
34
+ .sum();
35
+ Some((DfPivot::Logical, score))
36
+ }
37
+
38
+ pub fn scan_categorical<T: Copy + Ord + TryInto<usize> + Into<DfPivot> + MidpointThreshold>(
39
+ x: &[T],
40
+ xc: usize,
41
+ ys: &DecisionSlice,
42
+ mask: &Mask,
43
+ _rng: &mut RfRng,
44
+ ) -> Option<(DfPivot, f64)> {
45
+ if xc > 10 {
46
+ //When there is too many combinations, just treat it as ordered
47
+ return scan_integer(x, ys, mask);
48
+ }
49
+ if xc < 2 {
50
+ return None;
51
+ }
52
+ let n = mask.len();
53
+ let mut va: Vec<Votes> = std::iter::repeat_with(|| Votes::new(ys.ncat))
54
+ .take(xc)
55
+ .collect();
56
+ mask.iter()
57
+ .map(|&e| x[e])
58
+ .zip(ys.values.iter())
59
+ .for_each(|(x, &y)| va[x.try_into().ok().unwrap()].ingest_vote(y));
60
+ let sub_max: u64 = (1 << (xc - 1)) - 1;
61
+
62
+ (0..sub_max)
63
+ .map(|bitmask_id| bitmask_id + (1 << (xc - 1)))
64
+ .fold(None, |acc: Option<(u64, f64)>, bitmask| {
65
+ let left = va
66
+ .iter()
67
+ .enumerate()
68
+ .filter(|(e, _)| bitmask & (1 << e) != 0)
69
+ .fold(Votes::new(ys.ncat), |mut acc, (_, v)| {
70
+ acc.merge(v);
71
+ acc
72
+ });
73
+ let in_left: usize = left.0.iter().sum();
74
+
75
+ let score: f64 = ys
76
+ .summary
77
+ .0
78
+ .iter()
79
+ .zip(left.0.iter())
80
+ .map(|(&all, &left)| {
81
+ let ahead = (n - in_left) as f64;
82
+ let scanned = in_left as f64;
83
+ let n = n as f64;
84
+ let right = (all - left) as f64;
85
+ let left = left as f64;
86
+ (left / scanned) * (left / n) + (right / ahead) * (right / n)
87
+ })
88
+ .sum();
89
+ if score > acc.map(|x| x.1).unwrap_or(f64::NEG_INFINITY) {
90
+ return Some((bitmask, score));
91
+ }
92
+ acc
93
+ })
94
+ .map(|(bitmask, score)| (DfPivot::Subset(bitmask), score))
95
+ }
96
+
97
+ pub fn scan_float<T: Copy + PartialOrd + Add<T, Output = T> + Into<f64>>(
98
+ x: &[T],
99
+ ys: &DecisionSlice,
100
+ mask: &Mask,
101
+ ) -> Option<(DfPivot, f64)> {
102
+ let mut bound: Vec<(T, ClsDecisionBasicType)> = mask
103
+ .iter()
104
+ .zip(ys.values.iter())
105
+ .map(|(&xe, &y)| (x[xe], y))
106
+ .collect();
107
+ bound.sort_unstable_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
108
+
109
+ let n = bound.len();
110
+ let mut left = Votes::new(ys.ncat);
111
+ let mut scanned = 0_usize;
112
+
113
+ bound
114
+ .windows(2)
115
+ .map(|x| (x[0].0, x[1].0, x[0].1))
116
+ .fold(None, |acc: Option<(f64, f64)>, (x, next_x, y)| {
117
+ scanned += 1;
118
+ left.ingest_vote(y);
119
+ if x.partial_cmp(&next_x).unwrap().is_ne() {
120
+ let score: f64 = ys
121
+ .summary
122
+ .0
123
+ .iter()
124
+ .zip(left.0.iter())
125
+ .map(|(&all, &left)| {
126
+ let ahead = (n - scanned) as f64;
127
+ let scanned = scanned as f64;
128
+ let n = n as f64;
129
+ let right = (all - left) as f64;
130
+ let left = left as f64;
131
+ (left / scanned) * (left / n) + (right / ahead) * (right / n)
132
+ })
133
+ .sum();
134
+ if score > acc.map(|x| x.1).unwrap_or(f64::NEG_INFINITY) {
135
+ return Some((Into::<f64>::into(x + next_x) * 0.5, score));
136
+ }
137
+ }
138
+ acc
139
+ })
140
+ .map(|(thresh, score)| (DfPivot::Real(thresh), score))
141
+ }
142
+
143
+ pub fn scan_integer<T: Copy + Ord + Into<DfPivot> + MidpointThreshold>(
144
+ x: &[T],
145
+ ys: &DecisionSlice,
146
+ mask: &Mask,
147
+ ) -> Option<(DfPivot, f64)> {
148
+ let mut bound: Vec<(T, ClsDecisionBasicType)> = mask
149
+ .iter()
150
+ .zip(ys.values.iter())
151
+ .map(|(&xe, &y)| (x[xe], y))
152
+ .collect();
153
+ bound.sort_unstable_by_key(|a| a.0);
154
+
155
+ let n = bound.len();
156
+ let mut left = Votes::new(ys.ncat);
157
+ let mut scanned = 0_usize;
158
+
159
+ bound
160
+ .windows(2)
161
+ .map(|x| (x[0].0, x[1].0, x[0].1))
162
+ .fold(None, |acc: Option<(T, f64)>, (x, next_x, y)| {
163
+ scanned += 1;
164
+ left.ingest_vote(y);
165
+ if x.cmp(&next_x).is_ne() {
166
+ let score: f64 = ys
167
+ .summary
168
+ .0
169
+ .iter()
170
+ .zip(left.0.iter())
171
+ .map(|(&all, &left)| {
172
+ let ahead = (n - scanned) as f64;
173
+ let scanned = scanned as f64;
174
+ let n = n as f64;
175
+ let right = (all - left) as f64;
176
+ let left = left as f64;
177
+ (left / scanned) * (left / n) + (right / ahead) * (right / n)
178
+ })
179
+ .sum();
180
+ if score > acc.map(|x| x.1).unwrap_or(f64::NEG_INFINITY) {
181
+ return Some((x.midpoint_threshold(next_x), score));
182
+ }
183
+ }
184
+ acc
185
+ })
186
+ .map(|(thresh, score)| (thresh.into(), score))
187
+ }
@@ -0,0 +1,45 @@
1
+ use crate::classification::ClsDecisionBasicType;
2
+
3
+ use super::DataFrame;
4
+ use serde::{Deserialize, Serialize};
5
+ use xrf::{FairBest, RfRng, VoteAggregator};
6
+
7
+ #[derive(Clone, Serialize, Deserialize)]
8
+ pub struct Votes(pub Vec<usize>); //TODO: Fix impurity to make it private
9
+
10
+ impl Votes {
11
+ pub fn is_pure(&self) -> bool {
12
+ self.0.iter().filter(|&&x| x > 0).count() <= 1
13
+ }
14
+ pub fn new(ncat: ClsDecisionBasicType) -> Self {
15
+ Self(std::iter::repeat_n(0, ncat as usize).collect())
16
+ }
17
+
18
+ pub fn collapse_empty_random(&self, rng: &mut RfRng) -> ClsDecisionBasicType {
19
+ self.0
20
+ .iter()
21
+ .enumerate()
22
+ .fold(FairBest::new(), |mut best, (cls, count)| {
23
+ best.ingest(count, cls, rng);
24
+ best
25
+ })
26
+ .consume()
27
+ .map(|(_score, class)| class as ClsDecisionBasicType)
28
+ .unwrap()
29
+ }
30
+ }
31
+
32
+ impl VoteAggregator<DataFrame> for Votes {
33
+ fn new(input: &DataFrame) -> Self {
34
+ Votes::new(input.ncat)
35
+ }
36
+ fn ingest_vote(&mut self, v: ClsDecisionBasicType) {
37
+ self.0[v as usize] += 1;
38
+ }
39
+ fn merge(&mut self, other: &Self) {
40
+ self.0
41
+ .iter_mut()
42
+ .zip(other.0.iter())
43
+ .for_each(|(t, s)| *t += s);
44
+ }
45
+ }