ezmsg-sigproc 1.4.2__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ezmsg/sigproc/scaler.py CHANGED
@@ -1,10 +1,12 @@
1
- from dataclasses import replace
1
+ import functools
2
2
  import typing
3
3
 
4
4
  import numpy as np
5
5
  import numpy.typing as npt
6
+ import scipy.signal
6
7
  import ezmsg.core as ez
7
8
  from ezmsg.util.messages.axisarray import AxisArray
9
+ from ezmsg.util.messages.util import replace
8
10
  from ezmsg.util.generator import consumer
9
11
 
10
12
  from .base import GenAxisArray
@@ -28,9 +30,139 @@ def _alpha_from_tau(tau: float, dt: float) -> float:
28
30
  return 1 - np.exp(-dt / tau)
29
31
 
30
32
 
33
+ def ewma_step(
34
+ sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None
35
+ ):
36
+ """
37
+ Do an exponentially weighted moving average step.
38
+
39
+ Args:
40
+ sample: The new sample.
41
+ zi: The output of the previous step.
42
+ alpha: Fading factor.
43
+ beta: Persisting factor. If None, it is calculated as 1-alpha.
44
+
45
+ Returns:
46
+ alpha * sample + beta * zi
47
+
48
+ """
49
+ # Potential micro-optimization:
50
+ # Current: scalar-arr multiplication, scalar-arr multiplication, arr-arr addition
51
+ # Alternative: arr-arr subtraction, arr-arr multiplication, arr-arr addition
52
+ # return zi + alpha * (new_sample - zi)
53
+ beta = beta or (1 - alpha)
54
+ return alpha * sample + beta * zi
55
+
56
+
57
+ class EWMA:
58
+ def __init__(self, alpha: float):
59
+ self.beta = 1 - alpha
60
+ self._filt_func = functools.partial(
61
+ scipy.signal.lfilter, [alpha], [1.0, alpha - 1.0], axis=0
62
+ )
63
+ self.prev = None
64
+
65
+ def compute(self, arr: npt.NDArray) -> npt.NDArray:
66
+ if self.prev is None:
67
+ self.prev = self.beta * arr[:1]
68
+ expected, self.prev = self._filt_func(arr, zi=self.prev)
69
+ return expected
70
+
71
+
72
+ class EWMA_Deprecated:
73
+ """
74
+ Grabbed these methods from https://stackoverflow.com/a/70998068 and other answers in that topic,
75
+ but they ended up being slower than the scipy.signal.lfilter method.
76
+ Additionally, `compute` and `compute2` suffer from potential errors as the vector length increases
77
+ and beta**n approaches zero.
78
+ """
79
+
80
+ def __init__(self, alpha: float, max_len: int):
81
+ self.alpha = alpha
82
+ self.beta = 1 - alpha
83
+ self.prev: npt.NDArray | None = None
84
+ self.weights = np.empty((max_len + 1,), float)
85
+ self._precalc_weights(max_len)
86
+ self._step_func = functools.partial(ewma_step, alpha=self.alpha, beta=self.beta)
87
+
88
+ def _precalc_weights(self, n: int):
89
+ # (1-α)^0, (1-α)^1, (1-α)^2, ..., (1-α)^n
90
+ np.power(self.beta, np.arange(n + 1), out=self.weights)
91
+
92
+ def compute(self, arr: npt.NDArray, out: npt.NDArray | None = None) -> npt.NDArray:
93
+ if out is None:
94
+ out = np.empty(arr.shape, arr.dtype)
95
+
96
+ n = arr.shape[0]
97
+ weights = self.weights[:n]
98
+ weights = np.expand_dims(weights, list(range(1, arr.ndim)))
99
+
100
+ # α*P0, α*P1, α*P2, ..., α*Pn
101
+ np.multiply(self.alpha, arr, out)
102
+
103
+ # α*P0/(1-α)^0, α*P1/(1-α)^1, α*P2/(1-α)^2, ..., α*Pn/(1-α)^n
104
+ np.divide(out, weights, out)
105
+
106
+ # α*P0/(1-α)^0, α*P0/(1-α)^0 + α*P1/(1-α)^1, ...
107
+ np.cumsum(out, axis=0, out=out)
108
+
109
+ # (α*P0/(1-α)^0)*(1-α)^0, (α*P0/(1-α)^0 + α*P1/(1-α)^1)*(1-α)^1, ...
110
+ np.multiply(out, weights, out)
111
+
112
+ # Add the previous output
113
+ if self.prev is None:
114
+ self.prev = arr[:1]
115
+
116
+ out += self.prev * np.expand_dims(
117
+ self.weights[1 : n + 1], list(range(1, arr.ndim))
118
+ )
119
+
120
+ self.prev = out[-1:]
121
+
122
+ return out
123
+
124
+ def compute2(self, arr: npt.NDArray) -> npt.NDArray:
125
+ """
126
+ Compute the Exponentially Weighted Moving Average (EWMA) of the input array.
127
+
128
+ Args:
129
+ arr: The input array to be smoothed.
130
+
131
+ Returns:
132
+ The smoothed array.
133
+ """
134
+ n = arr.shape[0]
135
+ if n > len(self.weights):
136
+ self._precalc_weights(n)
137
+ weights = self.weights[:n][::-1]
138
+ weights = np.expand_dims(weights, list(range(1, arr.ndim)))
139
+
140
+ result = np.cumsum(self.alpha * weights * arr, axis=0)
141
+ result = result / weights
142
+
143
+ # Handle the first call when prev is unset
144
+ if self.prev is None:
145
+ self.prev = arr[:1]
146
+
147
+ result += self.prev * np.expand_dims(
148
+ self.weights[1 : n + 1], list(range(1, arr.ndim))
149
+ )
150
+
151
+ # Store the result back into prev
152
+ self.prev = result[-1]
153
+
154
+ return result
155
+
156
+ def compute_sample(self, new_sample: npt.NDArray) -> npt.NDArray:
157
+ if self.prev is None:
158
+ self.prev = new_sample
159
+ self.prev = self._step_func(new_sample, self.prev)
160
+ return self.prev
161
+
162
+
31
163
  @consumer
32
164
  def scaler(
33
- time_constant: float = 1.0, axis: typing.Optional[str] = None
165
+ time_constant: float = 1.0, axis: str | None = None
34
166
  ) -> typing.Generator[AxisArray, AxisArray, None]:
35
167
  """
36
168
  Apply the adaptive standard scaler from https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/
@@ -78,7 +210,7 @@ def scaler(
78
210
 
79
211
  @consumer
80
212
  def scaler_np(
81
- time_constant: float = 1.0, axis: typing.Optional[str] = None
213
+ time_constant: float = 1.0, axis: str | None = None
82
214
  ) -> typing.Generator[AxisArray, AxisArray, None]:
83
215
  """
84
216
  Create a generator function that applies an adaptive standard scaler.
@@ -87,6 +219,7 @@ def scaler_np(
87
219
  Args:
88
220
  time_constant: Decay constant `tau` in seconds.
89
221
  axis: The name of the axis to accumulate statistics over.
222
+ Note: The axis must be in the msg.axes and be of type AxisArray.LinearAxis.
90
223
 
91
224
  Returns:
92
225
  A primed generator object that expects to be sent a :obj:`AxisArray` via `.send(axis_array)`
@@ -95,10 +228,8 @@ def scaler_np(
95
228
  msg_out = AxisArray(np.array([]), dims=[""])
96
229
 
97
230
  # State variables
98
- alpha: float = 0.0
99
- means: typing.Optional[npt.NDArray] = None
100
- vars_means: typing.Optional[npt.NDArray] = None
101
- vars_sq_means: typing.Optional[npt.NDArray] = None
231
+ samps_ewma: EWMA | None = None
232
+ vars_sq_ewma: EWMA | None = None
102
233
 
103
234
  # Reset if input changes
104
235
  check_input = {
@@ -107,45 +238,32 @@ def scaler_np(
107
238
  "key": None, # Key change implies buffered means/vars are invalid.
108
239
  }
109
240
 
110
- def _ew_update(arr, prev, _alpha):
111
- if np.all(prev == 0):
112
- return arr
113
- # return _alpha * arr + (1 - _alpha) * prev
114
- # Micro-optimization: sub, mult, add (below) is faster than sub, mult, mult, add (above)
115
- return prev + _alpha * (arr - prev)
116
-
117
241
  while True:
118
242
  msg_in: AxisArray = yield msg_out
119
243
 
120
244
  axis = axis or msg_in.dims[0]
121
245
  axis_idx = msg_in.get_axis_idx(axis)
122
246
 
123
- if msg_in.axes[axis].gain != check_input["gain"]:
124
- alpha = _alpha_from_tau(time_constant, msg_in.axes[axis].gain)
125
- check_input["gain"] = msg_in.axes[axis].gain
126
-
127
247
  data: npt.NDArray = np.moveaxis(msg_in.data, axis_idx, 0)
128
248
  b_reset = data.shape[1:] != check_input["shape"]
129
- b_reset |= msg_in.key != check_input["key"]
249
+ b_reset = b_reset or msg_in.axes[axis].gain != check_input["gain"]
250
+ b_reset = b_reset or msg_in.key != check_input["key"]
130
251
  if b_reset:
131
252
  check_input["shape"] = data.shape[1:]
253
+ check_input["gain"] = msg_in.axes[axis].gain
132
254
  check_input["key"] = msg_in.key
133
- vars_sq_means = np.zeros_like(data[0], dtype=float)
134
- vars_means = np.zeros_like(data[0], dtype=float)
135
- means = np.zeros_like(data[0], dtype=float)
136
-
137
- result = np.zeros_like(data)
138
- for sample_ix in range(data.shape[0]):
139
- sample = data[sample_ix]
140
- # Update step
141
- vars_means = _ew_update(sample, vars_means, alpha)
142
- vars_sq_means = _ew_update(sample**2, vars_sq_means, alpha)
143
- means = _ew_update(sample, means, alpha)
144
- # Get step
145
- varis = vars_sq_means - vars_means**2
146
- y = (sample - means) / (varis**0.5)
147
- result[sample_ix] = y
255
+ alpha = _alpha_from_tau(time_constant, msg_in.axes[axis].gain)
256
+ samps_ewma = EWMA(alpha=alpha)
257
+ vars_sq_ewma = EWMA(alpha=alpha)
258
+
259
+ # Update step
260
+ means = samps_ewma.compute(data)
261
+ vars_sq_means = vars_sq_ewma.compute(data**2)
148
262
 
263
+ # Get step
264
+ varis = vars_sq_means - means**2
265
+ with np.errstate(divide="ignore", invalid="ignore"):
266
+ result = (data - means) / (varis**0.5)
149
267
  result[np.isnan(result)] = 0.0
150
268
  result = np.moveaxis(result, 0, axis_idx)
151
269
  msg_out = replace(msg_in, data=result)
@@ -158,7 +276,7 @@ class AdaptiveStandardScalerSettings(ez.Settings):
158
276
  """
159
277
 
160
278
  time_constant: float = 1.0
161
- axis: typing.Optional[str] = None
279
+ axis: str | None = None
162
280
 
163
281
 
164
282
  class AdaptiveStandardScaler(GenAxisArray):
@@ -1,22 +1,22 @@
1
- from dataclasses import replace
2
1
  import typing
3
2
 
4
3
  import ezmsg.core as ez
5
4
  from ezmsg.util.messages.axisarray import AxisArray
5
+ from ezmsg.util.messages.util import replace
6
6
  import numpy as np
7
7
  import numpy.typing as npt
8
8
 
9
9
 
10
10
  class SignalInjectorSettings(ez.Settings):
11
11
  time_dim: str = "time" # Input signal needs a time dimension with units in sec.
12
- frequency: typing.Optional[float] = None # Hz
12
+ frequency: float | None = None # Hz
13
13
  amplitude: float = 1.0
14
- mixing_seed: typing.Optional[int] = None
14
+ mixing_seed: int | None = None
15
15
 
16
16
 
17
17
  class SignalInjectorState(ez.State):
18
- cur_shape: typing.Optional[typing.Tuple[int, ...]] = None
19
- cur_frequency: typing.Optional[float] = None
18
+ cur_shape: tuple[int, ...] | None = None
19
+ cur_frequency: float | None = None
20
20
  cur_amplitude: float
21
21
  mixing: npt.NDArray
22
22
 
@@ -30,7 +30,7 @@ class SignalInjector(ez.Unit):
30
30
  SETTINGS = SignalInjectorSettings
31
31
  STATE = SignalInjectorState
32
32
 
33
- INPUT_FREQUENCY = ez.InputStream(typing.Optional[float])
33
+ INPUT_FREQUENCY = ez.InputStream(float | None)
34
34
  INPUT_AMPLITUDE = ez.InputStream(float)
35
35
  INPUT_SIGNAL = ez.InputStream(AxisArray)
36
36
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
@@ -41,7 +41,7 @@ class SignalInjector(ez.Unit):
41
41
  self.STATE.mixing = np.array([])
42
42
 
43
43
  @ez.subscriber(INPUT_FREQUENCY)
44
- async def on_frequency(self, msg: typing.Optional[float]) -> None:
44
+ async def on_frequency(self, msg: float | None) -> None:
45
45
  self.STATE.cur_frequency = msg
46
46
 
47
47
  @ez.subscriber(INPUT_AMPLITUDE)
ezmsg/sigproc/slicer.py CHANGED
@@ -1,10 +1,14 @@
1
- from dataclasses import replace
2
1
  import typing
3
2
 
4
3
  import numpy as np
5
4
  import numpy.typing as npt
6
5
  import ezmsg.core as ez
7
- from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
6
+ from ezmsg.util.messages.axisarray import (
7
+ AxisArray,
8
+ slice_along_axis,
9
+ AxisBase,
10
+ replace,
11
+ )
8
12
  from ezmsg.util.generator import consumer
9
13
 
10
14
  from .base import GenAxisArray
@@ -15,7 +19,10 @@ Slicer:Select a subset of data along a particular axis.
15
19
  """
16
20
 
17
21
 
18
- def parse_slice(s: str) -> typing.Tuple[typing.Union[slice, int], ...]:
22
+ def parse_slice(
23
+ s: str,
24
+ axinfo: AxisArray.CoordinateAxis | None = None,
25
+ ) -> tuple[slice | int, ...]:
19
26
  """
20
27
  Parses a string representation of a slice and returns a tuple of slice objects.
21
28
 
@@ -26,9 +33,13 @@ def parse_slice(s: str) -> typing.Tuple[typing.Union[slice, int], ...]:
26
33
  - "5" (or any integer) -> (5,). Take only that item.
27
34
  applying this to a ndarray or AxisArray will drop the dimension.
28
35
  - A comma-separated list of the above -> a tuple of slices | ints
36
+ - A comma-separated list of values and axinfo is provided and is a CoordinateAxis -> a tuple of ints
29
37
 
30
38
  Args:
31
39
  s: The string representation of the slice.
40
+ axinfo: (Optional) If provided, and of type CoordinateAxis,
41
+ and `s` is a comma-separated list of values, then the values
42
+ in s will be checked against the values in axinfo.data.
32
43
 
33
44
  Returns:
34
45
  A tuple of slice objects and/or ints.
@@ -38,15 +49,21 @@ def parse_slice(s: str) -> typing.Tuple[typing.Union[slice, int], ...]:
38
49
  if "," not in s:
39
50
  parts = [part.strip() for part in s.split(":")]
40
51
  if len(parts) == 1:
52
+ if (
53
+ axinfo is not None
54
+ and hasattr(axinfo, "data")
55
+ and parts[0] in axinfo.data
56
+ ):
57
+ return tuple(np.where(axinfo.data == parts[0])[0])
41
58
  return (int(parts[0]),)
42
59
  return (slice(*(int(part.strip()) if part else None for part in parts)),)
43
- suplist = [parse_slice(_) for _ in s.split(",")]
60
+ suplist = [parse_slice(_, axinfo=axinfo) for _ in s.split(",")]
44
61
  return tuple([item for sublist in suplist for item in sublist])
45
62
 
46
63
 
47
64
  @consumer
48
65
  def slicer(
49
- selection: str = "", axis: typing.Optional[str] = None
66
+ selection: str = "", axis: str | None = None
50
67
  ) -> typing.Generator[AxisArray, AxisArray, None]:
51
68
  """
52
69
  Slice along a particular axis.
@@ -63,8 +80,8 @@ def slicer(
63
80
  msg_out = AxisArray(np.array([]), dims=[""])
64
81
 
65
82
  # State variables
66
- _slice: typing.Optional[typing.Union[slice, npt.NDArray]] = None
67
- new_axis: typing.Optional[AxisArray.Axis] = None
83
+ _slice: slice | npt.NDArray | None = None
84
+ new_axis: AxisBase | None = None
68
85
  b_change_dims: bool = False # If number of dimensions changes when slicing
69
86
 
70
87
  # Reset if input changes
@@ -92,7 +109,7 @@ def slicer(
92
109
  b_change_dims = False
93
110
 
94
111
  # Calculate the slice
95
- _slices = parse_slice(selection)
112
+ _slices = parse_slice(selection, msg_in.axes.get(axis, None))
96
113
  if len(_slices) == 1:
97
114
  _slice = _slices[0]
98
115
  # Do we drop the sliced dimension?
@@ -107,12 +124,15 @@ def slicer(
107
124
  # Create the output axis.
108
125
  if (
109
126
  axis in msg_in.axes
110
- and hasattr(msg_in.axes[axis], "labels")
111
- and len(msg_in.axes[axis].labels) > 0
127
+ and hasattr(msg_in.axes[axis], "data")
128
+ and len(msg_in.axes[axis].data) > 0
112
129
  ):
113
- in_labels = np.array(msg_in.axes[axis].labels)
114
- new_labels = in_labels[_slice].tolist()
115
- new_axis = replace(msg_in.axes[axis], labels=new_labels)
130
+ in_data = np.array(msg_in.axes[axis].data)
131
+ if b_change_dims:
132
+ out_data = in_data[_slice : _slice + 1]
133
+ else:
134
+ out_data = in_data[_slice]
135
+ new_axis = replace(msg_in.axes[axis], data=out_data)
116
136
 
117
137
  replace_kwargs = {}
118
138
  if b_change_dims:
@@ -134,7 +154,7 @@ def slicer(
134
154
 
135
155
  class SlicerSettings(ez.Settings):
136
156
  selection: str = ""
137
- axis: typing.Optional[str] = None
157
+ axis: str | None = None
138
158
 
139
159
 
140
160
  class Slicer(GenAxisArray):
@@ -12,12 +12,12 @@ from .base import GenAxisArray
12
12
 
13
13
  @consumer
14
14
  def spectrogram(
15
- window_dur: typing.Optional[float] = None,
16
- window_shift: typing.Optional[float] = None,
15
+ window_dur: float | None = None,
16
+ window_shift: float | None = None,
17
17
  window: WindowFunction = WindowFunction.HANNING,
18
18
  transform: SpectralTransform = SpectralTransform.REL_DB,
19
19
  output: SpectralOutput = SpectralOutput.POSITIVE,
20
- ) -> typing.Generator[typing.Optional[AxisArray], AxisArray, None]:
20
+ ) -> typing.Generator[AxisArray | None, AxisArray, None]:
21
21
  """
22
22
  Calculate a spectrogram on streaming data.
23
23
 
@@ -50,7 +50,7 @@ def spectrogram(
50
50
  )
51
51
 
52
52
  # State variables
53
- msg_out: typing.Optional[AxisArray] = None
53
+ msg_out: AxisArray | None = None
54
54
 
55
55
  while True:
56
56
  msg_in: AxisArray = yield msg_out
@@ -63,8 +63,8 @@ class SpectrogramSettings(ez.Settings):
63
63
  See :obj:`spectrogram` for a description of the parameters.
64
64
  """
65
65
 
66
- window_dur: typing.Optional[float] = None # window duration in seconds
67
- window_shift: typing.Optional[float] = None
66
+ window_dur: float | None = None # window duration in seconds
67
+ window_shift: float | None = None
68
68
  """"window step in seconds. If None, window_shift == window_dur"""
69
69
 
70
70
  # See SpectrumSettings for details of following settings:
ezmsg/sigproc/spectrum.py CHANGED
@@ -1,11 +1,14 @@
1
- from dataclasses import replace
2
1
  import enum
3
2
  from functools import partial
4
3
  import typing
5
4
 
6
5
  import numpy as np
7
6
  import ezmsg.core as ez
8
- from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
7
+ from ezmsg.util.messages.axisarray import (
8
+ AxisArray,
9
+ slice_along_axis,
10
+ replace,
11
+ )
9
12
  from ezmsg.util.generator import consumer
10
13
 
11
14
  from .base import GenAxisArray
@@ -65,20 +68,21 @@ class SpectralOutput(OptionsEnum):
65
68
 
66
69
  @consumer
67
70
  def spectrum(
68
- axis: typing.Optional[str] = None,
69
- out_axis: typing.Optional[str] = "freq",
71
+ axis: str | None = None,
72
+ out_axis: str | None = "freq",
70
73
  window: WindowFunction = WindowFunction.HANNING,
71
74
  transform: SpectralTransform = SpectralTransform.REL_DB,
72
75
  output: SpectralOutput = SpectralOutput.POSITIVE,
73
- norm: typing.Optional[str] = "forward",
76
+ norm: str | None = "forward",
74
77
  do_fftshift: bool = True,
75
- nfft: typing.Optional[int] = None,
78
+ nfft: int | None = None,
76
79
  ) -> typing.Generator[AxisArray, AxisArray, None]:
77
80
  """
78
81
  Calculate a spectrum on a data slice.
79
82
 
80
83
  Args:
81
84
  axis: The name of the axis on which to calculate the spectrum.
85
+ Note: The axis must have an .axes entry of type LinearAxis, not CoordinateAxis.
82
86
  out_axis: The name of the new axis. Defaults to "freq".
83
87
  window: The :obj:`WindowFunction` to apply to the data slice prior to calculating the spectrum.
84
88
  transform: The :obj:`SpectralTransform` to apply to the spectral magnitude.
@@ -101,10 +105,10 @@ def spectrum(
101
105
  apply_window = window != WindowFunction.NONE
102
106
  do_fftshift &= output == SpectralOutput.FULL
103
107
  f_sl = slice(None)
104
- freq_axis: typing.Optional[AxisArray.Axis] = None
105
- fftfun: typing.Optional[typing.Callable] = None
106
- f_transform: typing.Optional[typing.Callable] = None
107
- new_dims: typing.Optional[typing.List[str]] = None
108
+ freq_axis: AxisArray.LinearAxis | None = None
109
+ fftfun: typing.Callable | None = None
110
+ f_transform: typing.Callable | None = None
111
+ new_dims: list[str] | None = None
108
112
 
109
113
  # Reset if input changes substantially
110
114
  check_input = {
@@ -174,7 +178,7 @@ def spectrum(
174
178
  freqs = np.fft.fftshift(freqs, axes=-1)
175
179
  freqs = freqs[f_sl]
176
180
  freqs = freqs.tolist() # To please type checking
177
- freq_axis = AxisArray.Axis(
181
+ freq_axis = AxisArray.LinearAxis(
178
182
  unit="Hz", gain=freqs[1] - freqs[0], offset=freqs[0]
179
183
  )
180
184
  if out_axis is None:
@@ -234,9 +238,9 @@ class SpectrumSettings(ez.Settings):
234
238
  See :obj:`spectrum` for a description of the parameters.
235
239
  """
236
240
 
237
- axis: typing.Optional[str] = None
238
- # n: typing.Optional[int] = None # n parameter for fft
239
- out_axis: typing.Optional[str] = "freq" # If none; don't change dim name
241
+ axis: str | None = None
242
+ # n: int | None = None # n parameter for fft
243
+ out_axis: str | None = "freq" # If none; don't change dim name
240
244
  window: WindowFunction = WindowFunction.HAMMING
241
245
  transform: SpectralTransform = SpectralTransform.REL_DB
242
246
  output: SpectralOutput = SpectralOutput.POSITIVE