ezmsg-sigproc 1.2.3__py3-none-any.whl → 1.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ezmsg/sigproc/__init__.py CHANGED
@@ -1,4 +1 @@
1
- import importlib.metadata
2
-
3
-
4
- __version__ = importlib.metadata.version("ezmsg-sigproc")
1
+ from .__version__ import __version__ as __version__
@@ -0,0 +1,16 @@
1
+ # file generated by setuptools_scm
2
+ # don't change, don't track in version control
3
+ TYPE_CHECKING = False
4
+ if TYPE_CHECKING:
5
+ from typing import Tuple, Union
6
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
7
+ else:
8
+ VERSION_TUPLE = object
9
+
10
+ version: str
11
+ __version__: str
12
+ __version_tuple__: VERSION_TUPLE
13
+ version_tuple: VERSION_TUPLE
14
+
15
+ __version__ = version = '1.3.2'
16
+ __version_tuple__ = version_tuple = (1, 3, 2)
@@ -0,0 +1,75 @@
1
+ from dataclasses import replace
2
+ import typing
3
+
4
+ import numpy as np
5
+ import scipy.special
6
+ import ezmsg.core as ez
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+ from ezmsg.util.generator import consumer
9
+
10
+ from .spectral import OptionsEnum
11
+ from .base import GenAxisArray
12
+
13
+
14
+ class ActivationFunction(OptionsEnum):
15
+ """Activation (transformation) function."""
16
+
17
+ NONE = "none"
18
+ """None."""
19
+
20
+ SIGMOID = "sigmoid"
21
+ """:obj:`scipy.special.expit`"""
22
+
23
+ EXPIT = "expit"
24
+ """:obj:`scipy.special.expit`"""
25
+
26
+ LOGIT = "logit"
27
+ """:obj:`scipy.special.logit`"""
28
+
29
+ LOGEXPIT = "log_expit"
30
+ """:obj:`scipy.special.log_expit`"""
31
+
32
+
33
+ ACTIVATIONS = {
34
+ ActivationFunction.NONE: lambda x: x,
35
+ ActivationFunction.SIGMOID: scipy.special.expit,
36
+ ActivationFunction.EXPIT: scipy.special.expit,
37
+ ActivationFunction.LOGIT: scipy.special.logit,
38
+ ActivationFunction.LOGEXPIT: scipy.special.log_expit,
39
+ }
40
+
41
+
42
+ @consumer
43
+ def activation(
44
+ function: typing.Union[str, ActivationFunction],
45
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
46
+ if type(function) is ActivationFunction:
47
+ func = ACTIVATIONS[function]
48
+ else:
49
+ # str type. There's probably an easier way to support either enum or str argument. Oh well this works.
50
+ function: str = function.lower()
51
+ if function not in ActivationFunction.options():
52
+ raise ValueError(
53
+ f"Unrecognized activation function {function}. Must be one of {ACTIVATIONS.keys()}"
54
+ )
55
+ function = list(ACTIVATIONS.keys())[
56
+ ActivationFunction.options().index(function)
57
+ ]
58
+ func = ACTIVATIONS[function]
59
+
60
+ msg_out = AxisArray(np.array([]), dims=[""])
61
+
62
+ while True:
63
+ msg_in: AxisArray = yield msg_out
64
+ msg_out = replace(msg_in, data=func(msg_in.data))
65
+
66
+
67
+ class ActivationSettings(ez.Settings):
68
+ function: str = ActivationFunction.NONE
69
+
70
+
71
+ class Activation(GenAxisArray):
72
+ SETTINGS = ActivationSettings
73
+
74
+ def construct_generator(self):
75
+ self.STATE.gen = activation(function=self.SETTINGS.function)
@@ -1,65 +1,148 @@
1
1
  from dataclasses import replace
2
2
  import os
3
3
  from pathlib import Path
4
- from typing import Generator, Optional, Union
4
+ import typing
5
5
 
6
6
  import numpy as np
7
+ import numpy.typing as npt
7
8
  import ezmsg.core as ez
8
9
  from ezmsg.util.messages.axisarray import AxisArray
9
- from ezmsg.util.generator import consumer, GenAxisArray
10
+ from ezmsg.util.generator import consumer
11
+
12
+ from .base import GenAxisArray
10
13
 
11
14
 
12
15
  @consumer
13
16
  def affine_transform(
14
- weights: Union[np.ndarray, str, Path],
15
- axis: Optional[str] = None,
17
+ weights: typing.Union[np.ndarray, str, Path],
18
+ axis: typing.Optional[str] = None,
16
19
  right_multiply: bool = True,
17
- ) -> Generator[AxisArray, AxisArray, None]:
18
- axis_arr_in = AxisArray(np.array([]), dims=[""])
19
- axis_arr_out = AxisArray(np.array([]), dims=[""])
20
-
20
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
21
+ """
22
+ Perform affine transformations on streaming data.
23
+
24
+ Args:
25
+ weights: An array of weights or a path to a file with weights compatible with np.loadtxt.
26
+ axis: The name of the axis to apply the transformation to. Defaults to the leading (0th) axis in the array.
27
+ right_multiply: Set False to tranpose the weights before applying.
28
+
29
+ Returns:
30
+ A primed generator object that yields an :obj:`AxisArray` object for every
31
+ :obj:`AxisArray` it receives via `send`.
32
+ """
33
+ msg_out = AxisArray(np.array([]), dims=[""])
34
+
35
+ # Check parameters
21
36
  if isinstance(weights, str):
22
- weights = Path(os.path.abspath(os.path.expanduser(weights)))
37
+ if weights == "passthrough":
38
+ weights = None
39
+ else:
40
+ weights = Path(os.path.abspath(os.path.expanduser(weights)))
23
41
  if isinstance(weights, Path):
24
42
  weights = np.loadtxt(weights, delimiter=",")
25
43
  if not right_multiply:
26
44
  weights = weights.T
27
- weights = np.ascontiguousarray(weights)
45
+ if weights is not None:
46
+ weights = np.ascontiguousarray(weights)
28
47
 
29
- while True:
30
- axis_arr_in = yield axis_arr_out
48
+ # State variables
49
+ # New axis with transformed labels, if required
50
+ new_axis: typing.Optional[AxisArray.Axis] = None
31
51
 
32
- if axis is None:
33
- axis = axis_arr_in.dims[-1]
34
- axis_idx = -1
35
- else:
36
- axis_idx = axis_arr_in.get_axis_idx(axis)
52
+ # Reset if any of these change.
53
+ check_input = {"key": None}
54
+ # We assume key change catches labels change; we don't want to check labels every message
55
+ # We don't need to check if input size has changed because weights multiplication will fail if so.
37
56
 
38
- data = axis_arr_in.data
57
+ while True:
58
+ msg_in: AxisArray = yield msg_out
59
+
60
+ if weights is None:
61
+ msg_out = msg_in
62
+ continue
63
+
64
+ axis = axis or msg_in.dims[-1] # Note: Most nodes default do dim[0]
65
+ axis_idx = msg_in.get_axis_idx(axis)
66
+
67
+ b_reset = msg_in.key != check_input["key"]
68
+ if b_reset:
69
+ # First sample or key has changed. Reset the state.
70
+ check_input["key"] = msg_in.key
71
+ # Determine if we need to modify the transformed axis.
72
+ if (
73
+ axis in msg_in.axes
74
+ and hasattr(msg_in.axes[axis], "labels")
75
+ and weights.shape[0] != weights.shape[1]
76
+ ):
77
+ in_labels = msg_in.axes[axis].labels
78
+ new_labels = []
79
+ n_in = weights.shape[1 if right_multiply else 0]
80
+ n_out = weights.shape[0 if right_multiply else 1]
81
+ if len(in_labels) != n_in:
82
+ # Something upstream did something it wasn't supposed to. We will drop the labels.
83
+ ez.logger.warning(
84
+ f"Received {len(in_labels)} for {n_in} inputs. Check upstream labels."
85
+ )
86
+ else:
87
+ b_used_inputs = np.any(weights, axis=0 if right_multiply else 1)
88
+ b_filled_outputs = np.any(weights, axis=1 if right_multiply else 0)
89
+ if np.all(b_used_inputs) and np.all(b_filled_outputs):
90
+ # All inputs are used and all outputs are used, but n_in != n_out.
91
+ # Mapping cannot be determined.
92
+ new_labels = []
93
+ elif np.all(b_used_inputs):
94
+ # Strange scenario: New outputs are filled with empty data.
95
+ in_ix = 0
96
+ new_labels = []
97
+ for out_ix in range(n_out):
98
+ if b_filled_outputs[out_ix]:
99
+ new_labels.append(in_labels[in_ix])
100
+ in_ix += 1
101
+ else:
102
+ new_labels.append("")
103
+ elif np.all(b_filled_outputs):
104
+ # Transform is dropping some of the inputs.
105
+ new_labels = np.array(in_labels)[b_used_inputs].tolist()
106
+ new_axis = replace(msg_in.axes[axis], labels=new_labels)
107
+
108
+ data = msg_in.data
39
109
 
40
110
  if data.shape[axis_idx] == (weights.shape[0] - 1):
41
111
  # The weights are stacked A|B where A is the transform and B is a single row
42
112
  # in the equation y = Ax + B. This supports NeuroKey's weights matrices.
43
- sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx+1:]
44
- data = np.concatenate((data, np.ones(sample_shape).astype(data.dtype)), axis=axis_idx)
113
+ sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx + 1 :]
114
+ data = np.concatenate(
115
+ (data, np.ones(sample_shape).astype(data.dtype)), axis=axis_idx
116
+ )
45
117
 
46
- if axis_idx in [-1, len(axis_arr_in.dims) - 1]:
118
+ if axis_idx in [-1, len(msg_in.dims) - 1]:
47
119
  data = np.matmul(data, weights)
48
120
  else:
49
121
  data = np.moveaxis(data, axis_idx, -1)
50
122
  data = np.matmul(data, weights)
51
123
  data = np.moveaxis(data, -1, axis_idx)
52
- axis_arr_out = replace(axis_arr_in, data=data)
124
+
125
+ replace_kwargs = {"data": data}
126
+ if new_axis is not None:
127
+ replace_kwargs["axes"] = {**msg_in.axes, axis: new_axis}
128
+ msg_out = replace(msg_in, **replace_kwargs)
53
129
 
54
130
 
55
131
  class AffineTransformSettings(ez.Settings):
56
- weights: Union[np.ndarray, str, Path]
57
- axis: Optional[str] = None
132
+ """
133
+ Settings for :obj:`AffineTransform`.
134
+ See :obj:`affine_transform` for argument details.
135
+ """
136
+
137
+ weights: typing.Union[np.ndarray, str, Path]
138
+ axis: typing.Optional[str] = None
58
139
  right_multiply: bool = True
59
140
 
60
141
 
61
142
  class AffineTransform(GenAxisArray):
62
- SETTINGS: AffineTransformSettings
143
+ """:obj:`Unit` for :obj:`affine_transform`"""
144
+
145
+ SETTINGS = AffineTransformSettings
63
146
 
64
147
  def construct_generator(self):
65
148
  self.STATE.gen = affine_transform(
@@ -69,25 +152,43 @@ class AffineTransform(GenAxisArray):
69
152
  )
70
153
 
71
154
 
155
+ def zeros_for_noop(data: npt.NDArray, **ignore_kwargs) -> npt.NDArray:
156
+ return np.zeros_like(data)
157
+
158
+
72
159
  @consumer
73
160
  def common_rereference(
74
- mode: str = "mean", axis: Optional[str] = None, include_current: bool = True
75
- ) -> Generator[AxisArray, AxisArray, None]:
76
- axis_arr_in = AxisArray(np.array([]), dims=[""])
77
- axis_arr_out = AxisArray(np.array([]), dims=[""])
161
+ mode: str = "mean", axis: typing.Optional[str] = None, include_current: bool = True
162
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
163
+ """
164
+ Perform common average referencing (CAR) on streaming data.
165
+
166
+ Args:
167
+ mode: The statistical mode to apply -- either "mean" or "median"
168
+ axis: The name of hte axis to apply the transformation to.
169
+ include_current: Set False to exclude each channel from participating in the calculation of its reference.
78
170
 
79
- func = {"mean": np.mean, "median": np.median}[mode]
171
+ Returns:
172
+ A primed generator object that yields an :obj:`AxisArray` object
173
+ for every :obj:`AxisArray` it receives via `send`.
174
+ """
175
+ msg_out = AxisArray(np.array([]), dims=[""])
176
+
177
+ if mode == "passthrough":
178
+ include_current = True
179
+
180
+ func = {"mean": np.mean, "median": np.median, "passthrough": zeros_for_noop}[mode]
80
181
 
81
182
  while True:
82
- axis_arr_in = yield axis_arr_out
183
+ msg_in: AxisArray = yield msg_out
83
184
 
84
185
  if axis is None:
85
- axis = axis_arr_in.dims[-1]
186
+ axis = msg_in.dims[-1]
86
187
  axis_idx = -1
87
188
  else:
88
- axis_idx = axis_arr_in.get_axis_idx(axis)
189
+ axis_idx = msg_in.get_axis_idx(axis)
89
190
 
90
- ref_data = func(axis_arr_in.data, axis=axis_idx, keepdims=True)
191
+ ref_data = func(msg_in.data, axis=axis_idx, keepdims=True)
91
192
 
92
193
  if not include_current:
93
194
  # Typical `CAR = x[0]/N + x[1]/N + ... x[i-1]/N + x[i]/N + x[i+1]/N + ... + x[N-1]/N`
@@ -100,21 +201,30 @@ def common_rereference(
100
201
  # from the current channel (i.e., `x[i] / (N-1)`)
101
202
  # i.e., `CAR[i] = (N / (N-1)) * common_CAR - x[i]/(N-1)`
102
203
  # We can use broadcasting subtraction instead of looping over channels.
103
- N = axis_arr_in.data.shape[axis_idx]
104
- ref_data = (N / (N - 1)) * ref_data - axis_arr_in.data / (N - 1)
204
+ N = msg_in.data.shape[axis_idx]
205
+ ref_data = (N / (N - 1)) * ref_data - msg_in.data / (N - 1)
105
206
  # Side note: I profiled using affine_transform and it's about 30x slower than this implementation.
106
207
 
107
- axis_arr_out = replace(axis_arr_in, data=axis_arr_in.data - ref_data)
208
+ msg_out = replace(msg_in, data=msg_in.data - ref_data)
108
209
 
109
210
 
110
211
  class CommonRereferenceSettings(ez.Settings):
212
+ """
213
+ Settings for :obj:`CommonRereference`
214
+ See :obj:`common_rereference` for argument details.
215
+ """
216
+
111
217
  mode: str = "mean"
112
- axis: Optional[str] = None
218
+ axis: typing.Optional[str] = None
113
219
  include_current: bool = True
114
220
 
115
221
 
116
222
  class CommonRereference(GenAxisArray):
117
- SETTINGS: CommonRereferenceSettings
223
+ """
224
+ :obj:`Unit` for :obj:`common_rereference`.
225
+ """
226
+
227
+ SETTINGS = CommonRereferenceSettings
118
228
 
119
229
  def construct_generator(self):
120
230
  self.STATE.gen = common_rereference(
@@ -2,13 +2,18 @@ from dataclasses import replace
2
2
  import typing
3
3
 
4
4
  import numpy as np
5
+ import numpy.typing as npt
5
6
  import ezmsg.core as ez
6
- from ezmsg.util.generator import consumer, GenAxisArray
7
+ from ezmsg.util.generator import consumer
7
8
  from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
8
- from ezmsg.sigproc.spectral import OptionsEnum
9
+
10
+ from .spectral import OptionsEnum
11
+ from .base import GenAxisArray
9
12
 
10
13
 
11
14
  class AggregationFunction(OptionsEnum):
15
+ """Enum for aggregation functions available to be used in :obj:`ranged_aggregate` operation."""
16
+
12
17
  NONE = "None (all)"
13
18
  MAX = "max"
14
19
  MIN = "min"
@@ -20,6 +25,8 @@ class AggregationFunction(OptionsEnum):
20
25
  NANMEAN = "nanmean"
21
26
  NANMEDIAN = "nanmedian"
22
27
  NANSTD = "nanstd"
28
+ ARGMIN = "argmin"
29
+ ARGMAX = "argmax"
23
30
 
24
31
 
25
32
  AGGREGATORS = {
@@ -33,7 +40,9 @@ AGGREGATORS = {
33
40
  AggregationFunction.NANMIN: np.nanmin,
34
41
  AggregationFunction.NANMEAN: np.nanmean,
35
42
  AggregationFunction.NANMEDIAN: np.nanmedian,
36
- AggregationFunction.NANSTD: np.nanstd
43
+ AggregationFunction.NANSTD: np.nanstd,
44
+ AggregationFunction.ARGMIN: np.argmin,
45
+ AggregationFunction.ARGMAX: np.argmax,
37
46
  }
38
47
 
39
48
 
@@ -41,63 +50,109 @@ AGGREGATORS = {
41
50
  def ranged_aggregate(
42
51
  axis: typing.Optional[str] = None,
43
52
  bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = None,
44
- operation: AggregationFunction = AggregationFunction.MEAN
53
+ operation: AggregationFunction = AggregationFunction.MEAN,
45
54
  ):
46
- axis_arr_in = AxisArray(np.array([]), dims=[""])
47
- axis_arr_out = AxisArray(np.array([]), dims=[""])
55
+ """
56
+ Apply an aggregation operation over one or more bands.
57
+
58
+ Args:
59
+ axis: The name of the axis along which to apply the bands.
60
+ bands: [(band1_min, band1_max), (band2_min, band2_max), ...]
61
+ If not set then this acts as a passthrough node.
62
+ operation: :obj:`AggregationFunction` to apply to each band.
48
63
 
49
- target_axis: typing.Optional[AxisArray.Axis] = None
50
- out_axis = AxisArray.Axis()
64
+ Returns:
65
+ A primed generator object ready to yield an AxisArray for each .send(axis_array)
66
+ """
67
+ msg_out = AxisArray(np.array([]), dims=[""])
68
+
69
+ # State variables
51
70
  slices: typing.Optional[typing.List[typing.Tuple[typing.Any, ...]]] = None
52
- axis_name = ""
71
+ out_axis: typing.Optional[AxisArray.Axis] = None
72
+ ax_vec: typing.Optional[npt.NDArray] = None
73
+
74
+ # Reset if any of these changes. Key not checked because continuity between chunks not required.
75
+ check_inputs = {"gain": None, "offset": None}
53
76
 
54
77
  while True:
55
- axis_arr_in = yield axis_arr_out
78
+ msg_in: AxisArray = yield msg_out
56
79
  if bands is None:
57
- axis_arr_out = axis_arr_in
80
+ msg_out = msg_in
58
81
  else:
59
- if slices is None or target_axis != axis_arr_in.get_axis(axis_name):
60
- # Calculate the slices. If we are operating on time axis then
61
- axis_name = axis or axis_arr_in.dims[0]
62
- ax_idx = axis_arr_in.get_axis_idx(axis_name)
63
- target_axis = axis_arr_in.axes[axis_name]
82
+ axis = axis or msg_in.dims[0]
83
+ target_axis = msg_in.get_axis(axis)
84
+
85
+ b_reset = target_axis.gain != check_inputs["gain"]
86
+ b_reset = b_reset or target_axis.offset != check_inputs["offset"]
87
+ if b_reset:
88
+ check_inputs["gain"] = target_axis.gain
89
+ check_inputs["offset"] = target_axis.offset
64
90
 
65
- ax_vec = target_axis.offset + np.arange(axis_arr_in.data.shape[ax_idx]) * target_axis.gain
91
+ # If the axis we are operating on has changed (e.g., "time" or "win" axis always changes),
92
+ # or the key has changed, then recalculate slices.
93
+
94
+ ax_idx = msg_in.get_axis_idx(axis)
95
+
96
+ ax_vec = (
97
+ target_axis.offset
98
+ + np.arange(msg_in.data.shape[ax_idx]) * target_axis.gain
99
+ )
66
100
  slices = []
67
101
  mids = []
68
- for (start, stop) in bands:
102
+ for start, stop in bands:
69
103
  inds = np.where(np.logical_and(ax_vec >= start, ax_vec <= stop))[0]
70
104
  mids.append(np.mean(inds) * target_axis.gain + target_axis.offset)
71
- slices.append(np.s_[inds[0]:inds[-1] + 1])
72
- out_axis = AxisArray.Axis(
73
- unit=target_axis.unit, offset=mids[0], gain=(mids[1] - mids[0]) if len(mids) > 1 else 1.0
74
- )
105
+ slices.append(np.s_[inds[0] : inds[-1] + 1])
106
+ out_ax_kwargs = {
107
+ "unit": target_axis.unit,
108
+ "offset": mids[0],
109
+ "gain": (mids[1] - mids[0]) if len(mids) > 1 else 1.0,
110
+ }
111
+ if hasattr(target_axis, "labels"):
112
+ out_ax_kwargs["labels"] = [f"{_[0]} - {_[1]}" for _ in bands]
113
+ out_axis = replace(target_axis, **out_ax_kwargs)
75
114
 
76
115
  agg_func = AGGREGATORS[operation]
77
116
  out_data = [
78
- agg_func(slice_along_axis(axis_arr_in.data, sl, axis=ax_idx), axis=ax_idx)
117
+ agg_func(slice_along_axis(msg_in.data, sl, axis=ax_idx), axis=ax_idx)
79
118
  for sl in slices
80
119
  ]
81
- new_axes = {**axis_arr_in.axes, axis_name: out_axis}
82
- axis_arr_out = replace(
83
- axis_arr_in,
120
+
121
+ msg_out = replace(
122
+ msg_in,
84
123
  data=np.stack(out_data, axis=ax_idx),
85
- axes=new_axes
124
+ axes={**msg_in.axes, axis: out_axis},
86
125
  )
126
+ if operation in [AggregationFunction.ARGMIN, AggregationFunction.ARGMAX]:
127
+ # Convert indices returned by argmin/argmax into the value along the axis.
128
+ out_data = []
129
+ for sl_ix, sl in enumerate(slices):
130
+ offsets = np.take(msg_out.data, [sl_ix], axis=ax_idx)
131
+ out_data.append(ax_vec[sl][offsets])
132
+ msg_out.data = np.concatenate(out_data, axis=ax_idx)
87
133
 
88
134
 
89
135
  class RangedAggregateSettings(ez.Settings):
136
+ """
137
+ Settings for ``RangedAggregate``.
138
+ See :obj:`ranged_aggregate` for details.
139
+ """
140
+
90
141
  axis: typing.Optional[str] = None
91
142
  bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = None
92
143
  operation: AggregationFunction = AggregationFunction.MEAN
93
144
 
94
145
 
95
146
  class RangedAggregate(GenAxisArray):
96
- SETTINGS: RangedAggregateSettings
147
+ """
148
+ Unit for :obj:`ranged_aggregate`
149
+ """
150
+
151
+ SETTINGS = RangedAggregateSettings
97
152
 
98
153
  def construct_generator(self):
99
154
  self.STATE.gen = ranged_aggregate(
100
155
  axis=self.SETTINGS.axis,
101
156
  bands=self.SETTINGS.bands,
102
- operation=self.SETTINGS.operation
157
+ operation=self.SETTINGS.operation,
103
158
  )
@@ -4,50 +4,71 @@ import typing
4
4
  import numpy as np
5
5
  import ezmsg.core as ez
6
6
  from ezmsg.util.messages.axisarray import AxisArray
7
- from ezmsg.util.generator import consumer, compose, GenAxisArray
7
+ from ezmsg.util.generator import consumer, compose
8
8
 
9
9
  from .spectrogram import spectrogram, SpectrogramSettings
10
10
  from .aggregate import ranged_aggregate, AggregationFunction
11
+ from .base import GenAxisArray
11
12
 
12
13
 
13
14
  @consumer
14
15
  def bandpower(
15
16
  spectrogram_settings: SpectrogramSettings,
16
- bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = [(17, 30), (70, 170)]
17
+ bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = [
18
+ (17, 30),
19
+ (70, 170),
20
+ ],
17
21
  ) -> typing.Generator[AxisArray, AxisArray, None]:
18
- axis_arr_in = AxisArray(np.array([]), dims=[""])
19
- axis_arr_out = AxisArray(np.array([]), dims=[""])
22
+ """
23
+ Calculate the average spectral power in each band.
24
+
25
+ Args:
26
+ spectrogram_settings: Settings for spectrogram calculation.
27
+ bands: (min, max) tuples of band limits in Hz.
28
+
29
+ Returns:
30
+ A primed generator object ready to yield an AxisArray for each .send(axis_array)
31
+ """
32
+ msg_out = AxisArray(np.array([]), dims=[""])
20
33
 
21
34
  f_spec = spectrogram(
22
35
  window_dur=spectrogram_settings.window_dur,
23
36
  window_shift=spectrogram_settings.window_shift,
24
37
  window=spectrogram_settings.window,
25
38
  transform=spectrogram_settings.transform,
26
- output=spectrogram_settings.output
39
+ output=spectrogram_settings.output,
27
40
  )
28
41
  f_agg = ranged_aggregate(
29
- axis="freq",
30
- bands=bands,
31
- operation=AggregationFunction.MEAN
42
+ axis="freq", bands=bands, operation=AggregationFunction.MEAN
32
43
  )
33
44
  pipeline = compose(f_spec, f_agg)
34
45
 
35
46
  while True:
36
- axis_arr_in = yield axis_arr_out
37
- axis_arr_out = pipeline(axis_arr_in)
47
+ msg_in: AxisArray = yield msg_out
48
+ msg_out = pipeline(msg_in)
38
49
 
39
50
 
40
51
  class BandPowerSettings(ez.Settings):
41
- spectrogram_settings: SpectrogramSettings = field(default_factory=SpectrogramSettings)
42
- bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = (
43
- field(default_factory=lambda: [(17, 30), (70, 170)]))
52
+ """
53
+ Settings for ``BandPower``.
54
+ See :obj:`bandpower` for details.
55
+ """
56
+
57
+ spectrogram_settings: SpectrogramSettings = field(
58
+ default_factory=SpectrogramSettings
59
+ )
60
+ bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = field(
61
+ default_factory=lambda: [(17, 30), (70, 170)]
62
+ )
44
63
 
45
64
 
46
65
  class BandPower(GenAxisArray):
47
- SETTINGS: BandPowerSettings
66
+ """:obj:`Unit` for :obj:`bandpower`."""
67
+
68
+ SETTINGS = BandPowerSettings
48
69
 
49
70
  def construct_generator(self):
50
71
  self.STATE.gen = bandpower(
51
72
  spectrogram_settings=self.SETTINGS.spectrogram_settings,
52
- bands=self.SETTINGS.bands
73
+ bands=self.SETTINGS.bands,
53
74
  )
ezmsg/sigproc/base.py ADDED
@@ -0,0 +1,38 @@
1
+ import traceback
2
+ import typing
3
+
4
+ import ezmsg.core as ez
5
+ from ezmsg.util.messages.axisarray import AxisArray
6
+ from ezmsg.util.generator import GenState
7
+
8
+
9
+ class GenAxisArray(ez.Unit):
10
+ STATE = GenState
11
+
12
+ INPUT_SIGNAL = ez.InputStream(AxisArray)
13
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
14
+ INPUT_SETTINGS = ez.InputStream(ez.Settings)
15
+
16
+ async def initialize(self) -> None:
17
+ self.construct_generator()
18
+
19
+ # Method to be implemented by subclasses to construct the specific generator
20
+ def construct_generator(self):
21
+ raise NotImplementedError
22
+
23
+ @ez.subscriber(INPUT_SETTINGS)
24
+ async def on_settings(self, msg: ez.Settings) -> None:
25
+ self.apply_settings(msg)
26
+ self.construct_generator()
27
+
28
+ @ez.subscriber(INPUT_SIGNAL, zero_copy=True)
29
+ @ez.publisher(OUTPUT_SIGNAL)
30
+ async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
31
+ try:
32
+ ret = self.STATE.gen.send(message)
33
+ if ret.data.size > 0:
34
+ yield self.OUTPUT_SIGNAL, ret
35
+ except (StopIteration, GeneratorExit):
36
+ ez.logger.debug(f"Generator closed in {self.address}")
37
+ except Exception:
38
+ ez.logger.info(traceback.format_exc())