zea 0.0.5__tar.gz → 0.0.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. {zea-0.0.5 → zea-0.0.6}/PKG-INFO +6 -2
  2. {zea-0.0.5 → zea-0.0.6}/pyproject.toml +8 -1
  3. {zea-0.0.5 → zea-0.0.6}/zea/__init__.py +1 -1
  4. {zea-0.0.5 → zea-0.0.6}/zea/agent/selection.py +166 -0
  5. {zea-0.0.5 → zea-0.0.6}/zea/backend/__init__.py +89 -0
  6. zea-0.0.6/zea/backend/jax/__init__.py +33 -0
  7. zea-0.0.6/zea/backend/tensorflow/__init__.py +17 -0
  8. zea-0.0.6/zea/backend/torch/__init__.py +39 -0
  9. {zea-0.0.5 → zea-0.0.6}/zea/data/layers.py +2 -3
  10. {zea-0.0.5 → zea-0.0.6}/zea/internal/registry.py +1 -1
  11. zea-0.0.6/zea/metrics.py +450 -0
  12. {zea-0.0.5 → zea-0.0.6}/zea/models/diffusion.py +14 -14
  13. {zea-0.0.5 → zea-0.0.6}/zea/models/echonetlvh.py +0 -11
  14. zea-0.0.6/zea/models/lv_segmentation.py +79 -0
  15. {zea-0.0.5 → zea-0.0.6}/zea/models/presets.py +36 -0
  16. zea-0.0.6/zea/models/regional_quality.py +122 -0
  17. {zea-0.0.5 → zea-0.0.6}/zea/ops.py +24 -11
  18. {zea-0.0.5 → zea-0.0.6}/zea/tensor_ops.py +101 -0
  19. zea-0.0.5/zea/backend/jax/__init__.py +0 -70
  20. zea-0.0.5/zea/backend/tensorflow/__init__.py +0 -66
  21. zea-0.0.5/zea/backend/torch/__init__.py +0 -74
  22. zea-0.0.5/zea/metrics.py +0 -158
  23. {zea-0.0.5 → zea-0.0.6}/LICENSE +0 -0
  24. {zea-0.0.5 → zea-0.0.6}/README.md +0 -0
  25. {zea-0.0.5 → zea-0.0.6}/zea/__main__.py +0 -0
  26. {zea-0.0.5 → zea-0.0.6}/zea/agent/__init__.py +0 -0
  27. {zea-0.0.5 → zea-0.0.6}/zea/agent/gumbel.py +0 -0
  28. {zea-0.0.5 → zea-0.0.6}/zea/agent/masks.py +0 -0
  29. {zea-0.0.5 → zea-0.0.6}/zea/backend/autograd.py +0 -0
  30. {zea-0.0.5 → zea-0.0.6}/zea/backend/tensorflow/dataloader.py +0 -0
  31. {zea-0.0.5 → zea-0.0.6}/zea/backend/tensorflow/layers/__init__.py +0 -0
  32. {zea-0.0.5 → zea-0.0.6}/zea/backend/tensorflow/layers/apodization.py +0 -0
  33. {zea-0.0.5 → zea-0.0.6}/zea/backend/tensorflow/layers/utils.py +0 -0
  34. {zea-0.0.5 → zea-0.0.6}/zea/backend/tensorflow/losses.py +0 -0
  35. {zea-0.0.5 → zea-0.0.6}/zea/backend/tensorflow/models/__init__.py +0 -0
  36. {zea-0.0.5 → zea-0.0.6}/zea/backend/tensorflow/models/lista.py +0 -0
  37. {zea-0.0.5 → zea-0.0.6}/zea/backend/tensorflow/scripts/convert-echonet-dynamic.py +0 -0
  38. {zea-0.0.5 → zea-0.0.6}/zea/backend/tensorflow/scripts/convert-taesd.py +0 -0
  39. {zea-0.0.5 → zea-0.0.6}/zea/backend/tensorflow/utils/__init__.py +0 -0
  40. {zea-0.0.5 → zea-0.0.6}/zea/backend/tensorflow/utils/callbacks.py +0 -0
  41. {zea-0.0.5 → zea-0.0.6}/zea/backend/tensorflow/utils/utils.py +0 -0
  42. {zea-0.0.5 → zea-0.0.6}/zea/backend/tf2jax.py +0 -0
  43. {zea-0.0.5 → zea-0.0.6}/zea/backend/torch/losses.py +0 -0
  44. {zea-0.0.5 → zea-0.0.6}/zea/beamform/__init__.py +0 -0
  45. {zea-0.0.5 → zea-0.0.6}/zea/beamform/beamformer.py +0 -0
  46. {zea-0.0.5 → zea-0.0.6}/zea/beamform/delays.py +0 -0
  47. {zea-0.0.5 → zea-0.0.6}/zea/beamform/lens_correction.py +0 -0
  48. {zea-0.0.5 → zea-0.0.6}/zea/beamform/pfield.py +0 -0
  49. {zea-0.0.5 → zea-0.0.6}/zea/beamform/phantoms.py +0 -0
  50. {zea-0.0.5 → zea-0.0.6}/zea/beamform/pixelgrid.py +0 -0
  51. {zea-0.0.5 → zea-0.0.6}/zea/config.py +0 -0
  52. {zea-0.0.5 → zea-0.0.6}/zea/data/__init__.py +0 -0
  53. {zea-0.0.5 → zea-0.0.6}/zea/data/__main__.py +0 -0
  54. {zea-0.0.5 → zea-0.0.6}/zea/data/augmentations.py +0 -0
  55. {zea-0.0.5 → zea-0.0.6}/zea/data/convert/__init__.py +0 -0
  56. {zea-0.0.5 → zea-0.0.6}/zea/data/convert/camus.py +0 -0
  57. {zea-0.0.5 → zea-0.0.6}/zea/data/convert/echonet.py +0 -0
  58. {zea-0.0.5 → zea-0.0.6}/zea/data/convert/echonetlvh/README.md +0 -0
  59. {zea-0.0.5 → zea-0.0.6}/zea/data/convert/echonetlvh/convert_raw_to_usbmd.py +0 -0
  60. {zea-0.0.5 → zea-0.0.6}/zea/data/convert/echonetlvh/precompute_crop.py +0 -0
  61. {zea-0.0.5 → zea-0.0.6}/zea/data/convert/images.py +0 -0
  62. {zea-0.0.5 → zea-0.0.6}/zea/data/convert/matlab.py +0 -0
  63. {zea-0.0.5 → zea-0.0.6}/zea/data/convert/picmus.py +0 -0
  64. {zea-0.0.5 → zea-0.0.6}/zea/data/data_format.py +0 -0
  65. {zea-0.0.5 → zea-0.0.6}/zea/data/dataloader.py +0 -0
  66. {zea-0.0.5 → zea-0.0.6}/zea/data/datasets.py +0 -0
  67. {zea-0.0.5 → zea-0.0.6}/zea/data/file.py +0 -0
  68. {zea-0.0.5 → zea-0.0.6}/zea/data/preset_utils.py +0 -0
  69. {zea-0.0.5 → zea-0.0.6}/zea/data/utils.py +0 -0
  70. {zea-0.0.5 → zea-0.0.6}/zea/datapaths.py +0 -0
  71. {zea-0.0.5 → zea-0.0.6}/zea/display.py +0 -0
  72. {zea-0.0.5 → zea-0.0.6}/zea/doppler.py +0 -0
  73. {zea-0.0.5 → zea-0.0.6}/zea/interface.py +0 -0
  74. {zea-0.0.5 → zea-0.0.6}/zea/internal/_generate_keras_ops.py +0 -0
  75. {zea-0.0.5 → zea-0.0.6}/zea/internal/cache.py +0 -0
  76. {zea-0.0.5 → zea-0.0.6}/zea/internal/checks.py +0 -0
  77. {zea-0.0.5 → zea-0.0.6}/zea/internal/config/create.py +0 -0
  78. {zea-0.0.5 → zea-0.0.6}/zea/internal/config/parameters.py +0 -0
  79. {zea-0.0.5 → zea-0.0.6}/zea/internal/config/validation.py +0 -0
  80. {zea-0.0.5 → zea-0.0.6}/zea/internal/core.py +0 -0
  81. {zea-0.0.5 → zea-0.0.6}/zea/internal/device.py +0 -0
  82. {zea-0.0.5 → zea-0.0.6}/zea/internal/git_info.py +0 -0
  83. {zea-0.0.5 → zea-0.0.6}/zea/internal/notebooks.py +0 -0
  84. {zea-0.0.5 → zea-0.0.6}/zea/internal/operators.py +0 -0
  85. {zea-0.0.5 → zea-0.0.6}/zea/internal/parameters.py +0 -0
  86. {zea-0.0.5 → zea-0.0.6}/zea/internal/setup_zea.py +0 -0
  87. {zea-0.0.5 → zea-0.0.6}/zea/internal/viewer.py +0 -0
  88. {zea-0.0.5 → zea-0.0.6}/zea/io_lib.py +0 -0
  89. {zea-0.0.5 → zea-0.0.6}/zea/keras_ops.py +0 -0
  90. {zea-0.0.5 → zea-0.0.6}/zea/log.py +0 -0
  91. {zea-0.0.5 → zea-0.0.6}/zea/models/__init__.py +0 -0
  92. {zea-0.0.5 → zea-0.0.6}/zea/models/base.py +0 -0
  93. {zea-0.0.5 → zea-0.0.6}/zea/models/carotid_segmenter.py +0 -0
  94. {zea-0.0.5 → zea-0.0.6}/zea/models/deeplabv3.py +0 -0
  95. {zea-0.0.5 → zea-0.0.6}/zea/models/dense.py +0 -0
  96. {zea-0.0.5 → zea-0.0.6}/zea/models/echonet.py +0 -0
  97. {zea-0.0.5 → zea-0.0.6}/zea/models/generative.py +0 -0
  98. {zea-0.0.5 → zea-0.0.6}/zea/models/gmm.py +0 -0
  99. {zea-0.0.5 → zea-0.0.6}/zea/models/layers.py +0 -0
  100. {zea-0.0.5 → zea-0.0.6}/zea/models/lpips.py +0 -0
  101. {zea-0.0.5 → zea-0.0.6}/zea/models/preset_utils.py +0 -0
  102. {zea-0.0.5 → zea-0.0.6}/zea/models/taesd.py +0 -0
  103. {zea-0.0.5 → zea-0.0.6}/zea/models/unet.py +0 -0
  104. {zea-0.0.5 → zea-0.0.6}/zea/models/utils.py +0 -0
  105. {zea-0.0.5 → zea-0.0.6}/zea/probes.py +0 -0
  106. {zea-0.0.5 → zea-0.0.6}/zea/scan.py +0 -0
  107. {zea-0.0.5 → zea-0.0.6}/zea/simulator.py +0 -0
  108. {zea-0.0.5 → zea-0.0.6}/zea/tools/__init__.py +0 -0
  109. {zea-0.0.5 → zea-0.0.6}/zea/tools/fit_scan_cone.py +0 -0
  110. {zea-0.0.5 → zea-0.0.6}/zea/tools/hf.py +0 -0
  111. {zea-0.0.5 → zea-0.0.6}/zea/tools/selection_tool.py +0 -0
  112. {zea-0.0.5 → zea-0.0.6}/zea/tools/wndb.py +0 -0
  113. {zea-0.0.5 → zea-0.0.6}/zea/utils.py +0 -0
  114. {zea-0.0.5 → zea-0.0.6}/zea/visualize.py +0 -0
  115. {zea-0.0.5 → zea-0.0.6}/zea/zea_darkmode.mplstyle +0 -0
@@ -1,7 +1,8 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: zea
3
- Version: 0.0.5
3
+ Version: 0.0.6
4
4
  Summary: A Toolbox for Cognitive Ultrasound Imaging. Provides a set of tools for processing of ultrasound data, all built in your favorite machine learning framework.
5
+ License-File: LICENSE
5
6
  Keywords: ultrasound,machine learning,beamforming
6
7
  Author: Tristan Stevens
7
8
  Author-email: t.s.w.stevens@tue.nl
@@ -21,6 +22,7 @@ Provides-Extra: display
21
22
  Provides-Extra: display-headless
22
23
  Provides-Extra: docs
23
24
  Provides-Extra: jax
25
+ Provides-Extra: models
24
26
  Provides-Extra: tests
25
27
  Requires-Dist: IPython ; extra == "dev"
26
28
  Requires-Dist: IPython ; extra == "docs"
@@ -49,6 +51,8 @@ Requires-Dist: myst-parser ; extra == "docs"
49
51
  Requires-Dist: nbsphinx ; extra == "dev"
50
52
  Requires-Dist: nbsphinx ; extra == "docs"
51
53
  Requires-Dist: numpy (>=1.24)
54
+ Requires-Dist: onnxruntime (>=1.15) ; extra == "dev"
55
+ Requires-Dist: onnxruntime (>=1.15) ; extra == "models"
52
56
  Requires-Dist: opencv-python (>=4) ; extra == "display"
53
57
  Requires-Dist: opencv-python-headless (>=4) ; extra == "dev"
54
58
  Requires-Dist: opencv-python-headless (>=4) ; extra == "display-headless"
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "zea"
3
- version = "0.0.5"
3
+ version = "0.0.6"
4
4
  description = "A Toolbox for Cognitive Ultrasound Imaging. Provides a set of tools for processing of ultrasound data, all built in your favorite machine learning framework."
5
5
  authors = [
6
6
  { name = "Tristan Stevens", email = "t.s.w.stevens@tue.nl" },
@@ -74,6 +74,8 @@ dev = [
74
74
  "IPython",
75
75
  # display
76
76
  "opencv-python-headless>=4",
77
+ # models
78
+ "onnxruntime>=1.15",
77
79
  ]
78
80
 
79
81
  tests = [
@@ -110,6 +112,11 @@ display = [
110
112
  display-headless = [
111
113
  "opencv-python-headless>=4",
112
114
  ]
115
+ # most models don't needs these optional dependencies
116
+ # but some do, so we list them here
117
+ models = [
118
+ "onnxruntime>=1.15",
119
+ ]
113
120
  # these are just here for .readthedocs.yaml
114
121
  # please for proper install (with GPU support)
115
122
  # install manually
@@ -7,7 +7,7 @@ from . import log
7
7
 
8
8
  # dynamically add __version__ attribute (see pyproject.toml)
9
9
  # __version__ = __import__("importlib.metadata").metadata.version(__package__)
10
- __version__ = "0.0.5"
10
+ __version__ = "0.0.6"
11
11
 
12
12
 
13
13
  def _bootstrap_backend():
@@ -11,11 +11,14 @@ For a comprehensive example usage, see: :doc:`../notebooks/agent/agent_example`
11
11
  All strategies are stateless, meaning that they do not maintain any internal state.
12
12
  """
13
13
 
14
+ from typing import Callable
15
+
14
16
  import keras
15
17
  from keras import ops
16
18
 
17
19
  from zea import tensor_ops
18
20
  from zea.agent import masks
21
+ from zea.backend.autograd import AutoGrad
19
22
  from zea.internal.registry import action_selection_registry
20
23
 
21
24
 
@@ -493,3 +496,166 @@ class CovarianceSamplingLines(LinesActionModel):
493
496
  best_mask = ops.squeeze(best_mask, axis=0)
494
497
 
495
498
  return best_mask, self.lines_to_im_size(best_mask)
499
+
500
+
501
+ class TaskBasedLines(GreedyEntropy):
502
+ """Task-based line selection for maximizing information gain.
503
+
504
+ This action selection strategy chooses lines to maximize information gain with respect
505
+ to a downstream task outcome. It uses gradient-based saliency to identify which image
506
+ regions contribute most to task uncertainty, then selects lines accordingly.
507
+ """
508
+
509
+ def __init__(
510
+ self,
511
+ n_actions: int,
512
+ n_possible_actions: int,
513
+ img_width: int,
514
+ img_height: int,
515
+ downstream_task_function: Callable,
516
+ mean: float = 0,
517
+ std_dev: float = 1,
518
+ num_lines_to_update: int = 5,
519
+ **kwargs,
520
+ ):
521
+ """Initialize the TaskBasedLines action selection model.
522
+
523
+ Args:
524
+ n_actions (int): The number of actions the agent can take.
525
+ n_possible_actions (int): The number of possible actions (line positions).
526
+ img_width (int): The width of the input image.
527
+ img_height (int): The height of the input image.
528
+ downstream_task_function (Callable): A differentiable function that takes a
529
+ batch of inputs and produces scalar outputs. This represents the downstream
530
+ task for which information gain should be maximized.
531
+ mean (float, optional): The mean of the RBF used for reweighting. Defaults to 0.
532
+ std_dev (float, optional): The standard deviation of the RBF used for reweighting.
533
+ Defaults to 1.
534
+ num_lines_to_update (int, optional): The number of lines around the selected line
535
+ to update during reweighting. Must be odd. Defaults to 5.
536
+ **kwargs: Additional keyword arguments passed to the parent class.
537
+ """
538
+ super().__init__(
539
+ n_actions,
540
+ n_possible_actions,
541
+ img_width,
542
+ img_height,
543
+ mean,
544
+ std_dev,
545
+ num_lines_to_update,
546
+ )
547
+ self.downstream_task_function = downstream_task_function
548
+
549
+ def compute_output_and_saliency_propagation(self, particles):
550
+ """Compute saliency-weighted posterior variance for task-based selection.
551
+
552
+ This method computes how much each pixel contributes to the variance of the
553
+ downstream task output. It uses automatic differentiation to compute gradients
554
+ of the task function with respect to each particle, then weights the posterior
555
+ variance by the squared mean gradient.
556
+
557
+ Args:
558
+ particles (Tensor): Particles of shape (batch_size, n_particles, height, width)
559
+ representing the posterior distribution over images.
560
+
561
+ Returns:
562
+ Tensor: Pixelwise contribution to downstream task variance,
563
+ of shape (batch_size, height, width). Higher values indicate pixels
564
+ that contribute more to task uncertainty.
565
+ """
566
+ autograd = AutoGrad()
567
+
568
+ autograd.set_function(self.downstream_task_function)
569
+ downstream_grad_and_value_fn = autograd.get_gradient_and_value_jit_fn()
570
+ jacobian, _ = ops.vectorized_map(
571
+ lambda p: ops.vectorized_map(
572
+ downstream_grad_and_value_fn,
573
+ p,
574
+ ),
575
+ particles,
576
+ )
577
+
578
+ posterior_variance = ops.var(particles, axis=1)
579
+ mean_jacobian = ops.mean(jacobian, axis=1)
580
+ return posterior_variance * (mean_jacobian**2)
581
+
582
+ def sum_neighbouring_columns_into_n_possible_actions(self, full_linewise_salience):
583
+ """Aggregate column-wise saliency into line-wise saliency scores.
584
+
585
+ This method groups neighboring columns together to create saliency scores
586
+ for each possible line action. Since each line action may correspond to
587
+ multiple image columns, this aggregation is necessary to match the action space.
588
+
589
+ Args:
590
+ full_linewise_salience (Tensor): Saliency values for each column,
591
+ of shape (batch_size, full_image_width).
592
+
593
+ Returns:
594
+ Tensor: Aggregated saliency scores for each possible action,
595
+ of shape (batch_size, n_possible_actions).
596
+
597
+ Raises:
598
+ AssertionError: If the image width is not evenly divisible by n_possible_actions.
599
+ """
600
+ batch_size = ops.shape(full_linewise_salience)[0]
601
+ full_image_width = ops.shape(full_linewise_salience)[1]
602
+ assert full_image_width % self.n_possible_actions == 0, (
603
+ "n_possible_actions must divide evenly into image width"
604
+ )
605
+ cols_per_action = full_image_width // self.n_possible_actions
606
+ stacked_linewise_salience = ops.reshape(
607
+ full_linewise_salience,
608
+ (batch_size, self.n_possible_actions, cols_per_action),
609
+ )
610
+ return ops.sum(stacked_linewise_salience, axis=2)
611
+
612
+ def sample(self, particles):
613
+ """Sample actions using task-based information gain maximization.
614
+
615
+ This method computes which lines would provide the most information about
616
+ the downstream task by:
617
+ 1. Computing pixelwise contribution to task variance using gradients
618
+ 2. Aggregating contributions into line-wise scores
619
+ 3. Greedily selecting lines with highest contribution scores
620
+ 4. Reweighting scores around selected lines (inherited from GreedyEntropy)
621
+
622
+ Args:
623
+ particles (Tensor): Particles representing the posterior distribution,
624
+ of shape (batch_size, n_particles, height, width).
625
+
626
+ Returns:
627
+ Tuple[Tensor, Tensor, Tensor]:
628
+ - selected_lines_k_hot: Selected lines as k-hot vectors,
629
+ shaped (batch_size, n_possible_actions)
630
+ - masks: Binary masks of shape (batch_size, img_height, img_width)
631
+ - pixelwise_contribution_to_var_dst: Pixelwise contribution to downstream
632
+ task variance, of shape (batch_size, height, width)
633
+
634
+ Note:
635
+ Unlike the parent GreedyEntropy class, this method returns an additional
636
+ tensor containing the pixelwise contribution scores for analysis.
637
+ """
638
+ pixelwise_contribution_to_var_dst = self.compute_output_and_saliency_propagation(particles)
639
+ linewise_contribution_to_var_dst = ops.sum(pixelwise_contribution_to_var_dst, axis=1)
640
+ actionwise_contribution_to_var_dst = self.sum_neighbouring_columns_into_n_possible_actions(
641
+ linewise_contribution_to_var_dst
642
+ )
643
+
644
+ # Greedily select best line, reweight entropies, and repeat
645
+ all_selected_lines = []
646
+ for _ in range(self.n_actions):
647
+ max_contribution_line, actionwise_contribution_to_var_dst = ops.vectorized_map(
648
+ self.select_line_and_reweight_entropy,
649
+ actionwise_contribution_to_var_dst,
650
+ )
651
+ all_selected_lines.append(max_contribution_line)
652
+
653
+ selected_lines_k_hot = ops.any(
654
+ ops.one_hot(all_selected_lines, self.n_possible_actions, dtype=masks._DEFAULT_DTYPE),
655
+ axis=0,
656
+ )
657
+ return (
658
+ selected_lines_k_hot,
659
+ self.lines_to_im_size(selected_lines_k_hot),
660
+ pixelwise_contribution_to_var_dst,
661
+ )
@@ -25,6 +25,8 @@ Key Features
25
25
 
26
26
  """
27
27
 
28
+ from contextlib import nullcontext
29
+
28
30
  import keras
29
31
 
30
32
  from zea import log
@@ -114,3 +116,90 @@ def _jit_compile(func, jax=True, tensorflow=True, **kwargs):
114
116
  log.warning("Initialize zea.Pipeline with jit_options=None to suppress this warning.")
115
117
  log.warning("Falling back to non-compiled mode.")
116
118
  return func
119
+
120
+
121
+ class on_device:
122
+ """Context manager to set the device regardless of backend.
123
+
124
+ For the `torch` backend, you need to manually move the model and data to the device before
125
+ using this context manager.
126
+
127
+ Args:
128
+ device (str): Device string, e.g. ``'cuda'``, ``'gpu'``, or ``'cpu'``.
129
+
130
+ Example:
131
+ .. code-block:: python
132
+
133
+ with zea.backend.on_device("gpu:3"):
134
+ pipeline = zea.Pipeline([zea.keras_ops.Abs()])
135
+ output = pipeline(data=keras.random.normal((10, 10))) # output is on "cuda:3"
136
+ """
137
+
138
+ def __init__(self, device: str):
139
+ self.device = self.get_device(device)
140
+ self._context = self.get_context(self.device)
141
+
142
+ def get_context(self, device):
143
+ if device is None:
144
+ return nullcontext()
145
+
146
+ if keras.backend.backend() == "tensorflow":
147
+ import tensorflow as tf
148
+
149
+ return tf.device(device)
150
+
151
+ if keras.backend.backend() == "jax":
152
+ import jax
153
+
154
+ return jax.default_device(device)
155
+ if keras.backend.backend() == "torch":
156
+ import torch
157
+
158
+ return torch.device(device)
159
+
160
+ return nullcontext()
161
+
162
+ def get_device(self, device: str):
163
+ if device is None:
164
+ return None
165
+
166
+ device = device.lower()
167
+
168
+ if keras.backend.backend() == "tensorflow":
169
+ return device.replace("cuda", "gpu")
170
+
171
+ if keras.backend.backend() == "jax":
172
+ from zea.backend.jax import str_to_jax_device
173
+
174
+ device = device.replace("cuda", "gpu")
175
+ return str_to_jax_device(device)
176
+
177
+ if keras.backend.backend() == "torch":
178
+ return device.replace("gpu", "cuda")
179
+
180
+ def __enter__(self):
181
+ self._context.__enter__()
182
+
183
+ def __exit__(self, exc_type, exc_val, exc_tb):
184
+ self._context.__exit__(exc_type, exc_val, exc_tb)
185
+
186
+
187
+ if keras.backend.backend() in ["tensorflow", "jax", "numpy"]:
188
+
189
+ def func_on_device(func, device, *args, **kwargs):
190
+ """Moves all tensor arguments of a function to a specified device before calling it.
191
+
192
+ Args:
193
+ func (callable): Function to be called.
194
+ device (str): Device to move tensors to.
195
+ *args: Positional arguments to be passed to the function.
196
+ **kwargs: Keyword arguments to be passed to the function.
197
+ Returns:
198
+ The output of the function.
199
+ """
200
+ with on_device(device):
201
+ return func(*args, **kwargs)
202
+ elif keras.backend.backend() == "torch":
203
+ from zea.backend.torch import func_on_device
204
+ else:
205
+ raise ValueError(f"Unsupported backend: {keras.backend.backend()}")
@@ -0,0 +1,33 @@
1
+ """Jax utilities for zea."""
2
+
3
+ import jax
4
+
5
+
6
+ def str_to_jax_device(device):
7
+ """Convert a device string to a JAX device.
8
+ Args:
9
+ device (str): Device string, e.g. ``'gpu:0'``, or ``'cpu:0'``.
10
+ Returns:
11
+ jax.Device: The corresponding JAX device.
12
+ """
13
+
14
+ if not isinstance(device, str):
15
+ raise ValueError(f"Device must be a string, got {type(device)}")
16
+
17
+ device = device.lower().replace("cuda", "gpu")
18
+
19
+ device = device.split(":")
20
+ if len(device) == 2:
21
+ device_type, device_number = device
22
+ device_number = int(device_number)
23
+ else:
24
+ # if no device number is specified, use the first device
25
+ device_type = device[0]
26
+ device_number = 0
27
+
28
+ available = jax.devices(device_type)
29
+ if len(available) == 0:
30
+ raise ValueError(f"No JAX devices available for type '{device_type}'.")
31
+ if device_number < 0 or device_number >= len(available):
32
+ raise ValueError(f"Device '{device}' is not available; JAX devices found: {available}")
33
+ return available[device_number]
@@ -0,0 +1,17 @@
1
+ """Tensorflow Ultrasound Beamforming Library.
2
+
3
+ Initialize modules for registries.
4
+ """
5
+
6
+ import sys
7
+ from pathlib import PosixPath
8
+
9
+ import numpy as np
10
+
11
+ # Convert PosixPath objects to strings in sys.path
12
+ # this is necessary due to weird TF bug when importing
13
+ sys.path = [str(p) if isinstance(p, PosixPath) else p for p in sys.path]
14
+
15
+ import tensorflow as tf # noqa: E402
16
+
17
+ from .dataloader import make_dataloader # noqa: E402
@@ -0,0 +1,39 @@
1
+ """Pytorch Ultrasound Beamforming Library.
2
+
3
+ Initialize modules for registries.
4
+ """
5
+
6
+ import torch
7
+
8
+
9
+ def func_on_device(func, device, *args, **kwargs):
10
+ """Moves all tensor arguments of a function to a specified device before calling it.
11
+
12
+ Args:
13
+ func (callable): Function to be called.
14
+ device (str or torch.device): Device to move tensors to.
15
+ *args: Positional arguments to be passed to the function.
16
+ **kwargs: Keyword arguments to be passed to the function.
17
+ Returns:
18
+ The output of the function.
19
+ """
20
+ if device is None:
21
+ return func(*args, **kwargs)
22
+
23
+ if isinstance(device, str):
24
+ device = torch.device(device)
25
+
26
+ def move_to_device(x):
27
+ if isinstance(x, torch.Tensor):
28
+ return x.to(device)
29
+ elif isinstance(x, (list, tuple)):
30
+ return type(x)(move_to_device(i) for i in x)
31
+ elif isinstance(x, dict):
32
+ return {k: move_to_device(v) for k, v in x.items()}
33
+ else:
34
+ return x
35
+
36
+ args = move_to_device(args)
37
+ kwargs = move_to_device(kwargs)
38
+
39
+ return func(*args, **kwargs)
@@ -25,7 +25,7 @@ class Resizer(TFDataLayer):
25
25
  Resize layer for resizing images. Can deal with N-dimensional images.
26
26
  Can do resize, center_crop, random_crop and crop_or_pad.
27
27
 
28
- Can be used in tf.data pipelines.
28
+ Can be used in `tf.data` pipelines.
29
29
  """
30
30
 
31
31
  def __init__(
@@ -36,7 +36,6 @@ class Resizer(TFDataLayer):
36
36
  seed: int | None = None,
37
37
  **resize_kwargs,
38
38
  ):
39
- # noqa: E501
40
39
  """
41
40
  Initializes the data loader with the specified parameters.
42
41
 
@@ -47,7 +46,7 @@ class Resizer(TFDataLayer):
47
46
  ['random_crop'](https://keras.io/api/layers/preprocessing_layers/image_augmentation/random_crop/),
48
47
  ['resize'](https://keras.io/api/layers/preprocessing_layers/image_preprocessing/resizing/),
49
48
  'crop_or_pad': resizes an image to a target width and height by either centrally
50
- cropping the image, padding it evenly with zeros or a combination of both.
49
+ cropping the image, padding it evenly with zeros or a combination of both.
51
50
  resize_axes (tuple | None, optional): The axes along which to resize.
52
51
  Must be of length 2. Defaults to None. In that case, can only process
53
52
  default tensors of shape (batch, height, width, channels), where the
@@ -200,7 +200,7 @@ tf_beamformer_registry = RegisterDecorator(items_to_register=["name", "framework
200
200
 
201
201
  torch_beamformer_registry = RegisterDecorator(items_to_register=["name", "framework"])
202
202
 
203
- metrics_registry = RegisterDecorator(items_to_register=["name", "framework", "supervised"])
203
+ metrics_registry = RegisterDecorator(items_to_register=["name", "paired"])
204
204
 
205
205
  checks_registry = RegisterDecorator(items_to_register=["data_type"])
206
206
  ops_registry = RegisterDecorator(items_to_register=["name"])