zea 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. zea/__init__.py +3 -3
  2. zea/agent/masks.py +2 -2
  3. zea/agent/selection.py +3 -3
  4. zea/backend/__init__.py +1 -1
  5. zea/backend/tensorflow/dataloader.py +1 -5
  6. zea/beamform/beamformer.py +4 -2
  7. zea/beamform/pfield.py +2 -2
  8. zea/beamform/pixelgrid.py +1 -1
  9. zea/data/__init__.py +0 -9
  10. zea/data/augmentations.py +222 -29
  11. zea/data/convert/__init__.py +1 -6
  12. zea/data/convert/__main__.py +164 -0
  13. zea/data/convert/camus.py +106 -40
  14. zea/data/convert/echonet.py +184 -83
  15. zea/data/convert/echonetlvh/README.md +2 -3
  16. zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
  17. zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
  18. zea/data/convert/echonetlvh/precompute_crop.py +43 -64
  19. zea/data/convert/picmus.py +37 -40
  20. zea/data/convert/utils.py +86 -0
  21. zea/data/convert/verasonics.py +1247 -0
  22. zea/data/data_format.py +124 -6
  23. zea/data/dataloader.py +12 -7
  24. zea/data/datasets.py +109 -70
  25. zea/data/file.py +119 -82
  26. zea/data/file_operations.py +496 -0
  27. zea/data/preset_utils.py +2 -2
  28. zea/display.py +8 -9
  29. zea/doppler.py +5 -5
  30. zea/func/__init__.py +109 -0
  31. zea/{tensor_ops.py → func/tensor.py} +113 -69
  32. zea/func/ultrasound.py +500 -0
  33. zea/internal/_generate_keras_ops.py +5 -5
  34. zea/internal/checks.py +6 -12
  35. zea/internal/operators.py +4 -0
  36. zea/io_lib.py +108 -160
  37. zea/metrics.py +6 -5
  38. zea/models/__init__.py +1 -1
  39. zea/models/diffusion.py +63 -12
  40. zea/models/echonetlvh.py +1 -1
  41. zea/models/gmm.py +1 -1
  42. zea/models/lv_segmentation.py +2 -0
  43. zea/ops/__init__.py +188 -0
  44. zea/ops/base.py +442 -0
  45. zea/{keras_ops.py → ops/keras_ops.py} +2 -2
  46. zea/ops/pipeline.py +1472 -0
  47. zea/ops/tensor.py +356 -0
  48. zea/ops/ultrasound.py +890 -0
  49. zea/probes.py +2 -10
  50. zea/scan.py +35 -28
  51. zea/tools/fit_scan_cone.py +90 -160
  52. zea/tools/selection_tool.py +1 -1
  53. zea/tracking/__init__.py +16 -0
  54. zea/tracking/base.py +94 -0
  55. zea/tracking/lucas_kanade.py +474 -0
  56. zea/tracking/segmentation.py +110 -0
  57. zea/utils.py +11 -2
  58. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/METADATA +5 -1
  59. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/RECORD +62 -48
  60. zea/data/convert/matlab.py +0 -1237
  61. zea/ops.py +0 -3294
  62. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/WHEEL +0 -0
  63. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/entry_points.txt +0 -0
  64. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/licenses/LICENSE +0 -0
zea/io_lib.py CHANGED
@@ -4,20 +4,17 @@ Use to quickly read and write files or interact with file system.
4
4
  """
5
5
 
6
6
  import functools
7
- import multiprocessing
8
7
  import os
9
8
  import time
10
9
  from io import BytesIO
11
10
  from pathlib import Path
11
+ from typing import Generator
12
12
 
13
13
  import imageio
14
14
  import numpy as np
15
- import tqdm
16
- import yaml
17
15
  from PIL import Image, ImageSequence
18
16
 
19
17
  from zea import log
20
- from zea.data.file import File
21
18
 
22
19
  _SUPPORTED_VID_TYPES = [".mp4", ".gif"]
23
20
  _SUPPORTED_IMG_TYPES = [".jpg", ".png", ".JPEG", ".PNG", ".jpeg"]
@@ -123,7 +120,7 @@ def save_video(images, filename, fps=20, **kwargs):
123
120
  filename (str or Path): Filename to which data should be written.
124
121
  fps (int): Frames per second of rendered format.
125
122
  **kwargs: Additional keyword arguments passed to the specific save function.
126
- For GIF files, this includes `shared_color_palette` (bool).
123
+ For GIF and mp4 files, this includes `shared_color_palette` (bool).
127
124
 
128
125
  Raises:
129
126
  ValueError: If the file extension is not supported.
@@ -133,14 +130,14 @@ def save_video(images, filename, fps=20, **kwargs):
133
130
  ext = filename.suffix.lower()
134
131
 
135
132
  if ext == ".mp4":
136
- return save_to_mp4(images, filename, fps=fps)
133
+ return save_to_mp4(images, filename, fps=fps, **kwargs)
137
134
  elif ext == ".gif":
138
135
  return save_to_gif(images, filename, fps=fps, **kwargs)
139
136
  else:
140
137
  raise ValueError(f"Unsupported file extension: {ext}")
141
138
 
142
139
 
143
- def save_to_gif(images, filename, fps=20, shared_color_palette=False):
140
+ def save_to_gif(images, filename, fps=20, shared_color_palette=True):
144
141
  """Saves a sequence of images to a GIF file.
145
142
 
146
143
  .. note::
@@ -156,9 +153,9 @@ def save_to_gif(images, filename, fps=20, shared_color_palette=False):
156
153
  fps (int): Frames per second of rendered format.
157
154
  shared_color_palette (bool, optional): If True, creates a global
158
155
  color palette across all frames, ensuring consistent colors
159
- throughout the GIF. Defaults to False, which is default behavior
160
- of PIL.Image.save. Note: True can cause slow saving for longer
161
- sequences, and also lead to larger file sizes in some cases.
156
+ throughout the GIF. Defaults to True, which is default behavior
157
+ of PIL.Image.save. Note: True increases speed and shrinks file
158
+ size for longer sequences.
162
159
 
163
160
  """
164
161
  images = preprocess_for_saving(images)
@@ -173,15 +170,8 @@ def save_to_gif(images, filename, fps=20, shared_color_palette=False):
173
170
 
174
171
  if shared_color_palette:
175
172
  # Apply the same palette to all frames without dithering for consistent color mapping
176
- # Convert all images to RGB and combine their colors for palette generation
177
- all_colors = np.vstack([np.array(img.convert("RGB")).reshape(-1, 3) for img in pillow_imgs])
178
- combined_image = Image.fromarray(all_colors.reshape(-1, 1, 3))
179
-
180
- # Generate palette from all frames
181
- global_palette = combined_image.quantize(
182
- colors=256,
183
- method=Image.MEDIANCUT,
184
- kmeans=1,
173
+ global_palette = compute_global_palette_by_histogram(
174
+ pillow_imgs, bits_per_channel=5, palette_size=256
185
175
  )
186
176
 
187
177
  # Apply the same palette to all frames without dithering
@@ -208,7 +198,7 @@ def save_to_gif(images, filename, fps=20, shared_color_palette=False):
208
198
  log.success(f"Successfully saved GIF to -> {log.yellow(filename)}")
209
199
 
210
200
 
211
- def save_to_mp4(images, filename, fps=20):
201
+ def save_to_mp4(images, filename, fps=20, shared_color_palette=False):
212
202
  """Saves a sequence of images to an MP4 file.
213
203
 
214
204
  .. note::
@@ -222,6 +212,10 @@ def save_to_mp4(images, filename, fps=20):
222
212
  which is then converted to RGB. Images should be uint8.
223
213
  filename (str or Path): Filename to which data should be written.
224
214
  fps (int): Frames per second of rendered format.
215
+ shared_color_palette (bool, optional): If True, creates a global
216
+ color palette across all frames, ensuring consistent colors
217
+ throughout the MP4. Note: True can cause slow saving for longer
218
+ sequences.
225
219
 
226
220
  Raises:
227
221
  ImportError: If imageio-ffmpeg is not installed.
@@ -249,166 +243,120 @@ def save_to_mp4(images, filename, fps=20):
249
243
  ) from exc
250
244
 
251
245
  try:
252
- for image in images:
253
- writer.append_data(image)
246
+ if shared_color_palette:
247
+ pillow_imgs = [Image.fromarray(img) for img in images]
248
+ global_palette = compute_global_palette_by_histogram(
249
+ pillow_imgs, bits_per_channel=5, palette_size=256
250
+ )
251
+ for img in pillow_imgs:
252
+ paletted_img = img.convert("RGB").quantize(
253
+ palette=global_palette,
254
+ dither=Image.NONE,
255
+ )
256
+ writer.append_data(np.array(paletted_img.convert("RGB")))
257
+ else:
258
+ # Write from numpy arrays directly
259
+ for image in images:
260
+ writer.append_data(image)
254
261
  finally:
255
262
  writer.close()
256
263
 
257
264
  return log.success(f"Successfully saved MP4 to -> {filename}")
258
265
 
259
266
 
260
- def search_file_tree(
261
- directory,
262
- filetypes=None,
263
- write=True,
264
- dataset_info_filename="dataset_info.yaml",
265
- hdf5_key_for_length=None,
266
- redo=False,
267
- parallel=False,
268
- verbose=True,
269
- ):
270
- """Lists all files in directory and sub-directories.
271
-
272
- If dataset_info.yaml is detected in the directory, that file is read and used
273
- to deduce the file paths. If not, the file paths are searched for in the
274
- directory and written to a dataset_info.yaml file.
267
+ def search_file_tree(directory, filetypes=None, verbose=True, relative=False) -> Generator:
268
+ """Traverse a directory tree and yield file paths matching specified file types.
275
269
 
276
270
  Args:
277
- directory (str): Path to base directory to start file search.
278
- filetypes (str or list, optional): Filetypes to look for in directory.
279
- Defaults to image types (.png etc.). Make sure to include the dot.
280
- write (bool, optional): Whether to write to dataset_info.yaml file.
281
- Defaults to True. If False, the file paths are not written to file
282
- and simply returned.
283
- dataset_info_filename (str, optional): Name of dataset info file.
284
- Defaults to "dataset_info.yaml", but can be changed to any name.
285
- hdf5_key_for_length (str, optional): Key to use for getting length of hdf5 files.
286
- Defaults to None. If set, the number of frames in each hdf5 file is
287
- calculated and stored in the dataset_info.yaml file. This is extra
288
- functionality of ``search_file_tree`` and only works with hdf5 files.
289
- redo (bool, optional): Whether to redo the search and overwrite the dataset_info.yaml file.
290
- parallel (bool, optional): Whether to use multiprocessing for hdf5 shape reading.
291
- verbose (bool, optional): Whether to print progress and info.
292
-
293
- Returns:
294
- dict: Dictionary containing file paths and total number of files.
295
- Has the following structure:
296
-
297
- .. code-block:: python
298
-
299
- {
300
- "file_paths": list of file paths,
301
- "total_num_files": total number of files,
302
- "file_lengths": list of number of frames in each hdf5 file,
303
- "file_shapes": list of shapes of each image file,
304
- "total_num_frames": total number of frames in all hdf5 files
305
- }
306
-
271
+ directory (str or Path): The root directory to start the search.
272
+ filetypes (list of str, optional): List of file extensions to match.
273
+ If None, file types supported by `zea` are matched. Defaults to None.
274
+ verbose (bool, optional): If True, logs the search process. Defaults to True.
275
+ relative (bool, optional): If True, yields file paths relative to the
276
+ root directory. Defaults to False.
277
+
278
+ Yields:
279
+ Path: Paths of files matching the specified file types.
307
280
  """
308
- directory = Path(directory)
309
- if not directory.is_dir():
310
- raise ValueError(
311
- log.error(f"Directory {directory} does not exist. Please provide a valid directory.")
312
- )
313
- assert Path(dataset_info_filename).suffix == ".yaml", (
314
- "Currently only YAML files are supported for dataset info file when "
315
- f"using `search_file_tree`, got {dataset_info_filename}"
316
- )
317
-
318
- if (directory / dataset_info_filename).is_file() and not redo:
319
- with open(directory / dataset_info_filename, "r", encoding="utf-8") as file:
320
- dataset_info = yaml.load(file, Loader=yaml.FullLoader)
321
-
322
- # Check if the file_shapes key is present in the dataset_info, otherwise redo the search
323
- if "file_shapes" in dataset_info:
324
- if verbose:
325
- log.info(
326
- "Using pregenerated dataset info file: "
327
- f"{log.yellow(directory / dataset_info_filename)} ..."
328
- )
329
- log.info(f"...for reading file paths in {log.yellow(directory)}")
330
- return dataset_info
331
-
332
- if redo and verbose:
333
- log.info(f"Overwriting dataset info file: {log.yellow(directory / dataset_info_filename)}")
334
-
335
- # set default file type
336
- if filetypes is None:
337
- filetypes = _SUPPORTED_IMG_TYPES + _SUPPORTED_VID_TYPES + _SUPPORTED_ZEA_TYPES
338
-
339
- file_paths = []
340
-
341
- if isinstance(filetypes, str):
342
- filetypes = [filetypes]
343
-
344
- if hdf5_key_for_length is not None:
345
- assert isinstance(hdf5_key_for_length, str), "hdf5_key_for_length must be a string"
346
- assert set(filetypes).issubset({".hdf5", ".h5"}), (
347
- "hdf5_key_for_length only works with when filetypes is set to "
348
- f"`.hdf5` or `.h5`, got {filetypes}"
349
- )
350
-
351
281
  # Traverse file tree to index all files from filetypes
352
282
  if verbose:
353
283
  log.info(f"Searching {log.yellow(directory)} for {filetypes} files...")
284
+
354
285
  for dirpath, _, filenames in os.walk(directory):
355
286
  for file in filenames:
356
287
  # Append to file_paths if it is a filetype file
357
288
  if Path(file).suffix in filetypes:
358
289
  file_path = Path(dirpath) / file
359
- file_path = file_path.relative_to(directory)
360
- file_paths.append(str(file_path))
361
-
362
- if hdf5_key_for_length is not None:
363
- # using multiprocessing to speed up reading hdf5 files
364
- # and getting the number of frames in each file
365
- if verbose:
366
- log.info("Getting number of frames in each hdf5 file...")
367
-
368
- get_shape_partial = functools.partial(File.get_shape, key=hdf5_key_for_length)
369
- # make sure to call search_file_tree from within a function
370
- # or use if __name__ == "__main__":
371
- # to avoid freezing the main process
372
- absolute_file_paths = [directory / file for file in file_paths]
373
- if parallel:
374
- with multiprocessing.Pool() as pool:
375
- file_shapes = list(
376
- tqdm.tqdm(
377
- pool.imap(
378
- get_shape_partial,
379
- absolute_file_paths,
380
- ),
381
- total=len(file_paths),
382
- desc="Getting number of frames in each hdf5 file",
383
- disable=not verbose,
384
- )
385
- )
386
- else:
387
- file_shapes = []
388
- for file_path in tqdm.tqdm(
389
- absolute_file_paths,
390
- desc="Getting number of frames in each hdf5 file",
391
- disable=not verbose,
392
- ):
393
- file_shapes.append(File.get_shape(file_path, hdf5_key_for_length))
394
-
395
- assert len(file_paths) > 0, f"No image files were found in: {directory}"
396
- if verbose:
397
- log.info(f"Found {len(file_paths)} image files in {log.yellow(directory)}")
398
- log.info(f"Writing dataset info to {log.yellow(directory / dataset_info_filename)}")
290
+ if relative:
291
+ file_path = file_path.relative_to(directory)
292
+ yield file_path
293
+
294
+
295
+ def compute_global_palette_by_histogram(pillow_imgs, bits_per_channel=5, palette_size=256):
296
+ """Computes a global color palette for a sequence of images using histogram analysis.
297
+
298
+ Args:
299
+ pillow_imgs (list): List of pillow images. All images should be in RGB mode.
300
+ bits_per_channel (int, optional): Number of bits to use per color channel for histogram
301
+ binning. Can take values between 1 and 7. Defaults to 5.
302
+ palette_size (int, optional): Number of colors in the resulting palette. Defaults to 256.
303
+
304
+ Returns:
305
+ PIL.Image: A PIL 'P' mode image containing the computed color palette.
306
+
307
+ Raises:
308
+ ValueError: If bits_per_channel or palette_size is outside of range.
309
+ """
310
+
311
+ if not 1 <= bits_per_channel <= 7:
312
+ raise ValueError(f"bits_per_channel must be between 1 and 7, got {bits_per_channel}")
313
+ if not 1 <= palette_size <= 256:
314
+ raise ValueError(f"palette_size must be between 1 and 256, got {palette_size}")
315
+
316
+ # compute number of bins per channel by bitshift
317
+ bins_per = 1 << bits_per_channel
318
+ # compute total number of histogram bins for RGB
319
+ total_bins = bins_per**3
320
+ # counts per bin in the final histogram
321
+ counts = np.zeros(total_bins, dtype=np.int64)
322
+
323
+ shift = 8 - bits_per_channel
324
+ # Iterate images, accumulate bin counts
325
+ for img in pillow_imgs:
326
+ arr = np.array(img.convert("RGB"), dtype=np.uint8).reshape(-1, 3)
327
+ # reduce bits, compute bin index
328
+ r = (arr[:, 0] >> shift).astype(np.int32)
329
+ g = (arr[:, 1] >> shift).astype(np.int32)
330
+ b = (arr[:, 2] >> shift).astype(np.int32)
331
+ idx = (r * bins_per + g) * bins_per + b
332
+ # accumulate counts
333
+ bincount = np.bincount(idx, minlength=total_bins)
334
+ counts += bincount
335
+
336
+ # pick top bins
337
+ top_idx = np.argpartition(-counts, palette_size - 1)[:palette_size]
338
+
339
+ # sort top bins by frequency
340
+ top_idx = top_idx[np.argsort(-counts[top_idx])]
341
+
342
+ # convert bin index back to representative RGB (center of bin)
343
+ bins = np.array(
344
+ [((i // (bins_per * bins_per)), (i // bins_per) % bins_per, i % bins_per) for i in top_idx]
345
+ )
399
346
 
400
- dataset_info = {"file_paths": file_paths, "total_num_files": len(file_paths)}
401
- if len(file_shapes) > 0:
402
- dataset_info["file_shapes"] = file_shapes
403
- file_lengths = [shape[0] for shape in file_shapes]
404
- dataset_info["file_lengths"] = file_lengths
405
- dataset_info["total_num_frames"] = sum(file_lengths)
347
+ # expand bin centers back to 8-bit values
348
+ center = (
349
+ (bins * (1 << (8 - bits_per_channel)) + (1 << (7 - bits_per_channel))).clip(0, 255)
350
+ ).astype(np.uint8)
351
+ palette_colors = center.reshape(-1, 3) # shape (k,3)
406
352
 
407
- if write:
408
- with open(directory / dataset_info_filename, "w", encoding="utf-8") as file:
409
- yaml.dump(dataset_info, file)
353
+ # build a PIL 'P' palette image from these colors
354
+ pal = np.zeros(768, dtype=np.uint8) # 256*3 entries
355
+ pal[: palette_colors.size] = palette_colors.flatten()
356
+ palette_img = Image.new("P", (1, 1))
357
+ palette_img.putpalette(pal.tolist())
410
358
 
411
- return dataset_info
359
+ return palette_img
412
360
 
413
361
 
414
362
  def matplotlib_figure_to_numpy(fig, **kwargs):
zea/metrics.py CHANGED
@@ -7,12 +7,13 @@ import keras
7
7
  import numpy as np
8
8
  from keras import ops
9
9
 
10
- from zea import log, tensor_ops
10
+ from zea import log
11
11
  from zea.backend import func_on_device
12
+ from zea.func import tensor
13
+ from zea.func.tensor import translate
12
14
  from zea.internal.registry import metrics_registry
13
15
  from zea.internal.utils import reduce_to_signature
14
16
  from zea.models.lpips import LPIPS
15
- from zea.tensor_ops import translate
16
17
 
17
18
 
18
19
  def get_metric(name, **kwargs):
@@ -197,7 +198,7 @@ def ssim(
197
198
 
198
199
  # Construct a 1D convolution.
199
200
  def filter_fn_1(z):
200
- return tensor_ops.correlate(z, ops.flip(filt), mode="valid")
201
+ return tensor.correlate(z, ops.flip(filt), mode="valid")
201
202
 
202
203
  # Apply the vectorized filter along the y axis.
203
204
  def filter_fn_y(z):
@@ -300,7 +301,7 @@ def get_lpips(image_range, batch_size=None, clip=False):
300
301
 
301
302
  imgs = ops.stack([img1, img2], axis=-1)
302
303
  n_batch_dims = ops.ndim(img1) - 3
303
- return tensor_ops.func_with_one_batch_dim(
304
+ return tensor.func_with_one_batch_dim(
304
305
  unstack_lpips, imgs, n_batch_dims, batch_size=batch_size
305
306
  )
306
307
 
@@ -372,7 +373,7 @@ class Metrics:
372
373
  # Because most metric functions do not support batching, we vmap over the batch axes.
373
374
  metric_fn = fun
374
375
  for ax in reversed(batch_axes):
375
- metric_fn = tensor_ops.vmap(metric_fn, in_axes=ax, _use_torch_vmap=True)
376
+ metric_fn = tensor.vmap(metric_fn, in_axes=ax, _use_torch_vmap=True)
376
377
 
377
378
  out = func_on_device(metric_fn, device, y_true, y_pred)
378
379
 
zea/models/__init__.py CHANGED
@@ -72,7 +72,7 @@ The following steps are recommended when adding a new model:
72
72
 
73
73
  1. Create a new module in the :mod:`zea.models` package for your model: ``zea.models.mymodel``.
74
74
  2. Add a model class that inherits from :class:`zea.models.base.Model`. For generative models, inherit from :class:`zea.models.generative.GenerativeModel` or :class:`zea.models.deepgenerative.DeepGenerativeModel` as appropriate. Make sure you implement the :meth:`call` method.
75
- 3. Upload the pretrained model weights to `our Hugging Face <https://huggingface.co/zea>`_. Should be a ``config.json`` and a ``model.weights.h5`` file. See `Keras documentation <https://keras.io/guides/serialization_and_saving/>`_ how those can be saved from your model. Simply drag and drop the files to the Hugging Face website to upload them.
75
+ 3. Upload the pretrained model weights to `our Hugging Face <https://huggingface.co/zeahub>`_. Should be a ``config.json`` and a ``model.weights.h5`` file. See `Keras documentation <https://keras.io/guides/serialization_and_saving/>`_ how those can be saved from your model. Simply drag and drop the files to the Hugging Face website to upload them.
76
76
 
77
77
  .. tip::
78
78
  It is recommended to use the mentioned saving procedure. However, alternate saving methods are also possible, see the :class:`zea.models.echonet.EchoNet` module for an example. You do now have to implement a :meth:`custom_load_weights` method in your model class.
zea/models/diffusion.py CHANGED
@@ -9,6 +9,10 @@ To try this model, simply load one of the available presets:
9
9
 
10
10
  >>> model = DiffusionModel.from_preset("diffusion-echonet-dynamic") # doctest: +SKIP
11
11
 
12
+ .. seealso::
13
+ A tutorial notebook where this model is used:
14
+ :doc:`../notebooks/models/diffusion_model_example`.
15
+
12
16
  """
13
17
 
14
18
  import abc
@@ -19,6 +23,7 @@ from keras import ops
19
23
 
20
24
  from zea.backend import _import_tf, jit
21
25
  from zea.backend.autograd import AutoGrad
26
+ from zea.func.tensor import L2, fori_loop, split_seed
22
27
  from zea.internal.core import Object
23
28
  from zea.internal.operators import Operator
24
29
  from zea.internal.registry import diffusion_guidance_registry, model_registry, operator_registry
@@ -29,7 +34,6 @@ from zea.models.preset_utils import register_presets
29
34
  from zea.models.presets import diffusion_model_presets
30
35
  from zea.models.unet import get_time_conditional_unetwork
31
36
  from zea.models.utils import LossTrackerWrapper
32
- from zea.tensor_ops import L2, fori_loop, split_seed
33
37
 
34
38
  tf = _import_tf()
35
39
 
@@ -51,6 +55,9 @@ class DiffusionModel(DeepGenerativeModel):
51
55
  name="diffusion_model",
52
56
  guidance="dps",
53
57
  operator="inpainting",
58
+ ema_val=0.999,
59
+ min_t=0.0,
60
+ max_t=1.0,
54
61
  **kwargs,
55
62
  ):
56
63
  """Initialize a diffusion model.
@@ -58,17 +65,20 @@ class DiffusionModel(DeepGenerativeModel):
58
65
  Args:
59
66
  input_shape: Shape of the input data. Typically of the form
60
67
  `(height, width, channels)` for images.
61
- widths: List of filter widths for the UNet.
62
- block_depth: Number of residual blocks in each UNet block.
63
- timesteps: Number of diffusion timesteps.
64
- beta_start: Initial noise schedule value.
65
- beta_end: Final noise schedule value.
68
+ input_range: Range of the input data.
69
+ min_signal_rate: Minimum signal rate for the diffusion schedule.
70
+ max_signal_rate: Maximum signal rate for the diffusion schedule.
71
+ network_name: Name of the network architecture to use. Options are
72
+ "unet_time_conditional" or "dense_time_conditional".
73
+ network_kwargs: Additional keyword arguments for the network.
66
74
  name: Name of the model.
67
75
  guidance: Guidance method to use. Can be a string, or dict with
68
76
  "name" and "params" keys. Additionally, can be a `DiffusionGuidance` object.
69
77
  operator: Operator to use. Can be a string, or dict with
70
78
  "name" and "params" keys. Additionally, can be a `Operator` object.
71
-
79
+ ema_val: Exponential moving average value for the network weights.
80
+ min_t: Minimum diffusion time for sampling during training.
81
+ max_t: Maximum diffusion time for sampling during training.
72
82
  **kwargs: Additional arguments.
73
83
  """
74
84
  super().__init__(name=name, **kwargs)
@@ -79,10 +89,11 @@ class DiffusionModel(DeepGenerativeModel):
79
89
  self.max_signal_rate = max_signal_rate
80
90
  self.network_name = network_name
81
91
  self.network_kwargs = network_kwargs or {}
92
+ self.ema_val = ema_val
82
93
 
83
- # reverse diffusion (i.e. sampling) goes from max_t to min_t
84
- self.min_t = 0.0
85
- self.max_t = 1.0
94
+ # reverse diffusion (i.e. sampling) goes from t = max_t to t = min_t
95
+ self.min_t = min_t
96
+ self.max_t = max_t
86
97
 
87
98
  if network_name == "unet_time_conditional":
88
99
  self.network = get_time_conditional_unetwork(
@@ -122,8 +133,11 @@ class DiffusionModel(DeepGenerativeModel):
122
133
  "input_range": self.input_range,
123
134
  "min_signal_rate": self.min_signal_rate,
124
135
  "max_signal_rate": self.max_signal_rate,
136
+ "min_t": self.min_t,
137
+ "max_t": self.max_t,
125
138
  "network_name": self.network_name,
126
139
  "network_kwargs": self.network_kwargs,
140
+ "ema_val": self.ema_val,
127
141
  }
128
142
  )
129
143
  return config
@@ -316,8 +330,8 @@ class DiffusionModel(DeepGenerativeModel):
316
330
  # Sample uniform random diffusion times in [min_t, max_t]
317
331
  diffusion_times = keras.random.uniform(
318
332
  shape=[batch_size, *[1] * n_dims],
319
- minval=self.min_signal_rate,
320
- maxval=self.max_signal_rate,
333
+ minval=self.min_t,
334
+ maxval=self.max_t,
321
335
  )
322
336
  noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
323
337
 
@@ -337,6 +351,43 @@ class DiffusionModel(DeepGenerativeModel):
337
351
  self.noise_loss_tracker.update_state(noise_loss)
338
352
  self.image_loss_tracker.update_state(image_loss)
339
353
 
354
+ # track the exponential moving averages of weights.
355
+ # ema_network is used for inference / sampling
356
+ for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
357
+ ema_weight.assign(self.ema_val * ema_weight + (1 - self.ema_val) * weight)
358
+
359
+ return {m.name: m.result() for m in self.metrics}
360
+
361
+ def test_step(self, data):
362
+ """
363
+ Custom test step so we can call model.fit() on the diffusion model.
364
+ """
365
+ batch_size, *input_shape = ops.shape(data)
366
+ n_dims = len(input_shape)
367
+
368
+ noises = keras.random.normal(shape=ops.shape(data))
369
+
370
+ # sample uniform random diffusion times
371
+ diffusion_times = keras.random.uniform(
372
+ shape=[batch_size, *[1] * n_dims],
373
+ minval=self.min_t,
374
+ maxval=self.max_t,
375
+ )
376
+ noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
377
+ # mix the images with noises accordingly
378
+ noisy_images = signal_rates * data + noise_rates * noises
379
+
380
+ # use the network to separate noisy images to their components
381
+ pred_noises, pred_images = self.denoise(
382
+ noisy_images, noise_rates, signal_rates, training=False
383
+ )
384
+
385
+ noise_loss = self.loss(noises, pred_noises)
386
+ image_loss = self.loss(data, pred_images)
387
+
388
+ self.noise_loss_tracker.update_state(noise_loss)
389
+ self.image_loss_tracker.update_state(image_loss)
390
+
340
391
  return {m.name: m.result() for m in self.metrics}
341
392
 
342
393
  def diffusion_schedule(self, diffusion_times):
zea/models/echonetlvh.py CHANGED
@@ -25,12 +25,12 @@ To try this model, simply load one of the available presets:
25
25
  import numpy as np
26
26
  from keras import ops
27
27
 
28
+ from zea.func.tensor import translate
28
29
  from zea.internal.registry import model_registry
29
30
  from zea.models.base import BaseModel
30
31
  from zea.models.deeplabv3 import DeeplabV3Plus
31
32
  from zea.models.preset_utils import register_presets
32
33
  from zea.models.presets import echonet_lvh_presets
33
- from zea.tensor_ops import translate
34
34
 
35
35
 
36
36
  @model_registry(name="echonetlvh")
zea/models/gmm.py CHANGED
@@ -4,8 +4,8 @@ import keras
4
4
  import numpy as np
5
5
  from keras import ops
6
6
 
7
+ from zea.func.tensor import linear_sum_assignment
7
8
  from zea.models.generative import GenerativeModel
8
- from zea.tensor_ops import linear_sum_assignment
9
9
 
10
10
 
11
11
  class GaussianMixtureModel(GenerativeModel):
@@ -44,6 +44,8 @@ from zea.models.base import BaseModel
44
44
  from zea.models.preset_utils import get_preset_loader, register_presets
45
45
  from zea.models.presets import augmented_camus_seg_presets
46
46
 
47
+ INFERENCE_SIZE = 256
48
+
47
49
 
48
50
  @model_registry(name="augmented_camus_seg")
49
51
  class AugmentedCamusSeg(BaseModel):