ezmsg-sigproc 1.2.3__py3-none-any.whl → 1.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,19 +3,45 @@ import typing
3
3
  import ezmsg.core as ez
4
4
  import scipy.signal
5
5
  import numpy as np
6
-
7
- from .filter import filtergen, Filter, FilterState, FilterSettingsBase
8
-
9
6
  from ezmsg.util.messages.axisarray import AxisArray
10
7
  from ezmsg.util.generator import consumer
11
8
 
9
+ from .filter import filtergen, Filter, FilterState, FilterSettingsBase
10
+
12
11
 
13
12
  class ButterworthFilterSettings(FilterSettingsBase):
13
+ """Settings for :obj:`ButterworthFilter`."""
14
+
14
15
  order: int = 0
15
- cuton: typing.Optional[float] = None # Hz
16
- cutoff: typing.Optional[float] = None # Hz
17
16
 
18
- def filter_specs(self) -> typing.Optional[typing.Tuple[str, typing.Union[float, typing.Tuple[float, float]]]]:
17
+ cuton: typing.Optional[float] = None
18
+ """
19
+ Cuton frequency (Hz). If cutoff is not specified then this is the highpass corner, otherwise
20
+ if it is lower than cutoff then this is the beginning of the bandpass
21
+ or if it is greater than cuton then it is the end of the bandstop.
22
+ """
23
+
24
+ cutoff: typing.Optional[float] = None
25
+ """
26
+ Cutoff frequency (Hz). If cuton is not specified then this is the lowpass corner, otherwise
27
+ if it is greater than cuton then this is the end of the bandpass,
28
+ or if it is less than cuton then it is the beginning of the bandstop.
29
+ """
30
+
31
+ def filter_specs(
32
+ self,
33
+ ) -> typing.Optional[
34
+ typing.Tuple[str, typing.Union[float, typing.Tuple[float, float]]]
35
+ ]:
36
+ """
37
+ Determine the filter type given the corner frequencies.
38
+
39
+ Returns:
40
+ A tuple with the first element being a string indicating the filter type
41
+ (one of "lowpass", "highpass", "bandpass", "bandstop")
42
+ and the second element being the corner frequency or frequencies.
43
+
44
+ """
19
45
  if self.cuton is None and self.cutoff is None:
20
46
  return None
21
47
  elif self.cuton is None and self.cutoff is not None:
@@ -37,28 +63,56 @@ def butter(
37
63
  cutoff: typing.Optional[float] = None,
38
64
  coef_type: str = "ba",
39
65
  ) -> typing.Generator[AxisArray, AxisArray, None]:
66
+ """
67
+ Apply Butterworth filter to streaming data. Uses :obj:`scipy.signal.butter` to design the filter.
68
+ See :obj:`ButterworthFilterSettings.filter_specs` for an explanation of specifying different
69
+ filter types (lowpass, highpass, bandpass, bandstop) from the parameters.
70
+
71
+ Args:
72
+ axis: The name of the axis to filter.
73
+ order: Filter order.
74
+ cuton: Corner frequency of the filter in Hz.
75
+ cutoff: Corner frequency of the filter in Hz.
76
+ coef_type: "ba" or "sos"
77
+
78
+ Returns:
79
+ A primed generator object which accepts .send(axis_array) and yields filtered axis array.
80
+
81
+ """
40
82
  # IO
41
- axis_arr_in = AxisArray(np.array([]), dims=[""])
42
- axis_arr_out = AxisArray(np.array([]), dims=[""])
83
+ msg_out = AxisArray(np.array([]), dims=[""])
43
84
 
85
+ # Check parameters
44
86
  btype, cutoffs = ButterworthFilterSettings(
45
87
  order=order, cuton=cuton, cutoff=cutoff
46
88
  ).filter_specs()
47
89
 
48
- # We cannot calculate coefs yet because we do not know input sample rate
49
- coefs = None
50
- filter_gen = filtergen(axis, coefs, coef_type) # Passthrough.
90
+ # State variables
91
+ # Initialize filtergen as passthrough until we can calculate coefs.
92
+ filter_gen = filtergen(axis, None, coef_type)
93
+
94
+ # Reset if these change.
95
+ check_input = {"gain": None}
96
+ # Key not checked because filter_gen will handle resetting if .key changes.
51
97
 
52
98
  while True:
53
- axis_arr_in = yield axis_arr_out
54
- if coefs is None and order > 0:
55
- fs = 1 / axis_arr_in.axes[axis or axis_arr_in.dims[0]].gain
99
+ msg_in: AxisArray = yield msg_out
100
+ axis = axis or msg_in.dims[0]
101
+
102
+ b_reset = msg_in.axes[axis].gain != check_input["gain"]
103
+ b_reset = b_reset and order > 0 # Not passthrough
104
+ if b_reset:
105
+ check_input["gain"] = msg_in.axes[axis].gain
56
106
  coefs = scipy.signal.butter(
57
- order, Wn=cutoffs, btype=btype, fs=fs, output=coef_type
107
+ order,
108
+ Wn=cutoffs,
109
+ btype=btype,
110
+ fs=1 / msg_in.axes[axis].gain,
111
+ output=coef_type,
58
112
  )
59
113
  filter_gen = filtergen(axis, coefs, coef_type)
60
114
 
61
- axis_arr_out = filter_gen.send(axis_arr_in)
115
+ msg_out = filter_gen.send(msg_in)
62
116
 
63
117
 
64
118
  class ButterworthFilterState(FilterState):
@@ -66,15 +120,17 @@ class ButterworthFilterState(FilterState):
66
120
 
67
121
 
68
122
  class ButterworthFilter(Filter):
69
- SETTINGS: ButterworthFilterSettings
70
- STATE: ButterworthFilterState
123
+ """:obj:`Unit` for :obj:`butterworth`"""
124
+
125
+ SETTINGS = ButterworthFilterSettings
126
+ STATE = ButterworthFilterState
71
127
 
72
128
  INPUT_FILTER = ez.InputStream(ButterworthFilterSettings)
73
129
 
74
- def initialize(self) -> None:
130
+ async def initialize(self) -> None:
75
131
  self.STATE.design = self.SETTINGS
76
132
  self.STATE.filt_designed = True
77
- super().initialize()
133
+ await super().initialize()
78
134
 
79
135
  def design_filter(self) -> typing.Optional[typing.Tuple[np.ndarray, np.ndarray]]:
80
136
  specs = self.STATE.design.filter_specs()
ezmsg/sigproc/decimate.py CHANGED
@@ -1,7 +1,5 @@
1
- import ezmsg.core as ez
2
-
3
1
  import scipy.signal
4
-
2
+ import ezmsg.core as ez
5
3
  from ezmsg.util.messages.axisarray import AxisArray
6
4
 
7
5
  from .downsample import Downsample, DownsampleSettings
@@ -9,7 +7,12 @@ from .filter import Filter, FilterCoefficients, FilterSettings
9
7
 
10
8
 
11
9
  class Decimate(ez.Collection):
12
- SETTINGS: DownsampleSettings
10
+ """
11
+ A :obj:`Collection` chaining a :obj:`Filter` node configured as a lowpass Chebyshev filter
12
+ and a :obj:`Downsample` node.
13
+ """
14
+
15
+ SETTINGS = DownsampleSettings
13
16
 
14
17
  INPUT_SIGNAL = ez.InputStream(AxisArray)
15
18
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
@@ -1,89 +1,107 @@
1
1
  from dataclasses import replace
2
- import traceback
3
- from typing import AsyncGenerator, Optional, Generator
2
+ import typing
4
3
 
5
4
  import numpy as np
6
-
7
- from ezmsg.util.messages.axisarray import AxisArray
5
+ from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
8
6
  from ezmsg.util.generator import consumer
9
7
  import ezmsg.core as ez
10
8
 
9
+ from .base import GenAxisArray
10
+
11
11
 
12
12
  @consumer
13
13
  def downsample(
14
- axis: Optional[str] = None, factor: int = 1
15
- ) -> Generator[AxisArray, AxisArray, None]:
16
- axis_arr_in = AxisArray(np.array([]), dims=[""])
17
- axis_arr_out = AxisArray(np.array([]), dims=[""])
14
+ axis: typing.Optional[str] = None, factor: int = 1
15
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
16
+ """
17
+ Construct a generator that yields a downsampled version of the data .send() to it.
18
+ Downsampled data simply comprise every `factor`th sample.
19
+ This should only be used following appropriate lowpass filtering.
20
+ If your pipeline does not already have lowpass filtering then consider
21
+ using the :obj:`Decimate` collection instead.
22
+
23
+ Args:
24
+ axis: The name of the axis along which to downsample.
25
+ factor: Downsampling factor.
26
+
27
+ Returns:
28
+ A primed generator object ready to receive a `.send(axis_array)`
29
+ and yields the downsampled data.
30
+ Note that if a send chunk does not have sufficient samples to reach the
31
+ next downsample interval then `None` is yielded.
32
+
33
+ """
34
+ msg_out = AxisArray(np.array([]), dims=[""])
35
+
36
+ if factor < 1:
37
+ raise ValueError("Downsample factor must be at least 1 (no downsampling)")
18
38
 
19
39
  # state variables
20
- s_idx = 0
40
+ s_idx: int = 0 # Index of the next msg's first sample into the virtual rotating ds_factor counter.
41
+
42
+ check_input = {"gain": None, "key": None}
21
43
 
22
44
  while True:
23
- axis_arr_in = yield axis_arr_out
45
+ msg_in: AxisArray = yield msg_out
24
46
 
25
47
  if axis is None:
26
- axis = axis_arr_in.dims[0]
27
- axis_info = axis_arr_in.get_axis(axis)
28
- axis_idx = axis_arr_in.get_axis_idx(axis)
29
-
30
- samples = np.arange(axis_arr_in.data.shape[axis_idx]) + s_idx
31
- samples = samples % factor
32
- s_idx = samples[-1] + 1
48
+ axis = msg_in.dims[0]
49
+ axis_info = msg_in.get_axis(axis)
50
+ axis_idx = msg_in.get_axis_idx(axis)
51
+
52
+ b_reset = (
53
+ msg_in.axes[axis].gain != check_input["gain"]
54
+ or msg_in.key != check_input["key"]
55
+ )
56
+ if b_reset:
57
+ check_input["gain"] = axis_info.gain
58
+ check_input["key"] = msg_in.key
59
+ # Reset state variables
60
+ s_idx = 0
61
+
62
+ n_samples = msg_in.data.shape[axis_idx]
63
+ samples = np.arange(s_idx, s_idx + n_samples) % factor
64
+ if n_samples > 0:
65
+ # Update state for next iteration.
66
+ s_idx = samples[-1] + 1
33
67
 
34
68
  pub_samples = np.where(samples == 0)[0]
35
69
  if len(pub_samples) > 0:
36
- new_axes = {ax_name: axis_arr_in.get_axis(ax_name) for ax_name in axis_arr_in.dims}
37
- new_offset = axis_info.offset + (axis_info.gain * pub_samples[0].item())
38
- new_gain = axis_info.gain * factor
39
- new_axes[axis] = replace(axis_info, gain=new_gain, offset=new_offset)
40
- down_data = np.take(axis_arr_in.data, pub_samples, axis=axis_idx)
41
- axis_arr_out = replace(axis_arr_in, data=down_data, dims=axis_arr_in.dims, axes=new_axes)
70
+ n_step = pub_samples[0].item()
71
+ data_slice = pub_samples
42
72
  else:
43
- axis_arr_out = None
73
+ n_step = 0
74
+ data_slice = slice(None, 0, None)
75
+ msg_out = replace(
76
+ msg_in,
77
+ data=slice_along_axis(msg_in.data, data_slice, axis=axis_idx),
78
+ axes={
79
+ **msg_in.axes,
80
+ axis: replace(
81
+ axis_info,
82
+ gain=axis_info.gain * factor,
83
+ offset=axis_info.offset + axis_info.gain * n_step,
84
+ ),
85
+ },
86
+ )
44
87
 
45
88
 
46
89
  class DownsampleSettings(ez.Settings):
47
- axis: Optional[str] = None
48
- factor: int = 1
49
-
90
+ """
91
+ Settings for :obj:`Downsample` node.
92
+ See :obj:`downsample` documentation for a description of the parameters.
93
+ """
50
94
 
51
- class DownsampleState(ez.State):
52
- cur_settings: DownsampleSettings
53
- gen: Generator
95
+ axis: typing.Optional[str] = None
96
+ factor: int = 1
54
97
 
55
98
 
56
- class Downsample(ez.Unit):
57
- SETTINGS: DownsampleSettings
58
- STATE: DownsampleState
99
+ class Downsample(GenAxisArray):
100
+ """:obj:`Unit` for :obj:`bandpower`."""
59
101
 
60
- INPUT_SETTINGS = ez.InputStream(DownsampleSettings)
61
- INPUT_SIGNAL = ez.InputStream(AxisArray)
62
- OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
102
+ SETTINGS = DownsampleSettings
63
103
 
64
104
  def construct_generator(self):
65
- self.STATE.gen = downsample(axis=self.STATE.cur_settings.axis, factor=self.STATE.cur_settings.factor)
66
-
67
- def initialize(self) -> None:
68
- self.STATE.cur_settings = self.SETTINGS
69
- self.construct_generator()
70
-
71
- @ez.subscriber(INPUT_SETTINGS)
72
- async def on_settings(self, msg: DownsampleSettings) -> None:
73
- self.STATE.cur_settings = msg
74
- self.construct_generator()
75
-
76
- @ez.subscriber(INPUT_SIGNAL, zero_copy=True)
77
- @ez.publisher(OUTPUT_SIGNAL)
78
- async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
79
- if self.STATE.cur_settings.factor < 1:
80
- raise ValueError("Downsample factor must be at least 1 (no downsampling)")
81
-
82
- try:
83
- out_msg = self.STATE.gen.send(msg)
84
- if out_msg is not None:
85
- yield self.OUTPUT_SIGNAL, out_msg
86
- except (StopIteration, GeneratorExit):
87
- ez.logger.debug(f"Downsample closed in {self.address}")
88
- except Exception:
89
- ez.logger.info(traceback.format_exc())
105
+ self.STATE.gen = downsample(
106
+ axis=self.SETTINGS.axis, factor=self.SETTINGS.factor
107
+ )
@@ -1,19 +1,20 @@
1
1
  import asyncio
2
2
  from dataclasses import replace
3
+ import typing
3
4
 
4
5
  import ezmsg.core as ez
5
6
  from ezmsg.util.messages.axisarray import AxisArray
6
-
7
7
  import numpy as np
8
8
 
9
9
  from .window import Window, WindowSettings
10
10
 
11
- from typing import AsyncGenerator, Optional
12
-
13
11
 
14
12
  class EWMSettings(ez.Settings):
15
- axis: Optional[str] = None
16
- zero_offset: bool = True # If true, we assume zero DC offset
13
+ axis: typing.Optional[str] = None
14
+ """Name of the axis to accumulate."""
15
+
16
+ zero_offset: bool = True
17
+ """If true, we assume zero DC offset for input data."""
17
18
 
18
19
 
19
20
  class EWMState(ez.State):
@@ -28,14 +29,14 @@ class EWM(ez.Unit):
28
29
  References https://stackoverflow.com/a/42926270
29
30
  """
30
31
 
31
- SETTINGS: EWMSettings
32
- STATE: EWMState
32
+ SETTINGS = EWMSettings
33
+ STATE = EWMState
33
34
 
34
35
  INPUT_SIGNAL = ez.InputStream(AxisArray)
35
36
  INPUT_BUFFER = ez.InputStream(AxisArray)
36
37
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
37
38
 
38
- def initialize(self) -> None:
39
+ async def initialize(self) -> None:
39
40
  self.STATE.signal_queue = asyncio.Queue()
40
41
  self.STATE.buffer_queue = asyncio.Queue()
41
42
 
@@ -48,7 +49,7 @@ class EWM(ez.Unit):
48
49
  self.STATE.buffer_queue.put_nowait(message)
49
50
 
50
51
  @ez.publisher(OUTPUT_SIGNAL)
51
- async def sync_output(self) -> AsyncGenerator:
52
+ async def sync_output(self) -> typing.AsyncGenerator:
52
53
  while True:
53
54
  signal = await self.STATE.signal_queue.get()
54
55
  buffer = await self.STATE.buffer_queue.get() # includes signal
@@ -96,13 +97,26 @@ class EWM(ez.Unit):
96
97
 
97
98
 
98
99
  class EWMFilterSettings(ez.Settings):
99
- history_dur: float # previous data to accumulate for standardization
100
- axis: Optional[str] = None
101
- zero_offset: bool = True # If true, we assume zero DC offset for input data
100
+ history_dur: float
101
+ """Previous data to accumulate for standardization."""
102
+
103
+ axis: typing.Optional[str] = None
104
+ """Name of the axis to accumulate."""
105
+
106
+ zero_offset: bool = True
107
+ """If true, we assume zero DC offset for input data."""
102
108
 
103
109
 
104
110
  class EWMFilter(ez.Collection):
105
- SETTINGS: EWMFilterSettings
111
+ """
112
+ A :obj:`Collection` that splits the input into a branch that
113
+ leads to :obj:`Window` which then feeds into :obj:`EWM` 's INPUT_BUFFER
114
+ and another branch that feeds directly into :obj:`EWM` 's INPUT_SIGNAL.
115
+
116
+ Consider :obj:`scaler` for a more efficient alternative.
117
+ """
118
+
119
+ SETTINGS = EWMFilterSettings
106
120
 
107
121
  INPUT_SIGNAL = ez.InputStream(AxisArray)
108
122
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
@@ -113,7 +127,7 @@ class EWMFilter(ez.Collection):
113
127
  def configure(self) -> None:
114
128
  self.EWM.apply_settings(
115
129
  EWMSettings(
116
- axis=self.SETTINGS.axis,
130
+ axis=self.SETTINGS.axis,
117
131
  zero_offset=self.SETTINGS.zero_offset,
118
132
  )
119
133
  )
ezmsg/sigproc/filter.py CHANGED
@@ -1,25 +1,26 @@
1
1
  import asyncio
2
- import typing
3
-
4
2
  from dataclasses import dataclass, replace, field
3
+ import typing
5
4
 
6
5
  import ezmsg.core as ez
7
- import scipy.signal
8
-
6
+ from ezmsg.util.messages.axisarray import AxisArray
7
+ from ezmsg.util.generator import consumer
9
8
  import numpy as np
10
9
  import numpy.typing as npt
10
+ import scipy.signal
11
11
 
12
- from ezmsg.util.messages.axisarray import AxisArray
13
- from ezmsg.util.generator import consumer
14
12
 
15
13
  @dataclass
16
14
  class FilterCoefficients:
17
15
  b: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0]))
18
16
  a: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0]))
19
17
 
18
+
20
19
  def _normalize_coefs(
21
- coefs: typing.Union[FilterCoefficients, typing.Tuple[npt.NDArray, npt.NDArray],npt.NDArray]
22
- ) -> typing.Tuple[str, typing.Tuple[npt.NDArray,...]]:
20
+ coefs: typing.Union[
21
+ FilterCoefficients, typing.Tuple[npt.NDArray, npt.NDArray], npt.NDArray
22
+ ],
23
+ ) -> typing.Tuple[str, typing.Tuple[npt.NDArray, ...]]:
23
24
  coef_type = "ba"
24
25
  if coefs is not None:
25
26
  # scipy.signal functions called with first arg `*coefs`.
@@ -31,57 +32,73 @@ def _normalize_coefs(
31
32
  coefs = (FilterCoefficients.b, FilterCoefficients.a)
32
33
  return coef_type, coefs
33
34
 
35
+
34
36
  @consumer
35
37
  def filtergen(
36
38
  axis: str, coefs: typing.Optional[typing.Tuple[np.ndarray]], coef_type: str
37
39
  ) -> typing.Generator[AxisArray, AxisArray, None]:
40
+ """
41
+ Construct a generic filter generator function.
42
+
43
+ Args:
44
+ axis: The name of the axis to operate on.
45
+ coefs: The pre-calculated filter coefficients.
46
+ coef_type: The type of filter coefficients. One of "ba" or "sos".
47
+
48
+ Returns:
49
+ A generator that expects .send(axis_array) and yields the filtered :obj:`AxisArray`.
50
+ """
38
51
  # Massage inputs
39
52
  if coefs is not None and not isinstance(coefs, tuple):
40
53
  # scipy.signal functions called with first arg `*coefs`, but sos coefs are a single ndarray.
41
54
  coefs = (coefs,)
42
55
 
43
56
  # Init IO
44
- axis_arr_in = AxisArray(np.array([]), dims=[""])
45
- axis_arr_out = AxisArray(np.array([]), dims=[""])
57
+ msg_out = AxisArray(np.array([]), dims=[""])
46
58
 
47
59
  filt_func = {"ba": scipy.signal.lfilter, "sos": scipy.signal.sosfilt}[coef_type]
48
60
  zi_func = {"ba": scipy.signal.lfilter_zi, "sos": scipy.signal.sosfilt_zi}[coef_type]
49
61
 
50
62
  # State variables
51
- axis_idx = None
52
- zi = None
53
- expected_shape = None
63
+ zi: typing.Optional[npt.NDArray] = None
64
+
65
+ # Reset if these change.
66
+ check_input = {"key": None, "shape": None}
67
+ # fs changing will be handled by caller that creates coefficients.
54
68
 
55
69
  while True:
56
- axis_arr_in = yield axis_arr_out
70
+ msg_in: AxisArray = yield msg_out
57
71
 
58
72
  if coefs is None:
59
73
  # passthrough if we do not have a filter design.
60
- axis_arr_out = axis_arr_in
74
+ msg_out = msg_in
61
75
  continue
62
76
 
63
- if axis_idx is None:
64
- axis_name = axis_arr_in.dims[0] if axis is None else axis
65
- axis_idx = axis_arr_in.get_axis_idx(axis_name)
66
-
67
- dat_in = axis_arr_in.data
77
+ axis = msg_in.dims[0] if axis is None else axis
78
+ axis_idx = msg_in.get_axis_idx(axis)
68
79
 
69
80
  # Re-calculate/reset zi if necessary
70
- samp_shape = dat_in.shape[:axis_idx] + dat_in.shape[axis_idx + 1 :]
71
- if zi is None or samp_shape != expected_shape:
72
- expected_shape = samp_shape
73
- n_tail = dat_in.ndim - axis_idx - 1
81
+ samp_shape = msg_in.data.shape[:axis_idx] + msg_in.data.shape[axis_idx + 1 :]
82
+ b_reset = samp_shape != check_input["shape"]
83
+ b_reset = b_reset or msg_in.key != check_input["key"]
84
+ if b_reset:
85
+ check_input["shape"] = samp_shape
86
+ check_input["key"] = msg_in.key
87
+
88
+ n_tail = msg_in.data.ndim - axis_idx - 1
74
89
  zi = zi_func(*coefs)
75
90
  zi_expand = (None,) * axis_idx + (slice(None),) + (None,) * n_tail
76
- n_tile = dat_in.shape[:axis_idx] + (1,) + dat_in.shape[axis_idx + 1 :]
91
+ n_tile = (
92
+ msg_in.data.shape[:axis_idx] + (1,) + msg_in.data.shape[axis_idx + 1 :]
93
+ )
77
94
  if coef_type == "sos":
78
95
  # sos zi must keep its leading dimension (`order / 2` for low|high; `order` for bpass|bstop)
79
96
  zi_expand = (slice(None),) + zi_expand
80
97
  n_tile = (1,) + n_tile
81
98
  zi = np.tile(zi[zi_expand], n_tile)
82
99
 
83
- dat_out, zi = filt_func(*coefs, dat_in, axis=axis_idx, zi=zi)
84
- axis_arr_out = replace(axis_arr_in, data=dat_out)
100
+ dat_out, zi = filt_func(*coefs, msg_in.data, axis=axis_idx, zi=zi)
101
+ msg_out = replace(msg_in, data=dat_out)
85
102
 
86
103
 
87
104
  class FilterSettingsBase(ez.Settings):
@@ -105,8 +122,8 @@ class FilterState(ez.State):
105
122
 
106
123
 
107
124
  class Filter(ez.Unit):
108
- SETTINGS: FilterSettingsBase
109
- STATE: FilterState
125
+ SETTINGS = FilterSettingsBase
126
+ STATE = FilterState
110
127
 
111
128
  INPUT_FILTER = ez.InputStream(FilterCoefficients)
112
129
  INPUT_SIGNAL = ez.InputStream(AxisArray)
@@ -116,7 +133,7 @@ class Filter(ez.Unit):
116
133
  raise NotImplementedError("Must implement 'design_filter' in Unit subclass!")
117
134
 
118
135
  # Set up filter with static initialization if specified
119
- def initialize(self) -> None:
136
+ async def initialize(self) -> None:
120
137
  if self.SETTINGS.axis is not None:
121
138
  self.STATE.axis = self.SETTINGS.axis
122
139
 
@@ -205,4 +222,7 @@ class Filter(ez.Unit):
205
222
  if one_dimensional:
206
223
  arr_out = np.squeeze(arr_out, axis=1)
207
224
 
208
- yield self.OUTPUT_SIGNAL, replace(msg, data=arr_out),
225
+ yield (
226
+ self.OUTPUT_SIGNAL,
227
+ replace(msg, data=arr_out),
228
+ )