ezmsg-sigproc 1.8.2__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +36 -39
- ezmsg/sigproc/adaptive_lattice_notch.py +231 -0
- ezmsg/sigproc/affinetransform.py +169 -163
- ezmsg/sigproc/aggregate.py +119 -104
- ezmsg/sigproc/bandpower.py +58 -52
- ezmsg/sigproc/base.py +1242 -0
- ezmsg/sigproc/butterworthfilter.py +37 -33
- ezmsg/sigproc/cheby.py +29 -17
- ezmsg/sigproc/combfilter.py +163 -0
- ezmsg/sigproc/decimate.py +19 -10
- ezmsg/sigproc/detrend.py +29 -0
- ezmsg/sigproc/diff.py +81 -0
- ezmsg/sigproc/downsample.py +78 -84
- ezmsg/sigproc/ewma.py +197 -0
- ezmsg/sigproc/extract_axis.py +41 -0
- ezmsg/sigproc/filter.py +257 -141
- ezmsg/sigproc/filterbank.py +247 -199
- ezmsg/sigproc/math/abs.py +17 -22
- ezmsg/sigproc/math/clip.py +24 -24
- ezmsg/sigproc/math/difference.py +34 -30
- ezmsg/sigproc/math/invert.py +13 -25
- ezmsg/sigproc/math/log.py +28 -33
- ezmsg/sigproc/math/scale.py +18 -26
- ezmsg/sigproc/quantize.py +71 -0
- ezmsg/sigproc/resample.py +298 -0
- ezmsg/sigproc/sampler.py +241 -259
- ezmsg/sigproc/scaler.py +55 -218
- ezmsg/sigproc/signalinjector.py +52 -43
- ezmsg/sigproc/slicer.py +81 -89
- ezmsg/sigproc/spectrogram.py +77 -75
- ezmsg/sigproc/spectrum.py +203 -168
- ezmsg/sigproc/synth.py +546 -393
- ezmsg/sigproc/transpose.py +131 -0
- ezmsg/sigproc/util/asio.py +156 -0
- ezmsg/sigproc/util/message.py +31 -0
- ezmsg/sigproc/util/profile.py +55 -12
- ezmsg/sigproc/util/typeresolution.py +83 -0
- ezmsg/sigproc/wavelets.py +154 -153
- ezmsg/sigproc/window.py +269 -211
- {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.0.0.dist-info}/METADATA +2 -1
- ezmsg_sigproc-2.0.0.dist-info/RECORD +51 -0
- ezmsg_sigproc-1.8.2.dist-info/RECORD +0 -39
- {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.0.0.dist-info}/WHEEL +0 -0
- {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.0.0.dist-info}/licenses/LICENSE.txt +0 -0
ezmsg/sigproc/downsample.py
CHANGED
|
@@ -1,85 +1,86 @@
|
|
|
1
|
-
import typing
|
|
2
|
-
|
|
3
1
|
import numpy as np
|
|
4
2
|
from ezmsg.util.messages.axisarray import (
|
|
5
3
|
AxisArray,
|
|
6
4
|
slice_along_axis,
|
|
7
5
|
replace,
|
|
8
6
|
)
|
|
9
|
-
from ezmsg.util.generator import consumer
|
|
10
7
|
import ezmsg.core as ez
|
|
11
8
|
|
|
12
|
-
from .base import
|
|
9
|
+
from .base import (
|
|
10
|
+
BaseStatefulTransformer,
|
|
11
|
+
BaseTransformerUnit,
|
|
12
|
+
processor_state,
|
|
13
|
+
)
|
|
13
14
|
|
|
14
15
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
16
|
+
class DownsampleSettings(ez.Settings):
|
|
17
|
+
"""
|
|
18
|
+
Settings for :obj:`Downsample` node.
|
|
19
19
|
"""
|
|
20
|
-
Construct a generator that yields a downsampled version of the data .send() to it.
|
|
21
|
-
Downsampled data simply comprise every `factor`th sample.
|
|
22
|
-
This should only be used following appropriate lowpass filtering.
|
|
23
|
-
If your pipeline does not already have lowpass filtering then consider
|
|
24
|
-
using the :obj:`Decimate` collection instead.
|
|
25
20
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
Note: The axis must exist in the message .axes and be of type AxisArray.LinearAxis.
|
|
29
|
-
target_rate: Desired rate after downsampling. The actual rate will be the nearest integer factor of the
|
|
30
|
-
input rate that is the same or higher than the target rate.
|
|
31
|
-
factor: Explicitly specify downsample factor. If specified, target_rate is ignored.
|
|
21
|
+
axis: str = "time"
|
|
22
|
+
"""The name of the axis along which to downsample."""
|
|
32
23
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
24
|
+
target_rate: float | None = None
|
|
25
|
+
"""Desired rate after downsampling. The actual rate will be the nearest integer factor of the
|
|
26
|
+
input rate that is the same or higher than the target rate."""
|
|
27
|
+
|
|
28
|
+
factor: int | None = None
|
|
29
|
+
"""Explicitly specify downsample factor. If specified, target_rate is ignored."""
|
|
38
30
|
|
|
39
|
-
"""
|
|
40
|
-
msg_out = AxisArray(np.array([]), dims=[""])
|
|
41
31
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
32
|
+
@processor_state
|
|
33
|
+
class DownsampleState:
|
|
34
|
+
q: int = 0
|
|
35
|
+
"""The integer downsampling factor. It will be determined based on the target rate."""
|
|
45
36
|
|
|
46
|
-
|
|
37
|
+
s_idx: int = 0
|
|
38
|
+
"""Index of the next msg's first sample into the virtual rotating ds_factor counter."""
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class DownsampleTransformer(
|
|
42
|
+
BaseStatefulTransformer[DownsampleSettings, AxisArray, AxisArray, DownsampleState]
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
Downsampled data simply comprise every `factor`th sample.
|
|
46
|
+
This should only be used following appropriate lowpass filtering.
|
|
47
|
+
If your pipeline does not already have lowpass filtering then consider
|
|
48
|
+
using the :obj:`Decimate` collection instead.
|
|
49
|
+
"""
|
|
47
50
|
|
|
48
|
-
|
|
49
|
-
|
|
51
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
52
|
+
return hash((message.axes[self.settings.axis].gain, message.key))
|
|
50
53
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
axis_info = msg_in.get_axis(axis)
|
|
54
|
-
axis_idx = msg_in.get_axis_idx(axis)
|
|
54
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
55
|
+
axis_info = message.get_axis(self.settings.axis)
|
|
55
56
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
57
|
+
if self.settings.factor is not None:
|
|
58
|
+
q = self.settings.factor
|
|
59
|
+
elif self.settings.target_rate is None:
|
|
60
|
+
q = 1
|
|
61
|
+
else:
|
|
62
|
+
q = int(1 / (axis_info.gain * self.settings.target_rate))
|
|
63
|
+
if q < 1:
|
|
64
|
+
ez.logger.warning(
|
|
65
|
+
f"Target rate {self.settings.target_rate} cannot be achieved with input rate of {1 / axis_info.gain}."
|
|
66
|
+
"Setting factor to 1."
|
|
67
|
+
)
|
|
68
|
+
q = 1
|
|
69
|
+
self._state.q = q
|
|
70
|
+
self._state.s_idx = 0
|
|
71
|
+
|
|
72
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
73
|
+
axis = self.settings.axis
|
|
74
|
+
axis_info = message.get_axis(axis)
|
|
75
|
+
axis_idx = message.get_axis_idx(axis)
|
|
76
|
+
|
|
77
|
+
n_samples = message.data.shape[axis_idx]
|
|
78
|
+
samples = (
|
|
79
|
+
np.arange(self.state.s_idx, self.state.s_idx + n_samples) % self._state.q
|
|
59
80
|
)
|
|
60
|
-
if b_reset:
|
|
61
|
-
check_input["gain"] = axis_info.gain
|
|
62
|
-
check_input["key"] = msg_in.key
|
|
63
|
-
# Reset state variables
|
|
64
|
-
s_idx = 0
|
|
65
|
-
if factor is not None:
|
|
66
|
-
q = factor
|
|
67
|
-
elif target_rate is None:
|
|
68
|
-
q = 1
|
|
69
|
-
else:
|
|
70
|
-
q = int(1 / (axis_info.gain * target_rate))
|
|
71
|
-
if q < 1:
|
|
72
|
-
ez.logger.warning(
|
|
73
|
-
f"Target rate {target_rate} cannot be achieved with input rate of {1/axis_info.gain}."
|
|
74
|
-
"Setting factor to 1."
|
|
75
|
-
)
|
|
76
|
-
q = 1
|
|
77
|
-
|
|
78
|
-
n_samples = msg_in.data.shape[axis_idx]
|
|
79
|
-
samples = np.arange(s_idx, s_idx + n_samples) % q
|
|
80
81
|
if n_samples > 0:
|
|
81
82
|
# Update state for next iteration.
|
|
82
|
-
s_idx = samples[-1] + 1
|
|
83
|
+
self._state.s_idx = samples[-1] + 1
|
|
83
84
|
|
|
84
85
|
pub_samples = np.where(samples == 0)[0]
|
|
85
86
|
if len(pub_samples) > 0:
|
|
@@ -89,38 +90,31 @@ def downsample(
|
|
|
89
90
|
n_step = 0
|
|
90
91
|
data_slice = slice(None, 0, None)
|
|
91
92
|
msg_out = replace(
|
|
92
|
-
|
|
93
|
-
data=slice_along_axis(
|
|
93
|
+
message,
|
|
94
|
+
data=slice_along_axis(message.data, data_slice, axis=axis_idx),
|
|
94
95
|
axes={
|
|
95
|
-
**
|
|
96
|
+
**message.axes,
|
|
96
97
|
axis: replace(
|
|
97
98
|
axis_info,
|
|
98
|
-
gain=axis_info.gain * q,
|
|
99
|
+
gain=axis_info.gain * self._state.q,
|
|
99
100
|
offset=axis_info.offset + axis_info.gain * n_step,
|
|
100
101
|
),
|
|
101
102
|
},
|
|
102
103
|
)
|
|
104
|
+
return msg_out
|
|
103
105
|
|
|
104
106
|
|
|
105
|
-
class
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
See :obj:`downsample` documentation for a description of the parameters.
|
|
109
|
-
"""
|
|
110
|
-
|
|
111
|
-
axis: str | None = None
|
|
112
|
-
target_rate: float | None = None
|
|
113
|
-
factor: int | None = None
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
class Downsample(GenAxisArray):
|
|
117
|
-
""":obj:`Unit` for :obj:`bandpower`."""
|
|
118
|
-
|
|
107
|
+
class Downsample(
|
|
108
|
+
BaseTransformerUnit[DownsampleSettings, AxisArray, AxisArray, DownsampleTransformer]
|
|
109
|
+
):
|
|
119
110
|
SETTINGS = DownsampleSettings
|
|
120
111
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
112
|
+
|
|
113
|
+
def downsample(
|
|
114
|
+
axis: str = "time",
|
|
115
|
+
target_rate: float | None = None,
|
|
116
|
+
factor: int | None = None,
|
|
117
|
+
) -> DownsampleTransformer:
|
|
118
|
+
return DownsampleTransformer(
|
|
119
|
+
DownsampleSettings(axis=axis, target_rate=target_rate, factor=factor)
|
|
120
|
+
)
|
ezmsg/sigproc/ewma.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
from dataclasses import field
|
|
2
|
+
import functools
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numpy.typing as npt
|
|
6
|
+
import scipy.signal as sps
|
|
7
|
+
import ezmsg.core as ez
|
|
8
|
+
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
9
|
+
from ezmsg.util.messages.util import replace
|
|
10
|
+
|
|
11
|
+
from .base import BaseStatefulTransformer, processor_state, BaseTransformerUnit
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _tau_from_alpha(alpha: float, dt: float) -> float:
|
|
15
|
+
"""
|
|
16
|
+
Inverse of _alpha_from_tau. See that function for explanation.
|
|
17
|
+
"""
|
|
18
|
+
return -dt / np.log(1 - alpha)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _alpha_from_tau(tau: float, dt: float) -> float:
|
|
22
|
+
"""
|
|
23
|
+
# https://en.wikipedia.org/wiki/Exponential_smoothing#Time_constant
|
|
24
|
+
:param tau: The amount of time for the smoothed response of a unit step function to reach
|
|
25
|
+
1 - 1/e approx-eq 63.2%.
|
|
26
|
+
:param dt: sampling period, or 1 / sampling_rate.
|
|
27
|
+
:return: alpha, the "fading factor" in exponential smoothing.
|
|
28
|
+
"""
|
|
29
|
+
return 1 - np.exp(-dt / tau)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def ewma_step(
|
|
33
|
+
sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
Do an exponentially weighted moving average step.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
sample: The new sample.
|
|
40
|
+
zi: The output of the previous step.
|
|
41
|
+
alpha: Fading factor.
|
|
42
|
+
beta: Persisting factor. If None, it is calculated as 1-alpha.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
alpha * sample + beta * zi
|
|
46
|
+
|
|
47
|
+
"""
|
|
48
|
+
# Potential micro-optimization:
|
|
49
|
+
# Current: scalar-arr multiplication, scalar-arr multiplication, arr-arr addition
|
|
50
|
+
# Alternative: arr-arr subtraction, arr-arr multiplication, arr-arr addition
|
|
51
|
+
# return zi + alpha * (new_sample - zi)
|
|
52
|
+
beta = beta or (1 - alpha)
|
|
53
|
+
return alpha * sample + beta * zi
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class EWMA_Deprecated:
|
|
57
|
+
"""
|
|
58
|
+
Grabbed these methods from https://stackoverflow.com/a/70998068 and other answers in that topic,
|
|
59
|
+
but they ended up being slower than the scipy.signal.lfilter method.
|
|
60
|
+
Additionally, `compute` and `compute2` suffer from potential errors as the vector length increases
|
|
61
|
+
and beta**n approaches zero.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(self, alpha: float, max_len: int):
|
|
65
|
+
self.alpha = alpha
|
|
66
|
+
self.beta = 1 - alpha
|
|
67
|
+
self.prev: npt.NDArray | None = None
|
|
68
|
+
self.weights = np.empty((max_len + 1,), float)
|
|
69
|
+
self._precalc_weights(max_len)
|
|
70
|
+
self._step_func = functools.partial(ewma_step, alpha=self.alpha, beta=self.beta)
|
|
71
|
+
|
|
72
|
+
def _precalc_weights(self, n: int):
|
|
73
|
+
# (1-α)^0, (1-α)^1, (1-α)^2, ..., (1-α)^n
|
|
74
|
+
np.power(self.beta, np.arange(n + 1), out=self.weights)
|
|
75
|
+
|
|
76
|
+
def compute(self, arr: npt.NDArray, out: npt.NDArray | None = None) -> npt.NDArray:
|
|
77
|
+
if out is None:
|
|
78
|
+
out = np.empty(arr.shape, arr.dtype)
|
|
79
|
+
|
|
80
|
+
n = arr.shape[0]
|
|
81
|
+
weights = self.weights[:n]
|
|
82
|
+
weights = np.expand_dims(weights, list(range(1, arr.ndim)))
|
|
83
|
+
|
|
84
|
+
# α*P0, α*P1, α*P2, ..., α*Pn
|
|
85
|
+
np.multiply(self.alpha, arr, out)
|
|
86
|
+
|
|
87
|
+
# α*P0/(1-α)^0, α*P1/(1-α)^1, α*P2/(1-α)^2, ..., α*Pn/(1-α)^n
|
|
88
|
+
np.divide(out, weights, out)
|
|
89
|
+
|
|
90
|
+
# α*P0/(1-α)^0, α*P0/(1-α)^0 + α*P1/(1-α)^1, ...
|
|
91
|
+
np.cumsum(out, axis=0, out=out)
|
|
92
|
+
|
|
93
|
+
# (α*P0/(1-α)^0)*(1-α)^0, (α*P0/(1-α)^0 + α*P1/(1-α)^1)*(1-α)^1, ...
|
|
94
|
+
np.multiply(out, weights, out)
|
|
95
|
+
|
|
96
|
+
# Add the previous output
|
|
97
|
+
if self.prev is None:
|
|
98
|
+
self.prev = arr[:1]
|
|
99
|
+
|
|
100
|
+
out += self.prev * np.expand_dims(
|
|
101
|
+
self.weights[1 : n + 1], list(range(1, arr.ndim))
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
self.prev = out[-1:]
|
|
105
|
+
|
|
106
|
+
return out
|
|
107
|
+
|
|
108
|
+
def compute2(self, arr: npt.NDArray) -> npt.NDArray:
|
|
109
|
+
"""
|
|
110
|
+
Compute the Exponentially Weighted Moving Average (EWMA) of the input array.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
arr: The input array to be smoothed.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
The smoothed array.
|
|
117
|
+
"""
|
|
118
|
+
n = arr.shape[0]
|
|
119
|
+
if n > len(self.weights):
|
|
120
|
+
self._precalc_weights(n)
|
|
121
|
+
weights = self.weights[:n][::-1]
|
|
122
|
+
weights = np.expand_dims(weights, list(range(1, arr.ndim)))
|
|
123
|
+
|
|
124
|
+
result = np.cumsum(self.alpha * weights * arr, axis=0)
|
|
125
|
+
result = result / weights
|
|
126
|
+
|
|
127
|
+
# Handle the first call when prev is unset
|
|
128
|
+
if self.prev is None:
|
|
129
|
+
self.prev = arr[:1]
|
|
130
|
+
|
|
131
|
+
result += self.prev * np.expand_dims(
|
|
132
|
+
self.weights[1 : n + 1], list(range(1, arr.ndim))
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Store the result back into prev
|
|
136
|
+
self.prev = result[-1]
|
|
137
|
+
|
|
138
|
+
return result
|
|
139
|
+
|
|
140
|
+
def compute_sample(self, new_sample: npt.NDArray) -> npt.NDArray:
|
|
141
|
+
if self.prev is None:
|
|
142
|
+
self.prev = new_sample
|
|
143
|
+
self.prev = self._step_func(new_sample, self.prev)
|
|
144
|
+
return self.prev
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class EWMASettings(ez.Settings):
|
|
148
|
+
time_constant: float = 1.0
|
|
149
|
+
axis: str | None = None
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@processor_state
|
|
153
|
+
class EWMAState:
|
|
154
|
+
alpha: float = field(default_factory=lambda: _alpha_from_tau(1.0, 1000.0))
|
|
155
|
+
zi: npt.NDArray | None = None
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class EWMATransformer(
|
|
159
|
+
BaseStatefulTransformer[EWMASettings, AxisArray, AxisArray, EWMAState]
|
|
160
|
+
):
|
|
161
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
162
|
+
axis = self.settings.axis or message.dims[0]
|
|
163
|
+
axis_idx = message.get_axis_idx(axis)
|
|
164
|
+
sample_shape = (
|
|
165
|
+
message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
166
|
+
)
|
|
167
|
+
return hash((sample_shape, message.axes[axis].gain, message.key))
|
|
168
|
+
|
|
169
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
170
|
+
axis = self.settings.axis or message.dims[0]
|
|
171
|
+
self._state.alpha = _alpha_from_tau(
|
|
172
|
+
self.settings.time_constant, message.axes[axis].gain
|
|
173
|
+
)
|
|
174
|
+
sub_dat = slice_along_axis(
|
|
175
|
+
message.data, slice(None, 1, None), axis=message.get_axis_idx(axis)
|
|
176
|
+
)
|
|
177
|
+
self._state.zi = (1 - self._state.alpha) * sub_dat
|
|
178
|
+
|
|
179
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
180
|
+
if np.prod(message.data.shape) == 0:
|
|
181
|
+
return message
|
|
182
|
+
axis = self.settings.axis or message.dims[0]
|
|
183
|
+
axis_idx = message.get_axis_idx(axis)
|
|
184
|
+
expected, self._state.zi = sps.lfilter(
|
|
185
|
+
[self._state.alpha],
|
|
186
|
+
[1.0, self._state.alpha - 1.0],
|
|
187
|
+
message.data,
|
|
188
|
+
axis=axis_idx,
|
|
189
|
+
zi=self._state.zi,
|
|
190
|
+
)
|
|
191
|
+
return replace(message, data=expected)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class EWMAUnit(
|
|
195
|
+
BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, EWMATransformer]
|
|
196
|
+
):
|
|
197
|
+
SETTINGS = EWMASettings
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import ezmsg.core as ez
|
|
3
|
+
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
4
|
+
from ezmsg.sigproc.base import BaseTransformer, BaseTransformerUnit
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ExtractAxisSettings(ez.Settings):
|
|
8
|
+
axis: str = "freq"
|
|
9
|
+
reference: str = "time"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ExtractAxisData(BaseTransformer[ExtractAxisSettings, AxisArray, AxisArray]):
|
|
13
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
14
|
+
targ_ax = message.axes[self.settings.axis]
|
|
15
|
+
if hasattr(targ_ax, "data"):
|
|
16
|
+
# Extracted axis is of type CoordinateAxis
|
|
17
|
+
return replace(
|
|
18
|
+
message,
|
|
19
|
+
data=targ_ax.data,
|
|
20
|
+
dims=targ_ax.dims,
|
|
21
|
+
axes={k: v for k, v in message.axes.items() if k in targ_ax.dims},
|
|
22
|
+
)
|
|
23
|
+
# Note: So far we don't have any transformers where the coordinate axis has its own axes,
|
|
24
|
+
# but if that happens in the future, we'd need to consider how to handle that.
|
|
25
|
+
|
|
26
|
+
else:
|
|
27
|
+
# Extracted axis is of type LinearAxis
|
|
28
|
+
# LinearAxis can only yield a 1d array data which simplifies dims and axes.
|
|
29
|
+
n = message.data.shape[message.get_axis_idx(self.settings.reference)]
|
|
30
|
+
return replace(
|
|
31
|
+
message,
|
|
32
|
+
data=targ_ax.value(np.arange(n)),
|
|
33
|
+
dims=[self.settings.reference],
|
|
34
|
+
axes={self.settings.reference: message.axes[self.settings.reference]},
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ExtractAxisDataUnit(
|
|
39
|
+
BaseTransformerUnit[ExtractAxisSettings, AxisArray, AxisArray, ExtractAxisData]
|
|
40
|
+
):
|
|
41
|
+
SETTINGS = ExtractAxisSettings
|