zea 0.0.8__tar.gz → 0.0.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. {zea-0.0.8 → zea-0.0.10}/PKG-INFO +5 -5
  2. {zea-0.0.8 → zea-0.0.10}/pyproject.toml +5 -5
  3. {zea-0.0.8 → zea-0.0.10}/zea/__init__.py +13 -7
  4. {zea-0.0.8 → zea-0.0.10}/zea/agent/masks.py +17 -5
  5. {zea-0.0.8 → zea-0.0.10}/zea/agent/selection.py +15 -6
  6. {zea-0.0.8 → zea-0.0.10}/zea/backend/__init__.py +18 -4
  7. {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/dataloader.py +1 -1
  8. {zea-0.0.8 → zea-0.0.10}/zea/beamform/beamformer.py +162 -91
  9. {zea-0.0.8 → zea-0.0.10}/zea/beamform/delays.py +12 -9
  10. {zea-0.0.8 → zea-0.0.10}/zea/beamform/lens_correction.py +0 -73
  11. {zea-0.0.8 → zea-0.0.10}/zea/beamform/pfield.py +6 -12
  12. zea-0.0.10/zea/beamform/phantoms.py +145 -0
  13. zea-0.0.10/zea/beamform/pixelgrid.py +189 -0
  14. {zea-0.0.8 → zea-0.0.10}/zea/config.py +2 -2
  15. {zea-0.0.8 → zea-0.0.10}/zea/data/augmentations.py +1 -1
  16. zea-0.0.10/zea/data/convert/__main__.py +174 -0
  17. {zea-0.0.8 → zea-0.0.10}/zea/data/convert/camus.py +8 -2
  18. {zea-0.0.8 → zea-0.0.10}/zea/data/convert/echonet.py +1 -1
  19. {zea-0.0.8 → zea-0.0.10}/zea/data/convert/echonetlvh/__init__.py +48 -51
  20. {zea-0.0.8 → zea-0.0.10}/zea/data/convert/echonetlvh/precompute_crop.py +12 -0
  21. {zea-0.0.8 → zea-0.0.10}/zea/data/convert/images.py +1 -1
  22. zea-0.0.10/zea/data/convert/verasonics.py +1503 -0
  23. {zea-0.0.8 → zea-0.0.10}/zea/data/data_format.py +53 -10
  24. {zea-0.0.8 → zea-0.0.10}/zea/data/datasets.py +0 -7
  25. {zea-0.0.8 → zea-0.0.10}/zea/data/file.py +35 -1
  26. {zea-0.0.8 → zea-0.0.10}/zea/data/file_operations.py +2 -0
  27. {zea-0.0.8 → zea-0.0.10}/zea/data/preset_utils.py +1 -1
  28. {zea-0.0.8 → zea-0.0.10}/zea/display.py +27 -10
  29. {zea-0.0.8 → zea-0.0.10}/zea/doppler.py +6 -6
  30. zea-0.0.10/zea/func/__init__.py +115 -0
  31. zea-0.0.8/zea/tensor_ops.py → zea-0.0.10/zea/func/tensor.py +36 -10
  32. zea-0.0.10/zea/func/ultrasound.py +596 -0
  33. {zea-0.0.8 → zea-0.0.10}/zea/internal/_generate_keras_ops.py +5 -5
  34. {zea-0.0.8 → zea-0.0.10}/zea/internal/config/parameters.py +1 -1
  35. {zea-0.0.8 → zea-0.0.10}/zea/internal/device.py +6 -1
  36. {zea-0.0.8 → zea-0.0.10}/zea/internal/dummy_scan.py +15 -8
  37. zea-0.0.10/zea/internal/notebooks.py +152 -0
  38. {zea-0.0.8 → zea-0.0.10}/zea/internal/parameters.py +20 -0
  39. {zea-0.0.8 → zea-0.0.10}/zea/internal/registry.py +1 -1
  40. {zea-0.0.8 → zea-0.0.10}/zea/metrics.py +88 -71
  41. {zea-0.0.8 → zea-0.0.10}/zea/models/__init__.py +1 -0
  42. {zea-0.0.8 → zea-0.0.10}/zea/models/diffusion.py +6 -1
  43. {zea-0.0.8 → zea-0.0.10}/zea/models/echonetlvh.py +1 -1
  44. {zea-0.0.8 → zea-0.0.10}/zea/models/gmm.py +1 -1
  45. zea-0.0.10/zea/models/hvae/__init__.py +243 -0
  46. zea-0.0.10/zea/models/hvae/model.py +1139 -0
  47. zea-0.0.10/zea/models/hvae/utils.py +616 -0
  48. {zea-0.0.8 → zea-0.0.10}/zea/models/layers.py +1 -1
  49. {zea-0.0.8 → zea-0.0.10}/zea/models/lpips.py +12 -2
  50. {zea-0.0.8 → zea-0.0.10}/zea/models/presets.py +16 -0
  51. zea-0.0.10/zea/ops/__init__.py +190 -0
  52. zea-0.0.10/zea/ops/base.py +441 -0
  53. {zea-0.0.8/zea → zea-0.0.10/zea/ops}/keras_ops.py +2 -2
  54. zea-0.0.10/zea/ops/pipeline.py +1482 -0
  55. zea-0.0.10/zea/ops/tensor.py +333 -0
  56. zea-0.0.10/zea/ops/ultrasound.py +1037 -0
  57. {zea-0.0.8 → zea-0.0.10}/zea/probes.py +4 -10
  58. {zea-0.0.8 → zea-0.0.10}/zea/scan.py +231 -94
  59. {zea-0.0.8 → zea-0.0.10}/zea/simulator.py +8 -8
  60. {zea-0.0.8 → zea-0.0.10}/zea/tools/fit_scan_cone.py +9 -7
  61. {zea-0.0.8 → zea-0.0.10}/zea/tools/selection_tool.py +14 -8
  62. {zea-0.0.8 → zea-0.0.10}/zea/tracking/lucas_kanade.py +1 -1
  63. {zea-0.0.8 → zea-0.0.10}/zea/tracking/segmentation.py +1 -1
  64. {zea-0.0.8 → zea-0.0.10}/zea/visualize.py +3 -1
  65. zea-0.0.8/zea/beamform/phantoms.py +0 -43
  66. zea-0.0.8/zea/beamform/pixelgrid.py +0 -131
  67. zea-0.0.8/zea/data/convert/__main__.py +0 -123
  68. zea-0.0.8/zea/data/convert/verasonics.py +0 -1209
  69. zea-0.0.8/zea/internal/notebooks.py +0 -39
  70. zea-0.0.8/zea/ops.py +0 -3534
  71. {zea-0.0.8 → zea-0.0.10}/LICENSE +0 -0
  72. {zea-0.0.8 → zea-0.0.10}/README.md +0 -0
  73. {zea-0.0.8 → zea-0.0.10}/zea/__main__.py +0 -0
  74. {zea-0.0.8 → zea-0.0.10}/zea/agent/__init__.py +0 -0
  75. {zea-0.0.8 → zea-0.0.10}/zea/agent/gumbel.py +0 -0
  76. {zea-0.0.8 → zea-0.0.10}/zea/backend/autograd.py +0 -0
  77. {zea-0.0.8 → zea-0.0.10}/zea/backend/jax/__init__.py +0 -0
  78. {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/__init__.py +0 -0
  79. {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/layers/__init__.py +0 -0
  80. {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/layers/apodization.py +0 -0
  81. {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/layers/utils.py +0 -0
  82. {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/losses.py +0 -0
  83. {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/models/__init__.py +0 -0
  84. {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/models/lista.py +0 -0
  85. {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/scripts/convert-echonet-dynamic.py +0 -0
  86. {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/scripts/convert-taesd.py +0 -0
  87. {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/utils/__init__.py +0 -0
  88. {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/utils/callbacks.py +0 -0
  89. {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/utils/utils.py +0 -0
  90. {zea-0.0.8 → zea-0.0.10}/zea/backend/tf2jax.py +0 -0
  91. {zea-0.0.8 → zea-0.0.10}/zea/backend/torch/__init__.py +0 -0
  92. {zea-0.0.8 → zea-0.0.10}/zea/backend/torch/losses.py +0 -0
  93. {zea-0.0.8 → zea-0.0.10}/zea/beamform/__init__.py +0 -0
  94. {zea-0.0.8 → zea-0.0.10}/zea/data/__init__.py +0 -0
  95. {zea-0.0.8 → zea-0.0.10}/zea/data/__main__.py +0 -0
  96. {zea-0.0.8 → zea-0.0.10}/zea/data/convert/__init__.py +0 -0
  97. {zea-0.0.8 → zea-0.0.10}/zea/data/convert/echonetlvh/README.md +0 -0
  98. {zea-0.0.8 → zea-0.0.10}/zea/data/convert/echonetlvh/manual_rejections.txt +0 -0
  99. {zea-0.0.8 → zea-0.0.10}/zea/data/convert/picmus.py +0 -0
  100. {zea-0.0.8 → zea-0.0.10}/zea/data/convert/utils.py +0 -0
  101. {zea-0.0.8 → zea-0.0.10}/zea/data/dataloader.py +0 -0
  102. {zea-0.0.8 → zea-0.0.10}/zea/data/layers.py +0 -0
  103. {zea-0.0.8 → zea-0.0.10}/zea/data/utils.py +0 -0
  104. {zea-0.0.8 → zea-0.0.10}/zea/datapaths.py +0 -0
  105. {zea-0.0.8 → zea-0.0.10}/zea/interface.py +0 -0
  106. {zea-0.0.8 → zea-0.0.10}/zea/internal/cache.py +0 -0
  107. {zea-0.0.8 → zea-0.0.10}/zea/internal/checks.py +0 -0
  108. {zea-0.0.8 → zea-0.0.10}/zea/internal/config/create.py +0 -0
  109. {zea-0.0.8 → zea-0.0.10}/zea/internal/config/validation.py +0 -0
  110. {zea-0.0.8 → zea-0.0.10}/zea/internal/core.py +0 -0
  111. {zea-0.0.8 → zea-0.0.10}/zea/internal/git_info.py +0 -0
  112. {zea-0.0.8 → zea-0.0.10}/zea/internal/operators.py +0 -0
  113. {zea-0.0.8 → zea-0.0.10}/zea/internal/setup_zea.py +0 -0
  114. {zea-0.0.8 → zea-0.0.10}/zea/internal/utils.py +0 -0
  115. {zea-0.0.8 → zea-0.0.10}/zea/internal/viewer.py +0 -0
  116. {zea-0.0.8 → zea-0.0.10}/zea/io_lib.py +0 -0
  117. {zea-0.0.8 → zea-0.0.10}/zea/log.py +0 -0
  118. {zea-0.0.8 → zea-0.0.10}/zea/models/base.py +0 -0
  119. {zea-0.0.8 → zea-0.0.10}/zea/models/carotid_segmenter.py +0 -0
  120. {zea-0.0.8 → zea-0.0.10}/zea/models/deeplabv3.py +0 -0
  121. {zea-0.0.8 → zea-0.0.10}/zea/models/dense.py +0 -0
  122. {zea-0.0.8 → zea-0.0.10}/zea/models/echonet.py +0 -0
  123. {zea-0.0.8 → zea-0.0.10}/zea/models/generative.py +0 -0
  124. {zea-0.0.8 → zea-0.0.10}/zea/models/lv_segmentation.py +0 -0
  125. {zea-0.0.8 → zea-0.0.10}/zea/models/preset_utils.py +0 -0
  126. {zea-0.0.8 → zea-0.0.10}/zea/models/regional_quality.py +0 -0
  127. {zea-0.0.8 → zea-0.0.10}/zea/models/taesd.py +0 -0
  128. {zea-0.0.8 → zea-0.0.10}/zea/models/unet.py +0 -0
  129. {zea-0.0.8 → zea-0.0.10}/zea/models/utils.py +0 -0
  130. {zea-0.0.8 → zea-0.0.10}/zea/tools/__init__.py +0 -0
  131. {zea-0.0.8 → zea-0.0.10}/zea/tools/hf.py +0 -0
  132. {zea-0.0.8 → zea-0.0.10}/zea/tools/wndb.py +0 -0
  133. {zea-0.0.8 → zea-0.0.10}/zea/tracking/__init__.py +0 -0
  134. {zea-0.0.8 → zea-0.0.10}/zea/tracking/base.py +0 -0
  135. {zea-0.0.8 → zea-0.0.10}/zea/utils.py +0 -0
  136. {zea-0.0.8 → zea-0.0.10}/zea/zea_darkmode.mplstyle +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: zea
3
- Version: 0.0.8
3
+ Version: 0.0.10
4
4
  Summary: A Toolbox for Cognitive Ultrasound Imaging. Provides a set of tools for processing of ultrasound data, all built in your favorite machine learning framework.
5
5
  License-File: LICENSE
6
6
  Keywords: ultrasound,machine learning,beamforming
@@ -44,8 +44,6 @@ Requires-Dist: jax ; extra == "backends"
44
44
  Requires-Dist: jax[cuda12-pip] (>=0.4.26) ; extra == "jax"
45
45
  Requires-Dist: keras (>=3.12)
46
46
  Requires-Dist: matplotlib (>=3.8)
47
- Requires-Dist: mock ; extra == "dev"
48
- Requires-Dist: mock ; extra == "docs"
49
47
  Requires-Dist: myst-parser ; extra == "dev"
50
48
  Requires-Dist: myst-parser ; extra == "docs"
51
49
  Requires-Dist: nbsphinx ; extra == "dev"
@@ -72,10 +70,10 @@ Requires-Dist: schema (>=0.7)
72
70
  Requires-Dist: scikit-image (>=0.23)
73
71
  Requires-Dist: scikit-learn (>=1.4)
74
72
  Requires-Dist: scipy (>=1.13)
73
+ Requires-Dist: simpleitk (>=2.2.1) ; extra == "dev"
74
+ Requires-Dist: simpleitk (>=2.2.1) ; extra == "tests"
75
75
  Requires-Dist: sphinx ; extra == "dev"
76
76
  Requires-Dist: sphinx ; extra == "docs"
77
- Requires-Dist: sphinx-argparse ; extra == "dev"
78
- Requires-Dist: sphinx-argparse ; extra == "docs"
79
77
  Requires-Dist: sphinx-autobuild ; extra == "dev"
80
78
  Requires-Dist: sphinx-autobuild ; extra == "docs"
81
79
  Requires-Dist: sphinx-autodoc-typehints ; extra == "dev"
@@ -86,6 +84,8 @@ Requires-Dist: sphinx-reredirects ; extra == "dev"
86
84
  Requires-Dist: sphinx-reredirects ; extra == "docs"
87
85
  Requires-Dist: sphinx_design ; extra == "dev"
88
86
  Requires-Dist: sphinx_design ; extra == "docs"
87
+ Requires-Dist: sphinxcontrib-autoprogram ; extra == "dev"
88
+ Requires-Dist: sphinxcontrib-autoprogram ; extra == "docs"
89
89
  Requires-Dist: sphinxcontrib-bibtex ; extra == "dev"
90
90
  Requires-Dist: sphinxcontrib-bibtex ; extra == "docs"
91
91
  Requires-Dist: tensorflow ; extra == "backends"
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "zea"
3
- version = "0.0.8"
3
+ version = "0.0.10"
4
4
  description = "A Toolbox for Cognitive Ultrasound Imaging. Provides a set of tools for processing of ultrasound data, all built in your favorite machine learning framework."
5
5
  authors = [
6
6
  { name = "Tristan Stevens", email = "t.s.w.stevens@tue.nl" },
@@ -54,6 +54,7 @@ dev = [
54
54
  "papermill>=2.4",
55
55
  "ipykernel>=6.29.5",
56
56
  "cloudpickle>=3.1.1",
57
+ "simpleitk>=2.2.1",
57
58
  "ipywidgets",
58
59
  "pre-commit",
59
60
  "ruff",
@@ -64,10 +65,9 @@ dev = [
64
65
  "sphinx-autodoc-typehints",
65
66
  "sphinx-copybutton",
66
67
  "sphinx_design",
67
- "sphinx-argparse",
68
+ "sphinxcontrib-autoprogram",
68
69
  "sphinx-reredirects",
69
70
  "sphinxcontrib-bibtex",
70
- "mock",
71
71
  "myst-parser",
72
72
  "nbsphinx",
73
73
  "furo",
@@ -84,6 +84,7 @@ tests = [
84
84
  "papermill>=2.4",
85
85
  "ipykernel>=6.29.5",
86
86
  "cloudpickle>=3.1.1",
87
+ "simpleitk>=2.2.1",
87
88
  "ipywidgets",
88
89
  "pre-commit",
89
90
  "ruff",
@@ -95,10 +96,9 @@ docs = [
95
96
  "sphinx-autodoc-typehints",
96
97
  "sphinx-copybutton",
97
98
  "sphinx_design",
98
- "sphinx-argparse",
99
+ "sphinxcontrib-autoprogram",
99
100
  "sphinx-reredirects",
100
101
  "sphinxcontrib-bibtex",
101
- "mock",
102
102
  "myst-parser",
103
103
  "nbsphinx",
104
104
  "furo",
@@ -2,12 +2,16 @@
2
2
 
3
3
  import importlib.util
4
4
  import os
5
+ from importlib.metadata import PackageNotFoundError, version
5
6
 
6
7
  from . import log
7
8
 
8
- # dynamically add __version__ attribute (see pyproject.toml)
9
- # __version__ = __import__("importlib.metadata").metadata.version(__package__)
10
- __version__ = "0.0.8"
9
+ try:
10
+ # dynamically add __version__ attribute (see pyproject.toml)
11
+ __version__ = version("zea")
12
+ except PackageNotFoundError:
13
+ # Package is not installed (e.g., running from source)
14
+ __version__ = "dev"
11
15
 
12
16
 
13
17
  def _bootstrap_backend():
@@ -80,8 +84,10 @@ def _bootstrap_backend():
80
84
  log.info(f"Using backend {keras_backend()!r}")
81
85
 
82
86
 
83
- # call and clean up namespace
84
- _bootstrap_backend()
87
+ # Skip backend bootstrap when building on ReadTheDocs
88
+ if os.environ.get("READTHEDOCS") != "True":
89
+ _bootstrap_backend()
90
+
85
91
  del _bootstrap_backend
86
92
 
87
93
  from . import (
@@ -89,12 +95,12 @@ from . import (
89
95
  beamform,
90
96
  data,
91
97
  display,
98
+ func,
92
99
  io_lib,
93
- keras_ops,
94
100
  metrics,
95
101
  models,
102
+ ops,
96
103
  simulator,
97
- tensor_ops,
98
104
  utils,
99
105
  visualize,
100
106
  )
@@ -4,13 +4,15 @@ Mask generation utilities.
4
4
  These masks are used as a measurement operator for focused scan-line subsampling.
5
5
  """
6
6
 
7
+ from __future__ import annotations
8
+
7
9
  from typing import List
8
10
 
9
11
  import keras
10
12
  from keras import ops
11
13
 
12
- from zea import tensor_ops
13
14
  from zea.agent.gumbel import hard_straight_through
15
+ from zea.func.tensor import nonzero
14
16
 
15
17
  _DEFAULT_DTYPE = "bool"
16
18
 
@@ -56,7 +58,7 @@ def k_hot_to_indices(selected_lines, n_actions: int, fill_value=-1):
56
58
 
57
59
  # Find nonzero indices for each frame
58
60
  def get_nonzero(row):
59
- return tensor_ops.nonzero(row > 0, size=n_actions, fill_value=fill_value)[0]
61
+ return nonzero(row > 0, size=n_actions, fill_value=fill_value)[0]
60
62
 
61
63
  indices = ops.vectorized_map(get_nonzero, selected_lines)
62
64
  return indices
@@ -117,11 +119,21 @@ def initial_equispaced_lines(
117
119
  Tensor: k-hot-encoded line vector of shape (n_possible_actions).
118
120
  Needs to be converted to image size.
119
121
  """
122
+ assert n_actions > 0, "Number of actions must be > 0."
123
+ assert n_possible_actions > 0, "Number of possible actions must be > 0."
124
+ assert n_actions <= n_possible_actions, (
125
+ "Number of actions must be less than or equal to number of possible actions."
126
+ )
127
+
120
128
  if assert_equal_spacing:
121
129
  _assert_equal_spacing(n_actions, n_possible_actions)
122
- selected_indices = ops.arange(0, n_possible_actions, n_possible_actions // n_actions)
123
- else:
124
- selected_indices = ops.linspace(0, n_possible_actions - 1, n_actions, dtype="int32")
130
+
131
+ # Distribute indices as evenly as possible
132
+ # This approach ensures spacing differs by at most 1 when not divisible
133
+ step = n_possible_actions / n_actions
134
+ selected_indices = ops.cast(
135
+ ops.round(ops.arange(0, n_actions, dtype="float32") * step), "int32"
136
+ )
125
137
 
126
138
  return indices_to_k_hot(selected_indices, n_possible_actions, dtype=dtype)
127
139
 
@@ -16,9 +16,9 @@ from typing import Callable
16
16
  import keras
17
17
  from keras import ops
18
18
 
19
- from zea import tensor_ops
20
19
  from zea.agent import masks
21
20
  from zea.backend.autograd import AutoGrad
21
+ from zea.func import tensor
22
22
  from zea.internal.registry import action_selection_registry
23
23
 
24
24
 
@@ -96,6 +96,7 @@ class GreedyEntropy(LinesActionModel):
96
96
  std_dev: float = 1,
97
97
  num_lines_to_update: int = 5,
98
98
  entropy_sigma: float = 1.0,
99
+ average_entropy_across_batch: bool = False,
99
100
  ):
100
101
  """Initialize the GreedyEntropy action selection model.
101
102
 
@@ -110,6 +111,10 @@ class GreedyEntropy(LinesActionModel):
110
111
  to update. Must be odd.
111
112
  entropy_sigma (float, optional): The standard deviation of the Gaussian
112
113
  Mixture components used to approximate the posterior.
114
+ average_entropy_across_batch (bool, optional): Whether to average entropy
115
+ across the batch when selecting lines. This can be useful when
116
+ selecting planes in 3D imaging, where the batch dimension represents
117
+ a third spatial dimension. Defaults to False.
113
118
  """
114
119
  super().__init__(n_actions, n_possible_actions, img_width, img_height)
115
120
 
@@ -117,6 +122,7 @@ class GreedyEntropy(LinesActionModel):
117
122
  # of the selected line is set to 0 once it's been selected.
118
123
  assert num_lines_to_update % 2 == 1, "num_samples must be odd."
119
124
  self.num_lines_to_update = num_lines_to_update
125
+ self.average_entropy_across_batch = average_entropy_across_batch
120
126
 
121
127
  # see here what I mean by upside_down_gaussian:
122
128
  # https://colab.research.google.com/drive/1CQp_Z6nADzOFsybdiH5Cag0vtVZjjioU?usp=sharing
@@ -153,7 +159,7 @@ class GreedyEntropy(LinesActionModel):
153
159
  assert particles.shape[1] > 1, "The entropy cannot be approximated using a single particle."
154
160
 
155
161
  if n_possible_actions is None:
156
- n_possible_actions = particles.shape[-1]
162
+ n_possible_actions = ops.shape(particles)[-1]
157
163
 
158
164
  # TODO: I think we only need to compute the lower triangular
159
165
  # of this matrix, since it's symmetric
@@ -164,7 +170,8 @@ class GreedyEntropy(LinesActionModel):
164
170
  # Vertically stack all columns corresponding with the same line
165
171
  # This way we can just sum across the height axis and get the entropy
166
172
  # for each pixel in a given line
167
- batch_size, n_particles, _, height, _ = gaussian_error_per_pixel_i_j.shape
173
+ batch_size, n_particles, _, height, _ = ops.shape(gaussian_error_per_pixel_i_j)
174
+
168
175
  gaussian_error_per_pixel_stacked = ops.transpose(
169
176
  ops.reshape(
170
177
  ops.transpose(gaussian_error_per_pixel_i_j, (0, 1, 2, 4, 3)),
@@ -274,6 +281,8 @@ class GreedyEntropy(LinesActionModel):
274
281
 
275
282
  pixelwise_entropy = self.compute_pixelwise_entropy(particles)
276
283
  linewise_entropy = ops.sum(pixelwise_entropy, axis=1)
284
+ if self.average_entropy_across_batch:
285
+ linewise_entropy = ops.expand_dims(ops.mean(linewise_entropy, axis=0), axis=0)
277
286
 
278
287
  # Greedily select best line, reweight entropies, and repeat
279
288
  all_selected_lines = []
@@ -334,7 +343,7 @@ class EquispacedLines(LinesActionModel):
334
343
  n_possible_actions: int,
335
344
  img_width: int,
336
345
  img_height: int,
337
- assert_equal_spacing=True,
346
+ assert_equal_spacing: bool = True,
338
347
  ):
339
348
  super().__init__(n_actions, n_possible_actions, img_width, img_height)
340
349
 
@@ -462,7 +471,7 @@ class CovarianceSamplingLines(LinesActionModel):
462
471
  particles = ops.reshape(particles, shape)
463
472
 
464
473
  # [batch_size, rows * stack_n_cols, n_possible_actions, n_possible_actions]
465
- cov_matrix = tensor_ops.batch_cov(particles)
474
+ cov_matrix = tensor.batch_cov(particles)
466
475
 
467
476
  # Sum over the row dimension [batch_size, n_possible_actions, n_possible_actions]
468
477
  cov_matrix = ops.sum(cov_matrix, axis=1)
@@ -477,7 +486,7 @@ class CovarianceSamplingLines(LinesActionModel):
477
486
  # Subsample the covariance matrix with random lines
478
487
  def subsample_with_mask(mask):
479
488
  """Subsample the covariance matrix with a single mask."""
480
- subsampled_cov_matrix = tensor_ops.boolean_mask(
489
+ subsampled_cov_matrix = tensor.boolean_mask(
481
490
  cov_matrix, mask, size=batch_size * self.n_actions**2
482
491
  )
483
492
  return ops.reshape(subsampled_cov_matrix, [batch_size, self.n_actions, self.n_actions])
@@ -59,8 +59,21 @@ def _import_torch():
59
59
  return None
60
60
 
61
61
 
62
+ def _get_backend():
63
+ try:
64
+ backend_result = keras.backend.backend()
65
+ if isinstance(backend_result, str):
66
+ return backend_result
67
+ else:
68
+ # to handle mocked backends during testing
69
+ return None
70
+ except Exception:
71
+ return None
72
+
73
+
62
74
  tf_mod = _import_tf()
63
75
  jax_mod = _import_jax()
76
+ backend = _get_backend()
64
77
 
65
78
 
66
79
  def tf_function(func=None, jit_compile=False, **kwargs):
@@ -131,7 +144,7 @@ class on_device:
131
144
  .. code-block:: python
132
145
 
133
146
  with zea.backend.on_device("gpu:3"):
134
- pipeline = zea.Pipeline([zea.keras_ops.Abs()])
147
+ pipeline = zea.Pipeline([zea.ops.Abs()])
135
148
  output = pipeline(data=keras.random.normal((10, 10))) # output is on "cuda:3"
136
149
  """
137
150
 
@@ -184,7 +197,7 @@ class on_device:
184
197
  self._context.__exit__(exc_type, exc_val, exc_tb)
185
198
 
186
199
 
187
- if keras.backend.backend() in ["tensorflow", "jax", "numpy"]:
200
+ if backend in [None, "tensorflow", "jax", "numpy"]:
188
201
 
189
202
  def func_on_device(func, device, *args, **kwargs):
190
203
  """Moves all tensor arguments of a function to a specified device before calling it.
@@ -199,7 +212,8 @@ if keras.backend.backend() in ["tensorflow", "jax", "numpy"]:
199
212
  """
200
213
  with on_device(device):
201
214
  return func(*args, **kwargs)
202
- elif keras.backend.backend() == "torch":
215
+
216
+ elif backend == "torch":
203
217
  from zea.backend.torch import func_on_device
204
218
  else:
205
- raise ValueError(f"Unsupported backend: {keras.backend.backend()}")
219
+ raise ValueError(f"Unsupported backend: {backend}")
@@ -12,8 +12,8 @@ from keras.src.trainers.data_adapters import TFDatasetAdapter
12
12
 
13
13
  from zea.data.dataloader import H5Generator
14
14
  from zea.data.layers import Resizer
15
+ from zea.func.tensor import translate
15
16
  from zea.internal.utils import find_methods_with_return_type
16
- from zea.tensor_ops import translate
17
17
 
18
18
  METHODS_THAT_RETURN_DATASET = find_methods_with_return_type(tf.data.Dataset, "DatasetV2")
19
19