ezmsg-sigproc 1.8.2__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +36 -39
- ezmsg/sigproc/adaptive_lattice_notch.py +231 -0
- ezmsg/sigproc/affinetransform.py +169 -163
- ezmsg/sigproc/aggregate.py +133 -101
- ezmsg/sigproc/bandpower.py +64 -52
- ezmsg/sigproc/base.py +1242 -0
- ezmsg/sigproc/butterworthfilter.py +37 -33
- ezmsg/sigproc/cheby.py +29 -17
- ezmsg/sigproc/combfilter.py +163 -0
- ezmsg/sigproc/decimate.py +19 -10
- ezmsg/sigproc/detrend.py +29 -0
- ezmsg/sigproc/diff.py +81 -0
- ezmsg/sigproc/downsample.py +78 -84
- ezmsg/sigproc/ewma.py +197 -0
- ezmsg/sigproc/extract_axis.py +41 -0
- ezmsg/sigproc/filter.py +257 -141
- ezmsg/sigproc/filterbank.py +247 -199
- ezmsg/sigproc/math/abs.py +17 -22
- ezmsg/sigproc/math/clip.py +24 -24
- ezmsg/sigproc/math/difference.py +34 -30
- ezmsg/sigproc/math/invert.py +13 -25
- ezmsg/sigproc/math/log.py +28 -33
- ezmsg/sigproc/math/scale.py +18 -26
- ezmsg/sigproc/quantize.py +71 -0
- ezmsg/sigproc/resample.py +298 -0
- ezmsg/sigproc/sampler.py +241 -259
- ezmsg/sigproc/scaler.py +55 -218
- ezmsg/sigproc/signalinjector.py +52 -43
- ezmsg/sigproc/slicer.py +81 -89
- ezmsg/sigproc/spectrogram.py +77 -75
- ezmsg/sigproc/spectrum.py +203 -168
- ezmsg/sigproc/synth.py +546 -393
- ezmsg/sigproc/transpose.py +131 -0
- ezmsg/sigproc/util/asio.py +156 -0
- ezmsg/sigproc/util/message.py +31 -0
- ezmsg/sigproc/util/profile.py +55 -12
- ezmsg/sigproc/util/typeresolution.py +83 -0
- ezmsg/sigproc/wavelets.py +154 -153
- ezmsg/sigproc/window.py +269 -211
- {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/METADATA +2 -1
- ezmsg_sigproc-2.1.0.dist-info/RECORD +51 -0
- ezmsg_sigproc-1.8.2.dist-info/RECORD +0 -39
- {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/WHEEL +0 -0
- {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/licenses/LICENSE.txt +0 -0
ezmsg/sigproc/wavelets.py
CHANGED
|
@@ -6,190 +6,191 @@ import pywt
|
|
|
6
6
|
import ezmsg.core as ez
|
|
7
7
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
8
|
from ezmsg.util.messages.util import replace
|
|
9
|
-
from ezmsg.util.generator import consumer
|
|
10
9
|
|
|
11
|
-
from .base import
|
|
10
|
+
from .base import (
|
|
11
|
+
BaseStatefulTransformer,
|
|
12
|
+
BaseTransformerUnit,
|
|
13
|
+
processor_state,
|
|
14
|
+
)
|
|
12
15
|
from .filterbank import filterbank, FilterbankMode, MinPhaseMode
|
|
13
16
|
|
|
14
17
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
min_phase: MinPhaseMode = MinPhaseMode.NONE,
|
|
20
|
-
axis: str = "time",
|
|
21
|
-
scales: list | tuple | npt.NDArray | None = None,
|
|
22
|
-
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
18
|
+
class CWTSettings(ez.Settings):
|
|
19
|
+
"""
|
|
20
|
+
Settings for :obj:`CWT`
|
|
21
|
+
See :obj:`cwt` for argument details.
|
|
23
22
|
"""
|
|
24
|
-
Perform a continuous wavelet transform.
|
|
25
|
-
The function is equivalent to the :obj:`pywt.cwt` function, but is designed to work with streaming data.
|
|
26
23
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
axis: The target axis for operation. Note that this will be moved to the -1th dimension
|
|
33
|
-
because fft and matrix multiplication is much faster on the last axis.
|
|
34
|
-
This axis must be in the msg.axes and it must be of type AxisArray.LinearAxis.
|
|
35
|
-
scales: The scales to use. If None, the scales will be calculated from the frequencies.
|
|
36
|
-
Note: Scales will be sorted from largest to smallest.
|
|
37
|
-
Note: Use of scales is deprecated in favor of frequencies. Convert scales to frequencies using
|
|
38
|
-
`pywt.scale2frequency(wavelet, scales, precision=10) * fs` where fs is the sampling frequency.
|
|
24
|
+
frequencies: list | tuple | npt.NDArray | None
|
|
25
|
+
wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet
|
|
26
|
+
min_phase: MinPhaseMode = MinPhaseMode.NONE
|
|
27
|
+
axis: str = "time"
|
|
28
|
+
scales: list | tuple | npt.NDArray | None = None
|
|
39
29
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
"""
|
|
44
|
-
precision = 10
|
|
45
|
-
msg_out: AxisArray | None = None
|
|
46
|
-
|
|
47
|
-
# Check parameters
|
|
48
|
-
if frequencies is None and scales is None:
|
|
49
|
-
raise ValueError("Either frequencies or scales must be provided.")
|
|
50
|
-
if frequencies is not None and scales is not None:
|
|
51
|
-
raise ValueError("Only one of frequencies or scales can be provided.")
|
|
52
|
-
if scales is not None:
|
|
53
|
-
scales = np.sort(scales)[::-1]
|
|
54
|
-
assert np.all(scales > 0), "scales must be positive."
|
|
55
|
-
assert scales.ndim == 1, "scales must be a 1D list, tuple, or array."
|
|
56
|
-
|
|
57
|
-
if not isinstance(wavelet, (pywt.ContinuousWavelet, pywt.Wavelet)):
|
|
58
|
-
wavelet = pywt.DiscreteContinuousWavelet(wavelet)
|
|
59
|
-
|
|
60
|
-
if frequencies is not None:
|
|
61
|
-
frequencies = np.sort(frequencies)
|
|
62
|
-
assert np.all(frequencies > 0), "frequencies must be positive."
|
|
63
|
-
assert frequencies.ndim == 1, "frequencies must be a 1D list, tuple, or array."
|
|
64
|
-
|
|
65
|
-
# State variables
|
|
30
|
+
|
|
31
|
+
@processor_state
|
|
32
|
+
class CWTState:
|
|
66
33
|
neg_rt_scales: npt.NDArray | None = None
|
|
67
|
-
|
|
68
|
-
int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
|
|
34
|
+
int_psi_scales: list[npt.NDArray] | None = None
|
|
69
35
|
template: AxisArray | None = None
|
|
70
36
|
fbgen: typing.Generator[AxisArray, AxisArray, None] | None = None
|
|
71
37
|
last_conv_samp: npt.NDArray | None = None
|
|
72
38
|
|
|
73
|
-
# Reset if input changed
|
|
74
|
-
check_input = {
|
|
75
|
-
"kind": None, # Need to recalc kernels at same complexity as input
|
|
76
|
-
"gain": None, # Need to recalc freqs
|
|
77
|
-
"shape": None, # Need to recalc template and buffer
|
|
78
|
-
"key": None, # Buffer obsolete
|
|
79
|
-
}
|
|
80
|
-
|
|
81
|
-
while True:
|
|
82
|
-
msg_in: AxisArray = yield msg_out
|
|
83
|
-
ax_idx = msg_in.get_axis_idx(axis)
|
|
84
|
-
in_shape = msg_in.data.shape[:ax_idx] + msg_in.data.shape[ax_idx + 1 :]
|
|
85
|
-
|
|
86
|
-
b_reset = msg_in.data.dtype.kind != check_input["kind"]
|
|
87
|
-
b_reset = b_reset or msg_in.axes[axis].gain != check_input["gain"]
|
|
88
|
-
b_reset = b_reset or in_shape != check_input["shape"]
|
|
89
|
-
b_reset = b_reset or msg_in.key != check_input["key"]
|
|
90
|
-
b_reset = b_reset and msg_in.data.size > 0
|
|
91
|
-
if b_reset:
|
|
92
|
-
check_input["kind"] = msg_in.data.dtype.kind
|
|
93
|
-
check_input["gain"] = msg_in.axes[axis].gain
|
|
94
|
-
check_input["shape"] = in_shape
|
|
95
|
-
check_input["key"] = msg_in.key
|
|
96
|
-
|
|
97
|
-
if frequencies is not None:
|
|
98
|
-
scales = pywt.frequency2scale(
|
|
99
|
-
wavelet, frequencies * msg_in.axes[axis].gain, precision=precision
|
|
100
|
-
)
|
|
101
|
-
neg_rt_scales = -np.sqrt(scales)[:, None]
|
|
102
|
-
|
|
103
|
-
# convert int_psi, wave_xvec to the same precision as the data
|
|
104
|
-
dt_data = msg_in.data.dtype # _check_dtype(msg_in.data)
|
|
105
|
-
dt_cplx = np.result_type(dt_data, np.complex64)
|
|
106
|
-
dt_psi = dt_cplx if int_psi.dtype.kind == "c" else dt_data
|
|
107
|
-
int_psi = np.asarray(int_psi, dtype=dt_psi)
|
|
108
|
-
# TODO: Currently int_psi cannot be made non-complex once it is complex.
|
|
109
|
-
|
|
110
|
-
# Calculate waves for each scale
|
|
111
|
-
wave_xvec = np.asarray(wave_xvec, dtype=msg_in.data.real.dtype)
|
|
112
|
-
wave_range = wave_xvec[-1] - wave_xvec[0]
|
|
113
|
-
step = wave_xvec[1] - wave_xvec[0]
|
|
114
|
-
int_psi_scales = []
|
|
115
|
-
for scale in scales:
|
|
116
|
-
reix = (np.arange(scale * wave_range + 1) / (scale * step)).astype(int)
|
|
117
|
-
if reix[-1] >= int_psi.size:
|
|
118
|
-
reix = np.extract(reix < int_psi.size, reix)
|
|
119
|
-
int_psi_scales.append(int_psi[reix][::-1])
|
|
120
|
-
|
|
121
|
-
# CONV is probably best because we often get huge kernels.
|
|
122
|
-
fbgen = filterbank(
|
|
123
|
-
int_psi_scales, mode=FilterbankMode.CONV, min_phase=min_phase, axis=axis
|
|
124
|
-
)
|
|
125
39
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
**msg_in.axes,
|
|
139
|
-
"freq": AxisArray.CoordinateAxis(
|
|
140
|
-
unit="Hz", data=freqs, dims=["freq"]
|
|
141
|
-
),
|
|
142
|
-
},
|
|
143
|
-
key=msg_in.key,
|
|
40
|
+
class CWTTransformer(
|
|
41
|
+
BaseStatefulTransformer[CWTSettings, AxisArray, AxisArray, CWTState]
|
|
42
|
+
):
|
|
43
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
44
|
+
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
45
|
+
in_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
|
|
46
|
+
return hash(
|
|
47
|
+
(
|
|
48
|
+
message.data.dtype.kind,
|
|
49
|
+
message.axes[self.settings.axis].gain,
|
|
50
|
+
in_shape,
|
|
51
|
+
message.key,
|
|
144
52
|
)
|
|
145
|
-
|
|
146
|
-
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
56
|
+
precision = 10
|
|
57
|
+
|
|
58
|
+
# Process wavelet
|
|
59
|
+
wavelet = (
|
|
60
|
+
self.settings.wavelet
|
|
61
|
+
if isinstance(self.settings.wavelet, (pywt.ContinuousWavelet, pywt.Wavelet))
|
|
62
|
+
else pywt.DiscreteContinuousWavelet(self.settings.wavelet)
|
|
63
|
+
)
|
|
64
|
+
# Process wavelet integration
|
|
65
|
+
int_psi, wave_xvec = pywt.integrate_wavelet(wavelet, precision=precision)
|
|
66
|
+
int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
|
|
67
|
+
|
|
68
|
+
# Calculate scales and frequencies
|
|
69
|
+
if self.settings.frequencies is not None:
|
|
70
|
+
frequencies = np.sort(np.array(self.settings.frequencies))
|
|
71
|
+
scales = pywt.frequency2scale(
|
|
72
|
+
wavelet,
|
|
73
|
+
frequencies * message.axes[self.settings.axis].gain,
|
|
74
|
+
precision=precision,
|
|
147
75
|
)
|
|
76
|
+
else:
|
|
77
|
+
scales = np.sort(self.settings.scales)[::-1]
|
|
78
|
+
|
|
79
|
+
self._state.neg_rt_scales = -np.sqrt(scales)[:, None]
|
|
80
|
+
|
|
81
|
+
# Convert to appropriate dtype
|
|
82
|
+
dt_data = message.data.dtype
|
|
83
|
+
dt_cplx = np.result_type(dt_data, np.complex64)
|
|
84
|
+
dt_psi = dt_cplx if int_psi.dtype.kind == "c" else dt_data
|
|
85
|
+
int_psi = np.asarray(int_psi, dtype=dt_psi)
|
|
86
|
+
# Note: Currently int_psi cannot be made non-complex once it is complex.
|
|
87
|
+
|
|
88
|
+
# Calculate waves for each scale
|
|
89
|
+
wave_xvec = np.asarray(wave_xvec, dtype=message.data.real.dtype)
|
|
90
|
+
wave_range = wave_xvec[-1] - wave_xvec[0]
|
|
91
|
+
step = wave_xvec[1] - wave_xvec[0]
|
|
92
|
+
self._state.int_psi_scales = []
|
|
93
|
+
for scale in scales:
|
|
94
|
+
reix = (np.arange(scale * wave_range + 1) / (scale * step)).astype(int)
|
|
95
|
+
if reix[-1] >= int_psi.size:
|
|
96
|
+
reix = np.extract(reix < int_psi.size, reix)
|
|
97
|
+
self._state.int_psi_scales.append(int_psi[reix][::-1])
|
|
98
|
+
|
|
99
|
+
# Setup filterbank generator
|
|
100
|
+
self._state.fbgen = filterbank(
|
|
101
|
+
self._state.int_psi_scales,
|
|
102
|
+
mode=FilterbankMode.CONV,
|
|
103
|
+
min_phase=self.settings.min_phase,
|
|
104
|
+
axis=self.settings.axis,
|
|
105
|
+
)
|
|
148
106
|
|
|
149
|
-
|
|
107
|
+
# Create output template
|
|
108
|
+
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
109
|
+
in_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
|
|
110
|
+
freqs = (
|
|
111
|
+
pywt.scale2frequency(wavelet, scales, precision)
|
|
112
|
+
/ message.axes[self.settings.axis].gain
|
|
113
|
+
)
|
|
114
|
+
dummy_shape = in_shape + (len(scales), 0)
|
|
115
|
+
self._state.template = AxisArray(
|
|
116
|
+
np.zeros(dummy_shape, dtype=dt_cplx if wavelet.complex_cwt else dt_data),
|
|
117
|
+
dims=message.dims[:ax_idx]
|
|
118
|
+
+ message.dims[ax_idx + 1 :]
|
|
119
|
+
+ ["freq", self.settings.axis],
|
|
120
|
+
axes={
|
|
121
|
+
**message.axes,
|
|
122
|
+
"freq": AxisArray.CoordinateAxis(unit="Hz", data=freqs, dims=["freq"]),
|
|
123
|
+
},
|
|
124
|
+
key=message.key,
|
|
125
|
+
)
|
|
126
|
+
self._state.last_conv_samp = np.zeros(
|
|
127
|
+
dummy_shape[:-1] + (1,), dtype=self._state.template.data.dtype
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
131
|
+
conv_msg = self._state.fbgen.send(message)
|
|
150
132
|
|
|
151
133
|
# Prepend with last_conv_samp before doing diff
|
|
152
|
-
dat = np.concatenate((last_conv_samp, conv_msg.data), axis=-1)
|
|
153
|
-
coef = neg_rt_scales * np.diff(dat, axis=-1)
|
|
154
|
-
# Store last_conv_samp for next iteration
|
|
155
|
-
last_conv_samp = conv_msg.data[..., -1:]
|
|
134
|
+
dat = np.concatenate((self._state.last_conv_samp, conv_msg.data), axis=-1)
|
|
135
|
+
coef = self._state.neg_rt_scales * np.diff(dat, axis=-1)
|
|
136
|
+
# Store last_conv_samp for next iteration
|
|
137
|
+
self._state.last_conv_samp = conv_msg.data[..., -1:]
|
|
156
138
|
|
|
157
|
-
if template.data.dtype.kind != "c":
|
|
139
|
+
if self._state.template.data.dtype.kind != "c":
|
|
158
140
|
coef = coef.real
|
|
159
141
|
|
|
160
142
|
# pywt.cwt slices off the beginning and end of the result where the convolution overran. We don't have
|
|
161
143
|
# that luxury when streaming.
|
|
162
144
|
# d = (coef.shape[-1] - msg_in.data.shape[ax_idx]) / 2.
|
|
163
145
|
# coef = coef[..., math.floor(d):-math.ceil(d)]
|
|
164
|
-
|
|
165
|
-
template,
|
|
146
|
+
return replace(
|
|
147
|
+
self._state.template,
|
|
148
|
+
data=coef,
|
|
149
|
+
axes={
|
|
150
|
+
**self._state.template.axes,
|
|
151
|
+
self.settings.axis: message.axes[self.settings.axis],
|
|
152
|
+
},
|
|
166
153
|
)
|
|
167
154
|
|
|
168
155
|
|
|
169
|
-
class CWTSettings
|
|
170
|
-
|
|
171
|
-
Settings for :obj:`CWT`
|
|
172
|
-
See :obj:`cwt` for argument details.
|
|
173
|
-
"""
|
|
174
|
-
frequencies: list | tuple | npt.NDArray | None
|
|
175
|
-
wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet
|
|
176
|
-
min_phase: MinPhaseMode = MinPhaseMode.NONE
|
|
177
|
-
axis: str = "time"
|
|
178
|
-
scales: list | tuple | npt.NDArray | None = None
|
|
156
|
+
class CWT(BaseTransformerUnit[CWTSettings, AxisArray, AxisArray, CWTTransformer]):
|
|
157
|
+
SETTINGS = CWTSettings
|
|
179
158
|
|
|
180
159
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
:
|
|
160
|
+
def cwt(
|
|
161
|
+
frequencies: list | tuple | npt.NDArray | None,
|
|
162
|
+
wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet,
|
|
163
|
+
min_phase: MinPhaseMode = MinPhaseMode.NONE,
|
|
164
|
+
axis: str = "time",
|
|
165
|
+
scales: list | tuple | npt.NDArray | None = None,
|
|
166
|
+
) -> CWTTransformer:
|
|
184
167
|
"""
|
|
168
|
+
Perform a continuous wavelet transform.
|
|
169
|
+
The function is equivalent to the :obj:`pywt.cwt` function, but is designed to work with streaming data.
|
|
185
170
|
|
|
186
|
-
|
|
171
|
+
Args:
|
|
172
|
+
frequencies: The wavelet frequencies to use in Hz. If `None` provided then the scales will be used.
|
|
173
|
+
Note: frequencies will be sorted from smallest to largest.
|
|
174
|
+
wavelet: Wavelet object or name of wavelet to use.
|
|
175
|
+
min_phase: See filterbank MinPhaseMode for details.
|
|
176
|
+
axis: The target axis for operation. Note that this will be moved to the -1th dimension
|
|
177
|
+
because fft and matrix multiplication is much faster on the last axis.
|
|
178
|
+
This axis must be in the msg.axes and it must be of type AxisArray.LinearAxis.
|
|
179
|
+
scales: The scales to use. If None, the scales will be calculated from the frequencies.
|
|
180
|
+
Note: Scales will be sorted from largest to smallest.
|
|
181
|
+
Note: Use of scales is deprecated in favor of frequencies. Convert scales to frequencies using
|
|
182
|
+
`pywt.scale2frequency(wavelet, scales, precision=10) * fs` where fs is the sampling frequency.
|
|
187
183
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
184
|
+
Returns:
|
|
185
|
+
A primed Generator object that expects an :obj:`AxisArray` via `.send(axis_array)` of continuous data
|
|
186
|
+
and yields an :obj:`AxisArray` with a continuous wavelet transform in its data.
|
|
187
|
+
"""
|
|
188
|
+
return CWTTransformer(
|
|
189
|
+
CWTSettings(
|
|
190
|
+
frequencies=frequencies,
|
|
191
|
+
wavelet=wavelet,
|
|
192
|
+
min_phase=min_phase,
|
|
193
|
+
axis=axis,
|
|
194
|
+
scales=scales,
|
|
195
195
|
)
|
|
196
|
+
)
|