zea 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. zea/__init__.py +54 -19
  2. zea/agent/__init__.py +12 -12
  3. zea/agent/masks.py +2 -1
  4. zea/backend/tensorflow/dataloader.py +2 -1
  5. zea/beamform/beamformer.py +100 -50
  6. zea/beamform/lens_correction.py +9 -2
  7. zea/beamform/pfield.py +9 -2
  8. zea/config.py +34 -25
  9. zea/data/__init__.py +22 -16
  10. zea/data/convert/camus.py +2 -1
  11. zea/data/convert/echonet.py +4 -4
  12. zea/data/convert/echonetlvh/convert_raw_to_usbmd.py +1 -1
  13. zea/data/convert/matlab.py +11 -4
  14. zea/data/data_format.py +31 -30
  15. zea/data/datasets.py +7 -5
  16. zea/data/file.py +104 -2
  17. zea/data/layers.py +3 -3
  18. zea/datapaths.py +16 -4
  19. zea/display.py +7 -5
  20. zea/interface.py +14 -16
  21. zea/internal/_generate_keras_ops.py +6 -7
  22. zea/internal/cache.py +2 -49
  23. zea/internal/config/validation.py +1 -2
  24. zea/internal/core.py +69 -6
  25. zea/internal/device.py +6 -2
  26. zea/internal/dummy_scan.py +330 -0
  27. zea/internal/operators.py +114 -2
  28. zea/internal/parameters.py +101 -70
  29. zea/internal/setup_zea.py +5 -6
  30. zea/internal/utils.py +282 -0
  31. zea/io_lib.py +247 -19
  32. zea/keras_ops.py +74 -4
  33. zea/log.py +9 -7
  34. zea/metrics.py +15 -7
  35. zea/models/__init__.py +30 -20
  36. zea/models/base.py +30 -14
  37. zea/models/carotid_segmenter.py +19 -4
  38. zea/models/diffusion.py +173 -12
  39. zea/models/echonet.py +22 -8
  40. zea/models/echonetlvh.py +31 -7
  41. zea/models/lpips.py +19 -2
  42. zea/models/lv_segmentation.py +28 -11
  43. zea/models/preset_utils.py +5 -5
  44. zea/models/regional_quality.py +30 -10
  45. zea/models/taesd.py +21 -5
  46. zea/models/unet.py +15 -1
  47. zea/ops.py +390 -196
  48. zea/probes.py +6 -6
  49. zea/scan.py +109 -49
  50. zea/simulator.py +24 -21
  51. zea/tensor_ops.py +406 -302
  52. zea/tools/hf.py +1 -1
  53. zea/tools/selection_tool.py +47 -86
  54. zea/utils.py +92 -480
  55. zea/visualize.py +177 -39
  56. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/METADATA +4 -2
  57. zea-0.0.7.dist-info/RECORD +114 -0
  58. zea-0.0.6.dist-info/RECORD +0 -112
  59. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/WHEEL +0 -0
  60. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/entry_points.txt +0 -0
  61. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/licenses/LICENSE +0 -0
zea/tensor_ops.py CHANGED
@@ -130,162 +130,6 @@ def extend_n_dims(arr, axis, n_dims):
130
130
  return ops.reshape(arr, new_shape)
131
131
 
132
132
 
133
- def vmap(fun, in_axes=0, out_axes=0):
134
- """Vectorized map.
135
-
136
- For torch and jax backends, this uses the native vmap implementation.
137
- For other backends, this a wrapper that uses `ops.vectorized_map` under the hood.
138
-
139
- Args:
140
- fun: The function to be mapped.
141
- in_axes: The axis or axes to be mapped over in the input.
142
- Can be an integer, a tuple of integers, or None.
143
- If None, the corresponding argument is not mapped over.
144
- Defaults to 0.
145
- out_axes: The axis or axes to be mapped over in the output.
146
- Can be an integer, a tuple of integers, or None.
147
- If None, the corresponding output is not mapped over.
148
- Defaults to 0.
149
-
150
- Returns:
151
- A function that applies `fun` in a vectorized manner over the specified axes.
152
-
153
- Raises:
154
- ValueError: If the backend does not support vmap.
155
- """
156
- if keras.backend.backend() == "jax":
157
- import jax
158
-
159
- return jax.vmap(fun, in_axes=in_axes, out_axes=out_axes)
160
- elif keras.backend.backend() == "torch":
161
- import torch
162
-
163
- return torch.vmap(fun, in_dims=in_axes, out_dims=out_axes)
164
- else:
165
- return manual_vmap(fun, in_axes=in_axes, out_axes=out_axes)
166
-
167
-
168
- def manual_vmap(fun, in_axes=0, out_axes=0):
169
- """Manual vectorized map for backends that do not support vmap."""
170
-
171
- def find_map_length(args, in_axes):
172
- """Find the length of the axis to map over."""
173
- # NOTE: only needed for numpy, the other backends can handle a singleton dimension
174
- for arg, axis in zip(args, in_axes):
175
- if axis is None:
176
- continue
177
-
178
- return ops.shape(arg)[axis]
179
- return 1
180
-
181
- def _moveaxes(args, in_axes, out_axes):
182
- """Move axes of the input arguments."""
183
- args = list(args)
184
- for i, (arg, in_axis, out_axis) in enumerate(zip(args, in_axes, out_axes)):
185
- if in_axis is not None:
186
- args[i] = ops.moveaxis(arg, in_axis, out_axis)
187
- else:
188
- args[i] = ops.repeat(arg[None], find_map_length(args, in_axes), axis=out_axis)
189
- return tuple(args)
190
-
191
- def _fun(args):
192
- return fun(*args)
193
-
194
- def wrapper(*args):
195
- # If in_axes or out_axes is an int, convert to tuple
196
- if isinstance(in_axes, int):
197
- _in_axes = (in_axes,) * len(args)
198
- else:
199
- _in_axes = in_axes
200
- if isinstance(out_axes, int):
201
- _out_axes = (out_axes,) * len(args)
202
- else:
203
- _out_axes = out_axes
204
- zeros = (0,) * len(args)
205
-
206
- # Check that in_axes and out_axes are tuples
207
- if not isinstance(_in_axes, tuple):
208
- raise ValueError("in_axes must be an int or a tuple of ints.")
209
- if not isinstance(_out_axes, tuple):
210
- raise ValueError("out_axes must be an int or a tuple of ints.")
211
-
212
- args = _moveaxes(args, _in_axes, zeros)
213
- outputs = ops.vectorized_map(_fun, tuple(args))
214
-
215
- tuple_output = isinstance(outputs, (tuple, list))
216
- if not tuple_output:
217
- outputs = (outputs,)
218
-
219
- outputs = _moveaxes(outputs, zeros, _out_axes)
220
-
221
- if not tuple_output:
222
- outputs = outputs[0]
223
-
224
- return outputs
225
-
226
- return wrapper
227
-
228
-
229
- def func_with_one_batch_dim(
230
- func,
231
- tensor,
232
- n_batch_dims: int,
233
- batch_size: int | None = None,
234
- func_axis: int | None = None,
235
- **kwargs,
236
- ):
237
- """Wraps a function to apply it to an input tensor with one or more batch dimensions.
238
-
239
- The function will be executed in parallel on all batch elements.
240
-
241
- Args:
242
- func (function): The function to apply to the image.
243
- Will take the `func_axis` output from the function.
244
- tensor (Tensor): The input tensor.
245
- n_batch_dims (int): The number of batch dimensions in the input tensor.
246
- Expects the input to start with n_batch_dims batch dimensions. Defaults to 2.
247
- batch_size (int, optional): Integer specifying the size of the batch for
248
- each step to execute in parallel. Defaults to None, in which case the function
249
- will run everything in parallel.
250
- func_axis (int, optional): If `func` returns mulitple outputs, this axis will be returned.
251
- **kwargs: Additional keyword arguments to pass to the function.
252
-
253
- Returns:
254
- The output tensor with the same batch dimensions as the input tensor.
255
-
256
- Raises:
257
- ValueError: If the number of batch dimensions is greater than the rank of the input tensor.
258
- """
259
- # Extract the shape of the batch dimensions from the input tensor
260
- batch_dims = ops.shape(tensor)[:n_batch_dims]
261
-
262
- # Extract the shape of the remaining (non-batch) dimensions
263
- other_dims = ops.shape(tensor)[n_batch_dims:]
264
-
265
- # Reshape the input tensor to merge all batch dimensions into one
266
- reshaped_input = ops.reshape(tensor, [-1, *other_dims])
267
-
268
- # Apply the given function to the reshaped input tensor
269
- if batch_size is None:
270
- reshaped_output = func(reshaped_input, **kwargs)
271
- else:
272
- reshaped_output = batched_map(func, reshaped_input, batch_size=batch_size)
273
-
274
- # If the function returns multiple outputs, select the one corresponding to `func_axis`
275
- if isinstance(reshaped_output, (tuple, list)):
276
- if func_axis is None:
277
- raise ValueError(
278
- "func_axis must be specified when the function returns multiple outputs."
279
- )
280
- reshaped_output = reshaped_output[func_axis]
281
-
282
- # Extract the shape of the output tensor after applying the function (excluding the batch dim)
283
- output_other_dims = ops.shape(reshaped_output)[1:]
284
-
285
- # Reshape the output tensor to restore the original batch dimensions
286
- return ops.reshape(reshaped_output, [*batch_dims, *output_other_dims])
287
-
288
-
289
133
  def matrix_power(matrix, power):
290
134
  """Compute the power of a square matrix.
291
135
 
@@ -416,97 +260,346 @@ def batch_cov(x, rowvar=True, bias=False, ddof=None):
416
260
  return cov_matrices
417
261
 
418
262
 
419
- def patched_map(f, xs, patches: int, jit=True, **batch_kwargs):
420
- """Wrapper around `batched_map` for patching.
263
+ def simple_map(function, elements):
264
+ """Like `ops.map` but no tracing or jit compilation."""
421
265
 
422
- Allows you to specify the number of patches rather than the batch size.
423
- """
424
- assert patches > 0, "Number of patches must be greater than 0."
266
+ if elements is None:
267
+ return function(None)
268
+
269
+ multiple_inputs = isinstance(elements, (list, tuple))
425
270
 
426
- if patches == 1:
427
- return f(xs, **batch_kwargs)
271
+ outputs = []
272
+ if not multiple_inputs:
273
+ batch_size = elements.shape[0]
274
+ for index in range(batch_size):
275
+ outputs.append(function(elements[index]))
428
276
  else:
429
- length = ops.shape(xs)[0]
430
- batch_size = np.ceil(length / patches).astype(int)
431
- return batched_map(f, xs, batch_size, jit, **batch_kwargs)
277
+ batch_size = elements[0].shape[0]
278
+ for index in range(batch_size):
279
+ outputs.append(function([e[index] if e is not None else None for e in elements]))
280
+
281
+ if isinstance(outputs[0], (list, tuple)):
282
+ return [ops.stack(tensors) for tensors in zip(*outputs)]
283
+ else:
284
+ return ops.stack(outputs)
285
+
286
+
287
+ if keras.backend.backend() == "numpy":
288
+
289
+ def vectorized_map(function, elements):
290
+ """Fixes keras.ops.vectorized_map in numpy backend with multiple outputs."""
291
+ if not isinstance(elements, (list, tuple)):
292
+ return np.stack([function(x) for x in elements])
293
+ else:
294
+ batch_size = elements[0].shape[0]
295
+ output_store = []
296
+ for index in range(batch_size):
297
+ output_store.append(function([x[index] for x in elements]))
298
+ if isinstance(output_store[0], (list, tuple)):
299
+ return [np.stack(tensors) for tensors in zip(*output_store)]
300
+ else:
301
+ return np.stack(output_store)
302
+ else:
303
+ vectorized_map = ops.vectorized_map
304
+
305
+
306
+ def _find_map_length(args, in_axes) -> int:
307
+ """Find the length of the axis to map over."""
308
+ for arg, axis in zip(args, in_axes):
309
+ if axis is None or arg is None:
310
+ continue
311
+
312
+ return ops.shape(arg)[axis]
313
+
314
+ raise ValueError("At least one in_axes must be non-None to determine map length.")
315
+
316
+
317
+ def _repeat_int_to_tuple(value, length) -> Tuple[Union[int, None], ...]:
318
+ """Convert an int or None to a tuple of length `length`."""
319
+ if isinstance(value, int) or value is None:
320
+ return (value,) * length
321
+ elif not isinstance(value, tuple):
322
+ raise ValueError("Value must be an int, None, or a tuple.")
323
+ return value
324
+
325
+
326
+ def _map(fun, in_axes=0, out_axes=0, map_fn=None, _use_torch_vmap=False):
327
+ """Mapping function, vectorized by default.
328
+
329
+ For jax, this uses the native vmap implementation.
330
+ For other backends, this a wrapper that uses `ops.vectorized_map` under the hood.
331
+
332
+ Probably you want to use `zea.tensor_ops.vmap` instead, which uses this function
333
+ with additional batching/chunking support.
334
+
335
+ Args:
336
+ fun: The function to be mapped.
337
+ in_axes: The axis or axes to be mapped over in the input.
338
+ Can be an integer, a tuple of integers, or None.
339
+ If None, the corresponding argument is not mapped over.
340
+ Defaults to 0.
341
+ out_axes: The axis or axes to be mapped over in the output.
342
+ Can be an integer, a tuple of integers, or None.
343
+ If None, the corresponding output is not mapped over.
344
+ Defaults to 0.
345
+ map_fn: The mapping function to use. If None, defaults to `ops.vectorized_map`.
346
+ _use_torch_vmap: If True, uses PyTorch's native vmap implementation.
347
+
348
+ Returns:
349
+ A function that applies `fun` (in a vectorized manner) over the specified axes.
350
+ """
432
351
 
352
+ # Use native vmap for JAX backend when map_fn is not provided
353
+ if keras.backend.backend() == "jax" and map_fn is None:
354
+ import jax
433
355
 
434
- def batched_map(f, xs, batch_size=None, jit=True, **batch_kwargs):
435
- """Map a function over leading array axes.
356
+ return jax.vmap(fun, in_axes=in_axes, out_axes=out_axes)
357
+
358
+ # Use native vmap for PyTorch backend when map_fn is not provided and _use_torch_vmap is True
359
+ if keras.backend.backend() == "torch" and map_fn is None and _use_torch_vmap:
360
+ import torch
361
+
362
+ return torch.vmap(fun, in_dims=in_axes, out_dims=out_axes)
363
+
364
+ # Default to keras vectorized_map if map_fn not provided
365
+ if map_fn is None:
366
+ map_fn = vectorized_map
367
+
368
+ def _moveaxes(args, in_axes, out_axes) -> tuple:
369
+ """Move axes of the input arguments."""
370
+ args = list(args)
371
+ new_args = []
372
+ map_length = _find_map_length(args, in_axes)
373
+ for arg, in_axis, out_axis in zip(args, in_axes, out_axes):
374
+ if arg is None:
375
+ # filter out None arguments
376
+ continue
377
+ if out_axis is None:
378
+ new_args.append(ops.take(arg, 0, axis=in_axis))
379
+ elif in_axis is not None:
380
+ new_args.append(ops.moveaxis(arg, in_axis, out_axis))
381
+ else:
382
+ new_args.append(
383
+ ops.repeat(ops.expand_dims(arg, out_axis), map_length, axis=out_axis)
384
+ )
385
+ return tuple(new_args)
386
+
387
+ def _partial_at(func, idx, value) -> callable:
388
+ """Return a new function with value inserted at index idx in args."""
389
+
390
+ def wrapper(*args, **kwargs):
391
+ args = list(args)
392
+ args[idx:idx] = [value]
393
+ return func(*args, **kwargs)
394
+
395
+ return wrapper
396
+
397
+ def mapped_wrapper(*args):
398
+ _in_axes = _repeat_int_to_tuple(in_axes, len(args))
399
+
400
+ # Move mapped axes to front
401
+ zeros = (0,) * len(args)
402
+ none_indexes = [i for i, arg in enumerate(args) if arg is None]
403
+ args = _moveaxes(args, _in_axes, zeros)
404
+
405
+ # Build function with None arguments prefilled
406
+ _prefilled_none_fn = fun
407
+ for none_idx in none_indexes:
408
+ _prefilled_none_fn = _partial_at(_prefilled_none_fn, none_idx, None)
409
+
410
+ def _fun(args):
411
+ return _prefilled_none_fn(*args)
412
+
413
+ outputs = map_fn(_fun, tuple(args))
414
+
415
+ # Wrap outputs to tuple for easier processing
416
+ tuple_output = isinstance(outputs, (tuple, list))
417
+ if not tuple_output:
418
+ outputs = (outputs,)
419
+
420
+ _out_axes = _repeat_int_to_tuple(out_axes, len(outputs))
421
+
422
+ # Move mapped axes back to original position
423
+ outputs = _moveaxes(outputs, zeros, _out_axes)
424
+
425
+ if not tuple_output:
426
+ outputs = outputs[0]
427
+
428
+ return outputs
429
+
430
+ return mapped_wrapper
431
+
432
+
433
+ def vmap(
434
+ fun,
435
+ in_axes=0,
436
+ out_axes=0,
437
+ batch_size=None,
438
+ chunks=None,
439
+ fn_supports_batch=False,
440
+ disable_jit=False,
441
+ _use_torch_vmap=False,
442
+ ):
443
+ """`vmap` with batching or chunking support to avoid memory issues.
444
+
445
+ Basically a wrapper around `vmap` that splits the input into batches or chunks
446
+ to avoid memory issues with large inputs.
436
447
 
437
448
  Args:
438
- f (callable): Function to apply element-wise over the first axis.
439
- xs (Tensor): Values over which to map along the leading axis.
440
- batch_size (int, optional): Size of the batch for each step. Defaults to None,
441
- in which case the function will be equivalent to `ops.map`, and thus map over
442
- the leading axis.
443
- jit (bool, optional): If True, use a jitted version of the function for
444
- faster batched mapping. Else, loop over the data with the original function.
445
- batch_kwargs (dict, optional): Additional keyword arguments (tensors) to
446
- batch along with xs. Must have the same first dimension size as xs.
449
+ fun: Function to be mapped.
450
+ in_axes: Axis or axes to be mapped over in the input.
451
+ out_axes: Axis or axes to be mapped over in the output.
452
+ batch_size: Size of the batch for each step. If `None`, the function will be equivalent
453
+ to `vmap`. If `1`, will be equivalent to `map`. Mutually exclusive with `chunks`.
454
+ chunks: Number of chunks to split the input into. If `None` or `1`, the function will be
455
+ equivalent to `vmap`. Mutually exclusive with `batch_size`.
456
+ fn_supports_batch: If True, assumes that `fun` can already handle batched inputs.
457
+ In this case, this function will only handle padding and reshaping for batching.
458
+ disable_jit: If True, disables JIT compilation for backends that support it.
459
+ This can be useful for debugging. Will fall back to simple mapping.
460
+ _use_torch_vmap: If True, uses PyTorch's native vmap implementation.
461
+ Advantage: you can apply `vmap` multiple times without issues.
462
+ Disadvantage: does not support None arguments.
463
+ Returns:
464
+ A function that applies `fun` in a batched manner over the specified axes.
465
+ """
466
+
467
+ # Mutually exclusive arguments
468
+ assert not (batch_size is not None and chunks is not None), (
469
+ "batch_size and chunks are mutually exclusive. Please specify only one of them."
470
+ )
471
+
472
+ if batch_size is not None:
473
+ assert batch_size > 0, "batch_size must be greater than 0."
474
+ if chunks is not None:
475
+ assert chunks > 0, "chunks must be greater than 0."
476
+
477
+ no_chunks_or_batch = batch_size is None and (chunks is None or chunks == 1)
478
+
479
+ if fn_supports_batch and no_chunks_or_batch:
480
+ return fun
481
+
482
+ if not fn_supports_batch:
483
+ # vmap to support batches
484
+ fun = _map(
485
+ fun,
486
+ in_axes=in_axes,
487
+ out_axes=out_axes,
488
+ map_fn=None if not disable_jit else simple_map,
489
+ _use_torch_vmap=_use_torch_vmap,
490
+ )
491
+
492
+ if no_chunks_or_batch:
493
+ return fun
494
+
495
+ # map (sequentially) to support batches/chunks that fit in memory
496
+ map_fn = ops.map if not disable_jit else simple_map
497
+
498
+ def batched_fun(*args):
499
+ _in_axes = _repeat_int_to_tuple(in_axes, len(args))
500
+ total_length = _find_map_length(args, _in_axes)
501
+
502
+ if chunks is not None:
503
+ _batch_size = np.ceil(total_length / chunks).astype(int)
504
+ else:
505
+ _batch_size = batch_size
506
+
507
+ new_args = []
508
+ for arg, in_axis in zip(args, _in_axes):
509
+ if in_axis is None or arg is None:
510
+ new_args.append(arg)
511
+ continue
512
+ padded_arg = pad_array_to_divisible(arg, _batch_size, axis=in_axis)
513
+ reshaped_arg = reshape_axis(padded_arg, (-1, _batch_size), axis=in_axis)
514
+ new_args.append(reshaped_arg)
515
+
516
+ outputs = _map(fun, in_axes=_in_axes, out_axes=out_axes, map_fn=map_fn)(*new_args)
517
+
518
+ # Wrap outputs to tuple for easier processing
519
+ tuple_output = isinstance(outputs, (tuple, list))
520
+ if not tuple_output:
521
+ outputs = (outputs,)
522
+
523
+ new_outputs = []
524
+ _out_axes = _repeat_int_to_tuple(out_axes, len(outputs))
525
+ for output, out_axis in zip(outputs, _out_axes):
526
+ if out_axis is None:
527
+ new_outputs.append(output)
528
+ continue
529
+ reshaped_output = flatten(output, start_dim=out_axis, end_dim=out_axis + 1)
530
+ cropped_output = ops.take(reshaped_output, ops.arange(total_length), axis=out_axis)
531
+ new_outputs.append(cropped_output)
532
+
533
+ if tuple_output:
534
+ return tuple(new_outputs)
535
+ else:
536
+ return new_outputs[0]
537
+
538
+ return batched_fun
539
+
540
+
541
+ def func_with_one_batch_dim(
542
+ func,
543
+ tensor,
544
+ n_batch_dims: int,
545
+ batch_size: int | None = None,
546
+ func_axis: int | None = None,
547
+ **kwargs,
548
+ ):
549
+ """Wraps a function to apply it to an input tensor with one or more batch dimensions.
550
+
551
+ The function will be executed in parallel on all batch elements.
552
+
553
+ Args:
554
+ func (function): The function to apply to the image.
555
+ Will take the `func_axis` output from the function.
556
+ tensor (Tensor): The input tensor.
557
+ n_batch_dims (int): The number of batch dimensions in the input tensor.
558
+ Expects the input to start with n_batch_dims batch dimensions. Defaults to 2.
559
+ batch_size (int, optional): Integer specifying the size of the batch for
560
+ each step to execute in parallel. Defaults to None, in which case the function
561
+ will run everything in parallel.
562
+ func_axis (int, optional): If `func` returns mulitple outputs, this axis will be returned.
563
+ **kwargs: Additional keyword arguments to pass to the function.
447
564
 
448
565
  Returns:
449
- The mapped tensor(s).
566
+ The output tensor with the same batch dimensions as the input tensor.
450
567
 
451
- Idea taken from: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.map.html
568
+ Raises:
569
+ ValueError: If the number of batch dimensions is greater than the rank of the input tensor.
452
570
  """
453
- if batch_kwargs is None:
454
- batch_kwargs = {}
455
-
456
- # Ensure all batch kwargs have the same leading dimension as xs.
457
- if batch_kwargs:
458
- assert all(
459
- ops.shape(xs)[0] == ops.shape(v)[0] for v in batch_kwargs.values() if v is not None
460
- ), "All batch kwargs must have the same first dimension size as xs."
461
-
462
- total = ops.shape(xs)[0]
463
- # TODO: could be rewritten with ops.cond such that it also works for jit=True.
464
- if not jit and batch_size is not None and total <= batch_size:
465
- return f(xs, **batch_kwargs)
466
-
467
- ## Non-jitted version: simply iterate over batches.
468
- if not jit:
469
- bs = batch_size or 1 # Default batch size to 1 if not specified.
470
- outputs = []
471
- for i in range(0, total, bs):
472
- idx = slice(i, i + bs)
473
- current_kwargs = {k: v[idx] for k, v in batch_kwargs.items()}
474
- outputs.append(f(xs[idx], **current_kwargs))
475
- return ops.concatenate(outputs, axis=0)
476
-
477
- ## Jitted version.
478
-
479
- # Helper to create the batched function for use with ops.map.
480
- def create_batched_f(kw_keys):
481
- def batched_f(inputs):
482
- x, *kw_values = inputs
483
- kw = dict(zip(kw_keys, kw_values))
484
- return f(x, **kw)
485
-
486
- return batched_f
571
+ # Extract the shape of the batch dimensions from the input tensor
572
+ batch_dims = ops.shape(tensor)[:n_batch_dims]
573
+
574
+ # Extract the shape of the remaining (non-batch) dimensions
575
+ other_dims = ops.shape(tensor)[n_batch_dims:]
487
576
 
577
+ # Reshape the input tensor to merge all batch dimensions into one
578
+ reshaped_input = ops.reshape(tensor, [-1, *other_dims])
579
+
580
+ # Apply the given function to the reshaped input tensor
488
581
  if batch_size is None:
489
- batched_f = create_batched_f(list(batch_kwargs.keys()))
490
- return ops.map(batched_f, (xs, *batch_kwargs.values()))
491
-
492
- # Pad and reshape primary tensor.
493
- xs_padded = pad_array_to_divisible(xs, batch_size, axis=0)
494
- new_shape = (-1, batch_size) + ops.shape(xs_padded)[1:]
495
- xs_reshaped = ops.reshape(xs_padded, new_shape)
496
-
497
- # Pad and reshape batch_kwargs similarly.
498
- reshaped_kwargs = {}
499
- for k, v in batch_kwargs.items():
500
- if v is None:
501
- reshaped_kwargs[k] = None
502
- else:
503
- v_padded = pad_array_to_divisible(v, batch_size, axis=0)
504
- reshaped_kwargs[k] = ops.reshape(v_padded, (-1, batch_size) + ops.shape(v_padded)[1:])
582
+ reshaped_output = func(reshaped_input, **kwargs)
583
+ else:
584
+ reshaped_output = vmap(
585
+ lambda *args: func(*args, **kwargs),
586
+ batch_size=batch_size,
587
+ fn_supports_batch=True,
588
+ )(reshaped_input)
589
+
590
+ # If the function returns multiple outputs, select the one corresponding to `func_axis`
591
+ if isinstance(reshaped_output, (tuple, list)):
592
+ if func_axis is None:
593
+ raise ValueError(
594
+ "func_axis must be specified when the function returns multiple outputs."
595
+ )
596
+ reshaped_output = reshaped_output[func_axis]
505
597
 
506
- batched_f = create_batched_f(list(reshaped_kwargs.keys()))
507
- out = ops.map(batched_f, (xs_reshaped, *reshaped_kwargs.values()))
508
- out_reshaped = ops.reshape(out, (-1,) + ops.shape(out)[2:])
509
- return out_reshaped[:total] # Remove any padding added.
598
+ # Extract the shape of the output tensor after applying the function (excluding the batch dim)
599
+ output_other_dims = ops.shape(reshaped_output)[1:]
600
+
601
+ # Reshape the output tensor to restore the original batch dimensions
602
+ return ops.reshape(reshaped_output, [*batch_dims, *output_other_dims])
510
603
 
511
604
 
512
605
  if keras.backend.backend() == "jax":
@@ -740,14 +833,16 @@ def stack_volume_data_along_axis(data, batch_axis: int, stack_axis: int, number:
740
833
  Tensor: Reshaped tensor with data stacked along stack_axis.
741
834
 
742
835
  Example:
743
- .. code-block:: python
836
+ .. doctest::
744
837
 
745
- import keras
838
+ >>> import keras
839
+ >>> from zea.tensor_ops import stack_volume_data_along_axis
746
840
 
747
- data = keras.random.uniform((10, 20, 30))
748
- # stacking along 1st axis with 2 frames per block
749
- stacked_data = stack_volume_data_along_axis(data, 0, 1, 2)
750
- stacked_data.shape
841
+ >>> data = keras.random.uniform((10, 20, 30))
842
+ >>> # stacking along 1st axis with 2 frames per block
843
+ >>> stacked_data = stack_volume_data_along_axis(data, 0, 1, 2)
844
+ >>> stacked_data.shape
845
+ (5, 40, 30)
751
846
  """
752
847
  blocks = int(ops.ceil(data.shape[batch_axis] / number))
753
848
  data = pad_array_to_divisible(data, axis=batch_axis, N=blocks, mode="reflect")
@@ -782,13 +877,15 @@ def split_volume_data_from_axis(data, batch_axis: int, stack_axis: int, number:
782
877
  Tensor: Reshaped tensor with data split back to original format.
783
878
 
784
879
  Example:
785
- .. code-block:: python
880
+ .. doctest::
786
881
 
787
- import keras
882
+ >>> import keras
883
+ >>> from zea.tensor_ops import split_volume_data_from_axis
788
884
 
789
- data = keras.random.uniform((20, 10, 30))
790
- split_data = split_volume_data_from_axis(data, 0, 1, 2, 2)
791
- split_data.shape
885
+ >>> data = keras.random.uniform((20, 10, 30))
886
+ >>> split_data = split_volume_data_from_axis(data, 0, 1, 2, 2)
887
+ >>> split_data.shape
888
+ (39, 5, 30)
792
889
  """
793
890
  if data.shape[stack_axis] == 1:
794
891
  # in this case it was a broadcasted axis which does not need to be split
@@ -904,14 +1001,17 @@ def check_patches_fit(
904
1001
  in the original image and the new image shape if the patches do not fit.
905
1002
 
906
1003
  Example:
907
- .. code-block:: python
908
-
909
- image_shape = (10, 10)
910
- patch_shape = (4, 4)
911
- overlap = (2, 2)
912
- patches_fit, new_shape = check_patches_fit(image_shape, patch_shape, overlap)
913
- patches_fit
914
- new_shape
1004
+ .. doctest::
1005
+
1006
+ >>> from zea.tensor_ops import check_patches_fit
1007
+ >>> image_shape = (10, 10)
1008
+ >>> patch_shape = (4, 4)
1009
+ >>> overlap = (2, 2)
1010
+ >>> patches_fit, new_shape = check_patches_fit(image_shape, patch_shape, overlap)
1011
+ >>> patches_fit
1012
+ True
1013
+ >>> new_shape
1014
+ (10, 10)
915
1015
  """
916
1016
  if overlap:
917
1017
  stride = (np.array(patch_shape) - np.array(overlap)).astype(int)
@@ -968,13 +1068,15 @@ def images_to_patches(
968
1068
  [batch, #patch_y, #patch_x, patch_size_y, patch_size_x, #channels].
969
1069
 
970
1070
  Example:
971
- .. code-block:: python
1071
+ .. doctest::
972
1072
 
973
- import keras
1073
+ >>> import keras
1074
+ >>> from zea.tensor_ops import images_to_patches
974
1075
 
975
- images = keras.random.uniform((2, 8, 8, 3))
976
- patches = images_to_patches(images, patch_shape=(4, 4), overlap=(2, 2))
977
- patches.shape
1076
+ >>> images = keras.random.uniform((2, 8, 8, 3))
1077
+ >>> patches = images_to_patches(images, patch_shape=(4, 4), overlap=(2, 2))
1078
+ >>> patches.shape
1079
+ (2, 3, 3, 4, 4, 3)
978
1080
  """
979
1081
  assert len(images.shape) == 4, (
980
1082
  f"input array should have 4 dimensions, but has {len(images.shape)} dimensions"
@@ -1052,13 +1154,15 @@ def patches_to_images(
1052
1154
  images (Tensor): Reconstructed batch of images from batch of patches.
1053
1155
 
1054
1156
  Example:
1055
- .. code-block:: python
1157
+ .. doctest::
1056
1158
 
1057
- import keras
1159
+ >>> import keras
1160
+ >>> from zea.tensor_ops import patches_to_images
1058
1161
 
1059
- patches = keras.random.uniform((2, 3, 3, 4, 4, 3))
1060
- images = patches_to_images(patches, image_shape=(8, 8, 3), overlap=(2, 2))
1061
- images.shape
1162
+ >>> patches = keras.random.uniform((2, 3, 3, 4, 4, 3))
1163
+ >>> images = patches_to_images(patches, image_shape=(8, 8, 3), overlap=(2, 2))
1164
+ >>> images.shape
1165
+ (2, 8, 8, 3)
1062
1166
  """
1063
1167
  # Input validation
1064
1168
  assert len(image_shape) == 3, "image_shape must have 3 dimensions: (height, width, channels)."
@@ -1138,14 +1242,16 @@ def reshape_axis(data, newshape: tuple, axis: int):
1138
1242
  axis (int): axis to reshape.
1139
1243
 
1140
1244
  Example:
1141
- .. code-block:: python
1245
+ .. doctest::
1142
1246
 
1143
- import keras
1247
+ >>> import keras
1248
+ >>> from zea.tensor_ops import reshape_axis
1144
1249
 
1145
- data = keras.random.uniform((3, 4, 5))
1146
- newshape = (2, 2)
1147
- reshaped_data = reshape_axis(data, newshape, axis=1)
1148
- reshaped_data.shape
1250
+ >>> data = keras.random.uniform((3, 4, 5))
1251
+ >>> newshape = (2, 2)
1252
+ >>> reshaped_data = reshape_axis(data, newshape, axis=1)
1253
+ >>> reshaped_data.shape
1254
+ (3, 2, 2, 5)
1149
1255
  """
1150
1256
  axis = map_negative_indices([axis], data.ndim)[0]
1151
1257
  shape = list(ops.shape(data)) # list
@@ -1359,17 +1465,17 @@ def fori_loop(lower, upper, body_fun, init_val, disable_jit=False):
1359
1465
 
1360
1466
 
1361
1467
  def L2(x):
1362
- """L2 norm of a tensor.
1468
+ """L2 norm of a real tensor.
1363
1469
 
1364
- Implementation of L2 norm: https://mathworld.wolfram.com/L2-Norm.html
1470
+ Implementation of L2 norm for real vectors: https://mathworld.wolfram.com/L2-Norm.html
1365
1471
  """
1366
1472
  return ops.sqrt(ops.sum(x**2))
1367
1473
 
1368
1474
 
1369
1475
  def L1(x):
1370
- """L1 norm of a tensor.
1476
+ """L1 norm of a real tensor.
1371
1477
 
1372
- Implementation of L1 norm: https://mathworld.wolfram.com/L1-Norm.html
1478
+ Implementation of L1 norm for real vectors: https://mathworld.wolfram.com/L1-Norm.html
1373
1479
  """
1374
1480
  return ops.sum(ops.abs(x))
1375
1481
 
@@ -1403,34 +1509,6 @@ def sinc(x, eps=keras.config.epsilon()):
1403
1509
  return ops.sin(x + eps) / (x + eps)
1404
1510
 
1405
1511
 
1406
- if keras.backend.backend() == "tensorflow":
1407
-
1408
- def safe_vectorize(
1409
- pyfunc,
1410
- excluded=None,
1411
- signature=None,
1412
- ):
1413
- """Just a wrapper around ops.vectorize.
1414
-
1415
- Because tensorflow does not support multiple arguments to ops.vectorize(func)(...)
1416
- We will just map the function manually.
1417
- """
1418
-
1419
- def _map(*args):
1420
- outputs = []
1421
- for i in range(ops.shape(args[0])[0]):
1422
- outputs.append(pyfunc(*[arg[i] for arg in args]))
1423
- return ops.stack(outputs)
1424
-
1425
- return _map
1426
-
1427
- else:
1428
-
1429
- def safe_vectorize(pyfunc, excluded=None, signature=None):
1430
- """Just a wrapper around ops.vectorize."""
1431
- return ops.vectorize(pyfunc, excluded=excluded, signature=signature)
1432
-
1433
-
1434
1512
  def apply_along_axis(func1d, axis, arr, *args, **kwargs):
1435
1513
  """Apply a function to 1D array slices along an axis.
1436
1514
 
@@ -1476,7 +1554,7 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
1476
1554
  perm = list(range(len(x.shape)))
1477
1555
  perm[0], perm[dim_offset] = perm[dim_offset], perm[0]
1478
1556
  x_moved = ops.transpose(x, perm)
1479
- result = ops.vectorized_map(f, x_moved)
1557
+ result = vectorized_map(f, x_moved)
1480
1558
  # Move the result dimension back if needed
1481
1559
  if len(result.shape) > 0:
1482
1560
  result_perm = list(range(len(result.shape)))
@@ -1497,7 +1575,7 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
1497
1575
  prev_func = func
1498
1576
 
1499
1577
  def make_func(f):
1500
- return lambda x: ops.vectorized_map(f, x)
1578
+ return lambda x: vectorized_map(f, x)
1501
1579
 
1502
1580
  func = make_func(prev_func)
1503
1581
 
@@ -1576,3 +1654,29 @@ def correlate(x, y, mode="full"):
1576
1654
  return complex_tensor
1577
1655
  else:
1578
1656
  return ops.real(complex_tensor)
1657
+
1658
+
1659
+ def translate(array, range_from=None, range_to=(0, 255)):
1660
+ """Map values in array from one range to other.
1661
+
1662
+ Args:
1663
+ array (ndarray): input array.
1664
+ range_from (Tuple, optional): lower and upper bound of original array.
1665
+ Defaults to min and max of array.
1666
+ range_to (Tuple, optional): lower and upper bound to which array should be mapped.
1667
+ Defaults to (0, 255).
1668
+
1669
+ Returns:
1670
+ (ndarray): translated array
1671
+ """
1672
+ if range_from is None:
1673
+ left_min, left_max = ops.min(array), ops.max(array)
1674
+ else:
1675
+ left_min, left_max = range_from
1676
+ right_min, right_max = range_to
1677
+
1678
+ # Convert the left range into a 0-1 range (float)
1679
+ value_scaled = (array - left_min) / (left_max - left_min)
1680
+
1681
+ # Convert the 0-1 range into a value in the right range.
1682
+ return right_min + (value_scaled * (right_max - right_min))