ezmsg-sigproc 1.2.3__py3-none-any.whl → 1.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +1 -4
- ezmsg/sigproc/__version__.py +16 -0
- ezmsg/sigproc/activation.py +75 -0
- ezmsg/sigproc/affinetransform.py +149 -39
- ezmsg/sigproc/aggregate.py +84 -29
- ezmsg/sigproc/bandpower.py +36 -15
- ezmsg/sigproc/base.py +38 -0
- ezmsg/sigproc/butterworthfilter.py +76 -20
- ezmsg/sigproc/decimate.py +7 -4
- ezmsg/sigproc/downsample.py +79 -61
- ezmsg/sigproc/ewmfilter.py +28 -14
- ezmsg/sigproc/filter.py +51 -31
- ezmsg/sigproc/filterbank.py +278 -0
- ezmsg/sigproc/math/__init__.py +0 -0
- ezmsg/sigproc/math/abs.py +28 -0
- ezmsg/sigproc/math/clip.py +30 -0
- ezmsg/sigproc/math/difference.py +60 -0
- ezmsg/sigproc/math/invert.py +29 -0
- ezmsg/sigproc/math/log.py +32 -0
- ezmsg/sigproc/math/scale.py +31 -0
- ezmsg/sigproc/messages.py +2 -3
- ezmsg/sigproc/sampler.py +152 -90
- ezmsg/sigproc/scaler.py +88 -42
- ezmsg/sigproc/signalinjector.py +7 -10
- ezmsg/sigproc/slicer.py +71 -36
- ezmsg/sigproc/spectral.py +6 -9
- ezmsg/sigproc/spectrogram.py +48 -30
- ezmsg/sigproc/spectrum.py +177 -76
- ezmsg/sigproc/synth.py +162 -67
- ezmsg/sigproc/wavelets.py +167 -0
- ezmsg/sigproc/window.py +193 -157
- ezmsg_sigproc-1.3.2.dist-info/METADATA +59 -0
- ezmsg_sigproc-1.3.2.dist-info/RECORD +35 -0
- {ezmsg_sigproc-1.2.3.dist-info → ezmsg_sigproc-1.3.2.dist-info}/WHEEL +1 -1
- ezmsg_sigproc-1.2.3.dist-info/METADATA +0 -38
- ezmsg_sigproc-1.2.3.dist-info/RECORD +0 -23
- {ezmsg_sigproc-1.2.3.dist-info → ezmsg_sigproc-1.3.2.dist-info/licenses}/LICENSE.txt +0 -0
|
@@ -3,19 +3,45 @@ import typing
|
|
|
3
3
|
import ezmsg.core as ez
|
|
4
4
|
import scipy.signal
|
|
5
5
|
import numpy as np
|
|
6
|
-
|
|
7
|
-
from .filter import filtergen, Filter, FilterState, FilterSettingsBase
|
|
8
|
-
|
|
9
6
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
10
7
|
from ezmsg.util.generator import consumer
|
|
11
8
|
|
|
9
|
+
from .filter import filtergen, Filter, FilterState, FilterSettingsBase
|
|
10
|
+
|
|
12
11
|
|
|
13
12
|
class ButterworthFilterSettings(FilterSettingsBase):
|
|
13
|
+
"""Settings for :obj:`ButterworthFilter`."""
|
|
14
|
+
|
|
14
15
|
order: int = 0
|
|
15
|
-
cuton: typing.Optional[float] = None # Hz
|
|
16
|
-
cutoff: typing.Optional[float] = None # Hz
|
|
17
16
|
|
|
18
|
-
|
|
17
|
+
cuton: typing.Optional[float] = None
|
|
18
|
+
"""
|
|
19
|
+
Cuton frequency (Hz). If cutoff is not specified then this is the highpass corner, otherwise
|
|
20
|
+
if it is lower than cutoff then this is the beginning of the bandpass
|
|
21
|
+
or if it is greater than cuton then it is the end of the bandstop.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
cutoff: typing.Optional[float] = None
|
|
25
|
+
"""
|
|
26
|
+
Cutoff frequency (Hz). If cuton is not specified then this is the lowpass corner, otherwise
|
|
27
|
+
if it is greater than cuton then this is the end of the bandpass,
|
|
28
|
+
or if it is less than cuton then it is the beginning of the bandstop.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def filter_specs(
|
|
32
|
+
self,
|
|
33
|
+
) -> typing.Optional[
|
|
34
|
+
typing.Tuple[str, typing.Union[float, typing.Tuple[float, float]]]
|
|
35
|
+
]:
|
|
36
|
+
"""
|
|
37
|
+
Determine the filter type given the corner frequencies.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A tuple with the first element being a string indicating the filter type
|
|
41
|
+
(one of "lowpass", "highpass", "bandpass", "bandstop")
|
|
42
|
+
and the second element being the corner frequency or frequencies.
|
|
43
|
+
|
|
44
|
+
"""
|
|
19
45
|
if self.cuton is None and self.cutoff is None:
|
|
20
46
|
return None
|
|
21
47
|
elif self.cuton is None and self.cutoff is not None:
|
|
@@ -37,28 +63,56 @@ def butter(
|
|
|
37
63
|
cutoff: typing.Optional[float] = None,
|
|
38
64
|
coef_type: str = "ba",
|
|
39
65
|
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
66
|
+
"""
|
|
67
|
+
Apply Butterworth filter to streaming data. Uses :obj:`scipy.signal.butter` to design the filter.
|
|
68
|
+
See :obj:`ButterworthFilterSettings.filter_specs` for an explanation of specifying different
|
|
69
|
+
filter types (lowpass, highpass, bandpass, bandstop) from the parameters.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
axis: The name of the axis to filter.
|
|
73
|
+
order: Filter order.
|
|
74
|
+
cuton: Corner frequency of the filter in Hz.
|
|
75
|
+
cutoff: Corner frequency of the filter in Hz.
|
|
76
|
+
coef_type: "ba" or "sos"
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
A primed generator object which accepts .send(axis_array) and yields filtered axis array.
|
|
80
|
+
|
|
81
|
+
"""
|
|
40
82
|
# IO
|
|
41
|
-
|
|
42
|
-
axis_arr_out = AxisArray(np.array([]), dims=[""])
|
|
83
|
+
msg_out = AxisArray(np.array([]), dims=[""])
|
|
43
84
|
|
|
85
|
+
# Check parameters
|
|
44
86
|
btype, cutoffs = ButterworthFilterSettings(
|
|
45
87
|
order=order, cuton=cuton, cutoff=cutoff
|
|
46
88
|
).filter_specs()
|
|
47
89
|
|
|
48
|
-
#
|
|
49
|
-
|
|
50
|
-
filter_gen = filtergen(axis,
|
|
90
|
+
# State variables
|
|
91
|
+
# Initialize filtergen as passthrough until we can calculate coefs.
|
|
92
|
+
filter_gen = filtergen(axis, None, coef_type)
|
|
93
|
+
|
|
94
|
+
# Reset if these change.
|
|
95
|
+
check_input = {"gain": None}
|
|
96
|
+
# Key not checked because filter_gen will handle resetting if .key changes.
|
|
51
97
|
|
|
52
98
|
while True:
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
99
|
+
msg_in: AxisArray = yield msg_out
|
|
100
|
+
axis = axis or msg_in.dims[0]
|
|
101
|
+
|
|
102
|
+
b_reset = msg_in.axes[axis].gain != check_input["gain"]
|
|
103
|
+
b_reset = b_reset and order > 0 # Not passthrough
|
|
104
|
+
if b_reset:
|
|
105
|
+
check_input["gain"] = msg_in.axes[axis].gain
|
|
56
106
|
coefs = scipy.signal.butter(
|
|
57
|
-
order,
|
|
107
|
+
order,
|
|
108
|
+
Wn=cutoffs,
|
|
109
|
+
btype=btype,
|
|
110
|
+
fs=1 / msg_in.axes[axis].gain,
|
|
111
|
+
output=coef_type,
|
|
58
112
|
)
|
|
59
113
|
filter_gen = filtergen(axis, coefs, coef_type)
|
|
60
114
|
|
|
61
|
-
|
|
115
|
+
msg_out = filter_gen.send(msg_in)
|
|
62
116
|
|
|
63
117
|
|
|
64
118
|
class ButterworthFilterState(FilterState):
|
|
@@ -66,15 +120,17 @@ class ButterworthFilterState(FilterState):
|
|
|
66
120
|
|
|
67
121
|
|
|
68
122
|
class ButterworthFilter(Filter):
|
|
69
|
-
|
|
70
|
-
|
|
123
|
+
""":obj:`Unit` for :obj:`butterworth`"""
|
|
124
|
+
|
|
125
|
+
SETTINGS = ButterworthFilterSettings
|
|
126
|
+
STATE = ButterworthFilterState
|
|
71
127
|
|
|
72
128
|
INPUT_FILTER = ez.InputStream(ButterworthFilterSettings)
|
|
73
129
|
|
|
74
|
-
def initialize(self) -> None:
|
|
130
|
+
async def initialize(self) -> None:
|
|
75
131
|
self.STATE.design = self.SETTINGS
|
|
76
132
|
self.STATE.filt_designed = True
|
|
77
|
-
super().initialize()
|
|
133
|
+
await super().initialize()
|
|
78
134
|
|
|
79
135
|
def design_filter(self) -> typing.Optional[typing.Tuple[np.ndarray, np.ndarray]]:
|
|
80
136
|
specs = self.STATE.design.filter_specs()
|
ezmsg/sigproc/decimate.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
|
-
import ezmsg.core as ez
|
|
2
|
-
|
|
3
1
|
import scipy.signal
|
|
4
|
-
|
|
2
|
+
import ezmsg.core as ez
|
|
5
3
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
6
4
|
|
|
7
5
|
from .downsample import Downsample, DownsampleSettings
|
|
@@ -9,7 +7,12 @@ from .filter import Filter, FilterCoefficients, FilterSettings
|
|
|
9
7
|
|
|
10
8
|
|
|
11
9
|
class Decimate(ez.Collection):
|
|
12
|
-
|
|
10
|
+
"""
|
|
11
|
+
A :obj:`Collection` chaining a :obj:`Filter` node configured as a lowpass Chebyshev filter
|
|
12
|
+
and a :obj:`Downsample` node.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
SETTINGS = DownsampleSettings
|
|
13
16
|
|
|
14
17
|
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
15
18
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
ezmsg/sigproc/downsample.py
CHANGED
|
@@ -1,89 +1,107 @@
|
|
|
1
1
|
from dataclasses import replace
|
|
2
|
-
import
|
|
3
|
-
from typing import AsyncGenerator, Optional, Generator
|
|
2
|
+
import typing
|
|
4
3
|
|
|
5
4
|
import numpy as np
|
|
6
|
-
|
|
7
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
5
|
+
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
8
6
|
from ezmsg.util.generator import consumer
|
|
9
7
|
import ezmsg.core as ez
|
|
10
8
|
|
|
9
|
+
from .base import GenAxisArray
|
|
10
|
+
|
|
11
11
|
|
|
12
12
|
@consumer
|
|
13
13
|
def downsample(
|
|
14
|
-
|
|
15
|
-
) -> Generator[AxisArray, AxisArray, None]:
|
|
16
|
-
|
|
17
|
-
|
|
14
|
+
axis: typing.Optional[str] = None, factor: int = 1
|
|
15
|
+
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
16
|
+
"""
|
|
17
|
+
Construct a generator that yields a downsampled version of the data .send() to it.
|
|
18
|
+
Downsampled data simply comprise every `factor`th sample.
|
|
19
|
+
This should only be used following appropriate lowpass filtering.
|
|
20
|
+
If your pipeline does not already have lowpass filtering then consider
|
|
21
|
+
using the :obj:`Decimate` collection instead.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
axis: The name of the axis along which to downsample.
|
|
25
|
+
factor: Downsampling factor.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
A primed generator object ready to receive a `.send(axis_array)`
|
|
29
|
+
and yields the downsampled data.
|
|
30
|
+
Note that if a send chunk does not have sufficient samples to reach the
|
|
31
|
+
next downsample interval then `None` is yielded.
|
|
32
|
+
|
|
33
|
+
"""
|
|
34
|
+
msg_out = AxisArray(np.array([]), dims=[""])
|
|
35
|
+
|
|
36
|
+
if factor < 1:
|
|
37
|
+
raise ValueError("Downsample factor must be at least 1 (no downsampling)")
|
|
18
38
|
|
|
19
39
|
# state variables
|
|
20
|
-
s_idx = 0
|
|
40
|
+
s_idx: int = 0 # Index of the next msg's first sample into the virtual rotating ds_factor counter.
|
|
41
|
+
|
|
42
|
+
check_input = {"gain": None, "key": None}
|
|
21
43
|
|
|
22
44
|
while True:
|
|
23
|
-
|
|
45
|
+
msg_in: AxisArray = yield msg_out
|
|
24
46
|
|
|
25
47
|
if axis is None:
|
|
26
|
-
axis =
|
|
27
|
-
axis_info =
|
|
28
|
-
axis_idx =
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
48
|
+
axis = msg_in.dims[0]
|
|
49
|
+
axis_info = msg_in.get_axis(axis)
|
|
50
|
+
axis_idx = msg_in.get_axis_idx(axis)
|
|
51
|
+
|
|
52
|
+
b_reset = (
|
|
53
|
+
msg_in.axes[axis].gain != check_input["gain"]
|
|
54
|
+
or msg_in.key != check_input["key"]
|
|
55
|
+
)
|
|
56
|
+
if b_reset:
|
|
57
|
+
check_input["gain"] = axis_info.gain
|
|
58
|
+
check_input["key"] = msg_in.key
|
|
59
|
+
# Reset state variables
|
|
60
|
+
s_idx = 0
|
|
61
|
+
|
|
62
|
+
n_samples = msg_in.data.shape[axis_idx]
|
|
63
|
+
samples = np.arange(s_idx, s_idx + n_samples) % factor
|
|
64
|
+
if n_samples > 0:
|
|
65
|
+
# Update state for next iteration.
|
|
66
|
+
s_idx = samples[-1] + 1
|
|
33
67
|
|
|
34
68
|
pub_samples = np.where(samples == 0)[0]
|
|
35
69
|
if len(pub_samples) > 0:
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
new_gain = axis_info.gain * factor
|
|
39
|
-
new_axes[axis] = replace(axis_info, gain=new_gain, offset=new_offset)
|
|
40
|
-
down_data = np.take(axis_arr_in.data, pub_samples, axis=axis_idx)
|
|
41
|
-
axis_arr_out = replace(axis_arr_in, data=down_data, dims=axis_arr_in.dims, axes=new_axes)
|
|
70
|
+
n_step = pub_samples[0].item()
|
|
71
|
+
data_slice = pub_samples
|
|
42
72
|
else:
|
|
43
|
-
|
|
73
|
+
n_step = 0
|
|
74
|
+
data_slice = slice(None, 0, None)
|
|
75
|
+
msg_out = replace(
|
|
76
|
+
msg_in,
|
|
77
|
+
data=slice_along_axis(msg_in.data, data_slice, axis=axis_idx),
|
|
78
|
+
axes={
|
|
79
|
+
**msg_in.axes,
|
|
80
|
+
axis: replace(
|
|
81
|
+
axis_info,
|
|
82
|
+
gain=axis_info.gain * factor,
|
|
83
|
+
offset=axis_info.offset + axis_info.gain * n_step,
|
|
84
|
+
),
|
|
85
|
+
},
|
|
86
|
+
)
|
|
44
87
|
|
|
45
88
|
|
|
46
89
|
class DownsampleSettings(ez.Settings):
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
90
|
+
"""
|
|
91
|
+
Settings for :obj:`Downsample` node.
|
|
92
|
+
See :obj:`downsample` documentation for a description of the parameters.
|
|
93
|
+
"""
|
|
50
94
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
gen: Generator
|
|
95
|
+
axis: typing.Optional[str] = None
|
|
96
|
+
factor: int = 1
|
|
54
97
|
|
|
55
98
|
|
|
56
|
-
class Downsample(
|
|
57
|
-
|
|
58
|
-
STATE: DownsampleState
|
|
99
|
+
class Downsample(GenAxisArray):
|
|
100
|
+
""":obj:`Unit` for :obj:`bandpower`."""
|
|
59
101
|
|
|
60
|
-
|
|
61
|
-
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
62
|
-
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
102
|
+
SETTINGS = DownsampleSettings
|
|
63
103
|
|
|
64
104
|
def construct_generator(self):
|
|
65
|
-
self.STATE.gen = downsample(
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
self.STATE.cur_settings = self.SETTINGS
|
|
69
|
-
self.construct_generator()
|
|
70
|
-
|
|
71
|
-
@ez.subscriber(INPUT_SETTINGS)
|
|
72
|
-
async def on_settings(self, msg: DownsampleSettings) -> None:
|
|
73
|
-
self.STATE.cur_settings = msg
|
|
74
|
-
self.construct_generator()
|
|
75
|
-
|
|
76
|
-
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
|
|
77
|
-
@ez.publisher(OUTPUT_SIGNAL)
|
|
78
|
-
async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
|
|
79
|
-
if self.STATE.cur_settings.factor < 1:
|
|
80
|
-
raise ValueError("Downsample factor must be at least 1 (no downsampling)")
|
|
81
|
-
|
|
82
|
-
try:
|
|
83
|
-
out_msg = self.STATE.gen.send(msg)
|
|
84
|
-
if out_msg is not None:
|
|
85
|
-
yield self.OUTPUT_SIGNAL, out_msg
|
|
86
|
-
except (StopIteration, GeneratorExit):
|
|
87
|
-
ez.logger.debug(f"Downsample closed in {self.address}")
|
|
88
|
-
except Exception:
|
|
89
|
-
ez.logger.info(traceback.format_exc())
|
|
105
|
+
self.STATE.gen = downsample(
|
|
106
|
+
axis=self.SETTINGS.axis, factor=self.SETTINGS.factor
|
|
107
|
+
)
|
ezmsg/sigproc/ewmfilter.py
CHANGED
|
@@ -1,19 +1,20 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
from dataclasses import replace
|
|
3
|
+
import typing
|
|
3
4
|
|
|
4
5
|
import ezmsg.core as ez
|
|
5
6
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
6
|
-
|
|
7
7
|
import numpy as np
|
|
8
8
|
|
|
9
9
|
from .window import Window, WindowSettings
|
|
10
10
|
|
|
11
|
-
from typing import AsyncGenerator, Optional
|
|
12
|
-
|
|
13
11
|
|
|
14
12
|
class EWMSettings(ez.Settings):
|
|
15
|
-
axis: Optional[str] = None
|
|
16
|
-
|
|
13
|
+
axis: typing.Optional[str] = None
|
|
14
|
+
"""Name of the axis to accumulate."""
|
|
15
|
+
|
|
16
|
+
zero_offset: bool = True
|
|
17
|
+
"""If true, we assume zero DC offset for input data."""
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class EWMState(ez.State):
|
|
@@ -28,14 +29,14 @@ class EWM(ez.Unit):
|
|
|
28
29
|
References https://stackoverflow.com/a/42926270
|
|
29
30
|
"""
|
|
30
31
|
|
|
31
|
-
SETTINGS
|
|
32
|
-
STATE
|
|
32
|
+
SETTINGS = EWMSettings
|
|
33
|
+
STATE = EWMState
|
|
33
34
|
|
|
34
35
|
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
35
36
|
INPUT_BUFFER = ez.InputStream(AxisArray)
|
|
36
37
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
37
38
|
|
|
38
|
-
def initialize(self) -> None:
|
|
39
|
+
async def initialize(self) -> None:
|
|
39
40
|
self.STATE.signal_queue = asyncio.Queue()
|
|
40
41
|
self.STATE.buffer_queue = asyncio.Queue()
|
|
41
42
|
|
|
@@ -48,7 +49,7 @@ class EWM(ez.Unit):
|
|
|
48
49
|
self.STATE.buffer_queue.put_nowait(message)
|
|
49
50
|
|
|
50
51
|
@ez.publisher(OUTPUT_SIGNAL)
|
|
51
|
-
async def sync_output(self) -> AsyncGenerator:
|
|
52
|
+
async def sync_output(self) -> typing.AsyncGenerator:
|
|
52
53
|
while True:
|
|
53
54
|
signal = await self.STATE.signal_queue.get()
|
|
54
55
|
buffer = await self.STATE.buffer_queue.get() # includes signal
|
|
@@ -96,13 +97,26 @@ class EWM(ez.Unit):
|
|
|
96
97
|
|
|
97
98
|
|
|
98
99
|
class EWMFilterSettings(ez.Settings):
|
|
99
|
-
history_dur: float
|
|
100
|
-
|
|
101
|
-
|
|
100
|
+
history_dur: float
|
|
101
|
+
"""Previous data to accumulate for standardization."""
|
|
102
|
+
|
|
103
|
+
axis: typing.Optional[str] = None
|
|
104
|
+
"""Name of the axis to accumulate."""
|
|
105
|
+
|
|
106
|
+
zero_offset: bool = True
|
|
107
|
+
"""If true, we assume zero DC offset for input data."""
|
|
102
108
|
|
|
103
109
|
|
|
104
110
|
class EWMFilter(ez.Collection):
|
|
105
|
-
|
|
111
|
+
"""
|
|
112
|
+
A :obj:`Collection` that splits the input into a branch that
|
|
113
|
+
leads to :obj:`Window` which then feeds into :obj:`EWM` 's INPUT_BUFFER
|
|
114
|
+
and another branch that feeds directly into :obj:`EWM` 's INPUT_SIGNAL.
|
|
115
|
+
|
|
116
|
+
Consider :obj:`scaler` for a more efficient alternative.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
SETTINGS = EWMFilterSettings
|
|
106
120
|
|
|
107
121
|
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
108
122
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
@@ -113,7 +127,7 @@ class EWMFilter(ez.Collection):
|
|
|
113
127
|
def configure(self) -> None:
|
|
114
128
|
self.EWM.apply_settings(
|
|
115
129
|
EWMSettings(
|
|
116
|
-
axis=self.SETTINGS.axis,
|
|
130
|
+
axis=self.SETTINGS.axis,
|
|
117
131
|
zero_offset=self.SETTINGS.zero_offset,
|
|
118
132
|
)
|
|
119
133
|
)
|
ezmsg/sigproc/filter.py
CHANGED
|
@@ -1,25 +1,26 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
import typing
|
|
3
|
-
|
|
4
2
|
from dataclasses import dataclass, replace, field
|
|
3
|
+
import typing
|
|
5
4
|
|
|
6
5
|
import ezmsg.core as ez
|
|
7
|
-
import
|
|
8
|
-
|
|
6
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
+
from ezmsg.util.generator import consumer
|
|
9
8
|
import numpy as np
|
|
10
9
|
import numpy.typing as npt
|
|
10
|
+
import scipy.signal
|
|
11
11
|
|
|
12
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
13
|
-
from ezmsg.util.generator import consumer
|
|
14
12
|
|
|
15
13
|
@dataclass
|
|
16
14
|
class FilterCoefficients:
|
|
17
15
|
b: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0]))
|
|
18
16
|
a: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0]))
|
|
19
17
|
|
|
18
|
+
|
|
20
19
|
def _normalize_coefs(
|
|
21
|
-
|
|
22
|
-
|
|
20
|
+
coefs: typing.Union[
|
|
21
|
+
FilterCoefficients, typing.Tuple[npt.NDArray, npt.NDArray], npt.NDArray
|
|
22
|
+
],
|
|
23
|
+
) -> typing.Tuple[str, typing.Tuple[npt.NDArray, ...]]:
|
|
23
24
|
coef_type = "ba"
|
|
24
25
|
if coefs is not None:
|
|
25
26
|
# scipy.signal functions called with first arg `*coefs`.
|
|
@@ -31,57 +32,73 @@ def _normalize_coefs(
|
|
|
31
32
|
coefs = (FilterCoefficients.b, FilterCoefficients.a)
|
|
32
33
|
return coef_type, coefs
|
|
33
34
|
|
|
35
|
+
|
|
34
36
|
@consumer
|
|
35
37
|
def filtergen(
|
|
36
38
|
axis: str, coefs: typing.Optional[typing.Tuple[np.ndarray]], coef_type: str
|
|
37
39
|
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
40
|
+
"""
|
|
41
|
+
Construct a generic filter generator function.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
axis: The name of the axis to operate on.
|
|
45
|
+
coefs: The pre-calculated filter coefficients.
|
|
46
|
+
coef_type: The type of filter coefficients. One of "ba" or "sos".
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
A generator that expects .send(axis_array) and yields the filtered :obj:`AxisArray`.
|
|
50
|
+
"""
|
|
38
51
|
# Massage inputs
|
|
39
52
|
if coefs is not None and not isinstance(coefs, tuple):
|
|
40
53
|
# scipy.signal functions called with first arg `*coefs`, but sos coefs are a single ndarray.
|
|
41
54
|
coefs = (coefs,)
|
|
42
55
|
|
|
43
56
|
# Init IO
|
|
44
|
-
|
|
45
|
-
axis_arr_out = AxisArray(np.array([]), dims=[""])
|
|
57
|
+
msg_out = AxisArray(np.array([]), dims=[""])
|
|
46
58
|
|
|
47
59
|
filt_func = {"ba": scipy.signal.lfilter, "sos": scipy.signal.sosfilt}[coef_type]
|
|
48
60
|
zi_func = {"ba": scipy.signal.lfilter_zi, "sos": scipy.signal.sosfilt_zi}[coef_type]
|
|
49
61
|
|
|
50
62
|
# State variables
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
63
|
+
zi: typing.Optional[npt.NDArray] = None
|
|
64
|
+
|
|
65
|
+
# Reset if these change.
|
|
66
|
+
check_input = {"key": None, "shape": None}
|
|
67
|
+
# fs changing will be handled by caller that creates coefficients.
|
|
54
68
|
|
|
55
69
|
while True:
|
|
56
|
-
|
|
70
|
+
msg_in: AxisArray = yield msg_out
|
|
57
71
|
|
|
58
72
|
if coefs is None:
|
|
59
73
|
# passthrough if we do not have a filter design.
|
|
60
|
-
|
|
74
|
+
msg_out = msg_in
|
|
61
75
|
continue
|
|
62
76
|
|
|
63
|
-
if
|
|
64
|
-
|
|
65
|
-
axis_idx = axis_arr_in.get_axis_idx(axis_name)
|
|
66
|
-
|
|
67
|
-
dat_in = axis_arr_in.data
|
|
77
|
+
axis = msg_in.dims[0] if axis is None else axis
|
|
78
|
+
axis_idx = msg_in.get_axis_idx(axis)
|
|
68
79
|
|
|
69
80
|
# Re-calculate/reset zi if necessary
|
|
70
|
-
samp_shape =
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
81
|
+
samp_shape = msg_in.data.shape[:axis_idx] + msg_in.data.shape[axis_idx + 1 :]
|
|
82
|
+
b_reset = samp_shape != check_input["shape"]
|
|
83
|
+
b_reset = b_reset or msg_in.key != check_input["key"]
|
|
84
|
+
if b_reset:
|
|
85
|
+
check_input["shape"] = samp_shape
|
|
86
|
+
check_input["key"] = msg_in.key
|
|
87
|
+
|
|
88
|
+
n_tail = msg_in.data.ndim - axis_idx - 1
|
|
74
89
|
zi = zi_func(*coefs)
|
|
75
90
|
zi_expand = (None,) * axis_idx + (slice(None),) + (None,) * n_tail
|
|
76
|
-
n_tile =
|
|
91
|
+
n_tile = (
|
|
92
|
+
msg_in.data.shape[:axis_idx] + (1,) + msg_in.data.shape[axis_idx + 1 :]
|
|
93
|
+
)
|
|
77
94
|
if coef_type == "sos":
|
|
78
95
|
# sos zi must keep its leading dimension (`order / 2` for low|high; `order` for bpass|bstop)
|
|
79
96
|
zi_expand = (slice(None),) + zi_expand
|
|
80
97
|
n_tile = (1,) + n_tile
|
|
81
98
|
zi = np.tile(zi[zi_expand], n_tile)
|
|
82
99
|
|
|
83
|
-
dat_out, zi = filt_func(*coefs,
|
|
84
|
-
|
|
100
|
+
dat_out, zi = filt_func(*coefs, msg_in.data, axis=axis_idx, zi=zi)
|
|
101
|
+
msg_out = replace(msg_in, data=dat_out)
|
|
85
102
|
|
|
86
103
|
|
|
87
104
|
class FilterSettingsBase(ez.Settings):
|
|
@@ -105,8 +122,8 @@ class FilterState(ez.State):
|
|
|
105
122
|
|
|
106
123
|
|
|
107
124
|
class Filter(ez.Unit):
|
|
108
|
-
SETTINGS
|
|
109
|
-
STATE
|
|
125
|
+
SETTINGS = FilterSettingsBase
|
|
126
|
+
STATE = FilterState
|
|
110
127
|
|
|
111
128
|
INPUT_FILTER = ez.InputStream(FilterCoefficients)
|
|
112
129
|
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
@@ -116,7 +133,7 @@ class Filter(ez.Unit):
|
|
|
116
133
|
raise NotImplementedError("Must implement 'design_filter' in Unit subclass!")
|
|
117
134
|
|
|
118
135
|
# Set up filter with static initialization if specified
|
|
119
|
-
def initialize(self) -> None:
|
|
136
|
+
async def initialize(self) -> None:
|
|
120
137
|
if self.SETTINGS.axis is not None:
|
|
121
138
|
self.STATE.axis = self.SETTINGS.axis
|
|
122
139
|
|
|
@@ -205,4 +222,7 @@ class Filter(ez.Unit):
|
|
|
205
222
|
if one_dimensional:
|
|
206
223
|
arr_out = np.squeeze(arr_out, axis=1)
|
|
207
224
|
|
|
208
|
-
yield
|
|
225
|
+
yield (
|
|
226
|
+
self.OUTPUT_SIGNAL,
|
|
227
|
+
replace(msg, data=arr_out),
|
|
228
|
+
)
|