ezmsg-sigproc 1.5.0__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ezmsg/sigproc/filter.py CHANGED
@@ -1,14 +1,16 @@
1
- import asyncio
2
1
  from dataclasses import dataclass, field
3
2
  import typing
4
3
 
5
4
  import ezmsg.core as ez
6
- from ezmsg.util.messages.axisarray import AxisArray, replace
5
+ from ezmsg.util.messages.axisarray import AxisArray
6
+ from ezmsg.util.messages.util import replace
7
7
  from ezmsg.util.generator import consumer
8
8
  import numpy as np
9
9
  import numpy.typing as npt
10
10
  import scipy.signal
11
11
 
12
+ from ezmsg.sigproc.base import GenAxisArray
13
+
12
14
 
13
15
  @dataclass
14
16
  class FilterCoefficients:
@@ -17,10 +19,8 @@ class FilterCoefficients:
17
19
 
18
20
 
19
21
  def _normalize_coefs(
20
- coefs: typing.Union[
21
- FilterCoefficients, typing.Tuple[npt.NDArray, npt.NDArray], npt.NDArray
22
- ],
23
- ) -> typing.Tuple[str, typing.Tuple[npt.NDArray, ...]]:
22
+ coefs: FilterCoefficients | tuple[npt.NDArray, npt.NDArray] | npt.NDArray,
23
+ ) -> tuple[str, tuple[npt.NDArray, ...]]:
24
24
  coef_type = "ba"
25
25
  if coefs is not None:
26
26
  # scipy.signal functions called with first arg `*coefs`.
@@ -35,7 +35,7 @@ def _normalize_coefs(
35
35
 
36
36
  @consumer
37
37
  def filtergen(
38
- axis: str, coefs: typing.Optional[typing.Tuple[np.ndarray]], coef_type: str
38
+ axis: str, coefs: npt.NDArray | tuple[npt.NDArray] | None, coef_type: str
39
39
  ) -> typing.Generator[AxisArray, AxisArray, None]:
40
40
  """
41
41
  Filter data using the provided coefficients.
@@ -61,7 +61,7 @@ def filtergen(
61
61
  zi_func = {"ba": scipy.signal.lfilter_zi, "sos": scipy.signal.sosfilt_zi}[coef_type]
62
62
 
63
63
  # State variables
64
- zi: typing.Optional[npt.NDArray] = None
64
+ zi: npt.NDArray | None = None
65
65
 
66
66
  # Reset if these change.
67
67
  check_input = {"key": None, "shape": None}
@@ -105,128 +105,95 @@ def filtergen(
105
105
  msg_out = replace(msg_in, data=dat_out)
106
106
 
107
107
 
108
- class FilterSettingsBase(ez.Settings):
109
- axis: typing.Optional[str] = None
110
- fs: typing.Optional[float] = None
108
+ # Type aliases
109
+ BACoeffs = tuple[npt.NDArray, npt.NDArray]
110
+ SOSCoeffs = npt.NDArray
111
+ FilterCoefsMultiType = BACoeffs | SOSCoeffs
111
112
 
112
113
 
113
- class FilterSettings(FilterSettingsBase):
114
- # If you'd like to statically design a filter, define it in settings
115
- filt: typing.Optional[FilterCoefficients] = None
114
+ @consumer
115
+ def filter_gen_by_design(
116
+ axis: str,
117
+ coef_type: str,
118
+ design_fun: typing.Callable[[float], FilterCoefsMultiType | None],
119
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
120
+ """
121
+ Filter data using a filter whose coefficients are calculated using the provided design function.
116
122
 
123
+ Args:
124
+ axis: The name of the axis to filter.
125
+ Note: The axis must be represented in the message .axes and be of type AxisArray.LinearAxis.
126
+ coef_type: "ba" or "sos"
127
+ design_fun: A callable that takes "fs" as its only argument and returns a tuple of filter coefficients.
128
+ If the design_fun returns None then the filter will act as a passthrough.
129
+ Hint: To make a design function that only requires fs, use functools.partial to set other parameters.
130
+ See butterworthfilter for an example.
117
131
 
118
- class FilterState(ez.State):
119
- axis: typing.Optional[str] = None
120
- zi: typing.Optional[np.ndarray] = None
121
- filt_designed: bool = False
122
- filt: typing.Optional[FilterCoefficients] = None
123
- filt_set: asyncio.Event = field(default_factory=asyncio.Event)
124
- samp_shape: typing.Optional[typing.Tuple[int, ...]] = None
125
- fs: typing.Optional[float] = None # Hz
132
+ Returns:
126
133
 
134
+ """
135
+ msg_out = AxisArray(np.array([]), dims=[""])
127
136
 
128
- class Filter(ez.Unit):
129
- SETTINGS = FilterSettingsBase
130
- STATE = FilterState
137
+ # State variables
138
+ # Initialize filtergen as passthrough until we receive a message that allows us to design the filter.
139
+ filter_gen = filtergen(axis, None, coef_type)
131
140
 
132
- INPUT_FILTER = ez.InputStream(FilterCoefficients)
133
- INPUT_SIGNAL = ez.InputStream(AxisArray)
134
- OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
141
+ # Reset if these change.
142
+ check_input = {"gain": None}
143
+ # No need to check parameters that don't affect the design; filter_gen should check most of its parameters.
135
144
 
136
- def design_filter(self) -> typing.Optional[typing.Tuple[np.ndarray, np.ndarray]]:
137
- raise NotImplementedError("Must implement 'design_filter' in Unit subclass!")
145
+ while True:
146
+ msg_in: AxisArray = yield msg_out
147
+ axis = axis or msg_in.dims[0]
148
+ b_reset = msg_in.axes[axis].gain != check_input["gain"]
149
+ if b_reset:
150
+ check_input["gain"] = msg_in.axes[axis].gain
151
+ coefs = design_fun(1 / msg_in.axes[axis].gain)
152
+ filter_gen = filtergen(axis, coefs, coef_type)
138
153
 
139
- # Set up filter with static initialization if specified
140
- async def initialize(self) -> None:
141
- if self.SETTINGS.axis is not None:
142
- self.STATE.axis = self.SETTINGS.axis
154
+ msg_out = filter_gen.send(msg_in)
143
155
 
144
- if isinstance(self.SETTINGS, FilterSettings):
145
- if self.SETTINGS.filt is not None:
146
- self.STATE.filt = self.SETTINGS.filt
147
- self.STATE.filt_set.set()
148
- else:
149
- self.STATE.filt_set.clear()
150
156
 
151
- if self.SETTINGS.fs is not None:
152
- try:
153
- self.update_filter()
154
- except NotImplementedError:
155
- ez.logger.debug("Using filter coefficients.")
157
+ class FilterBaseSettings(ez.Settings):
158
+ axis: str | None = None
159
+ coef_type: str = "ba"
156
160
 
157
- @ez.subscriber(INPUT_FILTER)
158
- async def redesign(self, message: FilterCoefficients):
159
- self.STATE.filt = message
160
-
161
- def update_filter(self):
162
- try:
163
- coefs = self.design_filter()
164
- self.STATE.filt = (
165
- FilterCoefficients() if coefs is None else FilterCoefficients(*coefs)
166
- )
167
- self.STATE.filt_set.set()
168
- self.STATE.filt_designed = True
169
- except NotImplementedError as e:
170
- raise e
171
- except Exception as e:
172
- ez.logger.warning(f"Error when designing filter: {e}")
173
-
174
- @ez.subscriber(INPUT_SIGNAL)
175
- @ez.publisher(OUTPUT_SIGNAL)
176
- async def apply_filter(self, msg: AxisArray) -> typing.AsyncGenerator:
177
- axis_name = msg.dims[0] if self.STATE.axis is None else self.STATE.axis
178
- axis_idx = msg.get_axis_idx(axis_name)
179
- axis = msg.get_axis(axis_name)
180
- fs = 1.0 / axis.gain
181
-
182
- if self.STATE.fs != fs and self.STATE.filt_designed is True:
183
- self.STATE.fs = fs
184
- self.update_filter()
185
-
186
- # Ensure filter is defined
187
- # TODO: Maybe have me be a passthrough filter until coefficients are received
188
- if self.STATE.filt is None:
189
- self.STATE.filt_set.clear()
190
- ez.logger.info("Awaiting filter coefficients...")
191
- await self.STATE.filt_set.wait()
192
- ez.logger.info("Filter coefficients received.")
193
-
194
- assert self.STATE.filt is not None
195
-
196
- arr_in = msg.data
197
-
198
- # If the array is one dimensional, add a temporary second dimension so that the math works out
199
- one_dimensional = False
200
- if arr_in.ndim == 1:
201
- arr_in = np.expand_dims(arr_in, axis=1)
202
- one_dimensional = True
203
-
204
- # We will perform filter with time dimension as last axis
205
- arr_in = np.moveaxis(arr_in, axis_idx, -1)
206
- samp_shape = arr_in[..., 0].shape
207
161
 
208
- # Re-calculate/reset zi if necessary
209
- if self.STATE.zi is None or samp_shape != self.STATE.samp_shape:
210
- zi: np.ndarray = scipy.signal.lfilter_zi(
211
- self.STATE.filt.b, self.STATE.filt.a
212
- )
213
- self.STATE.samp_shape = samp_shape
214
- self.STATE.zi = np.array([zi] * np.prod(self.STATE.samp_shape))
215
- self.STATE.zi = self.STATE.zi.reshape(
216
- tuple(list(self.STATE.samp_shape) + [zi.shape[0]])
217
- )
162
+ class FilterBase(GenAxisArray):
163
+ SETTINGS = FilterBaseSettings
164
+
165
+ # Backwards-compatible with `Filter` unit
166
+ INPUT_FILTER = ez.InputStream(FilterCoefsMultiType)
167
+
168
+ def design_filter(
169
+ self,
170
+ ) -> typing.Callable[[float], FilterCoefsMultiType | None]:
171
+ raise NotImplementedError("Must implement 'design_filter' in Unit subclass!")
218
172
 
219
- arr_out, self.STATE.zi = scipy.signal.lfilter(
220
- self.STATE.filt.b, self.STATE.filt.a, arr_in, zi=self.STATE.zi
173
+ def construct_generator(self):
174
+ design_fun = self.design_filter()
175
+ self.STATE.gen = filter_gen_by_design(
176
+ self.SETTINGS.axis, self.SETTINGS.coef_type, design_fun
221
177
  )
222
178
 
223
- arr_out = np.moveaxis(arr_out, -1, axis_idx)
179
+ @ez.subscriber(INPUT_FILTER)
180
+ async def redesign(self, message: FilterBaseSettings) -> None:
181
+ self.apply_settings(message)
182
+ self.construct_generator()
224
183
 
225
- # Remove temporary first dimension if necessary
226
- if one_dimensional:
227
- arr_out = np.squeeze(arr_out, axis=1)
228
184
 
229
- yield (
230
- self.OUTPUT_SIGNAL,
231
- replace(msg, data=arr_out),
232
- )
185
+ class FilterSettings(FilterBaseSettings):
186
+ # If you'd like to statically design a filter, define it in settings
187
+ coefs: FilterCoefficients | None = None
188
+ # Note: coef_type = "ba" is assumed for this class.
189
+
190
+
191
+ class Filter(FilterBase):
192
+ SETTINGS = FilterSettings
193
+
194
+ INPUT_FILTER = ez.InputStream(FilterCoefficients)
195
+
196
+ def design_filter(self) -> typing.Callable[[float], BACoeffs | None]:
197
+ if self.SETTINGS.coefs is None:
198
+ return lambda fs: None
199
+ return lambda fs: (self.SETTINGS.coefs.b, self.SETTINGS.coefs.a)
@@ -8,7 +8,8 @@ import scipy.fft as sp_fft
8
8
  from scipy.special import lambertw
9
9
  import numpy.typing as npt
10
10
  import ezmsg.core as ez
11
- from ezmsg.util.messages.axisarray import AxisArray, replace
11
+ from ezmsg.util.messages.axisarray import AxisArray
12
+ from ezmsg.util.messages.util import replace
12
13
  from ezmsg.util.generator import consumer
13
14
 
14
15
  from .base import GenAxisArray
@@ -35,7 +36,7 @@ class MinPhaseMode(OptionsEnum):
35
36
 
36
37
  @consumer
37
38
  def filterbank(
38
- kernels: typing.Union[list[npt.NDArray], tuple[npt.NDArray, ...]],
39
+ kernels: list[npt.NDArray] | tuple[npt.NDArray, ...],
39
40
  mode: FilterbankMode = FilterbankMode.CONV,
40
41
  min_phase: MinPhaseMode = MinPhaseMode.NONE,
41
42
  axis: str = "time",
@@ -62,10 +63,10 @@ def filterbank(
62
63
  with the data payload containing the absolute value of the input :obj:`AxisArray` data.
63
64
 
64
65
  """
65
- msg_out: typing.Optional[AxisArray] = None
66
+ msg_out: AxisArray | None = None
66
67
 
67
68
  # State variables
68
- template: typing.Optional[AxisArray] = None
69
+ template: AxisArray | None = None
69
70
 
70
71
  # Reset if these change
71
72
  check_input = {
@@ -257,7 +258,7 @@ def filterbank(
257
258
 
258
259
 
259
260
  class FilterbankSettings(ez.Settings):
260
- kernels: typing.Union[list[npt.NDArray], tuple[npt.NDArray, ...]]
261
+ kernels: list[npt.NDArray] | tuple[npt.NDArray, ...]
261
262
  mode: FilterbankMode = FilterbankMode.CONV
262
263
  min_phase: MinPhaseMode = MinPhaseMode.NONE
263
264
  axis: str = "time"
ezmsg/sigproc/math/abs.py CHANGED
@@ -3,7 +3,8 @@ import typing
3
3
  import numpy as np
4
4
  import ezmsg.core as ez
5
5
  from ezmsg.util.generator import consumer
6
- from ezmsg.util.messages.axisarray import AxisArray, replace
6
+ from ezmsg.util.messages.axisarray import AxisArray
7
+ from ezmsg.util.messages.util import replace
7
8
 
8
9
  from ..base import GenAxisArray
9
10
 
@@ -3,7 +3,8 @@ import typing
3
3
  import numpy as np
4
4
  import ezmsg.core as ez
5
5
  from ezmsg.util.generator import consumer
6
- from ezmsg.util.messages.axisarray import AxisArray, replace
6
+ from ezmsg.util.messages.axisarray import AxisArray
7
+ from ezmsg.util.messages.util import replace
7
8
 
8
9
  from ..base import GenAxisArray
9
10
 
@@ -3,7 +3,8 @@ import typing
3
3
  import numpy as np
4
4
  import ezmsg.core as ez
5
5
  from ezmsg.util.generator import consumer
6
- from ezmsg.util.messages.axisarray import AxisArray, replace
6
+ from ezmsg.util.messages.axisarray import AxisArray
7
+ from ezmsg.util.messages.util import replace
7
8
 
8
9
  from ..base import GenAxisArray
9
10
 
@@ -3,7 +3,8 @@ import typing
3
3
  import numpy as np
4
4
  import ezmsg.core as ez
5
5
  from ezmsg.util.generator import consumer
6
- from ezmsg.util.messages.axisarray import AxisArray, replace
6
+ from ezmsg.util.messages.axisarray import AxisArray
7
+ from ezmsg.util.messages.util import replace
7
8
 
8
9
  from ..base import GenAxisArray
9
10
 
ezmsg/sigproc/math/log.py CHANGED
@@ -3,7 +3,8 @@ import typing
3
3
  import numpy as np
4
4
  import ezmsg.core as ez
5
5
  from ezmsg.util.generator import consumer
6
- from ezmsg.util.messages.axisarray import AxisArray, replace
6
+ from ezmsg.util.messages.axisarray import AxisArray
7
+ from ezmsg.util.messages.util import replace
7
8
 
8
9
  from ..base import GenAxisArray
9
10
 
@@ -3,7 +3,8 @@ import typing
3
3
  import numpy as np
4
4
  import ezmsg.core as ez
5
5
  from ezmsg.util.generator import consumer
6
- from ezmsg.util.messages.axisarray import AxisArray, replace
6
+ from ezmsg.util.messages.axisarray import AxisArray
7
+ from ezmsg.util.messages.util import replace
7
8
 
8
9
  from ..base import GenAxisArray
9
10
 
ezmsg/sigproc/messages.py CHANGED
@@ -1,6 +1,5 @@
1
1
  import warnings
2
2
  import time
3
- import typing
4
3
 
5
4
  import numpy.typing as npt
6
5
  from ezmsg.util.messages.axisarray import AxisArray
@@ -20,7 +19,7 @@ def TSMessage(
20
19
  data: npt.NDArray,
21
20
  fs: float = 1.0,
22
21
  time_dim: int = 0,
23
- timestamp: typing.Optional[float] = None,
22
+ timestamp: float | None = None,
24
23
  ) -> AxisArray:
25
24
  dims = [f"dim_{i}" for i in range(data.ndim)]
26
25
  dims[time_dim] = "time"
ezmsg/sigproc/sampler.py CHANGED
@@ -10,8 +10,8 @@ import ezmsg.core as ez
10
10
  from ezmsg.util.messages.axisarray import (
11
11
  AxisArray,
12
12
  slice_along_axis,
13
- replace,
14
13
  )
14
+ from ezmsg.util.messages.util import replace
15
15
  from ezmsg.util.generator import consumer
16
16
 
17
17
 
@@ -20,7 +20,7 @@ class SampleTriggerMessage:
20
20
  timestamp: float = field(default_factory=time.time)
21
21
  """Time of the trigger, in seconds. The Clock depends on the input but defaults to time.time"""
22
22
 
23
- period: typing.Optional[typing.Tuple[float, float]] = None
23
+ period: tuple[float, float] | None = None
24
24
  """The period around the timestamp, in seconds"""
25
25
 
26
26
  value: typing.Any = None
@@ -39,13 +39,11 @@ class SampleMessage:
39
39
  @consumer
40
40
  def sampler(
41
41
  buffer_dur: float,
42
- axis: typing.Optional[str] = None,
43
- period: typing.Optional[typing.Tuple[float, float]] = None,
42
+ axis: str | None = None,
43
+ period: tuple[float, float] | None = None,
44
44
  value: typing.Any = None,
45
45
  estimate_alignment: bool = True,
46
- ) -> typing.Generator[
47
- typing.List[SampleMessage], typing.Union[AxisArray, SampleTriggerMessage], None
48
- ]:
46
+ ) -> typing.Generator[list[SampleMessage], AxisArray | SampleTriggerMessage, None]:
49
47
  """
50
48
  Sample data into a buffer, accept triggers, and return slices of sampled
51
49
  data around the trigger time.
@@ -74,7 +72,7 @@ def sampler(
74
72
 
75
73
  # State variables (most shared between trigger- and data-processing.
76
74
  triggers: deque[SampleTriggerMessage] = deque()
77
- buffer: typing.Optional[npt.NDArray] = None
75
+ buffer: npt.NDArray | None = None
78
76
  n_samples: int = 0
79
77
  offset: float = 0.0
80
78
 
@@ -230,8 +228,8 @@ class SamplerSettings(ez.Settings):
230
228
  """
231
229
 
232
230
  buffer_dur: float
233
- axis: typing.Optional[str] = None
234
- period: typing.Optional[typing.Tuple[float, float]] = None
231
+ axis: str | None = None
232
+ period: tuple[float, float] | None = None
235
233
  """Optional default period if unspecified in SampleTriggerMessage"""
236
234
 
237
235
  value: typing.Any = None
@@ -248,9 +246,7 @@ class SamplerSettings(ez.Settings):
248
246
 
249
247
  class SamplerState(ez.State):
250
248
  cur_settings: SamplerSettings
251
- gen: typing.Generator[
252
- typing.Union[AxisArray, SampleTriggerMessage], typing.List[SampleMessage], None
253
- ]
249
+ gen: typing.Generator[AxisArray | SampleTriggerMessage, list[SampleMessage], None]
254
250
 
255
251
 
256
252
  class Sampler(ez.Unit):
@@ -295,7 +291,7 @@ class Sampler(ez.Unit):
295
291
 
296
292
 
297
293
  class TriggerGeneratorSettings(ez.Settings):
298
- period: typing.Tuple[float, float]
294
+ period: tuple[float, float]
299
295
  """The period around the trigger event."""
300
296
 
301
297
  prewait: float = 0.5