ezmsg-learn 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/learn/__init__.py +2 -0
- ezmsg/learn/__version__.py +34 -0
- ezmsg/learn/dim_reduce/__init__.py +0 -0
- ezmsg/learn/dim_reduce/adaptive_decomp.py +284 -0
- ezmsg/learn/dim_reduce/incremental_decomp.py +181 -0
- ezmsg/learn/linear_model/__init__.py +1 -0
- ezmsg/learn/linear_model/adaptive_linear_regressor.py +6 -0
- ezmsg/learn/linear_model/cca.py +1 -0
- ezmsg/learn/linear_model/linear_regressor.py +5 -0
- ezmsg/learn/linear_model/sgd.py +5 -0
- ezmsg/learn/linear_model/slda.py +6 -0
- ezmsg/learn/model/__init__.py +0 -0
- ezmsg/learn/model/cca.py +122 -0
- ezmsg/learn/model/mlp.py +133 -0
- ezmsg/learn/model/mlp_old.py +49 -0
- ezmsg/learn/model/refit_kalman.py +401 -0
- ezmsg/learn/model/rnn.py +160 -0
- ezmsg/learn/model/transformer.py +175 -0
- ezmsg/learn/nlin_model/__init__.py +1 -0
- ezmsg/learn/nlin_model/mlp.py +6 -0
- ezmsg/learn/process/__init__.py +0 -0
- ezmsg/learn/process/adaptive_linear_regressor.py +157 -0
- ezmsg/learn/process/base.py +173 -0
- ezmsg/learn/process/linear_regressor.py +99 -0
- ezmsg/learn/process/mlp_old.py +200 -0
- ezmsg/learn/process/refit_kalman.py +407 -0
- ezmsg/learn/process/rnn.py +266 -0
- ezmsg/learn/process/sgd.py +131 -0
- ezmsg/learn/process/sklearn.py +274 -0
- ezmsg/learn/process/slda.py +119 -0
- ezmsg/learn/process/torch.py +378 -0
- ezmsg/learn/process/transformer.py +222 -0
- ezmsg/learn/util.py +66 -0
- ezmsg_learn-1.0.dist-info/METADATA +34 -0
- ezmsg_learn-1.0.dist-info/RECORD +36 -0
- ezmsg_learn-1.0.dist-info/WHEEL +4 -0
ezmsg/learn/__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# file generated by setuptools-scm
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"__version__",
|
|
6
|
+
"__version_tuple__",
|
|
7
|
+
"version",
|
|
8
|
+
"version_tuple",
|
|
9
|
+
"__commit_id__",
|
|
10
|
+
"commit_id",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
TYPE_CHECKING = False
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from typing import Tuple
|
|
16
|
+
from typing import Union
|
|
17
|
+
|
|
18
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
19
|
+
COMMIT_ID = Union[str, None]
|
|
20
|
+
else:
|
|
21
|
+
VERSION_TUPLE = object
|
|
22
|
+
COMMIT_ID = object
|
|
23
|
+
|
|
24
|
+
version: str
|
|
25
|
+
__version__: str
|
|
26
|
+
__version_tuple__: VERSION_TUPLE
|
|
27
|
+
version_tuple: VERSION_TUPLE
|
|
28
|
+
commit_id: COMMIT_ID
|
|
29
|
+
__commit_id__: COMMIT_ID
|
|
30
|
+
|
|
31
|
+
__version__ = version = '1.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (1, 0)
|
|
33
|
+
|
|
34
|
+
__commit_id__ = commit_id = None
|
|
File without changes
|
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
from sklearn.decomposition import IncrementalPCA, MiniBatchNMF
|
|
4
|
+
import numpy as np
|
|
5
|
+
import ezmsg.core as ez
|
|
6
|
+
from ezmsg.sigproc.base import (
|
|
7
|
+
processor_state,
|
|
8
|
+
BaseAdaptiveTransformer,
|
|
9
|
+
BaseAdaptiveTransformerUnit,
|
|
10
|
+
)
|
|
11
|
+
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AdaptiveDecompSettings(ez.Settings):
|
|
15
|
+
axis: str = "!time"
|
|
16
|
+
n_components: int = 2
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@processor_state
|
|
20
|
+
class AdaptiveDecompState:
|
|
21
|
+
template: AxisArray | None = None
|
|
22
|
+
axis_groups: tuple[str, list[str], list[str]] | None = None
|
|
23
|
+
estimator: typing.Any = None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
EstimatorType = typing.TypeVar(
|
|
27
|
+
"EstimatorType", bound=typing.Union[IncrementalPCA, MiniBatchNMF]
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AdaptiveDecompTransformer(
|
|
32
|
+
BaseAdaptiveTransformer[
|
|
33
|
+
AdaptiveDecompSettings, AxisArray, AxisArray, AdaptiveDecompState
|
|
34
|
+
],
|
|
35
|
+
typing.Generic[EstimatorType],
|
|
36
|
+
):
|
|
37
|
+
"""
|
|
38
|
+
Base class for adaptive decomposition transformers. See IncrementalPCATransformer and MiniBatchNMFTransformer
|
|
39
|
+
for concrete implementations.
|
|
40
|
+
|
|
41
|
+
Note that for these classes, adaptation is not automatic. The user must call partial_fit on the transformer.
|
|
42
|
+
For automated adaptation, see IncrementalDecompTransformer.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self, *args, **kwargs):
|
|
46
|
+
super().__init__(*args, **kwargs)
|
|
47
|
+
self._state.estimator = self._create_estimator()
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def get_message_type(cls, dir: str) -> typing.Type[AxisArray]:
|
|
51
|
+
# Override because we don't reuse the generic types.
|
|
52
|
+
return AxisArray
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
def get_estimator_type(cls) -> typing.Type[EstimatorType]:
|
|
56
|
+
return typing.get_args(cls.__orig_bases__[0])[0]
|
|
57
|
+
|
|
58
|
+
def _create_estimator(self) -> EstimatorType:
|
|
59
|
+
estimator_klass = self.get_estimator_type()
|
|
60
|
+
estimator_settings = self.settings.__dict__.copy()
|
|
61
|
+
estimator_settings.pop("axis")
|
|
62
|
+
return estimator_klass(**estimator_settings)
|
|
63
|
+
|
|
64
|
+
def _calculate_axis_groups(self, message: AxisArray):
|
|
65
|
+
if self.settings.axis.startswith("!"):
|
|
66
|
+
# Iterate over the !axis and collapse all other axes
|
|
67
|
+
iter_axis = self.settings.axis[1:]
|
|
68
|
+
it_ax_ix = message.get_axis_idx(iter_axis)
|
|
69
|
+
targ_axes = message.dims[:it_ax_ix] + message.dims[it_ax_ix + 1 :]
|
|
70
|
+
off_targ_axes = []
|
|
71
|
+
else:
|
|
72
|
+
# Do PCA on the parameterized axis
|
|
73
|
+
targ_axes = [self.settings.axis]
|
|
74
|
+
# Iterate over streaming axis
|
|
75
|
+
iter_axis = "win" if "win" in message.dims else "time"
|
|
76
|
+
if iter_axis == self.settings.axis:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
f"Iterating axis ({iter_axis}) cannot be the same as the target axis ({self.settings.axis})"
|
|
79
|
+
)
|
|
80
|
+
it_ax_ix = message.get_axis_idx(iter_axis)
|
|
81
|
+
# Remaining axes are to be treated independently
|
|
82
|
+
off_targ_axes = [
|
|
83
|
+
_
|
|
84
|
+
for _ in (message.dims[:it_ax_ix] + message.dims[it_ax_ix + 1 :])
|
|
85
|
+
if _ != self.settings.axis
|
|
86
|
+
]
|
|
87
|
+
self._state.axis_groups = iter_axis, targ_axes, off_targ_axes
|
|
88
|
+
|
|
89
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
90
|
+
iter_axis = (
|
|
91
|
+
self.settings.axis[1:]
|
|
92
|
+
if self.settings.axis.startswith("!")
|
|
93
|
+
else ("win" if "win" in message.dims else "time")
|
|
94
|
+
)
|
|
95
|
+
ax_idx = message.get_axis_idx(iter_axis)
|
|
96
|
+
sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
|
|
97
|
+
return hash((sample_shape, message.key))
|
|
98
|
+
|
|
99
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
100
|
+
"""Reset state"""
|
|
101
|
+
self._calculate_axis_groups(message)
|
|
102
|
+
iter_axis, targ_axes, off_targ_axes = self._state.axis_groups
|
|
103
|
+
|
|
104
|
+
# Template
|
|
105
|
+
out_dims = [iter_axis] + off_targ_axes
|
|
106
|
+
out_axes = {
|
|
107
|
+
iter_axis: message.axes[iter_axis],
|
|
108
|
+
**{k: message.axes[k] for k in off_targ_axes},
|
|
109
|
+
}
|
|
110
|
+
if len(targ_axes) == 1:
|
|
111
|
+
targ_ax_name = targ_axes[0]
|
|
112
|
+
else:
|
|
113
|
+
targ_ax_name = "components"
|
|
114
|
+
out_dims += [targ_ax_name]
|
|
115
|
+
out_axes[targ_ax_name] = AxisArray.CoordinateAxis(
|
|
116
|
+
data=np.arange(self.settings.n_components).astype(str),
|
|
117
|
+
dims=[targ_ax_name],
|
|
118
|
+
unit="component",
|
|
119
|
+
)
|
|
120
|
+
out_shape = [message.data.shape[message.get_axis_idx(_)] for _ in off_targ_axes]
|
|
121
|
+
out_shape = (0,) + tuple(out_shape) + (self.settings.n_components,)
|
|
122
|
+
self._state.template = replace(
|
|
123
|
+
message,
|
|
124
|
+
data=np.zeros(out_shape, dtype=float),
|
|
125
|
+
dims=out_dims,
|
|
126
|
+
axes=out_axes,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
130
|
+
iter_axis, targ_axes, off_targ_axes = self._state.axis_groups
|
|
131
|
+
ax_idx = message.get_axis_idx(iter_axis)
|
|
132
|
+
in_dat = message.data
|
|
133
|
+
|
|
134
|
+
if in_dat.shape[ax_idx] == 0:
|
|
135
|
+
return self._state.template
|
|
136
|
+
|
|
137
|
+
# Re-order axes
|
|
138
|
+
sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes
|
|
139
|
+
if message.dims != sorted_dims_exp:
|
|
140
|
+
# TODO: Implement axes transposition if needed
|
|
141
|
+
# re_order = [ax_idx] + off_targ_inds + targ_inds
|
|
142
|
+
# np.transpose(in_dat, re_order)
|
|
143
|
+
pass
|
|
144
|
+
|
|
145
|
+
# fold [iter_axis] + off_targ_axes together and fold targ_axes together
|
|
146
|
+
d2 = np.prod(in_dat.shape[len(off_targ_axes) + 1 :])
|
|
147
|
+
in_dat = in_dat.reshape((-1, d2))
|
|
148
|
+
|
|
149
|
+
replace_kwargs = {
|
|
150
|
+
"axes": {**self._state.template.axes, iter_axis: message.axes[iter_axis]},
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
# Transform data
|
|
154
|
+
if hasattr(self._state.estimator, "components_"):
|
|
155
|
+
decomp_dat = self._state.estimator.transform(in_dat).reshape(
|
|
156
|
+
(-1,) + self._state.template.data.shape[1:]
|
|
157
|
+
)
|
|
158
|
+
replace_kwargs["data"] = decomp_dat
|
|
159
|
+
|
|
160
|
+
return replace(self._state.template, **replace_kwargs)
|
|
161
|
+
|
|
162
|
+
def partial_fit(self, message: AxisArray) -> None:
|
|
163
|
+
# Check if we need to reset state
|
|
164
|
+
msg_hash = self._hash_message(message)
|
|
165
|
+
if self._hash != msg_hash:
|
|
166
|
+
self._reset_state(message)
|
|
167
|
+
self._hash = msg_hash
|
|
168
|
+
|
|
169
|
+
iter_axis, targ_axes, off_targ_axes = self._state.axis_groups
|
|
170
|
+
ax_idx = message.get_axis_idx(iter_axis)
|
|
171
|
+
in_dat = message.data
|
|
172
|
+
|
|
173
|
+
if in_dat.shape[ax_idx] == 0:
|
|
174
|
+
return
|
|
175
|
+
|
|
176
|
+
# Re-order axes if needed
|
|
177
|
+
sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes
|
|
178
|
+
if message.dims != sorted_dims_exp:
|
|
179
|
+
# TODO: Implement axes transposition if needed
|
|
180
|
+
pass
|
|
181
|
+
|
|
182
|
+
# fold [iter_axis] + off_targ_axes together and fold targ_axes together
|
|
183
|
+
d2 = np.prod(in_dat.shape[len(off_targ_axes) + 1 :])
|
|
184
|
+
in_dat = in_dat.reshape((-1, d2))
|
|
185
|
+
|
|
186
|
+
# Fit the estimator
|
|
187
|
+
self._state.estimator.partial_fit(in_dat)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class IncrementalPCASettings(AdaptiveDecompSettings):
|
|
191
|
+
# Additional settings specific to PCA
|
|
192
|
+
whiten: bool = False
|
|
193
|
+
batch_size: typing.Optional[int] = None
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class IncrementalPCATransformer(AdaptiveDecompTransformer[IncrementalPCA]):
|
|
197
|
+
pass
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class MiniBatchNMFSettings(AdaptiveDecompSettings):
|
|
201
|
+
# Additional settings specific to NMF
|
|
202
|
+
init: typing.Optional[str] = "random"
|
|
203
|
+
"""
|
|
204
|
+
'random', 'nndsvd', 'nndsvda', 'nndsvdar', 'custom', or None
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
batch_size: int = 1024
|
|
208
|
+
"""
|
|
209
|
+
batch_size is used only when doing a full fit (i.e., a reset),
|
|
210
|
+
or as the exponent to forget_factor, where a very small batch_size
|
|
211
|
+
will cause the model to update more slowly.
|
|
212
|
+
It is better to set batch_size to a larger number than the expected
|
|
213
|
+
chunk size and instead use forget_factor to control the learning rate.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
beta_loss: typing.Union[str, float] = "frobenius"
|
|
217
|
+
"""
|
|
218
|
+
'frobenius', 'kullback-leibler', 'itakura-saito'
|
|
219
|
+
Note that values different from 'frobenius'
|
|
220
|
+
(or 2) and 'kullback-leibler' (or 1) lead to significantly slower
|
|
221
|
+
fits. Note that for `beta_loss <= 0` (or 'itakura-saito'), the input
|
|
222
|
+
matrix `X` cannot contain zeros.
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
tol: float = 1e-4
|
|
226
|
+
|
|
227
|
+
max_no_improvement: typing.Optional[int] = None
|
|
228
|
+
|
|
229
|
+
max_iter: int = 200
|
|
230
|
+
|
|
231
|
+
alpha_W: float = 0.0
|
|
232
|
+
|
|
233
|
+
alpha_H: typing.Union[float, str] = "same"
|
|
234
|
+
|
|
235
|
+
l1_ratio: float = 0.0
|
|
236
|
+
|
|
237
|
+
forget_factor: float = 0.7
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class MiniBatchNMFTransformer(AdaptiveDecompTransformer[MiniBatchNMF]):
|
|
241
|
+
pass
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
SettingsType = typing.TypeVar(
|
|
245
|
+
"SettingsType", bound=typing.Union[IncrementalPCASettings, MiniBatchNMFSettings]
|
|
246
|
+
)
|
|
247
|
+
TransformerType = typing.TypeVar(
|
|
248
|
+
"TransformerType",
|
|
249
|
+
bound=typing.Union[IncrementalPCATransformer, MiniBatchNMFTransformer],
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class BaseAdaptiveDecompUnit(
|
|
254
|
+
BaseAdaptiveTransformerUnit[
|
|
255
|
+
SettingsType,
|
|
256
|
+
AxisArray,
|
|
257
|
+
AxisArray,
|
|
258
|
+
TransformerType,
|
|
259
|
+
],
|
|
260
|
+
typing.Generic[SettingsType, TransformerType],
|
|
261
|
+
):
|
|
262
|
+
INPUT_SAMPLE = ez.InputStream(AxisArray)
|
|
263
|
+
|
|
264
|
+
@ez.subscriber(INPUT_SAMPLE)
|
|
265
|
+
async def on_sample(self, msg: AxisArray) -> None:
|
|
266
|
+
await self.processor.apartial_fit(msg)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
class IncrementalPCAUnit(
|
|
270
|
+
BaseAdaptiveDecompUnit[
|
|
271
|
+
IncrementalPCASettings,
|
|
272
|
+
IncrementalPCATransformer,
|
|
273
|
+
]
|
|
274
|
+
):
|
|
275
|
+
SETTINGS = IncrementalPCASettings
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class MiniBatchNMFUnit(
|
|
279
|
+
BaseAdaptiveDecompUnit[
|
|
280
|
+
MiniBatchNMFSettings,
|
|
281
|
+
MiniBatchNMFTransformer,
|
|
282
|
+
]
|
|
283
|
+
):
|
|
284
|
+
SETTINGS = MiniBatchNMFSettings
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import ezmsg.core as ez
|
|
5
|
+
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
6
|
+
from ezmsg.sigproc.base import (
|
|
7
|
+
CompositeProcessor,
|
|
8
|
+
BaseStatefulProcessor,
|
|
9
|
+
BaseTransformerUnit,
|
|
10
|
+
)
|
|
11
|
+
from ezmsg.sigproc.window import WindowTransformer
|
|
12
|
+
|
|
13
|
+
from .adaptive_decomp import (
|
|
14
|
+
IncrementalPCASettings,
|
|
15
|
+
IncrementalPCATransformer,
|
|
16
|
+
MiniBatchNMFSettings,
|
|
17
|
+
MiniBatchNMFTransformer,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class IncrementalDecompSettings(ez.Settings):
|
|
22
|
+
axis: str = "!time"
|
|
23
|
+
n_components: int = 2
|
|
24
|
+
update_interval: float = 0.0
|
|
25
|
+
method: str = "pca"
|
|
26
|
+
batch_size: typing.Optional[int] = None
|
|
27
|
+
# PCA specific settings
|
|
28
|
+
whiten: bool = False
|
|
29
|
+
# NMF specific settings
|
|
30
|
+
init: str = "random"
|
|
31
|
+
beta_loss: str = "frobenius"
|
|
32
|
+
tol: float = 1e-3
|
|
33
|
+
alpha_W: float = 0.0
|
|
34
|
+
alpha_H: typing.Union[float, str] = "same"
|
|
35
|
+
l1_ratio: float = 0.0
|
|
36
|
+
forget_factor: float = 0.7
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class IncrementalDecompTransformer(
|
|
40
|
+
CompositeProcessor[IncrementalDecompSettings, AxisArray, AxisArray]
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
Automates usage of IncrementalPCATransformer and MiniBatchNMFTransformer by using a WindowTransformer
|
|
44
|
+
to extract training samples then calls partial_fit on the decomposition transformer.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
@staticmethod
|
|
48
|
+
def _initialize_processors(
|
|
49
|
+
settings: IncrementalDecompSettings,
|
|
50
|
+
) -> dict[str, BaseStatefulProcessor]:
|
|
51
|
+
# Create the appropriate decomposition transformer
|
|
52
|
+
if settings.method == "pca":
|
|
53
|
+
decomp_settings = IncrementalPCASettings(
|
|
54
|
+
axis=settings.axis,
|
|
55
|
+
n_components=settings.n_components,
|
|
56
|
+
batch_size=settings.batch_size,
|
|
57
|
+
whiten=settings.whiten,
|
|
58
|
+
)
|
|
59
|
+
decomp = IncrementalPCATransformer(settings=decomp_settings)
|
|
60
|
+
else: # nmf
|
|
61
|
+
decomp_settings = MiniBatchNMFSettings(
|
|
62
|
+
axis=settings.axis,
|
|
63
|
+
n_components=settings.n_components,
|
|
64
|
+
batch_size=settings.batch_size if settings.batch_size else 1024,
|
|
65
|
+
init=settings.init,
|
|
66
|
+
beta_loss=settings.beta_loss,
|
|
67
|
+
tol=settings.tol,
|
|
68
|
+
alpha_W=settings.alpha_W,
|
|
69
|
+
alpha_H=settings.alpha_H,
|
|
70
|
+
l1_ratio=settings.l1_ratio,
|
|
71
|
+
forget_factor=settings.forget_factor,
|
|
72
|
+
)
|
|
73
|
+
decomp = MiniBatchNMFTransformer(settings=decomp_settings)
|
|
74
|
+
|
|
75
|
+
# Create windowing processor if update_interval is specified
|
|
76
|
+
if settings.update_interval > 0:
|
|
77
|
+
# TODO: This `iter_axis` is likely incorrect.
|
|
78
|
+
iter_axis = settings.axis[1:] if settings.axis.startswith("!") else "time"
|
|
79
|
+
windowing = WindowTransformer(
|
|
80
|
+
axis=iter_axis,
|
|
81
|
+
window_dur=settings.update_interval,
|
|
82
|
+
window_shift=settings.update_interval,
|
|
83
|
+
zero_pad_until="none",
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
return {
|
|
87
|
+
"decomp": decomp,
|
|
88
|
+
"windowing": windowing,
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
return {"decomp": decomp}
|
|
92
|
+
|
|
93
|
+
def _partial_fit_windowed(self, train_msg: AxisArray) -> None:
|
|
94
|
+
"""
|
|
95
|
+
Helper function to do the partial_fit on the windowed message.
|
|
96
|
+
"""
|
|
97
|
+
if np.prod(train_msg.data.shape) > 0:
|
|
98
|
+
# Windowing created a new "win" axis, but we don't actually want to use that
|
|
99
|
+
# in the message we send to the decomp processor.
|
|
100
|
+
axis_idx = train_msg.get_axis_idx("win")
|
|
101
|
+
win_axis = train_msg.axes["win"]
|
|
102
|
+
offsets = win_axis.value(np.asarray(range(train_msg.data.shape[axis_idx])))
|
|
103
|
+
for ix, _msg in enumerate(train_msg.iter_over_axis("win")):
|
|
104
|
+
_msg = replace(
|
|
105
|
+
_msg,
|
|
106
|
+
axes={
|
|
107
|
+
**_msg.axes,
|
|
108
|
+
"time": replace(
|
|
109
|
+
_msg.axes["time"],
|
|
110
|
+
offset=_msg.axes["time"].offset + offsets[ix],
|
|
111
|
+
),
|
|
112
|
+
},
|
|
113
|
+
)
|
|
114
|
+
self._procs["decomp"].partial_fit(_msg)
|
|
115
|
+
|
|
116
|
+
def stateful_op(
|
|
117
|
+
self,
|
|
118
|
+
state: dict[str, tuple[typing.Any, int]] | None,
|
|
119
|
+
message: AxisArray,
|
|
120
|
+
) -> tuple[dict[str, tuple[typing.Any, int]], AxisArray]:
|
|
121
|
+
state = state or {}
|
|
122
|
+
|
|
123
|
+
estim = self._procs["decomp"]._state.estimator
|
|
124
|
+
if not hasattr(estim, "components_") or estim.components_ is None:
|
|
125
|
+
# If the estimator has not been trained once, train it with the first message
|
|
126
|
+
self._procs["decomp"].partial_fit(message)
|
|
127
|
+
elif "windowing" in self._procs:
|
|
128
|
+
state["windowing"], train_msg = self._procs["windowing"].stateful_op(
|
|
129
|
+
state.get("windowing", None), message
|
|
130
|
+
)
|
|
131
|
+
self._partial_fit_windowed(train_msg)
|
|
132
|
+
|
|
133
|
+
# Process the incoming message
|
|
134
|
+
state["decomp"], result = self._procs["decomp"].stateful_op(
|
|
135
|
+
state.get("decomp", None), message
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
return state, result
|
|
139
|
+
|
|
140
|
+
async def _aprocess(self, message: AxisArray) -> AxisArray:
|
|
141
|
+
"""
|
|
142
|
+
Asynchronously process the incoming message.
|
|
143
|
+
This is nearly identical to the _process method, but the processors
|
|
144
|
+
are called asynchronously.
|
|
145
|
+
"""
|
|
146
|
+
estim = self._procs["decomp"]._state.estimator
|
|
147
|
+
if not hasattr(estim, "components_") or estim.components_ is None:
|
|
148
|
+
# If the estimator has not been trained once, train it with the first message
|
|
149
|
+
self._procs["decomp"].partial_fit(message)
|
|
150
|
+
elif "windowing" in self._procs:
|
|
151
|
+
# If windowing is enabled, extract training samples and perform partial_fit
|
|
152
|
+
train_msg = await self._procs["windowing"].__acall__(message)
|
|
153
|
+
self._partial_fit_windowed(train_msg) # Non async
|
|
154
|
+
|
|
155
|
+
# Process the incoming message
|
|
156
|
+
decomp_result = await self._procs["decomp"].__acall__(message)
|
|
157
|
+
|
|
158
|
+
return decomp_result
|
|
159
|
+
|
|
160
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
161
|
+
estim = self._procs["decomp"]._state.estimator
|
|
162
|
+
if not hasattr(estim, "components_") or estim.components_ is None:
|
|
163
|
+
# If the estimator has not been trained once, train it with the first message
|
|
164
|
+
self._procs["decomp"].partial_fit(message)
|
|
165
|
+
elif "windowing" in self._procs:
|
|
166
|
+
# If windowing is enabled, extract training samples and perform partial_fit
|
|
167
|
+
train_msg = self._procs["windowing"](message)
|
|
168
|
+
self._partial_fit_windowed(train_msg)
|
|
169
|
+
|
|
170
|
+
# Process the incoming message
|
|
171
|
+
decomp_result = self._procs["decomp"](message)
|
|
172
|
+
|
|
173
|
+
return decomp_result
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class IncrementalDecompUnit(
|
|
177
|
+
BaseTransformerUnit[
|
|
178
|
+
IncrementalDecompSettings, AxisArray, AxisArray, IncrementalDecompTransformer
|
|
179
|
+
]
|
|
180
|
+
):
|
|
181
|
+
SETTINGS = IncrementalDecompSettings
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Use of this module is deprecated. Please use `ezmsg.learn.process` instead.
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from ..process.adaptive_linear_regressor import (
|
|
2
|
+
AdaptiveLinearRegressorSettings as AdaptiveLinearRegressorSettings,
|
|
3
|
+
AdaptiveLinearRegressorState as AdaptiveLinearRegressorState,
|
|
4
|
+
AdaptiveLinearRegressorTransformer as AdaptiveLinearRegressorTransformer,
|
|
5
|
+
AdaptiveLinearRegressorUnit as AdaptiveLinearRegressorUnit,
|
|
6
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from ..model.cca import IncrementalCCA as IncrementalCCA
|
|
File without changes
|
ezmsg/learn/model/cca.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from scipy import linalg
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class IncrementalCCA:
|
|
6
|
+
def __init__(
|
|
7
|
+
self,
|
|
8
|
+
n_components=2,
|
|
9
|
+
base_smoothing=0.95,
|
|
10
|
+
min_smoothing=0.5,
|
|
11
|
+
max_smoothing=0.99,
|
|
12
|
+
adaptation_rate=0.1,
|
|
13
|
+
):
|
|
14
|
+
"""
|
|
15
|
+
Parameters:
|
|
16
|
+
-----------
|
|
17
|
+
n_components : int
|
|
18
|
+
Number of canonical components to compute
|
|
19
|
+
base_smoothing : float
|
|
20
|
+
Base smoothing factor (will be adapted)
|
|
21
|
+
min_smoothing : float
|
|
22
|
+
Minimum allowed smoothing factor
|
|
23
|
+
max_smoothing : float
|
|
24
|
+
Maximum allowed smoothing factor
|
|
25
|
+
adaptation_rate : float
|
|
26
|
+
How quickly to adjust smoothing factor (between 0 and 1)
|
|
27
|
+
"""
|
|
28
|
+
self.n_components = n_components
|
|
29
|
+
self.base_smoothing = base_smoothing
|
|
30
|
+
self.current_smoothing = base_smoothing
|
|
31
|
+
self.min_smoothing = min_smoothing
|
|
32
|
+
self.max_smoothing = max_smoothing
|
|
33
|
+
self.adaptation_rate = adaptation_rate
|
|
34
|
+
self.initialized = False
|
|
35
|
+
|
|
36
|
+
def initialize(self, d1, d2):
|
|
37
|
+
"""Initialize the necessary matrices"""
|
|
38
|
+
self.d1 = d1
|
|
39
|
+
self.d2 = d2
|
|
40
|
+
|
|
41
|
+
# Initialize correlation matrices
|
|
42
|
+
self.C11 = np.zeros((d1, d1))
|
|
43
|
+
self.C22 = np.zeros((d2, d2))
|
|
44
|
+
self.C12 = np.zeros((d1, d2))
|
|
45
|
+
|
|
46
|
+
self.initialized = True
|
|
47
|
+
|
|
48
|
+
def _compute_change_magnitude(self, C11_new, C22_new, C12_new):
|
|
49
|
+
"""Compute magnitude of change in correlation structure"""
|
|
50
|
+
# Frobenius norm of differences
|
|
51
|
+
diff11 = np.linalg.norm(C11_new - self.C11)
|
|
52
|
+
diff22 = np.linalg.norm(C22_new - self.C22)
|
|
53
|
+
diff12 = np.linalg.norm(C12_new - self.C12)
|
|
54
|
+
|
|
55
|
+
# Normalize by matrix sizes
|
|
56
|
+
diff11 /= self.d1 * self.d1
|
|
57
|
+
diff22 /= self.d2 * self.d2
|
|
58
|
+
diff12 /= self.d1 * self.d2
|
|
59
|
+
|
|
60
|
+
return (diff11 + diff22 + diff12) / 3
|
|
61
|
+
|
|
62
|
+
def _adapt_smoothing(self, change_magnitude):
|
|
63
|
+
"""Adapt smoothing factor based on detected changes"""
|
|
64
|
+
# If change is large, decrease smoothing factor
|
|
65
|
+
target_smoothing = self.base_smoothing * (1.0 - change_magnitude)
|
|
66
|
+
target_smoothing = np.clip(
|
|
67
|
+
target_smoothing, self.min_smoothing, self.max_smoothing
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Smooth the adaptation itself
|
|
71
|
+
self.current_smoothing = (
|
|
72
|
+
1 - self.adaptation_rate
|
|
73
|
+
) * self.current_smoothing + self.adaptation_rate * target_smoothing
|
|
74
|
+
|
|
75
|
+
def partial_fit(self, X1, X2, update_projections=True):
|
|
76
|
+
"""Update the model with new samples using adaptive smoothing
|
|
77
|
+
Assumes X1 and X2 are already centered and scaled"""
|
|
78
|
+
if not self.initialized:
|
|
79
|
+
self.initialize(X1.shape[1], X2.shape[1])
|
|
80
|
+
|
|
81
|
+
# Compute new correlation matrices from current batch
|
|
82
|
+
C11_new = X1.T @ X1 / X1.shape[0]
|
|
83
|
+
C22_new = X2.T @ X2 / X2.shape[0]
|
|
84
|
+
C12_new = X1.T @ X2 / X1.shape[0]
|
|
85
|
+
|
|
86
|
+
# Detect changes and adapt smoothing factor
|
|
87
|
+
if self.C11.any(): # Skip first update
|
|
88
|
+
change_magnitude = self._compute_change_magnitude(C11_new, C22_new, C12_new)
|
|
89
|
+
self._adapt_smoothing(change_magnitude)
|
|
90
|
+
|
|
91
|
+
# Update with current smoothing factor
|
|
92
|
+
alpha = self.current_smoothing
|
|
93
|
+
self.C11 = alpha * self.C11 + (1 - alpha) * C11_new
|
|
94
|
+
self.C22 = alpha * self.C22 + (1 - alpha) * C22_new
|
|
95
|
+
self.C12 = alpha * self.C12 + (1 - alpha) * C12_new
|
|
96
|
+
|
|
97
|
+
if update_projections:
|
|
98
|
+
self._update_projections()
|
|
99
|
+
|
|
100
|
+
def _update_projections(self):
|
|
101
|
+
"""Update canonical vectors and correlations"""
|
|
102
|
+
eps = 1e-8
|
|
103
|
+
C11_reg = self.C11 + eps * np.eye(self.d1)
|
|
104
|
+
C22_reg = self.C22 + eps * np.eye(self.d2)
|
|
105
|
+
|
|
106
|
+
K = (
|
|
107
|
+
linalg.inv(linalg.sqrtm(C11_reg))
|
|
108
|
+
@ self.C12
|
|
109
|
+
@ linalg.inv(linalg.sqrtm(C22_reg))
|
|
110
|
+
)
|
|
111
|
+
U, self.correlations_, V = linalg.svd(K)
|
|
112
|
+
|
|
113
|
+
self.x_weights_ = linalg.inv(linalg.sqrtm(C11_reg)) @ U[:, : self.n_components]
|
|
114
|
+
self.y_weights_ = (
|
|
115
|
+
linalg.inv(linalg.sqrtm(C22_reg)) @ V.T[:, : self.n_components]
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
def transform(self, X1, X2):
|
|
119
|
+
"""Project data onto canonical components"""
|
|
120
|
+
X1_proj = X1 @ self.x_weights_
|
|
121
|
+
X2_proj = X2 @ self.y_weights_
|
|
122
|
+
return X1_proj, X2_proj
|