monai-weekly 1.5.dev2509__py3-none-any.whl → 1.5.dev2510__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/deepedit/interaction.py +1 -1
  4. monai/apps/deepgrow/interaction.py +1 -1
  5. monai/apps/detection/networks/retinanet_detector.py +1 -1
  6. monai/apps/detection/networks/retinanet_network.py +5 -5
  7. monai/apps/detection/utils/box_coder.py +2 -2
  8. monai/apps/mmars/mmars.py +1 -1
  9. monai/apps/reconstruction/networks/blocks/varnetblock.py +1 -1
  10. monai/bundle/scripts.py +3 -4
  11. monai/data/dataset.py +2 -9
  12. monai/data/utils.py +1 -1
  13. monai/data/video_dataset.py +1 -1
  14. monai/engines/evaluator.py +11 -16
  15. monai/engines/trainer.py +11 -17
  16. monai/engines/utils.py +1 -1
  17. monai/engines/workflow.py +2 -2
  18. monai/fl/client/monai_algo.py +1 -1
  19. monai/handlers/checkpoint_loader.py +1 -1
  20. monai/inferers/inferer.py +6 -6
  21. monai/inferers/merger.py +16 -13
  22. monai/losses/perceptual.py +1 -1
  23. monai/losses/sure_loss.py +1 -1
  24. monai/networks/blocks/crossattention.py +1 -6
  25. monai/networks/blocks/feature_pyramid_network.py +4 -2
  26. monai/networks/blocks/selfattention.py +1 -6
  27. monai/networks/blocks/upsample.py +3 -11
  28. monai/networks/layers/vector_quantizer.py +2 -2
  29. monai/networks/nets/hovernet.py +5 -4
  30. monai/networks/nets/resnet.py +2 -2
  31. monai/networks/nets/senet.py +1 -1
  32. monai/networks/nets/swin_unetr.py +46 -49
  33. monai/networks/nets/transchex.py +3 -2
  34. monai/networks/nets/vista3d.py +7 -7
  35. monai/networks/utils.py +5 -4
  36. monai/transforms/intensity/array.py +1 -1
  37. monai/transforms/spatial/array.py +6 -6
  38. monai/utils/misc.py +1 -1
  39. monai/utils/state_cacher.py +1 -1
  40. {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2510.dist-info}/METADATA +4 -3
  41. {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2510.dist-info}/RECORD +59 -59
  42. tests/bundle/test_bundle_download.py +16 -6
  43. tests/config/test_cv2_dist.py +1 -2
  44. tests/integration/test_integration_bundle_run.py +2 -4
  45. tests/integration/test_integration_classification_2d.py +1 -1
  46. tests/integration/test_integration_fast_train.py +2 -2
  47. tests/integration/test_integration_segmentation_3d.py +1 -1
  48. tests/metrics/test_compute_multiscalessim_metric.py +3 -3
  49. tests/metrics/test_surface_dice.py +3 -3
  50. tests/networks/nets/test_autoencoderkl.py +1 -1
  51. tests/networks/nets/test_controlnet.py +1 -1
  52. tests/networks/nets/test_diffusion_model_unet.py +1 -1
  53. tests/networks/nets/test_network_consistency.py +1 -1
  54. tests/networks/nets/test_swin_unetr.py +1 -1
  55. tests/networks/nets/test_transformer.py +1 -1
  56. tests/networks/test_save_state.py +1 -1
  57. {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2510.dist-info}/LICENSE +0 -0
  58. {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2510.dist-info}/WHEEL +0 -0
  59. {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2510.dist-info}/top_level.txt +0 -0
monai/__init__.py CHANGED
@@ -136,4 +136,4 @@ except BaseException:
136
136
 
137
137
  if MONAIEnvVars.debug():
138
138
  raise
139
- __commit_id__ = "a09c1f08461cec3d2131fde3939ef38c3c4ad5fc"
139
+ __commit_id__ = "7c26e5af385eb5f7a813fa405c6f3fc87b7511fa"
monai/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2025-03-02T02:29:03+0000",
11
+ "date": "2025-03-09T02:16:22+0000",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "5f85a7bfd54b91be03213999a7c177bfe2d583b2",
15
- "version": "1.5.dev2509"
14
+ "full-revisionid": "19fadf962d87a21e1d0edf8d72299e82f7611140",
15
+ "version": "1.5.dev2510"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -72,7 +72,7 @@ class Interaction:
72
72
 
73
73
  with torch.no_grad():
74
74
  if engine.amp:
75
- with torch.cuda.amp.autocast():
75
+ with torch.autocast("cuda"):
76
76
  predictions = engine.inferer(inputs, engine.network)
77
77
  else:
78
78
  predictions = engine.inferer(inputs, engine.network)
@@ -67,7 +67,7 @@ class Interaction:
67
67
  engine.network.eval()
68
68
  with torch.no_grad():
69
69
  if engine.amp:
70
- with torch.cuda.amp.autocast():
70
+ with torch.autocast("cuda"):
71
71
  predictions = engine.inferer(inputs, engine.network)
72
72
  else:
73
73
  predictions = engine.inferer(inputs, engine.network)
@@ -180,7 +180,7 @@ class RetinaNetDetector(nn.Module):
180
180
  nesterov=True,
181
181
  )
182
182
  torch.save(detector.network.state_dict(), 'model.pt') # save model
183
- detector.network.load_state_dict(torch.load('model.pt')) # load model
183
+ detector.network.load_state_dict(torch.load('model.pt', weights_only=True)) # load model
184
184
  """
185
185
 
186
186
  def __init__(
@@ -88,8 +88,8 @@ class RetinaNetClassificationHead(nn.Module):
88
88
 
89
89
  for layer in self.conv.children():
90
90
  if isinstance(layer, conv_type): # type: ignore
91
- torch.nn.init.normal_(layer.weight, std=0.01)
92
- torch.nn.init.constant_(layer.bias, 0)
91
+ torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
92
+ torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type]
93
93
 
94
94
  self.cls_logits = conv_type(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
95
95
  torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
@@ -167,8 +167,8 @@ class RetinaNetRegressionHead(nn.Module):
167
167
 
168
168
  for layer in self.conv.children():
169
169
  if isinstance(layer, conv_type): # type: ignore
170
- torch.nn.init.normal_(layer.weight, std=0.01)
171
- torch.nn.init.zeros_(layer.bias)
170
+ torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
171
+ torch.nn.init.zeros_(layer.bias) # type: ignore[arg-type]
172
172
 
173
173
  def forward(self, x: list[Tensor]) -> list[Tensor]:
174
174
  """
@@ -297,7 +297,7 @@ class RetinaNet(nn.Module):
297
297
  )
298
298
  self.feature_extractor = feature_extractor
299
299
 
300
- self.feature_map_channels: int = self.feature_extractor.out_channels
300
+ self.feature_map_channels: int = self.feature_extractor.out_channels # type: ignore[assignment]
301
301
  self.num_anchors = num_anchors
302
302
  self.classification_head = RetinaNetClassificationHead(
303
303
  self.feature_map_channels, self.num_anchors, self.num_classes, spatial_dims=self.spatial_dims
@@ -221,7 +221,7 @@ class BoxCoder:
221
221
 
222
222
  pred_ctr_xyx_axis = dxyz_axis * whd_axis[:, None] + ctr_xyz_axis[:, None]
223
223
  pred_whd_axis = torch.exp(dwhd_axis) * whd_axis[:, None]
224
- pred_whd_axis = pred_whd_axis.to(dxyz_axis.dtype)
224
+ pred_whd_axis = pred_whd_axis.to(dxyz_axis.dtype) # type: ignore[union-attr]
225
225
 
226
226
  # When convert float32 to float16, Inf or Nan may occur
227
227
  if torch.isnan(pred_whd_axis).any() or torch.isinf(pred_whd_axis).any():
@@ -229,7 +229,7 @@ class BoxCoder:
229
229
 
230
230
  # Distance from center to box's corner.
231
231
  c_to_c_whd_axis = (
232
- torch.tensor(0.5, dtype=pred_ctr_xyx_axis.dtype, device=pred_whd_axis.device) * pred_whd_axis
232
+ torch.tensor(0.5, dtype=pred_ctr_xyx_axis.dtype, device=pred_whd_axis.device) * pred_whd_axis # type: ignore[arg-type]
233
233
  )
234
234
 
235
235
  pred_boxes.append(pred_ctr_xyx_axis - c_to_c_whd_axis)
monai/apps/mmars/mmars.py CHANGED
@@ -241,7 +241,7 @@ def load_from_mmar(
241
241
  return torch.jit.load(_model_file, map_location=map_location)
242
242
 
243
243
  # loading with `torch.load`
244
- model_dict = torch.load(_model_file, map_location=map_location)
244
+ model_dict = torch.load(_model_file, map_location=map_location, weights_only=True)
245
245
  if weights_only:
246
246
  return model_dict.get(model_key, model_dict) # model_dict[model_key] or model_dict directly
247
247
 
@@ -55,7 +55,7 @@ class VarNetBlock(nn.Module):
55
55
  Returns:
56
56
  Output of DC block with the same shape as x
57
57
  """
58
- return torch.where(mask, x - ref_kspace, self.zeros) * self.dc_weight
58
+ return torch.where(mask, x - ref_kspace, self.zeros) * self.dc_weight # type: ignore
59
59
 
60
60
  def forward(self, current_kspace: Tensor, ref_kspace: Tensor, mask: Tensor, sens_maps: Tensor) -> Tensor:
61
61
  """
monai/bundle/scripts.py CHANGED
@@ -760,7 +760,7 @@ def load(
760
760
  if load_ts_module is True:
761
761
  return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files)
762
762
  # loading with `torch.load`
763
- model_dict = torch.load(full_path, map_location=torch.device(device))
763
+ model_dict = torch.load(full_path, map_location=torch.device(device), weights_only=True)
764
764
 
765
765
  if not isinstance(model_dict, Mapping):
766
766
  warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.")
@@ -1279,9 +1279,8 @@ def verify_net_in_out(
1279
1279
  if input_dtype == torch.float16:
1280
1280
  # fp16 can only be executed in gpu mode
1281
1281
  net.to("cuda")
1282
- from torch.cuda.amp import autocast
1283
1282
 
1284
- with autocast():
1283
+ with torch.autocast("cuda"):
1285
1284
  output = net(test_data.cuda(), **extra_forward_args_)
1286
1285
  net.to(device_)
1287
1286
  else:
@@ -1330,7 +1329,7 @@ def _export(
1330
1329
  # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver
1331
1330
  Checkpoint.load_objects(to_load={key_in_ckpt: net}, checkpoint=ckpt_file)
1332
1331
  else:
1333
- ckpt = torch.load(ckpt_file)
1332
+ ckpt = torch.load(ckpt_file, weights_only=True)
1334
1333
  copy_model_state(dst=net, src=ckpt if key_in_ckpt == "" else ckpt[key_in_ckpt])
1335
1334
 
1336
1335
  # Use the given converter to convert a model and save with metadata, config content
monai/data/dataset.py CHANGED
@@ -22,7 +22,6 @@ import time
22
22
  import warnings
23
23
  from collections.abc import Callable, Sequence
24
24
  from copy import copy, deepcopy
25
- from inspect import signature
26
25
  from multiprocessing.managers import ListProxy
27
26
  from multiprocessing.pool import ThreadPool
28
27
  from pathlib import Path
@@ -372,10 +371,7 @@ class PersistentDataset(Dataset):
372
371
 
373
372
  if hashfile is not None and hashfile.is_file(): # cache hit
374
373
  try:
375
- if "weights_only" in signature(torch.load).parameters:
376
- return torch.load(hashfile, weights_only=False)
377
- else:
378
- return torch.load(hashfile)
374
+ return torch.load(hashfile, weights_only=False)
379
375
  except PermissionError as e:
380
376
  if sys.platform != "win32":
381
377
  raise e
@@ -1674,7 +1670,4 @@ class GDSDataset(PersistentDataset):
1674
1670
  if meta_hash_file_name in self._meta_cache:
1675
1671
  return self._meta_cache[meta_hash_file_name]
1676
1672
  else:
1677
- if "weights_only" in signature(torch.load).parameters:
1678
- return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
1679
- else:
1680
- return torch.load(self.cache_dir / meta_hash_file_name)
1673
+ return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
monai/data/utils.py CHANGED
@@ -753,7 +753,7 @@ def affine_to_spacing(affine: NdarrayTensor, r: int = 3, dtype=float, suppress_z
753
753
  if isinstance(_affine, torch.Tensor):
754
754
  spacing = torch.sqrt(torch.sum(_affine * _affine, dim=0))
755
755
  else:
756
- spacing = np.sqrt(np.sum(_affine * _affine, axis=0))
756
+ spacing = np.sqrt(np.sum(_affine * _affine, axis=0)) # type: ignore[operator]
757
757
  if suppress_zeros:
758
758
  spacing[spacing == 0] = 1.0
759
759
  spacing_, *_ = convert_to_dst_type(spacing, dst=affine, dtype=dtype)
@@ -177,7 +177,7 @@ class VideoFileDataset(Dataset, VideoDataset):
177
177
  for codec, ext in all_codecs.items():
178
178
  writer = cv2.VideoWriter()
179
179
  fname = os.path.join(tmp_dir, f"test{ext}")
180
- fourcc = cv2.VideoWriter_fourcc(*codec)
180
+ fourcc = cv2.VideoWriter_fourcc(*codec) # type: ignore[attr-defined]
181
181
  noviderr = writer.open(fname, fourcc, 1, (10, 10))
182
182
  if noviderr:
183
183
  codecs[codec] = ext
@@ -28,7 +28,7 @@ from monai.transforms import Transform
28
28
  from monai.utils import ForwardMode, IgniteInfo, ensure_tuple, min_version, optional_import
29
29
  from monai.utils.enums import CommonKeys as Keys
30
30
  from monai.utils.enums import EngineStatsKeys as ESKeys
31
- from monai.utils.module import look_up_option, pytorch_after
31
+ from monai.utils.module import look_up_option
32
32
 
33
33
  if TYPE_CHECKING:
34
34
  from ignite.engine import Engine, EventEnum
@@ -82,8 +82,8 @@ class Evaluator(Workflow):
82
82
  default to `True`.
83
83
  to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
84
84
  `device`, `non_blocking`.
85
- amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
86
- https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
85
+ amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
86
+ https://pytorch.org/docs/stable/amp.html#torch.autocast.
87
87
 
88
88
  """
89
89
 
@@ -214,8 +214,8 @@ class SupervisedEvaluator(Evaluator):
214
214
  default to `True`.
215
215
  to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
216
216
  `device`, `non_blocking`.
217
- amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
218
- https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
217
+ amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
218
+ https://pytorch.org/docs/stable/amp.html#torch.autocast.
219
219
  compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
220
220
  `torch.Tensor` before forward pass, then converted back afterward with copied meta information.
221
221
  compile_kwargs: dict of the args for `torch.compile()` API, for more details:
@@ -269,13 +269,8 @@ class SupervisedEvaluator(Evaluator):
269
269
  amp_kwargs=amp_kwargs,
270
270
  )
271
271
  if compile:
272
- if pytorch_after(2, 1):
273
- compile_kwargs = {} if compile_kwargs is None else compile_kwargs
274
- network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
275
- else:
276
- warnings.warn(
277
- "Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done"
278
- )
272
+ compile_kwargs = {} if compile_kwargs is None else compile_kwargs
273
+ network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
279
274
  self.network = network
280
275
  self.compile = compile
281
276
  self.inferer = SimpleInferer() if inferer is None else inferer
@@ -329,7 +324,7 @@ class SupervisedEvaluator(Evaluator):
329
324
  # execute forward computation
330
325
  with engine.mode(engine.network):
331
326
  if engine.amp:
332
- with torch.cuda.amp.autocast(**engine.amp_kwargs):
327
+ with torch.autocast("cuda", **engine.amp_kwargs):
333
328
  engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
334
329
  else:
335
330
  engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
@@ -399,8 +394,8 @@ class EnsembleEvaluator(Evaluator):
399
394
  default to `True`.
400
395
  to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
401
396
  `device`, `non_blocking`.
402
- amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
403
- https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
397
+ amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
398
+ https://pytorch.org/docs/stable/amp.html#torch.autocast.
404
399
 
405
400
  """
406
401
 
@@ -492,7 +487,7 @@ class EnsembleEvaluator(Evaluator):
492
487
  for idx, network in enumerate(engine.networks):
493
488
  with engine.mode(network):
494
489
  if engine.amp:
495
- with torch.cuda.amp.autocast(**engine.amp_kwargs):
490
+ with torch.autocast("cuda", **engine.amp_kwargs):
496
491
  if isinstance(engine.state.output, dict):
497
492
  engine.state.output.update(
498
493
  {engine.pred_keys[idx]: engine.inferer(inputs, network, *args, **kwargs)}
monai/engines/trainer.py CHANGED
@@ -27,7 +27,6 @@ from monai.transforms import Transform
27
27
  from monai.utils import AdversarialIterationEvents, AdversarialKeys, GanKeys, IgniteInfo, min_version, optional_import
28
28
  from monai.utils.enums import CommonKeys as Keys
29
29
  from monai.utils.enums import EngineStatsKeys as ESKeys
30
- from monai.utils.module import pytorch_after
31
30
 
32
31
  if TYPE_CHECKING:
33
32
  from ignite.engine import Engine, EventEnum
@@ -126,8 +125,8 @@ class SupervisedTrainer(Trainer):
126
125
  more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
127
126
  to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
128
127
  `device`, `non_blocking`.
129
- amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
130
- https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
128
+ amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
129
+ https://pytorch.org/docs/stable/amp.html#torch.autocast.
131
130
  compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
132
131
  `torch.Tensor` before forward pass, then converted back afterward with copied meta information.
133
132
  compile_kwargs: dict of the args for `torch.compile()` API, for more details:
@@ -183,13 +182,8 @@ class SupervisedTrainer(Trainer):
183
182
  amp_kwargs=amp_kwargs,
184
183
  )
185
184
  if compile:
186
- if pytorch_after(2, 1):
187
- compile_kwargs = {} if compile_kwargs is None else compile_kwargs
188
- network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
189
- else:
190
- warnings.warn(
191
- "Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done"
192
- )
185
+ compile_kwargs = {} if compile_kwargs is None else compile_kwargs
186
+ network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
193
187
  self.network = network
194
188
  self.compile = compile
195
189
  self.optimizer = optimizer
@@ -255,7 +249,7 @@ class SupervisedTrainer(Trainer):
255
249
  engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
256
250
 
257
251
  if engine.amp and engine.scaler is not None:
258
- with torch.cuda.amp.autocast(**engine.amp_kwargs):
252
+ with torch.autocast("cuda", **engine.amp_kwargs):
259
253
  _compute_pred_loss()
260
254
  engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
261
255
  engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
@@ -341,8 +335,8 @@ class GanTrainer(Trainer):
341
335
  more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
342
336
  to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
343
337
  `device`, `non_blocking`.
344
- amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
345
- https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
338
+ amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
339
+ https://pytorch.org/docs/stable/amp.html#torch.autocast.
346
340
 
347
341
  """
348
342
 
@@ -518,8 +512,8 @@ class AdversarialTrainer(Trainer):
518
512
  more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
519
513
  to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
520
514
  `device`, `non_blocking`.
521
- amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
522
- https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
515
+ amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
516
+ https://pytorch.org/docs/stable/amp.html#torch.autocast.
523
517
  """
524
518
 
525
519
  def __init__(
@@ -689,7 +683,7 @@ class AdversarialTrainer(Trainer):
689
683
  engine.state.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
690
684
 
691
685
  if engine.amp and engine.state.g_scaler is not None:
692
- with torch.cuda.amp.autocast(**engine.amp_kwargs):
686
+ with torch.autocast("cuda", **engine.amp_kwargs):
693
687
  _compute_generator_loss()
694
688
 
695
689
  engine.state.output[Keys.LOSS] = (
@@ -737,7 +731,7 @@ class AdversarialTrainer(Trainer):
737
731
  engine.state.d_network.zero_grad(set_to_none=engine.optim_set_to_none)
738
732
 
739
733
  if engine.amp and engine.state.d_scaler is not None:
740
- with torch.cuda.amp.autocast(**engine.amp_kwargs):
734
+ with torch.autocast("cuda", **engine.amp_kwargs):
741
735
  _compute_discriminator_loss()
742
736
 
743
737
  engine.state.d_scaler.scale(engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS]).backward()
monai/engines/utils.py CHANGED
@@ -309,7 +309,7 @@ class VPredictionPrepareBatch(DiffusionPrepareBatch):
309
309
  self.scheduler = scheduler
310
310
 
311
311
  def get_target(self, images, noise, timesteps):
312
- return self.scheduler.get_velocity(images, noise, timesteps)
312
+ return self.scheduler.get_velocity(images, noise, timesteps) # type: ignore[operator]
313
313
 
314
314
 
315
315
  def default_make_latent(
monai/engines/workflow.py CHANGED
@@ -90,8 +90,8 @@ class Workflow(Engine):
90
90
  default to `True`.
91
91
  to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
92
92
  `device`, `non_blocking`.
93
- amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
94
- https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
93
+ amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
94
+ https://pytorch.org/docs/stable/amp.html#torch.autocast.
95
95
 
96
96
  Raises:
97
97
  TypeError: When ``data_loader`` is not a ``torch.utils.data.DataLoader``.
@@ -574,7 +574,7 @@ class MonaiAlgo(ClientAlgo, MonaiAlgoStats):
574
574
  model_path = os.path.join(self.bundle_root, cast(str, self.model_filepaths[model_type]))
575
575
  if not os.path.isfile(model_path):
576
576
  raise ValueError(f"No best model checkpoint exists at {model_path}")
577
- weights = torch.load(model_path, map_location="cpu")
577
+ weights = torch.load(model_path, map_location="cpu", weights_only=True)
578
578
  # if weights contain several state dicts, use the one defined by `save_dict_key`
579
579
  if isinstance(weights, dict) and self.save_dict_key in weights:
580
580
  weights = weights.get(self.save_dict_key)
@@ -122,7 +122,7 @@ class CheckpointLoader:
122
122
  Args:
123
123
  engine: Ignite Engine, it can be a trainer, validator or evaluator.
124
124
  """
125
- checkpoint = torch.load(self.load_path, map_location=self.map_location)
125
+ checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=False)
126
126
 
127
127
  k, _ = list(self.load_dict.items())[0]
128
128
  # single object and checkpoint is directly a state_dict
monai/inferers/inferer.py CHANGED
@@ -882,7 +882,7 @@ class DiffusionInferer(Inferer):
882
882
  )
883
883
 
884
884
  # 2. compute previous image: x_t -> x_t-1
885
- image, _ = scheduler.step(model_output, t, image)
885
+ image, _ = scheduler.step(model_output, t, image) # type: ignore[operator]
886
886
  if save_intermediates and t % intermediate_steps == 0:
887
887
  intermediates.append(image)
888
888
  if save_intermediates:
@@ -986,8 +986,8 @@ class DiffusionInferer(Inferer):
986
986
  predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image
987
987
 
988
988
  # get the posterior mean and variance
989
- posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image)
990
- posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance)
989
+ posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) # type: ignore[operator]
990
+ posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) # type: ignore[operator]
991
991
 
992
992
  log_posterior_variance = torch.log(posterior_variance)
993
993
  log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance
@@ -1436,7 +1436,7 @@ class ControlNetDiffusionInferer(DiffusionInferer):
1436
1436
  )
1437
1437
 
1438
1438
  # 3. compute previous image: x_t -> x_t-1
1439
- image, _ = scheduler.step(model_output, t, image)
1439
+ image, _ = scheduler.step(model_output, t, image) # type: ignore[operator]
1440
1440
  if save_intermediates and t % intermediate_steps == 0:
1441
1441
  intermediates.append(image)
1442
1442
  if save_intermediates:
@@ -1562,8 +1562,8 @@ class ControlNetDiffusionInferer(DiffusionInferer):
1562
1562
  predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image
1563
1563
 
1564
1564
  # get the posterior mean and variance
1565
- posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image)
1566
- posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance)
1565
+ posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) # type: ignore[operator]
1566
+ posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) # type: ignore[operator]
1567
1567
 
1568
1568
  log_posterior_variance = torch.log(posterior_variance)
1569
1569
  log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance
monai/inferers/merger.py CHANGED
@@ -53,8 +53,11 @@ class Merger(ABC):
53
53
  cropped_shape: Sequence[int] | None = None,
54
54
  device: torch.device | str | None = None,
55
55
  ) -> None:
56
- self.merged_shape = merged_shape
57
- self.cropped_shape = self.merged_shape if cropped_shape is None else cropped_shape
56
+ if merged_shape is None:
57
+ raise ValueError("Argument `merged_shape` must be provided")
58
+
59
+ self.merged_shape: tuple[int, ...] = tuple(merged_shape)
60
+ self.cropped_shape: tuple[int, ...] = tuple(self.merged_shape if cropped_shape is None else cropped_shape)
58
61
  self.device = device
59
62
  self.is_finalized = False
60
63
 
@@ -231,9 +234,9 @@ class ZarrAvgMerger(Merger):
231
234
  dtype: np.dtype | str = "float32",
232
235
  value_dtype: np.dtype | str = "float32",
233
236
  count_dtype: np.dtype | str = "uint8",
234
- store: zarr.storage.Store | str = "merged.zarr",
235
- value_store: zarr.storage.Store | str | None = None,
236
- count_store: zarr.storage.Store | str | None = None,
237
+ store: zarr.storage.Store | str = "merged.zarr", # type: ignore
238
+ value_store: zarr.storage.Store | str | None = None, # type: ignore
239
+ count_store: zarr.storage.Store | str | None = None, # type: ignore
237
240
  compressor: str | None = None,
238
241
  value_compressor: str | None = None,
239
242
  count_compressor: str | None = None,
@@ -251,18 +254,18 @@ class ZarrAvgMerger(Merger):
251
254
  if version_geq(get_package_version("zarr"), "3.0.0"):
252
255
  if value_store is None:
253
256
  self.tmpdir = TemporaryDirectory()
254
- self.value_store = zarr.storage.LocalStore(self.tmpdir.name)
257
+ self.value_store = zarr.storage.LocalStore(self.tmpdir.name) # type: ignore
255
258
  else:
256
- self.value_store = value_store
259
+ self.value_store = value_store # type: ignore
257
260
  if count_store is None:
258
261
  self.tmpdir = TemporaryDirectory()
259
- self.count_store = zarr.storage.LocalStore(self.tmpdir.name)
262
+ self.count_store = zarr.storage.LocalStore(self.tmpdir.name) # type: ignore
260
263
  else:
261
- self.count_store = count_store
264
+ self.count_store = count_store # type: ignore
262
265
  else:
263
266
  self.tmpdir = None
264
- self.value_store = zarr.storage.TempStore() if value_store is None else value_store
265
- self.count_store = zarr.storage.TempStore() if count_store is None else count_store
267
+ self.value_store = zarr.storage.TempStore() if value_store is None else value_store # type: ignore
268
+ self.count_store = zarr.storage.TempStore() if count_store is None else count_store # type: ignore
266
269
  self.chunks = chunks
267
270
  self.compressor = compressor
268
271
  self.value_compressor = value_compressor
@@ -314,7 +317,7 @@ class ZarrAvgMerger(Merger):
314
317
  map_slice = ensure_tuple_size(map_slice, values.ndim, pad_val=slice(None), pad_from_start=True)
315
318
  with self.lock:
316
319
  self.values[map_slice] += values.numpy()
317
- self.counts[map_slice] += 1
320
+ self.counts[map_slice] += 1 # type: ignore[operator]
318
321
 
319
322
  def finalize(self) -> zarr.Array:
320
323
  """
@@ -332,7 +335,7 @@ class ZarrAvgMerger(Merger):
332
335
  if not self.is_finalized:
333
336
  # use chunks for division to fit into memory
334
337
  for chunk in iterate_over_chunks(self.values.chunks, self.values.cdata_shape):
335
- self.output[chunk] = self.values[chunk] / self.counts[chunk]
338
+ self.output[chunk] = self.values[chunk] / self.counts[chunk] # type: ignore[operator]
336
339
  # finalize the shape
337
340
  self.output.resize(self.cropped_shape)
338
341
  # set finalize flag to protect performing in-place division again
@@ -374,7 +374,7 @@ class TorchvisionModelPerceptualSimilarity(nn.Module):
374
374
  else:
375
375
  network = torchvision.models.resnet50(weights=None)
376
376
  if pretrained is True:
377
- state_dict = torch.load(pretrained_path)
377
+ state_dict = torch.load(pretrained_path, weights_only=True)
378
378
  if pretrained_state_dict_key is not None:
379
379
  state_dict = state_dict[pretrained_state_dict_key]
380
380
  network.load_state_dict(state_dict)
monai/losses/sure_loss.py CHANGED
@@ -92,7 +92,7 @@ def sure_loss_function(
92
92
  y_ref = operator(x)
93
93
 
94
94
  # get perturbed output
95
- x_perturbed = x + eps * perturb_noise
95
+ x_perturbed = x + eps * perturb_noise # type: ignore
96
96
  y_perturbed = operator(x_perturbed)
97
97
  # divergence
98
98
  divergence = torch.sum(1.0 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref)) # type: ignore
@@ -17,7 +17,7 @@ import torch
17
17
  import torch.nn as nn
18
18
 
19
19
  from monai.networks.layers.utils import get_rel_pos_embedding_layer
20
- from monai.utils import optional_import, pytorch_after
20
+ from monai.utils import optional_import
21
21
 
22
22
  Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
23
23
 
@@ -84,11 +84,6 @@ class CrossAttentionBlock(nn.Module):
84
84
  if causal and sequence_length is None:
85
85
  raise ValueError("sequence_length is necessary for causal attention.")
86
86
 
87
- if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
88
- raise ValueError(
89
- "use_flash_attention is only supported for PyTorch versions >= 2.0."
90
- "Upgrade your PyTorch or set the flag to False."
91
- )
92
87
  if use_flash_attention and save_attn:
93
88
  raise ValueError(
94
89
  "save_attn has been set to True, but use_flash_attention is also set"
@@ -54,7 +54,9 @@ from __future__ import annotations
54
54
 
55
55
  from collections import OrderedDict
56
56
  from collections.abc import Callable
57
+ from typing import cast
57
58
 
59
+ import torch
58
60
  import torch.nn.functional as F
59
61
  from torch import Tensor, nn
60
62
 
@@ -194,8 +196,8 @@ class FeaturePyramidNetwork(nn.Module):
194
196
  conv_type_: type[nn.Module] = Conv[Conv.CONV, spatial_dims]
195
197
  for m in self.modules():
196
198
  if isinstance(m, conv_type_):
197
- nn.init.kaiming_uniform_(m.weight, a=1)
198
- nn.init.constant_(m.bias, 0.0)
199
+ nn.init.kaiming_uniform_(cast(torch.Tensor, m.weight), a=1)
200
+ nn.init.constant_(cast(torch.Tensor, m.bias), 0.0)
199
201
 
200
202
  if extra_blocks is not None:
201
203
  if not isinstance(extra_blocks, ExtraFPNBlock):
@@ -18,7 +18,7 @@ import torch.nn as nn
18
18
  import torch.nn.functional as F
19
19
 
20
20
  from monai.networks.layers.utils import get_rel_pos_embedding_layer
21
- from monai.utils import optional_import, pytorch_after
21
+ from monai.utils import optional_import
22
22
 
23
23
  Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
24
24
 
@@ -90,11 +90,6 @@ class SABlock(nn.Module):
90
90
  if causal and sequence_length is None:
91
91
  raise ValueError("sequence_length is necessary for causal attention.")
92
92
 
93
- if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
94
- raise ValueError(
95
- "use_flash_attention is only supported for PyTorch versions >= 2.0."
96
- "Upgrade your PyTorch or set the flag to False."
97
- )
98
93
  if use_flash_attention and save_attn:
99
94
  raise ValueError(
100
95
  "save_attn has been set to True, but use_flash_attention is also set"
@@ -17,8 +17,8 @@ import torch
17
17
  import torch.nn as nn
18
18
 
19
19
  from monai.networks.layers.factories import Conv, Pad, Pool
20
- from monai.networks.utils import CastTempType, icnr_init, pixelshuffle
21
- from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option, pytorch_after
20
+ from monai.networks.utils import icnr_init, pixelshuffle
21
+ from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option
22
22
 
23
23
  __all__ = ["Upsample", "UpSample", "SubpixelUpsample", "Subpixelupsample", "SubpixelUpSample"]
24
24
 
@@ -164,15 +164,7 @@ class UpSample(nn.Sequential):
164
164
  align_corners=align_corners,
165
165
  )
166
166
 
167
- # Cast to float32 as 'upsample_nearest2d_out_frame' op does not support bfloat16
168
- # https://github.com/pytorch/pytorch/issues/86679. This issue is solved in PyTorch 2.1
169
- if pytorch_after(major=2, minor=1):
170
- self.add_module("upsample_non_trainable", upsample)
171
- else:
172
- self.add_module(
173
- "upsample_non_trainable",
174
- CastTempType(initial_type=torch.bfloat16, temporary_type=torch.float32, submodule=upsample),
175
- )
167
+ self.add_module("upsample_non_trainable", upsample)
176
168
  if post_conv:
177
169
  self.add_module("postconv", post_conv)
178
170
  elif up_mode == UpsampleMode.PIXELSHUFFLE: