ezmsg-sigproc 1.8.2__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +36 -39
  3. ezmsg/sigproc/adaptive_lattice_notch.py +231 -0
  4. ezmsg/sigproc/affinetransform.py +169 -163
  5. ezmsg/sigproc/aggregate.py +133 -101
  6. ezmsg/sigproc/bandpower.py +64 -52
  7. ezmsg/sigproc/base.py +1242 -0
  8. ezmsg/sigproc/butterworthfilter.py +37 -33
  9. ezmsg/sigproc/cheby.py +29 -17
  10. ezmsg/sigproc/combfilter.py +163 -0
  11. ezmsg/sigproc/decimate.py +19 -10
  12. ezmsg/sigproc/detrend.py +29 -0
  13. ezmsg/sigproc/diff.py +81 -0
  14. ezmsg/sigproc/downsample.py +78 -84
  15. ezmsg/sigproc/ewma.py +197 -0
  16. ezmsg/sigproc/extract_axis.py +41 -0
  17. ezmsg/sigproc/filter.py +257 -141
  18. ezmsg/sigproc/filterbank.py +247 -199
  19. ezmsg/sigproc/math/abs.py +17 -22
  20. ezmsg/sigproc/math/clip.py +24 -24
  21. ezmsg/sigproc/math/difference.py +34 -30
  22. ezmsg/sigproc/math/invert.py +13 -25
  23. ezmsg/sigproc/math/log.py +28 -33
  24. ezmsg/sigproc/math/scale.py +18 -26
  25. ezmsg/sigproc/quantize.py +71 -0
  26. ezmsg/sigproc/resample.py +298 -0
  27. ezmsg/sigproc/sampler.py +241 -259
  28. ezmsg/sigproc/scaler.py +55 -218
  29. ezmsg/sigproc/signalinjector.py +52 -43
  30. ezmsg/sigproc/slicer.py +81 -89
  31. ezmsg/sigproc/spectrogram.py +77 -75
  32. ezmsg/sigproc/spectrum.py +203 -168
  33. ezmsg/sigproc/synth.py +546 -393
  34. ezmsg/sigproc/transpose.py +131 -0
  35. ezmsg/sigproc/util/asio.py +156 -0
  36. ezmsg/sigproc/util/message.py +31 -0
  37. ezmsg/sigproc/util/profile.py +55 -12
  38. ezmsg/sigproc/util/typeresolution.py +83 -0
  39. ezmsg/sigproc/wavelets.py +154 -153
  40. ezmsg/sigproc/window.py +269 -211
  41. {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/METADATA +2 -1
  42. ezmsg_sigproc-2.1.0.dist-info/RECORD +51 -0
  43. ezmsg_sigproc-1.8.2.dist-info/RECORD +0 -39
  44. {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/WHEEL +0 -0
  45. {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -3,7 +3,6 @@ import typing
3
3
  import numpy as np
4
4
  import numpy.typing as npt
5
5
  import ezmsg.core as ez
6
- from ezmsg.util.generator import consumer
7
6
  from ezmsg.util.messages.axisarray import (
8
7
  AxisArray,
9
8
  slice_along_axis,
@@ -12,7 +11,11 @@ from ezmsg.util.messages.axisarray import (
12
11
  )
13
12
 
14
13
  from .spectral import OptionsEnum
15
- from .base import GenAxisArray
14
+ from .base import (
15
+ BaseStatefulTransformer,
16
+ BaseTransformerUnit,
17
+ processor_state,
18
+ )
16
19
 
17
20
 
18
21
  class AggregationFunction(OptionsEnum):
@@ -33,6 +36,7 @@ class AggregationFunction(OptionsEnum):
33
36
  NANSUM = "nansum"
34
37
  ARGMIN = "argmin"
35
38
  ARGMAX = "argmax"
39
+ TRAPEZOID = "trapezoid"
36
40
 
37
41
 
38
42
  AGGREGATORS = {
@@ -51,133 +55,161 @@ AGGREGATORS = {
51
55
  AggregationFunction.NANSUM: np.nansum,
52
56
  AggregationFunction.ARGMIN: np.argmin,
53
57
  AggregationFunction.ARGMAX: np.argmax,
58
+ # Note: Some methods require x-coordinates and
59
+ # are handled specially in `_process`.
60
+ AggregationFunction.TRAPEZOID: np.trapezoid,
54
61
  }
55
62
 
56
63
 
57
- @consumer
58
- def ranged_aggregate(
59
- axis: str | None = None,
60
- bands: list[tuple[float, float]] | None = None,
61
- operation: AggregationFunction = AggregationFunction.MEAN,
62
- ):
64
+ class RangedAggregateSettings(ez.Settings):
65
+ """
66
+ Settings for ``RangedAggregate``.
63
67
  """
64
- Apply an aggregation operation over one or more bands.
65
68
 
66
- Args:
67
- axis: The name of the axis along which to apply the bands.
68
- bands: [(band1_min, band1_max), (band2_min, band2_max), ...]
69
- If not set then this acts as a passthrough node.
70
- operation: :obj:`AggregationFunction` to apply to each band.
69
+ axis: str | None = None
70
+ """The name of the axis along which to apply the bands."""
71
71
 
72
- Returns:
73
- A primed generator object ready to yield an :obj:`AxisArray` for each .send(axis_array)
72
+ bands: list[tuple[float, float]] | None = None
73
+ """
74
+ [(band1_min, band1_max), (band2_min, band2_max), ...]
75
+ If not set then this acts as a passthrough node.
74
76
  """
75
- msg_out = AxisArray(np.array([]), dims=[""])
76
77
 
77
- # State variables
78
+ operation: AggregationFunction = AggregationFunction.MEAN
79
+ """:obj:`AggregationFunction` to apply to each band."""
80
+
81
+
82
+ @processor_state
83
+ class RangedAggregateState:
78
84
  slices: list[tuple[typing.Any, ...]] | None = None
79
85
  out_axis: AxisBase | None = None
80
86
  ax_vec: npt.NDArray | None = None
81
87
 
82
- # Reset if any of these changes. Key not checked because continuity between chunks not required.
83
- check_inputs = {"gain": None, "offset": None, "len": None, "key": None}
84
88
 
85
- while True:
86
- msg_in: AxisArray = yield msg_out
87
- if bands is None:
88
- msg_out = msg_in
89
+ class RangedAggregateTransformer(
90
+ BaseStatefulTransformer[
91
+ RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateState
92
+ ]
93
+ ):
94
+ def __call__(self, message: AxisArray) -> AxisArray:
95
+ # Override for shortcut passthrough mode.
96
+ if self.settings.bands is None:
97
+ return message
98
+ return super().__call__(message)
99
+
100
+ def _hash_message(self, message: AxisArray) -> int:
101
+ axis = self.settings.axis or message.dims[0]
102
+ target_axis = message.get_axis(axis)
103
+
104
+ hash_components = (message.key,)
105
+ if hasattr(target_axis, "data"):
106
+ hash_components += (len(target_axis.data),)
107
+ elif isinstance(target_axis, AxisArray.LinearAxis):
108
+ hash_components += (target_axis.gain, target_axis.offset)
109
+ return hash(hash_components)
110
+
111
+ def _reset_state(self, message: AxisArray) -> None:
112
+ axis = self.settings.axis or message.dims[0]
113
+ target_axis = message.get_axis(axis)
114
+ ax_idx = message.get_axis_idx(axis)
115
+
116
+ if hasattr(target_axis, "data"):
117
+ self._state.ax_vec = target_axis.data
89
118
  else:
90
- axis = axis or msg_in.dims[0]
91
- target_axis = msg_in.get_axis(axis)
119
+ self._state.ax_vec = target_axis.value(
120
+ np.arange(message.data.shape[ax_idx])
121
+ )
92
122
 
93
- # Check if we need to reset state
94
- b_reset = msg_in.key != check_inputs["key"]
123
+ ax_dat = []
124
+ slices = []
125
+ for start, stop in self.settings.bands:
126
+ inds = np.where(
127
+ np.logical_and(self._state.ax_vec >= start, self._state.ax_vec <= stop)
128
+ )[0]
129
+ slices.append(np.s_[inds[0] : inds[-1] + 1])
95
130
  if hasattr(target_axis, "data"):
96
- b_reset = b_reset or len(target_axis.data) != check_inputs["len"]
97
- elif isinstance(target_axis, AxisArray.LinearAxis):
98
- b_reset = b_reset or target_axis.gain != check_inputs["gain"]
99
- b_reset = b_reset or target_axis.offset != check_inputs["offset"]
100
-
101
- if b_reset:
102
- # Update check variables
103
- check_inputs["key"] = msg_in.key
104
- if hasattr(target_axis, "data"):
105
- check_inputs["len"] = len(target_axis.data)
131
+ if self._state.ax_vec.dtype.type is np.str_:
132
+ sl_dat = f"{self._state.ax_vec[start]} - {self._state.ax_vec[stop]}"
106
133
  else:
107
- check_inputs["gain"] = target_axis.gain
108
- check_inputs["offset"] = target_axis.offset
109
-
110
- # If the axis we are operating on has changed (e.g., "time" or "win" axis always changes),
111
- # or the key has changed, then recalculate slices.
134
+ ax_dat.append(np.mean(self._state.ax_vec[inds]))
135
+ else:
136
+ sl_dat = target_axis.value(np.mean(inds))
137
+ ax_dat.append(sl_dat)
138
+
139
+ self._state.slices = slices
140
+ self._state.out_axis = AxisArray.CoordinateAxis(
141
+ data=np.array(ax_dat),
142
+ dims=[axis],
143
+ unit=target_axis.unit,
144
+ )
112
145
 
113
- ax_idx = msg_in.get_axis_idx(axis)
146
+ def _process(self, message: AxisArray) -> AxisArray:
147
+ axis = self.settings.axis or message.dims[0]
148
+ ax_idx = message.get_axis_idx(axis)
149
+ agg_func = AGGREGATORS[self.settings.operation]
114
150
 
115
- if hasattr(target_axis, "data"):
116
- ax_vec = target_axis.data
117
- else:
118
- ax_vec = target_axis.value(np.arange(msg_in.data.shape[ax_idx]))
119
-
120
- slices = []
121
- ax_dat = []
122
- for start, stop in bands:
123
- inds = np.where(np.logical_and(ax_vec >= start, ax_vec <= stop))[0]
124
- slices.append(np.s_[inds[0] : inds[-1] + 1])
125
- if hasattr(target_axis, "data"):
126
- if ax_vec.dtype.type is np.str_:
127
- sl_dat = f"{ax_vec[start]} - {ax_vec[stop]}"
128
- else:
129
- sl_dat = ax_dat.append(np.mean(ax_vec[inds]))
130
- else:
131
- sl_dat = target_axis.value(np.mean(inds))
132
- ax_dat.append(sl_dat)
133
-
134
- out_axis = AxisArray.CoordinateAxis(
135
- data=np.array(ax_dat),
136
- dims=[axis],
137
- unit=target_axis.unit,
151
+ if self.settings.operation in [
152
+ AggregationFunction.TRAPEZOID,
153
+ ]:
154
+ # Special handling for methods that require x-coordinates.
155
+ out_data = [
156
+ agg_func(
157
+ slice_along_axis(message.data, sl, axis=ax_idx),
158
+ x=self._state.ax_vec[sl],
159
+ axis=ax_idx,
138
160
  )
139
-
140
- agg_func = AGGREGATORS[operation]
161
+ for sl in self._state.slices
162
+ ]
163
+ else:
141
164
  out_data = [
142
- agg_func(slice_along_axis(msg_in.data, sl, axis=ax_idx), axis=ax_idx)
143
- for sl in slices
165
+ agg_func(slice_along_axis(message.data, sl, axis=ax_idx), axis=ax_idx)
166
+ for sl in self._state.slices
144
167
  ]
145
168
 
146
- msg_out = replace(
147
- msg_in,
148
- data=np.stack(out_data, axis=ax_idx),
149
- axes={**msg_in.axes, axis: out_axis},
150
- )
151
- if operation in [AggregationFunction.ARGMIN, AggregationFunction.ARGMAX]:
152
- # Convert indices returned by argmin/argmax into the value along the axis.
153
- out_data = []
154
- for sl_ix, sl in enumerate(slices):
155
- offsets = np.take(msg_out.data, [sl_ix], axis=ax_idx)
156
- out_data.append(ax_vec[sl][offsets])
157
- msg_out.data = np.concatenate(out_data, axis=ax_idx)
169
+ msg_out = replace(
170
+ message,
171
+ data=np.stack(out_data, axis=ax_idx),
172
+ axes={**message.axes, axis: self._state.out_axis},
173
+ )
158
174
 
175
+ if self.settings.operation in [
176
+ AggregationFunction.ARGMIN,
177
+ AggregationFunction.ARGMAX,
178
+ ]:
179
+ out_data = []
180
+ for sl_ix, sl in enumerate(self._state.slices):
181
+ offsets = np.take(msg_out.data, [sl_ix], axis=ax_idx)
182
+ out_data.append(self._state.ax_vec[sl][offsets])
183
+ msg_out.data = np.concatenate(out_data, axis=ax_idx)
159
184
 
160
- class RangedAggregateSettings(ez.Settings):
161
- """
162
- Settings for ``RangedAggregate``.
163
- See :obj:`ranged_aggregate` for details.
164
- """
185
+ return msg_out
165
186
 
166
- axis: str | None = None
167
- bands: list[tuple[float, float]] | None = None
168
- operation: AggregationFunction = AggregationFunction.MEAN
169
187
 
188
+ class RangedAggregate(
189
+ BaseTransformerUnit[
190
+ RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateTransformer
191
+ ]
192
+ ):
193
+ SETTINGS = RangedAggregateSettings
170
194
 
171
- class RangedAggregate(GenAxisArray):
172
- """
173
- Unit for :obj:`ranged_aggregate`
195
+
196
+ def ranged_aggregate(
197
+ axis: str | None = None,
198
+ bands: list[tuple[float, float]] | None = None,
199
+ operation: AggregationFunction = AggregationFunction.MEAN,
200
+ ) -> RangedAggregateTransformer:
174
201
  """
202
+ Apply an aggregation operation over one or more bands.
175
203
 
176
- SETTINGS = RangedAggregateSettings
204
+ Args:
205
+ axis: The name of the axis along which to apply the bands.
206
+ bands: [(band1_min, band1_max), (band2_min, band2_max), ...]
207
+ If not set then this acts as a passthrough node.
208
+ operation: :obj:`AggregationFunction` to apply to each band.
177
209
 
178
- def construct_generator(self):
179
- self.STATE.gen = ranged_aggregate(
180
- axis=self.SETTINGS.axis,
181
- bands=self.SETTINGS.bands,
182
- operation=self.SETTINGS.operation,
183
- )
210
+ Returns:
211
+ :obj:`RangedAggregateTransformer`
212
+ """
213
+ return RangedAggregateTransformer(
214
+ RangedAggregateSettings(axis=axis, bands=bands, operation=operation)
215
+ )
@@ -1,76 +1,88 @@
1
1
  from dataclasses import field
2
- import typing
3
2
 
4
- import numpy as np
5
3
  import ezmsg.core as ez
6
4
  from ezmsg.util.messages.axisarray import AxisArray
7
- from ezmsg.util.generator import consumer, compose
8
5
 
9
- from .spectrogram import spectrogram, SpectrogramSettings
10
- from .aggregate import ranged_aggregate, AggregationFunction
11
- from .base import GenAxisArray
12
-
13
-
14
- @consumer
15
- def bandpower(
16
- spectrogram_settings: SpectrogramSettings,
17
- bands: list[tuple[float, float]] | None = [
18
- (17, 30),
19
- (70, 170),
20
- ],
21
- ) -> typing.Generator[AxisArray, AxisArray, None]:
22
- """
23
- Calculate the average spectral power in each band.
24
-
25
- Args:
26
- spectrogram_settings: Settings for spectrogram calculation.
27
- bands: (min, max) tuples of band limits in Hz.
28
-
29
- Returns:
30
- A primed generator object ready to yield an :obj:`AxisArray` for each .send(axis_array)
31
- with the data payload being the average spectral power in each band of the input data.
32
- """
33
- msg_out = AxisArray(np.array([]), dims=[""])
34
-
35
- f_spec = spectrogram(
36
- window_dur=spectrogram_settings.window_dur,
37
- window_shift=spectrogram_settings.window_shift,
38
- window_anchor=spectrogram_settings.window_anchor,
39
- window=spectrogram_settings.window,
40
- transform=spectrogram_settings.transform,
41
- output=spectrogram_settings.output,
42
- )
43
- f_agg = ranged_aggregate(
44
- axis="freq", bands=bands, operation=AggregationFunction.MEAN
45
- )
46
- pipeline = compose(f_spec, f_agg)
47
-
48
- while True:
49
- msg_in: AxisArray = yield msg_out
50
- msg_out = pipeline(msg_in)
6
+ from .spectrogram import SpectrogramSettings, SpectrogramTransformer
7
+ from .aggregate import (
8
+ AggregationFunction,
9
+ RangedAggregateTransformer,
10
+ RangedAggregateSettings,
11
+ )
12
+ from .base import (
13
+ BaseProcessor,
14
+ CompositeProcessor,
15
+ BaseStatefulProcessor,
16
+ BaseTransformerUnit,
17
+ )
51
18
 
52
19
 
53
20
  class BandPowerSettings(ez.Settings):
54
21
  """
55
22
  Settings for ``BandPower``.
56
- See :obj:`bandpower` for details.
57
23
  """
58
24
 
59
25
  spectrogram_settings: SpectrogramSettings = field(
60
26
  default_factory=SpectrogramSettings
61
27
  )
28
+ """
29
+ Settings for spectrogram calculation.
30
+ """
31
+
62
32
  bands: list[tuple[float, float]] | None = field(
63
33
  default_factory=lambda: [(17, 30), (70, 170)]
64
34
  )
35
+ """
36
+ (min, max) tuples of band limits in Hz.
37
+ """
38
+
39
+ aggregation: AggregationFunction = AggregationFunction.MEAN
40
+ """:obj:`AggregationFunction` to apply to each band."""
65
41
 
66
42
 
67
- class BandPower(GenAxisArray):
68
- """:obj:`Unit` for :obj:`bandpower`."""
43
+ class BandPowerTransformer(CompositeProcessor[BandPowerSettings, AxisArray, AxisArray]):
44
+ @staticmethod
45
+ def _initialize_processors(
46
+ settings: BandPowerSettings,
47
+ ) -> dict[str, BaseProcessor | BaseStatefulProcessor]:
48
+ return {
49
+ "spectrogram": SpectrogramTransformer(
50
+ settings=settings.spectrogram_settings
51
+ ),
52
+ "aggregate": RangedAggregateTransformer(
53
+ settings=RangedAggregateSettings(
54
+ axis="freq",
55
+ bands=settings.bands,
56
+ operation=settings.aggregation,
57
+ )
58
+ ),
59
+ }
69
60
 
61
+
62
+ class BandPower(
63
+ BaseTransformerUnit[BandPowerSettings, AxisArray, AxisArray, BandPowerTransformer]
64
+ ):
70
65
  SETTINGS = BandPowerSettings
71
66
 
72
- def construct_generator(self):
73
- self.STATE.gen = bandpower(
74
- spectrogram_settings=self.SETTINGS.spectrogram_settings,
75
- bands=self.SETTINGS.bands,
67
+
68
+ def bandpower(
69
+ spectrogram_settings: SpectrogramSettings,
70
+ bands: list[tuple[float, float]] | None = [
71
+ (17, 30),
72
+ (70, 170),
73
+ ],
74
+ aggregation: AggregationFunction = AggregationFunction.MEAN,
75
+ ) -> BandPowerTransformer:
76
+ """
77
+ Calculate the average spectral power in each band.
78
+
79
+ Returns:
80
+ :obj:`BandPowerTransformer`
81
+ """
82
+ return BandPowerTransformer(
83
+ settings=BandPowerSettings(
84
+ spectrogram_settings=spectrogram_settings,
85
+ bands=bands,
86
+ aggregation=aggregation,
76
87
  )
88
+ )