zea 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +3 -3
- zea/agent/masks.py +2 -2
- zea/agent/selection.py +3 -3
- zea/backend/__init__.py +1 -1
- zea/backend/tensorflow/dataloader.py +1 -5
- zea/beamform/beamformer.py +4 -2
- zea/beamform/pfield.py +2 -2
- zea/beamform/pixelgrid.py +1 -1
- zea/data/__init__.py +0 -9
- zea/data/augmentations.py +222 -29
- zea/data/convert/__init__.py +1 -6
- zea/data/convert/__main__.py +164 -0
- zea/data/convert/camus.py +106 -40
- zea/data/convert/echonet.py +184 -83
- zea/data/convert/echonetlvh/README.md +2 -3
- zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
- zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
- zea/data/convert/echonetlvh/precompute_crop.py +43 -64
- zea/data/convert/picmus.py +37 -40
- zea/data/convert/utils.py +86 -0
- zea/data/convert/verasonics.py +1247 -0
- zea/data/data_format.py +124 -6
- zea/data/dataloader.py +12 -7
- zea/data/datasets.py +109 -70
- zea/data/file.py +119 -82
- zea/data/file_operations.py +496 -0
- zea/data/preset_utils.py +2 -2
- zea/display.py +8 -9
- zea/doppler.py +5 -5
- zea/func/__init__.py +109 -0
- zea/{tensor_ops.py → func/tensor.py} +113 -69
- zea/func/ultrasound.py +500 -0
- zea/internal/_generate_keras_ops.py +5 -5
- zea/internal/checks.py +6 -12
- zea/internal/operators.py +4 -0
- zea/io_lib.py +108 -160
- zea/metrics.py +6 -5
- zea/models/__init__.py +1 -1
- zea/models/diffusion.py +63 -12
- zea/models/echonetlvh.py +1 -1
- zea/models/gmm.py +1 -1
- zea/models/lv_segmentation.py +2 -0
- zea/ops/__init__.py +188 -0
- zea/ops/base.py +442 -0
- zea/{keras_ops.py → ops/keras_ops.py} +2 -2
- zea/ops/pipeline.py +1472 -0
- zea/ops/tensor.py +356 -0
- zea/ops/ultrasound.py +890 -0
- zea/probes.py +2 -10
- zea/scan.py +35 -28
- zea/tools/fit_scan_cone.py +90 -160
- zea/tools/selection_tool.py +1 -1
- zea/tracking/__init__.py +16 -0
- zea/tracking/base.py +94 -0
- zea/tracking/lucas_kanade.py +474 -0
- zea/tracking/segmentation.py +110 -0
- zea/utils.py +11 -2
- {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/METADATA +5 -1
- {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/RECORD +62 -48
- zea/data/convert/matlab.py +0 -1237
- zea/ops.py +0 -3294
- {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/WHEEL +0 -0
- {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/entry_points.txt +0 -0
- {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/licenses/LICENSE +0 -0
zea/__init__.py
CHANGED
|
@@ -7,7 +7,7 @@ from . import log
|
|
|
7
7
|
|
|
8
8
|
# dynamically add __version__ attribute (see pyproject.toml)
|
|
9
9
|
# __version__ = __import__("importlib.metadata").metadata.version(__package__)
|
|
10
|
-
__version__ = "0.0.
|
|
10
|
+
__version__ = "0.0.9"
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
def _bootstrap_backend():
|
|
@@ -89,12 +89,12 @@ from . import (
|
|
|
89
89
|
beamform,
|
|
90
90
|
data,
|
|
91
91
|
display,
|
|
92
|
+
func,
|
|
92
93
|
io_lib,
|
|
93
|
-
keras_ops,
|
|
94
94
|
metrics,
|
|
95
95
|
models,
|
|
96
|
+
ops,
|
|
96
97
|
simulator,
|
|
97
|
-
tensor_ops,
|
|
98
98
|
utils,
|
|
99
99
|
visualize,
|
|
100
100
|
)
|
zea/agent/masks.py
CHANGED
|
@@ -9,8 +9,8 @@ from typing import List
|
|
|
9
9
|
import keras
|
|
10
10
|
from keras import ops
|
|
11
11
|
|
|
12
|
-
from zea import tensor_ops
|
|
13
12
|
from zea.agent.gumbel import hard_straight_through
|
|
13
|
+
from zea.func.tensor import nonzero
|
|
14
14
|
|
|
15
15
|
_DEFAULT_DTYPE = "bool"
|
|
16
16
|
|
|
@@ -56,7 +56,7 @@ def k_hot_to_indices(selected_lines, n_actions: int, fill_value=-1):
|
|
|
56
56
|
|
|
57
57
|
# Find nonzero indices for each frame
|
|
58
58
|
def get_nonzero(row):
|
|
59
|
-
return
|
|
59
|
+
return nonzero(row > 0, size=n_actions, fill_value=fill_value)[0]
|
|
60
60
|
|
|
61
61
|
indices = ops.vectorized_map(get_nonzero, selected_lines)
|
|
62
62
|
return indices
|
zea/agent/selection.py
CHANGED
|
@@ -16,9 +16,9 @@ from typing import Callable
|
|
|
16
16
|
import keras
|
|
17
17
|
from keras import ops
|
|
18
18
|
|
|
19
|
-
from zea import tensor_ops
|
|
20
19
|
from zea.agent import masks
|
|
21
20
|
from zea.backend.autograd import AutoGrad
|
|
21
|
+
from zea.func import tensor
|
|
22
22
|
from zea.internal.registry import action_selection_registry
|
|
23
23
|
|
|
24
24
|
|
|
@@ -462,7 +462,7 @@ class CovarianceSamplingLines(LinesActionModel):
|
|
|
462
462
|
particles = ops.reshape(particles, shape)
|
|
463
463
|
|
|
464
464
|
# [batch_size, rows * stack_n_cols, n_possible_actions, n_possible_actions]
|
|
465
|
-
cov_matrix =
|
|
465
|
+
cov_matrix = tensor.batch_cov(particles)
|
|
466
466
|
|
|
467
467
|
# Sum over the row dimension [batch_size, n_possible_actions, n_possible_actions]
|
|
468
468
|
cov_matrix = ops.sum(cov_matrix, axis=1)
|
|
@@ -477,7 +477,7 @@ class CovarianceSamplingLines(LinesActionModel):
|
|
|
477
477
|
# Subsample the covariance matrix with random lines
|
|
478
478
|
def subsample_with_mask(mask):
|
|
479
479
|
"""Subsample the covariance matrix with a single mask."""
|
|
480
|
-
subsampled_cov_matrix =
|
|
480
|
+
subsampled_cov_matrix = tensor.boolean_mask(
|
|
481
481
|
cov_matrix, mask, size=batch_size * self.n_actions**2
|
|
482
482
|
)
|
|
483
483
|
return ops.reshape(subsampled_cov_matrix, [batch_size, self.n_actions, self.n_actions])
|
zea/backend/__init__.py
CHANGED
|
@@ -131,7 +131,7 @@ class on_device:
|
|
|
131
131
|
.. code-block:: python
|
|
132
132
|
|
|
133
133
|
with zea.backend.on_device("gpu:3"):
|
|
134
|
-
pipeline = zea.Pipeline([zea.
|
|
134
|
+
pipeline = zea.Pipeline([zea.ops.Abs()])
|
|
135
135
|
output = pipeline(data=keras.random.normal((10, 10))) # output is on "cuda:3"
|
|
136
136
|
"""
|
|
137
137
|
|
|
@@ -12,8 +12,8 @@ from keras.src.trainers.data_adapters import TFDatasetAdapter
|
|
|
12
12
|
|
|
13
13
|
from zea.data.dataloader import H5Generator
|
|
14
14
|
from zea.data.layers import Resizer
|
|
15
|
+
from zea.func.tensor import translate
|
|
15
16
|
from zea.internal.utils import find_methods_with_return_type
|
|
16
|
-
from zea.tensor_ops import translate
|
|
17
17
|
|
|
18
18
|
METHODS_THAT_RETURN_DATASET = find_methods_with_return_type(tf.data.Dataset, "DatasetV2")
|
|
19
19
|
|
|
@@ -155,10 +155,6 @@ def make_dataloader(
|
|
|
155
155
|
Mimics the native TF function ``tf.keras.utils.image_dataset_from_directory``
|
|
156
156
|
but for .hdf5 files.
|
|
157
157
|
|
|
158
|
-
Saves a dataset_info.yaml file in the directory with information about the dataset.
|
|
159
|
-
This file is used to load the dataset later on, which speeds up the initial loading
|
|
160
|
-
of the dataset for very large datasets.
|
|
161
|
-
|
|
162
158
|
Does the following in order to load a dataset:
|
|
163
159
|
|
|
164
160
|
- Find all .hdf5 files in the director(ies)
|
zea/beamform/beamformer.py
CHANGED
|
@@ -5,7 +5,7 @@ import numpy as np
|
|
|
5
5
|
from keras import ops
|
|
6
6
|
|
|
7
7
|
from zea.beamform.lens_correction import calculate_lens_corrected_delays
|
|
8
|
-
from zea.
|
|
8
|
+
from zea.func.tensor import vmap
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def fnum_window_fn_rect(normalized_angle):
|
|
@@ -379,7 +379,7 @@ def complex_rotate(iq, theta):
|
|
|
379
379
|
|
|
380
380
|
.. math::
|
|
381
381
|
|
|
382
|
-
x(t + \\Delta t) &= I'(t) \\cos(\\omega_c (t + \\Delta t))
|
|
382
|
+
x(t + \\Delta t) &= I'(t) \\cos(\\omega_c (t + \\Delta t))
|
|
383
383
|
- Q'(t) \\sin(\\omega_c (t + \\Delta t))\\\\
|
|
384
384
|
&= \\overbrace{(I'(t)\\cos(\\theta)
|
|
385
385
|
- Q'(t)\\sin(\\theta) )}^{I_\\Delta(t)} \\cos(\\omega_c t)\\\\
|
|
@@ -452,6 +452,8 @@ def distance_Tx_generic(
|
|
|
452
452
|
`(n_el,)`.
|
|
453
453
|
probe_geometry (ops.Tensor): The positions of the transducer elements of shape
|
|
454
454
|
`(n_el, 3)`.
|
|
455
|
+
focus_distance (float): The focus distance in meters.
|
|
456
|
+
polar_angle (float): The polar angle in radians.
|
|
455
457
|
sound_speed (float): The speed of sound in m/s. Defaults to 1540.
|
|
456
458
|
|
|
457
459
|
Returns:
|
zea/beamform/pfield.py
CHANGED
|
@@ -24,8 +24,8 @@ import numpy as np
|
|
|
24
24
|
from keras import ops
|
|
25
25
|
|
|
26
26
|
from zea import log
|
|
27
|
+
from zea.func.tensor import sinc
|
|
27
28
|
from zea.internal.cache import cache_output
|
|
28
|
-
from zea.tensor_ops import sinc
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
def _abs_sinc(x):
|
|
@@ -101,7 +101,7 @@ def compute_pfield(
|
|
|
101
101
|
# array params
|
|
102
102
|
probe_geometry = ops.convert_to_tensor(probe_geometry, dtype="float32")
|
|
103
103
|
|
|
104
|
-
pitch = probe_geometry[1, 0] - probe_geometry[0, 0] # element pitch
|
|
104
|
+
pitch = ops.abs(probe_geometry[1, 0] - probe_geometry[0, 0]) # element pitch
|
|
105
105
|
|
|
106
106
|
kerf = 0.1 * pitch # for now this is hardcoded
|
|
107
107
|
element_width = pitch - kerf
|
zea/beamform/pixelgrid.py
CHANGED
|
@@ -45,7 +45,7 @@ def cartesian_pixel_grid(xlims, zlims, grid_size_x=None, grid_size_z=None, dx=No
|
|
|
45
45
|
ValueError: Either grid_size_x and grid_size_z or dx and dz must be defined.
|
|
46
46
|
|
|
47
47
|
Returns:
|
|
48
|
-
grid (np.ndarray): Pixel grid of size (grid_size_z,
|
|
48
|
+
grid (np.ndarray): Pixel grid of size (grid_size_z, grid_size_x, 3) in
|
|
49
49
|
Cartesian coordinates (x, y, z)
|
|
50
50
|
"""
|
|
51
51
|
assert (bool(grid_size_x) and bool(grid_size_z)) ^ (bool(dx) and bool(dz)), (
|
zea/data/__init__.py
CHANGED
|
@@ -38,15 +38,6 @@ Examples usage
|
|
|
38
38
|
... files.append(file) # process each file as needed
|
|
39
39
|
>>> dataset.close()
|
|
40
40
|
|
|
41
|
-
Subpackage layout
|
|
42
|
-
-----------------
|
|
43
|
-
|
|
44
|
-
- ``file.py``: Implements :class:`zea.File` and related file utilities.
|
|
45
|
-
- ``datasets.py``: Implements :class:`zea.Dataset` and folder management.
|
|
46
|
-
- ``dataloader.py``: Data loading utilities for batching and shuffling.
|
|
47
|
-
- ``data_format.py``: Data validation and example dataset generation.
|
|
48
|
-
- ``convert/``: Data conversion tools (e.g., from external formats).
|
|
49
|
-
|
|
50
41
|
""" # noqa: E501
|
|
51
42
|
|
|
52
43
|
from .convert.camus import sitk_load
|
zea/data/augmentations.py
CHANGED
|
@@ -4,7 +4,7 @@ import keras
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
from keras import layers, ops
|
|
6
6
|
|
|
7
|
-
from zea.
|
|
7
|
+
from zea.func.tensor import is_jax_prng_key, split_seed
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class RandomCircleInclusion(layers.Layer):
|
|
@@ -30,7 +30,7 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
30
30
|
|
|
31
31
|
def __init__(
|
|
32
32
|
self,
|
|
33
|
-
radius: int,
|
|
33
|
+
radius: int | tuple[int, int],
|
|
34
34
|
fill_value: float = 1.0,
|
|
35
35
|
circle_axes: tuple[int, int] = (1, 2),
|
|
36
36
|
with_batch_dim=True,
|
|
@@ -38,25 +38,70 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
38
38
|
recovery_threshold=0.1,
|
|
39
39
|
randomize_location_across_batch=True,
|
|
40
40
|
seed=None,
|
|
41
|
+
width_range: tuple[int, int] = None,
|
|
42
|
+
height_range: tuple[int, int] = None,
|
|
41
43
|
**kwargs,
|
|
42
44
|
):
|
|
43
45
|
"""
|
|
44
46
|
Initialize RandomCircleInclusion.
|
|
45
47
|
|
|
46
48
|
Args:
|
|
47
|
-
radius (int): Radius of the circle to include.
|
|
49
|
+
radius (int or tuple[int, int]): Radius of the circle/ellipse to include.
|
|
48
50
|
fill_value (float): Value to fill inside the circle.
|
|
49
|
-
circle_axes (tuple[int, int]): Axes along which to draw the circle
|
|
51
|
+
circle_axes (tuple[int, int]): Axes along which to draw the circle
|
|
52
|
+
(height, width).
|
|
50
53
|
with_batch_dim (bool): Whether input has a batch dimension.
|
|
51
54
|
return_centers (bool): Whether to return circle centers along with images.
|
|
52
55
|
recovery_threshold (float): Threshold for considering a pixel as recovered.
|
|
53
|
-
randomize_location_across_batch (bool): If True
|
|
54
|
-
|
|
56
|
+
randomize_location_across_batch (bool): If True (and with_batch_dim=True),
|
|
57
|
+
each batch element gets a different random center. If False, all batch
|
|
58
|
+
elements share the same center.
|
|
55
59
|
seed (Any): Optional random seed for reproducibility.
|
|
60
|
+
width_range (tuple[int, int], optional): Range (min, max) for circle
|
|
61
|
+
center x (width axis).
|
|
62
|
+
height_range (tuple[int, int], optional): Range (min, max) for circle
|
|
63
|
+
center y (height axis).
|
|
56
64
|
**kwargs: Additional keyword arguments for the parent Layer.
|
|
65
|
+
|
|
66
|
+
Example:
|
|
67
|
+
.. doctest::
|
|
68
|
+
|
|
69
|
+
>>> from zea.data.augmentations import RandomCircleInclusion
|
|
70
|
+
>>> from keras import ops
|
|
71
|
+
|
|
72
|
+
>>> layer = RandomCircleInclusion(
|
|
73
|
+
... radius=5,
|
|
74
|
+
... circle_axes=(1, 2),
|
|
75
|
+
... with_batch_dim=True,
|
|
76
|
+
... )
|
|
77
|
+
>>> image = ops.zeros((1, 28, 28), dtype="float32")
|
|
78
|
+
>>> out = layer(image) # doctest: +SKIP
|
|
79
|
+
|
|
57
80
|
"""
|
|
58
81
|
super().__init__(**kwargs)
|
|
59
|
-
|
|
82
|
+
|
|
83
|
+
# Validate randomize_location_across_batch only makes sense with batch dim
|
|
84
|
+
if not with_batch_dim and not randomize_location_across_batch:
|
|
85
|
+
raise ValueError(
|
|
86
|
+
"randomize_location_across_batch=False is only applicable when "
|
|
87
|
+
"with_batch_dim=True. When with_batch_dim=False, there is no batch "
|
|
88
|
+
"to randomize across."
|
|
89
|
+
)
|
|
90
|
+
# Convert radius to tuple if int, else validate tuple
|
|
91
|
+
if isinstance(radius, int):
|
|
92
|
+
if radius <= 0:
|
|
93
|
+
raise ValueError(f"radius must be a positive integer, got {radius}.")
|
|
94
|
+
self.radius = (radius, radius)
|
|
95
|
+
elif isinstance(radius, tuple) and len(radius) == 2:
|
|
96
|
+
rx, ry = radius
|
|
97
|
+
if not all(isinstance(r, int) for r in (rx, ry)):
|
|
98
|
+
raise TypeError(f"radius tuple must contain two integers, got {radius!r}.")
|
|
99
|
+
if rx <= 0 or ry <= 0:
|
|
100
|
+
raise ValueError(f"radius components must be positive, got {radius!r}.")
|
|
101
|
+
self.radius = (rx, ry)
|
|
102
|
+
else:
|
|
103
|
+
raise TypeError("radius must be an int or a tuple of two ints")
|
|
104
|
+
|
|
60
105
|
self.fill_value = fill_value
|
|
61
106
|
self.circle_axes = circle_axes
|
|
62
107
|
self.with_batch_dim = with_batch_dim
|
|
@@ -64,6 +109,8 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
64
109
|
self.recovery_threshold = recovery_threshold
|
|
65
110
|
self.randomize_location_across_batch = randomize_location_across_batch
|
|
66
111
|
self.seed = seed
|
|
112
|
+
self.width_range = width_range
|
|
113
|
+
self.height_range = height_range
|
|
67
114
|
self._axis1 = None
|
|
68
115
|
self._axis2 = None
|
|
69
116
|
self._perm = None
|
|
@@ -116,6 +163,43 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
116
163
|
self._static_w = int(permuted_shape[-1])
|
|
117
164
|
self._static_shape = tuple(permuted_shape)
|
|
118
165
|
|
|
166
|
+
# Validate that ellipse can fit within image bounds
|
|
167
|
+
rx, ry = self.radius
|
|
168
|
+
min_required_width = 2 * rx + 1
|
|
169
|
+
min_required_height = 2 * ry + 1
|
|
170
|
+
|
|
171
|
+
if self._static_w < min_required_width:
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"Image width ({self._static_w}) is too small for radius {rx}. "
|
|
174
|
+
f"Minimum required width: {min_required_width}"
|
|
175
|
+
)
|
|
176
|
+
if self._static_h < min_required_height:
|
|
177
|
+
raise ValueError(
|
|
178
|
+
f"Image height ({self._static_h}) is too small for radius {ry}. "
|
|
179
|
+
f"Minimum required height: {min_required_height}"
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Validate width_range and height_range if provided
|
|
183
|
+
if self.width_range is not None:
|
|
184
|
+
min_x, max_x = self.width_range
|
|
185
|
+
if min_x >= max_x:
|
|
186
|
+
raise ValueError(f"width_range must have min < max, got {self.width_range}")
|
|
187
|
+
if min_x < rx or max_x > self._static_w - rx:
|
|
188
|
+
raise ValueError(
|
|
189
|
+
f"width_range {self.width_range} would place circle outside image bounds. "
|
|
190
|
+
f"Valid range: [{rx}, {self._static_w - rx})"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
if self.height_range is not None:
|
|
194
|
+
min_y, max_y = self.height_range
|
|
195
|
+
if min_y >= max_y:
|
|
196
|
+
raise ValueError(f"height_range must have min < max, got {self.height_range}")
|
|
197
|
+
if min_y < ry or max_y > self._static_h - ry:
|
|
198
|
+
raise ValueError(
|
|
199
|
+
f"height_range {self.height_range} would place circle outside image bounds. "
|
|
200
|
+
f"Valid range: [{ry}, {self._static_h - ry})"
|
|
201
|
+
)
|
|
202
|
+
|
|
119
203
|
super().build(input_shape)
|
|
120
204
|
|
|
121
205
|
def compute_output_shape(self, input_shape):
|
|
@@ -165,7 +249,7 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
165
249
|
centers (Tensor): Tensor of shape (batch, 2) with circle centers.
|
|
166
250
|
h (int): Height of the image.
|
|
167
251
|
w (int): Width of the image.
|
|
168
|
-
radius (int):
|
|
252
|
+
radius (tuple[int, int]): Radii of the ellipse (rx, ry).
|
|
169
253
|
dtype (str or dtype): Data type for the mask.
|
|
170
254
|
|
|
171
255
|
Returns:
|
|
@@ -176,12 +260,12 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
176
260
|
Y, X = ops.meshgrid(Y, X, indexing="ij")
|
|
177
261
|
Y = ops.expand_dims(Y, 0) # (1, h, w)
|
|
178
262
|
X = ops.expand_dims(X, 0) # (1, h, w)
|
|
179
|
-
# cx = ops.cast(centers[:, 0], "float32")[:, None, None]
|
|
180
|
-
# cy = ops.cast(centers[:, 1], "float32")[:, None, None]
|
|
181
263
|
cx = centers[:, 0][:, None, None]
|
|
182
264
|
cy = centers[:, 1][:, None, None]
|
|
183
|
-
|
|
184
|
-
|
|
265
|
+
rx, ry = radius
|
|
266
|
+
# Ellipse equation: ((X-cx)/rx)^2 + ((Y-cy)/ry)^2 <= 1
|
|
267
|
+
dist = ((X - cx) / rx) ** 2 + ((Y - cy) / ry) ** 2
|
|
268
|
+
mask = ops.cast(dist <= 1, dtype)
|
|
185
269
|
return mask
|
|
186
270
|
|
|
187
271
|
def call(self, x, seed=None):
|
|
@@ -197,9 +281,17 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
197
281
|
centers if return_centers is True.
|
|
198
282
|
"""
|
|
199
283
|
if keras.backend.backend() == "jax" and not is_jax_prng_key(seed):
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
284
|
+
if isinstance(seed, keras.random.SeedGenerator):
|
|
285
|
+
raise ValueError(
|
|
286
|
+
"When using JAX backend, please provide a jax.random.PRNGKey as seed, "
|
|
287
|
+
"instead of keras.random.SeedGenerator."
|
|
288
|
+
)
|
|
289
|
+
else:
|
|
290
|
+
raise TypeError(
|
|
291
|
+
f"When using JAX backend, seed must be a JAX PRNG key (created with "
|
|
292
|
+
f"jax.random.PRNGKey()), but got {type(seed)}. Note: jax.random.key() "
|
|
293
|
+
f"keys are not currently supported."
|
|
294
|
+
)
|
|
203
295
|
seed = seed if seed is not None else self.seed
|
|
204
296
|
|
|
205
297
|
if self.with_batch_dim:
|
|
@@ -209,22 +301,33 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
209
301
|
imgs, centers = ops.map(lambda arg: self._call(arg, seed), x)
|
|
210
302
|
else:
|
|
211
303
|
raise NotImplementedError(
|
|
212
|
-
"You cannot fix circle locations across while using"
|
|
304
|
+
"You cannot fix circle locations across batch while using "
|
|
213
305
|
+ "RandomCircleInclusion as a dataset augmentation, "
|
|
214
306
|
+ "since samples in a batch are handled independently."
|
|
215
307
|
)
|
|
216
308
|
else:
|
|
309
|
+
batch_size = ops.shape(x)[0]
|
|
217
310
|
if self.randomize_location_across_batch:
|
|
218
|
-
batch_size = ops.shape(x)[0]
|
|
219
311
|
seeds = split_seed(seed, batch_size)
|
|
220
|
-
if all(
|
|
312
|
+
if all(s is seeds[0] for s in seeds):
|
|
221
313
|
imgs, centers = ops.map(lambda arg: self._call(arg, seeds[0]), x)
|
|
222
314
|
else:
|
|
223
315
|
imgs, centers = ops.map(
|
|
224
316
|
lambda args: self._call(args[0], args[1]), (x, seeds)
|
|
225
317
|
)
|
|
226
318
|
else:
|
|
227
|
-
|
|
319
|
+
# Generate one random center that will be used for all batch elements
|
|
320
|
+
img0, center0 = self._call(x[0], seed)
|
|
321
|
+
|
|
322
|
+
# Apply the same center to all batch elements
|
|
323
|
+
imgs_list, centers_list = [img0], [center0]
|
|
324
|
+
for i in range(1, batch_size):
|
|
325
|
+
img_aug, center_out = self._call_with_fixed_center(x[i], center0)
|
|
326
|
+
imgs_list.append(img_aug)
|
|
327
|
+
centers_list.append(center_out)
|
|
328
|
+
|
|
329
|
+
imgs = ops.stack(imgs_list, axis=0)
|
|
330
|
+
centers = ops.stack(centers_list, axis=0)
|
|
228
331
|
else:
|
|
229
332
|
imgs, centers = self._call(x, seed)
|
|
230
333
|
|
|
@@ -248,17 +351,28 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
248
351
|
flat, flat_batch_size, h, w = self._flatten_batch_and_other_dims(x)
|
|
249
352
|
|
|
250
353
|
def _draw_circle_2d(img2d):
|
|
354
|
+
rx, ry = self.radius
|
|
355
|
+
# Determine allowed ranges for center
|
|
356
|
+
if self.width_range is not None:
|
|
357
|
+
min_x, max_x = self.width_range
|
|
358
|
+
else:
|
|
359
|
+
min_x, max_x = rx, w - rx
|
|
360
|
+
if self.height_range is not None:
|
|
361
|
+
min_y, max_y = self.height_range
|
|
362
|
+
else:
|
|
363
|
+
min_y, max_y = ry, h - ry
|
|
364
|
+
# Ensure the ellipse fits within the allowed region
|
|
251
365
|
cx = ops.cast(
|
|
252
|
-
keras.random.uniform((),
|
|
366
|
+
keras.random.uniform((), min_x, max_x, seed=seed),
|
|
253
367
|
"int32",
|
|
254
368
|
)
|
|
255
369
|
new_seed, _ = split_seed(seed, 2) # ensure that cx and cy are independent
|
|
256
370
|
cy = ops.cast(
|
|
257
|
-
keras.random.uniform((),
|
|
371
|
+
keras.random.uniform((), min_y, max_y, seed=new_seed),
|
|
258
372
|
"int32",
|
|
259
373
|
)
|
|
260
374
|
mask = self._make_circle_mask(
|
|
261
|
-
ops.stack([cx, cy])[None, :], h, w,
|
|
375
|
+
ops.stack([cx, cy])[None, :], h, w, (rx, ry), img2d.dtype
|
|
262
376
|
)[0]
|
|
263
377
|
img_aug = img2d * (1 - mask) + self.fill_value * mask
|
|
264
378
|
center = ops.stack([cx, cy])
|
|
@@ -271,6 +385,67 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
271
385
|
centers = ops.reshape(centers, centers_shape)
|
|
272
386
|
return (aug_imgs, centers)
|
|
273
387
|
|
|
388
|
+
def _apply_circle_mask(self, flat, center, h, w):
|
|
389
|
+
"""Apply circle mask to flattened image data.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
flat (Tensor): Flattened image data of shape (flat_batch, h, w).
|
|
393
|
+
center (Tensor): Center coordinates, either (2,) or (flat_batch, 2).
|
|
394
|
+
h (int): Height of images.
|
|
395
|
+
w (int): Width of images.
|
|
396
|
+
|
|
397
|
+
Returns:
|
|
398
|
+
Tensor: Augmented images with circle applied.
|
|
399
|
+
"""
|
|
400
|
+
rx, ry = self.radius
|
|
401
|
+
|
|
402
|
+
# Ensure center has batch dimension for broadcasting
|
|
403
|
+
if len(center.shape) == 1:
|
|
404
|
+
# Single center (2,) -> broadcast to all slices
|
|
405
|
+
center_batched = ops.tile(ops.reshape(center, [1, 2]), [flat.shape[0], 1])
|
|
406
|
+
else:
|
|
407
|
+
# Already batched (flat_batch, 2)
|
|
408
|
+
center_batched = center
|
|
409
|
+
|
|
410
|
+
# Create masks for all slices using vectorized_map or broadcasting
|
|
411
|
+
masks = self._make_circle_mask(center_batched, h, w, (rx, ry), flat.dtype)
|
|
412
|
+
|
|
413
|
+
# Apply masks
|
|
414
|
+
aug_imgs = flat * (1 - masks) + self.fill_value * masks
|
|
415
|
+
return aug_imgs
|
|
416
|
+
|
|
417
|
+
def _call_with_fixed_center(self, x, fixed_center):
|
|
418
|
+
"""Apply augmentation using a pre-determined center.
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
x (Tensor): Input image tensor.
|
|
422
|
+
fixed_center (Tensor): Pre-determined center coordinates, either (2,)
|
|
423
|
+
for a single center or (flat_batch, 2) for per-slice centers.
|
|
424
|
+
|
|
425
|
+
Returns:
|
|
426
|
+
tuple: (augmented image, center coordinates).
|
|
427
|
+
"""
|
|
428
|
+
x = self._permute_axes_to_circle_last(x)
|
|
429
|
+
flat, flat_batch_size, h, w = self._flatten_batch_and_other_dims(x)
|
|
430
|
+
|
|
431
|
+
# Apply circle mask with fixed center (handles both single and batched centers)
|
|
432
|
+
aug_imgs = self._apply_circle_mask(flat, fixed_center, h, w)
|
|
433
|
+
aug_imgs = ops.reshape(aug_imgs, x.shape)
|
|
434
|
+
aug_imgs = ops.transpose(aug_imgs, axes=self._inv_perm)
|
|
435
|
+
|
|
436
|
+
# Return centers matching the expected shape
|
|
437
|
+
if len(fixed_center.shape) == 1:
|
|
438
|
+
# Single center (2,) -> broadcast to match flat_batch_size
|
|
439
|
+
if flat_batch_size == 1:
|
|
440
|
+
centers = fixed_center
|
|
441
|
+
else:
|
|
442
|
+
centers = ops.tile(ops.reshape(fixed_center, [1, 2]), [flat_batch_size, 1])
|
|
443
|
+
else:
|
|
444
|
+
# Already batched centers (flat_batch, 2)
|
|
445
|
+
centers = fixed_center
|
|
446
|
+
|
|
447
|
+
return (aug_imgs, centers)
|
|
448
|
+
|
|
274
449
|
def get_config(self):
|
|
275
450
|
"""
|
|
276
451
|
Get layer configuration for serialization.
|
|
@@ -285,6 +460,8 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
285
460
|
"fill_value": self.fill_value,
|
|
286
461
|
"circle_axes": self.circle_axes,
|
|
287
462
|
"return_centers": self.return_centers,
|
|
463
|
+
"width_range": self.width_range,
|
|
464
|
+
"height_range": self.height_range,
|
|
288
465
|
}
|
|
289
466
|
)
|
|
290
467
|
return cfg
|
|
@@ -293,7 +470,8 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
293
470
|
self, images, centers, recovery_threshold, fill_value=None
|
|
294
471
|
):
|
|
295
472
|
"""
|
|
296
|
-
Evaluate the percentage of the true circle that has been recovered in the images
|
|
473
|
+
Evaluate the percentage of the true circle that has been recovered in the images,
|
|
474
|
+
and return a mask of the detected part of the circle.
|
|
297
475
|
|
|
298
476
|
Args:
|
|
299
477
|
images (Tensor): Tensor of images (any shape, with circle axes as specified).
|
|
@@ -302,8 +480,12 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
302
480
|
fill_value (float, optional): Optionally override fill_value for cases
|
|
303
481
|
where image range has changed.
|
|
304
482
|
|
|
305
|
-
|
|
306
|
-
Tensor
|
|
483
|
+
Returns:
|
|
484
|
+
Tuple[Tensor, Tensor]:
|
|
485
|
+
- percent_recovered: [batch] - average recovery percentage per batch element,
|
|
486
|
+
averaged across all non-batch dimensions (e.g., frames, slices)
|
|
487
|
+
- recovered_masks: [batch, flat_batch, h, w] or [batch, h, w] or [flat_batch, h, w]
|
|
488
|
+
depending on input shape - binary mask of detected circle regions
|
|
307
489
|
"""
|
|
308
490
|
fill_value = fill_value or self.fill_value
|
|
309
491
|
|
|
@@ -318,12 +500,23 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
318
500
|
recovered_sum = ops.sum(recovered, axis=[1, 2])
|
|
319
501
|
mask_sum = ops.sum(mask, axis=[1, 2])
|
|
320
502
|
percent_recovered = recovered_sum / (mask_sum + 1e-8)
|
|
321
|
-
|
|
503
|
+
# recovered_mask: binary mask of detected part of the circle
|
|
504
|
+
recovered_mask = ops.cast(recovered > 0, flat_image.dtype)
|
|
505
|
+
return percent_recovered, recovered_mask
|
|
322
506
|
|
|
323
507
|
if self.with_batch_dim:
|
|
324
|
-
|
|
508
|
+
results = ops.vectorized_map(
|
|
325
509
|
lambda args: _evaluate_recovered_circle_accuracy(args[0], args[1]),
|
|
326
510
|
(images, centers),
|
|
327
|
-
)
|
|
511
|
+
)
|
|
512
|
+
percent_recovered, recovered_masks = results
|
|
513
|
+
# If there are multiple circles per batch element (e.g., multiple frames/slices),
|
|
514
|
+
# take the mean across all non-batch dimensions to get one value per batch element
|
|
515
|
+
if len(percent_recovered.shape) > 1:
|
|
516
|
+
# Average over all axes except the batch dimension (axis 0)
|
|
517
|
+
axes_to_reduce = tuple(range(1, len(percent_recovered.shape)))
|
|
518
|
+
percent_recovered = ops.mean(percent_recovered, axis=axes_to_reduce)
|
|
519
|
+
return percent_recovered, recovered_masks
|
|
328
520
|
else:
|
|
329
|
-
|
|
521
|
+
percent_recovered, recovered_mask = _evaluate_recovered_circle_accuracy(images, centers)
|
|
522
|
+
return percent_recovered, recovered_mask
|
zea/data/convert/__init__.py
CHANGED
|
@@ -1,6 +1 @@
|
|
|
1
|
-
"""
|
|
2
|
-
|
|
3
|
-
from .camus import convert_camus
|
|
4
|
-
from .images import convert_image_dataset
|
|
5
|
-
from .matlab import zea_from_matlab_raw
|
|
6
|
-
from .picmus import convert_picmus
|
|
1
|
+
"""Data conversion of datasets to the ``zea`` data format."""
|