ezmsg-sigproc 1.2.3__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +1 -4
- ezmsg/sigproc/__version__.py +16 -0
- ezmsg/sigproc/activation.py +75 -0
- ezmsg/sigproc/affinetransform.py +149 -39
- ezmsg/sigproc/aggregate.py +84 -29
- ezmsg/sigproc/bandpower.py +36 -15
- ezmsg/sigproc/base.py +38 -0
- ezmsg/sigproc/butterworthfilter.py +76 -20
- ezmsg/sigproc/decimate.py +7 -4
- ezmsg/sigproc/downsample.py +79 -61
- ezmsg/sigproc/ewmfilter.py +28 -14
- ezmsg/sigproc/filter.py +51 -31
- ezmsg/sigproc/filterbank.py +278 -0
- ezmsg/sigproc/math/__init__.py +0 -0
- ezmsg/sigproc/math/abs.py +28 -0
- ezmsg/sigproc/math/clip.py +30 -0
- ezmsg/sigproc/math/difference.py +60 -0
- ezmsg/sigproc/math/invert.py +29 -0
- ezmsg/sigproc/math/log.py +32 -0
- ezmsg/sigproc/math/scale.py +31 -0
- ezmsg/sigproc/messages.py +2 -3
- ezmsg/sigproc/sampler.py +152 -90
- ezmsg/sigproc/scaler.py +88 -42
- ezmsg/sigproc/signalinjector.py +7 -10
- ezmsg/sigproc/slicer.py +71 -36
- ezmsg/sigproc/spectral.py +6 -9
- ezmsg/sigproc/spectrogram.py +48 -30
- ezmsg/sigproc/spectrum.py +177 -76
- ezmsg/sigproc/synth.py +162 -67
- ezmsg/sigproc/wavelets.py +167 -0
- ezmsg/sigproc/window.py +193 -157
- ezmsg_sigproc-1.3.1.dist-info/METADATA +59 -0
- ezmsg_sigproc-1.3.1.dist-info/RECORD +35 -0
- {ezmsg_sigproc-1.2.3.dist-info → ezmsg_sigproc-1.3.1.dist-info}/WHEEL +1 -1
- ezmsg_sigproc-1.2.3.dist-info/METADATA +0 -38
- ezmsg_sigproc-1.2.3.dist-info/RECORD +0 -23
- {ezmsg_sigproc-1.2.3.dist-info → ezmsg_sigproc-1.3.1.dist-info/licenses}/LICENSE.txt +0 -0
ezmsg/sigproc/__init__.py
CHANGED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# file generated by setuptools_scm
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
TYPE_CHECKING = False
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from typing import Tuple, Union
|
|
6
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
7
|
+
else:
|
|
8
|
+
VERSION_TUPLE = object
|
|
9
|
+
|
|
10
|
+
version: str
|
|
11
|
+
__version__: str
|
|
12
|
+
__version_tuple__: VERSION_TUPLE
|
|
13
|
+
version_tuple: VERSION_TUPLE
|
|
14
|
+
|
|
15
|
+
__version__ = version = '1.3.1'
|
|
16
|
+
__version_tuple__ = version_tuple = (1, 3, 1)
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from dataclasses import replace
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import scipy.special
|
|
6
|
+
import ezmsg.core as ez
|
|
7
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
|
+
from ezmsg.util.generator import consumer
|
|
9
|
+
|
|
10
|
+
from .spectral import OptionsEnum
|
|
11
|
+
from .base import GenAxisArray
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ActivationFunction(OptionsEnum):
|
|
15
|
+
"""Activation (transformation) function."""
|
|
16
|
+
|
|
17
|
+
NONE = "none"
|
|
18
|
+
"""None."""
|
|
19
|
+
|
|
20
|
+
SIGMOID = "sigmoid"
|
|
21
|
+
""":obj:`scipy.special.expit`"""
|
|
22
|
+
|
|
23
|
+
EXPIT = "expit"
|
|
24
|
+
""":obj:`scipy.special.expit`"""
|
|
25
|
+
|
|
26
|
+
LOGIT = "logit"
|
|
27
|
+
""":obj:`scipy.special.logit`"""
|
|
28
|
+
|
|
29
|
+
LOGEXPIT = "log_expit"
|
|
30
|
+
""":obj:`scipy.special.log_expit`"""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
ACTIVATIONS = {
|
|
34
|
+
ActivationFunction.NONE: lambda x: x,
|
|
35
|
+
ActivationFunction.SIGMOID: scipy.special.expit,
|
|
36
|
+
ActivationFunction.EXPIT: scipy.special.expit,
|
|
37
|
+
ActivationFunction.LOGIT: scipy.special.logit,
|
|
38
|
+
ActivationFunction.LOGEXPIT: scipy.special.log_expit,
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@consumer
|
|
43
|
+
def activation(
|
|
44
|
+
function: typing.Union[str, ActivationFunction],
|
|
45
|
+
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
46
|
+
if type(function) is ActivationFunction:
|
|
47
|
+
func = ACTIVATIONS[function]
|
|
48
|
+
else:
|
|
49
|
+
# str type. There's probably an easier way to support either enum or str argument. Oh well this works.
|
|
50
|
+
function: str = function.lower()
|
|
51
|
+
if function not in ActivationFunction.options():
|
|
52
|
+
raise ValueError(
|
|
53
|
+
f"Unrecognized activation function {function}. Must be one of {ACTIVATIONS.keys()}"
|
|
54
|
+
)
|
|
55
|
+
function = list(ACTIVATIONS.keys())[
|
|
56
|
+
ActivationFunction.options().index(function)
|
|
57
|
+
]
|
|
58
|
+
func = ACTIVATIONS[function]
|
|
59
|
+
|
|
60
|
+
msg_out = AxisArray(np.array([]), dims=[""])
|
|
61
|
+
|
|
62
|
+
while True:
|
|
63
|
+
msg_in: AxisArray = yield msg_out
|
|
64
|
+
msg_out = replace(msg_in, data=func(msg_in.data))
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class ActivationSettings(ez.Settings):
|
|
68
|
+
function: str = ActivationFunction.NONE
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class Activation(GenAxisArray):
|
|
72
|
+
SETTINGS = ActivationSettings
|
|
73
|
+
|
|
74
|
+
def construct_generator(self):
|
|
75
|
+
self.STATE.gen = activation(function=self.SETTINGS.function)
|
ezmsg/sigproc/affinetransform.py
CHANGED
|
@@ -1,65 +1,148 @@
|
|
|
1
1
|
from dataclasses import replace
|
|
2
2
|
import os
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
|
|
4
|
+
import typing
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
|
+
import numpy.typing as npt
|
|
7
8
|
import ezmsg.core as ez
|
|
8
9
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
|
-
from ezmsg.util.generator import consumer
|
|
10
|
+
from ezmsg.util.generator import consumer
|
|
11
|
+
|
|
12
|
+
from .base import GenAxisArray
|
|
10
13
|
|
|
11
14
|
|
|
12
15
|
@consumer
|
|
13
16
|
def affine_transform(
|
|
14
|
-
weights: Union[np.ndarray, str, Path],
|
|
15
|
-
axis: Optional[str] = None,
|
|
17
|
+
weights: typing.Union[np.ndarray, str, Path],
|
|
18
|
+
axis: typing.Optional[str] = None,
|
|
16
19
|
right_multiply: bool = True,
|
|
17
|
-
) -> Generator[AxisArray, AxisArray, None]:
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
20
|
+
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
21
|
+
"""
|
|
22
|
+
Perform affine transformations on streaming data.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
weights: An array of weights or a path to a file with weights compatible with np.loadtxt.
|
|
26
|
+
axis: The name of the axis to apply the transformation to. Defaults to the leading (0th) axis in the array.
|
|
27
|
+
right_multiply: Set False to tranpose the weights before applying.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
A primed generator object that yields an :obj:`AxisArray` object for every
|
|
31
|
+
:obj:`AxisArray` it receives via `send`.
|
|
32
|
+
"""
|
|
33
|
+
msg_out = AxisArray(np.array([]), dims=[""])
|
|
34
|
+
|
|
35
|
+
# Check parameters
|
|
21
36
|
if isinstance(weights, str):
|
|
22
|
-
weights
|
|
37
|
+
if weights == "passthrough":
|
|
38
|
+
weights = None
|
|
39
|
+
else:
|
|
40
|
+
weights = Path(os.path.abspath(os.path.expanduser(weights)))
|
|
23
41
|
if isinstance(weights, Path):
|
|
24
42
|
weights = np.loadtxt(weights, delimiter=",")
|
|
25
43
|
if not right_multiply:
|
|
26
44
|
weights = weights.T
|
|
27
|
-
weights
|
|
45
|
+
if weights is not None:
|
|
46
|
+
weights = np.ascontiguousarray(weights)
|
|
28
47
|
|
|
29
|
-
|
|
30
|
-
|
|
48
|
+
# State variables
|
|
49
|
+
# New axis with transformed labels, if required
|
|
50
|
+
new_axis: typing.Optional[AxisArray.Axis] = None
|
|
31
51
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
axis_idx = axis_arr_in.get_axis_idx(axis)
|
|
52
|
+
# Reset if any of these change.
|
|
53
|
+
check_input = {"key": None}
|
|
54
|
+
# We assume key change catches labels change; we don't want to check labels every message
|
|
55
|
+
# We don't need to check if input size has changed because weights multiplication will fail if so.
|
|
37
56
|
|
|
38
|
-
|
|
57
|
+
while True:
|
|
58
|
+
msg_in: AxisArray = yield msg_out
|
|
59
|
+
|
|
60
|
+
if weights is None:
|
|
61
|
+
msg_out = msg_in
|
|
62
|
+
continue
|
|
63
|
+
|
|
64
|
+
axis = axis or msg_in.dims[-1] # Note: Most nodes default do dim[0]
|
|
65
|
+
axis_idx = msg_in.get_axis_idx(axis)
|
|
66
|
+
|
|
67
|
+
b_reset = msg_in.key != check_input["key"]
|
|
68
|
+
if b_reset:
|
|
69
|
+
# First sample or key has changed. Reset the state.
|
|
70
|
+
check_input["key"] = msg_in.key
|
|
71
|
+
# Determine if we need to modify the transformed axis.
|
|
72
|
+
if (
|
|
73
|
+
axis in msg_in.axes
|
|
74
|
+
and hasattr(msg_in.axes[axis], "labels")
|
|
75
|
+
and weights.shape[0] != weights.shape[1]
|
|
76
|
+
):
|
|
77
|
+
in_labels = msg_in.axes[axis].labels
|
|
78
|
+
new_labels = []
|
|
79
|
+
n_in = weights.shape[1 if right_multiply else 0]
|
|
80
|
+
n_out = weights.shape[0 if right_multiply else 1]
|
|
81
|
+
if len(in_labels) != n_in:
|
|
82
|
+
# Something upstream did something it wasn't supposed to. We will drop the labels.
|
|
83
|
+
ez.logger.warning(
|
|
84
|
+
f"Received {len(in_labels)} for {n_in} inputs. Check upstream labels."
|
|
85
|
+
)
|
|
86
|
+
else:
|
|
87
|
+
b_used_inputs = np.any(weights, axis=0 if right_multiply else 1)
|
|
88
|
+
b_filled_outputs = np.any(weights, axis=1 if right_multiply else 0)
|
|
89
|
+
if np.all(b_used_inputs) and np.all(b_filled_outputs):
|
|
90
|
+
# All inputs are used and all outputs are used, but n_in != n_out.
|
|
91
|
+
# Mapping cannot be determined.
|
|
92
|
+
new_labels = []
|
|
93
|
+
elif np.all(b_used_inputs):
|
|
94
|
+
# Strange scenario: New outputs are filled with empty data.
|
|
95
|
+
in_ix = 0
|
|
96
|
+
new_labels = []
|
|
97
|
+
for out_ix in range(n_out):
|
|
98
|
+
if b_filled_outputs[out_ix]:
|
|
99
|
+
new_labels.append(in_labels[in_ix])
|
|
100
|
+
in_ix += 1
|
|
101
|
+
else:
|
|
102
|
+
new_labels.append("")
|
|
103
|
+
elif np.all(b_filled_outputs):
|
|
104
|
+
# Transform is dropping some of the inputs.
|
|
105
|
+
new_labels = np.array(in_labels)[b_used_inputs].tolist()
|
|
106
|
+
new_axis = replace(msg_in.axes[axis], labels=new_labels)
|
|
107
|
+
|
|
108
|
+
data = msg_in.data
|
|
39
109
|
|
|
40
110
|
if data.shape[axis_idx] == (weights.shape[0] - 1):
|
|
41
111
|
# The weights are stacked A|B where A is the transform and B is a single row
|
|
42
112
|
# in the equation y = Ax + B. This supports NeuroKey's weights matrices.
|
|
43
|
-
sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx+1:]
|
|
44
|
-
data = np.concatenate(
|
|
113
|
+
sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx + 1 :]
|
|
114
|
+
data = np.concatenate(
|
|
115
|
+
(data, np.ones(sample_shape).astype(data.dtype)), axis=axis_idx
|
|
116
|
+
)
|
|
45
117
|
|
|
46
|
-
if axis_idx in [-1, len(
|
|
118
|
+
if axis_idx in [-1, len(msg_in.dims) - 1]:
|
|
47
119
|
data = np.matmul(data, weights)
|
|
48
120
|
else:
|
|
49
121
|
data = np.moveaxis(data, axis_idx, -1)
|
|
50
122
|
data = np.matmul(data, weights)
|
|
51
123
|
data = np.moveaxis(data, -1, axis_idx)
|
|
52
|
-
|
|
124
|
+
|
|
125
|
+
replace_kwargs = {"data": data}
|
|
126
|
+
if new_axis is not None:
|
|
127
|
+
replace_kwargs["axes"] = {**msg_in.axes, axis: new_axis}
|
|
128
|
+
msg_out = replace(msg_in, **replace_kwargs)
|
|
53
129
|
|
|
54
130
|
|
|
55
131
|
class AffineTransformSettings(ez.Settings):
|
|
56
|
-
|
|
57
|
-
|
|
132
|
+
"""
|
|
133
|
+
Settings for :obj:`AffineTransform`.
|
|
134
|
+
See :obj:`affine_transform` for argument details.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
weights: typing.Union[np.ndarray, str, Path]
|
|
138
|
+
axis: typing.Optional[str] = None
|
|
58
139
|
right_multiply: bool = True
|
|
59
140
|
|
|
60
141
|
|
|
61
142
|
class AffineTransform(GenAxisArray):
|
|
62
|
-
|
|
143
|
+
""":obj:`Unit` for :obj:`affine_transform`"""
|
|
144
|
+
|
|
145
|
+
SETTINGS = AffineTransformSettings
|
|
63
146
|
|
|
64
147
|
def construct_generator(self):
|
|
65
148
|
self.STATE.gen = affine_transform(
|
|
@@ -69,25 +152,43 @@ class AffineTransform(GenAxisArray):
|
|
|
69
152
|
)
|
|
70
153
|
|
|
71
154
|
|
|
155
|
+
def zeros_for_noop(data: npt.NDArray, **ignore_kwargs) -> npt.NDArray:
|
|
156
|
+
return np.zeros_like(data)
|
|
157
|
+
|
|
158
|
+
|
|
72
159
|
@consumer
|
|
73
160
|
def common_rereference(
|
|
74
|
-
mode: str = "mean", axis: Optional[str] = None, include_current: bool = True
|
|
75
|
-
) -> Generator[AxisArray, AxisArray, None]:
|
|
76
|
-
|
|
77
|
-
|
|
161
|
+
mode: str = "mean", axis: typing.Optional[str] = None, include_current: bool = True
|
|
162
|
+
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
163
|
+
"""
|
|
164
|
+
Perform common average referencing (CAR) on streaming data.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
mode: The statistical mode to apply -- either "mean" or "median"
|
|
168
|
+
axis: The name of hte axis to apply the transformation to.
|
|
169
|
+
include_current: Set False to exclude each channel from participating in the calculation of its reference.
|
|
78
170
|
|
|
79
|
-
|
|
171
|
+
Returns:
|
|
172
|
+
A primed generator object that yields an :obj:`AxisArray` object
|
|
173
|
+
for every :obj:`AxisArray` it receives via `send`.
|
|
174
|
+
"""
|
|
175
|
+
msg_out = AxisArray(np.array([]), dims=[""])
|
|
176
|
+
|
|
177
|
+
if mode == "passthrough":
|
|
178
|
+
include_current = True
|
|
179
|
+
|
|
180
|
+
func = {"mean": np.mean, "median": np.median, "passthrough": zeros_for_noop}[mode]
|
|
80
181
|
|
|
81
182
|
while True:
|
|
82
|
-
|
|
183
|
+
msg_in: AxisArray = yield msg_out
|
|
83
184
|
|
|
84
185
|
if axis is None:
|
|
85
|
-
axis =
|
|
186
|
+
axis = msg_in.dims[-1]
|
|
86
187
|
axis_idx = -1
|
|
87
188
|
else:
|
|
88
|
-
axis_idx =
|
|
189
|
+
axis_idx = msg_in.get_axis_idx(axis)
|
|
89
190
|
|
|
90
|
-
ref_data = func(
|
|
191
|
+
ref_data = func(msg_in.data, axis=axis_idx, keepdims=True)
|
|
91
192
|
|
|
92
193
|
if not include_current:
|
|
93
194
|
# Typical `CAR = x[0]/N + x[1]/N + ... x[i-1]/N + x[i]/N + x[i+1]/N + ... + x[N-1]/N`
|
|
@@ -100,21 +201,30 @@ def common_rereference(
|
|
|
100
201
|
# from the current channel (i.e., `x[i] / (N-1)`)
|
|
101
202
|
# i.e., `CAR[i] = (N / (N-1)) * common_CAR - x[i]/(N-1)`
|
|
102
203
|
# We can use broadcasting subtraction instead of looping over channels.
|
|
103
|
-
N =
|
|
104
|
-
ref_data = (N / (N - 1)) * ref_data -
|
|
204
|
+
N = msg_in.data.shape[axis_idx]
|
|
205
|
+
ref_data = (N / (N - 1)) * ref_data - msg_in.data / (N - 1)
|
|
105
206
|
# Side note: I profiled using affine_transform and it's about 30x slower than this implementation.
|
|
106
207
|
|
|
107
|
-
|
|
208
|
+
msg_out = replace(msg_in, data=msg_in.data - ref_data)
|
|
108
209
|
|
|
109
210
|
|
|
110
211
|
class CommonRereferenceSettings(ez.Settings):
|
|
212
|
+
"""
|
|
213
|
+
Settings for :obj:`CommonRereference`
|
|
214
|
+
See :obj:`common_rereference` for argument details.
|
|
215
|
+
"""
|
|
216
|
+
|
|
111
217
|
mode: str = "mean"
|
|
112
|
-
axis: Optional[str] = None
|
|
218
|
+
axis: typing.Optional[str] = None
|
|
113
219
|
include_current: bool = True
|
|
114
220
|
|
|
115
221
|
|
|
116
222
|
class CommonRereference(GenAxisArray):
|
|
117
|
-
|
|
223
|
+
"""
|
|
224
|
+
:obj:`Unit` for :obj:`common_rereference`.
|
|
225
|
+
"""
|
|
226
|
+
|
|
227
|
+
SETTINGS = CommonRereferenceSettings
|
|
118
228
|
|
|
119
229
|
def construct_generator(self):
|
|
120
230
|
self.STATE.gen = common_rereference(
|
ezmsg/sigproc/aggregate.py
CHANGED
|
@@ -2,13 +2,18 @@ from dataclasses import replace
|
|
|
2
2
|
import typing
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
|
+
import numpy.typing as npt
|
|
5
6
|
import ezmsg.core as ez
|
|
6
|
-
from ezmsg.util.generator import consumer
|
|
7
|
+
from ezmsg.util.generator import consumer
|
|
7
8
|
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
8
|
-
|
|
9
|
+
|
|
10
|
+
from .spectral import OptionsEnum
|
|
11
|
+
from .base import GenAxisArray
|
|
9
12
|
|
|
10
13
|
|
|
11
14
|
class AggregationFunction(OptionsEnum):
|
|
15
|
+
"""Enum for aggregation functions available to be used in :obj:`ranged_aggregate` operation."""
|
|
16
|
+
|
|
12
17
|
NONE = "None (all)"
|
|
13
18
|
MAX = "max"
|
|
14
19
|
MIN = "min"
|
|
@@ -20,6 +25,8 @@ class AggregationFunction(OptionsEnum):
|
|
|
20
25
|
NANMEAN = "nanmean"
|
|
21
26
|
NANMEDIAN = "nanmedian"
|
|
22
27
|
NANSTD = "nanstd"
|
|
28
|
+
ARGMIN = "argmin"
|
|
29
|
+
ARGMAX = "argmax"
|
|
23
30
|
|
|
24
31
|
|
|
25
32
|
AGGREGATORS = {
|
|
@@ -33,7 +40,9 @@ AGGREGATORS = {
|
|
|
33
40
|
AggregationFunction.NANMIN: np.nanmin,
|
|
34
41
|
AggregationFunction.NANMEAN: np.nanmean,
|
|
35
42
|
AggregationFunction.NANMEDIAN: np.nanmedian,
|
|
36
|
-
AggregationFunction.NANSTD: np.nanstd
|
|
43
|
+
AggregationFunction.NANSTD: np.nanstd,
|
|
44
|
+
AggregationFunction.ARGMIN: np.argmin,
|
|
45
|
+
AggregationFunction.ARGMAX: np.argmax,
|
|
37
46
|
}
|
|
38
47
|
|
|
39
48
|
|
|
@@ -41,63 +50,109 @@ AGGREGATORS = {
|
|
|
41
50
|
def ranged_aggregate(
|
|
42
51
|
axis: typing.Optional[str] = None,
|
|
43
52
|
bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = None,
|
|
44
|
-
operation: AggregationFunction = AggregationFunction.MEAN
|
|
53
|
+
operation: AggregationFunction = AggregationFunction.MEAN,
|
|
45
54
|
):
|
|
46
|
-
|
|
47
|
-
|
|
55
|
+
"""
|
|
56
|
+
Apply an aggregation operation over one or more bands.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
axis: The name of the axis along which to apply the bands.
|
|
60
|
+
bands: [(band1_min, band1_max), (band2_min, band2_max), ...]
|
|
61
|
+
If not set then this acts as a passthrough node.
|
|
62
|
+
operation: :obj:`AggregationFunction` to apply to each band.
|
|
48
63
|
|
|
49
|
-
|
|
50
|
-
|
|
64
|
+
Returns:
|
|
65
|
+
A primed generator object ready to yield an AxisArray for each .send(axis_array)
|
|
66
|
+
"""
|
|
67
|
+
msg_out = AxisArray(np.array([]), dims=[""])
|
|
68
|
+
|
|
69
|
+
# State variables
|
|
51
70
|
slices: typing.Optional[typing.List[typing.Tuple[typing.Any, ...]]] = None
|
|
52
|
-
|
|
71
|
+
out_axis: typing.Optional[AxisArray.Axis] = None
|
|
72
|
+
ax_vec: typing.Optional[npt.NDArray] = None
|
|
73
|
+
|
|
74
|
+
# Reset if any of these changes. Key not checked because continuity between chunks not required.
|
|
75
|
+
check_inputs = {"gain": None, "offset": None}
|
|
53
76
|
|
|
54
77
|
while True:
|
|
55
|
-
|
|
78
|
+
msg_in: AxisArray = yield msg_out
|
|
56
79
|
if bands is None:
|
|
57
|
-
|
|
80
|
+
msg_out = msg_in
|
|
58
81
|
else:
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
82
|
+
axis = axis or msg_in.dims[0]
|
|
83
|
+
target_axis = msg_in.get_axis(axis)
|
|
84
|
+
|
|
85
|
+
b_reset = target_axis.gain != check_inputs["gain"]
|
|
86
|
+
b_reset = b_reset or target_axis.offset != check_inputs["offset"]
|
|
87
|
+
if b_reset:
|
|
88
|
+
check_inputs["gain"] = target_axis.gain
|
|
89
|
+
check_inputs["offset"] = target_axis.offset
|
|
64
90
|
|
|
65
|
-
|
|
91
|
+
# If the axis we are operating on has changed (e.g., "time" or "win" axis always changes),
|
|
92
|
+
# or the key has changed, then recalculate slices.
|
|
93
|
+
|
|
94
|
+
ax_idx = msg_in.get_axis_idx(axis)
|
|
95
|
+
|
|
96
|
+
ax_vec = (
|
|
97
|
+
target_axis.offset
|
|
98
|
+
+ np.arange(msg_in.data.shape[ax_idx]) * target_axis.gain
|
|
99
|
+
)
|
|
66
100
|
slices = []
|
|
67
101
|
mids = []
|
|
68
|
-
for
|
|
102
|
+
for start, stop in bands:
|
|
69
103
|
inds = np.where(np.logical_and(ax_vec >= start, ax_vec <= stop))[0]
|
|
70
104
|
mids.append(np.mean(inds) * target_axis.gain + target_axis.offset)
|
|
71
|
-
slices.append(np.s_[inds[0]:inds[-1] + 1])
|
|
72
|
-
|
|
73
|
-
unit
|
|
74
|
-
|
|
105
|
+
slices.append(np.s_[inds[0] : inds[-1] + 1])
|
|
106
|
+
out_ax_kwargs = {
|
|
107
|
+
"unit": target_axis.unit,
|
|
108
|
+
"offset": mids[0],
|
|
109
|
+
"gain": (mids[1] - mids[0]) if len(mids) > 1 else 1.0,
|
|
110
|
+
}
|
|
111
|
+
if hasattr(target_axis, "labels"):
|
|
112
|
+
out_ax_kwargs["labels"] = [f"{_[0]} - {_[1]}" for _ in bands]
|
|
113
|
+
out_axis = replace(target_axis, **out_ax_kwargs)
|
|
75
114
|
|
|
76
115
|
agg_func = AGGREGATORS[operation]
|
|
77
116
|
out_data = [
|
|
78
|
-
agg_func(slice_along_axis(
|
|
117
|
+
agg_func(slice_along_axis(msg_in.data, sl, axis=ax_idx), axis=ax_idx)
|
|
79
118
|
for sl in slices
|
|
80
119
|
]
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
120
|
+
|
|
121
|
+
msg_out = replace(
|
|
122
|
+
msg_in,
|
|
84
123
|
data=np.stack(out_data, axis=ax_idx),
|
|
85
|
-
axes=
|
|
124
|
+
axes={**msg_in.axes, axis: out_axis},
|
|
86
125
|
)
|
|
126
|
+
if operation in [AggregationFunction.ARGMIN, AggregationFunction.ARGMAX]:
|
|
127
|
+
# Convert indices returned by argmin/argmax into the value along the axis.
|
|
128
|
+
out_data = []
|
|
129
|
+
for sl_ix, sl in enumerate(slices):
|
|
130
|
+
offsets = np.take(msg_out.data, [sl_ix], axis=ax_idx)
|
|
131
|
+
out_data.append(ax_vec[sl][offsets])
|
|
132
|
+
msg_out.data = np.concatenate(out_data, axis=ax_idx)
|
|
87
133
|
|
|
88
134
|
|
|
89
135
|
class RangedAggregateSettings(ez.Settings):
|
|
136
|
+
"""
|
|
137
|
+
Settings for ``RangedAggregate``.
|
|
138
|
+
See :obj:`ranged_aggregate` for details.
|
|
139
|
+
"""
|
|
140
|
+
|
|
90
141
|
axis: typing.Optional[str] = None
|
|
91
142
|
bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = None
|
|
92
143
|
operation: AggregationFunction = AggregationFunction.MEAN
|
|
93
144
|
|
|
94
145
|
|
|
95
146
|
class RangedAggregate(GenAxisArray):
|
|
96
|
-
|
|
147
|
+
"""
|
|
148
|
+
Unit for :obj:`ranged_aggregate`
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
SETTINGS = RangedAggregateSettings
|
|
97
152
|
|
|
98
153
|
def construct_generator(self):
|
|
99
154
|
self.STATE.gen = ranged_aggregate(
|
|
100
155
|
axis=self.SETTINGS.axis,
|
|
101
156
|
bands=self.SETTINGS.bands,
|
|
102
|
-
operation=self.SETTINGS.operation
|
|
157
|
+
operation=self.SETTINGS.operation,
|
|
103
158
|
)
|
ezmsg/sigproc/bandpower.py
CHANGED
|
@@ -4,50 +4,71 @@ import typing
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import ezmsg.core as ez
|
|
6
6
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
-
from ezmsg.util.generator import consumer, compose
|
|
7
|
+
from ezmsg.util.generator import consumer, compose
|
|
8
8
|
|
|
9
9
|
from .spectrogram import spectrogram, SpectrogramSettings
|
|
10
10
|
from .aggregate import ranged_aggregate, AggregationFunction
|
|
11
|
+
from .base import GenAxisArray
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
@consumer
|
|
14
15
|
def bandpower(
|
|
15
16
|
spectrogram_settings: SpectrogramSettings,
|
|
16
|
-
bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = [
|
|
17
|
+
bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = [
|
|
18
|
+
(17, 30),
|
|
19
|
+
(70, 170),
|
|
20
|
+
],
|
|
17
21
|
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
18
|
-
|
|
19
|
-
|
|
22
|
+
"""
|
|
23
|
+
Calculate the average spectral power in each band.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
spectrogram_settings: Settings for spectrogram calculation.
|
|
27
|
+
bands: (min, max) tuples of band limits in Hz.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
A primed generator object ready to yield an AxisArray for each .send(axis_array)
|
|
31
|
+
"""
|
|
32
|
+
msg_out = AxisArray(np.array([]), dims=[""])
|
|
20
33
|
|
|
21
34
|
f_spec = spectrogram(
|
|
22
35
|
window_dur=spectrogram_settings.window_dur,
|
|
23
36
|
window_shift=spectrogram_settings.window_shift,
|
|
24
37
|
window=spectrogram_settings.window,
|
|
25
38
|
transform=spectrogram_settings.transform,
|
|
26
|
-
output=spectrogram_settings.output
|
|
39
|
+
output=spectrogram_settings.output,
|
|
27
40
|
)
|
|
28
41
|
f_agg = ranged_aggregate(
|
|
29
|
-
axis="freq",
|
|
30
|
-
bands=bands,
|
|
31
|
-
operation=AggregationFunction.MEAN
|
|
42
|
+
axis="freq", bands=bands, operation=AggregationFunction.MEAN
|
|
32
43
|
)
|
|
33
44
|
pipeline = compose(f_spec, f_agg)
|
|
34
45
|
|
|
35
46
|
while True:
|
|
36
|
-
|
|
37
|
-
|
|
47
|
+
msg_in: AxisArray = yield msg_out
|
|
48
|
+
msg_out = pipeline(msg_in)
|
|
38
49
|
|
|
39
50
|
|
|
40
51
|
class BandPowerSettings(ez.Settings):
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
52
|
+
"""
|
|
53
|
+
Settings for ``BandPower``.
|
|
54
|
+
See :obj:`bandpower` for details.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
spectrogram_settings: SpectrogramSettings = field(
|
|
58
|
+
default_factory=SpectrogramSettings
|
|
59
|
+
)
|
|
60
|
+
bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = field(
|
|
61
|
+
default_factory=lambda: [(17, 30), (70, 170)]
|
|
62
|
+
)
|
|
44
63
|
|
|
45
64
|
|
|
46
65
|
class BandPower(GenAxisArray):
|
|
47
|
-
|
|
66
|
+
""":obj:`Unit` for :obj:`bandpower`."""
|
|
67
|
+
|
|
68
|
+
SETTINGS = BandPowerSettings
|
|
48
69
|
|
|
49
70
|
def construct_generator(self):
|
|
50
71
|
self.STATE.gen = bandpower(
|
|
51
72
|
spectrogram_settings=self.SETTINGS.spectrogram_settings,
|
|
52
|
-
bands=self.SETTINGS.bands
|
|
73
|
+
bands=self.SETTINGS.bands,
|
|
53
74
|
)
|
ezmsg/sigproc/base.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import traceback
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import ezmsg.core as ez
|
|
5
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
6
|
+
from ezmsg.util.generator import GenState
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class GenAxisArray(ez.Unit):
|
|
10
|
+
STATE = GenState
|
|
11
|
+
|
|
12
|
+
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
13
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
14
|
+
INPUT_SETTINGS = ez.InputStream(ez.Settings)
|
|
15
|
+
|
|
16
|
+
async def initialize(self) -> None:
|
|
17
|
+
self.construct_generator()
|
|
18
|
+
|
|
19
|
+
# Method to be implemented by subclasses to construct the specific generator
|
|
20
|
+
def construct_generator(self):
|
|
21
|
+
raise NotImplementedError
|
|
22
|
+
|
|
23
|
+
@ez.subscriber(INPUT_SETTINGS)
|
|
24
|
+
async def on_settings(self, msg: ez.Settings) -> None:
|
|
25
|
+
self.apply_settings(msg)
|
|
26
|
+
self.construct_generator()
|
|
27
|
+
|
|
28
|
+
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
|
|
29
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
30
|
+
async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
|
|
31
|
+
try:
|
|
32
|
+
ret = self.STATE.gen.send(message)
|
|
33
|
+
if ret.data.size > 0:
|
|
34
|
+
yield self.OUTPUT_SIGNAL, ret
|
|
35
|
+
except (StopIteration, GeneratorExit):
|
|
36
|
+
ez.logger.debug(f"Generator closed in {self.address}")
|
|
37
|
+
except Exception:
|
|
38
|
+
ez.logger.info(traceback.format_exc())
|