monai-weekly 1.5.dev2508__py3-none-any.whl → 1.5.dev2510__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/deepedit/interaction.py +1 -1
  4. monai/apps/deepgrow/interaction.py +1 -1
  5. monai/apps/detection/networks/retinanet_detector.py +1 -1
  6. monai/apps/detection/networks/retinanet_network.py +5 -5
  7. monai/apps/detection/utils/box_coder.py +2 -2
  8. monai/apps/mmars/mmars.py +1 -1
  9. monai/apps/reconstruction/networks/blocks/varnetblock.py +1 -1
  10. monai/bundle/scripts.py +42 -20
  11. monai/data/dataset.py +2 -9
  12. monai/data/utils.py +1 -1
  13. monai/data/video_dataset.py +1 -1
  14. monai/engines/evaluator.py +11 -16
  15. monai/engines/trainer.py +11 -17
  16. monai/engines/utils.py +1 -1
  17. monai/engines/workflow.py +2 -2
  18. monai/fl/client/monai_algo.py +1 -1
  19. monai/handlers/checkpoint_loader.py +1 -1
  20. monai/inferers/inferer.py +35 -17
  21. monai/inferers/merger.py +16 -13
  22. monai/losses/perceptual.py +1 -1
  23. monai/losses/sure_loss.py +1 -1
  24. monai/networks/blocks/crossattention.py +1 -6
  25. monai/networks/blocks/feature_pyramid_network.py +4 -2
  26. monai/networks/blocks/selfattention.py +1 -6
  27. monai/networks/blocks/upsample.py +3 -11
  28. monai/networks/layers/vector_quantizer.py +2 -2
  29. monai/networks/nets/hovernet.py +5 -4
  30. monai/networks/nets/resnet.py +2 -2
  31. monai/networks/nets/senet.py +1 -1
  32. monai/networks/nets/swin_unetr.py +46 -49
  33. monai/networks/nets/transchex.py +3 -2
  34. monai/networks/nets/vista3d.py +7 -7
  35. monai/networks/utils.py +5 -4
  36. monai/transforms/intensity/array.py +1 -1
  37. monai/transforms/spatial/array.py +6 -6
  38. monai/utils/misc.py +1 -1
  39. monai/utils/state_cacher.py +1 -1
  40. {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2510.dist-info}/METADATA +4 -3
  41. {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2510.dist-info}/RECORD +60 -60
  42. {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2510.dist-info}/WHEEL +1 -1
  43. tests/bundle/test_bundle_download.py +16 -6
  44. tests/config/test_cv2_dist.py +1 -2
  45. tests/inferers/test_controlnet_inferers.py +9 -0
  46. tests/integration/test_integration_bundle_run.py +2 -4
  47. tests/integration/test_integration_classification_2d.py +1 -1
  48. tests/integration/test_integration_fast_train.py +2 -2
  49. tests/integration/test_integration_segmentation_3d.py +1 -1
  50. tests/metrics/test_compute_multiscalessim_metric.py +3 -3
  51. tests/metrics/test_surface_dice.py +3 -3
  52. tests/networks/nets/test_autoencoderkl.py +1 -1
  53. tests/networks/nets/test_controlnet.py +1 -1
  54. tests/networks/nets/test_diffusion_model_unet.py +1 -1
  55. tests/networks/nets/test_network_consistency.py +1 -1
  56. tests/networks/nets/test_swin_unetr.py +1 -1
  57. tests/networks/nets/test_transformer.py +1 -1
  58. tests/networks/test_save_state.py +1 -1
  59. {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2510.dist-info}/LICENSE +0 -0
  60. {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2510.dist-info}/top_level.txt +0 -0
monai/__init__.py CHANGED
@@ -136,4 +136,4 @@ except BaseException:
136
136
 
137
137
  if MONAIEnvVars.debug():
138
138
  raise
139
- __commit_id__ = "a7905909e785d1ef24103c32a2d3a5a36e1059a2"
139
+ __commit_id__ = "7c26e5af385eb5f7a813fa405c6f3fc87b7511fa"
monai/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2025-02-23T02:28:09+0000",
11
+ "date": "2025-03-09T02:16:22+0000",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "e55b5cbfbbba1800a968a9c06b2deaaa5c9bec54",
15
- "version": "1.5.dev2508"
14
+ "full-revisionid": "19fadf962d87a21e1d0edf8d72299e82f7611140",
15
+ "version": "1.5.dev2510"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -72,7 +72,7 @@ class Interaction:
72
72
 
73
73
  with torch.no_grad():
74
74
  if engine.amp:
75
- with torch.cuda.amp.autocast():
75
+ with torch.autocast("cuda"):
76
76
  predictions = engine.inferer(inputs, engine.network)
77
77
  else:
78
78
  predictions = engine.inferer(inputs, engine.network)
@@ -67,7 +67,7 @@ class Interaction:
67
67
  engine.network.eval()
68
68
  with torch.no_grad():
69
69
  if engine.amp:
70
- with torch.cuda.amp.autocast():
70
+ with torch.autocast("cuda"):
71
71
  predictions = engine.inferer(inputs, engine.network)
72
72
  else:
73
73
  predictions = engine.inferer(inputs, engine.network)
@@ -180,7 +180,7 @@ class RetinaNetDetector(nn.Module):
180
180
  nesterov=True,
181
181
  )
182
182
  torch.save(detector.network.state_dict(), 'model.pt') # save model
183
- detector.network.load_state_dict(torch.load('model.pt')) # load model
183
+ detector.network.load_state_dict(torch.load('model.pt', weights_only=True)) # load model
184
184
  """
185
185
 
186
186
  def __init__(
@@ -88,8 +88,8 @@ class RetinaNetClassificationHead(nn.Module):
88
88
 
89
89
  for layer in self.conv.children():
90
90
  if isinstance(layer, conv_type): # type: ignore
91
- torch.nn.init.normal_(layer.weight, std=0.01)
92
- torch.nn.init.constant_(layer.bias, 0)
91
+ torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
92
+ torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type]
93
93
 
94
94
  self.cls_logits = conv_type(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
95
95
  torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
@@ -167,8 +167,8 @@ class RetinaNetRegressionHead(nn.Module):
167
167
 
168
168
  for layer in self.conv.children():
169
169
  if isinstance(layer, conv_type): # type: ignore
170
- torch.nn.init.normal_(layer.weight, std=0.01)
171
- torch.nn.init.zeros_(layer.bias)
170
+ torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
171
+ torch.nn.init.zeros_(layer.bias) # type: ignore[arg-type]
172
172
 
173
173
  def forward(self, x: list[Tensor]) -> list[Tensor]:
174
174
  """
@@ -297,7 +297,7 @@ class RetinaNet(nn.Module):
297
297
  )
298
298
  self.feature_extractor = feature_extractor
299
299
 
300
- self.feature_map_channels: int = self.feature_extractor.out_channels
300
+ self.feature_map_channels: int = self.feature_extractor.out_channels # type: ignore[assignment]
301
301
  self.num_anchors = num_anchors
302
302
  self.classification_head = RetinaNetClassificationHead(
303
303
  self.feature_map_channels, self.num_anchors, self.num_classes, spatial_dims=self.spatial_dims
@@ -221,7 +221,7 @@ class BoxCoder:
221
221
 
222
222
  pred_ctr_xyx_axis = dxyz_axis * whd_axis[:, None] + ctr_xyz_axis[:, None]
223
223
  pred_whd_axis = torch.exp(dwhd_axis) * whd_axis[:, None]
224
- pred_whd_axis = pred_whd_axis.to(dxyz_axis.dtype)
224
+ pred_whd_axis = pred_whd_axis.to(dxyz_axis.dtype) # type: ignore[union-attr]
225
225
 
226
226
  # When convert float32 to float16, Inf or Nan may occur
227
227
  if torch.isnan(pred_whd_axis).any() or torch.isinf(pred_whd_axis).any():
@@ -229,7 +229,7 @@ class BoxCoder:
229
229
 
230
230
  # Distance from center to box's corner.
231
231
  c_to_c_whd_axis = (
232
- torch.tensor(0.5, dtype=pred_ctr_xyx_axis.dtype, device=pred_whd_axis.device) * pred_whd_axis
232
+ torch.tensor(0.5, dtype=pred_ctr_xyx_axis.dtype, device=pred_whd_axis.device) * pred_whd_axis # type: ignore[arg-type]
233
233
  )
234
234
 
235
235
  pred_boxes.append(pred_ctr_xyx_axis - c_to_c_whd_axis)
monai/apps/mmars/mmars.py CHANGED
@@ -241,7 +241,7 @@ def load_from_mmar(
241
241
  return torch.jit.load(_model_file, map_location=map_location)
242
242
 
243
243
  # loading with `torch.load`
244
- model_dict = torch.load(_model_file, map_location=map_location)
244
+ model_dict = torch.load(_model_file, map_location=map_location, weights_only=True)
245
245
  if weights_only:
246
246
  return model_dict.get(model_key, model_dict) # model_dict[model_key] or model_dict directly
247
247
 
@@ -55,7 +55,7 @@ class VarNetBlock(nn.Module):
55
55
  Returns:
56
56
  Output of DC block with the same shape as x
57
57
  """
58
- return torch.where(mask, x - ref_kspace, self.zeros) * self.dc_weight
58
+ return torch.where(mask, x - ref_kspace, self.zeros) * self.dc_weight # type: ignore
59
59
 
60
60
  def forward(self, current_kspace: Tensor, ref_kspace: Tensor, mask: Tensor, sens_maps: Tensor) -> Tensor:
61
61
  """
monai/bundle/scripts.py CHANGED
@@ -15,6 +15,7 @@ import ast
15
15
  import json
16
16
  import os
17
17
  import re
18
+ import urllib
18
19
  import warnings
19
20
  import zipfile
20
21
  from collections.abc import Mapping, Sequence
@@ -58,7 +59,7 @@ from monai.utils import (
58
59
  validate, _ = optional_import("jsonschema", name="validate")
59
60
  ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError")
60
61
  Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint")
61
- requests_get, has_requests = optional_import("requests", name="get")
62
+ requests, has_requests = optional_import("requests")
62
63
  onnx, _ = optional_import("onnx")
63
64
  huggingface_hub, _ = optional_import("huggingface_hub")
64
65
 
@@ -206,6 +207,16 @@ def _download_from_monaihosting(download_path: Path, filename: str, version: str
206
207
  extractall(filepath=filepath, output_dir=download_path, has_base=True)
207
208
 
208
209
 
210
+ def _download_from_bundle_info(download_path: Path, filename: str, version: str, progress: bool) -> None:
211
+ bundle_info = get_bundle_info(bundle_name=filename, version=version)
212
+ if not bundle_info:
213
+ raise ValueError(f"Bundle info not found for {filename} v{version}.")
214
+ url = bundle_info["browser_download_url"]
215
+ filepath = download_path / f"{filename}_v{version}.zip"
216
+ download_url(url=url, filepath=filepath, hash_val=None, progress=progress)
217
+ extractall(filepath=filepath, output_dir=download_path, has_base=True)
218
+
219
+
209
220
  def _add_ngc_prefix(name: str, prefix: str = "monai_") -> str:
210
221
  if name.startswith(prefix):
211
222
  return name
@@ -222,7 +233,7 @@ def _get_all_download_files(request_url: str, headers: dict | None = None) -> li
222
233
  if not has_requests:
223
234
  raise ValueError("requests package is required, please install it.")
224
235
  headers = {} if headers is None else headers
225
- response = requests_get(request_url, headers=headers)
236
+ response = requests.get(request_url, headers=headers)
226
237
  response.raise_for_status()
227
238
  model_info = json.loads(response.text)
228
239
 
@@ -266,7 +277,7 @@ def _download_from_ngc_private(
266
277
  request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo)
267
278
  if has_requests:
268
279
  headers = {} if headers is None else headers
269
- response = requests_get(request_url, headers=headers)
280
+ response = requests.get(request_url, headers=headers)
270
281
  response.raise_for_status()
271
282
  else:
272
283
  raise ValueError("NGC API requires requests package. Please install it.")
@@ -289,7 +300,7 @@ def _get_ngc_token(api_key, retry=0):
289
300
  url = "https://authn.nvidia.com/token?service=ngc"
290
301
  headers = {"Accept": "application/json", "Authorization": "ApiKey " + api_key}
291
302
  if has_requests:
292
- response = requests_get(url, headers=headers)
303
+ response = requests.get(url, headers=headers)
293
304
  if not response.ok:
294
305
  # retry 3 times, if failed, raise an error.
295
306
  if retry < 3:
@@ -303,14 +314,17 @@ def _get_ngc_token(api_key, retry=0):
303
314
 
304
315
  def _get_latest_bundle_version_monaihosting(name):
305
316
  full_url = f"{MONAI_HOSTING_BASE_URL}/{name.lower()}"
306
- requests_get, has_requests = optional_import("requests", name="get")
307
317
  if has_requests:
308
- resp = requests_get(full_url)
309
- resp.raise_for_status()
310
- else:
311
- raise ValueError("NGC API requires requests package. Please install it.")
312
- model_info = json.loads(resp.text)
313
- return model_info["model"]["latestVersionIdStr"]
318
+ resp = requests.get(full_url)
319
+ try:
320
+ resp.raise_for_status()
321
+ model_info = json.loads(resp.text)
322
+ return model_info["model"]["latestVersionIdStr"]
323
+ except requests.exceptions.HTTPError:
324
+ # for monaihosting bundles, if cannot find the version, get from model zoo model_info.json
325
+ return get_bundle_versions(name)["latest_version"]
326
+
327
+ raise ValueError("NGC API requires requests package. Please install it.")
314
328
 
315
329
 
316
330
  def _examine_monai_version(monai_version: str) -> tuple[bool, str]:
@@ -388,14 +402,14 @@ def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers:
388
402
  version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements
389
403
  if headers:
390
404
  version_header.update(headers)
391
- resp = requests_get(version_endpoint, headers=version_header)
405
+ resp = requests.get(version_endpoint, headers=version_header)
392
406
  resp.raise_for_status()
393
407
  model_info = json.loads(resp.text)
394
408
  latest_versions = _list_latest_versions(model_info)
395
409
 
396
410
  for version in latest_versions:
397
411
  file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json"
398
- resp = requests_get(file_endpoint, headers=headers)
412
+ resp = requests.get(file_endpoint, headers=headers)
399
413
  metadata = json.loads(resp.text)
400
414
  resp.raise_for_status()
401
415
  # if the package version is not available or the model is compatible with the package version
@@ -585,7 +599,16 @@ def download(
585
599
  name_ver = "_v".join([name_, version_]) if version_ is not None else name_
586
600
  _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_)
587
601
  elif source_ == "monaihosting":
588
- _download_from_monaihosting(download_path=bundle_dir_, filename=name_, version=version_, progress=progress_)
602
+ try:
603
+ _download_from_monaihosting(
604
+ download_path=bundle_dir_, filename=name_, version=version_, progress=progress_
605
+ )
606
+ except urllib.error.HTTPError:
607
+ # for monaihosting bundles, if cannot download from default host, download according to bundle_info
608
+ _download_from_bundle_info(
609
+ download_path=bundle_dir_, filename=name_, version=version_, progress=progress_
610
+ )
611
+
589
612
  elif source_ == "ngc":
590
613
  _download_from_ngc(
591
614
  download_path=bundle_dir_,
@@ -737,7 +760,7 @@ def load(
737
760
  if load_ts_module is True:
738
761
  return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files)
739
762
  # loading with `torch.load`
740
- model_dict = torch.load(full_path, map_location=torch.device(device))
763
+ model_dict = torch.load(full_path, map_location=torch.device(device), weights_only=True)
741
764
 
742
765
  if not isinstance(model_dict, Mapping):
743
766
  warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.")
@@ -792,9 +815,9 @@ def _get_all_bundles_info(
792
815
 
793
816
  if auth_token is not None:
794
817
  headers = {"Authorization": f"Bearer {auth_token}"}
795
- resp = requests_get(request_url, headers=headers)
818
+ resp = requests.get(request_url, headers=headers)
796
819
  else:
797
- resp = requests_get(request_url)
820
+ resp = requests.get(request_url)
798
821
  resp.raise_for_status()
799
822
  else:
800
823
  raise ValueError("requests package is required, please install it.")
@@ -1256,9 +1279,8 @@ def verify_net_in_out(
1256
1279
  if input_dtype == torch.float16:
1257
1280
  # fp16 can only be executed in gpu mode
1258
1281
  net.to("cuda")
1259
- from torch.cuda.amp import autocast
1260
1282
 
1261
- with autocast():
1283
+ with torch.autocast("cuda"):
1262
1284
  output = net(test_data.cuda(), **extra_forward_args_)
1263
1285
  net.to(device_)
1264
1286
  else:
@@ -1307,7 +1329,7 @@ def _export(
1307
1329
  # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver
1308
1330
  Checkpoint.load_objects(to_load={key_in_ckpt: net}, checkpoint=ckpt_file)
1309
1331
  else:
1310
- ckpt = torch.load(ckpt_file)
1332
+ ckpt = torch.load(ckpt_file, weights_only=True)
1311
1333
  copy_model_state(dst=net, src=ckpt if key_in_ckpt == "" else ckpt[key_in_ckpt])
1312
1334
 
1313
1335
  # Use the given converter to convert a model and save with metadata, config content
monai/data/dataset.py CHANGED
@@ -22,7 +22,6 @@ import time
22
22
  import warnings
23
23
  from collections.abc import Callable, Sequence
24
24
  from copy import copy, deepcopy
25
- from inspect import signature
26
25
  from multiprocessing.managers import ListProxy
27
26
  from multiprocessing.pool import ThreadPool
28
27
  from pathlib import Path
@@ -372,10 +371,7 @@ class PersistentDataset(Dataset):
372
371
 
373
372
  if hashfile is not None and hashfile.is_file(): # cache hit
374
373
  try:
375
- if "weights_only" in signature(torch.load).parameters:
376
- return torch.load(hashfile, weights_only=False)
377
- else:
378
- return torch.load(hashfile)
374
+ return torch.load(hashfile, weights_only=False)
379
375
  except PermissionError as e:
380
376
  if sys.platform != "win32":
381
377
  raise e
@@ -1674,7 +1670,4 @@ class GDSDataset(PersistentDataset):
1674
1670
  if meta_hash_file_name in self._meta_cache:
1675
1671
  return self._meta_cache[meta_hash_file_name]
1676
1672
  else:
1677
- if "weights_only" in signature(torch.load).parameters:
1678
- return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
1679
- else:
1680
- return torch.load(self.cache_dir / meta_hash_file_name)
1673
+ return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
monai/data/utils.py CHANGED
@@ -753,7 +753,7 @@ def affine_to_spacing(affine: NdarrayTensor, r: int = 3, dtype=float, suppress_z
753
753
  if isinstance(_affine, torch.Tensor):
754
754
  spacing = torch.sqrt(torch.sum(_affine * _affine, dim=0))
755
755
  else:
756
- spacing = np.sqrt(np.sum(_affine * _affine, axis=0))
756
+ spacing = np.sqrt(np.sum(_affine * _affine, axis=0)) # type: ignore[operator]
757
757
  if suppress_zeros:
758
758
  spacing[spacing == 0] = 1.0
759
759
  spacing_, *_ = convert_to_dst_type(spacing, dst=affine, dtype=dtype)
@@ -177,7 +177,7 @@ class VideoFileDataset(Dataset, VideoDataset):
177
177
  for codec, ext in all_codecs.items():
178
178
  writer = cv2.VideoWriter()
179
179
  fname = os.path.join(tmp_dir, f"test{ext}")
180
- fourcc = cv2.VideoWriter_fourcc(*codec)
180
+ fourcc = cv2.VideoWriter_fourcc(*codec) # type: ignore[attr-defined]
181
181
  noviderr = writer.open(fname, fourcc, 1, (10, 10))
182
182
  if noviderr:
183
183
  codecs[codec] = ext
@@ -28,7 +28,7 @@ from monai.transforms import Transform
28
28
  from monai.utils import ForwardMode, IgniteInfo, ensure_tuple, min_version, optional_import
29
29
  from monai.utils.enums import CommonKeys as Keys
30
30
  from monai.utils.enums import EngineStatsKeys as ESKeys
31
- from monai.utils.module import look_up_option, pytorch_after
31
+ from monai.utils.module import look_up_option
32
32
 
33
33
  if TYPE_CHECKING:
34
34
  from ignite.engine import Engine, EventEnum
@@ -82,8 +82,8 @@ class Evaluator(Workflow):
82
82
  default to `True`.
83
83
  to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
84
84
  `device`, `non_blocking`.
85
- amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
86
- https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
85
+ amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
86
+ https://pytorch.org/docs/stable/amp.html#torch.autocast.
87
87
 
88
88
  """
89
89
 
@@ -214,8 +214,8 @@ class SupervisedEvaluator(Evaluator):
214
214
  default to `True`.
215
215
  to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
216
216
  `device`, `non_blocking`.
217
- amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
218
- https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
217
+ amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
218
+ https://pytorch.org/docs/stable/amp.html#torch.autocast.
219
219
  compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
220
220
  `torch.Tensor` before forward pass, then converted back afterward with copied meta information.
221
221
  compile_kwargs: dict of the args for `torch.compile()` API, for more details:
@@ -269,13 +269,8 @@ class SupervisedEvaluator(Evaluator):
269
269
  amp_kwargs=amp_kwargs,
270
270
  )
271
271
  if compile:
272
- if pytorch_after(2, 1):
273
- compile_kwargs = {} if compile_kwargs is None else compile_kwargs
274
- network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
275
- else:
276
- warnings.warn(
277
- "Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done"
278
- )
272
+ compile_kwargs = {} if compile_kwargs is None else compile_kwargs
273
+ network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
279
274
  self.network = network
280
275
  self.compile = compile
281
276
  self.inferer = SimpleInferer() if inferer is None else inferer
@@ -329,7 +324,7 @@ class SupervisedEvaluator(Evaluator):
329
324
  # execute forward computation
330
325
  with engine.mode(engine.network):
331
326
  if engine.amp:
332
- with torch.cuda.amp.autocast(**engine.amp_kwargs):
327
+ with torch.autocast("cuda", **engine.amp_kwargs):
333
328
  engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
334
329
  else:
335
330
  engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
@@ -399,8 +394,8 @@ class EnsembleEvaluator(Evaluator):
399
394
  default to `True`.
400
395
  to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
401
396
  `device`, `non_blocking`.
402
- amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
403
- https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
397
+ amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
398
+ https://pytorch.org/docs/stable/amp.html#torch.autocast.
404
399
 
405
400
  """
406
401
 
@@ -492,7 +487,7 @@ class EnsembleEvaluator(Evaluator):
492
487
  for idx, network in enumerate(engine.networks):
493
488
  with engine.mode(network):
494
489
  if engine.amp:
495
- with torch.cuda.amp.autocast(**engine.amp_kwargs):
490
+ with torch.autocast("cuda", **engine.amp_kwargs):
496
491
  if isinstance(engine.state.output, dict):
497
492
  engine.state.output.update(
498
493
  {engine.pred_keys[idx]: engine.inferer(inputs, network, *args, **kwargs)}
monai/engines/trainer.py CHANGED
@@ -27,7 +27,6 @@ from monai.transforms import Transform
27
27
  from monai.utils import AdversarialIterationEvents, AdversarialKeys, GanKeys, IgniteInfo, min_version, optional_import
28
28
  from monai.utils.enums import CommonKeys as Keys
29
29
  from monai.utils.enums import EngineStatsKeys as ESKeys
30
- from monai.utils.module import pytorch_after
31
30
 
32
31
  if TYPE_CHECKING:
33
32
  from ignite.engine import Engine, EventEnum
@@ -126,8 +125,8 @@ class SupervisedTrainer(Trainer):
126
125
  more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
127
126
  to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
128
127
  `device`, `non_blocking`.
129
- amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
130
- https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
128
+ amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
129
+ https://pytorch.org/docs/stable/amp.html#torch.autocast.
131
130
  compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
132
131
  `torch.Tensor` before forward pass, then converted back afterward with copied meta information.
133
132
  compile_kwargs: dict of the args for `torch.compile()` API, for more details:
@@ -183,13 +182,8 @@ class SupervisedTrainer(Trainer):
183
182
  amp_kwargs=amp_kwargs,
184
183
  )
185
184
  if compile:
186
- if pytorch_after(2, 1):
187
- compile_kwargs = {} if compile_kwargs is None else compile_kwargs
188
- network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
189
- else:
190
- warnings.warn(
191
- "Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done"
192
- )
185
+ compile_kwargs = {} if compile_kwargs is None else compile_kwargs
186
+ network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
193
187
  self.network = network
194
188
  self.compile = compile
195
189
  self.optimizer = optimizer
@@ -255,7 +249,7 @@ class SupervisedTrainer(Trainer):
255
249
  engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
256
250
 
257
251
  if engine.amp and engine.scaler is not None:
258
- with torch.cuda.amp.autocast(**engine.amp_kwargs):
252
+ with torch.autocast("cuda", **engine.amp_kwargs):
259
253
  _compute_pred_loss()
260
254
  engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
261
255
  engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
@@ -341,8 +335,8 @@ class GanTrainer(Trainer):
341
335
  more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
342
336
  to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
343
337
  `device`, `non_blocking`.
344
- amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
345
- https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
338
+ amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
339
+ https://pytorch.org/docs/stable/amp.html#torch.autocast.
346
340
 
347
341
  """
348
342
 
@@ -518,8 +512,8 @@ class AdversarialTrainer(Trainer):
518
512
  more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
519
513
  to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
520
514
  `device`, `non_blocking`.
521
- amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
522
- https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
515
+ amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
516
+ https://pytorch.org/docs/stable/amp.html#torch.autocast.
523
517
  """
524
518
 
525
519
  def __init__(
@@ -689,7 +683,7 @@ class AdversarialTrainer(Trainer):
689
683
  engine.state.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
690
684
 
691
685
  if engine.amp and engine.state.g_scaler is not None:
692
- with torch.cuda.amp.autocast(**engine.amp_kwargs):
686
+ with torch.autocast("cuda", **engine.amp_kwargs):
693
687
  _compute_generator_loss()
694
688
 
695
689
  engine.state.output[Keys.LOSS] = (
@@ -737,7 +731,7 @@ class AdversarialTrainer(Trainer):
737
731
  engine.state.d_network.zero_grad(set_to_none=engine.optim_set_to_none)
738
732
 
739
733
  if engine.amp and engine.state.d_scaler is not None:
740
- with torch.cuda.amp.autocast(**engine.amp_kwargs):
734
+ with torch.autocast("cuda", **engine.amp_kwargs):
741
735
  _compute_discriminator_loss()
742
736
 
743
737
  engine.state.d_scaler.scale(engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS]).backward()
monai/engines/utils.py CHANGED
@@ -309,7 +309,7 @@ class VPredictionPrepareBatch(DiffusionPrepareBatch):
309
309
  self.scheduler = scheduler
310
310
 
311
311
  def get_target(self, images, noise, timesteps):
312
- return self.scheduler.get_velocity(images, noise, timesteps)
312
+ return self.scheduler.get_velocity(images, noise, timesteps) # type: ignore[operator]
313
313
 
314
314
 
315
315
  def default_make_latent(
monai/engines/workflow.py CHANGED
@@ -90,8 +90,8 @@ class Workflow(Engine):
90
90
  default to `True`.
91
91
  to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
92
92
  `device`, `non_blocking`.
93
- amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
94
- https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
93
+ amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
94
+ https://pytorch.org/docs/stable/amp.html#torch.autocast.
95
95
 
96
96
  Raises:
97
97
  TypeError: When ``data_loader`` is not a ``torch.utils.data.DataLoader``.
@@ -574,7 +574,7 @@ class MonaiAlgo(ClientAlgo, MonaiAlgoStats):
574
574
  model_path = os.path.join(self.bundle_root, cast(str, self.model_filepaths[model_type]))
575
575
  if not os.path.isfile(model_path):
576
576
  raise ValueError(f"No best model checkpoint exists at {model_path}")
577
- weights = torch.load(model_path, map_location="cpu")
577
+ weights = torch.load(model_path, map_location="cpu", weights_only=True)
578
578
  # if weights contain several state dicts, use the one defined by `save_dict_key`
579
579
  if isinstance(weights, dict) and self.save_dict_key in weights:
580
580
  weights = weights.get(self.save_dict_key)
@@ -122,7 +122,7 @@ class CheckpointLoader:
122
122
  Args:
123
123
  engine: Ignite Engine, it can be a trainer, validator or evaluator.
124
124
  """
125
- checkpoint = torch.load(self.load_path, map_location=self.map_location)
125
+ checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=False)
126
126
 
127
127
  k, _ = list(self.load_dict.items())[0]
128
128
  # single object and checkpoint is directly a state_dict