ezmsg-sigproc 1.7.0__py3-none-any.whl → 2.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +22 -4
- ezmsg/sigproc/activation.py +31 -40
- ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
- ezmsg/sigproc/affinetransform.py +171 -169
- ezmsg/sigproc/aggregate.py +190 -97
- ezmsg/sigproc/bandpower.py +60 -55
- ezmsg/sigproc/base.py +143 -33
- ezmsg/sigproc/butterworthfilter.py +34 -38
- ezmsg/sigproc/butterworthzerophase.py +305 -0
- ezmsg/sigproc/cheby.py +23 -17
- ezmsg/sigproc/combfilter.py +160 -0
- ezmsg/sigproc/coordinatespaces.py +159 -0
- ezmsg/sigproc/decimate.py +15 -10
- ezmsg/sigproc/denormalize.py +78 -0
- ezmsg/sigproc/detrend.py +28 -0
- ezmsg/sigproc/diff.py +82 -0
- ezmsg/sigproc/downsample.py +72 -81
- ezmsg/sigproc/ewma.py +217 -0
- ezmsg/sigproc/ewmfilter.py +1 -1
- ezmsg/sigproc/extract_axis.py +39 -0
- ezmsg/sigproc/fbcca.py +307 -0
- ezmsg/sigproc/filter.py +254 -148
- ezmsg/sigproc/filterbank.py +226 -214
- ezmsg/sigproc/filterbankdesign.py +129 -0
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +117 -0
- ezmsg/sigproc/gaussiansmoothing.py +89 -0
- ezmsg/sigproc/kaiser.py +106 -0
- ezmsg/sigproc/linear.py +120 -0
- ezmsg/sigproc/math/abs.py +23 -22
- ezmsg/sigproc/math/add.py +120 -0
- ezmsg/sigproc/math/clip.py +33 -25
- ezmsg/sigproc/math/difference.py +117 -43
- ezmsg/sigproc/math/invert.py +18 -25
- ezmsg/sigproc/math/log.py +38 -33
- ezmsg/sigproc/math/scale.py +24 -25
- ezmsg/sigproc/messages.py +1 -2
- ezmsg/sigproc/quantize.py +68 -0
- ezmsg/sigproc/resample.py +278 -0
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +209 -254
- ezmsg/sigproc/scaler.py +93 -218
- ezmsg/sigproc/signalinjector.py +44 -43
- ezmsg/sigproc/slicer.py +74 -102
- ezmsg/sigproc/spectral.py +3 -3
- ezmsg/sigproc/spectrogram.py +70 -70
- ezmsg/sigproc/spectrum.py +187 -173
- ezmsg/sigproc/transpose.py +134 -0
- ezmsg/sigproc/util/__init__.py +0 -0
- ezmsg/sigproc/util/asio.py +25 -0
- ezmsg/sigproc/util/axisarray_buffer.py +365 -0
- ezmsg/sigproc/util/buffer.py +449 -0
- ezmsg/sigproc/util/message.py +17 -0
- ezmsg/sigproc/util/profile.py +23 -0
- ezmsg/sigproc/util/sparse.py +115 -0
- ezmsg/sigproc/util/typeresolution.py +17 -0
- ezmsg/sigproc/wavelets.py +147 -154
- ezmsg/sigproc/window.py +248 -210
- ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
- {ezmsg_sigproc-1.7.0.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -1
- ezmsg/sigproc/synth.py +0 -621
- ezmsg_sigproc-1.7.0.dist-info/METADATA +0 -58
- ezmsg_sigproc-1.7.0.dist-info/RECORD +0 -36
- /ezmsg_sigproc-1.7.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/math/difference.py
CHANGED
|
@@ -1,18 +1,47 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Take the difference between 2 signals or between a signal and a constant value.
|
|
3
|
+
|
|
4
|
+
.. note::
|
|
5
|
+
:obj:`ConstDifferenceTransformer` supports the :doc:`Array API standard </guides/explanations/array_api>`,
|
|
6
|
+
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
|
|
7
|
+
:obj:`DifferenceProcessor` (two-input difference) currently requires NumPy arrays.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
1
11
|
import typing
|
|
12
|
+
from dataclasses import dataclass, field
|
|
2
13
|
|
|
3
|
-
import numpy as np
|
|
4
14
|
import ezmsg.core as ez
|
|
5
|
-
from ezmsg.
|
|
15
|
+
from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
|
|
16
|
+
from ezmsg.baseproc.util.asio import run_coroutine_sync
|
|
6
17
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
18
|
from ezmsg.util.messages.util import replace
|
|
8
19
|
|
|
9
|
-
|
|
20
|
+
|
|
21
|
+
class ConstDifferenceSettings(ez.Settings):
|
|
22
|
+
value: float = 0.0
|
|
23
|
+
"""number to subtract or be subtracted from the input data"""
|
|
24
|
+
|
|
25
|
+
subtrahend: bool = True
|
|
26
|
+
"""If True (default) then value is subtracted from the input data. If False, the input data
|
|
27
|
+
is subtracted from value."""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ConstDifferenceTransformer(BaseTransformer[ConstDifferenceSettings, AxisArray, AxisArray]):
|
|
31
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
32
|
+
return replace(
|
|
33
|
+
message,
|
|
34
|
+
data=(message.data - self.settings.value)
|
|
35
|
+
if self.settings.subtrahend
|
|
36
|
+
else (self.settings.value - message.data),
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ConstDifference(BaseTransformerUnit[ConstDifferenceSettings, AxisArray, AxisArray, ConstDifferenceTransformer]):
|
|
41
|
+
SETTINGS = ConstDifferenceSettings
|
|
10
42
|
|
|
11
43
|
|
|
12
|
-
|
|
13
|
-
def const_difference(
|
|
14
|
-
value: float = 0.0, subtrahend: bool = True
|
|
15
|
-
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
44
|
+
def const_difference(value: float = 0.0, subtrahend: bool = True) -> ConstDifferenceTransformer:
|
|
16
45
|
"""
|
|
17
46
|
result = (in_data - value) if subtrahend else (value - in_data)
|
|
18
47
|
https://en.wikipedia.org/wiki/Template:Arithmetic_operations
|
|
@@ -22,48 +51,93 @@ def const_difference(
|
|
|
22
51
|
subtrahend: If True (default) then value is subtracted from the input data.
|
|
23
52
|
If False, the input data is subtracted from value.
|
|
24
53
|
|
|
25
|
-
Returns:
|
|
26
|
-
|
|
54
|
+
Returns: :obj:`ConstDifferenceTransformer`.
|
|
55
|
+
"""
|
|
56
|
+
return ConstDifferenceTransformer(ConstDifferenceSettings(value=value, subtrahend=subtrahend))
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# --- Two-input Difference ---
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dataclass
|
|
63
|
+
class DifferenceState:
|
|
64
|
+
"""State for Difference processor with two input queues."""
|
|
65
|
+
|
|
66
|
+
queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
|
|
67
|
+
queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
|
|
68
|
+
|
|
27
69
|
|
|
70
|
+
class DifferenceProcessor:
|
|
71
|
+
"""Processor that subtracts two AxisArray signals (A - B).
|
|
72
|
+
|
|
73
|
+
This processor maintains separate queues for two input streams and
|
|
74
|
+
subtracts corresponding messages element-wise. It assumes both inputs
|
|
75
|
+
have compatible shapes and aligned time spans.
|
|
28
76
|
"""
|
|
29
|
-
msg_out = AxisArray(np.array([]), dims=[""])
|
|
30
|
-
while True:
|
|
31
|
-
msg_in: AxisArray = yield msg_out
|
|
32
|
-
msg_out = replace(
|
|
33
|
-
msg_in, data=(msg_in.data - value) if subtrahend else (value - msg_in.data)
|
|
34
|
-
)
|
|
35
77
|
|
|
78
|
+
def __init__(self):
|
|
79
|
+
self._state = DifferenceState()
|
|
36
80
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
81
|
+
@property
|
|
82
|
+
def state(self) -> DifferenceState:
|
|
83
|
+
return self._state
|
|
40
84
|
|
|
85
|
+
@state.setter
|
|
86
|
+
def state(self, state: DifferenceState | bytes | None) -> None:
|
|
87
|
+
if state is not None:
|
|
88
|
+
self._state = state
|
|
41
89
|
|
|
42
|
-
|
|
43
|
-
|
|
90
|
+
def push_a(self, msg: AxisArray) -> None:
|
|
91
|
+
"""Push a message to queue A (minuend)."""
|
|
92
|
+
self._state.queue_a.put_nowait(msg)
|
|
44
93
|
|
|
45
|
-
def
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
94
|
+
def push_b(self, msg: AxisArray) -> None:
|
|
95
|
+
"""Push a message to queue B (subtrahend)."""
|
|
96
|
+
self._state.queue_b.put_nowait(msg)
|
|
97
|
+
|
|
98
|
+
async def __acall__(self) -> AxisArray:
|
|
99
|
+
"""Await and subtract the next messages (A - B)."""
|
|
100
|
+
a = await self._state.queue_a.get()
|
|
101
|
+
b = await self._state.queue_b.get()
|
|
102
|
+
return replace(a, data=a.data - b.data)
|
|
103
|
+
|
|
104
|
+
def __call__(self) -> AxisArray:
|
|
105
|
+
"""Synchronously get and subtract the next messages."""
|
|
106
|
+
return run_coroutine_sync(self.__acall__())
|
|
107
|
+
|
|
108
|
+
# Aliases for legacy interface
|
|
109
|
+
async def __anext__(self) -> AxisArray:
|
|
110
|
+
return await self.__acall__()
|
|
111
|
+
|
|
112
|
+
def __next__(self) -> AxisArray:
|
|
113
|
+
return self.__call__()
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class Difference(ez.Unit):
|
|
117
|
+
"""Subtract two signals (A - B).
|
|
118
|
+
|
|
119
|
+
Assumes compatible/similar axes/dimensions and aligned time spans.
|
|
120
|
+
Messages are paired by arrival order (oldest from each queue).
|
|
121
|
+
|
|
122
|
+
OUTPUT = INPUT_SIGNAL_A - INPUT_SIGNAL_B
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
INPUT_SIGNAL_A = ez.InputStream(AxisArray)
|
|
126
|
+
INPUT_SIGNAL_B = ez.InputStream(AxisArray)
|
|
127
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
128
|
+
|
|
129
|
+
async def initialize(self) -> None:
|
|
130
|
+
self.processor = DifferenceProcessor()
|
|
131
|
+
|
|
132
|
+
@ez.subscriber(INPUT_SIGNAL_A)
|
|
133
|
+
async def on_a(self, msg: AxisArray) -> None:
|
|
134
|
+
self.processor.push_a(msg)
|
|
49
135
|
|
|
136
|
+
@ez.subscriber(INPUT_SIGNAL_B)
|
|
137
|
+
async def on_b(self, msg: AxisArray) -> None:
|
|
138
|
+
self.processor.push_b(msg)
|
|
50
139
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
# class Difference(ez.Unit):
|
|
56
|
-
# SETTINGS = DifferenceSettings
|
|
57
|
-
#
|
|
58
|
-
# INPUT_SIGNAL_1 = ez.InputStream(AxisArray)
|
|
59
|
-
# INPUT_SIGNAL_2 = ez.InputStream(AxisArray)
|
|
60
|
-
# OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
61
|
-
#
|
|
62
|
-
# @ez.subscriber(INPUT_SIGNAL_2, zero_copy=True)
|
|
63
|
-
# @ez.publisher(OUTPUT_SIGNAL)
|
|
64
|
-
# async def on_input_2(self, message: AxisArray) -> typing.AsyncGenerator:
|
|
65
|
-
# # TODO: buffer_2
|
|
66
|
-
# # TODO: take buffer_1 - buffer_2 for ranges that align
|
|
67
|
-
# # TODO: Drop samples from buffer_1 and buffer_2
|
|
68
|
-
# if ret is not None:
|
|
69
|
-
# yield self.OUTPUT_SIGNAL, ret
|
|
140
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
141
|
+
async def output(self) -> typing.AsyncGenerator:
|
|
142
|
+
while True:
|
|
143
|
+
yield self.OUTPUT_SIGNAL, await self.processor.__acall__()
|
ezmsg/sigproc/math/invert.py
CHANGED
|
@@ -1,35 +1,28 @@
|
|
|
1
|
-
|
|
1
|
+
"""
|
|
2
|
+
Compute the multiplicative inverse (1/x) of the data.
|
|
2
3
|
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
4
|
+
.. note::
|
|
5
|
+
This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
|
|
6
|
+
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
|
|
6
10
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
11
|
from ezmsg.util.messages.util import replace
|
|
8
12
|
|
|
9
|
-
from ..base import GenAxisArray
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
@consumer
|
|
13
|
-
def invert() -> typing.Generator[AxisArray, AxisArray, None]:
|
|
14
|
-
"""
|
|
15
|
-
Take the inverse of the data.
|
|
16
|
-
|
|
17
|
-
Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an :obj:`AxisArray`
|
|
18
|
-
with the data payload containing the inversion of the input :obj:`AxisArray` data.
|
|
19
13
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
msg_in: AxisArray = yield msg_out
|
|
24
|
-
msg_out = replace(msg_in, data=1 / msg_in.data)
|
|
14
|
+
class InvertTransformer(BaseTransformer[None, AxisArray, AxisArray]):
|
|
15
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
16
|
+
return replace(message, data=1 / message.data)
|
|
25
17
|
|
|
26
18
|
|
|
27
|
-
class
|
|
28
|
-
pass
|
|
19
|
+
class Invert(BaseTransformerUnit[None, AxisArray, AxisArray, InvertTransformer]): ... # SETTINGS = None
|
|
29
20
|
|
|
30
21
|
|
|
31
|
-
|
|
32
|
-
|
|
22
|
+
def invert() -> InvertTransformer:
|
|
23
|
+
"""
|
|
24
|
+
Take the inverse of the data.
|
|
33
25
|
|
|
34
|
-
|
|
35
|
-
|
|
26
|
+
Returns: :obj:`InvertTransformer`.
|
|
27
|
+
"""
|
|
28
|
+
return InvertTransformer()
|
ezmsg/sigproc/math/log.py
CHANGED
|
@@ -1,19 +1,49 @@
|
|
|
1
|
-
|
|
1
|
+
"""
|
|
2
|
+
Take the logarithm of the data.
|
|
3
|
+
|
|
4
|
+
.. note::
|
|
5
|
+
This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
|
|
6
|
+
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
|
|
7
|
+
"""
|
|
2
8
|
|
|
3
|
-
import numpy as np
|
|
4
9
|
import ezmsg.core as ez
|
|
5
|
-
from
|
|
10
|
+
from array_api_compat import get_namespace
|
|
11
|
+
from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
|
|
6
12
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
13
|
from ezmsg.util.messages.util import replace
|
|
8
14
|
|
|
9
|
-
|
|
15
|
+
|
|
16
|
+
class LogSettings(ez.Settings):
|
|
17
|
+
base: float = 10.0
|
|
18
|
+
"""The base of the logarithm. Default is 10."""
|
|
19
|
+
|
|
20
|
+
clip_zero: bool = False
|
|
21
|
+
"""If True, clip the data to the minimum positive value of the data type before taking the log."""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LogTransformer(BaseTransformer[LogSettings, AxisArray, AxisArray]):
|
|
25
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
26
|
+
xp = get_namespace(message.data)
|
|
27
|
+
data = message.data
|
|
28
|
+
if self.settings.clip_zero:
|
|
29
|
+
# Check if any values are <= 0 and dtype is floating point
|
|
30
|
+
has_non_positive = bool(xp.any(data <= 0))
|
|
31
|
+
is_floating = xp.isdtype(data.dtype, "real floating")
|
|
32
|
+
if has_non_positive and is_floating:
|
|
33
|
+
# Use smallest_normal (Array API equivalent of numpy's finfo.tiny)
|
|
34
|
+
min_val = xp.finfo(data.dtype).smallest_normal
|
|
35
|
+
data = xp.clip(data, min_val, None)
|
|
36
|
+
return replace(message, data=xp.log(data) / xp.log(self.settings.base))
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class Log(BaseTransformerUnit[LogSettings, AxisArray, AxisArray, LogTransformer]):
|
|
40
|
+
SETTINGS = LogSettings
|
|
10
41
|
|
|
11
42
|
|
|
12
|
-
@consumer
|
|
13
43
|
def log(
|
|
14
44
|
base: float = 10.0,
|
|
15
45
|
clip_zero: bool = False,
|
|
16
|
-
) ->
|
|
46
|
+
) -> LogTransformer:
|
|
17
47
|
"""
|
|
18
48
|
Take the logarithm of the data. See :obj:`np.log` for more details.
|
|
19
49
|
|
|
@@ -21,32 +51,7 @@ def log(
|
|
|
21
51
|
base: The base of the logarithm. Default is 10.
|
|
22
52
|
clip_zero: If True, clip the data to the minimum positive value of the data type before taking the log.
|
|
23
53
|
|
|
24
|
-
Returns:
|
|
25
|
-
with the data payload containing the logarithm of the input :obj:`AxisArray` data.
|
|
54
|
+
Returns: :obj:`LogTransformer`.
|
|
26
55
|
|
|
27
56
|
"""
|
|
28
|
-
|
|
29
|
-
log_base = np.log(base)
|
|
30
|
-
while True:
|
|
31
|
-
msg_in: AxisArray = yield msg_out
|
|
32
|
-
if (
|
|
33
|
-
clip_zero
|
|
34
|
-
and np.any(msg_in.data <= 0)
|
|
35
|
-
and np.issubdtype(msg_in.data.dtype, np.floating)
|
|
36
|
-
):
|
|
37
|
-
msg_in.data = np.clip(
|
|
38
|
-
msg_in.data, a_min=np.finfo(msg_in.data.dtype).tiny, a_max=None
|
|
39
|
-
)
|
|
40
|
-
msg_out = replace(msg_in, data=np.log(msg_in.data) / log_base)
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
class LogSettings(ez.Settings):
|
|
44
|
-
base: float = 10.0
|
|
45
|
-
clip_zero: bool = False
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
class Log(GenAxisArray):
|
|
49
|
-
SETTINGS = LogSettings
|
|
50
|
-
|
|
51
|
-
def construct_generator(self):
|
|
52
|
-
self.STATE.gen = log(base=self.SETTINGS.base, clip_zero=self.SETTINGS.clip_zero)
|
|
57
|
+
return LogTransformer(LogSettings(base=base, clip_zero=clip_zero))
|
ezmsg/sigproc/math/scale.py
CHANGED
|
@@ -1,40 +1,39 @@
|
|
|
1
|
-
|
|
1
|
+
"""
|
|
2
|
+
Scale the data by a constant factor.
|
|
3
|
+
|
|
4
|
+
.. note::
|
|
5
|
+
This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
|
|
6
|
+
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
|
|
7
|
+
"""
|
|
2
8
|
|
|
3
|
-
import numpy as np
|
|
4
9
|
import ezmsg.core as ez
|
|
5
|
-
from ezmsg.
|
|
10
|
+
from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
|
|
6
11
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
12
|
from ezmsg.util.messages.util import replace
|
|
8
13
|
|
|
9
|
-
from ..base import GenAxisArray
|
|
10
14
|
|
|
15
|
+
class ScaleSettings(ez.Settings):
|
|
16
|
+
scale: float = 1.0
|
|
17
|
+
"""Factor by which to scale the data magnitude."""
|
|
11
18
|
|
|
12
|
-
@consumer
|
|
13
|
-
def scale(scale: float = 1.0) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
14
|
-
"""
|
|
15
|
-
Scale the data by a constant factor.
|
|
16
19
|
|
|
17
|
-
|
|
18
|
-
|
|
20
|
+
class ScaleTransformer(BaseTransformer[ScaleSettings, AxisArray, AxisArray]):
|
|
21
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
22
|
+
return replace(message, data=self.settings.scale * message.data)
|
|
19
23
|
|
|
20
|
-
Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an :obj:`AxisArray`
|
|
21
|
-
with the data payload containing the input :obj:`AxisArray` data scaled by a constant factor.
|
|
22
24
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
while True:
|
|
26
|
-
msg_in: AxisArray = yield msg_out
|
|
27
|
-
msg_out = replace(msg_in, data=scale * msg_in.data)
|
|
25
|
+
class Scale(BaseTransformerUnit[ScaleSettings, AxisArray, AxisArray, ScaleTransformer]):
|
|
26
|
+
SETTINGS = ScaleSettings
|
|
28
27
|
|
|
29
28
|
|
|
30
|
-
|
|
31
|
-
|
|
29
|
+
def scale(scale: float = 1.0) -> ScaleTransformer:
|
|
30
|
+
"""
|
|
31
|
+
Scale the data by a constant factor.
|
|
32
32
|
|
|
33
|
+
Args:
|
|
34
|
+
scale: Factor by which to scale the data magnitude.
|
|
33
35
|
|
|
34
|
-
|
|
35
|
-
SETTINGS = ScaleSettings
|
|
36
|
+
Returns: :obj:`ScaleTransformer`
|
|
36
37
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
scale=self.SETTINGS.scale,
|
|
40
|
-
)
|
|
38
|
+
"""
|
|
39
|
+
return ScaleTransformer(ScaleSettings(scale=scale))
|
ezmsg/sigproc/messages.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
|
1
|
-
import warnings
|
|
2
1
|
import time
|
|
2
|
+
import warnings
|
|
3
3
|
|
|
4
4
|
import numpy.typing as npt
|
|
5
5
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
6
6
|
|
|
7
|
-
|
|
8
7
|
# UPCOMING: TSMessage Deprecation
|
|
9
8
|
# TSMessage is deprecated because it doesn't handle multiple time axes well.
|
|
10
9
|
# AxisArray has an incompatible API but supports a superset of functionality.
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import ezmsg.core as ez
|
|
2
|
+
import numpy as np
|
|
3
|
+
from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
|
|
4
|
+
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class QuantizeSettings(ez.Settings):
|
|
8
|
+
"""
|
|
9
|
+
Settings for the Quantizer.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
max_val: float
|
|
13
|
+
"""
|
|
14
|
+
Clip the data to this maximum value before quantization and map the [min_val max_val] range to the quantized range.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
min_val: float = 0.0
|
|
18
|
+
"""
|
|
19
|
+
Clip the data to this minimum value before quantization and map the [min_val max_val] range to the quantized range.
|
|
20
|
+
Default: 0
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
bits: int = 8
|
|
24
|
+
"""
|
|
25
|
+
Number of bits for quantization.
|
|
26
|
+
Note: The data type will be integer of the next power of 2 greater than or equal to this value.
|
|
27
|
+
Default: 8
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class QuantizeTransformer(BaseTransformer[QuantizeSettings, AxisArray, AxisArray]):
|
|
32
|
+
def _process(
|
|
33
|
+
self,
|
|
34
|
+
message: AxisArray,
|
|
35
|
+
) -> AxisArray:
|
|
36
|
+
expected_range = self.settings.max_val - self.settings.min_val
|
|
37
|
+
scale_factor = 2**self.settings.bits - 1
|
|
38
|
+
clip_max = self.settings.max_val
|
|
39
|
+
|
|
40
|
+
# Determine appropriate integer type based on bits
|
|
41
|
+
if self.settings.bits <= 1:
|
|
42
|
+
dtype = bool
|
|
43
|
+
elif self.settings.bits <= 8:
|
|
44
|
+
dtype = np.uint8
|
|
45
|
+
elif self.settings.bits <= 16:
|
|
46
|
+
dtype = np.uint16
|
|
47
|
+
elif self.settings.bits <= 32:
|
|
48
|
+
dtype = np.uint32
|
|
49
|
+
else:
|
|
50
|
+
dtype = np.uint64
|
|
51
|
+
if self.settings.bits == 64:
|
|
52
|
+
# The practical upper bound before converting to int is: 2**64 - 1025
|
|
53
|
+
# Anything larger will wrap around to 0.
|
|
54
|
+
#
|
|
55
|
+
clip_max *= 1 - 2e-16
|
|
56
|
+
|
|
57
|
+
data = message.data.clip(self.settings.min_val, clip_max)
|
|
58
|
+
data = (data - self.settings.min_val) / expected_range
|
|
59
|
+
|
|
60
|
+
# Scale to the quantized range [0, 2^bits - 1]
|
|
61
|
+
data = np.rint(scale_factor * data).astype(dtype)
|
|
62
|
+
|
|
63
|
+
# Create a new AxisArray with the quantized data
|
|
64
|
+
return replace(message, data=data)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class QuantizerUnit(BaseTransformerUnit[QuantizeSettings, AxisArray, AxisArray, QuantizeTransformer]):
|
|
68
|
+
SETTINGS = QuantizeSettings
|