ezmsg-sigproc 1.5.0__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ezmsg/sigproc/scaler.py CHANGED
@@ -1,9 +1,12 @@
1
+ import functools
1
2
  import typing
2
3
 
3
4
  import numpy as np
4
5
  import numpy.typing as npt
6
+ import scipy.signal
5
7
  import ezmsg.core as ez
6
- from ezmsg.util.messages.axisarray import AxisArray, replace
8
+ from ezmsg.util.messages.axisarray import AxisArray
9
+ from ezmsg.util.messages.util import replace
7
10
  from ezmsg.util.generator import consumer
8
11
 
9
12
  from .base import GenAxisArray
@@ -27,9 +30,139 @@ def _alpha_from_tau(tau: float, dt: float) -> float:
27
30
  return 1 - np.exp(-dt / tau)
28
31
 
29
32
 
33
+ def ewma_step(
34
+ sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None
35
+ ):
36
+ """
37
+ Do an exponentially weighted moving average step.
38
+
39
+ Args:
40
+ sample: The new sample.
41
+ zi: The output of the previous step.
42
+ alpha: Fading factor.
43
+ beta: Persisting factor. If None, it is calculated as 1-alpha.
44
+
45
+ Returns:
46
+ alpha * sample + beta * zi
47
+
48
+ """
49
+ # Potential micro-optimization:
50
+ # Current: scalar-arr multiplication, scalar-arr multiplication, arr-arr addition
51
+ # Alternative: arr-arr subtraction, arr-arr multiplication, arr-arr addition
52
+ # return zi + alpha * (new_sample - zi)
53
+ beta = beta or (1 - alpha)
54
+ return alpha * sample + beta * zi
55
+
56
+
57
+ class EWMA:
58
+ def __init__(self, alpha: float):
59
+ self.beta = 1 - alpha
60
+ self._filt_func = functools.partial(
61
+ scipy.signal.lfilter, [alpha], [1.0, alpha - 1.0], axis=0
62
+ )
63
+ self.prev = None
64
+
65
+ def compute(self, arr: npt.NDArray) -> npt.NDArray:
66
+ if self.prev is None:
67
+ self.prev = self.beta * arr[:1]
68
+ expected, self.prev = self._filt_func(arr, zi=self.prev)
69
+ return expected
70
+
71
+
72
+ class EWMA_Deprecated:
73
+ """
74
+ Grabbed these methods from https://stackoverflow.com/a/70998068 and other answers in that topic,
75
+ but they ended up being slower than the scipy.signal.lfilter method.
76
+ Additionally, `compute` and `compute2` suffer from potential errors as the vector length increases
77
+ and beta**n approaches zero.
78
+ """
79
+
80
+ def __init__(self, alpha: float, max_len: int):
81
+ self.alpha = alpha
82
+ self.beta = 1 - alpha
83
+ self.prev: npt.NDArray | None = None
84
+ self.weights = np.empty((max_len + 1,), float)
85
+ self._precalc_weights(max_len)
86
+ self._step_func = functools.partial(ewma_step, alpha=self.alpha, beta=self.beta)
87
+
88
+ def _precalc_weights(self, n: int):
89
+ # (1-α)^0, (1-α)^1, (1-α)^2, ..., (1-α)^n
90
+ np.power(self.beta, np.arange(n + 1), out=self.weights)
91
+
92
+ def compute(self, arr: npt.NDArray, out: npt.NDArray | None = None) -> npt.NDArray:
93
+ if out is None:
94
+ out = np.empty(arr.shape, arr.dtype)
95
+
96
+ n = arr.shape[0]
97
+ weights = self.weights[:n]
98
+ weights = np.expand_dims(weights, list(range(1, arr.ndim)))
99
+
100
+ # α*P0, α*P1, α*P2, ..., α*Pn
101
+ np.multiply(self.alpha, arr, out)
102
+
103
+ # α*P0/(1-α)^0, α*P1/(1-α)^1, α*P2/(1-α)^2, ..., α*Pn/(1-α)^n
104
+ np.divide(out, weights, out)
105
+
106
+ # α*P0/(1-α)^0, α*P0/(1-α)^0 + α*P1/(1-α)^1, ...
107
+ np.cumsum(out, axis=0, out=out)
108
+
109
+ # (α*P0/(1-α)^0)*(1-α)^0, (α*P0/(1-α)^0 + α*P1/(1-α)^1)*(1-α)^1, ...
110
+ np.multiply(out, weights, out)
111
+
112
+ # Add the previous output
113
+ if self.prev is None:
114
+ self.prev = arr[:1]
115
+
116
+ out += self.prev * np.expand_dims(
117
+ self.weights[1 : n + 1], list(range(1, arr.ndim))
118
+ )
119
+
120
+ self.prev = out[-1:]
121
+
122
+ return out
123
+
124
+ def compute2(self, arr: npt.NDArray) -> npt.NDArray:
125
+ """
126
+ Compute the Exponentially Weighted Moving Average (EWMA) of the input array.
127
+
128
+ Args:
129
+ arr: The input array to be smoothed.
130
+
131
+ Returns:
132
+ The smoothed array.
133
+ """
134
+ n = arr.shape[0]
135
+ if n > len(self.weights):
136
+ self._precalc_weights(n)
137
+ weights = self.weights[:n][::-1]
138
+ weights = np.expand_dims(weights, list(range(1, arr.ndim)))
139
+
140
+ result = np.cumsum(self.alpha * weights * arr, axis=0)
141
+ result = result / weights
142
+
143
+ # Handle the first call when prev is unset
144
+ if self.prev is None:
145
+ self.prev = arr[:1]
146
+
147
+ result += self.prev * np.expand_dims(
148
+ self.weights[1 : n + 1], list(range(1, arr.ndim))
149
+ )
150
+
151
+ # Store the result back into prev
152
+ self.prev = result[-1]
153
+
154
+ return result
155
+
156
+ def compute_sample(self, new_sample: npt.NDArray) -> npt.NDArray:
157
+ if self.prev is None:
158
+ self.prev = new_sample
159
+ self.prev = self._step_func(new_sample, self.prev)
160
+ return self.prev
161
+
162
+
30
163
  @consumer
31
164
  def scaler(
32
- time_constant: float = 1.0, axis: typing.Optional[str] = None
165
+ time_constant: float = 1.0, axis: str | None = None
33
166
  ) -> typing.Generator[AxisArray, AxisArray, None]:
34
167
  """
35
168
  Apply the adaptive standard scaler from https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/
@@ -77,7 +210,7 @@ def scaler(
77
210
 
78
211
  @consumer
79
212
  def scaler_np(
80
- time_constant: float = 1.0, axis: typing.Optional[str] = None
213
+ time_constant: float = 1.0, axis: str | None = None
81
214
  ) -> typing.Generator[AxisArray, AxisArray, None]:
82
215
  """
83
216
  Create a generator function that applies an adaptive standard scaler.
@@ -95,10 +228,8 @@ def scaler_np(
95
228
  msg_out = AxisArray(np.array([]), dims=[""])
96
229
 
97
230
  # State variables
98
- alpha: float = 0.0
99
- means: typing.Optional[npt.NDArray] = None
100
- vars_means: typing.Optional[npt.NDArray] = None
101
- vars_sq_means: typing.Optional[npt.NDArray] = None
231
+ samps_ewma: EWMA | None = None
232
+ vars_sq_ewma: EWMA | None = None
102
233
 
103
234
  # Reset if input changes
104
235
  check_input = {
@@ -107,45 +238,32 @@ def scaler_np(
107
238
  "key": None, # Key change implies buffered means/vars are invalid.
108
239
  }
109
240
 
110
- def _ew_update(arr, prev, _alpha):
111
- if np.all(prev == 0):
112
- return arr
113
- # return _alpha * arr + (1 - _alpha) * prev
114
- # Micro-optimization: sub, mult, add (below) is faster than sub, mult, mult, add (above)
115
- return prev + _alpha * (arr - prev)
116
-
117
241
  while True:
118
242
  msg_in: AxisArray = yield msg_out
119
243
 
120
244
  axis = axis or msg_in.dims[0]
121
245
  axis_idx = msg_in.get_axis_idx(axis)
122
246
 
123
- if msg_in.axes[axis].gain != check_input["gain"]:
124
- alpha = _alpha_from_tau(time_constant, msg_in.axes[axis].gain)
125
- check_input["gain"] = msg_in.axes[axis].gain
126
-
127
247
  data: npt.NDArray = np.moveaxis(msg_in.data, axis_idx, 0)
128
248
  b_reset = data.shape[1:] != check_input["shape"]
129
- b_reset |= msg_in.key != check_input["key"]
249
+ b_reset = b_reset or msg_in.axes[axis].gain != check_input["gain"]
250
+ b_reset = b_reset or msg_in.key != check_input["key"]
130
251
  if b_reset:
131
252
  check_input["shape"] = data.shape[1:]
253
+ check_input["gain"] = msg_in.axes[axis].gain
132
254
  check_input["key"] = msg_in.key
133
- vars_sq_means = np.zeros_like(data[0], dtype=float)
134
- vars_means = np.zeros_like(data[0], dtype=float)
135
- means = np.zeros_like(data[0], dtype=float)
136
-
137
- result = np.zeros_like(data)
138
- for sample_ix in range(data.shape[0]):
139
- sample = data[sample_ix]
140
- # Update step
141
- vars_means = _ew_update(sample, vars_means, alpha)
142
- vars_sq_means = _ew_update(sample**2, vars_sq_means, alpha)
143
- means = _ew_update(sample, means, alpha)
144
- # Get step
145
- varis = vars_sq_means - vars_means**2
146
- y = (sample - means) / (varis**0.5)
147
- result[sample_ix] = y
255
+ alpha = _alpha_from_tau(time_constant, msg_in.axes[axis].gain)
256
+ samps_ewma = EWMA(alpha=alpha)
257
+ vars_sq_ewma = EWMA(alpha=alpha)
258
+
259
+ # Update step
260
+ means = samps_ewma.compute(data)
261
+ vars_sq_means = vars_sq_ewma.compute(data**2)
148
262
 
263
+ # Get step
264
+ varis = vars_sq_means - means**2
265
+ with np.errstate(divide="ignore", invalid="ignore"):
266
+ result = (data - means) / (varis**0.5)
149
267
  result[np.isnan(result)] = 0.0
150
268
  result = np.moveaxis(result, 0, axis_idx)
151
269
  msg_out = replace(msg_in, data=result)
@@ -158,7 +276,7 @@ class AdaptiveStandardScalerSettings(ez.Settings):
158
276
  """
159
277
 
160
278
  time_constant: float = 1.0
161
- axis: typing.Optional[str] = None
279
+ axis: str | None = None
162
280
 
163
281
 
164
282
  class AdaptiveStandardScaler(GenAxisArray):
@@ -1,21 +1,22 @@
1
1
  import typing
2
2
 
3
3
  import ezmsg.core as ez
4
- from ezmsg.util.messages.axisarray import AxisArray, replace
4
+ from ezmsg.util.messages.axisarray import AxisArray
5
+ from ezmsg.util.messages.util import replace
5
6
  import numpy as np
6
7
  import numpy.typing as npt
7
8
 
8
9
 
9
10
  class SignalInjectorSettings(ez.Settings):
10
11
  time_dim: str = "time" # Input signal needs a time dimension with units in sec.
11
- frequency: typing.Optional[float] = None # Hz
12
+ frequency: float | None = None # Hz
12
13
  amplitude: float = 1.0
13
- mixing_seed: typing.Optional[int] = None
14
+ mixing_seed: int | None = None
14
15
 
15
16
 
16
17
  class SignalInjectorState(ez.State):
17
- cur_shape: typing.Optional[typing.Tuple[int, ...]] = None
18
- cur_frequency: typing.Optional[float] = None
18
+ cur_shape: tuple[int, ...] | None = None
19
+ cur_frequency: float | None = None
19
20
  cur_amplitude: float
20
21
  mixing: npt.NDArray
21
22
 
@@ -29,7 +30,7 @@ class SignalInjector(ez.Unit):
29
30
  SETTINGS = SignalInjectorSettings
30
31
  STATE = SignalInjectorState
31
32
 
32
- INPUT_FREQUENCY = ez.InputStream(typing.Optional[float])
33
+ INPUT_FREQUENCY = ez.InputStream(float | None)
33
34
  INPUT_AMPLITUDE = ez.InputStream(float)
34
35
  INPUT_SIGNAL = ez.InputStream(AxisArray)
35
36
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
@@ -40,7 +41,7 @@ class SignalInjector(ez.Unit):
40
41
  self.STATE.mixing = np.array([])
41
42
 
42
43
  @ez.subscriber(INPUT_FREQUENCY)
43
- async def on_frequency(self, msg: typing.Optional[float]) -> None:
44
+ async def on_frequency(self, msg: float | None) -> None:
44
45
  self.STATE.cur_frequency = msg
45
46
 
46
47
  @ez.subscriber(INPUT_AMPLITUDE)
ezmsg/sigproc/slicer.py CHANGED
@@ -21,8 +21,8 @@ Slicer:Select a subset of data along a particular axis.
21
21
 
22
22
  def parse_slice(
23
23
  s: str,
24
- axinfo: typing.Optional[AxisArray.CoordinateAxis] = None,
25
- ) -> typing.Tuple[typing.Union[slice, int], ...]:
24
+ axinfo: AxisArray.CoordinateAxis | None = None,
25
+ ) -> tuple[slice | int, ...]:
26
26
  """
27
27
  Parses a string representation of a slice and returns a tuple of slice objects.
28
28
 
@@ -63,7 +63,7 @@ def parse_slice(
63
63
 
64
64
  @consumer
65
65
  def slicer(
66
- selection: str = "", axis: typing.Optional[str] = None
66
+ selection: str = "", axis: str | None = None
67
67
  ) -> typing.Generator[AxisArray, AxisArray, None]:
68
68
  """
69
69
  Slice along a particular axis.
@@ -80,8 +80,8 @@ def slicer(
80
80
  msg_out = AxisArray(np.array([]), dims=[""])
81
81
 
82
82
  # State variables
83
- _slice: typing.Optional[typing.Union[slice, npt.NDArray]] = None
84
- new_axis: typing.Optional[AxisBase] = None
83
+ _slice: slice | npt.NDArray | None = None
84
+ new_axis: AxisBase | None = None
85
85
  b_change_dims: bool = False # If number of dimensions changes when slicing
86
86
 
87
87
  # Reset if input changes
@@ -154,7 +154,7 @@ def slicer(
154
154
 
155
155
  class SlicerSettings(ez.Settings):
156
156
  selection: str = ""
157
- axis: typing.Optional[str] = None
157
+ axis: str | None = None
158
158
 
159
159
 
160
160
  class Slicer(GenAxisArray):
@@ -12,12 +12,12 @@ from .base import GenAxisArray
12
12
 
13
13
  @consumer
14
14
  def spectrogram(
15
- window_dur: typing.Optional[float] = None,
16
- window_shift: typing.Optional[float] = None,
15
+ window_dur: float | None = None,
16
+ window_shift: float | None = None,
17
17
  window: WindowFunction = WindowFunction.HANNING,
18
18
  transform: SpectralTransform = SpectralTransform.REL_DB,
19
19
  output: SpectralOutput = SpectralOutput.POSITIVE,
20
- ) -> typing.Generator[typing.Optional[AxisArray], AxisArray, None]:
20
+ ) -> typing.Generator[AxisArray | None, AxisArray, None]:
21
21
  """
22
22
  Calculate a spectrogram on streaming data.
23
23
 
@@ -50,7 +50,7 @@ def spectrogram(
50
50
  )
51
51
 
52
52
  # State variables
53
- msg_out: typing.Optional[AxisArray] = None
53
+ msg_out: AxisArray | None = None
54
54
 
55
55
  while True:
56
56
  msg_in: AxisArray = yield msg_out
@@ -63,8 +63,8 @@ class SpectrogramSettings(ez.Settings):
63
63
  See :obj:`spectrogram` for a description of the parameters.
64
64
  """
65
65
 
66
- window_dur: typing.Optional[float] = None # window duration in seconds
67
- window_shift: typing.Optional[float] = None
66
+ window_dur: float | None = None # window duration in seconds
67
+ window_shift: float | None = None
68
68
  """"window step in seconds. If None, window_shift == window_dur"""
69
69
 
70
70
  # See SpectrumSettings for details of following settings:
ezmsg/sigproc/spectrum.py CHANGED
@@ -68,14 +68,14 @@ class SpectralOutput(OptionsEnum):
68
68
 
69
69
  @consumer
70
70
  def spectrum(
71
- axis: typing.Optional[str] = None,
72
- out_axis: typing.Optional[str] = "freq",
71
+ axis: str | None = None,
72
+ out_axis: str | None = "freq",
73
73
  window: WindowFunction = WindowFunction.HANNING,
74
74
  transform: SpectralTransform = SpectralTransform.REL_DB,
75
75
  output: SpectralOutput = SpectralOutput.POSITIVE,
76
- norm: typing.Optional[str] = "forward",
76
+ norm: str | None = "forward",
77
77
  do_fftshift: bool = True,
78
- nfft: typing.Optional[int] = None,
78
+ nfft: int | None = None,
79
79
  ) -> typing.Generator[AxisArray, AxisArray, None]:
80
80
  """
81
81
  Calculate a spectrum on a data slice.
@@ -105,10 +105,10 @@ def spectrum(
105
105
  apply_window = window != WindowFunction.NONE
106
106
  do_fftshift &= output == SpectralOutput.FULL
107
107
  f_sl = slice(None)
108
- freq_axis: typing.Optional[AxisArray.LinearAxis] = None
109
- fftfun: typing.Optional[typing.Callable] = None
110
- f_transform: typing.Optional[typing.Callable] = None
111
- new_dims: typing.Optional[typing.List[str]] = None
108
+ freq_axis: AxisArray.LinearAxis | None = None
109
+ fftfun: typing.Callable | None = None
110
+ f_transform: typing.Callable | None = None
111
+ new_dims: list[str] | None = None
112
112
 
113
113
  # Reset if input changes substantially
114
114
  check_input = {
@@ -238,9 +238,9 @@ class SpectrumSettings(ez.Settings):
238
238
  See :obj:`spectrum` for a description of the parameters.
239
239
  """
240
240
 
241
- axis: typing.Optional[str] = None
242
- # n: typing.Optional[int] = None # n parameter for fft
243
- out_axis: typing.Optional[str] = "freq" # If none; don't change dim name
241
+ axis: str | None = None
242
+ # n: int | None = None # n parameter for fft
243
+ out_axis: str | None = "freq" # If none; don't change dim name
244
244
  window: WindowFunction = WindowFunction.HAMMING
245
245
  transform: SpectralTransform = SpectralTransform.REL_DB
246
246
  output: SpectralOutput = SpectralOutput.POSITIVE
ezmsg/sigproc/synth.py CHANGED
@@ -1,18 +1,19 @@
1
1
  import asyncio
2
2
  from dataclasses import field
3
3
  import time
4
- from typing import Optional, Generator, AsyncGenerator, Union
4
+ import typing
5
5
 
6
6
  import numpy as np
7
7
  import ezmsg.core as ez
8
8
  from ezmsg.util.generator import consumer
9
- from ezmsg.util.messages.axisarray import AxisArray, replace
9
+ from ezmsg.util.messages.axisarray import AxisArray
10
+ from ezmsg.util.messages.util import replace
10
11
 
11
12
  from .butterworthfilter import ButterworthFilter, ButterworthFilterSettings
12
13
  from .base import GenAxisArray
13
14
 
14
15
 
15
- def clock(dispatch_rate: Optional[float]) -> Generator[ez.Flag, None, None]:
16
+ def clock(dispatch_rate: float | None) -> typing.Generator[ez.Flag, None, None]:
16
17
  """
17
18
  Construct a generator that yields events at a specified rate.
18
19
 
@@ -32,7 +33,7 @@ def clock(dispatch_rate: Optional[float]) -> Generator[ez.Flag, None, None]:
32
33
  yield ez.Flag()
33
34
 
34
35
 
35
- async def aclock(dispatch_rate: Optional[float]) -> AsyncGenerator[ez.Flag, None]:
36
+ async def aclock(dispatch_rate: float | None) -> typing.AsyncGenerator[ez.Flag, None]:
36
37
  """
37
38
  ``asyncio`` version of :obj:`clock`.
38
39
 
@@ -53,12 +54,12 @@ class ClockSettings(ez.Settings):
53
54
  """Settings for :obj:`Clock`. See :obj:`clock` for parameter description."""
54
55
 
55
56
  # Message dispatch rate (Hz), or None (fast as possible)
56
- dispatch_rate: Optional[float]
57
+ dispatch_rate: float | None
57
58
 
58
59
 
59
60
  class ClockState(ez.State):
60
61
  cur_settings: ClockSettings
61
- gen: AsyncGenerator
62
+ gen: typing.AsyncGenerator
62
63
 
63
64
 
64
65
  class Clock(ez.Unit):
@@ -83,7 +84,7 @@ class Clock(ez.Unit):
83
84
  self.construct_generator()
84
85
 
85
86
  @ez.publisher(OUTPUT_CLOCK)
86
- async def generate(self) -> AsyncGenerator:
87
+ async def generate(self) -> typing.AsyncGenerator:
87
88
  while True:
88
89
  out = await self.STATE.gen.__anext__()
89
90
  if out:
@@ -93,11 +94,11 @@ class Clock(ez.Unit):
93
94
  # COUNTER - Generate incrementing integer. fs and dispatch_rate parameters combine to give many options. #
94
95
  async def acounter(
95
96
  n_time: int,
96
- fs: Optional[float],
97
+ fs: float | None,
97
98
  n_ch: int = 1,
98
- dispatch_rate: Optional[Union[float, str]] = None,
99
- mod: Optional[int] = None,
100
- ) -> AsyncGenerator[AxisArray, None]:
99
+ dispatch_rate: float | str | None = None,
100
+ mod: int | None = None,
101
+ ) -> typing.AsyncGenerator[AxisArray, None]:
101
102
  """
102
103
  Construct an asynchronous generator to generate AxisArray objects at a specified rate
103
104
  and with the specified sampling rate.
@@ -206,14 +207,14 @@ class CounterSettings(ez.Settings):
206
207
  # Message dispatch rate (Hz), 'realtime', 'ext_clock', or None (fast as possible)
207
208
  # Note: if dispatch_rate is a float then time offsets will be synthetic and the
208
209
  # system will run faster or slower than wall clock time.
209
- dispatch_rate: Optional[Union[float, str]] = None
210
+ dispatch_rate: float | str | None = None
210
211
 
211
212
  # If set to an integer, counter will rollover
212
- mod: Optional[int] = None
213
+ mod: int | None = None
213
214
 
214
215
 
215
216
  class CounterState(ez.State):
216
- gen: AsyncGenerator[AxisArray, Optional[ez.Flag]]
217
+ gen: typing.AsyncGenerator[AxisArray, ez.Flag | None]
217
218
  cur_settings: CounterSettings
218
219
  new_generator: asyncio.Event
219
220
 
@@ -262,7 +263,7 @@ class Counter(ez.Unit):
262
263
  yield self.OUTPUT_SIGNAL, out
263
264
 
264
265
  @ez.publisher(OUTPUT_SIGNAL)
265
- async def run_generator(self) -> AsyncGenerator:
266
+ async def run_generator(self) -> typing.AsyncGenerator:
266
267
  while True:
267
268
  await self.STATE.new_generator.wait()
268
269
  self.STATE.new_generator.clear()
@@ -277,11 +278,11 @@ class Counter(ez.Unit):
277
278
 
278
279
  @consumer
279
280
  def sin(
280
- axis: Optional[str] = "time",
281
+ axis: str | None = "time",
281
282
  freq: float = 1.0,
282
283
  amp: float = 1.0,
283
284
  phase: float = 0.0,
284
- ) -> Generator[AxisArray, AxisArray, None]:
285
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
285
286
  """
286
287
  Construct a generator of sinusoidal waveforms in AxisArray objects.
287
288
 
@@ -319,7 +320,7 @@ class SinGeneratorSettings(ez.Settings):
319
320
  See :obj:`sin` for parameter descriptions.
320
321
  """
321
322
 
322
- time_axis: Optional[str] = "time"
323
+ time_axis: str | None = "time"
323
324
  freq: float = 1.0 # Oscillation frequency in Hz
324
325
  amp: float = 1.0 # Amplitude
325
326
  phase: float = 0.0 # Phase offset (in radians)
@@ -353,7 +354,7 @@ class OscillatorSettings(ez.Settings):
353
354
  n_ch: int = 1
354
355
  """Number of channels to output per block"""
355
356
 
356
- dispatch_rate: Optional[Union[float, str]] = None
357
+ dispatch_rate: float | str | None = None
357
358
  """(Hz) | 'realtime' | 'ext_clock'"""
358
359
 
359
360
  freq: float = 1.0
@@ -435,7 +436,7 @@ class RandomGenerator(ez.Unit):
435
436
 
436
437
  @ez.subscriber(INPUT_SIGNAL)
437
438
  @ez.publisher(OUTPUT_SIGNAL)
438
- async def generate(self, msg: AxisArray) -> AsyncGenerator:
439
+ async def generate(self, msg: AxisArray) -> typing.AsyncGenerator:
439
440
  random_data = np.random.normal(
440
441
  size=msg.shape, loc=self.SETTINGS.loc, scale=self.SETTINGS.scale
441
442
  )
@@ -451,7 +452,7 @@ class NoiseSettings(ez.Settings):
451
452
  n_time: int # Number of samples to output per block
452
453
  fs: float # Sampling rate of signal output in Hz
453
454
  n_ch: int = 1 # Number of channels to output
454
- dispatch_rate: Optional[Union[float, str]] = None
455
+ dispatch_rate: float | str | None = None
455
456
  """(Hz), 'realtime', or 'ext_clock'"""
456
457
  loc: float = 0.0 # DC offset
457
458
  scale: float = 1.0 # Scale (in standard deviations)
@@ -553,12 +554,12 @@ class Add(ez.Unit):
553
554
  self.STATE.queue_b.put_nowait(msg)
554
555
 
555
556
  @ez.publisher(OUTPUT_SIGNAL)
556
- async def output(self) -> AsyncGenerator:
557
+ async def output(self) -> typing.AsyncGenerator:
557
558
  while True:
558
559
  a = await self.STATE.queue_a.get()
559
560
  b = await self.STATE.queue_b.get()
560
561
 
561
- yield (self.OUTPUT_SIGNAL, replace(a, data=a.data + b.data))
562
+ yield self.OUTPUT_SIGNAL, replace(a, data=a.data + b.data)
562
563
 
563
564
 
564
565
  class EEGSynthSettings(ez.Settings):
ezmsg/sigproc/wavelets.py CHANGED
@@ -4,7 +4,8 @@ import numpy as np
4
4
  import numpy.typing as npt
5
5
  import pywt
6
6
  import ezmsg.core as ez
7
- from ezmsg.util.messages.axisarray import AxisArray, replace
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+ from ezmsg.util.messages.util import replace
8
9
  from ezmsg.util.generator import consumer
9
10
 
10
11
  from .base import GenAxisArray
@@ -13,44 +14,61 @@ from .filterbank import filterbank, FilterbankMode, MinPhaseMode
13
14
 
14
15
  @consumer
15
16
  def cwt(
16
- scales: typing.Union[list, tuple, npt.NDArray],
17
- wavelet: typing.Union[str, pywt.ContinuousWavelet, pywt.Wavelet],
17
+ frequencies: list | tuple | npt.NDArray | None,
18
+ wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet,
18
19
  min_phase: MinPhaseMode = MinPhaseMode.NONE,
19
20
  axis: str = "time",
21
+ scales: list | tuple | npt.NDArray | None = None,
20
22
  ) -> typing.Generator[AxisArray, AxisArray, None]:
21
23
  """
22
24
  Perform a continuous wavelet transform.
23
25
  The function is equivalent to the :obj:`pywt.cwt` function, but is designed to work with streaming data.
24
26
 
25
27
  Args:
26
- scales: The wavelet scales to use. Note: Scales will be sorted from largest to smallest.
28
+ frequencies: The wavelet frequencies to use in Hz. If `None` provided then the scales will be used.
29
+ Note: frequencies will be sorted from smallest to largest.
27
30
  wavelet: Wavelet object or name of wavelet to use.
28
31
  min_phase: See filterbank MinPhaseMode for details.
29
32
  axis: The target axis for operation. Note that this will be moved to the -1th dimension
30
33
  because fft and matrix multiplication is much faster on the last axis.
31
34
  This axis must be in the msg.axes and it must be of type AxisArray.LinearAxis.
35
+ scales: The scales to use. If None, the scales will be calculated from the frequencies.
36
+ Note: Scales will be sorted from largest to smallest.
37
+ Note: Use of scales is deprecated in favor of frequencies. Convert scales to frequencies using
38
+ `pywt.scale2frequency(wavelet, scales, precision=10) * fs` where fs is the sampling frequency.
32
39
 
33
40
  Returns:
34
41
  A primed Generator object that expects an :obj:`AxisArray` via `.send(axis_array)` of continuous data
35
42
  and yields an :obj:`AxisArray` with a continuous wavelet transform in its data.
36
43
  """
37
- msg_out: typing.Optional[AxisArray] = None
44
+ precision = 10
45
+ msg_out: AxisArray | None = None
38
46
 
39
47
  # Check parameters
40
- scales = np.sort(scales)[::-1]
41
- assert np.all(scales > 0), "Scales must be positive."
42
- assert scales.ndim == 1, "Scales must be a 1D list, tuple, or array."
48
+ if frequencies is None and scales is None:
49
+ raise ValueError("Either frequencies or scales must be provided.")
50
+ if frequencies is not None and scales is not None:
51
+ raise ValueError("Only one of frequencies or scales can be provided.")
52
+ if scales is not None:
53
+ scales = np.sort(scales)[::-1]
54
+ assert np.all(scales > 0), "scales must be positive."
55
+ assert scales.ndim == 1, "scales must be a 1D list, tuple, or array."
56
+
43
57
  if not isinstance(wavelet, (pywt.ContinuousWavelet, pywt.Wavelet)):
44
58
  wavelet = pywt.DiscreteContinuousWavelet(wavelet)
45
- precision = 10
59
+
60
+ if frequencies is not None:
61
+ frequencies = np.sort(frequencies)
62
+ assert np.all(frequencies > 0), "frequencies must be positive."
63
+ assert frequencies.ndim == 1, "frequencies must be a 1D list, tuple, or array."
46
64
 
47
65
  # State variables
48
- neg_rt_scales = -np.sqrt(scales)[:, None]
66
+ neg_rt_scales: npt.NDArray | None = None
49
67
  int_psi, wave_xvec = pywt.integrate_wavelet(wavelet, precision=precision)
50
68
  int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
51
- template: typing.Optional[AxisArray] = None
52
- fbgen: typing.Optional[typing.Generator[AxisArray, AxisArray, None]] = None
53
- last_conv_samp: typing.Optional[npt.NDArray] = None
69
+ template: AxisArray | None = None
70
+ fbgen: typing.Generator[AxisArray, AxisArray, None] | None = None
71
+ last_conv_samp: npt.NDArray | None = None
54
72
 
55
73
  # Reset if input changed
56
74
  check_input = {
@@ -76,6 +94,12 @@ def cwt(
76
94
  check_input["shape"] = in_shape
77
95
  check_input["key"] = msg_in.key
78
96
 
97
+ if frequencies is not None:
98
+ scales = pywt.frequency2scale(
99
+ wavelet, frequencies * msg_in.axes[axis].gain, precision=precision
100
+ )
101
+ neg_rt_scales = -np.sqrt(scales)[:, None]
102
+
79
103
  # convert int_psi, wave_xvec to the same precision as the data
80
104
  dt_data = msg_in.data.dtype # _check_dtype(msg_in.data)
81
105
  dt_cplx = np.result_type(dt_data, np.complex64)
@@ -148,8 +172,8 @@ class CWTSettings(ez.Settings):
148
172
  See :obj:`cwt` for argument details.
149
173
  """
150
174
 
151
- scales: typing.Union[list, tuple, npt.NDArray]
152
- wavelet: typing.Union[str, pywt.ContinuousWavelet, pywt.Wavelet]
175
+ scales: list | tuple | npt.NDArray
176
+ wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet
153
177
  min_phase: MinPhaseMode = MinPhaseMode.NONE
154
178
  axis: str = "time"
155
179