ezmsg-sigproc 1.4.2__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +2 -2
- ezmsg/sigproc/affinetransform.py +13 -13
- ezmsg/sigproc/aggregate.py +49 -28
- ezmsg/sigproc/bandpower.py +2 -2
- ezmsg/sigproc/butterworthfilter.py +89 -90
- ezmsg/sigproc/cheby.py +119 -0
- ezmsg/sigproc/decimate.py +11 -15
- ezmsg/sigproc/downsample.py +8 -4
- ezmsg/sigproc/ewmfilter.py +9 -5
- ezmsg/sigproc/filter.py +82 -115
- ezmsg/sigproc/filterbank.py +5 -5
- ezmsg/sigproc/math/abs.py +1 -1
- ezmsg/sigproc/math/clip.py +1 -1
- ezmsg/sigproc/math/difference.py +1 -1
- ezmsg/sigproc/math/invert.py +1 -1
- ezmsg/sigproc/math/log.py +1 -1
- ezmsg/sigproc/math/scale.py +1 -1
- ezmsg/sigproc/messages.py +2 -3
- ezmsg/sigproc/sampler.py +16 -15
- ezmsg/sigproc/scaler.py +153 -35
- ezmsg/sigproc/signalinjector.py +7 -7
- ezmsg/sigproc/slicer.py +34 -14
- ezmsg/sigproc/spectrogram.py +6 -6
- ezmsg/sigproc/spectrum.py +18 -14
- ezmsg/sigproc/synth.py +43 -27
- ezmsg/sigproc/wavelets.py +42 -17
- ezmsg/sigproc/window.py +14 -13
- {ezmsg_sigproc-1.4.2.dist-info → ezmsg_sigproc-1.6.0.dist-info}/METADATA +4 -5
- ezmsg_sigproc-1.6.0.dist-info/RECORD +36 -0
- {ezmsg_sigproc-1.4.2.dist-info → ezmsg_sigproc-1.6.0.dist-info}/WHEEL +1 -1
- ezmsg_sigproc-1.4.2.dist-info/RECORD +0 -35
- {ezmsg_sigproc-1.4.2.dist-info → ezmsg_sigproc-1.6.0.dist-info}/licenses/LICENSE.txt +0 -0
ezmsg/sigproc/__version__.py
CHANGED
ezmsg/sigproc/activation.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
from dataclasses import replace
|
|
2
1
|
import typing
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
4
|
import scipy.special
|
|
6
5
|
import ezmsg.core as ez
|
|
7
6
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
+
from ezmsg.util.messages.util import replace
|
|
8
8
|
from ezmsg.util.generator import consumer
|
|
9
9
|
|
|
10
10
|
from .spectral import OptionsEnum
|
|
@@ -41,7 +41,7 @@ ACTIVATIONS = {
|
|
|
41
41
|
|
|
42
42
|
@consumer
|
|
43
43
|
def activation(
|
|
44
|
-
function:
|
|
44
|
+
function: str | ActivationFunction,
|
|
45
45
|
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
46
46
|
"""
|
|
47
47
|
Transform the data with a simple activation function.
|
ezmsg/sigproc/affinetransform.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
from dataclasses import replace
|
|
2
1
|
import os
|
|
3
2
|
from pathlib import Path
|
|
4
3
|
import typing
|
|
@@ -6,7 +5,8 @@ import typing
|
|
|
6
5
|
import numpy as np
|
|
7
6
|
import numpy.typing as npt
|
|
8
7
|
import ezmsg.core as ez
|
|
9
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
|
+
from ezmsg.util.messages.axisarray import AxisArray, AxisBase
|
|
9
|
+
from ezmsg.util.messages.util import replace
|
|
10
10
|
from ezmsg.util.generator import consumer
|
|
11
11
|
|
|
12
12
|
from .base import GenAxisArray
|
|
@@ -14,8 +14,8 @@ from .base import GenAxisArray
|
|
|
14
14
|
|
|
15
15
|
@consumer
|
|
16
16
|
def affine_transform(
|
|
17
|
-
weights:
|
|
18
|
-
axis:
|
|
17
|
+
weights: np.ndarray | str | Path,
|
|
18
|
+
axis: str | None = None,
|
|
19
19
|
right_multiply: bool = True,
|
|
20
20
|
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
21
21
|
"""
|
|
@@ -47,7 +47,7 @@ def affine_transform(
|
|
|
47
47
|
|
|
48
48
|
# State variables
|
|
49
49
|
# New axis with transformed labels, if required
|
|
50
|
-
new_axis:
|
|
50
|
+
new_axis: AxisBase | None = None
|
|
51
51
|
|
|
52
52
|
# Reset if any of these change.
|
|
53
53
|
check_input = {"key": None}
|
|
@@ -71,10 +71,10 @@ def affine_transform(
|
|
|
71
71
|
# Determine if we need to modify the transformed axis.
|
|
72
72
|
if (
|
|
73
73
|
axis in msg_in.axes
|
|
74
|
-
and hasattr(msg_in.axes[axis], "
|
|
74
|
+
and hasattr(msg_in.axes[axis], "data")
|
|
75
75
|
and weights.shape[0] != weights.shape[1]
|
|
76
76
|
):
|
|
77
|
-
in_labels = msg_in.axes[axis].
|
|
77
|
+
in_labels = msg_in.axes[axis].data
|
|
78
78
|
new_labels = []
|
|
79
79
|
n_in, n_out = weights.shape
|
|
80
80
|
if len(in_labels) != n_in:
|
|
@@ -101,8 +101,8 @@ def affine_transform(
|
|
|
101
101
|
new_labels.append("")
|
|
102
102
|
elif np.all(b_filled_outputs):
|
|
103
103
|
# Transform is dropping some of the inputs.
|
|
104
|
-
new_labels = np.array(in_labels)[b_used_inputs]
|
|
105
|
-
new_axis = replace(msg_in.axes[axis],
|
|
104
|
+
new_labels = np.array(in_labels)[b_used_inputs]
|
|
105
|
+
new_axis = replace(msg_in.axes[axis], data=np.array(new_labels))
|
|
106
106
|
|
|
107
107
|
data = msg_in.data
|
|
108
108
|
|
|
@@ -133,8 +133,8 @@ class AffineTransformSettings(ez.Settings):
|
|
|
133
133
|
See :obj:`affine_transform` for argument details.
|
|
134
134
|
"""
|
|
135
135
|
|
|
136
|
-
weights:
|
|
137
|
-
axis:
|
|
136
|
+
weights: np.ndarray | str | Path
|
|
137
|
+
axis: str | None = None
|
|
138
138
|
right_multiply: bool = True
|
|
139
139
|
|
|
140
140
|
|
|
@@ -157,7 +157,7 @@ def zeros_for_noop(data: npt.NDArray, **ignore_kwargs) -> npt.NDArray:
|
|
|
157
157
|
|
|
158
158
|
@consumer
|
|
159
159
|
def common_rereference(
|
|
160
|
-
mode: str = "mean", axis:
|
|
160
|
+
mode: str = "mean", axis: str | None = None, include_current: bool = True
|
|
161
161
|
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
162
162
|
"""
|
|
163
163
|
Perform common average referencing (CAR) on streaming data.
|
|
@@ -214,7 +214,7 @@ class CommonRereferenceSettings(ez.Settings):
|
|
|
214
214
|
"""
|
|
215
215
|
|
|
216
216
|
mode: str = "mean"
|
|
217
|
-
axis:
|
|
217
|
+
axis: str | None = None
|
|
218
218
|
include_current: bool = True
|
|
219
219
|
|
|
220
220
|
|
ezmsg/sigproc/aggregate.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
|
1
|
-
from dataclasses import replace
|
|
2
1
|
import typing
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
4
|
import numpy.typing as npt
|
|
6
5
|
import ezmsg.core as ez
|
|
7
6
|
from ezmsg.util.generator import consumer
|
|
8
|
-
from ezmsg.util.messages.axisarray import
|
|
7
|
+
from ezmsg.util.messages.axisarray import (
|
|
8
|
+
AxisArray,
|
|
9
|
+
slice_along_axis,
|
|
10
|
+
AxisBase,
|
|
11
|
+
replace,
|
|
12
|
+
)
|
|
9
13
|
|
|
10
14
|
from .spectral import OptionsEnum
|
|
11
15
|
from .base import GenAxisArray
|
|
@@ -52,8 +56,8 @@ AGGREGATORS = {
|
|
|
52
56
|
|
|
53
57
|
@consumer
|
|
54
58
|
def ranged_aggregate(
|
|
55
|
-
axis:
|
|
56
|
-
bands:
|
|
59
|
+
axis: str | None = None,
|
|
60
|
+
bands: list[tuple[float, float]] | None = None,
|
|
57
61
|
operation: AggregationFunction = AggregationFunction.MEAN,
|
|
58
62
|
):
|
|
59
63
|
"""
|
|
@@ -71,12 +75,12 @@ def ranged_aggregate(
|
|
|
71
75
|
msg_out = AxisArray(np.array([]), dims=[""])
|
|
72
76
|
|
|
73
77
|
# State variables
|
|
74
|
-
slices:
|
|
75
|
-
out_axis:
|
|
76
|
-
ax_vec:
|
|
78
|
+
slices: list[tuple[typing.Any, ...]] | None = None
|
|
79
|
+
out_axis: AxisBase | None = None
|
|
80
|
+
ax_vec: npt.NDArray | None = None
|
|
77
81
|
|
|
78
82
|
# Reset if any of these changes. Key not checked because continuity between chunks not required.
|
|
79
|
-
check_inputs = {"gain": None, "offset": None}
|
|
83
|
+
check_inputs = {"gain": None, "offset": None, "len": None, "key": None}
|
|
80
84
|
|
|
81
85
|
while True:
|
|
82
86
|
msg_in: AxisArray = yield msg_out
|
|
@@ -86,35 +90,52 @@ def ranged_aggregate(
|
|
|
86
90
|
axis = axis or msg_in.dims[0]
|
|
87
91
|
target_axis = msg_in.get_axis(axis)
|
|
88
92
|
|
|
89
|
-
|
|
90
|
-
b_reset =
|
|
93
|
+
# Check if we need to reset state
|
|
94
|
+
b_reset = msg_in.key != check_inputs["key"]
|
|
95
|
+
if hasattr(target_axis, "data"):
|
|
96
|
+
b_reset = b_reset or len(target_axis.data) != check_inputs["len"]
|
|
97
|
+
elif isinstance(target_axis, AxisArray.LinearAxis):
|
|
98
|
+
b_reset = b_reset or target_axis.gain != check_inputs["gain"]
|
|
99
|
+
b_reset = b_reset or target_axis.offset != check_inputs["offset"]
|
|
100
|
+
|
|
91
101
|
if b_reset:
|
|
92
|
-
|
|
93
|
-
check_inputs["
|
|
102
|
+
# Update check variables
|
|
103
|
+
check_inputs["key"] = msg_in.key
|
|
104
|
+
if hasattr(target_axis, "data"):
|
|
105
|
+
check_inputs["len"] = len(target_axis.data)
|
|
106
|
+
else:
|
|
107
|
+
check_inputs["gain"] = target_axis.gain
|
|
108
|
+
check_inputs["offset"] = target_axis.offset
|
|
94
109
|
|
|
95
110
|
# If the axis we are operating on has changed (e.g., "time" or "win" axis always changes),
|
|
96
111
|
# or the key has changed, then recalculate slices.
|
|
97
112
|
|
|
98
113
|
ax_idx = msg_in.get_axis_idx(axis)
|
|
99
114
|
|
|
100
|
-
|
|
101
|
-
target_axis.
|
|
102
|
-
|
|
103
|
-
|
|
115
|
+
if hasattr(target_axis, "data"):
|
|
116
|
+
ax_vec = target_axis.data
|
|
117
|
+
else:
|
|
118
|
+
ax_vec = target_axis.value(np.arange(msg_in.data.shape[ax_idx]))
|
|
119
|
+
|
|
104
120
|
slices = []
|
|
105
|
-
|
|
121
|
+
ax_dat = []
|
|
106
122
|
for start, stop in bands:
|
|
107
123
|
inds = np.where(np.logical_and(ax_vec >= start, ax_vec <= stop))[0]
|
|
108
|
-
mids.append(np.mean(inds) * target_axis.gain + target_axis.offset)
|
|
109
124
|
slices.append(np.s_[inds[0] : inds[-1] + 1])
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
125
|
+
if hasattr(target_axis, "data"):
|
|
126
|
+
if ax_vec.dtype.type is np.str_:
|
|
127
|
+
sl_dat = f"{ax_vec[start]} - {ax_vec[stop]}"
|
|
128
|
+
else:
|
|
129
|
+
sl_dat = ax_dat.append(np.mean(ax_vec[inds]))
|
|
130
|
+
else:
|
|
131
|
+
sl_dat = target_axis.value(np.mean(inds))
|
|
132
|
+
ax_dat.append(sl_dat)
|
|
133
|
+
|
|
134
|
+
out_axis = AxisArray.CoordinateAxis(
|
|
135
|
+
data=np.array(ax_dat),
|
|
136
|
+
dims=[axis],
|
|
137
|
+
unit=target_axis.unit,
|
|
138
|
+
)
|
|
118
139
|
|
|
119
140
|
agg_func = AGGREGATORS[operation]
|
|
120
141
|
out_data = [
|
|
@@ -142,8 +163,8 @@ class RangedAggregateSettings(ez.Settings):
|
|
|
142
163
|
See :obj:`ranged_aggregate` for details.
|
|
143
164
|
"""
|
|
144
165
|
|
|
145
|
-
axis:
|
|
146
|
-
bands:
|
|
166
|
+
axis: str | None = None
|
|
167
|
+
bands: list[tuple[float, float]] | None = None
|
|
147
168
|
operation: AggregationFunction = AggregationFunction.MEAN
|
|
148
169
|
|
|
149
170
|
|
ezmsg/sigproc/bandpower.py
CHANGED
|
@@ -14,7 +14,7 @@ from .base import GenAxisArray
|
|
|
14
14
|
@consumer
|
|
15
15
|
def bandpower(
|
|
16
16
|
spectrogram_settings: SpectrogramSettings,
|
|
17
|
-
bands:
|
|
17
|
+
bands: list[tuple[float, float]] | None = [
|
|
18
18
|
(17, 30),
|
|
19
19
|
(70, 170),
|
|
20
20
|
],
|
|
@@ -58,7 +58,7 @@ class BandPowerSettings(ez.Settings):
|
|
|
58
58
|
spectrogram_settings: SpectrogramSettings = field(
|
|
59
59
|
default_factory=SpectrogramSettings
|
|
60
60
|
)
|
|
61
|
-
bands:
|
|
61
|
+
bands: list[tuple[float, float]] | None = field(
|
|
62
62
|
default_factory=lambda: [(17, 30), (70, 170)]
|
|
63
63
|
)
|
|
64
64
|
|
|
@@ -1,15 +1,19 @@
|
|
|
1
|
+
import functools
|
|
1
2
|
import typing
|
|
2
3
|
|
|
3
|
-
import ezmsg.core as ez
|
|
4
4
|
import scipy.signal
|
|
5
|
-
import numpy as np
|
|
6
5
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
-
from
|
|
6
|
+
from scipy.signal import normalize
|
|
8
7
|
|
|
9
|
-
from .filter import
|
|
8
|
+
from .filter import (
|
|
9
|
+
FilterBaseSettings,
|
|
10
|
+
FilterCoefsMultiType,
|
|
11
|
+
FilterBase,
|
|
12
|
+
filter_gen_by_design,
|
|
13
|
+
)
|
|
10
14
|
|
|
11
15
|
|
|
12
|
-
class ButterworthFilterSettings(
|
|
16
|
+
class ButterworthFilterSettings(FilterBaseSettings):
|
|
13
17
|
"""Settings for :obj:`ButterworthFilter`."""
|
|
14
18
|
|
|
15
19
|
order: int = 0
|
|
@@ -17,25 +21,28 @@ class ButterworthFilterSettings(FilterSettingsBase):
|
|
|
17
21
|
Filter order
|
|
18
22
|
"""
|
|
19
23
|
|
|
20
|
-
cuton:
|
|
24
|
+
cuton: float | None = None
|
|
21
25
|
"""
|
|
22
26
|
Cuton frequency (Hz). If `cutoff` is not specified then this is the highpass corner. Otherwise,
|
|
23
27
|
if this is lower than `cutoff` then this is the beginning of the bandpass
|
|
24
28
|
or if this is greater than `cutoff` then this is the end of the bandstop.
|
|
25
29
|
"""
|
|
26
30
|
|
|
27
|
-
cutoff:
|
|
31
|
+
cutoff: float | None = None
|
|
28
32
|
"""
|
|
29
33
|
Cutoff frequency (Hz). If `cuton` is not specified then this is the lowpass corner. Otherwise,
|
|
30
34
|
if this is greater than `cuton` then this is the end of the bandpass,
|
|
31
35
|
or if this is less than `cuton` then this is the beginning of the bandstop.
|
|
32
36
|
"""
|
|
33
37
|
|
|
38
|
+
wn_hz: bool = True
|
|
39
|
+
"""
|
|
40
|
+
Set False if provided Wn are normalized from 0 to 1, where 1 is the Nyquist frequency
|
|
41
|
+
"""
|
|
42
|
+
|
|
34
43
|
def filter_specs(
|
|
35
44
|
self,
|
|
36
|
-
) ->
|
|
37
|
-
typing.Tuple[str, typing.Union[float, typing.Tuple[float, float]]]
|
|
38
|
-
]:
|
|
45
|
+
) -> tuple[str, float | tuple[float, float]] | None:
|
|
39
46
|
"""
|
|
40
47
|
Determine the filter type given the corner frequencies.
|
|
41
48
|
|
|
@@ -58,21 +65,81 @@ class ButterworthFilterSettings(FilterSettingsBase):
|
|
|
58
65
|
return "bandstop", (self.cutoff, self.cuton)
|
|
59
66
|
|
|
60
67
|
|
|
61
|
-
|
|
68
|
+
def butter_design_fun(
|
|
69
|
+
fs: float,
|
|
70
|
+
order: int = 0,
|
|
71
|
+
cuton: float | None = None,
|
|
72
|
+
cutoff: float | None = None,
|
|
73
|
+
coef_type: str = "ba",
|
|
74
|
+
wn_hz: bool = True,
|
|
75
|
+
) -> FilterCoefsMultiType | None:
|
|
76
|
+
"""
|
|
77
|
+
See :obj:`ButterworthFilterSettings.filter_specs` for an explanation of specifying different
|
|
78
|
+
filter types (lowpass, highpass, bandpass, bandstop) from the parameters.
|
|
79
|
+
You are likely to want to use this function with :obj:`filter_by_design`, which only passes `fs` to the design
|
|
80
|
+
function (this), meaning that you should wrap this function with a lambda or prepare with functools.partial.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
fs: The sampling frequency of the data in Hz.
|
|
84
|
+
order: Filter order.
|
|
85
|
+
cuton: Corner frequency of the filter in Hz.
|
|
86
|
+
cutoff: Corner frequency of the filter in Hz.
|
|
87
|
+
coef_type: "ba", "sos", or "zpk"
|
|
88
|
+
wn_hz: Set False if provided Wn are normalized from 0 to 1, where 1 is the Nyquist frequency
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
The filter coefficients as a tuple of (b, a) for coef_type "ba", or as a single ndarray for "sos",
|
|
92
|
+
or (z, p, k) for "zpk".
|
|
93
|
+
|
|
94
|
+
"""
|
|
95
|
+
coefs = None
|
|
96
|
+
if order > 0:
|
|
97
|
+
btype, cutoffs = ButterworthFilterSettings(
|
|
98
|
+
order=order, cuton=cuton, cutoff=cutoff
|
|
99
|
+
).filter_specs()
|
|
100
|
+
coefs = scipy.signal.butter(
|
|
101
|
+
order,
|
|
102
|
+
Wn=cutoffs,
|
|
103
|
+
btype=btype,
|
|
104
|
+
fs=fs if wn_hz else None,
|
|
105
|
+
output=coef_type,
|
|
106
|
+
)
|
|
107
|
+
if coefs is not None and coef_type == "ba":
|
|
108
|
+
coefs = normalize(*coefs)
|
|
109
|
+
return coefs
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class ButterworthFilter(FilterBase):
|
|
113
|
+
SETTINGS = ButterworthFilterSettings
|
|
114
|
+
|
|
115
|
+
def design_filter(
|
|
116
|
+
self,
|
|
117
|
+
) -> typing.Callable[[float], FilterCoefsMultiType | None]:
|
|
118
|
+
return functools.partial(
|
|
119
|
+
butter_design_fun,
|
|
120
|
+
order=self.SETTINGS.order,
|
|
121
|
+
cuton=self.SETTINGS.cuton,
|
|
122
|
+
cutoff=self.SETTINGS.cutoff,
|
|
123
|
+
coef_type=self.SETTINGS.coef_type,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
62
127
|
def butter(
|
|
63
|
-
axis:
|
|
128
|
+
axis: str | None,
|
|
64
129
|
order: int = 0,
|
|
65
|
-
cuton:
|
|
66
|
-
cutoff:
|
|
130
|
+
cuton: float | None = None,
|
|
131
|
+
cutoff: float | None = None,
|
|
67
132
|
coef_type: str = "ba",
|
|
68
133
|
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
69
134
|
"""
|
|
135
|
+
Convenience generator wrapping filter_gen_by_design for Butterworth filters.
|
|
70
136
|
Apply Butterworth filter to streaming data. Uses :obj:`scipy.signal.butter` to design the filter.
|
|
71
137
|
See :obj:`ButterworthFilterSettings.filter_specs` for an explanation of specifying different
|
|
72
138
|
filter types (lowpass, highpass, bandpass, bandstop) from the parameters.
|
|
73
139
|
|
|
74
140
|
Args:
|
|
75
141
|
axis: The name of the axis to filter.
|
|
142
|
+
Note: The axis must be represented in the message .axes and be of type AxisArray.LinearAxis.
|
|
76
143
|
order: Filter order.
|
|
77
144
|
cuton: Corner frequency of the filter in Hz.
|
|
78
145
|
cutoff: Corner frequency of the filter in Hz.
|
|
@@ -83,79 +150,11 @@ def butter(
|
|
|
83
150
|
and yields an :obj:`AxisArray` with filtered data.
|
|
84
151
|
|
|
85
152
|
"""
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
)
|
|
93
|
-
|
|
94
|
-
# State variables
|
|
95
|
-
# Initialize filtergen as passthrough until we can calculate coefs.
|
|
96
|
-
filter_gen = filtergen(axis, None, coef_type)
|
|
97
|
-
|
|
98
|
-
# Reset if these change.
|
|
99
|
-
check_input = {"gain": None}
|
|
100
|
-
# Key not checked because filter_gen will handle resetting if .key changes.
|
|
101
|
-
|
|
102
|
-
while True:
|
|
103
|
-
msg_in: AxisArray = yield msg_out
|
|
104
|
-
axis = axis or msg_in.dims[0]
|
|
105
|
-
|
|
106
|
-
b_reset = msg_in.axes[axis].gain != check_input["gain"]
|
|
107
|
-
b_reset = b_reset and order > 0 # Not passthrough
|
|
108
|
-
if b_reset:
|
|
109
|
-
check_input["gain"] = msg_in.axes[axis].gain
|
|
110
|
-
coefs = scipy.signal.butter(
|
|
111
|
-
order,
|
|
112
|
-
Wn=cutoffs,
|
|
113
|
-
btype=btype,
|
|
114
|
-
fs=1 / msg_in.axes[axis].gain,
|
|
115
|
-
output=coef_type,
|
|
116
|
-
)
|
|
117
|
-
filter_gen = filtergen(axis, coefs, coef_type)
|
|
118
|
-
|
|
119
|
-
msg_out = filter_gen.send(msg_in)
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
class ButterworthFilterState(FilterState):
|
|
123
|
-
design: ButterworthFilterSettings
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
class ButterworthFilter(Filter):
|
|
127
|
-
""":obj:`Unit` for :obj:`butterworth`"""
|
|
128
|
-
|
|
129
|
-
SETTINGS = ButterworthFilterSettings
|
|
130
|
-
STATE = ButterworthFilterState
|
|
131
|
-
|
|
132
|
-
INPUT_FILTER = ez.InputStream(ButterworthFilterSettings)
|
|
133
|
-
|
|
134
|
-
async def initialize(self) -> None:
|
|
135
|
-
self.STATE.design = self.SETTINGS
|
|
136
|
-
self.STATE.filt_designed = True
|
|
137
|
-
await super().initialize()
|
|
138
|
-
|
|
139
|
-
def design_filter(self) -> typing.Optional[typing.Tuple[np.ndarray, np.ndarray]]:
|
|
140
|
-
specs = self.STATE.design.filter_specs()
|
|
141
|
-
if self.STATE.design.order > 0 and specs is not None:
|
|
142
|
-
btype, cut = specs
|
|
143
|
-
return scipy.signal.butter(
|
|
144
|
-
self.STATE.design.order,
|
|
145
|
-
Wn=cut,
|
|
146
|
-
btype=btype,
|
|
147
|
-
fs=self.STATE.fs,
|
|
148
|
-
output="ba",
|
|
149
|
-
)
|
|
150
|
-
else:
|
|
151
|
-
return None
|
|
152
|
-
|
|
153
|
-
@ez.subscriber(INPUT_FILTER)
|
|
154
|
-
async def redesign(self, message: ButterworthFilterSettings) -> None:
|
|
155
|
-
if type(message) is not ButterworthFilterSettings:
|
|
156
|
-
return
|
|
157
|
-
|
|
158
|
-
if self.STATE.design.order != message.order:
|
|
159
|
-
self.STATE.zi = None
|
|
160
|
-
self.STATE.design = message
|
|
161
|
-
self.update_filter()
|
|
153
|
+
design_fun = functools.partial(
|
|
154
|
+
butter_design_fun,
|
|
155
|
+
order=order,
|
|
156
|
+
cuton=cuton,
|
|
157
|
+
cutoff=cutoff,
|
|
158
|
+
coef_type=coef_type,
|
|
159
|
+
)
|
|
160
|
+
return filter_gen_by_design(axis, coef_type, design_fun)
|
ezmsg/sigproc/cheby.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import scipy.signal
|
|
5
|
+
from scipy.signal import normalize
|
|
6
|
+
|
|
7
|
+
from .filter import (
|
|
8
|
+
FilterBaseSettings,
|
|
9
|
+
FilterCoefsMultiType,
|
|
10
|
+
FilterBase,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ChebyshevFilterSettings(FilterBaseSettings):
|
|
15
|
+
"""Settings for :obj:`ButterworthFilter`."""
|
|
16
|
+
|
|
17
|
+
order: int = 0
|
|
18
|
+
"""
|
|
19
|
+
Filter order
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
ripple_tol: float | None = None
|
|
23
|
+
"""
|
|
24
|
+
The maximum ripple allowed below unity gain in the passband. Specified in decibels, as a positive number.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
Wn: float | tuple[float, float] | None = None
|
|
28
|
+
"""
|
|
29
|
+
A scalar or length-2 sequence giving the critical frequencies.
|
|
30
|
+
For Type I filters, this is the point in the transition band at which the gain first drops below -rp.
|
|
31
|
+
For digital filters, Wn are in the same units as fs unless wn_hz is False.
|
|
32
|
+
For analog filters, Wn is an angular frequency (e.g., rad/s).
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
btype: str = "lowpass"
|
|
36
|
+
"""
|
|
37
|
+
{‘lowpass’, ‘highpass’, ‘bandpass’, ‘bandstop’}
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
analog: bool = False
|
|
41
|
+
"""
|
|
42
|
+
When True, return an analog filter, otherwise a digital filter is returned.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
cheby_type: str = "cheby1"
|
|
46
|
+
"""
|
|
47
|
+
Which type of Chebyshev filter to design. Either "cheby1" or "cheby2".
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
wn_hz: bool = True
|
|
51
|
+
"""
|
|
52
|
+
Set False if provided Wn are normalized from 0 to 1, where 1 is the Nyquist frequency
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def cheby_design_fun(
|
|
57
|
+
fs: float,
|
|
58
|
+
order: int = 0,
|
|
59
|
+
ripple_tol: float | None = None,
|
|
60
|
+
Wn: float | tuple[float, float] | None = None,
|
|
61
|
+
btype: str = "lowpass",
|
|
62
|
+
analog: bool = False,
|
|
63
|
+
coef_type: str = "ba",
|
|
64
|
+
cheby_type: str = "cheby1",
|
|
65
|
+
wn_hz: bool = True,
|
|
66
|
+
) -> FilterCoefsMultiType:
|
|
67
|
+
"""
|
|
68
|
+
Chebyshev type I and type II digital and analog filter design.
|
|
69
|
+
Design an `order`th-order digital or analog Chebyshev type I or type II filter and return the filter coefficients.
|
|
70
|
+
See :obj:`ChebyFilterSettings` for argument description.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
The filter coefficients as a tuple of (b, a) for coef_type "ba", or as a single ndarray for "sos",
|
|
74
|
+
or (z, p, k) for "zpk".
|
|
75
|
+
"""
|
|
76
|
+
coefs = None
|
|
77
|
+
if order > 0:
|
|
78
|
+
if cheby_type == "cheby1":
|
|
79
|
+
coefs = scipy.signal.cheby1(
|
|
80
|
+
order,
|
|
81
|
+
ripple_tol,
|
|
82
|
+
Wn,
|
|
83
|
+
btype=btype,
|
|
84
|
+
analog=analog,
|
|
85
|
+
output=coef_type,
|
|
86
|
+
fs=fs if wn_hz else None,
|
|
87
|
+
)
|
|
88
|
+
elif cheby_type == "cheby2":
|
|
89
|
+
coefs = scipy.signal.cheby2(
|
|
90
|
+
order,
|
|
91
|
+
ripple_tol,
|
|
92
|
+
Wn,
|
|
93
|
+
btype=btype,
|
|
94
|
+
analog=analog,
|
|
95
|
+
output=coef_type,
|
|
96
|
+
fs=fs,
|
|
97
|
+
)
|
|
98
|
+
if coefs is not None and coef_type == "ba":
|
|
99
|
+
coefs = normalize(*coefs)
|
|
100
|
+
return coefs
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class ChebyshevFilter(FilterBase):
|
|
104
|
+
SETTINGS = ChebyshevFilterSettings
|
|
105
|
+
|
|
106
|
+
def design_filter(
|
|
107
|
+
self,
|
|
108
|
+
) -> typing.Callable[[float], FilterCoefsMultiType | None]:
|
|
109
|
+
return functools.partial(
|
|
110
|
+
cheby_design_fun,
|
|
111
|
+
order=self.SETTINGS.order,
|
|
112
|
+
ripple_tol=self.SETTINGS.ripple_tol,
|
|
113
|
+
Wn=self.SETTINGS.Wn,
|
|
114
|
+
btype=self.SETTINGS.btype,
|
|
115
|
+
analog=self.SETTINGS.analog,
|
|
116
|
+
coef_type=self.SETTINGS.coef_type,
|
|
117
|
+
cheby_type=self.SETTINGS.cheby_type,
|
|
118
|
+
wn_hz=self.SETTINGS.wn_hz,
|
|
119
|
+
)
|
ezmsg/sigproc/decimate.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
import scipy.signal
|
|
2
1
|
import ezmsg.core as ez
|
|
3
2
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
4
3
|
|
|
4
|
+
from .cheby import ChebyshevFilter, ChebyshevFilterSettings
|
|
5
5
|
from .downsample import Downsample, DownsampleSettings
|
|
6
|
-
from .filter import Filter, FilterCoefficients, FilterSettings
|
|
7
6
|
|
|
8
7
|
|
|
9
8
|
class Decimate(ez.Collection):
|
|
@@ -17,24 +16,21 @@ class Decimate(ez.Collection):
|
|
|
17
16
|
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
18
17
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
19
18
|
|
|
20
|
-
FILTER =
|
|
19
|
+
FILTER = ChebyshevFilter()
|
|
21
20
|
DOWNSAMPLE = Downsample()
|
|
22
21
|
|
|
23
22
|
def configure(self) -> None:
|
|
23
|
+
cheby_settings = ChebyshevFilterSettings(
|
|
24
|
+
order=8 if self.SETTINGS.factor > 1 else 0,
|
|
25
|
+
ripple_tol=0.05,
|
|
26
|
+
Wn=0.8 / self.SETTINGS.factor if self.SETTINGS.factor > 1 else None,
|
|
27
|
+
btype="lowpass",
|
|
28
|
+
axis=self.SETTINGS.axis,
|
|
29
|
+
wn_hz=False,
|
|
30
|
+
)
|
|
31
|
+
self.FILTER.apply_settings(cheby_settings)
|
|
24
32
|
self.DOWNSAMPLE.apply_settings(self.SETTINGS)
|
|
25
33
|
|
|
26
|
-
if self.SETTINGS.factor < 1:
|
|
27
|
-
raise ValueError("Decimation factor must be >= 1 (no decimation")
|
|
28
|
-
elif self.SETTINGS.factor == 1:
|
|
29
|
-
filt = FilterCoefficients()
|
|
30
|
-
else:
|
|
31
|
-
# See scipy.signal.decimate for IIR Filter Condition
|
|
32
|
-
b, a = scipy.signal.cheby1(8, 0.05, 0.8 / self.SETTINGS.factor)
|
|
33
|
-
system = scipy.signal.dlti(b, a)
|
|
34
|
-
filt = FilterCoefficients(b=system.num, a=system.den) # type: ignore
|
|
35
|
-
|
|
36
|
-
self.FILTER.apply_settings(FilterSettings(filt=filt))
|
|
37
|
-
|
|
38
34
|
def network(self) -> ez.NetworkDefinition:
|
|
39
35
|
return (
|
|
40
36
|
(self.INPUT_SIGNAL, self.FILTER.INPUT_SIGNAL),
|