ezmsg-sigproc 1.2.1__tar.gz → 1.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {ezmsg-sigproc-1.2.1 → ezmsg_sigproc-1.2.3}/PKG-INFO +15 -8
  2. ezmsg_sigproc-1.2.3/pyproject.toml +30 -0
  3. ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/__init__.py +4 -0
  4. ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/affinetransform.py +124 -0
  5. ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/aggregate.py +103 -0
  6. ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/bandpower.py +53 -0
  7. {ezmsg-sigproc-1.2.1 → ezmsg_sigproc-1.2.3/src}/ezmsg/sigproc/butterworthfilter.py +44 -6
  8. ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/downsample.py +89 -0
  9. {ezmsg-sigproc-1.2.1 → ezmsg_sigproc-1.2.3/src}/ezmsg/sigproc/ewmfilter.py +11 -3
  10. {ezmsg-sigproc-1.2.1 → ezmsg_sigproc-1.2.3/src}/ezmsg/sigproc/filter.py +82 -14
  11. ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/sampler.py +260 -0
  12. ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/scaler.py +127 -0
  13. ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/signalinjector.py +67 -0
  14. ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/slicer.py +98 -0
  15. ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/spectral.py +9 -0
  16. ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/spectrogram.py +68 -0
  17. ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/spectrum.py +158 -0
  18. {ezmsg-sigproc-1.2.1 → ezmsg_sigproc-1.2.3/src}/ezmsg/sigproc/synth.py +179 -80
  19. ezmsg_sigproc-1.2.3/src/ezmsg/sigproc/window.py +246 -0
  20. ezmsg-sigproc-1.2.1/ezmsg/sigproc/__init__.py +0 -1
  21. ezmsg-sigproc-1.2.1/ezmsg/sigproc/__version__.py +0 -1
  22. ezmsg-sigproc-1.2.1/ezmsg/sigproc/downsample.py +0 -63
  23. ezmsg-sigproc-1.2.1/ezmsg/sigproc/sampler.py +0 -287
  24. ezmsg-sigproc-1.2.1/ezmsg/sigproc/spectral.py +0 -132
  25. ezmsg-sigproc-1.2.1/ezmsg/sigproc/window.py +0 -144
  26. ezmsg-sigproc-1.2.1/ezmsg_sigproc.egg-info/PKG-INFO +0 -31
  27. ezmsg-sigproc-1.2.1/ezmsg_sigproc.egg-info/SOURCES.txt +0 -25
  28. ezmsg-sigproc-1.2.1/ezmsg_sigproc.egg-info/dependency_links.txt +0 -1
  29. ezmsg-sigproc-1.2.1/ezmsg_sigproc.egg-info/not-zip-safe +0 -1
  30. ezmsg-sigproc-1.2.1/ezmsg_sigproc.egg-info/requires.txt +0 -7
  31. ezmsg-sigproc-1.2.1/ezmsg_sigproc.egg-info/top_level.txt +0 -1
  32. ezmsg-sigproc-1.2.1/setup.cfg +0 -31
  33. ezmsg-sigproc-1.2.1/setup.py +0 -7
  34. ezmsg-sigproc-1.2.1/tests/test_butterworth.py +0 -142
  35. ezmsg-sigproc-1.2.1/tests/test_downsample.py +0 -132
  36. ezmsg-sigproc-1.2.1/tests/test_window.py +0 -139
  37. {ezmsg-sigproc-1.2.1 → ezmsg_sigproc-1.2.3}/LICENSE.txt +0 -0
  38. {ezmsg-sigproc-1.2.1 → ezmsg_sigproc-1.2.3}/README.md +0 -0
  39. {ezmsg-sigproc-1.2.1 → ezmsg_sigproc-1.2.3/src}/ezmsg/sigproc/decimate.py +0 -0
  40. {ezmsg-sigproc-1.2.1 → ezmsg_sigproc-1.2.3/src}/ezmsg/sigproc/messages.py +0 -0
@@ -1,16 +1,22 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ezmsg-sigproc
3
- Version: 1.2.1
3
+ Version: 1.2.3
4
4
  Summary: Timeseries signal processing implementations in ezmsg
5
- Home-page: https://github.com/iscoe/ezmsg
6
- Author: Griffin Milsap
7
- Author-email: griffin.milsap@jhuapl.edu
5
+ License: MIT
6
+ Author: Milsap, Griffin
7
+ Author-email: griffin.milsap@gmail.com
8
+ Requires-Python: >=3.8,<4.0
9
+ Classifier: License :: OSI Approved :: MIT License
8
10
  Classifier: Programming Language :: Python :: 3
9
- Classifier: Operating System :: OS Independent
10
- Requires-Python: >=3.8
11
+ Classifier: Programming Language :: Python :: 3.8
12
+ Classifier: Programming Language :: Python :: 3.9
13
+ Classifier: Programming Language :: Python :: 3.10
14
+ Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Requires-Dist: ezmsg (>=3.3.0,<4.0.0)
17
+ Requires-Dist: numpy (>=1.19.5,<2.0.0)
18
+ Requires-Dist: scipy (>=1.6.3,<2.0.0)
11
19
  Description-Content-Type: text/markdown
12
- Provides-Extra: test
13
- License-File: LICENSE.txt
14
20
 
15
21
  # ezmsg.sigproc
16
22
 
@@ -29,3 +35,4 @@ Timeseries signal processing implementations for ezmsg
29
35
  2. `cd` to this directory (`ezmsg-sigproc`) and run `pip install -e .`
30
36
  3. Signal processing modules are available under `import ezmsg.sigproc`
31
37
 
38
+
@@ -0,0 +1,30 @@
1
+ [tool.poetry]
2
+ name = "ezmsg-sigproc"
3
+ version = "1.2.3"
4
+ description = "Timeseries signal processing implementations in ezmsg"
5
+ authors = [
6
+ "Milsap, Griffin <griffin.milsap@gmail.com>",
7
+ "Peranich, Preston <pperanich@gmail.com>",
8
+ ]
9
+ license = "MIT"
10
+ readme = "README.md"
11
+ packages = [{ include = "ezmsg", from = "src" }]
12
+
13
+ [tool.poetry.dependencies]
14
+ python = "^3.8"
15
+ ezmsg = "^3.3.0"
16
+ numpy = "^1.19.5"
17
+ scipy = "^1.6.3"
18
+
19
+ [tool.poetry.group.test.dependencies]
20
+ pytest = "^7.0.0"
21
+ pytest-cov = "*"
22
+
23
+ [tool.pytest.ini_options]
24
+ addopts = ["--import-mode=importlib"]
25
+ pythonpath = ["src", "tests"]
26
+ testpaths = "tests"
27
+
28
+ [build-system]
29
+ requires = ["poetry-core"]
30
+ build-backend = "poetry.core.masonry.api"
@@ -0,0 +1,4 @@
1
+ import importlib.metadata
2
+
3
+
4
+ __version__ = importlib.metadata.version("ezmsg-sigproc")
@@ -0,0 +1,124 @@
1
+ from dataclasses import replace
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Generator, Optional, Union
5
+
6
+ import numpy as np
7
+ import ezmsg.core as ez
8
+ from ezmsg.util.messages.axisarray import AxisArray
9
+ from ezmsg.util.generator import consumer, GenAxisArray
10
+
11
+
12
+ @consumer
13
+ def affine_transform(
14
+ weights: Union[np.ndarray, str, Path],
15
+ axis: Optional[str] = None,
16
+ right_multiply: bool = True,
17
+ ) -> Generator[AxisArray, AxisArray, None]:
18
+ axis_arr_in = AxisArray(np.array([]), dims=[""])
19
+ axis_arr_out = AxisArray(np.array([]), dims=[""])
20
+
21
+ if isinstance(weights, str):
22
+ weights = Path(os.path.abspath(os.path.expanduser(weights)))
23
+ if isinstance(weights, Path):
24
+ weights = np.loadtxt(weights, delimiter=",")
25
+ if not right_multiply:
26
+ weights = weights.T
27
+ weights = np.ascontiguousarray(weights)
28
+
29
+ while True:
30
+ axis_arr_in = yield axis_arr_out
31
+
32
+ if axis is None:
33
+ axis = axis_arr_in.dims[-1]
34
+ axis_idx = -1
35
+ else:
36
+ axis_idx = axis_arr_in.get_axis_idx(axis)
37
+
38
+ data = axis_arr_in.data
39
+
40
+ if data.shape[axis_idx] == (weights.shape[0] - 1):
41
+ # The weights are stacked A|B where A is the transform and B is a single row
42
+ # in the equation y = Ax + B. This supports NeuroKey's weights matrices.
43
+ sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx+1:]
44
+ data = np.concatenate((data, np.ones(sample_shape).astype(data.dtype)), axis=axis_idx)
45
+
46
+ if axis_idx in [-1, len(axis_arr_in.dims) - 1]:
47
+ data = np.matmul(data, weights)
48
+ else:
49
+ data = np.moveaxis(data, axis_idx, -1)
50
+ data = np.matmul(data, weights)
51
+ data = np.moveaxis(data, -1, axis_idx)
52
+ axis_arr_out = replace(axis_arr_in, data=data)
53
+
54
+
55
+ class AffineTransformSettings(ez.Settings):
56
+ weights: Union[np.ndarray, str, Path]
57
+ axis: Optional[str] = None
58
+ right_multiply: bool = True
59
+
60
+
61
+ class AffineTransform(GenAxisArray):
62
+ SETTINGS: AffineTransformSettings
63
+
64
+ def construct_generator(self):
65
+ self.STATE.gen = affine_transform(
66
+ weights=self.SETTINGS.weights,
67
+ axis=self.SETTINGS.axis,
68
+ right_multiply=self.SETTINGS.right_multiply,
69
+ )
70
+
71
+
72
+ @consumer
73
+ def common_rereference(
74
+ mode: str = "mean", axis: Optional[str] = None, include_current: bool = True
75
+ ) -> Generator[AxisArray, AxisArray, None]:
76
+ axis_arr_in = AxisArray(np.array([]), dims=[""])
77
+ axis_arr_out = AxisArray(np.array([]), dims=[""])
78
+
79
+ func = {"mean": np.mean, "median": np.median}[mode]
80
+
81
+ while True:
82
+ axis_arr_in = yield axis_arr_out
83
+
84
+ if axis is None:
85
+ axis = axis_arr_in.dims[-1]
86
+ axis_idx = -1
87
+ else:
88
+ axis_idx = axis_arr_in.get_axis_idx(axis)
89
+
90
+ ref_data = func(axis_arr_in.data, axis=axis_idx, keepdims=True)
91
+
92
+ if not include_current:
93
+ # Typical `CAR = x[0]/N + x[1]/N + ... x[i-1]/N + x[i]/N + x[i+1]/N + ... + x[N-1]/N`
94
+ # and is the same for all i, so it is calculated only once in `ref_data`.
95
+ # However, if we had excluded the current channel,
96
+ # then we would have omitted the contribution of the current channel:
97
+ # `CAR[i] = x[0]/(N-1) + x[1]/(N-1) + ... x[i-1]/(N-1) + x[i+1]/(N-1) + ... + x[N-1]/(N-1)`
98
+ # The majority of the calculation is the same as when the current channel is included;
99
+ # we need only rescale CAR so the divisor is `N-1` instead of `N`, then subtract the contribution
100
+ # from the current channel (i.e., `x[i] / (N-1)`)
101
+ # i.e., `CAR[i] = (N / (N-1)) * common_CAR - x[i]/(N-1)`
102
+ # We can use broadcasting subtraction instead of looping over channels.
103
+ N = axis_arr_in.data.shape[axis_idx]
104
+ ref_data = (N / (N - 1)) * ref_data - axis_arr_in.data / (N - 1)
105
+ # Side note: I profiled using affine_transform and it's about 30x slower than this implementation.
106
+
107
+ axis_arr_out = replace(axis_arr_in, data=axis_arr_in.data - ref_data)
108
+
109
+
110
+ class CommonRereferenceSettings(ez.Settings):
111
+ mode: str = "mean"
112
+ axis: Optional[str] = None
113
+ include_current: bool = True
114
+
115
+
116
+ class CommonRereference(GenAxisArray):
117
+ SETTINGS: CommonRereferenceSettings
118
+
119
+ def construct_generator(self):
120
+ self.STATE.gen = common_rereference(
121
+ mode=self.SETTINGS.mode,
122
+ axis=self.SETTINGS.axis,
123
+ include_current=self.SETTINGS.include_current,
124
+ )
@@ -0,0 +1,103 @@
1
+ from dataclasses import replace
2
+ import typing
3
+
4
+ import numpy as np
5
+ import ezmsg.core as ez
6
+ from ezmsg.util.generator import consumer, GenAxisArray
7
+ from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
8
+ from ezmsg.sigproc.spectral import OptionsEnum
9
+
10
+
11
+ class AggregationFunction(OptionsEnum):
12
+ NONE = "None (all)"
13
+ MAX = "max"
14
+ MIN = "min"
15
+ MEAN = "mean"
16
+ MEDIAN = "median"
17
+ STD = "std"
18
+ NANMAX = "nanmax"
19
+ NANMIN = "nanmin"
20
+ NANMEAN = "nanmean"
21
+ NANMEDIAN = "nanmedian"
22
+ NANSTD = "nanstd"
23
+
24
+
25
+ AGGREGATORS = {
26
+ AggregationFunction.NONE: np.all,
27
+ AggregationFunction.MAX: np.max,
28
+ AggregationFunction.MIN: np.min,
29
+ AggregationFunction.MEAN: np.mean,
30
+ AggregationFunction.MEDIAN: np.median,
31
+ AggregationFunction.STD: np.std,
32
+ AggregationFunction.NANMAX: np.nanmax,
33
+ AggregationFunction.NANMIN: np.nanmin,
34
+ AggregationFunction.NANMEAN: np.nanmean,
35
+ AggregationFunction.NANMEDIAN: np.nanmedian,
36
+ AggregationFunction.NANSTD: np.nanstd
37
+ }
38
+
39
+
40
+ @consumer
41
+ def ranged_aggregate(
42
+ axis: typing.Optional[str] = None,
43
+ bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = None,
44
+ operation: AggregationFunction = AggregationFunction.MEAN
45
+ ):
46
+ axis_arr_in = AxisArray(np.array([]), dims=[""])
47
+ axis_arr_out = AxisArray(np.array([]), dims=[""])
48
+
49
+ target_axis: typing.Optional[AxisArray.Axis] = None
50
+ out_axis = AxisArray.Axis()
51
+ slices: typing.Optional[typing.List[typing.Tuple[typing.Any, ...]]] = None
52
+ axis_name = ""
53
+
54
+ while True:
55
+ axis_arr_in = yield axis_arr_out
56
+ if bands is None:
57
+ axis_arr_out = axis_arr_in
58
+ else:
59
+ if slices is None or target_axis != axis_arr_in.get_axis(axis_name):
60
+ # Calculate the slices. If we are operating on time axis then
61
+ axis_name = axis or axis_arr_in.dims[0]
62
+ ax_idx = axis_arr_in.get_axis_idx(axis_name)
63
+ target_axis = axis_arr_in.axes[axis_name]
64
+
65
+ ax_vec = target_axis.offset + np.arange(axis_arr_in.data.shape[ax_idx]) * target_axis.gain
66
+ slices = []
67
+ mids = []
68
+ for (start, stop) in bands:
69
+ inds = np.where(np.logical_and(ax_vec >= start, ax_vec <= stop))[0]
70
+ mids.append(np.mean(inds) * target_axis.gain + target_axis.offset)
71
+ slices.append(np.s_[inds[0]:inds[-1] + 1])
72
+ out_axis = AxisArray.Axis(
73
+ unit=target_axis.unit, offset=mids[0], gain=(mids[1] - mids[0]) if len(mids) > 1 else 1.0
74
+ )
75
+
76
+ agg_func = AGGREGATORS[operation]
77
+ out_data = [
78
+ agg_func(slice_along_axis(axis_arr_in.data, sl, axis=ax_idx), axis=ax_idx)
79
+ for sl in slices
80
+ ]
81
+ new_axes = {**axis_arr_in.axes, axis_name: out_axis}
82
+ axis_arr_out = replace(
83
+ axis_arr_in,
84
+ data=np.stack(out_data, axis=ax_idx),
85
+ axes=new_axes
86
+ )
87
+
88
+
89
+ class RangedAggregateSettings(ez.Settings):
90
+ axis: typing.Optional[str] = None
91
+ bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = None
92
+ operation: AggregationFunction = AggregationFunction.MEAN
93
+
94
+
95
+ class RangedAggregate(GenAxisArray):
96
+ SETTINGS: RangedAggregateSettings
97
+
98
+ def construct_generator(self):
99
+ self.STATE.gen = ranged_aggregate(
100
+ axis=self.SETTINGS.axis,
101
+ bands=self.SETTINGS.bands,
102
+ operation=self.SETTINGS.operation
103
+ )
@@ -0,0 +1,53 @@
1
+ from dataclasses import field
2
+ import typing
3
+
4
+ import numpy as np
5
+ import ezmsg.core as ez
6
+ from ezmsg.util.messages.axisarray import AxisArray
7
+ from ezmsg.util.generator import consumer, compose, GenAxisArray
8
+
9
+ from .spectrogram import spectrogram, SpectrogramSettings
10
+ from .aggregate import ranged_aggregate, AggregationFunction
11
+
12
+
13
+ @consumer
14
+ def bandpower(
15
+ spectrogram_settings: SpectrogramSettings,
16
+ bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = [(17, 30), (70, 170)]
17
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
18
+ axis_arr_in = AxisArray(np.array([]), dims=[""])
19
+ axis_arr_out = AxisArray(np.array([]), dims=[""])
20
+
21
+ f_spec = spectrogram(
22
+ window_dur=spectrogram_settings.window_dur,
23
+ window_shift=spectrogram_settings.window_shift,
24
+ window=spectrogram_settings.window,
25
+ transform=spectrogram_settings.transform,
26
+ output=spectrogram_settings.output
27
+ )
28
+ f_agg = ranged_aggregate(
29
+ axis="freq",
30
+ bands=bands,
31
+ operation=AggregationFunction.MEAN
32
+ )
33
+ pipeline = compose(f_spec, f_agg)
34
+
35
+ while True:
36
+ axis_arr_in = yield axis_arr_out
37
+ axis_arr_out = pipeline(axis_arr_in)
38
+
39
+
40
+ class BandPowerSettings(ez.Settings):
41
+ spectrogram_settings: SpectrogramSettings = field(default_factory=SpectrogramSettings)
42
+ bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = (
43
+ field(default_factory=lambda: [(17, 30), (70, 170)]))
44
+
45
+
46
+ class BandPower(GenAxisArray):
47
+ SETTINGS: BandPowerSettings
48
+
49
+ def construct_generator(self):
50
+ self.STATE.gen = bandpower(
51
+ spectrogram_settings=self.SETTINGS.spectrogram_settings,
52
+ bands=self.SETTINGS.bands
53
+ )
@@ -1,18 +1,21 @@
1
+ import typing
2
+
1
3
  import ezmsg.core as ez
2
4
  import scipy.signal
3
5
  import numpy as np
4
6
 
5
- from .filter import Filter, FilterState, FilterSettingsBase
7
+ from .filter import filtergen, Filter, FilterState, FilterSettingsBase
6
8
 
7
- from typing import Optional, Tuple, Union
9
+ from ezmsg.util.messages.axisarray import AxisArray
10
+ from ezmsg.util.generator import consumer
8
11
 
9
12
 
10
13
  class ButterworthFilterSettings(FilterSettingsBase):
11
14
  order: int = 0
12
- cuton: Optional[float] = None # Hz
13
- cutoff: Optional[float] = None # Hz
15
+ cuton: typing.Optional[float] = None # Hz
16
+ cutoff: typing.Optional[float] = None # Hz
14
17
 
15
- def filter_specs(self) -> Optional[Tuple[str, Union[float, Tuple[float, float]]]]:
18
+ def filter_specs(self) -> typing.Optional[typing.Tuple[str, typing.Union[float, typing.Tuple[float, float]]]]:
16
19
  if self.cuton is None and self.cutoff is None:
17
20
  return None
18
21
  elif self.cuton is None and self.cutoff is not None:
@@ -26,6 +29,38 @@ class ButterworthFilterSettings(FilterSettingsBase):
26
29
  return "bandstop", (self.cutoff, self.cuton)
27
30
 
28
31
 
32
+ @consumer
33
+ def butter(
34
+ axis: typing.Optional[str],
35
+ order: int = 0,
36
+ cuton: typing.Optional[float] = None,
37
+ cutoff: typing.Optional[float] = None,
38
+ coef_type: str = "ba",
39
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
40
+ # IO
41
+ axis_arr_in = AxisArray(np.array([]), dims=[""])
42
+ axis_arr_out = AxisArray(np.array([]), dims=[""])
43
+
44
+ btype, cutoffs = ButterworthFilterSettings(
45
+ order=order, cuton=cuton, cutoff=cutoff
46
+ ).filter_specs()
47
+
48
+ # We cannot calculate coefs yet because we do not know input sample rate
49
+ coefs = None
50
+ filter_gen = filtergen(axis, coefs, coef_type) # Passthrough.
51
+
52
+ while True:
53
+ axis_arr_in = yield axis_arr_out
54
+ if coefs is None and order > 0:
55
+ fs = 1 / axis_arr_in.axes[axis or axis_arr_in.dims[0]].gain
56
+ coefs = scipy.signal.butter(
57
+ order, Wn=cutoffs, btype=btype, fs=fs, output=coef_type
58
+ )
59
+ filter_gen = filtergen(axis, coefs, coef_type)
60
+
61
+ axis_arr_out = filter_gen.send(axis_arr_in)
62
+
63
+
29
64
  class ButterworthFilterState(FilterState):
30
65
  design: ButterworthFilterSettings
31
66
 
@@ -41,7 +76,7 @@ class ButterworthFilter(Filter):
41
76
  self.STATE.filt_designed = True
42
77
  super().initialize()
43
78
 
44
- def design_filter(self) -> Optional[Tuple[np.ndarray, np.ndarray]]:
79
+ def design_filter(self) -> typing.Optional[typing.Tuple[np.ndarray, np.ndarray]]:
45
80
  specs = self.STATE.design.filter_specs()
46
81
  if self.STATE.design.order > 0 and specs is not None:
47
82
  btype, cut = specs
@@ -57,6 +92,9 @@ class ButterworthFilter(Filter):
57
92
 
58
93
  @ez.subscriber(INPUT_FILTER)
59
94
  async def redesign(self, message: ButterworthFilterSettings) -> None:
95
+ if type(message) is not ButterworthFilterSettings:
96
+ return
97
+
60
98
  if self.STATE.design.order != message.order:
61
99
  self.STATE.zi = None
62
100
  self.STATE.design = message
@@ -0,0 +1,89 @@
1
+ from dataclasses import replace
2
+ import traceback
3
+ from typing import AsyncGenerator, Optional, Generator
4
+
5
+ import numpy as np
6
+
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+ from ezmsg.util.generator import consumer
9
+ import ezmsg.core as ez
10
+
11
+
12
+ @consumer
13
+ def downsample(
14
+ axis: Optional[str] = None, factor: int = 1
15
+ ) -> Generator[AxisArray, AxisArray, None]:
16
+ axis_arr_in = AxisArray(np.array([]), dims=[""])
17
+ axis_arr_out = AxisArray(np.array([]), dims=[""])
18
+
19
+ # state variables
20
+ s_idx = 0
21
+
22
+ while True:
23
+ axis_arr_in = yield axis_arr_out
24
+
25
+ if axis is None:
26
+ axis = axis_arr_in.dims[0]
27
+ axis_info = axis_arr_in.get_axis(axis)
28
+ axis_idx = axis_arr_in.get_axis_idx(axis)
29
+
30
+ samples = np.arange(axis_arr_in.data.shape[axis_idx]) + s_idx
31
+ samples = samples % factor
32
+ s_idx = samples[-1] + 1
33
+
34
+ pub_samples = np.where(samples == 0)[0]
35
+ if len(pub_samples) > 0:
36
+ new_axes = {ax_name: axis_arr_in.get_axis(ax_name) for ax_name in axis_arr_in.dims}
37
+ new_offset = axis_info.offset + (axis_info.gain * pub_samples[0].item())
38
+ new_gain = axis_info.gain * factor
39
+ new_axes[axis] = replace(axis_info, gain=new_gain, offset=new_offset)
40
+ down_data = np.take(axis_arr_in.data, pub_samples, axis=axis_idx)
41
+ axis_arr_out = replace(axis_arr_in, data=down_data, dims=axis_arr_in.dims, axes=new_axes)
42
+ else:
43
+ axis_arr_out = None
44
+
45
+
46
+ class DownsampleSettings(ez.Settings):
47
+ axis: Optional[str] = None
48
+ factor: int = 1
49
+
50
+
51
+ class DownsampleState(ez.State):
52
+ cur_settings: DownsampleSettings
53
+ gen: Generator
54
+
55
+
56
+ class Downsample(ez.Unit):
57
+ SETTINGS: DownsampleSettings
58
+ STATE: DownsampleState
59
+
60
+ INPUT_SETTINGS = ez.InputStream(DownsampleSettings)
61
+ INPUT_SIGNAL = ez.InputStream(AxisArray)
62
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
63
+
64
+ def construct_generator(self):
65
+ self.STATE.gen = downsample(axis=self.STATE.cur_settings.axis, factor=self.STATE.cur_settings.factor)
66
+
67
+ def initialize(self) -> None:
68
+ self.STATE.cur_settings = self.SETTINGS
69
+ self.construct_generator()
70
+
71
+ @ez.subscriber(INPUT_SETTINGS)
72
+ async def on_settings(self, msg: DownsampleSettings) -> None:
73
+ self.STATE.cur_settings = msg
74
+ self.construct_generator()
75
+
76
+ @ez.subscriber(INPUT_SIGNAL, zero_copy=True)
77
+ @ez.publisher(OUTPUT_SIGNAL)
78
+ async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
79
+ if self.STATE.cur_settings.factor < 1:
80
+ raise ValueError("Downsample factor must be at least 1 (no downsampling)")
81
+
82
+ try:
83
+ out_msg = self.STATE.gen.send(msg)
84
+ if out_msg is not None:
85
+ yield self.OUTPUT_SIGNAL, out_msg
86
+ except (StopIteration, GeneratorExit):
87
+ ez.logger.debug(f"Downsample closed in {self.address}")
88
+ except Exception:
89
+ ez.logger.info(traceback.format_exc())
@@ -73,9 +73,12 @@ class EWM(ez.Unit):
73
73
  buffer_data = buffer.data
74
74
  buffer_data = np.moveaxis(buffer_data, axis_idx, 0)
75
75
 
76
+ while scale_arr.ndim < buffer_data.ndim:
77
+ scale_arr = scale_arr[..., None]
78
+
76
79
  def ewma(data: np.ndarray) -> np.ndarray:
77
- mult = scale_arr[:, np.newaxis] * data * pw0
78
- out = scale_arr[::-1, np.newaxis] * mult.cumsum(axis=0)
80
+ mult = scale_arr * data * pw0
81
+ out = scale_arr[::-1] * mult.cumsum(axis=0)
79
82
 
80
83
  if not self.SETTINGS.zero_offset:
81
84
  out = (data[0, :, np.newaxis] * pows[1:]).T + out
@@ -108,7 +111,12 @@ class EWMFilter(ez.Collection):
108
111
  EWM = EWM()
109
112
 
110
113
  def configure(self) -> None:
111
- self.EWM.apply_settings(EWMSettings(axis=self.SETTINGS.axis, zero_offset=True))
114
+ self.EWM.apply_settings(
115
+ EWMSettings(
116
+ axis=self.SETTINGS.axis,
117
+ zero_offset=self.SETTINGS.zero_offset,
118
+ )
119
+ )
112
120
 
113
121
  self.WINDOW.apply_settings(
114
122
  WindowSettings(