ezmsg-learn 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '1.1.0'
32
- __version_tuple__ = version_tuple = (1, 1, 0)
31
+ __version__ = version = '1.2.0'
32
+ __version_tuple__ = version_tuple = (1, 2, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -11,7 +11,6 @@ from ezmsg.baseproc import (
11
11
  BaseAdaptiveTransformerUnit,
12
12
  processor_state,
13
13
  )
14
- from ezmsg.sigproc.sampler import SampleMessage
15
14
  from ezmsg.util.messages.axisarray import AxisArray, replace
16
15
 
17
16
  from ..util import AdaptiveLinearRegressor, RegressorType, get_regressor
@@ -78,30 +77,30 @@ class AdaptiveLinearRegressorTransformer(
78
77
  # .template is updated in partial_fit
79
78
  pass
80
79
 
81
- def partial_fit(self, message: SampleMessage) -> None:
82
- if np.any(np.isnan(message.sample.data)):
80
+ def partial_fit(self, message: AxisArray) -> None:
81
+ if np.any(np.isnan(message.data)):
83
82
  return
84
83
 
85
84
  if self.settings.model_type in [
86
85
  AdaptiveLinearRegressor.LINEAR,
87
86
  AdaptiveLinearRegressor.LOGISTIC,
88
87
  ]:
89
- x = pd.DataFrame.from_dict({k: v for k, v in zip(message.sample.axes["ch"].data, message.sample.data.T)})
88
+ x = pd.DataFrame.from_dict({k: v for k, v in zip(message.axes["ch"].data, message.data.T)})
90
89
  y = pd.Series(
91
- data=message.trigger.value.data[:, 0],
92
- name=message.trigger.value.axes["ch"].data[0],
90
+ data=message.attrs["trigger"].value.data[:, 0],
91
+ name=message.attrs["trigger"].value.axes["ch"].data[0],
93
92
  )
94
93
  self.state.model.learn_many(x, y)
95
94
  else:
96
- X = message.sample.data
97
- if message.sample.get_axis_idx("time") != 0:
98
- X = np.moveaxis(X, message.sample.get_axis_idx("time"), 0)
99
- self.state.model.partial_fit(X, message.trigger.value.data)
95
+ X = message.data
96
+ if message.get_axis_idx("time") != 0:
97
+ X = np.moveaxis(X, message.get_axis_idx("time"), 0)
98
+ self.state.model.partial_fit(X, message.attrs["trigger"].value.data)
100
99
 
101
100
  self.state.template = replace(
102
- message.trigger.value,
103
- data=np.empty_like(message.trigger.value.data),
104
- key=message.trigger.value.key + "_pred",
101
+ message.attrs["trigger"].value,
102
+ data=np.empty_like(message.attrs["trigger"].value.data),
103
+ key=message.attrs["trigger"].value.key + "_pred",
105
104
  )
106
105
 
107
106
  def _process(self, message: AxisArray) -> AxisArray | None:
@@ -7,7 +7,6 @@ from ezmsg.baseproc import (
7
7
  BaseAdaptiveTransformerUnit,
8
8
  processor_state,
9
9
  )
10
- from ezmsg.sigproc.sampler import SampleMessage
11
10
  from ezmsg.util.messages.axisarray import AxisArray, replace
12
11
  from sklearn.linear_model._base import LinearModel
13
12
 
@@ -53,18 +52,18 @@ class LinearRegressorTransformer(
53
52
  # .model and .template are initialized in __init__
54
53
  pass
55
54
 
56
- def partial_fit(self, message: SampleMessage) -> None:
57
- if np.any(np.isnan(message.sample.data)):
55
+ def partial_fit(self, message: AxisArray) -> None:
56
+ if np.any(np.isnan(message.data)):
58
57
  return
59
58
 
60
- X = message.sample.data
61
- y = message.trigger.value.data
59
+ X = message.data
60
+ y = message.attrs["trigger"].value.data
62
61
  # TODO: Resample should provide identical durations.
63
62
  self.state.model = self.state.model.fit(X[: y.shape[0]], y[: X.shape[0]])
64
63
  self.state.template = replace(
65
- message.trigger.value,
64
+ message.attrs["trigger"].value,
66
65
  data=np.array([[]]),
67
- key=message.trigger.value.key + "_pred",
66
+ key=message.attrs["trigger"].value.key + "_pred",
68
67
  )
69
68
 
70
69
  def _process(self, message: AxisArray) -> AxisArray:
@@ -9,7 +9,6 @@ from ezmsg.baseproc import (
9
9
  BaseAdaptiveTransformerUnit,
10
10
  processor_state,
11
11
  )
12
- from ezmsg.sigproc.sampler import SampleMessage
13
12
  from ezmsg.util.messages.axisarray import AxisArray
14
13
  from ezmsg.util.messages.util import replace
15
14
 
@@ -134,14 +133,14 @@ class MLPProcessor(BaseAdaptiveTransformer[MLPSettings, AxisArray, AxisArray, ML
134
133
  dtype = torch.float32 if self.settings.single_precision else torch.float64
135
134
  return torch.tensor(data, dtype=dtype, device=self._state.device)
136
135
 
137
- def partial_fit(self, message: SampleMessage) -> None:
136
+ def partial_fit(self, message: AxisArray) -> None:
138
137
  self._state.model.train()
139
138
 
140
139
  # TODO: loss_fn should be determined by setting
141
140
  loss_fn = torch.nn.functional.mse_loss
142
141
 
143
- X = self._to_tensor(message.sample.data)
144
- y_targ = self._to_tensor(message.trigger.value)
142
+ X = self._to_tensor(message.data)
143
+ y_targ = self._to_tensor(message.attrs["trigger"].value)
145
144
 
146
145
  with torch.set_grad_enabled(True):
147
146
  self._state.model.train()
@@ -8,7 +8,6 @@ from ezmsg.baseproc import (
8
8
  BaseAdaptiveTransformerUnit,
9
9
  processor_state,
10
10
  )
11
- from ezmsg.sigproc.sampler import SampleMessage
12
11
  from ezmsg.util.messages.axisarray import AxisArray
13
12
  from ezmsg.util.messages.util import replace
14
13
 
@@ -284,22 +283,22 @@ class RefitKalmanFilterProcessor(
284
283
  key=f"{message.key}_filtered" if hasattr(message, "key") else "filtered",
285
284
  )
286
285
 
287
- def partial_fit(self, message: SampleMessage) -> None:
286
+ def partial_fit(self, message: AxisArray) -> None:
288
287
  """
289
288
  Perform refitting using externally provided data.
290
289
 
291
- Expects message.sample.data (neural input) and message.trigger.value as a dict with:
290
+ Expects message.data (neural input) and message.attrs["trigger"].value as a dict with:
292
291
  - Y_state: (n_samples, n_states) array
293
292
  - intention_velocity_indices: Optional[int]
294
293
  - target_positions: Optional[np.ndarray]
295
294
  - cursor_positions: Optional[np.ndarray]
296
295
  - hold_flags: Optional[list[bool]]
297
296
  """
298
- if not hasattr(message, "sample") or not hasattr(message, "trigger"):
297
+ if "trigger" not in message.attrs:
299
298
  raise ValueError("Invalid message format for partial_fit.")
300
299
 
301
- X = np.array(message.sample.data)
302
- values = message.trigger.value
300
+ X = np.array(message.data)
301
+ values = message.attrs["trigger"].value
303
302
 
304
303
  if not isinstance(values, dict) or "Y_state" not in values:
305
304
  raise ValueError("partial_fit expects trigger.value to include at least 'Y_state'.")
@@ -5,7 +5,6 @@ import numpy as np
5
5
  import torch
6
6
  from ezmsg.baseproc import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit
7
7
  from ezmsg.baseproc.util.profile import profile_subpub
8
- from ezmsg.sigproc.sampler import SampleMessage
9
8
  from ezmsg.util.messages.axisarray import AxisArray
10
9
  from ezmsg.util.messages.util import replace
11
10
 
@@ -184,18 +183,18 @@ class RNNProcessor(
184
183
  if self._state.scheduler is not None:
185
184
  self._state.scheduler.step()
186
185
 
187
- def partial_fit(self, message: SampleMessage) -> None:
186
+ def partial_fit(self, message: AxisArray) -> None:
188
187
  self._state.model.train()
189
188
 
190
- X = self._to_tensor(message.sample.data)
189
+ X = self._to_tensor(message.data)
191
190
 
192
191
  # Add batch dimension if missing
193
192
  X, batched = self._ensure_batched(X)
194
193
 
195
194
  batch_size = X.shape[0]
196
- preserve_state = self._maybe_reset_state(message.sample, batch_size)
195
+ preserve_state = self._maybe_reset_state(message, batch_size)
197
196
 
198
- y_targ = message.trigger.value
197
+ y_targ = message.attrs["trigger"].value
199
198
  if not isinstance(y_targ, dict):
200
199
  y_targ = {"output": y_targ}
201
200
  y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
@@ -5,7 +5,6 @@ import numpy as np
5
5
  from ezmsg.baseproc import (
6
6
  BaseAdaptiveTransformer,
7
7
  BaseAdaptiveTransformerUnit,
8
- SampleMessage,
9
8
  processor_state,
10
9
  )
11
10
  from ezmsg.util.messages.axisarray import AxisArray
@@ -87,23 +86,23 @@ class SGDDecoderTransformer(BaseAdaptiveTransformer[SGDDecoderSettings, AxisArra
87
86
  key=message.key,
88
87
  )
89
88
 
90
- def partial_fit(self, message: SampleMessage) -> None:
89
+ def partial_fit(self, message: AxisArray) -> None:
91
90
  if self._hash != 0:
92
- self._reset_state(message.sample)
91
+ self._reset_state(message)
93
92
  self._hash = 0
94
93
 
95
- if np.any(np.isnan(message.sample.data)):
94
+ if np.any(np.isnan(message.data)):
96
95
  return
97
- train_sample = message.sample.data.reshape(1, -1)
96
+ train_sample = message.data.reshape(1, -1)
98
97
  if self._state.b_first_train:
99
98
  self._state.model.partial_fit(
100
99
  train_sample,
101
- [message.trigger.value],
100
+ [message.attrs["trigger"].value],
102
101
  classes=list(self.settings.label_weights.keys()),
103
102
  )
104
103
  self._state.b_first_train = False
105
104
  else:
106
- self._state.model.partial_fit(train_sample, [message.trigger.value])
105
+ self._state.model.partial_fit(train_sample, [message.attrs["trigger"].value])
107
106
 
108
107
 
109
108
  class SGDDecoder(
@@ -10,7 +10,6 @@ from ezmsg.baseproc import (
10
10
  BaseAdaptiveTransformerUnit,
11
11
  processor_state,
12
12
  )
13
- from ezmsg.sigproc.sampler import SampleMessage
14
13
  from ezmsg.util.messages.axisarray import AxisArray
15
14
  from ezmsg.util.messages.util import replace
16
15
 
@@ -116,25 +115,25 @@ class SklearnModelProcessor(BaseAdaptiveTransformer[SklearnModelSettings, AxisAr
116
115
  # No checkpoint, initialize from scratch
117
116
  self._init_model()
118
117
 
119
- def partial_fit(self, message: SampleMessage) -> None:
120
- X = message.sample.data
121
- y = message.trigger.value
118
+ def partial_fit(self, message: AxisArray) -> None:
119
+ X = message.data
120
+ y = message.attrs["trigger"].value
122
121
  if self._state.model is None:
123
- self._reset_state(message.sample)
122
+ self._reset_state(message)
124
123
  if hasattr(self._state.model, "partial_fit"):
125
124
  kwargs = {}
126
125
  if self.settings.partial_fit_classes is not None:
127
126
  kwargs["classes"] = self.settings.partial_fit_classes
128
127
  self._state.model.partial_fit(X, y, **kwargs)
129
128
  elif hasattr(self._state.model, "learn_many"):
130
- df_X = pd.DataFrame({k: v for k, v in zip(message.sample.axes["ch"].data, message.sample.data.T)})
129
+ df_X = pd.DataFrame({k: v for k, v in zip(message.axes["ch"].data, message.data.T)})
131
130
  name = (
132
- message.trigger.value.axes["ch"].data[0]
133
- if hasattr(message.trigger.value, "axes") and "ch" in message.trigger.value.axes
131
+ message.attrs["trigger"].value.axes["ch"].data[0]
132
+ if hasattr(message.attrs["trigger"].value, "axes") and "ch" in message.attrs["trigger"].value.axes
134
133
  else "target"
135
134
  )
136
135
  ser_y = pd.Series(
137
- data=np.asarray(message.trigger.value.data).flatten(),
136
+ data=np.asarray(message.attrs["trigger"].value.data).flatten(),
138
137
  name=name,
139
138
  )
140
139
  self._state.model.learn_many(df_X, ser_y)
@@ -12,7 +12,6 @@ from ezmsg.baseproc import (
12
12
  processor_state,
13
13
  )
14
14
  from ezmsg.baseproc.util.profile import profile_subpub
15
- from ezmsg.sigproc.sampler import SampleMessage
16
15
  from ezmsg.util.messages.axisarray import AxisArray
17
16
  from ezmsg.util.messages.util import replace
18
17
 
@@ -294,13 +293,13 @@ class TorchModelProcessor(
294
293
  def _process(self, message: AxisArray) -> list[AxisArray]:
295
294
  return self._common_process(message)
296
295
 
297
- def partial_fit(self, message: SampleMessage) -> None:
296
+ def partial_fit(self, message: AxisArray) -> None:
298
297
  self._state.model.train()
299
298
 
300
- X = self._to_tensor(message.sample.data)
299
+ X = self._to_tensor(message.data)
301
300
  X, batched = self._ensure_batched(X)
302
301
 
303
- y_targ = message.trigger.value
302
+ y_targ = message.attrs["trigger"].value
304
303
  if not isinstance(y_targ, dict):
305
304
  y_targ = {"output": y_targ}
306
305
  y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
@@ -4,7 +4,6 @@ import ezmsg.core as ez
4
4
  import torch
5
5
  from ezmsg.baseproc import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit
6
6
  from ezmsg.baseproc.util.profile import profile_subpub
7
- from ezmsg.sigproc.sampler import SampleMessage
8
7
  from ezmsg.util.messages.axisarray import AxisArray
9
8
  from ezmsg.util.messages.util import replace
10
9
 
@@ -125,13 +124,13 @@ class TransformerProcessor(
125
124
  )
126
125
  ]
127
126
 
128
- def partial_fit(self, message: SampleMessage) -> None:
127
+ def partial_fit(self, message: AxisArray) -> None:
129
128
  self._state.model.train()
130
129
 
131
- X = self._to_tensor(message.sample.data)
130
+ X = self._to_tensor(message.data)
132
131
  X, batched = self._ensure_batched(X)
133
132
 
134
- y_targ = message.trigger.value
133
+ y_targ = message.attrs["trigger"].value
135
134
  if not isinstance(y_targ, dict):
136
135
  y_targ = {"output": y_targ}
137
136
  y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
@@ -1,13 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ezmsg-learn
3
- Version: 1.1.0
3
+ Version: 1.2.0
4
4
  Summary: ezmsg namespace package for machine learning
5
5
  Author-email: Chadwick Boulay <chadwick.boulay@gmail.com>
6
6
  License-Expression: MIT
7
7
  License-File: LICENSE
8
8
  Requires-Python: >=3.10.15
9
- Requires-Dist: ezmsg-baseproc>=1.0.2
10
- Requires-Dist: ezmsg-sigproc>=2.14.0
9
+ Requires-Dist: ezmsg-baseproc>=1.3.0
10
+ Requires-Dist: ezmsg-sigproc>=2.15.0
11
11
  Requires-Dist: river>=0.22.0
12
12
  Requires-Dist: scikit-learn>=1.6.0
13
13
  Requires-Dist: torch>=2.6.0
@@ -1,5 +1,5 @@
1
1
  ezmsg/learn/__init__.py,sha256=9vTW4C2EQCHgDAo8gIlGNDfOxcQpNGV3Cct9-HsBJKY,57
2
- ezmsg/learn/__version__.py,sha256=ePNVzJOkxR8FY5bezqKQ_fgBRbzH1G7QTaRDHvGQRAY,704
2
+ ezmsg/learn/__version__.py,sha256=-uLONazCO1SzFfcY-K6A1keL--LIVfTYccGX6ciADac,704
3
3
  ezmsg/learn/util.py,sha256=cJPu07aWnsh_cIUMuVb0byXqm1CvLv9QO925U1t6oYs,2015
4
4
  ezmsg/learn/dim_reduce/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  ezmsg/learn/dim_reduce/adaptive_decomp.py,sha256=SmxBuTATuo3DBrGkHiwC2BTg0rmiAF7kRMX_J8iObLM,9240
@@ -20,19 +20,19 @@ ezmsg/learn/model/transformer.py,sha256=Vfo_CATOmoNtaqkE4BRWJy6QcX-1J_Pz3rMG2ErQ
20
20
  ezmsg/learn/nlin_model/__init__.py,sha256=Ap62_lD7Dc3D-15ebhhBzT23PMrdyiV8fWV6eWy6wnE,101
21
21
  ezmsg/learn/nlin_model/mlp.py,sha256=l3KNCS7w9KlMiNVRxooqBb7Agl8A4OwnbZiJkWD2lJU,233
22
22
  ezmsg/learn/process/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
- ezmsg/learn/process/adaptive_linear_regressor.py,sha256=-TJsMJlcfIvwflSg4KXursseFb064cX9d-R65CLomJ0,5219
23
+ ezmsg/learn/process/adaptive_linear_regressor.py,sha256=1xq4KN_UbTdrsbZ4PqSRs47aAVWMRLZNWJ5ZaRDtJ78,5179
24
24
  ezmsg/learn/process/base.py,sha256=1gfCMRh5JGRrKU-Obeh2abH1UW0dxHuc7kgxpBln76A,6655
25
- ezmsg/learn/process/linear_regressor.py,sha256=9cU25vXjVaClwBxR79M3qLi7sAKaEMAebwinITi5XYs,3071
26
- ezmsg/learn/process/mlp_old.py,sha256=5jw73VfUjHoY6kjVy22btyWeNhnYK19RpuzdXgxP324,6750
27
- ezmsg/learn/process/refit_kalman.py,sha256=nLQYWFbyP3BHrolvWbg-4HyRsaXs466nb9SU4-tMLEg,15148
28
- ezmsg/learn/process/rnn.py,sha256=5tYMe2EI71s95rTy-x-72s7tEKndcQsMvL-tf971HS8,9475
29
- ezmsg/learn/process/sgd.py,sha256=K3BgPLVszmTIeLiWi0uKq_loZCIy5XrSWKur3oRan98,3730
30
- ezmsg/learn/process/sklearn.py,sha256=Id8km2_8goHuErEtWLpFxqrrSNLpISZ6pGEPaNQ84yY,9486
25
+ ezmsg/learn/process/linear_regressor.py,sha256=8gacrdJJ78U3gBgC62qxvDpZOamZW0kn9NJoRgoLJHc,3032
26
+ ezmsg/learn/process/mlp_old.py,sha256=a2iS45ZlIfonPanerpKROSugYKPghas4JTErO7TM4RU,6700
27
+ ezmsg/learn/process/refit_kalman.py,sha256=wrVHb2ZHIDLT0LVlZsNfsD5rdl8dBECxZnqvxDiJuyE,15065
28
+ ezmsg/learn/process/rnn.py,sha256=HDc4PgZLTEtUVWGFUDqX9qA31Fid37Qj4eYR7jX7aUU,9418
29
+ ezmsg/learn/process/sgd.py,sha256=vkPWDiBKXNhF0EjHbpJt9SskSKw4TRAbhE7uHOYh1Ik,3704
30
+ ezmsg/learn/process/sklearn.py,sha256=Smpo2YITfjs8thVjSQ7kkqRjjSMa2FOL2ArW-fKMn38,9451
31
31
  ezmsg/learn/process/slda.py,sha256=BtVKBYkggvlC8_rLkrWfqthOFkKKFv-r9CdX_jWRn2o,4315
32
32
  ezmsg/learn/process/ssr.py,sha256=XiHcUSJ3tY3_HvATcQp427KhQIAZ9aVNvewCp-ODO8c,13939
33
- ezmsg/learn/process/torch.py,sha256=wvmeXmVzTPW5lEVf8bU4iu2gWg1mGuO0F6TIEEmNtP4,13126
34
- ezmsg/learn/process/transformer.py,sha256=e9XQFfpduDCUDJAwJ0ottQ9N2eZo-LC8firQlpuKheY,7760
35
- ezmsg_learn-1.1.0.dist-info/METADATA,sha256=dzHQzlVQahT_LEcZr3MZBcXBrVohxVSUJCG9F09ikHg,1494
36
- ezmsg_learn-1.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
37
- ezmsg_learn-1.1.0.dist-info/licenses/LICENSE,sha256=BDD8rfac1Ur7mp0_3izEdr6fHgSA3Or6U1Kb0ZAWsow,1066
38
- ezmsg_learn-1.1.0.dist-info/RECORD,,
33
+ ezmsg/learn/process/torch.py,sha256=HnFaUhimXM_ki2clcYDPN8wmwA1fpz_HEdcd_XhM4YM,13076
34
+ ezmsg/learn/process/transformer.py,sha256=bnVi87xv95O3plNJVBN5W6pukuS0lKkTvJ7lqkpED9o,7710
35
+ ezmsg_learn-1.2.0.dist-info/METADATA,sha256=v-okrlHbna_orkyBQ7xzzL1jpFgxH-AdhLf4FZ8GNhc,1494
36
+ ezmsg_learn-1.2.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
37
+ ezmsg_learn-1.2.0.dist-info/licenses/LICENSE,sha256=BDD8rfac1Ur7mp0_3izEdr6fHgSA3Or6U1Kb0ZAWsow,1066
38
+ ezmsg_learn-1.2.0.dist-info/RECORD,,