ezmsg-sigproc 1.2.2__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +1 -1
- ezmsg/sigproc/__version__.py +16 -1
- ezmsg/sigproc/activation.py +75 -0
- ezmsg/sigproc/affinetransform.py +234 -0
- ezmsg/sigproc/aggregate.py +158 -0
- ezmsg/sigproc/bandpower.py +74 -0
- ezmsg/sigproc/base.py +38 -0
- ezmsg/sigproc/butterworthfilter.py +102 -11
- ezmsg/sigproc/decimate.py +7 -4
- ezmsg/sigproc/downsample.py +95 -51
- ezmsg/sigproc/ewmfilter.py +38 -16
- ezmsg/sigproc/filter.py +108 -20
- ezmsg/sigproc/filterbank.py +278 -0
- ezmsg/sigproc/math/__init__.py +0 -0
- ezmsg/sigproc/math/abs.py +28 -0
- ezmsg/sigproc/math/clip.py +30 -0
- ezmsg/sigproc/math/difference.py +60 -0
- ezmsg/sigproc/math/invert.py +29 -0
- ezmsg/sigproc/math/log.py +32 -0
- ezmsg/sigproc/math/scale.py +31 -0
- ezmsg/sigproc/messages.py +2 -3
- ezmsg/sigproc/sampler.py +259 -224
- ezmsg/sigproc/scaler.py +173 -0
- ezmsg/sigproc/signalinjector.py +64 -0
- ezmsg/sigproc/slicer.py +133 -0
- ezmsg/sigproc/spectral.py +6 -132
- ezmsg/sigproc/spectrogram.py +86 -0
- ezmsg/sigproc/spectrum.py +259 -0
- ezmsg/sigproc/synth.py +299 -105
- ezmsg/sigproc/wavelets.py +167 -0
- ezmsg/sigproc/window.py +254 -116
- ezmsg_sigproc-1.3.1.dist-info/METADATA +59 -0
- ezmsg_sigproc-1.3.1.dist-info/RECORD +35 -0
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-1.3.1.dist-info}/WHEEL +1 -2
- ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
- ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
- ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-1.3.1.dist-info/licenses}/LICENSE.txt +0 -0
|
@@ -1,18 +1,47 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
1
3
|
import ezmsg.core as ez
|
|
2
4
|
import scipy.signal
|
|
3
5
|
import numpy as np
|
|
6
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
+
from ezmsg.util.generator import consumer
|
|
4
8
|
|
|
5
|
-
from .filter import Filter, FilterState, FilterSettingsBase
|
|
6
|
-
|
|
7
|
-
from typing import Optional, Tuple, Union
|
|
9
|
+
from .filter import filtergen, Filter, FilterState, FilterSettingsBase
|
|
8
10
|
|
|
9
11
|
|
|
10
12
|
class ButterworthFilterSettings(FilterSettingsBase):
|
|
13
|
+
"""Settings for :obj:`ButterworthFilter`."""
|
|
14
|
+
|
|
11
15
|
order: int = 0
|
|
12
|
-
cuton: Optional[float] = None # Hz
|
|
13
|
-
cutoff: Optional[float] = None # Hz
|
|
14
16
|
|
|
15
|
-
|
|
17
|
+
cuton: typing.Optional[float] = None
|
|
18
|
+
"""
|
|
19
|
+
Cuton frequency (Hz). If cutoff is not specified then this is the highpass corner, otherwise
|
|
20
|
+
if it is lower than cutoff then this is the beginning of the bandpass
|
|
21
|
+
or if it is greater than cuton then it is the end of the bandstop.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
cutoff: typing.Optional[float] = None
|
|
25
|
+
"""
|
|
26
|
+
Cutoff frequency (Hz). If cuton is not specified then this is the lowpass corner, otherwise
|
|
27
|
+
if it is greater than cuton then this is the end of the bandpass,
|
|
28
|
+
or if it is less than cuton then it is the beginning of the bandstop.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def filter_specs(
|
|
32
|
+
self,
|
|
33
|
+
) -> typing.Optional[
|
|
34
|
+
typing.Tuple[str, typing.Union[float, typing.Tuple[float, float]]]
|
|
35
|
+
]:
|
|
36
|
+
"""
|
|
37
|
+
Determine the filter type given the corner frequencies.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A tuple with the first element being a string indicating the filter type
|
|
41
|
+
(one of "lowpass", "highpass", "bandpass", "bandstop")
|
|
42
|
+
and the second element being the corner frequency or frequencies.
|
|
43
|
+
|
|
44
|
+
"""
|
|
16
45
|
if self.cuton is None and self.cutoff is None:
|
|
17
46
|
return None
|
|
18
47
|
elif self.cuton is None and self.cutoff is not None:
|
|
@@ -26,22 +55,84 @@ class ButterworthFilterSettings(FilterSettingsBase):
|
|
|
26
55
|
return "bandstop", (self.cutoff, self.cuton)
|
|
27
56
|
|
|
28
57
|
|
|
58
|
+
@consumer
|
|
59
|
+
def butter(
|
|
60
|
+
axis: typing.Optional[str],
|
|
61
|
+
order: int = 0,
|
|
62
|
+
cuton: typing.Optional[float] = None,
|
|
63
|
+
cutoff: typing.Optional[float] = None,
|
|
64
|
+
coef_type: str = "ba",
|
|
65
|
+
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
66
|
+
"""
|
|
67
|
+
Apply Butterworth filter to streaming data. Uses :obj:`scipy.signal.butter` to design the filter.
|
|
68
|
+
See :obj:`ButterworthFilterSettings.filter_specs` for an explanation of specifying different
|
|
69
|
+
filter types (lowpass, highpass, bandpass, bandstop) from the parameters.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
axis: The name of the axis to filter.
|
|
73
|
+
order: Filter order.
|
|
74
|
+
cuton: Corner frequency of the filter in Hz.
|
|
75
|
+
cutoff: Corner frequency of the filter in Hz.
|
|
76
|
+
coef_type: "ba" or "sos"
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
A primed generator object which accepts .send(axis_array) and yields filtered axis array.
|
|
80
|
+
|
|
81
|
+
"""
|
|
82
|
+
# IO
|
|
83
|
+
msg_out = AxisArray(np.array([]), dims=[""])
|
|
84
|
+
|
|
85
|
+
# Check parameters
|
|
86
|
+
btype, cutoffs = ButterworthFilterSettings(
|
|
87
|
+
order=order, cuton=cuton, cutoff=cutoff
|
|
88
|
+
).filter_specs()
|
|
89
|
+
|
|
90
|
+
# State variables
|
|
91
|
+
# Initialize filtergen as passthrough until we can calculate coefs.
|
|
92
|
+
filter_gen = filtergen(axis, None, coef_type)
|
|
93
|
+
|
|
94
|
+
# Reset if these change.
|
|
95
|
+
check_input = {"gain": None}
|
|
96
|
+
# Key not checked because filter_gen will handle resetting if .key changes.
|
|
97
|
+
|
|
98
|
+
while True:
|
|
99
|
+
msg_in: AxisArray = yield msg_out
|
|
100
|
+
axis = axis or msg_in.dims[0]
|
|
101
|
+
|
|
102
|
+
b_reset = msg_in.axes[axis].gain != check_input["gain"]
|
|
103
|
+
b_reset = b_reset and order > 0 # Not passthrough
|
|
104
|
+
if b_reset:
|
|
105
|
+
check_input["gain"] = msg_in.axes[axis].gain
|
|
106
|
+
coefs = scipy.signal.butter(
|
|
107
|
+
order,
|
|
108
|
+
Wn=cutoffs,
|
|
109
|
+
btype=btype,
|
|
110
|
+
fs=1 / msg_in.axes[axis].gain,
|
|
111
|
+
output=coef_type,
|
|
112
|
+
)
|
|
113
|
+
filter_gen = filtergen(axis, coefs, coef_type)
|
|
114
|
+
|
|
115
|
+
msg_out = filter_gen.send(msg_in)
|
|
116
|
+
|
|
117
|
+
|
|
29
118
|
class ButterworthFilterState(FilterState):
|
|
30
119
|
design: ButterworthFilterSettings
|
|
31
120
|
|
|
32
121
|
|
|
33
122
|
class ButterworthFilter(Filter):
|
|
34
|
-
|
|
35
|
-
|
|
123
|
+
""":obj:`Unit` for :obj:`butterworth`"""
|
|
124
|
+
|
|
125
|
+
SETTINGS = ButterworthFilterSettings
|
|
126
|
+
STATE = ButterworthFilterState
|
|
36
127
|
|
|
37
128
|
INPUT_FILTER = ez.InputStream(ButterworthFilterSettings)
|
|
38
129
|
|
|
39
|
-
def initialize(self) -> None:
|
|
130
|
+
async def initialize(self) -> None:
|
|
40
131
|
self.STATE.design = self.SETTINGS
|
|
41
132
|
self.STATE.filt_designed = True
|
|
42
|
-
super().initialize()
|
|
133
|
+
await super().initialize()
|
|
43
134
|
|
|
44
|
-
def design_filter(self) -> Optional[Tuple[np.ndarray, np.ndarray]]:
|
|
135
|
+
def design_filter(self) -> typing.Optional[typing.Tuple[np.ndarray, np.ndarray]]:
|
|
45
136
|
specs = self.STATE.design.filter_specs()
|
|
46
137
|
if self.STATE.design.order > 0 and specs is not None:
|
|
47
138
|
btype, cut = specs
|
ezmsg/sigproc/decimate.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
|
-
import ezmsg.core as ez
|
|
2
|
-
|
|
3
1
|
import scipy.signal
|
|
4
|
-
|
|
2
|
+
import ezmsg.core as ez
|
|
5
3
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
6
4
|
|
|
7
5
|
from .downsample import Downsample, DownsampleSettings
|
|
@@ -9,7 +7,12 @@ from .filter import Filter, FilterCoefficients, FilterSettings
|
|
|
9
7
|
|
|
10
8
|
|
|
11
9
|
class Decimate(ez.Collection):
|
|
12
|
-
|
|
10
|
+
"""
|
|
11
|
+
A :obj:`Collection` chaining a :obj:`Filter` node configured as a lowpass Chebyshev filter
|
|
12
|
+
and a :obj:`Downsample` node.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
SETTINGS = DownsampleSettings
|
|
13
16
|
|
|
14
17
|
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
15
18
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
ezmsg/sigproc/downsample.py
CHANGED
|
@@ -1,63 +1,107 @@
|
|
|
1
1
|
from dataclasses import replace
|
|
2
|
+
import typing
|
|
2
3
|
|
|
3
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
4
|
-
|
|
5
|
-
import ezmsg.core as ez
|
|
6
4
|
import numpy as np
|
|
5
|
+
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
6
|
+
from ezmsg.util.generator import consumer
|
|
7
|
+
import ezmsg.core as ez
|
|
7
8
|
|
|
8
|
-
from
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
9
|
+
from .base import GenAxisArray
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@consumer
|
|
13
|
+
def downsample(
|
|
14
|
+
axis: typing.Optional[str] = None, factor: int = 1
|
|
15
|
+
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
16
|
+
"""
|
|
17
|
+
Construct a generator that yields a downsampled version of the data .send() to it.
|
|
18
|
+
Downsampled data simply comprise every `factor`th sample.
|
|
19
|
+
This should only be used following appropriate lowpass filtering.
|
|
20
|
+
If your pipeline does not already have lowpass filtering then consider
|
|
21
|
+
using the :obj:`Decimate` collection instead.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
axis: The name of the axis along which to downsample.
|
|
25
|
+
factor: Downsampling factor.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
A primed generator object ready to receive a `.send(axis_array)`
|
|
29
|
+
and yields the downsampled data.
|
|
30
|
+
Note that if a send chunk does not have sufficient samples to reach the
|
|
31
|
+
next downsample interval then `None` is yielded.
|
|
32
|
+
|
|
33
|
+
"""
|
|
34
|
+
msg_out = AxisArray(np.array([]), dims=[""])
|
|
35
|
+
|
|
36
|
+
if factor < 1:
|
|
37
|
+
raise ValueError("Downsample factor must be at least 1 (no downsampling)")
|
|
38
|
+
|
|
39
|
+
# state variables
|
|
40
|
+
s_idx: int = 0 # Index of the next msg's first sample into the virtual rotating ds_factor counter.
|
|
41
|
+
|
|
42
|
+
check_input = {"gain": None, "key": None}
|
|
43
|
+
|
|
44
|
+
while True:
|
|
45
|
+
msg_in: AxisArray = yield msg_out
|
|
46
|
+
|
|
47
|
+
if axis is None:
|
|
48
|
+
axis = msg_in.dims[0]
|
|
49
|
+
axis_info = msg_in.get_axis(axis)
|
|
50
|
+
axis_idx = msg_in.get_axis_idx(axis)
|
|
51
|
+
|
|
52
|
+
b_reset = (
|
|
53
|
+
msg_in.axes[axis].gain != check_input["gain"]
|
|
54
|
+
or msg_in.key != check_input["key"]
|
|
55
|
+
)
|
|
56
|
+
if b_reset:
|
|
57
|
+
check_input["gain"] = axis_info.gain
|
|
58
|
+
check_input["key"] = msg_in.key
|
|
59
|
+
# Reset state variables
|
|
60
|
+
s_idx = 0
|
|
61
|
+
|
|
62
|
+
n_samples = msg_in.data.shape[axis_idx]
|
|
63
|
+
samples = np.arange(s_idx, s_idx + n_samples) % factor
|
|
64
|
+
if n_samples > 0:
|
|
65
|
+
# Update state for next iteration.
|
|
66
|
+
s_idx = samples[-1] + 1
|
|
23
67
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
68
|
+
pub_samples = np.where(samples == 0)[0]
|
|
69
|
+
if len(pub_samples) > 0:
|
|
70
|
+
n_step = pub_samples[0].item()
|
|
71
|
+
data_slice = pub_samples
|
|
72
|
+
else:
|
|
73
|
+
n_step = 0
|
|
74
|
+
data_slice = slice(None, 0, None)
|
|
75
|
+
msg_out = replace(
|
|
76
|
+
msg_in,
|
|
77
|
+
data=slice_along_axis(msg_in.data, data_slice, axis=axis_idx),
|
|
78
|
+
axes={
|
|
79
|
+
**msg_in.axes,
|
|
80
|
+
axis: replace(
|
|
81
|
+
axis_info,
|
|
82
|
+
gain=axis_info.gain * factor,
|
|
83
|
+
offset=axis_info.offset + axis_info.gain * n_step,
|
|
84
|
+
),
|
|
85
|
+
},
|
|
86
|
+
)
|
|
27
87
|
|
|
28
|
-
INPUT_SETTINGS = ez.InputStream(DownsampleSettings)
|
|
29
|
-
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
30
|
-
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
31
88
|
|
|
32
|
-
|
|
33
|
-
|
|
89
|
+
class DownsampleSettings(ez.Settings):
|
|
90
|
+
"""
|
|
91
|
+
Settings for :obj:`Downsample` node.
|
|
92
|
+
See :obj:`downsample` documentation for a description of the parameters.
|
|
93
|
+
"""
|
|
34
94
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
self.STATE.cur_settings = msg
|
|
95
|
+
axis: typing.Optional[str] = None
|
|
96
|
+
factor: int = 1
|
|
38
97
|
|
|
39
|
-
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
|
|
40
|
-
@ez.publisher(OUTPUT_SIGNAL)
|
|
41
|
-
async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
|
|
42
|
-
if self.STATE.cur_settings.factor < 1:
|
|
43
|
-
raise ValueError("Downsample factor must be at least 1 (no downsampling)")
|
|
44
98
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
axis_name = msg.dims[0]
|
|
48
|
-
axis = msg.get_axis(axis_name)
|
|
49
|
-
axis_idx = msg.get_axis_idx(axis_name)
|
|
99
|
+
class Downsample(GenAxisArray):
|
|
100
|
+
""":obj:`Unit` for :obj:`bandpower`."""
|
|
50
101
|
|
|
51
|
-
|
|
52
|
-
samples = samples % self.STATE.cur_settings.factor
|
|
53
|
-
self.STATE.s_idx = samples[-1] + 1
|
|
102
|
+
SETTINGS = DownsampleSettings
|
|
54
103
|
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
new_gain = axis.gain * self.STATE.cur_settings.factor
|
|
60
|
-
new_axes[axis_name] = replace(axis, gain=new_gain, offset=new_offset)
|
|
61
|
-
down_data = np.take(msg.data, pub_samples, axis_idx)
|
|
62
|
-
out_msg = replace(msg, data=down_data, dims=msg.dims, axes=new_axes)
|
|
63
|
-
yield self.OUTPUT_SIGNAL, out_msg
|
|
104
|
+
def construct_generator(self):
|
|
105
|
+
self.STATE.gen = downsample(
|
|
106
|
+
axis=self.SETTINGS.axis, factor=self.SETTINGS.factor
|
|
107
|
+
)
|
ezmsg/sigproc/ewmfilter.py
CHANGED
|
@@ -1,19 +1,20 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
from dataclasses import replace
|
|
3
|
+
import typing
|
|
3
4
|
|
|
4
5
|
import ezmsg.core as ez
|
|
5
6
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
6
|
-
|
|
7
7
|
import numpy as np
|
|
8
8
|
|
|
9
9
|
from .window import Window, WindowSettings
|
|
10
10
|
|
|
11
|
-
from typing import AsyncGenerator, Optional
|
|
12
|
-
|
|
13
11
|
|
|
14
12
|
class EWMSettings(ez.Settings):
|
|
15
|
-
axis: Optional[str] = None
|
|
16
|
-
|
|
13
|
+
axis: typing.Optional[str] = None
|
|
14
|
+
"""Name of the axis to accumulate."""
|
|
15
|
+
|
|
16
|
+
zero_offset: bool = True
|
|
17
|
+
"""If true, we assume zero DC offset for input data."""
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class EWMState(ez.State):
|
|
@@ -28,14 +29,14 @@ class EWM(ez.Unit):
|
|
|
28
29
|
References https://stackoverflow.com/a/42926270
|
|
29
30
|
"""
|
|
30
31
|
|
|
31
|
-
SETTINGS
|
|
32
|
-
STATE
|
|
32
|
+
SETTINGS = EWMSettings
|
|
33
|
+
STATE = EWMState
|
|
33
34
|
|
|
34
35
|
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
35
36
|
INPUT_BUFFER = ez.InputStream(AxisArray)
|
|
36
37
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
37
38
|
|
|
38
|
-
def initialize(self) -> None:
|
|
39
|
+
async def initialize(self) -> None:
|
|
39
40
|
self.STATE.signal_queue = asyncio.Queue()
|
|
40
41
|
self.STATE.buffer_queue = asyncio.Queue()
|
|
41
42
|
|
|
@@ -48,7 +49,7 @@ class EWM(ez.Unit):
|
|
|
48
49
|
self.STATE.buffer_queue.put_nowait(message)
|
|
49
50
|
|
|
50
51
|
@ez.publisher(OUTPUT_SIGNAL)
|
|
51
|
-
async def sync_output(self) -> AsyncGenerator:
|
|
52
|
+
async def sync_output(self) -> typing.AsyncGenerator:
|
|
52
53
|
while True:
|
|
53
54
|
signal = await self.STATE.signal_queue.get()
|
|
54
55
|
buffer = await self.STATE.buffer_queue.get() # includes signal
|
|
@@ -73,9 +74,12 @@ class EWM(ez.Unit):
|
|
|
73
74
|
buffer_data = buffer.data
|
|
74
75
|
buffer_data = np.moveaxis(buffer_data, axis_idx, 0)
|
|
75
76
|
|
|
77
|
+
while scale_arr.ndim < buffer_data.ndim:
|
|
78
|
+
scale_arr = scale_arr[..., None]
|
|
79
|
+
|
|
76
80
|
def ewma(data: np.ndarray) -> np.ndarray:
|
|
77
|
-
mult = scale_arr
|
|
78
|
-
out = scale_arr[::-1
|
|
81
|
+
mult = scale_arr * data * pw0
|
|
82
|
+
out = scale_arr[::-1] * mult.cumsum(axis=0)
|
|
79
83
|
|
|
80
84
|
if not self.SETTINGS.zero_offset:
|
|
81
85
|
out = (data[0, :, np.newaxis] * pows[1:]).T + out
|
|
@@ -93,13 +97,26 @@ class EWM(ez.Unit):
|
|
|
93
97
|
|
|
94
98
|
|
|
95
99
|
class EWMFilterSettings(ez.Settings):
|
|
96
|
-
history_dur: float
|
|
97
|
-
|
|
98
|
-
|
|
100
|
+
history_dur: float
|
|
101
|
+
"""Previous data to accumulate for standardization."""
|
|
102
|
+
|
|
103
|
+
axis: typing.Optional[str] = None
|
|
104
|
+
"""Name of the axis to accumulate."""
|
|
105
|
+
|
|
106
|
+
zero_offset: bool = True
|
|
107
|
+
"""If true, we assume zero DC offset for input data."""
|
|
99
108
|
|
|
100
109
|
|
|
101
110
|
class EWMFilter(ez.Collection):
|
|
102
|
-
|
|
111
|
+
"""
|
|
112
|
+
A :obj:`Collection` that splits the input into a branch that
|
|
113
|
+
leads to :obj:`Window` which then feeds into :obj:`EWM` 's INPUT_BUFFER
|
|
114
|
+
and another branch that feeds directly into :obj:`EWM` 's INPUT_SIGNAL.
|
|
115
|
+
|
|
116
|
+
Consider :obj:`scaler` for a more efficient alternative.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
SETTINGS = EWMFilterSettings
|
|
103
120
|
|
|
104
121
|
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
105
122
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
@@ -108,7 +125,12 @@ class EWMFilter(ez.Collection):
|
|
|
108
125
|
EWM = EWM()
|
|
109
126
|
|
|
110
127
|
def configure(self) -> None:
|
|
111
|
-
self.EWM.apply_settings(
|
|
128
|
+
self.EWM.apply_settings(
|
|
129
|
+
EWMSettings(
|
|
130
|
+
axis=self.SETTINGS.axis,
|
|
131
|
+
zero_offset=self.SETTINGS.zero_offset,
|
|
132
|
+
)
|
|
133
|
+
)
|
|
112
134
|
|
|
113
135
|
self.WINDOW.apply_settings(
|
|
114
136
|
WindowSettings(
|
ezmsg/sigproc/filter.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
from dataclasses import dataclass, replace, field
|
|
3
|
+
import typing
|
|
2
4
|
|
|
3
5
|
import ezmsg.core as ez
|
|
4
|
-
import scipy.signal
|
|
5
|
-
import numpy as np
|
|
6
|
-
import asyncio
|
|
7
|
-
|
|
8
6
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
|
-
|
|
10
|
-
|
|
7
|
+
from ezmsg.util.generator import consumer
|
|
8
|
+
import numpy as np
|
|
9
|
+
import numpy.typing as npt
|
|
10
|
+
import scipy.signal
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
@dataclass
|
|
@@ -16,39 +16,124 @@ class FilterCoefficients:
|
|
|
16
16
|
a: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0]))
|
|
17
17
|
|
|
18
18
|
|
|
19
|
+
def _normalize_coefs(
|
|
20
|
+
coefs: typing.Union[
|
|
21
|
+
FilterCoefficients, typing.Tuple[npt.NDArray, npt.NDArray], npt.NDArray
|
|
22
|
+
],
|
|
23
|
+
) -> typing.Tuple[str, typing.Tuple[npt.NDArray, ...]]:
|
|
24
|
+
coef_type = "ba"
|
|
25
|
+
if coefs is not None:
|
|
26
|
+
# scipy.signal functions called with first arg `*coefs`.
|
|
27
|
+
# Make sure we have a tuple of coefficients.
|
|
28
|
+
if isinstance(coefs, npt.NDArray):
|
|
29
|
+
coef_type = "sos"
|
|
30
|
+
coefs = (coefs,) # sos funcs just want a single ndarray.
|
|
31
|
+
elif isinstance(coefs, FilterCoefficients):
|
|
32
|
+
coefs = (FilterCoefficients.b, FilterCoefficients.a)
|
|
33
|
+
return coef_type, coefs
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@consumer
|
|
37
|
+
def filtergen(
|
|
38
|
+
axis: str, coefs: typing.Optional[typing.Tuple[np.ndarray]], coef_type: str
|
|
39
|
+
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
40
|
+
"""
|
|
41
|
+
Construct a generic filter generator function.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
axis: The name of the axis to operate on.
|
|
45
|
+
coefs: The pre-calculated filter coefficients.
|
|
46
|
+
coef_type: The type of filter coefficients. One of "ba" or "sos".
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
A generator that expects .send(axis_array) and yields the filtered :obj:`AxisArray`.
|
|
50
|
+
"""
|
|
51
|
+
# Massage inputs
|
|
52
|
+
if coefs is not None and not isinstance(coefs, tuple):
|
|
53
|
+
# scipy.signal functions called with first arg `*coefs`, but sos coefs are a single ndarray.
|
|
54
|
+
coefs = (coefs,)
|
|
55
|
+
|
|
56
|
+
# Init IO
|
|
57
|
+
msg_out = AxisArray(np.array([]), dims=[""])
|
|
58
|
+
|
|
59
|
+
filt_func = {"ba": scipy.signal.lfilter, "sos": scipy.signal.sosfilt}[coef_type]
|
|
60
|
+
zi_func = {"ba": scipy.signal.lfilter_zi, "sos": scipy.signal.sosfilt_zi}[coef_type]
|
|
61
|
+
|
|
62
|
+
# State variables
|
|
63
|
+
zi: typing.Optional[npt.NDArray] = None
|
|
64
|
+
|
|
65
|
+
# Reset if these change.
|
|
66
|
+
check_input = {"key": None, "shape": None}
|
|
67
|
+
# fs changing will be handled by caller that creates coefficients.
|
|
68
|
+
|
|
69
|
+
while True:
|
|
70
|
+
msg_in: AxisArray = yield msg_out
|
|
71
|
+
|
|
72
|
+
if coefs is None:
|
|
73
|
+
# passthrough if we do not have a filter design.
|
|
74
|
+
msg_out = msg_in
|
|
75
|
+
continue
|
|
76
|
+
|
|
77
|
+
axis = msg_in.dims[0] if axis is None else axis
|
|
78
|
+
axis_idx = msg_in.get_axis_idx(axis)
|
|
79
|
+
|
|
80
|
+
# Re-calculate/reset zi if necessary
|
|
81
|
+
samp_shape = msg_in.data.shape[:axis_idx] + msg_in.data.shape[axis_idx + 1 :]
|
|
82
|
+
b_reset = samp_shape != check_input["shape"]
|
|
83
|
+
b_reset = b_reset or msg_in.key != check_input["key"]
|
|
84
|
+
if b_reset:
|
|
85
|
+
check_input["shape"] = samp_shape
|
|
86
|
+
check_input["key"] = msg_in.key
|
|
87
|
+
|
|
88
|
+
n_tail = msg_in.data.ndim - axis_idx - 1
|
|
89
|
+
zi = zi_func(*coefs)
|
|
90
|
+
zi_expand = (None,) * axis_idx + (slice(None),) + (None,) * n_tail
|
|
91
|
+
n_tile = (
|
|
92
|
+
msg_in.data.shape[:axis_idx] + (1,) + msg_in.data.shape[axis_idx + 1 :]
|
|
93
|
+
)
|
|
94
|
+
if coef_type == "sos":
|
|
95
|
+
# sos zi must keep its leading dimension (`order / 2` for low|high; `order` for bpass|bstop)
|
|
96
|
+
zi_expand = (slice(None),) + zi_expand
|
|
97
|
+
n_tile = (1,) + n_tile
|
|
98
|
+
zi = np.tile(zi[zi_expand], n_tile)
|
|
99
|
+
|
|
100
|
+
dat_out, zi = filt_func(*coefs, msg_in.data, axis=axis_idx, zi=zi)
|
|
101
|
+
msg_out = replace(msg_in, data=dat_out)
|
|
102
|
+
|
|
103
|
+
|
|
19
104
|
class FilterSettingsBase(ez.Settings):
|
|
20
|
-
axis: Optional[str] = None
|
|
21
|
-
fs: Optional[float] = None
|
|
105
|
+
axis: typing.Optional[str] = None
|
|
106
|
+
fs: typing.Optional[float] = None
|
|
22
107
|
|
|
23
108
|
|
|
24
109
|
class FilterSettings(FilterSettingsBase):
|
|
25
110
|
# If you'd like to statically design a filter, define it in settings
|
|
26
|
-
filt: Optional[FilterCoefficients] = None
|
|
111
|
+
filt: typing.Optional[FilterCoefficients] = None
|
|
27
112
|
|
|
28
113
|
|
|
29
114
|
class FilterState(ez.State):
|
|
30
|
-
axis: Optional[str] = None
|
|
31
|
-
zi: Optional[np.ndarray] = None
|
|
115
|
+
axis: typing.Optional[str] = None
|
|
116
|
+
zi: typing.Optional[np.ndarray] = None
|
|
32
117
|
filt_designed: bool = False
|
|
33
|
-
filt: Optional[FilterCoefficients] = None
|
|
118
|
+
filt: typing.Optional[FilterCoefficients] = None
|
|
34
119
|
filt_set: asyncio.Event = field(default_factory=asyncio.Event)
|
|
35
|
-
samp_shape: Optional[Tuple[int, ...]] = None
|
|
36
|
-
fs: Optional[float] = None # Hz
|
|
120
|
+
samp_shape: typing.Optional[typing.Tuple[int, ...]] = None
|
|
121
|
+
fs: typing.Optional[float] = None # Hz
|
|
37
122
|
|
|
38
123
|
|
|
39
124
|
class Filter(ez.Unit):
|
|
40
|
-
SETTINGS
|
|
41
|
-
STATE
|
|
125
|
+
SETTINGS = FilterSettingsBase
|
|
126
|
+
STATE = FilterState
|
|
42
127
|
|
|
43
128
|
INPUT_FILTER = ez.InputStream(FilterCoefficients)
|
|
44
129
|
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
45
130
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
46
131
|
|
|
47
|
-
def design_filter(self) -> Optional[Tuple[np.ndarray, np.ndarray]]:
|
|
132
|
+
def design_filter(self) -> typing.Optional[typing.Tuple[np.ndarray, np.ndarray]]:
|
|
48
133
|
raise NotImplementedError("Must implement 'design_filter' in Unit subclass!")
|
|
49
134
|
|
|
50
135
|
# Set up filter with static initialization if specified
|
|
51
|
-
def initialize(self) -> None:
|
|
136
|
+
async def initialize(self) -> None:
|
|
52
137
|
if self.SETTINGS.axis is not None:
|
|
53
138
|
self.STATE.axis = self.SETTINGS.axis
|
|
54
139
|
|
|
@@ -84,7 +169,7 @@ class Filter(ez.Unit):
|
|
|
84
169
|
|
|
85
170
|
@ez.subscriber(INPUT_SIGNAL)
|
|
86
171
|
@ez.publisher(OUTPUT_SIGNAL)
|
|
87
|
-
async def apply_filter(self, msg: AxisArray) -> AsyncGenerator:
|
|
172
|
+
async def apply_filter(self, msg: AxisArray) -> typing.AsyncGenerator:
|
|
88
173
|
axis_name = msg.dims[0] if self.STATE.axis is None else self.STATE.axis
|
|
89
174
|
axis_idx = msg.get_axis_idx(axis_name)
|
|
90
175
|
axis = msg.get_axis(axis_name)
|
|
@@ -137,4 +222,7 @@ class Filter(ez.Unit):
|
|
|
137
222
|
if one_dimensional:
|
|
138
223
|
arr_out = np.squeeze(arr_out, axis=1)
|
|
139
224
|
|
|
140
|
-
yield
|
|
225
|
+
yield (
|
|
226
|
+
self.OUTPUT_SIGNAL,
|
|
227
|
+
replace(msg, data=arr_out),
|
|
228
|
+
)
|