ezmsg-sigproc 2.5.0__py3-none-any.whl → 2.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +5 -11
- ezmsg/sigproc/adaptive_lattice_notch.py +11 -30
- ezmsg/sigproc/affinetransform.py +16 -42
- ezmsg/sigproc/aggregate.py +17 -34
- ezmsg/sigproc/bandpower.py +12 -20
- ezmsg/sigproc/base.py +141 -1276
- ezmsg/sigproc/butterworthfilter.py +8 -16
- ezmsg/sigproc/butterworthzerophase.py +7 -16
- ezmsg/sigproc/cheby.py +4 -10
- ezmsg/sigproc/combfilter.py +5 -8
- ezmsg/sigproc/coordinatespaces.py +142 -0
- ezmsg/sigproc/decimate.py +3 -7
- ezmsg/sigproc/denormalize.py +6 -11
- ezmsg/sigproc/detrend.py +3 -4
- ezmsg/sigproc/diff.py +8 -17
- ezmsg/sigproc/downsample.py +11 -20
- ezmsg/sigproc/ewma.py +11 -28
- ezmsg/sigproc/ewmfilter.py +1 -1
- ezmsg/sigproc/extract_axis.py +3 -4
- ezmsg/sigproc/fbcca.py +34 -59
- ezmsg/sigproc/filter.py +19 -45
- ezmsg/sigproc/filterbank.py +37 -74
- ezmsg/sigproc/filterbankdesign.py +7 -14
- ezmsg/sigproc/fir_hilbert.py +13 -30
- ezmsg/sigproc/fir_pmc.py +5 -10
- ezmsg/sigproc/firfilter.py +12 -14
- ezmsg/sigproc/gaussiansmoothing.py +5 -9
- ezmsg/sigproc/kaiser.py +11 -15
- ezmsg/sigproc/math/abs.py +4 -3
- ezmsg/sigproc/math/add.py +121 -0
- ezmsg/sigproc/math/clip.py +4 -1
- ezmsg/sigproc/math/difference.py +100 -36
- ezmsg/sigproc/math/invert.py +3 -3
- ezmsg/sigproc/math/log.py +5 -6
- ezmsg/sigproc/math/scale.py +2 -0
- ezmsg/sigproc/messages.py +1 -2
- ezmsg/sigproc/quantize.py +3 -6
- ezmsg/sigproc/resample.py +17 -38
- ezmsg/sigproc/rollingscaler.py +12 -37
- ezmsg/sigproc/sampler.py +19 -37
- ezmsg/sigproc/scaler.py +11 -22
- ezmsg/sigproc/signalinjector.py +7 -18
- ezmsg/sigproc/slicer.py +14 -34
- ezmsg/sigproc/spectral.py +3 -3
- ezmsg/sigproc/spectrogram.py +12 -19
- ezmsg/sigproc/spectrum.py +17 -38
- ezmsg/sigproc/transpose.py +12 -24
- ezmsg/sigproc/util/asio.py +25 -156
- ezmsg/sigproc/util/axisarray_buffer.py +12 -26
- ezmsg/sigproc/util/buffer.py +22 -43
- ezmsg/sigproc/util/message.py +17 -31
- ezmsg/sigproc/util/profile.py +23 -174
- ezmsg/sigproc/util/sparse.py +7 -15
- ezmsg/sigproc/util/typeresolution.py +17 -83
- ezmsg/sigproc/wavelets.py +10 -19
- ezmsg/sigproc/window.py +29 -83
- ezmsg_sigproc-2.7.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.7.0.dist-info/RECORD +64 -0
- ezmsg/sigproc/synth.py +0 -774
- ezmsg_sigproc-2.5.0.dist-info/METADATA +0 -72
- ezmsg_sigproc-2.5.0.dist-info/RECORD +0 -63
- {ezmsg_sigproc-2.5.0.dist-info → ezmsg_sigproc-2.7.0.dist-info}/WHEEL +0 -0
- /ezmsg_sigproc-2.5.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.7.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/__version__.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '2.
|
|
32
|
-
__version_tuple__ = version_tuple = (2,
|
|
31
|
+
__version__ = version = '2.7.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (2, 7, 0)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
ezmsg/sigproc/activation.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
import scipy.special
|
|
2
1
|
import ezmsg.core as ez
|
|
2
|
+
import scipy.special
|
|
3
|
+
from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
|
|
3
4
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
4
5
|
from ezmsg.util.messages.util import replace
|
|
5
6
|
|
|
6
7
|
from .spectral import OptionsEnum
|
|
7
|
-
from .base import BaseTransformer, BaseTransformerUnit
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class ActivationFunction(OptionsEnum):
|
|
@@ -50,20 +50,14 @@ class ActivationTransformer(BaseTransformer[ActivationSettings, AxisArray, AxisA
|
|
|
50
50
|
# str type handling
|
|
51
51
|
function = self.settings.function.lower()
|
|
52
52
|
if function not in ActivationFunction.options():
|
|
53
|
-
raise ValueError(
|
|
54
|
-
|
|
55
|
-
)
|
|
56
|
-
function = list(ACTIVATIONS.keys())[
|
|
57
|
-
ActivationFunction.options().index(function)
|
|
58
|
-
]
|
|
53
|
+
raise ValueError(f"Unrecognized activation function {function}. Must be one of {ACTIVATIONS.keys()}")
|
|
54
|
+
function = list(ACTIVATIONS.keys())[ActivationFunction.options().index(function)]
|
|
59
55
|
func = ACTIVATIONS[function]
|
|
60
56
|
|
|
61
57
|
return replace(message, data=func(message.data))
|
|
62
58
|
|
|
63
59
|
|
|
64
|
-
class Activation(
|
|
65
|
-
BaseTransformerUnit[ActivationSettings, AxisArray, AxisArray, ActivationTransformer]
|
|
66
|
-
):
|
|
60
|
+
class Activation(BaseTransformerUnit[ActivationSettings, AxisArray, AxisArray, ActivationTransformer]):
|
|
67
61
|
SETTINGS = ActivationSettings
|
|
68
62
|
|
|
69
63
|
|
|
@@ -1,12 +1,11 @@
|
|
|
1
|
+
import ezmsg.core as ez
|
|
1
2
|
import numpy as np
|
|
2
3
|
import numpy.typing as npt
|
|
3
4
|
import scipy.signal
|
|
4
|
-
|
|
5
|
+
from ezmsg.baseproc import BaseStatefulTransformer, processor_state
|
|
5
6
|
from ezmsg.util.messages.axisarray import AxisArray, CoordinateAxis
|
|
6
7
|
from ezmsg.util.messages.util import replace
|
|
7
8
|
|
|
8
|
-
from .base import processor_state, BaseStatefulTransformer
|
|
9
|
-
|
|
10
9
|
|
|
11
10
|
class AdaptiveLatticeNotchFilterSettings(ez.Settings):
|
|
12
11
|
"""Settings for the Adaptive Lattice Notch Filter."""
|
|
@@ -76,9 +75,7 @@ class AdaptiveLatticeNotchFilterTransformer(
|
|
|
76
75
|
|
|
77
76
|
fs = 1 / message.axes[self.settings.axis].gain
|
|
78
77
|
init_f = (
|
|
79
|
-
self.settings.init_notch_freq
|
|
80
|
-
if self.settings.init_notch_freq is not None
|
|
81
|
-
else 0.07178314656435313 * fs
|
|
78
|
+
self.settings.init_notch_freq if self.settings.init_notch_freq is not None else 0.07178314656435313 * fs
|
|
82
79
|
)
|
|
83
80
|
init_omega = init_f * (2 * np.pi) / fs
|
|
84
81
|
init_k1 = -np.cos(init_omega)
|
|
@@ -91,9 +88,7 @@ class AdaptiveLatticeNotchFilterTransformer(
|
|
|
91
88
|
self._state.k1 = init_k1 + np.zeros(sample_shape, dtype=float)
|
|
92
89
|
self._state.freq_template = CoordinateAxis(
|
|
93
90
|
data=np.zeros((0,) + sample_shape, dtype=float),
|
|
94
|
-
dims=[self.settings.axis]
|
|
95
|
-
+ message.dims[:ax_idx]
|
|
96
|
-
+ message.dims[ax_idx + 1 :],
|
|
91
|
+
dims=[self.settings.axis] + message.dims[:ax_idx] + message.dims[ax_idx + 1 :],
|
|
97
92
|
unit="Hz",
|
|
98
93
|
)
|
|
99
94
|
|
|
@@ -147,9 +142,7 @@ class AdaptiveLatticeNotchFilterTransformer(
|
|
|
147
142
|
for ix, k in enumerate(self._state.k1.flatten()):
|
|
148
143
|
# Filter to get s_n (notch filter state)
|
|
149
144
|
a_s = [1, k * gamma_plus_1, gamma]
|
|
150
|
-
s_n[:, ix], self._state.zi[:, ix] = scipy.signal.lfilter(
|
|
151
|
-
[1], a_s, _x[:, ix], zi=self._state.zi[:, ix]
|
|
152
|
-
)
|
|
145
|
+
s_n[:, ix], self._state.zi[:, ix] = scipy.signal.lfilter([1], a_s, _x[:, ix], zi=self._state.zi[:, ix])
|
|
153
146
|
|
|
154
147
|
# Apply output filter to get y_out
|
|
155
148
|
b_y = [1, 2 * k, 1]
|
|
@@ -159,17 +152,11 @@ class AdaptiveLatticeNotchFilterTransformer(
|
|
|
159
152
|
s_n_reshaped = s_n.reshape((s_n.shape[0],) + x_data.shape[1:])
|
|
160
153
|
s_final = s_n_reshaped[-1] # Current s_n
|
|
161
154
|
s_final_1 = s_n_reshaped[-2] # s_n_1
|
|
162
|
-
s_final_2 = (
|
|
163
|
-
s_n_reshaped[-3] if len(s_n_reshaped) > 2 else self._state.s_history[0]
|
|
164
|
-
) # s_n_2
|
|
155
|
+
s_final_2 = s_n_reshaped[-3] if len(s_n_reshaped) > 2 else self._state.s_history[0] # s_n_2
|
|
165
156
|
|
|
166
157
|
# Update p and q using final values
|
|
167
|
-
self._state.p = eta * self._state.p + one_minus_eta * (
|
|
168
|
-
|
|
169
|
-
)
|
|
170
|
-
self._state.q = eta * self._state.q + one_minus_eta * (
|
|
171
|
-
2 * (s_final_1 * s_final_1)
|
|
172
|
-
)
|
|
158
|
+
self._state.p = eta * self._state.p + one_minus_eta * (s_final_1 * (s_final + s_final_2))
|
|
159
|
+
self._state.q = eta * self._state.q + one_minus_eta * (2 * (s_final_1 * s_final_1))
|
|
173
160
|
|
|
174
161
|
# Update reflection coefficient
|
|
175
162
|
new_k1 = -self._state.p / (self._state.q + 1e-8) # Avoid division by zero
|
|
@@ -199,17 +186,11 @@ class AdaptiveLatticeNotchFilterTransformer(
|
|
|
199
186
|
y_out[sample_ix] = s_n + 2 * self._state.k1 * s_n_1 + s_n_2
|
|
200
187
|
|
|
201
188
|
# Update filter parameters
|
|
202
|
-
self._state.p = eta * self._state.p + one_minus_eta * (
|
|
203
|
-
|
|
204
|
-
)
|
|
205
|
-
self._state.q = eta * self._state.q + one_minus_eta * (
|
|
206
|
-
2 * (s_n_1 * s_n_1)
|
|
207
|
-
)
|
|
189
|
+
self._state.p = eta * self._state.p + one_minus_eta * (s_n_1 * (s_n + s_n_2))
|
|
190
|
+
self._state.q = eta * self._state.q + one_minus_eta * (2 * (s_n_1 * s_n_1))
|
|
208
191
|
|
|
209
192
|
# Update reflection coefficient
|
|
210
|
-
new_k1 = -self._state.p / (
|
|
211
|
-
self._state.q + 1e-8
|
|
212
|
-
) # Avoid division by zero
|
|
193
|
+
new_k1 = -self._state.p / (self._state.q + 1e-8) # Avoid division by zero
|
|
213
194
|
new_k1 = np.clip(new_k1, -1, 1) # Clip to prevent instability
|
|
214
195
|
self._state.k1 = mu * self._state.k1 + one_minus_mu * new_k1 # Smoothed
|
|
215
196
|
|
ezmsg/sigproc/affinetransform.py
CHANGED
|
@@ -1,18 +1,17 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
|
|
4
|
+
import ezmsg.core as ez
|
|
4
5
|
import numpy as np
|
|
5
6
|
import numpy.typing as npt
|
|
6
|
-
|
|
7
|
-
from ezmsg.util.messages.axisarray import AxisArray, AxisBase
|
|
8
|
-
from ezmsg.util.messages.util import replace
|
|
9
|
-
|
|
10
|
-
from .base import (
|
|
7
|
+
from ezmsg.baseproc import (
|
|
11
8
|
BaseStatefulTransformer,
|
|
12
|
-
BaseTransformerUnit,
|
|
13
9
|
BaseTransformer,
|
|
10
|
+
BaseTransformerUnit,
|
|
14
11
|
processor_state,
|
|
15
12
|
)
|
|
13
|
+
from ezmsg.util.messages.axisarray import AxisArray, AxisBase
|
|
14
|
+
from ezmsg.util.messages.util import replace
|
|
16
15
|
|
|
17
16
|
|
|
18
17
|
class AffineTransformSettings(ez.Settings):
|
|
@@ -38,15 +37,12 @@ class AffineTransformState:
|
|
|
38
37
|
|
|
39
38
|
|
|
40
39
|
class AffineTransformTransformer(
|
|
41
|
-
BaseStatefulTransformer[
|
|
42
|
-
AffineTransformSettings, AxisArray, AxisArray, AffineTransformState
|
|
43
|
-
]
|
|
40
|
+
BaseStatefulTransformer[AffineTransformSettings, AxisArray, AxisArray, AffineTransformState]
|
|
44
41
|
):
|
|
45
42
|
def __call__(self, message: AxisArray) -> AxisArray:
|
|
46
43
|
# Override __call__ so we can shortcut if weights are None.
|
|
47
44
|
if self.settings.weights is None or (
|
|
48
|
-
isinstance(self.settings.weights, str)
|
|
49
|
-
and self.settings.weights == "passthrough"
|
|
45
|
+
isinstance(self.settings.weights, str) and self.settings.weights == "passthrough"
|
|
50
46
|
):
|
|
51
47
|
return message
|
|
52
48
|
return super().__call__(message)
|
|
@@ -68,18 +64,12 @@ class AffineTransformTransformer(
|
|
|
68
64
|
self._state.weights = weights
|
|
69
65
|
|
|
70
66
|
axis = self.settings.axis or message.dims[-1]
|
|
71
|
-
if (
|
|
72
|
-
axis in message.axes
|
|
73
|
-
and hasattr(message.axes[axis], "data")
|
|
74
|
-
and weights.shape[0] != weights.shape[1]
|
|
75
|
-
):
|
|
67
|
+
if axis in message.axes and hasattr(message.axes[axis], "data") and weights.shape[0] != weights.shape[1]:
|
|
76
68
|
in_labels = message.axes[axis].data
|
|
77
69
|
new_labels = []
|
|
78
70
|
n_in, n_out = weights.shape
|
|
79
71
|
if len(in_labels) != n_in:
|
|
80
|
-
ez.logger.warning(
|
|
81
|
-
f"Received {len(in_labels)} for {n_in} inputs. Check upstream labels."
|
|
82
|
-
)
|
|
72
|
+
ez.logger.warning(f"Received {len(in_labels)} for {n_in} inputs. Check upstream labels.")
|
|
83
73
|
else:
|
|
84
74
|
b_filled_outputs = np.any(weights, axis=0)
|
|
85
75
|
b_used_inputs = np.any(weights, axis=1)
|
|
@@ -97,9 +87,7 @@ class AffineTransformTransformer(
|
|
|
97
87
|
elif np.all(b_filled_outputs):
|
|
98
88
|
new_labels = np.array(in_labels)[b_used_inputs]
|
|
99
89
|
|
|
100
|
-
self._state.new_axis = replace(
|
|
101
|
-
message.axes[axis], data=np.array(new_labels)
|
|
102
|
-
)
|
|
90
|
+
self._state.new_axis = replace(message.axes[axis], data=np.array(new_labels))
|
|
103
91
|
|
|
104
92
|
def _process(self, message: AxisArray) -> AxisArray:
|
|
105
93
|
axis = self.settings.axis or message.dims[-1]
|
|
@@ -110,9 +98,7 @@ class AffineTransformTransformer(
|
|
|
110
98
|
# The weights are stacked A|B where A is the transform and B is a single row
|
|
111
99
|
# in the equation y = Ax + B. This supports NeuroKey's weights matrices.
|
|
112
100
|
sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx + 1 :]
|
|
113
|
-
data = np.concatenate(
|
|
114
|
-
(data, np.ones(sample_shape).astype(data.dtype)), axis=axis_idx
|
|
115
|
-
)
|
|
101
|
+
data = np.concatenate((data, np.ones(sample_shape).astype(data.dtype)), axis=axis_idx)
|
|
116
102
|
|
|
117
103
|
if axis_idx in [-1, len(message.dims) - 1]:
|
|
118
104
|
data = np.matmul(data, self._state.weights)
|
|
@@ -128,11 +114,7 @@ class AffineTransformTransformer(
|
|
|
128
114
|
return replace(message, **replace_kwargs)
|
|
129
115
|
|
|
130
116
|
|
|
131
|
-
class AffineTransform(
|
|
132
|
-
BaseTransformerUnit[
|
|
133
|
-
AffineTransformSettings, AxisArray, AxisArray, AffineTransformTransformer
|
|
134
|
-
]
|
|
135
|
-
):
|
|
117
|
+
class AffineTransform(BaseTransformerUnit[AffineTransformSettings, AxisArray, AxisArray, AffineTransformTransformer]):
|
|
136
118
|
SETTINGS = AffineTransformSettings
|
|
137
119
|
|
|
138
120
|
|
|
@@ -153,9 +135,7 @@ def affine_transform(
|
|
|
153
135
|
:obj:`AffineTransformTransformer`.
|
|
154
136
|
"""
|
|
155
137
|
return AffineTransformTransformer(
|
|
156
|
-
AffineTransformSettings(
|
|
157
|
-
weights=weights, axis=axis, right_multiply=right_multiply
|
|
158
|
-
)
|
|
138
|
+
AffineTransformSettings(weights=weights, axis=axis, right_multiply=right_multiply)
|
|
159
139
|
)
|
|
160
140
|
|
|
161
141
|
|
|
@@ -178,9 +158,7 @@ class CommonRereferenceSettings(ez.Settings):
|
|
|
178
158
|
"""Set False to exclude each channel from participating in the calculation of its reference."""
|
|
179
159
|
|
|
180
160
|
|
|
181
|
-
class CommonRereferenceTransformer(
|
|
182
|
-
BaseTransformer[CommonRereferenceSettings, AxisArray, AxisArray]
|
|
183
|
-
):
|
|
161
|
+
class CommonRereferenceTransformer(BaseTransformer[CommonRereferenceSettings, AxisArray, AxisArray]):
|
|
184
162
|
def _process(self, message: AxisArray) -> AxisArray:
|
|
185
163
|
if self.settings.mode == "passthrough":
|
|
186
164
|
return message
|
|
@@ -188,9 +166,7 @@ class CommonRereferenceTransformer(
|
|
|
188
166
|
axis = self.settings.axis or message.dims[-1]
|
|
189
167
|
axis_idx = message.get_axis_idx(axis)
|
|
190
168
|
|
|
191
|
-
func = {"mean": np.mean, "median": np.median, "passthrough": zeros_for_noop}[
|
|
192
|
-
self.settings.mode
|
|
193
|
-
]
|
|
169
|
+
func = {"mean": np.mean, "median": np.median, "passthrough": zeros_for_noop}[self.settings.mode]
|
|
194
170
|
|
|
195
171
|
ref_data = func(message.data, axis=axis_idx, keepdims=True)
|
|
196
172
|
|
|
@@ -213,9 +189,7 @@ class CommonRereferenceTransformer(
|
|
|
213
189
|
|
|
214
190
|
|
|
215
191
|
class CommonRereference(
|
|
216
|
-
BaseTransformerUnit[
|
|
217
|
-
CommonRereferenceSettings, AxisArray, AxisArray, CommonRereferenceTransformer
|
|
218
|
-
]
|
|
192
|
+
BaseTransformerUnit[CommonRereferenceSettings, AxisArray, AxisArray, CommonRereferenceTransformer]
|
|
219
193
|
):
|
|
220
194
|
SETTINGS = CommonRereferenceSettings
|
|
221
195
|
|
ezmsg/sigproc/aggregate.py
CHANGED
|
@@ -1,23 +1,23 @@
|
|
|
1
|
-
from array_api_compat import get_namespace
|
|
2
1
|
import typing
|
|
3
2
|
|
|
3
|
+
import ezmsg.core as ez
|
|
4
4
|
import numpy as np
|
|
5
5
|
import numpy.typing as npt
|
|
6
|
-
|
|
6
|
+
from array_api_compat import get_namespace
|
|
7
|
+
from ezmsg.baseproc import (
|
|
8
|
+
BaseStatefulTransformer,
|
|
9
|
+
BaseTransformer,
|
|
10
|
+
BaseTransformerUnit,
|
|
11
|
+
processor_state,
|
|
12
|
+
)
|
|
7
13
|
from ezmsg.util.messages.axisarray import (
|
|
8
14
|
AxisArray,
|
|
9
|
-
slice_along_axis,
|
|
10
15
|
AxisBase,
|
|
11
16
|
replace,
|
|
17
|
+
slice_along_axis,
|
|
12
18
|
)
|
|
13
19
|
|
|
14
20
|
from .spectral import OptionsEnum
|
|
15
|
-
from .base import (
|
|
16
|
-
BaseTransformer,
|
|
17
|
-
BaseStatefulTransformer,
|
|
18
|
-
BaseTransformerUnit,
|
|
19
|
-
processor_state,
|
|
20
|
-
)
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class AggregationFunction(OptionsEnum):
|
|
@@ -89,9 +89,7 @@ class RangedAggregateState:
|
|
|
89
89
|
|
|
90
90
|
|
|
91
91
|
class RangedAggregateTransformer(
|
|
92
|
-
BaseStatefulTransformer[
|
|
93
|
-
RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateState
|
|
94
|
-
]
|
|
92
|
+
BaseStatefulTransformer[RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateState]
|
|
95
93
|
):
|
|
96
94
|
def __call__(self, message: AxisArray) -> AxisArray:
|
|
97
95
|
# Override for shortcut passthrough mode.
|
|
@@ -118,16 +116,12 @@ class RangedAggregateTransformer(
|
|
|
118
116
|
if hasattr(target_axis, "data"):
|
|
119
117
|
self._state.ax_vec = target_axis.data
|
|
120
118
|
else:
|
|
121
|
-
self._state.ax_vec = target_axis.value(
|
|
122
|
-
np.arange(message.data.shape[ax_idx])
|
|
123
|
-
)
|
|
119
|
+
self._state.ax_vec = target_axis.value(np.arange(message.data.shape[ax_idx]))
|
|
124
120
|
|
|
125
121
|
ax_dat = []
|
|
126
122
|
slices = []
|
|
127
123
|
for start, stop in self.settings.bands:
|
|
128
|
-
inds = np.where(
|
|
129
|
-
np.logical_and(self._state.ax_vec >= start, self._state.ax_vec <= stop)
|
|
130
|
-
)[0]
|
|
124
|
+
inds = np.where(np.logical_and(self._state.ax_vec >= start, self._state.ax_vec <= stop))[0]
|
|
131
125
|
slices.append(np.s_[inds[0] : inds[-1] + 1])
|
|
132
126
|
if hasattr(target_axis, "data"):
|
|
133
127
|
if self._state.ax_vec.dtype.type is np.str_:
|
|
@@ -164,8 +158,7 @@ class RangedAggregateTransformer(
|
|
|
164
158
|
]
|
|
165
159
|
else:
|
|
166
160
|
out_data = [
|
|
167
|
-
agg_func(slice_along_axis(message.data, sl, axis=ax_idx), axis=ax_idx)
|
|
168
|
-
for sl in self._state.slices
|
|
161
|
+
agg_func(slice_along_axis(message.data, sl, axis=ax_idx), axis=ax_idx) for sl in self._state.slices
|
|
169
162
|
]
|
|
170
163
|
|
|
171
164
|
msg_out = replace(
|
|
@@ -187,11 +180,7 @@ class RangedAggregateTransformer(
|
|
|
187
180
|
return msg_out
|
|
188
181
|
|
|
189
182
|
|
|
190
|
-
class RangedAggregate(
|
|
191
|
-
BaseTransformerUnit[
|
|
192
|
-
RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateTransformer
|
|
193
|
-
]
|
|
194
|
-
):
|
|
183
|
+
class RangedAggregate(BaseTransformerUnit[RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateTransformer]):
|
|
195
184
|
SETTINGS = RangedAggregateSettings
|
|
196
185
|
|
|
197
186
|
|
|
@@ -212,9 +201,7 @@ def ranged_aggregate(
|
|
|
212
201
|
Returns:
|
|
213
202
|
:obj:`RangedAggregateTransformer`
|
|
214
203
|
"""
|
|
215
|
-
return RangedAggregateTransformer(
|
|
216
|
-
RangedAggregateSettings(axis=axis, bands=bands, operation=operation)
|
|
217
|
-
)
|
|
204
|
+
return RangedAggregateTransformer(RangedAggregateSettings(axis=axis, bands=bands, operation=operation))
|
|
218
205
|
|
|
219
206
|
|
|
220
207
|
class AggregateSettings(ez.Settings):
|
|
@@ -242,9 +229,7 @@ class AggregateTransformer(BaseTransformer[AggregateSettings, AxisArray, AxisArr
|
|
|
242
229
|
op = self.settings.operation
|
|
243
230
|
|
|
244
231
|
if op == AggregationFunction.NONE:
|
|
245
|
-
raise ValueError(
|
|
246
|
-
"AggregationFunction.NONE is not supported for full-axis aggregation"
|
|
247
|
-
)
|
|
232
|
+
raise ValueError("AggregationFunction.NONE is not supported for full-axis aggregation")
|
|
248
233
|
|
|
249
234
|
if op == AggregationFunction.TRAPEZOID:
|
|
250
235
|
# Trapezoid integration requires x-coordinates
|
|
@@ -276,9 +261,7 @@ class AggregateTransformer(BaseTransformer[AggregateSettings, AxisArray, AxisArr
|
|
|
276
261
|
)
|
|
277
262
|
|
|
278
263
|
|
|
279
|
-
class AggregateUnit(
|
|
280
|
-
BaseTransformerUnit[AggregateSettings, AxisArray, AxisArray, AggregateTransformer]
|
|
281
|
-
):
|
|
264
|
+
class AggregateUnit(BaseTransformerUnit[AggregateSettings, AxisArray, AxisArray, AggregateTransformer]):
|
|
282
265
|
"""Unit that aggregates an entire axis using a specified operation."""
|
|
283
266
|
|
|
284
267
|
SETTINGS = AggregateSettings
|
ezmsg/sigproc/bandpower.py
CHANGED
|
@@ -1,20 +1,20 @@
|
|
|
1
1
|
from dataclasses import field
|
|
2
2
|
|
|
3
3
|
import ezmsg.core as ez
|
|
4
|
+
from ezmsg.baseproc import (
|
|
5
|
+
BaseProcessor,
|
|
6
|
+
BaseStatefulProcessor,
|
|
7
|
+
BaseTransformerUnit,
|
|
8
|
+
CompositeProcessor,
|
|
9
|
+
)
|
|
4
10
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
5
11
|
|
|
6
|
-
from .spectrogram import SpectrogramSettings, SpectrogramTransformer
|
|
7
12
|
from .aggregate import (
|
|
8
13
|
AggregationFunction,
|
|
9
|
-
RangedAggregateTransformer,
|
|
10
14
|
RangedAggregateSettings,
|
|
15
|
+
RangedAggregateTransformer,
|
|
11
16
|
)
|
|
12
|
-
from .
|
|
13
|
-
BaseProcessor,
|
|
14
|
-
CompositeProcessor,
|
|
15
|
-
BaseStatefulProcessor,
|
|
16
|
-
BaseTransformerUnit,
|
|
17
|
-
)
|
|
17
|
+
from .spectrogram import SpectrogramSettings, SpectrogramTransformer
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class BandPowerSettings(ez.Settings):
|
|
@@ -22,16 +22,12 @@ class BandPowerSettings(ez.Settings):
|
|
|
22
22
|
Settings for ``BandPower``.
|
|
23
23
|
"""
|
|
24
24
|
|
|
25
|
-
spectrogram_settings: SpectrogramSettings = field(
|
|
26
|
-
default_factory=SpectrogramSettings
|
|
27
|
-
)
|
|
25
|
+
spectrogram_settings: SpectrogramSettings = field(default_factory=SpectrogramSettings)
|
|
28
26
|
"""
|
|
29
27
|
Settings for spectrogram calculation.
|
|
30
28
|
"""
|
|
31
29
|
|
|
32
|
-
bands: list[tuple[float, float]] | None = field(
|
|
33
|
-
default_factory=lambda: [(17, 30), (70, 170)]
|
|
34
|
-
)
|
|
30
|
+
bands: list[tuple[float, float]] | None = field(default_factory=lambda: [(17, 30), (70, 170)])
|
|
35
31
|
"""
|
|
36
32
|
(min, max) tuples of band limits in Hz.
|
|
37
33
|
"""
|
|
@@ -46,9 +42,7 @@ class BandPowerTransformer(CompositeProcessor[BandPowerSettings, AxisArray, Axis
|
|
|
46
42
|
settings: BandPowerSettings,
|
|
47
43
|
) -> dict[str, BaseProcessor | BaseStatefulProcessor]:
|
|
48
44
|
return {
|
|
49
|
-
"spectrogram": SpectrogramTransformer(
|
|
50
|
-
settings=settings.spectrogram_settings
|
|
51
|
-
),
|
|
45
|
+
"spectrogram": SpectrogramTransformer(settings=settings.spectrogram_settings),
|
|
52
46
|
"aggregate": RangedAggregateTransformer(
|
|
53
47
|
settings=RangedAggregateSettings(
|
|
54
48
|
axis="freq",
|
|
@@ -59,9 +53,7 @@ class BandPowerTransformer(CompositeProcessor[BandPowerSettings, AxisArray, Axis
|
|
|
59
53
|
}
|
|
60
54
|
|
|
61
55
|
|
|
62
|
-
class BandPower(
|
|
63
|
-
BaseTransformerUnit[BandPowerSettings, AxisArray, AxisArray, BandPowerTransformer]
|
|
64
|
-
):
|
|
56
|
+
class BandPower(BaseTransformerUnit[BandPowerSettings, AxisArray, AxisArray, BandPowerTransformer]):
|
|
65
57
|
SETTINGS = BandPowerSettings
|
|
66
58
|
|
|
67
59
|
|