ezmsg-sigproc 1.2.2__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. ezmsg/sigproc/__init__.py +1 -1
  2. ezmsg/sigproc/__version__.py +16 -1
  3. ezmsg/sigproc/activation.py +75 -0
  4. ezmsg/sigproc/affinetransform.py +234 -0
  5. ezmsg/sigproc/aggregate.py +158 -0
  6. ezmsg/sigproc/bandpower.py +74 -0
  7. ezmsg/sigproc/base.py +38 -0
  8. ezmsg/sigproc/butterworthfilter.py +102 -11
  9. ezmsg/sigproc/decimate.py +7 -4
  10. ezmsg/sigproc/downsample.py +95 -51
  11. ezmsg/sigproc/ewmfilter.py +38 -16
  12. ezmsg/sigproc/filter.py +108 -20
  13. ezmsg/sigproc/filterbank.py +278 -0
  14. ezmsg/sigproc/math/__init__.py +0 -0
  15. ezmsg/sigproc/math/abs.py +28 -0
  16. ezmsg/sigproc/math/clip.py +30 -0
  17. ezmsg/sigproc/math/difference.py +60 -0
  18. ezmsg/sigproc/math/invert.py +29 -0
  19. ezmsg/sigproc/math/log.py +32 -0
  20. ezmsg/sigproc/math/scale.py +31 -0
  21. ezmsg/sigproc/messages.py +2 -3
  22. ezmsg/sigproc/sampler.py +259 -224
  23. ezmsg/sigproc/scaler.py +173 -0
  24. ezmsg/sigproc/signalinjector.py +64 -0
  25. ezmsg/sigproc/slicer.py +133 -0
  26. ezmsg/sigproc/spectral.py +6 -132
  27. ezmsg/sigproc/spectrogram.py +86 -0
  28. ezmsg/sigproc/spectrum.py +259 -0
  29. ezmsg/sigproc/synth.py +299 -105
  30. ezmsg/sigproc/wavelets.py +167 -0
  31. ezmsg/sigproc/window.py +254 -116
  32. ezmsg_sigproc-1.3.1.dist-info/METADATA +59 -0
  33. ezmsg_sigproc-1.3.1.dist-info/RECORD +35 -0
  34. {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-1.3.1.dist-info}/WHEEL +1 -2
  35. ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
  36. ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
  37. ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
  38. {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-1.3.1.dist-info/licenses}/LICENSE.txt +0 -0
@@ -1,18 +1,47 @@
1
+ import typing
2
+
1
3
  import ezmsg.core as ez
2
4
  import scipy.signal
3
5
  import numpy as np
6
+ from ezmsg.util.messages.axisarray import AxisArray
7
+ from ezmsg.util.generator import consumer
4
8
 
5
- from .filter import Filter, FilterState, FilterSettingsBase
6
-
7
- from typing import Optional, Tuple, Union
9
+ from .filter import filtergen, Filter, FilterState, FilterSettingsBase
8
10
 
9
11
 
10
12
  class ButterworthFilterSettings(FilterSettingsBase):
13
+ """Settings for :obj:`ButterworthFilter`."""
14
+
11
15
  order: int = 0
12
- cuton: Optional[float] = None # Hz
13
- cutoff: Optional[float] = None # Hz
14
16
 
15
- def filter_specs(self) -> Optional[Tuple[str, Union[float, Tuple[float, float]]]]:
17
+ cuton: typing.Optional[float] = None
18
+ """
19
+ Cuton frequency (Hz). If cutoff is not specified then this is the highpass corner, otherwise
20
+ if it is lower than cutoff then this is the beginning of the bandpass
21
+ or if it is greater than cuton then it is the end of the bandstop.
22
+ """
23
+
24
+ cutoff: typing.Optional[float] = None
25
+ """
26
+ Cutoff frequency (Hz). If cuton is not specified then this is the lowpass corner, otherwise
27
+ if it is greater than cuton then this is the end of the bandpass,
28
+ or if it is less than cuton then it is the beginning of the bandstop.
29
+ """
30
+
31
+ def filter_specs(
32
+ self,
33
+ ) -> typing.Optional[
34
+ typing.Tuple[str, typing.Union[float, typing.Tuple[float, float]]]
35
+ ]:
36
+ """
37
+ Determine the filter type given the corner frequencies.
38
+
39
+ Returns:
40
+ A tuple with the first element being a string indicating the filter type
41
+ (one of "lowpass", "highpass", "bandpass", "bandstop")
42
+ and the second element being the corner frequency or frequencies.
43
+
44
+ """
16
45
  if self.cuton is None and self.cutoff is None:
17
46
  return None
18
47
  elif self.cuton is None and self.cutoff is not None:
@@ -26,22 +55,84 @@ class ButterworthFilterSettings(FilterSettingsBase):
26
55
  return "bandstop", (self.cutoff, self.cuton)
27
56
 
28
57
 
58
+ @consumer
59
+ def butter(
60
+ axis: typing.Optional[str],
61
+ order: int = 0,
62
+ cuton: typing.Optional[float] = None,
63
+ cutoff: typing.Optional[float] = None,
64
+ coef_type: str = "ba",
65
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
66
+ """
67
+ Apply Butterworth filter to streaming data. Uses :obj:`scipy.signal.butter` to design the filter.
68
+ See :obj:`ButterworthFilterSettings.filter_specs` for an explanation of specifying different
69
+ filter types (lowpass, highpass, bandpass, bandstop) from the parameters.
70
+
71
+ Args:
72
+ axis: The name of the axis to filter.
73
+ order: Filter order.
74
+ cuton: Corner frequency of the filter in Hz.
75
+ cutoff: Corner frequency of the filter in Hz.
76
+ coef_type: "ba" or "sos"
77
+
78
+ Returns:
79
+ A primed generator object which accepts .send(axis_array) and yields filtered axis array.
80
+
81
+ """
82
+ # IO
83
+ msg_out = AxisArray(np.array([]), dims=[""])
84
+
85
+ # Check parameters
86
+ btype, cutoffs = ButterworthFilterSettings(
87
+ order=order, cuton=cuton, cutoff=cutoff
88
+ ).filter_specs()
89
+
90
+ # State variables
91
+ # Initialize filtergen as passthrough until we can calculate coefs.
92
+ filter_gen = filtergen(axis, None, coef_type)
93
+
94
+ # Reset if these change.
95
+ check_input = {"gain": None}
96
+ # Key not checked because filter_gen will handle resetting if .key changes.
97
+
98
+ while True:
99
+ msg_in: AxisArray = yield msg_out
100
+ axis = axis or msg_in.dims[0]
101
+
102
+ b_reset = msg_in.axes[axis].gain != check_input["gain"]
103
+ b_reset = b_reset and order > 0 # Not passthrough
104
+ if b_reset:
105
+ check_input["gain"] = msg_in.axes[axis].gain
106
+ coefs = scipy.signal.butter(
107
+ order,
108
+ Wn=cutoffs,
109
+ btype=btype,
110
+ fs=1 / msg_in.axes[axis].gain,
111
+ output=coef_type,
112
+ )
113
+ filter_gen = filtergen(axis, coefs, coef_type)
114
+
115
+ msg_out = filter_gen.send(msg_in)
116
+
117
+
29
118
  class ButterworthFilterState(FilterState):
30
119
  design: ButterworthFilterSettings
31
120
 
32
121
 
33
122
  class ButterworthFilter(Filter):
34
- SETTINGS: ButterworthFilterSettings
35
- STATE: ButterworthFilterState
123
+ """:obj:`Unit` for :obj:`butterworth`"""
124
+
125
+ SETTINGS = ButterworthFilterSettings
126
+ STATE = ButterworthFilterState
36
127
 
37
128
  INPUT_FILTER = ez.InputStream(ButterworthFilterSettings)
38
129
 
39
- def initialize(self) -> None:
130
+ async def initialize(self) -> None:
40
131
  self.STATE.design = self.SETTINGS
41
132
  self.STATE.filt_designed = True
42
- super().initialize()
133
+ await super().initialize()
43
134
 
44
- def design_filter(self) -> Optional[Tuple[np.ndarray, np.ndarray]]:
135
+ def design_filter(self) -> typing.Optional[typing.Tuple[np.ndarray, np.ndarray]]:
45
136
  specs = self.STATE.design.filter_specs()
46
137
  if self.STATE.design.order > 0 and specs is not None:
47
138
  btype, cut = specs
ezmsg/sigproc/decimate.py CHANGED
@@ -1,7 +1,5 @@
1
- import ezmsg.core as ez
2
-
3
1
  import scipy.signal
4
-
2
+ import ezmsg.core as ez
5
3
  from ezmsg.util.messages.axisarray import AxisArray
6
4
 
7
5
  from .downsample import Downsample, DownsampleSettings
@@ -9,7 +7,12 @@ from .filter import Filter, FilterCoefficients, FilterSettings
9
7
 
10
8
 
11
9
  class Decimate(ez.Collection):
12
- SETTINGS: DownsampleSettings
10
+ """
11
+ A :obj:`Collection` chaining a :obj:`Filter` node configured as a lowpass Chebyshev filter
12
+ and a :obj:`Downsample` node.
13
+ """
14
+
15
+ SETTINGS = DownsampleSettings
13
16
 
14
17
  INPUT_SIGNAL = ez.InputStream(AxisArray)
15
18
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
@@ -1,63 +1,107 @@
1
1
  from dataclasses import replace
2
+ import typing
2
3
 
3
- from ezmsg.util.messages.axisarray import AxisArray
4
-
5
- import ezmsg.core as ez
6
4
  import numpy as np
5
+ from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
6
+ from ezmsg.util.generator import consumer
7
+ import ezmsg.core as ez
7
8
 
8
- from typing import (
9
- AsyncGenerator,
10
- Optional,
11
- )
12
-
13
-
14
- class DownsampleSettings(ez.Settings):
15
- axis: Optional[str] = None
16
- factor: int = 1
17
-
18
-
19
- class DownsampleState(ez.State):
20
- cur_settings: DownsampleSettings
21
- s_idx: int = 0
22
-
9
+ from .base import GenAxisArray
10
+
11
+
12
+ @consumer
13
+ def downsample(
14
+ axis: typing.Optional[str] = None, factor: int = 1
15
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
16
+ """
17
+ Construct a generator that yields a downsampled version of the data .send() to it.
18
+ Downsampled data simply comprise every `factor`th sample.
19
+ This should only be used following appropriate lowpass filtering.
20
+ If your pipeline does not already have lowpass filtering then consider
21
+ using the :obj:`Decimate` collection instead.
22
+
23
+ Args:
24
+ axis: The name of the axis along which to downsample.
25
+ factor: Downsampling factor.
26
+
27
+ Returns:
28
+ A primed generator object ready to receive a `.send(axis_array)`
29
+ and yields the downsampled data.
30
+ Note that if a send chunk does not have sufficient samples to reach the
31
+ next downsample interval then `None` is yielded.
32
+
33
+ """
34
+ msg_out = AxisArray(np.array([]), dims=[""])
35
+
36
+ if factor < 1:
37
+ raise ValueError("Downsample factor must be at least 1 (no downsampling)")
38
+
39
+ # state variables
40
+ s_idx: int = 0 # Index of the next msg's first sample into the virtual rotating ds_factor counter.
41
+
42
+ check_input = {"gain": None, "key": None}
43
+
44
+ while True:
45
+ msg_in: AxisArray = yield msg_out
46
+
47
+ if axis is None:
48
+ axis = msg_in.dims[0]
49
+ axis_info = msg_in.get_axis(axis)
50
+ axis_idx = msg_in.get_axis_idx(axis)
51
+
52
+ b_reset = (
53
+ msg_in.axes[axis].gain != check_input["gain"]
54
+ or msg_in.key != check_input["key"]
55
+ )
56
+ if b_reset:
57
+ check_input["gain"] = axis_info.gain
58
+ check_input["key"] = msg_in.key
59
+ # Reset state variables
60
+ s_idx = 0
61
+
62
+ n_samples = msg_in.data.shape[axis_idx]
63
+ samples = np.arange(s_idx, s_idx + n_samples) % factor
64
+ if n_samples > 0:
65
+ # Update state for next iteration.
66
+ s_idx = samples[-1] + 1
23
67
 
24
- class Downsample(ez.Unit):
25
- SETTINGS: DownsampleSettings
26
- STATE: DownsampleState
68
+ pub_samples = np.where(samples == 0)[0]
69
+ if len(pub_samples) > 0:
70
+ n_step = pub_samples[0].item()
71
+ data_slice = pub_samples
72
+ else:
73
+ n_step = 0
74
+ data_slice = slice(None, 0, None)
75
+ msg_out = replace(
76
+ msg_in,
77
+ data=slice_along_axis(msg_in.data, data_slice, axis=axis_idx),
78
+ axes={
79
+ **msg_in.axes,
80
+ axis: replace(
81
+ axis_info,
82
+ gain=axis_info.gain * factor,
83
+ offset=axis_info.offset + axis_info.gain * n_step,
84
+ ),
85
+ },
86
+ )
27
87
 
28
- INPUT_SETTINGS = ez.InputStream(DownsampleSettings)
29
- INPUT_SIGNAL = ez.InputStream(AxisArray)
30
- OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
31
88
 
32
- def initialize(self) -> None:
33
- self.STATE.cur_settings = self.SETTINGS
89
+ class DownsampleSettings(ez.Settings):
90
+ """
91
+ Settings for :obj:`Downsample` node.
92
+ See :obj:`downsample` documentation for a description of the parameters.
93
+ """
34
94
 
35
- @ez.subscriber(INPUT_SETTINGS)
36
- async def on_settings(self, msg: DownsampleSettings) -> None:
37
- self.STATE.cur_settings = msg
95
+ axis: typing.Optional[str] = None
96
+ factor: int = 1
38
97
 
39
- @ez.subscriber(INPUT_SIGNAL, zero_copy=True)
40
- @ez.publisher(OUTPUT_SIGNAL)
41
- async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
42
- if self.STATE.cur_settings.factor < 1:
43
- raise ValueError("Downsample factor must be at least 1 (no downsampling)")
44
98
 
45
- axis_name = self.STATE.cur_settings.axis
46
- if axis_name is None:
47
- axis_name = msg.dims[0]
48
- axis = msg.get_axis(axis_name)
49
- axis_idx = msg.get_axis_idx(axis_name)
99
+ class Downsample(GenAxisArray):
100
+ """:obj:`Unit` for :obj:`bandpower`."""
50
101
 
51
- samples = np.arange(msg.data.shape[axis_idx]) + self.STATE.s_idx
52
- samples = samples % self.STATE.cur_settings.factor
53
- self.STATE.s_idx = samples[-1] + 1
102
+ SETTINGS = DownsampleSettings
54
103
 
55
- pub_samples = np.where(samples == 0)[0]
56
- if len(pub_samples) != 0:
57
- new_axes = {ax_name: msg.get_axis(ax_name) for ax_name in msg.dims}
58
- new_offset = axis.offset + (axis.gain * pub_samples[0].item())
59
- new_gain = axis.gain * self.STATE.cur_settings.factor
60
- new_axes[axis_name] = replace(axis, gain=new_gain, offset=new_offset)
61
- down_data = np.take(msg.data, pub_samples, axis_idx)
62
- out_msg = replace(msg, data=down_data, dims=msg.dims, axes=new_axes)
63
- yield self.OUTPUT_SIGNAL, out_msg
104
+ def construct_generator(self):
105
+ self.STATE.gen = downsample(
106
+ axis=self.SETTINGS.axis, factor=self.SETTINGS.factor
107
+ )
@@ -1,19 +1,20 @@
1
1
  import asyncio
2
2
  from dataclasses import replace
3
+ import typing
3
4
 
4
5
  import ezmsg.core as ez
5
6
  from ezmsg.util.messages.axisarray import AxisArray
6
-
7
7
  import numpy as np
8
8
 
9
9
  from .window import Window, WindowSettings
10
10
 
11
- from typing import AsyncGenerator, Optional
12
-
13
11
 
14
12
  class EWMSettings(ez.Settings):
15
- axis: Optional[str] = None
16
- zero_offset: bool = True # If true, we assume zero DC offset
13
+ axis: typing.Optional[str] = None
14
+ """Name of the axis to accumulate."""
15
+
16
+ zero_offset: bool = True
17
+ """If true, we assume zero DC offset for input data."""
17
18
 
18
19
 
19
20
  class EWMState(ez.State):
@@ -28,14 +29,14 @@ class EWM(ez.Unit):
28
29
  References https://stackoverflow.com/a/42926270
29
30
  """
30
31
 
31
- SETTINGS: EWMSettings
32
- STATE: EWMState
32
+ SETTINGS = EWMSettings
33
+ STATE = EWMState
33
34
 
34
35
  INPUT_SIGNAL = ez.InputStream(AxisArray)
35
36
  INPUT_BUFFER = ez.InputStream(AxisArray)
36
37
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
37
38
 
38
- def initialize(self) -> None:
39
+ async def initialize(self) -> None:
39
40
  self.STATE.signal_queue = asyncio.Queue()
40
41
  self.STATE.buffer_queue = asyncio.Queue()
41
42
 
@@ -48,7 +49,7 @@ class EWM(ez.Unit):
48
49
  self.STATE.buffer_queue.put_nowait(message)
49
50
 
50
51
  @ez.publisher(OUTPUT_SIGNAL)
51
- async def sync_output(self) -> AsyncGenerator:
52
+ async def sync_output(self) -> typing.AsyncGenerator:
52
53
  while True:
53
54
  signal = await self.STATE.signal_queue.get()
54
55
  buffer = await self.STATE.buffer_queue.get() # includes signal
@@ -73,9 +74,12 @@ class EWM(ez.Unit):
73
74
  buffer_data = buffer.data
74
75
  buffer_data = np.moveaxis(buffer_data, axis_idx, 0)
75
76
 
77
+ while scale_arr.ndim < buffer_data.ndim:
78
+ scale_arr = scale_arr[..., None]
79
+
76
80
  def ewma(data: np.ndarray) -> np.ndarray:
77
- mult = scale_arr[:, np.newaxis] * data * pw0
78
- out = scale_arr[::-1, np.newaxis] * mult.cumsum(axis=0)
81
+ mult = scale_arr * data * pw0
82
+ out = scale_arr[::-1] * mult.cumsum(axis=0)
79
83
 
80
84
  if not self.SETTINGS.zero_offset:
81
85
  out = (data[0, :, np.newaxis] * pows[1:]).T + out
@@ -93,13 +97,26 @@ class EWM(ez.Unit):
93
97
 
94
98
 
95
99
  class EWMFilterSettings(ez.Settings):
96
- history_dur: float # previous data to accumulate for standardization
97
- axis: Optional[str] = None
98
- zero_offset: bool = True # If true, we assume zero DC offset for input data
100
+ history_dur: float
101
+ """Previous data to accumulate for standardization."""
102
+
103
+ axis: typing.Optional[str] = None
104
+ """Name of the axis to accumulate."""
105
+
106
+ zero_offset: bool = True
107
+ """If true, we assume zero DC offset for input data."""
99
108
 
100
109
 
101
110
  class EWMFilter(ez.Collection):
102
- SETTINGS: EWMFilterSettings
111
+ """
112
+ A :obj:`Collection` that splits the input into a branch that
113
+ leads to :obj:`Window` which then feeds into :obj:`EWM` 's INPUT_BUFFER
114
+ and another branch that feeds directly into :obj:`EWM` 's INPUT_SIGNAL.
115
+
116
+ Consider :obj:`scaler` for a more efficient alternative.
117
+ """
118
+
119
+ SETTINGS = EWMFilterSettings
103
120
 
104
121
  INPUT_SIGNAL = ez.InputStream(AxisArray)
105
122
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
@@ -108,7 +125,12 @@ class EWMFilter(ez.Collection):
108
125
  EWM = EWM()
109
126
 
110
127
  def configure(self) -> None:
111
- self.EWM.apply_settings(EWMSettings(axis=self.SETTINGS.axis, zero_offset=True))
128
+ self.EWM.apply_settings(
129
+ EWMSettings(
130
+ axis=self.SETTINGS.axis,
131
+ zero_offset=self.SETTINGS.zero_offset,
132
+ )
133
+ )
112
134
 
113
135
  self.WINDOW.apply_settings(
114
136
  WindowSettings(
ezmsg/sigproc/filter.py CHANGED
@@ -1,13 +1,13 @@
1
+ import asyncio
1
2
  from dataclasses import dataclass, replace, field
3
+ import typing
2
4
 
3
5
  import ezmsg.core as ez
4
- import scipy.signal
5
- import numpy as np
6
- import asyncio
7
-
8
6
  from ezmsg.util.messages.axisarray import AxisArray
9
-
10
- from typing import AsyncGenerator, Optional, Tuple
7
+ from ezmsg.util.generator import consumer
8
+ import numpy as np
9
+ import numpy.typing as npt
10
+ import scipy.signal
11
11
 
12
12
 
13
13
  @dataclass
@@ -16,39 +16,124 @@ class FilterCoefficients:
16
16
  a: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0]))
17
17
 
18
18
 
19
+ def _normalize_coefs(
20
+ coefs: typing.Union[
21
+ FilterCoefficients, typing.Tuple[npt.NDArray, npt.NDArray], npt.NDArray
22
+ ],
23
+ ) -> typing.Tuple[str, typing.Tuple[npt.NDArray, ...]]:
24
+ coef_type = "ba"
25
+ if coefs is not None:
26
+ # scipy.signal functions called with first arg `*coefs`.
27
+ # Make sure we have a tuple of coefficients.
28
+ if isinstance(coefs, npt.NDArray):
29
+ coef_type = "sos"
30
+ coefs = (coefs,) # sos funcs just want a single ndarray.
31
+ elif isinstance(coefs, FilterCoefficients):
32
+ coefs = (FilterCoefficients.b, FilterCoefficients.a)
33
+ return coef_type, coefs
34
+
35
+
36
+ @consumer
37
+ def filtergen(
38
+ axis: str, coefs: typing.Optional[typing.Tuple[np.ndarray]], coef_type: str
39
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
40
+ """
41
+ Construct a generic filter generator function.
42
+
43
+ Args:
44
+ axis: The name of the axis to operate on.
45
+ coefs: The pre-calculated filter coefficients.
46
+ coef_type: The type of filter coefficients. One of "ba" or "sos".
47
+
48
+ Returns:
49
+ A generator that expects .send(axis_array) and yields the filtered :obj:`AxisArray`.
50
+ """
51
+ # Massage inputs
52
+ if coefs is not None and not isinstance(coefs, tuple):
53
+ # scipy.signal functions called with first arg `*coefs`, but sos coefs are a single ndarray.
54
+ coefs = (coefs,)
55
+
56
+ # Init IO
57
+ msg_out = AxisArray(np.array([]), dims=[""])
58
+
59
+ filt_func = {"ba": scipy.signal.lfilter, "sos": scipy.signal.sosfilt}[coef_type]
60
+ zi_func = {"ba": scipy.signal.lfilter_zi, "sos": scipy.signal.sosfilt_zi}[coef_type]
61
+
62
+ # State variables
63
+ zi: typing.Optional[npt.NDArray] = None
64
+
65
+ # Reset if these change.
66
+ check_input = {"key": None, "shape": None}
67
+ # fs changing will be handled by caller that creates coefficients.
68
+
69
+ while True:
70
+ msg_in: AxisArray = yield msg_out
71
+
72
+ if coefs is None:
73
+ # passthrough if we do not have a filter design.
74
+ msg_out = msg_in
75
+ continue
76
+
77
+ axis = msg_in.dims[0] if axis is None else axis
78
+ axis_idx = msg_in.get_axis_idx(axis)
79
+
80
+ # Re-calculate/reset zi if necessary
81
+ samp_shape = msg_in.data.shape[:axis_idx] + msg_in.data.shape[axis_idx + 1 :]
82
+ b_reset = samp_shape != check_input["shape"]
83
+ b_reset = b_reset or msg_in.key != check_input["key"]
84
+ if b_reset:
85
+ check_input["shape"] = samp_shape
86
+ check_input["key"] = msg_in.key
87
+
88
+ n_tail = msg_in.data.ndim - axis_idx - 1
89
+ zi = zi_func(*coefs)
90
+ zi_expand = (None,) * axis_idx + (slice(None),) + (None,) * n_tail
91
+ n_tile = (
92
+ msg_in.data.shape[:axis_idx] + (1,) + msg_in.data.shape[axis_idx + 1 :]
93
+ )
94
+ if coef_type == "sos":
95
+ # sos zi must keep its leading dimension (`order / 2` for low|high; `order` for bpass|bstop)
96
+ zi_expand = (slice(None),) + zi_expand
97
+ n_tile = (1,) + n_tile
98
+ zi = np.tile(zi[zi_expand], n_tile)
99
+
100
+ dat_out, zi = filt_func(*coefs, msg_in.data, axis=axis_idx, zi=zi)
101
+ msg_out = replace(msg_in, data=dat_out)
102
+
103
+
19
104
  class FilterSettingsBase(ez.Settings):
20
- axis: Optional[str] = None
21
- fs: Optional[float] = None
105
+ axis: typing.Optional[str] = None
106
+ fs: typing.Optional[float] = None
22
107
 
23
108
 
24
109
  class FilterSettings(FilterSettingsBase):
25
110
  # If you'd like to statically design a filter, define it in settings
26
- filt: Optional[FilterCoefficients] = None
111
+ filt: typing.Optional[FilterCoefficients] = None
27
112
 
28
113
 
29
114
  class FilterState(ez.State):
30
- axis: Optional[str] = None
31
- zi: Optional[np.ndarray] = None
115
+ axis: typing.Optional[str] = None
116
+ zi: typing.Optional[np.ndarray] = None
32
117
  filt_designed: bool = False
33
- filt: Optional[FilterCoefficients] = None
118
+ filt: typing.Optional[FilterCoefficients] = None
34
119
  filt_set: asyncio.Event = field(default_factory=asyncio.Event)
35
- samp_shape: Optional[Tuple[int, ...]] = None
36
- fs: Optional[float] = None # Hz
120
+ samp_shape: typing.Optional[typing.Tuple[int, ...]] = None
121
+ fs: typing.Optional[float] = None # Hz
37
122
 
38
123
 
39
124
  class Filter(ez.Unit):
40
- SETTINGS: FilterSettingsBase
41
- STATE: FilterState
125
+ SETTINGS = FilterSettingsBase
126
+ STATE = FilterState
42
127
 
43
128
  INPUT_FILTER = ez.InputStream(FilterCoefficients)
44
129
  INPUT_SIGNAL = ez.InputStream(AxisArray)
45
130
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
46
131
 
47
- def design_filter(self) -> Optional[Tuple[np.ndarray, np.ndarray]]:
132
+ def design_filter(self) -> typing.Optional[typing.Tuple[np.ndarray, np.ndarray]]:
48
133
  raise NotImplementedError("Must implement 'design_filter' in Unit subclass!")
49
134
 
50
135
  # Set up filter with static initialization if specified
51
- def initialize(self) -> None:
136
+ async def initialize(self) -> None:
52
137
  if self.SETTINGS.axis is not None:
53
138
  self.STATE.axis = self.SETTINGS.axis
54
139
 
@@ -84,7 +169,7 @@ class Filter(ez.Unit):
84
169
 
85
170
  @ez.subscriber(INPUT_SIGNAL)
86
171
  @ez.publisher(OUTPUT_SIGNAL)
87
- async def apply_filter(self, msg: AxisArray) -> AsyncGenerator:
172
+ async def apply_filter(self, msg: AxisArray) -> typing.AsyncGenerator:
88
173
  axis_name = msg.dims[0] if self.STATE.axis is None else self.STATE.axis
89
174
  axis_idx = msg.get_axis_idx(axis_name)
90
175
  axis = msg.get_axis(axis_name)
@@ -137,4 +222,7 @@ class Filter(ez.Unit):
137
222
  if one_dimensional:
138
223
  arr_out = np.squeeze(arr_out, axis=1)
139
224
 
140
- yield self.OUTPUT_SIGNAL, replace(msg, data=arr_out),
225
+ yield (
226
+ self.OUTPUT_SIGNAL,
227
+ replace(msg, data=arr_out),
228
+ )