zea 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +8 -7
- zea/__main__.py +8 -26
- zea/agent/selection.py +166 -0
- zea/backend/__init__.py +89 -0
- zea/backend/jax/__init__.py +14 -51
- zea/backend/tensorflow/__init__.py +0 -49
- zea/backend/torch/__init__.py +27 -62
- zea/data/__main__.py +6 -3
- zea/data/file.py +19 -74
- zea/data/layers.py +2 -3
- zea/display.py +1 -5
- zea/doppler.py +75 -0
- zea/internal/_generate_keras_ops.py +125 -0
- zea/internal/core.py +10 -3
- zea/internal/device.py +33 -16
- zea/internal/notebooks.py +39 -0
- zea/internal/operators.py +10 -0
- zea/internal/parameters.py +75 -19
- zea/internal/registry.py +1 -1
- zea/internal/viewer.py +24 -24
- zea/io_lib.py +60 -62
- zea/keras_ops.py +1989 -0
- zea/metrics.py +357 -65
- zea/models/__init__.py +6 -3
- zea/models/deeplabv3.py +131 -0
- zea/models/diffusion.py +18 -18
- zea/models/echonetlvh.py +279 -0
- zea/models/lv_segmentation.py +79 -0
- zea/models/presets.py +50 -0
- zea/models/regional_quality.py +122 -0
- zea/ops.py +52 -56
- zea/scan.py +10 -3
- zea/tensor_ops.py +251 -0
- zea/tools/fit_scan_cone.py +2 -2
- zea/tools/selection_tool.py +28 -9
- {zea-0.0.4.dist-info → zea-0.0.6.dist-info}/METADATA +10 -3
- {zea-0.0.4.dist-info → zea-0.0.6.dist-info}/RECORD +40 -33
- {zea-0.0.4.dist-info → zea-0.0.6.dist-info}/WHEEL +1 -1
- zea/internal/convert.py +0 -150
- {zea-0.0.4.dist-info → zea-0.0.6.dist-info}/entry_points.txt +0 -0
- {zea-0.0.4.dist-info → zea-0.0.6.dist-info/licenses}/LICENSE +0 -0
zea/ops.py
CHANGED
|
@@ -380,7 +380,8 @@ class Pipeline:
|
|
|
380
380
|
validate=True,
|
|
381
381
|
timed: bool = False,
|
|
382
382
|
):
|
|
383
|
-
"""
|
|
383
|
+
"""
|
|
384
|
+
Initialize a pipeline.
|
|
384
385
|
|
|
385
386
|
Args:
|
|
386
387
|
operations (list): A list of Operation instances representing the operations
|
|
@@ -388,15 +389,22 @@ class Pipeline:
|
|
|
388
389
|
with_batch_dim (bool, optional): Whether operations should expect a batch dimension.
|
|
389
390
|
Defaults to True.
|
|
390
391
|
jit_options (str, optional): The JIT options to use. Must be "pipeline", "ops", or None.
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
-
|
|
392
|
+
|
|
393
|
+
- "pipeline": compiles the entire pipeline as a single function.
|
|
394
|
+
This may be faster but does not preserve python control flow, such as caching.
|
|
395
|
+
|
|
396
|
+
- "ops": compiles each operation separately. This preserves python control flow and
|
|
397
|
+
caching functionality, but speeds up the operations.
|
|
398
|
+
|
|
399
|
+
- None: disables JIT compilation.
|
|
400
|
+
|
|
396
401
|
Defaults to "ops".
|
|
402
|
+
|
|
397
403
|
jit_kwargs (dict, optional): Additional keyword arguments for the JIT compiler.
|
|
398
404
|
name (str, optional): The name of the pipeline. Defaults to "pipeline".
|
|
399
405
|
validate (bool, optional): Whether to validate the pipeline. Defaults to True.
|
|
406
|
+
timed (bool, optional): Whether to time each operation. Defaults to False.
|
|
407
|
+
|
|
400
408
|
"""
|
|
401
409
|
self._call_pipeline = self.call
|
|
402
410
|
self.name = name
|
|
@@ -1162,11 +1170,16 @@ class PatchedGrid(Pipeline):
|
|
|
1162
1170
|
With this class you can form a pipeline that will be applied to patches of the grid.
|
|
1163
1171
|
This is useful to avoid OOM errors when processing large grids.
|
|
1164
1172
|
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1173
|
+
Some things to NOTE about this class:
|
|
1174
|
+
|
|
1175
|
+
- The ops have to use flatgrid and flat_pfield as inputs, these will be patched.
|
|
1176
|
+
|
|
1177
|
+
- Changing anything other than `self.output_data_type` in the dict will not be propagated!
|
|
1178
|
+
|
|
1179
|
+
- Will be jitted as a single operation, not the individual operations.
|
|
1180
|
+
|
|
1181
|
+
- This class handles the batching.
|
|
1182
|
+
|
|
1170
1183
|
"""
|
|
1171
1184
|
|
|
1172
1185
|
def __init__(self, *args, num_patches=10, **kwargs):
|
|
@@ -1376,20 +1389,6 @@ class Mean(Operation):
|
|
|
1376
1389
|
return kwargs
|
|
1377
1390
|
|
|
1378
1391
|
|
|
1379
|
-
@ops_registry("transpose")
|
|
1380
|
-
class Transpose(Operation):
|
|
1381
|
-
"""Transpose the input data along the specified axes."""
|
|
1382
|
-
|
|
1383
|
-
def __init__(self, axes, **kwargs):
|
|
1384
|
-
super().__init__(**kwargs)
|
|
1385
|
-
self.axes = axes
|
|
1386
|
-
|
|
1387
|
-
def call(self, **kwargs):
|
|
1388
|
-
data = kwargs[self.key]
|
|
1389
|
-
transposed_data = ops.transpose(data, axes=self.axes)
|
|
1390
|
-
return {self.output_key: transposed_data}
|
|
1391
|
-
|
|
1392
|
-
|
|
1393
1392
|
@ops_registry("simulate_rf")
|
|
1394
1393
|
class Simulate(Operation):
|
|
1395
1394
|
"""Simulate RF data."""
|
|
@@ -1578,19 +1577,6 @@ class PfieldWeighting(Operation):
|
|
|
1578
1577
|
return {self.output_key: weighted_data}
|
|
1579
1578
|
|
|
1580
1579
|
|
|
1581
|
-
@ops_registry("sum")
|
|
1582
|
-
class Sum(Operation):
|
|
1583
|
-
"""Sum data along a specific axis."""
|
|
1584
|
-
|
|
1585
|
-
def __init__(self, axis, **kwargs):
|
|
1586
|
-
super().__init__(**kwargs)
|
|
1587
|
-
self.axis = axis
|
|
1588
|
-
|
|
1589
|
-
def call(self, **kwargs):
|
|
1590
|
-
data = kwargs[self.key]
|
|
1591
|
-
return {self.output_key: ops.sum(data, axis=self.axis)}
|
|
1592
|
-
|
|
1593
|
-
|
|
1594
1580
|
@ops_registry("delay_and_sum")
|
|
1595
1581
|
class DelayAndSum(Operation):
|
|
1596
1582
|
"""Sums time-delayed signals along channels and transmits."""
|
|
@@ -2124,29 +2110,37 @@ class Demodulate(Operation):
|
|
|
2124
2110
|
class Lambda(Operation):
|
|
2125
2111
|
"""Use any function as an operation."""
|
|
2126
2112
|
|
|
2127
|
-
def __init__(self, func,
|
|
2128
|
-
|
|
2129
|
-
|
|
2130
|
-
|
|
2113
|
+
def __init__(self, func, **kwargs):
|
|
2114
|
+
# Split kwargs into kwargs for partial and __init__
|
|
2115
|
+
op_kwargs = {k: v for k, v in kwargs.items() if k not in func.__code__.co_varnames}
|
|
2116
|
+
func_kwargs = {k: v for k, v in kwargs.items() if k in func.__code__.co_varnames}
|
|
2117
|
+
Lambda._check_if_unary(func, **func_kwargs)
|
|
2131
2118
|
|
|
2132
|
-
|
|
2133
|
-
|
|
2134
|
-
data = self.func(data)
|
|
2135
|
-
return {self.output_key: data}
|
|
2136
|
-
|
|
2137
|
-
|
|
2138
|
-
@ops_registry("clip")
|
|
2139
|
-
class Clip(Operation):
|
|
2140
|
-
"""Clip the input data to a given range."""
|
|
2119
|
+
super().__init__(**op_kwargs)
|
|
2120
|
+
self.func = partial(func, **func_kwargs)
|
|
2141
2121
|
|
|
2142
|
-
|
|
2143
|
-
|
|
2144
|
-
|
|
2145
|
-
|
|
2122
|
+
@staticmethod
|
|
2123
|
+
def _check_if_unary(func, **kwargs):
|
|
2124
|
+
"""Checks if the kwargs are sufficient to call the function as a unary operation."""
|
|
2125
|
+
sig = inspect.signature(func)
|
|
2126
|
+
# Remove arguments that are already provided in func_kwargs
|
|
2127
|
+
params = list(sig.parameters.values())
|
|
2128
|
+
remaining = [p for p in params if p.name not in kwargs]
|
|
2129
|
+
# Count required positional arguments (excluding self/cls)
|
|
2130
|
+
required_positional = [
|
|
2131
|
+
p
|
|
2132
|
+
for p in remaining
|
|
2133
|
+
if p.default is p.empty and p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
|
|
2134
|
+
]
|
|
2135
|
+
if len(required_positional) != 1:
|
|
2136
|
+
raise ValueError(
|
|
2137
|
+
f"Partial of {func.__name__} must be callable with exactly one required "
|
|
2138
|
+
f"positional argument, we still need: {required_positional}."
|
|
2139
|
+
)
|
|
2146
2140
|
|
|
2147
2141
|
def call(self, **kwargs):
|
|
2148
2142
|
data = kwargs[self.key]
|
|
2149
|
-
data =
|
|
2143
|
+
data = self.func(data)
|
|
2150
2144
|
return {self.output_key: data}
|
|
2151
2145
|
|
|
2152
2146
|
|
|
@@ -2685,6 +2679,7 @@ class AnisotropicDiffusion(Operation):
|
|
|
2685
2679
|
return result
|
|
2686
2680
|
|
|
2687
2681
|
|
|
2682
|
+
@ops_registry("channels_to_complex")
|
|
2688
2683
|
class ChannelsToComplex(Operation):
|
|
2689
2684
|
def call(self, **kwargs):
|
|
2690
2685
|
data = kwargs[self.key]
|
|
@@ -2692,6 +2687,7 @@ class ChannelsToComplex(Operation):
|
|
|
2692
2687
|
return {self.output_key: output}
|
|
2693
2688
|
|
|
2694
2689
|
|
|
2690
|
+
@ops_registry("complex_to_channels")
|
|
2695
2691
|
class ComplexToChannels(Operation):
|
|
2696
2692
|
def __init__(self, axis=-1, **kwargs):
|
|
2697
2693
|
super().__init__(**kwargs)
|
zea/scan.py
CHANGED
|
@@ -149,12 +149,13 @@ class Scan(Parameters):
|
|
|
149
149
|
Defaults to 0.0.
|
|
150
150
|
attenuation_coef (float, optional): Attenuation coefficient in dB/(MHz*cm).
|
|
151
151
|
Defaults to 0.0.
|
|
152
|
-
selected_transmits (None, str, int, list, or np.ndarray, optional):
|
|
152
|
+
selected_transmits (None, str, int, list, slice, or np.ndarray, optional):
|
|
153
153
|
Specifies which transmit events to select.
|
|
154
154
|
- None or "all": Use all transmits.
|
|
155
155
|
- "center": Use only the center transmit.
|
|
156
156
|
- int: Select this many evenly spaced transmits.
|
|
157
157
|
- list/array: Use these specific transmit indices.
|
|
158
|
+
- slice: Use transmits specified by the slice (e.g., slice(0, 10, 2)).
|
|
158
159
|
grid_type (str, optional): Type of grid to use for beamforming.
|
|
159
160
|
Can be "cartesian" or "polar". Defaults to "cartesian".
|
|
160
161
|
dynamic_range (tuple, optional): Dynamic range for image display.
|
|
@@ -171,13 +172,14 @@ class Scan(Parameters):
|
|
|
171
172
|
"pixels_per_wavelength": {"type": int, "default": 4},
|
|
172
173
|
"pfield_kwargs": {"type": dict, "default": {}},
|
|
173
174
|
"apply_lens_correction": {"type": bool, "default": False},
|
|
174
|
-
"lens_sound_speed": {"type":
|
|
175
|
+
"lens_sound_speed": {"type": float},
|
|
175
176
|
"lens_thickness": {"type": float},
|
|
176
177
|
"grid_type": {"type": str, "default": "cartesian"},
|
|
177
178
|
"polar_limits": {"type": (tuple, list)},
|
|
178
179
|
"dynamic_range": {"type": (tuple, list), "default": DEFAULT_DYNAMIC_RANGE},
|
|
180
|
+
"selected_transmits": {"type": (type(None), str, int, list, slice, np.ndarray)},
|
|
179
181
|
# acquisition parameters
|
|
180
|
-
"sound_speed": {"type":
|
|
182
|
+
"sound_speed": {"type": float, "default": 1540.0},
|
|
181
183
|
"sampling_frequency": {"type": float},
|
|
182
184
|
"center_frequency": {"type": float},
|
|
183
185
|
"n_el": {"type": int},
|
|
@@ -359,6 +361,7 @@ class Scan(Parameters):
|
|
|
359
361
|
- "center": Use only the center transmit
|
|
360
362
|
- int: Select this many evenly spaced transmits
|
|
361
363
|
- list/array: Use these specific transmit indices
|
|
364
|
+
- slice: Use transmits specified by the slice (e.g., slice(0, 10, 2))
|
|
362
365
|
|
|
363
366
|
Returns:
|
|
364
367
|
The current instance for method chaining.
|
|
@@ -416,6 +419,10 @@ class Scan(Parameters):
|
|
|
416
419
|
self._invalidate_dependents("selected_transmits")
|
|
417
420
|
return self
|
|
418
421
|
|
|
422
|
+
# Handle slice - convert to list of indices
|
|
423
|
+
if isinstance(selection, slice):
|
|
424
|
+
selection = list(range(n_tx_total))[selection]
|
|
425
|
+
|
|
419
426
|
# Handle list of indices
|
|
420
427
|
if isinstance(selection, list):
|
|
421
428
|
# Validate indices
|
zea/tensor_ops.py
CHANGED
|
@@ -130,6 +130,102 @@ def extend_n_dims(arr, axis, n_dims):
|
|
|
130
130
|
return ops.reshape(arr, new_shape)
|
|
131
131
|
|
|
132
132
|
|
|
133
|
+
def vmap(fun, in_axes=0, out_axes=0):
|
|
134
|
+
"""Vectorized map.
|
|
135
|
+
|
|
136
|
+
For torch and jax backends, this uses the native vmap implementation.
|
|
137
|
+
For other backends, this a wrapper that uses `ops.vectorized_map` under the hood.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
fun: The function to be mapped.
|
|
141
|
+
in_axes: The axis or axes to be mapped over in the input.
|
|
142
|
+
Can be an integer, a tuple of integers, or None.
|
|
143
|
+
If None, the corresponding argument is not mapped over.
|
|
144
|
+
Defaults to 0.
|
|
145
|
+
out_axes: The axis or axes to be mapped over in the output.
|
|
146
|
+
Can be an integer, a tuple of integers, or None.
|
|
147
|
+
If None, the corresponding output is not mapped over.
|
|
148
|
+
Defaults to 0.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
A function that applies `fun` in a vectorized manner over the specified axes.
|
|
152
|
+
|
|
153
|
+
Raises:
|
|
154
|
+
ValueError: If the backend does not support vmap.
|
|
155
|
+
"""
|
|
156
|
+
if keras.backend.backend() == "jax":
|
|
157
|
+
import jax
|
|
158
|
+
|
|
159
|
+
return jax.vmap(fun, in_axes=in_axes, out_axes=out_axes)
|
|
160
|
+
elif keras.backend.backend() == "torch":
|
|
161
|
+
import torch
|
|
162
|
+
|
|
163
|
+
return torch.vmap(fun, in_dims=in_axes, out_dims=out_axes)
|
|
164
|
+
else:
|
|
165
|
+
return manual_vmap(fun, in_axes=in_axes, out_axes=out_axes)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def manual_vmap(fun, in_axes=0, out_axes=0):
|
|
169
|
+
"""Manual vectorized map for backends that do not support vmap."""
|
|
170
|
+
|
|
171
|
+
def find_map_length(args, in_axes):
|
|
172
|
+
"""Find the length of the axis to map over."""
|
|
173
|
+
# NOTE: only needed for numpy, the other backends can handle a singleton dimension
|
|
174
|
+
for arg, axis in zip(args, in_axes):
|
|
175
|
+
if axis is None:
|
|
176
|
+
continue
|
|
177
|
+
|
|
178
|
+
return ops.shape(arg)[axis]
|
|
179
|
+
return 1
|
|
180
|
+
|
|
181
|
+
def _moveaxes(args, in_axes, out_axes):
|
|
182
|
+
"""Move axes of the input arguments."""
|
|
183
|
+
args = list(args)
|
|
184
|
+
for i, (arg, in_axis, out_axis) in enumerate(zip(args, in_axes, out_axes)):
|
|
185
|
+
if in_axis is not None:
|
|
186
|
+
args[i] = ops.moveaxis(arg, in_axis, out_axis)
|
|
187
|
+
else:
|
|
188
|
+
args[i] = ops.repeat(arg[None], find_map_length(args, in_axes), axis=out_axis)
|
|
189
|
+
return tuple(args)
|
|
190
|
+
|
|
191
|
+
def _fun(args):
|
|
192
|
+
return fun(*args)
|
|
193
|
+
|
|
194
|
+
def wrapper(*args):
|
|
195
|
+
# If in_axes or out_axes is an int, convert to tuple
|
|
196
|
+
if isinstance(in_axes, int):
|
|
197
|
+
_in_axes = (in_axes,) * len(args)
|
|
198
|
+
else:
|
|
199
|
+
_in_axes = in_axes
|
|
200
|
+
if isinstance(out_axes, int):
|
|
201
|
+
_out_axes = (out_axes,) * len(args)
|
|
202
|
+
else:
|
|
203
|
+
_out_axes = out_axes
|
|
204
|
+
zeros = (0,) * len(args)
|
|
205
|
+
|
|
206
|
+
# Check that in_axes and out_axes are tuples
|
|
207
|
+
if not isinstance(_in_axes, tuple):
|
|
208
|
+
raise ValueError("in_axes must be an int or a tuple of ints.")
|
|
209
|
+
if not isinstance(_out_axes, tuple):
|
|
210
|
+
raise ValueError("out_axes must be an int or a tuple of ints.")
|
|
211
|
+
|
|
212
|
+
args = _moveaxes(args, _in_axes, zeros)
|
|
213
|
+
outputs = ops.vectorized_map(_fun, tuple(args))
|
|
214
|
+
|
|
215
|
+
tuple_output = isinstance(outputs, (tuple, list))
|
|
216
|
+
if not tuple_output:
|
|
217
|
+
outputs = (outputs,)
|
|
218
|
+
|
|
219
|
+
outputs = _moveaxes(outputs, zeros, _out_axes)
|
|
220
|
+
|
|
221
|
+
if not tuple_output:
|
|
222
|
+
outputs = outputs[0]
|
|
223
|
+
|
|
224
|
+
return outputs
|
|
225
|
+
|
|
226
|
+
return wrapper
|
|
227
|
+
|
|
228
|
+
|
|
133
229
|
def func_with_one_batch_dim(
|
|
134
230
|
func,
|
|
135
231
|
tensor,
|
|
@@ -1270,6 +1366,14 @@ def L2(x):
|
|
|
1270
1366
|
return ops.sqrt(ops.sum(x**2))
|
|
1271
1367
|
|
|
1272
1368
|
|
|
1369
|
+
def L1(x):
|
|
1370
|
+
"""L1 norm of a tensor.
|
|
1371
|
+
|
|
1372
|
+
Implementation of L1 norm: https://mathworld.wolfram.com/L1-Norm.html
|
|
1373
|
+
"""
|
|
1374
|
+
return ops.sum(ops.abs(x))
|
|
1375
|
+
|
|
1376
|
+
|
|
1273
1377
|
def linear_sum_assignment(cost):
|
|
1274
1378
|
"""Greedy linear sum assignment.
|
|
1275
1379
|
|
|
@@ -1325,3 +1429,150 @@ else:
|
|
|
1325
1429
|
def safe_vectorize(pyfunc, excluded=None, signature=None):
|
|
1326
1430
|
"""Just a wrapper around ops.vectorize."""
|
|
1327
1431
|
return ops.vectorize(pyfunc, excluded=excluded, signature=signature)
|
|
1432
|
+
|
|
1433
|
+
|
|
1434
|
+
def apply_along_axis(func1d, axis, arr, *args, **kwargs):
|
|
1435
|
+
"""Apply a function to 1D array slices along an axis.
|
|
1436
|
+
|
|
1437
|
+
Keras implementation of numpy.apply_along_axis using keras.ops.vectorized_map.
|
|
1438
|
+
|
|
1439
|
+
Args:
|
|
1440
|
+
func1d: A callable function with signature ``func1d(arr, /, *args, **kwargs)``
|
|
1441
|
+
where ``*args`` and ``**kwargs`` are the additional positional and keyword
|
|
1442
|
+
arguments passed to apply_along_axis.
|
|
1443
|
+
axis: Integer axis along which to apply the function.
|
|
1444
|
+
arr: The array over which to apply the function.
|
|
1445
|
+
*args: Additional positional arguments passed through to func1d.
|
|
1446
|
+
**kwargs: Additional keyword arguments passed through to func1d.
|
|
1447
|
+
|
|
1448
|
+
Returns:
|
|
1449
|
+
The result of func1d applied along the specified axis.
|
|
1450
|
+
"""
|
|
1451
|
+
# Convert to keras tensor
|
|
1452
|
+
arr = ops.convert_to_tensor(arr)
|
|
1453
|
+
|
|
1454
|
+
# Get array dimensions
|
|
1455
|
+
num_dims = len(arr.shape)
|
|
1456
|
+
|
|
1457
|
+
# Canonicalize axis (handle negative indices)
|
|
1458
|
+
if axis < 0:
|
|
1459
|
+
axis = num_dims + axis
|
|
1460
|
+
|
|
1461
|
+
if axis < 0 or axis >= num_dims:
|
|
1462
|
+
raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}")
|
|
1463
|
+
|
|
1464
|
+
# Create a wrapper function that applies func1d with the additional arguments
|
|
1465
|
+
def func(slice_arr):
|
|
1466
|
+
return func1d(slice_arr, *args, **kwargs)
|
|
1467
|
+
|
|
1468
|
+
# Recursively build up vectorized maps following the JAX pattern
|
|
1469
|
+
# For dimensions after the target axis (right side)
|
|
1470
|
+
for i in range(1, num_dims - axis):
|
|
1471
|
+
prev_func = func
|
|
1472
|
+
|
|
1473
|
+
def make_func(f, dim_offset):
|
|
1474
|
+
def vectorized_func(x):
|
|
1475
|
+
# Move the dimension we want to map over to the front
|
|
1476
|
+
perm = list(range(len(x.shape)))
|
|
1477
|
+
perm[0], perm[dim_offset] = perm[dim_offset], perm[0]
|
|
1478
|
+
x_moved = ops.transpose(x, perm)
|
|
1479
|
+
result = ops.vectorized_map(f, x_moved)
|
|
1480
|
+
# Move the result dimension back if needed
|
|
1481
|
+
if len(result.shape) > 0:
|
|
1482
|
+
result_perm = list(range(len(result.shape)))
|
|
1483
|
+
if len(result_perm) > dim_offset:
|
|
1484
|
+
result_perm[0], result_perm[dim_offset] = (
|
|
1485
|
+
result_perm[dim_offset],
|
|
1486
|
+
result_perm[0],
|
|
1487
|
+
)
|
|
1488
|
+
result = ops.transpose(result, result_perm)
|
|
1489
|
+
return result
|
|
1490
|
+
|
|
1491
|
+
return vectorized_func
|
|
1492
|
+
|
|
1493
|
+
func = make_func(prev_func, i)
|
|
1494
|
+
|
|
1495
|
+
# For dimensions before the target axis (left side)
|
|
1496
|
+
for i in range(axis):
|
|
1497
|
+
prev_func = func
|
|
1498
|
+
|
|
1499
|
+
def make_func(f):
|
|
1500
|
+
return lambda x: ops.vectorized_map(f, x)
|
|
1501
|
+
|
|
1502
|
+
func = make_func(prev_func)
|
|
1503
|
+
|
|
1504
|
+
return func(arr)
|
|
1505
|
+
|
|
1506
|
+
|
|
1507
|
+
def correlate(x, y, mode="full"):
|
|
1508
|
+
"""
|
|
1509
|
+
Complex correlation via splitting real and imaginary parts.
|
|
1510
|
+
Equivalent to np.correlate(x, y, mode).
|
|
1511
|
+
|
|
1512
|
+
NOTE: this function exists because tensorflow does not support complex correlation.
|
|
1513
|
+
NOTE: tensorflow also handles padding differently than numpy, so we manually pad the input.
|
|
1514
|
+
|
|
1515
|
+
Args:
|
|
1516
|
+
x: np.ndarray (complex or real)
|
|
1517
|
+
y: np.ndarray (complex or real)
|
|
1518
|
+
mode: "full", "valid", or "same"
|
|
1519
|
+
"""
|
|
1520
|
+
x = ops.convert_to_tensor(x)
|
|
1521
|
+
y = ops.convert_to_tensor(y)
|
|
1522
|
+
|
|
1523
|
+
is_complex = "complex" in ops.dtype(x) or "complex" in ops.dtype(y)
|
|
1524
|
+
|
|
1525
|
+
# Cast to complex64 if real
|
|
1526
|
+
if not is_complex:
|
|
1527
|
+
x = ops.cast(x, "complex64")
|
|
1528
|
+
y = ops.cast(y, "complex64")
|
|
1529
|
+
|
|
1530
|
+
# Split into real and imaginary
|
|
1531
|
+
xr, xi = ops.real(x), ops.imag(x)
|
|
1532
|
+
yr, yi = ops.real(y), ops.imag(y)
|
|
1533
|
+
|
|
1534
|
+
# Pad to do full correlation
|
|
1535
|
+
pad_left = ops.shape(y)[0] - 1
|
|
1536
|
+
pad_right = ops.shape(y)[0] - 1
|
|
1537
|
+
xr = ops.pad(xr, [[pad_left, pad_right]])
|
|
1538
|
+
xi = ops.pad(xi, [[pad_left, pad_right]])
|
|
1539
|
+
|
|
1540
|
+
# Correlation: sum over x[n] * conj(y[n+k])
|
|
1541
|
+
rr = ops.correlate(xr, yr, mode="valid")
|
|
1542
|
+
ii = ops.correlate(xi, yi, mode="valid")
|
|
1543
|
+
ri = ops.correlate(xr, yi, mode="valid")
|
|
1544
|
+
ir = ops.correlate(xi, yr, mode="valid")
|
|
1545
|
+
|
|
1546
|
+
real_part = rr + ii
|
|
1547
|
+
imag_part = ir - ri
|
|
1548
|
+
|
|
1549
|
+
real_part = ops.cast(real_part, "complex64")
|
|
1550
|
+
imag_part = ops.cast(imag_part, "complex64")
|
|
1551
|
+
|
|
1552
|
+
complex_tensor = real_part + 1j * imag_part
|
|
1553
|
+
|
|
1554
|
+
# Extract relevant part based on mode
|
|
1555
|
+
full_length = ops.shape(real_part)[0]
|
|
1556
|
+
x_len = ops.shape(x)[0]
|
|
1557
|
+
y_len = ops.shape(y)[0]
|
|
1558
|
+
|
|
1559
|
+
if mode == "same":
|
|
1560
|
+
# Return output of length max(M, N)
|
|
1561
|
+
target_len = ops.maximum(x_len, y_len)
|
|
1562
|
+
start = ops.floor((full_length - target_len) / 2)
|
|
1563
|
+
start = ops.cast(start, "int32")
|
|
1564
|
+
end = start + target_len
|
|
1565
|
+
complex_tensor = complex_tensor[start:end]
|
|
1566
|
+
elif mode == "valid":
|
|
1567
|
+
# Return output of length max(M, N) - min(M, N) + 1
|
|
1568
|
+
target_len = ops.maximum(x_len, y_len) - ops.minimum(x_len, y_len) + 1
|
|
1569
|
+
start = ops.ceil((full_length - target_len) / 2)
|
|
1570
|
+
start = ops.cast(start, "int32")
|
|
1571
|
+
end = start + target_len
|
|
1572
|
+
complex_tensor = complex_tensor[start:end]
|
|
1573
|
+
# For "full" mode, use the entire result (no slicing needed)
|
|
1574
|
+
|
|
1575
|
+
if is_complex:
|
|
1576
|
+
return complex_tensor
|
|
1577
|
+
else:
|
|
1578
|
+
return ops.real(complex_tensor)
|
zea/tools/fit_scan_cone.py
CHANGED
|
@@ -83,8 +83,7 @@ def filter_edge_points_by_boundary(edge_points, is_left=True, min_cone_half_angl
|
|
|
83
83
|
|
|
84
84
|
|
|
85
85
|
def detect_cone_parameters(image, min_cone_half_angle_deg=20, threshold=15):
|
|
86
|
-
"""
|
|
87
|
-
Detect the ultrasound cone parameters from a grayscale image.
|
|
86
|
+
"""Detect the ultrasound cone parameters from a grayscale image.
|
|
88
87
|
|
|
89
88
|
This function performs the following steps:
|
|
90
89
|
1. Thresholds the image to create a binary mask
|
|
@@ -209,6 +208,7 @@ def detect_cone_parameters(image, min_cone_half_angle_deg=20, threshold=15):
|
|
|
209
208
|
apex_x = left_a + left_b * apex_y
|
|
210
209
|
|
|
211
210
|
# Calculate cone height
|
|
211
|
+
max_y = ops.cast(max_y, apex_y.dtype)
|
|
212
212
|
cone_height = max_y - apex_y
|
|
213
213
|
|
|
214
214
|
# Calculate opening angle from the line slopes
|
zea/tools/selection_tool.py
CHANGED
|
@@ -250,10 +250,13 @@ def add_shape_from_mask(ax, mask, **kwargs):
|
|
|
250
250
|
Returns:
|
|
251
251
|
plt.ax: matplotlib axis with shape added
|
|
252
252
|
"""
|
|
253
|
-
#
|
|
254
|
-
|
|
253
|
+
# Pad mask to ensure edge contours are found
|
|
254
|
+
padded_mask = np.pad(mask, pad_width=1, mode="constant", constant_values=0)
|
|
255
|
+
contours = measure.find_contours(padded_mask, 0.5)
|
|
255
256
|
patches = []
|
|
256
257
|
for contour in contours:
|
|
258
|
+
# Remove padding offset
|
|
259
|
+
contour -= 1
|
|
257
260
|
path = pltPath(contour[:, ::-1])
|
|
258
261
|
patch = PathPatch(path, **kwargs)
|
|
259
262
|
patches.append(ax.add_patch(patch))
|
|
@@ -593,10 +596,11 @@ def interpolate_masks(
|
|
|
593
596
|
assert all(mask.shape == mask_shape for mask in masks), "All masks must have the same shape."
|
|
594
597
|
|
|
595
598
|
# distribute number of frames over number of masks
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
599
|
+
base_frames = num_frames // (number_of_masks - 1)
|
|
600
|
+
remainder = num_frames % (number_of_masks - 1)
|
|
601
|
+
num_frames_per_segment = [base_frames] * (number_of_masks - 1)
|
|
602
|
+
for i in range(remainder):
|
|
603
|
+
num_frames_per_segment[i] += 1
|
|
600
604
|
|
|
601
605
|
if rectangle:
|
|
602
606
|
# get the rectangles
|
|
@@ -615,7 +619,6 @@ def interpolate_masks(
|
|
|
615
619
|
for _rectangle in rectangles:
|
|
616
620
|
interpolated_masks.append(reconstruct_mask_from_rectangle(_rectangle, mask_shape))
|
|
617
621
|
return interpolated_masks
|
|
618
|
-
|
|
619
622
|
# get the contours
|
|
620
623
|
polygons = []
|
|
621
624
|
for mask in masks:
|
|
@@ -726,6 +729,17 @@ def update_imshow_with_mask(
|
|
|
726
729
|
return imshow_obj, mask_obj
|
|
727
730
|
|
|
728
731
|
|
|
732
|
+
def ask_for_title():
|
|
733
|
+
print("What are you selecting?")
|
|
734
|
+
title = input("Enter a title for the selection: ")
|
|
735
|
+
if not title:
|
|
736
|
+
raise ValueError("Title cannot be empty.")
|
|
737
|
+
# Convert title to snake_case
|
|
738
|
+
title = title.strip().replace(" ", "_").lower()
|
|
739
|
+
print(f"Title set to: {title}")
|
|
740
|
+
return title
|
|
741
|
+
|
|
742
|
+
|
|
729
743
|
def main():
|
|
730
744
|
"""Main function for interactive selector on multiple images."""
|
|
731
745
|
print(
|
|
@@ -751,6 +765,7 @@ def main():
|
|
|
751
765
|
raise e
|
|
752
766
|
print("No more images selected. Continuing...")
|
|
753
767
|
|
|
768
|
+
title = ask_for_title()
|
|
754
769
|
selector = ask_for_selection_tool()
|
|
755
770
|
|
|
756
771
|
if same_images is True:
|
|
@@ -829,6 +844,10 @@ def main():
|
|
|
829
844
|
else:
|
|
830
845
|
add_shape_from_mask(axs, interpolated_masks[0], alpha=0.5)
|
|
831
846
|
|
|
847
|
+
filestem = Path(file.parent / f"{file.stem}_{title}_annotations.gif")
|
|
848
|
+
np.save(filestem.with_suffix(".npy"), interpolated_masks)
|
|
849
|
+
print(f"Succesfully saved interpolated masks to {log.yellow(filestem.with_suffix('.npy'))}")
|
|
850
|
+
|
|
832
851
|
fps = ask_save_animation_with_fps()
|
|
833
852
|
|
|
834
853
|
ani = FuncAnimation(
|
|
@@ -838,9 +857,9 @@ def main():
|
|
|
838
857
|
fargs=(axs, imshow_obj, images, interpolated_masks, selector),
|
|
839
858
|
interval=1000 / fps,
|
|
840
859
|
)
|
|
841
|
-
filename =
|
|
860
|
+
filename = filestem.with_suffix(".gif")
|
|
842
861
|
ani.save(filename, writer="pillow")
|
|
843
|
-
print(f"Succesfully saved animation as {filename}")
|
|
862
|
+
print(f"Succesfully saved animation as {log.yellow(filename)}")
|
|
844
863
|
|
|
845
864
|
|
|
846
865
|
if __name__ == "__main__":
|
|
@@ -1,7 +1,8 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: zea
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.6
|
|
4
4
|
Summary: A Toolbox for Cognitive Ultrasound Imaging. Provides a set of tools for processing of ultrasound data, all built in your favorite machine learning framework.
|
|
5
|
+
License-File: LICENSE
|
|
5
6
|
Keywords: ultrasound,machine learning,beamforming
|
|
6
7
|
Author: Tristan Stevens
|
|
7
8
|
Author-email: t.s.w.stevens@tue.nl
|
|
@@ -21,6 +22,7 @@ Provides-Extra: display
|
|
|
21
22
|
Provides-Extra: display-headless
|
|
22
23
|
Provides-Extra: docs
|
|
23
24
|
Provides-Extra: jax
|
|
25
|
+
Provides-Extra: models
|
|
24
26
|
Provides-Extra: tests
|
|
25
27
|
Requires-Dist: IPython ; extra == "dev"
|
|
26
28
|
Requires-Dist: IPython ; extra == "docs"
|
|
@@ -33,13 +35,14 @@ Requires-Dist: furo ; extra == "dev"
|
|
|
33
35
|
Requires-Dist: furo ; extra == "docs"
|
|
34
36
|
Requires-Dist: h5py (>=3.11)
|
|
35
37
|
Requires-Dist: huggingface_hub (>=0.26)
|
|
38
|
+
Requires-Dist: imageio[ffmpeg] (>=2.0)
|
|
36
39
|
Requires-Dist: ipykernel (>=6.29.5) ; extra == "dev"
|
|
37
40
|
Requires-Dist: ipykernel (>=6.29.5) ; extra == "tests"
|
|
38
41
|
Requires-Dist: ipywidgets ; extra == "dev"
|
|
39
42
|
Requires-Dist: ipywidgets ; extra == "tests"
|
|
40
43
|
Requires-Dist: jax ; extra == "backends"
|
|
41
44
|
Requires-Dist: jax[cuda12-pip] (>=0.4.26) ; extra == "jax"
|
|
42
|
-
Requires-Dist: keras (>=3.
|
|
45
|
+
Requires-Dist: keras (>=3.11)
|
|
43
46
|
Requires-Dist: matplotlib (>=3.8)
|
|
44
47
|
Requires-Dist: mock ; extra == "dev"
|
|
45
48
|
Requires-Dist: mock ; extra == "docs"
|
|
@@ -48,6 +51,8 @@ Requires-Dist: myst-parser ; extra == "docs"
|
|
|
48
51
|
Requires-Dist: nbsphinx ; extra == "dev"
|
|
49
52
|
Requires-Dist: nbsphinx ; extra == "docs"
|
|
50
53
|
Requires-Dist: numpy (>=1.24)
|
|
54
|
+
Requires-Dist: onnxruntime (>=1.15) ; extra == "dev"
|
|
55
|
+
Requires-Dist: onnxruntime (>=1.15) ; extra == "models"
|
|
51
56
|
Requires-Dist: opencv-python (>=4) ; extra == "display"
|
|
52
57
|
Requires-Dist: opencv-python-headless (>=4) ; extra == "dev"
|
|
53
58
|
Requires-Dist: opencv-python-headless (>=4) ; extra == "display-headless"
|
|
@@ -69,6 +74,8 @@ Requires-Dist: scikit-learn (>=1.4)
|
|
|
69
74
|
Requires-Dist: scipy (>=1.13)
|
|
70
75
|
Requires-Dist: sphinx ; extra == "dev"
|
|
71
76
|
Requires-Dist: sphinx ; extra == "docs"
|
|
77
|
+
Requires-Dist: sphinx-argparse ; extra == "dev"
|
|
78
|
+
Requires-Dist: sphinx-argparse ; extra == "docs"
|
|
72
79
|
Requires-Dist: sphinx-autobuild ; extra == "dev"
|
|
73
80
|
Requires-Dist: sphinx-autobuild ; extra == "docs"
|
|
74
81
|
Requires-Dist: sphinx-autodoc-typehints ; extra == "dev"
|