ezmsg-sigproc 2.4.1__py3-none-any.whl → 2.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +5 -11
  3. ezmsg/sigproc/adaptive_lattice_notch.py +11 -29
  4. ezmsg/sigproc/affinetransform.py +13 -38
  5. ezmsg/sigproc/aggregate.py +13 -30
  6. ezmsg/sigproc/bandpower.py +7 -15
  7. ezmsg/sigproc/base.py +141 -1276
  8. ezmsg/sigproc/butterworthfilter.py +8 -16
  9. ezmsg/sigproc/butterworthzerophase.py +123 -0
  10. ezmsg/sigproc/cheby.py +4 -10
  11. ezmsg/sigproc/combfilter.py +5 -8
  12. ezmsg/sigproc/decimate.py +2 -6
  13. ezmsg/sigproc/denormalize.py +6 -11
  14. ezmsg/sigproc/detrend.py +3 -4
  15. ezmsg/sigproc/diff.py +8 -17
  16. ezmsg/sigproc/downsample.py +6 -14
  17. ezmsg/sigproc/ewma.py +11 -27
  18. ezmsg/sigproc/ewmfilter.py +1 -1
  19. ezmsg/sigproc/extract_axis.py +3 -4
  20. ezmsg/sigproc/fbcca.py +31 -56
  21. ezmsg/sigproc/filter.py +19 -45
  22. ezmsg/sigproc/filterbank.py +33 -70
  23. ezmsg/sigproc/filterbankdesign.py +5 -12
  24. ezmsg/sigproc/fir_hilbert.py +336 -0
  25. ezmsg/sigproc/fir_pmc.py +209 -0
  26. ezmsg/sigproc/firfilter.py +12 -14
  27. ezmsg/sigproc/gaussiansmoothing.py +5 -9
  28. ezmsg/sigproc/kaiser.py +11 -15
  29. ezmsg/sigproc/math/abs.py +1 -3
  30. ezmsg/sigproc/math/add.py +121 -0
  31. ezmsg/sigproc/math/clip.py +1 -1
  32. ezmsg/sigproc/math/difference.py +98 -36
  33. ezmsg/sigproc/math/invert.py +1 -3
  34. ezmsg/sigproc/math/log.py +2 -6
  35. ezmsg/sigproc/messages.py +1 -2
  36. ezmsg/sigproc/quantize.py +2 -4
  37. ezmsg/sigproc/resample.py +13 -34
  38. ezmsg/sigproc/rollingscaler.py +232 -0
  39. ezmsg/sigproc/sampler.py +17 -35
  40. ezmsg/sigproc/scaler.py +8 -18
  41. ezmsg/sigproc/signalinjector.py +6 -16
  42. ezmsg/sigproc/slicer.py +9 -28
  43. ezmsg/sigproc/spectral.py +3 -3
  44. ezmsg/sigproc/spectrogram.py +12 -19
  45. ezmsg/sigproc/spectrum.py +12 -32
  46. ezmsg/sigproc/transpose.py +7 -18
  47. ezmsg/sigproc/util/asio.py +25 -156
  48. ezmsg/sigproc/util/axisarray_buffer.py +10 -26
  49. ezmsg/sigproc/util/buffer.py +18 -43
  50. ezmsg/sigproc/util/message.py +17 -31
  51. ezmsg/sigproc/util/profile.py +23 -174
  52. ezmsg/sigproc/util/sparse.py +5 -15
  53. ezmsg/sigproc/util/typeresolution.py +17 -83
  54. ezmsg/sigproc/wavelets.py +6 -15
  55. ezmsg/sigproc/window.py +24 -78
  56. {ezmsg_sigproc-2.4.1.dist-info → ezmsg_sigproc-2.6.0.dist-info}/METADATA +4 -3
  57. ezmsg_sigproc-2.6.0.dist-info/RECORD +63 -0
  58. ezmsg/sigproc/synth.py +0 -774
  59. ezmsg_sigproc-2.4.1.dist-info/RECORD +0 -59
  60. {ezmsg_sigproc-2.4.1.dist-info → ezmsg_sigproc-2.6.0.dist-info}/WHEEL +0 -0
  61. /ezmsg_sigproc-2.4.1.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.6.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/kaiser.py CHANGED
@@ -6,10 +6,10 @@ import numpy.typing as npt
6
6
  import scipy.signal
7
7
 
8
8
  from .filter import (
9
- FilterBaseSettings,
10
- FilterByDesignTransformer,
11
9
  BACoeffs,
12
10
  BaseFilterByDesignTransformerUnit,
11
+ FilterBaseSettings,
12
+ FilterByDesignTransformer,
13
13
  )
14
14
 
15
15
 
@@ -20,30 +20,30 @@ class KaiserFilterSettings(FilterBaseSettings):
20
20
 
21
21
  cutoff: float | npt.ArrayLike | None = None
22
22
  """
23
- Cutoff frequency of filter (expressed in the same units as fs) OR an array of cutoff frequencies
24
- (that is, band edges). In the former case, as a float, the cutoff frequency should correspond with
25
- the half-amplitude point, where the attenuation will be -6dB. In the latter case, the frequencies in
26
- cutoff should be positive and monotonically increasing between 0 and fs/2. The values 0 and fs/2 must
23
+ Cutoff frequency of filter (expressed in the same units as fs) OR an array of cutoff frequencies
24
+ (that is, band edges). In the former case, as a float, the cutoff frequency should correspond with
25
+ the half-amplitude point, where the attenuation will be -6dB. In the latter case, the frequencies in
26
+ cutoff should be positive and monotonically increasing between 0 and fs/2. The values 0 and fs/2 must
27
27
  not be included in cutoff.
28
28
  """
29
29
 
30
30
  ripple: float | None = None
31
31
  """
32
- Upper bound for the deviation (in dB) of the magnitude of the filter's frequency response from that of
32
+ Upper bound for the deviation (in dB) of the magnitude of the filter's frequency response from that of
33
33
  the desired filter (not including frequencies in any transition intervals).
34
34
  See scipy.signal.kaiserord for more information.
35
35
  """
36
36
 
37
37
  width: float | None = None
38
38
  """
39
- If width is not None, then assume it is the approximate width of the transition region (expressed in
39
+ If width is not None, then assume it is the approximate width of the transition region (expressed in
40
40
  the same units as fs) for use in Kaiser FIR filter design.
41
41
  See scipy.signal.kaiserord for more information.
42
42
  """
43
43
 
44
44
  pass_zero: bool | str = True
45
45
  """
46
- If True, the gain at the frequency 0 (i.e., the “DC gain”) is 1. If False, the DC gain is 0. Can also
46
+ If True, the gain at the frequency 0 (i.e., the “DC gain”) is 1. If False, the DC gain is 0. Can also
47
47
  be a string argument for the desired filter type (equivalent to btype in IIR design functions).
48
48
  {‘lowpass’, ‘highpass’, ‘bandpass’, ‘bandstop’}
49
49
  """
@@ -88,9 +88,7 @@ def kaiser_design_fun(
88
88
  return (taps, np.array([1.0]))
89
89
 
90
90
 
91
- class KaiserFilterTransformer(
92
- FilterByDesignTransformer[KaiserFilterSettings, BACoeffs]
93
- ):
91
+ class KaiserFilterTransformer(FilterByDesignTransformer[KaiserFilterSettings, BACoeffs]):
94
92
  def get_design_function(
95
93
  self,
96
94
  ) -> typing.Callable[[float], BACoeffs | None]:
@@ -104,7 +102,5 @@ class KaiserFilterTransformer(
104
102
  )
105
103
 
106
104
 
107
- class KaiserFilter(
108
- BaseFilterByDesignTransformerUnit[KaiserFilterSettings, KaiserFilterTransformer]
109
- ):
105
+ class KaiserFilter(BaseFilterByDesignTransformerUnit[KaiserFilterSettings, KaiserFilterTransformer]):
110
106
  SETTINGS = KaiserFilterSettings
ezmsg/sigproc/math/abs.py CHANGED
@@ -14,9 +14,7 @@ class AbsTransformer(BaseTransformer[None, AxisArray, AxisArray]):
14
14
  return replace(message, data=np.abs(message.data))
15
15
 
16
16
 
17
- class Abs(
18
- BaseTransformerUnit[None, AxisArray, AxisArray, AbsTransformer]
19
- ): ... # SETTINGS = None
17
+ class Abs(BaseTransformerUnit[None, AxisArray, AxisArray, AbsTransformer]): ... # SETTINGS = None
20
18
 
21
19
 
22
20
  def abs() -> AbsTransformer:
@@ -0,0 +1,121 @@
1
+ """Signal addition utilities."""
2
+
3
+ import asyncio
4
+ import typing
5
+ from dataclasses import dataclass, field
6
+
7
+ import ezmsg.core as ez
8
+ from ezmsg.baseproc.util.asio import run_coroutine_sync
9
+ from ezmsg.util.messages.axisarray import AxisArray
10
+ from ezmsg.util.messages.util import replace
11
+
12
+ from ..base import BaseTransformer, BaseTransformerUnit
13
+
14
+ # --- Constant Addition (single input) ---
15
+
16
+
17
+ class ConstAddSettings(ez.Settings):
18
+ value: float = 0.0
19
+ """Number to add to the input data."""
20
+
21
+
22
+ class ConstAddTransformer(BaseTransformer[ConstAddSettings, AxisArray, AxisArray]):
23
+ """Add a constant value to input data."""
24
+
25
+ def _process(self, message: AxisArray) -> AxisArray:
26
+ return replace(message, data=message.data + self.settings.value)
27
+
28
+
29
+ class ConstAdd(BaseTransformerUnit[ConstAddSettings, AxisArray, AxisArray, ConstAddTransformer]):
30
+ """Unit wrapper for ConstAddTransformer."""
31
+
32
+ SETTINGS = ConstAddSettings
33
+
34
+
35
+ # --- Two-input Addition ---
36
+
37
+
38
+ @dataclass
39
+ class AddState:
40
+ """State for Add processor with two input queues."""
41
+
42
+ queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
43
+ queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
44
+
45
+
46
+ class AddProcessor:
47
+ """Processor that adds two AxisArray signals together.
48
+
49
+ This processor maintains separate queues for two input streams and
50
+ adds corresponding messages element-wise. It assumes both inputs
51
+ have compatible shapes and aligned time spans.
52
+ """
53
+
54
+ def __init__(self):
55
+ self._state = AddState()
56
+
57
+ @property
58
+ def state(self) -> AddState:
59
+ return self._state
60
+
61
+ @state.setter
62
+ def state(self, state: AddState | bytes | None) -> None:
63
+ if state is not None:
64
+ # TODO: Support hydrating state from bytes
65
+ # if isinstance(state, bytes):
66
+ # self._state = pickle.loads(state)
67
+ # else:
68
+ self._state = state
69
+
70
+ def push_a(self, msg: AxisArray) -> None:
71
+ """Push a message to queue A."""
72
+ self._state.queue_a.put_nowait(msg)
73
+
74
+ def push_b(self, msg: AxisArray) -> None:
75
+ """Push a message to queue B."""
76
+ self._state.queue_b.put_nowait(msg)
77
+
78
+ async def __acall__(self) -> AxisArray:
79
+ """Await and add the next messages from both queues."""
80
+ a = await self._state.queue_a.get()
81
+ b = await self._state.queue_b.get()
82
+ return replace(a, data=a.data + b.data)
83
+
84
+ def __call__(self) -> AxisArray:
85
+ """Synchronously get and add the next messages from both queues."""
86
+ return run_coroutine_sync(self.__acall__())
87
+
88
+ # Aliases for legacy interface
89
+ async def __anext__(self) -> AxisArray:
90
+ return await self.__acall__()
91
+
92
+ def __next__(self) -> AxisArray:
93
+ return self.__call__()
94
+
95
+
96
+ class Add(ez.Unit):
97
+ """Add two signals together.
98
+
99
+ Assumes compatible/similar axes/dimensions and aligned time spans.
100
+ Messages are paired by arrival order (oldest from each queue).
101
+ """
102
+
103
+ INPUT_SIGNAL_A = ez.InputStream(AxisArray)
104
+ INPUT_SIGNAL_B = ez.InputStream(AxisArray)
105
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
106
+
107
+ async def initialize(self) -> None:
108
+ self.processor = AddProcessor()
109
+
110
+ @ez.subscriber(INPUT_SIGNAL_A)
111
+ async def on_a(self, msg: AxisArray) -> None:
112
+ self.processor.push_a(msg)
113
+
114
+ @ez.subscriber(INPUT_SIGNAL_B)
115
+ async def on_b(self, msg: AxisArray) -> None:
116
+ self.processor.push_b(msg)
117
+
118
+ @ez.publisher(OUTPUT_SIGNAL)
119
+ async def output(self) -> typing.AsyncGenerator:
120
+ while True:
121
+ yield self.OUTPUT_SIGNAL, await self.processor.__acall__()
@@ -1,5 +1,5 @@
1
- import numpy as np
2
1
  import ezmsg.core as ez
2
+ import numpy as np
3
3
  from ezmsg.util.messages.axisarray import AxisArray
4
4
  from ezmsg.util.messages.util import replace
5
5
 
@@ -1,4 +1,9 @@
1
+ import asyncio
2
+ import typing
3
+ from dataclasses import dataclass, field
4
+
1
5
  import ezmsg.core as ez
6
+ from ezmsg.baseproc.util.asio import run_coroutine_sync
2
7
  from ezmsg.util.messages.axisarray import AxisArray
3
8
  from ezmsg.util.messages.util import replace
4
9
 
@@ -10,12 +15,11 @@ class ConstDifferenceSettings(ez.Settings):
10
15
  """number to subtract or be subtracted from the input data"""
11
16
 
12
17
  subtrahend: bool = True
13
- """If True (default) then value is subtracted from the input data. If False, the input data is subtracted from value."""
18
+ """If True (default) then value is subtracted from the input data. If False, the input data
19
+ is subtracted from value."""
14
20
 
15
21
 
16
- class ConstDifferenceTransformer(
17
- BaseTransformer[ConstDifferenceSettings, AxisArray, AxisArray]
18
- ):
22
+ class ConstDifferenceTransformer(BaseTransformer[ConstDifferenceSettings, AxisArray, AxisArray]):
19
23
  def _process(self, message: AxisArray) -> AxisArray:
20
24
  return replace(
21
25
  message,
@@ -25,17 +29,11 @@ class ConstDifferenceTransformer(
25
29
  )
26
30
 
27
31
 
28
- class ConstDifference(
29
- BaseTransformerUnit[
30
- ConstDifferenceSettings, AxisArray, AxisArray, ConstDifferenceTransformer
31
- ]
32
- ):
32
+ class ConstDifference(BaseTransformerUnit[ConstDifferenceSettings, AxisArray, AxisArray, ConstDifferenceTransformer]):
33
33
  SETTINGS = ConstDifferenceSettings
34
34
 
35
35
 
36
- def const_difference(
37
- value: float = 0.0, subtrahend: bool = True
38
- ) -> ConstDifferenceTransformer:
36
+ def const_difference(value: float = 0.0, subtrahend: bool = True) -> ConstDifferenceTransformer:
39
37
  """
40
38
  result = (in_data - value) if subtrahend else (value - in_data)
41
39
  https://en.wikipedia.org/wiki/Template:Arithmetic_operations
@@ -47,27 +45,91 @@ def const_difference(
47
45
 
48
46
  Returns: :obj:`ConstDifferenceTransformer`.
49
47
  """
50
- return ConstDifferenceTransformer(
51
- ConstDifferenceSettings(value=value, subtrahend=subtrahend)
52
- )
53
-
54
-
55
- # class DifferenceSettings(ez.Settings):
56
- # pass
57
- #
58
- #
59
- # class Difference(ez.Unit):
60
- # SETTINGS = DifferenceSettings
61
- #
62
- # INPUT_SIGNAL_1 = ez.InputStream(AxisArray)
63
- # INPUT_SIGNAL_2 = ez.InputStream(AxisArray)
64
- # OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
65
- #
66
- # @ez.subscriber(INPUT_SIGNAL_2, zero_copy=True)
67
- # @ez.publisher(OUTPUT_SIGNAL)
68
- # async def on_input_2(self, message: AxisArray) -> typing.AsyncGenerator:
69
- # # TODO: buffer_2
70
- # # TODO: take buffer_1 - buffer_2 for ranges that align
71
- # # TODO: Drop samples from buffer_1 and buffer_2
72
- # if ret is not None:
73
- # yield self.OUTPUT_SIGNAL, ret
48
+ return ConstDifferenceTransformer(ConstDifferenceSettings(value=value, subtrahend=subtrahend))
49
+
50
+
51
+ # --- Two-input Difference ---
52
+
53
+
54
+ @dataclass
55
+ class DifferenceState:
56
+ """State for Difference processor with two input queues."""
57
+
58
+ queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
59
+ queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
60
+
61
+
62
+ class DifferenceProcessor:
63
+ """Processor that subtracts two AxisArray signals (A - B).
64
+
65
+ This processor maintains separate queues for two input streams and
66
+ subtracts corresponding messages element-wise. It assumes both inputs
67
+ have compatible shapes and aligned time spans.
68
+ """
69
+
70
+ def __init__(self):
71
+ self._state = DifferenceState()
72
+
73
+ @property
74
+ def state(self) -> DifferenceState:
75
+ return self._state
76
+
77
+ @state.setter
78
+ def state(self, state: DifferenceState | bytes | None) -> None:
79
+ if state is not None:
80
+ self._state = state
81
+
82
+ def push_a(self, msg: AxisArray) -> None:
83
+ """Push a message to queue A (minuend)."""
84
+ self._state.queue_a.put_nowait(msg)
85
+
86
+ def push_b(self, msg: AxisArray) -> None:
87
+ """Push a message to queue B (subtrahend)."""
88
+ self._state.queue_b.put_nowait(msg)
89
+
90
+ async def __acall__(self) -> AxisArray:
91
+ """Await and subtract the next messages (A - B)."""
92
+ a = await self._state.queue_a.get()
93
+ b = await self._state.queue_b.get()
94
+ return replace(a, data=a.data - b.data)
95
+
96
+ def __call__(self) -> AxisArray:
97
+ """Synchronously get and subtract the next messages."""
98
+ return run_coroutine_sync(self.__acall__())
99
+
100
+ # Aliases for legacy interface
101
+ async def __anext__(self) -> AxisArray:
102
+ return await self.__acall__()
103
+
104
+ def __next__(self) -> AxisArray:
105
+ return self.__call__()
106
+
107
+
108
+ class Difference(ez.Unit):
109
+ """Subtract two signals (A - B).
110
+
111
+ Assumes compatible/similar axes/dimensions and aligned time spans.
112
+ Messages are paired by arrival order (oldest from each queue).
113
+
114
+ OUTPUT = INPUT_SIGNAL_A - INPUT_SIGNAL_B
115
+ """
116
+
117
+ INPUT_SIGNAL_A = ez.InputStream(AxisArray)
118
+ INPUT_SIGNAL_B = ez.InputStream(AxisArray)
119
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
120
+
121
+ async def initialize(self) -> None:
122
+ self.processor = DifferenceProcessor()
123
+
124
+ @ez.subscriber(INPUT_SIGNAL_A)
125
+ async def on_a(self, msg: AxisArray) -> None:
126
+ self.processor.push_a(msg)
127
+
128
+ @ez.subscriber(INPUT_SIGNAL_B)
129
+ async def on_b(self, msg: AxisArray) -> None:
130
+ self.processor.push_b(msg)
131
+
132
+ @ez.publisher(OUTPUT_SIGNAL)
133
+ async def output(self) -> typing.AsyncGenerator:
134
+ while True:
135
+ yield self.OUTPUT_SIGNAL, await self.processor.__acall__()
@@ -9,9 +9,7 @@ class InvertTransformer(BaseTransformer[None, AxisArray, AxisArray]):
9
9
  return replace(message, data=1 / message.data)
10
10
 
11
11
 
12
- class Invert(
13
- BaseTransformerUnit[None, AxisArray, AxisArray, InvertTransformer]
14
- ): ... # SETTINGS = None
12
+ class Invert(BaseTransformerUnit[None, AxisArray, AxisArray, InvertTransformer]): ... # SETTINGS = None
15
13
 
16
14
 
17
15
  def invert() -> InvertTransformer:
ezmsg/sigproc/math/log.py CHANGED
@@ -1,5 +1,5 @@
1
- import numpy as np
2
1
  import ezmsg.core as ez
2
+ import numpy as np
3
3
  from ezmsg.util.messages.axisarray import AxisArray
4
4
  from ezmsg.util.messages.util import replace
5
5
 
@@ -17,11 +17,7 @@ class LogSettings(ez.Settings):
17
17
  class LogTransformer(BaseTransformer[LogSettings, AxisArray, AxisArray]):
18
18
  def _process(self, message: AxisArray) -> AxisArray:
19
19
  data = message.data
20
- if (
21
- self.settings.clip_zero
22
- and np.any(data <= 0)
23
- and np.issubdtype(data.dtype, np.floating)
24
- ):
20
+ if self.settings.clip_zero and np.any(data <= 0) and np.issubdtype(data.dtype, np.floating):
25
21
  data = np.clip(data, a_min=np.finfo(data.dtype).tiny, a_max=None)
26
22
  return replace(message, data=np.log(data) / np.log(self.settings.base))
27
23
 
ezmsg/sigproc/messages.py CHANGED
@@ -1,10 +1,9 @@
1
- import warnings
2
1
  import time
2
+ import warnings
3
3
 
4
4
  import numpy.typing as npt
5
5
  from ezmsg.util.messages.axisarray import AxisArray
6
6
 
7
-
8
7
  # UPCOMING: TSMessage Deprecation
9
8
  # TSMessage is deprecated because it doesn't handle multiple time axes well.
10
9
  # AxisArray has an incompatible API but supports a superset of functionality.
ezmsg/sigproc/quantize.py CHANGED
@@ -1,5 +1,5 @@
1
- import numpy as np
2
1
  import ezmsg.core as ez
2
+ import numpy as np
3
3
  from ezmsg.util.messages.axisarray import AxisArray, replace
4
4
 
5
5
  from .base import BaseTransformer, BaseTransformerUnit
@@ -65,7 +65,5 @@ class QuantizeTransformer(BaseTransformer[QuantizeSettings, AxisArray, AxisArray
65
65
  return replace(message, data=data)
66
66
 
67
67
 
68
- class QuantizerUnit(
69
- BaseTransformerUnit[QuantizeSettings, AxisArray, AxisArray, QuantizeTransformer]
70
- ):
68
+ class QuantizerUnit(BaseTransformerUnit[QuantizeSettings, AxisArray, AxisArray, QuantizeTransformer]):
71
69
  SETTINGS = QuantizeSettings
ezmsg/sigproc/resample.py CHANGED
@@ -2,15 +2,15 @@ import asyncio
2
2
  import math
3
3
  import time
4
4
 
5
+ import ezmsg.core as ez
5
6
  import numpy as np
6
7
  import scipy.interpolate
7
- import ezmsg.core as ez
8
8
  from ezmsg.util.messages.axisarray import AxisArray, LinearAxis
9
9
  from ezmsg.util.messages.util import replace
10
10
 
11
11
  from .base import (
12
- BaseStatefulProcessor,
13
12
  BaseConsumerUnit,
13
+ BaseStatefulProcessor,
14
14
  processor_state,
15
15
  )
16
16
  from .util.axisarray_buffer import HybridAxisArrayBuffer, HybridAxisBuffer
@@ -29,7 +29,7 @@ class ResampleSettings(ez.Settings):
29
29
  fill_value: str = "extrapolate"
30
30
  """
31
31
  Value to use for out-of-bounds samples.
32
- If 'extrapolate', the transformer will extrapolate.
32
+ If 'extrapolate', the transformer will extrapolate.
33
33
  If 'last', the transformer will use the last sample.
34
34
  See scipy.interpolate.interp1d for more options.
35
35
  """
@@ -57,9 +57,9 @@ class ResampleState:
57
57
  """
58
58
  The buffer for the reference axis (usually a time axis). The interpolation function
59
59
  will be evaluated at the reference axis values.
60
- When resample_rate is None, this buffer will be filled with the axis from incoming
60
+ When resample_rate is None, this buffer will be filled with the axis from incoming
61
61
  _reference_ messages.
62
- When resample_rate is not None (i.e., prescribed float resample_rate), this buffer
62
+ When resample_rate is not None (i.e., prescribed float resample_rate), this buffer
63
63
  is filled with a synthetic axis that is generated from the incoming signal messages.
64
64
  """
65
65
 
@@ -67,7 +67,7 @@ class ResampleState:
67
67
  """
68
68
  The last value of the reference axis that was returned. This helps us to know
69
69
  what the _next_ returned value should be, and to avoid returning the same value.
70
- TODO: We can eliminate this variable if we maintain "by convention" that the
70
+ TODO: We can eliminate this variable if we maintain "by convention" that the
71
71
  reference axis always has 1 value at its start that we exclude from the resampling.
72
72
  """
73
73
 
@@ -79,9 +79,7 @@ class ResampleState:
79
79
  """
80
80
 
81
81
 
82
- class ResampleProcessor(
83
- BaseStatefulProcessor[ResampleSettings, AxisArray, AxisArray, ResampleState]
84
- ):
82
+ class ResampleProcessor(BaseStatefulProcessor[ResampleSettings, AxisArray, AxisArray, ResampleState]):
85
83
  def _hash_message(self, message: AxisArray) -> int:
86
84
  ax_idx: int = message.get_axis_idx(self.settings.axis)
87
85
  sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
@@ -135,17 +133,11 @@ class ResampleProcessor(
135
133
  ax_idx = message.get_axis_idx(self.settings.axis)
136
134
  if self.settings.resample_rate is not None and message.data.shape[ax_idx] > 0:
137
135
  in_ax = message.axes[self.settings.axis]
138
- in_t_end = (
139
- in_ax.data[-1]
140
- if hasattr(in_ax, "data")
141
- else in_ax.value(message.data.shape[ax_idx] - 1)
142
- )
136
+ in_t_end = in_ax.data[-1] if hasattr(in_ax, "data") else in_ax.value(message.data.shape[ax_idx] - 1)
143
137
  out_gain = 1 / self.settings.resample_rate
144
138
  prev_t_end = self.state.last_ref_ax_val
145
139
  n_synth = math.ceil((in_t_end - prev_t_end) * self.settings.resample_rate)
146
- synth_ref_axis = LinearAxis(
147
- unit="s", gain=out_gain, offset=prev_t_end + out_gain
148
- )
140
+ synth_ref_axis = LinearAxis(unit="s", gain=out_gain, offset=prev_t_end + out_gain)
149
141
  self.state.ref_axis_buffer.write(synth_ref_axis, n_samples=n_synth)
150
142
 
151
143
  self.state.last_write_time = time.time()
@@ -193,11 +185,7 @@ class ResampleProcessor(
193
185
  # Get source to train interpolation
194
186
  src_axarr = src.peek()
195
187
  src_axis = src_axarr.axes[self.settings.axis]
196
- x = (
197
- src_axis.data
198
- if hasattr(src_axis, "data")
199
- else src_axis.value(np.arange(src_axarr.data.shape[0]))
200
- )
188
+ x = src_axis.data if hasattr(src_axis, "data") else src_axis.value(np.arange(src_axarr.data.shape[0]))
201
189
 
202
190
  # Only resample at reference values that have not been interpolated over previously.
203
191
  b_ref = ref_xvec > self.state.last_ref_ax_val
@@ -208,11 +196,7 @@ class ResampleProcessor(
208
196
 
209
197
  if len(ref_idx) == 0:
210
198
  # Nothing to interpolate over; return empty data
211
- null_ref = (
212
- replace(ref_ax, data=ref_ax.data[:0])
213
- if hasattr(ref_ax, "data")
214
- else ref_ax
215
- )
199
+ null_ref = replace(ref_ax, data=ref_ax.data[:0]) if hasattr(ref_ax, "data") else ref_ax
216
200
  return replace(
217
201
  src_axarr,
218
202
  data=src_axarr.data[:0, ...],
@@ -222,17 +206,12 @@ class ResampleProcessor(
222
206
  xnew = ref_xvec[ref_idx]
223
207
 
224
208
  # Identify source data indices around ref tvec with some padding for better interpolation.
225
- src_start_ix = max(
226
- 0, np.where(x > xnew[0])[0][0] - 2 if np.any(x > xnew[0]) else 0
227
- )
209
+ src_start_ix = max(0, np.where(x > xnew[0])[0][0] - 2 if np.any(x > xnew[0]) else 0)
228
210
 
229
211
  x = x[src_start_ix:]
230
212
  y = src_axarr.data[src_start_ix:]
231
213
 
232
- if (
233
- isinstance(self.settings.fill_value, str)
234
- and self.settings.fill_value == "last"
235
- ):
214
+ if isinstance(self.settings.fill_value, str) and self.settings.fill_value == "last":
236
215
  fill_value = (y[0], y[-1])
237
216
  else:
238
217
  fill_value = self.settings.fill_value