ezmsg-sigproc 1.8.1__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +36 -39
  3. ezmsg/sigproc/adaptive_lattice_notch.py +231 -0
  4. ezmsg/sigproc/affinetransform.py +169 -163
  5. ezmsg/sigproc/aggregate.py +119 -104
  6. ezmsg/sigproc/bandpower.py +58 -52
  7. ezmsg/sigproc/base.py +1242 -0
  8. ezmsg/sigproc/butterworthfilter.py +37 -33
  9. ezmsg/sigproc/cheby.py +29 -17
  10. ezmsg/sigproc/combfilter.py +163 -0
  11. ezmsg/sigproc/decimate.py +19 -10
  12. ezmsg/sigproc/detrend.py +29 -0
  13. ezmsg/sigproc/diff.py +81 -0
  14. ezmsg/sigproc/downsample.py +78 -78
  15. ezmsg/sigproc/ewma.py +197 -0
  16. ezmsg/sigproc/extract_axis.py +41 -0
  17. ezmsg/sigproc/filter.py +257 -141
  18. ezmsg/sigproc/filterbank.py +247 -199
  19. ezmsg/sigproc/math/abs.py +17 -22
  20. ezmsg/sigproc/math/clip.py +24 -24
  21. ezmsg/sigproc/math/difference.py +34 -30
  22. ezmsg/sigproc/math/invert.py +13 -25
  23. ezmsg/sigproc/math/log.py +28 -33
  24. ezmsg/sigproc/math/scale.py +18 -26
  25. ezmsg/sigproc/quantize.py +71 -0
  26. ezmsg/sigproc/resample.py +298 -0
  27. ezmsg/sigproc/sampler.py +241 -259
  28. ezmsg/sigproc/scaler.py +55 -218
  29. ezmsg/sigproc/signalinjector.py +52 -43
  30. ezmsg/sigproc/slicer.py +81 -89
  31. ezmsg/sigproc/spectrogram.py +77 -75
  32. ezmsg/sigproc/spectrum.py +203 -168
  33. ezmsg/sigproc/synth.py +546 -393
  34. ezmsg/sigproc/transpose.py +131 -0
  35. ezmsg/sigproc/util/asio.py +156 -0
  36. ezmsg/sigproc/util/message.py +31 -0
  37. ezmsg/sigproc/util/profile.py +55 -12
  38. ezmsg/sigproc/util/typeresolution.py +83 -0
  39. ezmsg/sigproc/wavelets.py +154 -153
  40. ezmsg/sigproc/window.py +269 -211
  41. {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/METADATA +2 -1
  42. ezmsg_sigproc-2.0.0.dist-info/RECORD +51 -0
  43. ezmsg_sigproc-1.8.1.dist-info/RECORD +0 -39
  44. {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/WHEEL +0 -0
  45. {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,82 +1,86 @@
1
- import typing
2
-
3
1
  import numpy as np
4
2
  from ezmsg.util.messages.axisarray import (
5
3
  AxisArray,
6
4
  slice_along_axis,
7
5
  replace,
8
6
  )
9
- from ezmsg.util.generator import consumer
10
7
  import ezmsg.core as ez
11
8
 
12
- from .base import GenAxisArray
9
+ from .base import (
10
+ BaseStatefulTransformer,
11
+ BaseTransformerUnit,
12
+ processor_state,
13
+ )
13
14
 
14
15
 
15
- @consumer
16
- def downsample(
17
- axis: str | None = None, target_rate: float | None = None
18
- ) -> typing.Generator[AxisArray, AxisArray, None]:
16
+ class DownsampleSettings(ez.Settings):
19
17
  """
20
- Construct a generator that yields a downsampled version of the data .send() to it.
21
- Downsampled data simply comprise every `factor`th sample.
22
- This should only be used following appropriate lowpass filtering.
23
- If your pipeline does not already have lowpass filtering then consider
24
- using the :obj:`Decimate` collection instead.
18
+ Settings for :obj:`Downsample` node.
19
+ """
20
+
21
+ axis: str = "time"
22
+ """The name of the axis along which to downsample."""
25
23
 
26
- Args:
27
- axis: The name of the axis along which to downsample.
28
- Note: The axis must exist in the message .axes and be of type AxisArray.LinearAxis.
29
- target_rate: Desired rate after downsampling. The actual rate will be the nearest integer factor of the
30
- input rate that is the same or higher than the target rate.
24
+ target_rate: float | None = None
25
+ """Desired rate after downsampling. The actual rate will be the nearest integer factor of the
26
+ input rate that is the same or higher than the target rate."""
31
27
 
32
- Returns:
33
- A primed generator object ready to receive an :obj:`AxisArray` via `.send(axis_array)`
34
- and yields an :obj:`AxisArray` with its data downsampled.
35
- Note that if a send chunk does not have sufficient samples to reach the
36
- next downsample interval then an :obj:`AxisArray` with size-zero data is yielded.
28
+ factor: int | None = None
29
+ """Explicitly specify downsample factor. If specified, target_rate is ignored."""
37
30
 
38
- """
39
- msg_out = AxisArray(np.array([]), dims=[""])
40
31
 
41
- # state variables
42
- factor: int = 0 # The integer downsampling factor. It will be determined based on the target rate.
43
- s_idx: int = 0 # Index of the next msg's first sample into the virtual rotating ds_factor counter.
32
+ @processor_state
33
+ class DownsampleState:
34
+ q: int = 0
35
+ """The integer downsampling factor. It will be determined based on the target rate."""
44
36
 
45
- check_input = {"gain": None, "key": None}
37
+ s_idx: int = 0
38
+ """Index of the next msg's first sample into the virtual rotating ds_factor counter."""
39
+
40
+
41
+ class DownsampleTransformer(
42
+ BaseStatefulTransformer[DownsampleSettings, AxisArray, AxisArray, DownsampleState]
43
+ ):
44
+ """
45
+ Downsampled data simply comprise every `factor`th sample.
46
+ This should only be used following appropriate lowpass filtering.
47
+ If your pipeline does not already have lowpass filtering then consider
48
+ using the :obj:`Decimate` collection instead.
49
+ """
46
50
 
47
- while True:
48
- msg_in: AxisArray = yield msg_out
51
+ def _hash_message(self, message: AxisArray) -> int:
52
+ return hash((message.axes[self.settings.axis].gain, message.key))
49
53
 
50
- if axis is None:
51
- axis = msg_in.dims[0]
52
- axis_info = msg_in.get_axis(axis)
53
- axis_idx = msg_in.get_axis_idx(axis)
54
+ def _reset_state(self, message: AxisArray) -> None:
55
+ axis_info = message.get_axis(self.settings.axis)
54
56
 
55
- b_reset = (
56
- msg_in.axes[axis].gain != check_input["gain"]
57
- or msg_in.key != check_input["key"]
57
+ if self.settings.factor is not None:
58
+ q = self.settings.factor
59
+ elif self.settings.target_rate is None:
60
+ q = 1
61
+ else:
62
+ q = int(1 / (axis_info.gain * self.settings.target_rate))
63
+ if q < 1:
64
+ ez.logger.warning(
65
+ f"Target rate {self.settings.target_rate} cannot be achieved with input rate of {1 / axis_info.gain}."
66
+ "Setting factor to 1."
67
+ )
68
+ q = 1
69
+ self._state.q = q
70
+ self._state.s_idx = 0
71
+
72
+ def _process(self, message: AxisArray) -> AxisArray:
73
+ axis = self.settings.axis
74
+ axis_info = message.get_axis(axis)
75
+ axis_idx = message.get_axis_idx(axis)
76
+
77
+ n_samples = message.data.shape[axis_idx]
78
+ samples = (
79
+ np.arange(self.state.s_idx, self.state.s_idx + n_samples) % self._state.q
58
80
  )
59
- if b_reset:
60
- check_input["gain"] = axis_info.gain
61
- check_input["key"] = msg_in.key
62
- # Reset state variables
63
- s_idx = 0
64
- if target_rate is None:
65
- factor = 1
66
- else:
67
- factor = int(1 / (axis_info.gain * target_rate))
68
- if factor < 1:
69
- ez.logger.warning(
70
- f"Target rate {target_rate} cannot be achieved with input rate of {1/axis_info.gain}."
71
- "Setting factor to 1."
72
- )
73
- factor = 1
74
-
75
- n_samples = msg_in.data.shape[axis_idx]
76
- samples = np.arange(s_idx, s_idx + n_samples) % factor
77
81
  if n_samples > 0:
78
82
  # Update state for next iteration.
79
- s_idx = samples[-1] + 1
83
+ self._state.s_idx = samples[-1] + 1
80
84
 
81
85
  pub_samples = np.where(samples == 0)[0]
82
86
  if len(pub_samples) > 0:
@@ -86,35 +90,31 @@ def downsample(
86
90
  n_step = 0
87
91
  data_slice = slice(None, 0, None)
88
92
  msg_out = replace(
89
- msg_in,
90
- data=slice_along_axis(msg_in.data, data_slice, axis=axis_idx),
93
+ message,
94
+ data=slice_along_axis(message.data, data_slice, axis=axis_idx),
91
95
  axes={
92
- **msg_in.axes,
96
+ **message.axes,
93
97
  axis: replace(
94
98
  axis_info,
95
- gain=axis_info.gain * factor,
99
+ gain=axis_info.gain * self._state.q,
96
100
  offset=axis_info.offset + axis_info.gain * n_step,
97
101
  ),
98
102
  },
99
103
  )
104
+ return msg_out
100
105
 
101
106
 
102
- class DownsampleSettings(ez.Settings):
103
- """
104
- Settings for :obj:`Downsample` node.
105
- See :obj:`downsample` documentation for a description of the parameters.
106
- """
107
-
108
- axis: str | None = None
109
- target_rate: float | None = None
110
-
111
-
112
- class Downsample(GenAxisArray):
113
- """:obj:`Unit` for :obj:`bandpower`."""
114
-
107
+ class Downsample(
108
+ BaseTransformerUnit[DownsampleSettings, AxisArray, AxisArray, DownsampleTransformer]
109
+ ):
115
110
  SETTINGS = DownsampleSettings
116
111
 
117
- def construct_generator(self):
118
- self.STATE.gen = downsample(
119
- axis=self.SETTINGS.axis, target_rate=self.SETTINGS.target_rate
120
- )
112
+
113
+ def downsample(
114
+ axis: str = "time",
115
+ target_rate: float | None = None,
116
+ factor: int | None = None,
117
+ ) -> DownsampleTransformer:
118
+ return DownsampleTransformer(
119
+ DownsampleSettings(axis=axis, target_rate=target_rate, factor=factor)
120
+ )
ezmsg/sigproc/ewma.py ADDED
@@ -0,0 +1,197 @@
1
+ from dataclasses import field
2
+ import functools
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ import scipy.signal as sps
7
+ import ezmsg.core as ez
8
+ from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
9
+ from ezmsg.util.messages.util import replace
10
+
11
+ from .base import BaseStatefulTransformer, processor_state, BaseTransformerUnit
12
+
13
+
14
+ def _tau_from_alpha(alpha: float, dt: float) -> float:
15
+ """
16
+ Inverse of _alpha_from_tau. See that function for explanation.
17
+ """
18
+ return -dt / np.log(1 - alpha)
19
+
20
+
21
+ def _alpha_from_tau(tau: float, dt: float) -> float:
22
+ """
23
+ # https://en.wikipedia.org/wiki/Exponential_smoothing#Time_constant
24
+ :param tau: The amount of time for the smoothed response of a unit step function to reach
25
+ 1 - 1/e approx-eq 63.2%.
26
+ :param dt: sampling period, or 1 / sampling_rate.
27
+ :return: alpha, the "fading factor" in exponential smoothing.
28
+ """
29
+ return 1 - np.exp(-dt / tau)
30
+
31
+
32
+ def ewma_step(
33
+ sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None
34
+ ):
35
+ """
36
+ Do an exponentially weighted moving average step.
37
+
38
+ Args:
39
+ sample: The new sample.
40
+ zi: The output of the previous step.
41
+ alpha: Fading factor.
42
+ beta: Persisting factor. If None, it is calculated as 1-alpha.
43
+
44
+ Returns:
45
+ alpha * sample + beta * zi
46
+
47
+ """
48
+ # Potential micro-optimization:
49
+ # Current: scalar-arr multiplication, scalar-arr multiplication, arr-arr addition
50
+ # Alternative: arr-arr subtraction, arr-arr multiplication, arr-arr addition
51
+ # return zi + alpha * (new_sample - zi)
52
+ beta = beta or (1 - alpha)
53
+ return alpha * sample + beta * zi
54
+
55
+
56
+ class EWMA_Deprecated:
57
+ """
58
+ Grabbed these methods from https://stackoverflow.com/a/70998068 and other answers in that topic,
59
+ but they ended up being slower than the scipy.signal.lfilter method.
60
+ Additionally, `compute` and `compute2` suffer from potential errors as the vector length increases
61
+ and beta**n approaches zero.
62
+ """
63
+
64
+ def __init__(self, alpha: float, max_len: int):
65
+ self.alpha = alpha
66
+ self.beta = 1 - alpha
67
+ self.prev: npt.NDArray | None = None
68
+ self.weights = np.empty((max_len + 1,), float)
69
+ self._precalc_weights(max_len)
70
+ self._step_func = functools.partial(ewma_step, alpha=self.alpha, beta=self.beta)
71
+
72
+ def _precalc_weights(self, n: int):
73
+ # (1-α)^0, (1-α)^1, (1-α)^2, ..., (1-α)^n
74
+ np.power(self.beta, np.arange(n + 1), out=self.weights)
75
+
76
+ def compute(self, arr: npt.NDArray, out: npt.NDArray | None = None) -> npt.NDArray:
77
+ if out is None:
78
+ out = np.empty(arr.shape, arr.dtype)
79
+
80
+ n = arr.shape[0]
81
+ weights = self.weights[:n]
82
+ weights = np.expand_dims(weights, list(range(1, arr.ndim)))
83
+
84
+ # α*P0, α*P1, α*P2, ..., α*Pn
85
+ np.multiply(self.alpha, arr, out)
86
+
87
+ # α*P0/(1-α)^0, α*P1/(1-α)^1, α*P2/(1-α)^2, ..., α*Pn/(1-α)^n
88
+ np.divide(out, weights, out)
89
+
90
+ # α*P0/(1-α)^0, α*P0/(1-α)^0 + α*P1/(1-α)^1, ...
91
+ np.cumsum(out, axis=0, out=out)
92
+
93
+ # (α*P0/(1-α)^0)*(1-α)^0, (α*P0/(1-α)^0 + α*P1/(1-α)^1)*(1-α)^1, ...
94
+ np.multiply(out, weights, out)
95
+
96
+ # Add the previous output
97
+ if self.prev is None:
98
+ self.prev = arr[:1]
99
+
100
+ out += self.prev * np.expand_dims(
101
+ self.weights[1 : n + 1], list(range(1, arr.ndim))
102
+ )
103
+
104
+ self.prev = out[-1:]
105
+
106
+ return out
107
+
108
+ def compute2(self, arr: npt.NDArray) -> npt.NDArray:
109
+ """
110
+ Compute the Exponentially Weighted Moving Average (EWMA) of the input array.
111
+
112
+ Args:
113
+ arr: The input array to be smoothed.
114
+
115
+ Returns:
116
+ The smoothed array.
117
+ """
118
+ n = arr.shape[0]
119
+ if n > len(self.weights):
120
+ self._precalc_weights(n)
121
+ weights = self.weights[:n][::-1]
122
+ weights = np.expand_dims(weights, list(range(1, arr.ndim)))
123
+
124
+ result = np.cumsum(self.alpha * weights * arr, axis=0)
125
+ result = result / weights
126
+
127
+ # Handle the first call when prev is unset
128
+ if self.prev is None:
129
+ self.prev = arr[:1]
130
+
131
+ result += self.prev * np.expand_dims(
132
+ self.weights[1 : n + 1], list(range(1, arr.ndim))
133
+ )
134
+
135
+ # Store the result back into prev
136
+ self.prev = result[-1]
137
+
138
+ return result
139
+
140
+ def compute_sample(self, new_sample: npt.NDArray) -> npt.NDArray:
141
+ if self.prev is None:
142
+ self.prev = new_sample
143
+ self.prev = self._step_func(new_sample, self.prev)
144
+ return self.prev
145
+
146
+
147
+ class EWMASettings(ez.Settings):
148
+ time_constant: float = 1.0
149
+ axis: str | None = None
150
+
151
+
152
+ @processor_state
153
+ class EWMAState:
154
+ alpha: float = field(default_factory=lambda: _alpha_from_tau(1.0, 1000.0))
155
+ zi: npt.NDArray | None = None
156
+
157
+
158
+ class EWMATransformer(
159
+ BaseStatefulTransformer[EWMASettings, AxisArray, AxisArray, EWMAState]
160
+ ):
161
+ def _hash_message(self, message: AxisArray) -> int:
162
+ axis = self.settings.axis or message.dims[0]
163
+ axis_idx = message.get_axis_idx(axis)
164
+ sample_shape = (
165
+ message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
166
+ )
167
+ return hash((sample_shape, message.axes[axis].gain, message.key))
168
+
169
+ def _reset_state(self, message: AxisArray) -> None:
170
+ axis = self.settings.axis or message.dims[0]
171
+ self._state.alpha = _alpha_from_tau(
172
+ self.settings.time_constant, message.axes[axis].gain
173
+ )
174
+ sub_dat = slice_along_axis(
175
+ message.data, slice(None, 1, None), axis=message.get_axis_idx(axis)
176
+ )
177
+ self._state.zi = (1 - self._state.alpha) * sub_dat
178
+
179
+ def _process(self, message: AxisArray) -> AxisArray:
180
+ if np.prod(message.data.shape) == 0:
181
+ return message
182
+ axis = self.settings.axis or message.dims[0]
183
+ axis_idx = message.get_axis_idx(axis)
184
+ expected, self._state.zi = sps.lfilter(
185
+ [self._state.alpha],
186
+ [1.0, self._state.alpha - 1.0],
187
+ message.data,
188
+ axis=axis_idx,
189
+ zi=self._state.zi,
190
+ )
191
+ return replace(message, data=expected)
192
+
193
+
194
+ class EWMAUnit(
195
+ BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, EWMATransformer]
196
+ ):
197
+ SETTINGS = EWMASettings
@@ -0,0 +1,41 @@
1
+ import numpy as np
2
+ import ezmsg.core as ez
3
+ from ezmsg.util.messages.axisarray import AxisArray, replace
4
+ from ezmsg.sigproc.base import BaseTransformer, BaseTransformerUnit
5
+
6
+
7
+ class ExtractAxisSettings(ez.Settings):
8
+ axis: str = "freq"
9
+ reference: str = "time"
10
+
11
+
12
+ class ExtractAxisData(BaseTransformer[ExtractAxisSettings, AxisArray, AxisArray]):
13
+ def _process(self, message: AxisArray) -> AxisArray:
14
+ targ_ax = message.axes[self.settings.axis]
15
+ if hasattr(targ_ax, "data"):
16
+ # Extracted axis is of type CoordinateAxis
17
+ return replace(
18
+ message,
19
+ data=targ_ax.data,
20
+ dims=targ_ax.dims,
21
+ axes={k: v for k, v in message.axes.items() if k in targ_ax.dims},
22
+ )
23
+ # Note: So far we don't have any transformers where the coordinate axis has its own axes,
24
+ # but if that happens in the future, we'd need to consider how to handle that.
25
+
26
+ else:
27
+ # Extracted axis is of type LinearAxis
28
+ # LinearAxis can only yield a 1d array data which simplifies dims and axes.
29
+ n = message.data.shape[message.get_axis_idx(self.settings.reference)]
30
+ return replace(
31
+ message,
32
+ data=targ_ax.value(np.arange(n)),
33
+ dims=[self.settings.reference],
34
+ axes={self.settings.reference: message.axes[self.settings.reference]},
35
+ )
36
+
37
+
38
+ class ExtractAxisDataUnit(
39
+ BaseTransformerUnit[ExtractAxisSettings, AxisArray, AxisArray, ExtractAxisData]
40
+ ):
41
+ SETTINGS = ExtractAxisSettings