ezmsg-sigproc 1.8.1__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +36 -39
  3. ezmsg/sigproc/adaptive_lattice_notch.py +231 -0
  4. ezmsg/sigproc/affinetransform.py +169 -163
  5. ezmsg/sigproc/aggregate.py +119 -104
  6. ezmsg/sigproc/bandpower.py +58 -52
  7. ezmsg/sigproc/base.py +1242 -0
  8. ezmsg/sigproc/butterworthfilter.py +37 -33
  9. ezmsg/sigproc/cheby.py +29 -17
  10. ezmsg/sigproc/combfilter.py +163 -0
  11. ezmsg/sigproc/decimate.py +19 -10
  12. ezmsg/sigproc/detrend.py +29 -0
  13. ezmsg/sigproc/diff.py +81 -0
  14. ezmsg/sigproc/downsample.py +78 -78
  15. ezmsg/sigproc/ewma.py +197 -0
  16. ezmsg/sigproc/extract_axis.py +41 -0
  17. ezmsg/sigproc/filter.py +257 -141
  18. ezmsg/sigproc/filterbank.py +247 -199
  19. ezmsg/sigproc/math/abs.py +17 -22
  20. ezmsg/sigproc/math/clip.py +24 -24
  21. ezmsg/sigproc/math/difference.py +34 -30
  22. ezmsg/sigproc/math/invert.py +13 -25
  23. ezmsg/sigproc/math/log.py +28 -33
  24. ezmsg/sigproc/math/scale.py +18 -26
  25. ezmsg/sigproc/quantize.py +71 -0
  26. ezmsg/sigproc/resample.py +298 -0
  27. ezmsg/sigproc/sampler.py +241 -259
  28. ezmsg/sigproc/scaler.py +55 -218
  29. ezmsg/sigproc/signalinjector.py +52 -43
  30. ezmsg/sigproc/slicer.py +81 -89
  31. ezmsg/sigproc/spectrogram.py +77 -75
  32. ezmsg/sigproc/spectrum.py +203 -168
  33. ezmsg/sigproc/synth.py +546 -393
  34. ezmsg/sigproc/transpose.py +131 -0
  35. ezmsg/sigproc/util/asio.py +156 -0
  36. ezmsg/sigproc/util/message.py +31 -0
  37. ezmsg/sigproc/util/profile.py +55 -12
  38. ezmsg/sigproc/util/typeresolution.py +83 -0
  39. ezmsg/sigproc/wavelets.py +154 -153
  40. ezmsg/sigproc/window.py +269 -211
  41. {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/METADATA +2 -1
  42. ezmsg_sigproc-2.0.0.dist-info/RECORD +51 -0
  43. ezmsg_sigproc-1.8.1.dist-info/RECORD +0 -39
  44. {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/WHEEL +0 -0
  45. {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,112 +1,112 @@
1
1
  import os
2
2
  from pathlib import Path
3
- import typing
4
3
 
5
4
  import numpy as np
6
5
  import numpy.typing as npt
7
6
  import ezmsg.core as ez
8
7
  from ezmsg.util.messages.axisarray import AxisArray, AxisBase
9
8
  from ezmsg.util.messages.util import replace
10
- from ezmsg.util.generator import consumer
11
9
 
12
- from .base import GenAxisArray
10
+ from .base import (
11
+ BaseStatefulTransformer,
12
+ BaseTransformerUnit,
13
+ BaseTransformer,
14
+ processor_state,
15
+ )
13
16
 
14
17
 
15
- @consumer
16
- def affine_transform(
17
- weights: np.ndarray | str | Path,
18
- axis: str | None = None,
19
- right_multiply: bool = True,
20
- ) -> typing.Generator[AxisArray, AxisArray, None]:
18
+ class AffineTransformSettings(ez.Settings):
19
+ """
20
+ Settings for :obj:`AffineTransform`.
21
+ See :obj:`affine_transform` for argument details.
21
22
  """
22
- Perform affine transformations on streaming data.
23
23
 
24
- Args:
25
- weights: An array of weights or a path to a file with weights compatible with np.loadtxt.
26
- axis: The name of the axis to apply the transformation to. Defaults to the leading (0th) axis in the array.
27
- right_multiply: Set False to transpose the weights before applying.
24
+ weights: np.ndarray | str | Path
25
+ """An array of weights or a path to a file with weights compatible with np.loadtxt."""
28
26
 
29
- Returns:
30
- A primed generator object that yields an :obj:`AxisArray` object for every
31
- :obj:`AxisArray` it receives via `send`.
32
- """
33
- msg_out = AxisArray(np.array([]), dims=[""])
27
+ axis: str | None = None
28
+ """The name of the axis to apply the transformation to. Defaults to the leading (0th) axis in the array."""
34
29
 
35
- # Check parameters
36
- if isinstance(weights, str):
37
- if weights == "passthrough":
38
- weights = None
39
- else:
40
- weights = Path(os.path.abspath(os.path.expanduser(weights)))
41
- if isinstance(weights, Path):
42
- weights = np.loadtxt(weights, delimiter=",")
43
- if not right_multiply:
44
- weights = weights.T
45
- if weights is not None:
46
- weights = np.ascontiguousarray(weights)
47
-
48
- # State variables
49
- # New axis with transformed labels, if required
30
+ right_multiply: bool = True
31
+ """Set False to transpose the weights before applying."""
32
+
33
+
34
+ @processor_state
35
+ class AffineTransformState:
36
+ weights: npt.NDArray | None = None
50
37
  new_axis: AxisBase | None = None
51
38
 
52
- # Reset if any of these change.
53
- check_input = {"key": None}
54
- # We assume key change catches labels change; we don't want to check labels every message
55
- # We don't need to check if input size has changed because weights multiplication will fail if so.
56
-
57
- while True:
58
- msg_in: AxisArray = yield msg_out
59
-
60
- if weights is None:
61
- msg_out = msg_in
62
- continue
63
-
64
- axis = axis or msg_in.dims[-1] # Note: Most nodes default do dim[0]
65
- axis_idx = msg_in.get_axis_idx(axis)
66
-
67
- b_reset = msg_in.key != check_input["key"]
68
- if b_reset:
69
- # First sample or key has changed. Reset the state.
70
- check_input["key"] = msg_in.key
71
- # Determine if we need to modify the transformed axis.
72
- if (
73
- axis in msg_in.axes
74
- and hasattr(msg_in.axes[axis], "data")
75
- and weights.shape[0] != weights.shape[1]
76
- ):
77
- in_labels = msg_in.axes[axis].data
78
- new_labels = []
79
- n_in, n_out = weights.shape
80
- if len(in_labels) != n_in:
81
- # Something upstream did something it wasn't supposed to. We will drop the labels.
82
- ez.logger.warning(
83
- f"Received {len(in_labels)} for {n_in} inputs. Check upstream labels."
84
- )
85
- else:
86
- b_filled_outputs = np.any(weights, axis=0)
87
- b_used_inputs = np.any(weights, axis=1)
88
- if np.all(b_used_inputs) and np.all(b_filled_outputs):
89
- # All inputs are used and all outputs are used, but n_in != n_out.
90
- # Mapping cannot be determined.
91
- new_labels = []
92
- elif np.all(b_used_inputs):
93
- # Strange scenario: New outputs are filled with empty data.
94
- in_ix = 0
95
- new_labels = []
96
- for out_ix in range(n_out):
97
- if b_filled_outputs[out_ix]:
98
- new_labels.append(in_labels[in_ix])
99
- in_ix += 1
100
- else:
101
- new_labels.append("")
102
- elif np.all(b_filled_outputs):
103
- # Transform is dropping some of the inputs.
104
- new_labels = np.array(in_labels)[b_used_inputs]
105
- new_axis = replace(msg_in.axes[axis], data=np.array(new_labels))
106
-
107
- data = msg_in.data
108
-
109
- if data.shape[axis_idx] == (weights.shape[0] - 1):
39
+
40
+ class AffineTransformTransformer(
41
+ BaseStatefulTransformer[
42
+ AffineTransformSettings, AxisArray, AxisArray, AffineTransformState
43
+ ]
44
+ ):
45
+ def __call__(self, message: AxisArray) -> AxisArray:
46
+ # Override __call__ so we can shortcut if weights are None.
47
+ if self.settings.weights is None or (
48
+ isinstance(self.settings.weights, str)
49
+ and self.settings.weights == "passthrough"
50
+ ):
51
+ return message
52
+ return super().__call__(message)
53
+
54
+ def _hash_message(self, message: AxisArray) -> int:
55
+ return hash(message.key)
56
+
57
+ def _reset_state(self, message: AxisArray) -> None:
58
+ weights = self.settings.weights
59
+ if isinstance(weights, str):
60
+ weights = Path(os.path.abspath(os.path.expanduser(weights)))
61
+ if isinstance(weights, Path):
62
+ weights = np.loadtxt(weights, delimiter=",")
63
+ if not self.settings.right_multiply:
64
+ weights = weights.T
65
+ if weights is not None:
66
+ weights = np.ascontiguousarray(weights)
67
+
68
+ self._state.weights = weights
69
+
70
+ axis = self.settings.axis or message.dims[-1]
71
+ if (
72
+ axis in message.axes
73
+ and hasattr(message.axes[axis], "data")
74
+ and weights.shape[0] != weights.shape[1]
75
+ ):
76
+ in_labels = message.axes[axis].data
77
+ new_labels = []
78
+ n_in, n_out = weights.shape
79
+ if len(in_labels) != n_in:
80
+ ez.logger.warning(
81
+ f"Received {len(in_labels)} for {n_in} inputs. Check upstream labels."
82
+ )
83
+ else:
84
+ b_filled_outputs = np.any(weights, axis=0)
85
+ b_used_inputs = np.any(weights, axis=1)
86
+ if np.all(b_used_inputs) and np.all(b_filled_outputs):
87
+ new_labels = []
88
+ elif np.all(b_used_inputs):
89
+ in_ix = 0
90
+ new_labels = []
91
+ for out_ix in range(n_out):
92
+ if b_filled_outputs[out_ix]:
93
+ new_labels.append(in_labels[in_ix])
94
+ in_ix += 1
95
+ else:
96
+ new_labels.append("")
97
+ elif np.all(b_filled_outputs):
98
+ new_labels = np.array(in_labels)[b_used_inputs]
99
+
100
+ self._state.new_axis = replace(
101
+ message.axes[axis], data=np.array(new_labels)
102
+ )
103
+
104
+ def _process(self, message: AxisArray) -> AxisArray:
105
+ axis = self.settings.axis or message.dims[-1]
106
+ axis_idx = message.get_axis_idx(axis)
107
+ data = message.data
108
+
109
+ if data.shape[axis_idx] == (self._state.weights.shape[0] - 1):
110
110
  # The weights are stacked A|B where A is the transform and B is a single row
111
111
  # in the equation y = Ax + B. This supports NeuroKey's weights matrices.
112
112
  sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx + 1 :]
@@ -114,82 +114,87 @@ def affine_transform(
114
114
  (data, np.ones(sample_shape).astype(data.dtype)), axis=axis_idx
115
115
  )
116
116
 
117
- if axis_idx in [-1, len(msg_in.dims) - 1]:
118
- data = np.matmul(data, weights)
117
+ if axis_idx in [-1, len(message.dims) - 1]:
118
+ data = np.matmul(data, self._state.weights)
119
119
  else:
120
120
  data = np.moveaxis(data, axis_idx, -1)
121
- data = np.matmul(data, weights)
121
+ data = np.matmul(data, self._state.weights)
122
122
  data = np.moveaxis(data, -1, axis_idx)
123
123
 
124
124
  replace_kwargs = {"data": data}
125
- if new_axis is not None:
126
- replace_kwargs["axes"] = {**msg_in.axes, axis: new_axis}
127
- msg_out = replace(msg_in, **replace_kwargs)
125
+ if self._state.new_axis is not None:
126
+ replace_kwargs["axes"] = {**message.axes, axis: self._state.new_axis}
128
127
 
128
+ return replace(message, **replace_kwargs)
129
129
 
130
- class AffineTransformSettings(ez.Settings):
131
- """
132
- Settings for :obj:`AffineTransform`.
133
- See :obj:`affine_transform` for argument details.
134
- """
135
130
 
136
- weights: np.ndarray | str | Path
137
- axis: str | None = None
138
- right_multiply: bool = True
131
+ class AffineTransform(
132
+ BaseTransformerUnit[
133
+ AffineTransformSettings, AxisArray, AxisArray, AffineTransformTransformer
134
+ ]
135
+ ):
136
+ SETTINGS = AffineTransformSettings
139
137
 
140
138
 
141
- class AffineTransform(GenAxisArray):
142
- """:obj:`Unit` for :obj:`affine_transform`"""
139
+ def affine_transform(
140
+ weights: np.ndarray | str | Path,
141
+ axis: str | None = None,
142
+ right_multiply: bool = True,
143
+ ) -> AffineTransformTransformer:
144
+ """
145
+ Perform affine transformations on streaming data.
143
146
 
144
- SETTINGS = AffineTransformSettings
147
+ Args:
148
+ weights: An array of weights or a path to a file with weights compatible with np.loadtxt.
149
+ axis: The name of the axis to apply the transformation to. Defaults to the leading (0th) axis in the array.
150
+ right_multiply: Set False to transpose the weights before applying.
145
151
 
146
- def construct_generator(self):
147
- self.STATE.gen = affine_transform(
148
- weights=self.SETTINGS.weights,
149
- axis=self.SETTINGS.axis,
150
- right_multiply=self.SETTINGS.right_multiply,
152
+ Returns:
153
+ :obj:`AffineTransformTransformer`.
154
+ """
155
+ return AffineTransformTransformer(
156
+ AffineTransformSettings(
157
+ weights=weights, axis=axis, right_multiply=right_multiply
151
158
  )
159
+ )
152
160
 
153
161
 
154
162
  def zeros_for_noop(data: npt.NDArray, **ignore_kwargs) -> npt.NDArray:
155
163
  return np.zeros_like(data)
156
164
 
157
165
 
158
- @consumer
159
- def common_rereference(
160
- mode: str = "mean", axis: str | None = None, include_current: bool = True
161
- ) -> typing.Generator[AxisArray, AxisArray, None]:
166
+ class CommonRereferenceSettings(ez.Settings):
167
+ """
168
+ Settings for :obj:`CommonRereference`
162
169
  """
163
- Perform common average referencing (CAR) on streaming data.
164
170
 
165
- Args:
166
- mode: The statistical mode to apply -- either "mean" or "median"
167
- axis: The name of hte axis to apply the transformation to.
168
- include_current: Set False to exclude each channel from participating in the calculation of its reference.
171
+ mode: str = "mean"
172
+ """The statistical mode to apply -- either "mean" or "median"."""
169
173
 
170
- Returns:
171
- A primed generator object that yields an :obj:`AxisArray` object
172
- for every :obj:`AxisArray` it receives via `send`.
173
- """
174
- msg_out = AxisArray(np.array([]), dims=[""])
174
+ axis: str | None = None
175
+ """The name of the axis to apply the transformation to."""
175
176
 
176
- if mode == "passthrough":
177
- include_current = True
177
+ include_current: bool = True
178
+ """Set False to exclude each channel from participating in the calculation of its reference."""
178
179
 
179
- func = {"mean": np.mean, "median": np.median, "passthrough": zeros_for_noop}[mode]
180
180
 
181
- while True:
182
- msg_in: AxisArray = yield msg_out
181
+ class CommonRereferenceTransformer(
182
+ BaseTransformer[CommonRereferenceSettings, AxisArray, AxisArray]
183
+ ):
184
+ def _process(self, message: AxisArray) -> AxisArray:
185
+ if self.settings.mode == "passthrough":
186
+ return message
183
187
 
184
- if axis is None:
185
- axis = msg_in.dims[-1]
186
- axis_idx = -1
187
- else:
188
- axis_idx = msg_in.get_axis_idx(axis)
188
+ axis = self.settings.axis or message.dims[-1]
189
+ axis_idx = message.get_axis_idx(axis)
190
+
191
+ func = {"mean": np.mean, "median": np.median, "passthrough": zeros_for_noop}[
192
+ self.settings.mode
193
+ ]
189
194
 
190
- ref_data = func(msg_in.data, axis=axis_idx, keepdims=True)
195
+ ref_data = func(message.data, axis=axis_idx, keepdims=True)
191
196
 
192
- if not include_current:
197
+ if not self.settings.include_current:
193
198
  # Typical `CAR = x[0]/N + x[1]/N + ... x[i-1]/N + x[i]/N + x[i+1]/N + ... + x[N-1]/N`
194
199
  # and is the same for all i, so it is calculated only once in `ref_data`.
195
200
  # However, if we had excluded the current channel,
@@ -200,34 +205,35 @@ def common_rereference(
200
205
  # from the current channel (i.e., `x[i] / (N-1)`)
201
206
  # i.e., `CAR[i] = (N / (N-1)) * common_CAR - x[i]/(N-1)`
202
207
  # We can use broadcasting subtraction instead of looping over channels.
203
- N = msg_in.data.shape[axis_idx]
204
- ref_data = (N / (N - 1)) * ref_data - msg_in.data / (N - 1)
205
- # Side note: I profiled using affine_transform and it's about 30x slower than this implementation.
208
+ N = message.data.shape[axis_idx]
209
+ ref_data = (N / (N - 1)) * ref_data - message.data / (N - 1)
210
+ # Note: I profiled using AffineTransformTransformer; it's ~30x slower than this implementation.
206
211
 
207
- msg_out = replace(msg_in, data=msg_in.data - ref_data)
212
+ return replace(message, data=message.data - ref_data)
208
213
 
209
214
 
210
- class CommonRereferenceSettings(ez.Settings):
211
- """
212
- Settings for :obj:`CommonRereference`
213
- See :obj:`common_rereference` for argument details.
214
- """
215
-
216
- mode: str = "mean"
217
- axis: str | None = None
218
- include_current: bool = True
215
+ class CommonRereference(
216
+ BaseTransformerUnit[
217
+ CommonRereferenceSettings, AxisArray, AxisArray, CommonRereferenceTransformer
218
+ ]
219
+ ):
220
+ SETTINGS = CommonRereferenceSettings
219
221
 
220
222
 
221
- class CommonRereference(GenAxisArray):
222
- """
223
- :obj:`Unit` for :obj:`common_rereference`.
223
+ def common_rereference(
224
+ mode: str = "mean", axis: str | None = None, include_current: bool = True
225
+ ) -> CommonRereferenceTransformer:
224
226
  """
227
+ Perform common average referencing (CAR) on streaming data.
225
228
 
226
- SETTINGS = CommonRereferenceSettings
229
+ Args:
230
+ mode: The statistical mode to apply -- either "mean" or "median"
231
+ axis: The name of hte axis to apply the transformation to.
232
+ include_current: Set False to exclude each channel from participating in the calculation of its reference.
227
233
 
228
- def construct_generator(self):
229
- self.STATE.gen = common_rereference(
230
- mode=self.SETTINGS.mode,
231
- axis=self.SETTINGS.axis,
232
- include_current=self.SETTINGS.include_current,
233
- )
234
+ Returns:
235
+ :obj:`CommonRereferenceTransformer`
236
+ """
237
+ return CommonRereferenceTransformer(
238
+ CommonRereferenceSettings(mode=mode, axis=axis, include_current=include_current)
239
+ )
@@ -3,7 +3,6 @@ import typing
3
3
  import numpy as np
4
4
  import numpy.typing as npt
5
5
  import ezmsg.core as ez
6
- from ezmsg.util.generator import consumer
7
6
  from ezmsg.util.messages.axisarray import (
8
7
  AxisArray,
9
8
  slice_along_axis,
@@ -12,7 +11,11 @@ from ezmsg.util.messages.axisarray import (
12
11
  )
13
12
 
14
13
  from .spectral import OptionsEnum
15
- from .base import GenAxisArray
14
+ from .base import (
15
+ BaseStatefulTransformer,
16
+ BaseTransformerUnit,
17
+ processor_state,
18
+ )
16
19
 
17
20
 
18
21
  class AggregationFunction(OptionsEnum):
@@ -54,130 +57,142 @@ AGGREGATORS = {
54
57
  }
55
58
 
56
59
 
57
- @consumer
58
- def ranged_aggregate(
59
- axis: str | None = None,
60
- bands: list[tuple[float, float]] | None = None,
61
- operation: AggregationFunction = AggregationFunction.MEAN,
62
- ):
60
+ class RangedAggregateSettings(ez.Settings):
61
+ """
62
+ Settings for ``RangedAggregate``.
63
63
  """
64
- Apply an aggregation operation over one or more bands.
65
64
 
66
- Args:
67
- axis: The name of the axis along which to apply the bands.
68
- bands: [(band1_min, band1_max), (band2_min, band2_max), ...]
69
- If not set then this acts as a passthrough node.
70
- operation: :obj:`AggregationFunction` to apply to each band.
65
+ axis: str | None = None
66
+ """The name of the axis along which to apply the bands."""
71
67
 
72
- Returns:
73
- A primed generator object ready to yield an :obj:`AxisArray` for each .send(axis_array)
68
+ bands: list[tuple[float, float]] | None = None
74
69
  """
75
- msg_out = AxisArray(np.array([]), dims=[""])
70
+ [(band1_min, band1_max), (band2_min, band2_max), ...]
71
+ If not set then this acts as a passthrough node.
72
+ """
73
+
74
+ operation: AggregationFunction = AggregationFunction.MEAN
75
+ """:obj:`AggregationFunction` to apply to each band."""
76
76
 
77
- # State variables
77
+
78
+ @processor_state
79
+ class RangedAggregateState:
78
80
  slices: list[tuple[typing.Any, ...]] | None = None
79
81
  out_axis: AxisBase | None = None
80
82
  ax_vec: npt.NDArray | None = None
81
83
 
82
- # Reset if any of these changes. Key not checked because continuity between chunks not required.
83
- check_inputs = {"gain": None, "offset": None, "len": None, "key": None}
84
84
 
85
- while True:
86
- msg_in: AxisArray = yield msg_out
87
- if bands is None:
88
- msg_out = msg_in
85
+ class RangedAggregateTransformer(
86
+ BaseStatefulTransformer[
87
+ RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateState
88
+ ]
89
+ ):
90
+ def __call__(self, message: AxisArray) -> AxisArray:
91
+ # Override for shortcut passthrough mode.
92
+ if self.settings.bands is None:
93
+ return message
94
+ return super().__call__(message)
95
+
96
+ def _hash_message(self, message: AxisArray) -> int:
97
+ axis = self.settings.axis or message.dims[0]
98
+ target_axis = message.get_axis(axis)
99
+
100
+ hash_components = (message.key,)
101
+ if hasattr(target_axis, "data"):
102
+ hash_components += (len(target_axis.data),)
103
+ elif isinstance(target_axis, AxisArray.LinearAxis):
104
+ hash_components += (target_axis.gain, target_axis.offset)
105
+ return hash(hash_components)
106
+
107
+ def _reset_state(self, message: AxisArray) -> None:
108
+ axis = self.settings.axis or message.dims[0]
109
+ target_axis = message.get_axis(axis)
110
+ ax_idx = message.get_axis_idx(axis)
111
+
112
+ if hasattr(target_axis, "data"):
113
+ self._state.ax_vec = target_axis.data
89
114
  else:
90
- axis = axis or msg_in.dims[0]
91
- target_axis = msg_in.get_axis(axis)
115
+ self._state.ax_vec = target_axis.value(
116
+ np.arange(message.data.shape[ax_idx])
117
+ )
92
118
 
93
- # Check if we need to reset state
94
- b_reset = msg_in.key != check_inputs["key"]
119
+ ax_dat = []
120
+ slices = []
121
+ for start, stop in self.settings.bands:
122
+ inds = np.where(
123
+ np.logical_and(self._state.ax_vec >= start, self._state.ax_vec <= stop)
124
+ )[0]
125
+ slices.append(np.s_[inds[0] : inds[-1] + 1])
95
126
  if hasattr(target_axis, "data"):
96
- b_reset = b_reset or len(target_axis.data) != check_inputs["len"]
97
- elif isinstance(target_axis, AxisArray.LinearAxis):
98
- b_reset = b_reset or target_axis.gain != check_inputs["gain"]
99
- b_reset = b_reset or target_axis.offset != check_inputs["offset"]
100
-
101
- if b_reset:
102
- # Update check variables
103
- check_inputs["key"] = msg_in.key
104
- if hasattr(target_axis, "data"):
105
- check_inputs["len"] = len(target_axis.data)
127
+ if self._state.ax_vec.dtype.type is np.str_:
128
+ sl_dat = f"{self._state.ax_vec[start]} - {self._state.ax_vec[stop]}"
106
129
  else:
107
- check_inputs["gain"] = target_axis.gain
108
- check_inputs["offset"] = target_axis.offset
130
+ ax_dat.append(np.mean(self._state.ax_vec[inds]))
131
+ else:
132
+ sl_dat = target_axis.value(np.mean(inds))
133
+ ax_dat.append(sl_dat)
134
+
135
+ self._state.slices = slices
136
+ self._state.out_axis = AxisArray.CoordinateAxis(
137
+ data=np.array(ax_dat),
138
+ dims=[axis],
139
+ unit=target_axis.unit,
140
+ )
109
141
 
110
- # If the axis we are operating on has changed (e.g., "time" or "win" axis always changes),
111
- # or the key has changed, then recalculate slices.
142
+ def _process(self, message: AxisArray) -> AxisArray:
143
+ axis = self.settings.axis or message.dims[0]
144
+ ax_idx = message.get_axis_idx(axis)
145
+ agg_func = AGGREGATORS[self.settings.operation]
112
146
 
113
- ax_idx = msg_in.get_axis_idx(axis)
147
+ out_data = [
148
+ agg_func(slice_along_axis(message.data, sl, axis=ax_idx), axis=ax_idx)
149
+ for sl in self._state.slices
150
+ ]
114
151
 
115
- if hasattr(target_axis, "data"):
116
- ax_vec = target_axis.data
117
- else:
118
- ax_vec = target_axis.value(np.arange(msg_in.data.shape[ax_idx]))
119
-
120
- slices = []
121
- ax_dat = []
122
- for start, stop in bands:
123
- inds = np.where(np.logical_and(ax_vec >= start, ax_vec <= stop))[0]
124
- slices.append(np.s_[inds[0] : inds[-1] + 1])
125
- if hasattr(target_axis, "data"):
126
- if ax_vec.dtype.type is np.str_:
127
- sl_dat = f"{ax_vec[start]} - {ax_vec[stop]}"
128
- else:
129
- sl_dat = ax_dat.append(np.mean(ax_vec[inds]))
130
- else:
131
- sl_dat = target_axis.value(np.mean(inds))
132
- ax_dat.append(sl_dat)
133
-
134
- out_axis = AxisArray.CoordinateAxis(
135
- data=np.array(ax_dat),
136
- dims=[axis],
137
- unit=target_axis.unit,
138
- )
139
-
140
- agg_func = AGGREGATORS[operation]
141
- out_data = [
142
- agg_func(slice_along_axis(msg_in.data, sl, axis=ax_idx), axis=ax_idx)
143
- for sl in slices
144
- ]
145
-
146
- msg_out = replace(
147
- msg_in,
148
- data=np.stack(out_data, axis=ax_idx),
149
- axes={**msg_in.axes, axis: out_axis},
150
- )
151
- if operation in [AggregationFunction.ARGMIN, AggregationFunction.ARGMAX]:
152
- # Convert indices returned by argmin/argmax into the value along the axis.
153
- out_data = []
154
- for sl_ix, sl in enumerate(slices):
155
- offsets = np.take(msg_out.data, [sl_ix], axis=ax_idx)
156
- out_data.append(ax_vec[sl][offsets])
157
- msg_out.data = np.concatenate(out_data, axis=ax_idx)
152
+ msg_out = replace(
153
+ message,
154
+ data=np.stack(out_data, axis=ax_idx),
155
+ axes={**message.axes, axis: self._state.out_axis},
156
+ )
158
157
 
158
+ if self.settings.operation in [
159
+ AggregationFunction.ARGMIN,
160
+ AggregationFunction.ARGMAX,
161
+ ]:
162
+ out_data = []
163
+ for sl_ix, sl in enumerate(self._state.slices):
164
+ offsets = np.take(msg_out.data, [sl_ix], axis=ax_idx)
165
+ out_data.append(self._state.ax_vec[sl][offsets])
166
+ msg_out.data = np.concatenate(out_data, axis=ax_idx)
159
167
 
160
- class RangedAggregateSettings(ez.Settings):
161
- """
162
- Settings for ``RangedAggregate``.
163
- See :obj:`ranged_aggregate` for details.
164
- """
168
+ return msg_out
165
169
 
166
- axis: str | None = None
167
- bands: list[tuple[float, float]] | None = None
168
- operation: AggregationFunction = AggregationFunction.MEAN
170
+
171
+ class RangedAggregate(
172
+ BaseTransformerUnit[
173
+ RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateTransformer
174
+ ]
175
+ ):
176
+ SETTINGS = RangedAggregateSettings
169
177
 
170
178
 
171
- class RangedAggregate(GenAxisArray):
172
- """
173
- Unit for :obj:`ranged_aggregate`
179
+ def ranged_aggregate(
180
+ axis: str | None = None,
181
+ bands: list[tuple[float, float]] | None = None,
182
+ operation: AggregationFunction = AggregationFunction.MEAN,
183
+ ) -> RangedAggregateTransformer:
174
184
  """
185
+ Apply an aggregation operation over one or more bands.
175
186
 
176
- SETTINGS = RangedAggregateSettings
187
+ Args:
188
+ axis: The name of the axis along which to apply the bands.
189
+ bands: [(band1_min, band1_max), (band2_min, band2_max), ...]
190
+ If not set then this acts as a passthrough node.
191
+ operation: :obj:`AggregationFunction` to apply to each band.
177
192
 
178
- def construct_generator(self):
179
- self.STATE.gen = ranged_aggregate(
180
- axis=self.SETTINGS.axis,
181
- bands=self.SETTINGS.bands,
182
- operation=self.SETTINGS.operation,
183
- )
193
+ Returns:
194
+ :obj:`RangedAggregateTransformer`
195
+ """
196
+ return RangedAggregateTransformer(
197
+ RangedAggregateSettings(axis=axis, bands=bands, operation=operation)
198
+ )