zea 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +54 -19
- zea/agent/__init__.py +12 -12
- zea/agent/masks.py +2 -1
- zea/backend/tensorflow/dataloader.py +2 -1
- zea/beamform/beamformer.py +100 -50
- zea/beamform/lens_correction.py +9 -2
- zea/beamform/pfield.py +9 -2
- zea/config.py +34 -25
- zea/data/__init__.py +22 -16
- zea/data/convert/camus.py +2 -1
- zea/data/convert/echonet.py +4 -4
- zea/data/convert/echonetlvh/convert_raw_to_usbmd.py +1 -1
- zea/data/convert/matlab.py +11 -4
- zea/data/data_format.py +31 -30
- zea/data/datasets.py +7 -5
- zea/data/file.py +104 -2
- zea/data/layers.py +3 -3
- zea/datapaths.py +16 -4
- zea/display.py +7 -5
- zea/interface.py +14 -16
- zea/internal/_generate_keras_ops.py +6 -7
- zea/internal/cache.py +2 -49
- zea/internal/config/validation.py +1 -2
- zea/internal/core.py +69 -6
- zea/internal/device.py +6 -2
- zea/internal/dummy_scan.py +330 -0
- zea/internal/operators.py +114 -2
- zea/internal/parameters.py +101 -70
- zea/internal/setup_zea.py +5 -6
- zea/internal/utils.py +282 -0
- zea/io_lib.py +247 -19
- zea/keras_ops.py +74 -4
- zea/log.py +9 -7
- zea/metrics.py +15 -7
- zea/models/__init__.py +30 -20
- zea/models/base.py +30 -14
- zea/models/carotid_segmenter.py +19 -4
- zea/models/diffusion.py +173 -12
- zea/models/echonet.py +22 -8
- zea/models/echonetlvh.py +31 -7
- zea/models/lpips.py +19 -2
- zea/models/lv_segmentation.py +28 -11
- zea/models/preset_utils.py +5 -5
- zea/models/regional_quality.py +30 -10
- zea/models/taesd.py +21 -5
- zea/models/unet.py +15 -1
- zea/ops.py +390 -196
- zea/probes.py +6 -6
- zea/scan.py +109 -49
- zea/simulator.py +24 -21
- zea/tensor_ops.py +406 -302
- zea/tools/hf.py +1 -1
- zea/tools/selection_tool.py +47 -86
- zea/utils.py +92 -480
- zea/visualize.py +177 -39
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/METADATA +4 -2
- zea-0.0.7.dist-info/RECORD +114 -0
- zea-0.0.6.dist-info/RECORD +0 -112
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/WHEEL +0 -0
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/entry_points.txt +0 -0
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/licenses/LICENSE +0 -0
zea/tensor_ops.py
CHANGED
|
@@ -130,162 +130,6 @@ def extend_n_dims(arr, axis, n_dims):
|
|
|
130
130
|
return ops.reshape(arr, new_shape)
|
|
131
131
|
|
|
132
132
|
|
|
133
|
-
def vmap(fun, in_axes=0, out_axes=0):
|
|
134
|
-
"""Vectorized map.
|
|
135
|
-
|
|
136
|
-
For torch and jax backends, this uses the native vmap implementation.
|
|
137
|
-
For other backends, this a wrapper that uses `ops.vectorized_map` under the hood.
|
|
138
|
-
|
|
139
|
-
Args:
|
|
140
|
-
fun: The function to be mapped.
|
|
141
|
-
in_axes: The axis or axes to be mapped over in the input.
|
|
142
|
-
Can be an integer, a tuple of integers, or None.
|
|
143
|
-
If None, the corresponding argument is not mapped over.
|
|
144
|
-
Defaults to 0.
|
|
145
|
-
out_axes: The axis or axes to be mapped over in the output.
|
|
146
|
-
Can be an integer, a tuple of integers, or None.
|
|
147
|
-
If None, the corresponding output is not mapped over.
|
|
148
|
-
Defaults to 0.
|
|
149
|
-
|
|
150
|
-
Returns:
|
|
151
|
-
A function that applies `fun` in a vectorized manner over the specified axes.
|
|
152
|
-
|
|
153
|
-
Raises:
|
|
154
|
-
ValueError: If the backend does not support vmap.
|
|
155
|
-
"""
|
|
156
|
-
if keras.backend.backend() == "jax":
|
|
157
|
-
import jax
|
|
158
|
-
|
|
159
|
-
return jax.vmap(fun, in_axes=in_axes, out_axes=out_axes)
|
|
160
|
-
elif keras.backend.backend() == "torch":
|
|
161
|
-
import torch
|
|
162
|
-
|
|
163
|
-
return torch.vmap(fun, in_dims=in_axes, out_dims=out_axes)
|
|
164
|
-
else:
|
|
165
|
-
return manual_vmap(fun, in_axes=in_axes, out_axes=out_axes)
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
def manual_vmap(fun, in_axes=0, out_axes=0):
|
|
169
|
-
"""Manual vectorized map for backends that do not support vmap."""
|
|
170
|
-
|
|
171
|
-
def find_map_length(args, in_axes):
|
|
172
|
-
"""Find the length of the axis to map over."""
|
|
173
|
-
# NOTE: only needed for numpy, the other backends can handle a singleton dimension
|
|
174
|
-
for arg, axis in zip(args, in_axes):
|
|
175
|
-
if axis is None:
|
|
176
|
-
continue
|
|
177
|
-
|
|
178
|
-
return ops.shape(arg)[axis]
|
|
179
|
-
return 1
|
|
180
|
-
|
|
181
|
-
def _moveaxes(args, in_axes, out_axes):
|
|
182
|
-
"""Move axes of the input arguments."""
|
|
183
|
-
args = list(args)
|
|
184
|
-
for i, (arg, in_axis, out_axis) in enumerate(zip(args, in_axes, out_axes)):
|
|
185
|
-
if in_axis is not None:
|
|
186
|
-
args[i] = ops.moveaxis(arg, in_axis, out_axis)
|
|
187
|
-
else:
|
|
188
|
-
args[i] = ops.repeat(arg[None], find_map_length(args, in_axes), axis=out_axis)
|
|
189
|
-
return tuple(args)
|
|
190
|
-
|
|
191
|
-
def _fun(args):
|
|
192
|
-
return fun(*args)
|
|
193
|
-
|
|
194
|
-
def wrapper(*args):
|
|
195
|
-
# If in_axes or out_axes is an int, convert to tuple
|
|
196
|
-
if isinstance(in_axes, int):
|
|
197
|
-
_in_axes = (in_axes,) * len(args)
|
|
198
|
-
else:
|
|
199
|
-
_in_axes = in_axes
|
|
200
|
-
if isinstance(out_axes, int):
|
|
201
|
-
_out_axes = (out_axes,) * len(args)
|
|
202
|
-
else:
|
|
203
|
-
_out_axes = out_axes
|
|
204
|
-
zeros = (0,) * len(args)
|
|
205
|
-
|
|
206
|
-
# Check that in_axes and out_axes are tuples
|
|
207
|
-
if not isinstance(_in_axes, tuple):
|
|
208
|
-
raise ValueError("in_axes must be an int or a tuple of ints.")
|
|
209
|
-
if not isinstance(_out_axes, tuple):
|
|
210
|
-
raise ValueError("out_axes must be an int or a tuple of ints.")
|
|
211
|
-
|
|
212
|
-
args = _moveaxes(args, _in_axes, zeros)
|
|
213
|
-
outputs = ops.vectorized_map(_fun, tuple(args))
|
|
214
|
-
|
|
215
|
-
tuple_output = isinstance(outputs, (tuple, list))
|
|
216
|
-
if not tuple_output:
|
|
217
|
-
outputs = (outputs,)
|
|
218
|
-
|
|
219
|
-
outputs = _moveaxes(outputs, zeros, _out_axes)
|
|
220
|
-
|
|
221
|
-
if not tuple_output:
|
|
222
|
-
outputs = outputs[0]
|
|
223
|
-
|
|
224
|
-
return outputs
|
|
225
|
-
|
|
226
|
-
return wrapper
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
def func_with_one_batch_dim(
|
|
230
|
-
func,
|
|
231
|
-
tensor,
|
|
232
|
-
n_batch_dims: int,
|
|
233
|
-
batch_size: int | None = None,
|
|
234
|
-
func_axis: int | None = None,
|
|
235
|
-
**kwargs,
|
|
236
|
-
):
|
|
237
|
-
"""Wraps a function to apply it to an input tensor with one or more batch dimensions.
|
|
238
|
-
|
|
239
|
-
The function will be executed in parallel on all batch elements.
|
|
240
|
-
|
|
241
|
-
Args:
|
|
242
|
-
func (function): The function to apply to the image.
|
|
243
|
-
Will take the `func_axis` output from the function.
|
|
244
|
-
tensor (Tensor): The input tensor.
|
|
245
|
-
n_batch_dims (int): The number of batch dimensions in the input tensor.
|
|
246
|
-
Expects the input to start with n_batch_dims batch dimensions. Defaults to 2.
|
|
247
|
-
batch_size (int, optional): Integer specifying the size of the batch for
|
|
248
|
-
each step to execute in parallel. Defaults to None, in which case the function
|
|
249
|
-
will run everything in parallel.
|
|
250
|
-
func_axis (int, optional): If `func` returns mulitple outputs, this axis will be returned.
|
|
251
|
-
**kwargs: Additional keyword arguments to pass to the function.
|
|
252
|
-
|
|
253
|
-
Returns:
|
|
254
|
-
The output tensor with the same batch dimensions as the input tensor.
|
|
255
|
-
|
|
256
|
-
Raises:
|
|
257
|
-
ValueError: If the number of batch dimensions is greater than the rank of the input tensor.
|
|
258
|
-
"""
|
|
259
|
-
# Extract the shape of the batch dimensions from the input tensor
|
|
260
|
-
batch_dims = ops.shape(tensor)[:n_batch_dims]
|
|
261
|
-
|
|
262
|
-
# Extract the shape of the remaining (non-batch) dimensions
|
|
263
|
-
other_dims = ops.shape(tensor)[n_batch_dims:]
|
|
264
|
-
|
|
265
|
-
# Reshape the input tensor to merge all batch dimensions into one
|
|
266
|
-
reshaped_input = ops.reshape(tensor, [-1, *other_dims])
|
|
267
|
-
|
|
268
|
-
# Apply the given function to the reshaped input tensor
|
|
269
|
-
if batch_size is None:
|
|
270
|
-
reshaped_output = func(reshaped_input, **kwargs)
|
|
271
|
-
else:
|
|
272
|
-
reshaped_output = batched_map(func, reshaped_input, batch_size=batch_size)
|
|
273
|
-
|
|
274
|
-
# If the function returns multiple outputs, select the one corresponding to `func_axis`
|
|
275
|
-
if isinstance(reshaped_output, (tuple, list)):
|
|
276
|
-
if func_axis is None:
|
|
277
|
-
raise ValueError(
|
|
278
|
-
"func_axis must be specified when the function returns multiple outputs."
|
|
279
|
-
)
|
|
280
|
-
reshaped_output = reshaped_output[func_axis]
|
|
281
|
-
|
|
282
|
-
# Extract the shape of the output tensor after applying the function (excluding the batch dim)
|
|
283
|
-
output_other_dims = ops.shape(reshaped_output)[1:]
|
|
284
|
-
|
|
285
|
-
# Reshape the output tensor to restore the original batch dimensions
|
|
286
|
-
return ops.reshape(reshaped_output, [*batch_dims, *output_other_dims])
|
|
287
|
-
|
|
288
|
-
|
|
289
133
|
def matrix_power(matrix, power):
|
|
290
134
|
"""Compute the power of a square matrix.
|
|
291
135
|
|
|
@@ -416,97 +260,346 @@ def batch_cov(x, rowvar=True, bias=False, ddof=None):
|
|
|
416
260
|
return cov_matrices
|
|
417
261
|
|
|
418
262
|
|
|
419
|
-
def
|
|
420
|
-
"""
|
|
263
|
+
def simple_map(function, elements):
|
|
264
|
+
"""Like `ops.map` but no tracing or jit compilation."""
|
|
421
265
|
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
266
|
+
if elements is None:
|
|
267
|
+
return function(None)
|
|
268
|
+
|
|
269
|
+
multiple_inputs = isinstance(elements, (list, tuple))
|
|
425
270
|
|
|
426
|
-
|
|
427
|
-
|
|
271
|
+
outputs = []
|
|
272
|
+
if not multiple_inputs:
|
|
273
|
+
batch_size = elements.shape[0]
|
|
274
|
+
for index in range(batch_size):
|
|
275
|
+
outputs.append(function(elements[index]))
|
|
428
276
|
else:
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
277
|
+
batch_size = elements[0].shape[0]
|
|
278
|
+
for index in range(batch_size):
|
|
279
|
+
outputs.append(function([e[index] if e is not None else None for e in elements]))
|
|
280
|
+
|
|
281
|
+
if isinstance(outputs[0], (list, tuple)):
|
|
282
|
+
return [ops.stack(tensors) for tensors in zip(*outputs)]
|
|
283
|
+
else:
|
|
284
|
+
return ops.stack(outputs)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
if keras.backend.backend() == "numpy":
|
|
288
|
+
|
|
289
|
+
def vectorized_map(function, elements):
|
|
290
|
+
"""Fixes keras.ops.vectorized_map in numpy backend with multiple outputs."""
|
|
291
|
+
if not isinstance(elements, (list, tuple)):
|
|
292
|
+
return np.stack([function(x) for x in elements])
|
|
293
|
+
else:
|
|
294
|
+
batch_size = elements[0].shape[0]
|
|
295
|
+
output_store = []
|
|
296
|
+
for index in range(batch_size):
|
|
297
|
+
output_store.append(function([x[index] for x in elements]))
|
|
298
|
+
if isinstance(output_store[0], (list, tuple)):
|
|
299
|
+
return [np.stack(tensors) for tensors in zip(*output_store)]
|
|
300
|
+
else:
|
|
301
|
+
return np.stack(output_store)
|
|
302
|
+
else:
|
|
303
|
+
vectorized_map = ops.vectorized_map
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def _find_map_length(args, in_axes) -> int:
|
|
307
|
+
"""Find the length of the axis to map over."""
|
|
308
|
+
for arg, axis in zip(args, in_axes):
|
|
309
|
+
if axis is None or arg is None:
|
|
310
|
+
continue
|
|
311
|
+
|
|
312
|
+
return ops.shape(arg)[axis]
|
|
313
|
+
|
|
314
|
+
raise ValueError("At least one in_axes must be non-None to determine map length.")
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def _repeat_int_to_tuple(value, length) -> Tuple[Union[int, None], ...]:
|
|
318
|
+
"""Convert an int or None to a tuple of length `length`."""
|
|
319
|
+
if isinstance(value, int) or value is None:
|
|
320
|
+
return (value,) * length
|
|
321
|
+
elif not isinstance(value, tuple):
|
|
322
|
+
raise ValueError("Value must be an int, None, or a tuple.")
|
|
323
|
+
return value
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def _map(fun, in_axes=0, out_axes=0, map_fn=None, _use_torch_vmap=False):
|
|
327
|
+
"""Mapping function, vectorized by default.
|
|
328
|
+
|
|
329
|
+
For jax, this uses the native vmap implementation.
|
|
330
|
+
For other backends, this a wrapper that uses `ops.vectorized_map` under the hood.
|
|
331
|
+
|
|
332
|
+
Probably you want to use `zea.tensor_ops.vmap` instead, which uses this function
|
|
333
|
+
with additional batching/chunking support.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
fun: The function to be mapped.
|
|
337
|
+
in_axes: The axis or axes to be mapped over in the input.
|
|
338
|
+
Can be an integer, a tuple of integers, or None.
|
|
339
|
+
If None, the corresponding argument is not mapped over.
|
|
340
|
+
Defaults to 0.
|
|
341
|
+
out_axes: The axis or axes to be mapped over in the output.
|
|
342
|
+
Can be an integer, a tuple of integers, or None.
|
|
343
|
+
If None, the corresponding output is not mapped over.
|
|
344
|
+
Defaults to 0.
|
|
345
|
+
map_fn: The mapping function to use. If None, defaults to `ops.vectorized_map`.
|
|
346
|
+
_use_torch_vmap: If True, uses PyTorch's native vmap implementation.
|
|
347
|
+
|
|
348
|
+
Returns:
|
|
349
|
+
A function that applies `fun` (in a vectorized manner) over the specified axes.
|
|
350
|
+
"""
|
|
432
351
|
|
|
352
|
+
# Use native vmap for JAX backend when map_fn is not provided
|
|
353
|
+
if keras.backend.backend() == "jax" and map_fn is None:
|
|
354
|
+
import jax
|
|
433
355
|
|
|
434
|
-
|
|
435
|
-
|
|
356
|
+
return jax.vmap(fun, in_axes=in_axes, out_axes=out_axes)
|
|
357
|
+
|
|
358
|
+
# Use native vmap for PyTorch backend when map_fn is not provided and _use_torch_vmap is True
|
|
359
|
+
if keras.backend.backend() == "torch" and map_fn is None and _use_torch_vmap:
|
|
360
|
+
import torch
|
|
361
|
+
|
|
362
|
+
return torch.vmap(fun, in_dims=in_axes, out_dims=out_axes)
|
|
363
|
+
|
|
364
|
+
# Default to keras vectorized_map if map_fn not provided
|
|
365
|
+
if map_fn is None:
|
|
366
|
+
map_fn = vectorized_map
|
|
367
|
+
|
|
368
|
+
def _moveaxes(args, in_axes, out_axes) -> tuple:
|
|
369
|
+
"""Move axes of the input arguments."""
|
|
370
|
+
args = list(args)
|
|
371
|
+
new_args = []
|
|
372
|
+
map_length = _find_map_length(args, in_axes)
|
|
373
|
+
for arg, in_axis, out_axis in zip(args, in_axes, out_axes):
|
|
374
|
+
if arg is None:
|
|
375
|
+
# filter out None arguments
|
|
376
|
+
continue
|
|
377
|
+
if out_axis is None:
|
|
378
|
+
new_args.append(ops.take(arg, 0, axis=in_axis))
|
|
379
|
+
elif in_axis is not None:
|
|
380
|
+
new_args.append(ops.moveaxis(arg, in_axis, out_axis))
|
|
381
|
+
else:
|
|
382
|
+
new_args.append(
|
|
383
|
+
ops.repeat(ops.expand_dims(arg, out_axis), map_length, axis=out_axis)
|
|
384
|
+
)
|
|
385
|
+
return tuple(new_args)
|
|
386
|
+
|
|
387
|
+
def _partial_at(func, idx, value) -> callable:
|
|
388
|
+
"""Return a new function with value inserted at index idx in args."""
|
|
389
|
+
|
|
390
|
+
def wrapper(*args, **kwargs):
|
|
391
|
+
args = list(args)
|
|
392
|
+
args[idx:idx] = [value]
|
|
393
|
+
return func(*args, **kwargs)
|
|
394
|
+
|
|
395
|
+
return wrapper
|
|
396
|
+
|
|
397
|
+
def mapped_wrapper(*args):
|
|
398
|
+
_in_axes = _repeat_int_to_tuple(in_axes, len(args))
|
|
399
|
+
|
|
400
|
+
# Move mapped axes to front
|
|
401
|
+
zeros = (0,) * len(args)
|
|
402
|
+
none_indexes = [i for i, arg in enumerate(args) if arg is None]
|
|
403
|
+
args = _moveaxes(args, _in_axes, zeros)
|
|
404
|
+
|
|
405
|
+
# Build function with None arguments prefilled
|
|
406
|
+
_prefilled_none_fn = fun
|
|
407
|
+
for none_idx in none_indexes:
|
|
408
|
+
_prefilled_none_fn = _partial_at(_prefilled_none_fn, none_idx, None)
|
|
409
|
+
|
|
410
|
+
def _fun(args):
|
|
411
|
+
return _prefilled_none_fn(*args)
|
|
412
|
+
|
|
413
|
+
outputs = map_fn(_fun, tuple(args))
|
|
414
|
+
|
|
415
|
+
# Wrap outputs to tuple for easier processing
|
|
416
|
+
tuple_output = isinstance(outputs, (tuple, list))
|
|
417
|
+
if not tuple_output:
|
|
418
|
+
outputs = (outputs,)
|
|
419
|
+
|
|
420
|
+
_out_axes = _repeat_int_to_tuple(out_axes, len(outputs))
|
|
421
|
+
|
|
422
|
+
# Move mapped axes back to original position
|
|
423
|
+
outputs = _moveaxes(outputs, zeros, _out_axes)
|
|
424
|
+
|
|
425
|
+
if not tuple_output:
|
|
426
|
+
outputs = outputs[0]
|
|
427
|
+
|
|
428
|
+
return outputs
|
|
429
|
+
|
|
430
|
+
return mapped_wrapper
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def vmap(
|
|
434
|
+
fun,
|
|
435
|
+
in_axes=0,
|
|
436
|
+
out_axes=0,
|
|
437
|
+
batch_size=None,
|
|
438
|
+
chunks=None,
|
|
439
|
+
fn_supports_batch=False,
|
|
440
|
+
disable_jit=False,
|
|
441
|
+
_use_torch_vmap=False,
|
|
442
|
+
):
|
|
443
|
+
"""`vmap` with batching or chunking support to avoid memory issues.
|
|
444
|
+
|
|
445
|
+
Basically a wrapper around `vmap` that splits the input into batches or chunks
|
|
446
|
+
to avoid memory issues with large inputs.
|
|
436
447
|
|
|
437
448
|
Args:
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
449
|
+
fun: Function to be mapped.
|
|
450
|
+
in_axes: Axis or axes to be mapped over in the input.
|
|
451
|
+
out_axes: Axis or axes to be mapped over in the output.
|
|
452
|
+
batch_size: Size of the batch for each step. If `None`, the function will be equivalent
|
|
453
|
+
to `vmap`. If `1`, will be equivalent to `map`. Mutually exclusive with `chunks`.
|
|
454
|
+
chunks: Number of chunks to split the input into. If `None` or `1`, the function will be
|
|
455
|
+
equivalent to `vmap`. Mutually exclusive with `batch_size`.
|
|
456
|
+
fn_supports_batch: If True, assumes that `fun` can already handle batched inputs.
|
|
457
|
+
In this case, this function will only handle padding and reshaping for batching.
|
|
458
|
+
disable_jit: If True, disables JIT compilation for backends that support it.
|
|
459
|
+
This can be useful for debugging. Will fall back to simple mapping.
|
|
460
|
+
_use_torch_vmap: If True, uses PyTorch's native vmap implementation.
|
|
461
|
+
Advantage: you can apply `vmap` multiple times without issues.
|
|
462
|
+
Disadvantage: does not support None arguments.
|
|
463
|
+
Returns:
|
|
464
|
+
A function that applies `fun` in a batched manner over the specified axes.
|
|
465
|
+
"""
|
|
466
|
+
|
|
467
|
+
# Mutually exclusive arguments
|
|
468
|
+
assert not (batch_size is not None and chunks is not None), (
|
|
469
|
+
"batch_size and chunks are mutually exclusive. Please specify only one of them."
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
if batch_size is not None:
|
|
473
|
+
assert batch_size > 0, "batch_size must be greater than 0."
|
|
474
|
+
if chunks is not None:
|
|
475
|
+
assert chunks > 0, "chunks must be greater than 0."
|
|
476
|
+
|
|
477
|
+
no_chunks_or_batch = batch_size is None and (chunks is None or chunks == 1)
|
|
478
|
+
|
|
479
|
+
if fn_supports_batch and no_chunks_or_batch:
|
|
480
|
+
return fun
|
|
481
|
+
|
|
482
|
+
if not fn_supports_batch:
|
|
483
|
+
# vmap to support batches
|
|
484
|
+
fun = _map(
|
|
485
|
+
fun,
|
|
486
|
+
in_axes=in_axes,
|
|
487
|
+
out_axes=out_axes,
|
|
488
|
+
map_fn=None if not disable_jit else simple_map,
|
|
489
|
+
_use_torch_vmap=_use_torch_vmap,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
if no_chunks_or_batch:
|
|
493
|
+
return fun
|
|
494
|
+
|
|
495
|
+
# map (sequentially) to support batches/chunks that fit in memory
|
|
496
|
+
map_fn = ops.map if not disable_jit else simple_map
|
|
497
|
+
|
|
498
|
+
def batched_fun(*args):
|
|
499
|
+
_in_axes = _repeat_int_to_tuple(in_axes, len(args))
|
|
500
|
+
total_length = _find_map_length(args, _in_axes)
|
|
501
|
+
|
|
502
|
+
if chunks is not None:
|
|
503
|
+
_batch_size = np.ceil(total_length / chunks).astype(int)
|
|
504
|
+
else:
|
|
505
|
+
_batch_size = batch_size
|
|
506
|
+
|
|
507
|
+
new_args = []
|
|
508
|
+
for arg, in_axis in zip(args, _in_axes):
|
|
509
|
+
if in_axis is None or arg is None:
|
|
510
|
+
new_args.append(arg)
|
|
511
|
+
continue
|
|
512
|
+
padded_arg = pad_array_to_divisible(arg, _batch_size, axis=in_axis)
|
|
513
|
+
reshaped_arg = reshape_axis(padded_arg, (-1, _batch_size), axis=in_axis)
|
|
514
|
+
new_args.append(reshaped_arg)
|
|
515
|
+
|
|
516
|
+
outputs = _map(fun, in_axes=_in_axes, out_axes=out_axes, map_fn=map_fn)(*new_args)
|
|
517
|
+
|
|
518
|
+
# Wrap outputs to tuple for easier processing
|
|
519
|
+
tuple_output = isinstance(outputs, (tuple, list))
|
|
520
|
+
if not tuple_output:
|
|
521
|
+
outputs = (outputs,)
|
|
522
|
+
|
|
523
|
+
new_outputs = []
|
|
524
|
+
_out_axes = _repeat_int_to_tuple(out_axes, len(outputs))
|
|
525
|
+
for output, out_axis in zip(outputs, _out_axes):
|
|
526
|
+
if out_axis is None:
|
|
527
|
+
new_outputs.append(output)
|
|
528
|
+
continue
|
|
529
|
+
reshaped_output = flatten(output, start_dim=out_axis, end_dim=out_axis + 1)
|
|
530
|
+
cropped_output = ops.take(reshaped_output, ops.arange(total_length), axis=out_axis)
|
|
531
|
+
new_outputs.append(cropped_output)
|
|
532
|
+
|
|
533
|
+
if tuple_output:
|
|
534
|
+
return tuple(new_outputs)
|
|
535
|
+
else:
|
|
536
|
+
return new_outputs[0]
|
|
537
|
+
|
|
538
|
+
return batched_fun
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
def func_with_one_batch_dim(
|
|
542
|
+
func,
|
|
543
|
+
tensor,
|
|
544
|
+
n_batch_dims: int,
|
|
545
|
+
batch_size: int | None = None,
|
|
546
|
+
func_axis: int | None = None,
|
|
547
|
+
**kwargs,
|
|
548
|
+
):
|
|
549
|
+
"""Wraps a function to apply it to an input tensor with one or more batch dimensions.
|
|
550
|
+
|
|
551
|
+
The function will be executed in parallel on all batch elements.
|
|
552
|
+
|
|
553
|
+
Args:
|
|
554
|
+
func (function): The function to apply to the image.
|
|
555
|
+
Will take the `func_axis` output from the function.
|
|
556
|
+
tensor (Tensor): The input tensor.
|
|
557
|
+
n_batch_dims (int): The number of batch dimensions in the input tensor.
|
|
558
|
+
Expects the input to start with n_batch_dims batch dimensions. Defaults to 2.
|
|
559
|
+
batch_size (int, optional): Integer specifying the size of the batch for
|
|
560
|
+
each step to execute in parallel. Defaults to None, in which case the function
|
|
561
|
+
will run everything in parallel.
|
|
562
|
+
func_axis (int, optional): If `func` returns mulitple outputs, this axis will be returned.
|
|
563
|
+
**kwargs: Additional keyword arguments to pass to the function.
|
|
447
564
|
|
|
448
565
|
Returns:
|
|
449
|
-
The
|
|
566
|
+
The output tensor with the same batch dimensions as the input tensor.
|
|
450
567
|
|
|
451
|
-
|
|
568
|
+
Raises:
|
|
569
|
+
ValueError: If the number of batch dimensions is greater than the rank of the input tensor.
|
|
452
570
|
"""
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
#
|
|
457
|
-
|
|
458
|
-
assert all(
|
|
459
|
-
ops.shape(xs)[0] == ops.shape(v)[0] for v in batch_kwargs.values() if v is not None
|
|
460
|
-
), "All batch kwargs must have the same first dimension size as xs."
|
|
461
|
-
|
|
462
|
-
total = ops.shape(xs)[0]
|
|
463
|
-
# TODO: could be rewritten with ops.cond such that it also works for jit=True.
|
|
464
|
-
if not jit and batch_size is not None and total <= batch_size:
|
|
465
|
-
return f(xs, **batch_kwargs)
|
|
466
|
-
|
|
467
|
-
## Non-jitted version: simply iterate over batches.
|
|
468
|
-
if not jit:
|
|
469
|
-
bs = batch_size or 1 # Default batch size to 1 if not specified.
|
|
470
|
-
outputs = []
|
|
471
|
-
for i in range(0, total, bs):
|
|
472
|
-
idx = slice(i, i + bs)
|
|
473
|
-
current_kwargs = {k: v[idx] for k, v in batch_kwargs.items()}
|
|
474
|
-
outputs.append(f(xs[idx], **current_kwargs))
|
|
475
|
-
return ops.concatenate(outputs, axis=0)
|
|
476
|
-
|
|
477
|
-
## Jitted version.
|
|
478
|
-
|
|
479
|
-
# Helper to create the batched function for use with ops.map.
|
|
480
|
-
def create_batched_f(kw_keys):
|
|
481
|
-
def batched_f(inputs):
|
|
482
|
-
x, *kw_values = inputs
|
|
483
|
-
kw = dict(zip(kw_keys, kw_values))
|
|
484
|
-
return f(x, **kw)
|
|
485
|
-
|
|
486
|
-
return batched_f
|
|
571
|
+
# Extract the shape of the batch dimensions from the input tensor
|
|
572
|
+
batch_dims = ops.shape(tensor)[:n_batch_dims]
|
|
573
|
+
|
|
574
|
+
# Extract the shape of the remaining (non-batch) dimensions
|
|
575
|
+
other_dims = ops.shape(tensor)[n_batch_dims:]
|
|
487
576
|
|
|
577
|
+
# Reshape the input tensor to merge all batch dimensions into one
|
|
578
|
+
reshaped_input = ops.reshape(tensor, [-1, *other_dims])
|
|
579
|
+
|
|
580
|
+
# Apply the given function to the reshaped input tensor
|
|
488
581
|
if batch_size is None:
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
#
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
reshaped_kwargs[k] = ops.reshape(v_padded, (-1, batch_size) + ops.shape(v_padded)[1:])
|
|
582
|
+
reshaped_output = func(reshaped_input, **kwargs)
|
|
583
|
+
else:
|
|
584
|
+
reshaped_output = vmap(
|
|
585
|
+
lambda *args: func(*args, **kwargs),
|
|
586
|
+
batch_size=batch_size,
|
|
587
|
+
fn_supports_batch=True,
|
|
588
|
+
)(reshaped_input)
|
|
589
|
+
|
|
590
|
+
# If the function returns multiple outputs, select the one corresponding to `func_axis`
|
|
591
|
+
if isinstance(reshaped_output, (tuple, list)):
|
|
592
|
+
if func_axis is None:
|
|
593
|
+
raise ValueError(
|
|
594
|
+
"func_axis must be specified when the function returns multiple outputs."
|
|
595
|
+
)
|
|
596
|
+
reshaped_output = reshaped_output[func_axis]
|
|
505
597
|
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
598
|
+
# Extract the shape of the output tensor after applying the function (excluding the batch dim)
|
|
599
|
+
output_other_dims = ops.shape(reshaped_output)[1:]
|
|
600
|
+
|
|
601
|
+
# Reshape the output tensor to restore the original batch dimensions
|
|
602
|
+
return ops.reshape(reshaped_output, [*batch_dims, *output_other_dims])
|
|
510
603
|
|
|
511
604
|
|
|
512
605
|
if keras.backend.backend() == "jax":
|
|
@@ -740,14 +833,16 @@ def stack_volume_data_along_axis(data, batch_axis: int, stack_axis: int, number:
|
|
|
740
833
|
Tensor: Reshaped tensor with data stacked along stack_axis.
|
|
741
834
|
|
|
742
835
|
Example:
|
|
743
|
-
..
|
|
836
|
+
.. doctest::
|
|
744
837
|
|
|
745
|
-
import keras
|
|
838
|
+
>>> import keras
|
|
839
|
+
>>> from zea.tensor_ops import stack_volume_data_along_axis
|
|
746
840
|
|
|
747
|
-
data = keras.random.uniform((10, 20, 30))
|
|
748
|
-
# stacking along 1st axis with 2 frames per block
|
|
749
|
-
stacked_data = stack_volume_data_along_axis(data, 0, 1, 2)
|
|
750
|
-
stacked_data.shape
|
|
841
|
+
>>> data = keras.random.uniform((10, 20, 30))
|
|
842
|
+
>>> # stacking along 1st axis with 2 frames per block
|
|
843
|
+
>>> stacked_data = stack_volume_data_along_axis(data, 0, 1, 2)
|
|
844
|
+
>>> stacked_data.shape
|
|
845
|
+
(5, 40, 30)
|
|
751
846
|
"""
|
|
752
847
|
blocks = int(ops.ceil(data.shape[batch_axis] / number))
|
|
753
848
|
data = pad_array_to_divisible(data, axis=batch_axis, N=blocks, mode="reflect")
|
|
@@ -782,13 +877,15 @@ def split_volume_data_from_axis(data, batch_axis: int, stack_axis: int, number:
|
|
|
782
877
|
Tensor: Reshaped tensor with data split back to original format.
|
|
783
878
|
|
|
784
879
|
Example:
|
|
785
|
-
..
|
|
880
|
+
.. doctest::
|
|
786
881
|
|
|
787
|
-
import keras
|
|
882
|
+
>>> import keras
|
|
883
|
+
>>> from zea.tensor_ops import split_volume_data_from_axis
|
|
788
884
|
|
|
789
|
-
data = keras.random.uniform((20, 10, 30))
|
|
790
|
-
split_data = split_volume_data_from_axis(data, 0, 1, 2, 2)
|
|
791
|
-
split_data.shape
|
|
885
|
+
>>> data = keras.random.uniform((20, 10, 30))
|
|
886
|
+
>>> split_data = split_volume_data_from_axis(data, 0, 1, 2, 2)
|
|
887
|
+
>>> split_data.shape
|
|
888
|
+
(39, 5, 30)
|
|
792
889
|
"""
|
|
793
890
|
if data.shape[stack_axis] == 1:
|
|
794
891
|
# in this case it was a broadcasted axis which does not need to be split
|
|
@@ -904,14 +1001,17 @@ def check_patches_fit(
|
|
|
904
1001
|
in the original image and the new image shape if the patches do not fit.
|
|
905
1002
|
|
|
906
1003
|
Example:
|
|
907
|
-
..
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
patches_fit
|
|
914
|
-
|
|
1004
|
+
.. doctest::
|
|
1005
|
+
|
|
1006
|
+
>>> from zea.tensor_ops import check_patches_fit
|
|
1007
|
+
>>> image_shape = (10, 10)
|
|
1008
|
+
>>> patch_shape = (4, 4)
|
|
1009
|
+
>>> overlap = (2, 2)
|
|
1010
|
+
>>> patches_fit, new_shape = check_patches_fit(image_shape, patch_shape, overlap)
|
|
1011
|
+
>>> patches_fit
|
|
1012
|
+
True
|
|
1013
|
+
>>> new_shape
|
|
1014
|
+
(10, 10)
|
|
915
1015
|
"""
|
|
916
1016
|
if overlap:
|
|
917
1017
|
stride = (np.array(patch_shape) - np.array(overlap)).astype(int)
|
|
@@ -968,13 +1068,15 @@ def images_to_patches(
|
|
|
968
1068
|
[batch, #patch_y, #patch_x, patch_size_y, patch_size_x, #channels].
|
|
969
1069
|
|
|
970
1070
|
Example:
|
|
971
|
-
..
|
|
1071
|
+
.. doctest::
|
|
972
1072
|
|
|
973
|
-
import keras
|
|
1073
|
+
>>> import keras
|
|
1074
|
+
>>> from zea.tensor_ops import images_to_patches
|
|
974
1075
|
|
|
975
|
-
images = keras.random.uniform((2, 8, 8, 3))
|
|
976
|
-
patches = images_to_patches(images, patch_shape=(4, 4), overlap=(2, 2))
|
|
977
|
-
patches.shape
|
|
1076
|
+
>>> images = keras.random.uniform((2, 8, 8, 3))
|
|
1077
|
+
>>> patches = images_to_patches(images, patch_shape=(4, 4), overlap=(2, 2))
|
|
1078
|
+
>>> patches.shape
|
|
1079
|
+
(2, 3, 3, 4, 4, 3)
|
|
978
1080
|
"""
|
|
979
1081
|
assert len(images.shape) == 4, (
|
|
980
1082
|
f"input array should have 4 dimensions, but has {len(images.shape)} dimensions"
|
|
@@ -1052,13 +1154,15 @@ def patches_to_images(
|
|
|
1052
1154
|
images (Tensor): Reconstructed batch of images from batch of patches.
|
|
1053
1155
|
|
|
1054
1156
|
Example:
|
|
1055
|
-
..
|
|
1157
|
+
.. doctest::
|
|
1056
1158
|
|
|
1057
|
-
import keras
|
|
1159
|
+
>>> import keras
|
|
1160
|
+
>>> from zea.tensor_ops import patches_to_images
|
|
1058
1161
|
|
|
1059
|
-
patches = keras.random.uniform((2, 3, 3, 4, 4, 3))
|
|
1060
|
-
images = patches_to_images(patches, image_shape=(8, 8, 3), overlap=(2, 2))
|
|
1061
|
-
images.shape
|
|
1162
|
+
>>> patches = keras.random.uniform((2, 3, 3, 4, 4, 3))
|
|
1163
|
+
>>> images = patches_to_images(patches, image_shape=(8, 8, 3), overlap=(2, 2))
|
|
1164
|
+
>>> images.shape
|
|
1165
|
+
(2, 8, 8, 3)
|
|
1062
1166
|
"""
|
|
1063
1167
|
# Input validation
|
|
1064
1168
|
assert len(image_shape) == 3, "image_shape must have 3 dimensions: (height, width, channels)."
|
|
@@ -1138,14 +1242,16 @@ def reshape_axis(data, newshape: tuple, axis: int):
|
|
|
1138
1242
|
axis (int): axis to reshape.
|
|
1139
1243
|
|
|
1140
1244
|
Example:
|
|
1141
|
-
..
|
|
1245
|
+
.. doctest::
|
|
1142
1246
|
|
|
1143
|
-
import keras
|
|
1247
|
+
>>> import keras
|
|
1248
|
+
>>> from zea.tensor_ops import reshape_axis
|
|
1144
1249
|
|
|
1145
|
-
data = keras.random.uniform((3, 4, 5))
|
|
1146
|
-
newshape = (2, 2)
|
|
1147
|
-
reshaped_data = reshape_axis(data, newshape, axis=1)
|
|
1148
|
-
reshaped_data.shape
|
|
1250
|
+
>>> data = keras.random.uniform((3, 4, 5))
|
|
1251
|
+
>>> newshape = (2, 2)
|
|
1252
|
+
>>> reshaped_data = reshape_axis(data, newshape, axis=1)
|
|
1253
|
+
>>> reshaped_data.shape
|
|
1254
|
+
(3, 2, 2, 5)
|
|
1149
1255
|
"""
|
|
1150
1256
|
axis = map_negative_indices([axis], data.ndim)[0]
|
|
1151
1257
|
shape = list(ops.shape(data)) # list
|
|
@@ -1359,17 +1465,17 @@ def fori_loop(lower, upper, body_fun, init_val, disable_jit=False):
|
|
|
1359
1465
|
|
|
1360
1466
|
|
|
1361
1467
|
def L2(x):
|
|
1362
|
-
"""L2 norm of a tensor.
|
|
1468
|
+
"""L2 norm of a real tensor.
|
|
1363
1469
|
|
|
1364
|
-
Implementation of L2 norm: https://mathworld.wolfram.com/L2-Norm.html
|
|
1470
|
+
Implementation of L2 norm for real vectors: https://mathworld.wolfram.com/L2-Norm.html
|
|
1365
1471
|
"""
|
|
1366
1472
|
return ops.sqrt(ops.sum(x**2))
|
|
1367
1473
|
|
|
1368
1474
|
|
|
1369
1475
|
def L1(x):
|
|
1370
|
-
"""L1 norm of a tensor.
|
|
1476
|
+
"""L1 norm of a real tensor.
|
|
1371
1477
|
|
|
1372
|
-
Implementation of L1 norm: https://mathworld.wolfram.com/L1-Norm.html
|
|
1478
|
+
Implementation of L1 norm for real vectors: https://mathworld.wolfram.com/L1-Norm.html
|
|
1373
1479
|
"""
|
|
1374
1480
|
return ops.sum(ops.abs(x))
|
|
1375
1481
|
|
|
@@ -1403,34 +1509,6 @@ def sinc(x, eps=keras.config.epsilon()):
|
|
|
1403
1509
|
return ops.sin(x + eps) / (x + eps)
|
|
1404
1510
|
|
|
1405
1511
|
|
|
1406
|
-
if keras.backend.backend() == "tensorflow":
|
|
1407
|
-
|
|
1408
|
-
def safe_vectorize(
|
|
1409
|
-
pyfunc,
|
|
1410
|
-
excluded=None,
|
|
1411
|
-
signature=None,
|
|
1412
|
-
):
|
|
1413
|
-
"""Just a wrapper around ops.vectorize.
|
|
1414
|
-
|
|
1415
|
-
Because tensorflow does not support multiple arguments to ops.vectorize(func)(...)
|
|
1416
|
-
We will just map the function manually.
|
|
1417
|
-
"""
|
|
1418
|
-
|
|
1419
|
-
def _map(*args):
|
|
1420
|
-
outputs = []
|
|
1421
|
-
for i in range(ops.shape(args[0])[0]):
|
|
1422
|
-
outputs.append(pyfunc(*[arg[i] for arg in args]))
|
|
1423
|
-
return ops.stack(outputs)
|
|
1424
|
-
|
|
1425
|
-
return _map
|
|
1426
|
-
|
|
1427
|
-
else:
|
|
1428
|
-
|
|
1429
|
-
def safe_vectorize(pyfunc, excluded=None, signature=None):
|
|
1430
|
-
"""Just a wrapper around ops.vectorize."""
|
|
1431
|
-
return ops.vectorize(pyfunc, excluded=excluded, signature=signature)
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
1512
|
def apply_along_axis(func1d, axis, arr, *args, **kwargs):
|
|
1435
1513
|
"""Apply a function to 1D array slices along an axis.
|
|
1436
1514
|
|
|
@@ -1476,7 +1554,7 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
|
|
|
1476
1554
|
perm = list(range(len(x.shape)))
|
|
1477
1555
|
perm[0], perm[dim_offset] = perm[dim_offset], perm[0]
|
|
1478
1556
|
x_moved = ops.transpose(x, perm)
|
|
1479
|
-
result =
|
|
1557
|
+
result = vectorized_map(f, x_moved)
|
|
1480
1558
|
# Move the result dimension back if needed
|
|
1481
1559
|
if len(result.shape) > 0:
|
|
1482
1560
|
result_perm = list(range(len(result.shape)))
|
|
@@ -1497,7 +1575,7 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
|
|
|
1497
1575
|
prev_func = func
|
|
1498
1576
|
|
|
1499
1577
|
def make_func(f):
|
|
1500
|
-
return lambda x:
|
|
1578
|
+
return lambda x: vectorized_map(f, x)
|
|
1501
1579
|
|
|
1502
1580
|
func = make_func(prev_func)
|
|
1503
1581
|
|
|
@@ -1576,3 +1654,29 @@ def correlate(x, y, mode="full"):
|
|
|
1576
1654
|
return complex_tensor
|
|
1577
1655
|
else:
|
|
1578
1656
|
return ops.real(complex_tensor)
|
|
1657
|
+
|
|
1658
|
+
|
|
1659
|
+
def translate(array, range_from=None, range_to=(0, 255)):
|
|
1660
|
+
"""Map values in array from one range to other.
|
|
1661
|
+
|
|
1662
|
+
Args:
|
|
1663
|
+
array (ndarray): input array.
|
|
1664
|
+
range_from (Tuple, optional): lower and upper bound of original array.
|
|
1665
|
+
Defaults to min and max of array.
|
|
1666
|
+
range_to (Tuple, optional): lower and upper bound to which array should be mapped.
|
|
1667
|
+
Defaults to (0, 255).
|
|
1668
|
+
|
|
1669
|
+
Returns:
|
|
1670
|
+
(ndarray): translated array
|
|
1671
|
+
"""
|
|
1672
|
+
if range_from is None:
|
|
1673
|
+
left_min, left_max = ops.min(array), ops.max(array)
|
|
1674
|
+
else:
|
|
1675
|
+
left_min, left_max = range_from
|
|
1676
|
+
right_min, right_max = range_to
|
|
1677
|
+
|
|
1678
|
+
# Convert the left range into a 0-1 range (float)
|
|
1679
|
+
value_scaled = (array - left_min) / (left_max - left_min)
|
|
1680
|
+
|
|
1681
|
+
# Convert the 0-1 range into a value in the right range.
|
|
1682
|
+
return right_min + (value_scaled * (right_max - right_min))
|