ezmsg-learn 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,222 @@
1
+ import typing
2
+
3
+ import ezmsg.core as ez
4
+ import torch
5
+ from ezmsg.sigproc.base import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit
6
+ from ezmsg.sigproc.sampler import SampleMessage
7
+ from ezmsg.sigproc.util.profile import profile_subpub
8
+ from ezmsg.util.messages.axisarray import AxisArray
9
+ from ezmsg.util.messages.util import replace
10
+
11
+ from .base import ModelInitMixin
12
+ from .torch import (
13
+ TorchModelSettings,
14
+ TorchModelState,
15
+ TorchProcessorMixin,
16
+ )
17
+
18
+
19
+ class TransformerSettings(TorchModelSettings):
20
+ model_class: str = "ezmsg.learn.model.transformer.TransformerModel"
21
+ """
22
+ Fully qualified class path of the model to be used.
23
+ This should be "ezmsg.learn.model.transformer.TransformerModel" for this.
24
+ """
25
+ autoregressive_head: str | None = None
26
+ """
27
+ The name of the output head used for autoregressive decoding.
28
+ This should match one of the keys in the model's output dictionary.
29
+ If None, the first output head will be used.
30
+ """
31
+ max_cache_len: int | None = 128
32
+ """
33
+ Maximum length of the target sequence cache for autoregressive decoding.
34
+ This limits the context length during decoding to prevent excessive memory usage.
35
+ If set to None, the cache will grow indefinitely.
36
+ """
37
+
38
+
39
+ class TransformerState(TorchModelState):
40
+ ar_head: str | None = None
41
+ """
42
+ The name of the autoregressive head used for decoding.
43
+ This is set based on the `autoregressive_head` setting.
44
+ If None, the first output head will be used.
45
+ """
46
+ tgt_cache: typing.Optional[torch.Tensor] = None
47
+ """
48
+ Cache for the target sequence used in autoregressive decoding.
49
+ This is updated with each processed message to maintain context.
50
+ """
51
+
52
+
53
+ class TransformerProcessor(
54
+ BaseAdaptiveTransformer[
55
+ TransformerSettings, AxisArray, AxisArray, TransformerState
56
+ ],
57
+ TorchProcessorMixin,
58
+ ModelInitMixin,
59
+ ):
60
+ @property
61
+ def has_decoder(self) -> bool:
62
+ return self.settings.model_kwargs.get("decoder_layers", 0) > 0
63
+
64
+ def reset_cache(self) -> None:
65
+ self._state.tgt_cache = None
66
+
67
+ def _reset_state(self, message: AxisArray) -> None:
68
+ model_kwargs = dict(self.settings.model_kwargs or {})
69
+ self._common_reset_state(message, model_kwargs)
70
+ self._init_optimizer()
71
+ self._validate_loss_keys(list(self._state.chan_ax.keys()))
72
+
73
+ self._state.tgt_cache = None
74
+ if (
75
+ self.settings.autoregressive_head is not None
76
+ and self.settings.autoregressive_head not in self._state.chan_ax
77
+ ):
78
+ raise ValueError(
79
+ f"Autoregressive head '{self.settings.autoregressive_head}' not found in target dictionary keys: {list(self._state.chan_ax.keys())}"
80
+ )
81
+ self._state.ar_head = (
82
+ self.settings.autoregressive_head
83
+ if self.settings.autoregressive_head is not None
84
+ else list(self._state.chan_ax.keys())[0]
85
+ )
86
+
87
+ def _process(self, message: AxisArray) -> list[AxisArray]:
88
+ # If has_decoder is False, fallback to regular processing
89
+ if not self.has_decoder:
90
+ return self._common_process(message)
91
+
92
+ x = self._to_tensor(message.data)
93
+ x, _ = self._ensure_batched(x)
94
+ if x.shape[0] > 1:
95
+ raise ValueError("Autoregressive decoding only supports batch size 1.")
96
+
97
+ with torch.no_grad():
98
+ y_pred = self._state.model(x, tgt=self._state.tgt_cache)
99
+
100
+ pred = y_pred[self._state.ar_head]
101
+ if self._state.tgt_cache is None:
102
+ self._state.tgt_cache = pred[:, -1:, :]
103
+ else:
104
+ self._state.tgt_cache = torch.cat(
105
+ [self._state.tgt_cache, pred[:, -1:, :]], dim=1
106
+ )
107
+ if self.settings.max_cache_len is not None:
108
+ if self._state.tgt_cache.shape[1] > self.settings.max_cache_len:
109
+ # Trim the cache to the maximum length
110
+ self._state.tgt_cache = self._state.tgt_cache[
111
+ :, -self.settings.max_cache_len :, :
112
+ ]
113
+
114
+ if isinstance(y_pred, dict):
115
+ return [
116
+ replace(
117
+ message,
118
+ data=out.squeeze(0).cpu().numpy(),
119
+ axes={**message.axes, "ch": self._state.chan_ax[key]},
120
+ key=key,
121
+ )
122
+ for key, out in y_pred.items()
123
+ ]
124
+ else:
125
+ return [
126
+ replace(
127
+ message,
128
+ data=y_pred.squeeze(0).cpu().numpy(),
129
+ axes={**message.axes, "ch": self._state.chan_ax["output"]},
130
+ )
131
+ ]
132
+
133
+ def partial_fit(self, message: SampleMessage) -> None:
134
+ self._state.model.train()
135
+
136
+ X = self._to_tensor(message.sample.data)
137
+ X, batched = self._ensure_batched(X)
138
+
139
+ y_targ = message.trigger.value
140
+ if not isinstance(y_targ, dict):
141
+ y_targ = {"output": y_targ}
142
+ y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
143
+ # Add batch dimension to y_targ values if missing
144
+ if batched:
145
+ for key in y_targ:
146
+ y_targ[key] = y_targ[key].unsqueeze(0)
147
+
148
+ loss_fns = self.settings.loss_fn
149
+ if loss_fns is None:
150
+ raise ValueError("loss_fn must be provided in settings to use partial_fit")
151
+ if not isinstance(loss_fns, dict):
152
+ loss_fns = {k: loss_fns for k in y_targ.keys()}
153
+
154
+ weights = self.settings.loss_weights or {}
155
+
156
+ if self.has_decoder:
157
+ if X.shape[0] != 1:
158
+ raise ValueError("Autoregressive decoding only supports batch size 1.")
159
+
160
+ # Create shifted target for autoregressive head
161
+ tgt_tensor = y_targ[self._state.ar_head]
162
+ tgt = torch.cat(
163
+ [
164
+ torch.zeros(
165
+ (1, 1, tgt_tensor.shape[-1]),
166
+ dtype=tgt_tensor.dtype,
167
+ device=tgt_tensor.device,
168
+ ),
169
+ tgt_tensor[:, :-1, :],
170
+ ],
171
+ dim=1,
172
+ )
173
+
174
+ # Reset tgt_cache at start of partial_fit to avoid stale context
175
+ self.reset_cache()
176
+ y_pred = self._state.model(X, tgt=tgt)
177
+ else:
178
+ # For non-autoregressive models, use the model directly
179
+ y_pred = self._state.model(X)
180
+
181
+ if not isinstance(y_pred, dict):
182
+ y_pred = {"output": y_pred}
183
+
184
+ with torch.set_grad_enabled(True):
185
+ losses = []
186
+ for key in y_targ.keys():
187
+ loss_fn = loss_fns.get(key)
188
+ if loss_fn is None:
189
+ raise ValueError(
190
+ f"Loss function for key '{key}' is not defined in settings."
191
+ )
192
+ loss = loss_fn(y_pred[key], y_targ[key])
193
+ weight = weights.get(key, 1.0)
194
+ losses.append(loss * weight)
195
+ total_loss = sum(losses)
196
+
197
+ self._state.optimizer.zero_grad()
198
+ total_loss.backward()
199
+ self._state.optimizer.step()
200
+ if self._state.scheduler is not None:
201
+ self._state.scheduler.step()
202
+
203
+ self._state.model.eval()
204
+
205
+
206
+ class TransformerUnit(
207
+ BaseAdaptiveTransformerUnit[
208
+ TransformerSettings,
209
+ AxisArray,
210
+ AxisArray,
211
+ TransformerProcessor,
212
+ ]
213
+ ):
214
+ SETTINGS = TransformerSettings
215
+
216
+ @ez.subscriber(BaseAdaptiveTransformerUnit.INPUT_SIGNAL, zero_copy=True)
217
+ @ez.publisher(BaseAdaptiveTransformerUnit.OUTPUT_SIGNAL)
218
+ @profile_subpub(trace_oldest=False)
219
+ async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
220
+ results = await self.processor.__acall__(message)
221
+ for result in results:
222
+ yield self.OUTPUT_SIGNAL, result
ezmsg/learn/util.py ADDED
@@ -0,0 +1,66 @@
1
+ from enum import Enum
2
+ from dataclasses import dataclass, field
3
+ import typing
4
+
5
+ from ezmsg.util.messages.axisarray import AxisArray
6
+ import sklearn.linear_model
7
+ import river.linear_model
8
+ # from sklearn.neural_network import MLPClassifier
9
+
10
+
11
+ class RegressorType(str, Enum):
12
+ ADAPTIVE = "adaptive"
13
+ STATIC = "static"
14
+
15
+
16
+ class AdaptiveLinearRegressor(str, Enum):
17
+ LINEAR = "linear"
18
+ LOGISTIC = "logistic"
19
+ SGD = "sgd"
20
+ PAR = "par" # passive-aggressive
21
+ # MLP = "mlp"
22
+
23
+
24
+ class StaticLinearRegressor(str, Enum):
25
+ LINEAR = "linear"
26
+ RIDGE = "ridge"
27
+
28
+
29
+ ADAPTIVE_REGRESSORS = {
30
+ AdaptiveLinearRegressor.LINEAR: river.linear_model.LinearRegression,
31
+ AdaptiveLinearRegressor.LOGISTIC: river.linear_model.LogisticRegression,
32
+ AdaptiveLinearRegressor.SGD: sklearn.linear_model.SGDRegressor,
33
+ AdaptiveLinearRegressor.PAR: sklearn.linear_model.PassiveAggressiveRegressor,
34
+ # AdaptiveLinearRegressor.MLP: MLPClassifier,
35
+ }
36
+
37
+
38
+ # Function to get a regressor by type and name
39
+ def get_regressor(
40
+ regressor_type: typing.Union[RegressorType, str],
41
+ regressor_name: typing.Union[AdaptiveLinearRegressor, StaticLinearRegressor, str],
42
+ ):
43
+ if isinstance(regressor_type, str):
44
+ regressor_type = RegressorType(regressor_type)
45
+
46
+ if regressor_type == RegressorType.ADAPTIVE:
47
+ if isinstance(regressor_name, str):
48
+ regressor_name = AdaptiveLinearRegressor(regressor_name)
49
+ return ADAPTIVE_REGRESSORS[regressor_name]
50
+ elif regressor_type == RegressorType.STATIC:
51
+ if isinstance(regressor_name, str):
52
+ regressor_name = StaticLinearRegressor(regressor_name)
53
+ return STATIC_REGRESSORS[regressor_name]
54
+ else:
55
+ raise ValueError(f"Unknown regressor type: {regressor_type}")
56
+
57
+
58
+ STATIC_REGRESSORS = {
59
+ StaticLinearRegressor.LINEAR: sklearn.linear_model.LinearRegression,
60
+ StaticLinearRegressor.RIDGE: sklearn.linear_model.Ridge,
61
+ }
62
+
63
+
64
+ @dataclass
65
+ class ClassifierMessage(AxisArray):
66
+ labels: list[str] = field(default_factory=list)
@@ -0,0 +1,34 @@
1
+ Metadata-Version: 2.4
2
+ Name: ezmsg-learn
3
+ Version: 1.0
4
+ Summary: ezmsg namespace package for machine learning
5
+ Author-email: Chadwick Boulay <chadwick.boulay@gmail.com>
6
+ License-Expression: MIT
7
+ Requires-Python: >=3.10.15
8
+ Requires-Dist: ezmsg-sigproc
9
+ Requires-Dist: river>=0.22.0
10
+ Requires-Dist: scikit-learn>=1.6.0
11
+ Requires-Dist: torch>=2.6.0
12
+ Description-Content-Type: text/markdown
13
+
14
+ # ezmsg-learn
15
+
16
+ This repository contains a Python package with modules for machine learning (ML)-related processing in the [`ezmsg`](https://www.ezmsg.org) framework. As ezmsg is intended primarily for processing unbounded streaming signals, so are the modules in this repo.
17
+
18
+ > If you are only interested in offline analysis without concern for reproducibility in online applications, then you should probably look elsewhere.
19
+
20
+ Processing units include dimensionality reduction, linear regression, and classification that can be initialized with known weights, or adapted on-the-fly with incoming (labeled) data. Machine-learning code depends on `river`, `scikit-learn`, `numpy`, and `torch`.
21
+
22
+ ## Getting Started
23
+
24
+ This ezmsg namespace package is still highly experimental and under active development. It is not yet available on PyPI, so you will need to install it from source. The easiest way to do this is to use the `pip` command to install the package directly from GitHub:
25
+
26
+ ```bash
27
+ pip install git+ssh://git@github.com/ezmsg-org/ezmsg-learn
28
+ ```
29
+
30
+ Note that this package depends on a specific version of `ezmsg-sigproc` (specifically, [this branch]("70-use-protocols-for-axisarray-transformers")) that has yet to be merged and released. This may conflict with your project's separate dependency on ezmsg-sigproc. However, this specific version of ezmsg-sigproc should be backwards compatible with its main branch, so in your project you can modify the dependency on ezmsg-sigproc to point to the new branch. e.g.,
31
+
32
+ ```bash
33
+ pip install git+ssh://git@github.com/ezmsg-org/ezmsg-sigproc@70-use-protocols-for-axisarray-transformers
34
+ ```
@@ -0,0 +1,36 @@
1
+ ezmsg/learn/__init__.py,sha256=9vTW4C2EQCHgDAo8gIlGNDfOxcQpNGV3Cct9-HsBJKY,57
2
+ ezmsg/learn/__version__.py,sha256=T-YAefOAMONzdzJN9AfYa3q6PjJ-HRflYoFg45W1xFU,699
3
+ ezmsg/learn/util.py,sha256=-WZ3k0sWSIJ1Z9aNiNFFYGa1-8oSbwQc-wI2i86w_C4,2014
4
+ ezmsg/learn/dim_reduce/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ ezmsg/learn/dim_reduce/adaptive_decomp.py,sha256=HmjantyGCIt7gP1d0iJ9mfoEtvR3-FjVrpv9JBT6mdc,9332
6
+ ezmsg/learn/dim_reduce/incremental_decomp.py,sha256=FRx0Rhn3q8yHe64e4jaHSwWNVJ9eJT1ltuZWZc-C1R8,6830
7
+ ezmsg/learn/linear_model/__init__.py,sha256=7_bcxc40W6UN2IfnJfjuVHe5mZ0BSPdHLKqCXpHPMwQ,78
8
+ ezmsg/learn/linear_model/adaptive_linear_regressor.py,sha256=zfbwjTaBGdlBzzDBU6Nu2nuUjmhFVYs7Co84VWyRqIE,332
9
+ ezmsg/learn/linear_model/cca.py,sha256=H-NnK9QH5vI0OnGezf674lV-X15xGzgar6d6G-67fZU,57
10
+ ezmsg/learn/linear_model/linear_regressor.py,sha256=qTkIV2FRMC-Lhg7RTxmX1pWWsov84gV-47x-YSdNGdM,211
11
+ ezmsg/learn/linear_model/sgd.py,sha256=6glInxmhapMVlBSZZB01w3vCqkNyGYOHoV8nfzrYCNI,138
12
+ ezmsg/learn/linear_model/slda.py,sha256=MmmDbfm5y-8XpLThgJlHVV3f2kqZ-mrpnNkkIUHrwWg,151
13
+ ezmsg/learn/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ ezmsg/learn/model/cca.py,sha256=wf6vsfGHWiPt4FgZEG6veudKi1-GJDVp0bxaabxjJvE,4262
15
+ ezmsg/learn/model/mlp.py,sha256=bMm4JX53sRdCLgMeh8gAvJnd2qruiTsyuv4MiwLpvCU,5574
16
+ ezmsg/learn/model/mlp_old.py,sha256=ODpeoU-6DFG_yZLF9jdBOeRsw1ptyMyWCKM-WGJBeRs,2140
17
+ ezmsg/learn/model/refit_kalman.py,sha256=PgjpMpIWkwHX7baSRLcT5loBA6h8uCf1L1spPfM4Nq8,16313
18
+ ezmsg/learn/model/rnn.py,sha256=VmcRG2UcxROyrY10AJ-jwompOThtuj6SYjPzHWxAohw,6654
19
+ ezmsg/learn/model/transformer.py,sha256=Vfo_CATOmoNtaqkE4BRWJy6QcX-1J_Pz3rMG2ErQSLM,7535
20
+ ezmsg/learn/nlin_model/__init__.py,sha256=Ap62_lD7Dc3D-15ebhhBzT23PMrdyiV8fWV6eWy6wnE,101
21
+ ezmsg/learn/nlin_model/mlp.py,sha256=zWTtI1JBP4KDcZCAqecHNn4Y59egEE9Bg2vi8-9la7k,165
22
+ ezmsg/learn/process/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
+ ezmsg/learn/process/adaptive_linear_regressor.py,sha256=0v6xRgWQWub0KlxOlb8ktYuohcwm4gvjHwOtv1qnNHg,5485
24
+ ezmsg/learn/process/base.py,sha256=MhJstPfoTBNqLMDBAZqwbKueYyr9HJhvLzeU3k-Wl8E,7023
25
+ ezmsg/learn/process/linear_regressor.py,sha256=2uKxvMraWITU0xZwbbI3O_mDD4cysl8RzHgl2XrPOPA,3119
26
+ ezmsg/learn/process/mlp_old.py,sha256=b86ee-l0RxUSwCuTGe8MdAPbyJLOToIehAKKvf7bais,6958
27
+ ezmsg/learn/process/refit_kalman.py,sha256=PWva0m33cPTjS8GX1aIK1tT5rgEHRub6Pp7GIgIhxsc,15220
28
+ ezmsg/learn/process/rnn.py,sha256=tF3ejKqCResqBBWmSZoa18KpgB36KHhkYrTohWbs8w4,9862
29
+ ezmsg/learn/process/sgd.py,sha256=DTg1MIA9M_K0qNTWF0oHhnMnn11jyL_MjvdAYKg8Gwk,4894
30
+ ezmsg/learn/process/sklearn.py,sha256=5nyMoJffqpLin86h-tZicVIZUmKYxCt1xT2mr5vGkc4,9988
31
+ ezmsg/learn/process/slda.py,sha256=M-zEiySPM4ovn0Os-ZaLRT8tyTTropsxoOj-veajSYg,4417
32
+ ezmsg/learn/process/torch.py,sha256=JMHSoVbuRCpYdg7JuCOeauhQkDHmnjiZ-XfuHp2TZjw,13353
33
+ ezmsg/learn/process/transformer.py,sha256=yhSDrREQy2PaiLVUzhwDppzg9g-NOox1gxatBMxj3BE,7872
34
+ ezmsg_learn-1.0.dist-info/METADATA,sha256=1qQqXG-QgCb9yXuqg6gQi-re5mmPt6gh98RbwfJlJ94,2012
35
+ ezmsg_learn-1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
36
+ ezmsg_learn-1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.27.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any