ezmsg-sigproc 2.3.0__tar.gz → 2.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (142) hide show
  1. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/PKG-INFO +1 -1
  2. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/__version__.py +2 -2
  3. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/aggregate.py +69 -0
  4. ezmsg_sigproc-2.4.0/tests/unit/test_aggregate.py +411 -0
  5. ezmsg_sigproc-2.3.0/tests/unit/test_aggregate.py +0 -161
  6. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/.github/workflows/docs.yml +0 -0
  7. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/.github/workflows/python-publish-ezmsg-sigproc.yml +0 -0
  8. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/.github/workflows/python-tests.yml +0 -0
  9. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/.gitignore +0 -0
  10. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/.pre-commit-config.yaml +0 -0
  11. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/LICENSE.txt +0 -0
  12. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/README.md +0 -0
  13. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/Makefile +0 -0
  14. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/make.bat +0 -0
  15. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/_templates/autosummary/module.rst +0 -0
  16. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/api/index.rst +0 -0
  17. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/conf.py +0 -0
  18. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/HybridBuffer.md +0 -0
  19. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/ProcessorsBase.md +0 -0
  20. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/explanations/sigproc.rst +0 -0
  21. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/how-tos/signalprocessing/adaptive.rst +0 -0
  22. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/how-tos/signalprocessing/checkpoint.rst +0 -0
  23. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/how-tos/signalprocessing/composite.rst +0 -0
  24. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/how-tos/signalprocessing/content-signalprocessing.rst +0 -0
  25. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/how-tos/signalprocessing/processor.rst +0 -0
  26. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/how-tos/signalprocessing/standalone.rst +0 -0
  27. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/how-tos/signalprocessing/stateful.rst +0 -0
  28. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/how-tos/signalprocessing/unit.rst +0 -0
  29. {ezmsg_sigproc-2.3.0/docs → ezmsg_sigproc-2.4.0/docs/source/guides}/img/HybridBufferBasic.svg +0 -0
  30. {ezmsg_sigproc-2.3.0/docs → ezmsg_sigproc-2.4.0/docs/source/guides}/img/HybridBufferOverflow.svg +0 -0
  31. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/sigproc/base.rst +0 -0
  32. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/sigproc/content-sigproc.rst +0 -0
  33. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/sigproc/processors.rst +0 -0
  34. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/sigproc/units.rst +0 -0
  35. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/guides/tutorials/signalprocessing.rst +0 -0
  36. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/docs/source/index.rst +0 -0
  37. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/pyproject.toml +0 -0
  38. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/__init__.py +0 -0
  39. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/activation.py +0 -0
  40. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/adaptive_lattice_notch.py +0 -0
  41. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/affinetransform.py +0 -0
  42. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/bandpower.py +0 -0
  43. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/base.py +0 -0
  44. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/butterworthfilter.py +0 -0
  45. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/cheby.py +0 -0
  46. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/combfilter.py +0 -0
  47. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/decimate.py +0 -0
  48. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/denormalize.py +0 -0
  49. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/detrend.py +0 -0
  50. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/diff.py +0 -0
  51. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/downsample.py +0 -0
  52. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/ewma.py +0 -0
  53. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/ewmfilter.py +0 -0
  54. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/extract_axis.py +0 -0
  55. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/fbcca.py +0 -0
  56. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/filter.py +0 -0
  57. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/filterbank.py +0 -0
  58. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/filterbankdesign.py +0 -0
  59. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/firfilter.py +0 -0
  60. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/gaussiansmoothing.py +0 -0
  61. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/kaiser.py +0 -0
  62. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/math/__init__.py +0 -0
  63. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/math/abs.py +0 -0
  64. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/math/clip.py +0 -0
  65. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/math/difference.py +0 -0
  66. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/math/invert.py +0 -0
  67. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/math/log.py +0 -0
  68. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/math/scale.py +0 -0
  69. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/messages.py +0 -0
  70. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/quantize.py +0 -0
  71. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/resample.py +0 -0
  72. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/sampler.py +0 -0
  73. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/scaler.py +0 -0
  74. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/signalinjector.py +0 -0
  75. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/slicer.py +0 -0
  76. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/spectral.py +0 -0
  77. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/spectrogram.py +0 -0
  78. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/spectrum.py +0 -0
  79. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/synth.py +0 -0
  80. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/transpose.py +0 -0
  81. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/util/__init__.py +0 -0
  82. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/util/asio.py +0 -0
  83. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/util/axisarray_buffer.py +0 -0
  84. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/util/buffer.py +0 -0
  85. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/util/message.py +0 -0
  86. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/util/profile.py +0 -0
  87. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/util/sparse.py +0 -0
  88. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/util/typeresolution.py +0 -0
  89. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/wavelets.py +0 -0
  90. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/src/ezmsg/sigproc/window.py +0 -0
  91. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/__init__.py +0 -0
  92. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/conftest.py +0 -0
  93. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/helpers/__init__.py +0 -0
  94. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/helpers/util.py +0 -0
  95. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/bytewax/test_spectrum_bytewax.py +0 -0
  96. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/bytewax/test_window_bytewax.py +0 -0
  97. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/ezmsg/test_butterworth_system.py +0 -0
  98. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/ezmsg/test_decimate_system.py +0 -0
  99. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/ezmsg/test_downsample_system.py +0 -0
  100. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/ezmsg/test_filter_system.py +0 -0
  101. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/ezmsg/test_sampler_system.py +0 -0
  102. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/ezmsg/test_scaler_system.py +0 -0
  103. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/ezmsg/test_spectrum_system.py +0 -0
  104. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/ezmsg/test_synth_system.py +0 -0
  105. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/integration/ezmsg/test_window_system.py +0 -0
  106. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/resources/xform.csv +0 -0
  107. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/test_profile.py +0 -0
  108. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/buffer/test_axisarray_buffer.py +0 -0
  109. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/buffer/test_buffer.py +0 -0
  110. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/buffer/test_buffer_overflow.py +0 -0
  111. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_activation.py +0 -0
  112. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_adaptive_lattice_notch.py +0 -0
  113. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_affine_transform.py +0 -0
  114. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_bandpower.py +0 -0
  115. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_base.py +0 -0
  116. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_butter.py +0 -0
  117. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_combfilter.py +0 -0
  118. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_denormalize.py +0 -0
  119. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_diff.py +0 -0
  120. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_downsample.py +0 -0
  121. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_ewma.py +0 -0
  122. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_extract_axis.py +0 -0
  123. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_fbcca.py +0 -0
  124. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_filter.py +0 -0
  125. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_filterbank.py +0 -0
  126. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_filterbankdesign.py +0 -0
  127. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_firfilter.py +0 -0
  128. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_gaussian_smoothing_filter.py +0 -0
  129. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_kaiser.py +0 -0
  130. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_math.py +0 -0
  131. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_quantize.py +0 -0
  132. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_resample.py +0 -0
  133. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_sampler.py +0 -0
  134. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_scaler.py +0 -0
  135. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_slicer.py +0 -0
  136. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_spectrogram.py +0 -0
  137. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_spectrum.py +0 -0
  138. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_synth.py +0 -0
  139. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_transpose.py +0 -0
  140. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_util.py +0 -0
  141. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_wavelets.py +0 -0
  142. {ezmsg_sigproc-2.3.0 → ezmsg_sigproc-2.4.0}/tests/unit/test_window.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ezmsg-sigproc
3
- Version: 2.3.0
3
+ Version: 2.4.0
4
4
  Summary: Timeseries signal processing implementations in ezmsg
5
5
  Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>
6
6
  License-Expression: MIT
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '2.3.0'
32
- __version_tuple__ = version_tuple = (2, 3, 0)
31
+ __version__ = version = '2.4.0'
32
+ __version_tuple__ = version_tuple = (2, 4, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -1,3 +1,4 @@
1
+ from array_api_compat import get_namespace
1
2
  import typing
2
3
 
3
4
  import numpy as np
@@ -12,6 +13,7 @@ from ezmsg.util.messages.axisarray import (
12
13
 
13
14
  from .spectral import OptionsEnum
14
15
  from .base import (
16
+ BaseTransformer,
15
17
  BaseStatefulTransformer,
16
18
  BaseTransformerUnit,
17
19
  processor_state,
@@ -213,3 +215,70 @@ def ranged_aggregate(
213
215
  return RangedAggregateTransformer(
214
216
  RangedAggregateSettings(axis=axis, bands=bands, operation=operation)
215
217
  )
218
+
219
+
220
+ class AggregateSettings(ez.Settings):
221
+ """Settings for :obj:`Aggregate`."""
222
+
223
+ axis: str
224
+ """The name of the axis to aggregate over. This axis will be removed from the output."""
225
+
226
+ operation: AggregationFunction = AggregationFunction.MEAN
227
+ """:obj:`AggregationFunction` to apply."""
228
+
229
+
230
+ class AggregateTransformer(BaseTransformer[AggregateSettings, AxisArray, AxisArray]):
231
+ """
232
+ Transformer that aggregates an entire axis using a specified operation.
233
+
234
+ Unlike :obj:`RangedAggregateTransformer` which aggregates over specific ranges/bands
235
+ and preserves the axis (with one value per band), this transformer aggregates the
236
+ entire axis and removes it from the output, reducing dimensionality by one.
237
+ """
238
+
239
+ def _process(self, message: AxisArray) -> AxisArray:
240
+ xp = get_namespace(message.data)
241
+ axis_idx = message.get_axis_idx(self.settings.axis)
242
+ op = self.settings.operation
243
+
244
+ if op == AggregationFunction.NONE:
245
+ raise ValueError(
246
+ "AggregationFunction.NONE is not supported for full-axis aggregation"
247
+ )
248
+
249
+ if op == AggregationFunction.TRAPEZOID:
250
+ # Trapezoid integration requires x-coordinates
251
+ target_axis = message.get_axis(self.settings.axis)
252
+ if hasattr(target_axis, "data"):
253
+ x = target_axis.data
254
+ else:
255
+ x = target_axis.value(np.arange(message.data.shape[axis_idx]))
256
+ agg_data = np.trapezoid(np.asarray(message.data), x=x, axis=axis_idx)
257
+ else:
258
+ # Try array-API compatible function first, fall back to numpy
259
+ func_name = op.value
260
+ if hasattr(xp, func_name):
261
+ agg_data = getattr(xp, func_name)(message.data, axis=axis_idx)
262
+ else:
263
+ agg_data = AGGREGATORS[op](message.data, axis=axis_idx)
264
+
265
+ new_dims = list(message.dims)
266
+ new_dims.pop(axis_idx)
267
+
268
+ new_axes = dict(message.axes)
269
+ new_axes.pop(self.settings.axis, None)
270
+
271
+ return replace(
272
+ message,
273
+ data=agg_data,
274
+ dims=new_dims,
275
+ axes=new_axes,
276
+ )
277
+
278
+
279
+ class AggregateUnit(
280
+ BaseTransformerUnit[AggregateSettings, AxisArray, AxisArray, AggregateTransformer]
281
+ ):
282
+ """Unit that aggregates an entire axis using a specified operation."""
283
+
284
+ SETTINGS = AggregateSettings
@@ -0,0 +1,411 @@
1
+ import copy
2
+ from functools import partial
3
+
4
+ import numpy as np
5
+ import pytest
6
+ from frozendict import frozendict
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+
9
+ from ezmsg.sigproc.aggregate import (
10
+ ranged_aggregate,
11
+ AggregationFunction,
12
+ AggregateTransformer,
13
+ AggregateSettings,
14
+ )
15
+
16
+ from tests.helpers.util import assert_messages_equal
17
+
18
+
19
+ def get_msg_gen(n_chans=20, n_freqs=100, data_dur=30.0, fs=1024.0, key=""):
20
+ n_samples = int(data_dur * fs)
21
+ data = np.arange(n_samples * n_chans * n_freqs).reshape(n_samples, n_chans, n_freqs)
22
+ n_msgs = int(data_dur / 2)
23
+
24
+ def msg_generator():
25
+ offset = 0
26
+ for arr in np.array_split(data, n_samples // n_msgs):
27
+ msg = AxisArray(
28
+ data=arr,
29
+ dims=["time", "ch", "freq"],
30
+ axes=frozendict(
31
+ {
32
+ "time": AxisArray.TimeAxis(fs=fs, offset=offset),
33
+ "freq": AxisArray.LinearAxis(gain=1.0, offset=0.0, unit="Hz"),
34
+ }
35
+ ),
36
+ key=key,
37
+ )
38
+ offset += arr.shape[0] / fs
39
+ yield msg
40
+
41
+ return msg_generator()
42
+
43
+
44
+ @pytest.mark.parametrize(
45
+ "agg_func",
46
+ [
47
+ AggregationFunction.MEAN,
48
+ AggregationFunction.MEDIAN,
49
+ AggregationFunction.STD,
50
+ AggregationFunction.SUM,
51
+ ],
52
+ )
53
+ def test_aggregate(agg_func: AggregationFunction):
54
+ bands = [(5.0, 20.0), (30.0, 50.0)]
55
+ targ_ax = "freq"
56
+
57
+ in_msgs = [_ for _ in get_msg_gen()]
58
+
59
+ # Grab a deepcopy backup of the inputs so we can check the inputs didn't change
60
+ # while being processed.
61
+ import copy
62
+
63
+ backup = [copy.deepcopy(_) for _ in in_msgs]
64
+
65
+ gen = ranged_aggregate(axis=targ_ax, bands=bands, operation=agg_func)
66
+ out_msgs = [gen.send(_) for _ in in_msgs]
67
+
68
+ assert_messages_equal(in_msgs, backup)
69
+
70
+ assert all([type(_) is AxisArray for _ in out_msgs])
71
+
72
+ # Check output axis
73
+ for out_msg in out_msgs:
74
+ ax = out_msg.axes[targ_ax]
75
+ assert np.array_equal(ax.data, np.array([np.mean(band) for band in bands]))
76
+ assert ax.unit == in_msgs[0].axes[targ_ax].unit
77
+
78
+ # Check data
79
+ data = AxisArray.concatenate(*in_msgs, dim="time").data
80
+ targ_ax = in_msgs[0].axes[targ_ax]
81
+ targ_ax_vec = targ_ax.value(np.arange(data.shape[-1]))
82
+ agg_func = {
83
+ AggregationFunction.MEAN: partial(np.mean, axis=-1, keepdims=True),
84
+ AggregationFunction.MEDIAN: partial(np.median, axis=-1, keepdims=True),
85
+ AggregationFunction.STD: partial(np.std, axis=-1, keepdims=True),
86
+ AggregationFunction.SUM: partial(np.sum, axis=-1, keepdims=True),
87
+ }[agg_func]
88
+ expected_data = np.concatenate(
89
+ [
90
+ agg_func(
91
+ data[..., np.logical_and(targ_ax_vec >= start, targ_ax_vec <= stop)]
92
+ )
93
+ for (start, stop) in bands
94
+ ],
95
+ axis=-1,
96
+ )
97
+ received_data = AxisArray.concatenate(*out_msgs, dim="time").data
98
+ assert np.allclose(received_data, expected_data)
99
+
100
+
101
+ @pytest.mark.parametrize(
102
+ "agg_func", [AggregationFunction.ARGMIN, AggregationFunction.ARGMAX]
103
+ )
104
+ def test_arg_aggregate(agg_func: AggregationFunction):
105
+ bands = [(5.0, 20.0), (30.0, 50.0)]
106
+ in_msgs = [_ for _ in get_msg_gen()]
107
+ gen = ranged_aggregate(axis="freq", bands=bands, operation=agg_func)
108
+ out_msgs = [gen.send(_) for _ in in_msgs]
109
+
110
+ if agg_func == AggregationFunction.ARGMIN:
111
+ expected_vals = np.array([np.min(_) for _ in bands])
112
+ else:
113
+ expected_vals = np.array([np.max(_) for _ in bands])
114
+ out_dat = AxisArray.concatenate(*out_msgs, dim="time").data
115
+ expected_dat = np.zeros(out_dat.shape[:-1] + (1,)) + expected_vals[None, None, :]
116
+ assert np.array_equal(out_dat, expected_dat)
117
+
118
+
119
+ def test_trapezoid():
120
+ bands = [(5.0, 20.0), (30.0, 50.0)]
121
+ in_msgs = [_ for _ in get_msg_gen()]
122
+ gen = ranged_aggregate(
123
+ axis="freq", bands=bands, operation=AggregationFunction.TRAPEZOID
124
+ )
125
+ out_msgs = [gen.send(_) for _ in in_msgs]
126
+
127
+ out_dat = AxisArray.concatenate(*out_msgs, dim="time").data
128
+
129
+ # Calculate expected data using trapezoidal integration
130
+ in_data = AxisArray.concatenate(*in_msgs, dim="time").data
131
+ targ_ax = in_msgs[0].axes["freq"]
132
+ targ_ax_vec = targ_ax.value(np.arange(in_data.shape[-1]))
133
+ expected = []
134
+ for start, stop in bands:
135
+ inds = np.logical_and(targ_ax_vec >= start, targ_ax_vec <= stop)
136
+ expected.append(np.trapezoid(in_data[..., inds], x=targ_ax_vec[inds], axis=-1))
137
+ expected = np.stack(expected, axis=-1)
138
+
139
+ assert out_dat.shape == expected.shape
140
+ assert np.allclose(out_dat, expected)
141
+
142
+
143
+ @pytest.mark.parametrize("change_ax", ["ch", "freq"])
144
+ def test_aggregate_handle_change(change_ax: str):
145
+ """
146
+ If ranged_aggregate couldn't handle incoming changes, then
147
+ change_ax being 'ch' should work while 'freq' should fail.
148
+ """
149
+ in_msgs1 = [_ for _ in get_msg_gen(n_chans=20, n_freqs=100)]
150
+ in_msgs2 = [
151
+ _
152
+ for _ in get_msg_gen(
153
+ n_chans=17 if change_ax == "ch" else 20,
154
+ n_freqs=70 if change_ax == "freq" else 100,
155
+ )
156
+ ]
157
+
158
+ gen = ranged_aggregate(
159
+ axis="freq",
160
+ bands=[(5.0, 20.0), (30.0, 50.0)],
161
+ operation=AggregationFunction.MEAN,
162
+ )
163
+
164
+ out_msgs1 = [gen.send(_) for _ in in_msgs1]
165
+ print(len(out_msgs1))
166
+ out_msgs2 = [gen.send(_) for _ in in_msgs2]
167
+ print(len(out_msgs2))
168
+
169
+
170
+ # ============== Tests for AggregateTransformer ==============
171
+
172
+
173
+ def get_simple_msg(n_times=10, n_chans=5, n_freqs=8, fs=100.0):
174
+ """Create a simple AxisArray message for testing AggregateTransformer."""
175
+ data = np.arange(n_times * n_chans * n_freqs, dtype=float).reshape(
176
+ n_times, n_chans, n_freqs
177
+ )
178
+ return AxisArray(
179
+ data=data,
180
+ dims=["time", "ch", "freq"],
181
+ axes=frozendict(
182
+ {
183
+ "time": AxisArray.TimeAxis(fs=fs, offset=0.0),
184
+ "ch": AxisArray.CoordinateAxis(
185
+ data=np.array([f"ch{i}" for i in range(n_chans)]),
186
+ dims=["ch"],
187
+ ),
188
+ "freq": AxisArray.LinearAxis(gain=2.0, offset=1.0, unit="Hz"),
189
+ }
190
+ ),
191
+ )
192
+
193
+
194
+ @pytest.mark.parametrize(
195
+ "operation",
196
+ [
197
+ AggregationFunction.MEAN,
198
+ AggregationFunction.SUM,
199
+ AggregationFunction.MAX,
200
+ AggregationFunction.MIN,
201
+ AggregationFunction.STD,
202
+ AggregationFunction.MEDIAN,
203
+ ],
204
+ )
205
+ def test_aggregate_transformer_basic(operation: AggregationFunction):
206
+ """Test AggregateTransformer with basic aggregation operations."""
207
+ msg_in = get_simple_msg()
208
+ backup = copy.deepcopy(msg_in)
209
+
210
+ transformer = AggregateTransformer(
211
+ AggregateSettings(axis="freq", operation=operation)
212
+ )
213
+ msg_out = transformer(msg_in)
214
+
215
+ # Verify input wasn't modified
216
+ assert_messages_equal([msg_in], [backup])
217
+
218
+ # Verify output type
219
+ assert isinstance(msg_out, AxisArray)
220
+
221
+ # Verify axis was removed
222
+ assert "freq" not in msg_out.dims
223
+ assert "freq" not in msg_out.axes
224
+ assert msg_out.dims == ["time", "ch"]
225
+
226
+ # Verify output shape
227
+ assert msg_out.data.shape == (10, 5)
228
+
229
+ # Verify data correctness
230
+ np_func = getattr(np, operation.value)
231
+ expected = np_func(msg_in.data, axis=2)
232
+ assert np.allclose(msg_out.data, expected)
233
+
234
+
235
+ @pytest.mark.parametrize("axis", ["time", "ch", "freq"])
236
+ def test_aggregate_transformer_different_axes(axis: str):
237
+ """Test AggregateTransformer can aggregate along different axes."""
238
+ msg_in = get_simple_msg(n_times=10, n_chans=5, n_freqs=8)
239
+
240
+ transformer = AggregateTransformer(
241
+ AggregateSettings(axis=axis, operation=AggregationFunction.MEAN)
242
+ )
243
+ msg_out = transformer(msg_in)
244
+
245
+ # Verify the specified axis was removed
246
+ assert axis not in msg_out.dims
247
+ assert axis not in msg_out.axes
248
+
249
+ # Verify remaining dims
250
+ expected_dims = [d for d in ["time", "ch", "freq"] if d != axis]
251
+ assert msg_out.dims == expected_dims
252
+
253
+ # Verify shape
254
+ axis_idx = msg_in.get_axis_idx(axis)
255
+ expected_shape = list(msg_in.data.shape)
256
+ expected_shape.pop(axis_idx)
257
+ assert msg_out.data.shape == tuple(expected_shape)
258
+
259
+ # Verify data
260
+ expected = np.mean(msg_in.data, axis=axis_idx)
261
+ assert np.allclose(msg_out.data, expected)
262
+
263
+
264
+ def test_aggregate_transformer_none_raises():
265
+ """Test that AggregationFunction.NONE raises an error."""
266
+ msg_in = get_simple_msg()
267
+
268
+ transformer = AggregateTransformer(
269
+ AggregateSettings(axis="freq", operation=AggregationFunction.NONE)
270
+ )
271
+
272
+ with pytest.raises(ValueError, match="NONE is not supported"):
273
+ transformer(msg_in)
274
+
275
+
276
+ @pytest.mark.parametrize(
277
+ "operation",
278
+ [
279
+ AggregationFunction.NANMEAN,
280
+ AggregationFunction.NANSUM,
281
+ AggregationFunction.NANMAX,
282
+ AggregationFunction.NANMIN,
283
+ AggregationFunction.NANSTD,
284
+ AggregationFunction.NANMEDIAN,
285
+ ],
286
+ )
287
+ def test_aggregate_transformer_nan_operations(operation: AggregationFunction):
288
+ """Test AggregateTransformer with NaN-aware operations."""
289
+ msg_in = get_simple_msg()
290
+ # Introduce some NaN values
291
+ msg_in.data[0, 0, 0] = np.nan
292
+ msg_in.data[5, 2, 3] = np.nan
293
+
294
+ transformer = AggregateTransformer(
295
+ AggregateSettings(axis="freq", operation=operation)
296
+ )
297
+ msg_out = transformer(msg_in)
298
+
299
+ # Verify output doesn't have NaN where nan-operations should have handled it
300
+ np_func = getattr(np, operation.value)
301
+ expected = np_func(msg_in.data, axis=2)
302
+ assert np.allclose(msg_out.data, expected, equal_nan=True)
303
+
304
+
305
+ @pytest.mark.parametrize(
306
+ "operation", [AggregationFunction.ARGMIN, AggregationFunction.ARGMAX]
307
+ )
308
+ def test_aggregate_transformer_argminmax(operation: AggregationFunction):
309
+ """Test AggregateTransformer with argmin/argmax operations."""
310
+ msg_in = get_simple_msg()
311
+
312
+ transformer = AggregateTransformer(
313
+ AggregateSettings(axis="freq", operation=operation)
314
+ )
315
+ msg_out = transformer(msg_in)
316
+
317
+ # Verify output shape (axis removed)
318
+ assert msg_out.data.shape == (10, 5)
319
+ assert "freq" not in msg_out.dims
320
+
321
+ # Verify data correctness (returns indices)
322
+ np_func = getattr(np, operation.value)
323
+ expected = np_func(msg_in.data, axis=2)
324
+ assert np.array_equal(msg_out.data, expected)
325
+
326
+
327
+ def test_aggregate_transformer_trapezoid():
328
+ """Test AggregateTransformer with trapezoid integration."""
329
+ msg_in = get_simple_msg(n_times=5, n_chans=3, n_freqs=10)
330
+
331
+ transformer = AggregateTransformer(
332
+ AggregateSettings(axis="freq", operation=AggregationFunction.TRAPEZOID)
333
+ )
334
+ msg_out = transformer(msg_in)
335
+
336
+ # Verify output shape
337
+ assert msg_out.data.shape == (5, 3)
338
+ assert "freq" not in msg_out.dims
339
+
340
+ # Calculate expected result using axis coordinates
341
+ freq_axis = msg_in.axes["freq"]
342
+ x = freq_axis.value(np.arange(msg_in.data.shape[2]))
343
+ expected = np.trapezoid(msg_in.data, x=x, axis=2)
344
+
345
+ assert np.allclose(msg_out.data, expected)
346
+
347
+
348
+ def test_aggregate_transformer_trapezoid_coordinate_axis():
349
+ """Test trapezoid integration with CoordinateAxis."""
350
+ n_times, n_chans, n_freqs = 5, 3, 10
351
+ data = np.arange(n_times * n_chans * n_freqs, dtype=float).reshape(
352
+ n_times, n_chans, n_freqs
353
+ )
354
+ freq_values = np.array([1.0, 2.0, 4.0, 7.0, 11.0, 16.0, 22.0, 29.0, 37.0, 46.0])
355
+ msg_in = AxisArray(
356
+ data=data,
357
+ dims=["time", "ch", "freq"],
358
+ axes=frozendict(
359
+ {
360
+ "time": AxisArray.TimeAxis(fs=100.0, offset=0.0),
361
+ "freq": AxisArray.CoordinateAxis(
362
+ data=freq_values, dims=["freq"], unit="Hz"
363
+ ),
364
+ }
365
+ ),
366
+ )
367
+
368
+ transformer = AggregateTransformer(
369
+ AggregateSettings(axis="freq", operation=AggregationFunction.TRAPEZOID)
370
+ )
371
+ msg_out = transformer(msg_in)
372
+
373
+ # Calculate expected using the coordinate values
374
+ expected = np.trapezoid(msg_in.data, x=freq_values, axis=2)
375
+ assert np.allclose(msg_out.data, expected)
376
+
377
+
378
+ def test_aggregate_transformer_preserves_other_axes():
379
+ """Test that non-aggregated axes are preserved correctly."""
380
+ msg_in = get_simple_msg()
381
+
382
+ transformer = AggregateTransformer(
383
+ AggregateSettings(axis="freq", operation=AggregationFunction.MEAN)
384
+ )
385
+ msg_out = transformer(msg_in)
386
+
387
+ # Verify time axis preserved
388
+ assert "time" in msg_out.axes
389
+ assert msg_out.axes["time"] == msg_in.axes["time"]
390
+
391
+ # Verify ch axis preserved
392
+ assert "ch" in msg_out.axes
393
+ ch_ax_in = msg_in.axes["ch"]
394
+ ch_ax_out = msg_out.axes["ch"]
395
+ assert np.array_equal(ch_ax_out.data, ch_ax_in.data)
396
+
397
+
398
+ def test_aggregate_transformer_multiple_calls():
399
+ """Test that transformer works correctly with multiple calls."""
400
+ transformer = AggregateTransformer(
401
+ AggregateSettings(axis="freq", operation=AggregationFunction.SUM)
402
+ )
403
+
404
+ for i in range(3):
405
+ msg_in = get_simple_msg()
406
+ msg_in.data = msg_in.data + i * 1000 # Different data each time
407
+
408
+ msg_out = transformer(msg_in)
409
+
410
+ expected = np.sum(msg_in.data, axis=2)
411
+ assert np.allclose(msg_out.data, expected)
@@ -1,161 +0,0 @@
1
- from functools import partial
2
-
3
- import numpy as np
4
- import pytest
5
- from frozendict import frozendict
6
- from ezmsg.util.messages.axisarray import AxisArray
7
-
8
- from ezmsg.sigproc.aggregate import ranged_aggregate, AggregationFunction
9
-
10
- from tests.helpers.util import assert_messages_equal
11
-
12
-
13
- def get_msg_gen(n_chans=20, n_freqs=100, data_dur=30.0, fs=1024.0, key=""):
14
- n_samples = int(data_dur * fs)
15
- data = np.arange(n_samples * n_chans * n_freqs).reshape(n_samples, n_chans, n_freqs)
16
- n_msgs = int(data_dur / 2)
17
-
18
- def msg_generator():
19
- offset = 0
20
- for arr in np.array_split(data, n_samples // n_msgs):
21
- msg = AxisArray(
22
- data=arr,
23
- dims=["time", "ch", "freq"],
24
- axes=frozendict(
25
- {
26
- "time": AxisArray.TimeAxis(fs=fs, offset=offset),
27
- "freq": AxisArray.LinearAxis(gain=1.0, offset=0.0, unit="Hz"),
28
- }
29
- ),
30
- key=key,
31
- )
32
- offset += arr.shape[0] / fs
33
- yield msg
34
-
35
- return msg_generator()
36
-
37
-
38
- @pytest.mark.parametrize(
39
- "agg_func",
40
- [
41
- AggregationFunction.MEAN,
42
- AggregationFunction.MEDIAN,
43
- AggregationFunction.STD,
44
- AggregationFunction.SUM,
45
- ],
46
- )
47
- def test_aggregate(agg_func: AggregationFunction):
48
- bands = [(5.0, 20.0), (30.0, 50.0)]
49
- targ_ax = "freq"
50
-
51
- in_msgs = [_ for _ in get_msg_gen()]
52
-
53
- # Grab a deepcopy backup of the inputs so we can check the inputs didn't change
54
- # while being processed.
55
- import copy
56
-
57
- backup = [copy.deepcopy(_) for _ in in_msgs]
58
-
59
- gen = ranged_aggregate(axis=targ_ax, bands=bands, operation=agg_func)
60
- out_msgs = [gen.send(_) for _ in in_msgs]
61
-
62
- assert_messages_equal(in_msgs, backup)
63
-
64
- assert all([type(_) is AxisArray for _ in out_msgs])
65
-
66
- # Check output axis
67
- for out_msg in out_msgs:
68
- ax = out_msg.axes[targ_ax]
69
- assert np.array_equal(ax.data, np.array([np.mean(band) for band in bands]))
70
- assert ax.unit == in_msgs[0].axes[targ_ax].unit
71
-
72
- # Check data
73
- data = AxisArray.concatenate(*in_msgs, dim="time").data
74
- targ_ax = in_msgs[0].axes[targ_ax]
75
- targ_ax_vec = targ_ax.value(np.arange(data.shape[-1]))
76
- agg_func = {
77
- AggregationFunction.MEAN: partial(np.mean, axis=-1, keepdims=True),
78
- AggregationFunction.MEDIAN: partial(np.median, axis=-1, keepdims=True),
79
- AggregationFunction.STD: partial(np.std, axis=-1, keepdims=True),
80
- AggregationFunction.SUM: partial(np.sum, axis=-1, keepdims=True),
81
- }[agg_func]
82
- expected_data = np.concatenate(
83
- [
84
- agg_func(
85
- data[..., np.logical_and(targ_ax_vec >= start, targ_ax_vec <= stop)]
86
- )
87
- for (start, stop) in bands
88
- ],
89
- axis=-1,
90
- )
91
- received_data = AxisArray.concatenate(*out_msgs, dim="time").data
92
- assert np.allclose(received_data, expected_data)
93
-
94
-
95
- @pytest.mark.parametrize(
96
- "agg_func", [AggregationFunction.ARGMIN, AggregationFunction.ARGMAX]
97
- )
98
- def test_arg_aggregate(agg_func: AggregationFunction):
99
- bands = [(5.0, 20.0), (30.0, 50.0)]
100
- in_msgs = [_ for _ in get_msg_gen()]
101
- gen = ranged_aggregate(axis="freq", bands=bands, operation=agg_func)
102
- out_msgs = [gen.send(_) for _ in in_msgs]
103
-
104
- if agg_func == AggregationFunction.ARGMIN:
105
- expected_vals = np.array([np.min(_) for _ in bands])
106
- else:
107
- expected_vals = np.array([np.max(_) for _ in bands])
108
- out_dat = AxisArray.concatenate(*out_msgs, dim="time").data
109
- expected_dat = np.zeros(out_dat.shape[:-1] + (1,)) + expected_vals[None, None, :]
110
- assert np.array_equal(out_dat, expected_dat)
111
-
112
-
113
- def test_trapezoid():
114
- bands = [(5.0, 20.0), (30.0, 50.0)]
115
- in_msgs = [_ for _ in get_msg_gen()]
116
- gen = ranged_aggregate(
117
- axis="freq", bands=bands, operation=AggregationFunction.TRAPEZOID
118
- )
119
- out_msgs = [gen.send(_) for _ in in_msgs]
120
-
121
- out_dat = AxisArray.concatenate(*out_msgs, dim="time").data
122
-
123
- # Calculate expected data using trapezoidal integration
124
- in_data = AxisArray.concatenate(*in_msgs, dim="time").data
125
- targ_ax = in_msgs[0].axes["freq"]
126
- targ_ax_vec = targ_ax.value(np.arange(in_data.shape[-1]))
127
- expected = []
128
- for start, stop in bands:
129
- inds = np.logical_and(targ_ax_vec >= start, targ_ax_vec <= stop)
130
- expected.append(np.trapezoid(in_data[..., inds], x=targ_ax_vec[inds], axis=-1))
131
- expected = np.stack(expected, axis=-1)
132
-
133
- assert out_dat.shape == expected.shape
134
- assert np.allclose(out_dat, expected)
135
-
136
-
137
- @pytest.mark.parametrize("change_ax", ["ch", "freq"])
138
- def test_aggregate_handle_change(change_ax: str):
139
- """
140
- If ranged_aggregate couldn't handle incoming changes, then
141
- change_ax being 'ch' should work while 'freq' should fail.
142
- """
143
- in_msgs1 = [_ for _ in get_msg_gen(n_chans=20, n_freqs=100)]
144
- in_msgs2 = [
145
- _
146
- for _ in get_msg_gen(
147
- n_chans=17 if change_ax == "ch" else 20,
148
- n_freqs=70 if change_ax == "freq" else 100,
149
- )
150
- ]
151
-
152
- gen = ranged_aggregate(
153
- axis="freq",
154
- bands=[(5.0, 20.0), (30.0, 50.0)],
155
- operation=AggregationFunction.MEAN,
156
- )
157
-
158
- out_msgs1 = [gen.send(_) for _ in in_msgs1]
159
- print(len(out_msgs1))
160
- out_msgs2 = [gen.send(_) for _ in in_msgs2]
161
- print(len(out_msgs2))
File without changes
File without changes
File without changes