zea 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. zea/__init__.py +3 -3
  2. zea/agent/masks.py +2 -2
  3. zea/agent/selection.py +3 -3
  4. zea/backend/__init__.py +1 -1
  5. zea/backend/tensorflow/dataloader.py +1 -5
  6. zea/beamform/beamformer.py +4 -2
  7. zea/beamform/pfield.py +2 -2
  8. zea/beamform/pixelgrid.py +1 -1
  9. zea/data/__init__.py +0 -9
  10. zea/data/augmentations.py +222 -29
  11. zea/data/convert/__init__.py +1 -6
  12. zea/data/convert/__main__.py +164 -0
  13. zea/data/convert/camus.py +106 -40
  14. zea/data/convert/echonet.py +184 -83
  15. zea/data/convert/echonetlvh/README.md +2 -3
  16. zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
  17. zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
  18. zea/data/convert/echonetlvh/precompute_crop.py +43 -64
  19. zea/data/convert/picmus.py +37 -40
  20. zea/data/convert/utils.py +86 -0
  21. zea/data/convert/verasonics.py +1247 -0
  22. zea/data/data_format.py +124 -6
  23. zea/data/dataloader.py +12 -7
  24. zea/data/datasets.py +109 -70
  25. zea/data/file.py +119 -82
  26. zea/data/file_operations.py +496 -0
  27. zea/data/preset_utils.py +2 -2
  28. zea/display.py +8 -9
  29. zea/doppler.py +5 -5
  30. zea/func/__init__.py +109 -0
  31. zea/{tensor_ops.py → func/tensor.py} +113 -69
  32. zea/func/ultrasound.py +500 -0
  33. zea/internal/_generate_keras_ops.py +5 -5
  34. zea/internal/checks.py +6 -12
  35. zea/internal/operators.py +4 -0
  36. zea/io_lib.py +108 -160
  37. zea/metrics.py +6 -5
  38. zea/models/__init__.py +1 -1
  39. zea/models/diffusion.py +63 -12
  40. zea/models/echonetlvh.py +1 -1
  41. zea/models/gmm.py +1 -1
  42. zea/models/lv_segmentation.py +2 -0
  43. zea/ops/__init__.py +188 -0
  44. zea/ops/base.py +442 -0
  45. zea/{keras_ops.py → ops/keras_ops.py} +2 -2
  46. zea/ops/pipeline.py +1472 -0
  47. zea/ops/tensor.py +356 -0
  48. zea/ops/ultrasound.py +890 -0
  49. zea/probes.py +2 -10
  50. zea/scan.py +35 -28
  51. zea/tools/fit_scan_cone.py +90 -160
  52. zea/tools/selection_tool.py +1 -1
  53. zea/tracking/__init__.py +16 -0
  54. zea/tracking/base.py +94 -0
  55. zea/tracking/lucas_kanade.py +474 -0
  56. zea/tracking/segmentation.py +110 -0
  57. zea/utils.py +11 -2
  58. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/METADATA +5 -1
  59. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/RECORD +62 -48
  60. zea/data/convert/matlab.py +0 -1237
  61. zea/ops.py +0 -3294
  62. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/WHEEL +0 -0
  63. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/entry_points.txt +0 -0
  64. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,6 @@
1
1
  """Basic tensor operations implemented with the multi-backend ``keras.ops``."""
2
2
 
3
- from typing import Tuple, Union
3
+ from typing import List, Tuple, Union
4
4
 
5
5
  import keras
6
6
  import numpy as np
@@ -9,7 +9,7 @@ from scipy.ndimage import _ni_support
9
9
  from scipy.ndimage._filters import _gaussian_kernel1d
10
10
 
11
11
  from zea import log
12
- from zea.utils import map_negative_indices
12
+ from zea.utils import canonicalize_axis
13
13
 
14
14
 
15
15
  def split_seed(seed, n):
@@ -329,7 +329,7 @@ def _map(fun, in_axes=0, out_axes=0, map_fn=None, _use_torch_vmap=False):
329
329
  For jax, this uses the native vmap implementation.
330
330
  For other backends, this a wrapper that uses `ops.vectorized_map` under the hood.
331
331
 
332
- Probably you want to use `zea.tensor_ops.vmap` instead, which uses this function
332
+ Probably you want to use `zea.func.vmap` instead, which uses this function
333
333
  with additional batching/chunking support.
334
334
 
335
335
  Args:
@@ -431,19 +431,20 @@ def _map(fun, in_axes=0, out_axes=0, map_fn=None, _use_torch_vmap=False):
431
431
 
432
432
 
433
433
  def vmap(
434
- fun,
435
- in_axes=0,
436
- out_axes=0,
437
- batch_size=None,
438
- chunks=None,
439
- fn_supports_batch=False,
440
- disable_jit=False,
441
- _use_torch_vmap=False,
434
+ fun: callable,
435
+ in_axes: List[Union[int, None]] | int = 0,
436
+ out_axes: List[Union[int, None]] | int = 0,
437
+ chunks: int | None = None,
438
+ batch_size: int | None = None,
439
+ fn_supports_batch: bool = False,
440
+ disable_jit: bool = False,
441
+ _use_torch_vmap: bool = False,
442
442
  ):
443
443
  """`vmap` with batching or chunking support to avoid memory issues.
444
444
 
445
445
  Basically a wrapper around `vmap` that splits the input into batches or chunks
446
- to avoid memory issues with large inputs.
446
+ to avoid memory issues with large inputs. Choose the `batch_size` or `chunks` wisely, because
447
+ it pads the input to make it divisible, and then crop the output back to the original size.
447
448
 
448
449
  Args:
449
450
  fun: Function to be mapped.
@@ -648,7 +649,7 @@ def pad_array_to_divisible(arr, N, axis=0, mode="constant", pad_value=None):
648
649
  return padded_array
649
650
 
650
651
 
651
- def interpolate_data(subsampled_data, mask, order=1, axis=-1):
652
+ def interpolate_data(subsampled_data, mask, order=1, axis=-1, fill_mode="nearest", fill_value=0):
652
653
  """Interpolate subsampled data along a specified axis using `map_coordinates`.
653
654
 
654
655
  Args:
@@ -658,6 +659,12 @@ def interpolate_data(subsampled_data, mask, order=1, axis=-1):
658
659
  `True` where data is known.
659
660
  order (int, optional): The order of the spline interpolation. Default is `1`.
660
661
  axis (int, optional): The axis along which the data is subsampled. Default is `-1`.
662
+ fill_mode (str, optional): Points outside the boundaries of the input are filled
663
+ according to the given mode. Default is 'nearest'. For more info see
664
+ `keras.ops.image.map_coordinates`.
665
+ fill_value (float, optional): Value to use for points outside the boundaries
666
+ of the input if `fill_mode` is 'constant'. Default is `0`. For more info see
667
+ `keras.ops.image.map_coordinates`.
661
668
 
662
669
  Returns:
663
670
  ndarray: The data interpolated back to the original grid.
@@ -722,6 +729,8 @@ def interpolate_data(subsampled_data, mask, order=1, axis=-1):
722
729
  subsampled_data,
723
730
  interp_coords,
724
731
  order=order,
732
+ fill_mode=fill_mode,
733
+ fill_value=fill_value,
725
734
  )
726
735
 
727
736
  interpolated_data = ops.reshape(interpolated_data, -1)
@@ -836,7 +845,7 @@ def stack_volume_data_along_axis(data, batch_axis: int, stack_axis: int, number:
836
845
  .. doctest::
837
846
 
838
847
  >>> import keras
839
- >>> from zea.tensor_ops import stack_volume_data_along_axis
848
+ >>> from zea.func import stack_volume_data_along_axis
840
849
 
841
850
  >>> data = keras.random.uniform((10, 20, 30))
842
851
  >>> # stacking along 1st axis with 2 frames per block
@@ -880,7 +889,7 @@ def split_volume_data_from_axis(data, batch_axis: int, stack_axis: int, number:
880
889
  .. doctest::
881
890
 
882
891
  >>> import keras
883
- >>> from zea.tensor_ops import split_volume_data_from_axis
892
+ >>> from zea.func import split_volume_data_from_axis
884
893
 
885
894
  >>> data = keras.random.uniform((20, 10, 30))
886
895
  >>> split_data = split_volume_data_from_axis(data, 0, 1, 2, 2)
@@ -1003,7 +1012,7 @@ def check_patches_fit(
1003
1012
  Example:
1004
1013
  .. doctest::
1005
1014
 
1006
- >>> from zea.tensor_ops import check_patches_fit
1015
+ >>> from zea.func import check_patches_fit
1007
1016
  >>> image_shape = (10, 10)
1008
1017
  >>> patch_shape = (4, 4)
1009
1018
  >>> overlap = (2, 2)
@@ -1071,7 +1080,7 @@ def images_to_patches(
1071
1080
  .. doctest::
1072
1081
 
1073
1082
  >>> import keras
1074
- >>> from zea.tensor_ops import images_to_patches
1083
+ >>> from zea.func import images_to_patches
1075
1084
 
1076
1085
  >>> images = keras.random.uniform((2, 8, 8, 3))
1077
1086
  >>> patches = images_to_patches(images, patch_shape=(4, 4), overlap=(2, 2))
@@ -1157,7 +1166,7 @@ def patches_to_images(
1157
1166
  .. doctest::
1158
1167
 
1159
1168
  >>> import keras
1160
- >>> from zea.tensor_ops import patches_to_images
1169
+ >>> from zea.func import patches_to_images
1161
1170
 
1162
1171
  >>> patches = keras.random.uniform((2, 3, 3, 4, 4, 3))
1163
1172
  >>> images = patches_to_images(patches, image_shape=(8, 8, 3), overlap=(2, 2))
@@ -1245,7 +1254,7 @@ def reshape_axis(data, newshape: tuple, axis: int):
1245
1254
  .. doctest::
1246
1255
 
1247
1256
  >>> import keras
1248
- >>> from zea.tensor_ops import reshape_axis
1257
+ >>> from zea.func import reshape_axis
1249
1258
 
1250
1259
  >>> data = keras.random.uniform((3, 4, 5))
1251
1260
  >>> newshape = (2, 2)
@@ -1253,7 +1262,7 @@ def reshape_axis(data, newshape: tuple, axis: int):
1253
1262
  >>> reshaped_data.shape
1254
1263
  (3, 2, 2, 5)
1255
1264
  """
1256
- axis = map_negative_indices([axis], data.ndim)[0]
1265
+ axis = canonicalize_axis(axis, data.ndim)
1257
1266
  shape = list(ops.shape(data)) # list
1258
1267
  shape = shape[:axis] + list(newshape) + shape[axis + 1 :]
1259
1268
  return ops.reshape(data, shape)
@@ -1512,7 +1521,8 @@ def sinc(x, eps=keras.config.epsilon()):
1512
1521
  def apply_along_axis(func1d, axis, arr, *args, **kwargs):
1513
1522
  """Apply a function to 1D array slices along an axis.
1514
1523
 
1515
- Keras implementation of numpy.apply_along_axis using keras.ops.vectorized_map.
1524
+ Keras implementation of ``numpy.apply_along_axis``. Copies the ``jax`` implementation, which
1525
+ uses ``vmap`` to vectorize the function application along the specified axis.
1516
1526
 
1517
1527
  Args:
1518
1528
  func1d: A callable function with signature ``func1d(arr, /, *args, **kwargs)``
@@ -1529,56 +1539,13 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
1529
1539
  # Convert to keras tensor
1530
1540
  arr = ops.convert_to_tensor(arr)
1531
1541
 
1532
- # Get array dimensions
1533
- num_dims = len(arr.shape)
1534
-
1535
- # Canonicalize axis (handle negative indices)
1536
- if axis < 0:
1537
- axis = num_dims + axis
1538
-
1539
- if axis < 0 or axis >= num_dims:
1540
- raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}")
1541
-
1542
- # Create a wrapper function that applies func1d with the additional arguments
1543
- def func(slice_arr):
1544
- return func1d(slice_arr, *args, **kwargs)
1545
-
1546
- # Recursively build up vectorized maps following the JAX pattern
1547
- # For dimensions after the target axis (right side)
1542
+ num_dims = ops.ndim(arr)
1543
+ axis = canonicalize_axis(axis, num_dims)
1544
+ func = lambda arr: func1d(arr, *args, **kwargs)
1548
1545
  for i in range(1, num_dims - axis):
1549
- prev_func = func
1550
-
1551
- def make_func(f, dim_offset):
1552
- def vectorized_func(x):
1553
- # Move the dimension we want to map over to the front
1554
- perm = list(range(len(x.shape)))
1555
- perm[0], perm[dim_offset] = perm[dim_offset], perm[0]
1556
- x_moved = ops.transpose(x, perm)
1557
- result = vectorized_map(f, x_moved)
1558
- # Move the result dimension back if needed
1559
- if len(result.shape) > 0:
1560
- result_perm = list(range(len(result.shape)))
1561
- if len(result_perm) > dim_offset:
1562
- result_perm[0], result_perm[dim_offset] = (
1563
- result_perm[dim_offset],
1564
- result_perm[0],
1565
- )
1566
- result = ops.transpose(result, result_perm)
1567
- return result
1568
-
1569
- return vectorized_func
1570
-
1571
- func = make_func(prev_func, i)
1572
-
1573
- # For dimensions before the target axis (left side)
1546
+ func = vmap(func, in_axes=i, out_axes=-1, _use_torch_vmap=True)
1574
1547
  for i in range(axis):
1575
- prev_func = func
1576
-
1577
- def make_func(f):
1578
- return lambda x: vectorized_map(f, x)
1579
-
1580
- func = make_func(prev_func)
1581
-
1548
+ func = vmap(func, in_axes=0, out_axes=0, _use_torch_vmap=True)
1582
1549
  return func(arr)
1583
1550
 
1584
1551
 
@@ -1595,6 +1562,8 @@ def correlate(x, y, mode="full"):
1595
1562
  y: np.ndarray (complex or real)
1596
1563
  mode: "full", "valid", or "same"
1597
1564
  """
1565
+ if keras.backend.backend() == "jax":
1566
+ return ops.correlate(x, y, mode=mode)
1598
1567
  x = ops.convert_to_tensor(x)
1599
1568
  y = ops.convert_to_tensor(y)
1600
1569
 
@@ -1656,6 +1625,57 @@ def correlate(x, y, mode="full"):
1656
1625
  return ops.real(complex_tensor)
1657
1626
 
1658
1627
 
1628
+ def find_contour(binary_mask):
1629
+ """Extract contour/boundary points from a binary mask using edge detection.
1630
+
1631
+ This function finds the boundary pixels of objects in a binary mask by detecting
1632
+ pixels that have at least one neighbor with a different value (using 4-connectivity).
1633
+
1634
+ Args:
1635
+ binary_mask: Binary mask tensor of shape (H, W) with values 0 or 1.
1636
+
1637
+ Returns:
1638
+ Boundary points as tensor of shape (N, 2) in (row, col) format.
1639
+ Returns empty tensor of shape (0, 2) if no boundaries are found.
1640
+
1641
+ Example:
1642
+ .. doctest::
1643
+
1644
+ >>> from zea.func import find_contour
1645
+ >>> import keras
1646
+ >>> mask = keras.ops.zeros((10, 10))
1647
+ >>> mask = keras.ops.scatter_update(
1648
+ ... mask, [[3, 3], [3, 4], [4, 3], [4, 4]], [1, 1, 1, 1]
1649
+ ... )
1650
+ >>> contour = find_contour(mask)
1651
+ >>> contour.shape
1652
+ (4, 2)
1653
+ """
1654
+ # Pad the mask to handle edges
1655
+ padded = ops.pad(binary_mask, [[1, 1], [1, 1]], mode="constant", constant_values=0.0)
1656
+
1657
+ # Check 4-connectivity (up, down, left, right)
1658
+ is_edge = (
1659
+ (binary_mask != padded[:-2, 1:-1]) # top neighbor different
1660
+ | (binary_mask != padded[2:, 1:-1]) # bottom neighbor different
1661
+ | (binary_mask != padded[1:-1, :-2]) # left neighbor different
1662
+ | (binary_mask != padded[1:-1, 2:]) # right neighbor different
1663
+ )
1664
+
1665
+ # Only keep edges that are part of the foreground (binary_mask == 1)
1666
+ is_boundary = is_edge & ops.cast(binary_mask, "bool")
1667
+
1668
+ boundary_indices = ops.where(is_boundary)
1669
+
1670
+ if ops.shape(boundary_indices[0])[0] > 0:
1671
+ boundary_points = ops.stack(boundary_indices, axis=1)
1672
+ boundary_points = ops.cast(boundary_points, "float32")
1673
+ else:
1674
+ boundary_points = ops.zeros((0, 2), dtype="float32")
1675
+
1676
+ return boundary_points
1677
+
1678
+
1659
1679
  def translate(array, range_from=None, range_to=(0, 255)):
1660
1680
  """Map values in array from one range to other.
1661
1681
 
@@ -1680,3 +1700,27 @@ def translate(array, range_from=None, range_to=(0, 255)):
1680
1700
 
1681
1701
  # Convert the 0-1 range into a value in the right range.
1682
1702
  return right_min + (value_scaled * (right_max - right_min))
1703
+
1704
+
1705
+ def normalize(data, output_range, input_range=None):
1706
+ """Normalize data to a given range.
1707
+
1708
+ Equivalent to `translate` with clipping.
1709
+
1710
+ Args:
1711
+ data (ops.Tensor): Input data to normalize.
1712
+ output_range (tuple): Range to which data should be mapped, e.g., (0, 1).
1713
+ input_range (tuple, optional): Range of input data.
1714
+ If None, the range will be computed from the data.
1715
+ Defaults to None.
1716
+ """
1717
+ if input_range is None:
1718
+ input_range = (None, None)
1719
+ minval, maxval = input_range
1720
+ if minval is None:
1721
+ minval = ops.min(data)
1722
+ if maxval is None:
1723
+ maxval = ops.max(data)
1724
+ data = ops.clip(data, minval, maxval)
1725
+ normalized_data = translate(data, (minval, maxval), output_range)
1726
+ return normalized_data