ezmsg-learn 1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,9 +2,8 @@ import typing
2
2
 
3
3
  import ezmsg.core as ez
4
4
  import torch
5
- from ezmsg.sigproc.base import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit
6
- from ezmsg.sigproc.sampler import SampleMessage
7
- from ezmsg.sigproc.util.profile import profile_subpub
5
+ from ezmsg.baseproc import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit
6
+ from ezmsg.baseproc.util.profile import profile_subpub
8
7
  from ezmsg.util.messages.axisarray import AxisArray
9
8
  from ezmsg.util.messages.util import replace
10
9
 
@@ -51,9 +50,7 @@ class TransformerState(TorchModelState):
51
50
 
52
51
 
53
52
  class TransformerProcessor(
54
- BaseAdaptiveTransformer[
55
- TransformerSettings, AxisArray, AxisArray, TransformerState
56
- ],
53
+ BaseAdaptiveTransformer[TransformerSettings, AxisArray, AxisArray, TransformerState],
57
54
  TorchProcessorMixin,
58
55
  ModelInitMixin,
59
56
  ):
@@ -76,7 +73,8 @@ class TransformerProcessor(
76
73
  and self.settings.autoregressive_head not in self._state.chan_ax
77
74
  ):
78
75
  raise ValueError(
79
- f"Autoregressive head '{self.settings.autoregressive_head}' not found in target dictionary keys: {list(self._state.chan_ax.keys())}"
76
+ f"Autoregressive head '{self.settings.autoregressive_head}' not found in target"
77
+ f"dictionary keys: {list(self._state.chan_ax.keys())}"
80
78
  )
81
79
  self._state.ar_head = (
82
80
  self.settings.autoregressive_head
@@ -101,15 +99,11 @@ class TransformerProcessor(
101
99
  if self._state.tgt_cache is None:
102
100
  self._state.tgt_cache = pred[:, -1:, :]
103
101
  else:
104
- self._state.tgt_cache = torch.cat(
105
- [self._state.tgt_cache, pred[:, -1:, :]], dim=1
106
- )
102
+ self._state.tgt_cache = torch.cat([self._state.tgt_cache, pred[:, -1:, :]], dim=1)
107
103
  if self.settings.max_cache_len is not None:
108
104
  if self._state.tgt_cache.shape[1] > self.settings.max_cache_len:
109
105
  # Trim the cache to the maximum length
110
- self._state.tgt_cache = self._state.tgt_cache[
111
- :, -self.settings.max_cache_len :, :
112
- ]
106
+ self._state.tgt_cache = self._state.tgt_cache[:, -self.settings.max_cache_len :, :]
113
107
 
114
108
  if isinstance(y_pred, dict):
115
109
  return [
@@ -130,13 +124,13 @@ class TransformerProcessor(
130
124
  )
131
125
  ]
132
126
 
133
- def partial_fit(self, message: SampleMessage) -> None:
127
+ def partial_fit(self, message: AxisArray) -> None:
134
128
  self._state.model.train()
135
129
 
136
- X = self._to_tensor(message.sample.data)
130
+ X = self._to_tensor(message.data)
137
131
  X, batched = self._ensure_batched(X)
138
132
 
139
- y_targ = message.trigger.value
133
+ y_targ = message.attrs["trigger"].value
140
134
  if not isinstance(y_targ, dict):
141
135
  y_targ = {"output": y_targ}
142
136
  y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
@@ -186,9 +180,7 @@ class TransformerProcessor(
186
180
  for key in y_targ.keys():
187
181
  loss_fn = loss_fns.get(key)
188
182
  if loss_fn is None:
189
- raise ValueError(
190
- f"Loss function for key '{key}' is not defined in settings."
191
- )
183
+ raise ValueError(f"Loss function for key '{key}' is not defined in settings.")
192
184
  loss = loss_fn(y_pred[key], y_targ[key])
193
185
  weight = weights.get(key, 1.0)
194
186
  losses.append(loss * weight)
ezmsg/learn/util.py CHANGED
@@ -1,10 +1,11 @@
1
- from enum import Enum
2
- from dataclasses import dataclass, field
3
1
  import typing
2
+ from dataclasses import dataclass, field
3
+ from enum import Enum
4
4
 
5
- from ezmsg.util.messages.axisarray import AxisArray
6
- import sklearn.linear_model
7
5
  import river.linear_model
6
+ import sklearn.linear_model
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+
8
9
  # from sklearn.neural_network import MLPClassifier
9
10
 
10
11
 
@@ -1,11 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ezmsg-learn
3
- Version: 1.0
3
+ Version: 1.2.0
4
4
  Summary: ezmsg namespace package for machine learning
5
5
  Author-email: Chadwick Boulay <chadwick.boulay@gmail.com>
6
6
  License-Expression: MIT
7
+ License-File: LICENSE
7
8
  Requires-Python: >=3.10.15
8
- Requires-Dist: ezmsg-sigproc
9
+ Requires-Dist: ezmsg-baseproc>=1.3.0
10
+ Requires-Dist: ezmsg-sigproc>=2.15.0
9
11
  Requires-Dist: river>=0.22.0
10
12
  Requires-Dist: scikit-learn>=1.6.0
11
13
  Requires-Dist: torch>=2.6.0
@@ -24,11 +26,5 @@ Processing units include dimensionality reduction, linear regression, and classi
24
26
  This ezmsg namespace package is still highly experimental and under active development. It is not yet available on PyPI, so you will need to install it from source. The easiest way to do this is to use the `pip` command to install the package directly from GitHub:
25
27
 
26
28
  ```bash
27
- pip install git+ssh://git@github.com/ezmsg-org/ezmsg-learn
28
- ```
29
-
30
- Note that this package depends on a specific version of `ezmsg-sigproc` (specifically, [this branch]("70-use-protocols-for-axisarray-transformers")) that has yet to be merged and released. This may conflict with your project's separate dependency on ezmsg-sigproc. However, this specific version of ezmsg-sigproc should be backwards compatible with its main branch, so in your project you can modify the dependency on ezmsg-sigproc to point to the new branch. e.g.,
31
-
32
- ```bash
33
- pip install git+ssh://git@github.com/ezmsg-org/ezmsg-sigproc@70-use-protocols-for-axisarray-transformers
29
+ pip install git+https://github.com/ezmsg-org/ezmsg-learn
34
30
  ```
@@ -0,0 +1,38 @@
1
+ ezmsg/learn/__init__.py,sha256=9vTW4C2EQCHgDAo8gIlGNDfOxcQpNGV3Cct9-HsBJKY,57
2
+ ezmsg/learn/__version__.py,sha256=-uLONazCO1SzFfcY-K6A1keL--LIVfTYccGX6ciADac,704
3
+ ezmsg/learn/util.py,sha256=cJPu07aWnsh_cIUMuVb0byXqm1CvLv9QO925U1t6oYs,2015
4
+ ezmsg/learn/dim_reduce/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ ezmsg/learn/dim_reduce/adaptive_decomp.py,sha256=SmxBuTATuo3DBrGkHiwC2BTg0rmiAF7kRMX_J8iObLM,9240
6
+ ezmsg/learn/dim_reduce/incremental_decomp.py,sha256=LxBbFf16cEW6YvZ73I6tnbiamJxRfGHCHLXnH_tIC88,6754
7
+ ezmsg/learn/linear_model/__init__.py,sha256=7_bcxc40W6UN2IfnJfjuVHe5mZ0BSPdHLKqCXpHPMwQ,78
8
+ ezmsg/learn/linear_model/adaptive_linear_regressor.py,sha256=AV0MaJ4jASfRMRDAaiKNY0I6a88si1JNhFPuyVpP__k,488
9
+ ezmsg/learn/linear_model/cca.py,sha256=H-NnK9QH5vI0OnGezf674lV-X15xGzgar6d6G-67fZU,57
10
+ ezmsg/learn/linear_model/linear_regressor.py,sha256=3Gi6DfvXIPf1J6wdeyO9nvm5AmSfK2J2CMMDDo2QLRg,297
11
+ ezmsg/learn/linear_model/sgd.py,sha256=xdef6lQ25m9kyqNueDBNtx7haNwrg-ljuyUPfimyxjA,198
12
+ ezmsg/learn/linear_model/slda.py,sha256=qJocbqbB8lgvl62VP0DCXQlF1jDc87CTBBFh0B2Bil8,244
13
+ ezmsg/learn/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ ezmsg/learn/model/cca.py,sha256=wf6vsfGHWiPt4FgZEG6veudKi1-GJDVp0bxaabxjJvE,4262
15
+ ezmsg/learn/model/mlp.py,sha256=zcuw2mNt1LGZxN0XOAxdG52JFWcscjO-JiIFfPNACnc,5476
16
+ ezmsg/learn/model/mlp_old.py,sha256=ODpeoU-6DFG_yZLF9jdBOeRsw1ptyMyWCKM-WGJBeRs,2140
17
+ ezmsg/learn/model/refit_kalman.py,sha256=bi3zoCnhJRwUbqrb4WpCNVSNLpeVUeMQChqcfpFXHJU,15747
18
+ ezmsg/learn/model/rnn.py,sha256=VmcRG2UcxROyrY10AJ-jwompOThtuj6SYjPzHWxAohw,6654
19
+ ezmsg/learn/model/transformer.py,sha256=Vfo_CATOmoNtaqkE4BRWJy6QcX-1J_Pz3rMG2ErQSLM,7535
20
+ ezmsg/learn/nlin_model/__init__.py,sha256=Ap62_lD7Dc3D-15ebhhBzT23PMrdyiV8fWV6eWy6wnE,101
21
+ ezmsg/learn/nlin_model/mlp.py,sha256=l3KNCS7w9KlMiNVRxooqBb7Agl8A4OwnbZiJkWD2lJU,233
22
+ ezmsg/learn/process/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
+ ezmsg/learn/process/adaptive_linear_regressor.py,sha256=1xq4KN_UbTdrsbZ4PqSRs47aAVWMRLZNWJ5ZaRDtJ78,5179
24
+ ezmsg/learn/process/base.py,sha256=1gfCMRh5JGRrKU-Obeh2abH1UW0dxHuc7kgxpBln76A,6655
25
+ ezmsg/learn/process/linear_regressor.py,sha256=8gacrdJJ78U3gBgC62qxvDpZOamZW0kn9NJoRgoLJHc,3032
26
+ ezmsg/learn/process/mlp_old.py,sha256=a2iS45ZlIfonPanerpKROSugYKPghas4JTErO7TM4RU,6700
27
+ ezmsg/learn/process/refit_kalman.py,sha256=wrVHb2ZHIDLT0LVlZsNfsD5rdl8dBECxZnqvxDiJuyE,15065
28
+ ezmsg/learn/process/rnn.py,sha256=HDc4PgZLTEtUVWGFUDqX9qA31Fid37Qj4eYR7jX7aUU,9418
29
+ ezmsg/learn/process/sgd.py,sha256=vkPWDiBKXNhF0EjHbpJt9SskSKw4TRAbhE7uHOYh1Ik,3704
30
+ ezmsg/learn/process/sklearn.py,sha256=Smpo2YITfjs8thVjSQ7kkqRjjSMa2FOL2ArW-fKMn38,9451
31
+ ezmsg/learn/process/slda.py,sha256=BtVKBYkggvlC8_rLkrWfqthOFkKKFv-r9CdX_jWRn2o,4315
32
+ ezmsg/learn/process/ssr.py,sha256=XiHcUSJ3tY3_HvATcQp427KhQIAZ9aVNvewCp-ODO8c,13939
33
+ ezmsg/learn/process/torch.py,sha256=HnFaUhimXM_ki2clcYDPN8wmwA1fpz_HEdcd_XhM4YM,13076
34
+ ezmsg/learn/process/transformer.py,sha256=bnVi87xv95O3plNJVBN5W6pukuS0lKkTvJ7lqkpED9o,7710
35
+ ezmsg_learn-1.2.0.dist-info/METADATA,sha256=v-okrlHbna_orkyBQ7xzzL1jpFgxH-AdhLf4FZ8GNhc,1494
36
+ ezmsg_learn-1.2.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
37
+ ezmsg_learn-1.2.0.dist-info/licenses/LICENSE,sha256=BDD8rfac1Ur7mp0_3izEdr6fHgSA3Or6U1Kb0ZAWsow,1066
38
+ ezmsg_learn-1.2.0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.27.0
2
+ Generator: hatchling 1.28.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 ezmsg-org
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -1,36 +0,0 @@
1
- ezmsg/learn/__init__.py,sha256=9vTW4C2EQCHgDAo8gIlGNDfOxcQpNGV3Cct9-HsBJKY,57
2
- ezmsg/learn/__version__.py,sha256=T-YAefOAMONzdzJN9AfYa3q6PjJ-HRflYoFg45W1xFU,699
3
- ezmsg/learn/util.py,sha256=-WZ3k0sWSIJ1Z9aNiNFFYGa1-8oSbwQc-wI2i86w_C4,2014
4
- ezmsg/learn/dim_reduce/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- ezmsg/learn/dim_reduce/adaptive_decomp.py,sha256=HmjantyGCIt7gP1d0iJ9mfoEtvR3-FjVrpv9JBT6mdc,9332
6
- ezmsg/learn/dim_reduce/incremental_decomp.py,sha256=FRx0Rhn3q8yHe64e4jaHSwWNVJ9eJT1ltuZWZc-C1R8,6830
7
- ezmsg/learn/linear_model/__init__.py,sha256=7_bcxc40W6UN2IfnJfjuVHe5mZ0BSPdHLKqCXpHPMwQ,78
8
- ezmsg/learn/linear_model/adaptive_linear_regressor.py,sha256=zfbwjTaBGdlBzzDBU6Nu2nuUjmhFVYs7Co84VWyRqIE,332
9
- ezmsg/learn/linear_model/cca.py,sha256=H-NnK9QH5vI0OnGezf674lV-X15xGzgar6d6G-67fZU,57
10
- ezmsg/learn/linear_model/linear_regressor.py,sha256=qTkIV2FRMC-Lhg7RTxmX1pWWsov84gV-47x-YSdNGdM,211
11
- ezmsg/learn/linear_model/sgd.py,sha256=6glInxmhapMVlBSZZB01w3vCqkNyGYOHoV8nfzrYCNI,138
12
- ezmsg/learn/linear_model/slda.py,sha256=MmmDbfm5y-8XpLThgJlHVV3f2kqZ-mrpnNkkIUHrwWg,151
13
- ezmsg/learn/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- ezmsg/learn/model/cca.py,sha256=wf6vsfGHWiPt4FgZEG6veudKi1-GJDVp0bxaabxjJvE,4262
15
- ezmsg/learn/model/mlp.py,sha256=bMm4JX53sRdCLgMeh8gAvJnd2qruiTsyuv4MiwLpvCU,5574
16
- ezmsg/learn/model/mlp_old.py,sha256=ODpeoU-6DFG_yZLF9jdBOeRsw1ptyMyWCKM-WGJBeRs,2140
17
- ezmsg/learn/model/refit_kalman.py,sha256=PgjpMpIWkwHX7baSRLcT5loBA6h8uCf1L1spPfM4Nq8,16313
18
- ezmsg/learn/model/rnn.py,sha256=VmcRG2UcxROyrY10AJ-jwompOThtuj6SYjPzHWxAohw,6654
19
- ezmsg/learn/model/transformer.py,sha256=Vfo_CATOmoNtaqkE4BRWJy6QcX-1J_Pz3rMG2ErQSLM,7535
20
- ezmsg/learn/nlin_model/__init__.py,sha256=Ap62_lD7Dc3D-15ebhhBzT23PMrdyiV8fWV6eWy6wnE,101
21
- ezmsg/learn/nlin_model/mlp.py,sha256=zWTtI1JBP4KDcZCAqecHNn4Y59egEE9Bg2vi8-9la7k,165
22
- ezmsg/learn/process/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
- ezmsg/learn/process/adaptive_linear_regressor.py,sha256=0v6xRgWQWub0KlxOlb8ktYuohcwm4gvjHwOtv1qnNHg,5485
24
- ezmsg/learn/process/base.py,sha256=MhJstPfoTBNqLMDBAZqwbKueYyr9HJhvLzeU3k-Wl8E,7023
25
- ezmsg/learn/process/linear_regressor.py,sha256=2uKxvMraWITU0xZwbbI3O_mDD4cysl8RzHgl2XrPOPA,3119
26
- ezmsg/learn/process/mlp_old.py,sha256=b86ee-l0RxUSwCuTGe8MdAPbyJLOToIehAKKvf7bais,6958
27
- ezmsg/learn/process/refit_kalman.py,sha256=PWva0m33cPTjS8GX1aIK1tT5rgEHRub6Pp7GIgIhxsc,15220
28
- ezmsg/learn/process/rnn.py,sha256=tF3ejKqCResqBBWmSZoa18KpgB36KHhkYrTohWbs8w4,9862
29
- ezmsg/learn/process/sgd.py,sha256=DTg1MIA9M_K0qNTWF0oHhnMnn11jyL_MjvdAYKg8Gwk,4894
30
- ezmsg/learn/process/sklearn.py,sha256=5nyMoJffqpLin86h-tZicVIZUmKYxCt1xT2mr5vGkc4,9988
31
- ezmsg/learn/process/slda.py,sha256=M-zEiySPM4ovn0Os-ZaLRT8tyTTropsxoOj-veajSYg,4417
32
- ezmsg/learn/process/torch.py,sha256=JMHSoVbuRCpYdg7JuCOeauhQkDHmnjiZ-XfuHp2TZjw,13353
33
- ezmsg/learn/process/transformer.py,sha256=yhSDrREQy2PaiLVUzhwDppzg9g-NOox1gxatBMxj3BE,7872
34
- ezmsg_learn-1.0.dist-info/METADATA,sha256=1qQqXG-QgCb9yXuqg6gQi-re5mmPt6gh98RbwfJlJ94,2012
35
- ezmsg_learn-1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
36
- ezmsg_learn-1.0.dist-info/RECORD,,