gouda-cheese 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gouda_cheese-0.1.0/.github/workflows/CI.yml +184 -0
- gouda_cheese-0.1.0/Cargo.lock +230 -0
- gouda_cheese-0.1.0/Cargo.toml +14 -0
- gouda_cheese-0.1.0/PKG-INFO +7 -0
- gouda_cheese-0.1.0/environment.yml +11 -0
- gouda_cheese-0.1.0/pyproject.toml +13 -0
- gouda_cheese-0.1.0/readme.md +4 -0
- gouda_cheese-0.1.0/src/deeplearning/gain/gain.py +67 -0
- gouda_cheese-0.1.0/src/knn.rs +307 -0
- gouda_cheese-0.1.0/src/lib.rs +15 -0
- gouda_cheese-0.1.0/src/utils.rs +45 -0
- gouda_cheese-0.1.0/test_all.sh +5 -0
- gouda_cheese-0.1.0/tests/test_performance.py +73 -0
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
# This file is autogenerated by maturin v1.12.6
|
|
2
|
+
# To update, run
|
|
3
|
+
#
|
|
4
|
+
# maturin generate-ci github
|
|
5
|
+
#
|
|
6
|
+
name: CI
|
|
7
|
+
|
|
8
|
+
on:
|
|
9
|
+
push:
|
|
10
|
+
branches:
|
|
11
|
+
- main
|
|
12
|
+
- master
|
|
13
|
+
tags:
|
|
14
|
+
- '*'
|
|
15
|
+
pull_request:
|
|
16
|
+
workflow_dispatch:
|
|
17
|
+
|
|
18
|
+
permissions:
|
|
19
|
+
contents: read
|
|
20
|
+
|
|
21
|
+
jobs:
|
|
22
|
+
linux:
|
|
23
|
+
runs-on: ${{ matrix.platform.runner }}
|
|
24
|
+
strategy:
|
|
25
|
+
matrix:
|
|
26
|
+
platform:
|
|
27
|
+
- runner: ubuntu-22.04
|
|
28
|
+
target: x86_64
|
|
29
|
+
- runner: ubuntu-22.04
|
|
30
|
+
target: x86
|
|
31
|
+
- runner: ubuntu-22.04
|
|
32
|
+
target: aarch64
|
|
33
|
+
- runner: ubuntu-22.04
|
|
34
|
+
target: armv7
|
|
35
|
+
- runner: ubuntu-22.04
|
|
36
|
+
target: s390x
|
|
37
|
+
- runner: ubuntu-22.04
|
|
38
|
+
target: ppc64le
|
|
39
|
+
steps:
|
|
40
|
+
- uses: actions/checkout@v6
|
|
41
|
+
- uses: actions/setup-python@v6
|
|
42
|
+
with:
|
|
43
|
+
python-version: 3.x
|
|
44
|
+
- name: Build wheels
|
|
45
|
+
uses: PyO3/maturin-action@v1
|
|
46
|
+
with:
|
|
47
|
+
target: ${{ matrix.platform.target }}
|
|
48
|
+
args: --release --out dist --find-interpreter
|
|
49
|
+
sccache: ${{ !startsWith(github.ref, 'refs/tags/') }}
|
|
50
|
+
manylinux: auto
|
|
51
|
+
- name: Upload wheels
|
|
52
|
+
uses: actions/upload-artifact@v6
|
|
53
|
+
with:
|
|
54
|
+
name: wheels-linux-${{ matrix.platform.target }}
|
|
55
|
+
path: dist
|
|
56
|
+
|
|
57
|
+
musllinux:
|
|
58
|
+
runs-on: ${{ matrix.platform.runner }}
|
|
59
|
+
strategy:
|
|
60
|
+
matrix:
|
|
61
|
+
platform:
|
|
62
|
+
- runner: ubuntu-22.04
|
|
63
|
+
target: x86_64
|
|
64
|
+
- runner: ubuntu-22.04
|
|
65
|
+
target: x86
|
|
66
|
+
- runner: ubuntu-22.04
|
|
67
|
+
target: aarch64
|
|
68
|
+
- runner: ubuntu-22.04
|
|
69
|
+
target: armv7
|
|
70
|
+
steps:
|
|
71
|
+
- uses: actions/checkout@v6
|
|
72
|
+
- uses: actions/setup-python@v6
|
|
73
|
+
with:
|
|
74
|
+
python-version: 3.x
|
|
75
|
+
- name: Build wheels
|
|
76
|
+
uses: PyO3/maturin-action@v1
|
|
77
|
+
with:
|
|
78
|
+
target: ${{ matrix.platform.target }}
|
|
79
|
+
args: --release --out dist --find-interpreter
|
|
80
|
+
sccache: ${{ !startsWith(github.ref, 'refs/tags/') }}
|
|
81
|
+
manylinux: musllinux_1_2
|
|
82
|
+
- name: Upload wheels
|
|
83
|
+
uses: actions/upload-artifact@v6
|
|
84
|
+
with:
|
|
85
|
+
name: wheels-musllinux-${{ matrix.platform.target }}
|
|
86
|
+
path: dist
|
|
87
|
+
|
|
88
|
+
windows:
|
|
89
|
+
runs-on: ${{ matrix.platform.runner }}
|
|
90
|
+
strategy:
|
|
91
|
+
matrix:
|
|
92
|
+
platform:
|
|
93
|
+
- runner: windows-latest
|
|
94
|
+
target: x64
|
|
95
|
+
python_arch: x64
|
|
96
|
+
- runner: windows-latest
|
|
97
|
+
target: x86
|
|
98
|
+
python_arch: x86
|
|
99
|
+
- runner: windows-11-arm
|
|
100
|
+
target: aarch64
|
|
101
|
+
python_arch: arm64
|
|
102
|
+
steps:
|
|
103
|
+
- uses: actions/checkout@v6
|
|
104
|
+
- uses: actions/setup-python@v6
|
|
105
|
+
with:
|
|
106
|
+
python-version: 3.13
|
|
107
|
+
architecture: ${{ matrix.platform.python_arch }}
|
|
108
|
+
- name: Build wheels
|
|
109
|
+
uses: PyO3/maturin-action@v1
|
|
110
|
+
with:
|
|
111
|
+
target: ${{ matrix.platform.target }}
|
|
112
|
+
args: --release --out dist --find-interpreter
|
|
113
|
+
sccache: ${{ !startsWith(github.ref, 'refs/tags/') }}
|
|
114
|
+
- name: Upload wheels
|
|
115
|
+
uses: actions/upload-artifact@v6
|
|
116
|
+
with:
|
|
117
|
+
name: wheels-windows-${{ matrix.platform.target }}
|
|
118
|
+
path: dist
|
|
119
|
+
|
|
120
|
+
macos:
|
|
121
|
+
runs-on: ${{ matrix.platform.runner }}
|
|
122
|
+
strategy:
|
|
123
|
+
matrix:
|
|
124
|
+
platform:
|
|
125
|
+
- runner: macos-15-intel
|
|
126
|
+
target: x86_64
|
|
127
|
+
- runner: macos-latest
|
|
128
|
+
target: aarch64
|
|
129
|
+
steps:
|
|
130
|
+
- uses: actions/checkout@v6
|
|
131
|
+
- uses: actions/setup-python@v6
|
|
132
|
+
with:
|
|
133
|
+
python-version: 3.x
|
|
134
|
+
- name: Build wheels
|
|
135
|
+
uses: PyO3/maturin-action@v1
|
|
136
|
+
with:
|
|
137
|
+
target: ${{ matrix.platform.target }}
|
|
138
|
+
args: --release --out dist --find-interpreter
|
|
139
|
+
sccache: ${{ !startsWith(github.ref, 'refs/tags/') }}
|
|
140
|
+
- name: Upload wheels
|
|
141
|
+
uses: actions/upload-artifact@v6
|
|
142
|
+
with:
|
|
143
|
+
name: wheels-macos-${{ matrix.platform.target }}
|
|
144
|
+
path: dist
|
|
145
|
+
|
|
146
|
+
sdist:
|
|
147
|
+
runs-on: ubuntu-latest
|
|
148
|
+
steps:
|
|
149
|
+
- uses: actions/checkout@v6
|
|
150
|
+
- name: Build sdist
|
|
151
|
+
uses: PyO3/maturin-action@v1
|
|
152
|
+
with:
|
|
153
|
+
command: sdist
|
|
154
|
+
args: --out dist
|
|
155
|
+
- name: Upload sdist
|
|
156
|
+
uses: actions/upload-artifact@v6
|
|
157
|
+
with:
|
|
158
|
+
name: wheels-sdist
|
|
159
|
+
path: dist
|
|
160
|
+
|
|
161
|
+
release:
|
|
162
|
+
name: Release
|
|
163
|
+
runs-on: ubuntu-latest
|
|
164
|
+
if: ${{ startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' }}
|
|
165
|
+
needs: [linux, musllinux, windows, macos, sdist]
|
|
166
|
+
permissions:
|
|
167
|
+
# Use to sign the release artifacts
|
|
168
|
+
id-token: write
|
|
169
|
+
# Used to upload release artifacts
|
|
170
|
+
contents: write
|
|
171
|
+
# Used to generate artifact attestation
|
|
172
|
+
attestations: write
|
|
173
|
+
steps:
|
|
174
|
+
- uses: actions/download-artifact@v7
|
|
175
|
+
- name: Generate artifact attestation
|
|
176
|
+
uses: actions/attest-build-provenance@v3
|
|
177
|
+
with:
|
|
178
|
+
subject-path: 'wheels-*/*'
|
|
179
|
+
- name: Install uv
|
|
180
|
+
if: ${{ startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' }}
|
|
181
|
+
uses: astral-sh/setup-uv@v7
|
|
182
|
+
- name: Publish to PyPI
|
|
183
|
+
if: ${{ startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' }}
|
|
184
|
+
run: uv publish 'wheels-*/*'
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
# This file is automatically @generated by Cargo.
|
|
2
|
+
# It is not intended for manual editing.
|
|
3
|
+
version = 4
|
|
4
|
+
|
|
5
|
+
[[package]]
|
|
6
|
+
name = "autocfg"
|
|
7
|
+
version = "1.5.0"
|
|
8
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
9
|
+
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
|
|
10
|
+
|
|
11
|
+
[[package]]
|
|
12
|
+
name = "gouda-cheese"
|
|
13
|
+
version = "0.1.0"
|
|
14
|
+
dependencies = [
|
|
15
|
+
"ndarray",
|
|
16
|
+
"numpy",
|
|
17
|
+
"pyo3",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
[[package]]
|
|
21
|
+
name = "heck"
|
|
22
|
+
version = "0.5.0"
|
|
23
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
24
|
+
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
|
25
|
+
|
|
26
|
+
[[package]]
|
|
27
|
+
name = "libc"
|
|
28
|
+
version = "0.2.184"
|
|
29
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
30
|
+
checksum = "48f5d2a454e16a5ea0f4ced81bd44e4cfc7bd3a507b61887c99fd3538b28e4af"
|
|
31
|
+
|
|
32
|
+
[[package]]
|
|
33
|
+
name = "matrixmultiply"
|
|
34
|
+
version = "0.3.10"
|
|
35
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
36
|
+
checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08"
|
|
37
|
+
dependencies = [
|
|
38
|
+
"autocfg",
|
|
39
|
+
"rawpointer",
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
[[package]]
|
|
43
|
+
name = "ndarray"
|
|
44
|
+
version = "0.17.2"
|
|
45
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
46
|
+
checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d"
|
|
47
|
+
dependencies = [
|
|
48
|
+
"matrixmultiply",
|
|
49
|
+
"num-complex",
|
|
50
|
+
"num-integer",
|
|
51
|
+
"num-traits",
|
|
52
|
+
"portable-atomic",
|
|
53
|
+
"portable-atomic-util",
|
|
54
|
+
"rawpointer",
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
[[package]]
|
|
58
|
+
name = "num-complex"
|
|
59
|
+
version = "0.4.6"
|
|
60
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
61
|
+
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
|
|
62
|
+
dependencies = [
|
|
63
|
+
"num-traits",
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
[[package]]
|
|
67
|
+
name = "num-integer"
|
|
68
|
+
version = "0.1.46"
|
|
69
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
70
|
+
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
|
|
71
|
+
dependencies = [
|
|
72
|
+
"num-traits",
|
|
73
|
+
]
|
|
74
|
+
|
|
75
|
+
[[package]]
|
|
76
|
+
name = "num-traits"
|
|
77
|
+
version = "0.2.19"
|
|
78
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
79
|
+
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
|
80
|
+
dependencies = [
|
|
81
|
+
"autocfg",
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
[[package]]
|
|
85
|
+
name = "numpy"
|
|
86
|
+
version = "0.28.0"
|
|
87
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
88
|
+
checksum = "778da78c64ddc928ebf5ad9df5edf0789410ff3bdbf3619aed51cd789a6af1e2"
|
|
89
|
+
dependencies = [
|
|
90
|
+
"libc",
|
|
91
|
+
"ndarray",
|
|
92
|
+
"num-complex",
|
|
93
|
+
"num-integer",
|
|
94
|
+
"num-traits",
|
|
95
|
+
"pyo3",
|
|
96
|
+
"pyo3-build-config",
|
|
97
|
+
"rustc-hash",
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
[[package]]
|
|
101
|
+
name = "once_cell"
|
|
102
|
+
version = "1.21.4"
|
|
103
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
104
|
+
checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50"
|
|
105
|
+
|
|
106
|
+
[[package]]
|
|
107
|
+
name = "portable-atomic"
|
|
108
|
+
version = "1.13.1"
|
|
109
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
110
|
+
checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
|
|
111
|
+
|
|
112
|
+
[[package]]
|
|
113
|
+
name = "portable-atomic-util"
|
|
114
|
+
version = "0.2.6"
|
|
115
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
116
|
+
checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3"
|
|
117
|
+
dependencies = [
|
|
118
|
+
"portable-atomic",
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
[[package]]
|
|
122
|
+
name = "proc-macro2"
|
|
123
|
+
version = "1.0.106"
|
|
124
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
125
|
+
checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
|
|
126
|
+
dependencies = [
|
|
127
|
+
"unicode-ident",
|
|
128
|
+
]
|
|
129
|
+
|
|
130
|
+
[[package]]
|
|
131
|
+
name = "pyo3"
|
|
132
|
+
version = "0.28.3"
|
|
133
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
134
|
+
checksum = "91fd8e38a3b50ed1167fb981cd6fd60147e091784c427b8f7183a7ee32c31c12"
|
|
135
|
+
dependencies = [
|
|
136
|
+
"libc",
|
|
137
|
+
"once_cell",
|
|
138
|
+
"portable-atomic",
|
|
139
|
+
"pyo3-build-config",
|
|
140
|
+
"pyo3-ffi",
|
|
141
|
+
"pyo3-macros",
|
|
142
|
+
]
|
|
143
|
+
|
|
144
|
+
[[package]]
|
|
145
|
+
name = "pyo3-build-config"
|
|
146
|
+
version = "0.28.3"
|
|
147
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
148
|
+
checksum = "e368e7ddfdeb98c9bca7f8383be1648fd84ab466bf2bc015e94008db6d35611e"
|
|
149
|
+
dependencies = [
|
|
150
|
+
"target-lexicon",
|
|
151
|
+
]
|
|
152
|
+
|
|
153
|
+
[[package]]
|
|
154
|
+
name = "pyo3-ffi"
|
|
155
|
+
version = "0.28.3"
|
|
156
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
157
|
+
checksum = "7f29e10af80b1f7ccaf7f69eace800a03ecd13e883acfacc1e5d0988605f651e"
|
|
158
|
+
dependencies = [
|
|
159
|
+
"libc",
|
|
160
|
+
"pyo3-build-config",
|
|
161
|
+
]
|
|
162
|
+
|
|
163
|
+
[[package]]
|
|
164
|
+
name = "pyo3-macros"
|
|
165
|
+
version = "0.28.3"
|
|
166
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
167
|
+
checksum = "df6e520eff47c45997d2fc7dd8214b25dd1310918bbb2642156ef66a67f29813"
|
|
168
|
+
dependencies = [
|
|
169
|
+
"proc-macro2",
|
|
170
|
+
"pyo3-macros-backend",
|
|
171
|
+
"quote",
|
|
172
|
+
"syn",
|
|
173
|
+
]
|
|
174
|
+
|
|
175
|
+
[[package]]
|
|
176
|
+
name = "pyo3-macros-backend"
|
|
177
|
+
version = "0.28.3"
|
|
178
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
179
|
+
checksum = "c4cdc218d835738f81c2338f822078af45b4afdf8b2e33cbb5916f108b813acb"
|
|
180
|
+
dependencies = [
|
|
181
|
+
"heck",
|
|
182
|
+
"proc-macro2",
|
|
183
|
+
"pyo3-build-config",
|
|
184
|
+
"quote",
|
|
185
|
+
"syn",
|
|
186
|
+
]
|
|
187
|
+
|
|
188
|
+
[[package]]
|
|
189
|
+
name = "quote"
|
|
190
|
+
version = "1.0.45"
|
|
191
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
192
|
+
checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924"
|
|
193
|
+
dependencies = [
|
|
194
|
+
"proc-macro2",
|
|
195
|
+
]
|
|
196
|
+
|
|
197
|
+
[[package]]
|
|
198
|
+
name = "rawpointer"
|
|
199
|
+
version = "0.2.1"
|
|
200
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
201
|
+
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
|
202
|
+
|
|
203
|
+
[[package]]
|
|
204
|
+
name = "rustc-hash"
|
|
205
|
+
version = "2.1.2"
|
|
206
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
207
|
+
checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe"
|
|
208
|
+
|
|
209
|
+
[[package]]
|
|
210
|
+
name = "syn"
|
|
211
|
+
version = "2.0.117"
|
|
212
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
213
|
+
checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99"
|
|
214
|
+
dependencies = [
|
|
215
|
+
"proc-macro2",
|
|
216
|
+
"quote",
|
|
217
|
+
"unicode-ident",
|
|
218
|
+
]
|
|
219
|
+
|
|
220
|
+
[[package]]
|
|
221
|
+
name = "target-lexicon"
|
|
222
|
+
version = "0.13.5"
|
|
223
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
224
|
+
checksum = "adb6935a6f5c20170eeceb1a3835a49e12e19d792f6dd344ccc76a985ca5a6ca"
|
|
225
|
+
|
|
226
|
+
[[package]]
|
|
227
|
+
name = "unicode-ident"
|
|
228
|
+
version = "1.0.24"
|
|
229
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
230
|
+
checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
[package]
|
|
2
|
+
name = "gouda-cheese"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
edition = "2024"
|
|
5
|
+
|
|
6
|
+
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
|
7
|
+
[lib]
|
|
8
|
+
name = "gouda"
|
|
9
|
+
crate-type = ["cdylib"]
|
|
10
|
+
|
|
11
|
+
[dependencies]
|
|
12
|
+
ndarray = "0.17.2"
|
|
13
|
+
numpy = "0.28.0"
|
|
14
|
+
pyo3 = { version = "0.28", features = ["extension-module"] }
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: gouda-cheese
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Classifier: Programming Language :: Rust
|
|
5
|
+
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
6
|
+
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
7
|
+
Requires-Python: >=3.8
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["maturin>=1.12,<2.0"]
|
|
3
|
+
build-backend = "maturin"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "gouda-cheese"
|
|
7
|
+
requires-python = ">=3.8"
|
|
8
|
+
classifiers = [
|
|
9
|
+
"Programming Language :: Rust",
|
|
10
|
+
"Programming Language :: Python :: Implementation :: CPython",
|
|
11
|
+
"Programming Language :: Python :: Implementation :: PyPy",
|
|
12
|
+
]
|
|
13
|
+
dynamic = ["version"]
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import pandas as pd
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Generator(nn.Module):
|
|
8
|
+
"""
|
|
9
|
+
Input:
|
|
10
|
+
x: Tabular data
|
|
11
|
+
random: randomly generated values
|
|
12
|
+
missing mask: binary mask
|
|
13
|
+
Out:
|
|
14
|
+
x: Imputed Values
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
# TODO:
|
|
18
|
+
# - Define Model for imputation
|
|
19
|
+
# Maybe something simple?
|
|
20
|
+
def __init__(self):
|
|
21
|
+
super().__init__()
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
def forward(self, x: torch.Tensor, random_value: torch.Tensor, missing_mask: torch.Tensor):
|
|
25
|
+
return x
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Discriminator(nn.Module):
|
|
29
|
+
"""
|
|
30
|
+
Input:
|
|
31
|
+
x: Tabular data
|
|
32
|
+
Out:
|
|
33
|
+
m: Binary Mask if a value is generated or real
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
# TODO:
|
|
37
|
+
# - Define Model that judges
|
|
38
|
+
def __init__(self):
|
|
39
|
+
super().__init__()
|
|
40
|
+
|
|
41
|
+
def forward(self, x: torch.Tensor):
|
|
42
|
+
return x > 0
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class Gain:
|
|
47
|
+
seed: int = 42
|
|
48
|
+
rng: torch.Generator = field(init=False)
|
|
49
|
+
generator: Generator = field(init=False)
|
|
50
|
+
discriminator: Discriminator = field(init=False)
|
|
51
|
+
|
|
52
|
+
def __post_init__(self):
|
|
53
|
+
self.rng = torch.Generator().manual_seed(self.seed)
|
|
54
|
+
|
|
55
|
+
def fit(self, x: pd.DataFrame):
|
|
56
|
+
nrows, ncols = x.shape
|
|
57
|
+
|
|
58
|
+
def transform(self, x: pd.DataFrame):
|
|
59
|
+
nrows, ncols = x.shape
|
|
60
|
+
mm = x.isna()
|
|
61
|
+
random_values = torch.rand(nrows, ncols, generator=self.rng)
|
|
62
|
+
x = x.fillna(0)
|
|
63
|
+
return self.generator(
|
|
64
|
+
torch.tensor(x.values),
|
|
65
|
+
random_values,
|
|
66
|
+
torch.tensor(mm.values)
|
|
67
|
+
)
|
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
use crate::utils::pyany_to_vec;
|
|
2
|
+
use core::f64;
|
|
3
|
+
use numpy::{IntoPyArray, PyArray2};
|
|
4
|
+
use pyo3::prelude::*;
|
|
5
|
+
use pyo3::types::PyAny;
|
|
6
|
+
use std::ops::Index;
|
|
7
|
+
|
|
8
|
+
#[pyclass]
|
|
9
|
+
pub struct KnnImputer {
|
|
10
|
+
#[pyo3(get, set)]
|
|
11
|
+
k: usize,
|
|
12
|
+
nrows: usize,
|
|
13
|
+
ncols: usize,
|
|
14
|
+
data: Option<Vec<f64>>,
|
|
15
|
+
is_fitted: bool,
|
|
16
|
+
metric: String,
|
|
17
|
+
weights: String,
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
const ALLOWED_WEIGHTS: [&str; 2] = ["uniform", "distance"];
|
|
21
|
+
|
|
22
|
+
#[pymethods]
|
|
23
|
+
impl KnnImputer {
|
|
24
|
+
#[new]
|
|
25
|
+
#[pyo3(signature = (k=5, metric="nan_euclid", weights="uniform"))]
|
|
26
|
+
pub fn new(k: usize, metric: &str, weights: &str) -> KnnImputer {
|
|
27
|
+
assert!(ALLOWED_WEIGHTS.contains(&weights));
|
|
28
|
+
KnnImputer {
|
|
29
|
+
k,
|
|
30
|
+
nrows: 0,
|
|
31
|
+
ncols: 0,
|
|
32
|
+
data: None,
|
|
33
|
+
is_fitted: false,
|
|
34
|
+
metric: metric.to_owned(),
|
|
35
|
+
weights: weights.to_owned(),
|
|
36
|
+
}
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
pub fn fit(slf: Py<Self>, py: Python<'_>, data: &Bound<'_, PyAny>) -> PyResult<Py<Self>> {
|
|
40
|
+
let (vec, nrows, ncols) = pyany_to_vec(py, data)?;
|
|
41
|
+
{
|
|
42
|
+
let mut inner = slf.borrow_mut(py);
|
|
43
|
+
inner.data = Some(vec);
|
|
44
|
+
inner.nrows = nrows;
|
|
45
|
+
inner.ncols = ncols;
|
|
46
|
+
inner.is_fitted = true;
|
|
47
|
+
} // dropping inner here (releasing the mutex)
|
|
48
|
+
Ok(slf)
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
pub fn transform<'py>(
|
|
52
|
+
&self,
|
|
53
|
+
py: Python<'py>,
|
|
54
|
+
data: &Bound<'_, PyAny>,
|
|
55
|
+
) -> PyResult<Bound<'py, PyArray2<f64>>> {
|
|
56
|
+
// check if fitted
|
|
57
|
+
if !self.is_fitted {
|
|
58
|
+
return Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(format!(
|
|
59
|
+
"Imputer is not fitted",
|
|
60
|
+
)));
|
|
61
|
+
}
|
|
62
|
+
let (data, nrows, _) = pyany_to_vec(py, data)?;
|
|
63
|
+
// actual method
|
|
64
|
+
let dist = match self.metric.as_str() {
|
|
65
|
+
"nan_euclid" => Self::nan_euclid,
|
|
66
|
+
"expected_distance" => Self::expected_distance,
|
|
67
|
+
m => {
|
|
68
|
+
return Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(format!(
|
|
69
|
+
"{} is a unknown metric",
|
|
70
|
+
m
|
|
71
|
+
)));
|
|
72
|
+
}
|
|
73
|
+
};
|
|
74
|
+
let imputed = self.brute_force(&data, nrows, dist);
|
|
75
|
+
// return python object
|
|
76
|
+
let array = ndarray::Array2::from_shape_vec((self.nrows, self.ncols), imputed)
|
|
77
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
|
|
78
|
+
Ok(array.into_pyarray(py))
|
|
79
|
+
}
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
impl KnnImputer {
|
|
83
|
+
fn brute_force(
|
|
84
|
+
&self,
|
|
85
|
+
data: &[f64],
|
|
86
|
+
nrows: usize,
|
|
87
|
+
dist: fn(&KnnImputer, &[f64], &[f64]) -> f64,
|
|
88
|
+
) -> Vec<f64> {
|
|
89
|
+
let mut imputed = data.to_vec();
|
|
90
|
+
let mut i = 0;
|
|
91
|
+
while i < data.len() {
|
|
92
|
+
if data[i].is_nan() {
|
|
93
|
+
// figure out point and impute
|
|
94
|
+
let row = i / self.ncols;
|
|
95
|
+
let col = i % self.ncols;
|
|
96
|
+
let mut cols = Vec::with_capacity(20);
|
|
97
|
+
for j in col..self.ncols {
|
|
98
|
+
let l = i + j - col;
|
|
99
|
+
if l >= data.len() {
|
|
100
|
+
break;
|
|
101
|
+
}
|
|
102
|
+
if data[l].is_nan() {
|
|
103
|
+
cols.push(j);
|
|
104
|
+
}
|
|
105
|
+
}
|
|
106
|
+
let mut distances = Vec::with_capacity(nrows);
|
|
107
|
+
let p = &self[row];
|
|
108
|
+
for r in 0..self.nrows {
|
|
109
|
+
if r == row {
|
|
110
|
+
distances.push(f64::MAX); // dont consider distance to self
|
|
111
|
+
} else {
|
|
112
|
+
distances.push(dist(&self, p, &self[r]));
|
|
113
|
+
}
|
|
114
|
+
}
|
|
115
|
+
let mut indices: Vec<usize> = (0..nrows).collect();
|
|
116
|
+
// indices.sort_by(|&a, &b| distances[a].total_cmp(&distances[b]));
|
|
117
|
+
indices.sort_unstable_by(|&a, &b| distances[a].total_cmp(&distances[b]));
|
|
118
|
+
let avgs = self.average(&indices, &cols, &self.get_weights(&distances));
|
|
119
|
+
for (avg, c) in avgs.into_iter().zip(&cols) {
|
|
120
|
+
imputed[i + c - col] = avg;
|
|
121
|
+
}
|
|
122
|
+
i += self.ncols - col;
|
|
123
|
+
} else {
|
|
124
|
+
i += 1;
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
imputed
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
fn average(&self, indices: &[usize], cols: &[usize], weights: &[f64]) -> Vec<f64> {
|
|
131
|
+
let mut avg: Vec<f64> = vec![0.0; cols.len()];
|
|
132
|
+
for (j, c) in cols.iter().enumerate() {
|
|
133
|
+
let mut count = 0;
|
|
134
|
+
for i in indices {
|
|
135
|
+
let val = self[*i][*c];
|
|
136
|
+
if val.is_nan() {
|
|
137
|
+
continue;
|
|
138
|
+
}
|
|
139
|
+
avg[j] += val * weights[*i];
|
|
140
|
+
count += 1;
|
|
141
|
+
if count >= self.k {
|
|
142
|
+
break;
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
// avg[j] *= 1.0 / self.k as f64;
|
|
146
|
+
}
|
|
147
|
+
avg
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
fn get_weights(&self, distances: &[f64]) -> Vec<f64> {
|
|
151
|
+
match self.weights.as_str() {
|
|
152
|
+
"uniform" => distances.iter().map(|_| 1.0 / self.k as f64).collect(),
|
|
153
|
+
"distances" => distances.iter().map(|d| 1.0 / d).collect(),
|
|
154
|
+
w => panic!("Unknown weight {}", w),
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
// Distance Functions
|
|
160
|
+
impl KnnImputer {
|
|
161
|
+
// TODO:
|
|
162
|
+
// write tests
|
|
163
|
+
fn nan_euclid(&self, a: &[f64], b: &[f64]) -> f64 {
|
|
164
|
+
let mut total = 0.0;
|
|
165
|
+
let mut valid = 0;
|
|
166
|
+
let ncols = a.len();
|
|
167
|
+
for i in 0..ncols {
|
|
168
|
+
let (x, y) = (a[i], b[i]);
|
|
169
|
+
if !(x.is_nan() || y.is_nan()) {
|
|
170
|
+
let d = x - y;
|
|
171
|
+
total += d * d;
|
|
172
|
+
valid += 1;
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
if valid == 0 {
|
|
176
|
+
return f64::INFINITY;
|
|
177
|
+
}
|
|
178
|
+
total * (ncols as f64 / valid as f64)
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
fn expected_distance(&self, a: &[f64], b: &[f64]) -> f64 {
|
|
182
|
+
let mut total = 0.0;
|
|
183
|
+
let mut total_obs = 0.0;
|
|
184
|
+
let ncols = a.len();
|
|
185
|
+
for i in 0..ncols {
|
|
186
|
+
let (x, y) = (a[i], b[i]);
|
|
187
|
+
match (x.is_nan(), y.is_nan()) {
|
|
188
|
+
(true, true) => total += 0.333,
|
|
189
|
+
(true, false) => total += y.max(1.0 - y),
|
|
190
|
+
(false, true) => total += x.max(1.0 - x),
|
|
191
|
+
(false, false) => {
|
|
192
|
+
let d = x - y;
|
|
193
|
+
total_obs += d * d
|
|
194
|
+
}
|
|
195
|
+
}
|
|
196
|
+
}
|
|
197
|
+
total + total_obs.sqrt()
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
impl Index<usize> for KnnImputer {
|
|
202
|
+
type Output = [f64];
|
|
203
|
+
|
|
204
|
+
fn index(&self, row: usize) -> &[f64] {
|
|
205
|
+
let offset = row * self.ncols;
|
|
206
|
+
&self.data.as_ref().unwrap()[offset..offset + self.ncols]
|
|
207
|
+
}
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
#[cfg(test)]
|
|
211
|
+
mod tests {
|
|
212
|
+
use super::*; // has access to everything, including private
|
|
213
|
+
|
|
214
|
+
#[test]
|
|
215
|
+
fn test_nan_euclid() {
|
|
216
|
+
let knn = KnnImputer::new(5, "nan_euclid", "uniform");
|
|
217
|
+
let p = &[f64::NAN, 0.22129885, 0.8863533, 0.50595314, 0.5011135];
|
|
218
|
+
let points = &[
|
|
219
|
+
[0.76052103, f64::NAN, 0.4094729, 0.9573324, f64::NAN],
|
|
220
|
+
[0.27839605, 0.7338148, 0.98359227, 0.98189233, 0.45384631],
|
|
221
|
+
[f64::NAN, 0.22129885, 0.8863533, 0.50595314, 0.5011135],
|
|
222
|
+
[f64::NAN, 0.32309935, 0.64573872, f64::NAN, f64::NAN],
|
|
223
|
+
[0.9317995, 0.51597243, 0.38054457, 0.62366235, 0.12229672],
|
|
224
|
+
[0.90547984, f64::NAN, 0.68424979, 0.55400964, 0.55284803],
|
|
225
|
+
[0.68846839, 0.53889275, 0.44453843, 0.43416536, 0.18575075],
|
|
226
|
+
[0.13333331, 0.8772666, 0.64398646, f64::NAN, 0.90529859],
|
|
227
|
+
[0.69819416, 0.65251852, 0.39663618, 0.65702538, f64::NAN],
|
|
228
|
+
];
|
|
229
|
+
let expected = &[
|
|
230
|
+
1.0382174099275148,
|
|
231
|
+
0.7912650658744038,
|
|
232
|
+
0.0,
|
|
233
|
+
0.41309417813332494,
|
|
234
|
+
0.7905951937189456,
|
|
235
|
+
0.2763805321428371,
|
|
236
|
+
0.7077017509263522,
|
|
237
|
+
1.042753531574897,
|
|
238
|
+
0.8646734303095986,
|
|
239
|
+
];
|
|
240
|
+
|
|
241
|
+
for (e, q) in expected.iter().zip(points) {
|
|
242
|
+
let result = knn.nan_euclid(p, q).sqrt();
|
|
243
|
+
assert!(
|
|
244
|
+
(result - e).abs() < 1e-7,
|
|
245
|
+
"Expected: {}; Actual: {}",
|
|
246
|
+
e,
|
|
247
|
+
result
|
|
248
|
+
);
|
|
249
|
+
}
|
|
250
|
+
}
|
|
251
|
+
#[test]
|
|
252
|
+
fn test_expected_distance() {
|
|
253
|
+
let p = &[f64::NAN, 0.555556, f64::NAN, 0.555556];
|
|
254
|
+
let points = &[
|
|
255
|
+
[0.0, 0.777778, 0.0, 0.777778],
|
|
256
|
+
[f64::NAN, 0.333333, 0.666667, 0.333333],
|
|
257
|
+
[0.0, 1.0, 0.0, 1.0],
|
|
258
|
+
[0.0, 0.88889, 0.0, 0.88889],
|
|
259
|
+
[0.0, 0.44444, 0.0, 0.44444],
|
|
260
|
+
[0.666667, f64::NAN, 0.666667, f64::NAN],
|
|
261
|
+
[f64::NAN, f64::NAN, f64::NAN, f64::NAN],
|
|
262
|
+
[f64::NAN, 0.555556, f64::NAN, 0.555556],
|
|
263
|
+
];
|
|
264
|
+
let expected = &[
|
|
265
|
+
2.314269366257674,
|
|
266
|
+
1.3139377804712364,
|
|
267
|
+
2.6285387325153478,
|
|
268
|
+
2.4714054636000733,
|
|
269
|
+
2.157141754196649,
|
|
270
|
+
2.444446,
|
|
271
|
+
1.777112,
|
|
272
|
+
0.666,
|
|
273
|
+
];
|
|
274
|
+
let knn = KnnImputer::new(5, "expected_distance", "uniform");
|
|
275
|
+
for (e, q) in expected.iter().zip(points) {
|
|
276
|
+
let result = knn.expected_distance(p, q);
|
|
277
|
+
assert!(
|
|
278
|
+
(result - e).abs() < 1e-9,
|
|
279
|
+
"Expected: {}; Actual: {}",
|
|
280
|
+
e,
|
|
281
|
+
result
|
|
282
|
+
);
|
|
283
|
+
}
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
#[test]
|
|
287
|
+
fn test_compare() {
|
|
288
|
+
let knn = KnnImputer::new(5, "nan_euclid", "uniform");
|
|
289
|
+
let (a, b) = (&[1.0, 2.0], &[3.0, 4.0]);
|
|
290
|
+
let diff = knn.nan_euclid(a, b).sqrt() - knn.expected_distance(a, b);
|
|
291
|
+
assert!((diff).abs() < 1e-10, "Expected: 0.0; Actual {}", diff);
|
|
292
|
+
|
|
293
|
+
let (a, b) = (&[1.0, f64::NAN], &[3.0, f64::NAN]);
|
|
294
|
+
// 2.8284271247461903
|
|
295
|
+
let euclid = knn.nan_euclid(a, b).sqrt();
|
|
296
|
+
// 2 + 1/3
|
|
297
|
+
let ed = knn.expected_distance(a, b);
|
|
298
|
+
let diff = euclid - ed;
|
|
299
|
+
assert!(
|
|
300
|
+
(diff - 0.4954271247461901).abs() < 1e-10,
|
|
301
|
+
"Expected: 0.4954271247461901 ; Actual {}\nEuclid: {}; ED: {}",
|
|
302
|
+
diff,
|
|
303
|
+
euclid,
|
|
304
|
+
ed
|
|
305
|
+
);
|
|
306
|
+
}
|
|
307
|
+
}
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
use pyo3::prelude::*;
|
|
2
|
+
mod knn;
|
|
3
|
+
mod utils;
|
|
4
|
+
|
|
5
|
+
/// A Python module implemented in Rust.
|
|
6
|
+
#[pymodule]
|
|
7
|
+
mod gouda {
|
|
8
|
+
use super::*;
|
|
9
|
+
|
|
10
|
+
#[pymodule_init]
|
|
11
|
+
fn init(module: &Bound<'_, PyModule>) -> PyResult<()> {
|
|
12
|
+
module.add_class::<knn::KnnImputer>()?;
|
|
13
|
+
Ok(())
|
|
14
|
+
}
|
|
15
|
+
}
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
use numpy::{PyReadonlyArray2, PyUntypedArrayMethods};
|
|
2
|
+
use pyo3::prelude::*;
|
|
3
|
+
use pyo3::types::PyAny;
|
|
4
|
+
|
|
5
|
+
const SUPPORTED_TYPES: &str = "numpy.ndarray, pandas.DataFrame";
|
|
6
|
+
|
|
7
|
+
pub fn pyany_to_vec(py: Python<'_>, obj: &Bound<'_, PyAny>) -> PyResult<(Vec<f64>, usize, usize)> {
|
|
8
|
+
// 1. numpy
|
|
9
|
+
if let Ok(arr) = obj.extract::<PyReadonlyArray2<f64>>() {
|
|
10
|
+
let shape = arr.shape();
|
|
11
|
+
let (nrows, ncols) = (shape[0], shape[1]);
|
|
12
|
+
// as_slice() gives the flat buffer directly if C-contiguous
|
|
13
|
+
let data = match arr.as_slice() {
|
|
14
|
+
Ok(s) => s.to_vec(),
|
|
15
|
+
Err(_) => arr.as_array().iter().copied().collect(), // non-contiguous fallback
|
|
16
|
+
};
|
|
17
|
+
return Ok((data, nrows, ncols));
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
// 2. pandas
|
|
21
|
+
let pandas = py.import("pandas")?;
|
|
22
|
+
if obj.is_instance(&pandas.getattr("DataFrame")?)? {
|
|
23
|
+
let np_module = py.import("numpy")?;
|
|
24
|
+
let np = np_module.call_method1("ascontiguousarray", (obj.call_method0("to_numpy")?,))?;
|
|
25
|
+
let arr = np.extract::<PyReadonlyArray2<f64>>()?;
|
|
26
|
+
let shape = arr.shape();
|
|
27
|
+
let (nrows, ncols) = (shape[0], shape[1]);
|
|
28
|
+
let data = match arr.as_slice() {
|
|
29
|
+
Ok(s) => s.to_vec(),
|
|
30
|
+
Err(_) => arr.as_array().iter().copied().collect(),
|
|
31
|
+
};
|
|
32
|
+
return Ok((data, nrows, ncols));
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
let type_name = obj
|
|
36
|
+
.get_type()
|
|
37
|
+
.qualname()
|
|
38
|
+
.map(|s| s.to_string())
|
|
39
|
+
.unwrap_or_else(|_| "unknown".to_string());
|
|
40
|
+
|
|
41
|
+
Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(format!(
|
|
42
|
+
"Unsupported type: '{}'. Supported types are: {}",
|
|
43
|
+
type_name, SUPPORTED_TYPES
|
|
44
|
+
)))
|
|
45
|
+
}
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import time
|
|
3
|
+
from sklearn.impute import KNNImputer as SKKNN
|
|
4
|
+
from gouda import KnnImputer as RSKNN
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def test_time():
|
|
8
|
+
data = np.random.rand(500, 50)
|
|
9
|
+
data[data < 0.78] = np.nan
|
|
10
|
+
|
|
11
|
+
N = 5
|
|
12
|
+
# Warmup
|
|
13
|
+
for _ in range(3):
|
|
14
|
+
RSKNN().fit(data).transform(data)
|
|
15
|
+
SKKNN().fit(data).transform(data)
|
|
16
|
+
|
|
17
|
+
# Benchmark Rust
|
|
18
|
+
times_rs = []
|
|
19
|
+
for _ in range(N):
|
|
20
|
+
imputer = RSKNN()
|
|
21
|
+
imputer.fit(data)
|
|
22
|
+
start = time.perf_counter_ns()
|
|
23
|
+
_ = imputer.transform(data)
|
|
24
|
+
times_rs.append(time.perf_counter_ns() - start)
|
|
25
|
+
|
|
26
|
+
# Benchmark sklearn
|
|
27
|
+
times_sk = []
|
|
28
|
+
for _ in range(N):
|
|
29
|
+
imputer = SKKNN()
|
|
30
|
+
imputer.fit(data)
|
|
31
|
+
start = time.perf_counter_ns()
|
|
32
|
+
_ = imputer.transform(data)
|
|
33
|
+
times_sk.append(time.perf_counter_ns() - start)
|
|
34
|
+
|
|
35
|
+
elapsed_rs = sorted(times_rs)[N // 2] # median
|
|
36
|
+
elapsed_sk = sorted(times_sk)[N // 2]
|
|
37
|
+
|
|
38
|
+
assert elapsed_rs < elapsed_sk * 0.5, f"Rust: {
|
|
39
|
+
elapsed_rs}ns sklearn: {elapsed_sk}ns"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def test_nans():
|
|
43
|
+
data = np.random.rand(500, 5)
|
|
44
|
+
data[data < 0.48] = np.nan
|
|
45
|
+
imputed = RSKNN().fit(data).transform(data)
|
|
46
|
+
print("data:\n", data)
|
|
47
|
+
print("imputed:\n", imputed)
|
|
48
|
+
assert not np.isnan(imputed).any(), "Imputed still has missing values"
|
|
49
|
+
imputed = RSKNN(metric="expected_distance").fit(data).transform(data)
|
|
50
|
+
assert not np.isnan(imputed).any(), "Imputed still has missing values"
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def test_isclose():
|
|
54
|
+
data = np.random.rand(500, 50)
|
|
55
|
+
data[data < 0.38] = np.nan
|
|
56
|
+
imputed_rs = RSKNN().fit(data).transform(data)
|
|
57
|
+
imputed_sk = SKKNN().fit(data).transform(data)
|
|
58
|
+
print(data, "\n")
|
|
59
|
+
print(imputed_sk, "\n")
|
|
60
|
+
print(imputed_rs, "\n")
|
|
61
|
+
avg_dist = (imputed_sk - imputed_rs).sum() / data.size
|
|
62
|
+
assert np.isclose(imputed_rs, imputed_sk, atol=0.1).all(
|
|
63
|
+
), f"wrong distances {avg_dist}"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def test_ed():
|
|
67
|
+
data = np.random.rand(500, 50)
|
|
68
|
+
assert (data < 1.0).all()
|
|
69
|
+
assert (data >= 0.0).all()
|
|
70
|
+
data[data < 0.38] = np.nan
|
|
71
|
+
imputed_rs = RSKNN(metric="expected_distance").fit(data).transform(data)
|
|
72
|
+
imputed_sk = SKKNN().fit(data).transform(data)
|
|
73
|
+
assert not np.isclose(imputed_rs, imputed_sk).all(), "wrong distances"
|