ezmsg-sigproc 1.7.0__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. ezmsg/sigproc/__version__.py +22 -4
  2. ezmsg/sigproc/activation.py +31 -40
  3. ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
  4. ezmsg/sigproc/affinetransform.py +171 -169
  5. ezmsg/sigproc/aggregate.py +190 -97
  6. ezmsg/sigproc/bandpower.py +60 -55
  7. ezmsg/sigproc/base.py +143 -33
  8. ezmsg/sigproc/butterworthfilter.py +34 -38
  9. ezmsg/sigproc/butterworthzerophase.py +305 -0
  10. ezmsg/sigproc/cheby.py +23 -17
  11. ezmsg/sigproc/combfilter.py +160 -0
  12. ezmsg/sigproc/coordinatespaces.py +159 -0
  13. ezmsg/sigproc/decimate.py +15 -10
  14. ezmsg/sigproc/denormalize.py +78 -0
  15. ezmsg/sigproc/detrend.py +28 -0
  16. ezmsg/sigproc/diff.py +82 -0
  17. ezmsg/sigproc/downsample.py +72 -81
  18. ezmsg/sigproc/ewma.py +217 -0
  19. ezmsg/sigproc/ewmfilter.py +1 -1
  20. ezmsg/sigproc/extract_axis.py +39 -0
  21. ezmsg/sigproc/fbcca.py +307 -0
  22. ezmsg/sigproc/filter.py +254 -148
  23. ezmsg/sigproc/filterbank.py +226 -214
  24. ezmsg/sigproc/filterbankdesign.py +129 -0
  25. ezmsg/sigproc/fir_hilbert.py +336 -0
  26. ezmsg/sigproc/fir_pmc.py +209 -0
  27. ezmsg/sigproc/firfilter.py +117 -0
  28. ezmsg/sigproc/gaussiansmoothing.py +89 -0
  29. ezmsg/sigproc/kaiser.py +106 -0
  30. ezmsg/sigproc/linear.py +120 -0
  31. ezmsg/sigproc/math/abs.py +23 -22
  32. ezmsg/sigproc/math/add.py +120 -0
  33. ezmsg/sigproc/math/clip.py +33 -25
  34. ezmsg/sigproc/math/difference.py +117 -43
  35. ezmsg/sigproc/math/invert.py +18 -25
  36. ezmsg/sigproc/math/log.py +38 -33
  37. ezmsg/sigproc/math/scale.py +24 -25
  38. ezmsg/sigproc/messages.py +1 -2
  39. ezmsg/sigproc/quantize.py +68 -0
  40. ezmsg/sigproc/resample.py +278 -0
  41. ezmsg/sigproc/rollingscaler.py +232 -0
  42. ezmsg/sigproc/sampler.py +209 -254
  43. ezmsg/sigproc/scaler.py +93 -218
  44. ezmsg/sigproc/signalinjector.py +44 -43
  45. ezmsg/sigproc/slicer.py +74 -102
  46. ezmsg/sigproc/spectral.py +3 -3
  47. ezmsg/sigproc/spectrogram.py +70 -70
  48. ezmsg/sigproc/spectrum.py +187 -173
  49. ezmsg/sigproc/transpose.py +134 -0
  50. ezmsg/sigproc/util/__init__.py +0 -0
  51. ezmsg/sigproc/util/asio.py +25 -0
  52. ezmsg/sigproc/util/axisarray_buffer.py +365 -0
  53. ezmsg/sigproc/util/buffer.py +449 -0
  54. ezmsg/sigproc/util/message.py +17 -0
  55. ezmsg/sigproc/util/profile.py +23 -0
  56. ezmsg/sigproc/util/sparse.py +115 -0
  57. ezmsg/sigproc/util/typeresolution.py +17 -0
  58. ezmsg/sigproc/wavelets.py +147 -154
  59. ezmsg/sigproc/window.py +248 -210
  60. ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
  61. ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
  62. {ezmsg_sigproc-1.7.0.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -1
  63. ezmsg/sigproc/synth.py +0 -621
  64. ezmsg_sigproc-1.7.0.dist-info/METADATA +0 -58
  65. ezmsg_sigproc-1.7.0.dist-info/RECORD +0 -36
  66. /ezmsg_sigproc-1.7.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
@@ -1,16 +1,34 @@
1
- # file generated by setuptools_scm
1
+ # file generated by setuptools-scm
2
2
  # don't change, don't track in version control
3
+
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
12
+
3
13
  TYPE_CHECKING = False
4
14
  if TYPE_CHECKING:
5
- from typing import Tuple, Union
15
+ from typing import Tuple
16
+ from typing import Union
17
+
6
18
  VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
7
20
  else:
8
21
  VERSION_TUPLE = object
22
+ COMMIT_ID = object
9
23
 
10
24
  version: str
11
25
  __version__: str
12
26
  __version_tuple__: VERSION_TUPLE
13
27
  version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
30
+
31
+ __version__ = version = '2.10.0'
32
+ __version_tuple__ = version_tuple = (2, 10, 0)
14
33
 
15
- __version__ = version = '1.7.0'
16
- __version_tuple__ = version_tuple = (1, 7, 0)
34
+ __commit_id__ = commit_id = None
@@ -1,14 +1,10 @@
1
- import typing
2
-
3
- import numpy as np
4
- import scipy.special
5
1
  import ezmsg.core as ez
2
+ import scipy.special
3
+ from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
6
4
  from ezmsg.util.messages.axisarray import AxisArray
7
5
  from ezmsg.util.messages.util import replace
8
- from ezmsg.util.generator import consumer
9
6
 
10
7
  from .spectral import OptionsEnum
11
- from .base import GenAxisArray
12
8
 
13
9
 
14
10
  class ActivationFunction(OptionsEnum):
@@ -39,10 +35,35 @@ ACTIVATIONS = {
39
35
  }
40
36
 
41
37
 
42
- @consumer
38
+ class ActivationSettings(ez.Settings):
39
+ function: str | ActivationFunction = ActivationFunction.NONE
40
+ """An enum value from ActivationFunction or a string representing the activation function.
41
+ Possible values are: SIGMOID, EXPIT, LOGIT, LOGEXPIT, "sigmoid", "expit", "logit", "log_expit".
42
+ SIGMOID and EXPIT are equivalent. See :obj:`scipy.special.expit` for more details."""
43
+
44
+
45
+ class ActivationTransformer(BaseTransformer[ActivationSettings, AxisArray, AxisArray]):
46
+ def _process(self, message: AxisArray) -> AxisArray:
47
+ if type(self.settings.function) is ActivationFunction:
48
+ func = ACTIVATIONS[self.settings.function]
49
+ else:
50
+ # str type handling
51
+ function = self.settings.function.lower()
52
+ if function not in ActivationFunction.options():
53
+ raise ValueError(f"Unrecognized activation function {function}. Must be one of {ACTIVATIONS.keys()}")
54
+ function = list(ACTIVATIONS.keys())[ActivationFunction.options().index(function)]
55
+ func = ACTIVATIONS[function]
56
+
57
+ return replace(message, data=func(message.data))
58
+
59
+
60
+ class Activation(BaseTransformerUnit[ActivationSettings, AxisArray, AxisArray, ActivationTransformer]):
61
+ SETTINGS = ActivationSettings
62
+
63
+
43
64
  def activation(
44
65
  function: str | ActivationFunction,
45
- ) -> typing.Generator[AxisArray, AxisArray, None]:
66
+ ) -> ActivationTransformer:
46
67
  """
47
68
  Transform the data with a simple activation function.
48
69
 
@@ -51,37 +72,7 @@ def activation(
51
72
  Possible values are: SIGMOID, EXPIT, LOGIT, LOGEXPIT, "sigmoid", "expit", "logit", "log_expit".
52
73
  SIGMOID and EXPIT are equivalent. See :obj:`scipy.special.expit` for more details.
53
74
 
54
- Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an AxisArray
55
- with the data payload containing a transformed version of the input data.
75
+ Returns: :obj:`ActivationTransformer`
56
76
 
57
77
  """
58
- if type(function) is ActivationFunction:
59
- func = ACTIVATIONS[function]
60
- else:
61
- # str type. There's probably an easier way to support either enum or str argument. Oh well this works.
62
- function: str = function.lower()
63
- if function not in ActivationFunction.options():
64
- raise ValueError(
65
- f"Unrecognized activation function {function}. Must be one of {ACTIVATIONS.keys()}"
66
- )
67
- function = list(ACTIVATIONS.keys())[
68
- ActivationFunction.options().index(function)
69
- ]
70
- func = ACTIVATIONS[function]
71
-
72
- msg_out = AxisArray(np.array([]), dims=[""])
73
-
74
- while True:
75
- msg_in: AxisArray = yield msg_out
76
- msg_out = replace(msg_in, data=func(msg_in.data))
77
-
78
-
79
- class ActivationSettings(ez.Settings):
80
- function: str = ActivationFunction.NONE
81
-
82
-
83
- class Activation(GenAxisArray):
84
- SETTINGS = ActivationSettings
85
-
86
- def construct_generator(self):
87
- self.STATE.gen = activation(function=self.SETTINGS.function)
78
+ return ActivationTransformer(ActivationSettings(function=function))
@@ -0,0 +1,212 @@
1
+ import ezmsg.core as ez
2
+ import numpy as np
3
+ import numpy.typing as npt
4
+ import scipy.signal
5
+ from ezmsg.baseproc import BaseStatefulTransformer, processor_state
6
+ from ezmsg.util.messages.axisarray import AxisArray, CoordinateAxis
7
+ from ezmsg.util.messages.util import replace
8
+
9
+
10
+ class AdaptiveLatticeNotchFilterSettings(ez.Settings):
11
+ """Settings for the Adaptive Lattice Notch Filter."""
12
+
13
+ gamma: float = 0.995
14
+ """Pole-zero contraction factor"""
15
+ mu: float = 0.99
16
+ """Smoothing factor"""
17
+ eta: float = 0.99
18
+ """Forgetting factor"""
19
+ axis: str = "time"
20
+ """Axis to apply filter to"""
21
+ init_notch_freq: float | None = None
22
+ """Initial notch frequency. Should be < nyquist."""
23
+ chunkwise: bool = False
24
+ """Speed up processing by updating the target freq once per chunk only."""
25
+
26
+
27
+ @processor_state
28
+ class AdaptiveLatticeNotchFilterState:
29
+ """State for the Adaptive Lattice Notch Filter."""
30
+
31
+ s_history: npt.NDArray | None = None
32
+ """Historical `s` values for the adaptive filter."""
33
+
34
+ p: npt.NDArray | None = None
35
+ """Accumulated product for reflection coefficient update"""
36
+
37
+ q: npt.NDArray | None = None
38
+ """Accumulated product for reflection coefficient update"""
39
+
40
+ k1: npt.NDArray | None = None
41
+ """Reflection coefficient"""
42
+
43
+ freq_template: CoordinateAxis | None = None
44
+ """Template for the frequency axis on the output"""
45
+
46
+ zi: npt.NDArray | None = None
47
+ """Initial conditions for the filter, updated after every chunk"""
48
+
49
+
50
+ class AdaptiveLatticeNotchFilterTransformer(
51
+ BaseStatefulTransformer[
52
+ AdaptiveLatticeNotchFilterSettings,
53
+ AxisArray,
54
+ AxisArray,
55
+ AdaptiveLatticeNotchFilterState,
56
+ ]
57
+ ):
58
+ """
59
+ Adaptive Lattice Notch Filter implementation as a stateful transformer.
60
+
61
+ https://biomedical-engineering-online.biomedcentral.com/articles/10.1186/1475-925X-13-170
62
+
63
+ The filter automatically tracks and removes frequency components from the input signal.
64
+ It outputs the estimated frequency (in Hz) and the filtered sample.
65
+ """
66
+
67
+ def _hash_message(self, message: AxisArray) -> int:
68
+ ax_idx = message.get_axis_idx(self.settings.axis)
69
+ sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
70
+ return hash((message.key, message.axes[self.settings.axis].gain, sample_shape))
71
+
72
+ def _reset_state(self, message: AxisArray) -> None:
73
+ ax_idx = message.get_axis_idx(self.settings.axis)
74
+ sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
75
+
76
+ fs = 1 / message.axes[self.settings.axis].gain
77
+ init_f = (
78
+ self.settings.init_notch_freq if self.settings.init_notch_freq is not None else 0.07178314656435313 * fs
79
+ )
80
+ init_omega = init_f * (2 * np.pi) / fs
81
+ init_k1 = -np.cos(init_omega)
82
+
83
+ """Reset filter state to initial values."""
84
+ self._state = AdaptiveLatticeNotchFilterState()
85
+ self._state.s_history = np.zeros((2,) + sample_shape, dtype=float)
86
+ self._state.p = np.zeros(sample_shape, dtype=float)
87
+ self._state.q = np.zeros(sample_shape, dtype=float)
88
+ self._state.k1 = init_k1 + np.zeros(sample_shape, dtype=float)
89
+ self._state.freq_template = CoordinateAxis(
90
+ data=np.zeros((0,) + sample_shape, dtype=float),
91
+ dims=[self.settings.axis] + message.dims[:ax_idx] + message.dims[ax_idx + 1 :],
92
+ unit="Hz",
93
+ )
94
+
95
+ # Initialize the initial conditions for the filter
96
+ self._state.zi = np.zeros((2, np.prod(sample_shape)), dtype=float)
97
+ # Note: we could calculate it properly, but as long as we are initializing s_history with zeros,
98
+ # it will always be zero.
99
+ # a = [1, init_k1 * (1 + self.settings.gamma), self.settings.gamma]
100
+ # b = [1]
101
+ # s = np.reshape(self._state.s_history, (2, -1))
102
+ # for feat_ix in range(np.prod(sample_shape)):
103
+ # self._state.zi[:, feat_ix] = scipy.signal.lfiltic(b, a, s[::-1, feat_ix], x=None)
104
+
105
+ def _process(self, message: AxisArray) -> AxisArray:
106
+ x_data = message.data
107
+ ax_idx = message.get_axis_idx(self.settings.axis)
108
+
109
+ # TODO: Time should be moved to -1th axis, not the 0th axis
110
+ if message.dims[0] != self.settings.axis:
111
+ x_data = np.moveaxis(x_data, ax_idx, 0)
112
+
113
+ # Access settings once
114
+ gamma = self.settings.gamma
115
+ eta = self.settings.eta
116
+ mu = self.settings.mu
117
+ fs = 1 / message.axes[self.settings.axis].gain
118
+
119
+ # Pre-compute constants
120
+ one_minus_eta = 1 - eta
121
+ one_minus_mu = 1 - mu
122
+ gamma_plus_1 = 1 + gamma
123
+ omega_scale = fs / (2 * np.pi)
124
+
125
+ # For the lattice filter with constant k1:
126
+ # s_n = x_n - k1*(1+gamma)*s_n_1 - gamma*s_n_2
127
+ # This is equivalent to an IIR filter with b=1, a=[1, k1*(1+gamma), gamma]
128
+
129
+ # For the output filter:
130
+ # y_n = s_n + 2*k1*s_n_1 + s_n_2
131
+ # We can treat this as a direct-form FIR filter applied to s_out
132
+
133
+ if self.settings.chunkwise:
134
+ # Process each chunk using current filter parameters
135
+ # Reshape input and prepare output arrays
136
+ _s = self._state.s_history.reshape((2, -1))
137
+ _x = x_data.reshape((x_data.shape[0], -1))
138
+ s_n = np.zeros_like(_x)
139
+ y_out = np.zeros_like(_x)
140
+
141
+ # Apply static filter for each feature dimension
142
+ for ix, k in enumerate(self._state.k1.flatten()):
143
+ # Filter to get s_n (notch filter state)
144
+ a_s = [1, k * gamma_plus_1, gamma]
145
+ s_n[:, ix], self._state.zi[:, ix] = scipy.signal.lfilter([1], a_s, _x[:, ix], zi=self._state.zi[:, ix])
146
+
147
+ # Apply output filter to get y_out
148
+ b_y = [1, 2 * k, 1]
149
+ y_out[:, ix] = scipy.signal.lfilter(b_y, [1], s_n[:, ix])
150
+
151
+ # Update filter parameters using final values from the chunk
152
+ s_n_reshaped = s_n.reshape((s_n.shape[0],) + x_data.shape[1:])
153
+ s_final = s_n_reshaped[-1] # Current s_n
154
+ s_final_1 = s_n_reshaped[-2] # s_n_1
155
+ s_final_2 = s_n_reshaped[-3] if len(s_n_reshaped) > 2 else self._state.s_history[0] # s_n_2
156
+
157
+ # Update p and q using final values
158
+ self._state.p = eta * self._state.p + one_minus_eta * (s_final_1 * (s_final + s_final_2))
159
+ self._state.q = eta * self._state.q + one_minus_eta * (2 * (s_final_1 * s_final_1))
160
+
161
+ # Update reflection coefficient
162
+ new_k1 = -self._state.p / (self._state.q + 1e-8) # Avoid division by zero
163
+ new_k1 = np.clip(new_k1, -1, 1) # Clip to prevent instability
164
+ self._state.k1 = mu * self._state.k1 + one_minus_mu * new_k1 # Smoothed
165
+
166
+ # Calculate frequency from updated k1 value
167
+ omega_n = np.arccos(-self._state.k1)
168
+ freq = omega_n * omega_scale
169
+ freq_out = np.full_like(x_data.reshape(x_data.shape), freq)
170
+
171
+ # Update s_history for next chunk
172
+ self._state.s_history = s_n_reshaped[-2:].reshape((2,) + x_data.shape[1:])
173
+
174
+ # Reshape y_out back to original dimensions
175
+ y_out = y_out.reshape(x_data.shape)
176
+
177
+ else:
178
+ # Perform filtering, sample-by-sample
179
+ y_out = np.zeros_like(x_data)
180
+ freq_out = np.zeros_like(x_data)
181
+ for sample_ix, x_n in enumerate(x_data):
182
+ s_n_1 = self._state.s_history[-1]
183
+ s_n_2 = self._state.s_history[-2]
184
+
185
+ s_n = x_n - self._state.k1 * gamma_plus_1 * s_n_1 - gamma * s_n_2
186
+ y_out[sample_ix] = s_n + 2 * self._state.k1 * s_n_1 + s_n_2
187
+
188
+ # Update filter parameters
189
+ self._state.p = eta * self._state.p + one_minus_eta * (s_n_1 * (s_n + s_n_2))
190
+ self._state.q = eta * self._state.q + one_minus_eta * (2 * (s_n_1 * s_n_1))
191
+
192
+ # Update reflection coefficient
193
+ new_k1 = -self._state.p / (self._state.q + 1e-8) # Avoid division by zero
194
+ new_k1 = np.clip(new_k1, -1, 1) # Clip to prevent instability
195
+ self._state.k1 = mu * self._state.k1 + one_minus_mu * new_k1 # Smoothed
196
+
197
+ # Compute normalized angular frequency using equation 13 from the paper
198
+ omega_n = np.arccos(-self._state.k1)
199
+ freq_out[sample_ix] = omega_n * omega_scale # As Hz
200
+
201
+ # Update for next iteration
202
+ self._state.s_history[-2] = s_n_1
203
+ self._state.s_history[-1] = s_n
204
+
205
+ return replace(
206
+ message,
207
+ data=y_out,
208
+ axes={
209
+ **message.axes,
210
+ "freq": replace(self._state.freq_template, data=freq_out),
211
+ },
212
+ )