ezmsg-sigproc 1.4.2__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +2 -2
- ezmsg/sigproc/affinetransform.py +13 -13
- ezmsg/sigproc/aggregate.py +49 -28
- ezmsg/sigproc/bandpower.py +2 -2
- ezmsg/sigproc/butterworthfilter.py +89 -90
- ezmsg/sigproc/cheby.py +119 -0
- ezmsg/sigproc/decimate.py +11 -15
- ezmsg/sigproc/downsample.py +8 -4
- ezmsg/sigproc/ewmfilter.py +9 -5
- ezmsg/sigproc/filter.py +82 -115
- ezmsg/sigproc/filterbank.py +5 -5
- ezmsg/sigproc/math/abs.py +1 -1
- ezmsg/sigproc/math/clip.py +1 -1
- ezmsg/sigproc/math/difference.py +1 -1
- ezmsg/sigproc/math/invert.py +1 -1
- ezmsg/sigproc/math/log.py +1 -1
- ezmsg/sigproc/math/scale.py +1 -1
- ezmsg/sigproc/messages.py +2 -3
- ezmsg/sigproc/sampler.py +16 -15
- ezmsg/sigproc/scaler.py +153 -35
- ezmsg/sigproc/signalinjector.py +7 -7
- ezmsg/sigproc/slicer.py +34 -14
- ezmsg/sigproc/spectrogram.py +6 -6
- ezmsg/sigproc/spectrum.py +18 -14
- ezmsg/sigproc/synth.py +43 -27
- ezmsg/sigproc/wavelets.py +42 -17
- ezmsg/sigproc/window.py +14 -13
- {ezmsg_sigproc-1.4.2.dist-info → ezmsg_sigproc-1.6.0.dist-info}/METADATA +4 -5
- ezmsg_sigproc-1.6.0.dist-info/RECORD +36 -0
- {ezmsg_sigproc-1.4.2.dist-info → ezmsg_sigproc-1.6.0.dist-info}/WHEEL +1 -1
- ezmsg_sigproc-1.4.2.dist-info/RECORD +0 -35
- {ezmsg_sigproc-1.4.2.dist-info → ezmsg_sigproc-1.6.0.dist-info}/licenses/LICENSE.txt +0 -0
ezmsg/sigproc/scaler.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
|
-
|
|
1
|
+
import functools
|
|
2
2
|
import typing
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import numpy.typing as npt
|
|
6
|
+
import scipy.signal
|
|
6
7
|
import ezmsg.core as ez
|
|
7
8
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
|
+
from ezmsg.util.messages.util import replace
|
|
8
10
|
from ezmsg.util.generator import consumer
|
|
9
11
|
|
|
10
12
|
from .base import GenAxisArray
|
|
@@ -28,9 +30,139 @@ def _alpha_from_tau(tau: float, dt: float) -> float:
|
|
|
28
30
|
return 1 - np.exp(-dt / tau)
|
|
29
31
|
|
|
30
32
|
|
|
33
|
+
def ewma_step(
|
|
34
|
+
sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
Do an exponentially weighted moving average step.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
sample: The new sample.
|
|
41
|
+
zi: The output of the previous step.
|
|
42
|
+
alpha: Fading factor.
|
|
43
|
+
beta: Persisting factor. If None, it is calculated as 1-alpha.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
alpha * sample + beta * zi
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
# Potential micro-optimization:
|
|
50
|
+
# Current: scalar-arr multiplication, scalar-arr multiplication, arr-arr addition
|
|
51
|
+
# Alternative: arr-arr subtraction, arr-arr multiplication, arr-arr addition
|
|
52
|
+
# return zi + alpha * (new_sample - zi)
|
|
53
|
+
beta = beta or (1 - alpha)
|
|
54
|
+
return alpha * sample + beta * zi
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class EWMA:
|
|
58
|
+
def __init__(self, alpha: float):
|
|
59
|
+
self.beta = 1 - alpha
|
|
60
|
+
self._filt_func = functools.partial(
|
|
61
|
+
scipy.signal.lfilter, [alpha], [1.0, alpha - 1.0], axis=0
|
|
62
|
+
)
|
|
63
|
+
self.prev = None
|
|
64
|
+
|
|
65
|
+
def compute(self, arr: npt.NDArray) -> npt.NDArray:
|
|
66
|
+
if self.prev is None:
|
|
67
|
+
self.prev = self.beta * arr[:1]
|
|
68
|
+
expected, self.prev = self._filt_func(arr, zi=self.prev)
|
|
69
|
+
return expected
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class EWMA_Deprecated:
|
|
73
|
+
"""
|
|
74
|
+
Grabbed these methods from https://stackoverflow.com/a/70998068 and other answers in that topic,
|
|
75
|
+
but they ended up being slower than the scipy.signal.lfilter method.
|
|
76
|
+
Additionally, `compute` and `compute2` suffer from potential errors as the vector length increases
|
|
77
|
+
and beta**n approaches zero.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def __init__(self, alpha: float, max_len: int):
|
|
81
|
+
self.alpha = alpha
|
|
82
|
+
self.beta = 1 - alpha
|
|
83
|
+
self.prev: npt.NDArray | None = None
|
|
84
|
+
self.weights = np.empty((max_len + 1,), float)
|
|
85
|
+
self._precalc_weights(max_len)
|
|
86
|
+
self._step_func = functools.partial(ewma_step, alpha=self.alpha, beta=self.beta)
|
|
87
|
+
|
|
88
|
+
def _precalc_weights(self, n: int):
|
|
89
|
+
# (1-α)^0, (1-α)^1, (1-α)^2, ..., (1-α)^n
|
|
90
|
+
np.power(self.beta, np.arange(n + 1), out=self.weights)
|
|
91
|
+
|
|
92
|
+
def compute(self, arr: npt.NDArray, out: npt.NDArray | None = None) -> npt.NDArray:
|
|
93
|
+
if out is None:
|
|
94
|
+
out = np.empty(arr.shape, arr.dtype)
|
|
95
|
+
|
|
96
|
+
n = arr.shape[0]
|
|
97
|
+
weights = self.weights[:n]
|
|
98
|
+
weights = np.expand_dims(weights, list(range(1, arr.ndim)))
|
|
99
|
+
|
|
100
|
+
# α*P0, α*P1, α*P2, ..., α*Pn
|
|
101
|
+
np.multiply(self.alpha, arr, out)
|
|
102
|
+
|
|
103
|
+
# α*P0/(1-α)^0, α*P1/(1-α)^1, α*P2/(1-α)^2, ..., α*Pn/(1-α)^n
|
|
104
|
+
np.divide(out, weights, out)
|
|
105
|
+
|
|
106
|
+
# α*P0/(1-α)^0, α*P0/(1-α)^0 + α*P1/(1-α)^1, ...
|
|
107
|
+
np.cumsum(out, axis=0, out=out)
|
|
108
|
+
|
|
109
|
+
# (α*P0/(1-α)^0)*(1-α)^0, (α*P0/(1-α)^0 + α*P1/(1-α)^1)*(1-α)^1, ...
|
|
110
|
+
np.multiply(out, weights, out)
|
|
111
|
+
|
|
112
|
+
# Add the previous output
|
|
113
|
+
if self.prev is None:
|
|
114
|
+
self.prev = arr[:1]
|
|
115
|
+
|
|
116
|
+
out += self.prev * np.expand_dims(
|
|
117
|
+
self.weights[1 : n + 1], list(range(1, arr.ndim))
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
self.prev = out[-1:]
|
|
121
|
+
|
|
122
|
+
return out
|
|
123
|
+
|
|
124
|
+
def compute2(self, arr: npt.NDArray) -> npt.NDArray:
|
|
125
|
+
"""
|
|
126
|
+
Compute the Exponentially Weighted Moving Average (EWMA) of the input array.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
arr: The input array to be smoothed.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
The smoothed array.
|
|
133
|
+
"""
|
|
134
|
+
n = arr.shape[0]
|
|
135
|
+
if n > len(self.weights):
|
|
136
|
+
self._precalc_weights(n)
|
|
137
|
+
weights = self.weights[:n][::-1]
|
|
138
|
+
weights = np.expand_dims(weights, list(range(1, arr.ndim)))
|
|
139
|
+
|
|
140
|
+
result = np.cumsum(self.alpha * weights * arr, axis=0)
|
|
141
|
+
result = result / weights
|
|
142
|
+
|
|
143
|
+
# Handle the first call when prev is unset
|
|
144
|
+
if self.prev is None:
|
|
145
|
+
self.prev = arr[:1]
|
|
146
|
+
|
|
147
|
+
result += self.prev * np.expand_dims(
|
|
148
|
+
self.weights[1 : n + 1], list(range(1, arr.ndim))
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Store the result back into prev
|
|
152
|
+
self.prev = result[-1]
|
|
153
|
+
|
|
154
|
+
return result
|
|
155
|
+
|
|
156
|
+
def compute_sample(self, new_sample: npt.NDArray) -> npt.NDArray:
|
|
157
|
+
if self.prev is None:
|
|
158
|
+
self.prev = new_sample
|
|
159
|
+
self.prev = self._step_func(new_sample, self.prev)
|
|
160
|
+
return self.prev
|
|
161
|
+
|
|
162
|
+
|
|
31
163
|
@consumer
|
|
32
164
|
def scaler(
|
|
33
|
-
time_constant: float = 1.0, axis:
|
|
165
|
+
time_constant: float = 1.0, axis: str | None = None
|
|
34
166
|
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
35
167
|
"""
|
|
36
168
|
Apply the adaptive standard scaler from https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/
|
|
@@ -78,7 +210,7 @@ def scaler(
|
|
|
78
210
|
|
|
79
211
|
@consumer
|
|
80
212
|
def scaler_np(
|
|
81
|
-
time_constant: float = 1.0, axis:
|
|
213
|
+
time_constant: float = 1.0, axis: str | None = None
|
|
82
214
|
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
83
215
|
"""
|
|
84
216
|
Create a generator function that applies an adaptive standard scaler.
|
|
@@ -87,6 +219,7 @@ def scaler_np(
|
|
|
87
219
|
Args:
|
|
88
220
|
time_constant: Decay constant `tau` in seconds.
|
|
89
221
|
axis: The name of the axis to accumulate statistics over.
|
|
222
|
+
Note: The axis must be in the msg.axes and be of type AxisArray.LinearAxis.
|
|
90
223
|
|
|
91
224
|
Returns:
|
|
92
225
|
A primed generator object that expects to be sent a :obj:`AxisArray` via `.send(axis_array)`
|
|
@@ -95,10 +228,8 @@ def scaler_np(
|
|
|
95
228
|
msg_out = AxisArray(np.array([]), dims=[""])
|
|
96
229
|
|
|
97
230
|
# State variables
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
vars_means: typing.Optional[npt.NDArray] = None
|
|
101
|
-
vars_sq_means: typing.Optional[npt.NDArray] = None
|
|
231
|
+
samps_ewma: EWMA | None = None
|
|
232
|
+
vars_sq_ewma: EWMA | None = None
|
|
102
233
|
|
|
103
234
|
# Reset if input changes
|
|
104
235
|
check_input = {
|
|
@@ -107,45 +238,32 @@ def scaler_np(
|
|
|
107
238
|
"key": None, # Key change implies buffered means/vars are invalid.
|
|
108
239
|
}
|
|
109
240
|
|
|
110
|
-
def _ew_update(arr, prev, _alpha):
|
|
111
|
-
if np.all(prev == 0):
|
|
112
|
-
return arr
|
|
113
|
-
# return _alpha * arr + (1 - _alpha) * prev
|
|
114
|
-
# Micro-optimization: sub, mult, add (below) is faster than sub, mult, mult, add (above)
|
|
115
|
-
return prev + _alpha * (arr - prev)
|
|
116
|
-
|
|
117
241
|
while True:
|
|
118
242
|
msg_in: AxisArray = yield msg_out
|
|
119
243
|
|
|
120
244
|
axis = axis or msg_in.dims[0]
|
|
121
245
|
axis_idx = msg_in.get_axis_idx(axis)
|
|
122
246
|
|
|
123
|
-
if msg_in.axes[axis].gain != check_input["gain"]:
|
|
124
|
-
alpha = _alpha_from_tau(time_constant, msg_in.axes[axis].gain)
|
|
125
|
-
check_input["gain"] = msg_in.axes[axis].gain
|
|
126
|
-
|
|
127
247
|
data: npt.NDArray = np.moveaxis(msg_in.data, axis_idx, 0)
|
|
128
248
|
b_reset = data.shape[1:] != check_input["shape"]
|
|
129
|
-
b_reset
|
|
249
|
+
b_reset = b_reset or msg_in.axes[axis].gain != check_input["gain"]
|
|
250
|
+
b_reset = b_reset or msg_in.key != check_input["key"]
|
|
130
251
|
if b_reset:
|
|
131
252
|
check_input["shape"] = data.shape[1:]
|
|
253
|
+
check_input["gain"] = msg_in.axes[axis].gain
|
|
132
254
|
check_input["key"] = msg_in.key
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
# Update step
|
|
141
|
-
vars_means = _ew_update(sample, vars_means, alpha)
|
|
142
|
-
vars_sq_means = _ew_update(sample**2, vars_sq_means, alpha)
|
|
143
|
-
means = _ew_update(sample, means, alpha)
|
|
144
|
-
# Get step
|
|
145
|
-
varis = vars_sq_means - vars_means**2
|
|
146
|
-
y = (sample - means) / (varis**0.5)
|
|
147
|
-
result[sample_ix] = y
|
|
255
|
+
alpha = _alpha_from_tau(time_constant, msg_in.axes[axis].gain)
|
|
256
|
+
samps_ewma = EWMA(alpha=alpha)
|
|
257
|
+
vars_sq_ewma = EWMA(alpha=alpha)
|
|
258
|
+
|
|
259
|
+
# Update step
|
|
260
|
+
means = samps_ewma.compute(data)
|
|
261
|
+
vars_sq_means = vars_sq_ewma.compute(data**2)
|
|
148
262
|
|
|
263
|
+
# Get step
|
|
264
|
+
varis = vars_sq_means - means**2
|
|
265
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
266
|
+
result = (data - means) / (varis**0.5)
|
|
149
267
|
result[np.isnan(result)] = 0.0
|
|
150
268
|
result = np.moveaxis(result, 0, axis_idx)
|
|
151
269
|
msg_out = replace(msg_in, data=result)
|
|
@@ -158,7 +276,7 @@ class AdaptiveStandardScalerSettings(ez.Settings):
|
|
|
158
276
|
"""
|
|
159
277
|
|
|
160
278
|
time_constant: float = 1.0
|
|
161
|
-
axis:
|
|
279
|
+
axis: str | None = None
|
|
162
280
|
|
|
163
281
|
|
|
164
282
|
class AdaptiveStandardScaler(GenAxisArray):
|
ezmsg/sigproc/signalinjector.py
CHANGED
|
@@ -1,22 +1,22 @@
|
|
|
1
|
-
from dataclasses import replace
|
|
2
1
|
import typing
|
|
3
2
|
|
|
4
3
|
import ezmsg.core as ez
|
|
5
4
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
5
|
+
from ezmsg.util.messages.util import replace
|
|
6
6
|
import numpy as np
|
|
7
7
|
import numpy.typing as npt
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class SignalInjectorSettings(ez.Settings):
|
|
11
11
|
time_dim: str = "time" # Input signal needs a time dimension with units in sec.
|
|
12
|
-
frequency:
|
|
12
|
+
frequency: float | None = None # Hz
|
|
13
13
|
amplitude: float = 1.0
|
|
14
|
-
mixing_seed:
|
|
14
|
+
mixing_seed: int | None = None
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
class SignalInjectorState(ez.State):
|
|
18
|
-
cur_shape:
|
|
19
|
-
cur_frequency:
|
|
18
|
+
cur_shape: tuple[int, ...] | None = None
|
|
19
|
+
cur_frequency: float | None = None
|
|
20
20
|
cur_amplitude: float
|
|
21
21
|
mixing: npt.NDArray
|
|
22
22
|
|
|
@@ -30,7 +30,7 @@ class SignalInjector(ez.Unit):
|
|
|
30
30
|
SETTINGS = SignalInjectorSettings
|
|
31
31
|
STATE = SignalInjectorState
|
|
32
32
|
|
|
33
|
-
INPUT_FREQUENCY = ez.InputStream(
|
|
33
|
+
INPUT_FREQUENCY = ez.InputStream(float | None)
|
|
34
34
|
INPUT_AMPLITUDE = ez.InputStream(float)
|
|
35
35
|
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
36
36
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
@@ -41,7 +41,7 @@ class SignalInjector(ez.Unit):
|
|
|
41
41
|
self.STATE.mixing = np.array([])
|
|
42
42
|
|
|
43
43
|
@ez.subscriber(INPUT_FREQUENCY)
|
|
44
|
-
async def on_frequency(self, msg:
|
|
44
|
+
async def on_frequency(self, msg: float | None) -> None:
|
|
45
45
|
self.STATE.cur_frequency = msg
|
|
46
46
|
|
|
47
47
|
@ez.subscriber(INPUT_AMPLITUDE)
|
ezmsg/sigproc/slicer.py
CHANGED
|
@@ -1,10 +1,14 @@
|
|
|
1
|
-
from dataclasses import replace
|
|
2
1
|
import typing
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
4
|
import numpy.typing as npt
|
|
6
5
|
import ezmsg.core as ez
|
|
7
|
-
from ezmsg.util.messages.axisarray import
|
|
6
|
+
from ezmsg.util.messages.axisarray import (
|
|
7
|
+
AxisArray,
|
|
8
|
+
slice_along_axis,
|
|
9
|
+
AxisBase,
|
|
10
|
+
replace,
|
|
11
|
+
)
|
|
8
12
|
from ezmsg.util.generator import consumer
|
|
9
13
|
|
|
10
14
|
from .base import GenAxisArray
|
|
@@ -15,7 +19,10 @@ Slicer:Select a subset of data along a particular axis.
|
|
|
15
19
|
"""
|
|
16
20
|
|
|
17
21
|
|
|
18
|
-
def parse_slice(
|
|
22
|
+
def parse_slice(
|
|
23
|
+
s: str,
|
|
24
|
+
axinfo: AxisArray.CoordinateAxis | None = None,
|
|
25
|
+
) -> tuple[slice | int, ...]:
|
|
19
26
|
"""
|
|
20
27
|
Parses a string representation of a slice and returns a tuple of slice objects.
|
|
21
28
|
|
|
@@ -26,9 +33,13 @@ def parse_slice(s: str) -> typing.Tuple[typing.Union[slice, int], ...]:
|
|
|
26
33
|
- "5" (or any integer) -> (5,). Take only that item.
|
|
27
34
|
applying this to a ndarray or AxisArray will drop the dimension.
|
|
28
35
|
- A comma-separated list of the above -> a tuple of slices | ints
|
|
36
|
+
- A comma-separated list of values and axinfo is provided and is a CoordinateAxis -> a tuple of ints
|
|
29
37
|
|
|
30
38
|
Args:
|
|
31
39
|
s: The string representation of the slice.
|
|
40
|
+
axinfo: (Optional) If provided, and of type CoordinateAxis,
|
|
41
|
+
and `s` is a comma-separated list of values, then the values
|
|
42
|
+
in s will be checked against the values in axinfo.data.
|
|
32
43
|
|
|
33
44
|
Returns:
|
|
34
45
|
A tuple of slice objects and/or ints.
|
|
@@ -38,15 +49,21 @@ def parse_slice(s: str) -> typing.Tuple[typing.Union[slice, int], ...]:
|
|
|
38
49
|
if "," not in s:
|
|
39
50
|
parts = [part.strip() for part in s.split(":")]
|
|
40
51
|
if len(parts) == 1:
|
|
52
|
+
if (
|
|
53
|
+
axinfo is not None
|
|
54
|
+
and hasattr(axinfo, "data")
|
|
55
|
+
and parts[0] in axinfo.data
|
|
56
|
+
):
|
|
57
|
+
return tuple(np.where(axinfo.data == parts[0])[0])
|
|
41
58
|
return (int(parts[0]),)
|
|
42
59
|
return (slice(*(int(part.strip()) if part else None for part in parts)),)
|
|
43
|
-
suplist = [parse_slice(_) for _ in s.split(",")]
|
|
60
|
+
suplist = [parse_slice(_, axinfo=axinfo) for _ in s.split(",")]
|
|
44
61
|
return tuple([item for sublist in suplist for item in sublist])
|
|
45
62
|
|
|
46
63
|
|
|
47
64
|
@consumer
|
|
48
65
|
def slicer(
|
|
49
|
-
selection: str = "", axis:
|
|
66
|
+
selection: str = "", axis: str | None = None
|
|
50
67
|
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
51
68
|
"""
|
|
52
69
|
Slice along a particular axis.
|
|
@@ -63,8 +80,8 @@ def slicer(
|
|
|
63
80
|
msg_out = AxisArray(np.array([]), dims=[""])
|
|
64
81
|
|
|
65
82
|
# State variables
|
|
66
|
-
_slice:
|
|
67
|
-
new_axis:
|
|
83
|
+
_slice: slice | npt.NDArray | None = None
|
|
84
|
+
new_axis: AxisBase | None = None
|
|
68
85
|
b_change_dims: bool = False # If number of dimensions changes when slicing
|
|
69
86
|
|
|
70
87
|
# Reset if input changes
|
|
@@ -92,7 +109,7 @@ def slicer(
|
|
|
92
109
|
b_change_dims = False
|
|
93
110
|
|
|
94
111
|
# Calculate the slice
|
|
95
|
-
_slices = parse_slice(selection)
|
|
112
|
+
_slices = parse_slice(selection, msg_in.axes.get(axis, None))
|
|
96
113
|
if len(_slices) == 1:
|
|
97
114
|
_slice = _slices[0]
|
|
98
115
|
# Do we drop the sliced dimension?
|
|
@@ -107,12 +124,15 @@ def slicer(
|
|
|
107
124
|
# Create the output axis.
|
|
108
125
|
if (
|
|
109
126
|
axis in msg_in.axes
|
|
110
|
-
and hasattr(msg_in.axes[axis], "
|
|
111
|
-
and len(msg_in.axes[axis].
|
|
127
|
+
and hasattr(msg_in.axes[axis], "data")
|
|
128
|
+
and len(msg_in.axes[axis].data) > 0
|
|
112
129
|
):
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
130
|
+
in_data = np.array(msg_in.axes[axis].data)
|
|
131
|
+
if b_change_dims:
|
|
132
|
+
out_data = in_data[_slice : _slice + 1]
|
|
133
|
+
else:
|
|
134
|
+
out_data = in_data[_slice]
|
|
135
|
+
new_axis = replace(msg_in.axes[axis], data=out_data)
|
|
116
136
|
|
|
117
137
|
replace_kwargs = {}
|
|
118
138
|
if b_change_dims:
|
|
@@ -134,7 +154,7 @@ def slicer(
|
|
|
134
154
|
|
|
135
155
|
class SlicerSettings(ez.Settings):
|
|
136
156
|
selection: str = ""
|
|
137
|
-
axis:
|
|
157
|
+
axis: str | None = None
|
|
138
158
|
|
|
139
159
|
|
|
140
160
|
class Slicer(GenAxisArray):
|
ezmsg/sigproc/spectrogram.py
CHANGED
|
@@ -12,12 +12,12 @@ from .base import GenAxisArray
|
|
|
12
12
|
|
|
13
13
|
@consumer
|
|
14
14
|
def spectrogram(
|
|
15
|
-
window_dur:
|
|
16
|
-
window_shift:
|
|
15
|
+
window_dur: float | None = None,
|
|
16
|
+
window_shift: float | None = None,
|
|
17
17
|
window: WindowFunction = WindowFunction.HANNING,
|
|
18
18
|
transform: SpectralTransform = SpectralTransform.REL_DB,
|
|
19
19
|
output: SpectralOutput = SpectralOutput.POSITIVE,
|
|
20
|
-
) -> typing.Generator[
|
|
20
|
+
) -> typing.Generator[AxisArray | None, AxisArray, None]:
|
|
21
21
|
"""
|
|
22
22
|
Calculate a spectrogram on streaming data.
|
|
23
23
|
|
|
@@ -50,7 +50,7 @@ def spectrogram(
|
|
|
50
50
|
)
|
|
51
51
|
|
|
52
52
|
# State variables
|
|
53
|
-
msg_out:
|
|
53
|
+
msg_out: AxisArray | None = None
|
|
54
54
|
|
|
55
55
|
while True:
|
|
56
56
|
msg_in: AxisArray = yield msg_out
|
|
@@ -63,8 +63,8 @@ class SpectrogramSettings(ez.Settings):
|
|
|
63
63
|
See :obj:`spectrogram` for a description of the parameters.
|
|
64
64
|
"""
|
|
65
65
|
|
|
66
|
-
window_dur:
|
|
67
|
-
window_shift:
|
|
66
|
+
window_dur: float | None = None # window duration in seconds
|
|
67
|
+
window_shift: float | None = None
|
|
68
68
|
""""window step in seconds. If None, window_shift == window_dur"""
|
|
69
69
|
|
|
70
70
|
# See SpectrumSettings for details of following settings:
|
ezmsg/sigproc/spectrum.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
|
1
|
-
from dataclasses import replace
|
|
2
1
|
import enum
|
|
3
2
|
from functools import partial
|
|
4
3
|
import typing
|
|
5
4
|
|
|
6
5
|
import numpy as np
|
|
7
6
|
import ezmsg.core as ez
|
|
8
|
-
from ezmsg.util.messages.axisarray import
|
|
7
|
+
from ezmsg.util.messages.axisarray import (
|
|
8
|
+
AxisArray,
|
|
9
|
+
slice_along_axis,
|
|
10
|
+
replace,
|
|
11
|
+
)
|
|
9
12
|
from ezmsg.util.generator import consumer
|
|
10
13
|
|
|
11
14
|
from .base import GenAxisArray
|
|
@@ -65,20 +68,21 @@ class SpectralOutput(OptionsEnum):
|
|
|
65
68
|
|
|
66
69
|
@consumer
|
|
67
70
|
def spectrum(
|
|
68
|
-
axis:
|
|
69
|
-
out_axis:
|
|
71
|
+
axis: str | None = None,
|
|
72
|
+
out_axis: str | None = "freq",
|
|
70
73
|
window: WindowFunction = WindowFunction.HANNING,
|
|
71
74
|
transform: SpectralTransform = SpectralTransform.REL_DB,
|
|
72
75
|
output: SpectralOutput = SpectralOutput.POSITIVE,
|
|
73
|
-
norm:
|
|
76
|
+
norm: str | None = "forward",
|
|
74
77
|
do_fftshift: bool = True,
|
|
75
|
-
nfft:
|
|
78
|
+
nfft: int | None = None,
|
|
76
79
|
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
77
80
|
"""
|
|
78
81
|
Calculate a spectrum on a data slice.
|
|
79
82
|
|
|
80
83
|
Args:
|
|
81
84
|
axis: The name of the axis on which to calculate the spectrum.
|
|
85
|
+
Note: The axis must have an .axes entry of type LinearAxis, not CoordinateAxis.
|
|
82
86
|
out_axis: The name of the new axis. Defaults to "freq".
|
|
83
87
|
window: The :obj:`WindowFunction` to apply to the data slice prior to calculating the spectrum.
|
|
84
88
|
transform: The :obj:`SpectralTransform` to apply to the spectral magnitude.
|
|
@@ -101,10 +105,10 @@ def spectrum(
|
|
|
101
105
|
apply_window = window != WindowFunction.NONE
|
|
102
106
|
do_fftshift &= output == SpectralOutput.FULL
|
|
103
107
|
f_sl = slice(None)
|
|
104
|
-
freq_axis:
|
|
105
|
-
fftfun: typing.
|
|
106
|
-
f_transform: typing.
|
|
107
|
-
new_dims:
|
|
108
|
+
freq_axis: AxisArray.LinearAxis | None = None
|
|
109
|
+
fftfun: typing.Callable | None = None
|
|
110
|
+
f_transform: typing.Callable | None = None
|
|
111
|
+
new_dims: list[str] | None = None
|
|
108
112
|
|
|
109
113
|
# Reset if input changes substantially
|
|
110
114
|
check_input = {
|
|
@@ -174,7 +178,7 @@ def spectrum(
|
|
|
174
178
|
freqs = np.fft.fftshift(freqs, axes=-1)
|
|
175
179
|
freqs = freqs[f_sl]
|
|
176
180
|
freqs = freqs.tolist() # To please type checking
|
|
177
|
-
freq_axis = AxisArray.
|
|
181
|
+
freq_axis = AxisArray.LinearAxis(
|
|
178
182
|
unit="Hz", gain=freqs[1] - freqs[0], offset=freqs[0]
|
|
179
183
|
)
|
|
180
184
|
if out_axis is None:
|
|
@@ -234,9 +238,9 @@ class SpectrumSettings(ez.Settings):
|
|
|
234
238
|
See :obj:`spectrum` for a description of the parameters.
|
|
235
239
|
"""
|
|
236
240
|
|
|
237
|
-
axis:
|
|
238
|
-
# n:
|
|
239
|
-
out_axis:
|
|
241
|
+
axis: str | None = None
|
|
242
|
+
# n: int | None = None # n parameter for fft
|
|
243
|
+
out_axis: str | None = "freq" # If none; don't change dim name
|
|
240
244
|
window: WindowFunction = WindowFunction.HAMMING
|
|
241
245
|
transform: SpectralTransform = SpectralTransform.REL_DB
|
|
242
246
|
output: SpectralOutput = SpectralOutput.POSITIVE
|