ezmsg-sigproc 1.8.2__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +36 -39
  3. ezmsg/sigproc/adaptive_lattice_notch.py +231 -0
  4. ezmsg/sigproc/affinetransform.py +169 -163
  5. ezmsg/sigproc/aggregate.py +133 -101
  6. ezmsg/sigproc/bandpower.py +64 -52
  7. ezmsg/sigproc/base.py +1242 -0
  8. ezmsg/sigproc/butterworthfilter.py +37 -33
  9. ezmsg/sigproc/cheby.py +29 -17
  10. ezmsg/sigproc/combfilter.py +163 -0
  11. ezmsg/sigproc/decimate.py +19 -10
  12. ezmsg/sigproc/detrend.py +29 -0
  13. ezmsg/sigproc/diff.py +81 -0
  14. ezmsg/sigproc/downsample.py +78 -84
  15. ezmsg/sigproc/ewma.py +197 -0
  16. ezmsg/sigproc/extract_axis.py +41 -0
  17. ezmsg/sigproc/filter.py +257 -141
  18. ezmsg/sigproc/filterbank.py +247 -199
  19. ezmsg/sigproc/math/abs.py +17 -22
  20. ezmsg/sigproc/math/clip.py +24 -24
  21. ezmsg/sigproc/math/difference.py +34 -30
  22. ezmsg/sigproc/math/invert.py +13 -25
  23. ezmsg/sigproc/math/log.py +28 -33
  24. ezmsg/sigproc/math/scale.py +18 -26
  25. ezmsg/sigproc/quantize.py +71 -0
  26. ezmsg/sigproc/resample.py +298 -0
  27. ezmsg/sigproc/sampler.py +241 -259
  28. ezmsg/sigproc/scaler.py +55 -218
  29. ezmsg/sigproc/signalinjector.py +52 -43
  30. ezmsg/sigproc/slicer.py +81 -89
  31. ezmsg/sigproc/spectrogram.py +77 -75
  32. ezmsg/sigproc/spectrum.py +203 -168
  33. ezmsg/sigproc/synth.py +546 -393
  34. ezmsg/sigproc/transpose.py +131 -0
  35. ezmsg/sigproc/util/asio.py +156 -0
  36. ezmsg/sigproc/util/message.py +31 -0
  37. ezmsg/sigproc/util/profile.py +55 -12
  38. ezmsg/sigproc/util/typeresolution.py +83 -0
  39. ezmsg/sigproc/wavelets.py +154 -153
  40. ezmsg/sigproc/window.py +269 -211
  41. {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/METADATA +2 -1
  42. ezmsg_sigproc-2.1.0.dist-info/RECORD +51 -0
  43. ezmsg_sigproc-1.8.2.dist-info/RECORD +0 -39
  44. {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/WHEEL +0 -0
  45. {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '1.8.2'
21
- __version_tuple__ = version_tuple = (1, 8, 2)
20
+ __version__ = version = '2.1.0'
21
+ __version_tuple__ = version_tuple = (2, 1, 0)
@@ -1,14 +1,10 @@
1
- import typing
2
-
3
- import numpy as np
4
1
  import scipy.special
5
2
  import ezmsg.core as ez
6
3
  from ezmsg.util.messages.axisarray import AxisArray
7
4
  from ezmsg.util.messages.util import replace
8
- from ezmsg.util.generator import consumer
9
5
 
10
6
  from .spectral import OptionsEnum
11
- from .base import GenAxisArray
7
+ from .base import BaseTransformer, BaseTransformerUnit
12
8
 
13
9
 
14
10
  class ActivationFunction(OptionsEnum):
@@ -39,10 +35,41 @@ ACTIVATIONS = {
39
35
  }
40
36
 
41
37
 
42
- @consumer
38
+ class ActivationSettings(ez.Settings):
39
+ function: str | ActivationFunction = ActivationFunction.NONE
40
+ """An enum value from ActivationFunction or a string representing the activation function.
41
+ Possible values are: SIGMOID, EXPIT, LOGIT, LOGEXPIT, "sigmoid", "expit", "logit", "log_expit".
42
+ SIGMOID and EXPIT are equivalent. See :obj:`scipy.special.expit` for more details."""
43
+
44
+
45
+ class ActivationTransformer(BaseTransformer[ActivationSettings, AxisArray, AxisArray]):
46
+ def _process(self, message: AxisArray) -> AxisArray:
47
+ if type(self.settings.function) is ActivationFunction:
48
+ func = ACTIVATIONS[self.settings.function]
49
+ else:
50
+ # str type handling
51
+ function = self.settings.function.lower()
52
+ if function not in ActivationFunction.options():
53
+ raise ValueError(
54
+ f"Unrecognized activation function {function}. Must be one of {ACTIVATIONS.keys()}"
55
+ )
56
+ function = list(ACTIVATIONS.keys())[
57
+ ActivationFunction.options().index(function)
58
+ ]
59
+ func = ACTIVATIONS[function]
60
+
61
+ return replace(message, data=func(message.data))
62
+
63
+
64
+ class Activation(
65
+ BaseTransformerUnit[ActivationSettings, AxisArray, AxisArray, ActivationTransformer]
66
+ ):
67
+ SETTINGS = ActivationSettings
68
+
69
+
43
70
  def activation(
44
71
  function: str | ActivationFunction,
45
- ) -> typing.Generator[AxisArray, AxisArray, None]:
72
+ ) -> ActivationTransformer:
46
73
  """
47
74
  Transform the data with a simple activation function.
48
75
 
@@ -51,37 +78,7 @@ def activation(
51
78
  Possible values are: SIGMOID, EXPIT, LOGIT, LOGEXPIT, "sigmoid", "expit", "logit", "log_expit".
52
79
  SIGMOID and EXPIT are equivalent. See :obj:`scipy.special.expit` for more details.
53
80
 
54
- Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an AxisArray
55
- with the data payload containing a transformed version of the input data.
81
+ Returns: :obj:`ActivationTransformer`
56
82
 
57
83
  """
58
- if type(function) is ActivationFunction:
59
- func = ACTIVATIONS[function]
60
- else:
61
- # str type. There's probably an easier way to support either enum or str argument. Oh well this works.
62
- function: str = function.lower()
63
- if function not in ActivationFunction.options():
64
- raise ValueError(
65
- f"Unrecognized activation function {function}. Must be one of {ACTIVATIONS.keys()}"
66
- )
67
- function = list(ACTIVATIONS.keys())[
68
- ActivationFunction.options().index(function)
69
- ]
70
- func = ACTIVATIONS[function]
71
-
72
- msg_out = AxisArray(np.array([]), dims=[""])
73
-
74
- while True:
75
- msg_in: AxisArray = yield msg_out
76
- msg_out = replace(msg_in, data=func(msg_in.data))
77
-
78
-
79
- class ActivationSettings(ez.Settings):
80
- function: str = ActivationFunction.NONE
81
-
82
-
83
- class Activation(GenAxisArray):
84
- SETTINGS = ActivationSettings
85
-
86
- def construct_generator(self):
87
- self.STATE.gen = activation(function=self.SETTINGS.function)
84
+ return ActivationTransformer(ActivationSettings(function=function))
@@ -0,0 +1,231 @@
1
+ import numpy as np
2
+ import numpy.typing as npt
3
+ import scipy.signal
4
+ import ezmsg.core as ez
5
+ from ezmsg.util.messages.axisarray import AxisArray, CoordinateAxis
6
+ from ezmsg.util.messages.util import replace
7
+
8
+ from .base import processor_state, BaseStatefulTransformer
9
+
10
+
11
+ class AdaptiveLatticeNotchFilterSettings(ez.Settings):
12
+ """Settings for the Adaptive Lattice Notch Filter."""
13
+
14
+ gamma: float = 0.995
15
+ """Pole-zero contraction factor"""
16
+ mu: float = 0.99
17
+ """Smoothing factor"""
18
+ eta: float = 0.99
19
+ """Forgetting factor"""
20
+ axis: str = "time"
21
+ """Axis to apply filter to"""
22
+ init_notch_freq: float | None = None
23
+ """Initial notch frequency. Should be < nyquist."""
24
+ chunkwise: bool = False
25
+ """Speed up processing by updating the target freq once per chunk only."""
26
+
27
+
28
+ @processor_state
29
+ class AdaptiveLatticeNotchFilterState:
30
+ """State for the Adaptive Lattice Notch Filter."""
31
+
32
+ s_history: npt.NDArray | None = None
33
+ """Historical `s` values for the adaptive filter."""
34
+
35
+ p: npt.NDArray | None = None
36
+ """Accumulated product for reflection coefficient update"""
37
+
38
+ q: npt.NDArray | None = None
39
+ """Accumulated product for reflection coefficient update"""
40
+
41
+ k1: npt.NDArray | None = None
42
+ """Reflection coefficient"""
43
+
44
+ freq_template: CoordinateAxis | None = None
45
+ """Template for the frequency axis on the output"""
46
+
47
+ zi: npt.NDArray | None = None
48
+ """Initial conditions for the filter, updated after every chunk"""
49
+
50
+
51
+ class AdaptiveLatticeNotchFilterTransformer(
52
+ BaseStatefulTransformer[
53
+ AdaptiveLatticeNotchFilterSettings,
54
+ AxisArray,
55
+ AxisArray,
56
+ AdaptiveLatticeNotchFilterState,
57
+ ]
58
+ ):
59
+ """
60
+ Adaptive Lattice Notch Filter implementation as a stateful transformer.
61
+
62
+ https://biomedical-engineering-online.biomedcentral.com/articles/10.1186/1475-925X-13-170
63
+
64
+ The filter automatically tracks and removes frequency components from the input signal.
65
+ It outputs the estimated frequency (in Hz) and the filtered sample.
66
+ """
67
+
68
+ def _hash_message(self, message: AxisArray) -> int:
69
+ ax_idx = message.get_axis_idx(self.settings.axis)
70
+ sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
71
+ return hash((message.key, message.axes[self.settings.axis].gain, sample_shape))
72
+
73
+ def _reset_state(self, message: AxisArray) -> None:
74
+ ax_idx = message.get_axis_idx(self.settings.axis)
75
+ sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
76
+
77
+ fs = 1 / message.axes[self.settings.axis].gain
78
+ init_f = (
79
+ self.settings.init_notch_freq
80
+ if self.settings.init_notch_freq is not None
81
+ else 0.07178314656435313 * fs
82
+ )
83
+ init_omega = init_f * (2 * np.pi) / fs
84
+ init_k1 = -np.cos(init_omega)
85
+
86
+ """Reset filter state to initial values."""
87
+ self._state = AdaptiveLatticeNotchFilterState()
88
+ self._state.s_history = np.zeros((2,) + sample_shape, dtype=float)
89
+ self._state.p = np.zeros(sample_shape, dtype=float)
90
+ self._state.q = np.zeros(sample_shape, dtype=float)
91
+ self._state.k1 = init_k1 + np.zeros(sample_shape, dtype=float)
92
+ self._state.freq_template = CoordinateAxis(
93
+ data=np.zeros((0,) + sample_shape, dtype=float),
94
+ dims=[self.settings.axis]
95
+ + message.dims[:ax_idx]
96
+ + message.dims[ax_idx + 1 :],
97
+ unit="Hz",
98
+ )
99
+
100
+ # Initialize the initial conditions for the filter
101
+ self._state.zi = np.zeros((2, np.prod(sample_shape)), dtype=float)
102
+ # Note: we could calculate it properly, but as long as we are initializing s_history with zeros,
103
+ # it will always be zero.
104
+ # a = [1, init_k1 * (1 + self.settings.gamma), self.settings.gamma]
105
+ # b = [1]
106
+ # s = np.reshape(self._state.s_history, (2, -1))
107
+ # for feat_ix in range(np.prod(sample_shape)):
108
+ # self._state.zi[:, feat_ix] = scipy.signal.lfiltic(b, a, s[::-1, feat_ix], x=None)
109
+
110
+ def _process(self, message: AxisArray) -> AxisArray:
111
+ x_data = message.data
112
+ ax_idx = message.get_axis_idx(self.settings.axis)
113
+
114
+ # TODO: Time should be moved to -1th axis, not the 0th axis
115
+ if message.dims[0] != self.settings.axis:
116
+ x_data = np.moveaxis(x_data, ax_idx, 0)
117
+
118
+ # Access settings once
119
+ gamma = self.settings.gamma
120
+ eta = self.settings.eta
121
+ mu = self.settings.mu
122
+ fs = 1 / message.axes[self.settings.axis].gain
123
+
124
+ # Pre-compute constants
125
+ one_minus_eta = 1 - eta
126
+ one_minus_mu = 1 - mu
127
+ gamma_plus_1 = 1 + gamma
128
+ omega_scale = fs / (2 * np.pi)
129
+
130
+ # For the lattice filter with constant k1:
131
+ # s_n = x_n - k1*(1+gamma)*s_n_1 - gamma*s_n_2
132
+ # This is equivalent to an IIR filter with b=1, a=[1, k1*(1+gamma), gamma]
133
+
134
+ # For the output filter:
135
+ # y_n = s_n + 2*k1*s_n_1 + s_n_2
136
+ # We can treat this as a direct-form FIR filter applied to s_out
137
+
138
+ if self.settings.chunkwise:
139
+ # Process each chunk using current filter parameters
140
+ # Reshape input and prepare output arrays
141
+ _s = self._state.s_history.reshape((2, -1))
142
+ _x = x_data.reshape((x_data.shape[0], -1))
143
+ s_n = np.zeros_like(_x)
144
+ y_out = np.zeros_like(_x)
145
+
146
+ # Apply static filter for each feature dimension
147
+ for ix, k in enumerate(self._state.k1.flatten()):
148
+ # Filter to get s_n (notch filter state)
149
+ a_s = [1, k * gamma_plus_1, gamma]
150
+ s_n[:, ix], self._state.zi[:, ix] = scipy.signal.lfilter(
151
+ [1], a_s, _x[:, ix], zi=self._state.zi[:, ix]
152
+ )
153
+
154
+ # Apply output filter to get y_out
155
+ b_y = [1, 2 * k, 1]
156
+ y_out[:, ix] = scipy.signal.lfilter(b_y, [1], s_n[:, ix])
157
+
158
+ # Update filter parameters using final values from the chunk
159
+ s_n_reshaped = s_n.reshape((s_n.shape[0],) + x_data.shape[1:])
160
+ s_final = s_n_reshaped[-1] # Current s_n
161
+ s_final_1 = s_n_reshaped[-2] # s_n_1
162
+ s_final_2 = (
163
+ s_n_reshaped[-3] if len(s_n_reshaped) > 2 else self._state.s_history[0]
164
+ ) # s_n_2
165
+
166
+ # Update p and q using final values
167
+ self._state.p = eta * self._state.p + one_minus_eta * (
168
+ s_final_1 * (s_final + s_final_2)
169
+ )
170
+ self._state.q = eta * self._state.q + one_minus_eta * (
171
+ 2 * (s_final_1 * s_final_1)
172
+ )
173
+
174
+ # Update reflection coefficient
175
+ new_k1 = -self._state.p / (self._state.q + 1e-8) # Avoid division by zero
176
+ new_k1 = np.clip(new_k1, -1, 1) # Clip to prevent instability
177
+ self._state.k1 = mu * self._state.k1 + one_minus_mu * new_k1 # Smoothed
178
+
179
+ # Calculate frequency from updated k1 value
180
+ omega_n = np.arccos(-self._state.k1)
181
+ freq = omega_n * omega_scale
182
+ freq_out = np.full_like(x_data.reshape(x_data.shape), freq)
183
+
184
+ # Update s_history for next chunk
185
+ self._state.s_history = s_n_reshaped[-2:].reshape((2,) + x_data.shape[1:])
186
+
187
+ # Reshape y_out back to original dimensions
188
+ y_out = y_out.reshape(x_data.shape)
189
+
190
+ else:
191
+ # Perform filtering, sample-by-sample
192
+ y_out = np.zeros_like(x_data)
193
+ freq_out = np.zeros_like(x_data)
194
+ for sample_ix, x_n in enumerate(x_data):
195
+ s_n_1 = self._state.s_history[-1]
196
+ s_n_2 = self._state.s_history[-2]
197
+
198
+ s_n = x_n - self._state.k1 * gamma_plus_1 * s_n_1 - gamma * s_n_2
199
+ y_out[sample_ix] = s_n + 2 * self._state.k1 * s_n_1 + s_n_2
200
+
201
+ # Update filter parameters
202
+ self._state.p = eta * self._state.p + one_minus_eta * (
203
+ s_n_1 * (s_n + s_n_2)
204
+ )
205
+ self._state.q = eta * self._state.q + one_minus_eta * (
206
+ 2 * (s_n_1 * s_n_1)
207
+ )
208
+
209
+ # Update reflection coefficient
210
+ new_k1 = -self._state.p / (
211
+ self._state.q + 1e-8
212
+ ) # Avoid division by zero
213
+ new_k1 = np.clip(new_k1, -1, 1) # Clip to prevent instability
214
+ self._state.k1 = mu * self._state.k1 + one_minus_mu * new_k1 # Smoothed
215
+
216
+ # Compute normalized angular frequency using equation 13 from the paper
217
+ omega_n = np.arccos(-self._state.k1)
218
+ freq_out[sample_ix] = omega_n * omega_scale # As Hz
219
+
220
+ # Update for next iteration
221
+ self._state.s_history[-2] = s_n_1
222
+ self._state.s_history[-1] = s_n
223
+
224
+ return replace(
225
+ message,
226
+ data=y_out,
227
+ axes={
228
+ **message.axes,
229
+ "freq": replace(self._state.freq_template, data=freq_out),
230
+ },
231
+ )