monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2429__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/detection/utils/anchor_utils.py +2 -2
  4. monai/apps/pathology/transforms/post/array.py +1 -0
  5. monai/bundle/scripts.py +106 -8
  6. monai/bundle/utils.py +1 -0
  7. monai/data/dataset_summary.py +1 -0
  8. monai/data/utils.py +9 -6
  9. monai/data/wsi_reader.py +2 -2
  10. monai/engines/__init__.py +3 -1
  11. monai/engines/trainer.py +281 -2
  12. monai/engines/utils.py +76 -1
  13. monai/handlers/mlflow_handler.py +21 -4
  14. monai/inferers/__init__.py +5 -0
  15. monai/inferers/inferer.py +1279 -1
  16. monai/networks/blocks/__init__.py +3 -0
  17. monai/networks/blocks/attention_utils.py +128 -0
  18. monai/networks/blocks/crossattention.py +166 -0
  19. monai/networks/blocks/rel_pos_embedding.py +56 -0
  20. monai/networks/blocks/selfattention.py +72 -5
  21. monai/networks/blocks/spade_norm.py +95 -0
  22. monai/networks/blocks/spatialattention.py +82 -0
  23. monai/networks/blocks/transformerblock.py +24 -4
  24. monai/networks/blocks/upsample.py +22 -10
  25. monai/networks/layers/__init__.py +2 -1
  26. monai/networks/layers/factories.py +12 -1
  27. monai/networks/layers/utils.py +14 -1
  28. monai/networks/layers/vector_quantizer.py +233 -0
  29. monai/networks/nets/__init__.py +9 -0
  30. monai/networks/nets/autoencoderkl.py +702 -0
  31. monai/networks/nets/controlnet.py +465 -0
  32. monai/networks/nets/diffusion_model_unet.py +1913 -0
  33. monai/networks/nets/patchgan_discriminator.py +230 -0
  34. monai/networks/nets/quicknat.py +2 -0
  35. monai/networks/nets/resnet.py +3 -4
  36. monai/networks/nets/spade_autoencoderkl.py +480 -0
  37. monai/networks/nets/spade_diffusion_model_unet.py +934 -0
  38. monai/networks/nets/spade_network.py +435 -0
  39. monai/networks/nets/swin_unetr.py +4 -3
  40. monai/networks/nets/transformer.py +157 -0
  41. monai/networks/nets/vqvae.py +472 -0
  42. monai/networks/schedulers/__init__.py +17 -0
  43. monai/networks/schedulers/ddim.py +294 -0
  44. monai/networks/schedulers/ddpm.py +250 -0
  45. monai/networks/schedulers/pndm.py +316 -0
  46. monai/networks/schedulers/scheduler.py +205 -0
  47. monai/networks/utils.py +22 -0
  48. monai/transforms/regularization/array.py +4 -0
  49. monai/transforms/utils_create_transform_ims.py +2 -4
  50. monai/utils/__init__.py +1 -0
  51. monai/utils/misc.py +5 -4
  52. monai/utils/ordering.py +207 -0
  53. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2429.dist-info}/METADATA +1 -1
  54. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2429.dist-info}/RECORD +57 -36
  55. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2429.dist-info}/WHEEL +1 -1
  56. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2429.dist-info}/LICENSE +0 -0
  57. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2429.dist-info}/top_level.txt +0 -0
monai/__init__.py CHANGED
@@ -93,4 +93,4 @@ except BaseException:
93
93
 
94
94
  if MONAIEnvVars.debug():
95
95
  raise
96
- __commit_id__ = "14b086b553693f5d344ff054f37d12ce6839da06"
96
+ __commit_id__ = "d020facccbd3afe979fce68c24703dcda47234f6"
monai/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-07-14T02:21:44+0000",
11
+ "date": "2024-07-21T02:18:43+0000",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "2d242d8f1b2876133bcafbe7fa5d967728a74998",
15
- "version": "1.4.dev2428"
14
+ "full-revisionid": "023827f344ab6236402ad609f0e938c308f68ca0",
15
+ "version": "1.4.dev2429"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -189,7 +189,7 @@ class AnchorGenerator(nn.Module):
189
189
  w_ratios = 1 / area_scale
190
190
  h_ratios = area_scale
191
191
  # if 3d, w:h:d = 1:aspect_ratios[:,0]:aspect_ratios[:,1]
192
- elif self.spatial_dims == 3:
192
+ else:
193
193
  area_scale = torch.pow(aspect_ratios_t[:, 0] * aspect_ratios_t[:, 1], 1 / 3.0)
194
194
  w_ratios = 1 / area_scale
195
195
  h_ratios = aspect_ratios_t[:, 0] / area_scale
@@ -199,7 +199,7 @@ class AnchorGenerator(nn.Module):
199
199
  hs = (h_ratios[:, None] * scales_t[None, :]).view(-1)
200
200
  if self.spatial_dims == 2:
201
201
  base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2.0
202
- elif self.spatial_dims == 3:
202
+ else: # elif self.spatial_dims == 3:
203
203
  ds = (d_ratios[:, None] * scales_t[None, :]).view(-1)
204
204
  base_anchors = torch.stack([-ws, -hs, -ds, ws, hs, ds], dim=1) / 2.0
205
205
 
@@ -379,6 +379,7 @@ class GenerateSuccinctContour(Transform):
379
379
  """
380
380
 
381
381
  p_delta = (current[0] - previous[0], current[1] - previous[1])
382
+ row, col = -1, -1
382
383
 
383
384
  if p_delta in ((0.0, 1.0), (0.5, 0.5), (1.0, 0.0)):
384
385
  row = int(current[0] + 0.5)
monai/bundle/scripts.py CHANGED
@@ -16,6 +16,7 @@ import json
16
16
  import os
17
17
  import re
18
18
  import warnings
19
+ import zipfile
19
20
  from collections.abc import Mapping, Sequence
20
21
  from pathlib import Path
21
22
  from pydoc import locate
@@ -171,6 +172,10 @@ def _get_ngc_bundle_url(model_name: str, version: str) -> str:
171
172
  return f"https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit/{model_name.lower()}/versions/{version}/zip"
172
173
 
173
174
 
175
+ def _get_ngc_private_bundle_url(model_name: str, version: str, repo: str) -> str:
176
+ return f"https://api.ngc.nvidia.com/v2/{repo}/models/{model_name.lower()}/versions/{version}/zip"
177
+
178
+
174
179
  def _get_monaihosting_bundle_url(model_name: str, version: str) -> str:
175
180
  monaihosting_root_path = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
176
181
  return f"{monaihosting_root_path}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip"
@@ -219,6 +224,48 @@ def _download_from_ngc(
219
224
  extractall(filepath=filepath, output_dir=extract_path, has_base=True)
220
225
 
221
226
 
227
+ def _download_from_ngc_private(
228
+ download_path: Path, filename: str, version: str, remove_prefix: str | None, repo: str, headers: dict | None = None
229
+ ) -> None:
230
+ # ensure prefix is contained
231
+ filename = _add_ngc_prefix(filename)
232
+ request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo)
233
+ if has_requests:
234
+ headers = {} if headers is None else headers
235
+ response = requests_get(request_url, headers=headers)
236
+ response.raise_for_status()
237
+ else:
238
+ raise ValueError("NGC API requires requests package. Please install it.")
239
+
240
+ zip_path = download_path / f"{filename}_v{version}.zip"
241
+ with open(zip_path, "wb") as f:
242
+ f.write(response.content)
243
+ logger.info(f"Downloading: {zip_path}.")
244
+ if remove_prefix:
245
+ filename = _remove_ngc_prefix(filename, prefix=remove_prefix)
246
+ extract_path = download_path / f"{filename}"
247
+ with zipfile.ZipFile(zip_path, "r") as z:
248
+ z.extractall(extract_path)
249
+ logger.info(f"Writing into directory: {extract_path}.")
250
+
251
+
252
+ def _get_ngc_token(api_key, retry=0):
253
+ """Try to connect to NGC."""
254
+ url = "https://authn.nvidia.com/token?service=ngc"
255
+ headers = {"Accept": "application/json", "Authorization": "ApiKey " + api_key}
256
+ if has_requests:
257
+ response = requests_get(url, headers=headers)
258
+ if not response.ok:
259
+ # retry 3 times, if failed, raise an error.
260
+ if retry < 3:
261
+ logger.info(f"Retrying {retry} time(s) to GET {url}.")
262
+ return _get_ngc_token(url, retry + 1)
263
+ raise RuntimeError("NGC API response is not ok. Failed to get token.")
264
+ else:
265
+ token = response.json()["token"]
266
+ return token
267
+
268
+
222
269
  def _get_latest_bundle_version_monaihosting(name):
223
270
  url = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
224
271
  full_url = f"{url}/{name.lower()}"
@@ -227,12 +274,28 @@ def _get_latest_bundle_version_monaihosting(name):
227
274
  resp = requests_get(full_url)
228
275
  resp.raise_for_status()
229
276
  else:
230
- raise ValueError("NGC API requires requests package. Please install it.")
277
+ raise ValueError("NGC API requires requests package. Please install it.")
231
278
  model_info = json.loads(resp.text)
232
279
  return model_info["model"]["latestVersionIdStr"]
233
280
 
234
281
 
235
- def _get_latest_bundle_version(source: str, name: str, repo: str) -> dict[str, list[str] | str] | Any | None:
282
+ def _get_latest_bundle_version_private_registry(name, repo, headers=None):
283
+ url = f"https://api.ngc.nvidia.com/v2/{repo}/models"
284
+ full_url = f"{url}/{name.lower()}"
285
+ requests_get, has_requests = optional_import("requests", name="get")
286
+ if has_requests:
287
+ headers = {} if headers is None else headers
288
+ resp = requests_get(full_url, headers=headers)
289
+ resp.raise_for_status()
290
+ else:
291
+ raise ValueError("NGC API requires requests package. Please install it.")
292
+ model_info = json.loads(resp.text)
293
+ return model_info["model"]["latestVersionIdStr"]
294
+
295
+
296
+ def _get_latest_bundle_version(
297
+ source: str, name: str, repo: str, **kwargs: Any
298
+ ) -> dict[str, list[str] | str] | Any | None:
236
299
  if source == "ngc":
237
300
  name = _add_ngc_prefix(name)
238
301
  model_dict = _get_all_ngc_models(name)
@@ -242,6 +305,10 @@ def _get_latest_bundle_version(source: str, name: str, repo: str) -> dict[str, l
242
305
  return None
243
306
  elif source == "monaihosting":
244
307
  return _get_latest_bundle_version_monaihosting(name)
308
+ elif source == "ngc_private":
309
+ headers = kwargs.pop("headers", {})
310
+ name = _add_ngc_prefix(name)
311
+ return _get_latest_bundle_version_private_registry(name, repo, headers)
245
312
  elif source == "github":
246
313
  repo_owner, repo_name, tag_name = repo.split("/")
247
314
  return get_bundle_versions(name, repo=f"{repo_owner}/{repo_name}", tag=tag_name)["latest_version"]
@@ -308,6 +375,9 @@ def download(
308
375
  # Execute this module as a CLI entry, and download bundle via URL:
309
376
  python -m monai.bundle download --name <bundle_name> --url <url>
310
377
 
378
+ # Execute this module as a CLI entry, and download bundle from ngc_private with latest version:
379
+ python -m monai.bundle download --name <bundle_name> --source "ngc_private" --bundle_dir "./" --repo "org/org_name"
380
+
311
381
  # Set default args of `run` in a JSON / YAML file, help to record and simplify the command line.
312
382
  # Other args still can override the default args at runtime.
313
383
  # The content of the JSON / YAML file is a dictionary. For example:
@@ -328,10 +398,13 @@ def download(
328
398
  Default is `bundle` subfolder under `torch.hub.get_dir()`.
329
399
  source: storage location name. This argument is used when `url` is `None`.
330
400
  In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
331
- it should be "ngc", "monaihosting", "github", or "huggingface_hub".
401
+ it should be "ngc", "monaihosting", "github", "ngc_private", or "huggingface_hub".
402
+ If source is "ngc_private", you need specify the NGC_API_KEY in the environment variable.
332
403
  repo: repo name. This argument is used when `url` is `None` and `source` is "github" or "huggingface_hub".
333
404
  If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
334
405
  If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name".
406
+ If `source` is "ngc_private", it should be in the form of "org/org_name" or "org/org_name/team/team_name",
407
+ or you can specify the environment variable NGC_ORG and NGC_TEAM.
335
408
  url: url to download the data. If not `None`, data will be downloaded directly
336
409
  and `source` will not be checked.
337
410
  If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
@@ -363,11 +436,18 @@ def download(
363
436
 
364
437
  bundle_dir_ = _process_bundle_dir(bundle_dir_)
365
438
  if repo_ is None:
366
- repo_ = "Project-MONAI/model-zoo/hosting_storage_v1"
367
- if len(repo_.split("/")) != 3 and source_ != "huggingface_hub":
368
- raise ValueError("repo should be in the form of `repo_owner/repo_name/release_tag`.")
439
+ org_ = os.getenv("NGC_ORG", None)
440
+ team_ = os.getenv("NGC_TEAM", None)
441
+ if org_ is not None and source_ == "ngc_private":
442
+ repo_ = f"org/{org_}/team/{team_}" if team_ is not None else f"org/{org_}"
443
+ else:
444
+ repo_ = "Project-MONAI/model-zoo/hosting_storage_v1"
445
+ if len(repo_.split("/")) not in (2, 4) and source_ == "ngc_private":
446
+ raise ValueError(f"repo should be in the form of `org/org_name/team/team_name` or `org/org_name`, got {repo_}.")
447
+ if len(repo_.split("/")) != 3 and source_ == "github":
448
+ raise ValueError(f"repo should be in the form of `repo_owner/repo_name/release_tag`, got {repo_}.")
369
449
  elif len(repo_.split("/")) != 2 and source_ == "huggingface_hub":
370
- raise ValueError("Hugging Face Hub repo should be in the form of `repo_owner/repo_name`")
450
+ raise ValueError(f"Hugging Face Hub repo should be in the form of `repo_owner/repo_name`, got {repo_}.")
371
451
  if url_ is not None:
372
452
  if name_ is not None:
373
453
  filepath = bundle_dir_ / f"{name_}.zip"
@@ -376,10 +456,19 @@ def download(
376
456
  download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_)
377
457
  extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True)
378
458
  else:
459
+ headers = {}
379
460
  if name_ is None:
380
461
  raise ValueError(f"To download from source: {source_}, `name` must be provided.")
462
+ if source == "ngc_private":
463
+ api_key = os.getenv("NGC_API_KEY", None)
464
+ if api_key is None:
465
+ raise ValueError("API key is required for ngc_private source.")
466
+ else:
467
+ token = _get_ngc_token(api_key)
468
+ headers = {"Authorization": f"Bearer {token}"}
469
+
381
470
  if version_ is None:
382
- version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_)
471
+ version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_, headers=headers)
383
472
  if source_ == "github":
384
473
  if version_ is not None:
385
474
  name_ = "_v".join([name_, version_])
@@ -394,6 +483,15 @@ def download(
394
483
  remove_prefix=remove_prefix_,
395
484
  progress=progress_,
396
485
  )
486
+ elif source_ == "ngc_private":
487
+ _download_from_ngc_private(
488
+ download_path=bundle_dir_,
489
+ filename=name_,
490
+ version=version_,
491
+ remove_prefix=remove_prefix_,
492
+ repo=repo_,
493
+ headers=headers,
494
+ )
397
495
  elif source_ == "huggingface_hub":
398
496
  extract_path = os.path.join(bundle_dir_, name_)
399
497
  huggingface_hub.snapshot_download(repo_id=repo_, revision=version_, local_dir=extract_path)
monai/bundle/utils.py CHANGED
@@ -221,6 +221,7 @@ def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any
221
221
  raise ValueError(f"Cannot find config file '{full_cname}'")
222
222
 
223
223
  ardata = archive.read(full_cname)
224
+ cdata = {}
224
225
 
225
226
  if full_cname.lower().endswith("json"):
226
227
  cdata = json.loads(ardata, **load_kw_args)
@@ -84,6 +84,7 @@ class DatasetSummary:
84
84
  """
85
85
 
86
86
  for data in self.data_loader:
87
+ meta_dict = {}
87
88
  if isinstance(data[self.image_key], MetaTensor):
88
89
  meta_dict = data[self.image_key].meta
89
90
  elif self.meta_key in data:
monai/data/utils.py CHANGED
@@ -53,10 +53,6 @@ from monai.utils import (
53
53
  pytorch_after,
54
54
  )
55
55
 
56
- if pytorch_after(1, 13):
57
- # import private code for reuse purposes, comment in case things break in the future
58
- from torch.utils.data._utils.collate import collate_tensor_fn, default_collate_fn_map
59
-
60
56
  pd, _ = optional_import("pandas")
61
57
  DataFrame, _ = optional_import("pandas", name="DataFrame")
62
58
  nib, _ = optional_import("nibabel")
@@ -454,8 +450,13 @@ def collate_meta_tensor_fn(batch, *, collate_fn_map=None):
454
450
  Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor`
455
451
  and so should not be used as a collate function directly in dataloaders.
456
452
  """
457
- collate_fn = collate_tensor_fn if pytorch_after(1, 13) else default_collate
458
- collated = collate_fn(batch) # type: ignore
453
+ if pytorch_after(1, 13):
454
+ from torch.utils.data._utils.collate import collate_tensor_fn # imported here for pylint/mypy issues
455
+
456
+ collated = collate_tensor_fn(batch)
457
+ else:
458
+ collated = default_collate(batch)
459
+
459
460
  meta_dicts = [i.meta or TraceKeys.NONE for i in batch]
460
461
  common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)])
461
462
  if common_:
@@ -496,6 +497,8 @@ def list_data_collate(batch: Sequence):
496
497
 
497
498
  if pytorch_after(1, 13):
498
499
  # needs to go here to avoid circular import
500
+ from torch.utils.data._utils.collate import default_collate_fn_map
501
+
499
502
  from monai.data.meta_tensor import MetaTensor
500
503
 
501
504
  default_collate_fn_map.update({MetaTensor: collate_meta_tensor_fn})
monai/data/wsi_reader.py CHANGED
@@ -1097,8 +1097,8 @@ class TiffFileWSIReader(BaseWSIReader):
1097
1097
  ):
1098
1098
  unit = wsi.pages[level].tags.get("ResolutionUnit")
1099
1099
  if unit is not None:
1100
- unit = str(unit.value)[8:]
1101
- else:
1100
+ unit = str(unit.value.name)
1101
+ if unit is None or len(unit) == 0:
1102
1102
  warnings.warn("The resolution unit is missing. `micrometer` will be used as default.")
1103
1103
  unit = "micrometer"
1104
1104
 
monai/engines/__init__.py CHANGED
@@ -12,12 +12,14 @@
12
12
  from __future__ import annotations
13
13
 
14
14
  from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator
15
- from .trainer import GanTrainer, SupervisedTrainer, Trainer
15
+ from .trainer import AdversarialTrainer, GanTrainer, SupervisedTrainer, Trainer
16
16
  from .utils import (
17
+ DiffusionPrepareBatch,
17
18
  IterationEvents,
18
19
  PrepareBatch,
19
20
  PrepareBatchDefault,
20
21
  PrepareBatchExtraInput,
22
+ VPredictionPrepareBatch,
21
23
  default_make_latent,
22
24
  default_metric_cmp_fn,
23
25
  default_prepare_batch,
monai/engines/trainer.py CHANGED
@@ -24,7 +24,7 @@ from monai.engines.utils import IterationEvents, default_make_latent, default_me
24
24
  from monai.engines.workflow import Workflow
25
25
  from monai.inferers import Inferer, SimpleInferer
26
26
  from monai.transforms import Transform
27
- from monai.utils import GanKeys, min_version, optional_import
27
+ from monai.utils import AdversarialIterationEvents, AdversarialKeys, GanKeys, min_version, optional_import
28
28
  from monai.utils.enums import CommonKeys as Keys
29
29
  from monai.utils.enums import EngineStatsKeys as ESKeys
30
30
  from monai.utils.module import pytorch_after
@@ -37,7 +37,7 @@ else:
37
37
  Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
38
38
  EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum")
39
39
 
40
- __all__ = ["Trainer", "SupervisedTrainer", "GanTrainer"]
40
+ __all__ = ["Trainer", "SupervisedTrainer", "GanTrainer", "AdversarialTrainer"]
41
41
 
42
42
 
43
43
  class Trainer(Workflow):
@@ -471,3 +471,282 @@ class GanTrainer(Trainer):
471
471
  GanKeys.GLOSS: g_loss.item(),
472
472
  GanKeys.DLOSS: d_total_loss.item(),
473
473
  }
474
+
475
+
476
+ class AdversarialTrainer(Trainer):
477
+ """
478
+ Standard supervised training workflow for adversarial loss enabled neural networks.
479
+
480
+ Args:
481
+ device: an object representing the device on which to run.
482
+ max_epochs: the total epoch number for engine to run.
483
+ train_data_loader: Core ignite engines uses `DataLoader` for training loop batchdata.
484
+ g_network: ''generator'' (G) network architecture.
485
+ g_optimizer: G optimizer function.
486
+ g_loss_function: G loss function for adversarial training.
487
+ recon_loss_function: G loss function for reconstructions.
488
+ d_network: discriminator (D) network architecture.
489
+ d_optimizer: D optimizer function.
490
+ d_loss_function: D loss function for adversarial training..
491
+ epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.
492
+ non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to
493
+ the host. For other cases, this argument has no effect.
494
+ prepare_batch: function to parse image and label for current iteration.
495
+ iteration_update: the callable function for every iteration, expect to accept `engine` and `batchdata` as input
496
+ parameters. if not provided, use `self._iteration()` instead.
497
+ g_inferer: inference method to execute G model forward. Defaults to ``SimpleInferer()``.
498
+ d_inferer: inference method to execute D model forward. Defaults to ``SimpleInferer()``.
499
+ postprocessing: execute additional transformation for the model output data. Typically, several Tensor based
500
+ transforms composed by `Compose`. Defaults to None
501
+ key_train_metric: compute metric when every iteration completed, and save average value to engine.state.metrics
502
+ when epoch completed. key_train_metric is the main metric to compare and save the checkpoint into files.
503
+ additional_metrics: more Ignite metrics that also attach to Ignite Engine.
504
+ metric_cmp_fn: function to compare current key metric with previous best key metric value, it must accept 2 args
505
+ (current_metric, previous_best) and return a bool result: if `True`, will update 'best_metric` and
506
+ `best_metric_epoch` with current metric and epoch, default to `greater than`.
507
+ train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
508
+ CheckpointHandler, StatsHandler, etc.
509
+ amp: whether to enable auto-mixed-precision training, default is False.
510
+ event_names: additional custom ignite events that will register to the engine.
511
+ new events can be a list of str or `ignite.engine.events.EventEnum`.
512
+ event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
513
+ for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
514
+ #ignite.engine.engine.Engine.register_events.
515
+ decollate: whether to decollate the batch-first data to a list of data after model computation, recommend
516
+ `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`.
517
+ optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
518
+ more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
519
+ to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
520
+ `device`, `non_blocking`.
521
+ amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
522
+ https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
523
+ """
524
+
525
+ def __init__(
526
+ self,
527
+ device: torch.device | str,
528
+ max_epochs: int,
529
+ train_data_loader: Iterable | DataLoader,
530
+ g_network: torch.nn.Module,
531
+ g_optimizer: Optimizer,
532
+ g_loss_function: Callable,
533
+ recon_loss_function: Callable,
534
+ d_network: torch.nn.Module,
535
+ d_optimizer: Optimizer,
536
+ d_loss_function: Callable,
537
+ epoch_length: int | None = None,
538
+ non_blocking: bool = False,
539
+ prepare_batch: Callable = default_prepare_batch,
540
+ iteration_update: Callable | None = None,
541
+ g_inferer: Inferer | None = None,
542
+ d_inferer: Inferer | None = None,
543
+ postprocessing: Transform | None = None,
544
+ key_train_metric: dict[str, Metric] | None = None,
545
+ additional_metrics: dict[str, Metric] | None = None,
546
+ metric_cmp_fn: Callable = default_metric_cmp_fn,
547
+ train_handlers: Sequence | None = None,
548
+ amp: bool = False,
549
+ event_names: list[str | EventEnum | type[EventEnum]] | None = None,
550
+ event_to_attr: dict | None = None,
551
+ decollate: bool = True,
552
+ optim_set_to_none: bool = False,
553
+ to_kwargs: dict | None = None,
554
+ amp_kwargs: dict | None = None,
555
+ ):
556
+ super().__init__(
557
+ device=device,
558
+ max_epochs=max_epochs,
559
+ data_loader=train_data_loader,
560
+ epoch_length=epoch_length,
561
+ non_blocking=non_blocking,
562
+ prepare_batch=prepare_batch,
563
+ iteration_update=iteration_update,
564
+ postprocessing=postprocessing,
565
+ key_metric=key_train_metric,
566
+ additional_metrics=additional_metrics,
567
+ metric_cmp_fn=metric_cmp_fn,
568
+ handlers=train_handlers,
569
+ amp=amp,
570
+ event_names=event_names,
571
+ event_to_attr=event_to_attr,
572
+ decollate=decollate,
573
+ to_kwargs=to_kwargs,
574
+ amp_kwargs=amp_kwargs,
575
+ )
576
+
577
+ self.register_events(*AdversarialIterationEvents)
578
+
579
+ self.state.g_network = g_network
580
+ self.state.g_optimizer = g_optimizer
581
+ self.state.g_loss_function = g_loss_function
582
+ self.state.recon_loss_function = recon_loss_function
583
+
584
+ self.state.d_network = d_network
585
+ self.state.d_optimizer = d_optimizer
586
+ self.state.d_loss_function = d_loss_function
587
+
588
+ self.g_inferer = SimpleInferer() if g_inferer is None else g_inferer
589
+ self.d_inferer = SimpleInferer() if d_inferer is None else d_inferer
590
+
591
+ self.state.g_scaler = torch.cuda.amp.GradScaler() if self.amp else None
592
+ self.state.d_scaler = torch.cuda.amp.GradScaler() if self.amp else None
593
+
594
+ self.optim_set_to_none = optim_set_to_none
595
+ self._complete_state_dict_user_keys()
596
+
597
+ def _complete_state_dict_user_keys(self) -> None:
598
+ """
599
+ This method appends to the _state_dict_user_keys AdversarialTrainer's elements that are required for
600
+ checkpoint saving.
601
+
602
+ Follows the example found at:
603
+ https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine.state_dict
604
+ """
605
+ self._state_dict_user_keys.extend(
606
+ ["g_network", "g_optimizer", "d_network", "d_optimizer", "g_scaler", "d_scaler"]
607
+ )
608
+
609
+ g_loss_state_dict = getattr(self.state.g_loss_function, "state_dict", None)
610
+ if callable(g_loss_state_dict):
611
+ self._state_dict_user_keys.append("g_loss_function")
612
+
613
+ d_loss_state_dict = getattr(self.state.d_loss_function, "state_dict", None)
614
+ if callable(d_loss_state_dict):
615
+ self._state_dict_user_keys.append("d_loss_function")
616
+
617
+ recon_loss_state_dict = getattr(self.state.recon_loss_function, "state_dict", None)
618
+ if callable(recon_loss_state_dict):
619
+ self._state_dict_user_keys.append("recon_loss_function")
620
+
621
+ def _iteration(
622
+ self, engine: AdversarialTrainer, batchdata: dict[str, torch.Tensor]
623
+ ) -> dict[str, torch.Tensor | int | float | bool]:
624
+ """
625
+ Callback function for the Adversarial Training processing logic of 1 iteration in Ignite Engine.
626
+ Return below items in a dictionary:
627
+ - IMAGE: image Tensor data for model input, already moved to device.
628
+ - LABEL: label Tensor data corresponding to the image, already moved to device. In case of Unsupervised
629
+ Learning this is equal to IMAGE.
630
+ - PRED: prediction result of model.
631
+ - LOSS: loss value computed by loss functions of the generator (reconstruction and adversarial summed up).
632
+ - AdversarialKeys.REALS: real images from the batch. Are the same as IMAGE.
633
+ - AdversarialKeys.FAKES: fake images generated by the generator. Are the same as PRED.
634
+ - AdversarialKeys.REAL_LOGITS: logits of the discriminator for the real images.
635
+ - AdversarialKeys.FAKE_LOGITS: logits of the discriminator for the fake images.
636
+ - AdversarialKeys.RECONSTRUCTION_LOSS: loss value computed by the reconstruction loss function.
637
+ - AdversarialKeys.GENERATOR_LOSS: loss value computed by the generator loss function. It is the
638
+ discriminator loss for the fake images. That is backpropagated through the generator only.
639
+ - AdversarialKeys.DISCRIMINATOR_LOSS: loss value computed by the discriminator loss function. It is the
640
+ discriminator loss for the real images and the fake images. That is backpropagated through the
641
+ discriminator only.
642
+
643
+ Args:
644
+ engine: `AdversarialTrainer` to execute operation for an iteration.
645
+ batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.
646
+
647
+ Raises:
648
+ ValueError: must provide batch data for current iteration.
649
+
650
+ """
651
+
652
+ if batchdata is None:
653
+ raise ValueError("Must provide batch data for current iteration.")
654
+ batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)
655
+
656
+ if len(batch) == 2:
657
+ inputs, targets = batch
658
+ args: tuple = ()
659
+ kwargs: dict = {}
660
+ else:
661
+ inputs, targets, args, kwargs = batch
662
+
663
+ engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets, AdversarialKeys.REALS: inputs}
664
+
665
+ def _compute_generator_loss() -> None:
666
+ engine.state.output[AdversarialKeys.FAKES] = engine.g_inferer(
667
+ inputs, engine.state.g_network, *args, **kwargs
668
+ )
669
+ engine.state.output[Keys.PRED] = engine.state.output[AdversarialKeys.FAKES]
670
+ engine.fire_event(AdversarialIterationEvents.GENERATOR_FORWARD_COMPLETED)
671
+
672
+ engine.state.output[AdversarialKeys.FAKE_LOGITS] = engine.d_inferer(
673
+ engine.state.output[AdversarialKeys.FAKES].float().contiguous(), engine.state.d_network, *args, **kwargs
674
+ )
675
+ engine.fire_event(AdversarialIterationEvents.GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED)
676
+
677
+ engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS] = engine.state.recon_loss_function(
678
+ engine.state.output[AdversarialKeys.FAKES], targets
679
+ ).mean()
680
+ engine.fire_event(AdversarialIterationEvents.RECONSTRUCTION_LOSS_COMPLETED)
681
+
682
+ engine.state.output[AdversarialKeys.GENERATOR_LOSS] = engine.state.g_loss_function(
683
+ engine.state.output[AdversarialKeys.FAKE_LOGITS]
684
+ ).mean()
685
+ engine.fire_event(AdversarialIterationEvents.GENERATOR_LOSS_COMPLETED)
686
+
687
+ # Train Generator
688
+ engine.state.g_network.train()
689
+ engine.state.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
690
+
691
+ if engine.amp and engine.state.g_scaler is not None:
692
+ with torch.cuda.amp.autocast(**engine.amp_kwargs):
693
+ _compute_generator_loss()
694
+
695
+ engine.state.output[Keys.LOSS] = (
696
+ engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS]
697
+ + engine.state.output[AdversarialKeys.GENERATOR_LOSS]
698
+ )
699
+ engine.state.g_scaler.scale(engine.state.output[Keys.LOSS]).backward()
700
+ engine.fire_event(AdversarialIterationEvents.GENERATOR_BACKWARD_COMPLETED)
701
+ engine.state.g_scaler.step(engine.state.g_optimizer)
702
+ engine.state.g_scaler.update()
703
+ else:
704
+ _compute_generator_loss()
705
+ (
706
+ engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS]
707
+ + engine.state.output[AdversarialKeys.GENERATOR_LOSS]
708
+ ).backward()
709
+ engine.fire_event(AdversarialIterationEvents.GENERATOR_BACKWARD_COMPLETED)
710
+ engine.state.g_optimizer.step()
711
+ engine.fire_event(AdversarialIterationEvents.GENERATOR_MODEL_COMPLETED)
712
+
713
+ def _compute_discriminator_loss() -> None:
714
+ engine.state.output[AdversarialKeys.REAL_LOGITS] = engine.d_inferer(
715
+ engine.state.output[AdversarialKeys.REALS].contiguous().detach(),
716
+ engine.state.d_network,
717
+ *args,
718
+ **kwargs,
719
+ )
720
+ engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_REALS_FORWARD_COMPLETED)
721
+
722
+ engine.state.output[AdversarialKeys.FAKE_LOGITS] = engine.d_inferer(
723
+ engine.state.output[AdversarialKeys.FAKES].contiguous().detach(),
724
+ engine.state.d_network,
725
+ *args,
726
+ **kwargs,
727
+ )
728
+ engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_FAKES_FORWARD_COMPLETED)
729
+
730
+ engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS] = engine.state.d_loss_function(
731
+ engine.state.output[AdversarialKeys.REAL_LOGITS], engine.state.output[AdversarialKeys.FAKE_LOGITS]
732
+ ).mean()
733
+ engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_LOSS_COMPLETED)
734
+
735
+ # Train Discriminator
736
+ engine.state.d_network.train()
737
+ engine.state.d_network.zero_grad(set_to_none=engine.optim_set_to_none)
738
+
739
+ if engine.amp and engine.state.d_scaler is not None:
740
+ with torch.cuda.amp.autocast(**engine.amp_kwargs):
741
+ _compute_discriminator_loss()
742
+
743
+ engine.state.d_scaler.scale(engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS]).backward()
744
+ engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_BACKWARD_COMPLETED)
745
+ engine.state.d_scaler.step(engine.state.d_optimizer)
746
+ engine.state.d_scaler.update()
747
+ else:
748
+ _compute_discriminator_loss()
749
+ engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS].backward()
750
+ engine.state.d_optimizer.step()
751
+
752
+ return engine.state.output