monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2429__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/detection/utils/anchor_utils.py +2 -2
- monai/apps/pathology/transforms/post/array.py +1 -0
- monai/bundle/scripts.py +106 -8
- monai/bundle/utils.py +1 -0
- monai/data/dataset_summary.py +1 -0
- monai/data/utils.py +9 -6
- monai/data/wsi_reader.py +2 -2
- monai/engines/__init__.py +3 -1
- monai/engines/trainer.py +281 -2
- monai/engines/utils.py +76 -1
- monai/handlers/mlflow_handler.py +21 -4
- monai/inferers/__init__.py +5 -0
- monai/inferers/inferer.py +1279 -1
- monai/networks/blocks/__init__.py +3 -0
- monai/networks/blocks/attention_utils.py +128 -0
- monai/networks/blocks/crossattention.py +166 -0
- monai/networks/blocks/rel_pos_embedding.py +56 -0
- monai/networks/blocks/selfattention.py +72 -5
- monai/networks/blocks/spade_norm.py +95 -0
- monai/networks/blocks/spatialattention.py +82 -0
- monai/networks/blocks/transformerblock.py +24 -4
- monai/networks/blocks/upsample.py +22 -10
- monai/networks/layers/__init__.py +2 -1
- monai/networks/layers/factories.py +12 -1
- monai/networks/layers/utils.py +14 -1
- monai/networks/layers/vector_quantizer.py +233 -0
- monai/networks/nets/__init__.py +9 -0
- monai/networks/nets/autoencoderkl.py +702 -0
- monai/networks/nets/controlnet.py +465 -0
- monai/networks/nets/diffusion_model_unet.py +1913 -0
- monai/networks/nets/patchgan_discriminator.py +230 -0
- monai/networks/nets/quicknat.py +2 -0
- monai/networks/nets/resnet.py +3 -4
- monai/networks/nets/spade_autoencoderkl.py +480 -0
- monai/networks/nets/spade_diffusion_model_unet.py +934 -0
- monai/networks/nets/spade_network.py +435 -0
- monai/networks/nets/swin_unetr.py +4 -3
- monai/networks/nets/transformer.py +157 -0
- monai/networks/nets/vqvae.py +472 -0
- monai/networks/schedulers/__init__.py +17 -0
- monai/networks/schedulers/ddim.py +294 -0
- monai/networks/schedulers/ddpm.py +250 -0
- monai/networks/schedulers/pndm.py +316 -0
- monai/networks/schedulers/scheduler.py +205 -0
- monai/networks/utils.py +22 -0
- monai/transforms/regularization/array.py +4 -0
- monai/transforms/utils_create_transform_ims.py +2 -4
- monai/utils/__init__.py +1 -0
- monai/utils/misc.py +5 -4
- monai/utils/ordering.py +207 -0
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2429.dist-info}/METADATA +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2429.dist-info}/RECORD +57 -36
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2429.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2429.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2429.dist-info}/top_level.txt +0 -0
monai/engines/utils.py
CHANGED
@@ -13,9 +13,10 @@ from __future__ import annotations
|
|
13
13
|
|
14
14
|
from abc import ABC, abstractmethod
|
15
15
|
from collections.abc import Callable, Sequence
|
16
|
-
from typing import TYPE_CHECKING, Any, cast
|
16
|
+
from typing import TYPE_CHECKING, Any, Mapping, cast
|
17
17
|
|
18
18
|
import torch
|
19
|
+
import torch.nn as nn
|
19
20
|
|
20
21
|
from monai.config import IgniteInfo
|
21
22
|
from monai.transforms import apply_transform
|
@@ -36,6 +37,8 @@ __all__ = [
|
|
36
37
|
"PrepareBatch",
|
37
38
|
"PrepareBatchDefault",
|
38
39
|
"PrepareBatchExtraInput",
|
40
|
+
"DiffusionPrepareBatch",
|
41
|
+
"VPredictionPrepareBatch",
|
39
42
|
"default_make_latent",
|
40
43
|
"engine_apply_transform",
|
41
44
|
"default_metric_cmp_fn",
|
@@ -238,6 +241,78 @@ class PrepareBatchExtraInput(PrepareBatch):
|
|
238
241
|
return cast(torch.Tensor, image), cast(torch.Tensor, label), tuple(args_), kwargs_
|
239
242
|
|
240
243
|
|
244
|
+
class DiffusionPrepareBatch(PrepareBatch):
|
245
|
+
"""
|
246
|
+
This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.
|
247
|
+
|
248
|
+
Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and
|
249
|
+
return the image and noise field as the image/target pair plus the noise field the kwargs under the key "noise".
|
250
|
+
This assumes the inferer being used in conjunction with this class expects a "noise" parameter to be provided.
|
251
|
+
|
252
|
+
If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition
|
253
|
+
field to be passed to the inferer. This will appear in the keyword arguments under the key "condition".
|
254
|
+
|
255
|
+
"""
|
256
|
+
|
257
|
+
def __init__(self, num_train_timesteps: int, condition_name: str | None = None) -> None:
|
258
|
+
self.condition_name = condition_name
|
259
|
+
self.num_train_timesteps = num_train_timesteps
|
260
|
+
|
261
|
+
def get_noise(self, images: torch.Tensor) -> torch.Tensor:
|
262
|
+
"""Returns the noise tensor for input tensor `images`, override this for different noise distributions."""
|
263
|
+
return torch.randn_like(images)
|
264
|
+
|
265
|
+
def get_timesteps(self, images: torch.Tensor) -> torch.Tensor:
|
266
|
+
"""Get a timestep, by default this is a random integer between 0 and `self.num_train_timesteps`."""
|
267
|
+
return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long()
|
268
|
+
|
269
|
+
def get_target(self, images: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
270
|
+
"""Return the target for the loss function, this is the `noise` value by default."""
|
271
|
+
return noise
|
272
|
+
|
273
|
+
def __call__(
|
274
|
+
self,
|
275
|
+
batchdata: dict[str, torch.Tensor],
|
276
|
+
device: str | torch.device | None = None,
|
277
|
+
non_blocking: bool = False,
|
278
|
+
**kwargs: Any,
|
279
|
+
) -> tuple[torch.Tensor, torch.Tensor, tuple, dict]:
|
280
|
+
images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs)
|
281
|
+
noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs)
|
282
|
+
timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs)
|
283
|
+
|
284
|
+
target = self.get_target(images, noise, timesteps).to(device, non_blocking=non_blocking, **kwargs)
|
285
|
+
infer_kwargs = {"noise": noise, "timesteps": timesteps}
|
286
|
+
|
287
|
+
if self.condition_name is not None and isinstance(batchdata, Mapping):
|
288
|
+
infer_kwargs["condition"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs)
|
289
|
+
|
290
|
+
# return input, target, arguments, and keyword arguments where noise is the target and also a keyword value
|
291
|
+
return images, target, (), infer_kwargs
|
292
|
+
|
293
|
+
|
294
|
+
class VPredictionPrepareBatch(DiffusionPrepareBatch):
|
295
|
+
"""
|
296
|
+
This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.
|
297
|
+
|
298
|
+
Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and
|
299
|
+
from this compute the velocity using the provided scheduler. This value is used as the target in place of the
|
300
|
+
noise field itself although the noise is field is in the kwargs under the key "noise". This assumes the inferer
|
301
|
+
being used in conjunction with this class expects a "noise" parameter to be provided.
|
302
|
+
|
303
|
+
If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition
|
304
|
+
field to be passed to the inferer. This will appear in the keyword arguments under the key "condition".
|
305
|
+
|
306
|
+
"""
|
307
|
+
|
308
|
+
def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_name: str | None = None) -> None:
|
309
|
+
super().__init__(num_train_timesteps=num_train_timesteps, condition_name=condition_name)
|
310
|
+
self.scheduler = scheduler
|
311
|
+
|
312
|
+
def get_target(self, images, noise, timesteps):
|
313
|
+
return self.scheduler.get_velocity(images, noise, timesteps)
|
314
|
+
|
315
|
+
|
241
316
|
def default_make_latent(
|
242
317
|
num_latents: int,
|
243
318
|
latent_size: int,
|
monai/handlers/mlflow_handler.py
CHANGED
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any
|
|
21
21
|
import torch
|
22
22
|
from torch.utils.data import Dataset
|
23
23
|
|
24
|
+
from monai.apps.utils import get_logger
|
24
25
|
from monai.config import IgniteInfo
|
25
26
|
from monai.utils import CommonKeys, ensure_tuple, min_version, optional_import
|
26
27
|
|
@@ -29,6 +30,9 @@ mlflow, _ = optional_import("mlflow", descriptor="Please install mlflow before u
|
|
29
30
|
mlflow.entities, _ = optional_import(
|
30
31
|
"mlflow.entities", descriptor="Please install mlflow.entities before using MLFlowHandler."
|
31
32
|
)
|
33
|
+
MlflowException, _ = optional_import(
|
34
|
+
"mlflow.exceptions", name="MlflowException", descriptor="Please install mlflow before using MLFlowHandler."
|
35
|
+
)
|
32
36
|
pandas, _ = optional_import("pandas", descriptor="Please install pandas for recording the dataset.")
|
33
37
|
tqdm, _ = optional_import("tqdm", "4.47.0", min_version, "tqdm")
|
34
38
|
|
@@ -41,6 +45,8 @@ else:
|
|
41
45
|
|
42
46
|
DEFAULT_TAG = "Loss"
|
43
47
|
|
48
|
+
logger = get_logger(module_name=__name__)
|
49
|
+
|
44
50
|
|
45
51
|
class MLFlowHandler:
|
46
52
|
"""
|
@@ -236,10 +242,21 @@ class MLFlowHandler:
|
|
236
242
|
def _set_experiment(self):
|
237
243
|
experiment = self.experiment
|
238
244
|
if not experiment:
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
245
|
+
for _retry_time in range(3):
|
246
|
+
try:
|
247
|
+
experiment = self.client.get_experiment_by_name(self.experiment_name)
|
248
|
+
if not experiment:
|
249
|
+
experiment_id = self.client.create_experiment(self.experiment_name)
|
250
|
+
experiment = self.client.get_experiment(experiment_id)
|
251
|
+
break
|
252
|
+
except MlflowException as e:
|
253
|
+
if "RESOURCE_ALREADY_EXISTS" in str(e):
|
254
|
+
logger.warning("Experiment already exists; delaying before retrying.")
|
255
|
+
time.sleep(1)
|
256
|
+
if _retry_time == 2:
|
257
|
+
raise e
|
258
|
+
else:
|
259
|
+
raise e
|
243
260
|
|
244
261
|
if experiment.lifecycle_stage != mlflow.entities.LifecycleStage.ACTIVE:
|
245
262
|
raise ValueError(f"Cannot set a deleted experiment '{self.experiment_name}' as the active experiment")
|
monai/inferers/__init__.py
CHANGED
@@ -12,13 +12,18 @@
|
|
12
12
|
from __future__ import annotations
|
13
13
|
|
14
14
|
from .inferer import (
|
15
|
+
ControlNetDiffusionInferer,
|
16
|
+
ControlNetLatentDiffusionInferer,
|
17
|
+
DiffusionInferer,
|
15
18
|
Inferer,
|
19
|
+
LatentDiffusionInferer,
|
16
20
|
PatchInferer,
|
17
21
|
SaliencyInferer,
|
18
22
|
SimpleInferer,
|
19
23
|
SliceInferer,
|
20
24
|
SlidingWindowInferer,
|
21
25
|
SlidingWindowInfererAdapt,
|
26
|
+
VQVAETransformerInferer,
|
22
27
|
)
|
23
28
|
from .merger import AvgMerger, Merger, ZarrAvgMerger
|
24
29
|
from .splitter import SlidingWindowSplitter, Splitter, WSISlidingWindowSplitter
|