ezmsg-sigproc 2.4.1__py3-none-any.whl → 2.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +5 -11
- ezmsg/sigproc/adaptive_lattice_notch.py +11 -29
- ezmsg/sigproc/affinetransform.py +13 -38
- ezmsg/sigproc/aggregate.py +13 -30
- ezmsg/sigproc/bandpower.py +7 -15
- ezmsg/sigproc/base.py +141 -1276
- ezmsg/sigproc/butterworthfilter.py +8 -16
- ezmsg/sigproc/butterworthzerophase.py +123 -0
- ezmsg/sigproc/cheby.py +4 -10
- ezmsg/sigproc/combfilter.py +5 -8
- ezmsg/sigproc/decimate.py +2 -6
- ezmsg/sigproc/denormalize.py +6 -11
- ezmsg/sigproc/detrend.py +3 -4
- ezmsg/sigproc/diff.py +8 -17
- ezmsg/sigproc/downsample.py +6 -14
- ezmsg/sigproc/ewma.py +11 -27
- ezmsg/sigproc/ewmfilter.py +1 -1
- ezmsg/sigproc/extract_axis.py +3 -4
- ezmsg/sigproc/fbcca.py +31 -56
- ezmsg/sigproc/filter.py +19 -45
- ezmsg/sigproc/filterbank.py +33 -70
- ezmsg/sigproc/filterbankdesign.py +5 -12
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +12 -14
- ezmsg/sigproc/gaussiansmoothing.py +5 -9
- ezmsg/sigproc/kaiser.py +11 -15
- ezmsg/sigproc/math/abs.py +1 -3
- ezmsg/sigproc/math/add.py +121 -0
- ezmsg/sigproc/math/clip.py +1 -1
- ezmsg/sigproc/math/difference.py +98 -36
- ezmsg/sigproc/math/invert.py +1 -3
- ezmsg/sigproc/math/log.py +2 -6
- ezmsg/sigproc/messages.py +1 -2
- ezmsg/sigproc/quantize.py +2 -4
- ezmsg/sigproc/resample.py +13 -34
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +17 -35
- ezmsg/sigproc/scaler.py +8 -18
- ezmsg/sigproc/signalinjector.py +6 -16
- ezmsg/sigproc/slicer.py +9 -28
- ezmsg/sigproc/spectral.py +3 -3
- ezmsg/sigproc/spectrogram.py +12 -19
- ezmsg/sigproc/spectrum.py +12 -32
- ezmsg/sigproc/transpose.py +7 -18
- ezmsg/sigproc/util/asio.py +25 -156
- ezmsg/sigproc/util/axisarray_buffer.py +10 -26
- ezmsg/sigproc/util/buffer.py +18 -43
- ezmsg/sigproc/util/message.py +17 -31
- ezmsg/sigproc/util/profile.py +23 -174
- ezmsg/sigproc/util/sparse.py +5 -15
- ezmsg/sigproc/util/typeresolution.py +17 -83
- ezmsg/sigproc/wavelets.py +6 -15
- ezmsg/sigproc/window.py +24 -78
- {ezmsg_sigproc-2.4.1.dist-info → ezmsg_sigproc-2.6.0.dist-info}/METADATA +4 -3
- ezmsg_sigproc-2.6.0.dist-info/RECORD +63 -0
- ezmsg/sigproc/synth.py +0 -774
- ezmsg_sigproc-2.4.1.dist-info/RECORD +0 -59
- {ezmsg_sigproc-2.4.1.dist-info → ezmsg_sigproc-2.6.0.dist-info}/WHEEL +0 -0
- /ezmsg_sigproc-2.4.1.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.6.0.dist-info/licenses/LICENSE +0 -0
|
@@ -5,11 +5,11 @@ import scipy.signal
|
|
|
5
5
|
from scipy.signal import normalize
|
|
6
6
|
|
|
7
7
|
from .filter import (
|
|
8
|
-
FilterBaseSettings,
|
|
9
8
|
BACoeffs,
|
|
10
|
-
SOSCoeffs,
|
|
11
|
-
FilterByDesignTransformer,
|
|
12
9
|
BaseFilterByDesignTransformerUnit,
|
|
10
|
+
FilterBaseSettings,
|
|
11
|
+
FilterByDesignTransformer,
|
|
12
|
+
SOSCoeffs,
|
|
13
13
|
)
|
|
14
14
|
|
|
15
15
|
|
|
@@ -27,14 +27,14 @@ class ButterworthFilterSettings(FilterBaseSettings):
|
|
|
27
27
|
"""
|
|
28
28
|
Cuton frequency (Hz). If `cutoff` is not specified then this is the highpass corner. Otherwise,
|
|
29
29
|
if this is lower than `cutoff` then this is the beginning of the bandpass
|
|
30
|
-
or if this is greater than `cutoff` then this is the end of the bandstop.
|
|
30
|
+
or if this is greater than `cutoff` then this is the end of the bandstop.
|
|
31
31
|
"""
|
|
32
32
|
|
|
33
33
|
cutoff: float | None = None
|
|
34
34
|
"""
|
|
35
35
|
Cutoff frequency (Hz). If `cuton` is not specified then this is the lowpass corner. Otherwise,
|
|
36
36
|
if this is greater than `cuton` then this is the end of the bandpass,
|
|
37
|
-
or if this is less than `cuton` then this is the beginning of the bandstop.
|
|
37
|
+
or if this is less than `cuton` then this is the beginning of the bandstop.
|
|
38
38
|
"""
|
|
39
39
|
|
|
40
40
|
wn_hz: bool = True
|
|
@@ -96,9 +96,7 @@ def butter_design_fun(
|
|
|
96
96
|
"""
|
|
97
97
|
coefs = None
|
|
98
98
|
if order > 0:
|
|
99
|
-
btype, cutoffs = ButterworthFilterSettings(
|
|
100
|
-
order=order, cuton=cuton, cutoff=cutoff
|
|
101
|
-
).filter_specs()
|
|
99
|
+
btype, cutoffs = ButterworthFilterSettings(order=order, cuton=cuton, cutoff=cutoff).filter_specs()
|
|
102
100
|
coefs = scipy.signal.butter(
|
|
103
101
|
order,
|
|
104
102
|
Wn=cutoffs,
|
|
@@ -111,9 +109,7 @@ def butter_design_fun(
|
|
|
111
109
|
return coefs
|
|
112
110
|
|
|
113
111
|
|
|
114
|
-
class ButterworthFilterTransformer(
|
|
115
|
-
FilterByDesignTransformer[ButterworthFilterSettings, BACoeffs | SOSCoeffs]
|
|
116
|
-
):
|
|
112
|
+
class ButterworthFilterTransformer(FilterByDesignTransformer[ButterworthFilterSettings, BACoeffs | SOSCoeffs]):
|
|
117
113
|
def get_design_function(
|
|
118
114
|
self,
|
|
119
115
|
) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
|
|
@@ -127,11 +123,7 @@ class ButterworthFilterTransformer(
|
|
|
127
123
|
)
|
|
128
124
|
|
|
129
125
|
|
|
130
|
-
class ButterworthFilter(
|
|
131
|
-
BaseFilterByDesignTransformerUnit[
|
|
132
|
-
ButterworthFilterSettings, ButterworthFilterTransformer
|
|
133
|
-
]
|
|
134
|
-
):
|
|
126
|
+
class ButterworthFilter(BaseFilterByDesignTransformerUnit[ButterworthFilterSettings, ButterworthFilterTransformer]):
|
|
135
127
|
SETTINGS = ButterworthFilterSettings
|
|
136
128
|
|
|
137
129
|
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import ezmsg.core as ez
|
|
5
|
+
import numpy as np
|
|
6
|
+
import scipy.signal
|
|
7
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
|
+
from ezmsg.util.messages.util import replace
|
|
9
|
+
|
|
10
|
+
from ezmsg.sigproc.base import SettingsType
|
|
11
|
+
from ezmsg.sigproc.butterworthfilter import ButterworthFilterSettings, butter_design_fun
|
|
12
|
+
from ezmsg.sigproc.filter import (
|
|
13
|
+
BACoeffs,
|
|
14
|
+
BaseFilterByDesignTransformerUnit,
|
|
15
|
+
FilterByDesignTransformer,
|
|
16
|
+
SOSCoeffs,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ButterworthZeroPhaseSettings(ButterworthFilterSettings):
|
|
21
|
+
"""Settings for :obj:`ButterworthZeroPhase`."""
|
|
22
|
+
|
|
23
|
+
# axis, coef_type, order, cuton, cutoff, wn_hz are inherited from ButterworthFilterSettings
|
|
24
|
+
padtype: str | None = None
|
|
25
|
+
"""
|
|
26
|
+
Padding type to use in `scipy.signal.filtfilt`.
|
|
27
|
+
Must be one of {'odd', 'even', 'constant', None}.
|
|
28
|
+
Default is None for no padding.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
padlen: int | None = 0
|
|
32
|
+
"""
|
|
33
|
+
Length of the padding to use in `scipy.signal.filtfilt`.
|
|
34
|
+
If None, SciPy's default padding is used.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ButterworthZeroPhaseTransformer(FilterByDesignTransformer[ButterworthZeroPhaseSettings, BACoeffs | SOSCoeffs]):
|
|
39
|
+
"""Zero-phase (filtfilt) Butterworth using your design function."""
|
|
40
|
+
|
|
41
|
+
def get_design_function(
|
|
42
|
+
self,
|
|
43
|
+
) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
|
|
44
|
+
return functools.partial(
|
|
45
|
+
butter_design_fun,
|
|
46
|
+
order=self.settings.order,
|
|
47
|
+
cuton=self.settings.cuton,
|
|
48
|
+
cutoff=self.settings.cutoff,
|
|
49
|
+
coef_type=self.settings.coef_type,
|
|
50
|
+
wn_hz=self.settings.wn_hz,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def update_settings(self, new_settings: typing.Optional[SettingsType] = None, **kwargs) -> None:
|
|
54
|
+
"""
|
|
55
|
+
Update settings and mark that filter coefficients need to be recalculated.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
new_settings: Complete new settings object to replace current settings
|
|
59
|
+
**kwargs: Individual settings to update
|
|
60
|
+
"""
|
|
61
|
+
# Update settings
|
|
62
|
+
if new_settings is not None:
|
|
63
|
+
self.settings = new_settings
|
|
64
|
+
else:
|
|
65
|
+
self.settings = replace(self.settings, **kwargs)
|
|
66
|
+
|
|
67
|
+
# Set flag to trigger recalculation on next message
|
|
68
|
+
self._coefs_cache = None
|
|
69
|
+
self._fs_cache = None
|
|
70
|
+
self.state.needs_redesign = True
|
|
71
|
+
|
|
72
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
73
|
+
self._coefs_cache = None
|
|
74
|
+
self._fs_cache = None
|
|
75
|
+
self.state.needs_redesign = True
|
|
76
|
+
|
|
77
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
78
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
79
|
+
ax_idx = message.get_axis_idx(axis)
|
|
80
|
+
fs = 1 / message.axes[axis].gain
|
|
81
|
+
|
|
82
|
+
if (
|
|
83
|
+
self._coefs_cache is None
|
|
84
|
+
or self.state.needs_redesign
|
|
85
|
+
or (self._fs_cache is None or not np.isclose(self._fs_cache, fs))
|
|
86
|
+
):
|
|
87
|
+
self._coefs_cache = self.get_design_function()(fs)
|
|
88
|
+
self._fs_cache = fs
|
|
89
|
+
self.state.needs_redesign = False
|
|
90
|
+
|
|
91
|
+
if self._coefs_cache is None or self.settings.order <= 0 or message.data.size <= 0:
|
|
92
|
+
return message
|
|
93
|
+
|
|
94
|
+
x = message.data
|
|
95
|
+
if self.settings.coef_type == "sos":
|
|
96
|
+
y = scipy.signal.sosfiltfilt(
|
|
97
|
+
self._coefs_cache,
|
|
98
|
+
x,
|
|
99
|
+
axis=ax_idx,
|
|
100
|
+
padtype=self.settings.padtype,
|
|
101
|
+
padlen=self.settings.padlen,
|
|
102
|
+
)
|
|
103
|
+
elif self.settings.coef_type == "ba":
|
|
104
|
+
b, a = self._coefs_cache
|
|
105
|
+
y = scipy.signal.filtfilt(
|
|
106
|
+
b,
|
|
107
|
+
a,
|
|
108
|
+
x,
|
|
109
|
+
axis=ax_idx,
|
|
110
|
+
padtype=self.settings.padtype,
|
|
111
|
+
padlen=self.settings.padlen,
|
|
112
|
+
)
|
|
113
|
+
else:
|
|
114
|
+
ez.logger.error("coef_type must be 'sos' or 'ba'.")
|
|
115
|
+
raise ValueError("coef_type must be 'sos' or 'ba'.")
|
|
116
|
+
|
|
117
|
+
return replace(message, data=y)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class ButterworthZeroPhase(
|
|
121
|
+
BaseFilterByDesignTransformerUnit[ButterworthZeroPhaseSettings, ButterworthZeroPhaseTransformer]
|
|
122
|
+
):
|
|
123
|
+
SETTINGS = ButterworthZeroPhaseSettings
|
ezmsg/sigproc/cheby.py
CHANGED
|
@@ -5,11 +5,11 @@ import scipy.signal
|
|
|
5
5
|
from scipy.signal import normalize
|
|
6
6
|
|
|
7
7
|
from .filter import (
|
|
8
|
+
BACoeffs,
|
|
9
|
+
BaseFilterByDesignTransformerUnit,
|
|
8
10
|
FilterBaseSettings,
|
|
9
11
|
FilterByDesignTransformer,
|
|
10
|
-
BACoeffs,
|
|
11
12
|
SOSCoeffs,
|
|
12
|
-
BaseFilterByDesignTransformerUnit,
|
|
13
13
|
)
|
|
14
14
|
|
|
15
15
|
|
|
@@ -104,9 +104,7 @@ def cheby_design_fun(
|
|
|
104
104
|
return coefs
|
|
105
105
|
|
|
106
106
|
|
|
107
|
-
class ChebyshevFilterTransformer(
|
|
108
|
-
FilterByDesignTransformer[ChebyshevFilterSettings, BACoeffs | SOSCoeffs]
|
|
109
|
-
):
|
|
107
|
+
class ChebyshevFilterTransformer(FilterByDesignTransformer[ChebyshevFilterSettings, BACoeffs | SOSCoeffs]):
|
|
110
108
|
def get_design_function(
|
|
111
109
|
self,
|
|
112
110
|
) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
|
|
@@ -123,9 +121,5 @@ class ChebyshevFilterTransformer(
|
|
|
123
121
|
)
|
|
124
122
|
|
|
125
123
|
|
|
126
|
-
class ChebyshevFilter(
|
|
127
|
-
BaseFilterByDesignTransformerUnit[
|
|
128
|
-
ChebyshevFilterSettings, ChebyshevFilterTransformer
|
|
129
|
-
]
|
|
130
|
-
):
|
|
124
|
+
class ChebyshevFilter(BaseFilterByDesignTransformerUnit[ChebyshevFilterSettings, ChebyshevFilterTransformer]):
|
|
131
125
|
SETTINGS = ChebyshevFilterSettings
|
ezmsg/sigproc/combfilter.py
CHANGED
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import typing
|
|
3
|
+
|
|
3
4
|
import numpy as np
|
|
4
5
|
import scipy.signal
|
|
5
6
|
from scipy.signal import normalize
|
|
6
7
|
|
|
7
8
|
from .filter import (
|
|
9
|
+
BACoeffs,
|
|
10
|
+
BaseFilterByDesignTransformerUnit,
|
|
8
11
|
FilterBaseSettings,
|
|
9
12
|
FilterByDesignTransformer,
|
|
10
|
-
BACoeffs,
|
|
11
13
|
SOSCoeffs,
|
|
12
|
-
BaseFilterByDesignTransformerUnit,
|
|
13
14
|
)
|
|
14
15
|
|
|
15
16
|
|
|
@@ -103,9 +104,7 @@ def comb_design_fun(
|
|
|
103
104
|
return combined_sos
|
|
104
105
|
|
|
105
106
|
|
|
106
|
-
class CombFilterTransformer(
|
|
107
|
-
FilterByDesignTransformer[CombFilterSettings, BACoeffs | SOSCoeffs]
|
|
108
|
-
):
|
|
107
|
+
class CombFilterTransformer(FilterByDesignTransformer[CombFilterSettings, BACoeffs | SOSCoeffs]):
|
|
109
108
|
def get_design_function(
|
|
110
109
|
self,
|
|
111
110
|
) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
|
|
@@ -120,9 +119,7 @@ class CombFilterTransformer(
|
|
|
120
119
|
)
|
|
121
120
|
|
|
122
121
|
|
|
123
|
-
class CombFilterUnit(
|
|
124
|
-
BaseFilterByDesignTransformerUnit[CombFilterSettings, CombFilterTransformer]
|
|
125
|
-
):
|
|
122
|
+
class CombFilterUnit(BaseFilterByDesignTransformerUnit[CombFilterSettings, CombFilterTransformer]):
|
|
126
123
|
SETTINGS = CombFilterSettings
|
|
127
124
|
|
|
128
125
|
|
ezmsg/sigproc/decimate.py
CHANGED
|
@@ -4,7 +4,7 @@ import ezmsg.core as ez
|
|
|
4
4
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
5
5
|
|
|
6
6
|
from .base import BaseTransformerUnit
|
|
7
|
-
from .cheby import
|
|
7
|
+
from .cheby import ChebyshevFilterSettings, ChebyshevFilterTransformer
|
|
8
8
|
from .downsample import Downsample, DownsampleSettings
|
|
9
9
|
from .filter import BACoeffs, SOSCoeffs
|
|
10
10
|
|
|
@@ -30,11 +30,7 @@ class ChebyForDecimateTransformer(ChebyshevFilterTransformer[BACoeffs | SOSCoeff
|
|
|
30
30
|
return cheby_opt_design_fun
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
class ChebyForDecimate(
|
|
34
|
-
BaseTransformerUnit[
|
|
35
|
-
ChebyshevFilterSettings, AxisArray, AxisArray, ChebyForDecimateTransformer
|
|
36
|
-
]
|
|
37
|
-
):
|
|
33
|
+
class ChebyForDecimate(BaseTransformerUnit[ChebyshevFilterSettings, AxisArray, AxisArray, ChebyForDecimateTransformer]):
|
|
38
34
|
SETTINGS = ChebyshevFilterSettings
|
|
39
35
|
|
|
40
36
|
|
ezmsg/sigproc/denormalize.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
import ezmsg.core as ez
|
|
2
2
|
import numpy as np
|
|
3
3
|
import numpy.typing as npt
|
|
4
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
5
|
+
from ezmsg.util.messages.util import replace
|
|
6
|
+
|
|
4
7
|
from ezmsg.sigproc.base import (
|
|
5
|
-
BaseTransformerUnit,
|
|
6
8
|
BaseStatefulTransformer,
|
|
9
|
+
BaseTransformerUnit,
|
|
7
10
|
processor_state,
|
|
8
11
|
)
|
|
9
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
10
|
-
from ezmsg.util.messages.util import replace
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class DenormalizeSettings(ez.Settings):
|
|
@@ -27,9 +28,7 @@ class DenormalizeState:
|
|
|
27
28
|
offsets: npt.NDArray | None = None
|
|
28
29
|
|
|
29
30
|
|
|
30
|
-
class DenormalizeTransformer(
|
|
31
|
-
BaseStatefulTransformer[DenormalizeSettings, AxisArray, AxisArray, DenormalizeState]
|
|
32
|
-
):
|
|
31
|
+
class DenormalizeTransformer(BaseStatefulTransformer[DenormalizeSettings, AxisArray, AxisArray, DenormalizeState]):
|
|
33
32
|
"""
|
|
34
33
|
Scales data from a normalized distribution (mean=0, std=1) to a denormalized
|
|
35
34
|
distribution using random per-channel offsets and gains designed to keep the
|
|
@@ -76,9 +75,5 @@ class DenormalizeTransformer(
|
|
|
76
75
|
)
|
|
77
76
|
|
|
78
77
|
|
|
79
|
-
class DenormalizeUnit(
|
|
80
|
-
BaseTransformerUnit[
|
|
81
|
-
DenormalizeSettings, AxisArray, AxisArray, DenormalizeTransformer
|
|
82
|
-
]
|
|
83
|
-
):
|
|
78
|
+
class DenormalizeUnit(BaseTransformerUnit[DenormalizeSettings, AxisArray, AxisArray, DenormalizeTransformer]):
|
|
84
79
|
SETTINGS = DenormalizeSettings
|
ezmsg/sigproc/detrend.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import scipy.signal as sps
|
|
2
2
|
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
3
|
-
|
|
3
|
+
|
|
4
4
|
from ezmsg.sigproc.base import BaseTransformerUnit
|
|
5
|
+
from ezmsg.sigproc.ewma import EWMASettings, EWMATransformer
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
class DetrendTransformer(EWMATransformer):
|
|
@@ -23,7 +24,5 @@ class DetrendTransformer(EWMATransformer):
|
|
|
23
24
|
return replace(message, data=message.data - means)
|
|
24
25
|
|
|
25
26
|
|
|
26
|
-
class DetrendUnit(
|
|
27
|
-
BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, DetrendTransformer]
|
|
28
|
-
):
|
|
27
|
+
class DetrendUnit(BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, DetrendTransformer]):
|
|
29
28
|
SETTINGS = EWMASettings
|
ezmsg/sigproc/diff.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
import ezmsg.core as ez
|
|
2
2
|
import numpy as np
|
|
3
3
|
import numpy.typing as npt
|
|
4
|
+
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
5
|
+
from ezmsg.util.messages.util import replace
|
|
6
|
+
|
|
4
7
|
from ezmsg.sigproc.base import (
|
|
8
|
+
BaseStatefulTransformer,
|
|
5
9
|
BaseTransformerUnit,
|
|
6
10
|
processor_state,
|
|
7
|
-
BaseStatefulTransformer,
|
|
8
11
|
)
|
|
9
|
-
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
10
|
-
from ezmsg.util.messages.util import replace
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class DiffSettings(ez.Settings):
|
|
@@ -21,9 +22,7 @@ class DiffState:
|
|
|
21
22
|
last_time: float | None = None
|
|
22
23
|
|
|
23
24
|
|
|
24
|
-
class DiffTransformer(
|
|
25
|
-
BaseStatefulTransformer[DiffSettings, AxisArray, AxisArray, DiffState]
|
|
26
|
-
):
|
|
25
|
+
class DiffTransformer(BaseStatefulTransformer[DiffSettings, AxisArray, AxisArray, DiffState]):
|
|
27
26
|
def _hash_message(self, message: AxisArray) -> int:
|
|
28
27
|
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
29
28
|
sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
|
|
@@ -49,20 +48,14 @@ class DiffTransformer(
|
|
|
49
48
|
axis=ax_idx,
|
|
50
49
|
)
|
|
51
50
|
# Prepare last_dat for next iteration
|
|
52
|
-
self.state.last_dat = slice_along_axis(
|
|
53
|
-
message.data, slice(-1, None), axis=ax_idx
|
|
54
|
-
)
|
|
51
|
+
self.state.last_dat = slice_along_axis(message.data, slice(-1, None), axis=ax_idx)
|
|
55
52
|
# Scale by fs if requested. This convers the diff to a derivative. e.g., diff of position becomes velocity.
|
|
56
53
|
if self.settings.scale_by_fs:
|
|
57
54
|
ax_info = message.get_axis(axis)
|
|
58
55
|
if hasattr(ax_info, "data"):
|
|
59
56
|
dt = np.diff(np.concatenate(([self.state.last_time], ax_info.data)))
|
|
60
57
|
# Expand dt dims to match diffs
|
|
61
|
-
exp_sl = (
|
|
62
|
-
(None,) * ax_idx
|
|
63
|
-
+ (Ellipsis,)
|
|
64
|
-
+ (None,) * (message.data.ndim - ax_idx - 1)
|
|
65
|
-
)
|
|
58
|
+
exp_sl = (None,) * ax_idx + (Ellipsis,) + (None,) * (message.data.ndim - ax_idx - 1)
|
|
66
59
|
diffs /= dt[exp_sl]
|
|
67
60
|
self.state.last_time = ax_info.data[-1] # For next iteration
|
|
68
61
|
else:
|
|
@@ -71,9 +64,7 @@ class DiffTransformer(
|
|
|
71
64
|
return replace(message, data=diffs)
|
|
72
65
|
|
|
73
66
|
|
|
74
|
-
class DiffUnit(
|
|
75
|
-
BaseTransformerUnit[DiffSettings, AxisArray, AxisArray, DiffTransformer]
|
|
76
|
-
):
|
|
67
|
+
class DiffUnit(BaseTransformerUnit[DiffSettings, AxisArray, AxisArray, DiffTransformer]):
|
|
77
68
|
SETTINGS = DiffSettings
|
|
78
69
|
|
|
79
70
|
|
ezmsg/sigproc/downsample.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
+
import ezmsg.core as ez
|
|
1
2
|
import numpy as np
|
|
2
3
|
from ezmsg.util.messages.axisarray import (
|
|
3
4
|
AxisArray,
|
|
4
|
-
slice_along_axis,
|
|
5
5
|
replace,
|
|
6
|
+
slice_along_axis,
|
|
6
7
|
)
|
|
7
|
-
import ezmsg.core as ez
|
|
8
8
|
|
|
9
9
|
from .base import (
|
|
10
10
|
BaseStatefulTransformer,
|
|
@@ -38,9 +38,7 @@ class DownsampleState:
|
|
|
38
38
|
"""Index of the next msg's first sample into the virtual rotating ds_factor counter."""
|
|
39
39
|
|
|
40
40
|
|
|
41
|
-
class DownsampleTransformer(
|
|
42
|
-
BaseStatefulTransformer[DownsampleSettings, AxisArray, AxisArray, DownsampleState]
|
|
43
|
-
):
|
|
41
|
+
class DownsampleTransformer(BaseStatefulTransformer[DownsampleSettings, AxisArray, AxisArray, DownsampleState]):
|
|
44
42
|
"""
|
|
45
43
|
Downsampled data simply comprise every `factor`th sample.
|
|
46
44
|
This should only be used following appropriate lowpass filtering.
|
|
@@ -75,9 +73,7 @@ class DownsampleTransformer(
|
|
|
75
73
|
axis_idx = message.get_axis_idx(axis)
|
|
76
74
|
|
|
77
75
|
n_samples = message.data.shape[axis_idx]
|
|
78
|
-
samples = (
|
|
79
|
-
np.arange(self.state.s_idx, self.state.s_idx + n_samples) % self._state.q
|
|
80
|
-
)
|
|
76
|
+
samples = np.arange(self.state.s_idx, self.state.s_idx + n_samples) % self._state.q
|
|
81
77
|
if n_samples > 0:
|
|
82
78
|
# Update state for next iteration.
|
|
83
79
|
self._state.s_idx = samples[-1] + 1
|
|
@@ -104,9 +100,7 @@ class DownsampleTransformer(
|
|
|
104
100
|
return msg_out
|
|
105
101
|
|
|
106
102
|
|
|
107
|
-
class Downsample(
|
|
108
|
-
BaseTransformerUnit[DownsampleSettings, AxisArray, AxisArray, DownsampleTransformer]
|
|
109
|
-
):
|
|
103
|
+
class Downsample(BaseTransformerUnit[DownsampleSettings, AxisArray, AxisArray, DownsampleTransformer]):
|
|
110
104
|
SETTINGS = DownsampleSettings
|
|
111
105
|
|
|
112
106
|
|
|
@@ -115,6 +109,4 @@ def downsample(
|
|
|
115
109
|
target_rate: float | None = None,
|
|
116
110
|
factor: int | None = None,
|
|
117
111
|
) -> DownsampleTransformer:
|
|
118
|
-
return DownsampleTransformer(
|
|
119
|
-
DownsampleSettings(axis=axis, target_rate=target_rate, factor=factor)
|
|
120
|
-
)
|
|
112
|
+
return DownsampleTransformer(DownsampleSettings(axis=axis, target_rate=target_rate, factor=factor))
|
ezmsg/sigproc/ewma.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
|
-
from dataclasses import field
|
|
2
1
|
import functools
|
|
2
|
+
from dataclasses import field
|
|
3
3
|
|
|
4
|
+
import ezmsg.core as ez
|
|
4
5
|
import numpy as np
|
|
5
6
|
import numpy.typing as npt
|
|
6
7
|
import scipy.signal as sps
|
|
7
|
-
import ezmsg.core as ez
|
|
8
8
|
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
9
9
|
from ezmsg.util.messages.util import replace
|
|
10
10
|
|
|
11
|
-
from .base import BaseStatefulTransformer,
|
|
11
|
+
from .base import BaseStatefulTransformer, BaseTransformerUnit, processor_state
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
def _tau_from_alpha(alpha: float, dt: float) -> float:
|
|
@@ -29,9 +29,7 @@ def _alpha_from_tau(tau: float, dt: float) -> float:
|
|
|
29
29
|
return 1 - np.exp(-dt / tau)
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
def ewma_step(
|
|
33
|
-
sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None
|
|
34
|
-
):
|
|
32
|
+
def ewma_step(sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None):
|
|
35
33
|
"""
|
|
36
34
|
Do an exponentially weighted moving average step.
|
|
37
35
|
|
|
@@ -97,9 +95,7 @@ class EWMA_Deprecated:
|
|
|
97
95
|
if self.prev is None:
|
|
98
96
|
self.prev = arr[:1]
|
|
99
97
|
|
|
100
|
-
out += self.prev * np.expand_dims(
|
|
101
|
-
self.weights[1 : n + 1], list(range(1, arr.ndim))
|
|
102
|
-
)
|
|
98
|
+
out += self.prev * np.expand_dims(self.weights[1 : n + 1], list(range(1, arr.ndim)))
|
|
103
99
|
|
|
104
100
|
self.prev = out[-1:]
|
|
105
101
|
|
|
@@ -128,9 +124,7 @@ class EWMA_Deprecated:
|
|
|
128
124
|
if self.prev is None:
|
|
129
125
|
self.prev = arr[:1]
|
|
130
126
|
|
|
131
|
-
result += self.prev * np.expand_dims(
|
|
132
|
-
self.weights[1 : n + 1], list(range(1, arr.ndim))
|
|
133
|
-
)
|
|
127
|
+
result += self.prev * np.expand_dims(self.weights[1 : n + 1], list(range(1, arr.ndim)))
|
|
134
128
|
|
|
135
129
|
# Store the result back into prev
|
|
136
130
|
self.prev = result[-1]
|
|
@@ -155,25 +149,17 @@ class EWMAState:
|
|
|
155
149
|
zi: npt.NDArray | None = None
|
|
156
150
|
|
|
157
151
|
|
|
158
|
-
class EWMATransformer(
|
|
159
|
-
BaseStatefulTransformer[EWMASettings, AxisArray, AxisArray, EWMAState]
|
|
160
|
-
):
|
|
152
|
+
class EWMATransformer(BaseStatefulTransformer[EWMASettings, AxisArray, AxisArray, EWMAState]):
|
|
161
153
|
def _hash_message(self, message: AxisArray) -> int:
|
|
162
154
|
axis = self.settings.axis or message.dims[0]
|
|
163
155
|
axis_idx = message.get_axis_idx(axis)
|
|
164
|
-
sample_shape =
|
|
165
|
-
message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
166
|
-
)
|
|
156
|
+
sample_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
167
157
|
return hash((sample_shape, message.axes[axis].gain, message.key))
|
|
168
158
|
|
|
169
159
|
def _reset_state(self, message: AxisArray) -> None:
|
|
170
160
|
axis = self.settings.axis or message.dims[0]
|
|
171
|
-
self._state.alpha = _alpha_from_tau(
|
|
172
|
-
|
|
173
|
-
)
|
|
174
|
-
sub_dat = slice_along_axis(
|
|
175
|
-
message.data, slice(None, 1, None), axis=message.get_axis_idx(axis)
|
|
176
|
-
)
|
|
161
|
+
self._state.alpha = _alpha_from_tau(self.settings.time_constant, message.axes[axis].gain)
|
|
162
|
+
sub_dat = slice_along_axis(message.data, slice(None, 1, None), axis=message.get_axis_idx(axis))
|
|
177
163
|
self._state.zi = (1 - self._state.alpha) * sub_dat
|
|
178
164
|
|
|
179
165
|
def _process(self, message: AxisArray) -> AxisArray:
|
|
@@ -191,7 +177,5 @@ class EWMATransformer(
|
|
|
191
177
|
return replace(message, data=expected)
|
|
192
178
|
|
|
193
179
|
|
|
194
|
-
class EWMAUnit(
|
|
195
|
-
BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, EWMATransformer]
|
|
196
|
-
):
|
|
180
|
+
class EWMAUnit(BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, EWMATransformer]):
|
|
197
181
|
SETTINGS = EWMASettings
|
ezmsg/sigproc/ewmfilter.py
CHANGED
ezmsg/sigproc/extract_axis.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
|
-
import numpy as np
|
|
2
1
|
import ezmsg.core as ez
|
|
2
|
+
import numpy as np
|
|
3
3
|
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
4
|
+
|
|
4
5
|
from ezmsg.sigproc.base import BaseTransformer, BaseTransformerUnit
|
|
5
6
|
|
|
6
7
|
|
|
@@ -35,7 +36,5 @@ class ExtractAxisData(BaseTransformer[ExtractAxisSettings, AxisArray, AxisArray]
|
|
|
35
36
|
)
|
|
36
37
|
|
|
37
38
|
|
|
38
|
-
class ExtractAxisDataUnit(
|
|
39
|
-
BaseTransformerUnit[ExtractAxisSettings, AxisArray, AxisArray, ExtractAxisData]
|
|
40
|
-
):
|
|
39
|
+
class ExtractAxisDataUnit(BaseTransformerUnit[ExtractAxisSettings, AxisArray, AxisArray, ExtractAxisData]):
|
|
41
40
|
SETTINGS = ExtractAxisSettings
|