ezmsg-sigproc 2.5.0__py3-none-any.whl → 2.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +5 -11
  3. ezmsg/sigproc/adaptive_lattice_notch.py +11 -30
  4. ezmsg/sigproc/affinetransform.py +16 -42
  5. ezmsg/sigproc/aggregate.py +17 -34
  6. ezmsg/sigproc/bandpower.py +12 -20
  7. ezmsg/sigproc/base.py +141 -1276
  8. ezmsg/sigproc/butterworthfilter.py +8 -16
  9. ezmsg/sigproc/butterworthzerophase.py +7 -16
  10. ezmsg/sigproc/cheby.py +4 -10
  11. ezmsg/sigproc/combfilter.py +5 -8
  12. ezmsg/sigproc/coordinatespaces.py +142 -0
  13. ezmsg/sigproc/decimate.py +3 -7
  14. ezmsg/sigproc/denormalize.py +6 -11
  15. ezmsg/sigproc/detrend.py +3 -4
  16. ezmsg/sigproc/diff.py +8 -17
  17. ezmsg/sigproc/downsample.py +11 -20
  18. ezmsg/sigproc/ewma.py +11 -28
  19. ezmsg/sigproc/ewmfilter.py +1 -1
  20. ezmsg/sigproc/extract_axis.py +3 -4
  21. ezmsg/sigproc/fbcca.py +34 -59
  22. ezmsg/sigproc/filter.py +19 -45
  23. ezmsg/sigproc/filterbank.py +37 -74
  24. ezmsg/sigproc/filterbankdesign.py +7 -14
  25. ezmsg/sigproc/fir_hilbert.py +13 -30
  26. ezmsg/sigproc/fir_pmc.py +5 -10
  27. ezmsg/sigproc/firfilter.py +12 -14
  28. ezmsg/sigproc/gaussiansmoothing.py +5 -9
  29. ezmsg/sigproc/kaiser.py +11 -15
  30. ezmsg/sigproc/math/abs.py +4 -3
  31. ezmsg/sigproc/math/add.py +121 -0
  32. ezmsg/sigproc/math/clip.py +4 -1
  33. ezmsg/sigproc/math/difference.py +100 -36
  34. ezmsg/sigproc/math/invert.py +3 -3
  35. ezmsg/sigproc/math/log.py +5 -6
  36. ezmsg/sigproc/math/scale.py +2 -0
  37. ezmsg/sigproc/messages.py +1 -2
  38. ezmsg/sigproc/quantize.py +3 -6
  39. ezmsg/sigproc/resample.py +17 -38
  40. ezmsg/sigproc/rollingscaler.py +12 -37
  41. ezmsg/sigproc/sampler.py +19 -37
  42. ezmsg/sigproc/scaler.py +11 -22
  43. ezmsg/sigproc/signalinjector.py +7 -18
  44. ezmsg/sigproc/slicer.py +14 -34
  45. ezmsg/sigproc/spectral.py +3 -3
  46. ezmsg/sigproc/spectrogram.py +12 -19
  47. ezmsg/sigproc/spectrum.py +17 -38
  48. ezmsg/sigproc/transpose.py +12 -24
  49. ezmsg/sigproc/util/asio.py +25 -156
  50. ezmsg/sigproc/util/axisarray_buffer.py +12 -26
  51. ezmsg/sigproc/util/buffer.py +22 -43
  52. ezmsg/sigproc/util/message.py +17 -31
  53. ezmsg/sigproc/util/profile.py +23 -174
  54. ezmsg/sigproc/util/sparse.py +7 -15
  55. ezmsg/sigproc/util/typeresolution.py +17 -83
  56. ezmsg/sigproc/wavelets.py +10 -19
  57. ezmsg/sigproc/window.py +29 -83
  58. ezmsg_sigproc-2.7.0.dist-info/METADATA +60 -0
  59. ezmsg_sigproc-2.7.0.dist-info/RECORD +64 -0
  60. ezmsg/sigproc/synth.py +0 -774
  61. ezmsg_sigproc-2.5.0.dist-info/METADATA +0 -72
  62. ezmsg_sigproc-2.5.0.dist-info/RECORD +0 -63
  63. {ezmsg_sigproc-2.5.0.dist-info → ezmsg_sigproc-2.7.0.dist-info}/WHEEL +0 -0
  64. /ezmsg_sigproc-2.5.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.7.0.dist-info/licenses/LICENSE +0 -0
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '2.5.0'
32
- __version_tuple__ = version_tuple = (2, 5, 0)
31
+ __version__ = version = '2.7.0'
32
+ __version_tuple__ = version_tuple = (2, 7, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -1,10 +1,10 @@
1
- import scipy.special
2
1
  import ezmsg.core as ez
2
+ import scipy.special
3
+ from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
3
4
  from ezmsg.util.messages.axisarray import AxisArray
4
5
  from ezmsg.util.messages.util import replace
5
6
 
6
7
  from .spectral import OptionsEnum
7
- from .base import BaseTransformer, BaseTransformerUnit
8
8
 
9
9
 
10
10
  class ActivationFunction(OptionsEnum):
@@ -50,20 +50,14 @@ class ActivationTransformer(BaseTransformer[ActivationSettings, AxisArray, AxisA
50
50
  # str type handling
51
51
  function = self.settings.function.lower()
52
52
  if function not in ActivationFunction.options():
53
- raise ValueError(
54
- f"Unrecognized activation function {function}. Must be one of {ACTIVATIONS.keys()}"
55
- )
56
- function = list(ACTIVATIONS.keys())[
57
- ActivationFunction.options().index(function)
58
- ]
53
+ raise ValueError(f"Unrecognized activation function {function}. Must be one of {ACTIVATIONS.keys()}")
54
+ function = list(ACTIVATIONS.keys())[ActivationFunction.options().index(function)]
59
55
  func = ACTIVATIONS[function]
60
56
 
61
57
  return replace(message, data=func(message.data))
62
58
 
63
59
 
64
- class Activation(
65
- BaseTransformerUnit[ActivationSettings, AxisArray, AxisArray, ActivationTransformer]
66
- ):
60
+ class Activation(BaseTransformerUnit[ActivationSettings, AxisArray, AxisArray, ActivationTransformer]):
67
61
  SETTINGS = ActivationSettings
68
62
 
69
63
 
@@ -1,12 +1,11 @@
1
+ import ezmsg.core as ez
1
2
  import numpy as np
2
3
  import numpy.typing as npt
3
4
  import scipy.signal
4
- import ezmsg.core as ez
5
+ from ezmsg.baseproc import BaseStatefulTransformer, processor_state
5
6
  from ezmsg.util.messages.axisarray import AxisArray, CoordinateAxis
6
7
  from ezmsg.util.messages.util import replace
7
8
 
8
- from .base import processor_state, BaseStatefulTransformer
9
-
10
9
 
11
10
  class AdaptiveLatticeNotchFilterSettings(ez.Settings):
12
11
  """Settings for the Adaptive Lattice Notch Filter."""
@@ -76,9 +75,7 @@ class AdaptiveLatticeNotchFilterTransformer(
76
75
 
77
76
  fs = 1 / message.axes[self.settings.axis].gain
78
77
  init_f = (
79
- self.settings.init_notch_freq
80
- if self.settings.init_notch_freq is not None
81
- else 0.07178314656435313 * fs
78
+ self.settings.init_notch_freq if self.settings.init_notch_freq is not None else 0.07178314656435313 * fs
82
79
  )
83
80
  init_omega = init_f * (2 * np.pi) / fs
84
81
  init_k1 = -np.cos(init_omega)
@@ -91,9 +88,7 @@ class AdaptiveLatticeNotchFilterTransformer(
91
88
  self._state.k1 = init_k1 + np.zeros(sample_shape, dtype=float)
92
89
  self._state.freq_template = CoordinateAxis(
93
90
  data=np.zeros((0,) + sample_shape, dtype=float),
94
- dims=[self.settings.axis]
95
- + message.dims[:ax_idx]
96
- + message.dims[ax_idx + 1 :],
91
+ dims=[self.settings.axis] + message.dims[:ax_idx] + message.dims[ax_idx + 1 :],
97
92
  unit="Hz",
98
93
  )
99
94
 
@@ -147,9 +142,7 @@ class AdaptiveLatticeNotchFilterTransformer(
147
142
  for ix, k in enumerate(self._state.k1.flatten()):
148
143
  # Filter to get s_n (notch filter state)
149
144
  a_s = [1, k * gamma_plus_1, gamma]
150
- s_n[:, ix], self._state.zi[:, ix] = scipy.signal.lfilter(
151
- [1], a_s, _x[:, ix], zi=self._state.zi[:, ix]
152
- )
145
+ s_n[:, ix], self._state.zi[:, ix] = scipy.signal.lfilter([1], a_s, _x[:, ix], zi=self._state.zi[:, ix])
153
146
 
154
147
  # Apply output filter to get y_out
155
148
  b_y = [1, 2 * k, 1]
@@ -159,17 +152,11 @@ class AdaptiveLatticeNotchFilterTransformer(
159
152
  s_n_reshaped = s_n.reshape((s_n.shape[0],) + x_data.shape[1:])
160
153
  s_final = s_n_reshaped[-1] # Current s_n
161
154
  s_final_1 = s_n_reshaped[-2] # s_n_1
162
- s_final_2 = (
163
- s_n_reshaped[-3] if len(s_n_reshaped) > 2 else self._state.s_history[0]
164
- ) # s_n_2
155
+ s_final_2 = s_n_reshaped[-3] if len(s_n_reshaped) > 2 else self._state.s_history[0] # s_n_2
165
156
 
166
157
  # Update p and q using final values
167
- self._state.p = eta * self._state.p + one_minus_eta * (
168
- s_final_1 * (s_final + s_final_2)
169
- )
170
- self._state.q = eta * self._state.q + one_minus_eta * (
171
- 2 * (s_final_1 * s_final_1)
172
- )
158
+ self._state.p = eta * self._state.p + one_minus_eta * (s_final_1 * (s_final + s_final_2))
159
+ self._state.q = eta * self._state.q + one_minus_eta * (2 * (s_final_1 * s_final_1))
173
160
 
174
161
  # Update reflection coefficient
175
162
  new_k1 = -self._state.p / (self._state.q + 1e-8) # Avoid division by zero
@@ -199,17 +186,11 @@ class AdaptiveLatticeNotchFilterTransformer(
199
186
  y_out[sample_ix] = s_n + 2 * self._state.k1 * s_n_1 + s_n_2
200
187
 
201
188
  # Update filter parameters
202
- self._state.p = eta * self._state.p + one_minus_eta * (
203
- s_n_1 * (s_n + s_n_2)
204
- )
205
- self._state.q = eta * self._state.q + one_minus_eta * (
206
- 2 * (s_n_1 * s_n_1)
207
- )
189
+ self._state.p = eta * self._state.p + one_minus_eta * (s_n_1 * (s_n + s_n_2))
190
+ self._state.q = eta * self._state.q + one_minus_eta * (2 * (s_n_1 * s_n_1))
208
191
 
209
192
  # Update reflection coefficient
210
- new_k1 = -self._state.p / (
211
- self._state.q + 1e-8
212
- ) # Avoid division by zero
193
+ new_k1 = -self._state.p / (self._state.q + 1e-8) # Avoid division by zero
213
194
  new_k1 = np.clip(new_k1, -1, 1) # Clip to prevent instability
214
195
  self._state.k1 = mu * self._state.k1 + one_minus_mu * new_k1 # Smoothed
215
196
 
@@ -1,18 +1,17 @@
1
1
  import os
2
2
  from pathlib import Path
3
3
 
4
+ import ezmsg.core as ez
4
5
  import numpy as np
5
6
  import numpy.typing as npt
6
- import ezmsg.core as ez
7
- from ezmsg.util.messages.axisarray import AxisArray, AxisBase
8
- from ezmsg.util.messages.util import replace
9
-
10
- from .base import (
7
+ from ezmsg.baseproc import (
11
8
  BaseStatefulTransformer,
12
- BaseTransformerUnit,
13
9
  BaseTransformer,
10
+ BaseTransformerUnit,
14
11
  processor_state,
15
12
  )
13
+ from ezmsg.util.messages.axisarray import AxisArray, AxisBase
14
+ from ezmsg.util.messages.util import replace
16
15
 
17
16
 
18
17
  class AffineTransformSettings(ez.Settings):
@@ -38,15 +37,12 @@ class AffineTransformState:
38
37
 
39
38
 
40
39
  class AffineTransformTransformer(
41
- BaseStatefulTransformer[
42
- AffineTransformSettings, AxisArray, AxisArray, AffineTransformState
43
- ]
40
+ BaseStatefulTransformer[AffineTransformSettings, AxisArray, AxisArray, AffineTransformState]
44
41
  ):
45
42
  def __call__(self, message: AxisArray) -> AxisArray:
46
43
  # Override __call__ so we can shortcut if weights are None.
47
44
  if self.settings.weights is None or (
48
- isinstance(self.settings.weights, str)
49
- and self.settings.weights == "passthrough"
45
+ isinstance(self.settings.weights, str) and self.settings.weights == "passthrough"
50
46
  ):
51
47
  return message
52
48
  return super().__call__(message)
@@ -68,18 +64,12 @@ class AffineTransformTransformer(
68
64
  self._state.weights = weights
69
65
 
70
66
  axis = self.settings.axis or message.dims[-1]
71
- if (
72
- axis in message.axes
73
- and hasattr(message.axes[axis], "data")
74
- and weights.shape[0] != weights.shape[1]
75
- ):
67
+ if axis in message.axes and hasattr(message.axes[axis], "data") and weights.shape[0] != weights.shape[1]:
76
68
  in_labels = message.axes[axis].data
77
69
  new_labels = []
78
70
  n_in, n_out = weights.shape
79
71
  if len(in_labels) != n_in:
80
- ez.logger.warning(
81
- f"Received {len(in_labels)} for {n_in} inputs. Check upstream labels."
82
- )
72
+ ez.logger.warning(f"Received {len(in_labels)} for {n_in} inputs. Check upstream labels.")
83
73
  else:
84
74
  b_filled_outputs = np.any(weights, axis=0)
85
75
  b_used_inputs = np.any(weights, axis=1)
@@ -97,9 +87,7 @@ class AffineTransformTransformer(
97
87
  elif np.all(b_filled_outputs):
98
88
  new_labels = np.array(in_labels)[b_used_inputs]
99
89
 
100
- self._state.new_axis = replace(
101
- message.axes[axis], data=np.array(new_labels)
102
- )
90
+ self._state.new_axis = replace(message.axes[axis], data=np.array(new_labels))
103
91
 
104
92
  def _process(self, message: AxisArray) -> AxisArray:
105
93
  axis = self.settings.axis or message.dims[-1]
@@ -110,9 +98,7 @@ class AffineTransformTransformer(
110
98
  # The weights are stacked A|B where A is the transform and B is a single row
111
99
  # in the equation y = Ax + B. This supports NeuroKey's weights matrices.
112
100
  sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx + 1 :]
113
- data = np.concatenate(
114
- (data, np.ones(sample_shape).astype(data.dtype)), axis=axis_idx
115
- )
101
+ data = np.concatenate((data, np.ones(sample_shape).astype(data.dtype)), axis=axis_idx)
116
102
 
117
103
  if axis_idx in [-1, len(message.dims) - 1]:
118
104
  data = np.matmul(data, self._state.weights)
@@ -128,11 +114,7 @@ class AffineTransformTransformer(
128
114
  return replace(message, **replace_kwargs)
129
115
 
130
116
 
131
- class AffineTransform(
132
- BaseTransformerUnit[
133
- AffineTransformSettings, AxisArray, AxisArray, AffineTransformTransformer
134
- ]
135
- ):
117
+ class AffineTransform(BaseTransformerUnit[AffineTransformSettings, AxisArray, AxisArray, AffineTransformTransformer]):
136
118
  SETTINGS = AffineTransformSettings
137
119
 
138
120
 
@@ -153,9 +135,7 @@ def affine_transform(
153
135
  :obj:`AffineTransformTransformer`.
154
136
  """
155
137
  return AffineTransformTransformer(
156
- AffineTransformSettings(
157
- weights=weights, axis=axis, right_multiply=right_multiply
158
- )
138
+ AffineTransformSettings(weights=weights, axis=axis, right_multiply=right_multiply)
159
139
  )
160
140
 
161
141
 
@@ -178,9 +158,7 @@ class CommonRereferenceSettings(ez.Settings):
178
158
  """Set False to exclude each channel from participating in the calculation of its reference."""
179
159
 
180
160
 
181
- class CommonRereferenceTransformer(
182
- BaseTransformer[CommonRereferenceSettings, AxisArray, AxisArray]
183
- ):
161
+ class CommonRereferenceTransformer(BaseTransformer[CommonRereferenceSettings, AxisArray, AxisArray]):
184
162
  def _process(self, message: AxisArray) -> AxisArray:
185
163
  if self.settings.mode == "passthrough":
186
164
  return message
@@ -188,9 +166,7 @@ class CommonRereferenceTransformer(
188
166
  axis = self.settings.axis or message.dims[-1]
189
167
  axis_idx = message.get_axis_idx(axis)
190
168
 
191
- func = {"mean": np.mean, "median": np.median, "passthrough": zeros_for_noop}[
192
- self.settings.mode
193
- ]
169
+ func = {"mean": np.mean, "median": np.median, "passthrough": zeros_for_noop}[self.settings.mode]
194
170
 
195
171
  ref_data = func(message.data, axis=axis_idx, keepdims=True)
196
172
 
@@ -213,9 +189,7 @@ class CommonRereferenceTransformer(
213
189
 
214
190
 
215
191
  class CommonRereference(
216
- BaseTransformerUnit[
217
- CommonRereferenceSettings, AxisArray, AxisArray, CommonRereferenceTransformer
218
- ]
192
+ BaseTransformerUnit[CommonRereferenceSettings, AxisArray, AxisArray, CommonRereferenceTransformer]
219
193
  ):
220
194
  SETTINGS = CommonRereferenceSettings
221
195
 
@@ -1,23 +1,23 @@
1
- from array_api_compat import get_namespace
2
1
  import typing
3
2
 
3
+ import ezmsg.core as ez
4
4
  import numpy as np
5
5
  import numpy.typing as npt
6
- import ezmsg.core as ez
6
+ from array_api_compat import get_namespace
7
+ from ezmsg.baseproc import (
8
+ BaseStatefulTransformer,
9
+ BaseTransformer,
10
+ BaseTransformerUnit,
11
+ processor_state,
12
+ )
7
13
  from ezmsg.util.messages.axisarray import (
8
14
  AxisArray,
9
- slice_along_axis,
10
15
  AxisBase,
11
16
  replace,
17
+ slice_along_axis,
12
18
  )
13
19
 
14
20
  from .spectral import OptionsEnum
15
- from .base import (
16
- BaseTransformer,
17
- BaseStatefulTransformer,
18
- BaseTransformerUnit,
19
- processor_state,
20
- )
21
21
 
22
22
 
23
23
  class AggregationFunction(OptionsEnum):
@@ -89,9 +89,7 @@ class RangedAggregateState:
89
89
 
90
90
 
91
91
  class RangedAggregateTransformer(
92
- BaseStatefulTransformer[
93
- RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateState
94
- ]
92
+ BaseStatefulTransformer[RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateState]
95
93
  ):
96
94
  def __call__(self, message: AxisArray) -> AxisArray:
97
95
  # Override for shortcut passthrough mode.
@@ -118,16 +116,12 @@ class RangedAggregateTransformer(
118
116
  if hasattr(target_axis, "data"):
119
117
  self._state.ax_vec = target_axis.data
120
118
  else:
121
- self._state.ax_vec = target_axis.value(
122
- np.arange(message.data.shape[ax_idx])
123
- )
119
+ self._state.ax_vec = target_axis.value(np.arange(message.data.shape[ax_idx]))
124
120
 
125
121
  ax_dat = []
126
122
  slices = []
127
123
  for start, stop in self.settings.bands:
128
- inds = np.where(
129
- np.logical_and(self._state.ax_vec >= start, self._state.ax_vec <= stop)
130
- )[0]
124
+ inds = np.where(np.logical_and(self._state.ax_vec >= start, self._state.ax_vec <= stop))[0]
131
125
  slices.append(np.s_[inds[0] : inds[-1] + 1])
132
126
  if hasattr(target_axis, "data"):
133
127
  if self._state.ax_vec.dtype.type is np.str_:
@@ -164,8 +158,7 @@ class RangedAggregateTransformer(
164
158
  ]
165
159
  else:
166
160
  out_data = [
167
- agg_func(slice_along_axis(message.data, sl, axis=ax_idx), axis=ax_idx)
168
- for sl in self._state.slices
161
+ agg_func(slice_along_axis(message.data, sl, axis=ax_idx), axis=ax_idx) for sl in self._state.slices
169
162
  ]
170
163
 
171
164
  msg_out = replace(
@@ -187,11 +180,7 @@ class RangedAggregateTransformer(
187
180
  return msg_out
188
181
 
189
182
 
190
- class RangedAggregate(
191
- BaseTransformerUnit[
192
- RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateTransformer
193
- ]
194
- ):
183
+ class RangedAggregate(BaseTransformerUnit[RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateTransformer]):
195
184
  SETTINGS = RangedAggregateSettings
196
185
 
197
186
 
@@ -212,9 +201,7 @@ def ranged_aggregate(
212
201
  Returns:
213
202
  :obj:`RangedAggregateTransformer`
214
203
  """
215
- return RangedAggregateTransformer(
216
- RangedAggregateSettings(axis=axis, bands=bands, operation=operation)
217
- )
204
+ return RangedAggregateTransformer(RangedAggregateSettings(axis=axis, bands=bands, operation=operation))
218
205
 
219
206
 
220
207
  class AggregateSettings(ez.Settings):
@@ -242,9 +229,7 @@ class AggregateTransformer(BaseTransformer[AggregateSettings, AxisArray, AxisArr
242
229
  op = self.settings.operation
243
230
 
244
231
  if op == AggregationFunction.NONE:
245
- raise ValueError(
246
- "AggregationFunction.NONE is not supported for full-axis aggregation"
247
- )
232
+ raise ValueError("AggregationFunction.NONE is not supported for full-axis aggregation")
248
233
 
249
234
  if op == AggregationFunction.TRAPEZOID:
250
235
  # Trapezoid integration requires x-coordinates
@@ -276,9 +261,7 @@ class AggregateTransformer(BaseTransformer[AggregateSettings, AxisArray, AxisArr
276
261
  )
277
262
 
278
263
 
279
- class AggregateUnit(
280
- BaseTransformerUnit[AggregateSettings, AxisArray, AxisArray, AggregateTransformer]
281
- ):
264
+ class AggregateUnit(BaseTransformerUnit[AggregateSettings, AxisArray, AxisArray, AggregateTransformer]):
282
265
  """Unit that aggregates an entire axis using a specified operation."""
283
266
 
284
267
  SETTINGS = AggregateSettings
@@ -1,20 +1,20 @@
1
1
  from dataclasses import field
2
2
 
3
3
  import ezmsg.core as ez
4
+ from ezmsg.baseproc import (
5
+ BaseProcessor,
6
+ BaseStatefulProcessor,
7
+ BaseTransformerUnit,
8
+ CompositeProcessor,
9
+ )
4
10
  from ezmsg.util.messages.axisarray import AxisArray
5
11
 
6
- from .spectrogram import SpectrogramSettings, SpectrogramTransformer
7
12
  from .aggregate import (
8
13
  AggregationFunction,
9
- RangedAggregateTransformer,
10
14
  RangedAggregateSettings,
15
+ RangedAggregateTransformer,
11
16
  )
12
- from .base import (
13
- BaseProcessor,
14
- CompositeProcessor,
15
- BaseStatefulProcessor,
16
- BaseTransformerUnit,
17
- )
17
+ from .spectrogram import SpectrogramSettings, SpectrogramTransformer
18
18
 
19
19
 
20
20
  class BandPowerSettings(ez.Settings):
@@ -22,16 +22,12 @@ class BandPowerSettings(ez.Settings):
22
22
  Settings for ``BandPower``.
23
23
  """
24
24
 
25
- spectrogram_settings: SpectrogramSettings = field(
26
- default_factory=SpectrogramSettings
27
- )
25
+ spectrogram_settings: SpectrogramSettings = field(default_factory=SpectrogramSettings)
28
26
  """
29
27
  Settings for spectrogram calculation.
30
28
  """
31
29
 
32
- bands: list[tuple[float, float]] | None = field(
33
- default_factory=lambda: [(17, 30), (70, 170)]
34
- )
30
+ bands: list[tuple[float, float]] | None = field(default_factory=lambda: [(17, 30), (70, 170)])
35
31
  """
36
32
  (min, max) tuples of band limits in Hz.
37
33
  """
@@ -46,9 +42,7 @@ class BandPowerTransformer(CompositeProcessor[BandPowerSettings, AxisArray, Axis
46
42
  settings: BandPowerSettings,
47
43
  ) -> dict[str, BaseProcessor | BaseStatefulProcessor]:
48
44
  return {
49
- "spectrogram": SpectrogramTransformer(
50
- settings=settings.spectrogram_settings
51
- ),
45
+ "spectrogram": SpectrogramTransformer(settings=settings.spectrogram_settings),
52
46
  "aggregate": RangedAggregateTransformer(
53
47
  settings=RangedAggregateSettings(
54
48
  axis="freq",
@@ -59,9 +53,7 @@ class BandPowerTransformer(CompositeProcessor[BandPowerSettings, AxisArray, Axis
59
53
  }
60
54
 
61
55
 
62
- class BandPower(
63
- BaseTransformerUnit[BandPowerSettings, AxisArray, AxisArray, BandPowerTransformer]
64
- ):
56
+ class BandPower(BaseTransformerUnit[BandPowerSettings, AxisArray, AxisArray, BandPowerTransformer]):
65
57
  SETTINGS = BandPowerSettings
66
58
 
67
59