monai-weekly 1.5.dev2509__py3-none-any.whl → 1.5.dev2511__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/deepedit/interaction.py +1 -1
  4. monai/apps/deepgrow/interaction.py +1 -1
  5. monai/apps/detection/networks/retinanet_detector.py +1 -1
  6. monai/apps/detection/networks/retinanet_network.py +5 -5
  7. monai/apps/detection/utils/box_coder.py +2 -2
  8. monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +4 -0
  9. monai/apps/mmars/mmars.py +1 -1
  10. monai/apps/reconstruction/networks/blocks/varnetblock.py +1 -1
  11. monai/bundle/scripts.py +3 -4
  12. monai/data/dataset.py +2 -9
  13. monai/data/utils.py +1 -1
  14. monai/data/video_dataset.py +1 -1
  15. monai/engines/evaluator.py +11 -16
  16. monai/engines/trainer.py +11 -17
  17. monai/engines/utils.py +1 -1
  18. monai/engines/workflow.py +2 -2
  19. monai/fl/client/monai_algo.py +1 -1
  20. monai/handlers/checkpoint_loader.py +1 -1
  21. monai/inferers/inferer.py +33 -13
  22. monai/inferers/merger.py +16 -13
  23. monai/losses/perceptual.py +1 -1
  24. monai/losses/sure_loss.py +1 -1
  25. monai/networks/blocks/crossattention.py +1 -6
  26. monai/networks/blocks/feature_pyramid_network.py +4 -2
  27. monai/networks/blocks/selfattention.py +1 -6
  28. monai/networks/blocks/upsample.py +3 -11
  29. monai/networks/layers/vector_quantizer.py +2 -2
  30. monai/networks/nets/hovernet.py +5 -4
  31. monai/networks/nets/resnet.py +2 -2
  32. monai/networks/nets/senet.py +1 -1
  33. monai/networks/nets/swin_unetr.py +46 -49
  34. monai/networks/nets/transchex.py +3 -2
  35. monai/networks/nets/vista3d.py +7 -7
  36. monai/networks/schedulers/__init__.py +1 -0
  37. monai/networks/schedulers/rectified_flow.py +322 -0
  38. monai/networks/utils.py +5 -4
  39. monai/transforms/intensity/array.py +1 -1
  40. monai/transforms/spatial/array.py +6 -6
  41. monai/utils/misc.py +1 -1
  42. monai/utils/state_cacher.py +1 -1
  43. {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2511.dist-info}/METADATA +4 -3
  44. {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2511.dist-info}/RECORD +66 -64
  45. {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2511.dist-info}/WHEEL +1 -1
  46. tests/bundle/test_bundle_download.py +16 -6
  47. tests/config/test_cv2_dist.py +1 -2
  48. tests/inferers/test_controlnet_inferers.py +96 -32
  49. tests/inferers/test_diffusion_inferer.py +99 -1
  50. tests/inferers/test_latent_diffusion_inferer.py +217 -211
  51. tests/integration/test_integration_bundle_run.py +2 -4
  52. tests/integration/test_integration_classification_2d.py +1 -1
  53. tests/integration/test_integration_fast_train.py +2 -2
  54. tests/integration/test_integration_segmentation_3d.py +1 -1
  55. tests/metrics/test_compute_multiscalessim_metric.py +3 -3
  56. tests/metrics/test_surface_dice.py +3 -3
  57. tests/networks/nets/test_autoencoderkl.py +1 -1
  58. tests/networks/nets/test_controlnet.py +1 -1
  59. tests/networks/nets/test_diffusion_model_unet.py +1 -1
  60. tests/networks/nets/test_network_consistency.py +1 -1
  61. tests/networks/nets/test_swin_unetr.py +1 -1
  62. tests/networks/nets/test_transformer.py +1 -1
  63. tests/networks/schedulers/test_scheduler_rflow.py +105 -0
  64. tests/networks/test_save_state.py +1 -1
  65. {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2511.dist-info}/LICENSE +0 -0
  66. {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2511.dist-info}/top_level.txt +0 -0
@@ -54,7 +54,9 @@ from __future__ import annotations
54
54
 
55
55
  from collections import OrderedDict
56
56
  from collections.abc import Callable
57
+ from typing import cast
57
58
 
59
+ import torch
58
60
  import torch.nn.functional as F
59
61
  from torch import Tensor, nn
60
62
 
@@ -194,8 +196,8 @@ class FeaturePyramidNetwork(nn.Module):
194
196
  conv_type_: type[nn.Module] = Conv[Conv.CONV, spatial_dims]
195
197
  for m in self.modules():
196
198
  if isinstance(m, conv_type_):
197
- nn.init.kaiming_uniform_(m.weight, a=1)
198
- nn.init.constant_(m.bias, 0.0)
199
+ nn.init.kaiming_uniform_(cast(torch.Tensor, m.weight), a=1)
200
+ nn.init.constant_(cast(torch.Tensor, m.bias), 0.0)
199
201
 
200
202
  if extra_blocks is not None:
201
203
  if not isinstance(extra_blocks, ExtraFPNBlock):
@@ -18,7 +18,7 @@ import torch.nn as nn
18
18
  import torch.nn.functional as F
19
19
 
20
20
  from monai.networks.layers.utils import get_rel_pos_embedding_layer
21
- from monai.utils import optional_import, pytorch_after
21
+ from monai.utils import optional_import
22
22
 
23
23
  Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
24
24
 
@@ -90,11 +90,6 @@ class SABlock(nn.Module):
90
90
  if causal and sequence_length is None:
91
91
  raise ValueError("sequence_length is necessary for causal attention.")
92
92
 
93
- if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
94
- raise ValueError(
95
- "use_flash_attention is only supported for PyTorch versions >= 2.0."
96
- "Upgrade your PyTorch or set the flag to False."
97
- )
98
93
  if use_flash_attention and save_attn:
99
94
  raise ValueError(
100
95
  "save_attn has been set to True, but use_flash_attention is also set"
@@ -17,8 +17,8 @@ import torch
17
17
  import torch.nn as nn
18
18
 
19
19
  from monai.networks.layers.factories import Conv, Pad, Pool
20
- from monai.networks.utils import CastTempType, icnr_init, pixelshuffle
21
- from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option, pytorch_after
20
+ from monai.networks.utils import icnr_init, pixelshuffle
21
+ from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option
22
22
 
23
23
  __all__ = ["Upsample", "UpSample", "SubpixelUpsample", "Subpixelupsample", "SubpixelUpSample"]
24
24
 
@@ -164,15 +164,7 @@ class UpSample(nn.Sequential):
164
164
  align_corners=align_corners,
165
165
  )
166
166
 
167
- # Cast to float32 as 'upsample_nearest2d_out_frame' op does not support bfloat16
168
- # https://github.com/pytorch/pytorch/issues/86679. This issue is solved in PyTorch 2.1
169
- if pytorch_after(major=2, minor=1):
170
- self.add_module("upsample_non_trainable", upsample)
171
- else:
172
- self.add_module(
173
- "upsample_non_trainable",
174
- CastTempType(initial_type=torch.bfloat16, temporary_type=torch.float32, submodule=upsample),
175
- )
167
+ self.add_module("upsample_non_trainable", upsample)
176
168
  if post_conv:
177
169
  self.add_module("postconv", post_conv)
178
170
  elif up_mode == UpsampleMode.PIXELSHUFFLE:
@@ -100,7 +100,7 @@ class EMAQuantizer(nn.Module):
100
100
  torch.Tensor: Quantization indices of shape [B,H,W,D,1]
101
101
 
102
102
  """
103
- with torch.cuda.amp.autocast(enabled=False):
103
+ with torch.autocast("cuda", enabled=False):
104
104
  encoding_indices_view = list(inputs.shape)
105
105
  del encoding_indices_view[1]
106
106
 
@@ -138,7 +138,7 @@ class EMAQuantizer(nn.Module):
138
138
  Returns:
139
139
  torch.Tensor: Quantize space representation of encoding_indices in channel first format.
140
140
  """
141
- with torch.cuda.amp.autocast(enabled=False):
141
+ with torch.autocast("cuda", enabled=False):
142
142
  embedding: torch.Tensor = (
143
143
  self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous()
144
144
  )
@@ -633,9 +633,9 @@ def _remap_preact_resnet_model(model_url: str):
633
633
  # download the pretrained weights into torch hub's default dir
634
634
  weights_dir = os.path.join(torch.hub.get_dir(), "preact-resnet50.pth")
635
635
  download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)
636
- state_dict = torch.load(weights_dir, map_location=None if torch.cuda.is_available() else torch.device("cpu"))[
637
- "desc"
638
- ]
636
+ map_location = None if torch.cuda.is_available() else torch.device("cpu")
637
+ state_dict = torch.load(weights_dir, map_location=map_location, weights_only=True)["desc"]
638
+
639
639
  for key in list(state_dict.keys()):
640
640
  new_key = None
641
641
  if pattern_conv0.match(key):
@@ -668,7 +668,8 @@ def _remap_standard_resnet_model(model_url: str, state_dict_key: str | None = No
668
668
  # download the pretrained weights into torch hub's default dir
669
669
  weights_dir = os.path.join(torch.hub.get_dir(), "resnet50.pth")
670
670
  download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)
671
- state_dict = torch.load(weights_dir, map_location=None if torch.cuda.is_available() else torch.device("cpu"))
671
+ map_location = None if torch.cuda.is_available() else torch.device("cpu")
672
+ state_dict = torch.load(weights_dir, map_location=map_location, weights_only=True)
672
673
  if state_dict_key is not None:
673
674
  state_dict = state_dict[state_dict_key]
674
675
 
@@ -493,7 +493,7 @@ def _resnet(
493
493
  if isinstance(pretrained, str):
494
494
  if Path(pretrained).exists():
495
495
  logger.info(f"Loading weights from {pretrained}...")
496
- model_state_dict = torch.load(pretrained, map_location=device)
496
+ model_state_dict = torch.load(pretrained, map_location=device, weights_only=True)
497
497
  else:
498
498
  # Throw error
499
499
  raise FileNotFoundError("The pretrained checkpoint file is not found")
@@ -665,7 +665,7 @@ def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", dat
665
665
  raise EntryNotFoundError(
666
666
  f"{filename} not found on {medicalnet_huggingface_repo_basename}{resnet_depth}"
667
667
  ) from None
668
- checkpoint = torch.load(pretrained_path, map_location=torch.device(device))
668
+ checkpoint = torch.load(pretrained_path, map_location=torch.device(device), weights_only=True)
669
669
  else:
670
670
  raise NotImplementedError("Supported resnet_depth are: [10, 18, 34, 50, 101, 152, 200]")
671
671
  logger.info(f"{filename} downloaded")
@@ -302,7 +302,7 @@ def _load_state_dict(model: nn.Module, arch: str, progress: bool):
302
302
 
303
303
  if isinstance(model_url, dict):
304
304
  download_url(model_url["url"], filepath=model_url["filename"])
305
- state_dict = torch.load(model_url["filename"], map_location=None)
305
+ state_dict = torch.load(model_url["filename"], map_location=None, weights_only=True)
306
306
  else:
307
307
  state_dict = load_state_dict_from_url(model_url, progress=progress)
308
308
  for key in list(state_dict.keys()):
@@ -272,53 +272,50 @@ class SwinUNETR(nn.Module):
272
272
  self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels)
273
273
 
274
274
  def load_from(self, weights):
275
+ layers1_0: BasicLayer = self.swinViT.layers1[0] # type: ignore[assignment]
276
+ layers2_0: BasicLayer = self.swinViT.layers2[0] # type: ignore[assignment]
277
+ layers3_0: BasicLayer = self.swinViT.layers3[0] # type: ignore[assignment]
278
+ layers4_0: BasicLayer = self.swinViT.layers4[0] # type: ignore[assignment]
279
+ wstate = weights["state_dict"]
280
+
275
281
  with torch.no_grad():
276
- self.swinViT.patch_embed.proj.weight.copy_(weights["state_dict"]["module.patch_embed.proj.weight"])
277
- self.swinViT.patch_embed.proj.bias.copy_(weights["state_dict"]["module.patch_embed.proj.bias"])
278
- for bname, block in self.swinViT.layers1[0].blocks.named_children():
279
- block.load_from(weights, n_block=bname, layer="layers1")
280
- self.swinViT.layers1[0].downsample.reduction.weight.copy_(
281
- weights["state_dict"]["module.layers1.0.downsample.reduction.weight"]
282
- )
283
- self.swinViT.layers1[0].downsample.norm.weight.copy_(
284
- weights["state_dict"]["module.layers1.0.downsample.norm.weight"]
285
- )
286
- self.swinViT.layers1[0].downsample.norm.bias.copy_(
287
- weights["state_dict"]["module.layers1.0.downsample.norm.bias"]
288
- )
289
- for bname, block in self.swinViT.layers2[0].blocks.named_children():
290
- block.load_from(weights, n_block=bname, layer="layers2")
291
- self.swinViT.layers2[0].downsample.reduction.weight.copy_(
292
- weights["state_dict"]["module.layers2.0.downsample.reduction.weight"]
293
- )
294
- self.swinViT.layers2[0].downsample.norm.weight.copy_(
295
- weights["state_dict"]["module.layers2.0.downsample.norm.weight"]
296
- )
297
- self.swinViT.layers2[0].downsample.norm.bias.copy_(
298
- weights["state_dict"]["module.layers2.0.downsample.norm.bias"]
299
- )
300
- for bname, block in self.swinViT.layers3[0].blocks.named_children():
301
- block.load_from(weights, n_block=bname, layer="layers3")
302
- self.swinViT.layers3[0].downsample.reduction.weight.copy_(
303
- weights["state_dict"]["module.layers3.0.downsample.reduction.weight"]
304
- )
305
- self.swinViT.layers3[0].downsample.norm.weight.copy_(
306
- weights["state_dict"]["module.layers3.0.downsample.norm.weight"]
307
- )
308
- self.swinViT.layers3[0].downsample.norm.bias.copy_(
309
- weights["state_dict"]["module.layers3.0.downsample.norm.bias"]
310
- )
311
- for bname, block in self.swinViT.layers4[0].blocks.named_children():
312
- block.load_from(weights, n_block=bname, layer="layers4")
313
- self.swinViT.layers4[0].downsample.reduction.weight.copy_(
314
- weights["state_dict"]["module.layers4.0.downsample.reduction.weight"]
315
- )
316
- self.swinViT.layers4[0].downsample.norm.weight.copy_(
317
- weights["state_dict"]["module.layers4.0.downsample.norm.weight"]
318
- )
319
- self.swinViT.layers4[0].downsample.norm.bias.copy_(
320
- weights["state_dict"]["module.layers4.0.downsample.norm.bias"]
321
- )
282
+ self.swinViT.patch_embed.proj.weight.copy_(wstate["module.patch_embed.proj.weight"])
283
+ self.swinViT.patch_embed.proj.bias.copy_(wstate["module.patch_embed.proj.bias"])
284
+ for bname, block in layers1_0.blocks.named_children():
285
+ block.load_from(weights, n_block=bname, layer="layers1") # type: ignore[operator]
286
+
287
+ if layers1_0.downsample is not None:
288
+ d = layers1_0.downsample
289
+ d.reduction.weight.copy_(wstate["module.layers1.0.downsample.reduction.weight"]) # type: ignore
290
+ d.norm.weight.copy_(wstate["module.layers1.0.downsample.norm.weight"]) # type: ignore
291
+ d.norm.bias.copy_(wstate["module.layers1.0.downsample.norm.bias"]) # type: ignore
292
+
293
+ for bname, block in layers2_0.blocks.named_children():
294
+ block.load_from(weights, n_block=bname, layer="layers2") # type: ignore[operator]
295
+
296
+ if layers2_0.downsample is not None:
297
+ d = layers2_0.downsample
298
+ d.reduction.weight.copy_(wstate["module.layers2.0.downsample.reduction.weight"]) # type: ignore
299
+ d.norm.weight.copy_(wstate["module.layers2.0.downsample.norm.weight"]) # type: ignore
300
+ d.norm.bias.copy_(wstate["module.layers2.0.downsample.norm.bias"]) # type: ignore
301
+
302
+ for bname, block in layers3_0.blocks.named_children():
303
+ block.load_from(weights, n_block=bname, layer="layers3") # type: ignore[operator]
304
+
305
+ if layers3_0.downsample is not None:
306
+ d = layers3_0.downsample
307
+ d.reduction.weight.copy_(wstate["module.layers3.0.downsample.reduction.weight"]) # type: ignore
308
+ d.norm.weight.copy_(wstate["module.layers3.0.downsample.norm.weight"]) # type: ignore
309
+ d.norm.bias.copy_(wstate["module.layers3.0.downsample.norm.bias"]) # type: ignore
310
+
311
+ for bname, block in layers4_0.blocks.named_children():
312
+ block.load_from(weights, n_block=bname, layer="layers4") # type: ignore[operator]
313
+
314
+ if layers4_0.downsample is not None:
315
+ d = layers4_0.downsample
316
+ d.reduction.weight.copy_(wstate["module.layers4.0.downsample.reduction.weight"]) # type: ignore
317
+ d.norm.weight.copy_(wstate["module.layers4.0.downsample.norm.weight"]) # type: ignore
318
+ d.norm.bias.copy_(wstate["module.layers4.0.downsample.norm.bias"]) # type: ignore
322
319
 
323
320
  @torch.jit.unused
324
321
  def _check_input_size(self, spatial_shape):
@@ -532,7 +529,7 @@ class WindowAttention(nn.Module):
532
529
  q = q * self.scale
533
530
  attn = q @ k.transpose(-2, -1)
534
531
  relative_position_bias = self.relative_position_bias_table[
535
- self.relative_position_index.clone()[:n, :n].reshape(-1)
532
+ self.relative_position_index.clone()[:n, :n].reshape(-1) # type: ignore[operator]
536
533
  ].reshape(n, n, -1)
537
534
  relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
538
535
  attn = attn + relative_position_bias.unsqueeze(0)
@@ -691,7 +688,7 @@ class SwinTransformerBlock(nn.Module):
691
688
  self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]])
692
689
  self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]])
693
690
  self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]])
694
- self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]])
691
+ self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]]) # type: ignore[operator]
695
692
  self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]])
696
693
  self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]])
697
694
  self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]])
@@ -1118,7 +1115,7 @@ def filter_swinunetr(key, value):
1118
1115
  )
1119
1116
  ssl_weights_path = "./ssl_pretrained_weights.pth"
1120
1117
  download_url(resource, ssl_weights_path)
1121
- ssl_weights = torch.load(ssl_weights_path)["model"]
1118
+ ssl_weights = torch.load(ssl_weights_path, weights_only=True)["model"]
1122
1119
 
1123
1120
  dst_dict, loaded, not_loaded = copy_model_state(model, ssl_weights, filter_func=filter_swinunetr)
1124
1121
 
@@ -43,7 +43,7 @@ class BertPreTrainedModel(nn.Module):
43
43
 
44
44
  def init_bert_weights(self, module):
45
45
  if isinstance(module, (nn.Linear, nn.Embedding)):
46
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
46
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) # type: ignore[union-attr,arg-type]
47
47
  elif isinstance(module, torch.nn.LayerNorm):
48
48
  module.bias.data.zero_()
49
49
  module.weight.data.fill_(1.0)
@@ -68,7 +68,8 @@ class BertPreTrainedModel(nn.Module):
68
68
  weights_path = cached_file(path_or_repo_id, filename, cache_dir=cache_dir)
69
69
  model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs)
70
70
  if state_dict is None and not from_tf:
71
- state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None)
71
+ map_location = "cpu" if not torch.cuda.is_available() else None
72
+ state_dict = torch.load(weights_path, map_location=map_location, weights_only=True)
72
73
  if from_tf:
73
74
  return load_tf_weights_in_bert(model, weights_path)
74
75
  old_keys = []
@@ -315,7 +315,7 @@ class VISTA3D(nn.Module):
315
315
  """
316
316
  if auto_freeze != self.auto_freeze:
317
317
  if hasattr(self.image_encoder, "set_auto_grad"):
318
- self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze)
318
+ self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) # type: ignore[operator]
319
319
  else:
320
320
  for param in self.image_encoder.parameters():
321
321
  param.requires_grad = (not auto_freeze) and (not point_freeze)
@@ -325,7 +325,7 @@ class VISTA3D(nn.Module):
325
325
 
326
326
  if point_freeze != self.point_freeze:
327
327
  if hasattr(self.image_encoder, "set_auto_grad"):
328
- self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze)
328
+ self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) # type: ignore[operator]
329
329
  else:
330
330
  for param in self.image_encoder.parameters():
331
331
  param.requires_grad = (not auto_freeze) and (not point_freeze)
@@ -543,10 +543,10 @@ class PointMappingSAM(nn.Module):
543
543
  point_embedding = self.pe_layer.forward_with_coords(points, out_shape) # type: ignore
544
544
  point_embedding[point_labels == -1] = 0.0
545
545
  point_embedding[point_labels == -1] += self.not_a_point_embed.weight
546
- point_embedding[point_labels == 0] += self.point_embeddings[0].weight
547
- point_embedding[point_labels == 1] += self.point_embeddings[1].weight
548
- point_embedding[point_labels == 2] += self.point_embeddings[0].weight + self.special_class_embed.weight
549
- point_embedding[point_labels == 3] += self.point_embeddings[1].weight + self.special_class_embed.weight
546
+ point_embedding[point_labels == 0] += self.point_embeddings[0].weight # type: ignore[arg-type]
547
+ point_embedding[point_labels == 1] += self.point_embeddings[1].weight # type: ignore[arg-type]
548
+ point_embedding[point_labels == 2] += self.point_embeddings[0].weight + self.special_class_embed.weight # type: ignore[operator]
549
+ point_embedding[point_labels == 3] += self.point_embeddings[1].weight + self.special_class_embed.weight # type: ignore[operator]
550
550
  output_tokens = self.mask_tokens.weight
551
551
 
552
552
  output_tokens = output_tokens.unsqueeze(0).expand(point_embedding.size(0), -1, -1)
@@ -884,7 +884,7 @@ class PositionEmbeddingRandom(nn.Module):
884
884
  coords = 2 * coords - 1
885
885
  # [bs=1,N=2,2] @ [2,128]
886
886
  # [bs=1, N=2, 128]
887
- coords = coords @ self.positional_encoding_gaussian_matrix
887
+ coords = coords @ self.positional_encoding_gaussian_matrix # type: ignore[operator]
888
888
  coords = 2 * np.pi * coords
889
889
  # outputs d_1 x ... x d_n x C shape
890
890
  # [bs=1, N=2, 128+128=256]
@@ -14,4 +14,5 @@ from __future__ import annotations
14
14
  from .ddim import DDIMScheduler
15
15
  from .ddpm import DDPMScheduler
16
16
  from .pndm import PNDMScheduler
17
+ from .rectified_flow import RFlowScheduler
17
18
  from .scheduler import NoiseSchedules, Scheduler
@@ -0,0 +1,322 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+ #
12
+ # =========================================================================
13
+ # Adapted from https://github.com/hpcaitech/Open-Sora/blob/main/opensora/schedulers/rf/rectified_flow.py
14
+ # which has the following license:
15
+ # https://github.com/hpcaitech/Open-Sora/blob/main/LICENSE
16
+ # Licensed under the Apache License, Version 2.0 (the "License");
17
+ # you may not use this file except in compliance with the License.
18
+ # You may obtain a copy of the License at
19
+ #
20
+ # http://www.apache.org/licenses/LICENSE-2.0
21
+ #
22
+ # Unless required by applicable law or agreed to in writing, software
23
+ # distributed under the License is distributed on an "AS IS" BASIS,
24
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25
+ # See the License for the specific language governing permissions and
26
+ # limitations under the License.
27
+ # =========================================================================
28
+
29
+ from __future__ import annotations
30
+
31
+ from typing import Union
32
+
33
+ import numpy as np
34
+ import torch
35
+ from torch.distributions import LogisticNormal
36
+
37
+ from monai.utils import StrEnum
38
+
39
+ from .ddpm import DDPMPredictionType
40
+ from .scheduler import Scheduler
41
+
42
+
43
+ class RFlowPredictionType(StrEnum):
44
+ """
45
+ Set of valid prediction type names for the RFlow scheduler's `prediction_type` argument.
46
+
47
+ v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf
48
+ """
49
+
50
+ V_PREDICTION = DDPMPredictionType.V_PREDICTION
51
+
52
+
53
+ def timestep_transform(
54
+ t, input_img_size_numel, base_img_size_numel=32 * 32 * 32, scale=1.0, num_train_timesteps=1000, spatial_dim=3
55
+ ):
56
+ """
57
+ Applies a transformation to the timestep based on image resolution scaling.
58
+
59
+ Args:
60
+ t (torch.Tensor): The original timestep(s).
61
+ input_img_size_numel (torch.Tensor): The input image's size (H * W * D).
62
+ base_img_size_numel (int): reference H*W*D size, usually smaller than input_img_size_numel.
63
+ scale (float): Scaling factor for the transformation.
64
+ num_train_timesteps (int): Total number of training timesteps.
65
+ spatial_dim (int): Number of spatial dimensions in the image.
66
+
67
+ Returns:
68
+ torch.Tensor: Transformed timestep(s).
69
+ """
70
+ t = t / num_train_timesteps
71
+ ratio_space = (input_img_size_numel / base_img_size_numel) ** (1.0 / spatial_dim)
72
+
73
+ ratio = ratio_space * scale
74
+ new_t = ratio * t / (1 + (ratio - 1) * t)
75
+
76
+ new_t = new_t * num_train_timesteps
77
+ return new_t
78
+
79
+
80
+ class RFlowScheduler(Scheduler):
81
+ """
82
+ A rectified flow scheduler for guiding the diffusion process in a generative model.
83
+
84
+ Supports uniform and logit-normal sampling methods, timestep transformation for
85
+ different resolutions, and noise addition during diffusion.
86
+
87
+ Args:
88
+ num_train_timesteps (int): Total number of training timesteps.
89
+ use_discrete_timesteps (bool): Whether to use discrete timesteps.
90
+ sample_method (str): Training time step sampling method ('uniform' or 'logit-normal').
91
+ loc (float): Location parameter for logit-normal distribution, used only if sample_method='logit-normal'.
92
+ scale (float): Scale parameter for logit-normal distribution, used only if sample_method='logit-normal'.
93
+ use_timestep_transform (bool): Whether to apply timestep transformation.
94
+ If true, there will be more inference timesteps at early(noisy) stages for larger image volumes.
95
+ transform_scale (float): Scaling factor for timestep transformation, used only if use_timestep_transform=True.
96
+ steps_offset (int): Offset added to computed timesteps, used only if use_timestep_transform=True.
97
+ base_img_size_numel (int): Reference image volume size for scaling, used only if use_timestep_transform=True.
98
+ spatial_dim (int): 2 or 3, incidcating 2D or 3D images, used only if use_timestep_transform=True.
99
+
100
+ Example:
101
+
102
+ .. code-block:: python
103
+
104
+ # define a scheduler
105
+ noise_scheduler = RFlowScheduler(
106
+ num_train_timesteps = 1000,
107
+ use_discrete_timesteps = True,
108
+ sample_method = 'logit-normal',
109
+ use_timestep_transform = True,
110
+ base_img_size_numel = 32 * 32 * 32,
111
+ spatial_dim = 3
112
+ )
113
+
114
+ # during training
115
+ inputs = torch.ones(2,4,64,64,32)
116
+ noise = torch.randn_like(inputs)
117
+ timesteps = noise_scheduler.sample_timesteps(inputs)
118
+ noisy_inputs = noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
119
+ predicted_velocity = diffusion_unet(
120
+ x=noisy_inputs,
121
+ timesteps=timesteps
122
+ )
123
+ loss = loss_l1(predicted_velocity, (inputs - noise))
124
+
125
+ # during inference
126
+ noisy_inputs = torch.randn(2,4,64,64,32)
127
+ input_img_size_numel = torch.prod(torch.tensor(noisy_inputs.shape[-3:])
128
+ noise_scheduler.set_timesteps(
129
+ num_inference_steps=30, input_img_size_numel=input_img_size_numel)
130
+ )
131
+ all_next_timesteps = torch.cat(
132
+ (noise_scheduler.timesteps[1:], torch.tensor([0], dtype=noise_scheduler.timesteps.dtype))
133
+ )
134
+ for t, next_t in tqdm(
135
+ zip(noise_scheduler.timesteps, all_next_timesteps),
136
+ total=min(len(noise_scheduler.timesteps), len(all_next_timesteps)),
137
+ ):
138
+ predicted_velocity = diffusion_unet(
139
+ x=noisy_inputs,
140
+ timesteps=timesteps
141
+ )
142
+ noisy_inputs, _ = noise_scheduler.step(predicted_velocity, t, noisy_inputs, next_t)
143
+ final_output = noisy_inputs
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ num_train_timesteps: int = 1000,
149
+ use_discrete_timesteps: bool = True,
150
+ sample_method: str = "uniform",
151
+ loc: float = 0.0,
152
+ scale: float = 1.0,
153
+ use_timestep_transform: bool = False,
154
+ transform_scale: float = 1.0,
155
+ steps_offset: int = 0,
156
+ base_img_size_numel: int = 32 * 32 * 32,
157
+ spatial_dim: int = 3,
158
+ ):
159
+ # rectified flow only accepts velocity prediction
160
+ self.prediction_type = RFlowPredictionType.V_PREDICTION
161
+
162
+ self.num_train_timesteps = num_train_timesteps
163
+ self.use_discrete_timesteps = use_discrete_timesteps
164
+ self.base_img_size_numel = base_img_size_numel
165
+ self.spatial_dim = spatial_dim
166
+
167
+ # sample method
168
+ if sample_method not in ["uniform", "logit-normal"]:
169
+ raise ValueError(
170
+ f"sample_method = {sample_method}, which has to be chosen from ['uniform', 'logit-normal']."
171
+ )
172
+ self.sample_method = sample_method
173
+ if sample_method == "logit-normal":
174
+ self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale]))
175
+ self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device)
176
+
177
+ # timestep transform
178
+ self.use_timestep_transform = use_timestep_transform
179
+ self.transform_scale = transform_scale
180
+ self.steps_offset = steps_offset
181
+
182
+ def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
183
+ """
184
+ Add noise to the original samples.
185
+
186
+ Args:
187
+ original_samples: original samples
188
+ noise: noise to add to samples
189
+ timesteps: timesteps tensor with shape of (N,), indicating the timestep to be computed for each sample.
190
+
191
+ Returns:
192
+ noisy_samples: sample with added noise
193
+ """
194
+ timepoints: torch.Tensor = timesteps.float() / self.num_train_timesteps
195
+ timepoints = 1 - timepoints # [1,1/1000]
196
+
197
+ # expand timepoint to noise shape
198
+ if noise.ndim == 5:
199
+ timepoints = timepoints[..., None, None, None, None].expand(-1, *noise.shape[1:])
200
+ elif noise.ndim == 4:
201
+ timepoints = timepoints[..., None, None, None].expand(-1, *noise.shape[1:])
202
+ else:
203
+ raise ValueError(f"noise tensor has to be 4D or 5D tensor, yet got shape of {noise.shape}")
204
+
205
+ noisy_samples: torch.Tensor = timepoints * original_samples + (1 - timepoints) * noise
206
+
207
+ return noisy_samples
208
+
209
+ def set_timesteps(
210
+ self,
211
+ num_inference_steps: int,
212
+ device: str | torch.device | None = None,
213
+ input_img_size_numel: int | None = None,
214
+ ) -> None:
215
+ """
216
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
217
+
218
+ Args:
219
+ num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
220
+ device: target device to put the data.
221
+ input_img_size_numel: int, H*W*D of the image, used with self.use_timestep_transform is True.
222
+ """
223
+ if num_inference_steps > self.num_train_timesteps or num_inference_steps < 1:
224
+ raise ValueError(
225
+ f"`num_inference_steps`: {num_inference_steps} should be at least 1, "
226
+ "and cannot be larger than `self.num_train_timesteps`:"
227
+ f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle"
228
+ f" maximal {self.num_train_timesteps} timesteps."
229
+ )
230
+
231
+ self.num_inference_steps = num_inference_steps
232
+ # prepare timesteps
233
+ timesteps = [
234
+ (1.0 - i / self.num_inference_steps) * self.num_train_timesteps for i in range(self.num_inference_steps)
235
+ ]
236
+ if self.use_discrete_timesteps:
237
+ timesteps = [int(round(t)) for t in timesteps]
238
+ if self.use_timestep_transform:
239
+ timesteps = [
240
+ timestep_transform(
241
+ t,
242
+ input_img_size_numel=input_img_size_numel,
243
+ base_img_size_numel=self.base_img_size_numel,
244
+ num_train_timesteps=self.num_train_timesteps,
245
+ spatial_dim=self.spatial_dim,
246
+ )
247
+ for t in timesteps
248
+ ]
249
+ timesteps_np = np.array(timesteps).astype(np.float16)
250
+ if self.use_discrete_timesteps:
251
+ timesteps_np = timesteps_np.astype(np.int64)
252
+ self.timesteps = torch.from_numpy(timesteps_np).to(device)
253
+ self.timesteps += self.steps_offset
254
+
255
+ def sample_timesteps(self, x_start):
256
+ """
257
+ Randomly samples training timesteps using the chosen sampling method.
258
+
259
+ Args:
260
+ x_start (torch.Tensor): The input tensor for sampling.
261
+
262
+ Returns:
263
+ torch.Tensor: Sampled timesteps.
264
+ """
265
+ if self.sample_method == "uniform":
266
+ t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_train_timesteps
267
+ elif self.sample_method == "logit-normal":
268
+ t = self.sample_t(x_start) * self.num_train_timesteps
269
+
270
+ if self.use_discrete_timesteps:
271
+ t = t.long()
272
+
273
+ if self.use_timestep_transform:
274
+ input_img_size_numel = torch.prod(torch.tensor(x_start.shape[2:]))
275
+ t = timestep_transform(
276
+ t,
277
+ input_img_size_numel=input_img_size_numel,
278
+ base_img_size_numel=self.base_img_size_numel,
279
+ num_train_timesteps=self.num_train_timesteps,
280
+ spatial_dim=len(x_start.shape) - 2,
281
+ )
282
+
283
+ return t
284
+
285
+ def step(
286
+ self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep: Union[int, None] = None
287
+ ) -> tuple[torch.Tensor, torch.Tensor]:
288
+ """
289
+ Predicts the next sample in the diffusion process.
290
+
291
+ Args:
292
+ model_output (torch.Tensor): Output from the trained diffusion model.
293
+ timestep (int): Current timestep in the diffusion chain.
294
+ sample (torch.Tensor): Current sample in the process.
295
+ next_timestep (Union[int, None]): Optional next timestep.
296
+
297
+ Returns:
298
+ tuple[torch.Tensor, torch.Tensor]: Predicted sample at the next step and additional info.
299
+ """
300
+ # Ensure num_inference_steps exists and is a valid integer
301
+ if not hasattr(self, "num_inference_steps") or not isinstance(self.num_inference_steps, int):
302
+ raise AttributeError(
303
+ "num_inference_steps is missing or not an integer in the class."
304
+ "Please run self.set_timesteps(num_inference_steps,device,input_img_size_numel) to set it."
305
+ )
306
+
307
+ v_pred = model_output
308
+
309
+ if next_timestep is not None:
310
+ next_timestep = int(next_timestep)
311
+ dt: float = (
312
+ float(timestep - next_timestep) / self.num_train_timesteps
313
+ ) # Now next_timestep is guaranteed to be int
314
+ else:
315
+ dt = (
316
+ 1.0 / float(self.num_inference_steps) if self.num_inference_steps > 0 else 0.0
317
+ ) # Avoid division by zero
318
+
319
+ pred_post_sample = sample + v_pred * dt
320
+ pred_original_sample = sample + v_pred * timestep / self.num_train_timesteps
321
+
322
+ return pred_post_sample, pred_original_sample