pyfru 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyfru-0.1.0/Cargo.toml +4 -0
- pyfru-0.1.0/PKG-INFO +39 -0
- pyfru-0.1.0/README.md +9 -0
- pyfru-0.1.0/fru-arrow/Cargo.toml +26 -0
- pyfru-0.1.0/fru-arrow/README.md +9 -0
- pyfru-0.1.0/fru-arrow/src/attribute.rs +191 -0
- pyfru-0.1.0/fru-arrow/src/classification/da.rs +59 -0
- pyfru-0.1.0/fru-arrow/src/classification/impurity.rs +187 -0
- pyfru-0.1.0/fru-arrow/src/classification/votes.rs +45 -0
- pyfru-0.1.0/fru-arrow/src/classification.rs +137 -0
- pyfru-0.1.0/fru-arrow/src/lib.rs +738 -0
- pyfru-0.1.0/fru-arrow/src/regression/da.rs +46 -0
- pyfru-0.1.0/fru-arrow/src/regression/impurity.rs +143 -0
- pyfru-0.1.0/fru-arrow/src/regression/votes.rs +38 -0
- pyfru-0.1.0/fru-arrow/src/regression.rs +159 -0
- pyfru-0.1.0/fru-arrow/src/serialize.rs +223 -0
- pyfru-0.1.0/fru-arrow/src/tools.rs +27 -0
- pyfru-0.1.0/fru-arrow/tests/random_forest.rs +518 -0
- pyfru-0.1.0/pyfru/Cargo.lock +335 -0
- pyfru-0.1.0/pyfru/Cargo.toml +20 -0
- pyfru-0.1.0/pyfru/README.md +9 -0
- pyfru-0.1.0/pyfru/python/docs/Makefile +16 -0
- pyfru-0.1.0/pyfru/python/docs/api.rst +8 -0
- pyfru-0.1.0/pyfru/python/docs/conf.py +43 -0
- pyfru-0.1.0/pyfru/python/docs/getting_started.rst +217 -0
- pyfru-0.1.0/pyfru/python/docs/index.rst +14 -0
- pyfru-0.1.0/pyfru/python/tests/test_pyfru.py +356 -0
- pyfru-0.1.0/pyfru/src/lib.rs +144 -0
- pyfru-0.1.0/pyfru/uv.lock +808 -0
- pyfru-0.1.0/pyproject.toml +49 -0
- pyfru-0.1.0/python/pyfru/__init__.py +3 -0
- pyfru-0.1.0/python/pyfru/data_structures.py +74 -0
- pyfru-0.1.0/python/pyfru/random_forest.py +411 -0
pyfru-0.1.0/Cargo.toml
ADDED
pyfru-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: pyfru
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Classifier: Development Status :: 4 - Beta
|
|
5
|
+
Classifier: License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)
|
|
6
|
+
Classifier: Operating System :: OS Independent
|
|
7
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
8
|
+
Classifier: Programming Language :: Python
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
13
|
+
Requires-Dist: numpy>=2.0.0
|
|
14
|
+
Requires-Dist: pyarrow>=23.0.1
|
|
15
|
+
Requires-Dist: pytest>=9 ; extra == 'dev'
|
|
16
|
+
Requires-Dist: pandas>=3 ; extra == 'dev'
|
|
17
|
+
Requires-Dist: furo>=2025.12.19 ; extra == 'dev'
|
|
18
|
+
Requires-Dist: myst-parser>=5.1.0 ; extra == 'dev'
|
|
19
|
+
Requires-Dist: ruff>=0.15.12 ; extra == 'dev'
|
|
20
|
+
Requires-Dist: sphinx>=9.1.0 ; extra == 'dev'
|
|
21
|
+
Provides-Extra: dev
|
|
22
|
+
Summary: Blazingly fast Random Forest with arrow support (pandas, polars, pyarrow etc.)
|
|
23
|
+
Author-email: Chris Piwonski <fruarrow@kpiwonski.com>
|
|
24
|
+
License-Expression: GPL-3.0-or-later
|
|
25
|
+
Requires-Python: >=3.12
|
|
26
|
+
Description-Content-Type: text/markdown; charset=UTF-8; variant=GFM
|
|
27
|
+
Project-URL: documentation, https://kpiwonski.github.io/fru-arrow/
|
|
28
|
+
Project-URL: repository, https://github.com/kpiwonski/fru-arrow
|
|
29
|
+
|
|
30
|
+
# Fru-arrow
|
|
31
|
+
|
|
32
|
+
[R version](https://cran.r-project.org/web/packages/fru/index.html) |
|
|
33
|
+
[Pyfru docs](https://kpiwonski.github.io/fru-arrow/)
|
|
34
|
+
|
|
35
|
+
Fru-arrow is a highly performant implementation of the **Random Forest** model. It uses Arrow PyCapsule underneath,
|
|
36
|
+
making integration with any library that supports it - ``polars``, ``pandas``, ``pyarrow`` straightforward.
|
|
37
|
+
Moreover, it features permutation importance with a novel, highly optimized algorithm.
|
|
38
|
+
It can be used for both **classification** and **regression**, as well as out-of-bag predictions.
|
|
39
|
+
|
pyfru-0.1.0/README.md
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
# Fru-arrow
|
|
2
|
+
|
|
3
|
+
[R version](https://cran.r-project.org/web/packages/fru/index.html) |
|
|
4
|
+
[Pyfru docs](https://kpiwonski.github.io/fru-arrow/)
|
|
5
|
+
|
|
6
|
+
Fru-arrow is a highly performant implementation of the **Random Forest** model. It uses Arrow PyCapsule underneath,
|
|
7
|
+
making integration with any library that supports it - ``polars``, ``pandas``, ``pyarrow`` straightforward.
|
|
8
|
+
Moreover, it features permutation importance with a novel, highly optimized algorithm.
|
|
9
|
+
It can be used for both **classification** and **regression**, as well as out-of-bag predictions.
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
[package]
|
|
2
|
+
name = "fru-arrow"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
authors = ["Krzysztof Piwonski <fruarrow@kpiwonski.com>", "Miron B. Kursa <m@mbq.me>"]
|
|
5
|
+
edition="2024"
|
|
6
|
+
license = "GPL-3.0-or-later"
|
|
7
|
+
description = "Blazingly fast implementation of Random Forest with apache arrow support"
|
|
8
|
+
readme = "README.md"
|
|
9
|
+
keywords = ["machine-learning", "ml", "random-forest", "decision-tree"]
|
|
10
|
+
categories = ["artificial-intelligence"]
|
|
11
|
+
repository = "https://github.com/kpiwonski/fru-arrow"
|
|
12
|
+
|
|
13
|
+
[dependencies]
|
|
14
|
+
xrf = {version="0.1.7"}
|
|
15
|
+
minarrow = {version="0.11.0", features=["extended_numeric_types", "extended_categorical"]}
|
|
16
|
+
serde = "1.0.228"
|
|
17
|
+
postcard = {version = "1.1.3", features = ["use-std"]}
|
|
18
|
+
thiserror = "2.0.18"
|
|
19
|
+
|
|
20
|
+
[dev-dependencies]
|
|
21
|
+
rand = "0.9"
|
|
22
|
+
minarrow = {version="0.11.0", features=["extended_numeric_types", "extended_categorical", "views"]}
|
|
23
|
+
|
|
24
|
+
[profile.release]
|
|
25
|
+
codegen-units = 1
|
|
26
|
+
lto = "fat"
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
# Fru-arrow
|
|
2
|
+
|
|
3
|
+
[R version](https://cran.r-project.org/web/packages/fru/index.html) |
|
|
4
|
+
[Pyfru docs](https://kpiwonski.github.io/fru-arrow/)
|
|
5
|
+
|
|
6
|
+
Fru-arrow is a highly performant implementation of the **Random Forest** model. It uses Arrow PyCapsule underneath,
|
|
7
|
+
making integration with any library that supports it - ``polars``, ``pandas``, ``pyarrow`` straightforward.
|
|
8
|
+
Moreover, it features permutation importance with a novel, highly optimized algorithm.
|
|
9
|
+
It can be used for both **classification** and **regression**, as well as out-of-bag predictions.
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
use std::marker::PhantomData;
|
|
2
|
+
|
|
3
|
+
use minarrow::{
|
|
4
|
+
Array, BooleanArray, CategoricalArray, FloatArray, IntegerArray, NumericArray, TextArray,
|
|
5
|
+
};
|
|
6
|
+
use serde::{Deserialize, Serialize};
|
|
7
|
+
use xrf::{FeatureSampler, RfInput, RfRng};
|
|
8
|
+
|
|
9
|
+
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
10
|
+
pub enum DfPivot {
|
|
11
|
+
Logical,
|
|
12
|
+
Real(f64),
|
|
13
|
+
Integer(i64),
|
|
14
|
+
UInteger(u64),
|
|
15
|
+
Subset(u64),
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
pub struct SplittingIterator<'a, M> {
|
|
19
|
+
pair: DfSplittingPair<'a>,
|
|
20
|
+
mask_iter: M,
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
enum DfSplittingPair<'a> {
|
|
24
|
+
Boolean(&'a BooleanArray<()>),
|
|
25
|
+
Float32(&'a FloatArray<f32>, f64),
|
|
26
|
+
Float64(&'a FloatArray<f64>, f64),
|
|
27
|
+
Integer8(&'a IntegerArray<i8>, i64),
|
|
28
|
+
Integer16(&'a IntegerArray<i16>, i64),
|
|
29
|
+
Integer32(&'a IntegerArray<i32>, i64),
|
|
30
|
+
Integer64(&'a IntegerArray<i64>, i64),
|
|
31
|
+
UInteger8(&'a IntegerArray<u8>, u64),
|
|
32
|
+
UInteger16(&'a IntegerArray<u16>, u64),
|
|
33
|
+
UInteger32(&'a IntegerArray<u32>, u64),
|
|
34
|
+
UInteger64(&'a IntegerArray<u64>, u64),
|
|
35
|
+
Categorical8(&'a CategoricalArray<u8>, u64),
|
|
36
|
+
Categorical16(&'a CategoricalArray<u16>, u64),
|
|
37
|
+
Categorical32(&'a CategoricalArray<u32>, u64),
|
|
38
|
+
Categorical64(&'a CategoricalArray<u64>, u64),
|
|
39
|
+
}
|
|
40
|
+
impl<'a, M> SplittingIterator<'a, M> {
|
|
41
|
+
pub fn new(x: &'a Array, pivot: &DfPivot, mask_iter: M) -> Self {
|
|
42
|
+
let pair = match x {
|
|
43
|
+
Array::NumericArray(num) => match (num, pivot) {
|
|
44
|
+
(NumericArray::Float32(arr), &DfPivot::Real(xt)) => {
|
|
45
|
+
DfSplittingPair::Float32(arr, xt)
|
|
46
|
+
}
|
|
47
|
+
(NumericArray::Float64(arr), &DfPivot::Real(xt)) => {
|
|
48
|
+
DfSplittingPair::Float64(arr, xt)
|
|
49
|
+
}
|
|
50
|
+
(NumericArray::Int8(arr), &DfPivot::Integer(xt)) => {
|
|
51
|
+
DfSplittingPair::Integer8(arr, xt)
|
|
52
|
+
}
|
|
53
|
+
(NumericArray::Int16(arr), &DfPivot::Integer(xt)) => {
|
|
54
|
+
DfSplittingPair::Integer16(arr, xt)
|
|
55
|
+
}
|
|
56
|
+
(NumericArray::Int32(arr), &DfPivot::Integer(xt)) => {
|
|
57
|
+
DfSplittingPair::Integer32(arr, xt)
|
|
58
|
+
}
|
|
59
|
+
(NumericArray::Int64(arr), &DfPivot::Integer(xt)) => {
|
|
60
|
+
DfSplittingPair::Integer64(arr, xt)
|
|
61
|
+
}
|
|
62
|
+
(NumericArray::UInt8(arr), &DfPivot::UInteger(xt)) => {
|
|
63
|
+
DfSplittingPair::UInteger8(arr, xt)
|
|
64
|
+
}
|
|
65
|
+
(NumericArray::UInt16(arr), &DfPivot::UInteger(xt)) => {
|
|
66
|
+
DfSplittingPair::UInteger16(arr, xt)
|
|
67
|
+
}
|
|
68
|
+
(NumericArray::UInt32(arr), &DfPivot::UInteger(xt)) => {
|
|
69
|
+
DfSplittingPair::UInteger32(arr, xt)
|
|
70
|
+
}
|
|
71
|
+
(NumericArray::UInt64(arr), &DfPivot::UInteger(xt)) => {
|
|
72
|
+
DfSplittingPair::UInteger64(arr, xt)
|
|
73
|
+
}
|
|
74
|
+
_ => panic!("Unsupported array type!"),
|
|
75
|
+
},
|
|
76
|
+
Array::TextArray(cat) => match (cat, pivot) {
|
|
77
|
+
(TextArray::Categorical8(arr), &DfPivot::Subset(sub)) => {
|
|
78
|
+
DfSplittingPair::Categorical8(arr, sub)
|
|
79
|
+
}
|
|
80
|
+
(TextArray::Categorical16(arr), &DfPivot::Subset(sub)) => {
|
|
81
|
+
DfSplittingPair::Categorical16(arr, sub)
|
|
82
|
+
}
|
|
83
|
+
(TextArray::Categorical32(arr), &DfPivot::Subset(sub)) => {
|
|
84
|
+
DfSplittingPair::Categorical32(arr, sub)
|
|
85
|
+
}
|
|
86
|
+
(TextArray::Categorical64(arr), &DfPivot::Subset(sub)) => {
|
|
87
|
+
DfSplittingPair::Categorical64(arr, sub)
|
|
88
|
+
}
|
|
89
|
+
_ => panic!("Unsupported array type!"),
|
|
90
|
+
},
|
|
91
|
+
Array::BooleanArray(arr) => match pivot {
|
|
92
|
+
&DfPivot::Logical => DfSplittingPair::Boolean(arr),
|
|
93
|
+
_ => panic!("Unsupported array type!"),
|
|
94
|
+
},
|
|
95
|
+
_ => panic!("Unsupported array type!"),
|
|
96
|
+
};
|
|
97
|
+
|
|
98
|
+
Self { pair, mask_iter }
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
impl<'a, 'b, M> Iterator for SplittingIterator<'a, M>
|
|
103
|
+
where
|
|
104
|
+
M: Iterator<Item = &'b usize>,
|
|
105
|
+
{
|
|
106
|
+
type Item = bool;
|
|
107
|
+
fn next(&mut self) -> Option<bool> {
|
|
108
|
+
if let Some(&e) = self.mask_iter.next() {
|
|
109
|
+
let ans = match self.pair {
|
|
110
|
+
DfSplittingPair::Boolean(x) => x[e],
|
|
111
|
+
DfSplittingPair::Float32(x, xt) => x[e] as f64 > xt,
|
|
112
|
+
DfSplittingPair::Float64(x, xt) => x[e] > xt,
|
|
113
|
+
DfSplittingPair::Integer8(x, xt) => x[e] as i64 > xt,
|
|
114
|
+
DfSplittingPair::Integer16(x, xt) => x[e] as i64 > xt,
|
|
115
|
+
DfSplittingPair::Integer32(x, xt) => x[e] as i64 > xt,
|
|
116
|
+
DfSplittingPair::Integer64(x, xt) => x[e] > xt,
|
|
117
|
+
DfSplittingPair::UInteger8(x, xt) => x[e] as u64 > xt,
|
|
118
|
+
DfSplittingPair::UInteger16(x, xt) => x[e] as u64 > xt,
|
|
119
|
+
DfSplittingPair::UInteger32(x, xt) => x[e] as u64 > xt,
|
|
120
|
+
DfSplittingPair::UInteger64(x, xt) => x[e] > xt,
|
|
121
|
+
DfSplittingPair::Categorical8(x, split) => split & (1 << x[e] as u64) != 0,
|
|
122
|
+
DfSplittingPair::Categorical16(x, split) => split & (1 << x[e] as u64) != 0,
|
|
123
|
+
DfSplittingPair::Categorical32(x, split) => split & (1 << x[e] as u64) != 0,
|
|
124
|
+
DfSplittingPair::Categorical64(x, split) => split & (1 << x[e]) != 0,
|
|
125
|
+
};
|
|
126
|
+
Some(ans)
|
|
127
|
+
} else {
|
|
128
|
+
None
|
|
129
|
+
}
|
|
130
|
+
}
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
pub struct FYSampler<I> {
|
|
134
|
+
mixed: Vec<usize>,
|
|
135
|
+
left: usize,
|
|
136
|
+
marker: PhantomData<I>,
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
impl<I: RfInput<FeatureId = usize>> FYSampler<I> {
|
|
140
|
+
pub fn new(input: &I) -> Self {
|
|
141
|
+
Self {
|
|
142
|
+
mixed: (0..input.feature_count()).collect(),
|
|
143
|
+
left: input.feature_count(),
|
|
144
|
+
marker: PhantomData,
|
|
145
|
+
}
|
|
146
|
+
}
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
impl<I: RfInput<FeatureId = usize>> FeatureSampler<I> for FYSampler<I> {
|
|
150
|
+
fn random_feature(&mut self, rng: &mut RfRng) -> I::FeatureId {
|
|
151
|
+
let sel = rng.up_to(self.left);
|
|
152
|
+
let ans = self.mixed[sel];
|
|
153
|
+
self.left = self.left.checked_sub(1).unwrap();
|
|
154
|
+
self.mixed.swap(sel, self.left);
|
|
155
|
+
ans
|
|
156
|
+
}
|
|
157
|
+
fn reload(&mut self) {
|
|
158
|
+
self.left = self.mixed.len();
|
|
159
|
+
}
|
|
160
|
+
fn reset(&mut self) {
|
|
161
|
+
self.mixed = (0..self.mixed.len()).collect();
|
|
162
|
+
self.left = self.mixed.len();
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
macro_rules! impl_from_uint_for_dfpivot {
|
|
167
|
+
($($t:ty),* $(,)?) => {
|
|
168
|
+
$(
|
|
169
|
+
impl From<$t> for DfPivot {
|
|
170
|
+
fn from(value: $t) -> Self {
|
|
171
|
+
DfPivot::UInteger(value as u64)
|
|
172
|
+
}
|
|
173
|
+
}
|
|
174
|
+
)*
|
|
175
|
+
};
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
macro_rules! impl_from_int_for_dfpivot {
|
|
179
|
+
($($t:ty),* $(,)?) => {
|
|
180
|
+
$(
|
|
181
|
+
impl From<$t> for DfPivot {
|
|
182
|
+
fn from(value: $t) -> Self {
|
|
183
|
+
DfPivot::Integer(value as i64)
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
)*
|
|
187
|
+
};
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
impl_from_int_for_dfpivot!(i8, i16, i32, i64);
|
|
191
|
+
impl_from_uint_for_dfpivot!(u8, u16, u32, u64);
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
use crate::classification::ClsDecisionBasicType;
|
|
2
|
+
|
|
3
|
+
use super::DataFrame;
|
|
4
|
+
use minarrow::CategoricalArray;
|
|
5
|
+
use std::collections::HashMap;
|
|
6
|
+
use xrf::{AccuracyDecreaseAggregator, Mask, RfInput};
|
|
7
|
+
|
|
8
|
+
pub struct ClsDaAggregator {
|
|
9
|
+
direct: Vec<Option<ClsDecisionBasicType>>,
|
|
10
|
+
drops: HashMap<usize, isize>,
|
|
11
|
+
n: usize,
|
|
12
|
+
true_decision: CategoricalArray<ClsDecisionBasicType>,
|
|
13
|
+
}
|
|
14
|
+
impl AccuracyDecreaseAggregator<DataFrame> for ClsDaAggregator {
|
|
15
|
+
fn new(input: &DataFrame, on: &Mask, n: usize) -> Self {
|
|
16
|
+
Self {
|
|
17
|
+
direct: vec![None; n],
|
|
18
|
+
drops: HashMap::new(),
|
|
19
|
+
n: on.len(),
|
|
20
|
+
//TODO: Reference here
|
|
21
|
+
true_decision: input.decision.clone(),
|
|
22
|
+
}
|
|
23
|
+
}
|
|
24
|
+
fn ingest(&mut self, permutted: Option<usize>, mask: &Mask, vote: &ClsDecisionBasicType) {
|
|
25
|
+
if let Some(permutted) = permutted {
|
|
26
|
+
let diff: isize = mask
|
|
27
|
+
.iter()
|
|
28
|
+
.map(|&e| {
|
|
29
|
+
let oob_vote = self.direct.get(e).unwrap().unwrap();
|
|
30
|
+
if !oob_vote.eq(vote) {
|
|
31
|
+
let tru = self.true_decision[e];
|
|
32
|
+
match (tru.eq(vote), tru.eq(&oob_vote)) {
|
|
33
|
+
(true, true) => unreachable!("Logic error"),
|
|
34
|
+
(true, false) => -1,
|
|
35
|
+
(false, true) => 1,
|
|
36
|
+
|
|
37
|
+
(false, false) => 0,
|
|
38
|
+
}
|
|
39
|
+
} else {
|
|
40
|
+
0
|
|
41
|
+
}
|
|
42
|
+
})
|
|
43
|
+
.sum();
|
|
44
|
+
*self.drops.entry(permutted).or_insert(0) += diff;
|
|
45
|
+
} else {
|
|
46
|
+
for &e in mask.iter() {
|
|
47
|
+
self.direct[e] = Some(*vote);
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
fn get_direct_vote(&self, e: usize) -> ClsDecisionBasicType {
|
|
52
|
+
self.direct.get(e).unwrap().unwrap()
|
|
53
|
+
}
|
|
54
|
+
fn mda_iter(&self) -> impl Iterator<Item = (<DataFrame as RfInput>::FeatureId, f64)> {
|
|
55
|
+
self.drops
|
|
56
|
+
.iter()
|
|
57
|
+
.map(|(a, b)| (*a, (*b as f64) / (self.n as f64)))
|
|
58
|
+
}
|
|
59
|
+
}
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
use std::ops::Add;
|
|
2
|
+
|
|
3
|
+
use super::{DecisionSlice, Votes};
|
|
4
|
+
use crate::{attribute::DfPivot, classification::ClsDecisionBasicType, tools::MidpointThreshold};
|
|
5
|
+
use minarrow::BooleanArray;
|
|
6
|
+
use xrf::{Mask, RfRng, VoteAggregator};
|
|
7
|
+
|
|
8
|
+
pub fn scan_bin(x: &BooleanArray<()>, ys: &DecisionSlice, mask: &Mask) -> Option<(DfPivot, f64)> {
|
|
9
|
+
let mut left = Votes::new(ys.ncat);
|
|
10
|
+
let mut xt = 0_usize;
|
|
11
|
+
let n = mask.len();
|
|
12
|
+
mask.iter()
|
|
13
|
+
.map(|&e| x[e])
|
|
14
|
+
.zip(ys.values.iter())
|
|
15
|
+
.for_each(|(x, &y)| {
|
|
16
|
+
if x {
|
|
17
|
+
left.ingest_vote(y);
|
|
18
|
+
xt += 1;
|
|
19
|
+
}
|
|
20
|
+
});
|
|
21
|
+
let score: f64 = ys
|
|
22
|
+
.summary
|
|
23
|
+
.0
|
|
24
|
+
.iter()
|
|
25
|
+
.zip(left.0.iter())
|
|
26
|
+
.map(|(total, &left)| {
|
|
27
|
+
let for_false = (n - xt) as f64;
|
|
28
|
+
let for_true = xt as f64;
|
|
29
|
+
let n = n as f64;
|
|
30
|
+
let right = (total - left) as f64;
|
|
31
|
+
let left = left as f64;
|
|
32
|
+
(left / for_true) * (left / n) + (right / for_false) * (right / n)
|
|
33
|
+
})
|
|
34
|
+
.sum();
|
|
35
|
+
Some((DfPivot::Logical, score))
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
pub fn scan_categorical<T: Copy + Ord + TryInto<usize> + Into<DfPivot> + MidpointThreshold>(
|
|
39
|
+
x: &[T],
|
|
40
|
+
xc: usize,
|
|
41
|
+
ys: &DecisionSlice,
|
|
42
|
+
mask: &Mask,
|
|
43
|
+
_rng: &mut RfRng,
|
|
44
|
+
) -> Option<(DfPivot, f64)> {
|
|
45
|
+
if xc > 10 {
|
|
46
|
+
//When there is too many combinations, just treat it as ordered
|
|
47
|
+
return scan_integer(x, ys, mask);
|
|
48
|
+
}
|
|
49
|
+
if xc < 2 {
|
|
50
|
+
return None;
|
|
51
|
+
}
|
|
52
|
+
let n = mask.len();
|
|
53
|
+
let mut va: Vec<Votes> = std::iter::repeat_with(|| Votes::new(ys.ncat))
|
|
54
|
+
.take(xc)
|
|
55
|
+
.collect();
|
|
56
|
+
mask.iter()
|
|
57
|
+
.map(|&e| x[e])
|
|
58
|
+
.zip(ys.values.iter())
|
|
59
|
+
.for_each(|(x, &y)| va[x.try_into().ok().unwrap()].ingest_vote(y));
|
|
60
|
+
let sub_max: u64 = (1 << (xc - 1)) - 1;
|
|
61
|
+
|
|
62
|
+
(0..sub_max)
|
|
63
|
+
.map(|bitmask_id| bitmask_id + (1 << (xc - 1)))
|
|
64
|
+
.fold(None, |acc: Option<(u64, f64)>, bitmask| {
|
|
65
|
+
let left = va
|
|
66
|
+
.iter()
|
|
67
|
+
.enumerate()
|
|
68
|
+
.filter(|(e, _)| bitmask & (1 << e) != 0)
|
|
69
|
+
.fold(Votes::new(ys.ncat), |mut acc, (_, v)| {
|
|
70
|
+
acc.merge(v);
|
|
71
|
+
acc
|
|
72
|
+
});
|
|
73
|
+
let in_left: usize = left.0.iter().sum();
|
|
74
|
+
|
|
75
|
+
let score: f64 = ys
|
|
76
|
+
.summary
|
|
77
|
+
.0
|
|
78
|
+
.iter()
|
|
79
|
+
.zip(left.0.iter())
|
|
80
|
+
.map(|(&all, &left)| {
|
|
81
|
+
let ahead = (n - in_left) as f64;
|
|
82
|
+
let scanned = in_left as f64;
|
|
83
|
+
let n = n as f64;
|
|
84
|
+
let right = (all - left) as f64;
|
|
85
|
+
let left = left as f64;
|
|
86
|
+
(left / scanned) * (left / n) + (right / ahead) * (right / n)
|
|
87
|
+
})
|
|
88
|
+
.sum();
|
|
89
|
+
if score > acc.map(|x| x.1).unwrap_or(f64::NEG_INFINITY) {
|
|
90
|
+
return Some((bitmask, score));
|
|
91
|
+
}
|
|
92
|
+
acc
|
|
93
|
+
})
|
|
94
|
+
.map(|(bitmask, score)| (DfPivot::Subset(bitmask), score))
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
pub fn scan_float<T: Copy + PartialOrd + Add<T, Output = T> + Into<f64>>(
|
|
98
|
+
x: &[T],
|
|
99
|
+
ys: &DecisionSlice,
|
|
100
|
+
mask: &Mask,
|
|
101
|
+
) -> Option<(DfPivot, f64)> {
|
|
102
|
+
let mut bound: Vec<(T, ClsDecisionBasicType)> = mask
|
|
103
|
+
.iter()
|
|
104
|
+
.zip(ys.values.iter())
|
|
105
|
+
.map(|(&xe, &y)| (x[xe], y))
|
|
106
|
+
.collect();
|
|
107
|
+
bound.sort_unstable_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
|
108
|
+
|
|
109
|
+
let n = bound.len();
|
|
110
|
+
let mut left = Votes::new(ys.ncat);
|
|
111
|
+
let mut scanned = 0_usize;
|
|
112
|
+
|
|
113
|
+
bound
|
|
114
|
+
.windows(2)
|
|
115
|
+
.map(|x| (x[0].0, x[1].0, x[0].1))
|
|
116
|
+
.fold(None, |acc: Option<(f64, f64)>, (x, next_x, y)| {
|
|
117
|
+
scanned += 1;
|
|
118
|
+
left.ingest_vote(y);
|
|
119
|
+
if x.partial_cmp(&next_x).unwrap().is_ne() {
|
|
120
|
+
let score: f64 = ys
|
|
121
|
+
.summary
|
|
122
|
+
.0
|
|
123
|
+
.iter()
|
|
124
|
+
.zip(left.0.iter())
|
|
125
|
+
.map(|(&all, &left)| {
|
|
126
|
+
let ahead = (n - scanned) as f64;
|
|
127
|
+
let scanned = scanned as f64;
|
|
128
|
+
let n = n as f64;
|
|
129
|
+
let right = (all - left) as f64;
|
|
130
|
+
let left = left as f64;
|
|
131
|
+
(left / scanned) * (left / n) + (right / ahead) * (right / n)
|
|
132
|
+
})
|
|
133
|
+
.sum();
|
|
134
|
+
if score > acc.map(|x| x.1).unwrap_or(f64::NEG_INFINITY) {
|
|
135
|
+
return Some((Into::<f64>::into(x + next_x) * 0.5, score));
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
acc
|
|
139
|
+
})
|
|
140
|
+
.map(|(thresh, score)| (DfPivot::Real(thresh), score))
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
pub fn scan_integer<T: Copy + Ord + Into<DfPivot> + MidpointThreshold>(
|
|
144
|
+
x: &[T],
|
|
145
|
+
ys: &DecisionSlice,
|
|
146
|
+
mask: &Mask,
|
|
147
|
+
) -> Option<(DfPivot, f64)> {
|
|
148
|
+
let mut bound: Vec<(T, ClsDecisionBasicType)> = mask
|
|
149
|
+
.iter()
|
|
150
|
+
.zip(ys.values.iter())
|
|
151
|
+
.map(|(&xe, &y)| (x[xe], y))
|
|
152
|
+
.collect();
|
|
153
|
+
bound.sort_unstable_by_key(|a| a.0);
|
|
154
|
+
|
|
155
|
+
let n = bound.len();
|
|
156
|
+
let mut left = Votes::new(ys.ncat);
|
|
157
|
+
let mut scanned = 0_usize;
|
|
158
|
+
|
|
159
|
+
bound
|
|
160
|
+
.windows(2)
|
|
161
|
+
.map(|x| (x[0].0, x[1].0, x[0].1))
|
|
162
|
+
.fold(None, |acc: Option<(T, f64)>, (x, next_x, y)| {
|
|
163
|
+
scanned += 1;
|
|
164
|
+
left.ingest_vote(y);
|
|
165
|
+
if x.cmp(&next_x).is_ne() {
|
|
166
|
+
let score: f64 = ys
|
|
167
|
+
.summary
|
|
168
|
+
.0
|
|
169
|
+
.iter()
|
|
170
|
+
.zip(left.0.iter())
|
|
171
|
+
.map(|(&all, &left)| {
|
|
172
|
+
let ahead = (n - scanned) as f64;
|
|
173
|
+
let scanned = scanned as f64;
|
|
174
|
+
let n = n as f64;
|
|
175
|
+
let right = (all - left) as f64;
|
|
176
|
+
let left = left as f64;
|
|
177
|
+
(left / scanned) * (left / n) + (right / ahead) * (right / n)
|
|
178
|
+
})
|
|
179
|
+
.sum();
|
|
180
|
+
if score > acc.map(|x| x.1).unwrap_or(f64::NEG_INFINITY) {
|
|
181
|
+
return Some((x.midpoint_threshold(next_x), score));
|
|
182
|
+
}
|
|
183
|
+
}
|
|
184
|
+
acc
|
|
185
|
+
})
|
|
186
|
+
.map(|(thresh, score)| (thresh.into(), score))
|
|
187
|
+
}
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
use crate::classification::ClsDecisionBasicType;
|
|
2
|
+
|
|
3
|
+
use super::DataFrame;
|
|
4
|
+
use serde::{Deserialize, Serialize};
|
|
5
|
+
use xrf::{FairBest, RfRng, VoteAggregator};
|
|
6
|
+
|
|
7
|
+
#[derive(Clone, Serialize, Deserialize)]
|
|
8
|
+
pub struct Votes(pub Vec<usize>); //TODO: Fix impurity to make it private
|
|
9
|
+
|
|
10
|
+
impl Votes {
|
|
11
|
+
pub fn is_pure(&self) -> bool {
|
|
12
|
+
self.0.iter().filter(|&&x| x > 0).count() <= 1
|
|
13
|
+
}
|
|
14
|
+
pub fn new(ncat: ClsDecisionBasicType) -> Self {
|
|
15
|
+
Self(std::iter::repeat_n(0, ncat as usize).collect())
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
pub fn collapse_empty_random(&self, rng: &mut RfRng) -> ClsDecisionBasicType {
|
|
19
|
+
self.0
|
|
20
|
+
.iter()
|
|
21
|
+
.enumerate()
|
|
22
|
+
.fold(FairBest::new(), |mut best, (cls, count)| {
|
|
23
|
+
best.ingest(count, cls, rng);
|
|
24
|
+
best
|
|
25
|
+
})
|
|
26
|
+
.consume()
|
|
27
|
+
.map(|(_score, class)| class as ClsDecisionBasicType)
|
|
28
|
+
.unwrap()
|
|
29
|
+
}
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
impl VoteAggregator<DataFrame> for Votes {
|
|
33
|
+
fn new(input: &DataFrame) -> Self {
|
|
34
|
+
Votes::new(input.ncat)
|
|
35
|
+
}
|
|
36
|
+
fn ingest_vote(&mut self, v: ClsDecisionBasicType) {
|
|
37
|
+
self.0[v as usize] += 1;
|
|
38
|
+
}
|
|
39
|
+
fn merge(&mut self, other: &Self) {
|
|
40
|
+
self.0
|
|
41
|
+
.iter_mut()
|
|
42
|
+
.zip(other.0.iter())
|
|
43
|
+
.for_each(|(t, s)| *t += s);
|
|
44
|
+
}
|
|
45
|
+
}
|