ezmsg-sigproc 2.9.0__py3-none-any.whl → 2.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '2.9.0'
32
- __version_tuple__ = version_tuple = (2, 9, 0)
31
+ __version__ = version = '2.11.0'
32
+ __version_tuple__ = version_tuple = (2, 11, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -1,42 +1,80 @@
1
+ """
2
+ Streaming zero-phase Butterworth filter implemented as a two-stage composite processor.
3
+
4
+ Stage 1: Forward causal Butterworth filter (from ezmsg.sigproc.butterworthfilter)
5
+ Stage 2: Backward acausal filter with buffering (ButterworthBackwardFilterTransformer)
6
+
7
+ The output is delayed by `pad_length` samples to ensure the backward pass has sufficient
8
+ future context. The pad_length is computed analytically using scipy's heuristic.
9
+ """
10
+
1
11
  import functools
2
12
  import typing
3
13
 
4
- import ezmsg.core as ez
5
14
  import numpy as np
6
15
  import scipy.signal
7
- from ezmsg.baseproc import SettingsType
16
+ from ezmsg.baseproc import BaseTransformerUnit
17
+ from ezmsg.baseproc.composite import CompositeProcessor
8
18
  from ezmsg.util.messages.axisarray import AxisArray
9
19
  from ezmsg.util.messages.util import replace
10
20
 
11
- from ezmsg.sigproc.butterworthfilter import ButterworthFilterSettings, butter_design_fun
12
- from ezmsg.sigproc.filter import (
13
- BACoeffs,
14
- BaseFilterByDesignTransformerUnit,
15
- FilterByDesignTransformer,
16
- SOSCoeffs,
21
+ from .butterworthfilter import (
22
+ ButterworthFilterSettings,
23
+ ButterworthFilterTransformer,
24
+ butter_design_fun,
17
25
  )
26
+ from .filter import BACoeffs, FilterByDesignTransformer, SOSCoeffs
27
+ from .util.axisarray_buffer import HybridAxisArrayBuffer
18
28
 
19
29
 
20
30
  class ButterworthZeroPhaseSettings(ButterworthFilterSettings):
21
- """Settings for :obj:`ButterworthZeroPhase`."""
31
+ """
32
+ Settings for :obj:`ButterworthZeroPhase`.
33
+
34
+ This implements a streaming zero-phase Butterworth filter using forward-backward
35
+ filtering. The output is delayed by `pad_length` samples to ensure the backward
36
+ pass has sufficient future context.
37
+
38
+ The pad_length is computed by finding where the filter's impulse response decays
39
+ to `settle_cutoff` fraction of its peak value. This accounts for the filter's
40
+ actual time constant rather than just its order.
41
+ """
42
+
43
+ # Inherits from ButterworthFilterSettings:
44
+ # axis, coef_type, order, cuton, cutoff, wn_hz
22
45
 
23
- # axis, coef_type, order, cuton, cutoff, wn_hz are inherited from ButterworthFilterSettings
24
- padtype: str | None = None
46
+ settle_cutoff: float = 0.01
25
47
  """
26
- Padding type to use in `scipy.signal.filtfilt`.
27
- Must be one of {'odd', 'even', 'constant', None}.
28
- Default is None for no padding.
48
+ Fraction of peak impulse response used to determine settling time.
49
+ The pad_length is set to the number of samples until the impulse response
50
+ decays to this fraction of its peak. Default is 0.01 (1% of peak).
29
51
  """
30
52
 
31
- padlen: int | None = 0
53
+ max_pad_duration: float | None = None
32
54
  """
33
- Length of the padding to use in `scipy.signal.filtfilt`.
34
- If None, SciPy's default padding is used.
55
+ Maximum pad duration in seconds. If set, the pad_length will be capped
56
+ at this value times the sampling rate. Use this to limit latency for
57
+ filters with very long impulse responses. Default is None (no limit).
35
58
  """
36
59
 
37
60
 
38
- class ButterworthZeroPhaseTransformer(FilterByDesignTransformer[ButterworthZeroPhaseSettings, BACoeffs | SOSCoeffs]):
39
- """Zero-phase (filtfilt) Butterworth using your design function."""
61
+ class ButterworthBackwardFilterTransformer(FilterByDesignTransformer[ButterworthFilterSettings, BACoeffs | SOSCoeffs]):
62
+ """
63
+ Backward (acausal) Butterworth filter with buffering.
64
+
65
+ This transformer buffers its input and applies the filter in reverse,
66
+ outputting only the "settled" portion where transients have decayed.
67
+ This introduces a lag of ``pad_length`` samples.
68
+
69
+ Intended to be used as stage 2 in a zero-phase filter pipeline, receiving
70
+ forward-filtered data from a ButterworthFilterTransformer.
71
+ """
72
+
73
+ # Instance attributes (initialized in _reset_state)
74
+ _buffer: HybridAxisArrayBuffer | None
75
+ _coefs_cache: BACoeffs | SOSCoeffs | None
76
+ _zi_tiled: np.ndarray | None
77
+ _pad_length: int
40
78
 
41
79
  def get_design_function(
42
80
  self,
@@ -50,74 +88,218 @@ class ButterworthZeroPhaseTransformer(FilterByDesignTransformer[ButterworthZeroP
50
88
  wn_hz=self.settings.wn_hz,
51
89
  )
52
90
 
53
- def update_settings(self, new_settings: typing.Optional[SettingsType] = None, **kwargs) -> None:
91
+ def _compute_pad_length(self, fs: float) -> int:
54
92
  """
55
- Update settings and mark that filter coefficients need to be recalculated.
93
+ Compute pad length based on the filter's impulse response settling time.
94
+
95
+ The pad_length is determined by finding where the impulse response decays
96
+ to `settle_cutoff` fraction of its peak value. This is then optionally
97
+ capped by `max_pad_duration`.
56
98
 
57
99
  Args:
58
- new_settings: Complete new settings object to replace current settings
59
- **kwargs: Individual settings to update
100
+ fs: Sampling frequency in Hz.
101
+
102
+ Returns:
103
+ Number of samples for the pad length.
60
104
  """
61
- # Update settings
62
- if new_settings is not None:
63
- self.settings = new_settings
105
+ # Design the filter to compute impulse response
106
+ coefs = self.get_design_function()(fs)
107
+ if coefs is None:
108
+ # Filter design failed or is disabled
109
+ return 0
110
+
111
+ # Generate impulse response - use a generous length initially
112
+ # Start with scipy's heuristic as minimum, then extend if needed
113
+ if self.settings.coef_type == "ba":
114
+ min_length = 3 * (self.settings.order + 1)
64
115
  else:
65
- self.settings = replace(self.settings, **kwargs)
116
+ n_sections = (self.settings.order + 1) // 2
117
+ min_length = 3 * n_sections * 2
66
118
 
67
- # Set flag to trigger recalculation on next message
68
- self._coefs_cache = None
69
- self._fs_cache = None
70
- self.state.needs_redesign = True
119
+ # Use 10x the minimum as initial impulse length, or at least 10000 samples
120
+ # (10000 samples allows for ~333ms at 30kHz, covering most practical cases)
121
+ impulse_length = max(min_length * 10, 10000)
122
+
123
+ # Cap impulse length computation if max_pad_duration is set
124
+ if self.settings.max_pad_duration is not None:
125
+ max_samples = int(self.settings.max_pad_duration * fs)
126
+ impulse_length = min(impulse_length, max_samples + 1)
127
+
128
+ impulse = np.zeros(impulse_length)
129
+ impulse[0] = 1.0
130
+
131
+ if self.settings.coef_type == "ba":
132
+ b, a = coefs
133
+ h = scipy.signal.lfilter(b, a, impulse)
134
+ else:
135
+ h = scipy.signal.sosfilt(coefs, impulse)
136
+
137
+ # Find where impulse response settles to settle_cutoff of peak
138
+ abs_h = np.abs(h)
139
+ peak = abs_h.max()
140
+ if peak == 0:
141
+ return min_length
142
+
143
+ threshold = self.settings.settle_cutoff * peak
144
+ above_threshold = np.where(abs_h > threshold)[0]
145
+
146
+ if len(above_threshold) == 0:
147
+ pad_length = min_length
148
+ else:
149
+ pad_length = above_threshold[-1] + 1
150
+
151
+ # Ensure at least the scipy heuristic minimum
152
+ pad_length = max(pad_length, min_length)
153
+
154
+ # Apply max_pad_duration cap if set
155
+ if self.settings.max_pad_duration is not None:
156
+ max_samples = int(self.settings.max_pad_duration * fs)
157
+ pad_length = min(pad_length, max_samples)
158
+
159
+ return pad_length
71
160
 
72
161
  def _reset_state(self, message: AxisArray) -> None:
162
+ """Reset filter state when stream changes."""
73
163
  self._coefs_cache = None
74
- self._fs_cache = None
164
+ self._zi_tiled = None
165
+ self._buffer = None
166
+ # Compute pad_length based on the message's sampling rate
167
+ axis = message.dims[0] if self.settings.axis is None else self.settings.axis
168
+ fs = 1 / message.axes[axis].gain
169
+ self._pad_length = self._compute_pad_length(fs)
75
170
  self.state.needs_redesign = True
76
171
 
172
+ def _compute_zi_tiled(self, data: np.ndarray, ax_idx: int) -> None:
173
+ """Compute and cache the tiled zi for the given data shape.
174
+
175
+ Called once per stream (or after filter redesign). The result is
176
+ broadcast-ready for multiplication by the edge sample on each chunk.
177
+ """
178
+ if self.settings.coef_type == "ba":
179
+ b, a = self._coefs_cache
180
+ zi_base = scipy.signal.lfilter_zi(b, a)
181
+ else: # sos
182
+ zi_base = scipy.signal.sosfilt_zi(self._coefs_cache)
183
+
184
+ n_tail = data.ndim - ax_idx - 1
185
+
186
+ if self.settings.coef_type == "ba":
187
+ zi_expand = (None,) * ax_idx + (slice(None),) + (None,) * n_tail
188
+ n_tile = data.shape[:ax_idx] + (1,) + data.shape[ax_idx + 1 :]
189
+ else: # sos
190
+ zi_expand = (slice(None),) + (None,) * ax_idx + (slice(None),) + (None,) * n_tail
191
+ n_tile = (1,) + data.shape[:ax_idx] + (1,) + data.shape[ax_idx + 1 :]
192
+
193
+ self._zi_tiled = np.tile(zi_base[zi_expand], n_tile)
194
+
195
+ def _initialize_zi(self, data: np.ndarray, ax_idx: int) -> np.ndarray:
196
+ """Initialize filter state (zi) scaled by edge value."""
197
+ if self._zi_tiled is None:
198
+ self._compute_zi_tiled(data, ax_idx)
199
+ first_sample = np.take(data, [0], axis=ax_idx)
200
+ return self._zi_tiled * first_sample
201
+
77
202
  def _process(self, message: AxisArray) -> AxisArray:
78
203
  axis = message.dims[0] if self.settings.axis is None else self.settings.axis
79
204
  ax_idx = message.get_axis_idx(axis)
80
205
  fs = 1 / message.axes[axis].gain
81
206
 
82
- if (
83
- self._coefs_cache is None
84
- or self.state.needs_redesign
85
- or (self._fs_cache is None or not np.isclose(self._fs_cache, fs))
86
- ):
207
+ # Check if we need to redesign filter
208
+ if self._coefs_cache is None or self.state.needs_redesign:
87
209
  self._coefs_cache = self.get_design_function()(fs)
88
- self._fs_cache = fs
210
+ self._pad_length = self._compute_pad_length(fs)
211
+ self._zi_tiled = None # Invalidate; recomputed on next use.
89
212
  self.state.needs_redesign = False
90
213
 
214
+ # Initialize buffer with duration based on pad_length
215
+ # Add some margin to handle variable chunk sizes
216
+ buffer_duration = (self._pad_length + 1) / fs
217
+ self._buffer = HybridAxisArrayBuffer(duration=buffer_duration, axis=axis)
218
+
219
+ # Early exit if filter is effectively disabled
91
220
  if self._coefs_cache is None or self.settings.order <= 0 or message.data.size <= 0:
92
221
  return message
93
222
 
94
- x = message.data
95
- if self.settings.coef_type == "sos":
96
- y = scipy.signal.sosfiltfilt(
97
- self._coefs_cache,
98
- x,
99
- axis=ax_idx,
100
- padtype=self.settings.padtype,
101
- padlen=self.settings.padlen,
102
- )
103
- elif self.settings.coef_type == "ba":
223
+ # Write new data to buffer
224
+ self._buffer.write(message)
225
+ n_available = self._buffer.available()
226
+ n_output = n_available - self._pad_length
227
+
228
+ # If we don't have enough data yet, return empty
229
+ if n_output <= 0:
230
+ new_shape = list(message.data.shape)
231
+ new_shape[ax_idx] = 0
232
+ empty_data = np.empty(new_shape, dtype=message.data.dtype)
233
+ return replace(message, data=empty_data)
234
+
235
+ # Peek all available data from buffer
236
+ # Note: HybridAxisArrayBuffer moves the target axis to position 0
237
+ buffered = self._buffer.peek(n_available)
238
+ combined = buffered.data
239
+ buffer_ax_idx = 0 # Buffer always puts time axis at position 0
240
+
241
+ # Backward filter on reversed data
242
+ combined_rev = np.flip(combined, axis=buffer_ax_idx)
243
+ backward_zi = self._initialize_zi(combined_rev, buffer_ax_idx)
244
+
245
+ if self.settings.coef_type == "ba":
104
246
  b, a = self._coefs_cache
105
- y = scipy.signal.filtfilt(
106
- b,
107
- a,
108
- x,
109
- axis=ax_idx,
110
- padtype=self.settings.padtype,
111
- padlen=self.settings.padlen,
112
- )
113
- else:
114
- ez.logger.error("coef_type must be 'sos' or 'ba'.")
115
- raise ValueError("coef_type must be 'sos' or 'ba'.")
247
+ y_bwd_rev, _ = scipy.signal.lfilter(b, a, combined_rev, axis=buffer_ax_idx, zi=backward_zi)
248
+ else: # sos
249
+ y_bwd_rev, _ = scipy.signal.sosfilt(self._coefs_cache, combined_rev, axis=buffer_ax_idx, zi=backward_zi)
250
+
251
+ # Reverse back to get output in correct time order
252
+ y_bwd = np.flip(y_bwd_rev, axis=buffer_ax_idx)
253
+
254
+ # Output the settled portion (first n_output samples)
255
+ y = y_bwd[:n_output]
256
+
257
+ # Advance buffer read head to discard output samples, keep pad_length
258
+ self._buffer.seek(n_output)
259
+
260
+ # Build output with adjusted time axis
261
+ # LinearAxis offset is already correct from the buffer
262
+ out_axis = buffered.axes[axis]
263
+
264
+ # Move axis back to original position if needed
265
+ if ax_idx != 0:
266
+ y = np.moveaxis(y, 0, ax_idx)
267
+
268
+ return replace(
269
+ message,
270
+ data=y,
271
+ axes={**message.axes, axis: out_axis},
272
+ )
273
+
274
+
275
+ class ButterworthZeroPhaseTransformer(CompositeProcessor[ButterworthZeroPhaseSettings, AxisArray, AxisArray]):
276
+ """
277
+ Streaming zero-phase Butterworth filter as a composite of two stages.
278
+
279
+ Stage 1 (forward): Standard causal Butterworth filter with state
280
+ Stage 2 (backward): Acausal Butterworth filter with buffering
281
+
282
+ The output is delayed by ``pad_length`` samples.
283
+ """
284
+
285
+ @staticmethod
286
+ def _initialize_processors(
287
+ settings: ButterworthZeroPhaseSettings,
288
+ ) -> dict[str, typing.Any]:
289
+ # Both stages use the same filter design settings
290
+ return {
291
+ "forward": ButterworthFilterTransformer(settings),
292
+ "backward": ButterworthBackwardFilterTransformer(settings),
293
+ }
116
294
 
117
- return replace(message, data=y)
295
+ @classmethod
296
+ def get_message_type(cls, dir: str) -> type[AxisArray]:
297
+ if dir in ("in", "out"):
298
+ return AxisArray
299
+ raise ValueError(f"Invalid direction: {dir}. Must be 'in' or 'out'.")
118
300
 
119
301
 
120
302
  class ButterworthZeroPhase(
121
- BaseFilterByDesignTransformerUnit[ButterworthZeroPhaseSettings, ButterworthZeroPhaseTransformer]
303
+ BaseTransformerUnit[ButterworthZeroPhaseSettings, AxisArray, AxisArray, ButterworthZeroPhaseTransformer]
122
304
  ):
123
305
  SETTINGS = ButterworthZeroPhaseSettings
@@ -0,0 +1,43 @@
1
+ """
2
+ Element-wise power of the data.
3
+
4
+ .. note::
5
+ This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
6
+ enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
7
+ """
8
+
9
+ import ezmsg.core as ez
10
+ from array_api_compat import get_namespace
11
+ from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
12
+ from ezmsg.util.messages.axisarray import AxisArray
13
+ from ezmsg.util.messages.util import replace
14
+
15
+
16
+ class PowSettings(ez.Settings):
17
+ exponent: float = 2.0
18
+ """The exponent to raise the data to. Default is 2.0 (squaring)."""
19
+
20
+
21
+ class PowTransformer(BaseTransformer[PowSettings, AxisArray, AxisArray]):
22
+ def _process(self, message: AxisArray) -> AxisArray:
23
+ xp = get_namespace(message.data)
24
+ return replace(message, data=xp.pow(message.data, self.settings.exponent))
25
+
26
+
27
+ class Pow(BaseTransformerUnit[PowSettings, AxisArray, AxisArray, PowTransformer]):
28
+ SETTINGS = PowSettings
29
+
30
+
31
+ def pow(
32
+ exponent: float = 2.0,
33
+ ) -> PowTransformer:
34
+ """
35
+ Raise the data to an element-wise power. See :obj:`xp.pow` for more details.
36
+
37
+ Args:
38
+ exponent: The exponent to raise the data to. Default is 2.0.
39
+
40
+ Returns: :obj:`PowTransformer`.
41
+
42
+ """
43
+ return PowTransformer(PowSettings(exponent=exponent))
ezmsg/sigproc/scaler.py CHANGED
@@ -7,7 +7,6 @@ from ezmsg.baseproc import (
7
7
  BaseTransformerUnit,
8
8
  processor_state,
9
9
  )
10
- from ezmsg.util.generator import consumer
11
10
  from ezmsg.util.messages.axisarray import AxisArray
12
11
  from ezmsg.util.messages.util import replace
13
12
 
@@ -18,50 +17,69 @@ from .ewma import _tau_from_alpha as _tau_from_alpha
18
17
  from .ewma import ewma_step as ewma_step
19
18
 
20
19
 
21
- @consumer
22
- def scaler(time_constant: float = 1.0, axis: str | None = None) -> typing.Generator[AxisArray, AxisArray, None]:
23
- """
24
- Apply the adaptive standard scaler from https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/
25
- This is faster than :obj:`scaler_np` for single-channel data.
20
+ class RiverAdaptiveStandardScalerSettings(ez.Settings):
21
+ time_constant: float = 1.0
22
+ """Decay constant ``tau`` in seconds."""
23
+
24
+ axis: str | None = None
25
+ """The name of the axis to accumulate statistics over."""
26
+
27
+
28
+ @processor_state
29
+ class RiverAdaptiveStandardScalerState:
30
+ scaler: typing.Any = None
31
+ axis: str | None = None
32
+ axis_idx: int = 0
26
33
 
27
- Args:
28
- time_constant: Decay constant `tau` in seconds.
29
- axis: The name of the axis to accumulate statistics over.
30
34
 
31
- Returns:
32
- A primed generator object that expects to be sent a :obj:`AxisArray` via `.send(axis_array)`
33
- and yields an :obj:`AxisArray` with its data being a standardized, or "Z-scored" version of the input data.
35
+ class RiverAdaptiveStandardScalerTransformer(
36
+ BaseStatefulTransformer[
37
+ RiverAdaptiveStandardScalerSettings,
38
+ AxisArray,
39
+ AxisArray,
40
+ RiverAdaptiveStandardScalerState,
41
+ ]
42
+ ):
43
+ """
44
+ Apply the adaptive standard scaler from
45
+ `river <https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/>`_.
46
+
47
+ This processes data sample-by-sample using River's online learning
48
+ implementation. For a vectorized EWMA-based alternative, see
49
+ :class:`AdaptiveStandardScalerTransformer`.
34
50
  """
35
- from river import preprocessing
36
51
 
37
- msg_out = AxisArray(np.array([]), dims=[""])
38
- _scaler = None
39
- while True:
40
- msg_in: AxisArray = yield msg_out
41
- data = msg_in.data
52
+ def _reset_state(self, message: AxisArray) -> None:
53
+ from river import preprocessing
54
+
55
+ axis = self.settings.axis
42
56
  if axis is None:
43
- axis = msg_in.dims[0]
44
- axis_idx = 0
57
+ axis = message.dims[0]
58
+ self._state.axis_idx = 0
45
59
  else:
46
- axis_idx = msg_in.get_axis_idx(axis)
47
- if axis_idx != 0:
48
- data = np.moveaxis(data, axis_idx, 0)
60
+ self._state.axis_idx = message.get_axis_idx(axis)
61
+ self._state.axis = axis
62
+
63
+ alpha = _alpha_from_tau(self.settings.time_constant, message.axes[axis].gain)
64
+ self._state.scaler = preprocessing.AdaptiveStandardScaler(fading_factor=alpha)
49
65
 
50
- if _scaler is None:
51
- alpha = _alpha_from_tau(time_constant, msg_in.axes[axis].gain)
52
- _scaler = preprocessing.AdaptiveStandardScaler(fading_factor=alpha)
66
+ def _process(self, message: AxisArray) -> AxisArray:
67
+ data = message.data
68
+ axis_idx = self._state.axis_idx
69
+ if axis_idx != 0:
70
+ data = np.moveaxis(data, axis_idx, 0)
53
71
 
54
72
  result = []
55
73
  for sample in data:
56
74
  x = {k: v for k, v in enumerate(sample.flatten().tolist())}
57
- _scaler.learn_one(x)
58
- y = _scaler.transform_one(x)
75
+ self._state.scaler.learn_one(x)
76
+ y = self._state.scaler.transform_one(x)
59
77
  k = sorted(y.keys())
60
78
  result.append(np.array([y[_] for _ in k]).reshape(sample.shape))
61
79
 
62
80
  result = np.stack(result)
63
81
  result = np.moveaxis(result, 0, axis_idx)
64
- msg_out = replace(msg_in, data=result)
82
+ return replace(message, data=result)
65
83
 
66
84
 
67
85
  class AdaptiveStandardScalerSettings(EWMASettings): ...
@@ -158,7 +176,14 @@ class AdaptiveStandardScaler(
158
176
  self.processor.settings = msg
159
177
 
160
178
 
161
- # Backwards compatibility...
179
+ # Convenience functions to support deprecated generator API
180
+ def scaler(time_constant: float = 1.0, axis: str | None = None) -> RiverAdaptiveStandardScalerTransformer:
181
+ """Create a :class:`RiverAdaptiveStandardScalerTransformer` with the given parameters."""
182
+ return RiverAdaptiveStandardScalerTransformer(
183
+ settings=RiverAdaptiveStandardScalerSettings(time_constant=time_constant, axis=axis)
184
+ )
185
+
186
+
162
187
  def scaler_np(time_constant: float = 1.0, axis: str | None = None) -> AdaptiveStandardScalerTransformer:
163
188
  return AdaptiveStandardScalerTransformer(
164
189
  settings=AdaptiveStandardScalerSettings(time_constant=time_constant, axis=axis)
@@ -0,0 +1,116 @@
1
+ """
2
+ Time-domain single-band power estimation.
3
+
4
+ Two methods are provided:
5
+
6
+ 1. **RMS Band Power** — Bandpass filter, square, window into bins, take the mean, optionally take the square root.
7
+ 2. **Square-Law + LPF Band Power** — Bandpass filter, square, lowpass filter (smoothing), downsample.
8
+ """
9
+
10
+ from dataclasses import field
11
+
12
+ import ezmsg.core as ez
13
+ from ezmsg.baseproc import (
14
+ BaseProcessor,
15
+ BaseStatefulProcessor,
16
+ BaseTransformerUnit,
17
+ CompositeProcessor,
18
+ )
19
+ from ezmsg.util.messages.axisarray import AxisArray
20
+ from ezmsg.util.messages.modify import modify_axis
21
+
22
+ from .aggregate import AggregateSettings, AggregateTransformer, AggregationFunction
23
+ from .butterworthfilter import ButterworthFilterSettings, ButterworthFilterTransformer
24
+ from .downsample import DownsampleSettings, DownsampleTransformer
25
+ from .math.pow import PowSettings, PowTransformer
26
+ from .window import WindowTransformer
27
+
28
+
29
+ class RMSBandPowerSettings(ez.Settings):
30
+ """Settings for :obj:`RMSBandPowerTransformer`."""
31
+
32
+ bandpass: ButterworthFilterSettings = field(
33
+ default_factory=lambda: ButterworthFilterSettings(order=4, coef_type="sos")
34
+ )
35
+ """Butterworth bandpass filter settings. Set ``cuton`` and ``cutoff`` to define the band."""
36
+
37
+ bin_duration: float = 0.05
38
+ """Duration of each non-overlapping bin in seconds."""
39
+
40
+ apply_sqrt: bool = True
41
+ """If True, output is RMS (root-mean-square). If False, output is mean-square power."""
42
+
43
+
44
+ class RMSBandPowerTransformer(CompositeProcessor[RMSBandPowerSettings, AxisArray, AxisArray]):
45
+ """
46
+ RMS band power estimation.
47
+
48
+ Pipeline: bandpass -> square -> window(bins) -> mean(time) -> rename bin->time -> [sqrt]
49
+ """
50
+
51
+ @staticmethod
52
+ def _initialize_processors(
53
+ settings: RMSBandPowerSettings,
54
+ ) -> dict[str, BaseProcessor | BaseStatefulProcessor]:
55
+ procs: dict[str, BaseProcessor | BaseStatefulProcessor] = {
56
+ "bandpass": ButterworthFilterTransformer(settings.bandpass),
57
+ "square": PowTransformer(PowSettings(exponent=2.0)),
58
+ "window": WindowTransformer(
59
+ axis="time",
60
+ newaxis="bin",
61
+ window_dur=settings.bin_duration,
62
+ window_shift=settings.bin_duration,
63
+ zero_pad_until="none",
64
+ ),
65
+ "aggregate": AggregateTransformer(AggregateSettings(axis="time", operation=AggregationFunction.MEAN)),
66
+ "rename": modify_axis(name_map={"bin": "time"}),
67
+ }
68
+ if settings.apply_sqrt:
69
+ procs["sqrt"] = PowTransformer(PowSettings(exponent=0.5))
70
+ return procs
71
+
72
+
73
+ class RMSBandPower(BaseTransformerUnit[RMSBandPowerSettings, AxisArray, AxisArray, RMSBandPowerTransformer]):
74
+ SETTINGS = RMSBandPowerSettings
75
+
76
+
77
+ class SquareLawBandPowerSettings(ez.Settings):
78
+ """Settings for :obj:`SquareLawBandPowerTransformer`."""
79
+
80
+ bandpass: ButterworthFilterSettings = field(
81
+ default_factory=lambda: ButterworthFilterSettings(order=4, coef_type="sos")
82
+ )
83
+ """Butterworth bandpass filter settings. Set ``cuton`` and ``cutoff`` to define the band."""
84
+
85
+ lowpass: ButterworthFilterSettings = field(
86
+ default_factory=lambda: ButterworthFilterSettings(order=4, coef_type="sos")
87
+ )
88
+ """Butterworth lowpass filter settings for smoothing the squared signal."""
89
+
90
+ downsample: DownsampleSettings = field(default_factory=DownsampleSettings)
91
+ """Downsample settings for rate reduction after lowpass smoothing."""
92
+
93
+
94
+ class SquareLawBandPowerTransformer(CompositeProcessor[SquareLawBandPowerSettings, AxisArray, AxisArray]):
95
+ """
96
+ Square-law + LPF band power estimation.
97
+
98
+ Pipeline: bandpass -> square -> lowpass -> downsample
99
+ """
100
+
101
+ @staticmethod
102
+ def _initialize_processors(
103
+ settings: SquareLawBandPowerSettings,
104
+ ) -> dict[str, BaseProcessor | BaseStatefulProcessor]:
105
+ return {
106
+ "bandpass": ButterworthFilterTransformer(settings.bandpass),
107
+ "square": PowTransformer(PowSettings(exponent=2.0)),
108
+ "lowpass": ButterworthFilterTransformer(settings.lowpass),
109
+ "downsample": DownsampleTransformer(settings.downsample),
110
+ }
111
+
112
+
113
+ class SquareLawBandPower(
114
+ BaseTransformerUnit[SquareLawBandPowerSettings, AxisArray, AxisArray, SquareLawBandPowerTransformer]
115
+ ):
116
+ SETTINGS = SquareLawBandPowerSettings
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ezmsg-sigproc
3
- Version: 2.9.0
3
+ Version: 2.11.0
4
4
  Summary: Timeseries signal processing implementations in ezmsg
5
5
  Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>, Kyle McGraw <kmcgraw@blackrockneuro.com>
6
6
  License-Expression: MIT
@@ -1,5 +1,5 @@
1
1
  ezmsg/sigproc/__init__.py,sha256=8K4IcOA3-pfzadoM6s2Sfg5460KlJUocGgyTJTJl96U,52
2
- ezmsg/sigproc/__version__.py,sha256=24D27WSWcO2HaMUGCrVCG6Zbd76BrLDh2qQ1KbGX6Hc,704
2
+ ezmsg/sigproc/__version__.py,sha256=eqKbWb9LnxuZWE9-pafopBz45ugg0beSlKLIOIjeSzc,706
3
3
  ezmsg/sigproc/activation.py,sha256=83vnTa3ZcC4Q3VSWcGfaqhCEqYRNySUOyVpMHZXfz-c,2755
4
4
  ezmsg/sigproc/adaptive_lattice_notch.py,sha256=ThUR48mbSHuThkimtD0j4IXNMrOVcpZgGhE7PCYfXhU,8818
5
5
  ezmsg/sigproc/affinetransform.py,sha256=jl7DiSa5Yb0qsmFJbfSiSeGmvK1SGoBgycFC5JU5DVY,9434
@@ -7,7 +7,7 @@ ezmsg/sigproc/aggregate.py,sha256=7Hdz1m-S6Cl9h0oRQHeS_UTGBemhOB4XdFyX6cGcdHo,93
7
7
  ezmsg/sigproc/bandpower.py,sha256=dAhH56sUrXNhcRFymTTwjdM_KcU5OxFzrR_sxIPAxyw,2264
8
8
  ezmsg/sigproc/base.py,sha256=SJvKEb8gw6mUMwlV5sH0iPG0bXrgS8tvkPwhI-j89MQ,3672
9
9
  ezmsg/sigproc/butterworthfilter.py,sha256=NKTGkgjvlmC1Dc9gD2Z6UBzUq12KicfnczrzM5ZTosk,5255
10
- ezmsg/sigproc/butterworthzerophase.py,sha256=Df3F1QBBE39FBjNi67wvTsb1bSdTRTSTZXbZiKFlxC4,4105
10
+ ezmsg/sigproc/butterworthzerophase.py,sha256=CU6cXkI6j1LQCEz0sr2IthAPCq_TEtbvSb7h2Nw1w74,11820
11
11
  ezmsg/sigproc/cheby.py,sha256=B8pGt5_pOBpNZCmaibNl_NKkyuasd8ZEJXeTDCTaino,3711
12
12
  ezmsg/sigproc/combfilter.py,sha256=MSxr1I-jBePW_9AuCiv3RQ1HUNxIsNhLk0q1Iu8ikAw,4766
13
13
  ezmsg/sigproc/coordinatespaces.py,sha256=bp_0fTS9b27OQqLoFzgE3f9rb287P8y0S1dWWGrS08o,5298
@@ -34,8 +34,9 @@ ezmsg/sigproc/quantize.py,sha256=uSM2z2xXwL0dgSltyzLEmlKjaJZ2meA3PDWX8_bM0Hs,219
34
34
  ezmsg/sigproc/resample.py,sha256=3mm9pvxryNVhQuTCIMW3ToUkUfbVOCsIgvXUiurit1Y,11389
35
35
  ezmsg/sigproc/rollingscaler.py,sha256=e-smSKDhmDD2nWIf6I77CtRxQp_7sHS268SGPi7aXp8,8499
36
36
  ezmsg/sigproc/sampler.py,sha256=iOk2YoUX22u9iTjFKimzP5V074RDBVcmswgfyxvZRZo,10761
37
- ezmsg/sigproc/scaler.py,sha256=oBZa6uzyftChvk6aqBD5clil6pedx3IF-dptrb74EA0,5888
37
+ ezmsg/sigproc/scaler.py,sha256=nCgShZufPId_b-Sbsc8Si31lbtOb3nPImNcnksd774w,6578
38
38
  ezmsg/sigproc/signalinjector.py,sha256=mB62H2b-ScgPtH1jajEpxgDHqdb-RKekQfgyNncsE8Y,2874
39
+ ezmsg/sigproc/singlebandpow.py,sha256=BVlWhFI6zU3ME3EVdZbwf-FMz1d2sfuNFDKXs1hn5HM,4353
39
40
  ezmsg/sigproc/slicer.py,sha256=xLXxWf722V08ytVwvPimYjDKKj0pkC2HjdgCVaoaOvs,5195
40
41
  ezmsg/sigproc/spectral.py,sha256=wFzuihS7qJZMQcp0ds_qCG-zCbvh5DyhFRjn2wA9TWQ,322
41
42
  ezmsg/sigproc/spectrogram.py,sha256=g8xYWENzle6O5uEF-vfjsF5gOSDnJTwiu3ZudicO470,2893
@@ -50,6 +51,7 @@ ezmsg/sigproc/math/clip.py,sha256=1D6mUlOzBB7L35G_KKYZmfg7nYlbuDdITV4EH0R-yUo,15
50
51
  ezmsg/sigproc/math/difference.py,sha256=uUYZgbLe-GrFSN6EOFjs9fQZllp827IluxL6m8TJuH8,4791
51
52
  ezmsg/sigproc/math/invert.py,sha256=nz8jbfvDoez6s9NmAprBtTAI5oSDj0wNUPk8j13XiVk,855
52
53
  ezmsg/sigproc/math/log.py,sha256=JhjSqLnQnvx_3F4txRYHuUPSJ12Yj2HvRTsCMNvlxpo,2022
54
+ ezmsg/sigproc/math/pow.py,sha256=0sdlXFUEBXmpEV_i75oshGRjMguv8L13nLt7hlvdX3E,1284
53
55
  ezmsg/sigproc/math/scale.py,sha256=4_xHcHNuf13E1fxIF5vbkPfkN4En6zkfPIKID7lCERk,1133
54
56
  ezmsg/sigproc/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
55
57
  ezmsg/sigproc/util/asio.py,sha256=aAj0e7OoBvkRy28k05HL2s9YPCTxOddc05xMN-qd4lQ,577
@@ -59,7 +61,7 @@ ezmsg/sigproc/util/message.py,sha256=ppN3IYtIAwrxWG9JOvgWFn1wDdIumkEzYFfqpH9VQkY
59
61
  ezmsg/sigproc/util/profile.py,sha256=eVOo9pXgusrnH1yfRdd2RsM7Dbe2UpyC0LJ9MfGpB08,416
60
62
  ezmsg/sigproc/util/sparse.py,sha256=NjbJitCtO0B6CENTlyd9c-lHEJwoCan-T3DIgPyeShw,4834
61
63
  ezmsg/sigproc/util/typeresolution.py,sha256=fMFzLi63dqCIclGFLcMdM870OYxJnkeWw6aWKNMk718,362
62
- ezmsg_sigproc-2.9.0.dist-info/METADATA,sha256=H7aJdYvhCekfdihzdsPbvT1a942faGRIlcoeDd_23HI,1908
63
- ezmsg_sigproc-2.9.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
64
- ezmsg_sigproc-2.9.0.dist-info/licenses/LICENSE,sha256=seu0tKhhAMPCUgc1XpXGGaCxY1YaYvFJwqFuQZAl2go,1100
65
- ezmsg_sigproc-2.9.0.dist-info/RECORD,,
64
+ ezmsg_sigproc-2.11.0.dist-info/METADATA,sha256=8XB8fu3sNqsrwV-ff8xtlWUKsFdERMSqqkotMhfNtu0,1909
65
+ ezmsg_sigproc-2.11.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
66
+ ezmsg_sigproc-2.11.0.dist-info/licenses/LICENSE,sha256=seu0tKhhAMPCUgc1XpXGGaCxY1YaYvFJwqFuQZAl2go,1100
67
+ ezmsg_sigproc-2.11.0.dist-info/RECORD,,