ezmsg-sigproc 1.7.0__py3-none-any.whl → 2.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +22 -4
- ezmsg/sigproc/activation.py +31 -40
- ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
- ezmsg/sigproc/affinetransform.py +171 -169
- ezmsg/sigproc/aggregate.py +190 -97
- ezmsg/sigproc/bandpower.py +60 -55
- ezmsg/sigproc/base.py +143 -33
- ezmsg/sigproc/butterworthfilter.py +34 -38
- ezmsg/sigproc/butterworthzerophase.py +305 -0
- ezmsg/sigproc/cheby.py +23 -17
- ezmsg/sigproc/combfilter.py +160 -0
- ezmsg/sigproc/coordinatespaces.py +159 -0
- ezmsg/sigproc/decimate.py +15 -10
- ezmsg/sigproc/denormalize.py +78 -0
- ezmsg/sigproc/detrend.py +28 -0
- ezmsg/sigproc/diff.py +82 -0
- ezmsg/sigproc/downsample.py +72 -81
- ezmsg/sigproc/ewma.py +217 -0
- ezmsg/sigproc/ewmfilter.py +1 -1
- ezmsg/sigproc/extract_axis.py +39 -0
- ezmsg/sigproc/fbcca.py +307 -0
- ezmsg/sigproc/filter.py +254 -148
- ezmsg/sigproc/filterbank.py +226 -214
- ezmsg/sigproc/filterbankdesign.py +129 -0
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +117 -0
- ezmsg/sigproc/gaussiansmoothing.py +89 -0
- ezmsg/sigproc/kaiser.py +106 -0
- ezmsg/sigproc/linear.py +120 -0
- ezmsg/sigproc/math/abs.py +23 -22
- ezmsg/sigproc/math/add.py +120 -0
- ezmsg/sigproc/math/clip.py +33 -25
- ezmsg/sigproc/math/difference.py +117 -43
- ezmsg/sigproc/math/invert.py +18 -25
- ezmsg/sigproc/math/log.py +38 -33
- ezmsg/sigproc/math/scale.py +24 -25
- ezmsg/sigproc/messages.py +1 -2
- ezmsg/sigproc/quantize.py +68 -0
- ezmsg/sigproc/resample.py +278 -0
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +209 -254
- ezmsg/sigproc/scaler.py +93 -218
- ezmsg/sigproc/signalinjector.py +44 -43
- ezmsg/sigproc/slicer.py +74 -102
- ezmsg/sigproc/spectral.py +3 -3
- ezmsg/sigproc/spectrogram.py +70 -70
- ezmsg/sigproc/spectrum.py +187 -173
- ezmsg/sigproc/transpose.py +134 -0
- ezmsg/sigproc/util/__init__.py +0 -0
- ezmsg/sigproc/util/asio.py +25 -0
- ezmsg/sigproc/util/axisarray_buffer.py +365 -0
- ezmsg/sigproc/util/buffer.py +449 -0
- ezmsg/sigproc/util/message.py +17 -0
- ezmsg/sigproc/util/profile.py +23 -0
- ezmsg/sigproc/util/sparse.py +115 -0
- ezmsg/sigproc/util/typeresolution.py +17 -0
- ezmsg/sigproc/wavelets.py +147 -154
- ezmsg/sigproc/window.py +248 -210
- ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
- {ezmsg_sigproc-1.7.0.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -1
- ezmsg/sigproc/synth.py +0 -621
- ezmsg_sigproc-1.7.0.dist-info/METADATA +0 -58
- ezmsg_sigproc-1.7.0.dist-info/RECORD +0 -36
- /ezmsg_sigproc-1.7.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/filterbank.py
CHANGED
|
@@ -2,19 +2,22 @@ import functools
|
|
|
2
2
|
import math
|
|
3
3
|
import typing
|
|
4
4
|
|
|
5
|
+
import ezmsg.core as ez
|
|
5
6
|
import numpy as np
|
|
6
|
-
import scipy.signal as sps
|
|
7
|
-
import scipy.fft as sp_fft
|
|
8
|
-
from scipy.special import lambertw
|
|
9
7
|
import numpy.typing as npt
|
|
10
|
-
import
|
|
8
|
+
import scipy.fft as sp_fft
|
|
9
|
+
import scipy.signal as sps
|
|
10
|
+
from ezmsg.baseproc import (
|
|
11
|
+
BaseStatefulTransformer,
|
|
12
|
+
BaseTransformerUnit,
|
|
13
|
+
processor_state,
|
|
14
|
+
)
|
|
11
15
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
12
16
|
from ezmsg.util.messages.util import replace
|
|
13
|
-
from
|
|
17
|
+
from scipy.special import lambertw
|
|
14
18
|
|
|
15
|
-
from .base import GenAxisArray
|
|
16
19
|
from .spectrum import OptionsEnum
|
|
17
|
-
from .window import
|
|
20
|
+
from .window import WindowTransformer
|
|
18
21
|
|
|
19
22
|
|
|
20
23
|
class FilterbankMode(OptionsEnum):
|
|
@@ -29,252 +32,261 @@ class MinPhaseMode(OptionsEnum):
|
|
|
29
32
|
"""The mode of operation for the filterbank."""
|
|
30
33
|
|
|
31
34
|
NONE = "No kernel modification"
|
|
32
|
-
HILBERT =
|
|
33
|
-
|
|
35
|
+
HILBERT = (
|
|
36
|
+
"Hilbert Method; designed to be used with equiripple filters (e.g., from remez) with unity or zero gain regions"
|
|
37
|
+
)
|
|
38
|
+
HOMOMORPHIC = (
|
|
39
|
+
"Works best with filters with an odd number of taps, and the resulting minimum phase filter "
|
|
40
|
+
"will have a magnitude response that approximates the square root of the original filter’s "
|
|
41
|
+
"magnitude response using half the number of taps"
|
|
42
|
+
)
|
|
34
43
|
# HOMOMORPHICFULL = "Like HOMOMORPHIC, but uses the full number of taps and same magnitude"
|
|
35
44
|
|
|
36
45
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
mode: FilterbankMode = FilterbankMode.CONV
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
46
|
+
class FilterbankSettings(ez.Settings):
|
|
47
|
+
kernels: list[npt.NDArray] | tuple[npt.NDArray, ...]
|
|
48
|
+
|
|
49
|
+
mode: FilterbankMode = FilterbankMode.CONV
|
|
50
|
+
"""
|
|
51
|
+
"conv", "fft", or "auto". If "auto", the mode is determined by the size of the input data.
|
|
52
|
+
fft mode is more efficient for long kernels. However, fft mode uses non-overlapping windows and will
|
|
53
|
+
incur a delay equal to the window length, which is larger than the largest kernel.
|
|
54
|
+
conv mode is less efficient but will return data for every incoming chunk regardless of how small it is
|
|
55
|
+
and thus can provide shorter latency updates.
|
|
45
56
|
"""
|
|
46
|
-
Perform multiple (direct or fft) convolutions on a signal using a bank of kernels.
|
|
47
|
-
This is intended to be used during online processing, therefore both direct and fft convolutions
|
|
48
|
-
use the overlap-add method.
|
|
49
|
-
Args:
|
|
50
|
-
kernels:
|
|
51
|
-
mode: "conv", "fft", or "auto". If "auto", the mode is determined by the size of the input data.
|
|
52
|
-
fft mode is more efficient for long kernels. However, fft mode uses non-overlapping windows and will
|
|
53
|
-
incur a delay equal to the window length, which is larger than the largest kernel.
|
|
54
|
-
conv mode is less efficient but will return data for every incoming chunk regardless of how small it is
|
|
55
|
-
and thus can provide shorter latency updates.
|
|
56
|
-
min_phase: If not None, convert the kernels to minimum-phase equivalents. Valid options are
|
|
57
|
-
'hilbert', 'homomorphic', and 'homomorphic-full'. Complex filters not supported.
|
|
58
|
-
See `scipy.signal.minimum_phase` for details.
|
|
59
|
-
axis: The name of the axis to operate on. This should usually be "time".
|
|
60
|
-
new_axis: The name of the new axis corresponding to the kernel index.
|
|
61
|
-
|
|
62
|
-
Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an :obj:`AxisArray`
|
|
63
|
-
with the data payload containing the absolute value of the input :obj:`AxisArray` data.
|
|
64
57
|
|
|
58
|
+
min_phase: MinPhaseMode = MinPhaseMode.NONE
|
|
65
59
|
"""
|
|
66
|
-
|
|
60
|
+
If not None, convert the kernels to minimum-phase equivalents. Valid options are
|
|
61
|
+
'hilbert', 'homomorphic', and 'homomorphic-full'. Complex filters not supported.
|
|
62
|
+
See `scipy.signal.minimum_phase` for details.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
axis: str = "time"
|
|
66
|
+
"""The name of the axis to operate on. This should usually be "time"."""
|
|
67
67
|
|
|
68
|
-
|
|
68
|
+
new_axis: str = "kernel"
|
|
69
|
+
"""The name of the new axis corresponding to the kernel index."""
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@processor_state
|
|
73
|
+
class FilterbankState:
|
|
74
|
+
tail: npt.NDArray | None = None
|
|
69
75
|
template: AxisArray | None = None
|
|
76
|
+
dest_arr: npt.NDArray | None = None
|
|
77
|
+
prep_kerns: npt.NDArray | list[npt.NDArray] | None = None
|
|
78
|
+
windower: WindowTransformer | None = None
|
|
79
|
+
fft: typing.Callable | None = None
|
|
80
|
+
ifft: typing.Callable | None = None
|
|
81
|
+
nfft: int | None = None
|
|
82
|
+
infft: int | None = None
|
|
83
|
+
overlap: int | None = None
|
|
84
|
+
mode: FilterbankMode | None = None
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class FilterbankTransformer(BaseStatefulTransformer[FilterbankSettings, AxisArray, AxisArray, FilterbankState]):
|
|
88
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
89
|
+
axis = self.settings.axis or message.dims[0]
|
|
90
|
+
gain = message.axes[axis].gain if axis in message.axes else 1.0
|
|
91
|
+
targ_ax_ix = message.get_axis_idx(axis)
|
|
92
|
+
in_shape = message.data.shape[:targ_ax_ix] + message.data.shape[targ_ax_ix + 1 :]
|
|
93
|
+
|
|
94
|
+
return hash(
|
|
95
|
+
(
|
|
96
|
+
message.key,
|
|
97
|
+
gain if self.settings.mode in [FilterbankMode.FFT, FilterbankMode.AUTO] else None,
|
|
98
|
+
message.data.dtype.kind,
|
|
99
|
+
in_shape,
|
|
100
|
+
)
|
|
101
|
+
)
|
|
70
102
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
103
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
104
|
+
axis = self.settings.axis or message.dims[0]
|
|
105
|
+
gain = message.axes[axis].gain if axis in message.axes else 1.0
|
|
106
|
+
targ_ax_ix = message.get_axis_idx(axis)
|
|
107
|
+
in_shape = message.data.shape[:targ_ax_ix] + message.data.shape[targ_ax_ix + 1 :]
|
|
108
|
+
|
|
109
|
+
kernels = self.settings.kernels
|
|
110
|
+
if self.settings.min_phase != MinPhaseMode.NONE:
|
|
111
|
+
method, half = {
|
|
112
|
+
MinPhaseMode.HILBERT: ("hilbert", False),
|
|
113
|
+
MinPhaseMode.HOMOMORPHIC: ("homomorphic", False),
|
|
114
|
+
# MinPhaseMode.HOMOMORPHICFULL: ("homomorphic", True),
|
|
115
|
+
}[self.settings.min_phase]
|
|
116
|
+
kernels = [sps.minimum_phase(k, method=method) for k in kernels]
|
|
117
|
+
|
|
118
|
+
# Determine if this will be operating with complex data.
|
|
119
|
+
b_complex = message.data.dtype.kind == "c" or any([_.dtype.kind == "c" for _ in kernels])
|
|
120
|
+
|
|
121
|
+
# Calculate window_dur, window_shift, nfft
|
|
122
|
+
max_kernel_len = max([_.size for _ in kernels])
|
|
123
|
+
# From sps._calc_oa_lens, where s2=max_kernel_len,:
|
|
124
|
+
# fallback_nfft = n_input + max_kernel_len - 1, but n_input is unbound.
|
|
125
|
+
self._state.overlap = max_kernel_len - 1
|
|
126
|
+
|
|
127
|
+
# Prepare previous iteration's overlap tail to add to input -- all zeros.
|
|
128
|
+
tail_shape = in_shape + (len(kernels), self._state.overlap)
|
|
129
|
+
self._state.tail = np.zeros(tail_shape, dtype="complex" if b_complex else "float")
|
|
130
|
+
|
|
131
|
+
# Prepare output template -- kernels axis immediately before the target axis
|
|
132
|
+
dummy_shape = in_shape + (len(kernels), 0)
|
|
133
|
+
self._state.template = AxisArray(
|
|
134
|
+
data=np.zeros(dummy_shape, dtype="complex" if b_complex else "float"),
|
|
135
|
+
dims=message.dims[:targ_ax_ix] + message.dims[targ_ax_ix + 1 :] + [self.settings.new_axis, axis],
|
|
136
|
+
axes=message.axes.copy(),
|
|
137
|
+
key=message.key,
|
|
92
138
|
)
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
}[min_phase]
|
|
107
|
-
kernels = [
|
|
108
|
-
sps.minimum_phase(
|
|
109
|
-
k, method=method
|
|
110
|
-
) # , half=half) -- half requires later scipy >= 1.14
|
|
111
|
-
for k in kernels
|
|
112
|
-
]
|
|
113
|
-
|
|
114
|
-
# Determine if this will be operating with complex data.
|
|
115
|
-
b_complex = msg_in.data.dtype.kind == "c" or any(
|
|
116
|
-
[_.dtype.kind == "c" for _ in kernels]
|
|
139
|
+
|
|
140
|
+
# Determine optimal mode. Assumes 100 msec chunks.
|
|
141
|
+
self._state.mode = self.settings.mode
|
|
142
|
+
if self._state.mode == FilterbankMode.AUTO:
|
|
143
|
+
# concatenate kernels into 1 mega kernel then check what's faster.
|
|
144
|
+
# Will typically return fft when combined kernel length is > 1500.
|
|
145
|
+
concat_kernel = np.concatenate(kernels)
|
|
146
|
+
n_dummy = max(2 * len(concat_kernel), int(0.1 / gain))
|
|
147
|
+
dummy_arr = np.zeros(n_dummy)
|
|
148
|
+
self._state.mode = (
|
|
149
|
+
FilterbankMode.CONV
|
|
150
|
+
if sps.choose_conv_method(dummy_arr, concat_kernel, mode="full") == "direct"
|
|
151
|
+
else FilterbankMode.FFT
|
|
117
152
|
)
|
|
118
153
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
154
|
+
if self._state.mode == FilterbankMode.CONV:
|
|
155
|
+
# Preallocate memory for convolution result and overlap-add
|
|
156
|
+
dest_shape = in_shape + (
|
|
157
|
+
len(kernels),
|
|
158
|
+
self._state.overlap + message.data.shape[targ_ax_ix],
|
|
159
|
+
)
|
|
160
|
+
self._state.dest_arr = np.zeros(dest_shape, dtype="complex" if b_complex else "float")
|
|
161
|
+
self._state.prep_kerns = kernels
|
|
162
|
+
else: # FFT mode
|
|
163
|
+
# Calculate optimal nfft and windowing size.
|
|
164
|
+
opt_size = -self._state.overlap * lambertw(-1 / (2 * math.e * self._state.overlap), k=-1).real
|
|
165
|
+
self._state.nfft = sp_fft.next_fast_len(math.ceil(opt_size))
|
|
166
|
+
win_len = self._state.nfft - self._state.overlap
|
|
167
|
+
# infft same as nfft. Keeping as separate variable because I might need it again.
|
|
168
|
+
self._state.infft = win_len + self._state.overlap
|
|
169
|
+
|
|
170
|
+
# Create windowing node.
|
|
171
|
+
# Note: We could do windowing manually to avoid the overhead of the message structure,
|
|
172
|
+
# but windowing is difficult to do correctly, so we lean on the heavily-tested `windowing` generator.
|
|
173
|
+
win_dur = win_len * gain
|
|
174
|
+
self._state.windower = WindowTransformer(
|
|
175
|
+
axis=axis,
|
|
176
|
+
newaxis="win",
|
|
177
|
+
window_dur=win_dur,
|
|
178
|
+
window_shift=win_dur,
|
|
179
|
+
zero_pad_until="none",
|
|
138
180
|
)
|
|
139
181
|
|
|
140
|
-
#
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
opt_size = -overlap * lambertw(-1 / (2 * math.e * overlap), k=-1).real
|
|
163
|
-
nfft = sp_fft.next_fast_len(math.ceil(opt_size))
|
|
164
|
-
win_len = nfft - overlap
|
|
165
|
-
# infft same as nfft. Keeping as separate variable because I might need it again.
|
|
166
|
-
infft = win_len + overlap
|
|
167
|
-
|
|
168
|
-
# Create windowing node.
|
|
169
|
-
# Note: We could do windowing manually to avoid the overhead of the message structure,
|
|
170
|
-
# but windowing is difficult to do correctly, so we lean on the heavily-tested `windowing` generator.
|
|
171
|
-
win_dur = win_len * gain
|
|
172
|
-
wingen = windowing(
|
|
173
|
-
axis=axis,
|
|
174
|
-
newaxis="win", # Big data chunks might yield more than 1 window.
|
|
175
|
-
window_dur=win_dur,
|
|
176
|
-
window_shift=win_dur, # Tumbling (not sliding) windows expected!
|
|
177
|
-
zero_pad_until="none",
|
|
178
|
-
)
|
|
179
|
-
|
|
180
|
-
# Windowing output has an extra "win" dimension, so we need our tail to match.
|
|
181
|
-
tail = np.expand_dims(tail, -2)
|
|
182
|
-
|
|
183
|
-
# Prepare fft functions
|
|
184
|
-
# Note: We could instead use `spectrum` but this adds overhead in creating the message structure
|
|
185
|
-
# for a rather simple calculation. We may revisit if `spectrum` gets additional features, such as
|
|
186
|
-
# more fft backends.
|
|
187
|
-
if b_complex:
|
|
188
|
-
fft = functools.partial(sp_fft.fft, n=nfft, norm="backward")
|
|
189
|
-
ifft = functools.partial(sp_fft.ifft, n=infft, norm="backward")
|
|
190
|
-
else:
|
|
191
|
-
fft = functools.partial(sp_fft.rfft, n=nfft, norm="backward")
|
|
192
|
-
ifft = functools.partial(sp_fft.irfft, n=infft, norm="backward")
|
|
193
|
-
|
|
194
|
-
# Calculate fft of kernels
|
|
195
|
-
prep_kerns = np.array([fft(_) for _ in kernels])
|
|
196
|
-
prep_kerns = np.expand_dims(prep_kerns, -2)
|
|
197
|
-
# TODO: If fft_kernels have significant stretches of zeros, convert to sparse array.
|
|
182
|
+
# Windowing output has an extra "win" dimension, so we need our tail to match.
|
|
183
|
+
self._state.tail = np.expand_dims(self._state.tail, -2)
|
|
184
|
+
|
|
185
|
+
# Prepare fft functions
|
|
186
|
+
# Note: We could instead use `spectrum` but this adds overhead in creating the message structure
|
|
187
|
+
# for a rather simple calculation. We may revisit if `spectrum` gets additional features, such as
|
|
188
|
+
# more fft backends.
|
|
189
|
+
if b_complex:
|
|
190
|
+
self._state.fft = functools.partial(sp_fft.fft, n=self._state.nfft, norm="backward")
|
|
191
|
+
self._state.ifft = functools.partial(sp_fft.ifft, n=self._state.infft, norm="backward")
|
|
192
|
+
else:
|
|
193
|
+
self._state.fft = functools.partial(sp_fft.rfft, n=self._state.nfft, norm="backward")
|
|
194
|
+
self._state.ifft = functools.partial(sp_fft.irfft, n=self._state.infft, norm="backward")
|
|
195
|
+
|
|
196
|
+
# Calculate fft of kernels
|
|
197
|
+
self._state.prep_kerns = np.array([self._state.fft(_) for _ in kernels])
|
|
198
|
+
self._state.prep_kerns = np.expand_dims(self._state.prep_kerns, -2)
|
|
199
|
+
# TODO: If fft_kernels have significant stretches of zeros, convert to sparse array.
|
|
200
|
+
|
|
201
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
202
|
+
axis = self.settings.axis or message.dims[0]
|
|
203
|
+
targ_ax_ix = message.get_axis_idx(axis)
|
|
198
204
|
|
|
199
205
|
# Make sure target axis is in -1th position.
|
|
200
|
-
if targ_ax_ix != (
|
|
201
|
-
in_dat = np.moveaxis(
|
|
202
|
-
if mode == FilterbankMode.FFT:
|
|
203
|
-
# Fix
|
|
204
|
-
move_dims =
|
|
205
|
-
|
|
206
|
-
)
|
|
207
|
-
msg_in = replace(msg_in, data=in_dat, dims=move_dims)
|
|
206
|
+
if targ_ax_ix != (message.data.ndim - 1):
|
|
207
|
+
in_dat = np.moveaxis(message.data, targ_ax_ix, -1)
|
|
208
|
+
if self._state.mode == FilterbankMode.FFT:
|
|
209
|
+
# Fix message.dims because we will pass it to windower
|
|
210
|
+
move_dims = message.dims[:targ_ax_ix] + message.dims[targ_ax_ix + 1 :] + [axis]
|
|
211
|
+
message = replace(message, data=in_dat, dims=move_dims)
|
|
208
212
|
else:
|
|
209
|
-
in_dat =
|
|
210
|
-
|
|
211
|
-
if mode == FilterbankMode.CONV:
|
|
212
|
-
n_dest = in_dat.shape[-1] + overlap
|
|
213
|
-
if dest_arr.shape[-1] < n_dest:
|
|
214
|
-
pad = np.zeros(dest_arr.shape[:-1] + (n_dest - dest_arr.shape[-1],))
|
|
215
|
-
dest_arr = np.concatenate(dest_arr, pad, axis=-1)
|
|
216
|
-
dest_arr.fill(0)
|
|
217
|
-
|
|
218
|
-
|
|
213
|
+
in_dat = message.data
|
|
214
|
+
|
|
215
|
+
if self._state.mode == FilterbankMode.CONV:
|
|
216
|
+
n_dest = in_dat.shape[-1] + self._state.overlap
|
|
217
|
+
if self._state.dest_arr.shape[-1] < n_dest:
|
|
218
|
+
pad = np.zeros(self._state.dest_arr.shape[:-1] + (n_dest - self._state.dest_arr.shape[-1],))
|
|
219
|
+
self._state.dest_arr = np.concatenate([self._state.dest_arr, pad], axis=-1)
|
|
220
|
+
self._state.dest_arr.fill(0)
|
|
221
|
+
|
|
222
|
+
# Note: I tried several alternatives to this loop; all were slower than this.
|
|
223
|
+
# numba.jit; stride_tricks + np.einsum; threading. Latter might be better with Python 3.13.
|
|
224
|
+
for k_ix, k in enumerate(self._state.prep_kerns):
|
|
219
225
|
n_out = in_dat.shape[-1] + k.shape[-1] - 1
|
|
220
|
-
dest_arr[..., k_ix, :n_out] = np.apply_along_axis(
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
dest_arr[..., :overlap] += tail # Add previous overlap
|
|
224
|
-
new_tail = dest_arr[..., in_dat.shape[-1] : n_dest]
|
|
226
|
+
self._state.dest_arr[..., k_ix, :n_out] = np.apply_along_axis(np.convolve, -1, in_dat, k, mode="full")
|
|
227
|
+
self._state.dest_arr[..., : self._state.overlap] += self._state.tail
|
|
228
|
+
new_tail = self._state.dest_arr[..., in_dat.shape[-1] : n_dest]
|
|
225
229
|
if new_tail.size > 0:
|
|
226
230
|
# COPY overlap for next iteration
|
|
227
|
-
tail = new_tail.copy()
|
|
228
|
-
res = dest_arr[..., : in_dat.shape[-1]].copy()
|
|
229
|
-
|
|
231
|
+
self._state.tail = new_tail.copy()
|
|
232
|
+
res = self._state.dest_arr[..., : in_dat.shape[-1]].copy()
|
|
233
|
+
else: # FFT mode
|
|
230
234
|
# Slice into non-overlapping windows
|
|
231
|
-
win_msg =
|
|
232
|
-
# Calculate
|
|
233
|
-
spec_dat = fft(win_msg.data, axis=-1)
|
|
235
|
+
win_msg = self._state.windower.send(message)
|
|
236
|
+
# Calculate spectrum of each window
|
|
237
|
+
spec_dat = self._state.fft(win_msg.data, axis=-1)
|
|
234
238
|
# Insert axis for filters
|
|
235
239
|
spec_dat = np.expand_dims(spec_dat, -3)
|
|
236
240
|
|
|
237
241
|
# Do the FFT convolution
|
|
238
242
|
# TODO: handle fft_kernels being sparse. Maybe need np.dot.
|
|
239
|
-
conv_spec = spec_dat * prep_kerns
|
|
240
|
-
overlapped = ifft(conv_spec, axis=-1)
|
|
243
|
+
conv_spec = spec_dat * self._state.prep_kerns
|
|
244
|
+
overlapped = self._state.ifft(conv_spec, axis=-1)
|
|
241
245
|
|
|
242
246
|
# Do the overlap-add on the `axis` axis
|
|
243
247
|
# Previous iteration's tail:
|
|
244
|
-
overlapped[..., :1, :overlap] += tail
|
|
248
|
+
overlapped[..., :1, : self._state.overlap] += self._state.tail
|
|
245
249
|
# window-to-window:
|
|
246
|
-
overlapped[..., 1:, :overlap] += overlapped[..., :-1, -overlap:]
|
|
250
|
+
overlapped[..., 1:, : self._state.overlap] += overlapped[..., :-1, -self._state.overlap :]
|
|
247
251
|
# Save tail:
|
|
248
|
-
new_tail = overlapped[..., -1:, -overlap:]
|
|
252
|
+
new_tail = overlapped[..., -1:, -self._state.overlap :]
|
|
249
253
|
if new_tail.size > 0:
|
|
250
254
|
# All of the above code works if input is size-zero, but we don't want to save a zero-size tail.
|
|
251
|
-
tail = new_tail
|
|
255
|
+
self._state.tail = new_tail
|
|
252
256
|
# Concat over win axis, without overlap.
|
|
253
|
-
res = overlapped[...,
|
|
257
|
+
res = overlapped[..., : -self._state.overlap].reshape(overlapped.shape[:-2] + (-1,))
|
|
254
258
|
|
|
255
|
-
|
|
256
|
-
template,
|
|
259
|
+
return replace(
|
|
260
|
+
self._state.template,
|
|
261
|
+
data=res,
|
|
262
|
+
axes={**self._state.template.axes, axis: message.axes[axis]},
|
|
257
263
|
)
|
|
258
264
|
|
|
259
265
|
|
|
260
|
-
class FilterbankSettings
|
|
261
|
-
kernels: list[npt.NDArray] | tuple[npt.NDArray, ...]
|
|
262
|
-
mode: FilterbankMode = FilterbankMode.CONV
|
|
263
|
-
min_phase: MinPhaseMode = MinPhaseMode.NONE
|
|
264
|
-
axis: str = "time"
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
class Filterbank(GenAxisArray):
|
|
268
|
-
"""Unit for :obj:`spectrum`"""
|
|
269
|
-
|
|
266
|
+
class Filterbank(BaseTransformerUnit[FilterbankSettings, AxisArray, AxisArray, FilterbankTransformer]):
|
|
270
267
|
SETTINGS = FilterbankSettings
|
|
271
268
|
|
|
272
|
-
INPUT_SETTINGS = ez.InputStream(FilterbankSettings)
|
|
273
269
|
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
270
|
+
def filterbank(
|
|
271
|
+
kernels: list[npt.NDArray] | tuple[npt.NDArray, ...],
|
|
272
|
+
mode: FilterbankMode = FilterbankMode.CONV,
|
|
273
|
+
min_phase: MinPhaseMode = MinPhaseMode.NONE,
|
|
274
|
+
axis: str = "time",
|
|
275
|
+
new_axis: str = "kernel",
|
|
276
|
+
) -> FilterbankTransformer:
|
|
277
|
+
"""
|
|
278
|
+
Perform multiple (direct or fft) convolutions on a signal using a bank of kernels.
|
|
279
|
+
This is intended to be used during online processing, therefore both direct and fft convolutions
|
|
280
|
+
use the overlap-add method.
|
|
281
|
+
|
|
282
|
+
Returns: :obj:`FilterbankTransformer`.
|
|
283
|
+
"""
|
|
284
|
+
return FilterbankTransformer(
|
|
285
|
+
settings=FilterbankSettings(
|
|
286
|
+
kernels=kernels,
|
|
287
|
+
mode=mode,
|
|
288
|
+
min_phase=min_phase,
|
|
289
|
+
axis=axis,
|
|
290
|
+
new_axis=new_axis,
|
|
280
291
|
)
|
|
292
|
+
)
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
import ezmsg.core as ez
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numpy.typing as npt
|
|
6
|
+
from ezmsg.baseproc import (
|
|
7
|
+
BaseStatefulTransformer,
|
|
8
|
+
processor_state,
|
|
9
|
+
)
|
|
10
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
11
|
+
from ezmsg.util.messages.util import replace
|
|
12
|
+
|
|
13
|
+
from .filterbank import (
|
|
14
|
+
FilterbankMode,
|
|
15
|
+
FilterbankSettings,
|
|
16
|
+
FilterbankTransformer,
|
|
17
|
+
MinPhaseMode,
|
|
18
|
+
)
|
|
19
|
+
from .kaiser import KaiserFilterSettings, kaiser_design_fun
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class FilterbankDesignSettings(ez.Settings):
|
|
23
|
+
filters: typing.Iterable[KaiserFilterSettings]
|
|
24
|
+
|
|
25
|
+
mode: FilterbankMode = FilterbankMode.CONV
|
|
26
|
+
"""
|
|
27
|
+
"conv", "fft", or "auto". If "auto", the mode is determined by the size of the input data.
|
|
28
|
+
fft mode is more efficient for long kernels. However, fft mode uses non-overlapping windows and will
|
|
29
|
+
incur a delay equal to the window length, which is larger than the largest kernel.
|
|
30
|
+
conv mode is less efficient but will return data for every incoming chunk regardless of how small it is
|
|
31
|
+
and thus can provide shorter latency updates.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
min_phase: MinPhaseMode = MinPhaseMode.NONE
|
|
35
|
+
"""
|
|
36
|
+
If not None, convert the kernels to minimum-phase equivalents. Valid options are
|
|
37
|
+
'hilbert', 'homomorphic', and 'homomorphic-full'. Complex filters not supported.
|
|
38
|
+
See `scipy.signal.minimum_phase` for details.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
axis: str = "time"
|
|
42
|
+
"""The name of the axis to operate on. This should usually be "time"."""
|
|
43
|
+
|
|
44
|
+
new_axis: str = "kernel"
|
|
45
|
+
"""The name of the new axis corresponding to the kernel index."""
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@processor_state
|
|
49
|
+
class FilterbankDesignState:
|
|
50
|
+
filterbank: FilterbankTransformer | None = None
|
|
51
|
+
needs_redesign: bool = False
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class FilterbankDesignTransformer(
|
|
55
|
+
BaseStatefulTransformer[FilterbankDesignSettings, AxisArray, AxisArray, FilterbankDesignState],
|
|
56
|
+
):
|
|
57
|
+
"""
|
|
58
|
+
Transformer that designs and applies a filterbank based on Kaiser windowed FIR filters.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
@classmethod
|
|
62
|
+
def get_message_type(cls, dir: str) -> type[AxisArray]:
|
|
63
|
+
if dir in ("in", "out"):
|
|
64
|
+
return AxisArray
|
|
65
|
+
else:
|
|
66
|
+
raise ValueError(f"Invalid direction: {dir}. Must be 'in' or 'out'.")
|
|
67
|
+
|
|
68
|
+
def update_settings(self, new_settings: typing.Optional[FilterbankDesignSettings] = None, **kwargs) -> None:
|
|
69
|
+
"""
|
|
70
|
+
Update settings and mark that filter coefficients need to be recalculated.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
new_settings: Complete new settings object to replace current settings
|
|
74
|
+
**kwargs: Individual settings to update
|
|
75
|
+
"""
|
|
76
|
+
# Update settings
|
|
77
|
+
if new_settings is not None:
|
|
78
|
+
self.settings = new_settings
|
|
79
|
+
else:
|
|
80
|
+
self.settings = replace(self.settings, **kwargs)
|
|
81
|
+
|
|
82
|
+
# Set flag to trigger recalculation on next message
|
|
83
|
+
if self.state.filterbank is not None:
|
|
84
|
+
self.state.needs_redesign = True
|
|
85
|
+
|
|
86
|
+
def _calculate_kernels(self, fs: float) -> list[npt.NDArray]:
|
|
87
|
+
kernels = []
|
|
88
|
+
for filter in self.settings.filters:
|
|
89
|
+
output = kaiser_design_fun(
|
|
90
|
+
fs,
|
|
91
|
+
cutoff=filter.cutoff,
|
|
92
|
+
ripple=filter.ripple,
|
|
93
|
+
width=filter.width,
|
|
94
|
+
pass_zero=filter.pass_zero,
|
|
95
|
+
wn_hz=filter.wn_hz,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
kernels.append(np.array([1.0]) if output is None else output[0])
|
|
99
|
+
return kernels
|
|
100
|
+
|
|
101
|
+
def __call__(self, message: AxisArray) -> AxisArray:
|
|
102
|
+
if self.state.filterbank is not None and self.state.needs_redesign:
|
|
103
|
+
self._reset_state(message)
|
|
104
|
+
self.state.needs_redesign = False
|
|
105
|
+
return super().__call__(message)
|
|
106
|
+
|
|
107
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
108
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
109
|
+
gain = message.axes[axis].gain if hasattr(message.axes[axis], "gain") else 1
|
|
110
|
+
axis_idx = message.get_axis_idx(axis)
|
|
111
|
+
samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
112
|
+
return hash((message.key, samp_shape, gain))
|
|
113
|
+
|
|
114
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
115
|
+
axis_obj = message.axes[self.settings.axis]
|
|
116
|
+
assert isinstance(axis_obj, AxisArray.LinearAxis)
|
|
117
|
+
fs = 1 / axis_obj.gain
|
|
118
|
+
kernels = self._calculate_kernels(fs)
|
|
119
|
+
new_settings = FilterbankSettings(
|
|
120
|
+
kernels=kernels,
|
|
121
|
+
mode=self.settings.mode,
|
|
122
|
+
min_phase=self.settings.min_phase,
|
|
123
|
+
axis=self.settings.axis,
|
|
124
|
+
new_axis=self.settings.new_axis,
|
|
125
|
+
)
|
|
126
|
+
self.state.filterbank = FilterbankTransformer(settings=new_settings)
|
|
127
|
+
|
|
128
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
129
|
+
return self.state.filterbank(message)
|