zea 0.0.8__tar.gz → 0.0.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {zea-0.0.8 → zea-0.0.10}/PKG-INFO +5 -5
- {zea-0.0.8 → zea-0.0.10}/pyproject.toml +5 -5
- {zea-0.0.8 → zea-0.0.10}/zea/__init__.py +13 -7
- {zea-0.0.8 → zea-0.0.10}/zea/agent/masks.py +17 -5
- {zea-0.0.8 → zea-0.0.10}/zea/agent/selection.py +15 -6
- {zea-0.0.8 → zea-0.0.10}/zea/backend/__init__.py +18 -4
- {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/dataloader.py +1 -1
- {zea-0.0.8 → zea-0.0.10}/zea/beamform/beamformer.py +162 -91
- {zea-0.0.8 → zea-0.0.10}/zea/beamform/delays.py +12 -9
- {zea-0.0.8 → zea-0.0.10}/zea/beamform/lens_correction.py +0 -73
- {zea-0.0.8 → zea-0.0.10}/zea/beamform/pfield.py +6 -12
- zea-0.0.10/zea/beamform/phantoms.py +145 -0
- zea-0.0.10/zea/beamform/pixelgrid.py +189 -0
- {zea-0.0.8 → zea-0.0.10}/zea/config.py +2 -2
- {zea-0.0.8 → zea-0.0.10}/zea/data/augmentations.py +1 -1
- zea-0.0.10/zea/data/convert/__main__.py +174 -0
- {zea-0.0.8 → zea-0.0.10}/zea/data/convert/camus.py +8 -2
- {zea-0.0.8 → zea-0.0.10}/zea/data/convert/echonet.py +1 -1
- {zea-0.0.8 → zea-0.0.10}/zea/data/convert/echonetlvh/__init__.py +48 -51
- {zea-0.0.8 → zea-0.0.10}/zea/data/convert/echonetlvh/precompute_crop.py +12 -0
- {zea-0.0.8 → zea-0.0.10}/zea/data/convert/images.py +1 -1
- zea-0.0.10/zea/data/convert/verasonics.py +1503 -0
- {zea-0.0.8 → zea-0.0.10}/zea/data/data_format.py +53 -10
- {zea-0.0.8 → zea-0.0.10}/zea/data/datasets.py +0 -7
- {zea-0.0.8 → zea-0.0.10}/zea/data/file.py +35 -1
- {zea-0.0.8 → zea-0.0.10}/zea/data/file_operations.py +2 -0
- {zea-0.0.8 → zea-0.0.10}/zea/data/preset_utils.py +1 -1
- {zea-0.0.8 → zea-0.0.10}/zea/display.py +27 -10
- {zea-0.0.8 → zea-0.0.10}/zea/doppler.py +6 -6
- zea-0.0.10/zea/func/__init__.py +115 -0
- zea-0.0.8/zea/tensor_ops.py → zea-0.0.10/zea/func/tensor.py +36 -10
- zea-0.0.10/zea/func/ultrasound.py +596 -0
- {zea-0.0.8 → zea-0.0.10}/zea/internal/_generate_keras_ops.py +5 -5
- {zea-0.0.8 → zea-0.0.10}/zea/internal/config/parameters.py +1 -1
- {zea-0.0.8 → zea-0.0.10}/zea/internal/device.py +6 -1
- {zea-0.0.8 → zea-0.0.10}/zea/internal/dummy_scan.py +15 -8
- zea-0.0.10/zea/internal/notebooks.py +152 -0
- {zea-0.0.8 → zea-0.0.10}/zea/internal/parameters.py +20 -0
- {zea-0.0.8 → zea-0.0.10}/zea/internal/registry.py +1 -1
- {zea-0.0.8 → zea-0.0.10}/zea/metrics.py +88 -71
- {zea-0.0.8 → zea-0.0.10}/zea/models/__init__.py +1 -0
- {zea-0.0.8 → zea-0.0.10}/zea/models/diffusion.py +6 -1
- {zea-0.0.8 → zea-0.0.10}/zea/models/echonetlvh.py +1 -1
- {zea-0.0.8 → zea-0.0.10}/zea/models/gmm.py +1 -1
- zea-0.0.10/zea/models/hvae/__init__.py +243 -0
- zea-0.0.10/zea/models/hvae/model.py +1139 -0
- zea-0.0.10/zea/models/hvae/utils.py +616 -0
- {zea-0.0.8 → zea-0.0.10}/zea/models/layers.py +1 -1
- {zea-0.0.8 → zea-0.0.10}/zea/models/lpips.py +12 -2
- {zea-0.0.8 → zea-0.0.10}/zea/models/presets.py +16 -0
- zea-0.0.10/zea/ops/__init__.py +190 -0
- zea-0.0.10/zea/ops/base.py +441 -0
- {zea-0.0.8/zea → zea-0.0.10/zea/ops}/keras_ops.py +2 -2
- zea-0.0.10/zea/ops/pipeline.py +1482 -0
- zea-0.0.10/zea/ops/tensor.py +333 -0
- zea-0.0.10/zea/ops/ultrasound.py +1037 -0
- {zea-0.0.8 → zea-0.0.10}/zea/probes.py +4 -10
- {zea-0.0.8 → zea-0.0.10}/zea/scan.py +231 -94
- {zea-0.0.8 → zea-0.0.10}/zea/simulator.py +8 -8
- {zea-0.0.8 → zea-0.0.10}/zea/tools/fit_scan_cone.py +9 -7
- {zea-0.0.8 → zea-0.0.10}/zea/tools/selection_tool.py +14 -8
- {zea-0.0.8 → zea-0.0.10}/zea/tracking/lucas_kanade.py +1 -1
- {zea-0.0.8 → zea-0.0.10}/zea/tracking/segmentation.py +1 -1
- {zea-0.0.8 → zea-0.0.10}/zea/visualize.py +3 -1
- zea-0.0.8/zea/beamform/phantoms.py +0 -43
- zea-0.0.8/zea/beamform/pixelgrid.py +0 -131
- zea-0.0.8/zea/data/convert/__main__.py +0 -123
- zea-0.0.8/zea/data/convert/verasonics.py +0 -1209
- zea-0.0.8/zea/internal/notebooks.py +0 -39
- zea-0.0.8/zea/ops.py +0 -3534
- {zea-0.0.8 → zea-0.0.10}/LICENSE +0 -0
- {zea-0.0.8 → zea-0.0.10}/README.md +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/__main__.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/agent/__init__.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/agent/gumbel.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/backend/autograd.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/backend/jax/__init__.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/__init__.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/layers/__init__.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/layers/apodization.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/layers/utils.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/losses.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/models/__init__.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/models/lista.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/scripts/convert-echonet-dynamic.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/scripts/convert-taesd.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/utils/__init__.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/utils/callbacks.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/backend/tensorflow/utils/utils.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/backend/tf2jax.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/backend/torch/__init__.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/backend/torch/losses.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/beamform/__init__.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/data/__init__.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/data/__main__.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/data/convert/__init__.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/data/convert/echonetlvh/README.md +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/data/convert/echonetlvh/manual_rejections.txt +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/data/convert/picmus.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/data/convert/utils.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/data/dataloader.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/data/layers.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/data/utils.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/datapaths.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/interface.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/internal/cache.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/internal/checks.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/internal/config/create.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/internal/config/validation.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/internal/core.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/internal/git_info.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/internal/operators.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/internal/setup_zea.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/internal/utils.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/internal/viewer.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/io_lib.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/log.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/models/base.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/models/carotid_segmenter.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/models/deeplabv3.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/models/dense.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/models/echonet.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/models/generative.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/models/lv_segmentation.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/models/preset_utils.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/models/regional_quality.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/models/taesd.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/models/unet.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/models/utils.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/tools/__init__.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/tools/hf.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/tools/wndb.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/tracking/__init__.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/tracking/base.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/utils.py +0 -0
- {zea-0.0.8 → zea-0.0.10}/zea/zea_darkmode.mplstyle +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: zea
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.10
|
|
4
4
|
Summary: A Toolbox for Cognitive Ultrasound Imaging. Provides a set of tools for processing of ultrasound data, all built in your favorite machine learning framework.
|
|
5
5
|
License-File: LICENSE
|
|
6
6
|
Keywords: ultrasound,machine learning,beamforming
|
|
@@ -44,8 +44,6 @@ Requires-Dist: jax ; extra == "backends"
|
|
|
44
44
|
Requires-Dist: jax[cuda12-pip] (>=0.4.26) ; extra == "jax"
|
|
45
45
|
Requires-Dist: keras (>=3.12)
|
|
46
46
|
Requires-Dist: matplotlib (>=3.8)
|
|
47
|
-
Requires-Dist: mock ; extra == "dev"
|
|
48
|
-
Requires-Dist: mock ; extra == "docs"
|
|
49
47
|
Requires-Dist: myst-parser ; extra == "dev"
|
|
50
48
|
Requires-Dist: myst-parser ; extra == "docs"
|
|
51
49
|
Requires-Dist: nbsphinx ; extra == "dev"
|
|
@@ -72,10 +70,10 @@ Requires-Dist: schema (>=0.7)
|
|
|
72
70
|
Requires-Dist: scikit-image (>=0.23)
|
|
73
71
|
Requires-Dist: scikit-learn (>=1.4)
|
|
74
72
|
Requires-Dist: scipy (>=1.13)
|
|
73
|
+
Requires-Dist: simpleitk (>=2.2.1) ; extra == "dev"
|
|
74
|
+
Requires-Dist: simpleitk (>=2.2.1) ; extra == "tests"
|
|
75
75
|
Requires-Dist: sphinx ; extra == "dev"
|
|
76
76
|
Requires-Dist: sphinx ; extra == "docs"
|
|
77
|
-
Requires-Dist: sphinx-argparse ; extra == "dev"
|
|
78
|
-
Requires-Dist: sphinx-argparse ; extra == "docs"
|
|
79
77
|
Requires-Dist: sphinx-autobuild ; extra == "dev"
|
|
80
78
|
Requires-Dist: sphinx-autobuild ; extra == "docs"
|
|
81
79
|
Requires-Dist: sphinx-autodoc-typehints ; extra == "dev"
|
|
@@ -86,6 +84,8 @@ Requires-Dist: sphinx-reredirects ; extra == "dev"
|
|
|
86
84
|
Requires-Dist: sphinx-reredirects ; extra == "docs"
|
|
87
85
|
Requires-Dist: sphinx_design ; extra == "dev"
|
|
88
86
|
Requires-Dist: sphinx_design ; extra == "docs"
|
|
87
|
+
Requires-Dist: sphinxcontrib-autoprogram ; extra == "dev"
|
|
88
|
+
Requires-Dist: sphinxcontrib-autoprogram ; extra == "docs"
|
|
89
89
|
Requires-Dist: sphinxcontrib-bibtex ; extra == "dev"
|
|
90
90
|
Requires-Dist: sphinxcontrib-bibtex ; extra == "docs"
|
|
91
91
|
Requires-Dist: tensorflow ; extra == "backends"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "zea"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.10"
|
|
4
4
|
description = "A Toolbox for Cognitive Ultrasound Imaging. Provides a set of tools for processing of ultrasound data, all built in your favorite machine learning framework."
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Tristan Stevens", email = "t.s.w.stevens@tue.nl" },
|
|
@@ -54,6 +54,7 @@ dev = [
|
|
|
54
54
|
"papermill>=2.4",
|
|
55
55
|
"ipykernel>=6.29.5",
|
|
56
56
|
"cloudpickle>=3.1.1",
|
|
57
|
+
"simpleitk>=2.2.1",
|
|
57
58
|
"ipywidgets",
|
|
58
59
|
"pre-commit",
|
|
59
60
|
"ruff",
|
|
@@ -64,10 +65,9 @@ dev = [
|
|
|
64
65
|
"sphinx-autodoc-typehints",
|
|
65
66
|
"sphinx-copybutton",
|
|
66
67
|
"sphinx_design",
|
|
67
|
-
"
|
|
68
|
+
"sphinxcontrib-autoprogram",
|
|
68
69
|
"sphinx-reredirects",
|
|
69
70
|
"sphinxcontrib-bibtex",
|
|
70
|
-
"mock",
|
|
71
71
|
"myst-parser",
|
|
72
72
|
"nbsphinx",
|
|
73
73
|
"furo",
|
|
@@ -84,6 +84,7 @@ tests = [
|
|
|
84
84
|
"papermill>=2.4",
|
|
85
85
|
"ipykernel>=6.29.5",
|
|
86
86
|
"cloudpickle>=3.1.1",
|
|
87
|
+
"simpleitk>=2.2.1",
|
|
87
88
|
"ipywidgets",
|
|
88
89
|
"pre-commit",
|
|
89
90
|
"ruff",
|
|
@@ -95,10 +96,9 @@ docs = [
|
|
|
95
96
|
"sphinx-autodoc-typehints",
|
|
96
97
|
"sphinx-copybutton",
|
|
97
98
|
"sphinx_design",
|
|
98
|
-
"
|
|
99
|
+
"sphinxcontrib-autoprogram",
|
|
99
100
|
"sphinx-reredirects",
|
|
100
101
|
"sphinxcontrib-bibtex",
|
|
101
|
-
"mock",
|
|
102
102
|
"myst-parser",
|
|
103
103
|
"nbsphinx",
|
|
104
104
|
"furo",
|
|
@@ -2,12 +2,16 @@
|
|
|
2
2
|
|
|
3
3
|
import importlib.util
|
|
4
4
|
import os
|
|
5
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
5
6
|
|
|
6
7
|
from . import log
|
|
7
8
|
|
|
8
|
-
|
|
9
|
-
# __version__
|
|
10
|
-
__version__ = "
|
|
9
|
+
try:
|
|
10
|
+
# dynamically add __version__ attribute (see pyproject.toml)
|
|
11
|
+
__version__ = version("zea")
|
|
12
|
+
except PackageNotFoundError:
|
|
13
|
+
# Package is not installed (e.g., running from source)
|
|
14
|
+
__version__ = "dev"
|
|
11
15
|
|
|
12
16
|
|
|
13
17
|
def _bootstrap_backend():
|
|
@@ -80,8 +84,10 @@ def _bootstrap_backend():
|
|
|
80
84
|
log.info(f"Using backend {keras_backend()!r}")
|
|
81
85
|
|
|
82
86
|
|
|
83
|
-
#
|
|
84
|
-
|
|
87
|
+
# Skip backend bootstrap when building on ReadTheDocs
|
|
88
|
+
if os.environ.get("READTHEDOCS") != "True":
|
|
89
|
+
_bootstrap_backend()
|
|
90
|
+
|
|
85
91
|
del _bootstrap_backend
|
|
86
92
|
|
|
87
93
|
from . import (
|
|
@@ -89,12 +95,12 @@ from . import (
|
|
|
89
95
|
beamform,
|
|
90
96
|
data,
|
|
91
97
|
display,
|
|
98
|
+
func,
|
|
92
99
|
io_lib,
|
|
93
|
-
keras_ops,
|
|
94
100
|
metrics,
|
|
95
101
|
models,
|
|
102
|
+
ops,
|
|
96
103
|
simulator,
|
|
97
|
-
tensor_ops,
|
|
98
104
|
utils,
|
|
99
105
|
visualize,
|
|
100
106
|
)
|
|
@@ -4,13 +4,15 @@ Mask generation utilities.
|
|
|
4
4
|
These masks are used as a measurement operator for focused scan-line subsampling.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
7
9
|
from typing import List
|
|
8
10
|
|
|
9
11
|
import keras
|
|
10
12
|
from keras import ops
|
|
11
13
|
|
|
12
|
-
from zea import tensor_ops
|
|
13
14
|
from zea.agent.gumbel import hard_straight_through
|
|
15
|
+
from zea.func.tensor import nonzero
|
|
14
16
|
|
|
15
17
|
_DEFAULT_DTYPE = "bool"
|
|
16
18
|
|
|
@@ -56,7 +58,7 @@ def k_hot_to_indices(selected_lines, n_actions: int, fill_value=-1):
|
|
|
56
58
|
|
|
57
59
|
# Find nonzero indices for each frame
|
|
58
60
|
def get_nonzero(row):
|
|
59
|
-
return
|
|
61
|
+
return nonzero(row > 0, size=n_actions, fill_value=fill_value)[0]
|
|
60
62
|
|
|
61
63
|
indices = ops.vectorized_map(get_nonzero, selected_lines)
|
|
62
64
|
return indices
|
|
@@ -117,11 +119,21 @@ def initial_equispaced_lines(
|
|
|
117
119
|
Tensor: k-hot-encoded line vector of shape (n_possible_actions).
|
|
118
120
|
Needs to be converted to image size.
|
|
119
121
|
"""
|
|
122
|
+
assert n_actions > 0, "Number of actions must be > 0."
|
|
123
|
+
assert n_possible_actions > 0, "Number of possible actions must be > 0."
|
|
124
|
+
assert n_actions <= n_possible_actions, (
|
|
125
|
+
"Number of actions must be less than or equal to number of possible actions."
|
|
126
|
+
)
|
|
127
|
+
|
|
120
128
|
if assert_equal_spacing:
|
|
121
129
|
_assert_equal_spacing(n_actions, n_possible_actions)
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
130
|
+
|
|
131
|
+
# Distribute indices as evenly as possible
|
|
132
|
+
# This approach ensures spacing differs by at most 1 when not divisible
|
|
133
|
+
step = n_possible_actions / n_actions
|
|
134
|
+
selected_indices = ops.cast(
|
|
135
|
+
ops.round(ops.arange(0, n_actions, dtype="float32") * step), "int32"
|
|
136
|
+
)
|
|
125
137
|
|
|
126
138
|
return indices_to_k_hot(selected_indices, n_possible_actions, dtype=dtype)
|
|
127
139
|
|
|
@@ -16,9 +16,9 @@ from typing import Callable
|
|
|
16
16
|
import keras
|
|
17
17
|
from keras import ops
|
|
18
18
|
|
|
19
|
-
from zea import tensor_ops
|
|
20
19
|
from zea.agent import masks
|
|
21
20
|
from zea.backend.autograd import AutoGrad
|
|
21
|
+
from zea.func import tensor
|
|
22
22
|
from zea.internal.registry import action_selection_registry
|
|
23
23
|
|
|
24
24
|
|
|
@@ -96,6 +96,7 @@ class GreedyEntropy(LinesActionModel):
|
|
|
96
96
|
std_dev: float = 1,
|
|
97
97
|
num_lines_to_update: int = 5,
|
|
98
98
|
entropy_sigma: float = 1.0,
|
|
99
|
+
average_entropy_across_batch: bool = False,
|
|
99
100
|
):
|
|
100
101
|
"""Initialize the GreedyEntropy action selection model.
|
|
101
102
|
|
|
@@ -110,6 +111,10 @@ class GreedyEntropy(LinesActionModel):
|
|
|
110
111
|
to update. Must be odd.
|
|
111
112
|
entropy_sigma (float, optional): The standard deviation of the Gaussian
|
|
112
113
|
Mixture components used to approximate the posterior.
|
|
114
|
+
average_entropy_across_batch (bool, optional): Whether to average entropy
|
|
115
|
+
across the batch when selecting lines. This can be useful when
|
|
116
|
+
selecting planes in 3D imaging, where the batch dimension represents
|
|
117
|
+
a third spatial dimension. Defaults to False.
|
|
113
118
|
"""
|
|
114
119
|
super().__init__(n_actions, n_possible_actions, img_width, img_height)
|
|
115
120
|
|
|
@@ -117,6 +122,7 @@ class GreedyEntropy(LinesActionModel):
|
|
|
117
122
|
# of the selected line is set to 0 once it's been selected.
|
|
118
123
|
assert num_lines_to_update % 2 == 1, "num_samples must be odd."
|
|
119
124
|
self.num_lines_to_update = num_lines_to_update
|
|
125
|
+
self.average_entropy_across_batch = average_entropy_across_batch
|
|
120
126
|
|
|
121
127
|
# see here what I mean by upside_down_gaussian:
|
|
122
128
|
# https://colab.research.google.com/drive/1CQp_Z6nADzOFsybdiH5Cag0vtVZjjioU?usp=sharing
|
|
@@ -153,7 +159,7 @@ class GreedyEntropy(LinesActionModel):
|
|
|
153
159
|
assert particles.shape[1] > 1, "The entropy cannot be approximated using a single particle."
|
|
154
160
|
|
|
155
161
|
if n_possible_actions is None:
|
|
156
|
-
n_possible_actions =
|
|
162
|
+
n_possible_actions = ops.shape(particles)[-1]
|
|
157
163
|
|
|
158
164
|
# TODO: I think we only need to compute the lower triangular
|
|
159
165
|
# of this matrix, since it's symmetric
|
|
@@ -164,7 +170,8 @@ class GreedyEntropy(LinesActionModel):
|
|
|
164
170
|
# Vertically stack all columns corresponding with the same line
|
|
165
171
|
# This way we can just sum across the height axis and get the entropy
|
|
166
172
|
# for each pixel in a given line
|
|
167
|
-
batch_size, n_particles, _, height, _ =
|
|
173
|
+
batch_size, n_particles, _, height, _ = ops.shape(gaussian_error_per_pixel_i_j)
|
|
174
|
+
|
|
168
175
|
gaussian_error_per_pixel_stacked = ops.transpose(
|
|
169
176
|
ops.reshape(
|
|
170
177
|
ops.transpose(gaussian_error_per_pixel_i_j, (0, 1, 2, 4, 3)),
|
|
@@ -274,6 +281,8 @@ class GreedyEntropy(LinesActionModel):
|
|
|
274
281
|
|
|
275
282
|
pixelwise_entropy = self.compute_pixelwise_entropy(particles)
|
|
276
283
|
linewise_entropy = ops.sum(pixelwise_entropy, axis=1)
|
|
284
|
+
if self.average_entropy_across_batch:
|
|
285
|
+
linewise_entropy = ops.expand_dims(ops.mean(linewise_entropy, axis=0), axis=0)
|
|
277
286
|
|
|
278
287
|
# Greedily select best line, reweight entropies, and repeat
|
|
279
288
|
all_selected_lines = []
|
|
@@ -334,7 +343,7 @@ class EquispacedLines(LinesActionModel):
|
|
|
334
343
|
n_possible_actions: int,
|
|
335
344
|
img_width: int,
|
|
336
345
|
img_height: int,
|
|
337
|
-
assert_equal_spacing=True,
|
|
346
|
+
assert_equal_spacing: bool = True,
|
|
338
347
|
):
|
|
339
348
|
super().__init__(n_actions, n_possible_actions, img_width, img_height)
|
|
340
349
|
|
|
@@ -462,7 +471,7 @@ class CovarianceSamplingLines(LinesActionModel):
|
|
|
462
471
|
particles = ops.reshape(particles, shape)
|
|
463
472
|
|
|
464
473
|
# [batch_size, rows * stack_n_cols, n_possible_actions, n_possible_actions]
|
|
465
|
-
cov_matrix =
|
|
474
|
+
cov_matrix = tensor.batch_cov(particles)
|
|
466
475
|
|
|
467
476
|
# Sum over the row dimension [batch_size, n_possible_actions, n_possible_actions]
|
|
468
477
|
cov_matrix = ops.sum(cov_matrix, axis=1)
|
|
@@ -477,7 +486,7 @@ class CovarianceSamplingLines(LinesActionModel):
|
|
|
477
486
|
# Subsample the covariance matrix with random lines
|
|
478
487
|
def subsample_with_mask(mask):
|
|
479
488
|
"""Subsample the covariance matrix with a single mask."""
|
|
480
|
-
subsampled_cov_matrix =
|
|
489
|
+
subsampled_cov_matrix = tensor.boolean_mask(
|
|
481
490
|
cov_matrix, mask, size=batch_size * self.n_actions**2
|
|
482
491
|
)
|
|
483
492
|
return ops.reshape(subsampled_cov_matrix, [batch_size, self.n_actions, self.n_actions])
|
|
@@ -59,8 +59,21 @@ def _import_torch():
|
|
|
59
59
|
return None
|
|
60
60
|
|
|
61
61
|
|
|
62
|
+
def _get_backend():
|
|
63
|
+
try:
|
|
64
|
+
backend_result = keras.backend.backend()
|
|
65
|
+
if isinstance(backend_result, str):
|
|
66
|
+
return backend_result
|
|
67
|
+
else:
|
|
68
|
+
# to handle mocked backends during testing
|
|
69
|
+
return None
|
|
70
|
+
except Exception:
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
|
|
62
74
|
tf_mod = _import_tf()
|
|
63
75
|
jax_mod = _import_jax()
|
|
76
|
+
backend = _get_backend()
|
|
64
77
|
|
|
65
78
|
|
|
66
79
|
def tf_function(func=None, jit_compile=False, **kwargs):
|
|
@@ -131,7 +144,7 @@ class on_device:
|
|
|
131
144
|
.. code-block:: python
|
|
132
145
|
|
|
133
146
|
with zea.backend.on_device("gpu:3"):
|
|
134
|
-
pipeline = zea.Pipeline([zea.
|
|
147
|
+
pipeline = zea.Pipeline([zea.ops.Abs()])
|
|
135
148
|
output = pipeline(data=keras.random.normal((10, 10))) # output is on "cuda:3"
|
|
136
149
|
"""
|
|
137
150
|
|
|
@@ -184,7 +197,7 @@ class on_device:
|
|
|
184
197
|
self._context.__exit__(exc_type, exc_val, exc_tb)
|
|
185
198
|
|
|
186
199
|
|
|
187
|
-
if
|
|
200
|
+
if backend in [None, "tensorflow", "jax", "numpy"]:
|
|
188
201
|
|
|
189
202
|
def func_on_device(func, device, *args, **kwargs):
|
|
190
203
|
"""Moves all tensor arguments of a function to a specified device before calling it.
|
|
@@ -199,7 +212,8 @@ if keras.backend.backend() in ["tensorflow", "jax", "numpy"]:
|
|
|
199
212
|
"""
|
|
200
213
|
with on_device(device):
|
|
201
214
|
return func(*args, **kwargs)
|
|
202
|
-
|
|
215
|
+
|
|
216
|
+
elif backend == "torch":
|
|
203
217
|
from zea.backend.torch import func_on_device
|
|
204
218
|
else:
|
|
205
|
-
raise ValueError(f"Unsupported backend: {
|
|
219
|
+
raise ValueError(f"Unsupported backend: {backend}")
|
|
@@ -12,8 +12,8 @@ from keras.src.trainers.data_adapters import TFDatasetAdapter
|
|
|
12
12
|
|
|
13
13
|
from zea.data.dataloader import H5Generator
|
|
14
14
|
from zea.data.layers import Resizer
|
|
15
|
+
from zea.func.tensor import translate
|
|
15
16
|
from zea.internal.utils import find_methods_with_return_type
|
|
16
|
-
from zea.tensor_ops import translate
|
|
17
17
|
|
|
18
18
|
METHODS_THAT_RETURN_DATASET = find_methods_with_return_type(tf.data.Dataset, "DatasetV2")
|
|
19
19
|
|