ezmsg-sigproc 1.2.2__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. ezmsg/sigproc/__init__.py +1 -1
  2. ezmsg/sigproc/__version__.py +16 -1
  3. ezmsg/sigproc/activation.py +75 -0
  4. ezmsg/sigproc/affinetransform.py +234 -0
  5. ezmsg/sigproc/aggregate.py +158 -0
  6. ezmsg/sigproc/bandpower.py +74 -0
  7. ezmsg/sigproc/base.py +38 -0
  8. ezmsg/sigproc/butterworthfilter.py +102 -11
  9. ezmsg/sigproc/decimate.py +7 -4
  10. ezmsg/sigproc/downsample.py +95 -51
  11. ezmsg/sigproc/ewmfilter.py +38 -16
  12. ezmsg/sigproc/filter.py +108 -20
  13. ezmsg/sigproc/filterbank.py +278 -0
  14. ezmsg/sigproc/math/__init__.py +0 -0
  15. ezmsg/sigproc/math/abs.py +28 -0
  16. ezmsg/sigproc/math/clip.py +30 -0
  17. ezmsg/sigproc/math/difference.py +60 -0
  18. ezmsg/sigproc/math/invert.py +29 -0
  19. ezmsg/sigproc/math/log.py +32 -0
  20. ezmsg/sigproc/math/scale.py +31 -0
  21. ezmsg/sigproc/messages.py +2 -3
  22. ezmsg/sigproc/sampler.py +259 -224
  23. ezmsg/sigproc/scaler.py +173 -0
  24. ezmsg/sigproc/signalinjector.py +64 -0
  25. ezmsg/sigproc/slicer.py +133 -0
  26. ezmsg/sigproc/spectral.py +6 -132
  27. ezmsg/sigproc/spectrogram.py +86 -0
  28. ezmsg/sigproc/spectrum.py +259 -0
  29. ezmsg/sigproc/synth.py +299 -105
  30. ezmsg/sigproc/wavelets.py +167 -0
  31. ezmsg/sigproc/window.py +254 -116
  32. ezmsg_sigproc-1.3.1.dist-info/METADATA +59 -0
  33. ezmsg_sigproc-1.3.1.dist-info/RECORD +35 -0
  34. {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-1.3.1.dist-info}/WHEEL +1 -2
  35. ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
  36. ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
  37. ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
  38. {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-1.3.1.dist-info/licenses}/LICENSE.txt +0 -0
ezmsg/sigproc/__init__.py CHANGED
@@ -1 +1 @@
1
- from .__version__ import __version__
1
+ from .__version__ import __version__ as __version__
@@ -1 +1,16 @@
1
- __version__ = "1.2.2"
1
+ # file generated by setuptools_scm
2
+ # don't change, don't track in version control
3
+ TYPE_CHECKING = False
4
+ if TYPE_CHECKING:
5
+ from typing import Tuple, Union
6
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
7
+ else:
8
+ VERSION_TUPLE = object
9
+
10
+ version: str
11
+ __version__: str
12
+ __version_tuple__: VERSION_TUPLE
13
+ version_tuple: VERSION_TUPLE
14
+
15
+ __version__ = version = '1.3.1'
16
+ __version_tuple__ = version_tuple = (1, 3, 1)
@@ -0,0 +1,75 @@
1
+ from dataclasses import replace
2
+ import typing
3
+
4
+ import numpy as np
5
+ import scipy.special
6
+ import ezmsg.core as ez
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+ from ezmsg.util.generator import consumer
9
+
10
+ from .spectral import OptionsEnum
11
+ from .base import GenAxisArray
12
+
13
+
14
+ class ActivationFunction(OptionsEnum):
15
+ """Activation (transformation) function."""
16
+
17
+ NONE = "none"
18
+ """None."""
19
+
20
+ SIGMOID = "sigmoid"
21
+ """:obj:`scipy.special.expit`"""
22
+
23
+ EXPIT = "expit"
24
+ """:obj:`scipy.special.expit`"""
25
+
26
+ LOGIT = "logit"
27
+ """:obj:`scipy.special.logit`"""
28
+
29
+ LOGEXPIT = "log_expit"
30
+ """:obj:`scipy.special.log_expit`"""
31
+
32
+
33
+ ACTIVATIONS = {
34
+ ActivationFunction.NONE: lambda x: x,
35
+ ActivationFunction.SIGMOID: scipy.special.expit,
36
+ ActivationFunction.EXPIT: scipy.special.expit,
37
+ ActivationFunction.LOGIT: scipy.special.logit,
38
+ ActivationFunction.LOGEXPIT: scipy.special.log_expit,
39
+ }
40
+
41
+
42
+ @consumer
43
+ def activation(
44
+ function: typing.Union[str, ActivationFunction],
45
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
46
+ if type(function) is ActivationFunction:
47
+ func = ACTIVATIONS[function]
48
+ else:
49
+ # str type. There's probably an easier way to support either enum or str argument. Oh well this works.
50
+ function: str = function.lower()
51
+ if function not in ActivationFunction.options():
52
+ raise ValueError(
53
+ f"Unrecognized activation function {function}. Must be one of {ACTIVATIONS.keys()}"
54
+ )
55
+ function = list(ACTIVATIONS.keys())[
56
+ ActivationFunction.options().index(function)
57
+ ]
58
+ func = ACTIVATIONS[function]
59
+
60
+ msg_out = AxisArray(np.array([]), dims=[""])
61
+
62
+ while True:
63
+ msg_in: AxisArray = yield msg_out
64
+ msg_out = replace(msg_in, data=func(msg_in.data))
65
+
66
+
67
+ class ActivationSettings(ez.Settings):
68
+ function: str = ActivationFunction.NONE
69
+
70
+
71
+ class Activation(GenAxisArray):
72
+ SETTINGS = ActivationSettings
73
+
74
+ def construct_generator(self):
75
+ self.STATE.gen = activation(function=self.SETTINGS.function)
@@ -0,0 +1,234 @@
1
+ from dataclasses import replace
2
+ import os
3
+ from pathlib import Path
4
+ import typing
5
+
6
+ import numpy as np
7
+ import numpy.typing as npt
8
+ import ezmsg.core as ez
9
+ from ezmsg.util.messages.axisarray import AxisArray
10
+ from ezmsg.util.generator import consumer
11
+
12
+ from .base import GenAxisArray
13
+
14
+
15
+ @consumer
16
+ def affine_transform(
17
+ weights: typing.Union[np.ndarray, str, Path],
18
+ axis: typing.Optional[str] = None,
19
+ right_multiply: bool = True,
20
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
21
+ """
22
+ Perform affine transformations on streaming data.
23
+
24
+ Args:
25
+ weights: An array of weights or a path to a file with weights compatible with np.loadtxt.
26
+ axis: The name of the axis to apply the transformation to. Defaults to the leading (0th) axis in the array.
27
+ right_multiply: Set False to tranpose the weights before applying.
28
+
29
+ Returns:
30
+ A primed generator object that yields an :obj:`AxisArray` object for every
31
+ :obj:`AxisArray` it receives via `send`.
32
+ """
33
+ msg_out = AxisArray(np.array([]), dims=[""])
34
+
35
+ # Check parameters
36
+ if isinstance(weights, str):
37
+ if weights == "passthrough":
38
+ weights = None
39
+ else:
40
+ weights = Path(os.path.abspath(os.path.expanduser(weights)))
41
+ if isinstance(weights, Path):
42
+ weights = np.loadtxt(weights, delimiter=",")
43
+ if not right_multiply:
44
+ weights = weights.T
45
+ if weights is not None:
46
+ weights = np.ascontiguousarray(weights)
47
+
48
+ # State variables
49
+ # New axis with transformed labels, if required
50
+ new_axis: typing.Optional[AxisArray.Axis] = None
51
+
52
+ # Reset if any of these change.
53
+ check_input = {"key": None}
54
+ # We assume key change catches labels change; we don't want to check labels every message
55
+ # We don't need to check if input size has changed because weights multiplication will fail if so.
56
+
57
+ while True:
58
+ msg_in: AxisArray = yield msg_out
59
+
60
+ if weights is None:
61
+ msg_out = msg_in
62
+ continue
63
+
64
+ axis = axis or msg_in.dims[-1] # Note: Most nodes default do dim[0]
65
+ axis_idx = msg_in.get_axis_idx(axis)
66
+
67
+ b_reset = msg_in.key != check_input["key"]
68
+ if b_reset:
69
+ # First sample or key has changed. Reset the state.
70
+ check_input["key"] = msg_in.key
71
+ # Determine if we need to modify the transformed axis.
72
+ if (
73
+ axis in msg_in.axes
74
+ and hasattr(msg_in.axes[axis], "labels")
75
+ and weights.shape[0] != weights.shape[1]
76
+ ):
77
+ in_labels = msg_in.axes[axis].labels
78
+ new_labels = []
79
+ n_in = weights.shape[1 if right_multiply else 0]
80
+ n_out = weights.shape[0 if right_multiply else 1]
81
+ if len(in_labels) != n_in:
82
+ # Something upstream did something it wasn't supposed to. We will drop the labels.
83
+ ez.logger.warning(
84
+ f"Received {len(in_labels)} for {n_in} inputs. Check upstream labels."
85
+ )
86
+ else:
87
+ b_used_inputs = np.any(weights, axis=0 if right_multiply else 1)
88
+ b_filled_outputs = np.any(weights, axis=1 if right_multiply else 0)
89
+ if np.all(b_used_inputs) and np.all(b_filled_outputs):
90
+ # All inputs are used and all outputs are used, but n_in != n_out.
91
+ # Mapping cannot be determined.
92
+ new_labels = []
93
+ elif np.all(b_used_inputs):
94
+ # Strange scenario: New outputs are filled with empty data.
95
+ in_ix = 0
96
+ new_labels = []
97
+ for out_ix in range(n_out):
98
+ if b_filled_outputs[out_ix]:
99
+ new_labels.append(in_labels[in_ix])
100
+ in_ix += 1
101
+ else:
102
+ new_labels.append("")
103
+ elif np.all(b_filled_outputs):
104
+ # Transform is dropping some of the inputs.
105
+ new_labels = np.array(in_labels)[b_used_inputs].tolist()
106
+ new_axis = replace(msg_in.axes[axis], labels=new_labels)
107
+
108
+ data = msg_in.data
109
+
110
+ if data.shape[axis_idx] == (weights.shape[0] - 1):
111
+ # The weights are stacked A|B where A is the transform and B is a single row
112
+ # in the equation y = Ax + B. This supports NeuroKey's weights matrices.
113
+ sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx + 1 :]
114
+ data = np.concatenate(
115
+ (data, np.ones(sample_shape).astype(data.dtype)), axis=axis_idx
116
+ )
117
+
118
+ if axis_idx in [-1, len(msg_in.dims) - 1]:
119
+ data = np.matmul(data, weights)
120
+ else:
121
+ data = np.moveaxis(data, axis_idx, -1)
122
+ data = np.matmul(data, weights)
123
+ data = np.moveaxis(data, -1, axis_idx)
124
+
125
+ replace_kwargs = {"data": data}
126
+ if new_axis is not None:
127
+ replace_kwargs["axes"] = {**msg_in.axes, axis: new_axis}
128
+ msg_out = replace(msg_in, **replace_kwargs)
129
+
130
+
131
+ class AffineTransformSettings(ez.Settings):
132
+ """
133
+ Settings for :obj:`AffineTransform`.
134
+ See :obj:`affine_transform` for argument details.
135
+ """
136
+
137
+ weights: typing.Union[np.ndarray, str, Path]
138
+ axis: typing.Optional[str] = None
139
+ right_multiply: bool = True
140
+
141
+
142
+ class AffineTransform(GenAxisArray):
143
+ """:obj:`Unit` for :obj:`affine_transform`"""
144
+
145
+ SETTINGS = AffineTransformSettings
146
+
147
+ def construct_generator(self):
148
+ self.STATE.gen = affine_transform(
149
+ weights=self.SETTINGS.weights,
150
+ axis=self.SETTINGS.axis,
151
+ right_multiply=self.SETTINGS.right_multiply,
152
+ )
153
+
154
+
155
+ def zeros_for_noop(data: npt.NDArray, **ignore_kwargs) -> npt.NDArray:
156
+ return np.zeros_like(data)
157
+
158
+
159
+ @consumer
160
+ def common_rereference(
161
+ mode: str = "mean", axis: typing.Optional[str] = None, include_current: bool = True
162
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
163
+ """
164
+ Perform common average referencing (CAR) on streaming data.
165
+
166
+ Args:
167
+ mode: The statistical mode to apply -- either "mean" or "median"
168
+ axis: The name of hte axis to apply the transformation to.
169
+ include_current: Set False to exclude each channel from participating in the calculation of its reference.
170
+
171
+ Returns:
172
+ A primed generator object that yields an :obj:`AxisArray` object
173
+ for every :obj:`AxisArray` it receives via `send`.
174
+ """
175
+ msg_out = AxisArray(np.array([]), dims=[""])
176
+
177
+ if mode == "passthrough":
178
+ include_current = True
179
+
180
+ func = {"mean": np.mean, "median": np.median, "passthrough": zeros_for_noop}[mode]
181
+
182
+ while True:
183
+ msg_in: AxisArray = yield msg_out
184
+
185
+ if axis is None:
186
+ axis = msg_in.dims[-1]
187
+ axis_idx = -1
188
+ else:
189
+ axis_idx = msg_in.get_axis_idx(axis)
190
+
191
+ ref_data = func(msg_in.data, axis=axis_idx, keepdims=True)
192
+
193
+ if not include_current:
194
+ # Typical `CAR = x[0]/N + x[1]/N + ... x[i-1]/N + x[i]/N + x[i+1]/N + ... + x[N-1]/N`
195
+ # and is the same for all i, so it is calculated only once in `ref_data`.
196
+ # However, if we had excluded the current channel,
197
+ # then we would have omitted the contribution of the current channel:
198
+ # `CAR[i] = x[0]/(N-1) + x[1]/(N-1) + ... x[i-1]/(N-1) + x[i+1]/(N-1) + ... + x[N-1]/(N-1)`
199
+ # The majority of the calculation is the same as when the current channel is included;
200
+ # we need only rescale CAR so the divisor is `N-1` instead of `N`, then subtract the contribution
201
+ # from the current channel (i.e., `x[i] / (N-1)`)
202
+ # i.e., `CAR[i] = (N / (N-1)) * common_CAR - x[i]/(N-1)`
203
+ # We can use broadcasting subtraction instead of looping over channels.
204
+ N = msg_in.data.shape[axis_idx]
205
+ ref_data = (N / (N - 1)) * ref_data - msg_in.data / (N - 1)
206
+ # Side note: I profiled using affine_transform and it's about 30x slower than this implementation.
207
+
208
+ msg_out = replace(msg_in, data=msg_in.data - ref_data)
209
+
210
+
211
+ class CommonRereferenceSettings(ez.Settings):
212
+ """
213
+ Settings for :obj:`CommonRereference`
214
+ See :obj:`common_rereference` for argument details.
215
+ """
216
+
217
+ mode: str = "mean"
218
+ axis: typing.Optional[str] = None
219
+ include_current: bool = True
220
+
221
+
222
+ class CommonRereference(GenAxisArray):
223
+ """
224
+ :obj:`Unit` for :obj:`common_rereference`.
225
+ """
226
+
227
+ SETTINGS = CommonRereferenceSettings
228
+
229
+ def construct_generator(self):
230
+ self.STATE.gen = common_rereference(
231
+ mode=self.SETTINGS.mode,
232
+ axis=self.SETTINGS.axis,
233
+ include_current=self.SETTINGS.include_current,
234
+ )
@@ -0,0 +1,158 @@
1
+ from dataclasses import replace
2
+ import typing
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ import ezmsg.core as ez
7
+ from ezmsg.util.generator import consumer
8
+ from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
9
+
10
+ from .spectral import OptionsEnum
11
+ from .base import GenAxisArray
12
+
13
+
14
+ class AggregationFunction(OptionsEnum):
15
+ """Enum for aggregation functions available to be used in :obj:`ranged_aggregate` operation."""
16
+
17
+ NONE = "None (all)"
18
+ MAX = "max"
19
+ MIN = "min"
20
+ MEAN = "mean"
21
+ MEDIAN = "median"
22
+ STD = "std"
23
+ NANMAX = "nanmax"
24
+ NANMIN = "nanmin"
25
+ NANMEAN = "nanmean"
26
+ NANMEDIAN = "nanmedian"
27
+ NANSTD = "nanstd"
28
+ ARGMIN = "argmin"
29
+ ARGMAX = "argmax"
30
+
31
+
32
+ AGGREGATORS = {
33
+ AggregationFunction.NONE: np.all,
34
+ AggregationFunction.MAX: np.max,
35
+ AggregationFunction.MIN: np.min,
36
+ AggregationFunction.MEAN: np.mean,
37
+ AggregationFunction.MEDIAN: np.median,
38
+ AggregationFunction.STD: np.std,
39
+ AggregationFunction.NANMAX: np.nanmax,
40
+ AggregationFunction.NANMIN: np.nanmin,
41
+ AggregationFunction.NANMEAN: np.nanmean,
42
+ AggregationFunction.NANMEDIAN: np.nanmedian,
43
+ AggregationFunction.NANSTD: np.nanstd,
44
+ AggregationFunction.ARGMIN: np.argmin,
45
+ AggregationFunction.ARGMAX: np.argmax,
46
+ }
47
+
48
+
49
+ @consumer
50
+ def ranged_aggregate(
51
+ axis: typing.Optional[str] = None,
52
+ bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = None,
53
+ operation: AggregationFunction = AggregationFunction.MEAN,
54
+ ):
55
+ """
56
+ Apply an aggregation operation over one or more bands.
57
+
58
+ Args:
59
+ axis: The name of the axis along which to apply the bands.
60
+ bands: [(band1_min, band1_max), (band2_min, band2_max), ...]
61
+ If not set then this acts as a passthrough node.
62
+ operation: :obj:`AggregationFunction` to apply to each band.
63
+
64
+ Returns:
65
+ A primed generator object ready to yield an AxisArray for each .send(axis_array)
66
+ """
67
+ msg_out = AxisArray(np.array([]), dims=[""])
68
+
69
+ # State variables
70
+ slices: typing.Optional[typing.List[typing.Tuple[typing.Any, ...]]] = None
71
+ out_axis: typing.Optional[AxisArray.Axis] = None
72
+ ax_vec: typing.Optional[npt.NDArray] = None
73
+
74
+ # Reset if any of these changes. Key not checked because continuity between chunks not required.
75
+ check_inputs = {"gain": None, "offset": None}
76
+
77
+ while True:
78
+ msg_in: AxisArray = yield msg_out
79
+ if bands is None:
80
+ msg_out = msg_in
81
+ else:
82
+ axis = axis or msg_in.dims[0]
83
+ target_axis = msg_in.get_axis(axis)
84
+
85
+ b_reset = target_axis.gain != check_inputs["gain"]
86
+ b_reset = b_reset or target_axis.offset != check_inputs["offset"]
87
+ if b_reset:
88
+ check_inputs["gain"] = target_axis.gain
89
+ check_inputs["offset"] = target_axis.offset
90
+
91
+ # If the axis we are operating on has changed (e.g., "time" or "win" axis always changes),
92
+ # or the key has changed, then recalculate slices.
93
+
94
+ ax_idx = msg_in.get_axis_idx(axis)
95
+
96
+ ax_vec = (
97
+ target_axis.offset
98
+ + np.arange(msg_in.data.shape[ax_idx]) * target_axis.gain
99
+ )
100
+ slices = []
101
+ mids = []
102
+ for start, stop in bands:
103
+ inds = np.where(np.logical_and(ax_vec >= start, ax_vec <= stop))[0]
104
+ mids.append(np.mean(inds) * target_axis.gain + target_axis.offset)
105
+ slices.append(np.s_[inds[0] : inds[-1] + 1])
106
+ out_ax_kwargs = {
107
+ "unit": target_axis.unit,
108
+ "offset": mids[0],
109
+ "gain": (mids[1] - mids[0]) if len(mids) > 1 else 1.0,
110
+ }
111
+ if hasattr(target_axis, "labels"):
112
+ out_ax_kwargs["labels"] = [f"{_[0]} - {_[1]}" for _ in bands]
113
+ out_axis = replace(target_axis, **out_ax_kwargs)
114
+
115
+ agg_func = AGGREGATORS[operation]
116
+ out_data = [
117
+ agg_func(slice_along_axis(msg_in.data, sl, axis=ax_idx), axis=ax_idx)
118
+ for sl in slices
119
+ ]
120
+
121
+ msg_out = replace(
122
+ msg_in,
123
+ data=np.stack(out_data, axis=ax_idx),
124
+ axes={**msg_in.axes, axis: out_axis},
125
+ )
126
+ if operation in [AggregationFunction.ARGMIN, AggregationFunction.ARGMAX]:
127
+ # Convert indices returned by argmin/argmax into the value along the axis.
128
+ out_data = []
129
+ for sl_ix, sl in enumerate(slices):
130
+ offsets = np.take(msg_out.data, [sl_ix], axis=ax_idx)
131
+ out_data.append(ax_vec[sl][offsets])
132
+ msg_out.data = np.concatenate(out_data, axis=ax_idx)
133
+
134
+
135
+ class RangedAggregateSettings(ez.Settings):
136
+ """
137
+ Settings for ``RangedAggregate``.
138
+ See :obj:`ranged_aggregate` for details.
139
+ """
140
+
141
+ axis: typing.Optional[str] = None
142
+ bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = None
143
+ operation: AggregationFunction = AggregationFunction.MEAN
144
+
145
+
146
+ class RangedAggregate(GenAxisArray):
147
+ """
148
+ Unit for :obj:`ranged_aggregate`
149
+ """
150
+
151
+ SETTINGS = RangedAggregateSettings
152
+
153
+ def construct_generator(self):
154
+ self.STATE.gen = ranged_aggregate(
155
+ axis=self.SETTINGS.axis,
156
+ bands=self.SETTINGS.bands,
157
+ operation=self.SETTINGS.operation,
158
+ )
@@ -0,0 +1,74 @@
1
+ from dataclasses import field
2
+ import typing
3
+
4
+ import numpy as np
5
+ import ezmsg.core as ez
6
+ from ezmsg.util.messages.axisarray import AxisArray
7
+ from ezmsg.util.generator import consumer, compose
8
+
9
+ from .spectrogram import spectrogram, SpectrogramSettings
10
+ from .aggregate import ranged_aggregate, AggregationFunction
11
+ from .base import GenAxisArray
12
+
13
+
14
+ @consumer
15
+ def bandpower(
16
+ spectrogram_settings: SpectrogramSettings,
17
+ bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = [
18
+ (17, 30),
19
+ (70, 170),
20
+ ],
21
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
22
+ """
23
+ Calculate the average spectral power in each band.
24
+
25
+ Args:
26
+ spectrogram_settings: Settings for spectrogram calculation.
27
+ bands: (min, max) tuples of band limits in Hz.
28
+
29
+ Returns:
30
+ A primed generator object ready to yield an AxisArray for each .send(axis_array)
31
+ """
32
+ msg_out = AxisArray(np.array([]), dims=[""])
33
+
34
+ f_spec = spectrogram(
35
+ window_dur=spectrogram_settings.window_dur,
36
+ window_shift=spectrogram_settings.window_shift,
37
+ window=spectrogram_settings.window,
38
+ transform=spectrogram_settings.transform,
39
+ output=spectrogram_settings.output,
40
+ )
41
+ f_agg = ranged_aggregate(
42
+ axis="freq", bands=bands, operation=AggregationFunction.MEAN
43
+ )
44
+ pipeline = compose(f_spec, f_agg)
45
+
46
+ while True:
47
+ msg_in: AxisArray = yield msg_out
48
+ msg_out = pipeline(msg_in)
49
+
50
+
51
+ class BandPowerSettings(ez.Settings):
52
+ """
53
+ Settings for ``BandPower``.
54
+ See :obj:`bandpower` for details.
55
+ """
56
+
57
+ spectrogram_settings: SpectrogramSettings = field(
58
+ default_factory=SpectrogramSettings
59
+ )
60
+ bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = field(
61
+ default_factory=lambda: [(17, 30), (70, 170)]
62
+ )
63
+
64
+
65
+ class BandPower(GenAxisArray):
66
+ """:obj:`Unit` for :obj:`bandpower`."""
67
+
68
+ SETTINGS = BandPowerSettings
69
+
70
+ def construct_generator(self):
71
+ self.STATE.gen = bandpower(
72
+ spectrogram_settings=self.SETTINGS.spectrogram_settings,
73
+ bands=self.SETTINGS.bands,
74
+ )
ezmsg/sigproc/base.py ADDED
@@ -0,0 +1,38 @@
1
+ import traceback
2
+ import typing
3
+
4
+ import ezmsg.core as ez
5
+ from ezmsg.util.messages.axisarray import AxisArray
6
+ from ezmsg.util.generator import GenState
7
+
8
+
9
+ class GenAxisArray(ez.Unit):
10
+ STATE = GenState
11
+
12
+ INPUT_SIGNAL = ez.InputStream(AxisArray)
13
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
14
+ INPUT_SETTINGS = ez.InputStream(ez.Settings)
15
+
16
+ async def initialize(self) -> None:
17
+ self.construct_generator()
18
+
19
+ # Method to be implemented by subclasses to construct the specific generator
20
+ def construct_generator(self):
21
+ raise NotImplementedError
22
+
23
+ @ez.subscriber(INPUT_SETTINGS)
24
+ async def on_settings(self, msg: ez.Settings) -> None:
25
+ self.apply_settings(msg)
26
+ self.construct_generator()
27
+
28
+ @ez.subscriber(INPUT_SIGNAL, zero_copy=True)
29
+ @ez.publisher(OUTPUT_SIGNAL)
30
+ async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
31
+ try:
32
+ ret = self.STATE.gen.send(message)
33
+ if ret.data.size > 0:
34
+ yield self.OUTPUT_SIGNAL, ret
35
+ except (StopIteration, GeneratorExit):
36
+ ez.logger.debug(f"Generator closed in {self.address}")
37
+ except Exception:
38
+ ez.logger.info(traceback.format_exc())