ezmsg-sigproc 2.4.1__py3-none-any.whl → 2.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +5 -11
  3. ezmsg/sigproc/adaptive_lattice_notch.py +11 -29
  4. ezmsg/sigproc/affinetransform.py +13 -38
  5. ezmsg/sigproc/aggregate.py +13 -30
  6. ezmsg/sigproc/bandpower.py +7 -15
  7. ezmsg/sigproc/base.py +141 -1276
  8. ezmsg/sigproc/butterworthfilter.py +8 -16
  9. ezmsg/sigproc/butterworthzerophase.py +123 -0
  10. ezmsg/sigproc/cheby.py +4 -10
  11. ezmsg/sigproc/combfilter.py +5 -8
  12. ezmsg/sigproc/decimate.py +2 -6
  13. ezmsg/sigproc/denormalize.py +6 -11
  14. ezmsg/sigproc/detrend.py +3 -4
  15. ezmsg/sigproc/diff.py +8 -17
  16. ezmsg/sigproc/downsample.py +6 -14
  17. ezmsg/sigproc/ewma.py +11 -27
  18. ezmsg/sigproc/ewmfilter.py +1 -1
  19. ezmsg/sigproc/extract_axis.py +3 -4
  20. ezmsg/sigproc/fbcca.py +31 -56
  21. ezmsg/sigproc/filter.py +19 -45
  22. ezmsg/sigproc/filterbank.py +33 -70
  23. ezmsg/sigproc/filterbankdesign.py +5 -12
  24. ezmsg/sigproc/fir_hilbert.py +336 -0
  25. ezmsg/sigproc/fir_pmc.py +209 -0
  26. ezmsg/sigproc/firfilter.py +12 -14
  27. ezmsg/sigproc/gaussiansmoothing.py +5 -9
  28. ezmsg/sigproc/kaiser.py +11 -15
  29. ezmsg/sigproc/math/abs.py +1 -3
  30. ezmsg/sigproc/math/add.py +121 -0
  31. ezmsg/sigproc/math/clip.py +1 -1
  32. ezmsg/sigproc/math/difference.py +98 -36
  33. ezmsg/sigproc/math/invert.py +1 -3
  34. ezmsg/sigproc/math/log.py +2 -6
  35. ezmsg/sigproc/messages.py +1 -2
  36. ezmsg/sigproc/quantize.py +2 -4
  37. ezmsg/sigproc/resample.py +13 -34
  38. ezmsg/sigproc/rollingscaler.py +232 -0
  39. ezmsg/sigproc/sampler.py +17 -35
  40. ezmsg/sigproc/scaler.py +8 -18
  41. ezmsg/sigproc/signalinjector.py +6 -16
  42. ezmsg/sigproc/slicer.py +9 -28
  43. ezmsg/sigproc/spectral.py +3 -3
  44. ezmsg/sigproc/spectrogram.py +12 -19
  45. ezmsg/sigproc/spectrum.py +12 -32
  46. ezmsg/sigproc/transpose.py +7 -18
  47. ezmsg/sigproc/util/asio.py +25 -156
  48. ezmsg/sigproc/util/axisarray_buffer.py +10 -26
  49. ezmsg/sigproc/util/buffer.py +18 -43
  50. ezmsg/sigproc/util/message.py +17 -31
  51. ezmsg/sigproc/util/profile.py +23 -174
  52. ezmsg/sigproc/util/sparse.py +5 -15
  53. ezmsg/sigproc/util/typeresolution.py +17 -83
  54. ezmsg/sigproc/wavelets.py +6 -15
  55. ezmsg/sigproc/window.py +24 -78
  56. {ezmsg_sigproc-2.4.1.dist-info → ezmsg_sigproc-2.6.0.dist-info}/METADATA +4 -3
  57. ezmsg_sigproc-2.6.0.dist-info/RECORD +63 -0
  58. ezmsg/sigproc/synth.py +0 -774
  59. ezmsg_sigproc-2.4.1.dist-info/RECORD +0 -59
  60. {ezmsg_sigproc-2.4.1.dist-info → ezmsg_sigproc-2.6.0.dist-info}/WHEEL +0 -0
  61. /ezmsg_sigproc-2.4.1.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.6.0.dist-info/licenses/LICENSE +0 -0
@@ -0,0 +1,232 @@
1
+ from collections import deque
2
+
3
+ import ezmsg.core as ez
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ from ezmsg.util.messages.axisarray import AxisArray
7
+ from ezmsg.util.messages.util import replace
8
+
9
+ from ezmsg.sigproc.base import (
10
+ BaseAdaptiveTransformer,
11
+ BaseAdaptiveTransformerUnit,
12
+ processor_state,
13
+ )
14
+ from ezmsg.sigproc.sampler import SampleMessage
15
+
16
+
17
+ class RollingScalerSettings(ez.Settings):
18
+ axis: str = "time"
19
+ """
20
+ Axis along which samples are arranged.
21
+ """
22
+
23
+ k_samples: int | None = 20
24
+ """
25
+ Rolling window size in number of samples.
26
+ """
27
+
28
+ window_size: float | None = None
29
+ """
30
+ Rolling window size in seconds.
31
+ If set, overrides `k_samples`.
32
+ `update_with_signal` likely should be True if using this option.
33
+ """
34
+
35
+ update_with_signal: bool = False
36
+ """
37
+ If True, update rolling statistics using the incoming process stream.
38
+ """
39
+
40
+ min_samples: int = 1
41
+ """
42
+ Minimum number of samples required to compute statistics.
43
+ Used when `window_size` is not set.
44
+ """
45
+
46
+ min_seconds: float = 1.0
47
+ """
48
+ Minimum duration in seconds required to compute statistics.
49
+ Used when `window_size` is set.
50
+ """
51
+
52
+ artifact_z_thresh: float | None = None
53
+ """
54
+ Threshold for z-score based artifact detection.
55
+ If set, samples with any channel exceeding this z-score will be excluded
56
+ from updating the rolling statistics.
57
+ """
58
+
59
+ clip: float | None = 10.0
60
+ """
61
+ If set, clip the output values to the range [-clip, clip].
62
+ """
63
+
64
+
65
+ @processor_state
66
+ class RollingScalerState:
67
+ mean: npt.NDArray | None = None
68
+ N: int = 0
69
+ M2: npt.NDArray | None = None
70
+ samples: deque | None = None
71
+ k_samples: int | None = None
72
+ min_samples: int | None = None
73
+
74
+
75
+ class RollingScalerProcessor(BaseAdaptiveTransformer[RollingScalerSettings, AxisArray, AxisArray, RollingScalerState]):
76
+ """
77
+ Processor for rolling z-score normalization of input `AxisArray` messages.
78
+
79
+ The processor maintains rolling statistics (mean and variance) over the last `k_samples`
80
+ samples received via the `partial_fit()` method. When processing an `AxisArray` message,
81
+ it normalizes the data using the current rolling statistics.
82
+
83
+ The input `AxisArray` messages are expected to have shape `(time, ch)`, where `ch` is the
84
+ channel axis. The processor computes the z-score for each channel independently.
85
+
86
+ Note: You should consider instead using the AdaptiveStandardScalerTransformer which
87
+ is computationally more efficient and uses less memory. This RollingScalerProcessor
88
+ is primarily provided to reproduce processing in the literature.
89
+
90
+ Settings:
91
+ ---------
92
+ k_samples: int
93
+ Number of previous samples to use for rolling statistics.
94
+
95
+ Example:
96
+ -----------------------------
97
+ ```python
98
+ processor = RollingScalerProcessor(
99
+ settings=RollingScalerSettings(
100
+ k_samples=20 # Number of previous samples to use for rolling statistics
101
+ )
102
+ )
103
+ ```
104
+ """
105
+
106
+ def _hash_message(self, message: AxisArray) -> int:
107
+ axis = message.dims[0] if self.settings.axis is None else self.settings.axis
108
+ gain = message.axes[axis].gain if hasattr(message.axes[axis], "gain") else 1
109
+ axis_idx = message.get_axis_idx(axis)
110
+ samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
111
+ return hash((message.key, samp_shape, gain))
112
+
113
+ def _reset_state(self, message: AxisArray) -> None:
114
+ ch = message.data.shape[-1]
115
+ self._state.mean = np.zeros(ch)
116
+ self._state.N = 0
117
+ self._state.M2 = np.zeros(ch)
118
+ self._state.k_samples = (
119
+ int(np.ceil(self.settings.window_size / message.axes[self.settings.axis].gain))
120
+ if self.settings.window_size is not None
121
+ else self.settings.k_samples
122
+ )
123
+ if self._state.k_samples is not None and self._state.k_samples < 1:
124
+ ez.logger.warning("window_size smaller than sample gain; setting k_samples to 1.")
125
+ self._state.k_samples = 1
126
+ elif self._state.k_samples is None:
127
+ ez.logger.warning("k_samples is None; z-score accumulation will be unbounded.")
128
+ self._state.samples = deque(maxlen=self._state.k_samples)
129
+ self._state.min_samples = (
130
+ int(np.ceil(self.settings.min_seconds / message.axes[self.settings.axis].gain))
131
+ if self.settings.window_size is not None
132
+ else self.settings.min_samples
133
+ )
134
+ if self._state.k_samples is not None and self._state.min_samples > self._state.k_samples:
135
+ ez.logger.warning("min_samples is greater than k_samples; adjusting min_samples to k_samples.")
136
+ self._state.min_samples = self._state.k_samples
137
+
138
+ def _add_batch_stats(self, x: npt.NDArray) -> None:
139
+ x = np.asarray(x, dtype=np.float64)
140
+ n_b = x.shape[0]
141
+ mean_b = np.mean(x, axis=0)
142
+ M2_b = np.sum((x - mean_b) ** 2, axis=0)
143
+
144
+ if self._state.k_samples is not None and len(self._state.samples) == self._state.k_samples:
145
+ n_old, mean_old, M2_old = self._state.samples.popleft()
146
+ N_T = self._state.N
147
+ N_new = N_T - n_old
148
+
149
+ if N_new <= 0:
150
+ self._state.N = 0
151
+ self._state.mean = np.zeros_like(self._state.mean)
152
+ self._state.M2 = np.zeros_like(self._state.M2)
153
+ else:
154
+ delta = mean_old - self._state.mean
155
+ self._state.N = N_new
156
+ self._state.mean = (N_T * self._state.mean - n_old * mean_old) / N_new
157
+ self._state.M2 = self._state.M2 - M2_old - (delta * delta) * (N_T * n_old / N_new)
158
+
159
+ N_A = self._state.N
160
+ N = N_A + n_b
161
+ delta = mean_b - self._state.mean
162
+ self._state.mean = self._state.mean + delta * (n_b / N)
163
+ self._state.M2 = self._state.M2 + M2_b + (delta * delta) * (N_A * n_b / N)
164
+ self._state.N = N
165
+
166
+ self._state.samples.append((n_b, mean_b, M2_b))
167
+
168
+ def partial_fit(self, message: SampleMessage) -> None:
169
+ x = message.sample.data
170
+ self._add_batch_stats(x)
171
+
172
+ def _process(self, message: AxisArray) -> AxisArray:
173
+ if self._state.N == 0 or self._state.N < self._state.min_samples:
174
+ if self.settings.update_with_signal:
175
+ x = message.data
176
+ if self.settings.artifact_z_thresh is not None and self._state.N > 0:
177
+ varis = self._state.M2 / self._state.N
178
+ std = np.maximum(np.sqrt(varis), 1e-8)
179
+ z = np.abs((x - self._state.mean) / std)
180
+ mask = np.any(z > self.settings.artifact_z_thresh, axis=1)
181
+ x = x[~mask]
182
+ if x.size > 0:
183
+ self._add_batch_stats(x)
184
+ return message
185
+
186
+ varis = self._state.M2 / self._state.N
187
+ std = np.maximum(np.sqrt(varis), 1e-8)
188
+ with np.errstate(divide="ignore", invalid="ignore"):
189
+ result = (message.data - self._state.mean) / std
190
+ result = np.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.0)
191
+ if self.settings.clip is not None:
192
+ result = np.clip(result, -self.settings.clip, self.settings.clip)
193
+
194
+ if self.settings.update_with_signal:
195
+ x = message.data
196
+ if self.settings.artifact_z_thresh is not None:
197
+ z_scores = np.abs((x - self._state.mean) / std)
198
+ mask = np.any(z_scores > self.settings.artifact_z_thresh, axis=1)
199
+ x = x[~mask]
200
+ if x.size > 0:
201
+ self._add_batch_stats(x)
202
+
203
+ return replace(message, data=result)
204
+
205
+
206
+ class RollingScalerUnit(
207
+ BaseAdaptiveTransformerUnit[
208
+ RollingScalerSettings,
209
+ AxisArray,
210
+ AxisArray,
211
+ RollingScalerProcessor,
212
+ ]
213
+ ):
214
+ """
215
+ Unit wrapper for :obj:`RollingScalerProcessor`.
216
+
217
+ This unit performs rolling z-score normalization on incoming `AxisArray` messages. The unit maintains rolling
218
+ statistics (mean and variance) over the last `k_samples` samples received. When processing an `AxisArray` message,
219
+ it normalizes the data using the current rolling statistics.
220
+
221
+ Example:
222
+ -----------------------------
223
+ ```python
224
+ unit = RollingScalerUnit(
225
+ settings=RollingScalerSettings(
226
+ k_samples=20 # Number of previous samples to use for rolling statistics
227
+ )
228
+ )
229
+ ```
230
+ """
231
+
232
+ SETTINGS = RollingScalerSettings
ezmsg/sigproc/sampler.py CHANGED
@@ -1,28 +1,28 @@
1
1
  import asyncio
2
- from collections import deque
3
2
  import copy
4
3
  import traceback
5
4
  import typing
5
+ from collections import deque
6
6
 
7
- import numpy as np
8
7
  import ezmsg.core as ez
8
+ import numpy as np
9
9
  from ezmsg.util.messages.axisarray import (
10
10
  AxisArray,
11
11
  )
12
12
  from ezmsg.util.messages.util import replace
13
13
 
14
- from .util.profile import profile_subpub
15
- from .util.axisarray_buffer import HybridAxisArrayBuffer
16
- from .util.buffer import UpdateStrategy
17
- from .util.message import SampleMessage, SampleTriggerMessage
18
14
  from .base import (
19
- BaseStatefulTransformer,
20
15
  BaseConsumerUnit,
21
- BaseTransformerUnit,
22
- BaseStatefulProducer,
23
16
  BaseProducerUnit,
17
+ BaseStatefulProducer,
18
+ BaseStatefulTransformer,
19
+ BaseTransformerUnit,
24
20
  processor_state,
25
21
  )
22
+ from .util.axisarray_buffer import HybridAxisArrayBuffer
23
+ from .util.buffer import UpdateStrategy
24
+ from .util.message import SampleMessage, SampleTriggerMessage
25
+ from .util.profile import profile_subpub
26
26
 
27
27
 
28
28
  class SamplerSettings(ez.Settings):
@@ -74,12 +74,8 @@ class SamplerState:
74
74
  triggers: deque[SampleTriggerMessage] | None = None
75
75
 
76
76
 
77
- class SamplerTransformer(
78
- BaseStatefulTransformer[SamplerSettings, AxisArray, AxisArray, SamplerState]
79
- ):
80
- def __call__(
81
- self, message: AxisArray | SampleTriggerMessage
82
- ) -> list[SampleMessage]:
77
+ class SamplerTransformer(BaseStatefulTransformer[SamplerSettings, AxisArray, AxisArray, SamplerState]):
78
+ def __call__(self, message: AxisArray | SampleTriggerMessage) -> list[SampleMessage]:
83
79
  # TODO: Currently we have a single entry point that accepts both
84
80
  # data and trigger messages and we choose a code path based on
85
81
  # the message type. However, in the future we will likely replace
@@ -99,9 +95,7 @@ class SamplerTransformer(
99
95
  # Compute hash based on message properties that require state reset
100
96
  axis = self.settings.axis or message.dims[0]
101
97
  axis_idx = message.get_axis_idx(axis)
102
- sample_shape = (
103
- message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
104
- )
98
+ sample_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
105
99
  return hash((sample_shape, message.key))
106
100
 
107
101
  def _reset_state(self, message: AxisArray) -> None:
@@ -193,20 +187,14 @@ class SamplerTransformer(
193
187
  trigger_ts: float = message.timestamp
194
188
  if not self.settings.estimate_alignment:
195
189
  # Override the trigger timestamp with the next sample's likely timestamp.
196
- trigger_ts = (
197
- self._state.buffer.axis_final_value + self._state.buffer.axis_gain
198
- )
190
+ trigger_ts = self._state.buffer.axis_final_value + self._state.buffer.axis_gain
199
191
 
200
- new_trig_msg = replace(
201
- message, timestamp=trigger_ts, period=_period, value=_value
202
- )
192
+ new_trig_msg = replace(message, timestamp=trigger_ts, period=_period, value=_value)
203
193
  self._state.triggers.append(new_trig_msg)
204
194
  return []
205
195
 
206
196
 
207
- class Sampler(
208
- BaseTransformerUnit[SamplerSettings, AxisArray, AxisArray, SamplerTransformer]
209
- ):
197
+ class Sampler(BaseTransformerUnit[SamplerSettings, AxisArray, AxisArray, SamplerTransformer]):
210
198
  SETTINGS = SamplerSettings
211
199
 
212
200
  INPUT_TRIGGER = ez.InputStream(SampleTriggerMessage)
@@ -269,19 +257,13 @@ class TriggerGeneratorState:
269
257
  output: int = 0
270
258
 
271
259
 
272
- class TriggerProducer(
273
- BaseStatefulProducer[
274
- TriggerGeneratorSettings, SampleTriggerMessage, TriggerGeneratorState
275
- ]
276
- ):
260
+ class TriggerProducer(BaseStatefulProducer[TriggerGeneratorSettings, SampleTriggerMessage, TriggerGeneratorState]):
277
261
  def _reset_state(self) -> None:
278
262
  self._state.output = 0
279
263
 
280
264
  async def _produce(self) -> SampleTriggerMessage:
281
265
  await asyncio.sleep(self.settings.publish_period)
282
- out_msg = SampleTriggerMessage(
283
- period=self.settings.period, value=self._state.output
284
- )
266
+ out_msg = SampleTriggerMessage(period=self.settings.period, value=self._state.output)
285
267
  self._state.output += 1
286
268
  return out_msg
287
269
 
ezmsg/sigproc/scaler.py CHANGED
@@ -1,27 +1,25 @@
1
1
  import typing
2
2
 
3
3
  import numpy as np
4
+ from ezmsg.util.generator import consumer
4
5
  from ezmsg.util.messages.axisarray import AxisArray
5
6
  from ezmsg.util.messages.util import replace
6
- from ezmsg.util.generator import consumer
7
7
 
8
8
  from .base import (
9
9
  BaseStatefulTransformer,
10
10
  BaseTransformerUnit,
11
11
  processor_state,
12
12
  )
13
- from .ewma import EWMATransformer, EWMASettings, _alpha_from_tau
14
13
 
15
14
  # Imports for backwards compatibility with previous module location
16
15
  from .ewma import EWMA_Deprecated as EWMA_Deprecated
17
- from .ewma import ewma_step as ewma_step
16
+ from .ewma import EWMASettings, EWMATransformer, _alpha_from_tau
18
17
  from .ewma import _tau_from_alpha as _tau_from_alpha
18
+ from .ewma import ewma_step as ewma_step
19
19
 
20
20
 
21
21
  @consumer
22
- def scaler(
23
- time_constant: float = 1.0, axis: str | None = None
24
- ) -> typing.Generator[AxisArray, AxisArray, None]:
22
+ def scaler(time_constant: float = 1.0, axis: str | None = None) -> typing.Generator[AxisArray, AxisArray, None]:
25
23
  """
26
24
  Apply the adaptive standard scaler from https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/
27
25
  This is faster than :obj:`scaler_np` for single-channel data.
@@ -85,19 +83,13 @@ class AdaptiveStandardScalerTransformer(
85
83
  ]
86
84
  ):
87
85
  def _reset_state(self, message: AxisArray) -> None:
88
- self._state.samps_ewma = EWMATransformer(
89
- time_constant=self.settings.time_constant, axis=self.settings.axis
90
- )
91
- self._state.vars_sq_ewma = EWMATransformer(
92
- time_constant=self.settings.time_constant, axis=self.settings.axis
93
- )
86
+ self._state.samps_ewma = EWMATransformer(time_constant=self.settings.time_constant, axis=self.settings.axis)
87
+ self._state.vars_sq_ewma = EWMATransformer(time_constant=self.settings.time_constant, axis=self.settings.axis)
94
88
 
95
89
  def _process(self, message: AxisArray) -> AxisArray:
96
90
  # Update step
97
91
  mean_message = self._state.samps_ewma(message)
98
- var_sq_message = self._state.vars_sq_ewma(
99
- replace(message, data=message.data**2)
100
- )
92
+ var_sq_message = self._state.vars_sq_ewma(replace(message, data=message.data**2))
101
93
 
102
94
  # Get step
103
95
  varis = var_sq_message.data - mean_message.data**2
@@ -119,9 +111,7 @@ class AdaptiveStandardScaler(
119
111
 
120
112
 
121
113
  # Backwards compatibility...
122
- def scaler_np(
123
- time_constant: float = 1.0, axis: str | None = None
124
- ) -> AdaptiveStandardScalerTransformer:
114
+ def scaler_np(time_constant: float = 1.0, axis: str | None = None) -> AdaptiveStandardScalerTransformer:
125
115
  return AdaptiveStandardScalerTransformer(
126
116
  settings=AdaptiveStandardScalerSettings(time_constant=time_constant, axis=axis)
127
117
  )
@@ -1,8 +1,8 @@
1
1
  import ezmsg.core as ez
2
- from ezmsg.util.messages.axisarray import AxisArray
3
- from ezmsg.util.messages.util import replace
4
2
  import numpy as np
5
3
  import numpy.typing as npt
4
+ from ezmsg.util.messages.axisarray import AxisArray
5
+ from ezmsg.util.messages.util import replace
6
6
 
7
7
  from .base import (
8
8
  BaseAsyncTransformer,
@@ -27,15 +27,11 @@ class SignalInjectorState:
27
27
 
28
28
 
29
29
  class SignalInjectorTransformer(
30
- BaseAsyncTransformer[
31
- SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorState
32
- ]
30
+ BaseAsyncTransformer[SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorState]
33
31
  ):
34
32
  def _hash_message(self, message: AxisArray) -> int:
35
33
  time_ax_idx = message.get_axis_idx(self.settings.time_dim)
36
- sample_shape = (
37
- message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
38
- )
34
+ sample_shape = message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
39
35
  return hash((message.key,) + sample_shape)
40
36
 
41
37
  def _reset_state(self, message: AxisArray) -> None:
@@ -44,9 +40,7 @@ class SignalInjectorTransformer(
44
40
  if self._state.cur_amplitude is None:
45
41
  self._state.cur_amplitude = self.settings.amplitude
46
42
  time_ax_idx = message.get_axis_idx(self.settings.time_dim)
47
- self._state.cur_shape = (
48
- message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
49
- )
43
+ self._state.cur_shape = message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
50
44
  rng = np.random.default_rng(self.settings.mixing_seed)
51
45
  self._state.mixing = rng.random((1, message.shape2d(self.settings.time_dim)[1]))
52
46
  self._state.mixing = (self._state.mixing * 2.0) - 1.0
@@ -63,11 +57,7 @@ class SignalInjectorTransformer(
63
57
  return out_msg
64
58
 
65
59
 
66
- class SignalInjector(
67
- BaseTransformerUnit[
68
- SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorTransformer
69
- ]
70
- ):
60
+ class SignalInjector(BaseTransformerUnit[SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorTransformer]):
71
61
  SETTINGS = SignalInjectorSettings
72
62
  INPUT_FREQUENCY = ez.InputStream(float | None)
73
63
  INPUT_AMPLITUDE = ez.InputStream(float)
ezmsg/sigproc/slicer.py CHANGED
@@ -1,11 +1,11 @@
1
+ import ezmsg.core as ez
1
2
  import numpy as np
2
3
  import numpy.typing as npt
3
- import ezmsg.core as ez
4
4
  from ezmsg.util.messages.axisarray import (
5
5
  AxisArray,
6
- slice_along_axis,
7
6
  AxisBase,
8
7
  replace,
8
+ slice_along_axis,
9
9
  )
10
10
 
11
11
  from .base import (
@@ -49,11 +49,7 @@ def parse_slice(
49
49
  if "," not in s:
50
50
  parts = [part.strip() for part in s.split(":")]
51
51
  if len(parts) == 1:
52
- if (
53
- axinfo is not None
54
- and hasattr(axinfo, "data")
55
- and parts[0] in axinfo.data
56
- ):
52
+ if axinfo is not None and hasattr(axinfo, "data") and parts[0] in axinfo.data:
57
53
  return tuple(np.where(axinfo.data == parts[0])[0])
58
54
  return (int(parts[0]),)
59
55
  return (slice(*(int(part.strip()) if part else None for part in parts)),)
@@ -76,9 +72,7 @@ class SlicerState:
76
72
  b_change_dims: bool = False
77
73
 
78
74
 
79
- class SlicerTransformer(
80
- BaseStatefulTransformer[SlicerSettings, AxisArray, AxisArray, SlicerState]
81
- ):
75
+ class SlicerTransformer(BaseStatefulTransformer[SlicerSettings, AxisArray, AxisArray, SlicerState]):
82
76
  def _hash_message(self, message: AxisArray) -> int:
83
77
  axis = self.settings.axis or message.dims[-1]
84
78
  axis_idx = message.get_axis_idx(axis)
@@ -101,11 +95,7 @@ class SlicerTransformer(
101
95
  self._state.slice_ = np.s_[indices]
102
96
 
103
97
  # Create the output axis
104
- if (
105
- axis in message.axes
106
- and hasattr(message.axes[axis], "data")
107
- and len(message.axes[axis].data) > 0
108
- ):
98
+ if axis in message.axes and hasattr(message.axes[axis], "data") and len(message.axes[axis].data) > 0:
109
99
  in_data = np.array(message.axes[axis].data)
110
100
  if self._state.b_change_dims:
111
101
  out_data = in_data[self._state.slice_ : self._state.slice_ + 1]
@@ -119,17 +109,10 @@ class SlicerTransformer(
119
109
 
120
110
  replace_kwargs = {}
121
111
  if self._state.b_change_dims:
122
- replace_kwargs["dims"] = [
123
- _ for dim_ix, _ in enumerate(message.dims) if dim_ix != axis_idx
124
- ]
125
- replace_kwargs["axes"] = {
126
- k: v for k, v in message.axes.items() if k != axis
127
- }
112
+ replace_kwargs["dims"] = [_ for dim_ix, _ in enumerate(message.dims) if dim_ix != axis_idx]
113
+ replace_kwargs["axes"] = {k: v for k, v in message.axes.items() if k != axis}
128
114
  elif self._state.new_axis is not None:
129
- replace_kwargs["axes"] = {
130
- k: (v if k != axis else self._state.new_axis)
131
- for k, v in message.axes.items()
132
- }
115
+ replace_kwargs["axes"] = {k: (v if k != axis else self._state.new_axis) for k, v in message.axes.items()}
133
116
 
134
117
  return replace(
135
118
  message,
@@ -138,9 +121,7 @@ class SlicerTransformer(
138
121
  )
139
122
 
140
123
 
141
- class Slicer(
142
- BaseTransformerUnit[SlicerSettings, AxisArray, AxisArray, SlicerTransformer]
143
- ):
124
+ class Slicer(BaseTransformerUnit[SlicerSettings, AxisArray, AxisArray, SlicerTransformer]):
144
125
  SETTINGS = SlicerSettings
145
126
 
146
127
 
ezmsg/sigproc/spectral.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from .spectrum import OptionsEnum as OptionsEnum
2
- from .spectrum import WindowFunction as WindowFunction
3
- from .spectrum import SpectralTransform as SpectralTransform
4
2
  from .spectrum import SpectralOutput as SpectralOutput
5
- from .spectrum import SpectrumSettings as SpectrumSettings
3
+ from .spectrum import SpectralTransform as SpectralTransform
6
4
  from .spectrum import Spectrum as Spectrum
5
+ from .spectrum import SpectrumSettings as SpectrumSettings
6
+ from .spectrum import WindowFunction as WindowFunction
@@ -1,20 +1,21 @@
1
1
  from typing import Generator
2
+
2
3
  import ezmsg.core as ez
3
4
  from ezmsg.util.messages.axisarray import AxisArray
4
5
  from ezmsg.util.messages.modify import modify_axis
5
6
 
6
- from .window import Anchor, WindowTransformer
7
- from .spectrum import (
8
- WindowFunction,
9
- SpectralTransform,
10
- SpectralOutput,
11
- SpectrumTransformer,
12
- )
13
7
  from .base import (
14
- CompositeProcessor,
15
8
  BaseStatefulProcessor,
16
9
  BaseTransformerUnit,
10
+ CompositeProcessor,
11
+ )
12
+ from .spectrum import (
13
+ SpectralOutput,
14
+ SpectralTransform,
15
+ SpectrumTransformer,
16
+ WindowFunction,
17
17
  )
18
+ from .window import Anchor, WindowTransformer
18
19
 
19
20
 
20
21
  class SpectrogramSettings(ez.Settings):
@@ -41,9 +42,7 @@ class SpectrogramSettings(ez.Settings):
41
42
  """The :obj:`SpectralOutput` format."""
42
43
 
43
44
 
44
- class SpectrogramTransformer(
45
- CompositeProcessor[SpectrogramSettings, AxisArray, AxisArray]
46
- ):
45
+ class SpectrogramTransformer(CompositeProcessor[SpectrogramSettings, AxisArray, AxisArray]):
47
46
  @staticmethod
48
47
  def _initialize_processors(
49
48
  settings: SpectrogramSettings,
@@ -54,9 +53,7 @@ class SpectrogramTransformer(
54
53
  newaxis="win",
55
54
  window_dur=settings.window_dur,
56
55
  window_shift=settings.window_shift,
57
- zero_pad_until="shift"
58
- if settings.window_shift is not None
59
- else "input",
56
+ zero_pad_until="shift" if settings.window_shift is not None else "input",
60
57
  anchor=settings.window_anchor,
61
58
  ),
62
59
  "spectrum": SpectrumTransformer(
@@ -69,11 +66,7 @@ class SpectrogramTransformer(
69
66
  }
70
67
 
71
68
 
72
- class Spectrogram(
73
- BaseTransformerUnit[
74
- SpectrogramSettings, AxisArray, AxisArray, SpectrogramTransformer
75
- ]
76
- ):
69
+ class Spectrogram(BaseTransformerUnit[SpectrogramSettings, AxisArray, AxisArray, SpectrogramTransformer]):
77
70
  SETTINGS = SpectrogramSettings
78
71
 
79
72