ezmsg-sigproc 1.2.2__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. ezmsg/sigproc/__init__.py +1 -1
  2. ezmsg/sigproc/__version__.py +16 -1
  3. ezmsg/sigproc/activation.py +75 -0
  4. ezmsg/sigproc/affinetransform.py +234 -0
  5. ezmsg/sigproc/aggregate.py +158 -0
  6. ezmsg/sigproc/bandpower.py +74 -0
  7. ezmsg/sigproc/base.py +38 -0
  8. ezmsg/sigproc/butterworthfilter.py +102 -11
  9. ezmsg/sigproc/decimate.py +7 -4
  10. ezmsg/sigproc/downsample.py +95 -51
  11. ezmsg/sigproc/ewmfilter.py +38 -16
  12. ezmsg/sigproc/filter.py +108 -20
  13. ezmsg/sigproc/filterbank.py +278 -0
  14. ezmsg/sigproc/math/__init__.py +0 -0
  15. ezmsg/sigproc/math/abs.py +28 -0
  16. ezmsg/sigproc/math/clip.py +30 -0
  17. ezmsg/sigproc/math/difference.py +60 -0
  18. ezmsg/sigproc/math/invert.py +29 -0
  19. ezmsg/sigproc/math/log.py +32 -0
  20. ezmsg/sigproc/math/scale.py +31 -0
  21. ezmsg/sigproc/messages.py +2 -3
  22. ezmsg/sigproc/sampler.py +259 -224
  23. ezmsg/sigproc/scaler.py +173 -0
  24. ezmsg/sigproc/signalinjector.py +64 -0
  25. ezmsg/sigproc/slicer.py +133 -0
  26. ezmsg/sigproc/spectral.py +6 -132
  27. ezmsg/sigproc/spectrogram.py +86 -0
  28. ezmsg/sigproc/spectrum.py +259 -0
  29. ezmsg/sigproc/synth.py +299 -105
  30. ezmsg/sigproc/wavelets.py +167 -0
  31. ezmsg/sigproc/window.py +254 -116
  32. ezmsg_sigproc-1.3.1.dist-info/METADATA +59 -0
  33. ezmsg_sigproc-1.3.1.dist-info/RECORD +35 -0
  34. {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-1.3.1.dist-info}/WHEEL +1 -2
  35. ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
  36. ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
  37. ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
  38. {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-1.3.1.dist-info/licenses}/LICENSE.txt +0 -0
@@ -0,0 +1,167 @@
1
+ from dataclasses import replace
2
+ import typing
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ import pywt
7
+ import ezmsg.core as ez
8
+ from ezmsg.util.messages.axisarray import AxisArray
9
+ from ezmsg.util.generator import consumer
10
+
11
+ from .base import GenAxisArray
12
+ from .filterbank import filterbank, FilterbankMode, MinPhaseMode
13
+
14
+
15
+ @consumer
16
+ def cwt(
17
+ scales: typing.Union[list, tuple, npt.NDArray],
18
+ wavelet: typing.Union[str, pywt.ContinuousWavelet, pywt.Wavelet],
19
+ min_phase: MinPhaseMode = MinPhaseMode.NONE,
20
+ axis: str = "time",
21
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
22
+ """
23
+ Build a generator to perform a continuous wavelet transform on sent AxisArray messages.
24
+ The function is equivalent to the `pywt.cwt` function, but is designed to work with streaming data.
25
+
26
+ Args:
27
+ scales: The wavelet scales to use.
28
+ wavelet: Wavelet object or name of wavelet to use.
29
+ min_phase: See filterbank MinPhaseMode for details.
30
+ axis: The target axis for operation. Note that this will be moved to the -1th dimension
31
+ because fft and matrix multiplication is much faster on the last axis.
32
+
33
+ Returns:
34
+ A Generator object that expects `.send(axis_array)` of continuous data
35
+ """
36
+ msg_out: typing.Optional[AxisArray] = None
37
+
38
+ # Check parameters
39
+ scales = np.array(scales)
40
+ assert np.all(scales > 0), "Scales must be positive."
41
+ assert scales.ndim == 1, "Scales must be a 1D list, tuple, or array."
42
+ if not isinstance(wavelet, (pywt.ContinuousWavelet, pywt.Wavelet)):
43
+ wavelet = pywt.DiscreteContinuousWavelet(wavelet)
44
+ precision = 10
45
+
46
+ # State variables
47
+ neg_rt_scales = -np.sqrt(scales)[:, None]
48
+ int_psi, wave_xvec = pywt.integrate_wavelet(wavelet, precision=precision)
49
+ int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
50
+ template: typing.Optional[AxisArray] = None
51
+ fbgen: typing.Optional[typing.Generator[AxisArray, AxisArray, None]] = None
52
+ last_conv_samp: typing.Optional[npt.NDArray] = None
53
+
54
+ # Reset if input changed
55
+ check_input = {
56
+ "kind": None, # Need to recalc kernels at same complexity as input
57
+ "gain": None, # Need to recalc freqs
58
+ "shape": None, # Need to recalc template and buffer
59
+ "key": None, # Buffer obsolete
60
+ }
61
+
62
+ while True:
63
+ msg_in: AxisArray = yield msg_out
64
+ ax_idx = msg_in.get_axis_idx(axis)
65
+ in_shape = msg_in.data.shape[:ax_idx] + msg_in.data.shape[ax_idx + 1 :]
66
+
67
+ b_reset = msg_in.data.dtype.kind != check_input["kind"]
68
+ b_reset = b_reset or msg_in.axes[axis].gain != check_input["gain"]
69
+ b_reset = b_reset or in_shape != check_input["shape"]
70
+ b_reset = b_reset or msg_in.key != check_input["key"]
71
+ b_reset = b_reset and msg_in.data.size > 0
72
+ if b_reset:
73
+ check_input["kind"] = msg_in.data.dtype.kind
74
+ check_input["gain"] = msg_in.axes[axis].gain
75
+ check_input["shape"] = in_shape
76
+ check_input["key"] = msg_in.key
77
+
78
+ # convert int_psi, wave_xvec to the same precision as the data
79
+ dt_data = msg_in.data.dtype # _check_dtype(msg_in.data)
80
+ dt_cplx = np.result_type(dt_data, np.complex64)
81
+ dt_psi = dt_cplx if int_psi.dtype.kind == "c" else dt_data
82
+ int_psi = np.asarray(int_psi, dtype=dt_psi)
83
+ # TODO: Currently int_psi cannot be made non-complex once it is complex.
84
+
85
+ # Calculate waves for each scale
86
+ wave_xvec = np.asarray(wave_xvec, dtype=msg_in.data.real.dtype)
87
+ wave_range = wave_xvec[-1] - wave_xvec[0]
88
+ step = wave_xvec[1] - wave_xvec[0]
89
+ int_psi_scales = []
90
+ for scale in scales:
91
+ reix = (np.arange(scale * wave_range + 1) / (scale * step)).astype(int)
92
+ if reix[-1] >= int_psi.size:
93
+ reix = np.extract(reix < int_psi.size, reix)
94
+ int_psi_scales.append(int_psi[reix][::-1])
95
+
96
+ # CONV is probably best because we often get huge kernels.
97
+ fbgen = filterbank(
98
+ int_psi_scales, mode=FilterbankMode.CONV, min_phase=min_phase, axis=axis
99
+ )
100
+
101
+ freqs = (
102
+ pywt.scale2frequency(wavelet, scales, precision)
103
+ / msg_in.axes[axis].gain
104
+ )
105
+ fstep = (freqs[1] - freqs[0]) if len(freqs) > 1 else 1.0
106
+ # Create output template
107
+ dummy_shape = in_shape + (len(scales), 0)
108
+ template = AxisArray(
109
+ np.zeros(
110
+ dummy_shape, dtype=dt_cplx if wavelet.complex_cwt else dt_data
111
+ ),
112
+ dims=msg_in.dims[:ax_idx] + msg_in.dims[ax_idx + 1 :] + ["freq", axis],
113
+ axes={
114
+ **msg_in.axes,
115
+ "freq": AxisArray.Axis("Hz", offset=freqs[0], gain=fstep),
116
+ },
117
+ )
118
+ last_conv_samp = np.zeros(
119
+ dummy_shape[:-1] + (1,), dtype=template.data.dtype
120
+ )
121
+
122
+ conv_msg = fbgen.send(msg_in)
123
+
124
+ # Prepend with last_conv_samp before doing diff
125
+ dat = np.concatenate((last_conv_samp, conv_msg.data), axis=-1)
126
+ coef = neg_rt_scales * np.diff(dat, axis=-1)
127
+ # Store last_conv_samp for next iteration.
128
+ last_conv_samp = conv_msg.data[..., -1:]
129
+
130
+ if template.data.dtype.kind != "c":
131
+ coef = coef.real
132
+
133
+ # pywt.cwt slices off the beginning and end of the result where the convolution overran. We don't have
134
+ # that luxury when streaming.
135
+ # d = (coef.shape[-1] - msg_in.data.shape[ax_idx]) / 2.
136
+ # coef = coef[..., math.floor(d):-math.ceil(d)]
137
+ msg_out = replace(
138
+ template, data=coef, axes={**template.axes, axis: msg_in.axes[axis]}
139
+ )
140
+
141
+
142
+ class CWTSettings(ez.Settings):
143
+ """
144
+ Settings for :obj:`CWT`
145
+ See :obj:`cwt` for argument details.
146
+ """
147
+
148
+ scales: typing.Union[list, tuple, npt.NDArray]
149
+ wavelet: typing.Union[str, pywt.ContinuousWavelet, pywt.Wavelet]
150
+ min_phase: MinPhaseMode = MinPhaseMode.NONE
151
+ axis: str = "time"
152
+
153
+
154
+ class CWT(GenAxisArray):
155
+ """
156
+ :obj:`Unit` for :obj:`common_rereference`.
157
+ """
158
+
159
+ SETTINGS = CWTSettings
160
+
161
+ def construct_generator(self):
162
+ self.STATE.gen = cwt(
163
+ scales=self.SETTINGS.scales,
164
+ wavelet=self.SETTINGS.wavelet,
165
+ min_phase=self.SETTINGS.min_phase,
166
+ axis=self.SETTINGS.axis,
167
+ )
ezmsg/sigproc/window.py CHANGED
@@ -1,144 +1,282 @@
1
1
  from dataclasses import replace
2
+ import traceback
3
+ import typing
2
4
 
3
5
  import ezmsg.core as ez
4
6
  import numpy as np
5
7
  import numpy.typing as npt
8
+ from ezmsg.util.messages.axisarray import (
9
+ AxisArray,
10
+ slice_along_axis,
11
+ sliding_win_oneaxis,
12
+ )
13
+ from ezmsg.util.generator import consumer
6
14
 
7
- from ezmsg.util.messages.axisarray import AxisArray
15
+ from .base import GenAxisArray
8
16
 
9
- from typing import AsyncGenerator, Optional, Tuple, List
10
17
 
18
+ @consumer
19
+ def windowing(
20
+ axis: typing.Optional[str] = None,
21
+ newaxis: str = "win",
22
+ window_dur: typing.Optional[float] = None,
23
+ window_shift: typing.Optional[float] = None,
24
+ zero_pad_until: str = "input",
25
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
26
+ """
27
+ Construct a generator that yields windows of data from an input :obj:`AxisArray`.
11
28
 
12
- class WindowSettings(ez.Settings):
13
- axis: Optional[str] = None
14
- newaxis: Optional[
15
- str
16
- ] = None # Optional new axis for output. If "None" - no new axes on output
17
- window_dur: Optional[
18
- float
19
- ] = None # Sec. If "None" -- passthrough; window_shift is ignored.
20
- window_shift: Optional[float] = None # Sec. If "None", activate "1:1 mode"
29
+ Args:
30
+ axis: The axis along which to segment windows.
31
+ If None, defaults to the first dimension of the first seen AxisArray.
32
+ newaxis: New axis on which windows are delimited, immediately
33
+ preceding the target windowed axis. The data length along newaxis may be 0 if
34
+ this most recent push did not provide enough data for a new window.
35
+ If window_shift is None then the newaxis length will always be 1.
36
+ window_dur: The duration of the window in seconds.
37
+ If None, the function acts as a passthrough and all other parameters are ignored.
38
+ window_shift: The shift of the window in seconds.
39
+ If None (default), windowing operates in "1:1 mode", where each input yields exactly one most-recent window.
40
+ zero_pad_until: Determines how the function initializes the buffer.
41
+ Can be one of "input" (default), "full", "shift", or "none". If `window_shift` is None then this field is
42
+ ignored and "input" is always used.
21
43
 
44
+ - "input" (default) initializes the buffer with the input then prepends with zeros to the window size.
45
+ The first input will always yield at least one output.
46
+ - "shift" fills the buffer until `window_shift`.
47
+ No outputs will be yielded until at least `window_shift` data has been seen.
48
+ - "none" does not pad the buffer. No outputs will be yielded until at least `window_dur` data has been seen.
22
49
 
23
- class WindowState(ez.State):
24
- cur_settings: WindowSettings
50
+ Returns:
51
+ A (primed) generator that accepts .send(an AxisArray object) and yields a list of windowed
52
+ AxisArray objects. The list will always be length-1 if `newaxis` is not None or `window_shift` is None.
53
+ """
54
+ # Check arguments
55
+ if newaxis is None:
56
+ ez.logger.warning("`newaxis` must not be None. Setting to 'win'.")
57
+ newaxis = "win"
58
+ if window_shift is None and zero_pad_until != "input":
59
+ ez.logger.warning(
60
+ "`zero_pad_until` must be 'input' if `window_shift` is None. "
61
+ f"Ignoring received argument value: {zero_pad_until}"
62
+ )
63
+ zero_pad_until = "input"
64
+ elif window_shift is not None and zero_pad_until == "input":
65
+ ez.logger.warning(
66
+ "windowing is non-deterministic with `zero_pad_until='input'` as it depends on the size "
67
+ "of the first input. We recommend using 'shift' when `window_shift` is float-valued."
68
+ )
69
+ msg_out = AxisArray(np.array([]), dims=[""])
25
70
 
26
- samp_shape: Optional[Tuple[int, ...]] = None # Shape of individual sample
27
- out_fs: Optional[float] = None
28
- buffer: Optional[npt.NDArray] = None
29
- window_samples: Optional[int] = None
30
- window_shift_samples: Optional[int] = None
71
+ # State variables
72
+ buffer: typing.Optional[npt.NDArray] = None
73
+ window_samples: typing.Optional[int] = None
74
+ window_shift_samples: typing.Optional[int] = None
75
+ # Number of incoming samples to ignore. Only relevant when shift > window.:
76
+ shift_deficit: int = 0
77
+ b_1to1 = window_shift is None
78
+ newaxis_warned: bool = b_1to1
79
+ out_newaxis: typing.Optional[AxisArray.Axis] = None
80
+ out_dims: typing.typing.Optional[typing.List[str]] = None
31
81
 
82
+ check_inputs = {"samp_shape": None, "fs": None, "key": None}
32
83
 
33
- class Window(ez.Unit):
34
- STATE: WindowState
35
- SETTINGS: WindowSettings
84
+ while True:
85
+ msg_in: AxisArray = yield msg_out
36
86
 
37
- INPUT_SIGNAL = ez.InputStream(AxisArray)
38
- OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
39
- INPUT_SETTINGS = ez.InputStream(WindowSettings)
87
+ if window_dur is None:
88
+ msg_out = msg_in
89
+ continue
40
90
 
41
- def initialize(self) -> None:
42
- self.STATE.cur_settings = self.SETTINGS
91
+ axis = axis or msg_in.dims[0]
92
+ axis_idx = msg_in.get_axis_idx(axis)
93
+ axis_info = msg_in.get_axis(axis)
94
+ fs = 1.0 / axis_info.gain
43
95
 
44
- @ez.subscriber(INPUT_SETTINGS)
45
- async def on_settings(self, msg: WindowSettings) -> None:
46
- self.STATE.cur_settings = msg
47
- self.STATE.out_fs = None # This should trigger a reallocation
96
+ if not newaxis_warned and newaxis in msg_in.dims:
97
+ ez.logger.warning(
98
+ f"newaxis {newaxis} present in input dims. Using {newaxis}_win instead"
99
+ )
100
+ newaxis_warned = True
101
+ newaxis = f"{newaxis}_win"
48
102
 
49
- @ez.subscriber(INPUT_SIGNAL)
50
- @ez.publisher(OUTPUT_SIGNAL)
51
- async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
52
- if self.STATE.cur_settings.window_dur is None:
53
- yield self.OUTPUT_SIGNAL, msg
54
- return
55
-
56
- axis_name = self.STATE.cur_settings.axis
57
- if axis_name is None:
58
- axis_name = msg.dims[0]
59
- axis_idx = msg.get_axis_idx(axis_name)
60
- axis = msg.get_axis(axis_name)
61
- fs = 1.0 / axis.gain
62
-
63
- # Create a view of data with time axis at dim 0
64
- time_view = np.moveaxis(msg.data, axis_idx, 0)
65
- samp_shape = time_view.shape[1:]
66
-
67
- # Pre(re?)allocate buffer
68
- window_samples = int(self.STATE.cur_settings.window_dur * fs)
69
- if (
70
- (self.STATE.samp_shape != samp_shape)
71
- or (self.STATE.out_fs != fs)
72
- or self.STATE.buffer is None
73
- ):
74
- self.STATE.buffer = np.zeros(tuple([window_samples] + list(samp_shape)))
75
-
76
- self.STATE.window_samples = window_samples
77
- self.STATE.samp_shape = samp_shape
78
- self.STATE.out_fs = fs
79
-
80
- self.STATE.window_shift_samples = None
81
- if self.STATE.cur_settings.window_shift is not None:
82
- self.STATE.window_shift_samples = int(
83
- fs * self.STATE.cur_settings.window_shift
103
+ samp_shape = msg_in.data.shape[:axis_idx] + msg_in.data.shape[axis_idx + 1 :]
104
+
105
+ # If buffer unset or input stats changed, create a new buffer
106
+ b_reset = buffer is None
107
+ b_reset = b_reset or samp_shape != check_inputs["samp_shape"]
108
+ b_reset = b_reset or fs != check_inputs["fs"]
109
+ b_reset = b_reset or msg_in.key != check_inputs["key"]
110
+ if b_reset:
111
+ # Update check variables
112
+ check_inputs["samp_shape"] = samp_shape
113
+ check_inputs["fs"] = fs
114
+ check_inputs["key"] = msg_in.key
115
+
116
+ window_samples = int(window_dur * fs)
117
+ if not b_1to1:
118
+ window_shift_samples = int(window_shift * fs)
119
+ if zero_pad_until == "none":
120
+ req_samples = window_samples
121
+ elif zero_pad_until == "shift" and not b_1to1:
122
+ req_samples = window_shift_samples
123
+ else: # i.e. zero_pad_until == "input"
124
+ req_samples = msg_in.data.shape[axis_idx]
125
+ n_zero = max(0, window_samples - req_samples)
126
+ buffer = np.zeros(
127
+ msg_in.data.shape[:axis_idx]
128
+ + (n_zero,)
129
+ + msg_in.data.shape[axis_idx + 1 :]
84
130
  )
85
131
 
86
- # Currently we just concatenate the new time samples and clip the output
87
- # np.roll actually returns a copy, and there's no way to construct a
88
- # rolling view of the data. In current numpy implementations, np.concatenate
132
+ # Add new data to buffer.
133
+ # Currently, we concatenate the new time samples and clip the output.
134
+ # np.roll is not preferred as it returns a copy, and there's no way to construct a
135
+ # rolling view of the data. In current numpy implementations, np.concatenate
89
136
  # is generally faster than np.roll and slicing anyway, but this could still
90
137
  # be a performance bottleneck for large memory arrays.
91
- self.STATE.buffer = np.concatenate((self.STATE.buffer, time_view), axis=0)
138
+ # A circular buffer might be faster.
139
+ buffer = np.concatenate((buffer, msg_in.data), axis=axis_idx)
92
140
 
93
- buffer_offset = np.arange(self.STATE.buffer.shape[0] + time_view.shape[0])
94
- buffer_offset -= self.STATE.buffer.shape[0] + 1
95
- buffer_offset = (buffer_offset * axis.gain) + axis.offset
141
+ # Create a vector of buffer timestamps to track axis `offset` in output(s)
142
+ buffer_offset = np.arange(buffer.shape[axis_idx]).astype(float)
143
+ # Adjust so first _new_ sample at index 0
144
+ buffer_offset -= buffer_offset[-msg_in.data.shape[axis_idx]]
145
+ # Convert form indices to 'units' (probably seconds).
146
+ buffer_offset *= axis_info.gain
147
+ buffer_offset += axis_info.offset
96
148
 
97
- outputs: List[Tuple[npt.NDArray, float]] = []
149
+ if not b_1to1 and shift_deficit > 0:
150
+ n_skip = min(buffer.shape[axis_idx], shift_deficit)
151
+ if n_skip > 0:
152
+ buffer = slice_along_axis(buffer, slice(n_skip, None), axis_idx)
153
+ buffer_offset = buffer_offset[n_skip:]
154
+ shift_deficit -= n_skip
98
155
 
99
- if self.STATE.window_shift_samples is None: # one-to-one mode
100
- self.STATE.buffer = self.STATE.buffer[-self.STATE.window_samples :, ...]
101
- buffer_offset = buffer_offset[-self.STATE.window_samples :]
102
- outputs.append((self.STATE.buffer, buffer_offset[0]))
156
+ # Prepare reusable parts of output
157
+ if out_newaxis is None:
158
+ out_dims = msg_in.dims[:axis_idx] + [newaxis] + msg_in.dims[axis_idx:]
159
+ out_newaxis = replace(
160
+ axis_info,
161
+ gain=0.0 if b_1to1 else axis_info.gain * window_shift_samples,
162
+ offset=0.0, # offset modified per-msg below
163
+ )
164
+
165
+ # Generate outputs.
166
+ # Preliminary copy of axes without the axes that we are modifying.
167
+ out_axes = {k: v for k, v in msg_in.axes.items() if k not in [newaxis, axis]}
168
+
169
+ # Update targeted (windowed) axis so that its offset is relative to the new axis
170
+ # TODO: If we have `anchor_newest=True` then offset should be -win_dur
171
+ out_axes[axis] = replace(axis_info, offset=0.0)
172
+
173
+ # How we update .data and .axes[newaxis] depends on the windowing mode.
174
+ if b_1to1:
175
+ # one-to-one mode -- Each send yields exactly one window containing only the most recent samples.
176
+ buffer = slice_along_axis(buffer, slice(-window_samples, None), axis_idx)
177
+ out_dat = np.expand_dims(buffer, axis=axis_idx)
178
+ out_newaxis = replace(out_newaxis, offset=buffer_offset[-window_samples])
179
+ elif buffer.shape[axis_idx] >= window_samples:
180
+ # Deterministic window shifts.
181
+ out_dat = sliding_win_oneaxis(buffer, window_samples, axis_idx)
182
+ out_dat = slice_along_axis(
183
+ out_dat, slice(None, None, window_shift_samples), axis_idx
184
+ )
185
+ offset_view = sliding_win_oneaxis(buffer_offset, window_samples, 0)[
186
+ ::window_shift_samples
187
+ ]
188
+ out_newaxis = replace(out_newaxis, offset=offset_view[0, 0])
103
189
 
190
+ # Drop expired beginning of buffer and update shift_deficit
191
+ multi_shift = window_shift_samples * out_dat.shape[axis_idx]
192
+ shift_deficit = max(0, multi_shift - buffer.shape[axis_idx])
193
+ buffer = slice_along_axis(buffer, slice(multi_shift, None), axis_idx)
104
194
  else:
105
- yieldable_size = self.STATE.window_samples + self.STATE.window_shift_samples
106
- while self.STATE.buffer.shape[0] >= yieldable_size:
107
- outputs.append(
108
- (
109
- self.STATE.buffer[: self.STATE.window_samples, ...],
110
- buffer_offset[0],
111
- )
112
- )
113
- self.STATE.buffer = self.STATE.buffer[
114
- self.STATE.window_shift_samples :, ...
115
- ]
116
- buffer_offset = buffer_offset[self.STATE.window_shift_samples :]
117
-
118
- for out_view, offset in outputs:
119
- out_view = np.moveaxis(out_view, 0, axis_idx)
120
-
121
- if (
122
- self.STATE.cur_settings.newaxis is not None
123
- and self.STATE.cur_settings.newaxis != self.STATE.cur_settings.axis
124
- ):
125
- new_gain = 0.0
126
- if self.STATE.window_shift_samples is not None:
127
- new_gain = axis.gain * self.STATE.window_shift_samples
128
-
129
- out_axis = replace(axis, unit=axis.unit, gain=new_gain, offset=offset)
130
- out_axes = {**msg.axes, **{self.STATE.cur_settings.newaxis: out_axis}}
131
- out_dims = [self.STATE.cur_settings.newaxis] + msg.dims
132
- out_view = out_view[np.newaxis, ...]
133
-
134
- yield self.OUTPUT_SIGNAL, replace(
135
- msg, data=out_view, dims=out_dims, axes=out_axes
136
- )
137
-
138
- else:
139
- if axis_name in msg.axes:
140
- out_axes = msg.axes
141
- out_axes[axis_name] = replace(axis, offset=offset)
142
- yield self.OUTPUT_SIGNAL, replace(msg, data=out_view, axes=out_axes)
195
+ # Not enough data to make a new window. Return empty data.
196
+ empty_data_shape = (
197
+ msg_in.data.shape[:axis_idx]
198
+ + (0, window_samples)
199
+ + msg_in.data.shape[axis_idx + 1 :]
200
+ )
201
+ out_dat = np.zeros(empty_data_shape, dtype=msg_in.data.dtype)
202
+ # out_newaxis will have first timestamp in input... but mostly meaningless because output is size-zero.
203
+ out_newaxis = replace(out_newaxis, offset=axis_info.offset)
204
+
205
+ msg_out = replace(
206
+ msg_in, data=out_dat, dims=out_dims, axes={**out_axes, newaxis: out_newaxis}
207
+ )
208
+
209
+
210
+ class WindowSettings(ez.Settings):
211
+ axis: typing.Optional[str] = None
212
+ newaxis: typing.Optional[str] = None # new axis for output. No new axes if None
213
+ window_dur: typing.Optional[float] = None # Sec. passthrough if None
214
+ window_shift: typing.Optional[float] = None # Sec. Use "1:1 mode" if None
215
+ zero_pad_until: str = "full" # "full", "shift", "input", "none"
216
+
217
+
218
+ class WindowState(ez.State):
219
+ cur_settings: WindowSettings
220
+ gen: typing.Generator
221
+
222
+
223
+ class Window(GenAxisArray):
224
+ """:obj:`Unit` for :obj:`bandpower`."""
225
+
226
+ SETTINGS = WindowSettings
227
+
228
+ INPUT_SIGNAL = ez.InputStream(AxisArray)
229
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
230
+
231
+ def construct_generator(self):
232
+ self.STATE.gen = windowing(
233
+ axis=self.SETTINGS.axis,
234
+ newaxis=self.SETTINGS.newaxis,
235
+ window_dur=self.SETTINGS.window_dur,
236
+ window_shift=self.SETTINGS.window_shift,
237
+ zero_pad_until=self.SETTINGS.zero_pad_until,
238
+ )
239
+
240
+ @ez.subscriber(INPUT_SIGNAL, zero_copy=True)
241
+ @ez.publisher(OUTPUT_SIGNAL)
242
+ async def on_signal(self, msg: AxisArray) -> typing.AsyncGenerator:
243
+ try:
244
+ out_msg = self.STATE.gen.send(msg)
245
+ if out_msg.data.size > 0:
246
+ if (
247
+ self.SETTINGS.newaxis is not None
248
+ or self.SETTINGS.window_dur is None
249
+ ):
250
+ # Multi-win mode or pass-through mode.
251
+ yield self.OUTPUT_SIGNAL, out_msg
143
252
  else:
144
- yield self.OUTPUT_SIGNAL, replace(msg, data=out_view)
253
+ # We need to split out_msg into multiple yields, dropping newaxis.
254
+ axis_idx = out_msg.get_axis_idx("win")
255
+ win_axis = out_msg.axes["win"]
256
+ offsets = (
257
+ np.arange(out_msg.data.shape[axis_idx]) * win_axis.gain
258
+ + win_axis.offset
259
+ )
260
+ for msg_ix in range(out_msg.data.shape[axis_idx]):
261
+ # Need to drop 'win' and replace self.SETTINGS.axis from axes.
262
+ _out_axes = {
263
+ **{
264
+ k: v
265
+ for k, v in out_msg.axes.items()
266
+ if k not in ["win", self.SETTINGS.axis]
267
+ },
268
+ self.SETTINGS.axis: replace(
269
+ out_msg.axes[self.SETTINGS.axis], offset=offsets[msg_ix]
270
+ ),
271
+ }
272
+ _out_msg = replace(
273
+ out_msg,
274
+ data=slice_along_axis(out_msg.data, msg_ix, axis_idx),
275
+ dims=out_msg.dims[:axis_idx] + out_msg.dims[axis_idx + 1 :],
276
+ axes=_out_axes,
277
+ )
278
+ yield self.OUTPUT_SIGNAL, _out_msg
279
+ except (StopIteration, GeneratorExit):
280
+ ez.logger.debug(f"Window closed in {self.address}")
281
+ except Exception:
282
+ ez.logger.info(traceback.format_exc())
@@ -0,0 +1,59 @@
1
+ Metadata-Version: 2.3
2
+ Name: ezmsg-sigproc
3
+ Version: 1.3.1
4
+ Summary: Timeseries signal processing implementations in ezmsg
5
+ Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>
6
+ License-Expression: MIT
7
+ License-File: LICENSE.txt
8
+ Requires-Python: >=3.9
9
+ Requires-Dist: ezmsg>=3.5.0
10
+ Requires-Dist: numpy>=2.0.2
11
+ Requires-Dist: pywavelets>=1.6.0
12
+ Requires-Dist: scipy>=1.13.1
13
+ Provides-Extra: test
14
+ Requires-Dist: flake8>=7.1.1; extra == 'test'
15
+ Requires-Dist: frozendict>=2.4.4; extra == 'test'
16
+ Requires-Dist: pytest-asyncio>=0.24.0; extra == 'test'
17
+ Requires-Dist: pytest-cov>=5.0.0; extra == 'test'
18
+ Requires-Dist: pytest>=8.3.3; extra == 'test'
19
+ Description-Content-Type: text/markdown
20
+
21
+ # ezmsg.sigproc
22
+
23
+ Timeseries signal processing implementations for ezmsg
24
+
25
+ ## Dependencies
26
+
27
+ * `ezmsg`
28
+ * `numpy`
29
+ * `scipy`
30
+ * `pywavelets`
31
+
32
+ ## Installation
33
+
34
+ ### Release
35
+
36
+ Install the latest release from pypi with: `pip install ezmsg-sigproc` (or `uv add ...` or `poetry add ...`).
37
+
38
+ ### Development Version
39
+
40
+ You can add the development version of `ezmsg-sigproc` to your project's dependencies in one of several ways.
41
+
42
+ You can clone it and add its path to your project dependencies. You may wish to do this if you intend to edit `ezmsg-sigproc`. If so, please refer to the [Developers](#developers) section below.
43
+
44
+ You can also add it directly from GitHub:
45
+
46
+ * Using `pip`: `pip install git+https://github.com/ezmsg-org/ezmsg-sigproc.git@dev`
47
+ * Using `poetry`: `poetry add "git+https://github.com/ezmsg-org/ezmsg-sigproc.git@dev"`
48
+ * Using `uv`: `uv add git+https://github.com/ezmsg-org/ezmsg-sigproc --branch dev`
49
+
50
+ ## Developers
51
+
52
+ We use [`uv`](https://docs.astral.sh/uv/getting-started/installation/) for development. It is not strictly required, but if you intend to contribute to ezmsg-sigproc then using `uv` will lead to the smoothest collaboration.
53
+
54
+ 1. Install [`uv`](https://docs.astral.sh/uv/getting-started/installation/) if not already installed.
55
+ 2. Fork ezmsg-sigproc and clone your fork to your local computer.
56
+ 3. Open a terminal and `cd` to the cloned folder.
57
+ 4. `uv sync` to create a .venv and install dependencies.
58
+ 5. `uv run pre-commit install` to install pre-commit hooks to do linting and formatting.
59
+ 6. After editing code and making commits, Run the test suite before making a PR: `uv run pytest tests`