ezmsg-sigproc 1.5.0__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +3 -2
- ezmsg/sigproc/affinetransform.py +9 -8
- ezmsg/sigproc/aggregate.py +7 -7
- ezmsg/sigproc/bandpower.py +2 -2
- ezmsg/sigproc/butterworthfilter.py +88 -90
- ezmsg/sigproc/cheby.py +119 -0
- ezmsg/sigproc/decimate.py +35 -15
- ezmsg/sigproc/downsample.py +17 -8
- ezmsg/sigproc/ewmfilter.py +10 -5
- ezmsg/sigproc/filter.py +82 -115
- ezmsg/sigproc/filterbank.py +6 -5
- ezmsg/sigproc/math/abs.py +2 -1
- ezmsg/sigproc/math/clip.py +2 -1
- ezmsg/sigproc/math/difference.py +2 -1
- ezmsg/sigproc/math/invert.py +2 -1
- ezmsg/sigproc/math/log.py +2 -1
- ezmsg/sigproc/math/scale.py +2 -1
- ezmsg/sigproc/messages.py +1 -2
- ezmsg/sigproc/sampler.py +10 -14
- ezmsg/sigproc/scaler.py +153 -35
- ezmsg/sigproc/signalinjector.py +8 -7
- ezmsg/sigproc/slicer.py +6 -6
- ezmsg/sigproc/spectrogram.py +6 -6
- ezmsg/sigproc/spectrum.py +11 -11
- ezmsg/sigproc/synth.py +24 -23
- ezmsg/sigproc/wavelets.py +39 -15
- ezmsg/sigproc/window.py +12 -12
- {ezmsg_sigproc-1.5.0.dist-info → ezmsg_sigproc-1.7.0.dist-info}/METADATA +2 -2
- ezmsg_sigproc-1.7.0.dist-info/RECORD +36 -0
- ezmsg_sigproc-1.5.0.dist-info/RECORD +0 -35
- {ezmsg_sigproc-1.5.0.dist-info → ezmsg_sigproc-1.7.0.dist-info}/WHEEL +0 -0
- {ezmsg_sigproc-1.5.0.dist-info → ezmsg_sigproc-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
ezmsg/sigproc/scaler.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
|
1
|
+
import functools
|
|
1
2
|
import typing
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
4
5
|
import numpy.typing as npt
|
|
6
|
+
import scipy.signal
|
|
5
7
|
import ezmsg.core as ez
|
|
6
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
|
+
from ezmsg.util.messages.util import replace
|
|
7
10
|
from ezmsg.util.generator import consumer
|
|
8
11
|
|
|
9
12
|
from .base import GenAxisArray
|
|
@@ -27,9 +30,139 @@ def _alpha_from_tau(tau: float, dt: float) -> float:
|
|
|
27
30
|
return 1 - np.exp(-dt / tau)
|
|
28
31
|
|
|
29
32
|
|
|
33
|
+
def ewma_step(
|
|
34
|
+
sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
Do an exponentially weighted moving average step.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
sample: The new sample.
|
|
41
|
+
zi: The output of the previous step.
|
|
42
|
+
alpha: Fading factor.
|
|
43
|
+
beta: Persisting factor. If None, it is calculated as 1-alpha.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
alpha * sample + beta * zi
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
# Potential micro-optimization:
|
|
50
|
+
# Current: scalar-arr multiplication, scalar-arr multiplication, arr-arr addition
|
|
51
|
+
# Alternative: arr-arr subtraction, arr-arr multiplication, arr-arr addition
|
|
52
|
+
# return zi + alpha * (new_sample - zi)
|
|
53
|
+
beta = beta or (1 - alpha)
|
|
54
|
+
return alpha * sample + beta * zi
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class EWMA:
|
|
58
|
+
def __init__(self, alpha: float):
|
|
59
|
+
self.beta = 1 - alpha
|
|
60
|
+
self._filt_func = functools.partial(
|
|
61
|
+
scipy.signal.lfilter, [alpha], [1.0, alpha - 1.0], axis=0
|
|
62
|
+
)
|
|
63
|
+
self.prev = None
|
|
64
|
+
|
|
65
|
+
def compute(self, arr: npt.NDArray) -> npt.NDArray:
|
|
66
|
+
if self.prev is None:
|
|
67
|
+
self.prev = self.beta * arr[:1]
|
|
68
|
+
expected, self.prev = self._filt_func(arr, zi=self.prev)
|
|
69
|
+
return expected
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class EWMA_Deprecated:
|
|
73
|
+
"""
|
|
74
|
+
Grabbed these methods from https://stackoverflow.com/a/70998068 and other answers in that topic,
|
|
75
|
+
but they ended up being slower than the scipy.signal.lfilter method.
|
|
76
|
+
Additionally, `compute` and `compute2` suffer from potential errors as the vector length increases
|
|
77
|
+
and beta**n approaches zero.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def __init__(self, alpha: float, max_len: int):
|
|
81
|
+
self.alpha = alpha
|
|
82
|
+
self.beta = 1 - alpha
|
|
83
|
+
self.prev: npt.NDArray | None = None
|
|
84
|
+
self.weights = np.empty((max_len + 1,), float)
|
|
85
|
+
self._precalc_weights(max_len)
|
|
86
|
+
self._step_func = functools.partial(ewma_step, alpha=self.alpha, beta=self.beta)
|
|
87
|
+
|
|
88
|
+
def _precalc_weights(self, n: int):
|
|
89
|
+
# (1-α)^0, (1-α)^1, (1-α)^2, ..., (1-α)^n
|
|
90
|
+
np.power(self.beta, np.arange(n + 1), out=self.weights)
|
|
91
|
+
|
|
92
|
+
def compute(self, arr: npt.NDArray, out: npt.NDArray | None = None) -> npt.NDArray:
|
|
93
|
+
if out is None:
|
|
94
|
+
out = np.empty(arr.shape, arr.dtype)
|
|
95
|
+
|
|
96
|
+
n = arr.shape[0]
|
|
97
|
+
weights = self.weights[:n]
|
|
98
|
+
weights = np.expand_dims(weights, list(range(1, arr.ndim)))
|
|
99
|
+
|
|
100
|
+
# α*P0, α*P1, α*P2, ..., α*Pn
|
|
101
|
+
np.multiply(self.alpha, arr, out)
|
|
102
|
+
|
|
103
|
+
# α*P0/(1-α)^0, α*P1/(1-α)^1, α*P2/(1-α)^2, ..., α*Pn/(1-α)^n
|
|
104
|
+
np.divide(out, weights, out)
|
|
105
|
+
|
|
106
|
+
# α*P0/(1-α)^0, α*P0/(1-α)^0 + α*P1/(1-α)^1, ...
|
|
107
|
+
np.cumsum(out, axis=0, out=out)
|
|
108
|
+
|
|
109
|
+
# (α*P0/(1-α)^0)*(1-α)^0, (α*P0/(1-α)^0 + α*P1/(1-α)^1)*(1-α)^1, ...
|
|
110
|
+
np.multiply(out, weights, out)
|
|
111
|
+
|
|
112
|
+
# Add the previous output
|
|
113
|
+
if self.prev is None:
|
|
114
|
+
self.prev = arr[:1]
|
|
115
|
+
|
|
116
|
+
out += self.prev * np.expand_dims(
|
|
117
|
+
self.weights[1 : n + 1], list(range(1, arr.ndim))
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
self.prev = out[-1:]
|
|
121
|
+
|
|
122
|
+
return out
|
|
123
|
+
|
|
124
|
+
def compute2(self, arr: npt.NDArray) -> npt.NDArray:
|
|
125
|
+
"""
|
|
126
|
+
Compute the Exponentially Weighted Moving Average (EWMA) of the input array.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
arr: The input array to be smoothed.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
The smoothed array.
|
|
133
|
+
"""
|
|
134
|
+
n = arr.shape[0]
|
|
135
|
+
if n > len(self.weights):
|
|
136
|
+
self._precalc_weights(n)
|
|
137
|
+
weights = self.weights[:n][::-1]
|
|
138
|
+
weights = np.expand_dims(weights, list(range(1, arr.ndim)))
|
|
139
|
+
|
|
140
|
+
result = np.cumsum(self.alpha * weights * arr, axis=0)
|
|
141
|
+
result = result / weights
|
|
142
|
+
|
|
143
|
+
# Handle the first call when prev is unset
|
|
144
|
+
if self.prev is None:
|
|
145
|
+
self.prev = arr[:1]
|
|
146
|
+
|
|
147
|
+
result += self.prev * np.expand_dims(
|
|
148
|
+
self.weights[1 : n + 1], list(range(1, arr.ndim))
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Store the result back into prev
|
|
152
|
+
self.prev = result[-1]
|
|
153
|
+
|
|
154
|
+
return result
|
|
155
|
+
|
|
156
|
+
def compute_sample(self, new_sample: npt.NDArray) -> npt.NDArray:
|
|
157
|
+
if self.prev is None:
|
|
158
|
+
self.prev = new_sample
|
|
159
|
+
self.prev = self._step_func(new_sample, self.prev)
|
|
160
|
+
return self.prev
|
|
161
|
+
|
|
162
|
+
|
|
30
163
|
@consumer
|
|
31
164
|
def scaler(
|
|
32
|
-
time_constant: float = 1.0, axis:
|
|
165
|
+
time_constant: float = 1.0, axis: str | None = None
|
|
33
166
|
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
34
167
|
"""
|
|
35
168
|
Apply the adaptive standard scaler from https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/
|
|
@@ -77,7 +210,7 @@ def scaler(
|
|
|
77
210
|
|
|
78
211
|
@consumer
|
|
79
212
|
def scaler_np(
|
|
80
|
-
time_constant: float = 1.0, axis:
|
|
213
|
+
time_constant: float = 1.0, axis: str | None = None
|
|
81
214
|
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
82
215
|
"""
|
|
83
216
|
Create a generator function that applies an adaptive standard scaler.
|
|
@@ -95,10 +228,8 @@ def scaler_np(
|
|
|
95
228
|
msg_out = AxisArray(np.array([]), dims=[""])
|
|
96
229
|
|
|
97
230
|
# State variables
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
vars_means: typing.Optional[npt.NDArray] = None
|
|
101
|
-
vars_sq_means: typing.Optional[npt.NDArray] = None
|
|
231
|
+
samps_ewma: EWMA | None = None
|
|
232
|
+
vars_sq_ewma: EWMA | None = None
|
|
102
233
|
|
|
103
234
|
# Reset if input changes
|
|
104
235
|
check_input = {
|
|
@@ -107,45 +238,32 @@ def scaler_np(
|
|
|
107
238
|
"key": None, # Key change implies buffered means/vars are invalid.
|
|
108
239
|
}
|
|
109
240
|
|
|
110
|
-
def _ew_update(arr, prev, _alpha):
|
|
111
|
-
if np.all(prev == 0):
|
|
112
|
-
return arr
|
|
113
|
-
# return _alpha * arr + (1 - _alpha) * prev
|
|
114
|
-
# Micro-optimization: sub, mult, add (below) is faster than sub, mult, mult, add (above)
|
|
115
|
-
return prev + _alpha * (arr - prev)
|
|
116
|
-
|
|
117
241
|
while True:
|
|
118
242
|
msg_in: AxisArray = yield msg_out
|
|
119
243
|
|
|
120
244
|
axis = axis or msg_in.dims[0]
|
|
121
245
|
axis_idx = msg_in.get_axis_idx(axis)
|
|
122
246
|
|
|
123
|
-
if msg_in.axes[axis].gain != check_input["gain"]:
|
|
124
|
-
alpha = _alpha_from_tau(time_constant, msg_in.axes[axis].gain)
|
|
125
|
-
check_input["gain"] = msg_in.axes[axis].gain
|
|
126
|
-
|
|
127
247
|
data: npt.NDArray = np.moveaxis(msg_in.data, axis_idx, 0)
|
|
128
248
|
b_reset = data.shape[1:] != check_input["shape"]
|
|
129
|
-
b_reset
|
|
249
|
+
b_reset = b_reset or msg_in.axes[axis].gain != check_input["gain"]
|
|
250
|
+
b_reset = b_reset or msg_in.key != check_input["key"]
|
|
130
251
|
if b_reset:
|
|
131
252
|
check_input["shape"] = data.shape[1:]
|
|
253
|
+
check_input["gain"] = msg_in.axes[axis].gain
|
|
132
254
|
check_input["key"] = msg_in.key
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
# Update step
|
|
141
|
-
vars_means = _ew_update(sample, vars_means, alpha)
|
|
142
|
-
vars_sq_means = _ew_update(sample**2, vars_sq_means, alpha)
|
|
143
|
-
means = _ew_update(sample, means, alpha)
|
|
144
|
-
# Get step
|
|
145
|
-
varis = vars_sq_means - vars_means**2
|
|
146
|
-
y = (sample - means) / (varis**0.5)
|
|
147
|
-
result[sample_ix] = y
|
|
255
|
+
alpha = _alpha_from_tau(time_constant, msg_in.axes[axis].gain)
|
|
256
|
+
samps_ewma = EWMA(alpha=alpha)
|
|
257
|
+
vars_sq_ewma = EWMA(alpha=alpha)
|
|
258
|
+
|
|
259
|
+
# Update step
|
|
260
|
+
means = samps_ewma.compute(data)
|
|
261
|
+
vars_sq_means = vars_sq_ewma.compute(data**2)
|
|
148
262
|
|
|
263
|
+
# Get step
|
|
264
|
+
varis = vars_sq_means - means**2
|
|
265
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
266
|
+
result = (data - means) / (varis**0.5)
|
|
149
267
|
result[np.isnan(result)] = 0.0
|
|
150
268
|
result = np.moveaxis(result, 0, axis_idx)
|
|
151
269
|
msg_out = replace(msg_in, data=result)
|
|
@@ -158,7 +276,7 @@ class AdaptiveStandardScalerSettings(ez.Settings):
|
|
|
158
276
|
"""
|
|
159
277
|
|
|
160
278
|
time_constant: float = 1.0
|
|
161
|
-
axis:
|
|
279
|
+
axis: str | None = None
|
|
162
280
|
|
|
163
281
|
|
|
164
282
|
class AdaptiveStandardScaler(GenAxisArray):
|
ezmsg/sigproc/signalinjector.py
CHANGED
|
@@ -1,21 +1,22 @@
|
|
|
1
1
|
import typing
|
|
2
2
|
|
|
3
3
|
import ezmsg.core as ez
|
|
4
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
4
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
5
|
+
from ezmsg.util.messages.util import replace
|
|
5
6
|
import numpy as np
|
|
6
7
|
import numpy.typing as npt
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
class SignalInjectorSettings(ez.Settings):
|
|
10
11
|
time_dim: str = "time" # Input signal needs a time dimension with units in sec.
|
|
11
|
-
frequency:
|
|
12
|
+
frequency: float | None = None # Hz
|
|
12
13
|
amplitude: float = 1.0
|
|
13
|
-
mixing_seed:
|
|
14
|
+
mixing_seed: int | None = None
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
class SignalInjectorState(ez.State):
|
|
17
|
-
cur_shape:
|
|
18
|
-
cur_frequency:
|
|
18
|
+
cur_shape: tuple[int, ...] | None = None
|
|
19
|
+
cur_frequency: float | None = None
|
|
19
20
|
cur_amplitude: float
|
|
20
21
|
mixing: npt.NDArray
|
|
21
22
|
|
|
@@ -29,7 +30,7 @@ class SignalInjector(ez.Unit):
|
|
|
29
30
|
SETTINGS = SignalInjectorSettings
|
|
30
31
|
STATE = SignalInjectorState
|
|
31
32
|
|
|
32
|
-
INPUT_FREQUENCY = ez.InputStream(
|
|
33
|
+
INPUT_FREQUENCY = ez.InputStream(float | None)
|
|
33
34
|
INPUT_AMPLITUDE = ez.InputStream(float)
|
|
34
35
|
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
35
36
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
@@ -40,7 +41,7 @@ class SignalInjector(ez.Unit):
|
|
|
40
41
|
self.STATE.mixing = np.array([])
|
|
41
42
|
|
|
42
43
|
@ez.subscriber(INPUT_FREQUENCY)
|
|
43
|
-
async def on_frequency(self, msg:
|
|
44
|
+
async def on_frequency(self, msg: float | None) -> None:
|
|
44
45
|
self.STATE.cur_frequency = msg
|
|
45
46
|
|
|
46
47
|
@ez.subscriber(INPUT_AMPLITUDE)
|
ezmsg/sigproc/slicer.py
CHANGED
|
@@ -21,8 +21,8 @@ Slicer:Select a subset of data along a particular axis.
|
|
|
21
21
|
|
|
22
22
|
def parse_slice(
|
|
23
23
|
s: str,
|
|
24
|
-
axinfo:
|
|
25
|
-
) ->
|
|
24
|
+
axinfo: AxisArray.CoordinateAxis | None = None,
|
|
25
|
+
) -> tuple[slice | int, ...]:
|
|
26
26
|
"""
|
|
27
27
|
Parses a string representation of a slice and returns a tuple of slice objects.
|
|
28
28
|
|
|
@@ -63,7 +63,7 @@ def parse_slice(
|
|
|
63
63
|
|
|
64
64
|
@consumer
|
|
65
65
|
def slicer(
|
|
66
|
-
selection: str = "", axis:
|
|
66
|
+
selection: str = "", axis: str | None = None
|
|
67
67
|
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
68
68
|
"""
|
|
69
69
|
Slice along a particular axis.
|
|
@@ -80,8 +80,8 @@ def slicer(
|
|
|
80
80
|
msg_out = AxisArray(np.array([]), dims=[""])
|
|
81
81
|
|
|
82
82
|
# State variables
|
|
83
|
-
_slice:
|
|
84
|
-
new_axis:
|
|
83
|
+
_slice: slice | npt.NDArray | None = None
|
|
84
|
+
new_axis: AxisBase | None = None
|
|
85
85
|
b_change_dims: bool = False # If number of dimensions changes when slicing
|
|
86
86
|
|
|
87
87
|
# Reset if input changes
|
|
@@ -154,7 +154,7 @@ def slicer(
|
|
|
154
154
|
|
|
155
155
|
class SlicerSettings(ez.Settings):
|
|
156
156
|
selection: str = ""
|
|
157
|
-
axis:
|
|
157
|
+
axis: str | None = None
|
|
158
158
|
|
|
159
159
|
|
|
160
160
|
class Slicer(GenAxisArray):
|
ezmsg/sigproc/spectrogram.py
CHANGED
|
@@ -12,12 +12,12 @@ from .base import GenAxisArray
|
|
|
12
12
|
|
|
13
13
|
@consumer
|
|
14
14
|
def spectrogram(
|
|
15
|
-
window_dur:
|
|
16
|
-
window_shift:
|
|
15
|
+
window_dur: float | None = None,
|
|
16
|
+
window_shift: float | None = None,
|
|
17
17
|
window: WindowFunction = WindowFunction.HANNING,
|
|
18
18
|
transform: SpectralTransform = SpectralTransform.REL_DB,
|
|
19
19
|
output: SpectralOutput = SpectralOutput.POSITIVE,
|
|
20
|
-
) -> typing.Generator[
|
|
20
|
+
) -> typing.Generator[AxisArray | None, AxisArray, None]:
|
|
21
21
|
"""
|
|
22
22
|
Calculate a spectrogram on streaming data.
|
|
23
23
|
|
|
@@ -50,7 +50,7 @@ def spectrogram(
|
|
|
50
50
|
)
|
|
51
51
|
|
|
52
52
|
# State variables
|
|
53
|
-
msg_out:
|
|
53
|
+
msg_out: AxisArray | None = None
|
|
54
54
|
|
|
55
55
|
while True:
|
|
56
56
|
msg_in: AxisArray = yield msg_out
|
|
@@ -63,8 +63,8 @@ class SpectrogramSettings(ez.Settings):
|
|
|
63
63
|
See :obj:`spectrogram` for a description of the parameters.
|
|
64
64
|
"""
|
|
65
65
|
|
|
66
|
-
window_dur:
|
|
67
|
-
window_shift:
|
|
66
|
+
window_dur: float | None = None # window duration in seconds
|
|
67
|
+
window_shift: float | None = None
|
|
68
68
|
""""window step in seconds. If None, window_shift == window_dur"""
|
|
69
69
|
|
|
70
70
|
# See SpectrumSettings for details of following settings:
|
ezmsg/sigproc/spectrum.py
CHANGED
|
@@ -68,14 +68,14 @@ class SpectralOutput(OptionsEnum):
|
|
|
68
68
|
|
|
69
69
|
@consumer
|
|
70
70
|
def spectrum(
|
|
71
|
-
axis:
|
|
72
|
-
out_axis:
|
|
71
|
+
axis: str | None = None,
|
|
72
|
+
out_axis: str | None = "freq",
|
|
73
73
|
window: WindowFunction = WindowFunction.HANNING,
|
|
74
74
|
transform: SpectralTransform = SpectralTransform.REL_DB,
|
|
75
75
|
output: SpectralOutput = SpectralOutput.POSITIVE,
|
|
76
|
-
norm:
|
|
76
|
+
norm: str | None = "forward",
|
|
77
77
|
do_fftshift: bool = True,
|
|
78
|
-
nfft:
|
|
78
|
+
nfft: int | None = None,
|
|
79
79
|
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
80
80
|
"""
|
|
81
81
|
Calculate a spectrum on a data slice.
|
|
@@ -105,10 +105,10 @@ def spectrum(
|
|
|
105
105
|
apply_window = window != WindowFunction.NONE
|
|
106
106
|
do_fftshift &= output == SpectralOutput.FULL
|
|
107
107
|
f_sl = slice(None)
|
|
108
|
-
freq_axis:
|
|
109
|
-
fftfun: typing.
|
|
110
|
-
f_transform: typing.
|
|
111
|
-
new_dims:
|
|
108
|
+
freq_axis: AxisArray.LinearAxis | None = None
|
|
109
|
+
fftfun: typing.Callable | None = None
|
|
110
|
+
f_transform: typing.Callable | None = None
|
|
111
|
+
new_dims: list[str] | None = None
|
|
112
112
|
|
|
113
113
|
# Reset if input changes substantially
|
|
114
114
|
check_input = {
|
|
@@ -238,9 +238,9 @@ class SpectrumSettings(ez.Settings):
|
|
|
238
238
|
See :obj:`spectrum` for a description of the parameters.
|
|
239
239
|
"""
|
|
240
240
|
|
|
241
|
-
axis:
|
|
242
|
-
# n:
|
|
243
|
-
out_axis:
|
|
241
|
+
axis: str | None = None
|
|
242
|
+
# n: int | None = None # n parameter for fft
|
|
243
|
+
out_axis: str | None = "freq" # If none; don't change dim name
|
|
244
244
|
window: WindowFunction = WindowFunction.HAMMING
|
|
245
245
|
transform: SpectralTransform = SpectralTransform.REL_DB
|
|
246
246
|
output: SpectralOutput = SpectralOutput.POSITIVE
|
ezmsg/sigproc/synth.py
CHANGED
|
@@ -1,18 +1,19 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
from dataclasses import field
|
|
3
3
|
import time
|
|
4
|
-
|
|
4
|
+
import typing
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import ezmsg.core as ez
|
|
8
8
|
from ezmsg.util.generator import consumer
|
|
9
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
10
|
+
from ezmsg.util.messages.util import replace
|
|
10
11
|
|
|
11
12
|
from .butterworthfilter import ButterworthFilter, ButterworthFilterSettings
|
|
12
13
|
from .base import GenAxisArray
|
|
13
14
|
|
|
14
15
|
|
|
15
|
-
def clock(dispatch_rate:
|
|
16
|
+
def clock(dispatch_rate: float | None) -> typing.Generator[ez.Flag, None, None]:
|
|
16
17
|
"""
|
|
17
18
|
Construct a generator that yields events at a specified rate.
|
|
18
19
|
|
|
@@ -32,7 +33,7 @@ def clock(dispatch_rate: Optional[float]) -> Generator[ez.Flag, None, None]:
|
|
|
32
33
|
yield ez.Flag()
|
|
33
34
|
|
|
34
35
|
|
|
35
|
-
async def aclock(dispatch_rate:
|
|
36
|
+
async def aclock(dispatch_rate: float | None) -> typing.AsyncGenerator[ez.Flag, None]:
|
|
36
37
|
"""
|
|
37
38
|
``asyncio`` version of :obj:`clock`.
|
|
38
39
|
|
|
@@ -53,12 +54,12 @@ class ClockSettings(ez.Settings):
|
|
|
53
54
|
"""Settings for :obj:`Clock`. See :obj:`clock` for parameter description."""
|
|
54
55
|
|
|
55
56
|
# Message dispatch rate (Hz), or None (fast as possible)
|
|
56
|
-
dispatch_rate:
|
|
57
|
+
dispatch_rate: float | None
|
|
57
58
|
|
|
58
59
|
|
|
59
60
|
class ClockState(ez.State):
|
|
60
61
|
cur_settings: ClockSettings
|
|
61
|
-
gen: AsyncGenerator
|
|
62
|
+
gen: typing.AsyncGenerator
|
|
62
63
|
|
|
63
64
|
|
|
64
65
|
class Clock(ez.Unit):
|
|
@@ -83,7 +84,7 @@ class Clock(ez.Unit):
|
|
|
83
84
|
self.construct_generator()
|
|
84
85
|
|
|
85
86
|
@ez.publisher(OUTPUT_CLOCK)
|
|
86
|
-
async def generate(self) -> AsyncGenerator:
|
|
87
|
+
async def generate(self) -> typing.AsyncGenerator:
|
|
87
88
|
while True:
|
|
88
89
|
out = await self.STATE.gen.__anext__()
|
|
89
90
|
if out:
|
|
@@ -93,11 +94,11 @@ class Clock(ez.Unit):
|
|
|
93
94
|
# COUNTER - Generate incrementing integer. fs and dispatch_rate parameters combine to give many options. #
|
|
94
95
|
async def acounter(
|
|
95
96
|
n_time: int,
|
|
96
|
-
fs:
|
|
97
|
+
fs: float | None,
|
|
97
98
|
n_ch: int = 1,
|
|
98
|
-
dispatch_rate:
|
|
99
|
-
mod:
|
|
100
|
-
) -> AsyncGenerator[AxisArray, None]:
|
|
99
|
+
dispatch_rate: float | str | None = None,
|
|
100
|
+
mod: int | None = None,
|
|
101
|
+
) -> typing.AsyncGenerator[AxisArray, None]:
|
|
101
102
|
"""
|
|
102
103
|
Construct an asynchronous generator to generate AxisArray objects at a specified rate
|
|
103
104
|
and with the specified sampling rate.
|
|
@@ -206,14 +207,14 @@ class CounterSettings(ez.Settings):
|
|
|
206
207
|
# Message dispatch rate (Hz), 'realtime', 'ext_clock', or None (fast as possible)
|
|
207
208
|
# Note: if dispatch_rate is a float then time offsets will be synthetic and the
|
|
208
209
|
# system will run faster or slower than wall clock time.
|
|
209
|
-
dispatch_rate:
|
|
210
|
+
dispatch_rate: float | str | None = None
|
|
210
211
|
|
|
211
212
|
# If set to an integer, counter will rollover
|
|
212
|
-
mod:
|
|
213
|
+
mod: int | None = None
|
|
213
214
|
|
|
214
215
|
|
|
215
216
|
class CounterState(ez.State):
|
|
216
|
-
gen: AsyncGenerator[AxisArray,
|
|
217
|
+
gen: typing.AsyncGenerator[AxisArray, ez.Flag | None]
|
|
217
218
|
cur_settings: CounterSettings
|
|
218
219
|
new_generator: asyncio.Event
|
|
219
220
|
|
|
@@ -262,7 +263,7 @@ class Counter(ez.Unit):
|
|
|
262
263
|
yield self.OUTPUT_SIGNAL, out
|
|
263
264
|
|
|
264
265
|
@ez.publisher(OUTPUT_SIGNAL)
|
|
265
|
-
async def run_generator(self) -> AsyncGenerator:
|
|
266
|
+
async def run_generator(self) -> typing.AsyncGenerator:
|
|
266
267
|
while True:
|
|
267
268
|
await self.STATE.new_generator.wait()
|
|
268
269
|
self.STATE.new_generator.clear()
|
|
@@ -277,11 +278,11 @@ class Counter(ez.Unit):
|
|
|
277
278
|
|
|
278
279
|
@consumer
|
|
279
280
|
def sin(
|
|
280
|
-
axis:
|
|
281
|
+
axis: str | None = "time",
|
|
281
282
|
freq: float = 1.0,
|
|
282
283
|
amp: float = 1.0,
|
|
283
284
|
phase: float = 0.0,
|
|
284
|
-
) -> Generator[AxisArray, AxisArray, None]:
|
|
285
|
+
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
285
286
|
"""
|
|
286
287
|
Construct a generator of sinusoidal waveforms in AxisArray objects.
|
|
287
288
|
|
|
@@ -319,7 +320,7 @@ class SinGeneratorSettings(ez.Settings):
|
|
|
319
320
|
See :obj:`sin` for parameter descriptions.
|
|
320
321
|
"""
|
|
321
322
|
|
|
322
|
-
time_axis:
|
|
323
|
+
time_axis: str | None = "time"
|
|
323
324
|
freq: float = 1.0 # Oscillation frequency in Hz
|
|
324
325
|
amp: float = 1.0 # Amplitude
|
|
325
326
|
phase: float = 0.0 # Phase offset (in radians)
|
|
@@ -353,7 +354,7 @@ class OscillatorSettings(ez.Settings):
|
|
|
353
354
|
n_ch: int = 1
|
|
354
355
|
"""Number of channels to output per block"""
|
|
355
356
|
|
|
356
|
-
dispatch_rate:
|
|
357
|
+
dispatch_rate: float | str | None = None
|
|
357
358
|
"""(Hz) | 'realtime' | 'ext_clock'"""
|
|
358
359
|
|
|
359
360
|
freq: float = 1.0
|
|
@@ -435,7 +436,7 @@ class RandomGenerator(ez.Unit):
|
|
|
435
436
|
|
|
436
437
|
@ez.subscriber(INPUT_SIGNAL)
|
|
437
438
|
@ez.publisher(OUTPUT_SIGNAL)
|
|
438
|
-
async def generate(self, msg: AxisArray) -> AsyncGenerator:
|
|
439
|
+
async def generate(self, msg: AxisArray) -> typing.AsyncGenerator:
|
|
439
440
|
random_data = np.random.normal(
|
|
440
441
|
size=msg.shape, loc=self.SETTINGS.loc, scale=self.SETTINGS.scale
|
|
441
442
|
)
|
|
@@ -451,7 +452,7 @@ class NoiseSettings(ez.Settings):
|
|
|
451
452
|
n_time: int # Number of samples to output per block
|
|
452
453
|
fs: float # Sampling rate of signal output in Hz
|
|
453
454
|
n_ch: int = 1 # Number of channels to output
|
|
454
|
-
dispatch_rate:
|
|
455
|
+
dispatch_rate: float | str | None = None
|
|
455
456
|
"""(Hz), 'realtime', or 'ext_clock'"""
|
|
456
457
|
loc: float = 0.0 # DC offset
|
|
457
458
|
scale: float = 1.0 # Scale (in standard deviations)
|
|
@@ -553,12 +554,12 @@ class Add(ez.Unit):
|
|
|
553
554
|
self.STATE.queue_b.put_nowait(msg)
|
|
554
555
|
|
|
555
556
|
@ez.publisher(OUTPUT_SIGNAL)
|
|
556
|
-
async def output(self) -> AsyncGenerator:
|
|
557
|
+
async def output(self) -> typing.AsyncGenerator:
|
|
557
558
|
while True:
|
|
558
559
|
a = await self.STATE.queue_a.get()
|
|
559
560
|
b = await self.STATE.queue_b.get()
|
|
560
561
|
|
|
561
|
-
yield
|
|
562
|
+
yield self.OUTPUT_SIGNAL, replace(a, data=a.data + b.data)
|
|
562
563
|
|
|
563
564
|
|
|
564
565
|
class EEGSynthSettings(ez.Settings):
|
ezmsg/sigproc/wavelets.py
CHANGED
|
@@ -4,7 +4,8 @@ import numpy as np
|
|
|
4
4
|
import numpy.typing as npt
|
|
5
5
|
import pywt
|
|
6
6
|
import ezmsg.core as ez
|
|
7
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
|
+
from ezmsg.util.messages.util import replace
|
|
8
9
|
from ezmsg.util.generator import consumer
|
|
9
10
|
|
|
10
11
|
from .base import GenAxisArray
|
|
@@ -13,44 +14,61 @@ from .filterbank import filterbank, FilterbankMode, MinPhaseMode
|
|
|
13
14
|
|
|
14
15
|
@consumer
|
|
15
16
|
def cwt(
|
|
16
|
-
|
|
17
|
-
wavelet:
|
|
17
|
+
frequencies: list | tuple | npt.NDArray | None,
|
|
18
|
+
wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet,
|
|
18
19
|
min_phase: MinPhaseMode = MinPhaseMode.NONE,
|
|
19
20
|
axis: str = "time",
|
|
21
|
+
scales: list | tuple | npt.NDArray | None = None,
|
|
20
22
|
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
21
23
|
"""
|
|
22
24
|
Perform a continuous wavelet transform.
|
|
23
25
|
The function is equivalent to the :obj:`pywt.cwt` function, but is designed to work with streaming data.
|
|
24
26
|
|
|
25
27
|
Args:
|
|
26
|
-
|
|
28
|
+
frequencies: The wavelet frequencies to use in Hz. If `None` provided then the scales will be used.
|
|
29
|
+
Note: frequencies will be sorted from smallest to largest.
|
|
27
30
|
wavelet: Wavelet object or name of wavelet to use.
|
|
28
31
|
min_phase: See filterbank MinPhaseMode for details.
|
|
29
32
|
axis: The target axis for operation. Note that this will be moved to the -1th dimension
|
|
30
33
|
because fft and matrix multiplication is much faster on the last axis.
|
|
31
34
|
This axis must be in the msg.axes and it must be of type AxisArray.LinearAxis.
|
|
35
|
+
scales: The scales to use. If None, the scales will be calculated from the frequencies.
|
|
36
|
+
Note: Scales will be sorted from largest to smallest.
|
|
37
|
+
Note: Use of scales is deprecated in favor of frequencies. Convert scales to frequencies using
|
|
38
|
+
`pywt.scale2frequency(wavelet, scales, precision=10) * fs` where fs is the sampling frequency.
|
|
32
39
|
|
|
33
40
|
Returns:
|
|
34
41
|
A primed Generator object that expects an :obj:`AxisArray` via `.send(axis_array)` of continuous data
|
|
35
42
|
and yields an :obj:`AxisArray` with a continuous wavelet transform in its data.
|
|
36
43
|
"""
|
|
37
|
-
|
|
44
|
+
precision = 10
|
|
45
|
+
msg_out: AxisArray | None = None
|
|
38
46
|
|
|
39
47
|
# Check parameters
|
|
40
|
-
scales
|
|
41
|
-
|
|
42
|
-
|
|
48
|
+
if frequencies is None and scales is None:
|
|
49
|
+
raise ValueError("Either frequencies or scales must be provided.")
|
|
50
|
+
if frequencies is not None and scales is not None:
|
|
51
|
+
raise ValueError("Only one of frequencies or scales can be provided.")
|
|
52
|
+
if scales is not None:
|
|
53
|
+
scales = np.sort(scales)[::-1]
|
|
54
|
+
assert np.all(scales > 0), "scales must be positive."
|
|
55
|
+
assert scales.ndim == 1, "scales must be a 1D list, tuple, or array."
|
|
56
|
+
|
|
43
57
|
if not isinstance(wavelet, (pywt.ContinuousWavelet, pywt.Wavelet)):
|
|
44
58
|
wavelet = pywt.DiscreteContinuousWavelet(wavelet)
|
|
45
|
-
|
|
59
|
+
|
|
60
|
+
if frequencies is not None:
|
|
61
|
+
frequencies = np.sort(frequencies)
|
|
62
|
+
assert np.all(frequencies > 0), "frequencies must be positive."
|
|
63
|
+
assert frequencies.ndim == 1, "frequencies must be a 1D list, tuple, or array."
|
|
46
64
|
|
|
47
65
|
# State variables
|
|
48
|
-
neg_rt_scales
|
|
66
|
+
neg_rt_scales: npt.NDArray | None = None
|
|
49
67
|
int_psi, wave_xvec = pywt.integrate_wavelet(wavelet, precision=precision)
|
|
50
68
|
int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
|
|
51
|
-
template:
|
|
52
|
-
fbgen: typing.
|
|
53
|
-
last_conv_samp:
|
|
69
|
+
template: AxisArray | None = None
|
|
70
|
+
fbgen: typing.Generator[AxisArray, AxisArray, None] | None = None
|
|
71
|
+
last_conv_samp: npt.NDArray | None = None
|
|
54
72
|
|
|
55
73
|
# Reset if input changed
|
|
56
74
|
check_input = {
|
|
@@ -76,6 +94,12 @@ def cwt(
|
|
|
76
94
|
check_input["shape"] = in_shape
|
|
77
95
|
check_input["key"] = msg_in.key
|
|
78
96
|
|
|
97
|
+
if frequencies is not None:
|
|
98
|
+
scales = pywt.frequency2scale(
|
|
99
|
+
wavelet, frequencies * msg_in.axes[axis].gain, precision=precision
|
|
100
|
+
)
|
|
101
|
+
neg_rt_scales = -np.sqrt(scales)[:, None]
|
|
102
|
+
|
|
79
103
|
# convert int_psi, wave_xvec to the same precision as the data
|
|
80
104
|
dt_data = msg_in.data.dtype # _check_dtype(msg_in.data)
|
|
81
105
|
dt_cplx = np.result_type(dt_data, np.complex64)
|
|
@@ -148,8 +172,8 @@ class CWTSettings(ez.Settings):
|
|
|
148
172
|
See :obj:`cwt` for argument details.
|
|
149
173
|
"""
|
|
150
174
|
|
|
151
|
-
scales:
|
|
152
|
-
wavelet:
|
|
175
|
+
scales: list | tuple | npt.NDArray
|
|
176
|
+
wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet
|
|
153
177
|
min_phase: MinPhaseMode = MinPhaseMode.NONE
|
|
154
178
|
axis: str = "time"
|
|
155
179
|
|