ezmsg-sigproc 1.4.2__tar.gz → 1.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/.github/workflows/python-tests.yml +1 -1
  2. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/PKG-INFO +4 -5
  3. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/pyproject.toml +2 -2
  4. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/__version__.py +2 -2
  5. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/activation.py +2 -2
  6. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/affinetransform.py +13 -13
  7. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/aggregate.py +49 -28
  8. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/bandpower.py +2 -2
  9. ezmsg_sigproc-1.6.0/src/ezmsg/sigproc/butterworthfilter.py +160 -0
  10. ezmsg_sigproc-1.6.0/src/ezmsg/sigproc/cheby.py +119 -0
  11. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/decimate.py +11 -15
  12. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/downsample.py +8 -4
  13. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/ewmfilter.py +9 -5
  14. ezmsg_sigproc-1.6.0/src/ezmsg/sigproc/filter.py +199 -0
  15. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/filterbank.py +5 -5
  16. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/math/abs.py +1 -1
  17. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/math/clip.py +1 -1
  18. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/math/difference.py +1 -1
  19. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/math/invert.py +1 -1
  20. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/math/log.py +1 -1
  21. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/math/scale.py +1 -1
  22. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/messages.py +2 -3
  23. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/sampler.py +16 -15
  24. ezmsg_sigproc-1.6.0/src/ezmsg/sigproc/scaler.py +290 -0
  25. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/signalinjector.py +7 -7
  26. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/slicer.py +34 -14
  27. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/spectrogram.py +6 -6
  28. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/spectrum.py +18 -14
  29. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/synth.py +43 -27
  30. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/wavelets.py +42 -17
  31. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/window.py +14 -13
  32. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/tests/helpers/util.py +15 -8
  33. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/tests/test_activation.py +1 -1
  34. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/tests/test_affine_transform.py +8 -24
  35. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/tests/test_aggregate.py +4 -6
  36. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/tests/test_bandpower.py +2 -8
  37. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/tests/test_butter.py +19 -10
  38. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/tests/test_butterworth.py +2 -5
  39. ezmsg_sigproc-1.6.0/tests/test_decimate.py +59 -0
  40. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/tests/test_downsample.py +12 -6
  41. ezmsg_sigproc-1.6.0/tests/test_filter_system.py +117 -0
  42. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/tests/test_filterbank.py +1 -1
  43. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/tests/test_sampler.py +7 -6
  44. ezmsg_sigproc-1.6.0/tests/test_scaler.py +150 -0
  45. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/tests/test_slicer.py +23 -23
  46. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/tests/test_spectrogram.py +1 -4
  47. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/tests/test_spectrum.py +7 -12
  48. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/tests/test_synth.py +19 -16
  49. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/tests/test_wavelets.py +48 -28
  50. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/tests/test_window.py +33 -16
  51. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/uv.lock +112 -164
  52. ezmsg_sigproc-1.4.2/src/ezmsg/sigproc/butterworthfilter.py +0 -161
  53. ezmsg_sigproc-1.4.2/src/ezmsg/sigproc/filter.py +0 -232
  54. ezmsg_sigproc-1.4.2/src/ezmsg/sigproc/scaler.py +0 -172
  55. ezmsg_sigproc-1.4.2/tests/test_scaler.py +0 -132
  56. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/.github/workflows/python-publish-ezmsg-sigproc.yml +0 -0
  57. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/.gitignore +0 -0
  58. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/.pre-commit-config.yaml +0 -0
  59. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/LICENSE.txt +0 -0
  60. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/README.md +0 -0
  61. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/__init__.py +0 -0
  62. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/base.py +0 -0
  63. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/math/__init__.py +0 -0
  64. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/src/ezmsg/sigproc/spectral.py +0 -0
  65. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/tests/conftest.py +0 -0
  66. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/tests/helpers/__init__.py +0 -0
  67. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/tests/resources/xform.csv +0 -0
  68. {ezmsg_sigproc-1.4.2 → ezmsg_sigproc-1.6.0}/tests/test_math.py +0 -0
@@ -13,7 +13,7 @@ jobs:
13
13
  build:
14
14
  strategy:
15
15
  matrix:
16
- python-version: [3.9, "3.10", "3.11", "3.12"]
16
+ python-version: ["3.10", "3.11", "3.12"]
17
17
  os:
18
18
  - "ubuntu-latest"
19
19
  - "windows-latest"
@@ -1,12 +1,11 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: ezmsg-sigproc
3
- Version: 1.4.2
3
+ Version: 1.6.0
4
4
  Summary: Timeseries signal processing implementations in ezmsg
5
5
  Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>
6
- License-Expression: MIT
7
- License-File: LICENSE.txt
8
- Requires-Python: >=3.9
9
- Requires-Dist: ezmsg>=3.5.0
6
+ License: MIT
7
+ Requires-Python: >=3.10.15
8
+ Requires-Dist: ezmsg>=3.6.0
10
9
  Requires-Dist: numpy>=1.26.0
11
10
  Requires-Dist: pywavelets>=1.6.0
12
11
  Requires-Dist: scipy>=1.13.1
@@ -8,10 +8,10 @@ authors = [
8
8
  ]
9
9
  license = "MIT"
10
10
  readme = "README.md"
11
- requires-python = ">=3.9"
11
+ requires-python = ">=3.10.15"
12
12
  dynamic = ["version"]
13
13
  dependencies = [
14
- "ezmsg>=3.5.0",
14
+ "ezmsg>=3.6.0",
15
15
  "numpy>=1.26.0",
16
16
  "pywavelets>=1.6.0",
17
17
  "scipy>=1.13.1",
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '1.4.2'
16
- __version_tuple__ = version_tuple = (1, 4, 2)
15
+ __version__ = version = '1.6.0'
16
+ __version_tuple__ = version_tuple = (1, 6, 0)
@@ -1,10 +1,10 @@
1
- from dataclasses import replace
2
1
  import typing
3
2
 
4
3
  import numpy as np
5
4
  import scipy.special
6
5
  import ezmsg.core as ez
7
6
  from ezmsg.util.messages.axisarray import AxisArray
7
+ from ezmsg.util.messages.util import replace
8
8
  from ezmsg.util.generator import consumer
9
9
 
10
10
  from .spectral import OptionsEnum
@@ -41,7 +41,7 @@ ACTIVATIONS = {
41
41
 
42
42
  @consumer
43
43
  def activation(
44
- function: typing.Union[str, ActivationFunction],
44
+ function: str | ActivationFunction,
45
45
  ) -> typing.Generator[AxisArray, AxisArray, None]:
46
46
  """
47
47
  Transform the data with a simple activation function.
@@ -1,4 +1,3 @@
1
- from dataclasses import replace
2
1
  import os
3
2
  from pathlib import Path
4
3
  import typing
@@ -6,7 +5,8 @@ import typing
6
5
  import numpy as np
7
6
  import numpy.typing as npt
8
7
  import ezmsg.core as ez
9
- from ezmsg.util.messages.axisarray import AxisArray
8
+ from ezmsg.util.messages.axisarray import AxisArray, AxisBase
9
+ from ezmsg.util.messages.util import replace
10
10
  from ezmsg.util.generator import consumer
11
11
 
12
12
  from .base import GenAxisArray
@@ -14,8 +14,8 @@ from .base import GenAxisArray
14
14
 
15
15
  @consumer
16
16
  def affine_transform(
17
- weights: typing.Union[np.ndarray, str, Path],
18
- axis: typing.Optional[str] = None,
17
+ weights: np.ndarray | str | Path,
18
+ axis: str | None = None,
19
19
  right_multiply: bool = True,
20
20
  ) -> typing.Generator[AxisArray, AxisArray, None]:
21
21
  """
@@ -47,7 +47,7 @@ def affine_transform(
47
47
 
48
48
  # State variables
49
49
  # New axis with transformed labels, if required
50
- new_axis: typing.Optional[AxisArray.Axis] = None
50
+ new_axis: AxisBase | None = None
51
51
 
52
52
  # Reset if any of these change.
53
53
  check_input = {"key": None}
@@ -71,10 +71,10 @@ def affine_transform(
71
71
  # Determine if we need to modify the transformed axis.
72
72
  if (
73
73
  axis in msg_in.axes
74
- and hasattr(msg_in.axes[axis], "labels")
74
+ and hasattr(msg_in.axes[axis], "data")
75
75
  and weights.shape[0] != weights.shape[1]
76
76
  ):
77
- in_labels = msg_in.axes[axis].labels
77
+ in_labels = msg_in.axes[axis].data
78
78
  new_labels = []
79
79
  n_in, n_out = weights.shape
80
80
  if len(in_labels) != n_in:
@@ -101,8 +101,8 @@ def affine_transform(
101
101
  new_labels.append("")
102
102
  elif np.all(b_filled_outputs):
103
103
  # Transform is dropping some of the inputs.
104
- new_labels = np.array(in_labels)[b_used_inputs].tolist()
105
- new_axis = replace(msg_in.axes[axis], labels=new_labels)
104
+ new_labels = np.array(in_labels)[b_used_inputs]
105
+ new_axis = replace(msg_in.axes[axis], data=np.array(new_labels))
106
106
 
107
107
  data = msg_in.data
108
108
 
@@ -133,8 +133,8 @@ class AffineTransformSettings(ez.Settings):
133
133
  See :obj:`affine_transform` for argument details.
134
134
  """
135
135
 
136
- weights: typing.Union[np.ndarray, str, Path]
137
- axis: typing.Optional[str] = None
136
+ weights: np.ndarray | str | Path
137
+ axis: str | None = None
138
138
  right_multiply: bool = True
139
139
 
140
140
 
@@ -157,7 +157,7 @@ def zeros_for_noop(data: npt.NDArray, **ignore_kwargs) -> npt.NDArray:
157
157
 
158
158
  @consumer
159
159
  def common_rereference(
160
- mode: str = "mean", axis: typing.Optional[str] = None, include_current: bool = True
160
+ mode: str = "mean", axis: str | None = None, include_current: bool = True
161
161
  ) -> typing.Generator[AxisArray, AxisArray, None]:
162
162
  """
163
163
  Perform common average referencing (CAR) on streaming data.
@@ -214,7 +214,7 @@ class CommonRereferenceSettings(ez.Settings):
214
214
  """
215
215
 
216
216
  mode: str = "mean"
217
- axis: typing.Optional[str] = None
217
+ axis: str | None = None
218
218
  include_current: bool = True
219
219
 
220
220
 
@@ -1,11 +1,15 @@
1
- from dataclasses import replace
2
1
  import typing
3
2
 
4
3
  import numpy as np
5
4
  import numpy.typing as npt
6
5
  import ezmsg.core as ez
7
6
  from ezmsg.util.generator import consumer
8
- from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
7
+ from ezmsg.util.messages.axisarray import (
8
+ AxisArray,
9
+ slice_along_axis,
10
+ AxisBase,
11
+ replace,
12
+ )
9
13
 
10
14
  from .spectral import OptionsEnum
11
15
  from .base import GenAxisArray
@@ -52,8 +56,8 @@ AGGREGATORS = {
52
56
 
53
57
  @consumer
54
58
  def ranged_aggregate(
55
- axis: typing.Optional[str] = None,
56
- bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = None,
59
+ axis: str | None = None,
60
+ bands: list[tuple[float, float]] | None = None,
57
61
  operation: AggregationFunction = AggregationFunction.MEAN,
58
62
  ):
59
63
  """
@@ -71,12 +75,12 @@ def ranged_aggregate(
71
75
  msg_out = AxisArray(np.array([]), dims=[""])
72
76
 
73
77
  # State variables
74
- slices: typing.Optional[typing.List[typing.Tuple[typing.Any, ...]]] = None
75
- out_axis: typing.Optional[AxisArray.Axis] = None
76
- ax_vec: typing.Optional[npt.NDArray] = None
78
+ slices: list[tuple[typing.Any, ...]] | None = None
79
+ out_axis: AxisBase | None = None
80
+ ax_vec: npt.NDArray | None = None
77
81
 
78
82
  # Reset if any of these changes. Key not checked because continuity between chunks not required.
79
- check_inputs = {"gain": None, "offset": None}
83
+ check_inputs = {"gain": None, "offset": None, "len": None, "key": None}
80
84
 
81
85
  while True:
82
86
  msg_in: AxisArray = yield msg_out
@@ -86,35 +90,52 @@ def ranged_aggregate(
86
90
  axis = axis or msg_in.dims[0]
87
91
  target_axis = msg_in.get_axis(axis)
88
92
 
89
- b_reset = target_axis.gain != check_inputs["gain"]
90
- b_reset = b_reset or target_axis.offset != check_inputs["offset"]
93
+ # Check if we need to reset state
94
+ b_reset = msg_in.key != check_inputs["key"]
95
+ if hasattr(target_axis, "data"):
96
+ b_reset = b_reset or len(target_axis.data) != check_inputs["len"]
97
+ elif isinstance(target_axis, AxisArray.LinearAxis):
98
+ b_reset = b_reset or target_axis.gain != check_inputs["gain"]
99
+ b_reset = b_reset or target_axis.offset != check_inputs["offset"]
100
+
91
101
  if b_reset:
92
- check_inputs["gain"] = target_axis.gain
93
- check_inputs["offset"] = target_axis.offset
102
+ # Update check variables
103
+ check_inputs["key"] = msg_in.key
104
+ if hasattr(target_axis, "data"):
105
+ check_inputs["len"] = len(target_axis.data)
106
+ else:
107
+ check_inputs["gain"] = target_axis.gain
108
+ check_inputs["offset"] = target_axis.offset
94
109
 
95
110
  # If the axis we are operating on has changed (e.g., "time" or "win" axis always changes),
96
111
  # or the key has changed, then recalculate slices.
97
112
 
98
113
  ax_idx = msg_in.get_axis_idx(axis)
99
114
 
100
- ax_vec = (
101
- target_axis.offset
102
- + np.arange(msg_in.data.shape[ax_idx]) * target_axis.gain
103
- )
115
+ if hasattr(target_axis, "data"):
116
+ ax_vec = target_axis.data
117
+ else:
118
+ ax_vec = target_axis.value(np.arange(msg_in.data.shape[ax_idx]))
119
+
104
120
  slices = []
105
- mids = []
121
+ ax_dat = []
106
122
  for start, stop in bands:
107
123
  inds = np.where(np.logical_and(ax_vec >= start, ax_vec <= stop))[0]
108
- mids.append(np.mean(inds) * target_axis.gain + target_axis.offset)
109
124
  slices.append(np.s_[inds[0] : inds[-1] + 1])
110
- out_ax_kwargs = {
111
- "unit": target_axis.unit,
112
- "offset": mids[0],
113
- "gain": (mids[1] - mids[0]) if len(mids) > 1 else 1.0,
114
- }
115
- if hasattr(target_axis, "labels"):
116
- out_ax_kwargs["labels"] = [f"{_[0]} - {_[1]}" for _ in bands]
117
- out_axis = replace(target_axis, **out_ax_kwargs)
125
+ if hasattr(target_axis, "data"):
126
+ if ax_vec.dtype.type is np.str_:
127
+ sl_dat = f"{ax_vec[start]} - {ax_vec[stop]}"
128
+ else:
129
+ sl_dat = ax_dat.append(np.mean(ax_vec[inds]))
130
+ else:
131
+ sl_dat = target_axis.value(np.mean(inds))
132
+ ax_dat.append(sl_dat)
133
+
134
+ out_axis = AxisArray.CoordinateAxis(
135
+ data=np.array(ax_dat),
136
+ dims=[axis],
137
+ unit=target_axis.unit,
138
+ )
118
139
 
119
140
  agg_func = AGGREGATORS[operation]
120
141
  out_data = [
@@ -142,8 +163,8 @@ class RangedAggregateSettings(ez.Settings):
142
163
  See :obj:`ranged_aggregate` for details.
143
164
  """
144
165
 
145
- axis: typing.Optional[str] = None
146
- bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = None
166
+ axis: str | None = None
167
+ bands: list[tuple[float, float]] | None = None
147
168
  operation: AggregationFunction = AggregationFunction.MEAN
148
169
 
149
170
 
@@ -14,7 +14,7 @@ from .base import GenAxisArray
14
14
  @consumer
15
15
  def bandpower(
16
16
  spectrogram_settings: SpectrogramSettings,
17
- bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = [
17
+ bands: list[tuple[float, float]] | None = [
18
18
  (17, 30),
19
19
  (70, 170),
20
20
  ],
@@ -58,7 +58,7 @@ class BandPowerSettings(ez.Settings):
58
58
  spectrogram_settings: SpectrogramSettings = field(
59
59
  default_factory=SpectrogramSettings
60
60
  )
61
- bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = field(
61
+ bands: list[tuple[float, float]] | None = field(
62
62
  default_factory=lambda: [(17, 30), (70, 170)]
63
63
  )
64
64
 
@@ -0,0 +1,160 @@
1
+ import functools
2
+ import typing
3
+
4
+ import scipy.signal
5
+ from ezmsg.util.messages.axisarray import AxisArray
6
+ from scipy.signal import normalize
7
+
8
+ from .filter import (
9
+ FilterBaseSettings,
10
+ FilterCoefsMultiType,
11
+ FilterBase,
12
+ filter_gen_by_design,
13
+ )
14
+
15
+
16
+ class ButterworthFilterSettings(FilterBaseSettings):
17
+ """Settings for :obj:`ButterworthFilter`."""
18
+
19
+ order: int = 0
20
+ """
21
+ Filter order
22
+ """
23
+
24
+ cuton: float | None = None
25
+ """
26
+ Cuton frequency (Hz). If `cutoff` is not specified then this is the highpass corner. Otherwise,
27
+ if this is lower than `cutoff` then this is the beginning of the bandpass
28
+ or if this is greater than `cutoff` then this is the end of the bandstop.
29
+ """
30
+
31
+ cutoff: float | None = None
32
+ """
33
+ Cutoff frequency (Hz). If `cuton` is not specified then this is the lowpass corner. Otherwise,
34
+ if this is greater than `cuton` then this is the end of the bandpass,
35
+ or if this is less than `cuton` then this is the beginning of the bandstop.
36
+ """
37
+
38
+ wn_hz: bool = True
39
+ """
40
+ Set False if provided Wn are normalized from 0 to 1, where 1 is the Nyquist frequency
41
+ """
42
+
43
+ def filter_specs(
44
+ self,
45
+ ) -> tuple[str, float | tuple[float, float]] | None:
46
+ """
47
+ Determine the filter type given the corner frequencies.
48
+
49
+ Returns:
50
+ A tuple with the first element being a string indicating the filter type
51
+ (one of "lowpass", "highpass", "bandpass", "bandstop")
52
+ and the second element being the corner frequency or frequencies.
53
+
54
+ """
55
+ if self.cuton is None and self.cutoff is None:
56
+ return None
57
+ elif self.cuton is None and self.cutoff is not None:
58
+ return "lowpass", self.cutoff
59
+ elif self.cuton is not None and self.cutoff is None:
60
+ return "highpass", self.cuton
61
+ elif self.cuton is not None and self.cutoff is not None:
62
+ if self.cuton <= self.cutoff:
63
+ return "bandpass", (self.cuton, self.cutoff)
64
+ else:
65
+ return "bandstop", (self.cutoff, self.cuton)
66
+
67
+
68
+ def butter_design_fun(
69
+ fs: float,
70
+ order: int = 0,
71
+ cuton: float | None = None,
72
+ cutoff: float | None = None,
73
+ coef_type: str = "ba",
74
+ wn_hz: bool = True,
75
+ ) -> FilterCoefsMultiType | None:
76
+ """
77
+ See :obj:`ButterworthFilterSettings.filter_specs` for an explanation of specifying different
78
+ filter types (lowpass, highpass, bandpass, bandstop) from the parameters.
79
+ You are likely to want to use this function with :obj:`filter_by_design`, which only passes `fs` to the design
80
+ function (this), meaning that you should wrap this function with a lambda or prepare with functools.partial.
81
+
82
+ Args:
83
+ fs: The sampling frequency of the data in Hz.
84
+ order: Filter order.
85
+ cuton: Corner frequency of the filter in Hz.
86
+ cutoff: Corner frequency of the filter in Hz.
87
+ coef_type: "ba", "sos", or "zpk"
88
+ wn_hz: Set False if provided Wn are normalized from 0 to 1, where 1 is the Nyquist frequency
89
+
90
+ Returns:
91
+ The filter coefficients as a tuple of (b, a) for coef_type "ba", or as a single ndarray for "sos",
92
+ or (z, p, k) for "zpk".
93
+
94
+ """
95
+ coefs = None
96
+ if order > 0:
97
+ btype, cutoffs = ButterworthFilterSettings(
98
+ order=order, cuton=cuton, cutoff=cutoff
99
+ ).filter_specs()
100
+ coefs = scipy.signal.butter(
101
+ order,
102
+ Wn=cutoffs,
103
+ btype=btype,
104
+ fs=fs if wn_hz else None,
105
+ output=coef_type,
106
+ )
107
+ if coefs is not None and coef_type == "ba":
108
+ coefs = normalize(*coefs)
109
+ return coefs
110
+
111
+
112
+ class ButterworthFilter(FilterBase):
113
+ SETTINGS = ButterworthFilterSettings
114
+
115
+ def design_filter(
116
+ self,
117
+ ) -> typing.Callable[[float], FilterCoefsMultiType | None]:
118
+ return functools.partial(
119
+ butter_design_fun,
120
+ order=self.SETTINGS.order,
121
+ cuton=self.SETTINGS.cuton,
122
+ cutoff=self.SETTINGS.cutoff,
123
+ coef_type=self.SETTINGS.coef_type,
124
+ )
125
+
126
+
127
+ def butter(
128
+ axis: str | None,
129
+ order: int = 0,
130
+ cuton: float | None = None,
131
+ cutoff: float | None = None,
132
+ coef_type: str = "ba",
133
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
134
+ """
135
+ Convenience generator wrapping filter_gen_by_design for Butterworth filters.
136
+ Apply Butterworth filter to streaming data. Uses :obj:`scipy.signal.butter` to design the filter.
137
+ See :obj:`ButterworthFilterSettings.filter_specs` for an explanation of specifying different
138
+ filter types (lowpass, highpass, bandpass, bandstop) from the parameters.
139
+
140
+ Args:
141
+ axis: The name of the axis to filter.
142
+ Note: The axis must be represented in the message .axes and be of type AxisArray.LinearAxis.
143
+ order: Filter order.
144
+ cuton: Corner frequency of the filter in Hz.
145
+ cutoff: Corner frequency of the filter in Hz.
146
+ coef_type: "ba" or "sos"
147
+
148
+ Returns:
149
+ A primed generator object which accepts an :obj:`AxisArray` via .send(axis_array)
150
+ and yields an :obj:`AxisArray` with filtered data.
151
+
152
+ """
153
+ design_fun = functools.partial(
154
+ butter_design_fun,
155
+ order=order,
156
+ cuton=cuton,
157
+ cutoff=cutoff,
158
+ coef_type=coef_type,
159
+ )
160
+ return filter_gen_by_design(axis, coef_type, design_fun)
@@ -0,0 +1,119 @@
1
+ import functools
2
+ import typing
3
+
4
+ import scipy.signal
5
+ from scipy.signal import normalize
6
+
7
+ from .filter import (
8
+ FilterBaseSettings,
9
+ FilterCoefsMultiType,
10
+ FilterBase,
11
+ )
12
+
13
+
14
+ class ChebyshevFilterSettings(FilterBaseSettings):
15
+ """Settings for :obj:`ButterworthFilter`."""
16
+
17
+ order: int = 0
18
+ """
19
+ Filter order
20
+ """
21
+
22
+ ripple_tol: float | None = None
23
+ """
24
+ The maximum ripple allowed below unity gain in the passband. Specified in decibels, as a positive number.
25
+ """
26
+
27
+ Wn: float | tuple[float, float] | None = None
28
+ """
29
+ A scalar or length-2 sequence giving the critical frequencies.
30
+ For Type I filters, this is the point in the transition band at which the gain first drops below -rp.
31
+ For digital filters, Wn are in the same units as fs unless wn_hz is False.
32
+ For analog filters, Wn is an angular frequency (e.g., rad/s).
33
+ """
34
+
35
+ btype: str = "lowpass"
36
+ """
37
+ {‘lowpass’, ‘highpass’, ‘bandpass’, ‘bandstop’}
38
+ """
39
+
40
+ analog: bool = False
41
+ """
42
+ When True, return an analog filter, otherwise a digital filter is returned.
43
+ """
44
+
45
+ cheby_type: str = "cheby1"
46
+ """
47
+ Which type of Chebyshev filter to design. Either "cheby1" or "cheby2".
48
+ """
49
+
50
+ wn_hz: bool = True
51
+ """
52
+ Set False if provided Wn are normalized from 0 to 1, where 1 is the Nyquist frequency
53
+ """
54
+
55
+
56
+ def cheby_design_fun(
57
+ fs: float,
58
+ order: int = 0,
59
+ ripple_tol: float | None = None,
60
+ Wn: float | tuple[float, float] | None = None,
61
+ btype: str = "lowpass",
62
+ analog: bool = False,
63
+ coef_type: str = "ba",
64
+ cheby_type: str = "cheby1",
65
+ wn_hz: bool = True,
66
+ ) -> FilterCoefsMultiType:
67
+ """
68
+ Chebyshev type I and type II digital and analog filter design.
69
+ Design an `order`th-order digital or analog Chebyshev type I or type II filter and return the filter coefficients.
70
+ See :obj:`ChebyFilterSettings` for argument description.
71
+
72
+ Returns:
73
+ The filter coefficients as a tuple of (b, a) for coef_type "ba", or as a single ndarray for "sos",
74
+ or (z, p, k) for "zpk".
75
+ """
76
+ coefs = None
77
+ if order > 0:
78
+ if cheby_type == "cheby1":
79
+ coefs = scipy.signal.cheby1(
80
+ order,
81
+ ripple_tol,
82
+ Wn,
83
+ btype=btype,
84
+ analog=analog,
85
+ output=coef_type,
86
+ fs=fs if wn_hz else None,
87
+ )
88
+ elif cheby_type == "cheby2":
89
+ coefs = scipy.signal.cheby2(
90
+ order,
91
+ ripple_tol,
92
+ Wn,
93
+ btype=btype,
94
+ analog=analog,
95
+ output=coef_type,
96
+ fs=fs,
97
+ )
98
+ if coefs is not None and coef_type == "ba":
99
+ coefs = normalize(*coefs)
100
+ return coefs
101
+
102
+
103
+ class ChebyshevFilter(FilterBase):
104
+ SETTINGS = ChebyshevFilterSettings
105
+
106
+ def design_filter(
107
+ self,
108
+ ) -> typing.Callable[[float], FilterCoefsMultiType | None]:
109
+ return functools.partial(
110
+ cheby_design_fun,
111
+ order=self.SETTINGS.order,
112
+ ripple_tol=self.SETTINGS.ripple_tol,
113
+ Wn=self.SETTINGS.Wn,
114
+ btype=self.SETTINGS.btype,
115
+ analog=self.SETTINGS.analog,
116
+ coef_type=self.SETTINGS.coef_type,
117
+ cheby_type=self.SETTINGS.cheby_type,
118
+ wn_hz=self.SETTINGS.wn_hz,
119
+ )
@@ -1,9 +1,8 @@
1
- import scipy.signal
2
1
  import ezmsg.core as ez
3
2
  from ezmsg.util.messages.axisarray import AxisArray
4
3
 
4
+ from .cheby import ChebyshevFilter, ChebyshevFilterSettings
5
5
  from .downsample import Downsample, DownsampleSettings
6
- from .filter import Filter, FilterCoefficients, FilterSettings
7
6
 
8
7
 
9
8
  class Decimate(ez.Collection):
@@ -17,24 +16,21 @@ class Decimate(ez.Collection):
17
16
  INPUT_SIGNAL = ez.InputStream(AxisArray)
18
17
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
19
18
 
20
- FILTER = Filter()
19
+ FILTER = ChebyshevFilter()
21
20
  DOWNSAMPLE = Downsample()
22
21
 
23
22
  def configure(self) -> None:
23
+ cheby_settings = ChebyshevFilterSettings(
24
+ order=8 if self.SETTINGS.factor > 1 else 0,
25
+ ripple_tol=0.05,
26
+ Wn=0.8 / self.SETTINGS.factor if self.SETTINGS.factor > 1 else None,
27
+ btype="lowpass",
28
+ axis=self.SETTINGS.axis,
29
+ wn_hz=False,
30
+ )
31
+ self.FILTER.apply_settings(cheby_settings)
24
32
  self.DOWNSAMPLE.apply_settings(self.SETTINGS)
25
33
 
26
- if self.SETTINGS.factor < 1:
27
- raise ValueError("Decimation factor must be >= 1 (no decimation")
28
- elif self.SETTINGS.factor == 1:
29
- filt = FilterCoefficients()
30
- else:
31
- # See scipy.signal.decimate for IIR Filter Condition
32
- b, a = scipy.signal.cheby1(8, 0.05, 0.8 / self.SETTINGS.factor)
33
- system = scipy.signal.dlti(b, a)
34
- filt = FilterCoefficients(b=system.num, a=system.den) # type: ignore
35
-
36
- self.FILTER.apply_settings(FilterSettings(filt=filt))
37
-
38
34
  def network(self) -> ez.NetworkDefinition:
39
35
  return (
40
36
  (self.INPUT_SIGNAL, self.FILTER.INPUT_SIGNAL),
@@ -1,8 +1,11 @@
1
- from dataclasses import replace
2
1
  import typing
3
2
 
4
3
  import numpy as np
5
- from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
4
+ from ezmsg.util.messages.axisarray import (
5
+ AxisArray,
6
+ slice_along_axis,
7
+ replace,
8
+ )
6
9
  from ezmsg.util.generator import consumer
7
10
  import ezmsg.core as ez
8
11
 
@@ -11,7 +14,7 @@ from .base import GenAxisArray
11
14
 
12
15
  @consumer
13
16
  def downsample(
14
- axis: typing.Optional[str] = None, factor: int = 1
17
+ axis: str | None = None, factor: int = 1
15
18
  ) -> typing.Generator[AxisArray, AxisArray, None]:
16
19
  """
17
20
  Construct a generator that yields a downsampled version of the data .send() to it.
@@ -22,6 +25,7 @@ def downsample(
22
25
 
23
26
  Args:
24
27
  axis: The name of the axis along which to downsample.
28
+ Note: The axis must exist in the message .axes and be of type AxisArray.LinearAxis.
25
29
  factor: Downsampling factor.
26
30
 
27
31
  Returns:
@@ -92,7 +96,7 @@ class DownsampleSettings(ez.Settings):
92
96
  See :obj:`downsample` documentation for a description of the parameters.
93
97
  """
94
98
 
95
- axis: typing.Optional[str] = None
99
+ axis: str | None = None
96
100
  factor: int = 1
97
101
 
98
102