ezmsg-learn 1.3.0__tar.gz → 1.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/PKG-INFO +4 -3
  2. ezmsg_learn-1.4.0/docs/source/guides/array_api.rst +246 -0
  3. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/docs/source/guides/classification.rst +5 -4
  4. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/docs/source/index.rst +1 -0
  5. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/pyproject.toml +4 -4
  6. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/__version__.py +2 -2
  7. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/dim_reduce/adaptive_decomp.py +28 -8
  8. ezmsg_learn-1.4.0/src/ezmsg/learn/model/cca.py +163 -0
  9. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/model/refit_kalman.py +90 -69
  10. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/adaptive_linear_regressor.py +32 -9
  11. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/refit_kalman.py +56 -26
  12. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/slda.py +22 -6
  13. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/ssr.py +0 -5
  14. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/tests/unit/test_adaptive_linear_regressor.py +2 -2
  15. ezmsg_learn-1.4.0/tests/unit/test_cca.py +125 -0
  16. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/tests/unit/test_linear_regressor.py +3 -3
  17. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/tests/unit/test_mlp_old.py +1 -2
  18. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/tests/unit/test_rnn.py +1 -1
  19. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/tests/unit/test_sgd.py +2 -2
  20. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/tests/unit/test_ssr.py +4 -4
  21. ezmsg_learn-1.3.0/src/ezmsg/learn/model/cca.py +0 -122
  22. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/.github/workflows/docs.yml +0 -0
  23. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/.github/workflows/python-publish.yml +0 -0
  24. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/.github/workflows/python-tests.yml +0 -0
  25. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/.gitignore +0 -0
  26. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/.pre-commit-config.yaml +0 -0
  27. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/LICENSE +0 -0
  28. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/README.md +0 -0
  29. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/docs/Makefile +0 -0
  30. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/docs/make.bat +0 -0
  31. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/docs/source/_templates/autosummary/module.rst +0 -0
  32. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/docs/source/api/index.rst +0 -0
  33. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/docs/source/conf.py +0 -0
  34. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/__init__.py +0 -0
  35. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/dim_reduce/__init__.py +0 -0
  36. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/dim_reduce/incremental_decomp.py +0 -0
  37. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/linear_model/__init__.py +0 -0
  38. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/linear_model/adaptive_linear_regressor.py +0 -0
  39. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/linear_model/cca.py +0 -0
  40. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/linear_model/linear_regressor.py +0 -0
  41. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/linear_model/sgd.py +0 -0
  42. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/linear_model/slda.py +0 -0
  43. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/model/__init__.py +0 -0
  44. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/model/mlp.py +0 -0
  45. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/model/mlp_old.py +0 -0
  46. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/model/rnn.py +0 -0
  47. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/model/transformer.py +0 -0
  48. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/nlin_model/__init__.py +0 -0
  49. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/nlin_model/mlp.py +0 -0
  50. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/__init__.py +0 -0
  51. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/base.py +0 -0
  52. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/linear_regressor.py +0 -0
  53. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/mlp_old.py +0 -0
  54. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/rnn.py +0 -0
  55. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/sgd.py +0 -0
  56. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/sklearn.py +0 -0
  57. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/torch.py +0 -0
  58. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/process/transformer.py +0 -0
  59. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/src/ezmsg/learn/util.py +0 -0
  60. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/tests/benchmark/bench_lrr.py +0 -0
  61. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/tests/dim_reduce/test_adaptive_decomp.py +0 -0
  62. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/tests/dim_reduce/test_incremental_decomp.py +0 -0
  63. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/tests/integration/conftest.py +0 -0
  64. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/tests/integration/test_mlp_system.py +0 -0
  65. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/tests/integration/test_refit_kalman_system.py +0 -0
  66. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/tests/integration/test_rnn_system.py +0 -0
  67. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/tests/integration/test_sklearn_system.py +0 -0
  68. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/tests/integration/test_torch_system.py +0 -0
  69. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/tests/integration/test_transformer_system.py +0 -0
  70. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/tests/unit/test_mlp.py +0 -0
  71. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/tests/unit/test_refit_kalman.py +0 -0
  72. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/tests/unit/test_sklearn.py +0 -0
  73. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/tests/unit/test_slda.py +0 -0
  74. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/tests/unit/test_torch.py +0 -0
  75. {ezmsg_learn-1.3.0 → ezmsg_learn-1.4.0}/tests/unit/test_transformer.py +0 -0
@@ -1,13 +1,14 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ezmsg-learn
3
- Version: 1.3.0
3
+ Version: 1.4.0
4
4
  Summary: ezmsg namespace package for machine learning
5
5
  Author-email: Chadwick Boulay <chadwick.boulay@gmail.com>
6
6
  License-Expression: MIT
7
7
  License-File: LICENSE
8
8
  Requires-Python: >=3.10.15
9
- Requires-Dist: ezmsg-baseproc>=1.3.0
10
- Requires-Dist: ezmsg-sigproc>=2.15.0
9
+ Requires-Dist: ezmsg-baseproc>=1.5.1
10
+ Requires-Dist: ezmsg-sigproc>=2.17.0
11
+ Requires-Dist: ezmsg>=3.7.3
11
12
  Requires-Dist: river>=0.22.0
12
13
  Requires-Dist: scikit-learn>=1.6.0
13
14
  Requires-Dist: torch>=2.6.0
@@ -0,0 +1,246 @@
1
+ Array API Compatibility
2
+ =======================
3
+
4
+ ezmsg-learn uses the `Array API standard <https://data-apis.org/array-api/latest/>`_
5
+ to allow processors to operate on arrays from different backends — NumPy, CuPy,
6
+ PyTorch, and others — without code changes.
7
+
8
+ .. contents:: On this page
9
+ :local:
10
+ :depth: 2
11
+
12
+
13
+ How It Works
14
+ ------------
15
+
16
+ Modules that support the Array API derive the array namespace from their input
17
+ data using ``array_api_compat.get_namespace()``:
18
+
19
+ .. code-block:: python
20
+
21
+ from array_api_compat import get_namespace
22
+
23
+ def process(self, data):
24
+ xp = get_namespace(data) # numpy, cupy, torch, etc.
25
+ result = xp.linalg.inv(data) # dispatches to the right backend
26
+ return result
27
+
28
+ This means that if you pass a CuPy array, all computation stays on the GPU.
29
+ If you pass a NumPy array, it behaves exactly as before.
30
+
31
+ Helper utilities from ``ezmsg.sigproc.util.array`` handle device placement
32
+ and creation functions portably:
33
+
34
+ - ``array_device(x)`` — returns the device of an array, or ``None``
35
+ - ``xp_create(fn, *args, dtype=None, device=None)`` — calls creation
36
+ functions (``zeros``, ``eye``) with optional device
37
+ - ``xp_asarray(xp, obj, dtype=None, device=None)`` — portable ``asarray``
38
+
39
+
40
+ Module Compatibility
41
+ --------------------
42
+
43
+ The table below summarises the Array API status of each module.
44
+
45
+ Fully compatible
46
+ ^^^^^^^^^^^^^^^^
47
+
48
+ These modules perform all computation in the source array namespace.
49
+
50
+ .. list-table::
51
+ :header-rows: 1
52
+ :widths: 35 65
53
+
54
+ * - Module
55
+ - Notes
56
+ * - ``process.ssr``
57
+ - LRR / self-supervised regression. Full Array API.
58
+ * - ``model.cca``
59
+ - Incremental CCA. Replaced ``scipy.linalg.sqrtm`` with an
60
+ eigendecomposition-based inverse square root using only Array API ops.
61
+ * - ``process.rnn``
62
+ - PyTorch-native; operates on ``torch.Tensor`` throughout.
63
+
64
+ Mostly compatible (with NumPy boundaries)
65
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
66
+
67
+ These modules use the Array API for data manipulation but fall back to NumPy
68
+ at specific points where a dependency requires it.
69
+
70
+ .. list-table::
71
+ :header-rows: 1
72
+ :widths: 25 35 40
73
+
74
+ * - Module
75
+ - NumPy boundary
76
+ - Reason
77
+ * - ``model.refit_kalman``
78
+ - ``_compute_gain()``
79
+ - ``scipy.linalg.solve_discrete_are`` has no Array API equivalent.
80
+ Matrices are converted to NumPy for the DARE solver, then converted back.
81
+ * - ``model.refit_kalman``
82
+ - ``refit()`` mutation loop
83
+ - Per-sample velocity remapping uses ``np.linalg.norm`` on small vectors
84
+ and scalar element assignment.
85
+ * - ``process.refit_kalman``
86
+ - Inherits boundaries from model
87
+ - State init and output arrays use the source namespace.
88
+ * - ``process.slda``
89
+ - ``predict_proba``
90
+ - sklearn ``LinearDiscriminantAnalysis`` requires NumPy input.
91
+ * - ``process.adaptive_linear_regressor``
92
+ - ``partial_fit`` / ``predict``
93
+ - sklearn and river models require NumPy / pandas input.
94
+ * - ``dim_reduce.adaptive_decomp``
95
+ - ``partial_fit`` / ``transform``
96
+ - sklearn ``IncrementalPCA`` and ``MiniBatchNMF`` require NumPy input.
97
+
98
+ Not converted
99
+ ^^^^^^^^^^^^^
100
+
101
+ These modules use NumPy directly. Conversion would provide little benefit
102
+ because the underlying estimator is the bottleneck.
103
+
104
+ .. list-table::
105
+ :header-rows: 1
106
+ :widths: 25 75
107
+
108
+ * - Module
109
+ - Reason
110
+ * - ``process.linear_regressor``
111
+ - Thin wrapper around sklearn ``LinearModel.predict``.
112
+ Could be made compatible if sklearn's ``array_api_dispatch`` is enabled
113
+ (see below).
114
+ * - ``process.sgd``
115
+ - sklearn ``SGDClassifier`` has no Array API support.
116
+ * - ``process.sklearn``
117
+ - Generic wrapper for arbitrary models; cannot assume Array API support.
118
+ * - ``dim_reduce.incremental_decomp``
119
+ - Delegates to ``adaptive_decomp``; trivial numpy usage (``np.prod`` on
120
+ Python tuples).
121
+
122
+
123
+ sklearn Array API Dispatch
124
+ --------------------------
125
+
126
+ scikit-learn 1.8+ has experimental support for Array API dispatch on a subset
127
+ of estimators. Two estimators used in ezmsg-learn are on the supported list:
128
+
129
+ .. list-table::
130
+ :header-rows: 1
131
+ :widths: 30 30 40
132
+
133
+ * - Estimator
134
+ - Used in
135
+ - Constraint
136
+ * - ``LinearDiscriminantAnalysis``
137
+ - ``process.slda``
138
+ - Requires ``solver="svd"`` (the ``"lsqr"`` solver with ``shrinkage``
139
+ is not supported)
140
+ * - ``Ridge``
141
+ - ``process.linear_regressor``
142
+ - Requires ``solver="svd"``
143
+
144
+ To use dispatch, enable it before creating the estimator:
145
+
146
+ .. code-block:: python
147
+
148
+ from sklearn import set_config
149
+ set_config(array_api_dispatch=True)
150
+
151
+ .. warning::
152
+
153
+ - ``array_api_dispatch`` is marked **experimental** in sklearn.
154
+ - Solver constraints (``solver="svd"``) may produce slightly different
155
+ numerical results compared to other solvers.
156
+ - Enabling dispatch globally may affect other sklearn estimators in the
157
+ same process.
158
+ - ezmsg-learn does **not** enable dispatch by default.
159
+
160
+ Estimators that do **not** support Array API dispatch:
161
+
162
+ - ``IncrementalPCA``, ``MiniBatchNMF`` — only batch ``PCA`` is supported
163
+ - ``SGDClassifier``, ``SGDRegressor``, ``PassiveAggressiveRegressor``
164
+ - All river models
165
+
166
+
167
+ Writing Array API Compatible Code
168
+ ----------------------------------
169
+
170
+ When adding or modifying processors in ezmsg-learn, follow these patterns.
171
+
172
+ Deriving the namespace
173
+ ^^^^^^^^^^^^^^^^^^^^^^
174
+
175
+ Always derive ``xp`` from the input data, not from a hardcoded ``numpy``:
176
+
177
+ .. code-block:: python
178
+
179
+ from array_api_compat import get_namespace
180
+ from ezmsg.sigproc.util.array import array_device, xp_create
181
+
182
+ def _process(self, message):
183
+ xp = get_namespace(message.data)
184
+ dev = array_device(message.data)
185
+
186
+ Transposing matrices
187
+ ^^^^^^^^^^^^^^^^^^^^
188
+
189
+ The Array API does not support ``.T``. Use ``xp.linalg.matrix_transpose()``:
190
+
191
+ .. code-block:: python
192
+
193
+ # Before (numpy-only)
194
+ result = A.T @ B
195
+
196
+ # After (Array API)
197
+ _mT = xp.linalg.matrix_transpose
198
+ result = _mT(A) @ B
199
+
200
+ Creating arrays
201
+ ^^^^^^^^^^^^^^^
202
+
203
+ Use ``xp_create`` to handle device placement portably:
204
+
205
+ .. code-block:: python
206
+
207
+ # Before
208
+ I = np.eye(n)
209
+ z = np.zeros((m, n), dtype=np.float64)
210
+
211
+ # After
212
+ I = xp_create(xp.eye, n, device=dev)
213
+ z = xp_create(xp.zeros, (m, n), dtype=xp.float64, device=dev)
214
+
215
+ Handling sklearn boundaries
216
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
217
+
218
+ When calling into sklearn (or other NumPy-only libraries), convert at the
219
+ boundary and convert back:
220
+
221
+ .. code-block:: python
222
+
223
+ from array_api_compat import is_numpy_array
224
+
225
+ # Convert to numpy for sklearn
226
+ X_np = np.asarray(X) if not is_numpy_array(X) else X
227
+ result_np = estimator.predict(X_np)
228
+
229
+ # Convert back to source namespace
230
+ result = xp.asarray(result_np) if not is_numpy_array(X) else result_np
231
+
232
+ Checking for NaN
233
+ ^^^^^^^^^^^^^^^^
234
+
235
+ Use ``xp.isnan`` instead of ``np.isnan``:
236
+
237
+ .. code-block:: python
238
+
239
+ if xp.any(xp.isnan(message.data)):
240
+ return
241
+
242
+ Norms
243
+ ^^^^^
244
+
245
+ Use ``xp.linalg.matrix_norm`` (Frobenius by default) instead of
246
+ ``np.linalg.norm`` for matrices. For vectors, use ``xp.linalg.vector_norm``.
@@ -125,7 +125,8 @@ For models that support ``partial_fit``, you can update them during streaming:
125
125
  .. code-block:: python
126
126
 
127
127
  from ezmsg.learn.process.sklearn import SklearnModelProcessor, SklearnModelSettings
128
- from ezmsg.sigproc.sampler import SampleMessage
128
+ from ezmsg.baseproc import SampleTriggerMessage
129
+ from ezmsg.util.messages.util import replace
129
130
 
130
131
  # Create processor with online learning support
131
132
  processor = SklearnModelProcessor(
@@ -137,9 +138,9 @@ For models that support ``partial_fit``, you can update them during streaming:
137
138
  )
138
139
 
139
140
  # Training with labeled samples
140
- sample_msg = SampleMessage(
141
- sample=feature_array, # AxisArray with features
142
- trigger=label_value, # The class label
141
+ sample_msg = replace(
142
+ feature_array, # AxisArray with features
143
+ attrs={"trigger": SampleTriggerMessage(value=label_value)}
143
144
  )
144
145
  processor.partial_fit(sample_msg)
145
146
 
@@ -54,6 +54,7 @@ For general ezmsg tutorials and guides, visit `ezmsg.org <https://www.ezmsg.org>
54
54
  :caption: Contents:
55
55
 
56
56
  guides/classification
57
+ guides/array_api
57
58
  api/index
58
59
 
59
60
 
@@ -9,8 +9,9 @@ license = "MIT"
9
9
  requires-python = ">=3.10.15"
10
10
  dynamic = ["version"]
11
11
  dependencies = [
12
- "ezmsg-baseproc>=1.3.0",
13
- "ezmsg-sigproc>=2.15.0",
12
+ "ezmsg>=3.7.3",
13
+ "ezmsg-baseproc>=1.5.1",
14
+ "ezmsg-sigproc>=2.17.0",
14
15
  "river>=0.22.0",
15
16
  "scikit-learn>=1.6.0",
16
17
  "torch>=2.6.0",
@@ -73,5 +74,4 @@ known-third-party = ["ezmsg", "ezmsg.baseproc", "ezmsg.sigproc"]
73
74
 
74
75
  [tool.uv.sources]
75
76
  # Uncomment to use development version of ezmsg from git
76
- #ezmsg = { git = "https://github.com/ezmsg-org/ezmsg.git", branch = "feature/profiling" }
77
- #ezmsg-sigproc = { path = "../ezmsg-sigproc", editable = true }
77
+ #ezmsg = { git = "https://github.com/ezmsg-org/ezmsg.git", branch = "feature/profiling" }
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '1.3.0'
32
- __version_tuple__ = version_tuple = (1, 3, 0)
31
+ __version__ = version = '1.4.0'
32
+ __version_tuple__ = version_tuple = (1, 4, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -1,7 +1,18 @@
1
+ """Adaptive decomposition transformers (PCA, NMF).
2
+
3
+ .. note::
4
+ This module supports the Array API standard via
5
+ ``array_api_compat.get_namespace()``. Reshaping and output allocation
6
+ use Array API operations; a NumPy boundary is applied before sklearn
7
+ ``partial_fit``/``transform`` calls.
8
+ """
9
+
10
+ import math
1
11
  import typing
2
12
 
3
13
  import ezmsg.core as ez
4
14
  import numpy as np
15
+ from array_api_compat import get_namespace, is_numpy_array
5
16
  from ezmsg.baseproc import (
6
17
  BaseAdaptiveTransformer,
7
18
  BaseAdaptiveTransformerUnit,
@@ -128,6 +139,8 @@ class AdaptiveDecompTransformer(
128
139
  if in_dat.shape[ax_idx] == 0:
129
140
  return self._state.template
130
141
 
142
+ xp = get_namespace(in_dat)
143
+
131
144
  # Re-order axes
132
145
  sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes
133
146
  if message.dims != sorted_dims_exp:
@@ -137,16 +150,20 @@ class AdaptiveDecompTransformer(
137
150
  pass
138
151
 
139
152
  # fold [iter_axis] + off_targ_axes together and fold targ_axes together
140
- d2 = np.prod(in_dat.shape[len(off_targ_axes) + 1 :])
141
- in_dat = in_dat.reshape((-1, d2))
153
+ d2 = math.prod(in_dat.shape[len(off_targ_axes) + 1 :])
154
+ in_dat = xp.reshape(in_dat, (-1, d2))
142
155
 
143
156
  replace_kwargs = {
144
157
  "axes": {**self._state.template.axes, iter_axis: message.axes[iter_axis]},
145
158
  }
146
159
 
147
- # Transform data
160
+ # Transform data — sklearn needs numpy
148
161
  if hasattr(self._state.estimator, "components_"):
149
- decomp_dat = self._state.estimator.transform(in_dat).reshape((-1,) + self._state.template.data.shape[1:])
162
+ in_np = np.asarray(in_dat) if not is_numpy_array(in_dat) else in_dat
163
+ decomp_dat = self._state.estimator.transform(in_np)
164
+ # Convert back to source namespace
165
+ decomp_dat = xp.asarray(decomp_dat) if not is_numpy_array(in_dat) else decomp_dat
166
+ decomp_dat = xp.reshape(decomp_dat, (-1,) + self._state.template.data.shape[1:])
150
167
  replace_kwargs["data"] = decomp_dat
151
168
 
152
169
  return replace(self._state.template, **replace_kwargs)
@@ -165,6 +182,8 @@ class AdaptiveDecompTransformer(
165
182
  if in_dat.shape[ax_idx] == 0:
166
183
  return
167
184
 
185
+ xp = get_namespace(in_dat)
186
+
168
187
  # Re-order axes if needed
169
188
  sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes
170
189
  if message.dims != sorted_dims_exp:
@@ -172,11 +191,12 @@ class AdaptiveDecompTransformer(
172
191
  pass
173
192
 
174
193
  # fold [iter_axis] + off_targ_axes together and fold targ_axes together
175
- d2 = np.prod(in_dat.shape[len(off_targ_axes) + 1 :])
176
- in_dat = in_dat.reshape((-1, d2))
194
+ d2 = math.prod(in_dat.shape[len(off_targ_axes) + 1 :])
195
+ in_dat = xp.reshape(in_dat, (-1, d2))
177
196
 
178
- # Fit the estimator
179
- self._state.estimator.partial_fit(in_dat)
197
+ # Fit the estimator — sklearn needs numpy
198
+ in_np = np.asarray(in_dat) if not is_numpy_array(in_dat) else in_dat
199
+ self._state.estimator.partial_fit(in_np)
180
200
 
181
201
 
182
202
  class IncrementalPCASettings(AdaptiveDecompSettings):
@@ -0,0 +1,163 @@
1
+ """Incremental Canonical Correlation Analysis (CCA).
2
+
3
+ .. note::
4
+ This module supports the Array API standard via
5
+ ``array_api_compat.get_namespace()``. All linear algebra uses Array API
6
+ operations; ``scipy.linalg.sqrtm`` is replaced by an eigendecomposition-
7
+ based inverse square root (:func:`_inv_sqrtm_spd`).
8
+ """
9
+
10
+ import numpy as np
11
+ from array_api_compat import get_namespace
12
+ from ezmsg.sigproc.util.array import array_device, xp_create
13
+
14
+
15
+ def _inv_sqrtm_spd(xp, A):
16
+ """Inverse matrix square root for symmetric positive-definite matrices.
17
+
18
+ Computes ``inv(sqrtm(A)) = Q @ diag(1/sqrt(lambda)) @ Q^T`` using the
19
+ eigendecomposition. This is more numerically stable than computing
20
+ ``inv(sqrtm(...))`` separately and uses only Array API operations.
21
+ """
22
+ eigenvalues, eigenvectors = xp.linalg.eigh(A)
23
+ eigenvalues = xp.clip(eigenvalues, 1e-12, None) # avoid div-by-zero
24
+ inv_sqrt_eig = 1.0 / xp.sqrt(eigenvalues)
25
+ # Q @ diag(v) == Q * v (broadcasting), then @ Q^T
26
+ return (eigenvectors * inv_sqrt_eig) @ xp.linalg.matrix_transpose(eigenvectors)
27
+
28
+
29
+ class IncrementalCCA:
30
+ def __init__(
31
+ self,
32
+ n_components=2,
33
+ base_smoothing=0.95,
34
+ min_smoothing=0.5,
35
+ max_smoothing=0.99,
36
+ adaptation_rate=0.1,
37
+ ):
38
+ """
39
+ Parameters:
40
+ -----------
41
+ n_components : int
42
+ Number of canonical components to compute
43
+ base_smoothing : float
44
+ Base smoothing factor (will be adapted)
45
+ min_smoothing : float
46
+ Minimum allowed smoothing factor
47
+ max_smoothing : float
48
+ Maximum allowed smoothing factor
49
+ adaptation_rate : float
50
+ How quickly to adjust smoothing factor (between 0 and 1)
51
+ """
52
+ self.n_components = n_components
53
+ self.base_smoothing = base_smoothing
54
+ self.current_smoothing = base_smoothing
55
+ self.min_smoothing = min_smoothing
56
+ self.max_smoothing = max_smoothing
57
+ self.adaptation_rate = adaptation_rate
58
+ self.initialized = False
59
+
60
+ def initialize(self, d1, d2, *, ref_array=None):
61
+ """Initialize the necessary matrices.
62
+
63
+ Args:
64
+ d1: Dimensionality of the first dataset.
65
+ d2: Dimensionality of the second dataset.
66
+ ref_array: Optional reference array to derive array namespace
67
+ and device from. If ``None``, defaults to NumPy.
68
+ """
69
+ self.d1 = d1
70
+ self.d2 = d2
71
+
72
+ if ref_array is not None:
73
+ xp = get_namespace(ref_array)
74
+ dev = array_device(ref_array)
75
+ else:
76
+ xp, dev = np, None
77
+
78
+ # Initialize correlation matrices
79
+ self.C11 = xp_create(xp.zeros, (d1, d1), dtype=xp.float64, device=dev)
80
+ self.C22 = xp_create(xp.zeros, (d2, d2), dtype=xp.float64, device=dev)
81
+ self.C12 = xp_create(xp.zeros, (d1, d2), dtype=xp.float64, device=dev)
82
+
83
+ self.initialized = True
84
+
85
+ def _compute_change_magnitude(self, C11_new, C22_new, C12_new):
86
+ """Compute magnitude of change in correlation structure."""
87
+ xp = get_namespace(self.C11)
88
+
89
+ # Frobenius norm of differences
90
+ diff11 = xp.linalg.matrix_norm(C11_new - self.C11)
91
+ diff22 = xp.linalg.matrix_norm(C22_new - self.C22)
92
+ diff12 = xp.linalg.matrix_norm(C12_new - self.C12)
93
+
94
+ # Normalize by matrix sizes
95
+ diff11 = diff11 / (self.d1 * self.d1)
96
+ diff22 = diff22 / (self.d2 * self.d2)
97
+ diff12 = diff12 / (self.d1 * self.d2)
98
+
99
+ return float((diff11 + diff22 + diff12) / 3)
100
+
101
+ def _adapt_smoothing(self, change_magnitude):
102
+ """Adapt smoothing factor based on detected changes."""
103
+ # If change is large, decrease smoothing factor
104
+ target_smoothing = self.base_smoothing * (1.0 - change_magnitude)
105
+ target_smoothing = max(self.min_smoothing, min(target_smoothing, self.max_smoothing))
106
+
107
+ # Smooth the adaptation itself
108
+ self.current_smoothing = (
109
+ 1 - self.adaptation_rate
110
+ ) * self.current_smoothing + self.adaptation_rate * target_smoothing
111
+
112
+ def partial_fit(self, X1, X2, update_projections=True):
113
+ """Update the model with new samples using adaptive smoothing.
114
+ Assumes X1 and X2 are already centered and scaled."""
115
+ xp = get_namespace(X1, X2)
116
+ _mT = xp.linalg.matrix_transpose
117
+
118
+ if not self.initialized:
119
+ self.initialize(X1.shape[1], X2.shape[1], ref_array=X1)
120
+
121
+ # Compute new correlation matrices from current batch
122
+ C11_new = _mT(X1) @ X1 / X1.shape[0]
123
+ C22_new = _mT(X2) @ X2 / X2.shape[0]
124
+ C12_new = _mT(X1) @ X2 / X1.shape[0]
125
+
126
+ # Detect changes and adapt smoothing factor
127
+ if bool(xp.any(self.C11 != 0)): # Skip first update
128
+ change_magnitude = self._compute_change_magnitude(C11_new, C22_new, C12_new)
129
+ self._adapt_smoothing(change_magnitude)
130
+
131
+ # Update with current smoothing factor
132
+ alpha = self.current_smoothing
133
+ self.C11 = alpha * self.C11 + (1 - alpha) * C11_new
134
+ self.C22 = alpha * self.C22 + (1 - alpha) * C22_new
135
+ self.C12 = alpha * self.C12 + (1 - alpha) * C12_new
136
+
137
+ if update_projections:
138
+ self._update_projections()
139
+
140
+ def _update_projections(self):
141
+ """Update canonical vectors and correlations."""
142
+ xp = get_namespace(self.C11)
143
+ dev = array_device(self.C11)
144
+ _mT = xp.linalg.matrix_transpose
145
+
146
+ eps = 1e-8
147
+ C11_reg = self.C11 + eps * xp_create(xp.eye, self.d1, dtype=self.C11.dtype, device=dev)
148
+ C22_reg = self.C22 + eps * xp_create(xp.eye, self.d2, dtype=self.C22.dtype, device=dev)
149
+
150
+ inv_sqrt_C11 = _inv_sqrtm_spd(xp, C11_reg)
151
+ inv_sqrt_C22 = _inv_sqrtm_spd(xp, C22_reg)
152
+
153
+ K = inv_sqrt_C11 @ self.C12 @ inv_sqrt_C22
154
+ U, self.correlations_, Vh = xp.linalg.svd(K, full_matrices=False)
155
+
156
+ self.x_weights_ = inv_sqrt_C11 @ U[:, : self.n_components]
157
+ self.y_weights_ = inv_sqrt_C22 @ _mT(Vh)[:, : self.n_components]
158
+
159
+ def transform(self, X1, X2):
160
+ """Project data onto canonical components."""
161
+ X1_proj = X1 @ self.x_weights_
162
+ X2_proj = X2 @ self.y_weights_
163
+ return X1_proj, X2_proj