ezmsg-sigproc 1.2.2__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ezmsg/sigproc/__init__.py +1 -1
  2. ezmsg/sigproc/__version__.py +34 -1
  3. ezmsg/sigproc/activation.py +78 -0
  4. ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
  5. ezmsg/sigproc/affinetransform.py +235 -0
  6. ezmsg/sigproc/aggregate.py +276 -0
  7. ezmsg/sigproc/bandpower.py +80 -0
  8. ezmsg/sigproc/base.py +149 -0
  9. ezmsg/sigproc/butterworthfilter.py +129 -39
  10. ezmsg/sigproc/butterworthzerophase.py +305 -0
  11. ezmsg/sigproc/cheby.py +125 -0
  12. ezmsg/sigproc/combfilter.py +160 -0
  13. ezmsg/sigproc/coordinatespaces.py +159 -0
  14. ezmsg/sigproc/decimate.py +46 -18
  15. ezmsg/sigproc/denormalize.py +78 -0
  16. ezmsg/sigproc/detrend.py +28 -0
  17. ezmsg/sigproc/diff.py +82 -0
  18. ezmsg/sigproc/downsample.py +97 -49
  19. ezmsg/sigproc/ewma.py +217 -0
  20. ezmsg/sigproc/ewmfilter.py +45 -19
  21. ezmsg/sigproc/extract_axis.py +39 -0
  22. ezmsg/sigproc/fbcca.py +307 -0
  23. ezmsg/sigproc/filter.py +282 -117
  24. ezmsg/sigproc/filterbank.py +292 -0
  25. ezmsg/sigproc/filterbankdesign.py +129 -0
  26. ezmsg/sigproc/fir_hilbert.py +336 -0
  27. ezmsg/sigproc/fir_pmc.py +209 -0
  28. ezmsg/sigproc/firfilter.py +117 -0
  29. ezmsg/sigproc/gaussiansmoothing.py +89 -0
  30. ezmsg/sigproc/kaiser.py +106 -0
  31. ezmsg/sigproc/linear.py +120 -0
  32. ezmsg/sigproc/math/__init__.py +0 -0
  33. ezmsg/sigproc/math/abs.py +35 -0
  34. ezmsg/sigproc/math/add.py +120 -0
  35. ezmsg/sigproc/math/clip.py +48 -0
  36. ezmsg/sigproc/math/difference.py +143 -0
  37. ezmsg/sigproc/math/invert.py +28 -0
  38. ezmsg/sigproc/math/log.py +57 -0
  39. ezmsg/sigproc/math/scale.py +39 -0
  40. ezmsg/sigproc/messages.py +3 -6
  41. ezmsg/sigproc/quantize.py +68 -0
  42. ezmsg/sigproc/resample.py +278 -0
  43. ezmsg/sigproc/rollingscaler.py +232 -0
  44. ezmsg/sigproc/sampler.py +232 -241
  45. ezmsg/sigproc/scaler.py +165 -0
  46. ezmsg/sigproc/signalinjector.py +70 -0
  47. ezmsg/sigproc/slicer.py +138 -0
  48. ezmsg/sigproc/spectral.py +6 -132
  49. ezmsg/sigproc/spectrogram.py +90 -0
  50. ezmsg/sigproc/spectrum.py +277 -0
  51. ezmsg/sigproc/transpose.py +134 -0
  52. ezmsg/sigproc/util/__init__.py +0 -0
  53. ezmsg/sigproc/util/asio.py +25 -0
  54. ezmsg/sigproc/util/axisarray_buffer.py +365 -0
  55. ezmsg/sigproc/util/buffer.py +449 -0
  56. ezmsg/sigproc/util/message.py +17 -0
  57. ezmsg/sigproc/util/profile.py +23 -0
  58. ezmsg/sigproc/util/sparse.py +115 -0
  59. ezmsg/sigproc/util/typeresolution.py +17 -0
  60. ezmsg/sigproc/wavelets.py +187 -0
  61. ezmsg/sigproc/window.py +301 -117
  62. ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
  63. ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
  64. {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -2
  65. ezmsg/sigproc/synth.py +0 -411
  66. ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
  67. ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
  68. ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
  69. /ezmsg_sigproc-1.2.2.dist-info/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/ewma.py ADDED
@@ -0,0 +1,217 @@
1
+ import functools
2
+ from dataclasses import field
3
+
4
+ import ezmsg.core as ez
5
+ import numpy as np
6
+ import numpy.typing as npt
7
+ import scipy.signal as sps
8
+ from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state
9
+ from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
10
+ from ezmsg.util.messages.util import replace
11
+
12
+
13
+ def _tau_from_alpha(alpha: float, dt: float) -> float:
14
+ """
15
+ Inverse of _alpha_from_tau. See that function for explanation.
16
+ """
17
+ return -dt / np.log(1 - alpha)
18
+
19
+
20
+ def _alpha_from_tau(tau: float, dt: float) -> float:
21
+ """
22
+ # https://en.wikipedia.org/wiki/Exponential_smoothing#Time_constant
23
+ :param tau: The amount of time for the smoothed response of a unit step function to reach
24
+ 1 - 1/e approx-eq 63.2%.
25
+ :param dt: sampling period, or 1 / sampling_rate.
26
+ :return: alpha, the "fading factor" in exponential smoothing.
27
+ """
28
+ return 1 - np.exp(-dt / tau)
29
+
30
+
31
+ def ewma_step(sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None):
32
+ """
33
+ Do an exponentially weighted moving average step.
34
+
35
+ Args:
36
+ sample: The new sample.
37
+ zi: The output of the previous step.
38
+ alpha: Fading factor.
39
+ beta: Persisting factor. If None, it is calculated as 1-alpha.
40
+
41
+ Returns:
42
+ alpha * sample + beta * zi
43
+
44
+ """
45
+ # Potential micro-optimization:
46
+ # Current: scalar-arr multiplication, scalar-arr multiplication, arr-arr addition
47
+ # Alternative: arr-arr subtraction, arr-arr multiplication, arr-arr addition
48
+ # return zi + alpha * (new_sample - zi)
49
+ beta = beta or (1 - alpha)
50
+ return alpha * sample + beta * zi
51
+
52
+
53
+ class EWMA_Deprecated:
54
+ """
55
+ Grabbed these methods from https://stackoverflow.com/a/70998068 and other answers in that topic,
56
+ but they ended up being slower than the scipy.signal.lfilter method.
57
+ Additionally, `compute` and `compute2` suffer from potential errors as the vector length increases
58
+ and beta**n approaches zero.
59
+ """
60
+
61
+ def __init__(self, alpha: float, max_len: int):
62
+ self.alpha = alpha
63
+ self.beta = 1 - alpha
64
+ self.prev: npt.NDArray | None = None
65
+ self.weights = np.empty((max_len + 1,), float)
66
+ self._precalc_weights(max_len)
67
+ self._step_func = functools.partial(ewma_step, alpha=self.alpha, beta=self.beta)
68
+
69
+ def _precalc_weights(self, n: int):
70
+ # (1-α)^0, (1-α)^1, (1-α)^2, ..., (1-α)^n
71
+ np.power(self.beta, np.arange(n + 1), out=self.weights)
72
+
73
+ def compute(self, arr: npt.NDArray, out: npt.NDArray | None = None) -> npt.NDArray:
74
+ if out is None:
75
+ out = np.empty(arr.shape, arr.dtype)
76
+
77
+ n = arr.shape[0]
78
+ weights = self.weights[:n]
79
+ weights = np.expand_dims(weights, list(range(1, arr.ndim)))
80
+
81
+ # α*P0, α*P1, α*P2, ..., α*Pn
82
+ np.multiply(self.alpha, arr, out)
83
+
84
+ # α*P0/(1-α)^0, α*P1/(1-α)^1, α*P2/(1-α)^2, ..., α*Pn/(1-α)^n
85
+ np.divide(out, weights, out)
86
+
87
+ # α*P0/(1-α)^0, α*P0/(1-α)^0 + α*P1/(1-α)^1, ...
88
+ np.cumsum(out, axis=0, out=out)
89
+
90
+ # (α*P0/(1-α)^0)*(1-α)^0, (α*P0/(1-α)^0 + α*P1/(1-α)^1)*(1-α)^1, ...
91
+ np.multiply(out, weights, out)
92
+
93
+ # Add the previous output
94
+ if self.prev is None:
95
+ self.prev = arr[:1]
96
+
97
+ out += self.prev * np.expand_dims(self.weights[1 : n + 1], list(range(1, arr.ndim)))
98
+
99
+ self.prev = out[-1:]
100
+
101
+ return out
102
+
103
+ def compute2(self, arr: npt.NDArray) -> npt.NDArray:
104
+ """
105
+ Compute the Exponentially Weighted Moving Average (EWMA) of the input array.
106
+
107
+ Args:
108
+ arr: The input array to be smoothed.
109
+
110
+ Returns:
111
+ The smoothed array.
112
+ """
113
+ n = arr.shape[0]
114
+ if n > len(self.weights):
115
+ self._precalc_weights(n)
116
+ weights = self.weights[:n][::-1]
117
+ weights = np.expand_dims(weights, list(range(1, arr.ndim)))
118
+
119
+ result = np.cumsum(self.alpha * weights * arr, axis=0)
120
+ result = result / weights
121
+
122
+ # Handle the first call when prev is unset
123
+ if self.prev is None:
124
+ self.prev = arr[:1]
125
+
126
+ result += self.prev * np.expand_dims(self.weights[1 : n + 1], list(range(1, arr.ndim)))
127
+
128
+ # Store the result back into prev
129
+ self.prev = result[-1]
130
+
131
+ return result
132
+
133
+ def compute_sample(self, new_sample: npt.NDArray) -> npt.NDArray:
134
+ if self.prev is None:
135
+ self.prev = new_sample
136
+ self.prev = self._step_func(new_sample, self.prev)
137
+ return self.prev
138
+
139
+
140
+ class EWMASettings(ez.Settings):
141
+ time_constant: float = 1.0
142
+ """The amount of time for the smoothed response of a unit step function to reach 1 - 1/e approx-eq 63.2%."""
143
+
144
+ axis: str | None = None
145
+
146
+ accumulate: bool = True
147
+ """If True, update the EWMA state with each sample. If False, only apply
148
+ the current EWMA estimate without updating state (useful for inference
149
+ periods where you don't want to adapt statistics)."""
150
+
151
+
152
+ @processor_state
153
+ class EWMAState:
154
+ alpha: float = field(default_factory=lambda: _alpha_from_tau(1.0, 1000.0))
155
+ zi: npt.NDArray | None = None
156
+
157
+
158
+ class EWMATransformer(BaseStatefulTransformer[EWMASettings, AxisArray, AxisArray, EWMAState]):
159
+ def _hash_message(self, message: AxisArray) -> int:
160
+ axis = self.settings.axis or message.dims[0]
161
+ axis_idx = message.get_axis_idx(axis)
162
+ sample_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
163
+ return hash((sample_shape, message.axes[axis].gain, message.key))
164
+
165
+ def _reset_state(self, message: AxisArray) -> None:
166
+ axis = self.settings.axis or message.dims[0]
167
+ self._state.alpha = _alpha_from_tau(self.settings.time_constant, message.axes[axis].gain)
168
+ sub_dat = slice_along_axis(message.data, slice(None, 1, None), axis=message.get_axis_idx(axis))
169
+ self._state.zi = (1 - self._state.alpha) * sub_dat
170
+
171
+ def _process(self, message: AxisArray) -> AxisArray:
172
+ if np.prod(message.data.shape) == 0:
173
+ return message
174
+ axis = self.settings.axis or message.dims[0]
175
+ axis_idx = message.get_axis_idx(axis)
176
+ if self.settings.accumulate:
177
+ # Normal behavior: update state with new samples
178
+ expected, self._state.zi = sps.lfilter(
179
+ [self._state.alpha],
180
+ [1.0, self._state.alpha - 1.0],
181
+ message.data,
182
+ axis=axis_idx,
183
+ zi=self._state.zi,
184
+ )
185
+ else:
186
+ # Process-only: compute output without updating state
187
+ expected, _ = sps.lfilter(
188
+ [self._state.alpha],
189
+ [1.0, self._state.alpha - 1.0],
190
+ message.data,
191
+ axis=axis_idx,
192
+ zi=self._state.zi,
193
+ )
194
+ return replace(message, data=expected)
195
+
196
+
197
+ class EWMAUnit(BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, EWMATransformer]):
198
+ SETTINGS = EWMASettings
199
+
200
+ @ez.subscriber(BaseTransformerUnit.INPUT_SETTINGS)
201
+ async def on_settings(self, msg: EWMASettings) -> None:
202
+ """
203
+ Handle settings updates with smart reset behavior.
204
+
205
+ Only resets state if `axis` changes (structural change).
206
+ Changes to `time_constant` or `accumulate` are applied without
207
+ resetting accumulated state.
208
+ """
209
+ old_axis = self.SETTINGS.axis
210
+ self.apply_settings(msg)
211
+
212
+ if msg.axis != old_axis:
213
+ # Axis changed - need full reset
214
+ self.create_processor()
215
+ else:
216
+ # Only accumulate or time_constant changed - keep state
217
+ self.processor.settings = msg
@@ -1,19 +1,20 @@
1
1
  import asyncio
2
- from dataclasses import replace
2
+ import typing
3
3
 
4
4
  import ezmsg.core as ez
5
- from ezmsg.util.messages.axisarray import AxisArray
6
-
7
5
  import numpy as np
6
+ from ezmsg.util.messages.axisarray import AxisArray
7
+ from ezmsg.util.messages.util import replace
8
8
 
9
9
  from .window import Window, WindowSettings
10
10
 
11
- from typing import AsyncGenerator, Optional
12
-
13
11
 
14
12
  class EWMSettings(ez.Settings):
15
- axis: Optional[str] = None
16
- zero_offset: bool = True # If true, we assume zero DC offset
13
+ axis: str | None = None
14
+ """Name of the axis to accumulate."""
15
+
16
+ zero_offset: bool = True
17
+ """If true, we assume zero DC offset for input data."""
17
18
 
18
19
 
19
20
  class EWMState(ez.State):
@@ -23,19 +24,23 @@ class EWMState(ez.State):
23
24
 
24
25
  class EWM(ez.Unit):
25
26
  """
26
- Exponentially Weighted Moving Average Standardization
27
+ Exponentially Weighted Moving Average Standardization.
28
+ This is deprecated. Please use :obj:`ezmsg.sigproc.scaler.AdaptiveStandardScaler` instead.
27
29
 
28
30
  References https://stackoverflow.com/a/42926270
29
31
  """
30
32
 
31
- SETTINGS: EWMSettings
32
- STATE: EWMState
33
+ SETTINGS = EWMSettings
34
+ STATE = EWMState
33
35
 
34
36
  INPUT_SIGNAL = ez.InputStream(AxisArray)
35
37
  INPUT_BUFFER = ez.InputStream(AxisArray)
36
38
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
37
39
 
38
- def initialize(self) -> None:
40
+ async def initialize(self) -> None:
41
+ ez.logger.warning(
42
+ "EWM/EWMFilter is deprecated and will be removed in a future version. Use AdaptiveStandardScaler instead."
43
+ )
39
44
  self.STATE.signal_queue = asyncio.Queue()
40
45
  self.STATE.buffer_queue = asyncio.Queue()
41
46
 
@@ -48,7 +53,7 @@ class EWM(ez.Unit):
48
53
  self.STATE.buffer_queue.put_nowait(message)
49
54
 
50
55
  @ez.publisher(OUTPUT_SIGNAL)
51
- async def sync_output(self) -> AsyncGenerator:
56
+ async def sync_output(self) -> typing.AsyncGenerator:
52
57
  while True:
53
58
  signal = await self.STATE.signal_queue.get()
54
59
  buffer = await self.STATE.buffer_queue.get() # includes signal
@@ -73,9 +78,12 @@ class EWM(ez.Unit):
73
78
  buffer_data = buffer.data
74
79
  buffer_data = np.moveaxis(buffer_data, axis_idx, 0)
75
80
 
81
+ while scale_arr.ndim < buffer_data.ndim:
82
+ scale_arr = scale_arr[..., None]
83
+
76
84
  def ewma(data: np.ndarray) -> np.ndarray:
77
- mult = scale_arr[:, np.newaxis] * data * pw0
78
- out = scale_arr[::-1, np.newaxis] * mult.cumsum(axis=0)
85
+ mult = scale_arr * data * pw0
86
+ out = scale_arr[::-1] * mult.cumsum(axis=0)
79
87
 
80
88
  if not self.SETTINGS.zero_offset:
81
89
  out = (data[0, :, np.newaxis] * pows[1:]).T + out
@@ -93,13 +101,26 @@ class EWM(ez.Unit):
93
101
 
94
102
 
95
103
  class EWMFilterSettings(ez.Settings):
96
- history_dur: float # previous data to accumulate for standardization
97
- axis: Optional[str] = None
98
- zero_offset: bool = True # If true, we assume zero DC offset for input data
104
+ history_dur: float
105
+ """Previous data to accumulate for standardization."""
106
+
107
+ axis: str | None = None
108
+ """Name of the axis to accumulate."""
109
+
110
+ zero_offset: bool = True
111
+ """If true, we assume zero DC offset for input data."""
99
112
 
100
113
 
101
114
  class EWMFilter(ez.Collection):
102
- SETTINGS: EWMFilterSettings
115
+ """
116
+ A :obj:`Collection` that splits the input into a branch that
117
+ leads to :obj:`Window` which then feeds into :obj:`EWM` 's INPUT_BUFFER
118
+ and another branch that feeds directly into :obj:`EWM` 's INPUT_SIGNAL.
119
+
120
+ This is deprecated. Please use :obj:`ezmsg.sigproc.scaler.AdaptiveStandardScaler` instead.
121
+ """
122
+
123
+ SETTINGS = EWMFilterSettings
103
124
 
104
125
  INPUT_SIGNAL = ez.InputStream(AxisArray)
105
126
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
@@ -108,7 +129,12 @@ class EWMFilter(ez.Collection):
108
129
  EWM = EWM()
109
130
 
110
131
  def configure(self) -> None:
111
- self.EWM.apply_settings(EWMSettings(axis=self.SETTINGS.axis, zero_offset=True))
132
+ self.EWM.apply_settings(
133
+ EWMSettings(
134
+ axis=self.SETTINGS.axis,
135
+ zero_offset=self.SETTINGS.zero_offset,
136
+ )
137
+ )
112
138
 
113
139
  self.WINDOW.apply_settings(
114
140
  WindowSettings(
@@ -0,0 +1,39 @@
1
+ import ezmsg.core as ez
2
+ import numpy as np
3
+ from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
4
+ from ezmsg.util.messages.axisarray import AxisArray, replace
5
+
6
+
7
+ class ExtractAxisSettings(ez.Settings):
8
+ axis: str = "freq"
9
+ reference: str = "time"
10
+
11
+
12
+ class ExtractAxisData(BaseTransformer[ExtractAxisSettings, AxisArray, AxisArray]):
13
+ def _process(self, message: AxisArray) -> AxisArray:
14
+ targ_ax = message.axes[self.settings.axis]
15
+ if hasattr(targ_ax, "data"):
16
+ # Extracted axis is of type CoordinateAxis
17
+ return replace(
18
+ message,
19
+ data=targ_ax.data,
20
+ dims=targ_ax.dims,
21
+ axes={k: v for k, v in message.axes.items() if k in targ_ax.dims},
22
+ )
23
+ # Note: So far we don't have any transformers where the coordinate axis has its own axes,
24
+ # but if that happens in the future, we'd need to consider how to handle that.
25
+
26
+ else:
27
+ # Extracted axis is of type LinearAxis
28
+ # LinearAxis can only yield a 1d array data which simplifies dims and axes.
29
+ n = message.data.shape[message.get_axis_idx(self.settings.reference)]
30
+ return replace(
31
+ message,
32
+ data=targ_ax.value(np.arange(n)),
33
+ dims=[self.settings.reference],
34
+ axes={self.settings.reference: message.axes[self.settings.reference]},
35
+ )
36
+
37
+
38
+ class ExtractAxisDataUnit(BaseTransformerUnit[ExtractAxisSettings, AxisArray, AxisArray, ExtractAxisData]):
39
+ SETTINGS = ExtractAxisSettings
ezmsg/sigproc/fbcca.py ADDED
@@ -0,0 +1,307 @@
1
+ import math
2
+ import typing
3
+ from dataclasses import field
4
+
5
+ import ezmsg.core as ez
6
+ import numpy as np
7
+ from ezmsg.baseproc import (
8
+ BaseProcessor,
9
+ BaseStatefulProcessor,
10
+ BaseTransformer,
11
+ BaseTransformerUnit,
12
+ CompositeProcessor,
13
+ )
14
+ from ezmsg.util.messages.axisarray import AxisArray
15
+ from ezmsg.util.messages.util import replace
16
+
17
+ from .filterbankdesign import (
18
+ FilterbankDesignSettings,
19
+ FilterbankDesignTransformer,
20
+ )
21
+ from .kaiser import KaiserFilterSettings
22
+ from .sampler import SampleTriggerMessage
23
+ from .window import WindowSettings, WindowTransformer
24
+
25
+
26
+ class FBCCASettings(ez.Settings):
27
+ """
28
+ Settings for :obj:`FBCCATransformer`
29
+ """
30
+
31
+ time_dim: str
32
+ """
33
+ The time dim in the data array.
34
+ """
35
+
36
+ ch_dim: str
37
+ """
38
+ The channels dim in the data array.
39
+ """
40
+
41
+ filterbank_dim: str | None = None
42
+ """
43
+ The filter bank subband dim in the data array. If unspecified, method falls back to CCA
44
+ None (default): the input has no subbands; just use CCA
45
+ """
46
+
47
+ harmonics: int = 5
48
+ """
49
+ The number of additional harmonics beyond the fundamental to use for the 'design' matrix.
50
+ 5 (default): Evaluate 5 harmonics of the base frequency.
51
+ Many periodic signals are not pure sinusoids, and inclusion of higher harmonics can help evaluate the
52
+ presence of signals with higher frequency harmonic content
53
+ """
54
+
55
+ freqs: typing.List[float] = field(default_factory=list)
56
+ """
57
+ Frequencies (in hz) to evaluate the presence of within the input signal.
58
+ [] (default): an empty list; frequencies will be found within the input SampleMessages.
59
+ AxisArrays have no good place to put this metadata, so specify frequencies here if only AxisArrays
60
+ will be passed as input to the generator. If the input has a `trigger` attr of type :obj:`SampleTriggerMessage`,
61
+ the processor looks for the `freqs` attribute within that trigger for a list of frequencies to evaluate.
62
+ This field is present in the :obj:`SSVEPSampleTriggerMessage` defined in ezmsg.tasks.ssvep from
63
+ the ezmsg-tasks package.
64
+ NOTE: Avoid frequencies that have line-noise (60 Hz/50 Hz) as a harmonic.
65
+ """
66
+
67
+ softmax_beta: float = 1.0
68
+ """
69
+ Beta parameter for softmax on output --> "probabilities".
70
+ 1.0 (default): Use the shifted softmax transformation to output 0-1 probabilities.
71
+ If 0.0, the maximum singular value of the SVD for each design matrix is output
72
+ """
73
+
74
+ target_freq_dim: str = "target_freq"
75
+ """
76
+ Name for dim to put target frequency outputs on.
77
+ 'target_freq' (default)
78
+ """
79
+
80
+ max_int_time: float = 0.0
81
+ """
82
+ Maximum integration time (in seconds) to use for calculation.
83
+ 0 (default): Use all time provided for the calculation.
84
+ Useful for artificially limiting the amount of data used for the CCA method to evaluate
85
+ the necessary integration time for good decoding performance
86
+ """
87
+
88
+
89
+ class FBCCATransformer(BaseTransformer[FBCCASettings, AxisArray, AxisArray]):
90
+ """
91
+ A canonical-correlation (CCA) signal decoder for detection of periodic activity in multi-channel timeseries
92
+ recordings. It is particularly useful for detecting the presence of steady-state evoked responses in multi-channel
93
+ EEG data. Please see Lin et. al. 2007 for a description on the use of CCA to detect the presence of SSVEP in EEG
94
+ data.
95
+ This implementation also includes the "Filterbank" extension of the CCA decoding approach which utilizes a
96
+ filterbank to decompose input multi-channel EEG data into several frequency sub-bands; each of which is analyzed
97
+ with CCA, then combined using a weighted sum; allowing CCA to more readily identify harmonic content in EEG data.
98
+ Read more about this approach in Chen et. al. 2015.
99
+
100
+ ## Further reading:
101
+ * [Lin et. al. 2007](https://ieeexplore.ieee.org/document/4015614)
102
+ * [Nakanishi et. al. 2015](https://doi.org/10.1371%2Fjournal.pone.0140703)
103
+ * [Chen et. al. 2015](http://dx.doi.org/10.1088/1741-2560/12/4/046008)
104
+ """
105
+
106
+ def _process(self, message: AxisArray) -> AxisArray:
107
+ """
108
+ Input: AxisArray with at least a time_dim, and ch_dim
109
+ Output: AxisArray with time_dim, ch_dim, (and filterbank_dim if specified)
110
+ collapsed, with a new 'target_freq' dim of length 'freqs'
111
+ """
112
+
113
+ test_freqs: list[float] = self.settings.freqs
114
+ trigger = message.attrs.get("trigger", None)
115
+ if isinstance(trigger, SampleTriggerMessage):
116
+ if len(test_freqs) == 0:
117
+ test_freqs = getattr(trigger, "freqs", [])
118
+
119
+ if len(test_freqs) == 0:
120
+ raise ValueError("no frequencies to test")
121
+
122
+ time_dim_idx = message.get_axis_idx(self.settings.time_dim)
123
+ ch_dim_idx = message.get_axis_idx(self.settings.ch_dim)
124
+
125
+ filterbank_dim_idx = None
126
+ if self.settings.filterbank_dim is not None:
127
+ filterbank_dim_idx = message.get_axis_idx(self.settings.filterbank_dim)
128
+
129
+ # Move (filterbank_dim), time, ch to end of array
130
+ rm_dims = [self.settings.time_dim, self.settings.ch_dim]
131
+ if self.settings.filterbank_dim is not None:
132
+ rm_dims = [self.settings.filterbank_dim] + rm_dims
133
+ new_order = [i for i, dim in enumerate(message.dims) if dim not in rm_dims]
134
+ if filterbank_dim_idx is not None:
135
+ new_order.append(filterbank_dim_idx)
136
+ new_order.extend([time_dim_idx, ch_dim_idx])
137
+ out_dims = [message.dims[i] for i in new_order if message.dims[i] not in rm_dims]
138
+ data_arr = message.data.transpose(new_order)
139
+
140
+ # Add a singleton dim for filterbank dim if we don't have one
141
+ if filterbank_dim_idx is None:
142
+ data_arr = data_arr[..., None, :, :]
143
+ filterbank_dim_idx = data_arr.ndim - 3
144
+
145
+ # data_arr is now (..., filterbank, time, ch)
146
+ # Get output shape for remaining dims and reshape data_arr for iterative processing
147
+ out_shape = list(data_arr.shape[:-3])
148
+ data_arr = data_arr.reshape([math.prod(out_shape), *data_arr.shape[-3:]])
149
+
150
+ # Create output dims and axes with added target_freq_dim
151
+ out_shape.append(len(test_freqs))
152
+ out_dims.append(self.settings.target_freq_dim)
153
+ out_axes = {
154
+ axis_name: axis
155
+ for axis_name, axis in message.axes.items()
156
+ if axis_name not in rm_dims
157
+ and not (isinstance(axis, AxisArray.CoordinateAxis) and any(d in rm_dims for d in axis.dims))
158
+ }
159
+ out_axes[self.settings.target_freq_dim] = AxisArray.CoordinateAxis(
160
+ np.array(test_freqs), [self.settings.target_freq_dim]
161
+ )
162
+
163
+ if message.data.size == 0:
164
+ out_data = message.data.reshape(out_shape)
165
+ output = replace(message, data=out_data, dims=out_dims, axes=out_axes)
166
+ return output
167
+
168
+ # Get time axis
169
+ t_ax_info = message.ax(self.settings.time_dim)
170
+ t = t_ax_info.values
171
+ t -= t[0]
172
+ max_samp = len(t)
173
+ if self.settings.max_int_time > 0:
174
+ max_samp = int(abs(t_ax_info.values - self.settings.max_int_time).argmin())
175
+ t = t[:max_samp]
176
+
177
+ calc_output = np.zeros((*data_arr.shape[:-2], len(test_freqs)))
178
+
179
+ for test_freq_idx, test_freq in enumerate(test_freqs):
180
+ # Create the design matrix of base frequency and requested harmonics
181
+ Y = np.column_stack(
182
+ [
183
+ fn(2.0 * np.pi * k * test_freq * t)
184
+ for k in range(1, self.settings.harmonics + 1)
185
+ for fn in (np.sin, np.cos)
186
+ ]
187
+ )
188
+
189
+ for test_idx, arr in enumerate(data_arr): # iterate over first dim; arr is (filterbank x time x ch)
190
+ for band_idx, band in enumerate(arr): # iterate over second dim: arr is (time x ch)
191
+ calc_output[test_idx, band_idx, test_freq_idx] = cca_rho_max(band[:max_samp, ...], Y)
192
+
193
+ # Combine per-subband canonical correlations using a weighted sum
194
+ # https://iopscience.iop.org/article/10.1088/1741-2560/12/4/046008
195
+ freq_weights = (np.arange(1, calc_output.shape[1] + 1) ** -1.25) + 0.25
196
+ calc_output = ((calc_output**2) * freq_weights[None, :, None]).sum(axis=1)
197
+
198
+ if self.settings.softmax_beta != 0:
199
+ calc_output = calc_softmax(calc_output, axis=-1, beta=self.settings.softmax_beta)
200
+
201
+ output = replace(
202
+ message,
203
+ data=calc_output.reshape(out_shape),
204
+ dims=out_dims,
205
+ axes=out_axes,
206
+ )
207
+
208
+ return output
209
+
210
+
211
+ class FBCCA(BaseTransformerUnit[FBCCASettings, AxisArray, AxisArray, FBCCATransformer]):
212
+ SETTINGS = FBCCASettings
213
+
214
+
215
+ class StreamingFBCCASettings(FBCCASettings):
216
+ """
217
+ Perform rolling/streaming FBCCA on incoming EEG.
218
+ Decomposes the input multi-channel timeseries data into multiple sub-bands using a FilterbankDesign Transformer,
219
+ then accumulates data using Window into short-time observations for analysis using an FBCCA Transformer.
220
+ """
221
+
222
+ window_dur: float = 4.0 # sec
223
+ window_shift: float = 0.5 # sec
224
+ window_dim: str = "fbcca_window"
225
+ filter_bw: float = 7.0 # Hz
226
+ filter_low: float = 7.0 # Hz
227
+ trans_bw: float = 2.0 # Hz
228
+ ripple_db: float = 20.0 # dB
229
+ subbands: int = 12
230
+
231
+
232
+ class StreamingFBCCATransformer(CompositeProcessor[StreamingFBCCASettings, AxisArray, AxisArray]):
233
+ @staticmethod
234
+ def _initialize_processors(
235
+ settings: StreamingFBCCASettings,
236
+ ) -> dict[str, BaseProcessor | BaseStatefulProcessor]:
237
+ pipeline = {}
238
+
239
+ if settings.filterbank_dim is not None:
240
+ cut_freqs = (np.arange(settings.subbands + 1) * settings.filter_bw) + settings.filter_low
241
+ filters = [
242
+ KaiserFilterSettings(
243
+ axis=settings.time_dim,
244
+ cutoff=(c - settings.trans_bw, cut_freqs[-1]),
245
+ ripple=settings.ripple_db,
246
+ width=settings.trans_bw,
247
+ pass_zero=False,
248
+ )
249
+ for c in cut_freqs[:-1]
250
+ ]
251
+
252
+ pipeline["filterbank"] = FilterbankDesignTransformer(
253
+ FilterbankDesignSettings(filters=filters, new_axis=settings.filterbank_dim)
254
+ )
255
+
256
+ pipeline["window"] = WindowTransformer(
257
+ WindowSettings(
258
+ axis=settings.time_dim,
259
+ newaxis=settings.window_dim,
260
+ window_dur=settings.window_dur,
261
+ window_shift=settings.window_shift,
262
+ zero_pad_until="shift",
263
+ )
264
+ )
265
+
266
+ pipeline["fbcca"] = FBCCATransformer(settings)
267
+
268
+ return pipeline
269
+
270
+
271
+ class StreamingFBCCA(BaseTransformerUnit[StreamingFBCCASettings, AxisArray, AxisArray, StreamingFBCCATransformer]):
272
+ SETTINGS = StreamingFBCCASettings
273
+
274
+
275
+ def cca_rho_max(X: np.ndarray, Y: np.ndarray) -> float:
276
+ """
277
+ X: (n_time, n_ch)
278
+ Y: (n_time, n_ref) # design matrix for one frequency
279
+ returns: largest canonical correlation in [0,1]
280
+ """
281
+ # Center columns
282
+ Xc = X - X.mean(axis=0, keepdims=True)
283
+ Yc = Y - Y.mean(axis=0, keepdims=True)
284
+
285
+ # Drop any zero-variance columns to avoid rank issues
286
+ Xc = Xc[:, Xc.std(axis=0) > 1e-12]
287
+ Yc = Yc[:, Yc.std(axis=0) > 1e-12]
288
+ if Xc.size == 0 or Yc.size == 0:
289
+ return 0.0
290
+
291
+ # Orthonormal bases
292
+ Qx, _ = np.linalg.qr(Xc, mode="reduced") # (n_time, r_x)
293
+ Qy, _ = np.linalg.qr(Yc, mode="reduced") # (n_time, r_y)
294
+
295
+ # Canonical correlations are the singular values of Qx^T Qy
296
+ with np.errstate(divide="ignore", over="ignore", invalid="ignore"):
297
+ s = np.linalg.svd(Qx.T @ Qy, compute_uv=False)
298
+ return float(s[0]) if s.size else 0.0
299
+
300
+
301
+ def calc_softmax(cv: np.ndarray, axis: int, beta: float = 1.0):
302
+ # Calculate softmax with shifting to avoid overflow
303
+ # (https://doi.org/10.1093/imanum/draa038)
304
+ cv = cv - cv.max(axis=axis, keepdims=True)
305
+ cv = np.exp(beta * cv)
306
+ cv = cv / np.sum(cv, axis=axis, keepdims=True)
307
+ return cv