monai-weekly 1.5.dev2509__py3-none-any.whl → 1.5.dev2511__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/deepedit/interaction.py +1 -1
- monai/apps/deepgrow/interaction.py +1 -1
- monai/apps/detection/networks/retinanet_detector.py +1 -1
- monai/apps/detection/networks/retinanet_network.py +5 -5
- monai/apps/detection/utils/box_coder.py +2 -2
- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +4 -0
- monai/apps/mmars/mmars.py +1 -1
- monai/apps/reconstruction/networks/blocks/varnetblock.py +1 -1
- monai/bundle/scripts.py +3 -4
- monai/data/dataset.py +2 -9
- monai/data/utils.py +1 -1
- monai/data/video_dataset.py +1 -1
- monai/engines/evaluator.py +11 -16
- monai/engines/trainer.py +11 -17
- monai/engines/utils.py +1 -1
- monai/engines/workflow.py +2 -2
- monai/fl/client/monai_algo.py +1 -1
- monai/handlers/checkpoint_loader.py +1 -1
- monai/inferers/inferer.py +33 -13
- monai/inferers/merger.py +16 -13
- monai/losses/perceptual.py +1 -1
- monai/losses/sure_loss.py +1 -1
- monai/networks/blocks/crossattention.py +1 -6
- monai/networks/blocks/feature_pyramid_network.py +4 -2
- monai/networks/blocks/selfattention.py +1 -6
- monai/networks/blocks/upsample.py +3 -11
- monai/networks/layers/vector_quantizer.py +2 -2
- monai/networks/nets/hovernet.py +5 -4
- monai/networks/nets/resnet.py +2 -2
- monai/networks/nets/senet.py +1 -1
- monai/networks/nets/swin_unetr.py +46 -49
- monai/networks/nets/transchex.py +3 -2
- monai/networks/nets/vista3d.py +7 -7
- monai/networks/schedulers/__init__.py +1 -0
- monai/networks/schedulers/rectified_flow.py +322 -0
- monai/networks/utils.py +5 -4
- monai/transforms/intensity/array.py +1 -1
- monai/transforms/spatial/array.py +6 -6
- monai/utils/misc.py +1 -1
- monai/utils/state_cacher.py +1 -1
- {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2511.dist-info}/METADATA +4 -3
- {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2511.dist-info}/RECORD +66 -64
- {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2511.dist-info}/WHEEL +1 -1
- tests/bundle/test_bundle_download.py +16 -6
- tests/config/test_cv2_dist.py +1 -2
- tests/inferers/test_controlnet_inferers.py +96 -32
- tests/inferers/test_diffusion_inferer.py +99 -1
- tests/inferers/test_latent_diffusion_inferer.py +217 -211
- tests/integration/test_integration_bundle_run.py +2 -4
- tests/integration/test_integration_classification_2d.py +1 -1
- tests/integration/test_integration_fast_train.py +2 -2
- tests/integration/test_integration_segmentation_3d.py +1 -1
- tests/metrics/test_compute_multiscalessim_metric.py +3 -3
- tests/metrics/test_surface_dice.py +3 -3
- tests/networks/nets/test_autoencoderkl.py +1 -1
- tests/networks/nets/test_controlnet.py +1 -1
- tests/networks/nets/test_diffusion_model_unet.py +1 -1
- tests/networks/nets/test_network_consistency.py +1 -1
- tests/networks/nets/test_swin_unetr.py +1 -1
- tests/networks/nets/test_transformer.py +1 -1
- tests/networks/schedulers/test_scheduler_rflow.py +105 -0
- tests/networks/test_save_state.py +1 -1
- {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2511.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2511.dist-info}/top_level.txt +0 -0
@@ -54,7 +54,9 @@ from __future__ import annotations
|
|
54
54
|
|
55
55
|
from collections import OrderedDict
|
56
56
|
from collections.abc import Callable
|
57
|
+
from typing import cast
|
57
58
|
|
59
|
+
import torch
|
58
60
|
import torch.nn.functional as F
|
59
61
|
from torch import Tensor, nn
|
60
62
|
|
@@ -194,8 +196,8 @@ class FeaturePyramidNetwork(nn.Module):
|
|
194
196
|
conv_type_: type[nn.Module] = Conv[Conv.CONV, spatial_dims]
|
195
197
|
for m in self.modules():
|
196
198
|
if isinstance(m, conv_type_):
|
197
|
-
nn.init.kaiming_uniform_(m.weight, a=1)
|
198
|
-
nn.init.constant_(m.bias, 0.0)
|
199
|
+
nn.init.kaiming_uniform_(cast(torch.Tensor, m.weight), a=1)
|
200
|
+
nn.init.constant_(cast(torch.Tensor, m.bias), 0.0)
|
199
201
|
|
200
202
|
if extra_blocks is not None:
|
201
203
|
if not isinstance(extra_blocks, ExtraFPNBlock):
|
@@ -18,7 +18,7 @@ import torch.nn as nn
|
|
18
18
|
import torch.nn.functional as F
|
19
19
|
|
20
20
|
from monai.networks.layers.utils import get_rel_pos_embedding_layer
|
21
|
-
from monai.utils import optional_import
|
21
|
+
from monai.utils import optional_import
|
22
22
|
|
23
23
|
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
|
24
24
|
|
@@ -90,11 +90,6 @@ class SABlock(nn.Module):
|
|
90
90
|
if causal and sequence_length is None:
|
91
91
|
raise ValueError("sequence_length is necessary for causal attention.")
|
92
92
|
|
93
|
-
if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
|
94
|
-
raise ValueError(
|
95
|
-
"use_flash_attention is only supported for PyTorch versions >= 2.0."
|
96
|
-
"Upgrade your PyTorch or set the flag to False."
|
97
|
-
)
|
98
93
|
if use_flash_attention and save_attn:
|
99
94
|
raise ValueError(
|
100
95
|
"save_attn has been set to True, but use_flash_attention is also set"
|
@@ -17,8 +17,8 @@ import torch
|
|
17
17
|
import torch.nn as nn
|
18
18
|
|
19
19
|
from monai.networks.layers.factories import Conv, Pad, Pool
|
20
|
-
from monai.networks.utils import
|
21
|
-
from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option
|
20
|
+
from monai.networks.utils import icnr_init, pixelshuffle
|
21
|
+
from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option
|
22
22
|
|
23
23
|
__all__ = ["Upsample", "UpSample", "SubpixelUpsample", "Subpixelupsample", "SubpixelUpSample"]
|
24
24
|
|
@@ -164,15 +164,7 @@ class UpSample(nn.Sequential):
|
|
164
164
|
align_corners=align_corners,
|
165
165
|
)
|
166
166
|
|
167
|
-
|
168
|
-
# https://github.com/pytorch/pytorch/issues/86679. This issue is solved in PyTorch 2.1
|
169
|
-
if pytorch_after(major=2, minor=1):
|
170
|
-
self.add_module("upsample_non_trainable", upsample)
|
171
|
-
else:
|
172
|
-
self.add_module(
|
173
|
-
"upsample_non_trainable",
|
174
|
-
CastTempType(initial_type=torch.bfloat16, temporary_type=torch.float32, submodule=upsample),
|
175
|
-
)
|
167
|
+
self.add_module("upsample_non_trainable", upsample)
|
176
168
|
if post_conv:
|
177
169
|
self.add_module("postconv", post_conv)
|
178
170
|
elif up_mode == UpsampleMode.PIXELSHUFFLE:
|
@@ -100,7 +100,7 @@ class EMAQuantizer(nn.Module):
|
|
100
100
|
torch.Tensor: Quantization indices of shape [B,H,W,D,1]
|
101
101
|
|
102
102
|
"""
|
103
|
-
with torch.
|
103
|
+
with torch.autocast("cuda", enabled=False):
|
104
104
|
encoding_indices_view = list(inputs.shape)
|
105
105
|
del encoding_indices_view[1]
|
106
106
|
|
@@ -138,7 +138,7 @@ class EMAQuantizer(nn.Module):
|
|
138
138
|
Returns:
|
139
139
|
torch.Tensor: Quantize space representation of encoding_indices in channel first format.
|
140
140
|
"""
|
141
|
-
with torch.
|
141
|
+
with torch.autocast("cuda", enabled=False):
|
142
142
|
embedding: torch.Tensor = (
|
143
143
|
self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous()
|
144
144
|
)
|
monai/networks/nets/hovernet.py
CHANGED
@@ -633,9 +633,9 @@ def _remap_preact_resnet_model(model_url: str):
|
|
633
633
|
# download the pretrained weights into torch hub's default dir
|
634
634
|
weights_dir = os.path.join(torch.hub.get_dir(), "preact-resnet50.pth")
|
635
635
|
download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)
|
636
|
-
|
637
|
-
|
638
|
-
|
636
|
+
map_location = None if torch.cuda.is_available() else torch.device("cpu")
|
637
|
+
state_dict = torch.load(weights_dir, map_location=map_location, weights_only=True)["desc"]
|
638
|
+
|
639
639
|
for key in list(state_dict.keys()):
|
640
640
|
new_key = None
|
641
641
|
if pattern_conv0.match(key):
|
@@ -668,7 +668,8 @@ def _remap_standard_resnet_model(model_url: str, state_dict_key: str | None = No
|
|
668
668
|
# download the pretrained weights into torch hub's default dir
|
669
669
|
weights_dir = os.path.join(torch.hub.get_dir(), "resnet50.pth")
|
670
670
|
download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)
|
671
|
-
|
671
|
+
map_location = None if torch.cuda.is_available() else torch.device("cpu")
|
672
|
+
state_dict = torch.load(weights_dir, map_location=map_location, weights_only=True)
|
672
673
|
if state_dict_key is not None:
|
673
674
|
state_dict = state_dict[state_dict_key]
|
674
675
|
|
monai/networks/nets/resnet.py
CHANGED
@@ -493,7 +493,7 @@ def _resnet(
|
|
493
493
|
if isinstance(pretrained, str):
|
494
494
|
if Path(pretrained).exists():
|
495
495
|
logger.info(f"Loading weights from {pretrained}...")
|
496
|
-
model_state_dict = torch.load(pretrained, map_location=device)
|
496
|
+
model_state_dict = torch.load(pretrained, map_location=device, weights_only=True)
|
497
497
|
else:
|
498
498
|
# Throw error
|
499
499
|
raise FileNotFoundError("The pretrained checkpoint file is not found")
|
@@ -665,7 +665,7 @@ def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", dat
|
|
665
665
|
raise EntryNotFoundError(
|
666
666
|
f"{filename} not found on {medicalnet_huggingface_repo_basename}{resnet_depth}"
|
667
667
|
) from None
|
668
|
-
checkpoint = torch.load(pretrained_path, map_location=torch.device(device))
|
668
|
+
checkpoint = torch.load(pretrained_path, map_location=torch.device(device), weights_only=True)
|
669
669
|
else:
|
670
670
|
raise NotImplementedError("Supported resnet_depth are: [10, 18, 34, 50, 101, 152, 200]")
|
671
671
|
logger.info(f"{filename} downloaded")
|
monai/networks/nets/senet.py
CHANGED
@@ -302,7 +302,7 @@ def _load_state_dict(model: nn.Module, arch: str, progress: bool):
|
|
302
302
|
|
303
303
|
if isinstance(model_url, dict):
|
304
304
|
download_url(model_url["url"], filepath=model_url["filename"])
|
305
|
-
state_dict = torch.load(model_url["filename"], map_location=None)
|
305
|
+
state_dict = torch.load(model_url["filename"], map_location=None, weights_only=True)
|
306
306
|
else:
|
307
307
|
state_dict = load_state_dict_from_url(model_url, progress=progress)
|
308
308
|
for key in list(state_dict.keys()):
|
@@ -272,53 +272,50 @@ class SwinUNETR(nn.Module):
|
|
272
272
|
self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels)
|
273
273
|
|
274
274
|
def load_from(self, weights):
|
275
|
+
layers1_0: BasicLayer = self.swinViT.layers1[0] # type: ignore[assignment]
|
276
|
+
layers2_0: BasicLayer = self.swinViT.layers2[0] # type: ignore[assignment]
|
277
|
+
layers3_0: BasicLayer = self.swinViT.layers3[0] # type: ignore[assignment]
|
278
|
+
layers4_0: BasicLayer = self.swinViT.layers4[0] # type: ignore[assignment]
|
279
|
+
wstate = weights["state_dict"]
|
280
|
+
|
275
281
|
with torch.no_grad():
|
276
|
-
self.swinViT.patch_embed.proj.weight.copy_(
|
277
|
-
self.swinViT.patch_embed.proj.bias.copy_(
|
278
|
-
for bname, block in
|
279
|
-
block.load_from(weights, n_block=bname, layer="layers1")
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
)
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
weights
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
self.swinViT.layers4[0].downsample.reduction.weight.copy_(
|
314
|
-
weights["state_dict"]["module.layers4.0.downsample.reduction.weight"]
|
315
|
-
)
|
316
|
-
self.swinViT.layers4[0].downsample.norm.weight.copy_(
|
317
|
-
weights["state_dict"]["module.layers4.0.downsample.norm.weight"]
|
318
|
-
)
|
319
|
-
self.swinViT.layers4[0].downsample.norm.bias.copy_(
|
320
|
-
weights["state_dict"]["module.layers4.0.downsample.norm.bias"]
|
321
|
-
)
|
282
|
+
self.swinViT.patch_embed.proj.weight.copy_(wstate["module.patch_embed.proj.weight"])
|
283
|
+
self.swinViT.patch_embed.proj.bias.copy_(wstate["module.patch_embed.proj.bias"])
|
284
|
+
for bname, block in layers1_0.blocks.named_children():
|
285
|
+
block.load_from(weights, n_block=bname, layer="layers1") # type: ignore[operator]
|
286
|
+
|
287
|
+
if layers1_0.downsample is not None:
|
288
|
+
d = layers1_0.downsample
|
289
|
+
d.reduction.weight.copy_(wstate["module.layers1.0.downsample.reduction.weight"]) # type: ignore
|
290
|
+
d.norm.weight.copy_(wstate["module.layers1.0.downsample.norm.weight"]) # type: ignore
|
291
|
+
d.norm.bias.copy_(wstate["module.layers1.0.downsample.norm.bias"]) # type: ignore
|
292
|
+
|
293
|
+
for bname, block in layers2_0.blocks.named_children():
|
294
|
+
block.load_from(weights, n_block=bname, layer="layers2") # type: ignore[operator]
|
295
|
+
|
296
|
+
if layers2_0.downsample is not None:
|
297
|
+
d = layers2_0.downsample
|
298
|
+
d.reduction.weight.copy_(wstate["module.layers2.0.downsample.reduction.weight"]) # type: ignore
|
299
|
+
d.norm.weight.copy_(wstate["module.layers2.0.downsample.norm.weight"]) # type: ignore
|
300
|
+
d.norm.bias.copy_(wstate["module.layers2.0.downsample.norm.bias"]) # type: ignore
|
301
|
+
|
302
|
+
for bname, block in layers3_0.blocks.named_children():
|
303
|
+
block.load_from(weights, n_block=bname, layer="layers3") # type: ignore[operator]
|
304
|
+
|
305
|
+
if layers3_0.downsample is not None:
|
306
|
+
d = layers3_0.downsample
|
307
|
+
d.reduction.weight.copy_(wstate["module.layers3.0.downsample.reduction.weight"]) # type: ignore
|
308
|
+
d.norm.weight.copy_(wstate["module.layers3.0.downsample.norm.weight"]) # type: ignore
|
309
|
+
d.norm.bias.copy_(wstate["module.layers3.0.downsample.norm.bias"]) # type: ignore
|
310
|
+
|
311
|
+
for bname, block in layers4_0.blocks.named_children():
|
312
|
+
block.load_from(weights, n_block=bname, layer="layers4") # type: ignore[operator]
|
313
|
+
|
314
|
+
if layers4_0.downsample is not None:
|
315
|
+
d = layers4_0.downsample
|
316
|
+
d.reduction.weight.copy_(wstate["module.layers4.0.downsample.reduction.weight"]) # type: ignore
|
317
|
+
d.norm.weight.copy_(wstate["module.layers4.0.downsample.norm.weight"]) # type: ignore
|
318
|
+
d.norm.bias.copy_(wstate["module.layers4.0.downsample.norm.bias"]) # type: ignore
|
322
319
|
|
323
320
|
@torch.jit.unused
|
324
321
|
def _check_input_size(self, spatial_shape):
|
@@ -532,7 +529,7 @@ class WindowAttention(nn.Module):
|
|
532
529
|
q = q * self.scale
|
533
530
|
attn = q @ k.transpose(-2, -1)
|
534
531
|
relative_position_bias = self.relative_position_bias_table[
|
535
|
-
self.relative_position_index.clone()[:n, :n].reshape(-1)
|
532
|
+
self.relative_position_index.clone()[:n, :n].reshape(-1) # type: ignore[operator]
|
536
533
|
].reshape(n, n, -1)
|
537
534
|
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
|
538
535
|
attn = attn + relative_position_bias.unsqueeze(0)
|
@@ -691,7 +688,7 @@ class SwinTransformerBlock(nn.Module):
|
|
691
688
|
self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]])
|
692
689
|
self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]])
|
693
690
|
self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]])
|
694
|
-
self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]])
|
691
|
+
self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]]) # type: ignore[operator]
|
695
692
|
self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]])
|
696
693
|
self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]])
|
697
694
|
self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]])
|
@@ -1118,7 +1115,7 @@ def filter_swinunetr(key, value):
|
|
1118
1115
|
)
|
1119
1116
|
ssl_weights_path = "./ssl_pretrained_weights.pth"
|
1120
1117
|
download_url(resource, ssl_weights_path)
|
1121
|
-
ssl_weights = torch.load(ssl_weights_path)["model"]
|
1118
|
+
ssl_weights = torch.load(ssl_weights_path, weights_only=True)["model"]
|
1122
1119
|
|
1123
1120
|
dst_dict, loaded, not_loaded = copy_model_state(model, ssl_weights, filter_func=filter_swinunetr)
|
1124
1121
|
|
monai/networks/nets/transchex.py
CHANGED
@@ -43,7 +43,7 @@ class BertPreTrainedModel(nn.Module):
|
|
43
43
|
|
44
44
|
def init_bert_weights(self, module):
|
45
45
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
46
|
-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
46
|
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) # type: ignore[union-attr,arg-type]
|
47
47
|
elif isinstance(module, torch.nn.LayerNorm):
|
48
48
|
module.bias.data.zero_()
|
49
49
|
module.weight.data.fill_(1.0)
|
@@ -68,7 +68,8 @@ class BertPreTrainedModel(nn.Module):
|
|
68
68
|
weights_path = cached_file(path_or_repo_id, filename, cache_dir=cache_dir)
|
69
69
|
model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs)
|
70
70
|
if state_dict is None and not from_tf:
|
71
|
-
|
71
|
+
map_location = "cpu" if not torch.cuda.is_available() else None
|
72
|
+
state_dict = torch.load(weights_path, map_location=map_location, weights_only=True)
|
72
73
|
if from_tf:
|
73
74
|
return load_tf_weights_in_bert(model, weights_path)
|
74
75
|
old_keys = []
|
monai/networks/nets/vista3d.py
CHANGED
@@ -315,7 +315,7 @@ class VISTA3D(nn.Module):
|
|
315
315
|
"""
|
316
316
|
if auto_freeze != self.auto_freeze:
|
317
317
|
if hasattr(self.image_encoder, "set_auto_grad"):
|
318
|
-
self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze)
|
318
|
+
self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) # type: ignore[operator]
|
319
319
|
else:
|
320
320
|
for param in self.image_encoder.parameters():
|
321
321
|
param.requires_grad = (not auto_freeze) and (not point_freeze)
|
@@ -325,7 +325,7 @@ class VISTA3D(nn.Module):
|
|
325
325
|
|
326
326
|
if point_freeze != self.point_freeze:
|
327
327
|
if hasattr(self.image_encoder, "set_auto_grad"):
|
328
|
-
self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze)
|
328
|
+
self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) # type: ignore[operator]
|
329
329
|
else:
|
330
330
|
for param in self.image_encoder.parameters():
|
331
331
|
param.requires_grad = (not auto_freeze) and (not point_freeze)
|
@@ -543,10 +543,10 @@ class PointMappingSAM(nn.Module):
|
|
543
543
|
point_embedding = self.pe_layer.forward_with_coords(points, out_shape) # type: ignore
|
544
544
|
point_embedding[point_labels == -1] = 0.0
|
545
545
|
point_embedding[point_labels == -1] += self.not_a_point_embed.weight
|
546
|
-
point_embedding[point_labels == 0] += self.point_embeddings[0].weight
|
547
|
-
point_embedding[point_labels == 1] += self.point_embeddings[1].weight
|
548
|
-
point_embedding[point_labels == 2] += self.point_embeddings[0].weight + self.special_class_embed.weight
|
549
|
-
point_embedding[point_labels == 3] += self.point_embeddings[1].weight + self.special_class_embed.weight
|
546
|
+
point_embedding[point_labels == 0] += self.point_embeddings[0].weight # type: ignore[arg-type]
|
547
|
+
point_embedding[point_labels == 1] += self.point_embeddings[1].weight # type: ignore[arg-type]
|
548
|
+
point_embedding[point_labels == 2] += self.point_embeddings[0].weight + self.special_class_embed.weight # type: ignore[operator]
|
549
|
+
point_embedding[point_labels == 3] += self.point_embeddings[1].weight + self.special_class_embed.weight # type: ignore[operator]
|
550
550
|
output_tokens = self.mask_tokens.weight
|
551
551
|
|
552
552
|
output_tokens = output_tokens.unsqueeze(0).expand(point_embedding.size(0), -1, -1)
|
@@ -884,7 +884,7 @@ class PositionEmbeddingRandom(nn.Module):
|
|
884
884
|
coords = 2 * coords - 1
|
885
885
|
# [bs=1,N=2,2] @ [2,128]
|
886
886
|
# [bs=1, N=2, 128]
|
887
|
-
coords = coords @ self.positional_encoding_gaussian_matrix
|
887
|
+
coords = coords @ self.positional_encoding_gaussian_matrix # type: ignore[operator]
|
888
888
|
coords = 2 * np.pi * coords
|
889
889
|
# outputs d_1 x ... x d_n x C shape
|
890
890
|
# [bs=1, N=2, 128+128=256]
|
@@ -0,0 +1,322 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
#
|
12
|
+
# =========================================================================
|
13
|
+
# Adapted from https://github.com/hpcaitech/Open-Sora/blob/main/opensora/schedulers/rf/rectified_flow.py
|
14
|
+
# which has the following license:
|
15
|
+
# https://github.com/hpcaitech/Open-Sora/blob/main/LICENSE
|
16
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
17
|
+
# you may not use this file except in compliance with the License.
|
18
|
+
# You may obtain a copy of the License at
|
19
|
+
#
|
20
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
21
|
+
#
|
22
|
+
# Unless required by applicable law or agreed to in writing, software
|
23
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
24
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
25
|
+
# See the License for the specific language governing permissions and
|
26
|
+
# limitations under the License.
|
27
|
+
# =========================================================================
|
28
|
+
|
29
|
+
from __future__ import annotations
|
30
|
+
|
31
|
+
from typing import Union
|
32
|
+
|
33
|
+
import numpy as np
|
34
|
+
import torch
|
35
|
+
from torch.distributions import LogisticNormal
|
36
|
+
|
37
|
+
from monai.utils import StrEnum
|
38
|
+
|
39
|
+
from .ddpm import DDPMPredictionType
|
40
|
+
from .scheduler import Scheduler
|
41
|
+
|
42
|
+
|
43
|
+
class RFlowPredictionType(StrEnum):
|
44
|
+
"""
|
45
|
+
Set of valid prediction type names for the RFlow scheduler's `prediction_type` argument.
|
46
|
+
|
47
|
+
v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf
|
48
|
+
"""
|
49
|
+
|
50
|
+
V_PREDICTION = DDPMPredictionType.V_PREDICTION
|
51
|
+
|
52
|
+
|
53
|
+
def timestep_transform(
|
54
|
+
t, input_img_size_numel, base_img_size_numel=32 * 32 * 32, scale=1.0, num_train_timesteps=1000, spatial_dim=3
|
55
|
+
):
|
56
|
+
"""
|
57
|
+
Applies a transformation to the timestep based on image resolution scaling.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
t (torch.Tensor): The original timestep(s).
|
61
|
+
input_img_size_numel (torch.Tensor): The input image's size (H * W * D).
|
62
|
+
base_img_size_numel (int): reference H*W*D size, usually smaller than input_img_size_numel.
|
63
|
+
scale (float): Scaling factor for the transformation.
|
64
|
+
num_train_timesteps (int): Total number of training timesteps.
|
65
|
+
spatial_dim (int): Number of spatial dimensions in the image.
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
torch.Tensor: Transformed timestep(s).
|
69
|
+
"""
|
70
|
+
t = t / num_train_timesteps
|
71
|
+
ratio_space = (input_img_size_numel / base_img_size_numel) ** (1.0 / spatial_dim)
|
72
|
+
|
73
|
+
ratio = ratio_space * scale
|
74
|
+
new_t = ratio * t / (1 + (ratio - 1) * t)
|
75
|
+
|
76
|
+
new_t = new_t * num_train_timesteps
|
77
|
+
return new_t
|
78
|
+
|
79
|
+
|
80
|
+
class RFlowScheduler(Scheduler):
|
81
|
+
"""
|
82
|
+
A rectified flow scheduler for guiding the diffusion process in a generative model.
|
83
|
+
|
84
|
+
Supports uniform and logit-normal sampling methods, timestep transformation for
|
85
|
+
different resolutions, and noise addition during diffusion.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
num_train_timesteps (int): Total number of training timesteps.
|
89
|
+
use_discrete_timesteps (bool): Whether to use discrete timesteps.
|
90
|
+
sample_method (str): Training time step sampling method ('uniform' or 'logit-normal').
|
91
|
+
loc (float): Location parameter for logit-normal distribution, used only if sample_method='logit-normal'.
|
92
|
+
scale (float): Scale parameter for logit-normal distribution, used only if sample_method='logit-normal'.
|
93
|
+
use_timestep_transform (bool): Whether to apply timestep transformation.
|
94
|
+
If true, there will be more inference timesteps at early(noisy) stages for larger image volumes.
|
95
|
+
transform_scale (float): Scaling factor for timestep transformation, used only if use_timestep_transform=True.
|
96
|
+
steps_offset (int): Offset added to computed timesteps, used only if use_timestep_transform=True.
|
97
|
+
base_img_size_numel (int): Reference image volume size for scaling, used only if use_timestep_transform=True.
|
98
|
+
spatial_dim (int): 2 or 3, incidcating 2D or 3D images, used only if use_timestep_transform=True.
|
99
|
+
|
100
|
+
Example:
|
101
|
+
|
102
|
+
.. code-block:: python
|
103
|
+
|
104
|
+
# define a scheduler
|
105
|
+
noise_scheduler = RFlowScheduler(
|
106
|
+
num_train_timesteps = 1000,
|
107
|
+
use_discrete_timesteps = True,
|
108
|
+
sample_method = 'logit-normal',
|
109
|
+
use_timestep_transform = True,
|
110
|
+
base_img_size_numel = 32 * 32 * 32,
|
111
|
+
spatial_dim = 3
|
112
|
+
)
|
113
|
+
|
114
|
+
# during training
|
115
|
+
inputs = torch.ones(2,4,64,64,32)
|
116
|
+
noise = torch.randn_like(inputs)
|
117
|
+
timesteps = noise_scheduler.sample_timesteps(inputs)
|
118
|
+
noisy_inputs = noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
|
119
|
+
predicted_velocity = diffusion_unet(
|
120
|
+
x=noisy_inputs,
|
121
|
+
timesteps=timesteps
|
122
|
+
)
|
123
|
+
loss = loss_l1(predicted_velocity, (inputs - noise))
|
124
|
+
|
125
|
+
# during inference
|
126
|
+
noisy_inputs = torch.randn(2,4,64,64,32)
|
127
|
+
input_img_size_numel = torch.prod(torch.tensor(noisy_inputs.shape[-3:])
|
128
|
+
noise_scheduler.set_timesteps(
|
129
|
+
num_inference_steps=30, input_img_size_numel=input_img_size_numel)
|
130
|
+
)
|
131
|
+
all_next_timesteps = torch.cat(
|
132
|
+
(noise_scheduler.timesteps[1:], torch.tensor([0], dtype=noise_scheduler.timesteps.dtype))
|
133
|
+
)
|
134
|
+
for t, next_t in tqdm(
|
135
|
+
zip(noise_scheduler.timesteps, all_next_timesteps),
|
136
|
+
total=min(len(noise_scheduler.timesteps), len(all_next_timesteps)),
|
137
|
+
):
|
138
|
+
predicted_velocity = diffusion_unet(
|
139
|
+
x=noisy_inputs,
|
140
|
+
timesteps=timesteps
|
141
|
+
)
|
142
|
+
noisy_inputs, _ = noise_scheduler.step(predicted_velocity, t, noisy_inputs, next_t)
|
143
|
+
final_output = noisy_inputs
|
144
|
+
"""
|
145
|
+
|
146
|
+
def __init__(
|
147
|
+
self,
|
148
|
+
num_train_timesteps: int = 1000,
|
149
|
+
use_discrete_timesteps: bool = True,
|
150
|
+
sample_method: str = "uniform",
|
151
|
+
loc: float = 0.0,
|
152
|
+
scale: float = 1.0,
|
153
|
+
use_timestep_transform: bool = False,
|
154
|
+
transform_scale: float = 1.0,
|
155
|
+
steps_offset: int = 0,
|
156
|
+
base_img_size_numel: int = 32 * 32 * 32,
|
157
|
+
spatial_dim: int = 3,
|
158
|
+
):
|
159
|
+
# rectified flow only accepts velocity prediction
|
160
|
+
self.prediction_type = RFlowPredictionType.V_PREDICTION
|
161
|
+
|
162
|
+
self.num_train_timesteps = num_train_timesteps
|
163
|
+
self.use_discrete_timesteps = use_discrete_timesteps
|
164
|
+
self.base_img_size_numel = base_img_size_numel
|
165
|
+
self.spatial_dim = spatial_dim
|
166
|
+
|
167
|
+
# sample method
|
168
|
+
if sample_method not in ["uniform", "logit-normal"]:
|
169
|
+
raise ValueError(
|
170
|
+
f"sample_method = {sample_method}, which has to be chosen from ['uniform', 'logit-normal']."
|
171
|
+
)
|
172
|
+
self.sample_method = sample_method
|
173
|
+
if sample_method == "logit-normal":
|
174
|
+
self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale]))
|
175
|
+
self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device)
|
176
|
+
|
177
|
+
# timestep transform
|
178
|
+
self.use_timestep_transform = use_timestep_transform
|
179
|
+
self.transform_scale = transform_scale
|
180
|
+
self.steps_offset = steps_offset
|
181
|
+
|
182
|
+
def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
183
|
+
"""
|
184
|
+
Add noise to the original samples.
|
185
|
+
|
186
|
+
Args:
|
187
|
+
original_samples: original samples
|
188
|
+
noise: noise to add to samples
|
189
|
+
timesteps: timesteps tensor with shape of (N,), indicating the timestep to be computed for each sample.
|
190
|
+
|
191
|
+
Returns:
|
192
|
+
noisy_samples: sample with added noise
|
193
|
+
"""
|
194
|
+
timepoints: torch.Tensor = timesteps.float() / self.num_train_timesteps
|
195
|
+
timepoints = 1 - timepoints # [1,1/1000]
|
196
|
+
|
197
|
+
# expand timepoint to noise shape
|
198
|
+
if noise.ndim == 5:
|
199
|
+
timepoints = timepoints[..., None, None, None, None].expand(-1, *noise.shape[1:])
|
200
|
+
elif noise.ndim == 4:
|
201
|
+
timepoints = timepoints[..., None, None, None].expand(-1, *noise.shape[1:])
|
202
|
+
else:
|
203
|
+
raise ValueError(f"noise tensor has to be 4D or 5D tensor, yet got shape of {noise.shape}")
|
204
|
+
|
205
|
+
noisy_samples: torch.Tensor = timepoints * original_samples + (1 - timepoints) * noise
|
206
|
+
|
207
|
+
return noisy_samples
|
208
|
+
|
209
|
+
def set_timesteps(
|
210
|
+
self,
|
211
|
+
num_inference_steps: int,
|
212
|
+
device: str | torch.device | None = None,
|
213
|
+
input_img_size_numel: int | None = None,
|
214
|
+
) -> None:
|
215
|
+
"""
|
216
|
+
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
217
|
+
|
218
|
+
Args:
|
219
|
+
num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
|
220
|
+
device: target device to put the data.
|
221
|
+
input_img_size_numel: int, H*W*D of the image, used with self.use_timestep_transform is True.
|
222
|
+
"""
|
223
|
+
if num_inference_steps > self.num_train_timesteps or num_inference_steps < 1:
|
224
|
+
raise ValueError(
|
225
|
+
f"`num_inference_steps`: {num_inference_steps} should be at least 1, "
|
226
|
+
"and cannot be larger than `self.num_train_timesteps`:"
|
227
|
+
f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
228
|
+
f" maximal {self.num_train_timesteps} timesteps."
|
229
|
+
)
|
230
|
+
|
231
|
+
self.num_inference_steps = num_inference_steps
|
232
|
+
# prepare timesteps
|
233
|
+
timesteps = [
|
234
|
+
(1.0 - i / self.num_inference_steps) * self.num_train_timesteps for i in range(self.num_inference_steps)
|
235
|
+
]
|
236
|
+
if self.use_discrete_timesteps:
|
237
|
+
timesteps = [int(round(t)) for t in timesteps]
|
238
|
+
if self.use_timestep_transform:
|
239
|
+
timesteps = [
|
240
|
+
timestep_transform(
|
241
|
+
t,
|
242
|
+
input_img_size_numel=input_img_size_numel,
|
243
|
+
base_img_size_numel=self.base_img_size_numel,
|
244
|
+
num_train_timesteps=self.num_train_timesteps,
|
245
|
+
spatial_dim=self.spatial_dim,
|
246
|
+
)
|
247
|
+
for t in timesteps
|
248
|
+
]
|
249
|
+
timesteps_np = np.array(timesteps).astype(np.float16)
|
250
|
+
if self.use_discrete_timesteps:
|
251
|
+
timesteps_np = timesteps_np.astype(np.int64)
|
252
|
+
self.timesteps = torch.from_numpy(timesteps_np).to(device)
|
253
|
+
self.timesteps += self.steps_offset
|
254
|
+
|
255
|
+
def sample_timesteps(self, x_start):
|
256
|
+
"""
|
257
|
+
Randomly samples training timesteps using the chosen sampling method.
|
258
|
+
|
259
|
+
Args:
|
260
|
+
x_start (torch.Tensor): The input tensor for sampling.
|
261
|
+
|
262
|
+
Returns:
|
263
|
+
torch.Tensor: Sampled timesteps.
|
264
|
+
"""
|
265
|
+
if self.sample_method == "uniform":
|
266
|
+
t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_train_timesteps
|
267
|
+
elif self.sample_method == "logit-normal":
|
268
|
+
t = self.sample_t(x_start) * self.num_train_timesteps
|
269
|
+
|
270
|
+
if self.use_discrete_timesteps:
|
271
|
+
t = t.long()
|
272
|
+
|
273
|
+
if self.use_timestep_transform:
|
274
|
+
input_img_size_numel = torch.prod(torch.tensor(x_start.shape[2:]))
|
275
|
+
t = timestep_transform(
|
276
|
+
t,
|
277
|
+
input_img_size_numel=input_img_size_numel,
|
278
|
+
base_img_size_numel=self.base_img_size_numel,
|
279
|
+
num_train_timesteps=self.num_train_timesteps,
|
280
|
+
spatial_dim=len(x_start.shape) - 2,
|
281
|
+
)
|
282
|
+
|
283
|
+
return t
|
284
|
+
|
285
|
+
def step(
|
286
|
+
self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep: Union[int, None] = None
|
287
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
288
|
+
"""
|
289
|
+
Predicts the next sample in the diffusion process.
|
290
|
+
|
291
|
+
Args:
|
292
|
+
model_output (torch.Tensor): Output from the trained diffusion model.
|
293
|
+
timestep (int): Current timestep in the diffusion chain.
|
294
|
+
sample (torch.Tensor): Current sample in the process.
|
295
|
+
next_timestep (Union[int, None]): Optional next timestep.
|
296
|
+
|
297
|
+
Returns:
|
298
|
+
tuple[torch.Tensor, torch.Tensor]: Predicted sample at the next step and additional info.
|
299
|
+
"""
|
300
|
+
# Ensure num_inference_steps exists and is a valid integer
|
301
|
+
if not hasattr(self, "num_inference_steps") or not isinstance(self.num_inference_steps, int):
|
302
|
+
raise AttributeError(
|
303
|
+
"num_inference_steps is missing or not an integer in the class."
|
304
|
+
"Please run self.set_timesteps(num_inference_steps,device,input_img_size_numel) to set it."
|
305
|
+
)
|
306
|
+
|
307
|
+
v_pred = model_output
|
308
|
+
|
309
|
+
if next_timestep is not None:
|
310
|
+
next_timestep = int(next_timestep)
|
311
|
+
dt: float = (
|
312
|
+
float(timestep - next_timestep) / self.num_train_timesteps
|
313
|
+
) # Now next_timestep is guaranteed to be int
|
314
|
+
else:
|
315
|
+
dt = (
|
316
|
+
1.0 / float(self.num_inference_steps) if self.num_inference_steps > 0 else 0.0
|
317
|
+
) # Avoid division by zero
|
318
|
+
|
319
|
+
pred_post_sample = sample + v_pred * dt
|
320
|
+
pred_original_sample = sample + v_pred * timestep / self.num_train_timesteps
|
321
|
+
|
322
|
+
return pred_post_sample, pred_original_sample
|