ezmsg-sigproc 1.8.2__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +36 -39
- ezmsg/sigproc/adaptive_lattice_notch.py +231 -0
- ezmsg/sigproc/affinetransform.py +169 -163
- ezmsg/sigproc/aggregate.py +133 -101
- ezmsg/sigproc/bandpower.py +64 -52
- ezmsg/sigproc/base.py +1242 -0
- ezmsg/sigproc/butterworthfilter.py +37 -33
- ezmsg/sigproc/cheby.py +29 -17
- ezmsg/sigproc/combfilter.py +163 -0
- ezmsg/sigproc/decimate.py +19 -10
- ezmsg/sigproc/detrend.py +29 -0
- ezmsg/sigproc/diff.py +81 -0
- ezmsg/sigproc/downsample.py +78 -84
- ezmsg/sigproc/ewma.py +197 -0
- ezmsg/sigproc/extract_axis.py +41 -0
- ezmsg/sigproc/filter.py +257 -141
- ezmsg/sigproc/filterbank.py +247 -199
- ezmsg/sigproc/math/abs.py +17 -22
- ezmsg/sigproc/math/clip.py +24 -24
- ezmsg/sigproc/math/difference.py +34 -30
- ezmsg/sigproc/math/invert.py +13 -25
- ezmsg/sigproc/math/log.py +28 -33
- ezmsg/sigproc/math/scale.py +18 -26
- ezmsg/sigproc/quantize.py +71 -0
- ezmsg/sigproc/resample.py +298 -0
- ezmsg/sigproc/sampler.py +241 -259
- ezmsg/sigproc/scaler.py +55 -218
- ezmsg/sigproc/signalinjector.py +52 -43
- ezmsg/sigproc/slicer.py +81 -89
- ezmsg/sigproc/spectrogram.py +77 -75
- ezmsg/sigproc/spectrum.py +203 -168
- ezmsg/sigproc/synth.py +546 -393
- ezmsg/sigproc/transpose.py +131 -0
- ezmsg/sigproc/util/asio.py +156 -0
- ezmsg/sigproc/util/message.py +31 -0
- ezmsg/sigproc/util/profile.py +55 -12
- ezmsg/sigproc/util/typeresolution.py +83 -0
- ezmsg/sigproc/wavelets.py +154 -153
- ezmsg/sigproc/window.py +269 -211
- {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/METADATA +2 -1
- ezmsg_sigproc-2.1.0.dist-info/RECORD +51 -0
- ezmsg_sigproc-1.8.2.dist-info/RECORD +0 -39
- {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/WHEEL +0 -0
- {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/licenses/LICENSE.txt +0 -0
ezmsg/sigproc/filterbank.py
CHANGED
|
@@ -10,11 +10,14 @@ import numpy.typing as npt
|
|
|
10
10
|
import ezmsg.core as ez
|
|
11
11
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
12
12
|
from ezmsg.util.messages.util import replace
|
|
13
|
-
from ezmsg.util.generator import consumer
|
|
14
13
|
|
|
15
|
-
from .base import
|
|
14
|
+
from .base import (
|
|
15
|
+
BaseStatefulTransformer,
|
|
16
|
+
BaseTransformerUnit,
|
|
17
|
+
processor_state,
|
|
18
|
+
)
|
|
16
19
|
from .spectrum import OptionsEnum
|
|
17
|
-
from .window import
|
|
20
|
+
from .window import WindowTransformer
|
|
18
21
|
|
|
19
22
|
|
|
20
23
|
class FilterbankMode(OptionsEnum):
|
|
@@ -34,248 +37,293 @@ class MinPhaseMode(OptionsEnum):
|
|
|
34
37
|
# HOMOMORPHICFULL = "Like HOMOMORPHIC, but uses the full number of taps and same magnitude"
|
|
35
38
|
|
|
36
39
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
mode: FilterbankMode = FilterbankMode.CONV
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
40
|
+
class FilterbankSettings(ez.Settings):
|
|
41
|
+
kernels: list[npt.NDArray] | tuple[npt.NDArray, ...]
|
|
42
|
+
|
|
43
|
+
mode: FilterbankMode = FilterbankMode.CONV
|
|
44
|
+
"""
|
|
45
|
+
"conv", "fft", or "auto". If "auto", the mode is determined by the size of the input data.
|
|
46
|
+
fft mode is more efficient for long kernels. However, fft mode uses non-overlapping windows and will
|
|
47
|
+
incur a delay equal to the window length, which is larger than the largest kernel.
|
|
48
|
+
conv mode is less efficient but will return data for every incoming chunk regardless of how small it is
|
|
49
|
+
and thus can provide shorter latency updates.
|
|
45
50
|
"""
|
|
46
|
-
Perform multiple (direct or fft) convolutions on a signal using a bank of kernels.
|
|
47
|
-
This is intended to be used during online processing, therefore both direct and fft convolutions
|
|
48
|
-
use the overlap-add method.
|
|
49
|
-
Args:
|
|
50
|
-
kernels:
|
|
51
|
-
mode: "conv", "fft", or "auto". If "auto", the mode is determined by the size of the input data.
|
|
52
|
-
fft mode is more efficient for long kernels. However, fft mode uses non-overlapping windows and will
|
|
53
|
-
incur a delay equal to the window length, which is larger than the largest kernel.
|
|
54
|
-
conv mode is less efficient but will return data for every incoming chunk regardless of how small it is
|
|
55
|
-
and thus can provide shorter latency updates.
|
|
56
|
-
min_phase: If not None, convert the kernels to minimum-phase equivalents. Valid options are
|
|
57
|
-
'hilbert', 'homomorphic', and 'homomorphic-full'. Complex filters not supported.
|
|
58
|
-
See `scipy.signal.minimum_phase` for details.
|
|
59
|
-
axis: The name of the axis to operate on. This should usually be "time".
|
|
60
|
-
new_axis: The name of the new axis corresponding to the kernel index.
|
|
61
|
-
|
|
62
|
-
Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an :obj:`AxisArray`
|
|
63
|
-
with the data payload containing the absolute value of the input :obj:`AxisArray` data.
|
|
64
51
|
|
|
52
|
+
min_phase: MinPhaseMode = MinPhaseMode.NONE
|
|
53
|
+
"""
|
|
54
|
+
If not None, convert the kernels to minimum-phase equivalents. Valid options are
|
|
55
|
+
'hilbert', 'homomorphic', and 'homomorphic-full'. Complex filters not supported.
|
|
56
|
+
See `scipy.signal.minimum_phase` for details.
|
|
65
57
|
"""
|
|
66
|
-
msg_out: AxisArray | None = None
|
|
67
58
|
|
|
68
|
-
|
|
59
|
+
axis: str = "time"
|
|
60
|
+
"""The name of the axis to operate on. This should usually be "time"."""
|
|
61
|
+
|
|
62
|
+
new_axis: str = "kernel"
|
|
63
|
+
"""The name of the new axis corresponding to the kernel index."""
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@processor_state
|
|
67
|
+
class FilterbankState:
|
|
68
|
+
tail: npt.NDArray | None = None
|
|
69
69
|
template: AxisArray | None = None
|
|
70
|
+
dest_arr: npt.NDArray | None = None
|
|
71
|
+
prep_kerns: npt.NDArray | list[npt.NDArray] | None = None
|
|
72
|
+
windower: WindowTransformer | None = None
|
|
73
|
+
fft: typing.Callable | None = None
|
|
74
|
+
ifft: typing.Callable | None = None
|
|
75
|
+
nfft: int | None = None
|
|
76
|
+
infft: int | None = None
|
|
77
|
+
overlap: int | None = None
|
|
78
|
+
mode: FilterbankMode | None = None
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class FilterbankTransformer(
|
|
82
|
+
BaseStatefulTransformer[FilterbankSettings, AxisArray, AxisArray, FilterbankState]
|
|
83
|
+
):
|
|
84
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
85
|
+
axis = self.settings.axis or message.dims[0]
|
|
86
|
+
gain = message.axes[axis].gain if axis in message.axes else 1.0
|
|
87
|
+
targ_ax_ix = message.get_axis_idx(axis)
|
|
88
|
+
in_shape = (
|
|
89
|
+
message.data.shape[:targ_ax_ix] + message.data.shape[targ_ax_ix + 1 :]
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
return hash(
|
|
93
|
+
(
|
|
94
|
+
message.key,
|
|
95
|
+
gain
|
|
96
|
+
if self.settings.mode in [FilterbankMode.FFT, FilterbankMode.AUTO]
|
|
97
|
+
else None,
|
|
98
|
+
message.data.dtype.kind,
|
|
99
|
+
in_shape,
|
|
100
|
+
)
|
|
101
|
+
)
|
|
70
102
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
"shape": None,
|
|
78
|
-
}
|
|
79
|
-
|
|
80
|
-
while True:
|
|
81
|
-
msg_in: AxisArray = yield msg_out
|
|
82
|
-
|
|
83
|
-
axis = axis or msg_in.dims[0]
|
|
84
|
-
gain = msg_in.axes[axis].gain if axis in msg_in.axes else 1.0
|
|
85
|
-
targ_ax_ix = msg_in.get_axis_idx(axis)
|
|
86
|
-
in_shape = msg_in.data.shape[:targ_ax_ix] + msg_in.data.shape[targ_ax_ix + 1 :]
|
|
87
|
-
|
|
88
|
-
b_reset = msg_in.key != check_input["key"]
|
|
89
|
-
b_reset = b_reset or (
|
|
90
|
-
gain != check_input["gain"]
|
|
91
|
-
and mode in [FilterbankMode.FFT, FilterbankMode.AUTO]
|
|
103
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
104
|
+
axis = self.settings.axis or message.dims[0]
|
|
105
|
+
gain = message.axes[axis].gain if axis in message.axes else 1.0
|
|
106
|
+
targ_ax_ix = message.get_axis_idx(axis)
|
|
107
|
+
in_shape = (
|
|
108
|
+
message.data.shape[:targ_ax_ix] + message.data.shape[targ_ax_ix + 1 :]
|
|
92
109
|
)
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
if
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
110
|
+
|
|
111
|
+
kernels = self.settings.kernels
|
|
112
|
+
if self.settings.min_phase != MinPhaseMode.NONE:
|
|
113
|
+
method, half = {
|
|
114
|
+
MinPhaseMode.HILBERT: ("hilbert", False),
|
|
115
|
+
MinPhaseMode.HOMOMORPHIC: ("homomorphic", False),
|
|
116
|
+
# MinPhaseMode.HOMOMORPHICFULL: ("homomorphic", True),
|
|
117
|
+
}[self.settings.min_phase]
|
|
118
|
+
kernels = [sps.minimum_phase(k, method=method) for k in kernels]
|
|
119
|
+
|
|
120
|
+
# Determine if this will be operating with complex data.
|
|
121
|
+
b_complex = message.data.dtype.kind == "c" or any(
|
|
122
|
+
[_.dtype.kind == "c" for _ in kernels]
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Calculate window_dur, window_shift, nfft
|
|
126
|
+
max_kernel_len = max([_.size for _ in kernels])
|
|
127
|
+
# From sps._calc_oa_lens, where s2=max_kernel_len,:
|
|
128
|
+
# fallback_nfft = n_input + max_kernel_len - 1, but n_input is unbound.
|
|
129
|
+
self._state.overlap = max_kernel_len - 1
|
|
130
|
+
|
|
131
|
+
# Prepare previous iteration's overlap tail to add to input -- all zeros.
|
|
132
|
+
tail_shape = in_shape + (len(kernels), self._state.overlap)
|
|
133
|
+
self._state.tail = np.zeros(
|
|
134
|
+
tail_shape, dtype="complex" if b_complex else "float"
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# Prepare output template -- kernels axis immediately before the target axis
|
|
138
|
+
dummy_shape = in_shape + (len(kernels), 0)
|
|
139
|
+
self._state.template = AxisArray(
|
|
140
|
+
data=np.zeros(dummy_shape, dtype="complex" if b_complex else "float"),
|
|
141
|
+
dims=message.dims[:targ_ax_ix]
|
|
142
|
+
+ message.dims[targ_ax_ix + 1 :]
|
|
143
|
+
+ [self.settings.new_axis, axis],
|
|
144
|
+
axes=message.axes.copy(),
|
|
145
|
+
key=message.key,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Determine optimal mode. Assumes 100 msec chunks.
|
|
149
|
+
self._state.mode = self.settings.mode
|
|
150
|
+
if self._state.mode == FilterbankMode.AUTO:
|
|
151
|
+
# concatenate kernels into 1 mega kernel then check what's faster.
|
|
152
|
+
# Will typically return fft when combined kernel length is > 1500.
|
|
153
|
+
concat_kernel = np.concatenate(kernels)
|
|
154
|
+
n_dummy = max(2 * len(concat_kernel), int(0.1 / gain))
|
|
155
|
+
dummy_arr = np.zeros(n_dummy)
|
|
156
|
+
self._state.mode = (
|
|
157
|
+
FilterbankMode.CONV
|
|
158
|
+
if sps.choose_conv_method(dummy_arr, concat_kernel, mode="full")
|
|
159
|
+
== "direct"
|
|
160
|
+
else FilterbankMode.FFT
|
|
117
161
|
)
|
|
118
162
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
# Prepare previous iteration's overlap tail to add to input -- all zeros.
|
|
126
|
-
tail_shape = in_shape + (len(kernels), overlap)
|
|
127
|
-
tail = np.zeros(tail_shape, dtype="complex" if b_complex else "float")
|
|
128
|
-
|
|
129
|
-
# Prepare output template -- kernels axis immediately before the target axis
|
|
130
|
-
dummy_shape = in_shape + (len(kernels), 0)
|
|
131
|
-
template = AxisArray(
|
|
132
|
-
data=np.zeros(dummy_shape, dtype="complex" if b_complex else "float"),
|
|
133
|
-
dims=msg_in.dims[:targ_ax_ix]
|
|
134
|
-
+ msg_in.dims[targ_ax_ix + 1 :]
|
|
135
|
-
+ [new_axis, axis],
|
|
136
|
-
axes=msg_in.axes.copy(), # We do not have info for kernel/filter axis :(.
|
|
137
|
-
key=msg_in.key,
|
|
163
|
+
if self._state.mode == FilterbankMode.CONV:
|
|
164
|
+
# Preallocate memory for convolution result and overlap-add
|
|
165
|
+
dest_shape = in_shape + (
|
|
166
|
+
len(kernels),
|
|
167
|
+
self._state.overlap + message.data.shape[targ_ax_ix],
|
|
138
168
|
)
|
|
169
|
+
self._state.dest_arr = np.zeros(
|
|
170
|
+
dest_shape, dtype="complex" if b_complex else "float"
|
|
171
|
+
)
|
|
172
|
+
self._state.prep_kerns = kernels
|
|
173
|
+
else: # FFT mode
|
|
174
|
+
# Calculate optimal nfft and windowing size.
|
|
175
|
+
opt_size = (
|
|
176
|
+
-self._state.overlap
|
|
177
|
+
* lambertw(-1 / (2 * math.e * self._state.overlap), k=-1).real
|
|
178
|
+
)
|
|
179
|
+
self._state.nfft = sp_fft.next_fast_len(math.ceil(opt_size))
|
|
180
|
+
win_len = self._state.nfft - self._state.overlap
|
|
181
|
+
# infft same as nfft. Keeping as separate variable because I might need it again.
|
|
182
|
+
self._state.infft = win_len + self._state.overlap
|
|
183
|
+
|
|
184
|
+
# Create windowing node.
|
|
185
|
+
# Note: We could do windowing manually to avoid the overhead of the message structure,
|
|
186
|
+
# but windowing is difficult to do correctly, so we lean on the heavily-tested `windowing` generator.
|
|
187
|
+
win_dur = win_len * gain
|
|
188
|
+
self._state.windower = WindowTransformer(
|
|
189
|
+
axis=axis,
|
|
190
|
+
newaxis="win",
|
|
191
|
+
window_dur=win_dur,
|
|
192
|
+
window_shift=win_dur,
|
|
193
|
+
zero_pad_until="none",
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# Windowing output has an extra "win" dimension, so we need our tail to match.
|
|
197
|
+
self._state.tail = np.expand_dims(self._state.tail, -2)
|
|
139
198
|
|
|
140
|
-
#
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
mode = sps.choose_conv_method(dummy_arr, concat_kernel, mode="full")
|
|
148
|
-
mode = FilterbankMode.CONV if mode == "direct" else FilterbankMode.FFT
|
|
149
|
-
|
|
150
|
-
if mode == FilterbankMode.CONV:
|
|
151
|
-
# Preallocate memory for convolution result and overlap-add
|
|
152
|
-
dest_shape = in_shape + (
|
|
153
|
-
len(kernels),
|
|
154
|
-
overlap + msg_in.data.shape[targ_ax_ix],
|
|
199
|
+
# Prepare fft functions
|
|
200
|
+
# Note: We could instead use `spectrum` but this adds overhead in creating the message structure
|
|
201
|
+
# for a rather simple calculation. We may revisit if `spectrum` gets additional features, such as
|
|
202
|
+
# more fft backends.
|
|
203
|
+
if b_complex:
|
|
204
|
+
self._state.fft = functools.partial(
|
|
205
|
+
sp_fft.fft, n=self._state.nfft, norm="backward"
|
|
155
206
|
)
|
|
156
|
-
|
|
157
|
-
|
|
207
|
+
self._state.ifft = functools.partial(
|
|
208
|
+
sp_fft.ifft, n=self._state.infft, norm="backward"
|
|
158
209
|
)
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
# infft same as nfft. Keeping as separate variable because I might need it again.
|
|
166
|
-
infft = win_len + overlap
|
|
167
|
-
|
|
168
|
-
# Create windowing node.
|
|
169
|
-
# Note: We could do windowing manually to avoid the overhead of the message structure,
|
|
170
|
-
# but windowing is difficult to do correctly, so we lean on the heavily-tested `windowing` generator.
|
|
171
|
-
win_dur = win_len * gain
|
|
172
|
-
wingen = windowing(
|
|
173
|
-
axis=axis,
|
|
174
|
-
newaxis="win", # Big data chunks might yield more than 1 window.
|
|
175
|
-
window_dur=win_dur,
|
|
176
|
-
window_shift=win_dur, # Tumbling (not sliding) windows expected!
|
|
177
|
-
zero_pad_until="none",
|
|
210
|
+
else:
|
|
211
|
+
self._state.fft = functools.partial(
|
|
212
|
+
sp_fft.rfft, n=self._state.nfft, norm="backward"
|
|
213
|
+
)
|
|
214
|
+
self._state.ifft = functools.partial(
|
|
215
|
+
sp_fft.irfft, n=self._state.infft, norm="backward"
|
|
178
216
|
)
|
|
179
217
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
fft = functools.partial(sp_fft.fft, n=nfft, norm="backward")
|
|
189
|
-
ifft = functools.partial(sp_fft.ifft, n=infft, norm="backward")
|
|
190
|
-
else:
|
|
191
|
-
fft = functools.partial(sp_fft.rfft, n=nfft, norm="backward")
|
|
192
|
-
ifft = functools.partial(sp_fft.irfft, n=infft, norm="backward")
|
|
193
|
-
|
|
194
|
-
# Calculate fft of kernels
|
|
195
|
-
prep_kerns = np.array([fft(_) for _ in kernels])
|
|
196
|
-
prep_kerns = np.expand_dims(prep_kerns, -2)
|
|
197
|
-
# TODO: If fft_kernels have significant stretches of zeros, convert to sparse array.
|
|
218
|
+
# Calculate fft of kernels
|
|
219
|
+
self._state.prep_kerns = np.array([self._state.fft(_) for _ in kernels])
|
|
220
|
+
self._state.prep_kerns = np.expand_dims(self._state.prep_kerns, -2)
|
|
221
|
+
# TODO: If fft_kernels have significant stretches of zeros, convert to sparse array.
|
|
222
|
+
|
|
223
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
224
|
+
axis = self.settings.axis or message.dims[0]
|
|
225
|
+
targ_ax_ix = message.get_axis_idx(axis)
|
|
198
226
|
|
|
199
227
|
# Make sure target axis is in -1th position.
|
|
200
|
-
if targ_ax_ix != (
|
|
201
|
-
in_dat = np.moveaxis(
|
|
202
|
-
if mode == FilterbankMode.FFT:
|
|
203
|
-
# Fix
|
|
228
|
+
if targ_ax_ix != (message.data.ndim - 1):
|
|
229
|
+
in_dat = np.moveaxis(message.data, targ_ax_ix, -1)
|
|
230
|
+
if self._state.mode == FilterbankMode.FFT:
|
|
231
|
+
# Fix message.dims because we will pass it to windower
|
|
204
232
|
move_dims = (
|
|
205
|
-
|
|
233
|
+
message.dims[:targ_ax_ix] + message.dims[targ_ax_ix + 1 :] + [axis]
|
|
206
234
|
)
|
|
207
|
-
|
|
235
|
+
message = replace(message, data=in_dat, dims=move_dims)
|
|
208
236
|
else:
|
|
209
|
-
in_dat =
|
|
210
|
-
|
|
211
|
-
if mode == FilterbankMode.CONV:
|
|
212
|
-
n_dest = in_dat.shape[-1] + overlap
|
|
213
|
-
if dest_arr.shape[-1] < n_dest:
|
|
214
|
-
pad = np.zeros(
|
|
215
|
-
|
|
216
|
-
|
|
237
|
+
in_dat = message.data
|
|
238
|
+
|
|
239
|
+
if self._state.mode == FilterbankMode.CONV:
|
|
240
|
+
n_dest = in_dat.shape[-1] + self._state.overlap
|
|
241
|
+
if self._state.dest_arr.shape[-1] < n_dest:
|
|
242
|
+
pad = np.zeros(
|
|
243
|
+
self._state.dest_arr.shape[:-1]
|
|
244
|
+
+ (n_dest - self._state.dest_arr.shape[-1],)
|
|
245
|
+
)
|
|
246
|
+
self._state.dest_arr = np.concatenate(
|
|
247
|
+
[self._state.dest_arr, pad], axis=-1
|
|
248
|
+
)
|
|
249
|
+
self._state.dest_arr.fill(0)
|
|
250
|
+
|
|
217
251
|
# Note: I tried several alternatives to this loop; all were slower than this.
|
|
218
252
|
# numba.jit; stride_tricks + np.einsum; threading. Latter might be better with Python 3.13.
|
|
219
|
-
for k_ix, k in enumerate(
|
|
253
|
+
for k_ix, k in enumerate(self._state.prep_kerns):
|
|
220
254
|
n_out = in_dat.shape[-1] + k.shape[-1] - 1
|
|
221
|
-
dest_arr[..., k_ix, :n_out] = np.apply_along_axis(
|
|
255
|
+
self._state.dest_arr[..., k_ix, :n_out] = np.apply_along_axis(
|
|
222
256
|
np.convolve, -1, in_dat, k, mode="full"
|
|
223
257
|
)
|
|
224
|
-
dest_arr[..., :overlap] += tail
|
|
225
|
-
new_tail = dest_arr[..., in_dat.shape[-1] : n_dest]
|
|
258
|
+
self._state.dest_arr[..., : self._state.overlap] += self._state.tail
|
|
259
|
+
new_tail = self._state.dest_arr[..., in_dat.shape[-1] : n_dest]
|
|
226
260
|
if new_tail.size > 0:
|
|
227
261
|
# COPY overlap for next iteration
|
|
228
|
-
tail = new_tail.copy()
|
|
229
|
-
res = dest_arr[..., : in_dat.shape[-1]].copy()
|
|
230
|
-
|
|
262
|
+
self._state.tail = new_tail.copy()
|
|
263
|
+
res = self._state.dest_arr[..., : in_dat.shape[-1]].copy()
|
|
264
|
+
else: # FFT mode
|
|
231
265
|
# Slice into non-overlapping windows
|
|
232
|
-
win_msg =
|
|
233
|
-
# Calculate
|
|
234
|
-
spec_dat = fft(win_msg.data, axis=-1)
|
|
266
|
+
win_msg = self._state.windower.send(message)
|
|
267
|
+
# Calculate spectrum of each window
|
|
268
|
+
spec_dat = self._state.fft(win_msg.data, axis=-1)
|
|
235
269
|
# Insert axis for filters
|
|
236
270
|
spec_dat = np.expand_dims(spec_dat, -3)
|
|
237
271
|
|
|
238
272
|
# Do the FFT convolution
|
|
239
273
|
# TODO: handle fft_kernels being sparse. Maybe need np.dot.
|
|
240
|
-
conv_spec = spec_dat * prep_kerns
|
|
241
|
-
overlapped = ifft(conv_spec, axis=-1)
|
|
274
|
+
conv_spec = spec_dat * self._state.prep_kerns
|
|
275
|
+
overlapped = self._state.ifft(conv_spec, axis=-1)
|
|
242
276
|
|
|
243
277
|
# Do the overlap-add on the `axis` axis
|
|
244
278
|
# Previous iteration's tail:
|
|
245
|
-
overlapped[..., :1, :overlap] += tail
|
|
279
|
+
overlapped[..., :1, : self._state.overlap] += self._state.tail
|
|
246
280
|
# window-to-window:
|
|
247
|
-
overlapped[..., 1:, :overlap] += overlapped[
|
|
281
|
+
overlapped[..., 1:, : self._state.overlap] += overlapped[
|
|
282
|
+
..., :-1, -self._state.overlap :
|
|
283
|
+
]
|
|
248
284
|
# Save tail:
|
|
249
|
-
new_tail = overlapped[..., -1:, -overlap:]
|
|
285
|
+
new_tail = overlapped[..., -1:, -self._state.overlap :]
|
|
250
286
|
if new_tail.size > 0:
|
|
251
287
|
# All of the above code works if input is size-zero, but we don't want to save a zero-size tail.
|
|
252
|
-
tail = new_tail
|
|
288
|
+
self._state.tail = new_tail
|
|
253
289
|
# Concat over win axis, without overlap.
|
|
254
|
-
res = overlapped[...,
|
|
290
|
+
res = overlapped[..., : -self._state.overlap].reshape(
|
|
291
|
+
overlapped.shape[:-2] + (-1,)
|
|
292
|
+
)
|
|
255
293
|
|
|
256
|
-
|
|
257
|
-
template,
|
|
294
|
+
return replace(
|
|
295
|
+
self._state.template,
|
|
296
|
+
data=res,
|
|
297
|
+
axes={**self._state.template.axes, axis: message.axes[axis]},
|
|
258
298
|
)
|
|
259
299
|
|
|
260
300
|
|
|
261
|
-
class
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
min_phase: MinPhaseMode = MinPhaseMode.NONE
|
|
265
|
-
axis: str = "time"
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
class Filterbank(GenAxisArray):
|
|
269
|
-
"""Unit for :obj:`spectrum`"""
|
|
270
|
-
|
|
301
|
+
class Filterbank(
|
|
302
|
+
BaseTransformerUnit[FilterbankSettings, AxisArray, AxisArray, FilterbankTransformer]
|
|
303
|
+
):
|
|
271
304
|
SETTINGS = FilterbankSettings
|
|
272
305
|
|
|
273
|
-
INPUT_SETTINGS = ez.InputStream(FilterbankSettings)
|
|
274
306
|
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
307
|
+
def filterbank(
|
|
308
|
+
kernels: list[npt.NDArray] | tuple[npt.NDArray, ...],
|
|
309
|
+
mode: FilterbankMode = FilterbankMode.CONV,
|
|
310
|
+
min_phase: MinPhaseMode = MinPhaseMode.NONE,
|
|
311
|
+
axis: str = "time",
|
|
312
|
+
new_axis: str = "kernel",
|
|
313
|
+
) -> FilterbankTransformer:
|
|
314
|
+
"""
|
|
315
|
+
Perform multiple (direct or fft) convolutions on a signal using a bank of kernels.
|
|
316
|
+
This is intended to be used during online processing, therefore both direct and fft convolutions
|
|
317
|
+
use the overlap-add method.
|
|
318
|
+
|
|
319
|
+
Returns: :obj:`FilterbankTransformer`.
|
|
320
|
+
"""
|
|
321
|
+
return FilterbankTransformer(
|
|
322
|
+
settings=FilterbankSettings(
|
|
323
|
+
kernels=kernels,
|
|
324
|
+
mode=mode,
|
|
325
|
+
min_phase=min_phase,
|
|
326
|
+
axis=axis,
|
|
327
|
+
new_axis=new_axis,
|
|
281
328
|
)
|
|
329
|
+
)
|
ezmsg/sigproc/math/abs.py
CHANGED
|
@@ -1,34 +1,29 @@
|
|
|
1
|
-
import typing
|
|
2
|
-
|
|
3
1
|
import numpy as np
|
|
4
|
-
import ezmsg.core as ez
|
|
5
|
-
from ezmsg.util.generator import consumer
|
|
6
2
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
3
|
from ezmsg.util.messages.util import replace
|
|
8
4
|
|
|
9
|
-
from ..base import
|
|
5
|
+
from ..base import BaseTransformer, BaseTransformerUnit
|
|
10
6
|
|
|
11
7
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
"""
|
|
15
|
-
Take the absolute value of the data. See :obj:`np.abs` for more details.
|
|
8
|
+
class AbsSettings:
|
|
9
|
+
pass
|
|
16
10
|
|
|
17
|
-
Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an :obj:`AxisArray`
|
|
18
|
-
with the data payload containing the absolute value of the input :obj:`AxisArray` data.
|
|
19
|
-
"""
|
|
20
|
-
msg_out = AxisArray(np.array([]), dims=[""])
|
|
21
|
-
while True:
|
|
22
|
-
msg_in: AxisArray = yield msg_out
|
|
23
|
-
msg_out = replace(msg_in, data=np.abs(msg_in.data))
|
|
24
11
|
|
|
12
|
+
class AbsTransformer(BaseTransformer[None, AxisArray, AxisArray]):
|
|
13
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
14
|
+
return replace(message, data=np.abs(message.data))
|
|
25
15
|
|
|
26
|
-
class AbsSettings(ez.Settings):
|
|
27
|
-
pass
|
|
28
16
|
|
|
17
|
+
class Abs(
|
|
18
|
+
BaseTransformerUnit[None, AxisArray, AxisArray, AbsTransformer]
|
|
19
|
+
): ... # SETTINGS = None
|
|
29
20
|
|
|
30
|
-
class Abs(GenAxisArray):
|
|
31
|
-
SETTINGS = AbsSettings
|
|
32
21
|
|
|
33
|
-
|
|
34
|
-
|
|
22
|
+
def abs() -> AbsTransformer:
|
|
23
|
+
"""
|
|
24
|
+
Take the absolute value of the data. See :obj:`np.abs` for more details.
|
|
25
|
+
|
|
26
|
+
Returns: :obj:`AbsTransformer`.
|
|
27
|
+
|
|
28
|
+
"""
|
|
29
|
+
return AbsTransformer()
|
ezmsg/sigproc/math/clip.py
CHANGED
|
@@ -1,16 +1,32 @@
|
|
|
1
|
-
import typing
|
|
2
|
-
|
|
3
1
|
import numpy as np
|
|
4
2
|
import ezmsg.core as ez
|
|
5
|
-
from ezmsg.util.generator import consumer
|
|
6
3
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
4
|
from ezmsg.util.messages.util import replace
|
|
8
5
|
|
|
9
|
-
from ..base import
|
|
6
|
+
from ..base import BaseTransformer, BaseTransformerUnit
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ClipSettings(ez.Settings):
|
|
10
|
+
a_min: float
|
|
11
|
+
"""Lower clip bound."""
|
|
10
12
|
|
|
13
|
+
a_max: float
|
|
14
|
+
"""Upper clip bound."""
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ClipTransformer(BaseTransformer[ClipSettings, AxisArray, AxisArray]):
|
|
18
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
19
|
+
return replace(
|
|
20
|
+
message,
|
|
21
|
+
data=np.clip(message.data, self.settings.a_min, self.settings.a_max),
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Clip(BaseTransformerUnit[ClipSettings, AxisArray, AxisArray, ClipTransformer]):
|
|
26
|
+
SETTINGS = ClipSettings
|
|
11
27
|
|
|
12
|
-
|
|
13
|
-
def clip(a_min: float, a_max: float) ->
|
|
28
|
+
|
|
29
|
+
def clip(a_min: float, a_max: float) -> ClipTransformer:
|
|
14
30
|
"""
|
|
15
31
|
Clips the data to be within the specified range. See :obj:`np.clip` for more details.
|
|
16
32
|
|
|
@@ -18,23 +34,7 @@ def clip(a_min: float, a_max: float) -> typing.Generator[AxisArray, AxisArray, N
|
|
|
18
34
|
a_min: Lower clip bound
|
|
19
35
|
a_max: Upper clip bound
|
|
20
36
|
|
|
21
|
-
Returns:
|
|
22
|
-
with the data payload containing the clipped version of the input :obj:`AxisArray` data.
|
|
37
|
+
Returns: :obj:`ClipTransformer`.
|
|
23
38
|
|
|
24
39
|
"""
|
|
25
|
-
|
|
26
|
-
while True:
|
|
27
|
-
msg_in: AxisArray = yield msg_out
|
|
28
|
-
msg_out = replace(msg_in, data=np.clip(msg_in.data, a_min, a_max))
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
class ClipSettings(ez.Settings):
|
|
32
|
-
a_min: float
|
|
33
|
-
a_max: float
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
class Clip(GenAxisArray):
|
|
37
|
-
SETTINGS = ClipSettings
|
|
38
|
-
|
|
39
|
-
def construct_generator(self):
|
|
40
|
-
self.STATE.gen = clip(a_min=self.SETTINGS.a_min, a_max=self.SETTINGS.a_max)
|
|
40
|
+
return ClipTransformer(ClipSettings(a_min=a_min, a_max=a_max))
|