ezmsg-sigproc 1.7.0__py3-none-any.whl → 2.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +22 -4
- ezmsg/sigproc/activation.py +31 -40
- ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
- ezmsg/sigproc/affinetransform.py +171 -169
- ezmsg/sigproc/aggregate.py +190 -97
- ezmsg/sigproc/bandpower.py +60 -55
- ezmsg/sigproc/base.py +143 -33
- ezmsg/sigproc/butterworthfilter.py +34 -38
- ezmsg/sigproc/butterworthzerophase.py +305 -0
- ezmsg/sigproc/cheby.py +23 -17
- ezmsg/sigproc/combfilter.py +160 -0
- ezmsg/sigproc/coordinatespaces.py +159 -0
- ezmsg/sigproc/decimate.py +15 -10
- ezmsg/sigproc/denormalize.py +78 -0
- ezmsg/sigproc/detrend.py +28 -0
- ezmsg/sigproc/diff.py +82 -0
- ezmsg/sigproc/downsample.py +72 -81
- ezmsg/sigproc/ewma.py +217 -0
- ezmsg/sigproc/ewmfilter.py +1 -1
- ezmsg/sigproc/extract_axis.py +39 -0
- ezmsg/sigproc/fbcca.py +307 -0
- ezmsg/sigproc/filter.py +254 -148
- ezmsg/sigproc/filterbank.py +226 -214
- ezmsg/sigproc/filterbankdesign.py +129 -0
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +117 -0
- ezmsg/sigproc/gaussiansmoothing.py +89 -0
- ezmsg/sigproc/kaiser.py +106 -0
- ezmsg/sigproc/linear.py +120 -0
- ezmsg/sigproc/math/abs.py +23 -22
- ezmsg/sigproc/math/add.py +120 -0
- ezmsg/sigproc/math/clip.py +33 -25
- ezmsg/sigproc/math/difference.py +117 -43
- ezmsg/sigproc/math/invert.py +18 -25
- ezmsg/sigproc/math/log.py +38 -33
- ezmsg/sigproc/math/scale.py +24 -25
- ezmsg/sigproc/messages.py +1 -2
- ezmsg/sigproc/quantize.py +68 -0
- ezmsg/sigproc/resample.py +278 -0
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +209 -254
- ezmsg/sigproc/scaler.py +93 -218
- ezmsg/sigproc/signalinjector.py +44 -43
- ezmsg/sigproc/slicer.py +74 -102
- ezmsg/sigproc/spectral.py +3 -3
- ezmsg/sigproc/spectrogram.py +70 -70
- ezmsg/sigproc/spectrum.py +187 -173
- ezmsg/sigproc/transpose.py +134 -0
- ezmsg/sigproc/util/__init__.py +0 -0
- ezmsg/sigproc/util/asio.py +25 -0
- ezmsg/sigproc/util/axisarray_buffer.py +365 -0
- ezmsg/sigproc/util/buffer.py +449 -0
- ezmsg/sigproc/util/message.py +17 -0
- ezmsg/sigproc/util/profile.py +23 -0
- ezmsg/sigproc/util/sparse.py +115 -0
- ezmsg/sigproc/util/typeresolution.py +17 -0
- ezmsg/sigproc/wavelets.py +147 -154
- ezmsg/sigproc/window.py +248 -210
- ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
- {ezmsg_sigproc-1.7.0.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -1
- ezmsg/sigproc/synth.py +0 -621
- ezmsg_sigproc-1.7.0.dist-info/METADATA +0 -58
- ezmsg_sigproc-1.7.0.dist-info/RECORD +0 -36
- /ezmsg_sigproc-1.7.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/aggregate.py
CHANGED
|
@@ -1,18 +1,32 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Aggregation operations over arrays.
|
|
3
|
+
|
|
4
|
+
.. note::
|
|
5
|
+
:obj:`AggregateTransformer` supports the :doc:`Array API standard </guides/explanations/array_api>`,
|
|
6
|
+
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
|
|
7
|
+
:obj:`RangedAggregateTransformer` currently requires NumPy arrays.
|
|
8
|
+
"""
|
|
9
|
+
|
|
1
10
|
import typing
|
|
2
11
|
|
|
12
|
+
import ezmsg.core as ez
|
|
3
13
|
import numpy as np
|
|
4
14
|
import numpy.typing as npt
|
|
5
|
-
|
|
6
|
-
from ezmsg.
|
|
15
|
+
from array_api_compat import get_namespace
|
|
16
|
+
from ezmsg.baseproc import (
|
|
17
|
+
BaseStatefulTransformer,
|
|
18
|
+
BaseTransformer,
|
|
19
|
+
BaseTransformerUnit,
|
|
20
|
+
processor_state,
|
|
21
|
+
)
|
|
7
22
|
from ezmsg.util.messages.axisarray import (
|
|
8
23
|
AxisArray,
|
|
9
|
-
slice_along_axis,
|
|
10
24
|
AxisBase,
|
|
11
25
|
replace,
|
|
26
|
+
slice_along_axis,
|
|
12
27
|
)
|
|
13
28
|
|
|
14
29
|
from .spectral import OptionsEnum
|
|
15
|
-
from .base import GenAxisArray
|
|
16
30
|
|
|
17
31
|
|
|
18
32
|
class AggregationFunction(OptionsEnum):
|
|
@@ -33,6 +47,7 @@ class AggregationFunction(OptionsEnum):
|
|
|
33
47
|
NANSUM = "nansum"
|
|
34
48
|
ARGMIN = "argmin"
|
|
35
49
|
ARGMAX = "argmax"
|
|
50
|
+
TRAPEZOID = "trapezoid"
|
|
36
51
|
|
|
37
52
|
|
|
38
53
|
AGGREGATORS = {
|
|
@@ -51,133 +66,211 @@ AGGREGATORS = {
|
|
|
51
66
|
AggregationFunction.NANSUM: np.nansum,
|
|
52
67
|
AggregationFunction.ARGMIN: np.argmin,
|
|
53
68
|
AggregationFunction.ARGMAX: np.argmax,
|
|
69
|
+
# Note: Some methods require x-coordinates and
|
|
70
|
+
# are handled specially in `_process`.
|
|
71
|
+
AggregationFunction.TRAPEZOID: np.trapezoid,
|
|
54
72
|
}
|
|
55
73
|
|
|
56
74
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
bands: list[tuple[float, float]] | None = None,
|
|
61
|
-
operation: AggregationFunction = AggregationFunction.MEAN,
|
|
62
|
-
):
|
|
75
|
+
class RangedAggregateSettings(ez.Settings):
|
|
76
|
+
"""
|
|
77
|
+
Settings for ``RangedAggregate``.
|
|
63
78
|
"""
|
|
64
|
-
Apply an aggregation operation over one or more bands.
|
|
65
79
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
bands: [(band1_min, band1_max), (band2_min, band2_max), ...]
|
|
69
|
-
If not set then this acts as a passthrough node.
|
|
70
|
-
operation: :obj:`AggregationFunction` to apply to each band.
|
|
80
|
+
axis: str | None = None
|
|
81
|
+
"""The name of the axis along which to apply the bands."""
|
|
71
82
|
|
|
72
|
-
|
|
73
|
-
A primed generator object ready to yield an :obj:`AxisArray` for each .send(axis_array)
|
|
83
|
+
bands: list[tuple[float, float]] | None = None
|
|
74
84
|
"""
|
|
75
|
-
|
|
85
|
+
[(band1_min, band1_max), (band2_min, band2_max), ...]
|
|
86
|
+
If not set then this acts as a passthrough node.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
operation: AggregationFunction = AggregationFunction.MEAN
|
|
90
|
+
""":obj:`AggregationFunction` to apply to each band."""
|
|
76
91
|
|
|
77
|
-
|
|
92
|
+
|
|
93
|
+
@processor_state
|
|
94
|
+
class RangedAggregateState:
|
|
78
95
|
slices: list[tuple[typing.Any, ...]] | None = None
|
|
79
96
|
out_axis: AxisBase | None = None
|
|
80
97
|
ax_vec: npt.NDArray | None = None
|
|
81
98
|
|
|
82
|
-
# Reset if any of these changes. Key not checked because continuity between chunks not required.
|
|
83
|
-
check_inputs = {"gain": None, "offset": None, "len": None, "key": None}
|
|
84
99
|
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
100
|
+
class RangedAggregateTransformer(
|
|
101
|
+
BaseStatefulTransformer[RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateState]
|
|
102
|
+
):
|
|
103
|
+
def __call__(self, message: AxisArray) -> AxisArray:
|
|
104
|
+
# Override for shortcut passthrough mode.
|
|
105
|
+
if self.settings.bands is None:
|
|
106
|
+
return message
|
|
107
|
+
return super().__call__(message)
|
|
108
|
+
|
|
109
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
110
|
+
axis = self.settings.axis or message.dims[0]
|
|
111
|
+
target_axis = message.get_axis(axis)
|
|
112
|
+
|
|
113
|
+
hash_components = (message.key,)
|
|
114
|
+
if hasattr(target_axis, "data"):
|
|
115
|
+
hash_components += (len(target_axis.data),)
|
|
116
|
+
elif isinstance(target_axis, AxisArray.LinearAxis):
|
|
117
|
+
hash_components += (target_axis.gain, target_axis.offset)
|
|
118
|
+
return hash(hash_components)
|
|
119
|
+
|
|
120
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
121
|
+
axis = self.settings.axis or message.dims[0]
|
|
122
|
+
target_axis = message.get_axis(axis)
|
|
123
|
+
ax_idx = message.get_axis_idx(axis)
|
|
124
|
+
|
|
125
|
+
if hasattr(target_axis, "data"):
|
|
126
|
+
self._state.ax_vec = target_axis.data
|
|
89
127
|
else:
|
|
90
|
-
|
|
91
|
-
target_axis = msg_in.get_axis(axis)
|
|
128
|
+
self._state.ax_vec = target_axis.value(np.arange(message.data.shape[ax_idx]))
|
|
92
129
|
|
|
93
|
-
|
|
94
|
-
|
|
130
|
+
ax_dat = []
|
|
131
|
+
slices = []
|
|
132
|
+
for start, stop in self.settings.bands:
|
|
133
|
+
inds = np.where(np.logical_and(self._state.ax_vec >= start, self._state.ax_vec <= stop))[0]
|
|
134
|
+
slices.append(np.s_[inds[0] : inds[-1] + 1])
|
|
95
135
|
if hasattr(target_axis, "data"):
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
b_reset = b_reset or target_axis.gain != check_inputs["gain"]
|
|
99
|
-
b_reset = b_reset or target_axis.offset != check_inputs["offset"]
|
|
100
|
-
|
|
101
|
-
if b_reset:
|
|
102
|
-
# Update check variables
|
|
103
|
-
check_inputs["key"] = msg_in.key
|
|
104
|
-
if hasattr(target_axis, "data"):
|
|
105
|
-
check_inputs["len"] = len(target_axis.data)
|
|
136
|
+
if self._state.ax_vec.dtype.type is np.str_:
|
|
137
|
+
sl_dat = f"{self._state.ax_vec[start]} - {self._state.ax_vec[stop]}"
|
|
106
138
|
else:
|
|
107
|
-
|
|
108
|
-
|
|
139
|
+
ax_dat.append(np.mean(self._state.ax_vec[inds]))
|
|
140
|
+
else:
|
|
141
|
+
sl_dat = target_axis.value(np.mean(inds))
|
|
142
|
+
ax_dat.append(sl_dat)
|
|
109
143
|
|
|
110
|
-
|
|
111
|
-
|
|
144
|
+
self._state.slices = slices
|
|
145
|
+
self._state.out_axis = AxisArray.CoordinateAxis(
|
|
146
|
+
data=np.array(ax_dat),
|
|
147
|
+
dims=[axis],
|
|
148
|
+
unit=target_axis.unit,
|
|
149
|
+
)
|
|
112
150
|
|
|
113
|
-
|
|
151
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
152
|
+
axis = self.settings.axis or message.dims[0]
|
|
153
|
+
ax_idx = message.get_axis_idx(axis)
|
|
154
|
+
agg_func = AGGREGATORS[self.settings.operation]
|
|
114
155
|
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
slices.append(np.s_[inds[0] : inds[-1] + 1])
|
|
125
|
-
if hasattr(target_axis, "data"):
|
|
126
|
-
if ax_vec.dtype.type is np.str_:
|
|
127
|
-
sl_dat = f"{ax_vec[start]} - {ax_vec[stop]}"
|
|
128
|
-
else:
|
|
129
|
-
sl_dat = ax_dat.append(np.mean(ax_vec[inds]))
|
|
130
|
-
else:
|
|
131
|
-
sl_dat = target_axis.value(np.mean(inds))
|
|
132
|
-
ax_dat.append(sl_dat)
|
|
133
|
-
|
|
134
|
-
out_axis = AxisArray.CoordinateAxis(
|
|
135
|
-
data=np.array(ax_dat),
|
|
136
|
-
dims=[axis],
|
|
137
|
-
unit=target_axis.unit,
|
|
156
|
+
if self.settings.operation in [
|
|
157
|
+
AggregationFunction.TRAPEZOID,
|
|
158
|
+
]:
|
|
159
|
+
# Special handling for methods that require x-coordinates.
|
|
160
|
+
out_data = [
|
|
161
|
+
agg_func(
|
|
162
|
+
slice_along_axis(message.data, sl, axis=ax_idx),
|
|
163
|
+
x=self._state.ax_vec[sl],
|
|
164
|
+
axis=ax_idx,
|
|
138
165
|
)
|
|
139
|
-
|
|
140
|
-
|
|
166
|
+
for sl in self._state.slices
|
|
167
|
+
]
|
|
168
|
+
else:
|
|
141
169
|
out_data = [
|
|
142
|
-
agg_func(slice_along_axis(
|
|
143
|
-
for sl in slices
|
|
170
|
+
agg_func(slice_along_axis(message.data, sl, axis=ax_idx), axis=ax_idx) for sl in self._state.slices
|
|
144
171
|
]
|
|
145
172
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
if operation in [AggregationFunction.ARGMIN, AggregationFunction.ARGMAX]:
|
|
152
|
-
# Convert indices returned by argmin/argmax into the value along the axis.
|
|
153
|
-
out_data = []
|
|
154
|
-
for sl_ix, sl in enumerate(slices):
|
|
155
|
-
offsets = np.take(msg_out.data, [sl_ix], axis=ax_idx)
|
|
156
|
-
out_data.append(ax_vec[sl][offsets])
|
|
157
|
-
msg_out.data = np.concatenate(out_data, axis=ax_idx)
|
|
173
|
+
msg_out = replace(
|
|
174
|
+
message,
|
|
175
|
+
data=np.stack(out_data, axis=ax_idx),
|
|
176
|
+
axes={**message.axes, axis: self._state.out_axis},
|
|
177
|
+
)
|
|
158
178
|
|
|
179
|
+
if self.settings.operation in [
|
|
180
|
+
AggregationFunction.ARGMIN,
|
|
181
|
+
AggregationFunction.ARGMAX,
|
|
182
|
+
]:
|
|
183
|
+
out_data = []
|
|
184
|
+
for sl_ix, sl in enumerate(self._state.slices):
|
|
185
|
+
offsets = np.take(msg_out.data, [sl_ix], axis=ax_idx)
|
|
186
|
+
out_data.append(self._state.ax_vec[sl][offsets])
|
|
187
|
+
msg_out.data = np.concatenate(out_data, axis=ax_idx)
|
|
159
188
|
|
|
160
|
-
|
|
189
|
+
return msg_out
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class RangedAggregate(BaseTransformerUnit[RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateTransformer]):
|
|
193
|
+
SETTINGS = RangedAggregateSettings
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def ranged_aggregate(
|
|
197
|
+
axis: str | None = None,
|
|
198
|
+
bands: list[tuple[float, float]] | None = None,
|
|
199
|
+
operation: AggregationFunction = AggregationFunction.MEAN,
|
|
200
|
+
) -> RangedAggregateTransformer:
|
|
161
201
|
"""
|
|
162
|
-
|
|
163
|
-
|
|
202
|
+
Apply an aggregation operation over one or more bands.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
axis: The name of the axis along which to apply the bands.
|
|
206
|
+
bands: [(band1_min, band1_max), (band2_min, band2_max), ...]
|
|
207
|
+
If not set then this acts as a passthrough node.
|
|
208
|
+
operation: :obj:`AggregationFunction` to apply to each band.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
:obj:`RangedAggregateTransformer`
|
|
164
212
|
"""
|
|
213
|
+
return RangedAggregateTransformer(RangedAggregateSettings(axis=axis, bands=bands, operation=operation))
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class AggregateSettings(ez.Settings):
|
|
217
|
+
"""Settings for :obj:`Aggregate`."""
|
|
218
|
+
|
|
219
|
+
axis: str
|
|
220
|
+
"""The name of the axis to aggregate over. This axis will be removed from the output."""
|
|
165
221
|
|
|
166
|
-
axis: str | None = None
|
|
167
|
-
bands: list[tuple[float, float]] | None = None
|
|
168
222
|
operation: AggregationFunction = AggregationFunction.MEAN
|
|
223
|
+
""":obj:`AggregationFunction` to apply."""
|
|
169
224
|
|
|
170
225
|
|
|
171
|
-
class
|
|
226
|
+
class AggregateTransformer(BaseTransformer[AggregateSettings, AxisArray, AxisArray]):
|
|
172
227
|
"""
|
|
173
|
-
|
|
228
|
+
Transformer that aggregates an entire axis using a specified operation.
|
|
229
|
+
|
|
230
|
+
Unlike :obj:`RangedAggregateTransformer` which aggregates over specific ranges/bands
|
|
231
|
+
and preserves the axis (with one value per band), this transformer aggregates the
|
|
232
|
+
entire axis and removes it from the output, reducing dimensionality by one.
|
|
174
233
|
"""
|
|
175
234
|
|
|
176
|
-
|
|
235
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
236
|
+
xp = get_namespace(message.data)
|
|
237
|
+
axis_idx = message.get_axis_idx(self.settings.axis)
|
|
238
|
+
op = self.settings.operation
|
|
177
239
|
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
240
|
+
if op == AggregationFunction.NONE:
|
|
241
|
+
raise ValueError("AggregationFunction.NONE is not supported for full-axis aggregation")
|
|
242
|
+
|
|
243
|
+
if op == AggregationFunction.TRAPEZOID:
|
|
244
|
+
# Trapezoid integration requires x-coordinates
|
|
245
|
+
target_axis = message.get_axis(self.settings.axis)
|
|
246
|
+
if hasattr(target_axis, "data"):
|
|
247
|
+
x = target_axis.data
|
|
248
|
+
else:
|
|
249
|
+
x = target_axis.value(np.arange(message.data.shape[axis_idx]))
|
|
250
|
+
agg_data = np.trapezoid(np.asarray(message.data), x=x, axis=axis_idx)
|
|
251
|
+
else:
|
|
252
|
+
# Try array-API compatible function first, fall back to numpy
|
|
253
|
+
func_name = op.value
|
|
254
|
+
if hasattr(xp, func_name):
|
|
255
|
+
agg_data = getattr(xp, func_name)(message.data, axis=axis_idx)
|
|
256
|
+
else:
|
|
257
|
+
agg_data = AGGREGATORS[op](message.data, axis=axis_idx)
|
|
258
|
+
|
|
259
|
+
new_dims = list(message.dims)
|
|
260
|
+
new_dims.pop(axis_idx)
|
|
261
|
+
|
|
262
|
+
new_axes = dict(message.axes)
|
|
263
|
+
new_axes.pop(self.settings.axis, None)
|
|
264
|
+
|
|
265
|
+
return replace(
|
|
266
|
+
message,
|
|
267
|
+
data=agg_data,
|
|
268
|
+
dims=new_dims,
|
|
269
|
+
axes=new_axes,
|
|
183
270
|
)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class AggregateUnit(BaseTransformerUnit[AggregateSettings, AxisArray, AxisArray, AggregateTransformer]):
|
|
274
|
+
"""Unit that aggregates an entire axis using a specified operation."""
|
|
275
|
+
|
|
276
|
+
SETTINGS = AggregateSettings
|
ezmsg/sigproc/bandpower.py
CHANGED
|
@@ -1,75 +1,80 @@
|
|
|
1
1
|
from dataclasses import field
|
|
2
|
-
import typing
|
|
3
2
|
|
|
4
|
-
import numpy as np
|
|
5
3
|
import ezmsg.core as ez
|
|
4
|
+
from ezmsg.baseproc import (
|
|
5
|
+
BaseProcessor,
|
|
6
|
+
BaseStatefulProcessor,
|
|
7
|
+
BaseTransformerUnit,
|
|
8
|
+
CompositeProcessor,
|
|
9
|
+
)
|
|
6
10
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
-
from ezmsg.util.generator import consumer, compose
|
|
8
11
|
|
|
9
|
-
from .
|
|
10
|
-
|
|
11
|
-
|
|
12
|
+
from .aggregate import (
|
|
13
|
+
AggregationFunction,
|
|
14
|
+
RangedAggregateSettings,
|
|
15
|
+
RangedAggregateTransformer,
|
|
16
|
+
)
|
|
17
|
+
from .spectrogram import SpectrogramSettings, SpectrogramTransformer
|
|
12
18
|
|
|
13
19
|
|
|
14
|
-
|
|
15
|
-
def bandpower(
|
|
16
|
-
spectrogram_settings: SpectrogramSettings,
|
|
17
|
-
bands: list[tuple[float, float]] | None = [
|
|
18
|
-
(17, 30),
|
|
19
|
-
(70, 170),
|
|
20
|
-
],
|
|
21
|
-
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
20
|
+
class BandPowerSettings(ez.Settings):
|
|
22
21
|
"""
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
Args:
|
|
26
|
-
spectrogram_settings: Settings for spectrogram calculation.
|
|
27
|
-
bands: (min, max) tuples of band limits in Hz.
|
|
28
|
-
|
|
29
|
-
Returns:
|
|
30
|
-
A primed generator object ready to yield an :obj:`AxisArray` for each .send(axis_array)
|
|
31
|
-
with the data payload being the average spectral power in each band of the input data.
|
|
22
|
+
Settings for ``BandPower``.
|
|
32
23
|
"""
|
|
33
|
-
msg_out = AxisArray(np.array([]), dims=[""])
|
|
34
|
-
|
|
35
|
-
f_spec = spectrogram(
|
|
36
|
-
window_dur=spectrogram_settings.window_dur,
|
|
37
|
-
window_shift=spectrogram_settings.window_shift,
|
|
38
|
-
window=spectrogram_settings.window,
|
|
39
|
-
transform=spectrogram_settings.transform,
|
|
40
|
-
output=spectrogram_settings.output,
|
|
41
|
-
)
|
|
42
|
-
f_agg = ranged_aggregate(
|
|
43
|
-
axis="freq", bands=bands, operation=AggregationFunction.MEAN
|
|
44
|
-
)
|
|
45
|
-
pipeline = compose(f_spec, f_agg)
|
|
46
|
-
|
|
47
|
-
while True:
|
|
48
|
-
msg_in: AxisArray = yield msg_out
|
|
49
|
-
msg_out = pipeline(msg_in)
|
|
50
24
|
|
|
25
|
+
spectrogram_settings: SpectrogramSettings = field(default_factory=SpectrogramSettings)
|
|
26
|
+
"""
|
|
27
|
+
Settings for spectrogram calculation.
|
|
28
|
+
"""
|
|
51
29
|
|
|
52
|
-
|
|
30
|
+
bands: list[tuple[float, float]] | None = field(default_factory=lambda: [(17, 30), (70, 170)])
|
|
53
31
|
"""
|
|
54
|
-
|
|
55
|
-
See :obj:`bandpower` for details.
|
|
32
|
+
(min, max) tuples of band limits in Hz.
|
|
56
33
|
"""
|
|
57
34
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
bands: list[tuple[float, float]] | None = field(
|
|
62
|
-
default_factory=lambda: [(17, 30), (70, 170)]
|
|
63
|
-
)
|
|
35
|
+
aggregation: AggregationFunction = AggregationFunction.MEAN
|
|
36
|
+
""":obj:`AggregationFunction` to apply to each band."""
|
|
37
|
+
|
|
64
38
|
|
|
39
|
+
class BandPowerTransformer(CompositeProcessor[BandPowerSettings, AxisArray, AxisArray]):
|
|
40
|
+
@staticmethod
|
|
41
|
+
def _initialize_processors(
|
|
42
|
+
settings: BandPowerSettings,
|
|
43
|
+
) -> dict[str, BaseProcessor | BaseStatefulProcessor]:
|
|
44
|
+
return {
|
|
45
|
+
"spectrogram": SpectrogramTransformer(settings=settings.spectrogram_settings),
|
|
46
|
+
"aggregate": RangedAggregateTransformer(
|
|
47
|
+
settings=RangedAggregateSettings(
|
|
48
|
+
axis="freq",
|
|
49
|
+
bands=settings.bands,
|
|
50
|
+
operation=settings.aggregation,
|
|
51
|
+
)
|
|
52
|
+
),
|
|
53
|
+
}
|
|
65
54
|
|
|
66
|
-
class BandPower(GenAxisArray):
|
|
67
|
-
""":obj:`Unit` for :obj:`bandpower`."""
|
|
68
55
|
|
|
56
|
+
class BandPower(BaseTransformerUnit[BandPowerSettings, AxisArray, AxisArray, BandPowerTransformer]):
|
|
69
57
|
SETTINGS = BandPowerSettings
|
|
70
58
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
59
|
+
|
|
60
|
+
def bandpower(
|
|
61
|
+
spectrogram_settings: SpectrogramSettings,
|
|
62
|
+
bands: list[tuple[float, float]] | None = [
|
|
63
|
+
(17, 30),
|
|
64
|
+
(70, 170),
|
|
65
|
+
],
|
|
66
|
+
aggregation: AggregationFunction = AggregationFunction.MEAN,
|
|
67
|
+
) -> BandPowerTransformer:
|
|
68
|
+
"""
|
|
69
|
+
Calculate the average spectral power in each band.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
:obj:`BandPowerTransformer`
|
|
73
|
+
"""
|
|
74
|
+
return BandPowerTransformer(
|
|
75
|
+
settings=BandPowerSettings(
|
|
76
|
+
spectrogram_settings=spectrogram_settings,
|
|
77
|
+
bands=bands,
|
|
78
|
+
aggregation=aggregation,
|
|
75
79
|
)
|
|
80
|
+
)
|
ezmsg/sigproc/base.py
CHANGED
|
@@ -1,39 +1,149 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
import typing
|
|
1
|
+
"""
|
|
2
|
+
Backwards-compatible re-exports from ezmsg.baseproc.
|
|
4
3
|
|
|
5
|
-
|
|
6
|
-
from ezmsg.
|
|
7
|
-
from ezmsg.util.generator import GenState
|
|
4
|
+
This module re-exports all symbols from ezmsg.baseproc to maintain backwards
|
|
5
|
+
compatibility for code that imports from ezmsg.sigproc.base.
|
|
8
6
|
|
|
7
|
+
New code should import directly from ezmsg.baseproc instead.
|
|
8
|
+
"""
|
|
9
9
|
|
|
10
|
-
|
|
11
|
-
STATE = GenState
|
|
10
|
+
import warnings
|
|
12
11
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
12
|
+
warnings.warn(
|
|
13
|
+
"Importing from 'ezmsg.sigproc.base' is deprecated. Please import from 'ezmsg.baseproc' instead.",
|
|
14
|
+
DeprecationWarning,
|
|
15
|
+
stacklevel=2,
|
|
16
|
+
)
|
|
16
17
|
|
|
17
|
-
|
|
18
|
-
|
|
18
|
+
# Re-export everything from ezmsg.baseproc for backwards compatibility
|
|
19
|
+
from ezmsg.baseproc import ( # noqa: E402
|
|
20
|
+
# Protocols
|
|
21
|
+
AdaptiveTransformer,
|
|
22
|
+
# Type variables
|
|
23
|
+
AdaptiveTransformerType,
|
|
24
|
+
# Stateful classes
|
|
25
|
+
BaseAdaptiveTransformer,
|
|
26
|
+
# Unit classes
|
|
27
|
+
BaseAdaptiveTransformerUnit,
|
|
28
|
+
BaseAsyncTransformer,
|
|
29
|
+
# Base processor classes
|
|
30
|
+
BaseConsumer,
|
|
31
|
+
BaseConsumerUnit,
|
|
32
|
+
BaseProcessor,
|
|
33
|
+
BaseProcessorUnit,
|
|
34
|
+
BaseProducer,
|
|
35
|
+
BaseProducerUnit,
|
|
36
|
+
BaseStatefulConsumer,
|
|
37
|
+
BaseStatefulProcessor,
|
|
38
|
+
BaseStatefulProducer,
|
|
39
|
+
BaseStatefulTransformer,
|
|
40
|
+
BaseTransformer,
|
|
41
|
+
BaseTransformerUnit,
|
|
42
|
+
# Composite classes
|
|
43
|
+
CompositeProcessor,
|
|
44
|
+
CompositeProducer,
|
|
45
|
+
CompositeStateful,
|
|
46
|
+
Consumer,
|
|
47
|
+
ConsumerType,
|
|
48
|
+
GenAxisArray,
|
|
49
|
+
MessageInType,
|
|
50
|
+
MessageOutType,
|
|
51
|
+
Processor,
|
|
52
|
+
Producer,
|
|
53
|
+
ProducerType,
|
|
54
|
+
# Message types
|
|
55
|
+
SampleMessage,
|
|
56
|
+
SettingsType,
|
|
57
|
+
Stateful,
|
|
58
|
+
StatefulConsumer,
|
|
59
|
+
StatefulProcessor,
|
|
60
|
+
StatefulProducer,
|
|
61
|
+
StatefulTransformer,
|
|
62
|
+
StateType,
|
|
63
|
+
Transformer,
|
|
64
|
+
TransformerType,
|
|
65
|
+
# Type resolution helpers
|
|
66
|
+
_get_base_processor_message_in_type,
|
|
67
|
+
_get_base_processor_message_out_type,
|
|
68
|
+
_get_base_processor_settings_type,
|
|
69
|
+
_get_base_processor_state_type,
|
|
70
|
+
_get_processor_message_type,
|
|
71
|
+
# Type utilities
|
|
72
|
+
check_message_type_compatibility,
|
|
73
|
+
get_base_adaptive_transformer_type,
|
|
74
|
+
get_base_consumer_type,
|
|
75
|
+
get_base_producer_type,
|
|
76
|
+
get_base_transformer_type,
|
|
77
|
+
is_sample_message,
|
|
78
|
+
# Decorators
|
|
79
|
+
processor_state,
|
|
80
|
+
# Profiling
|
|
81
|
+
profile_subpub,
|
|
82
|
+
resolve_typevar,
|
|
83
|
+
)
|
|
19
84
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
85
|
+
__all__ = [
|
|
86
|
+
# Protocols
|
|
87
|
+
"Processor",
|
|
88
|
+
"Producer",
|
|
89
|
+
"Consumer",
|
|
90
|
+
"Transformer",
|
|
91
|
+
"StatefulProcessor",
|
|
92
|
+
"StatefulProducer",
|
|
93
|
+
"StatefulConsumer",
|
|
94
|
+
"StatefulTransformer",
|
|
95
|
+
"AdaptiveTransformer",
|
|
96
|
+
# Type variables
|
|
97
|
+
"MessageInType",
|
|
98
|
+
"MessageOutType",
|
|
99
|
+
"SettingsType",
|
|
100
|
+
"StateType",
|
|
101
|
+
"ProducerType",
|
|
102
|
+
"ConsumerType",
|
|
103
|
+
"TransformerType",
|
|
104
|
+
"AdaptiveTransformerType",
|
|
105
|
+
# Decorators
|
|
106
|
+
"processor_state",
|
|
107
|
+
# Base processor classes
|
|
108
|
+
"BaseProcessor",
|
|
109
|
+
"BaseProducer",
|
|
110
|
+
"BaseConsumer",
|
|
111
|
+
"BaseTransformer",
|
|
112
|
+
# Stateful classes
|
|
113
|
+
"Stateful",
|
|
114
|
+
"BaseStatefulProcessor",
|
|
115
|
+
"BaseStatefulProducer",
|
|
116
|
+
"BaseStatefulConsumer",
|
|
117
|
+
"BaseStatefulTransformer",
|
|
118
|
+
"BaseAdaptiveTransformer",
|
|
119
|
+
"BaseAsyncTransformer",
|
|
120
|
+
# Composite classes
|
|
121
|
+
"CompositeStateful",
|
|
122
|
+
"CompositeProcessor",
|
|
123
|
+
"CompositeProducer",
|
|
124
|
+
# Unit classes
|
|
125
|
+
"BaseProducerUnit",
|
|
126
|
+
"BaseProcessorUnit",
|
|
127
|
+
"BaseConsumerUnit",
|
|
128
|
+
"BaseTransformerUnit",
|
|
129
|
+
"BaseAdaptiveTransformerUnit",
|
|
130
|
+
"GenAxisArray",
|
|
131
|
+
# Type resolution helpers
|
|
132
|
+
"get_base_producer_type",
|
|
133
|
+
"get_base_consumer_type",
|
|
134
|
+
"get_base_transformer_type",
|
|
135
|
+
"get_base_adaptive_transformer_type",
|
|
136
|
+
"_get_base_processor_settings_type",
|
|
137
|
+
"_get_base_processor_message_in_type",
|
|
138
|
+
"_get_base_processor_message_out_type",
|
|
139
|
+
"_get_base_processor_state_type",
|
|
140
|
+
"_get_processor_message_type",
|
|
141
|
+
# Message types
|
|
142
|
+
"SampleMessage",
|
|
143
|
+
"is_sample_message",
|
|
144
|
+
# Profiling
|
|
145
|
+
"profile_subpub",
|
|
146
|
+
# Type utilities
|
|
147
|
+
"check_message_type_compatibility",
|
|
148
|
+
"resolve_typevar",
|
|
149
|
+
]
|