ezmsg-sigproc 2.5.0__py3-none-any.whl → 2.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +5 -11
  3. ezmsg/sigproc/adaptive_lattice_notch.py +11 -30
  4. ezmsg/sigproc/affinetransform.py +16 -42
  5. ezmsg/sigproc/aggregate.py +17 -34
  6. ezmsg/sigproc/bandpower.py +12 -20
  7. ezmsg/sigproc/base.py +141 -1276
  8. ezmsg/sigproc/butterworthfilter.py +8 -16
  9. ezmsg/sigproc/butterworthzerophase.py +7 -16
  10. ezmsg/sigproc/cheby.py +4 -10
  11. ezmsg/sigproc/combfilter.py +5 -8
  12. ezmsg/sigproc/coordinatespaces.py +142 -0
  13. ezmsg/sigproc/decimate.py +3 -7
  14. ezmsg/sigproc/denormalize.py +6 -11
  15. ezmsg/sigproc/detrend.py +3 -4
  16. ezmsg/sigproc/diff.py +8 -17
  17. ezmsg/sigproc/downsample.py +11 -20
  18. ezmsg/sigproc/ewma.py +11 -28
  19. ezmsg/sigproc/ewmfilter.py +1 -1
  20. ezmsg/sigproc/extract_axis.py +3 -4
  21. ezmsg/sigproc/fbcca.py +34 -59
  22. ezmsg/sigproc/filter.py +19 -45
  23. ezmsg/sigproc/filterbank.py +37 -74
  24. ezmsg/sigproc/filterbankdesign.py +7 -14
  25. ezmsg/sigproc/fir_hilbert.py +13 -30
  26. ezmsg/sigproc/fir_pmc.py +5 -10
  27. ezmsg/sigproc/firfilter.py +12 -14
  28. ezmsg/sigproc/gaussiansmoothing.py +5 -9
  29. ezmsg/sigproc/kaiser.py +11 -15
  30. ezmsg/sigproc/math/abs.py +4 -3
  31. ezmsg/sigproc/math/add.py +121 -0
  32. ezmsg/sigproc/math/clip.py +4 -1
  33. ezmsg/sigproc/math/difference.py +100 -36
  34. ezmsg/sigproc/math/invert.py +3 -3
  35. ezmsg/sigproc/math/log.py +5 -6
  36. ezmsg/sigproc/math/scale.py +2 -0
  37. ezmsg/sigproc/messages.py +1 -2
  38. ezmsg/sigproc/quantize.py +3 -6
  39. ezmsg/sigproc/resample.py +17 -38
  40. ezmsg/sigproc/rollingscaler.py +12 -37
  41. ezmsg/sigproc/sampler.py +19 -37
  42. ezmsg/sigproc/scaler.py +11 -22
  43. ezmsg/sigproc/signalinjector.py +7 -18
  44. ezmsg/sigproc/slicer.py +14 -34
  45. ezmsg/sigproc/spectral.py +3 -3
  46. ezmsg/sigproc/spectrogram.py +12 -19
  47. ezmsg/sigproc/spectrum.py +17 -38
  48. ezmsg/sigproc/transpose.py +12 -24
  49. ezmsg/sigproc/util/asio.py +25 -156
  50. ezmsg/sigproc/util/axisarray_buffer.py +12 -26
  51. ezmsg/sigproc/util/buffer.py +22 -43
  52. ezmsg/sigproc/util/message.py +17 -31
  53. ezmsg/sigproc/util/profile.py +23 -174
  54. ezmsg/sigproc/util/sparse.py +7 -15
  55. ezmsg/sigproc/util/typeresolution.py +17 -83
  56. ezmsg/sigproc/wavelets.py +10 -19
  57. ezmsg/sigproc/window.py +29 -83
  58. ezmsg_sigproc-2.7.0.dist-info/METADATA +60 -0
  59. ezmsg_sigproc-2.7.0.dist-info/RECORD +64 -0
  60. ezmsg/sigproc/synth.py +0 -774
  61. ezmsg_sigproc-2.5.0.dist-info/METADATA +0 -72
  62. ezmsg_sigproc-2.5.0.dist-info/RECORD +0 -63
  63. {ezmsg_sigproc-2.5.0.dist-info → ezmsg_sigproc-2.7.0.dist-info}/WHEEL +0 -0
  64. /ezmsg_sigproc-2.5.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.7.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/math/log.py CHANGED
@@ -1,5 +1,8 @@
1
- import numpy as np
1
+ """Take the logarithm of the data."""
2
+
3
+ # TODO: Array API
2
4
  import ezmsg.core as ez
5
+ import numpy as np
3
6
  from ezmsg.util.messages.axisarray import AxisArray
4
7
  from ezmsg.util.messages.util import replace
5
8
 
@@ -17,11 +20,7 @@ class LogSettings(ez.Settings):
17
20
  class LogTransformer(BaseTransformer[LogSettings, AxisArray, AxisArray]):
18
21
  def _process(self, message: AxisArray) -> AxisArray:
19
22
  data = message.data
20
- if (
21
- self.settings.clip_zero
22
- and np.any(data <= 0)
23
- and np.issubdtype(data.dtype, np.floating)
24
- ):
23
+ if self.settings.clip_zero and np.any(data <= 0) and np.issubdtype(data.dtype, np.floating):
25
24
  data = np.clip(data, a_min=np.finfo(data.dtype).tiny, a_max=None)
26
25
  return replace(message, data=np.log(data) / np.log(self.settings.base))
27
26
 
@@ -1,3 +1,5 @@
1
+ """Scale the data by a constant factor."""
2
+
1
3
  import ezmsg.core as ez
2
4
  from ezmsg.util.messages.axisarray import AxisArray
3
5
  from ezmsg.util.messages.util import replace
ezmsg/sigproc/messages.py CHANGED
@@ -1,10 +1,9 @@
1
- import warnings
2
1
  import time
2
+ import warnings
3
3
 
4
4
  import numpy.typing as npt
5
5
  from ezmsg.util.messages.axisarray import AxisArray
6
6
 
7
-
8
7
  # UPCOMING: TSMessage Deprecation
9
8
  # TSMessage is deprecated because it doesn't handle multiple time axes well.
10
9
  # AxisArray has an incompatible API but supports a superset of functionality.
ezmsg/sigproc/quantize.py CHANGED
@@ -1,9 +1,8 @@
1
- import numpy as np
2
1
  import ezmsg.core as ez
2
+ import numpy as np
3
+ from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
3
4
  from ezmsg.util.messages.axisarray import AxisArray, replace
4
5
 
5
- from .base import BaseTransformer, BaseTransformerUnit
6
-
7
6
 
8
7
  class QuantizeSettings(ez.Settings):
9
8
  """
@@ -65,7 +64,5 @@ class QuantizeTransformer(BaseTransformer[QuantizeSettings, AxisArray, AxisArray
65
64
  return replace(message, data=data)
66
65
 
67
66
 
68
- class QuantizerUnit(
69
- BaseTransformerUnit[QuantizeSettings, AxisArray, AxisArray, QuantizeTransformer]
70
- ):
67
+ class QuantizerUnit(BaseTransformerUnit[QuantizeSettings, AxisArray, AxisArray, QuantizeTransformer]):
71
68
  SETTINGS = QuantizeSettings
ezmsg/sigproc/resample.py CHANGED
@@ -2,17 +2,17 @@ import asyncio
2
2
  import math
3
3
  import time
4
4
 
5
+ import ezmsg.core as ez
5
6
  import numpy as np
6
7
  import scipy.interpolate
7
- import ezmsg.core as ez
8
- from ezmsg.util.messages.axisarray import AxisArray, LinearAxis
9
- from ezmsg.util.messages.util import replace
10
-
11
- from .base import (
12
- BaseStatefulProcessor,
8
+ from ezmsg.baseproc import (
13
9
  BaseConsumerUnit,
10
+ BaseStatefulProcessor,
14
11
  processor_state,
15
12
  )
13
+ from ezmsg.util.messages.axisarray import AxisArray, LinearAxis
14
+ from ezmsg.util.messages.util import replace
15
+
16
16
  from .util.axisarray_buffer import HybridAxisArrayBuffer, HybridAxisBuffer
17
17
  from .util.buffer import UpdateStrategy
18
18
 
@@ -29,7 +29,7 @@ class ResampleSettings(ez.Settings):
29
29
  fill_value: str = "extrapolate"
30
30
  """
31
31
  Value to use for out-of-bounds samples.
32
- If 'extrapolate', the transformer will extrapolate.
32
+ If 'extrapolate', the transformer will extrapolate.
33
33
  If 'last', the transformer will use the last sample.
34
34
  See scipy.interpolate.interp1d for more options.
35
35
  """
@@ -57,9 +57,9 @@ class ResampleState:
57
57
  """
58
58
  The buffer for the reference axis (usually a time axis). The interpolation function
59
59
  will be evaluated at the reference axis values.
60
- When resample_rate is None, this buffer will be filled with the axis from incoming
60
+ When resample_rate is None, this buffer will be filled with the axis from incoming
61
61
  _reference_ messages.
62
- When resample_rate is not None (i.e., prescribed float resample_rate), this buffer
62
+ When resample_rate is not None (i.e., prescribed float resample_rate), this buffer
63
63
  is filled with a synthetic axis that is generated from the incoming signal messages.
64
64
  """
65
65
 
@@ -67,7 +67,7 @@ class ResampleState:
67
67
  """
68
68
  The last value of the reference axis that was returned. This helps us to know
69
69
  what the _next_ returned value should be, and to avoid returning the same value.
70
- TODO: We can eliminate this variable if we maintain "by convention" that the
70
+ TODO: We can eliminate this variable if we maintain "by convention" that the
71
71
  reference axis always has 1 value at its start that we exclude from the resampling.
72
72
  """
73
73
 
@@ -79,9 +79,7 @@ class ResampleState:
79
79
  """
80
80
 
81
81
 
82
- class ResampleProcessor(
83
- BaseStatefulProcessor[ResampleSettings, AxisArray, AxisArray, ResampleState]
84
- ):
82
+ class ResampleProcessor(BaseStatefulProcessor[ResampleSettings, AxisArray, AxisArray, ResampleState]):
85
83
  def _hash_message(self, message: AxisArray) -> int:
86
84
  ax_idx: int = message.get_axis_idx(self.settings.axis)
87
85
  sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
@@ -135,17 +133,11 @@ class ResampleProcessor(
135
133
  ax_idx = message.get_axis_idx(self.settings.axis)
136
134
  if self.settings.resample_rate is not None and message.data.shape[ax_idx] > 0:
137
135
  in_ax = message.axes[self.settings.axis]
138
- in_t_end = (
139
- in_ax.data[-1]
140
- if hasattr(in_ax, "data")
141
- else in_ax.value(message.data.shape[ax_idx] - 1)
142
- )
136
+ in_t_end = in_ax.data[-1] if hasattr(in_ax, "data") else in_ax.value(message.data.shape[ax_idx] - 1)
143
137
  out_gain = 1 / self.settings.resample_rate
144
138
  prev_t_end = self.state.last_ref_ax_val
145
139
  n_synth = math.ceil((in_t_end - prev_t_end) * self.settings.resample_rate)
146
- synth_ref_axis = LinearAxis(
147
- unit="s", gain=out_gain, offset=prev_t_end + out_gain
148
- )
140
+ synth_ref_axis = LinearAxis(unit="s", gain=out_gain, offset=prev_t_end + out_gain)
149
141
  self.state.ref_axis_buffer.write(synth_ref_axis, n_samples=n_synth)
150
142
 
151
143
  self.state.last_write_time = time.time()
@@ -193,11 +185,7 @@ class ResampleProcessor(
193
185
  # Get source to train interpolation
194
186
  src_axarr = src.peek()
195
187
  src_axis = src_axarr.axes[self.settings.axis]
196
- x = (
197
- src_axis.data
198
- if hasattr(src_axis, "data")
199
- else src_axis.value(np.arange(src_axarr.data.shape[0]))
200
- )
188
+ x = src_axis.data if hasattr(src_axis, "data") else src_axis.value(np.arange(src_axarr.data.shape[0]))
201
189
 
202
190
  # Only resample at reference values that have not been interpolated over previously.
203
191
  b_ref = ref_xvec > self.state.last_ref_ax_val
@@ -208,11 +196,7 @@ class ResampleProcessor(
208
196
 
209
197
  if len(ref_idx) == 0:
210
198
  # Nothing to interpolate over; return empty data
211
- null_ref = (
212
- replace(ref_ax, data=ref_ax.data[:0])
213
- if hasattr(ref_ax, "data")
214
- else ref_ax
215
- )
199
+ null_ref = replace(ref_ax, data=ref_ax.data[:0]) if hasattr(ref_ax, "data") else ref_ax
216
200
  return replace(
217
201
  src_axarr,
218
202
  data=src_axarr.data[:0, ...],
@@ -222,17 +206,12 @@ class ResampleProcessor(
222
206
  xnew = ref_xvec[ref_idx]
223
207
 
224
208
  # Identify source data indices around ref tvec with some padding for better interpolation.
225
- src_start_ix = max(
226
- 0, np.where(x > xnew[0])[0][0] - 2 if np.any(x > xnew[0]) else 0
227
- )
209
+ src_start_ix = max(0, np.where(x > xnew[0])[0][0] - 2 if np.any(x > xnew[0]) else 0)
228
210
 
229
211
  x = x[src_start_ix:]
230
212
  y = src_axarr.data[src_start_ix:]
231
213
 
232
- if (
233
- isinstance(self.settings.fill_value, str)
234
- and self.settings.fill_value == "last"
235
- ):
214
+ if isinstance(self.settings.fill_value, str) and self.settings.fill_value == "last":
236
215
  fill_value = (y[0], y[-1])
237
216
  else:
238
217
  fill_value = self.settings.fill_value
@@ -3,14 +3,15 @@ from collections import deque
3
3
  import ezmsg.core as ez
4
4
  import numpy as np
5
5
  import numpy.typing as npt
6
+ from ezmsg.util.messages.axisarray import AxisArray
7
+ from ezmsg.util.messages.util import replace
8
+
6
9
  from ezmsg.sigproc.base import (
7
10
  BaseAdaptiveTransformer,
8
11
  BaseAdaptiveTransformerUnit,
9
12
  processor_state,
10
13
  )
11
14
  from ezmsg.sigproc.sampler import SampleMessage
12
- from ezmsg.util.messages.axisarray import AxisArray
13
- from ezmsg.util.messages.util import replace
14
15
 
15
16
 
16
17
  class RollingScalerSettings(ez.Settings):
@@ -71,11 +72,7 @@ class RollingScalerState:
71
72
  min_samples: int | None = None
72
73
 
73
74
 
74
- class RollingScalerProcessor(
75
- BaseAdaptiveTransformer[
76
- RollingScalerSettings, AxisArray, AxisArray, RollingScalerState
77
- ]
78
- ):
75
+ class RollingScalerProcessor(BaseAdaptiveTransformer[RollingScalerSettings, AxisArray, AxisArray, RollingScalerState]):
79
76
  """
80
77
  Processor for rolling z-score normalization of input `AxisArray` messages.
81
78
 
@@ -119,40 +116,23 @@ class RollingScalerProcessor(
119
116
  self._state.N = 0
120
117
  self._state.M2 = np.zeros(ch)
121
118
  self._state.k_samples = (
122
- int(
123
- np.ceil(
124
- self.settings.window_size / message.axes[self.settings.axis].gain
125
- )
126
- )
119
+ int(np.ceil(self.settings.window_size / message.axes[self.settings.axis].gain))
127
120
  if self.settings.window_size is not None
128
121
  else self.settings.k_samples
129
122
  )
130
123
  if self._state.k_samples is not None and self._state.k_samples < 1:
131
- ez.logger.warning(
132
- "window_size smaller than sample gain; setting k_samples to 1."
133
- )
124
+ ez.logger.warning("window_size smaller than sample gain; setting k_samples to 1.")
134
125
  self._state.k_samples = 1
135
126
  elif self._state.k_samples is None:
136
- ez.logger.warning(
137
- "k_samples is None; z-score accumulation will be unbounded."
138
- )
127
+ ez.logger.warning("k_samples is None; z-score accumulation will be unbounded.")
139
128
  self._state.samples = deque(maxlen=self._state.k_samples)
140
129
  self._state.min_samples = (
141
- int(
142
- np.ceil(
143
- self.settings.min_seconds / message.axes[self.settings.axis].gain
144
- )
145
- )
130
+ int(np.ceil(self.settings.min_seconds / message.axes[self.settings.axis].gain))
146
131
  if self.settings.window_size is not None
147
132
  else self.settings.min_samples
148
133
  )
149
- if (
150
- self._state.k_samples is not None
151
- and self._state.min_samples > self._state.k_samples
152
- ):
153
- ez.logger.warning(
154
- "min_samples is greater than k_samples; adjusting min_samples to k_samples."
155
- )
134
+ if self._state.k_samples is not None and self._state.min_samples > self._state.k_samples:
135
+ ez.logger.warning("min_samples is greater than k_samples; adjusting min_samples to k_samples.")
156
136
  self._state.min_samples = self._state.k_samples
157
137
 
158
138
  def _add_batch_stats(self, x: npt.NDArray) -> None:
@@ -161,10 +141,7 @@ class RollingScalerProcessor(
161
141
  mean_b = np.mean(x, axis=0)
162
142
  M2_b = np.sum((x - mean_b) ** 2, axis=0)
163
143
 
164
- if (
165
- self._state.k_samples is not None
166
- and len(self._state.samples) == self._state.k_samples
167
- ):
144
+ if self._state.k_samples is not None and len(self._state.samples) == self._state.k_samples:
168
145
  n_old, mean_old, M2_old = self._state.samples.popleft()
169
146
  N_T = self._state.N
170
147
  N_new = N_T - n_old
@@ -177,9 +154,7 @@ class RollingScalerProcessor(
177
154
  delta = mean_old - self._state.mean
178
155
  self._state.N = N_new
179
156
  self._state.mean = (N_T * self._state.mean - n_old * mean_old) / N_new
180
- self._state.M2 = (
181
- self._state.M2 - M2_old - (delta * delta) * (N_T * n_old / N_new)
182
- )
157
+ self._state.M2 = self._state.M2 - M2_old - (delta * delta) * (N_T * n_old / N_new)
183
158
 
184
159
  N_A = self._state.N
185
160
  N = N_A + n_b
ezmsg/sigproc/sampler.py CHANGED
@@ -1,28 +1,28 @@
1
1
  import asyncio
2
- from collections import deque
3
2
  import copy
4
3
  import traceback
5
4
  import typing
5
+ from collections import deque
6
6
 
7
- import numpy as np
8
7
  import ezmsg.core as ez
8
+ import numpy as np
9
+ from ezmsg.baseproc import (
10
+ BaseConsumerUnit,
11
+ BaseProducerUnit,
12
+ BaseStatefulProducer,
13
+ BaseStatefulTransformer,
14
+ BaseTransformerUnit,
15
+ processor_state,
16
+ )
9
17
  from ezmsg.util.messages.axisarray import (
10
18
  AxisArray,
11
19
  )
12
20
  from ezmsg.util.messages.util import replace
13
21
 
14
- from .util.profile import profile_subpub
15
22
  from .util.axisarray_buffer import HybridAxisArrayBuffer
16
23
  from .util.buffer import UpdateStrategy
17
24
  from .util.message import SampleMessage, SampleTriggerMessage
18
- from .base import (
19
- BaseStatefulTransformer,
20
- BaseConsumerUnit,
21
- BaseTransformerUnit,
22
- BaseStatefulProducer,
23
- BaseProducerUnit,
24
- processor_state,
25
- )
25
+ from .util.profile import profile_subpub
26
26
 
27
27
 
28
28
  class SamplerSettings(ez.Settings):
@@ -74,12 +74,8 @@ class SamplerState:
74
74
  triggers: deque[SampleTriggerMessage] | None = None
75
75
 
76
76
 
77
- class SamplerTransformer(
78
- BaseStatefulTransformer[SamplerSettings, AxisArray, AxisArray, SamplerState]
79
- ):
80
- def __call__(
81
- self, message: AxisArray | SampleTriggerMessage
82
- ) -> list[SampleMessage]:
77
+ class SamplerTransformer(BaseStatefulTransformer[SamplerSettings, AxisArray, AxisArray, SamplerState]):
78
+ def __call__(self, message: AxisArray | SampleTriggerMessage) -> list[SampleMessage]:
83
79
  # TODO: Currently we have a single entry point that accepts both
84
80
  # data and trigger messages and we choose a code path based on
85
81
  # the message type. However, in the future we will likely replace
@@ -99,9 +95,7 @@ class SamplerTransformer(
99
95
  # Compute hash based on message properties that require state reset
100
96
  axis = self.settings.axis or message.dims[0]
101
97
  axis_idx = message.get_axis_idx(axis)
102
- sample_shape = (
103
- message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
104
- )
98
+ sample_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
105
99
  return hash((sample_shape, message.key))
106
100
 
107
101
  def _reset_state(self, message: AxisArray) -> None:
@@ -193,20 +187,14 @@ class SamplerTransformer(
193
187
  trigger_ts: float = message.timestamp
194
188
  if not self.settings.estimate_alignment:
195
189
  # Override the trigger timestamp with the next sample's likely timestamp.
196
- trigger_ts = (
197
- self._state.buffer.axis_final_value + self._state.buffer.axis_gain
198
- )
190
+ trigger_ts = self._state.buffer.axis_final_value + self._state.buffer.axis_gain
199
191
 
200
- new_trig_msg = replace(
201
- message, timestamp=trigger_ts, period=_period, value=_value
202
- )
192
+ new_trig_msg = replace(message, timestamp=trigger_ts, period=_period, value=_value)
203
193
  self._state.triggers.append(new_trig_msg)
204
194
  return []
205
195
 
206
196
 
207
- class Sampler(
208
- BaseTransformerUnit[SamplerSettings, AxisArray, AxisArray, SamplerTransformer]
209
- ):
197
+ class Sampler(BaseTransformerUnit[SamplerSettings, AxisArray, AxisArray, SamplerTransformer]):
210
198
  SETTINGS = SamplerSettings
211
199
 
212
200
  INPUT_TRIGGER = ez.InputStream(SampleTriggerMessage)
@@ -269,19 +257,13 @@ class TriggerGeneratorState:
269
257
  output: int = 0
270
258
 
271
259
 
272
- class TriggerProducer(
273
- BaseStatefulProducer[
274
- TriggerGeneratorSettings, SampleTriggerMessage, TriggerGeneratorState
275
- ]
276
- ):
260
+ class TriggerProducer(BaseStatefulProducer[TriggerGeneratorSettings, SampleTriggerMessage, TriggerGeneratorState]):
277
261
  def _reset_state(self) -> None:
278
262
  self._state.output = 0
279
263
 
280
264
  async def _produce(self) -> SampleTriggerMessage:
281
265
  await asyncio.sleep(self.settings.publish_period)
282
- out_msg = SampleTriggerMessage(
283
- period=self.settings.period, value=self._state.output
284
- )
266
+ out_msg = SampleTriggerMessage(period=self.settings.period, value=self._state.output)
285
267
  self._state.output += 1
286
268
  return out_msg
287
269
 
ezmsg/sigproc/scaler.py CHANGED
@@ -1,27 +1,24 @@
1
1
  import typing
2
2
 
3
3
  import numpy as np
4
- from ezmsg.util.messages.axisarray import AxisArray
5
- from ezmsg.util.messages.util import replace
6
- from ezmsg.util.generator import consumer
7
-
8
- from .base import (
4
+ from ezmsg.baseproc import (
9
5
  BaseStatefulTransformer,
10
6
  BaseTransformerUnit,
11
7
  processor_state,
12
8
  )
13
- from .ewma import EWMATransformer, EWMASettings, _alpha_from_tau
9
+ from ezmsg.util.generator import consumer
10
+ from ezmsg.util.messages.axisarray import AxisArray
11
+ from ezmsg.util.messages.util import replace
14
12
 
15
13
  # Imports for backwards compatibility with previous module location
16
14
  from .ewma import EWMA_Deprecated as EWMA_Deprecated
17
- from .ewma import ewma_step as ewma_step
15
+ from .ewma import EWMASettings, EWMATransformer, _alpha_from_tau
18
16
  from .ewma import _tau_from_alpha as _tau_from_alpha
17
+ from .ewma import ewma_step as ewma_step
19
18
 
20
19
 
21
20
  @consumer
22
- def scaler(
23
- time_constant: float = 1.0, axis: str | None = None
24
- ) -> typing.Generator[AxisArray, AxisArray, None]:
21
+ def scaler(time_constant: float = 1.0, axis: str | None = None) -> typing.Generator[AxisArray, AxisArray, None]:
25
22
  """
26
23
  Apply the adaptive standard scaler from https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/
27
24
  This is faster than :obj:`scaler_np` for single-channel data.
@@ -85,19 +82,13 @@ class AdaptiveStandardScalerTransformer(
85
82
  ]
86
83
  ):
87
84
  def _reset_state(self, message: AxisArray) -> None:
88
- self._state.samps_ewma = EWMATransformer(
89
- time_constant=self.settings.time_constant, axis=self.settings.axis
90
- )
91
- self._state.vars_sq_ewma = EWMATransformer(
92
- time_constant=self.settings.time_constant, axis=self.settings.axis
93
- )
85
+ self._state.samps_ewma = EWMATransformer(time_constant=self.settings.time_constant, axis=self.settings.axis)
86
+ self._state.vars_sq_ewma = EWMATransformer(time_constant=self.settings.time_constant, axis=self.settings.axis)
94
87
 
95
88
  def _process(self, message: AxisArray) -> AxisArray:
96
89
  # Update step
97
90
  mean_message = self._state.samps_ewma(message)
98
- var_sq_message = self._state.vars_sq_ewma(
99
- replace(message, data=message.data**2)
100
- )
91
+ var_sq_message = self._state.vars_sq_ewma(replace(message, data=message.data**2))
101
92
 
102
93
  # Get step
103
94
  varis = var_sq_message.data - mean_message.data**2
@@ -119,9 +110,7 @@ class AdaptiveStandardScaler(
119
110
 
120
111
 
121
112
  # Backwards compatibility...
122
- def scaler_np(
123
- time_constant: float = 1.0, axis: str | None = None
124
- ) -> AdaptiveStandardScalerTransformer:
113
+ def scaler_np(time_constant: float = 1.0, axis: str | None = None) -> AdaptiveStandardScalerTransformer:
125
114
  return AdaptiveStandardScalerTransformer(
126
115
  settings=AdaptiveStandardScalerSettings(time_constant=time_constant, axis=axis)
127
116
  )
@@ -1,14 +1,13 @@
1
1
  import ezmsg.core as ez
2
- from ezmsg.util.messages.axisarray import AxisArray
3
- from ezmsg.util.messages.util import replace
4
2
  import numpy as np
5
3
  import numpy.typing as npt
6
-
7
- from .base import (
4
+ from ezmsg.baseproc import (
8
5
  BaseAsyncTransformer,
9
6
  BaseTransformerUnit,
10
7
  processor_state,
11
8
  )
9
+ from ezmsg.util.messages.axisarray import AxisArray
10
+ from ezmsg.util.messages.util import replace
12
11
 
13
12
 
14
13
  class SignalInjectorSettings(ez.Settings):
@@ -27,15 +26,11 @@ class SignalInjectorState:
27
26
 
28
27
 
29
28
  class SignalInjectorTransformer(
30
- BaseAsyncTransformer[
31
- SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorState
32
- ]
29
+ BaseAsyncTransformer[SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorState]
33
30
  ):
34
31
  def _hash_message(self, message: AxisArray) -> int:
35
32
  time_ax_idx = message.get_axis_idx(self.settings.time_dim)
36
- sample_shape = (
37
- message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
38
- )
33
+ sample_shape = message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
39
34
  return hash((message.key,) + sample_shape)
40
35
 
41
36
  def _reset_state(self, message: AxisArray) -> None:
@@ -44,9 +39,7 @@ class SignalInjectorTransformer(
44
39
  if self._state.cur_amplitude is None:
45
40
  self._state.cur_amplitude = self.settings.amplitude
46
41
  time_ax_idx = message.get_axis_idx(self.settings.time_dim)
47
- self._state.cur_shape = (
48
- message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
49
- )
42
+ self._state.cur_shape = message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
50
43
  rng = np.random.default_rng(self.settings.mixing_seed)
51
44
  self._state.mixing = rng.random((1, message.shape2d(self.settings.time_dim)[1]))
52
45
  self._state.mixing = (self._state.mixing * 2.0) - 1.0
@@ -63,11 +56,7 @@ class SignalInjectorTransformer(
63
56
  return out_msg
64
57
 
65
58
 
66
- class SignalInjector(
67
- BaseTransformerUnit[
68
- SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorTransformer
69
- ]
70
- ):
59
+ class SignalInjector(BaseTransformerUnit[SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorTransformer]):
71
60
  SETTINGS = SignalInjectorSettings
72
61
  INPUT_FREQUENCY = ez.InputStream(float | None)
73
62
  INPUT_AMPLITUDE = ez.InputStream(float)
ezmsg/sigproc/slicer.py CHANGED
@@ -1,17 +1,16 @@
1
+ import ezmsg.core as ez
1
2
  import numpy as np
2
3
  import numpy.typing as npt
3
- import ezmsg.core as ez
4
+ from ezmsg.baseproc import (
5
+ BaseStatefulTransformer,
6
+ BaseTransformerUnit,
7
+ processor_state,
8
+ )
4
9
  from ezmsg.util.messages.axisarray import (
5
10
  AxisArray,
6
- slice_along_axis,
7
11
  AxisBase,
8
12
  replace,
9
- )
10
-
11
- from .base import (
12
- BaseStatefulTransformer,
13
- BaseTransformerUnit,
14
- processor_state,
13
+ slice_along_axis,
15
14
  )
16
15
 
17
16
  """
@@ -49,11 +48,7 @@ def parse_slice(
49
48
  if "," not in s:
50
49
  parts = [part.strip() for part in s.split(":")]
51
50
  if len(parts) == 1:
52
- if (
53
- axinfo is not None
54
- and hasattr(axinfo, "data")
55
- and parts[0] in axinfo.data
56
- ):
51
+ if axinfo is not None and hasattr(axinfo, "data") and parts[0] in axinfo.data:
57
52
  return tuple(np.where(axinfo.data == parts[0])[0])
58
53
  return (int(parts[0]),)
59
54
  return (slice(*(int(part.strip()) if part else None for part in parts)),)
@@ -76,9 +71,7 @@ class SlicerState:
76
71
  b_change_dims: bool = False
77
72
 
78
73
 
79
- class SlicerTransformer(
80
- BaseStatefulTransformer[SlicerSettings, AxisArray, AxisArray, SlicerState]
81
- ):
74
+ class SlicerTransformer(BaseStatefulTransformer[SlicerSettings, AxisArray, AxisArray, SlicerState]):
82
75
  def _hash_message(self, message: AxisArray) -> int:
83
76
  axis = self.settings.axis or message.dims[-1]
84
77
  axis_idx = message.get_axis_idx(axis)
@@ -101,11 +94,7 @@ class SlicerTransformer(
101
94
  self._state.slice_ = np.s_[indices]
102
95
 
103
96
  # Create the output axis
104
- if (
105
- axis in message.axes
106
- and hasattr(message.axes[axis], "data")
107
- and len(message.axes[axis].data) > 0
108
- ):
97
+ if axis in message.axes and hasattr(message.axes[axis], "data") and len(message.axes[axis].data) > 0:
109
98
  in_data = np.array(message.axes[axis].data)
110
99
  if self._state.b_change_dims:
111
100
  out_data = in_data[self._state.slice_ : self._state.slice_ + 1]
@@ -119,17 +108,10 @@ class SlicerTransformer(
119
108
 
120
109
  replace_kwargs = {}
121
110
  if self._state.b_change_dims:
122
- replace_kwargs["dims"] = [
123
- _ for dim_ix, _ in enumerate(message.dims) if dim_ix != axis_idx
124
- ]
125
- replace_kwargs["axes"] = {
126
- k: v for k, v in message.axes.items() if k != axis
127
- }
111
+ replace_kwargs["dims"] = [_ for dim_ix, _ in enumerate(message.dims) if dim_ix != axis_idx]
112
+ replace_kwargs["axes"] = {k: v for k, v in message.axes.items() if k != axis}
128
113
  elif self._state.new_axis is not None:
129
- replace_kwargs["axes"] = {
130
- k: (v if k != axis else self._state.new_axis)
131
- for k, v in message.axes.items()
132
- }
114
+ replace_kwargs["axes"] = {k: (v if k != axis else self._state.new_axis) for k, v in message.axes.items()}
133
115
 
134
116
  return replace(
135
117
  message,
@@ -138,9 +120,7 @@ class SlicerTransformer(
138
120
  )
139
121
 
140
122
 
141
- class Slicer(
142
- BaseTransformerUnit[SlicerSettings, AxisArray, AxisArray, SlicerTransformer]
143
- ):
123
+ class Slicer(BaseTransformerUnit[SlicerSettings, AxisArray, AxisArray, SlicerTransformer]):
144
124
  SETTINGS = SlicerSettings
145
125
 
146
126
 
ezmsg/sigproc/spectral.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from .spectrum import OptionsEnum as OptionsEnum
2
- from .spectrum import WindowFunction as WindowFunction
3
- from .spectrum import SpectralTransform as SpectralTransform
4
2
  from .spectrum import SpectralOutput as SpectralOutput
5
- from .spectrum import SpectrumSettings as SpectrumSettings
3
+ from .spectrum import SpectralTransform as SpectralTransform
6
4
  from .spectrum import Spectrum as Spectrum
5
+ from .spectrum import SpectrumSettings as SpectrumSettings
6
+ from .spectrum import WindowFunction as WindowFunction