monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2429__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/detection/utils/anchor_utils.py +2 -2
  4. monai/apps/pathology/transforms/post/array.py +1 -0
  5. monai/bundle/scripts.py +106 -8
  6. monai/bundle/utils.py +1 -0
  7. monai/data/dataset_summary.py +1 -0
  8. monai/data/utils.py +9 -6
  9. monai/data/wsi_reader.py +2 -2
  10. monai/engines/__init__.py +3 -1
  11. monai/engines/trainer.py +281 -2
  12. monai/engines/utils.py +76 -1
  13. monai/handlers/mlflow_handler.py +21 -4
  14. monai/inferers/__init__.py +5 -0
  15. monai/inferers/inferer.py +1279 -1
  16. monai/networks/blocks/__init__.py +3 -0
  17. monai/networks/blocks/attention_utils.py +128 -0
  18. monai/networks/blocks/crossattention.py +166 -0
  19. monai/networks/blocks/rel_pos_embedding.py +56 -0
  20. monai/networks/blocks/selfattention.py +72 -5
  21. monai/networks/blocks/spade_norm.py +95 -0
  22. monai/networks/blocks/spatialattention.py +82 -0
  23. monai/networks/blocks/transformerblock.py +24 -4
  24. monai/networks/blocks/upsample.py +22 -10
  25. monai/networks/layers/__init__.py +2 -1
  26. monai/networks/layers/factories.py +12 -1
  27. monai/networks/layers/utils.py +14 -1
  28. monai/networks/layers/vector_quantizer.py +233 -0
  29. monai/networks/nets/__init__.py +9 -0
  30. monai/networks/nets/autoencoderkl.py +702 -0
  31. monai/networks/nets/controlnet.py +465 -0
  32. monai/networks/nets/diffusion_model_unet.py +1913 -0
  33. monai/networks/nets/patchgan_discriminator.py +230 -0
  34. monai/networks/nets/quicknat.py +2 -0
  35. monai/networks/nets/resnet.py +3 -4
  36. monai/networks/nets/spade_autoencoderkl.py +480 -0
  37. monai/networks/nets/spade_diffusion_model_unet.py +934 -0
  38. monai/networks/nets/spade_network.py +435 -0
  39. monai/networks/nets/swin_unetr.py +4 -3
  40. monai/networks/nets/transformer.py +157 -0
  41. monai/networks/nets/vqvae.py +472 -0
  42. monai/networks/schedulers/__init__.py +17 -0
  43. monai/networks/schedulers/ddim.py +294 -0
  44. monai/networks/schedulers/ddpm.py +250 -0
  45. monai/networks/schedulers/pndm.py +316 -0
  46. monai/networks/schedulers/scheduler.py +205 -0
  47. monai/networks/utils.py +22 -0
  48. monai/transforms/regularization/array.py +4 -0
  49. monai/transforms/utils_create_transform_ims.py +2 -4
  50. monai/utils/__init__.py +1 -0
  51. monai/utils/misc.py +5 -4
  52. monai/utils/ordering.py +207 -0
  53. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2429.dist-info}/METADATA +1 -1
  54. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2429.dist-info}/RECORD +57 -36
  55. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2429.dist-info}/WHEEL +1 -1
  56. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2429.dist-info}/LICENSE +0 -0
  57. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2429.dist-info}/top_level.txt +0 -0
monai/engines/utils.py CHANGED
@@ -13,9 +13,10 @@ from __future__ import annotations
13
13
 
14
14
  from abc import ABC, abstractmethod
15
15
  from collections.abc import Callable, Sequence
16
- from typing import TYPE_CHECKING, Any, cast
16
+ from typing import TYPE_CHECKING, Any, Mapping, cast
17
17
 
18
18
  import torch
19
+ import torch.nn as nn
19
20
 
20
21
  from monai.config import IgniteInfo
21
22
  from monai.transforms import apply_transform
@@ -36,6 +37,8 @@ __all__ = [
36
37
  "PrepareBatch",
37
38
  "PrepareBatchDefault",
38
39
  "PrepareBatchExtraInput",
40
+ "DiffusionPrepareBatch",
41
+ "VPredictionPrepareBatch",
39
42
  "default_make_latent",
40
43
  "engine_apply_transform",
41
44
  "default_metric_cmp_fn",
@@ -238,6 +241,78 @@ class PrepareBatchExtraInput(PrepareBatch):
238
241
  return cast(torch.Tensor, image), cast(torch.Tensor, label), tuple(args_), kwargs_
239
242
 
240
243
 
244
+ class DiffusionPrepareBatch(PrepareBatch):
245
+ """
246
+ This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.
247
+
248
+ Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and
249
+ return the image and noise field as the image/target pair plus the noise field the kwargs under the key "noise".
250
+ This assumes the inferer being used in conjunction with this class expects a "noise" parameter to be provided.
251
+
252
+ If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition
253
+ field to be passed to the inferer. This will appear in the keyword arguments under the key "condition".
254
+
255
+ """
256
+
257
+ def __init__(self, num_train_timesteps: int, condition_name: str | None = None) -> None:
258
+ self.condition_name = condition_name
259
+ self.num_train_timesteps = num_train_timesteps
260
+
261
+ def get_noise(self, images: torch.Tensor) -> torch.Tensor:
262
+ """Returns the noise tensor for input tensor `images`, override this for different noise distributions."""
263
+ return torch.randn_like(images)
264
+
265
+ def get_timesteps(self, images: torch.Tensor) -> torch.Tensor:
266
+ """Get a timestep, by default this is a random integer between 0 and `self.num_train_timesteps`."""
267
+ return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long()
268
+
269
+ def get_target(self, images: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
270
+ """Return the target for the loss function, this is the `noise` value by default."""
271
+ return noise
272
+
273
+ def __call__(
274
+ self,
275
+ batchdata: dict[str, torch.Tensor],
276
+ device: str | torch.device | None = None,
277
+ non_blocking: bool = False,
278
+ **kwargs: Any,
279
+ ) -> tuple[torch.Tensor, torch.Tensor, tuple, dict]:
280
+ images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs)
281
+ noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs)
282
+ timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs)
283
+
284
+ target = self.get_target(images, noise, timesteps).to(device, non_blocking=non_blocking, **kwargs)
285
+ infer_kwargs = {"noise": noise, "timesteps": timesteps}
286
+
287
+ if self.condition_name is not None and isinstance(batchdata, Mapping):
288
+ infer_kwargs["condition"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs)
289
+
290
+ # return input, target, arguments, and keyword arguments where noise is the target and also a keyword value
291
+ return images, target, (), infer_kwargs
292
+
293
+
294
+ class VPredictionPrepareBatch(DiffusionPrepareBatch):
295
+ """
296
+ This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.
297
+
298
+ Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and
299
+ from this compute the velocity using the provided scheduler. This value is used as the target in place of the
300
+ noise field itself although the noise is field is in the kwargs under the key "noise". This assumes the inferer
301
+ being used in conjunction with this class expects a "noise" parameter to be provided.
302
+
303
+ If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition
304
+ field to be passed to the inferer. This will appear in the keyword arguments under the key "condition".
305
+
306
+ """
307
+
308
+ def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_name: str | None = None) -> None:
309
+ super().__init__(num_train_timesteps=num_train_timesteps, condition_name=condition_name)
310
+ self.scheduler = scheduler
311
+
312
+ def get_target(self, images, noise, timesteps):
313
+ return self.scheduler.get_velocity(images, noise, timesteps)
314
+
315
+
241
316
  def default_make_latent(
242
317
  num_latents: int,
243
318
  latent_size: int,
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any
21
21
  import torch
22
22
  from torch.utils.data import Dataset
23
23
 
24
+ from monai.apps.utils import get_logger
24
25
  from monai.config import IgniteInfo
25
26
  from monai.utils import CommonKeys, ensure_tuple, min_version, optional_import
26
27
 
@@ -29,6 +30,9 @@ mlflow, _ = optional_import("mlflow", descriptor="Please install mlflow before u
29
30
  mlflow.entities, _ = optional_import(
30
31
  "mlflow.entities", descriptor="Please install mlflow.entities before using MLFlowHandler."
31
32
  )
33
+ MlflowException, _ = optional_import(
34
+ "mlflow.exceptions", name="MlflowException", descriptor="Please install mlflow before using MLFlowHandler."
35
+ )
32
36
  pandas, _ = optional_import("pandas", descriptor="Please install pandas for recording the dataset.")
33
37
  tqdm, _ = optional_import("tqdm", "4.47.0", min_version, "tqdm")
34
38
 
@@ -41,6 +45,8 @@ else:
41
45
 
42
46
  DEFAULT_TAG = "Loss"
43
47
 
48
+ logger = get_logger(module_name=__name__)
49
+
44
50
 
45
51
  class MLFlowHandler:
46
52
  """
@@ -236,10 +242,21 @@ class MLFlowHandler:
236
242
  def _set_experiment(self):
237
243
  experiment = self.experiment
238
244
  if not experiment:
239
- experiment = self.client.get_experiment_by_name(self.experiment_name)
240
- if not experiment:
241
- experiment_id = self.client.create_experiment(self.experiment_name)
242
- experiment = self.client.get_experiment(experiment_id)
245
+ for _retry_time in range(3):
246
+ try:
247
+ experiment = self.client.get_experiment_by_name(self.experiment_name)
248
+ if not experiment:
249
+ experiment_id = self.client.create_experiment(self.experiment_name)
250
+ experiment = self.client.get_experiment(experiment_id)
251
+ break
252
+ except MlflowException as e:
253
+ if "RESOURCE_ALREADY_EXISTS" in str(e):
254
+ logger.warning("Experiment already exists; delaying before retrying.")
255
+ time.sleep(1)
256
+ if _retry_time == 2:
257
+ raise e
258
+ else:
259
+ raise e
243
260
 
244
261
  if experiment.lifecycle_stage != mlflow.entities.LifecycleStage.ACTIVE:
245
262
  raise ValueError(f"Cannot set a deleted experiment '{self.experiment_name}' as the active experiment")
@@ -12,13 +12,18 @@
12
12
  from __future__ import annotations
13
13
 
14
14
  from .inferer import (
15
+ ControlNetDiffusionInferer,
16
+ ControlNetLatentDiffusionInferer,
17
+ DiffusionInferer,
15
18
  Inferer,
19
+ LatentDiffusionInferer,
16
20
  PatchInferer,
17
21
  SaliencyInferer,
18
22
  SimpleInferer,
19
23
  SliceInferer,
20
24
  SlidingWindowInferer,
21
25
  SlidingWindowInfererAdapt,
26
+ VQVAETransformerInferer,
22
27
  )
23
28
  from .merger import AvgMerger, Merger, ZarrAvgMerger
24
29
  from .splitter import SlidingWindowSplitter, Splitter, WSISlidingWindowSplitter