monai-weekly 1.5.dev2508__py3-none-any.whl → 1.5.dev2510__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/deepedit/interaction.py +1 -1
  4. monai/apps/deepgrow/interaction.py +1 -1
  5. monai/apps/detection/networks/retinanet_detector.py +1 -1
  6. monai/apps/detection/networks/retinanet_network.py +5 -5
  7. monai/apps/detection/utils/box_coder.py +2 -2
  8. monai/apps/mmars/mmars.py +1 -1
  9. monai/apps/reconstruction/networks/blocks/varnetblock.py +1 -1
  10. monai/bundle/scripts.py +42 -20
  11. monai/data/dataset.py +2 -9
  12. monai/data/utils.py +1 -1
  13. monai/data/video_dataset.py +1 -1
  14. monai/engines/evaluator.py +11 -16
  15. monai/engines/trainer.py +11 -17
  16. monai/engines/utils.py +1 -1
  17. monai/engines/workflow.py +2 -2
  18. monai/fl/client/monai_algo.py +1 -1
  19. monai/handlers/checkpoint_loader.py +1 -1
  20. monai/inferers/inferer.py +35 -17
  21. monai/inferers/merger.py +16 -13
  22. monai/losses/perceptual.py +1 -1
  23. monai/losses/sure_loss.py +1 -1
  24. monai/networks/blocks/crossattention.py +1 -6
  25. monai/networks/blocks/feature_pyramid_network.py +4 -2
  26. monai/networks/blocks/selfattention.py +1 -6
  27. monai/networks/blocks/upsample.py +3 -11
  28. monai/networks/layers/vector_quantizer.py +2 -2
  29. monai/networks/nets/hovernet.py +5 -4
  30. monai/networks/nets/resnet.py +2 -2
  31. monai/networks/nets/senet.py +1 -1
  32. monai/networks/nets/swin_unetr.py +46 -49
  33. monai/networks/nets/transchex.py +3 -2
  34. monai/networks/nets/vista3d.py +7 -7
  35. monai/networks/utils.py +5 -4
  36. monai/transforms/intensity/array.py +1 -1
  37. monai/transforms/spatial/array.py +6 -6
  38. monai/utils/misc.py +1 -1
  39. monai/utils/state_cacher.py +1 -1
  40. {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2510.dist-info}/METADATA +4 -3
  41. {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2510.dist-info}/RECORD +60 -60
  42. {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2510.dist-info}/WHEEL +1 -1
  43. tests/bundle/test_bundle_download.py +16 -6
  44. tests/config/test_cv2_dist.py +1 -2
  45. tests/inferers/test_controlnet_inferers.py +9 -0
  46. tests/integration/test_integration_bundle_run.py +2 -4
  47. tests/integration/test_integration_classification_2d.py +1 -1
  48. tests/integration/test_integration_fast_train.py +2 -2
  49. tests/integration/test_integration_segmentation_3d.py +1 -1
  50. tests/metrics/test_compute_multiscalessim_metric.py +3 -3
  51. tests/metrics/test_surface_dice.py +3 -3
  52. tests/networks/nets/test_autoencoderkl.py +1 -1
  53. tests/networks/nets/test_controlnet.py +1 -1
  54. tests/networks/nets/test_diffusion_model_unet.py +1 -1
  55. tests/networks/nets/test_network_consistency.py +1 -1
  56. tests/networks/nets/test_swin_unetr.py +1 -1
  57. tests/networks/nets/test_transformer.py +1 -1
  58. tests/networks/test_save_state.py +1 -1
  59. {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2510.dist-info}/LICENSE +0 -0
  60. {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2510.dist-info}/top_level.txt +0 -0
monai/inferers/inferer.py CHANGED
@@ -882,7 +882,7 @@ class DiffusionInferer(Inferer):
882
882
  )
883
883
 
884
884
  # 2. compute previous image: x_t -> x_t-1
885
- image, _ = scheduler.step(model_output, t, image)
885
+ image, _ = scheduler.step(model_output, t, image) # type: ignore[operator]
886
886
  if save_intermediates and t % intermediate_steps == 0:
887
887
  intermediates.append(image)
888
888
  if save_intermediates:
@@ -986,8 +986,8 @@ class DiffusionInferer(Inferer):
986
986
  predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image
987
987
 
988
988
  # get the posterior mean and variance
989
- posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image)
990
- posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance)
989
+ posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) # type: ignore[operator]
990
+ posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) # type: ignore[operator]
991
991
 
992
992
  log_posterior_variance = torch.log(posterior_variance)
993
993
  log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance
@@ -1334,13 +1334,15 @@ class ControlNetDiffusionInferer(DiffusionInferer):
1334
1334
  raise NotImplementedError(f"{mode} condition is not supported")
1335
1335
 
1336
1336
  noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
1337
- down_block_res_samples, mid_block_res_sample = controlnet(
1338
- x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond
1339
- )
1337
+
1340
1338
  if mode == "concat" and condition is not None:
1341
1339
  noisy_image = torch.cat([noisy_image, condition], dim=1)
1342
1340
  condition = None
1343
1341
 
1342
+ down_block_res_samples, mid_block_res_sample = controlnet(
1343
+ x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond, context=condition
1344
+ )
1345
+
1344
1346
  diffuse = diffusion_model
1345
1347
  if isinstance(diffusion_model, SPADEDiffusionModelUNet):
1346
1348
  diffuse = partial(diffusion_model, seg=seg)
@@ -1396,17 +1398,21 @@ class ControlNetDiffusionInferer(DiffusionInferer):
1396
1398
  progress_bar = iter(scheduler.timesteps)
1397
1399
  intermediates = []
1398
1400
  for t in progress_bar:
1399
- # 1. ControlNet forward
1400
- down_block_res_samples, mid_block_res_sample = controlnet(
1401
- x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond
1402
- )
1403
- # 2. predict noise model_output
1404
1401
  diffuse = diffusion_model
1405
1402
  if isinstance(diffusion_model, SPADEDiffusionModelUNet):
1406
1403
  diffuse = partial(diffusion_model, seg=seg)
1407
1404
 
1408
1405
  if mode == "concat" and conditioning is not None:
1406
+ # 1. Conditioning
1409
1407
  model_input = torch.cat([image, conditioning], dim=1)
1408
+ # 2. ControlNet forward
1409
+ down_block_res_samples, mid_block_res_sample = controlnet(
1410
+ x=model_input,
1411
+ timesteps=torch.Tensor((t,)).to(input_noise.device),
1412
+ controlnet_cond=cn_cond,
1413
+ context=None,
1414
+ )
1415
+ # 3. predict noise model_output
1410
1416
  model_output = diffuse(
1411
1417
  model_input,
1412
1418
  timesteps=torch.Tensor((t,)).to(input_noise.device),
@@ -1415,6 +1421,12 @@ class ControlNetDiffusionInferer(DiffusionInferer):
1415
1421
  mid_block_additional_residual=mid_block_res_sample,
1416
1422
  )
1417
1423
  else:
1424
+ down_block_res_samples, mid_block_res_sample = controlnet(
1425
+ x=image,
1426
+ timesteps=torch.Tensor((t,)).to(input_noise.device),
1427
+ controlnet_cond=cn_cond,
1428
+ context=conditioning,
1429
+ )
1418
1430
  model_output = diffuse(
1419
1431
  image,
1420
1432
  timesteps=torch.Tensor((t,)).to(input_noise.device),
@@ -1424,7 +1436,7 @@ class ControlNetDiffusionInferer(DiffusionInferer):
1424
1436
  )
1425
1437
 
1426
1438
  # 3. compute previous image: x_t -> x_t-1
1427
- image, _ = scheduler.step(model_output, t, image)
1439
+ image, _ = scheduler.step(model_output, t, image) # type: ignore[operator]
1428
1440
  if save_intermediates and t % intermediate_steps == 0:
1429
1441
  intermediates.append(image)
1430
1442
  if save_intermediates:
@@ -1485,9 +1497,6 @@ class ControlNetDiffusionInferer(DiffusionInferer):
1485
1497
  for t in progress_bar:
1486
1498
  timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
1487
1499
  noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
1488
- down_block_res_samples, mid_block_res_sample = controlnet(
1489
- x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond
1490
- )
1491
1500
 
1492
1501
  diffuse = diffusion_model
1493
1502
  if isinstance(diffusion_model, SPADEDiffusionModelUNet):
@@ -1495,6 +1504,9 @@ class ControlNetDiffusionInferer(DiffusionInferer):
1495
1504
 
1496
1505
  if mode == "concat" and conditioning is not None:
1497
1506
  noisy_image = torch.cat([noisy_image, conditioning], dim=1)
1507
+ down_block_res_samples, mid_block_res_sample = controlnet(
1508
+ x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond, context=None
1509
+ )
1498
1510
  model_output = diffuse(
1499
1511
  noisy_image,
1500
1512
  timesteps=timesteps,
@@ -1503,6 +1515,12 @@ class ControlNetDiffusionInferer(DiffusionInferer):
1503
1515
  mid_block_additional_residual=mid_block_res_sample,
1504
1516
  )
1505
1517
  else:
1518
+ down_block_res_samples, mid_block_res_sample = controlnet(
1519
+ x=noisy_image,
1520
+ timesteps=torch.Tensor((t,)).to(inputs.device),
1521
+ controlnet_cond=cn_cond,
1522
+ context=conditioning,
1523
+ )
1506
1524
  model_output = diffuse(
1507
1525
  x=noisy_image,
1508
1526
  timesteps=timesteps,
@@ -1544,8 +1562,8 @@ class ControlNetDiffusionInferer(DiffusionInferer):
1544
1562
  predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image
1545
1563
 
1546
1564
  # get the posterior mean and variance
1547
- posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image)
1548
- posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance)
1565
+ posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) # type: ignore[operator]
1566
+ posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) # type: ignore[operator]
1549
1567
 
1550
1568
  log_posterior_variance = torch.log(posterior_variance)
1551
1569
  log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance
monai/inferers/merger.py CHANGED
@@ -53,8 +53,11 @@ class Merger(ABC):
53
53
  cropped_shape: Sequence[int] | None = None,
54
54
  device: torch.device | str | None = None,
55
55
  ) -> None:
56
- self.merged_shape = merged_shape
57
- self.cropped_shape = self.merged_shape if cropped_shape is None else cropped_shape
56
+ if merged_shape is None:
57
+ raise ValueError("Argument `merged_shape` must be provided")
58
+
59
+ self.merged_shape: tuple[int, ...] = tuple(merged_shape)
60
+ self.cropped_shape: tuple[int, ...] = tuple(self.merged_shape if cropped_shape is None else cropped_shape)
58
61
  self.device = device
59
62
  self.is_finalized = False
60
63
 
@@ -231,9 +234,9 @@ class ZarrAvgMerger(Merger):
231
234
  dtype: np.dtype | str = "float32",
232
235
  value_dtype: np.dtype | str = "float32",
233
236
  count_dtype: np.dtype | str = "uint8",
234
- store: zarr.storage.Store | str = "merged.zarr",
235
- value_store: zarr.storage.Store | str | None = None,
236
- count_store: zarr.storage.Store | str | None = None,
237
+ store: zarr.storage.Store | str = "merged.zarr", # type: ignore
238
+ value_store: zarr.storage.Store | str | None = None, # type: ignore
239
+ count_store: zarr.storage.Store | str | None = None, # type: ignore
237
240
  compressor: str | None = None,
238
241
  value_compressor: str | None = None,
239
242
  count_compressor: str | None = None,
@@ -251,18 +254,18 @@ class ZarrAvgMerger(Merger):
251
254
  if version_geq(get_package_version("zarr"), "3.0.0"):
252
255
  if value_store is None:
253
256
  self.tmpdir = TemporaryDirectory()
254
- self.value_store = zarr.storage.LocalStore(self.tmpdir.name)
257
+ self.value_store = zarr.storage.LocalStore(self.tmpdir.name) # type: ignore
255
258
  else:
256
- self.value_store = value_store
259
+ self.value_store = value_store # type: ignore
257
260
  if count_store is None:
258
261
  self.tmpdir = TemporaryDirectory()
259
- self.count_store = zarr.storage.LocalStore(self.tmpdir.name)
262
+ self.count_store = zarr.storage.LocalStore(self.tmpdir.name) # type: ignore
260
263
  else:
261
- self.count_store = count_store
264
+ self.count_store = count_store # type: ignore
262
265
  else:
263
266
  self.tmpdir = None
264
- self.value_store = zarr.storage.TempStore() if value_store is None else value_store
265
- self.count_store = zarr.storage.TempStore() if count_store is None else count_store
267
+ self.value_store = zarr.storage.TempStore() if value_store is None else value_store # type: ignore
268
+ self.count_store = zarr.storage.TempStore() if count_store is None else count_store # type: ignore
266
269
  self.chunks = chunks
267
270
  self.compressor = compressor
268
271
  self.value_compressor = value_compressor
@@ -314,7 +317,7 @@ class ZarrAvgMerger(Merger):
314
317
  map_slice = ensure_tuple_size(map_slice, values.ndim, pad_val=slice(None), pad_from_start=True)
315
318
  with self.lock:
316
319
  self.values[map_slice] += values.numpy()
317
- self.counts[map_slice] += 1
320
+ self.counts[map_slice] += 1 # type: ignore[operator]
318
321
 
319
322
  def finalize(self) -> zarr.Array:
320
323
  """
@@ -332,7 +335,7 @@ class ZarrAvgMerger(Merger):
332
335
  if not self.is_finalized:
333
336
  # use chunks for division to fit into memory
334
337
  for chunk in iterate_over_chunks(self.values.chunks, self.values.cdata_shape):
335
- self.output[chunk] = self.values[chunk] / self.counts[chunk]
338
+ self.output[chunk] = self.values[chunk] / self.counts[chunk] # type: ignore[operator]
336
339
  # finalize the shape
337
340
  self.output.resize(self.cropped_shape)
338
341
  # set finalize flag to protect performing in-place division again
@@ -374,7 +374,7 @@ class TorchvisionModelPerceptualSimilarity(nn.Module):
374
374
  else:
375
375
  network = torchvision.models.resnet50(weights=None)
376
376
  if pretrained is True:
377
- state_dict = torch.load(pretrained_path)
377
+ state_dict = torch.load(pretrained_path, weights_only=True)
378
378
  if pretrained_state_dict_key is not None:
379
379
  state_dict = state_dict[pretrained_state_dict_key]
380
380
  network.load_state_dict(state_dict)
monai/losses/sure_loss.py CHANGED
@@ -92,7 +92,7 @@ def sure_loss_function(
92
92
  y_ref = operator(x)
93
93
 
94
94
  # get perturbed output
95
- x_perturbed = x + eps * perturb_noise
95
+ x_perturbed = x + eps * perturb_noise # type: ignore
96
96
  y_perturbed = operator(x_perturbed)
97
97
  # divergence
98
98
  divergence = torch.sum(1.0 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref)) # type: ignore
@@ -17,7 +17,7 @@ import torch
17
17
  import torch.nn as nn
18
18
 
19
19
  from monai.networks.layers.utils import get_rel_pos_embedding_layer
20
- from monai.utils import optional_import, pytorch_after
20
+ from monai.utils import optional_import
21
21
 
22
22
  Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
23
23
 
@@ -84,11 +84,6 @@ class CrossAttentionBlock(nn.Module):
84
84
  if causal and sequence_length is None:
85
85
  raise ValueError("sequence_length is necessary for causal attention.")
86
86
 
87
- if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
88
- raise ValueError(
89
- "use_flash_attention is only supported for PyTorch versions >= 2.0."
90
- "Upgrade your PyTorch or set the flag to False."
91
- )
92
87
  if use_flash_attention and save_attn:
93
88
  raise ValueError(
94
89
  "save_attn has been set to True, but use_flash_attention is also set"
@@ -54,7 +54,9 @@ from __future__ import annotations
54
54
 
55
55
  from collections import OrderedDict
56
56
  from collections.abc import Callable
57
+ from typing import cast
57
58
 
59
+ import torch
58
60
  import torch.nn.functional as F
59
61
  from torch import Tensor, nn
60
62
 
@@ -194,8 +196,8 @@ class FeaturePyramidNetwork(nn.Module):
194
196
  conv_type_: type[nn.Module] = Conv[Conv.CONV, spatial_dims]
195
197
  for m in self.modules():
196
198
  if isinstance(m, conv_type_):
197
- nn.init.kaiming_uniform_(m.weight, a=1)
198
- nn.init.constant_(m.bias, 0.0)
199
+ nn.init.kaiming_uniform_(cast(torch.Tensor, m.weight), a=1)
200
+ nn.init.constant_(cast(torch.Tensor, m.bias), 0.0)
199
201
 
200
202
  if extra_blocks is not None:
201
203
  if not isinstance(extra_blocks, ExtraFPNBlock):
@@ -18,7 +18,7 @@ import torch.nn as nn
18
18
  import torch.nn.functional as F
19
19
 
20
20
  from monai.networks.layers.utils import get_rel_pos_embedding_layer
21
- from monai.utils import optional_import, pytorch_after
21
+ from monai.utils import optional_import
22
22
 
23
23
  Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
24
24
 
@@ -90,11 +90,6 @@ class SABlock(nn.Module):
90
90
  if causal and sequence_length is None:
91
91
  raise ValueError("sequence_length is necessary for causal attention.")
92
92
 
93
- if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
94
- raise ValueError(
95
- "use_flash_attention is only supported for PyTorch versions >= 2.0."
96
- "Upgrade your PyTorch or set the flag to False."
97
- )
98
93
  if use_flash_attention and save_attn:
99
94
  raise ValueError(
100
95
  "save_attn has been set to True, but use_flash_attention is also set"
@@ -17,8 +17,8 @@ import torch
17
17
  import torch.nn as nn
18
18
 
19
19
  from monai.networks.layers.factories import Conv, Pad, Pool
20
- from monai.networks.utils import CastTempType, icnr_init, pixelshuffle
21
- from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option, pytorch_after
20
+ from monai.networks.utils import icnr_init, pixelshuffle
21
+ from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option
22
22
 
23
23
  __all__ = ["Upsample", "UpSample", "SubpixelUpsample", "Subpixelupsample", "SubpixelUpSample"]
24
24
 
@@ -164,15 +164,7 @@ class UpSample(nn.Sequential):
164
164
  align_corners=align_corners,
165
165
  )
166
166
 
167
- # Cast to float32 as 'upsample_nearest2d_out_frame' op does not support bfloat16
168
- # https://github.com/pytorch/pytorch/issues/86679. This issue is solved in PyTorch 2.1
169
- if pytorch_after(major=2, minor=1):
170
- self.add_module("upsample_non_trainable", upsample)
171
- else:
172
- self.add_module(
173
- "upsample_non_trainable",
174
- CastTempType(initial_type=torch.bfloat16, temporary_type=torch.float32, submodule=upsample),
175
- )
167
+ self.add_module("upsample_non_trainable", upsample)
176
168
  if post_conv:
177
169
  self.add_module("postconv", post_conv)
178
170
  elif up_mode == UpsampleMode.PIXELSHUFFLE:
@@ -100,7 +100,7 @@ class EMAQuantizer(nn.Module):
100
100
  torch.Tensor: Quantization indices of shape [B,H,W,D,1]
101
101
 
102
102
  """
103
- with torch.cuda.amp.autocast(enabled=False):
103
+ with torch.autocast("cuda", enabled=False):
104
104
  encoding_indices_view = list(inputs.shape)
105
105
  del encoding_indices_view[1]
106
106
 
@@ -138,7 +138,7 @@ class EMAQuantizer(nn.Module):
138
138
  Returns:
139
139
  torch.Tensor: Quantize space representation of encoding_indices in channel first format.
140
140
  """
141
- with torch.cuda.amp.autocast(enabled=False):
141
+ with torch.autocast("cuda", enabled=False):
142
142
  embedding: torch.Tensor = (
143
143
  self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous()
144
144
  )
@@ -633,9 +633,9 @@ def _remap_preact_resnet_model(model_url: str):
633
633
  # download the pretrained weights into torch hub's default dir
634
634
  weights_dir = os.path.join(torch.hub.get_dir(), "preact-resnet50.pth")
635
635
  download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)
636
- state_dict = torch.load(weights_dir, map_location=None if torch.cuda.is_available() else torch.device("cpu"))[
637
- "desc"
638
- ]
636
+ map_location = None if torch.cuda.is_available() else torch.device("cpu")
637
+ state_dict = torch.load(weights_dir, map_location=map_location, weights_only=True)["desc"]
638
+
639
639
  for key in list(state_dict.keys()):
640
640
  new_key = None
641
641
  if pattern_conv0.match(key):
@@ -668,7 +668,8 @@ def _remap_standard_resnet_model(model_url: str, state_dict_key: str | None = No
668
668
  # download the pretrained weights into torch hub's default dir
669
669
  weights_dir = os.path.join(torch.hub.get_dir(), "resnet50.pth")
670
670
  download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)
671
- state_dict = torch.load(weights_dir, map_location=None if torch.cuda.is_available() else torch.device("cpu"))
671
+ map_location = None if torch.cuda.is_available() else torch.device("cpu")
672
+ state_dict = torch.load(weights_dir, map_location=map_location, weights_only=True)
672
673
  if state_dict_key is not None:
673
674
  state_dict = state_dict[state_dict_key]
674
675
 
@@ -493,7 +493,7 @@ def _resnet(
493
493
  if isinstance(pretrained, str):
494
494
  if Path(pretrained).exists():
495
495
  logger.info(f"Loading weights from {pretrained}...")
496
- model_state_dict = torch.load(pretrained, map_location=device)
496
+ model_state_dict = torch.load(pretrained, map_location=device, weights_only=True)
497
497
  else:
498
498
  # Throw error
499
499
  raise FileNotFoundError("The pretrained checkpoint file is not found")
@@ -665,7 +665,7 @@ def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", dat
665
665
  raise EntryNotFoundError(
666
666
  f"{filename} not found on {medicalnet_huggingface_repo_basename}{resnet_depth}"
667
667
  ) from None
668
- checkpoint = torch.load(pretrained_path, map_location=torch.device(device))
668
+ checkpoint = torch.load(pretrained_path, map_location=torch.device(device), weights_only=True)
669
669
  else:
670
670
  raise NotImplementedError("Supported resnet_depth are: [10, 18, 34, 50, 101, 152, 200]")
671
671
  logger.info(f"{filename} downloaded")
@@ -302,7 +302,7 @@ def _load_state_dict(model: nn.Module, arch: str, progress: bool):
302
302
 
303
303
  if isinstance(model_url, dict):
304
304
  download_url(model_url["url"], filepath=model_url["filename"])
305
- state_dict = torch.load(model_url["filename"], map_location=None)
305
+ state_dict = torch.load(model_url["filename"], map_location=None, weights_only=True)
306
306
  else:
307
307
  state_dict = load_state_dict_from_url(model_url, progress=progress)
308
308
  for key in list(state_dict.keys()):
@@ -272,53 +272,50 @@ class SwinUNETR(nn.Module):
272
272
  self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels)
273
273
 
274
274
  def load_from(self, weights):
275
+ layers1_0: BasicLayer = self.swinViT.layers1[0] # type: ignore[assignment]
276
+ layers2_0: BasicLayer = self.swinViT.layers2[0] # type: ignore[assignment]
277
+ layers3_0: BasicLayer = self.swinViT.layers3[0] # type: ignore[assignment]
278
+ layers4_0: BasicLayer = self.swinViT.layers4[0] # type: ignore[assignment]
279
+ wstate = weights["state_dict"]
280
+
275
281
  with torch.no_grad():
276
- self.swinViT.patch_embed.proj.weight.copy_(weights["state_dict"]["module.patch_embed.proj.weight"])
277
- self.swinViT.patch_embed.proj.bias.copy_(weights["state_dict"]["module.patch_embed.proj.bias"])
278
- for bname, block in self.swinViT.layers1[0].blocks.named_children():
279
- block.load_from(weights, n_block=bname, layer="layers1")
280
- self.swinViT.layers1[0].downsample.reduction.weight.copy_(
281
- weights["state_dict"]["module.layers1.0.downsample.reduction.weight"]
282
- )
283
- self.swinViT.layers1[0].downsample.norm.weight.copy_(
284
- weights["state_dict"]["module.layers1.0.downsample.norm.weight"]
285
- )
286
- self.swinViT.layers1[0].downsample.norm.bias.copy_(
287
- weights["state_dict"]["module.layers1.0.downsample.norm.bias"]
288
- )
289
- for bname, block in self.swinViT.layers2[0].blocks.named_children():
290
- block.load_from(weights, n_block=bname, layer="layers2")
291
- self.swinViT.layers2[0].downsample.reduction.weight.copy_(
292
- weights["state_dict"]["module.layers2.0.downsample.reduction.weight"]
293
- )
294
- self.swinViT.layers2[0].downsample.norm.weight.copy_(
295
- weights["state_dict"]["module.layers2.0.downsample.norm.weight"]
296
- )
297
- self.swinViT.layers2[0].downsample.norm.bias.copy_(
298
- weights["state_dict"]["module.layers2.0.downsample.norm.bias"]
299
- )
300
- for bname, block in self.swinViT.layers3[0].blocks.named_children():
301
- block.load_from(weights, n_block=bname, layer="layers3")
302
- self.swinViT.layers3[0].downsample.reduction.weight.copy_(
303
- weights["state_dict"]["module.layers3.0.downsample.reduction.weight"]
304
- )
305
- self.swinViT.layers3[0].downsample.norm.weight.copy_(
306
- weights["state_dict"]["module.layers3.0.downsample.norm.weight"]
307
- )
308
- self.swinViT.layers3[0].downsample.norm.bias.copy_(
309
- weights["state_dict"]["module.layers3.0.downsample.norm.bias"]
310
- )
311
- for bname, block in self.swinViT.layers4[0].blocks.named_children():
312
- block.load_from(weights, n_block=bname, layer="layers4")
313
- self.swinViT.layers4[0].downsample.reduction.weight.copy_(
314
- weights["state_dict"]["module.layers4.0.downsample.reduction.weight"]
315
- )
316
- self.swinViT.layers4[0].downsample.norm.weight.copy_(
317
- weights["state_dict"]["module.layers4.0.downsample.norm.weight"]
318
- )
319
- self.swinViT.layers4[0].downsample.norm.bias.copy_(
320
- weights["state_dict"]["module.layers4.0.downsample.norm.bias"]
321
- )
282
+ self.swinViT.patch_embed.proj.weight.copy_(wstate["module.patch_embed.proj.weight"])
283
+ self.swinViT.patch_embed.proj.bias.copy_(wstate["module.patch_embed.proj.bias"])
284
+ for bname, block in layers1_0.blocks.named_children():
285
+ block.load_from(weights, n_block=bname, layer="layers1") # type: ignore[operator]
286
+
287
+ if layers1_0.downsample is not None:
288
+ d = layers1_0.downsample
289
+ d.reduction.weight.copy_(wstate["module.layers1.0.downsample.reduction.weight"]) # type: ignore
290
+ d.norm.weight.copy_(wstate["module.layers1.0.downsample.norm.weight"]) # type: ignore
291
+ d.norm.bias.copy_(wstate["module.layers1.0.downsample.norm.bias"]) # type: ignore
292
+
293
+ for bname, block in layers2_0.blocks.named_children():
294
+ block.load_from(weights, n_block=bname, layer="layers2") # type: ignore[operator]
295
+
296
+ if layers2_0.downsample is not None:
297
+ d = layers2_0.downsample
298
+ d.reduction.weight.copy_(wstate["module.layers2.0.downsample.reduction.weight"]) # type: ignore
299
+ d.norm.weight.copy_(wstate["module.layers2.0.downsample.norm.weight"]) # type: ignore
300
+ d.norm.bias.copy_(wstate["module.layers2.0.downsample.norm.bias"]) # type: ignore
301
+
302
+ for bname, block in layers3_0.blocks.named_children():
303
+ block.load_from(weights, n_block=bname, layer="layers3") # type: ignore[operator]
304
+
305
+ if layers3_0.downsample is not None:
306
+ d = layers3_0.downsample
307
+ d.reduction.weight.copy_(wstate["module.layers3.0.downsample.reduction.weight"]) # type: ignore
308
+ d.norm.weight.copy_(wstate["module.layers3.0.downsample.norm.weight"]) # type: ignore
309
+ d.norm.bias.copy_(wstate["module.layers3.0.downsample.norm.bias"]) # type: ignore
310
+
311
+ for bname, block in layers4_0.blocks.named_children():
312
+ block.load_from(weights, n_block=bname, layer="layers4") # type: ignore[operator]
313
+
314
+ if layers4_0.downsample is not None:
315
+ d = layers4_0.downsample
316
+ d.reduction.weight.copy_(wstate["module.layers4.0.downsample.reduction.weight"]) # type: ignore
317
+ d.norm.weight.copy_(wstate["module.layers4.0.downsample.norm.weight"]) # type: ignore
318
+ d.norm.bias.copy_(wstate["module.layers4.0.downsample.norm.bias"]) # type: ignore
322
319
 
323
320
  @torch.jit.unused
324
321
  def _check_input_size(self, spatial_shape):
@@ -532,7 +529,7 @@ class WindowAttention(nn.Module):
532
529
  q = q * self.scale
533
530
  attn = q @ k.transpose(-2, -1)
534
531
  relative_position_bias = self.relative_position_bias_table[
535
- self.relative_position_index.clone()[:n, :n].reshape(-1)
532
+ self.relative_position_index.clone()[:n, :n].reshape(-1) # type: ignore[operator]
536
533
  ].reshape(n, n, -1)
537
534
  relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
538
535
  attn = attn + relative_position_bias.unsqueeze(0)
@@ -691,7 +688,7 @@ class SwinTransformerBlock(nn.Module):
691
688
  self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]])
692
689
  self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]])
693
690
  self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]])
694
- self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]])
691
+ self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]]) # type: ignore[operator]
695
692
  self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]])
696
693
  self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]])
697
694
  self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]])
@@ -1118,7 +1115,7 @@ def filter_swinunetr(key, value):
1118
1115
  )
1119
1116
  ssl_weights_path = "./ssl_pretrained_weights.pth"
1120
1117
  download_url(resource, ssl_weights_path)
1121
- ssl_weights = torch.load(ssl_weights_path)["model"]
1118
+ ssl_weights = torch.load(ssl_weights_path, weights_only=True)["model"]
1122
1119
 
1123
1120
  dst_dict, loaded, not_loaded = copy_model_state(model, ssl_weights, filter_func=filter_swinunetr)
1124
1121
 
@@ -43,7 +43,7 @@ class BertPreTrainedModel(nn.Module):
43
43
 
44
44
  def init_bert_weights(self, module):
45
45
  if isinstance(module, (nn.Linear, nn.Embedding)):
46
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
46
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) # type: ignore[union-attr,arg-type]
47
47
  elif isinstance(module, torch.nn.LayerNorm):
48
48
  module.bias.data.zero_()
49
49
  module.weight.data.fill_(1.0)
@@ -68,7 +68,8 @@ class BertPreTrainedModel(nn.Module):
68
68
  weights_path = cached_file(path_or_repo_id, filename, cache_dir=cache_dir)
69
69
  model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs)
70
70
  if state_dict is None and not from_tf:
71
- state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None)
71
+ map_location = "cpu" if not torch.cuda.is_available() else None
72
+ state_dict = torch.load(weights_path, map_location=map_location, weights_only=True)
72
73
  if from_tf:
73
74
  return load_tf_weights_in_bert(model, weights_path)
74
75
  old_keys = []
@@ -315,7 +315,7 @@ class VISTA3D(nn.Module):
315
315
  """
316
316
  if auto_freeze != self.auto_freeze:
317
317
  if hasattr(self.image_encoder, "set_auto_grad"):
318
- self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze)
318
+ self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) # type: ignore[operator]
319
319
  else:
320
320
  for param in self.image_encoder.parameters():
321
321
  param.requires_grad = (not auto_freeze) and (not point_freeze)
@@ -325,7 +325,7 @@ class VISTA3D(nn.Module):
325
325
 
326
326
  if point_freeze != self.point_freeze:
327
327
  if hasattr(self.image_encoder, "set_auto_grad"):
328
- self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze)
328
+ self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) # type: ignore[operator]
329
329
  else:
330
330
  for param in self.image_encoder.parameters():
331
331
  param.requires_grad = (not auto_freeze) and (not point_freeze)
@@ -543,10 +543,10 @@ class PointMappingSAM(nn.Module):
543
543
  point_embedding = self.pe_layer.forward_with_coords(points, out_shape) # type: ignore
544
544
  point_embedding[point_labels == -1] = 0.0
545
545
  point_embedding[point_labels == -1] += self.not_a_point_embed.weight
546
- point_embedding[point_labels == 0] += self.point_embeddings[0].weight
547
- point_embedding[point_labels == 1] += self.point_embeddings[1].weight
548
- point_embedding[point_labels == 2] += self.point_embeddings[0].weight + self.special_class_embed.weight
549
- point_embedding[point_labels == 3] += self.point_embeddings[1].weight + self.special_class_embed.weight
546
+ point_embedding[point_labels == 0] += self.point_embeddings[0].weight # type: ignore[arg-type]
547
+ point_embedding[point_labels == 1] += self.point_embeddings[1].weight # type: ignore[arg-type]
548
+ point_embedding[point_labels == 2] += self.point_embeddings[0].weight + self.special_class_embed.weight # type: ignore[operator]
549
+ point_embedding[point_labels == 3] += self.point_embeddings[1].weight + self.special_class_embed.weight # type: ignore[operator]
550
550
  output_tokens = self.mask_tokens.weight
551
551
 
552
552
  output_tokens = output_tokens.unsqueeze(0).expand(point_embedding.size(0), -1, -1)
@@ -884,7 +884,7 @@ class PositionEmbeddingRandom(nn.Module):
884
884
  coords = 2 * coords - 1
885
885
  # [bs=1,N=2,2] @ [2,128]
886
886
  # [bs=1, N=2, 128]
887
- coords = coords @ self.positional_encoding_gaussian_matrix
887
+ coords = coords @ self.positional_encoding_gaussian_matrix # type: ignore[operator]
888
888
  coords = 2 * np.pi * coords
889
889
  # outputs d_1 x ... x d_n x C shape
890
890
  # [bs=1, N=2, 128+128=256]
monai/networks/utils.py CHANGED
@@ -22,7 +22,7 @@ from collections import OrderedDict
22
22
  from collections.abc import Callable, Mapping, Sequence
23
23
  from contextlib import contextmanager
24
24
  from copy import deepcopy
25
- from typing import Any
25
+ from typing import Any, Iterable
26
26
 
27
27
  import numpy as np
28
28
  import torch
@@ -1238,7 +1238,7 @@ class CastToFloat(torch.nn.Module):
1238
1238
 
1239
1239
  def forward(self, x):
1240
1240
  dtype = x.dtype
1241
- with torch.amp.autocast("cuda", enabled=False):
1241
+ with torch.autocast("cuda", enabled=False):
1242
1242
  ret = self.mod.forward(x.to(torch.float32)).to(dtype)
1243
1243
  return ret
1244
1244
 
@@ -1255,7 +1255,7 @@ class CastToFloatAll(torch.nn.Module):
1255
1255
 
1256
1256
  def forward(self, *args):
1257
1257
  from_dtype = args[0].dtype
1258
- with torch.amp.autocast("cuda", enabled=False):
1258
+ with torch.autocast("cuda", enabled=False):
1259
1259
  ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32))
1260
1260
  return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype)
1261
1261
 
@@ -1291,7 +1291,8 @@ def simple_replace(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable
1291
1291
  def expansion_fn(mod: nn.Module) -> nn.Module | None:
1292
1292
  if not isinstance(mod, base_t):
1293
1293
  return None
1294
- args = [getattr(mod, name, None) for name in mod.__constants__]
1294
+ constants: Iterable = mod.__constants__ # type: ignore[assignment]
1295
+ args = [getattr(mod, name, None) for name in constants]
1295
1296
  out = dest_t(*args)
1296
1297
  return out
1297
1298
 
@@ -1856,7 +1856,7 @@ class RandHistogramShift(RandomizableTransform):
1856
1856
  indices = ns.searchsorted(xp.reshape(-1), x.reshape(-1)) - 1
1857
1857
  indices = ns.clip(indices, 0, len(m) - 1)
1858
1858
 
1859
- f = (m[indices] * x.reshape(-1) + b[indices]).reshape(x.shape)
1859
+ f: NdarrayOrTensor = (m[indices] * x.reshape(-1) + b[indices]).reshape(x.shape)
1860
1860
  f[x < xp[0]] = fp[0]
1861
1861
  f[x > xp[-1]] = fp[-1]
1862
1862
  return f