ezmsg-sigproc 1.7.0__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. ezmsg/sigproc/__version__.py +22 -4
  2. ezmsg/sigproc/activation.py +31 -40
  3. ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
  4. ezmsg/sigproc/affinetransform.py +171 -169
  5. ezmsg/sigproc/aggregate.py +190 -97
  6. ezmsg/sigproc/bandpower.py +60 -55
  7. ezmsg/sigproc/base.py +143 -33
  8. ezmsg/sigproc/butterworthfilter.py +34 -38
  9. ezmsg/sigproc/butterworthzerophase.py +305 -0
  10. ezmsg/sigproc/cheby.py +23 -17
  11. ezmsg/sigproc/combfilter.py +160 -0
  12. ezmsg/sigproc/coordinatespaces.py +159 -0
  13. ezmsg/sigproc/decimate.py +15 -10
  14. ezmsg/sigproc/denormalize.py +78 -0
  15. ezmsg/sigproc/detrend.py +28 -0
  16. ezmsg/sigproc/diff.py +82 -0
  17. ezmsg/sigproc/downsample.py +72 -81
  18. ezmsg/sigproc/ewma.py +217 -0
  19. ezmsg/sigproc/ewmfilter.py +1 -1
  20. ezmsg/sigproc/extract_axis.py +39 -0
  21. ezmsg/sigproc/fbcca.py +307 -0
  22. ezmsg/sigproc/filter.py +254 -148
  23. ezmsg/sigproc/filterbank.py +226 -214
  24. ezmsg/sigproc/filterbankdesign.py +129 -0
  25. ezmsg/sigproc/fir_hilbert.py +336 -0
  26. ezmsg/sigproc/fir_pmc.py +209 -0
  27. ezmsg/sigproc/firfilter.py +117 -0
  28. ezmsg/sigproc/gaussiansmoothing.py +89 -0
  29. ezmsg/sigproc/kaiser.py +106 -0
  30. ezmsg/sigproc/linear.py +120 -0
  31. ezmsg/sigproc/math/abs.py +23 -22
  32. ezmsg/sigproc/math/add.py +120 -0
  33. ezmsg/sigproc/math/clip.py +33 -25
  34. ezmsg/sigproc/math/difference.py +117 -43
  35. ezmsg/sigproc/math/invert.py +18 -25
  36. ezmsg/sigproc/math/log.py +38 -33
  37. ezmsg/sigproc/math/scale.py +24 -25
  38. ezmsg/sigproc/messages.py +1 -2
  39. ezmsg/sigproc/quantize.py +68 -0
  40. ezmsg/sigproc/resample.py +278 -0
  41. ezmsg/sigproc/rollingscaler.py +232 -0
  42. ezmsg/sigproc/sampler.py +209 -254
  43. ezmsg/sigproc/scaler.py +93 -218
  44. ezmsg/sigproc/signalinjector.py +44 -43
  45. ezmsg/sigproc/slicer.py +74 -102
  46. ezmsg/sigproc/spectral.py +3 -3
  47. ezmsg/sigproc/spectrogram.py +70 -70
  48. ezmsg/sigproc/spectrum.py +187 -173
  49. ezmsg/sigproc/transpose.py +134 -0
  50. ezmsg/sigproc/util/__init__.py +0 -0
  51. ezmsg/sigproc/util/asio.py +25 -0
  52. ezmsg/sigproc/util/axisarray_buffer.py +365 -0
  53. ezmsg/sigproc/util/buffer.py +449 -0
  54. ezmsg/sigproc/util/message.py +17 -0
  55. ezmsg/sigproc/util/profile.py +23 -0
  56. ezmsg/sigproc/util/sparse.py +115 -0
  57. ezmsg/sigproc/util/typeresolution.py +17 -0
  58. ezmsg/sigproc/wavelets.py +147 -154
  59. ezmsg/sigproc/window.py +248 -210
  60. ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
  61. ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
  62. {ezmsg_sigproc-1.7.0.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -1
  63. ezmsg/sigproc/synth.py +0 -621
  64. ezmsg_sigproc-1.7.0.dist-info/METADATA +0 -58
  65. ezmsg_sigproc-1.7.0.dist-info/RECORD +0 -36
  66. /ezmsg_sigproc-1.7.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/ewma.py ADDED
@@ -0,0 +1,217 @@
1
+ import functools
2
+ from dataclasses import field
3
+
4
+ import ezmsg.core as ez
5
+ import numpy as np
6
+ import numpy.typing as npt
7
+ import scipy.signal as sps
8
+ from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state
9
+ from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
10
+ from ezmsg.util.messages.util import replace
11
+
12
+
13
+ def _tau_from_alpha(alpha: float, dt: float) -> float:
14
+ """
15
+ Inverse of _alpha_from_tau. See that function for explanation.
16
+ """
17
+ return -dt / np.log(1 - alpha)
18
+
19
+
20
+ def _alpha_from_tau(tau: float, dt: float) -> float:
21
+ """
22
+ # https://en.wikipedia.org/wiki/Exponential_smoothing#Time_constant
23
+ :param tau: The amount of time for the smoothed response of a unit step function to reach
24
+ 1 - 1/e approx-eq 63.2%.
25
+ :param dt: sampling period, or 1 / sampling_rate.
26
+ :return: alpha, the "fading factor" in exponential smoothing.
27
+ """
28
+ return 1 - np.exp(-dt / tau)
29
+
30
+
31
+ def ewma_step(sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None):
32
+ """
33
+ Do an exponentially weighted moving average step.
34
+
35
+ Args:
36
+ sample: The new sample.
37
+ zi: The output of the previous step.
38
+ alpha: Fading factor.
39
+ beta: Persisting factor. If None, it is calculated as 1-alpha.
40
+
41
+ Returns:
42
+ alpha * sample + beta * zi
43
+
44
+ """
45
+ # Potential micro-optimization:
46
+ # Current: scalar-arr multiplication, scalar-arr multiplication, arr-arr addition
47
+ # Alternative: arr-arr subtraction, arr-arr multiplication, arr-arr addition
48
+ # return zi + alpha * (new_sample - zi)
49
+ beta = beta or (1 - alpha)
50
+ return alpha * sample + beta * zi
51
+
52
+
53
+ class EWMA_Deprecated:
54
+ """
55
+ Grabbed these methods from https://stackoverflow.com/a/70998068 and other answers in that topic,
56
+ but they ended up being slower than the scipy.signal.lfilter method.
57
+ Additionally, `compute` and `compute2` suffer from potential errors as the vector length increases
58
+ and beta**n approaches zero.
59
+ """
60
+
61
+ def __init__(self, alpha: float, max_len: int):
62
+ self.alpha = alpha
63
+ self.beta = 1 - alpha
64
+ self.prev: npt.NDArray | None = None
65
+ self.weights = np.empty((max_len + 1,), float)
66
+ self._precalc_weights(max_len)
67
+ self._step_func = functools.partial(ewma_step, alpha=self.alpha, beta=self.beta)
68
+
69
+ def _precalc_weights(self, n: int):
70
+ # (1-α)^0, (1-α)^1, (1-α)^2, ..., (1-α)^n
71
+ np.power(self.beta, np.arange(n + 1), out=self.weights)
72
+
73
+ def compute(self, arr: npt.NDArray, out: npt.NDArray | None = None) -> npt.NDArray:
74
+ if out is None:
75
+ out = np.empty(arr.shape, arr.dtype)
76
+
77
+ n = arr.shape[0]
78
+ weights = self.weights[:n]
79
+ weights = np.expand_dims(weights, list(range(1, arr.ndim)))
80
+
81
+ # α*P0, α*P1, α*P2, ..., α*Pn
82
+ np.multiply(self.alpha, arr, out)
83
+
84
+ # α*P0/(1-α)^0, α*P1/(1-α)^1, α*P2/(1-α)^2, ..., α*Pn/(1-α)^n
85
+ np.divide(out, weights, out)
86
+
87
+ # α*P0/(1-α)^0, α*P0/(1-α)^0 + α*P1/(1-α)^1, ...
88
+ np.cumsum(out, axis=0, out=out)
89
+
90
+ # (α*P0/(1-α)^0)*(1-α)^0, (α*P0/(1-α)^0 + α*P1/(1-α)^1)*(1-α)^1, ...
91
+ np.multiply(out, weights, out)
92
+
93
+ # Add the previous output
94
+ if self.prev is None:
95
+ self.prev = arr[:1]
96
+
97
+ out += self.prev * np.expand_dims(self.weights[1 : n + 1], list(range(1, arr.ndim)))
98
+
99
+ self.prev = out[-1:]
100
+
101
+ return out
102
+
103
+ def compute2(self, arr: npt.NDArray) -> npt.NDArray:
104
+ """
105
+ Compute the Exponentially Weighted Moving Average (EWMA) of the input array.
106
+
107
+ Args:
108
+ arr: The input array to be smoothed.
109
+
110
+ Returns:
111
+ The smoothed array.
112
+ """
113
+ n = arr.shape[0]
114
+ if n > len(self.weights):
115
+ self._precalc_weights(n)
116
+ weights = self.weights[:n][::-1]
117
+ weights = np.expand_dims(weights, list(range(1, arr.ndim)))
118
+
119
+ result = np.cumsum(self.alpha * weights * arr, axis=0)
120
+ result = result / weights
121
+
122
+ # Handle the first call when prev is unset
123
+ if self.prev is None:
124
+ self.prev = arr[:1]
125
+
126
+ result += self.prev * np.expand_dims(self.weights[1 : n + 1], list(range(1, arr.ndim)))
127
+
128
+ # Store the result back into prev
129
+ self.prev = result[-1]
130
+
131
+ return result
132
+
133
+ def compute_sample(self, new_sample: npt.NDArray) -> npt.NDArray:
134
+ if self.prev is None:
135
+ self.prev = new_sample
136
+ self.prev = self._step_func(new_sample, self.prev)
137
+ return self.prev
138
+
139
+
140
+ class EWMASettings(ez.Settings):
141
+ time_constant: float = 1.0
142
+ """The amount of time for the smoothed response of a unit step function to reach 1 - 1/e approx-eq 63.2%."""
143
+
144
+ axis: str | None = None
145
+
146
+ accumulate: bool = True
147
+ """If True, update the EWMA state with each sample. If False, only apply
148
+ the current EWMA estimate without updating state (useful for inference
149
+ periods where you don't want to adapt statistics)."""
150
+
151
+
152
+ @processor_state
153
+ class EWMAState:
154
+ alpha: float = field(default_factory=lambda: _alpha_from_tau(1.0, 1000.0))
155
+ zi: npt.NDArray | None = None
156
+
157
+
158
+ class EWMATransformer(BaseStatefulTransformer[EWMASettings, AxisArray, AxisArray, EWMAState]):
159
+ def _hash_message(self, message: AxisArray) -> int:
160
+ axis = self.settings.axis or message.dims[0]
161
+ axis_idx = message.get_axis_idx(axis)
162
+ sample_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
163
+ return hash((sample_shape, message.axes[axis].gain, message.key))
164
+
165
+ def _reset_state(self, message: AxisArray) -> None:
166
+ axis = self.settings.axis or message.dims[0]
167
+ self._state.alpha = _alpha_from_tau(self.settings.time_constant, message.axes[axis].gain)
168
+ sub_dat = slice_along_axis(message.data, slice(None, 1, None), axis=message.get_axis_idx(axis))
169
+ self._state.zi = (1 - self._state.alpha) * sub_dat
170
+
171
+ def _process(self, message: AxisArray) -> AxisArray:
172
+ if np.prod(message.data.shape) == 0:
173
+ return message
174
+ axis = self.settings.axis or message.dims[0]
175
+ axis_idx = message.get_axis_idx(axis)
176
+ if self.settings.accumulate:
177
+ # Normal behavior: update state with new samples
178
+ expected, self._state.zi = sps.lfilter(
179
+ [self._state.alpha],
180
+ [1.0, self._state.alpha - 1.0],
181
+ message.data,
182
+ axis=axis_idx,
183
+ zi=self._state.zi,
184
+ )
185
+ else:
186
+ # Process-only: compute output without updating state
187
+ expected, _ = sps.lfilter(
188
+ [self._state.alpha],
189
+ [1.0, self._state.alpha - 1.0],
190
+ message.data,
191
+ axis=axis_idx,
192
+ zi=self._state.zi,
193
+ )
194
+ return replace(message, data=expected)
195
+
196
+
197
+ class EWMAUnit(BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, EWMATransformer]):
198
+ SETTINGS = EWMASettings
199
+
200
+ @ez.subscriber(BaseTransformerUnit.INPUT_SETTINGS)
201
+ async def on_settings(self, msg: EWMASettings) -> None:
202
+ """
203
+ Handle settings updates with smart reset behavior.
204
+
205
+ Only resets state if `axis` changes (structural change).
206
+ Changes to `time_constant` or `accumulate` are applied without
207
+ resetting accumulated state.
208
+ """
209
+ old_axis = self.SETTINGS.axis
210
+ self.apply_settings(msg)
211
+
212
+ if msg.axis != old_axis:
213
+ # Axis changed - need full reset
214
+ self.create_processor()
215
+ else:
216
+ # Only accumulate or time_constant changed - keep state
217
+ self.processor.settings = msg
@@ -2,9 +2,9 @@ import asyncio
2
2
  import typing
3
3
 
4
4
  import ezmsg.core as ez
5
+ import numpy as np
5
6
  from ezmsg.util.messages.axisarray import AxisArray
6
7
  from ezmsg.util.messages.util import replace
7
- import numpy as np
8
8
 
9
9
  from .window import Window, WindowSettings
10
10
 
@@ -0,0 +1,39 @@
1
+ import ezmsg.core as ez
2
+ import numpy as np
3
+ from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
4
+ from ezmsg.util.messages.axisarray import AxisArray, replace
5
+
6
+
7
+ class ExtractAxisSettings(ez.Settings):
8
+ axis: str = "freq"
9
+ reference: str = "time"
10
+
11
+
12
+ class ExtractAxisData(BaseTransformer[ExtractAxisSettings, AxisArray, AxisArray]):
13
+ def _process(self, message: AxisArray) -> AxisArray:
14
+ targ_ax = message.axes[self.settings.axis]
15
+ if hasattr(targ_ax, "data"):
16
+ # Extracted axis is of type CoordinateAxis
17
+ return replace(
18
+ message,
19
+ data=targ_ax.data,
20
+ dims=targ_ax.dims,
21
+ axes={k: v for k, v in message.axes.items() if k in targ_ax.dims},
22
+ )
23
+ # Note: So far we don't have any transformers where the coordinate axis has its own axes,
24
+ # but if that happens in the future, we'd need to consider how to handle that.
25
+
26
+ else:
27
+ # Extracted axis is of type LinearAxis
28
+ # LinearAxis can only yield a 1d array data which simplifies dims and axes.
29
+ n = message.data.shape[message.get_axis_idx(self.settings.reference)]
30
+ return replace(
31
+ message,
32
+ data=targ_ax.value(np.arange(n)),
33
+ dims=[self.settings.reference],
34
+ axes={self.settings.reference: message.axes[self.settings.reference]},
35
+ )
36
+
37
+
38
+ class ExtractAxisDataUnit(BaseTransformerUnit[ExtractAxisSettings, AxisArray, AxisArray, ExtractAxisData]):
39
+ SETTINGS = ExtractAxisSettings
ezmsg/sigproc/fbcca.py ADDED
@@ -0,0 +1,307 @@
1
+ import math
2
+ import typing
3
+ from dataclasses import field
4
+
5
+ import ezmsg.core as ez
6
+ import numpy as np
7
+ from ezmsg.baseproc import (
8
+ BaseProcessor,
9
+ BaseStatefulProcessor,
10
+ BaseTransformer,
11
+ BaseTransformerUnit,
12
+ CompositeProcessor,
13
+ )
14
+ from ezmsg.util.messages.axisarray import AxisArray
15
+ from ezmsg.util.messages.util import replace
16
+
17
+ from .filterbankdesign import (
18
+ FilterbankDesignSettings,
19
+ FilterbankDesignTransformer,
20
+ )
21
+ from .kaiser import KaiserFilterSettings
22
+ from .sampler import SampleTriggerMessage
23
+ from .window import WindowSettings, WindowTransformer
24
+
25
+
26
+ class FBCCASettings(ez.Settings):
27
+ """
28
+ Settings for :obj:`FBCCATransformer`
29
+ """
30
+
31
+ time_dim: str
32
+ """
33
+ The time dim in the data array.
34
+ """
35
+
36
+ ch_dim: str
37
+ """
38
+ The channels dim in the data array.
39
+ """
40
+
41
+ filterbank_dim: str | None = None
42
+ """
43
+ The filter bank subband dim in the data array. If unspecified, method falls back to CCA
44
+ None (default): the input has no subbands; just use CCA
45
+ """
46
+
47
+ harmonics: int = 5
48
+ """
49
+ The number of additional harmonics beyond the fundamental to use for the 'design' matrix.
50
+ 5 (default): Evaluate 5 harmonics of the base frequency.
51
+ Many periodic signals are not pure sinusoids, and inclusion of higher harmonics can help evaluate the
52
+ presence of signals with higher frequency harmonic content
53
+ """
54
+
55
+ freqs: typing.List[float] = field(default_factory=list)
56
+ """
57
+ Frequencies (in hz) to evaluate the presence of within the input signal.
58
+ [] (default): an empty list; frequencies will be found within the input SampleMessages.
59
+ AxisArrays have no good place to put this metadata, so specify frequencies here if only AxisArrays
60
+ will be passed as input to the generator. If the input has a `trigger` attr of type :obj:`SampleTriggerMessage`,
61
+ the processor looks for the `freqs` attribute within that trigger for a list of frequencies to evaluate.
62
+ This field is present in the :obj:`SSVEPSampleTriggerMessage` defined in ezmsg.tasks.ssvep from
63
+ the ezmsg-tasks package.
64
+ NOTE: Avoid frequencies that have line-noise (60 Hz/50 Hz) as a harmonic.
65
+ """
66
+
67
+ softmax_beta: float = 1.0
68
+ """
69
+ Beta parameter for softmax on output --> "probabilities".
70
+ 1.0 (default): Use the shifted softmax transformation to output 0-1 probabilities.
71
+ If 0.0, the maximum singular value of the SVD for each design matrix is output
72
+ """
73
+
74
+ target_freq_dim: str = "target_freq"
75
+ """
76
+ Name for dim to put target frequency outputs on.
77
+ 'target_freq' (default)
78
+ """
79
+
80
+ max_int_time: float = 0.0
81
+ """
82
+ Maximum integration time (in seconds) to use for calculation.
83
+ 0 (default): Use all time provided for the calculation.
84
+ Useful for artificially limiting the amount of data used for the CCA method to evaluate
85
+ the necessary integration time for good decoding performance
86
+ """
87
+
88
+
89
+ class FBCCATransformer(BaseTransformer[FBCCASettings, AxisArray, AxisArray]):
90
+ """
91
+ A canonical-correlation (CCA) signal decoder for detection of periodic activity in multi-channel timeseries
92
+ recordings. It is particularly useful for detecting the presence of steady-state evoked responses in multi-channel
93
+ EEG data. Please see Lin et. al. 2007 for a description on the use of CCA to detect the presence of SSVEP in EEG
94
+ data.
95
+ This implementation also includes the "Filterbank" extension of the CCA decoding approach which utilizes a
96
+ filterbank to decompose input multi-channel EEG data into several frequency sub-bands; each of which is analyzed
97
+ with CCA, then combined using a weighted sum; allowing CCA to more readily identify harmonic content in EEG data.
98
+ Read more about this approach in Chen et. al. 2015.
99
+
100
+ ## Further reading:
101
+ * [Lin et. al. 2007](https://ieeexplore.ieee.org/document/4015614)
102
+ * [Nakanishi et. al. 2015](https://doi.org/10.1371%2Fjournal.pone.0140703)
103
+ * [Chen et. al. 2015](http://dx.doi.org/10.1088/1741-2560/12/4/046008)
104
+ """
105
+
106
+ def _process(self, message: AxisArray) -> AxisArray:
107
+ """
108
+ Input: AxisArray with at least a time_dim, and ch_dim
109
+ Output: AxisArray with time_dim, ch_dim, (and filterbank_dim if specified)
110
+ collapsed, with a new 'target_freq' dim of length 'freqs'
111
+ """
112
+
113
+ test_freqs: list[float] = self.settings.freqs
114
+ trigger = message.attrs.get("trigger", None)
115
+ if isinstance(trigger, SampleTriggerMessage):
116
+ if len(test_freqs) == 0:
117
+ test_freqs = getattr(trigger, "freqs", [])
118
+
119
+ if len(test_freqs) == 0:
120
+ raise ValueError("no frequencies to test")
121
+
122
+ time_dim_idx = message.get_axis_idx(self.settings.time_dim)
123
+ ch_dim_idx = message.get_axis_idx(self.settings.ch_dim)
124
+
125
+ filterbank_dim_idx = None
126
+ if self.settings.filterbank_dim is not None:
127
+ filterbank_dim_idx = message.get_axis_idx(self.settings.filterbank_dim)
128
+
129
+ # Move (filterbank_dim), time, ch to end of array
130
+ rm_dims = [self.settings.time_dim, self.settings.ch_dim]
131
+ if self.settings.filterbank_dim is not None:
132
+ rm_dims = [self.settings.filterbank_dim] + rm_dims
133
+ new_order = [i for i, dim in enumerate(message.dims) if dim not in rm_dims]
134
+ if filterbank_dim_idx is not None:
135
+ new_order.append(filterbank_dim_idx)
136
+ new_order.extend([time_dim_idx, ch_dim_idx])
137
+ out_dims = [message.dims[i] for i in new_order if message.dims[i] not in rm_dims]
138
+ data_arr = message.data.transpose(new_order)
139
+
140
+ # Add a singleton dim for filterbank dim if we don't have one
141
+ if filterbank_dim_idx is None:
142
+ data_arr = data_arr[..., None, :, :]
143
+ filterbank_dim_idx = data_arr.ndim - 3
144
+
145
+ # data_arr is now (..., filterbank, time, ch)
146
+ # Get output shape for remaining dims and reshape data_arr for iterative processing
147
+ out_shape = list(data_arr.shape[:-3])
148
+ data_arr = data_arr.reshape([math.prod(out_shape), *data_arr.shape[-3:]])
149
+
150
+ # Create output dims and axes with added target_freq_dim
151
+ out_shape.append(len(test_freqs))
152
+ out_dims.append(self.settings.target_freq_dim)
153
+ out_axes = {
154
+ axis_name: axis
155
+ for axis_name, axis in message.axes.items()
156
+ if axis_name not in rm_dims
157
+ and not (isinstance(axis, AxisArray.CoordinateAxis) and any(d in rm_dims for d in axis.dims))
158
+ }
159
+ out_axes[self.settings.target_freq_dim] = AxisArray.CoordinateAxis(
160
+ np.array(test_freqs), [self.settings.target_freq_dim]
161
+ )
162
+
163
+ if message.data.size == 0:
164
+ out_data = message.data.reshape(out_shape)
165
+ output = replace(message, data=out_data, dims=out_dims, axes=out_axes)
166
+ return output
167
+
168
+ # Get time axis
169
+ t_ax_info = message.ax(self.settings.time_dim)
170
+ t = t_ax_info.values
171
+ t -= t[0]
172
+ max_samp = len(t)
173
+ if self.settings.max_int_time > 0:
174
+ max_samp = int(abs(t_ax_info.values - self.settings.max_int_time).argmin())
175
+ t = t[:max_samp]
176
+
177
+ calc_output = np.zeros((*data_arr.shape[:-2], len(test_freqs)))
178
+
179
+ for test_freq_idx, test_freq in enumerate(test_freqs):
180
+ # Create the design matrix of base frequency and requested harmonics
181
+ Y = np.column_stack(
182
+ [
183
+ fn(2.0 * np.pi * k * test_freq * t)
184
+ for k in range(1, self.settings.harmonics + 1)
185
+ for fn in (np.sin, np.cos)
186
+ ]
187
+ )
188
+
189
+ for test_idx, arr in enumerate(data_arr): # iterate over first dim; arr is (filterbank x time x ch)
190
+ for band_idx, band in enumerate(arr): # iterate over second dim: arr is (time x ch)
191
+ calc_output[test_idx, band_idx, test_freq_idx] = cca_rho_max(band[:max_samp, ...], Y)
192
+
193
+ # Combine per-subband canonical correlations using a weighted sum
194
+ # https://iopscience.iop.org/article/10.1088/1741-2560/12/4/046008
195
+ freq_weights = (np.arange(1, calc_output.shape[1] + 1) ** -1.25) + 0.25
196
+ calc_output = ((calc_output**2) * freq_weights[None, :, None]).sum(axis=1)
197
+
198
+ if self.settings.softmax_beta != 0:
199
+ calc_output = calc_softmax(calc_output, axis=-1, beta=self.settings.softmax_beta)
200
+
201
+ output = replace(
202
+ message,
203
+ data=calc_output.reshape(out_shape),
204
+ dims=out_dims,
205
+ axes=out_axes,
206
+ )
207
+
208
+ return output
209
+
210
+
211
+ class FBCCA(BaseTransformerUnit[FBCCASettings, AxisArray, AxisArray, FBCCATransformer]):
212
+ SETTINGS = FBCCASettings
213
+
214
+
215
+ class StreamingFBCCASettings(FBCCASettings):
216
+ """
217
+ Perform rolling/streaming FBCCA on incoming EEG.
218
+ Decomposes the input multi-channel timeseries data into multiple sub-bands using a FilterbankDesign Transformer,
219
+ then accumulates data using Window into short-time observations for analysis using an FBCCA Transformer.
220
+ """
221
+
222
+ window_dur: float = 4.0 # sec
223
+ window_shift: float = 0.5 # sec
224
+ window_dim: str = "fbcca_window"
225
+ filter_bw: float = 7.0 # Hz
226
+ filter_low: float = 7.0 # Hz
227
+ trans_bw: float = 2.0 # Hz
228
+ ripple_db: float = 20.0 # dB
229
+ subbands: int = 12
230
+
231
+
232
+ class StreamingFBCCATransformer(CompositeProcessor[StreamingFBCCASettings, AxisArray, AxisArray]):
233
+ @staticmethod
234
+ def _initialize_processors(
235
+ settings: StreamingFBCCASettings,
236
+ ) -> dict[str, BaseProcessor | BaseStatefulProcessor]:
237
+ pipeline = {}
238
+
239
+ if settings.filterbank_dim is not None:
240
+ cut_freqs = (np.arange(settings.subbands + 1) * settings.filter_bw) + settings.filter_low
241
+ filters = [
242
+ KaiserFilterSettings(
243
+ axis=settings.time_dim,
244
+ cutoff=(c - settings.trans_bw, cut_freqs[-1]),
245
+ ripple=settings.ripple_db,
246
+ width=settings.trans_bw,
247
+ pass_zero=False,
248
+ )
249
+ for c in cut_freqs[:-1]
250
+ ]
251
+
252
+ pipeline["filterbank"] = FilterbankDesignTransformer(
253
+ FilterbankDesignSettings(filters=filters, new_axis=settings.filterbank_dim)
254
+ )
255
+
256
+ pipeline["window"] = WindowTransformer(
257
+ WindowSettings(
258
+ axis=settings.time_dim,
259
+ newaxis=settings.window_dim,
260
+ window_dur=settings.window_dur,
261
+ window_shift=settings.window_shift,
262
+ zero_pad_until="shift",
263
+ )
264
+ )
265
+
266
+ pipeline["fbcca"] = FBCCATransformer(settings)
267
+
268
+ return pipeline
269
+
270
+
271
+ class StreamingFBCCA(BaseTransformerUnit[StreamingFBCCASettings, AxisArray, AxisArray, StreamingFBCCATransformer]):
272
+ SETTINGS = StreamingFBCCASettings
273
+
274
+
275
+ def cca_rho_max(X: np.ndarray, Y: np.ndarray) -> float:
276
+ """
277
+ X: (n_time, n_ch)
278
+ Y: (n_time, n_ref) # design matrix for one frequency
279
+ returns: largest canonical correlation in [0,1]
280
+ """
281
+ # Center columns
282
+ Xc = X - X.mean(axis=0, keepdims=True)
283
+ Yc = Y - Y.mean(axis=0, keepdims=True)
284
+
285
+ # Drop any zero-variance columns to avoid rank issues
286
+ Xc = Xc[:, Xc.std(axis=0) > 1e-12]
287
+ Yc = Yc[:, Yc.std(axis=0) > 1e-12]
288
+ if Xc.size == 0 or Yc.size == 0:
289
+ return 0.0
290
+
291
+ # Orthonormal bases
292
+ Qx, _ = np.linalg.qr(Xc, mode="reduced") # (n_time, r_x)
293
+ Qy, _ = np.linalg.qr(Yc, mode="reduced") # (n_time, r_y)
294
+
295
+ # Canonical correlations are the singular values of Qx^T Qy
296
+ with np.errstate(divide="ignore", over="ignore", invalid="ignore"):
297
+ s = np.linalg.svd(Qx.T @ Qy, compute_uv=False)
298
+ return float(s[0]) if s.size else 0.0
299
+
300
+
301
+ def calc_softmax(cv: np.ndarray, axis: int, beta: float = 1.0):
302
+ # Calculate softmax with shifting to avoid overflow
303
+ # (https://doi.org/10.1093/imanum/draa038)
304
+ cv = cv - cv.max(axis=axis, keepdims=True)
305
+ cv = np.exp(beta * cv)
306
+ cv = cv / np.sum(cv, axis=axis, keepdims=True)
307
+ return cv