ezmsg-sigproc 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ezmsg/sigproc/__init__.py CHANGED
@@ -1 +1,4 @@
1
- from .__version__ import __version__
1
+ import importlib.metadata
2
+
3
+
4
+ __version__ = importlib.metadata.version("ezmsg-sigproc")
@@ -0,0 +1,124 @@
1
+ from dataclasses import replace
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Generator, Optional, Union
5
+
6
+ import numpy as np
7
+ import ezmsg.core as ez
8
+ from ezmsg.util.messages.axisarray import AxisArray
9
+ from ezmsg.util.generator import consumer, GenAxisArray
10
+
11
+
12
+ @consumer
13
+ def affine_transform(
14
+ weights: Union[np.ndarray, str, Path],
15
+ axis: Optional[str] = None,
16
+ right_multiply: bool = True,
17
+ ) -> Generator[AxisArray, AxisArray, None]:
18
+ axis_arr_in = AxisArray(np.array([]), dims=[""])
19
+ axis_arr_out = AxisArray(np.array([]), dims=[""])
20
+
21
+ if isinstance(weights, str):
22
+ weights = Path(os.path.abspath(os.path.expanduser(weights)))
23
+ if isinstance(weights, Path):
24
+ weights = np.loadtxt(weights, delimiter=",")
25
+ if not right_multiply:
26
+ weights = weights.T
27
+ weights = np.ascontiguousarray(weights)
28
+
29
+ while True:
30
+ axis_arr_in = yield axis_arr_out
31
+
32
+ if axis is None:
33
+ axis = axis_arr_in.dims[-1]
34
+ axis_idx = -1
35
+ else:
36
+ axis_idx = axis_arr_in.get_axis_idx(axis)
37
+
38
+ data = axis_arr_in.data
39
+
40
+ if data.shape[axis_idx] == (weights.shape[0] - 1):
41
+ # The weights are stacked A|B where A is the transform and B is a single row
42
+ # in the equation y = Ax + B. This supports NeuroKey's weights matrices.
43
+ sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx+1:]
44
+ data = np.concatenate((data, np.ones(sample_shape).astype(data.dtype)), axis=axis_idx)
45
+
46
+ if axis_idx in [-1, len(axis_arr_in.dims) - 1]:
47
+ data = np.matmul(data, weights)
48
+ else:
49
+ data = np.moveaxis(data, axis_idx, -1)
50
+ data = np.matmul(data, weights)
51
+ data = np.moveaxis(data, -1, axis_idx)
52
+ axis_arr_out = replace(axis_arr_in, data=data)
53
+
54
+
55
+ class AffineTransformSettings(ez.Settings):
56
+ weights: Union[np.ndarray, str, Path]
57
+ axis: Optional[str] = None
58
+ right_multiply: bool = True
59
+
60
+
61
+ class AffineTransform(GenAxisArray):
62
+ SETTINGS: AffineTransformSettings
63
+
64
+ def construct_generator(self):
65
+ self.STATE.gen = affine_transform(
66
+ weights=self.SETTINGS.weights,
67
+ axis=self.SETTINGS.axis,
68
+ right_multiply=self.SETTINGS.right_multiply,
69
+ )
70
+
71
+
72
+ @consumer
73
+ def common_rereference(
74
+ mode: str = "mean", axis: Optional[str] = None, include_current: bool = True
75
+ ) -> Generator[AxisArray, AxisArray, None]:
76
+ axis_arr_in = AxisArray(np.array([]), dims=[""])
77
+ axis_arr_out = AxisArray(np.array([]), dims=[""])
78
+
79
+ func = {"mean": np.mean, "median": np.median}[mode]
80
+
81
+ while True:
82
+ axis_arr_in = yield axis_arr_out
83
+
84
+ if axis is None:
85
+ axis = axis_arr_in.dims[-1]
86
+ axis_idx = -1
87
+ else:
88
+ axis_idx = axis_arr_in.get_axis_idx(axis)
89
+
90
+ ref_data = func(axis_arr_in.data, axis=axis_idx, keepdims=True)
91
+
92
+ if not include_current:
93
+ # Typical `CAR = x[0]/N + x[1]/N + ... x[i-1]/N + x[i]/N + x[i+1]/N + ... + x[N-1]/N`
94
+ # and is the same for all i, so it is calculated only once in `ref_data`.
95
+ # However, if we had excluded the current channel,
96
+ # then we would have omitted the contribution of the current channel:
97
+ # `CAR[i] = x[0]/(N-1) + x[1]/(N-1) + ... x[i-1]/(N-1) + x[i+1]/(N-1) + ... + x[N-1]/(N-1)`
98
+ # The majority of the calculation is the same as when the current channel is included;
99
+ # we need only rescale CAR so the divisor is `N-1` instead of `N`, then subtract the contribution
100
+ # from the current channel (i.e., `x[i] / (N-1)`)
101
+ # i.e., `CAR[i] = (N / (N-1)) * common_CAR - x[i]/(N-1)`
102
+ # We can use broadcasting subtraction instead of looping over channels.
103
+ N = axis_arr_in.data.shape[axis_idx]
104
+ ref_data = (N / (N - 1)) * ref_data - axis_arr_in.data / (N - 1)
105
+ # Side note: I profiled using affine_transform and it's about 30x slower than this implementation.
106
+
107
+ axis_arr_out = replace(axis_arr_in, data=axis_arr_in.data - ref_data)
108
+
109
+
110
+ class CommonRereferenceSettings(ez.Settings):
111
+ mode: str = "mean"
112
+ axis: Optional[str] = None
113
+ include_current: bool = True
114
+
115
+
116
+ class CommonRereference(GenAxisArray):
117
+ SETTINGS: CommonRereferenceSettings
118
+
119
+ def construct_generator(self):
120
+ self.STATE.gen = common_rereference(
121
+ mode=self.SETTINGS.mode,
122
+ axis=self.SETTINGS.axis,
123
+ include_current=self.SETTINGS.include_current,
124
+ )
@@ -0,0 +1,103 @@
1
+ from dataclasses import replace
2
+ import typing
3
+
4
+ import numpy as np
5
+ import ezmsg.core as ez
6
+ from ezmsg.util.generator import consumer, GenAxisArray
7
+ from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
8
+ from ezmsg.sigproc.spectral import OptionsEnum
9
+
10
+
11
+ class AggregationFunction(OptionsEnum):
12
+ NONE = "None (all)"
13
+ MAX = "max"
14
+ MIN = "min"
15
+ MEAN = "mean"
16
+ MEDIAN = "median"
17
+ STD = "std"
18
+ NANMAX = "nanmax"
19
+ NANMIN = "nanmin"
20
+ NANMEAN = "nanmean"
21
+ NANMEDIAN = "nanmedian"
22
+ NANSTD = "nanstd"
23
+
24
+
25
+ AGGREGATORS = {
26
+ AggregationFunction.NONE: np.all,
27
+ AggregationFunction.MAX: np.max,
28
+ AggregationFunction.MIN: np.min,
29
+ AggregationFunction.MEAN: np.mean,
30
+ AggregationFunction.MEDIAN: np.median,
31
+ AggregationFunction.STD: np.std,
32
+ AggregationFunction.NANMAX: np.nanmax,
33
+ AggregationFunction.NANMIN: np.nanmin,
34
+ AggregationFunction.NANMEAN: np.nanmean,
35
+ AggregationFunction.NANMEDIAN: np.nanmedian,
36
+ AggregationFunction.NANSTD: np.nanstd
37
+ }
38
+
39
+
40
+ @consumer
41
+ def ranged_aggregate(
42
+ axis: typing.Optional[str] = None,
43
+ bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = None,
44
+ operation: AggregationFunction = AggregationFunction.MEAN
45
+ ):
46
+ axis_arr_in = AxisArray(np.array([]), dims=[""])
47
+ axis_arr_out = AxisArray(np.array([]), dims=[""])
48
+
49
+ target_axis: typing.Optional[AxisArray.Axis] = None
50
+ out_axis = AxisArray.Axis()
51
+ slices: typing.Optional[typing.List[typing.Tuple[typing.Any, ...]]] = None
52
+ axis_name = ""
53
+
54
+ while True:
55
+ axis_arr_in = yield axis_arr_out
56
+ if bands is None:
57
+ axis_arr_out = axis_arr_in
58
+ else:
59
+ if slices is None or target_axis != axis_arr_in.get_axis(axis_name):
60
+ # Calculate the slices. If we are operating on time axis then
61
+ axis_name = axis or axis_arr_in.dims[0]
62
+ ax_idx = axis_arr_in.get_axis_idx(axis_name)
63
+ target_axis = axis_arr_in.axes[axis_name]
64
+
65
+ ax_vec = target_axis.offset + np.arange(axis_arr_in.data.shape[ax_idx]) * target_axis.gain
66
+ slices = []
67
+ mids = []
68
+ for (start, stop) in bands:
69
+ inds = np.where(np.logical_and(ax_vec >= start, ax_vec <= stop))[0]
70
+ mids.append(np.mean(inds) * target_axis.gain + target_axis.offset)
71
+ slices.append(np.s_[inds[0]:inds[-1] + 1])
72
+ out_axis = AxisArray.Axis(
73
+ unit=target_axis.unit, offset=mids[0], gain=(mids[1] - mids[0]) if len(mids) > 1 else 1.0
74
+ )
75
+
76
+ agg_func = AGGREGATORS[operation]
77
+ out_data = [
78
+ agg_func(slice_along_axis(axis_arr_in.data, sl, axis=ax_idx), axis=ax_idx)
79
+ for sl in slices
80
+ ]
81
+ new_axes = {**axis_arr_in.axes, axis_name: out_axis}
82
+ axis_arr_out = replace(
83
+ axis_arr_in,
84
+ data=np.stack(out_data, axis=ax_idx),
85
+ axes=new_axes
86
+ )
87
+
88
+
89
+ class RangedAggregateSettings(ez.Settings):
90
+ axis: typing.Optional[str] = None
91
+ bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = None
92
+ operation: AggregationFunction = AggregationFunction.MEAN
93
+
94
+
95
+ class RangedAggregate(GenAxisArray):
96
+ SETTINGS: RangedAggregateSettings
97
+
98
+ def construct_generator(self):
99
+ self.STATE.gen = ranged_aggregate(
100
+ axis=self.SETTINGS.axis,
101
+ bands=self.SETTINGS.bands,
102
+ operation=self.SETTINGS.operation
103
+ )
@@ -0,0 +1,53 @@
1
+ from dataclasses import field
2
+ import typing
3
+
4
+ import numpy as np
5
+ import ezmsg.core as ez
6
+ from ezmsg.util.messages.axisarray import AxisArray
7
+ from ezmsg.util.generator import consumer, compose, GenAxisArray
8
+
9
+ from .spectrogram import spectrogram, SpectrogramSettings
10
+ from .aggregate import ranged_aggregate, AggregationFunction
11
+
12
+
13
+ @consumer
14
+ def bandpower(
15
+ spectrogram_settings: SpectrogramSettings,
16
+ bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = [(17, 30), (70, 170)]
17
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
18
+ axis_arr_in = AxisArray(np.array([]), dims=[""])
19
+ axis_arr_out = AxisArray(np.array([]), dims=[""])
20
+
21
+ f_spec = spectrogram(
22
+ window_dur=spectrogram_settings.window_dur,
23
+ window_shift=spectrogram_settings.window_shift,
24
+ window=spectrogram_settings.window,
25
+ transform=spectrogram_settings.transform,
26
+ output=spectrogram_settings.output
27
+ )
28
+ f_agg = ranged_aggregate(
29
+ axis="freq",
30
+ bands=bands,
31
+ operation=AggregationFunction.MEAN
32
+ )
33
+ pipeline = compose(f_spec, f_agg)
34
+
35
+ while True:
36
+ axis_arr_in = yield axis_arr_out
37
+ axis_arr_out = pipeline(axis_arr_in)
38
+
39
+
40
+ class BandPowerSettings(ez.Settings):
41
+ spectrogram_settings: SpectrogramSettings = field(default_factory=SpectrogramSettings)
42
+ bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = (
43
+ field(default_factory=lambda: [(17, 30), (70, 170)]))
44
+
45
+
46
+ class BandPower(GenAxisArray):
47
+ SETTINGS: BandPowerSettings
48
+
49
+ def construct_generator(self):
50
+ self.STATE.gen = bandpower(
51
+ spectrogram_settings=self.SETTINGS.spectrogram_settings,
52
+ bands=self.SETTINGS.bands
53
+ )
@@ -1,18 +1,21 @@
1
+ import typing
2
+
1
3
  import ezmsg.core as ez
2
4
  import scipy.signal
3
5
  import numpy as np
4
6
 
5
- from .filter import Filter, FilterState, FilterSettingsBase
7
+ from .filter import filtergen, Filter, FilterState, FilterSettingsBase
6
8
 
7
- from typing import Optional, Tuple, Union
9
+ from ezmsg.util.messages.axisarray import AxisArray
10
+ from ezmsg.util.generator import consumer
8
11
 
9
12
 
10
13
  class ButterworthFilterSettings(FilterSettingsBase):
11
14
  order: int = 0
12
- cuton: Optional[float] = None # Hz
13
- cutoff: Optional[float] = None # Hz
15
+ cuton: typing.Optional[float] = None # Hz
16
+ cutoff: typing.Optional[float] = None # Hz
14
17
 
15
- def filter_specs(self) -> Optional[Tuple[str, Union[float, Tuple[float, float]]]]:
18
+ def filter_specs(self) -> typing.Optional[typing.Tuple[str, typing.Union[float, typing.Tuple[float, float]]]]:
16
19
  if self.cuton is None and self.cutoff is None:
17
20
  return None
18
21
  elif self.cuton is None and self.cutoff is not None:
@@ -26,6 +29,38 @@ class ButterworthFilterSettings(FilterSettingsBase):
26
29
  return "bandstop", (self.cutoff, self.cuton)
27
30
 
28
31
 
32
+ @consumer
33
+ def butter(
34
+ axis: typing.Optional[str],
35
+ order: int = 0,
36
+ cuton: typing.Optional[float] = None,
37
+ cutoff: typing.Optional[float] = None,
38
+ coef_type: str = "ba",
39
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
40
+ # IO
41
+ axis_arr_in = AxisArray(np.array([]), dims=[""])
42
+ axis_arr_out = AxisArray(np.array([]), dims=[""])
43
+
44
+ btype, cutoffs = ButterworthFilterSettings(
45
+ order=order, cuton=cuton, cutoff=cutoff
46
+ ).filter_specs()
47
+
48
+ # We cannot calculate coefs yet because we do not know input sample rate
49
+ coefs = None
50
+ filter_gen = filtergen(axis, coefs, coef_type) # Passthrough.
51
+
52
+ while True:
53
+ axis_arr_in = yield axis_arr_out
54
+ if coefs is None and order > 0:
55
+ fs = 1 / axis_arr_in.axes[axis or axis_arr_in.dims[0]].gain
56
+ coefs = scipy.signal.butter(
57
+ order, Wn=cutoffs, btype=btype, fs=fs, output=coef_type
58
+ )
59
+ filter_gen = filtergen(axis, coefs, coef_type)
60
+
61
+ axis_arr_out = filter_gen.send(axis_arr_in)
62
+
63
+
29
64
  class ButterworthFilterState(FilterState):
30
65
  design: ButterworthFilterSettings
31
66
 
@@ -41,7 +76,7 @@ class ButterworthFilter(Filter):
41
76
  self.STATE.filt_designed = True
42
77
  super().initialize()
43
78
 
44
- def design_filter(self) -> Optional[Tuple[np.ndarray, np.ndarray]]:
79
+ def design_filter(self) -> typing.Optional[typing.Tuple[np.ndarray, np.ndarray]]:
45
80
  specs = self.STATE.design.filter_specs()
46
81
  if self.STATE.design.order > 0 and specs is not None:
47
82
  btype, cut = specs
@@ -57,6 +92,9 @@ class ButterworthFilter(Filter):
57
92
 
58
93
  @ez.subscriber(INPUT_FILTER)
59
94
  async def redesign(self, message: ButterworthFilterSettings) -> None:
95
+ if type(message) is not ButterworthFilterSettings:
96
+ return
97
+
60
98
  if self.STATE.design.order != message.order:
61
99
  self.STATE.zi = None
62
100
  self.STATE.design = message
@@ -1,14 +1,46 @@
1
1
  from dataclasses import replace
2
+ import traceback
3
+ from typing import AsyncGenerator, Optional, Generator
2
4
 
3
- from ezmsg.util.messages.axisarray import AxisArray
5
+ import numpy as np
4
6
 
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+ from ezmsg.util.generator import consumer
5
9
  import ezmsg.core as ez
6
- import numpy as np
7
10
 
8
- from typing import (
9
- AsyncGenerator,
10
- Optional,
11
- )
11
+
12
+ @consumer
13
+ def downsample(
14
+ axis: Optional[str] = None, factor: int = 1
15
+ ) -> Generator[AxisArray, AxisArray, None]:
16
+ axis_arr_in = AxisArray(np.array([]), dims=[""])
17
+ axis_arr_out = AxisArray(np.array([]), dims=[""])
18
+
19
+ # state variables
20
+ s_idx = 0
21
+
22
+ while True:
23
+ axis_arr_in = yield axis_arr_out
24
+
25
+ if axis is None:
26
+ axis = axis_arr_in.dims[0]
27
+ axis_info = axis_arr_in.get_axis(axis)
28
+ axis_idx = axis_arr_in.get_axis_idx(axis)
29
+
30
+ samples = np.arange(axis_arr_in.data.shape[axis_idx]) + s_idx
31
+ samples = samples % factor
32
+ s_idx = samples[-1] + 1
33
+
34
+ pub_samples = np.where(samples == 0)[0]
35
+ if len(pub_samples) > 0:
36
+ new_axes = {ax_name: axis_arr_in.get_axis(ax_name) for ax_name in axis_arr_in.dims}
37
+ new_offset = axis_info.offset + (axis_info.gain * pub_samples[0].item())
38
+ new_gain = axis_info.gain * factor
39
+ new_axes[axis] = replace(axis_info, gain=new_gain, offset=new_offset)
40
+ down_data = np.take(axis_arr_in.data, pub_samples, axis=axis_idx)
41
+ axis_arr_out = replace(axis_arr_in, data=down_data, dims=axis_arr_in.dims, axes=new_axes)
42
+ else:
43
+ axis_arr_out = None
12
44
 
13
45
 
14
46
  class DownsampleSettings(ez.Settings):
@@ -18,7 +50,7 @@ class DownsampleSettings(ez.Settings):
18
50
 
19
51
  class DownsampleState(ez.State):
20
52
  cur_settings: DownsampleSettings
21
- s_idx: int = 0
53
+ gen: Generator
22
54
 
23
55
 
24
56
  class Downsample(ez.Unit):
@@ -29,12 +61,17 @@ class Downsample(ez.Unit):
29
61
  INPUT_SIGNAL = ez.InputStream(AxisArray)
30
62
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
31
63
 
64
+ def construct_generator(self):
65
+ self.STATE.gen = downsample(axis=self.STATE.cur_settings.axis, factor=self.STATE.cur_settings.factor)
66
+
32
67
  def initialize(self) -> None:
33
68
  self.STATE.cur_settings = self.SETTINGS
69
+ self.construct_generator()
34
70
 
35
71
  @ez.subscriber(INPUT_SETTINGS)
36
72
  async def on_settings(self, msg: DownsampleSettings) -> None:
37
73
  self.STATE.cur_settings = msg
74
+ self.construct_generator()
38
75
 
39
76
  @ez.subscriber(INPUT_SIGNAL, zero_copy=True)
40
77
  @ez.publisher(OUTPUT_SIGNAL)
@@ -42,22 +79,11 @@ class Downsample(ez.Unit):
42
79
  if self.STATE.cur_settings.factor < 1:
43
80
  raise ValueError("Downsample factor must be at least 1 (no downsampling)")
44
81
 
45
- axis_name = self.STATE.cur_settings.axis
46
- if axis_name is None:
47
- axis_name = msg.dims[0]
48
- axis = msg.get_axis(axis_name)
49
- axis_idx = msg.get_axis_idx(axis_name)
50
-
51
- samples = np.arange(msg.data.shape[axis_idx]) + self.STATE.s_idx
52
- samples = samples % self.STATE.cur_settings.factor
53
- self.STATE.s_idx = samples[-1] + 1
54
-
55
- pub_samples = np.where(samples == 0)[0]
56
- if len(pub_samples) != 0:
57
- new_axes = {ax_name: msg.get_axis(ax_name) for ax_name in msg.dims}
58
- new_offset = axis.offset + (axis.gain * pub_samples[0].item())
59
- new_gain = axis.gain * self.STATE.cur_settings.factor
60
- new_axes[axis_name] = replace(axis, gain=new_gain, offset=new_offset)
61
- down_data = np.take(msg.data, pub_samples, axis_idx)
62
- out_msg = replace(msg, data=down_data, dims=msg.dims, axes=new_axes)
63
- yield self.OUTPUT_SIGNAL, out_msg
82
+ try:
83
+ out_msg = self.STATE.gen.send(msg)
84
+ if out_msg is not None:
85
+ yield self.OUTPUT_SIGNAL, out_msg
86
+ except (StopIteration, GeneratorExit):
87
+ ez.logger.debug(f"Downsample closed in {self.address}")
88
+ except Exception:
89
+ ez.logger.info(traceback.format_exc())
@@ -73,9 +73,12 @@ class EWM(ez.Unit):
73
73
  buffer_data = buffer.data
74
74
  buffer_data = np.moveaxis(buffer_data, axis_idx, 0)
75
75
 
76
+ while scale_arr.ndim < buffer_data.ndim:
77
+ scale_arr = scale_arr[..., None]
78
+
76
79
  def ewma(data: np.ndarray) -> np.ndarray:
77
- mult = scale_arr[:, np.newaxis] * data * pw0
78
- out = scale_arr[::-1, np.newaxis] * mult.cumsum(axis=0)
80
+ mult = scale_arr * data * pw0
81
+ out = scale_arr[::-1] * mult.cumsum(axis=0)
79
82
 
80
83
  if not self.SETTINGS.zero_offset:
81
84
  out = (data[0, :, np.newaxis] * pows[1:]).T + out
@@ -108,7 +111,12 @@ class EWMFilter(ez.Collection):
108
111
  EWM = EWM()
109
112
 
110
113
  def configure(self) -> None:
111
- self.EWM.apply_settings(EWMSettings(axis=self.SETTINGS.axis, zero_offset=True))
114
+ self.EWM.apply_settings(
115
+ EWMSettings(
116
+ axis=self.SETTINGS.axis,
117
+ zero_offset=self.SETTINGS.zero_offset,
118
+ )
119
+ )
112
120
 
113
121
  self.WINDOW.apply_settings(
114
122
  WindowSettings(
ezmsg/sigproc/filter.py CHANGED
@@ -1,39 +1,107 @@
1
+ import asyncio
2
+ import typing
3
+
1
4
  from dataclasses import dataclass, replace, field
2
5
 
3
6
  import ezmsg.core as ez
4
7
  import scipy.signal
8
+
5
9
  import numpy as np
6
- import asyncio
10
+ import numpy.typing as npt
7
11
 
8
12
  from ezmsg.util.messages.axisarray import AxisArray
9
-
10
- from typing import AsyncGenerator, Optional, Tuple
11
-
13
+ from ezmsg.util.generator import consumer
12
14
 
13
15
  @dataclass
14
16
  class FilterCoefficients:
15
17
  b: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0]))
16
18
  a: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0]))
17
19
 
20
+ def _normalize_coefs(
21
+ coefs: typing.Union[FilterCoefficients, typing.Tuple[npt.NDArray, npt.NDArray],npt.NDArray]
22
+ ) -> typing.Tuple[str, typing.Tuple[npt.NDArray,...]]:
23
+ coef_type = "ba"
24
+ if coefs is not None:
25
+ # scipy.signal functions called with first arg `*coefs`.
26
+ # Make sure we have a tuple of coefficients.
27
+ if isinstance(coefs, npt.NDArray):
28
+ coef_type = "sos"
29
+ coefs = (coefs,) # sos funcs just want a single ndarray.
30
+ elif isinstance(coefs, FilterCoefficients):
31
+ coefs = (FilterCoefficients.b, FilterCoefficients.a)
32
+ return coef_type, coefs
33
+
34
+ @consumer
35
+ def filtergen(
36
+ axis: str, coefs: typing.Optional[typing.Tuple[np.ndarray]], coef_type: str
37
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
38
+ # Massage inputs
39
+ if coefs is not None and not isinstance(coefs, tuple):
40
+ # scipy.signal functions called with first arg `*coefs`, but sos coefs are a single ndarray.
41
+ coefs = (coefs,)
42
+
43
+ # Init IO
44
+ axis_arr_in = AxisArray(np.array([]), dims=[""])
45
+ axis_arr_out = AxisArray(np.array([]), dims=[""])
46
+
47
+ filt_func = {"ba": scipy.signal.lfilter, "sos": scipy.signal.sosfilt}[coef_type]
48
+ zi_func = {"ba": scipy.signal.lfilter_zi, "sos": scipy.signal.sosfilt_zi}[coef_type]
49
+
50
+ # State variables
51
+ axis_idx = None
52
+ zi = None
53
+ expected_shape = None
54
+
55
+ while True:
56
+ axis_arr_in = yield axis_arr_out
57
+
58
+ if coefs is None:
59
+ # passthrough if we do not have a filter design.
60
+ axis_arr_out = axis_arr_in
61
+ continue
62
+
63
+ if axis_idx is None:
64
+ axis_name = axis_arr_in.dims[0] if axis is None else axis
65
+ axis_idx = axis_arr_in.get_axis_idx(axis_name)
66
+
67
+ dat_in = axis_arr_in.data
68
+
69
+ # Re-calculate/reset zi if necessary
70
+ samp_shape = dat_in.shape[:axis_idx] + dat_in.shape[axis_idx + 1 :]
71
+ if zi is None or samp_shape != expected_shape:
72
+ expected_shape = samp_shape
73
+ n_tail = dat_in.ndim - axis_idx - 1
74
+ zi = zi_func(*coefs)
75
+ zi_expand = (None,) * axis_idx + (slice(None),) + (None,) * n_tail
76
+ n_tile = dat_in.shape[:axis_idx] + (1,) + dat_in.shape[axis_idx + 1 :]
77
+ if coef_type == "sos":
78
+ # sos zi must keep its leading dimension (`order / 2` for low|high; `order` for bpass|bstop)
79
+ zi_expand = (slice(None),) + zi_expand
80
+ n_tile = (1,) + n_tile
81
+ zi = np.tile(zi[zi_expand], n_tile)
82
+
83
+ dat_out, zi = filt_func(*coefs, dat_in, axis=axis_idx, zi=zi)
84
+ axis_arr_out = replace(axis_arr_in, data=dat_out)
85
+
18
86
 
19
87
  class FilterSettingsBase(ez.Settings):
20
- axis: Optional[str] = None
21
- fs: Optional[float] = None
88
+ axis: typing.Optional[str] = None
89
+ fs: typing.Optional[float] = None
22
90
 
23
91
 
24
92
  class FilterSettings(FilterSettingsBase):
25
93
  # If you'd like to statically design a filter, define it in settings
26
- filt: Optional[FilterCoefficients] = None
94
+ filt: typing.Optional[FilterCoefficients] = None
27
95
 
28
96
 
29
97
  class FilterState(ez.State):
30
- axis: Optional[str] = None
31
- zi: Optional[np.ndarray] = None
98
+ axis: typing.Optional[str] = None
99
+ zi: typing.Optional[np.ndarray] = None
32
100
  filt_designed: bool = False
33
- filt: Optional[FilterCoefficients] = None
101
+ filt: typing.Optional[FilterCoefficients] = None
34
102
  filt_set: asyncio.Event = field(default_factory=asyncio.Event)
35
- samp_shape: Optional[Tuple[int, ...]] = None
36
- fs: Optional[float] = None # Hz
103
+ samp_shape: typing.Optional[typing.Tuple[int, ...]] = None
104
+ fs: typing.Optional[float] = None # Hz
37
105
 
38
106
 
39
107
  class Filter(ez.Unit):
@@ -44,7 +112,7 @@ class Filter(ez.Unit):
44
112
  INPUT_SIGNAL = ez.InputStream(AxisArray)
45
113
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
46
114
 
47
- def design_filter(self) -> Optional[Tuple[np.ndarray, np.ndarray]]:
115
+ def design_filter(self) -> typing.Optional[typing.Tuple[np.ndarray, np.ndarray]]:
48
116
  raise NotImplementedError("Must implement 'design_filter' in Unit subclass!")
49
117
 
50
118
  # Set up filter with static initialization if specified
@@ -84,7 +152,7 @@ class Filter(ez.Unit):
84
152
 
85
153
  @ez.subscriber(INPUT_SIGNAL)
86
154
  @ez.publisher(OUTPUT_SIGNAL)
87
- async def apply_filter(self, msg: AxisArray) -> AsyncGenerator:
155
+ async def apply_filter(self, msg: AxisArray) -> typing.AsyncGenerator:
88
156
  axis_name = msg.dims[0] if self.STATE.axis is None else self.STATE.axis
89
157
  axis_idx = msg.get_axis_idx(axis_name)
90
158
  axis = msg.get_axis(axis_name)