ezmsg-learn 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2 @@
1
+ def hello() -> str:
2
+ return "Hello from ezmsg-learn!"
@@ -0,0 +1,34 @@
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
12
+
13
+ TYPE_CHECKING = False
14
+ if TYPE_CHECKING:
15
+ from typing import Tuple
16
+ from typing import Union
17
+
18
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
20
+ else:
21
+ VERSION_TUPLE = object
22
+ COMMIT_ID = object
23
+
24
+ version: str
25
+ __version__: str
26
+ __version_tuple__: VERSION_TUPLE
27
+ version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
30
+
31
+ __version__ = version = '1.0'
32
+ __version_tuple__ = version_tuple = (1, 0)
33
+
34
+ __commit_id__ = commit_id = None
File without changes
@@ -0,0 +1,284 @@
1
+ import typing
2
+
3
+ from sklearn.decomposition import IncrementalPCA, MiniBatchNMF
4
+ import numpy as np
5
+ import ezmsg.core as ez
6
+ from ezmsg.sigproc.base import (
7
+ processor_state,
8
+ BaseAdaptiveTransformer,
9
+ BaseAdaptiveTransformerUnit,
10
+ )
11
+ from ezmsg.util.messages.axisarray import AxisArray, replace
12
+
13
+
14
+ class AdaptiveDecompSettings(ez.Settings):
15
+ axis: str = "!time"
16
+ n_components: int = 2
17
+
18
+
19
+ @processor_state
20
+ class AdaptiveDecompState:
21
+ template: AxisArray | None = None
22
+ axis_groups: tuple[str, list[str], list[str]] | None = None
23
+ estimator: typing.Any = None
24
+
25
+
26
+ EstimatorType = typing.TypeVar(
27
+ "EstimatorType", bound=typing.Union[IncrementalPCA, MiniBatchNMF]
28
+ )
29
+
30
+
31
+ class AdaptiveDecompTransformer(
32
+ BaseAdaptiveTransformer[
33
+ AdaptiveDecompSettings, AxisArray, AxisArray, AdaptiveDecompState
34
+ ],
35
+ typing.Generic[EstimatorType],
36
+ ):
37
+ """
38
+ Base class for adaptive decomposition transformers. See IncrementalPCATransformer and MiniBatchNMFTransformer
39
+ for concrete implementations.
40
+
41
+ Note that for these classes, adaptation is not automatic. The user must call partial_fit on the transformer.
42
+ For automated adaptation, see IncrementalDecompTransformer.
43
+ """
44
+
45
+ def __init__(self, *args, **kwargs):
46
+ super().__init__(*args, **kwargs)
47
+ self._state.estimator = self._create_estimator()
48
+
49
+ @classmethod
50
+ def get_message_type(cls, dir: str) -> typing.Type[AxisArray]:
51
+ # Override because we don't reuse the generic types.
52
+ return AxisArray
53
+
54
+ @classmethod
55
+ def get_estimator_type(cls) -> typing.Type[EstimatorType]:
56
+ return typing.get_args(cls.__orig_bases__[0])[0]
57
+
58
+ def _create_estimator(self) -> EstimatorType:
59
+ estimator_klass = self.get_estimator_type()
60
+ estimator_settings = self.settings.__dict__.copy()
61
+ estimator_settings.pop("axis")
62
+ return estimator_klass(**estimator_settings)
63
+
64
+ def _calculate_axis_groups(self, message: AxisArray):
65
+ if self.settings.axis.startswith("!"):
66
+ # Iterate over the !axis and collapse all other axes
67
+ iter_axis = self.settings.axis[1:]
68
+ it_ax_ix = message.get_axis_idx(iter_axis)
69
+ targ_axes = message.dims[:it_ax_ix] + message.dims[it_ax_ix + 1 :]
70
+ off_targ_axes = []
71
+ else:
72
+ # Do PCA on the parameterized axis
73
+ targ_axes = [self.settings.axis]
74
+ # Iterate over streaming axis
75
+ iter_axis = "win" if "win" in message.dims else "time"
76
+ if iter_axis == self.settings.axis:
77
+ raise ValueError(
78
+ f"Iterating axis ({iter_axis}) cannot be the same as the target axis ({self.settings.axis})"
79
+ )
80
+ it_ax_ix = message.get_axis_idx(iter_axis)
81
+ # Remaining axes are to be treated independently
82
+ off_targ_axes = [
83
+ _
84
+ for _ in (message.dims[:it_ax_ix] + message.dims[it_ax_ix + 1 :])
85
+ if _ != self.settings.axis
86
+ ]
87
+ self._state.axis_groups = iter_axis, targ_axes, off_targ_axes
88
+
89
+ def _hash_message(self, message: AxisArray) -> int:
90
+ iter_axis = (
91
+ self.settings.axis[1:]
92
+ if self.settings.axis.startswith("!")
93
+ else ("win" if "win" in message.dims else "time")
94
+ )
95
+ ax_idx = message.get_axis_idx(iter_axis)
96
+ sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
97
+ return hash((sample_shape, message.key))
98
+
99
+ def _reset_state(self, message: AxisArray) -> None:
100
+ """Reset state"""
101
+ self._calculate_axis_groups(message)
102
+ iter_axis, targ_axes, off_targ_axes = self._state.axis_groups
103
+
104
+ # Template
105
+ out_dims = [iter_axis] + off_targ_axes
106
+ out_axes = {
107
+ iter_axis: message.axes[iter_axis],
108
+ **{k: message.axes[k] for k in off_targ_axes},
109
+ }
110
+ if len(targ_axes) == 1:
111
+ targ_ax_name = targ_axes[0]
112
+ else:
113
+ targ_ax_name = "components"
114
+ out_dims += [targ_ax_name]
115
+ out_axes[targ_ax_name] = AxisArray.CoordinateAxis(
116
+ data=np.arange(self.settings.n_components).astype(str),
117
+ dims=[targ_ax_name],
118
+ unit="component",
119
+ )
120
+ out_shape = [message.data.shape[message.get_axis_idx(_)] for _ in off_targ_axes]
121
+ out_shape = (0,) + tuple(out_shape) + (self.settings.n_components,)
122
+ self._state.template = replace(
123
+ message,
124
+ data=np.zeros(out_shape, dtype=float),
125
+ dims=out_dims,
126
+ axes=out_axes,
127
+ )
128
+
129
+ def _process(self, message: AxisArray) -> AxisArray:
130
+ iter_axis, targ_axes, off_targ_axes = self._state.axis_groups
131
+ ax_idx = message.get_axis_idx(iter_axis)
132
+ in_dat = message.data
133
+
134
+ if in_dat.shape[ax_idx] == 0:
135
+ return self._state.template
136
+
137
+ # Re-order axes
138
+ sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes
139
+ if message.dims != sorted_dims_exp:
140
+ # TODO: Implement axes transposition if needed
141
+ # re_order = [ax_idx] + off_targ_inds + targ_inds
142
+ # np.transpose(in_dat, re_order)
143
+ pass
144
+
145
+ # fold [iter_axis] + off_targ_axes together and fold targ_axes together
146
+ d2 = np.prod(in_dat.shape[len(off_targ_axes) + 1 :])
147
+ in_dat = in_dat.reshape((-1, d2))
148
+
149
+ replace_kwargs = {
150
+ "axes": {**self._state.template.axes, iter_axis: message.axes[iter_axis]},
151
+ }
152
+
153
+ # Transform data
154
+ if hasattr(self._state.estimator, "components_"):
155
+ decomp_dat = self._state.estimator.transform(in_dat).reshape(
156
+ (-1,) + self._state.template.data.shape[1:]
157
+ )
158
+ replace_kwargs["data"] = decomp_dat
159
+
160
+ return replace(self._state.template, **replace_kwargs)
161
+
162
+ def partial_fit(self, message: AxisArray) -> None:
163
+ # Check if we need to reset state
164
+ msg_hash = self._hash_message(message)
165
+ if self._hash != msg_hash:
166
+ self._reset_state(message)
167
+ self._hash = msg_hash
168
+
169
+ iter_axis, targ_axes, off_targ_axes = self._state.axis_groups
170
+ ax_idx = message.get_axis_idx(iter_axis)
171
+ in_dat = message.data
172
+
173
+ if in_dat.shape[ax_idx] == 0:
174
+ return
175
+
176
+ # Re-order axes if needed
177
+ sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes
178
+ if message.dims != sorted_dims_exp:
179
+ # TODO: Implement axes transposition if needed
180
+ pass
181
+
182
+ # fold [iter_axis] + off_targ_axes together and fold targ_axes together
183
+ d2 = np.prod(in_dat.shape[len(off_targ_axes) + 1 :])
184
+ in_dat = in_dat.reshape((-1, d2))
185
+
186
+ # Fit the estimator
187
+ self._state.estimator.partial_fit(in_dat)
188
+
189
+
190
+ class IncrementalPCASettings(AdaptiveDecompSettings):
191
+ # Additional settings specific to PCA
192
+ whiten: bool = False
193
+ batch_size: typing.Optional[int] = None
194
+
195
+
196
+ class IncrementalPCATransformer(AdaptiveDecompTransformer[IncrementalPCA]):
197
+ pass
198
+
199
+
200
+ class MiniBatchNMFSettings(AdaptiveDecompSettings):
201
+ # Additional settings specific to NMF
202
+ init: typing.Optional[str] = "random"
203
+ """
204
+ 'random', 'nndsvd', 'nndsvda', 'nndsvdar', 'custom', or None
205
+ """
206
+
207
+ batch_size: int = 1024
208
+ """
209
+ batch_size is used only when doing a full fit (i.e., a reset),
210
+ or as the exponent to forget_factor, where a very small batch_size
211
+ will cause the model to update more slowly.
212
+ It is better to set batch_size to a larger number than the expected
213
+ chunk size and instead use forget_factor to control the learning rate.
214
+ """
215
+
216
+ beta_loss: typing.Union[str, float] = "frobenius"
217
+ """
218
+ 'frobenius', 'kullback-leibler', 'itakura-saito'
219
+ Note that values different from 'frobenius'
220
+ (or 2) and 'kullback-leibler' (or 1) lead to significantly slower
221
+ fits. Note that for `beta_loss <= 0` (or 'itakura-saito'), the input
222
+ matrix `X` cannot contain zeros.
223
+ """
224
+
225
+ tol: float = 1e-4
226
+
227
+ max_no_improvement: typing.Optional[int] = None
228
+
229
+ max_iter: int = 200
230
+
231
+ alpha_W: float = 0.0
232
+
233
+ alpha_H: typing.Union[float, str] = "same"
234
+
235
+ l1_ratio: float = 0.0
236
+
237
+ forget_factor: float = 0.7
238
+
239
+
240
+ class MiniBatchNMFTransformer(AdaptiveDecompTransformer[MiniBatchNMF]):
241
+ pass
242
+
243
+
244
+ SettingsType = typing.TypeVar(
245
+ "SettingsType", bound=typing.Union[IncrementalPCASettings, MiniBatchNMFSettings]
246
+ )
247
+ TransformerType = typing.TypeVar(
248
+ "TransformerType",
249
+ bound=typing.Union[IncrementalPCATransformer, MiniBatchNMFTransformer],
250
+ )
251
+
252
+
253
+ class BaseAdaptiveDecompUnit(
254
+ BaseAdaptiveTransformerUnit[
255
+ SettingsType,
256
+ AxisArray,
257
+ AxisArray,
258
+ TransformerType,
259
+ ],
260
+ typing.Generic[SettingsType, TransformerType],
261
+ ):
262
+ INPUT_SAMPLE = ez.InputStream(AxisArray)
263
+
264
+ @ez.subscriber(INPUT_SAMPLE)
265
+ async def on_sample(self, msg: AxisArray) -> None:
266
+ await self.processor.apartial_fit(msg)
267
+
268
+
269
+ class IncrementalPCAUnit(
270
+ BaseAdaptiveDecompUnit[
271
+ IncrementalPCASettings,
272
+ IncrementalPCATransformer,
273
+ ]
274
+ ):
275
+ SETTINGS = IncrementalPCASettings
276
+
277
+
278
+ class MiniBatchNMFUnit(
279
+ BaseAdaptiveDecompUnit[
280
+ MiniBatchNMFSettings,
281
+ MiniBatchNMFTransformer,
282
+ ]
283
+ ):
284
+ SETTINGS = MiniBatchNMFSettings
@@ -0,0 +1,181 @@
1
+ import typing
2
+
3
+ import numpy as np
4
+ import ezmsg.core as ez
5
+ from ezmsg.util.messages.axisarray import AxisArray, replace
6
+ from ezmsg.sigproc.base import (
7
+ CompositeProcessor,
8
+ BaseStatefulProcessor,
9
+ BaseTransformerUnit,
10
+ )
11
+ from ezmsg.sigproc.window import WindowTransformer
12
+
13
+ from .adaptive_decomp import (
14
+ IncrementalPCASettings,
15
+ IncrementalPCATransformer,
16
+ MiniBatchNMFSettings,
17
+ MiniBatchNMFTransformer,
18
+ )
19
+
20
+
21
+ class IncrementalDecompSettings(ez.Settings):
22
+ axis: str = "!time"
23
+ n_components: int = 2
24
+ update_interval: float = 0.0
25
+ method: str = "pca"
26
+ batch_size: typing.Optional[int] = None
27
+ # PCA specific settings
28
+ whiten: bool = False
29
+ # NMF specific settings
30
+ init: str = "random"
31
+ beta_loss: str = "frobenius"
32
+ tol: float = 1e-3
33
+ alpha_W: float = 0.0
34
+ alpha_H: typing.Union[float, str] = "same"
35
+ l1_ratio: float = 0.0
36
+ forget_factor: float = 0.7
37
+
38
+
39
+ class IncrementalDecompTransformer(
40
+ CompositeProcessor[IncrementalDecompSettings, AxisArray, AxisArray]
41
+ ):
42
+ """
43
+ Automates usage of IncrementalPCATransformer and MiniBatchNMFTransformer by using a WindowTransformer
44
+ to extract training samples then calls partial_fit on the decomposition transformer.
45
+ """
46
+
47
+ @staticmethod
48
+ def _initialize_processors(
49
+ settings: IncrementalDecompSettings,
50
+ ) -> dict[str, BaseStatefulProcessor]:
51
+ # Create the appropriate decomposition transformer
52
+ if settings.method == "pca":
53
+ decomp_settings = IncrementalPCASettings(
54
+ axis=settings.axis,
55
+ n_components=settings.n_components,
56
+ batch_size=settings.batch_size,
57
+ whiten=settings.whiten,
58
+ )
59
+ decomp = IncrementalPCATransformer(settings=decomp_settings)
60
+ else: # nmf
61
+ decomp_settings = MiniBatchNMFSettings(
62
+ axis=settings.axis,
63
+ n_components=settings.n_components,
64
+ batch_size=settings.batch_size if settings.batch_size else 1024,
65
+ init=settings.init,
66
+ beta_loss=settings.beta_loss,
67
+ tol=settings.tol,
68
+ alpha_W=settings.alpha_W,
69
+ alpha_H=settings.alpha_H,
70
+ l1_ratio=settings.l1_ratio,
71
+ forget_factor=settings.forget_factor,
72
+ )
73
+ decomp = MiniBatchNMFTransformer(settings=decomp_settings)
74
+
75
+ # Create windowing processor if update_interval is specified
76
+ if settings.update_interval > 0:
77
+ # TODO: This `iter_axis` is likely incorrect.
78
+ iter_axis = settings.axis[1:] if settings.axis.startswith("!") else "time"
79
+ windowing = WindowTransformer(
80
+ axis=iter_axis,
81
+ window_dur=settings.update_interval,
82
+ window_shift=settings.update_interval,
83
+ zero_pad_until="none",
84
+ )
85
+
86
+ return {
87
+ "decomp": decomp,
88
+ "windowing": windowing,
89
+ }
90
+
91
+ return {"decomp": decomp}
92
+
93
+ def _partial_fit_windowed(self, train_msg: AxisArray) -> None:
94
+ """
95
+ Helper function to do the partial_fit on the windowed message.
96
+ """
97
+ if np.prod(train_msg.data.shape) > 0:
98
+ # Windowing created a new "win" axis, but we don't actually want to use that
99
+ # in the message we send to the decomp processor.
100
+ axis_idx = train_msg.get_axis_idx("win")
101
+ win_axis = train_msg.axes["win"]
102
+ offsets = win_axis.value(np.asarray(range(train_msg.data.shape[axis_idx])))
103
+ for ix, _msg in enumerate(train_msg.iter_over_axis("win")):
104
+ _msg = replace(
105
+ _msg,
106
+ axes={
107
+ **_msg.axes,
108
+ "time": replace(
109
+ _msg.axes["time"],
110
+ offset=_msg.axes["time"].offset + offsets[ix],
111
+ ),
112
+ },
113
+ )
114
+ self._procs["decomp"].partial_fit(_msg)
115
+
116
+ def stateful_op(
117
+ self,
118
+ state: dict[str, tuple[typing.Any, int]] | None,
119
+ message: AxisArray,
120
+ ) -> tuple[dict[str, tuple[typing.Any, int]], AxisArray]:
121
+ state = state or {}
122
+
123
+ estim = self._procs["decomp"]._state.estimator
124
+ if not hasattr(estim, "components_") or estim.components_ is None:
125
+ # If the estimator has not been trained once, train it with the first message
126
+ self._procs["decomp"].partial_fit(message)
127
+ elif "windowing" in self._procs:
128
+ state["windowing"], train_msg = self._procs["windowing"].stateful_op(
129
+ state.get("windowing", None), message
130
+ )
131
+ self._partial_fit_windowed(train_msg)
132
+
133
+ # Process the incoming message
134
+ state["decomp"], result = self._procs["decomp"].stateful_op(
135
+ state.get("decomp", None), message
136
+ )
137
+
138
+ return state, result
139
+
140
+ async def _aprocess(self, message: AxisArray) -> AxisArray:
141
+ """
142
+ Asynchronously process the incoming message.
143
+ This is nearly identical to the _process method, but the processors
144
+ are called asynchronously.
145
+ """
146
+ estim = self._procs["decomp"]._state.estimator
147
+ if not hasattr(estim, "components_") or estim.components_ is None:
148
+ # If the estimator has not been trained once, train it with the first message
149
+ self._procs["decomp"].partial_fit(message)
150
+ elif "windowing" in self._procs:
151
+ # If windowing is enabled, extract training samples and perform partial_fit
152
+ train_msg = await self._procs["windowing"].__acall__(message)
153
+ self._partial_fit_windowed(train_msg) # Non async
154
+
155
+ # Process the incoming message
156
+ decomp_result = await self._procs["decomp"].__acall__(message)
157
+
158
+ return decomp_result
159
+
160
+ def _process(self, message: AxisArray) -> AxisArray:
161
+ estim = self._procs["decomp"]._state.estimator
162
+ if not hasattr(estim, "components_") or estim.components_ is None:
163
+ # If the estimator has not been trained once, train it with the first message
164
+ self._procs["decomp"].partial_fit(message)
165
+ elif "windowing" in self._procs:
166
+ # If windowing is enabled, extract training samples and perform partial_fit
167
+ train_msg = self._procs["windowing"](message)
168
+ self._partial_fit_windowed(train_msg)
169
+
170
+ # Process the incoming message
171
+ decomp_result = self._procs["decomp"](message)
172
+
173
+ return decomp_result
174
+
175
+
176
+ class IncrementalDecompUnit(
177
+ BaseTransformerUnit[
178
+ IncrementalDecompSettings, AxisArray, AxisArray, IncrementalDecompTransformer
179
+ ]
180
+ ):
181
+ SETTINGS = IncrementalDecompSettings
@@ -0,0 +1 @@
1
+ # Use of this module is deprecated. Please use `ezmsg.learn.process` instead.
@@ -0,0 +1,6 @@
1
+ from ..process.adaptive_linear_regressor import (
2
+ AdaptiveLinearRegressorSettings as AdaptiveLinearRegressorSettings,
3
+ AdaptiveLinearRegressorState as AdaptiveLinearRegressorState,
4
+ AdaptiveLinearRegressorTransformer as AdaptiveLinearRegressorTransformer,
5
+ AdaptiveLinearRegressorUnit as AdaptiveLinearRegressorUnit,
6
+ )
@@ -0,0 +1 @@
1
+ from ..model.cca import IncrementalCCA as IncrementalCCA
@@ -0,0 +1,5 @@
1
+ from ..process.linear_regressor import (
2
+ LinearRegressorSettings as LinearRegressorSettings,
3
+ LinearRegressorState as LinearRegressorState,
4
+ LinearRegressorTransformer as LinearRegressorTransformer,
5
+ )
@@ -0,0 +1,5 @@
1
+ from ..process.sgd import (
2
+ sgd_decoder as sgd_decoder,
3
+ SGDDecoderSettings as SGDDecoderSettings,
4
+ SGDDecoder as SGDDecoder,
5
+ )
@@ -0,0 +1,6 @@
1
+ from ..process.slda import (
2
+ SLDASettings as SLDASettings,
3
+ SLDAState as SLDAState,
4
+ SLDATransformer as SLDATransformer,
5
+ SLDA as SLDA,
6
+ )
File without changes
@@ -0,0 +1,122 @@
1
+ import numpy as np
2
+ from scipy import linalg
3
+
4
+
5
+ class IncrementalCCA:
6
+ def __init__(
7
+ self,
8
+ n_components=2,
9
+ base_smoothing=0.95,
10
+ min_smoothing=0.5,
11
+ max_smoothing=0.99,
12
+ adaptation_rate=0.1,
13
+ ):
14
+ """
15
+ Parameters:
16
+ -----------
17
+ n_components : int
18
+ Number of canonical components to compute
19
+ base_smoothing : float
20
+ Base smoothing factor (will be adapted)
21
+ min_smoothing : float
22
+ Minimum allowed smoothing factor
23
+ max_smoothing : float
24
+ Maximum allowed smoothing factor
25
+ adaptation_rate : float
26
+ How quickly to adjust smoothing factor (between 0 and 1)
27
+ """
28
+ self.n_components = n_components
29
+ self.base_smoothing = base_smoothing
30
+ self.current_smoothing = base_smoothing
31
+ self.min_smoothing = min_smoothing
32
+ self.max_smoothing = max_smoothing
33
+ self.adaptation_rate = adaptation_rate
34
+ self.initialized = False
35
+
36
+ def initialize(self, d1, d2):
37
+ """Initialize the necessary matrices"""
38
+ self.d1 = d1
39
+ self.d2 = d2
40
+
41
+ # Initialize correlation matrices
42
+ self.C11 = np.zeros((d1, d1))
43
+ self.C22 = np.zeros((d2, d2))
44
+ self.C12 = np.zeros((d1, d2))
45
+
46
+ self.initialized = True
47
+
48
+ def _compute_change_magnitude(self, C11_new, C22_new, C12_new):
49
+ """Compute magnitude of change in correlation structure"""
50
+ # Frobenius norm of differences
51
+ diff11 = np.linalg.norm(C11_new - self.C11)
52
+ diff22 = np.linalg.norm(C22_new - self.C22)
53
+ diff12 = np.linalg.norm(C12_new - self.C12)
54
+
55
+ # Normalize by matrix sizes
56
+ diff11 /= self.d1 * self.d1
57
+ diff22 /= self.d2 * self.d2
58
+ diff12 /= self.d1 * self.d2
59
+
60
+ return (diff11 + diff22 + diff12) / 3
61
+
62
+ def _adapt_smoothing(self, change_magnitude):
63
+ """Adapt smoothing factor based on detected changes"""
64
+ # If change is large, decrease smoothing factor
65
+ target_smoothing = self.base_smoothing * (1.0 - change_magnitude)
66
+ target_smoothing = np.clip(
67
+ target_smoothing, self.min_smoothing, self.max_smoothing
68
+ )
69
+
70
+ # Smooth the adaptation itself
71
+ self.current_smoothing = (
72
+ 1 - self.adaptation_rate
73
+ ) * self.current_smoothing + self.adaptation_rate * target_smoothing
74
+
75
+ def partial_fit(self, X1, X2, update_projections=True):
76
+ """Update the model with new samples using adaptive smoothing
77
+ Assumes X1 and X2 are already centered and scaled"""
78
+ if not self.initialized:
79
+ self.initialize(X1.shape[1], X2.shape[1])
80
+
81
+ # Compute new correlation matrices from current batch
82
+ C11_new = X1.T @ X1 / X1.shape[0]
83
+ C22_new = X2.T @ X2 / X2.shape[0]
84
+ C12_new = X1.T @ X2 / X1.shape[0]
85
+
86
+ # Detect changes and adapt smoothing factor
87
+ if self.C11.any(): # Skip first update
88
+ change_magnitude = self._compute_change_magnitude(C11_new, C22_new, C12_new)
89
+ self._adapt_smoothing(change_magnitude)
90
+
91
+ # Update with current smoothing factor
92
+ alpha = self.current_smoothing
93
+ self.C11 = alpha * self.C11 + (1 - alpha) * C11_new
94
+ self.C22 = alpha * self.C22 + (1 - alpha) * C22_new
95
+ self.C12 = alpha * self.C12 + (1 - alpha) * C12_new
96
+
97
+ if update_projections:
98
+ self._update_projections()
99
+
100
+ def _update_projections(self):
101
+ """Update canonical vectors and correlations"""
102
+ eps = 1e-8
103
+ C11_reg = self.C11 + eps * np.eye(self.d1)
104
+ C22_reg = self.C22 + eps * np.eye(self.d2)
105
+
106
+ K = (
107
+ linalg.inv(linalg.sqrtm(C11_reg))
108
+ @ self.C12
109
+ @ linalg.inv(linalg.sqrtm(C22_reg))
110
+ )
111
+ U, self.correlations_, V = linalg.svd(K)
112
+
113
+ self.x_weights_ = linalg.inv(linalg.sqrtm(C11_reg)) @ U[:, : self.n_components]
114
+ self.y_weights_ = (
115
+ linalg.inv(linalg.sqrtm(C22_reg)) @ V.T[:, : self.n_components]
116
+ )
117
+
118
+ def transform(self, X1, X2):
119
+ """Project data onto canonical components"""
120
+ X1_proj = X1 @ self.x_weights_
121
+ X2_proj = X2 @ self.y_weights_
122
+ return X1_proj, X2_proj