ezmsg-sigproc 1.2.2__tar.gz → 1.2.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ezmsg-sigproc-1.2.2 → ezmsg_sigproc-1.2.3}/PKG-INFO +15 -13
- ezmsg_sigproc-1.2.3/pyproject.toml +30 -0
- ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/__init__.py +4 -0
- ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/affinetransform.py +124 -0
- ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/aggregate.py +103 -0
- ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/bandpower.py +53 -0
- {ezmsg-sigproc-1.2.2 → ezmsg_sigproc-1.2.3/src}/ezmsg/sigproc/butterworthfilter.py +41 -6
- ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/downsample.py +89 -0
- {ezmsg-sigproc-1.2.2 → ezmsg_sigproc-1.2.3/src}/ezmsg/sigproc/ewmfilter.py +11 -3
- {ezmsg-sigproc-1.2.2 → ezmsg_sigproc-1.2.3/src}/ezmsg/sigproc/filter.py +82 -14
- ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/sampler.py +260 -0
- ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/scaler.py +127 -0
- ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/signalinjector.py +67 -0
- ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/slicer.py +98 -0
- ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/spectral.py +9 -0
- ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/spectrogram.py +68 -0
- ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/spectrum.py +158 -0
- {ezmsg-sigproc-1.2.2 → ezmsg_sigproc-1.2.3/src}/ezmsg/sigproc/synth.py +179 -80
- ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/window.py +246 -0
- ezmsg-sigproc-1.2.2/ezmsg/sigproc/__init__.py +0 -1
- ezmsg-sigproc-1.2.2/ezmsg/sigproc/__version__.py +0 -1
- ezmsg-sigproc-1.2.2/ezmsg/sigproc/downsample.py +0 -63
- ezmsg-sigproc-1.2.2/ezmsg/sigproc/sampler.py +0 -287
- ezmsg-sigproc-1.2.2/ezmsg/sigproc/spectral.py +0 -132
- ezmsg-sigproc-1.2.2/ezmsg/sigproc/window.py +0 -144
- ezmsg-sigproc-1.2.2/ezmsg_sigproc.egg-info/PKG-INFO +0 -36
- ezmsg-sigproc-1.2.2/ezmsg_sigproc.egg-info/SOURCES.txt +0 -25
- ezmsg-sigproc-1.2.2/ezmsg_sigproc.egg-info/dependency_links.txt +0 -1
- ezmsg-sigproc-1.2.2/ezmsg_sigproc.egg-info/not-zip-safe +0 -1
- ezmsg-sigproc-1.2.2/ezmsg_sigproc.egg-info/requires.txt +0 -7
- ezmsg-sigproc-1.2.2/ezmsg_sigproc.egg-info/top_level.txt +0 -1
- ezmsg-sigproc-1.2.2/setup.cfg +0 -31
- ezmsg-sigproc-1.2.2/setup.py +0 -7
- ezmsg-sigproc-1.2.2/tests/test_butterworth.py +0 -142
- ezmsg-sigproc-1.2.2/tests/test_downsample.py +0 -132
- ezmsg-sigproc-1.2.2/tests/test_window.py +0 -139
- {ezmsg-sigproc-1.2.2 → ezmsg_sigproc-1.2.3}/LICENSE.txt +0 -0
- {ezmsg-sigproc-1.2.2 → ezmsg_sigproc-1.2.3}/README.md +0 -0
- {ezmsg-sigproc-1.2.2 → ezmsg_sigproc-1.2.3/src}/ezmsg/sigproc/decimate.py +0 -0
- {ezmsg-sigproc-1.2.2 → ezmsg_sigproc-1.2.3/src}/ezmsg/sigproc/messages.py +0 -0
|
@@ -1,21 +1,22 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ezmsg-sigproc
|
|
3
|
-
Version: 1.2.
|
|
3
|
+
Version: 1.2.3
|
|
4
4
|
Summary: Timeseries signal processing implementations in ezmsg
|
|
5
|
-
|
|
6
|
-
Author: Griffin
|
|
7
|
-
Author-email: griffin.milsap@
|
|
5
|
+
License: MIT
|
|
6
|
+
Author: Milsap, Griffin
|
|
7
|
+
Author-email: griffin.milsap@gmail.com
|
|
8
|
+
Requires-Python: >=3.8,<4.0
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
8
10
|
Classifier: Programming Language :: Python :: 3
|
|
9
|
-
Classifier:
|
|
10
|
-
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
16
|
+
Requires-Dist: ezmsg (>=3.3.0,<4.0.0)
|
|
17
|
+
Requires-Dist: numpy (>=1.19.5,<2.0.0)
|
|
18
|
+
Requires-Dist: scipy (>=1.6.3,<2.0.0)
|
|
11
19
|
Description-Content-Type: text/markdown
|
|
12
|
-
License-File: LICENSE.txt
|
|
13
|
-
Requires-Dist: ezmsg>=3.3.0
|
|
14
|
-
Requires-Dist: numpy>=1.19.5
|
|
15
|
-
Requires-Dist: scipy>=1.6.3
|
|
16
|
-
Provides-Extra: test
|
|
17
|
-
Requires-Dist: pytest; extra == "test"
|
|
18
|
-
Requires-Dist: pytest-cov; extra == "test"
|
|
19
20
|
|
|
20
21
|
# ezmsg.sigproc
|
|
21
22
|
|
|
@@ -34,3 +35,4 @@ Timeseries signal processing implementations for ezmsg
|
|
|
34
35
|
2. `cd` to this directory (`ezmsg-sigproc`) and run `pip install -e .`
|
|
35
36
|
3. Signal processing modules are available under `import ezmsg.sigproc`
|
|
36
37
|
|
|
38
|
+
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
[tool.poetry]
|
|
2
|
+
name = "ezmsg-sigproc"
|
|
3
|
+
version = "1.2.3"
|
|
4
|
+
description = "Timeseries signal processing implementations in ezmsg"
|
|
5
|
+
authors = [
|
|
6
|
+
"Milsap, Griffin <griffin.milsap@gmail.com>",
|
|
7
|
+
"Peranich, Preston <pperanich@gmail.com>",
|
|
8
|
+
]
|
|
9
|
+
license = "MIT"
|
|
10
|
+
readme = "README.md"
|
|
11
|
+
packages = [{ include = "ezmsg", from = "src" }]
|
|
12
|
+
|
|
13
|
+
[tool.poetry.dependencies]
|
|
14
|
+
python = "^3.8"
|
|
15
|
+
ezmsg = "^3.3.0"
|
|
16
|
+
numpy = "^1.19.5"
|
|
17
|
+
scipy = "^1.6.3"
|
|
18
|
+
|
|
19
|
+
[tool.poetry.group.test.dependencies]
|
|
20
|
+
pytest = "^7.0.0"
|
|
21
|
+
pytest-cov = "*"
|
|
22
|
+
|
|
23
|
+
[tool.pytest.ini_options]
|
|
24
|
+
addopts = ["--import-mode=importlib"]
|
|
25
|
+
pythonpath = ["src", "tests"]
|
|
26
|
+
testpaths = "tests"
|
|
27
|
+
|
|
28
|
+
[build-system]
|
|
29
|
+
requires = ["poetry-core"]
|
|
30
|
+
build-backend = "poetry.core.masonry.api"
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
from dataclasses import replace
|
|
2
|
+
import os
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Generator, Optional, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import ezmsg.core as ez
|
|
8
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
|
+
from ezmsg.util.generator import consumer, GenAxisArray
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@consumer
|
|
13
|
+
def affine_transform(
|
|
14
|
+
weights: Union[np.ndarray, str, Path],
|
|
15
|
+
axis: Optional[str] = None,
|
|
16
|
+
right_multiply: bool = True,
|
|
17
|
+
) -> Generator[AxisArray, AxisArray, None]:
|
|
18
|
+
axis_arr_in = AxisArray(np.array([]), dims=[""])
|
|
19
|
+
axis_arr_out = AxisArray(np.array([]), dims=[""])
|
|
20
|
+
|
|
21
|
+
if isinstance(weights, str):
|
|
22
|
+
weights = Path(os.path.abspath(os.path.expanduser(weights)))
|
|
23
|
+
if isinstance(weights, Path):
|
|
24
|
+
weights = np.loadtxt(weights, delimiter=",")
|
|
25
|
+
if not right_multiply:
|
|
26
|
+
weights = weights.T
|
|
27
|
+
weights = np.ascontiguousarray(weights)
|
|
28
|
+
|
|
29
|
+
while True:
|
|
30
|
+
axis_arr_in = yield axis_arr_out
|
|
31
|
+
|
|
32
|
+
if axis is None:
|
|
33
|
+
axis = axis_arr_in.dims[-1]
|
|
34
|
+
axis_idx = -1
|
|
35
|
+
else:
|
|
36
|
+
axis_idx = axis_arr_in.get_axis_idx(axis)
|
|
37
|
+
|
|
38
|
+
data = axis_arr_in.data
|
|
39
|
+
|
|
40
|
+
if data.shape[axis_idx] == (weights.shape[0] - 1):
|
|
41
|
+
# The weights are stacked A|B where A is the transform and B is a single row
|
|
42
|
+
# in the equation y = Ax + B. This supports NeuroKey's weights matrices.
|
|
43
|
+
sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx+1:]
|
|
44
|
+
data = np.concatenate((data, np.ones(sample_shape).astype(data.dtype)), axis=axis_idx)
|
|
45
|
+
|
|
46
|
+
if axis_idx in [-1, len(axis_arr_in.dims) - 1]:
|
|
47
|
+
data = np.matmul(data, weights)
|
|
48
|
+
else:
|
|
49
|
+
data = np.moveaxis(data, axis_idx, -1)
|
|
50
|
+
data = np.matmul(data, weights)
|
|
51
|
+
data = np.moveaxis(data, -1, axis_idx)
|
|
52
|
+
axis_arr_out = replace(axis_arr_in, data=data)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class AffineTransformSettings(ez.Settings):
|
|
56
|
+
weights: Union[np.ndarray, str, Path]
|
|
57
|
+
axis: Optional[str] = None
|
|
58
|
+
right_multiply: bool = True
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class AffineTransform(GenAxisArray):
|
|
62
|
+
SETTINGS: AffineTransformSettings
|
|
63
|
+
|
|
64
|
+
def construct_generator(self):
|
|
65
|
+
self.STATE.gen = affine_transform(
|
|
66
|
+
weights=self.SETTINGS.weights,
|
|
67
|
+
axis=self.SETTINGS.axis,
|
|
68
|
+
right_multiply=self.SETTINGS.right_multiply,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@consumer
|
|
73
|
+
def common_rereference(
|
|
74
|
+
mode: str = "mean", axis: Optional[str] = None, include_current: bool = True
|
|
75
|
+
) -> Generator[AxisArray, AxisArray, None]:
|
|
76
|
+
axis_arr_in = AxisArray(np.array([]), dims=[""])
|
|
77
|
+
axis_arr_out = AxisArray(np.array([]), dims=[""])
|
|
78
|
+
|
|
79
|
+
func = {"mean": np.mean, "median": np.median}[mode]
|
|
80
|
+
|
|
81
|
+
while True:
|
|
82
|
+
axis_arr_in = yield axis_arr_out
|
|
83
|
+
|
|
84
|
+
if axis is None:
|
|
85
|
+
axis = axis_arr_in.dims[-1]
|
|
86
|
+
axis_idx = -1
|
|
87
|
+
else:
|
|
88
|
+
axis_idx = axis_arr_in.get_axis_idx(axis)
|
|
89
|
+
|
|
90
|
+
ref_data = func(axis_arr_in.data, axis=axis_idx, keepdims=True)
|
|
91
|
+
|
|
92
|
+
if not include_current:
|
|
93
|
+
# Typical `CAR = x[0]/N + x[1]/N + ... x[i-1]/N + x[i]/N + x[i+1]/N + ... + x[N-1]/N`
|
|
94
|
+
# and is the same for all i, so it is calculated only once in `ref_data`.
|
|
95
|
+
# However, if we had excluded the current channel,
|
|
96
|
+
# then we would have omitted the contribution of the current channel:
|
|
97
|
+
# `CAR[i] = x[0]/(N-1) + x[1]/(N-1) + ... x[i-1]/(N-1) + x[i+1]/(N-1) + ... + x[N-1]/(N-1)`
|
|
98
|
+
# The majority of the calculation is the same as when the current channel is included;
|
|
99
|
+
# we need only rescale CAR so the divisor is `N-1` instead of `N`, then subtract the contribution
|
|
100
|
+
# from the current channel (i.e., `x[i] / (N-1)`)
|
|
101
|
+
# i.e., `CAR[i] = (N / (N-1)) * common_CAR - x[i]/(N-1)`
|
|
102
|
+
# We can use broadcasting subtraction instead of looping over channels.
|
|
103
|
+
N = axis_arr_in.data.shape[axis_idx]
|
|
104
|
+
ref_data = (N / (N - 1)) * ref_data - axis_arr_in.data / (N - 1)
|
|
105
|
+
# Side note: I profiled using affine_transform and it's about 30x slower than this implementation.
|
|
106
|
+
|
|
107
|
+
axis_arr_out = replace(axis_arr_in, data=axis_arr_in.data - ref_data)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class CommonRereferenceSettings(ez.Settings):
|
|
111
|
+
mode: str = "mean"
|
|
112
|
+
axis: Optional[str] = None
|
|
113
|
+
include_current: bool = True
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class CommonRereference(GenAxisArray):
|
|
117
|
+
SETTINGS: CommonRereferenceSettings
|
|
118
|
+
|
|
119
|
+
def construct_generator(self):
|
|
120
|
+
self.STATE.gen = common_rereference(
|
|
121
|
+
mode=self.SETTINGS.mode,
|
|
122
|
+
axis=self.SETTINGS.axis,
|
|
123
|
+
include_current=self.SETTINGS.include_current,
|
|
124
|
+
)
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from dataclasses import replace
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import ezmsg.core as ez
|
|
6
|
+
from ezmsg.util.generator import consumer, GenAxisArray
|
|
7
|
+
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
8
|
+
from ezmsg.sigproc.spectral import OptionsEnum
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AggregationFunction(OptionsEnum):
|
|
12
|
+
NONE = "None (all)"
|
|
13
|
+
MAX = "max"
|
|
14
|
+
MIN = "min"
|
|
15
|
+
MEAN = "mean"
|
|
16
|
+
MEDIAN = "median"
|
|
17
|
+
STD = "std"
|
|
18
|
+
NANMAX = "nanmax"
|
|
19
|
+
NANMIN = "nanmin"
|
|
20
|
+
NANMEAN = "nanmean"
|
|
21
|
+
NANMEDIAN = "nanmedian"
|
|
22
|
+
NANSTD = "nanstd"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
AGGREGATORS = {
|
|
26
|
+
AggregationFunction.NONE: np.all,
|
|
27
|
+
AggregationFunction.MAX: np.max,
|
|
28
|
+
AggregationFunction.MIN: np.min,
|
|
29
|
+
AggregationFunction.MEAN: np.mean,
|
|
30
|
+
AggregationFunction.MEDIAN: np.median,
|
|
31
|
+
AggregationFunction.STD: np.std,
|
|
32
|
+
AggregationFunction.NANMAX: np.nanmax,
|
|
33
|
+
AggregationFunction.NANMIN: np.nanmin,
|
|
34
|
+
AggregationFunction.NANMEAN: np.nanmean,
|
|
35
|
+
AggregationFunction.NANMEDIAN: np.nanmedian,
|
|
36
|
+
AggregationFunction.NANSTD: np.nanstd
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@consumer
|
|
41
|
+
def ranged_aggregate(
|
|
42
|
+
axis: typing.Optional[str] = None,
|
|
43
|
+
bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = None,
|
|
44
|
+
operation: AggregationFunction = AggregationFunction.MEAN
|
|
45
|
+
):
|
|
46
|
+
axis_arr_in = AxisArray(np.array([]), dims=[""])
|
|
47
|
+
axis_arr_out = AxisArray(np.array([]), dims=[""])
|
|
48
|
+
|
|
49
|
+
target_axis: typing.Optional[AxisArray.Axis] = None
|
|
50
|
+
out_axis = AxisArray.Axis()
|
|
51
|
+
slices: typing.Optional[typing.List[typing.Tuple[typing.Any, ...]]] = None
|
|
52
|
+
axis_name = ""
|
|
53
|
+
|
|
54
|
+
while True:
|
|
55
|
+
axis_arr_in = yield axis_arr_out
|
|
56
|
+
if bands is None:
|
|
57
|
+
axis_arr_out = axis_arr_in
|
|
58
|
+
else:
|
|
59
|
+
if slices is None or target_axis != axis_arr_in.get_axis(axis_name):
|
|
60
|
+
# Calculate the slices. If we are operating on time axis then
|
|
61
|
+
axis_name = axis or axis_arr_in.dims[0]
|
|
62
|
+
ax_idx = axis_arr_in.get_axis_idx(axis_name)
|
|
63
|
+
target_axis = axis_arr_in.axes[axis_name]
|
|
64
|
+
|
|
65
|
+
ax_vec = target_axis.offset + np.arange(axis_arr_in.data.shape[ax_idx]) * target_axis.gain
|
|
66
|
+
slices = []
|
|
67
|
+
mids = []
|
|
68
|
+
for (start, stop) in bands:
|
|
69
|
+
inds = np.where(np.logical_and(ax_vec >= start, ax_vec <= stop))[0]
|
|
70
|
+
mids.append(np.mean(inds) * target_axis.gain + target_axis.offset)
|
|
71
|
+
slices.append(np.s_[inds[0]:inds[-1] + 1])
|
|
72
|
+
out_axis = AxisArray.Axis(
|
|
73
|
+
unit=target_axis.unit, offset=mids[0], gain=(mids[1] - mids[0]) if len(mids) > 1 else 1.0
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
agg_func = AGGREGATORS[operation]
|
|
77
|
+
out_data = [
|
|
78
|
+
agg_func(slice_along_axis(axis_arr_in.data, sl, axis=ax_idx), axis=ax_idx)
|
|
79
|
+
for sl in slices
|
|
80
|
+
]
|
|
81
|
+
new_axes = {**axis_arr_in.axes, axis_name: out_axis}
|
|
82
|
+
axis_arr_out = replace(
|
|
83
|
+
axis_arr_in,
|
|
84
|
+
data=np.stack(out_data, axis=ax_idx),
|
|
85
|
+
axes=new_axes
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class RangedAggregateSettings(ez.Settings):
|
|
90
|
+
axis: typing.Optional[str] = None
|
|
91
|
+
bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = None
|
|
92
|
+
operation: AggregationFunction = AggregationFunction.MEAN
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class RangedAggregate(GenAxisArray):
|
|
96
|
+
SETTINGS: RangedAggregateSettings
|
|
97
|
+
|
|
98
|
+
def construct_generator(self):
|
|
99
|
+
self.STATE.gen = ranged_aggregate(
|
|
100
|
+
axis=self.SETTINGS.axis,
|
|
101
|
+
bands=self.SETTINGS.bands,
|
|
102
|
+
operation=self.SETTINGS.operation
|
|
103
|
+
)
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from dataclasses import field
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import ezmsg.core as ez
|
|
6
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
+
from ezmsg.util.generator import consumer, compose, GenAxisArray
|
|
8
|
+
|
|
9
|
+
from .spectrogram import spectrogram, SpectrogramSettings
|
|
10
|
+
from .aggregate import ranged_aggregate, AggregationFunction
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@consumer
|
|
14
|
+
def bandpower(
|
|
15
|
+
spectrogram_settings: SpectrogramSettings,
|
|
16
|
+
bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = [(17, 30), (70, 170)]
|
|
17
|
+
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
18
|
+
axis_arr_in = AxisArray(np.array([]), dims=[""])
|
|
19
|
+
axis_arr_out = AxisArray(np.array([]), dims=[""])
|
|
20
|
+
|
|
21
|
+
f_spec = spectrogram(
|
|
22
|
+
window_dur=spectrogram_settings.window_dur,
|
|
23
|
+
window_shift=spectrogram_settings.window_shift,
|
|
24
|
+
window=spectrogram_settings.window,
|
|
25
|
+
transform=spectrogram_settings.transform,
|
|
26
|
+
output=spectrogram_settings.output
|
|
27
|
+
)
|
|
28
|
+
f_agg = ranged_aggregate(
|
|
29
|
+
axis="freq",
|
|
30
|
+
bands=bands,
|
|
31
|
+
operation=AggregationFunction.MEAN
|
|
32
|
+
)
|
|
33
|
+
pipeline = compose(f_spec, f_agg)
|
|
34
|
+
|
|
35
|
+
while True:
|
|
36
|
+
axis_arr_in = yield axis_arr_out
|
|
37
|
+
axis_arr_out = pipeline(axis_arr_in)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class BandPowerSettings(ez.Settings):
|
|
41
|
+
spectrogram_settings: SpectrogramSettings = field(default_factory=SpectrogramSettings)
|
|
42
|
+
bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = (
|
|
43
|
+
field(default_factory=lambda: [(17, 30), (70, 170)]))
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class BandPower(GenAxisArray):
|
|
47
|
+
SETTINGS: BandPowerSettings
|
|
48
|
+
|
|
49
|
+
def construct_generator(self):
|
|
50
|
+
self.STATE.gen = bandpower(
|
|
51
|
+
spectrogram_settings=self.SETTINGS.spectrogram_settings,
|
|
52
|
+
bands=self.SETTINGS.bands
|
|
53
|
+
)
|
|
@@ -1,18 +1,21 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
1
3
|
import ezmsg.core as ez
|
|
2
4
|
import scipy.signal
|
|
3
5
|
import numpy as np
|
|
4
6
|
|
|
5
|
-
from .filter import Filter, FilterState, FilterSettingsBase
|
|
7
|
+
from .filter import filtergen, Filter, FilterState, FilterSettingsBase
|
|
6
8
|
|
|
7
|
-
from
|
|
9
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
10
|
+
from ezmsg.util.generator import consumer
|
|
8
11
|
|
|
9
12
|
|
|
10
13
|
class ButterworthFilterSettings(FilterSettingsBase):
|
|
11
14
|
order: int = 0
|
|
12
|
-
cuton: Optional[float] = None # Hz
|
|
13
|
-
cutoff: Optional[float] = None # Hz
|
|
15
|
+
cuton: typing.Optional[float] = None # Hz
|
|
16
|
+
cutoff: typing.Optional[float] = None # Hz
|
|
14
17
|
|
|
15
|
-
def filter_specs(self) -> Optional[Tuple[str, Union[float, Tuple[float, float]]]]:
|
|
18
|
+
def filter_specs(self) -> typing.Optional[typing.Tuple[str, typing.Union[float, typing.Tuple[float, float]]]]:
|
|
16
19
|
if self.cuton is None and self.cutoff is None:
|
|
17
20
|
return None
|
|
18
21
|
elif self.cuton is None and self.cutoff is not None:
|
|
@@ -26,6 +29,38 @@ class ButterworthFilterSettings(FilterSettingsBase):
|
|
|
26
29
|
return "bandstop", (self.cutoff, self.cuton)
|
|
27
30
|
|
|
28
31
|
|
|
32
|
+
@consumer
|
|
33
|
+
def butter(
|
|
34
|
+
axis: typing.Optional[str],
|
|
35
|
+
order: int = 0,
|
|
36
|
+
cuton: typing.Optional[float] = None,
|
|
37
|
+
cutoff: typing.Optional[float] = None,
|
|
38
|
+
coef_type: str = "ba",
|
|
39
|
+
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
40
|
+
# IO
|
|
41
|
+
axis_arr_in = AxisArray(np.array([]), dims=[""])
|
|
42
|
+
axis_arr_out = AxisArray(np.array([]), dims=[""])
|
|
43
|
+
|
|
44
|
+
btype, cutoffs = ButterworthFilterSettings(
|
|
45
|
+
order=order, cuton=cuton, cutoff=cutoff
|
|
46
|
+
).filter_specs()
|
|
47
|
+
|
|
48
|
+
# We cannot calculate coefs yet because we do not know input sample rate
|
|
49
|
+
coefs = None
|
|
50
|
+
filter_gen = filtergen(axis, coefs, coef_type) # Passthrough.
|
|
51
|
+
|
|
52
|
+
while True:
|
|
53
|
+
axis_arr_in = yield axis_arr_out
|
|
54
|
+
if coefs is None and order > 0:
|
|
55
|
+
fs = 1 / axis_arr_in.axes[axis or axis_arr_in.dims[0]].gain
|
|
56
|
+
coefs = scipy.signal.butter(
|
|
57
|
+
order, Wn=cutoffs, btype=btype, fs=fs, output=coef_type
|
|
58
|
+
)
|
|
59
|
+
filter_gen = filtergen(axis, coefs, coef_type)
|
|
60
|
+
|
|
61
|
+
axis_arr_out = filter_gen.send(axis_arr_in)
|
|
62
|
+
|
|
63
|
+
|
|
29
64
|
class ButterworthFilterState(FilterState):
|
|
30
65
|
design: ButterworthFilterSettings
|
|
31
66
|
|
|
@@ -41,7 +76,7 @@ class ButterworthFilter(Filter):
|
|
|
41
76
|
self.STATE.filt_designed = True
|
|
42
77
|
super().initialize()
|
|
43
78
|
|
|
44
|
-
def design_filter(self) -> Optional[Tuple[np.ndarray, np.ndarray]]:
|
|
79
|
+
def design_filter(self) -> typing.Optional[typing.Tuple[np.ndarray, np.ndarray]]:
|
|
45
80
|
specs = self.STATE.design.filter_specs()
|
|
46
81
|
if self.STATE.design.order > 0 and specs is not None:
|
|
47
82
|
btype, cut = specs
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from dataclasses import replace
|
|
2
|
+
import traceback
|
|
3
|
+
from typing import AsyncGenerator, Optional, Generator
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
|
+
from ezmsg.util.generator import consumer
|
|
9
|
+
import ezmsg.core as ez
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@consumer
|
|
13
|
+
def downsample(
|
|
14
|
+
axis: Optional[str] = None, factor: int = 1
|
|
15
|
+
) -> Generator[AxisArray, AxisArray, None]:
|
|
16
|
+
axis_arr_in = AxisArray(np.array([]), dims=[""])
|
|
17
|
+
axis_arr_out = AxisArray(np.array([]), dims=[""])
|
|
18
|
+
|
|
19
|
+
# state variables
|
|
20
|
+
s_idx = 0
|
|
21
|
+
|
|
22
|
+
while True:
|
|
23
|
+
axis_arr_in = yield axis_arr_out
|
|
24
|
+
|
|
25
|
+
if axis is None:
|
|
26
|
+
axis = axis_arr_in.dims[0]
|
|
27
|
+
axis_info = axis_arr_in.get_axis(axis)
|
|
28
|
+
axis_idx = axis_arr_in.get_axis_idx(axis)
|
|
29
|
+
|
|
30
|
+
samples = np.arange(axis_arr_in.data.shape[axis_idx]) + s_idx
|
|
31
|
+
samples = samples % factor
|
|
32
|
+
s_idx = samples[-1] + 1
|
|
33
|
+
|
|
34
|
+
pub_samples = np.where(samples == 0)[0]
|
|
35
|
+
if len(pub_samples) > 0:
|
|
36
|
+
new_axes = {ax_name: axis_arr_in.get_axis(ax_name) for ax_name in axis_arr_in.dims}
|
|
37
|
+
new_offset = axis_info.offset + (axis_info.gain * pub_samples[0].item())
|
|
38
|
+
new_gain = axis_info.gain * factor
|
|
39
|
+
new_axes[axis] = replace(axis_info, gain=new_gain, offset=new_offset)
|
|
40
|
+
down_data = np.take(axis_arr_in.data, pub_samples, axis=axis_idx)
|
|
41
|
+
axis_arr_out = replace(axis_arr_in, data=down_data, dims=axis_arr_in.dims, axes=new_axes)
|
|
42
|
+
else:
|
|
43
|
+
axis_arr_out = None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class DownsampleSettings(ez.Settings):
|
|
47
|
+
axis: Optional[str] = None
|
|
48
|
+
factor: int = 1
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class DownsampleState(ez.State):
|
|
52
|
+
cur_settings: DownsampleSettings
|
|
53
|
+
gen: Generator
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class Downsample(ez.Unit):
|
|
57
|
+
SETTINGS: DownsampleSettings
|
|
58
|
+
STATE: DownsampleState
|
|
59
|
+
|
|
60
|
+
INPUT_SETTINGS = ez.InputStream(DownsampleSettings)
|
|
61
|
+
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
62
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
63
|
+
|
|
64
|
+
def construct_generator(self):
|
|
65
|
+
self.STATE.gen = downsample(axis=self.STATE.cur_settings.axis, factor=self.STATE.cur_settings.factor)
|
|
66
|
+
|
|
67
|
+
def initialize(self) -> None:
|
|
68
|
+
self.STATE.cur_settings = self.SETTINGS
|
|
69
|
+
self.construct_generator()
|
|
70
|
+
|
|
71
|
+
@ez.subscriber(INPUT_SETTINGS)
|
|
72
|
+
async def on_settings(self, msg: DownsampleSettings) -> None:
|
|
73
|
+
self.STATE.cur_settings = msg
|
|
74
|
+
self.construct_generator()
|
|
75
|
+
|
|
76
|
+
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
|
|
77
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
78
|
+
async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
|
|
79
|
+
if self.STATE.cur_settings.factor < 1:
|
|
80
|
+
raise ValueError("Downsample factor must be at least 1 (no downsampling)")
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
out_msg = self.STATE.gen.send(msg)
|
|
84
|
+
if out_msg is not None:
|
|
85
|
+
yield self.OUTPUT_SIGNAL, out_msg
|
|
86
|
+
except (StopIteration, GeneratorExit):
|
|
87
|
+
ez.logger.debug(f"Downsample closed in {self.address}")
|
|
88
|
+
except Exception:
|
|
89
|
+
ez.logger.info(traceback.format_exc())
|
|
@@ -73,9 +73,12 @@ class EWM(ez.Unit):
|
|
|
73
73
|
buffer_data = buffer.data
|
|
74
74
|
buffer_data = np.moveaxis(buffer_data, axis_idx, 0)
|
|
75
75
|
|
|
76
|
+
while scale_arr.ndim < buffer_data.ndim:
|
|
77
|
+
scale_arr = scale_arr[..., None]
|
|
78
|
+
|
|
76
79
|
def ewma(data: np.ndarray) -> np.ndarray:
|
|
77
|
-
mult = scale_arr
|
|
78
|
-
out = scale_arr[::-1
|
|
80
|
+
mult = scale_arr * data * pw0
|
|
81
|
+
out = scale_arr[::-1] * mult.cumsum(axis=0)
|
|
79
82
|
|
|
80
83
|
if not self.SETTINGS.zero_offset:
|
|
81
84
|
out = (data[0, :, np.newaxis] * pows[1:]).T + out
|
|
@@ -108,7 +111,12 @@ class EWMFilter(ez.Collection):
|
|
|
108
111
|
EWM = EWM()
|
|
109
112
|
|
|
110
113
|
def configure(self) -> None:
|
|
111
|
-
self.EWM.apply_settings(
|
|
114
|
+
self.EWM.apply_settings(
|
|
115
|
+
EWMSettings(
|
|
116
|
+
axis=self.SETTINGS.axis,
|
|
117
|
+
zero_offset=self.SETTINGS.zero_offset,
|
|
118
|
+
)
|
|
119
|
+
)
|
|
112
120
|
|
|
113
121
|
self.WINDOW.apply_settings(
|
|
114
122
|
WindowSettings(
|
|
@@ -1,39 +1,107 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import typing
|
|
3
|
+
|
|
1
4
|
from dataclasses import dataclass, replace, field
|
|
2
5
|
|
|
3
6
|
import ezmsg.core as ez
|
|
4
7
|
import scipy.signal
|
|
8
|
+
|
|
5
9
|
import numpy as np
|
|
6
|
-
import
|
|
10
|
+
import numpy.typing as npt
|
|
7
11
|
|
|
8
12
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
|
-
|
|
10
|
-
from typing import AsyncGenerator, Optional, Tuple
|
|
11
|
-
|
|
13
|
+
from ezmsg.util.generator import consumer
|
|
12
14
|
|
|
13
15
|
@dataclass
|
|
14
16
|
class FilterCoefficients:
|
|
15
17
|
b: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0]))
|
|
16
18
|
a: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0]))
|
|
17
19
|
|
|
20
|
+
def _normalize_coefs(
|
|
21
|
+
coefs: typing.Union[FilterCoefficients, typing.Tuple[npt.NDArray, npt.NDArray],npt.NDArray]
|
|
22
|
+
) -> typing.Tuple[str, typing.Tuple[npt.NDArray,...]]:
|
|
23
|
+
coef_type = "ba"
|
|
24
|
+
if coefs is not None:
|
|
25
|
+
# scipy.signal functions called with first arg `*coefs`.
|
|
26
|
+
# Make sure we have a tuple of coefficients.
|
|
27
|
+
if isinstance(coefs, npt.NDArray):
|
|
28
|
+
coef_type = "sos"
|
|
29
|
+
coefs = (coefs,) # sos funcs just want a single ndarray.
|
|
30
|
+
elif isinstance(coefs, FilterCoefficients):
|
|
31
|
+
coefs = (FilterCoefficients.b, FilterCoefficients.a)
|
|
32
|
+
return coef_type, coefs
|
|
33
|
+
|
|
34
|
+
@consumer
|
|
35
|
+
def filtergen(
|
|
36
|
+
axis: str, coefs: typing.Optional[typing.Tuple[np.ndarray]], coef_type: str
|
|
37
|
+
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
38
|
+
# Massage inputs
|
|
39
|
+
if coefs is not None and not isinstance(coefs, tuple):
|
|
40
|
+
# scipy.signal functions called with first arg `*coefs`, but sos coefs are a single ndarray.
|
|
41
|
+
coefs = (coefs,)
|
|
42
|
+
|
|
43
|
+
# Init IO
|
|
44
|
+
axis_arr_in = AxisArray(np.array([]), dims=[""])
|
|
45
|
+
axis_arr_out = AxisArray(np.array([]), dims=[""])
|
|
46
|
+
|
|
47
|
+
filt_func = {"ba": scipy.signal.lfilter, "sos": scipy.signal.sosfilt}[coef_type]
|
|
48
|
+
zi_func = {"ba": scipy.signal.lfilter_zi, "sos": scipy.signal.sosfilt_zi}[coef_type]
|
|
49
|
+
|
|
50
|
+
# State variables
|
|
51
|
+
axis_idx = None
|
|
52
|
+
zi = None
|
|
53
|
+
expected_shape = None
|
|
54
|
+
|
|
55
|
+
while True:
|
|
56
|
+
axis_arr_in = yield axis_arr_out
|
|
57
|
+
|
|
58
|
+
if coefs is None:
|
|
59
|
+
# passthrough if we do not have a filter design.
|
|
60
|
+
axis_arr_out = axis_arr_in
|
|
61
|
+
continue
|
|
62
|
+
|
|
63
|
+
if axis_idx is None:
|
|
64
|
+
axis_name = axis_arr_in.dims[0] if axis is None else axis
|
|
65
|
+
axis_idx = axis_arr_in.get_axis_idx(axis_name)
|
|
66
|
+
|
|
67
|
+
dat_in = axis_arr_in.data
|
|
68
|
+
|
|
69
|
+
# Re-calculate/reset zi if necessary
|
|
70
|
+
samp_shape = dat_in.shape[:axis_idx] + dat_in.shape[axis_idx + 1 :]
|
|
71
|
+
if zi is None or samp_shape != expected_shape:
|
|
72
|
+
expected_shape = samp_shape
|
|
73
|
+
n_tail = dat_in.ndim - axis_idx - 1
|
|
74
|
+
zi = zi_func(*coefs)
|
|
75
|
+
zi_expand = (None,) * axis_idx + (slice(None),) + (None,) * n_tail
|
|
76
|
+
n_tile = dat_in.shape[:axis_idx] + (1,) + dat_in.shape[axis_idx + 1 :]
|
|
77
|
+
if coef_type == "sos":
|
|
78
|
+
# sos zi must keep its leading dimension (`order / 2` for low|high; `order` for bpass|bstop)
|
|
79
|
+
zi_expand = (slice(None),) + zi_expand
|
|
80
|
+
n_tile = (1,) + n_tile
|
|
81
|
+
zi = np.tile(zi[zi_expand], n_tile)
|
|
82
|
+
|
|
83
|
+
dat_out, zi = filt_func(*coefs, dat_in, axis=axis_idx, zi=zi)
|
|
84
|
+
axis_arr_out = replace(axis_arr_in, data=dat_out)
|
|
85
|
+
|
|
18
86
|
|
|
19
87
|
class FilterSettingsBase(ez.Settings):
|
|
20
|
-
axis: Optional[str] = None
|
|
21
|
-
fs: Optional[float] = None
|
|
88
|
+
axis: typing.Optional[str] = None
|
|
89
|
+
fs: typing.Optional[float] = None
|
|
22
90
|
|
|
23
91
|
|
|
24
92
|
class FilterSettings(FilterSettingsBase):
|
|
25
93
|
# If you'd like to statically design a filter, define it in settings
|
|
26
|
-
filt: Optional[FilterCoefficients] = None
|
|
94
|
+
filt: typing.Optional[FilterCoefficients] = None
|
|
27
95
|
|
|
28
96
|
|
|
29
97
|
class FilterState(ez.State):
|
|
30
|
-
axis: Optional[str] = None
|
|
31
|
-
zi: Optional[np.ndarray] = None
|
|
98
|
+
axis: typing.Optional[str] = None
|
|
99
|
+
zi: typing.Optional[np.ndarray] = None
|
|
32
100
|
filt_designed: bool = False
|
|
33
|
-
filt: Optional[FilterCoefficients] = None
|
|
101
|
+
filt: typing.Optional[FilterCoefficients] = None
|
|
34
102
|
filt_set: asyncio.Event = field(default_factory=asyncio.Event)
|
|
35
|
-
samp_shape: Optional[Tuple[int, ...]] = None
|
|
36
|
-
fs: Optional[float] = None # Hz
|
|
103
|
+
samp_shape: typing.Optional[typing.Tuple[int, ...]] = None
|
|
104
|
+
fs: typing.Optional[float] = None # Hz
|
|
37
105
|
|
|
38
106
|
|
|
39
107
|
class Filter(ez.Unit):
|
|
@@ -44,7 +112,7 @@ class Filter(ez.Unit):
|
|
|
44
112
|
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
45
113
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
46
114
|
|
|
47
|
-
def design_filter(self) -> Optional[Tuple[np.ndarray, np.ndarray]]:
|
|
115
|
+
def design_filter(self) -> typing.Optional[typing.Tuple[np.ndarray, np.ndarray]]:
|
|
48
116
|
raise NotImplementedError("Must implement 'design_filter' in Unit subclass!")
|
|
49
117
|
|
|
50
118
|
# Set up filter with static initialization if specified
|
|
@@ -84,7 +152,7 @@ class Filter(ez.Unit):
|
|
|
84
152
|
|
|
85
153
|
@ez.subscriber(INPUT_SIGNAL)
|
|
86
154
|
@ez.publisher(OUTPUT_SIGNAL)
|
|
87
|
-
async def apply_filter(self, msg: AxisArray) -> AsyncGenerator:
|
|
155
|
+
async def apply_filter(self, msg: AxisArray) -> typing.AsyncGenerator:
|
|
88
156
|
axis_name = msg.dims[0] if self.STATE.axis is None else self.STATE.axis
|
|
89
157
|
axis_idx = msg.get_axis_idx(axis_name)
|
|
90
158
|
axis = msg.get_axis(axis_name)
|