ezmsg-sigproc 1.4.2__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +2 -2
- ezmsg/sigproc/affinetransform.py +13 -13
- ezmsg/sigproc/aggregate.py +49 -28
- ezmsg/sigproc/bandpower.py +2 -2
- ezmsg/sigproc/butterworthfilter.py +89 -90
- ezmsg/sigproc/cheby.py +119 -0
- ezmsg/sigproc/decimate.py +11 -15
- ezmsg/sigproc/downsample.py +8 -4
- ezmsg/sigproc/ewmfilter.py +9 -5
- ezmsg/sigproc/filter.py +82 -115
- ezmsg/sigproc/filterbank.py +5 -5
- ezmsg/sigproc/math/abs.py +1 -1
- ezmsg/sigproc/math/clip.py +1 -1
- ezmsg/sigproc/math/difference.py +1 -1
- ezmsg/sigproc/math/invert.py +1 -1
- ezmsg/sigproc/math/log.py +1 -1
- ezmsg/sigproc/math/scale.py +1 -1
- ezmsg/sigproc/messages.py +2 -3
- ezmsg/sigproc/sampler.py +16 -15
- ezmsg/sigproc/scaler.py +153 -35
- ezmsg/sigproc/signalinjector.py +7 -7
- ezmsg/sigproc/slicer.py +34 -14
- ezmsg/sigproc/spectrogram.py +6 -6
- ezmsg/sigproc/spectrum.py +18 -14
- ezmsg/sigproc/synth.py +43 -27
- ezmsg/sigproc/wavelets.py +42 -17
- ezmsg/sigproc/window.py +14 -13
- {ezmsg_sigproc-1.4.2.dist-info → ezmsg_sigproc-1.6.0.dist-info}/METADATA +4 -5
- ezmsg_sigproc-1.6.0.dist-info/RECORD +36 -0
- {ezmsg_sigproc-1.4.2.dist-info → ezmsg_sigproc-1.6.0.dist-info}/WHEEL +1 -1
- ezmsg_sigproc-1.4.2.dist-info/RECORD +0 -35
- {ezmsg_sigproc-1.4.2.dist-info → ezmsg_sigproc-1.6.0.dist-info}/licenses/LICENSE.txt +0 -0
ezmsg/sigproc/downsample.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
|
1
|
-
from dataclasses import replace
|
|
2
1
|
import typing
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
|
-
from ezmsg.util.messages.axisarray import
|
|
4
|
+
from ezmsg.util.messages.axisarray import (
|
|
5
|
+
AxisArray,
|
|
6
|
+
slice_along_axis,
|
|
7
|
+
replace,
|
|
8
|
+
)
|
|
6
9
|
from ezmsg.util.generator import consumer
|
|
7
10
|
import ezmsg.core as ez
|
|
8
11
|
|
|
@@ -11,7 +14,7 @@ from .base import GenAxisArray
|
|
|
11
14
|
|
|
12
15
|
@consumer
|
|
13
16
|
def downsample(
|
|
14
|
-
axis:
|
|
17
|
+
axis: str | None = None, factor: int = 1
|
|
15
18
|
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
16
19
|
"""
|
|
17
20
|
Construct a generator that yields a downsampled version of the data .send() to it.
|
|
@@ -22,6 +25,7 @@ def downsample(
|
|
|
22
25
|
|
|
23
26
|
Args:
|
|
24
27
|
axis: The name of the axis along which to downsample.
|
|
28
|
+
Note: The axis must exist in the message .axes and be of type AxisArray.LinearAxis.
|
|
25
29
|
factor: Downsampling factor.
|
|
26
30
|
|
|
27
31
|
Returns:
|
|
@@ -92,7 +96,7 @@ class DownsampleSettings(ez.Settings):
|
|
|
92
96
|
See :obj:`downsample` documentation for a description of the parameters.
|
|
93
97
|
"""
|
|
94
98
|
|
|
95
|
-
axis:
|
|
99
|
+
axis: str | None = None
|
|
96
100
|
factor: int = 1
|
|
97
101
|
|
|
98
102
|
|
ezmsg/sigproc/ewmfilter.py
CHANGED
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
from dataclasses import replace
|
|
3
2
|
import typing
|
|
4
3
|
|
|
5
4
|
import ezmsg.core as ez
|
|
6
5
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
6
|
+
from ezmsg.util.messages.util import replace
|
|
7
7
|
import numpy as np
|
|
8
8
|
|
|
9
9
|
from .window import Window, WindowSettings
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class EWMSettings(ez.Settings):
|
|
13
|
-
axis:
|
|
13
|
+
axis: str | None = None
|
|
14
14
|
"""Name of the axis to accumulate."""
|
|
15
15
|
|
|
16
16
|
zero_offset: bool = True
|
|
@@ -24,7 +24,8 @@ class EWMState(ez.State):
|
|
|
24
24
|
|
|
25
25
|
class EWM(ez.Unit):
|
|
26
26
|
"""
|
|
27
|
-
Exponentially Weighted Moving Average Standardization
|
|
27
|
+
Exponentially Weighted Moving Average Standardization.
|
|
28
|
+
This is deprecated. Please use :obj:`ezmsg.sigproc.scaler.AdaptiveStandardScaler` instead.
|
|
28
29
|
|
|
29
30
|
References https://stackoverflow.com/a/42926270
|
|
30
31
|
"""
|
|
@@ -37,6 +38,9 @@ class EWM(ez.Unit):
|
|
|
37
38
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
38
39
|
|
|
39
40
|
async def initialize(self) -> None:
|
|
41
|
+
ez.logger.warning(
|
|
42
|
+
"EWM/EWMFilter is deprecated and will be removed in a future version. Use AdaptiveStandardScaler instead."
|
|
43
|
+
)
|
|
40
44
|
self.STATE.signal_queue = asyncio.Queue()
|
|
41
45
|
self.STATE.buffer_queue = asyncio.Queue()
|
|
42
46
|
|
|
@@ -100,7 +104,7 @@ class EWMFilterSettings(ez.Settings):
|
|
|
100
104
|
history_dur: float
|
|
101
105
|
"""Previous data to accumulate for standardization."""
|
|
102
106
|
|
|
103
|
-
axis:
|
|
107
|
+
axis: str | None = None
|
|
104
108
|
"""Name of the axis to accumulate."""
|
|
105
109
|
|
|
106
110
|
zero_offset: bool = True
|
|
@@ -113,7 +117,7 @@ class EWMFilter(ez.Collection):
|
|
|
113
117
|
leads to :obj:`Window` which then feeds into :obj:`EWM` 's INPUT_BUFFER
|
|
114
118
|
and another branch that feeds directly into :obj:`EWM` 's INPUT_SIGNAL.
|
|
115
119
|
|
|
116
|
-
|
|
120
|
+
This is deprecated. Please use :obj:`ezmsg.sigproc.scaler.AdaptiveStandardScaler` instead.
|
|
117
121
|
"""
|
|
118
122
|
|
|
119
123
|
SETTINGS = EWMFilterSettings
|
ezmsg/sigproc/filter.py
CHANGED
|
@@ -1,14 +1,16 @@
|
|
|
1
|
-
import
|
|
2
|
-
from dataclasses import dataclass, replace, field
|
|
1
|
+
from dataclasses import dataclass, field
|
|
3
2
|
import typing
|
|
4
3
|
|
|
5
4
|
import ezmsg.core as ez
|
|
6
5
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
6
|
+
from ezmsg.util.messages.util import replace
|
|
7
7
|
from ezmsg.util.generator import consumer
|
|
8
8
|
import numpy as np
|
|
9
9
|
import numpy.typing as npt
|
|
10
10
|
import scipy.signal
|
|
11
11
|
|
|
12
|
+
from ezmsg.sigproc.base import GenAxisArray
|
|
13
|
+
|
|
12
14
|
|
|
13
15
|
@dataclass
|
|
14
16
|
class FilterCoefficients:
|
|
@@ -17,10 +19,8 @@ class FilterCoefficients:
|
|
|
17
19
|
|
|
18
20
|
|
|
19
21
|
def _normalize_coefs(
|
|
20
|
-
coefs:
|
|
21
|
-
|
|
22
|
-
],
|
|
23
|
-
) -> typing.Tuple[str, typing.Tuple[npt.NDArray, ...]]:
|
|
22
|
+
coefs: FilterCoefficients | tuple[npt.NDArray, npt.NDArray] | npt.NDArray,
|
|
23
|
+
) -> tuple[str, tuple[npt.NDArray, ...]]:
|
|
24
24
|
coef_type = "ba"
|
|
25
25
|
if coefs is not None:
|
|
26
26
|
# scipy.signal functions called with first arg `*coefs`.
|
|
@@ -35,7 +35,7 @@ def _normalize_coefs(
|
|
|
35
35
|
|
|
36
36
|
@consumer
|
|
37
37
|
def filtergen(
|
|
38
|
-
axis: str, coefs:
|
|
38
|
+
axis: str, coefs: npt.NDArray | tuple[npt.NDArray] | None, coef_type: str
|
|
39
39
|
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
40
40
|
"""
|
|
41
41
|
Filter data using the provided coefficients.
|
|
@@ -61,7 +61,7 @@ def filtergen(
|
|
|
61
61
|
zi_func = {"ba": scipy.signal.lfilter_zi, "sos": scipy.signal.sosfilt_zi}[coef_type]
|
|
62
62
|
|
|
63
63
|
# State variables
|
|
64
|
-
zi:
|
|
64
|
+
zi: npt.NDArray | None = None
|
|
65
65
|
|
|
66
66
|
# Reset if these change.
|
|
67
67
|
check_input = {"key": None, "shape": None}
|
|
@@ -105,128 +105,95 @@ def filtergen(
|
|
|
105
105
|
msg_out = replace(msg_in, data=dat_out)
|
|
106
106
|
|
|
107
107
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
108
|
+
# Type aliases
|
|
109
|
+
BACoeffs = tuple[npt.NDArray, npt.NDArray]
|
|
110
|
+
SOSCoeffs = npt.NDArray
|
|
111
|
+
FilterCoefsMultiType = BACoeffs | SOSCoeffs
|
|
111
112
|
|
|
112
113
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
114
|
+
@consumer
|
|
115
|
+
def filter_gen_by_design(
|
|
116
|
+
axis: str,
|
|
117
|
+
coef_type: str,
|
|
118
|
+
design_fun: typing.Callable[[float], FilterCoefsMultiType | None],
|
|
119
|
+
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
120
|
+
"""
|
|
121
|
+
Filter data using a filter whose coefficients are calculated using the provided design function.
|
|
116
122
|
|
|
123
|
+
Args:
|
|
124
|
+
axis: The name of the axis to filter.
|
|
125
|
+
Note: The axis must be represented in the message .axes and be of type AxisArray.LinearAxis.
|
|
126
|
+
coef_type: "ba" or "sos"
|
|
127
|
+
design_fun: A callable that takes "fs" as its only argument and returns a tuple of filter coefficients.
|
|
128
|
+
If the design_fun returns None then the filter will act as a passthrough.
|
|
129
|
+
Hint: To make a design function that only requires fs, use functools.partial to set other parameters.
|
|
130
|
+
See butterworthfilter for an example.
|
|
117
131
|
|
|
118
|
-
|
|
119
|
-
axis: typing.Optional[str] = None
|
|
120
|
-
zi: typing.Optional[np.ndarray] = None
|
|
121
|
-
filt_designed: bool = False
|
|
122
|
-
filt: typing.Optional[FilterCoefficients] = None
|
|
123
|
-
filt_set: asyncio.Event = field(default_factory=asyncio.Event)
|
|
124
|
-
samp_shape: typing.Optional[typing.Tuple[int, ...]] = None
|
|
125
|
-
fs: typing.Optional[float] = None # Hz
|
|
132
|
+
Returns:
|
|
126
133
|
|
|
134
|
+
"""
|
|
135
|
+
msg_out = AxisArray(np.array([]), dims=[""])
|
|
127
136
|
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
137
|
+
# State variables
|
|
138
|
+
# Initialize filtergen as passthrough until we receive a message that allows us to design the filter.
|
|
139
|
+
filter_gen = filtergen(axis, None, coef_type)
|
|
131
140
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
141
|
+
# Reset if these change.
|
|
142
|
+
check_input = {"gain": None}
|
|
143
|
+
# No need to check parameters that don't affect the design; filter_gen should check most of its parameters.
|
|
135
144
|
|
|
136
|
-
|
|
137
|
-
|
|
145
|
+
while True:
|
|
146
|
+
msg_in: AxisArray = yield msg_out
|
|
147
|
+
axis = axis or msg_in.dims[0]
|
|
148
|
+
b_reset = msg_in.axes[axis].gain != check_input["gain"]
|
|
149
|
+
if b_reset:
|
|
150
|
+
check_input["gain"] = msg_in.axes[axis].gain
|
|
151
|
+
coefs = design_fun(1 / msg_in.axes[axis].gain)
|
|
152
|
+
filter_gen = filtergen(axis, coefs, coef_type)
|
|
138
153
|
|
|
139
|
-
|
|
140
|
-
async def initialize(self) -> None:
|
|
141
|
-
if self.SETTINGS.axis is not None:
|
|
142
|
-
self.STATE.axis = self.SETTINGS.axis
|
|
154
|
+
msg_out = filter_gen.send(msg_in)
|
|
143
155
|
|
|
144
|
-
if isinstance(self.SETTINGS, FilterSettings):
|
|
145
|
-
if self.SETTINGS.filt is not None:
|
|
146
|
-
self.STATE.filt = self.SETTINGS.filt
|
|
147
|
-
self.STATE.filt_set.set()
|
|
148
|
-
else:
|
|
149
|
-
self.STATE.filt_set.clear()
|
|
150
156
|
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
except NotImplementedError:
|
|
155
|
-
ez.logger.debug("Using filter coefficients.")
|
|
157
|
+
class FilterBaseSettings(ez.Settings):
|
|
158
|
+
axis: str | None = None
|
|
159
|
+
coef_type: str = "ba"
|
|
156
160
|
|
|
157
|
-
@ez.subscriber(INPUT_FILTER)
|
|
158
|
-
async def redesign(self, message: FilterCoefficients):
|
|
159
|
-
self.STATE.filt = message
|
|
160
|
-
|
|
161
|
-
def update_filter(self):
|
|
162
|
-
try:
|
|
163
|
-
coefs = self.design_filter()
|
|
164
|
-
self.STATE.filt = (
|
|
165
|
-
FilterCoefficients() if coefs is None else FilterCoefficients(*coefs)
|
|
166
|
-
)
|
|
167
|
-
self.STATE.filt_set.set()
|
|
168
|
-
self.STATE.filt_designed = True
|
|
169
|
-
except NotImplementedError as e:
|
|
170
|
-
raise e
|
|
171
|
-
except Exception as e:
|
|
172
|
-
ez.logger.warning(f"Error when designing filter: {e}")
|
|
173
|
-
|
|
174
|
-
@ez.subscriber(INPUT_SIGNAL)
|
|
175
|
-
@ez.publisher(OUTPUT_SIGNAL)
|
|
176
|
-
async def apply_filter(self, msg: AxisArray) -> typing.AsyncGenerator:
|
|
177
|
-
axis_name = msg.dims[0] if self.STATE.axis is None else self.STATE.axis
|
|
178
|
-
axis_idx = msg.get_axis_idx(axis_name)
|
|
179
|
-
axis = msg.get_axis(axis_name)
|
|
180
|
-
fs = 1.0 / axis.gain
|
|
181
|
-
|
|
182
|
-
if self.STATE.fs != fs and self.STATE.filt_designed is True:
|
|
183
|
-
self.STATE.fs = fs
|
|
184
|
-
self.update_filter()
|
|
185
|
-
|
|
186
|
-
# Ensure filter is defined
|
|
187
|
-
# TODO: Maybe have me be a passthrough filter until coefficients are received
|
|
188
|
-
if self.STATE.filt is None:
|
|
189
|
-
self.STATE.filt_set.clear()
|
|
190
|
-
ez.logger.info("Awaiting filter coefficients...")
|
|
191
|
-
await self.STATE.filt_set.wait()
|
|
192
|
-
ez.logger.info("Filter coefficients received.")
|
|
193
|
-
|
|
194
|
-
assert self.STATE.filt is not None
|
|
195
|
-
|
|
196
|
-
arr_in = msg.data
|
|
197
|
-
|
|
198
|
-
# If the array is one dimensional, add a temporary second dimension so that the math works out
|
|
199
|
-
one_dimensional = False
|
|
200
|
-
if arr_in.ndim == 1:
|
|
201
|
-
arr_in = np.expand_dims(arr_in, axis=1)
|
|
202
|
-
one_dimensional = True
|
|
203
|
-
|
|
204
|
-
# We will perform filter with time dimension as last axis
|
|
205
|
-
arr_in = np.moveaxis(arr_in, axis_idx, -1)
|
|
206
|
-
samp_shape = arr_in[..., 0].shape
|
|
207
161
|
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
162
|
+
class FilterBase(GenAxisArray):
|
|
163
|
+
SETTINGS = FilterBaseSettings
|
|
164
|
+
|
|
165
|
+
# Backwards-compatible with `Filter` unit
|
|
166
|
+
INPUT_FILTER = ez.InputStream(FilterCoefsMultiType)
|
|
167
|
+
|
|
168
|
+
def design_filter(
|
|
169
|
+
self,
|
|
170
|
+
) -> typing.Callable[[float], FilterCoefsMultiType | None]:
|
|
171
|
+
raise NotImplementedError("Must implement 'design_filter' in Unit subclass!")
|
|
218
172
|
|
|
219
|
-
|
|
220
|
-
|
|
173
|
+
def construct_generator(self):
|
|
174
|
+
design_fun = self.design_filter()
|
|
175
|
+
self.STATE.gen = filter_gen_by_design(
|
|
176
|
+
self.SETTINGS.axis, self.SETTINGS.coef_type, design_fun
|
|
221
177
|
)
|
|
222
178
|
|
|
223
|
-
|
|
179
|
+
@ez.subscriber(INPUT_FILTER)
|
|
180
|
+
async def redesign(self, message: FilterBaseSettings) -> None:
|
|
181
|
+
self.apply_settings(message)
|
|
182
|
+
self.construct_generator()
|
|
224
183
|
|
|
225
|
-
# Remove temporary first dimension if necessary
|
|
226
|
-
if one_dimensional:
|
|
227
|
-
arr_out = np.squeeze(arr_out, axis=1)
|
|
228
184
|
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
185
|
+
class FilterSettings(FilterBaseSettings):
|
|
186
|
+
# If you'd like to statically design a filter, define it in settings
|
|
187
|
+
coefs: FilterCoefficients | None = None
|
|
188
|
+
# Note: coef_type = "ba" is assumed for this class.
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class Filter(FilterBase):
|
|
192
|
+
SETTINGS = FilterSettings
|
|
193
|
+
|
|
194
|
+
INPUT_FILTER = ez.InputStream(FilterCoefficients)
|
|
195
|
+
|
|
196
|
+
def design_filter(self) -> typing.Callable[[float], BACoeffs | None]:
|
|
197
|
+
if self.SETTINGS.coefs is None:
|
|
198
|
+
return lambda fs: None
|
|
199
|
+
return lambda fs: (self.SETTINGS.coefs.b, self.SETTINGS.coefs.a)
|
ezmsg/sigproc/filterbank.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
from dataclasses import replace
|
|
2
1
|
import functools
|
|
3
2
|
import math
|
|
4
3
|
import typing
|
|
@@ -10,6 +9,7 @@ from scipy.special import lambertw
|
|
|
10
9
|
import numpy.typing as npt
|
|
11
10
|
import ezmsg.core as ez
|
|
12
11
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
12
|
+
from ezmsg.util.messages.util import replace
|
|
13
13
|
from ezmsg.util.generator import consumer
|
|
14
14
|
|
|
15
15
|
from .base import GenAxisArray
|
|
@@ -36,7 +36,7 @@ class MinPhaseMode(OptionsEnum):
|
|
|
36
36
|
|
|
37
37
|
@consumer
|
|
38
38
|
def filterbank(
|
|
39
|
-
kernels:
|
|
39
|
+
kernels: list[npt.NDArray] | tuple[npt.NDArray, ...],
|
|
40
40
|
mode: FilterbankMode = FilterbankMode.CONV,
|
|
41
41
|
min_phase: MinPhaseMode = MinPhaseMode.NONE,
|
|
42
42
|
axis: str = "time",
|
|
@@ -63,10 +63,10 @@ def filterbank(
|
|
|
63
63
|
with the data payload containing the absolute value of the input :obj:`AxisArray` data.
|
|
64
64
|
|
|
65
65
|
"""
|
|
66
|
-
msg_out:
|
|
66
|
+
msg_out: AxisArray | None = None
|
|
67
67
|
|
|
68
68
|
# State variables
|
|
69
|
-
template:
|
|
69
|
+
template: AxisArray | None = None
|
|
70
70
|
|
|
71
71
|
# Reset if these change
|
|
72
72
|
check_input = {
|
|
@@ -258,7 +258,7 @@ def filterbank(
|
|
|
258
258
|
|
|
259
259
|
|
|
260
260
|
class FilterbankSettings(ez.Settings):
|
|
261
|
-
kernels:
|
|
261
|
+
kernels: list[npt.NDArray] | tuple[npt.NDArray, ...]
|
|
262
262
|
mode: FilterbankMode = FilterbankMode.CONV
|
|
263
263
|
min_phase: MinPhaseMode = MinPhaseMode.NONE
|
|
264
264
|
axis: str = "time"
|
ezmsg/sigproc/math/abs.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
from dataclasses import replace
|
|
2
1
|
import typing
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
4
|
import ezmsg.core as ez
|
|
6
5
|
from ezmsg.util.generator import consumer
|
|
7
6
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
+
from ezmsg.util.messages.util import replace
|
|
8
8
|
|
|
9
9
|
from ..base import GenAxisArray
|
|
10
10
|
|
ezmsg/sigproc/math/clip.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
from dataclasses import replace
|
|
2
1
|
import typing
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
4
|
import ezmsg.core as ez
|
|
6
5
|
from ezmsg.util.generator import consumer
|
|
7
6
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
+
from ezmsg.util.messages.util import replace
|
|
8
8
|
|
|
9
9
|
from ..base import GenAxisArray
|
|
10
10
|
|
ezmsg/sigproc/math/difference.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
from dataclasses import replace
|
|
2
1
|
import typing
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
4
|
import ezmsg.core as ez
|
|
6
5
|
from ezmsg.util.generator import consumer
|
|
7
6
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
+
from ezmsg.util.messages.util import replace
|
|
8
8
|
|
|
9
9
|
from ..base import GenAxisArray
|
|
10
10
|
|
ezmsg/sigproc/math/invert.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
from dataclasses import replace
|
|
2
1
|
import typing
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
4
|
import ezmsg.core as ez
|
|
6
5
|
from ezmsg.util.generator import consumer
|
|
7
6
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
+
from ezmsg.util.messages.util import replace
|
|
8
8
|
|
|
9
9
|
from ..base import GenAxisArray
|
|
10
10
|
|
ezmsg/sigproc/math/log.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
from dataclasses import replace
|
|
2
1
|
import typing
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
4
|
import ezmsg.core as ez
|
|
6
5
|
from ezmsg.util.generator import consumer
|
|
7
6
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
+
from ezmsg.util.messages.util import replace
|
|
8
8
|
|
|
9
9
|
from ..base import GenAxisArray
|
|
10
10
|
|
ezmsg/sigproc/math/scale.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
from dataclasses import replace
|
|
2
1
|
import typing
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
4
|
import ezmsg.core as ez
|
|
6
5
|
from ezmsg.util.generator import consumer
|
|
7
6
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
+
from ezmsg.util.messages.util import replace
|
|
8
8
|
|
|
9
9
|
from ..base import GenAxisArray
|
|
10
10
|
|
ezmsg/sigproc/messages.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import warnings
|
|
2
2
|
import time
|
|
3
|
-
import typing
|
|
4
3
|
|
|
5
4
|
import numpy.typing as npt
|
|
6
5
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
@@ -20,11 +19,11 @@ def TSMessage(
|
|
|
20
19
|
data: npt.NDArray,
|
|
21
20
|
fs: float = 1.0,
|
|
22
21
|
time_dim: int = 0,
|
|
23
|
-
timestamp:
|
|
22
|
+
timestamp: float | None = None,
|
|
24
23
|
) -> AxisArray:
|
|
25
24
|
dims = [f"dim_{i}" for i in range(data.ndim)]
|
|
26
25
|
dims[time_dim] = "time"
|
|
27
26
|
offset = time.time() if timestamp is None else timestamp
|
|
28
27
|
offset_adj = data.shape[time_dim] / fs # offset corresponds to idx[0] on time_dim
|
|
29
|
-
axis = AxisArray.
|
|
28
|
+
axis = AxisArray.TimeAxis(fs, offset=offset - offset_adj)
|
|
30
29
|
return AxisArray(data, dims=dims, axes=dict(time=axis))
|
ezmsg/sigproc/sampler.py
CHANGED
|
@@ -1,13 +1,17 @@
|
|
|
1
1
|
import asyncio # Dev/test apparatus
|
|
2
2
|
from collections import deque
|
|
3
|
-
from dataclasses import dataclass,
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
4
|
import time
|
|
5
5
|
import typing
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import numpy.typing as npt
|
|
9
9
|
import ezmsg.core as ez
|
|
10
|
-
from ezmsg.util.messages.axisarray import
|
|
10
|
+
from ezmsg.util.messages.axisarray import (
|
|
11
|
+
AxisArray,
|
|
12
|
+
slice_along_axis,
|
|
13
|
+
)
|
|
14
|
+
from ezmsg.util.messages.util import replace
|
|
11
15
|
from ezmsg.util.generator import consumer
|
|
12
16
|
|
|
13
17
|
|
|
@@ -16,7 +20,7 @@ class SampleTriggerMessage:
|
|
|
16
20
|
timestamp: float = field(default_factory=time.time)
|
|
17
21
|
"""Time of the trigger, in seconds. The Clock depends on the input but defaults to time.time"""
|
|
18
22
|
|
|
19
|
-
period:
|
|
23
|
+
period: tuple[float, float] | None = None
|
|
20
24
|
"""The period around the timestamp, in seconds"""
|
|
21
25
|
|
|
22
26
|
value: typing.Any = None
|
|
@@ -35,13 +39,11 @@ class SampleMessage:
|
|
|
35
39
|
@consumer
|
|
36
40
|
def sampler(
|
|
37
41
|
buffer_dur: float,
|
|
38
|
-
axis:
|
|
39
|
-
period:
|
|
42
|
+
axis: str | None = None,
|
|
43
|
+
period: tuple[float, float] | None = None,
|
|
40
44
|
value: typing.Any = None,
|
|
41
45
|
estimate_alignment: bool = True,
|
|
42
|
-
) -> typing.Generator[
|
|
43
|
-
typing.List[SampleMessage], typing.Union[AxisArray, SampleTriggerMessage], None
|
|
44
|
-
]:
|
|
46
|
+
) -> typing.Generator[list[SampleMessage], AxisArray | SampleTriggerMessage, None]:
|
|
45
47
|
"""
|
|
46
48
|
Sample data into a buffer, accept triggers, and return slices of sampled
|
|
47
49
|
data around the trigger time.
|
|
@@ -52,6 +54,7 @@ def sampler(
|
|
|
52
54
|
need a buffer of 0.5 + (1.5 - -1.0) = 3.0 seconds. It is best to at least double your estimate if memory allows.
|
|
53
55
|
axis: The axis along which to sample the data.
|
|
54
56
|
None (default) will choose the first axis in the first input.
|
|
57
|
+
Note: (for now) the axis must exist in the msg .axes and be of type AxisArray.LinearAxis
|
|
55
58
|
period: The period in seconds during which to sample the data.
|
|
56
59
|
Defaults to None. Only used if not None and the trigger message does not define its own period.
|
|
57
60
|
value: The value to sample. Defaults to None.
|
|
@@ -69,7 +72,7 @@ def sampler(
|
|
|
69
72
|
|
|
70
73
|
# State variables (most shared between trigger- and data-processing.
|
|
71
74
|
triggers: deque[SampleTriggerMessage] = deque()
|
|
72
|
-
buffer:
|
|
75
|
+
buffer: npt.NDArray | None = None
|
|
73
76
|
n_samples: int = 0
|
|
74
77
|
offset: float = 0.0
|
|
75
78
|
|
|
@@ -225,8 +228,8 @@ class SamplerSettings(ez.Settings):
|
|
|
225
228
|
"""
|
|
226
229
|
|
|
227
230
|
buffer_dur: float
|
|
228
|
-
axis:
|
|
229
|
-
period:
|
|
231
|
+
axis: str | None = None
|
|
232
|
+
period: tuple[float, float] | None = None
|
|
230
233
|
"""Optional default period if unspecified in SampleTriggerMessage"""
|
|
231
234
|
|
|
232
235
|
value: typing.Any = None
|
|
@@ -243,9 +246,7 @@ class SamplerSettings(ez.Settings):
|
|
|
243
246
|
|
|
244
247
|
class SamplerState(ez.State):
|
|
245
248
|
cur_settings: SamplerSettings
|
|
246
|
-
gen: typing.Generator[
|
|
247
|
-
typing.Union[AxisArray, SampleTriggerMessage], typing.List[SampleMessage], None
|
|
248
|
-
]
|
|
249
|
+
gen: typing.Generator[AxisArray | SampleTriggerMessage, list[SampleMessage], None]
|
|
249
250
|
|
|
250
251
|
|
|
251
252
|
class Sampler(ez.Unit):
|
|
@@ -290,7 +291,7 @@ class Sampler(ez.Unit):
|
|
|
290
291
|
|
|
291
292
|
|
|
292
293
|
class TriggerGeneratorSettings(ez.Settings):
|
|
293
|
-
period:
|
|
294
|
+
period: tuple[float, float]
|
|
294
295
|
"""The period around the trigger event."""
|
|
295
296
|
|
|
296
297
|
prewait: float = 0.5
|