ezmsg-sigproc 1.2.2__py3-none-any.whl → 2.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +1 -1
- ezmsg/sigproc/__version__.py +34 -1
- ezmsg/sigproc/activation.py +78 -0
- ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
- ezmsg/sigproc/affinetransform.py +235 -0
- ezmsg/sigproc/aggregate.py +276 -0
- ezmsg/sigproc/bandpower.py +80 -0
- ezmsg/sigproc/base.py +149 -0
- ezmsg/sigproc/butterworthfilter.py +129 -39
- ezmsg/sigproc/butterworthzerophase.py +305 -0
- ezmsg/sigproc/cheby.py +125 -0
- ezmsg/sigproc/combfilter.py +160 -0
- ezmsg/sigproc/coordinatespaces.py +159 -0
- ezmsg/sigproc/decimate.py +46 -18
- ezmsg/sigproc/denormalize.py +78 -0
- ezmsg/sigproc/detrend.py +28 -0
- ezmsg/sigproc/diff.py +82 -0
- ezmsg/sigproc/downsample.py +97 -49
- ezmsg/sigproc/ewma.py +217 -0
- ezmsg/sigproc/ewmfilter.py +45 -19
- ezmsg/sigproc/extract_axis.py +39 -0
- ezmsg/sigproc/fbcca.py +307 -0
- ezmsg/sigproc/filter.py +282 -117
- ezmsg/sigproc/filterbank.py +292 -0
- ezmsg/sigproc/filterbankdesign.py +129 -0
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +117 -0
- ezmsg/sigproc/gaussiansmoothing.py +89 -0
- ezmsg/sigproc/kaiser.py +106 -0
- ezmsg/sigproc/linear.py +120 -0
- ezmsg/sigproc/math/__init__.py +0 -0
- ezmsg/sigproc/math/abs.py +35 -0
- ezmsg/sigproc/math/add.py +120 -0
- ezmsg/sigproc/math/clip.py +48 -0
- ezmsg/sigproc/math/difference.py +143 -0
- ezmsg/sigproc/math/invert.py +28 -0
- ezmsg/sigproc/math/log.py +57 -0
- ezmsg/sigproc/math/scale.py +39 -0
- ezmsg/sigproc/messages.py +3 -6
- ezmsg/sigproc/quantize.py +68 -0
- ezmsg/sigproc/resample.py +278 -0
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +232 -241
- ezmsg/sigproc/scaler.py +165 -0
- ezmsg/sigproc/signalinjector.py +70 -0
- ezmsg/sigproc/slicer.py +138 -0
- ezmsg/sigproc/spectral.py +6 -132
- ezmsg/sigproc/spectrogram.py +90 -0
- ezmsg/sigproc/spectrum.py +277 -0
- ezmsg/sigproc/transpose.py +134 -0
- ezmsg/sigproc/util/__init__.py +0 -0
- ezmsg/sigproc/util/asio.py +25 -0
- ezmsg/sigproc/util/axisarray_buffer.py +365 -0
- ezmsg/sigproc/util/buffer.py +449 -0
- ezmsg/sigproc/util/message.py +17 -0
- ezmsg/sigproc/util/profile.py +23 -0
- ezmsg/sigproc/util/sparse.py +115 -0
- ezmsg/sigproc/util/typeresolution.py +17 -0
- ezmsg/sigproc/wavelets.py +187 -0
- ezmsg/sigproc/window.py +301 -117
- ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -2
- ezmsg/sigproc/synth.py +0 -411
- ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
- ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
- ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
- /ezmsg_sigproc-1.2.2.dist-info/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
import enum
|
|
2
|
+
import typing
|
|
3
|
+
from functools import partial
|
|
4
|
+
|
|
5
|
+
import ezmsg.core as ez
|
|
6
|
+
import numpy as np
|
|
7
|
+
import numpy.typing as npt
|
|
8
|
+
from ezmsg.baseproc import (
|
|
9
|
+
BaseStatefulTransformer,
|
|
10
|
+
BaseTransformerUnit,
|
|
11
|
+
processor_state,
|
|
12
|
+
)
|
|
13
|
+
from ezmsg.util.messages.axisarray import (
|
|
14
|
+
AxisArray,
|
|
15
|
+
replace,
|
|
16
|
+
slice_along_axis,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class OptionsEnum(enum.Enum):
|
|
21
|
+
@classmethod
|
|
22
|
+
def options(cls):
|
|
23
|
+
return list(map(lambda c: c.value, cls))
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class WindowFunction(OptionsEnum):
|
|
27
|
+
"""Windowing function prior to calculating spectrum."""
|
|
28
|
+
|
|
29
|
+
NONE = "None (Rectangular)"
|
|
30
|
+
"""None."""
|
|
31
|
+
|
|
32
|
+
HAMMING = "Hamming"
|
|
33
|
+
""":obj:`numpy.hamming`"""
|
|
34
|
+
|
|
35
|
+
HANNING = "Hanning"
|
|
36
|
+
""":obj:`numpy.hanning`"""
|
|
37
|
+
|
|
38
|
+
BARTLETT = "Bartlett"
|
|
39
|
+
""":obj:`numpy.bartlett`"""
|
|
40
|
+
|
|
41
|
+
BLACKMAN = "Blackman"
|
|
42
|
+
""":obj:`numpy.blackman`"""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
WINDOWS = {
|
|
46
|
+
WindowFunction.NONE: np.ones,
|
|
47
|
+
WindowFunction.HAMMING: np.hamming,
|
|
48
|
+
WindowFunction.HANNING: np.hanning,
|
|
49
|
+
WindowFunction.BARTLETT: np.bartlett,
|
|
50
|
+
WindowFunction.BLACKMAN: np.blackman,
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class SpectralTransform(OptionsEnum):
|
|
55
|
+
"""Additional transformation functions to apply to the spectral result."""
|
|
56
|
+
|
|
57
|
+
RAW_COMPLEX = "Complex FFT Output"
|
|
58
|
+
REAL = "Real Component of FFT"
|
|
59
|
+
IMAG = "Imaginary Component of FFT"
|
|
60
|
+
REL_POWER = "Relative Power"
|
|
61
|
+
REL_DB = "Log Power (Relative dB)"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class SpectralOutput(OptionsEnum):
|
|
65
|
+
"""The expected spectral contents."""
|
|
66
|
+
|
|
67
|
+
FULL = "Full Spectrum"
|
|
68
|
+
POSITIVE = "Positive Frequencies"
|
|
69
|
+
NEGATIVE = "Negative Frequencies"
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class SpectrumSettings(ez.Settings):
|
|
73
|
+
"""
|
|
74
|
+
Settings for :obj:`Spectrum.
|
|
75
|
+
See :obj:`spectrum` for a description of the parameters.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
axis: str | None = None
|
|
79
|
+
"""
|
|
80
|
+
The name of the axis on which to calculate the spectrum.
|
|
81
|
+
Note: The axis must have an .axes entry of type LinearAxis, not CoordinateAxis.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
# n: int | None = None # n parameter for fft
|
|
85
|
+
|
|
86
|
+
out_axis: str | None = "freq"
|
|
87
|
+
"""The name of the new axis. Defaults to "freq". If none; don't change dim name"""
|
|
88
|
+
|
|
89
|
+
window: WindowFunction = WindowFunction.HAMMING
|
|
90
|
+
"""The :obj:`WindowFunction` to apply to the data slice prior to calculating the spectrum."""
|
|
91
|
+
|
|
92
|
+
transform: SpectralTransform = SpectralTransform.REL_DB
|
|
93
|
+
"""The :obj:`SpectralTransform` to apply to the spectral magnitude."""
|
|
94
|
+
|
|
95
|
+
output: SpectralOutput = SpectralOutput.POSITIVE
|
|
96
|
+
"""The :obj:`SpectralOutput` format."""
|
|
97
|
+
|
|
98
|
+
norm: str | None = "forward"
|
|
99
|
+
"""
|
|
100
|
+
Normalization mode. Default "forward" is best used when the inverse transform is not needed,
|
|
101
|
+
for example when the goal is to get spectral power. Use "backward" (equivalent to None) to not
|
|
102
|
+
scale the spectrum which is useful when the spectra will be manipulated and possibly inverse-transformed.
|
|
103
|
+
See numpy.fft.fft for details.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
do_fftshift: bool = True
|
|
107
|
+
"""
|
|
108
|
+
Whether to apply fftshift to the output. Default is True.
|
|
109
|
+
This value is ignored unless output is SpectralOutput.FULL.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
nfft: int | None = None
|
|
113
|
+
"""
|
|
114
|
+
The number of points to use for the FFT. If None, the length of the input data is used.
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@processor_state
|
|
119
|
+
class SpectrumState:
|
|
120
|
+
f_sl: slice | None = None
|
|
121
|
+
# I would prefer `slice(None)` as f_sl default but this fails because it is mutable.
|
|
122
|
+
freq_axis: AxisArray.LinearAxis | None = None
|
|
123
|
+
fftfun: typing.Callable | None = None
|
|
124
|
+
f_transform: typing.Callable | None = None
|
|
125
|
+
new_dims: list[str] | None = None
|
|
126
|
+
window: npt.NDArray | None = None
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class SpectrumTransformer(BaseStatefulTransformer[SpectrumSettings, AxisArray, AxisArray, SpectrumState]):
|
|
130
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
131
|
+
axis = self.settings.axis or message.dims[0]
|
|
132
|
+
ax_idx = message.get_axis_idx(axis)
|
|
133
|
+
ax_info = message.axes[axis]
|
|
134
|
+
targ_len = message.data.shape[ax_idx]
|
|
135
|
+
return hash((targ_len, message.data.ndim, message.data.dtype.kind, ax_idx, ax_info.gain))
|
|
136
|
+
|
|
137
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
138
|
+
axis = self.settings.axis or message.dims[0]
|
|
139
|
+
ax_idx = message.get_axis_idx(axis)
|
|
140
|
+
ax_info = message.axes[axis]
|
|
141
|
+
targ_len = message.data.shape[ax_idx]
|
|
142
|
+
nfft = self.settings.nfft or targ_len
|
|
143
|
+
|
|
144
|
+
# Pre-calculate windowing
|
|
145
|
+
window = WINDOWS[self.settings.window](targ_len)
|
|
146
|
+
window = window.reshape(
|
|
147
|
+
[1] * ax_idx
|
|
148
|
+
+ [
|
|
149
|
+
len(window),
|
|
150
|
+
]
|
|
151
|
+
+ [1] * (message.data.ndim - 1 - ax_idx)
|
|
152
|
+
)
|
|
153
|
+
if self.settings.transform != SpectralTransform.RAW_COMPLEX and not (
|
|
154
|
+
self.settings.transform == SpectralTransform.REAL or self.settings.transform == SpectralTransform.IMAG
|
|
155
|
+
):
|
|
156
|
+
scale = np.sum(window**2.0) * ax_info.gain
|
|
157
|
+
|
|
158
|
+
if self.settings.window != WindowFunction.NONE:
|
|
159
|
+
self.state.window = window
|
|
160
|
+
|
|
161
|
+
# Pre-calculate frequencies and select our fft function.
|
|
162
|
+
b_complex = message.data.dtype.kind == "c"
|
|
163
|
+
self.state.f_sl = slice(None)
|
|
164
|
+
if (not b_complex) and self.settings.output == SpectralOutput.POSITIVE:
|
|
165
|
+
# If input is not complex and desired output is SpectralOutput.POSITIVE, we can save some computation
|
|
166
|
+
# by using rfft and rfftfreq.
|
|
167
|
+
self.state.fftfun = partial(np.fft.rfft, n=nfft, axis=ax_idx, norm=self.settings.norm)
|
|
168
|
+
freqs = np.fft.rfftfreq(nfft, d=ax_info.gain * targ_len / nfft)
|
|
169
|
+
else:
|
|
170
|
+
self.state.fftfun = partial(np.fft.fft, n=nfft, axis=ax_idx, norm=self.settings.norm)
|
|
171
|
+
freqs = np.fft.fftfreq(nfft, d=ax_info.gain * targ_len / nfft)
|
|
172
|
+
if self.settings.output == SpectralOutput.POSITIVE:
|
|
173
|
+
self.state.f_sl = slice(None, nfft // 2 + 1 - (nfft % 2))
|
|
174
|
+
elif self.settings.output == SpectralOutput.NEGATIVE:
|
|
175
|
+
freqs = np.fft.fftshift(freqs, axes=-1)
|
|
176
|
+
self.state.f_sl = slice(None, nfft // 2 + 1)
|
|
177
|
+
elif self.settings.do_fftshift and self.settings.output == SpectralOutput.FULL:
|
|
178
|
+
freqs = np.fft.fftshift(freqs, axes=-1)
|
|
179
|
+
freqs = freqs[self.state.f_sl]
|
|
180
|
+
freqs = freqs.tolist() # To please type checking
|
|
181
|
+
self.state.freq_axis = AxisArray.LinearAxis(unit="Hz", gain=freqs[1] - freqs[0], offset=freqs[0])
|
|
182
|
+
self.state.new_dims = (
|
|
183
|
+
message.dims[:ax_idx]
|
|
184
|
+
+ [
|
|
185
|
+
self.settings.out_axis or axis,
|
|
186
|
+
]
|
|
187
|
+
+ message.dims[ax_idx + 1 :]
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
def f_transform(x):
|
|
191
|
+
return x
|
|
192
|
+
|
|
193
|
+
if self.settings.transform != SpectralTransform.RAW_COMPLEX:
|
|
194
|
+
if self.settings.transform == SpectralTransform.REAL:
|
|
195
|
+
|
|
196
|
+
def f_transform(x):
|
|
197
|
+
return x.real
|
|
198
|
+
elif self.settings.transform == SpectralTransform.IMAG:
|
|
199
|
+
|
|
200
|
+
def f_transform(x):
|
|
201
|
+
return x.imag
|
|
202
|
+
else:
|
|
203
|
+
|
|
204
|
+
def f1(x):
|
|
205
|
+
return (np.abs(x) ** 2.0) / scale
|
|
206
|
+
|
|
207
|
+
if self.settings.transform == SpectralTransform.REL_DB:
|
|
208
|
+
|
|
209
|
+
def f_transform(x):
|
|
210
|
+
return 10 * np.log10(f1(x))
|
|
211
|
+
else:
|
|
212
|
+
f_transform = f1
|
|
213
|
+
self.state.f_transform = f_transform
|
|
214
|
+
|
|
215
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
216
|
+
axis = self.settings.axis or message.dims[0]
|
|
217
|
+
ax_idx = message.get_axis_idx(axis)
|
|
218
|
+
targ_len = message.data.shape[ax_idx]
|
|
219
|
+
|
|
220
|
+
new_axes = {k: v for k, v in message.axes.items() if k not in [self.settings.out_axis, axis]}
|
|
221
|
+
new_axes[self.settings.out_axis or axis] = self.state.freq_axis
|
|
222
|
+
|
|
223
|
+
if self.state.window is not None:
|
|
224
|
+
win_dat = message.data * self.state.window
|
|
225
|
+
else:
|
|
226
|
+
win_dat = message.data
|
|
227
|
+
spec = self.state.fftfun(
|
|
228
|
+
win_dat,
|
|
229
|
+
n=self.settings.nfft or targ_len,
|
|
230
|
+
axis=ax_idx,
|
|
231
|
+
norm=self.settings.norm,
|
|
232
|
+
)
|
|
233
|
+
# Note: norm="forward" equivalent to `/ nfft`
|
|
234
|
+
if (
|
|
235
|
+
self.settings.do_fftshift and self.settings.output == SpectralOutput.FULL
|
|
236
|
+
) or self.settings.output == SpectralOutput.NEGATIVE:
|
|
237
|
+
spec = np.fft.fftshift(spec, axes=ax_idx)
|
|
238
|
+
spec = self.state.f_transform(spec)
|
|
239
|
+
spec = slice_along_axis(spec, self.state.f_sl, ax_idx)
|
|
240
|
+
|
|
241
|
+
msg_out = replace(message, data=spec, dims=self.state.new_dims, axes=new_axes)
|
|
242
|
+
return msg_out
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class Spectrum(BaseTransformerUnit[SpectrumSettings, AxisArray, AxisArray, SpectrumTransformer]):
|
|
246
|
+
SETTINGS = SpectrumSettings
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def spectrum(
|
|
250
|
+
axis: str | None = None,
|
|
251
|
+
out_axis: str | None = "freq",
|
|
252
|
+
window: WindowFunction = WindowFunction.HANNING,
|
|
253
|
+
transform: SpectralTransform = SpectralTransform.REL_DB,
|
|
254
|
+
output: SpectralOutput = SpectralOutput.POSITIVE,
|
|
255
|
+
norm: str | None = "forward",
|
|
256
|
+
do_fftshift: bool = True,
|
|
257
|
+
nfft: int | None = None,
|
|
258
|
+
) -> SpectrumTransformer:
|
|
259
|
+
"""
|
|
260
|
+
Calculate a spectrum on a data slice.
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
A :obj:`SpectrumTransformer` object that expects an :obj:`AxisArray` via `.(axis_array)` (__call__)
|
|
264
|
+
containing continuous data and returns an :obj:`AxisArray` with data of spectral magnitudes or powers.
|
|
265
|
+
"""
|
|
266
|
+
return SpectrumTransformer(
|
|
267
|
+
SpectrumSettings(
|
|
268
|
+
axis=axis,
|
|
269
|
+
out_axis=out_axis,
|
|
270
|
+
window=window,
|
|
271
|
+
transform=transform,
|
|
272
|
+
output=output,
|
|
273
|
+
norm=norm,
|
|
274
|
+
do_fftshift=do_fftshift,
|
|
275
|
+
nfft=nfft,
|
|
276
|
+
)
|
|
277
|
+
)
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Transpose or permute array dimensions.
|
|
3
|
+
|
|
4
|
+
.. note::
|
|
5
|
+
This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
|
|
6
|
+
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
|
|
7
|
+
Memory layout optimization (C/F order) only applies to NumPy arrays.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from types import EllipsisType
|
|
11
|
+
|
|
12
|
+
import ezmsg.core as ez
|
|
13
|
+
import numpy as np
|
|
14
|
+
from array_api_compat import get_namespace, is_numpy_array
|
|
15
|
+
from ezmsg.baseproc import (
|
|
16
|
+
BaseStatefulTransformer,
|
|
17
|
+
BaseTransformerUnit,
|
|
18
|
+
processor_state,
|
|
19
|
+
)
|
|
20
|
+
from ezmsg.util.messages.axisarray import (
|
|
21
|
+
AxisArray,
|
|
22
|
+
replace,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TransposeSettings(ez.Settings):
|
|
27
|
+
"""
|
|
28
|
+
Settings for :obj:`Transpose` node.
|
|
29
|
+
|
|
30
|
+
Fields:
|
|
31
|
+
axes:
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
axes: tuple[int | str | EllipsisType, ...] | None = None
|
|
35
|
+
order: str | None = None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@processor_state
|
|
39
|
+
class TransposeState:
|
|
40
|
+
axes_ints: tuple[int, ...] | None = None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class TransposeTransformer(BaseStatefulTransformer[TransposeSettings, AxisArray, AxisArray, TransposeState]):
|
|
44
|
+
"""
|
|
45
|
+
Downsampled data simply comprise every `factor`th sample.
|
|
46
|
+
This should only be used following appropriate lowpass filtering.
|
|
47
|
+
If your pipeline does not already have lowpass filtering then consider
|
|
48
|
+
using the :obj:`Decimate` collection instead.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
52
|
+
return hash(tuple(message.dims))
|
|
53
|
+
|
|
54
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
55
|
+
if self.settings.axes is None:
|
|
56
|
+
self._state.axes_ints = None
|
|
57
|
+
else:
|
|
58
|
+
ell_ix = [ix for ix, ax in enumerate(self.settings.axes) if ax is Ellipsis]
|
|
59
|
+
if len(ell_ix) > 1:
|
|
60
|
+
raise ValueError("Only one Ellipsis is allowed in axes.")
|
|
61
|
+
ell_ix = ell_ix[0] if len(ell_ix) == 1 else len(message.dims)
|
|
62
|
+
prefix = []
|
|
63
|
+
for ax in self.settings.axes[:ell_ix]:
|
|
64
|
+
if isinstance(ax, int):
|
|
65
|
+
prefix.append(ax)
|
|
66
|
+
else:
|
|
67
|
+
if ax not in message.dims:
|
|
68
|
+
raise ValueError(f"Axis {ax} not found in message dims.")
|
|
69
|
+
prefix.append(message.dims.index(ax))
|
|
70
|
+
suffix = []
|
|
71
|
+
for ax in self.settings.axes[ell_ix + 1 :]:
|
|
72
|
+
if isinstance(ax, int):
|
|
73
|
+
suffix.append(ax)
|
|
74
|
+
else:
|
|
75
|
+
if ax not in message.dims:
|
|
76
|
+
raise ValueError(f"Axis {ax} not found in message dims.")
|
|
77
|
+
suffix.append(message.dims.index(ax))
|
|
78
|
+
ells = [_ for _ in range(message.data.ndim) if _ not in prefix and _ not in suffix]
|
|
79
|
+
re_ix = tuple(prefix + ells + suffix)
|
|
80
|
+
if re_ix == tuple(range(message.data.ndim)):
|
|
81
|
+
self._state.axes_ints = None
|
|
82
|
+
else:
|
|
83
|
+
self._state.axes_ints = re_ix
|
|
84
|
+
if self.settings.order is not None and self.settings.order.upper()[0] not in [
|
|
85
|
+
"C",
|
|
86
|
+
"F",
|
|
87
|
+
]:
|
|
88
|
+
raise ValueError("order must be 'C' or 'F'.")
|
|
89
|
+
|
|
90
|
+
def __call__(self, message: AxisArray) -> AxisArray:
|
|
91
|
+
if self.settings.axes is None and self.settings.order is None:
|
|
92
|
+
# Passthrough
|
|
93
|
+
return message
|
|
94
|
+
return super().__call__(message)
|
|
95
|
+
|
|
96
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
97
|
+
xp = get_namespace(message.data)
|
|
98
|
+
if self.state.axes_ints is None:
|
|
99
|
+
# No transpose required
|
|
100
|
+
if self.settings.order is None:
|
|
101
|
+
# No memory relayout required
|
|
102
|
+
# Note: We should not be able to reach here because it should be shortcutted at passthrough.
|
|
103
|
+
msg_out = message
|
|
104
|
+
else:
|
|
105
|
+
# Memory layout optimization only applies to numpy arrays
|
|
106
|
+
if is_numpy_array(message.data):
|
|
107
|
+
msg_out = replace(
|
|
108
|
+
message,
|
|
109
|
+
data=np.require(message.data, requirements=self.settings.order.upper()[0]),
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
msg_out = message
|
|
113
|
+
else:
|
|
114
|
+
dims_out = [message.dims[ix] for ix in self.state.axes_ints]
|
|
115
|
+
data_out = xp.permute_dims(message.data, axes=self.state.axes_ints)
|
|
116
|
+
if self.settings.order is not None and is_numpy_array(data_out):
|
|
117
|
+
# Memory layout optimization only applies to numpy arrays
|
|
118
|
+
data_out = np.require(data_out, requirements=self.settings.order.upper()[0])
|
|
119
|
+
msg_out = replace(
|
|
120
|
+
message,
|
|
121
|
+
data=data_out,
|
|
122
|
+
dims=dims_out,
|
|
123
|
+
)
|
|
124
|
+
return msg_out
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class Transpose(BaseTransformerUnit[TransposeSettings, AxisArray, AxisArray, TransposeTransformer]):
|
|
128
|
+
SETTINGS = TransposeSettings
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def transpose(
|
|
132
|
+
axes: tuple[int | str | EllipsisType, ...] | None = None, order: str | None = None
|
|
133
|
+
) -> TransposeTransformer:
|
|
134
|
+
return TransposeTransformer(TransposeSettings(axes=axes, order=order))
|
|
File without changes
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Backwards-compatible re-exports from ezmsg.baseproc.util.asio.
|
|
3
|
+
|
|
4
|
+
New code should import directly from ezmsg.baseproc instead.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import warnings
|
|
8
|
+
|
|
9
|
+
warnings.warn(
|
|
10
|
+
"Importing from 'ezmsg.sigproc.util.asio' is deprecated. Please import from 'ezmsg.baseproc.util.asio' instead.",
|
|
11
|
+
DeprecationWarning,
|
|
12
|
+
stacklevel=2,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from ezmsg.baseproc.util.asio import ( # noqa: E402
|
|
16
|
+
CoroutineExecutionError,
|
|
17
|
+
SyncToAsyncGeneratorWrapper,
|
|
18
|
+
run_coroutine_sync,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"CoroutineExecutionError",
|
|
23
|
+
"SyncToAsyncGeneratorWrapper",
|
|
24
|
+
"run_coroutine_sync",
|
|
25
|
+
]
|