ezmsg-sigproc 1.8.1__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +36 -39
  3. ezmsg/sigproc/adaptive_lattice_notch.py +231 -0
  4. ezmsg/sigproc/affinetransform.py +169 -163
  5. ezmsg/sigproc/aggregate.py +119 -104
  6. ezmsg/sigproc/bandpower.py +58 -52
  7. ezmsg/sigproc/base.py +1242 -0
  8. ezmsg/sigproc/butterworthfilter.py +37 -33
  9. ezmsg/sigproc/cheby.py +29 -17
  10. ezmsg/sigproc/combfilter.py +163 -0
  11. ezmsg/sigproc/decimate.py +19 -10
  12. ezmsg/sigproc/detrend.py +29 -0
  13. ezmsg/sigproc/diff.py +81 -0
  14. ezmsg/sigproc/downsample.py +78 -78
  15. ezmsg/sigproc/ewma.py +197 -0
  16. ezmsg/sigproc/extract_axis.py +41 -0
  17. ezmsg/sigproc/filter.py +257 -141
  18. ezmsg/sigproc/filterbank.py +247 -199
  19. ezmsg/sigproc/math/abs.py +17 -22
  20. ezmsg/sigproc/math/clip.py +24 -24
  21. ezmsg/sigproc/math/difference.py +34 -30
  22. ezmsg/sigproc/math/invert.py +13 -25
  23. ezmsg/sigproc/math/log.py +28 -33
  24. ezmsg/sigproc/math/scale.py +18 -26
  25. ezmsg/sigproc/quantize.py +71 -0
  26. ezmsg/sigproc/resample.py +298 -0
  27. ezmsg/sigproc/sampler.py +241 -259
  28. ezmsg/sigproc/scaler.py +55 -218
  29. ezmsg/sigproc/signalinjector.py +52 -43
  30. ezmsg/sigproc/slicer.py +81 -89
  31. ezmsg/sigproc/spectrogram.py +77 -75
  32. ezmsg/sigproc/spectrum.py +203 -168
  33. ezmsg/sigproc/synth.py +546 -393
  34. ezmsg/sigproc/transpose.py +131 -0
  35. ezmsg/sigproc/util/asio.py +156 -0
  36. ezmsg/sigproc/util/message.py +31 -0
  37. ezmsg/sigproc/util/profile.py +55 -12
  38. ezmsg/sigproc/util/typeresolution.py +83 -0
  39. ezmsg/sigproc/wavelets.py +154 -153
  40. ezmsg/sigproc/window.py +269 -211
  41. {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/METADATA +2 -1
  42. ezmsg_sigproc-2.0.0.dist-info/RECORD +51 -0
  43. ezmsg_sigproc-1.8.1.dist-info/RECORD +0 -39
  44. {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/WHEEL +0 -0
  45. {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/licenses/LICENSE.txt +0 -0
ezmsg/sigproc/scaler.py CHANGED
@@ -1,163 +1,21 @@
1
- import functools
2
1
  import typing
3
2
 
4
3
  import numpy as np
5
- import numpy.typing as npt
6
- import scipy.signal
7
- import ezmsg.core as ez
8
4
  from ezmsg.util.messages.axisarray import AxisArray
9
5
  from ezmsg.util.messages.util import replace
10
6
  from ezmsg.util.generator import consumer
11
7
 
12
- from .base import GenAxisArray
8
+ from .base import (
9
+ BaseStatefulTransformer,
10
+ BaseTransformerUnit,
11
+ processor_state,
12
+ )
13
+ from .ewma import EWMATransformer, EWMASettings, _alpha_from_tau
13
14
 
14
-
15
- def _tau_from_alpha(alpha: float, dt: float) -> float:
16
- """
17
- Inverse of _alpha_from_tau. See that function for explanation.
18
- """
19
- return -dt / np.log(1 - alpha)
20
-
21
-
22
- def _alpha_from_tau(tau: float, dt: float) -> float:
23
- """
24
- # https://en.wikipedia.org/wiki/Exponential_smoothing#Time_constant
25
- :param tau: The amount of time for the smoothed response of a unit step function to reach
26
- 1 - 1/e approx-eq 63.2%.
27
- :param dt: sampling period, or 1 / sampling_rate.
28
- :return: alpha, the "fading factor" in exponential smoothing.
29
- """
30
- return 1 - np.exp(-dt / tau)
31
-
32
-
33
- def ewma_step(
34
- sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None
35
- ):
36
- """
37
- Do an exponentially weighted moving average step.
38
-
39
- Args:
40
- sample: The new sample.
41
- zi: The output of the previous step.
42
- alpha: Fading factor.
43
- beta: Persisting factor. If None, it is calculated as 1-alpha.
44
-
45
- Returns:
46
- alpha * sample + beta * zi
47
-
48
- """
49
- # Potential micro-optimization:
50
- # Current: scalar-arr multiplication, scalar-arr multiplication, arr-arr addition
51
- # Alternative: arr-arr subtraction, arr-arr multiplication, arr-arr addition
52
- # return zi + alpha * (new_sample - zi)
53
- beta = beta or (1 - alpha)
54
- return alpha * sample + beta * zi
55
-
56
-
57
- class EWMA:
58
- def __init__(self, alpha: float):
59
- self.beta = 1 - alpha
60
- self._filt_func = functools.partial(
61
- scipy.signal.lfilter, [alpha], [1.0, alpha - 1.0], axis=0
62
- )
63
- self.prev = None
64
-
65
- def compute(self, arr: npt.NDArray) -> npt.NDArray:
66
- if self.prev is None:
67
- self.prev = self.beta * arr[:1]
68
- expected, self.prev = self._filt_func(arr, zi=self.prev)
69
- return expected
70
-
71
-
72
- class EWMA_Deprecated:
73
- """
74
- Grabbed these methods from https://stackoverflow.com/a/70998068 and other answers in that topic,
75
- but they ended up being slower than the scipy.signal.lfilter method.
76
- Additionally, `compute` and `compute2` suffer from potential errors as the vector length increases
77
- and beta**n approaches zero.
78
- """
79
-
80
- def __init__(self, alpha: float, max_len: int):
81
- self.alpha = alpha
82
- self.beta = 1 - alpha
83
- self.prev: npt.NDArray | None = None
84
- self.weights = np.empty((max_len + 1,), float)
85
- self._precalc_weights(max_len)
86
- self._step_func = functools.partial(ewma_step, alpha=self.alpha, beta=self.beta)
87
-
88
- def _precalc_weights(self, n: int):
89
- # (1-α)^0, (1-α)^1, (1-α)^2, ..., (1-α)^n
90
- np.power(self.beta, np.arange(n + 1), out=self.weights)
91
-
92
- def compute(self, arr: npt.NDArray, out: npt.NDArray | None = None) -> npt.NDArray:
93
- if out is None:
94
- out = np.empty(arr.shape, arr.dtype)
95
-
96
- n = arr.shape[0]
97
- weights = self.weights[:n]
98
- weights = np.expand_dims(weights, list(range(1, arr.ndim)))
99
-
100
- # α*P0, α*P1, α*P2, ..., α*Pn
101
- np.multiply(self.alpha, arr, out)
102
-
103
- # α*P0/(1-α)^0, α*P1/(1-α)^1, α*P2/(1-α)^2, ..., α*Pn/(1-α)^n
104
- np.divide(out, weights, out)
105
-
106
- # α*P0/(1-α)^0, α*P0/(1-α)^0 + α*P1/(1-α)^1, ...
107
- np.cumsum(out, axis=0, out=out)
108
-
109
- # (α*P0/(1-α)^0)*(1-α)^0, (α*P0/(1-α)^0 + α*P1/(1-α)^1)*(1-α)^1, ...
110
- np.multiply(out, weights, out)
111
-
112
- # Add the previous output
113
- if self.prev is None:
114
- self.prev = arr[:1]
115
-
116
- out += self.prev * np.expand_dims(
117
- self.weights[1 : n + 1], list(range(1, arr.ndim))
118
- )
119
-
120
- self.prev = out[-1:]
121
-
122
- return out
123
-
124
- def compute2(self, arr: npt.NDArray) -> npt.NDArray:
125
- """
126
- Compute the Exponentially Weighted Moving Average (EWMA) of the input array.
127
-
128
- Args:
129
- arr: The input array to be smoothed.
130
-
131
- Returns:
132
- The smoothed array.
133
- """
134
- n = arr.shape[0]
135
- if n > len(self.weights):
136
- self._precalc_weights(n)
137
- weights = self.weights[:n][::-1]
138
- weights = np.expand_dims(weights, list(range(1, arr.ndim)))
139
-
140
- result = np.cumsum(self.alpha * weights * arr, axis=0)
141
- result = result / weights
142
-
143
- # Handle the first call when prev is unset
144
- if self.prev is None:
145
- self.prev = arr[:1]
146
-
147
- result += self.prev * np.expand_dims(
148
- self.weights[1 : n + 1], list(range(1, arr.ndim))
149
- )
150
-
151
- # Store the result back into prev
152
- self.prev = result[-1]
153
-
154
- return result
155
-
156
- def compute_sample(self, new_sample: npt.NDArray) -> npt.NDArray:
157
- if self.prev is None:
158
- self.prev = new_sample
159
- self.prev = self._step_func(new_sample, self.prev)
160
- return self.prev
15
+ # Imports for backwards compatibility with previous module location
16
+ from .ewma import EWMA_Deprecated as EWMA_Deprecated
17
+ from .ewma import ewma_step as ewma_step
18
+ from .ewma import _tau_from_alpha as _tau_from_alpha
161
19
 
162
20
 
163
21
  @consumer
@@ -208,83 +66,62 @@ def scaler(
208
66
  msg_out = replace(msg_in, data=result)
209
67
 
210
68
 
211
- @consumer
212
- def scaler_np(
213
- time_constant: float = 1.0, axis: str | None = None
214
- ) -> typing.Generator[AxisArray, AxisArray, None]:
215
- """
216
- Create a generator function that applies an adaptive standard scaler.
217
- This is faster than :obj:`scaler` for multichannel data.
218
-
219
- Args:
220
- time_constant: Decay constant `tau` in seconds.
221
- axis: The name of the axis to accumulate statistics over.
222
- Note: The axis must be in the msg.axes and be of type AxisArray.LinearAxis.
223
-
224
- Returns:
225
- A primed generator object that expects to be sent a :obj:`AxisArray` via `.send(axis_array)`
226
- and yields an :obj:`AxisArray` with its data being a standardized, or "Z-scored" version of the input data.
227
- """
228
- msg_out = AxisArray(np.array([]), dims=[""])
229
-
230
- # State variables
231
- samps_ewma: EWMA | None = None
232
- vars_sq_ewma: EWMA | None = None
69
+ class AdaptiveStandardScalerSettings(EWMASettings): ...
233
70
 
234
- # Reset if input changes
235
- check_input = {
236
- "gain": None, # Resets alpha
237
- "shape": None,
238
- "key": None, # Key change implies buffered means/vars are invalid.
239
- }
240
71
 
241
- while True:
242
- msg_in: AxisArray = yield msg_out
72
+ @processor_state
73
+ class AdaptiveStandardScalerState:
74
+ samps_ewma: EWMATransformer | None = None
75
+ vars_sq_ewma: EWMATransformer | None = None
76
+ alpha: float | None = None
243
77
 
244
- axis = axis or msg_in.dims[0]
245
- axis_idx = msg_in.get_axis_idx(axis)
246
78
 
247
- data: npt.NDArray = np.moveaxis(msg_in.data, axis_idx, 0)
248
- b_reset = data.shape[1:] != check_input["shape"]
249
- b_reset = b_reset or msg_in.axes[axis].gain != check_input["gain"]
250
- b_reset = b_reset or msg_in.key != check_input["key"]
251
- if b_reset:
252
- check_input["shape"] = data.shape[1:]
253
- check_input["gain"] = msg_in.axes[axis].gain
254
- check_input["key"] = msg_in.key
255
- alpha = _alpha_from_tau(time_constant, msg_in.axes[axis].gain)
256
- samps_ewma = EWMA(alpha=alpha)
257
- vars_sq_ewma = EWMA(alpha=alpha)
79
+ class AdaptiveStandardScalerTransformer(
80
+ BaseStatefulTransformer[
81
+ AdaptiveStandardScalerSettings,
82
+ AxisArray,
83
+ AxisArray,
84
+ AdaptiveStandardScalerState,
85
+ ]
86
+ ):
87
+ def _reset_state(self, message: AxisArray) -> None:
88
+ self._state.samps_ewma = EWMATransformer(
89
+ time_constant=self.settings.time_constant, axis=self.settings.axis
90
+ )
91
+ self._state.vars_sq_ewma = EWMATransformer(
92
+ time_constant=self.settings.time_constant, axis=self.settings.axis
93
+ )
258
94
 
95
+ def _process(self, message: AxisArray) -> AxisArray:
259
96
  # Update step
260
- means = samps_ewma.compute(data)
261
- vars_sq_means = vars_sq_ewma.compute(data**2)
97
+ mean_message = self._state.samps_ewma(message)
98
+ var_sq_message = self._state.vars_sq_ewma(
99
+ replace(message, data=message.data**2)
100
+ )
262
101
 
263
102
  # Get step
264
- varis = vars_sq_means - means**2
103
+ varis = var_sq_message.data - mean_message.data**2
265
104
  with np.errstate(divide="ignore", invalid="ignore"):
266
- result = (data - means) / (varis**0.5)
105
+ result = (message.data - mean_message.data) / (varis**0.5)
267
106
  result[np.isnan(result)] = 0.0
268
- result = np.moveaxis(result, 0, axis_idx)
269
- msg_out = replace(msg_in, data=result)
107
+ return replace(message, data=result)
270
108
 
271
109
 
272
- class AdaptiveStandardScalerSettings(ez.Settings):
273
- """
274
- Settings for :obj:`AdaptiveStandardScaler`.
275
- See :obj:`scaler_np` for a description of the parameters.
276
- """
277
-
278
- time_constant: float = 1.0
279
- axis: str | None = None
280
-
281
-
282
- class AdaptiveStandardScaler(GenAxisArray):
283
- """Unit for :obj:`scaler_np`"""
284
-
110
+ class AdaptiveStandardScaler(
111
+ BaseTransformerUnit[
112
+ AdaptiveStandardScalerSettings,
113
+ AxisArray,
114
+ AxisArray,
115
+ AdaptiveStandardScalerTransformer,
116
+ ]
117
+ ):
285
118
  SETTINGS = AdaptiveStandardScalerSettings
286
119
 
287
- def construct_generator(self):
288
- self.STATE.gen = scaler_np(
289
- time_constant=self.SETTINGS.time_constant, axis=self.SETTINGS.axis
290
- )
120
+
121
+ # Backwards compatibility...
122
+ def scaler_np(
123
+ time_constant: float = 1.0, axis: str | None = None
124
+ ) -> AdaptiveStandardScalerTransformer:
125
+ return AdaptiveStandardScalerTransformer(
126
+ settings=AdaptiveStandardScalerSettings(time_constant=time_constant, axis=axis)
127
+ )
@@ -1,12 +1,14 @@
1
- import typing
2
-
3
1
  import ezmsg.core as ez
4
2
  from ezmsg.util.messages.axisarray import AxisArray
5
3
  from ezmsg.util.messages.util import replace
6
4
  import numpy as np
7
5
  import numpy.typing as npt
8
6
 
9
- from .util.profile import profile_subpub
7
+ from .base import (
8
+ BaseAsyncTransformer,
9
+ BaseTransformerUnit,
10
+ processor_state,
11
+ )
10
12
 
11
13
 
12
14
  class SignalInjectorSettings(ez.Settings):
@@ -16,57 +18,64 @@ class SignalInjectorSettings(ez.Settings):
16
18
  mixing_seed: int | None = None
17
19
 
18
20
 
19
- class SignalInjectorState(ez.State):
21
+ @processor_state
22
+ class SignalInjectorState:
20
23
  cur_shape: tuple[int, ...] | None = None
21
24
  cur_frequency: float | None = None
22
- cur_amplitude: float
23
- mixing: npt.NDArray
25
+ cur_amplitude: float | None = None
26
+ mixing: npt.NDArray | None = None
24
27
 
25
28
 
26
- class SignalInjector(ez.Unit):
27
- """
28
- Add a sinusoidal signal to the input signal. Each feature gets a different amplitude of the sinusoid.
29
- All features get the same frequency sinusoid. The frequency and base amplitude can be changed while running.
30
- """
29
+ class SignalInjectorTransformer(
30
+ BaseAsyncTransformer[
31
+ SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorState
32
+ ]
33
+ ):
34
+ def _hash_message(self, message: AxisArray) -> int:
35
+ time_ax_idx = message.get_axis_idx(self.settings.time_dim)
36
+ sample_shape = (
37
+ message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
38
+ )
39
+ return hash((message.key,) + sample_shape)
40
+
41
+ def _reset_state(self, message: AxisArray) -> None:
42
+ if self._state.cur_frequency is None:
43
+ self._state.cur_frequency = self.settings.frequency
44
+ if self._state.cur_amplitude is None:
45
+ self._state.cur_amplitude = self.settings.amplitude
46
+ time_ax_idx = message.get_axis_idx(self.settings.time_dim)
47
+ self._state.cur_shape = (
48
+ message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
49
+ )
50
+ rng = np.random.default_rng(self.settings.mixing_seed)
51
+ self._state.mixing = rng.random((1, message.shape2d(self.settings.time_dim)[1]))
52
+ self._state.mixing = (self._state.mixing * 2.0) - 1.0
53
+
54
+ async def _aprocess(self, message: AxisArray) -> AxisArray:
55
+ if self._state.cur_frequency is None:
56
+ return message
57
+ out_msg = replace(message, data=message.data.copy())
58
+ t = out_msg.ax(self.settings.time_dim).values[..., np.newaxis]
59
+ signal = np.sin(2 * np.pi * self._state.cur_frequency * t)
60
+ mixed_signal = signal * self._state.mixing * self._state.cur_amplitude
61
+ with out_msg.view2d(self.settings.time_dim) as view:
62
+ view[...] = view + mixed_signal.astype(view.dtype)
63
+ return out_msg
31
64
 
32
- SETTINGS = SignalInjectorSettings
33
- STATE = SignalInjectorState
34
65
 
66
+ class SignalInjector(
67
+ BaseTransformerUnit[
68
+ SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorTransformer
69
+ ]
70
+ ):
71
+ SETTINGS = SignalInjectorSettings
35
72
  INPUT_FREQUENCY = ez.InputStream(float | None)
36
73
  INPUT_AMPLITUDE = ez.InputStream(float)
37
- INPUT_SIGNAL = ez.InputStream(AxisArray)
38
- OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
39
-
40
- async def initialize(self) -> None:
41
- self.STATE.cur_frequency = self.SETTINGS.frequency
42
- self.STATE.cur_amplitude = self.SETTINGS.amplitude
43
- self.STATE.mixing = np.array([])
44
74
 
45
75
  @ez.subscriber(INPUT_FREQUENCY)
46
76
  async def on_frequency(self, msg: float | None) -> None:
47
- self.STATE.cur_frequency = msg
77
+ self.processor.state.cur_frequency = msg
48
78
 
49
79
  @ez.subscriber(INPUT_AMPLITUDE)
50
80
  async def on_amplitude(self, msg: float) -> None:
51
- self.STATE.cur_amplitude = msg
52
-
53
- @ez.subscriber(INPUT_SIGNAL)
54
- @ez.publisher(OUTPUT_SIGNAL)
55
- @profile_subpub(trace_oldest=False)
56
- async def inject(self, msg: AxisArray) -> typing.AsyncGenerator:
57
- if self.STATE.cur_shape != msg.shape:
58
- self.STATE.cur_shape = msg.shape
59
- rng = np.random.default_rng(self.SETTINGS.mixing_seed)
60
- self.STATE.mixing = rng.random((1, msg.shape2d(self.SETTINGS.time_dim)[1]))
61
- self.STATE.mixing = (self.STATE.mixing * 2.0) - 1.0
62
-
63
- if self.STATE.cur_frequency is None:
64
- yield self.OUTPUT_SIGNAL, msg
65
- else:
66
- out_msg = replace(msg, data=msg.data.copy())
67
- t = out_msg.ax(self.SETTINGS.time_dim).values[..., np.newaxis]
68
- signal = np.sin(2 * np.pi * self.STATE.cur_frequency * t)
69
- mixed_signal = signal * self.STATE.mixing * self.STATE.cur_amplitude
70
- with out_msg.view2d(self.SETTINGS.time_dim) as view:
71
- view[...] = view + mixed_signal.astype(view.dtype)
72
- yield self.OUTPUT_SIGNAL, out_msg
81
+ self.processor.state.cur_amplitude = msg
ezmsg/sigproc/slicer.py CHANGED
@@ -1,5 +1,3 @@
1
- import typing
2
-
3
1
  import numpy as np
4
2
  import numpy.typing as npt
5
3
  import ezmsg.core as ez
@@ -9,10 +7,12 @@ from ezmsg.util.messages.axisarray import (
9
7
  AxisBase,
10
8
  replace,
11
9
  )
12
- from ezmsg.util.generator import consumer
13
-
14
- from .base import GenAxisArray
15
10
 
11
+ from .base import (
12
+ BaseStatefulTransformer,
13
+ BaseTransformerUnit,
14
+ processor_state,
15
+ )
16
16
 
17
17
  """
18
18
  Slicer:Select a subset of data along a particular axis.
@@ -61,106 +61,98 @@ def parse_slice(
61
61
  return tuple([item for sublist in suplist for item in sublist])
62
62
 
63
63
 
64
- @consumer
65
- def slicer(
66
- selection: str = "", axis: str | None = None
67
- ) -> typing.Generator[AxisArray, AxisArray, None]:
68
- """
69
- Slice along a particular axis.
70
-
71
- Args:
72
- selection: See :obj:`ezmsg.sigproc.slicer.parse_slice` for details.
73
- axis: The name of the axis to slice along. If None, the last axis is used.
64
+ class SlicerSettings(ez.Settings):
65
+ selection: str = ""
66
+ """selection: See :obj:`ezmsg.sigproc.slicer.parse_slice` for details."""
74
67
 
75
- Returns:
76
- A primed generator object ready to yield an :obj:`AxisArray` for each .send(axis_array)
77
- with the data payload containing a sliced view of the input data.
68
+ axis: str | None = None
69
+ """The name of the axis to slice along. If None, the last axis is used."""
78
70
 
79
- """
80
- msg_out = AxisArray(np.array([]), dims=[""])
81
71
 
82
- # State variables
83
- _slice: slice | npt.NDArray | None = None
72
+ @processor_state
73
+ class SlicerState:
74
+ slice_: slice | int | npt.NDArray | None = None
84
75
  new_axis: AxisBase | None = None
85
- b_change_dims: bool = False # If number of dimensions changes when slicing
86
-
87
- # Reset if input changes
88
- check_input = {
89
- "key": None, # key change used as proxy for label change, which we don't check explicitly
90
- "len": None,
91
- }
92
-
93
- while True:
94
- msg_in: AxisArray = yield msg_out
95
-
96
- axis = axis or msg_in.dims[-1]
97
- axis_idx = msg_in.get_axis_idx(axis)
98
-
99
- b_reset = _slice is None # or new_axis is None
100
- b_reset = b_reset or msg_in.key != check_input["key"]
101
- b_reset = b_reset or (
102
- (msg_in.data.shape[axis_idx] != check_input["len"])
103
- and (type(_slice) is np.ndarray)
104
- )
105
- if b_reset:
106
- check_input["key"] = msg_in.key
107
- check_input["len"] = msg_in.data.shape[axis_idx]
108
- new_axis = None # Will hold updated metadata
109
- b_change_dims = False
110
-
111
- # Calculate the slice
112
- _slices = parse_slice(selection, msg_in.axes.get(axis, None))
113
- if len(_slices) == 1:
114
- _slice = _slices[0]
115
- # Do we drop the sliced dimension?
116
- b_change_dims = isinstance(_slice, int)
76
+ b_change_dims: bool = False
77
+
78
+
79
+ class SlicerTransformer(
80
+ BaseStatefulTransformer[SlicerSettings, AxisArray, AxisArray, SlicerState]
81
+ ):
82
+ def _hash_message(self, message: AxisArray) -> int:
83
+ axis = self.settings.axis or message.dims[-1]
84
+ axis_idx = message.get_axis_idx(axis)
85
+ return hash((message.key, message.data.shape[axis_idx]))
86
+
87
+ def _reset_state(self, message: AxisArray) -> None:
88
+ axis = self.settings.axis or message.dims[-1]
89
+ axis_idx = message.get_axis_idx(axis)
90
+ self._state.new_axis = None
91
+ self._state.b_change_dims = False
92
+
93
+ # Calculate the slice
94
+ _slices = parse_slice(self.settings.selection, message.axes.get(axis, None))
95
+ if len(_slices) == 1:
96
+ self._state.slice_ = _slices[0]
97
+ self._state.b_change_dims = isinstance(self._state.slice_, int)
98
+ else:
99
+ indices = np.arange(message.data.shape[axis_idx])
100
+ indices = np.hstack([indices[_] for _ in _slices])
101
+ self._state.slice_ = np.s_[indices]
102
+
103
+ # Create the output axis
104
+ if (
105
+ axis in message.axes
106
+ and hasattr(message.axes[axis], "data")
107
+ and len(message.axes[axis].data) > 0
108
+ ):
109
+ in_data = np.array(message.axes[axis].data)
110
+ if self._state.b_change_dims:
111
+ out_data = in_data[self._state.slice_ : self._state.slice_ + 1]
117
112
  else:
118
- # Multiple slices, but this cannot be done in a single step, so we convert the slices
119
- # to a discontinuous set of integer indexes.
120
- indices = np.arange(msg_in.data.shape[axis_idx])
121
- indices = np.hstack([indices[_] for _ in _slices])
122
- _slice = np.s_[indices] # Integer scalar array
113
+ out_data = in_data[self._state.slice_]
114
+ self._state.new_axis = replace(message.axes[axis], data=out_data)
123
115
 
124
- # Create the output axis.
125
- if (
126
- axis in msg_in.axes
127
- and hasattr(msg_in.axes[axis], "data")
128
- and len(msg_in.axes[axis].data) > 0
129
- ):
130
- in_data = np.array(msg_in.axes[axis].data)
131
- if b_change_dims:
132
- out_data = in_data[_slice : _slice + 1]
133
- else:
134
- out_data = in_data[_slice]
135
- new_axis = replace(msg_in.axes[axis], data=out_data)
116
+ def _process(self, message: AxisArray) -> AxisArray:
117
+ axis = self.settings.axis or message.dims[-1]
118
+ axis_idx = message.get_axis_idx(axis)
136
119
 
137
120
  replace_kwargs = {}
138
- if b_change_dims:
139
- # Dropping the target axis
121
+ if self._state.b_change_dims:
140
122
  replace_kwargs["dims"] = [
141
- _ for dim_ix, _ in enumerate(msg_in.dims) if dim_ix != axis_idx
123
+ _ for dim_ix, _ in enumerate(message.dims) if dim_ix != axis_idx
142
124
  ]
143
- replace_kwargs["axes"] = {k: v for k, v in msg_in.axes.items() if k != axis}
144
- elif new_axis is not None:
145
125
  replace_kwargs["axes"] = {
146
- k: (v if k != axis else new_axis) for k, v in msg_in.axes.items()
126
+ k: v for k, v in message.axes.items() if k != axis
147
127
  }
148
- msg_out = replace(
149
- msg_in,
150
- data=slice_along_axis(msg_in.data, _slice, axis_idx),
128
+ elif self._state.new_axis is not None:
129
+ replace_kwargs["axes"] = {
130
+ k: (v if k != axis else self._state.new_axis)
131
+ for k, v in message.axes.items()
132
+ }
133
+
134
+ return replace(
135
+ message,
136
+ data=slice_along_axis(message.data, self._state.slice_, axis_idx),
151
137
  **replace_kwargs,
152
138
  )
153
139
 
154
140
 
155
- class SlicerSettings(ez.Settings):
156
- selection: str = ""
157
- axis: str | None = None
141
+ class Slicer(
142
+ BaseTransformerUnit[SlicerSettings, AxisArray, AxisArray, SlicerTransformer]
143
+ ):
144
+ SETTINGS = SlicerSettings
158
145
 
159
146
 
160
- class Slicer(GenAxisArray):
161
- SETTINGS = SlicerSettings
147
+ def slicer(selection: str = "", axis: str | None = None) -> SlicerTransformer:
148
+ """
149
+ Slice along a particular axis.
162
150
 
163
- def construct_generator(self):
164
- self.STATE.gen = slicer(
165
- selection=self.SETTINGS.selection, axis=self.SETTINGS.axis
166
- )
151
+ Args:
152
+ selection: See :obj:`ezmsg.sigproc.slicer.parse_slice` for details.
153
+ axis: The name of the axis to slice along. If None, the last axis is used.
154
+
155
+ Returns:
156
+ :obj:`SlicerTransformer`
157
+ """
158
+ return SlicerTransformer(SlicerSettings(selection=selection, axis=axis))