ezmsg-sigproc 1.7.0__py3-none-any.whl → 2.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +22 -4
- ezmsg/sigproc/activation.py +31 -40
- ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
- ezmsg/sigproc/affinetransform.py +171 -169
- ezmsg/sigproc/aggregate.py +190 -97
- ezmsg/sigproc/bandpower.py +60 -55
- ezmsg/sigproc/base.py +143 -33
- ezmsg/sigproc/butterworthfilter.py +34 -38
- ezmsg/sigproc/butterworthzerophase.py +305 -0
- ezmsg/sigproc/cheby.py +23 -17
- ezmsg/sigproc/combfilter.py +160 -0
- ezmsg/sigproc/coordinatespaces.py +159 -0
- ezmsg/sigproc/decimate.py +15 -10
- ezmsg/sigproc/denormalize.py +78 -0
- ezmsg/sigproc/detrend.py +28 -0
- ezmsg/sigproc/diff.py +82 -0
- ezmsg/sigproc/downsample.py +72 -81
- ezmsg/sigproc/ewma.py +217 -0
- ezmsg/sigproc/ewmfilter.py +1 -1
- ezmsg/sigproc/extract_axis.py +39 -0
- ezmsg/sigproc/fbcca.py +307 -0
- ezmsg/sigproc/filter.py +254 -148
- ezmsg/sigproc/filterbank.py +226 -214
- ezmsg/sigproc/filterbankdesign.py +129 -0
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +117 -0
- ezmsg/sigproc/gaussiansmoothing.py +89 -0
- ezmsg/sigproc/kaiser.py +106 -0
- ezmsg/sigproc/linear.py +120 -0
- ezmsg/sigproc/math/abs.py +23 -22
- ezmsg/sigproc/math/add.py +120 -0
- ezmsg/sigproc/math/clip.py +33 -25
- ezmsg/sigproc/math/difference.py +117 -43
- ezmsg/sigproc/math/invert.py +18 -25
- ezmsg/sigproc/math/log.py +38 -33
- ezmsg/sigproc/math/scale.py +24 -25
- ezmsg/sigproc/messages.py +1 -2
- ezmsg/sigproc/quantize.py +68 -0
- ezmsg/sigproc/resample.py +278 -0
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +209 -254
- ezmsg/sigproc/scaler.py +93 -218
- ezmsg/sigproc/signalinjector.py +44 -43
- ezmsg/sigproc/slicer.py +74 -102
- ezmsg/sigproc/spectral.py +3 -3
- ezmsg/sigproc/spectrogram.py +70 -70
- ezmsg/sigproc/spectrum.py +187 -173
- ezmsg/sigproc/transpose.py +134 -0
- ezmsg/sigproc/util/__init__.py +0 -0
- ezmsg/sigproc/util/asio.py +25 -0
- ezmsg/sigproc/util/axisarray_buffer.py +365 -0
- ezmsg/sigproc/util/buffer.py +449 -0
- ezmsg/sigproc/util/message.py +17 -0
- ezmsg/sigproc/util/profile.py +23 -0
- ezmsg/sigproc/util/sparse.py +115 -0
- ezmsg/sigproc/util/typeresolution.py +17 -0
- ezmsg/sigproc/wavelets.py +147 -154
- ezmsg/sigproc/window.py +248 -210
- ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
- {ezmsg_sigproc-1.7.0.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -1
- ezmsg/sigproc/synth.py +0 -621
- ezmsg_sigproc-1.7.0.dist-info/METADATA +0 -58
- ezmsg_sigproc-1.7.0.dist-info/RECORD +0 -36
- /ezmsg_sigproc-1.7.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/wavelets.py
CHANGED
|
@@ -1,194 +1,187 @@
|
|
|
1
1
|
import typing
|
|
2
2
|
|
|
3
|
+
import ezmsg.core as ez
|
|
3
4
|
import numpy as np
|
|
4
5
|
import numpy.typing as npt
|
|
5
6
|
import pywt
|
|
6
|
-
|
|
7
|
+
from ezmsg.baseproc import (
|
|
8
|
+
BaseStatefulTransformer,
|
|
9
|
+
BaseTransformerUnit,
|
|
10
|
+
processor_state,
|
|
11
|
+
)
|
|
7
12
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
13
|
from ezmsg.util.messages.util import replace
|
|
9
|
-
from ezmsg.util.generator import consumer
|
|
10
14
|
|
|
11
|
-
from .
|
|
12
|
-
from .filterbank import filterbank, FilterbankMode, MinPhaseMode
|
|
15
|
+
from .filterbank import FilterbankMode, MinPhaseMode, filterbank
|
|
13
16
|
|
|
14
17
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
min_phase: MinPhaseMode = MinPhaseMode.NONE,
|
|
20
|
-
axis: str = "time",
|
|
21
|
-
scales: list | tuple | npt.NDArray | None = None,
|
|
22
|
-
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
18
|
+
class CWTSettings(ez.Settings):
|
|
19
|
+
"""
|
|
20
|
+
Settings for :obj:`CWT`
|
|
21
|
+
See :obj:`cwt` for argument details.
|
|
23
22
|
"""
|
|
24
|
-
Perform a continuous wavelet transform.
|
|
25
|
-
The function is equivalent to the :obj:`pywt.cwt` function, but is designed to work with streaming data.
|
|
26
23
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
axis: The target axis for operation. Note that this will be moved to the -1th dimension
|
|
33
|
-
because fft and matrix multiplication is much faster on the last axis.
|
|
34
|
-
This axis must be in the msg.axes and it must be of type AxisArray.LinearAxis.
|
|
35
|
-
scales: The scales to use. If None, the scales will be calculated from the frequencies.
|
|
36
|
-
Note: Scales will be sorted from largest to smallest.
|
|
37
|
-
Note: Use of scales is deprecated in favor of frequencies. Convert scales to frequencies using
|
|
38
|
-
`pywt.scale2frequency(wavelet, scales, precision=10) * fs` where fs is the sampling frequency.
|
|
24
|
+
frequencies: list | tuple | npt.NDArray | None
|
|
25
|
+
wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet
|
|
26
|
+
min_phase: MinPhaseMode = MinPhaseMode.NONE
|
|
27
|
+
axis: str = "time"
|
|
28
|
+
scales: list | tuple | npt.NDArray | None = None
|
|
39
29
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
"""
|
|
44
|
-
precision = 10
|
|
45
|
-
msg_out: AxisArray | None = None
|
|
46
|
-
|
|
47
|
-
# Check parameters
|
|
48
|
-
if frequencies is None and scales is None:
|
|
49
|
-
raise ValueError("Either frequencies or scales must be provided.")
|
|
50
|
-
if frequencies is not None and scales is not None:
|
|
51
|
-
raise ValueError("Only one of frequencies or scales can be provided.")
|
|
52
|
-
if scales is not None:
|
|
53
|
-
scales = np.sort(scales)[::-1]
|
|
54
|
-
assert np.all(scales > 0), "scales must be positive."
|
|
55
|
-
assert scales.ndim == 1, "scales must be a 1D list, tuple, or array."
|
|
56
|
-
|
|
57
|
-
if not isinstance(wavelet, (pywt.ContinuousWavelet, pywt.Wavelet)):
|
|
58
|
-
wavelet = pywt.DiscreteContinuousWavelet(wavelet)
|
|
59
|
-
|
|
60
|
-
if frequencies is not None:
|
|
61
|
-
frequencies = np.sort(frequencies)
|
|
62
|
-
assert np.all(frequencies > 0), "frequencies must be positive."
|
|
63
|
-
assert frequencies.ndim == 1, "frequencies must be a 1D list, tuple, or array."
|
|
64
|
-
|
|
65
|
-
# State variables
|
|
30
|
+
|
|
31
|
+
@processor_state
|
|
32
|
+
class CWTState:
|
|
66
33
|
neg_rt_scales: npt.NDArray | None = None
|
|
67
|
-
|
|
68
|
-
int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
|
|
34
|
+
int_psi_scales: list[npt.NDArray] | None = None
|
|
69
35
|
template: AxisArray | None = None
|
|
70
36
|
fbgen: typing.Generator[AxisArray, AxisArray, None] | None = None
|
|
71
37
|
last_conv_samp: npt.NDArray | None = None
|
|
72
38
|
|
|
73
|
-
# Reset if input changed
|
|
74
|
-
check_input = {
|
|
75
|
-
"kind": None, # Need to recalc kernels at same complexity as input
|
|
76
|
-
"gain": None, # Need to recalc freqs
|
|
77
|
-
"shape": None, # Need to recalc template and buffer
|
|
78
|
-
"key": None, # Buffer obsolete
|
|
79
|
-
}
|
|
80
|
-
|
|
81
|
-
while True:
|
|
82
|
-
msg_in: AxisArray = yield msg_out
|
|
83
|
-
ax_idx = msg_in.get_axis_idx(axis)
|
|
84
|
-
in_shape = msg_in.data.shape[:ax_idx] + msg_in.data.shape[ax_idx + 1 :]
|
|
85
|
-
|
|
86
|
-
b_reset = msg_in.data.dtype.kind != check_input["kind"]
|
|
87
|
-
b_reset = b_reset or msg_in.axes[axis].gain != check_input["gain"]
|
|
88
|
-
b_reset = b_reset or in_shape != check_input["shape"]
|
|
89
|
-
b_reset = b_reset or msg_in.key != check_input["key"]
|
|
90
|
-
b_reset = b_reset and msg_in.data.size > 0
|
|
91
|
-
if b_reset:
|
|
92
|
-
check_input["kind"] = msg_in.data.dtype.kind
|
|
93
|
-
check_input["gain"] = msg_in.axes[axis].gain
|
|
94
|
-
check_input["shape"] = in_shape
|
|
95
|
-
check_input["key"] = msg_in.key
|
|
96
|
-
|
|
97
|
-
if frequencies is not None:
|
|
98
|
-
scales = pywt.frequency2scale(
|
|
99
|
-
wavelet, frequencies * msg_in.axes[axis].gain, precision=precision
|
|
100
|
-
)
|
|
101
|
-
neg_rt_scales = -np.sqrt(scales)[:, None]
|
|
102
|
-
|
|
103
|
-
# convert int_psi, wave_xvec to the same precision as the data
|
|
104
|
-
dt_data = msg_in.data.dtype # _check_dtype(msg_in.data)
|
|
105
|
-
dt_cplx = np.result_type(dt_data, np.complex64)
|
|
106
|
-
dt_psi = dt_cplx if int_psi.dtype.kind == "c" else dt_data
|
|
107
|
-
int_psi = np.asarray(int_psi, dtype=dt_psi)
|
|
108
|
-
# TODO: Currently int_psi cannot be made non-complex once it is complex.
|
|
109
|
-
|
|
110
|
-
# Calculate waves for each scale
|
|
111
|
-
wave_xvec = np.asarray(wave_xvec, dtype=msg_in.data.real.dtype)
|
|
112
|
-
wave_range = wave_xvec[-1] - wave_xvec[0]
|
|
113
|
-
step = wave_xvec[1] - wave_xvec[0]
|
|
114
|
-
int_psi_scales = []
|
|
115
|
-
for scale in scales:
|
|
116
|
-
reix = (np.arange(scale * wave_range + 1) / (scale * step)).astype(int)
|
|
117
|
-
if reix[-1] >= int_psi.size:
|
|
118
|
-
reix = np.extract(reix < int_psi.size, reix)
|
|
119
|
-
int_psi_scales.append(int_psi[reix][::-1])
|
|
120
|
-
|
|
121
|
-
# CONV is probably best because we often get huge kernels.
|
|
122
|
-
fbgen = filterbank(
|
|
123
|
-
int_psi_scales, mode=FilterbankMode.CONV, min_phase=min_phase, axis=axis
|
|
124
|
-
)
|
|
125
39
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
40
|
+
class CWTTransformer(BaseStatefulTransformer[CWTSettings, AxisArray, AxisArray, CWTState]):
|
|
41
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
42
|
+
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
43
|
+
in_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
|
|
44
|
+
return hash(
|
|
45
|
+
(
|
|
46
|
+
message.data.dtype.kind,
|
|
47
|
+
message.axes[self.settings.axis].gain,
|
|
48
|
+
in_shape,
|
|
49
|
+
message.key,
|
|
129
50
|
)
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
54
|
+
precision = 10
|
|
55
|
+
|
|
56
|
+
# Process wavelet
|
|
57
|
+
wavelet = (
|
|
58
|
+
self.settings.wavelet
|
|
59
|
+
if isinstance(self.settings.wavelet, (pywt.ContinuousWavelet, pywt.Wavelet))
|
|
60
|
+
else pywt.DiscreteContinuousWavelet(self.settings.wavelet)
|
|
61
|
+
)
|
|
62
|
+
# Process wavelet integration
|
|
63
|
+
int_psi, wave_xvec = pywt.integrate_wavelet(wavelet, precision=precision)
|
|
64
|
+
int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
|
|
65
|
+
|
|
66
|
+
# Calculate scales and frequencies
|
|
67
|
+
if self.settings.frequencies is not None:
|
|
68
|
+
frequencies = np.sort(np.array(self.settings.frequencies))
|
|
69
|
+
scales = pywt.frequency2scale(
|
|
70
|
+
wavelet,
|
|
71
|
+
frequencies * message.axes[self.settings.axis].gain,
|
|
72
|
+
precision=precision,
|
|
147
73
|
)
|
|
74
|
+
else:
|
|
75
|
+
scales = np.sort(self.settings.scales)[::-1]
|
|
76
|
+
|
|
77
|
+
self._state.neg_rt_scales = -np.sqrt(scales)[:, None]
|
|
78
|
+
|
|
79
|
+
# Convert to appropriate dtype
|
|
80
|
+
dt_data = message.data.dtype
|
|
81
|
+
dt_cplx = np.result_type(dt_data, np.complex64)
|
|
82
|
+
dt_psi = dt_cplx if int_psi.dtype.kind == "c" else dt_data
|
|
83
|
+
int_psi = np.asarray(int_psi, dtype=dt_psi)
|
|
84
|
+
# Note: Currently int_psi cannot be made non-complex once it is complex.
|
|
85
|
+
|
|
86
|
+
# Calculate waves for each scale
|
|
87
|
+
wave_xvec = np.asarray(wave_xvec, dtype=message.data.real.dtype)
|
|
88
|
+
wave_range = wave_xvec[-1] - wave_xvec[0]
|
|
89
|
+
step = wave_xvec[1] - wave_xvec[0]
|
|
90
|
+
self._state.int_psi_scales = []
|
|
91
|
+
for scale in scales:
|
|
92
|
+
reix = (np.arange(scale * wave_range + 1) / (scale * step)).astype(int)
|
|
93
|
+
if reix[-1] >= int_psi.size:
|
|
94
|
+
reix = np.extract(reix < int_psi.size, reix)
|
|
95
|
+
self._state.int_psi_scales.append(int_psi[reix][::-1])
|
|
96
|
+
|
|
97
|
+
# Setup filterbank generator
|
|
98
|
+
self._state.fbgen = filterbank(
|
|
99
|
+
self._state.int_psi_scales,
|
|
100
|
+
mode=FilterbankMode.CONV,
|
|
101
|
+
min_phase=self.settings.min_phase,
|
|
102
|
+
axis=self.settings.axis,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Create output template
|
|
106
|
+
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
107
|
+
in_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
|
|
108
|
+
freqs = pywt.scale2frequency(wavelet, scales, precision) / message.axes[self.settings.axis].gain
|
|
109
|
+
dummy_shape = in_shape + (len(scales), 0)
|
|
110
|
+
self._state.template = AxisArray(
|
|
111
|
+
np.zeros(dummy_shape, dtype=dt_cplx if wavelet.complex_cwt else dt_data),
|
|
112
|
+
dims=message.dims[:ax_idx] + message.dims[ax_idx + 1 :] + ["freq", self.settings.axis],
|
|
113
|
+
axes={
|
|
114
|
+
**message.axes,
|
|
115
|
+
"freq": AxisArray.CoordinateAxis(unit="Hz", data=freqs, dims=["freq"]),
|
|
116
|
+
},
|
|
117
|
+
key=message.key,
|
|
118
|
+
)
|
|
119
|
+
self._state.last_conv_samp = np.zeros(dummy_shape[:-1] + (1,), dtype=self._state.template.data.dtype)
|
|
148
120
|
|
|
149
|
-
|
|
121
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
122
|
+
conv_msg = self._state.fbgen.send(message)
|
|
150
123
|
|
|
151
124
|
# Prepend with last_conv_samp before doing diff
|
|
152
|
-
dat = np.concatenate((last_conv_samp, conv_msg.data), axis=-1)
|
|
153
|
-
coef = neg_rt_scales * np.diff(dat, axis=-1)
|
|
154
|
-
# Store last_conv_samp for next iteration
|
|
155
|
-
last_conv_samp = conv_msg.data[..., -1:]
|
|
125
|
+
dat = np.concatenate((self._state.last_conv_samp, conv_msg.data), axis=-1)
|
|
126
|
+
coef = self._state.neg_rt_scales * np.diff(dat, axis=-1)
|
|
127
|
+
# Store last_conv_samp for next iteration
|
|
128
|
+
self._state.last_conv_samp = conv_msg.data[..., -1:]
|
|
156
129
|
|
|
157
|
-
if template.data.dtype.kind != "c":
|
|
130
|
+
if self._state.template.data.dtype.kind != "c":
|
|
158
131
|
coef = coef.real
|
|
159
132
|
|
|
160
133
|
# pywt.cwt slices off the beginning and end of the result where the convolution overran. We don't have
|
|
161
134
|
# that luxury when streaming.
|
|
162
135
|
# d = (coef.shape[-1] - msg_in.data.shape[ax_idx]) / 2.
|
|
163
136
|
# coef = coef[..., math.floor(d):-math.ceil(d)]
|
|
164
|
-
|
|
165
|
-
template,
|
|
137
|
+
return replace(
|
|
138
|
+
self._state.template,
|
|
139
|
+
data=coef,
|
|
140
|
+
axes={
|
|
141
|
+
**self._state.template.axes,
|
|
142
|
+
self.settings.axis: message.axes[self.settings.axis],
|
|
143
|
+
},
|
|
166
144
|
)
|
|
167
145
|
|
|
168
146
|
|
|
169
|
-
class CWTSettings
|
|
170
|
-
|
|
171
|
-
Settings for :obj:`CWT`
|
|
172
|
-
See :obj:`cwt` for argument details.
|
|
173
|
-
"""
|
|
174
|
-
|
|
175
|
-
scales: list | tuple | npt.NDArray
|
|
176
|
-
wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet
|
|
177
|
-
min_phase: MinPhaseMode = MinPhaseMode.NONE
|
|
178
|
-
axis: str = "time"
|
|
147
|
+
class CWT(BaseTransformerUnit[CWTSettings, AxisArray, AxisArray, CWTTransformer]):
|
|
148
|
+
SETTINGS = CWTSettings
|
|
179
149
|
|
|
180
150
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
:
|
|
151
|
+
def cwt(
|
|
152
|
+
frequencies: list | tuple | npt.NDArray | None,
|
|
153
|
+
wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet,
|
|
154
|
+
min_phase: MinPhaseMode = MinPhaseMode.NONE,
|
|
155
|
+
axis: str = "time",
|
|
156
|
+
scales: list | tuple | npt.NDArray | None = None,
|
|
157
|
+
) -> CWTTransformer:
|
|
184
158
|
"""
|
|
159
|
+
Perform a continuous wavelet transform.
|
|
160
|
+
The function is equivalent to the :obj:`pywt.cwt` function, but is designed to work with streaming data.
|
|
185
161
|
|
|
186
|
-
|
|
162
|
+
Args:
|
|
163
|
+
frequencies: The wavelet frequencies to use in Hz. If `None` provided then the scales will be used.
|
|
164
|
+
Note: frequencies will be sorted from smallest to largest.
|
|
165
|
+
wavelet: Wavelet object or name of wavelet to use.
|
|
166
|
+
min_phase: See filterbank MinPhaseMode for details.
|
|
167
|
+
axis: The target axis for operation. Note that this will be moved to the -1th dimension
|
|
168
|
+
because fft and matrix multiplication is much faster on the last axis.
|
|
169
|
+
This axis must be in the msg.axes and it must be of type AxisArray.LinearAxis.
|
|
170
|
+
scales: The scales to use. If None, the scales will be calculated from the frequencies.
|
|
171
|
+
Note: Scales will be sorted from largest to smallest.
|
|
172
|
+
Note: Use of scales is deprecated in favor of frequencies. Convert scales to frequencies using
|
|
173
|
+
`pywt.scale2frequency(wavelet, scales, precision=10) * fs` where fs is the sampling frequency.
|
|
187
174
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
175
|
+
Returns:
|
|
176
|
+
A primed Generator object that expects an :obj:`AxisArray` via `.send(axis_array)` of continuous data
|
|
177
|
+
and yields an :obj:`AxisArray` with a continuous wavelet transform in its data.
|
|
178
|
+
"""
|
|
179
|
+
return CWTTransformer(
|
|
180
|
+
CWTSettings(
|
|
181
|
+
frequencies=frequencies,
|
|
182
|
+
wavelet=wavelet,
|
|
183
|
+
min_phase=min_phase,
|
|
184
|
+
axis=axis,
|
|
185
|
+
scales=scales,
|
|
194
186
|
)
|
|
187
|
+
)
|