ezmsg-sigproc 2.5.0__py3-none-any.whl → 2.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +5 -11
  3. ezmsg/sigproc/adaptive_lattice_notch.py +11 -29
  4. ezmsg/sigproc/affinetransform.py +13 -38
  5. ezmsg/sigproc/aggregate.py +13 -30
  6. ezmsg/sigproc/bandpower.py +7 -15
  7. ezmsg/sigproc/base.py +141 -1276
  8. ezmsg/sigproc/butterworthfilter.py +8 -16
  9. ezmsg/sigproc/butterworthzerophase.py +7 -16
  10. ezmsg/sigproc/cheby.py +4 -10
  11. ezmsg/sigproc/combfilter.py +5 -8
  12. ezmsg/sigproc/decimate.py +2 -6
  13. ezmsg/sigproc/denormalize.py +6 -11
  14. ezmsg/sigproc/detrend.py +3 -4
  15. ezmsg/sigproc/diff.py +8 -17
  16. ezmsg/sigproc/downsample.py +6 -14
  17. ezmsg/sigproc/ewma.py +11 -27
  18. ezmsg/sigproc/ewmfilter.py +1 -1
  19. ezmsg/sigproc/extract_axis.py +3 -4
  20. ezmsg/sigproc/fbcca.py +31 -56
  21. ezmsg/sigproc/filter.py +19 -45
  22. ezmsg/sigproc/filterbank.py +33 -70
  23. ezmsg/sigproc/filterbankdesign.py +5 -12
  24. ezmsg/sigproc/fir_hilbert.py +13 -30
  25. ezmsg/sigproc/fir_pmc.py +5 -10
  26. ezmsg/sigproc/firfilter.py +12 -14
  27. ezmsg/sigproc/gaussiansmoothing.py +5 -9
  28. ezmsg/sigproc/kaiser.py +11 -15
  29. ezmsg/sigproc/math/abs.py +1 -3
  30. ezmsg/sigproc/math/add.py +121 -0
  31. ezmsg/sigproc/math/clip.py +1 -1
  32. ezmsg/sigproc/math/difference.py +98 -36
  33. ezmsg/sigproc/math/invert.py +1 -3
  34. ezmsg/sigproc/math/log.py +2 -6
  35. ezmsg/sigproc/messages.py +1 -2
  36. ezmsg/sigproc/quantize.py +2 -4
  37. ezmsg/sigproc/resample.py +13 -34
  38. ezmsg/sigproc/rollingscaler.py +12 -37
  39. ezmsg/sigproc/sampler.py +17 -35
  40. ezmsg/sigproc/scaler.py +8 -18
  41. ezmsg/sigproc/signalinjector.py +6 -16
  42. ezmsg/sigproc/slicer.py +9 -28
  43. ezmsg/sigproc/spectral.py +3 -3
  44. ezmsg/sigproc/spectrogram.py +12 -19
  45. ezmsg/sigproc/spectrum.py +12 -32
  46. ezmsg/sigproc/transpose.py +7 -18
  47. ezmsg/sigproc/util/asio.py +25 -156
  48. ezmsg/sigproc/util/axisarray_buffer.py +10 -26
  49. ezmsg/sigproc/util/buffer.py +18 -43
  50. ezmsg/sigproc/util/message.py +17 -31
  51. ezmsg/sigproc/util/profile.py +23 -174
  52. ezmsg/sigproc/util/sparse.py +5 -15
  53. ezmsg/sigproc/util/typeresolution.py +17 -83
  54. ezmsg/sigproc/wavelets.py +6 -15
  55. ezmsg/sigproc/window.py +24 -78
  56. {ezmsg_sigproc-2.5.0.dist-info → ezmsg_sigproc-2.6.0.dist-info}/METADATA +3 -2
  57. ezmsg_sigproc-2.6.0.dist-info/RECORD +63 -0
  58. ezmsg/sigproc/synth.py +0 -774
  59. ezmsg_sigproc-2.5.0.dist-info/RECORD +0 -63
  60. {ezmsg_sigproc-2.5.0.dist-info → ezmsg_sigproc-2.6.0.dist-info}/WHEEL +0 -0
  61. /ezmsg_sigproc-2.5.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.6.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/messages.py CHANGED
@@ -1,10 +1,9 @@
1
- import warnings
2
1
  import time
2
+ import warnings
3
3
 
4
4
  import numpy.typing as npt
5
5
  from ezmsg.util.messages.axisarray import AxisArray
6
6
 
7
-
8
7
  # UPCOMING: TSMessage Deprecation
9
8
  # TSMessage is deprecated because it doesn't handle multiple time axes well.
10
9
  # AxisArray has an incompatible API but supports a superset of functionality.
ezmsg/sigproc/quantize.py CHANGED
@@ -1,5 +1,5 @@
1
- import numpy as np
2
1
  import ezmsg.core as ez
2
+ import numpy as np
3
3
  from ezmsg.util.messages.axisarray import AxisArray, replace
4
4
 
5
5
  from .base import BaseTransformer, BaseTransformerUnit
@@ -65,7 +65,5 @@ class QuantizeTransformer(BaseTransformer[QuantizeSettings, AxisArray, AxisArray
65
65
  return replace(message, data=data)
66
66
 
67
67
 
68
- class QuantizerUnit(
69
- BaseTransformerUnit[QuantizeSettings, AxisArray, AxisArray, QuantizeTransformer]
70
- ):
68
+ class QuantizerUnit(BaseTransformerUnit[QuantizeSettings, AxisArray, AxisArray, QuantizeTransformer]):
71
69
  SETTINGS = QuantizeSettings
ezmsg/sigproc/resample.py CHANGED
@@ -2,15 +2,15 @@ import asyncio
2
2
  import math
3
3
  import time
4
4
 
5
+ import ezmsg.core as ez
5
6
  import numpy as np
6
7
  import scipy.interpolate
7
- import ezmsg.core as ez
8
8
  from ezmsg.util.messages.axisarray import AxisArray, LinearAxis
9
9
  from ezmsg.util.messages.util import replace
10
10
 
11
11
  from .base import (
12
- BaseStatefulProcessor,
13
12
  BaseConsumerUnit,
13
+ BaseStatefulProcessor,
14
14
  processor_state,
15
15
  )
16
16
  from .util.axisarray_buffer import HybridAxisArrayBuffer, HybridAxisBuffer
@@ -29,7 +29,7 @@ class ResampleSettings(ez.Settings):
29
29
  fill_value: str = "extrapolate"
30
30
  """
31
31
  Value to use for out-of-bounds samples.
32
- If 'extrapolate', the transformer will extrapolate.
32
+ If 'extrapolate', the transformer will extrapolate.
33
33
  If 'last', the transformer will use the last sample.
34
34
  See scipy.interpolate.interp1d for more options.
35
35
  """
@@ -57,9 +57,9 @@ class ResampleState:
57
57
  """
58
58
  The buffer for the reference axis (usually a time axis). The interpolation function
59
59
  will be evaluated at the reference axis values.
60
- When resample_rate is None, this buffer will be filled with the axis from incoming
60
+ When resample_rate is None, this buffer will be filled with the axis from incoming
61
61
  _reference_ messages.
62
- When resample_rate is not None (i.e., prescribed float resample_rate), this buffer
62
+ When resample_rate is not None (i.e., prescribed float resample_rate), this buffer
63
63
  is filled with a synthetic axis that is generated from the incoming signal messages.
64
64
  """
65
65
 
@@ -67,7 +67,7 @@ class ResampleState:
67
67
  """
68
68
  The last value of the reference axis that was returned. This helps us to know
69
69
  what the _next_ returned value should be, and to avoid returning the same value.
70
- TODO: We can eliminate this variable if we maintain "by convention" that the
70
+ TODO: We can eliminate this variable if we maintain "by convention" that the
71
71
  reference axis always has 1 value at its start that we exclude from the resampling.
72
72
  """
73
73
 
@@ -79,9 +79,7 @@ class ResampleState:
79
79
  """
80
80
 
81
81
 
82
- class ResampleProcessor(
83
- BaseStatefulProcessor[ResampleSettings, AxisArray, AxisArray, ResampleState]
84
- ):
82
+ class ResampleProcessor(BaseStatefulProcessor[ResampleSettings, AxisArray, AxisArray, ResampleState]):
85
83
  def _hash_message(self, message: AxisArray) -> int:
86
84
  ax_idx: int = message.get_axis_idx(self.settings.axis)
87
85
  sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
@@ -135,17 +133,11 @@ class ResampleProcessor(
135
133
  ax_idx = message.get_axis_idx(self.settings.axis)
136
134
  if self.settings.resample_rate is not None and message.data.shape[ax_idx] > 0:
137
135
  in_ax = message.axes[self.settings.axis]
138
- in_t_end = (
139
- in_ax.data[-1]
140
- if hasattr(in_ax, "data")
141
- else in_ax.value(message.data.shape[ax_idx] - 1)
142
- )
136
+ in_t_end = in_ax.data[-1] if hasattr(in_ax, "data") else in_ax.value(message.data.shape[ax_idx] - 1)
143
137
  out_gain = 1 / self.settings.resample_rate
144
138
  prev_t_end = self.state.last_ref_ax_val
145
139
  n_synth = math.ceil((in_t_end - prev_t_end) * self.settings.resample_rate)
146
- synth_ref_axis = LinearAxis(
147
- unit="s", gain=out_gain, offset=prev_t_end + out_gain
148
- )
140
+ synth_ref_axis = LinearAxis(unit="s", gain=out_gain, offset=prev_t_end + out_gain)
149
141
  self.state.ref_axis_buffer.write(synth_ref_axis, n_samples=n_synth)
150
142
 
151
143
  self.state.last_write_time = time.time()
@@ -193,11 +185,7 @@ class ResampleProcessor(
193
185
  # Get source to train interpolation
194
186
  src_axarr = src.peek()
195
187
  src_axis = src_axarr.axes[self.settings.axis]
196
- x = (
197
- src_axis.data
198
- if hasattr(src_axis, "data")
199
- else src_axis.value(np.arange(src_axarr.data.shape[0]))
200
- )
188
+ x = src_axis.data if hasattr(src_axis, "data") else src_axis.value(np.arange(src_axarr.data.shape[0]))
201
189
 
202
190
  # Only resample at reference values that have not been interpolated over previously.
203
191
  b_ref = ref_xvec > self.state.last_ref_ax_val
@@ -208,11 +196,7 @@ class ResampleProcessor(
208
196
 
209
197
  if len(ref_idx) == 0:
210
198
  # Nothing to interpolate over; return empty data
211
- null_ref = (
212
- replace(ref_ax, data=ref_ax.data[:0])
213
- if hasattr(ref_ax, "data")
214
- else ref_ax
215
- )
199
+ null_ref = replace(ref_ax, data=ref_ax.data[:0]) if hasattr(ref_ax, "data") else ref_ax
216
200
  return replace(
217
201
  src_axarr,
218
202
  data=src_axarr.data[:0, ...],
@@ -222,17 +206,12 @@ class ResampleProcessor(
222
206
  xnew = ref_xvec[ref_idx]
223
207
 
224
208
  # Identify source data indices around ref tvec with some padding for better interpolation.
225
- src_start_ix = max(
226
- 0, np.where(x > xnew[0])[0][0] - 2 if np.any(x > xnew[0]) else 0
227
- )
209
+ src_start_ix = max(0, np.where(x > xnew[0])[0][0] - 2 if np.any(x > xnew[0]) else 0)
228
210
 
229
211
  x = x[src_start_ix:]
230
212
  y = src_axarr.data[src_start_ix:]
231
213
 
232
- if (
233
- isinstance(self.settings.fill_value, str)
234
- and self.settings.fill_value == "last"
235
- ):
214
+ if isinstance(self.settings.fill_value, str) and self.settings.fill_value == "last":
236
215
  fill_value = (y[0], y[-1])
237
216
  else:
238
217
  fill_value = self.settings.fill_value
@@ -3,14 +3,15 @@ from collections import deque
3
3
  import ezmsg.core as ez
4
4
  import numpy as np
5
5
  import numpy.typing as npt
6
+ from ezmsg.util.messages.axisarray import AxisArray
7
+ from ezmsg.util.messages.util import replace
8
+
6
9
  from ezmsg.sigproc.base import (
7
10
  BaseAdaptiveTransformer,
8
11
  BaseAdaptiveTransformerUnit,
9
12
  processor_state,
10
13
  )
11
14
  from ezmsg.sigproc.sampler import SampleMessage
12
- from ezmsg.util.messages.axisarray import AxisArray
13
- from ezmsg.util.messages.util import replace
14
15
 
15
16
 
16
17
  class RollingScalerSettings(ez.Settings):
@@ -71,11 +72,7 @@ class RollingScalerState:
71
72
  min_samples: int | None = None
72
73
 
73
74
 
74
- class RollingScalerProcessor(
75
- BaseAdaptiveTransformer[
76
- RollingScalerSettings, AxisArray, AxisArray, RollingScalerState
77
- ]
78
- ):
75
+ class RollingScalerProcessor(BaseAdaptiveTransformer[RollingScalerSettings, AxisArray, AxisArray, RollingScalerState]):
79
76
  """
80
77
  Processor for rolling z-score normalization of input `AxisArray` messages.
81
78
 
@@ -119,40 +116,23 @@ class RollingScalerProcessor(
119
116
  self._state.N = 0
120
117
  self._state.M2 = np.zeros(ch)
121
118
  self._state.k_samples = (
122
- int(
123
- np.ceil(
124
- self.settings.window_size / message.axes[self.settings.axis].gain
125
- )
126
- )
119
+ int(np.ceil(self.settings.window_size / message.axes[self.settings.axis].gain))
127
120
  if self.settings.window_size is not None
128
121
  else self.settings.k_samples
129
122
  )
130
123
  if self._state.k_samples is not None and self._state.k_samples < 1:
131
- ez.logger.warning(
132
- "window_size smaller than sample gain; setting k_samples to 1."
133
- )
124
+ ez.logger.warning("window_size smaller than sample gain; setting k_samples to 1.")
134
125
  self._state.k_samples = 1
135
126
  elif self._state.k_samples is None:
136
- ez.logger.warning(
137
- "k_samples is None; z-score accumulation will be unbounded."
138
- )
127
+ ez.logger.warning("k_samples is None; z-score accumulation will be unbounded.")
139
128
  self._state.samples = deque(maxlen=self._state.k_samples)
140
129
  self._state.min_samples = (
141
- int(
142
- np.ceil(
143
- self.settings.min_seconds / message.axes[self.settings.axis].gain
144
- )
145
- )
130
+ int(np.ceil(self.settings.min_seconds / message.axes[self.settings.axis].gain))
146
131
  if self.settings.window_size is not None
147
132
  else self.settings.min_samples
148
133
  )
149
- if (
150
- self._state.k_samples is not None
151
- and self._state.min_samples > self._state.k_samples
152
- ):
153
- ez.logger.warning(
154
- "min_samples is greater than k_samples; adjusting min_samples to k_samples."
155
- )
134
+ if self._state.k_samples is not None and self._state.min_samples > self._state.k_samples:
135
+ ez.logger.warning("min_samples is greater than k_samples; adjusting min_samples to k_samples.")
156
136
  self._state.min_samples = self._state.k_samples
157
137
 
158
138
  def _add_batch_stats(self, x: npt.NDArray) -> None:
@@ -161,10 +141,7 @@ class RollingScalerProcessor(
161
141
  mean_b = np.mean(x, axis=0)
162
142
  M2_b = np.sum((x - mean_b) ** 2, axis=0)
163
143
 
164
- if (
165
- self._state.k_samples is not None
166
- and len(self._state.samples) == self._state.k_samples
167
- ):
144
+ if self._state.k_samples is not None and len(self._state.samples) == self._state.k_samples:
168
145
  n_old, mean_old, M2_old = self._state.samples.popleft()
169
146
  N_T = self._state.N
170
147
  N_new = N_T - n_old
@@ -177,9 +154,7 @@ class RollingScalerProcessor(
177
154
  delta = mean_old - self._state.mean
178
155
  self._state.N = N_new
179
156
  self._state.mean = (N_T * self._state.mean - n_old * mean_old) / N_new
180
- self._state.M2 = (
181
- self._state.M2 - M2_old - (delta * delta) * (N_T * n_old / N_new)
182
- )
157
+ self._state.M2 = self._state.M2 - M2_old - (delta * delta) * (N_T * n_old / N_new)
183
158
 
184
159
  N_A = self._state.N
185
160
  N = N_A + n_b
ezmsg/sigproc/sampler.py CHANGED
@@ -1,28 +1,28 @@
1
1
  import asyncio
2
- from collections import deque
3
2
  import copy
4
3
  import traceback
5
4
  import typing
5
+ from collections import deque
6
6
 
7
- import numpy as np
8
7
  import ezmsg.core as ez
8
+ import numpy as np
9
9
  from ezmsg.util.messages.axisarray import (
10
10
  AxisArray,
11
11
  )
12
12
  from ezmsg.util.messages.util import replace
13
13
 
14
- from .util.profile import profile_subpub
15
- from .util.axisarray_buffer import HybridAxisArrayBuffer
16
- from .util.buffer import UpdateStrategy
17
- from .util.message import SampleMessage, SampleTriggerMessage
18
14
  from .base import (
19
- BaseStatefulTransformer,
20
15
  BaseConsumerUnit,
21
- BaseTransformerUnit,
22
- BaseStatefulProducer,
23
16
  BaseProducerUnit,
17
+ BaseStatefulProducer,
18
+ BaseStatefulTransformer,
19
+ BaseTransformerUnit,
24
20
  processor_state,
25
21
  )
22
+ from .util.axisarray_buffer import HybridAxisArrayBuffer
23
+ from .util.buffer import UpdateStrategy
24
+ from .util.message import SampleMessage, SampleTriggerMessage
25
+ from .util.profile import profile_subpub
26
26
 
27
27
 
28
28
  class SamplerSettings(ez.Settings):
@@ -74,12 +74,8 @@ class SamplerState:
74
74
  triggers: deque[SampleTriggerMessage] | None = None
75
75
 
76
76
 
77
- class SamplerTransformer(
78
- BaseStatefulTransformer[SamplerSettings, AxisArray, AxisArray, SamplerState]
79
- ):
80
- def __call__(
81
- self, message: AxisArray | SampleTriggerMessage
82
- ) -> list[SampleMessage]:
77
+ class SamplerTransformer(BaseStatefulTransformer[SamplerSettings, AxisArray, AxisArray, SamplerState]):
78
+ def __call__(self, message: AxisArray | SampleTriggerMessage) -> list[SampleMessage]:
83
79
  # TODO: Currently we have a single entry point that accepts both
84
80
  # data and trigger messages and we choose a code path based on
85
81
  # the message type. However, in the future we will likely replace
@@ -99,9 +95,7 @@ class SamplerTransformer(
99
95
  # Compute hash based on message properties that require state reset
100
96
  axis = self.settings.axis or message.dims[0]
101
97
  axis_idx = message.get_axis_idx(axis)
102
- sample_shape = (
103
- message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
104
- )
98
+ sample_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
105
99
  return hash((sample_shape, message.key))
106
100
 
107
101
  def _reset_state(self, message: AxisArray) -> None:
@@ -193,20 +187,14 @@ class SamplerTransformer(
193
187
  trigger_ts: float = message.timestamp
194
188
  if not self.settings.estimate_alignment:
195
189
  # Override the trigger timestamp with the next sample's likely timestamp.
196
- trigger_ts = (
197
- self._state.buffer.axis_final_value + self._state.buffer.axis_gain
198
- )
190
+ trigger_ts = self._state.buffer.axis_final_value + self._state.buffer.axis_gain
199
191
 
200
- new_trig_msg = replace(
201
- message, timestamp=trigger_ts, period=_period, value=_value
202
- )
192
+ new_trig_msg = replace(message, timestamp=trigger_ts, period=_period, value=_value)
203
193
  self._state.triggers.append(new_trig_msg)
204
194
  return []
205
195
 
206
196
 
207
- class Sampler(
208
- BaseTransformerUnit[SamplerSettings, AxisArray, AxisArray, SamplerTransformer]
209
- ):
197
+ class Sampler(BaseTransformerUnit[SamplerSettings, AxisArray, AxisArray, SamplerTransformer]):
210
198
  SETTINGS = SamplerSettings
211
199
 
212
200
  INPUT_TRIGGER = ez.InputStream(SampleTriggerMessage)
@@ -269,19 +257,13 @@ class TriggerGeneratorState:
269
257
  output: int = 0
270
258
 
271
259
 
272
- class TriggerProducer(
273
- BaseStatefulProducer[
274
- TriggerGeneratorSettings, SampleTriggerMessage, TriggerGeneratorState
275
- ]
276
- ):
260
+ class TriggerProducer(BaseStatefulProducer[TriggerGeneratorSettings, SampleTriggerMessage, TriggerGeneratorState]):
277
261
  def _reset_state(self) -> None:
278
262
  self._state.output = 0
279
263
 
280
264
  async def _produce(self) -> SampleTriggerMessage:
281
265
  await asyncio.sleep(self.settings.publish_period)
282
- out_msg = SampleTriggerMessage(
283
- period=self.settings.period, value=self._state.output
284
- )
266
+ out_msg = SampleTriggerMessage(period=self.settings.period, value=self._state.output)
285
267
  self._state.output += 1
286
268
  return out_msg
287
269
 
ezmsg/sigproc/scaler.py CHANGED
@@ -1,27 +1,25 @@
1
1
  import typing
2
2
 
3
3
  import numpy as np
4
+ from ezmsg.util.generator import consumer
4
5
  from ezmsg.util.messages.axisarray import AxisArray
5
6
  from ezmsg.util.messages.util import replace
6
- from ezmsg.util.generator import consumer
7
7
 
8
8
  from .base import (
9
9
  BaseStatefulTransformer,
10
10
  BaseTransformerUnit,
11
11
  processor_state,
12
12
  )
13
- from .ewma import EWMATransformer, EWMASettings, _alpha_from_tau
14
13
 
15
14
  # Imports for backwards compatibility with previous module location
16
15
  from .ewma import EWMA_Deprecated as EWMA_Deprecated
17
- from .ewma import ewma_step as ewma_step
16
+ from .ewma import EWMASettings, EWMATransformer, _alpha_from_tau
18
17
  from .ewma import _tau_from_alpha as _tau_from_alpha
18
+ from .ewma import ewma_step as ewma_step
19
19
 
20
20
 
21
21
  @consumer
22
- def scaler(
23
- time_constant: float = 1.0, axis: str | None = None
24
- ) -> typing.Generator[AxisArray, AxisArray, None]:
22
+ def scaler(time_constant: float = 1.0, axis: str | None = None) -> typing.Generator[AxisArray, AxisArray, None]:
25
23
  """
26
24
  Apply the adaptive standard scaler from https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/
27
25
  This is faster than :obj:`scaler_np` for single-channel data.
@@ -85,19 +83,13 @@ class AdaptiveStandardScalerTransformer(
85
83
  ]
86
84
  ):
87
85
  def _reset_state(self, message: AxisArray) -> None:
88
- self._state.samps_ewma = EWMATransformer(
89
- time_constant=self.settings.time_constant, axis=self.settings.axis
90
- )
91
- self._state.vars_sq_ewma = EWMATransformer(
92
- time_constant=self.settings.time_constant, axis=self.settings.axis
93
- )
86
+ self._state.samps_ewma = EWMATransformer(time_constant=self.settings.time_constant, axis=self.settings.axis)
87
+ self._state.vars_sq_ewma = EWMATransformer(time_constant=self.settings.time_constant, axis=self.settings.axis)
94
88
 
95
89
  def _process(self, message: AxisArray) -> AxisArray:
96
90
  # Update step
97
91
  mean_message = self._state.samps_ewma(message)
98
- var_sq_message = self._state.vars_sq_ewma(
99
- replace(message, data=message.data**2)
100
- )
92
+ var_sq_message = self._state.vars_sq_ewma(replace(message, data=message.data**2))
101
93
 
102
94
  # Get step
103
95
  varis = var_sq_message.data - mean_message.data**2
@@ -119,9 +111,7 @@ class AdaptiveStandardScaler(
119
111
 
120
112
 
121
113
  # Backwards compatibility...
122
- def scaler_np(
123
- time_constant: float = 1.0, axis: str | None = None
124
- ) -> AdaptiveStandardScalerTransformer:
114
+ def scaler_np(time_constant: float = 1.0, axis: str | None = None) -> AdaptiveStandardScalerTransformer:
125
115
  return AdaptiveStandardScalerTransformer(
126
116
  settings=AdaptiveStandardScalerSettings(time_constant=time_constant, axis=axis)
127
117
  )
@@ -1,8 +1,8 @@
1
1
  import ezmsg.core as ez
2
- from ezmsg.util.messages.axisarray import AxisArray
3
- from ezmsg.util.messages.util import replace
4
2
  import numpy as np
5
3
  import numpy.typing as npt
4
+ from ezmsg.util.messages.axisarray import AxisArray
5
+ from ezmsg.util.messages.util import replace
6
6
 
7
7
  from .base import (
8
8
  BaseAsyncTransformer,
@@ -27,15 +27,11 @@ class SignalInjectorState:
27
27
 
28
28
 
29
29
  class SignalInjectorTransformer(
30
- BaseAsyncTransformer[
31
- SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorState
32
- ]
30
+ BaseAsyncTransformer[SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorState]
33
31
  ):
34
32
  def _hash_message(self, message: AxisArray) -> int:
35
33
  time_ax_idx = message.get_axis_idx(self.settings.time_dim)
36
- sample_shape = (
37
- message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
38
- )
34
+ sample_shape = message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
39
35
  return hash((message.key,) + sample_shape)
40
36
 
41
37
  def _reset_state(self, message: AxisArray) -> None:
@@ -44,9 +40,7 @@ class SignalInjectorTransformer(
44
40
  if self._state.cur_amplitude is None:
45
41
  self._state.cur_amplitude = self.settings.amplitude
46
42
  time_ax_idx = message.get_axis_idx(self.settings.time_dim)
47
- self._state.cur_shape = (
48
- message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
49
- )
43
+ self._state.cur_shape = message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
50
44
  rng = np.random.default_rng(self.settings.mixing_seed)
51
45
  self._state.mixing = rng.random((1, message.shape2d(self.settings.time_dim)[1]))
52
46
  self._state.mixing = (self._state.mixing * 2.0) - 1.0
@@ -63,11 +57,7 @@ class SignalInjectorTransformer(
63
57
  return out_msg
64
58
 
65
59
 
66
- class SignalInjector(
67
- BaseTransformerUnit[
68
- SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorTransformer
69
- ]
70
- ):
60
+ class SignalInjector(BaseTransformerUnit[SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorTransformer]):
71
61
  SETTINGS = SignalInjectorSettings
72
62
  INPUT_FREQUENCY = ez.InputStream(float | None)
73
63
  INPUT_AMPLITUDE = ez.InputStream(float)
ezmsg/sigproc/slicer.py CHANGED
@@ -1,11 +1,11 @@
1
+ import ezmsg.core as ez
1
2
  import numpy as np
2
3
  import numpy.typing as npt
3
- import ezmsg.core as ez
4
4
  from ezmsg.util.messages.axisarray import (
5
5
  AxisArray,
6
- slice_along_axis,
7
6
  AxisBase,
8
7
  replace,
8
+ slice_along_axis,
9
9
  )
10
10
 
11
11
  from .base import (
@@ -49,11 +49,7 @@ def parse_slice(
49
49
  if "," not in s:
50
50
  parts = [part.strip() for part in s.split(":")]
51
51
  if len(parts) == 1:
52
- if (
53
- axinfo is not None
54
- and hasattr(axinfo, "data")
55
- and parts[0] in axinfo.data
56
- ):
52
+ if axinfo is not None and hasattr(axinfo, "data") and parts[0] in axinfo.data:
57
53
  return tuple(np.where(axinfo.data == parts[0])[0])
58
54
  return (int(parts[0]),)
59
55
  return (slice(*(int(part.strip()) if part else None for part in parts)),)
@@ -76,9 +72,7 @@ class SlicerState:
76
72
  b_change_dims: bool = False
77
73
 
78
74
 
79
- class SlicerTransformer(
80
- BaseStatefulTransformer[SlicerSettings, AxisArray, AxisArray, SlicerState]
81
- ):
75
+ class SlicerTransformer(BaseStatefulTransformer[SlicerSettings, AxisArray, AxisArray, SlicerState]):
82
76
  def _hash_message(self, message: AxisArray) -> int:
83
77
  axis = self.settings.axis or message.dims[-1]
84
78
  axis_idx = message.get_axis_idx(axis)
@@ -101,11 +95,7 @@ class SlicerTransformer(
101
95
  self._state.slice_ = np.s_[indices]
102
96
 
103
97
  # Create the output axis
104
- if (
105
- axis in message.axes
106
- and hasattr(message.axes[axis], "data")
107
- and len(message.axes[axis].data) > 0
108
- ):
98
+ if axis in message.axes and hasattr(message.axes[axis], "data") and len(message.axes[axis].data) > 0:
109
99
  in_data = np.array(message.axes[axis].data)
110
100
  if self._state.b_change_dims:
111
101
  out_data = in_data[self._state.slice_ : self._state.slice_ + 1]
@@ -119,17 +109,10 @@ class SlicerTransformer(
119
109
 
120
110
  replace_kwargs = {}
121
111
  if self._state.b_change_dims:
122
- replace_kwargs["dims"] = [
123
- _ for dim_ix, _ in enumerate(message.dims) if dim_ix != axis_idx
124
- ]
125
- replace_kwargs["axes"] = {
126
- k: v for k, v in message.axes.items() if k != axis
127
- }
112
+ replace_kwargs["dims"] = [_ for dim_ix, _ in enumerate(message.dims) if dim_ix != axis_idx]
113
+ replace_kwargs["axes"] = {k: v for k, v in message.axes.items() if k != axis}
128
114
  elif self._state.new_axis is not None:
129
- replace_kwargs["axes"] = {
130
- k: (v if k != axis else self._state.new_axis)
131
- for k, v in message.axes.items()
132
- }
115
+ replace_kwargs["axes"] = {k: (v if k != axis else self._state.new_axis) for k, v in message.axes.items()}
133
116
 
134
117
  return replace(
135
118
  message,
@@ -138,9 +121,7 @@ class SlicerTransformer(
138
121
  )
139
122
 
140
123
 
141
- class Slicer(
142
- BaseTransformerUnit[SlicerSettings, AxisArray, AxisArray, SlicerTransformer]
143
- ):
124
+ class Slicer(BaseTransformerUnit[SlicerSettings, AxisArray, AxisArray, SlicerTransformer]):
144
125
  SETTINGS = SlicerSettings
145
126
 
146
127
 
ezmsg/sigproc/spectral.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from .spectrum import OptionsEnum as OptionsEnum
2
- from .spectrum import WindowFunction as WindowFunction
3
- from .spectrum import SpectralTransform as SpectralTransform
4
2
  from .spectrum import SpectralOutput as SpectralOutput
5
- from .spectrum import SpectrumSettings as SpectrumSettings
3
+ from .spectrum import SpectralTransform as SpectralTransform
6
4
  from .spectrum import Spectrum as Spectrum
5
+ from .spectrum import SpectrumSettings as SpectrumSettings
6
+ from .spectrum import WindowFunction as WindowFunction
@@ -1,20 +1,21 @@
1
1
  from typing import Generator
2
+
2
3
  import ezmsg.core as ez
3
4
  from ezmsg.util.messages.axisarray import AxisArray
4
5
  from ezmsg.util.messages.modify import modify_axis
5
6
 
6
- from .window import Anchor, WindowTransformer
7
- from .spectrum import (
8
- WindowFunction,
9
- SpectralTransform,
10
- SpectralOutput,
11
- SpectrumTransformer,
12
- )
13
7
  from .base import (
14
- CompositeProcessor,
15
8
  BaseStatefulProcessor,
16
9
  BaseTransformerUnit,
10
+ CompositeProcessor,
11
+ )
12
+ from .spectrum import (
13
+ SpectralOutput,
14
+ SpectralTransform,
15
+ SpectrumTransformer,
16
+ WindowFunction,
17
17
  )
18
+ from .window import Anchor, WindowTransformer
18
19
 
19
20
 
20
21
  class SpectrogramSettings(ez.Settings):
@@ -41,9 +42,7 @@ class SpectrogramSettings(ez.Settings):
41
42
  """The :obj:`SpectralOutput` format."""
42
43
 
43
44
 
44
- class SpectrogramTransformer(
45
- CompositeProcessor[SpectrogramSettings, AxisArray, AxisArray]
46
- ):
45
+ class SpectrogramTransformer(CompositeProcessor[SpectrogramSettings, AxisArray, AxisArray]):
47
46
  @staticmethod
48
47
  def _initialize_processors(
49
48
  settings: SpectrogramSettings,
@@ -54,9 +53,7 @@ class SpectrogramTransformer(
54
53
  newaxis="win",
55
54
  window_dur=settings.window_dur,
56
55
  window_shift=settings.window_shift,
57
- zero_pad_until="shift"
58
- if settings.window_shift is not None
59
- else "input",
56
+ zero_pad_until="shift" if settings.window_shift is not None else "input",
60
57
  anchor=settings.window_anchor,
61
58
  ),
62
59
  "spectrum": SpectrumTransformer(
@@ -69,11 +66,7 @@ class SpectrogramTransformer(
69
66
  }
70
67
 
71
68
 
72
- class Spectrogram(
73
- BaseTransformerUnit[
74
- SpectrogramSettings, AxisArray, AxisArray, SpectrogramTransformer
75
- ]
76
- ):
69
+ class Spectrogram(BaseTransformerUnit[SpectrogramSettings, AxisArray, AxisArray, SpectrogramTransformer]):
77
70
  SETTINGS = SpectrogramSettings
78
71
 
79
72