zea 0.0.8__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. zea/__init__.py +3 -3
  2. zea/agent/masks.py +2 -2
  3. zea/agent/selection.py +3 -3
  4. zea/backend/__init__.py +1 -1
  5. zea/backend/tensorflow/dataloader.py +1 -1
  6. zea/beamform/beamformer.py +4 -2
  7. zea/beamform/pfield.py +2 -2
  8. zea/data/augmentations.py +1 -1
  9. zea/data/convert/__main__.py +93 -52
  10. zea/data/convert/camus.py +8 -2
  11. zea/data/convert/echonet.py +1 -1
  12. zea/data/convert/echonetlvh/__init__.py +1 -1
  13. zea/data/convert/verasonics.py +810 -772
  14. zea/data/data_format.py +0 -2
  15. zea/data/file.py +28 -0
  16. zea/data/preset_utils.py +1 -1
  17. zea/display.py +1 -1
  18. zea/doppler.py +5 -5
  19. zea/func/__init__.py +109 -0
  20. zea/{tensor_ops.py → func/tensor.py} +32 -8
  21. zea/func/ultrasound.py +500 -0
  22. zea/internal/_generate_keras_ops.py +5 -5
  23. zea/metrics.py +6 -5
  24. zea/models/diffusion.py +1 -1
  25. zea/models/echonetlvh.py +1 -1
  26. zea/models/gmm.py +1 -1
  27. zea/ops/__init__.py +188 -0
  28. zea/ops/base.py +442 -0
  29. zea/{keras_ops.py → ops/keras_ops.py} +2 -2
  30. zea/ops/pipeline.py +1472 -0
  31. zea/ops/tensor.py +356 -0
  32. zea/ops/ultrasound.py +890 -0
  33. zea/probes.py +2 -10
  34. zea/scan.py +17 -20
  35. zea/tools/fit_scan_cone.py +1 -1
  36. zea/tools/selection_tool.py +1 -1
  37. zea/tracking/lucas_kanade.py +1 -1
  38. zea/tracking/segmentation.py +1 -1
  39. {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/METADATA +3 -1
  40. {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/RECORD +43 -37
  41. zea/ops.py +0 -3534
  42. {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/WHEEL +0 -0
  43. {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/entry_points.txt +0 -0
  44. {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/licenses/LICENSE +0 -0
zea/func/ultrasound.py ADDED
@@ -0,0 +1,500 @@
1
+ import numpy as np
2
+ import scipy
3
+ from keras import ops
4
+
5
+ from zea import log
6
+ from zea.func.tensor import (
7
+ resample,
8
+ )
9
+
10
+
11
+ def demodulate_not_jitable(
12
+ rf_data,
13
+ sampling_frequency=None,
14
+ center_frequency=None,
15
+ bandwidth=None,
16
+ filter_coeff=None,
17
+ ):
18
+ """Demodulates an RF signal to complex base-band (IQ).
19
+
20
+ Demodulates the radiofrequency (RF) bandpass signals and returns the
21
+ Inphase/Quadrature (I/Q) components. IQ is a complex whose real (imaginary)
22
+ part contains the in-phase (quadrature) component.
23
+
24
+ This function operates (i.e. demodulates) on the RF signal over the
25
+ (fast-) time axis which is assumed to be the last axis.
26
+
27
+ Args:
28
+ rf_data (ndarray): real valued input array of size [..., n_ax, n_el].
29
+ second to last axis is fast-time axis.
30
+ sampling_frequency (float): the sampling frequency of the RF signals (in Hz).
31
+ Only not necessary when filter_coeff is provided.
32
+ center_frequency (float, optional): represents the center frequency (in Hz).
33
+ Defaults to None.
34
+ bandwidth (float, optional): Bandwidth of RF signal in % of center
35
+ frequency. Defaults to None.
36
+ The bandwidth in % is defined by:
37
+ B = Bandwidth_in_% = Bandwidth_in_Hz*(100/center_frequency).
38
+ The cutoff frequency:
39
+ Wn = Bandwidth_in_Hz/sampling_frequency, i.e:
40
+ Wn = B*(center_frequency/100)/sampling_frequency.
41
+ filter_coeff (list, optional): (b, a), numerator and denominator coefficients
42
+ of FIR filter for quadratic band pass filter. All other parameters are ignored
43
+ if filter_coeff are provided. Instead the given filter_coeff is directly used.
44
+ If not provided, a filter is derived from the other params (sampling_frequency,
45
+ center_frequency, bandwidth).
46
+ see https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.lfilter.html
47
+
48
+ Returns:
49
+ iq_data (ndarray): complex valued base-band signal.
50
+
51
+ """
52
+ rf_data = ops.convert_to_numpy(rf_data)
53
+ assert np.isreal(rf_data).all(), f"RF must contain real RF signals, got {rf_data.dtype}"
54
+
55
+ input_shape = rf_data.shape
56
+ n_dim = len(input_shape)
57
+ if n_dim > 2:
58
+ *_, n_ax, n_el = input_shape
59
+ else:
60
+ n_ax, n_el = input_shape
61
+
62
+ if filter_coeff is None:
63
+ assert sampling_frequency is not None, "provide sampling_frequency when no filter is given."
64
+ # Time vector
65
+ t = np.arange(n_ax) / sampling_frequency
66
+ t0 = 0
67
+ t = t + t0
68
+
69
+ # Estimate center frequency
70
+ if center_frequency is None:
71
+ # Keep a maximum of 100 randomly selected scanlines
72
+ idx = np.arange(n_el)
73
+ if n_el > 100:
74
+ idx = np.random.permutation(idx)[:100]
75
+ # Power Spectrum
76
+ P = np.sum(
77
+ np.abs(np.fft.fft(np.take(rf_data, idx, axis=-1), axis=-2)) ** 2,
78
+ axis=-1,
79
+ )
80
+ P = P[: n_ax // 2]
81
+ # Carrier frequency
82
+ idx = np.sum(np.arange(n_ax // 2) * P) / np.sum(P)
83
+ center_frequency = idx * sampling_frequency / n_ax
84
+
85
+ # Normalized cut-off frequency
86
+ if bandwidth is None:
87
+ Wn = min(2 * center_frequency / sampling_frequency, 0.5)
88
+ bandwidth = center_frequency * Wn
89
+ else:
90
+ assert np.isscalar(bandwidth), "The signal bandwidth (in %) must be a scalar."
91
+ assert (bandwidth > 0) & (bandwidth <= 200), (
92
+ "The signal bandwidth (in %) must be within the interval of ]0,200]."
93
+ )
94
+ # bandwidth in Hz
95
+ bandwidth = center_frequency * bandwidth / 100
96
+ Wn = bandwidth / sampling_frequency
97
+ assert (Wn > 0) & (Wn <= 1), (
98
+ "The normalized cutoff frequency is not within the interval of (0,1). "
99
+ "Check the input parameters!"
100
+ )
101
+
102
+ # Down-mixing of the RF signals
103
+ carrier = np.exp(-1j * 2 * np.pi * center_frequency * t)
104
+ # add the singleton dimensions
105
+ carrier = np.reshape(carrier, (*[1] * (n_dim - 2), n_ax, 1))
106
+ iq_data = rf_data * carrier
107
+
108
+ # Low-pass filter
109
+ N = 5
110
+ b, a = scipy.signal.butter(N, Wn, "low")
111
+
112
+ # factor 2: to preserve the envelope amplitude
113
+ iq_data = scipy.signal.filtfilt(b, a, iq_data, axis=-2) * 2
114
+
115
+ # Display a warning message if harmful aliasing is suspected
116
+ # the RF signal is undersampled
117
+ if sampling_frequency < (2 * center_frequency + bandwidth):
118
+ # lower and higher frequencies of the bandpass signal
119
+ fL = center_frequency - bandwidth / 2
120
+ fH = center_frequency + bandwidth / 2
121
+ n = fH // (fH - fL)
122
+ harmless_aliasing = any(
123
+ (2 * fH / np.arange(1, n) <= sampling_frequency)
124
+ & (sampling_frequency <= 2 * fL / np.arange(1, n))
125
+ )
126
+ if not harmless_aliasing:
127
+ log.warning(
128
+ "rf2iq:harmful_aliasing Harmful aliasing is present: the aliases"
129
+ " are not mutually exclusive!"
130
+ )
131
+ else:
132
+ b, a = filter_coeff
133
+ iq_data = scipy.signal.lfilter(b, a, rf_data, axis=-2) * 2
134
+
135
+ return iq_data
136
+
137
+
138
+ def upmix(iq_data, sampling_frequency, center_frequency, upsampling_rate=6):
139
+ """Upsamples and upmixes complex base-band signals (IQ) to RF.
140
+
141
+ Args:
142
+ iq_data (ndarray): complex valued input array of size [..., n_ax, n_el]. second
143
+ to last axis is fast-time axis.
144
+ sampling_frequency (float): the sampling frequency of the input IQ signal (in Hz).
145
+ resulting sampling_frequency of RF data is upsampling_rate times higher.
146
+ center_frequency (float, optional): represents the center frequency (in Hz).
147
+
148
+ Returns:
149
+ rf_data (ndarray): output real valued rf data.
150
+ """
151
+ assert iq_data.dtype in [
152
+ "complex64",
153
+ "complex128",
154
+ ], "IQ must contain all complex signals."
155
+
156
+ input_shape = iq_data.shape
157
+ n_dim = len(input_shape)
158
+ if n_dim > 2:
159
+ *_, n_ax, _ = input_shape
160
+ else:
161
+ n_ax, _ = input_shape
162
+
163
+ # Time vector
164
+ n_ax_up = n_ax * upsampling_rate
165
+ sampling_frequency_up = sampling_frequency * upsampling_rate
166
+
167
+ t = ops.arange(n_ax_up, dtype="float32") / sampling_frequency_up
168
+ t0 = 0
169
+ t = t + t0
170
+
171
+ iq_data_upsampled = resample(
172
+ iq_data,
173
+ n_samples=n_ax_up,
174
+ axis=-2,
175
+ order=1,
176
+ )
177
+
178
+ # Up-mixing of the IQ signals
179
+ t = ops.cast(t, dtype="complex64")
180
+ center_frequency = ops.cast(center_frequency, dtype="complex64")
181
+ carrier = ops.exp(1j * 2 * np.pi * center_frequency * t)
182
+ carrier = ops.reshape(carrier, (*[1] * (n_dim - 2), n_ax_up, 1))
183
+
184
+ rf_data = iq_data_upsampled * carrier
185
+ rf_data = ops.real(rf_data) * ops.sqrt(2)
186
+
187
+ return ops.cast(rf_data, "float32")
188
+
189
+
190
+ def get_band_pass_filter(num_taps, sampling_frequency, f1, f2):
191
+ """Band pass filter
192
+
193
+ Args:
194
+ num_taps (int): number of taps in filter.
195
+ sampling_frequency (float): sample frequency in Hz.
196
+ f1 (float): cutoff frequency in Hz of left band edge.
197
+ f2 (float): cutoff frequency in Hz of right band edge.
198
+
199
+ Returns:
200
+ ndarray: band pass filter
201
+ """
202
+ bpf = scipy.signal.firwin(num_taps, [f1, f2], pass_zero=False, fs=sampling_frequency)
203
+ return bpf
204
+
205
+
206
+ def get_low_pass_iq_filter(num_taps, sampling_frequency, center_frequency, bandwidth):
207
+ """Design complex low-pass filter.
208
+
209
+ The filter is a low-pass FIR filter modulated to the center frequency.
210
+
211
+ Args:
212
+ num_taps (int): number of taps in filter.
213
+ sampling_frequency (float): sample frequency.
214
+ center_frequency (float): center frequency.
215
+ bandwidth (float): bandwidth in Hz.
216
+
217
+ Raises:
218
+ ValueError: if cutoff frequency (bandwidth / 2) is not within (0, sampling_frequency / 2)
219
+
220
+ Returns:
221
+ ndarray: Complex-valued low-pass filter
222
+ """
223
+ cutoff = bandwidth / 2
224
+ if not (0 < cutoff < sampling_frequency / 2):
225
+ raise ValueError(
226
+ f"Cutoff frequency must be within (0, sampling_frequency / 2), "
227
+ f"got {cutoff} Hz, must be within (0, {sampling_frequency / 2}) Hz"
228
+ )
229
+ # Design real-valued low-pass filter
230
+ lpf = scipy.signal.firwin(num_taps, cutoff, pass_zero=True, fs=sampling_frequency)
231
+ # Modulate to center frequency to make it complex
232
+ time_points = np.arange(num_taps) / sampling_frequency
233
+ lpf_complex = lpf * np.exp(1j * 2 * np.pi * center_frequency * time_points)
234
+ return lpf_complex
235
+
236
+
237
+ def complex_to_channels(complex_data, axis=-1):
238
+ """Unroll complex data to separate channels.
239
+
240
+ Args:
241
+ complex_data (complex ndarray): complex input data.
242
+ axis (int, optional): on which axis to extend. Defaults to -1.
243
+
244
+ Returns:
245
+ ndarray: real array with real and imaginary components
246
+ unrolled over two channels at axis.
247
+ """
248
+ # assert ops.iscomplex(complex_data).any()
249
+ q_data = ops.imag(complex_data)
250
+ i_data = ops.real(complex_data)
251
+
252
+ i_data = ops.expand_dims(i_data, axis=axis)
253
+ q_data = ops.expand_dims(q_data, axis=axis)
254
+
255
+ iq_data = ops.concatenate((i_data, q_data), axis=axis)
256
+ return iq_data
257
+
258
+
259
+ def channels_to_complex(data):
260
+ """Convert array with real and imaginary components at
261
+ different channels to complex data array.
262
+
263
+ Args:
264
+ data (ndarray): input data, with at 0 index of axis
265
+ real component and 1 index of axis the imaginary.
266
+
267
+ Returns:
268
+ ndarray: complex array with real and imaginary components.
269
+ """
270
+ assert data.shape[-1] == 2, "Data must have two channels."
271
+ data = ops.cast(data, "complex64")
272
+ return data[..., 0] + 1j * data[..., 1]
273
+
274
+
275
+ def hilbert(x, N: int = None, axis=-1):
276
+ """Manual implementation of the Hilbert transform function. The function
277
+ returns the analytical signal.
278
+
279
+ Operated in the Fourier domain.
280
+
281
+ Note:
282
+ THIS IS NOT THE MATHEMATICAL THE HILBERT TRANSFORM as you will find it on
283
+ wikipedia, but computes the analytical signal. The implementation reproduces
284
+ the behavior of the `scipy.signal.hilbert` function.
285
+
286
+ Args:
287
+ x (ndarray): input data of any shape.
288
+ N (int, optional): number of points in the FFT. Defaults to None.
289
+ axis (int, optional): axis to operate on. Defaults to -1.
290
+ Returns:
291
+ x (ndarray): complex iq data of any shape.k
292
+
293
+ """
294
+ input_shape = x.shape
295
+ n_dim = len(input_shape)
296
+
297
+ n_ax = input_shape[axis]
298
+
299
+ if axis < 0:
300
+ axis = n_dim + axis
301
+
302
+ if N is not None:
303
+ if N < n_ax:
304
+ raise ValueError("N must be greater or equal to n_ax.")
305
+ # only pad along the axis, use manual padding
306
+ pad = N - n_ax
307
+ zeros = ops.zeros(
308
+ input_shape[:axis] + (pad,) + input_shape[axis + 1 :],
309
+ )
310
+
311
+ x = ops.concatenate((x, zeros), axis=axis)
312
+ else:
313
+ N = n_ax
314
+
315
+ # Create filter to zero out negative frequencies
316
+ h = np.zeros(N)
317
+ if N % 2 == 0:
318
+ h[0] = h[N // 2] = 1
319
+ h[1 : N // 2] = 2
320
+ else:
321
+ h[0] = 1
322
+ h[1 : (N + 1) // 2] = 2
323
+
324
+ idx = list(range(n_dim))
325
+ # make sure axis gets to the end for fft (operates on last axis)
326
+ idx.remove(axis)
327
+ idx.append(axis)
328
+ x = ops.transpose(x, idx)
329
+
330
+ if x.ndim > 1:
331
+ ind = [np.newaxis] * x.ndim
332
+ ind[-1] = slice(None)
333
+ h = h[tuple(ind)]
334
+
335
+ h = ops.convert_to_tensor(h)
336
+ h = ops.cast(h, "complex64")
337
+ h = h + 1j * ops.zeros_like(h)
338
+
339
+ Xf_r, Xf_i = ops.fft((x, ops.zeros_like(x)))
340
+
341
+ Xf_r = ops.cast(Xf_r, "complex64")
342
+ Xf_i = ops.cast(Xf_i, "complex64")
343
+
344
+ Xf = Xf_r + 1j * Xf_i
345
+ Xf = Xf * h
346
+
347
+ # x = np.fft.ifft(Xf)
348
+ # do manual ifft using fft
349
+ Xf_r = ops.real(Xf)
350
+ Xf_i = ops.imag(Xf)
351
+ Xf_r_inv, Xf_i_inv = ops.fft((Xf_r, -Xf_i))
352
+
353
+ Xf_i_inv = ops.cast(Xf_i_inv, "complex64")
354
+ Xf_r_inv = ops.cast(Xf_r_inv, "complex64")
355
+
356
+ x = Xf_r_inv / N
357
+ x = x + 1j * (-Xf_i_inv / N)
358
+
359
+ # switch back to original shape
360
+ idx = list(range(n_dim))
361
+ idx.insert(axis, idx.pop(-1))
362
+ x = ops.transpose(x, idx)
363
+ return x
364
+
365
+
366
+ def demodulate(data, center_frequency, sampling_frequency, axis=-3):
367
+ """Demodulates the input data to baseband. The function computes the analytical
368
+ signal (the signal with negative frequencies removed) and then shifts the spectrum
369
+ of the signal to baseband by multiplying with a complex exponential. Where the
370
+ spectrum was centered around `center_frequency` before, it is now centered around
371
+ 0 Hz. The baseband IQ data are complex-valued. The real and imaginary parts
372
+ are stored in two real-valued channels.
373
+
374
+ Args:
375
+ data (ops.Tensor): The input data to demodulate of shape `(..., axis, ..., 1)`.
376
+ center_frequency (float): The center frequency of the signal.
377
+ sampling_frequency (float): The sampling frequency of the signal.
378
+ axis (int, optional): The axis along which to demodulate. Defaults to -3.
379
+
380
+ Returns:
381
+ ops.Tensor: The demodulated IQ data of shape `(..., axis, ..., 2)`.
382
+ """
383
+ # Compute the analytical signal
384
+ analytical_signal = hilbert(data, axis=axis)
385
+
386
+ # Define frequency indices
387
+ frequency_indices = ops.arange(analytical_signal.shape[axis])
388
+
389
+ # Expand the frequency indices to match the shape of the RF data
390
+ indexing = [None] * data.ndim
391
+ indexing[axis] = slice(None)
392
+ indexing = tuple(indexing)
393
+ frequency_indices_shaped_like_rf = frequency_indices[indexing]
394
+
395
+ # Cast to complex64
396
+ center_frequency = ops.cast(center_frequency, dtype="complex64")
397
+ sampling_frequency = ops.cast(sampling_frequency, dtype="complex64")
398
+ frequency_indices_shaped_like_rf = ops.cast(frequency_indices_shaped_like_rf, dtype="complex64")
399
+
400
+ # Shift to baseband
401
+ phasor_exponent = (
402
+ -1j * 2 * np.pi * center_frequency * frequency_indices_shaped_like_rf / sampling_frequency
403
+ )
404
+ iq_data_signal_complex = analytical_signal * ops.exp(phasor_exponent)
405
+
406
+ # Split the complex signal into two channels
407
+ iq_data_two_channel = complex_to_channels(ops.squeeze(iq_data_signal_complex, axis=-1))
408
+
409
+ return iq_data_two_channel
410
+
411
+
412
+ def compute_time_to_peak_stack(waveforms, center_frequencies, waveform_sampling_frequency=250e6):
413
+ """Compute the time of the peak of each waveform in a stack of waveforms.
414
+
415
+ Args:
416
+ waveforms (ndarray): The waveforms of shape (n_waveforms, n_samples).
417
+ center_frequencies (ndarray): The center frequencies of the waveforms in Hz of shape
418
+ (n_waveforms,) or a scalar if all waveforms have the same center frequency.
419
+ waveform_sampling_frequency (float): The sampling frequency of the waveforms in Hz.
420
+
421
+ Returns:
422
+ ndarray: The time to peak for each waveform in seconds.
423
+ """
424
+ t_peak = []
425
+ center_frequencies = center_frequencies * ops.ones((waveforms.shape[0],))
426
+ for waveform, center_frequency in zip(waveforms, center_frequencies):
427
+ t_peak.append(compute_time_to_peak(waveform, center_frequency, waveform_sampling_frequency))
428
+ return ops.stack(t_peak)
429
+
430
+
431
+ def compute_time_to_peak(waveform, center_frequency, waveform_sampling_frequency=250e6):
432
+ """Compute the time of the peak of the waveform.
433
+
434
+ Args:
435
+ waveform (ndarray): The waveform of shape (n_samples).
436
+ center_frequency (float): The center frequency of the waveform in Hz.
437
+ waveform_sampling_frequency (float): The sampling frequency of the waveform in Hz.
438
+
439
+ Returns:
440
+ float: The time to peak for the waveform in seconds.
441
+ """
442
+ n_samples = waveform.shape[0]
443
+ if n_samples == 0:
444
+ raise ValueError("Waveform has zero samples.")
445
+
446
+ waveforms_iq_complex_channels = demodulate(
447
+ waveform[..., None], center_frequency, waveform_sampling_frequency, axis=-1
448
+ )
449
+ waveforms_iq_complex = channels_to_complex(waveforms_iq_complex_channels)
450
+ envelope = ops.abs(waveforms_iq_complex)
451
+ peak_idx = ops.argmax(envelope, axis=-1)
452
+ t_peak = ops.cast(peak_idx, dtype="float32") / waveform_sampling_frequency
453
+ return t_peak
454
+
455
+
456
+ def envelope_detect(data, axis=-3):
457
+ """Envelope detection of RF signals.
458
+
459
+ If the input data is real, it first applies the Hilbert transform along the specified axis
460
+ and then computes the magnitude of the resulting complex signal.
461
+ If the input data is complex, it computes the magnitude directly.
462
+
463
+ Args:
464
+ - data (Tensor): The beamformed data of shape (..., grid_size_z, grid_size_x, n_ch).
465
+ - axis (int): Axis along which to apply the Hilbert transform. Defaults to -3.
466
+
467
+ Returns:
468
+ - envelope_data (Tensor): The envelope detected data
469
+ of shape (..., grid_size_z, grid_size_x).
470
+ """
471
+ if data.shape[-1] == 2:
472
+ data = channels_to_complex(data)
473
+ else:
474
+ n_ax = ops.shape(data)[axis]
475
+ n_ax_float = ops.cast(n_ax, "float32")
476
+
477
+ # Calculate next power of 2: M = 2^ceil(log2(n_ax))
478
+ # see https://github.com/tue-bmd/zea/discussions/147
479
+ log2_n_ax = ops.log2(n_ax_float)
480
+ M = ops.cast(2 ** ops.ceil(log2_n_ax), "int32")
481
+
482
+ data = hilbert(data, N=M, axis=axis)
483
+ indices = ops.arange(n_ax)
484
+
485
+ data = ops.take(data, indices, axis=axis)
486
+ data = ops.squeeze(data, axis=-1)
487
+
488
+ # data = ops.abs(data)
489
+ real = ops.real(data)
490
+ imag = ops.imag(data)
491
+ data = ops.sqrt(real**2 + imag**2)
492
+ data = ops.cast(data, "float32")
493
+ return data
494
+
495
+
496
+ def log_compress(data, eps=1e-16):
497
+ """Apply logarithmic compression to data."""
498
+ eps = ops.convert_to_tensor(eps, dtype=data.dtype)
499
+ data = ops.where(data == 0, eps, data) # Avoid log(0)
500
+ return 20 * ops.log10(data)
@@ -5,7 +5,7 @@ They can be used in zea pipelines like any other :class:`zea.Operation`, for exa
5
5
 
6
6
  .. doctest::
7
7
 
8
- >>> from zea.keras_ops import Squeeze
8
+ >>> from zea.ops.keras_ops import Squeeze
9
9
  >>> op = Squeeze(axis=1)
10
10
  """
11
11
 
@@ -78,7 +78,7 @@ They can be used in zea pipelines like any other :class:`zea.Operation`, for exa
78
78
 
79
79
  .. doctest::
80
80
 
81
- >>> from zea.keras_ops import Squeeze
81
+ >>> from zea.ops.keras_ops import Squeeze
82
82
 
83
83
  >>> op = Squeeze(axis=1)
84
84
 
@@ -89,7 +89,7 @@ Generated with Keras {keras.__version__}
89
89
  import keras
90
90
 
91
91
  from zea.internal.registry import ops_registry
92
- from zea.ops import Lambda
92
+ from zea.ops.base import Lambda
93
93
 
94
94
  class MissingKerasOps(ValueError):
95
95
  def __init__(self, class_name: str, func: str):
@@ -109,7 +109,7 @@ class MissingKerasOps(ValueError):
109
109
  content += _generate_operation_class_code(name, keras.ops.image)
110
110
 
111
111
  # Write to a temporary file first, then move to final location
112
- target_path = Path(__file__).parent.parent / "keras_ops.py"
112
+ target_path = Path(__file__).parent.parent / "ops/keras_ops.py"
113
113
  with tempfile.NamedTemporaryFile("w", delete=False, encoding="utf-8") as tmp_file:
114
114
  tmp_file.write(content)
115
115
  temp_path = Path(tmp_file.name)
@@ -117,7 +117,7 @@ class MissingKerasOps(ValueError):
117
117
  # Atomic move to avoid partial writes
118
118
  shutil.move(temp_path, target_path)
119
119
 
120
- print("Done generating `keras_ops.py`.")
120
+ print("Done generating `ops/keras_ops.py`.")
121
121
 
122
122
 
123
123
  if __name__ == "__main__":
zea/metrics.py CHANGED
@@ -7,12 +7,13 @@ import keras
7
7
  import numpy as np
8
8
  from keras import ops
9
9
 
10
- from zea import log, tensor_ops
10
+ from zea import log
11
11
  from zea.backend import func_on_device
12
+ from zea.func import tensor
13
+ from zea.func.tensor import translate
12
14
  from zea.internal.registry import metrics_registry
13
15
  from zea.internal.utils import reduce_to_signature
14
16
  from zea.models.lpips import LPIPS
15
- from zea.tensor_ops import translate
16
17
 
17
18
 
18
19
  def get_metric(name, **kwargs):
@@ -197,7 +198,7 @@ def ssim(
197
198
 
198
199
  # Construct a 1D convolution.
199
200
  def filter_fn_1(z):
200
- return tensor_ops.correlate(z, ops.flip(filt), mode="valid")
201
+ return tensor.correlate(z, ops.flip(filt), mode="valid")
201
202
 
202
203
  # Apply the vectorized filter along the y axis.
203
204
  def filter_fn_y(z):
@@ -300,7 +301,7 @@ def get_lpips(image_range, batch_size=None, clip=False):
300
301
 
301
302
  imgs = ops.stack([img1, img2], axis=-1)
302
303
  n_batch_dims = ops.ndim(img1) - 3
303
- return tensor_ops.func_with_one_batch_dim(
304
+ return tensor.func_with_one_batch_dim(
304
305
  unstack_lpips, imgs, n_batch_dims, batch_size=batch_size
305
306
  )
306
307
 
@@ -372,7 +373,7 @@ class Metrics:
372
373
  # Because most metric functions do not support batching, we vmap over the batch axes.
373
374
  metric_fn = fun
374
375
  for ax in reversed(batch_axes):
375
- metric_fn = tensor_ops.vmap(metric_fn, in_axes=ax, _use_torch_vmap=True)
376
+ metric_fn = tensor.vmap(metric_fn, in_axes=ax, _use_torch_vmap=True)
376
377
 
377
378
  out = func_on_device(metric_fn, device, y_true, y_pred)
378
379
 
zea/models/diffusion.py CHANGED
@@ -23,6 +23,7 @@ from keras import ops
23
23
 
24
24
  from zea.backend import _import_tf, jit
25
25
  from zea.backend.autograd import AutoGrad
26
+ from zea.func.tensor import L2, fori_loop, split_seed
26
27
  from zea.internal.core import Object
27
28
  from zea.internal.operators import Operator
28
29
  from zea.internal.registry import diffusion_guidance_registry, model_registry, operator_registry
@@ -33,7 +34,6 @@ from zea.models.preset_utils import register_presets
33
34
  from zea.models.presets import diffusion_model_presets
34
35
  from zea.models.unet import get_time_conditional_unetwork
35
36
  from zea.models.utils import LossTrackerWrapper
36
- from zea.tensor_ops import L2, fori_loop, split_seed
37
37
 
38
38
  tf = _import_tf()
39
39
 
zea/models/echonetlvh.py CHANGED
@@ -25,12 +25,12 @@ To try this model, simply load one of the available presets:
25
25
  import numpy as np
26
26
  from keras import ops
27
27
 
28
+ from zea.func.tensor import translate
28
29
  from zea.internal.registry import model_registry
29
30
  from zea.models.base import BaseModel
30
31
  from zea.models.deeplabv3 import DeeplabV3Plus
31
32
  from zea.models.preset_utils import register_presets
32
33
  from zea.models.presets import echonet_lvh_presets
33
- from zea.tensor_ops import translate
34
34
 
35
35
 
36
36
  @model_registry(name="echonetlvh")
zea/models/gmm.py CHANGED
@@ -4,8 +4,8 @@ import keras
4
4
  import numpy as np
5
5
  from keras import ops
6
6
 
7
+ from zea.func.tensor import linear_sum_assignment
7
8
  from zea.models.generative import GenerativeModel
8
- from zea.tensor_ops import linear_sum_assignment
9
9
 
10
10
 
11
11
  class GaussianMixtureModel(GenerativeModel):