ezmsg-sigproc 1.8.1__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +36 -39
- ezmsg/sigproc/adaptive_lattice_notch.py +231 -0
- ezmsg/sigproc/affinetransform.py +169 -163
- ezmsg/sigproc/aggregate.py +119 -104
- ezmsg/sigproc/bandpower.py +58 -52
- ezmsg/sigproc/base.py +1242 -0
- ezmsg/sigproc/butterworthfilter.py +37 -33
- ezmsg/sigproc/cheby.py +29 -17
- ezmsg/sigproc/combfilter.py +163 -0
- ezmsg/sigproc/decimate.py +19 -10
- ezmsg/sigproc/detrend.py +29 -0
- ezmsg/sigproc/diff.py +81 -0
- ezmsg/sigproc/downsample.py +78 -78
- ezmsg/sigproc/ewma.py +197 -0
- ezmsg/sigproc/extract_axis.py +41 -0
- ezmsg/sigproc/filter.py +257 -141
- ezmsg/sigproc/filterbank.py +247 -199
- ezmsg/sigproc/math/abs.py +17 -22
- ezmsg/sigproc/math/clip.py +24 -24
- ezmsg/sigproc/math/difference.py +34 -30
- ezmsg/sigproc/math/invert.py +13 -25
- ezmsg/sigproc/math/log.py +28 -33
- ezmsg/sigproc/math/scale.py +18 -26
- ezmsg/sigproc/quantize.py +71 -0
- ezmsg/sigproc/resample.py +298 -0
- ezmsg/sigproc/sampler.py +241 -259
- ezmsg/sigproc/scaler.py +55 -218
- ezmsg/sigproc/signalinjector.py +52 -43
- ezmsg/sigproc/slicer.py +81 -89
- ezmsg/sigproc/spectrogram.py +77 -75
- ezmsg/sigproc/spectrum.py +203 -168
- ezmsg/sigproc/synth.py +546 -393
- ezmsg/sigproc/transpose.py +131 -0
- ezmsg/sigproc/util/asio.py +156 -0
- ezmsg/sigproc/util/message.py +31 -0
- ezmsg/sigproc/util/profile.py +55 -12
- ezmsg/sigproc/util/typeresolution.py +83 -0
- ezmsg/sigproc/wavelets.py +154 -153
- ezmsg/sigproc/window.py +269 -211
- {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/METADATA +2 -1
- ezmsg_sigproc-2.0.0.dist-info/RECORD +51 -0
- ezmsg_sigproc-1.8.1.dist-info/RECORD +0 -39
- {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/WHEEL +0 -0
- {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -2,20 +2,22 @@ import functools
|
|
|
2
2
|
import typing
|
|
3
3
|
|
|
4
4
|
import scipy.signal
|
|
5
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
6
5
|
from scipy.signal import normalize
|
|
7
6
|
|
|
8
7
|
from .filter import (
|
|
9
8
|
FilterBaseSettings,
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
9
|
+
BACoeffs,
|
|
10
|
+
SOSCoeffs,
|
|
11
|
+
FilterByDesignTransformer,
|
|
12
|
+
BaseFilterByDesignTransformerUnit,
|
|
13
13
|
)
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class ButterworthFilterSettings(FilterBaseSettings):
|
|
17
17
|
"""Settings for :obj:`ButterworthFilter`."""
|
|
18
18
|
|
|
19
|
+
# axis and coef_type are inherited from FilterBaseSettings
|
|
20
|
+
|
|
19
21
|
order: int = 0
|
|
20
22
|
"""
|
|
21
23
|
Filter order
|
|
@@ -72,7 +74,7 @@ def butter_design_fun(
|
|
|
72
74
|
cutoff: float | None = None,
|
|
73
75
|
coef_type: str = "ba",
|
|
74
76
|
wn_hz: bool = True,
|
|
75
|
-
) ->
|
|
77
|
+
) -> BACoeffs | SOSCoeffs | None:
|
|
76
78
|
"""
|
|
77
79
|
See :obj:`ButterworthFilterSettings.filter_specs` for an explanation of specifying different
|
|
78
80
|
filter types (lowpass, highpass, bandpass, bandstop) from the parameters.
|
|
@@ -109,52 +111,54 @@ def butter_design_fun(
|
|
|
109
111
|
return coefs
|
|
110
112
|
|
|
111
113
|
|
|
112
|
-
class
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
def
|
|
114
|
+
class ButterworthFilterTransformer(
|
|
115
|
+
FilterByDesignTransformer[ButterworthFilterSettings, BACoeffs | SOSCoeffs]
|
|
116
|
+
):
|
|
117
|
+
def get_design_function(
|
|
116
118
|
self,
|
|
117
|
-
) -> typing.Callable[[float],
|
|
119
|
+
) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
|
|
118
120
|
return functools.partial(
|
|
119
121
|
butter_design_fun,
|
|
120
|
-
order=self.
|
|
121
|
-
cuton=self.
|
|
122
|
-
cutoff=self.
|
|
123
|
-
coef_type=self.
|
|
122
|
+
order=self.settings.order,
|
|
123
|
+
cuton=self.settings.cuton,
|
|
124
|
+
cutoff=self.settings.cutoff,
|
|
125
|
+
coef_type=self.settings.coef_type,
|
|
126
|
+
wn_hz=self.settings.wn_hz,
|
|
124
127
|
)
|
|
125
128
|
|
|
126
129
|
|
|
130
|
+
class ButterworthFilter(
|
|
131
|
+
BaseFilterByDesignTransformerUnit[
|
|
132
|
+
ButterworthFilterSettings, ButterworthFilterTransformer
|
|
133
|
+
]
|
|
134
|
+
):
|
|
135
|
+
SETTINGS = ButterworthFilterSettings
|
|
136
|
+
|
|
137
|
+
|
|
127
138
|
def butter(
|
|
128
139
|
axis: str | None,
|
|
129
140
|
order: int = 0,
|
|
130
141
|
cuton: float | None = None,
|
|
131
142
|
cutoff: float | None = None,
|
|
132
143
|
coef_type: str = "ba",
|
|
133
|
-
|
|
144
|
+
wn_hz: bool = True,
|
|
145
|
+
) -> ButterworthFilterTransformer:
|
|
134
146
|
"""
|
|
135
147
|
Convenience generator wrapping filter_gen_by_design for Butterworth filters.
|
|
136
148
|
Apply Butterworth filter to streaming data. Uses :obj:`scipy.signal.butter` to design the filter.
|
|
137
149
|
See :obj:`ButterworthFilterSettings.filter_specs` for an explanation of specifying different
|
|
138
150
|
filter types (lowpass, highpass, bandpass, bandstop) from the parameters.
|
|
139
151
|
|
|
140
|
-
Args:
|
|
141
|
-
axis: The name of the axis to filter.
|
|
142
|
-
Note: The axis must be represented in the message .axes and be of type AxisArray.LinearAxis.
|
|
143
|
-
order: Filter order.
|
|
144
|
-
cuton: Corner frequency of the filter in Hz.
|
|
145
|
-
cutoff: Corner frequency of the filter in Hz.
|
|
146
|
-
coef_type: "ba" or "sos"
|
|
147
|
-
|
|
148
152
|
Returns:
|
|
149
|
-
|
|
150
|
-
and yields an :obj:`AxisArray` with filtered data.
|
|
151
|
-
|
|
153
|
+
:obj:`ButterworthFilterTransformer`
|
|
152
154
|
"""
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
155
|
+
return ButterworthFilterTransformer(
|
|
156
|
+
ButterworthFilterSettings(
|
|
157
|
+
axis=axis,
|
|
158
|
+
order=order,
|
|
159
|
+
cuton=cuton,
|
|
160
|
+
cutoff=cutoff,
|
|
161
|
+
coef_type=coef_type,
|
|
162
|
+
wn_hz=wn_hz,
|
|
163
|
+
)
|
|
159
164
|
)
|
|
160
|
-
return filter_gen_by_design(axis, coef_type, design_fun)
|
ezmsg/sigproc/cheby.py
CHANGED
|
@@ -6,13 +6,17 @@ from scipy.signal import normalize
|
|
|
6
6
|
|
|
7
7
|
from .filter import (
|
|
8
8
|
FilterBaseSettings,
|
|
9
|
-
|
|
10
|
-
|
|
9
|
+
FilterByDesignTransformer,
|
|
10
|
+
BACoeffs,
|
|
11
|
+
SOSCoeffs,
|
|
12
|
+
BaseFilterByDesignTransformerUnit,
|
|
11
13
|
)
|
|
12
14
|
|
|
13
15
|
|
|
14
16
|
class ChebyshevFilterSettings(FilterBaseSettings):
|
|
15
|
-
"""Settings for :obj:`
|
|
17
|
+
"""Settings for :obj:`ChebyshevFilter`."""
|
|
18
|
+
|
|
19
|
+
# axis and coef_type are inherited from FilterBaseSettings
|
|
16
20
|
|
|
17
21
|
order: int = 0
|
|
18
22
|
"""
|
|
@@ -63,7 +67,7 @@ def cheby_design_fun(
|
|
|
63
67
|
coef_type: str = "ba",
|
|
64
68
|
cheby_type: str = "cheby1",
|
|
65
69
|
wn_hz: bool = True,
|
|
66
|
-
) ->
|
|
70
|
+
) -> BACoeffs | SOSCoeffs | None:
|
|
67
71
|
"""
|
|
68
72
|
Chebyshev type I and type II digital and analog filter design.
|
|
69
73
|
Design an `order`th-order digital or analog Chebyshev type I or type II filter and return the filter coefficients.
|
|
@@ -100,20 +104,28 @@ def cheby_design_fun(
|
|
|
100
104
|
return coefs
|
|
101
105
|
|
|
102
106
|
|
|
103
|
-
class
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
def
|
|
107
|
+
class ChebyshevFilterTransformer(
|
|
108
|
+
FilterByDesignTransformer[ChebyshevFilterSettings, BACoeffs | SOSCoeffs]
|
|
109
|
+
):
|
|
110
|
+
def get_design_function(
|
|
107
111
|
self,
|
|
108
|
-
) -> typing.Callable[[float],
|
|
112
|
+
) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
|
|
109
113
|
return functools.partial(
|
|
110
114
|
cheby_design_fun,
|
|
111
|
-
order=self.
|
|
112
|
-
ripple_tol=self.
|
|
113
|
-
Wn=self.
|
|
114
|
-
btype=self.
|
|
115
|
-
analog=self.
|
|
116
|
-
coef_type=self.
|
|
117
|
-
cheby_type=self.
|
|
118
|
-
wn_hz=self.
|
|
115
|
+
order=self.settings.order,
|
|
116
|
+
ripple_tol=self.settings.ripple_tol,
|
|
117
|
+
Wn=self.settings.Wn,
|
|
118
|
+
btype=self.settings.btype,
|
|
119
|
+
analog=self.settings.analog,
|
|
120
|
+
coef_type=self.settings.coef_type,
|
|
121
|
+
cheby_type=self.settings.cheby_type,
|
|
122
|
+
wn_hz=self.settings.wn_hz,
|
|
119
123
|
)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class ChebyshevFilter(
|
|
127
|
+
BaseFilterByDesignTransformerUnit[
|
|
128
|
+
ChebyshevFilterSettings, ChebyshevFilterTransformer
|
|
129
|
+
]
|
|
130
|
+
):
|
|
131
|
+
SETTINGS = ChebyshevFilterSettings
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import typing
|
|
3
|
+
import numpy as np
|
|
4
|
+
import scipy.signal
|
|
5
|
+
from scipy.signal import normalize
|
|
6
|
+
|
|
7
|
+
from .filter import (
|
|
8
|
+
FilterBaseSettings,
|
|
9
|
+
FilterByDesignTransformer,
|
|
10
|
+
BACoeffs,
|
|
11
|
+
SOSCoeffs,
|
|
12
|
+
BaseFilterByDesignTransformerUnit,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CombFilterSettings(FilterBaseSettings):
|
|
17
|
+
"""Settings for :obj:`CombFilter`."""
|
|
18
|
+
|
|
19
|
+
# axis and coef_type are inherited from FilterBaseSettings
|
|
20
|
+
|
|
21
|
+
fundamental_freq: float = 60.0
|
|
22
|
+
"""
|
|
23
|
+
Fundamental frequency in Hz
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
num_harmonics: int = 3
|
|
27
|
+
"""
|
|
28
|
+
Number of harmonics to include (including fundamental)
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
q_factor: float = 35.0
|
|
32
|
+
"""
|
|
33
|
+
Quality factor (Q) for each peak/notch
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
filter_type: str = "notch"
|
|
37
|
+
"""
|
|
38
|
+
Type of comb filter: 'notch' removes harmonics, 'peak' passes harmonics at the expense of others.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
quality_scaling: str = "constant"
|
|
42
|
+
"""
|
|
43
|
+
'constant': same quality for all harmonics results in wider bands at higher frequencies,
|
|
44
|
+
'proportional': quality proportional to frequency results in constant bandwidths.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def comb_design_fun(
|
|
49
|
+
fs: float,
|
|
50
|
+
fundamental_freq: float = 60.0,
|
|
51
|
+
num_harmonics: int = 3,
|
|
52
|
+
q_factor: float = 35.0,
|
|
53
|
+
filter_type: str = "notch",
|
|
54
|
+
coef_type: str = "sos",
|
|
55
|
+
quality_scaling: str = "constant",
|
|
56
|
+
) -> BACoeffs | SOSCoeffs | None:
|
|
57
|
+
"""
|
|
58
|
+
Design a comb filter as cascaded second-order sections targeting a fundamental frequency and its harmonics.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
The filter coefficients as SOS (recommended) or (b, a) for finite precision stability.
|
|
62
|
+
"""
|
|
63
|
+
if coef_type != "sos" and coef_type != "ba":
|
|
64
|
+
raise ValueError("Comb filter only supports 'sos' or 'ba' coefficient types")
|
|
65
|
+
|
|
66
|
+
# Generate all SOS sections
|
|
67
|
+
all_sos = []
|
|
68
|
+
|
|
69
|
+
for i in range(1, num_harmonics + 1):
|
|
70
|
+
freq = fundamental_freq * i
|
|
71
|
+
|
|
72
|
+
# Skip if frequency exceeds Nyquist
|
|
73
|
+
if freq >= fs / 2:
|
|
74
|
+
continue
|
|
75
|
+
|
|
76
|
+
# Adjust Q factor based on scaling method
|
|
77
|
+
current_q = q_factor
|
|
78
|
+
if quality_scaling == "proportional":
|
|
79
|
+
current_q = q_factor * i
|
|
80
|
+
|
|
81
|
+
if filter_type == "notch":
|
|
82
|
+
sos = scipy.signal.iirnotch(w0=freq, Q=current_q, fs=fs)
|
|
83
|
+
else: # peak filter
|
|
84
|
+
sos = scipy.signal.iirpeak(w0=freq, Q=current_q, fs=fs)
|
|
85
|
+
# Though .iirnotch and .iirpeak return b, a pairs, these are second order so
|
|
86
|
+
# we can use them directly as SOS sections.
|
|
87
|
+
# Check:
|
|
88
|
+
# assert np.allclose(scipy.signal.tf2sos(sos[0], sos[1])[0], np.hstack(sos))
|
|
89
|
+
|
|
90
|
+
all_sos.append(np.hstack(sos))
|
|
91
|
+
|
|
92
|
+
if not all_sos:
|
|
93
|
+
return None
|
|
94
|
+
|
|
95
|
+
# Combine all SOS sections
|
|
96
|
+
combined_sos = np.vstack(all_sos)
|
|
97
|
+
|
|
98
|
+
if coef_type == "ba":
|
|
99
|
+
# Convert to transfer function form
|
|
100
|
+
b, a = scipy.signal.sos2tf(combined_sos)
|
|
101
|
+
return normalize(b, a)
|
|
102
|
+
|
|
103
|
+
return combined_sos
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class CombFilterTransformer(
|
|
107
|
+
FilterByDesignTransformer[CombFilterSettings, BACoeffs | SOSCoeffs]
|
|
108
|
+
):
|
|
109
|
+
def get_design_function(
|
|
110
|
+
self,
|
|
111
|
+
) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
|
|
112
|
+
return functools.partial(
|
|
113
|
+
comb_design_fun,
|
|
114
|
+
fundamental_freq=self.settings.fundamental_freq,
|
|
115
|
+
num_harmonics=self.settings.num_harmonics,
|
|
116
|
+
q_factor=self.settings.q_factor,
|
|
117
|
+
filter_type=self.settings.filter_type,
|
|
118
|
+
coef_type=self.settings.coef_type,
|
|
119
|
+
quality_scaling=self.settings.quality_scaling,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class CombFilterUnit(
|
|
124
|
+
BaseFilterByDesignTransformerUnit[CombFilterSettings, CombFilterTransformer]
|
|
125
|
+
):
|
|
126
|
+
SETTINGS = CombFilterSettings
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def comb(
|
|
130
|
+
axis: str | None,
|
|
131
|
+
fundamental_freq: float = 50.0,
|
|
132
|
+
num_harmonics: int = 3,
|
|
133
|
+
q_factor: float = 35.0,
|
|
134
|
+
filter_type: str = "notch",
|
|
135
|
+
coef_type: str = "sos",
|
|
136
|
+
quality_scaling: str = "constant",
|
|
137
|
+
) -> CombFilterTransformer:
|
|
138
|
+
"""
|
|
139
|
+
Create a comb filter for enhancing or removing a fundamental frequency and its harmonics.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
axis: Axis to filter along
|
|
143
|
+
fundamental_freq: Base frequency in Hz
|
|
144
|
+
num_harmonics: Number of harmonic peaks/notches (including fundamental)
|
|
145
|
+
q_factor: Quality factor for peak/notch width
|
|
146
|
+
filter_type: 'notch' to remove or 'peak' to enhance harmonics
|
|
147
|
+
coef_type: Coefficient type ('sos' recommended for stability)
|
|
148
|
+
quality_scaling: How to handle bandwidths across harmonics
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
CombFilterTransformer
|
|
152
|
+
"""
|
|
153
|
+
return CombFilterTransformer(
|
|
154
|
+
CombFilterSettings(
|
|
155
|
+
axis=axis,
|
|
156
|
+
fundamental_freq=fundamental_freq,
|
|
157
|
+
num_harmonics=num_harmonics,
|
|
158
|
+
q_factor=q_factor,
|
|
159
|
+
filter_type=filter_type,
|
|
160
|
+
coef_type=coef_type,
|
|
161
|
+
quality_scaling=quality_scaling,
|
|
162
|
+
)
|
|
163
|
+
)
|
ezmsg/sigproc/decimate.py
CHANGED
|
@@ -3,31 +3,41 @@ import typing
|
|
|
3
3
|
import ezmsg.core as ez
|
|
4
4
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
5
5
|
|
|
6
|
-
from .
|
|
6
|
+
from .base import BaseTransformerUnit
|
|
7
|
+
from .cheby import ChebyshevFilterTransformer, ChebyshevFilterSettings
|
|
7
8
|
from .downsample import Downsample, DownsampleSettings
|
|
8
|
-
from .filter import
|
|
9
|
+
from .filter import BACoeffs, SOSCoeffs
|
|
9
10
|
|
|
10
11
|
|
|
11
|
-
class
|
|
12
|
+
class ChebyForDecimateTransformer(ChebyshevFilterTransformer[BACoeffs | SOSCoeffs]):
|
|
12
13
|
"""
|
|
13
|
-
A :obj:`
|
|
14
|
+
A :obj:`ChebyshevFilterTransformer` with a design filter method that additionally accepts a target sampling rate,
|
|
14
15
|
and if the target rate cannot be achieved it returns None, else it returns the filter coefficients.
|
|
15
16
|
"""
|
|
16
17
|
|
|
17
|
-
def
|
|
18
|
+
def get_design_function(
|
|
18
19
|
self,
|
|
19
|
-
) -> typing.Callable[[float],
|
|
20
|
-
def cheby_opt_design_fun(fs: float) ->
|
|
20
|
+
) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
|
|
21
|
+
def cheby_opt_design_fun(fs: float) -> BACoeffs | SOSCoeffs | None:
|
|
21
22
|
if fs is None:
|
|
22
23
|
return None
|
|
23
|
-
ds_factor = int(fs / (2.5 * self.
|
|
24
|
+
ds_factor = int(fs / (2.5 * self.settings.Wn))
|
|
24
25
|
if ds_factor < 2:
|
|
25
26
|
return None
|
|
26
|
-
partial_fun = super(
|
|
27
|
+
partial_fun = super(ChebyForDecimateTransformer, self).get_design_function()
|
|
27
28
|
return partial_fun(fs)
|
|
29
|
+
|
|
28
30
|
return cheby_opt_design_fun
|
|
29
31
|
|
|
30
32
|
|
|
33
|
+
class ChebyForDecimate(
|
|
34
|
+
BaseTransformerUnit[
|
|
35
|
+
ChebyshevFilterSettings, AxisArray, AxisArray, ChebyForDecimateTransformer
|
|
36
|
+
]
|
|
37
|
+
):
|
|
38
|
+
SETTINGS = ChebyshevFilterSettings
|
|
39
|
+
|
|
40
|
+
|
|
31
41
|
class Decimate(ez.Collection):
|
|
32
42
|
"""
|
|
33
43
|
A :obj:`Collection` chaining a :obj:`Filter` node configured as a lowpass Chebyshev filter
|
|
@@ -43,7 +53,6 @@ class Decimate(ez.Collection):
|
|
|
43
53
|
DOWNSAMPLE = Downsample()
|
|
44
54
|
|
|
45
55
|
def configure(self) -> None:
|
|
46
|
-
|
|
47
56
|
cheby_settings = ChebyshevFilterSettings(
|
|
48
57
|
order=8,
|
|
49
58
|
ripple_tol=0.05,
|
ezmsg/sigproc/detrend.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import scipy.signal as sps
|
|
2
|
+
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
3
|
+
from ezmsg.sigproc.ewma import EWMATransformer, EWMASettings
|
|
4
|
+
from ezmsg.sigproc.base import BaseTransformerUnit
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DetrendTransformer(EWMATransformer):
|
|
8
|
+
"""
|
|
9
|
+
Detrend the data using an exponentially weighted moving average (EWMA)
|
|
10
|
+
estimate of the mean.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def _process(self, message):
|
|
14
|
+
axis = self.settings.axis or message.dims[0]
|
|
15
|
+
axis_idx = message.get_axis_idx(axis)
|
|
16
|
+
means, self._state.zi = sps.lfilter(
|
|
17
|
+
[self._state.alpha],
|
|
18
|
+
[1.0, self._state.alpha - 1.0],
|
|
19
|
+
message.data,
|
|
20
|
+
axis=axis_idx,
|
|
21
|
+
zi=self._state.zi,
|
|
22
|
+
)
|
|
23
|
+
return replace(message, data=message.data - means)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DetrendUnit(
|
|
27
|
+
BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, DetrendTransformer]
|
|
28
|
+
):
|
|
29
|
+
SETTINGS = EWMASettings
|
ezmsg/sigproc/diff.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import ezmsg.core as ez
|
|
2
|
+
import numpy as np
|
|
3
|
+
import numpy.typing as npt
|
|
4
|
+
from ezmsg.sigproc.base import (
|
|
5
|
+
BaseTransformerUnit,
|
|
6
|
+
processor_state,
|
|
7
|
+
BaseStatefulTransformer,
|
|
8
|
+
)
|
|
9
|
+
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
10
|
+
from ezmsg.util.messages.util import replace
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DiffSettings(ez.Settings):
|
|
14
|
+
axis: str | None = None
|
|
15
|
+
scale_by_fs: bool = False
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@processor_state
|
|
19
|
+
class DiffState:
|
|
20
|
+
last_dat: npt.NDArray | None = None
|
|
21
|
+
last_time: float | None = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class DiffTransformer(
|
|
25
|
+
BaseStatefulTransformer[DiffSettings, AxisArray, AxisArray, DiffState]
|
|
26
|
+
):
|
|
27
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
28
|
+
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
29
|
+
sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
|
|
30
|
+
return hash((sample_shape, message.key))
|
|
31
|
+
|
|
32
|
+
def _reset_state(self, message) -> None:
|
|
33
|
+
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
34
|
+
self.state.last_dat = slice_along_axis(message.data, slice(0, 1), axis=ax_idx)
|
|
35
|
+
if self.settings.scale_by_fs:
|
|
36
|
+
ax_info = message.get_axis(self.settings.axis)
|
|
37
|
+
if hasattr(ax_info, "data"):
|
|
38
|
+
if len(ax_info.data) > 1:
|
|
39
|
+
self.state.last_time = 2 * ax_info.data[0] - ax_info.data[1]
|
|
40
|
+
else:
|
|
41
|
+
self.state.last_time = ax_info.data[0] - 0.001
|
|
42
|
+
|
|
43
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
44
|
+
axis = self.settings.axis or message.dims[0]
|
|
45
|
+
ax_idx = message.get_axis_idx(axis)
|
|
46
|
+
|
|
47
|
+
diffs = np.diff(
|
|
48
|
+
np.concatenate((self.state.last_dat, message.data), axis=ax_idx),
|
|
49
|
+
axis=ax_idx,
|
|
50
|
+
)
|
|
51
|
+
# Prepare last_dat for next iteration
|
|
52
|
+
self.state.last_dat = slice_along_axis(
|
|
53
|
+
message.data, slice(-1, None), axis=ax_idx
|
|
54
|
+
)
|
|
55
|
+
# Scale by fs if requested. This convers the diff to a derivative. e.g., diff of position becomes velocity.
|
|
56
|
+
if self.settings.scale_by_fs:
|
|
57
|
+
ax_info = message.get_axis(axis)
|
|
58
|
+
if hasattr(ax_info, "data"):
|
|
59
|
+
dt = np.diff(np.concatenate(([self.state.last_time], ax_info.data)))
|
|
60
|
+
# Expand dt dims to match diffs
|
|
61
|
+
exp_sl = (
|
|
62
|
+
(None,) * ax_idx
|
|
63
|
+
+ (Ellipsis,)
|
|
64
|
+
+ (None,) * (message.data.ndim - ax_idx - 1)
|
|
65
|
+
)
|
|
66
|
+
diffs /= dt[exp_sl]
|
|
67
|
+
self.state.last_time = ax_info.data[-1] # For next iteration
|
|
68
|
+
else:
|
|
69
|
+
diffs /= ax_info.gain
|
|
70
|
+
|
|
71
|
+
return replace(message, data=diffs)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class DiffUnit(
|
|
75
|
+
BaseTransformerUnit[DiffSettings, AxisArray, AxisArray, DiffTransformer]
|
|
76
|
+
):
|
|
77
|
+
SETTINGS = DiffSettings
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def diff(axis: str = "time", scale_by_fs: bool = False) -> DiffTransformer:
|
|
81
|
+
return DiffTransformer(DiffSettings(axis=axis, scale_by_fs=scale_by_fs))
|