ezmsg-sigproc 1.2.2__py3-none-any.whl → 2.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +1 -1
- ezmsg/sigproc/__version__.py +34 -1
- ezmsg/sigproc/activation.py +78 -0
- ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
- ezmsg/sigproc/affinetransform.py +235 -0
- ezmsg/sigproc/aggregate.py +276 -0
- ezmsg/sigproc/bandpower.py +80 -0
- ezmsg/sigproc/base.py +149 -0
- ezmsg/sigproc/butterworthfilter.py +129 -39
- ezmsg/sigproc/butterworthzerophase.py +305 -0
- ezmsg/sigproc/cheby.py +125 -0
- ezmsg/sigproc/combfilter.py +160 -0
- ezmsg/sigproc/coordinatespaces.py +159 -0
- ezmsg/sigproc/decimate.py +46 -18
- ezmsg/sigproc/denormalize.py +78 -0
- ezmsg/sigproc/detrend.py +28 -0
- ezmsg/sigproc/diff.py +82 -0
- ezmsg/sigproc/downsample.py +97 -49
- ezmsg/sigproc/ewma.py +217 -0
- ezmsg/sigproc/ewmfilter.py +45 -19
- ezmsg/sigproc/extract_axis.py +39 -0
- ezmsg/sigproc/fbcca.py +307 -0
- ezmsg/sigproc/filter.py +282 -117
- ezmsg/sigproc/filterbank.py +292 -0
- ezmsg/sigproc/filterbankdesign.py +129 -0
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +117 -0
- ezmsg/sigproc/gaussiansmoothing.py +89 -0
- ezmsg/sigproc/kaiser.py +106 -0
- ezmsg/sigproc/linear.py +120 -0
- ezmsg/sigproc/math/__init__.py +0 -0
- ezmsg/sigproc/math/abs.py +35 -0
- ezmsg/sigproc/math/add.py +120 -0
- ezmsg/sigproc/math/clip.py +48 -0
- ezmsg/sigproc/math/difference.py +143 -0
- ezmsg/sigproc/math/invert.py +28 -0
- ezmsg/sigproc/math/log.py +57 -0
- ezmsg/sigproc/math/scale.py +39 -0
- ezmsg/sigproc/messages.py +3 -6
- ezmsg/sigproc/quantize.py +68 -0
- ezmsg/sigproc/resample.py +278 -0
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +232 -241
- ezmsg/sigproc/scaler.py +165 -0
- ezmsg/sigproc/signalinjector.py +70 -0
- ezmsg/sigproc/slicer.py +138 -0
- ezmsg/sigproc/spectral.py +6 -132
- ezmsg/sigproc/spectrogram.py +90 -0
- ezmsg/sigproc/spectrum.py +277 -0
- ezmsg/sigproc/transpose.py +134 -0
- ezmsg/sigproc/util/__init__.py +0 -0
- ezmsg/sigproc/util/asio.py +25 -0
- ezmsg/sigproc/util/axisarray_buffer.py +365 -0
- ezmsg/sigproc/util/buffer.py +449 -0
- ezmsg/sigproc/util/message.py +17 -0
- ezmsg/sigproc/util/profile.py +23 -0
- ezmsg/sigproc/util/sparse.py +115 -0
- ezmsg/sigproc/util/typeresolution.py +17 -0
- ezmsg/sigproc/wavelets.py +187 -0
- ezmsg/sigproc/window.py +301 -117
- ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -2
- ezmsg/sigproc/synth.py +0 -411
- ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
- ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
- ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
- /ezmsg_sigproc-1.2.2.dist-info/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
from .__version__ import __version__
|
|
1
|
+
from .__version__ import __version__ as __version__
|
ezmsg/sigproc/__version__.py
CHANGED
|
@@ -1 +1,34 @@
|
|
|
1
|
-
|
|
1
|
+
# file generated by setuptools-scm
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"__version__",
|
|
6
|
+
"__version_tuple__",
|
|
7
|
+
"version",
|
|
8
|
+
"version_tuple",
|
|
9
|
+
"__commit_id__",
|
|
10
|
+
"commit_id",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
TYPE_CHECKING = False
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from typing import Tuple
|
|
16
|
+
from typing import Union
|
|
17
|
+
|
|
18
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
19
|
+
COMMIT_ID = Union[str, None]
|
|
20
|
+
else:
|
|
21
|
+
VERSION_TUPLE = object
|
|
22
|
+
COMMIT_ID = object
|
|
23
|
+
|
|
24
|
+
version: str
|
|
25
|
+
__version__: str
|
|
26
|
+
__version_tuple__: VERSION_TUPLE
|
|
27
|
+
version_tuple: VERSION_TUPLE
|
|
28
|
+
commit_id: COMMIT_ID
|
|
29
|
+
__commit_id__: COMMIT_ID
|
|
30
|
+
|
|
31
|
+
__version__ = version = '2.10.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (2, 10, 0)
|
|
33
|
+
|
|
34
|
+
__commit_id__ = commit_id = None
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import ezmsg.core as ez
|
|
2
|
+
import scipy.special
|
|
3
|
+
from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
|
|
4
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
5
|
+
from ezmsg.util.messages.util import replace
|
|
6
|
+
|
|
7
|
+
from .spectral import OptionsEnum
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ActivationFunction(OptionsEnum):
|
|
11
|
+
"""Activation (transformation) function."""
|
|
12
|
+
|
|
13
|
+
NONE = "none"
|
|
14
|
+
"""None."""
|
|
15
|
+
|
|
16
|
+
SIGMOID = "sigmoid"
|
|
17
|
+
""":obj:`scipy.special.expit`"""
|
|
18
|
+
|
|
19
|
+
EXPIT = "expit"
|
|
20
|
+
""":obj:`scipy.special.expit`"""
|
|
21
|
+
|
|
22
|
+
LOGIT = "logit"
|
|
23
|
+
""":obj:`scipy.special.logit`"""
|
|
24
|
+
|
|
25
|
+
LOGEXPIT = "log_expit"
|
|
26
|
+
""":obj:`scipy.special.log_expit`"""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
ACTIVATIONS = {
|
|
30
|
+
ActivationFunction.NONE: lambda x: x,
|
|
31
|
+
ActivationFunction.SIGMOID: scipy.special.expit,
|
|
32
|
+
ActivationFunction.EXPIT: scipy.special.expit,
|
|
33
|
+
ActivationFunction.LOGIT: scipy.special.logit,
|
|
34
|
+
ActivationFunction.LOGEXPIT: scipy.special.log_expit,
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ActivationSettings(ez.Settings):
|
|
39
|
+
function: str | ActivationFunction = ActivationFunction.NONE
|
|
40
|
+
"""An enum value from ActivationFunction or a string representing the activation function.
|
|
41
|
+
Possible values are: SIGMOID, EXPIT, LOGIT, LOGEXPIT, "sigmoid", "expit", "logit", "log_expit".
|
|
42
|
+
SIGMOID and EXPIT are equivalent. See :obj:`scipy.special.expit` for more details."""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ActivationTransformer(BaseTransformer[ActivationSettings, AxisArray, AxisArray]):
|
|
46
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
47
|
+
if type(self.settings.function) is ActivationFunction:
|
|
48
|
+
func = ACTIVATIONS[self.settings.function]
|
|
49
|
+
else:
|
|
50
|
+
# str type handling
|
|
51
|
+
function = self.settings.function.lower()
|
|
52
|
+
if function not in ActivationFunction.options():
|
|
53
|
+
raise ValueError(f"Unrecognized activation function {function}. Must be one of {ACTIVATIONS.keys()}")
|
|
54
|
+
function = list(ACTIVATIONS.keys())[ActivationFunction.options().index(function)]
|
|
55
|
+
func = ACTIVATIONS[function]
|
|
56
|
+
|
|
57
|
+
return replace(message, data=func(message.data))
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class Activation(BaseTransformerUnit[ActivationSettings, AxisArray, AxisArray, ActivationTransformer]):
|
|
61
|
+
SETTINGS = ActivationSettings
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def activation(
|
|
65
|
+
function: str | ActivationFunction,
|
|
66
|
+
) -> ActivationTransformer:
|
|
67
|
+
"""
|
|
68
|
+
Transform the data with a simple activation function.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
function: An enum value from ActivationFunction or a string representing the activation function.
|
|
72
|
+
Possible values are: SIGMOID, EXPIT, LOGIT, LOGEXPIT, "sigmoid", "expit", "logit", "log_expit".
|
|
73
|
+
SIGMOID and EXPIT are equivalent. See :obj:`scipy.special.expit` for more details.
|
|
74
|
+
|
|
75
|
+
Returns: :obj:`ActivationTransformer`
|
|
76
|
+
|
|
77
|
+
"""
|
|
78
|
+
return ActivationTransformer(ActivationSettings(function=function))
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
import ezmsg.core as ez
|
|
2
|
+
import numpy as np
|
|
3
|
+
import numpy.typing as npt
|
|
4
|
+
import scipy.signal
|
|
5
|
+
from ezmsg.baseproc import BaseStatefulTransformer, processor_state
|
|
6
|
+
from ezmsg.util.messages.axisarray import AxisArray, CoordinateAxis
|
|
7
|
+
from ezmsg.util.messages.util import replace
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AdaptiveLatticeNotchFilterSettings(ez.Settings):
|
|
11
|
+
"""Settings for the Adaptive Lattice Notch Filter."""
|
|
12
|
+
|
|
13
|
+
gamma: float = 0.995
|
|
14
|
+
"""Pole-zero contraction factor"""
|
|
15
|
+
mu: float = 0.99
|
|
16
|
+
"""Smoothing factor"""
|
|
17
|
+
eta: float = 0.99
|
|
18
|
+
"""Forgetting factor"""
|
|
19
|
+
axis: str = "time"
|
|
20
|
+
"""Axis to apply filter to"""
|
|
21
|
+
init_notch_freq: float | None = None
|
|
22
|
+
"""Initial notch frequency. Should be < nyquist."""
|
|
23
|
+
chunkwise: bool = False
|
|
24
|
+
"""Speed up processing by updating the target freq once per chunk only."""
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@processor_state
|
|
28
|
+
class AdaptiveLatticeNotchFilterState:
|
|
29
|
+
"""State for the Adaptive Lattice Notch Filter."""
|
|
30
|
+
|
|
31
|
+
s_history: npt.NDArray | None = None
|
|
32
|
+
"""Historical `s` values for the adaptive filter."""
|
|
33
|
+
|
|
34
|
+
p: npt.NDArray | None = None
|
|
35
|
+
"""Accumulated product for reflection coefficient update"""
|
|
36
|
+
|
|
37
|
+
q: npt.NDArray | None = None
|
|
38
|
+
"""Accumulated product for reflection coefficient update"""
|
|
39
|
+
|
|
40
|
+
k1: npt.NDArray | None = None
|
|
41
|
+
"""Reflection coefficient"""
|
|
42
|
+
|
|
43
|
+
freq_template: CoordinateAxis | None = None
|
|
44
|
+
"""Template for the frequency axis on the output"""
|
|
45
|
+
|
|
46
|
+
zi: npt.NDArray | None = None
|
|
47
|
+
"""Initial conditions for the filter, updated after every chunk"""
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class AdaptiveLatticeNotchFilterTransformer(
|
|
51
|
+
BaseStatefulTransformer[
|
|
52
|
+
AdaptiveLatticeNotchFilterSettings,
|
|
53
|
+
AxisArray,
|
|
54
|
+
AxisArray,
|
|
55
|
+
AdaptiveLatticeNotchFilterState,
|
|
56
|
+
]
|
|
57
|
+
):
|
|
58
|
+
"""
|
|
59
|
+
Adaptive Lattice Notch Filter implementation as a stateful transformer.
|
|
60
|
+
|
|
61
|
+
https://biomedical-engineering-online.biomedcentral.com/articles/10.1186/1475-925X-13-170
|
|
62
|
+
|
|
63
|
+
The filter automatically tracks and removes frequency components from the input signal.
|
|
64
|
+
It outputs the estimated frequency (in Hz) and the filtered sample.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
68
|
+
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
69
|
+
sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
|
|
70
|
+
return hash((message.key, message.axes[self.settings.axis].gain, sample_shape))
|
|
71
|
+
|
|
72
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
73
|
+
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
74
|
+
sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
|
|
75
|
+
|
|
76
|
+
fs = 1 / message.axes[self.settings.axis].gain
|
|
77
|
+
init_f = (
|
|
78
|
+
self.settings.init_notch_freq if self.settings.init_notch_freq is not None else 0.07178314656435313 * fs
|
|
79
|
+
)
|
|
80
|
+
init_omega = init_f * (2 * np.pi) / fs
|
|
81
|
+
init_k1 = -np.cos(init_omega)
|
|
82
|
+
|
|
83
|
+
"""Reset filter state to initial values."""
|
|
84
|
+
self._state = AdaptiveLatticeNotchFilterState()
|
|
85
|
+
self._state.s_history = np.zeros((2,) + sample_shape, dtype=float)
|
|
86
|
+
self._state.p = np.zeros(sample_shape, dtype=float)
|
|
87
|
+
self._state.q = np.zeros(sample_shape, dtype=float)
|
|
88
|
+
self._state.k1 = init_k1 + np.zeros(sample_shape, dtype=float)
|
|
89
|
+
self._state.freq_template = CoordinateAxis(
|
|
90
|
+
data=np.zeros((0,) + sample_shape, dtype=float),
|
|
91
|
+
dims=[self.settings.axis] + message.dims[:ax_idx] + message.dims[ax_idx + 1 :],
|
|
92
|
+
unit="Hz",
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Initialize the initial conditions for the filter
|
|
96
|
+
self._state.zi = np.zeros((2, np.prod(sample_shape)), dtype=float)
|
|
97
|
+
# Note: we could calculate it properly, but as long as we are initializing s_history with zeros,
|
|
98
|
+
# it will always be zero.
|
|
99
|
+
# a = [1, init_k1 * (1 + self.settings.gamma), self.settings.gamma]
|
|
100
|
+
# b = [1]
|
|
101
|
+
# s = np.reshape(self._state.s_history, (2, -1))
|
|
102
|
+
# for feat_ix in range(np.prod(sample_shape)):
|
|
103
|
+
# self._state.zi[:, feat_ix] = scipy.signal.lfiltic(b, a, s[::-1, feat_ix], x=None)
|
|
104
|
+
|
|
105
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
106
|
+
x_data = message.data
|
|
107
|
+
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
108
|
+
|
|
109
|
+
# TODO: Time should be moved to -1th axis, not the 0th axis
|
|
110
|
+
if message.dims[0] != self.settings.axis:
|
|
111
|
+
x_data = np.moveaxis(x_data, ax_idx, 0)
|
|
112
|
+
|
|
113
|
+
# Access settings once
|
|
114
|
+
gamma = self.settings.gamma
|
|
115
|
+
eta = self.settings.eta
|
|
116
|
+
mu = self.settings.mu
|
|
117
|
+
fs = 1 / message.axes[self.settings.axis].gain
|
|
118
|
+
|
|
119
|
+
# Pre-compute constants
|
|
120
|
+
one_minus_eta = 1 - eta
|
|
121
|
+
one_minus_mu = 1 - mu
|
|
122
|
+
gamma_plus_1 = 1 + gamma
|
|
123
|
+
omega_scale = fs / (2 * np.pi)
|
|
124
|
+
|
|
125
|
+
# For the lattice filter with constant k1:
|
|
126
|
+
# s_n = x_n - k1*(1+gamma)*s_n_1 - gamma*s_n_2
|
|
127
|
+
# This is equivalent to an IIR filter with b=1, a=[1, k1*(1+gamma), gamma]
|
|
128
|
+
|
|
129
|
+
# For the output filter:
|
|
130
|
+
# y_n = s_n + 2*k1*s_n_1 + s_n_2
|
|
131
|
+
# We can treat this as a direct-form FIR filter applied to s_out
|
|
132
|
+
|
|
133
|
+
if self.settings.chunkwise:
|
|
134
|
+
# Process each chunk using current filter parameters
|
|
135
|
+
# Reshape input and prepare output arrays
|
|
136
|
+
_s = self._state.s_history.reshape((2, -1))
|
|
137
|
+
_x = x_data.reshape((x_data.shape[0], -1))
|
|
138
|
+
s_n = np.zeros_like(_x)
|
|
139
|
+
y_out = np.zeros_like(_x)
|
|
140
|
+
|
|
141
|
+
# Apply static filter for each feature dimension
|
|
142
|
+
for ix, k in enumerate(self._state.k1.flatten()):
|
|
143
|
+
# Filter to get s_n (notch filter state)
|
|
144
|
+
a_s = [1, k * gamma_plus_1, gamma]
|
|
145
|
+
s_n[:, ix], self._state.zi[:, ix] = scipy.signal.lfilter([1], a_s, _x[:, ix], zi=self._state.zi[:, ix])
|
|
146
|
+
|
|
147
|
+
# Apply output filter to get y_out
|
|
148
|
+
b_y = [1, 2 * k, 1]
|
|
149
|
+
y_out[:, ix] = scipy.signal.lfilter(b_y, [1], s_n[:, ix])
|
|
150
|
+
|
|
151
|
+
# Update filter parameters using final values from the chunk
|
|
152
|
+
s_n_reshaped = s_n.reshape((s_n.shape[0],) + x_data.shape[1:])
|
|
153
|
+
s_final = s_n_reshaped[-1] # Current s_n
|
|
154
|
+
s_final_1 = s_n_reshaped[-2] # s_n_1
|
|
155
|
+
s_final_2 = s_n_reshaped[-3] if len(s_n_reshaped) > 2 else self._state.s_history[0] # s_n_2
|
|
156
|
+
|
|
157
|
+
# Update p and q using final values
|
|
158
|
+
self._state.p = eta * self._state.p + one_minus_eta * (s_final_1 * (s_final + s_final_2))
|
|
159
|
+
self._state.q = eta * self._state.q + one_minus_eta * (2 * (s_final_1 * s_final_1))
|
|
160
|
+
|
|
161
|
+
# Update reflection coefficient
|
|
162
|
+
new_k1 = -self._state.p / (self._state.q + 1e-8) # Avoid division by zero
|
|
163
|
+
new_k1 = np.clip(new_k1, -1, 1) # Clip to prevent instability
|
|
164
|
+
self._state.k1 = mu * self._state.k1 + one_minus_mu * new_k1 # Smoothed
|
|
165
|
+
|
|
166
|
+
# Calculate frequency from updated k1 value
|
|
167
|
+
omega_n = np.arccos(-self._state.k1)
|
|
168
|
+
freq = omega_n * omega_scale
|
|
169
|
+
freq_out = np.full_like(x_data.reshape(x_data.shape), freq)
|
|
170
|
+
|
|
171
|
+
# Update s_history for next chunk
|
|
172
|
+
self._state.s_history = s_n_reshaped[-2:].reshape((2,) + x_data.shape[1:])
|
|
173
|
+
|
|
174
|
+
# Reshape y_out back to original dimensions
|
|
175
|
+
y_out = y_out.reshape(x_data.shape)
|
|
176
|
+
|
|
177
|
+
else:
|
|
178
|
+
# Perform filtering, sample-by-sample
|
|
179
|
+
y_out = np.zeros_like(x_data)
|
|
180
|
+
freq_out = np.zeros_like(x_data)
|
|
181
|
+
for sample_ix, x_n in enumerate(x_data):
|
|
182
|
+
s_n_1 = self._state.s_history[-1]
|
|
183
|
+
s_n_2 = self._state.s_history[-2]
|
|
184
|
+
|
|
185
|
+
s_n = x_n - self._state.k1 * gamma_plus_1 * s_n_1 - gamma * s_n_2
|
|
186
|
+
y_out[sample_ix] = s_n + 2 * self._state.k1 * s_n_1 + s_n_2
|
|
187
|
+
|
|
188
|
+
# Update filter parameters
|
|
189
|
+
self._state.p = eta * self._state.p + one_minus_eta * (s_n_1 * (s_n + s_n_2))
|
|
190
|
+
self._state.q = eta * self._state.q + one_minus_eta * (2 * (s_n_1 * s_n_1))
|
|
191
|
+
|
|
192
|
+
# Update reflection coefficient
|
|
193
|
+
new_k1 = -self._state.p / (self._state.q + 1e-8) # Avoid division by zero
|
|
194
|
+
new_k1 = np.clip(new_k1, -1, 1) # Clip to prevent instability
|
|
195
|
+
self._state.k1 = mu * self._state.k1 + one_minus_mu * new_k1 # Smoothed
|
|
196
|
+
|
|
197
|
+
# Compute normalized angular frequency using equation 13 from the paper
|
|
198
|
+
omega_n = np.arccos(-self._state.k1)
|
|
199
|
+
freq_out[sample_ix] = omega_n * omega_scale # As Hz
|
|
200
|
+
|
|
201
|
+
# Update for next iteration
|
|
202
|
+
self._state.s_history[-2] = s_n_1
|
|
203
|
+
self._state.s_history[-1] = s_n
|
|
204
|
+
|
|
205
|
+
return replace(
|
|
206
|
+
message,
|
|
207
|
+
data=y_out,
|
|
208
|
+
axes={
|
|
209
|
+
**message.axes,
|
|
210
|
+
"freq": replace(self._state.freq_template, data=freq_out),
|
|
211
|
+
},
|
|
212
|
+
)
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
"""Affine transformations via matrix multiplication: y = Ax or y = Ax + B.
|
|
2
|
+
|
|
3
|
+
For full matrix transformations where channels are mixed (off-diagonal weights),
|
|
4
|
+
use :obj:`AffineTransformTransformer` or the `AffineTransform` unit.
|
|
5
|
+
|
|
6
|
+
For simple per-channel scaling and offset (diagonal weights only), use
|
|
7
|
+
:obj:`LinearTransformTransformer` from :mod:`ezmsg.sigproc.linear` instead,
|
|
8
|
+
which is more efficient as it avoids matrix multiplication.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import os
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
import ezmsg.core as ez
|
|
15
|
+
import numpy as np
|
|
16
|
+
import numpy.typing as npt
|
|
17
|
+
from ezmsg.baseproc import (
|
|
18
|
+
BaseStatefulTransformer,
|
|
19
|
+
BaseTransformer,
|
|
20
|
+
BaseTransformerUnit,
|
|
21
|
+
processor_state,
|
|
22
|
+
)
|
|
23
|
+
from ezmsg.util.messages.axisarray import AxisArray, AxisBase
|
|
24
|
+
from ezmsg.util.messages.util import replace
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AffineTransformSettings(ez.Settings):
|
|
28
|
+
"""
|
|
29
|
+
Settings for :obj:`AffineTransform`.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
weights: np.ndarray | str | Path
|
|
33
|
+
"""An array of weights or a path to a file with weights compatible with np.loadtxt."""
|
|
34
|
+
|
|
35
|
+
axis: str | None = None
|
|
36
|
+
"""The name of the axis to apply the transformation to. Defaults to the leading (0th) axis in the array."""
|
|
37
|
+
|
|
38
|
+
right_multiply: bool = True
|
|
39
|
+
"""Set False to transpose the weights before applying."""
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@processor_state
|
|
43
|
+
class AffineTransformState:
|
|
44
|
+
weights: npt.NDArray | None = None
|
|
45
|
+
new_axis: AxisBase | None = None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class AffineTransformTransformer(
|
|
49
|
+
BaseStatefulTransformer[AffineTransformSettings, AxisArray, AxisArray, AffineTransformState]
|
|
50
|
+
):
|
|
51
|
+
"""Apply affine transformation via matrix multiplication: y = Ax or y = Ax + B.
|
|
52
|
+
|
|
53
|
+
Use this transformer when you need full matrix transformations that mix
|
|
54
|
+
channels (off-diagonal weights), such as spatial filters or projections.
|
|
55
|
+
|
|
56
|
+
For simple per-channel scaling and offset where each output channel depends
|
|
57
|
+
only on its corresponding input channel (diagonal weight matrix), use
|
|
58
|
+
:obj:`LinearTransformTransformer` instead, which is more efficient.
|
|
59
|
+
|
|
60
|
+
The weights matrix can include an offset row (stacked as [A|B]) where the
|
|
61
|
+
input is automatically augmented with a column of ones to compute y = Ax + B.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __call__(self, message: AxisArray) -> AxisArray:
|
|
65
|
+
# Override __call__ so we can shortcut if weights are None.
|
|
66
|
+
if self.settings.weights is None or (
|
|
67
|
+
isinstance(self.settings.weights, str) and self.settings.weights == "passthrough"
|
|
68
|
+
):
|
|
69
|
+
return message
|
|
70
|
+
return super().__call__(message)
|
|
71
|
+
|
|
72
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
73
|
+
return hash(message.key)
|
|
74
|
+
|
|
75
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
76
|
+
weights = self.settings.weights
|
|
77
|
+
if isinstance(weights, str):
|
|
78
|
+
weights = Path(os.path.abspath(os.path.expanduser(weights)))
|
|
79
|
+
if isinstance(weights, Path):
|
|
80
|
+
weights = np.loadtxt(weights, delimiter=",")
|
|
81
|
+
if not self.settings.right_multiply:
|
|
82
|
+
weights = weights.T
|
|
83
|
+
if weights is not None:
|
|
84
|
+
weights = np.ascontiguousarray(weights)
|
|
85
|
+
|
|
86
|
+
self._state.weights = weights
|
|
87
|
+
|
|
88
|
+
axis = self.settings.axis or message.dims[-1]
|
|
89
|
+
if axis in message.axes and hasattr(message.axes[axis], "data") and weights.shape[0] != weights.shape[1]:
|
|
90
|
+
in_labels = message.axes[axis].data
|
|
91
|
+
new_labels = []
|
|
92
|
+
n_in, n_out = weights.shape
|
|
93
|
+
if len(in_labels) != n_in:
|
|
94
|
+
ez.logger.warning(f"Received {len(in_labels)} for {n_in} inputs. Check upstream labels.")
|
|
95
|
+
else:
|
|
96
|
+
b_filled_outputs = np.any(weights, axis=0)
|
|
97
|
+
b_used_inputs = np.any(weights, axis=1)
|
|
98
|
+
if np.all(b_used_inputs) and np.all(b_filled_outputs):
|
|
99
|
+
new_labels = []
|
|
100
|
+
elif np.all(b_used_inputs):
|
|
101
|
+
in_ix = 0
|
|
102
|
+
new_labels = []
|
|
103
|
+
for out_ix in range(n_out):
|
|
104
|
+
if b_filled_outputs[out_ix]:
|
|
105
|
+
new_labels.append(in_labels[in_ix])
|
|
106
|
+
in_ix += 1
|
|
107
|
+
else:
|
|
108
|
+
new_labels.append("")
|
|
109
|
+
elif np.all(b_filled_outputs):
|
|
110
|
+
new_labels = np.array(in_labels)[b_used_inputs]
|
|
111
|
+
|
|
112
|
+
self._state.new_axis = replace(message.axes[axis], data=np.array(new_labels))
|
|
113
|
+
|
|
114
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
115
|
+
axis = self.settings.axis or message.dims[-1]
|
|
116
|
+
axis_idx = message.get_axis_idx(axis)
|
|
117
|
+
data = message.data
|
|
118
|
+
|
|
119
|
+
if data.shape[axis_idx] == (self._state.weights.shape[0] - 1):
|
|
120
|
+
# The weights are stacked A|B where A is the transform and B is a single row
|
|
121
|
+
# in the equation y = Ax + B. This supports NeuroKey's weights matrices.
|
|
122
|
+
sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx + 1 :]
|
|
123
|
+
data = np.concatenate((data, np.ones(sample_shape).astype(data.dtype)), axis=axis_idx)
|
|
124
|
+
|
|
125
|
+
if axis_idx in [-1, len(message.dims) - 1]:
|
|
126
|
+
data = np.matmul(data, self._state.weights)
|
|
127
|
+
else:
|
|
128
|
+
data = np.moveaxis(data, axis_idx, -1)
|
|
129
|
+
data = np.matmul(data, self._state.weights)
|
|
130
|
+
data = np.moveaxis(data, -1, axis_idx)
|
|
131
|
+
|
|
132
|
+
replace_kwargs = {"data": data}
|
|
133
|
+
if self._state.new_axis is not None:
|
|
134
|
+
replace_kwargs["axes"] = {**message.axes, axis: self._state.new_axis}
|
|
135
|
+
|
|
136
|
+
return replace(message, **replace_kwargs)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class AffineTransform(BaseTransformerUnit[AffineTransformSettings, AxisArray, AxisArray, AffineTransformTransformer]):
|
|
140
|
+
SETTINGS = AffineTransformSettings
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def affine_transform(
|
|
144
|
+
weights: np.ndarray | str | Path,
|
|
145
|
+
axis: str | None = None,
|
|
146
|
+
right_multiply: bool = True,
|
|
147
|
+
) -> AffineTransformTransformer:
|
|
148
|
+
"""
|
|
149
|
+
Perform affine transformations on streaming data.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
weights: An array of weights or a path to a file with weights compatible with np.loadtxt.
|
|
153
|
+
axis: The name of the axis to apply the transformation to. Defaults to the leading (0th) axis in the array.
|
|
154
|
+
right_multiply: Set False to transpose the weights before applying.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
:obj:`AffineTransformTransformer`.
|
|
158
|
+
"""
|
|
159
|
+
return AffineTransformTransformer(
|
|
160
|
+
AffineTransformSettings(weights=weights, axis=axis, right_multiply=right_multiply)
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def zeros_for_noop(data: npt.NDArray, **ignore_kwargs) -> npt.NDArray:
|
|
165
|
+
return np.zeros_like(data)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class CommonRereferenceSettings(ez.Settings):
|
|
169
|
+
"""
|
|
170
|
+
Settings for :obj:`CommonRereference`
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
mode: str = "mean"
|
|
174
|
+
"""The statistical mode to apply -- either "mean" or "median"."""
|
|
175
|
+
|
|
176
|
+
axis: str | None = None
|
|
177
|
+
"""The name of the axis to apply the transformation to."""
|
|
178
|
+
|
|
179
|
+
include_current: bool = True
|
|
180
|
+
"""Set False to exclude each channel from participating in the calculation of its reference."""
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class CommonRereferenceTransformer(BaseTransformer[CommonRereferenceSettings, AxisArray, AxisArray]):
|
|
184
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
185
|
+
if self.settings.mode == "passthrough":
|
|
186
|
+
return message
|
|
187
|
+
|
|
188
|
+
axis = self.settings.axis or message.dims[-1]
|
|
189
|
+
axis_idx = message.get_axis_idx(axis)
|
|
190
|
+
|
|
191
|
+
func = {"mean": np.mean, "median": np.median, "passthrough": zeros_for_noop}[self.settings.mode]
|
|
192
|
+
|
|
193
|
+
ref_data = func(message.data, axis=axis_idx, keepdims=True)
|
|
194
|
+
|
|
195
|
+
if not self.settings.include_current:
|
|
196
|
+
# Typical `CAR = x[0]/N + x[1]/N + ... x[i-1]/N + x[i]/N + x[i+1]/N + ... + x[N-1]/N`
|
|
197
|
+
# and is the same for all i, so it is calculated only once in `ref_data`.
|
|
198
|
+
# However, if we had excluded the current channel,
|
|
199
|
+
# then we would have omitted the contribution of the current channel:
|
|
200
|
+
# `CAR[i] = x[0]/(N-1) + x[1]/(N-1) + ... x[i-1]/(N-1) + x[i+1]/(N-1) + ... + x[N-1]/(N-1)`
|
|
201
|
+
# The majority of the calculation is the same as when the current channel is included;
|
|
202
|
+
# we need only rescale CAR so the divisor is `N-1` instead of `N`, then subtract the contribution
|
|
203
|
+
# from the current channel (i.e., `x[i] / (N-1)`)
|
|
204
|
+
# i.e., `CAR[i] = (N / (N-1)) * common_CAR - x[i]/(N-1)`
|
|
205
|
+
# We can use broadcasting subtraction instead of looping over channels.
|
|
206
|
+
N = message.data.shape[axis_idx]
|
|
207
|
+
ref_data = (N / (N - 1)) * ref_data - message.data / (N - 1)
|
|
208
|
+
# Note: I profiled using AffineTransformTransformer; it's ~30x slower than this implementation.
|
|
209
|
+
|
|
210
|
+
return replace(message, data=message.data - ref_data)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class CommonRereference(
|
|
214
|
+
BaseTransformerUnit[CommonRereferenceSettings, AxisArray, AxisArray, CommonRereferenceTransformer]
|
|
215
|
+
):
|
|
216
|
+
SETTINGS = CommonRereferenceSettings
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def common_rereference(
|
|
220
|
+
mode: str = "mean", axis: str | None = None, include_current: bool = True
|
|
221
|
+
) -> CommonRereferenceTransformer:
|
|
222
|
+
"""
|
|
223
|
+
Perform common average referencing (CAR) on streaming data.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
mode: The statistical mode to apply -- either "mean" or "median"
|
|
227
|
+
axis: The name of hte axis to apply the transformation to.
|
|
228
|
+
include_current: Set False to exclude each channel from participating in the calculation of its reference.
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
:obj:`CommonRereferenceTransformer`
|
|
232
|
+
"""
|
|
233
|
+
return CommonRereferenceTransformer(
|
|
234
|
+
CommonRereferenceSettings(mode=mode, axis=axis, include_current=include_current)
|
|
235
|
+
)
|