monai-weekly 1.5.dev2508__py3-none-any.whl → 1.5.dev2510__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/deepedit/interaction.py +1 -1
- monai/apps/deepgrow/interaction.py +1 -1
- monai/apps/detection/networks/retinanet_detector.py +1 -1
- monai/apps/detection/networks/retinanet_network.py +5 -5
- monai/apps/detection/utils/box_coder.py +2 -2
- monai/apps/mmars/mmars.py +1 -1
- monai/apps/reconstruction/networks/blocks/varnetblock.py +1 -1
- monai/bundle/scripts.py +42 -20
- monai/data/dataset.py +2 -9
- monai/data/utils.py +1 -1
- monai/data/video_dataset.py +1 -1
- monai/engines/evaluator.py +11 -16
- monai/engines/trainer.py +11 -17
- monai/engines/utils.py +1 -1
- monai/engines/workflow.py +2 -2
- monai/fl/client/monai_algo.py +1 -1
- monai/handlers/checkpoint_loader.py +1 -1
- monai/inferers/inferer.py +35 -17
- monai/inferers/merger.py +16 -13
- monai/losses/perceptual.py +1 -1
- monai/losses/sure_loss.py +1 -1
- monai/networks/blocks/crossattention.py +1 -6
- monai/networks/blocks/feature_pyramid_network.py +4 -2
- monai/networks/blocks/selfattention.py +1 -6
- monai/networks/blocks/upsample.py +3 -11
- monai/networks/layers/vector_quantizer.py +2 -2
- monai/networks/nets/hovernet.py +5 -4
- monai/networks/nets/resnet.py +2 -2
- monai/networks/nets/senet.py +1 -1
- monai/networks/nets/swin_unetr.py +46 -49
- monai/networks/nets/transchex.py +3 -2
- monai/networks/nets/vista3d.py +7 -7
- monai/networks/utils.py +5 -4
- monai/transforms/intensity/array.py +1 -1
- monai/transforms/spatial/array.py +6 -6
- monai/utils/misc.py +1 -1
- monai/utils/state_cacher.py +1 -1
- {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2510.dist-info}/METADATA +4 -3
- {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2510.dist-info}/RECORD +60 -60
- {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2510.dist-info}/WHEEL +1 -1
- tests/bundle/test_bundle_download.py +16 -6
- tests/config/test_cv2_dist.py +1 -2
- tests/inferers/test_controlnet_inferers.py +9 -0
- tests/integration/test_integration_bundle_run.py +2 -4
- tests/integration/test_integration_classification_2d.py +1 -1
- tests/integration/test_integration_fast_train.py +2 -2
- tests/integration/test_integration_segmentation_3d.py +1 -1
- tests/metrics/test_compute_multiscalessim_metric.py +3 -3
- tests/metrics/test_surface_dice.py +3 -3
- tests/networks/nets/test_autoencoderkl.py +1 -1
- tests/networks/nets/test_controlnet.py +1 -1
- tests/networks/nets/test_diffusion_model_unet.py +1 -1
- tests/networks/nets/test_network_consistency.py +1 -1
- tests/networks/nets/test_swin_unetr.py +1 -1
- tests/networks/nets/test_transformer.py +1 -1
- tests/networks/test_save_state.py +1 -1
- {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2510.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2510.dist-info}/top_level.txt +0 -0
monai/__init__.py
CHANGED
monai/_version.py
CHANGED
@@ -8,11 +8,11 @@ import json
|
|
8
8
|
|
9
9
|
version_json = '''
|
10
10
|
{
|
11
|
-
"date": "2025-
|
11
|
+
"date": "2025-03-09T02:16:22+0000",
|
12
12
|
"dirty": false,
|
13
13
|
"error": null,
|
14
|
-
"full-revisionid": "
|
15
|
-
"version": "1.5.
|
14
|
+
"full-revisionid": "19fadf962d87a21e1d0edf8d72299e82f7611140",
|
15
|
+
"version": "1.5.dev2510"
|
16
16
|
}
|
17
17
|
''' # END VERSION_JSON
|
18
18
|
|
@@ -67,7 +67,7 @@ class Interaction:
|
|
67
67
|
engine.network.eval()
|
68
68
|
with torch.no_grad():
|
69
69
|
if engine.amp:
|
70
|
-
with torch.
|
70
|
+
with torch.autocast("cuda"):
|
71
71
|
predictions = engine.inferer(inputs, engine.network)
|
72
72
|
else:
|
73
73
|
predictions = engine.inferer(inputs, engine.network)
|
@@ -180,7 +180,7 @@ class RetinaNetDetector(nn.Module):
|
|
180
180
|
nesterov=True,
|
181
181
|
)
|
182
182
|
torch.save(detector.network.state_dict(), 'model.pt') # save model
|
183
|
-
detector.network.load_state_dict(torch.load('model.pt')) # load model
|
183
|
+
detector.network.load_state_dict(torch.load('model.pt', weights_only=True)) # load model
|
184
184
|
"""
|
185
185
|
|
186
186
|
def __init__(
|
@@ -88,8 +88,8 @@ class RetinaNetClassificationHead(nn.Module):
|
|
88
88
|
|
89
89
|
for layer in self.conv.children():
|
90
90
|
if isinstance(layer, conv_type): # type: ignore
|
91
|
-
torch.nn.init.normal_(layer.weight, std=0.01)
|
92
|
-
torch.nn.init.constant_(layer.bias, 0)
|
91
|
+
torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
|
92
|
+
torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type]
|
93
93
|
|
94
94
|
self.cls_logits = conv_type(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
|
95
95
|
torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
|
@@ -167,8 +167,8 @@ class RetinaNetRegressionHead(nn.Module):
|
|
167
167
|
|
168
168
|
for layer in self.conv.children():
|
169
169
|
if isinstance(layer, conv_type): # type: ignore
|
170
|
-
torch.nn.init.normal_(layer.weight, std=0.01)
|
171
|
-
torch.nn.init.zeros_(layer.bias)
|
170
|
+
torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
|
171
|
+
torch.nn.init.zeros_(layer.bias) # type: ignore[arg-type]
|
172
172
|
|
173
173
|
def forward(self, x: list[Tensor]) -> list[Tensor]:
|
174
174
|
"""
|
@@ -297,7 +297,7 @@ class RetinaNet(nn.Module):
|
|
297
297
|
)
|
298
298
|
self.feature_extractor = feature_extractor
|
299
299
|
|
300
|
-
self.feature_map_channels: int = self.feature_extractor.out_channels
|
300
|
+
self.feature_map_channels: int = self.feature_extractor.out_channels # type: ignore[assignment]
|
301
301
|
self.num_anchors = num_anchors
|
302
302
|
self.classification_head = RetinaNetClassificationHead(
|
303
303
|
self.feature_map_channels, self.num_anchors, self.num_classes, spatial_dims=self.spatial_dims
|
@@ -221,7 +221,7 @@ class BoxCoder:
|
|
221
221
|
|
222
222
|
pred_ctr_xyx_axis = dxyz_axis * whd_axis[:, None] + ctr_xyz_axis[:, None]
|
223
223
|
pred_whd_axis = torch.exp(dwhd_axis) * whd_axis[:, None]
|
224
|
-
pred_whd_axis = pred_whd_axis.to(dxyz_axis.dtype)
|
224
|
+
pred_whd_axis = pred_whd_axis.to(dxyz_axis.dtype) # type: ignore[union-attr]
|
225
225
|
|
226
226
|
# When convert float32 to float16, Inf or Nan may occur
|
227
227
|
if torch.isnan(pred_whd_axis).any() or torch.isinf(pred_whd_axis).any():
|
@@ -229,7 +229,7 @@ class BoxCoder:
|
|
229
229
|
|
230
230
|
# Distance from center to box's corner.
|
231
231
|
c_to_c_whd_axis = (
|
232
|
-
torch.tensor(0.5, dtype=pred_ctr_xyx_axis.dtype, device=pred_whd_axis.device) * pred_whd_axis
|
232
|
+
torch.tensor(0.5, dtype=pred_ctr_xyx_axis.dtype, device=pred_whd_axis.device) * pred_whd_axis # type: ignore[arg-type]
|
233
233
|
)
|
234
234
|
|
235
235
|
pred_boxes.append(pred_ctr_xyx_axis - c_to_c_whd_axis)
|
monai/apps/mmars/mmars.py
CHANGED
@@ -241,7 +241,7 @@ def load_from_mmar(
|
|
241
241
|
return torch.jit.load(_model_file, map_location=map_location)
|
242
242
|
|
243
243
|
# loading with `torch.load`
|
244
|
-
model_dict = torch.load(_model_file, map_location=map_location)
|
244
|
+
model_dict = torch.load(_model_file, map_location=map_location, weights_only=True)
|
245
245
|
if weights_only:
|
246
246
|
return model_dict.get(model_key, model_dict) # model_dict[model_key] or model_dict directly
|
247
247
|
|
@@ -55,7 +55,7 @@ class VarNetBlock(nn.Module):
|
|
55
55
|
Returns:
|
56
56
|
Output of DC block with the same shape as x
|
57
57
|
"""
|
58
|
-
return torch.where(mask, x - ref_kspace, self.zeros) * self.dc_weight
|
58
|
+
return torch.where(mask, x - ref_kspace, self.zeros) * self.dc_weight # type: ignore
|
59
59
|
|
60
60
|
def forward(self, current_kspace: Tensor, ref_kspace: Tensor, mask: Tensor, sens_maps: Tensor) -> Tensor:
|
61
61
|
"""
|
monai/bundle/scripts.py
CHANGED
@@ -15,6 +15,7 @@ import ast
|
|
15
15
|
import json
|
16
16
|
import os
|
17
17
|
import re
|
18
|
+
import urllib
|
18
19
|
import warnings
|
19
20
|
import zipfile
|
20
21
|
from collections.abc import Mapping, Sequence
|
@@ -58,7 +59,7 @@ from monai.utils import (
|
|
58
59
|
validate, _ = optional_import("jsonschema", name="validate")
|
59
60
|
ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError")
|
60
61
|
Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint")
|
61
|
-
|
62
|
+
requests, has_requests = optional_import("requests")
|
62
63
|
onnx, _ = optional_import("onnx")
|
63
64
|
huggingface_hub, _ = optional_import("huggingface_hub")
|
64
65
|
|
@@ -206,6 +207,16 @@ def _download_from_monaihosting(download_path: Path, filename: str, version: str
|
|
206
207
|
extractall(filepath=filepath, output_dir=download_path, has_base=True)
|
207
208
|
|
208
209
|
|
210
|
+
def _download_from_bundle_info(download_path: Path, filename: str, version: str, progress: bool) -> None:
|
211
|
+
bundle_info = get_bundle_info(bundle_name=filename, version=version)
|
212
|
+
if not bundle_info:
|
213
|
+
raise ValueError(f"Bundle info not found for {filename} v{version}.")
|
214
|
+
url = bundle_info["browser_download_url"]
|
215
|
+
filepath = download_path / f"{filename}_v{version}.zip"
|
216
|
+
download_url(url=url, filepath=filepath, hash_val=None, progress=progress)
|
217
|
+
extractall(filepath=filepath, output_dir=download_path, has_base=True)
|
218
|
+
|
219
|
+
|
209
220
|
def _add_ngc_prefix(name: str, prefix: str = "monai_") -> str:
|
210
221
|
if name.startswith(prefix):
|
211
222
|
return name
|
@@ -222,7 +233,7 @@ def _get_all_download_files(request_url: str, headers: dict | None = None) -> li
|
|
222
233
|
if not has_requests:
|
223
234
|
raise ValueError("requests package is required, please install it.")
|
224
235
|
headers = {} if headers is None else headers
|
225
|
-
response =
|
236
|
+
response = requests.get(request_url, headers=headers)
|
226
237
|
response.raise_for_status()
|
227
238
|
model_info = json.loads(response.text)
|
228
239
|
|
@@ -266,7 +277,7 @@ def _download_from_ngc_private(
|
|
266
277
|
request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo)
|
267
278
|
if has_requests:
|
268
279
|
headers = {} if headers is None else headers
|
269
|
-
response =
|
280
|
+
response = requests.get(request_url, headers=headers)
|
270
281
|
response.raise_for_status()
|
271
282
|
else:
|
272
283
|
raise ValueError("NGC API requires requests package. Please install it.")
|
@@ -289,7 +300,7 @@ def _get_ngc_token(api_key, retry=0):
|
|
289
300
|
url = "https://authn.nvidia.com/token?service=ngc"
|
290
301
|
headers = {"Accept": "application/json", "Authorization": "ApiKey " + api_key}
|
291
302
|
if has_requests:
|
292
|
-
response =
|
303
|
+
response = requests.get(url, headers=headers)
|
293
304
|
if not response.ok:
|
294
305
|
# retry 3 times, if failed, raise an error.
|
295
306
|
if retry < 3:
|
@@ -303,14 +314,17 @@ def _get_ngc_token(api_key, retry=0):
|
|
303
314
|
|
304
315
|
def _get_latest_bundle_version_monaihosting(name):
|
305
316
|
full_url = f"{MONAI_HOSTING_BASE_URL}/{name.lower()}"
|
306
|
-
requests_get, has_requests = optional_import("requests", name="get")
|
307
317
|
if has_requests:
|
308
|
-
resp =
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
318
|
+
resp = requests.get(full_url)
|
319
|
+
try:
|
320
|
+
resp.raise_for_status()
|
321
|
+
model_info = json.loads(resp.text)
|
322
|
+
return model_info["model"]["latestVersionIdStr"]
|
323
|
+
except requests.exceptions.HTTPError:
|
324
|
+
# for monaihosting bundles, if cannot find the version, get from model zoo model_info.json
|
325
|
+
return get_bundle_versions(name)["latest_version"]
|
326
|
+
|
327
|
+
raise ValueError("NGC API requires requests package. Please install it.")
|
314
328
|
|
315
329
|
|
316
330
|
def _examine_monai_version(monai_version: str) -> tuple[bool, str]:
|
@@ -388,14 +402,14 @@ def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers:
|
|
388
402
|
version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements
|
389
403
|
if headers:
|
390
404
|
version_header.update(headers)
|
391
|
-
resp =
|
405
|
+
resp = requests.get(version_endpoint, headers=version_header)
|
392
406
|
resp.raise_for_status()
|
393
407
|
model_info = json.loads(resp.text)
|
394
408
|
latest_versions = _list_latest_versions(model_info)
|
395
409
|
|
396
410
|
for version in latest_versions:
|
397
411
|
file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json"
|
398
|
-
resp =
|
412
|
+
resp = requests.get(file_endpoint, headers=headers)
|
399
413
|
metadata = json.loads(resp.text)
|
400
414
|
resp.raise_for_status()
|
401
415
|
# if the package version is not available or the model is compatible with the package version
|
@@ -585,7 +599,16 @@ def download(
|
|
585
599
|
name_ver = "_v".join([name_, version_]) if version_ is not None else name_
|
586
600
|
_download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_)
|
587
601
|
elif source_ == "monaihosting":
|
588
|
-
|
602
|
+
try:
|
603
|
+
_download_from_monaihosting(
|
604
|
+
download_path=bundle_dir_, filename=name_, version=version_, progress=progress_
|
605
|
+
)
|
606
|
+
except urllib.error.HTTPError:
|
607
|
+
# for monaihosting bundles, if cannot download from default host, download according to bundle_info
|
608
|
+
_download_from_bundle_info(
|
609
|
+
download_path=bundle_dir_, filename=name_, version=version_, progress=progress_
|
610
|
+
)
|
611
|
+
|
589
612
|
elif source_ == "ngc":
|
590
613
|
_download_from_ngc(
|
591
614
|
download_path=bundle_dir_,
|
@@ -737,7 +760,7 @@ def load(
|
|
737
760
|
if load_ts_module is True:
|
738
761
|
return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files)
|
739
762
|
# loading with `torch.load`
|
740
|
-
model_dict = torch.load(full_path, map_location=torch.device(device))
|
763
|
+
model_dict = torch.load(full_path, map_location=torch.device(device), weights_only=True)
|
741
764
|
|
742
765
|
if not isinstance(model_dict, Mapping):
|
743
766
|
warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.")
|
@@ -792,9 +815,9 @@ def _get_all_bundles_info(
|
|
792
815
|
|
793
816
|
if auth_token is not None:
|
794
817
|
headers = {"Authorization": f"Bearer {auth_token}"}
|
795
|
-
resp =
|
818
|
+
resp = requests.get(request_url, headers=headers)
|
796
819
|
else:
|
797
|
-
resp =
|
820
|
+
resp = requests.get(request_url)
|
798
821
|
resp.raise_for_status()
|
799
822
|
else:
|
800
823
|
raise ValueError("requests package is required, please install it.")
|
@@ -1256,9 +1279,8 @@ def verify_net_in_out(
|
|
1256
1279
|
if input_dtype == torch.float16:
|
1257
1280
|
# fp16 can only be executed in gpu mode
|
1258
1281
|
net.to("cuda")
|
1259
|
-
from torch.cuda.amp import autocast
|
1260
1282
|
|
1261
|
-
with autocast():
|
1283
|
+
with torch.autocast("cuda"):
|
1262
1284
|
output = net(test_data.cuda(), **extra_forward_args_)
|
1263
1285
|
net.to(device_)
|
1264
1286
|
else:
|
@@ -1307,7 +1329,7 @@ def _export(
|
|
1307
1329
|
# here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver
|
1308
1330
|
Checkpoint.load_objects(to_load={key_in_ckpt: net}, checkpoint=ckpt_file)
|
1309
1331
|
else:
|
1310
|
-
ckpt = torch.load(ckpt_file)
|
1332
|
+
ckpt = torch.load(ckpt_file, weights_only=True)
|
1311
1333
|
copy_model_state(dst=net, src=ckpt if key_in_ckpt == "" else ckpt[key_in_ckpt])
|
1312
1334
|
|
1313
1335
|
# Use the given converter to convert a model and save with metadata, config content
|
monai/data/dataset.py
CHANGED
@@ -22,7 +22,6 @@ import time
|
|
22
22
|
import warnings
|
23
23
|
from collections.abc import Callable, Sequence
|
24
24
|
from copy import copy, deepcopy
|
25
|
-
from inspect import signature
|
26
25
|
from multiprocessing.managers import ListProxy
|
27
26
|
from multiprocessing.pool import ThreadPool
|
28
27
|
from pathlib import Path
|
@@ -372,10 +371,7 @@ class PersistentDataset(Dataset):
|
|
372
371
|
|
373
372
|
if hashfile is not None and hashfile.is_file(): # cache hit
|
374
373
|
try:
|
375
|
-
|
376
|
-
return torch.load(hashfile, weights_only=False)
|
377
|
-
else:
|
378
|
-
return torch.load(hashfile)
|
374
|
+
return torch.load(hashfile, weights_only=False)
|
379
375
|
except PermissionError as e:
|
380
376
|
if sys.platform != "win32":
|
381
377
|
raise e
|
@@ -1674,7 +1670,4 @@ class GDSDataset(PersistentDataset):
|
|
1674
1670
|
if meta_hash_file_name in self._meta_cache:
|
1675
1671
|
return self._meta_cache[meta_hash_file_name]
|
1676
1672
|
else:
|
1677
|
-
|
1678
|
-
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
|
1679
|
-
else:
|
1680
|
-
return torch.load(self.cache_dir / meta_hash_file_name)
|
1673
|
+
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
|
monai/data/utils.py
CHANGED
@@ -753,7 +753,7 @@ def affine_to_spacing(affine: NdarrayTensor, r: int = 3, dtype=float, suppress_z
|
|
753
753
|
if isinstance(_affine, torch.Tensor):
|
754
754
|
spacing = torch.sqrt(torch.sum(_affine * _affine, dim=0))
|
755
755
|
else:
|
756
|
-
spacing = np.sqrt(np.sum(_affine * _affine, axis=0))
|
756
|
+
spacing = np.sqrt(np.sum(_affine * _affine, axis=0)) # type: ignore[operator]
|
757
757
|
if suppress_zeros:
|
758
758
|
spacing[spacing == 0] = 1.0
|
759
759
|
spacing_, *_ = convert_to_dst_type(spacing, dst=affine, dtype=dtype)
|
monai/data/video_dataset.py
CHANGED
@@ -177,7 +177,7 @@ class VideoFileDataset(Dataset, VideoDataset):
|
|
177
177
|
for codec, ext in all_codecs.items():
|
178
178
|
writer = cv2.VideoWriter()
|
179
179
|
fname = os.path.join(tmp_dir, f"test{ext}")
|
180
|
-
fourcc = cv2.VideoWriter_fourcc(*codec)
|
180
|
+
fourcc = cv2.VideoWriter_fourcc(*codec) # type: ignore[attr-defined]
|
181
181
|
noviderr = writer.open(fname, fourcc, 1, (10, 10))
|
182
182
|
if noviderr:
|
183
183
|
codecs[codec] = ext
|
monai/engines/evaluator.py
CHANGED
@@ -28,7 +28,7 @@ from monai.transforms import Transform
|
|
28
28
|
from monai.utils import ForwardMode, IgniteInfo, ensure_tuple, min_version, optional_import
|
29
29
|
from monai.utils.enums import CommonKeys as Keys
|
30
30
|
from monai.utils.enums import EngineStatsKeys as ESKeys
|
31
|
-
from monai.utils.module import look_up_option
|
31
|
+
from monai.utils.module import look_up_option
|
32
32
|
|
33
33
|
if TYPE_CHECKING:
|
34
34
|
from ignite.engine import Engine, EventEnum
|
@@ -82,8 +82,8 @@ class Evaluator(Workflow):
|
|
82
82
|
default to `True`.
|
83
83
|
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
|
84
84
|
`device`, `non_blocking`.
|
85
|
-
amp_kwargs: dict of the args for `torch.
|
86
|
-
https://pytorch.org/docs/stable/amp.html#torch.
|
85
|
+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
|
86
|
+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
|
87
87
|
|
88
88
|
"""
|
89
89
|
|
@@ -214,8 +214,8 @@ class SupervisedEvaluator(Evaluator):
|
|
214
214
|
default to `True`.
|
215
215
|
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
|
216
216
|
`device`, `non_blocking`.
|
217
|
-
amp_kwargs: dict of the args for `torch.
|
218
|
-
https://pytorch.org/docs/stable/amp.html#torch.
|
217
|
+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
|
218
|
+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
|
219
219
|
compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
|
220
220
|
`torch.Tensor` before forward pass, then converted back afterward with copied meta information.
|
221
221
|
compile_kwargs: dict of the args for `torch.compile()` API, for more details:
|
@@ -269,13 +269,8 @@ class SupervisedEvaluator(Evaluator):
|
|
269
269
|
amp_kwargs=amp_kwargs,
|
270
270
|
)
|
271
271
|
if compile:
|
272
|
-
if
|
273
|
-
|
274
|
-
network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
|
275
|
-
else:
|
276
|
-
warnings.warn(
|
277
|
-
"Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done"
|
278
|
-
)
|
272
|
+
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
273
|
+
network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
|
279
274
|
self.network = network
|
280
275
|
self.compile = compile
|
281
276
|
self.inferer = SimpleInferer() if inferer is None else inferer
|
@@ -329,7 +324,7 @@ class SupervisedEvaluator(Evaluator):
|
|
329
324
|
# execute forward computation
|
330
325
|
with engine.mode(engine.network):
|
331
326
|
if engine.amp:
|
332
|
-
with torch.
|
327
|
+
with torch.autocast("cuda", **engine.amp_kwargs):
|
333
328
|
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
|
334
329
|
else:
|
335
330
|
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
|
@@ -399,8 +394,8 @@ class EnsembleEvaluator(Evaluator):
|
|
399
394
|
default to `True`.
|
400
395
|
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
|
401
396
|
`device`, `non_blocking`.
|
402
|
-
amp_kwargs: dict of the args for `torch.
|
403
|
-
https://pytorch.org/docs/stable/amp.html#torch.
|
397
|
+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
|
398
|
+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
|
404
399
|
|
405
400
|
"""
|
406
401
|
|
@@ -492,7 +487,7 @@ class EnsembleEvaluator(Evaluator):
|
|
492
487
|
for idx, network in enumerate(engine.networks):
|
493
488
|
with engine.mode(network):
|
494
489
|
if engine.amp:
|
495
|
-
with torch.
|
490
|
+
with torch.autocast("cuda", **engine.amp_kwargs):
|
496
491
|
if isinstance(engine.state.output, dict):
|
497
492
|
engine.state.output.update(
|
498
493
|
{engine.pred_keys[idx]: engine.inferer(inputs, network, *args, **kwargs)}
|
monai/engines/trainer.py
CHANGED
@@ -27,7 +27,6 @@ from monai.transforms import Transform
|
|
27
27
|
from monai.utils import AdversarialIterationEvents, AdversarialKeys, GanKeys, IgniteInfo, min_version, optional_import
|
28
28
|
from monai.utils.enums import CommonKeys as Keys
|
29
29
|
from monai.utils.enums import EngineStatsKeys as ESKeys
|
30
|
-
from monai.utils.module import pytorch_after
|
31
30
|
|
32
31
|
if TYPE_CHECKING:
|
33
32
|
from ignite.engine import Engine, EventEnum
|
@@ -126,8 +125,8 @@ class SupervisedTrainer(Trainer):
|
|
126
125
|
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
|
127
126
|
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
|
128
127
|
`device`, `non_blocking`.
|
129
|
-
amp_kwargs: dict of the args for `torch.
|
130
|
-
https://pytorch.org/docs/stable/amp.html#torch.
|
128
|
+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
|
129
|
+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
|
131
130
|
compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
|
132
131
|
`torch.Tensor` before forward pass, then converted back afterward with copied meta information.
|
133
132
|
compile_kwargs: dict of the args for `torch.compile()` API, for more details:
|
@@ -183,13 +182,8 @@ class SupervisedTrainer(Trainer):
|
|
183
182
|
amp_kwargs=amp_kwargs,
|
184
183
|
)
|
185
184
|
if compile:
|
186
|
-
if
|
187
|
-
|
188
|
-
network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
|
189
|
-
else:
|
190
|
-
warnings.warn(
|
191
|
-
"Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done"
|
192
|
-
)
|
185
|
+
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
186
|
+
network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
|
193
187
|
self.network = network
|
194
188
|
self.compile = compile
|
195
189
|
self.optimizer = optimizer
|
@@ -255,7 +249,7 @@ class SupervisedTrainer(Trainer):
|
|
255
249
|
engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
|
256
250
|
|
257
251
|
if engine.amp and engine.scaler is not None:
|
258
|
-
with torch.
|
252
|
+
with torch.autocast("cuda", **engine.amp_kwargs):
|
259
253
|
_compute_pred_loss()
|
260
254
|
engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
|
261
255
|
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
|
@@ -341,8 +335,8 @@ class GanTrainer(Trainer):
|
|
341
335
|
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
|
342
336
|
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
|
343
337
|
`device`, `non_blocking`.
|
344
|
-
amp_kwargs: dict of the args for `torch.
|
345
|
-
https://pytorch.org/docs/stable/amp.html#torch.
|
338
|
+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
|
339
|
+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
|
346
340
|
|
347
341
|
"""
|
348
342
|
|
@@ -518,8 +512,8 @@ class AdversarialTrainer(Trainer):
|
|
518
512
|
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
|
519
513
|
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
|
520
514
|
`device`, `non_blocking`.
|
521
|
-
amp_kwargs: dict of the args for `torch.
|
522
|
-
https://pytorch.org/docs/stable/amp.html#torch.
|
515
|
+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
|
516
|
+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
|
523
517
|
"""
|
524
518
|
|
525
519
|
def __init__(
|
@@ -689,7 +683,7 @@ class AdversarialTrainer(Trainer):
|
|
689
683
|
engine.state.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
|
690
684
|
|
691
685
|
if engine.amp and engine.state.g_scaler is not None:
|
692
|
-
with torch.
|
686
|
+
with torch.autocast("cuda", **engine.amp_kwargs):
|
693
687
|
_compute_generator_loss()
|
694
688
|
|
695
689
|
engine.state.output[Keys.LOSS] = (
|
@@ -737,7 +731,7 @@ class AdversarialTrainer(Trainer):
|
|
737
731
|
engine.state.d_network.zero_grad(set_to_none=engine.optim_set_to_none)
|
738
732
|
|
739
733
|
if engine.amp and engine.state.d_scaler is not None:
|
740
|
-
with torch.
|
734
|
+
with torch.autocast("cuda", **engine.amp_kwargs):
|
741
735
|
_compute_discriminator_loss()
|
742
736
|
|
743
737
|
engine.state.d_scaler.scale(engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS]).backward()
|
monai/engines/utils.py
CHANGED
@@ -309,7 +309,7 @@ class VPredictionPrepareBatch(DiffusionPrepareBatch):
|
|
309
309
|
self.scheduler = scheduler
|
310
310
|
|
311
311
|
def get_target(self, images, noise, timesteps):
|
312
|
-
return self.scheduler.get_velocity(images, noise, timesteps)
|
312
|
+
return self.scheduler.get_velocity(images, noise, timesteps) # type: ignore[operator]
|
313
313
|
|
314
314
|
|
315
315
|
def default_make_latent(
|
monai/engines/workflow.py
CHANGED
@@ -90,8 +90,8 @@ class Workflow(Engine):
|
|
90
90
|
default to `True`.
|
91
91
|
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
|
92
92
|
`device`, `non_blocking`.
|
93
|
-
amp_kwargs: dict of the args for `torch.
|
94
|
-
https://pytorch.org/docs/stable/amp.html#torch.
|
93
|
+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
|
94
|
+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
|
95
95
|
|
96
96
|
Raises:
|
97
97
|
TypeError: When ``data_loader`` is not a ``torch.utils.data.DataLoader``.
|
monai/fl/client/monai_algo.py
CHANGED
@@ -574,7 +574,7 @@ class MonaiAlgo(ClientAlgo, MonaiAlgoStats):
|
|
574
574
|
model_path = os.path.join(self.bundle_root, cast(str, self.model_filepaths[model_type]))
|
575
575
|
if not os.path.isfile(model_path):
|
576
576
|
raise ValueError(f"No best model checkpoint exists at {model_path}")
|
577
|
-
weights = torch.load(model_path, map_location="cpu")
|
577
|
+
weights = torch.load(model_path, map_location="cpu", weights_only=True)
|
578
578
|
# if weights contain several state dicts, use the one defined by `save_dict_key`
|
579
579
|
if isinstance(weights, dict) and self.save_dict_key in weights:
|
580
580
|
weights = weights.get(self.save_dict_key)
|
@@ -122,7 +122,7 @@ class CheckpointLoader:
|
|
122
122
|
Args:
|
123
123
|
engine: Ignite Engine, it can be a trainer, validator or evaluator.
|
124
124
|
"""
|
125
|
-
checkpoint = torch.load(self.load_path, map_location=self.map_location)
|
125
|
+
checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=False)
|
126
126
|
|
127
127
|
k, _ = list(self.load_dict.items())[0]
|
128
128
|
# single object and checkpoint is directly a state_dict
|