ezmsg-learn 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. ezmsg/learn/__init__.py +2 -0
  2. ezmsg/learn/__version__.py +34 -0
  3. ezmsg/learn/dim_reduce/__init__.py +0 -0
  4. ezmsg/learn/dim_reduce/adaptive_decomp.py +274 -0
  5. ezmsg/learn/dim_reduce/incremental_decomp.py +173 -0
  6. ezmsg/learn/linear_model/__init__.py +1 -0
  7. ezmsg/learn/linear_model/adaptive_linear_regressor.py +12 -0
  8. ezmsg/learn/linear_model/cca.py +1 -0
  9. ezmsg/learn/linear_model/linear_regressor.py +9 -0
  10. ezmsg/learn/linear_model/sgd.py +9 -0
  11. ezmsg/learn/linear_model/slda.py +12 -0
  12. ezmsg/learn/model/__init__.py +0 -0
  13. ezmsg/learn/model/cca.py +122 -0
  14. ezmsg/learn/model/mlp.py +127 -0
  15. ezmsg/learn/model/mlp_old.py +49 -0
  16. ezmsg/learn/model/refit_kalman.py +369 -0
  17. ezmsg/learn/model/rnn.py +160 -0
  18. ezmsg/learn/model/transformer.py +175 -0
  19. ezmsg/learn/nlin_model/__init__.py +1 -0
  20. ezmsg/learn/nlin_model/mlp.py +10 -0
  21. ezmsg/learn/process/__init__.py +0 -0
  22. ezmsg/learn/process/adaptive_linear_regressor.py +142 -0
  23. ezmsg/learn/process/base.py +154 -0
  24. ezmsg/learn/process/linear_regressor.py +95 -0
  25. ezmsg/learn/process/mlp_old.py +188 -0
  26. ezmsg/learn/process/refit_kalman.py +403 -0
  27. ezmsg/learn/process/rnn.py +245 -0
  28. ezmsg/learn/process/sgd.py +117 -0
  29. ezmsg/learn/process/sklearn.py +241 -0
  30. ezmsg/learn/process/slda.py +110 -0
  31. ezmsg/learn/process/ssr.py +374 -0
  32. ezmsg/learn/process/torch.py +362 -0
  33. ezmsg/learn/process/transformer.py +215 -0
  34. ezmsg/learn/util.py +67 -0
  35. ezmsg_learn-1.1.0.dist-info/METADATA +30 -0
  36. ezmsg_learn-1.1.0.dist-info/RECORD +38 -0
  37. ezmsg_learn-1.1.0.dist-info/WHEEL +4 -0
  38. ezmsg_learn-1.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,2 @@
1
+ def hello() -> str:
2
+ return "Hello from ezmsg-learn!"
@@ -0,0 +1,34 @@
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
12
+
13
+ TYPE_CHECKING = False
14
+ if TYPE_CHECKING:
15
+ from typing import Tuple
16
+ from typing import Union
17
+
18
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
20
+ else:
21
+ VERSION_TUPLE = object
22
+ COMMIT_ID = object
23
+
24
+ version: str
25
+ __version__: str
26
+ __version_tuple__: VERSION_TUPLE
27
+ version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
30
+
31
+ __version__ = version = '1.1.0'
32
+ __version_tuple__ = version_tuple = (1, 1, 0)
33
+
34
+ __commit_id__ = commit_id = None
File without changes
@@ -0,0 +1,274 @@
1
+ import typing
2
+
3
+ import ezmsg.core as ez
4
+ import numpy as np
5
+ from ezmsg.baseproc import (
6
+ BaseAdaptiveTransformer,
7
+ BaseAdaptiveTransformerUnit,
8
+ processor_state,
9
+ )
10
+ from ezmsg.util.messages.axisarray import AxisArray, replace
11
+ from sklearn.decomposition import IncrementalPCA, MiniBatchNMF
12
+
13
+
14
+ class AdaptiveDecompSettings(ez.Settings):
15
+ axis: str = "!time"
16
+ n_components: int = 2
17
+
18
+
19
+ @processor_state
20
+ class AdaptiveDecompState:
21
+ template: AxisArray | None = None
22
+ axis_groups: tuple[str, list[str], list[str]] | None = None
23
+ estimator: typing.Any = None
24
+
25
+
26
+ EstimatorType = typing.TypeVar("EstimatorType", bound=typing.Union[IncrementalPCA, MiniBatchNMF])
27
+
28
+
29
+ class AdaptiveDecompTransformer(
30
+ BaseAdaptiveTransformer[AdaptiveDecompSettings, AxisArray, AxisArray, AdaptiveDecompState],
31
+ typing.Generic[EstimatorType],
32
+ ):
33
+ """
34
+ Base class for adaptive decomposition transformers. See IncrementalPCATransformer and MiniBatchNMFTransformer
35
+ for concrete implementations.
36
+
37
+ Note that for these classes, adaptation is not automatic. The user must call partial_fit on the transformer.
38
+ For automated adaptation, see IncrementalDecompTransformer.
39
+ """
40
+
41
+ def __init__(self, *args, **kwargs):
42
+ super().__init__(*args, **kwargs)
43
+ self._state.estimator = self._create_estimator()
44
+
45
+ @classmethod
46
+ def get_message_type(cls, dir: str) -> typing.Type[AxisArray]:
47
+ # Override because we don't reuse the generic types.
48
+ return AxisArray
49
+
50
+ @classmethod
51
+ def get_estimator_type(cls) -> typing.Type[EstimatorType]:
52
+ return typing.get_args(cls.__orig_bases__[0])[0]
53
+
54
+ def _create_estimator(self) -> EstimatorType:
55
+ estimator_klass = self.get_estimator_type()
56
+ estimator_settings = self.settings.__dict__.copy()
57
+ estimator_settings.pop("axis")
58
+ return estimator_klass(**estimator_settings)
59
+
60
+ def _calculate_axis_groups(self, message: AxisArray):
61
+ if self.settings.axis.startswith("!"):
62
+ # Iterate over the !axis and collapse all other axes
63
+ iter_axis = self.settings.axis[1:]
64
+ it_ax_ix = message.get_axis_idx(iter_axis)
65
+ targ_axes = message.dims[:it_ax_ix] + message.dims[it_ax_ix + 1 :]
66
+ off_targ_axes = []
67
+ else:
68
+ # Do PCA on the parameterized axis
69
+ targ_axes = [self.settings.axis]
70
+ # Iterate over streaming axis
71
+ iter_axis = "win" if "win" in message.dims else "time"
72
+ if iter_axis == self.settings.axis:
73
+ raise ValueError(
74
+ f"Iterating axis ({iter_axis}) cannot be the same as the target axis ({self.settings.axis})"
75
+ )
76
+ it_ax_ix = message.get_axis_idx(iter_axis)
77
+ # Remaining axes are to be treated independently
78
+ off_targ_axes = [
79
+ _ for _ in (message.dims[:it_ax_ix] + message.dims[it_ax_ix + 1 :]) if _ != self.settings.axis
80
+ ]
81
+ self._state.axis_groups = iter_axis, targ_axes, off_targ_axes
82
+
83
+ def _hash_message(self, message: AxisArray) -> int:
84
+ iter_axis = (
85
+ self.settings.axis[1:]
86
+ if self.settings.axis.startswith("!")
87
+ else ("win" if "win" in message.dims else "time")
88
+ )
89
+ ax_idx = message.get_axis_idx(iter_axis)
90
+ sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
91
+ return hash((sample_shape, message.key))
92
+
93
+ def _reset_state(self, message: AxisArray) -> None:
94
+ """Reset state"""
95
+ self._calculate_axis_groups(message)
96
+ iter_axis, targ_axes, off_targ_axes = self._state.axis_groups
97
+
98
+ # Template
99
+ out_dims = [iter_axis] + off_targ_axes
100
+ out_axes = {
101
+ iter_axis: message.axes[iter_axis],
102
+ **{k: message.axes[k] for k in off_targ_axes},
103
+ }
104
+ if len(targ_axes) == 1:
105
+ targ_ax_name = targ_axes[0]
106
+ else:
107
+ targ_ax_name = "components"
108
+ out_dims += [targ_ax_name]
109
+ out_axes[targ_ax_name] = AxisArray.CoordinateAxis(
110
+ data=np.arange(self.settings.n_components).astype(str),
111
+ dims=[targ_ax_name],
112
+ unit="component",
113
+ )
114
+ out_shape = [message.data.shape[message.get_axis_idx(_)] for _ in off_targ_axes]
115
+ out_shape = (0,) + tuple(out_shape) + (self.settings.n_components,)
116
+ self._state.template = replace(
117
+ message,
118
+ data=np.zeros(out_shape, dtype=float),
119
+ dims=out_dims,
120
+ axes=out_axes,
121
+ )
122
+
123
+ def _process(self, message: AxisArray) -> AxisArray:
124
+ iter_axis, targ_axes, off_targ_axes = self._state.axis_groups
125
+ ax_idx = message.get_axis_idx(iter_axis)
126
+ in_dat = message.data
127
+
128
+ if in_dat.shape[ax_idx] == 0:
129
+ return self._state.template
130
+
131
+ # Re-order axes
132
+ sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes
133
+ if message.dims != sorted_dims_exp:
134
+ # TODO: Implement axes transposition if needed
135
+ # re_order = [ax_idx] + off_targ_inds + targ_inds
136
+ # np.transpose(in_dat, re_order)
137
+ pass
138
+
139
+ # fold [iter_axis] + off_targ_axes together and fold targ_axes together
140
+ d2 = np.prod(in_dat.shape[len(off_targ_axes) + 1 :])
141
+ in_dat = in_dat.reshape((-1, d2))
142
+
143
+ replace_kwargs = {
144
+ "axes": {**self._state.template.axes, iter_axis: message.axes[iter_axis]},
145
+ }
146
+
147
+ # Transform data
148
+ if hasattr(self._state.estimator, "components_"):
149
+ decomp_dat = self._state.estimator.transform(in_dat).reshape((-1,) + self._state.template.data.shape[1:])
150
+ replace_kwargs["data"] = decomp_dat
151
+
152
+ return replace(self._state.template, **replace_kwargs)
153
+
154
+ def partial_fit(self, message: AxisArray) -> None:
155
+ # Check if we need to reset state
156
+ msg_hash = self._hash_message(message)
157
+ if self._hash != msg_hash:
158
+ self._reset_state(message)
159
+ self._hash = msg_hash
160
+
161
+ iter_axis, targ_axes, off_targ_axes = self._state.axis_groups
162
+ ax_idx = message.get_axis_idx(iter_axis)
163
+ in_dat = message.data
164
+
165
+ if in_dat.shape[ax_idx] == 0:
166
+ return
167
+
168
+ # Re-order axes if needed
169
+ sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes
170
+ if message.dims != sorted_dims_exp:
171
+ # TODO: Implement axes transposition if needed
172
+ pass
173
+
174
+ # fold [iter_axis] + off_targ_axes together and fold targ_axes together
175
+ d2 = np.prod(in_dat.shape[len(off_targ_axes) + 1 :])
176
+ in_dat = in_dat.reshape((-1, d2))
177
+
178
+ # Fit the estimator
179
+ self._state.estimator.partial_fit(in_dat)
180
+
181
+
182
+ class IncrementalPCASettings(AdaptiveDecompSettings):
183
+ # Additional settings specific to PCA
184
+ whiten: bool = False
185
+ batch_size: typing.Optional[int] = None
186
+
187
+
188
+ class IncrementalPCATransformer(AdaptiveDecompTransformer[IncrementalPCA]):
189
+ pass
190
+
191
+
192
+ class MiniBatchNMFSettings(AdaptiveDecompSettings):
193
+ # Additional settings specific to NMF
194
+ init: typing.Optional[str] = "random"
195
+ """
196
+ 'random', 'nndsvd', 'nndsvda', 'nndsvdar', 'custom', or None
197
+ """
198
+
199
+ batch_size: int = 1024
200
+ """
201
+ batch_size is used only when doing a full fit (i.e., a reset),
202
+ or as the exponent to forget_factor, where a very small batch_size
203
+ will cause the model to update more slowly.
204
+ It is better to set batch_size to a larger number than the expected
205
+ chunk size and instead use forget_factor to control the learning rate.
206
+ """
207
+
208
+ beta_loss: typing.Union[str, float] = "frobenius"
209
+ """
210
+ 'frobenius', 'kullback-leibler', 'itakura-saito'
211
+ Note that values different from 'frobenius'
212
+ (or 2) and 'kullback-leibler' (or 1) lead to significantly slower
213
+ fits. Note that for `beta_loss <= 0` (or 'itakura-saito'), the input
214
+ matrix `X` cannot contain zeros.
215
+ """
216
+
217
+ tol: float = 1e-4
218
+
219
+ max_no_improvement: typing.Optional[int] = None
220
+
221
+ max_iter: int = 200
222
+
223
+ alpha_W: float = 0.0
224
+
225
+ alpha_H: typing.Union[float, str] = "same"
226
+
227
+ l1_ratio: float = 0.0
228
+
229
+ forget_factor: float = 0.7
230
+
231
+
232
+ class MiniBatchNMFTransformer(AdaptiveDecompTransformer[MiniBatchNMF]):
233
+ pass
234
+
235
+
236
+ SettingsType = typing.TypeVar("SettingsType", bound=typing.Union[IncrementalPCASettings, MiniBatchNMFSettings])
237
+ TransformerType = typing.TypeVar(
238
+ "TransformerType",
239
+ bound=typing.Union[IncrementalPCATransformer, MiniBatchNMFTransformer],
240
+ )
241
+
242
+
243
+ class BaseAdaptiveDecompUnit(
244
+ BaseAdaptiveTransformerUnit[
245
+ SettingsType,
246
+ AxisArray,
247
+ AxisArray,
248
+ TransformerType,
249
+ ],
250
+ typing.Generic[SettingsType, TransformerType],
251
+ ):
252
+ INPUT_SAMPLE = ez.InputStream(AxisArray)
253
+
254
+ @ez.subscriber(INPUT_SAMPLE)
255
+ async def on_sample(self, msg: AxisArray) -> None:
256
+ await self.processor.apartial_fit(msg)
257
+
258
+
259
+ class IncrementalPCAUnit(
260
+ BaseAdaptiveDecompUnit[
261
+ IncrementalPCASettings,
262
+ IncrementalPCATransformer,
263
+ ]
264
+ ):
265
+ SETTINGS = IncrementalPCASettings
266
+
267
+
268
+ class MiniBatchNMFUnit(
269
+ BaseAdaptiveDecompUnit[
270
+ MiniBatchNMFSettings,
271
+ MiniBatchNMFTransformer,
272
+ ]
273
+ ):
274
+ SETTINGS = MiniBatchNMFSettings
@@ -0,0 +1,173 @@
1
+ import typing
2
+
3
+ import ezmsg.core as ez
4
+ import numpy as np
5
+ from ezmsg.baseproc import (
6
+ BaseStatefulProcessor,
7
+ BaseTransformerUnit,
8
+ CompositeProcessor,
9
+ )
10
+ from ezmsg.sigproc.window import WindowTransformer
11
+ from ezmsg.util.messages.axisarray import AxisArray, replace
12
+
13
+ from .adaptive_decomp import (
14
+ IncrementalPCASettings,
15
+ IncrementalPCATransformer,
16
+ MiniBatchNMFSettings,
17
+ MiniBatchNMFTransformer,
18
+ )
19
+
20
+
21
+ class IncrementalDecompSettings(ez.Settings):
22
+ axis: str = "!time"
23
+ n_components: int = 2
24
+ update_interval: float = 0.0
25
+ method: str = "pca"
26
+ batch_size: typing.Optional[int] = None
27
+ # PCA specific settings
28
+ whiten: bool = False
29
+ # NMF specific settings
30
+ init: str = "random"
31
+ beta_loss: str = "frobenius"
32
+ tol: float = 1e-3
33
+ alpha_W: float = 0.0
34
+ alpha_H: typing.Union[float, str] = "same"
35
+ l1_ratio: float = 0.0
36
+ forget_factor: float = 0.7
37
+
38
+
39
+ class IncrementalDecompTransformer(CompositeProcessor[IncrementalDecompSettings, AxisArray, AxisArray]):
40
+ """
41
+ Automates usage of IncrementalPCATransformer and MiniBatchNMFTransformer by using a WindowTransformer
42
+ to extract training samples then calls partial_fit on the decomposition transformer.
43
+ """
44
+
45
+ @staticmethod
46
+ def _initialize_processors(
47
+ settings: IncrementalDecompSettings,
48
+ ) -> dict[str, BaseStatefulProcessor]:
49
+ # Create the appropriate decomposition transformer
50
+ if settings.method == "pca":
51
+ decomp_settings = IncrementalPCASettings(
52
+ axis=settings.axis,
53
+ n_components=settings.n_components,
54
+ batch_size=settings.batch_size,
55
+ whiten=settings.whiten,
56
+ )
57
+ decomp = IncrementalPCATransformer(settings=decomp_settings)
58
+ else: # nmf
59
+ decomp_settings = MiniBatchNMFSettings(
60
+ axis=settings.axis,
61
+ n_components=settings.n_components,
62
+ batch_size=settings.batch_size if settings.batch_size else 1024,
63
+ init=settings.init,
64
+ beta_loss=settings.beta_loss,
65
+ tol=settings.tol,
66
+ alpha_W=settings.alpha_W,
67
+ alpha_H=settings.alpha_H,
68
+ l1_ratio=settings.l1_ratio,
69
+ forget_factor=settings.forget_factor,
70
+ )
71
+ decomp = MiniBatchNMFTransformer(settings=decomp_settings)
72
+
73
+ # Create windowing processor if update_interval is specified
74
+ if settings.update_interval > 0:
75
+ # TODO: This `iter_axis` is likely incorrect.
76
+ iter_axis = settings.axis[1:] if settings.axis.startswith("!") else "time"
77
+ windowing = WindowTransformer(
78
+ axis=iter_axis,
79
+ window_dur=settings.update_interval,
80
+ window_shift=settings.update_interval,
81
+ zero_pad_until="none",
82
+ )
83
+
84
+ return {
85
+ "decomp": decomp,
86
+ "windowing": windowing,
87
+ }
88
+
89
+ return {"decomp": decomp}
90
+
91
+ def _partial_fit_windowed(self, train_msg: AxisArray) -> None:
92
+ """
93
+ Helper function to do the partial_fit on the windowed message.
94
+ """
95
+ if np.prod(train_msg.data.shape) > 0:
96
+ # Windowing created a new "win" axis, but we don't actually want to use that
97
+ # in the message we send to the decomp processor.
98
+ axis_idx = train_msg.get_axis_idx("win")
99
+ win_axis = train_msg.axes["win"]
100
+ offsets = win_axis.value(np.asarray(range(train_msg.data.shape[axis_idx])))
101
+ for ix, _msg in enumerate(train_msg.iter_over_axis("win")):
102
+ _msg = replace(
103
+ _msg,
104
+ axes={
105
+ **_msg.axes,
106
+ "time": replace(
107
+ _msg.axes["time"],
108
+ offset=_msg.axes["time"].offset + offsets[ix],
109
+ ),
110
+ },
111
+ )
112
+ self._procs["decomp"].partial_fit(_msg)
113
+
114
+ def stateful_op(
115
+ self,
116
+ state: dict[str, tuple[typing.Any, int]] | None,
117
+ message: AxisArray,
118
+ ) -> tuple[dict[str, tuple[typing.Any, int]], AxisArray]:
119
+ state = state or {}
120
+
121
+ estim = self._procs["decomp"]._state.estimator
122
+ if not hasattr(estim, "components_") or estim.components_ is None:
123
+ # If the estimator has not been trained once, train it with the first message
124
+ self._procs["decomp"].partial_fit(message)
125
+ elif "windowing" in self._procs:
126
+ state["windowing"], train_msg = self._procs["windowing"].stateful_op(state.get("windowing", None), message)
127
+ self._partial_fit_windowed(train_msg)
128
+
129
+ # Process the incoming message
130
+ state["decomp"], result = self._procs["decomp"].stateful_op(state.get("decomp", None), message)
131
+
132
+ return state, result
133
+
134
+ async def _aprocess(self, message: AxisArray) -> AxisArray:
135
+ """
136
+ Asynchronously process the incoming message.
137
+ This is nearly identical to the _process method, but the processors
138
+ are called asynchronously.
139
+ """
140
+ estim = self._procs["decomp"]._state.estimator
141
+ if not hasattr(estim, "components_") or estim.components_ is None:
142
+ # If the estimator has not been trained once, train it with the first message
143
+ self._procs["decomp"].partial_fit(message)
144
+ elif "windowing" in self._procs:
145
+ # If windowing is enabled, extract training samples and perform partial_fit
146
+ train_msg = await self._procs["windowing"].__acall__(message)
147
+ self._partial_fit_windowed(train_msg) # Non async
148
+
149
+ # Process the incoming message
150
+ decomp_result = await self._procs["decomp"].__acall__(message)
151
+
152
+ return decomp_result
153
+
154
+ def _process(self, message: AxisArray) -> AxisArray:
155
+ estim = self._procs["decomp"]._state.estimator
156
+ if not hasattr(estim, "components_") or estim.components_ is None:
157
+ # If the estimator has not been trained once, train it with the first message
158
+ self._procs["decomp"].partial_fit(message)
159
+ elif "windowing" in self._procs:
160
+ # If windowing is enabled, extract training samples and perform partial_fit
161
+ train_msg = self._procs["windowing"](message)
162
+ self._partial_fit_windowed(train_msg)
163
+
164
+ # Process the incoming message
165
+ decomp_result = self._procs["decomp"](message)
166
+
167
+ return decomp_result
168
+
169
+
170
+ class IncrementalDecompUnit(
171
+ BaseTransformerUnit[IncrementalDecompSettings, AxisArray, AxisArray, IncrementalDecompTransformer]
172
+ ):
173
+ SETTINGS = IncrementalDecompSettings
@@ -0,0 +1 @@
1
+ # Use of this module is deprecated. Please use `ezmsg.learn.process` instead.
@@ -0,0 +1,12 @@
1
+ from ..process.adaptive_linear_regressor import (
2
+ AdaptiveLinearRegressorSettings as AdaptiveLinearRegressorSettings,
3
+ )
4
+ from ..process.adaptive_linear_regressor import (
5
+ AdaptiveLinearRegressorState as AdaptiveLinearRegressorState,
6
+ )
7
+ from ..process.adaptive_linear_regressor import (
8
+ AdaptiveLinearRegressorTransformer as AdaptiveLinearRegressorTransformer,
9
+ )
10
+ from ..process.adaptive_linear_regressor import (
11
+ AdaptiveLinearRegressorUnit as AdaptiveLinearRegressorUnit,
12
+ )
@@ -0,0 +1 @@
1
+ from ..model.cca import IncrementalCCA as IncrementalCCA
@@ -0,0 +1,9 @@
1
+ from ..process.linear_regressor import (
2
+ LinearRegressorSettings as LinearRegressorSettings,
3
+ )
4
+ from ..process.linear_regressor import (
5
+ LinearRegressorState as LinearRegressorState,
6
+ )
7
+ from ..process.linear_regressor import (
8
+ LinearRegressorTransformer as LinearRegressorTransformer,
9
+ )
@@ -0,0 +1,9 @@
1
+ from ..process.sgd import (
2
+ SGDDecoder as SGDDecoder,
3
+ )
4
+ from ..process.sgd import (
5
+ SGDDecoderSettings as SGDDecoderSettings,
6
+ )
7
+ from ..process.sgd import (
8
+ sgd_decoder as sgd_decoder,
9
+ )
@@ -0,0 +1,12 @@
1
+ from ..process.slda import (
2
+ SLDA as SLDA,
3
+ )
4
+ from ..process.slda import (
5
+ SLDASettings as SLDASettings,
6
+ )
7
+ from ..process.slda import (
8
+ SLDAState as SLDAState,
9
+ )
10
+ from ..process.slda import (
11
+ SLDATransformer as SLDATransformer,
12
+ )
File without changes
@@ -0,0 +1,122 @@
1
+ import numpy as np
2
+ from scipy import linalg
3
+
4
+
5
+ class IncrementalCCA:
6
+ def __init__(
7
+ self,
8
+ n_components=2,
9
+ base_smoothing=0.95,
10
+ min_smoothing=0.5,
11
+ max_smoothing=0.99,
12
+ adaptation_rate=0.1,
13
+ ):
14
+ """
15
+ Parameters:
16
+ -----------
17
+ n_components : int
18
+ Number of canonical components to compute
19
+ base_smoothing : float
20
+ Base smoothing factor (will be adapted)
21
+ min_smoothing : float
22
+ Minimum allowed smoothing factor
23
+ max_smoothing : float
24
+ Maximum allowed smoothing factor
25
+ adaptation_rate : float
26
+ How quickly to adjust smoothing factor (between 0 and 1)
27
+ """
28
+ self.n_components = n_components
29
+ self.base_smoothing = base_smoothing
30
+ self.current_smoothing = base_smoothing
31
+ self.min_smoothing = min_smoothing
32
+ self.max_smoothing = max_smoothing
33
+ self.adaptation_rate = adaptation_rate
34
+ self.initialized = False
35
+
36
+ def initialize(self, d1, d2):
37
+ """Initialize the necessary matrices"""
38
+ self.d1 = d1
39
+ self.d2 = d2
40
+
41
+ # Initialize correlation matrices
42
+ self.C11 = np.zeros((d1, d1))
43
+ self.C22 = np.zeros((d2, d2))
44
+ self.C12 = np.zeros((d1, d2))
45
+
46
+ self.initialized = True
47
+
48
+ def _compute_change_magnitude(self, C11_new, C22_new, C12_new):
49
+ """Compute magnitude of change in correlation structure"""
50
+ # Frobenius norm of differences
51
+ diff11 = np.linalg.norm(C11_new - self.C11)
52
+ diff22 = np.linalg.norm(C22_new - self.C22)
53
+ diff12 = np.linalg.norm(C12_new - self.C12)
54
+
55
+ # Normalize by matrix sizes
56
+ diff11 /= self.d1 * self.d1
57
+ diff22 /= self.d2 * self.d2
58
+ diff12 /= self.d1 * self.d2
59
+
60
+ return (diff11 + diff22 + diff12) / 3
61
+
62
+ def _adapt_smoothing(self, change_magnitude):
63
+ """Adapt smoothing factor based on detected changes"""
64
+ # If change is large, decrease smoothing factor
65
+ target_smoothing = self.base_smoothing * (1.0 - change_magnitude)
66
+ target_smoothing = np.clip(
67
+ target_smoothing, self.min_smoothing, self.max_smoothing
68
+ )
69
+
70
+ # Smooth the adaptation itself
71
+ self.current_smoothing = (
72
+ 1 - self.adaptation_rate
73
+ ) * self.current_smoothing + self.adaptation_rate * target_smoothing
74
+
75
+ def partial_fit(self, X1, X2, update_projections=True):
76
+ """Update the model with new samples using adaptive smoothing
77
+ Assumes X1 and X2 are already centered and scaled"""
78
+ if not self.initialized:
79
+ self.initialize(X1.shape[1], X2.shape[1])
80
+
81
+ # Compute new correlation matrices from current batch
82
+ C11_new = X1.T @ X1 / X1.shape[0]
83
+ C22_new = X2.T @ X2 / X2.shape[0]
84
+ C12_new = X1.T @ X2 / X1.shape[0]
85
+
86
+ # Detect changes and adapt smoothing factor
87
+ if self.C11.any(): # Skip first update
88
+ change_magnitude = self._compute_change_magnitude(C11_new, C22_new, C12_new)
89
+ self._adapt_smoothing(change_magnitude)
90
+
91
+ # Update with current smoothing factor
92
+ alpha = self.current_smoothing
93
+ self.C11 = alpha * self.C11 + (1 - alpha) * C11_new
94
+ self.C22 = alpha * self.C22 + (1 - alpha) * C22_new
95
+ self.C12 = alpha * self.C12 + (1 - alpha) * C12_new
96
+
97
+ if update_projections:
98
+ self._update_projections()
99
+
100
+ def _update_projections(self):
101
+ """Update canonical vectors and correlations"""
102
+ eps = 1e-8
103
+ C11_reg = self.C11 + eps * np.eye(self.d1)
104
+ C22_reg = self.C22 + eps * np.eye(self.d2)
105
+
106
+ K = (
107
+ linalg.inv(linalg.sqrtm(C11_reg))
108
+ @ self.C12
109
+ @ linalg.inv(linalg.sqrtm(C22_reg))
110
+ )
111
+ U, self.correlations_, V = linalg.svd(K)
112
+
113
+ self.x_weights_ = linalg.inv(linalg.sqrtm(C11_reg)) @ U[:, : self.n_components]
114
+ self.y_weights_ = (
115
+ linalg.inv(linalg.sqrtm(C22_reg)) @ V.T[:, : self.n_components]
116
+ )
117
+
118
+ def transform(self, X1, X2):
119
+ """Project data onto canonical components"""
120
+ X1_proj = X1 @ self.x_weights_
121
+ X2_proj = X2 @ self.y_weights_
122
+ return X1_proj, X2_proj