ezmsg-learn 1.2.0__tar.gz → 1.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/PKG-INFO +4 -3
- ezmsg_learn-1.4.0/docs/source/guides/array_api.rst +246 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/docs/source/guides/classification.rst +5 -4
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/docs/source/index.rst +1 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/pyproject.toml +4 -4
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/__version__.py +2 -2
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/dim_reduce/adaptive_decomp.py +28 -8
- ezmsg_learn-1.4.0/src/ezmsg/learn/model/cca.py +163 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/model/refit_kalman.py +90 -69
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/adaptive_linear_regressor.py +32 -9
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/refit_kalman.py +56 -26
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/slda.py +22 -6
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/ssr.py +0 -5
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/tests/unit/test_adaptive_linear_regressor.py +2 -2
- ezmsg_learn-1.4.0/tests/unit/test_cca.py +125 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/tests/unit/test_linear_regressor.py +3 -3
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/tests/unit/test_mlp_old.py +1 -2
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/tests/unit/test_rnn.py +1 -1
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/tests/unit/test_sgd.py +2 -2
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/tests/unit/test_ssr.py +4 -4
- ezmsg_learn-1.2.0/src/ezmsg/learn/model/cca.py +0 -122
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/.github/workflows/docs.yml +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/.github/workflows/python-publish.yml +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/.github/workflows/python-tests.yml +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/.gitignore +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/.pre-commit-config.yaml +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/LICENSE +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/README.md +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/docs/Makefile +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/docs/make.bat +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/docs/source/_templates/autosummary/module.rst +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/docs/source/api/index.rst +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/docs/source/conf.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/__init__.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/dim_reduce/__init__.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/dim_reduce/incremental_decomp.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/linear_model/__init__.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/linear_model/adaptive_linear_regressor.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/linear_model/cca.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/linear_model/linear_regressor.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/linear_model/sgd.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/linear_model/slda.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/model/__init__.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/model/mlp.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/model/mlp_old.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/model/rnn.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/model/transformer.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/nlin_model/__init__.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/nlin_model/mlp.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/__init__.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/base.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/linear_regressor.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/mlp_old.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/rnn.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/sgd.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/sklearn.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/torch.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/transformer.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/util.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/tests/benchmark/bench_lrr.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/tests/dim_reduce/test_adaptive_decomp.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/tests/dim_reduce/test_incremental_decomp.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/tests/integration/conftest.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/tests/integration/test_mlp_system.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/tests/integration/test_refit_kalman_system.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/tests/integration/test_rnn_system.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/tests/integration/test_sklearn_system.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/tests/integration/test_torch_system.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/tests/integration/test_transformer_system.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/tests/unit/test_mlp.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/tests/unit/test_refit_kalman.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/tests/unit/test_sklearn.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/tests/unit/test_slda.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/tests/unit/test_torch.py +0 -0
- {ezmsg_learn-1.2.0 → ezmsg_learn-1.4.0}/tests/unit/test_transformer.py +0 -0
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ezmsg-learn
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.4.0
|
|
4
4
|
Summary: ezmsg namespace package for machine learning
|
|
5
5
|
Author-email: Chadwick Boulay <chadwick.boulay@gmail.com>
|
|
6
6
|
License-Expression: MIT
|
|
7
7
|
License-File: LICENSE
|
|
8
8
|
Requires-Python: >=3.10.15
|
|
9
|
-
Requires-Dist: ezmsg-baseproc>=1.
|
|
10
|
-
Requires-Dist: ezmsg-sigproc>=2.
|
|
9
|
+
Requires-Dist: ezmsg-baseproc>=1.5.1
|
|
10
|
+
Requires-Dist: ezmsg-sigproc>=2.17.0
|
|
11
|
+
Requires-Dist: ezmsg>=3.7.3
|
|
11
12
|
Requires-Dist: river>=0.22.0
|
|
12
13
|
Requires-Dist: scikit-learn>=1.6.0
|
|
13
14
|
Requires-Dist: torch>=2.6.0
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
Array API Compatibility
|
|
2
|
+
=======================
|
|
3
|
+
|
|
4
|
+
ezmsg-learn uses the `Array API standard <https://data-apis.org/array-api/latest/>`_
|
|
5
|
+
to allow processors to operate on arrays from different backends — NumPy, CuPy,
|
|
6
|
+
PyTorch, and others — without code changes.
|
|
7
|
+
|
|
8
|
+
.. contents:: On this page
|
|
9
|
+
:local:
|
|
10
|
+
:depth: 2
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
How It Works
|
|
14
|
+
------------
|
|
15
|
+
|
|
16
|
+
Modules that support the Array API derive the array namespace from their input
|
|
17
|
+
data using ``array_api_compat.get_namespace()``:
|
|
18
|
+
|
|
19
|
+
.. code-block:: python
|
|
20
|
+
|
|
21
|
+
from array_api_compat import get_namespace
|
|
22
|
+
|
|
23
|
+
def process(self, data):
|
|
24
|
+
xp = get_namespace(data) # numpy, cupy, torch, etc.
|
|
25
|
+
result = xp.linalg.inv(data) # dispatches to the right backend
|
|
26
|
+
return result
|
|
27
|
+
|
|
28
|
+
This means that if you pass a CuPy array, all computation stays on the GPU.
|
|
29
|
+
If you pass a NumPy array, it behaves exactly as before.
|
|
30
|
+
|
|
31
|
+
Helper utilities from ``ezmsg.sigproc.util.array`` handle device placement
|
|
32
|
+
and creation functions portably:
|
|
33
|
+
|
|
34
|
+
- ``array_device(x)`` — returns the device of an array, or ``None``
|
|
35
|
+
- ``xp_create(fn, *args, dtype=None, device=None)`` — calls creation
|
|
36
|
+
functions (``zeros``, ``eye``) with optional device
|
|
37
|
+
- ``xp_asarray(xp, obj, dtype=None, device=None)`` — portable ``asarray``
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
Module Compatibility
|
|
41
|
+
--------------------
|
|
42
|
+
|
|
43
|
+
The table below summarises the Array API status of each module.
|
|
44
|
+
|
|
45
|
+
Fully compatible
|
|
46
|
+
^^^^^^^^^^^^^^^^
|
|
47
|
+
|
|
48
|
+
These modules perform all computation in the source array namespace.
|
|
49
|
+
|
|
50
|
+
.. list-table::
|
|
51
|
+
:header-rows: 1
|
|
52
|
+
:widths: 35 65
|
|
53
|
+
|
|
54
|
+
* - Module
|
|
55
|
+
- Notes
|
|
56
|
+
* - ``process.ssr``
|
|
57
|
+
- LRR / self-supervised regression. Full Array API.
|
|
58
|
+
* - ``model.cca``
|
|
59
|
+
- Incremental CCA. Replaced ``scipy.linalg.sqrtm`` with an
|
|
60
|
+
eigendecomposition-based inverse square root using only Array API ops.
|
|
61
|
+
* - ``process.rnn``
|
|
62
|
+
- PyTorch-native; operates on ``torch.Tensor`` throughout.
|
|
63
|
+
|
|
64
|
+
Mostly compatible (with NumPy boundaries)
|
|
65
|
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
66
|
+
|
|
67
|
+
These modules use the Array API for data manipulation but fall back to NumPy
|
|
68
|
+
at specific points where a dependency requires it.
|
|
69
|
+
|
|
70
|
+
.. list-table::
|
|
71
|
+
:header-rows: 1
|
|
72
|
+
:widths: 25 35 40
|
|
73
|
+
|
|
74
|
+
* - Module
|
|
75
|
+
- NumPy boundary
|
|
76
|
+
- Reason
|
|
77
|
+
* - ``model.refit_kalman``
|
|
78
|
+
- ``_compute_gain()``
|
|
79
|
+
- ``scipy.linalg.solve_discrete_are`` has no Array API equivalent.
|
|
80
|
+
Matrices are converted to NumPy for the DARE solver, then converted back.
|
|
81
|
+
* - ``model.refit_kalman``
|
|
82
|
+
- ``refit()`` mutation loop
|
|
83
|
+
- Per-sample velocity remapping uses ``np.linalg.norm`` on small vectors
|
|
84
|
+
and scalar element assignment.
|
|
85
|
+
* - ``process.refit_kalman``
|
|
86
|
+
- Inherits boundaries from model
|
|
87
|
+
- State init and output arrays use the source namespace.
|
|
88
|
+
* - ``process.slda``
|
|
89
|
+
- ``predict_proba``
|
|
90
|
+
- sklearn ``LinearDiscriminantAnalysis`` requires NumPy input.
|
|
91
|
+
* - ``process.adaptive_linear_regressor``
|
|
92
|
+
- ``partial_fit`` / ``predict``
|
|
93
|
+
- sklearn and river models require NumPy / pandas input.
|
|
94
|
+
* - ``dim_reduce.adaptive_decomp``
|
|
95
|
+
- ``partial_fit`` / ``transform``
|
|
96
|
+
- sklearn ``IncrementalPCA`` and ``MiniBatchNMF`` require NumPy input.
|
|
97
|
+
|
|
98
|
+
Not converted
|
|
99
|
+
^^^^^^^^^^^^^
|
|
100
|
+
|
|
101
|
+
These modules use NumPy directly. Conversion would provide little benefit
|
|
102
|
+
because the underlying estimator is the bottleneck.
|
|
103
|
+
|
|
104
|
+
.. list-table::
|
|
105
|
+
:header-rows: 1
|
|
106
|
+
:widths: 25 75
|
|
107
|
+
|
|
108
|
+
* - Module
|
|
109
|
+
- Reason
|
|
110
|
+
* - ``process.linear_regressor``
|
|
111
|
+
- Thin wrapper around sklearn ``LinearModel.predict``.
|
|
112
|
+
Could be made compatible if sklearn's ``array_api_dispatch`` is enabled
|
|
113
|
+
(see below).
|
|
114
|
+
* - ``process.sgd``
|
|
115
|
+
- sklearn ``SGDClassifier`` has no Array API support.
|
|
116
|
+
* - ``process.sklearn``
|
|
117
|
+
- Generic wrapper for arbitrary models; cannot assume Array API support.
|
|
118
|
+
* - ``dim_reduce.incremental_decomp``
|
|
119
|
+
- Delegates to ``adaptive_decomp``; trivial numpy usage (``np.prod`` on
|
|
120
|
+
Python tuples).
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
sklearn Array API Dispatch
|
|
124
|
+
--------------------------
|
|
125
|
+
|
|
126
|
+
scikit-learn 1.8+ has experimental support for Array API dispatch on a subset
|
|
127
|
+
of estimators. Two estimators used in ezmsg-learn are on the supported list:
|
|
128
|
+
|
|
129
|
+
.. list-table::
|
|
130
|
+
:header-rows: 1
|
|
131
|
+
:widths: 30 30 40
|
|
132
|
+
|
|
133
|
+
* - Estimator
|
|
134
|
+
- Used in
|
|
135
|
+
- Constraint
|
|
136
|
+
* - ``LinearDiscriminantAnalysis``
|
|
137
|
+
- ``process.slda``
|
|
138
|
+
- Requires ``solver="svd"`` (the ``"lsqr"`` solver with ``shrinkage``
|
|
139
|
+
is not supported)
|
|
140
|
+
* - ``Ridge``
|
|
141
|
+
- ``process.linear_regressor``
|
|
142
|
+
- Requires ``solver="svd"``
|
|
143
|
+
|
|
144
|
+
To use dispatch, enable it before creating the estimator:
|
|
145
|
+
|
|
146
|
+
.. code-block:: python
|
|
147
|
+
|
|
148
|
+
from sklearn import set_config
|
|
149
|
+
set_config(array_api_dispatch=True)
|
|
150
|
+
|
|
151
|
+
.. warning::
|
|
152
|
+
|
|
153
|
+
- ``array_api_dispatch`` is marked **experimental** in sklearn.
|
|
154
|
+
- Solver constraints (``solver="svd"``) may produce slightly different
|
|
155
|
+
numerical results compared to other solvers.
|
|
156
|
+
- Enabling dispatch globally may affect other sklearn estimators in the
|
|
157
|
+
same process.
|
|
158
|
+
- ezmsg-learn does **not** enable dispatch by default.
|
|
159
|
+
|
|
160
|
+
Estimators that do **not** support Array API dispatch:
|
|
161
|
+
|
|
162
|
+
- ``IncrementalPCA``, ``MiniBatchNMF`` — only batch ``PCA`` is supported
|
|
163
|
+
- ``SGDClassifier``, ``SGDRegressor``, ``PassiveAggressiveRegressor``
|
|
164
|
+
- All river models
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
Writing Array API Compatible Code
|
|
168
|
+
----------------------------------
|
|
169
|
+
|
|
170
|
+
When adding or modifying processors in ezmsg-learn, follow these patterns.
|
|
171
|
+
|
|
172
|
+
Deriving the namespace
|
|
173
|
+
^^^^^^^^^^^^^^^^^^^^^^
|
|
174
|
+
|
|
175
|
+
Always derive ``xp`` from the input data, not from a hardcoded ``numpy``:
|
|
176
|
+
|
|
177
|
+
.. code-block:: python
|
|
178
|
+
|
|
179
|
+
from array_api_compat import get_namespace
|
|
180
|
+
from ezmsg.sigproc.util.array import array_device, xp_create
|
|
181
|
+
|
|
182
|
+
def _process(self, message):
|
|
183
|
+
xp = get_namespace(message.data)
|
|
184
|
+
dev = array_device(message.data)
|
|
185
|
+
|
|
186
|
+
Transposing matrices
|
|
187
|
+
^^^^^^^^^^^^^^^^^^^^
|
|
188
|
+
|
|
189
|
+
The Array API does not support ``.T``. Use ``xp.linalg.matrix_transpose()``:
|
|
190
|
+
|
|
191
|
+
.. code-block:: python
|
|
192
|
+
|
|
193
|
+
# Before (numpy-only)
|
|
194
|
+
result = A.T @ B
|
|
195
|
+
|
|
196
|
+
# After (Array API)
|
|
197
|
+
_mT = xp.linalg.matrix_transpose
|
|
198
|
+
result = _mT(A) @ B
|
|
199
|
+
|
|
200
|
+
Creating arrays
|
|
201
|
+
^^^^^^^^^^^^^^^
|
|
202
|
+
|
|
203
|
+
Use ``xp_create`` to handle device placement portably:
|
|
204
|
+
|
|
205
|
+
.. code-block:: python
|
|
206
|
+
|
|
207
|
+
# Before
|
|
208
|
+
I = np.eye(n)
|
|
209
|
+
z = np.zeros((m, n), dtype=np.float64)
|
|
210
|
+
|
|
211
|
+
# After
|
|
212
|
+
I = xp_create(xp.eye, n, device=dev)
|
|
213
|
+
z = xp_create(xp.zeros, (m, n), dtype=xp.float64, device=dev)
|
|
214
|
+
|
|
215
|
+
Handling sklearn boundaries
|
|
216
|
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
217
|
+
|
|
218
|
+
When calling into sklearn (or other NumPy-only libraries), convert at the
|
|
219
|
+
boundary and convert back:
|
|
220
|
+
|
|
221
|
+
.. code-block:: python
|
|
222
|
+
|
|
223
|
+
from array_api_compat import is_numpy_array
|
|
224
|
+
|
|
225
|
+
# Convert to numpy for sklearn
|
|
226
|
+
X_np = np.asarray(X) if not is_numpy_array(X) else X
|
|
227
|
+
result_np = estimator.predict(X_np)
|
|
228
|
+
|
|
229
|
+
# Convert back to source namespace
|
|
230
|
+
result = xp.asarray(result_np) if not is_numpy_array(X) else result_np
|
|
231
|
+
|
|
232
|
+
Checking for NaN
|
|
233
|
+
^^^^^^^^^^^^^^^^
|
|
234
|
+
|
|
235
|
+
Use ``xp.isnan`` instead of ``np.isnan``:
|
|
236
|
+
|
|
237
|
+
.. code-block:: python
|
|
238
|
+
|
|
239
|
+
if xp.any(xp.isnan(message.data)):
|
|
240
|
+
return
|
|
241
|
+
|
|
242
|
+
Norms
|
|
243
|
+
^^^^^
|
|
244
|
+
|
|
245
|
+
Use ``xp.linalg.matrix_norm`` (Frobenius by default) instead of
|
|
246
|
+
``np.linalg.norm`` for matrices. For vectors, use ``xp.linalg.vector_norm``.
|
|
@@ -125,7 +125,8 @@ For models that support ``partial_fit``, you can update them during streaming:
|
|
|
125
125
|
.. code-block:: python
|
|
126
126
|
|
|
127
127
|
from ezmsg.learn.process.sklearn import SklearnModelProcessor, SklearnModelSettings
|
|
128
|
-
from ezmsg.
|
|
128
|
+
from ezmsg.baseproc import SampleTriggerMessage
|
|
129
|
+
from ezmsg.util.messages.util import replace
|
|
129
130
|
|
|
130
131
|
# Create processor with online learning support
|
|
131
132
|
processor = SklearnModelProcessor(
|
|
@@ -137,9 +138,9 @@ For models that support ``partial_fit``, you can update them during streaming:
|
|
|
137
138
|
)
|
|
138
139
|
|
|
139
140
|
# Training with labeled samples
|
|
140
|
-
sample_msg =
|
|
141
|
-
|
|
142
|
-
trigger=label_value
|
|
141
|
+
sample_msg = replace(
|
|
142
|
+
feature_array, # AxisArray with features
|
|
143
|
+
attrs={"trigger": SampleTriggerMessage(value=label_value)}
|
|
143
144
|
)
|
|
144
145
|
processor.partial_fit(sample_msg)
|
|
145
146
|
|
|
@@ -9,8 +9,9 @@ license = "MIT"
|
|
|
9
9
|
requires-python = ">=3.10.15"
|
|
10
10
|
dynamic = ["version"]
|
|
11
11
|
dependencies = [
|
|
12
|
-
"ezmsg
|
|
13
|
-
"ezmsg-
|
|
12
|
+
"ezmsg>=3.7.3",
|
|
13
|
+
"ezmsg-baseproc>=1.5.1",
|
|
14
|
+
"ezmsg-sigproc>=2.17.0",
|
|
14
15
|
"river>=0.22.0",
|
|
15
16
|
"scikit-learn>=1.6.0",
|
|
16
17
|
"torch>=2.6.0",
|
|
@@ -73,5 +74,4 @@ known-third-party = ["ezmsg", "ezmsg.baseproc", "ezmsg.sigproc"]
|
|
|
73
74
|
|
|
74
75
|
[tool.uv.sources]
|
|
75
76
|
# Uncomment to use development version of ezmsg from git
|
|
76
|
-
#ezmsg = { git = "https://github.com/ezmsg-org/ezmsg.git", branch = "feature/profiling" }
|
|
77
|
-
#ezmsg-sigproc = { path = "../ezmsg-sigproc", editable = true }
|
|
77
|
+
#ezmsg = { git = "https://github.com/ezmsg-org/ezmsg.git", branch = "feature/profiling" }
|
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '1.
|
|
32
|
-
__version_tuple__ = version_tuple = (1,
|
|
31
|
+
__version__ = version = '1.4.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (1, 4, 0)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -1,7 +1,18 @@
|
|
|
1
|
+
"""Adaptive decomposition transformers (PCA, NMF).
|
|
2
|
+
|
|
3
|
+
.. note::
|
|
4
|
+
This module supports the Array API standard via
|
|
5
|
+
``array_api_compat.get_namespace()``. Reshaping and output allocation
|
|
6
|
+
use Array API operations; a NumPy boundary is applied before sklearn
|
|
7
|
+
``partial_fit``/``transform`` calls.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import math
|
|
1
11
|
import typing
|
|
2
12
|
|
|
3
13
|
import ezmsg.core as ez
|
|
4
14
|
import numpy as np
|
|
15
|
+
from array_api_compat import get_namespace, is_numpy_array
|
|
5
16
|
from ezmsg.baseproc import (
|
|
6
17
|
BaseAdaptiveTransformer,
|
|
7
18
|
BaseAdaptiveTransformerUnit,
|
|
@@ -128,6 +139,8 @@ class AdaptiveDecompTransformer(
|
|
|
128
139
|
if in_dat.shape[ax_idx] == 0:
|
|
129
140
|
return self._state.template
|
|
130
141
|
|
|
142
|
+
xp = get_namespace(in_dat)
|
|
143
|
+
|
|
131
144
|
# Re-order axes
|
|
132
145
|
sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes
|
|
133
146
|
if message.dims != sorted_dims_exp:
|
|
@@ -137,16 +150,20 @@ class AdaptiveDecompTransformer(
|
|
|
137
150
|
pass
|
|
138
151
|
|
|
139
152
|
# fold [iter_axis] + off_targ_axes together and fold targ_axes together
|
|
140
|
-
d2 =
|
|
141
|
-
in_dat =
|
|
153
|
+
d2 = math.prod(in_dat.shape[len(off_targ_axes) + 1 :])
|
|
154
|
+
in_dat = xp.reshape(in_dat, (-1, d2))
|
|
142
155
|
|
|
143
156
|
replace_kwargs = {
|
|
144
157
|
"axes": {**self._state.template.axes, iter_axis: message.axes[iter_axis]},
|
|
145
158
|
}
|
|
146
159
|
|
|
147
|
-
# Transform data
|
|
160
|
+
# Transform data — sklearn needs numpy
|
|
148
161
|
if hasattr(self._state.estimator, "components_"):
|
|
149
|
-
|
|
162
|
+
in_np = np.asarray(in_dat) if not is_numpy_array(in_dat) else in_dat
|
|
163
|
+
decomp_dat = self._state.estimator.transform(in_np)
|
|
164
|
+
# Convert back to source namespace
|
|
165
|
+
decomp_dat = xp.asarray(decomp_dat) if not is_numpy_array(in_dat) else decomp_dat
|
|
166
|
+
decomp_dat = xp.reshape(decomp_dat, (-1,) + self._state.template.data.shape[1:])
|
|
150
167
|
replace_kwargs["data"] = decomp_dat
|
|
151
168
|
|
|
152
169
|
return replace(self._state.template, **replace_kwargs)
|
|
@@ -165,6 +182,8 @@ class AdaptiveDecompTransformer(
|
|
|
165
182
|
if in_dat.shape[ax_idx] == 0:
|
|
166
183
|
return
|
|
167
184
|
|
|
185
|
+
xp = get_namespace(in_dat)
|
|
186
|
+
|
|
168
187
|
# Re-order axes if needed
|
|
169
188
|
sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes
|
|
170
189
|
if message.dims != sorted_dims_exp:
|
|
@@ -172,11 +191,12 @@ class AdaptiveDecompTransformer(
|
|
|
172
191
|
pass
|
|
173
192
|
|
|
174
193
|
# fold [iter_axis] + off_targ_axes together and fold targ_axes together
|
|
175
|
-
d2 =
|
|
176
|
-
in_dat =
|
|
194
|
+
d2 = math.prod(in_dat.shape[len(off_targ_axes) + 1 :])
|
|
195
|
+
in_dat = xp.reshape(in_dat, (-1, d2))
|
|
177
196
|
|
|
178
|
-
# Fit the estimator
|
|
179
|
-
|
|
197
|
+
# Fit the estimator — sklearn needs numpy
|
|
198
|
+
in_np = np.asarray(in_dat) if not is_numpy_array(in_dat) else in_dat
|
|
199
|
+
self._state.estimator.partial_fit(in_np)
|
|
180
200
|
|
|
181
201
|
|
|
182
202
|
class IncrementalPCASettings(AdaptiveDecompSettings):
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""Incremental Canonical Correlation Analysis (CCA).
|
|
2
|
+
|
|
3
|
+
.. note::
|
|
4
|
+
This module supports the Array API standard via
|
|
5
|
+
``array_api_compat.get_namespace()``. All linear algebra uses Array API
|
|
6
|
+
operations; ``scipy.linalg.sqrtm`` is replaced by an eigendecomposition-
|
|
7
|
+
based inverse square root (:func:`_inv_sqrtm_spd`).
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from array_api_compat import get_namespace
|
|
12
|
+
from ezmsg.sigproc.util.array import array_device, xp_create
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _inv_sqrtm_spd(xp, A):
|
|
16
|
+
"""Inverse matrix square root for symmetric positive-definite matrices.
|
|
17
|
+
|
|
18
|
+
Computes ``inv(sqrtm(A)) = Q @ diag(1/sqrt(lambda)) @ Q^T`` using the
|
|
19
|
+
eigendecomposition. This is more numerically stable than computing
|
|
20
|
+
``inv(sqrtm(...))`` separately and uses only Array API operations.
|
|
21
|
+
"""
|
|
22
|
+
eigenvalues, eigenvectors = xp.linalg.eigh(A)
|
|
23
|
+
eigenvalues = xp.clip(eigenvalues, 1e-12, None) # avoid div-by-zero
|
|
24
|
+
inv_sqrt_eig = 1.0 / xp.sqrt(eigenvalues)
|
|
25
|
+
# Q @ diag(v) == Q * v (broadcasting), then @ Q^T
|
|
26
|
+
return (eigenvectors * inv_sqrt_eig) @ xp.linalg.matrix_transpose(eigenvectors)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class IncrementalCCA:
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
n_components=2,
|
|
33
|
+
base_smoothing=0.95,
|
|
34
|
+
min_smoothing=0.5,
|
|
35
|
+
max_smoothing=0.99,
|
|
36
|
+
adaptation_rate=0.1,
|
|
37
|
+
):
|
|
38
|
+
"""
|
|
39
|
+
Parameters:
|
|
40
|
+
-----------
|
|
41
|
+
n_components : int
|
|
42
|
+
Number of canonical components to compute
|
|
43
|
+
base_smoothing : float
|
|
44
|
+
Base smoothing factor (will be adapted)
|
|
45
|
+
min_smoothing : float
|
|
46
|
+
Minimum allowed smoothing factor
|
|
47
|
+
max_smoothing : float
|
|
48
|
+
Maximum allowed smoothing factor
|
|
49
|
+
adaptation_rate : float
|
|
50
|
+
How quickly to adjust smoothing factor (between 0 and 1)
|
|
51
|
+
"""
|
|
52
|
+
self.n_components = n_components
|
|
53
|
+
self.base_smoothing = base_smoothing
|
|
54
|
+
self.current_smoothing = base_smoothing
|
|
55
|
+
self.min_smoothing = min_smoothing
|
|
56
|
+
self.max_smoothing = max_smoothing
|
|
57
|
+
self.adaptation_rate = adaptation_rate
|
|
58
|
+
self.initialized = False
|
|
59
|
+
|
|
60
|
+
def initialize(self, d1, d2, *, ref_array=None):
|
|
61
|
+
"""Initialize the necessary matrices.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
d1: Dimensionality of the first dataset.
|
|
65
|
+
d2: Dimensionality of the second dataset.
|
|
66
|
+
ref_array: Optional reference array to derive array namespace
|
|
67
|
+
and device from. If ``None``, defaults to NumPy.
|
|
68
|
+
"""
|
|
69
|
+
self.d1 = d1
|
|
70
|
+
self.d2 = d2
|
|
71
|
+
|
|
72
|
+
if ref_array is not None:
|
|
73
|
+
xp = get_namespace(ref_array)
|
|
74
|
+
dev = array_device(ref_array)
|
|
75
|
+
else:
|
|
76
|
+
xp, dev = np, None
|
|
77
|
+
|
|
78
|
+
# Initialize correlation matrices
|
|
79
|
+
self.C11 = xp_create(xp.zeros, (d1, d1), dtype=xp.float64, device=dev)
|
|
80
|
+
self.C22 = xp_create(xp.zeros, (d2, d2), dtype=xp.float64, device=dev)
|
|
81
|
+
self.C12 = xp_create(xp.zeros, (d1, d2), dtype=xp.float64, device=dev)
|
|
82
|
+
|
|
83
|
+
self.initialized = True
|
|
84
|
+
|
|
85
|
+
def _compute_change_magnitude(self, C11_new, C22_new, C12_new):
|
|
86
|
+
"""Compute magnitude of change in correlation structure."""
|
|
87
|
+
xp = get_namespace(self.C11)
|
|
88
|
+
|
|
89
|
+
# Frobenius norm of differences
|
|
90
|
+
diff11 = xp.linalg.matrix_norm(C11_new - self.C11)
|
|
91
|
+
diff22 = xp.linalg.matrix_norm(C22_new - self.C22)
|
|
92
|
+
diff12 = xp.linalg.matrix_norm(C12_new - self.C12)
|
|
93
|
+
|
|
94
|
+
# Normalize by matrix sizes
|
|
95
|
+
diff11 = diff11 / (self.d1 * self.d1)
|
|
96
|
+
diff22 = diff22 / (self.d2 * self.d2)
|
|
97
|
+
diff12 = diff12 / (self.d1 * self.d2)
|
|
98
|
+
|
|
99
|
+
return float((diff11 + diff22 + diff12) / 3)
|
|
100
|
+
|
|
101
|
+
def _adapt_smoothing(self, change_magnitude):
|
|
102
|
+
"""Adapt smoothing factor based on detected changes."""
|
|
103
|
+
# If change is large, decrease smoothing factor
|
|
104
|
+
target_smoothing = self.base_smoothing * (1.0 - change_magnitude)
|
|
105
|
+
target_smoothing = max(self.min_smoothing, min(target_smoothing, self.max_smoothing))
|
|
106
|
+
|
|
107
|
+
# Smooth the adaptation itself
|
|
108
|
+
self.current_smoothing = (
|
|
109
|
+
1 - self.adaptation_rate
|
|
110
|
+
) * self.current_smoothing + self.adaptation_rate * target_smoothing
|
|
111
|
+
|
|
112
|
+
def partial_fit(self, X1, X2, update_projections=True):
|
|
113
|
+
"""Update the model with new samples using adaptive smoothing.
|
|
114
|
+
Assumes X1 and X2 are already centered and scaled."""
|
|
115
|
+
xp = get_namespace(X1, X2)
|
|
116
|
+
_mT = xp.linalg.matrix_transpose
|
|
117
|
+
|
|
118
|
+
if not self.initialized:
|
|
119
|
+
self.initialize(X1.shape[1], X2.shape[1], ref_array=X1)
|
|
120
|
+
|
|
121
|
+
# Compute new correlation matrices from current batch
|
|
122
|
+
C11_new = _mT(X1) @ X1 / X1.shape[0]
|
|
123
|
+
C22_new = _mT(X2) @ X2 / X2.shape[0]
|
|
124
|
+
C12_new = _mT(X1) @ X2 / X1.shape[0]
|
|
125
|
+
|
|
126
|
+
# Detect changes and adapt smoothing factor
|
|
127
|
+
if bool(xp.any(self.C11 != 0)): # Skip first update
|
|
128
|
+
change_magnitude = self._compute_change_magnitude(C11_new, C22_new, C12_new)
|
|
129
|
+
self._adapt_smoothing(change_magnitude)
|
|
130
|
+
|
|
131
|
+
# Update with current smoothing factor
|
|
132
|
+
alpha = self.current_smoothing
|
|
133
|
+
self.C11 = alpha * self.C11 + (1 - alpha) * C11_new
|
|
134
|
+
self.C22 = alpha * self.C22 + (1 - alpha) * C22_new
|
|
135
|
+
self.C12 = alpha * self.C12 + (1 - alpha) * C12_new
|
|
136
|
+
|
|
137
|
+
if update_projections:
|
|
138
|
+
self._update_projections()
|
|
139
|
+
|
|
140
|
+
def _update_projections(self):
|
|
141
|
+
"""Update canonical vectors and correlations."""
|
|
142
|
+
xp = get_namespace(self.C11)
|
|
143
|
+
dev = array_device(self.C11)
|
|
144
|
+
_mT = xp.linalg.matrix_transpose
|
|
145
|
+
|
|
146
|
+
eps = 1e-8
|
|
147
|
+
C11_reg = self.C11 + eps * xp_create(xp.eye, self.d1, dtype=self.C11.dtype, device=dev)
|
|
148
|
+
C22_reg = self.C22 + eps * xp_create(xp.eye, self.d2, dtype=self.C22.dtype, device=dev)
|
|
149
|
+
|
|
150
|
+
inv_sqrt_C11 = _inv_sqrtm_spd(xp, C11_reg)
|
|
151
|
+
inv_sqrt_C22 = _inv_sqrtm_spd(xp, C22_reg)
|
|
152
|
+
|
|
153
|
+
K = inv_sqrt_C11 @ self.C12 @ inv_sqrt_C22
|
|
154
|
+
U, self.correlations_, Vh = xp.linalg.svd(K, full_matrices=False)
|
|
155
|
+
|
|
156
|
+
self.x_weights_ = inv_sqrt_C11 @ U[:, : self.n_components]
|
|
157
|
+
self.y_weights_ = inv_sqrt_C22 @ _mT(Vh)[:, : self.n_components]
|
|
158
|
+
|
|
159
|
+
def transform(self, X1, X2):
|
|
160
|
+
"""Project data onto canonical components."""
|
|
161
|
+
X1_proj = X1 @ self.x_weights_
|
|
162
|
+
X2_proj = X2 @ self.y_weights_
|
|
163
|
+
return X1_proj, X2_proj
|