zea 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. zea/__init__.py +3 -3
  2. zea/agent/masks.py +2 -2
  3. zea/agent/selection.py +3 -3
  4. zea/backend/__init__.py +1 -1
  5. zea/backend/tensorflow/dataloader.py +1 -5
  6. zea/beamform/beamformer.py +4 -2
  7. zea/beamform/pfield.py +2 -2
  8. zea/beamform/pixelgrid.py +1 -1
  9. zea/data/__init__.py +0 -9
  10. zea/data/augmentations.py +222 -29
  11. zea/data/convert/__init__.py +1 -6
  12. zea/data/convert/__main__.py +164 -0
  13. zea/data/convert/camus.py +106 -40
  14. zea/data/convert/echonet.py +184 -83
  15. zea/data/convert/echonetlvh/README.md +2 -3
  16. zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
  17. zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
  18. zea/data/convert/echonetlvh/precompute_crop.py +43 -64
  19. zea/data/convert/picmus.py +37 -40
  20. zea/data/convert/utils.py +86 -0
  21. zea/data/convert/verasonics.py +1247 -0
  22. zea/data/data_format.py +124 -6
  23. zea/data/dataloader.py +12 -7
  24. zea/data/datasets.py +109 -70
  25. zea/data/file.py +119 -82
  26. zea/data/file_operations.py +496 -0
  27. zea/data/preset_utils.py +2 -2
  28. zea/display.py +8 -9
  29. zea/doppler.py +5 -5
  30. zea/func/__init__.py +109 -0
  31. zea/{tensor_ops.py → func/tensor.py} +113 -69
  32. zea/func/ultrasound.py +500 -0
  33. zea/internal/_generate_keras_ops.py +5 -5
  34. zea/internal/checks.py +6 -12
  35. zea/internal/operators.py +4 -0
  36. zea/io_lib.py +108 -160
  37. zea/metrics.py +6 -5
  38. zea/models/__init__.py +1 -1
  39. zea/models/diffusion.py +63 -12
  40. zea/models/echonetlvh.py +1 -1
  41. zea/models/gmm.py +1 -1
  42. zea/models/lv_segmentation.py +2 -0
  43. zea/ops/__init__.py +188 -0
  44. zea/ops/base.py +442 -0
  45. zea/{keras_ops.py → ops/keras_ops.py} +2 -2
  46. zea/ops/pipeline.py +1472 -0
  47. zea/ops/tensor.py +356 -0
  48. zea/ops/ultrasound.py +890 -0
  49. zea/probes.py +2 -10
  50. zea/scan.py +35 -28
  51. zea/tools/fit_scan_cone.py +90 -160
  52. zea/tools/selection_tool.py +1 -1
  53. zea/tracking/__init__.py +16 -0
  54. zea/tracking/base.py +94 -0
  55. zea/tracking/lucas_kanade.py +474 -0
  56. zea/tracking/segmentation.py +110 -0
  57. zea/utils.py +11 -2
  58. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/METADATA +5 -1
  59. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/RECORD +62 -48
  60. zea/data/convert/matlab.py +0 -1237
  61. zea/ops.py +0 -3294
  62. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/WHEEL +0 -0
  63. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/entry_points.txt +0 -0
  64. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/licenses/LICENSE +0 -0
zea/ops/ultrasound.py ADDED
@@ -0,0 +1,890 @@
1
+ import uuid
2
+
3
+ import keras
4
+ import numpy as np
5
+ from keras import ops
6
+
7
+ from zea import log
8
+ from zea.beamform.beamformer import tof_correction
9
+ from zea.display import scan_convert
10
+ from zea.func.tensor import (
11
+ apply_along_axis,
12
+ correlate,
13
+ extend_n_dims,
14
+ reshape_axis,
15
+ )
16
+ from zea.func.ultrasound import (
17
+ channels_to_complex,
18
+ complex_to_channels,
19
+ demodulate,
20
+ envelope_detect,
21
+ get_low_pass_iq_filter,
22
+ log_compress,
23
+ upmix,
24
+ )
25
+ from zea.internal.core import (
26
+ DEFAULT_DYNAMIC_RANGE,
27
+ DataTypes,
28
+ )
29
+ from zea.internal.registry import ops_registry
30
+ from zea.ops.base import (
31
+ ImageOperation,
32
+ Operation,
33
+ )
34
+ from zea.ops.tensor import (
35
+ GaussianBlur,
36
+ )
37
+ from zea.simulator import simulate_rf
38
+
39
+
40
+ @ops_registry("simulate_rf")
41
+ class Simulate(Operation):
42
+ """Simulate RF data."""
43
+
44
+ # Define operation-specific static parameters
45
+ STATIC_PARAMS = ["n_ax", "apply_lens_correction"]
46
+
47
+ def __init__(self, **kwargs):
48
+ super().__init__(
49
+ output_data_type=DataTypes.RAW_DATA,
50
+ additional_output_keys=["n_ch"],
51
+ **kwargs,
52
+ )
53
+
54
+ def call(
55
+ self,
56
+ scatterer_positions,
57
+ scatterer_magnitudes,
58
+ probe_geometry,
59
+ apply_lens_correction,
60
+ lens_thickness,
61
+ lens_sound_speed,
62
+ sound_speed,
63
+ n_ax,
64
+ center_frequency,
65
+ sampling_frequency,
66
+ t0_delays,
67
+ initial_times,
68
+ element_width,
69
+ attenuation_coef,
70
+ tx_apodizations,
71
+ **kwargs,
72
+ ):
73
+ return {
74
+ self.output_key: simulate_rf(
75
+ ops.convert_to_tensor(scatterer_positions),
76
+ ops.convert_to_tensor(scatterer_magnitudes),
77
+ probe_geometry=probe_geometry,
78
+ apply_lens_correction=apply_lens_correction,
79
+ lens_thickness=lens_thickness,
80
+ lens_sound_speed=lens_sound_speed,
81
+ sound_speed=sound_speed,
82
+ n_ax=n_ax,
83
+ center_frequency=center_frequency,
84
+ sampling_frequency=sampling_frequency,
85
+ t0_delays=t0_delays,
86
+ initial_times=initial_times,
87
+ element_width=element_width,
88
+ attenuation_coef=attenuation_coef,
89
+ tx_apodizations=tx_apodizations,
90
+ ),
91
+ "n_ch": 1, # Simulate always returns RF data (so single channel)
92
+ }
93
+
94
+
95
+ @ops_registry("tof_correction")
96
+ class TOFCorrection(Operation):
97
+ """Time-of-flight correction operation for ultrasound data."""
98
+
99
+ # Define operation-specific static parameters
100
+ STATIC_PARAMS = ["f_number", "apply_lens_correction"]
101
+
102
+ def __init__(self, **kwargs):
103
+ super().__init__(
104
+ input_data_type=DataTypes.RAW_DATA,
105
+ output_data_type=DataTypes.ALIGNED_DATA,
106
+ **kwargs,
107
+ )
108
+
109
+ def call(
110
+ self,
111
+ flatgrid,
112
+ sound_speed,
113
+ polar_angles,
114
+ focus_distances,
115
+ sampling_frequency,
116
+ f_number,
117
+ demodulation_frequency,
118
+ t0_delays,
119
+ tx_apodizations,
120
+ initial_times,
121
+ probe_geometry,
122
+ t_peak,
123
+ tx_waveform_indices,
124
+ apply_lens_correction=None,
125
+ lens_thickness=None,
126
+ lens_sound_speed=None,
127
+ **kwargs,
128
+ ):
129
+ """Perform time-of-flight correction on raw RF data.
130
+
131
+ Args:
132
+ raw_data (ops.Tensor): Raw RF data to correct
133
+ flatgrid (ops.Tensor): Grid points at which to evaluate the time-of-flight
134
+ sound_speed (float): Sound speed in the medium
135
+ polar_angles (ops.Tensor): Polar angles for scan lines
136
+ focus_distances (ops.Tensor): Focus distances for scan lines
137
+ sampling_frequency (float): Sampling frequency
138
+ f_number (float): F-number for apodization
139
+ demodulation_frequency (float): Demodulation frequency
140
+ t0_delays (ops.Tensor): T0 delays
141
+ tx_apodizations (ops.Tensor): Transmit apodizations
142
+ initial_times (ops.Tensor): Initial times
143
+ probe_geometry (ops.Tensor): Probe element positions
144
+ t_peak (float): Time to peak of the transmit pulse
145
+ tx_waveform_indices (ops.Tensor): Index of the transmit waveform for each
146
+ transmit. (All zero if there is only one waveform)
147
+ apply_lens_correction (bool): Whether to apply lens correction
148
+ lens_thickness (float): Lens thickness
149
+ lens_sound_speed (float): Sound speed in the lens
150
+
151
+ Returns:
152
+ dict: Dictionary containing tof_corrected_data
153
+ """
154
+
155
+ raw_data = kwargs[self.key]
156
+
157
+ tof_kwargs = {
158
+ "flatgrid": flatgrid,
159
+ "t0_delays": t0_delays,
160
+ "tx_apodizations": tx_apodizations,
161
+ "sound_speed": sound_speed,
162
+ "probe_geometry": probe_geometry,
163
+ "initial_times": initial_times,
164
+ "sampling_frequency": sampling_frequency,
165
+ "demodulation_frequency": demodulation_frequency,
166
+ "f_number": f_number,
167
+ "polar_angles": polar_angles,
168
+ "focus_distances": focus_distances,
169
+ "t_peak": t_peak,
170
+ "tx_waveform_indices": tx_waveform_indices,
171
+ "apply_lens_correction": apply_lens_correction,
172
+ "lens_thickness": lens_thickness,
173
+ "lens_sound_speed": lens_sound_speed,
174
+ }
175
+
176
+ if not self.with_batch_dim:
177
+ tof_corrected = tof_correction(raw_data, **tof_kwargs)
178
+ else:
179
+ tof_corrected = ops.map(
180
+ lambda data: tof_correction(data, **tof_kwargs),
181
+ raw_data,
182
+ )
183
+
184
+ return {self.output_key: tof_corrected}
185
+
186
+
187
+ @ops_registry("pfield_weighting")
188
+ class PfieldWeighting(Operation):
189
+ """Weighting aligned data with the pressure field."""
190
+
191
+ def __init__(self, **kwargs):
192
+ super().__init__(
193
+ input_data_type=DataTypes.ALIGNED_DATA,
194
+ output_data_type=DataTypes.ALIGNED_DATA,
195
+ **kwargs,
196
+ )
197
+
198
+ def call(self, flat_pfield=None, **kwargs):
199
+ """Weight data with pressure field.
200
+
201
+ Args:
202
+ flat_pfield (ops.Tensor): Pressure field weight mask of shape (n_pix, n_tx)
203
+
204
+ Returns:
205
+ dict: Dictionary containing weighted data
206
+ """
207
+ data = kwargs[self.key] # must start with ((batch_size,) n_tx, n_pix, ...)
208
+
209
+ if flat_pfield is None:
210
+ return {self.output_key: data}
211
+
212
+ # Swap (n_pix, n_tx) to (n_tx, n_pix)
213
+ flat_pfield = ops.swapaxes(flat_pfield, 0, 1)
214
+
215
+ # Add batch dimension if needed
216
+ if self.with_batch_dim:
217
+ pfield_expanded = ops.expand_dims(flat_pfield, axis=0)
218
+ else:
219
+ pfield_expanded = flat_pfield
220
+
221
+ append_n_dims = ops.ndim(data) - ops.ndim(pfield_expanded)
222
+ pfield_expanded = extend_n_dims(pfield_expanded, axis=-1, n_dims=append_n_dims)
223
+
224
+ # Perform element-wise multiplication with the pressure weight mask
225
+ weighted_data = data * pfield_expanded
226
+
227
+ return {self.output_key: weighted_data}
228
+
229
+
230
+ @ops_registry("scan_convert")
231
+ class ScanConvert(Operation):
232
+ """Scan convert images to cartesian coordinates."""
233
+
234
+ STATIC_PARAMS = ["fill_value"]
235
+
236
+ def __init__(self, order=1, **kwargs):
237
+ """Initialize the ScanConvert operation.
238
+
239
+ Args:
240
+ order (int, optional): Interpolation order. Defaults to 1. Currently only
241
+ GPU support for order=1.
242
+ """
243
+ if order > 1:
244
+ jittable = False
245
+ log.warning(
246
+ "GPU support for order > 1 is not available. " + "Disabling jit for ScanConvert."
247
+ )
248
+ else:
249
+ jittable = True
250
+
251
+ super().__init__(
252
+ input_data_type=DataTypes.IMAGE,
253
+ output_data_type=DataTypes.IMAGE_SC,
254
+ jittable=jittable,
255
+ additional_output_keys=[
256
+ "resolution",
257
+ "x_lim",
258
+ "y_lim",
259
+ "z_lim",
260
+ "rho_range",
261
+ "theta_range",
262
+ "phi_range",
263
+ "d_rho",
264
+ "d_theta",
265
+ "d_phi",
266
+ ],
267
+ **kwargs,
268
+ )
269
+ self.order = order
270
+
271
+ def call(
272
+ self,
273
+ rho_range=None,
274
+ theta_range=None,
275
+ phi_range=None,
276
+ resolution=None,
277
+ coordinates=None,
278
+ fill_value=None,
279
+ **kwargs,
280
+ ):
281
+ """Scan convert images to cartesian coordinates.
282
+
283
+ Args:
284
+ rho_range (Tuple): Range of the rho axis in the polar coordinate system.
285
+ Defined in meters.
286
+ theta_range (Tuple): Range of the theta axis in the polar coordinate system.
287
+ Defined in radians.
288
+ phi_range (Tuple): Range of the phi axis in the polar coordinate system.
289
+ Defined in radians.
290
+ resolution (float): Resolution of the output image in meters per pixel.
291
+ if None, the resolution is computed based on the input data.
292
+ coordinates (Tensor): Coordinates for scan convertion. If None, will be computed
293
+ based on rho_range, theta_range, phi_range and resolution. If provided, this
294
+ operation can be jitted.
295
+ fill_value (float): Value to fill the image with outside the defined region.
296
+
297
+ """
298
+ if fill_value is None:
299
+ fill_value = np.nan
300
+
301
+ data = kwargs[self.key]
302
+
303
+ if self._jit_compile and self.jittable:
304
+ assert coordinates is not None, (
305
+ "coordinates must be provided to jit scan conversion."
306
+ "You can set ScanConvert(jit_compile=False) to disable jitting."
307
+ )
308
+
309
+ data_out, parameters = scan_convert(
310
+ data,
311
+ rho_range,
312
+ theta_range,
313
+ phi_range,
314
+ resolution,
315
+ coordinates,
316
+ fill_value,
317
+ self.order,
318
+ with_batch_dim=self.with_batch_dim,
319
+ )
320
+
321
+ return {self.output_key: data_out, **parameters}
322
+
323
+
324
+ @ops_registry("demodulate")
325
+ class Demodulate(Operation):
326
+ """Demodulates the input data to baseband. After this operation, the carrier frequency
327
+ is removed (0 Hz) and the data is in IQ format stored in two real valued channels."""
328
+
329
+ def __init__(self, axis=-3, **kwargs):
330
+ super().__init__(
331
+ input_data_type=DataTypes.RAW_DATA,
332
+ output_data_type=DataTypes.RAW_DATA,
333
+ jittable=True,
334
+ additional_output_keys=[
335
+ "demodulation_frequency",
336
+ "center_frequency",
337
+ "n_ch",
338
+ ],
339
+ **kwargs,
340
+ )
341
+ self.axis = axis
342
+
343
+ def call(self, center_frequency=None, sampling_frequency=None, **kwargs):
344
+ data = kwargs[self.key]
345
+
346
+ demodulation_frequency = center_frequency
347
+
348
+ # Split the complex signal into two channels
349
+ iq_data_two_channel = demodulate(
350
+ data=data,
351
+ center_frequency=center_frequency,
352
+ sampling_frequency=sampling_frequency,
353
+ axis=self.axis,
354
+ )
355
+
356
+ return {
357
+ self.output_key: iq_data_two_channel,
358
+ "demodulation_frequency": demodulation_frequency,
359
+ "center_frequency": 0.0,
360
+ "n_ch": 2,
361
+ }
362
+
363
+
364
+ @ops_registry("fir_filter")
365
+ class FirFilter(Operation):
366
+ """Apply a FIR filter to the input signal using convolution.
367
+
368
+ Looks for the filter taps in the input dictionary using the specified ``filter_key``.
369
+ """
370
+
371
+ def __init__(
372
+ self,
373
+ axis: int,
374
+ complex_channels: bool = False,
375
+ filter_key: str = "fir_filter_taps",
376
+ **kwargs,
377
+ ):
378
+ """
379
+ Args:
380
+ axis (int): Axis along which to apply the filter. Cannot be the batch dimension.
381
+ When using ``complex_channels=True``, the complex channels are removed to convert
382
+ to complex numbers before filtering, so adjust the ``axis`` accordingly!
383
+ complex_channels (bool): Whether the last dimension of the input signal represents
384
+ complex channels (real and imaginary parts). When True, it will convert the signal
385
+ to ``complex`` dtype before filtering and convert it back to two channels
386
+ after filtering.
387
+ filter_key (str): Key in the input dictionary where the FIR filter taps are stored.
388
+ Default is "fir_filter_taps".
389
+ """
390
+ super().__init__(**kwargs)
391
+ self._check_axis(axis)
392
+
393
+ self.axis = axis
394
+ self.complex_channels = complex_channels
395
+ self.filter_key = filter_key
396
+
397
+ def _check_axis(self, axis, ndim=None):
398
+ """Check if the axis is valid."""
399
+ if ndim is not None:
400
+ if axis < -ndim or axis >= ndim:
401
+ raise ValueError(f"Axis {axis} is out of bounds for array of dimension {ndim}.")
402
+
403
+ if self.with_batch_dim and (axis == 0 or (ndim is not None and axis == -ndim)):
404
+ raise ValueError("Cannot apply FIR filter along batch dimension.")
405
+
406
+ @property
407
+ def valid_keys(self):
408
+ """Get the valid keys for the `call` method."""
409
+ return self._valid_keys.union({self.filter_key})
410
+
411
+ def call(self, **kwargs):
412
+ signal = kwargs[self.key]
413
+ fir_filter_taps = kwargs[self.filter_key]
414
+
415
+ if self.complex_channels:
416
+ signal = channels_to_complex(signal)
417
+
418
+ self._check_axis(self.axis, ndim=ops.ndim(signal))
419
+
420
+ def _convolve(signal):
421
+ """Apply the filter to the signal using correlation."""
422
+ return correlate(signal, fir_filter_taps[::-1], mode="same")
423
+
424
+ filtered_signal = apply_along_axis(_convolve, self.axis, signal)
425
+
426
+ if self.complex_channels:
427
+ filtered_signal = complex_to_channels(filtered_signal)
428
+
429
+ return {self.output_key: filtered_signal}
430
+
431
+
432
+ @ops_registry("low_pass_filter")
433
+ class LowPassFilter(FirFilter):
434
+ """Apply a low-pass FIR filter to the input signal using convolution.
435
+
436
+ It is recommended to use :class:`FirFilter` with pre-computed filter taps for jittable
437
+ operations. The :class:`LowPassFilter` operation itself is not jittable and is provided
438
+ for convenience only.
439
+
440
+ Uses :func:`get_low_pass_iq_filter` to compute the filter taps.
441
+ """
442
+
443
+ def __init__(self, axis: int, complex_channels: bool = False, num_taps: int = 128, **kwargs):
444
+ """Initialize the LowPassFilter operation.
445
+
446
+ Args:
447
+ axis (int): Axis along which to apply the filter. Cannot be the batch dimension.
448
+ When using ``complex_channels=True``, the complex channels are removed to convert
449
+ to complex numbers before filtering, so adjust the ``axis`` accordingly.
450
+ complex_channels (bool): Whether the last dimension of the input signal represents
451
+ complex channels (real and imaginary parts). When True, it will convert the signal
452
+ to ``complex`` dtype before filtering and convert it back to two channels
453
+ after filtering.
454
+ num_taps (int): Number of taps in the FIR filter. Default is 128.
455
+ """
456
+ self._random_suffix = str(uuid.uuid4())
457
+ kwargs.pop("filter_key", None)
458
+ kwargs.pop("jittable", None)
459
+ super().__init__(
460
+ axis=axis,
461
+ complex_channels=complex_channels,
462
+ filter_key=f"low_pass_{self._random_suffix}",
463
+ jittable=False,
464
+ **kwargs,
465
+ )
466
+ self.num_taps = num_taps
467
+
468
+ def call(self, bandwidth, sampling_frequency, center_frequency, **kwargs):
469
+ lpf = get_low_pass_iq_filter(
470
+ self.num_taps,
471
+ ops.convert_to_numpy(sampling_frequency).item(),
472
+ ops.convert_to_numpy(center_frequency).item(),
473
+ ops.convert_to_numpy(bandwidth).item(),
474
+ )
475
+ kwargs[self.filter_key] = lpf
476
+ return super().call(**kwargs)
477
+
478
+
479
+ @ops_registry("channels_to_complex")
480
+ class ChannelsToComplex(Operation):
481
+ def call(self, **kwargs):
482
+ data = kwargs[self.key]
483
+ output = channels_to_complex(data)
484
+ return {self.output_key: output}
485
+
486
+
487
+ @ops_registry("complex_to_channels")
488
+ class ComplexToChannels(Operation):
489
+ def __init__(self, axis=-1, **kwargs):
490
+ super().__init__(**kwargs)
491
+ self.axis = axis
492
+
493
+ def call(self, **kwargs):
494
+ data = kwargs[self.key]
495
+ output = complex_to_channels(data, axis=self.axis)
496
+ return {self.output_key: output}
497
+
498
+
499
+ @ops_registry("lee_filter")
500
+ class LeeFilter(ImageOperation):
501
+ """
502
+ The Lee filter is a speckle reduction filter commonly used in synthetic aperture radar (SAR)
503
+ and ultrasound image processing. It smooths the image while preserving edges and details.
504
+ This implementation uses Gaussian filter for local statistics and treats channels independently.
505
+
506
+ Lee, J.S. (1980). Digital image enhancement and noise filtering by use of local statistics.
507
+ IEEE Transactions on Pattern Analysis and Machine Intelligence, (2), 165-168.
508
+ """
509
+
510
+ def __init__(self, sigma=3, kernel_size=None, pad_mode="symmetric", **kwargs):
511
+ """
512
+ Args:
513
+ sigma (float): Standard deviation for Gaussian kernel. Default is 3.
514
+ kernel_size (int, optional): Size of the Gaussian kernel. If None,
515
+ it will be calculated based on sigma.
516
+ pad_mode (str): Padding mode to be used for Gaussian blur. Default is "symmetric".
517
+ """
518
+ super().__init__(**kwargs)
519
+ self.sigma = sigma
520
+ self.kernel_size = kernel_size
521
+ self.pad_mode = pad_mode
522
+
523
+ # Create a GaussianBlur instance for computing local statistics
524
+ self.gaussian_blur = GaussianBlur(
525
+ sigma=self.sigma,
526
+ kernel_size=self.kernel_size,
527
+ pad_mode=self.pad_mode,
528
+ with_batch_dim=self.with_batch_dim,
529
+ jittable=self._jittable,
530
+ key="data",
531
+ )
532
+
533
+ @property
534
+ def with_batch_dim(self):
535
+ """Get the with_batch_dim property of the LeeFilter operation."""
536
+ return self._with_batch_dim
537
+
538
+ @with_batch_dim.setter
539
+ def with_batch_dim(self, value):
540
+ """Set the with_batch_dim property of the LeeFilter operation."""
541
+ self._with_batch_dim = value
542
+ if hasattr(self, "gaussian_blur"):
543
+ self.gaussian_blur.with_batch_dim = value
544
+
545
+ def call(self, **kwargs):
546
+ """Apply the Lee filter to the input data.
547
+
548
+ Args:
549
+ data (ops.Tensor): Input image data of shape (height, width, channels) with
550
+ optional batch dimension if ``self.with_batch_dim``.
551
+ """
552
+ super().call(**kwargs)
553
+ data = kwargs.pop(self.key)
554
+
555
+ # Apply Gaussian blur to get local mean
556
+ img_mean = self.gaussian_blur.call(data=data, **kwargs)[self.gaussian_blur.output_key]
557
+
558
+ # Apply Gaussian blur to squared data to get local squared mean
559
+ img_sqr_mean = self.gaussian_blur.call(
560
+ data=data**2,
561
+ **kwargs,
562
+ )[self.gaussian_blur.output_key]
563
+
564
+ # Calculate local variance
565
+ img_variance = img_sqr_mean - img_mean**2
566
+
567
+ # Calculate global variance (per channel)
568
+ overall_variance = ops.var(data, axis=(-3, -2), keepdims=True)
569
+
570
+ # Calculate adaptive weights
571
+ eps = keras.config.epsilon()
572
+ img_weights = img_variance / (img_variance + overall_variance + eps)
573
+
574
+ # Apply Lee filter formula
575
+ img_output = img_mean + img_weights * (data - img_mean)
576
+
577
+ return {self.output_key: img_output}
578
+
579
+
580
+ @ops_registry("companding")
581
+ class Companding(Operation):
582
+ """Companding according to the A- or μ-law algorithm.
583
+
584
+ Invertible compressing operation. Used to compress
585
+ dynamic range of input data (and subsequently expand).
586
+
587
+ μ-law companding:
588
+ https://en.wikipedia.org/wiki/%CE%9C-law_algorithm
589
+ A-law companding:
590
+ https://en.wikipedia.org/wiki/A-law_algorithm
591
+
592
+ Args:
593
+ expand (bool, optional): If set to False (default),
594
+ data is compressed, else expanded.
595
+ comp_type (str): either `a` or `mu`.
596
+ mu (float, optional): compression parameter. Defaults to 255.
597
+ A (float, optional): compression parameter. Defaults to 87.6.
598
+ """
599
+
600
+ def __init__(self, expand=False, comp_type="mu", **kwargs):
601
+ super().__init__(**kwargs)
602
+ self.expand = expand
603
+ self.comp_type = comp_type.lower()
604
+ if self.comp_type not in ["mu", "a"]:
605
+ raise ValueError("comp_type must be 'mu' or 'a'.")
606
+
607
+ if self.comp_type == "mu":
608
+ self._compand_func = self._mu_law_expand if self.expand else self._mu_law_compress
609
+ else:
610
+ self._compand_func = self._a_law_expand if self.expand else self._a_law_compress
611
+
612
+ @staticmethod
613
+ def _mu_law_compress(x, mu=255, **kwargs):
614
+ x = ops.clip(x, -1, 1)
615
+ return ops.sign(x) * ops.log(1.0 + mu * ops.abs(x)) / ops.log(1.0 + mu)
616
+
617
+ @staticmethod
618
+ def _mu_law_expand(y, mu=255, **kwargs):
619
+ y = ops.clip(y, -1, 1)
620
+ return ops.sign(y) * ((1.0 + mu) ** ops.abs(y) - 1.0) / mu
621
+
622
+ @staticmethod
623
+ def _a_law_compress(x, A=87.6, **kwargs):
624
+ x = ops.clip(x, -1, 1)
625
+ x_sign = ops.sign(x)
626
+ x_abs = ops.abs(x)
627
+ A_log = ops.log(A)
628
+ val1 = x_sign * A * x_abs / (1.0 + A_log)
629
+ val2 = x_sign * (1.0 + ops.log(A * x_abs)) / (1.0 + A_log)
630
+ y = ops.where((x_abs >= 0) & (x_abs < (1.0 / A)), val1, val2)
631
+ return y
632
+
633
+ @staticmethod
634
+ def _a_law_expand(y, A=87.6, **kwargs):
635
+ y = ops.clip(y, -1, 1)
636
+ y_sign = ops.sign(y)
637
+ y_abs = ops.abs(y)
638
+ A_log = ops.log(A)
639
+ val1 = y_sign * y_abs * (1.0 + A_log) / A
640
+ val2 = y_sign * ops.exp(y_abs * (1.0 + A_log) - 1.0) / A
641
+ x = ops.where((y_abs >= 0) & (y_abs < (1.0 / (1.0 + A_log))), val1, val2)
642
+ return x
643
+
644
+ def call(self, mu=255, A=87.6, **kwargs):
645
+ data = kwargs[self.key]
646
+
647
+ mu = ops.cast(mu, data.dtype)
648
+ A = ops.cast(A, data.dtype)
649
+
650
+ data_out = self._compand_func(data, mu=mu, A=A)
651
+ return {self.output_key: data_out}
652
+
653
+
654
+ @ops_registry("downsample")
655
+ class Downsample(Operation):
656
+ """Downsample data along a specific axis."""
657
+
658
+ def __init__(self, factor: int = 1, phase: int = 0, axis: int = -3, **kwargs):
659
+ super().__init__(
660
+ additional_output_keys=["sampling_frequency", "n_ax"],
661
+ **kwargs,
662
+ )
663
+ if factor < 1:
664
+ raise ValueError("Downsample factor must be >= 1.")
665
+ if phase < 0 or phase >= factor:
666
+ raise ValueError("phase must satisfy 0 <= phase < factor.")
667
+ self.factor = factor
668
+ self.phase = phase
669
+ self.axis = axis
670
+
671
+ def call(self, sampling_frequency=None, n_ax=None, **kwargs):
672
+ data = kwargs[self.key]
673
+ length = ops.shape(data)[self.axis]
674
+ sample_idx = ops.arange(self.phase, length, self.factor)
675
+ data_downsampled = ops.take(data, sample_idx, axis=self.axis)
676
+
677
+ output = {self.output_key: data_downsampled}
678
+ # downsampling also affects the sampling frequency
679
+ if sampling_frequency is not None:
680
+ sampling_frequency = sampling_frequency / self.factor
681
+ output["sampling_frequency"] = sampling_frequency
682
+ if n_ax is not None:
683
+ n_ax = n_ax // self.factor
684
+ output["n_ax"] = n_ax
685
+ return output
686
+
687
+
688
+ @ops_registry("anisotropic_diffusion")
689
+ class AnisotropicDiffusion(Operation):
690
+ """Speckle Reducing Anisotropic Diffusion (SRAD) filter.
691
+
692
+ Reference:
693
+ - https://www.researchgate.net/publication/5602035_Speckle_reducing_anisotropic_diffusion
694
+ - https://nl.mathworks.com/matlabcentral/fileexchange/54044-image-despeckle-filtering-toolbox
695
+ """
696
+
697
+ def call(self, niter=100, lmbda=0.1, rect=None, eps=1e-6, **kwargs):
698
+ """Anisotropic diffusion filter.
699
+
700
+ Assumes input data is non-negative.
701
+
702
+ Args:
703
+ niter: Number of iterations.
704
+ lmbda: Lambda parameter.
705
+ rect: Rectangle [x1, y1, x2, y2] for homogeneous noise (optional).
706
+ eps: Small epsilon for stability.
707
+ Returns:
708
+ Filtered image (2D tensor or batch of images).
709
+ """
710
+ data = kwargs[self.key]
711
+
712
+ if not self.with_batch_dim:
713
+ data = ops.expand_dims(data, axis=0)
714
+
715
+ batch_size = ops.shape(data)[0]
716
+
717
+ results = []
718
+ for i in range(batch_size):
719
+ image = data[i]
720
+ image_out = self._anisotropic_diffusion_single(image, niter, lmbda, rect, eps)
721
+ results.append(image_out)
722
+
723
+ result = ops.stack(results, axis=0)
724
+
725
+ if not self.with_batch_dim:
726
+ result = ops.squeeze(result, axis=0)
727
+
728
+ return {self.output_key: result}
729
+
730
+ def _anisotropic_diffusion_single(self, image, niter, lmbda, rect, eps):
731
+ """Apply anisotropic diffusion to a single image (2D)."""
732
+ image = ops.exp(image)
733
+ M, N = image.shape
734
+
735
+ for _ in range(niter):
736
+ iN = ops.concatenate([image[1:], ops.zeros((1, N), dtype=image.dtype)], axis=0)
737
+ iS = ops.concatenate([ops.zeros((1, N), dtype=image.dtype), image[:-1]], axis=0)
738
+ jW = ops.concatenate([image[:, 1:], ops.zeros((M, 1), dtype=image.dtype)], axis=1)
739
+ jE = ops.concatenate([ops.zeros((M, 1), dtype=image.dtype), image[:, :-1]], axis=1)
740
+
741
+ if rect is not None:
742
+ x1, y1, x2, y2 = rect
743
+ imageuniform = image[x1:x2, y1:y2]
744
+ q0_squared = (ops.std(imageuniform) / (ops.mean(imageuniform) + eps)) ** 2
745
+
746
+ dN = iN - image
747
+ dS = iS - image
748
+ dW = jW - image
749
+ dE = jE - image
750
+
751
+ G2 = (dN**2 + dS**2 + dW**2 + dE**2) / (image**2 + eps)
752
+ L = (dN + dS + dW + dE) / (image + eps)
753
+ num = (0.5 * G2) - ((1 / 16) * (L**2))
754
+ den = (1 + ((1 / 4) * L)) ** 2
755
+ q_squared = num / (den + eps)
756
+
757
+ if rect is not None:
758
+ den = (q_squared - q0_squared) / (q0_squared * (1 + q0_squared) + eps)
759
+ c = 1.0 / (1 + den)
760
+ cS = ops.concatenate([ops.zeros((1, N), dtype=image.dtype), c[:-1]], axis=0)
761
+ cE = ops.concatenate([ops.zeros((M, 1), dtype=image.dtype), c[:, :-1]], axis=1)
762
+
763
+ D = (cS * dS) + (c * dN) + (cE * dE) + (c * dW)
764
+ image = image + (lmbda / 4) * D
765
+
766
+ result = ops.log(image)
767
+ return result
768
+
769
+
770
+ @ops_registry("envelope_detect")
771
+ class EnvelopeDetect(Operation):
772
+ """Envelope detection of RF signals."""
773
+
774
+ def __init__(
775
+ self,
776
+ axis=-3,
777
+ **kwargs,
778
+ ):
779
+ super().__init__(
780
+ input_data_type=DataTypes.BEAMFORMED_DATA,
781
+ output_data_type=DataTypes.ENVELOPE_DATA,
782
+ **kwargs,
783
+ )
784
+ self.axis = axis
785
+
786
+ def call(self, **kwargs):
787
+ """
788
+ Args:
789
+ - data (Tensor): The beamformed data of shape (..., grid_size_z, grid_size_x, n_ch).
790
+ Returns:
791
+ - envelope_data (Tensor): The envelope detected data
792
+ of shape (..., grid_size_z, grid_size_x).
793
+ """
794
+ data = kwargs[self.key]
795
+
796
+ data = envelope_detect(data, axis=self.axis)
797
+
798
+ return {self.output_key: data}
799
+
800
+
801
+ @ops_registry("upmix")
802
+ class UpMix(Operation):
803
+ """Upmix IQ data to RF data."""
804
+
805
+ def __init__(
806
+ self,
807
+ upsampling_rate=1,
808
+ **kwargs,
809
+ ):
810
+ super().__init__(
811
+ **kwargs,
812
+ )
813
+ self.upsampling_rate = upsampling_rate
814
+
815
+ def call(
816
+ self,
817
+ sampling_frequency=None,
818
+ center_frequency=None,
819
+ **kwargs,
820
+ ):
821
+ data = kwargs[self.key]
822
+
823
+ if data.shape[-1] == 1:
824
+ log.warning("Upmixing is not applicable to RF data.")
825
+ return {self.output_key: data}
826
+ elif data.shape[-1] == 2:
827
+ data = channels_to_complex(data)
828
+
829
+ data = upmix(data, sampling_frequency, center_frequency, self.upsampling_rate)
830
+ data = ops.expand_dims(data, axis=-1)
831
+ return {self.output_key: data}
832
+
833
+
834
+ @ops_registry("log_compress")
835
+ class LogCompress(Operation):
836
+ """Logarithmic compression of data."""
837
+
838
+ def __init__(self, clip: bool = True, **kwargs):
839
+ """Initialize the LogCompress operation.
840
+
841
+ Args:
842
+ clip (bool): Whether to clip the output to a dynamic range. Defaults to True.
843
+ """
844
+ super().__init__(
845
+ input_data_type=DataTypes.ENVELOPE_DATA,
846
+ output_data_type=DataTypes.IMAGE,
847
+ **kwargs,
848
+ )
849
+ self.clip = clip
850
+
851
+ def call(self, dynamic_range=None, **kwargs):
852
+ """Apply logarithmic compression to data.
853
+
854
+ Args:
855
+ dynamic_range (tuple, optional): Dynamic range in dB. Defaults to (-60, 0).
856
+
857
+ Returns:
858
+ dict: Dictionary containing log-compressed data
859
+ """
860
+ data = kwargs[self.key]
861
+
862
+ if dynamic_range is None:
863
+ dynamic_range = ops.array(DEFAULT_DYNAMIC_RANGE)
864
+ dynamic_range = ops.cast(dynamic_range, data.dtype)
865
+
866
+ compressed_data = log_compress(data)
867
+ if self.clip:
868
+ compressed_data = ops.clip(compressed_data, dynamic_range[0], dynamic_range[1])
869
+
870
+ return {self.output_key: compressed_data}
871
+
872
+
873
+ @ops_registry("reshape_grid")
874
+ class ReshapeGrid(Operation):
875
+ """Reshape flat grid data to grid shape."""
876
+
877
+ def __init__(self, axis=0, **kwargs):
878
+ super().__init__(**kwargs)
879
+ self.axis = axis
880
+
881
+ def call(self, grid, **kwargs):
882
+ """
883
+ Args:
884
+ - data (Tensor): The flat grid data of shape (..., n_pix, ...).
885
+ Returns:
886
+ - reshaped_data (Tensor): The reshaped data of shape (..., grid.shape, ...).
887
+ """
888
+ data = kwargs[self.key]
889
+ reshaped_data = reshape_axis(data, grid.shape[:-1], self.axis + int(self.with_batch_dim))
890
+ return {self.output_key: reshaped_data}