ezmsg-sigproc 1.2.3__py3-none-any.whl → 1.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,278 @@
1
+ from dataclasses import replace
2
+ import functools
3
+ import math
4
+ import typing
5
+
6
+ import numpy as np
7
+ import scipy.signal as sps
8
+ import scipy.fft as sp_fft
9
+ from scipy.special import lambertw
10
+ import numpy.typing as npt
11
+ import ezmsg.core as ez
12
+ from ezmsg.util.messages.axisarray import AxisArray
13
+ from ezmsg.util.generator import consumer
14
+
15
+ from .base import GenAxisArray
16
+ from .spectrum import OptionsEnum
17
+ from .window import windowing
18
+
19
+
20
+ class FilterbankMode(OptionsEnum):
21
+ """The mode of operation for the filterbank."""
22
+
23
+ CONV = "Direct Convolution"
24
+ FFT = "FFT Convolution"
25
+ AUTO = "Automatic"
26
+
27
+
28
+ class MinPhaseMode(OptionsEnum):
29
+ """The mode of operation for the filterbank."""
30
+
31
+ NONE = "No kernel modification"
32
+ HILBERT = "Hilbert Method; designed to be used with equiripple filters (e.g., from remez) with unity or zero gain regions"
33
+ HOMOMORPHIC = "Works best with filters with an odd number of taps, and the resulting minimum phase filter will have a magnitude response that approximates the square root of the original filter’s magnitude response using half the number of taps"
34
+ # HOMOMORPHICFULL = "Like HOMOMORPHIC, but uses the full number of taps and same magnitude"
35
+
36
+
37
+ @consumer
38
+ def filterbank(
39
+ kernels: typing.Union[list[npt.NDArray], tuple[npt.NDArray, ...]],
40
+ mode: FilterbankMode = FilterbankMode.CONV,
41
+ min_phase: MinPhaseMode = MinPhaseMode.NONE,
42
+ axis: str = "time",
43
+ new_axis: str = "kernel",
44
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
45
+ """
46
+ Returns a generator that perform multiple (direct or fft) convolutions on a signal using a bank of kernels.
47
+ This generator is intended to be used during online processing, therefore both direct and fft convolutions
48
+ use the overlap-add method.
49
+ Args:
50
+ kernels:
51
+ mode: "conv", "fft", or "auto". If "auto", the mode is determined by the size of the input data.
52
+ fft mode is more efficient for long kernels. However, fft mode uses non-overlapping windows and will
53
+ incur a delay equal to the window length, which is larger than the largest kernel.
54
+ conv mode is less efficient but will return data for every incoming chunk regardless of how small it is
55
+ and thus can provide shorter latency updates.
56
+ min_phase: If not None, convert the kernels to minimum-phase equivalents. Valid options are
57
+ 'hilbert', 'homomorphic', and 'homomorphic-full'. Complex filters not supported.
58
+ See `scipy.signal.minimum_phase` for details.
59
+ axis: The name of the axis to operate on. This should usually be "time".
60
+ new_axis: The name of the new axis corresponding to the kernel index.
61
+
62
+ Returns:
63
+
64
+ """
65
+ msg_out: typing.Optional[AxisArray] = None
66
+
67
+ # State variables
68
+ template: typing.Optional[AxisArray] = None
69
+
70
+ # Reset if these change
71
+ check_input = {
72
+ "key": None,
73
+ "template": None,
74
+ "gain": None,
75
+ "kind": None,
76
+ "shape": None,
77
+ }
78
+
79
+ while True:
80
+ msg_in: AxisArray = yield msg_out
81
+
82
+ axis = axis or msg_in.dims[0]
83
+ gain = msg_in.axes[axis].gain if axis in msg_in.axes else 1.0
84
+ targ_ax_ix = msg_in.get_axis_idx(axis)
85
+ in_shape = msg_in.data.shape[:targ_ax_ix] + msg_in.data.shape[targ_ax_ix + 1 :]
86
+
87
+ b_reset = msg_in.key != check_input["key"]
88
+ b_reset = b_reset or (
89
+ gain != check_input["gain"]
90
+ and mode in [FilterbankMode.FFT, FilterbankMode.AUTO]
91
+ )
92
+ b_reset = b_reset or msg_in.data.dtype.kind != check_input["kind"]
93
+ b_reset = b_reset or in_shape != check_input["shape"]
94
+ if b_reset:
95
+ check_input["key"] = msg_in.key
96
+ check_input["gain"] = gain
97
+ check_input["kind"] = msg_in.data.dtype.kind
98
+ check_input["shape"] = in_shape
99
+
100
+ if min_phase != MinPhaseMode.NONE:
101
+ method, half = {
102
+ MinPhaseMode.HILBERT: ("hilbert", False),
103
+ MinPhaseMode.HOMOMORPHIC: ("homomorphic", False),
104
+ # MinPhaseMode.HOMOMORPHICFULL: ("homomorphic", True),
105
+ }[min_phase]
106
+ kernels = [
107
+ sps.minimum_phase(
108
+ k, method=method
109
+ ) # , half=half) -- half requires later scipy >= 1.14
110
+ for k in kernels
111
+ ]
112
+
113
+ # Determine if this will be operating with complex data.
114
+ b_complex = msg_in.data.dtype.kind == "c" or any(
115
+ [_.dtype.kind == "c" for _ in kernels]
116
+ )
117
+
118
+ # Calculate window_dur, window_shift, nfft
119
+ max_kernel_len = max([_.size for _ in kernels])
120
+ # From sps._calc_oa_lens, where s2=max_kernel_len,:
121
+ # fallback_nfft = n_input + max_kernel_len - 1, but n_input is unbound.
122
+ overlap = max_kernel_len - 1
123
+
124
+ # Prepare previous iteration's overlap tail to add to input -- all zeros.
125
+ tail_shape = in_shape + (len(kernels), overlap)
126
+ tail = np.zeros(tail_shape, dtype="complex" if b_complex else "float")
127
+
128
+ # Prepare output template -- kernels axis immediately before the target axis
129
+ dummy_shape = in_shape + (len(kernels), 0)
130
+ template = AxisArray(
131
+ data=np.zeros(dummy_shape, dtype="complex" if b_complex else "float"),
132
+ dims=msg_in.dims[:targ_ax_ix]
133
+ + msg_in.dims[targ_ax_ix + 1 :]
134
+ + [new_axis, axis],
135
+ axes=msg_in.axes.copy(), # We do not have info for kernel/filter axis :(.
136
+ )
137
+
138
+ # Determine optimal mode. Assumes 100 msec chunks.
139
+ if mode == FilterbankMode.AUTO:
140
+ # concatenate kernels into 1 mega kernel then check what's faster.
141
+ # Will typically return fft when combined kernel length is > 1500.
142
+ concat_kernel = np.concatenate(kernels)
143
+ n_dummy = max(2 * len(concat_kernel), int(0.1 / gain))
144
+ dummy_arr = np.zeros(n_dummy)
145
+ mode = sps.choose_conv_method(dummy_arr, concat_kernel, mode="full")
146
+ mode = FilterbankMode.CONV if mode == "direct" else FilterbankMode.FFT
147
+
148
+ if mode == FilterbankMode.CONV:
149
+ # Preallocate memory for convolution result and overlap-add
150
+ dest_shape = in_shape + (
151
+ len(kernels),
152
+ overlap + msg_in.data.shape[targ_ax_ix],
153
+ )
154
+ dest_arr = np.zeros(
155
+ dest_shape, dtype="complex" if b_complex else "float"
156
+ )
157
+
158
+ elif mode == FilterbankMode.FFT:
159
+ # Calculate optimal nfft and windowing size.
160
+ opt_size = -overlap * lambertw(-1 / (2 * math.e * overlap), k=-1).real
161
+ nfft = sp_fft.next_fast_len(math.ceil(opt_size))
162
+ win_len = nfft - overlap
163
+ # infft same as nfft. Keeping as separate variable because I might need it again.
164
+ infft = win_len + overlap
165
+
166
+ # Create windowing node.
167
+ # Note: We could do windowing manually to avoid the overhead of the message structure,
168
+ # but windowing is difficult to do correctly, so we lean on the heavily-tested `windowing` generator.
169
+ win_dur = win_len * gain
170
+ wingen = windowing(
171
+ axis=axis,
172
+ newaxis="win", # Big data chunks might yield more than 1 window.
173
+ window_dur=win_dur,
174
+ window_shift=win_dur, # Tumbling (not sliding) windows expected!
175
+ zero_pad_until="none",
176
+ )
177
+
178
+ # Windowing output has an extra "win" dimension, so we need our tail to match.
179
+ tail = np.expand_dims(tail, -2)
180
+
181
+ # Prepare fft functions
182
+ # Note: We could instead use `spectrum` but this adds overhead in creating the message structure
183
+ # for a rather simple calculation. We may revisit if `spectrum` gets additional features, such as
184
+ # more fft backends.
185
+ if b_complex:
186
+ fft = functools.partial(sp_fft.fft, n=nfft, norm="backward")
187
+ ifft = functools.partial(sp_fft.ifft, n=infft, norm="backward")
188
+ else:
189
+ fft = functools.partial(sp_fft.rfft, n=nfft, norm="backward")
190
+ ifft = functools.partial(sp_fft.irfft, n=infft, norm="backward")
191
+
192
+ # Calculate fft of kernels
193
+ prep_kerns = np.array([fft(_) for _ in kernels])
194
+ prep_kerns = np.expand_dims(prep_kerns, -2)
195
+ # TODO: If fft_kernels have significant stretches of zeros, convert to sparse array.
196
+
197
+ # Make sure target axis is in -1th position.
198
+ if targ_ax_ix != (msg_in.data.ndim - 1):
199
+ in_dat = np.moveaxis(msg_in.data, targ_ax_ix, -1)
200
+ if mode == FilterbankMode.FFT:
201
+ # Fix msg_in .dims because we will pass it to wingen
202
+ move_dims = (
203
+ msg_in.dims[:targ_ax_ix] + msg_in.dims[targ_ax_ix + 1 :] + [axis]
204
+ )
205
+ msg_in = replace(msg_in, data=in_dat, dims=move_dims)
206
+ else:
207
+ in_dat = msg_in.data
208
+
209
+ if mode == FilterbankMode.CONV:
210
+ n_dest = in_dat.shape[-1] + overlap
211
+ if dest_arr.shape[-1] < n_dest:
212
+ pad = np.zeros(dest_arr.shape[:-1] + (n_dest - dest_arr.shape[-1],))
213
+ dest_arr = np.concatenate(dest_arr, pad, axis=-1)
214
+ dest_arr.fill(0)
215
+ # TODO: Parallelize this loop.
216
+ for k_ix, k in enumerate(kernels):
217
+ n_out = in_dat.shape[-1] + k.shape[-1] - 1
218
+ dest_arr[..., k_ix, :n_out] = np.apply_along_axis(
219
+ np.convolve, -1, in_dat, k, mode="full"
220
+ )
221
+ dest_arr[..., :overlap] += tail # Add previous overlap
222
+ new_tail = dest_arr[..., in_dat.shape[-1] : n_dest]
223
+ if new_tail.size > 0:
224
+ # COPY overlap for next iteration
225
+ tail = new_tail.copy()
226
+ res = dest_arr[..., : in_dat.shape[-1]].copy()
227
+ elif mode == FilterbankMode.FFT:
228
+ # Slice into non-overlapping windows
229
+ win_msg = wingen.send(msg_in)
230
+ # Calculate spectra of each window
231
+ spec_dat = fft(win_msg.data, axis=-1)
232
+ # Insert axis for filters
233
+ spec_dat = np.expand_dims(spec_dat, -3)
234
+
235
+ # Do the FFT convolution
236
+ # TODO: handle fft_kernels being sparse. Maybe need np.dot.
237
+ conv_spec = spec_dat * prep_kerns
238
+ overlapped = ifft(conv_spec, axis=-1)
239
+
240
+ # Do the overlap-add on the `axis` axis
241
+ # Previous iteration's tail:
242
+ overlapped[..., :1, :overlap] += tail
243
+ # window-to-window:
244
+ overlapped[..., 1:, :overlap] += overlapped[..., :-1, -overlap:]
245
+ # Save tail:
246
+ new_tail = overlapped[..., -1:, -overlap:]
247
+ if new_tail.size > 0:
248
+ # All of the above code works if input is size-zero, but we don't want to save a zero-size tail.
249
+ tail = new_tail # Save the tail for the next iteration.
250
+ # Concat over win axis, without overlap.
251
+ res = overlapped[..., :-overlap].reshape(overlapped.shape[:-2] + (-1,))
252
+
253
+ msg_out = replace(
254
+ template, data=res, axes={**template.axes, axis: msg_in.axes[axis]}
255
+ )
256
+
257
+
258
+ class FilterbankSettings(ez.Settings):
259
+ kernels: typing.Union[list[npt.NDArray], tuple[npt.NDArray, ...]]
260
+ mode: FilterbankMode = FilterbankMode.CONV
261
+ min_phase: MinPhaseMode = MinPhaseMode.NONE
262
+ axis: str = "time"
263
+
264
+
265
+ class Filterbank(GenAxisArray):
266
+ """Unit for :obj:`spectrum`"""
267
+
268
+ SETTINGS = FilterbankSettings
269
+
270
+ INPUT_SETTINGS = ez.InputStream(FilterbankSettings)
271
+
272
+ def construct_generator(self):
273
+ self.STATE.gen = filterbank(
274
+ kernels=self.SETTINGS.kernels,
275
+ mode=self.SETTINGS.mode,
276
+ min_phase=self.SETTINGS.min_phase,
277
+ axis=self.SETTINGS.axis,
278
+ )
File without changes
@@ -0,0 +1,28 @@
1
+ from dataclasses import replace
2
+ import typing
3
+
4
+ import numpy as np
5
+ import ezmsg.core as ez
6
+ from ezmsg.util.generator import consumer
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+
9
+ from ..base import GenAxisArray
10
+
11
+
12
+ @consumer
13
+ def abs() -> typing.Generator[AxisArray, AxisArray, None]:
14
+ msg_out = AxisArray(np.array([]), dims=[""])
15
+ while True:
16
+ msg_in = yield msg_out
17
+ msg_out = replace(msg_in, data=np.abs(msg_in.data))
18
+
19
+
20
+ class AbsSettings(ez.Settings):
21
+ pass
22
+
23
+
24
+ class Abs(GenAxisArray):
25
+ SETTINGS = AbsSettings
26
+
27
+ def construct_generator(self):
28
+ self.STATE.gen = abs()
@@ -0,0 +1,30 @@
1
+ from dataclasses import replace
2
+ import typing
3
+
4
+ import numpy as np
5
+ import ezmsg.core as ez
6
+ from ezmsg.util.generator import consumer
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+
9
+ from ..base import GenAxisArray
10
+
11
+
12
+ @consumer
13
+ def clip(a_min: float, a_max: float) -> typing.Generator[AxisArray, AxisArray, None]:
14
+ msg_in = AxisArray(np.array([]), dims=[""])
15
+ msg_out = AxisArray(np.array([]), dims=[""])
16
+ while True:
17
+ msg_in = yield msg_out
18
+ msg_out = replace(msg_in, data=np.clip(msg_in.data, a_min, a_max))
19
+
20
+
21
+ class ClipSettings(ez.Settings):
22
+ a_min: float
23
+ a_max: float
24
+
25
+
26
+ class Clip(GenAxisArray):
27
+ SETTINGS = ClipSettings
28
+
29
+ def construct_generator(self):
30
+ self.STATE.gen = clip(a_min=self.SETTINGS.a_min, a_max=self.SETTINGS.a_max)
@@ -0,0 +1,60 @@
1
+ from dataclasses import replace
2
+ import typing
3
+
4
+ import numpy as np
5
+ import ezmsg.core as ez
6
+ from ezmsg.util.generator import consumer
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+
9
+ from ..base import GenAxisArray
10
+
11
+
12
+ @consumer
13
+ def const_difference(
14
+ value: float = 0.0, subtrahend: bool = True
15
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
16
+ """
17
+ result = (in_data - value) if subtrahend else (value - in_data)
18
+ https://en.wikipedia.org/wiki/Template:Arithmetic_operations
19
+ """
20
+ msg_out = AxisArray(np.array([]), dims=[""])
21
+ while True:
22
+ msg_in: AxisArray = yield msg_out
23
+ msg_out = replace(
24
+ msg_in, data=(msg_in.data - value) if subtrahend else (value - msg_in.data)
25
+ )
26
+
27
+
28
+ class ConstDifferenceSettings(ez.Settings):
29
+ value: float = 0.0
30
+ subtrahend: bool = True
31
+
32
+
33
+ class ConstDifference(GenAxisArray):
34
+ SETTINGS = ConstDifferenceSettings
35
+
36
+ def construct_generator(self):
37
+ self.STATE.gen = const_difference(
38
+ value=self.SETTINGS.value, subtrahend=self.SETTINGS.subtrahend
39
+ )
40
+
41
+
42
+ # class DifferenceSettings(ez.Settings):
43
+ # pass
44
+ #
45
+ #
46
+ # class Difference(ez.Unit):
47
+ # SETTINGS = DifferenceSettings
48
+ #
49
+ # INPUT_SIGNAL_1 = ez.InputStream(AxisArray)
50
+ # INPUT_SIGNAL_2 = ez.InputStream(AxisArray)
51
+ # OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
52
+ #
53
+ # @ez.subscriber(INPUT_SIGNAL_2, zero_copy=True)
54
+ # @ez.publisher(OUTPUT_SIGNAL)
55
+ # async def on_input_2(self, message: AxisArray) -> typing.AsyncGenerator:
56
+ # # TODO: buffer_2
57
+ # # TODO: take buffer_1 - buffer_2 for ranges that align
58
+ # # TODO: Drop samples from buffer_1 and buffer_2
59
+ # if ret is not None:
60
+ # yield self.OUTPUT_SIGNAL, ret
@@ -0,0 +1,29 @@
1
+ from dataclasses import replace
2
+ import typing
3
+
4
+ import numpy as np
5
+ import ezmsg.core as ez
6
+ from ezmsg.util.generator import consumer
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+
9
+ from ..base import GenAxisArray
10
+
11
+
12
+ @consumer
13
+ def invert() -> typing.Generator[AxisArray, AxisArray, None]:
14
+ msg_in = AxisArray(np.array([]), dims=[""])
15
+ msg_out = AxisArray(np.array([]), dims=[""])
16
+ while True:
17
+ msg_in = yield msg_out
18
+ msg_out = replace(msg_in, data=1 / msg_in.data)
19
+
20
+
21
+ class InvertSettings(ez.Settings):
22
+ pass
23
+
24
+
25
+ class Invert(GenAxisArray):
26
+ SETTINGS = InvertSettings
27
+
28
+ def construct_generator(self):
29
+ self.STATE.gen = invert()
@@ -0,0 +1,32 @@
1
+ from dataclasses import replace
2
+ import typing
3
+
4
+ import numpy as np
5
+ import ezmsg.core as ez
6
+ from ezmsg.util.generator import consumer
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+
9
+ from ..base import GenAxisArray
10
+
11
+
12
+ @consumer
13
+ def log(
14
+ base: float = 10.0,
15
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
16
+ msg_in = AxisArray(np.array([]), dims=[""])
17
+ msg_out = AxisArray(np.array([]), dims=[""])
18
+ log_base = np.log(base)
19
+ while True:
20
+ msg_in = yield msg_out
21
+ msg_out = replace(msg_in, data=np.log(msg_in.data) / log_base)
22
+
23
+
24
+ class LogSettings(ez.Settings):
25
+ base: float = 10.0
26
+
27
+
28
+ class Log(GenAxisArray):
29
+ SETTINGS = LogSettings
30
+
31
+ def construct_generator(self):
32
+ self.STATE.gen = log(base=self.SETTINGS.base)
@@ -0,0 +1,31 @@
1
+ from dataclasses import replace
2
+ import typing
3
+
4
+ import numpy as np
5
+ import ezmsg.core as ez
6
+ from ezmsg.util.generator import consumer
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+
9
+ from ..base import GenAxisArray
10
+
11
+
12
+ @consumer
13
+ def scale(scale: float = 1.0) -> typing.Generator[AxisArray, AxisArray, None]:
14
+ msg_in = AxisArray(np.array([]), dims=[""])
15
+ msg_out = AxisArray(np.array([]), dims=[""])
16
+ while True:
17
+ msg_in = yield msg_out
18
+ msg_out = replace(msg_in, data=scale * msg_in.data)
19
+
20
+
21
+ class ScaleSettings(ez.Settings):
22
+ scale: float = 1.0
23
+
24
+
25
+ class Scale(GenAxisArray):
26
+ SETTINGS = ScaleSettings
27
+
28
+ def construct_generator(self):
29
+ self.STATE.gen = scale(
30
+ scale=self.SETTINGS.scale,
31
+ )
ezmsg/sigproc/messages.py CHANGED
@@ -1,11 +1,10 @@
1
1
  import warnings
2
2
  import time
3
+ import typing
3
4
 
4
5
  import numpy.typing as npt
5
-
6
6
  from ezmsg.util.messages.axisarray import AxisArray
7
7
 
8
- from typing import Optional
9
8
 
10
9
  # UPCOMING: TSMessage Deprecation
11
10
  # TSMessage is deprecated because it doesn't handle multiple time axes well.
@@ -21,7 +20,7 @@ def TSMessage(
21
20
  data: npt.NDArray,
22
21
  fs: float = 1.0,
23
22
  time_dim: int = 0,
24
- timestamp: Optional[float] = None,
23
+ timestamp: typing.Optional[float] = None,
25
24
  ) -> AxisArray:
26
25
  dims = [f"dim_{i}" for i in range(data.ndim)]
27
26
  dims[time_dim] = "time"