zea 0.0.7__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. zea/__init__.py +1 -1
  2. zea/backend/tensorflow/dataloader.py +0 -4
  3. zea/beamform/pixelgrid.py +1 -1
  4. zea/data/__init__.py +0 -9
  5. zea/data/augmentations.py +221 -28
  6. zea/data/convert/__init__.py +1 -6
  7. zea/data/convert/__main__.py +123 -0
  8. zea/data/convert/camus.py +99 -39
  9. zea/data/convert/echonet.py +183 -82
  10. zea/data/convert/echonetlvh/README.md +2 -3
  11. zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +173 -102
  12. zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
  13. zea/data/convert/echonetlvh/precompute_crop.py +43 -64
  14. zea/data/convert/picmus.py +37 -40
  15. zea/data/convert/utils.py +86 -0
  16. zea/data/convert/{matlab.py → verasonics.py} +33 -61
  17. zea/data/data_format.py +124 -4
  18. zea/data/dataloader.py +12 -7
  19. zea/data/datasets.py +109 -70
  20. zea/data/file.py +91 -82
  21. zea/data/file_operations.py +496 -0
  22. zea/data/preset_utils.py +1 -1
  23. zea/display.py +7 -8
  24. zea/internal/checks.py +6 -12
  25. zea/internal/operators.py +4 -0
  26. zea/io_lib.py +108 -160
  27. zea/models/__init__.py +1 -1
  28. zea/models/diffusion.py +62 -11
  29. zea/models/lv_segmentation.py +2 -0
  30. zea/ops.py +398 -158
  31. zea/scan.py +18 -8
  32. zea/tensor_ops.py +82 -62
  33. zea/tools/fit_scan_cone.py +90 -160
  34. zea/tracking/__init__.py +16 -0
  35. zea/tracking/base.py +94 -0
  36. zea/tracking/lucas_kanade.py +474 -0
  37. zea/tracking/segmentation.py +110 -0
  38. zea/utils.py +11 -2
  39. {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/METADATA +3 -1
  40. {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/RECORD +43 -35
  41. {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/WHEEL +0 -0
  42. {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/entry_points.txt +0 -0
  43. {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/licenses/LICENSE +0 -0
zea/scan.py CHANGED
@@ -153,7 +153,6 @@ class Scan(Parameters):
153
153
  apply_lens_correction (bool, optional): Whether to apply lens correction to
154
154
  delays. Defaults to False.
155
155
  lens_thickness (float, optional): Thickness of the lens in meters.
156
- Defaults to None.
157
156
  f_number (float, optional): F-number of the transducer. Defaults to 1.0.
158
157
  theta_range (tuple, optional): Range of theta angles for 3D imaging.
159
158
  phi_range (tuple, optional): Range of phi angles for 3D imaging.
@@ -215,8 +214,8 @@ class Scan(Parameters):
215
214
  "initial_times": {"type": np.ndarray},
216
215
  "time_to_next_transmit": {"type": np.ndarray},
217
216
  "tgc_gain_curve": {"type": np.ndarray},
218
- "waveforms_one_way": {"type": np.ndarray},
219
- "waveforms_two_way": {"type": np.ndarray},
217
+ "waveforms_one_way": {"type": np.ndarray, "default": None},
218
+ "waveforms_two_way": {"type": np.ndarray, "default": None},
220
219
  "tx_waveform_indices": {"type": np.ndarray},
221
220
  "t_peak": {"type": np.ndarray},
222
221
  # scan conversion parameters
@@ -508,7 +507,7 @@ class Scan(Parameters):
508
507
  value = self._params.get("azimuth_angles")
509
508
  if value is None:
510
509
  log.warning("No azimuth angles provided, using zeros")
511
- value = np.zeros(self.n_tx_selected)
510
+ return np.zeros(self.n_tx_selected)
512
511
 
513
512
  return value[self.selected_transmits]
514
513
 
@@ -529,7 +528,7 @@ class Scan(Parameters):
529
528
  value = self._params.get("tx_apodizations")
530
529
  if value is None:
531
530
  log.warning("No transmit apodizations provided, using ones")
532
- value = np.ones((self.n_tx_selected, self.n_el))
531
+ return np.ones((self.n_tx_selected, self.n_el))
533
532
 
534
533
  return value[self.selected_transmits]
535
534
 
@@ -539,7 +538,7 @@ class Scan(Parameters):
539
538
  value = self._params.get("focus_distances")
540
539
  if value is None:
541
540
  log.warning("No focus distances provided, using zeros")
542
- value = np.zeros(self.n_tx_selected)
541
+ return np.zeros(self.n_tx_selected)
543
542
 
544
543
  return value[self.selected_transmits]
545
544
 
@@ -549,7 +548,7 @@ class Scan(Parameters):
549
548
  value = self._params.get("initial_times")
550
549
  if value is None:
551
550
  log.warning("No initial times provided, using zeros")
552
- value = np.zeros(self.n_tx_selected)
551
+ return np.zeros(self.n_tx_selected)
553
552
 
554
553
  return value[self.selected_transmits]
555
554
 
@@ -617,7 +616,7 @@ class Scan(Parameters):
617
616
 
618
617
  @cache_with_dependencies("pfield")
619
618
  def flat_pfield(self):
620
- """Flattened pfield for weighting."""
619
+ """Flattened pfield for weighting of shape (n_pix, n_tx)."""
621
620
  return self.pfield.reshape(self.n_tx, -1).swapaxes(0, 1)
622
621
 
623
622
  @cache_with_dependencies("zlims")
@@ -674,6 +673,17 @@ class Scan(Parameters):
674
673
  otherwise 2D."""
675
674
  return self.coordinates_3d if getattr(self, "phi_range", None) else self.coordinates_2d
676
675
 
676
+ @property
677
+ def pulse_repetition_frequency(self):
678
+ """The pulse repetition frequency (PRF) [Hz]. Assumes a constant PRF."""
679
+ if self.time_to_next_transmit is None:
680
+ log.warning("Time to next transmit is not set, cannot compute PRF")
681
+ return None
682
+
683
+ pulse_repetition_interval = np.mean(self.time_to_next_transmit)
684
+
685
+ return 1 / pulse_repetition_interval
686
+
677
687
  @cache_with_dependencies("time_to_next_transmit")
678
688
  def frames_per_second(self):
679
689
  """The number of frames per second [Hz]. Assumes a constant frame rate.
zea/tensor_ops.py CHANGED
@@ -1,6 +1,6 @@
1
1
  """Basic tensor operations implemented with the multi-backend ``keras.ops``."""
2
2
 
3
- from typing import Tuple, Union
3
+ from typing import List, Tuple, Union
4
4
 
5
5
  import keras
6
6
  import numpy as np
@@ -9,7 +9,7 @@ from scipy.ndimage import _ni_support
9
9
  from scipy.ndimage._filters import _gaussian_kernel1d
10
10
 
11
11
  from zea import log
12
- from zea.utils import map_negative_indices
12
+ from zea.utils import canonicalize_axis
13
13
 
14
14
 
15
15
  def split_seed(seed, n):
@@ -431,19 +431,20 @@ def _map(fun, in_axes=0, out_axes=0, map_fn=None, _use_torch_vmap=False):
431
431
 
432
432
 
433
433
  def vmap(
434
- fun,
435
- in_axes=0,
436
- out_axes=0,
437
- batch_size=None,
438
- chunks=None,
439
- fn_supports_batch=False,
440
- disable_jit=False,
441
- _use_torch_vmap=False,
434
+ fun: callable,
435
+ in_axes: List[Union[int, None]] | int = 0,
436
+ out_axes: List[Union[int, None]] | int = 0,
437
+ chunks: int | None = None,
438
+ batch_size: int | None = None,
439
+ fn_supports_batch: bool = False,
440
+ disable_jit: bool = False,
441
+ _use_torch_vmap: bool = False,
442
442
  ):
443
443
  """`vmap` with batching or chunking support to avoid memory issues.
444
444
 
445
445
  Basically a wrapper around `vmap` that splits the input into batches or chunks
446
- to avoid memory issues with large inputs.
446
+ to avoid memory issues with large inputs. Choose the `batch_size` or `chunks` wisely, because
447
+ it pads the input to make it divisible, and then crop the output back to the original size.
447
448
 
448
449
  Args:
449
450
  fun: Function to be mapped.
@@ -648,7 +649,7 @@ def pad_array_to_divisible(arr, N, axis=0, mode="constant", pad_value=None):
648
649
  return padded_array
649
650
 
650
651
 
651
- def interpolate_data(subsampled_data, mask, order=1, axis=-1):
652
+ def interpolate_data(subsampled_data, mask, order=1, axis=-1, fill_mode="nearest", fill_value=0):
652
653
  """Interpolate subsampled data along a specified axis using `map_coordinates`.
653
654
 
654
655
  Args:
@@ -658,6 +659,12 @@ def interpolate_data(subsampled_data, mask, order=1, axis=-1):
658
659
  `True` where data is known.
659
660
  order (int, optional): The order of the spline interpolation. Default is `1`.
660
661
  axis (int, optional): The axis along which the data is subsampled. Default is `-1`.
662
+ fill_mode (str, optional): Points outside the boundaries of the input are filled
663
+ according to the given mode. Default is 'nearest'. For more info see
664
+ `keras.ops.image.map_coordinates`.
665
+ fill_value (float, optional): Value to use for points outside the boundaries
666
+ of the input if `fill_mode` is 'constant'. Default is `0`. For more info see
667
+ `keras.ops.image.map_coordinates`.
661
668
 
662
669
  Returns:
663
670
  ndarray: The data interpolated back to the original grid.
@@ -722,6 +729,8 @@ def interpolate_data(subsampled_data, mask, order=1, axis=-1):
722
729
  subsampled_data,
723
730
  interp_coords,
724
731
  order=order,
732
+ fill_mode=fill_mode,
733
+ fill_value=fill_value,
725
734
  )
726
735
 
727
736
  interpolated_data = ops.reshape(interpolated_data, -1)
@@ -1253,7 +1262,7 @@ def reshape_axis(data, newshape: tuple, axis: int):
1253
1262
  >>> reshaped_data.shape
1254
1263
  (3, 2, 2, 5)
1255
1264
  """
1256
- axis = map_negative_indices([axis], data.ndim)[0]
1265
+ axis = canonicalize_axis(axis, data.ndim)
1257
1266
  shape = list(ops.shape(data)) # list
1258
1267
  shape = shape[:axis] + list(newshape) + shape[axis + 1 :]
1259
1268
  return ops.reshape(data, shape)
@@ -1512,7 +1521,8 @@ def sinc(x, eps=keras.config.epsilon()):
1512
1521
  def apply_along_axis(func1d, axis, arr, *args, **kwargs):
1513
1522
  """Apply a function to 1D array slices along an axis.
1514
1523
 
1515
- Keras implementation of numpy.apply_along_axis using keras.ops.vectorized_map.
1524
+ Keras implementation of ``numpy.apply_along_axis``. Copies the ``jax`` implementation, which
1525
+ uses ``vmap`` to vectorize the function application along the specified axis.
1516
1526
 
1517
1527
  Args:
1518
1528
  func1d: A callable function with signature ``func1d(arr, /, *args, **kwargs)``
@@ -1529,56 +1539,13 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
1529
1539
  # Convert to keras tensor
1530
1540
  arr = ops.convert_to_tensor(arr)
1531
1541
 
1532
- # Get array dimensions
1533
- num_dims = len(arr.shape)
1534
-
1535
- # Canonicalize axis (handle negative indices)
1536
- if axis < 0:
1537
- axis = num_dims + axis
1538
-
1539
- if axis < 0 or axis >= num_dims:
1540
- raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}")
1541
-
1542
- # Create a wrapper function that applies func1d with the additional arguments
1543
- def func(slice_arr):
1544
- return func1d(slice_arr, *args, **kwargs)
1545
-
1546
- # Recursively build up vectorized maps following the JAX pattern
1547
- # For dimensions after the target axis (right side)
1542
+ num_dims = ops.ndim(arr)
1543
+ axis = canonicalize_axis(axis, num_dims)
1544
+ func = lambda arr: func1d(arr, *args, **kwargs)
1548
1545
  for i in range(1, num_dims - axis):
1549
- prev_func = func
1550
-
1551
- def make_func(f, dim_offset):
1552
- def vectorized_func(x):
1553
- # Move the dimension we want to map over to the front
1554
- perm = list(range(len(x.shape)))
1555
- perm[0], perm[dim_offset] = perm[dim_offset], perm[0]
1556
- x_moved = ops.transpose(x, perm)
1557
- result = vectorized_map(f, x_moved)
1558
- # Move the result dimension back if needed
1559
- if len(result.shape) > 0:
1560
- result_perm = list(range(len(result.shape)))
1561
- if len(result_perm) > dim_offset:
1562
- result_perm[0], result_perm[dim_offset] = (
1563
- result_perm[dim_offset],
1564
- result_perm[0],
1565
- )
1566
- result = ops.transpose(result, result_perm)
1567
- return result
1568
-
1569
- return vectorized_func
1570
-
1571
- func = make_func(prev_func, i)
1572
-
1573
- # For dimensions before the target axis (left side)
1546
+ func = vmap(func, in_axes=i, out_axes=-1, _use_torch_vmap=True)
1574
1547
  for i in range(axis):
1575
- prev_func = func
1576
-
1577
- def make_func(f):
1578
- return lambda x: vectorized_map(f, x)
1579
-
1580
- func = make_func(prev_func)
1581
-
1548
+ func = vmap(func, in_axes=0, out_axes=0, _use_torch_vmap=True)
1582
1549
  return func(arr)
1583
1550
 
1584
1551
 
@@ -1595,6 +1562,8 @@ def correlate(x, y, mode="full"):
1595
1562
  y: np.ndarray (complex or real)
1596
1563
  mode: "full", "valid", or "same"
1597
1564
  """
1565
+ if keras.backend.backend() == "jax":
1566
+ return ops.correlate(x, y, mode=mode)
1598
1567
  x = ops.convert_to_tensor(x)
1599
1568
  y = ops.convert_to_tensor(y)
1600
1569
 
@@ -1656,6 +1625,57 @@ def correlate(x, y, mode="full"):
1656
1625
  return ops.real(complex_tensor)
1657
1626
 
1658
1627
 
1628
+ def find_contour(binary_mask):
1629
+ """Extract contour/boundary points from a binary mask using edge detection.
1630
+
1631
+ This function finds the boundary pixels of objects in a binary mask by detecting
1632
+ pixels that have at least one neighbor with a different value (using 4-connectivity).
1633
+
1634
+ Args:
1635
+ binary_mask: Binary mask tensor of shape (H, W) with values 0 or 1.
1636
+
1637
+ Returns:
1638
+ Boundary points as tensor of shape (N, 2) in (row, col) format.
1639
+ Returns empty tensor of shape (0, 2) if no boundaries are found.
1640
+
1641
+ Example:
1642
+ .. doctest::
1643
+
1644
+ >>> from zea.tensor_ops import find_contour
1645
+ >>> import keras
1646
+ >>> mask = keras.ops.zeros((10, 10))
1647
+ >>> mask = keras.ops.scatter_update(
1648
+ ... mask, [[3, 3], [3, 4], [4, 3], [4, 4]], [1, 1, 1, 1]
1649
+ ... )
1650
+ >>> contour = find_contour(mask)
1651
+ >>> contour.shape
1652
+ (4, 2)
1653
+ """
1654
+ # Pad the mask to handle edges
1655
+ padded = ops.pad(binary_mask, [[1, 1], [1, 1]], mode="constant", constant_values=0.0)
1656
+
1657
+ # Check 4-connectivity (up, down, left, right)
1658
+ is_edge = (
1659
+ (binary_mask != padded[:-2, 1:-1]) # top neighbor different
1660
+ | (binary_mask != padded[2:, 1:-1]) # bottom neighbor different
1661
+ | (binary_mask != padded[1:-1, :-2]) # left neighbor different
1662
+ | (binary_mask != padded[1:-1, 2:]) # right neighbor different
1663
+ )
1664
+
1665
+ # Only keep edges that are part of the foreground (binary_mask == 1)
1666
+ is_boundary = is_edge & ops.cast(binary_mask, "bool")
1667
+
1668
+ boundary_indices = ops.where(is_boundary)
1669
+
1670
+ if ops.shape(boundary_indices[0])[0] > 0:
1671
+ boundary_points = ops.stack(boundary_indices, axis=1)
1672
+ boundary_points = ops.cast(boundary_points, "float32")
1673
+ else:
1674
+ boundary_points = ops.zeros((0, 2), dtype="float32")
1675
+
1676
+ return boundary_points
1677
+
1678
+
1659
1679
  def translate(array, range_from=None, range_to=(0, 255)):
1660
1680
  """Map values in array from one range to other.
1661
1681