zea 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zea/ops.py CHANGED
@@ -380,7 +380,8 @@ class Pipeline:
380
380
  validate=True,
381
381
  timed: bool = False,
382
382
  ):
383
- """Initialize a pipeline
383
+ """
384
+ Initialize a pipeline.
384
385
 
385
386
  Args:
386
387
  operations (list): A list of Operation instances representing the operations
@@ -388,15 +389,22 @@ class Pipeline:
388
389
  with_batch_dim (bool, optional): Whether operations should expect a batch dimension.
389
390
  Defaults to True.
390
391
  jit_options (str, optional): The JIT options to use. Must be "pipeline", "ops", or None.
391
- - "pipeline" compiles the entire pipeline as a single function.
392
- This may be faster but, does not preserve python control flow, such as caching.
393
- - "ops" compiles each operation separately. This preserves python control flow and
394
- caching functionality, but speeds up the operations.
395
- - None disables JIT compilation.
392
+
393
+ - "pipeline": compiles the entire pipeline as a single function.
394
+ This may be faster but does not preserve python control flow, such as caching.
395
+
396
+ - "ops": compiles each operation separately. This preserves python control flow and
397
+ caching functionality, but speeds up the operations.
398
+
399
+ - None: disables JIT compilation.
400
+
396
401
  Defaults to "ops".
402
+
397
403
  jit_kwargs (dict, optional): Additional keyword arguments for the JIT compiler.
398
404
  name (str, optional): The name of the pipeline. Defaults to "pipeline".
399
405
  validate (bool, optional): Whether to validate the pipeline. Defaults to True.
406
+ timed (bool, optional): Whether to time each operation. Defaults to False.
407
+
400
408
  """
401
409
  self._call_pipeline = self.call
402
410
  self.name = name
@@ -1162,11 +1170,16 @@ class PatchedGrid(Pipeline):
1162
1170
  With this class you can form a pipeline that will be applied to patches of the grid.
1163
1171
  This is useful to avoid OOM errors when processing large grids.
1164
1172
 
1165
- Somethings to NOTE about this class:
1166
- - The ops have to use flatgrid and flat_pfield as inputs, these will be patched.
1167
- - Changing anything other than `self.output_data_type` in the dict will not be propagated!
1168
- - Will be jitted as a single operation, not the individual operations.
1169
- - This class handles the batching.
1173
+ Some things to NOTE about this class:
1174
+
1175
+ - The ops have to use flatgrid and flat_pfield as inputs, these will be patched.
1176
+
1177
+ - Changing anything other than `self.output_data_type` in the dict will not be propagated!
1178
+
1179
+ - Will be jitted as a single operation, not the individual operations.
1180
+
1181
+ - This class handles the batching.
1182
+
1170
1183
  """
1171
1184
 
1172
1185
  def __init__(self, *args, num_patches=10, **kwargs):
@@ -1376,20 +1389,6 @@ class Mean(Operation):
1376
1389
  return kwargs
1377
1390
 
1378
1391
 
1379
- @ops_registry("transpose")
1380
- class Transpose(Operation):
1381
- """Transpose the input data along the specified axes."""
1382
-
1383
- def __init__(self, axes, **kwargs):
1384
- super().__init__(**kwargs)
1385
- self.axes = axes
1386
-
1387
- def call(self, **kwargs):
1388
- data = kwargs[self.key]
1389
- transposed_data = ops.transpose(data, axes=self.axes)
1390
- return {self.output_key: transposed_data}
1391
-
1392
-
1393
1392
  @ops_registry("simulate_rf")
1394
1393
  class Simulate(Operation):
1395
1394
  """Simulate RF data."""
@@ -1578,19 +1577,6 @@ class PfieldWeighting(Operation):
1578
1577
  return {self.output_key: weighted_data}
1579
1578
 
1580
1579
 
1581
- @ops_registry("sum")
1582
- class Sum(Operation):
1583
- """Sum data along a specific axis."""
1584
-
1585
- def __init__(self, axis, **kwargs):
1586
- super().__init__(**kwargs)
1587
- self.axis = axis
1588
-
1589
- def call(self, **kwargs):
1590
- data = kwargs[self.key]
1591
- return {self.output_key: ops.sum(data, axis=self.axis)}
1592
-
1593
-
1594
1580
  @ops_registry("delay_and_sum")
1595
1581
  class DelayAndSum(Operation):
1596
1582
  """Sums time-delayed signals along channels and transmits."""
@@ -2124,29 +2110,37 @@ class Demodulate(Operation):
2124
2110
  class Lambda(Operation):
2125
2111
  """Use any function as an operation."""
2126
2112
 
2127
- def __init__(self, func, func_kwargs=None, **kwargs):
2128
- super().__init__(**kwargs)
2129
- func_kwargs = func_kwargs or {}
2130
- self.func = partial(func, **func_kwargs)
2113
+ def __init__(self, func, **kwargs):
2114
+ # Split kwargs into kwargs for partial and __init__
2115
+ op_kwargs = {k: v for k, v in kwargs.items() if k not in func.__code__.co_varnames}
2116
+ func_kwargs = {k: v for k, v in kwargs.items() if k in func.__code__.co_varnames}
2117
+ Lambda._check_if_unary(func, **func_kwargs)
2131
2118
 
2132
- def call(self, **kwargs):
2133
- data = kwargs[self.key]
2134
- data = self.func(data)
2135
- return {self.output_key: data}
2136
-
2137
-
2138
- @ops_registry("clip")
2139
- class Clip(Operation):
2140
- """Clip the input data to a given range."""
2119
+ super().__init__(**op_kwargs)
2120
+ self.func = partial(func, **func_kwargs)
2141
2121
 
2142
- def __init__(self, min_value=None, max_value=None, **kwargs):
2143
- super().__init__(**kwargs)
2144
- self.min_value = min_value
2145
- self.max_value = max_value
2122
+ @staticmethod
2123
+ def _check_if_unary(func, **kwargs):
2124
+ """Checks if the kwargs are sufficient to call the function as a unary operation."""
2125
+ sig = inspect.signature(func)
2126
+ # Remove arguments that are already provided in func_kwargs
2127
+ params = list(sig.parameters.values())
2128
+ remaining = [p for p in params if p.name not in kwargs]
2129
+ # Count required positional arguments (excluding self/cls)
2130
+ required_positional = [
2131
+ p
2132
+ for p in remaining
2133
+ if p.default is p.empty and p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
2134
+ ]
2135
+ if len(required_positional) != 1:
2136
+ raise ValueError(
2137
+ f"Partial of {func.__name__} must be callable with exactly one required "
2138
+ f"positional argument, we still need: {required_positional}."
2139
+ )
2146
2140
 
2147
2141
  def call(self, **kwargs):
2148
2142
  data = kwargs[self.key]
2149
- data = ops.clip(data, self.min_value, self.max_value)
2143
+ data = self.func(data)
2150
2144
  return {self.output_key: data}
2151
2145
 
2152
2146
 
@@ -2685,6 +2679,7 @@ class AnisotropicDiffusion(Operation):
2685
2679
  return result
2686
2680
 
2687
2681
 
2682
+ @ops_registry("channels_to_complex")
2688
2683
  class ChannelsToComplex(Operation):
2689
2684
  def call(self, **kwargs):
2690
2685
  data = kwargs[self.key]
@@ -2692,6 +2687,7 @@ class ChannelsToComplex(Operation):
2692
2687
  return {self.output_key: output}
2693
2688
 
2694
2689
 
2690
+ @ops_registry("complex_to_channels")
2695
2691
  class ComplexToChannels(Operation):
2696
2692
  def __init__(self, axis=-1, **kwargs):
2697
2693
  super().__init__(**kwargs)
zea/scan.py CHANGED
@@ -149,12 +149,13 @@ class Scan(Parameters):
149
149
  Defaults to 0.0.
150
150
  attenuation_coef (float, optional): Attenuation coefficient in dB/(MHz*cm).
151
151
  Defaults to 0.0.
152
- selected_transmits (None, str, int, list, or np.ndarray, optional):
152
+ selected_transmits (None, str, int, list, slice, or np.ndarray, optional):
153
153
  Specifies which transmit events to select.
154
154
  - None or "all": Use all transmits.
155
155
  - "center": Use only the center transmit.
156
156
  - int: Select this many evenly spaced transmits.
157
157
  - list/array: Use these specific transmit indices.
158
+ - slice: Use transmits specified by the slice (e.g., slice(0, 10, 2)).
158
159
  grid_type (str, optional): Type of grid to use for beamforming.
159
160
  Can be "cartesian" or "polar". Defaults to "cartesian".
160
161
  dynamic_range (tuple, optional): Dynamic range for image display.
@@ -171,13 +172,14 @@ class Scan(Parameters):
171
172
  "pixels_per_wavelength": {"type": int, "default": 4},
172
173
  "pfield_kwargs": {"type": dict, "default": {}},
173
174
  "apply_lens_correction": {"type": bool, "default": False},
174
- "lens_sound_speed": {"type": (float, int)},
175
+ "lens_sound_speed": {"type": float},
175
176
  "lens_thickness": {"type": float},
176
177
  "grid_type": {"type": str, "default": "cartesian"},
177
178
  "polar_limits": {"type": (tuple, list)},
178
179
  "dynamic_range": {"type": (tuple, list), "default": DEFAULT_DYNAMIC_RANGE},
180
+ "selected_transmits": {"type": (type(None), str, int, list, slice, np.ndarray)},
179
181
  # acquisition parameters
180
- "sound_speed": {"type": (float, int), "default": 1540.0},
182
+ "sound_speed": {"type": float, "default": 1540.0},
181
183
  "sampling_frequency": {"type": float},
182
184
  "center_frequency": {"type": float},
183
185
  "n_el": {"type": int},
@@ -359,6 +361,7 @@ class Scan(Parameters):
359
361
  - "center": Use only the center transmit
360
362
  - int: Select this many evenly spaced transmits
361
363
  - list/array: Use these specific transmit indices
364
+ - slice: Use transmits specified by the slice (e.g., slice(0, 10, 2))
362
365
 
363
366
  Returns:
364
367
  The current instance for method chaining.
@@ -416,6 +419,10 @@ class Scan(Parameters):
416
419
  self._invalidate_dependents("selected_transmits")
417
420
  return self
418
421
 
422
+ # Handle slice - convert to list of indices
423
+ if isinstance(selection, slice):
424
+ selection = list(range(n_tx_total))[selection]
425
+
419
426
  # Handle list of indices
420
427
  if isinstance(selection, list):
421
428
  # Validate indices
zea/tensor_ops.py CHANGED
@@ -130,6 +130,102 @@ def extend_n_dims(arr, axis, n_dims):
130
130
  return ops.reshape(arr, new_shape)
131
131
 
132
132
 
133
+ def vmap(fun, in_axes=0, out_axes=0):
134
+ """Vectorized map.
135
+
136
+ For torch and jax backends, this uses the native vmap implementation.
137
+ For other backends, this a wrapper that uses `ops.vectorized_map` under the hood.
138
+
139
+ Args:
140
+ fun: The function to be mapped.
141
+ in_axes: The axis or axes to be mapped over in the input.
142
+ Can be an integer, a tuple of integers, or None.
143
+ If None, the corresponding argument is not mapped over.
144
+ Defaults to 0.
145
+ out_axes: The axis or axes to be mapped over in the output.
146
+ Can be an integer, a tuple of integers, or None.
147
+ If None, the corresponding output is not mapped over.
148
+ Defaults to 0.
149
+
150
+ Returns:
151
+ A function that applies `fun` in a vectorized manner over the specified axes.
152
+
153
+ Raises:
154
+ ValueError: If the backend does not support vmap.
155
+ """
156
+ if keras.backend.backend() == "jax":
157
+ import jax
158
+
159
+ return jax.vmap(fun, in_axes=in_axes, out_axes=out_axes)
160
+ elif keras.backend.backend() == "torch":
161
+ import torch
162
+
163
+ return torch.vmap(fun, in_dims=in_axes, out_dims=out_axes)
164
+ else:
165
+ return manual_vmap(fun, in_axes=in_axes, out_axes=out_axes)
166
+
167
+
168
+ def manual_vmap(fun, in_axes=0, out_axes=0):
169
+ """Manual vectorized map for backends that do not support vmap."""
170
+
171
+ def find_map_length(args, in_axes):
172
+ """Find the length of the axis to map over."""
173
+ # NOTE: only needed for numpy, the other backends can handle a singleton dimension
174
+ for arg, axis in zip(args, in_axes):
175
+ if axis is None:
176
+ continue
177
+
178
+ return ops.shape(arg)[axis]
179
+ return 1
180
+
181
+ def _moveaxes(args, in_axes, out_axes):
182
+ """Move axes of the input arguments."""
183
+ args = list(args)
184
+ for i, (arg, in_axis, out_axis) in enumerate(zip(args, in_axes, out_axes)):
185
+ if in_axis is not None:
186
+ args[i] = ops.moveaxis(arg, in_axis, out_axis)
187
+ else:
188
+ args[i] = ops.repeat(arg[None], find_map_length(args, in_axes), axis=out_axis)
189
+ return tuple(args)
190
+
191
+ def _fun(args):
192
+ return fun(*args)
193
+
194
+ def wrapper(*args):
195
+ # If in_axes or out_axes is an int, convert to tuple
196
+ if isinstance(in_axes, int):
197
+ _in_axes = (in_axes,) * len(args)
198
+ else:
199
+ _in_axes = in_axes
200
+ if isinstance(out_axes, int):
201
+ _out_axes = (out_axes,) * len(args)
202
+ else:
203
+ _out_axes = out_axes
204
+ zeros = (0,) * len(args)
205
+
206
+ # Check that in_axes and out_axes are tuples
207
+ if not isinstance(_in_axes, tuple):
208
+ raise ValueError("in_axes must be an int or a tuple of ints.")
209
+ if not isinstance(_out_axes, tuple):
210
+ raise ValueError("out_axes must be an int or a tuple of ints.")
211
+
212
+ args = _moveaxes(args, _in_axes, zeros)
213
+ outputs = ops.vectorized_map(_fun, tuple(args))
214
+
215
+ tuple_output = isinstance(outputs, (tuple, list))
216
+ if not tuple_output:
217
+ outputs = (outputs,)
218
+
219
+ outputs = _moveaxes(outputs, zeros, _out_axes)
220
+
221
+ if not tuple_output:
222
+ outputs = outputs[0]
223
+
224
+ return outputs
225
+
226
+ return wrapper
227
+
228
+
133
229
  def func_with_one_batch_dim(
134
230
  func,
135
231
  tensor,
@@ -1270,6 +1366,14 @@ def L2(x):
1270
1366
  return ops.sqrt(ops.sum(x**2))
1271
1367
 
1272
1368
 
1369
+ def L1(x):
1370
+ """L1 norm of a tensor.
1371
+
1372
+ Implementation of L1 norm: https://mathworld.wolfram.com/L1-Norm.html
1373
+ """
1374
+ return ops.sum(ops.abs(x))
1375
+
1376
+
1273
1377
  def linear_sum_assignment(cost):
1274
1378
  """Greedy linear sum assignment.
1275
1379
 
@@ -1325,3 +1429,150 @@ else:
1325
1429
  def safe_vectorize(pyfunc, excluded=None, signature=None):
1326
1430
  """Just a wrapper around ops.vectorize."""
1327
1431
  return ops.vectorize(pyfunc, excluded=excluded, signature=signature)
1432
+
1433
+
1434
+ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
1435
+ """Apply a function to 1D array slices along an axis.
1436
+
1437
+ Keras implementation of numpy.apply_along_axis using keras.ops.vectorized_map.
1438
+
1439
+ Args:
1440
+ func1d: A callable function with signature ``func1d(arr, /, *args, **kwargs)``
1441
+ where ``*args`` and ``**kwargs`` are the additional positional and keyword
1442
+ arguments passed to apply_along_axis.
1443
+ axis: Integer axis along which to apply the function.
1444
+ arr: The array over which to apply the function.
1445
+ *args: Additional positional arguments passed through to func1d.
1446
+ **kwargs: Additional keyword arguments passed through to func1d.
1447
+
1448
+ Returns:
1449
+ The result of func1d applied along the specified axis.
1450
+ """
1451
+ # Convert to keras tensor
1452
+ arr = ops.convert_to_tensor(arr)
1453
+
1454
+ # Get array dimensions
1455
+ num_dims = len(arr.shape)
1456
+
1457
+ # Canonicalize axis (handle negative indices)
1458
+ if axis < 0:
1459
+ axis = num_dims + axis
1460
+
1461
+ if axis < 0 or axis >= num_dims:
1462
+ raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}")
1463
+
1464
+ # Create a wrapper function that applies func1d with the additional arguments
1465
+ def func(slice_arr):
1466
+ return func1d(slice_arr, *args, **kwargs)
1467
+
1468
+ # Recursively build up vectorized maps following the JAX pattern
1469
+ # For dimensions after the target axis (right side)
1470
+ for i in range(1, num_dims - axis):
1471
+ prev_func = func
1472
+
1473
+ def make_func(f, dim_offset):
1474
+ def vectorized_func(x):
1475
+ # Move the dimension we want to map over to the front
1476
+ perm = list(range(len(x.shape)))
1477
+ perm[0], perm[dim_offset] = perm[dim_offset], perm[0]
1478
+ x_moved = ops.transpose(x, perm)
1479
+ result = ops.vectorized_map(f, x_moved)
1480
+ # Move the result dimension back if needed
1481
+ if len(result.shape) > 0:
1482
+ result_perm = list(range(len(result.shape)))
1483
+ if len(result_perm) > dim_offset:
1484
+ result_perm[0], result_perm[dim_offset] = (
1485
+ result_perm[dim_offset],
1486
+ result_perm[0],
1487
+ )
1488
+ result = ops.transpose(result, result_perm)
1489
+ return result
1490
+
1491
+ return vectorized_func
1492
+
1493
+ func = make_func(prev_func, i)
1494
+
1495
+ # For dimensions before the target axis (left side)
1496
+ for i in range(axis):
1497
+ prev_func = func
1498
+
1499
+ def make_func(f):
1500
+ return lambda x: ops.vectorized_map(f, x)
1501
+
1502
+ func = make_func(prev_func)
1503
+
1504
+ return func(arr)
1505
+
1506
+
1507
+ def correlate(x, y, mode="full"):
1508
+ """
1509
+ Complex correlation via splitting real and imaginary parts.
1510
+ Equivalent to np.correlate(x, y, mode).
1511
+
1512
+ NOTE: this function exists because tensorflow does not support complex correlation.
1513
+ NOTE: tensorflow also handles padding differently than numpy, so we manually pad the input.
1514
+
1515
+ Args:
1516
+ x: np.ndarray (complex or real)
1517
+ y: np.ndarray (complex or real)
1518
+ mode: "full", "valid", or "same"
1519
+ """
1520
+ x = ops.convert_to_tensor(x)
1521
+ y = ops.convert_to_tensor(y)
1522
+
1523
+ is_complex = "complex" in ops.dtype(x) or "complex" in ops.dtype(y)
1524
+
1525
+ # Cast to complex64 if real
1526
+ if not is_complex:
1527
+ x = ops.cast(x, "complex64")
1528
+ y = ops.cast(y, "complex64")
1529
+
1530
+ # Split into real and imaginary
1531
+ xr, xi = ops.real(x), ops.imag(x)
1532
+ yr, yi = ops.real(y), ops.imag(y)
1533
+
1534
+ # Pad to do full correlation
1535
+ pad_left = ops.shape(y)[0] - 1
1536
+ pad_right = ops.shape(y)[0] - 1
1537
+ xr = ops.pad(xr, [[pad_left, pad_right]])
1538
+ xi = ops.pad(xi, [[pad_left, pad_right]])
1539
+
1540
+ # Correlation: sum over x[n] * conj(y[n+k])
1541
+ rr = ops.correlate(xr, yr, mode="valid")
1542
+ ii = ops.correlate(xi, yi, mode="valid")
1543
+ ri = ops.correlate(xr, yi, mode="valid")
1544
+ ir = ops.correlate(xi, yr, mode="valid")
1545
+
1546
+ real_part = rr + ii
1547
+ imag_part = ir - ri
1548
+
1549
+ real_part = ops.cast(real_part, "complex64")
1550
+ imag_part = ops.cast(imag_part, "complex64")
1551
+
1552
+ complex_tensor = real_part + 1j * imag_part
1553
+
1554
+ # Extract relevant part based on mode
1555
+ full_length = ops.shape(real_part)[0]
1556
+ x_len = ops.shape(x)[0]
1557
+ y_len = ops.shape(y)[0]
1558
+
1559
+ if mode == "same":
1560
+ # Return output of length max(M, N)
1561
+ target_len = ops.maximum(x_len, y_len)
1562
+ start = ops.floor((full_length - target_len) / 2)
1563
+ start = ops.cast(start, "int32")
1564
+ end = start + target_len
1565
+ complex_tensor = complex_tensor[start:end]
1566
+ elif mode == "valid":
1567
+ # Return output of length max(M, N) - min(M, N) + 1
1568
+ target_len = ops.maximum(x_len, y_len) - ops.minimum(x_len, y_len) + 1
1569
+ start = ops.ceil((full_length - target_len) / 2)
1570
+ start = ops.cast(start, "int32")
1571
+ end = start + target_len
1572
+ complex_tensor = complex_tensor[start:end]
1573
+ # For "full" mode, use the entire result (no slicing needed)
1574
+
1575
+ if is_complex:
1576
+ return complex_tensor
1577
+ else:
1578
+ return ops.real(complex_tensor)
@@ -83,8 +83,7 @@ def filter_edge_points_by_boundary(edge_points, is_left=True, min_cone_half_angl
83
83
 
84
84
 
85
85
  def detect_cone_parameters(image, min_cone_half_angle_deg=20, threshold=15):
86
- """
87
- Detect the ultrasound cone parameters from a grayscale image.
86
+ """Detect the ultrasound cone parameters from a grayscale image.
88
87
 
89
88
  This function performs the following steps:
90
89
  1. Thresholds the image to create a binary mask
@@ -209,6 +208,7 @@ def detect_cone_parameters(image, min_cone_half_angle_deg=20, threshold=15):
209
208
  apex_x = left_a + left_b * apex_y
210
209
 
211
210
  # Calculate cone height
211
+ max_y = ops.cast(max_y, apex_y.dtype)
212
212
  cone_height = max_y - apex_y
213
213
 
214
214
  # Calculate opening angle from the line slopes
@@ -250,10 +250,13 @@ def add_shape_from_mask(ax, mask, **kwargs):
250
250
  Returns:
251
251
  plt.ax: matplotlib axis with shape added
252
252
  """
253
- # Create a Path patch
254
- contours = measure.find_contours(mask, 0.5)
253
+ # Pad mask to ensure edge contours are found
254
+ padded_mask = np.pad(mask, pad_width=1, mode="constant", constant_values=0)
255
+ contours = measure.find_contours(padded_mask, 0.5)
255
256
  patches = []
256
257
  for contour in contours:
258
+ # Remove padding offset
259
+ contour -= 1
257
260
  path = pltPath(contour[:, ::-1])
258
261
  patch = PathPatch(path, **kwargs)
259
262
  patches.append(ax.add_patch(patch))
@@ -593,10 +596,11 @@ def interpolate_masks(
593
596
  assert all(mask.shape == mask_shape for mask in masks), "All masks must have the same shape."
594
597
 
595
598
  # distribute number of frames over number of masks
596
- num_frames_per_segment = [num_frames // (number_of_masks - 1)] * (number_of_masks - 1)
597
- if num_frames % num_frames_per_segment[0] != 0:
598
- # make sure that number of frames per mask adds up to total number of frames
599
- num_frames_per_segment[-1] += num_frames - sum(num_frames_per_segment)
599
+ base_frames = num_frames // (number_of_masks - 1)
600
+ remainder = num_frames % (number_of_masks - 1)
601
+ num_frames_per_segment = [base_frames] * (number_of_masks - 1)
602
+ for i in range(remainder):
603
+ num_frames_per_segment[i] += 1
600
604
 
601
605
  if rectangle:
602
606
  # get the rectangles
@@ -615,7 +619,6 @@ def interpolate_masks(
615
619
  for _rectangle in rectangles:
616
620
  interpolated_masks.append(reconstruct_mask_from_rectangle(_rectangle, mask_shape))
617
621
  return interpolated_masks
618
-
619
622
  # get the contours
620
623
  polygons = []
621
624
  for mask in masks:
@@ -726,6 +729,17 @@ def update_imshow_with_mask(
726
729
  return imshow_obj, mask_obj
727
730
 
728
731
 
732
+ def ask_for_title():
733
+ print("What are you selecting?")
734
+ title = input("Enter a title for the selection: ")
735
+ if not title:
736
+ raise ValueError("Title cannot be empty.")
737
+ # Convert title to snake_case
738
+ title = title.strip().replace(" ", "_").lower()
739
+ print(f"Title set to: {title}")
740
+ return title
741
+
742
+
729
743
  def main():
730
744
  """Main function for interactive selector on multiple images."""
731
745
  print(
@@ -751,6 +765,7 @@ def main():
751
765
  raise e
752
766
  print("No more images selected. Continuing...")
753
767
 
768
+ title = ask_for_title()
754
769
  selector = ask_for_selection_tool()
755
770
 
756
771
  if same_images is True:
@@ -829,6 +844,10 @@ def main():
829
844
  else:
830
845
  add_shape_from_mask(axs, interpolated_masks[0], alpha=0.5)
831
846
 
847
+ filestem = Path(file.parent / f"{file.stem}_{title}_annotations.gif")
848
+ np.save(filestem.with_suffix(".npy"), interpolated_masks)
849
+ print(f"Succesfully saved interpolated masks to {log.yellow(filestem.with_suffix('.npy'))}")
850
+
832
851
  fps = ask_save_animation_with_fps()
833
852
 
834
853
  ani = FuncAnimation(
@@ -838,9 +857,9 @@ def main():
838
857
  fargs=(axs, imshow_obj, images, interpolated_masks, selector),
839
858
  interval=1000 / fps,
840
859
  )
841
- filename = Path(file.parent.stem + "_" + f"{file.stem}_interpolated_masks.gif")
860
+ filename = filestem.with_suffix(".gif")
842
861
  ani.save(filename, writer="pillow")
843
- print(f"Succesfully saved animation as {filename}")
862
+ print(f"Succesfully saved animation as {log.yellow(filename)}")
844
863
 
845
864
 
846
865
  if __name__ == "__main__":
@@ -1,7 +1,8 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: zea
3
- Version: 0.0.4
3
+ Version: 0.0.6
4
4
  Summary: A Toolbox for Cognitive Ultrasound Imaging. Provides a set of tools for processing of ultrasound data, all built in your favorite machine learning framework.
5
+ License-File: LICENSE
5
6
  Keywords: ultrasound,machine learning,beamforming
6
7
  Author: Tristan Stevens
7
8
  Author-email: t.s.w.stevens@tue.nl
@@ -21,6 +22,7 @@ Provides-Extra: display
21
22
  Provides-Extra: display-headless
22
23
  Provides-Extra: docs
23
24
  Provides-Extra: jax
25
+ Provides-Extra: models
24
26
  Provides-Extra: tests
25
27
  Requires-Dist: IPython ; extra == "dev"
26
28
  Requires-Dist: IPython ; extra == "docs"
@@ -33,13 +35,14 @@ Requires-Dist: furo ; extra == "dev"
33
35
  Requires-Dist: furo ; extra == "docs"
34
36
  Requires-Dist: h5py (>=3.11)
35
37
  Requires-Dist: huggingface_hub (>=0.26)
38
+ Requires-Dist: imageio[ffmpeg] (>=2.0)
36
39
  Requires-Dist: ipykernel (>=6.29.5) ; extra == "dev"
37
40
  Requires-Dist: ipykernel (>=6.29.5) ; extra == "tests"
38
41
  Requires-Dist: ipywidgets ; extra == "dev"
39
42
  Requires-Dist: ipywidgets ; extra == "tests"
40
43
  Requires-Dist: jax ; extra == "backends"
41
44
  Requires-Dist: jax[cuda12-pip] (>=0.4.26) ; extra == "jax"
42
- Requires-Dist: keras (>=3.9)
45
+ Requires-Dist: keras (>=3.11)
43
46
  Requires-Dist: matplotlib (>=3.8)
44
47
  Requires-Dist: mock ; extra == "dev"
45
48
  Requires-Dist: mock ; extra == "docs"
@@ -48,6 +51,8 @@ Requires-Dist: myst-parser ; extra == "docs"
48
51
  Requires-Dist: nbsphinx ; extra == "dev"
49
52
  Requires-Dist: nbsphinx ; extra == "docs"
50
53
  Requires-Dist: numpy (>=1.24)
54
+ Requires-Dist: onnxruntime (>=1.15) ; extra == "dev"
55
+ Requires-Dist: onnxruntime (>=1.15) ; extra == "models"
51
56
  Requires-Dist: opencv-python (>=4) ; extra == "display"
52
57
  Requires-Dist: opencv-python-headless (>=4) ; extra == "dev"
53
58
  Requires-Dist: opencv-python-headless (>=4) ; extra == "display-headless"
@@ -69,6 +74,8 @@ Requires-Dist: scikit-learn (>=1.4)
69
74
  Requires-Dist: scipy (>=1.13)
70
75
  Requires-Dist: sphinx ; extra == "dev"
71
76
  Requires-Dist: sphinx ; extra == "docs"
77
+ Requires-Dist: sphinx-argparse ; extra == "dev"
78
+ Requires-Dist: sphinx-argparse ; extra == "docs"
72
79
  Requires-Dist: sphinx-autobuild ; extra == "dev"
73
80
  Requires-Dist: sphinx-autobuild ; extra == "docs"
74
81
  Requires-Dist: sphinx-autodoc-typehints ; extra == "dev"