ezmsg-sigproc 1.2.2__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ezmsg/sigproc/__init__.py +1 -1
  2. ezmsg/sigproc/__version__.py +34 -1
  3. ezmsg/sigproc/activation.py +78 -0
  4. ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
  5. ezmsg/sigproc/affinetransform.py +235 -0
  6. ezmsg/sigproc/aggregate.py +276 -0
  7. ezmsg/sigproc/bandpower.py +80 -0
  8. ezmsg/sigproc/base.py +149 -0
  9. ezmsg/sigproc/butterworthfilter.py +129 -39
  10. ezmsg/sigproc/butterworthzerophase.py +305 -0
  11. ezmsg/sigproc/cheby.py +125 -0
  12. ezmsg/sigproc/combfilter.py +160 -0
  13. ezmsg/sigproc/coordinatespaces.py +159 -0
  14. ezmsg/sigproc/decimate.py +46 -18
  15. ezmsg/sigproc/denormalize.py +78 -0
  16. ezmsg/sigproc/detrend.py +28 -0
  17. ezmsg/sigproc/diff.py +82 -0
  18. ezmsg/sigproc/downsample.py +97 -49
  19. ezmsg/sigproc/ewma.py +217 -0
  20. ezmsg/sigproc/ewmfilter.py +45 -19
  21. ezmsg/sigproc/extract_axis.py +39 -0
  22. ezmsg/sigproc/fbcca.py +307 -0
  23. ezmsg/sigproc/filter.py +282 -117
  24. ezmsg/sigproc/filterbank.py +292 -0
  25. ezmsg/sigproc/filterbankdesign.py +129 -0
  26. ezmsg/sigproc/fir_hilbert.py +336 -0
  27. ezmsg/sigproc/fir_pmc.py +209 -0
  28. ezmsg/sigproc/firfilter.py +117 -0
  29. ezmsg/sigproc/gaussiansmoothing.py +89 -0
  30. ezmsg/sigproc/kaiser.py +106 -0
  31. ezmsg/sigproc/linear.py +120 -0
  32. ezmsg/sigproc/math/__init__.py +0 -0
  33. ezmsg/sigproc/math/abs.py +35 -0
  34. ezmsg/sigproc/math/add.py +120 -0
  35. ezmsg/sigproc/math/clip.py +48 -0
  36. ezmsg/sigproc/math/difference.py +143 -0
  37. ezmsg/sigproc/math/invert.py +28 -0
  38. ezmsg/sigproc/math/log.py +57 -0
  39. ezmsg/sigproc/math/scale.py +39 -0
  40. ezmsg/sigproc/messages.py +3 -6
  41. ezmsg/sigproc/quantize.py +68 -0
  42. ezmsg/sigproc/resample.py +278 -0
  43. ezmsg/sigproc/rollingscaler.py +232 -0
  44. ezmsg/sigproc/sampler.py +232 -241
  45. ezmsg/sigproc/scaler.py +165 -0
  46. ezmsg/sigproc/signalinjector.py +70 -0
  47. ezmsg/sigproc/slicer.py +138 -0
  48. ezmsg/sigproc/spectral.py +6 -132
  49. ezmsg/sigproc/spectrogram.py +90 -0
  50. ezmsg/sigproc/spectrum.py +277 -0
  51. ezmsg/sigproc/transpose.py +134 -0
  52. ezmsg/sigproc/util/__init__.py +0 -0
  53. ezmsg/sigproc/util/asio.py +25 -0
  54. ezmsg/sigproc/util/axisarray_buffer.py +365 -0
  55. ezmsg/sigproc/util/buffer.py +449 -0
  56. ezmsg/sigproc/util/message.py +17 -0
  57. ezmsg/sigproc/util/profile.py +23 -0
  58. ezmsg/sigproc/util/sparse.py +115 -0
  59. ezmsg/sigproc/util/typeresolution.py +17 -0
  60. ezmsg/sigproc/wavelets.py +187 -0
  61. ezmsg/sigproc/window.py +301 -117
  62. ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
  63. ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
  64. {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -2
  65. ezmsg/sigproc/synth.py +0 -411
  66. ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
  67. ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
  68. ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
  69. /ezmsg_sigproc-1.2.2.dist-info/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/__init__.py CHANGED
@@ -1 +1 @@
1
- from .__version__ import __version__
1
+ from .__version__ import __version__ as __version__
@@ -1 +1,34 @@
1
- __version__ = "1.2.2"
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
12
+
13
+ TYPE_CHECKING = False
14
+ if TYPE_CHECKING:
15
+ from typing import Tuple
16
+ from typing import Union
17
+
18
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
20
+ else:
21
+ VERSION_TUPLE = object
22
+ COMMIT_ID = object
23
+
24
+ version: str
25
+ __version__: str
26
+ __version_tuple__: VERSION_TUPLE
27
+ version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
30
+
31
+ __version__ = version = '2.10.0'
32
+ __version_tuple__ = version_tuple = (2, 10, 0)
33
+
34
+ __commit_id__ = commit_id = None
@@ -0,0 +1,78 @@
1
+ import ezmsg.core as ez
2
+ import scipy.special
3
+ from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
4
+ from ezmsg.util.messages.axisarray import AxisArray
5
+ from ezmsg.util.messages.util import replace
6
+
7
+ from .spectral import OptionsEnum
8
+
9
+
10
+ class ActivationFunction(OptionsEnum):
11
+ """Activation (transformation) function."""
12
+
13
+ NONE = "none"
14
+ """None."""
15
+
16
+ SIGMOID = "sigmoid"
17
+ """:obj:`scipy.special.expit`"""
18
+
19
+ EXPIT = "expit"
20
+ """:obj:`scipy.special.expit`"""
21
+
22
+ LOGIT = "logit"
23
+ """:obj:`scipy.special.logit`"""
24
+
25
+ LOGEXPIT = "log_expit"
26
+ """:obj:`scipy.special.log_expit`"""
27
+
28
+
29
+ ACTIVATIONS = {
30
+ ActivationFunction.NONE: lambda x: x,
31
+ ActivationFunction.SIGMOID: scipy.special.expit,
32
+ ActivationFunction.EXPIT: scipy.special.expit,
33
+ ActivationFunction.LOGIT: scipy.special.logit,
34
+ ActivationFunction.LOGEXPIT: scipy.special.log_expit,
35
+ }
36
+
37
+
38
+ class ActivationSettings(ez.Settings):
39
+ function: str | ActivationFunction = ActivationFunction.NONE
40
+ """An enum value from ActivationFunction or a string representing the activation function.
41
+ Possible values are: SIGMOID, EXPIT, LOGIT, LOGEXPIT, "sigmoid", "expit", "logit", "log_expit".
42
+ SIGMOID and EXPIT are equivalent. See :obj:`scipy.special.expit` for more details."""
43
+
44
+
45
+ class ActivationTransformer(BaseTransformer[ActivationSettings, AxisArray, AxisArray]):
46
+ def _process(self, message: AxisArray) -> AxisArray:
47
+ if type(self.settings.function) is ActivationFunction:
48
+ func = ACTIVATIONS[self.settings.function]
49
+ else:
50
+ # str type handling
51
+ function = self.settings.function.lower()
52
+ if function not in ActivationFunction.options():
53
+ raise ValueError(f"Unrecognized activation function {function}. Must be one of {ACTIVATIONS.keys()}")
54
+ function = list(ACTIVATIONS.keys())[ActivationFunction.options().index(function)]
55
+ func = ACTIVATIONS[function]
56
+
57
+ return replace(message, data=func(message.data))
58
+
59
+
60
+ class Activation(BaseTransformerUnit[ActivationSettings, AxisArray, AxisArray, ActivationTransformer]):
61
+ SETTINGS = ActivationSettings
62
+
63
+
64
+ def activation(
65
+ function: str | ActivationFunction,
66
+ ) -> ActivationTransformer:
67
+ """
68
+ Transform the data with a simple activation function.
69
+
70
+ Args:
71
+ function: An enum value from ActivationFunction or a string representing the activation function.
72
+ Possible values are: SIGMOID, EXPIT, LOGIT, LOGEXPIT, "sigmoid", "expit", "logit", "log_expit".
73
+ SIGMOID and EXPIT are equivalent. See :obj:`scipy.special.expit` for more details.
74
+
75
+ Returns: :obj:`ActivationTransformer`
76
+
77
+ """
78
+ return ActivationTransformer(ActivationSettings(function=function))
@@ -0,0 +1,212 @@
1
+ import ezmsg.core as ez
2
+ import numpy as np
3
+ import numpy.typing as npt
4
+ import scipy.signal
5
+ from ezmsg.baseproc import BaseStatefulTransformer, processor_state
6
+ from ezmsg.util.messages.axisarray import AxisArray, CoordinateAxis
7
+ from ezmsg.util.messages.util import replace
8
+
9
+
10
+ class AdaptiveLatticeNotchFilterSettings(ez.Settings):
11
+ """Settings for the Adaptive Lattice Notch Filter."""
12
+
13
+ gamma: float = 0.995
14
+ """Pole-zero contraction factor"""
15
+ mu: float = 0.99
16
+ """Smoothing factor"""
17
+ eta: float = 0.99
18
+ """Forgetting factor"""
19
+ axis: str = "time"
20
+ """Axis to apply filter to"""
21
+ init_notch_freq: float | None = None
22
+ """Initial notch frequency. Should be < nyquist."""
23
+ chunkwise: bool = False
24
+ """Speed up processing by updating the target freq once per chunk only."""
25
+
26
+
27
+ @processor_state
28
+ class AdaptiveLatticeNotchFilterState:
29
+ """State for the Adaptive Lattice Notch Filter."""
30
+
31
+ s_history: npt.NDArray | None = None
32
+ """Historical `s` values for the adaptive filter."""
33
+
34
+ p: npt.NDArray | None = None
35
+ """Accumulated product for reflection coefficient update"""
36
+
37
+ q: npt.NDArray | None = None
38
+ """Accumulated product for reflection coefficient update"""
39
+
40
+ k1: npt.NDArray | None = None
41
+ """Reflection coefficient"""
42
+
43
+ freq_template: CoordinateAxis | None = None
44
+ """Template for the frequency axis on the output"""
45
+
46
+ zi: npt.NDArray | None = None
47
+ """Initial conditions for the filter, updated after every chunk"""
48
+
49
+
50
+ class AdaptiveLatticeNotchFilterTransformer(
51
+ BaseStatefulTransformer[
52
+ AdaptiveLatticeNotchFilterSettings,
53
+ AxisArray,
54
+ AxisArray,
55
+ AdaptiveLatticeNotchFilterState,
56
+ ]
57
+ ):
58
+ """
59
+ Adaptive Lattice Notch Filter implementation as a stateful transformer.
60
+
61
+ https://biomedical-engineering-online.biomedcentral.com/articles/10.1186/1475-925X-13-170
62
+
63
+ The filter automatically tracks and removes frequency components from the input signal.
64
+ It outputs the estimated frequency (in Hz) and the filtered sample.
65
+ """
66
+
67
+ def _hash_message(self, message: AxisArray) -> int:
68
+ ax_idx = message.get_axis_idx(self.settings.axis)
69
+ sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
70
+ return hash((message.key, message.axes[self.settings.axis].gain, sample_shape))
71
+
72
+ def _reset_state(self, message: AxisArray) -> None:
73
+ ax_idx = message.get_axis_idx(self.settings.axis)
74
+ sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
75
+
76
+ fs = 1 / message.axes[self.settings.axis].gain
77
+ init_f = (
78
+ self.settings.init_notch_freq if self.settings.init_notch_freq is not None else 0.07178314656435313 * fs
79
+ )
80
+ init_omega = init_f * (2 * np.pi) / fs
81
+ init_k1 = -np.cos(init_omega)
82
+
83
+ """Reset filter state to initial values."""
84
+ self._state = AdaptiveLatticeNotchFilterState()
85
+ self._state.s_history = np.zeros((2,) + sample_shape, dtype=float)
86
+ self._state.p = np.zeros(sample_shape, dtype=float)
87
+ self._state.q = np.zeros(sample_shape, dtype=float)
88
+ self._state.k1 = init_k1 + np.zeros(sample_shape, dtype=float)
89
+ self._state.freq_template = CoordinateAxis(
90
+ data=np.zeros((0,) + sample_shape, dtype=float),
91
+ dims=[self.settings.axis] + message.dims[:ax_idx] + message.dims[ax_idx + 1 :],
92
+ unit="Hz",
93
+ )
94
+
95
+ # Initialize the initial conditions for the filter
96
+ self._state.zi = np.zeros((2, np.prod(sample_shape)), dtype=float)
97
+ # Note: we could calculate it properly, but as long as we are initializing s_history with zeros,
98
+ # it will always be zero.
99
+ # a = [1, init_k1 * (1 + self.settings.gamma), self.settings.gamma]
100
+ # b = [1]
101
+ # s = np.reshape(self._state.s_history, (2, -1))
102
+ # for feat_ix in range(np.prod(sample_shape)):
103
+ # self._state.zi[:, feat_ix] = scipy.signal.lfiltic(b, a, s[::-1, feat_ix], x=None)
104
+
105
+ def _process(self, message: AxisArray) -> AxisArray:
106
+ x_data = message.data
107
+ ax_idx = message.get_axis_idx(self.settings.axis)
108
+
109
+ # TODO: Time should be moved to -1th axis, not the 0th axis
110
+ if message.dims[0] != self.settings.axis:
111
+ x_data = np.moveaxis(x_data, ax_idx, 0)
112
+
113
+ # Access settings once
114
+ gamma = self.settings.gamma
115
+ eta = self.settings.eta
116
+ mu = self.settings.mu
117
+ fs = 1 / message.axes[self.settings.axis].gain
118
+
119
+ # Pre-compute constants
120
+ one_minus_eta = 1 - eta
121
+ one_minus_mu = 1 - mu
122
+ gamma_plus_1 = 1 + gamma
123
+ omega_scale = fs / (2 * np.pi)
124
+
125
+ # For the lattice filter with constant k1:
126
+ # s_n = x_n - k1*(1+gamma)*s_n_1 - gamma*s_n_2
127
+ # This is equivalent to an IIR filter with b=1, a=[1, k1*(1+gamma), gamma]
128
+
129
+ # For the output filter:
130
+ # y_n = s_n + 2*k1*s_n_1 + s_n_2
131
+ # We can treat this as a direct-form FIR filter applied to s_out
132
+
133
+ if self.settings.chunkwise:
134
+ # Process each chunk using current filter parameters
135
+ # Reshape input and prepare output arrays
136
+ _s = self._state.s_history.reshape((2, -1))
137
+ _x = x_data.reshape((x_data.shape[0], -1))
138
+ s_n = np.zeros_like(_x)
139
+ y_out = np.zeros_like(_x)
140
+
141
+ # Apply static filter for each feature dimension
142
+ for ix, k in enumerate(self._state.k1.flatten()):
143
+ # Filter to get s_n (notch filter state)
144
+ a_s = [1, k * gamma_plus_1, gamma]
145
+ s_n[:, ix], self._state.zi[:, ix] = scipy.signal.lfilter([1], a_s, _x[:, ix], zi=self._state.zi[:, ix])
146
+
147
+ # Apply output filter to get y_out
148
+ b_y = [1, 2 * k, 1]
149
+ y_out[:, ix] = scipy.signal.lfilter(b_y, [1], s_n[:, ix])
150
+
151
+ # Update filter parameters using final values from the chunk
152
+ s_n_reshaped = s_n.reshape((s_n.shape[0],) + x_data.shape[1:])
153
+ s_final = s_n_reshaped[-1] # Current s_n
154
+ s_final_1 = s_n_reshaped[-2] # s_n_1
155
+ s_final_2 = s_n_reshaped[-3] if len(s_n_reshaped) > 2 else self._state.s_history[0] # s_n_2
156
+
157
+ # Update p and q using final values
158
+ self._state.p = eta * self._state.p + one_minus_eta * (s_final_1 * (s_final + s_final_2))
159
+ self._state.q = eta * self._state.q + one_minus_eta * (2 * (s_final_1 * s_final_1))
160
+
161
+ # Update reflection coefficient
162
+ new_k1 = -self._state.p / (self._state.q + 1e-8) # Avoid division by zero
163
+ new_k1 = np.clip(new_k1, -1, 1) # Clip to prevent instability
164
+ self._state.k1 = mu * self._state.k1 + one_minus_mu * new_k1 # Smoothed
165
+
166
+ # Calculate frequency from updated k1 value
167
+ omega_n = np.arccos(-self._state.k1)
168
+ freq = omega_n * omega_scale
169
+ freq_out = np.full_like(x_data.reshape(x_data.shape), freq)
170
+
171
+ # Update s_history for next chunk
172
+ self._state.s_history = s_n_reshaped[-2:].reshape((2,) + x_data.shape[1:])
173
+
174
+ # Reshape y_out back to original dimensions
175
+ y_out = y_out.reshape(x_data.shape)
176
+
177
+ else:
178
+ # Perform filtering, sample-by-sample
179
+ y_out = np.zeros_like(x_data)
180
+ freq_out = np.zeros_like(x_data)
181
+ for sample_ix, x_n in enumerate(x_data):
182
+ s_n_1 = self._state.s_history[-1]
183
+ s_n_2 = self._state.s_history[-2]
184
+
185
+ s_n = x_n - self._state.k1 * gamma_plus_1 * s_n_1 - gamma * s_n_2
186
+ y_out[sample_ix] = s_n + 2 * self._state.k1 * s_n_1 + s_n_2
187
+
188
+ # Update filter parameters
189
+ self._state.p = eta * self._state.p + one_minus_eta * (s_n_1 * (s_n + s_n_2))
190
+ self._state.q = eta * self._state.q + one_minus_eta * (2 * (s_n_1 * s_n_1))
191
+
192
+ # Update reflection coefficient
193
+ new_k1 = -self._state.p / (self._state.q + 1e-8) # Avoid division by zero
194
+ new_k1 = np.clip(new_k1, -1, 1) # Clip to prevent instability
195
+ self._state.k1 = mu * self._state.k1 + one_minus_mu * new_k1 # Smoothed
196
+
197
+ # Compute normalized angular frequency using equation 13 from the paper
198
+ omega_n = np.arccos(-self._state.k1)
199
+ freq_out[sample_ix] = omega_n * omega_scale # As Hz
200
+
201
+ # Update for next iteration
202
+ self._state.s_history[-2] = s_n_1
203
+ self._state.s_history[-1] = s_n
204
+
205
+ return replace(
206
+ message,
207
+ data=y_out,
208
+ axes={
209
+ **message.axes,
210
+ "freq": replace(self._state.freq_template, data=freq_out),
211
+ },
212
+ )
@@ -0,0 +1,235 @@
1
+ """Affine transformations via matrix multiplication: y = Ax or y = Ax + B.
2
+
3
+ For full matrix transformations where channels are mixed (off-diagonal weights),
4
+ use :obj:`AffineTransformTransformer` or the `AffineTransform` unit.
5
+
6
+ For simple per-channel scaling and offset (diagonal weights only), use
7
+ :obj:`LinearTransformTransformer` from :mod:`ezmsg.sigproc.linear` instead,
8
+ which is more efficient as it avoids matrix multiplication.
9
+ """
10
+
11
+ import os
12
+ from pathlib import Path
13
+
14
+ import ezmsg.core as ez
15
+ import numpy as np
16
+ import numpy.typing as npt
17
+ from ezmsg.baseproc import (
18
+ BaseStatefulTransformer,
19
+ BaseTransformer,
20
+ BaseTransformerUnit,
21
+ processor_state,
22
+ )
23
+ from ezmsg.util.messages.axisarray import AxisArray, AxisBase
24
+ from ezmsg.util.messages.util import replace
25
+
26
+
27
+ class AffineTransformSettings(ez.Settings):
28
+ """
29
+ Settings for :obj:`AffineTransform`.
30
+ """
31
+
32
+ weights: np.ndarray | str | Path
33
+ """An array of weights or a path to a file with weights compatible with np.loadtxt."""
34
+
35
+ axis: str | None = None
36
+ """The name of the axis to apply the transformation to. Defaults to the leading (0th) axis in the array."""
37
+
38
+ right_multiply: bool = True
39
+ """Set False to transpose the weights before applying."""
40
+
41
+
42
+ @processor_state
43
+ class AffineTransformState:
44
+ weights: npt.NDArray | None = None
45
+ new_axis: AxisBase | None = None
46
+
47
+
48
+ class AffineTransformTransformer(
49
+ BaseStatefulTransformer[AffineTransformSettings, AxisArray, AxisArray, AffineTransformState]
50
+ ):
51
+ """Apply affine transformation via matrix multiplication: y = Ax or y = Ax + B.
52
+
53
+ Use this transformer when you need full matrix transformations that mix
54
+ channels (off-diagonal weights), such as spatial filters or projections.
55
+
56
+ For simple per-channel scaling and offset where each output channel depends
57
+ only on its corresponding input channel (diagonal weight matrix), use
58
+ :obj:`LinearTransformTransformer` instead, which is more efficient.
59
+
60
+ The weights matrix can include an offset row (stacked as [A|B]) where the
61
+ input is automatically augmented with a column of ones to compute y = Ax + B.
62
+ """
63
+
64
+ def __call__(self, message: AxisArray) -> AxisArray:
65
+ # Override __call__ so we can shortcut if weights are None.
66
+ if self.settings.weights is None or (
67
+ isinstance(self.settings.weights, str) and self.settings.weights == "passthrough"
68
+ ):
69
+ return message
70
+ return super().__call__(message)
71
+
72
+ def _hash_message(self, message: AxisArray) -> int:
73
+ return hash(message.key)
74
+
75
+ def _reset_state(self, message: AxisArray) -> None:
76
+ weights = self.settings.weights
77
+ if isinstance(weights, str):
78
+ weights = Path(os.path.abspath(os.path.expanduser(weights)))
79
+ if isinstance(weights, Path):
80
+ weights = np.loadtxt(weights, delimiter=",")
81
+ if not self.settings.right_multiply:
82
+ weights = weights.T
83
+ if weights is not None:
84
+ weights = np.ascontiguousarray(weights)
85
+
86
+ self._state.weights = weights
87
+
88
+ axis = self.settings.axis or message.dims[-1]
89
+ if axis in message.axes and hasattr(message.axes[axis], "data") and weights.shape[0] != weights.shape[1]:
90
+ in_labels = message.axes[axis].data
91
+ new_labels = []
92
+ n_in, n_out = weights.shape
93
+ if len(in_labels) != n_in:
94
+ ez.logger.warning(f"Received {len(in_labels)} for {n_in} inputs. Check upstream labels.")
95
+ else:
96
+ b_filled_outputs = np.any(weights, axis=0)
97
+ b_used_inputs = np.any(weights, axis=1)
98
+ if np.all(b_used_inputs) and np.all(b_filled_outputs):
99
+ new_labels = []
100
+ elif np.all(b_used_inputs):
101
+ in_ix = 0
102
+ new_labels = []
103
+ for out_ix in range(n_out):
104
+ if b_filled_outputs[out_ix]:
105
+ new_labels.append(in_labels[in_ix])
106
+ in_ix += 1
107
+ else:
108
+ new_labels.append("")
109
+ elif np.all(b_filled_outputs):
110
+ new_labels = np.array(in_labels)[b_used_inputs]
111
+
112
+ self._state.new_axis = replace(message.axes[axis], data=np.array(new_labels))
113
+
114
+ def _process(self, message: AxisArray) -> AxisArray:
115
+ axis = self.settings.axis or message.dims[-1]
116
+ axis_idx = message.get_axis_idx(axis)
117
+ data = message.data
118
+
119
+ if data.shape[axis_idx] == (self._state.weights.shape[0] - 1):
120
+ # The weights are stacked A|B where A is the transform and B is a single row
121
+ # in the equation y = Ax + B. This supports NeuroKey's weights matrices.
122
+ sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx + 1 :]
123
+ data = np.concatenate((data, np.ones(sample_shape).astype(data.dtype)), axis=axis_idx)
124
+
125
+ if axis_idx in [-1, len(message.dims) - 1]:
126
+ data = np.matmul(data, self._state.weights)
127
+ else:
128
+ data = np.moveaxis(data, axis_idx, -1)
129
+ data = np.matmul(data, self._state.weights)
130
+ data = np.moveaxis(data, -1, axis_idx)
131
+
132
+ replace_kwargs = {"data": data}
133
+ if self._state.new_axis is not None:
134
+ replace_kwargs["axes"] = {**message.axes, axis: self._state.new_axis}
135
+
136
+ return replace(message, **replace_kwargs)
137
+
138
+
139
+ class AffineTransform(BaseTransformerUnit[AffineTransformSettings, AxisArray, AxisArray, AffineTransformTransformer]):
140
+ SETTINGS = AffineTransformSettings
141
+
142
+
143
+ def affine_transform(
144
+ weights: np.ndarray | str | Path,
145
+ axis: str | None = None,
146
+ right_multiply: bool = True,
147
+ ) -> AffineTransformTransformer:
148
+ """
149
+ Perform affine transformations on streaming data.
150
+
151
+ Args:
152
+ weights: An array of weights or a path to a file with weights compatible with np.loadtxt.
153
+ axis: The name of the axis to apply the transformation to. Defaults to the leading (0th) axis in the array.
154
+ right_multiply: Set False to transpose the weights before applying.
155
+
156
+ Returns:
157
+ :obj:`AffineTransformTransformer`.
158
+ """
159
+ return AffineTransformTransformer(
160
+ AffineTransformSettings(weights=weights, axis=axis, right_multiply=right_multiply)
161
+ )
162
+
163
+
164
+ def zeros_for_noop(data: npt.NDArray, **ignore_kwargs) -> npt.NDArray:
165
+ return np.zeros_like(data)
166
+
167
+
168
+ class CommonRereferenceSettings(ez.Settings):
169
+ """
170
+ Settings for :obj:`CommonRereference`
171
+ """
172
+
173
+ mode: str = "mean"
174
+ """The statistical mode to apply -- either "mean" or "median"."""
175
+
176
+ axis: str | None = None
177
+ """The name of the axis to apply the transformation to."""
178
+
179
+ include_current: bool = True
180
+ """Set False to exclude each channel from participating in the calculation of its reference."""
181
+
182
+
183
+ class CommonRereferenceTransformer(BaseTransformer[CommonRereferenceSettings, AxisArray, AxisArray]):
184
+ def _process(self, message: AxisArray) -> AxisArray:
185
+ if self.settings.mode == "passthrough":
186
+ return message
187
+
188
+ axis = self.settings.axis or message.dims[-1]
189
+ axis_idx = message.get_axis_idx(axis)
190
+
191
+ func = {"mean": np.mean, "median": np.median, "passthrough": zeros_for_noop}[self.settings.mode]
192
+
193
+ ref_data = func(message.data, axis=axis_idx, keepdims=True)
194
+
195
+ if not self.settings.include_current:
196
+ # Typical `CAR = x[0]/N + x[1]/N + ... x[i-1]/N + x[i]/N + x[i+1]/N + ... + x[N-1]/N`
197
+ # and is the same for all i, so it is calculated only once in `ref_data`.
198
+ # However, if we had excluded the current channel,
199
+ # then we would have omitted the contribution of the current channel:
200
+ # `CAR[i] = x[0]/(N-1) + x[1]/(N-1) + ... x[i-1]/(N-1) + x[i+1]/(N-1) + ... + x[N-1]/(N-1)`
201
+ # The majority of the calculation is the same as when the current channel is included;
202
+ # we need only rescale CAR so the divisor is `N-1` instead of `N`, then subtract the contribution
203
+ # from the current channel (i.e., `x[i] / (N-1)`)
204
+ # i.e., `CAR[i] = (N / (N-1)) * common_CAR - x[i]/(N-1)`
205
+ # We can use broadcasting subtraction instead of looping over channels.
206
+ N = message.data.shape[axis_idx]
207
+ ref_data = (N / (N - 1)) * ref_data - message.data / (N - 1)
208
+ # Note: I profiled using AffineTransformTransformer; it's ~30x slower than this implementation.
209
+
210
+ return replace(message, data=message.data - ref_data)
211
+
212
+
213
+ class CommonRereference(
214
+ BaseTransformerUnit[CommonRereferenceSettings, AxisArray, AxisArray, CommonRereferenceTransformer]
215
+ ):
216
+ SETTINGS = CommonRereferenceSettings
217
+
218
+
219
+ def common_rereference(
220
+ mode: str = "mean", axis: str | None = None, include_current: bool = True
221
+ ) -> CommonRereferenceTransformer:
222
+ """
223
+ Perform common average referencing (CAR) on streaming data.
224
+
225
+ Args:
226
+ mode: The statistical mode to apply -- either "mean" or "median"
227
+ axis: The name of hte axis to apply the transformation to.
228
+ include_current: Set False to exclude each channel from participating in the calculation of its reference.
229
+
230
+ Returns:
231
+ :obj:`CommonRereferenceTransformer`
232
+ """
233
+ return CommonRereferenceTransformer(
234
+ CommonRereferenceSettings(mode=mode, axis=axis, include_current=include_current)
235
+ )