monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2429__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/detection/utils/anchor_utils.py +2 -2
- monai/apps/pathology/transforms/post/array.py +1 -0
- monai/bundle/scripts.py +106 -8
- monai/bundle/utils.py +1 -0
- monai/data/dataset_summary.py +1 -0
- monai/data/utils.py +9 -6
- monai/data/wsi_reader.py +2 -2
- monai/engines/__init__.py +3 -1
- monai/engines/trainer.py +281 -2
- monai/engines/utils.py +76 -1
- monai/handlers/mlflow_handler.py +21 -4
- monai/inferers/__init__.py +5 -0
- monai/inferers/inferer.py +1279 -1
- monai/networks/blocks/__init__.py +3 -0
- monai/networks/blocks/attention_utils.py +128 -0
- monai/networks/blocks/crossattention.py +166 -0
- monai/networks/blocks/rel_pos_embedding.py +56 -0
- monai/networks/blocks/selfattention.py +72 -5
- monai/networks/blocks/spade_norm.py +95 -0
- monai/networks/blocks/spatialattention.py +82 -0
- monai/networks/blocks/transformerblock.py +24 -4
- monai/networks/blocks/upsample.py +22 -10
- monai/networks/layers/__init__.py +2 -1
- monai/networks/layers/factories.py +12 -1
- monai/networks/layers/utils.py +14 -1
- monai/networks/layers/vector_quantizer.py +233 -0
- monai/networks/nets/__init__.py +9 -0
- monai/networks/nets/autoencoderkl.py +702 -0
- monai/networks/nets/controlnet.py +465 -0
- monai/networks/nets/diffusion_model_unet.py +1913 -0
- monai/networks/nets/patchgan_discriminator.py +230 -0
- monai/networks/nets/quicknat.py +2 -0
- monai/networks/nets/resnet.py +3 -4
- monai/networks/nets/spade_autoencoderkl.py +480 -0
- monai/networks/nets/spade_diffusion_model_unet.py +934 -0
- monai/networks/nets/spade_network.py +435 -0
- monai/networks/nets/swin_unetr.py +4 -3
- monai/networks/nets/transformer.py +157 -0
- monai/networks/nets/vqvae.py +472 -0
- monai/networks/schedulers/__init__.py +17 -0
- monai/networks/schedulers/ddim.py +294 -0
- monai/networks/schedulers/ddpm.py +250 -0
- monai/networks/schedulers/pndm.py +316 -0
- monai/networks/schedulers/scheduler.py +205 -0
- monai/networks/utils.py +22 -0
- monai/transforms/regularization/array.py +4 -0
- monai/transforms/utils_create_transform_ims.py +2 -4
- monai/utils/__init__.py +1 -0
- monai/utils/misc.py +5 -4
- monai/utils/ordering.py +207 -0
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2429.dist-info}/METADATA +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2429.dist-info}/RECORD +57 -36
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2429.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2429.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2429.dist-info}/top_level.txt +0 -0
monai/__init__.py
CHANGED
monai/_version.py
CHANGED
@@ -8,11 +8,11 @@ import json
|
|
8
8
|
|
9
9
|
version_json = '''
|
10
10
|
{
|
11
|
-
"date": "2024-07-
|
11
|
+
"date": "2024-07-21T02:18:43+0000",
|
12
12
|
"dirty": false,
|
13
13
|
"error": null,
|
14
|
-
"full-revisionid": "
|
15
|
-
"version": "1.4.
|
14
|
+
"full-revisionid": "023827f344ab6236402ad609f0e938c308f68ca0",
|
15
|
+
"version": "1.4.dev2429"
|
16
16
|
}
|
17
17
|
''' # END VERSION_JSON
|
18
18
|
|
@@ -189,7 +189,7 @@ class AnchorGenerator(nn.Module):
|
|
189
189
|
w_ratios = 1 / area_scale
|
190
190
|
h_ratios = area_scale
|
191
191
|
# if 3d, w:h:d = 1:aspect_ratios[:,0]:aspect_ratios[:,1]
|
192
|
-
|
192
|
+
else:
|
193
193
|
area_scale = torch.pow(aspect_ratios_t[:, 0] * aspect_ratios_t[:, 1], 1 / 3.0)
|
194
194
|
w_ratios = 1 / area_scale
|
195
195
|
h_ratios = aspect_ratios_t[:, 0] / area_scale
|
@@ -199,7 +199,7 @@ class AnchorGenerator(nn.Module):
|
|
199
199
|
hs = (h_ratios[:, None] * scales_t[None, :]).view(-1)
|
200
200
|
if self.spatial_dims == 2:
|
201
201
|
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2.0
|
202
|
-
elif self.spatial_dims == 3:
|
202
|
+
else: # elif self.spatial_dims == 3:
|
203
203
|
ds = (d_ratios[:, None] * scales_t[None, :]).view(-1)
|
204
204
|
base_anchors = torch.stack([-ws, -hs, -ds, ws, hs, ds], dim=1) / 2.0
|
205
205
|
|
monai/bundle/scripts.py
CHANGED
@@ -16,6 +16,7 @@ import json
|
|
16
16
|
import os
|
17
17
|
import re
|
18
18
|
import warnings
|
19
|
+
import zipfile
|
19
20
|
from collections.abc import Mapping, Sequence
|
20
21
|
from pathlib import Path
|
21
22
|
from pydoc import locate
|
@@ -171,6 +172,10 @@ def _get_ngc_bundle_url(model_name: str, version: str) -> str:
|
|
171
172
|
return f"https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit/{model_name.lower()}/versions/{version}/zip"
|
172
173
|
|
173
174
|
|
175
|
+
def _get_ngc_private_bundle_url(model_name: str, version: str, repo: str) -> str:
|
176
|
+
return f"https://api.ngc.nvidia.com/v2/{repo}/models/{model_name.lower()}/versions/{version}/zip"
|
177
|
+
|
178
|
+
|
174
179
|
def _get_monaihosting_bundle_url(model_name: str, version: str) -> str:
|
175
180
|
monaihosting_root_path = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
|
176
181
|
return f"{monaihosting_root_path}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip"
|
@@ -219,6 +224,48 @@ def _download_from_ngc(
|
|
219
224
|
extractall(filepath=filepath, output_dir=extract_path, has_base=True)
|
220
225
|
|
221
226
|
|
227
|
+
def _download_from_ngc_private(
|
228
|
+
download_path: Path, filename: str, version: str, remove_prefix: str | None, repo: str, headers: dict | None = None
|
229
|
+
) -> None:
|
230
|
+
# ensure prefix is contained
|
231
|
+
filename = _add_ngc_prefix(filename)
|
232
|
+
request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo)
|
233
|
+
if has_requests:
|
234
|
+
headers = {} if headers is None else headers
|
235
|
+
response = requests_get(request_url, headers=headers)
|
236
|
+
response.raise_for_status()
|
237
|
+
else:
|
238
|
+
raise ValueError("NGC API requires requests package. Please install it.")
|
239
|
+
|
240
|
+
zip_path = download_path / f"{filename}_v{version}.zip"
|
241
|
+
with open(zip_path, "wb") as f:
|
242
|
+
f.write(response.content)
|
243
|
+
logger.info(f"Downloading: {zip_path}.")
|
244
|
+
if remove_prefix:
|
245
|
+
filename = _remove_ngc_prefix(filename, prefix=remove_prefix)
|
246
|
+
extract_path = download_path / f"{filename}"
|
247
|
+
with zipfile.ZipFile(zip_path, "r") as z:
|
248
|
+
z.extractall(extract_path)
|
249
|
+
logger.info(f"Writing into directory: {extract_path}.")
|
250
|
+
|
251
|
+
|
252
|
+
def _get_ngc_token(api_key, retry=0):
|
253
|
+
"""Try to connect to NGC."""
|
254
|
+
url = "https://authn.nvidia.com/token?service=ngc"
|
255
|
+
headers = {"Accept": "application/json", "Authorization": "ApiKey " + api_key}
|
256
|
+
if has_requests:
|
257
|
+
response = requests_get(url, headers=headers)
|
258
|
+
if not response.ok:
|
259
|
+
# retry 3 times, if failed, raise an error.
|
260
|
+
if retry < 3:
|
261
|
+
logger.info(f"Retrying {retry} time(s) to GET {url}.")
|
262
|
+
return _get_ngc_token(url, retry + 1)
|
263
|
+
raise RuntimeError("NGC API response is not ok. Failed to get token.")
|
264
|
+
else:
|
265
|
+
token = response.json()["token"]
|
266
|
+
return token
|
267
|
+
|
268
|
+
|
222
269
|
def _get_latest_bundle_version_monaihosting(name):
|
223
270
|
url = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
|
224
271
|
full_url = f"{url}/{name.lower()}"
|
@@ -227,12 +274,28 @@ def _get_latest_bundle_version_monaihosting(name):
|
|
227
274
|
resp = requests_get(full_url)
|
228
275
|
resp.raise_for_status()
|
229
276
|
else:
|
230
|
-
raise ValueError("NGC API requires requests package.
|
277
|
+
raise ValueError("NGC API requires requests package. Please install it.")
|
231
278
|
model_info = json.loads(resp.text)
|
232
279
|
return model_info["model"]["latestVersionIdStr"]
|
233
280
|
|
234
281
|
|
235
|
-
def
|
282
|
+
def _get_latest_bundle_version_private_registry(name, repo, headers=None):
|
283
|
+
url = f"https://api.ngc.nvidia.com/v2/{repo}/models"
|
284
|
+
full_url = f"{url}/{name.lower()}"
|
285
|
+
requests_get, has_requests = optional_import("requests", name="get")
|
286
|
+
if has_requests:
|
287
|
+
headers = {} if headers is None else headers
|
288
|
+
resp = requests_get(full_url, headers=headers)
|
289
|
+
resp.raise_for_status()
|
290
|
+
else:
|
291
|
+
raise ValueError("NGC API requires requests package. Please install it.")
|
292
|
+
model_info = json.loads(resp.text)
|
293
|
+
return model_info["model"]["latestVersionIdStr"]
|
294
|
+
|
295
|
+
|
296
|
+
def _get_latest_bundle_version(
|
297
|
+
source: str, name: str, repo: str, **kwargs: Any
|
298
|
+
) -> dict[str, list[str] | str] | Any | None:
|
236
299
|
if source == "ngc":
|
237
300
|
name = _add_ngc_prefix(name)
|
238
301
|
model_dict = _get_all_ngc_models(name)
|
@@ -242,6 +305,10 @@ def _get_latest_bundle_version(source: str, name: str, repo: str) -> dict[str, l
|
|
242
305
|
return None
|
243
306
|
elif source == "monaihosting":
|
244
307
|
return _get_latest_bundle_version_monaihosting(name)
|
308
|
+
elif source == "ngc_private":
|
309
|
+
headers = kwargs.pop("headers", {})
|
310
|
+
name = _add_ngc_prefix(name)
|
311
|
+
return _get_latest_bundle_version_private_registry(name, repo, headers)
|
245
312
|
elif source == "github":
|
246
313
|
repo_owner, repo_name, tag_name = repo.split("/")
|
247
314
|
return get_bundle_versions(name, repo=f"{repo_owner}/{repo_name}", tag=tag_name)["latest_version"]
|
@@ -308,6 +375,9 @@ def download(
|
|
308
375
|
# Execute this module as a CLI entry, and download bundle via URL:
|
309
376
|
python -m monai.bundle download --name <bundle_name> --url <url>
|
310
377
|
|
378
|
+
# Execute this module as a CLI entry, and download bundle from ngc_private with latest version:
|
379
|
+
python -m monai.bundle download --name <bundle_name> --source "ngc_private" --bundle_dir "./" --repo "org/org_name"
|
380
|
+
|
311
381
|
# Set default args of `run` in a JSON / YAML file, help to record and simplify the command line.
|
312
382
|
# Other args still can override the default args at runtime.
|
313
383
|
# The content of the JSON / YAML file is a dictionary. For example:
|
@@ -328,10 +398,13 @@ def download(
|
|
328
398
|
Default is `bundle` subfolder under `torch.hub.get_dir()`.
|
329
399
|
source: storage location name. This argument is used when `url` is `None`.
|
330
400
|
In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
|
331
|
-
it should be "ngc", "monaihosting", "github", or "huggingface_hub".
|
401
|
+
it should be "ngc", "monaihosting", "github", "ngc_private", or "huggingface_hub".
|
402
|
+
If source is "ngc_private", you need specify the NGC_API_KEY in the environment variable.
|
332
403
|
repo: repo name. This argument is used when `url` is `None` and `source` is "github" or "huggingface_hub".
|
333
404
|
If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
|
334
405
|
If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name".
|
406
|
+
If `source` is "ngc_private", it should be in the form of "org/org_name" or "org/org_name/team/team_name",
|
407
|
+
or you can specify the environment variable NGC_ORG and NGC_TEAM.
|
335
408
|
url: url to download the data. If not `None`, data will be downloaded directly
|
336
409
|
and `source` will not be checked.
|
337
410
|
If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
|
@@ -363,11 +436,18 @@ def download(
|
|
363
436
|
|
364
437
|
bundle_dir_ = _process_bundle_dir(bundle_dir_)
|
365
438
|
if repo_ is None:
|
366
|
-
|
367
|
-
|
368
|
-
|
439
|
+
org_ = os.getenv("NGC_ORG", None)
|
440
|
+
team_ = os.getenv("NGC_TEAM", None)
|
441
|
+
if org_ is not None and source_ == "ngc_private":
|
442
|
+
repo_ = f"org/{org_}/team/{team_}" if team_ is not None else f"org/{org_}"
|
443
|
+
else:
|
444
|
+
repo_ = "Project-MONAI/model-zoo/hosting_storage_v1"
|
445
|
+
if len(repo_.split("/")) not in (2, 4) and source_ == "ngc_private":
|
446
|
+
raise ValueError(f"repo should be in the form of `org/org_name/team/team_name` or `org/org_name`, got {repo_}.")
|
447
|
+
if len(repo_.split("/")) != 3 and source_ == "github":
|
448
|
+
raise ValueError(f"repo should be in the form of `repo_owner/repo_name/release_tag`, got {repo_}.")
|
369
449
|
elif len(repo_.split("/")) != 2 and source_ == "huggingface_hub":
|
370
|
-
raise ValueError("Hugging Face Hub repo should be in the form of `repo_owner/repo_name
|
450
|
+
raise ValueError(f"Hugging Face Hub repo should be in the form of `repo_owner/repo_name`, got {repo_}.")
|
371
451
|
if url_ is not None:
|
372
452
|
if name_ is not None:
|
373
453
|
filepath = bundle_dir_ / f"{name_}.zip"
|
@@ -376,10 +456,19 @@ def download(
|
|
376
456
|
download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_)
|
377
457
|
extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True)
|
378
458
|
else:
|
459
|
+
headers = {}
|
379
460
|
if name_ is None:
|
380
461
|
raise ValueError(f"To download from source: {source_}, `name` must be provided.")
|
462
|
+
if source == "ngc_private":
|
463
|
+
api_key = os.getenv("NGC_API_KEY", None)
|
464
|
+
if api_key is None:
|
465
|
+
raise ValueError("API key is required for ngc_private source.")
|
466
|
+
else:
|
467
|
+
token = _get_ngc_token(api_key)
|
468
|
+
headers = {"Authorization": f"Bearer {token}"}
|
469
|
+
|
381
470
|
if version_ is None:
|
382
|
-
version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_)
|
471
|
+
version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_, headers=headers)
|
383
472
|
if source_ == "github":
|
384
473
|
if version_ is not None:
|
385
474
|
name_ = "_v".join([name_, version_])
|
@@ -394,6 +483,15 @@ def download(
|
|
394
483
|
remove_prefix=remove_prefix_,
|
395
484
|
progress=progress_,
|
396
485
|
)
|
486
|
+
elif source_ == "ngc_private":
|
487
|
+
_download_from_ngc_private(
|
488
|
+
download_path=bundle_dir_,
|
489
|
+
filename=name_,
|
490
|
+
version=version_,
|
491
|
+
remove_prefix=remove_prefix_,
|
492
|
+
repo=repo_,
|
493
|
+
headers=headers,
|
494
|
+
)
|
397
495
|
elif source_ == "huggingface_hub":
|
398
496
|
extract_path = os.path.join(bundle_dir_, name_)
|
399
497
|
huggingface_hub.snapshot_download(repo_id=repo_, revision=version_, local_dir=extract_path)
|
monai/bundle/utils.py
CHANGED
@@ -221,6 +221,7 @@ def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any
|
|
221
221
|
raise ValueError(f"Cannot find config file '{full_cname}'")
|
222
222
|
|
223
223
|
ardata = archive.read(full_cname)
|
224
|
+
cdata = {}
|
224
225
|
|
225
226
|
if full_cname.lower().endswith("json"):
|
226
227
|
cdata = json.loads(ardata, **load_kw_args)
|
monai/data/dataset_summary.py
CHANGED
monai/data/utils.py
CHANGED
@@ -53,10 +53,6 @@ from monai.utils import (
|
|
53
53
|
pytorch_after,
|
54
54
|
)
|
55
55
|
|
56
|
-
if pytorch_after(1, 13):
|
57
|
-
# import private code for reuse purposes, comment in case things break in the future
|
58
|
-
from torch.utils.data._utils.collate import collate_tensor_fn, default_collate_fn_map
|
59
|
-
|
60
56
|
pd, _ = optional_import("pandas")
|
61
57
|
DataFrame, _ = optional_import("pandas", name="DataFrame")
|
62
58
|
nib, _ = optional_import("nibabel")
|
@@ -454,8 +450,13 @@ def collate_meta_tensor_fn(batch, *, collate_fn_map=None):
|
|
454
450
|
Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor`
|
455
451
|
and so should not be used as a collate function directly in dataloaders.
|
456
452
|
"""
|
457
|
-
|
458
|
-
|
453
|
+
if pytorch_after(1, 13):
|
454
|
+
from torch.utils.data._utils.collate import collate_tensor_fn # imported here for pylint/mypy issues
|
455
|
+
|
456
|
+
collated = collate_tensor_fn(batch)
|
457
|
+
else:
|
458
|
+
collated = default_collate(batch)
|
459
|
+
|
459
460
|
meta_dicts = [i.meta or TraceKeys.NONE for i in batch]
|
460
461
|
common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)])
|
461
462
|
if common_:
|
@@ -496,6 +497,8 @@ def list_data_collate(batch: Sequence):
|
|
496
497
|
|
497
498
|
if pytorch_after(1, 13):
|
498
499
|
# needs to go here to avoid circular import
|
500
|
+
from torch.utils.data._utils.collate import default_collate_fn_map
|
501
|
+
|
499
502
|
from monai.data.meta_tensor import MetaTensor
|
500
503
|
|
501
504
|
default_collate_fn_map.update({MetaTensor: collate_meta_tensor_fn})
|
monai/data/wsi_reader.py
CHANGED
@@ -1097,8 +1097,8 @@ class TiffFileWSIReader(BaseWSIReader):
|
|
1097
1097
|
):
|
1098
1098
|
unit = wsi.pages[level].tags.get("ResolutionUnit")
|
1099
1099
|
if unit is not None:
|
1100
|
-
unit = str(unit.value)
|
1101
|
-
|
1100
|
+
unit = str(unit.value.name)
|
1101
|
+
if unit is None or len(unit) == 0:
|
1102
1102
|
warnings.warn("The resolution unit is missing. `micrometer` will be used as default.")
|
1103
1103
|
unit = "micrometer"
|
1104
1104
|
|
monai/engines/__init__.py
CHANGED
@@ -12,12 +12,14 @@
|
|
12
12
|
from __future__ import annotations
|
13
13
|
|
14
14
|
from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator
|
15
|
-
from .trainer import GanTrainer, SupervisedTrainer, Trainer
|
15
|
+
from .trainer import AdversarialTrainer, GanTrainer, SupervisedTrainer, Trainer
|
16
16
|
from .utils import (
|
17
|
+
DiffusionPrepareBatch,
|
17
18
|
IterationEvents,
|
18
19
|
PrepareBatch,
|
19
20
|
PrepareBatchDefault,
|
20
21
|
PrepareBatchExtraInput,
|
22
|
+
VPredictionPrepareBatch,
|
21
23
|
default_make_latent,
|
22
24
|
default_metric_cmp_fn,
|
23
25
|
default_prepare_batch,
|
monai/engines/trainer.py
CHANGED
@@ -24,7 +24,7 @@ from monai.engines.utils import IterationEvents, default_make_latent, default_me
|
|
24
24
|
from monai.engines.workflow import Workflow
|
25
25
|
from monai.inferers import Inferer, SimpleInferer
|
26
26
|
from monai.transforms import Transform
|
27
|
-
from monai.utils import GanKeys, min_version, optional_import
|
27
|
+
from monai.utils import AdversarialIterationEvents, AdversarialKeys, GanKeys, min_version, optional_import
|
28
28
|
from monai.utils.enums import CommonKeys as Keys
|
29
29
|
from monai.utils.enums import EngineStatsKeys as ESKeys
|
30
30
|
from monai.utils.module import pytorch_after
|
@@ -37,7 +37,7 @@ else:
|
|
37
37
|
Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
|
38
38
|
EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum")
|
39
39
|
|
40
|
-
__all__ = ["Trainer", "SupervisedTrainer", "GanTrainer"]
|
40
|
+
__all__ = ["Trainer", "SupervisedTrainer", "GanTrainer", "AdversarialTrainer"]
|
41
41
|
|
42
42
|
|
43
43
|
class Trainer(Workflow):
|
@@ -471,3 +471,282 @@ class GanTrainer(Trainer):
|
|
471
471
|
GanKeys.GLOSS: g_loss.item(),
|
472
472
|
GanKeys.DLOSS: d_total_loss.item(),
|
473
473
|
}
|
474
|
+
|
475
|
+
|
476
|
+
class AdversarialTrainer(Trainer):
|
477
|
+
"""
|
478
|
+
Standard supervised training workflow for adversarial loss enabled neural networks.
|
479
|
+
|
480
|
+
Args:
|
481
|
+
device: an object representing the device on which to run.
|
482
|
+
max_epochs: the total epoch number for engine to run.
|
483
|
+
train_data_loader: Core ignite engines uses `DataLoader` for training loop batchdata.
|
484
|
+
g_network: ''generator'' (G) network architecture.
|
485
|
+
g_optimizer: G optimizer function.
|
486
|
+
g_loss_function: G loss function for adversarial training.
|
487
|
+
recon_loss_function: G loss function for reconstructions.
|
488
|
+
d_network: discriminator (D) network architecture.
|
489
|
+
d_optimizer: D optimizer function.
|
490
|
+
d_loss_function: D loss function for adversarial training..
|
491
|
+
epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.
|
492
|
+
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to
|
493
|
+
the host. For other cases, this argument has no effect.
|
494
|
+
prepare_batch: function to parse image and label for current iteration.
|
495
|
+
iteration_update: the callable function for every iteration, expect to accept `engine` and `batchdata` as input
|
496
|
+
parameters. if not provided, use `self._iteration()` instead.
|
497
|
+
g_inferer: inference method to execute G model forward. Defaults to ``SimpleInferer()``.
|
498
|
+
d_inferer: inference method to execute D model forward. Defaults to ``SimpleInferer()``.
|
499
|
+
postprocessing: execute additional transformation for the model output data. Typically, several Tensor based
|
500
|
+
transforms composed by `Compose`. Defaults to None
|
501
|
+
key_train_metric: compute metric when every iteration completed, and save average value to engine.state.metrics
|
502
|
+
when epoch completed. key_train_metric is the main metric to compare and save the checkpoint into files.
|
503
|
+
additional_metrics: more Ignite metrics that also attach to Ignite Engine.
|
504
|
+
metric_cmp_fn: function to compare current key metric with previous best key metric value, it must accept 2 args
|
505
|
+
(current_metric, previous_best) and return a bool result: if `True`, will update 'best_metric` and
|
506
|
+
`best_metric_epoch` with current metric and epoch, default to `greater than`.
|
507
|
+
train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
|
508
|
+
CheckpointHandler, StatsHandler, etc.
|
509
|
+
amp: whether to enable auto-mixed-precision training, default is False.
|
510
|
+
event_names: additional custom ignite events that will register to the engine.
|
511
|
+
new events can be a list of str or `ignite.engine.events.EventEnum`.
|
512
|
+
event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
|
513
|
+
for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
|
514
|
+
#ignite.engine.engine.Engine.register_events.
|
515
|
+
decollate: whether to decollate the batch-first data to a list of data after model computation, recommend
|
516
|
+
`decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`.
|
517
|
+
optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
|
518
|
+
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
|
519
|
+
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
|
520
|
+
`device`, `non_blocking`.
|
521
|
+
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
|
522
|
+
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
|
523
|
+
"""
|
524
|
+
|
525
|
+
def __init__(
|
526
|
+
self,
|
527
|
+
device: torch.device | str,
|
528
|
+
max_epochs: int,
|
529
|
+
train_data_loader: Iterable | DataLoader,
|
530
|
+
g_network: torch.nn.Module,
|
531
|
+
g_optimizer: Optimizer,
|
532
|
+
g_loss_function: Callable,
|
533
|
+
recon_loss_function: Callable,
|
534
|
+
d_network: torch.nn.Module,
|
535
|
+
d_optimizer: Optimizer,
|
536
|
+
d_loss_function: Callable,
|
537
|
+
epoch_length: int | None = None,
|
538
|
+
non_blocking: bool = False,
|
539
|
+
prepare_batch: Callable = default_prepare_batch,
|
540
|
+
iteration_update: Callable | None = None,
|
541
|
+
g_inferer: Inferer | None = None,
|
542
|
+
d_inferer: Inferer | None = None,
|
543
|
+
postprocessing: Transform | None = None,
|
544
|
+
key_train_metric: dict[str, Metric] | None = None,
|
545
|
+
additional_metrics: dict[str, Metric] | None = None,
|
546
|
+
metric_cmp_fn: Callable = default_metric_cmp_fn,
|
547
|
+
train_handlers: Sequence | None = None,
|
548
|
+
amp: bool = False,
|
549
|
+
event_names: list[str | EventEnum | type[EventEnum]] | None = None,
|
550
|
+
event_to_attr: dict | None = None,
|
551
|
+
decollate: bool = True,
|
552
|
+
optim_set_to_none: bool = False,
|
553
|
+
to_kwargs: dict | None = None,
|
554
|
+
amp_kwargs: dict | None = None,
|
555
|
+
):
|
556
|
+
super().__init__(
|
557
|
+
device=device,
|
558
|
+
max_epochs=max_epochs,
|
559
|
+
data_loader=train_data_loader,
|
560
|
+
epoch_length=epoch_length,
|
561
|
+
non_blocking=non_blocking,
|
562
|
+
prepare_batch=prepare_batch,
|
563
|
+
iteration_update=iteration_update,
|
564
|
+
postprocessing=postprocessing,
|
565
|
+
key_metric=key_train_metric,
|
566
|
+
additional_metrics=additional_metrics,
|
567
|
+
metric_cmp_fn=metric_cmp_fn,
|
568
|
+
handlers=train_handlers,
|
569
|
+
amp=amp,
|
570
|
+
event_names=event_names,
|
571
|
+
event_to_attr=event_to_attr,
|
572
|
+
decollate=decollate,
|
573
|
+
to_kwargs=to_kwargs,
|
574
|
+
amp_kwargs=amp_kwargs,
|
575
|
+
)
|
576
|
+
|
577
|
+
self.register_events(*AdversarialIterationEvents)
|
578
|
+
|
579
|
+
self.state.g_network = g_network
|
580
|
+
self.state.g_optimizer = g_optimizer
|
581
|
+
self.state.g_loss_function = g_loss_function
|
582
|
+
self.state.recon_loss_function = recon_loss_function
|
583
|
+
|
584
|
+
self.state.d_network = d_network
|
585
|
+
self.state.d_optimizer = d_optimizer
|
586
|
+
self.state.d_loss_function = d_loss_function
|
587
|
+
|
588
|
+
self.g_inferer = SimpleInferer() if g_inferer is None else g_inferer
|
589
|
+
self.d_inferer = SimpleInferer() if d_inferer is None else d_inferer
|
590
|
+
|
591
|
+
self.state.g_scaler = torch.cuda.amp.GradScaler() if self.amp else None
|
592
|
+
self.state.d_scaler = torch.cuda.amp.GradScaler() if self.amp else None
|
593
|
+
|
594
|
+
self.optim_set_to_none = optim_set_to_none
|
595
|
+
self._complete_state_dict_user_keys()
|
596
|
+
|
597
|
+
def _complete_state_dict_user_keys(self) -> None:
|
598
|
+
"""
|
599
|
+
This method appends to the _state_dict_user_keys AdversarialTrainer's elements that are required for
|
600
|
+
checkpoint saving.
|
601
|
+
|
602
|
+
Follows the example found at:
|
603
|
+
https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine.state_dict
|
604
|
+
"""
|
605
|
+
self._state_dict_user_keys.extend(
|
606
|
+
["g_network", "g_optimizer", "d_network", "d_optimizer", "g_scaler", "d_scaler"]
|
607
|
+
)
|
608
|
+
|
609
|
+
g_loss_state_dict = getattr(self.state.g_loss_function, "state_dict", None)
|
610
|
+
if callable(g_loss_state_dict):
|
611
|
+
self._state_dict_user_keys.append("g_loss_function")
|
612
|
+
|
613
|
+
d_loss_state_dict = getattr(self.state.d_loss_function, "state_dict", None)
|
614
|
+
if callable(d_loss_state_dict):
|
615
|
+
self._state_dict_user_keys.append("d_loss_function")
|
616
|
+
|
617
|
+
recon_loss_state_dict = getattr(self.state.recon_loss_function, "state_dict", None)
|
618
|
+
if callable(recon_loss_state_dict):
|
619
|
+
self._state_dict_user_keys.append("recon_loss_function")
|
620
|
+
|
621
|
+
def _iteration(
|
622
|
+
self, engine: AdversarialTrainer, batchdata: dict[str, torch.Tensor]
|
623
|
+
) -> dict[str, torch.Tensor | int | float | bool]:
|
624
|
+
"""
|
625
|
+
Callback function for the Adversarial Training processing logic of 1 iteration in Ignite Engine.
|
626
|
+
Return below items in a dictionary:
|
627
|
+
- IMAGE: image Tensor data for model input, already moved to device.
|
628
|
+
- LABEL: label Tensor data corresponding to the image, already moved to device. In case of Unsupervised
|
629
|
+
Learning this is equal to IMAGE.
|
630
|
+
- PRED: prediction result of model.
|
631
|
+
- LOSS: loss value computed by loss functions of the generator (reconstruction and adversarial summed up).
|
632
|
+
- AdversarialKeys.REALS: real images from the batch. Are the same as IMAGE.
|
633
|
+
- AdversarialKeys.FAKES: fake images generated by the generator. Are the same as PRED.
|
634
|
+
- AdversarialKeys.REAL_LOGITS: logits of the discriminator for the real images.
|
635
|
+
- AdversarialKeys.FAKE_LOGITS: logits of the discriminator for the fake images.
|
636
|
+
- AdversarialKeys.RECONSTRUCTION_LOSS: loss value computed by the reconstruction loss function.
|
637
|
+
- AdversarialKeys.GENERATOR_LOSS: loss value computed by the generator loss function. It is the
|
638
|
+
discriminator loss for the fake images. That is backpropagated through the generator only.
|
639
|
+
- AdversarialKeys.DISCRIMINATOR_LOSS: loss value computed by the discriminator loss function. It is the
|
640
|
+
discriminator loss for the real images and the fake images. That is backpropagated through the
|
641
|
+
discriminator only.
|
642
|
+
|
643
|
+
Args:
|
644
|
+
engine: `AdversarialTrainer` to execute operation for an iteration.
|
645
|
+
batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.
|
646
|
+
|
647
|
+
Raises:
|
648
|
+
ValueError: must provide batch data for current iteration.
|
649
|
+
|
650
|
+
"""
|
651
|
+
|
652
|
+
if batchdata is None:
|
653
|
+
raise ValueError("Must provide batch data for current iteration.")
|
654
|
+
batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)
|
655
|
+
|
656
|
+
if len(batch) == 2:
|
657
|
+
inputs, targets = batch
|
658
|
+
args: tuple = ()
|
659
|
+
kwargs: dict = {}
|
660
|
+
else:
|
661
|
+
inputs, targets, args, kwargs = batch
|
662
|
+
|
663
|
+
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets, AdversarialKeys.REALS: inputs}
|
664
|
+
|
665
|
+
def _compute_generator_loss() -> None:
|
666
|
+
engine.state.output[AdversarialKeys.FAKES] = engine.g_inferer(
|
667
|
+
inputs, engine.state.g_network, *args, **kwargs
|
668
|
+
)
|
669
|
+
engine.state.output[Keys.PRED] = engine.state.output[AdversarialKeys.FAKES]
|
670
|
+
engine.fire_event(AdversarialIterationEvents.GENERATOR_FORWARD_COMPLETED)
|
671
|
+
|
672
|
+
engine.state.output[AdversarialKeys.FAKE_LOGITS] = engine.d_inferer(
|
673
|
+
engine.state.output[AdversarialKeys.FAKES].float().contiguous(), engine.state.d_network, *args, **kwargs
|
674
|
+
)
|
675
|
+
engine.fire_event(AdversarialIterationEvents.GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED)
|
676
|
+
|
677
|
+
engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS] = engine.state.recon_loss_function(
|
678
|
+
engine.state.output[AdversarialKeys.FAKES], targets
|
679
|
+
).mean()
|
680
|
+
engine.fire_event(AdversarialIterationEvents.RECONSTRUCTION_LOSS_COMPLETED)
|
681
|
+
|
682
|
+
engine.state.output[AdversarialKeys.GENERATOR_LOSS] = engine.state.g_loss_function(
|
683
|
+
engine.state.output[AdversarialKeys.FAKE_LOGITS]
|
684
|
+
).mean()
|
685
|
+
engine.fire_event(AdversarialIterationEvents.GENERATOR_LOSS_COMPLETED)
|
686
|
+
|
687
|
+
# Train Generator
|
688
|
+
engine.state.g_network.train()
|
689
|
+
engine.state.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
|
690
|
+
|
691
|
+
if engine.amp and engine.state.g_scaler is not None:
|
692
|
+
with torch.cuda.amp.autocast(**engine.amp_kwargs):
|
693
|
+
_compute_generator_loss()
|
694
|
+
|
695
|
+
engine.state.output[Keys.LOSS] = (
|
696
|
+
engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS]
|
697
|
+
+ engine.state.output[AdversarialKeys.GENERATOR_LOSS]
|
698
|
+
)
|
699
|
+
engine.state.g_scaler.scale(engine.state.output[Keys.LOSS]).backward()
|
700
|
+
engine.fire_event(AdversarialIterationEvents.GENERATOR_BACKWARD_COMPLETED)
|
701
|
+
engine.state.g_scaler.step(engine.state.g_optimizer)
|
702
|
+
engine.state.g_scaler.update()
|
703
|
+
else:
|
704
|
+
_compute_generator_loss()
|
705
|
+
(
|
706
|
+
engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS]
|
707
|
+
+ engine.state.output[AdversarialKeys.GENERATOR_LOSS]
|
708
|
+
).backward()
|
709
|
+
engine.fire_event(AdversarialIterationEvents.GENERATOR_BACKWARD_COMPLETED)
|
710
|
+
engine.state.g_optimizer.step()
|
711
|
+
engine.fire_event(AdversarialIterationEvents.GENERATOR_MODEL_COMPLETED)
|
712
|
+
|
713
|
+
def _compute_discriminator_loss() -> None:
|
714
|
+
engine.state.output[AdversarialKeys.REAL_LOGITS] = engine.d_inferer(
|
715
|
+
engine.state.output[AdversarialKeys.REALS].contiguous().detach(),
|
716
|
+
engine.state.d_network,
|
717
|
+
*args,
|
718
|
+
**kwargs,
|
719
|
+
)
|
720
|
+
engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_REALS_FORWARD_COMPLETED)
|
721
|
+
|
722
|
+
engine.state.output[AdversarialKeys.FAKE_LOGITS] = engine.d_inferer(
|
723
|
+
engine.state.output[AdversarialKeys.FAKES].contiguous().detach(),
|
724
|
+
engine.state.d_network,
|
725
|
+
*args,
|
726
|
+
**kwargs,
|
727
|
+
)
|
728
|
+
engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_FAKES_FORWARD_COMPLETED)
|
729
|
+
|
730
|
+
engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS] = engine.state.d_loss_function(
|
731
|
+
engine.state.output[AdversarialKeys.REAL_LOGITS], engine.state.output[AdversarialKeys.FAKE_LOGITS]
|
732
|
+
).mean()
|
733
|
+
engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_LOSS_COMPLETED)
|
734
|
+
|
735
|
+
# Train Discriminator
|
736
|
+
engine.state.d_network.train()
|
737
|
+
engine.state.d_network.zero_grad(set_to_none=engine.optim_set_to_none)
|
738
|
+
|
739
|
+
if engine.amp and engine.state.d_scaler is not None:
|
740
|
+
with torch.cuda.amp.autocast(**engine.amp_kwargs):
|
741
|
+
_compute_discriminator_loss()
|
742
|
+
|
743
|
+
engine.state.d_scaler.scale(engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS]).backward()
|
744
|
+
engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_BACKWARD_COMPLETED)
|
745
|
+
engine.state.d_scaler.step(engine.state.d_optimizer)
|
746
|
+
engine.state.d_scaler.update()
|
747
|
+
else:
|
748
|
+
_compute_discriminator_loss()
|
749
|
+
engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS].backward()
|
750
|
+
engine.state.d_optimizer.step()
|
751
|
+
|
752
|
+
return engine.state.output
|