ezmsg-sigproc 1.1.1__tar.gz → 1.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ezmsg-sigproc-1.1.1/ezmsg_sigproc.egg-info → ezmsg-sigproc-1.2.0}/PKG-INFO +1 -1
- ezmsg-sigproc-1.2.0/ezmsg/sigproc/__init__.py +1 -0
- ezmsg-sigproc-1.2.0/ezmsg/sigproc/__version__.py +1 -0
- {ezmsg-sigproc-1.1.1 → ezmsg-sigproc-1.2.0}/ezmsg/sigproc/butterworthfilter.py +17 -27
- {ezmsg-sigproc-1.1.1 → ezmsg-sigproc-1.2.0}/ezmsg/sigproc/decimate.py +7 -10
- ezmsg-sigproc-1.2.0/ezmsg/sigproc/downsample.py +63 -0
- ezmsg-sigproc-1.2.0/ezmsg/sigproc/ewmfilter.py +127 -0
- {ezmsg-sigproc-1.1.1 → ezmsg-sigproc-1.2.0}/ezmsg/sigproc/filter.py +40 -24
- ezmsg-sigproc-1.2.0/ezmsg/sigproc/messages.py +31 -0
- ezmsg-sigproc-1.2.0/ezmsg/sigproc/sampler.py +287 -0
- ezmsg-sigproc-1.2.0/ezmsg/sigproc/spectral.py +132 -0
- ezmsg-sigproc-1.2.0/ezmsg/sigproc/synth.py +411 -0
- ezmsg-sigproc-1.2.0/ezmsg/sigproc/window.py +144 -0
- {ezmsg-sigproc-1.1.1 → ezmsg-sigproc-1.2.0/ezmsg_sigproc.egg-info}/PKG-INFO +1 -1
- {ezmsg-sigproc-1.1.1 → ezmsg-sigproc-1.2.0}/ezmsg_sigproc.egg-info/SOURCES.txt +5 -2
- {ezmsg-sigproc-1.1.1 → ezmsg-sigproc-1.2.0}/ezmsg_sigproc.egg-info/requires.txt +1 -1
- {ezmsg-sigproc-1.1.1 → ezmsg-sigproc-1.2.0}/setup.cfg +1 -1
- ezmsg-sigproc-1.2.0/setup.py +7 -0
- ezmsg-sigproc-1.2.0/tests/test_butterworth.py +143 -0
- ezmsg-sigproc-1.2.0/tests/test_downsample.py +133 -0
- ezmsg-sigproc-1.2.0/tests/test_window.py +140 -0
- ezmsg-sigproc-1.1.1/ezmsg/sigproc/__init__.py +0 -1
- ezmsg-sigproc-1.1.1/ezmsg/sigproc/__version__.py +0 -1
- ezmsg-sigproc-1.1.1/ezmsg/sigproc/downsample.py +0 -69
- ezmsg-sigproc-1.1.1/ezmsg/sigproc/ewmfilter.py +0 -121
- ezmsg-sigproc-1.1.1/ezmsg/sigproc/messages.py +0 -51
- ezmsg-sigproc-1.1.1/ezmsg/sigproc/sampler.py +0 -251
- ezmsg-sigproc-1.1.1/ezmsg/sigproc/synth.py +0 -236
- ezmsg-sigproc-1.1.1/ezmsg/sigproc/timeseriesmessage.py +0 -1
- ezmsg-sigproc-1.1.1/ezmsg/sigproc/window.py +0 -112
- ezmsg-sigproc-1.1.1/setup.py +0 -7
- {ezmsg-sigproc-1.1.1 → ezmsg-sigproc-1.2.0}/LICENSE.txt +0 -0
- {ezmsg-sigproc-1.1.1 → ezmsg-sigproc-1.2.0}/README.md +0 -0
- {ezmsg-sigproc-1.1.1 → ezmsg-sigproc-1.2.0}/ezmsg_sigproc.egg-info/dependency_links.txt +0 -0
- {ezmsg-sigproc-1.1.1 → ezmsg-sigproc-1.2.0}/ezmsg_sigproc.egg-info/not-zip-safe +0 -0
- {ezmsg-sigproc-1.1.1 → ezmsg-sigproc-1.2.0}/ezmsg_sigproc.egg-info/top_level.txt +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .__version__ import __version__
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "1.2.0"
|
|
@@ -1,49 +1,40 @@
|
|
|
1
|
-
from dataclasses import dataclass, field
|
|
2
|
-
import logging
|
|
3
|
-
|
|
4
1
|
import ezmsg.core as ez
|
|
5
2
|
import scipy.signal
|
|
6
3
|
import numpy as np
|
|
7
4
|
|
|
8
|
-
from .filter import Filter, FilterState,
|
|
5
|
+
from .filter import Filter, FilterState, FilterSettingsBase
|
|
9
6
|
|
|
10
7
|
from typing import Optional, Tuple, Union
|
|
11
8
|
|
|
12
|
-
logger = logging.getLogger('ezmsg')
|
|
13
|
-
|
|
14
9
|
|
|
15
|
-
|
|
16
|
-
class ButterworthFilterDesign:
|
|
10
|
+
class ButterworthFilterSettings(FilterSettingsBase):
|
|
17
11
|
order: int = 0
|
|
18
12
|
cuton: Optional[float] = None # Hz
|
|
19
13
|
cutoff: Optional[float] = None # Hz
|
|
20
14
|
|
|
21
|
-
def filter_specs(
|
|
15
|
+
def filter_specs(self) -> Optional[Tuple[str, Union[float, Tuple[float, float]]]]:
|
|
22
16
|
if self.cuton is None and self.cutoff is None:
|
|
23
17
|
return None
|
|
24
18
|
elif self.cuton is None and self.cutoff is not None:
|
|
25
|
-
return
|
|
19
|
+
return "lowpass", self.cutoff
|
|
26
20
|
elif self.cuton is not None and self.cutoff is None:
|
|
27
|
-
return
|
|
21
|
+
return "highpass", self.cuton
|
|
28
22
|
elif self.cuton is not None and self.cutoff is not None:
|
|
29
|
-
if self.cuton <= self.cutoff:
|
|
30
|
-
return
|
|
31
|
-
else:
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
class ButterworthFilterSettings(ButterworthFilterDesign, FilterSettings):
|
|
35
|
-
...
|
|
23
|
+
if self.cuton <= self.cutoff:
|
|
24
|
+
return "bandpass", (self.cuton, self.cutoff)
|
|
25
|
+
else:
|
|
26
|
+
return "bandstop", (self.cutoff, self.cuton)
|
|
36
27
|
|
|
37
28
|
|
|
38
29
|
class ButterworthFilterState(FilterState):
|
|
39
|
-
design:
|
|
30
|
+
design: ButterworthFilterSettings
|
|
40
31
|
|
|
41
32
|
|
|
42
33
|
class ButterworthFilter(Filter):
|
|
43
34
|
SETTINGS: ButterworthFilterSettings
|
|
44
35
|
STATE: ButterworthFilterState
|
|
45
36
|
|
|
46
|
-
INPUT_FILTER = ez.InputStream(
|
|
37
|
+
INPUT_FILTER = ez.InputStream(ButterworthFilterSettings)
|
|
47
38
|
|
|
48
39
|
def initialize(self) -> None:
|
|
49
40
|
self.STATE.design = self.SETTINGS
|
|
@@ -55,18 +46,17 @@ class ButterworthFilter(Filter):
|
|
|
55
46
|
if self.STATE.design.order > 0 and specs is not None:
|
|
56
47
|
btype, cut = specs
|
|
57
48
|
return scipy.signal.butter(
|
|
58
|
-
self.STATE.design.order,
|
|
59
|
-
Wn=cut,
|
|
60
|
-
btype=btype,
|
|
61
|
-
fs=self.STATE.fs,
|
|
62
|
-
output="ba"
|
|
49
|
+
self.STATE.design.order,
|
|
50
|
+
Wn=cut,
|
|
51
|
+
btype=btype,
|
|
52
|
+
fs=self.STATE.fs,
|
|
53
|
+
output="ba",
|
|
63
54
|
)
|
|
64
55
|
else:
|
|
65
56
|
return None
|
|
66
57
|
|
|
67
|
-
|
|
68
58
|
@ez.subscriber(INPUT_FILTER)
|
|
69
|
-
async def redesign(self, message:
|
|
59
|
+
async def redesign(self, message: ButterworthFilterSettings) -> None:
|
|
70
60
|
if self.STATE.design.order != message.order:
|
|
71
61
|
self.STATE.zi = None
|
|
72
62
|
self.STATE.design = message
|
|
@@ -2,23 +2,22 @@ import ezmsg.core as ez
|
|
|
2
2
|
|
|
3
3
|
import scipy.signal
|
|
4
4
|
|
|
5
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
6
|
+
|
|
5
7
|
from .downsample import Downsample, DownsampleSettings
|
|
6
8
|
from .filter import Filter, FilterCoefficients, FilterSettings
|
|
7
|
-
from .messages import TSMessage as TimeSeriesMessage
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
class Decimate(ez.Collection):
|
|
11
|
-
|
|
12
12
|
SETTINGS: DownsampleSettings
|
|
13
13
|
|
|
14
|
-
INPUT_SIGNAL = ez.InputStream(
|
|
15
|
-
OUTPUT_SIGNAL = ez.OutputStream(
|
|
14
|
+
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
15
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
16
16
|
|
|
17
17
|
FILTER = Filter()
|
|
18
18
|
DOWNSAMPLE = Downsample()
|
|
19
19
|
|
|
20
20
|
def configure(self) -> None:
|
|
21
|
-
|
|
22
21
|
self.DOWNSAMPLE.apply_settings(self.SETTINGS)
|
|
23
22
|
|
|
24
23
|
if self.SETTINGS.factor < 1:
|
|
@@ -27,11 +26,9 @@ class Decimate(ez.Collection):
|
|
|
27
26
|
filt = FilterCoefficients()
|
|
28
27
|
else:
|
|
29
28
|
# See scipy.signal.decimate for IIR Filter Condition
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
)
|
|
33
|
-
|
|
34
|
-
filt = FilterCoefficients(b=system.num, a=system.den)
|
|
29
|
+
b, a = scipy.signal.cheby1(8, 0.05, 0.8 / self.SETTINGS.factor)
|
|
30
|
+
system = scipy.signal.dlti(b, a)
|
|
31
|
+
filt = FilterCoefficients(b=system.num, a=system.den) # type: ignore
|
|
35
32
|
|
|
36
33
|
self.FILTER.apply_settings(FilterSettings(filt=filt))
|
|
37
34
|
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from dataclasses import replace
|
|
2
|
+
|
|
3
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
4
|
+
|
|
5
|
+
import ezmsg.core as ez
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from typing import (
|
|
9
|
+
AsyncGenerator,
|
|
10
|
+
Optional,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class DownsampleSettings(ez.Settings):
|
|
15
|
+
axis: Optional[str] = None
|
|
16
|
+
factor: int = 1
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DownsampleState(ez.State):
|
|
20
|
+
cur_settings: DownsampleSettings
|
|
21
|
+
s_idx: int = 0
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Downsample(ez.Unit):
|
|
25
|
+
SETTINGS: DownsampleSettings
|
|
26
|
+
STATE: DownsampleState
|
|
27
|
+
|
|
28
|
+
INPUT_SETTINGS = ez.InputStream(DownsampleSettings)
|
|
29
|
+
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
30
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
31
|
+
|
|
32
|
+
def initialize(self) -> None:
|
|
33
|
+
self.STATE.cur_settings = self.SETTINGS
|
|
34
|
+
|
|
35
|
+
@ez.subscriber(INPUT_SETTINGS)
|
|
36
|
+
async def on_settings(self, msg: DownsampleSettings) -> None:
|
|
37
|
+
self.STATE.cur_settings = msg
|
|
38
|
+
|
|
39
|
+
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
|
|
40
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
41
|
+
async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
|
|
42
|
+
if self.STATE.cur_settings.factor < 1:
|
|
43
|
+
raise ValueError("Downsample factor must be at least 1 (no downsampling)")
|
|
44
|
+
|
|
45
|
+
axis_name = self.STATE.cur_settings.axis
|
|
46
|
+
if axis_name is None:
|
|
47
|
+
axis_name = msg.dims[0]
|
|
48
|
+
axis = msg.get_axis(axis_name)
|
|
49
|
+
axis_idx = msg.get_axis_idx(axis_name)
|
|
50
|
+
|
|
51
|
+
samples = np.arange(msg.data.shape[axis_idx]) + self.STATE.s_idx
|
|
52
|
+
samples = samples % self.STATE.cur_settings.factor
|
|
53
|
+
self.STATE.s_idx = samples[-1] + 1
|
|
54
|
+
|
|
55
|
+
pub_samples = np.where(samples == 0)[0]
|
|
56
|
+
if len(pub_samples) != 0:
|
|
57
|
+
new_axes = {ax_name: msg.get_axis(ax_name) for ax_name in msg.dims}
|
|
58
|
+
new_offset = axis.offset + (axis.gain * pub_samples[0].item())
|
|
59
|
+
new_gain = axis.gain * self.STATE.cur_settings.factor
|
|
60
|
+
new_axes[axis_name] = replace(axis, gain=new_gain, offset=new_offset)
|
|
61
|
+
down_data = np.take(msg.data, pub_samples, axis_idx)
|
|
62
|
+
out_msg = replace(msg, data=down_data, dims=msg.dims, axes=new_axes)
|
|
63
|
+
yield self.OUTPUT_SIGNAL, out_msg
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from dataclasses import replace
|
|
3
|
+
|
|
4
|
+
import ezmsg.core as ez
|
|
5
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from .window import Window, WindowSettings
|
|
10
|
+
|
|
11
|
+
from typing import AsyncGenerator, Optional
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class EWMSettings(ez.Settings):
|
|
15
|
+
axis: Optional[str] = None
|
|
16
|
+
zero_offset: bool = True # If true, we assume zero DC offset
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class EWMState(ez.State):
|
|
20
|
+
buffer_queue: "asyncio.Queue[AxisArray]"
|
|
21
|
+
signal_queue: "asyncio.Queue[AxisArray]"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class EWM(ez.Unit):
|
|
25
|
+
"""
|
|
26
|
+
Exponentially Weighted Moving Average Standardization
|
|
27
|
+
|
|
28
|
+
References https://stackoverflow.com/a/42926270
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
SETTINGS: EWMSettings
|
|
32
|
+
STATE: EWMState
|
|
33
|
+
|
|
34
|
+
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
35
|
+
INPUT_BUFFER = ez.InputStream(AxisArray)
|
|
36
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
37
|
+
|
|
38
|
+
def initialize(self) -> None:
|
|
39
|
+
self.STATE.signal_queue = asyncio.Queue()
|
|
40
|
+
self.STATE.buffer_queue = asyncio.Queue()
|
|
41
|
+
|
|
42
|
+
@ez.subscriber(INPUT_SIGNAL)
|
|
43
|
+
async def on_signal(self, message: AxisArray) -> None:
|
|
44
|
+
self.STATE.signal_queue.put_nowait(message)
|
|
45
|
+
|
|
46
|
+
@ez.subscriber(INPUT_BUFFER)
|
|
47
|
+
async def on_buffer(self, message: AxisArray) -> None:
|
|
48
|
+
self.STATE.buffer_queue.put_nowait(message)
|
|
49
|
+
|
|
50
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
51
|
+
async def sync_output(self) -> AsyncGenerator:
|
|
52
|
+
while True:
|
|
53
|
+
signal = await self.STATE.signal_queue.get()
|
|
54
|
+
buffer = await self.STATE.buffer_queue.get() # includes signal
|
|
55
|
+
|
|
56
|
+
axis_name = self.SETTINGS.axis
|
|
57
|
+
if axis_name is None:
|
|
58
|
+
axis_name = signal.dims[0]
|
|
59
|
+
|
|
60
|
+
axis_idx = signal.get_axis_idx(axis_name)
|
|
61
|
+
|
|
62
|
+
buffer_len = buffer.shape[axis_idx]
|
|
63
|
+
block_len = signal.shape[axis_idx]
|
|
64
|
+
window = buffer_len - block_len
|
|
65
|
+
|
|
66
|
+
alpha = 2 / (window + 1.0)
|
|
67
|
+
alpha_rev = 1 - alpha
|
|
68
|
+
|
|
69
|
+
pows = alpha_rev ** (np.arange(buffer_len + 1))
|
|
70
|
+
scale_arr = 1 / pows[:-1]
|
|
71
|
+
pw0 = alpha * alpha_rev ** (buffer_len - 1)
|
|
72
|
+
|
|
73
|
+
buffer_data = buffer.data
|
|
74
|
+
buffer_data = np.moveaxis(buffer_data, axis_idx, 0)
|
|
75
|
+
|
|
76
|
+
def ewma(data: np.ndarray) -> np.ndarray:
|
|
77
|
+
mult = scale_arr[:, np.newaxis] * data * pw0
|
|
78
|
+
out = scale_arr[::-1, np.newaxis] * mult.cumsum(axis=0)
|
|
79
|
+
|
|
80
|
+
if not self.SETTINGS.zero_offset:
|
|
81
|
+
out = (data[0, :, np.newaxis] * pows[1:]).T + out
|
|
82
|
+
|
|
83
|
+
return out
|
|
84
|
+
|
|
85
|
+
mean = ewma(buffer_data)
|
|
86
|
+
std = ewma((buffer_data - mean) ** 2.0)
|
|
87
|
+
|
|
88
|
+
standardized = (buffer_data - mean) / np.sqrt(std).clip(1e-4)
|
|
89
|
+
standardized = standardized[-signal.shape[axis_idx] :, ...]
|
|
90
|
+
standardized = np.moveaxis(standardized, axis_idx, 0)
|
|
91
|
+
|
|
92
|
+
yield self.OUTPUT_SIGNAL, replace(signal, data=standardized)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class EWMFilterSettings(ez.Settings):
|
|
96
|
+
history_dur: float # previous data to accumulate for standardization
|
|
97
|
+
axis: Optional[str] = None
|
|
98
|
+
zero_offset: bool = True # If true, we assume zero DC offset for input data
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class EWMFilter(ez.Collection):
|
|
102
|
+
SETTINGS: EWMFilterSettings
|
|
103
|
+
|
|
104
|
+
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
105
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
106
|
+
|
|
107
|
+
WINDOW = Window()
|
|
108
|
+
EWM = EWM()
|
|
109
|
+
|
|
110
|
+
def configure(self) -> None:
|
|
111
|
+
self.EWM.apply_settings(EWMSettings(axis=self.SETTINGS.axis, zero_offset=True))
|
|
112
|
+
|
|
113
|
+
self.WINDOW.apply_settings(
|
|
114
|
+
WindowSettings(
|
|
115
|
+
axis=self.SETTINGS.axis,
|
|
116
|
+
window_dur=self.SETTINGS.history_dur,
|
|
117
|
+
window_shift=None, # 1:1 mode
|
|
118
|
+
)
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
def network(self) -> ez.NetworkDefinition:
|
|
122
|
+
return (
|
|
123
|
+
(self.INPUT_SIGNAL, self.WINDOW.INPUT_SIGNAL),
|
|
124
|
+
(self.WINDOW.OUTPUT_SIGNAL, self.EWM.INPUT_BUFFER),
|
|
125
|
+
(self.INPUT_SIGNAL, self.EWM.INPUT_SIGNAL),
|
|
126
|
+
(self.EWM.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL),
|
|
127
|
+
)
|
|
@@ -4,26 +4,30 @@ import ezmsg.core as ez
|
|
|
4
4
|
import scipy.signal
|
|
5
5
|
import numpy as np
|
|
6
6
|
import asyncio
|
|
7
|
-
import logging
|
|
8
7
|
|
|
9
|
-
from .messages import
|
|
8
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
10
9
|
|
|
11
10
|
from typing import AsyncGenerator, Optional, Tuple
|
|
12
11
|
|
|
13
12
|
|
|
14
13
|
@dataclass
|
|
15
14
|
class FilterCoefficients:
|
|
16
|
-
b: np.ndarray = field(default_factory
|
|
17
|
-
a: np.ndarray = field(default_factory
|
|
15
|
+
b: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0]))
|
|
16
|
+
a: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0]))
|
|
18
17
|
|
|
19
18
|
|
|
20
|
-
class
|
|
19
|
+
class FilterSettingsBase(ez.Settings):
|
|
20
|
+
axis: Optional[str] = None
|
|
21
|
+
fs: Optional[float] = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class FilterSettings(FilterSettingsBase):
|
|
21
25
|
# If you'd like to statically design a filter, define it in settings
|
|
22
26
|
filt: Optional[FilterCoefficients] = None
|
|
23
|
-
fs: Optional[float] = None
|
|
24
27
|
|
|
25
28
|
|
|
26
29
|
class FilterState(ez.State):
|
|
30
|
+
axis: Optional[str] = None
|
|
27
31
|
zi: Optional[np.ndarray] = None
|
|
28
32
|
filt_designed: bool = False
|
|
29
33
|
filt: Optional[FilterCoefficients] = None
|
|
@@ -33,21 +37,25 @@ class FilterState(ez.State):
|
|
|
33
37
|
|
|
34
38
|
|
|
35
39
|
class Filter(ez.Unit):
|
|
36
|
-
SETTINGS:
|
|
40
|
+
SETTINGS: FilterSettingsBase
|
|
37
41
|
STATE: FilterState
|
|
38
42
|
|
|
39
43
|
INPUT_FILTER = ez.InputStream(FilterCoefficients)
|
|
40
|
-
INPUT_SIGNAL = ez.InputStream(
|
|
41
|
-
OUTPUT_SIGNAL = ez.OutputStream(
|
|
44
|
+
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
45
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
42
46
|
|
|
43
47
|
def design_filter(self) -> Optional[Tuple[np.ndarray, np.ndarray]]:
|
|
44
48
|
raise NotImplementedError("Must implement 'design_filter' in Unit subclass!")
|
|
45
49
|
|
|
46
50
|
# Set up filter with static initialization if specified
|
|
47
51
|
def initialize(self) -> None:
|
|
48
|
-
if self.SETTINGS.
|
|
49
|
-
self.STATE.
|
|
50
|
-
|
|
52
|
+
if self.SETTINGS.axis is not None:
|
|
53
|
+
self.STATE.axis = self.SETTINGS.axis
|
|
54
|
+
|
|
55
|
+
if isinstance(self.SETTINGS, FilterSettings):
|
|
56
|
+
if self.SETTINGS.filt is not None:
|
|
57
|
+
self.STATE.filt = self.SETTINGS.filt
|
|
58
|
+
self.STATE.filt_set.set()
|
|
51
59
|
else:
|
|
52
60
|
self.STATE.filt_set.clear()
|
|
53
61
|
|
|
@@ -64,7 +72,9 @@ class Filter(ez.Unit):
|
|
|
64
72
|
def update_filter(self):
|
|
65
73
|
try:
|
|
66
74
|
coefs = self.design_filter()
|
|
67
|
-
self.STATE.filt =
|
|
75
|
+
self.STATE.filt = (
|
|
76
|
+
FilterCoefficients() if coefs is None else FilterCoefficients(*coefs)
|
|
77
|
+
)
|
|
68
78
|
self.STATE.filt_set.set()
|
|
69
79
|
self.STATE.filt_designed = True
|
|
70
80
|
except NotImplementedError as e:
|
|
@@ -74,30 +84,36 @@ class Filter(ez.Unit):
|
|
|
74
84
|
|
|
75
85
|
@ez.subscriber(INPUT_SIGNAL)
|
|
76
86
|
@ez.publisher(OUTPUT_SIGNAL)
|
|
77
|
-
async def apply_filter(self,
|
|
78
|
-
if self.STATE.
|
|
79
|
-
|
|
87
|
+
async def apply_filter(self, msg: AxisArray) -> AsyncGenerator:
|
|
88
|
+
axis_name = msg.dims[0] if self.STATE.axis is None else self.STATE.axis
|
|
89
|
+
axis_idx = msg.get_axis_idx(axis_name)
|
|
90
|
+
axis = msg.get_axis(axis_name)
|
|
91
|
+
fs = 1.0 / axis.gain
|
|
92
|
+
|
|
93
|
+
if self.STATE.fs != fs and self.STATE.filt_designed is True:
|
|
94
|
+
self.STATE.fs = fs
|
|
80
95
|
self.update_filter()
|
|
81
96
|
|
|
82
97
|
# Ensure filter is defined
|
|
98
|
+
# TODO: Maybe have me be a passthrough filter until coefficients are received
|
|
83
99
|
if self.STATE.filt is None:
|
|
84
100
|
self.STATE.filt_set.clear()
|
|
85
101
|
ez.logger.info("Awaiting filter coefficients...")
|
|
86
102
|
await self.STATE.filt_set.wait()
|
|
87
103
|
ez.logger.info("Filter coefficients received.")
|
|
88
104
|
|
|
89
|
-
|
|
105
|
+
assert self.STATE.filt is not None
|
|
106
|
+
|
|
107
|
+
arr_in = msg.data
|
|
90
108
|
|
|
91
109
|
# If the array is one dimensional, add a temporary second dimension so that the math works out
|
|
92
110
|
one_dimensional = False
|
|
93
|
-
if
|
|
94
|
-
arr_in = np.expand_dims(
|
|
111
|
+
if arr_in.ndim == 1:
|
|
112
|
+
arr_in = np.expand_dims(arr_in, axis=1)
|
|
95
113
|
one_dimensional = True
|
|
96
|
-
else:
|
|
97
|
-
arr_in = message.data
|
|
98
114
|
|
|
99
115
|
# We will perform filter with time dimension as last axis
|
|
100
|
-
arr_in = np.moveaxis(arr_in,
|
|
116
|
+
arr_in = np.moveaxis(arr_in, axis_idx, -1)
|
|
101
117
|
samp_shape = arr_in[..., 0].shape
|
|
102
118
|
|
|
103
119
|
# Re-calculate/reset zi if necessary
|
|
@@ -115,10 +131,10 @@ class Filter(ez.Unit):
|
|
|
115
131
|
self.STATE.filt.b, self.STATE.filt.a, arr_in, zi=self.STATE.zi
|
|
116
132
|
)
|
|
117
133
|
|
|
118
|
-
arr_out = np.moveaxis(arr_out, -1,
|
|
134
|
+
arr_out = np.moveaxis(arr_out, -1, axis_idx)
|
|
119
135
|
|
|
120
136
|
# Remove temporary first dimension if necessary
|
|
121
137
|
if one_dimensional:
|
|
122
138
|
arr_out = np.squeeze(arr_out, axis=1)
|
|
123
139
|
|
|
124
|
-
yield
|
|
140
|
+
yield self.OUTPUT_SIGNAL, replace(msg, data=arr_out),
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
import time
|
|
3
|
+
|
|
4
|
+
import numpy.typing as npt
|
|
5
|
+
|
|
6
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
+
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
# UPCOMING: TSMessage Deprecation
|
|
11
|
+
# TSMessage is deprecated because it doesn't handle multiple time axes well.
|
|
12
|
+
# AxisArray has an incompatible API but supports a superset of functionality.
|
|
13
|
+
warnings.warn(
|
|
14
|
+
"TimeSeriesMessage/TSMessage is deprecated. Please use ezmsg.utils.AxisArray",
|
|
15
|
+
DeprecationWarning,
|
|
16
|
+
stacklevel=2,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def TSMessage(
|
|
21
|
+
data: npt.NDArray,
|
|
22
|
+
fs: float = 1.0,
|
|
23
|
+
time_dim: int = 0,
|
|
24
|
+
timestamp: Optional[float] = None,
|
|
25
|
+
) -> AxisArray:
|
|
26
|
+
dims = [f"dim_{i}" for i in range(data.ndim)]
|
|
27
|
+
dims[time_dim] = "time"
|
|
28
|
+
offset = time.time() if timestamp is None else timestamp
|
|
29
|
+
offset_adj = data.shape[time_dim] / fs # offset corresponds to idx[0] on time_dim
|
|
30
|
+
axis = AxisArray.Axis.TimeAxis(fs, offset=offset - offset_adj)
|
|
31
|
+
return AxisArray(data, dims=dims, axes=dict(time=axis))
|