ezmsg-sigproc 1.2.2__py3-none-any.whl → 2.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +1 -1
- ezmsg/sigproc/__version__.py +34 -1
- ezmsg/sigproc/activation.py +78 -0
- ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
- ezmsg/sigproc/affinetransform.py +235 -0
- ezmsg/sigproc/aggregate.py +276 -0
- ezmsg/sigproc/bandpower.py +80 -0
- ezmsg/sigproc/base.py +149 -0
- ezmsg/sigproc/butterworthfilter.py +129 -39
- ezmsg/sigproc/butterworthzerophase.py +305 -0
- ezmsg/sigproc/cheby.py +125 -0
- ezmsg/sigproc/combfilter.py +160 -0
- ezmsg/sigproc/coordinatespaces.py +159 -0
- ezmsg/sigproc/decimate.py +46 -18
- ezmsg/sigproc/denormalize.py +78 -0
- ezmsg/sigproc/detrend.py +28 -0
- ezmsg/sigproc/diff.py +82 -0
- ezmsg/sigproc/downsample.py +97 -49
- ezmsg/sigproc/ewma.py +217 -0
- ezmsg/sigproc/ewmfilter.py +45 -19
- ezmsg/sigproc/extract_axis.py +39 -0
- ezmsg/sigproc/fbcca.py +307 -0
- ezmsg/sigproc/filter.py +282 -117
- ezmsg/sigproc/filterbank.py +292 -0
- ezmsg/sigproc/filterbankdesign.py +129 -0
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +117 -0
- ezmsg/sigproc/gaussiansmoothing.py +89 -0
- ezmsg/sigproc/kaiser.py +106 -0
- ezmsg/sigproc/linear.py +120 -0
- ezmsg/sigproc/math/__init__.py +0 -0
- ezmsg/sigproc/math/abs.py +35 -0
- ezmsg/sigproc/math/add.py +120 -0
- ezmsg/sigproc/math/clip.py +48 -0
- ezmsg/sigproc/math/difference.py +143 -0
- ezmsg/sigproc/math/invert.py +28 -0
- ezmsg/sigproc/math/log.py +57 -0
- ezmsg/sigproc/math/scale.py +39 -0
- ezmsg/sigproc/messages.py +3 -6
- ezmsg/sigproc/quantize.py +68 -0
- ezmsg/sigproc/resample.py +278 -0
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +232 -241
- ezmsg/sigproc/scaler.py +165 -0
- ezmsg/sigproc/signalinjector.py +70 -0
- ezmsg/sigproc/slicer.py +138 -0
- ezmsg/sigproc/spectral.py +6 -132
- ezmsg/sigproc/spectrogram.py +90 -0
- ezmsg/sigproc/spectrum.py +277 -0
- ezmsg/sigproc/transpose.py +134 -0
- ezmsg/sigproc/util/__init__.py +0 -0
- ezmsg/sigproc/util/asio.py +25 -0
- ezmsg/sigproc/util/axisarray_buffer.py +365 -0
- ezmsg/sigproc/util/buffer.py +449 -0
- ezmsg/sigproc/util/message.py +17 -0
- ezmsg/sigproc/util/profile.py +23 -0
- ezmsg/sigproc/util/sparse.py +115 -0
- ezmsg/sigproc/util/typeresolution.py +17 -0
- ezmsg/sigproc/wavelets.py +187 -0
- ezmsg/sigproc/window.py +301 -117
- ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -2
- ezmsg/sigproc/synth.py +0 -411
- ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
- ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
- ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
- /ezmsg_sigproc-1.2.2.dist-info/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/ewma.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from dataclasses import field
|
|
3
|
+
|
|
4
|
+
import ezmsg.core as ez
|
|
5
|
+
import numpy as np
|
|
6
|
+
import numpy.typing as npt
|
|
7
|
+
import scipy.signal as sps
|
|
8
|
+
from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state
|
|
9
|
+
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
10
|
+
from ezmsg.util.messages.util import replace
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _tau_from_alpha(alpha: float, dt: float) -> float:
|
|
14
|
+
"""
|
|
15
|
+
Inverse of _alpha_from_tau. See that function for explanation.
|
|
16
|
+
"""
|
|
17
|
+
return -dt / np.log(1 - alpha)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _alpha_from_tau(tau: float, dt: float) -> float:
|
|
21
|
+
"""
|
|
22
|
+
# https://en.wikipedia.org/wiki/Exponential_smoothing#Time_constant
|
|
23
|
+
:param tau: The amount of time for the smoothed response of a unit step function to reach
|
|
24
|
+
1 - 1/e approx-eq 63.2%.
|
|
25
|
+
:param dt: sampling period, or 1 / sampling_rate.
|
|
26
|
+
:return: alpha, the "fading factor" in exponential smoothing.
|
|
27
|
+
"""
|
|
28
|
+
return 1 - np.exp(-dt / tau)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def ewma_step(sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None):
|
|
32
|
+
"""
|
|
33
|
+
Do an exponentially weighted moving average step.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
sample: The new sample.
|
|
37
|
+
zi: The output of the previous step.
|
|
38
|
+
alpha: Fading factor.
|
|
39
|
+
beta: Persisting factor. If None, it is calculated as 1-alpha.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
alpha * sample + beta * zi
|
|
43
|
+
|
|
44
|
+
"""
|
|
45
|
+
# Potential micro-optimization:
|
|
46
|
+
# Current: scalar-arr multiplication, scalar-arr multiplication, arr-arr addition
|
|
47
|
+
# Alternative: arr-arr subtraction, arr-arr multiplication, arr-arr addition
|
|
48
|
+
# return zi + alpha * (new_sample - zi)
|
|
49
|
+
beta = beta or (1 - alpha)
|
|
50
|
+
return alpha * sample + beta * zi
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class EWMA_Deprecated:
|
|
54
|
+
"""
|
|
55
|
+
Grabbed these methods from https://stackoverflow.com/a/70998068 and other answers in that topic,
|
|
56
|
+
but they ended up being slower than the scipy.signal.lfilter method.
|
|
57
|
+
Additionally, `compute` and `compute2` suffer from potential errors as the vector length increases
|
|
58
|
+
and beta**n approaches zero.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(self, alpha: float, max_len: int):
|
|
62
|
+
self.alpha = alpha
|
|
63
|
+
self.beta = 1 - alpha
|
|
64
|
+
self.prev: npt.NDArray | None = None
|
|
65
|
+
self.weights = np.empty((max_len + 1,), float)
|
|
66
|
+
self._precalc_weights(max_len)
|
|
67
|
+
self._step_func = functools.partial(ewma_step, alpha=self.alpha, beta=self.beta)
|
|
68
|
+
|
|
69
|
+
def _precalc_weights(self, n: int):
|
|
70
|
+
# (1-α)^0, (1-α)^1, (1-α)^2, ..., (1-α)^n
|
|
71
|
+
np.power(self.beta, np.arange(n + 1), out=self.weights)
|
|
72
|
+
|
|
73
|
+
def compute(self, arr: npt.NDArray, out: npt.NDArray | None = None) -> npt.NDArray:
|
|
74
|
+
if out is None:
|
|
75
|
+
out = np.empty(arr.shape, arr.dtype)
|
|
76
|
+
|
|
77
|
+
n = arr.shape[0]
|
|
78
|
+
weights = self.weights[:n]
|
|
79
|
+
weights = np.expand_dims(weights, list(range(1, arr.ndim)))
|
|
80
|
+
|
|
81
|
+
# α*P0, α*P1, α*P2, ..., α*Pn
|
|
82
|
+
np.multiply(self.alpha, arr, out)
|
|
83
|
+
|
|
84
|
+
# α*P0/(1-α)^0, α*P1/(1-α)^1, α*P2/(1-α)^2, ..., α*Pn/(1-α)^n
|
|
85
|
+
np.divide(out, weights, out)
|
|
86
|
+
|
|
87
|
+
# α*P0/(1-α)^0, α*P0/(1-α)^0 + α*P1/(1-α)^1, ...
|
|
88
|
+
np.cumsum(out, axis=0, out=out)
|
|
89
|
+
|
|
90
|
+
# (α*P0/(1-α)^0)*(1-α)^0, (α*P0/(1-α)^0 + α*P1/(1-α)^1)*(1-α)^1, ...
|
|
91
|
+
np.multiply(out, weights, out)
|
|
92
|
+
|
|
93
|
+
# Add the previous output
|
|
94
|
+
if self.prev is None:
|
|
95
|
+
self.prev = arr[:1]
|
|
96
|
+
|
|
97
|
+
out += self.prev * np.expand_dims(self.weights[1 : n + 1], list(range(1, arr.ndim)))
|
|
98
|
+
|
|
99
|
+
self.prev = out[-1:]
|
|
100
|
+
|
|
101
|
+
return out
|
|
102
|
+
|
|
103
|
+
def compute2(self, arr: npt.NDArray) -> npt.NDArray:
|
|
104
|
+
"""
|
|
105
|
+
Compute the Exponentially Weighted Moving Average (EWMA) of the input array.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
arr: The input array to be smoothed.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
The smoothed array.
|
|
112
|
+
"""
|
|
113
|
+
n = arr.shape[0]
|
|
114
|
+
if n > len(self.weights):
|
|
115
|
+
self._precalc_weights(n)
|
|
116
|
+
weights = self.weights[:n][::-1]
|
|
117
|
+
weights = np.expand_dims(weights, list(range(1, arr.ndim)))
|
|
118
|
+
|
|
119
|
+
result = np.cumsum(self.alpha * weights * arr, axis=0)
|
|
120
|
+
result = result / weights
|
|
121
|
+
|
|
122
|
+
# Handle the first call when prev is unset
|
|
123
|
+
if self.prev is None:
|
|
124
|
+
self.prev = arr[:1]
|
|
125
|
+
|
|
126
|
+
result += self.prev * np.expand_dims(self.weights[1 : n + 1], list(range(1, arr.ndim)))
|
|
127
|
+
|
|
128
|
+
# Store the result back into prev
|
|
129
|
+
self.prev = result[-1]
|
|
130
|
+
|
|
131
|
+
return result
|
|
132
|
+
|
|
133
|
+
def compute_sample(self, new_sample: npt.NDArray) -> npt.NDArray:
|
|
134
|
+
if self.prev is None:
|
|
135
|
+
self.prev = new_sample
|
|
136
|
+
self.prev = self._step_func(new_sample, self.prev)
|
|
137
|
+
return self.prev
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class EWMASettings(ez.Settings):
|
|
141
|
+
time_constant: float = 1.0
|
|
142
|
+
"""The amount of time for the smoothed response of a unit step function to reach 1 - 1/e approx-eq 63.2%."""
|
|
143
|
+
|
|
144
|
+
axis: str | None = None
|
|
145
|
+
|
|
146
|
+
accumulate: bool = True
|
|
147
|
+
"""If True, update the EWMA state with each sample. If False, only apply
|
|
148
|
+
the current EWMA estimate without updating state (useful for inference
|
|
149
|
+
periods where you don't want to adapt statistics)."""
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@processor_state
|
|
153
|
+
class EWMAState:
|
|
154
|
+
alpha: float = field(default_factory=lambda: _alpha_from_tau(1.0, 1000.0))
|
|
155
|
+
zi: npt.NDArray | None = None
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class EWMATransformer(BaseStatefulTransformer[EWMASettings, AxisArray, AxisArray, EWMAState]):
|
|
159
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
160
|
+
axis = self.settings.axis or message.dims[0]
|
|
161
|
+
axis_idx = message.get_axis_idx(axis)
|
|
162
|
+
sample_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
163
|
+
return hash((sample_shape, message.axes[axis].gain, message.key))
|
|
164
|
+
|
|
165
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
166
|
+
axis = self.settings.axis or message.dims[0]
|
|
167
|
+
self._state.alpha = _alpha_from_tau(self.settings.time_constant, message.axes[axis].gain)
|
|
168
|
+
sub_dat = slice_along_axis(message.data, slice(None, 1, None), axis=message.get_axis_idx(axis))
|
|
169
|
+
self._state.zi = (1 - self._state.alpha) * sub_dat
|
|
170
|
+
|
|
171
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
172
|
+
if np.prod(message.data.shape) == 0:
|
|
173
|
+
return message
|
|
174
|
+
axis = self.settings.axis or message.dims[0]
|
|
175
|
+
axis_idx = message.get_axis_idx(axis)
|
|
176
|
+
if self.settings.accumulate:
|
|
177
|
+
# Normal behavior: update state with new samples
|
|
178
|
+
expected, self._state.zi = sps.lfilter(
|
|
179
|
+
[self._state.alpha],
|
|
180
|
+
[1.0, self._state.alpha - 1.0],
|
|
181
|
+
message.data,
|
|
182
|
+
axis=axis_idx,
|
|
183
|
+
zi=self._state.zi,
|
|
184
|
+
)
|
|
185
|
+
else:
|
|
186
|
+
# Process-only: compute output without updating state
|
|
187
|
+
expected, _ = sps.lfilter(
|
|
188
|
+
[self._state.alpha],
|
|
189
|
+
[1.0, self._state.alpha - 1.0],
|
|
190
|
+
message.data,
|
|
191
|
+
axis=axis_idx,
|
|
192
|
+
zi=self._state.zi,
|
|
193
|
+
)
|
|
194
|
+
return replace(message, data=expected)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class EWMAUnit(BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, EWMATransformer]):
|
|
198
|
+
SETTINGS = EWMASettings
|
|
199
|
+
|
|
200
|
+
@ez.subscriber(BaseTransformerUnit.INPUT_SETTINGS)
|
|
201
|
+
async def on_settings(self, msg: EWMASettings) -> None:
|
|
202
|
+
"""
|
|
203
|
+
Handle settings updates with smart reset behavior.
|
|
204
|
+
|
|
205
|
+
Only resets state if `axis` changes (structural change).
|
|
206
|
+
Changes to `time_constant` or `accumulate` are applied without
|
|
207
|
+
resetting accumulated state.
|
|
208
|
+
"""
|
|
209
|
+
old_axis = self.SETTINGS.axis
|
|
210
|
+
self.apply_settings(msg)
|
|
211
|
+
|
|
212
|
+
if msg.axis != old_axis:
|
|
213
|
+
# Axis changed - need full reset
|
|
214
|
+
self.create_processor()
|
|
215
|
+
else:
|
|
216
|
+
# Only accumulate or time_constant changed - keep state
|
|
217
|
+
self.processor.settings = msg
|
ezmsg/sigproc/ewmfilter.py
CHANGED
|
@@ -1,19 +1,20 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
|
|
2
|
+
import typing
|
|
3
3
|
|
|
4
4
|
import ezmsg.core as ez
|
|
5
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
6
|
-
|
|
7
5
|
import numpy as np
|
|
6
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
+
from ezmsg.util.messages.util import replace
|
|
8
8
|
|
|
9
9
|
from .window import Window, WindowSettings
|
|
10
10
|
|
|
11
|
-
from typing import AsyncGenerator, Optional
|
|
12
|
-
|
|
13
11
|
|
|
14
12
|
class EWMSettings(ez.Settings):
|
|
15
|
-
axis:
|
|
16
|
-
|
|
13
|
+
axis: str | None = None
|
|
14
|
+
"""Name of the axis to accumulate."""
|
|
15
|
+
|
|
16
|
+
zero_offset: bool = True
|
|
17
|
+
"""If true, we assume zero DC offset for input data."""
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class EWMState(ez.State):
|
|
@@ -23,19 +24,23 @@ class EWMState(ez.State):
|
|
|
23
24
|
|
|
24
25
|
class EWM(ez.Unit):
|
|
25
26
|
"""
|
|
26
|
-
Exponentially Weighted Moving Average Standardization
|
|
27
|
+
Exponentially Weighted Moving Average Standardization.
|
|
28
|
+
This is deprecated. Please use :obj:`ezmsg.sigproc.scaler.AdaptiveStandardScaler` instead.
|
|
27
29
|
|
|
28
30
|
References https://stackoverflow.com/a/42926270
|
|
29
31
|
"""
|
|
30
32
|
|
|
31
|
-
SETTINGS
|
|
32
|
-
STATE
|
|
33
|
+
SETTINGS = EWMSettings
|
|
34
|
+
STATE = EWMState
|
|
33
35
|
|
|
34
36
|
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
35
37
|
INPUT_BUFFER = ez.InputStream(AxisArray)
|
|
36
38
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
37
39
|
|
|
38
|
-
def initialize(self) -> None:
|
|
40
|
+
async def initialize(self) -> None:
|
|
41
|
+
ez.logger.warning(
|
|
42
|
+
"EWM/EWMFilter is deprecated and will be removed in a future version. Use AdaptiveStandardScaler instead."
|
|
43
|
+
)
|
|
39
44
|
self.STATE.signal_queue = asyncio.Queue()
|
|
40
45
|
self.STATE.buffer_queue = asyncio.Queue()
|
|
41
46
|
|
|
@@ -48,7 +53,7 @@ class EWM(ez.Unit):
|
|
|
48
53
|
self.STATE.buffer_queue.put_nowait(message)
|
|
49
54
|
|
|
50
55
|
@ez.publisher(OUTPUT_SIGNAL)
|
|
51
|
-
async def sync_output(self) -> AsyncGenerator:
|
|
56
|
+
async def sync_output(self) -> typing.AsyncGenerator:
|
|
52
57
|
while True:
|
|
53
58
|
signal = await self.STATE.signal_queue.get()
|
|
54
59
|
buffer = await self.STATE.buffer_queue.get() # includes signal
|
|
@@ -73,9 +78,12 @@ class EWM(ez.Unit):
|
|
|
73
78
|
buffer_data = buffer.data
|
|
74
79
|
buffer_data = np.moveaxis(buffer_data, axis_idx, 0)
|
|
75
80
|
|
|
81
|
+
while scale_arr.ndim < buffer_data.ndim:
|
|
82
|
+
scale_arr = scale_arr[..., None]
|
|
83
|
+
|
|
76
84
|
def ewma(data: np.ndarray) -> np.ndarray:
|
|
77
|
-
mult = scale_arr
|
|
78
|
-
out = scale_arr[::-1
|
|
85
|
+
mult = scale_arr * data * pw0
|
|
86
|
+
out = scale_arr[::-1] * mult.cumsum(axis=0)
|
|
79
87
|
|
|
80
88
|
if not self.SETTINGS.zero_offset:
|
|
81
89
|
out = (data[0, :, np.newaxis] * pows[1:]).T + out
|
|
@@ -93,13 +101,26 @@ class EWM(ez.Unit):
|
|
|
93
101
|
|
|
94
102
|
|
|
95
103
|
class EWMFilterSettings(ez.Settings):
|
|
96
|
-
history_dur: float
|
|
97
|
-
|
|
98
|
-
|
|
104
|
+
history_dur: float
|
|
105
|
+
"""Previous data to accumulate for standardization."""
|
|
106
|
+
|
|
107
|
+
axis: str | None = None
|
|
108
|
+
"""Name of the axis to accumulate."""
|
|
109
|
+
|
|
110
|
+
zero_offset: bool = True
|
|
111
|
+
"""If true, we assume zero DC offset for input data."""
|
|
99
112
|
|
|
100
113
|
|
|
101
114
|
class EWMFilter(ez.Collection):
|
|
102
|
-
|
|
115
|
+
"""
|
|
116
|
+
A :obj:`Collection` that splits the input into a branch that
|
|
117
|
+
leads to :obj:`Window` which then feeds into :obj:`EWM` 's INPUT_BUFFER
|
|
118
|
+
and another branch that feeds directly into :obj:`EWM` 's INPUT_SIGNAL.
|
|
119
|
+
|
|
120
|
+
This is deprecated. Please use :obj:`ezmsg.sigproc.scaler.AdaptiveStandardScaler` instead.
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
SETTINGS = EWMFilterSettings
|
|
103
124
|
|
|
104
125
|
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
105
126
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
@@ -108,7 +129,12 @@ class EWMFilter(ez.Collection):
|
|
|
108
129
|
EWM = EWM()
|
|
109
130
|
|
|
110
131
|
def configure(self) -> None:
|
|
111
|
-
self.EWM.apply_settings(
|
|
132
|
+
self.EWM.apply_settings(
|
|
133
|
+
EWMSettings(
|
|
134
|
+
axis=self.SETTINGS.axis,
|
|
135
|
+
zero_offset=self.SETTINGS.zero_offset,
|
|
136
|
+
)
|
|
137
|
+
)
|
|
112
138
|
|
|
113
139
|
self.WINDOW.apply_settings(
|
|
114
140
|
WindowSettings(
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import ezmsg.core as ez
|
|
2
|
+
import numpy as np
|
|
3
|
+
from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
|
|
4
|
+
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ExtractAxisSettings(ez.Settings):
|
|
8
|
+
axis: str = "freq"
|
|
9
|
+
reference: str = "time"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ExtractAxisData(BaseTransformer[ExtractAxisSettings, AxisArray, AxisArray]):
|
|
13
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
14
|
+
targ_ax = message.axes[self.settings.axis]
|
|
15
|
+
if hasattr(targ_ax, "data"):
|
|
16
|
+
# Extracted axis is of type CoordinateAxis
|
|
17
|
+
return replace(
|
|
18
|
+
message,
|
|
19
|
+
data=targ_ax.data,
|
|
20
|
+
dims=targ_ax.dims,
|
|
21
|
+
axes={k: v for k, v in message.axes.items() if k in targ_ax.dims},
|
|
22
|
+
)
|
|
23
|
+
# Note: So far we don't have any transformers where the coordinate axis has its own axes,
|
|
24
|
+
# but if that happens in the future, we'd need to consider how to handle that.
|
|
25
|
+
|
|
26
|
+
else:
|
|
27
|
+
# Extracted axis is of type LinearAxis
|
|
28
|
+
# LinearAxis can only yield a 1d array data which simplifies dims and axes.
|
|
29
|
+
n = message.data.shape[message.get_axis_idx(self.settings.reference)]
|
|
30
|
+
return replace(
|
|
31
|
+
message,
|
|
32
|
+
data=targ_ax.value(np.arange(n)),
|
|
33
|
+
dims=[self.settings.reference],
|
|
34
|
+
axes={self.settings.reference: message.axes[self.settings.reference]},
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ExtractAxisDataUnit(BaseTransformerUnit[ExtractAxisSettings, AxisArray, AxisArray, ExtractAxisData]):
|
|
39
|
+
SETTINGS = ExtractAxisSettings
|
ezmsg/sigproc/fbcca.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import typing
|
|
3
|
+
from dataclasses import field
|
|
4
|
+
|
|
5
|
+
import ezmsg.core as ez
|
|
6
|
+
import numpy as np
|
|
7
|
+
from ezmsg.baseproc import (
|
|
8
|
+
BaseProcessor,
|
|
9
|
+
BaseStatefulProcessor,
|
|
10
|
+
BaseTransformer,
|
|
11
|
+
BaseTransformerUnit,
|
|
12
|
+
CompositeProcessor,
|
|
13
|
+
)
|
|
14
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
15
|
+
from ezmsg.util.messages.util import replace
|
|
16
|
+
|
|
17
|
+
from .filterbankdesign import (
|
|
18
|
+
FilterbankDesignSettings,
|
|
19
|
+
FilterbankDesignTransformer,
|
|
20
|
+
)
|
|
21
|
+
from .kaiser import KaiserFilterSettings
|
|
22
|
+
from .sampler import SampleTriggerMessage
|
|
23
|
+
from .window import WindowSettings, WindowTransformer
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class FBCCASettings(ez.Settings):
|
|
27
|
+
"""
|
|
28
|
+
Settings for :obj:`FBCCATransformer`
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
time_dim: str
|
|
32
|
+
"""
|
|
33
|
+
The time dim in the data array.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
ch_dim: str
|
|
37
|
+
"""
|
|
38
|
+
The channels dim in the data array.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
filterbank_dim: str | None = None
|
|
42
|
+
"""
|
|
43
|
+
The filter bank subband dim in the data array. If unspecified, method falls back to CCA
|
|
44
|
+
None (default): the input has no subbands; just use CCA
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
harmonics: int = 5
|
|
48
|
+
"""
|
|
49
|
+
The number of additional harmonics beyond the fundamental to use for the 'design' matrix.
|
|
50
|
+
5 (default): Evaluate 5 harmonics of the base frequency.
|
|
51
|
+
Many periodic signals are not pure sinusoids, and inclusion of higher harmonics can help evaluate the
|
|
52
|
+
presence of signals with higher frequency harmonic content
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
freqs: typing.List[float] = field(default_factory=list)
|
|
56
|
+
"""
|
|
57
|
+
Frequencies (in hz) to evaluate the presence of within the input signal.
|
|
58
|
+
[] (default): an empty list; frequencies will be found within the input SampleMessages.
|
|
59
|
+
AxisArrays have no good place to put this metadata, so specify frequencies here if only AxisArrays
|
|
60
|
+
will be passed as input to the generator. If the input has a `trigger` attr of type :obj:`SampleTriggerMessage`,
|
|
61
|
+
the processor looks for the `freqs` attribute within that trigger for a list of frequencies to evaluate.
|
|
62
|
+
This field is present in the :obj:`SSVEPSampleTriggerMessage` defined in ezmsg.tasks.ssvep from
|
|
63
|
+
the ezmsg-tasks package.
|
|
64
|
+
NOTE: Avoid frequencies that have line-noise (60 Hz/50 Hz) as a harmonic.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
softmax_beta: float = 1.0
|
|
68
|
+
"""
|
|
69
|
+
Beta parameter for softmax on output --> "probabilities".
|
|
70
|
+
1.0 (default): Use the shifted softmax transformation to output 0-1 probabilities.
|
|
71
|
+
If 0.0, the maximum singular value of the SVD for each design matrix is output
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
target_freq_dim: str = "target_freq"
|
|
75
|
+
"""
|
|
76
|
+
Name for dim to put target frequency outputs on.
|
|
77
|
+
'target_freq' (default)
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
max_int_time: float = 0.0
|
|
81
|
+
"""
|
|
82
|
+
Maximum integration time (in seconds) to use for calculation.
|
|
83
|
+
0 (default): Use all time provided for the calculation.
|
|
84
|
+
Useful for artificially limiting the amount of data used for the CCA method to evaluate
|
|
85
|
+
the necessary integration time for good decoding performance
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class FBCCATransformer(BaseTransformer[FBCCASettings, AxisArray, AxisArray]):
|
|
90
|
+
"""
|
|
91
|
+
A canonical-correlation (CCA) signal decoder for detection of periodic activity in multi-channel timeseries
|
|
92
|
+
recordings. It is particularly useful for detecting the presence of steady-state evoked responses in multi-channel
|
|
93
|
+
EEG data. Please see Lin et. al. 2007 for a description on the use of CCA to detect the presence of SSVEP in EEG
|
|
94
|
+
data.
|
|
95
|
+
This implementation also includes the "Filterbank" extension of the CCA decoding approach which utilizes a
|
|
96
|
+
filterbank to decompose input multi-channel EEG data into several frequency sub-bands; each of which is analyzed
|
|
97
|
+
with CCA, then combined using a weighted sum; allowing CCA to more readily identify harmonic content in EEG data.
|
|
98
|
+
Read more about this approach in Chen et. al. 2015.
|
|
99
|
+
|
|
100
|
+
## Further reading:
|
|
101
|
+
* [Lin et. al. 2007](https://ieeexplore.ieee.org/document/4015614)
|
|
102
|
+
* [Nakanishi et. al. 2015](https://doi.org/10.1371%2Fjournal.pone.0140703)
|
|
103
|
+
* [Chen et. al. 2015](http://dx.doi.org/10.1088/1741-2560/12/4/046008)
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
107
|
+
"""
|
|
108
|
+
Input: AxisArray with at least a time_dim, and ch_dim
|
|
109
|
+
Output: AxisArray with time_dim, ch_dim, (and filterbank_dim if specified)
|
|
110
|
+
collapsed, with a new 'target_freq' dim of length 'freqs'
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
test_freqs: list[float] = self.settings.freqs
|
|
114
|
+
trigger = message.attrs.get("trigger", None)
|
|
115
|
+
if isinstance(trigger, SampleTriggerMessage):
|
|
116
|
+
if len(test_freqs) == 0:
|
|
117
|
+
test_freqs = getattr(trigger, "freqs", [])
|
|
118
|
+
|
|
119
|
+
if len(test_freqs) == 0:
|
|
120
|
+
raise ValueError("no frequencies to test")
|
|
121
|
+
|
|
122
|
+
time_dim_idx = message.get_axis_idx(self.settings.time_dim)
|
|
123
|
+
ch_dim_idx = message.get_axis_idx(self.settings.ch_dim)
|
|
124
|
+
|
|
125
|
+
filterbank_dim_idx = None
|
|
126
|
+
if self.settings.filterbank_dim is not None:
|
|
127
|
+
filterbank_dim_idx = message.get_axis_idx(self.settings.filterbank_dim)
|
|
128
|
+
|
|
129
|
+
# Move (filterbank_dim), time, ch to end of array
|
|
130
|
+
rm_dims = [self.settings.time_dim, self.settings.ch_dim]
|
|
131
|
+
if self.settings.filterbank_dim is not None:
|
|
132
|
+
rm_dims = [self.settings.filterbank_dim] + rm_dims
|
|
133
|
+
new_order = [i for i, dim in enumerate(message.dims) if dim not in rm_dims]
|
|
134
|
+
if filterbank_dim_idx is not None:
|
|
135
|
+
new_order.append(filterbank_dim_idx)
|
|
136
|
+
new_order.extend([time_dim_idx, ch_dim_idx])
|
|
137
|
+
out_dims = [message.dims[i] for i in new_order if message.dims[i] not in rm_dims]
|
|
138
|
+
data_arr = message.data.transpose(new_order)
|
|
139
|
+
|
|
140
|
+
# Add a singleton dim for filterbank dim if we don't have one
|
|
141
|
+
if filterbank_dim_idx is None:
|
|
142
|
+
data_arr = data_arr[..., None, :, :]
|
|
143
|
+
filterbank_dim_idx = data_arr.ndim - 3
|
|
144
|
+
|
|
145
|
+
# data_arr is now (..., filterbank, time, ch)
|
|
146
|
+
# Get output shape for remaining dims and reshape data_arr for iterative processing
|
|
147
|
+
out_shape = list(data_arr.shape[:-3])
|
|
148
|
+
data_arr = data_arr.reshape([math.prod(out_shape), *data_arr.shape[-3:]])
|
|
149
|
+
|
|
150
|
+
# Create output dims and axes with added target_freq_dim
|
|
151
|
+
out_shape.append(len(test_freqs))
|
|
152
|
+
out_dims.append(self.settings.target_freq_dim)
|
|
153
|
+
out_axes = {
|
|
154
|
+
axis_name: axis
|
|
155
|
+
for axis_name, axis in message.axes.items()
|
|
156
|
+
if axis_name not in rm_dims
|
|
157
|
+
and not (isinstance(axis, AxisArray.CoordinateAxis) and any(d in rm_dims for d in axis.dims))
|
|
158
|
+
}
|
|
159
|
+
out_axes[self.settings.target_freq_dim] = AxisArray.CoordinateAxis(
|
|
160
|
+
np.array(test_freqs), [self.settings.target_freq_dim]
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
if message.data.size == 0:
|
|
164
|
+
out_data = message.data.reshape(out_shape)
|
|
165
|
+
output = replace(message, data=out_data, dims=out_dims, axes=out_axes)
|
|
166
|
+
return output
|
|
167
|
+
|
|
168
|
+
# Get time axis
|
|
169
|
+
t_ax_info = message.ax(self.settings.time_dim)
|
|
170
|
+
t = t_ax_info.values
|
|
171
|
+
t -= t[0]
|
|
172
|
+
max_samp = len(t)
|
|
173
|
+
if self.settings.max_int_time > 0:
|
|
174
|
+
max_samp = int(abs(t_ax_info.values - self.settings.max_int_time).argmin())
|
|
175
|
+
t = t[:max_samp]
|
|
176
|
+
|
|
177
|
+
calc_output = np.zeros((*data_arr.shape[:-2], len(test_freqs)))
|
|
178
|
+
|
|
179
|
+
for test_freq_idx, test_freq in enumerate(test_freqs):
|
|
180
|
+
# Create the design matrix of base frequency and requested harmonics
|
|
181
|
+
Y = np.column_stack(
|
|
182
|
+
[
|
|
183
|
+
fn(2.0 * np.pi * k * test_freq * t)
|
|
184
|
+
for k in range(1, self.settings.harmonics + 1)
|
|
185
|
+
for fn in (np.sin, np.cos)
|
|
186
|
+
]
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
for test_idx, arr in enumerate(data_arr): # iterate over first dim; arr is (filterbank x time x ch)
|
|
190
|
+
for band_idx, band in enumerate(arr): # iterate over second dim: arr is (time x ch)
|
|
191
|
+
calc_output[test_idx, band_idx, test_freq_idx] = cca_rho_max(band[:max_samp, ...], Y)
|
|
192
|
+
|
|
193
|
+
# Combine per-subband canonical correlations using a weighted sum
|
|
194
|
+
# https://iopscience.iop.org/article/10.1088/1741-2560/12/4/046008
|
|
195
|
+
freq_weights = (np.arange(1, calc_output.shape[1] + 1) ** -1.25) + 0.25
|
|
196
|
+
calc_output = ((calc_output**2) * freq_weights[None, :, None]).sum(axis=1)
|
|
197
|
+
|
|
198
|
+
if self.settings.softmax_beta != 0:
|
|
199
|
+
calc_output = calc_softmax(calc_output, axis=-1, beta=self.settings.softmax_beta)
|
|
200
|
+
|
|
201
|
+
output = replace(
|
|
202
|
+
message,
|
|
203
|
+
data=calc_output.reshape(out_shape),
|
|
204
|
+
dims=out_dims,
|
|
205
|
+
axes=out_axes,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
return output
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class FBCCA(BaseTransformerUnit[FBCCASettings, AxisArray, AxisArray, FBCCATransformer]):
|
|
212
|
+
SETTINGS = FBCCASettings
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class StreamingFBCCASettings(FBCCASettings):
|
|
216
|
+
"""
|
|
217
|
+
Perform rolling/streaming FBCCA on incoming EEG.
|
|
218
|
+
Decomposes the input multi-channel timeseries data into multiple sub-bands using a FilterbankDesign Transformer,
|
|
219
|
+
then accumulates data using Window into short-time observations for analysis using an FBCCA Transformer.
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
window_dur: float = 4.0 # sec
|
|
223
|
+
window_shift: float = 0.5 # sec
|
|
224
|
+
window_dim: str = "fbcca_window"
|
|
225
|
+
filter_bw: float = 7.0 # Hz
|
|
226
|
+
filter_low: float = 7.0 # Hz
|
|
227
|
+
trans_bw: float = 2.0 # Hz
|
|
228
|
+
ripple_db: float = 20.0 # dB
|
|
229
|
+
subbands: int = 12
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class StreamingFBCCATransformer(CompositeProcessor[StreamingFBCCASettings, AxisArray, AxisArray]):
|
|
233
|
+
@staticmethod
|
|
234
|
+
def _initialize_processors(
|
|
235
|
+
settings: StreamingFBCCASettings,
|
|
236
|
+
) -> dict[str, BaseProcessor | BaseStatefulProcessor]:
|
|
237
|
+
pipeline = {}
|
|
238
|
+
|
|
239
|
+
if settings.filterbank_dim is not None:
|
|
240
|
+
cut_freqs = (np.arange(settings.subbands + 1) * settings.filter_bw) + settings.filter_low
|
|
241
|
+
filters = [
|
|
242
|
+
KaiserFilterSettings(
|
|
243
|
+
axis=settings.time_dim,
|
|
244
|
+
cutoff=(c - settings.trans_bw, cut_freqs[-1]),
|
|
245
|
+
ripple=settings.ripple_db,
|
|
246
|
+
width=settings.trans_bw,
|
|
247
|
+
pass_zero=False,
|
|
248
|
+
)
|
|
249
|
+
for c in cut_freqs[:-1]
|
|
250
|
+
]
|
|
251
|
+
|
|
252
|
+
pipeline["filterbank"] = FilterbankDesignTransformer(
|
|
253
|
+
FilterbankDesignSettings(filters=filters, new_axis=settings.filterbank_dim)
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
pipeline["window"] = WindowTransformer(
|
|
257
|
+
WindowSettings(
|
|
258
|
+
axis=settings.time_dim,
|
|
259
|
+
newaxis=settings.window_dim,
|
|
260
|
+
window_dur=settings.window_dur,
|
|
261
|
+
window_shift=settings.window_shift,
|
|
262
|
+
zero_pad_until="shift",
|
|
263
|
+
)
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
pipeline["fbcca"] = FBCCATransformer(settings)
|
|
267
|
+
|
|
268
|
+
return pipeline
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class StreamingFBCCA(BaseTransformerUnit[StreamingFBCCASettings, AxisArray, AxisArray, StreamingFBCCATransformer]):
|
|
272
|
+
SETTINGS = StreamingFBCCASettings
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def cca_rho_max(X: np.ndarray, Y: np.ndarray) -> float:
|
|
276
|
+
"""
|
|
277
|
+
X: (n_time, n_ch)
|
|
278
|
+
Y: (n_time, n_ref) # design matrix for one frequency
|
|
279
|
+
returns: largest canonical correlation in [0,1]
|
|
280
|
+
"""
|
|
281
|
+
# Center columns
|
|
282
|
+
Xc = X - X.mean(axis=0, keepdims=True)
|
|
283
|
+
Yc = Y - Y.mean(axis=0, keepdims=True)
|
|
284
|
+
|
|
285
|
+
# Drop any zero-variance columns to avoid rank issues
|
|
286
|
+
Xc = Xc[:, Xc.std(axis=0) > 1e-12]
|
|
287
|
+
Yc = Yc[:, Yc.std(axis=0) > 1e-12]
|
|
288
|
+
if Xc.size == 0 or Yc.size == 0:
|
|
289
|
+
return 0.0
|
|
290
|
+
|
|
291
|
+
# Orthonormal bases
|
|
292
|
+
Qx, _ = np.linalg.qr(Xc, mode="reduced") # (n_time, r_x)
|
|
293
|
+
Qy, _ = np.linalg.qr(Yc, mode="reduced") # (n_time, r_y)
|
|
294
|
+
|
|
295
|
+
# Canonical correlations are the singular values of Qx^T Qy
|
|
296
|
+
with np.errstate(divide="ignore", over="ignore", invalid="ignore"):
|
|
297
|
+
s = np.linalg.svd(Qx.T @ Qy, compute_uv=False)
|
|
298
|
+
return float(s[0]) if s.size else 0.0
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def calc_softmax(cv: np.ndarray, axis: int, beta: float = 1.0):
|
|
302
|
+
# Calculate softmax with shifting to avoid overflow
|
|
303
|
+
# (https://doi.org/10.1093/imanum/draa038)
|
|
304
|
+
cv = cv - cv.max(axis=axis, keepdims=True)
|
|
305
|
+
cv = np.exp(beta * cv)
|
|
306
|
+
cv = cv / np.sum(cv, axis=axis, keepdims=True)
|
|
307
|
+
return cv
|