zea 0.0.0__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. zea/__init__.py +74 -0
  2. zea/__main__.py +82 -0
  3. zea/agent/__init__.py +27 -0
  4. zea/agent/gumbel.py +107 -0
  5. zea/agent/masks.py +192 -0
  6. zea/agent/selection.py +478 -0
  7. zea/backend/__init__.py +114 -0
  8. zea/backend/autograd.py +181 -0
  9. zea/backend/jax/__init__.py +70 -0
  10. zea/backend/tensorflow/__init__.py +66 -0
  11. zea/backend/tensorflow/dataloader.py +372 -0
  12. zea/backend/tensorflow/layers/__init__.py +0 -0
  13. zea/backend/tensorflow/layers/apodization.py +37 -0
  14. zea/backend/tensorflow/layers/utils.py +105 -0
  15. zea/backend/tensorflow/losses.py +64 -0
  16. zea/backend/tensorflow/models/__init__.py +0 -0
  17. zea/backend/tensorflow/models/lista.py +112 -0
  18. zea/backend/tensorflow/scripts/convert-echonet-dynamic.py +139 -0
  19. zea/backend/tensorflow/scripts/convert-taesd.py +88 -0
  20. zea/backend/tensorflow/utils/__init__.py +0 -0
  21. zea/backend/tensorflow/utils/callbacks.py +4 -0
  22. zea/backend/tensorflow/utils/utils.py +35 -0
  23. zea/backend/tf2jax.py +10 -0
  24. zea/backend/torch/__init__.py +74 -0
  25. zea/backend/torch/losses.py +64 -0
  26. zea/beamform/__init__.py +20 -0
  27. zea/beamform/beamformer.py +498 -0
  28. zea/beamform/delays.py +153 -0
  29. zea/beamform/lens_correction.py +198 -0
  30. zea/beamform/pfield.py +406 -0
  31. zea/beamform/phantoms.py +43 -0
  32. zea/beamform/pixelgrid.py +130 -0
  33. zea/config.py +520 -0
  34. zea/data/__init__.py +55 -0
  35. zea/data/__main__.py +31 -0
  36. zea/data/augmentations.py +329 -0
  37. zea/data/convert/__init__.py +6 -0
  38. zea/data/convert/camus.py +259 -0
  39. zea/data/convert/echonet.py +438 -0
  40. zea/data/convert/echonetlvh/README.md +9 -0
  41. zea/data/convert/echonetlvh/convert_raw_to_usbmd.py +499 -0
  42. zea/data/convert/echonetlvh/precompute_crop.py +251 -0
  43. zea/data/convert/images.py +135 -0
  44. zea/data/convert/matlab.py +1230 -0
  45. zea/data/convert/picmus.py +182 -0
  46. zea/data/data_format.py +718 -0
  47. zea/data/dataloader.py +409 -0
  48. zea/data/datasets.py +650 -0
  49. zea/data/file.py +844 -0
  50. zea/data/layers.py +167 -0
  51. zea/data/preset_utils.py +109 -0
  52. zea/data/utils.py +90 -0
  53. zea/datapaths.py +566 -0
  54. zea/display.py +666 -0
  55. zea/interface.py +546 -0
  56. zea/internal/cache.py +284 -0
  57. zea/internal/checks.py +301 -0
  58. zea/internal/config/create.py +161 -0
  59. zea/internal/config/parameters.py +123 -0
  60. zea/internal/config/validation.py +165 -0
  61. zea/internal/convert.py +150 -0
  62. zea/internal/core.py +314 -0
  63. zea/internal/device.py +414 -0
  64. zea/internal/git_info.py +44 -0
  65. zea/internal/operators.py +72 -0
  66. zea/internal/parameters.py +425 -0
  67. zea/internal/registry.py +203 -0
  68. zea/internal/setup_zea.py +223 -0
  69. zea/internal/viewer.py +450 -0
  70. zea/io_lib.py +350 -0
  71. zea/log.py +356 -0
  72. zea/metrics.py +158 -0
  73. zea/models/__init__.py +88 -0
  74. zea/models/base.py +195 -0
  75. zea/models/carotid_segmenter.py +168 -0
  76. zea/models/dense.py +132 -0
  77. zea/models/diffusion.py +842 -0
  78. zea/models/echonet.py +181 -0
  79. zea/models/generative.py +75 -0
  80. zea/models/gmm.py +208 -0
  81. zea/models/layers.py +69 -0
  82. zea/models/lpips.py +181 -0
  83. zea/models/preset_utils.py +414 -0
  84. zea/models/presets.py +100 -0
  85. zea/models/taesd.py +245 -0
  86. zea/models/unet.py +207 -0
  87. zea/models/utils.py +60 -0
  88. zea/ops.py +3026 -0
  89. zea/probes.py +224 -0
  90. zea/scan.py +619 -0
  91. zea/simulator.py +343 -0
  92. zea/tensor_ops.py +1327 -0
  93. zea/tools/__init__.py +8 -0
  94. zea/tools/fit_scan_cone.py +709 -0
  95. zea/tools/hf.py +174 -0
  96. zea/tools/selection_tool.py +847 -0
  97. zea/tools/wndb.py +22 -0
  98. zea/utils.py +664 -0
  99. zea/visualize.py +634 -0
  100. zea/zea_darkmode.mplstyle +798 -0
  101. zea-0.0.2.dist-info/LICENSE +202 -0
  102. zea-0.0.2.dist-info/METADATA +115 -0
  103. zea-0.0.2.dist-info/RECORD +105 -0
  104. {zea-0.0.0.dist-info → zea-0.0.2.dist-info}/WHEEL +1 -2
  105. zea-0.0.2.dist-info/entry_points.txt +3 -0
  106. zea-0.0.0.dist-info/METADATA +0 -17
  107. zea-0.0.0.dist-info/RECORD +0 -5
  108. zea-0.0.0.dist-info/top_level.txt +0 -1
zea/__init__.py CHANGED
@@ -0,0 +1,74 @@
1
+ """``zea``: *A Toolbox for Cognitive Ultrasound Imaging.*"""
2
+
3
+ import importlib.util
4
+ import os
5
+
6
+ from . import log
7
+
8
+ # dynamically add __version__ attribute (see pyproject.toml)
9
+ # __version__ = __import__("importlib.metadata").metadata.version(__package__)
10
+ __version__ = "0.0.2"
11
+
12
+
13
+ def setup():
14
+ """Setup function to initialize the zea package."""
15
+
16
+ def _check_backend_installed():
17
+ """Assert that at least one ML backend (torch, tensorflow, jax) is installed.
18
+ If not, raise an AssertionError with a helpful install message.
19
+ """
20
+
21
+ ml_backends = ["torch", "tensorflow", "jax"]
22
+ for backend in ml_backends:
23
+ if importlib.util.find_spec(backend) is not None:
24
+ return
25
+
26
+ backend_env = os.environ.get("KERAS_BACKEND", "numpy")
27
+ install_guide_urls = {
28
+ "torch": "https://pytorch.org/get-started/locally/",
29
+ "tensorflow": "https://www.tensorflow.org/install",
30
+ "jax": "https://docs.jax.dev/en/latest/installation.html",
31
+ }
32
+ guide_url = install_guide_urls.get(backend_env, "https://keras.io/getting_started/")
33
+ raise ImportError(
34
+ "No ML backend (torch, tensorflow, jax) installed in current environment. "
35
+ f"Please install at least one ML backend before importing {__package__} or "
36
+ f"any other library. Current KERAS_BACKEND is set to '{backend_env}', "
37
+ f"please install it first, see: {guide_url}. One simple alternative is to "
38
+ f"install with default backend: `pip install {__package__}[jax]`."
39
+ )
40
+
41
+ _check_backend_installed()
42
+
43
+ import keras
44
+
45
+ log.info(f"Using backend {keras.backend.backend()!r}")
46
+
47
+
48
+ # call and clean up namespace
49
+ setup()
50
+ del setup
51
+
52
+ from . import (
53
+ agent,
54
+ beamform,
55
+ data,
56
+ display,
57
+ io_lib,
58
+ metrics,
59
+ models,
60
+ simulator,
61
+ tensor_ops,
62
+ utils,
63
+ visualize,
64
+ )
65
+ from .config import Config
66
+ from .data.datasets import Dataset, Folder
67
+ from .data.file import File, load_file
68
+ from .datapaths import set_data_paths
69
+ from .interface import Interface
70
+ from .internal.device import init_device
71
+ from .internal.setup_zea import set_backend, setup, setup_config
72
+ from .ops import Pipeline
73
+ from .probes import Probe
74
+ from .scan import Scan
zea/__main__.py ADDED
@@ -0,0 +1,82 @@
1
+ """Main entry point for zea
2
+
3
+ Run as `zea --config path/to/config.yaml` to start the zea interface.
4
+ Or do not pass a config file to open a file dialog to choose a config file.
5
+
6
+ """
7
+
8
+ import argparse
9
+ import sys
10
+ from pathlib import Path
11
+
12
+ from zea import log
13
+ from zea.visualize import set_mpl_style
14
+
15
+
16
+ def get_args():
17
+ """Command line argument parser"""
18
+ parser = argparse.ArgumentParser(description="Process ultrasound data.")
19
+ parser.add_argument("-c", "--config", type=str, default=None, help="path to config file.")
20
+ parser.add_argument(
21
+ "-t",
22
+ "--task",
23
+ default="view",
24
+ choices=["view"],
25
+ type=str,
26
+ help="which task to run",
27
+ )
28
+ parser.add_argument(
29
+ "--backend",
30
+ default=None,
31
+ type=str,
32
+ help=(
33
+ "Keras backend to use. Default is the one set by the environment "
34
+ "variable KERAS_BACKEND."
35
+ ),
36
+ )
37
+ parser.add_argument(
38
+ "--skip_validate_file",
39
+ default=False,
40
+ action="store_true",
41
+ help="Skip zea file integrity checks. Use with caution.",
42
+ )
43
+ parser.add_argument("--gui", default=False, action=argparse.BooleanOptionalAction)
44
+ args = parser.parse_args()
45
+ return args
46
+
47
+
48
+ def main():
49
+ """main entrypoint for zea"""
50
+ args = get_args()
51
+
52
+ set_mpl_style()
53
+
54
+ if args.backend:
55
+ from zea.internal.setup_zea import set_backend
56
+
57
+ set_backend(args.backend)
58
+
59
+ wd = Path(__file__).parent.resolve()
60
+ sys.path.append(str(wd))
61
+
62
+ import keras
63
+
64
+ from zea.interface import Interface
65
+ from zea.internal.setup_zea import setup
66
+
67
+ config = setup(args.config)
68
+
69
+ if args.task == "view":
70
+ cli = Interface(
71
+ config,
72
+ validate_file=not args.skip_validate_file,
73
+ )
74
+
75
+ log.info(f"Using {keras.backend.backend()} backend")
76
+ cli.run(plot=True)
77
+ else:
78
+ raise ValueError(f"Unknown task {args.task}, see `zea --help` for available tasks.")
79
+
80
+
81
+ if __name__ == "__main__":
82
+ main()
zea/agent/__init__.py ADDED
@@ -0,0 +1,27 @@
1
+ """Agent subpackage for closing action-perception loop in ultrasound imaging.
2
+
3
+ The `agent` subpackage provides tools and utilities for agent-based algorithms within the ``zea`` framework, including mask generation and action selection strategies. See :mod:`zea.agent.masks` and :mod:`zea.agent.selection` for key functions implementing intelligent focused transmit selection, such as the :class:`zea.agent.selection.GreedyEntropy` algorithm.
4
+
5
+ For a practical example, see :doc:`../notebooks/agent/agent_example`.
6
+
7
+ Example usage
8
+ ^^^^^^^^^^^^^
9
+
10
+ .. code-block:: python
11
+
12
+ import zea
13
+ import numpy as np
14
+
15
+ agent = zea.agent.selection.GreedyEntropy(
16
+ n_actions=7,
17
+ n_possible_actions=112,
18
+ img_width=112,
19
+ img_height=112,
20
+ )
21
+
22
+ # (batch, samples, height, width)
23
+ particles = np.random.rand(1, 10, 112, 112)
24
+ lines, mask = agent.sample(particles)
25
+ """
26
+
27
+ from . import masks, selection
zea/agent/gumbel.py ADDED
@@ -0,0 +1,107 @@
1
+ """Gumbel-Softmax trick implemented with the multi-backend ``keras.ops``."""
2
+
3
+ import keras
4
+ import numpy as np
5
+ from keras import ops
6
+
7
+ if keras.backend.backend() != "jax":
8
+ # This allows tensorflow tracing
9
+ prod = ops.prod
10
+ else:
11
+ # Jax does not allow shapes to be tensors
12
+ prod = np.prod
13
+
14
+
15
+ class SubsetOperator:
16
+ """SubsetOperator applies the Gumbel-Softmax trick for continuous top-k selection.
17
+
18
+ Args:
19
+ k (int): The number of elements to select.
20
+ tau (float, optional): The temperature parameter for Gumbel-Softmax. Defaults to 1.0.
21
+ hard (bool, optional): Whether to use straight-through Gumbel-Softmax. Defaults to False.
22
+
23
+ Sources:
24
+ - `Reparameterizable Subset Sampling via Continuous Relaxations <https://github.com/ermongroup/subsets>`_
25
+ - `Sampling Subsets with Gumbel-Top Relaxations <https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/DL2/sampling/subsets.html>`_
26
+ """ # noqa: E501
27
+
28
+ def __init__(self, k, tau=1.0, hard=False, n_value_dims=1):
29
+ self.k = k
30
+ self.tau = tau
31
+ self.hard = hard
32
+ self.EPSILON = np.finfo(np.float32).tiny
33
+ self.n_value_dims = n_value_dims # for a image mask: n_value_dims=2
34
+
35
+ def gumbel_sample(self, shape):
36
+ """Samples from Gumbel(0,1) distribution"""
37
+ uniform = keras.random.uniform(shape, minval=0, maxval=1)
38
+ return -ops.log(-ops.log(uniform + self.EPSILON) + self.EPSILON)
39
+
40
+ def __call__(self, scores):
41
+ # Gumbel-Softmax trick to make the sampling differentiable
42
+ gumbel_noise = self.gumbel_sample(ops.shape(scores))
43
+ scores = scores + gumbel_noise
44
+
45
+ # Continuous top-k selection
46
+ khot = ops.zeros_like(scores)
47
+ onehot_approx = ops.zeros_like(scores)
48
+
49
+ for _ in range(self.k):
50
+ khot_mask = ops.max(1.0 - onehot_approx, self.EPSILON)
51
+ scores = scores + ops.log(khot_mask)
52
+ onehot_approx = ops.softmax(scores / self.tau, axis=1)
53
+ khot = khot + onehot_approx
54
+
55
+ # Optionally convert soft selection to hard selection using straight-through estimator
56
+ if self.hard:
57
+ res = hard_straight_through(khot, self.k, self.n_value_dims)
58
+ else:
59
+ res = khot
60
+
61
+ return res
62
+
63
+
64
+ def hard_straight_through(khot_orig, k, n_value_dims=1):
65
+ """Applies the hard straight-through estimator to the given k-hot encoded tensor.
66
+
67
+ Args:
68
+ khot_orig (Tensor): The original k-hot encoded tensor.
69
+ k (int): The number of top elements to select.
70
+ n_value_dims (int, optional): The number of value dimensions in the input tensor.
71
+ Defaults to 1. E.g. for a 2D image mask, `n_value_dims=2`.
72
+ Returns:
73
+ Tensor: The tensor after applying the hard straight-through estimator,
74
+ with the same shape as `khot_orig`.
75
+ """
76
+
77
+ # Extract the batch dimensions and the value dimensions
78
+ original_shape = ops.shape(khot_orig)
79
+ value_dims = original_shape[-n_value_dims:]
80
+
81
+ # Flatten the input tensor along the value dimensions
82
+ khot = ops.reshape(khot_orig, (-1, prod(value_dims)))
83
+
84
+ # Get the top-k indices
85
+ indices = ops.top_k(khot, k)[1]
86
+
87
+ # Reshape the indices for use with ops.scatter
88
+ scatter_indices = ops.stack(
89
+ [
90
+ ops.repeat(ops.arange(ops.shape(khot)[0]), k),
91
+ ops.reshape(indices, (-1,)),
92
+ ],
93
+ axis=-1,
94
+ )
95
+
96
+ # Create the hard k-hot tensor
97
+ khot_hard = ops.scatter(
98
+ scatter_indices,
99
+ ops.ones(prod(ops.shape(indices)), "float32"),
100
+ ops.shape(khot),
101
+ )
102
+
103
+ # Straight-through estimator
104
+ out = khot_hard - ops.stop_gradient(khot) + khot
105
+
106
+ # Reshape to the original shape
107
+ return ops.reshape(out, original_shape)
zea/agent/masks.py ADDED
@@ -0,0 +1,192 @@
1
+ """
2
+ Mask generation utilities.
3
+
4
+ These masks are used as a measurement operator for focused scan-line subsampling.
5
+ """
6
+
7
+ from typing import List
8
+
9
+ import keras
10
+ from keras import ops
11
+
12
+ from zea import tensor_ops
13
+ from zea.agent.gumbel import hard_straight_through
14
+
15
+ _DEFAULT_DTYPE = "bool"
16
+
17
+
18
+ def indices_to_k_hot(
19
+ indices: List[int],
20
+ n_possible_actions: int,
21
+ dtype=_DEFAULT_DTYPE,
22
+ ):
23
+ """Convert a list of indices to a k-hot encoded vector.
24
+
25
+ A k-hot encoded vector is suitable during tracing when the number of actions can change.
26
+ This is the default represenation for actions in zea.
27
+
28
+ Args:
29
+ indices (List[int]): List of indices to set to 1.
30
+ n_possible_actions (int): Total number of possible actions.
31
+ dtype (str, optional): Data type of the mask. Defaults to _DEFAULT_DTYPE.
32
+
33
+ Returns:
34
+ Tensor: k-hot-encoded vector of shape (n_possible_actions).
35
+ """
36
+ mask = ops.zeros(n_possible_actions, dtype=dtype)
37
+ return ops.scatter_update(
38
+ mask, ops.expand_dims(indices, axis=1), ops.ones(len(indices), dtype=dtype)
39
+ )
40
+
41
+
42
+ def k_hot_to_indices(selected_lines, n_actions: int, fill_value=-1):
43
+ """Convert k-hot encoded lines to indices of selected actions.
44
+
45
+ Args:
46
+ selected_lines (Tensor): k-hot encoded lines of shape (batch_size, n_possible_actions).
47
+ n_actions (int): Number of lines selected.
48
+ fill_value (int, optional): Value to fill in case there are not enough selected actions.
49
+ Defaults to -1.
50
+
51
+ Returns:
52
+ Tensor: Indices of selected actions of shape (batch_size, n_actions).
53
+ If there are fewer than `n_actions` selected, the remaining indices will be
54
+ filled with `fill_value`.
55
+ """
56
+
57
+ # Find nonzero indices for each frame
58
+ def get_nonzero(row):
59
+ return tensor_ops.nonzero(row > 0, size=n_actions, fill_value=fill_value)[0]
60
+
61
+ indices = ops.vectorized_map(get_nonzero, selected_lines)
62
+ return indices
63
+
64
+
65
+ def random_uniform_lines(
66
+ n_actions: int,
67
+ n_possible_actions: int,
68
+ n_masks: int,
69
+ seed: int | keras.random.SeedGenerator = None,
70
+ dtype=_DEFAULT_DTYPE,
71
+ ):
72
+ """Will generate a mask with random lines.
73
+
74
+ Guarantees precisely n_actions.
75
+
76
+ Args:
77
+ n_actions (int): Number of actions to be selected.
78
+ n_possible_actions (int): Number of possible actions.
79
+ n_masks (int): Number of masks to generate.
80
+ seed (int | SeedGenerator | jax.random.key, optional): Seed for random number generation.
81
+ Defaults to None.
82
+
83
+ Returns:
84
+ Tensor: k-hot-encoded line vectors of shape (n_masks, n_possible_actions).
85
+ Needs to be converted to image size.
86
+ """
87
+ masks = keras.random.uniform([n_masks, n_possible_actions], seed=seed, dtype="float32")
88
+ masks = hard_straight_through(masks, n_actions)
89
+ return ops.cast(masks, dtype=dtype)
90
+
91
+
92
+ def _assert_equal_spacing(n_actions, n_possible_actions):
93
+ assert n_possible_actions % n_actions == 0, (
94
+ "Number of actions must divide evenly into possible actions to use equispaced sampling."
95
+ )
96
+
97
+
98
+ def initial_equispaced_lines(
99
+ n_actions, n_possible_actions, dtype=_DEFAULT_DTYPE, assert_equal_spacing=True
100
+ ):
101
+ """Generate an initial equispaced k-hot line mask.
102
+
103
+ For example, if ``n_actions=2`` and ``n_possible_actions=6``,
104
+ then ``initial_mask=[1, 0, 0, 1, 0, 0]``.
105
+
106
+ Args:
107
+ n_actions (int): Number of actions to be selected.
108
+ n_possible_actions (int): Number of possible actions.
109
+ dtype (str, optional): Data type of the mask. Defaults to _DEFAULT_DTYPE.
110
+ assert_equal_spacing (bool, optional): If True, asserts that
111
+ `n_possible_actions` is divisible by `n_actions`, this means that every
112
+ line will have the exact same spacing. Otherwise, there might be
113
+ some spacing differences. Defaults to True.
114
+
115
+ Returns:
116
+ Tensor: k-hot-encoded line vector of shape (n_possible_actions).
117
+ Needs to be converted to image size.
118
+ """
119
+ if assert_equal_spacing:
120
+ _assert_equal_spacing(n_actions, n_possible_actions)
121
+ selected_indices = ops.arange(0, n_possible_actions, n_possible_actions // n_actions)
122
+ else:
123
+ selected_indices = ops.linspace(0, n_possible_actions - 1, n_actions, dtype="int32")
124
+
125
+ return indices_to_k_hot(selected_indices, n_possible_actions, dtype=dtype)
126
+
127
+
128
+ def next_equispaced_lines(previous_lines, shift=1):
129
+ """
130
+ Rolls the previous equispaced mask of shape (..., n_possible_actions) to the right by
131
+ `shift` which is 1 by default.
132
+ """
133
+ return ops.roll(previous_lines, shift=shift, axis=-1)
134
+
135
+
136
+ def lines_to_im_size(lines, img_size: tuple):
137
+ """
138
+ Convert k-hot-encoded line vectors to image size.
139
+
140
+ Args:
141
+ lines (Tensor): shape is (n_masks, n_possible_actions)
142
+ img_size (tuple): (height, width)
143
+
144
+ Returns:
145
+ Tensor: Masks of shape (n_masks, img_size, img_size)
146
+ """
147
+ height, width = img_size
148
+
149
+ remainder = width % ops.shape(lines)[1]
150
+ assert remainder == 0, (
151
+ f"Width must be divisible by number of actions. Got remainder: {remainder}."
152
+ )
153
+
154
+ # Repeat till width of image
155
+ masks = ops.repeat(lines, width // ops.shape(lines)[1], axis=1)
156
+
157
+ # Repeat till height of image
158
+ masks = ops.repeat(masks[:, None], height, axis=1)
159
+
160
+ return masks
161
+
162
+
163
+ def make_line_mask(
164
+ line_indices: List[int],
165
+ image_shape: List[int],
166
+ line_width: int = 1,
167
+ dtype=_DEFAULT_DTYPE,
168
+ ):
169
+ """
170
+ Creates a mask with vertical (i.e. second axis) lines at specified indices.
171
+
172
+ Args:
173
+ line_indices (List[int]): A list of indices where the lines should be drawn.
174
+ image_shape (List[int]): The shape of the image as [height, width, channels].
175
+ line_width (int, optional): The width of each line. Defaults to 1.
176
+ dtype (str, optional): The data type of the mask. Defaults to "float32".
177
+
178
+ Returns:
179
+ mask (Tensor): A tensor of the same shape as `image_shape` with lines drawn
180
+ at the specified indices.
181
+ """
182
+ height, width, channels = image_shape
183
+
184
+ # Create k-hot vector for the line indices
185
+ k_hot = indices_to_k_hot(line_indices, width // line_width, dtype=dtype)
186
+ # Expand to (1, n_possible_actions) for lines_to_im_size
187
+ k_hot = ops.expand_dims(k_hot, axis=0)
188
+ # Use lines_to_im_size to create the mask of shape (1, height, width)
189
+ mask_2d = lines_to_im_size(k_hot, (height, width))[0]
190
+
191
+ # Expand to (height, width, channels)
192
+ return ops.broadcast_to(mask_2d[..., None], (height, width, channels))