ezmsg-sigproc 1.2.2__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +1 -1
- ezmsg/sigproc/__version__.py +16 -1
- ezmsg/sigproc/activation.py +75 -0
- ezmsg/sigproc/affinetransform.py +234 -0
- ezmsg/sigproc/aggregate.py +158 -0
- ezmsg/sigproc/bandpower.py +74 -0
- ezmsg/sigproc/base.py +38 -0
- ezmsg/sigproc/butterworthfilter.py +102 -11
- ezmsg/sigproc/decimate.py +7 -4
- ezmsg/sigproc/downsample.py +95 -51
- ezmsg/sigproc/ewmfilter.py +38 -16
- ezmsg/sigproc/filter.py +108 -20
- ezmsg/sigproc/filterbank.py +278 -0
- ezmsg/sigproc/math/__init__.py +0 -0
- ezmsg/sigproc/math/abs.py +28 -0
- ezmsg/sigproc/math/clip.py +30 -0
- ezmsg/sigproc/math/difference.py +60 -0
- ezmsg/sigproc/math/invert.py +29 -0
- ezmsg/sigproc/math/log.py +32 -0
- ezmsg/sigproc/math/scale.py +31 -0
- ezmsg/sigproc/messages.py +2 -3
- ezmsg/sigproc/sampler.py +259 -224
- ezmsg/sigproc/scaler.py +173 -0
- ezmsg/sigproc/signalinjector.py +64 -0
- ezmsg/sigproc/slicer.py +133 -0
- ezmsg/sigproc/spectral.py +6 -132
- ezmsg/sigproc/spectrogram.py +86 -0
- ezmsg/sigproc/spectrum.py +259 -0
- ezmsg/sigproc/synth.py +299 -105
- ezmsg/sigproc/wavelets.py +167 -0
- ezmsg/sigproc/window.py +254 -116
- ezmsg_sigproc-1.3.1.dist-info/METADATA +59 -0
- ezmsg_sigproc-1.3.1.dist-info/RECORD +35 -0
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-1.3.1.dist-info}/WHEEL +1 -2
- ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
- ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
- ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-1.3.1.dist-info/licenses}/LICENSE.txt +0 -0
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
from dataclasses import replace
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numpy.typing as npt
|
|
6
|
+
import pywt
|
|
7
|
+
import ezmsg.core as ez
|
|
8
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
|
+
from ezmsg.util.generator import consumer
|
|
10
|
+
|
|
11
|
+
from .base import GenAxisArray
|
|
12
|
+
from .filterbank import filterbank, FilterbankMode, MinPhaseMode
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@consumer
|
|
16
|
+
def cwt(
|
|
17
|
+
scales: typing.Union[list, tuple, npt.NDArray],
|
|
18
|
+
wavelet: typing.Union[str, pywt.ContinuousWavelet, pywt.Wavelet],
|
|
19
|
+
min_phase: MinPhaseMode = MinPhaseMode.NONE,
|
|
20
|
+
axis: str = "time",
|
|
21
|
+
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
22
|
+
"""
|
|
23
|
+
Build a generator to perform a continuous wavelet transform on sent AxisArray messages.
|
|
24
|
+
The function is equivalent to the `pywt.cwt` function, but is designed to work with streaming data.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
scales: The wavelet scales to use.
|
|
28
|
+
wavelet: Wavelet object or name of wavelet to use.
|
|
29
|
+
min_phase: See filterbank MinPhaseMode for details.
|
|
30
|
+
axis: The target axis for operation. Note that this will be moved to the -1th dimension
|
|
31
|
+
because fft and matrix multiplication is much faster on the last axis.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
A Generator object that expects `.send(axis_array)` of continuous data
|
|
35
|
+
"""
|
|
36
|
+
msg_out: typing.Optional[AxisArray] = None
|
|
37
|
+
|
|
38
|
+
# Check parameters
|
|
39
|
+
scales = np.array(scales)
|
|
40
|
+
assert np.all(scales > 0), "Scales must be positive."
|
|
41
|
+
assert scales.ndim == 1, "Scales must be a 1D list, tuple, or array."
|
|
42
|
+
if not isinstance(wavelet, (pywt.ContinuousWavelet, pywt.Wavelet)):
|
|
43
|
+
wavelet = pywt.DiscreteContinuousWavelet(wavelet)
|
|
44
|
+
precision = 10
|
|
45
|
+
|
|
46
|
+
# State variables
|
|
47
|
+
neg_rt_scales = -np.sqrt(scales)[:, None]
|
|
48
|
+
int_psi, wave_xvec = pywt.integrate_wavelet(wavelet, precision=precision)
|
|
49
|
+
int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
|
|
50
|
+
template: typing.Optional[AxisArray] = None
|
|
51
|
+
fbgen: typing.Optional[typing.Generator[AxisArray, AxisArray, None]] = None
|
|
52
|
+
last_conv_samp: typing.Optional[npt.NDArray] = None
|
|
53
|
+
|
|
54
|
+
# Reset if input changed
|
|
55
|
+
check_input = {
|
|
56
|
+
"kind": None, # Need to recalc kernels at same complexity as input
|
|
57
|
+
"gain": None, # Need to recalc freqs
|
|
58
|
+
"shape": None, # Need to recalc template and buffer
|
|
59
|
+
"key": None, # Buffer obsolete
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
while True:
|
|
63
|
+
msg_in: AxisArray = yield msg_out
|
|
64
|
+
ax_idx = msg_in.get_axis_idx(axis)
|
|
65
|
+
in_shape = msg_in.data.shape[:ax_idx] + msg_in.data.shape[ax_idx + 1 :]
|
|
66
|
+
|
|
67
|
+
b_reset = msg_in.data.dtype.kind != check_input["kind"]
|
|
68
|
+
b_reset = b_reset or msg_in.axes[axis].gain != check_input["gain"]
|
|
69
|
+
b_reset = b_reset or in_shape != check_input["shape"]
|
|
70
|
+
b_reset = b_reset or msg_in.key != check_input["key"]
|
|
71
|
+
b_reset = b_reset and msg_in.data.size > 0
|
|
72
|
+
if b_reset:
|
|
73
|
+
check_input["kind"] = msg_in.data.dtype.kind
|
|
74
|
+
check_input["gain"] = msg_in.axes[axis].gain
|
|
75
|
+
check_input["shape"] = in_shape
|
|
76
|
+
check_input["key"] = msg_in.key
|
|
77
|
+
|
|
78
|
+
# convert int_psi, wave_xvec to the same precision as the data
|
|
79
|
+
dt_data = msg_in.data.dtype # _check_dtype(msg_in.data)
|
|
80
|
+
dt_cplx = np.result_type(dt_data, np.complex64)
|
|
81
|
+
dt_psi = dt_cplx if int_psi.dtype.kind == "c" else dt_data
|
|
82
|
+
int_psi = np.asarray(int_psi, dtype=dt_psi)
|
|
83
|
+
# TODO: Currently int_psi cannot be made non-complex once it is complex.
|
|
84
|
+
|
|
85
|
+
# Calculate waves for each scale
|
|
86
|
+
wave_xvec = np.asarray(wave_xvec, dtype=msg_in.data.real.dtype)
|
|
87
|
+
wave_range = wave_xvec[-1] - wave_xvec[0]
|
|
88
|
+
step = wave_xvec[1] - wave_xvec[0]
|
|
89
|
+
int_psi_scales = []
|
|
90
|
+
for scale in scales:
|
|
91
|
+
reix = (np.arange(scale * wave_range + 1) / (scale * step)).astype(int)
|
|
92
|
+
if reix[-1] >= int_psi.size:
|
|
93
|
+
reix = np.extract(reix < int_psi.size, reix)
|
|
94
|
+
int_psi_scales.append(int_psi[reix][::-1])
|
|
95
|
+
|
|
96
|
+
# CONV is probably best because we often get huge kernels.
|
|
97
|
+
fbgen = filterbank(
|
|
98
|
+
int_psi_scales, mode=FilterbankMode.CONV, min_phase=min_phase, axis=axis
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
freqs = (
|
|
102
|
+
pywt.scale2frequency(wavelet, scales, precision)
|
|
103
|
+
/ msg_in.axes[axis].gain
|
|
104
|
+
)
|
|
105
|
+
fstep = (freqs[1] - freqs[0]) if len(freqs) > 1 else 1.0
|
|
106
|
+
# Create output template
|
|
107
|
+
dummy_shape = in_shape + (len(scales), 0)
|
|
108
|
+
template = AxisArray(
|
|
109
|
+
np.zeros(
|
|
110
|
+
dummy_shape, dtype=dt_cplx if wavelet.complex_cwt else dt_data
|
|
111
|
+
),
|
|
112
|
+
dims=msg_in.dims[:ax_idx] + msg_in.dims[ax_idx + 1 :] + ["freq", axis],
|
|
113
|
+
axes={
|
|
114
|
+
**msg_in.axes,
|
|
115
|
+
"freq": AxisArray.Axis("Hz", offset=freqs[0], gain=fstep),
|
|
116
|
+
},
|
|
117
|
+
)
|
|
118
|
+
last_conv_samp = np.zeros(
|
|
119
|
+
dummy_shape[:-1] + (1,), dtype=template.data.dtype
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
conv_msg = fbgen.send(msg_in)
|
|
123
|
+
|
|
124
|
+
# Prepend with last_conv_samp before doing diff
|
|
125
|
+
dat = np.concatenate((last_conv_samp, conv_msg.data), axis=-1)
|
|
126
|
+
coef = neg_rt_scales * np.diff(dat, axis=-1)
|
|
127
|
+
# Store last_conv_samp for next iteration.
|
|
128
|
+
last_conv_samp = conv_msg.data[..., -1:]
|
|
129
|
+
|
|
130
|
+
if template.data.dtype.kind != "c":
|
|
131
|
+
coef = coef.real
|
|
132
|
+
|
|
133
|
+
# pywt.cwt slices off the beginning and end of the result where the convolution overran. We don't have
|
|
134
|
+
# that luxury when streaming.
|
|
135
|
+
# d = (coef.shape[-1] - msg_in.data.shape[ax_idx]) / 2.
|
|
136
|
+
# coef = coef[..., math.floor(d):-math.ceil(d)]
|
|
137
|
+
msg_out = replace(
|
|
138
|
+
template, data=coef, axes={**template.axes, axis: msg_in.axes[axis]}
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class CWTSettings(ez.Settings):
|
|
143
|
+
"""
|
|
144
|
+
Settings for :obj:`CWT`
|
|
145
|
+
See :obj:`cwt` for argument details.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
scales: typing.Union[list, tuple, npt.NDArray]
|
|
149
|
+
wavelet: typing.Union[str, pywt.ContinuousWavelet, pywt.Wavelet]
|
|
150
|
+
min_phase: MinPhaseMode = MinPhaseMode.NONE
|
|
151
|
+
axis: str = "time"
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class CWT(GenAxisArray):
|
|
155
|
+
"""
|
|
156
|
+
:obj:`Unit` for :obj:`common_rereference`.
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
SETTINGS = CWTSettings
|
|
160
|
+
|
|
161
|
+
def construct_generator(self):
|
|
162
|
+
self.STATE.gen = cwt(
|
|
163
|
+
scales=self.SETTINGS.scales,
|
|
164
|
+
wavelet=self.SETTINGS.wavelet,
|
|
165
|
+
min_phase=self.SETTINGS.min_phase,
|
|
166
|
+
axis=self.SETTINGS.axis,
|
|
167
|
+
)
|
ezmsg/sigproc/window.py
CHANGED
|
@@ -1,144 +1,282 @@
|
|
|
1
1
|
from dataclasses import replace
|
|
2
|
+
import traceback
|
|
3
|
+
import typing
|
|
2
4
|
|
|
3
5
|
import ezmsg.core as ez
|
|
4
6
|
import numpy as np
|
|
5
7
|
import numpy.typing as npt
|
|
8
|
+
from ezmsg.util.messages.axisarray import (
|
|
9
|
+
AxisArray,
|
|
10
|
+
slice_along_axis,
|
|
11
|
+
sliding_win_oneaxis,
|
|
12
|
+
)
|
|
13
|
+
from ezmsg.util.generator import consumer
|
|
6
14
|
|
|
7
|
-
from
|
|
15
|
+
from .base import GenAxisArray
|
|
8
16
|
|
|
9
|
-
from typing import AsyncGenerator, Optional, Tuple, List
|
|
10
17
|
|
|
18
|
+
@consumer
|
|
19
|
+
def windowing(
|
|
20
|
+
axis: typing.Optional[str] = None,
|
|
21
|
+
newaxis: str = "win",
|
|
22
|
+
window_dur: typing.Optional[float] = None,
|
|
23
|
+
window_shift: typing.Optional[float] = None,
|
|
24
|
+
zero_pad_until: str = "input",
|
|
25
|
+
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
26
|
+
"""
|
|
27
|
+
Construct a generator that yields windows of data from an input :obj:`AxisArray`.
|
|
11
28
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
29
|
+
Args:
|
|
30
|
+
axis: The axis along which to segment windows.
|
|
31
|
+
If None, defaults to the first dimension of the first seen AxisArray.
|
|
32
|
+
newaxis: New axis on which windows are delimited, immediately
|
|
33
|
+
preceding the target windowed axis. The data length along newaxis may be 0 if
|
|
34
|
+
this most recent push did not provide enough data for a new window.
|
|
35
|
+
If window_shift is None then the newaxis length will always be 1.
|
|
36
|
+
window_dur: The duration of the window in seconds.
|
|
37
|
+
If None, the function acts as a passthrough and all other parameters are ignored.
|
|
38
|
+
window_shift: The shift of the window in seconds.
|
|
39
|
+
If None (default), windowing operates in "1:1 mode", where each input yields exactly one most-recent window.
|
|
40
|
+
zero_pad_until: Determines how the function initializes the buffer.
|
|
41
|
+
Can be one of "input" (default), "full", "shift", or "none". If `window_shift` is None then this field is
|
|
42
|
+
ignored and "input" is always used.
|
|
21
43
|
|
|
44
|
+
- "input" (default) initializes the buffer with the input then prepends with zeros to the window size.
|
|
45
|
+
The first input will always yield at least one output.
|
|
46
|
+
- "shift" fills the buffer until `window_shift`.
|
|
47
|
+
No outputs will be yielded until at least `window_shift` data has been seen.
|
|
48
|
+
- "none" does not pad the buffer. No outputs will be yielded until at least `window_dur` data has been seen.
|
|
22
49
|
|
|
23
|
-
|
|
24
|
-
|
|
50
|
+
Returns:
|
|
51
|
+
A (primed) generator that accepts .send(an AxisArray object) and yields a list of windowed
|
|
52
|
+
AxisArray objects. The list will always be length-1 if `newaxis` is not None or `window_shift` is None.
|
|
53
|
+
"""
|
|
54
|
+
# Check arguments
|
|
55
|
+
if newaxis is None:
|
|
56
|
+
ez.logger.warning("`newaxis` must not be None. Setting to 'win'.")
|
|
57
|
+
newaxis = "win"
|
|
58
|
+
if window_shift is None and zero_pad_until != "input":
|
|
59
|
+
ez.logger.warning(
|
|
60
|
+
"`zero_pad_until` must be 'input' if `window_shift` is None. "
|
|
61
|
+
f"Ignoring received argument value: {zero_pad_until}"
|
|
62
|
+
)
|
|
63
|
+
zero_pad_until = "input"
|
|
64
|
+
elif window_shift is not None and zero_pad_until == "input":
|
|
65
|
+
ez.logger.warning(
|
|
66
|
+
"windowing is non-deterministic with `zero_pad_until='input'` as it depends on the size "
|
|
67
|
+
"of the first input. We recommend using 'shift' when `window_shift` is float-valued."
|
|
68
|
+
)
|
|
69
|
+
msg_out = AxisArray(np.array([]), dims=[""])
|
|
25
70
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
71
|
+
# State variables
|
|
72
|
+
buffer: typing.Optional[npt.NDArray] = None
|
|
73
|
+
window_samples: typing.Optional[int] = None
|
|
74
|
+
window_shift_samples: typing.Optional[int] = None
|
|
75
|
+
# Number of incoming samples to ignore. Only relevant when shift > window.:
|
|
76
|
+
shift_deficit: int = 0
|
|
77
|
+
b_1to1 = window_shift is None
|
|
78
|
+
newaxis_warned: bool = b_1to1
|
|
79
|
+
out_newaxis: typing.Optional[AxisArray.Axis] = None
|
|
80
|
+
out_dims: typing.typing.Optional[typing.List[str]] = None
|
|
31
81
|
|
|
82
|
+
check_inputs = {"samp_shape": None, "fs": None, "key": None}
|
|
32
83
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
SETTINGS: WindowSettings
|
|
84
|
+
while True:
|
|
85
|
+
msg_in: AxisArray = yield msg_out
|
|
36
86
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
87
|
+
if window_dur is None:
|
|
88
|
+
msg_out = msg_in
|
|
89
|
+
continue
|
|
40
90
|
|
|
41
|
-
|
|
42
|
-
|
|
91
|
+
axis = axis or msg_in.dims[0]
|
|
92
|
+
axis_idx = msg_in.get_axis_idx(axis)
|
|
93
|
+
axis_info = msg_in.get_axis(axis)
|
|
94
|
+
fs = 1.0 / axis_info.gain
|
|
43
95
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
96
|
+
if not newaxis_warned and newaxis in msg_in.dims:
|
|
97
|
+
ez.logger.warning(
|
|
98
|
+
f"newaxis {newaxis} present in input dims. Using {newaxis}_win instead"
|
|
99
|
+
)
|
|
100
|
+
newaxis_warned = True
|
|
101
|
+
newaxis = f"{newaxis}_win"
|
|
48
102
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
self.STATE.window_samples = window_samples
|
|
77
|
-
self.STATE.samp_shape = samp_shape
|
|
78
|
-
self.STATE.out_fs = fs
|
|
79
|
-
|
|
80
|
-
self.STATE.window_shift_samples = None
|
|
81
|
-
if self.STATE.cur_settings.window_shift is not None:
|
|
82
|
-
self.STATE.window_shift_samples = int(
|
|
83
|
-
fs * self.STATE.cur_settings.window_shift
|
|
103
|
+
samp_shape = msg_in.data.shape[:axis_idx] + msg_in.data.shape[axis_idx + 1 :]
|
|
104
|
+
|
|
105
|
+
# If buffer unset or input stats changed, create a new buffer
|
|
106
|
+
b_reset = buffer is None
|
|
107
|
+
b_reset = b_reset or samp_shape != check_inputs["samp_shape"]
|
|
108
|
+
b_reset = b_reset or fs != check_inputs["fs"]
|
|
109
|
+
b_reset = b_reset or msg_in.key != check_inputs["key"]
|
|
110
|
+
if b_reset:
|
|
111
|
+
# Update check variables
|
|
112
|
+
check_inputs["samp_shape"] = samp_shape
|
|
113
|
+
check_inputs["fs"] = fs
|
|
114
|
+
check_inputs["key"] = msg_in.key
|
|
115
|
+
|
|
116
|
+
window_samples = int(window_dur * fs)
|
|
117
|
+
if not b_1to1:
|
|
118
|
+
window_shift_samples = int(window_shift * fs)
|
|
119
|
+
if zero_pad_until == "none":
|
|
120
|
+
req_samples = window_samples
|
|
121
|
+
elif zero_pad_until == "shift" and not b_1to1:
|
|
122
|
+
req_samples = window_shift_samples
|
|
123
|
+
else: # i.e. zero_pad_until == "input"
|
|
124
|
+
req_samples = msg_in.data.shape[axis_idx]
|
|
125
|
+
n_zero = max(0, window_samples - req_samples)
|
|
126
|
+
buffer = np.zeros(
|
|
127
|
+
msg_in.data.shape[:axis_idx]
|
|
128
|
+
+ (n_zero,)
|
|
129
|
+
+ msg_in.data.shape[axis_idx + 1 :]
|
|
84
130
|
)
|
|
85
131
|
|
|
86
|
-
#
|
|
87
|
-
#
|
|
88
|
-
#
|
|
132
|
+
# Add new data to buffer.
|
|
133
|
+
# Currently, we concatenate the new time samples and clip the output.
|
|
134
|
+
# np.roll is not preferred as it returns a copy, and there's no way to construct a
|
|
135
|
+
# rolling view of the data. In current numpy implementations, np.concatenate
|
|
89
136
|
# is generally faster than np.roll and slicing anyway, but this could still
|
|
90
137
|
# be a performance bottleneck for large memory arrays.
|
|
91
|
-
|
|
138
|
+
# A circular buffer might be faster.
|
|
139
|
+
buffer = np.concatenate((buffer, msg_in.data), axis=axis_idx)
|
|
92
140
|
|
|
93
|
-
|
|
94
|
-
buffer_offset
|
|
95
|
-
|
|
141
|
+
# Create a vector of buffer timestamps to track axis `offset` in output(s)
|
|
142
|
+
buffer_offset = np.arange(buffer.shape[axis_idx]).astype(float)
|
|
143
|
+
# Adjust so first _new_ sample at index 0
|
|
144
|
+
buffer_offset -= buffer_offset[-msg_in.data.shape[axis_idx]]
|
|
145
|
+
# Convert form indices to 'units' (probably seconds).
|
|
146
|
+
buffer_offset *= axis_info.gain
|
|
147
|
+
buffer_offset += axis_info.offset
|
|
96
148
|
|
|
97
|
-
|
|
149
|
+
if not b_1to1 and shift_deficit > 0:
|
|
150
|
+
n_skip = min(buffer.shape[axis_idx], shift_deficit)
|
|
151
|
+
if n_skip > 0:
|
|
152
|
+
buffer = slice_along_axis(buffer, slice(n_skip, None), axis_idx)
|
|
153
|
+
buffer_offset = buffer_offset[n_skip:]
|
|
154
|
+
shift_deficit -= n_skip
|
|
98
155
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
156
|
+
# Prepare reusable parts of output
|
|
157
|
+
if out_newaxis is None:
|
|
158
|
+
out_dims = msg_in.dims[:axis_idx] + [newaxis] + msg_in.dims[axis_idx:]
|
|
159
|
+
out_newaxis = replace(
|
|
160
|
+
axis_info,
|
|
161
|
+
gain=0.0 if b_1to1 else axis_info.gain * window_shift_samples,
|
|
162
|
+
offset=0.0, # offset modified per-msg below
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# Generate outputs.
|
|
166
|
+
# Preliminary copy of axes without the axes that we are modifying.
|
|
167
|
+
out_axes = {k: v for k, v in msg_in.axes.items() if k not in [newaxis, axis]}
|
|
168
|
+
|
|
169
|
+
# Update targeted (windowed) axis so that its offset is relative to the new axis
|
|
170
|
+
# TODO: If we have `anchor_newest=True` then offset should be -win_dur
|
|
171
|
+
out_axes[axis] = replace(axis_info, offset=0.0)
|
|
172
|
+
|
|
173
|
+
# How we update .data and .axes[newaxis] depends on the windowing mode.
|
|
174
|
+
if b_1to1:
|
|
175
|
+
# one-to-one mode -- Each send yields exactly one window containing only the most recent samples.
|
|
176
|
+
buffer = slice_along_axis(buffer, slice(-window_samples, None), axis_idx)
|
|
177
|
+
out_dat = np.expand_dims(buffer, axis=axis_idx)
|
|
178
|
+
out_newaxis = replace(out_newaxis, offset=buffer_offset[-window_samples])
|
|
179
|
+
elif buffer.shape[axis_idx] >= window_samples:
|
|
180
|
+
# Deterministic window shifts.
|
|
181
|
+
out_dat = sliding_win_oneaxis(buffer, window_samples, axis_idx)
|
|
182
|
+
out_dat = slice_along_axis(
|
|
183
|
+
out_dat, slice(None, None, window_shift_samples), axis_idx
|
|
184
|
+
)
|
|
185
|
+
offset_view = sliding_win_oneaxis(buffer_offset, window_samples, 0)[
|
|
186
|
+
::window_shift_samples
|
|
187
|
+
]
|
|
188
|
+
out_newaxis = replace(out_newaxis, offset=offset_view[0, 0])
|
|
103
189
|
|
|
190
|
+
# Drop expired beginning of buffer and update shift_deficit
|
|
191
|
+
multi_shift = window_shift_samples * out_dat.shape[axis_idx]
|
|
192
|
+
shift_deficit = max(0, multi_shift - buffer.shape[axis_idx])
|
|
193
|
+
buffer = slice_along_axis(buffer, slice(multi_shift, None), axis_idx)
|
|
104
194
|
else:
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
195
|
+
# Not enough data to make a new window. Return empty data.
|
|
196
|
+
empty_data_shape = (
|
|
197
|
+
msg_in.data.shape[:axis_idx]
|
|
198
|
+
+ (0, window_samples)
|
|
199
|
+
+ msg_in.data.shape[axis_idx + 1 :]
|
|
200
|
+
)
|
|
201
|
+
out_dat = np.zeros(empty_data_shape, dtype=msg_in.data.dtype)
|
|
202
|
+
# out_newaxis will have first timestamp in input... but mostly meaningless because output is size-zero.
|
|
203
|
+
out_newaxis = replace(out_newaxis, offset=axis_info.offset)
|
|
204
|
+
|
|
205
|
+
msg_out = replace(
|
|
206
|
+
msg_in, data=out_dat, dims=out_dims, axes={**out_axes, newaxis: out_newaxis}
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class WindowSettings(ez.Settings):
|
|
211
|
+
axis: typing.Optional[str] = None
|
|
212
|
+
newaxis: typing.Optional[str] = None # new axis for output. No new axes if None
|
|
213
|
+
window_dur: typing.Optional[float] = None # Sec. passthrough if None
|
|
214
|
+
window_shift: typing.Optional[float] = None # Sec. Use "1:1 mode" if None
|
|
215
|
+
zero_pad_until: str = "full" # "full", "shift", "input", "none"
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class WindowState(ez.State):
|
|
219
|
+
cur_settings: WindowSettings
|
|
220
|
+
gen: typing.Generator
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class Window(GenAxisArray):
|
|
224
|
+
""":obj:`Unit` for :obj:`bandpower`."""
|
|
225
|
+
|
|
226
|
+
SETTINGS = WindowSettings
|
|
227
|
+
|
|
228
|
+
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
229
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
230
|
+
|
|
231
|
+
def construct_generator(self):
|
|
232
|
+
self.STATE.gen = windowing(
|
|
233
|
+
axis=self.SETTINGS.axis,
|
|
234
|
+
newaxis=self.SETTINGS.newaxis,
|
|
235
|
+
window_dur=self.SETTINGS.window_dur,
|
|
236
|
+
window_shift=self.SETTINGS.window_shift,
|
|
237
|
+
zero_pad_until=self.SETTINGS.zero_pad_until,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
|
|
241
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
242
|
+
async def on_signal(self, msg: AxisArray) -> typing.AsyncGenerator:
|
|
243
|
+
try:
|
|
244
|
+
out_msg = self.STATE.gen.send(msg)
|
|
245
|
+
if out_msg.data.size > 0:
|
|
246
|
+
if (
|
|
247
|
+
self.SETTINGS.newaxis is not None
|
|
248
|
+
or self.SETTINGS.window_dur is None
|
|
249
|
+
):
|
|
250
|
+
# Multi-win mode or pass-through mode.
|
|
251
|
+
yield self.OUTPUT_SIGNAL, out_msg
|
|
143
252
|
else:
|
|
144
|
-
|
|
253
|
+
# We need to split out_msg into multiple yields, dropping newaxis.
|
|
254
|
+
axis_idx = out_msg.get_axis_idx("win")
|
|
255
|
+
win_axis = out_msg.axes["win"]
|
|
256
|
+
offsets = (
|
|
257
|
+
np.arange(out_msg.data.shape[axis_idx]) * win_axis.gain
|
|
258
|
+
+ win_axis.offset
|
|
259
|
+
)
|
|
260
|
+
for msg_ix in range(out_msg.data.shape[axis_idx]):
|
|
261
|
+
# Need to drop 'win' and replace self.SETTINGS.axis from axes.
|
|
262
|
+
_out_axes = {
|
|
263
|
+
**{
|
|
264
|
+
k: v
|
|
265
|
+
for k, v in out_msg.axes.items()
|
|
266
|
+
if k not in ["win", self.SETTINGS.axis]
|
|
267
|
+
},
|
|
268
|
+
self.SETTINGS.axis: replace(
|
|
269
|
+
out_msg.axes[self.SETTINGS.axis], offset=offsets[msg_ix]
|
|
270
|
+
),
|
|
271
|
+
}
|
|
272
|
+
_out_msg = replace(
|
|
273
|
+
out_msg,
|
|
274
|
+
data=slice_along_axis(out_msg.data, msg_ix, axis_idx),
|
|
275
|
+
dims=out_msg.dims[:axis_idx] + out_msg.dims[axis_idx + 1 :],
|
|
276
|
+
axes=_out_axes,
|
|
277
|
+
)
|
|
278
|
+
yield self.OUTPUT_SIGNAL, _out_msg
|
|
279
|
+
except (StopIteration, GeneratorExit):
|
|
280
|
+
ez.logger.debug(f"Window closed in {self.address}")
|
|
281
|
+
except Exception:
|
|
282
|
+
ez.logger.info(traceback.format_exc())
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: ezmsg-sigproc
|
|
3
|
+
Version: 1.3.1
|
|
4
|
+
Summary: Timeseries signal processing implementations in ezmsg
|
|
5
|
+
Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
License-File: LICENSE.txt
|
|
8
|
+
Requires-Python: >=3.9
|
|
9
|
+
Requires-Dist: ezmsg>=3.5.0
|
|
10
|
+
Requires-Dist: numpy>=2.0.2
|
|
11
|
+
Requires-Dist: pywavelets>=1.6.0
|
|
12
|
+
Requires-Dist: scipy>=1.13.1
|
|
13
|
+
Provides-Extra: test
|
|
14
|
+
Requires-Dist: flake8>=7.1.1; extra == 'test'
|
|
15
|
+
Requires-Dist: frozendict>=2.4.4; extra == 'test'
|
|
16
|
+
Requires-Dist: pytest-asyncio>=0.24.0; extra == 'test'
|
|
17
|
+
Requires-Dist: pytest-cov>=5.0.0; extra == 'test'
|
|
18
|
+
Requires-Dist: pytest>=8.3.3; extra == 'test'
|
|
19
|
+
Description-Content-Type: text/markdown
|
|
20
|
+
|
|
21
|
+
# ezmsg.sigproc
|
|
22
|
+
|
|
23
|
+
Timeseries signal processing implementations for ezmsg
|
|
24
|
+
|
|
25
|
+
## Dependencies
|
|
26
|
+
|
|
27
|
+
* `ezmsg`
|
|
28
|
+
* `numpy`
|
|
29
|
+
* `scipy`
|
|
30
|
+
* `pywavelets`
|
|
31
|
+
|
|
32
|
+
## Installation
|
|
33
|
+
|
|
34
|
+
### Release
|
|
35
|
+
|
|
36
|
+
Install the latest release from pypi with: `pip install ezmsg-sigproc` (or `uv add ...` or `poetry add ...`).
|
|
37
|
+
|
|
38
|
+
### Development Version
|
|
39
|
+
|
|
40
|
+
You can add the development version of `ezmsg-sigproc` to your project's dependencies in one of several ways.
|
|
41
|
+
|
|
42
|
+
You can clone it and add its path to your project dependencies. You may wish to do this if you intend to edit `ezmsg-sigproc`. If so, please refer to the [Developers](#developers) section below.
|
|
43
|
+
|
|
44
|
+
You can also add it directly from GitHub:
|
|
45
|
+
|
|
46
|
+
* Using `pip`: `pip install git+https://github.com/ezmsg-org/ezmsg-sigproc.git@dev`
|
|
47
|
+
* Using `poetry`: `poetry add "git+https://github.com/ezmsg-org/ezmsg-sigproc.git@dev"`
|
|
48
|
+
* Using `uv`: `uv add git+https://github.com/ezmsg-org/ezmsg-sigproc --branch dev`
|
|
49
|
+
|
|
50
|
+
## Developers
|
|
51
|
+
|
|
52
|
+
We use [`uv`](https://docs.astral.sh/uv/getting-started/installation/) for development. It is not strictly required, but if you intend to contribute to ezmsg-sigproc then using `uv` will lead to the smoothest collaboration.
|
|
53
|
+
|
|
54
|
+
1. Install [`uv`](https://docs.astral.sh/uv/getting-started/installation/) if not already installed.
|
|
55
|
+
2. Fork ezmsg-sigproc and clone your fork to your local computer.
|
|
56
|
+
3. Open a terminal and `cd` to the cloned folder.
|
|
57
|
+
4. `uv sync` to create a .venv and install dependencies.
|
|
58
|
+
5. `uv run pre-commit install` to install pre-commit hooks to do linting and formatting.
|
|
59
|
+
6. After editing code and making commits, Run the test suite before making a PR: `uv run pytest tests`
|