zea 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. zea/__init__.py +3 -3
  2. zea/agent/masks.py +2 -2
  3. zea/agent/selection.py +3 -3
  4. zea/backend/__init__.py +1 -1
  5. zea/backend/tensorflow/dataloader.py +1 -5
  6. zea/beamform/beamformer.py +4 -2
  7. zea/beamform/pfield.py +2 -2
  8. zea/beamform/pixelgrid.py +1 -1
  9. zea/data/__init__.py +0 -9
  10. zea/data/augmentations.py +222 -29
  11. zea/data/convert/__init__.py +1 -6
  12. zea/data/convert/__main__.py +164 -0
  13. zea/data/convert/camus.py +106 -40
  14. zea/data/convert/echonet.py +184 -83
  15. zea/data/convert/echonetlvh/README.md +2 -3
  16. zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
  17. zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
  18. zea/data/convert/echonetlvh/precompute_crop.py +43 -64
  19. zea/data/convert/picmus.py +37 -40
  20. zea/data/convert/utils.py +86 -0
  21. zea/data/convert/verasonics.py +1247 -0
  22. zea/data/data_format.py +124 -6
  23. zea/data/dataloader.py +12 -7
  24. zea/data/datasets.py +109 -70
  25. zea/data/file.py +119 -82
  26. zea/data/file_operations.py +496 -0
  27. zea/data/preset_utils.py +2 -2
  28. zea/display.py +8 -9
  29. zea/doppler.py +5 -5
  30. zea/func/__init__.py +109 -0
  31. zea/{tensor_ops.py → func/tensor.py} +113 -69
  32. zea/func/ultrasound.py +500 -0
  33. zea/internal/_generate_keras_ops.py +5 -5
  34. zea/internal/checks.py +6 -12
  35. zea/internal/operators.py +4 -0
  36. zea/io_lib.py +108 -160
  37. zea/metrics.py +6 -5
  38. zea/models/__init__.py +1 -1
  39. zea/models/diffusion.py +63 -12
  40. zea/models/echonetlvh.py +1 -1
  41. zea/models/gmm.py +1 -1
  42. zea/models/lv_segmentation.py +2 -0
  43. zea/ops/__init__.py +188 -0
  44. zea/ops/base.py +442 -0
  45. zea/{keras_ops.py → ops/keras_ops.py} +2 -2
  46. zea/ops/pipeline.py +1472 -0
  47. zea/ops/tensor.py +356 -0
  48. zea/ops/ultrasound.py +890 -0
  49. zea/probes.py +2 -10
  50. zea/scan.py +35 -28
  51. zea/tools/fit_scan_cone.py +90 -160
  52. zea/tools/selection_tool.py +1 -1
  53. zea/tracking/__init__.py +16 -0
  54. zea/tracking/base.py +94 -0
  55. zea/tracking/lucas_kanade.py +474 -0
  56. zea/tracking/segmentation.py +110 -0
  57. zea/utils.py +11 -2
  58. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/METADATA +5 -1
  59. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/RECORD +62 -48
  60. zea/data/convert/matlab.py +0 -1237
  61. zea/ops.py +0 -3294
  62. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/WHEEL +0 -0
  63. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/entry_points.txt +0 -0
  64. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/licenses/LICENSE +0 -0
zea/__init__.py CHANGED
@@ -7,7 +7,7 @@ from . import log
7
7
 
8
8
  # dynamically add __version__ attribute (see pyproject.toml)
9
9
  # __version__ = __import__("importlib.metadata").metadata.version(__package__)
10
- __version__ = "0.0.7"
10
+ __version__ = "0.0.9"
11
11
 
12
12
 
13
13
  def _bootstrap_backend():
@@ -89,12 +89,12 @@ from . import (
89
89
  beamform,
90
90
  data,
91
91
  display,
92
+ func,
92
93
  io_lib,
93
- keras_ops,
94
94
  metrics,
95
95
  models,
96
+ ops,
96
97
  simulator,
97
- tensor_ops,
98
98
  utils,
99
99
  visualize,
100
100
  )
zea/agent/masks.py CHANGED
@@ -9,8 +9,8 @@ from typing import List
9
9
  import keras
10
10
  from keras import ops
11
11
 
12
- from zea import tensor_ops
13
12
  from zea.agent.gumbel import hard_straight_through
13
+ from zea.func.tensor import nonzero
14
14
 
15
15
  _DEFAULT_DTYPE = "bool"
16
16
 
@@ -56,7 +56,7 @@ def k_hot_to_indices(selected_lines, n_actions: int, fill_value=-1):
56
56
 
57
57
  # Find nonzero indices for each frame
58
58
  def get_nonzero(row):
59
- return tensor_ops.nonzero(row > 0, size=n_actions, fill_value=fill_value)[0]
59
+ return nonzero(row > 0, size=n_actions, fill_value=fill_value)[0]
60
60
 
61
61
  indices = ops.vectorized_map(get_nonzero, selected_lines)
62
62
  return indices
zea/agent/selection.py CHANGED
@@ -16,9 +16,9 @@ from typing import Callable
16
16
  import keras
17
17
  from keras import ops
18
18
 
19
- from zea import tensor_ops
20
19
  from zea.agent import masks
21
20
  from zea.backend.autograd import AutoGrad
21
+ from zea.func import tensor
22
22
  from zea.internal.registry import action_selection_registry
23
23
 
24
24
 
@@ -462,7 +462,7 @@ class CovarianceSamplingLines(LinesActionModel):
462
462
  particles = ops.reshape(particles, shape)
463
463
 
464
464
  # [batch_size, rows * stack_n_cols, n_possible_actions, n_possible_actions]
465
- cov_matrix = tensor_ops.batch_cov(particles)
465
+ cov_matrix = tensor.batch_cov(particles)
466
466
 
467
467
  # Sum over the row dimension [batch_size, n_possible_actions, n_possible_actions]
468
468
  cov_matrix = ops.sum(cov_matrix, axis=1)
@@ -477,7 +477,7 @@ class CovarianceSamplingLines(LinesActionModel):
477
477
  # Subsample the covariance matrix with random lines
478
478
  def subsample_with_mask(mask):
479
479
  """Subsample the covariance matrix with a single mask."""
480
- subsampled_cov_matrix = tensor_ops.boolean_mask(
480
+ subsampled_cov_matrix = tensor.boolean_mask(
481
481
  cov_matrix, mask, size=batch_size * self.n_actions**2
482
482
  )
483
483
  return ops.reshape(subsampled_cov_matrix, [batch_size, self.n_actions, self.n_actions])
zea/backend/__init__.py CHANGED
@@ -131,7 +131,7 @@ class on_device:
131
131
  .. code-block:: python
132
132
 
133
133
  with zea.backend.on_device("gpu:3"):
134
- pipeline = zea.Pipeline([zea.keras_ops.Abs()])
134
+ pipeline = zea.Pipeline([zea.ops.Abs()])
135
135
  output = pipeline(data=keras.random.normal((10, 10))) # output is on "cuda:3"
136
136
  """
137
137
 
@@ -12,8 +12,8 @@ from keras.src.trainers.data_adapters import TFDatasetAdapter
12
12
 
13
13
  from zea.data.dataloader import H5Generator
14
14
  from zea.data.layers import Resizer
15
+ from zea.func.tensor import translate
15
16
  from zea.internal.utils import find_methods_with_return_type
16
- from zea.tensor_ops import translate
17
17
 
18
18
  METHODS_THAT_RETURN_DATASET = find_methods_with_return_type(tf.data.Dataset, "DatasetV2")
19
19
 
@@ -155,10 +155,6 @@ def make_dataloader(
155
155
  Mimics the native TF function ``tf.keras.utils.image_dataset_from_directory``
156
156
  but for .hdf5 files.
157
157
 
158
- Saves a dataset_info.yaml file in the directory with information about the dataset.
159
- This file is used to load the dataset later on, which speeds up the initial loading
160
- of the dataset for very large datasets.
161
-
162
158
  Does the following in order to load a dataset:
163
159
 
164
160
  - Find all .hdf5 files in the director(ies)
@@ -5,7 +5,7 @@ import numpy as np
5
5
  from keras import ops
6
6
 
7
7
  from zea.beamform.lens_correction import calculate_lens_corrected_delays
8
- from zea.tensor_ops import vmap
8
+ from zea.func.tensor import vmap
9
9
 
10
10
 
11
11
  def fnum_window_fn_rect(normalized_angle):
@@ -379,7 +379,7 @@ def complex_rotate(iq, theta):
379
379
 
380
380
  .. math::
381
381
 
382
- x(t + \\Delta t) &= I'(t) \\cos(\\omega_c (t + \\Delta t))
382
+ x(t + \\Delta t) &= I'(t) \\cos(\\omega_c (t + \\Delta t))
383
383
  - Q'(t) \\sin(\\omega_c (t + \\Delta t))\\\\
384
384
  &= \\overbrace{(I'(t)\\cos(\\theta)
385
385
  - Q'(t)\\sin(\\theta) )}^{I_\\Delta(t)} \\cos(\\omega_c t)\\\\
@@ -452,6 +452,8 @@ def distance_Tx_generic(
452
452
  `(n_el,)`.
453
453
  probe_geometry (ops.Tensor): The positions of the transducer elements of shape
454
454
  `(n_el, 3)`.
455
+ focus_distance (float): The focus distance in meters.
456
+ polar_angle (float): The polar angle in radians.
455
457
  sound_speed (float): The speed of sound in m/s. Defaults to 1540.
456
458
 
457
459
  Returns:
zea/beamform/pfield.py CHANGED
@@ -24,8 +24,8 @@ import numpy as np
24
24
  from keras import ops
25
25
 
26
26
  from zea import log
27
+ from zea.func.tensor import sinc
27
28
  from zea.internal.cache import cache_output
28
- from zea.tensor_ops import sinc
29
29
 
30
30
 
31
31
  def _abs_sinc(x):
@@ -101,7 +101,7 @@ def compute_pfield(
101
101
  # array params
102
102
  probe_geometry = ops.convert_to_tensor(probe_geometry, dtype="float32")
103
103
 
104
- pitch = probe_geometry[1, 0] - probe_geometry[0, 0] # element pitch
104
+ pitch = ops.abs(probe_geometry[1, 0] - probe_geometry[0, 0]) # element pitch
105
105
 
106
106
  kerf = 0.1 * pitch # for now this is hardcoded
107
107
  element_width = pitch - kerf
zea/beamform/pixelgrid.py CHANGED
@@ -45,7 +45,7 @@ def cartesian_pixel_grid(xlims, zlims, grid_size_x=None, grid_size_z=None, dx=No
45
45
  ValueError: Either grid_size_x and grid_size_z or dx and dz must be defined.
46
46
 
47
47
  Returns:
48
- grid (np.ndarray): Pixel grid of size (grid_size_z, nx, 3) in
48
+ grid (np.ndarray): Pixel grid of size (grid_size_z, grid_size_x, 3) in
49
49
  Cartesian coordinates (x, y, z)
50
50
  """
51
51
  assert (bool(grid_size_x) and bool(grid_size_z)) ^ (bool(dx) and bool(dz)), (
zea/data/__init__.py CHANGED
@@ -38,15 +38,6 @@ Examples usage
38
38
  ... files.append(file) # process each file as needed
39
39
  >>> dataset.close()
40
40
 
41
- Subpackage layout
42
- -----------------
43
-
44
- - ``file.py``: Implements :class:`zea.File` and related file utilities.
45
- - ``datasets.py``: Implements :class:`zea.Dataset` and folder management.
46
- - ``dataloader.py``: Data loading utilities for batching and shuffling.
47
- - ``data_format.py``: Data validation and example dataset generation.
48
- - ``convert/``: Data conversion tools (e.g., from external formats).
49
-
50
41
  """ # noqa: E501
51
42
 
52
43
  from .convert.camus import sitk_load
zea/data/augmentations.py CHANGED
@@ -4,7 +4,7 @@ import keras
4
4
  import numpy as np
5
5
  from keras import layers, ops
6
6
 
7
- from zea.tensor_ops import is_jax_prng_key, split_seed
7
+ from zea.func.tensor import is_jax_prng_key, split_seed
8
8
 
9
9
 
10
10
  class RandomCircleInclusion(layers.Layer):
@@ -30,7 +30,7 @@ class RandomCircleInclusion(layers.Layer):
30
30
 
31
31
  def __init__(
32
32
  self,
33
- radius: int,
33
+ radius: int | tuple[int, int],
34
34
  fill_value: float = 1.0,
35
35
  circle_axes: tuple[int, int] = (1, 2),
36
36
  with_batch_dim=True,
@@ -38,25 +38,70 @@ class RandomCircleInclusion(layers.Layer):
38
38
  recovery_threshold=0.1,
39
39
  randomize_location_across_batch=True,
40
40
  seed=None,
41
+ width_range: tuple[int, int] = None,
42
+ height_range: tuple[int, int] = None,
41
43
  **kwargs,
42
44
  ):
43
45
  """
44
46
  Initialize RandomCircleInclusion.
45
47
 
46
48
  Args:
47
- radius (int): Radius of the circle to include.
49
+ radius (int or tuple[int, int]): Radius of the circle/ellipse to include.
48
50
  fill_value (float): Value to fill inside the circle.
49
- circle_axes (tuple[int, int]): Axes along which to draw the circle (height, width).
51
+ circle_axes (tuple[int, int]): Axes along which to draw the circle
52
+ (height, width).
50
53
  with_batch_dim (bool): Whether input has a batch dimension.
51
54
  return_centers (bool): Whether to return circle centers along with images.
52
55
  recovery_threshold (float): Threshold for considering a pixel as recovered.
53
- randomize_location_across_batch (bool): If True, randomize circle location
54
- per batch element.
56
+ randomize_location_across_batch (bool): If True (and with_batch_dim=True),
57
+ each batch element gets a different random center. If False, all batch
58
+ elements share the same center.
55
59
  seed (Any): Optional random seed for reproducibility.
60
+ width_range (tuple[int, int], optional): Range (min, max) for circle
61
+ center x (width axis).
62
+ height_range (tuple[int, int], optional): Range (min, max) for circle
63
+ center y (height axis).
56
64
  **kwargs: Additional keyword arguments for the parent Layer.
65
+
66
+ Example:
67
+ .. doctest::
68
+
69
+ >>> from zea.data.augmentations import RandomCircleInclusion
70
+ >>> from keras import ops
71
+
72
+ >>> layer = RandomCircleInclusion(
73
+ ... radius=5,
74
+ ... circle_axes=(1, 2),
75
+ ... with_batch_dim=True,
76
+ ... )
77
+ >>> image = ops.zeros((1, 28, 28), dtype="float32")
78
+ >>> out = layer(image) # doctest: +SKIP
79
+
57
80
  """
58
81
  super().__init__(**kwargs)
59
- self.radius = radius
82
+
83
+ # Validate randomize_location_across_batch only makes sense with batch dim
84
+ if not with_batch_dim and not randomize_location_across_batch:
85
+ raise ValueError(
86
+ "randomize_location_across_batch=False is only applicable when "
87
+ "with_batch_dim=True. When with_batch_dim=False, there is no batch "
88
+ "to randomize across."
89
+ )
90
+ # Convert radius to tuple if int, else validate tuple
91
+ if isinstance(radius, int):
92
+ if radius <= 0:
93
+ raise ValueError(f"radius must be a positive integer, got {radius}.")
94
+ self.radius = (radius, radius)
95
+ elif isinstance(radius, tuple) and len(radius) == 2:
96
+ rx, ry = radius
97
+ if not all(isinstance(r, int) for r in (rx, ry)):
98
+ raise TypeError(f"radius tuple must contain two integers, got {radius!r}.")
99
+ if rx <= 0 or ry <= 0:
100
+ raise ValueError(f"radius components must be positive, got {radius!r}.")
101
+ self.radius = (rx, ry)
102
+ else:
103
+ raise TypeError("radius must be an int or a tuple of two ints")
104
+
60
105
  self.fill_value = fill_value
61
106
  self.circle_axes = circle_axes
62
107
  self.with_batch_dim = with_batch_dim
@@ -64,6 +109,8 @@ class RandomCircleInclusion(layers.Layer):
64
109
  self.recovery_threshold = recovery_threshold
65
110
  self.randomize_location_across_batch = randomize_location_across_batch
66
111
  self.seed = seed
112
+ self.width_range = width_range
113
+ self.height_range = height_range
67
114
  self._axis1 = None
68
115
  self._axis2 = None
69
116
  self._perm = None
@@ -116,6 +163,43 @@ class RandomCircleInclusion(layers.Layer):
116
163
  self._static_w = int(permuted_shape[-1])
117
164
  self._static_shape = tuple(permuted_shape)
118
165
 
166
+ # Validate that ellipse can fit within image bounds
167
+ rx, ry = self.radius
168
+ min_required_width = 2 * rx + 1
169
+ min_required_height = 2 * ry + 1
170
+
171
+ if self._static_w < min_required_width:
172
+ raise ValueError(
173
+ f"Image width ({self._static_w}) is too small for radius {rx}. "
174
+ f"Minimum required width: {min_required_width}"
175
+ )
176
+ if self._static_h < min_required_height:
177
+ raise ValueError(
178
+ f"Image height ({self._static_h}) is too small for radius {ry}. "
179
+ f"Minimum required height: {min_required_height}"
180
+ )
181
+
182
+ # Validate width_range and height_range if provided
183
+ if self.width_range is not None:
184
+ min_x, max_x = self.width_range
185
+ if min_x >= max_x:
186
+ raise ValueError(f"width_range must have min < max, got {self.width_range}")
187
+ if min_x < rx or max_x > self._static_w - rx:
188
+ raise ValueError(
189
+ f"width_range {self.width_range} would place circle outside image bounds. "
190
+ f"Valid range: [{rx}, {self._static_w - rx})"
191
+ )
192
+
193
+ if self.height_range is not None:
194
+ min_y, max_y = self.height_range
195
+ if min_y >= max_y:
196
+ raise ValueError(f"height_range must have min < max, got {self.height_range}")
197
+ if min_y < ry or max_y > self._static_h - ry:
198
+ raise ValueError(
199
+ f"height_range {self.height_range} would place circle outside image bounds. "
200
+ f"Valid range: [{ry}, {self._static_h - ry})"
201
+ )
202
+
119
203
  super().build(input_shape)
120
204
 
121
205
  def compute_output_shape(self, input_shape):
@@ -165,7 +249,7 @@ class RandomCircleInclusion(layers.Layer):
165
249
  centers (Tensor): Tensor of shape (batch, 2) with circle centers.
166
250
  h (int): Height of the image.
167
251
  w (int): Width of the image.
168
- radius (int): Radius of the circle.
252
+ radius (tuple[int, int]): Radii of the ellipse (rx, ry).
169
253
  dtype (str or dtype): Data type for the mask.
170
254
 
171
255
  Returns:
@@ -176,12 +260,12 @@ class RandomCircleInclusion(layers.Layer):
176
260
  Y, X = ops.meshgrid(Y, X, indexing="ij")
177
261
  Y = ops.expand_dims(Y, 0) # (1, h, w)
178
262
  X = ops.expand_dims(X, 0) # (1, h, w)
179
- # cx = ops.cast(centers[:, 0], "float32")[:, None, None]
180
- # cy = ops.cast(centers[:, 1], "float32")[:, None, None]
181
263
  cx = centers[:, 0][:, None, None]
182
264
  cy = centers[:, 1][:, None, None]
183
- dist2 = (X - cx) ** 2 + (Y - cy) ** 2
184
- mask = ops.cast(dist2 <= radius**2, dtype)
265
+ rx, ry = radius
266
+ # Ellipse equation: ((X-cx)/rx)^2 + ((Y-cy)/ry)^2 <= 1
267
+ dist = ((X - cx) / rx) ** 2 + ((Y - cy) / ry) ** 2
268
+ mask = ops.cast(dist <= 1, dtype)
185
269
  return mask
186
270
 
187
271
  def call(self, x, seed=None):
@@ -197,9 +281,17 @@ class RandomCircleInclusion(layers.Layer):
197
281
  centers if return_centers is True.
198
282
  """
199
283
  if keras.backend.backend() == "jax" and not is_jax_prng_key(seed):
200
- raise NotImplementedError(
201
- "jax.random.key() is not supported, please use jax.random.PRNGKey()"
202
- )
284
+ if isinstance(seed, keras.random.SeedGenerator):
285
+ raise ValueError(
286
+ "When using JAX backend, please provide a jax.random.PRNGKey as seed, "
287
+ "instead of keras.random.SeedGenerator."
288
+ )
289
+ else:
290
+ raise TypeError(
291
+ f"When using JAX backend, seed must be a JAX PRNG key (created with "
292
+ f"jax.random.PRNGKey()), but got {type(seed)}. Note: jax.random.key() "
293
+ f"keys are not currently supported."
294
+ )
203
295
  seed = seed if seed is not None else self.seed
204
296
 
205
297
  if self.with_batch_dim:
@@ -209,22 +301,33 @@ class RandomCircleInclusion(layers.Layer):
209
301
  imgs, centers = ops.map(lambda arg: self._call(arg, seed), x)
210
302
  else:
211
303
  raise NotImplementedError(
212
- "You cannot fix circle locations across while using"
304
+ "You cannot fix circle locations across batch while using "
213
305
  + "RandomCircleInclusion as a dataset augmentation, "
214
306
  + "since samples in a batch are handled independently."
215
307
  )
216
308
  else:
309
+ batch_size = ops.shape(x)[0]
217
310
  if self.randomize_location_across_batch:
218
- batch_size = ops.shape(x)[0]
219
311
  seeds = split_seed(seed, batch_size)
220
- if all(seed is seeds[0] for seed in seeds):
312
+ if all(s is seeds[0] for s in seeds):
221
313
  imgs, centers = ops.map(lambda arg: self._call(arg, seeds[0]), x)
222
314
  else:
223
315
  imgs, centers = ops.map(
224
316
  lambda args: self._call(args[0], args[1]), (x, seeds)
225
317
  )
226
318
  else:
227
- imgs, centers = ops.map(lambda arg: self._call(arg, seed), x)
319
+ # Generate one random center that will be used for all batch elements
320
+ img0, center0 = self._call(x[0], seed)
321
+
322
+ # Apply the same center to all batch elements
323
+ imgs_list, centers_list = [img0], [center0]
324
+ for i in range(1, batch_size):
325
+ img_aug, center_out = self._call_with_fixed_center(x[i], center0)
326
+ imgs_list.append(img_aug)
327
+ centers_list.append(center_out)
328
+
329
+ imgs = ops.stack(imgs_list, axis=0)
330
+ centers = ops.stack(centers_list, axis=0)
228
331
  else:
229
332
  imgs, centers = self._call(x, seed)
230
333
 
@@ -248,17 +351,28 @@ class RandomCircleInclusion(layers.Layer):
248
351
  flat, flat_batch_size, h, w = self._flatten_batch_and_other_dims(x)
249
352
 
250
353
  def _draw_circle_2d(img2d):
354
+ rx, ry = self.radius
355
+ # Determine allowed ranges for center
356
+ if self.width_range is not None:
357
+ min_x, max_x = self.width_range
358
+ else:
359
+ min_x, max_x = rx, w - rx
360
+ if self.height_range is not None:
361
+ min_y, max_y = self.height_range
362
+ else:
363
+ min_y, max_y = ry, h - ry
364
+ # Ensure the ellipse fits within the allowed region
251
365
  cx = ops.cast(
252
- keras.random.uniform((), self.radius, w - self.radius, seed=seed),
366
+ keras.random.uniform((), min_x, max_x, seed=seed),
253
367
  "int32",
254
368
  )
255
369
  new_seed, _ = split_seed(seed, 2) # ensure that cx and cy are independent
256
370
  cy = ops.cast(
257
- keras.random.uniform((), self.radius, h - self.radius, seed=new_seed),
371
+ keras.random.uniform((), min_y, max_y, seed=new_seed),
258
372
  "int32",
259
373
  )
260
374
  mask = self._make_circle_mask(
261
- ops.stack([cx, cy])[None, :], h, w, self.radius, img2d.dtype
375
+ ops.stack([cx, cy])[None, :], h, w, (rx, ry), img2d.dtype
262
376
  )[0]
263
377
  img_aug = img2d * (1 - mask) + self.fill_value * mask
264
378
  center = ops.stack([cx, cy])
@@ -271,6 +385,67 @@ class RandomCircleInclusion(layers.Layer):
271
385
  centers = ops.reshape(centers, centers_shape)
272
386
  return (aug_imgs, centers)
273
387
 
388
+ def _apply_circle_mask(self, flat, center, h, w):
389
+ """Apply circle mask to flattened image data.
390
+
391
+ Args:
392
+ flat (Tensor): Flattened image data of shape (flat_batch, h, w).
393
+ center (Tensor): Center coordinates, either (2,) or (flat_batch, 2).
394
+ h (int): Height of images.
395
+ w (int): Width of images.
396
+
397
+ Returns:
398
+ Tensor: Augmented images with circle applied.
399
+ """
400
+ rx, ry = self.radius
401
+
402
+ # Ensure center has batch dimension for broadcasting
403
+ if len(center.shape) == 1:
404
+ # Single center (2,) -> broadcast to all slices
405
+ center_batched = ops.tile(ops.reshape(center, [1, 2]), [flat.shape[0], 1])
406
+ else:
407
+ # Already batched (flat_batch, 2)
408
+ center_batched = center
409
+
410
+ # Create masks for all slices using vectorized_map or broadcasting
411
+ masks = self._make_circle_mask(center_batched, h, w, (rx, ry), flat.dtype)
412
+
413
+ # Apply masks
414
+ aug_imgs = flat * (1 - masks) + self.fill_value * masks
415
+ return aug_imgs
416
+
417
+ def _call_with_fixed_center(self, x, fixed_center):
418
+ """Apply augmentation using a pre-determined center.
419
+
420
+ Args:
421
+ x (Tensor): Input image tensor.
422
+ fixed_center (Tensor): Pre-determined center coordinates, either (2,)
423
+ for a single center or (flat_batch, 2) for per-slice centers.
424
+
425
+ Returns:
426
+ tuple: (augmented image, center coordinates).
427
+ """
428
+ x = self._permute_axes_to_circle_last(x)
429
+ flat, flat_batch_size, h, w = self._flatten_batch_and_other_dims(x)
430
+
431
+ # Apply circle mask with fixed center (handles both single and batched centers)
432
+ aug_imgs = self._apply_circle_mask(flat, fixed_center, h, w)
433
+ aug_imgs = ops.reshape(aug_imgs, x.shape)
434
+ aug_imgs = ops.transpose(aug_imgs, axes=self._inv_perm)
435
+
436
+ # Return centers matching the expected shape
437
+ if len(fixed_center.shape) == 1:
438
+ # Single center (2,) -> broadcast to match flat_batch_size
439
+ if flat_batch_size == 1:
440
+ centers = fixed_center
441
+ else:
442
+ centers = ops.tile(ops.reshape(fixed_center, [1, 2]), [flat_batch_size, 1])
443
+ else:
444
+ # Already batched centers (flat_batch, 2)
445
+ centers = fixed_center
446
+
447
+ return (aug_imgs, centers)
448
+
274
449
  def get_config(self):
275
450
  """
276
451
  Get layer configuration for serialization.
@@ -285,6 +460,8 @@ class RandomCircleInclusion(layers.Layer):
285
460
  "fill_value": self.fill_value,
286
461
  "circle_axes": self.circle_axes,
287
462
  "return_centers": self.return_centers,
463
+ "width_range": self.width_range,
464
+ "height_range": self.height_range,
288
465
  }
289
466
  )
290
467
  return cfg
@@ -293,7 +470,8 @@ class RandomCircleInclusion(layers.Layer):
293
470
  self, images, centers, recovery_threshold, fill_value=None
294
471
  ):
295
472
  """
296
- Evaluate the percentage of the true circle that has been recovered in the images.
473
+ Evaluate the percentage of the true circle that has been recovered in the images,
474
+ and return a mask of the detected part of the circle.
297
475
 
298
476
  Args:
299
477
  images (Tensor): Tensor of images (any shape, with circle axes as specified).
@@ -302,8 +480,12 @@ class RandomCircleInclusion(layers.Layer):
302
480
  fill_value (float, optional): Optionally override fill_value for cases
303
481
  where image range has changed.
304
482
 
305
- Returns:
306
- Tensor: Percentage recovered for each circle (shape: [num_circles]).
483
+ Returns:
484
+ Tuple[Tensor, Tensor]:
485
+ - percent_recovered: [batch] - average recovery percentage per batch element,
486
+ averaged across all non-batch dimensions (e.g., frames, slices)
487
+ - recovered_masks: [batch, flat_batch, h, w] or [batch, h, w] or [flat_batch, h, w]
488
+ depending on input shape - binary mask of detected circle regions
307
489
  """
308
490
  fill_value = fill_value or self.fill_value
309
491
 
@@ -318,12 +500,23 @@ class RandomCircleInclusion(layers.Layer):
318
500
  recovered_sum = ops.sum(recovered, axis=[1, 2])
319
501
  mask_sum = ops.sum(mask, axis=[1, 2])
320
502
  percent_recovered = recovered_sum / (mask_sum + 1e-8)
321
- return percent_recovered
503
+ # recovered_mask: binary mask of detected part of the circle
504
+ recovered_mask = ops.cast(recovered > 0, flat_image.dtype)
505
+ return percent_recovered, recovered_mask
322
506
 
323
507
  if self.with_batch_dim:
324
- return ops.vectorized_map(
508
+ results = ops.vectorized_map(
325
509
  lambda args: _evaluate_recovered_circle_accuracy(args[0], args[1]),
326
510
  (images, centers),
327
- )[..., 0]
511
+ )
512
+ percent_recovered, recovered_masks = results
513
+ # If there are multiple circles per batch element (e.g., multiple frames/slices),
514
+ # take the mean across all non-batch dimensions to get one value per batch element
515
+ if len(percent_recovered.shape) > 1:
516
+ # Average over all axes except the batch dimension (axis 0)
517
+ axes_to_reduce = tuple(range(1, len(percent_recovered.shape)))
518
+ percent_recovered = ops.mean(percent_recovered, axis=axes_to_reduce)
519
+ return percent_recovered, recovered_masks
328
520
  else:
329
- return _evaluate_recovered_circle_accuracy(images, centers)
521
+ percent_recovered, recovered_mask = _evaluate_recovered_circle_accuracy(images, centers)
522
+ return percent_recovered, recovered_mask
@@ -1,6 +1 @@
1
- """All functions in this module are used to convert different datasets to the zea format."""
2
-
3
- from .camus import convert_camus
4
- from .images import convert_image_dataset
5
- from .matlab import zea_from_matlab_raw
6
- from .picmus import convert_picmus
1
+ """Data conversion of datasets to the ``zea`` data format."""