zea 0.0.7__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +1 -1
- zea/backend/tensorflow/dataloader.py +0 -4
- zea/beamform/pixelgrid.py +1 -1
- zea/data/__init__.py +0 -9
- zea/data/augmentations.py +221 -28
- zea/data/convert/__init__.py +1 -6
- zea/data/convert/__main__.py +123 -0
- zea/data/convert/camus.py +99 -39
- zea/data/convert/echonet.py +183 -82
- zea/data/convert/echonetlvh/README.md +2 -3
- zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +173 -102
- zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
- zea/data/convert/echonetlvh/precompute_crop.py +43 -64
- zea/data/convert/picmus.py +37 -40
- zea/data/convert/utils.py +86 -0
- zea/data/convert/{matlab.py → verasonics.py} +33 -61
- zea/data/data_format.py +124 -4
- zea/data/dataloader.py +12 -7
- zea/data/datasets.py +109 -70
- zea/data/file.py +91 -82
- zea/data/file_operations.py +496 -0
- zea/data/preset_utils.py +1 -1
- zea/display.py +7 -8
- zea/internal/checks.py +6 -12
- zea/internal/operators.py +4 -0
- zea/io_lib.py +108 -160
- zea/models/__init__.py +1 -1
- zea/models/diffusion.py +62 -11
- zea/models/lv_segmentation.py +2 -0
- zea/ops.py +398 -158
- zea/scan.py +18 -8
- zea/tensor_ops.py +82 -62
- zea/tools/fit_scan_cone.py +90 -160
- zea/tracking/__init__.py +16 -0
- zea/tracking/base.py +94 -0
- zea/tracking/lucas_kanade.py +474 -0
- zea/tracking/segmentation.py +110 -0
- zea/utils.py +11 -2
- {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/METADATA +3 -1
- {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/RECORD +43 -35
- {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/WHEEL +0 -0
- {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/entry_points.txt +0 -0
- {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/licenses/LICENSE +0 -0
zea/io_lib.py
CHANGED
|
@@ -4,20 +4,17 @@ Use to quickly read and write files or interact with file system.
|
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
import functools
|
|
7
|
-
import multiprocessing
|
|
8
7
|
import os
|
|
9
8
|
import time
|
|
10
9
|
from io import BytesIO
|
|
11
10
|
from pathlib import Path
|
|
11
|
+
from typing import Generator
|
|
12
12
|
|
|
13
13
|
import imageio
|
|
14
14
|
import numpy as np
|
|
15
|
-
import tqdm
|
|
16
|
-
import yaml
|
|
17
15
|
from PIL import Image, ImageSequence
|
|
18
16
|
|
|
19
17
|
from zea import log
|
|
20
|
-
from zea.data.file import File
|
|
21
18
|
|
|
22
19
|
_SUPPORTED_VID_TYPES = [".mp4", ".gif"]
|
|
23
20
|
_SUPPORTED_IMG_TYPES = [".jpg", ".png", ".JPEG", ".PNG", ".jpeg"]
|
|
@@ -123,7 +120,7 @@ def save_video(images, filename, fps=20, **kwargs):
|
|
|
123
120
|
filename (str or Path): Filename to which data should be written.
|
|
124
121
|
fps (int): Frames per second of rendered format.
|
|
125
122
|
**kwargs: Additional keyword arguments passed to the specific save function.
|
|
126
|
-
For GIF files, this includes `shared_color_palette` (bool).
|
|
123
|
+
For GIF and mp4 files, this includes `shared_color_palette` (bool).
|
|
127
124
|
|
|
128
125
|
Raises:
|
|
129
126
|
ValueError: If the file extension is not supported.
|
|
@@ -133,14 +130,14 @@ def save_video(images, filename, fps=20, **kwargs):
|
|
|
133
130
|
ext = filename.suffix.lower()
|
|
134
131
|
|
|
135
132
|
if ext == ".mp4":
|
|
136
|
-
return save_to_mp4(images, filename, fps=fps)
|
|
133
|
+
return save_to_mp4(images, filename, fps=fps, **kwargs)
|
|
137
134
|
elif ext == ".gif":
|
|
138
135
|
return save_to_gif(images, filename, fps=fps, **kwargs)
|
|
139
136
|
else:
|
|
140
137
|
raise ValueError(f"Unsupported file extension: {ext}")
|
|
141
138
|
|
|
142
139
|
|
|
143
|
-
def save_to_gif(images, filename, fps=20, shared_color_palette=
|
|
140
|
+
def save_to_gif(images, filename, fps=20, shared_color_palette=True):
|
|
144
141
|
"""Saves a sequence of images to a GIF file.
|
|
145
142
|
|
|
146
143
|
.. note::
|
|
@@ -156,9 +153,9 @@ def save_to_gif(images, filename, fps=20, shared_color_palette=False):
|
|
|
156
153
|
fps (int): Frames per second of rendered format.
|
|
157
154
|
shared_color_palette (bool, optional): If True, creates a global
|
|
158
155
|
color palette across all frames, ensuring consistent colors
|
|
159
|
-
throughout the GIF. Defaults to
|
|
160
|
-
of PIL.Image.save. Note: True
|
|
161
|
-
|
|
156
|
+
throughout the GIF. Defaults to True, which is default behavior
|
|
157
|
+
of PIL.Image.save. Note: True increases speed and shrinks file
|
|
158
|
+
size for longer sequences.
|
|
162
159
|
|
|
163
160
|
"""
|
|
164
161
|
images = preprocess_for_saving(images)
|
|
@@ -173,15 +170,8 @@ def save_to_gif(images, filename, fps=20, shared_color_palette=False):
|
|
|
173
170
|
|
|
174
171
|
if shared_color_palette:
|
|
175
172
|
# Apply the same palette to all frames without dithering for consistent color mapping
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
combined_image = Image.fromarray(all_colors.reshape(-1, 1, 3))
|
|
179
|
-
|
|
180
|
-
# Generate palette from all frames
|
|
181
|
-
global_palette = combined_image.quantize(
|
|
182
|
-
colors=256,
|
|
183
|
-
method=Image.MEDIANCUT,
|
|
184
|
-
kmeans=1,
|
|
173
|
+
global_palette = compute_global_palette_by_histogram(
|
|
174
|
+
pillow_imgs, bits_per_channel=5, palette_size=256
|
|
185
175
|
)
|
|
186
176
|
|
|
187
177
|
# Apply the same palette to all frames without dithering
|
|
@@ -208,7 +198,7 @@ def save_to_gif(images, filename, fps=20, shared_color_palette=False):
|
|
|
208
198
|
log.success(f"Successfully saved GIF to -> {log.yellow(filename)}")
|
|
209
199
|
|
|
210
200
|
|
|
211
|
-
def save_to_mp4(images, filename, fps=20):
|
|
201
|
+
def save_to_mp4(images, filename, fps=20, shared_color_palette=False):
|
|
212
202
|
"""Saves a sequence of images to an MP4 file.
|
|
213
203
|
|
|
214
204
|
.. note::
|
|
@@ -222,6 +212,10 @@ def save_to_mp4(images, filename, fps=20):
|
|
|
222
212
|
which is then converted to RGB. Images should be uint8.
|
|
223
213
|
filename (str or Path): Filename to which data should be written.
|
|
224
214
|
fps (int): Frames per second of rendered format.
|
|
215
|
+
shared_color_palette (bool, optional): If True, creates a global
|
|
216
|
+
color palette across all frames, ensuring consistent colors
|
|
217
|
+
throughout the MP4. Note: True can cause slow saving for longer
|
|
218
|
+
sequences.
|
|
225
219
|
|
|
226
220
|
Raises:
|
|
227
221
|
ImportError: If imageio-ffmpeg is not installed.
|
|
@@ -249,166 +243,120 @@ def save_to_mp4(images, filename, fps=20):
|
|
|
249
243
|
) from exc
|
|
250
244
|
|
|
251
245
|
try:
|
|
252
|
-
|
|
253
|
-
|
|
246
|
+
if shared_color_palette:
|
|
247
|
+
pillow_imgs = [Image.fromarray(img) for img in images]
|
|
248
|
+
global_palette = compute_global_palette_by_histogram(
|
|
249
|
+
pillow_imgs, bits_per_channel=5, palette_size=256
|
|
250
|
+
)
|
|
251
|
+
for img in pillow_imgs:
|
|
252
|
+
paletted_img = img.convert("RGB").quantize(
|
|
253
|
+
palette=global_palette,
|
|
254
|
+
dither=Image.NONE,
|
|
255
|
+
)
|
|
256
|
+
writer.append_data(np.array(paletted_img.convert("RGB")))
|
|
257
|
+
else:
|
|
258
|
+
# Write from numpy arrays directly
|
|
259
|
+
for image in images:
|
|
260
|
+
writer.append_data(image)
|
|
254
261
|
finally:
|
|
255
262
|
writer.close()
|
|
256
263
|
|
|
257
264
|
return log.success(f"Successfully saved MP4 to -> {filename}")
|
|
258
265
|
|
|
259
266
|
|
|
260
|
-
def search_file_tree(
|
|
261
|
-
directory
|
|
262
|
-
filetypes=None,
|
|
263
|
-
write=True,
|
|
264
|
-
dataset_info_filename="dataset_info.yaml",
|
|
265
|
-
hdf5_key_for_length=None,
|
|
266
|
-
redo=False,
|
|
267
|
-
parallel=False,
|
|
268
|
-
verbose=True,
|
|
269
|
-
):
|
|
270
|
-
"""Lists all files in directory and sub-directories.
|
|
271
|
-
|
|
272
|
-
If dataset_info.yaml is detected in the directory, that file is read and used
|
|
273
|
-
to deduce the file paths. If not, the file paths are searched for in the
|
|
274
|
-
directory and written to a dataset_info.yaml file.
|
|
267
|
+
def search_file_tree(directory, filetypes=None, verbose=True, relative=False) -> Generator:
|
|
268
|
+
"""Traverse a directory tree and yield file paths matching specified file types.
|
|
275
269
|
|
|
276
270
|
Args:
|
|
277
|
-
directory (str):
|
|
278
|
-
filetypes (
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
Defaults to None. If set, the number of frames in each hdf5 file is
|
|
287
|
-
calculated and stored in the dataset_info.yaml file. This is extra
|
|
288
|
-
functionality of ``search_file_tree`` and only works with hdf5 files.
|
|
289
|
-
redo (bool, optional): Whether to redo the search and overwrite the dataset_info.yaml file.
|
|
290
|
-
parallel (bool, optional): Whether to use multiprocessing for hdf5 shape reading.
|
|
291
|
-
verbose (bool, optional): Whether to print progress and info.
|
|
292
|
-
|
|
293
|
-
Returns:
|
|
294
|
-
dict: Dictionary containing file paths and total number of files.
|
|
295
|
-
Has the following structure:
|
|
296
|
-
|
|
297
|
-
.. code-block:: python
|
|
298
|
-
|
|
299
|
-
{
|
|
300
|
-
"file_paths": list of file paths,
|
|
301
|
-
"total_num_files": total number of files,
|
|
302
|
-
"file_lengths": list of number of frames in each hdf5 file,
|
|
303
|
-
"file_shapes": list of shapes of each image file,
|
|
304
|
-
"total_num_frames": total number of frames in all hdf5 files
|
|
305
|
-
}
|
|
306
|
-
|
|
271
|
+
directory (str or Path): The root directory to start the search.
|
|
272
|
+
filetypes (list of str, optional): List of file extensions to match.
|
|
273
|
+
If None, file types supported by `zea` are matched. Defaults to None.
|
|
274
|
+
verbose (bool, optional): If True, logs the search process. Defaults to True.
|
|
275
|
+
relative (bool, optional): If True, yields file paths relative to the
|
|
276
|
+
root directory. Defaults to False.
|
|
277
|
+
|
|
278
|
+
Yields:
|
|
279
|
+
Path: Paths of files matching the specified file types.
|
|
307
280
|
"""
|
|
308
|
-
directory = Path(directory)
|
|
309
|
-
if not directory.is_dir():
|
|
310
|
-
raise ValueError(
|
|
311
|
-
log.error(f"Directory {directory} does not exist. Please provide a valid directory.")
|
|
312
|
-
)
|
|
313
|
-
assert Path(dataset_info_filename).suffix == ".yaml", (
|
|
314
|
-
"Currently only YAML files are supported for dataset info file when "
|
|
315
|
-
f"using `search_file_tree`, got {dataset_info_filename}"
|
|
316
|
-
)
|
|
317
|
-
|
|
318
|
-
if (directory / dataset_info_filename).is_file() and not redo:
|
|
319
|
-
with open(directory / dataset_info_filename, "r", encoding="utf-8") as file:
|
|
320
|
-
dataset_info = yaml.load(file, Loader=yaml.FullLoader)
|
|
321
|
-
|
|
322
|
-
# Check if the file_shapes key is present in the dataset_info, otherwise redo the search
|
|
323
|
-
if "file_shapes" in dataset_info:
|
|
324
|
-
if verbose:
|
|
325
|
-
log.info(
|
|
326
|
-
"Using pregenerated dataset info file: "
|
|
327
|
-
f"{log.yellow(directory / dataset_info_filename)} ..."
|
|
328
|
-
)
|
|
329
|
-
log.info(f"...for reading file paths in {log.yellow(directory)}")
|
|
330
|
-
return dataset_info
|
|
331
|
-
|
|
332
|
-
if redo and verbose:
|
|
333
|
-
log.info(f"Overwriting dataset info file: {log.yellow(directory / dataset_info_filename)}")
|
|
334
|
-
|
|
335
|
-
# set default file type
|
|
336
|
-
if filetypes is None:
|
|
337
|
-
filetypes = _SUPPORTED_IMG_TYPES + _SUPPORTED_VID_TYPES + _SUPPORTED_ZEA_TYPES
|
|
338
|
-
|
|
339
|
-
file_paths = []
|
|
340
|
-
|
|
341
|
-
if isinstance(filetypes, str):
|
|
342
|
-
filetypes = [filetypes]
|
|
343
|
-
|
|
344
|
-
if hdf5_key_for_length is not None:
|
|
345
|
-
assert isinstance(hdf5_key_for_length, str), "hdf5_key_for_length must be a string"
|
|
346
|
-
assert set(filetypes).issubset({".hdf5", ".h5"}), (
|
|
347
|
-
"hdf5_key_for_length only works with when filetypes is set to "
|
|
348
|
-
f"`.hdf5` or `.h5`, got {filetypes}"
|
|
349
|
-
)
|
|
350
|
-
|
|
351
281
|
# Traverse file tree to index all files from filetypes
|
|
352
282
|
if verbose:
|
|
353
283
|
log.info(f"Searching {log.yellow(directory)} for {filetypes} files...")
|
|
284
|
+
|
|
354
285
|
for dirpath, _, filenames in os.walk(directory):
|
|
355
286
|
for file in filenames:
|
|
356
287
|
# Append to file_paths if it is a filetype file
|
|
357
288
|
if Path(file).suffix in filetypes:
|
|
358
289
|
file_path = Path(dirpath) / file
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
290
|
+
if relative:
|
|
291
|
+
file_path = file_path.relative_to(directory)
|
|
292
|
+
yield file_path
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def compute_global_palette_by_histogram(pillow_imgs, bits_per_channel=5, palette_size=256):
|
|
296
|
+
"""Computes a global color palette for a sequence of images using histogram analysis.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
pillow_imgs (list): List of pillow images. All images should be in RGB mode.
|
|
300
|
+
bits_per_channel (int, optional): Number of bits to use per color channel for histogram
|
|
301
|
+
binning. Can take values between 1 and 7. Defaults to 5.
|
|
302
|
+
palette_size (int, optional): Number of colors in the resulting palette. Defaults to 256.
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
PIL.Image: A PIL 'P' mode image containing the computed color palette.
|
|
306
|
+
|
|
307
|
+
Raises:
|
|
308
|
+
ValueError: If bits_per_channel or palette_size is outside of range.
|
|
309
|
+
"""
|
|
310
|
+
|
|
311
|
+
if not 1 <= bits_per_channel <= 7:
|
|
312
|
+
raise ValueError(f"bits_per_channel must be between 1 and 7, got {bits_per_channel}")
|
|
313
|
+
if not 1 <= palette_size <= 256:
|
|
314
|
+
raise ValueError(f"palette_size must be between 1 and 256, got {palette_size}")
|
|
315
|
+
|
|
316
|
+
# compute number of bins per channel by bitshift
|
|
317
|
+
bins_per = 1 << bits_per_channel
|
|
318
|
+
# compute total number of histogram bins for RGB
|
|
319
|
+
total_bins = bins_per**3
|
|
320
|
+
# counts per bin in the final histogram
|
|
321
|
+
counts = np.zeros(total_bins, dtype=np.int64)
|
|
322
|
+
|
|
323
|
+
shift = 8 - bits_per_channel
|
|
324
|
+
# Iterate images, accumulate bin counts
|
|
325
|
+
for img in pillow_imgs:
|
|
326
|
+
arr = np.array(img.convert("RGB"), dtype=np.uint8).reshape(-1, 3)
|
|
327
|
+
# reduce bits, compute bin index
|
|
328
|
+
r = (arr[:, 0] >> shift).astype(np.int32)
|
|
329
|
+
g = (arr[:, 1] >> shift).astype(np.int32)
|
|
330
|
+
b = (arr[:, 2] >> shift).astype(np.int32)
|
|
331
|
+
idx = (r * bins_per + g) * bins_per + b
|
|
332
|
+
# accumulate counts
|
|
333
|
+
bincount = np.bincount(idx, minlength=total_bins)
|
|
334
|
+
counts += bincount
|
|
335
|
+
|
|
336
|
+
# pick top bins
|
|
337
|
+
top_idx = np.argpartition(-counts, palette_size - 1)[:palette_size]
|
|
338
|
+
|
|
339
|
+
# sort top bins by frequency
|
|
340
|
+
top_idx = top_idx[np.argsort(-counts[top_idx])]
|
|
341
|
+
|
|
342
|
+
# convert bin index back to representative RGB (center of bin)
|
|
343
|
+
bins = np.array(
|
|
344
|
+
[((i // (bins_per * bins_per)), (i // bins_per) % bins_per, i % bins_per) for i in top_idx]
|
|
345
|
+
)
|
|
399
346
|
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
dataset_info["total_num_frames"] = sum(file_lengths)
|
|
347
|
+
# expand bin centers back to 8-bit values
|
|
348
|
+
center = (
|
|
349
|
+
(bins * (1 << (8 - bits_per_channel)) + (1 << (7 - bits_per_channel))).clip(0, 255)
|
|
350
|
+
).astype(np.uint8)
|
|
351
|
+
palette_colors = center.reshape(-1, 3) # shape (k,3)
|
|
406
352
|
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
353
|
+
# build a PIL 'P' palette image from these colors
|
|
354
|
+
pal = np.zeros(768, dtype=np.uint8) # 256*3 entries
|
|
355
|
+
pal[: palette_colors.size] = palette_colors.flatten()
|
|
356
|
+
palette_img = Image.new("P", (1, 1))
|
|
357
|
+
palette_img.putpalette(pal.tolist())
|
|
410
358
|
|
|
411
|
-
return
|
|
359
|
+
return palette_img
|
|
412
360
|
|
|
413
361
|
|
|
414
362
|
def matplotlib_figure_to_numpy(fig, **kwargs):
|
zea/models/__init__.py
CHANGED
|
@@ -72,7 +72,7 @@ The following steps are recommended when adding a new model:
|
|
|
72
72
|
|
|
73
73
|
1. Create a new module in the :mod:`zea.models` package for your model: ``zea.models.mymodel``.
|
|
74
74
|
2. Add a model class that inherits from :class:`zea.models.base.Model`. For generative models, inherit from :class:`zea.models.generative.GenerativeModel` or :class:`zea.models.deepgenerative.DeepGenerativeModel` as appropriate. Make sure you implement the :meth:`call` method.
|
|
75
|
-
3. Upload the pretrained model weights to `our Hugging Face <https://huggingface.co/
|
|
75
|
+
3. Upload the pretrained model weights to `our Hugging Face <https://huggingface.co/zeahub>`_. Should be a ``config.json`` and a ``model.weights.h5`` file. See `Keras documentation <https://keras.io/guides/serialization_and_saving/>`_ how those can be saved from your model. Simply drag and drop the files to the Hugging Face website to upload them.
|
|
76
76
|
|
|
77
77
|
.. tip::
|
|
78
78
|
It is recommended to use the mentioned saving procedure. However, alternate saving methods are also possible, see the :class:`zea.models.echonet.EchoNet` module for an example. You do now have to implement a :meth:`custom_load_weights` method in your model class.
|
zea/models/diffusion.py
CHANGED
|
@@ -9,6 +9,10 @@ To try this model, simply load one of the available presets:
|
|
|
9
9
|
|
|
10
10
|
>>> model = DiffusionModel.from_preset("diffusion-echonet-dynamic") # doctest: +SKIP
|
|
11
11
|
|
|
12
|
+
.. seealso::
|
|
13
|
+
A tutorial notebook where this model is used:
|
|
14
|
+
:doc:`../notebooks/models/diffusion_model_example`.
|
|
15
|
+
|
|
12
16
|
"""
|
|
13
17
|
|
|
14
18
|
import abc
|
|
@@ -51,6 +55,9 @@ class DiffusionModel(DeepGenerativeModel):
|
|
|
51
55
|
name="diffusion_model",
|
|
52
56
|
guidance="dps",
|
|
53
57
|
operator="inpainting",
|
|
58
|
+
ema_val=0.999,
|
|
59
|
+
min_t=0.0,
|
|
60
|
+
max_t=1.0,
|
|
54
61
|
**kwargs,
|
|
55
62
|
):
|
|
56
63
|
"""Initialize a diffusion model.
|
|
@@ -58,17 +65,20 @@ class DiffusionModel(DeepGenerativeModel):
|
|
|
58
65
|
Args:
|
|
59
66
|
input_shape: Shape of the input data. Typically of the form
|
|
60
67
|
`(height, width, channels)` for images.
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
68
|
+
input_range: Range of the input data.
|
|
69
|
+
min_signal_rate: Minimum signal rate for the diffusion schedule.
|
|
70
|
+
max_signal_rate: Maximum signal rate for the diffusion schedule.
|
|
71
|
+
network_name: Name of the network architecture to use. Options are
|
|
72
|
+
"unet_time_conditional" or "dense_time_conditional".
|
|
73
|
+
network_kwargs: Additional keyword arguments for the network.
|
|
66
74
|
name: Name of the model.
|
|
67
75
|
guidance: Guidance method to use. Can be a string, or dict with
|
|
68
76
|
"name" and "params" keys. Additionally, can be a `DiffusionGuidance` object.
|
|
69
77
|
operator: Operator to use. Can be a string, or dict with
|
|
70
78
|
"name" and "params" keys. Additionally, can be a `Operator` object.
|
|
71
|
-
|
|
79
|
+
ema_val: Exponential moving average value for the network weights.
|
|
80
|
+
min_t: Minimum diffusion time for sampling during training.
|
|
81
|
+
max_t: Maximum diffusion time for sampling during training.
|
|
72
82
|
**kwargs: Additional arguments.
|
|
73
83
|
"""
|
|
74
84
|
super().__init__(name=name, **kwargs)
|
|
@@ -79,10 +89,11 @@ class DiffusionModel(DeepGenerativeModel):
|
|
|
79
89
|
self.max_signal_rate = max_signal_rate
|
|
80
90
|
self.network_name = network_name
|
|
81
91
|
self.network_kwargs = network_kwargs or {}
|
|
92
|
+
self.ema_val = ema_val
|
|
82
93
|
|
|
83
|
-
# reverse diffusion (i.e. sampling) goes from max_t to min_t
|
|
84
|
-
self.min_t =
|
|
85
|
-
self.max_t =
|
|
94
|
+
# reverse diffusion (i.e. sampling) goes from t = max_t to t = min_t
|
|
95
|
+
self.min_t = min_t
|
|
96
|
+
self.max_t = max_t
|
|
86
97
|
|
|
87
98
|
if network_name == "unet_time_conditional":
|
|
88
99
|
self.network = get_time_conditional_unetwork(
|
|
@@ -122,8 +133,11 @@ class DiffusionModel(DeepGenerativeModel):
|
|
|
122
133
|
"input_range": self.input_range,
|
|
123
134
|
"min_signal_rate": self.min_signal_rate,
|
|
124
135
|
"max_signal_rate": self.max_signal_rate,
|
|
136
|
+
"min_t": self.min_t,
|
|
137
|
+
"max_t": self.max_t,
|
|
125
138
|
"network_name": self.network_name,
|
|
126
139
|
"network_kwargs": self.network_kwargs,
|
|
140
|
+
"ema_val": self.ema_val,
|
|
127
141
|
}
|
|
128
142
|
)
|
|
129
143
|
return config
|
|
@@ -316,8 +330,8 @@ class DiffusionModel(DeepGenerativeModel):
|
|
|
316
330
|
# Sample uniform random diffusion times in [min_t, max_t]
|
|
317
331
|
diffusion_times = keras.random.uniform(
|
|
318
332
|
shape=[batch_size, *[1] * n_dims],
|
|
319
|
-
minval=self.
|
|
320
|
-
maxval=self.
|
|
333
|
+
minval=self.min_t,
|
|
334
|
+
maxval=self.max_t,
|
|
321
335
|
)
|
|
322
336
|
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
|
|
323
337
|
|
|
@@ -337,6 +351,43 @@ class DiffusionModel(DeepGenerativeModel):
|
|
|
337
351
|
self.noise_loss_tracker.update_state(noise_loss)
|
|
338
352
|
self.image_loss_tracker.update_state(image_loss)
|
|
339
353
|
|
|
354
|
+
# track the exponential moving averages of weights.
|
|
355
|
+
# ema_network is used for inference / sampling
|
|
356
|
+
for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
|
|
357
|
+
ema_weight.assign(self.ema_val * ema_weight + (1 - self.ema_val) * weight)
|
|
358
|
+
|
|
359
|
+
return {m.name: m.result() for m in self.metrics}
|
|
360
|
+
|
|
361
|
+
def test_step(self, data):
|
|
362
|
+
"""
|
|
363
|
+
Custom test step so we can call model.fit() on the diffusion model.
|
|
364
|
+
"""
|
|
365
|
+
batch_size, *input_shape = ops.shape(data)
|
|
366
|
+
n_dims = len(input_shape)
|
|
367
|
+
|
|
368
|
+
noises = keras.random.normal(shape=ops.shape(data))
|
|
369
|
+
|
|
370
|
+
# sample uniform random diffusion times
|
|
371
|
+
diffusion_times = keras.random.uniform(
|
|
372
|
+
shape=[batch_size, *[1] * n_dims],
|
|
373
|
+
minval=self.min_t,
|
|
374
|
+
maxval=self.max_t,
|
|
375
|
+
)
|
|
376
|
+
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
|
|
377
|
+
# mix the images with noises accordingly
|
|
378
|
+
noisy_images = signal_rates * data + noise_rates * noises
|
|
379
|
+
|
|
380
|
+
# use the network to separate noisy images to their components
|
|
381
|
+
pred_noises, pred_images = self.denoise(
|
|
382
|
+
noisy_images, noise_rates, signal_rates, training=False
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
noise_loss = self.loss(noises, pred_noises)
|
|
386
|
+
image_loss = self.loss(data, pred_images)
|
|
387
|
+
|
|
388
|
+
self.noise_loss_tracker.update_state(noise_loss)
|
|
389
|
+
self.image_loss_tracker.update_state(image_loss)
|
|
390
|
+
|
|
340
391
|
return {m.name: m.result() for m in self.metrics}
|
|
341
392
|
|
|
342
393
|
def diffusion_schedule(self, diffusion_times):
|
zea/models/lv_segmentation.py
CHANGED
|
@@ -44,6 +44,8 @@ from zea.models.base import BaseModel
|
|
|
44
44
|
from zea.models.preset_utils import get_preset_loader, register_presets
|
|
45
45
|
from zea.models.presets import augmented_camus_seg_presets
|
|
46
46
|
|
|
47
|
+
INFERENCE_SIZE = 256
|
|
48
|
+
|
|
47
49
|
|
|
48
50
|
@model_registry(name="augmented_camus_seg")
|
|
49
51
|
class AugmentedCamusSeg(BaseModel):
|