ezmsg-sigproc 1.2.2__py3-none-any.whl → 2.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +1 -1
- ezmsg/sigproc/__version__.py +34 -1
- ezmsg/sigproc/activation.py +78 -0
- ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
- ezmsg/sigproc/affinetransform.py +235 -0
- ezmsg/sigproc/aggregate.py +276 -0
- ezmsg/sigproc/bandpower.py +80 -0
- ezmsg/sigproc/base.py +149 -0
- ezmsg/sigproc/butterworthfilter.py +129 -39
- ezmsg/sigproc/butterworthzerophase.py +305 -0
- ezmsg/sigproc/cheby.py +125 -0
- ezmsg/sigproc/combfilter.py +160 -0
- ezmsg/sigproc/coordinatespaces.py +159 -0
- ezmsg/sigproc/decimate.py +46 -18
- ezmsg/sigproc/denormalize.py +78 -0
- ezmsg/sigproc/detrend.py +28 -0
- ezmsg/sigproc/diff.py +82 -0
- ezmsg/sigproc/downsample.py +97 -49
- ezmsg/sigproc/ewma.py +217 -0
- ezmsg/sigproc/ewmfilter.py +45 -19
- ezmsg/sigproc/extract_axis.py +39 -0
- ezmsg/sigproc/fbcca.py +307 -0
- ezmsg/sigproc/filter.py +282 -117
- ezmsg/sigproc/filterbank.py +292 -0
- ezmsg/sigproc/filterbankdesign.py +129 -0
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +117 -0
- ezmsg/sigproc/gaussiansmoothing.py +89 -0
- ezmsg/sigproc/kaiser.py +106 -0
- ezmsg/sigproc/linear.py +120 -0
- ezmsg/sigproc/math/__init__.py +0 -0
- ezmsg/sigproc/math/abs.py +35 -0
- ezmsg/sigproc/math/add.py +120 -0
- ezmsg/sigproc/math/clip.py +48 -0
- ezmsg/sigproc/math/difference.py +143 -0
- ezmsg/sigproc/math/invert.py +28 -0
- ezmsg/sigproc/math/log.py +57 -0
- ezmsg/sigproc/math/scale.py +39 -0
- ezmsg/sigproc/messages.py +3 -6
- ezmsg/sigproc/quantize.py +68 -0
- ezmsg/sigproc/resample.py +278 -0
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +232 -241
- ezmsg/sigproc/scaler.py +165 -0
- ezmsg/sigproc/signalinjector.py +70 -0
- ezmsg/sigproc/slicer.py +138 -0
- ezmsg/sigproc/spectral.py +6 -132
- ezmsg/sigproc/spectrogram.py +90 -0
- ezmsg/sigproc/spectrum.py +277 -0
- ezmsg/sigproc/transpose.py +134 -0
- ezmsg/sigproc/util/__init__.py +0 -0
- ezmsg/sigproc/util/asio.py +25 -0
- ezmsg/sigproc/util/axisarray_buffer.py +365 -0
- ezmsg/sigproc/util/buffer.py +449 -0
- ezmsg/sigproc/util/message.py +17 -0
- ezmsg/sigproc/util/profile.py +23 -0
- ezmsg/sigproc/util/sparse.py +115 -0
- ezmsg/sigproc/util/typeresolution.py +17 -0
- ezmsg/sigproc/wavelets.py +187 -0
- ezmsg/sigproc/window.py +301 -117
- ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -2
- ezmsg/sigproc/synth.py +0 -411
- ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
- ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
- ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
- /ezmsg_sigproc-1.2.2.dist-info/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Coordinate space transformations for streaming data.
|
|
3
|
+
|
|
4
|
+
This module provides utilities and ezmsg nodes for transforming between
|
|
5
|
+
Cartesian (x, y) and polar (r, theta) coordinate systems.
|
|
6
|
+
|
|
7
|
+
.. note::
|
|
8
|
+
This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
|
|
9
|
+
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from enum import Enum
|
|
13
|
+
from typing import Tuple
|
|
14
|
+
|
|
15
|
+
import ezmsg.core as ez
|
|
16
|
+
import numpy as np
|
|
17
|
+
import numpy.typing as npt
|
|
18
|
+
from array_api_compat import get_namespace, is_array_api_obj
|
|
19
|
+
from ezmsg.baseproc import (
|
|
20
|
+
BaseTransformer,
|
|
21
|
+
BaseTransformerUnit,
|
|
22
|
+
)
|
|
23
|
+
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
24
|
+
|
|
25
|
+
# -- Utility functions for coordinate transformations --
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _get_namespace_or_numpy(*args: npt.ArrayLike):
|
|
29
|
+
"""Get array namespace if any arg is an array, otherwise return numpy."""
|
|
30
|
+
for arg in args:
|
|
31
|
+
if is_array_api_obj(arg):
|
|
32
|
+
return get_namespace(arg)
|
|
33
|
+
return np
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def polar2z(r: npt.ArrayLike, theta: npt.ArrayLike) -> npt.ArrayLike:
|
|
37
|
+
"""Convert polar coordinates to complex number representation."""
|
|
38
|
+
xp = _get_namespace_or_numpy(r, theta)
|
|
39
|
+
return r * xp.exp(1j * theta)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def z2polar(z: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
|
|
43
|
+
"""Convert complex number to polar coordinates (r, theta)."""
|
|
44
|
+
xp = _get_namespace_or_numpy(z)
|
|
45
|
+
return xp.abs(z), xp.atan2(xp.imag(z), xp.real(z))
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def cart2z(x: npt.ArrayLike, y: npt.ArrayLike) -> npt.ArrayLike:
|
|
49
|
+
"""Convert Cartesian coordinates to complex number representation."""
|
|
50
|
+
return x + 1j * y
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def z2cart(z: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
|
|
54
|
+
"""Convert complex number to Cartesian coordinates (x, y)."""
|
|
55
|
+
xp = _get_namespace_or_numpy(z)
|
|
56
|
+
return xp.real(z), xp.imag(z)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def cart2pol(x: npt.ArrayLike, y: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
|
|
60
|
+
"""Convert Cartesian coordinates (x, y) to polar coordinates (r, theta)."""
|
|
61
|
+
return z2polar(cart2z(x, y))
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def pol2cart(r: npt.ArrayLike, theta: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
|
|
65
|
+
"""Convert polar coordinates (r, theta) to Cartesian coordinates (x, y)."""
|
|
66
|
+
return z2cart(polar2z(r, theta))
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# -- ezmsg transformer classes --
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class CoordinateMode(str, Enum):
|
|
73
|
+
"""Transformation mode for coordinate conversion."""
|
|
74
|
+
|
|
75
|
+
CART2POL = "cart2pol"
|
|
76
|
+
"""Convert Cartesian (x, y) to polar (r, theta)."""
|
|
77
|
+
|
|
78
|
+
POL2CART = "pol2cart"
|
|
79
|
+
"""Convert polar (r, theta) to Cartesian (x, y)."""
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class CoordinateSpacesSettings(ez.Settings):
|
|
83
|
+
"""
|
|
84
|
+
Settings for :obj:`CoordinateSpaces`.
|
|
85
|
+
|
|
86
|
+
See :obj:`coordinate_spaces` for argument details.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
mode: CoordinateMode = CoordinateMode.CART2POL
|
|
90
|
+
"""The transformation mode: 'cart2pol' or 'pol2cart'."""
|
|
91
|
+
|
|
92
|
+
axis: str | None = None
|
|
93
|
+
"""
|
|
94
|
+
The name of the axis containing the coordinate components.
|
|
95
|
+
Defaults to the last axis. Must have exactly 2 elements (x,y or r,theta).
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class CoordinateSpacesTransformer(BaseTransformer[CoordinateSpacesSettings, AxisArray, AxisArray]):
|
|
100
|
+
"""
|
|
101
|
+
Transform between Cartesian and polar coordinate systems.
|
|
102
|
+
|
|
103
|
+
The input must have exactly 2 elements along the specified axis:
|
|
104
|
+
- For cart2pol: expects (x, y), outputs (r, theta)
|
|
105
|
+
- For pol2cart: expects (r, theta), outputs (x, y)
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
109
|
+
xp = get_namespace(message.data)
|
|
110
|
+
axis = self.settings.axis or message.dims[-1]
|
|
111
|
+
axis_idx = message.get_axis_idx(axis)
|
|
112
|
+
|
|
113
|
+
if message.data.shape[axis_idx] != 2:
|
|
114
|
+
raise ValueError(
|
|
115
|
+
f"Coordinate transformation requires exactly 2 elements along axis '{axis}', "
|
|
116
|
+
f"got {message.data.shape[axis_idx]}."
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Extract components along the specified axis
|
|
120
|
+
slices_a = [slice(None)] * message.data.ndim
|
|
121
|
+
slices_b = [slice(None)] * message.data.ndim
|
|
122
|
+
slices_a[axis_idx] = 0
|
|
123
|
+
slices_b[axis_idx] = 1
|
|
124
|
+
|
|
125
|
+
component_a = message.data[tuple(slices_a)]
|
|
126
|
+
component_b = message.data[tuple(slices_b)]
|
|
127
|
+
|
|
128
|
+
if self.settings.mode == CoordinateMode.CART2POL:
|
|
129
|
+
# Input: x, y -> Output: r, theta
|
|
130
|
+
out_a, out_b = cart2pol(component_a, component_b)
|
|
131
|
+
else:
|
|
132
|
+
# Input: r, theta -> Output: x, y
|
|
133
|
+
out_a, out_b = pol2cart(component_a, component_b)
|
|
134
|
+
|
|
135
|
+
# Stack results back along the same axis
|
|
136
|
+
result = xp.stack([out_a, out_b], axis=axis_idx)
|
|
137
|
+
|
|
138
|
+
# Update axis labels if present (use numpy for string labels)
|
|
139
|
+
axes = message.axes
|
|
140
|
+
if axis in axes and hasattr(axes[axis], "data"):
|
|
141
|
+
if self.settings.mode == CoordinateMode.CART2POL:
|
|
142
|
+
new_labels = np.array(["r", "theta"])
|
|
143
|
+
else:
|
|
144
|
+
new_labels = np.array(["x", "y"])
|
|
145
|
+
axes = {**axes, axis: replace(axes[axis], data=new_labels)}
|
|
146
|
+
|
|
147
|
+
return replace(message, data=result, axes=axes)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class CoordinateSpaces(
|
|
151
|
+
BaseTransformerUnit[CoordinateSpacesSettings, AxisArray, AxisArray, CoordinateSpacesTransformer]
|
|
152
|
+
):
|
|
153
|
+
"""
|
|
154
|
+
Unit for transforming between Cartesian and polar coordinate systems.
|
|
155
|
+
|
|
156
|
+
See :obj:`CoordinateSpacesSettings` for configuration options.
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
SETTINGS = CoordinateSpacesSettings
|
ezmsg/sigproc/decimate.py
CHANGED
|
@@ -1,37 +1,65 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
import scipy.signal
|
|
1
|
+
import typing
|
|
4
2
|
|
|
3
|
+
import ezmsg.core as ez
|
|
4
|
+
from ezmsg.baseproc import BaseTransformerUnit
|
|
5
5
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
6
6
|
|
|
7
|
+
from .cheby import ChebyshevFilterSettings, ChebyshevFilterTransformer
|
|
7
8
|
from .downsample import Downsample, DownsampleSettings
|
|
8
|
-
from .filter import
|
|
9
|
+
from .filter import BACoeffs, SOSCoeffs
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ChebyForDecimateTransformer(ChebyshevFilterTransformer[BACoeffs | SOSCoeffs]):
|
|
13
|
+
"""
|
|
14
|
+
A :obj:`ChebyshevFilterTransformer` with a design filter method that additionally accepts a target sampling rate,
|
|
15
|
+
and if the target rate cannot be achieved it returns None, else it returns the filter coefficients.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def get_design_function(
|
|
19
|
+
self,
|
|
20
|
+
) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
|
|
21
|
+
def cheby_opt_design_fun(fs: float) -> BACoeffs | SOSCoeffs | None:
|
|
22
|
+
if fs is None:
|
|
23
|
+
return None
|
|
24
|
+
ds_factor = int(fs / (2.5 * self.settings.Wn))
|
|
25
|
+
if ds_factor < 2:
|
|
26
|
+
return None
|
|
27
|
+
partial_fun = super(ChebyForDecimateTransformer, self).get_design_function()
|
|
28
|
+
return partial_fun(fs)
|
|
29
|
+
|
|
30
|
+
return cheby_opt_design_fun
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ChebyForDecimate(BaseTransformerUnit[ChebyshevFilterSettings, AxisArray, AxisArray, ChebyForDecimateTransformer]):
|
|
34
|
+
SETTINGS = ChebyshevFilterSettings
|
|
9
35
|
|
|
10
36
|
|
|
11
37
|
class Decimate(ez.Collection):
|
|
12
|
-
|
|
38
|
+
"""
|
|
39
|
+
A :obj:`Collection` chaining a :obj:`Filter` node configured as a lowpass Chebyshev filter
|
|
40
|
+
and a :obj:`Downsample` node.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
SETTINGS = DownsampleSettings
|
|
13
44
|
|
|
14
45
|
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
15
46
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
16
47
|
|
|
17
|
-
FILTER =
|
|
48
|
+
FILTER = ChebyForDecimate()
|
|
18
49
|
DOWNSAMPLE = Downsample()
|
|
19
50
|
|
|
20
51
|
def configure(self) -> None:
|
|
52
|
+
cheby_settings = ChebyshevFilterSettings(
|
|
53
|
+
order=8,
|
|
54
|
+
ripple_tol=0.05,
|
|
55
|
+
Wn=0.4 * self.SETTINGS.target_rate,
|
|
56
|
+
btype="lowpass",
|
|
57
|
+
axis=self.SETTINGS.axis,
|
|
58
|
+
wn_hz=True,
|
|
59
|
+
)
|
|
60
|
+
self.FILTER.apply_settings(cheby_settings)
|
|
21
61
|
self.DOWNSAMPLE.apply_settings(self.SETTINGS)
|
|
22
62
|
|
|
23
|
-
if self.SETTINGS.factor < 1:
|
|
24
|
-
raise ValueError("Decimation factor must be >= 1 (no decimation")
|
|
25
|
-
elif self.SETTINGS.factor == 1:
|
|
26
|
-
filt = FilterCoefficients()
|
|
27
|
-
else:
|
|
28
|
-
# See scipy.signal.decimate for IIR Filter Condition
|
|
29
|
-
b, a = scipy.signal.cheby1(8, 0.05, 0.8 / self.SETTINGS.factor)
|
|
30
|
-
system = scipy.signal.dlti(b, a)
|
|
31
|
-
filt = FilterCoefficients(b=system.num, a=system.den) # type: ignore
|
|
32
|
-
|
|
33
|
-
self.FILTER.apply_settings(FilterSettings(filt=filt))
|
|
34
|
-
|
|
35
63
|
def network(self) -> ez.NetworkDefinition:
|
|
36
64
|
return (
|
|
37
65
|
(self.INPUT_SIGNAL, self.FILTER.INPUT_SIGNAL),
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import ezmsg.core as ez
|
|
2
|
+
import numpy as np
|
|
3
|
+
import numpy.typing as npt
|
|
4
|
+
from ezmsg.baseproc import (
|
|
5
|
+
BaseStatefulTransformer,
|
|
6
|
+
BaseTransformerUnit,
|
|
7
|
+
processor_state,
|
|
8
|
+
)
|
|
9
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
10
|
+
from ezmsg.util.messages.util import replace
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DenormalizeSettings(ez.Settings):
|
|
14
|
+
low_rate: float = 2.0
|
|
15
|
+
"""Low end of probable rate after denormalization (Hz)."""
|
|
16
|
+
|
|
17
|
+
high_rate: float = 40.0
|
|
18
|
+
"""High end of probable rate after denormalization (Hz)."""
|
|
19
|
+
|
|
20
|
+
distribution: str = "uniform"
|
|
21
|
+
"""Distribution to sample rates from. Options are 'uniform', 'normal', or 'constant'."""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@processor_state
|
|
25
|
+
class DenormalizeState:
|
|
26
|
+
gains: npt.NDArray | None = None
|
|
27
|
+
offsets: npt.NDArray | None = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class DenormalizeTransformer(BaseStatefulTransformer[DenormalizeSettings, AxisArray, AxisArray, DenormalizeState]):
|
|
31
|
+
"""
|
|
32
|
+
Scales data from a normalized distribution (mean=0, std=1) to a denormalized
|
|
33
|
+
distribution using random per-channel offsets and gains designed to keep the
|
|
34
|
+
99.9% CIs between 0 and 2x the offset.
|
|
35
|
+
|
|
36
|
+
This is useful for simulating realistic firing rates from normalized data.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
40
|
+
ax_ix = message.get_axis_idx("ch")
|
|
41
|
+
nch = message.data.shape[ax_ix]
|
|
42
|
+
arr_size = (nch, 1) if ax_ix == 0 else (1, nch)
|
|
43
|
+
if self.settings.distribution == "uniform":
|
|
44
|
+
self.state.offsets = np.random.uniform(2.0, 40.0, size=arr_size)
|
|
45
|
+
elif self.settings.distribution == "normal":
|
|
46
|
+
self.state.offsets = np.random.normal(
|
|
47
|
+
loc=(self.settings.low_rate + self.settings.high_rate) / 2.0,
|
|
48
|
+
scale=(self.settings.high_rate - self.settings.low_rate) / 6.0,
|
|
49
|
+
size=arr_size,
|
|
50
|
+
)
|
|
51
|
+
self.state.offsets = np.clip(
|
|
52
|
+
self.state.offsets,
|
|
53
|
+
a_min=self.settings.low_rate,
|
|
54
|
+
a_max=self.settings.high_rate,
|
|
55
|
+
)
|
|
56
|
+
elif self.settings.distribution == "constant":
|
|
57
|
+
self.state.offsets = np.full(
|
|
58
|
+
shape=arr_size,
|
|
59
|
+
fill_value=(self.settings.low_rate + self.settings.high_rate) / 2.0,
|
|
60
|
+
)
|
|
61
|
+
else:
|
|
62
|
+
raise ValueError(f"Invalid distribution: {self.settings.distribution}")
|
|
63
|
+
# Input has std == 1
|
|
64
|
+
# Desired output has range from 0 to 2*self.state.offsets within 99.9% confidence interval
|
|
65
|
+
# For a standard normal distribution, 99.9% of data is within +/- 3.29 std devs.
|
|
66
|
+
# So, gain = offset / 3.29 to scale the std dev appropriately.
|
|
67
|
+
self.state.gains = self.state.offsets / 3.29
|
|
68
|
+
|
|
69
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
70
|
+
denorm = message.data * self.state.gains + self.state.offsets
|
|
71
|
+
return replace(
|
|
72
|
+
message,
|
|
73
|
+
data=np.clip(denorm, a_min=0.0, a_max=None),
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class DenormalizeUnit(BaseTransformerUnit[DenormalizeSettings, AxisArray, AxisArray, DenormalizeTransformer]):
|
|
78
|
+
SETTINGS = DenormalizeSettings
|
ezmsg/sigproc/detrend.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import scipy.signal as sps
|
|
2
|
+
from ezmsg.baseproc import BaseTransformerUnit
|
|
3
|
+
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
4
|
+
|
|
5
|
+
from ezmsg.sigproc.ewma import EWMASettings, EWMATransformer
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DetrendTransformer(EWMATransformer):
|
|
9
|
+
"""
|
|
10
|
+
Detrend the data using an exponentially weighted moving average (EWMA)
|
|
11
|
+
estimate of the mean.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def _process(self, message):
|
|
15
|
+
axis = self.settings.axis or message.dims[0]
|
|
16
|
+
axis_idx = message.get_axis_idx(axis)
|
|
17
|
+
means, self._state.zi = sps.lfilter(
|
|
18
|
+
[self._state.alpha],
|
|
19
|
+
[1.0, self._state.alpha - 1.0],
|
|
20
|
+
message.data,
|
|
21
|
+
axis=axis_idx,
|
|
22
|
+
zi=self._state.zi,
|
|
23
|
+
)
|
|
24
|
+
return replace(message, data=message.data - means)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DetrendUnit(BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, DetrendTransformer]):
|
|
28
|
+
SETTINGS = EWMASettings
|
ezmsg/sigproc/diff.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Compute differences along an axis.
|
|
3
|
+
|
|
4
|
+
.. note::
|
|
5
|
+
This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
|
|
6
|
+
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import ezmsg.core as ez
|
|
10
|
+
import numpy as np
|
|
11
|
+
import numpy.typing as npt
|
|
12
|
+
from array_api_compat import get_namespace
|
|
13
|
+
from ezmsg.baseproc import (
|
|
14
|
+
BaseStatefulTransformer,
|
|
15
|
+
BaseTransformerUnit,
|
|
16
|
+
processor_state,
|
|
17
|
+
)
|
|
18
|
+
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
19
|
+
from ezmsg.util.messages.util import replace
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DiffSettings(ez.Settings):
|
|
23
|
+
axis: str | None = None
|
|
24
|
+
scale_by_fs: bool = False
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@processor_state
|
|
28
|
+
class DiffState:
|
|
29
|
+
last_dat: npt.NDArray | None = None
|
|
30
|
+
last_time: float | None = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class DiffTransformer(BaseStatefulTransformer[DiffSettings, AxisArray, AxisArray, DiffState]):
|
|
34
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
35
|
+
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
36
|
+
sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
|
|
37
|
+
return hash((sample_shape, message.key))
|
|
38
|
+
|
|
39
|
+
def _reset_state(self, message) -> None:
|
|
40
|
+
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
41
|
+
self.state.last_dat = slice_along_axis(message.data, slice(0, 1), axis=ax_idx)
|
|
42
|
+
if self.settings.scale_by_fs:
|
|
43
|
+
ax_info = message.get_axis(self.settings.axis)
|
|
44
|
+
if hasattr(ax_info, "data"):
|
|
45
|
+
if len(ax_info.data) > 1:
|
|
46
|
+
self.state.last_time = 2 * ax_info.data[0] - ax_info.data[1]
|
|
47
|
+
else:
|
|
48
|
+
self.state.last_time = ax_info.data[0] - 0.001
|
|
49
|
+
|
|
50
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
51
|
+
xp = get_namespace(message.data)
|
|
52
|
+
axis = self.settings.axis or message.dims[0]
|
|
53
|
+
ax_idx = message.get_axis_idx(axis)
|
|
54
|
+
|
|
55
|
+
diffs = xp.diff(
|
|
56
|
+
xp.concat((self.state.last_dat, message.data), axis=ax_idx),
|
|
57
|
+
axis=ax_idx,
|
|
58
|
+
)
|
|
59
|
+
# Prepare last_dat for next iteration
|
|
60
|
+
self.state.last_dat = slice_along_axis(message.data, slice(-1, None), axis=ax_idx)
|
|
61
|
+
# Scale by fs if requested. This converts the diff to a derivative. e.g., diff of position becomes velocity.
|
|
62
|
+
if self.settings.scale_by_fs:
|
|
63
|
+
ax_info = message.get_axis(axis)
|
|
64
|
+
if hasattr(ax_info, "data"):
|
|
65
|
+
# ax_info.data is typically numpy for metadata, so use np.diff here
|
|
66
|
+
dt = np.diff(np.concatenate(([self.state.last_time], ax_info.data)))
|
|
67
|
+
# Expand dt dims to match diffs
|
|
68
|
+
exp_sl = (None,) * ax_idx + (Ellipsis,) + (None,) * (message.data.ndim - ax_idx - 1)
|
|
69
|
+
diffs /= xp.asarray(dt[exp_sl])
|
|
70
|
+
self.state.last_time = ax_info.data[-1] # For next iteration
|
|
71
|
+
else:
|
|
72
|
+
diffs /= ax_info.gain
|
|
73
|
+
|
|
74
|
+
return replace(message, data=diffs)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class DiffUnit(BaseTransformerUnit[DiffSettings, AxisArray, AxisArray, DiffTransformer]):
|
|
78
|
+
SETTINGS = DiffSettings
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def diff(axis: str = "time", scale_by_fs: bool = False) -> DiffTransformer:
|
|
82
|
+
return DiffTransformer(DiffSettings(axis=axis, scale_by_fs=scale_by_fs))
|
ezmsg/sigproc/downsample.py
CHANGED
|
@@ -1,63 +1,111 @@
|
|
|
1
|
-
from dataclasses import replace
|
|
2
|
-
|
|
3
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
4
|
-
|
|
5
1
|
import ezmsg.core as ez
|
|
6
2
|
import numpy as np
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
3
|
+
from ezmsg.baseproc import (
|
|
4
|
+
BaseStatefulTransformer,
|
|
5
|
+
BaseTransformerUnit,
|
|
6
|
+
processor_state,
|
|
7
|
+
)
|
|
8
|
+
from ezmsg.util.messages.axisarray import (
|
|
9
|
+
AxisArray,
|
|
10
|
+
replace,
|
|
11
|
+
slice_along_axis,
|
|
11
12
|
)
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class DownsampleSettings(ez.Settings):
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
16
|
+
"""
|
|
17
|
+
Settings for :obj:`Downsample` node.
|
|
18
|
+
"""
|
|
18
19
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
s_idx: int = 0
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class Downsample(ez.Unit):
|
|
25
|
-
SETTINGS: DownsampleSettings
|
|
26
|
-
STATE: DownsampleState
|
|
20
|
+
axis: str = "time"
|
|
21
|
+
"""The name of the axis along which to downsample."""
|
|
27
22
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
23
|
+
target_rate: float | None = None
|
|
24
|
+
"""Desired rate after downsampling. The actual rate will be the nearest integer factor of the
|
|
25
|
+
input rate that is the same or higher than the target rate."""
|
|
31
26
|
|
|
32
|
-
|
|
33
|
-
|
|
27
|
+
factor: int | None = None
|
|
28
|
+
"""Explicitly specify downsample factor. If specified, target_rate is ignored."""
|
|
34
29
|
|
|
35
|
-
@ez.subscriber(INPUT_SETTINGS)
|
|
36
|
-
async def on_settings(self, msg: DownsampleSettings) -> None:
|
|
37
|
-
self.STATE.cur_settings = msg
|
|
38
30
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
raise ValueError("Downsample factor must be at least 1 (no downsampling)")
|
|
31
|
+
@processor_state
|
|
32
|
+
class DownsampleState:
|
|
33
|
+
q: int = 0
|
|
34
|
+
"""The integer downsampling factor. It will be determined based on the target rate."""
|
|
44
35
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
36
|
+
s_idx: int = 0
|
|
37
|
+
"""Index of the next msg's first sample into the virtual rotating ds_factor counter."""
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class DownsampleTransformer(BaseStatefulTransformer[DownsampleSettings, AxisArray, AxisArray, DownsampleState]):
|
|
41
|
+
"""
|
|
42
|
+
Downsampled data simply comprise every `factor`th sample.
|
|
43
|
+
This should only be used following appropriate lowpass filtering.
|
|
44
|
+
If your pipeline does not already have lowpass filtering then consider
|
|
45
|
+
using the :obj:`Decimate` collection instead.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
49
|
+
return hash((message.axes[self.settings.axis].gain, message.key))
|
|
50
|
+
|
|
51
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
52
|
+
axis_info = message.get_axis(self.settings.axis)
|
|
53
|
+
|
|
54
|
+
if self.settings.factor is not None:
|
|
55
|
+
q = self.settings.factor
|
|
56
|
+
elif self.settings.target_rate is None:
|
|
57
|
+
q = 1
|
|
58
|
+
else:
|
|
59
|
+
q = int(1 / (axis_info.gain * self.settings.target_rate))
|
|
60
|
+
if q < 1:
|
|
61
|
+
ez.logger.warning(
|
|
62
|
+
f"Target rate {self.settings.target_rate} cannot be achieved with input rate of {1 / axis_info.gain}."
|
|
63
|
+
"Setting factor to 1."
|
|
64
|
+
)
|
|
65
|
+
q = 1
|
|
66
|
+
self._state.q = q
|
|
67
|
+
self._state.s_idx = 0
|
|
68
|
+
|
|
69
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
70
|
+
axis = self.settings.axis
|
|
71
|
+
axis_info = message.get_axis(axis)
|
|
72
|
+
axis_idx = message.get_axis_idx(axis)
|
|
73
|
+
|
|
74
|
+
n_samples = message.data.shape[axis_idx]
|
|
75
|
+
samples = np.arange(self.state.s_idx, self.state.s_idx + n_samples) % self._state.q
|
|
76
|
+
if n_samples > 0:
|
|
77
|
+
# Update state for next iteration.
|
|
78
|
+
self._state.s_idx = samples[-1] + 1
|
|
54
79
|
|
|
55
80
|
pub_samples = np.where(samples == 0)[0]
|
|
56
|
-
if len(pub_samples)
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
81
|
+
if len(pub_samples) > 0:
|
|
82
|
+
n_step = pub_samples[0].item()
|
|
83
|
+
data_slice = pub_samples
|
|
84
|
+
else:
|
|
85
|
+
n_step = 0
|
|
86
|
+
data_slice = slice(None, 0, None)
|
|
87
|
+
msg_out = replace(
|
|
88
|
+
message,
|
|
89
|
+
data=slice_along_axis(message.data, data_slice, axis=axis_idx),
|
|
90
|
+
axes={
|
|
91
|
+
**message.axes,
|
|
92
|
+
axis: replace(
|
|
93
|
+
axis_info,
|
|
94
|
+
gain=axis_info.gain * self._state.q,
|
|
95
|
+
offset=axis_info.offset + axis_info.gain * n_step,
|
|
96
|
+
),
|
|
97
|
+
},
|
|
98
|
+
)
|
|
99
|
+
return msg_out
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class Downsample(BaseTransformerUnit[DownsampleSettings, AxisArray, AxisArray, DownsampleTransformer]):
|
|
103
|
+
SETTINGS = DownsampleSettings
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def downsample(
|
|
107
|
+
axis: str = "time",
|
|
108
|
+
target_rate: float | None = None,
|
|
109
|
+
factor: int | None = None,
|
|
110
|
+
) -> DownsampleTransformer:
|
|
111
|
+
return DownsampleTransformer(DownsampleSettings(axis=axis, target_rate=target_rate, factor=factor))
|