ezmsg-sigproc 1.7.0__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. ezmsg/sigproc/__version__.py +22 -4
  2. ezmsg/sigproc/activation.py +31 -40
  3. ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
  4. ezmsg/sigproc/affinetransform.py +171 -169
  5. ezmsg/sigproc/aggregate.py +190 -97
  6. ezmsg/sigproc/bandpower.py +60 -55
  7. ezmsg/sigproc/base.py +143 -33
  8. ezmsg/sigproc/butterworthfilter.py +34 -38
  9. ezmsg/sigproc/butterworthzerophase.py +305 -0
  10. ezmsg/sigproc/cheby.py +23 -17
  11. ezmsg/sigproc/combfilter.py +160 -0
  12. ezmsg/sigproc/coordinatespaces.py +159 -0
  13. ezmsg/sigproc/decimate.py +15 -10
  14. ezmsg/sigproc/denormalize.py +78 -0
  15. ezmsg/sigproc/detrend.py +28 -0
  16. ezmsg/sigproc/diff.py +82 -0
  17. ezmsg/sigproc/downsample.py +72 -81
  18. ezmsg/sigproc/ewma.py +217 -0
  19. ezmsg/sigproc/ewmfilter.py +1 -1
  20. ezmsg/sigproc/extract_axis.py +39 -0
  21. ezmsg/sigproc/fbcca.py +307 -0
  22. ezmsg/sigproc/filter.py +254 -148
  23. ezmsg/sigproc/filterbank.py +226 -214
  24. ezmsg/sigproc/filterbankdesign.py +129 -0
  25. ezmsg/sigproc/fir_hilbert.py +336 -0
  26. ezmsg/sigproc/fir_pmc.py +209 -0
  27. ezmsg/sigproc/firfilter.py +117 -0
  28. ezmsg/sigproc/gaussiansmoothing.py +89 -0
  29. ezmsg/sigproc/kaiser.py +106 -0
  30. ezmsg/sigproc/linear.py +120 -0
  31. ezmsg/sigproc/math/abs.py +23 -22
  32. ezmsg/sigproc/math/add.py +120 -0
  33. ezmsg/sigproc/math/clip.py +33 -25
  34. ezmsg/sigproc/math/difference.py +117 -43
  35. ezmsg/sigproc/math/invert.py +18 -25
  36. ezmsg/sigproc/math/log.py +38 -33
  37. ezmsg/sigproc/math/scale.py +24 -25
  38. ezmsg/sigproc/messages.py +1 -2
  39. ezmsg/sigproc/quantize.py +68 -0
  40. ezmsg/sigproc/resample.py +278 -0
  41. ezmsg/sigproc/rollingscaler.py +232 -0
  42. ezmsg/sigproc/sampler.py +209 -254
  43. ezmsg/sigproc/scaler.py +93 -218
  44. ezmsg/sigproc/signalinjector.py +44 -43
  45. ezmsg/sigproc/slicer.py +74 -102
  46. ezmsg/sigproc/spectral.py +3 -3
  47. ezmsg/sigproc/spectrogram.py +70 -70
  48. ezmsg/sigproc/spectrum.py +187 -173
  49. ezmsg/sigproc/transpose.py +134 -0
  50. ezmsg/sigproc/util/__init__.py +0 -0
  51. ezmsg/sigproc/util/asio.py +25 -0
  52. ezmsg/sigproc/util/axisarray_buffer.py +365 -0
  53. ezmsg/sigproc/util/buffer.py +449 -0
  54. ezmsg/sigproc/util/message.py +17 -0
  55. ezmsg/sigproc/util/profile.py +23 -0
  56. ezmsg/sigproc/util/sparse.py +115 -0
  57. ezmsg/sigproc/util/typeresolution.py +17 -0
  58. ezmsg/sigproc/wavelets.py +147 -154
  59. ezmsg/sigproc/window.py +248 -210
  60. ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
  61. ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
  62. {ezmsg_sigproc-1.7.0.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -1
  63. ezmsg/sigproc/synth.py +0 -621
  64. ezmsg_sigproc-1.7.0.dist-info/METADATA +0 -58
  65. ezmsg_sigproc-1.7.0.dist-info/RECORD +0 -36
  66. /ezmsg_sigproc-1.7.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
@@ -1,18 +1,32 @@
1
+ """
2
+ Aggregation operations over arrays.
3
+
4
+ .. note::
5
+ :obj:`AggregateTransformer` supports the :doc:`Array API standard </guides/explanations/array_api>`,
6
+ enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
7
+ :obj:`RangedAggregateTransformer` currently requires NumPy arrays.
8
+ """
9
+
1
10
  import typing
2
11
 
12
+ import ezmsg.core as ez
3
13
  import numpy as np
4
14
  import numpy.typing as npt
5
- import ezmsg.core as ez
6
- from ezmsg.util.generator import consumer
15
+ from array_api_compat import get_namespace
16
+ from ezmsg.baseproc import (
17
+ BaseStatefulTransformer,
18
+ BaseTransformer,
19
+ BaseTransformerUnit,
20
+ processor_state,
21
+ )
7
22
  from ezmsg.util.messages.axisarray import (
8
23
  AxisArray,
9
- slice_along_axis,
10
24
  AxisBase,
11
25
  replace,
26
+ slice_along_axis,
12
27
  )
13
28
 
14
29
  from .spectral import OptionsEnum
15
- from .base import GenAxisArray
16
30
 
17
31
 
18
32
  class AggregationFunction(OptionsEnum):
@@ -33,6 +47,7 @@ class AggregationFunction(OptionsEnum):
33
47
  NANSUM = "nansum"
34
48
  ARGMIN = "argmin"
35
49
  ARGMAX = "argmax"
50
+ TRAPEZOID = "trapezoid"
36
51
 
37
52
 
38
53
  AGGREGATORS = {
@@ -51,133 +66,211 @@ AGGREGATORS = {
51
66
  AggregationFunction.NANSUM: np.nansum,
52
67
  AggregationFunction.ARGMIN: np.argmin,
53
68
  AggregationFunction.ARGMAX: np.argmax,
69
+ # Note: Some methods require x-coordinates and
70
+ # are handled specially in `_process`.
71
+ AggregationFunction.TRAPEZOID: np.trapezoid,
54
72
  }
55
73
 
56
74
 
57
- @consumer
58
- def ranged_aggregate(
59
- axis: str | None = None,
60
- bands: list[tuple[float, float]] | None = None,
61
- operation: AggregationFunction = AggregationFunction.MEAN,
62
- ):
75
+ class RangedAggregateSettings(ez.Settings):
76
+ """
77
+ Settings for ``RangedAggregate``.
63
78
  """
64
- Apply an aggregation operation over one or more bands.
65
79
 
66
- Args:
67
- axis: The name of the axis along which to apply the bands.
68
- bands: [(band1_min, band1_max), (band2_min, band2_max), ...]
69
- If not set then this acts as a passthrough node.
70
- operation: :obj:`AggregationFunction` to apply to each band.
80
+ axis: str | None = None
81
+ """The name of the axis along which to apply the bands."""
71
82
 
72
- Returns:
73
- A primed generator object ready to yield an :obj:`AxisArray` for each .send(axis_array)
83
+ bands: list[tuple[float, float]] | None = None
74
84
  """
75
- msg_out = AxisArray(np.array([]), dims=[""])
85
+ [(band1_min, band1_max), (band2_min, band2_max), ...]
86
+ If not set then this acts as a passthrough node.
87
+ """
88
+
89
+ operation: AggregationFunction = AggregationFunction.MEAN
90
+ """:obj:`AggregationFunction` to apply to each band."""
76
91
 
77
- # State variables
92
+
93
+ @processor_state
94
+ class RangedAggregateState:
78
95
  slices: list[tuple[typing.Any, ...]] | None = None
79
96
  out_axis: AxisBase | None = None
80
97
  ax_vec: npt.NDArray | None = None
81
98
 
82
- # Reset if any of these changes. Key not checked because continuity between chunks not required.
83
- check_inputs = {"gain": None, "offset": None, "len": None, "key": None}
84
99
 
85
- while True:
86
- msg_in: AxisArray = yield msg_out
87
- if bands is None:
88
- msg_out = msg_in
100
+ class RangedAggregateTransformer(
101
+ BaseStatefulTransformer[RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateState]
102
+ ):
103
+ def __call__(self, message: AxisArray) -> AxisArray:
104
+ # Override for shortcut passthrough mode.
105
+ if self.settings.bands is None:
106
+ return message
107
+ return super().__call__(message)
108
+
109
+ def _hash_message(self, message: AxisArray) -> int:
110
+ axis = self.settings.axis or message.dims[0]
111
+ target_axis = message.get_axis(axis)
112
+
113
+ hash_components = (message.key,)
114
+ if hasattr(target_axis, "data"):
115
+ hash_components += (len(target_axis.data),)
116
+ elif isinstance(target_axis, AxisArray.LinearAxis):
117
+ hash_components += (target_axis.gain, target_axis.offset)
118
+ return hash(hash_components)
119
+
120
+ def _reset_state(self, message: AxisArray) -> None:
121
+ axis = self.settings.axis or message.dims[0]
122
+ target_axis = message.get_axis(axis)
123
+ ax_idx = message.get_axis_idx(axis)
124
+
125
+ if hasattr(target_axis, "data"):
126
+ self._state.ax_vec = target_axis.data
89
127
  else:
90
- axis = axis or msg_in.dims[0]
91
- target_axis = msg_in.get_axis(axis)
128
+ self._state.ax_vec = target_axis.value(np.arange(message.data.shape[ax_idx]))
92
129
 
93
- # Check if we need to reset state
94
- b_reset = msg_in.key != check_inputs["key"]
130
+ ax_dat = []
131
+ slices = []
132
+ for start, stop in self.settings.bands:
133
+ inds = np.where(np.logical_and(self._state.ax_vec >= start, self._state.ax_vec <= stop))[0]
134
+ slices.append(np.s_[inds[0] : inds[-1] + 1])
95
135
  if hasattr(target_axis, "data"):
96
- b_reset = b_reset or len(target_axis.data) != check_inputs["len"]
97
- elif isinstance(target_axis, AxisArray.LinearAxis):
98
- b_reset = b_reset or target_axis.gain != check_inputs["gain"]
99
- b_reset = b_reset or target_axis.offset != check_inputs["offset"]
100
-
101
- if b_reset:
102
- # Update check variables
103
- check_inputs["key"] = msg_in.key
104
- if hasattr(target_axis, "data"):
105
- check_inputs["len"] = len(target_axis.data)
136
+ if self._state.ax_vec.dtype.type is np.str_:
137
+ sl_dat = f"{self._state.ax_vec[start]} - {self._state.ax_vec[stop]}"
106
138
  else:
107
- check_inputs["gain"] = target_axis.gain
108
- check_inputs["offset"] = target_axis.offset
139
+ ax_dat.append(np.mean(self._state.ax_vec[inds]))
140
+ else:
141
+ sl_dat = target_axis.value(np.mean(inds))
142
+ ax_dat.append(sl_dat)
109
143
 
110
- # If the axis we are operating on has changed (e.g., "time" or "win" axis always changes),
111
- # or the key has changed, then recalculate slices.
144
+ self._state.slices = slices
145
+ self._state.out_axis = AxisArray.CoordinateAxis(
146
+ data=np.array(ax_dat),
147
+ dims=[axis],
148
+ unit=target_axis.unit,
149
+ )
112
150
 
113
- ax_idx = msg_in.get_axis_idx(axis)
151
+ def _process(self, message: AxisArray) -> AxisArray:
152
+ axis = self.settings.axis or message.dims[0]
153
+ ax_idx = message.get_axis_idx(axis)
154
+ agg_func = AGGREGATORS[self.settings.operation]
114
155
 
115
- if hasattr(target_axis, "data"):
116
- ax_vec = target_axis.data
117
- else:
118
- ax_vec = target_axis.value(np.arange(msg_in.data.shape[ax_idx]))
119
-
120
- slices = []
121
- ax_dat = []
122
- for start, stop in bands:
123
- inds = np.where(np.logical_and(ax_vec >= start, ax_vec <= stop))[0]
124
- slices.append(np.s_[inds[0] : inds[-1] + 1])
125
- if hasattr(target_axis, "data"):
126
- if ax_vec.dtype.type is np.str_:
127
- sl_dat = f"{ax_vec[start]} - {ax_vec[stop]}"
128
- else:
129
- sl_dat = ax_dat.append(np.mean(ax_vec[inds]))
130
- else:
131
- sl_dat = target_axis.value(np.mean(inds))
132
- ax_dat.append(sl_dat)
133
-
134
- out_axis = AxisArray.CoordinateAxis(
135
- data=np.array(ax_dat),
136
- dims=[axis],
137
- unit=target_axis.unit,
156
+ if self.settings.operation in [
157
+ AggregationFunction.TRAPEZOID,
158
+ ]:
159
+ # Special handling for methods that require x-coordinates.
160
+ out_data = [
161
+ agg_func(
162
+ slice_along_axis(message.data, sl, axis=ax_idx),
163
+ x=self._state.ax_vec[sl],
164
+ axis=ax_idx,
138
165
  )
139
-
140
- agg_func = AGGREGATORS[operation]
166
+ for sl in self._state.slices
167
+ ]
168
+ else:
141
169
  out_data = [
142
- agg_func(slice_along_axis(msg_in.data, sl, axis=ax_idx), axis=ax_idx)
143
- for sl in slices
170
+ agg_func(slice_along_axis(message.data, sl, axis=ax_idx), axis=ax_idx) for sl in self._state.slices
144
171
  ]
145
172
 
146
- msg_out = replace(
147
- msg_in,
148
- data=np.stack(out_data, axis=ax_idx),
149
- axes={**msg_in.axes, axis: out_axis},
150
- )
151
- if operation in [AggregationFunction.ARGMIN, AggregationFunction.ARGMAX]:
152
- # Convert indices returned by argmin/argmax into the value along the axis.
153
- out_data = []
154
- for sl_ix, sl in enumerate(slices):
155
- offsets = np.take(msg_out.data, [sl_ix], axis=ax_idx)
156
- out_data.append(ax_vec[sl][offsets])
157
- msg_out.data = np.concatenate(out_data, axis=ax_idx)
173
+ msg_out = replace(
174
+ message,
175
+ data=np.stack(out_data, axis=ax_idx),
176
+ axes={**message.axes, axis: self._state.out_axis},
177
+ )
158
178
 
179
+ if self.settings.operation in [
180
+ AggregationFunction.ARGMIN,
181
+ AggregationFunction.ARGMAX,
182
+ ]:
183
+ out_data = []
184
+ for sl_ix, sl in enumerate(self._state.slices):
185
+ offsets = np.take(msg_out.data, [sl_ix], axis=ax_idx)
186
+ out_data.append(self._state.ax_vec[sl][offsets])
187
+ msg_out.data = np.concatenate(out_data, axis=ax_idx)
159
188
 
160
- class RangedAggregateSettings(ez.Settings):
189
+ return msg_out
190
+
191
+
192
+ class RangedAggregate(BaseTransformerUnit[RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateTransformer]):
193
+ SETTINGS = RangedAggregateSettings
194
+
195
+
196
+ def ranged_aggregate(
197
+ axis: str | None = None,
198
+ bands: list[tuple[float, float]] | None = None,
199
+ operation: AggregationFunction = AggregationFunction.MEAN,
200
+ ) -> RangedAggregateTransformer:
161
201
  """
162
- Settings for ``RangedAggregate``.
163
- See :obj:`ranged_aggregate` for details.
202
+ Apply an aggregation operation over one or more bands.
203
+
204
+ Args:
205
+ axis: The name of the axis along which to apply the bands.
206
+ bands: [(band1_min, band1_max), (band2_min, band2_max), ...]
207
+ If not set then this acts as a passthrough node.
208
+ operation: :obj:`AggregationFunction` to apply to each band.
209
+
210
+ Returns:
211
+ :obj:`RangedAggregateTransformer`
164
212
  """
213
+ return RangedAggregateTransformer(RangedAggregateSettings(axis=axis, bands=bands, operation=operation))
214
+
215
+
216
+ class AggregateSettings(ez.Settings):
217
+ """Settings for :obj:`Aggregate`."""
218
+
219
+ axis: str
220
+ """The name of the axis to aggregate over. This axis will be removed from the output."""
165
221
 
166
- axis: str | None = None
167
- bands: list[tuple[float, float]] | None = None
168
222
  operation: AggregationFunction = AggregationFunction.MEAN
223
+ """:obj:`AggregationFunction` to apply."""
169
224
 
170
225
 
171
- class RangedAggregate(GenAxisArray):
226
+ class AggregateTransformer(BaseTransformer[AggregateSettings, AxisArray, AxisArray]):
172
227
  """
173
- Unit for :obj:`ranged_aggregate`
228
+ Transformer that aggregates an entire axis using a specified operation.
229
+
230
+ Unlike :obj:`RangedAggregateTransformer` which aggregates over specific ranges/bands
231
+ and preserves the axis (with one value per band), this transformer aggregates the
232
+ entire axis and removes it from the output, reducing dimensionality by one.
174
233
  """
175
234
 
176
- SETTINGS = RangedAggregateSettings
235
+ def _process(self, message: AxisArray) -> AxisArray:
236
+ xp = get_namespace(message.data)
237
+ axis_idx = message.get_axis_idx(self.settings.axis)
238
+ op = self.settings.operation
177
239
 
178
- def construct_generator(self):
179
- self.STATE.gen = ranged_aggregate(
180
- axis=self.SETTINGS.axis,
181
- bands=self.SETTINGS.bands,
182
- operation=self.SETTINGS.operation,
240
+ if op == AggregationFunction.NONE:
241
+ raise ValueError("AggregationFunction.NONE is not supported for full-axis aggregation")
242
+
243
+ if op == AggregationFunction.TRAPEZOID:
244
+ # Trapezoid integration requires x-coordinates
245
+ target_axis = message.get_axis(self.settings.axis)
246
+ if hasattr(target_axis, "data"):
247
+ x = target_axis.data
248
+ else:
249
+ x = target_axis.value(np.arange(message.data.shape[axis_idx]))
250
+ agg_data = np.trapezoid(np.asarray(message.data), x=x, axis=axis_idx)
251
+ else:
252
+ # Try array-API compatible function first, fall back to numpy
253
+ func_name = op.value
254
+ if hasattr(xp, func_name):
255
+ agg_data = getattr(xp, func_name)(message.data, axis=axis_idx)
256
+ else:
257
+ agg_data = AGGREGATORS[op](message.data, axis=axis_idx)
258
+
259
+ new_dims = list(message.dims)
260
+ new_dims.pop(axis_idx)
261
+
262
+ new_axes = dict(message.axes)
263
+ new_axes.pop(self.settings.axis, None)
264
+
265
+ return replace(
266
+ message,
267
+ data=agg_data,
268
+ dims=new_dims,
269
+ axes=new_axes,
183
270
  )
271
+
272
+
273
+ class AggregateUnit(BaseTransformerUnit[AggregateSettings, AxisArray, AxisArray, AggregateTransformer]):
274
+ """Unit that aggregates an entire axis using a specified operation."""
275
+
276
+ SETTINGS = AggregateSettings
@@ -1,75 +1,80 @@
1
1
  from dataclasses import field
2
- import typing
3
2
 
4
- import numpy as np
5
3
  import ezmsg.core as ez
4
+ from ezmsg.baseproc import (
5
+ BaseProcessor,
6
+ BaseStatefulProcessor,
7
+ BaseTransformerUnit,
8
+ CompositeProcessor,
9
+ )
6
10
  from ezmsg.util.messages.axisarray import AxisArray
7
- from ezmsg.util.generator import consumer, compose
8
11
 
9
- from .spectrogram import spectrogram, SpectrogramSettings
10
- from .aggregate import ranged_aggregate, AggregationFunction
11
- from .base import GenAxisArray
12
+ from .aggregate import (
13
+ AggregationFunction,
14
+ RangedAggregateSettings,
15
+ RangedAggregateTransformer,
16
+ )
17
+ from .spectrogram import SpectrogramSettings, SpectrogramTransformer
12
18
 
13
19
 
14
- @consumer
15
- def bandpower(
16
- spectrogram_settings: SpectrogramSettings,
17
- bands: list[tuple[float, float]] | None = [
18
- (17, 30),
19
- (70, 170),
20
- ],
21
- ) -> typing.Generator[AxisArray, AxisArray, None]:
20
+ class BandPowerSettings(ez.Settings):
22
21
  """
23
- Calculate the average spectral power in each band.
24
-
25
- Args:
26
- spectrogram_settings: Settings for spectrogram calculation.
27
- bands: (min, max) tuples of band limits in Hz.
28
-
29
- Returns:
30
- A primed generator object ready to yield an :obj:`AxisArray` for each .send(axis_array)
31
- with the data payload being the average spectral power in each band of the input data.
22
+ Settings for ``BandPower``.
32
23
  """
33
- msg_out = AxisArray(np.array([]), dims=[""])
34
-
35
- f_spec = spectrogram(
36
- window_dur=spectrogram_settings.window_dur,
37
- window_shift=spectrogram_settings.window_shift,
38
- window=spectrogram_settings.window,
39
- transform=spectrogram_settings.transform,
40
- output=spectrogram_settings.output,
41
- )
42
- f_agg = ranged_aggregate(
43
- axis="freq", bands=bands, operation=AggregationFunction.MEAN
44
- )
45
- pipeline = compose(f_spec, f_agg)
46
-
47
- while True:
48
- msg_in: AxisArray = yield msg_out
49
- msg_out = pipeline(msg_in)
50
24
 
25
+ spectrogram_settings: SpectrogramSettings = field(default_factory=SpectrogramSettings)
26
+ """
27
+ Settings for spectrogram calculation.
28
+ """
51
29
 
52
- class BandPowerSettings(ez.Settings):
30
+ bands: list[tuple[float, float]] | None = field(default_factory=lambda: [(17, 30), (70, 170)])
53
31
  """
54
- Settings for ``BandPower``.
55
- See :obj:`bandpower` for details.
32
+ (min, max) tuples of band limits in Hz.
56
33
  """
57
34
 
58
- spectrogram_settings: SpectrogramSettings = field(
59
- default_factory=SpectrogramSettings
60
- )
61
- bands: list[tuple[float, float]] | None = field(
62
- default_factory=lambda: [(17, 30), (70, 170)]
63
- )
35
+ aggregation: AggregationFunction = AggregationFunction.MEAN
36
+ """:obj:`AggregationFunction` to apply to each band."""
37
+
64
38
 
39
+ class BandPowerTransformer(CompositeProcessor[BandPowerSettings, AxisArray, AxisArray]):
40
+ @staticmethod
41
+ def _initialize_processors(
42
+ settings: BandPowerSettings,
43
+ ) -> dict[str, BaseProcessor | BaseStatefulProcessor]:
44
+ return {
45
+ "spectrogram": SpectrogramTransformer(settings=settings.spectrogram_settings),
46
+ "aggregate": RangedAggregateTransformer(
47
+ settings=RangedAggregateSettings(
48
+ axis="freq",
49
+ bands=settings.bands,
50
+ operation=settings.aggregation,
51
+ )
52
+ ),
53
+ }
65
54
 
66
- class BandPower(GenAxisArray):
67
- """:obj:`Unit` for :obj:`bandpower`."""
68
55
 
56
+ class BandPower(BaseTransformerUnit[BandPowerSettings, AxisArray, AxisArray, BandPowerTransformer]):
69
57
  SETTINGS = BandPowerSettings
70
58
 
71
- def construct_generator(self):
72
- self.STATE.gen = bandpower(
73
- spectrogram_settings=self.SETTINGS.spectrogram_settings,
74
- bands=self.SETTINGS.bands,
59
+
60
+ def bandpower(
61
+ spectrogram_settings: SpectrogramSettings,
62
+ bands: list[tuple[float, float]] | None = [
63
+ (17, 30),
64
+ (70, 170),
65
+ ],
66
+ aggregation: AggregationFunction = AggregationFunction.MEAN,
67
+ ) -> BandPowerTransformer:
68
+ """
69
+ Calculate the average spectral power in each band.
70
+
71
+ Returns:
72
+ :obj:`BandPowerTransformer`
73
+ """
74
+ return BandPowerTransformer(
75
+ settings=BandPowerSettings(
76
+ spectrogram_settings=spectrogram_settings,
77
+ bands=bands,
78
+ aggregation=aggregation,
75
79
  )
80
+ )
ezmsg/sigproc/base.py CHANGED
@@ -1,39 +1,149 @@
1
- import math
2
- import traceback
3
- import typing
1
+ """
2
+ Backwards-compatible re-exports from ezmsg.baseproc.
4
3
 
5
- import ezmsg.core as ez
6
- from ezmsg.util.messages.axisarray import AxisArray
7
- from ezmsg.util.generator import GenState
4
+ This module re-exports all symbols from ezmsg.baseproc to maintain backwards
5
+ compatibility for code that imports from ezmsg.sigproc.base.
8
6
 
7
+ New code should import directly from ezmsg.baseproc instead.
8
+ """
9
9
 
10
- class GenAxisArray(ez.Unit):
11
- STATE = GenState
10
+ import warnings
12
11
 
13
- INPUT_SIGNAL = ez.InputStream(AxisArray)
14
- OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
15
- INPUT_SETTINGS = ez.InputStream(ez.Settings)
12
+ warnings.warn(
13
+ "Importing from 'ezmsg.sigproc.base' is deprecated. Please import from 'ezmsg.baseproc' instead.",
14
+ DeprecationWarning,
15
+ stacklevel=2,
16
+ )
16
17
 
17
- async def initialize(self) -> None:
18
- self.construct_generator()
18
+ # Re-export everything from ezmsg.baseproc for backwards compatibility
19
+ from ezmsg.baseproc import ( # noqa: E402
20
+ # Protocols
21
+ AdaptiveTransformer,
22
+ # Type variables
23
+ AdaptiveTransformerType,
24
+ # Stateful classes
25
+ BaseAdaptiveTransformer,
26
+ # Unit classes
27
+ BaseAdaptiveTransformerUnit,
28
+ BaseAsyncTransformer,
29
+ # Base processor classes
30
+ BaseConsumer,
31
+ BaseConsumerUnit,
32
+ BaseProcessor,
33
+ BaseProcessorUnit,
34
+ BaseProducer,
35
+ BaseProducerUnit,
36
+ BaseStatefulConsumer,
37
+ BaseStatefulProcessor,
38
+ BaseStatefulProducer,
39
+ BaseStatefulTransformer,
40
+ BaseTransformer,
41
+ BaseTransformerUnit,
42
+ # Composite classes
43
+ CompositeProcessor,
44
+ CompositeProducer,
45
+ CompositeStateful,
46
+ Consumer,
47
+ ConsumerType,
48
+ GenAxisArray,
49
+ MessageInType,
50
+ MessageOutType,
51
+ Processor,
52
+ Producer,
53
+ ProducerType,
54
+ # Message types
55
+ SampleMessage,
56
+ SettingsType,
57
+ Stateful,
58
+ StatefulConsumer,
59
+ StatefulProcessor,
60
+ StatefulProducer,
61
+ StatefulTransformer,
62
+ StateType,
63
+ Transformer,
64
+ TransformerType,
65
+ # Type resolution helpers
66
+ _get_base_processor_message_in_type,
67
+ _get_base_processor_message_out_type,
68
+ _get_base_processor_settings_type,
69
+ _get_base_processor_state_type,
70
+ _get_processor_message_type,
71
+ # Type utilities
72
+ check_message_type_compatibility,
73
+ get_base_adaptive_transformer_type,
74
+ get_base_consumer_type,
75
+ get_base_producer_type,
76
+ get_base_transformer_type,
77
+ is_sample_message,
78
+ # Decorators
79
+ processor_state,
80
+ # Profiling
81
+ profile_subpub,
82
+ resolve_typevar,
83
+ )
19
84
 
20
- # Method to be implemented by subclasses to construct the specific generator
21
- def construct_generator(self):
22
- raise NotImplementedError
23
-
24
- @ez.subscriber(INPUT_SETTINGS)
25
- async def on_settings(self, msg: ez.Settings) -> None:
26
- self.apply_settings(msg)
27
- self.construct_generator()
28
-
29
- @ez.subscriber(INPUT_SIGNAL, zero_copy=True)
30
- @ez.publisher(OUTPUT_SIGNAL)
31
- async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
32
- try:
33
- ret = self.STATE.gen.send(message)
34
- if math.prod(ret.data.shape) > 0:
35
- yield self.OUTPUT_SIGNAL, ret
36
- except (StopIteration, GeneratorExit):
37
- ez.logger.debug(f"Generator closed in {self.address}")
38
- except Exception:
39
- ez.logger.info(traceback.format_exc())
85
+ __all__ = [
86
+ # Protocols
87
+ "Processor",
88
+ "Producer",
89
+ "Consumer",
90
+ "Transformer",
91
+ "StatefulProcessor",
92
+ "StatefulProducer",
93
+ "StatefulConsumer",
94
+ "StatefulTransformer",
95
+ "AdaptiveTransformer",
96
+ # Type variables
97
+ "MessageInType",
98
+ "MessageOutType",
99
+ "SettingsType",
100
+ "StateType",
101
+ "ProducerType",
102
+ "ConsumerType",
103
+ "TransformerType",
104
+ "AdaptiveTransformerType",
105
+ # Decorators
106
+ "processor_state",
107
+ # Base processor classes
108
+ "BaseProcessor",
109
+ "BaseProducer",
110
+ "BaseConsumer",
111
+ "BaseTransformer",
112
+ # Stateful classes
113
+ "Stateful",
114
+ "BaseStatefulProcessor",
115
+ "BaseStatefulProducer",
116
+ "BaseStatefulConsumer",
117
+ "BaseStatefulTransformer",
118
+ "BaseAdaptiveTransformer",
119
+ "BaseAsyncTransformer",
120
+ # Composite classes
121
+ "CompositeStateful",
122
+ "CompositeProcessor",
123
+ "CompositeProducer",
124
+ # Unit classes
125
+ "BaseProducerUnit",
126
+ "BaseProcessorUnit",
127
+ "BaseConsumerUnit",
128
+ "BaseTransformerUnit",
129
+ "BaseAdaptiveTransformerUnit",
130
+ "GenAxisArray",
131
+ # Type resolution helpers
132
+ "get_base_producer_type",
133
+ "get_base_consumer_type",
134
+ "get_base_transformer_type",
135
+ "get_base_adaptive_transformer_type",
136
+ "_get_base_processor_settings_type",
137
+ "_get_base_processor_message_in_type",
138
+ "_get_base_processor_message_out_type",
139
+ "_get_base_processor_state_type",
140
+ "_get_processor_message_type",
141
+ # Message types
142
+ "SampleMessage",
143
+ "is_sample_message",
144
+ # Profiling
145
+ "profile_subpub",
146
+ # Type utilities
147
+ "check_message_type_compatibility",
148
+ "resolve_typevar",
149
+ ]