ezmsg-sigproc 1.2.2__py3-none-any.whl → 2.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +1 -1
- ezmsg/sigproc/__version__.py +34 -1
- ezmsg/sigproc/activation.py +78 -0
- ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
- ezmsg/sigproc/affinetransform.py +235 -0
- ezmsg/sigproc/aggregate.py +276 -0
- ezmsg/sigproc/bandpower.py +80 -0
- ezmsg/sigproc/base.py +149 -0
- ezmsg/sigproc/butterworthfilter.py +129 -39
- ezmsg/sigproc/butterworthzerophase.py +305 -0
- ezmsg/sigproc/cheby.py +125 -0
- ezmsg/sigproc/combfilter.py +160 -0
- ezmsg/sigproc/coordinatespaces.py +159 -0
- ezmsg/sigproc/decimate.py +46 -18
- ezmsg/sigproc/denormalize.py +78 -0
- ezmsg/sigproc/detrend.py +28 -0
- ezmsg/sigproc/diff.py +82 -0
- ezmsg/sigproc/downsample.py +97 -49
- ezmsg/sigproc/ewma.py +217 -0
- ezmsg/sigproc/ewmfilter.py +45 -19
- ezmsg/sigproc/extract_axis.py +39 -0
- ezmsg/sigproc/fbcca.py +307 -0
- ezmsg/sigproc/filter.py +282 -117
- ezmsg/sigproc/filterbank.py +292 -0
- ezmsg/sigproc/filterbankdesign.py +129 -0
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +117 -0
- ezmsg/sigproc/gaussiansmoothing.py +89 -0
- ezmsg/sigproc/kaiser.py +106 -0
- ezmsg/sigproc/linear.py +120 -0
- ezmsg/sigproc/math/__init__.py +0 -0
- ezmsg/sigproc/math/abs.py +35 -0
- ezmsg/sigproc/math/add.py +120 -0
- ezmsg/sigproc/math/clip.py +48 -0
- ezmsg/sigproc/math/difference.py +143 -0
- ezmsg/sigproc/math/invert.py +28 -0
- ezmsg/sigproc/math/log.py +57 -0
- ezmsg/sigproc/math/scale.py +39 -0
- ezmsg/sigproc/messages.py +3 -6
- ezmsg/sigproc/quantize.py +68 -0
- ezmsg/sigproc/resample.py +278 -0
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +232 -241
- ezmsg/sigproc/scaler.py +165 -0
- ezmsg/sigproc/signalinjector.py +70 -0
- ezmsg/sigproc/slicer.py +138 -0
- ezmsg/sigproc/spectral.py +6 -132
- ezmsg/sigproc/spectrogram.py +90 -0
- ezmsg/sigproc/spectrum.py +277 -0
- ezmsg/sigproc/transpose.py +134 -0
- ezmsg/sigproc/util/__init__.py +0 -0
- ezmsg/sigproc/util/asio.py +25 -0
- ezmsg/sigproc/util/axisarray_buffer.py +365 -0
- ezmsg/sigproc/util/buffer.py +449 -0
- ezmsg/sigproc/util/message.py +17 -0
- ezmsg/sigproc/util/profile.py +23 -0
- ezmsg/sigproc/util/sparse.py +115 -0
- ezmsg/sigproc/util/typeresolution.py +17 -0
- ezmsg/sigproc/wavelets.py +187 -0
- ezmsg/sigproc/window.py +301 -117
- ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -2
- ezmsg/sigproc/synth.py +0 -411
- ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
- ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
- ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
- /ezmsg_sigproc-1.2.2.dist-info/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
import ezmsg.core as ez
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numpy.typing as npt
|
|
6
|
+
import pywt
|
|
7
|
+
from ezmsg.baseproc import (
|
|
8
|
+
BaseStatefulTransformer,
|
|
9
|
+
BaseTransformerUnit,
|
|
10
|
+
processor_state,
|
|
11
|
+
)
|
|
12
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
13
|
+
from ezmsg.util.messages.util import replace
|
|
14
|
+
|
|
15
|
+
from .filterbank import FilterbankMode, MinPhaseMode, filterbank
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CWTSettings(ez.Settings):
|
|
19
|
+
"""
|
|
20
|
+
Settings for :obj:`CWT`
|
|
21
|
+
See :obj:`cwt` for argument details.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
frequencies: list | tuple | npt.NDArray | None
|
|
25
|
+
wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet
|
|
26
|
+
min_phase: MinPhaseMode = MinPhaseMode.NONE
|
|
27
|
+
axis: str = "time"
|
|
28
|
+
scales: list | tuple | npt.NDArray | None = None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@processor_state
|
|
32
|
+
class CWTState:
|
|
33
|
+
neg_rt_scales: npt.NDArray | None = None
|
|
34
|
+
int_psi_scales: list[npt.NDArray] | None = None
|
|
35
|
+
template: AxisArray | None = None
|
|
36
|
+
fbgen: typing.Generator[AxisArray, AxisArray, None] | None = None
|
|
37
|
+
last_conv_samp: npt.NDArray | None = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class CWTTransformer(BaseStatefulTransformer[CWTSettings, AxisArray, AxisArray, CWTState]):
|
|
41
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
42
|
+
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
43
|
+
in_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
|
|
44
|
+
return hash(
|
|
45
|
+
(
|
|
46
|
+
message.data.dtype.kind,
|
|
47
|
+
message.axes[self.settings.axis].gain,
|
|
48
|
+
in_shape,
|
|
49
|
+
message.key,
|
|
50
|
+
)
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
54
|
+
precision = 10
|
|
55
|
+
|
|
56
|
+
# Process wavelet
|
|
57
|
+
wavelet = (
|
|
58
|
+
self.settings.wavelet
|
|
59
|
+
if isinstance(self.settings.wavelet, (pywt.ContinuousWavelet, pywt.Wavelet))
|
|
60
|
+
else pywt.DiscreteContinuousWavelet(self.settings.wavelet)
|
|
61
|
+
)
|
|
62
|
+
# Process wavelet integration
|
|
63
|
+
int_psi, wave_xvec = pywt.integrate_wavelet(wavelet, precision=precision)
|
|
64
|
+
int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
|
|
65
|
+
|
|
66
|
+
# Calculate scales and frequencies
|
|
67
|
+
if self.settings.frequencies is not None:
|
|
68
|
+
frequencies = np.sort(np.array(self.settings.frequencies))
|
|
69
|
+
scales = pywt.frequency2scale(
|
|
70
|
+
wavelet,
|
|
71
|
+
frequencies * message.axes[self.settings.axis].gain,
|
|
72
|
+
precision=precision,
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
scales = np.sort(self.settings.scales)[::-1]
|
|
76
|
+
|
|
77
|
+
self._state.neg_rt_scales = -np.sqrt(scales)[:, None]
|
|
78
|
+
|
|
79
|
+
# Convert to appropriate dtype
|
|
80
|
+
dt_data = message.data.dtype
|
|
81
|
+
dt_cplx = np.result_type(dt_data, np.complex64)
|
|
82
|
+
dt_psi = dt_cplx if int_psi.dtype.kind == "c" else dt_data
|
|
83
|
+
int_psi = np.asarray(int_psi, dtype=dt_psi)
|
|
84
|
+
# Note: Currently int_psi cannot be made non-complex once it is complex.
|
|
85
|
+
|
|
86
|
+
# Calculate waves for each scale
|
|
87
|
+
wave_xvec = np.asarray(wave_xvec, dtype=message.data.real.dtype)
|
|
88
|
+
wave_range = wave_xvec[-1] - wave_xvec[0]
|
|
89
|
+
step = wave_xvec[1] - wave_xvec[0]
|
|
90
|
+
self._state.int_psi_scales = []
|
|
91
|
+
for scale in scales:
|
|
92
|
+
reix = (np.arange(scale * wave_range + 1) / (scale * step)).astype(int)
|
|
93
|
+
if reix[-1] >= int_psi.size:
|
|
94
|
+
reix = np.extract(reix < int_psi.size, reix)
|
|
95
|
+
self._state.int_psi_scales.append(int_psi[reix][::-1])
|
|
96
|
+
|
|
97
|
+
# Setup filterbank generator
|
|
98
|
+
self._state.fbgen = filterbank(
|
|
99
|
+
self._state.int_psi_scales,
|
|
100
|
+
mode=FilterbankMode.CONV,
|
|
101
|
+
min_phase=self.settings.min_phase,
|
|
102
|
+
axis=self.settings.axis,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Create output template
|
|
106
|
+
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
107
|
+
in_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
|
|
108
|
+
freqs = pywt.scale2frequency(wavelet, scales, precision) / message.axes[self.settings.axis].gain
|
|
109
|
+
dummy_shape = in_shape + (len(scales), 0)
|
|
110
|
+
self._state.template = AxisArray(
|
|
111
|
+
np.zeros(dummy_shape, dtype=dt_cplx if wavelet.complex_cwt else dt_data),
|
|
112
|
+
dims=message.dims[:ax_idx] + message.dims[ax_idx + 1 :] + ["freq", self.settings.axis],
|
|
113
|
+
axes={
|
|
114
|
+
**message.axes,
|
|
115
|
+
"freq": AxisArray.CoordinateAxis(unit="Hz", data=freqs, dims=["freq"]),
|
|
116
|
+
},
|
|
117
|
+
key=message.key,
|
|
118
|
+
)
|
|
119
|
+
self._state.last_conv_samp = np.zeros(dummy_shape[:-1] + (1,), dtype=self._state.template.data.dtype)
|
|
120
|
+
|
|
121
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
122
|
+
conv_msg = self._state.fbgen.send(message)
|
|
123
|
+
|
|
124
|
+
# Prepend with last_conv_samp before doing diff
|
|
125
|
+
dat = np.concatenate((self._state.last_conv_samp, conv_msg.data), axis=-1)
|
|
126
|
+
coef = self._state.neg_rt_scales * np.diff(dat, axis=-1)
|
|
127
|
+
# Store last_conv_samp for next iteration
|
|
128
|
+
self._state.last_conv_samp = conv_msg.data[..., -1:]
|
|
129
|
+
|
|
130
|
+
if self._state.template.data.dtype.kind != "c":
|
|
131
|
+
coef = coef.real
|
|
132
|
+
|
|
133
|
+
# pywt.cwt slices off the beginning and end of the result where the convolution overran. We don't have
|
|
134
|
+
# that luxury when streaming.
|
|
135
|
+
# d = (coef.shape[-1] - msg_in.data.shape[ax_idx]) / 2.
|
|
136
|
+
# coef = coef[..., math.floor(d):-math.ceil(d)]
|
|
137
|
+
return replace(
|
|
138
|
+
self._state.template,
|
|
139
|
+
data=coef,
|
|
140
|
+
axes={
|
|
141
|
+
**self._state.template.axes,
|
|
142
|
+
self.settings.axis: message.axes[self.settings.axis],
|
|
143
|
+
},
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class CWT(BaseTransformerUnit[CWTSettings, AxisArray, AxisArray, CWTTransformer]):
|
|
148
|
+
SETTINGS = CWTSettings
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def cwt(
|
|
152
|
+
frequencies: list | tuple | npt.NDArray | None,
|
|
153
|
+
wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet,
|
|
154
|
+
min_phase: MinPhaseMode = MinPhaseMode.NONE,
|
|
155
|
+
axis: str = "time",
|
|
156
|
+
scales: list | tuple | npt.NDArray | None = None,
|
|
157
|
+
) -> CWTTransformer:
|
|
158
|
+
"""
|
|
159
|
+
Perform a continuous wavelet transform.
|
|
160
|
+
The function is equivalent to the :obj:`pywt.cwt` function, but is designed to work with streaming data.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
frequencies: The wavelet frequencies to use in Hz. If `None` provided then the scales will be used.
|
|
164
|
+
Note: frequencies will be sorted from smallest to largest.
|
|
165
|
+
wavelet: Wavelet object or name of wavelet to use.
|
|
166
|
+
min_phase: See filterbank MinPhaseMode for details.
|
|
167
|
+
axis: The target axis for operation. Note that this will be moved to the -1th dimension
|
|
168
|
+
because fft and matrix multiplication is much faster on the last axis.
|
|
169
|
+
This axis must be in the msg.axes and it must be of type AxisArray.LinearAxis.
|
|
170
|
+
scales: The scales to use. If None, the scales will be calculated from the frequencies.
|
|
171
|
+
Note: Scales will be sorted from largest to smallest.
|
|
172
|
+
Note: Use of scales is deprecated in favor of frequencies. Convert scales to frequencies using
|
|
173
|
+
`pywt.scale2frequency(wavelet, scales, precision=10) * fs` where fs is the sampling frequency.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
A primed Generator object that expects an :obj:`AxisArray` via `.send(axis_array)` of continuous data
|
|
177
|
+
and yields an :obj:`AxisArray` with a continuous wavelet transform in its data.
|
|
178
|
+
"""
|
|
179
|
+
return CWTTransformer(
|
|
180
|
+
CWTSettings(
|
|
181
|
+
frequencies=frequencies,
|
|
182
|
+
wavelet=wavelet,
|
|
183
|
+
min_phase=min_phase,
|
|
184
|
+
axis=axis,
|
|
185
|
+
scales=scales,
|
|
186
|
+
)
|
|
187
|
+
)
|
ezmsg/sigproc/window.py
CHANGED
|
@@ -1,144 +1,328 @@
|
|
|
1
|
-
|
|
1
|
+
import enum
|
|
2
|
+
import traceback
|
|
3
|
+
import typing
|
|
2
4
|
|
|
3
5
|
import ezmsg.core as ez
|
|
4
|
-
import numpy as np
|
|
5
6
|
import numpy.typing as npt
|
|
7
|
+
import sparse
|
|
8
|
+
from array_api_compat import get_namespace, is_pydata_sparse_namespace
|
|
9
|
+
from ezmsg.baseproc import (
|
|
10
|
+
BaseStatefulTransformer,
|
|
11
|
+
BaseTransformerUnit,
|
|
12
|
+
processor_state,
|
|
13
|
+
)
|
|
14
|
+
from ezmsg.util.messages.axisarray import (
|
|
15
|
+
AxisArray,
|
|
16
|
+
replace,
|
|
17
|
+
slice_along_axis,
|
|
18
|
+
sliding_win_oneaxis,
|
|
19
|
+
)
|
|
6
20
|
|
|
7
|
-
from
|
|
21
|
+
from .util.profile import profile_subpub
|
|
22
|
+
from .util.sparse import sliding_win_oneaxis as sparse_sliding_win_oneaxis
|
|
8
23
|
|
|
9
|
-
|
|
24
|
+
|
|
25
|
+
class Anchor(enum.Enum):
|
|
26
|
+
BEGINNING = "beginning"
|
|
27
|
+
END = "end"
|
|
28
|
+
MIDDLE = "middle"
|
|
10
29
|
|
|
11
30
|
|
|
12
31
|
class WindowSettings(ez.Settings):
|
|
13
|
-
axis:
|
|
14
|
-
newaxis:
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
] = None # Sec. If "None" -- passthrough; window_shift is ignored.
|
|
20
|
-
window_shift: Optional[float] = None # Sec. If "None", activate "1:1 mode"
|
|
32
|
+
axis: str | None = None
|
|
33
|
+
newaxis: str | None = None # new axis for output. No new axes if None
|
|
34
|
+
window_dur: float | None = None # Sec. passthrough if None
|
|
35
|
+
window_shift: float | None = None # Sec. Use "1:1 mode" if None
|
|
36
|
+
zero_pad_until: str = "full" # "full", "shift", "input", "none"
|
|
37
|
+
anchor: str | Anchor = Anchor.BEGINNING
|
|
21
38
|
|
|
22
39
|
|
|
23
|
-
|
|
24
|
-
|
|
40
|
+
@processor_state
|
|
41
|
+
class WindowState:
|
|
42
|
+
buffer: npt.NDArray | sparse.SparseArray | None = None
|
|
25
43
|
|
|
26
|
-
|
|
27
|
-
out_fs: Optional[float] = None
|
|
28
|
-
buffer: Optional[npt.NDArray] = None
|
|
29
|
-
window_samples: Optional[int] = None
|
|
30
|
-
window_shift_samples: Optional[int] = None
|
|
44
|
+
window_samples: int | None = None
|
|
31
45
|
|
|
46
|
+
window_shift_samples: int | None = None
|
|
32
47
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
SETTINGS: WindowSettings
|
|
48
|
+
shift_deficit: int = 0
|
|
49
|
+
""" Number of incoming samples to ignore. Only relevant when shift > window."""
|
|
36
50
|
|
|
37
|
-
|
|
38
|
-
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
39
|
-
INPUT_SETTINGS = ez.InputStream(WindowSettings)
|
|
51
|
+
newaxis_warned: bool = False
|
|
40
52
|
|
|
41
|
-
|
|
42
|
-
self.STATE.cur_settings = self.SETTINGS
|
|
53
|
+
out_newaxis: AxisArray.LinearAxis | None = None
|
|
43
54
|
|
|
44
|
-
|
|
45
|
-
async def on_settings(self, msg: WindowSettings) -> None:
|
|
46
|
-
self.STATE.cur_settings = msg
|
|
47
|
-
self.STATE.out_fs = None # This should trigger a reallocation
|
|
55
|
+
out_dims: list[str] | None = None
|
|
48
56
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
57
|
+
|
|
58
|
+
class WindowTransformer(BaseStatefulTransformer[WindowSettings, AxisArray, AxisArray, WindowState]):
|
|
59
|
+
"""
|
|
60
|
+
Apply a sliding window along the specified axis to input streaming data.
|
|
61
|
+
The `windowing` method is perhaps the most useful and versatile method in ezmsg.sigproc, but its parameterization
|
|
62
|
+
can be difficult. Please read the argument descriptions carefully.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
axis: The axis along which to segment windows.
|
|
70
|
+
If None, defaults to the first dimension of the first seen AxisArray.
|
|
71
|
+
Note: The windowed axis must be an AxisArray.LinearAxis, not an AxisArray.CoordinateAxis.
|
|
72
|
+
newaxis: New axis on which windows are delimited, immediately
|
|
73
|
+
preceding the target windowed axis. The data length along newaxis may be 0 if
|
|
74
|
+
this most recent push did not provide enough data for a new window.
|
|
75
|
+
If window_shift is None then the newaxis length will always be 1.
|
|
76
|
+
window_dur: The duration of the window in seconds.
|
|
77
|
+
If None, the function acts as a passthrough and all other parameters are ignored.
|
|
78
|
+
window_shift: The shift of the window in seconds.
|
|
79
|
+
If None (default), windowing operates in "1:1 mode",
|
|
80
|
+
where each input yields exactly one most-recent window.
|
|
81
|
+
zero_pad_until: Determines how the function initializes the buffer.
|
|
82
|
+
Can be one of "input" (default), "full", "shift", or "none".
|
|
83
|
+
If `window_shift` is None then this field is ignored and "input" is always used.
|
|
84
|
+
|
|
85
|
+
- "input" (default) initializes the buffer with the input then prepends with zeros to the window size.
|
|
86
|
+
The first input will always yield at least one output.
|
|
87
|
+
- "shift" fills the buffer until `window_shift`.
|
|
88
|
+
No outputs will be yielded until at least `window_shift` data has been seen.
|
|
89
|
+
- "none" does not pad the buffer. No outputs will be yielded until
|
|
90
|
+
at least `window_dur` data has been seen.
|
|
91
|
+
anchor: Determines the entry in `axis` that gets assigned `0`, which references the
|
|
92
|
+
value in `newaxis`. Can be of class :obj:`Anchor` or a string representation of an :obj:`Anchor`.
|
|
93
|
+
"""
|
|
94
|
+
super().__init__(*args, **kwargs)
|
|
95
|
+
|
|
96
|
+
# Sanity-check settings
|
|
97
|
+
# if self.settings.newaxis is None:
|
|
98
|
+
# ez.logger.warning("`newaxis=None` will be replaced with `newaxis='win'`.")
|
|
99
|
+
# object.__setattr__(self.settings, "newaxis", "win")
|
|
100
|
+
if self.settings.window_shift is None and self.settings.zero_pad_until != "input":
|
|
101
|
+
ez.logger.warning(
|
|
102
|
+
"`zero_pad_until` must be 'input' if `window_shift` is None. "
|
|
103
|
+
f"Ignoring received argument value: {self.settings.zero_pad_until}"
|
|
104
|
+
)
|
|
105
|
+
object.__setattr__(self.settings, "zero_pad_until", "input")
|
|
106
|
+
elif self.settings.window_shift is not None and self.settings.zero_pad_until == "input":
|
|
107
|
+
ez.logger.warning(
|
|
108
|
+
"windowing is non-deterministic with `zero_pad_until='input'` as it depends on the size "
|
|
109
|
+
"of the first input. We recommend using `zero_pad_until='shift'` when `window_shift` is float-valued."
|
|
110
|
+
)
|
|
111
|
+
try:
|
|
112
|
+
object.__setattr__(self.settings, "anchor", Anchor(self.settings.anchor))
|
|
113
|
+
except ValueError:
|
|
114
|
+
raise ValueError(
|
|
115
|
+
f"Invalid anchor: {self.settings.anchor}. Valid anchor are: {', '.join([e.value for e in Anchor])}"
|
|
84
116
|
)
|
|
85
117
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
118
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
119
|
+
axis = self.settings.axis or message.dims[0]
|
|
120
|
+
axis_idx = message.get_axis_idx(axis)
|
|
121
|
+
axis_info = message.get_axis(axis)
|
|
122
|
+
fs = 1.0 / axis_info.gain
|
|
123
|
+
samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
124
|
+
|
|
125
|
+
return hash(samp_shape + (fs, message.key))
|
|
126
|
+
|
|
127
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
128
|
+
_newaxis = self.settings.newaxis or "win"
|
|
129
|
+
if not self._state.newaxis_warned and _newaxis in message.dims:
|
|
130
|
+
ez.logger.warning(f"newaxis {_newaxis} present in input dims. Using {_newaxis}_win instead")
|
|
131
|
+
self._state.newaxis_warned = True
|
|
132
|
+
self.settings.newaxis = f"{_newaxis}_win"
|
|
133
|
+
|
|
134
|
+
axis = self.settings.axis or message.dims[0]
|
|
135
|
+
axis_idx = message.get_axis_idx(axis)
|
|
136
|
+
axis_info = message.get_axis(axis)
|
|
137
|
+
fs = 1.0 / axis_info.gain
|
|
138
|
+
|
|
139
|
+
xp = get_namespace(message.data)
|
|
140
|
+
|
|
141
|
+
self._state.window_samples = int(self.settings.window_dur * fs)
|
|
142
|
+
if self.settings.window_shift is not None:
|
|
143
|
+
# If window_shift is None, we are in "1:1 mode" and window_shift_samples is not used.
|
|
144
|
+
self._state.window_shift_samples = int(self.settings.window_shift * fs)
|
|
145
|
+
if self.settings.zero_pad_until == "none":
|
|
146
|
+
req_samples = self._state.window_samples
|
|
147
|
+
elif self.settings.zero_pad_until == "shift" and self.settings.window_shift is not None:
|
|
148
|
+
req_samples = self._state.window_shift_samples
|
|
149
|
+
else: # i.e. zero_pad_until == "input"
|
|
150
|
+
req_samples = message.data.shape[axis_idx]
|
|
151
|
+
n_zero = max(0, self._state.window_samples - req_samples)
|
|
152
|
+
init_buffer_shape = message.data.shape[:axis_idx] + (n_zero,) + message.data.shape[axis_idx + 1 :]
|
|
153
|
+
self._state.buffer = xp.zeros(init_buffer_shape, dtype=message.data.dtype)
|
|
154
|
+
|
|
155
|
+
# Prepare reusable parts of output
|
|
156
|
+
if self._state.out_newaxis is None:
|
|
157
|
+
self._state.out_dims = list(message.dims[:axis_idx]) + [_newaxis] + list(message.dims[axis_idx:])
|
|
158
|
+
self._state.out_newaxis = replace(
|
|
159
|
+
axis_info,
|
|
160
|
+
gain=0.0 if self.settings.window_shift is None else axis_info.gain * self._state.window_shift_samples,
|
|
161
|
+
offset=0.0, # offset modified per-msg below
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
def __call__(self, message: AxisArray) -> AxisArray:
|
|
165
|
+
if self.settings.window_dur is None:
|
|
166
|
+
# Shortcut for no windowing
|
|
167
|
+
return message
|
|
168
|
+
return super().__call__(message)
|
|
169
|
+
|
|
170
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
171
|
+
axis = self.settings.axis or message.dims[0]
|
|
172
|
+
axis_idx = message.get_axis_idx(axis)
|
|
173
|
+
axis_info = message.get_axis(axis)
|
|
174
|
+
|
|
175
|
+
xp = get_namespace(message.data)
|
|
176
|
+
|
|
177
|
+
# Add new data to buffer.
|
|
178
|
+
# Currently, we concatenate the new time samples and clip the output.
|
|
179
|
+
# np.roll is not preferred as it returns a copy, and there's no way to construct a
|
|
180
|
+
# rolling view of the data. In current numpy implementations, np.concatenate
|
|
89
181
|
# is generally faster than np.roll and slicing anyway, but this could still
|
|
90
182
|
# be a performance bottleneck for large memory arrays.
|
|
91
|
-
|
|
183
|
+
# A circular buffer might be faster.
|
|
184
|
+
self._state.buffer = xp.concatenate((self._state.buffer, message.data), axis=axis_idx)
|
|
185
|
+
|
|
186
|
+
# Create a vector of buffer timestamps to track axis `offset` in output(s)
|
|
187
|
+
buffer_t0 = 0.0
|
|
188
|
+
buffer_tlen = self._state.buffer.shape[axis_idx]
|
|
189
|
+
|
|
190
|
+
# Adjust so first _new_ sample at index 0.
|
|
191
|
+
buffer_t0 -= self._state.buffer.shape[axis_idx] - message.data.shape[axis_idx]
|
|
192
|
+
|
|
193
|
+
# Convert form indices to 'units' (probably seconds).
|
|
194
|
+
buffer_t0 *= axis_info.gain
|
|
195
|
+
buffer_t0 += axis_info.offset
|
|
196
|
+
|
|
197
|
+
if self.settings.window_shift is not None and self._state.shift_deficit > 0:
|
|
198
|
+
n_skip = min(self._state.buffer.shape[axis_idx], self._state.shift_deficit)
|
|
199
|
+
if n_skip > 0:
|
|
200
|
+
self._state.buffer = slice_along_axis(self._state.buffer, slice(n_skip, None), axis_idx)
|
|
201
|
+
buffer_t0 += n_skip * axis_info.gain
|
|
202
|
+
buffer_tlen -= n_skip
|
|
203
|
+
self._state.shift_deficit -= n_skip
|
|
92
204
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
205
|
+
# Generate outputs.
|
|
206
|
+
# Preliminary copy of axes without the axes that we are modifying.
|
|
207
|
+
_newaxis = self.settings.newaxis or "win"
|
|
208
|
+
out_axes = {k: v for k, v in message.axes.items() if k not in [_newaxis, axis]}
|
|
96
209
|
|
|
97
|
-
|
|
210
|
+
# Update targeted (windowed) axis so that its offset is relative to the new axis
|
|
211
|
+
if self.settings.anchor == Anchor.BEGINNING:
|
|
212
|
+
out_axes[axis] = replace(axis_info, offset=0.0)
|
|
213
|
+
elif self.settings.anchor == Anchor.END:
|
|
214
|
+
out_axes[axis] = replace(axis_info, offset=-self.settings.window_dur)
|
|
215
|
+
elif self.settings.anchor == Anchor.MIDDLE:
|
|
216
|
+
out_axes[axis] = replace(axis_info, offset=-self.settings.window_dur / 2)
|
|
98
217
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
218
|
+
# How we update .data and .axes[newaxis] depends on the windowing mode.
|
|
219
|
+
if self.settings.window_shift is None:
|
|
220
|
+
# one-to-one mode -- Each send yields exactly one window containing only the most recent samples.
|
|
221
|
+
self._state.buffer = slice_along_axis(
|
|
222
|
+
self._state.buffer, slice(-self._state.window_samples, None), axis_idx
|
|
223
|
+
)
|
|
224
|
+
out_dat = self._state.buffer.reshape(
|
|
225
|
+
self._state.buffer.shape[:axis_idx] + (1,) + self._state.buffer.shape[axis_idx:]
|
|
226
|
+
)
|
|
227
|
+
win_offset = buffer_t0 + axis_info.gain * (buffer_tlen - self._state.window_samples)
|
|
228
|
+
elif self._state.buffer.shape[axis_idx] >= self._state.window_samples:
|
|
229
|
+
# Deterministic window shifts.
|
|
230
|
+
sliding_win_fun = sparse_sliding_win_oneaxis if is_pydata_sparse_namespace(xp) else sliding_win_oneaxis
|
|
231
|
+
out_dat = sliding_win_fun(
|
|
232
|
+
self._state.buffer,
|
|
233
|
+
self._state.window_samples,
|
|
234
|
+
axis_idx,
|
|
235
|
+
step=self._state.window_shift_samples,
|
|
236
|
+
)
|
|
237
|
+
win_offset = buffer_t0
|
|
103
238
|
|
|
239
|
+
# Drop expired beginning of buffer and update shift_deficit
|
|
240
|
+
multi_shift = self._state.window_shift_samples * out_dat.shape[axis_idx]
|
|
241
|
+
self._state.shift_deficit = max(0, multi_shift - self._state.buffer.shape[axis_idx])
|
|
242
|
+
self._state.buffer = slice_along_axis(self._state.buffer, slice(multi_shift, None), axis_idx)
|
|
104
243
|
else:
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
244
|
+
# Not enough data to make a new window. Return empty data.
|
|
245
|
+
empty_data_shape = (
|
|
246
|
+
message.data.shape[:axis_idx] + (0, self._state.window_samples) + message.data.shape[axis_idx + 1 :]
|
|
247
|
+
)
|
|
248
|
+
out_dat = xp.zeros(empty_data_shape, dtype=message.data.dtype)
|
|
249
|
+
# out_newaxis will have first timestamp in input... but mostly meaningless because output is size-zero.
|
|
250
|
+
win_offset = axis_info.offset
|
|
251
|
+
|
|
252
|
+
if self.settings.anchor == Anchor.END:
|
|
253
|
+
win_offset += self.settings.window_dur
|
|
254
|
+
elif self.settings.anchor == Anchor.MIDDLE:
|
|
255
|
+
win_offset += self.settings.window_dur / 2
|
|
256
|
+
self._state.out_newaxis = replace(self._state.out_newaxis, offset=win_offset)
|
|
257
|
+
|
|
258
|
+
msg_out = replace(
|
|
259
|
+
message,
|
|
260
|
+
data=out_dat,
|
|
261
|
+
dims=self._state.out_dims,
|
|
262
|
+
axes={**out_axes, _newaxis: self._state.out_newaxis},
|
|
263
|
+
)
|
|
264
|
+
return msg_out
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
class Window(BaseTransformerUnit[WindowSettings, AxisArray, AxisArray, WindowTransformer]):
|
|
268
|
+
SETTINGS = WindowSettings
|
|
269
|
+
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
270
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
271
|
+
|
|
272
|
+
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
|
|
273
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
274
|
+
@profile_subpub(trace_oldest=False)
|
|
275
|
+
async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
|
|
276
|
+
"""
|
|
277
|
+
override superclass on_signal so we can opt to yield once or multiple times after dropping the win axis.
|
|
278
|
+
"""
|
|
279
|
+
# TODO: The transfomer overwrites settings.newaxis from None to "win",
|
|
280
|
+
# then we no longer know if the user wants to trim out the newaxis from the unit.
|
|
281
|
+
xp = get_namespace(message.data)
|
|
282
|
+
try:
|
|
283
|
+
ret = self.processor(message)
|
|
284
|
+
if ret.data.size > 0:
|
|
285
|
+
if self.SETTINGS.newaxis is not None or self.SETTINGS.window_dur is None:
|
|
286
|
+
# Multi-win mode or pass-through mode.
|
|
287
|
+
yield self.OUTPUT_SIGNAL, ret
|
|
143
288
|
else:
|
|
144
|
-
|
|
289
|
+
# We need to split out_msg into multiple yields, dropping newaxis.
|
|
290
|
+
axis_idx = ret.get_axis_idx("win")
|
|
291
|
+
win_axis = ret.axes["win"]
|
|
292
|
+
offsets = win_axis.value(xp.asarray(range(ret.data.shape[axis_idx])))
|
|
293
|
+
for msg_ix in range(ret.data.shape[axis_idx]):
|
|
294
|
+
# Need to drop 'win' and replace self.SETTINGS.axis from axes.
|
|
295
|
+
_out_axes = {
|
|
296
|
+
**{k: v for k, v in ret.axes.items() if k not in ["win", self.SETTINGS.axis]},
|
|
297
|
+
self.SETTINGS.axis: replace(ret.axes[self.SETTINGS.axis], offset=offsets[msg_ix]),
|
|
298
|
+
}
|
|
299
|
+
_ret = replace(
|
|
300
|
+
ret,
|
|
301
|
+
data=slice_along_axis(ret.data, msg_ix, axis_idx),
|
|
302
|
+
dims=ret.dims[:axis_idx] + ret.dims[axis_idx + 1 :],
|
|
303
|
+
axes=_out_axes,
|
|
304
|
+
)
|
|
305
|
+
yield self.OUTPUT_SIGNAL, _ret
|
|
306
|
+
|
|
307
|
+
except Exception:
|
|
308
|
+
ez.logger.info(traceback.format_exc())
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def windowing(
|
|
312
|
+
axis: str | None = None,
|
|
313
|
+
newaxis: str | None = None,
|
|
314
|
+
window_dur: float | None = None,
|
|
315
|
+
window_shift: float | None = None,
|
|
316
|
+
zero_pad_until: str = "full",
|
|
317
|
+
anchor: str | Anchor = Anchor.BEGINNING,
|
|
318
|
+
) -> WindowTransformer:
|
|
319
|
+
return WindowTransformer(
|
|
320
|
+
WindowSettings(
|
|
321
|
+
axis=axis,
|
|
322
|
+
newaxis=newaxis,
|
|
323
|
+
window_dur=window_dur,
|
|
324
|
+
window_shift=window_shift,
|
|
325
|
+
zero_pad_until=zero_pad_until,
|
|
326
|
+
anchor=anchor,
|
|
327
|
+
)
|
|
328
|
+
)
|