gouda-cheese 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,184 @@
1
+ # This file is autogenerated by maturin v1.12.6
2
+ # To update, run
3
+ #
4
+ # maturin generate-ci github
5
+ #
6
+ name: CI
7
+
8
+ on:
9
+ push:
10
+ branches:
11
+ - main
12
+ - master
13
+ tags:
14
+ - '*'
15
+ pull_request:
16
+ workflow_dispatch:
17
+
18
+ permissions:
19
+ contents: read
20
+
21
+ jobs:
22
+ linux:
23
+ runs-on: ${{ matrix.platform.runner }}
24
+ strategy:
25
+ matrix:
26
+ platform:
27
+ - runner: ubuntu-22.04
28
+ target: x86_64
29
+ - runner: ubuntu-22.04
30
+ target: x86
31
+ - runner: ubuntu-22.04
32
+ target: aarch64
33
+ - runner: ubuntu-22.04
34
+ target: armv7
35
+ - runner: ubuntu-22.04
36
+ target: s390x
37
+ - runner: ubuntu-22.04
38
+ target: ppc64le
39
+ steps:
40
+ - uses: actions/checkout@v6
41
+ - uses: actions/setup-python@v6
42
+ with:
43
+ python-version: 3.x
44
+ - name: Build wheels
45
+ uses: PyO3/maturin-action@v1
46
+ with:
47
+ target: ${{ matrix.platform.target }}
48
+ args: --release --out dist --find-interpreter
49
+ sccache: ${{ !startsWith(github.ref, 'refs/tags/') }}
50
+ manylinux: auto
51
+ - name: Upload wheels
52
+ uses: actions/upload-artifact@v6
53
+ with:
54
+ name: wheels-linux-${{ matrix.platform.target }}
55
+ path: dist
56
+
57
+ musllinux:
58
+ runs-on: ${{ matrix.platform.runner }}
59
+ strategy:
60
+ matrix:
61
+ platform:
62
+ - runner: ubuntu-22.04
63
+ target: x86_64
64
+ - runner: ubuntu-22.04
65
+ target: x86
66
+ - runner: ubuntu-22.04
67
+ target: aarch64
68
+ - runner: ubuntu-22.04
69
+ target: armv7
70
+ steps:
71
+ - uses: actions/checkout@v6
72
+ - uses: actions/setup-python@v6
73
+ with:
74
+ python-version: 3.x
75
+ - name: Build wheels
76
+ uses: PyO3/maturin-action@v1
77
+ with:
78
+ target: ${{ matrix.platform.target }}
79
+ args: --release --out dist --find-interpreter
80
+ sccache: ${{ !startsWith(github.ref, 'refs/tags/') }}
81
+ manylinux: musllinux_1_2
82
+ - name: Upload wheels
83
+ uses: actions/upload-artifact@v6
84
+ with:
85
+ name: wheels-musllinux-${{ matrix.platform.target }}
86
+ path: dist
87
+
88
+ windows:
89
+ runs-on: ${{ matrix.platform.runner }}
90
+ strategy:
91
+ matrix:
92
+ platform:
93
+ - runner: windows-latest
94
+ target: x64
95
+ python_arch: x64
96
+ - runner: windows-latest
97
+ target: x86
98
+ python_arch: x86
99
+ - runner: windows-11-arm
100
+ target: aarch64
101
+ python_arch: arm64
102
+ steps:
103
+ - uses: actions/checkout@v6
104
+ - uses: actions/setup-python@v6
105
+ with:
106
+ python-version: 3.13
107
+ architecture: ${{ matrix.platform.python_arch }}
108
+ - name: Build wheels
109
+ uses: PyO3/maturin-action@v1
110
+ with:
111
+ target: ${{ matrix.platform.target }}
112
+ args: --release --out dist --find-interpreter
113
+ sccache: ${{ !startsWith(github.ref, 'refs/tags/') }}
114
+ - name: Upload wheels
115
+ uses: actions/upload-artifact@v6
116
+ with:
117
+ name: wheels-windows-${{ matrix.platform.target }}
118
+ path: dist
119
+
120
+ macos:
121
+ runs-on: ${{ matrix.platform.runner }}
122
+ strategy:
123
+ matrix:
124
+ platform:
125
+ - runner: macos-15-intel
126
+ target: x86_64
127
+ - runner: macos-latest
128
+ target: aarch64
129
+ steps:
130
+ - uses: actions/checkout@v6
131
+ - uses: actions/setup-python@v6
132
+ with:
133
+ python-version: 3.x
134
+ - name: Build wheels
135
+ uses: PyO3/maturin-action@v1
136
+ with:
137
+ target: ${{ matrix.platform.target }}
138
+ args: --release --out dist --find-interpreter
139
+ sccache: ${{ !startsWith(github.ref, 'refs/tags/') }}
140
+ - name: Upload wheels
141
+ uses: actions/upload-artifact@v6
142
+ with:
143
+ name: wheels-macos-${{ matrix.platform.target }}
144
+ path: dist
145
+
146
+ sdist:
147
+ runs-on: ubuntu-latest
148
+ steps:
149
+ - uses: actions/checkout@v6
150
+ - name: Build sdist
151
+ uses: PyO3/maturin-action@v1
152
+ with:
153
+ command: sdist
154
+ args: --out dist
155
+ - name: Upload sdist
156
+ uses: actions/upload-artifact@v6
157
+ with:
158
+ name: wheels-sdist
159
+ path: dist
160
+
161
+ release:
162
+ name: Release
163
+ runs-on: ubuntu-latest
164
+ if: ${{ startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' }}
165
+ needs: [linux, musllinux, windows, macos, sdist]
166
+ permissions:
167
+ # Use to sign the release artifacts
168
+ id-token: write
169
+ # Used to upload release artifacts
170
+ contents: write
171
+ # Used to generate artifact attestation
172
+ attestations: write
173
+ steps:
174
+ - uses: actions/download-artifact@v7
175
+ - name: Generate artifact attestation
176
+ uses: actions/attest-build-provenance@v3
177
+ with:
178
+ subject-path: 'wheels-*/*'
179
+ - name: Install uv
180
+ if: ${{ startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' }}
181
+ uses: astral-sh/setup-uv@v7
182
+ - name: Publish to PyPI
183
+ if: ${{ startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' }}
184
+ run: uv publish 'wheels-*/*'
@@ -0,0 +1,230 @@
1
+ # This file is automatically @generated by Cargo.
2
+ # It is not intended for manual editing.
3
+ version = 4
4
+
5
+ [[package]]
6
+ name = "autocfg"
7
+ version = "1.5.0"
8
+ source = "registry+https://github.com/rust-lang/crates.io-index"
9
+ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
10
+
11
+ [[package]]
12
+ name = "gouda-cheese"
13
+ version = "0.1.0"
14
+ dependencies = [
15
+ "ndarray",
16
+ "numpy",
17
+ "pyo3",
18
+ ]
19
+
20
+ [[package]]
21
+ name = "heck"
22
+ version = "0.5.0"
23
+ source = "registry+https://github.com/rust-lang/crates.io-index"
24
+ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
25
+
26
+ [[package]]
27
+ name = "libc"
28
+ version = "0.2.184"
29
+ source = "registry+https://github.com/rust-lang/crates.io-index"
30
+ checksum = "48f5d2a454e16a5ea0f4ced81bd44e4cfc7bd3a507b61887c99fd3538b28e4af"
31
+
32
+ [[package]]
33
+ name = "matrixmultiply"
34
+ version = "0.3.10"
35
+ source = "registry+https://github.com/rust-lang/crates.io-index"
36
+ checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08"
37
+ dependencies = [
38
+ "autocfg",
39
+ "rawpointer",
40
+ ]
41
+
42
+ [[package]]
43
+ name = "ndarray"
44
+ version = "0.17.2"
45
+ source = "registry+https://github.com/rust-lang/crates.io-index"
46
+ checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d"
47
+ dependencies = [
48
+ "matrixmultiply",
49
+ "num-complex",
50
+ "num-integer",
51
+ "num-traits",
52
+ "portable-atomic",
53
+ "portable-atomic-util",
54
+ "rawpointer",
55
+ ]
56
+
57
+ [[package]]
58
+ name = "num-complex"
59
+ version = "0.4.6"
60
+ source = "registry+https://github.com/rust-lang/crates.io-index"
61
+ checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
62
+ dependencies = [
63
+ "num-traits",
64
+ ]
65
+
66
+ [[package]]
67
+ name = "num-integer"
68
+ version = "0.1.46"
69
+ source = "registry+https://github.com/rust-lang/crates.io-index"
70
+ checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
71
+ dependencies = [
72
+ "num-traits",
73
+ ]
74
+
75
+ [[package]]
76
+ name = "num-traits"
77
+ version = "0.2.19"
78
+ source = "registry+https://github.com/rust-lang/crates.io-index"
79
+ checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
80
+ dependencies = [
81
+ "autocfg",
82
+ ]
83
+
84
+ [[package]]
85
+ name = "numpy"
86
+ version = "0.28.0"
87
+ source = "registry+https://github.com/rust-lang/crates.io-index"
88
+ checksum = "778da78c64ddc928ebf5ad9df5edf0789410ff3bdbf3619aed51cd789a6af1e2"
89
+ dependencies = [
90
+ "libc",
91
+ "ndarray",
92
+ "num-complex",
93
+ "num-integer",
94
+ "num-traits",
95
+ "pyo3",
96
+ "pyo3-build-config",
97
+ "rustc-hash",
98
+ ]
99
+
100
+ [[package]]
101
+ name = "once_cell"
102
+ version = "1.21.4"
103
+ source = "registry+https://github.com/rust-lang/crates.io-index"
104
+ checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50"
105
+
106
+ [[package]]
107
+ name = "portable-atomic"
108
+ version = "1.13.1"
109
+ source = "registry+https://github.com/rust-lang/crates.io-index"
110
+ checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
111
+
112
+ [[package]]
113
+ name = "portable-atomic-util"
114
+ version = "0.2.6"
115
+ source = "registry+https://github.com/rust-lang/crates.io-index"
116
+ checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3"
117
+ dependencies = [
118
+ "portable-atomic",
119
+ ]
120
+
121
+ [[package]]
122
+ name = "proc-macro2"
123
+ version = "1.0.106"
124
+ source = "registry+https://github.com/rust-lang/crates.io-index"
125
+ checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
126
+ dependencies = [
127
+ "unicode-ident",
128
+ ]
129
+
130
+ [[package]]
131
+ name = "pyo3"
132
+ version = "0.28.3"
133
+ source = "registry+https://github.com/rust-lang/crates.io-index"
134
+ checksum = "91fd8e38a3b50ed1167fb981cd6fd60147e091784c427b8f7183a7ee32c31c12"
135
+ dependencies = [
136
+ "libc",
137
+ "once_cell",
138
+ "portable-atomic",
139
+ "pyo3-build-config",
140
+ "pyo3-ffi",
141
+ "pyo3-macros",
142
+ ]
143
+
144
+ [[package]]
145
+ name = "pyo3-build-config"
146
+ version = "0.28.3"
147
+ source = "registry+https://github.com/rust-lang/crates.io-index"
148
+ checksum = "e368e7ddfdeb98c9bca7f8383be1648fd84ab466bf2bc015e94008db6d35611e"
149
+ dependencies = [
150
+ "target-lexicon",
151
+ ]
152
+
153
+ [[package]]
154
+ name = "pyo3-ffi"
155
+ version = "0.28.3"
156
+ source = "registry+https://github.com/rust-lang/crates.io-index"
157
+ checksum = "7f29e10af80b1f7ccaf7f69eace800a03ecd13e883acfacc1e5d0988605f651e"
158
+ dependencies = [
159
+ "libc",
160
+ "pyo3-build-config",
161
+ ]
162
+
163
+ [[package]]
164
+ name = "pyo3-macros"
165
+ version = "0.28.3"
166
+ source = "registry+https://github.com/rust-lang/crates.io-index"
167
+ checksum = "df6e520eff47c45997d2fc7dd8214b25dd1310918bbb2642156ef66a67f29813"
168
+ dependencies = [
169
+ "proc-macro2",
170
+ "pyo3-macros-backend",
171
+ "quote",
172
+ "syn",
173
+ ]
174
+
175
+ [[package]]
176
+ name = "pyo3-macros-backend"
177
+ version = "0.28.3"
178
+ source = "registry+https://github.com/rust-lang/crates.io-index"
179
+ checksum = "c4cdc218d835738f81c2338f822078af45b4afdf8b2e33cbb5916f108b813acb"
180
+ dependencies = [
181
+ "heck",
182
+ "proc-macro2",
183
+ "pyo3-build-config",
184
+ "quote",
185
+ "syn",
186
+ ]
187
+
188
+ [[package]]
189
+ name = "quote"
190
+ version = "1.0.45"
191
+ source = "registry+https://github.com/rust-lang/crates.io-index"
192
+ checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924"
193
+ dependencies = [
194
+ "proc-macro2",
195
+ ]
196
+
197
+ [[package]]
198
+ name = "rawpointer"
199
+ version = "0.2.1"
200
+ source = "registry+https://github.com/rust-lang/crates.io-index"
201
+ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
202
+
203
+ [[package]]
204
+ name = "rustc-hash"
205
+ version = "2.1.2"
206
+ source = "registry+https://github.com/rust-lang/crates.io-index"
207
+ checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe"
208
+
209
+ [[package]]
210
+ name = "syn"
211
+ version = "2.0.117"
212
+ source = "registry+https://github.com/rust-lang/crates.io-index"
213
+ checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99"
214
+ dependencies = [
215
+ "proc-macro2",
216
+ "quote",
217
+ "unicode-ident",
218
+ ]
219
+
220
+ [[package]]
221
+ name = "target-lexicon"
222
+ version = "0.13.5"
223
+ source = "registry+https://github.com/rust-lang/crates.io-index"
224
+ checksum = "adb6935a6f5c20170eeceb1a3835a49e12e19d792f6dd344ccc76a985ca5a6ca"
225
+
226
+ [[package]]
227
+ name = "unicode-ident"
228
+ version = "1.0.24"
229
+ source = "registry+https://github.com/rust-lang/crates.io-index"
230
+ checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"
@@ -0,0 +1,14 @@
1
+ [package]
2
+ name = "gouda-cheese"
3
+ version = "0.1.0"
4
+ edition = "2024"
5
+
6
+ # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
7
+ [lib]
8
+ name = "gouda"
9
+ crate-type = ["cdylib"]
10
+
11
+ [dependencies]
12
+ ndarray = "0.17.2"
13
+ numpy = "0.28.0"
14
+ pyo3 = { version = "0.28", features = ["extension-module"] }
@@ -0,0 +1,7 @@
1
+ Metadata-Version: 2.4
2
+ Name: gouda-cheese
3
+ Version: 0.1.0
4
+ Classifier: Programming Language :: Rust
5
+ Classifier: Programming Language :: Python :: Implementation :: CPython
6
+ Classifier: Programming Language :: Python :: Implementation :: PyPy
7
+ Requires-Python: >=3.8
@@ -0,0 +1,11 @@
1
+ name: gouda
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - python=3.13
6
+ - numpy
7
+ - pandas
8
+ - scikit-learn
9
+ - maturin
10
+ - pytest
11
+ prefix: /home/tim/nas/conda_envs/gouda
@@ -0,0 +1,13 @@
1
+ [build-system]
2
+ requires = ["maturin>=1.12,<2.0"]
3
+ build-backend = "maturin"
4
+
5
+ [project]
6
+ name = "gouda-cheese"
7
+ requires-python = ">=3.8"
8
+ classifiers = [
9
+ "Programming Language :: Rust",
10
+ "Programming Language :: Python :: Implementation :: CPython",
11
+ "Programming Language :: Python :: Implementation :: PyPy",
12
+ ]
13
+ dynamic = ["version"]
@@ -0,0 +1,4 @@
1
+ # Imputation Library
2
+
3
+ Commonly available imputation tools lack support for categorical values, and sometimes do not work at all.
4
+ This project wants to add robust imputation benchmarks, with significant performance boost to other available tools.
@@ -0,0 +1,67 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import pandas as pd
4
+ from dataclasses import dataclass, field
5
+
6
+
7
+ class Generator(nn.Module):
8
+ """
9
+ Input:
10
+ x: Tabular data
11
+ random: randomly generated values
12
+ missing mask: binary mask
13
+ Out:
14
+ x: Imputed Values
15
+ """
16
+
17
+ # TODO:
18
+ # - Define Model for imputation
19
+ # Maybe something simple?
20
+ def __init__(self):
21
+ super().__init__()
22
+ pass
23
+
24
+ def forward(self, x: torch.Tensor, random_value: torch.Tensor, missing_mask: torch.Tensor):
25
+ return x
26
+
27
+
28
+ class Discriminator(nn.Module):
29
+ """
30
+ Input:
31
+ x: Tabular data
32
+ Out:
33
+ m: Binary Mask if a value is generated or real
34
+ """
35
+
36
+ # TODO:
37
+ # - Define Model that judges
38
+ def __init__(self):
39
+ super().__init__()
40
+
41
+ def forward(self, x: torch.Tensor):
42
+ return x > 0
43
+
44
+
45
+ @dataclass
46
+ class Gain:
47
+ seed: int = 42
48
+ rng: torch.Generator = field(init=False)
49
+ generator: Generator = field(init=False)
50
+ discriminator: Discriminator = field(init=False)
51
+
52
+ def __post_init__(self):
53
+ self.rng = torch.Generator().manual_seed(self.seed)
54
+
55
+ def fit(self, x: pd.DataFrame):
56
+ nrows, ncols = x.shape
57
+
58
+ def transform(self, x: pd.DataFrame):
59
+ nrows, ncols = x.shape
60
+ mm = x.isna()
61
+ random_values = torch.rand(nrows, ncols, generator=self.rng)
62
+ x = x.fillna(0)
63
+ return self.generator(
64
+ torch.tensor(x.values),
65
+ random_values,
66
+ torch.tensor(mm.values)
67
+ )
@@ -0,0 +1,307 @@
1
+ use crate::utils::pyany_to_vec;
2
+ use core::f64;
3
+ use numpy::{IntoPyArray, PyArray2};
4
+ use pyo3::prelude::*;
5
+ use pyo3::types::PyAny;
6
+ use std::ops::Index;
7
+
8
+ #[pyclass]
9
+ pub struct KnnImputer {
10
+ #[pyo3(get, set)]
11
+ k: usize,
12
+ nrows: usize,
13
+ ncols: usize,
14
+ data: Option<Vec<f64>>,
15
+ is_fitted: bool,
16
+ metric: String,
17
+ weights: String,
18
+ }
19
+
20
+ const ALLOWED_WEIGHTS: [&str; 2] = ["uniform", "distance"];
21
+
22
+ #[pymethods]
23
+ impl KnnImputer {
24
+ #[new]
25
+ #[pyo3(signature = (k=5, metric="nan_euclid", weights="uniform"))]
26
+ pub fn new(k: usize, metric: &str, weights: &str) -> KnnImputer {
27
+ assert!(ALLOWED_WEIGHTS.contains(&weights));
28
+ KnnImputer {
29
+ k,
30
+ nrows: 0,
31
+ ncols: 0,
32
+ data: None,
33
+ is_fitted: false,
34
+ metric: metric.to_owned(),
35
+ weights: weights.to_owned(),
36
+ }
37
+ }
38
+
39
+ pub fn fit(slf: Py<Self>, py: Python<'_>, data: &Bound<'_, PyAny>) -> PyResult<Py<Self>> {
40
+ let (vec, nrows, ncols) = pyany_to_vec(py, data)?;
41
+ {
42
+ let mut inner = slf.borrow_mut(py);
43
+ inner.data = Some(vec);
44
+ inner.nrows = nrows;
45
+ inner.ncols = ncols;
46
+ inner.is_fitted = true;
47
+ } // dropping inner here (releasing the mutex)
48
+ Ok(slf)
49
+ }
50
+
51
+ pub fn transform<'py>(
52
+ &self,
53
+ py: Python<'py>,
54
+ data: &Bound<'_, PyAny>,
55
+ ) -> PyResult<Bound<'py, PyArray2<f64>>> {
56
+ // check if fitted
57
+ if !self.is_fitted {
58
+ return Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(format!(
59
+ "Imputer is not fitted",
60
+ )));
61
+ }
62
+ let (data, nrows, _) = pyany_to_vec(py, data)?;
63
+ // actual method
64
+ let dist = match self.metric.as_str() {
65
+ "nan_euclid" => Self::nan_euclid,
66
+ "expected_distance" => Self::expected_distance,
67
+ m => {
68
+ return Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(format!(
69
+ "{} is a unknown metric",
70
+ m
71
+ )));
72
+ }
73
+ };
74
+ let imputed = self.brute_force(&data, nrows, dist);
75
+ // return python object
76
+ let array = ndarray::Array2::from_shape_vec((self.nrows, self.ncols), imputed)
77
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
78
+ Ok(array.into_pyarray(py))
79
+ }
80
+ }
81
+
82
+ impl KnnImputer {
83
+ fn brute_force(
84
+ &self,
85
+ data: &[f64],
86
+ nrows: usize,
87
+ dist: fn(&KnnImputer, &[f64], &[f64]) -> f64,
88
+ ) -> Vec<f64> {
89
+ let mut imputed = data.to_vec();
90
+ let mut i = 0;
91
+ while i < data.len() {
92
+ if data[i].is_nan() {
93
+ // figure out point and impute
94
+ let row = i / self.ncols;
95
+ let col = i % self.ncols;
96
+ let mut cols = Vec::with_capacity(20);
97
+ for j in col..self.ncols {
98
+ let l = i + j - col;
99
+ if l >= data.len() {
100
+ break;
101
+ }
102
+ if data[l].is_nan() {
103
+ cols.push(j);
104
+ }
105
+ }
106
+ let mut distances = Vec::with_capacity(nrows);
107
+ let p = &self[row];
108
+ for r in 0..self.nrows {
109
+ if r == row {
110
+ distances.push(f64::MAX); // dont consider distance to self
111
+ } else {
112
+ distances.push(dist(&self, p, &self[r]));
113
+ }
114
+ }
115
+ let mut indices: Vec<usize> = (0..nrows).collect();
116
+ // indices.sort_by(|&a, &b| distances[a].total_cmp(&distances[b]));
117
+ indices.sort_unstable_by(|&a, &b| distances[a].total_cmp(&distances[b]));
118
+ let avgs = self.average(&indices, &cols, &self.get_weights(&distances));
119
+ for (avg, c) in avgs.into_iter().zip(&cols) {
120
+ imputed[i + c - col] = avg;
121
+ }
122
+ i += self.ncols - col;
123
+ } else {
124
+ i += 1;
125
+ }
126
+ }
127
+ imputed
128
+ }
129
+
130
+ fn average(&self, indices: &[usize], cols: &[usize], weights: &[f64]) -> Vec<f64> {
131
+ let mut avg: Vec<f64> = vec![0.0; cols.len()];
132
+ for (j, c) in cols.iter().enumerate() {
133
+ let mut count = 0;
134
+ for i in indices {
135
+ let val = self[*i][*c];
136
+ if val.is_nan() {
137
+ continue;
138
+ }
139
+ avg[j] += val * weights[*i];
140
+ count += 1;
141
+ if count >= self.k {
142
+ break;
143
+ }
144
+ }
145
+ // avg[j] *= 1.0 / self.k as f64;
146
+ }
147
+ avg
148
+ }
149
+
150
+ fn get_weights(&self, distances: &[f64]) -> Vec<f64> {
151
+ match self.weights.as_str() {
152
+ "uniform" => distances.iter().map(|_| 1.0 / self.k as f64).collect(),
153
+ "distances" => distances.iter().map(|d| 1.0 / d).collect(),
154
+ w => panic!("Unknown weight {}", w),
155
+ }
156
+ }
157
+ }
158
+
159
+ // Distance Functions
160
+ impl KnnImputer {
161
+ // TODO:
162
+ // write tests
163
+ fn nan_euclid(&self, a: &[f64], b: &[f64]) -> f64 {
164
+ let mut total = 0.0;
165
+ let mut valid = 0;
166
+ let ncols = a.len();
167
+ for i in 0..ncols {
168
+ let (x, y) = (a[i], b[i]);
169
+ if !(x.is_nan() || y.is_nan()) {
170
+ let d = x - y;
171
+ total += d * d;
172
+ valid += 1;
173
+ }
174
+ }
175
+ if valid == 0 {
176
+ return f64::INFINITY;
177
+ }
178
+ total * (ncols as f64 / valid as f64)
179
+ }
180
+
181
+ fn expected_distance(&self, a: &[f64], b: &[f64]) -> f64 {
182
+ let mut total = 0.0;
183
+ let mut total_obs = 0.0;
184
+ let ncols = a.len();
185
+ for i in 0..ncols {
186
+ let (x, y) = (a[i], b[i]);
187
+ match (x.is_nan(), y.is_nan()) {
188
+ (true, true) => total += 0.333,
189
+ (true, false) => total += y.max(1.0 - y),
190
+ (false, true) => total += x.max(1.0 - x),
191
+ (false, false) => {
192
+ let d = x - y;
193
+ total_obs += d * d
194
+ }
195
+ }
196
+ }
197
+ total + total_obs.sqrt()
198
+ }
199
+ }
200
+
201
+ impl Index<usize> for KnnImputer {
202
+ type Output = [f64];
203
+
204
+ fn index(&self, row: usize) -> &[f64] {
205
+ let offset = row * self.ncols;
206
+ &self.data.as_ref().unwrap()[offset..offset + self.ncols]
207
+ }
208
+ }
209
+
210
+ #[cfg(test)]
211
+ mod tests {
212
+ use super::*; // has access to everything, including private
213
+
214
+ #[test]
215
+ fn test_nan_euclid() {
216
+ let knn = KnnImputer::new(5, "nan_euclid", "uniform");
217
+ let p = &[f64::NAN, 0.22129885, 0.8863533, 0.50595314, 0.5011135];
218
+ let points = &[
219
+ [0.76052103, f64::NAN, 0.4094729, 0.9573324, f64::NAN],
220
+ [0.27839605, 0.7338148, 0.98359227, 0.98189233, 0.45384631],
221
+ [f64::NAN, 0.22129885, 0.8863533, 0.50595314, 0.5011135],
222
+ [f64::NAN, 0.32309935, 0.64573872, f64::NAN, f64::NAN],
223
+ [0.9317995, 0.51597243, 0.38054457, 0.62366235, 0.12229672],
224
+ [0.90547984, f64::NAN, 0.68424979, 0.55400964, 0.55284803],
225
+ [0.68846839, 0.53889275, 0.44453843, 0.43416536, 0.18575075],
226
+ [0.13333331, 0.8772666, 0.64398646, f64::NAN, 0.90529859],
227
+ [0.69819416, 0.65251852, 0.39663618, 0.65702538, f64::NAN],
228
+ ];
229
+ let expected = &[
230
+ 1.0382174099275148,
231
+ 0.7912650658744038,
232
+ 0.0,
233
+ 0.41309417813332494,
234
+ 0.7905951937189456,
235
+ 0.2763805321428371,
236
+ 0.7077017509263522,
237
+ 1.042753531574897,
238
+ 0.8646734303095986,
239
+ ];
240
+
241
+ for (e, q) in expected.iter().zip(points) {
242
+ let result = knn.nan_euclid(p, q).sqrt();
243
+ assert!(
244
+ (result - e).abs() < 1e-7,
245
+ "Expected: {}; Actual: {}",
246
+ e,
247
+ result
248
+ );
249
+ }
250
+ }
251
+ #[test]
252
+ fn test_expected_distance() {
253
+ let p = &[f64::NAN, 0.555556, f64::NAN, 0.555556];
254
+ let points = &[
255
+ [0.0, 0.777778, 0.0, 0.777778],
256
+ [f64::NAN, 0.333333, 0.666667, 0.333333],
257
+ [0.0, 1.0, 0.0, 1.0],
258
+ [0.0, 0.88889, 0.0, 0.88889],
259
+ [0.0, 0.44444, 0.0, 0.44444],
260
+ [0.666667, f64::NAN, 0.666667, f64::NAN],
261
+ [f64::NAN, f64::NAN, f64::NAN, f64::NAN],
262
+ [f64::NAN, 0.555556, f64::NAN, 0.555556],
263
+ ];
264
+ let expected = &[
265
+ 2.314269366257674,
266
+ 1.3139377804712364,
267
+ 2.6285387325153478,
268
+ 2.4714054636000733,
269
+ 2.157141754196649,
270
+ 2.444446,
271
+ 1.777112,
272
+ 0.666,
273
+ ];
274
+ let knn = KnnImputer::new(5, "expected_distance", "uniform");
275
+ for (e, q) in expected.iter().zip(points) {
276
+ let result = knn.expected_distance(p, q);
277
+ assert!(
278
+ (result - e).abs() < 1e-9,
279
+ "Expected: {}; Actual: {}",
280
+ e,
281
+ result
282
+ );
283
+ }
284
+ }
285
+
286
+ #[test]
287
+ fn test_compare() {
288
+ let knn = KnnImputer::new(5, "nan_euclid", "uniform");
289
+ let (a, b) = (&[1.0, 2.0], &[3.0, 4.0]);
290
+ let diff = knn.nan_euclid(a, b).sqrt() - knn.expected_distance(a, b);
291
+ assert!((diff).abs() < 1e-10, "Expected: 0.0; Actual {}", diff);
292
+
293
+ let (a, b) = (&[1.0, f64::NAN], &[3.0, f64::NAN]);
294
+ // 2.8284271247461903
295
+ let euclid = knn.nan_euclid(a, b).sqrt();
296
+ // 2 + 1/3
297
+ let ed = knn.expected_distance(a, b);
298
+ let diff = euclid - ed;
299
+ assert!(
300
+ (diff - 0.4954271247461901).abs() < 1e-10,
301
+ "Expected: 0.4954271247461901 ; Actual {}\nEuclid: {}; ED: {}",
302
+ diff,
303
+ euclid,
304
+ ed
305
+ );
306
+ }
307
+ }
@@ -0,0 +1,15 @@
1
+ use pyo3::prelude::*;
2
+ mod knn;
3
+ mod utils;
4
+
5
+ /// A Python module implemented in Rust.
6
+ #[pymodule]
7
+ mod gouda {
8
+ use super::*;
9
+
10
+ #[pymodule_init]
11
+ fn init(module: &Bound<'_, PyModule>) -> PyResult<()> {
12
+ module.add_class::<knn::KnnImputer>()?;
13
+ Ok(())
14
+ }
15
+ }
@@ -0,0 +1,45 @@
1
+ use numpy::{PyReadonlyArray2, PyUntypedArrayMethods};
2
+ use pyo3::prelude::*;
3
+ use pyo3::types::PyAny;
4
+
5
+ const SUPPORTED_TYPES: &str = "numpy.ndarray, pandas.DataFrame";
6
+
7
+ pub fn pyany_to_vec(py: Python<'_>, obj: &Bound<'_, PyAny>) -> PyResult<(Vec<f64>, usize, usize)> {
8
+ // 1. numpy
9
+ if let Ok(arr) = obj.extract::<PyReadonlyArray2<f64>>() {
10
+ let shape = arr.shape();
11
+ let (nrows, ncols) = (shape[0], shape[1]);
12
+ // as_slice() gives the flat buffer directly if C-contiguous
13
+ let data = match arr.as_slice() {
14
+ Ok(s) => s.to_vec(),
15
+ Err(_) => arr.as_array().iter().copied().collect(), // non-contiguous fallback
16
+ };
17
+ return Ok((data, nrows, ncols));
18
+ }
19
+
20
+ // 2. pandas
21
+ let pandas = py.import("pandas")?;
22
+ if obj.is_instance(&pandas.getattr("DataFrame")?)? {
23
+ let np_module = py.import("numpy")?;
24
+ let np = np_module.call_method1("ascontiguousarray", (obj.call_method0("to_numpy")?,))?;
25
+ let arr = np.extract::<PyReadonlyArray2<f64>>()?;
26
+ let shape = arr.shape();
27
+ let (nrows, ncols) = (shape[0], shape[1]);
28
+ let data = match arr.as_slice() {
29
+ Ok(s) => s.to_vec(),
30
+ Err(_) => arr.as_array().iter().copied().collect(),
31
+ };
32
+ return Ok((data, nrows, ncols));
33
+ }
34
+
35
+ let type_name = obj
36
+ .get_type()
37
+ .qualname()
38
+ .map(|s| s.to_string())
39
+ .unwrap_or_else(|_| "unknown".to_string());
40
+
41
+ Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(format!(
42
+ "Unsupported type: '{}'. Supported types are: {}",
43
+ type_name, SUPPORTED_TYPES
44
+ )))
45
+ }
@@ -0,0 +1,5 @@
1
+ #!/usr/bin/env bash
2
+
3
+ cargo test
4
+ maturin develop --release
5
+ python -m pytest tests
@@ -0,0 +1,73 @@
1
+ import numpy as np
2
+ import time
3
+ from sklearn.impute import KNNImputer as SKKNN
4
+ from gouda import KnnImputer as RSKNN
5
+
6
+
7
+ def test_time():
8
+ data = np.random.rand(500, 50)
9
+ data[data < 0.78] = np.nan
10
+
11
+ N = 5
12
+ # Warmup
13
+ for _ in range(3):
14
+ RSKNN().fit(data).transform(data)
15
+ SKKNN().fit(data).transform(data)
16
+
17
+ # Benchmark Rust
18
+ times_rs = []
19
+ for _ in range(N):
20
+ imputer = RSKNN()
21
+ imputer.fit(data)
22
+ start = time.perf_counter_ns()
23
+ _ = imputer.transform(data)
24
+ times_rs.append(time.perf_counter_ns() - start)
25
+
26
+ # Benchmark sklearn
27
+ times_sk = []
28
+ for _ in range(N):
29
+ imputer = SKKNN()
30
+ imputer.fit(data)
31
+ start = time.perf_counter_ns()
32
+ _ = imputer.transform(data)
33
+ times_sk.append(time.perf_counter_ns() - start)
34
+
35
+ elapsed_rs = sorted(times_rs)[N // 2] # median
36
+ elapsed_sk = sorted(times_sk)[N // 2]
37
+
38
+ assert elapsed_rs < elapsed_sk * 0.5, f"Rust: {
39
+ elapsed_rs}ns sklearn: {elapsed_sk}ns"
40
+
41
+
42
+ def test_nans():
43
+ data = np.random.rand(500, 5)
44
+ data[data < 0.48] = np.nan
45
+ imputed = RSKNN().fit(data).transform(data)
46
+ print("data:\n", data)
47
+ print("imputed:\n", imputed)
48
+ assert not np.isnan(imputed).any(), "Imputed still has missing values"
49
+ imputed = RSKNN(metric="expected_distance").fit(data).transform(data)
50
+ assert not np.isnan(imputed).any(), "Imputed still has missing values"
51
+
52
+
53
+ def test_isclose():
54
+ data = np.random.rand(500, 50)
55
+ data[data < 0.38] = np.nan
56
+ imputed_rs = RSKNN().fit(data).transform(data)
57
+ imputed_sk = SKKNN().fit(data).transform(data)
58
+ print(data, "\n")
59
+ print(imputed_sk, "\n")
60
+ print(imputed_rs, "\n")
61
+ avg_dist = (imputed_sk - imputed_rs).sum() / data.size
62
+ assert np.isclose(imputed_rs, imputed_sk, atol=0.1).all(
63
+ ), f"wrong distances {avg_dist}"
64
+
65
+
66
+ def test_ed():
67
+ data = np.random.rand(500, 50)
68
+ assert (data < 1.0).all()
69
+ assert (data >= 0.0).all()
70
+ data[data < 0.38] = np.nan
71
+ imputed_rs = RSKNN(metric="expected_distance").fit(data).transform(data)
72
+ imputed_sk = SKKNN().fit(data).transform(data)
73
+ assert not np.isclose(imputed_rs, imputed_sk).all(), "wrong distances"