ezmsg-sigproc 1.8.2__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +36 -39
- ezmsg/sigproc/adaptive_lattice_notch.py +231 -0
- ezmsg/sigproc/affinetransform.py +169 -163
- ezmsg/sigproc/aggregate.py +119 -104
- ezmsg/sigproc/bandpower.py +58 -52
- ezmsg/sigproc/base.py +1242 -0
- ezmsg/sigproc/butterworthfilter.py +37 -33
- ezmsg/sigproc/cheby.py +29 -17
- ezmsg/sigproc/combfilter.py +163 -0
- ezmsg/sigproc/decimate.py +19 -10
- ezmsg/sigproc/detrend.py +29 -0
- ezmsg/sigproc/diff.py +81 -0
- ezmsg/sigproc/downsample.py +78 -84
- ezmsg/sigproc/ewma.py +197 -0
- ezmsg/sigproc/extract_axis.py +41 -0
- ezmsg/sigproc/filter.py +257 -141
- ezmsg/sigproc/filterbank.py +247 -199
- ezmsg/sigproc/math/abs.py +17 -22
- ezmsg/sigproc/math/clip.py +24 -24
- ezmsg/sigproc/math/difference.py +34 -30
- ezmsg/sigproc/math/invert.py +13 -25
- ezmsg/sigproc/math/log.py +28 -33
- ezmsg/sigproc/math/scale.py +18 -26
- ezmsg/sigproc/quantize.py +71 -0
- ezmsg/sigproc/resample.py +298 -0
- ezmsg/sigproc/sampler.py +241 -259
- ezmsg/sigproc/scaler.py +55 -218
- ezmsg/sigproc/signalinjector.py +52 -43
- ezmsg/sigproc/slicer.py +81 -89
- ezmsg/sigproc/spectrogram.py +77 -75
- ezmsg/sigproc/spectrum.py +203 -168
- ezmsg/sigproc/synth.py +546 -393
- ezmsg/sigproc/transpose.py +131 -0
- ezmsg/sigproc/util/asio.py +156 -0
- ezmsg/sigproc/util/message.py +31 -0
- ezmsg/sigproc/util/profile.py +55 -12
- ezmsg/sigproc/util/typeresolution.py +83 -0
- ezmsg/sigproc/wavelets.py +154 -153
- ezmsg/sigproc/window.py +269 -211
- {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.0.0.dist-info}/METADATA +2 -1
- ezmsg_sigproc-2.0.0.dist-info/RECORD +51 -0
- ezmsg_sigproc-1.8.2.dist-info/RECORD +0 -39
- {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.0.0.dist-info}/WHEEL +0 -0
- {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.0.0.dist-info}/licenses/LICENSE.txt +0 -0
ezmsg/sigproc/math/difference.py
CHANGED
|
@@ -1,18 +1,41 @@
|
|
|
1
|
-
import typing
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
1
|
import ezmsg.core as ez
|
|
5
|
-
from ezmsg.util.generator import consumer
|
|
6
2
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
3
|
from ezmsg.util.messages.util import replace
|
|
8
4
|
|
|
9
|
-
from ..base import
|
|
5
|
+
from ..base import BaseTransformer, BaseTransformerUnit
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ConstDifferenceSettings(ez.Settings):
|
|
9
|
+
value: float = 0.0
|
|
10
|
+
"""number to subtract or be subtracted from the input data"""
|
|
11
|
+
|
|
12
|
+
subtrahend: bool = True
|
|
13
|
+
"""If True (default) then value is subtracted from the input data. If False, the input data is subtracted from value."""
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ConstDifferenceTransformer(
|
|
17
|
+
BaseTransformer[ConstDifferenceSettings, AxisArray, AxisArray]
|
|
18
|
+
):
|
|
19
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
20
|
+
return replace(
|
|
21
|
+
message,
|
|
22
|
+
data=(message.data - self.settings.value)
|
|
23
|
+
if self.settings.subtrahend
|
|
24
|
+
else (self.settings.value - message.data),
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ConstDifference(
|
|
29
|
+
BaseTransformerUnit[
|
|
30
|
+
ConstDifferenceSettings, AxisArray, AxisArray, ConstDifferenceTransformer
|
|
31
|
+
]
|
|
32
|
+
):
|
|
33
|
+
SETTINGS = ConstDifferenceSettings
|
|
10
34
|
|
|
11
35
|
|
|
12
|
-
@consumer
|
|
13
36
|
def const_difference(
|
|
14
37
|
value: float = 0.0, subtrahend: bool = True
|
|
15
|
-
) ->
|
|
38
|
+
) -> ConstDifferenceTransformer:
|
|
16
39
|
"""
|
|
17
40
|
result = (in_data - value) if subtrahend else (value - in_data)
|
|
18
41
|
https://en.wikipedia.org/wiki/Template:Arithmetic_operations
|
|
@@ -22,30 +45,11 @@ def const_difference(
|
|
|
22
45
|
subtrahend: If True (default) then value is subtracted from the input data.
|
|
23
46
|
If False, the input data is subtracted from value.
|
|
24
47
|
|
|
25
|
-
Returns:
|
|
26
|
-
with the data payload containing the difference between the input :obj:`AxisArray` data and the value.
|
|
27
|
-
|
|
48
|
+
Returns: :obj:`ConstDifferenceTransformer`.
|
|
28
49
|
"""
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
msg_out = replace(
|
|
33
|
-
msg_in, data=(msg_in.data - value) if subtrahend else (value - msg_in.data)
|
|
34
|
-
)
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
class ConstDifferenceSettings(ez.Settings):
|
|
38
|
-
value: float = 0.0
|
|
39
|
-
subtrahend: bool = True
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
class ConstDifference(GenAxisArray):
|
|
43
|
-
SETTINGS = ConstDifferenceSettings
|
|
44
|
-
|
|
45
|
-
def construct_generator(self):
|
|
46
|
-
self.STATE.gen = const_difference(
|
|
47
|
-
value=self.SETTINGS.value, subtrahend=self.SETTINGS.subtrahend
|
|
48
|
-
)
|
|
50
|
+
return ConstDifferenceTransformer(
|
|
51
|
+
ConstDifferenceSettings(value=value, subtrahend=subtrahend)
|
|
52
|
+
)
|
|
49
53
|
|
|
50
54
|
|
|
51
55
|
# class DifferenceSettings(ez.Settings):
|
ezmsg/sigproc/math/invert.py
CHANGED
|
@@ -1,35 +1,23 @@
|
|
|
1
|
-
import typing
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import ezmsg.core as ez
|
|
5
|
-
from ezmsg.util.generator import consumer
|
|
6
1
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
2
|
from ezmsg.util.messages.util import replace
|
|
8
3
|
|
|
9
|
-
from ..base import
|
|
4
|
+
from ..base import BaseTransformer, BaseTransformerUnit
|
|
10
5
|
|
|
11
6
|
|
|
12
|
-
|
|
13
|
-
def
|
|
14
|
-
|
|
15
|
-
Take the inverse of the data.
|
|
7
|
+
class InvertTransformer(BaseTransformer[None, AxisArray, AxisArray]):
|
|
8
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
9
|
+
return replace(message, data=1 / message.data)
|
|
16
10
|
|
|
17
|
-
Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an :obj:`AxisArray`
|
|
18
|
-
with the data payload containing the inversion of the input :obj:`AxisArray` data.
|
|
19
|
-
|
|
20
|
-
"""
|
|
21
|
-
msg_out = AxisArray(np.array([]), dims=[""])
|
|
22
|
-
while True:
|
|
23
|
-
msg_in: AxisArray = yield msg_out
|
|
24
|
-
msg_out = replace(msg_in, data=1 / msg_in.data)
|
|
25
11
|
|
|
12
|
+
class Invert(
|
|
13
|
+
BaseTransformerUnit[None, AxisArray, AxisArray, InvertTransformer]
|
|
14
|
+
): ... # SETTINGS = None
|
|
26
15
|
|
|
27
|
-
class InvertSettings(ez.Settings):
|
|
28
|
-
pass
|
|
29
16
|
|
|
17
|
+
def invert() -> InvertTransformer:
|
|
18
|
+
"""
|
|
19
|
+
Take the inverse of the data.
|
|
30
20
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def construct_generator(self):
|
|
35
|
-
self.STATE.gen = invert()
|
|
21
|
+
Returns: :obj:`InvertTransformer`.
|
|
22
|
+
"""
|
|
23
|
+
return InvertTransformer()
|
ezmsg/sigproc/math/log.py
CHANGED
|
@@ -1,19 +1,39 @@
|
|
|
1
|
-
import typing
|
|
2
|
-
|
|
3
1
|
import numpy as np
|
|
4
2
|
import ezmsg.core as ez
|
|
5
|
-
from ezmsg.util.generator import consumer
|
|
6
3
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
4
|
from ezmsg.util.messages.util import replace
|
|
8
5
|
|
|
9
|
-
from ..base import
|
|
6
|
+
from ..base import BaseTransformer, BaseTransformerUnit
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LogSettings(ez.Settings):
|
|
10
|
+
base: float = 10.0
|
|
11
|
+
"""The base of the logarithm. Default is 10."""
|
|
12
|
+
|
|
13
|
+
clip_zero: bool = False
|
|
14
|
+
"""If True, clip the data to the minimum positive value of the data type before taking the log."""
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LogTransformer(BaseTransformer[LogSettings, AxisArray, AxisArray]):
|
|
18
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
19
|
+
data = message.data
|
|
20
|
+
if (
|
|
21
|
+
self.settings.clip_zero
|
|
22
|
+
and np.any(data <= 0)
|
|
23
|
+
and np.issubdtype(data.dtype, np.floating)
|
|
24
|
+
):
|
|
25
|
+
data = np.clip(data, a_min=np.finfo(data.dtype).tiny, a_max=None)
|
|
26
|
+
return replace(message, data=np.log(data) / np.log(self.settings.base))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Log(BaseTransformerUnit[LogSettings, AxisArray, AxisArray, LogTransformer]):
|
|
30
|
+
SETTINGS = LogSettings
|
|
10
31
|
|
|
11
32
|
|
|
12
|
-
@consumer
|
|
13
33
|
def log(
|
|
14
34
|
base: float = 10.0,
|
|
15
35
|
clip_zero: bool = False,
|
|
16
|
-
) ->
|
|
36
|
+
) -> LogTransformer:
|
|
17
37
|
"""
|
|
18
38
|
Take the logarithm of the data. See :obj:`np.log` for more details.
|
|
19
39
|
|
|
@@ -21,32 +41,7 @@ def log(
|
|
|
21
41
|
base: The base of the logarithm. Default is 10.
|
|
22
42
|
clip_zero: If True, clip the data to the minimum positive value of the data type before taking the log.
|
|
23
43
|
|
|
24
|
-
Returns:
|
|
25
|
-
with the data payload containing the logarithm of the input :obj:`AxisArray` data.
|
|
44
|
+
Returns: :obj:`LogTransformer`.
|
|
26
45
|
|
|
27
46
|
"""
|
|
28
|
-
|
|
29
|
-
log_base = np.log(base)
|
|
30
|
-
while True:
|
|
31
|
-
msg_in: AxisArray = yield msg_out
|
|
32
|
-
if (
|
|
33
|
-
clip_zero
|
|
34
|
-
and np.any(msg_in.data <= 0)
|
|
35
|
-
and np.issubdtype(msg_in.data.dtype, np.floating)
|
|
36
|
-
):
|
|
37
|
-
msg_in.data = np.clip(
|
|
38
|
-
msg_in.data, a_min=np.finfo(msg_in.data.dtype).tiny, a_max=None
|
|
39
|
-
)
|
|
40
|
-
msg_out = replace(msg_in, data=np.log(msg_in.data) / log_base)
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
class LogSettings(ez.Settings):
|
|
44
|
-
base: float = 10.0
|
|
45
|
-
clip_zero: bool = False
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
class Log(GenAxisArray):
|
|
49
|
-
SETTINGS = LogSettings
|
|
50
|
-
|
|
51
|
-
def construct_generator(self):
|
|
52
|
-
self.STATE.gen = log(base=self.SETTINGS.base, clip_zero=self.SETTINGS.clip_zero)
|
|
47
|
+
return LogTransformer(LogSettings(base=base, clip_zero=clip_zero))
|
ezmsg/sigproc/math/scale.py
CHANGED
|
@@ -1,40 +1,32 @@
|
|
|
1
|
-
import typing
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
1
|
import ezmsg.core as ez
|
|
5
|
-
from ezmsg.util.generator import consumer
|
|
6
2
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
3
|
from ezmsg.util.messages.util import replace
|
|
8
4
|
|
|
9
|
-
from ..base import
|
|
5
|
+
from ..base import BaseTransformer, BaseTransformerUnit
|
|
10
6
|
|
|
11
7
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
"""
|
|
15
|
-
Scale the data by a constant factor.
|
|
8
|
+
class ScaleSettings(ez.Settings):
|
|
9
|
+
scale: float = 1.0
|
|
10
|
+
"""Factor by which to scale the data magnitude."""
|
|
16
11
|
|
|
17
|
-
Args:
|
|
18
|
-
scale: Factor by which to scale the data magnitude.
|
|
19
12
|
|
|
20
|
-
|
|
21
|
-
|
|
13
|
+
class ScaleTransformer(BaseTransformer[ScaleSettings, AxisArray, AxisArray]):
|
|
14
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
15
|
+
return replace(message, data=self.settings.scale * message.data)
|
|
22
16
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
msg_in: AxisArray = yield msg_out
|
|
27
|
-
msg_out = replace(msg_in, data=scale * msg_in.data)
|
|
17
|
+
|
|
18
|
+
class Scale(BaseTransformerUnit[ScaleSettings, AxisArray, AxisArray, ScaleTransformer]):
|
|
19
|
+
SETTINGS = ScaleSettings
|
|
28
20
|
|
|
29
21
|
|
|
30
|
-
|
|
31
|
-
|
|
22
|
+
def scale(scale: float = 1.0) -> ScaleTransformer:
|
|
23
|
+
"""
|
|
24
|
+
Scale the data by a constant factor.
|
|
32
25
|
|
|
26
|
+
Args:
|
|
27
|
+
scale: Factor by which to scale the data magnitude.
|
|
33
28
|
|
|
34
|
-
|
|
35
|
-
SETTINGS = ScaleSettings
|
|
29
|
+
Returns: :obj:`ScaleTransformer`
|
|
36
30
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
scale=self.SETTINGS.scale,
|
|
40
|
-
)
|
|
31
|
+
"""
|
|
32
|
+
return ScaleTransformer(ScaleSettings(scale=scale))
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import ezmsg.core as ez
|
|
3
|
+
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
4
|
+
|
|
5
|
+
from .base import BaseTransformer, BaseTransformerUnit
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class QuantizeSettings(ez.Settings):
|
|
9
|
+
"""
|
|
10
|
+
Settings for the Quantizer.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
max_val: float
|
|
14
|
+
"""
|
|
15
|
+
Clip the data to this maximum value before quantization and map the [min_val max_val] range to the quantized range.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
min_val: float = 0.0
|
|
19
|
+
"""
|
|
20
|
+
Clip the data to this minimum value before quantization and map the [min_val max_val] range to the quantized range.
|
|
21
|
+
Default: 0
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
bits: int = 8
|
|
25
|
+
"""
|
|
26
|
+
Number of bits for quantization.
|
|
27
|
+
Note: The data type will be integer of the next power of 2 greater than or equal to this value.
|
|
28
|
+
Default: 8
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class QuantizeTransformer(BaseTransformer[QuantizeSettings, AxisArray, AxisArray]):
|
|
33
|
+
def _process(
|
|
34
|
+
self,
|
|
35
|
+
message: AxisArray,
|
|
36
|
+
) -> AxisArray:
|
|
37
|
+
expected_range = self.settings.max_val - self.settings.min_val
|
|
38
|
+
scale_factor = 2**self.settings.bits - 1
|
|
39
|
+
clip_max = self.settings.max_val
|
|
40
|
+
|
|
41
|
+
# Determine appropriate integer type based on bits
|
|
42
|
+
if self.settings.bits <= 1:
|
|
43
|
+
dtype = bool
|
|
44
|
+
elif self.settings.bits <= 8:
|
|
45
|
+
dtype = np.uint8
|
|
46
|
+
elif self.settings.bits <= 16:
|
|
47
|
+
dtype = np.uint16
|
|
48
|
+
elif self.settings.bits <= 32:
|
|
49
|
+
dtype = np.uint32
|
|
50
|
+
else:
|
|
51
|
+
dtype = np.uint64
|
|
52
|
+
if self.settings.bits == 64:
|
|
53
|
+
# The practical upper bound before converting to int is: 2**64 - 1025
|
|
54
|
+
# Anything larger will wrap around to 0.
|
|
55
|
+
#
|
|
56
|
+
clip_max *= 1 - 2e-16
|
|
57
|
+
|
|
58
|
+
data = message.data.clip(self.settings.min_val, clip_max)
|
|
59
|
+
data = (data - self.settings.min_val) / expected_range
|
|
60
|
+
|
|
61
|
+
# Scale to the quantized range [0, 2^bits - 1]
|
|
62
|
+
data = np.rint(scale_factor * data).astype(dtype)
|
|
63
|
+
|
|
64
|
+
# Create a new AxisArray with the quantized data
|
|
65
|
+
return replace(message, data=data)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class QuantizerUnit(
|
|
69
|
+
BaseTransformerUnit[QuantizeSettings, AxisArray, AxisArray, QuantizeTransformer]
|
|
70
|
+
):
|
|
71
|
+
SETTINGS = QuantizeSettings
|
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import dataclasses
|
|
3
|
+
import time
|
|
4
|
+
import typing
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import numpy.typing as npt
|
|
8
|
+
import scipy.interpolate
|
|
9
|
+
import ezmsg.core as ez
|
|
10
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
11
|
+
from ezmsg.util.messages.util import replace
|
|
12
|
+
|
|
13
|
+
from .base import (
|
|
14
|
+
BaseStatefulProcessor,
|
|
15
|
+
BaseConsumerUnit,
|
|
16
|
+
processor_state,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ResampleSettings(ez.Settings):
|
|
21
|
+
axis: str = "time"
|
|
22
|
+
|
|
23
|
+
resample_rate: float | None = None
|
|
24
|
+
"""target resample rate in Hz. If None, the resample rate will be determined by the reference signal."""
|
|
25
|
+
|
|
26
|
+
max_chunk_delay: float = 0.0
|
|
27
|
+
"""Maximum delay between outputs in seconds. If the delay exceeds this value, the transformer will extrapolate."""
|
|
28
|
+
|
|
29
|
+
fill_value: str = "extrapolate"
|
|
30
|
+
"""
|
|
31
|
+
Value to use for out-of-bounds samples.
|
|
32
|
+
If 'extrapolate', the transformer will extrapolate.
|
|
33
|
+
If 'last', the transformer will use the last sample.
|
|
34
|
+
See scipy.interpolate.interp1d for more options.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclasses.dataclass
|
|
39
|
+
class ResampleBuffer:
|
|
40
|
+
data: npt.NDArray
|
|
41
|
+
tvec: npt.NDArray
|
|
42
|
+
template: AxisArray
|
|
43
|
+
last_update: float
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@processor_state
|
|
47
|
+
class ResampleState:
|
|
48
|
+
signal_buffer: ResampleBuffer | None = None
|
|
49
|
+
ref_axis: tuple[typing.Union[AxisArray.TimeAxis, AxisArray.CoordinateAxis], int] = (
|
|
50
|
+
AxisArray.TimeAxis(fs=1.0),
|
|
51
|
+
0,
|
|
52
|
+
)
|
|
53
|
+
last_t_out: float | None = None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class ResampleProcessor(
|
|
57
|
+
BaseStatefulProcessor[ResampleSettings, AxisArray, AxisArray, ResampleState]
|
|
58
|
+
):
|
|
59
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
60
|
+
ax_idx: int = message.get_axis_idx(self.settings.axis)
|
|
61
|
+
sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
|
|
62
|
+
ax = message.axes[self.settings.axis]
|
|
63
|
+
in_fs = (1 / ax.gain) if hasattr(ax, "gain") else None
|
|
64
|
+
return hash((message.key, in_fs) + sample_shape)
|
|
65
|
+
|
|
66
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
67
|
+
"""
|
|
68
|
+
Reset the internal state based on the incoming message.
|
|
69
|
+
If resample_rate is None, the output is driven by the reference signal.
|
|
70
|
+
The input will still determine the template (except the primary axis) and the buffer.
|
|
71
|
+
"""
|
|
72
|
+
ax_idx: int = message.get_axis_idx(self.settings.axis)
|
|
73
|
+
ax = message.axes[self.settings.axis]
|
|
74
|
+
in_dat = message.data
|
|
75
|
+
in_tvec = (
|
|
76
|
+
ax.data
|
|
77
|
+
if hasattr(ax, "data")
|
|
78
|
+
else ax.value(np.arange(in_dat.shape[ax_idx]))
|
|
79
|
+
)
|
|
80
|
+
if ax_idx != 0:
|
|
81
|
+
in_dat = np.moveaxis(in_dat, ax_idx, 0)
|
|
82
|
+
|
|
83
|
+
if self.settings.resample_rate is None:
|
|
84
|
+
# Output is driven by input.
|
|
85
|
+
# We cannot include the resampled axis until we see reference data.
|
|
86
|
+
out_axes = {
|
|
87
|
+
k: v for k, v in message.axes.items() if k != self.settings.axis
|
|
88
|
+
}
|
|
89
|
+
# last_t_out also driven by reference data.
|
|
90
|
+
# self.state.last_t_out = None
|
|
91
|
+
else:
|
|
92
|
+
out_axes = {
|
|
93
|
+
**message.axes,
|
|
94
|
+
self.settings.axis: AxisArray.TimeAxis(
|
|
95
|
+
fs=self.settings.resample_rate, offset=in_tvec[0]
|
|
96
|
+
),
|
|
97
|
+
}
|
|
98
|
+
self.state.last_t_out = in_tvec[0] - 1 / self.settings.resample_rate
|
|
99
|
+
template = replace(message, data=in_dat[:0], axes=out_axes)
|
|
100
|
+
self.state.signal_buffer = ResampleBuffer(
|
|
101
|
+
data=in_dat[:0],
|
|
102
|
+
tvec=in_tvec[:0],
|
|
103
|
+
template=template,
|
|
104
|
+
last_update=time.time(),
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def _process(self, message: AxisArray) -> None:
|
|
108
|
+
# The incoming message will be added to the buffer.
|
|
109
|
+
buf = self.state.signal_buffer
|
|
110
|
+
|
|
111
|
+
# If our outputs are driven by reference signal, create the template's output axis if not already created.
|
|
112
|
+
if (
|
|
113
|
+
self.settings.resample_rate is None
|
|
114
|
+
and self.settings.axis not in self.state.signal_buffer.template.axes
|
|
115
|
+
):
|
|
116
|
+
buf = self.state.signal_buffer
|
|
117
|
+
buf.template.axes[self.settings.axis] = self.state.ref_axis[0]
|
|
118
|
+
if hasattr(buf.template.axes[self.settings.axis], "gain"):
|
|
119
|
+
buf.template = replace(
|
|
120
|
+
buf.template,
|
|
121
|
+
axes={
|
|
122
|
+
**buf.template.axes,
|
|
123
|
+
self.settings.axis: replace(
|
|
124
|
+
buf.template.axes[self.settings.axis],
|
|
125
|
+
offset=self.state.last_t_out,
|
|
126
|
+
),
|
|
127
|
+
},
|
|
128
|
+
)
|
|
129
|
+
# Note: last_t_out was set on the first call to push_reference.
|
|
130
|
+
|
|
131
|
+
# Append the new data to the buffer
|
|
132
|
+
ax_idx: int = message.get_axis_idx(self.settings.axis)
|
|
133
|
+
in_dat: npt.NDArray = message.data
|
|
134
|
+
if ax_idx != 0:
|
|
135
|
+
in_dat = np.moveaxis(in_dat, ax_idx, 0)
|
|
136
|
+
ax = message.axes[self.settings.axis]
|
|
137
|
+
in_tvec = (
|
|
138
|
+
ax.data if hasattr(ax, "data") else ax.value(np.arange(in_dat.shape[0]))
|
|
139
|
+
)
|
|
140
|
+
buf.data = np.concatenate((buf.data, in_dat), axis=0)
|
|
141
|
+
buf.tvec = np.hstack((buf.tvec, in_tvec))
|
|
142
|
+
buf.last_update = time.time()
|
|
143
|
+
|
|
144
|
+
def push_reference(self, message: AxisArray) -> None:
|
|
145
|
+
ax = message.axes[self.settings.axis]
|
|
146
|
+
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
147
|
+
n_new = message.data.shape[ax_idx]
|
|
148
|
+
if self.state.ref_axis[1] == 0:
|
|
149
|
+
self.state.ref_axis = (ax, n_new)
|
|
150
|
+
else:
|
|
151
|
+
if hasattr(ax, "gain"):
|
|
152
|
+
# Rate and offset don't need to change; we simply increment our sample counter.
|
|
153
|
+
self.state.ref_axis = (
|
|
154
|
+
self.state.ref_axis[0],
|
|
155
|
+
self.state.ref_axis[1] + n_new,
|
|
156
|
+
)
|
|
157
|
+
else:
|
|
158
|
+
# Extend our time axis with the new data.
|
|
159
|
+
new_tvec = np.concatenate(
|
|
160
|
+
(self.state.ref_axis[0].data, ax.data), axis=0
|
|
161
|
+
)
|
|
162
|
+
self.state.ref_axis = (
|
|
163
|
+
replace(self.state.ref_axis[0], data=new_tvec),
|
|
164
|
+
self.state.ref_axis[1] + n_new,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
if self.settings.resample_rate is None and self.state.last_t_out is None:
|
|
168
|
+
# This reference axis will become THE output axis.
|
|
169
|
+
# If last_t_out has not previously been set, we set it to the sample before this reference data.
|
|
170
|
+
if hasattr(self.state.ref_axis[0], "gain"):
|
|
171
|
+
ref_tvec = self.state.ref_axis[0].value(np.arange(2))
|
|
172
|
+
else:
|
|
173
|
+
ref_tvec = self.state.ref_axis[0].data[:2]
|
|
174
|
+
self.state.last_t_out = 2 * ref_tvec[0] - ref_tvec[1]
|
|
175
|
+
|
|
176
|
+
def __next__(self) -> AxisArray:
|
|
177
|
+
buf = self.state.signal_buffer
|
|
178
|
+
|
|
179
|
+
if buf is None:
|
|
180
|
+
return AxisArray(data=np.array([]), dims=[""], axes={}, key="null")
|
|
181
|
+
|
|
182
|
+
# buffer is empty or ref-driven && empty-reference; return the empty template
|
|
183
|
+
if (buf.tvec.size == 0) or (
|
|
184
|
+
self.settings.resample_rate is None and self.state.ref_axis[1] < 3
|
|
185
|
+
):
|
|
186
|
+
# Note: empty template's primary axis' offset might be meaningless.
|
|
187
|
+
return buf.template
|
|
188
|
+
|
|
189
|
+
# Identify the output timestamps at which we will resample the buffer
|
|
190
|
+
b_project = False
|
|
191
|
+
if self.settings.resample_rate is None:
|
|
192
|
+
# Rely on reference signal to determine output timestamps
|
|
193
|
+
if hasattr(self.state.ref_axis[0], "data"):
|
|
194
|
+
ref_tvec = self.state.ref_axis[0].data
|
|
195
|
+
else:
|
|
196
|
+
n_avail = self.state.ref_axis[1]
|
|
197
|
+
ref_tvec = self.state.ref_axis[0].value(np.arange(n_avail))
|
|
198
|
+
else:
|
|
199
|
+
# Get output timestamps from resample_rate and what we've collected so far
|
|
200
|
+
t_begin = self.state.last_t_out + 1 / self.settings.resample_rate
|
|
201
|
+
t_end = buf.tvec[-1]
|
|
202
|
+
if self.settings.max_chunk_delay > 0 and time.time() > (
|
|
203
|
+
buf.last_update + self.settings.max_chunk_delay
|
|
204
|
+
):
|
|
205
|
+
# We've waiting too long between pushes. We will have to extrapolate.
|
|
206
|
+
b_project = True
|
|
207
|
+
t_end += self.settings.max_chunk_delay
|
|
208
|
+
ref_tvec = np.arange(t_begin, t_end, 1 / self.settings.resample_rate)
|
|
209
|
+
|
|
210
|
+
# Which samples can we resample?
|
|
211
|
+
b_ref = ref_tvec > self.state.last_t_out
|
|
212
|
+
if not b_project:
|
|
213
|
+
b_ref = np.logical_and(b_ref, ref_tvec <= buf.tvec[-1])
|
|
214
|
+
ref_idx = np.where(b_ref)[0]
|
|
215
|
+
|
|
216
|
+
if len(ref_idx) < 2:
|
|
217
|
+
# Not enough data to resample; return the empty template.
|
|
218
|
+
return buf.template
|
|
219
|
+
|
|
220
|
+
tnew = ref_tvec[ref_idx]
|
|
221
|
+
# Slice buf to minimal range around tnew with some padding for better interpolation.
|
|
222
|
+
buf_start_ix = max(0, np.searchsorted(buf.tvec, tnew[0]) - 2)
|
|
223
|
+
buf_stop_ix = np.searchsorted(buf.tvec, tnew[-1], side="right") + 2
|
|
224
|
+
x = buf.tvec[buf_start_ix:buf_stop_ix]
|
|
225
|
+
y = buf.data[buf_start_ix:buf_stop_ix]
|
|
226
|
+
if (
|
|
227
|
+
isinstance(self.settings.fill_value, str)
|
|
228
|
+
and self.settings.fill_value == "last"
|
|
229
|
+
):
|
|
230
|
+
fill_value = (y[0], y[-1])
|
|
231
|
+
else:
|
|
232
|
+
fill_value = self.settings.fill_value
|
|
233
|
+
f = scipy.interpolate.interp1d(
|
|
234
|
+
x,
|
|
235
|
+
y,
|
|
236
|
+
kind="linear",
|
|
237
|
+
axis=0,
|
|
238
|
+
copy=False,
|
|
239
|
+
bounds_error=False,
|
|
240
|
+
fill_value=fill_value,
|
|
241
|
+
assume_sorted=True,
|
|
242
|
+
)
|
|
243
|
+
resampled_data = f(tnew)
|
|
244
|
+
if hasattr(buf.template.axes[self.settings.axis], "data"):
|
|
245
|
+
repl_axis = replace(buf.template.axes[self.settings.axis], data=tnew)
|
|
246
|
+
else:
|
|
247
|
+
repl_axis = replace(buf.template.axes[self.settings.axis], offset=tnew[0])
|
|
248
|
+
result = replace(
|
|
249
|
+
buf.template,
|
|
250
|
+
data=resampled_data,
|
|
251
|
+
axes={
|
|
252
|
+
**buf.template.axes,
|
|
253
|
+
self.settings.axis: repl_axis,
|
|
254
|
+
},
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# Update state to move past samples that are no longer be needed
|
|
258
|
+
self.state.last_t_out = tnew[-1]
|
|
259
|
+
buf.data = buf.data[max(0, buf_stop_ix - 3) :]
|
|
260
|
+
buf.tvec = buf.tvec[max(0, buf_stop_ix - 3) :]
|
|
261
|
+
buf.last_update = time.time()
|
|
262
|
+
|
|
263
|
+
if self.settings.resample_rate is None:
|
|
264
|
+
# Update self.state.ref_axis to remove samples that have been used in the output
|
|
265
|
+
if hasattr(self.state.ref_axis[0], "data"):
|
|
266
|
+
new_ref_ax = replace(
|
|
267
|
+
self.state.ref_axis[0],
|
|
268
|
+
data=self.state.ref_axis[0].data[ref_idx[-1] + 1 :],
|
|
269
|
+
)
|
|
270
|
+
else:
|
|
271
|
+
next_offset = self.state.ref_axis[0].value(ref_idx[-1] + 1)
|
|
272
|
+
new_ref_ax = replace(self.state.ref_axis[0], offset=next_offset)
|
|
273
|
+
self.state.ref_axis = (new_ref_ax, self.state.ref_axis[1] - len(ref_idx))
|
|
274
|
+
|
|
275
|
+
return result
|
|
276
|
+
|
|
277
|
+
def send(self, message: AxisArray) -> AxisArray:
|
|
278
|
+
self(message)
|
|
279
|
+
return next(self)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
class ResampleUnit(BaseConsumerUnit[ResampleSettings, AxisArray, ResampleProcessor]):
|
|
283
|
+
SETTINGS = ResampleSettings
|
|
284
|
+
INPUT_REFERENCE = ez.InputStream(AxisArray)
|
|
285
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
286
|
+
|
|
287
|
+
@ez.subscriber(INPUT_REFERENCE, zero_copy=True)
|
|
288
|
+
async def on_reference(self, message: AxisArray):
|
|
289
|
+
self.processor.push_reference(message)
|
|
290
|
+
|
|
291
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
292
|
+
async def gen_resampled(self):
|
|
293
|
+
while True:
|
|
294
|
+
result: AxisArray = next(self.processor)
|
|
295
|
+
if np.prod(result.data.shape) > 0:
|
|
296
|
+
yield self.OUTPUT_SIGNAL, result
|
|
297
|
+
else:
|
|
298
|
+
await asyncio.sleep(0.001)
|