zea 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +3 -3
- zea/agent/masks.py +2 -2
- zea/agent/selection.py +3 -3
- zea/backend/__init__.py +1 -1
- zea/backend/tensorflow/dataloader.py +1 -5
- zea/beamform/beamformer.py +4 -2
- zea/beamform/pfield.py +2 -2
- zea/beamform/pixelgrid.py +1 -1
- zea/data/__init__.py +0 -9
- zea/data/augmentations.py +222 -29
- zea/data/convert/__init__.py +1 -6
- zea/data/convert/__main__.py +164 -0
- zea/data/convert/camus.py +106 -40
- zea/data/convert/echonet.py +184 -83
- zea/data/convert/echonetlvh/README.md +2 -3
- zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
- zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
- zea/data/convert/echonetlvh/precompute_crop.py +43 -64
- zea/data/convert/picmus.py +37 -40
- zea/data/convert/utils.py +86 -0
- zea/data/convert/verasonics.py +1247 -0
- zea/data/data_format.py +124 -6
- zea/data/dataloader.py +12 -7
- zea/data/datasets.py +109 -70
- zea/data/file.py +119 -82
- zea/data/file_operations.py +496 -0
- zea/data/preset_utils.py +2 -2
- zea/display.py +8 -9
- zea/doppler.py +5 -5
- zea/func/__init__.py +109 -0
- zea/{tensor_ops.py → func/tensor.py} +113 -69
- zea/func/ultrasound.py +500 -0
- zea/internal/_generate_keras_ops.py +5 -5
- zea/internal/checks.py +6 -12
- zea/internal/operators.py +4 -0
- zea/io_lib.py +108 -160
- zea/metrics.py +6 -5
- zea/models/__init__.py +1 -1
- zea/models/diffusion.py +63 -12
- zea/models/echonetlvh.py +1 -1
- zea/models/gmm.py +1 -1
- zea/models/lv_segmentation.py +2 -0
- zea/ops/__init__.py +188 -0
- zea/ops/base.py +442 -0
- zea/{keras_ops.py → ops/keras_ops.py} +2 -2
- zea/ops/pipeline.py +1472 -0
- zea/ops/tensor.py +356 -0
- zea/ops/ultrasound.py +890 -0
- zea/probes.py +2 -10
- zea/scan.py +35 -28
- zea/tools/fit_scan_cone.py +90 -160
- zea/tools/selection_tool.py +1 -1
- zea/tracking/__init__.py +16 -0
- zea/tracking/base.py +94 -0
- zea/tracking/lucas_kanade.py +474 -0
- zea/tracking/segmentation.py +110 -0
- zea/utils.py +11 -2
- {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/METADATA +5 -1
- {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/RECORD +62 -48
- zea/data/convert/matlab.py +0 -1237
- zea/ops.py +0 -3294
- {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/WHEEL +0 -0
- {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/entry_points.txt +0 -0
- {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Basic tensor operations implemented with the multi-backend ``keras.ops``."""
|
|
2
2
|
|
|
3
|
-
from typing import Tuple, Union
|
|
3
|
+
from typing import List, Tuple, Union
|
|
4
4
|
|
|
5
5
|
import keras
|
|
6
6
|
import numpy as np
|
|
@@ -9,7 +9,7 @@ from scipy.ndimage import _ni_support
|
|
|
9
9
|
from scipy.ndimage._filters import _gaussian_kernel1d
|
|
10
10
|
|
|
11
11
|
from zea import log
|
|
12
|
-
from zea.utils import
|
|
12
|
+
from zea.utils import canonicalize_axis
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def split_seed(seed, n):
|
|
@@ -329,7 +329,7 @@ def _map(fun, in_axes=0, out_axes=0, map_fn=None, _use_torch_vmap=False):
|
|
|
329
329
|
For jax, this uses the native vmap implementation.
|
|
330
330
|
For other backends, this a wrapper that uses `ops.vectorized_map` under the hood.
|
|
331
331
|
|
|
332
|
-
Probably you want to use `zea.
|
|
332
|
+
Probably you want to use `zea.func.vmap` instead, which uses this function
|
|
333
333
|
with additional batching/chunking support.
|
|
334
334
|
|
|
335
335
|
Args:
|
|
@@ -431,19 +431,20 @@ def _map(fun, in_axes=0, out_axes=0, map_fn=None, _use_torch_vmap=False):
|
|
|
431
431
|
|
|
432
432
|
|
|
433
433
|
def vmap(
|
|
434
|
-
fun,
|
|
435
|
-
in_axes=0,
|
|
436
|
-
out_axes=0,
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
fn_supports_batch=False,
|
|
440
|
-
disable_jit=False,
|
|
441
|
-
_use_torch_vmap=False,
|
|
434
|
+
fun: callable,
|
|
435
|
+
in_axes: List[Union[int, None]] | int = 0,
|
|
436
|
+
out_axes: List[Union[int, None]] | int = 0,
|
|
437
|
+
chunks: int | None = None,
|
|
438
|
+
batch_size: int | None = None,
|
|
439
|
+
fn_supports_batch: bool = False,
|
|
440
|
+
disable_jit: bool = False,
|
|
441
|
+
_use_torch_vmap: bool = False,
|
|
442
442
|
):
|
|
443
443
|
"""`vmap` with batching or chunking support to avoid memory issues.
|
|
444
444
|
|
|
445
445
|
Basically a wrapper around `vmap` that splits the input into batches or chunks
|
|
446
|
-
to avoid memory issues with large inputs.
|
|
446
|
+
to avoid memory issues with large inputs. Choose the `batch_size` or `chunks` wisely, because
|
|
447
|
+
it pads the input to make it divisible, and then crop the output back to the original size.
|
|
447
448
|
|
|
448
449
|
Args:
|
|
449
450
|
fun: Function to be mapped.
|
|
@@ -648,7 +649,7 @@ def pad_array_to_divisible(arr, N, axis=0, mode="constant", pad_value=None):
|
|
|
648
649
|
return padded_array
|
|
649
650
|
|
|
650
651
|
|
|
651
|
-
def interpolate_data(subsampled_data, mask, order=1, axis=-1):
|
|
652
|
+
def interpolate_data(subsampled_data, mask, order=1, axis=-1, fill_mode="nearest", fill_value=0):
|
|
652
653
|
"""Interpolate subsampled data along a specified axis using `map_coordinates`.
|
|
653
654
|
|
|
654
655
|
Args:
|
|
@@ -658,6 +659,12 @@ def interpolate_data(subsampled_data, mask, order=1, axis=-1):
|
|
|
658
659
|
`True` where data is known.
|
|
659
660
|
order (int, optional): The order of the spline interpolation. Default is `1`.
|
|
660
661
|
axis (int, optional): The axis along which the data is subsampled. Default is `-1`.
|
|
662
|
+
fill_mode (str, optional): Points outside the boundaries of the input are filled
|
|
663
|
+
according to the given mode. Default is 'nearest'. For more info see
|
|
664
|
+
`keras.ops.image.map_coordinates`.
|
|
665
|
+
fill_value (float, optional): Value to use for points outside the boundaries
|
|
666
|
+
of the input if `fill_mode` is 'constant'. Default is `0`. For more info see
|
|
667
|
+
`keras.ops.image.map_coordinates`.
|
|
661
668
|
|
|
662
669
|
Returns:
|
|
663
670
|
ndarray: The data interpolated back to the original grid.
|
|
@@ -722,6 +729,8 @@ def interpolate_data(subsampled_data, mask, order=1, axis=-1):
|
|
|
722
729
|
subsampled_data,
|
|
723
730
|
interp_coords,
|
|
724
731
|
order=order,
|
|
732
|
+
fill_mode=fill_mode,
|
|
733
|
+
fill_value=fill_value,
|
|
725
734
|
)
|
|
726
735
|
|
|
727
736
|
interpolated_data = ops.reshape(interpolated_data, -1)
|
|
@@ -836,7 +845,7 @@ def stack_volume_data_along_axis(data, batch_axis: int, stack_axis: int, number:
|
|
|
836
845
|
.. doctest::
|
|
837
846
|
|
|
838
847
|
>>> import keras
|
|
839
|
-
>>> from zea.
|
|
848
|
+
>>> from zea.func import stack_volume_data_along_axis
|
|
840
849
|
|
|
841
850
|
>>> data = keras.random.uniform((10, 20, 30))
|
|
842
851
|
>>> # stacking along 1st axis with 2 frames per block
|
|
@@ -880,7 +889,7 @@ def split_volume_data_from_axis(data, batch_axis: int, stack_axis: int, number:
|
|
|
880
889
|
.. doctest::
|
|
881
890
|
|
|
882
891
|
>>> import keras
|
|
883
|
-
>>> from zea.
|
|
892
|
+
>>> from zea.func import split_volume_data_from_axis
|
|
884
893
|
|
|
885
894
|
>>> data = keras.random.uniform((20, 10, 30))
|
|
886
895
|
>>> split_data = split_volume_data_from_axis(data, 0, 1, 2, 2)
|
|
@@ -1003,7 +1012,7 @@ def check_patches_fit(
|
|
|
1003
1012
|
Example:
|
|
1004
1013
|
.. doctest::
|
|
1005
1014
|
|
|
1006
|
-
>>> from zea.
|
|
1015
|
+
>>> from zea.func import check_patches_fit
|
|
1007
1016
|
>>> image_shape = (10, 10)
|
|
1008
1017
|
>>> patch_shape = (4, 4)
|
|
1009
1018
|
>>> overlap = (2, 2)
|
|
@@ -1071,7 +1080,7 @@ def images_to_patches(
|
|
|
1071
1080
|
.. doctest::
|
|
1072
1081
|
|
|
1073
1082
|
>>> import keras
|
|
1074
|
-
>>> from zea.
|
|
1083
|
+
>>> from zea.func import images_to_patches
|
|
1075
1084
|
|
|
1076
1085
|
>>> images = keras.random.uniform((2, 8, 8, 3))
|
|
1077
1086
|
>>> patches = images_to_patches(images, patch_shape=(4, 4), overlap=(2, 2))
|
|
@@ -1157,7 +1166,7 @@ def patches_to_images(
|
|
|
1157
1166
|
.. doctest::
|
|
1158
1167
|
|
|
1159
1168
|
>>> import keras
|
|
1160
|
-
>>> from zea.
|
|
1169
|
+
>>> from zea.func import patches_to_images
|
|
1161
1170
|
|
|
1162
1171
|
>>> patches = keras.random.uniform((2, 3, 3, 4, 4, 3))
|
|
1163
1172
|
>>> images = patches_to_images(patches, image_shape=(8, 8, 3), overlap=(2, 2))
|
|
@@ -1245,7 +1254,7 @@ def reshape_axis(data, newshape: tuple, axis: int):
|
|
|
1245
1254
|
.. doctest::
|
|
1246
1255
|
|
|
1247
1256
|
>>> import keras
|
|
1248
|
-
>>> from zea.
|
|
1257
|
+
>>> from zea.func import reshape_axis
|
|
1249
1258
|
|
|
1250
1259
|
>>> data = keras.random.uniform((3, 4, 5))
|
|
1251
1260
|
>>> newshape = (2, 2)
|
|
@@ -1253,7 +1262,7 @@ def reshape_axis(data, newshape: tuple, axis: int):
|
|
|
1253
1262
|
>>> reshaped_data.shape
|
|
1254
1263
|
(3, 2, 2, 5)
|
|
1255
1264
|
"""
|
|
1256
|
-
axis =
|
|
1265
|
+
axis = canonicalize_axis(axis, data.ndim)
|
|
1257
1266
|
shape = list(ops.shape(data)) # list
|
|
1258
1267
|
shape = shape[:axis] + list(newshape) + shape[axis + 1 :]
|
|
1259
1268
|
return ops.reshape(data, shape)
|
|
@@ -1512,7 +1521,8 @@ def sinc(x, eps=keras.config.epsilon()):
|
|
|
1512
1521
|
def apply_along_axis(func1d, axis, arr, *args, **kwargs):
|
|
1513
1522
|
"""Apply a function to 1D array slices along an axis.
|
|
1514
1523
|
|
|
1515
|
-
Keras implementation of numpy.apply_along_axis
|
|
1524
|
+
Keras implementation of ``numpy.apply_along_axis``. Copies the ``jax`` implementation, which
|
|
1525
|
+
uses ``vmap`` to vectorize the function application along the specified axis.
|
|
1516
1526
|
|
|
1517
1527
|
Args:
|
|
1518
1528
|
func1d: A callable function with signature ``func1d(arr, /, *args, **kwargs)``
|
|
@@ -1529,56 +1539,13 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
|
|
|
1529
1539
|
# Convert to keras tensor
|
|
1530
1540
|
arr = ops.convert_to_tensor(arr)
|
|
1531
1541
|
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
# Canonicalize axis (handle negative indices)
|
|
1536
|
-
if axis < 0:
|
|
1537
|
-
axis = num_dims + axis
|
|
1538
|
-
|
|
1539
|
-
if axis < 0 or axis >= num_dims:
|
|
1540
|
-
raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}")
|
|
1541
|
-
|
|
1542
|
-
# Create a wrapper function that applies func1d with the additional arguments
|
|
1543
|
-
def func(slice_arr):
|
|
1544
|
-
return func1d(slice_arr, *args, **kwargs)
|
|
1545
|
-
|
|
1546
|
-
# Recursively build up vectorized maps following the JAX pattern
|
|
1547
|
-
# For dimensions after the target axis (right side)
|
|
1542
|
+
num_dims = ops.ndim(arr)
|
|
1543
|
+
axis = canonicalize_axis(axis, num_dims)
|
|
1544
|
+
func = lambda arr: func1d(arr, *args, **kwargs)
|
|
1548
1545
|
for i in range(1, num_dims - axis):
|
|
1549
|
-
|
|
1550
|
-
|
|
1551
|
-
def make_func(f, dim_offset):
|
|
1552
|
-
def vectorized_func(x):
|
|
1553
|
-
# Move the dimension we want to map over to the front
|
|
1554
|
-
perm = list(range(len(x.shape)))
|
|
1555
|
-
perm[0], perm[dim_offset] = perm[dim_offset], perm[0]
|
|
1556
|
-
x_moved = ops.transpose(x, perm)
|
|
1557
|
-
result = vectorized_map(f, x_moved)
|
|
1558
|
-
# Move the result dimension back if needed
|
|
1559
|
-
if len(result.shape) > 0:
|
|
1560
|
-
result_perm = list(range(len(result.shape)))
|
|
1561
|
-
if len(result_perm) > dim_offset:
|
|
1562
|
-
result_perm[0], result_perm[dim_offset] = (
|
|
1563
|
-
result_perm[dim_offset],
|
|
1564
|
-
result_perm[0],
|
|
1565
|
-
)
|
|
1566
|
-
result = ops.transpose(result, result_perm)
|
|
1567
|
-
return result
|
|
1568
|
-
|
|
1569
|
-
return vectorized_func
|
|
1570
|
-
|
|
1571
|
-
func = make_func(prev_func, i)
|
|
1572
|
-
|
|
1573
|
-
# For dimensions before the target axis (left side)
|
|
1546
|
+
func = vmap(func, in_axes=i, out_axes=-1, _use_torch_vmap=True)
|
|
1574
1547
|
for i in range(axis):
|
|
1575
|
-
|
|
1576
|
-
|
|
1577
|
-
def make_func(f):
|
|
1578
|
-
return lambda x: vectorized_map(f, x)
|
|
1579
|
-
|
|
1580
|
-
func = make_func(prev_func)
|
|
1581
|
-
|
|
1548
|
+
func = vmap(func, in_axes=0, out_axes=0, _use_torch_vmap=True)
|
|
1582
1549
|
return func(arr)
|
|
1583
1550
|
|
|
1584
1551
|
|
|
@@ -1595,6 +1562,8 @@ def correlate(x, y, mode="full"):
|
|
|
1595
1562
|
y: np.ndarray (complex or real)
|
|
1596
1563
|
mode: "full", "valid", or "same"
|
|
1597
1564
|
"""
|
|
1565
|
+
if keras.backend.backend() == "jax":
|
|
1566
|
+
return ops.correlate(x, y, mode=mode)
|
|
1598
1567
|
x = ops.convert_to_tensor(x)
|
|
1599
1568
|
y = ops.convert_to_tensor(y)
|
|
1600
1569
|
|
|
@@ -1656,6 +1625,57 @@ def correlate(x, y, mode="full"):
|
|
|
1656
1625
|
return ops.real(complex_tensor)
|
|
1657
1626
|
|
|
1658
1627
|
|
|
1628
|
+
def find_contour(binary_mask):
|
|
1629
|
+
"""Extract contour/boundary points from a binary mask using edge detection.
|
|
1630
|
+
|
|
1631
|
+
This function finds the boundary pixels of objects in a binary mask by detecting
|
|
1632
|
+
pixels that have at least one neighbor with a different value (using 4-connectivity).
|
|
1633
|
+
|
|
1634
|
+
Args:
|
|
1635
|
+
binary_mask: Binary mask tensor of shape (H, W) with values 0 or 1.
|
|
1636
|
+
|
|
1637
|
+
Returns:
|
|
1638
|
+
Boundary points as tensor of shape (N, 2) in (row, col) format.
|
|
1639
|
+
Returns empty tensor of shape (0, 2) if no boundaries are found.
|
|
1640
|
+
|
|
1641
|
+
Example:
|
|
1642
|
+
.. doctest::
|
|
1643
|
+
|
|
1644
|
+
>>> from zea.func import find_contour
|
|
1645
|
+
>>> import keras
|
|
1646
|
+
>>> mask = keras.ops.zeros((10, 10))
|
|
1647
|
+
>>> mask = keras.ops.scatter_update(
|
|
1648
|
+
... mask, [[3, 3], [3, 4], [4, 3], [4, 4]], [1, 1, 1, 1]
|
|
1649
|
+
... )
|
|
1650
|
+
>>> contour = find_contour(mask)
|
|
1651
|
+
>>> contour.shape
|
|
1652
|
+
(4, 2)
|
|
1653
|
+
"""
|
|
1654
|
+
# Pad the mask to handle edges
|
|
1655
|
+
padded = ops.pad(binary_mask, [[1, 1], [1, 1]], mode="constant", constant_values=0.0)
|
|
1656
|
+
|
|
1657
|
+
# Check 4-connectivity (up, down, left, right)
|
|
1658
|
+
is_edge = (
|
|
1659
|
+
(binary_mask != padded[:-2, 1:-1]) # top neighbor different
|
|
1660
|
+
| (binary_mask != padded[2:, 1:-1]) # bottom neighbor different
|
|
1661
|
+
| (binary_mask != padded[1:-1, :-2]) # left neighbor different
|
|
1662
|
+
| (binary_mask != padded[1:-1, 2:]) # right neighbor different
|
|
1663
|
+
)
|
|
1664
|
+
|
|
1665
|
+
# Only keep edges that are part of the foreground (binary_mask == 1)
|
|
1666
|
+
is_boundary = is_edge & ops.cast(binary_mask, "bool")
|
|
1667
|
+
|
|
1668
|
+
boundary_indices = ops.where(is_boundary)
|
|
1669
|
+
|
|
1670
|
+
if ops.shape(boundary_indices[0])[0] > 0:
|
|
1671
|
+
boundary_points = ops.stack(boundary_indices, axis=1)
|
|
1672
|
+
boundary_points = ops.cast(boundary_points, "float32")
|
|
1673
|
+
else:
|
|
1674
|
+
boundary_points = ops.zeros((0, 2), dtype="float32")
|
|
1675
|
+
|
|
1676
|
+
return boundary_points
|
|
1677
|
+
|
|
1678
|
+
|
|
1659
1679
|
def translate(array, range_from=None, range_to=(0, 255)):
|
|
1660
1680
|
"""Map values in array from one range to other.
|
|
1661
1681
|
|
|
@@ -1680,3 +1700,27 @@ def translate(array, range_from=None, range_to=(0, 255)):
|
|
|
1680
1700
|
|
|
1681
1701
|
# Convert the 0-1 range into a value in the right range.
|
|
1682
1702
|
return right_min + (value_scaled * (right_max - right_min))
|
|
1703
|
+
|
|
1704
|
+
|
|
1705
|
+
def normalize(data, output_range, input_range=None):
|
|
1706
|
+
"""Normalize data to a given range.
|
|
1707
|
+
|
|
1708
|
+
Equivalent to `translate` with clipping.
|
|
1709
|
+
|
|
1710
|
+
Args:
|
|
1711
|
+
data (ops.Tensor): Input data to normalize.
|
|
1712
|
+
output_range (tuple): Range to which data should be mapped, e.g., (0, 1).
|
|
1713
|
+
input_range (tuple, optional): Range of input data.
|
|
1714
|
+
If None, the range will be computed from the data.
|
|
1715
|
+
Defaults to None.
|
|
1716
|
+
"""
|
|
1717
|
+
if input_range is None:
|
|
1718
|
+
input_range = (None, None)
|
|
1719
|
+
minval, maxval = input_range
|
|
1720
|
+
if minval is None:
|
|
1721
|
+
minval = ops.min(data)
|
|
1722
|
+
if maxval is None:
|
|
1723
|
+
maxval = ops.max(data)
|
|
1724
|
+
data = ops.clip(data, minval, maxval)
|
|
1725
|
+
normalized_data = translate(data, (minval, maxval), output_range)
|
|
1726
|
+
return normalized_data
|