ezmsg-sigproc 1.8.2__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +36 -39
- ezmsg/sigproc/adaptive_lattice_notch.py +231 -0
- ezmsg/sigproc/affinetransform.py +169 -163
- ezmsg/sigproc/aggregate.py +133 -101
- ezmsg/sigproc/bandpower.py +64 -52
- ezmsg/sigproc/base.py +1242 -0
- ezmsg/sigproc/butterworthfilter.py +37 -33
- ezmsg/sigproc/cheby.py +29 -17
- ezmsg/sigproc/combfilter.py +163 -0
- ezmsg/sigproc/decimate.py +19 -10
- ezmsg/sigproc/detrend.py +29 -0
- ezmsg/sigproc/diff.py +81 -0
- ezmsg/sigproc/downsample.py +78 -84
- ezmsg/sigproc/ewma.py +197 -0
- ezmsg/sigproc/extract_axis.py +41 -0
- ezmsg/sigproc/filter.py +257 -141
- ezmsg/sigproc/filterbank.py +247 -199
- ezmsg/sigproc/math/abs.py +17 -22
- ezmsg/sigproc/math/clip.py +24 -24
- ezmsg/sigproc/math/difference.py +34 -30
- ezmsg/sigproc/math/invert.py +13 -25
- ezmsg/sigproc/math/log.py +28 -33
- ezmsg/sigproc/math/scale.py +18 -26
- ezmsg/sigproc/quantize.py +71 -0
- ezmsg/sigproc/resample.py +298 -0
- ezmsg/sigproc/sampler.py +241 -259
- ezmsg/sigproc/scaler.py +55 -218
- ezmsg/sigproc/signalinjector.py +52 -43
- ezmsg/sigproc/slicer.py +81 -89
- ezmsg/sigproc/spectrogram.py +77 -75
- ezmsg/sigproc/spectrum.py +203 -168
- ezmsg/sigproc/synth.py +546 -393
- ezmsg/sigproc/transpose.py +131 -0
- ezmsg/sigproc/util/asio.py +156 -0
- ezmsg/sigproc/util/message.py +31 -0
- ezmsg/sigproc/util/profile.py +55 -12
- ezmsg/sigproc/util/typeresolution.py +83 -0
- ezmsg/sigproc/wavelets.py +154 -153
- ezmsg/sigproc/window.py +269 -211
- {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/METADATA +2 -1
- ezmsg_sigproc-2.1.0.dist-info/RECORD +51 -0
- ezmsg_sigproc-1.8.2.dist-info/RECORD +0 -39
- {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/WHEEL +0 -0
- {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/licenses/LICENSE.txt +0 -0
ezmsg/sigproc/aggregate.py
CHANGED
|
@@ -3,7 +3,6 @@ import typing
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import numpy.typing as npt
|
|
5
5
|
import ezmsg.core as ez
|
|
6
|
-
from ezmsg.util.generator import consumer
|
|
7
6
|
from ezmsg.util.messages.axisarray import (
|
|
8
7
|
AxisArray,
|
|
9
8
|
slice_along_axis,
|
|
@@ -12,7 +11,11 @@ from ezmsg.util.messages.axisarray import (
|
|
|
12
11
|
)
|
|
13
12
|
|
|
14
13
|
from .spectral import OptionsEnum
|
|
15
|
-
from .base import
|
|
14
|
+
from .base import (
|
|
15
|
+
BaseStatefulTransformer,
|
|
16
|
+
BaseTransformerUnit,
|
|
17
|
+
processor_state,
|
|
18
|
+
)
|
|
16
19
|
|
|
17
20
|
|
|
18
21
|
class AggregationFunction(OptionsEnum):
|
|
@@ -33,6 +36,7 @@ class AggregationFunction(OptionsEnum):
|
|
|
33
36
|
NANSUM = "nansum"
|
|
34
37
|
ARGMIN = "argmin"
|
|
35
38
|
ARGMAX = "argmax"
|
|
39
|
+
TRAPEZOID = "trapezoid"
|
|
36
40
|
|
|
37
41
|
|
|
38
42
|
AGGREGATORS = {
|
|
@@ -51,133 +55,161 @@ AGGREGATORS = {
|
|
|
51
55
|
AggregationFunction.NANSUM: np.nansum,
|
|
52
56
|
AggregationFunction.ARGMIN: np.argmin,
|
|
53
57
|
AggregationFunction.ARGMAX: np.argmax,
|
|
58
|
+
# Note: Some methods require x-coordinates and
|
|
59
|
+
# are handled specially in `_process`.
|
|
60
|
+
AggregationFunction.TRAPEZOID: np.trapezoid,
|
|
54
61
|
}
|
|
55
62
|
|
|
56
63
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
bands: list[tuple[float, float]] | None = None,
|
|
61
|
-
operation: AggregationFunction = AggregationFunction.MEAN,
|
|
62
|
-
):
|
|
64
|
+
class RangedAggregateSettings(ez.Settings):
|
|
65
|
+
"""
|
|
66
|
+
Settings for ``RangedAggregate``.
|
|
63
67
|
"""
|
|
64
|
-
Apply an aggregation operation over one or more bands.
|
|
65
68
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
bands: [(band1_min, band1_max), (band2_min, band2_max), ...]
|
|
69
|
-
If not set then this acts as a passthrough node.
|
|
70
|
-
operation: :obj:`AggregationFunction` to apply to each band.
|
|
69
|
+
axis: str | None = None
|
|
70
|
+
"""The name of the axis along which to apply the bands."""
|
|
71
71
|
|
|
72
|
-
|
|
73
|
-
|
|
72
|
+
bands: list[tuple[float, float]] | None = None
|
|
73
|
+
"""
|
|
74
|
+
[(band1_min, band1_max), (band2_min, band2_max), ...]
|
|
75
|
+
If not set then this acts as a passthrough node.
|
|
74
76
|
"""
|
|
75
|
-
msg_out = AxisArray(np.array([]), dims=[""])
|
|
76
77
|
|
|
77
|
-
|
|
78
|
+
operation: AggregationFunction = AggregationFunction.MEAN
|
|
79
|
+
""":obj:`AggregationFunction` to apply to each band."""
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@processor_state
|
|
83
|
+
class RangedAggregateState:
|
|
78
84
|
slices: list[tuple[typing.Any, ...]] | None = None
|
|
79
85
|
out_axis: AxisBase | None = None
|
|
80
86
|
ax_vec: npt.NDArray | None = None
|
|
81
87
|
|
|
82
|
-
# Reset if any of these changes. Key not checked because continuity between chunks not required.
|
|
83
|
-
check_inputs = {"gain": None, "offset": None, "len": None, "key": None}
|
|
84
88
|
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
+
class RangedAggregateTransformer(
|
|
90
|
+
BaseStatefulTransformer[
|
|
91
|
+
RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateState
|
|
92
|
+
]
|
|
93
|
+
):
|
|
94
|
+
def __call__(self, message: AxisArray) -> AxisArray:
|
|
95
|
+
# Override for shortcut passthrough mode.
|
|
96
|
+
if self.settings.bands is None:
|
|
97
|
+
return message
|
|
98
|
+
return super().__call__(message)
|
|
99
|
+
|
|
100
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
101
|
+
axis = self.settings.axis or message.dims[0]
|
|
102
|
+
target_axis = message.get_axis(axis)
|
|
103
|
+
|
|
104
|
+
hash_components = (message.key,)
|
|
105
|
+
if hasattr(target_axis, "data"):
|
|
106
|
+
hash_components += (len(target_axis.data),)
|
|
107
|
+
elif isinstance(target_axis, AxisArray.LinearAxis):
|
|
108
|
+
hash_components += (target_axis.gain, target_axis.offset)
|
|
109
|
+
return hash(hash_components)
|
|
110
|
+
|
|
111
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
112
|
+
axis = self.settings.axis or message.dims[0]
|
|
113
|
+
target_axis = message.get_axis(axis)
|
|
114
|
+
ax_idx = message.get_axis_idx(axis)
|
|
115
|
+
|
|
116
|
+
if hasattr(target_axis, "data"):
|
|
117
|
+
self._state.ax_vec = target_axis.data
|
|
89
118
|
else:
|
|
90
|
-
|
|
91
|
-
|
|
119
|
+
self._state.ax_vec = target_axis.value(
|
|
120
|
+
np.arange(message.data.shape[ax_idx])
|
|
121
|
+
)
|
|
92
122
|
|
|
93
|
-
|
|
94
|
-
|
|
123
|
+
ax_dat = []
|
|
124
|
+
slices = []
|
|
125
|
+
for start, stop in self.settings.bands:
|
|
126
|
+
inds = np.where(
|
|
127
|
+
np.logical_and(self._state.ax_vec >= start, self._state.ax_vec <= stop)
|
|
128
|
+
)[0]
|
|
129
|
+
slices.append(np.s_[inds[0] : inds[-1] + 1])
|
|
95
130
|
if hasattr(target_axis, "data"):
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
b_reset = b_reset or target_axis.gain != check_inputs["gain"]
|
|
99
|
-
b_reset = b_reset or target_axis.offset != check_inputs["offset"]
|
|
100
|
-
|
|
101
|
-
if b_reset:
|
|
102
|
-
# Update check variables
|
|
103
|
-
check_inputs["key"] = msg_in.key
|
|
104
|
-
if hasattr(target_axis, "data"):
|
|
105
|
-
check_inputs["len"] = len(target_axis.data)
|
|
131
|
+
if self._state.ax_vec.dtype.type is np.str_:
|
|
132
|
+
sl_dat = f"{self._state.ax_vec[start]} - {self._state.ax_vec[stop]}"
|
|
106
133
|
else:
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
134
|
+
ax_dat.append(np.mean(self._state.ax_vec[inds]))
|
|
135
|
+
else:
|
|
136
|
+
sl_dat = target_axis.value(np.mean(inds))
|
|
137
|
+
ax_dat.append(sl_dat)
|
|
138
|
+
|
|
139
|
+
self._state.slices = slices
|
|
140
|
+
self._state.out_axis = AxisArray.CoordinateAxis(
|
|
141
|
+
data=np.array(ax_dat),
|
|
142
|
+
dims=[axis],
|
|
143
|
+
unit=target_axis.unit,
|
|
144
|
+
)
|
|
112
145
|
|
|
113
|
-
|
|
146
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
147
|
+
axis = self.settings.axis or message.dims[0]
|
|
148
|
+
ax_idx = message.get_axis_idx(axis)
|
|
149
|
+
agg_func = AGGREGATORS[self.settings.operation]
|
|
114
150
|
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
slices.append(np.s_[inds[0] : inds[-1] + 1])
|
|
125
|
-
if hasattr(target_axis, "data"):
|
|
126
|
-
if ax_vec.dtype.type is np.str_:
|
|
127
|
-
sl_dat = f"{ax_vec[start]} - {ax_vec[stop]}"
|
|
128
|
-
else:
|
|
129
|
-
sl_dat = ax_dat.append(np.mean(ax_vec[inds]))
|
|
130
|
-
else:
|
|
131
|
-
sl_dat = target_axis.value(np.mean(inds))
|
|
132
|
-
ax_dat.append(sl_dat)
|
|
133
|
-
|
|
134
|
-
out_axis = AxisArray.CoordinateAxis(
|
|
135
|
-
data=np.array(ax_dat),
|
|
136
|
-
dims=[axis],
|
|
137
|
-
unit=target_axis.unit,
|
|
151
|
+
if self.settings.operation in [
|
|
152
|
+
AggregationFunction.TRAPEZOID,
|
|
153
|
+
]:
|
|
154
|
+
# Special handling for methods that require x-coordinates.
|
|
155
|
+
out_data = [
|
|
156
|
+
agg_func(
|
|
157
|
+
slice_along_axis(message.data, sl, axis=ax_idx),
|
|
158
|
+
x=self._state.ax_vec[sl],
|
|
159
|
+
axis=ax_idx,
|
|
138
160
|
)
|
|
139
|
-
|
|
140
|
-
|
|
161
|
+
for sl in self._state.slices
|
|
162
|
+
]
|
|
163
|
+
else:
|
|
141
164
|
out_data = [
|
|
142
|
-
agg_func(slice_along_axis(
|
|
143
|
-
for sl in slices
|
|
165
|
+
agg_func(slice_along_axis(message.data, sl, axis=ax_idx), axis=ax_idx)
|
|
166
|
+
for sl in self._state.slices
|
|
144
167
|
]
|
|
145
168
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
if operation in [AggregationFunction.ARGMIN, AggregationFunction.ARGMAX]:
|
|
152
|
-
# Convert indices returned by argmin/argmax into the value along the axis.
|
|
153
|
-
out_data = []
|
|
154
|
-
for sl_ix, sl in enumerate(slices):
|
|
155
|
-
offsets = np.take(msg_out.data, [sl_ix], axis=ax_idx)
|
|
156
|
-
out_data.append(ax_vec[sl][offsets])
|
|
157
|
-
msg_out.data = np.concatenate(out_data, axis=ax_idx)
|
|
169
|
+
msg_out = replace(
|
|
170
|
+
message,
|
|
171
|
+
data=np.stack(out_data, axis=ax_idx),
|
|
172
|
+
axes={**message.axes, axis: self._state.out_axis},
|
|
173
|
+
)
|
|
158
174
|
|
|
175
|
+
if self.settings.operation in [
|
|
176
|
+
AggregationFunction.ARGMIN,
|
|
177
|
+
AggregationFunction.ARGMAX,
|
|
178
|
+
]:
|
|
179
|
+
out_data = []
|
|
180
|
+
for sl_ix, sl in enumerate(self._state.slices):
|
|
181
|
+
offsets = np.take(msg_out.data, [sl_ix], axis=ax_idx)
|
|
182
|
+
out_data.append(self._state.ax_vec[sl][offsets])
|
|
183
|
+
msg_out.data = np.concatenate(out_data, axis=ax_idx)
|
|
159
184
|
|
|
160
|
-
|
|
161
|
-
"""
|
|
162
|
-
Settings for ``RangedAggregate``.
|
|
163
|
-
See :obj:`ranged_aggregate` for details.
|
|
164
|
-
"""
|
|
185
|
+
return msg_out
|
|
165
186
|
|
|
166
|
-
axis: str | None = None
|
|
167
|
-
bands: list[tuple[float, float]] | None = None
|
|
168
|
-
operation: AggregationFunction = AggregationFunction.MEAN
|
|
169
187
|
|
|
188
|
+
class RangedAggregate(
|
|
189
|
+
BaseTransformerUnit[
|
|
190
|
+
RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateTransformer
|
|
191
|
+
]
|
|
192
|
+
):
|
|
193
|
+
SETTINGS = RangedAggregateSettings
|
|
170
194
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
195
|
+
|
|
196
|
+
def ranged_aggregate(
|
|
197
|
+
axis: str | None = None,
|
|
198
|
+
bands: list[tuple[float, float]] | None = None,
|
|
199
|
+
operation: AggregationFunction = AggregationFunction.MEAN,
|
|
200
|
+
) -> RangedAggregateTransformer:
|
|
174
201
|
"""
|
|
202
|
+
Apply an aggregation operation over one or more bands.
|
|
175
203
|
|
|
176
|
-
|
|
204
|
+
Args:
|
|
205
|
+
axis: The name of the axis along which to apply the bands.
|
|
206
|
+
bands: [(band1_min, band1_max), (band2_min, band2_max), ...]
|
|
207
|
+
If not set then this acts as a passthrough node.
|
|
208
|
+
operation: :obj:`AggregationFunction` to apply to each band.
|
|
177
209
|
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
210
|
+
Returns:
|
|
211
|
+
:obj:`RangedAggregateTransformer`
|
|
212
|
+
"""
|
|
213
|
+
return RangedAggregateTransformer(
|
|
214
|
+
RangedAggregateSettings(axis=axis, bands=bands, operation=operation)
|
|
215
|
+
)
|
ezmsg/sigproc/bandpower.py
CHANGED
|
@@ -1,76 +1,88 @@
|
|
|
1
1
|
from dataclasses import field
|
|
2
|
-
import typing
|
|
3
2
|
|
|
4
|
-
import numpy as np
|
|
5
3
|
import ezmsg.core as ez
|
|
6
4
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
-
from ezmsg.util.generator import consumer, compose
|
|
8
5
|
|
|
9
|
-
from .spectrogram import
|
|
10
|
-
from .aggregate import
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
22
|
-
"""
|
|
23
|
-
Calculate the average spectral power in each band.
|
|
24
|
-
|
|
25
|
-
Args:
|
|
26
|
-
spectrogram_settings: Settings for spectrogram calculation.
|
|
27
|
-
bands: (min, max) tuples of band limits in Hz.
|
|
28
|
-
|
|
29
|
-
Returns:
|
|
30
|
-
A primed generator object ready to yield an :obj:`AxisArray` for each .send(axis_array)
|
|
31
|
-
with the data payload being the average spectral power in each band of the input data.
|
|
32
|
-
"""
|
|
33
|
-
msg_out = AxisArray(np.array([]), dims=[""])
|
|
34
|
-
|
|
35
|
-
f_spec = spectrogram(
|
|
36
|
-
window_dur=spectrogram_settings.window_dur,
|
|
37
|
-
window_shift=spectrogram_settings.window_shift,
|
|
38
|
-
window_anchor=spectrogram_settings.window_anchor,
|
|
39
|
-
window=spectrogram_settings.window,
|
|
40
|
-
transform=spectrogram_settings.transform,
|
|
41
|
-
output=spectrogram_settings.output,
|
|
42
|
-
)
|
|
43
|
-
f_agg = ranged_aggregate(
|
|
44
|
-
axis="freq", bands=bands, operation=AggregationFunction.MEAN
|
|
45
|
-
)
|
|
46
|
-
pipeline = compose(f_spec, f_agg)
|
|
47
|
-
|
|
48
|
-
while True:
|
|
49
|
-
msg_in: AxisArray = yield msg_out
|
|
50
|
-
msg_out = pipeline(msg_in)
|
|
6
|
+
from .spectrogram import SpectrogramSettings, SpectrogramTransformer
|
|
7
|
+
from .aggregate import (
|
|
8
|
+
AggregationFunction,
|
|
9
|
+
RangedAggregateTransformer,
|
|
10
|
+
RangedAggregateSettings,
|
|
11
|
+
)
|
|
12
|
+
from .base import (
|
|
13
|
+
BaseProcessor,
|
|
14
|
+
CompositeProcessor,
|
|
15
|
+
BaseStatefulProcessor,
|
|
16
|
+
BaseTransformerUnit,
|
|
17
|
+
)
|
|
51
18
|
|
|
52
19
|
|
|
53
20
|
class BandPowerSettings(ez.Settings):
|
|
54
21
|
"""
|
|
55
22
|
Settings for ``BandPower``.
|
|
56
|
-
See :obj:`bandpower` for details.
|
|
57
23
|
"""
|
|
58
24
|
|
|
59
25
|
spectrogram_settings: SpectrogramSettings = field(
|
|
60
26
|
default_factory=SpectrogramSettings
|
|
61
27
|
)
|
|
28
|
+
"""
|
|
29
|
+
Settings for spectrogram calculation.
|
|
30
|
+
"""
|
|
31
|
+
|
|
62
32
|
bands: list[tuple[float, float]] | None = field(
|
|
63
33
|
default_factory=lambda: [(17, 30), (70, 170)]
|
|
64
34
|
)
|
|
35
|
+
"""
|
|
36
|
+
(min, max) tuples of band limits in Hz.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
aggregation: AggregationFunction = AggregationFunction.MEAN
|
|
40
|
+
""":obj:`AggregationFunction` to apply to each band."""
|
|
65
41
|
|
|
66
42
|
|
|
67
|
-
class
|
|
68
|
-
|
|
43
|
+
class BandPowerTransformer(CompositeProcessor[BandPowerSettings, AxisArray, AxisArray]):
|
|
44
|
+
@staticmethod
|
|
45
|
+
def _initialize_processors(
|
|
46
|
+
settings: BandPowerSettings,
|
|
47
|
+
) -> dict[str, BaseProcessor | BaseStatefulProcessor]:
|
|
48
|
+
return {
|
|
49
|
+
"spectrogram": SpectrogramTransformer(
|
|
50
|
+
settings=settings.spectrogram_settings
|
|
51
|
+
),
|
|
52
|
+
"aggregate": RangedAggregateTransformer(
|
|
53
|
+
settings=RangedAggregateSettings(
|
|
54
|
+
axis="freq",
|
|
55
|
+
bands=settings.bands,
|
|
56
|
+
operation=settings.aggregation,
|
|
57
|
+
)
|
|
58
|
+
),
|
|
59
|
+
}
|
|
69
60
|
|
|
61
|
+
|
|
62
|
+
class BandPower(
|
|
63
|
+
BaseTransformerUnit[BandPowerSettings, AxisArray, AxisArray, BandPowerTransformer]
|
|
64
|
+
):
|
|
70
65
|
SETTINGS = BandPowerSettings
|
|
71
66
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
67
|
+
|
|
68
|
+
def bandpower(
|
|
69
|
+
spectrogram_settings: SpectrogramSettings,
|
|
70
|
+
bands: list[tuple[float, float]] | None = [
|
|
71
|
+
(17, 30),
|
|
72
|
+
(70, 170),
|
|
73
|
+
],
|
|
74
|
+
aggregation: AggregationFunction = AggregationFunction.MEAN,
|
|
75
|
+
) -> BandPowerTransformer:
|
|
76
|
+
"""
|
|
77
|
+
Calculate the average spectral power in each band.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
:obj:`BandPowerTransformer`
|
|
81
|
+
"""
|
|
82
|
+
return BandPowerTransformer(
|
|
83
|
+
settings=BandPowerSettings(
|
|
84
|
+
spectrogram_settings=spectrogram_settings,
|
|
85
|
+
bands=bands,
|
|
86
|
+
aggregation=aggregation,
|
|
76
87
|
)
|
|
88
|
+
)
|