ezmsg-sigproc 1.7.0__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. ezmsg/sigproc/__version__.py +22 -4
  2. ezmsg/sigproc/activation.py +31 -40
  3. ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
  4. ezmsg/sigproc/affinetransform.py +171 -169
  5. ezmsg/sigproc/aggregate.py +190 -97
  6. ezmsg/sigproc/bandpower.py +60 -55
  7. ezmsg/sigproc/base.py +143 -33
  8. ezmsg/sigproc/butterworthfilter.py +34 -38
  9. ezmsg/sigproc/butterworthzerophase.py +305 -0
  10. ezmsg/sigproc/cheby.py +23 -17
  11. ezmsg/sigproc/combfilter.py +160 -0
  12. ezmsg/sigproc/coordinatespaces.py +159 -0
  13. ezmsg/sigproc/decimate.py +15 -10
  14. ezmsg/sigproc/denormalize.py +78 -0
  15. ezmsg/sigproc/detrend.py +28 -0
  16. ezmsg/sigproc/diff.py +82 -0
  17. ezmsg/sigproc/downsample.py +72 -81
  18. ezmsg/sigproc/ewma.py +217 -0
  19. ezmsg/sigproc/ewmfilter.py +1 -1
  20. ezmsg/sigproc/extract_axis.py +39 -0
  21. ezmsg/sigproc/fbcca.py +307 -0
  22. ezmsg/sigproc/filter.py +254 -148
  23. ezmsg/sigproc/filterbank.py +226 -214
  24. ezmsg/sigproc/filterbankdesign.py +129 -0
  25. ezmsg/sigproc/fir_hilbert.py +336 -0
  26. ezmsg/sigproc/fir_pmc.py +209 -0
  27. ezmsg/sigproc/firfilter.py +117 -0
  28. ezmsg/sigproc/gaussiansmoothing.py +89 -0
  29. ezmsg/sigproc/kaiser.py +106 -0
  30. ezmsg/sigproc/linear.py +120 -0
  31. ezmsg/sigproc/math/abs.py +23 -22
  32. ezmsg/sigproc/math/add.py +120 -0
  33. ezmsg/sigproc/math/clip.py +33 -25
  34. ezmsg/sigproc/math/difference.py +117 -43
  35. ezmsg/sigproc/math/invert.py +18 -25
  36. ezmsg/sigproc/math/log.py +38 -33
  37. ezmsg/sigproc/math/scale.py +24 -25
  38. ezmsg/sigproc/messages.py +1 -2
  39. ezmsg/sigproc/quantize.py +68 -0
  40. ezmsg/sigproc/resample.py +278 -0
  41. ezmsg/sigproc/rollingscaler.py +232 -0
  42. ezmsg/sigproc/sampler.py +209 -254
  43. ezmsg/sigproc/scaler.py +93 -218
  44. ezmsg/sigproc/signalinjector.py +44 -43
  45. ezmsg/sigproc/slicer.py +74 -102
  46. ezmsg/sigproc/spectral.py +3 -3
  47. ezmsg/sigproc/spectrogram.py +70 -70
  48. ezmsg/sigproc/spectrum.py +187 -173
  49. ezmsg/sigproc/transpose.py +134 -0
  50. ezmsg/sigproc/util/__init__.py +0 -0
  51. ezmsg/sigproc/util/asio.py +25 -0
  52. ezmsg/sigproc/util/axisarray_buffer.py +365 -0
  53. ezmsg/sigproc/util/buffer.py +449 -0
  54. ezmsg/sigproc/util/message.py +17 -0
  55. ezmsg/sigproc/util/profile.py +23 -0
  56. ezmsg/sigproc/util/sparse.py +115 -0
  57. ezmsg/sigproc/util/typeresolution.py +17 -0
  58. ezmsg/sigproc/wavelets.py +147 -154
  59. ezmsg/sigproc/window.py +248 -210
  60. ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
  61. ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
  62. {ezmsg_sigproc-1.7.0.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -1
  63. ezmsg/sigproc/synth.py +0 -621
  64. ezmsg_sigproc-1.7.0.dist-info/METADATA +0 -58
  65. ezmsg_sigproc-1.7.0.dist-info/RECORD +0 -36
  66. /ezmsg_sigproc-1.7.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/scaler.py CHANGED
@@ -1,169 +1,25 @@
1
- import functools
2
1
  import typing
3
2
 
4
- import numpy as np
5
- import numpy.typing as npt
6
- import scipy.signal
7
3
  import ezmsg.core as ez
4
+ import numpy as np
5
+ from ezmsg.baseproc import (
6
+ BaseStatefulTransformer,
7
+ BaseTransformerUnit,
8
+ processor_state,
9
+ )
10
+ from ezmsg.util.generator import consumer
8
11
  from ezmsg.util.messages.axisarray import AxisArray
9
12
  from ezmsg.util.messages.util import replace
10
- from ezmsg.util.generator import consumer
11
-
12
- from .base import GenAxisArray
13
-
14
-
15
- def _tau_from_alpha(alpha: float, dt: float) -> float:
16
- """
17
- Inverse of _alpha_from_tau. See that function for explanation.
18
- """
19
- return -dt / np.log(1 - alpha)
20
-
21
-
22
- def _alpha_from_tau(tau: float, dt: float) -> float:
23
- """
24
- # https://en.wikipedia.org/wiki/Exponential_smoothing#Time_constant
25
- :param tau: The amount of time for the smoothed response of a unit step function to reach
26
- 1 - 1/e approx-eq 63.2%.
27
- :param dt: sampling period, or 1 / sampling_rate.
28
- :return: alpha, the "fading factor" in exponential smoothing.
29
- """
30
- return 1 - np.exp(-dt / tau)
31
-
32
-
33
- def ewma_step(
34
- sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None
35
- ):
36
- """
37
- Do an exponentially weighted moving average step.
38
-
39
- Args:
40
- sample: The new sample.
41
- zi: The output of the previous step.
42
- alpha: Fading factor.
43
- beta: Persisting factor. If None, it is calculated as 1-alpha.
44
-
45
- Returns:
46
- alpha * sample + beta * zi
47
-
48
- """
49
- # Potential micro-optimization:
50
- # Current: scalar-arr multiplication, scalar-arr multiplication, arr-arr addition
51
- # Alternative: arr-arr subtraction, arr-arr multiplication, arr-arr addition
52
- # return zi + alpha * (new_sample - zi)
53
- beta = beta or (1 - alpha)
54
- return alpha * sample + beta * zi
55
-
56
-
57
- class EWMA:
58
- def __init__(self, alpha: float):
59
- self.beta = 1 - alpha
60
- self._filt_func = functools.partial(
61
- scipy.signal.lfilter, [alpha], [1.0, alpha - 1.0], axis=0
62
- )
63
- self.prev = None
64
-
65
- def compute(self, arr: npt.NDArray) -> npt.NDArray:
66
- if self.prev is None:
67
- self.prev = self.beta * arr[:1]
68
- expected, self.prev = self._filt_func(arr, zi=self.prev)
69
- return expected
70
-
71
-
72
- class EWMA_Deprecated:
73
- """
74
- Grabbed these methods from https://stackoverflow.com/a/70998068 and other answers in that topic,
75
- but they ended up being slower than the scipy.signal.lfilter method.
76
- Additionally, `compute` and `compute2` suffer from potential errors as the vector length increases
77
- and beta**n approaches zero.
78
- """
79
-
80
- def __init__(self, alpha: float, max_len: int):
81
- self.alpha = alpha
82
- self.beta = 1 - alpha
83
- self.prev: npt.NDArray | None = None
84
- self.weights = np.empty((max_len + 1,), float)
85
- self._precalc_weights(max_len)
86
- self._step_func = functools.partial(ewma_step, alpha=self.alpha, beta=self.beta)
87
-
88
- def _precalc_weights(self, n: int):
89
- # (1-α)^0, (1-α)^1, (1-α)^2, ..., (1-α)^n
90
- np.power(self.beta, np.arange(n + 1), out=self.weights)
91
-
92
- def compute(self, arr: npt.NDArray, out: npt.NDArray | None = None) -> npt.NDArray:
93
- if out is None:
94
- out = np.empty(arr.shape, arr.dtype)
95
-
96
- n = arr.shape[0]
97
- weights = self.weights[:n]
98
- weights = np.expand_dims(weights, list(range(1, arr.ndim)))
99
-
100
- # α*P0, α*P1, α*P2, ..., α*Pn
101
- np.multiply(self.alpha, arr, out)
102
-
103
- # α*P0/(1-α)^0, α*P1/(1-α)^1, α*P2/(1-α)^2, ..., α*Pn/(1-α)^n
104
- np.divide(out, weights, out)
105
-
106
- # α*P0/(1-α)^0, α*P0/(1-α)^0 + α*P1/(1-α)^1, ...
107
- np.cumsum(out, axis=0, out=out)
108
-
109
- # (α*P0/(1-α)^0)*(1-α)^0, (α*P0/(1-α)^0 + α*P1/(1-α)^1)*(1-α)^1, ...
110
- np.multiply(out, weights, out)
111
-
112
- # Add the previous output
113
- if self.prev is None:
114
- self.prev = arr[:1]
115
-
116
- out += self.prev * np.expand_dims(
117
- self.weights[1 : n + 1], list(range(1, arr.ndim))
118
- )
119
-
120
- self.prev = out[-1:]
121
13
 
122
- return out
123
-
124
- def compute2(self, arr: npt.NDArray) -> npt.NDArray:
125
- """
126
- Compute the Exponentially Weighted Moving Average (EWMA) of the input array.
127
-
128
- Args:
129
- arr: The input array to be smoothed.
130
-
131
- Returns:
132
- The smoothed array.
133
- """
134
- n = arr.shape[0]
135
- if n > len(self.weights):
136
- self._precalc_weights(n)
137
- weights = self.weights[:n][::-1]
138
- weights = np.expand_dims(weights, list(range(1, arr.ndim)))
139
-
140
- result = np.cumsum(self.alpha * weights * arr, axis=0)
141
- result = result / weights
142
-
143
- # Handle the first call when prev is unset
144
- if self.prev is None:
145
- self.prev = arr[:1]
146
-
147
- result += self.prev * np.expand_dims(
148
- self.weights[1 : n + 1], list(range(1, arr.ndim))
149
- )
150
-
151
- # Store the result back into prev
152
- self.prev = result[-1]
153
-
154
- return result
155
-
156
- def compute_sample(self, new_sample: npt.NDArray) -> npt.NDArray:
157
- if self.prev is None:
158
- self.prev = new_sample
159
- self.prev = self._step_func(new_sample, self.prev)
160
- return self.prev
14
+ # Imports for backwards compatibility with previous module location
15
+ from .ewma import EWMA_Deprecated as EWMA_Deprecated
16
+ from .ewma import EWMASettings, EWMATransformer, _alpha_from_tau
17
+ from .ewma import _tau_from_alpha as _tau_from_alpha
18
+ from .ewma import ewma_step as ewma_step
161
19
 
162
20
 
163
21
  @consumer
164
- def scaler(
165
- time_constant: float = 1.0, axis: str | None = None
166
- ) -> typing.Generator[AxisArray, AxisArray, None]:
22
+ def scaler(time_constant: float = 1.0, axis: str | None = None) -> typing.Generator[AxisArray, AxisArray, None]:
167
23
  """
168
24
  Apply the adaptive standard scaler from https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/
169
25
  This is faster than :obj:`scaler_np` for single-channel data.
@@ -208,83 +64,102 @@ def scaler(
208
64
  msg_out = replace(msg_in, data=result)
209
65
 
210
66
 
211
- @consumer
212
- def scaler_np(
213
- time_constant: float = 1.0, axis: str | None = None
214
- ) -> typing.Generator[AxisArray, AxisArray, None]:
215
- """
216
- Create a generator function that applies an adaptive standard scaler.
217
- This is faster than :obj:`scaler` for multichannel data.
67
+ class AdaptiveStandardScalerSettings(EWMASettings): ...
218
68
 
219
- Args:
220
- time_constant: Decay constant `tau` in seconds.
221
- axis: The name of the axis to accumulate statistics over.
222
- Note: The axis must be in the msg.axes and be of type AxisArray.LinearAxis.
223
69
 
224
- Returns:
225
- A primed generator object that expects to be sent a :obj:`AxisArray` via `.send(axis_array)`
226
- and yields an :obj:`AxisArray` with its data being a standardized, or "Z-scored" version of the input data.
227
- """
228
- msg_out = AxisArray(np.array([]), dims=[""])
70
+ @processor_state
71
+ class AdaptiveStandardScalerState:
72
+ samps_ewma: EWMATransformer | None = None
73
+ vars_sq_ewma: EWMATransformer | None = None
74
+ alpha: float | None = None
229
75
 
230
- # State variables
231
- samps_ewma: EWMA | None = None
232
- vars_sq_ewma: EWMA | None = None
233
76
 
234
- # Reset if input changes
235
- check_input = {
236
- "gain": None, # Resets alpha
237
- "shape": None,
238
- "key": None, # Key change implies buffered means/vars are invalid.
239
- }
77
+ class AdaptiveStandardScalerTransformer(
78
+ BaseStatefulTransformer[
79
+ AdaptiveStandardScalerSettings,
80
+ AxisArray,
81
+ AxisArray,
82
+ AdaptiveStandardScalerState,
83
+ ]
84
+ ):
85
+ def _reset_state(self, message: AxisArray) -> None:
86
+ self._state.samps_ewma = EWMATransformer(
87
+ time_constant=self.settings.time_constant,
88
+ axis=self.settings.axis,
89
+ accumulate=self.settings.accumulate,
90
+ )
91
+ self._state.vars_sq_ewma = EWMATransformer(
92
+ time_constant=self.settings.time_constant,
93
+ axis=self.settings.axis,
94
+ accumulate=self.settings.accumulate,
95
+ )
240
96
 
241
- while True:
242
- msg_in: AxisArray = yield msg_out
97
+ @property
98
+ def accumulate(self) -> bool:
99
+ """Whether to accumulate statistics from incoming samples."""
100
+ return self.settings.accumulate
243
101
 
244
- axis = axis or msg_in.dims[0]
245
- axis_idx = msg_in.get_axis_idx(axis)
102
+ @accumulate.setter
103
+ def accumulate(self, value: bool) -> None:
104
+ """
105
+ Set the accumulate mode and propagate to child EWMA transformers.
246
106
 
247
- data: npt.NDArray = np.moveaxis(msg_in.data, axis_idx, 0)
248
- b_reset = data.shape[1:] != check_input["shape"]
249
- b_reset = b_reset or msg_in.axes[axis].gain != check_input["gain"]
250
- b_reset = b_reset or msg_in.key != check_input["key"]
251
- if b_reset:
252
- check_input["shape"] = data.shape[1:]
253
- check_input["gain"] = msg_in.axes[axis].gain
254
- check_input["key"] = msg_in.key
255
- alpha = _alpha_from_tau(time_constant, msg_in.axes[axis].gain)
256
- samps_ewma = EWMA(alpha=alpha)
257
- vars_sq_ewma = EWMA(alpha=alpha)
107
+ Args:
108
+ value: If True, update statistics with each sample.
109
+ If False, only apply current statistics without updating.
110
+ """
111
+ if self._state.samps_ewma is not None:
112
+ self._state.samps_ewma.settings = replace(self._state.samps_ewma.settings, accumulate=value)
113
+ if self._state.vars_sq_ewma is not None:
114
+ self._state.vars_sq_ewma.settings = replace(self._state.vars_sq_ewma.settings, accumulate=value)
258
115
 
259
- # Update step
260
- means = samps_ewma.compute(data)
261
- vars_sq_means = vars_sq_ewma.compute(data**2)
116
+ def _process(self, message: AxisArray) -> AxisArray:
117
+ # Update step (respects accumulate setting via child EWMAs)
118
+ mean_message = self._state.samps_ewma(message)
119
+ var_sq_message = self._state.vars_sq_ewma(replace(message, data=message.data**2))
262
120
 
263
121
  # Get step
264
- varis = vars_sq_means - means**2
122
+ varis = var_sq_message.data - mean_message.data**2
265
123
  with np.errstate(divide="ignore", invalid="ignore"):
266
- result = (data - means) / (varis**0.5)
124
+ result = (message.data - mean_message.data) / (varis**0.5)
267
125
  result[np.isnan(result)] = 0.0
268
- result = np.moveaxis(result, 0, axis_idx)
269
- msg_out = replace(msg_in, data=result)
126
+ return replace(message, data=result)
270
127
 
271
128
 
272
- class AdaptiveStandardScalerSettings(ez.Settings):
273
- """
274
- Settings for :obj:`AdaptiveStandardScaler`.
275
- See :obj:`scaler_np` for a description of the parameters.
276
- """
129
+ class AdaptiveStandardScaler(
130
+ BaseTransformerUnit[
131
+ AdaptiveStandardScalerSettings,
132
+ AxisArray,
133
+ AxisArray,
134
+ AdaptiveStandardScalerTransformer,
135
+ ]
136
+ ):
137
+ SETTINGS = AdaptiveStandardScalerSettings
277
138
 
278
- time_constant: float = 1.0
279
- axis: str | None = None
139
+ @ez.subscriber(BaseTransformerUnit.INPUT_SETTINGS)
140
+ async def on_settings(self, msg: AdaptiveStandardScalerSettings) -> None:
141
+ """
142
+ Handle settings updates with smart reset behavior.
280
143
 
144
+ Only resets state if `axis` changes (structural change).
145
+ Changes to `time_constant` or `accumulate` are applied without
146
+ resetting accumulated statistics.
147
+ """
148
+ old_axis = self.SETTINGS.axis
149
+ self.apply_settings(msg)
281
150
 
282
- class AdaptiveStandardScaler(GenAxisArray):
283
- """Unit for :obj:`scaler_np`"""
151
+ if msg.axis != old_axis:
152
+ # Axis changed - need full reset
153
+ self.create_processor()
154
+ else:
155
+ # Update accumulate on processor (propagates to child EWMAs)
156
+ self.processor.accumulate = msg.accumulate
157
+ # Also update own settings reference
158
+ self.processor.settings = msg
284
159
 
285
- SETTINGS = AdaptiveStandardScalerSettings
286
160
 
287
- def construct_generator(self):
288
- self.STATE.gen = scaler_np(
289
- time_constant=self.SETTINGS.time_constant, axis=self.SETTINGS.axis
290
- )
161
+ # Backwards compatibility...
162
+ def scaler_np(time_constant: float = 1.0, axis: str | None = None) -> AdaptiveStandardScalerTransformer:
163
+ return AdaptiveStandardScalerTransformer(
164
+ settings=AdaptiveStandardScalerSettings(time_constant=time_constant, axis=axis)
165
+ )
@@ -1,10 +1,13 @@
1
- import typing
2
-
3
1
  import ezmsg.core as ez
4
- from ezmsg.util.messages.axisarray import AxisArray
5
- from ezmsg.util.messages.util import replace
6
2
  import numpy as np
7
3
  import numpy.typing as npt
4
+ from ezmsg.baseproc import (
5
+ BaseAsyncTransformer,
6
+ BaseTransformerUnit,
7
+ processor_state,
8
+ )
9
+ from ezmsg.util.messages.axisarray import AxisArray
10
+ from ezmsg.util.messages.util import replace
8
11
 
9
12
 
10
13
  class SignalInjectorSettings(ez.Settings):
@@ -14,56 +17,54 @@ class SignalInjectorSettings(ez.Settings):
14
17
  mixing_seed: int | None = None
15
18
 
16
19
 
17
- class SignalInjectorState(ez.State):
20
+ @processor_state
21
+ class SignalInjectorState:
18
22
  cur_shape: tuple[int, ...] | None = None
19
23
  cur_frequency: float | None = None
20
- cur_amplitude: float
21
- mixing: npt.NDArray
24
+ cur_amplitude: float | None = None
25
+ mixing: npt.NDArray | None = None
22
26
 
23
27
 
24
- class SignalInjector(ez.Unit):
25
- """
26
- Add a sinusoidal signal to the input signal. Each feature gets a different amplitude of the sinusoid.
27
- All features get the same frequency sinusoid. The frequency and base amplitude can be changed while running.
28
- """
28
+ class SignalInjectorTransformer(
29
+ BaseAsyncTransformer[SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorState]
30
+ ):
31
+ def _hash_message(self, message: AxisArray) -> int:
32
+ time_ax_idx = message.get_axis_idx(self.settings.time_dim)
33
+ sample_shape = message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
34
+ return hash((message.key,) + sample_shape)
29
35
 
30
- SETTINGS = SignalInjectorSettings
31
- STATE = SignalInjectorState
36
+ def _reset_state(self, message: AxisArray) -> None:
37
+ if self._state.cur_frequency is None:
38
+ self._state.cur_frequency = self.settings.frequency
39
+ if self._state.cur_amplitude is None:
40
+ self._state.cur_amplitude = self.settings.amplitude
41
+ time_ax_idx = message.get_axis_idx(self.settings.time_dim)
42
+ self._state.cur_shape = message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
43
+ rng = np.random.default_rng(self.settings.mixing_seed)
44
+ self._state.mixing = rng.random((1, message.shape2d(self.settings.time_dim)[1]))
45
+ self._state.mixing = (self._state.mixing * 2.0) - 1.0
32
46
 
47
+ async def _aprocess(self, message: AxisArray) -> AxisArray:
48
+ if self._state.cur_frequency is None:
49
+ return message
50
+ out_msg = replace(message, data=message.data.copy())
51
+ t = out_msg.ax(self.settings.time_dim).values[..., np.newaxis]
52
+ signal = np.sin(2 * np.pi * self._state.cur_frequency * t)
53
+ mixed_signal = signal * self._state.mixing * self._state.cur_amplitude
54
+ with out_msg.view2d(self.settings.time_dim) as view:
55
+ view[...] = view + mixed_signal.astype(view.dtype)
56
+ return out_msg
57
+
58
+
59
+ class SignalInjector(BaseTransformerUnit[SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorTransformer]):
60
+ SETTINGS = SignalInjectorSettings
33
61
  INPUT_FREQUENCY = ez.InputStream(float | None)
34
62
  INPUT_AMPLITUDE = ez.InputStream(float)
35
- INPUT_SIGNAL = ez.InputStream(AxisArray)
36
- OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
37
-
38
- async def initialize(self) -> None:
39
- self.STATE.cur_frequency = self.SETTINGS.frequency
40
- self.STATE.cur_amplitude = self.SETTINGS.amplitude
41
- self.STATE.mixing = np.array([])
42
63
 
43
64
  @ez.subscriber(INPUT_FREQUENCY)
44
65
  async def on_frequency(self, msg: float | None) -> None:
45
- self.STATE.cur_frequency = msg
66
+ self.processor.state.cur_frequency = msg
46
67
 
47
68
  @ez.subscriber(INPUT_AMPLITUDE)
48
69
  async def on_amplitude(self, msg: float) -> None:
49
- self.STATE.cur_amplitude = msg
50
-
51
- @ez.subscriber(INPUT_SIGNAL)
52
- @ez.publisher(OUTPUT_SIGNAL)
53
- async def inject(self, msg: AxisArray) -> typing.AsyncGenerator:
54
- if self.STATE.cur_shape != msg.shape:
55
- self.STATE.cur_shape = msg.shape
56
- rng = np.random.default_rng(self.SETTINGS.mixing_seed)
57
- self.STATE.mixing = rng.random((1, msg.shape2d(self.SETTINGS.time_dim)[1]))
58
- self.STATE.mixing = (self.STATE.mixing * 2.0) - 1.0
59
-
60
- if self.STATE.cur_frequency is None:
61
- yield self.OUTPUT_SIGNAL, msg
62
- else:
63
- out_msg = replace(msg, data=msg.data.copy())
64
- t = out_msg.ax(self.SETTINGS.time_dim).values[..., np.newaxis]
65
- signal = np.sin(2 * np.pi * self.STATE.cur_frequency * t)
66
- mixed_signal = signal * self.STATE.mixing * self.STATE.cur_amplitude
67
- with out_msg.view2d(self.SETTINGS.time_dim) as view:
68
- view[...] = view + mixed_signal.astype(view.dtype)
69
- yield self.OUTPUT_SIGNAL, out_msg
70
+ self.processor.state.cur_amplitude = msg
ezmsg/sigproc/slicer.py CHANGED
@@ -1,18 +1,17 @@
1
- import typing
2
-
1
+ import ezmsg.core as ez
3
2
  import numpy as np
4
3
  import numpy.typing as npt
5
- import ezmsg.core as ez
4
+ from ezmsg.baseproc import (
5
+ BaseStatefulTransformer,
6
+ BaseTransformerUnit,
7
+ processor_state,
8
+ )
6
9
  from ezmsg.util.messages.axisarray import (
7
10
  AxisArray,
8
- slice_along_axis,
9
11
  AxisBase,
10
12
  replace,
13
+ slice_along_axis,
11
14
  )
12
- from ezmsg.util.generator import consumer
13
-
14
- from .base import GenAxisArray
15
-
16
15
 
17
16
  """
18
17
  Slicer:Select a subset of data along a particular axis.
@@ -49,11 +48,7 @@ def parse_slice(
49
48
  if "," not in s:
50
49
  parts = [part.strip() for part in s.split(":")]
51
50
  if len(parts) == 1:
52
- if (
53
- axinfo is not None
54
- and hasattr(axinfo, "data")
55
- and parts[0] in axinfo.data
56
- ):
51
+ if axinfo is not None and hasattr(axinfo, "data") and parts[0] in axinfo.data:
57
52
  return tuple(np.where(axinfo.data == parts[0])[0])
58
53
  return (int(parts[0]),)
59
54
  return (slice(*(int(part.strip()) if part else None for part in parts)),)
@@ -61,106 +56,83 @@ def parse_slice(
61
56
  return tuple([item for sublist in suplist for item in sublist])
62
57
 
63
58
 
64
- @consumer
65
- def slicer(
66
- selection: str = "", axis: str | None = None
67
- ) -> typing.Generator[AxisArray, AxisArray, None]:
68
- """
69
- Slice along a particular axis.
70
-
71
- Args:
72
- selection: See :obj:`ezmsg.sigproc.slicer.parse_slice` for details.
73
- axis: The name of the axis to slice along. If None, the last axis is used.
59
+ class SlicerSettings(ez.Settings):
60
+ selection: str = ""
61
+ """selection: See :obj:`ezmsg.sigproc.slicer.parse_slice` for details."""
74
62
 
75
- Returns:
76
- A primed generator object ready to yield an :obj:`AxisArray` for each .send(axis_array)
77
- with the data payload containing a sliced view of the input data.
63
+ axis: str | None = None
64
+ """The name of the axis to slice along. If None, the last axis is used."""
78
65
 
79
- """
80
- msg_out = AxisArray(np.array([]), dims=[""])
81
66
 
82
- # State variables
83
- _slice: slice | npt.NDArray | None = None
67
+ @processor_state
68
+ class SlicerState:
69
+ slice_: slice | int | npt.NDArray | None = None
84
70
  new_axis: AxisBase | None = None
85
- b_change_dims: bool = False # If number of dimensions changes when slicing
86
-
87
- # Reset if input changes
88
- check_input = {
89
- "key": None, # key change used as proxy for label change, which we don't check explicitly
90
- "len": None,
91
- }
92
-
93
- while True:
94
- msg_in: AxisArray = yield msg_out
95
-
96
- axis = axis or msg_in.dims[-1]
97
- axis_idx = msg_in.get_axis_idx(axis)
98
-
99
- b_reset = _slice is None # or new_axis is None
100
- b_reset = b_reset or msg_in.key != check_input["key"]
101
- b_reset = b_reset or (
102
- (msg_in.data.shape[axis_idx] != check_input["len"])
103
- and (type(_slice) is np.ndarray)
104
- )
105
- if b_reset:
106
- check_input["key"] = msg_in.key
107
- check_input["len"] = msg_in.data.shape[axis_idx]
108
- new_axis = None # Will hold updated metadata
109
- b_change_dims = False
110
-
111
- # Calculate the slice
112
- _slices = parse_slice(selection, msg_in.axes.get(axis, None))
113
- if len(_slices) == 1:
114
- _slice = _slices[0]
115
- # Do we drop the sliced dimension?
116
- b_change_dims = isinstance(_slice, int)
71
+ b_change_dims: bool = False
72
+
73
+
74
+ class SlicerTransformer(BaseStatefulTransformer[SlicerSettings, AxisArray, AxisArray, SlicerState]):
75
+ def _hash_message(self, message: AxisArray) -> int:
76
+ axis = self.settings.axis or message.dims[-1]
77
+ axis_idx = message.get_axis_idx(axis)
78
+ return hash((message.key, message.data.shape[axis_idx]))
79
+
80
+ def _reset_state(self, message: AxisArray) -> None:
81
+ axis = self.settings.axis or message.dims[-1]
82
+ axis_idx = message.get_axis_idx(axis)
83
+ self._state.new_axis = None
84
+ self._state.b_change_dims = False
85
+
86
+ # Calculate the slice
87
+ _slices = parse_slice(self.settings.selection, message.axes.get(axis, None))
88
+ if len(_slices) == 1:
89
+ self._state.slice_ = _slices[0]
90
+ self._state.b_change_dims = isinstance(self._state.slice_, int)
91
+ else:
92
+ indices = np.arange(message.data.shape[axis_idx])
93
+ indices = np.hstack([indices[_] for _ in _slices])
94
+ self._state.slice_ = np.s_[indices]
95
+
96
+ # Create the output axis
97
+ if axis in message.axes and hasattr(message.axes[axis], "data") and len(message.axes[axis].data) > 0:
98
+ in_data = np.array(message.axes[axis].data)
99
+ if self._state.b_change_dims:
100
+ out_data = in_data[self._state.slice_ : self._state.slice_ + 1]
117
101
  else:
118
- # Multiple slices, but this cannot be done in a single step, so we convert the slices
119
- # to a discontinuous set of integer indexes.
120
- indices = np.arange(msg_in.data.shape[axis_idx])
121
- indices = np.hstack([indices[_] for _ in _slices])
122
- _slice = np.s_[indices] # Integer scalar array
123
-
124
- # Create the output axis.
125
- if (
126
- axis in msg_in.axes
127
- and hasattr(msg_in.axes[axis], "data")
128
- and len(msg_in.axes[axis].data) > 0
129
- ):
130
- in_data = np.array(msg_in.axes[axis].data)
131
- if b_change_dims:
132
- out_data = in_data[_slice : _slice + 1]
133
- else:
134
- out_data = in_data[_slice]
135
- new_axis = replace(msg_in.axes[axis], data=out_data)
102
+ out_data = in_data[self._state.slice_]
103
+ self._state.new_axis = replace(message.axes[axis], data=out_data)
104
+
105
+ def _process(self, message: AxisArray) -> AxisArray:
106
+ axis = self.settings.axis or message.dims[-1]
107
+ axis_idx = message.get_axis_idx(axis)
136
108
 
137
109
  replace_kwargs = {}
138
- if b_change_dims:
139
- # Dropping the target axis
140
- replace_kwargs["dims"] = [
141
- _ for dim_ix, _ in enumerate(msg_in.dims) if dim_ix != axis_idx
142
- ]
143
- replace_kwargs["axes"] = {k: v for k, v in msg_in.axes.items() if k != axis}
144
- elif new_axis is not None:
145
- replace_kwargs["axes"] = {
146
- k: (v if k != axis else new_axis) for k, v in msg_in.axes.items()
147
- }
148
- msg_out = replace(
149
- msg_in,
150
- data=slice_along_axis(msg_in.data, _slice, axis_idx),
110
+ if self._state.b_change_dims:
111
+ replace_kwargs["dims"] = [_ for dim_ix, _ in enumerate(message.dims) if dim_ix != axis_idx]
112
+ replace_kwargs["axes"] = {k: v for k, v in message.axes.items() if k != axis}
113
+ elif self._state.new_axis is not None:
114
+ replace_kwargs["axes"] = {k: (v if k != axis else self._state.new_axis) for k, v in message.axes.items()}
115
+
116
+ return replace(
117
+ message,
118
+ data=slice_along_axis(message.data, self._state.slice_, axis_idx),
151
119
  **replace_kwargs,
152
120
  )
153
121
 
154
122
 
155
- class SlicerSettings(ez.Settings):
156
- selection: str = ""
157
- axis: str | None = None
123
+ class Slicer(BaseTransformerUnit[SlicerSettings, AxisArray, AxisArray, SlicerTransformer]):
124
+ SETTINGS = SlicerSettings
158
125
 
159
126
 
160
- class Slicer(GenAxisArray):
161
- SETTINGS = SlicerSettings
127
+ def slicer(selection: str = "", axis: str | None = None) -> SlicerTransformer:
128
+ """
129
+ Slice along a particular axis.
162
130
 
163
- def construct_generator(self):
164
- self.STATE.gen = slicer(
165
- selection=self.SETTINGS.selection, axis=self.SETTINGS.axis
166
- )
131
+ Args:
132
+ selection: See :obj:`ezmsg.sigproc.slicer.parse_slice` for details.
133
+ axis: The name of the axis to slice along. If None, the last axis is used.
134
+
135
+ Returns:
136
+ :obj:`SlicerTransformer`
137
+ """
138
+ return SlicerTransformer(SlicerSettings(selection=selection, axis=axis))