monai-weekly 1.5.dev2509__py3-none-any.whl → 1.5.dev2511__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/deepedit/interaction.py +1 -1
- monai/apps/deepgrow/interaction.py +1 -1
- monai/apps/detection/networks/retinanet_detector.py +1 -1
- monai/apps/detection/networks/retinanet_network.py +5 -5
- monai/apps/detection/utils/box_coder.py +2 -2
- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +4 -0
- monai/apps/mmars/mmars.py +1 -1
- monai/apps/reconstruction/networks/blocks/varnetblock.py +1 -1
- monai/bundle/scripts.py +3 -4
- monai/data/dataset.py +2 -9
- monai/data/utils.py +1 -1
- monai/data/video_dataset.py +1 -1
- monai/engines/evaluator.py +11 -16
- monai/engines/trainer.py +11 -17
- monai/engines/utils.py +1 -1
- monai/engines/workflow.py +2 -2
- monai/fl/client/monai_algo.py +1 -1
- monai/handlers/checkpoint_loader.py +1 -1
- monai/inferers/inferer.py +33 -13
- monai/inferers/merger.py +16 -13
- monai/losses/perceptual.py +1 -1
- monai/losses/sure_loss.py +1 -1
- monai/networks/blocks/crossattention.py +1 -6
- monai/networks/blocks/feature_pyramid_network.py +4 -2
- monai/networks/blocks/selfattention.py +1 -6
- monai/networks/blocks/upsample.py +3 -11
- monai/networks/layers/vector_quantizer.py +2 -2
- monai/networks/nets/hovernet.py +5 -4
- monai/networks/nets/resnet.py +2 -2
- monai/networks/nets/senet.py +1 -1
- monai/networks/nets/swin_unetr.py +46 -49
- monai/networks/nets/transchex.py +3 -2
- monai/networks/nets/vista3d.py +7 -7
- monai/networks/schedulers/__init__.py +1 -0
- monai/networks/schedulers/rectified_flow.py +322 -0
- monai/networks/utils.py +5 -4
- monai/transforms/intensity/array.py +1 -1
- monai/transforms/spatial/array.py +6 -6
- monai/utils/misc.py +1 -1
- monai/utils/state_cacher.py +1 -1
- {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2511.dist-info}/METADATA +4 -3
- {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2511.dist-info}/RECORD +66 -64
- {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2511.dist-info}/WHEEL +1 -1
- tests/bundle/test_bundle_download.py +16 -6
- tests/config/test_cv2_dist.py +1 -2
- tests/inferers/test_controlnet_inferers.py +96 -32
- tests/inferers/test_diffusion_inferer.py +99 -1
- tests/inferers/test_latent_diffusion_inferer.py +217 -211
- tests/integration/test_integration_bundle_run.py +2 -4
- tests/integration/test_integration_classification_2d.py +1 -1
- tests/integration/test_integration_fast_train.py +2 -2
- tests/integration/test_integration_segmentation_3d.py +1 -1
- tests/metrics/test_compute_multiscalessim_metric.py +3 -3
- tests/metrics/test_surface_dice.py +3 -3
- tests/networks/nets/test_autoencoderkl.py +1 -1
- tests/networks/nets/test_controlnet.py +1 -1
- tests/networks/nets/test_diffusion_model_unet.py +1 -1
- tests/networks/nets/test_network_consistency.py +1 -1
- tests/networks/nets/test_swin_unetr.py +1 -1
- tests/networks/nets/test_transformer.py +1 -1
- tests/networks/schedulers/test_scheduler_rflow.py +105 -0
- tests/networks/test_save_state.py +1 -1
- {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2511.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2511.dist-info}/top_level.txt +0 -0
monai/__init__.py
CHANGED
monai/_version.py
CHANGED
@@ -8,11 +8,11 @@ import json
|
|
8
8
|
|
9
9
|
version_json = '''
|
10
10
|
{
|
11
|
-
"date": "2025-03-
|
11
|
+
"date": "2025-03-16T02:30:38+0000",
|
12
12
|
"dirty": false,
|
13
13
|
"error": null,
|
14
|
-
"full-revisionid": "
|
15
|
-
"version": "1.5.
|
14
|
+
"full-revisionid": "7876647f87c763d854f9546bbc60e12f13af84a6",
|
15
|
+
"version": "1.5.dev2511"
|
16
16
|
}
|
17
17
|
''' # END VERSION_JSON
|
18
18
|
|
@@ -67,7 +67,7 @@ class Interaction:
|
|
67
67
|
engine.network.eval()
|
68
68
|
with torch.no_grad():
|
69
69
|
if engine.amp:
|
70
|
-
with torch.
|
70
|
+
with torch.autocast("cuda"):
|
71
71
|
predictions = engine.inferer(inputs, engine.network)
|
72
72
|
else:
|
73
73
|
predictions = engine.inferer(inputs, engine.network)
|
@@ -180,7 +180,7 @@ class RetinaNetDetector(nn.Module):
|
|
180
180
|
nesterov=True,
|
181
181
|
)
|
182
182
|
torch.save(detector.network.state_dict(), 'model.pt') # save model
|
183
|
-
detector.network.load_state_dict(torch.load('model.pt')) # load model
|
183
|
+
detector.network.load_state_dict(torch.load('model.pt', weights_only=True)) # load model
|
184
184
|
"""
|
185
185
|
|
186
186
|
def __init__(
|
@@ -88,8 +88,8 @@ class RetinaNetClassificationHead(nn.Module):
|
|
88
88
|
|
89
89
|
for layer in self.conv.children():
|
90
90
|
if isinstance(layer, conv_type): # type: ignore
|
91
|
-
torch.nn.init.normal_(layer.weight, std=0.01)
|
92
|
-
torch.nn.init.constant_(layer.bias, 0)
|
91
|
+
torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
|
92
|
+
torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type]
|
93
93
|
|
94
94
|
self.cls_logits = conv_type(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
|
95
95
|
torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
|
@@ -167,8 +167,8 @@ class RetinaNetRegressionHead(nn.Module):
|
|
167
167
|
|
168
168
|
for layer in self.conv.children():
|
169
169
|
if isinstance(layer, conv_type): # type: ignore
|
170
|
-
torch.nn.init.normal_(layer.weight, std=0.01)
|
171
|
-
torch.nn.init.zeros_(layer.bias)
|
170
|
+
torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
|
171
|
+
torch.nn.init.zeros_(layer.bias) # type: ignore[arg-type]
|
172
172
|
|
173
173
|
def forward(self, x: list[Tensor]) -> list[Tensor]:
|
174
174
|
"""
|
@@ -297,7 +297,7 @@ class RetinaNet(nn.Module):
|
|
297
297
|
)
|
298
298
|
self.feature_extractor = feature_extractor
|
299
299
|
|
300
|
-
self.feature_map_channels: int = self.feature_extractor.out_channels
|
300
|
+
self.feature_map_channels: int = self.feature_extractor.out_channels # type: ignore[assignment]
|
301
301
|
self.num_anchors = num_anchors
|
302
302
|
self.classification_head = RetinaNetClassificationHead(
|
303
303
|
self.feature_map_channels, self.num_anchors, self.num_classes, spatial_dims=self.spatial_dims
|
@@ -221,7 +221,7 @@ class BoxCoder:
|
|
221
221
|
|
222
222
|
pred_ctr_xyx_axis = dxyz_axis * whd_axis[:, None] + ctr_xyz_axis[:, None]
|
223
223
|
pred_whd_axis = torch.exp(dwhd_axis) * whd_axis[:, None]
|
224
|
-
pred_whd_axis = pred_whd_axis.to(dxyz_axis.dtype)
|
224
|
+
pred_whd_axis = pred_whd_axis.to(dxyz_axis.dtype) # type: ignore[union-attr]
|
225
225
|
|
226
226
|
# When convert float32 to float16, Inf or Nan may occur
|
227
227
|
if torch.isnan(pred_whd_axis).any() or torch.isinf(pred_whd_axis).any():
|
@@ -229,7 +229,7 @@ class BoxCoder:
|
|
229
229
|
|
230
230
|
# Distance from center to box's corner.
|
231
231
|
c_to_c_whd_axis = (
|
232
|
-
torch.tensor(0.5, dtype=pred_ctr_xyx_axis.dtype, device=pred_whd_axis.device) * pred_whd_axis
|
232
|
+
torch.tensor(0.5, dtype=pred_ctr_xyx_axis.dtype, device=pred_whd_axis.device) * pred_whd_axis # type: ignore[arg-type]
|
233
233
|
)
|
234
234
|
|
235
235
|
pred_boxes.append(pred_ctr_xyx_axis - c_to_c_whd_axis)
|
@@ -232,6 +232,10 @@ class MaisiConvolution(nn.Module):
|
|
232
232
|
if self.print_info:
|
233
233
|
logger.info(f"Number of splits: {self.num_splits}")
|
234
234
|
|
235
|
+
if self.dim_split <= 1 and self.num_splits <= 1:
|
236
|
+
x = self.conv(x)
|
237
|
+
return x
|
238
|
+
|
235
239
|
# compute size of splits
|
236
240
|
l = x.size(self.dim_split + 2)
|
237
241
|
split_size = l // self.num_splits
|
monai/apps/mmars/mmars.py
CHANGED
@@ -241,7 +241,7 @@ def load_from_mmar(
|
|
241
241
|
return torch.jit.load(_model_file, map_location=map_location)
|
242
242
|
|
243
243
|
# loading with `torch.load`
|
244
|
-
model_dict = torch.load(_model_file, map_location=map_location)
|
244
|
+
model_dict = torch.load(_model_file, map_location=map_location, weights_only=True)
|
245
245
|
if weights_only:
|
246
246
|
return model_dict.get(model_key, model_dict) # model_dict[model_key] or model_dict directly
|
247
247
|
|
@@ -55,7 +55,7 @@ class VarNetBlock(nn.Module):
|
|
55
55
|
Returns:
|
56
56
|
Output of DC block with the same shape as x
|
57
57
|
"""
|
58
|
-
return torch.where(mask, x - ref_kspace, self.zeros) * self.dc_weight
|
58
|
+
return torch.where(mask, x - ref_kspace, self.zeros) * self.dc_weight # type: ignore
|
59
59
|
|
60
60
|
def forward(self, current_kspace: Tensor, ref_kspace: Tensor, mask: Tensor, sens_maps: Tensor) -> Tensor:
|
61
61
|
"""
|
monai/bundle/scripts.py
CHANGED
@@ -760,7 +760,7 @@ def load(
|
|
760
760
|
if load_ts_module is True:
|
761
761
|
return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files)
|
762
762
|
# loading with `torch.load`
|
763
|
-
model_dict = torch.load(full_path, map_location=torch.device(device))
|
763
|
+
model_dict = torch.load(full_path, map_location=torch.device(device), weights_only=True)
|
764
764
|
|
765
765
|
if not isinstance(model_dict, Mapping):
|
766
766
|
warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.")
|
@@ -1279,9 +1279,8 @@ def verify_net_in_out(
|
|
1279
1279
|
if input_dtype == torch.float16:
|
1280
1280
|
# fp16 can only be executed in gpu mode
|
1281
1281
|
net.to("cuda")
|
1282
|
-
from torch.cuda.amp import autocast
|
1283
1282
|
|
1284
|
-
with autocast():
|
1283
|
+
with torch.autocast("cuda"):
|
1285
1284
|
output = net(test_data.cuda(), **extra_forward_args_)
|
1286
1285
|
net.to(device_)
|
1287
1286
|
else:
|
@@ -1330,7 +1329,7 @@ def _export(
|
|
1330
1329
|
# here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver
|
1331
1330
|
Checkpoint.load_objects(to_load={key_in_ckpt: net}, checkpoint=ckpt_file)
|
1332
1331
|
else:
|
1333
|
-
ckpt = torch.load(ckpt_file)
|
1332
|
+
ckpt = torch.load(ckpt_file, weights_only=True)
|
1334
1333
|
copy_model_state(dst=net, src=ckpt if key_in_ckpt == "" else ckpt[key_in_ckpt])
|
1335
1334
|
|
1336
1335
|
# Use the given converter to convert a model and save with metadata, config content
|
monai/data/dataset.py
CHANGED
@@ -22,7 +22,6 @@ import time
|
|
22
22
|
import warnings
|
23
23
|
from collections.abc import Callable, Sequence
|
24
24
|
from copy import copy, deepcopy
|
25
|
-
from inspect import signature
|
26
25
|
from multiprocessing.managers import ListProxy
|
27
26
|
from multiprocessing.pool import ThreadPool
|
28
27
|
from pathlib import Path
|
@@ -372,10 +371,7 @@ class PersistentDataset(Dataset):
|
|
372
371
|
|
373
372
|
if hashfile is not None and hashfile.is_file(): # cache hit
|
374
373
|
try:
|
375
|
-
|
376
|
-
return torch.load(hashfile, weights_only=False)
|
377
|
-
else:
|
378
|
-
return torch.load(hashfile)
|
374
|
+
return torch.load(hashfile, weights_only=False)
|
379
375
|
except PermissionError as e:
|
380
376
|
if sys.platform != "win32":
|
381
377
|
raise e
|
@@ -1674,7 +1670,4 @@ class GDSDataset(PersistentDataset):
|
|
1674
1670
|
if meta_hash_file_name in self._meta_cache:
|
1675
1671
|
return self._meta_cache[meta_hash_file_name]
|
1676
1672
|
else:
|
1677
|
-
|
1678
|
-
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
|
1679
|
-
else:
|
1680
|
-
return torch.load(self.cache_dir / meta_hash_file_name)
|
1673
|
+
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
|
monai/data/utils.py
CHANGED
@@ -753,7 +753,7 @@ def affine_to_spacing(affine: NdarrayTensor, r: int = 3, dtype=float, suppress_z
|
|
753
753
|
if isinstance(_affine, torch.Tensor):
|
754
754
|
spacing = torch.sqrt(torch.sum(_affine * _affine, dim=0))
|
755
755
|
else:
|
756
|
-
spacing = np.sqrt(np.sum(_affine * _affine, axis=0))
|
756
|
+
spacing = np.sqrt(np.sum(_affine * _affine, axis=0)) # type: ignore[operator]
|
757
757
|
if suppress_zeros:
|
758
758
|
spacing[spacing == 0] = 1.0
|
759
759
|
spacing_, *_ = convert_to_dst_type(spacing, dst=affine, dtype=dtype)
|
monai/data/video_dataset.py
CHANGED
@@ -177,7 +177,7 @@ class VideoFileDataset(Dataset, VideoDataset):
|
|
177
177
|
for codec, ext in all_codecs.items():
|
178
178
|
writer = cv2.VideoWriter()
|
179
179
|
fname = os.path.join(tmp_dir, f"test{ext}")
|
180
|
-
fourcc = cv2.VideoWriter_fourcc(*codec)
|
180
|
+
fourcc = cv2.VideoWriter_fourcc(*codec) # type: ignore[attr-defined]
|
181
181
|
noviderr = writer.open(fname, fourcc, 1, (10, 10))
|
182
182
|
if noviderr:
|
183
183
|
codecs[codec] = ext
|
monai/engines/evaluator.py
CHANGED
@@ -28,7 +28,7 @@ from monai.transforms import Transform
|
|
28
28
|
from monai.utils import ForwardMode, IgniteInfo, ensure_tuple, min_version, optional_import
|
29
29
|
from monai.utils.enums import CommonKeys as Keys
|
30
30
|
from monai.utils.enums import EngineStatsKeys as ESKeys
|
31
|
-
from monai.utils.module import look_up_option
|
31
|
+
from monai.utils.module import look_up_option
|
32
32
|
|
33
33
|
if TYPE_CHECKING:
|
34
34
|
from ignite.engine import Engine, EventEnum
|
@@ -82,8 +82,8 @@ class Evaluator(Workflow):
|
|
82
82
|
default to `True`.
|
83
83
|
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
|
84
84
|
`device`, `non_blocking`.
|
85
|
-
amp_kwargs: dict of the args for `torch.
|
86
|
-
https://pytorch.org/docs/stable/amp.html#torch.
|
85
|
+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
|
86
|
+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
|
87
87
|
|
88
88
|
"""
|
89
89
|
|
@@ -214,8 +214,8 @@ class SupervisedEvaluator(Evaluator):
|
|
214
214
|
default to `True`.
|
215
215
|
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
|
216
216
|
`device`, `non_blocking`.
|
217
|
-
amp_kwargs: dict of the args for `torch.
|
218
|
-
https://pytorch.org/docs/stable/amp.html#torch.
|
217
|
+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
|
218
|
+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
|
219
219
|
compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
|
220
220
|
`torch.Tensor` before forward pass, then converted back afterward with copied meta information.
|
221
221
|
compile_kwargs: dict of the args for `torch.compile()` API, for more details:
|
@@ -269,13 +269,8 @@ class SupervisedEvaluator(Evaluator):
|
|
269
269
|
amp_kwargs=amp_kwargs,
|
270
270
|
)
|
271
271
|
if compile:
|
272
|
-
if
|
273
|
-
|
274
|
-
network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
|
275
|
-
else:
|
276
|
-
warnings.warn(
|
277
|
-
"Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done"
|
278
|
-
)
|
272
|
+
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
273
|
+
network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
|
279
274
|
self.network = network
|
280
275
|
self.compile = compile
|
281
276
|
self.inferer = SimpleInferer() if inferer is None else inferer
|
@@ -329,7 +324,7 @@ class SupervisedEvaluator(Evaluator):
|
|
329
324
|
# execute forward computation
|
330
325
|
with engine.mode(engine.network):
|
331
326
|
if engine.amp:
|
332
|
-
with torch.
|
327
|
+
with torch.autocast("cuda", **engine.amp_kwargs):
|
333
328
|
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
|
334
329
|
else:
|
335
330
|
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
|
@@ -399,8 +394,8 @@ class EnsembleEvaluator(Evaluator):
|
|
399
394
|
default to `True`.
|
400
395
|
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
|
401
396
|
`device`, `non_blocking`.
|
402
|
-
amp_kwargs: dict of the args for `torch.
|
403
|
-
https://pytorch.org/docs/stable/amp.html#torch.
|
397
|
+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
|
398
|
+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
|
404
399
|
|
405
400
|
"""
|
406
401
|
|
@@ -492,7 +487,7 @@ class EnsembleEvaluator(Evaluator):
|
|
492
487
|
for idx, network in enumerate(engine.networks):
|
493
488
|
with engine.mode(network):
|
494
489
|
if engine.amp:
|
495
|
-
with torch.
|
490
|
+
with torch.autocast("cuda", **engine.amp_kwargs):
|
496
491
|
if isinstance(engine.state.output, dict):
|
497
492
|
engine.state.output.update(
|
498
493
|
{engine.pred_keys[idx]: engine.inferer(inputs, network, *args, **kwargs)}
|
monai/engines/trainer.py
CHANGED
@@ -27,7 +27,6 @@ from monai.transforms import Transform
|
|
27
27
|
from monai.utils import AdversarialIterationEvents, AdversarialKeys, GanKeys, IgniteInfo, min_version, optional_import
|
28
28
|
from monai.utils.enums import CommonKeys as Keys
|
29
29
|
from monai.utils.enums import EngineStatsKeys as ESKeys
|
30
|
-
from monai.utils.module import pytorch_after
|
31
30
|
|
32
31
|
if TYPE_CHECKING:
|
33
32
|
from ignite.engine import Engine, EventEnum
|
@@ -126,8 +125,8 @@ class SupervisedTrainer(Trainer):
|
|
126
125
|
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
|
127
126
|
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
|
128
127
|
`device`, `non_blocking`.
|
129
|
-
amp_kwargs: dict of the args for `torch.
|
130
|
-
https://pytorch.org/docs/stable/amp.html#torch.
|
128
|
+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
|
129
|
+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
|
131
130
|
compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
|
132
131
|
`torch.Tensor` before forward pass, then converted back afterward with copied meta information.
|
133
132
|
compile_kwargs: dict of the args for `torch.compile()` API, for more details:
|
@@ -183,13 +182,8 @@ class SupervisedTrainer(Trainer):
|
|
183
182
|
amp_kwargs=amp_kwargs,
|
184
183
|
)
|
185
184
|
if compile:
|
186
|
-
if
|
187
|
-
|
188
|
-
network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
|
189
|
-
else:
|
190
|
-
warnings.warn(
|
191
|
-
"Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done"
|
192
|
-
)
|
185
|
+
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
186
|
+
network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
|
193
187
|
self.network = network
|
194
188
|
self.compile = compile
|
195
189
|
self.optimizer = optimizer
|
@@ -255,7 +249,7 @@ class SupervisedTrainer(Trainer):
|
|
255
249
|
engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
|
256
250
|
|
257
251
|
if engine.amp and engine.scaler is not None:
|
258
|
-
with torch.
|
252
|
+
with torch.autocast("cuda", **engine.amp_kwargs):
|
259
253
|
_compute_pred_loss()
|
260
254
|
engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
|
261
255
|
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
|
@@ -341,8 +335,8 @@ class GanTrainer(Trainer):
|
|
341
335
|
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
|
342
336
|
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
|
343
337
|
`device`, `non_blocking`.
|
344
|
-
amp_kwargs: dict of the args for `torch.
|
345
|
-
https://pytorch.org/docs/stable/amp.html#torch.
|
338
|
+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
|
339
|
+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
|
346
340
|
|
347
341
|
"""
|
348
342
|
|
@@ -518,8 +512,8 @@ class AdversarialTrainer(Trainer):
|
|
518
512
|
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
|
519
513
|
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
|
520
514
|
`device`, `non_blocking`.
|
521
|
-
amp_kwargs: dict of the args for `torch.
|
522
|
-
https://pytorch.org/docs/stable/amp.html#torch.
|
515
|
+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
|
516
|
+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
|
523
517
|
"""
|
524
518
|
|
525
519
|
def __init__(
|
@@ -689,7 +683,7 @@ class AdversarialTrainer(Trainer):
|
|
689
683
|
engine.state.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
|
690
684
|
|
691
685
|
if engine.amp and engine.state.g_scaler is not None:
|
692
|
-
with torch.
|
686
|
+
with torch.autocast("cuda", **engine.amp_kwargs):
|
693
687
|
_compute_generator_loss()
|
694
688
|
|
695
689
|
engine.state.output[Keys.LOSS] = (
|
@@ -737,7 +731,7 @@ class AdversarialTrainer(Trainer):
|
|
737
731
|
engine.state.d_network.zero_grad(set_to_none=engine.optim_set_to_none)
|
738
732
|
|
739
733
|
if engine.amp and engine.state.d_scaler is not None:
|
740
|
-
with torch.
|
734
|
+
with torch.autocast("cuda", **engine.amp_kwargs):
|
741
735
|
_compute_discriminator_loss()
|
742
736
|
|
743
737
|
engine.state.d_scaler.scale(engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS]).backward()
|
monai/engines/utils.py
CHANGED
@@ -309,7 +309,7 @@ class VPredictionPrepareBatch(DiffusionPrepareBatch):
|
|
309
309
|
self.scheduler = scheduler
|
310
310
|
|
311
311
|
def get_target(self, images, noise, timesteps):
|
312
|
-
return self.scheduler.get_velocity(images, noise, timesteps)
|
312
|
+
return self.scheduler.get_velocity(images, noise, timesteps) # type: ignore[operator]
|
313
313
|
|
314
314
|
|
315
315
|
def default_make_latent(
|
monai/engines/workflow.py
CHANGED
@@ -90,8 +90,8 @@ class Workflow(Engine):
|
|
90
90
|
default to `True`.
|
91
91
|
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
|
92
92
|
`device`, `non_blocking`.
|
93
|
-
amp_kwargs: dict of the args for `torch.
|
94
|
-
https://pytorch.org/docs/stable/amp.html#torch.
|
93
|
+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
|
94
|
+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
|
95
95
|
|
96
96
|
Raises:
|
97
97
|
TypeError: When ``data_loader`` is not a ``torch.utils.data.DataLoader``.
|
monai/fl/client/monai_algo.py
CHANGED
@@ -574,7 +574,7 @@ class MonaiAlgo(ClientAlgo, MonaiAlgoStats):
|
|
574
574
|
model_path = os.path.join(self.bundle_root, cast(str, self.model_filepaths[model_type]))
|
575
575
|
if not os.path.isfile(model_path):
|
576
576
|
raise ValueError(f"No best model checkpoint exists at {model_path}")
|
577
|
-
weights = torch.load(model_path, map_location="cpu")
|
577
|
+
weights = torch.load(model_path, map_location="cpu", weights_only=True)
|
578
578
|
# if weights contain several state dicts, use the one defined by `save_dict_key`
|
579
579
|
if isinstance(weights, dict) and self.save_dict_key in weights:
|
580
580
|
weights = weights.get(self.save_dict_key)
|
@@ -122,7 +122,7 @@ class CheckpointLoader:
|
|
122
122
|
Args:
|
123
123
|
engine: Ignite Engine, it can be a trainer, validator or evaluator.
|
124
124
|
"""
|
125
|
-
checkpoint = torch.load(self.load_path, map_location=self.map_location)
|
125
|
+
checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=False)
|
126
126
|
|
127
127
|
k, _ = list(self.load_dict.items())[0]
|
128
128
|
# single object and checkpoint is directly a state_dict
|
monai/inferers/inferer.py
CHANGED
@@ -39,7 +39,7 @@ from monai.networks.nets import (
|
|
39
39
|
SPADEAutoencoderKL,
|
40
40
|
SPADEDiffusionModelUNet,
|
41
41
|
)
|
42
|
-
from monai.networks.schedulers import Scheduler
|
42
|
+
from monai.networks.schedulers import RFlowScheduler, Scheduler
|
43
43
|
from monai.transforms import CenterSpatialCrop, SpatialPad
|
44
44
|
from monai.utils import BlendMode, Ordering, PatchKeys, PytorchPadMode, ensure_tuple, optional_import
|
45
45
|
from monai.visualize import CAM, GradCAM, GradCAMpp
|
@@ -859,12 +859,18 @@ class DiffusionInferer(Inferer):
|
|
859
859
|
if not scheduler:
|
860
860
|
scheduler = self.scheduler
|
861
861
|
image = input_noise
|
862
|
+
|
863
|
+
all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype)))
|
862
864
|
if verbose and has_tqdm:
|
863
|
-
progress_bar = tqdm(
|
865
|
+
progress_bar = tqdm(
|
866
|
+
zip(scheduler.timesteps, all_next_timesteps),
|
867
|
+
total=min(len(scheduler.timesteps), len(all_next_timesteps)),
|
868
|
+
)
|
864
869
|
else:
|
865
|
-
progress_bar = iter(scheduler.timesteps)
|
870
|
+
progress_bar = iter(zip(scheduler.timesteps, all_next_timesteps))
|
866
871
|
intermediates = []
|
867
|
-
|
872
|
+
|
873
|
+
for t, next_t in progress_bar:
|
868
874
|
# 1. predict noise model_output
|
869
875
|
diffusion_model = (
|
870
876
|
partial(diffusion_model, seg=seg)
|
@@ -882,9 +888,13 @@ class DiffusionInferer(Inferer):
|
|
882
888
|
)
|
883
889
|
|
884
890
|
# 2. compute previous image: x_t -> x_t-1
|
885
|
-
|
891
|
+
if not isinstance(scheduler, RFlowScheduler):
|
892
|
+
image, _ = scheduler.step(model_output, t, image) # type: ignore
|
893
|
+
else:
|
894
|
+
image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore
|
886
895
|
if save_intermediates and t % intermediate_steps == 0:
|
887
896
|
intermediates.append(image)
|
897
|
+
|
888
898
|
if save_intermediates:
|
889
899
|
return image, intermediates
|
890
900
|
else:
|
@@ -986,8 +996,8 @@ class DiffusionInferer(Inferer):
|
|
986
996
|
predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image
|
987
997
|
|
988
998
|
# get the posterior mean and variance
|
989
|
-
posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image)
|
990
|
-
posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance)
|
999
|
+
posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) # type: ignore[operator]
|
1000
|
+
posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) # type: ignore[operator]
|
991
1001
|
|
992
1002
|
log_posterior_variance = torch.log(posterior_variance)
|
993
1003
|
log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance
|
@@ -1392,12 +1402,18 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1392
1402
|
if not scheduler:
|
1393
1403
|
scheduler = self.scheduler
|
1394
1404
|
image = input_noise
|
1405
|
+
|
1406
|
+
all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype)))
|
1395
1407
|
if verbose and has_tqdm:
|
1396
|
-
progress_bar = tqdm(
|
1408
|
+
progress_bar = tqdm(
|
1409
|
+
zip(scheduler.timesteps, all_next_timesteps),
|
1410
|
+
total=min(len(scheduler.timesteps), len(all_next_timesteps)),
|
1411
|
+
)
|
1397
1412
|
else:
|
1398
|
-
progress_bar = iter(scheduler.timesteps)
|
1413
|
+
progress_bar = iter(zip(scheduler.timesteps, all_next_timesteps))
|
1399
1414
|
intermediates = []
|
1400
|
-
|
1415
|
+
|
1416
|
+
for t, next_t in progress_bar:
|
1401
1417
|
diffuse = diffusion_model
|
1402
1418
|
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
|
1403
1419
|
diffuse = partial(diffusion_model, seg=seg)
|
@@ -1436,7 +1452,11 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1436
1452
|
)
|
1437
1453
|
|
1438
1454
|
# 3. compute previous image: x_t -> x_t-1
|
1439
|
-
|
1455
|
+
if not isinstance(scheduler, RFlowScheduler):
|
1456
|
+
image, _ = scheduler.step(model_output, t, image) # type: ignore
|
1457
|
+
else:
|
1458
|
+
image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore
|
1459
|
+
|
1440
1460
|
if save_intermediates and t % intermediate_steps == 0:
|
1441
1461
|
intermediates.append(image)
|
1442
1462
|
if save_intermediates:
|
@@ -1562,8 +1582,8 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1562
1582
|
predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image
|
1563
1583
|
|
1564
1584
|
# get the posterior mean and variance
|
1565
|
-
posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image)
|
1566
|
-
posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance)
|
1585
|
+
posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) # type: ignore[operator]
|
1586
|
+
posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) # type: ignore[operator]
|
1567
1587
|
|
1568
1588
|
log_posterior_variance = torch.log(posterior_variance)
|
1569
1589
|
log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance
|
monai/inferers/merger.py
CHANGED
@@ -53,8 +53,11 @@ class Merger(ABC):
|
|
53
53
|
cropped_shape: Sequence[int] | None = None,
|
54
54
|
device: torch.device | str | None = None,
|
55
55
|
) -> None:
|
56
|
-
|
57
|
-
|
56
|
+
if merged_shape is None:
|
57
|
+
raise ValueError("Argument `merged_shape` must be provided")
|
58
|
+
|
59
|
+
self.merged_shape: tuple[int, ...] = tuple(merged_shape)
|
60
|
+
self.cropped_shape: tuple[int, ...] = tuple(self.merged_shape if cropped_shape is None else cropped_shape)
|
58
61
|
self.device = device
|
59
62
|
self.is_finalized = False
|
60
63
|
|
@@ -231,9 +234,9 @@ class ZarrAvgMerger(Merger):
|
|
231
234
|
dtype: np.dtype | str = "float32",
|
232
235
|
value_dtype: np.dtype | str = "float32",
|
233
236
|
count_dtype: np.dtype | str = "uint8",
|
234
|
-
store: zarr.storage.Store | str = "merged.zarr",
|
235
|
-
value_store: zarr.storage.Store | str | None = None,
|
236
|
-
count_store: zarr.storage.Store | str | None = None,
|
237
|
+
store: zarr.storage.Store | str = "merged.zarr", # type: ignore
|
238
|
+
value_store: zarr.storage.Store | str | None = None, # type: ignore
|
239
|
+
count_store: zarr.storage.Store | str | None = None, # type: ignore
|
237
240
|
compressor: str | None = None,
|
238
241
|
value_compressor: str | None = None,
|
239
242
|
count_compressor: str | None = None,
|
@@ -251,18 +254,18 @@ class ZarrAvgMerger(Merger):
|
|
251
254
|
if version_geq(get_package_version("zarr"), "3.0.0"):
|
252
255
|
if value_store is None:
|
253
256
|
self.tmpdir = TemporaryDirectory()
|
254
|
-
self.value_store = zarr.storage.LocalStore(self.tmpdir.name)
|
257
|
+
self.value_store = zarr.storage.LocalStore(self.tmpdir.name) # type: ignore
|
255
258
|
else:
|
256
|
-
self.value_store = value_store
|
259
|
+
self.value_store = value_store # type: ignore
|
257
260
|
if count_store is None:
|
258
261
|
self.tmpdir = TemporaryDirectory()
|
259
|
-
self.count_store = zarr.storage.LocalStore(self.tmpdir.name)
|
262
|
+
self.count_store = zarr.storage.LocalStore(self.tmpdir.name) # type: ignore
|
260
263
|
else:
|
261
|
-
self.count_store = count_store
|
264
|
+
self.count_store = count_store # type: ignore
|
262
265
|
else:
|
263
266
|
self.tmpdir = None
|
264
|
-
self.value_store = zarr.storage.TempStore() if value_store is None else value_store
|
265
|
-
self.count_store = zarr.storage.TempStore() if count_store is None else count_store
|
267
|
+
self.value_store = zarr.storage.TempStore() if value_store is None else value_store # type: ignore
|
268
|
+
self.count_store = zarr.storage.TempStore() if count_store is None else count_store # type: ignore
|
266
269
|
self.chunks = chunks
|
267
270
|
self.compressor = compressor
|
268
271
|
self.value_compressor = value_compressor
|
@@ -314,7 +317,7 @@ class ZarrAvgMerger(Merger):
|
|
314
317
|
map_slice = ensure_tuple_size(map_slice, values.ndim, pad_val=slice(None), pad_from_start=True)
|
315
318
|
with self.lock:
|
316
319
|
self.values[map_slice] += values.numpy()
|
317
|
-
self.counts[map_slice] += 1
|
320
|
+
self.counts[map_slice] += 1 # type: ignore[operator]
|
318
321
|
|
319
322
|
def finalize(self) -> zarr.Array:
|
320
323
|
"""
|
@@ -332,7 +335,7 @@ class ZarrAvgMerger(Merger):
|
|
332
335
|
if not self.is_finalized:
|
333
336
|
# use chunks for division to fit into memory
|
334
337
|
for chunk in iterate_over_chunks(self.values.chunks, self.values.cdata_shape):
|
335
|
-
self.output[chunk] = self.values[chunk] / self.counts[chunk]
|
338
|
+
self.output[chunk] = self.values[chunk] / self.counts[chunk] # type: ignore[operator]
|
336
339
|
# finalize the shape
|
337
340
|
self.output.resize(self.cropped_shape)
|
338
341
|
# set finalize flag to protect performing in-place division again
|
monai/losses/perceptual.py
CHANGED
@@ -374,7 +374,7 @@ class TorchvisionModelPerceptualSimilarity(nn.Module):
|
|
374
374
|
else:
|
375
375
|
network = torchvision.models.resnet50(weights=None)
|
376
376
|
if pretrained is True:
|
377
|
-
state_dict = torch.load(pretrained_path)
|
377
|
+
state_dict = torch.load(pretrained_path, weights_only=True)
|
378
378
|
if pretrained_state_dict_key is not None:
|
379
379
|
state_dict = state_dict[pretrained_state_dict_key]
|
380
380
|
network.load_state_dict(state_dict)
|
monai/losses/sure_loss.py
CHANGED
@@ -92,7 +92,7 @@ def sure_loss_function(
|
|
92
92
|
y_ref = operator(x)
|
93
93
|
|
94
94
|
# get perturbed output
|
95
|
-
x_perturbed = x + eps * perturb_noise
|
95
|
+
x_perturbed = x + eps * perturb_noise # type: ignore
|
96
96
|
y_perturbed = operator(x_perturbed)
|
97
97
|
# divergence
|
98
98
|
divergence = torch.sum(1.0 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref)) # type: ignore
|
@@ -17,7 +17,7 @@ import torch
|
|
17
17
|
import torch.nn as nn
|
18
18
|
|
19
19
|
from monai.networks.layers.utils import get_rel_pos_embedding_layer
|
20
|
-
from monai.utils import optional_import
|
20
|
+
from monai.utils import optional_import
|
21
21
|
|
22
22
|
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
|
23
23
|
|
@@ -84,11 +84,6 @@ class CrossAttentionBlock(nn.Module):
|
|
84
84
|
if causal and sequence_length is None:
|
85
85
|
raise ValueError("sequence_length is necessary for causal attention.")
|
86
86
|
|
87
|
-
if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
|
88
|
-
raise ValueError(
|
89
|
-
"use_flash_attention is only supported for PyTorch versions >= 2.0."
|
90
|
-
"Upgrade your PyTorch or set the flag to False."
|
91
|
-
)
|
92
87
|
if use_flash_attention and save_attn:
|
93
88
|
raise ValueError(
|
94
89
|
"save_attn has been set to True, but use_flash_attention is also set"
|