monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/hpo_gen.py +1 -1
- monai/apps/detection/utils/anchor_utils.py +2 -2
- monai/apps/pathology/transforms/post/array.py +7 -4
- monai/auto3dseg/analyzer.py +1 -1
- monai/bundle/scripts.py +204 -22
- monai/bundle/utils.py +1 -0
- monai/data/dataset_summary.py +1 -0
- monai/data/meta_tensor.py +2 -2
- monai/data/test_time_augmentation.py +2 -0
- monai/data/utils.py +9 -6
- monai/data/wsi_reader.py +2 -2
- monai/engines/__init__.py +3 -1
- monai/engines/trainer.py +281 -2
- monai/engines/utils.py +76 -1
- monai/handlers/mlflow_handler.py +21 -4
- monai/inferers/__init__.py +5 -0
- monai/inferers/inferer.py +1279 -1
- monai/metrics/cumulative_average.py +2 -0
- monai/metrics/panoptic_quality.py +1 -1
- monai/metrics/rocauc.py +2 -2
- monai/networks/blocks/__init__.py +3 -0
- monai/networks/blocks/attention_utils.py +128 -0
- monai/networks/blocks/crossattention.py +168 -0
- monai/networks/blocks/rel_pos_embedding.py +56 -0
- monai/networks/blocks/selfattention.py +74 -5
- monai/networks/blocks/spade_norm.py +95 -0
- monai/networks/blocks/spatialattention.py +82 -0
- monai/networks/blocks/transformerblock.py +25 -4
- monai/networks/blocks/upsample.py +22 -10
- monai/networks/layers/__init__.py +2 -1
- monai/networks/layers/factories.py +12 -1
- monai/networks/layers/simplelayers.py +1 -1
- monai/networks/layers/utils.py +14 -1
- monai/networks/layers/vector_quantizer.py +233 -0
- monai/networks/nets/__init__.py +9 -0
- monai/networks/nets/autoencoderkl.py +702 -0
- monai/networks/nets/controlnet.py +465 -0
- monai/networks/nets/diffusion_model_unet.py +1913 -0
- monai/networks/nets/patchgan_discriminator.py +230 -0
- monai/networks/nets/quicknat.py +8 -6
- monai/networks/nets/resnet.py +3 -4
- monai/networks/nets/spade_autoencoderkl.py +480 -0
- monai/networks/nets/spade_diffusion_model_unet.py +934 -0
- monai/networks/nets/spade_network.py +435 -0
- monai/networks/nets/swin_unetr.py +4 -3
- monai/networks/nets/transformer.py +157 -0
- monai/networks/nets/vqvae.py +472 -0
- monai/networks/schedulers/__init__.py +17 -0
- monai/networks/schedulers/ddim.py +294 -0
- monai/networks/schedulers/ddpm.py +250 -0
- monai/networks/schedulers/pndm.py +316 -0
- monai/networks/schedulers/scheduler.py +205 -0
- monai/networks/utils.py +22 -0
- monai/transforms/croppad/array.py +8 -8
- monai/transforms/croppad/dictionary.py +4 -4
- monai/transforms/croppad/functional.py +1 -1
- monai/transforms/regularization/array.py +4 -0
- monai/transforms/spatial/array.py +1 -1
- monai/transforms/utils_create_transform_ims.py +2 -4
- monai/utils/__init__.py +1 -0
- monai/utils/misc.py +5 -4
- monai/utils/ordering.py +207 -0
- monai/visualize/class_activation_maps.py +5 -5
- monai/visualize/img2tensorboard.py +3 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
monai/engines/trainer.py
CHANGED
@@ -24,7 +24,7 @@ from monai.engines.utils import IterationEvents, default_make_latent, default_me
|
|
24
24
|
from monai.engines.workflow import Workflow
|
25
25
|
from monai.inferers import Inferer, SimpleInferer
|
26
26
|
from monai.transforms import Transform
|
27
|
-
from monai.utils import GanKeys, min_version, optional_import
|
27
|
+
from monai.utils import AdversarialIterationEvents, AdversarialKeys, GanKeys, min_version, optional_import
|
28
28
|
from monai.utils.enums import CommonKeys as Keys
|
29
29
|
from monai.utils.enums import EngineStatsKeys as ESKeys
|
30
30
|
from monai.utils.module import pytorch_after
|
@@ -37,7 +37,7 @@ else:
|
|
37
37
|
Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
|
38
38
|
EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum")
|
39
39
|
|
40
|
-
__all__ = ["Trainer", "SupervisedTrainer", "GanTrainer"]
|
40
|
+
__all__ = ["Trainer", "SupervisedTrainer", "GanTrainer", "AdversarialTrainer"]
|
41
41
|
|
42
42
|
|
43
43
|
class Trainer(Workflow):
|
@@ -471,3 +471,282 @@ class GanTrainer(Trainer):
|
|
471
471
|
GanKeys.GLOSS: g_loss.item(),
|
472
472
|
GanKeys.DLOSS: d_total_loss.item(),
|
473
473
|
}
|
474
|
+
|
475
|
+
|
476
|
+
class AdversarialTrainer(Trainer):
|
477
|
+
"""
|
478
|
+
Standard supervised training workflow for adversarial loss enabled neural networks.
|
479
|
+
|
480
|
+
Args:
|
481
|
+
device: an object representing the device on which to run.
|
482
|
+
max_epochs: the total epoch number for engine to run.
|
483
|
+
train_data_loader: Core ignite engines uses `DataLoader` for training loop batchdata.
|
484
|
+
g_network: ''generator'' (G) network architecture.
|
485
|
+
g_optimizer: G optimizer function.
|
486
|
+
g_loss_function: G loss function for adversarial training.
|
487
|
+
recon_loss_function: G loss function for reconstructions.
|
488
|
+
d_network: discriminator (D) network architecture.
|
489
|
+
d_optimizer: D optimizer function.
|
490
|
+
d_loss_function: D loss function for adversarial training..
|
491
|
+
epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.
|
492
|
+
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to
|
493
|
+
the host. For other cases, this argument has no effect.
|
494
|
+
prepare_batch: function to parse image and label for current iteration.
|
495
|
+
iteration_update: the callable function for every iteration, expect to accept `engine` and `batchdata` as input
|
496
|
+
parameters. if not provided, use `self._iteration()` instead.
|
497
|
+
g_inferer: inference method to execute G model forward. Defaults to ``SimpleInferer()``.
|
498
|
+
d_inferer: inference method to execute D model forward. Defaults to ``SimpleInferer()``.
|
499
|
+
postprocessing: execute additional transformation for the model output data. Typically, several Tensor based
|
500
|
+
transforms composed by `Compose`. Defaults to None
|
501
|
+
key_train_metric: compute metric when every iteration completed, and save average value to engine.state.metrics
|
502
|
+
when epoch completed. key_train_metric is the main metric to compare and save the checkpoint into files.
|
503
|
+
additional_metrics: more Ignite metrics that also attach to Ignite Engine.
|
504
|
+
metric_cmp_fn: function to compare current key metric with previous best key metric value, it must accept 2 args
|
505
|
+
(current_metric, previous_best) and return a bool result: if `True`, will update 'best_metric` and
|
506
|
+
`best_metric_epoch` with current metric and epoch, default to `greater than`.
|
507
|
+
train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
|
508
|
+
CheckpointHandler, StatsHandler, etc.
|
509
|
+
amp: whether to enable auto-mixed-precision training, default is False.
|
510
|
+
event_names: additional custom ignite events that will register to the engine.
|
511
|
+
new events can be a list of str or `ignite.engine.events.EventEnum`.
|
512
|
+
event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
|
513
|
+
for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
|
514
|
+
#ignite.engine.engine.Engine.register_events.
|
515
|
+
decollate: whether to decollate the batch-first data to a list of data after model computation, recommend
|
516
|
+
`decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`.
|
517
|
+
optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
|
518
|
+
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
|
519
|
+
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
|
520
|
+
`device`, `non_blocking`.
|
521
|
+
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
|
522
|
+
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
|
523
|
+
"""
|
524
|
+
|
525
|
+
def __init__(
|
526
|
+
self,
|
527
|
+
device: torch.device | str,
|
528
|
+
max_epochs: int,
|
529
|
+
train_data_loader: Iterable | DataLoader,
|
530
|
+
g_network: torch.nn.Module,
|
531
|
+
g_optimizer: Optimizer,
|
532
|
+
g_loss_function: Callable,
|
533
|
+
recon_loss_function: Callable,
|
534
|
+
d_network: torch.nn.Module,
|
535
|
+
d_optimizer: Optimizer,
|
536
|
+
d_loss_function: Callable,
|
537
|
+
epoch_length: int | None = None,
|
538
|
+
non_blocking: bool = False,
|
539
|
+
prepare_batch: Callable = default_prepare_batch,
|
540
|
+
iteration_update: Callable | None = None,
|
541
|
+
g_inferer: Inferer | None = None,
|
542
|
+
d_inferer: Inferer | None = None,
|
543
|
+
postprocessing: Transform | None = None,
|
544
|
+
key_train_metric: dict[str, Metric] | None = None,
|
545
|
+
additional_metrics: dict[str, Metric] | None = None,
|
546
|
+
metric_cmp_fn: Callable = default_metric_cmp_fn,
|
547
|
+
train_handlers: Sequence | None = None,
|
548
|
+
amp: bool = False,
|
549
|
+
event_names: list[str | EventEnum | type[EventEnum]] | None = None,
|
550
|
+
event_to_attr: dict | None = None,
|
551
|
+
decollate: bool = True,
|
552
|
+
optim_set_to_none: bool = False,
|
553
|
+
to_kwargs: dict | None = None,
|
554
|
+
amp_kwargs: dict | None = None,
|
555
|
+
):
|
556
|
+
super().__init__(
|
557
|
+
device=device,
|
558
|
+
max_epochs=max_epochs,
|
559
|
+
data_loader=train_data_loader,
|
560
|
+
epoch_length=epoch_length,
|
561
|
+
non_blocking=non_blocking,
|
562
|
+
prepare_batch=prepare_batch,
|
563
|
+
iteration_update=iteration_update,
|
564
|
+
postprocessing=postprocessing,
|
565
|
+
key_metric=key_train_metric,
|
566
|
+
additional_metrics=additional_metrics,
|
567
|
+
metric_cmp_fn=metric_cmp_fn,
|
568
|
+
handlers=train_handlers,
|
569
|
+
amp=amp,
|
570
|
+
event_names=event_names,
|
571
|
+
event_to_attr=event_to_attr,
|
572
|
+
decollate=decollate,
|
573
|
+
to_kwargs=to_kwargs,
|
574
|
+
amp_kwargs=amp_kwargs,
|
575
|
+
)
|
576
|
+
|
577
|
+
self.register_events(*AdversarialIterationEvents)
|
578
|
+
|
579
|
+
self.state.g_network = g_network
|
580
|
+
self.state.g_optimizer = g_optimizer
|
581
|
+
self.state.g_loss_function = g_loss_function
|
582
|
+
self.state.recon_loss_function = recon_loss_function
|
583
|
+
|
584
|
+
self.state.d_network = d_network
|
585
|
+
self.state.d_optimizer = d_optimizer
|
586
|
+
self.state.d_loss_function = d_loss_function
|
587
|
+
|
588
|
+
self.g_inferer = SimpleInferer() if g_inferer is None else g_inferer
|
589
|
+
self.d_inferer = SimpleInferer() if d_inferer is None else d_inferer
|
590
|
+
|
591
|
+
self.state.g_scaler = torch.cuda.amp.GradScaler() if self.amp else None
|
592
|
+
self.state.d_scaler = torch.cuda.amp.GradScaler() if self.amp else None
|
593
|
+
|
594
|
+
self.optim_set_to_none = optim_set_to_none
|
595
|
+
self._complete_state_dict_user_keys()
|
596
|
+
|
597
|
+
def _complete_state_dict_user_keys(self) -> None:
|
598
|
+
"""
|
599
|
+
This method appends to the _state_dict_user_keys AdversarialTrainer's elements that are required for
|
600
|
+
checkpoint saving.
|
601
|
+
|
602
|
+
Follows the example found at:
|
603
|
+
https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine.state_dict
|
604
|
+
"""
|
605
|
+
self._state_dict_user_keys.extend(
|
606
|
+
["g_network", "g_optimizer", "d_network", "d_optimizer", "g_scaler", "d_scaler"]
|
607
|
+
)
|
608
|
+
|
609
|
+
g_loss_state_dict = getattr(self.state.g_loss_function, "state_dict", None)
|
610
|
+
if callable(g_loss_state_dict):
|
611
|
+
self._state_dict_user_keys.append("g_loss_function")
|
612
|
+
|
613
|
+
d_loss_state_dict = getattr(self.state.d_loss_function, "state_dict", None)
|
614
|
+
if callable(d_loss_state_dict):
|
615
|
+
self._state_dict_user_keys.append("d_loss_function")
|
616
|
+
|
617
|
+
recon_loss_state_dict = getattr(self.state.recon_loss_function, "state_dict", None)
|
618
|
+
if callable(recon_loss_state_dict):
|
619
|
+
self._state_dict_user_keys.append("recon_loss_function")
|
620
|
+
|
621
|
+
def _iteration(
|
622
|
+
self, engine: AdversarialTrainer, batchdata: dict[str, torch.Tensor]
|
623
|
+
) -> dict[str, torch.Tensor | int | float | bool]:
|
624
|
+
"""
|
625
|
+
Callback function for the Adversarial Training processing logic of 1 iteration in Ignite Engine.
|
626
|
+
Return below items in a dictionary:
|
627
|
+
- IMAGE: image Tensor data for model input, already moved to device.
|
628
|
+
- LABEL: label Tensor data corresponding to the image, already moved to device. In case of Unsupervised
|
629
|
+
Learning this is equal to IMAGE.
|
630
|
+
- PRED: prediction result of model.
|
631
|
+
- LOSS: loss value computed by loss functions of the generator (reconstruction and adversarial summed up).
|
632
|
+
- AdversarialKeys.REALS: real images from the batch. Are the same as IMAGE.
|
633
|
+
- AdversarialKeys.FAKES: fake images generated by the generator. Are the same as PRED.
|
634
|
+
- AdversarialKeys.REAL_LOGITS: logits of the discriminator for the real images.
|
635
|
+
- AdversarialKeys.FAKE_LOGITS: logits of the discriminator for the fake images.
|
636
|
+
- AdversarialKeys.RECONSTRUCTION_LOSS: loss value computed by the reconstruction loss function.
|
637
|
+
- AdversarialKeys.GENERATOR_LOSS: loss value computed by the generator loss function. It is the
|
638
|
+
discriminator loss for the fake images. That is backpropagated through the generator only.
|
639
|
+
- AdversarialKeys.DISCRIMINATOR_LOSS: loss value computed by the discriminator loss function. It is the
|
640
|
+
discriminator loss for the real images and the fake images. That is backpropagated through the
|
641
|
+
discriminator only.
|
642
|
+
|
643
|
+
Args:
|
644
|
+
engine: `AdversarialTrainer` to execute operation for an iteration.
|
645
|
+
batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.
|
646
|
+
|
647
|
+
Raises:
|
648
|
+
ValueError: must provide batch data for current iteration.
|
649
|
+
|
650
|
+
"""
|
651
|
+
|
652
|
+
if batchdata is None:
|
653
|
+
raise ValueError("Must provide batch data for current iteration.")
|
654
|
+
batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)
|
655
|
+
|
656
|
+
if len(batch) == 2:
|
657
|
+
inputs, targets = batch
|
658
|
+
args: tuple = ()
|
659
|
+
kwargs: dict = {}
|
660
|
+
else:
|
661
|
+
inputs, targets, args, kwargs = batch
|
662
|
+
|
663
|
+
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets, AdversarialKeys.REALS: inputs}
|
664
|
+
|
665
|
+
def _compute_generator_loss() -> None:
|
666
|
+
engine.state.output[AdversarialKeys.FAKES] = engine.g_inferer(
|
667
|
+
inputs, engine.state.g_network, *args, **kwargs
|
668
|
+
)
|
669
|
+
engine.state.output[Keys.PRED] = engine.state.output[AdversarialKeys.FAKES]
|
670
|
+
engine.fire_event(AdversarialIterationEvents.GENERATOR_FORWARD_COMPLETED)
|
671
|
+
|
672
|
+
engine.state.output[AdversarialKeys.FAKE_LOGITS] = engine.d_inferer(
|
673
|
+
engine.state.output[AdversarialKeys.FAKES].float().contiguous(), engine.state.d_network, *args, **kwargs
|
674
|
+
)
|
675
|
+
engine.fire_event(AdversarialIterationEvents.GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED)
|
676
|
+
|
677
|
+
engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS] = engine.state.recon_loss_function(
|
678
|
+
engine.state.output[AdversarialKeys.FAKES], targets
|
679
|
+
).mean()
|
680
|
+
engine.fire_event(AdversarialIterationEvents.RECONSTRUCTION_LOSS_COMPLETED)
|
681
|
+
|
682
|
+
engine.state.output[AdversarialKeys.GENERATOR_LOSS] = engine.state.g_loss_function(
|
683
|
+
engine.state.output[AdversarialKeys.FAKE_LOGITS]
|
684
|
+
).mean()
|
685
|
+
engine.fire_event(AdversarialIterationEvents.GENERATOR_LOSS_COMPLETED)
|
686
|
+
|
687
|
+
# Train Generator
|
688
|
+
engine.state.g_network.train()
|
689
|
+
engine.state.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
|
690
|
+
|
691
|
+
if engine.amp and engine.state.g_scaler is not None:
|
692
|
+
with torch.cuda.amp.autocast(**engine.amp_kwargs):
|
693
|
+
_compute_generator_loss()
|
694
|
+
|
695
|
+
engine.state.output[Keys.LOSS] = (
|
696
|
+
engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS]
|
697
|
+
+ engine.state.output[AdversarialKeys.GENERATOR_LOSS]
|
698
|
+
)
|
699
|
+
engine.state.g_scaler.scale(engine.state.output[Keys.LOSS]).backward()
|
700
|
+
engine.fire_event(AdversarialIterationEvents.GENERATOR_BACKWARD_COMPLETED)
|
701
|
+
engine.state.g_scaler.step(engine.state.g_optimizer)
|
702
|
+
engine.state.g_scaler.update()
|
703
|
+
else:
|
704
|
+
_compute_generator_loss()
|
705
|
+
(
|
706
|
+
engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS]
|
707
|
+
+ engine.state.output[AdversarialKeys.GENERATOR_LOSS]
|
708
|
+
).backward()
|
709
|
+
engine.fire_event(AdversarialIterationEvents.GENERATOR_BACKWARD_COMPLETED)
|
710
|
+
engine.state.g_optimizer.step()
|
711
|
+
engine.fire_event(AdversarialIterationEvents.GENERATOR_MODEL_COMPLETED)
|
712
|
+
|
713
|
+
def _compute_discriminator_loss() -> None:
|
714
|
+
engine.state.output[AdversarialKeys.REAL_LOGITS] = engine.d_inferer(
|
715
|
+
engine.state.output[AdversarialKeys.REALS].contiguous().detach(),
|
716
|
+
engine.state.d_network,
|
717
|
+
*args,
|
718
|
+
**kwargs,
|
719
|
+
)
|
720
|
+
engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_REALS_FORWARD_COMPLETED)
|
721
|
+
|
722
|
+
engine.state.output[AdversarialKeys.FAKE_LOGITS] = engine.d_inferer(
|
723
|
+
engine.state.output[AdversarialKeys.FAKES].contiguous().detach(),
|
724
|
+
engine.state.d_network,
|
725
|
+
*args,
|
726
|
+
**kwargs,
|
727
|
+
)
|
728
|
+
engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_FAKES_FORWARD_COMPLETED)
|
729
|
+
|
730
|
+
engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS] = engine.state.d_loss_function(
|
731
|
+
engine.state.output[AdversarialKeys.REAL_LOGITS], engine.state.output[AdversarialKeys.FAKE_LOGITS]
|
732
|
+
).mean()
|
733
|
+
engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_LOSS_COMPLETED)
|
734
|
+
|
735
|
+
# Train Discriminator
|
736
|
+
engine.state.d_network.train()
|
737
|
+
engine.state.d_network.zero_grad(set_to_none=engine.optim_set_to_none)
|
738
|
+
|
739
|
+
if engine.amp and engine.state.d_scaler is not None:
|
740
|
+
with torch.cuda.amp.autocast(**engine.amp_kwargs):
|
741
|
+
_compute_discriminator_loss()
|
742
|
+
|
743
|
+
engine.state.d_scaler.scale(engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS]).backward()
|
744
|
+
engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_BACKWARD_COMPLETED)
|
745
|
+
engine.state.d_scaler.step(engine.state.d_optimizer)
|
746
|
+
engine.state.d_scaler.update()
|
747
|
+
else:
|
748
|
+
_compute_discriminator_loss()
|
749
|
+
engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS].backward()
|
750
|
+
engine.state.d_optimizer.step()
|
751
|
+
|
752
|
+
return engine.state.output
|
monai/engines/utils.py
CHANGED
@@ -13,9 +13,10 @@ from __future__ import annotations
|
|
13
13
|
|
14
14
|
from abc import ABC, abstractmethod
|
15
15
|
from collections.abc import Callable, Sequence
|
16
|
-
from typing import TYPE_CHECKING, Any, cast
|
16
|
+
from typing import TYPE_CHECKING, Any, Mapping, cast
|
17
17
|
|
18
18
|
import torch
|
19
|
+
import torch.nn as nn
|
19
20
|
|
20
21
|
from monai.config import IgniteInfo
|
21
22
|
from monai.transforms import apply_transform
|
@@ -36,6 +37,8 @@ __all__ = [
|
|
36
37
|
"PrepareBatch",
|
37
38
|
"PrepareBatchDefault",
|
38
39
|
"PrepareBatchExtraInput",
|
40
|
+
"DiffusionPrepareBatch",
|
41
|
+
"VPredictionPrepareBatch",
|
39
42
|
"default_make_latent",
|
40
43
|
"engine_apply_transform",
|
41
44
|
"default_metric_cmp_fn",
|
@@ -238,6 +241,78 @@ class PrepareBatchExtraInput(PrepareBatch):
|
|
238
241
|
return cast(torch.Tensor, image), cast(torch.Tensor, label), tuple(args_), kwargs_
|
239
242
|
|
240
243
|
|
244
|
+
class DiffusionPrepareBatch(PrepareBatch):
|
245
|
+
"""
|
246
|
+
This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.
|
247
|
+
|
248
|
+
Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and
|
249
|
+
return the image and noise field as the image/target pair plus the noise field the kwargs under the key "noise".
|
250
|
+
This assumes the inferer being used in conjunction with this class expects a "noise" parameter to be provided.
|
251
|
+
|
252
|
+
If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition
|
253
|
+
field to be passed to the inferer. This will appear in the keyword arguments under the key "condition".
|
254
|
+
|
255
|
+
"""
|
256
|
+
|
257
|
+
def __init__(self, num_train_timesteps: int, condition_name: str | None = None) -> None:
|
258
|
+
self.condition_name = condition_name
|
259
|
+
self.num_train_timesteps = num_train_timesteps
|
260
|
+
|
261
|
+
def get_noise(self, images: torch.Tensor) -> torch.Tensor:
|
262
|
+
"""Returns the noise tensor for input tensor `images`, override this for different noise distributions."""
|
263
|
+
return torch.randn_like(images)
|
264
|
+
|
265
|
+
def get_timesteps(self, images: torch.Tensor) -> torch.Tensor:
|
266
|
+
"""Get a timestep, by default this is a random integer between 0 and `self.num_train_timesteps`."""
|
267
|
+
return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long()
|
268
|
+
|
269
|
+
def get_target(self, images: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
270
|
+
"""Return the target for the loss function, this is the `noise` value by default."""
|
271
|
+
return noise
|
272
|
+
|
273
|
+
def __call__(
|
274
|
+
self,
|
275
|
+
batchdata: dict[str, torch.Tensor],
|
276
|
+
device: str | torch.device | None = None,
|
277
|
+
non_blocking: bool = False,
|
278
|
+
**kwargs: Any,
|
279
|
+
) -> tuple[torch.Tensor, torch.Tensor, tuple, dict]:
|
280
|
+
images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs)
|
281
|
+
noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs)
|
282
|
+
timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs)
|
283
|
+
|
284
|
+
target = self.get_target(images, noise, timesteps).to(device, non_blocking=non_blocking, **kwargs)
|
285
|
+
infer_kwargs = {"noise": noise, "timesteps": timesteps}
|
286
|
+
|
287
|
+
if self.condition_name is not None and isinstance(batchdata, Mapping):
|
288
|
+
infer_kwargs["condition"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs)
|
289
|
+
|
290
|
+
# return input, target, arguments, and keyword arguments where noise is the target and also a keyword value
|
291
|
+
return images, target, (), infer_kwargs
|
292
|
+
|
293
|
+
|
294
|
+
class VPredictionPrepareBatch(DiffusionPrepareBatch):
|
295
|
+
"""
|
296
|
+
This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.
|
297
|
+
|
298
|
+
Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and
|
299
|
+
from this compute the velocity using the provided scheduler. This value is used as the target in place of the
|
300
|
+
noise field itself although the noise is field is in the kwargs under the key "noise". This assumes the inferer
|
301
|
+
being used in conjunction with this class expects a "noise" parameter to be provided.
|
302
|
+
|
303
|
+
If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition
|
304
|
+
field to be passed to the inferer. This will appear in the keyword arguments under the key "condition".
|
305
|
+
|
306
|
+
"""
|
307
|
+
|
308
|
+
def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_name: str | None = None) -> None:
|
309
|
+
super().__init__(num_train_timesteps=num_train_timesteps, condition_name=condition_name)
|
310
|
+
self.scheduler = scheduler
|
311
|
+
|
312
|
+
def get_target(self, images, noise, timesteps):
|
313
|
+
return self.scheduler.get_velocity(images, noise, timesteps)
|
314
|
+
|
315
|
+
|
241
316
|
def default_make_latent(
|
242
317
|
num_latents: int,
|
243
318
|
latent_size: int,
|
monai/handlers/mlflow_handler.py
CHANGED
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any
|
|
21
21
|
import torch
|
22
22
|
from torch.utils.data import Dataset
|
23
23
|
|
24
|
+
from monai.apps.utils import get_logger
|
24
25
|
from monai.config import IgniteInfo
|
25
26
|
from monai.utils import CommonKeys, ensure_tuple, min_version, optional_import
|
26
27
|
|
@@ -29,6 +30,9 @@ mlflow, _ = optional_import("mlflow", descriptor="Please install mlflow before u
|
|
29
30
|
mlflow.entities, _ = optional_import(
|
30
31
|
"mlflow.entities", descriptor="Please install mlflow.entities before using MLFlowHandler."
|
31
32
|
)
|
33
|
+
MlflowException, _ = optional_import(
|
34
|
+
"mlflow.exceptions", name="MlflowException", descriptor="Please install mlflow before using MLFlowHandler."
|
35
|
+
)
|
32
36
|
pandas, _ = optional_import("pandas", descriptor="Please install pandas for recording the dataset.")
|
33
37
|
tqdm, _ = optional_import("tqdm", "4.47.0", min_version, "tqdm")
|
34
38
|
|
@@ -41,6 +45,8 @@ else:
|
|
41
45
|
|
42
46
|
DEFAULT_TAG = "Loss"
|
43
47
|
|
48
|
+
logger = get_logger(module_name=__name__)
|
49
|
+
|
44
50
|
|
45
51
|
class MLFlowHandler:
|
46
52
|
"""
|
@@ -236,10 +242,21 @@ class MLFlowHandler:
|
|
236
242
|
def _set_experiment(self):
|
237
243
|
experiment = self.experiment
|
238
244
|
if not experiment:
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
245
|
+
for _retry_time in range(3):
|
246
|
+
try:
|
247
|
+
experiment = self.client.get_experiment_by_name(self.experiment_name)
|
248
|
+
if not experiment:
|
249
|
+
experiment_id = self.client.create_experiment(self.experiment_name)
|
250
|
+
experiment = self.client.get_experiment(experiment_id)
|
251
|
+
break
|
252
|
+
except MlflowException as e:
|
253
|
+
if "RESOURCE_ALREADY_EXISTS" in str(e):
|
254
|
+
logger.warning("Experiment already exists; delaying before retrying.")
|
255
|
+
time.sleep(1)
|
256
|
+
if _retry_time == 2:
|
257
|
+
raise e
|
258
|
+
else:
|
259
|
+
raise e
|
243
260
|
|
244
261
|
if experiment.lifecycle_stage != mlflow.entities.LifecycleStage.ACTIVE:
|
245
262
|
raise ValueError(f"Cannot set a deleted experiment '{self.experiment_name}' as the active experiment")
|
monai/inferers/__init__.py
CHANGED
@@ -12,13 +12,18 @@
|
|
12
12
|
from __future__ import annotations
|
13
13
|
|
14
14
|
from .inferer import (
|
15
|
+
ControlNetDiffusionInferer,
|
16
|
+
ControlNetLatentDiffusionInferer,
|
17
|
+
DiffusionInferer,
|
15
18
|
Inferer,
|
19
|
+
LatentDiffusionInferer,
|
16
20
|
PatchInferer,
|
17
21
|
SaliencyInferer,
|
18
22
|
SimpleInferer,
|
19
23
|
SliceInferer,
|
20
24
|
SlidingWindowInferer,
|
21
25
|
SlidingWindowInfererAdapt,
|
26
|
+
VQVAETransformerInferer,
|
22
27
|
)
|
23
28
|
from .merger import AvgMerger, Merger, ZarrAvgMerger
|
24
29
|
from .splitter import SlidingWindowSplitter, Splitter, WSISlidingWindowSplitter
|