ezmsg-sigproc 1.2.2__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ezmsg/sigproc/__init__.py +1 -1
  2. ezmsg/sigproc/__version__.py +34 -1
  3. ezmsg/sigproc/activation.py +78 -0
  4. ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
  5. ezmsg/sigproc/affinetransform.py +235 -0
  6. ezmsg/sigproc/aggregate.py +276 -0
  7. ezmsg/sigproc/bandpower.py +80 -0
  8. ezmsg/sigproc/base.py +149 -0
  9. ezmsg/sigproc/butterworthfilter.py +129 -39
  10. ezmsg/sigproc/butterworthzerophase.py +305 -0
  11. ezmsg/sigproc/cheby.py +125 -0
  12. ezmsg/sigproc/combfilter.py +160 -0
  13. ezmsg/sigproc/coordinatespaces.py +159 -0
  14. ezmsg/sigproc/decimate.py +46 -18
  15. ezmsg/sigproc/denormalize.py +78 -0
  16. ezmsg/sigproc/detrend.py +28 -0
  17. ezmsg/sigproc/diff.py +82 -0
  18. ezmsg/sigproc/downsample.py +97 -49
  19. ezmsg/sigproc/ewma.py +217 -0
  20. ezmsg/sigproc/ewmfilter.py +45 -19
  21. ezmsg/sigproc/extract_axis.py +39 -0
  22. ezmsg/sigproc/fbcca.py +307 -0
  23. ezmsg/sigproc/filter.py +282 -117
  24. ezmsg/sigproc/filterbank.py +292 -0
  25. ezmsg/sigproc/filterbankdesign.py +129 -0
  26. ezmsg/sigproc/fir_hilbert.py +336 -0
  27. ezmsg/sigproc/fir_pmc.py +209 -0
  28. ezmsg/sigproc/firfilter.py +117 -0
  29. ezmsg/sigproc/gaussiansmoothing.py +89 -0
  30. ezmsg/sigproc/kaiser.py +106 -0
  31. ezmsg/sigproc/linear.py +120 -0
  32. ezmsg/sigproc/math/__init__.py +0 -0
  33. ezmsg/sigproc/math/abs.py +35 -0
  34. ezmsg/sigproc/math/add.py +120 -0
  35. ezmsg/sigproc/math/clip.py +48 -0
  36. ezmsg/sigproc/math/difference.py +143 -0
  37. ezmsg/sigproc/math/invert.py +28 -0
  38. ezmsg/sigproc/math/log.py +57 -0
  39. ezmsg/sigproc/math/scale.py +39 -0
  40. ezmsg/sigproc/messages.py +3 -6
  41. ezmsg/sigproc/quantize.py +68 -0
  42. ezmsg/sigproc/resample.py +278 -0
  43. ezmsg/sigproc/rollingscaler.py +232 -0
  44. ezmsg/sigproc/sampler.py +232 -241
  45. ezmsg/sigproc/scaler.py +165 -0
  46. ezmsg/sigproc/signalinjector.py +70 -0
  47. ezmsg/sigproc/slicer.py +138 -0
  48. ezmsg/sigproc/spectral.py +6 -132
  49. ezmsg/sigproc/spectrogram.py +90 -0
  50. ezmsg/sigproc/spectrum.py +277 -0
  51. ezmsg/sigproc/transpose.py +134 -0
  52. ezmsg/sigproc/util/__init__.py +0 -0
  53. ezmsg/sigproc/util/asio.py +25 -0
  54. ezmsg/sigproc/util/axisarray_buffer.py +365 -0
  55. ezmsg/sigproc/util/buffer.py +449 -0
  56. ezmsg/sigproc/util/message.py +17 -0
  57. ezmsg/sigproc/util/profile.py +23 -0
  58. ezmsg/sigproc/util/sparse.py +115 -0
  59. ezmsg/sigproc/util/typeresolution.py +17 -0
  60. ezmsg/sigproc/wavelets.py +187 -0
  61. ezmsg/sigproc/window.py +301 -117
  62. ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
  63. ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
  64. {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -2
  65. ezmsg/sigproc/synth.py +0 -411
  66. ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
  67. ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
  68. ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
  69. /ezmsg_sigproc-1.2.2.dist-info/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
@@ -0,0 +1,187 @@
1
+ import typing
2
+
3
+ import ezmsg.core as ez
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ import pywt
7
+ from ezmsg.baseproc import (
8
+ BaseStatefulTransformer,
9
+ BaseTransformerUnit,
10
+ processor_state,
11
+ )
12
+ from ezmsg.util.messages.axisarray import AxisArray
13
+ from ezmsg.util.messages.util import replace
14
+
15
+ from .filterbank import FilterbankMode, MinPhaseMode, filterbank
16
+
17
+
18
+ class CWTSettings(ez.Settings):
19
+ """
20
+ Settings for :obj:`CWT`
21
+ See :obj:`cwt` for argument details.
22
+ """
23
+
24
+ frequencies: list | tuple | npt.NDArray | None
25
+ wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet
26
+ min_phase: MinPhaseMode = MinPhaseMode.NONE
27
+ axis: str = "time"
28
+ scales: list | tuple | npt.NDArray | None = None
29
+
30
+
31
+ @processor_state
32
+ class CWTState:
33
+ neg_rt_scales: npt.NDArray | None = None
34
+ int_psi_scales: list[npt.NDArray] | None = None
35
+ template: AxisArray | None = None
36
+ fbgen: typing.Generator[AxisArray, AxisArray, None] | None = None
37
+ last_conv_samp: npt.NDArray | None = None
38
+
39
+
40
+ class CWTTransformer(BaseStatefulTransformer[CWTSettings, AxisArray, AxisArray, CWTState]):
41
+ def _hash_message(self, message: AxisArray) -> int:
42
+ ax_idx = message.get_axis_idx(self.settings.axis)
43
+ in_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
44
+ return hash(
45
+ (
46
+ message.data.dtype.kind,
47
+ message.axes[self.settings.axis].gain,
48
+ in_shape,
49
+ message.key,
50
+ )
51
+ )
52
+
53
+ def _reset_state(self, message: AxisArray) -> None:
54
+ precision = 10
55
+
56
+ # Process wavelet
57
+ wavelet = (
58
+ self.settings.wavelet
59
+ if isinstance(self.settings.wavelet, (pywt.ContinuousWavelet, pywt.Wavelet))
60
+ else pywt.DiscreteContinuousWavelet(self.settings.wavelet)
61
+ )
62
+ # Process wavelet integration
63
+ int_psi, wave_xvec = pywt.integrate_wavelet(wavelet, precision=precision)
64
+ int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
65
+
66
+ # Calculate scales and frequencies
67
+ if self.settings.frequencies is not None:
68
+ frequencies = np.sort(np.array(self.settings.frequencies))
69
+ scales = pywt.frequency2scale(
70
+ wavelet,
71
+ frequencies * message.axes[self.settings.axis].gain,
72
+ precision=precision,
73
+ )
74
+ else:
75
+ scales = np.sort(self.settings.scales)[::-1]
76
+
77
+ self._state.neg_rt_scales = -np.sqrt(scales)[:, None]
78
+
79
+ # Convert to appropriate dtype
80
+ dt_data = message.data.dtype
81
+ dt_cplx = np.result_type(dt_data, np.complex64)
82
+ dt_psi = dt_cplx if int_psi.dtype.kind == "c" else dt_data
83
+ int_psi = np.asarray(int_psi, dtype=dt_psi)
84
+ # Note: Currently int_psi cannot be made non-complex once it is complex.
85
+
86
+ # Calculate waves for each scale
87
+ wave_xvec = np.asarray(wave_xvec, dtype=message.data.real.dtype)
88
+ wave_range = wave_xvec[-1] - wave_xvec[0]
89
+ step = wave_xvec[1] - wave_xvec[0]
90
+ self._state.int_psi_scales = []
91
+ for scale in scales:
92
+ reix = (np.arange(scale * wave_range + 1) / (scale * step)).astype(int)
93
+ if reix[-1] >= int_psi.size:
94
+ reix = np.extract(reix < int_psi.size, reix)
95
+ self._state.int_psi_scales.append(int_psi[reix][::-1])
96
+
97
+ # Setup filterbank generator
98
+ self._state.fbgen = filterbank(
99
+ self._state.int_psi_scales,
100
+ mode=FilterbankMode.CONV,
101
+ min_phase=self.settings.min_phase,
102
+ axis=self.settings.axis,
103
+ )
104
+
105
+ # Create output template
106
+ ax_idx = message.get_axis_idx(self.settings.axis)
107
+ in_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
108
+ freqs = pywt.scale2frequency(wavelet, scales, precision) / message.axes[self.settings.axis].gain
109
+ dummy_shape = in_shape + (len(scales), 0)
110
+ self._state.template = AxisArray(
111
+ np.zeros(dummy_shape, dtype=dt_cplx if wavelet.complex_cwt else dt_data),
112
+ dims=message.dims[:ax_idx] + message.dims[ax_idx + 1 :] + ["freq", self.settings.axis],
113
+ axes={
114
+ **message.axes,
115
+ "freq": AxisArray.CoordinateAxis(unit="Hz", data=freqs, dims=["freq"]),
116
+ },
117
+ key=message.key,
118
+ )
119
+ self._state.last_conv_samp = np.zeros(dummy_shape[:-1] + (1,), dtype=self._state.template.data.dtype)
120
+
121
+ def _process(self, message: AxisArray) -> AxisArray:
122
+ conv_msg = self._state.fbgen.send(message)
123
+
124
+ # Prepend with last_conv_samp before doing diff
125
+ dat = np.concatenate((self._state.last_conv_samp, conv_msg.data), axis=-1)
126
+ coef = self._state.neg_rt_scales * np.diff(dat, axis=-1)
127
+ # Store last_conv_samp for next iteration
128
+ self._state.last_conv_samp = conv_msg.data[..., -1:]
129
+
130
+ if self._state.template.data.dtype.kind != "c":
131
+ coef = coef.real
132
+
133
+ # pywt.cwt slices off the beginning and end of the result where the convolution overran. We don't have
134
+ # that luxury when streaming.
135
+ # d = (coef.shape[-1] - msg_in.data.shape[ax_idx]) / 2.
136
+ # coef = coef[..., math.floor(d):-math.ceil(d)]
137
+ return replace(
138
+ self._state.template,
139
+ data=coef,
140
+ axes={
141
+ **self._state.template.axes,
142
+ self.settings.axis: message.axes[self.settings.axis],
143
+ },
144
+ )
145
+
146
+
147
+ class CWT(BaseTransformerUnit[CWTSettings, AxisArray, AxisArray, CWTTransformer]):
148
+ SETTINGS = CWTSettings
149
+
150
+
151
+ def cwt(
152
+ frequencies: list | tuple | npt.NDArray | None,
153
+ wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet,
154
+ min_phase: MinPhaseMode = MinPhaseMode.NONE,
155
+ axis: str = "time",
156
+ scales: list | tuple | npt.NDArray | None = None,
157
+ ) -> CWTTransformer:
158
+ """
159
+ Perform a continuous wavelet transform.
160
+ The function is equivalent to the :obj:`pywt.cwt` function, but is designed to work with streaming data.
161
+
162
+ Args:
163
+ frequencies: The wavelet frequencies to use in Hz. If `None` provided then the scales will be used.
164
+ Note: frequencies will be sorted from smallest to largest.
165
+ wavelet: Wavelet object or name of wavelet to use.
166
+ min_phase: See filterbank MinPhaseMode for details.
167
+ axis: The target axis for operation. Note that this will be moved to the -1th dimension
168
+ because fft and matrix multiplication is much faster on the last axis.
169
+ This axis must be in the msg.axes and it must be of type AxisArray.LinearAxis.
170
+ scales: The scales to use. If None, the scales will be calculated from the frequencies.
171
+ Note: Scales will be sorted from largest to smallest.
172
+ Note: Use of scales is deprecated in favor of frequencies. Convert scales to frequencies using
173
+ `pywt.scale2frequency(wavelet, scales, precision=10) * fs` where fs is the sampling frequency.
174
+
175
+ Returns:
176
+ A primed Generator object that expects an :obj:`AxisArray` via `.send(axis_array)` of continuous data
177
+ and yields an :obj:`AxisArray` with a continuous wavelet transform in its data.
178
+ """
179
+ return CWTTransformer(
180
+ CWTSettings(
181
+ frequencies=frequencies,
182
+ wavelet=wavelet,
183
+ min_phase=min_phase,
184
+ axis=axis,
185
+ scales=scales,
186
+ )
187
+ )
ezmsg/sigproc/window.py CHANGED
@@ -1,144 +1,328 @@
1
- from dataclasses import replace
1
+ import enum
2
+ import traceback
3
+ import typing
2
4
 
3
5
  import ezmsg.core as ez
4
- import numpy as np
5
6
  import numpy.typing as npt
7
+ import sparse
8
+ from array_api_compat import get_namespace, is_pydata_sparse_namespace
9
+ from ezmsg.baseproc import (
10
+ BaseStatefulTransformer,
11
+ BaseTransformerUnit,
12
+ processor_state,
13
+ )
14
+ from ezmsg.util.messages.axisarray import (
15
+ AxisArray,
16
+ replace,
17
+ slice_along_axis,
18
+ sliding_win_oneaxis,
19
+ )
6
20
 
7
- from ezmsg.util.messages.axisarray import AxisArray
21
+ from .util.profile import profile_subpub
22
+ from .util.sparse import sliding_win_oneaxis as sparse_sliding_win_oneaxis
8
23
 
9
- from typing import AsyncGenerator, Optional, Tuple, List
24
+
25
+ class Anchor(enum.Enum):
26
+ BEGINNING = "beginning"
27
+ END = "end"
28
+ MIDDLE = "middle"
10
29
 
11
30
 
12
31
  class WindowSettings(ez.Settings):
13
- axis: Optional[str] = None
14
- newaxis: Optional[
15
- str
16
- ] = None # Optional new axis for output. If "None" - no new axes on output
17
- window_dur: Optional[
18
- float
19
- ] = None # Sec. If "None" -- passthrough; window_shift is ignored.
20
- window_shift: Optional[float] = None # Sec. If "None", activate "1:1 mode"
32
+ axis: str | None = None
33
+ newaxis: str | None = None # new axis for output. No new axes if None
34
+ window_dur: float | None = None # Sec. passthrough if None
35
+ window_shift: float | None = None # Sec. Use "1:1 mode" if None
36
+ zero_pad_until: str = "full" # "full", "shift", "input", "none"
37
+ anchor: str | Anchor = Anchor.BEGINNING
21
38
 
22
39
 
23
- class WindowState(ez.State):
24
- cur_settings: WindowSettings
40
+ @processor_state
41
+ class WindowState:
42
+ buffer: npt.NDArray | sparse.SparseArray | None = None
25
43
 
26
- samp_shape: Optional[Tuple[int, ...]] = None # Shape of individual sample
27
- out_fs: Optional[float] = None
28
- buffer: Optional[npt.NDArray] = None
29
- window_samples: Optional[int] = None
30
- window_shift_samples: Optional[int] = None
44
+ window_samples: int | None = None
31
45
 
46
+ window_shift_samples: int | None = None
32
47
 
33
- class Window(ez.Unit):
34
- STATE: WindowState
35
- SETTINGS: WindowSettings
48
+ shift_deficit: int = 0
49
+ """ Number of incoming samples to ignore. Only relevant when shift > window."""
36
50
 
37
- INPUT_SIGNAL = ez.InputStream(AxisArray)
38
- OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
39
- INPUT_SETTINGS = ez.InputStream(WindowSettings)
51
+ newaxis_warned: bool = False
40
52
 
41
- def initialize(self) -> None:
42
- self.STATE.cur_settings = self.SETTINGS
53
+ out_newaxis: AxisArray.LinearAxis | None = None
43
54
 
44
- @ez.subscriber(INPUT_SETTINGS)
45
- async def on_settings(self, msg: WindowSettings) -> None:
46
- self.STATE.cur_settings = msg
47
- self.STATE.out_fs = None # This should trigger a reallocation
55
+ out_dims: list[str] | None = None
48
56
 
49
- @ez.subscriber(INPUT_SIGNAL)
50
- @ez.publisher(OUTPUT_SIGNAL)
51
- async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
52
- if self.STATE.cur_settings.window_dur is None:
53
- yield self.OUTPUT_SIGNAL, msg
54
- return
55
-
56
- axis_name = self.STATE.cur_settings.axis
57
- if axis_name is None:
58
- axis_name = msg.dims[0]
59
- axis_idx = msg.get_axis_idx(axis_name)
60
- axis = msg.get_axis(axis_name)
61
- fs = 1.0 / axis.gain
62
-
63
- # Create a view of data with time axis at dim 0
64
- time_view = np.moveaxis(msg.data, axis_idx, 0)
65
- samp_shape = time_view.shape[1:]
66
-
67
- # Pre(re?)allocate buffer
68
- window_samples = int(self.STATE.cur_settings.window_dur * fs)
69
- if (
70
- (self.STATE.samp_shape != samp_shape)
71
- or (self.STATE.out_fs != fs)
72
- or self.STATE.buffer is None
73
- ):
74
- self.STATE.buffer = np.zeros(tuple([window_samples] + list(samp_shape)))
75
-
76
- self.STATE.window_samples = window_samples
77
- self.STATE.samp_shape = samp_shape
78
- self.STATE.out_fs = fs
79
-
80
- self.STATE.window_shift_samples = None
81
- if self.STATE.cur_settings.window_shift is not None:
82
- self.STATE.window_shift_samples = int(
83
- fs * self.STATE.cur_settings.window_shift
57
+
58
+ class WindowTransformer(BaseStatefulTransformer[WindowSettings, AxisArray, AxisArray, WindowState]):
59
+ """
60
+ Apply a sliding window along the specified axis to input streaming data.
61
+ The `windowing` method is perhaps the most useful and versatile method in ezmsg.sigproc, but its parameterization
62
+ can be difficult. Please read the argument descriptions carefully.
63
+ """
64
+
65
+ def __init__(self, *args, **kwargs) -> None:
66
+ """
67
+
68
+ Args:
69
+ axis: The axis along which to segment windows.
70
+ If None, defaults to the first dimension of the first seen AxisArray.
71
+ Note: The windowed axis must be an AxisArray.LinearAxis, not an AxisArray.CoordinateAxis.
72
+ newaxis: New axis on which windows are delimited, immediately
73
+ preceding the target windowed axis. The data length along newaxis may be 0 if
74
+ this most recent push did not provide enough data for a new window.
75
+ If window_shift is None then the newaxis length will always be 1.
76
+ window_dur: The duration of the window in seconds.
77
+ If None, the function acts as a passthrough and all other parameters are ignored.
78
+ window_shift: The shift of the window in seconds.
79
+ If None (default), windowing operates in "1:1 mode",
80
+ where each input yields exactly one most-recent window.
81
+ zero_pad_until: Determines how the function initializes the buffer.
82
+ Can be one of "input" (default), "full", "shift", or "none".
83
+ If `window_shift` is None then this field is ignored and "input" is always used.
84
+
85
+ - "input" (default) initializes the buffer with the input then prepends with zeros to the window size.
86
+ The first input will always yield at least one output.
87
+ - "shift" fills the buffer until `window_shift`.
88
+ No outputs will be yielded until at least `window_shift` data has been seen.
89
+ - "none" does not pad the buffer. No outputs will be yielded until
90
+ at least `window_dur` data has been seen.
91
+ anchor: Determines the entry in `axis` that gets assigned `0`, which references the
92
+ value in `newaxis`. Can be of class :obj:`Anchor` or a string representation of an :obj:`Anchor`.
93
+ """
94
+ super().__init__(*args, **kwargs)
95
+
96
+ # Sanity-check settings
97
+ # if self.settings.newaxis is None:
98
+ # ez.logger.warning("`newaxis=None` will be replaced with `newaxis='win'`.")
99
+ # object.__setattr__(self.settings, "newaxis", "win")
100
+ if self.settings.window_shift is None and self.settings.zero_pad_until != "input":
101
+ ez.logger.warning(
102
+ "`zero_pad_until` must be 'input' if `window_shift` is None. "
103
+ f"Ignoring received argument value: {self.settings.zero_pad_until}"
104
+ )
105
+ object.__setattr__(self.settings, "zero_pad_until", "input")
106
+ elif self.settings.window_shift is not None and self.settings.zero_pad_until == "input":
107
+ ez.logger.warning(
108
+ "windowing is non-deterministic with `zero_pad_until='input'` as it depends on the size "
109
+ "of the first input. We recommend using `zero_pad_until='shift'` when `window_shift` is float-valued."
110
+ )
111
+ try:
112
+ object.__setattr__(self.settings, "anchor", Anchor(self.settings.anchor))
113
+ except ValueError:
114
+ raise ValueError(
115
+ f"Invalid anchor: {self.settings.anchor}. Valid anchor are: {', '.join([e.value for e in Anchor])}"
84
116
  )
85
117
 
86
- # Currently we just concatenate the new time samples and clip the output
87
- # np.roll actually returns a copy, and there's no way to construct a
88
- # rolling view of the data. In current numpy implementations, np.concatenate
118
+ def _hash_message(self, message: AxisArray) -> int:
119
+ axis = self.settings.axis or message.dims[0]
120
+ axis_idx = message.get_axis_idx(axis)
121
+ axis_info = message.get_axis(axis)
122
+ fs = 1.0 / axis_info.gain
123
+ samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
124
+
125
+ return hash(samp_shape + (fs, message.key))
126
+
127
+ def _reset_state(self, message: AxisArray) -> None:
128
+ _newaxis = self.settings.newaxis or "win"
129
+ if not self._state.newaxis_warned and _newaxis in message.dims:
130
+ ez.logger.warning(f"newaxis {_newaxis} present in input dims. Using {_newaxis}_win instead")
131
+ self._state.newaxis_warned = True
132
+ self.settings.newaxis = f"{_newaxis}_win"
133
+
134
+ axis = self.settings.axis or message.dims[0]
135
+ axis_idx = message.get_axis_idx(axis)
136
+ axis_info = message.get_axis(axis)
137
+ fs = 1.0 / axis_info.gain
138
+
139
+ xp = get_namespace(message.data)
140
+
141
+ self._state.window_samples = int(self.settings.window_dur * fs)
142
+ if self.settings.window_shift is not None:
143
+ # If window_shift is None, we are in "1:1 mode" and window_shift_samples is not used.
144
+ self._state.window_shift_samples = int(self.settings.window_shift * fs)
145
+ if self.settings.zero_pad_until == "none":
146
+ req_samples = self._state.window_samples
147
+ elif self.settings.zero_pad_until == "shift" and self.settings.window_shift is not None:
148
+ req_samples = self._state.window_shift_samples
149
+ else: # i.e. zero_pad_until == "input"
150
+ req_samples = message.data.shape[axis_idx]
151
+ n_zero = max(0, self._state.window_samples - req_samples)
152
+ init_buffer_shape = message.data.shape[:axis_idx] + (n_zero,) + message.data.shape[axis_idx + 1 :]
153
+ self._state.buffer = xp.zeros(init_buffer_shape, dtype=message.data.dtype)
154
+
155
+ # Prepare reusable parts of output
156
+ if self._state.out_newaxis is None:
157
+ self._state.out_dims = list(message.dims[:axis_idx]) + [_newaxis] + list(message.dims[axis_idx:])
158
+ self._state.out_newaxis = replace(
159
+ axis_info,
160
+ gain=0.0 if self.settings.window_shift is None else axis_info.gain * self._state.window_shift_samples,
161
+ offset=0.0, # offset modified per-msg below
162
+ )
163
+
164
+ def __call__(self, message: AxisArray) -> AxisArray:
165
+ if self.settings.window_dur is None:
166
+ # Shortcut for no windowing
167
+ return message
168
+ return super().__call__(message)
169
+
170
+ def _process(self, message: AxisArray) -> AxisArray:
171
+ axis = self.settings.axis or message.dims[0]
172
+ axis_idx = message.get_axis_idx(axis)
173
+ axis_info = message.get_axis(axis)
174
+
175
+ xp = get_namespace(message.data)
176
+
177
+ # Add new data to buffer.
178
+ # Currently, we concatenate the new time samples and clip the output.
179
+ # np.roll is not preferred as it returns a copy, and there's no way to construct a
180
+ # rolling view of the data. In current numpy implementations, np.concatenate
89
181
  # is generally faster than np.roll and slicing anyway, but this could still
90
182
  # be a performance bottleneck for large memory arrays.
91
- self.STATE.buffer = np.concatenate((self.STATE.buffer, time_view), axis=0)
183
+ # A circular buffer might be faster.
184
+ self._state.buffer = xp.concatenate((self._state.buffer, message.data), axis=axis_idx)
185
+
186
+ # Create a vector of buffer timestamps to track axis `offset` in output(s)
187
+ buffer_t0 = 0.0
188
+ buffer_tlen = self._state.buffer.shape[axis_idx]
189
+
190
+ # Adjust so first _new_ sample at index 0.
191
+ buffer_t0 -= self._state.buffer.shape[axis_idx] - message.data.shape[axis_idx]
192
+
193
+ # Convert form indices to 'units' (probably seconds).
194
+ buffer_t0 *= axis_info.gain
195
+ buffer_t0 += axis_info.offset
196
+
197
+ if self.settings.window_shift is not None and self._state.shift_deficit > 0:
198
+ n_skip = min(self._state.buffer.shape[axis_idx], self._state.shift_deficit)
199
+ if n_skip > 0:
200
+ self._state.buffer = slice_along_axis(self._state.buffer, slice(n_skip, None), axis_idx)
201
+ buffer_t0 += n_skip * axis_info.gain
202
+ buffer_tlen -= n_skip
203
+ self._state.shift_deficit -= n_skip
92
204
 
93
- buffer_offset = np.arange(self.STATE.buffer.shape[0] + time_view.shape[0])
94
- buffer_offset -= self.STATE.buffer.shape[0] + 1
95
- buffer_offset = (buffer_offset * axis.gain) + axis.offset
205
+ # Generate outputs.
206
+ # Preliminary copy of axes without the axes that we are modifying.
207
+ _newaxis = self.settings.newaxis or "win"
208
+ out_axes = {k: v for k, v in message.axes.items() if k not in [_newaxis, axis]}
96
209
 
97
- outputs: List[Tuple[npt.NDArray, float]] = []
210
+ # Update targeted (windowed) axis so that its offset is relative to the new axis
211
+ if self.settings.anchor == Anchor.BEGINNING:
212
+ out_axes[axis] = replace(axis_info, offset=0.0)
213
+ elif self.settings.anchor == Anchor.END:
214
+ out_axes[axis] = replace(axis_info, offset=-self.settings.window_dur)
215
+ elif self.settings.anchor == Anchor.MIDDLE:
216
+ out_axes[axis] = replace(axis_info, offset=-self.settings.window_dur / 2)
98
217
 
99
- if self.STATE.window_shift_samples is None: # one-to-one mode
100
- self.STATE.buffer = self.STATE.buffer[-self.STATE.window_samples :, ...]
101
- buffer_offset = buffer_offset[-self.STATE.window_samples :]
102
- outputs.append((self.STATE.buffer, buffer_offset[0]))
218
+ # How we update .data and .axes[newaxis] depends on the windowing mode.
219
+ if self.settings.window_shift is None:
220
+ # one-to-one mode -- Each send yields exactly one window containing only the most recent samples.
221
+ self._state.buffer = slice_along_axis(
222
+ self._state.buffer, slice(-self._state.window_samples, None), axis_idx
223
+ )
224
+ out_dat = self._state.buffer.reshape(
225
+ self._state.buffer.shape[:axis_idx] + (1,) + self._state.buffer.shape[axis_idx:]
226
+ )
227
+ win_offset = buffer_t0 + axis_info.gain * (buffer_tlen - self._state.window_samples)
228
+ elif self._state.buffer.shape[axis_idx] >= self._state.window_samples:
229
+ # Deterministic window shifts.
230
+ sliding_win_fun = sparse_sliding_win_oneaxis if is_pydata_sparse_namespace(xp) else sliding_win_oneaxis
231
+ out_dat = sliding_win_fun(
232
+ self._state.buffer,
233
+ self._state.window_samples,
234
+ axis_idx,
235
+ step=self._state.window_shift_samples,
236
+ )
237
+ win_offset = buffer_t0
103
238
 
239
+ # Drop expired beginning of buffer and update shift_deficit
240
+ multi_shift = self._state.window_shift_samples * out_dat.shape[axis_idx]
241
+ self._state.shift_deficit = max(0, multi_shift - self._state.buffer.shape[axis_idx])
242
+ self._state.buffer = slice_along_axis(self._state.buffer, slice(multi_shift, None), axis_idx)
104
243
  else:
105
- yieldable_size = self.STATE.window_samples + self.STATE.window_shift_samples
106
- while self.STATE.buffer.shape[0] >= yieldable_size:
107
- outputs.append(
108
- (
109
- self.STATE.buffer[: self.STATE.window_samples, ...],
110
- buffer_offset[0],
111
- )
112
- )
113
- self.STATE.buffer = self.STATE.buffer[
114
- self.STATE.window_shift_samples :, ...
115
- ]
116
- buffer_offset = buffer_offset[self.STATE.window_shift_samples :]
117
-
118
- for out_view, offset in outputs:
119
- out_view = np.moveaxis(out_view, 0, axis_idx)
120
-
121
- if (
122
- self.STATE.cur_settings.newaxis is not None
123
- and self.STATE.cur_settings.newaxis != self.STATE.cur_settings.axis
124
- ):
125
- new_gain = 0.0
126
- if self.STATE.window_shift_samples is not None:
127
- new_gain = axis.gain * self.STATE.window_shift_samples
128
-
129
- out_axis = replace(axis, unit=axis.unit, gain=new_gain, offset=offset)
130
- out_axes = {**msg.axes, **{self.STATE.cur_settings.newaxis: out_axis}}
131
- out_dims = [self.STATE.cur_settings.newaxis] + msg.dims
132
- out_view = out_view[np.newaxis, ...]
133
-
134
- yield self.OUTPUT_SIGNAL, replace(
135
- msg, data=out_view, dims=out_dims, axes=out_axes
136
- )
137
-
138
- else:
139
- if axis_name in msg.axes:
140
- out_axes = msg.axes
141
- out_axes[axis_name] = replace(axis, offset=offset)
142
- yield self.OUTPUT_SIGNAL, replace(msg, data=out_view, axes=out_axes)
244
+ # Not enough data to make a new window. Return empty data.
245
+ empty_data_shape = (
246
+ message.data.shape[:axis_idx] + (0, self._state.window_samples) + message.data.shape[axis_idx + 1 :]
247
+ )
248
+ out_dat = xp.zeros(empty_data_shape, dtype=message.data.dtype)
249
+ # out_newaxis will have first timestamp in input... but mostly meaningless because output is size-zero.
250
+ win_offset = axis_info.offset
251
+
252
+ if self.settings.anchor == Anchor.END:
253
+ win_offset += self.settings.window_dur
254
+ elif self.settings.anchor == Anchor.MIDDLE:
255
+ win_offset += self.settings.window_dur / 2
256
+ self._state.out_newaxis = replace(self._state.out_newaxis, offset=win_offset)
257
+
258
+ msg_out = replace(
259
+ message,
260
+ data=out_dat,
261
+ dims=self._state.out_dims,
262
+ axes={**out_axes, _newaxis: self._state.out_newaxis},
263
+ )
264
+ return msg_out
265
+
266
+
267
+ class Window(BaseTransformerUnit[WindowSettings, AxisArray, AxisArray, WindowTransformer]):
268
+ SETTINGS = WindowSettings
269
+ INPUT_SIGNAL = ez.InputStream(AxisArray)
270
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
271
+
272
+ @ez.subscriber(INPUT_SIGNAL, zero_copy=True)
273
+ @ez.publisher(OUTPUT_SIGNAL)
274
+ @profile_subpub(trace_oldest=False)
275
+ async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
276
+ """
277
+ override superclass on_signal so we can opt to yield once or multiple times after dropping the win axis.
278
+ """
279
+ # TODO: The transfomer overwrites settings.newaxis from None to "win",
280
+ # then we no longer know if the user wants to trim out the newaxis from the unit.
281
+ xp = get_namespace(message.data)
282
+ try:
283
+ ret = self.processor(message)
284
+ if ret.data.size > 0:
285
+ if self.SETTINGS.newaxis is not None or self.SETTINGS.window_dur is None:
286
+ # Multi-win mode or pass-through mode.
287
+ yield self.OUTPUT_SIGNAL, ret
143
288
  else:
144
- yield self.OUTPUT_SIGNAL, replace(msg, data=out_view)
289
+ # We need to split out_msg into multiple yields, dropping newaxis.
290
+ axis_idx = ret.get_axis_idx("win")
291
+ win_axis = ret.axes["win"]
292
+ offsets = win_axis.value(xp.asarray(range(ret.data.shape[axis_idx])))
293
+ for msg_ix in range(ret.data.shape[axis_idx]):
294
+ # Need to drop 'win' and replace self.SETTINGS.axis from axes.
295
+ _out_axes = {
296
+ **{k: v for k, v in ret.axes.items() if k not in ["win", self.SETTINGS.axis]},
297
+ self.SETTINGS.axis: replace(ret.axes[self.SETTINGS.axis], offset=offsets[msg_ix]),
298
+ }
299
+ _ret = replace(
300
+ ret,
301
+ data=slice_along_axis(ret.data, msg_ix, axis_idx),
302
+ dims=ret.dims[:axis_idx] + ret.dims[axis_idx + 1 :],
303
+ axes=_out_axes,
304
+ )
305
+ yield self.OUTPUT_SIGNAL, _ret
306
+
307
+ except Exception:
308
+ ez.logger.info(traceback.format_exc())
309
+
310
+
311
+ def windowing(
312
+ axis: str | None = None,
313
+ newaxis: str | None = None,
314
+ window_dur: float | None = None,
315
+ window_shift: float | None = None,
316
+ zero_pad_until: str = "full",
317
+ anchor: str | Anchor = Anchor.BEGINNING,
318
+ ) -> WindowTransformer:
319
+ return WindowTransformer(
320
+ WindowSettings(
321
+ axis=axis,
322
+ newaxis=newaxis,
323
+ window_dur=window_dur,
324
+ window_shift=window_shift,
325
+ zero_pad_until=zero_pad_until,
326
+ anchor=anchor,
327
+ )
328
+ )