ezmsg-sigproc 1.7.0__py3-none-any.whl → 2.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +22 -4
- ezmsg/sigproc/activation.py +31 -40
- ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
- ezmsg/sigproc/affinetransform.py +171 -169
- ezmsg/sigproc/aggregate.py +190 -97
- ezmsg/sigproc/bandpower.py +60 -55
- ezmsg/sigproc/base.py +143 -33
- ezmsg/sigproc/butterworthfilter.py +34 -38
- ezmsg/sigproc/butterworthzerophase.py +305 -0
- ezmsg/sigproc/cheby.py +23 -17
- ezmsg/sigproc/combfilter.py +160 -0
- ezmsg/sigproc/coordinatespaces.py +159 -0
- ezmsg/sigproc/decimate.py +15 -10
- ezmsg/sigproc/denormalize.py +78 -0
- ezmsg/sigproc/detrend.py +28 -0
- ezmsg/sigproc/diff.py +82 -0
- ezmsg/sigproc/downsample.py +72 -81
- ezmsg/sigproc/ewma.py +217 -0
- ezmsg/sigproc/ewmfilter.py +1 -1
- ezmsg/sigproc/extract_axis.py +39 -0
- ezmsg/sigproc/fbcca.py +307 -0
- ezmsg/sigproc/filter.py +254 -148
- ezmsg/sigproc/filterbank.py +226 -214
- ezmsg/sigproc/filterbankdesign.py +129 -0
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +117 -0
- ezmsg/sigproc/gaussiansmoothing.py +89 -0
- ezmsg/sigproc/kaiser.py +106 -0
- ezmsg/sigproc/linear.py +120 -0
- ezmsg/sigproc/math/abs.py +23 -22
- ezmsg/sigproc/math/add.py +120 -0
- ezmsg/sigproc/math/clip.py +33 -25
- ezmsg/sigproc/math/difference.py +117 -43
- ezmsg/sigproc/math/invert.py +18 -25
- ezmsg/sigproc/math/log.py +38 -33
- ezmsg/sigproc/math/scale.py +24 -25
- ezmsg/sigproc/messages.py +1 -2
- ezmsg/sigproc/quantize.py +68 -0
- ezmsg/sigproc/resample.py +278 -0
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +209 -254
- ezmsg/sigproc/scaler.py +93 -218
- ezmsg/sigproc/signalinjector.py +44 -43
- ezmsg/sigproc/slicer.py +74 -102
- ezmsg/sigproc/spectral.py +3 -3
- ezmsg/sigproc/spectrogram.py +70 -70
- ezmsg/sigproc/spectrum.py +187 -173
- ezmsg/sigproc/transpose.py +134 -0
- ezmsg/sigproc/util/__init__.py +0 -0
- ezmsg/sigproc/util/asio.py +25 -0
- ezmsg/sigproc/util/axisarray_buffer.py +365 -0
- ezmsg/sigproc/util/buffer.py +449 -0
- ezmsg/sigproc/util/message.py +17 -0
- ezmsg/sigproc/util/profile.py +23 -0
- ezmsg/sigproc/util/sparse.py +115 -0
- ezmsg/sigproc/util/typeresolution.py +17 -0
- ezmsg/sigproc/wavelets.py +147 -154
- ezmsg/sigproc/window.py +248 -210
- ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
- {ezmsg_sigproc-1.7.0.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -1
- ezmsg/sigproc/synth.py +0 -621
- ezmsg_sigproc-1.7.0.dist-info/METADATA +0 -58
- ezmsg_sigproc-1.7.0.dist-info/RECORD +0 -36
- /ezmsg_sigproc-1.7.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Coordinate space transformations for streaming data.
|
|
3
|
+
|
|
4
|
+
This module provides utilities and ezmsg nodes for transforming between
|
|
5
|
+
Cartesian (x, y) and polar (r, theta) coordinate systems.
|
|
6
|
+
|
|
7
|
+
.. note::
|
|
8
|
+
This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
|
|
9
|
+
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from enum import Enum
|
|
13
|
+
from typing import Tuple
|
|
14
|
+
|
|
15
|
+
import ezmsg.core as ez
|
|
16
|
+
import numpy as np
|
|
17
|
+
import numpy.typing as npt
|
|
18
|
+
from array_api_compat import get_namespace, is_array_api_obj
|
|
19
|
+
from ezmsg.baseproc import (
|
|
20
|
+
BaseTransformer,
|
|
21
|
+
BaseTransformerUnit,
|
|
22
|
+
)
|
|
23
|
+
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
24
|
+
|
|
25
|
+
# -- Utility functions for coordinate transformations --
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _get_namespace_or_numpy(*args: npt.ArrayLike):
|
|
29
|
+
"""Get array namespace if any arg is an array, otherwise return numpy."""
|
|
30
|
+
for arg in args:
|
|
31
|
+
if is_array_api_obj(arg):
|
|
32
|
+
return get_namespace(arg)
|
|
33
|
+
return np
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def polar2z(r: npt.ArrayLike, theta: npt.ArrayLike) -> npt.ArrayLike:
|
|
37
|
+
"""Convert polar coordinates to complex number representation."""
|
|
38
|
+
xp = _get_namespace_or_numpy(r, theta)
|
|
39
|
+
return r * xp.exp(1j * theta)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def z2polar(z: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
|
|
43
|
+
"""Convert complex number to polar coordinates (r, theta)."""
|
|
44
|
+
xp = _get_namespace_or_numpy(z)
|
|
45
|
+
return xp.abs(z), xp.atan2(xp.imag(z), xp.real(z))
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def cart2z(x: npt.ArrayLike, y: npt.ArrayLike) -> npt.ArrayLike:
|
|
49
|
+
"""Convert Cartesian coordinates to complex number representation."""
|
|
50
|
+
return x + 1j * y
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def z2cart(z: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
|
|
54
|
+
"""Convert complex number to Cartesian coordinates (x, y)."""
|
|
55
|
+
xp = _get_namespace_or_numpy(z)
|
|
56
|
+
return xp.real(z), xp.imag(z)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def cart2pol(x: npt.ArrayLike, y: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
|
|
60
|
+
"""Convert Cartesian coordinates (x, y) to polar coordinates (r, theta)."""
|
|
61
|
+
return z2polar(cart2z(x, y))
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def pol2cart(r: npt.ArrayLike, theta: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
|
|
65
|
+
"""Convert polar coordinates (r, theta) to Cartesian coordinates (x, y)."""
|
|
66
|
+
return z2cart(polar2z(r, theta))
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# -- ezmsg transformer classes --
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class CoordinateMode(str, Enum):
|
|
73
|
+
"""Transformation mode for coordinate conversion."""
|
|
74
|
+
|
|
75
|
+
CART2POL = "cart2pol"
|
|
76
|
+
"""Convert Cartesian (x, y) to polar (r, theta)."""
|
|
77
|
+
|
|
78
|
+
POL2CART = "pol2cart"
|
|
79
|
+
"""Convert polar (r, theta) to Cartesian (x, y)."""
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class CoordinateSpacesSettings(ez.Settings):
|
|
83
|
+
"""
|
|
84
|
+
Settings for :obj:`CoordinateSpaces`.
|
|
85
|
+
|
|
86
|
+
See :obj:`coordinate_spaces` for argument details.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
mode: CoordinateMode = CoordinateMode.CART2POL
|
|
90
|
+
"""The transformation mode: 'cart2pol' or 'pol2cart'."""
|
|
91
|
+
|
|
92
|
+
axis: str | None = None
|
|
93
|
+
"""
|
|
94
|
+
The name of the axis containing the coordinate components.
|
|
95
|
+
Defaults to the last axis. Must have exactly 2 elements (x,y or r,theta).
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class CoordinateSpacesTransformer(BaseTransformer[CoordinateSpacesSettings, AxisArray, AxisArray]):
|
|
100
|
+
"""
|
|
101
|
+
Transform between Cartesian and polar coordinate systems.
|
|
102
|
+
|
|
103
|
+
The input must have exactly 2 elements along the specified axis:
|
|
104
|
+
- For cart2pol: expects (x, y), outputs (r, theta)
|
|
105
|
+
- For pol2cart: expects (r, theta), outputs (x, y)
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
109
|
+
xp = get_namespace(message.data)
|
|
110
|
+
axis = self.settings.axis or message.dims[-1]
|
|
111
|
+
axis_idx = message.get_axis_idx(axis)
|
|
112
|
+
|
|
113
|
+
if message.data.shape[axis_idx] != 2:
|
|
114
|
+
raise ValueError(
|
|
115
|
+
f"Coordinate transformation requires exactly 2 elements along axis '{axis}', "
|
|
116
|
+
f"got {message.data.shape[axis_idx]}."
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Extract components along the specified axis
|
|
120
|
+
slices_a = [slice(None)] * message.data.ndim
|
|
121
|
+
slices_b = [slice(None)] * message.data.ndim
|
|
122
|
+
slices_a[axis_idx] = 0
|
|
123
|
+
slices_b[axis_idx] = 1
|
|
124
|
+
|
|
125
|
+
component_a = message.data[tuple(slices_a)]
|
|
126
|
+
component_b = message.data[tuple(slices_b)]
|
|
127
|
+
|
|
128
|
+
if self.settings.mode == CoordinateMode.CART2POL:
|
|
129
|
+
# Input: x, y -> Output: r, theta
|
|
130
|
+
out_a, out_b = cart2pol(component_a, component_b)
|
|
131
|
+
else:
|
|
132
|
+
# Input: r, theta -> Output: x, y
|
|
133
|
+
out_a, out_b = pol2cart(component_a, component_b)
|
|
134
|
+
|
|
135
|
+
# Stack results back along the same axis
|
|
136
|
+
result = xp.stack([out_a, out_b], axis=axis_idx)
|
|
137
|
+
|
|
138
|
+
# Update axis labels if present (use numpy for string labels)
|
|
139
|
+
axes = message.axes
|
|
140
|
+
if axis in axes and hasattr(axes[axis], "data"):
|
|
141
|
+
if self.settings.mode == CoordinateMode.CART2POL:
|
|
142
|
+
new_labels = np.array(["r", "theta"])
|
|
143
|
+
else:
|
|
144
|
+
new_labels = np.array(["x", "y"])
|
|
145
|
+
axes = {**axes, axis: replace(axes[axis], data=new_labels)}
|
|
146
|
+
|
|
147
|
+
return replace(message, data=result, axes=axes)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class CoordinateSpaces(
|
|
151
|
+
BaseTransformerUnit[CoordinateSpacesSettings, AxisArray, AxisArray, CoordinateSpacesTransformer]
|
|
152
|
+
):
|
|
153
|
+
"""
|
|
154
|
+
Unit for transforming between Cartesian and polar coordinate systems.
|
|
155
|
+
|
|
156
|
+
See :obj:`CoordinateSpacesSettings` for configuration options.
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
SETTINGS = CoordinateSpacesSettings
|
ezmsg/sigproc/decimate.py
CHANGED
|
@@ -1,33 +1,39 @@
|
|
|
1
1
|
import typing
|
|
2
2
|
|
|
3
3
|
import ezmsg.core as ez
|
|
4
|
+
from ezmsg.baseproc import BaseTransformerUnit
|
|
4
5
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
5
6
|
|
|
6
|
-
from .cheby import
|
|
7
|
+
from .cheby import ChebyshevFilterSettings, ChebyshevFilterTransformer
|
|
7
8
|
from .downsample import Downsample, DownsampleSettings
|
|
8
|
-
from .filter import
|
|
9
|
+
from .filter import BACoeffs, SOSCoeffs
|
|
9
10
|
|
|
10
11
|
|
|
11
|
-
class
|
|
12
|
+
class ChebyForDecimateTransformer(ChebyshevFilterTransformer[BACoeffs | SOSCoeffs]):
|
|
12
13
|
"""
|
|
13
|
-
A :obj:`
|
|
14
|
+
A :obj:`ChebyshevFilterTransformer` with a design filter method that additionally accepts a target sampling rate,
|
|
14
15
|
and if the target rate cannot be achieved it returns None, else it returns the filter coefficients.
|
|
15
16
|
"""
|
|
16
17
|
|
|
17
|
-
def
|
|
18
|
+
def get_design_function(
|
|
18
19
|
self,
|
|
19
|
-
) -> typing.Callable[[float],
|
|
20
|
-
def cheby_opt_design_fun(fs: float) ->
|
|
20
|
+
) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
|
|
21
|
+
def cheby_opt_design_fun(fs: float) -> BACoeffs | SOSCoeffs | None:
|
|
21
22
|
if fs is None:
|
|
22
23
|
return None
|
|
23
|
-
ds_factor = int(fs / (2.5 * self.
|
|
24
|
+
ds_factor = int(fs / (2.5 * self.settings.Wn))
|
|
24
25
|
if ds_factor < 2:
|
|
25
26
|
return None
|
|
26
|
-
partial_fun = super(
|
|
27
|
+
partial_fun = super(ChebyForDecimateTransformer, self).get_design_function()
|
|
27
28
|
return partial_fun(fs)
|
|
29
|
+
|
|
28
30
|
return cheby_opt_design_fun
|
|
29
31
|
|
|
30
32
|
|
|
33
|
+
class ChebyForDecimate(BaseTransformerUnit[ChebyshevFilterSettings, AxisArray, AxisArray, ChebyForDecimateTransformer]):
|
|
34
|
+
SETTINGS = ChebyshevFilterSettings
|
|
35
|
+
|
|
36
|
+
|
|
31
37
|
class Decimate(ez.Collection):
|
|
32
38
|
"""
|
|
33
39
|
A :obj:`Collection` chaining a :obj:`Filter` node configured as a lowpass Chebyshev filter
|
|
@@ -43,7 +49,6 @@ class Decimate(ez.Collection):
|
|
|
43
49
|
DOWNSAMPLE = Downsample()
|
|
44
50
|
|
|
45
51
|
def configure(self) -> None:
|
|
46
|
-
|
|
47
52
|
cheby_settings = ChebyshevFilterSettings(
|
|
48
53
|
order=8,
|
|
49
54
|
ripple_tol=0.05,
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import ezmsg.core as ez
|
|
2
|
+
import numpy as np
|
|
3
|
+
import numpy.typing as npt
|
|
4
|
+
from ezmsg.baseproc import (
|
|
5
|
+
BaseStatefulTransformer,
|
|
6
|
+
BaseTransformerUnit,
|
|
7
|
+
processor_state,
|
|
8
|
+
)
|
|
9
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
10
|
+
from ezmsg.util.messages.util import replace
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DenormalizeSettings(ez.Settings):
|
|
14
|
+
low_rate: float = 2.0
|
|
15
|
+
"""Low end of probable rate after denormalization (Hz)."""
|
|
16
|
+
|
|
17
|
+
high_rate: float = 40.0
|
|
18
|
+
"""High end of probable rate after denormalization (Hz)."""
|
|
19
|
+
|
|
20
|
+
distribution: str = "uniform"
|
|
21
|
+
"""Distribution to sample rates from. Options are 'uniform', 'normal', or 'constant'."""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@processor_state
|
|
25
|
+
class DenormalizeState:
|
|
26
|
+
gains: npt.NDArray | None = None
|
|
27
|
+
offsets: npt.NDArray | None = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class DenormalizeTransformer(BaseStatefulTransformer[DenormalizeSettings, AxisArray, AxisArray, DenormalizeState]):
|
|
31
|
+
"""
|
|
32
|
+
Scales data from a normalized distribution (mean=0, std=1) to a denormalized
|
|
33
|
+
distribution using random per-channel offsets and gains designed to keep the
|
|
34
|
+
99.9% CIs between 0 and 2x the offset.
|
|
35
|
+
|
|
36
|
+
This is useful for simulating realistic firing rates from normalized data.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
40
|
+
ax_ix = message.get_axis_idx("ch")
|
|
41
|
+
nch = message.data.shape[ax_ix]
|
|
42
|
+
arr_size = (nch, 1) if ax_ix == 0 else (1, nch)
|
|
43
|
+
if self.settings.distribution == "uniform":
|
|
44
|
+
self.state.offsets = np.random.uniform(2.0, 40.0, size=arr_size)
|
|
45
|
+
elif self.settings.distribution == "normal":
|
|
46
|
+
self.state.offsets = np.random.normal(
|
|
47
|
+
loc=(self.settings.low_rate + self.settings.high_rate) / 2.0,
|
|
48
|
+
scale=(self.settings.high_rate - self.settings.low_rate) / 6.0,
|
|
49
|
+
size=arr_size,
|
|
50
|
+
)
|
|
51
|
+
self.state.offsets = np.clip(
|
|
52
|
+
self.state.offsets,
|
|
53
|
+
a_min=self.settings.low_rate,
|
|
54
|
+
a_max=self.settings.high_rate,
|
|
55
|
+
)
|
|
56
|
+
elif self.settings.distribution == "constant":
|
|
57
|
+
self.state.offsets = np.full(
|
|
58
|
+
shape=arr_size,
|
|
59
|
+
fill_value=(self.settings.low_rate + self.settings.high_rate) / 2.0,
|
|
60
|
+
)
|
|
61
|
+
else:
|
|
62
|
+
raise ValueError(f"Invalid distribution: {self.settings.distribution}")
|
|
63
|
+
# Input has std == 1
|
|
64
|
+
# Desired output has range from 0 to 2*self.state.offsets within 99.9% confidence interval
|
|
65
|
+
# For a standard normal distribution, 99.9% of data is within +/- 3.29 std devs.
|
|
66
|
+
# So, gain = offset / 3.29 to scale the std dev appropriately.
|
|
67
|
+
self.state.gains = self.state.offsets / 3.29
|
|
68
|
+
|
|
69
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
70
|
+
denorm = message.data * self.state.gains + self.state.offsets
|
|
71
|
+
return replace(
|
|
72
|
+
message,
|
|
73
|
+
data=np.clip(denorm, a_min=0.0, a_max=None),
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class DenormalizeUnit(BaseTransformerUnit[DenormalizeSettings, AxisArray, AxisArray, DenormalizeTransformer]):
|
|
78
|
+
SETTINGS = DenormalizeSettings
|
ezmsg/sigproc/detrend.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import scipy.signal as sps
|
|
2
|
+
from ezmsg.baseproc import BaseTransformerUnit
|
|
3
|
+
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
4
|
+
|
|
5
|
+
from ezmsg.sigproc.ewma import EWMASettings, EWMATransformer
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DetrendTransformer(EWMATransformer):
|
|
9
|
+
"""
|
|
10
|
+
Detrend the data using an exponentially weighted moving average (EWMA)
|
|
11
|
+
estimate of the mean.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def _process(self, message):
|
|
15
|
+
axis = self.settings.axis or message.dims[0]
|
|
16
|
+
axis_idx = message.get_axis_idx(axis)
|
|
17
|
+
means, self._state.zi = sps.lfilter(
|
|
18
|
+
[self._state.alpha],
|
|
19
|
+
[1.0, self._state.alpha - 1.0],
|
|
20
|
+
message.data,
|
|
21
|
+
axis=axis_idx,
|
|
22
|
+
zi=self._state.zi,
|
|
23
|
+
)
|
|
24
|
+
return replace(message, data=message.data - means)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DetrendUnit(BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, DetrendTransformer]):
|
|
28
|
+
SETTINGS = EWMASettings
|
ezmsg/sigproc/diff.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Compute differences along an axis.
|
|
3
|
+
|
|
4
|
+
.. note::
|
|
5
|
+
This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
|
|
6
|
+
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import ezmsg.core as ez
|
|
10
|
+
import numpy as np
|
|
11
|
+
import numpy.typing as npt
|
|
12
|
+
from array_api_compat import get_namespace
|
|
13
|
+
from ezmsg.baseproc import (
|
|
14
|
+
BaseStatefulTransformer,
|
|
15
|
+
BaseTransformerUnit,
|
|
16
|
+
processor_state,
|
|
17
|
+
)
|
|
18
|
+
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
19
|
+
from ezmsg.util.messages.util import replace
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DiffSettings(ez.Settings):
|
|
23
|
+
axis: str | None = None
|
|
24
|
+
scale_by_fs: bool = False
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@processor_state
|
|
28
|
+
class DiffState:
|
|
29
|
+
last_dat: npt.NDArray | None = None
|
|
30
|
+
last_time: float | None = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class DiffTransformer(BaseStatefulTransformer[DiffSettings, AxisArray, AxisArray, DiffState]):
|
|
34
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
35
|
+
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
36
|
+
sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
|
|
37
|
+
return hash((sample_shape, message.key))
|
|
38
|
+
|
|
39
|
+
def _reset_state(self, message) -> None:
|
|
40
|
+
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
41
|
+
self.state.last_dat = slice_along_axis(message.data, slice(0, 1), axis=ax_idx)
|
|
42
|
+
if self.settings.scale_by_fs:
|
|
43
|
+
ax_info = message.get_axis(self.settings.axis)
|
|
44
|
+
if hasattr(ax_info, "data"):
|
|
45
|
+
if len(ax_info.data) > 1:
|
|
46
|
+
self.state.last_time = 2 * ax_info.data[0] - ax_info.data[1]
|
|
47
|
+
else:
|
|
48
|
+
self.state.last_time = ax_info.data[0] - 0.001
|
|
49
|
+
|
|
50
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
51
|
+
xp = get_namespace(message.data)
|
|
52
|
+
axis = self.settings.axis or message.dims[0]
|
|
53
|
+
ax_idx = message.get_axis_idx(axis)
|
|
54
|
+
|
|
55
|
+
diffs = xp.diff(
|
|
56
|
+
xp.concat((self.state.last_dat, message.data), axis=ax_idx),
|
|
57
|
+
axis=ax_idx,
|
|
58
|
+
)
|
|
59
|
+
# Prepare last_dat for next iteration
|
|
60
|
+
self.state.last_dat = slice_along_axis(message.data, slice(-1, None), axis=ax_idx)
|
|
61
|
+
# Scale by fs if requested. This converts the diff to a derivative. e.g., diff of position becomes velocity.
|
|
62
|
+
if self.settings.scale_by_fs:
|
|
63
|
+
ax_info = message.get_axis(axis)
|
|
64
|
+
if hasattr(ax_info, "data"):
|
|
65
|
+
# ax_info.data is typically numpy for metadata, so use np.diff here
|
|
66
|
+
dt = np.diff(np.concatenate(([self.state.last_time], ax_info.data)))
|
|
67
|
+
# Expand dt dims to match diffs
|
|
68
|
+
exp_sl = (None,) * ax_idx + (Ellipsis,) + (None,) * (message.data.ndim - ax_idx - 1)
|
|
69
|
+
diffs /= xp.asarray(dt[exp_sl])
|
|
70
|
+
self.state.last_time = ax_info.data[-1] # For next iteration
|
|
71
|
+
else:
|
|
72
|
+
diffs /= ax_info.gain
|
|
73
|
+
|
|
74
|
+
return replace(message, data=diffs)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class DiffUnit(BaseTransformerUnit[DiffSettings, AxisArray, AxisArray, DiffTransformer]):
|
|
78
|
+
SETTINGS = DiffSettings
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def diff(axis: str = "time", scale_by_fs: bool = False) -> DiffTransformer:
|
|
82
|
+
return DiffTransformer(DiffSettings(axis=axis, scale_by_fs=scale_by_fs))
|
ezmsg/sigproc/downsample.py
CHANGED
|
@@ -1,82 +1,81 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
1
|
+
import ezmsg.core as ez
|
|
3
2
|
import numpy as np
|
|
3
|
+
from ezmsg.baseproc import (
|
|
4
|
+
BaseStatefulTransformer,
|
|
5
|
+
BaseTransformerUnit,
|
|
6
|
+
processor_state,
|
|
7
|
+
)
|
|
4
8
|
from ezmsg.util.messages.axisarray import (
|
|
5
9
|
AxisArray,
|
|
6
|
-
slice_along_axis,
|
|
7
10
|
replace,
|
|
11
|
+
slice_along_axis,
|
|
8
12
|
)
|
|
9
|
-
from ezmsg.util.generator import consumer
|
|
10
|
-
import ezmsg.core as ez
|
|
11
13
|
|
|
12
|
-
from .base import GenAxisArray
|
|
13
14
|
|
|
15
|
+
class DownsampleSettings(ez.Settings):
|
|
16
|
+
"""
|
|
17
|
+
Settings for :obj:`Downsample` node.
|
|
18
|
+
"""
|
|
14
19
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
20
|
+
axis: str = "time"
|
|
21
|
+
"""The name of the axis along which to downsample."""
|
|
22
|
+
|
|
23
|
+
target_rate: float | None = None
|
|
24
|
+
"""Desired rate after downsampling. The actual rate will be the nearest integer factor of the
|
|
25
|
+
input rate that is the same or higher than the target rate."""
|
|
26
|
+
|
|
27
|
+
factor: int | None = None
|
|
28
|
+
"""Explicitly specify downsample factor. If specified, target_rate is ignored."""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@processor_state
|
|
32
|
+
class DownsampleState:
|
|
33
|
+
q: int = 0
|
|
34
|
+
"""The integer downsampling factor. It will be determined based on the target rate."""
|
|
35
|
+
|
|
36
|
+
s_idx: int = 0
|
|
37
|
+
"""Index of the next msg's first sample into the virtual rotating ds_factor counter."""
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class DownsampleTransformer(BaseStatefulTransformer[DownsampleSettings, AxisArray, AxisArray, DownsampleState]):
|
|
19
41
|
"""
|
|
20
|
-
Construct a generator that yields a downsampled version of the data .send() to it.
|
|
21
42
|
Downsampled data simply comprise every `factor`th sample.
|
|
22
43
|
This should only be used following appropriate lowpass filtering.
|
|
23
44
|
If your pipeline does not already have lowpass filtering then consider
|
|
24
45
|
using the :obj:`Decimate` collection instead.
|
|
25
|
-
|
|
26
|
-
Args:
|
|
27
|
-
axis: The name of the axis along which to downsample.
|
|
28
|
-
Note: The axis must exist in the message .axes and be of type AxisArray.LinearAxis.
|
|
29
|
-
target_rate: Desired rate after downsampling. The actual rate will be the nearest integer factor of the
|
|
30
|
-
input rate that is the same or higher than the target rate.
|
|
31
|
-
|
|
32
|
-
Returns:
|
|
33
|
-
A primed generator object ready to receive an :obj:`AxisArray` via `.send(axis_array)`
|
|
34
|
-
and yields an :obj:`AxisArray` with its data downsampled.
|
|
35
|
-
Note that if a send chunk does not have sufficient samples to reach the
|
|
36
|
-
next downsample interval then an :obj:`AxisArray` with size-zero data is yielded.
|
|
37
|
-
|
|
38
46
|
"""
|
|
39
|
-
msg_out = AxisArray(np.array([]), dims=[""])
|
|
40
|
-
|
|
41
|
-
# state variables
|
|
42
|
-
factor: int = 0 # The integer downsampling factor. It will be determined based on the target rate.
|
|
43
|
-
s_idx: int = 0 # Index of the next msg's first sample into the virtual rotating ds_factor counter.
|
|
44
|
-
|
|
45
|
-
check_input = {"gain": None, "key": None}
|
|
46
47
|
|
|
47
|
-
|
|
48
|
-
|
|
48
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
49
|
+
return hash((message.axes[self.settings.axis].gain, message.key))
|
|
49
50
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
axis_info = msg_in.get_axis(axis)
|
|
53
|
-
axis_idx = msg_in.get_axis_idx(axis)
|
|
51
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
52
|
+
axis_info = message.get_axis(self.settings.axis)
|
|
54
53
|
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
n_samples =
|
|
76
|
-
samples = np.arange(s_idx, s_idx + n_samples) %
|
|
54
|
+
if self.settings.factor is not None:
|
|
55
|
+
q = self.settings.factor
|
|
56
|
+
elif self.settings.target_rate is None:
|
|
57
|
+
q = 1
|
|
58
|
+
else:
|
|
59
|
+
q = int(1 / (axis_info.gain * self.settings.target_rate))
|
|
60
|
+
if q < 1:
|
|
61
|
+
ez.logger.warning(
|
|
62
|
+
f"Target rate {self.settings.target_rate} cannot be achieved with input rate of {1 / axis_info.gain}."
|
|
63
|
+
"Setting factor to 1."
|
|
64
|
+
)
|
|
65
|
+
q = 1
|
|
66
|
+
self._state.q = q
|
|
67
|
+
self._state.s_idx = 0
|
|
68
|
+
|
|
69
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
70
|
+
axis = self.settings.axis
|
|
71
|
+
axis_info = message.get_axis(axis)
|
|
72
|
+
axis_idx = message.get_axis_idx(axis)
|
|
73
|
+
|
|
74
|
+
n_samples = message.data.shape[axis_idx]
|
|
75
|
+
samples = np.arange(self.state.s_idx, self.state.s_idx + n_samples) % self._state.q
|
|
77
76
|
if n_samples > 0:
|
|
78
77
|
# Update state for next iteration.
|
|
79
|
-
s_idx = samples[-1] + 1
|
|
78
|
+
self._state.s_idx = samples[-1] + 1
|
|
80
79
|
|
|
81
80
|
pub_samples = np.where(samples == 0)[0]
|
|
82
81
|
if len(pub_samples) > 0:
|
|
@@ -86,35 +85,27 @@ def downsample(
|
|
|
86
85
|
n_step = 0
|
|
87
86
|
data_slice = slice(None, 0, None)
|
|
88
87
|
msg_out = replace(
|
|
89
|
-
|
|
90
|
-
data=slice_along_axis(
|
|
88
|
+
message,
|
|
89
|
+
data=slice_along_axis(message.data, data_slice, axis=axis_idx),
|
|
91
90
|
axes={
|
|
92
|
-
**
|
|
91
|
+
**message.axes,
|
|
93
92
|
axis: replace(
|
|
94
93
|
axis_info,
|
|
95
|
-
gain=axis_info.gain *
|
|
94
|
+
gain=axis_info.gain * self._state.q,
|
|
96
95
|
offset=axis_info.offset + axis_info.gain * n_step,
|
|
97
96
|
),
|
|
98
97
|
},
|
|
99
98
|
)
|
|
99
|
+
return msg_out
|
|
100
100
|
|
|
101
101
|
|
|
102
|
-
class DownsampleSettings
|
|
103
|
-
"""
|
|
104
|
-
Settings for :obj:`Downsample` node.
|
|
105
|
-
See :obj:`downsample` documentation for a description of the parameters.
|
|
106
|
-
"""
|
|
107
|
-
|
|
108
|
-
axis: str | None = None
|
|
109
|
-
target_rate: float | None = None
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
class Downsample(GenAxisArray):
|
|
113
|
-
""":obj:`Unit` for :obj:`bandpower`."""
|
|
114
|
-
|
|
102
|
+
class Downsample(BaseTransformerUnit[DownsampleSettings, AxisArray, AxisArray, DownsampleTransformer]):
|
|
115
103
|
SETTINGS = DownsampleSettings
|
|
116
104
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
105
|
+
|
|
106
|
+
def downsample(
|
|
107
|
+
axis: str = "time",
|
|
108
|
+
target_rate: float | None = None,
|
|
109
|
+
factor: int | None = None,
|
|
110
|
+
) -> DownsampleTransformer:
|
|
111
|
+
return DownsampleTransformer(DownsampleSettings(axis=axis, target_rate=target_rate, factor=factor))
|