zea 0.0.8__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +3 -3
- zea/agent/masks.py +2 -2
- zea/agent/selection.py +3 -3
- zea/backend/__init__.py +1 -1
- zea/backend/tensorflow/dataloader.py +1 -1
- zea/beamform/beamformer.py +4 -2
- zea/beamform/pfield.py +2 -2
- zea/data/augmentations.py +1 -1
- zea/data/convert/__main__.py +93 -52
- zea/data/convert/camus.py +8 -2
- zea/data/convert/echonet.py +1 -1
- zea/data/convert/echonetlvh/__init__.py +1 -1
- zea/data/convert/verasonics.py +810 -772
- zea/data/data_format.py +0 -2
- zea/data/file.py +28 -0
- zea/data/preset_utils.py +1 -1
- zea/display.py +1 -1
- zea/doppler.py +5 -5
- zea/func/__init__.py +109 -0
- zea/{tensor_ops.py → func/tensor.py} +32 -8
- zea/func/ultrasound.py +500 -0
- zea/internal/_generate_keras_ops.py +5 -5
- zea/metrics.py +6 -5
- zea/models/diffusion.py +1 -1
- zea/models/echonetlvh.py +1 -1
- zea/models/gmm.py +1 -1
- zea/ops/__init__.py +188 -0
- zea/ops/base.py +442 -0
- zea/{keras_ops.py → ops/keras_ops.py} +2 -2
- zea/ops/pipeline.py +1472 -0
- zea/ops/tensor.py +356 -0
- zea/ops/ultrasound.py +890 -0
- zea/probes.py +2 -10
- zea/scan.py +17 -20
- zea/tools/fit_scan_cone.py +1 -1
- zea/tools/selection_tool.py +1 -1
- zea/tracking/lucas_kanade.py +1 -1
- zea/tracking/segmentation.py +1 -1
- {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/METADATA +3 -1
- {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/RECORD +43 -37
- zea/ops.py +0 -3534
- {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/WHEEL +0 -0
- {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/entry_points.txt +0 -0
- {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/licenses/LICENSE +0 -0
zea/func/ultrasound.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import scipy
|
|
3
|
+
from keras import ops
|
|
4
|
+
|
|
5
|
+
from zea import log
|
|
6
|
+
from zea.func.tensor import (
|
|
7
|
+
resample,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def demodulate_not_jitable(
|
|
12
|
+
rf_data,
|
|
13
|
+
sampling_frequency=None,
|
|
14
|
+
center_frequency=None,
|
|
15
|
+
bandwidth=None,
|
|
16
|
+
filter_coeff=None,
|
|
17
|
+
):
|
|
18
|
+
"""Demodulates an RF signal to complex base-band (IQ).
|
|
19
|
+
|
|
20
|
+
Demodulates the radiofrequency (RF) bandpass signals and returns the
|
|
21
|
+
Inphase/Quadrature (I/Q) components. IQ is a complex whose real (imaginary)
|
|
22
|
+
part contains the in-phase (quadrature) component.
|
|
23
|
+
|
|
24
|
+
This function operates (i.e. demodulates) on the RF signal over the
|
|
25
|
+
(fast-) time axis which is assumed to be the last axis.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
rf_data (ndarray): real valued input array of size [..., n_ax, n_el].
|
|
29
|
+
second to last axis is fast-time axis.
|
|
30
|
+
sampling_frequency (float): the sampling frequency of the RF signals (in Hz).
|
|
31
|
+
Only not necessary when filter_coeff is provided.
|
|
32
|
+
center_frequency (float, optional): represents the center frequency (in Hz).
|
|
33
|
+
Defaults to None.
|
|
34
|
+
bandwidth (float, optional): Bandwidth of RF signal in % of center
|
|
35
|
+
frequency. Defaults to None.
|
|
36
|
+
The bandwidth in % is defined by:
|
|
37
|
+
B = Bandwidth_in_% = Bandwidth_in_Hz*(100/center_frequency).
|
|
38
|
+
The cutoff frequency:
|
|
39
|
+
Wn = Bandwidth_in_Hz/sampling_frequency, i.e:
|
|
40
|
+
Wn = B*(center_frequency/100)/sampling_frequency.
|
|
41
|
+
filter_coeff (list, optional): (b, a), numerator and denominator coefficients
|
|
42
|
+
of FIR filter for quadratic band pass filter. All other parameters are ignored
|
|
43
|
+
if filter_coeff are provided. Instead the given filter_coeff is directly used.
|
|
44
|
+
If not provided, a filter is derived from the other params (sampling_frequency,
|
|
45
|
+
center_frequency, bandwidth).
|
|
46
|
+
see https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.lfilter.html
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
iq_data (ndarray): complex valued base-band signal.
|
|
50
|
+
|
|
51
|
+
"""
|
|
52
|
+
rf_data = ops.convert_to_numpy(rf_data)
|
|
53
|
+
assert np.isreal(rf_data).all(), f"RF must contain real RF signals, got {rf_data.dtype}"
|
|
54
|
+
|
|
55
|
+
input_shape = rf_data.shape
|
|
56
|
+
n_dim = len(input_shape)
|
|
57
|
+
if n_dim > 2:
|
|
58
|
+
*_, n_ax, n_el = input_shape
|
|
59
|
+
else:
|
|
60
|
+
n_ax, n_el = input_shape
|
|
61
|
+
|
|
62
|
+
if filter_coeff is None:
|
|
63
|
+
assert sampling_frequency is not None, "provide sampling_frequency when no filter is given."
|
|
64
|
+
# Time vector
|
|
65
|
+
t = np.arange(n_ax) / sampling_frequency
|
|
66
|
+
t0 = 0
|
|
67
|
+
t = t + t0
|
|
68
|
+
|
|
69
|
+
# Estimate center frequency
|
|
70
|
+
if center_frequency is None:
|
|
71
|
+
# Keep a maximum of 100 randomly selected scanlines
|
|
72
|
+
idx = np.arange(n_el)
|
|
73
|
+
if n_el > 100:
|
|
74
|
+
idx = np.random.permutation(idx)[:100]
|
|
75
|
+
# Power Spectrum
|
|
76
|
+
P = np.sum(
|
|
77
|
+
np.abs(np.fft.fft(np.take(rf_data, idx, axis=-1), axis=-2)) ** 2,
|
|
78
|
+
axis=-1,
|
|
79
|
+
)
|
|
80
|
+
P = P[: n_ax // 2]
|
|
81
|
+
# Carrier frequency
|
|
82
|
+
idx = np.sum(np.arange(n_ax // 2) * P) / np.sum(P)
|
|
83
|
+
center_frequency = idx * sampling_frequency / n_ax
|
|
84
|
+
|
|
85
|
+
# Normalized cut-off frequency
|
|
86
|
+
if bandwidth is None:
|
|
87
|
+
Wn = min(2 * center_frequency / sampling_frequency, 0.5)
|
|
88
|
+
bandwidth = center_frequency * Wn
|
|
89
|
+
else:
|
|
90
|
+
assert np.isscalar(bandwidth), "The signal bandwidth (in %) must be a scalar."
|
|
91
|
+
assert (bandwidth > 0) & (bandwidth <= 200), (
|
|
92
|
+
"The signal bandwidth (in %) must be within the interval of ]0,200]."
|
|
93
|
+
)
|
|
94
|
+
# bandwidth in Hz
|
|
95
|
+
bandwidth = center_frequency * bandwidth / 100
|
|
96
|
+
Wn = bandwidth / sampling_frequency
|
|
97
|
+
assert (Wn > 0) & (Wn <= 1), (
|
|
98
|
+
"The normalized cutoff frequency is not within the interval of (0,1). "
|
|
99
|
+
"Check the input parameters!"
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Down-mixing of the RF signals
|
|
103
|
+
carrier = np.exp(-1j * 2 * np.pi * center_frequency * t)
|
|
104
|
+
# add the singleton dimensions
|
|
105
|
+
carrier = np.reshape(carrier, (*[1] * (n_dim - 2), n_ax, 1))
|
|
106
|
+
iq_data = rf_data * carrier
|
|
107
|
+
|
|
108
|
+
# Low-pass filter
|
|
109
|
+
N = 5
|
|
110
|
+
b, a = scipy.signal.butter(N, Wn, "low")
|
|
111
|
+
|
|
112
|
+
# factor 2: to preserve the envelope amplitude
|
|
113
|
+
iq_data = scipy.signal.filtfilt(b, a, iq_data, axis=-2) * 2
|
|
114
|
+
|
|
115
|
+
# Display a warning message if harmful aliasing is suspected
|
|
116
|
+
# the RF signal is undersampled
|
|
117
|
+
if sampling_frequency < (2 * center_frequency + bandwidth):
|
|
118
|
+
# lower and higher frequencies of the bandpass signal
|
|
119
|
+
fL = center_frequency - bandwidth / 2
|
|
120
|
+
fH = center_frequency + bandwidth / 2
|
|
121
|
+
n = fH // (fH - fL)
|
|
122
|
+
harmless_aliasing = any(
|
|
123
|
+
(2 * fH / np.arange(1, n) <= sampling_frequency)
|
|
124
|
+
& (sampling_frequency <= 2 * fL / np.arange(1, n))
|
|
125
|
+
)
|
|
126
|
+
if not harmless_aliasing:
|
|
127
|
+
log.warning(
|
|
128
|
+
"rf2iq:harmful_aliasing Harmful aliasing is present: the aliases"
|
|
129
|
+
" are not mutually exclusive!"
|
|
130
|
+
)
|
|
131
|
+
else:
|
|
132
|
+
b, a = filter_coeff
|
|
133
|
+
iq_data = scipy.signal.lfilter(b, a, rf_data, axis=-2) * 2
|
|
134
|
+
|
|
135
|
+
return iq_data
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def upmix(iq_data, sampling_frequency, center_frequency, upsampling_rate=6):
|
|
139
|
+
"""Upsamples and upmixes complex base-band signals (IQ) to RF.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
iq_data (ndarray): complex valued input array of size [..., n_ax, n_el]. second
|
|
143
|
+
to last axis is fast-time axis.
|
|
144
|
+
sampling_frequency (float): the sampling frequency of the input IQ signal (in Hz).
|
|
145
|
+
resulting sampling_frequency of RF data is upsampling_rate times higher.
|
|
146
|
+
center_frequency (float, optional): represents the center frequency (in Hz).
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
rf_data (ndarray): output real valued rf data.
|
|
150
|
+
"""
|
|
151
|
+
assert iq_data.dtype in [
|
|
152
|
+
"complex64",
|
|
153
|
+
"complex128",
|
|
154
|
+
], "IQ must contain all complex signals."
|
|
155
|
+
|
|
156
|
+
input_shape = iq_data.shape
|
|
157
|
+
n_dim = len(input_shape)
|
|
158
|
+
if n_dim > 2:
|
|
159
|
+
*_, n_ax, _ = input_shape
|
|
160
|
+
else:
|
|
161
|
+
n_ax, _ = input_shape
|
|
162
|
+
|
|
163
|
+
# Time vector
|
|
164
|
+
n_ax_up = n_ax * upsampling_rate
|
|
165
|
+
sampling_frequency_up = sampling_frequency * upsampling_rate
|
|
166
|
+
|
|
167
|
+
t = ops.arange(n_ax_up, dtype="float32") / sampling_frequency_up
|
|
168
|
+
t0 = 0
|
|
169
|
+
t = t + t0
|
|
170
|
+
|
|
171
|
+
iq_data_upsampled = resample(
|
|
172
|
+
iq_data,
|
|
173
|
+
n_samples=n_ax_up,
|
|
174
|
+
axis=-2,
|
|
175
|
+
order=1,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# Up-mixing of the IQ signals
|
|
179
|
+
t = ops.cast(t, dtype="complex64")
|
|
180
|
+
center_frequency = ops.cast(center_frequency, dtype="complex64")
|
|
181
|
+
carrier = ops.exp(1j * 2 * np.pi * center_frequency * t)
|
|
182
|
+
carrier = ops.reshape(carrier, (*[1] * (n_dim - 2), n_ax_up, 1))
|
|
183
|
+
|
|
184
|
+
rf_data = iq_data_upsampled * carrier
|
|
185
|
+
rf_data = ops.real(rf_data) * ops.sqrt(2)
|
|
186
|
+
|
|
187
|
+
return ops.cast(rf_data, "float32")
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def get_band_pass_filter(num_taps, sampling_frequency, f1, f2):
|
|
191
|
+
"""Band pass filter
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
num_taps (int): number of taps in filter.
|
|
195
|
+
sampling_frequency (float): sample frequency in Hz.
|
|
196
|
+
f1 (float): cutoff frequency in Hz of left band edge.
|
|
197
|
+
f2 (float): cutoff frequency in Hz of right band edge.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
ndarray: band pass filter
|
|
201
|
+
"""
|
|
202
|
+
bpf = scipy.signal.firwin(num_taps, [f1, f2], pass_zero=False, fs=sampling_frequency)
|
|
203
|
+
return bpf
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def get_low_pass_iq_filter(num_taps, sampling_frequency, center_frequency, bandwidth):
|
|
207
|
+
"""Design complex low-pass filter.
|
|
208
|
+
|
|
209
|
+
The filter is a low-pass FIR filter modulated to the center frequency.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
num_taps (int): number of taps in filter.
|
|
213
|
+
sampling_frequency (float): sample frequency.
|
|
214
|
+
center_frequency (float): center frequency.
|
|
215
|
+
bandwidth (float): bandwidth in Hz.
|
|
216
|
+
|
|
217
|
+
Raises:
|
|
218
|
+
ValueError: if cutoff frequency (bandwidth / 2) is not within (0, sampling_frequency / 2)
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
ndarray: Complex-valued low-pass filter
|
|
222
|
+
"""
|
|
223
|
+
cutoff = bandwidth / 2
|
|
224
|
+
if not (0 < cutoff < sampling_frequency / 2):
|
|
225
|
+
raise ValueError(
|
|
226
|
+
f"Cutoff frequency must be within (0, sampling_frequency / 2), "
|
|
227
|
+
f"got {cutoff} Hz, must be within (0, {sampling_frequency / 2}) Hz"
|
|
228
|
+
)
|
|
229
|
+
# Design real-valued low-pass filter
|
|
230
|
+
lpf = scipy.signal.firwin(num_taps, cutoff, pass_zero=True, fs=sampling_frequency)
|
|
231
|
+
# Modulate to center frequency to make it complex
|
|
232
|
+
time_points = np.arange(num_taps) / sampling_frequency
|
|
233
|
+
lpf_complex = lpf * np.exp(1j * 2 * np.pi * center_frequency * time_points)
|
|
234
|
+
return lpf_complex
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def complex_to_channels(complex_data, axis=-1):
|
|
238
|
+
"""Unroll complex data to separate channels.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
complex_data (complex ndarray): complex input data.
|
|
242
|
+
axis (int, optional): on which axis to extend. Defaults to -1.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
ndarray: real array with real and imaginary components
|
|
246
|
+
unrolled over two channels at axis.
|
|
247
|
+
"""
|
|
248
|
+
# assert ops.iscomplex(complex_data).any()
|
|
249
|
+
q_data = ops.imag(complex_data)
|
|
250
|
+
i_data = ops.real(complex_data)
|
|
251
|
+
|
|
252
|
+
i_data = ops.expand_dims(i_data, axis=axis)
|
|
253
|
+
q_data = ops.expand_dims(q_data, axis=axis)
|
|
254
|
+
|
|
255
|
+
iq_data = ops.concatenate((i_data, q_data), axis=axis)
|
|
256
|
+
return iq_data
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def channels_to_complex(data):
|
|
260
|
+
"""Convert array with real and imaginary components at
|
|
261
|
+
different channels to complex data array.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
data (ndarray): input data, with at 0 index of axis
|
|
265
|
+
real component and 1 index of axis the imaginary.
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
ndarray: complex array with real and imaginary components.
|
|
269
|
+
"""
|
|
270
|
+
assert data.shape[-1] == 2, "Data must have two channels."
|
|
271
|
+
data = ops.cast(data, "complex64")
|
|
272
|
+
return data[..., 0] + 1j * data[..., 1]
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def hilbert(x, N: int = None, axis=-1):
|
|
276
|
+
"""Manual implementation of the Hilbert transform function. The function
|
|
277
|
+
returns the analytical signal.
|
|
278
|
+
|
|
279
|
+
Operated in the Fourier domain.
|
|
280
|
+
|
|
281
|
+
Note:
|
|
282
|
+
THIS IS NOT THE MATHEMATICAL THE HILBERT TRANSFORM as you will find it on
|
|
283
|
+
wikipedia, but computes the analytical signal. The implementation reproduces
|
|
284
|
+
the behavior of the `scipy.signal.hilbert` function.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
x (ndarray): input data of any shape.
|
|
288
|
+
N (int, optional): number of points in the FFT. Defaults to None.
|
|
289
|
+
axis (int, optional): axis to operate on. Defaults to -1.
|
|
290
|
+
Returns:
|
|
291
|
+
x (ndarray): complex iq data of any shape.k
|
|
292
|
+
|
|
293
|
+
"""
|
|
294
|
+
input_shape = x.shape
|
|
295
|
+
n_dim = len(input_shape)
|
|
296
|
+
|
|
297
|
+
n_ax = input_shape[axis]
|
|
298
|
+
|
|
299
|
+
if axis < 0:
|
|
300
|
+
axis = n_dim + axis
|
|
301
|
+
|
|
302
|
+
if N is not None:
|
|
303
|
+
if N < n_ax:
|
|
304
|
+
raise ValueError("N must be greater or equal to n_ax.")
|
|
305
|
+
# only pad along the axis, use manual padding
|
|
306
|
+
pad = N - n_ax
|
|
307
|
+
zeros = ops.zeros(
|
|
308
|
+
input_shape[:axis] + (pad,) + input_shape[axis + 1 :],
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
x = ops.concatenate((x, zeros), axis=axis)
|
|
312
|
+
else:
|
|
313
|
+
N = n_ax
|
|
314
|
+
|
|
315
|
+
# Create filter to zero out negative frequencies
|
|
316
|
+
h = np.zeros(N)
|
|
317
|
+
if N % 2 == 0:
|
|
318
|
+
h[0] = h[N // 2] = 1
|
|
319
|
+
h[1 : N // 2] = 2
|
|
320
|
+
else:
|
|
321
|
+
h[0] = 1
|
|
322
|
+
h[1 : (N + 1) // 2] = 2
|
|
323
|
+
|
|
324
|
+
idx = list(range(n_dim))
|
|
325
|
+
# make sure axis gets to the end for fft (operates on last axis)
|
|
326
|
+
idx.remove(axis)
|
|
327
|
+
idx.append(axis)
|
|
328
|
+
x = ops.transpose(x, idx)
|
|
329
|
+
|
|
330
|
+
if x.ndim > 1:
|
|
331
|
+
ind = [np.newaxis] * x.ndim
|
|
332
|
+
ind[-1] = slice(None)
|
|
333
|
+
h = h[tuple(ind)]
|
|
334
|
+
|
|
335
|
+
h = ops.convert_to_tensor(h)
|
|
336
|
+
h = ops.cast(h, "complex64")
|
|
337
|
+
h = h + 1j * ops.zeros_like(h)
|
|
338
|
+
|
|
339
|
+
Xf_r, Xf_i = ops.fft((x, ops.zeros_like(x)))
|
|
340
|
+
|
|
341
|
+
Xf_r = ops.cast(Xf_r, "complex64")
|
|
342
|
+
Xf_i = ops.cast(Xf_i, "complex64")
|
|
343
|
+
|
|
344
|
+
Xf = Xf_r + 1j * Xf_i
|
|
345
|
+
Xf = Xf * h
|
|
346
|
+
|
|
347
|
+
# x = np.fft.ifft(Xf)
|
|
348
|
+
# do manual ifft using fft
|
|
349
|
+
Xf_r = ops.real(Xf)
|
|
350
|
+
Xf_i = ops.imag(Xf)
|
|
351
|
+
Xf_r_inv, Xf_i_inv = ops.fft((Xf_r, -Xf_i))
|
|
352
|
+
|
|
353
|
+
Xf_i_inv = ops.cast(Xf_i_inv, "complex64")
|
|
354
|
+
Xf_r_inv = ops.cast(Xf_r_inv, "complex64")
|
|
355
|
+
|
|
356
|
+
x = Xf_r_inv / N
|
|
357
|
+
x = x + 1j * (-Xf_i_inv / N)
|
|
358
|
+
|
|
359
|
+
# switch back to original shape
|
|
360
|
+
idx = list(range(n_dim))
|
|
361
|
+
idx.insert(axis, idx.pop(-1))
|
|
362
|
+
x = ops.transpose(x, idx)
|
|
363
|
+
return x
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def demodulate(data, center_frequency, sampling_frequency, axis=-3):
|
|
367
|
+
"""Demodulates the input data to baseband. The function computes the analytical
|
|
368
|
+
signal (the signal with negative frequencies removed) and then shifts the spectrum
|
|
369
|
+
of the signal to baseband by multiplying with a complex exponential. Where the
|
|
370
|
+
spectrum was centered around `center_frequency` before, it is now centered around
|
|
371
|
+
0 Hz. The baseband IQ data are complex-valued. The real and imaginary parts
|
|
372
|
+
are stored in two real-valued channels.
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
data (ops.Tensor): The input data to demodulate of shape `(..., axis, ..., 1)`.
|
|
376
|
+
center_frequency (float): The center frequency of the signal.
|
|
377
|
+
sampling_frequency (float): The sampling frequency of the signal.
|
|
378
|
+
axis (int, optional): The axis along which to demodulate. Defaults to -3.
|
|
379
|
+
|
|
380
|
+
Returns:
|
|
381
|
+
ops.Tensor: The demodulated IQ data of shape `(..., axis, ..., 2)`.
|
|
382
|
+
"""
|
|
383
|
+
# Compute the analytical signal
|
|
384
|
+
analytical_signal = hilbert(data, axis=axis)
|
|
385
|
+
|
|
386
|
+
# Define frequency indices
|
|
387
|
+
frequency_indices = ops.arange(analytical_signal.shape[axis])
|
|
388
|
+
|
|
389
|
+
# Expand the frequency indices to match the shape of the RF data
|
|
390
|
+
indexing = [None] * data.ndim
|
|
391
|
+
indexing[axis] = slice(None)
|
|
392
|
+
indexing = tuple(indexing)
|
|
393
|
+
frequency_indices_shaped_like_rf = frequency_indices[indexing]
|
|
394
|
+
|
|
395
|
+
# Cast to complex64
|
|
396
|
+
center_frequency = ops.cast(center_frequency, dtype="complex64")
|
|
397
|
+
sampling_frequency = ops.cast(sampling_frequency, dtype="complex64")
|
|
398
|
+
frequency_indices_shaped_like_rf = ops.cast(frequency_indices_shaped_like_rf, dtype="complex64")
|
|
399
|
+
|
|
400
|
+
# Shift to baseband
|
|
401
|
+
phasor_exponent = (
|
|
402
|
+
-1j * 2 * np.pi * center_frequency * frequency_indices_shaped_like_rf / sampling_frequency
|
|
403
|
+
)
|
|
404
|
+
iq_data_signal_complex = analytical_signal * ops.exp(phasor_exponent)
|
|
405
|
+
|
|
406
|
+
# Split the complex signal into two channels
|
|
407
|
+
iq_data_two_channel = complex_to_channels(ops.squeeze(iq_data_signal_complex, axis=-1))
|
|
408
|
+
|
|
409
|
+
return iq_data_two_channel
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def compute_time_to_peak_stack(waveforms, center_frequencies, waveform_sampling_frequency=250e6):
|
|
413
|
+
"""Compute the time of the peak of each waveform in a stack of waveforms.
|
|
414
|
+
|
|
415
|
+
Args:
|
|
416
|
+
waveforms (ndarray): The waveforms of shape (n_waveforms, n_samples).
|
|
417
|
+
center_frequencies (ndarray): The center frequencies of the waveforms in Hz of shape
|
|
418
|
+
(n_waveforms,) or a scalar if all waveforms have the same center frequency.
|
|
419
|
+
waveform_sampling_frequency (float): The sampling frequency of the waveforms in Hz.
|
|
420
|
+
|
|
421
|
+
Returns:
|
|
422
|
+
ndarray: The time to peak for each waveform in seconds.
|
|
423
|
+
"""
|
|
424
|
+
t_peak = []
|
|
425
|
+
center_frequencies = center_frequencies * ops.ones((waveforms.shape[0],))
|
|
426
|
+
for waveform, center_frequency in zip(waveforms, center_frequencies):
|
|
427
|
+
t_peak.append(compute_time_to_peak(waveform, center_frequency, waveform_sampling_frequency))
|
|
428
|
+
return ops.stack(t_peak)
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def compute_time_to_peak(waveform, center_frequency, waveform_sampling_frequency=250e6):
|
|
432
|
+
"""Compute the time of the peak of the waveform.
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
waveform (ndarray): The waveform of shape (n_samples).
|
|
436
|
+
center_frequency (float): The center frequency of the waveform in Hz.
|
|
437
|
+
waveform_sampling_frequency (float): The sampling frequency of the waveform in Hz.
|
|
438
|
+
|
|
439
|
+
Returns:
|
|
440
|
+
float: The time to peak for the waveform in seconds.
|
|
441
|
+
"""
|
|
442
|
+
n_samples = waveform.shape[0]
|
|
443
|
+
if n_samples == 0:
|
|
444
|
+
raise ValueError("Waveform has zero samples.")
|
|
445
|
+
|
|
446
|
+
waveforms_iq_complex_channels = demodulate(
|
|
447
|
+
waveform[..., None], center_frequency, waveform_sampling_frequency, axis=-1
|
|
448
|
+
)
|
|
449
|
+
waveforms_iq_complex = channels_to_complex(waveforms_iq_complex_channels)
|
|
450
|
+
envelope = ops.abs(waveforms_iq_complex)
|
|
451
|
+
peak_idx = ops.argmax(envelope, axis=-1)
|
|
452
|
+
t_peak = ops.cast(peak_idx, dtype="float32") / waveform_sampling_frequency
|
|
453
|
+
return t_peak
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
def envelope_detect(data, axis=-3):
|
|
457
|
+
"""Envelope detection of RF signals.
|
|
458
|
+
|
|
459
|
+
If the input data is real, it first applies the Hilbert transform along the specified axis
|
|
460
|
+
and then computes the magnitude of the resulting complex signal.
|
|
461
|
+
If the input data is complex, it computes the magnitude directly.
|
|
462
|
+
|
|
463
|
+
Args:
|
|
464
|
+
- data (Tensor): The beamformed data of shape (..., grid_size_z, grid_size_x, n_ch).
|
|
465
|
+
- axis (int): Axis along which to apply the Hilbert transform. Defaults to -3.
|
|
466
|
+
|
|
467
|
+
Returns:
|
|
468
|
+
- envelope_data (Tensor): The envelope detected data
|
|
469
|
+
of shape (..., grid_size_z, grid_size_x).
|
|
470
|
+
"""
|
|
471
|
+
if data.shape[-1] == 2:
|
|
472
|
+
data = channels_to_complex(data)
|
|
473
|
+
else:
|
|
474
|
+
n_ax = ops.shape(data)[axis]
|
|
475
|
+
n_ax_float = ops.cast(n_ax, "float32")
|
|
476
|
+
|
|
477
|
+
# Calculate next power of 2: M = 2^ceil(log2(n_ax))
|
|
478
|
+
# see https://github.com/tue-bmd/zea/discussions/147
|
|
479
|
+
log2_n_ax = ops.log2(n_ax_float)
|
|
480
|
+
M = ops.cast(2 ** ops.ceil(log2_n_ax), "int32")
|
|
481
|
+
|
|
482
|
+
data = hilbert(data, N=M, axis=axis)
|
|
483
|
+
indices = ops.arange(n_ax)
|
|
484
|
+
|
|
485
|
+
data = ops.take(data, indices, axis=axis)
|
|
486
|
+
data = ops.squeeze(data, axis=-1)
|
|
487
|
+
|
|
488
|
+
# data = ops.abs(data)
|
|
489
|
+
real = ops.real(data)
|
|
490
|
+
imag = ops.imag(data)
|
|
491
|
+
data = ops.sqrt(real**2 + imag**2)
|
|
492
|
+
data = ops.cast(data, "float32")
|
|
493
|
+
return data
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def log_compress(data, eps=1e-16):
|
|
497
|
+
"""Apply logarithmic compression to data."""
|
|
498
|
+
eps = ops.convert_to_tensor(eps, dtype=data.dtype)
|
|
499
|
+
data = ops.where(data == 0, eps, data) # Avoid log(0)
|
|
500
|
+
return 20 * ops.log10(data)
|
|
@@ -5,7 +5,7 @@ They can be used in zea pipelines like any other :class:`zea.Operation`, for exa
|
|
|
5
5
|
|
|
6
6
|
.. doctest::
|
|
7
7
|
|
|
8
|
-
>>> from zea.keras_ops import Squeeze
|
|
8
|
+
>>> from zea.ops.keras_ops import Squeeze
|
|
9
9
|
>>> op = Squeeze(axis=1)
|
|
10
10
|
"""
|
|
11
11
|
|
|
@@ -78,7 +78,7 @@ They can be used in zea pipelines like any other :class:`zea.Operation`, for exa
|
|
|
78
78
|
|
|
79
79
|
.. doctest::
|
|
80
80
|
|
|
81
|
-
>>> from zea.keras_ops import Squeeze
|
|
81
|
+
>>> from zea.ops.keras_ops import Squeeze
|
|
82
82
|
|
|
83
83
|
>>> op = Squeeze(axis=1)
|
|
84
84
|
|
|
@@ -89,7 +89,7 @@ Generated with Keras {keras.__version__}
|
|
|
89
89
|
import keras
|
|
90
90
|
|
|
91
91
|
from zea.internal.registry import ops_registry
|
|
92
|
-
from zea.ops import Lambda
|
|
92
|
+
from zea.ops.base import Lambda
|
|
93
93
|
|
|
94
94
|
class MissingKerasOps(ValueError):
|
|
95
95
|
def __init__(self, class_name: str, func: str):
|
|
@@ -109,7 +109,7 @@ class MissingKerasOps(ValueError):
|
|
|
109
109
|
content += _generate_operation_class_code(name, keras.ops.image)
|
|
110
110
|
|
|
111
111
|
# Write to a temporary file first, then move to final location
|
|
112
|
-
target_path = Path(__file__).parent.parent / "keras_ops.py"
|
|
112
|
+
target_path = Path(__file__).parent.parent / "ops/keras_ops.py"
|
|
113
113
|
with tempfile.NamedTemporaryFile("w", delete=False, encoding="utf-8") as tmp_file:
|
|
114
114
|
tmp_file.write(content)
|
|
115
115
|
temp_path = Path(tmp_file.name)
|
|
@@ -117,7 +117,7 @@ class MissingKerasOps(ValueError):
|
|
|
117
117
|
# Atomic move to avoid partial writes
|
|
118
118
|
shutil.move(temp_path, target_path)
|
|
119
119
|
|
|
120
|
-
print("Done generating `keras_ops.py`.")
|
|
120
|
+
print("Done generating `ops/keras_ops.py`.")
|
|
121
121
|
|
|
122
122
|
|
|
123
123
|
if __name__ == "__main__":
|
zea/metrics.py
CHANGED
|
@@ -7,12 +7,13 @@ import keras
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
from keras import ops
|
|
9
9
|
|
|
10
|
-
from zea import log
|
|
10
|
+
from zea import log
|
|
11
11
|
from zea.backend import func_on_device
|
|
12
|
+
from zea.func import tensor
|
|
13
|
+
from zea.func.tensor import translate
|
|
12
14
|
from zea.internal.registry import metrics_registry
|
|
13
15
|
from zea.internal.utils import reduce_to_signature
|
|
14
16
|
from zea.models.lpips import LPIPS
|
|
15
|
-
from zea.tensor_ops import translate
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
def get_metric(name, **kwargs):
|
|
@@ -197,7 +198,7 @@ def ssim(
|
|
|
197
198
|
|
|
198
199
|
# Construct a 1D convolution.
|
|
199
200
|
def filter_fn_1(z):
|
|
200
|
-
return
|
|
201
|
+
return tensor.correlate(z, ops.flip(filt), mode="valid")
|
|
201
202
|
|
|
202
203
|
# Apply the vectorized filter along the y axis.
|
|
203
204
|
def filter_fn_y(z):
|
|
@@ -300,7 +301,7 @@ def get_lpips(image_range, batch_size=None, clip=False):
|
|
|
300
301
|
|
|
301
302
|
imgs = ops.stack([img1, img2], axis=-1)
|
|
302
303
|
n_batch_dims = ops.ndim(img1) - 3
|
|
303
|
-
return
|
|
304
|
+
return tensor.func_with_one_batch_dim(
|
|
304
305
|
unstack_lpips, imgs, n_batch_dims, batch_size=batch_size
|
|
305
306
|
)
|
|
306
307
|
|
|
@@ -372,7 +373,7 @@ class Metrics:
|
|
|
372
373
|
# Because most metric functions do not support batching, we vmap over the batch axes.
|
|
373
374
|
metric_fn = fun
|
|
374
375
|
for ax in reversed(batch_axes):
|
|
375
|
-
metric_fn =
|
|
376
|
+
metric_fn = tensor.vmap(metric_fn, in_axes=ax, _use_torch_vmap=True)
|
|
376
377
|
|
|
377
378
|
out = func_on_device(metric_fn, device, y_true, y_pred)
|
|
378
379
|
|
zea/models/diffusion.py
CHANGED
|
@@ -23,6 +23,7 @@ from keras import ops
|
|
|
23
23
|
|
|
24
24
|
from zea.backend import _import_tf, jit
|
|
25
25
|
from zea.backend.autograd import AutoGrad
|
|
26
|
+
from zea.func.tensor import L2, fori_loop, split_seed
|
|
26
27
|
from zea.internal.core import Object
|
|
27
28
|
from zea.internal.operators import Operator
|
|
28
29
|
from zea.internal.registry import diffusion_guidance_registry, model_registry, operator_registry
|
|
@@ -33,7 +34,6 @@ from zea.models.preset_utils import register_presets
|
|
|
33
34
|
from zea.models.presets import diffusion_model_presets
|
|
34
35
|
from zea.models.unet import get_time_conditional_unetwork
|
|
35
36
|
from zea.models.utils import LossTrackerWrapper
|
|
36
|
-
from zea.tensor_ops import L2, fori_loop, split_seed
|
|
37
37
|
|
|
38
38
|
tf = _import_tf()
|
|
39
39
|
|
zea/models/echonetlvh.py
CHANGED
|
@@ -25,12 +25,12 @@ To try this model, simply load one of the available presets:
|
|
|
25
25
|
import numpy as np
|
|
26
26
|
from keras import ops
|
|
27
27
|
|
|
28
|
+
from zea.func.tensor import translate
|
|
28
29
|
from zea.internal.registry import model_registry
|
|
29
30
|
from zea.models.base import BaseModel
|
|
30
31
|
from zea.models.deeplabv3 import DeeplabV3Plus
|
|
31
32
|
from zea.models.preset_utils import register_presets
|
|
32
33
|
from zea.models.presets import echonet_lvh_presets
|
|
33
|
-
from zea.tensor_ops import translate
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
@model_registry(name="echonetlvh")
|
zea/models/gmm.py
CHANGED
|
@@ -4,8 +4,8 @@ import keras
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
from keras import ops
|
|
6
6
|
|
|
7
|
+
from zea.func.tensor import linear_sum_assignment
|
|
7
8
|
from zea.models.generative import GenerativeModel
|
|
8
|
-
from zea.tensor_ops import linear_sum_assignment
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class GaussianMixtureModel(GenerativeModel):
|