ezmsg-sigproc 2.5.0__py3-none-any.whl → 2.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +5 -11
- ezmsg/sigproc/adaptive_lattice_notch.py +11 -30
- ezmsg/sigproc/affinetransform.py +16 -42
- ezmsg/sigproc/aggregate.py +17 -34
- ezmsg/sigproc/bandpower.py +12 -20
- ezmsg/sigproc/base.py +141 -1276
- ezmsg/sigproc/butterworthfilter.py +8 -16
- ezmsg/sigproc/butterworthzerophase.py +7 -16
- ezmsg/sigproc/cheby.py +4 -10
- ezmsg/sigproc/combfilter.py +5 -8
- ezmsg/sigproc/coordinatespaces.py +142 -0
- ezmsg/sigproc/decimate.py +3 -7
- ezmsg/sigproc/denormalize.py +6 -11
- ezmsg/sigproc/detrend.py +3 -4
- ezmsg/sigproc/diff.py +8 -17
- ezmsg/sigproc/downsample.py +11 -20
- ezmsg/sigproc/ewma.py +11 -28
- ezmsg/sigproc/ewmfilter.py +1 -1
- ezmsg/sigproc/extract_axis.py +3 -4
- ezmsg/sigproc/fbcca.py +34 -59
- ezmsg/sigproc/filter.py +19 -45
- ezmsg/sigproc/filterbank.py +37 -74
- ezmsg/sigproc/filterbankdesign.py +7 -14
- ezmsg/sigproc/fir_hilbert.py +13 -30
- ezmsg/sigproc/fir_pmc.py +5 -10
- ezmsg/sigproc/firfilter.py +12 -14
- ezmsg/sigproc/gaussiansmoothing.py +5 -9
- ezmsg/sigproc/kaiser.py +11 -15
- ezmsg/sigproc/math/abs.py +4 -3
- ezmsg/sigproc/math/add.py +121 -0
- ezmsg/sigproc/math/clip.py +4 -1
- ezmsg/sigproc/math/difference.py +100 -36
- ezmsg/sigproc/math/invert.py +3 -3
- ezmsg/sigproc/math/log.py +5 -6
- ezmsg/sigproc/math/scale.py +2 -0
- ezmsg/sigproc/messages.py +1 -2
- ezmsg/sigproc/quantize.py +3 -6
- ezmsg/sigproc/resample.py +17 -38
- ezmsg/sigproc/rollingscaler.py +12 -37
- ezmsg/sigproc/sampler.py +19 -37
- ezmsg/sigproc/scaler.py +11 -22
- ezmsg/sigproc/signalinjector.py +7 -18
- ezmsg/sigproc/slicer.py +14 -34
- ezmsg/sigproc/spectral.py +3 -3
- ezmsg/sigproc/spectrogram.py +12 -19
- ezmsg/sigproc/spectrum.py +17 -38
- ezmsg/sigproc/transpose.py +12 -24
- ezmsg/sigproc/util/asio.py +25 -156
- ezmsg/sigproc/util/axisarray_buffer.py +12 -26
- ezmsg/sigproc/util/buffer.py +22 -43
- ezmsg/sigproc/util/message.py +17 -31
- ezmsg/sigproc/util/profile.py +23 -174
- ezmsg/sigproc/util/sparse.py +7 -15
- ezmsg/sigproc/util/typeresolution.py +17 -83
- ezmsg/sigproc/wavelets.py +10 -19
- ezmsg/sigproc/window.py +29 -83
- ezmsg_sigproc-2.7.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.7.0.dist-info/RECORD +64 -0
- ezmsg/sigproc/synth.py +0 -774
- ezmsg_sigproc-2.5.0.dist-info/METADATA +0 -72
- ezmsg_sigproc-2.5.0.dist-info/RECORD +0 -63
- {ezmsg_sigproc-2.5.0.dist-info → ezmsg_sigproc-2.7.0.dist-info}/WHEEL +0 -0
- /ezmsg_sigproc-2.5.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.7.0.dist-info/licenses/LICENSE +0 -0
|
@@ -5,11 +5,11 @@ import scipy.signal
|
|
|
5
5
|
from scipy.signal import normalize
|
|
6
6
|
|
|
7
7
|
from .filter import (
|
|
8
|
-
FilterBaseSettings,
|
|
9
8
|
BACoeffs,
|
|
10
|
-
SOSCoeffs,
|
|
11
|
-
FilterByDesignTransformer,
|
|
12
9
|
BaseFilterByDesignTransformerUnit,
|
|
10
|
+
FilterBaseSettings,
|
|
11
|
+
FilterByDesignTransformer,
|
|
12
|
+
SOSCoeffs,
|
|
13
13
|
)
|
|
14
14
|
|
|
15
15
|
|
|
@@ -27,14 +27,14 @@ class ButterworthFilterSettings(FilterBaseSettings):
|
|
|
27
27
|
"""
|
|
28
28
|
Cuton frequency (Hz). If `cutoff` is not specified then this is the highpass corner. Otherwise,
|
|
29
29
|
if this is lower than `cutoff` then this is the beginning of the bandpass
|
|
30
|
-
or if this is greater than `cutoff` then this is the end of the bandstop.
|
|
30
|
+
or if this is greater than `cutoff` then this is the end of the bandstop.
|
|
31
31
|
"""
|
|
32
32
|
|
|
33
33
|
cutoff: float | None = None
|
|
34
34
|
"""
|
|
35
35
|
Cutoff frequency (Hz). If `cuton` is not specified then this is the lowpass corner. Otherwise,
|
|
36
36
|
if this is greater than `cuton` then this is the end of the bandpass,
|
|
37
|
-
or if this is less than `cuton` then this is the beginning of the bandstop.
|
|
37
|
+
or if this is less than `cuton` then this is the beginning of the bandstop.
|
|
38
38
|
"""
|
|
39
39
|
|
|
40
40
|
wn_hz: bool = True
|
|
@@ -96,9 +96,7 @@ def butter_design_fun(
|
|
|
96
96
|
"""
|
|
97
97
|
coefs = None
|
|
98
98
|
if order > 0:
|
|
99
|
-
btype, cutoffs = ButterworthFilterSettings(
|
|
100
|
-
order=order, cuton=cuton, cutoff=cutoff
|
|
101
|
-
).filter_specs()
|
|
99
|
+
btype, cutoffs = ButterworthFilterSettings(order=order, cuton=cuton, cutoff=cutoff).filter_specs()
|
|
102
100
|
coefs = scipy.signal.butter(
|
|
103
101
|
order,
|
|
104
102
|
Wn=cutoffs,
|
|
@@ -111,9 +109,7 @@ def butter_design_fun(
|
|
|
111
109
|
return coefs
|
|
112
110
|
|
|
113
111
|
|
|
114
|
-
class ButterworthFilterTransformer(
|
|
115
|
-
FilterByDesignTransformer[ButterworthFilterSettings, BACoeffs | SOSCoeffs]
|
|
116
|
-
):
|
|
112
|
+
class ButterworthFilterTransformer(FilterByDesignTransformer[ButterworthFilterSettings, BACoeffs | SOSCoeffs]):
|
|
117
113
|
def get_design_function(
|
|
118
114
|
self,
|
|
119
115
|
) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
|
|
@@ -127,11 +123,7 @@ class ButterworthFilterTransformer(
|
|
|
127
123
|
)
|
|
128
124
|
|
|
129
125
|
|
|
130
|
-
class ButterworthFilter(
|
|
131
|
-
BaseFilterByDesignTransformerUnit[
|
|
132
|
-
ButterworthFilterSettings, ButterworthFilterTransformer
|
|
133
|
-
]
|
|
134
|
-
):
|
|
126
|
+
class ButterworthFilter(BaseFilterByDesignTransformerUnit[ButterworthFilterSettings, ButterworthFilterTransformer]):
|
|
135
127
|
SETTINGS = ButterworthFilterSettings
|
|
136
128
|
|
|
137
129
|
|
|
@@ -4,6 +4,9 @@ import typing
|
|
|
4
4
|
import ezmsg.core as ez
|
|
5
5
|
import numpy as np
|
|
6
6
|
import scipy.signal
|
|
7
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
|
+
from ezmsg.util.messages.util import replace
|
|
9
|
+
|
|
7
10
|
from ezmsg.sigproc.base import SettingsType
|
|
8
11
|
from ezmsg.sigproc.butterworthfilter import ButterworthFilterSettings, butter_design_fun
|
|
9
12
|
from ezmsg.sigproc.filter import (
|
|
@@ -12,8 +15,6 @@ from ezmsg.sigproc.filter import (
|
|
|
12
15
|
FilterByDesignTransformer,
|
|
13
16
|
SOSCoeffs,
|
|
14
17
|
)
|
|
15
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
16
|
-
from ezmsg.util.messages.util import replace
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class ButterworthZeroPhaseSettings(ButterworthFilterSettings):
|
|
@@ -34,9 +35,7 @@ class ButterworthZeroPhaseSettings(ButterworthFilterSettings):
|
|
|
34
35
|
"""
|
|
35
36
|
|
|
36
37
|
|
|
37
|
-
class ButterworthZeroPhaseTransformer(
|
|
38
|
-
FilterByDesignTransformer[ButterworthZeroPhaseSettings, BACoeffs | SOSCoeffs]
|
|
39
|
-
):
|
|
38
|
+
class ButterworthZeroPhaseTransformer(FilterByDesignTransformer[ButterworthZeroPhaseSettings, BACoeffs | SOSCoeffs]):
|
|
40
39
|
"""Zero-phase (filtfilt) Butterworth using your design function."""
|
|
41
40
|
|
|
42
41
|
def get_design_function(
|
|
@@ -51,9 +50,7 @@ class ButterworthZeroPhaseTransformer(
|
|
|
51
50
|
wn_hz=self.settings.wn_hz,
|
|
52
51
|
)
|
|
53
52
|
|
|
54
|
-
def update_settings(
|
|
55
|
-
self, new_settings: typing.Optional[SettingsType] = None, **kwargs
|
|
56
|
-
) -> None:
|
|
53
|
+
def update_settings(self, new_settings: typing.Optional[SettingsType] = None, **kwargs) -> None:
|
|
57
54
|
"""
|
|
58
55
|
Update settings and mark that filter coefficients need to be recalculated.
|
|
59
56
|
|
|
@@ -91,11 +88,7 @@ class ButterworthZeroPhaseTransformer(
|
|
|
91
88
|
self._fs_cache = fs
|
|
92
89
|
self.state.needs_redesign = False
|
|
93
90
|
|
|
94
|
-
if
|
|
95
|
-
self._coefs_cache is None
|
|
96
|
-
or self.settings.order <= 0
|
|
97
|
-
or message.data.size <= 0
|
|
98
|
-
):
|
|
91
|
+
if self._coefs_cache is None or self.settings.order <= 0 or message.data.size <= 0:
|
|
99
92
|
return message
|
|
100
93
|
|
|
101
94
|
x = message.data
|
|
@@ -125,8 +118,6 @@ class ButterworthZeroPhaseTransformer(
|
|
|
125
118
|
|
|
126
119
|
|
|
127
120
|
class ButterworthZeroPhase(
|
|
128
|
-
BaseFilterByDesignTransformerUnit[
|
|
129
|
-
ButterworthZeroPhaseSettings, ButterworthZeroPhaseTransformer
|
|
130
|
-
]
|
|
121
|
+
BaseFilterByDesignTransformerUnit[ButterworthZeroPhaseSettings, ButterworthZeroPhaseTransformer]
|
|
131
122
|
):
|
|
132
123
|
SETTINGS = ButterworthZeroPhaseSettings
|
ezmsg/sigproc/cheby.py
CHANGED
|
@@ -5,11 +5,11 @@ import scipy.signal
|
|
|
5
5
|
from scipy.signal import normalize
|
|
6
6
|
|
|
7
7
|
from .filter import (
|
|
8
|
+
BACoeffs,
|
|
9
|
+
BaseFilterByDesignTransformerUnit,
|
|
8
10
|
FilterBaseSettings,
|
|
9
11
|
FilterByDesignTransformer,
|
|
10
|
-
BACoeffs,
|
|
11
12
|
SOSCoeffs,
|
|
12
|
-
BaseFilterByDesignTransformerUnit,
|
|
13
13
|
)
|
|
14
14
|
|
|
15
15
|
|
|
@@ -104,9 +104,7 @@ def cheby_design_fun(
|
|
|
104
104
|
return coefs
|
|
105
105
|
|
|
106
106
|
|
|
107
|
-
class ChebyshevFilterTransformer(
|
|
108
|
-
FilterByDesignTransformer[ChebyshevFilterSettings, BACoeffs | SOSCoeffs]
|
|
109
|
-
):
|
|
107
|
+
class ChebyshevFilterTransformer(FilterByDesignTransformer[ChebyshevFilterSettings, BACoeffs | SOSCoeffs]):
|
|
110
108
|
def get_design_function(
|
|
111
109
|
self,
|
|
112
110
|
) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
|
|
@@ -123,9 +121,5 @@ class ChebyshevFilterTransformer(
|
|
|
123
121
|
)
|
|
124
122
|
|
|
125
123
|
|
|
126
|
-
class ChebyshevFilter(
|
|
127
|
-
BaseFilterByDesignTransformerUnit[
|
|
128
|
-
ChebyshevFilterSettings, ChebyshevFilterTransformer
|
|
129
|
-
]
|
|
130
|
-
):
|
|
124
|
+
class ChebyshevFilter(BaseFilterByDesignTransformerUnit[ChebyshevFilterSettings, ChebyshevFilterTransformer]):
|
|
131
125
|
SETTINGS = ChebyshevFilterSettings
|
ezmsg/sigproc/combfilter.py
CHANGED
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import typing
|
|
3
|
+
|
|
3
4
|
import numpy as np
|
|
4
5
|
import scipy.signal
|
|
5
6
|
from scipy.signal import normalize
|
|
6
7
|
|
|
7
8
|
from .filter import (
|
|
9
|
+
BACoeffs,
|
|
10
|
+
BaseFilterByDesignTransformerUnit,
|
|
8
11
|
FilterBaseSettings,
|
|
9
12
|
FilterByDesignTransformer,
|
|
10
|
-
BACoeffs,
|
|
11
13
|
SOSCoeffs,
|
|
12
|
-
BaseFilterByDesignTransformerUnit,
|
|
13
14
|
)
|
|
14
15
|
|
|
15
16
|
|
|
@@ -103,9 +104,7 @@ def comb_design_fun(
|
|
|
103
104
|
return combined_sos
|
|
104
105
|
|
|
105
106
|
|
|
106
|
-
class CombFilterTransformer(
|
|
107
|
-
FilterByDesignTransformer[CombFilterSettings, BACoeffs | SOSCoeffs]
|
|
108
|
-
):
|
|
107
|
+
class CombFilterTransformer(FilterByDesignTransformer[CombFilterSettings, BACoeffs | SOSCoeffs]):
|
|
109
108
|
def get_design_function(
|
|
110
109
|
self,
|
|
111
110
|
) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
|
|
@@ -120,9 +119,7 @@ class CombFilterTransformer(
|
|
|
120
119
|
)
|
|
121
120
|
|
|
122
121
|
|
|
123
|
-
class CombFilterUnit(
|
|
124
|
-
BaseFilterByDesignTransformerUnit[CombFilterSettings, CombFilterTransformer]
|
|
125
|
-
):
|
|
122
|
+
class CombFilterUnit(BaseFilterByDesignTransformerUnit[CombFilterSettings, CombFilterTransformer]):
|
|
126
123
|
SETTINGS = CombFilterSettings
|
|
127
124
|
|
|
128
125
|
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Coordinate space transformations for streaming data.
|
|
3
|
+
|
|
4
|
+
This module provides utilities and ezmsg nodes for transforming between
|
|
5
|
+
Cartesian (x, y) and polar (r, theta) coordinate systems.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from enum import Enum
|
|
9
|
+
from typing import Tuple
|
|
10
|
+
|
|
11
|
+
import ezmsg.core as ez
|
|
12
|
+
import numpy as np
|
|
13
|
+
import numpy.typing as npt
|
|
14
|
+
from ezmsg.baseproc import (
|
|
15
|
+
BaseTransformer,
|
|
16
|
+
BaseTransformerUnit,
|
|
17
|
+
)
|
|
18
|
+
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
19
|
+
|
|
20
|
+
# -- Utility functions for coordinate transformations --
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def polar2z(r: npt.ArrayLike, theta: npt.ArrayLike) -> npt.ArrayLike:
|
|
24
|
+
"""Convert polar coordinates to complex number representation."""
|
|
25
|
+
return r * np.exp(1j * theta)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def z2polar(z: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
|
|
29
|
+
"""Convert complex number to polar coordinates (r, theta)."""
|
|
30
|
+
return np.abs(z), np.angle(z)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def cart2z(x: npt.ArrayLike, y: npt.ArrayLike) -> npt.ArrayLike:
|
|
34
|
+
"""Convert Cartesian coordinates to complex number representation."""
|
|
35
|
+
return x + 1j * y
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def z2cart(z: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
|
|
39
|
+
"""Convert complex number to Cartesian coordinates (x, y)."""
|
|
40
|
+
return np.real(z), np.imag(z)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def cart2pol(x: npt.ArrayLike, y: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
|
|
44
|
+
"""Convert Cartesian coordinates (x, y) to polar coordinates (r, theta)."""
|
|
45
|
+
return z2polar(cart2z(x, y))
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def pol2cart(r: npt.ArrayLike, theta: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
|
|
49
|
+
"""Convert polar coordinates (r, theta) to Cartesian coordinates (x, y)."""
|
|
50
|
+
return z2cart(polar2z(r, theta))
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# -- ezmsg transformer classes --
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class CoordinateMode(str, Enum):
|
|
57
|
+
"""Transformation mode for coordinate conversion."""
|
|
58
|
+
|
|
59
|
+
CART2POL = "cart2pol"
|
|
60
|
+
"""Convert Cartesian (x, y) to polar (r, theta)."""
|
|
61
|
+
|
|
62
|
+
POL2CART = "pol2cart"
|
|
63
|
+
"""Convert polar (r, theta) to Cartesian (x, y)."""
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class CoordinateSpacesSettings(ez.Settings):
|
|
67
|
+
"""
|
|
68
|
+
Settings for :obj:`CoordinateSpaces`.
|
|
69
|
+
|
|
70
|
+
See :obj:`coordinate_spaces` for argument details.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
mode: CoordinateMode = CoordinateMode.CART2POL
|
|
74
|
+
"""The transformation mode: 'cart2pol' or 'pol2cart'."""
|
|
75
|
+
|
|
76
|
+
axis: str | None = None
|
|
77
|
+
"""
|
|
78
|
+
The name of the axis containing the coordinate components.
|
|
79
|
+
Defaults to the last axis. Must have exactly 2 elements (x,y or r,theta).
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class CoordinateSpacesTransformer(BaseTransformer[CoordinateSpacesSettings, AxisArray, AxisArray]):
|
|
84
|
+
"""
|
|
85
|
+
Transform between Cartesian and polar coordinate systems.
|
|
86
|
+
|
|
87
|
+
The input must have exactly 2 elements along the specified axis:
|
|
88
|
+
- For cart2pol: expects (x, y), outputs (r, theta)
|
|
89
|
+
- For pol2cart: expects (r, theta), outputs (x, y)
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
93
|
+
axis = self.settings.axis or message.dims[-1]
|
|
94
|
+
axis_idx = message.get_axis_idx(axis)
|
|
95
|
+
|
|
96
|
+
if message.data.shape[axis_idx] != 2:
|
|
97
|
+
raise ValueError(
|
|
98
|
+
f"Coordinate transformation requires exactly 2 elements along axis '{axis}', "
|
|
99
|
+
f"got {message.data.shape[axis_idx]}."
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Extract components along the specified axis
|
|
103
|
+
slices_a = [slice(None)] * message.data.ndim
|
|
104
|
+
slices_b = [slice(None)] * message.data.ndim
|
|
105
|
+
slices_a[axis_idx] = 0
|
|
106
|
+
slices_b[axis_idx] = 1
|
|
107
|
+
|
|
108
|
+
component_a = message.data[tuple(slices_a)]
|
|
109
|
+
component_b = message.data[tuple(slices_b)]
|
|
110
|
+
|
|
111
|
+
if self.settings.mode == CoordinateMode.CART2POL:
|
|
112
|
+
# Input: x, y -> Output: r, theta
|
|
113
|
+
out_a, out_b = cart2pol(component_a, component_b)
|
|
114
|
+
else:
|
|
115
|
+
# Input: r, theta -> Output: x, y
|
|
116
|
+
out_a, out_b = pol2cart(component_a, component_b)
|
|
117
|
+
|
|
118
|
+
# Stack results back along the same axis
|
|
119
|
+
result = np.stack([out_a, out_b], axis=axis_idx)
|
|
120
|
+
|
|
121
|
+
# Update axis labels if present
|
|
122
|
+
axes = message.axes
|
|
123
|
+
if axis in axes and hasattr(axes[axis], "data"):
|
|
124
|
+
if self.settings.mode == CoordinateMode.CART2POL:
|
|
125
|
+
new_labels = np.array(["r", "theta"])
|
|
126
|
+
else:
|
|
127
|
+
new_labels = np.array(["x", "y"])
|
|
128
|
+
axes = {**axes, axis: replace(axes[axis], data=new_labels)}
|
|
129
|
+
|
|
130
|
+
return replace(message, data=result, axes=axes)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class CoordinateSpaces(
|
|
134
|
+
BaseTransformerUnit[CoordinateSpacesSettings, AxisArray, AxisArray, CoordinateSpacesTransformer]
|
|
135
|
+
):
|
|
136
|
+
"""
|
|
137
|
+
Unit for transforming between Cartesian and polar coordinate systems.
|
|
138
|
+
|
|
139
|
+
See :obj:`CoordinateSpacesSettings` for configuration options.
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
SETTINGS = CoordinateSpacesSettings
|
ezmsg/sigproc/decimate.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import typing
|
|
2
2
|
|
|
3
3
|
import ezmsg.core as ez
|
|
4
|
+
from ezmsg.baseproc import BaseTransformerUnit
|
|
4
5
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
5
6
|
|
|
6
|
-
from .
|
|
7
|
-
from .cheby import ChebyshevFilterTransformer, ChebyshevFilterSettings
|
|
7
|
+
from .cheby import ChebyshevFilterSettings, ChebyshevFilterTransformer
|
|
8
8
|
from .downsample import Downsample, DownsampleSettings
|
|
9
9
|
from .filter import BACoeffs, SOSCoeffs
|
|
10
10
|
|
|
@@ -30,11 +30,7 @@ class ChebyForDecimateTransformer(ChebyshevFilterTransformer[BACoeffs | SOSCoeff
|
|
|
30
30
|
return cheby_opt_design_fun
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
class ChebyForDecimate(
|
|
34
|
-
BaseTransformerUnit[
|
|
35
|
-
ChebyshevFilterSettings, AxisArray, AxisArray, ChebyForDecimateTransformer
|
|
36
|
-
]
|
|
37
|
-
):
|
|
33
|
+
class ChebyForDecimate(BaseTransformerUnit[ChebyshevFilterSettings, AxisArray, AxisArray, ChebyForDecimateTransformer]):
|
|
38
34
|
SETTINGS = ChebyshevFilterSettings
|
|
39
35
|
|
|
40
36
|
|
ezmsg/sigproc/denormalize.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
import ezmsg.core as ez
|
|
2
2
|
import numpy as np
|
|
3
3
|
import numpy.typing as npt
|
|
4
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
5
|
+
from ezmsg.util.messages.util import replace
|
|
6
|
+
|
|
4
7
|
from ezmsg.sigproc.base import (
|
|
5
|
-
BaseTransformerUnit,
|
|
6
8
|
BaseStatefulTransformer,
|
|
9
|
+
BaseTransformerUnit,
|
|
7
10
|
processor_state,
|
|
8
11
|
)
|
|
9
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
10
|
-
from ezmsg.util.messages.util import replace
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class DenormalizeSettings(ez.Settings):
|
|
@@ -27,9 +28,7 @@ class DenormalizeState:
|
|
|
27
28
|
offsets: npt.NDArray | None = None
|
|
28
29
|
|
|
29
30
|
|
|
30
|
-
class DenormalizeTransformer(
|
|
31
|
-
BaseStatefulTransformer[DenormalizeSettings, AxisArray, AxisArray, DenormalizeState]
|
|
32
|
-
):
|
|
31
|
+
class DenormalizeTransformer(BaseStatefulTransformer[DenormalizeSettings, AxisArray, AxisArray, DenormalizeState]):
|
|
33
32
|
"""
|
|
34
33
|
Scales data from a normalized distribution (mean=0, std=1) to a denormalized
|
|
35
34
|
distribution using random per-channel offsets and gains designed to keep the
|
|
@@ -76,9 +75,5 @@ class DenormalizeTransformer(
|
|
|
76
75
|
)
|
|
77
76
|
|
|
78
77
|
|
|
79
|
-
class DenormalizeUnit(
|
|
80
|
-
BaseTransformerUnit[
|
|
81
|
-
DenormalizeSettings, AxisArray, AxisArray, DenormalizeTransformer
|
|
82
|
-
]
|
|
83
|
-
):
|
|
78
|
+
class DenormalizeUnit(BaseTransformerUnit[DenormalizeSettings, AxisArray, AxisArray, DenormalizeTransformer]):
|
|
84
79
|
SETTINGS = DenormalizeSettings
|
ezmsg/sigproc/detrend.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import scipy.signal as sps
|
|
2
2
|
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
3
|
-
|
|
3
|
+
|
|
4
4
|
from ezmsg.sigproc.base import BaseTransformerUnit
|
|
5
|
+
from ezmsg.sigproc.ewma import EWMASettings, EWMATransformer
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
class DetrendTransformer(EWMATransformer):
|
|
@@ -23,7 +24,5 @@ class DetrendTransformer(EWMATransformer):
|
|
|
23
24
|
return replace(message, data=message.data - means)
|
|
24
25
|
|
|
25
26
|
|
|
26
|
-
class DetrendUnit(
|
|
27
|
-
BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, DetrendTransformer]
|
|
28
|
-
):
|
|
27
|
+
class DetrendUnit(BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, DetrendTransformer]):
|
|
29
28
|
SETTINGS = EWMASettings
|
ezmsg/sigproc/diff.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
import ezmsg.core as ez
|
|
2
2
|
import numpy as np
|
|
3
3
|
import numpy.typing as npt
|
|
4
|
+
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
5
|
+
from ezmsg.util.messages.util import replace
|
|
6
|
+
|
|
4
7
|
from ezmsg.sigproc.base import (
|
|
8
|
+
BaseStatefulTransformer,
|
|
5
9
|
BaseTransformerUnit,
|
|
6
10
|
processor_state,
|
|
7
|
-
BaseStatefulTransformer,
|
|
8
11
|
)
|
|
9
|
-
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
10
|
-
from ezmsg.util.messages.util import replace
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class DiffSettings(ez.Settings):
|
|
@@ -21,9 +22,7 @@ class DiffState:
|
|
|
21
22
|
last_time: float | None = None
|
|
22
23
|
|
|
23
24
|
|
|
24
|
-
class DiffTransformer(
|
|
25
|
-
BaseStatefulTransformer[DiffSettings, AxisArray, AxisArray, DiffState]
|
|
26
|
-
):
|
|
25
|
+
class DiffTransformer(BaseStatefulTransformer[DiffSettings, AxisArray, AxisArray, DiffState]):
|
|
27
26
|
def _hash_message(self, message: AxisArray) -> int:
|
|
28
27
|
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
29
28
|
sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
|
|
@@ -49,20 +48,14 @@ class DiffTransformer(
|
|
|
49
48
|
axis=ax_idx,
|
|
50
49
|
)
|
|
51
50
|
# Prepare last_dat for next iteration
|
|
52
|
-
self.state.last_dat = slice_along_axis(
|
|
53
|
-
message.data, slice(-1, None), axis=ax_idx
|
|
54
|
-
)
|
|
51
|
+
self.state.last_dat = slice_along_axis(message.data, slice(-1, None), axis=ax_idx)
|
|
55
52
|
# Scale by fs if requested. This convers the diff to a derivative. e.g., diff of position becomes velocity.
|
|
56
53
|
if self.settings.scale_by_fs:
|
|
57
54
|
ax_info = message.get_axis(axis)
|
|
58
55
|
if hasattr(ax_info, "data"):
|
|
59
56
|
dt = np.diff(np.concatenate(([self.state.last_time], ax_info.data)))
|
|
60
57
|
# Expand dt dims to match diffs
|
|
61
|
-
exp_sl = (
|
|
62
|
-
(None,) * ax_idx
|
|
63
|
-
+ (Ellipsis,)
|
|
64
|
-
+ (None,) * (message.data.ndim - ax_idx - 1)
|
|
65
|
-
)
|
|
58
|
+
exp_sl = (None,) * ax_idx + (Ellipsis,) + (None,) * (message.data.ndim - ax_idx - 1)
|
|
66
59
|
diffs /= dt[exp_sl]
|
|
67
60
|
self.state.last_time = ax_info.data[-1] # For next iteration
|
|
68
61
|
else:
|
|
@@ -71,9 +64,7 @@ class DiffTransformer(
|
|
|
71
64
|
return replace(message, data=diffs)
|
|
72
65
|
|
|
73
66
|
|
|
74
|
-
class DiffUnit(
|
|
75
|
-
BaseTransformerUnit[DiffSettings, AxisArray, AxisArray, DiffTransformer]
|
|
76
|
-
):
|
|
67
|
+
class DiffUnit(BaseTransformerUnit[DiffSettings, AxisArray, AxisArray, DiffTransformer]):
|
|
77
68
|
SETTINGS = DiffSettings
|
|
78
69
|
|
|
79
70
|
|
ezmsg/sigproc/downsample.py
CHANGED
|
@@ -1,16 +1,15 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
from ezmsg.util.messages.axisarray import (
|
|
3
|
-
AxisArray,
|
|
4
|
-
slice_along_axis,
|
|
5
|
-
replace,
|
|
6
|
-
)
|
|
7
1
|
import ezmsg.core as ez
|
|
8
|
-
|
|
9
|
-
from .
|
|
2
|
+
import numpy as np
|
|
3
|
+
from ezmsg.baseproc import (
|
|
10
4
|
BaseStatefulTransformer,
|
|
11
5
|
BaseTransformerUnit,
|
|
12
6
|
processor_state,
|
|
13
7
|
)
|
|
8
|
+
from ezmsg.util.messages.axisarray import (
|
|
9
|
+
AxisArray,
|
|
10
|
+
replace,
|
|
11
|
+
slice_along_axis,
|
|
12
|
+
)
|
|
14
13
|
|
|
15
14
|
|
|
16
15
|
class DownsampleSettings(ez.Settings):
|
|
@@ -38,9 +37,7 @@ class DownsampleState:
|
|
|
38
37
|
"""Index of the next msg's first sample into the virtual rotating ds_factor counter."""
|
|
39
38
|
|
|
40
39
|
|
|
41
|
-
class DownsampleTransformer(
|
|
42
|
-
BaseStatefulTransformer[DownsampleSettings, AxisArray, AxisArray, DownsampleState]
|
|
43
|
-
):
|
|
40
|
+
class DownsampleTransformer(BaseStatefulTransformer[DownsampleSettings, AxisArray, AxisArray, DownsampleState]):
|
|
44
41
|
"""
|
|
45
42
|
Downsampled data simply comprise every `factor`th sample.
|
|
46
43
|
This should only be used following appropriate lowpass filtering.
|
|
@@ -75,9 +72,7 @@ class DownsampleTransformer(
|
|
|
75
72
|
axis_idx = message.get_axis_idx(axis)
|
|
76
73
|
|
|
77
74
|
n_samples = message.data.shape[axis_idx]
|
|
78
|
-
samples = (
|
|
79
|
-
np.arange(self.state.s_idx, self.state.s_idx + n_samples) % self._state.q
|
|
80
|
-
)
|
|
75
|
+
samples = np.arange(self.state.s_idx, self.state.s_idx + n_samples) % self._state.q
|
|
81
76
|
if n_samples > 0:
|
|
82
77
|
# Update state for next iteration.
|
|
83
78
|
self._state.s_idx = samples[-1] + 1
|
|
@@ -104,9 +99,7 @@ class DownsampleTransformer(
|
|
|
104
99
|
return msg_out
|
|
105
100
|
|
|
106
101
|
|
|
107
|
-
class Downsample(
|
|
108
|
-
BaseTransformerUnit[DownsampleSettings, AxisArray, AxisArray, DownsampleTransformer]
|
|
109
|
-
):
|
|
102
|
+
class Downsample(BaseTransformerUnit[DownsampleSettings, AxisArray, AxisArray, DownsampleTransformer]):
|
|
110
103
|
SETTINGS = DownsampleSettings
|
|
111
104
|
|
|
112
105
|
|
|
@@ -115,6 +108,4 @@ def downsample(
|
|
|
115
108
|
target_rate: float | None = None,
|
|
116
109
|
factor: int | None = None,
|
|
117
110
|
) -> DownsampleTransformer:
|
|
118
|
-
return DownsampleTransformer(
|
|
119
|
-
DownsampleSettings(axis=axis, target_rate=target_rate, factor=factor)
|
|
120
|
-
)
|
|
111
|
+
return DownsampleTransformer(DownsampleSettings(axis=axis, target_rate=target_rate, factor=factor))
|
ezmsg/sigproc/ewma.py
CHANGED
|
@@ -1,15 +1,14 @@
|
|
|
1
|
-
from dataclasses import field
|
|
2
1
|
import functools
|
|
2
|
+
from dataclasses import field
|
|
3
3
|
|
|
4
|
+
import ezmsg.core as ez
|
|
4
5
|
import numpy as np
|
|
5
6
|
import numpy.typing as npt
|
|
6
7
|
import scipy.signal as sps
|
|
7
|
-
|
|
8
|
+
from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state
|
|
8
9
|
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
9
10
|
from ezmsg.util.messages.util import replace
|
|
10
11
|
|
|
11
|
-
from .base import BaseStatefulTransformer, processor_state, BaseTransformerUnit
|
|
12
|
-
|
|
13
12
|
|
|
14
13
|
def _tau_from_alpha(alpha: float, dt: float) -> float:
|
|
15
14
|
"""
|
|
@@ -29,9 +28,7 @@ def _alpha_from_tau(tau: float, dt: float) -> float:
|
|
|
29
28
|
return 1 - np.exp(-dt / tau)
|
|
30
29
|
|
|
31
30
|
|
|
32
|
-
def ewma_step(
|
|
33
|
-
sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None
|
|
34
|
-
):
|
|
31
|
+
def ewma_step(sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None):
|
|
35
32
|
"""
|
|
36
33
|
Do an exponentially weighted moving average step.
|
|
37
34
|
|
|
@@ -97,9 +94,7 @@ class EWMA_Deprecated:
|
|
|
97
94
|
if self.prev is None:
|
|
98
95
|
self.prev = arr[:1]
|
|
99
96
|
|
|
100
|
-
out += self.prev * np.expand_dims(
|
|
101
|
-
self.weights[1 : n + 1], list(range(1, arr.ndim))
|
|
102
|
-
)
|
|
97
|
+
out += self.prev * np.expand_dims(self.weights[1 : n + 1], list(range(1, arr.ndim)))
|
|
103
98
|
|
|
104
99
|
self.prev = out[-1:]
|
|
105
100
|
|
|
@@ -128,9 +123,7 @@ class EWMA_Deprecated:
|
|
|
128
123
|
if self.prev is None:
|
|
129
124
|
self.prev = arr[:1]
|
|
130
125
|
|
|
131
|
-
result += self.prev * np.expand_dims(
|
|
132
|
-
self.weights[1 : n + 1], list(range(1, arr.ndim))
|
|
133
|
-
)
|
|
126
|
+
result += self.prev * np.expand_dims(self.weights[1 : n + 1], list(range(1, arr.ndim)))
|
|
134
127
|
|
|
135
128
|
# Store the result back into prev
|
|
136
129
|
self.prev = result[-1]
|
|
@@ -155,25 +148,17 @@ class EWMAState:
|
|
|
155
148
|
zi: npt.NDArray | None = None
|
|
156
149
|
|
|
157
150
|
|
|
158
|
-
class EWMATransformer(
|
|
159
|
-
BaseStatefulTransformer[EWMASettings, AxisArray, AxisArray, EWMAState]
|
|
160
|
-
):
|
|
151
|
+
class EWMATransformer(BaseStatefulTransformer[EWMASettings, AxisArray, AxisArray, EWMAState]):
|
|
161
152
|
def _hash_message(self, message: AxisArray) -> int:
|
|
162
153
|
axis = self.settings.axis or message.dims[0]
|
|
163
154
|
axis_idx = message.get_axis_idx(axis)
|
|
164
|
-
sample_shape =
|
|
165
|
-
message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
166
|
-
)
|
|
155
|
+
sample_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
167
156
|
return hash((sample_shape, message.axes[axis].gain, message.key))
|
|
168
157
|
|
|
169
158
|
def _reset_state(self, message: AxisArray) -> None:
|
|
170
159
|
axis = self.settings.axis or message.dims[0]
|
|
171
|
-
self._state.alpha = _alpha_from_tau(
|
|
172
|
-
|
|
173
|
-
)
|
|
174
|
-
sub_dat = slice_along_axis(
|
|
175
|
-
message.data, slice(None, 1, None), axis=message.get_axis_idx(axis)
|
|
176
|
-
)
|
|
160
|
+
self._state.alpha = _alpha_from_tau(self.settings.time_constant, message.axes[axis].gain)
|
|
161
|
+
sub_dat = slice_along_axis(message.data, slice(None, 1, None), axis=message.get_axis_idx(axis))
|
|
177
162
|
self._state.zi = (1 - self._state.alpha) * sub_dat
|
|
178
163
|
|
|
179
164
|
def _process(self, message: AxisArray) -> AxisArray:
|
|
@@ -191,7 +176,5 @@ class EWMATransformer(
|
|
|
191
176
|
return replace(message, data=expected)
|
|
192
177
|
|
|
193
178
|
|
|
194
|
-
class EWMAUnit(
|
|
195
|
-
BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, EWMATransformer]
|
|
196
|
-
):
|
|
179
|
+
class EWMAUnit(BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, EWMATransformer]):
|
|
197
180
|
SETTINGS = EWMASettings
|
ezmsg/sigproc/ewmfilter.py
CHANGED
ezmsg/sigproc/extract_axis.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
|
-
import numpy as np
|
|
2
1
|
import ezmsg.core as ez
|
|
2
|
+
import numpy as np
|
|
3
3
|
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
4
|
+
|
|
4
5
|
from ezmsg.sigproc.base import BaseTransformer, BaseTransformerUnit
|
|
5
6
|
|
|
6
7
|
|
|
@@ -35,7 +36,5 @@ class ExtractAxisData(BaseTransformer[ExtractAxisSettings, AxisArray, AxisArray]
|
|
|
35
36
|
)
|
|
36
37
|
|
|
37
38
|
|
|
38
|
-
class ExtractAxisDataUnit(
|
|
39
|
-
BaseTransformerUnit[ExtractAxisSettings, AxisArray, AxisArray, ExtractAxisData]
|
|
40
|
-
):
|
|
39
|
+
class ExtractAxisDataUnit(BaseTransformerUnit[ExtractAxisSettings, AxisArray, AxisArray, ExtractAxisData]):
|
|
41
40
|
SETTINGS = ExtractAxisSettings
|