ezmsg-learn 1.1.0__tar.gz → 1.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/PKG-INFO +3 -3
  2. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/pyproject.toml +2 -2
  3. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/__version__.py +2 -2
  4. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/adaptive_linear_regressor.py +12 -13
  5. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/linear_regressor.py +6 -7
  6. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/mlp_old.py +3 -4
  7. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/refit_kalman.py +5 -6
  8. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/rnn.py +4 -5
  9. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/sgd.py +6 -7
  10. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/sklearn.py +8 -9
  11. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/torch.py +3 -4
  12. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/transformer.py +3 -4
  13. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_adaptive_linear_regressor.py +2 -2
  14. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_linear_regressor.py +2 -2
  15. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_mlp.py +9 -9
  16. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_mlp_old.py +10 -5
  17. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_refit_kalman.py +7 -8
  18. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_rnn.py +19 -25
  19. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_sgd.py +5 -4
  20. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_sklearn.py +12 -13
  21. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_torch.py +9 -14
  22. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_transformer.py +17 -19
  23. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/.github/workflows/docs.yml +0 -0
  24. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/.github/workflows/python-publish.yml +0 -0
  25. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/.github/workflows/python-tests.yml +0 -0
  26. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/.gitignore +0 -0
  27. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/.pre-commit-config.yaml +0 -0
  28. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/LICENSE +0 -0
  29. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/README.md +0 -0
  30. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/docs/Makefile +0 -0
  31. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/docs/make.bat +0 -0
  32. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/docs/source/_templates/autosummary/module.rst +0 -0
  33. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/docs/source/api/index.rst +0 -0
  34. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/docs/source/conf.py +0 -0
  35. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/docs/source/guides/classification.rst +0 -0
  36. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/docs/source/index.rst +0 -0
  37. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/__init__.py +0 -0
  38. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/dim_reduce/__init__.py +0 -0
  39. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/dim_reduce/adaptive_decomp.py +0 -0
  40. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/dim_reduce/incremental_decomp.py +0 -0
  41. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/linear_model/__init__.py +0 -0
  42. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/linear_model/adaptive_linear_regressor.py +0 -0
  43. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/linear_model/cca.py +0 -0
  44. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/linear_model/linear_regressor.py +0 -0
  45. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/linear_model/sgd.py +0 -0
  46. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/linear_model/slda.py +0 -0
  47. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/model/__init__.py +0 -0
  48. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/model/cca.py +0 -0
  49. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/model/mlp.py +0 -0
  50. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/model/mlp_old.py +0 -0
  51. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/model/refit_kalman.py +0 -0
  52. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/model/rnn.py +0 -0
  53. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/model/transformer.py +0 -0
  54. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/nlin_model/__init__.py +0 -0
  55. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/nlin_model/mlp.py +0 -0
  56. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/__init__.py +0 -0
  57. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/base.py +0 -0
  58. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/slda.py +0 -0
  59. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/ssr.py +0 -0
  60. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/util.py +0 -0
  61. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/benchmark/bench_lrr.py +0 -0
  62. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/dim_reduce/test_adaptive_decomp.py +0 -0
  63. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/dim_reduce/test_incremental_decomp.py +0 -0
  64. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/integration/conftest.py +0 -0
  65. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/integration/test_mlp_system.py +0 -0
  66. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/integration/test_refit_kalman_system.py +0 -0
  67. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/integration/test_rnn_system.py +0 -0
  68. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/integration/test_sklearn_system.py +0 -0
  69. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/integration/test_torch_system.py +0 -0
  70. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/integration/test_transformer_system.py +0 -0
  71. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_slda.py +0 -0
  72. {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_ssr.py +0 -0
@@ -1,13 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ezmsg-learn
3
- Version: 1.1.0
3
+ Version: 1.2.0
4
4
  Summary: ezmsg namespace package for machine learning
5
5
  Author-email: Chadwick Boulay <chadwick.boulay@gmail.com>
6
6
  License-Expression: MIT
7
7
  License-File: LICENSE
8
8
  Requires-Python: >=3.10.15
9
- Requires-Dist: ezmsg-baseproc>=1.0.2
10
- Requires-Dist: ezmsg-sigproc>=2.14.0
9
+ Requires-Dist: ezmsg-baseproc>=1.3.0
10
+ Requires-Dist: ezmsg-sigproc>=2.15.0
11
11
  Requires-Dist: river>=0.22.0
12
12
  Requires-Dist: scikit-learn>=1.6.0
13
13
  Requires-Dist: torch>=2.6.0
@@ -9,8 +9,8 @@ license = "MIT"
9
9
  requires-python = ">=3.10.15"
10
10
  dynamic = ["version"]
11
11
  dependencies = [
12
- "ezmsg-baseproc>=1.0.2",
13
- "ezmsg-sigproc>=2.14.0",
12
+ "ezmsg-baseproc>=1.3.0",
13
+ "ezmsg-sigproc>=2.15.0",
14
14
  "river>=0.22.0",
15
15
  "scikit-learn>=1.6.0",
16
16
  "torch>=2.6.0",
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '1.1.0'
32
- __version_tuple__ = version_tuple = (1, 1, 0)
31
+ __version__ = version = '1.2.0'
32
+ __version_tuple__ = version_tuple = (1, 2, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -11,7 +11,6 @@ from ezmsg.baseproc import (
11
11
  BaseAdaptiveTransformerUnit,
12
12
  processor_state,
13
13
  )
14
- from ezmsg.sigproc.sampler import SampleMessage
15
14
  from ezmsg.util.messages.axisarray import AxisArray, replace
16
15
 
17
16
  from ..util import AdaptiveLinearRegressor, RegressorType, get_regressor
@@ -78,30 +77,30 @@ class AdaptiveLinearRegressorTransformer(
78
77
  # .template is updated in partial_fit
79
78
  pass
80
79
 
81
- def partial_fit(self, message: SampleMessage) -> None:
82
- if np.any(np.isnan(message.sample.data)):
80
+ def partial_fit(self, message: AxisArray) -> None:
81
+ if np.any(np.isnan(message.data)):
83
82
  return
84
83
 
85
84
  if self.settings.model_type in [
86
85
  AdaptiveLinearRegressor.LINEAR,
87
86
  AdaptiveLinearRegressor.LOGISTIC,
88
87
  ]:
89
- x = pd.DataFrame.from_dict({k: v for k, v in zip(message.sample.axes["ch"].data, message.sample.data.T)})
88
+ x = pd.DataFrame.from_dict({k: v for k, v in zip(message.axes["ch"].data, message.data.T)})
90
89
  y = pd.Series(
91
- data=message.trigger.value.data[:, 0],
92
- name=message.trigger.value.axes["ch"].data[0],
90
+ data=message.attrs["trigger"].value.data[:, 0],
91
+ name=message.attrs["trigger"].value.axes["ch"].data[0],
93
92
  )
94
93
  self.state.model.learn_many(x, y)
95
94
  else:
96
- X = message.sample.data
97
- if message.sample.get_axis_idx("time") != 0:
98
- X = np.moveaxis(X, message.sample.get_axis_idx("time"), 0)
99
- self.state.model.partial_fit(X, message.trigger.value.data)
95
+ X = message.data
96
+ if message.get_axis_idx("time") != 0:
97
+ X = np.moveaxis(X, message.get_axis_idx("time"), 0)
98
+ self.state.model.partial_fit(X, message.attrs["trigger"].value.data)
100
99
 
101
100
  self.state.template = replace(
102
- message.trigger.value,
103
- data=np.empty_like(message.trigger.value.data),
104
- key=message.trigger.value.key + "_pred",
101
+ message.attrs["trigger"].value,
102
+ data=np.empty_like(message.attrs["trigger"].value.data),
103
+ key=message.attrs["trigger"].value.key + "_pred",
105
104
  )
106
105
 
107
106
  def _process(self, message: AxisArray) -> AxisArray | None:
@@ -7,7 +7,6 @@ from ezmsg.baseproc import (
7
7
  BaseAdaptiveTransformerUnit,
8
8
  processor_state,
9
9
  )
10
- from ezmsg.sigproc.sampler import SampleMessage
11
10
  from ezmsg.util.messages.axisarray import AxisArray, replace
12
11
  from sklearn.linear_model._base import LinearModel
13
12
 
@@ -53,18 +52,18 @@ class LinearRegressorTransformer(
53
52
  # .model and .template are initialized in __init__
54
53
  pass
55
54
 
56
- def partial_fit(self, message: SampleMessage) -> None:
57
- if np.any(np.isnan(message.sample.data)):
55
+ def partial_fit(self, message: AxisArray) -> None:
56
+ if np.any(np.isnan(message.data)):
58
57
  return
59
58
 
60
- X = message.sample.data
61
- y = message.trigger.value.data
59
+ X = message.data
60
+ y = message.attrs["trigger"].value.data
62
61
  # TODO: Resample should provide identical durations.
63
62
  self.state.model = self.state.model.fit(X[: y.shape[0]], y[: X.shape[0]])
64
63
  self.state.template = replace(
65
- message.trigger.value,
64
+ message.attrs["trigger"].value,
66
65
  data=np.array([[]]),
67
- key=message.trigger.value.key + "_pred",
66
+ key=message.attrs["trigger"].value.key + "_pred",
68
67
  )
69
68
 
70
69
  def _process(self, message: AxisArray) -> AxisArray:
@@ -9,7 +9,6 @@ from ezmsg.baseproc import (
9
9
  BaseAdaptiveTransformerUnit,
10
10
  processor_state,
11
11
  )
12
- from ezmsg.sigproc.sampler import SampleMessage
13
12
  from ezmsg.util.messages.axisarray import AxisArray
14
13
  from ezmsg.util.messages.util import replace
15
14
 
@@ -134,14 +133,14 @@ class MLPProcessor(BaseAdaptiveTransformer[MLPSettings, AxisArray, AxisArray, ML
134
133
  dtype = torch.float32 if self.settings.single_precision else torch.float64
135
134
  return torch.tensor(data, dtype=dtype, device=self._state.device)
136
135
 
137
- def partial_fit(self, message: SampleMessage) -> None:
136
+ def partial_fit(self, message: AxisArray) -> None:
138
137
  self._state.model.train()
139
138
 
140
139
  # TODO: loss_fn should be determined by setting
141
140
  loss_fn = torch.nn.functional.mse_loss
142
141
 
143
- X = self._to_tensor(message.sample.data)
144
- y_targ = self._to_tensor(message.trigger.value)
142
+ X = self._to_tensor(message.data)
143
+ y_targ = self._to_tensor(message.attrs["trigger"].value)
145
144
 
146
145
  with torch.set_grad_enabled(True):
147
146
  self._state.model.train()
@@ -8,7 +8,6 @@ from ezmsg.baseproc import (
8
8
  BaseAdaptiveTransformerUnit,
9
9
  processor_state,
10
10
  )
11
- from ezmsg.sigproc.sampler import SampleMessage
12
11
  from ezmsg.util.messages.axisarray import AxisArray
13
12
  from ezmsg.util.messages.util import replace
14
13
 
@@ -284,22 +283,22 @@ class RefitKalmanFilterProcessor(
284
283
  key=f"{message.key}_filtered" if hasattr(message, "key") else "filtered",
285
284
  )
286
285
 
287
- def partial_fit(self, message: SampleMessage) -> None:
286
+ def partial_fit(self, message: AxisArray) -> None:
288
287
  """
289
288
  Perform refitting using externally provided data.
290
289
 
291
- Expects message.sample.data (neural input) and message.trigger.value as a dict with:
290
+ Expects message.data (neural input) and message.attrs["trigger"].value as a dict with:
292
291
  - Y_state: (n_samples, n_states) array
293
292
  - intention_velocity_indices: Optional[int]
294
293
  - target_positions: Optional[np.ndarray]
295
294
  - cursor_positions: Optional[np.ndarray]
296
295
  - hold_flags: Optional[list[bool]]
297
296
  """
298
- if not hasattr(message, "sample") or not hasattr(message, "trigger"):
297
+ if "trigger" not in message.attrs:
299
298
  raise ValueError("Invalid message format for partial_fit.")
300
299
 
301
- X = np.array(message.sample.data)
302
- values = message.trigger.value
300
+ X = np.array(message.data)
301
+ values = message.attrs["trigger"].value
303
302
 
304
303
  if not isinstance(values, dict) or "Y_state" not in values:
305
304
  raise ValueError("partial_fit expects trigger.value to include at least 'Y_state'.")
@@ -5,7 +5,6 @@ import numpy as np
5
5
  import torch
6
6
  from ezmsg.baseproc import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit
7
7
  from ezmsg.baseproc.util.profile import profile_subpub
8
- from ezmsg.sigproc.sampler import SampleMessage
9
8
  from ezmsg.util.messages.axisarray import AxisArray
10
9
  from ezmsg.util.messages.util import replace
11
10
 
@@ -184,18 +183,18 @@ class RNNProcessor(
184
183
  if self._state.scheduler is not None:
185
184
  self._state.scheduler.step()
186
185
 
187
- def partial_fit(self, message: SampleMessage) -> None:
186
+ def partial_fit(self, message: AxisArray) -> None:
188
187
  self._state.model.train()
189
188
 
190
- X = self._to_tensor(message.sample.data)
189
+ X = self._to_tensor(message.data)
191
190
 
192
191
  # Add batch dimension if missing
193
192
  X, batched = self._ensure_batched(X)
194
193
 
195
194
  batch_size = X.shape[0]
196
- preserve_state = self._maybe_reset_state(message.sample, batch_size)
195
+ preserve_state = self._maybe_reset_state(message, batch_size)
197
196
 
198
- y_targ = message.trigger.value
197
+ y_targ = message.attrs["trigger"].value
199
198
  if not isinstance(y_targ, dict):
200
199
  y_targ = {"output": y_targ}
201
200
  y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
@@ -5,7 +5,6 @@ import numpy as np
5
5
  from ezmsg.baseproc import (
6
6
  BaseAdaptiveTransformer,
7
7
  BaseAdaptiveTransformerUnit,
8
- SampleMessage,
9
8
  processor_state,
10
9
  )
11
10
  from ezmsg.util.messages.axisarray import AxisArray
@@ -87,23 +86,23 @@ class SGDDecoderTransformer(BaseAdaptiveTransformer[SGDDecoderSettings, AxisArra
87
86
  key=message.key,
88
87
  )
89
88
 
90
- def partial_fit(self, message: SampleMessage) -> None:
89
+ def partial_fit(self, message: AxisArray) -> None:
91
90
  if self._hash != 0:
92
- self._reset_state(message.sample)
91
+ self._reset_state(message)
93
92
  self._hash = 0
94
93
 
95
- if np.any(np.isnan(message.sample.data)):
94
+ if np.any(np.isnan(message.data)):
96
95
  return
97
- train_sample = message.sample.data.reshape(1, -1)
96
+ train_sample = message.data.reshape(1, -1)
98
97
  if self._state.b_first_train:
99
98
  self._state.model.partial_fit(
100
99
  train_sample,
101
- [message.trigger.value],
100
+ [message.attrs["trigger"].value],
102
101
  classes=list(self.settings.label_weights.keys()),
103
102
  )
104
103
  self._state.b_first_train = False
105
104
  else:
106
- self._state.model.partial_fit(train_sample, [message.trigger.value])
105
+ self._state.model.partial_fit(train_sample, [message.attrs["trigger"].value])
107
106
 
108
107
 
109
108
  class SGDDecoder(
@@ -10,7 +10,6 @@ from ezmsg.baseproc import (
10
10
  BaseAdaptiveTransformerUnit,
11
11
  processor_state,
12
12
  )
13
- from ezmsg.sigproc.sampler import SampleMessage
14
13
  from ezmsg.util.messages.axisarray import AxisArray
15
14
  from ezmsg.util.messages.util import replace
16
15
 
@@ -116,25 +115,25 @@ class SklearnModelProcessor(BaseAdaptiveTransformer[SklearnModelSettings, AxisAr
116
115
  # No checkpoint, initialize from scratch
117
116
  self._init_model()
118
117
 
119
- def partial_fit(self, message: SampleMessage) -> None:
120
- X = message.sample.data
121
- y = message.trigger.value
118
+ def partial_fit(self, message: AxisArray) -> None:
119
+ X = message.data
120
+ y = message.attrs["trigger"].value
122
121
  if self._state.model is None:
123
- self._reset_state(message.sample)
122
+ self._reset_state(message)
124
123
  if hasattr(self._state.model, "partial_fit"):
125
124
  kwargs = {}
126
125
  if self.settings.partial_fit_classes is not None:
127
126
  kwargs["classes"] = self.settings.partial_fit_classes
128
127
  self._state.model.partial_fit(X, y, **kwargs)
129
128
  elif hasattr(self._state.model, "learn_many"):
130
- df_X = pd.DataFrame({k: v for k, v in zip(message.sample.axes["ch"].data, message.sample.data.T)})
129
+ df_X = pd.DataFrame({k: v for k, v in zip(message.axes["ch"].data, message.data.T)})
131
130
  name = (
132
- message.trigger.value.axes["ch"].data[0]
133
- if hasattr(message.trigger.value, "axes") and "ch" in message.trigger.value.axes
131
+ message.attrs["trigger"].value.axes["ch"].data[0]
132
+ if hasattr(message.attrs["trigger"].value, "axes") and "ch" in message.attrs["trigger"].value.axes
134
133
  else "target"
135
134
  )
136
135
  ser_y = pd.Series(
137
- data=np.asarray(message.trigger.value.data).flatten(),
136
+ data=np.asarray(message.attrs["trigger"].value.data).flatten(),
138
137
  name=name,
139
138
  )
140
139
  self._state.model.learn_many(df_X, ser_y)
@@ -12,7 +12,6 @@ from ezmsg.baseproc import (
12
12
  processor_state,
13
13
  )
14
14
  from ezmsg.baseproc.util.profile import profile_subpub
15
- from ezmsg.sigproc.sampler import SampleMessage
16
15
  from ezmsg.util.messages.axisarray import AxisArray
17
16
  from ezmsg.util.messages.util import replace
18
17
 
@@ -294,13 +293,13 @@ class TorchModelProcessor(
294
293
  def _process(self, message: AxisArray) -> list[AxisArray]:
295
294
  return self._common_process(message)
296
295
 
297
- def partial_fit(self, message: SampleMessage) -> None:
296
+ def partial_fit(self, message: AxisArray) -> None:
298
297
  self._state.model.train()
299
298
 
300
- X = self._to_tensor(message.sample.data)
299
+ X = self._to_tensor(message.data)
301
300
  X, batched = self._ensure_batched(X)
302
301
 
303
- y_targ = message.trigger.value
302
+ y_targ = message.attrs["trigger"].value
304
303
  if not isinstance(y_targ, dict):
305
304
  y_targ = {"output": y_targ}
306
305
  y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
@@ -4,7 +4,6 @@ import ezmsg.core as ez
4
4
  import torch
5
5
  from ezmsg.baseproc import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit
6
6
  from ezmsg.baseproc.util.profile import profile_subpub
7
- from ezmsg.sigproc.sampler import SampleMessage
8
7
  from ezmsg.util.messages.axisarray import AxisArray
9
8
  from ezmsg.util.messages.util import replace
10
9
 
@@ -125,13 +124,13 @@ class TransformerProcessor(
125
124
  )
126
125
  ]
127
126
 
128
- def partial_fit(self, message: SampleMessage) -> None:
127
+ def partial_fit(self, message: AxisArray) -> None:
129
128
  self._state.model.train()
130
129
 
131
- X = self._to_tensor(message.sample.data)
130
+ X = self._to_tensor(message.data)
132
131
  X, batched = self._ensure_batched(X)
133
132
 
134
- y_targ = message.trigger.value
133
+ y_targ = message.attrs["trigger"].value
135
134
  if not isinstance(y_targ, dict):
136
135
  y_targ = {"output": y_targ}
137
136
  y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
@@ -1,6 +1,6 @@
1
1
  import numpy as np
2
2
  import pytest
3
- from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage
3
+ from ezmsg.baseproc import SampleTriggerMessage
4
4
  from ezmsg.util.messages.axisarray import AxisArray, replace
5
5
 
6
6
  from ezmsg.learn.process.adaptive_linear_regressor import (
@@ -42,7 +42,7 @@ def test_adaptive_linear_regressor(model_type: str):
42
42
  period=(0.0, dur),
43
43
  value=value_axarr,
44
44
  )
45
- samp = SampleMessage(trigger=samp_trig, sample=sig_axarr)
45
+ samp = replace(sig_axarr, attrs={"trigger": samp_trig})
46
46
 
47
47
  proc = AdaptiveLinearRegressorTransformer(model_type=model_type)
48
48
  _ = proc.send(samp)
@@ -1,6 +1,6 @@
1
1
  import numpy as np
2
2
  import pytest
3
- from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage
3
+ from ezmsg.baseproc import SampleTriggerMessage
4
4
  from ezmsg.util.messages.axisarray import AxisArray, replace
5
5
 
6
6
  from ezmsg.learn.process.linear_regressor import LinearRegressorTransformer
@@ -40,7 +40,7 @@ def test_linear_regressor(model_type: str):
40
40
  period=(0.0, dur),
41
41
  value=value_axarr,
42
42
  )
43
- samp = SampleMessage(trigger=samp_trig, sample=sig_axarr)
43
+ samp = replace(sig_axarr, attrs={"trigger": samp_trig})
44
44
 
45
45
  gen = LinearRegressorTransformer(model_type=model_type)
46
46
  _ = gen.send(samp)
@@ -77,7 +77,8 @@ def test_mlp_checkpoint_io(tmp_path, sample_input, mlp_settings):
77
77
 
78
78
 
79
79
  def test_mlp_partial_fit_learns(sample_input, mlp_settings):
80
- from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage
80
+ from ezmsg.baseproc import SampleTriggerMessage
81
+ from ezmsg.util.messages.util import replace
81
82
 
82
83
  proc = TorchModelProcessor(
83
84
  model_class="ezmsg.learn.model.mlp.MLP",
@@ -88,13 +89,12 @@ def test_mlp_partial_fit_learns(sample_input, mlp_settings):
88
89
  )
89
90
  proc(sample_input)
90
91
 
91
- sample = AxisArray(
92
- data=sample_input.data[:1], dims=["time", "ch"], axes=sample_input.axes
93
- )
92
+ sample = AxisArray(data=sample_input.data[:1], dims=["time", "ch"], axes=sample_input.axes)
94
93
  target = np.random.randn(1, 5)
95
94
 
96
- msg = SampleMessage(
97
- sample=sample, trigger=SampleTriggerMessage(timestamp=0.0, value=target)
95
+ msg = replace(
96
+ sample,
97
+ attrs={**sample.attrs, "trigger": SampleTriggerMessage(timestamp=0.0, value=target)},
98
98
  )
99
99
 
100
100
  before = [p.detach().clone() for p in proc.state.model.parameters()]
@@ -135,9 +135,9 @@ def test_mlp_hidden_size_integer(sample_input):
135
135
  device="cpu",
136
136
  )
137
137
  proc(sample_input)
138
- hidden_layers = [
139
- m for m in proc._state.model.modules() if isinstance(m, torch.nn.Linear)
140
- ][:-1] # Exclude the output head
138
+ hidden_layers = [m for m in proc._state.model.modules() if isinstance(m, torch.nn.Linear)][
139
+ :-1
140
+ ] # Exclude the output head
141
141
  assert len(hidden_layers) == 3 # num_layers = 3
142
142
  assert hidden_layers[0].in_features == 8
143
143
  assert all(layer.out_features == 32 for layer in hidden_layers[:-1])
@@ -4,8 +4,9 @@ import numpy as np
4
4
  import pytest
5
5
  import torch
6
6
  import torch.nn
7
- from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage
7
+ from ezmsg.baseproc import SampleTriggerMessage
8
8
  from ezmsg.util.messages.axisarray import AxisArray
9
+ from ezmsg.util.messages.util import replace
9
10
  from sklearn.model_selection import train_test_split
10
11
 
11
12
  from ezmsg.learn.process.mlp_old import MLPProcessor
@@ -146,7 +147,10 @@ def test_mlp_process():
146
147
  template.data[:] = X # This would fail if n_samps / batch_size had a remainder.
147
148
  template.axes["time"].offset = ts
148
149
  if set == 0:
149
- yield SampleMessage(trigger=SampleTriggerMessage(timestamp=ts, value=y), sample=template)
150
+ yield replace(
151
+ template,
152
+ attrs={**template.attrs, "trigger": SampleTriggerMessage(timestamp=ts, value=y)},
153
+ )
150
154
  else:
151
155
  yield template, y
152
156
 
@@ -167,14 +171,15 @@ def test_mlp_process():
167
171
  result = []
168
172
  train_loss = []
169
173
  for sample_msg in xy_gen(set=0):
170
- # Naive closed-loop inference
171
- result.append(proc(sample_msg.sample))
174
+ # Naive closed-loop inference — strip trigger attrs before inference
175
+ plain_msg = replace(sample_msg, attrs={})
176
+ result.append(proc(plain_msg))
172
177
 
173
178
  # Collect the loss to see if it decreases with training.
174
179
  train_loss.append(
175
180
  torch.nn.MSELoss()(
176
181
  torch.tensor(result[-1].data),
177
- torch.tensor(sample_msg.trigger.value.reshape(-1, 1), dtype=torch.float32),
182
+ torch.tensor(sample_msg.attrs["trigger"].value.reshape(-1, 1), dtype=torch.float32),
178
183
  ).item()
179
184
  )
180
185
 
@@ -4,6 +4,7 @@ from pathlib import Path
4
4
 
5
5
  import numpy as np
6
6
  import pytest
7
+ from ezmsg.baseproc import SampleTriggerMessage
7
8
  from ezmsg.util.messages.axisarray import AxisArray
8
9
 
9
10
  from ezmsg.learn.process.refit_kalman import (
@@ -299,12 +300,6 @@ def test_partial_fit_functionality(create_test_message, checkpoint_file):
299
300
  H_initial = checkpoint_data["H_observation_matrix"]
300
301
  Q_initial = checkpoint_data["Q_measurement_noise_covariance"]
301
302
 
302
- # Create a mock SampleMessage with the expected structure
303
- class MockSampleMessage:
304
- def __init__(self, neural_data, trigger_value):
305
- self.sample = type("obj", (object,), {"data": neural_data})()
306
- self.trigger = type("obj", (object,), {"value": trigger_value})()
307
-
308
303
  # Create test data
309
304
  neural_data = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) # 3 samples, 2 channels
310
305
  trigger_value = {
@@ -315,8 +310,12 @@ def test_partial_fit_functionality(create_test_message, checkpoint_file):
315
310
  "hold_flags": [False, False, False],
316
311
  }
317
312
 
318
- mock_message = MockSampleMessage(neural_data, trigger_value)
319
- processor.partial_fit(mock_message)
313
+ sample_msg = AxisArray(
314
+ data=neural_data,
315
+ dims=["time", "ch"],
316
+ attrs={"trigger": SampleTriggerMessage(timestamp=0.0, value=trigger_value)},
317
+ )
318
+ processor.partial_fit(sample_msg)
320
319
 
321
320
  assert not np.allclose(H_initial, processor._state.model.H_observation_matrix)
322
321
  assert not np.allclose(Q_initial, processor._state.model.Q_measurement_noise_covariance)
@@ -5,8 +5,9 @@ import numpy as np
5
5
  import pytest
6
6
  import torch
7
7
  import torch.nn
8
- from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage
8
+ from ezmsg.baseproc import SampleTriggerMessage
9
9
  from ezmsg.util.messages.axisarray import AxisArray
10
+ from ezmsg.util.messages.util import replace
10
11
 
11
12
  from ezmsg.learn.process.rnn import RNNProcessor
12
13
 
@@ -107,9 +108,7 @@ def test_rnn_process(rnn_type, simple_message):
107
108
  # We don't pass in the hx state so it should be initialized to zeros, same as in the first call to proc.
108
109
  in_tensor = torch.tensor(simple_message.data[None, ...], dtype=torch.float32)
109
110
  with torch.no_grad():
110
- expected_result = (
111
- proc.state.model(in_tensor)[0]["output"].cpu().numpy().squeeze(0)
112
- )
111
+ expected_result = proc.state.model(in_tensor)[0]["output"].cpu().numpy().squeeze(0)
113
112
  assert np.allclose(output.data, expected_result)
114
113
 
115
114
 
@@ -139,9 +138,9 @@ def test_rnn_partial_fit(simple_message):
139
138
 
140
139
  target_shape = (simple_message.data.shape[0], output_size)
141
140
  target_value = np.ones(target_shape, dtype=np.float32)
142
- sample_message = SampleMessage(
143
- trigger=SampleTriggerMessage(timestamp=0.0, value=target_value),
144
- sample=simple_message,
141
+ sample_message = replace(
142
+ simple_message,
143
+ attrs={**simple_message.attrs, "trigger": SampleTriggerMessage(timestamp=0.0, value=target_value)},
145
144
  )
146
145
 
147
146
  proc(sample_message)
@@ -149,9 +148,7 @@ def test_rnn_partial_fit(simple_message):
149
148
  assert not proc.state.model.training
150
149
  updated_weights = [p.detach() for p in proc.state.model.parameters()]
151
150
 
152
- assert any(
153
- not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights)
154
- )
151
+ assert any(not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights))
155
152
 
156
153
 
157
154
  def test_rnn_checkpoint_save_load(simple_message):
@@ -201,9 +198,7 @@ def test_rnn_checkpoint_save_load(simple_message):
201
198
 
202
199
  for key in state_dict1:
203
200
  assert key in state_dict2, f"Missing key {key} in loaded state_dict"
204
- assert torch.equal(state_dict1[key], state_dict2[key]), (
205
- f"Mismatch in parameter {key}"
206
- )
201
+ assert torch.equal(state_dict1[key], state_dict2[key]), f"Mismatch in parameter {key}"
207
202
 
208
203
  finally:
209
204
  # Ensure the temporary file is deleted
@@ -244,20 +239,21 @@ def test_rnn_partial_fit_multiloss(simple_message):
244
239
  dtype=torch.long,
245
240
  )
246
241
 
247
- sample_message = SampleMessage(
248
- trigger=SampleTriggerMessage(
249
- timestamp=0.0,
250
- value={"traj": traj_target, "state": state_target},
251
- ),
252
- sample=simple_message,
242
+ sample_message = replace(
243
+ simple_message,
244
+ attrs={
245
+ **simple_message.attrs,
246
+ "trigger": SampleTriggerMessage(
247
+ timestamp=0.0,
248
+ value={"traj": traj_target, "state": state_target},
249
+ ),
250
+ },
253
251
  )
254
252
 
255
253
  proc.partial_fit(sample_message)
256
254
 
257
255
  updated_weights = [p.detach() for p in proc.state.model.parameters()]
258
- assert any(
259
- not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights)
260
- )
256
+ assert any(not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights))
261
257
 
262
258
 
263
259
  @pytest.mark.parametrize(
@@ -269,9 +265,7 @@ def test_rnn_partial_fit_multiloss(simple_message):
269
265
  ("auto", 0.05, 0.1, False), # overlapping → reset
270
266
  ],
271
267
  )
272
- def test_rnn_preserve_state(
273
- preserve_state_across_windows, win_stride, win_len, should_preserve
274
- ):
268
+ def test_rnn_preserve_state(preserve_state_across_windows, win_stride, win_len, should_preserve):
275
269
  hidden_size = 16
276
270
  num_layers = 1
277
271
  output_size = 2
@@ -1,5 +1,5 @@
1
1
  import numpy as np
2
- from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage
2
+ from ezmsg.baseproc import SampleTriggerMessage
3
3
  from ezmsg.util.messages.axisarray import AxisArray
4
4
 
5
5
  from ezmsg.learn.process.sgd import SGDDecoderSettings, SGDDecoderTransformer
@@ -13,9 +13,10 @@ def test_sgd():
13
13
  data = np.random.normal(scale=0.05, size=(3, 2, 1))
14
14
  data[time_idx[label] : time_idx[label] + 1, 0, 0] += 1.0
15
15
  samples.append(
16
- SampleMessage(
17
- trigger=SampleTriggerMessage(timestamp=len(samples), period=None, value=label),
18
- sample=AxisArray(data=data, dims=["time", "ch", "freq"]),
16
+ AxisArray(
17
+ data=data,
18
+ dims=["time", "ch", "freq"],
19
+ attrs={"trigger": SampleTriggerMessage(timestamp=len(samples), period=None, value=label)},
19
20
  )
20
21
  )
21
22
 
@@ -1,7 +1,8 @@
1
1
  import numpy as np
2
2
  import pytest
3
- from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage
3
+ from ezmsg.baseproc import SampleTriggerMessage
4
4
  from ezmsg.util.messages.axisarray import AxisArray
5
+ from ezmsg.util.messages.util import replace
5
6
 
6
7
  from ezmsg.learn.process.sklearn import SklearnModelProcessor
7
8
 
@@ -83,9 +84,9 @@ def test_partial_fit_supported_models(
83
84
  proc = SklearnModelProcessor(**settings_kwargs)
84
85
  proc._reset_state(input_axisarray)
85
86
 
86
- sample_msg = SampleMessage(
87
- sample=input_axisarray,
88
- trigger=SampleTriggerMessage(timestamp=0.0, value=labels),
87
+ sample_msg = replace(
88
+ input_axisarray,
89
+ attrs={**input_axisarray.attrs, "trigger": SampleTriggerMessage(timestamp=0.0, value=labels)},
89
90
  )
90
91
 
91
92
  proc.partial_fit(sample_msg)
@@ -96,9 +97,9 @@ def test_partial_fit_supported_models(
96
97
  def test_partial_fit_unsupported_model(input_axisarray, labels_regression):
97
98
  proc = SklearnModelProcessor(model_class="sklearn.linear_model.Ridge")
98
99
  proc._reset_state(input_axisarray)
99
- sample_msg = SampleMessage(
100
- sample=input_axisarray,
101
- trigger=SampleTriggerMessage(timestamp=0.0, value=labels_regression),
100
+ sample_msg = replace(
101
+ input_axisarray,
102
+ attrs={**input_axisarray.attrs, "trigger": SampleTriggerMessage(timestamp=0.0, value=labels_regression)},
102
103
  )
103
104
  with pytest.raises(NotImplementedError, match="partial_fit"):
104
105
  proc.partial_fit(sample_msg)
@@ -108,9 +109,9 @@ def test_partial_fit_changes_model(input_axisarray, labels_regression):
108
109
  proc = SklearnModelProcessor(model_class="sklearn.linear_model.SGDRegressor")
109
110
  proc._reset_state(input_axisarray)
110
111
 
111
- sample_msg = SampleMessage(
112
- sample=input_axisarray,
113
- trigger=SampleTriggerMessage(timestamp=0.0, value=labels_regression),
112
+ sample_msg = replace(
113
+ input_axisarray,
114
+ attrs={**input_axisarray.attrs, "trigger": SampleTriggerMessage(timestamp=0.0, value=labels_regression)},
114
115
  )
115
116
 
116
117
  proc.partial_fit(sample_msg)
@@ -127,9 +128,7 @@ def test_model_save_and_load(tmp_path, input_axisarray):
127
128
  checkpoint_path = tmp_path / "model_checkpoint.pkl"
128
129
  proc.save_checkpoint(str(checkpoint_path))
129
130
 
130
- new_proc = SklearnModelProcessor(
131
- model_class="sklearn.linear_model.Ridge", checkpoint_path=str(checkpoint_path)
132
- )
131
+ new_proc = SklearnModelProcessor(model_class="sklearn.linear_model.Ridge", checkpoint_path=str(checkpoint_path))
133
132
  new_proc._reset_state(input_axisarray)
134
133
  assert new_proc._state.model is not None
135
134
 
@@ -5,8 +5,9 @@ from pathlib import Path
5
5
  import numpy as np
6
6
  import pytest
7
7
  import torch
8
- from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage
8
+ from ezmsg.baseproc import SampleTriggerMessage
9
9
  from ezmsg.util.messages.axisarray import AxisArray
10
+ from ezmsg.util.messages.util import replace
10
11
 
11
12
  from ezmsg.learn.process.torch import TorchModelProcessor
12
13
 
@@ -185,9 +186,9 @@ def test_partial_fit_changes_weights(batch_message, device):
185
186
  },
186
187
  )
187
188
 
188
- msg = SampleMessage(
189
- sample=sample,
190
- trigger=SampleTriggerMessage(timestamp=0.0, value=y),
189
+ msg = replace(
190
+ sample,
191
+ attrs={"trigger": SampleTriggerMessage(timestamp=0.0, value=y)},
191
192
  )
192
193
 
193
194
  proc(sample) # run forward pass once to init model
@@ -318,14 +319,11 @@ def test_multihead_partial_fit_with_loss_dict(batch_message, device):
318
319
  "head_a": np.random.randn(1, 2),
319
320
  "head_b": np.random.randn(1, 3),
320
321
  }
321
- sample = AxisArray(
322
+ msg = AxisArray(
322
323
  data=batch_message.data[:1],
323
324
  dims=["time", "ch"],
324
325
  axes=batch_message.axes,
325
- )
326
- msg = SampleMessage(
327
- sample=sample,
328
- trigger=SampleTriggerMessage(timestamp=0.0, value=y_targ),
326
+ attrs={"trigger": SampleTriggerMessage(timestamp=0.0, value=y_targ)},
329
327
  )
330
328
 
331
329
  before_a = proc._state.model.head_a.weight.clone()
@@ -360,14 +358,11 @@ def test_partial_fit_with_loss_weights(batch_message, device):
360
358
  "head_a": np.random.randn(1, 2),
361
359
  "head_b": np.random.randn(1, 3),
362
360
  }
363
- sample = AxisArray(
361
+ msg = AxisArray(
364
362
  data=batch_message.data[:1],
365
363
  dims=["time", "ch"],
366
364
  axes=batch_message.axes,
367
- )
368
- msg = SampleMessage(
369
- sample=sample,
370
- trigger=SampleTriggerMessage(timestamp=0.0, value=y_targ),
365
+ attrs={"trigger": SampleTriggerMessage(timestamp=0.0, value=y_targ)},
371
366
  )
372
367
 
373
368
  # Expect no error, and just run once
@@ -5,8 +5,9 @@ import numpy as np
5
5
  import pytest
6
6
  import torch
7
7
  import torch.nn
8
- from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage
8
+ from ezmsg.baseproc import SampleTriggerMessage
9
9
  from ezmsg.util.messages.axisarray import AxisArray
10
+ from ezmsg.util.messages.util import replace
10
11
 
11
12
  from ezmsg.learn.process.transformer import TransformerProcessor
12
13
 
@@ -138,9 +139,9 @@ def test_transformer_partial_fit(simple_message, decoder_layers):
138
139
 
139
140
  target_shape = (simple_message.data.shape[0], output_size)
140
141
  target_value = np.ones(target_shape, dtype=np.float32)
141
- sample_message = SampleMessage(
142
- trigger=SampleTriggerMessage(timestamp=0.0, value=target_value),
143
- sample=simple_message,
142
+ sample_message = replace(
143
+ simple_message,
144
+ attrs={**simple_message.attrs, "trigger": SampleTriggerMessage(timestamp=0.0, value=target_value)},
144
145
  )
145
146
 
146
147
  proc.partial_fit(sample_message)
@@ -149,9 +150,7 @@ def test_transformer_partial_fit(simple_message, decoder_layers):
149
150
  assert proc.state.tgt_cache is None
150
151
  updated_weights = [p.detach() for p in proc.state.model.parameters()]
151
152
 
152
- assert any(
153
- not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights)
154
- )
153
+ assert any(not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights))
155
154
 
156
155
 
157
156
  def test_transformer_checkpoint_save_load(simple_message):
@@ -201,9 +200,7 @@ def test_transformer_checkpoint_save_load(simple_message):
201
200
 
202
201
  for key in state_dict1:
203
202
  assert key in state_dict2, f"Missing key {key} in loaded state_dict"
204
- assert torch.equal(state_dict1[key], state_dict2[key]), (
205
- f"Mismatch in parameter {key}"
206
- )
203
+ assert torch.equal(state_dict1[key], state_dict2[key]), f"Mismatch in parameter {key}"
207
204
 
208
205
  finally:
209
206
  # Ensure the temporary file is deleted
@@ -244,20 +241,21 @@ def test_transformer_partial_fit_multiloss(simple_message):
244
241
  dtype=torch.long,
245
242
  )
246
243
 
247
- sample_message = SampleMessage(
248
- trigger=SampleTriggerMessage(
249
- timestamp=0.0,
250
- value={"traj": traj_target, "state": state_target},
251
- ),
252
- sample=simple_message,
244
+ sample_message = replace(
245
+ simple_message,
246
+ attrs={
247
+ **simple_message.attrs,
248
+ "trigger": SampleTriggerMessage(
249
+ timestamp=0.0,
250
+ value={"traj": traj_target, "state": state_target},
251
+ ),
252
+ },
253
253
  )
254
254
 
255
255
  proc.partial_fit(sample_message)
256
256
 
257
257
  updated_weights = [p.detach() for p in proc.state.model.parameters()]
258
- assert any(
259
- not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights)
260
- )
258
+ assert any(not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights))
261
259
 
262
260
 
263
261
  def test_autoregressive_cache_behavior(simple_message):
File without changes
File without changes
File without changes
File without changes
File without changes