ezmsg-sigproc 2.3.0__tar.gz → 2.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/PKG-INFO +1 -1
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/__version__.py +2 -2
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/aggregate.py +69 -0
- ezmsg_sigproc-2.4.0/tests/unit/test_aggregate.py +411 -0
- ezmsg_sigproc-2.3.0/tests/unit/test_aggregate.py +0 -161
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/.github/workflows/docs.yml +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/.github/workflows/python-publish-ezmsg-sigproc.yml +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/.github/workflows/python-tests.yml +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/.gitignore +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/.pre-commit-config.yaml +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/LICENSE.txt +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/README.md +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/Makefile +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/make.bat +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/_templates/autosummary/module.rst +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/api/index.rst +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/conf.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/HybridBuffer.md +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/ProcessorsBase.md +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/explanations/sigproc.rst +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/how-tos/signalprocessing/adaptive.rst +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/how-tos/signalprocessing/checkpoint.rst +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/how-tos/signalprocessing/composite.rst +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/how-tos/signalprocessing/content-signalprocessing.rst +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/how-tos/signalprocessing/processor.rst +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/how-tos/signalprocessing/standalone.rst +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/how-tos/signalprocessing/stateful.rst +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/how-tos/signalprocessing/unit.rst +0 -0
- {ezmsg_sigproc-2.3.0/docs → ezmsg_sigproc-2.4.0/docs/source/guides}/img/HybridBufferBasic.svg +0 -0
- {ezmsg_sigproc-2.3.0/docs → ezmsg_sigproc-2.4.0/docs/source/guides}/img/HybridBufferOverflow.svg +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/sigproc/base.rst +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/sigproc/content-sigproc.rst +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/sigproc/processors.rst +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/sigproc/units.rst +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/tutorials/signalprocessing.rst +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/index.rst +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/pyproject.toml +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/__init__.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/activation.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/adaptive_lattice_notch.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/affinetransform.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/bandpower.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/base.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/butterworthfilter.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/cheby.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/combfilter.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/decimate.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/denormalize.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/detrend.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/diff.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/downsample.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/ewma.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/ewmfilter.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/extract_axis.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/fbcca.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/filter.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/filterbank.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/filterbankdesign.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/firfilter.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/gaussiansmoothing.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/kaiser.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/math/__init__.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/math/abs.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/math/clip.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/math/difference.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/math/invert.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/math/log.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/math/scale.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/messages.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/quantize.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/resample.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/sampler.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/scaler.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/signalinjector.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/slicer.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/spectral.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/spectrogram.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/spectrum.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/synth.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/transpose.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/util/__init__.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/util/asio.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/util/axisarray_buffer.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/util/buffer.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/util/message.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/util/profile.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/util/sparse.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/util/typeresolution.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/wavelets.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/window.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/__init__.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/conftest.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/helpers/__init__.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/helpers/util.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/bytewax/test_spectrum_bytewax.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/bytewax/test_window_bytewax.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/ezmsg/test_butterworth_system.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/ezmsg/test_decimate_system.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/ezmsg/test_downsample_system.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/ezmsg/test_filter_system.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/ezmsg/test_sampler_system.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/ezmsg/test_scaler_system.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/ezmsg/test_spectrum_system.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/ezmsg/test_synth_system.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/ezmsg/test_window_system.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/resources/xform.csv +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/test_profile.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/buffer/test_axisarray_buffer.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/buffer/test_buffer.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/buffer/test_buffer_overflow.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_activation.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_adaptive_lattice_notch.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_affine_transform.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_bandpower.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_base.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_butter.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_combfilter.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_denormalize.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_diff.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_downsample.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_ewma.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_extract_axis.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_fbcca.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_filter.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_filterbank.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_filterbankdesign.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_firfilter.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_gaussian_smoothing_filter.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_kaiser.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_math.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_quantize.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_resample.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_sampler.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_scaler.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_slicer.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_spectrogram.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_spectrum.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_synth.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_transpose.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_util.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_wavelets.py +0 -0
- {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_window.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ezmsg-sigproc
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.4.0
|
|
4
4
|
Summary: Timeseries signal processing implementations in ezmsg
|
|
5
5
|
Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '2.
|
|
32
|
-
__version_tuple__ = version_tuple = (2,
|
|
31
|
+
__version__ = version = '2.4.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (2, 4, 0)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from array_api_compat import get_namespace
|
|
1
2
|
import typing
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
@@ -12,6 +13,7 @@ from ezmsg.util.messages.axisarray import (
|
|
|
12
13
|
|
|
13
14
|
from .spectral import OptionsEnum
|
|
14
15
|
from .base import (
|
|
16
|
+
BaseTransformer,
|
|
15
17
|
BaseStatefulTransformer,
|
|
16
18
|
BaseTransformerUnit,
|
|
17
19
|
processor_state,
|
|
@@ -213,3 +215,70 @@ def ranged_aggregate(
|
|
|
213
215
|
return RangedAggregateTransformer(
|
|
214
216
|
RangedAggregateSettings(axis=axis, bands=bands, operation=operation)
|
|
215
217
|
)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class AggregateSettings(ez.Settings):
|
|
221
|
+
"""Settings for :obj:`Aggregate`."""
|
|
222
|
+
|
|
223
|
+
axis: str
|
|
224
|
+
"""The name of the axis to aggregate over. This axis will be removed from the output."""
|
|
225
|
+
|
|
226
|
+
operation: AggregationFunction = AggregationFunction.MEAN
|
|
227
|
+
""":obj:`AggregationFunction` to apply."""
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class AggregateTransformer(BaseTransformer[AggregateSettings, AxisArray, AxisArray]):
|
|
231
|
+
"""
|
|
232
|
+
Transformer that aggregates an entire axis using a specified operation.
|
|
233
|
+
|
|
234
|
+
Unlike :obj:`RangedAggregateTransformer` which aggregates over specific ranges/bands
|
|
235
|
+
and preserves the axis (with one value per band), this transformer aggregates the
|
|
236
|
+
entire axis and removes it from the output, reducing dimensionality by one.
|
|
237
|
+
"""
|
|
238
|
+
|
|
239
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
240
|
+
xp = get_namespace(message.data)
|
|
241
|
+
axis_idx = message.get_axis_idx(self.settings.axis)
|
|
242
|
+
op = self.settings.operation
|
|
243
|
+
|
|
244
|
+
if op == AggregationFunction.NONE:
|
|
245
|
+
raise ValueError(
|
|
246
|
+
"AggregationFunction.NONE is not supported for full-axis aggregation"
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
if op == AggregationFunction.TRAPEZOID:
|
|
250
|
+
# Trapezoid integration requires x-coordinates
|
|
251
|
+
target_axis = message.get_axis(self.settings.axis)
|
|
252
|
+
if hasattr(target_axis, "data"):
|
|
253
|
+
x = target_axis.data
|
|
254
|
+
else:
|
|
255
|
+
x = target_axis.value(np.arange(message.data.shape[axis_idx]))
|
|
256
|
+
agg_data = np.trapezoid(np.asarray(message.data), x=x, axis=axis_idx)
|
|
257
|
+
else:
|
|
258
|
+
# Try array-API compatible function first, fall back to numpy
|
|
259
|
+
func_name = op.value
|
|
260
|
+
if hasattr(xp, func_name):
|
|
261
|
+
agg_data = getattr(xp, func_name)(message.data, axis=axis_idx)
|
|
262
|
+
else:
|
|
263
|
+
agg_data = AGGREGATORS[op](message.data, axis=axis_idx)
|
|
264
|
+
|
|
265
|
+
new_dims = list(message.dims)
|
|
266
|
+
new_dims.pop(axis_idx)
|
|
267
|
+
|
|
268
|
+
new_axes = dict(message.axes)
|
|
269
|
+
new_axes.pop(self.settings.axis, None)
|
|
270
|
+
|
|
271
|
+
return replace(
|
|
272
|
+
message,
|
|
273
|
+
data=agg_data,
|
|
274
|
+
dims=new_dims,
|
|
275
|
+
axes=new_axes,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
class AggregateUnit(
|
|
280
|
+
BaseTransformerUnit[AggregateSettings, AxisArray, AxisArray, AggregateTransformer]
|
|
281
|
+
):
|
|
282
|
+
"""Unit that aggregates an entire axis using a specified operation."""
|
|
283
|
+
|
|
284
|
+
SETTINGS = AggregateSettings
|
|
@@ -0,0 +1,411 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pytest
|
|
6
|
+
from frozendict import frozendict
|
|
7
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
|
+
|
|
9
|
+
from ezmsg.sigproc.aggregate import (
|
|
10
|
+
ranged_aggregate,
|
|
11
|
+
AggregationFunction,
|
|
12
|
+
AggregateTransformer,
|
|
13
|
+
AggregateSettings,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from tests.helpers.util import assert_messages_equal
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_msg_gen(n_chans=20, n_freqs=100, data_dur=30.0, fs=1024.0, key=""):
|
|
20
|
+
n_samples = int(data_dur * fs)
|
|
21
|
+
data = np.arange(n_samples * n_chans * n_freqs).reshape(n_samples, n_chans, n_freqs)
|
|
22
|
+
n_msgs = int(data_dur / 2)
|
|
23
|
+
|
|
24
|
+
def msg_generator():
|
|
25
|
+
offset = 0
|
|
26
|
+
for arr in np.array_split(data, n_samples // n_msgs):
|
|
27
|
+
msg = AxisArray(
|
|
28
|
+
data=arr,
|
|
29
|
+
dims=["time", "ch", "freq"],
|
|
30
|
+
axes=frozendict(
|
|
31
|
+
{
|
|
32
|
+
"time": AxisArray.TimeAxis(fs=fs, offset=offset),
|
|
33
|
+
"freq": AxisArray.LinearAxis(gain=1.0, offset=0.0, unit="Hz"),
|
|
34
|
+
}
|
|
35
|
+
),
|
|
36
|
+
key=key,
|
|
37
|
+
)
|
|
38
|
+
offset += arr.shape[0] / fs
|
|
39
|
+
yield msg
|
|
40
|
+
|
|
41
|
+
return msg_generator()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@pytest.mark.parametrize(
|
|
45
|
+
"agg_func",
|
|
46
|
+
[
|
|
47
|
+
AggregationFunction.MEAN,
|
|
48
|
+
AggregationFunction.MEDIAN,
|
|
49
|
+
AggregationFunction.STD,
|
|
50
|
+
AggregationFunction.SUM,
|
|
51
|
+
],
|
|
52
|
+
)
|
|
53
|
+
def test_aggregate(agg_func: AggregationFunction):
|
|
54
|
+
bands = [(5.0, 20.0), (30.0, 50.0)]
|
|
55
|
+
targ_ax = "freq"
|
|
56
|
+
|
|
57
|
+
in_msgs = [_ for _ in get_msg_gen()]
|
|
58
|
+
|
|
59
|
+
# Grab a deepcopy backup of the inputs so we can check the inputs didn't change
|
|
60
|
+
# while being processed.
|
|
61
|
+
import copy
|
|
62
|
+
|
|
63
|
+
backup = [copy.deepcopy(_) for _ in in_msgs]
|
|
64
|
+
|
|
65
|
+
gen = ranged_aggregate(axis=targ_ax, bands=bands, operation=agg_func)
|
|
66
|
+
out_msgs = [gen.send(_) for _ in in_msgs]
|
|
67
|
+
|
|
68
|
+
assert_messages_equal(in_msgs, backup)
|
|
69
|
+
|
|
70
|
+
assert all([type(_) is AxisArray for _ in out_msgs])
|
|
71
|
+
|
|
72
|
+
# Check output axis
|
|
73
|
+
for out_msg in out_msgs:
|
|
74
|
+
ax = out_msg.axes[targ_ax]
|
|
75
|
+
assert np.array_equal(ax.data, np.array([np.mean(band) for band in bands]))
|
|
76
|
+
assert ax.unit == in_msgs[0].axes[targ_ax].unit
|
|
77
|
+
|
|
78
|
+
# Check data
|
|
79
|
+
data = AxisArray.concatenate(*in_msgs, dim="time").data
|
|
80
|
+
targ_ax = in_msgs[0].axes[targ_ax]
|
|
81
|
+
targ_ax_vec = targ_ax.value(np.arange(data.shape[-1]))
|
|
82
|
+
agg_func = {
|
|
83
|
+
AggregationFunction.MEAN: partial(np.mean, axis=-1, keepdims=True),
|
|
84
|
+
AggregationFunction.MEDIAN: partial(np.median, axis=-1, keepdims=True),
|
|
85
|
+
AggregationFunction.STD: partial(np.std, axis=-1, keepdims=True),
|
|
86
|
+
AggregationFunction.SUM: partial(np.sum, axis=-1, keepdims=True),
|
|
87
|
+
}[agg_func]
|
|
88
|
+
expected_data = np.concatenate(
|
|
89
|
+
[
|
|
90
|
+
agg_func(
|
|
91
|
+
data[..., np.logical_and(targ_ax_vec >= start, targ_ax_vec <= stop)]
|
|
92
|
+
)
|
|
93
|
+
for (start, stop) in bands
|
|
94
|
+
],
|
|
95
|
+
axis=-1,
|
|
96
|
+
)
|
|
97
|
+
received_data = AxisArray.concatenate(*out_msgs, dim="time").data
|
|
98
|
+
assert np.allclose(received_data, expected_data)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@pytest.mark.parametrize(
|
|
102
|
+
"agg_func", [AggregationFunction.ARGMIN, AggregationFunction.ARGMAX]
|
|
103
|
+
)
|
|
104
|
+
def test_arg_aggregate(agg_func: AggregationFunction):
|
|
105
|
+
bands = [(5.0, 20.0), (30.0, 50.0)]
|
|
106
|
+
in_msgs = [_ for _ in get_msg_gen()]
|
|
107
|
+
gen = ranged_aggregate(axis="freq", bands=bands, operation=agg_func)
|
|
108
|
+
out_msgs = [gen.send(_) for _ in in_msgs]
|
|
109
|
+
|
|
110
|
+
if agg_func == AggregationFunction.ARGMIN:
|
|
111
|
+
expected_vals = np.array([np.min(_) for _ in bands])
|
|
112
|
+
else:
|
|
113
|
+
expected_vals = np.array([np.max(_) for _ in bands])
|
|
114
|
+
out_dat = AxisArray.concatenate(*out_msgs, dim="time").data
|
|
115
|
+
expected_dat = np.zeros(out_dat.shape[:-1] + (1,)) + expected_vals[None, None, :]
|
|
116
|
+
assert np.array_equal(out_dat, expected_dat)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def test_trapezoid():
|
|
120
|
+
bands = [(5.0, 20.0), (30.0, 50.0)]
|
|
121
|
+
in_msgs = [_ for _ in get_msg_gen()]
|
|
122
|
+
gen = ranged_aggregate(
|
|
123
|
+
axis="freq", bands=bands, operation=AggregationFunction.TRAPEZOID
|
|
124
|
+
)
|
|
125
|
+
out_msgs = [gen.send(_) for _ in in_msgs]
|
|
126
|
+
|
|
127
|
+
out_dat = AxisArray.concatenate(*out_msgs, dim="time").data
|
|
128
|
+
|
|
129
|
+
# Calculate expected data using trapezoidal integration
|
|
130
|
+
in_data = AxisArray.concatenate(*in_msgs, dim="time").data
|
|
131
|
+
targ_ax = in_msgs[0].axes["freq"]
|
|
132
|
+
targ_ax_vec = targ_ax.value(np.arange(in_data.shape[-1]))
|
|
133
|
+
expected = []
|
|
134
|
+
for start, stop in bands:
|
|
135
|
+
inds = np.logical_and(targ_ax_vec >= start, targ_ax_vec <= stop)
|
|
136
|
+
expected.append(np.trapezoid(in_data[..., inds], x=targ_ax_vec[inds], axis=-1))
|
|
137
|
+
expected = np.stack(expected, axis=-1)
|
|
138
|
+
|
|
139
|
+
assert out_dat.shape == expected.shape
|
|
140
|
+
assert np.allclose(out_dat, expected)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@pytest.mark.parametrize("change_ax", ["ch", "freq"])
|
|
144
|
+
def test_aggregate_handle_change(change_ax: str):
|
|
145
|
+
"""
|
|
146
|
+
If ranged_aggregate couldn't handle incoming changes, then
|
|
147
|
+
change_ax being 'ch' should work while 'freq' should fail.
|
|
148
|
+
"""
|
|
149
|
+
in_msgs1 = [_ for _ in get_msg_gen(n_chans=20, n_freqs=100)]
|
|
150
|
+
in_msgs2 = [
|
|
151
|
+
_
|
|
152
|
+
for _ in get_msg_gen(
|
|
153
|
+
n_chans=17 if change_ax == "ch" else 20,
|
|
154
|
+
n_freqs=70 if change_ax == "freq" else 100,
|
|
155
|
+
)
|
|
156
|
+
]
|
|
157
|
+
|
|
158
|
+
gen = ranged_aggregate(
|
|
159
|
+
axis="freq",
|
|
160
|
+
bands=[(5.0, 20.0), (30.0, 50.0)],
|
|
161
|
+
operation=AggregationFunction.MEAN,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
out_msgs1 = [gen.send(_) for _ in in_msgs1]
|
|
165
|
+
print(len(out_msgs1))
|
|
166
|
+
out_msgs2 = [gen.send(_) for _ in in_msgs2]
|
|
167
|
+
print(len(out_msgs2))
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
# ============== Tests for AggregateTransformer ==============
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def get_simple_msg(n_times=10, n_chans=5, n_freqs=8, fs=100.0):
|
|
174
|
+
"""Create a simple AxisArray message for testing AggregateTransformer."""
|
|
175
|
+
data = np.arange(n_times * n_chans * n_freqs, dtype=float).reshape(
|
|
176
|
+
n_times, n_chans, n_freqs
|
|
177
|
+
)
|
|
178
|
+
return AxisArray(
|
|
179
|
+
data=data,
|
|
180
|
+
dims=["time", "ch", "freq"],
|
|
181
|
+
axes=frozendict(
|
|
182
|
+
{
|
|
183
|
+
"time": AxisArray.TimeAxis(fs=fs, offset=0.0),
|
|
184
|
+
"ch": AxisArray.CoordinateAxis(
|
|
185
|
+
data=np.array([f"ch{i}" for i in range(n_chans)]),
|
|
186
|
+
dims=["ch"],
|
|
187
|
+
),
|
|
188
|
+
"freq": AxisArray.LinearAxis(gain=2.0, offset=1.0, unit="Hz"),
|
|
189
|
+
}
|
|
190
|
+
),
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
@pytest.mark.parametrize(
|
|
195
|
+
"operation",
|
|
196
|
+
[
|
|
197
|
+
AggregationFunction.MEAN,
|
|
198
|
+
AggregationFunction.SUM,
|
|
199
|
+
AggregationFunction.MAX,
|
|
200
|
+
AggregationFunction.MIN,
|
|
201
|
+
AggregationFunction.STD,
|
|
202
|
+
AggregationFunction.MEDIAN,
|
|
203
|
+
],
|
|
204
|
+
)
|
|
205
|
+
def test_aggregate_transformer_basic(operation: AggregationFunction):
|
|
206
|
+
"""Test AggregateTransformer with basic aggregation operations."""
|
|
207
|
+
msg_in = get_simple_msg()
|
|
208
|
+
backup = copy.deepcopy(msg_in)
|
|
209
|
+
|
|
210
|
+
transformer = AggregateTransformer(
|
|
211
|
+
AggregateSettings(axis="freq", operation=operation)
|
|
212
|
+
)
|
|
213
|
+
msg_out = transformer(msg_in)
|
|
214
|
+
|
|
215
|
+
# Verify input wasn't modified
|
|
216
|
+
assert_messages_equal([msg_in], [backup])
|
|
217
|
+
|
|
218
|
+
# Verify output type
|
|
219
|
+
assert isinstance(msg_out, AxisArray)
|
|
220
|
+
|
|
221
|
+
# Verify axis was removed
|
|
222
|
+
assert "freq" not in msg_out.dims
|
|
223
|
+
assert "freq" not in msg_out.axes
|
|
224
|
+
assert msg_out.dims == ["time", "ch"]
|
|
225
|
+
|
|
226
|
+
# Verify output shape
|
|
227
|
+
assert msg_out.data.shape == (10, 5)
|
|
228
|
+
|
|
229
|
+
# Verify data correctness
|
|
230
|
+
np_func = getattr(np, operation.value)
|
|
231
|
+
expected = np_func(msg_in.data, axis=2)
|
|
232
|
+
assert np.allclose(msg_out.data, expected)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
@pytest.mark.parametrize("axis", ["time", "ch", "freq"])
|
|
236
|
+
def test_aggregate_transformer_different_axes(axis: str):
|
|
237
|
+
"""Test AggregateTransformer can aggregate along different axes."""
|
|
238
|
+
msg_in = get_simple_msg(n_times=10, n_chans=5, n_freqs=8)
|
|
239
|
+
|
|
240
|
+
transformer = AggregateTransformer(
|
|
241
|
+
AggregateSettings(axis=axis, operation=AggregationFunction.MEAN)
|
|
242
|
+
)
|
|
243
|
+
msg_out = transformer(msg_in)
|
|
244
|
+
|
|
245
|
+
# Verify the specified axis was removed
|
|
246
|
+
assert axis not in msg_out.dims
|
|
247
|
+
assert axis not in msg_out.axes
|
|
248
|
+
|
|
249
|
+
# Verify remaining dims
|
|
250
|
+
expected_dims = [d for d in ["time", "ch", "freq"] if d != axis]
|
|
251
|
+
assert msg_out.dims == expected_dims
|
|
252
|
+
|
|
253
|
+
# Verify shape
|
|
254
|
+
axis_idx = msg_in.get_axis_idx(axis)
|
|
255
|
+
expected_shape = list(msg_in.data.shape)
|
|
256
|
+
expected_shape.pop(axis_idx)
|
|
257
|
+
assert msg_out.data.shape == tuple(expected_shape)
|
|
258
|
+
|
|
259
|
+
# Verify data
|
|
260
|
+
expected = np.mean(msg_in.data, axis=axis_idx)
|
|
261
|
+
assert np.allclose(msg_out.data, expected)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def test_aggregate_transformer_none_raises():
|
|
265
|
+
"""Test that AggregationFunction.NONE raises an error."""
|
|
266
|
+
msg_in = get_simple_msg()
|
|
267
|
+
|
|
268
|
+
transformer = AggregateTransformer(
|
|
269
|
+
AggregateSettings(axis="freq", operation=AggregationFunction.NONE)
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
with pytest.raises(ValueError, match="NONE is not supported"):
|
|
273
|
+
transformer(msg_in)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
@pytest.mark.parametrize(
|
|
277
|
+
"operation",
|
|
278
|
+
[
|
|
279
|
+
AggregationFunction.NANMEAN,
|
|
280
|
+
AggregationFunction.NANSUM,
|
|
281
|
+
AggregationFunction.NANMAX,
|
|
282
|
+
AggregationFunction.NANMIN,
|
|
283
|
+
AggregationFunction.NANSTD,
|
|
284
|
+
AggregationFunction.NANMEDIAN,
|
|
285
|
+
],
|
|
286
|
+
)
|
|
287
|
+
def test_aggregate_transformer_nan_operations(operation: AggregationFunction):
|
|
288
|
+
"""Test AggregateTransformer with NaN-aware operations."""
|
|
289
|
+
msg_in = get_simple_msg()
|
|
290
|
+
# Introduce some NaN values
|
|
291
|
+
msg_in.data[0, 0, 0] = np.nan
|
|
292
|
+
msg_in.data[5, 2, 3] = np.nan
|
|
293
|
+
|
|
294
|
+
transformer = AggregateTransformer(
|
|
295
|
+
AggregateSettings(axis="freq", operation=operation)
|
|
296
|
+
)
|
|
297
|
+
msg_out = transformer(msg_in)
|
|
298
|
+
|
|
299
|
+
# Verify output doesn't have NaN where nan-operations should have handled it
|
|
300
|
+
np_func = getattr(np, operation.value)
|
|
301
|
+
expected = np_func(msg_in.data, axis=2)
|
|
302
|
+
assert np.allclose(msg_out.data, expected, equal_nan=True)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
@pytest.mark.parametrize(
|
|
306
|
+
"operation", [AggregationFunction.ARGMIN, AggregationFunction.ARGMAX]
|
|
307
|
+
)
|
|
308
|
+
def test_aggregate_transformer_argminmax(operation: AggregationFunction):
|
|
309
|
+
"""Test AggregateTransformer with argmin/argmax operations."""
|
|
310
|
+
msg_in = get_simple_msg()
|
|
311
|
+
|
|
312
|
+
transformer = AggregateTransformer(
|
|
313
|
+
AggregateSettings(axis="freq", operation=operation)
|
|
314
|
+
)
|
|
315
|
+
msg_out = transformer(msg_in)
|
|
316
|
+
|
|
317
|
+
# Verify output shape (axis removed)
|
|
318
|
+
assert msg_out.data.shape == (10, 5)
|
|
319
|
+
assert "freq" not in msg_out.dims
|
|
320
|
+
|
|
321
|
+
# Verify data correctness (returns indices)
|
|
322
|
+
np_func = getattr(np, operation.value)
|
|
323
|
+
expected = np_func(msg_in.data, axis=2)
|
|
324
|
+
assert np.array_equal(msg_out.data, expected)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def test_aggregate_transformer_trapezoid():
|
|
328
|
+
"""Test AggregateTransformer with trapezoid integration."""
|
|
329
|
+
msg_in = get_simple_msg(n_times=5, n_chans=3, n_freqs=10)
|
|
330
|
+
|
|
331
|
+
transformer = AggregateTransformer(
|
|
332
|
+
AggregateSettings(axis="freq", operation=AggregationFunction.TRAPEZOID)
|
|
333
|
+
)
|
|
334
|
+
msg_out = transformer(msg_in)
|
|
335
|
+
|
|
336
|
+
# Verify output shape
|
|
337
|
+
assert msg_out.data.shape == (5, 3)
|
|
338
|
+
assert "freq" not in msg_out.dims
|
|
339
|
+
|
|
340
|
+
# Calculate expected result using axis coordinates
|
|
341
|
+
freq_axis = msg_in.axes["freq"]
|
|
342
|
+
x = freq_axis.value(np.arange(msg_in.data.shape[2]))
|
|
343
|
+
expected = np.trapezoid(msg_in.data, x=x, axis=2)
|
|
344
|
+
|
|
345
|
+
assert np.allclose(msg_out.data, expected)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def test_aggregate_transformer_trapezoid_coordinate_axis():
|
|
349
|
+
"""Test trapezoid integration with CoordinateAxis."""
|
|
350
|
+
n_times, n_chans, n_freqs = 5, 3, 10
|
|
351
|
+
data = np.arange(n_times * n_chans * n_freqs, dtype=float).reshape(
|
|
352
|
+
n_times, n_chans, n_freqs
|
|
353
|
+
)
|
|
354
|
+
freq_values = np.array([1.0, 2.0, 4.0, 7.0, 11.0, 16.0, 22.0, 29.0, 37.0, 46.0])
|
|
355
|
+
msg_in = AxisArray(
|
|
356
|
+
data=data,
|
|
357
|
+
dims=["time", "ch", "freq"],
|
|
358
|
+
axes=frozendict(
|
|
359
|
+
{
|
|
360
|
+
"time": AxisArray.TimeAxis(fs=100.0, offset=0.0),
|
|
361
|
+
"freq": AxisArray.CoordinateAxis(
|
|
362
|
+
data=freq_values, dims=["freq"], unit="Hz"
|
|
363
|
+
),
|
|
364
|
+
}
|
|
365
|
+
),
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
transformer = AggregateTransformer(
|
|
369
|
+
AggregateSettings(axis="freq", operation=AggregationFunction.TRAPEZOID)
|
|
370
|
+
)
|
|
371
|
+
msg_out = transformer(msg_in)
|
|
372
|
+
|
|
373
|
+
# Calculate expected using the coordinate values
|
|
374
|
+
expected = np.trapezoid(msg_in.data, x=freq_values, axis=2)
|
|
375
|
+
assert np.allclose(msg_out.data, expected)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def test_aggregate_transformer_preserves_other_axes():
|
|
379
|
+
"""Test that non-aggregated axes are preserved correctly."""
|
|
380
|
+
msg_in = get_simple_msg()
|
|
381
|
+
|
|
382
|
+
transformer = AggregateTransformer(
|
|
383
|
+
AggregateSettings(axis="freq", operation=AggregationFunction.MEAN)
|
|
384
|
+
)
|
|
385
|
+
msg_out = transformer(msg_in)
|
|
386
|
+
|
|
387
|
+
# Verify time axis preserved
|
|
388
|
+
assert "time" in msg_out.axes
|
|
389
|
+
assert msg_out.axes["time"] == msg_in.axes["time"]
|
|
390
|
+
|
|
391
|
+
# Verify ch axis preserved
|
|
392
|
+
assert "ch" in msg_out.axes
|
|
393
|
+
ch_ax_in = msg_in.axes["ch"]
|
|
394
|
+
ch_ax_out = msg_out.axes["ch"]
|
|
395
|
+
assert np.array_equal(ch_ax_out.data, ch_ax_in.data)
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def test_aggregate_transformer_multiple_calls():
|
|
399
|
+
"""Test that transformer works correctly with multiple calls."""
|
|
400
|
+
transformer = AggregateTransformer(
|
|
401
|
+
AggregateSettings(axis="freq", operation=AggregationFunction.SUM)
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
for i in range(3):
|
|
405
|
+
msg_in = get_simple_msg()
|
|
406
|
+
msg_in.data = msg_in.data + i * 1000 # Different data each time
|
|
407
|
+
|
|
408
|
+
msg_out = transformer(msg_in)
|
|
409
|
+
|
|
410
|
+
expected = np.sum(msg_in.data, axis=2)
|
|
411
|
+
assert np.allclose(msg_out.data, expected)
|
|
@@ -1,161 +0,0 @@
|
|
|
1
|
-
from functools import partial
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import pytest
|
|
5
|
-
from frozendict import frozendict
|
|
6
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
-
|
|
8
|
-
from ezmsg.sigproc.aggregate import ranged_aggregate, AggregationFunction
|
|
9
|
-
|
|
10
|
-
from tests.helpers.util import assert_messages_equal
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def get_msg_gen(n_chans=20, n_freqs=100, data_dur=30.0, fs=1024.0, key=""):
|
|
14
|
-
n_samples = int(data_dur * fs)
|
|
15
|
-
data = np.arange(n_samples * n_chans * n_freqs).reshape(n_samples, n_chans, n_freqs)
|
|
16
|
-
n_msgs = int(data_dur / 2)
|
|
17
|
-
|
|
18
|
-
def msg_generator():
|
|
19
|
-
offset = 0
|
|
20
|
-
for arr in np.array_split(data, n_samples // n_msgs):
|
|
21
|
-
msg = AxisArray(
|
|
22
|
-
data=arr,
|
|
23
|
-
dims=["time", "ch", "freq"],
|
|
24
|
-
axes=frozendict(
|
|
25
|
-
{
|
|
26
|
-
"time": AxisArray.TimeAxis(fs=fs, offset=offset),
|
|
27
|
-
"freq": AxisArray.LinearAxis(gain=1.0, offset=0.0, unit="Hz"),
|
|
28
|
-
}
|
|
29
|
-
),
|
|
30
|
-
key=key,
|
|
31
|
-
)
|
|
32
|
-
offset += arr.shape[0] / fs
|
|
33
|
-
yield msg
|
|
34
|
-
|
|
35
|
-
return msg_generator()
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
@pytest.mark.parametrize(
|
|
39
|
-
"agg_func",
|
|
40
|
-
[
|
|
41
|
-
AggregationFunction.MEAN,
|
|
42
|
-
AggregationFunction.MEDIAN,
|
|
43
|
-
AggregationFunction.STD,
|
|
44
|
-
AggregationFunction.SUM,
|
|
45
|
-
],
|
|
46
|
-
)
|
|
47
|
-
def test_aggregate(agg_func: AggregationFunction):
|
|
48
|
-
bands = [(5.0, 20.0), (30.0, 50.0)]
|
|
49
|
-
targ_ax = "freq"
|
|
50
|
-
|
|
51
|
-
in_msgs = [_ for _ in get_msg_gen()]
|
|
52
|
-
|
|
53
|
-
# Grab a deepcopy backup of the inputs so we can check the inputs didn't change
|
|
54
|
-
# while being processed.
|
|
55
|
-
import copy
|
|
56
|
-
|
|
57
|
-
backup = [copy.deepcopy(_) for _ in in_msgs]
|
|
58
|
-
|
|
59
|
-
gen = ranged_aggregate(axis=targ_ax, bands=bands, operation=agg_func)
|
|
60
|
-
out_msgs = [gen.send(_) for _ in in_msgs]
|
|
61
|
-
|
|
62
|
-
assert_messages_equal(in_msgs, backup)
|
|
63
|
-
|
|
64
|
-
assert all([type(_) is AxisArray for _ in out_msgs])
|
|
65
|
-
|
|
66
|
-
# Check output axis
|
|
67
|
-
for out_msg in out_msgs:
|
|
68
|
-
ax = out_msg.axes[targ_ax]
|
|
69
|
-
assert np.array_equal(ax.data, np.array([np.mean(band) for band in bands]))
|
|
70
|
-
assert ax.unit == in_msgs[0].axes[targ_ax].unit
|
|
71
|
-
|
|
72
|
-
# Check data
|
|
73
|
-
data = AxisArray.concatenate(*in_msgs, dim="time").data
|
|
74
|
-
targ_ax = in_msgs[0].axes[targ_ax]
|
|
75
|
-
targ_ax_vec = targ_ax.value(np.arange(data.shape[-1]))
|
|
76
|
-
agg_func = {
|
|
77
|
-
AggregationFunction.MEAN: partial(np.mean, axis=-1, keepdims=True),
|
|
78
|
-
AggregationFunction.MEDIAN: partial(np.median, axis=-1, keepdims=True),
|
|
79
|
-
AggregationFunction.STD: partial(np.std, axis=-1, keepdims=True),
|
|
80
|
-
AggregationFunction.SUM: partial(np.sum, axis=-1, keepdims=True),
|
|
81
|
-
}[agg_func]
|
|
82
|
-
expected_data = np.concatenate(
|
|
83
|
-
[
|
|
84
|
-
agg_func(
|
|
85
|
-
data[..., np.logical_and(targ_ax_vec >= start, targ_ax_vec <= stop)]
|
|
86
|
-
)
|
|
87
|
-
for (start, stop) in bands
|
|
88
|
-
],
|
|
89
|
-
axis=-1,
|
|
90
|
-
)
|
|
91
|
-
received_data = AxisArray.concatenate(*out_msgs, dim="time").data
|
|
92
|
-
assert np.allclose(received_data, expected_data)
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
@pytest.mark.parametrize(
|
|
96
|
-
"agg_func", [AggregationFunction.ARGMIN, AggregationFunction.ARGMAX]
|
|
97
|
-
)
|
|
98
|
-
def test_arg_aggregate(agg_func: AggregationFunction):
|
|
99
|
-
bands = [(5.0, 20.0), (30.0, 50.0)]
|
|
100
|
-
in_msgs = [_ for _ in get_msg_gen()]
|
|
101
|
-
gen = ranged_aggregate(axis="freq", bands=bands, operation=agg_func)
|
|
102
|
-
out_msgs = [gen.send(_) for _ in in_msgs]
|
|
103
|
-
|
|
104
|
-
if agg_func == AggregationFunction.ARGMIN:
|
|
105
|
-
expected_vals = np.array([np.min(_) for _ in bands])
|
|
106
|
-
else:
|
|
107
|
-
expected_vals = np.array([np.max(_) for _ in bands])
|
|
108
|
-
out_dat = AxisArray.concatenate(*out_msgs, dim="time").data
|
|
109
|
-
expected_dat = np.zeros(out_dat.shape[:-1] + (1,)) + expected_vals[None, None, :]
|
|
110
|
-
assert np.array_equal(out_dat, expected_dat)
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
def test_trapezoid():
|
|
114
|
-
bands = [(5.0, 20.0), (30.0, 50.0)]
|
|
115
|
-
in_msgs = [_ for _ in get_msg_gen()]
|
|
116
|
-
gen = ranged_aggregate(
|
|
117
|
-
axis="freq", bands=bands, operation=AggregationFunction.TRAPEZOID
|
|
118
|
-
)
|
|
119
|
-
out_msgs = [gen.send(_) for _ in in_msgs]
|
|
120
|
-
|
|
121
|
-
out_dat = AxisArray.concatenate(*out_msgs, dim="time").data
|
|
122
|
-
|
|
123
|
-
# Calculate expected data using trapezoidal integration
|
|
124
|
-
in_data = AxisArray.concatenate(*in_msgs, dim="time").data
|
|
125
|
-
targ_ax = in_msgs[0].axes["freq"]
|
|
126
|
-
targ_ax_vec = targ_ax.value(np.arange(in_data.shape[-1]))
|
|
127
|
-
expected = []
|
|
128
|
-
for start, stop in bands:
|
|
129
|
-
inds = np.logical_and(targ_ax_vec >= start, targ_ax_vec <= stop)
|
|
130
|
-
expected.append(np.trapezoid(in_data[..., inds], x=targ_ax_vec[inds], axis=-1))
|
|
131
|
-
expected = np.stack(expected, axis=-1)
|
|
132
|
-
|
|
133
|
-
assert out_dat.shape == expected.shape
|
|
134
|
-
assert np.allclose(out_dat, expected)
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
@pytest.mark.parametrize("change_ax", ["ch", "freq"])
|
|
138
|
-
def test_aggregate_handle_change(change_ax: str):
|
|
139
|
-
"""
|
|
140
|
-
If ranged_aggregate couldn't handle incoming changes, then
|
|
141
|
-
change_ax being 'ch' should work while 'freq' should fail.
|
|
142
|
-
"""
|
|
143
|
-
in_msgs1 = [_ for _ in get_msg_gen(n_chans=20, n_freqs=100)]
|
|
144
|
-
in_msgs2 = [
|
|
145
|
-
_
|
|
146
|
-
for _ in get_msg_gen(
|
|
147
|
-
n_chans=17 if change_ax == "ch" else 20,
|
|
148
|
-
n_freqs=70 if change_ax == "freq" else 100,
|
|
149
|
-
)
|
|
150
|
-
]
|
|
151
|
-
|
|
152
|
-
gen = ranged_aggregate(
|
|
153
|
-
axis="freq",
|
|
154
|
-
bands=[(5.0, 20.0), (30.0, 50.0)],
|
|
155
|
-
operation=AggregationFunction.MEAN,
|
|
156
|
-
)
|
|
157
|
-
|
|
158
|
-
out_msgs1 = [gen.send(_) for _ in in_msgs1]
|
|
159
|
-
print(len(out_msgs1))
|
|
160
|
-
out_msgs2 = [gen.send(_) for _ in in_msgs2]
|
|
161
|
-
print(len(out_msgs2))
|
|
File without changes
|
{ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/.github/workflows/python-publish-ezmsg-sigproc.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/how-tos/signalprocessing/adaptive.rst
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/how-tos/signalprocessing/stateful.rst
RENAMED
|
File without changes
|
{ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/how-tos/signalprocessing/unit.rst
RENAMED
|
File without changes
|
{ezmsg_sigproc-2.3.0/docs → ezmsg_sigproc-2.4.0/docs/source/guides}/img/HybridBufferBasic.svg
RENAMED
|
File without changes
|
{ezmsg_sigproc-2.3.0/docs → ezmsg_sigproc-2.4.0/docs/source/guides}/img/HybridBufferOverflow.svg
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/tutorials/signalprocessing.rst
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/bytewax/test_spectrum_bytewax.py
RENAMED
|
File without changes
|
{ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/bytewax/test_window_bytewax.py
RENAMED
|
File without changes
|
{ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/ezmsg/test_butterworth_system.py
RENAMED
|
File without changes
|
|
File without changes
|
{ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/ezmsg/test_downsample_system.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|