ezmsg-sigproc 1.2.2__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ezmsg/sigproc/__init__.py +1 -1
  2. ezmsg/sigproc/__version__.py +34 -1
  3. ezmsg/sigproc/activation.py +78 -0
  4. ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
  5. ezmsg/sigproc/affinetransform.py +235 -0
  6. ezmsg/sigproc/aggregate.py +276 -0
  7. ezmsg/sigproc/bandpower.py +80 -0
  8. ezmsg/sigproc/base.py +149 -0
  9. ezmsg/sigproc/butterworthfilter.py +129 -39
  10. ezmsg/sigproc/butterworthzerophase.py +305 -0
  11. ezmsg/sigproc/cheby.py +125 -0
  12. ezmsg/sigproc/combfilter.py +160 -0
  13. ezmsg/sigproc/coordinatespaces.py +159 -0
  14. ezmsg/sigproc/decimate.py +46 -18
  15. ezmsg/sigproc/denormalize.py +78 -0
  16. ezmsg/sigproc/detrend.py +28 -0
  17. ezmsg/sigproc/diff.py +82 -0
  18. ezmsg/sigproc/downsample.py +97 -49
  19. ezmsg/sigproc/ewma.py +217 -0
  20. ezmsg/sigproc/ewmfilter.py +45 -19
  21. ezmsg/sigproc/extract_axis.py +39 -0
  22. ezmsg/sigproc/fbcca.py +307 -0
  23. ezmsg/sigproc/filter.py +282 -117
  24. ezmsg/sigproc/filterbank.py +292 -0
  25. ezmsg/sigproc/filterbankdesign.py +129 -0
  26. ezmsg/sigproc/fir_hilbert.py +336 -0
  27. ezmsg/sigproc/fir_pmc.py +209 -0
  28. ezmsg/sigproc/firfilter.py +117 -0
  29. ezmsg/sigproc/gaussiansmoothing.py +89 -0
  30. ezmsg/sigproc/kaiser.py +106 -0
  31. ezmsg/sigproc/linear.py +120 -0
  32. ezmsg/sigproc/math/__init__.py +0 -0
  33. ezmsg/sigproc/math/abs.py +35 -0
  34. ezmsg/sigproc/math/add.py +120 -0
  35. ezmsg/sigproc/math/clip.py +48 -0
  36. ezmsg/sigproc/math/difference.py +143 -0
  37. ezmsg/sigproc/math/invert.py +28 -0
  38. ezmsg/sigproc/math/log.py +57 -0
  39. ezmsg/sigproc/math/scale.py +39 -0
  40. ezmsg/sigproc/messages.py +3 -6
  41. ezmsg/sigproc/quantize.py +68 -0
  42. ezmsg/sigproc/resample.py +278 -0
  43. ezmsg/sigproc/rollingscaler.py +232 -0
  44. ezmsg/sigproc/sampler.py +232 -241
  45. ezmsg/sigproc/scaler.py +165 -0
  46. ezmsg/sigproc/signalinjector.py +70 -0
  47. ezmsg/sigproc/slicer.py +138 -0
  48. ezmsg/sigproc/spectral.py +6 -132
  49. ezmsg/sigproc/spectrogram.py +90 -0
  50. ezmsg/sigproc/spectrum.py +277 -0
  51. ezmsg/sigproc/transpose.py +134 -0
  52. ezmsg/sigproc/util/__init__.py +0 -0
  53. ezmsg/sigproc/util/asio.py +25 -0
  54. ezmsg/sigproc/util/axisarray_buffer.py +365 -0
  55. ezmsg/sigproc/util/buffer.py +449 -0
  56. ezmsg/sigproc/util/message.py +17 -0
  57. ezmsg/sigproc/util/profile.py +23 -0
  58. ezmsg/sigproc/util/sparse.py +115 -0
  59. ezmsg/sigproc/util/typeresolution.py +17 -0
  60. ezmsg/sigproc/wavelets.py +187 -0
  61. ezmsg/sigproc/window.py +301 -117
  62. ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
  63. ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
  64. {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -2
  65. ezmsg/sigproc/synth.py +0 -411
  66. ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
  67. ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
  68. ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
  69. /ezmsg_sigproc-1.2.2.dist-info/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
@@ -0,0 +1,232 @@
1
+ from collections import deque
2
+
3
+ import ezmsg.core as ez
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ from ezmsg.baseproc import (
7
+ BaseAdaptiveTransformer,
8
+ BaseAdaptiveTransformerUnit,
9
+ processor_state,
10
+ )
11
+ from ezmsg.util.messages.axisarray import AxisArray
12
+ from ezmsg.util.messages.util import replace
13
+
14
+ from ezmsg.sigproc.sampler import SampleMessage
15
+
16
+
17
+ class RollingScalerSettings(ez.Settings):
18
+ axis: str = "time"
19
+ """
20
+ Axis along which samples are arranged.
21
+ """
22
+
23
+ k_samples: int | None = 20
24
+ """
25
+ Rolling window size in number of samples.
26
+ """
27
+
28
+ window_size: float | None = None
29
+ """
30
+ Rolling window size in seconds.
31
+ If set, overrides `k_samples`.
32
+ `update_with_signal` likely should be True if using this option.
33
+ """
34
+
35
+ update_with_signal: bool = False
36
+ """
37
+ If True, update rolling statistics using the incoming process stream.
38
+ """
39
+
40
+ min_samples: int = 1
41
+ """
42
+ Minimum number of samples required to compute statistics.
43
+ Used when `window_size` is not set.
44
+ """
45
+
46
+ min_seconds: float = 1.0
47
+ """
48
+ Minimum duration in seconds required to compute statistics.
49
+ Used when `window_size` is set.
50
+ """
51
+
52
+ artifact_z_thresh: float | None = None
53
+ """
54
+ Threshold for z-score based artifact detection.
55
+ If set, samples with any channel exceeding this z-score will be excluded
56
+ from updating the rolling statistics.
57
+ """
58
+
59
+ clip: float | None = 10.0
60
+ """
61
+ If set, clip the output values to the range [-clip, clip].
62
+ """
63
+
64
+
65
+ @processor_state
66
+ class RollingScalerState:
67
+ mean: npt.NDArray | None = None
68
+ N: int = 0
69
+ M2: npt.NDArray | None = None
70
+ samples: deque | None = None
71
+ k_samples: int | None = None
72
+ min_samples: int | None = None
73
+
74
+
75
+ class RollingScalerProcessor(BaseAdaptiveTransformer[RollingScalerSettings, AxisArray, AxisArray, RollingScalerState]):
76
+ """
77
+ Processor for rolling z-score normalization of input `AxisArray` messages.
78
+
79
+ The processor maintains rolling statistics (mean and variance) over the last `k_samples`
80
+ samples received via the `partial_fit()` method. When processing an `AxisArray` message,
81
+ it normalizes the data using the current rolling statistics.
82
+
83
+ The input `AxisArray` messages are expected to have shape `(time, ch)`, where `ch` is the
84
+ channel axis. The processor computes the z-score for each channel independently.
85
+
86
+ Note: You should consider instead using the AdaptiveStandardScalerTransformer which
87
+ is computationally more efficient and uses less memory. This RollingScalerProcessor
88
+ is primarily provided to reproduce processing in the literature.
89
+
90
+ Settings:
91
+ ---------
92
+ k_samples: int
93
+ Number of previous samples to use for rolling statistics.
94
+
95
+ Example:
96
+ -----------------------------
97
+ ```python
98
+ processor = RollingScalerProcessor(
99
+ settings=RollingScalerSettings(
100
+ k_samples=20 # Number of previous samples to use for rolling statistics
101
+ )
102
+ )
103
+ ```
104
+ """
105
+
106
+ def _hash_message(self, message: AxisArray) -> int:
107
+ axis = message.dims[0] if self.settings.axis is None else self.settings.axis
108
+ gain = message.axes[axis].gain if hasattr(message.axes[axis], "gain") else 1
109
+ axis_idx = message.get_axis_idx(axis)
110
+ samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
111
+ return hash((message.key, samp_shape, gain))
112
+
113
+ def _reset_state(self, message: AxisArray) -> None:
114
+ ch = message.data.shape[-1]
115
+ self._state.mean = np.zeros(ch)
116
+ self._state.N = 0
117
+ self._state.M2 = np.zeros(ch)
118
+ self._state.k_samples = (
119
+ int(np.ceil(self.settings.window_size / message.axes[self.settings.axis].gain))
120
+ if self.settings.window_size is not None
121
+ else self.settings.k_samples
122
+ )
123
+ if self._state.k_samples is not None and self._state.k_samples < 1:
124
+ ez.logger.warning("window_size smaller than sample gain; setting k_samples to 1.")
125
+ self._state.k_samples = 1
126
+ elif self._state.k_samples is None:
127
+ ez.logger.warning("k_samples is None; z-score accumulation will be unbounded.")
128
+ self._state.samples = deque(maxlen=self._state.k_samples)
129
+ self._state.min_samples = (
130
+ int(np.ceil(self.settings.min_seconds / message.axes[self.settings.axis].gain))
131
+ if self.settings.window_size is not None
132
+ else self.settings.min_samples
133
+ )
134
+ if self._state.k_samples is not None and self._state.min_samples > self._state.k_samples:
135
+ ez.logger.warning("min_samples is greater than k_samples; adjusting min_samples to k_samples.")
136
+ self._state.min_samples = self._state.k_samples
137
+
138
+ def _add_batch_stats(self, x: npt.NDArray) -> None:
139
+ x = np.asarray(x, dtype=np.float64)
140
+ n_b = x.shape[0]
141
+ mean_b = np.mean(x, axis=0)
142
+ M2_b = np.sum((x - mean_b) ** 2, axis=0)
143
+
144
+ if self._state.k_samples is not None and len(self._state.samples) == self._state.k_samples:
145
+ n_old, mean_old, M2_old = self._state.samples.popleft()
146
+ N_T = self._state.N
147
+ N_new = N_T - n_old
148
+
149
+ if N_new <= 0:
150
+ self._state.N = 0
151
+ self._state.mean = np.zeros_like(self._state.mean)
152
+ self._state.M2 = np.zeros_like(self._state.M2)
153
+ else:
154
+ delta = mean_old - self._state.mean
155
+ self._state.N = N_new
156
+ self._state.mean = (N_T * self._state.mean - n_old * mean_old) / N_new
157
+ self._state.M2 = self._state.M2 - M2_old - (delta * delta) * (N_T * n_old / N_new)
158
+
159
+ N_A = self._state.N
160
+ N = N_A + n_b
161
+ delta = mean_b - self._state.mean
162
+ self._state.mean = self._state.mean + delta * (n_b / N)
163
+ self._state.M2 = self._state.M2 + M2_b + (delta * delta) * (N_A * n_b / N)
164
+ self._state.N = N
165
+
166
+ self._state.samples.append((n_b, mean_b, M2_b))
167
+
168
+ def partial_fit(self, message: SampleMessage) -> None:
169
+ x = message.sample.data
170
+ self._add_batch_stats(x)
171
+
172
+ def _process(self, message: AxisArray) -> AxisArray:
173
+ if self._state.N == 0 or self._state.N < self._state.min_samples:
174
+ if self.settings.update_with_signal:
175
+ x = message.data
176
+ if self.settings.artifact_z_thresh is not None and self._state.N > 0:
177
+ varis = self._state.M2 / self._state.N
178
+ std = np.maximum(np.sqrt(varis), 1e-8)
179
+ z = np.abs((x - self._state.mean) / std)
180
+ mask = np.any(z > self.settings.artifact_z_thresh, axis=1)
181
+ x = x[~mask]
182
+ if x.size > 0:
183
+ self._add_batch_stats(x)
184
+ return message
185
+
186
+ varis = self._state.M2 / self._state.N
187
+ std = np.maximum(np.sqrt(varis), 1e-8)
188
+ with np.errstate(divide="ignore", invalid="ignore"):
189
+ result = (message.data - self._state.mean) / std
190
+ result = np.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.0)
191
+ if self.settings.clip is not None:
192
+ result = np.clip(result, -self.settings.clip, self.settings.clip)
193
+
194
+ if self.settings.update_with_signal:
195
+ x = message.data
196
+ if self.settings.artifact_z_thresh is not None:
197
+ z_scores = np.abs((x - self._state.mean) / std)
198
+ mask = np.any(z_scores > self.settings.artifact_z_thresh, axis=1)
199
+ x = x[~mask]
200
+ if x.size > 0:
201
+ self._add_batch_stats(x)
202
+
203
+ return replace(message, data=result)
204
+
205
+
206
+ class RollingScalerUnit(
207
+ BaseAdaptiveTransformerUnit[
208
+ RollingScalerSettings,
209
+ AxisArray,
210
+ AxisArray,
211
+ RollingScalerProcessor,
212
+ ]
213
+ ):
214
+ """
215
+ Unit wrapper for :obj:`RollingScalerProcessor`.
216
+
217
+ This unit performs rolling z-score normalization on incoming `AxisArray` messages. The unit maintains rolling
218
+ statistics (mean and variance) over the last `k_samples` samples received. When processing an `AxisArray` message,
219
+ it normalizes the data using the current rolling statistics.
220
+
221
+ Example:
222
+ -----------------------------
223
+ ```python
224
+ unit = RollingScalerUnit(
225
+ settings=RollingScalerSettings(
226
+ k_samples=20 # Number of previous samples to use for rolling statistics
227
+ )
228
+ )
229
+ ```
230
+ """
231
+
232
+ SETTINGS = RollingScalerSettings