monai-weekly 1.5.dev2509__py3-none-any.whl → 1.5.dev2510__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/deepedit/interaction.py +1 -1
- monai/apps/deepgrow/interaction.py +1 -1
- monai/apps/detection/networks/retinanet_detector.py +1 -1
- monai/apps/detection/networks/retinanet_network.py +5 -5
- monai/apps/detection/utils/box_coder.py +2 -2
- monai/apps/mmars/mmars.py +1 -1
- monai/apps/reconstruction/networks/blocks/varnetblock.py +1 -1
- monai/bundle/scripts.py +3 -4
- monai/data/dataset.py +2 -9
- monai/data/utils.py +1 -1
- monai/data/video_dataset.py +1 -1
- monai/engines/evaluator.py +11 -16
- monai/engines/trainer.py +11 -17
- monai/engines/utils.py +1 -1
- monai/engines/workflow.py +2 -2
- monai/fl/client/monai_algo.py +1 -1
- monai/handlers/checkpoint_loader.py +1 -1
- monai/inferers/inferer.py +6 -6
- monai/inferers/merger.py +16 -13
- monai/losses/perceptual.py +1 -1
- monai/losses/sure_loss.py +1 -1
- monai/networks/blocks/crossattention.py +1 -6
- monai/networks/blocks/feature_pyramid_network.py +4 -2
- monai/networks/blocks/selfattention.py +1 -6
- monai/networks/blocks/upsample.py +3 -11
- monai/networks/layers/vector_quantizer.py +2 -2
- monai/networks/nets/hovernet.py +5 -4
- monai/networks/nets/resnet.py +2 -2
- monai/networks/nets/senet.py +1 -1
- monai/networks/nets/swin_unetr.py +46 -49
- monai/networks/nets/transchex.py +3 -2
- monai/networks/nets/vista3d.py +7 -7
- monai/networks/utils.py +5 -4
- monai/transforms/intensity/array.py +1 -1
- monai/transforms/spatial/array.py +6 -6
- monai/utils/misc.py +1 -1
- monai/utils/state_cacher.py +1 -1
- {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2510.dist-info}/METADATA +4 -3
- {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2510.dist-info}/RECORD +59 -59
- tests/bundle/test_bundle_download.py +16 -6
- tests/config/test_cv2_dist.py +1 -2
- tests/integration/test_integration_bundle_run.py +2 -4
- tests/integration/test_integration_classification_2d.py +1 -1
- tests/integration/test_integration_fast_train.py +2 -2
- tests/integration/test_integration_segmentation_3d.py +1 -1
- tests/metrics/test_compute_multiscalessim_metric.py +3 -3
- tests/metrics/test_surface_dice.py +3 -3
- tests/networks/nets/test_autoencoderkl.py +1 -1
- tests/networks/nets/test_controlnet.py +1 -1
- tests/networks/nets/test_diffusion_model_unet.py +1 -1
- tests/networks/nets/test_network_consistency.py +1 -1
- tests/networks/nets/test_swin_unetr.py +1 -1
- tests/networks/nets/test_transformer.py +1 -1
- tests/networks/test_save_state.py +1 -1
- {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2510.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2510.dist-info}/WHEEL +0 -0
- {monai_weekly-1.5.dev2509.dist-info → monai_weekly-1.5.dev2510.dist-info}/top_level.txt +0 -0
@@ -100,7 +100,7 @@ class EMAQuantizer(nn.Module):
|
|
100
100
|
torch.Tensor: Quantization indices of shape [B,H,W,D,1]
|
101
101
|
|
102
102
|
"""
|
103
|
-
with torch.
|
103
|
+
with torch.autocast("cuda", enabled=False):
|
104
104
|
encoding_indices_view = list(inputs.shape)
|
105
105
|
del encoding_indices_view[1]
|
106
106
|
|
@@ -138,7 +138,7 @@ class EMAQuantizer(nn.Module):
|
|
138
138
|
Returns:
|
139
139
|
torch.Tensor: Quantize space representation of encoding_indices in channel first format.
|
140
140
|
"""
|
141
|
-
with torch.
|
141
|
+
with torch.autocast("cuda", enabled=False):
|
142
142
|
embedding: torch.Tensor = (
|
143
143
|
self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous()
|
144
144
|
)
|
monai/networks/nets/hovernet.py
CHANGED
@@ -633,9 +633,9 @@ def _remap_preact_resnet_model(model_url: str):
|
|
633
633
|
# download the pretrained weights into torch hub's default dir
|
634
634
|
weights_dir = os.path.join(torch.hub.get_dir(), "preact-resnet50.pth")
|
635
635
|
download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)
|
636
|
-
|
637
|
-
|
638
|
-
|
636
|
+
map_location = None if torch.cuda.is_available() else torch.device("cpu")
|
637
|
+
state_dict = torch.load(weights_dir, map_location=map_location, weights_only=True)["desc"]
|
638
|
+
|
639
639
|
for key in list(state_dict.keys()):
|
640
640
|
new_key = None
|
641
641
|
if pattern_conv0.match(key):
|
@@ -668,7 +668,8 @@ def _remap_standard_resnet_model(model_url: str, state_dict_key: str | None = No
|
|
668
668
|
# download the pretrained weights into torch hub's default dir
|
669
669
|
weights_dir = os.path.join(torch.hub.get_dir(), "resnet50.pth")
|
670
670
|
download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)
|
671
|
-
|
671
|
+
map_location = None if torch.cuda.is_available() else torch.device("cpu")
|
672
|
+
state_dict = torch.load(weights_dir, map_location=map_location, weights_only=True)
|
672
673
|
if state_dict_key is not None:
|
673
674
|
state_dict = state_dict[state_dict_key]
|
674
675
|
|
monai/networks/nets/resnet.py
CHANGED
@@ -493,7 +493,7 @@ def _resnet(
|
|
493
493
|
if isinstance(pretrained, str):
|
494
494
|
if Path(pretrained).exists():
|
495
495
|
logger.info(f"Loading weights from {pretrained}...")
|
496
|
-
model_state_dict = torch.load(pretrained, map_location=device)
|
496
|
+
model_state_dict = torch.load(pretrained, map_location=device, weights_only=True)
|
497
497
|
else:
|
498
498
|
# Throw error
|
499
499
|
raise FileNotFoundError("The pretrained checkpoint file is not found")
|
@@ -665,7 +665,7 @@ def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", dat
|
|
665
665
|
raise EntryNotFoundError(
|
666
666
|
f"{filename} not found on {medicalnet_huggingface_repo_basename}{resnet_depth}"
|
667
667
|
) from None
|
668
|
-
checkpoint = torch.load(pretrained_path, map_location=torch.device(device))
|
668
|
+
checkpoint = torch.load(pretrained_path, map_location=torch.device(device), weights_only=True)
|
669
669
|
else:
|
670
670
|
raise NotImplementedError("Supported resnet_depth are: [10, 18, 34, 50, 101, 152, 200]")
|
671
671
|
logger.info(f"{filename} downloaded")
|
monai/networks/nets/senet.py
CHANGED
@@ -302,7 +302,7 @@ def _load_state_dict(model: nn.Module, arch: str, progress: bool):
|
|
302
302
|
|
303
303
|
if isinstance(model_url, dict):
|
304
304
|
download_url(model_url["url"], filepath=model_url["filename"])
|
305
|
-
state_dict = torch.load(model_url["filename"], map_location=None)
|
305
|
+
state_dict = torch.load(model_url["filename"], map_location=None, weights_only=True)
|
306
306
|
else:
|
307
307
|
state_dict = load_state_dict_from_url(model_url, progress=progress)
|
308
308
|
for key in list(state_dict.keys()):
|
@@ -272,53 +272,50 @@ class SwinUNETR(nn.Module):
|
|
272
272
|
self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels)
|
273
273
|
|
274
274
|
def load_from(self, weights):
|
275
|
+
layers1_0: BasicLayer = self.swinViT.layers1[0] # type: ignore[assignment]
|
276
|
+
layers2_0: BasicLayer = self.swinViT.layers2[0] # type: ignore[assignment]
|
277
|
+
layers3_0: BasicLayer = self.swinViT.layers3[0] # type: ignore[assignment]
|
278
|
+
layers4_0: BasicLayer = self.swinViT.layers4[0] # type: ignore[assignment]
|
279
|
+
wstate = weights["state_dict"]
|
280
|
+
|
275
281
|
with torch.no_grad():
|
276
|
-
self.swinViT.patch_embed.proj.weight.copy_(
|
277
|
-
self.swinViT.patch_embed.proj.bias.copy_(
|
278
|
-
for bname, block in
|
279
|
-
block.load_from(weights, n_block=bname, layer="layers1")
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
)
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
weights
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
self.swinViT.layers4[0].downsample.reduction.weight.copy_(
|
314
|
-
weights["state_dict"]["module.layers4.0.downsample.reduction.weight"]
|
315
|
-
)
|
316
|
-
self.swinViT.layers4[0].downsample.norm.weight.copy_(
|
317
|
-
weights["state_dict"]["module.layers4.0.downsample.norm.weight"]
|
318
|
-
)
|
319
|
-
self.swinViT.layers4[0].downsample.norm.bias.copy_(
|
320
|
-
weights["state_dict"]["module.layers4.0.downsample.norm.bias"]
|
321
|
-
)
|
282
|
+
self.swinViT.patch_embed.proj.weight.copy_(wstate["module.patch_embed.proj.weight"])
|
283
|
+
self.swinViT.patch_embed.proj.bias.copy_(wstate["module.patch_embed.proj.bias"])
|
284
|
+
for bname, block in layers1_0.blocks.named_children():
|
285
|
+
block.load_from(weights, n_block=bname, layer="layers1") # type: ignore[operator]
|
286
|
+
|
287
|
+
if layers1_0.downsample is not None:
|
288
|
+
d = layers1_0.downsample
|
289
|
+
d.reduction.weight.copy_(wstate["module.layers1.0.downsample.reduction.weight"]) # type: ignore
|
290
|
+
d.norm.weight.copy_(wstate["module.layers1.0.downsample.norm.weight"]) # type: ignore
|
291
|
+
d.norm.bias.copy_(wstate["module.layers1.0.downsample.norm.bias"]) # type: ignore
|
292
|
+
|
293
|
+
for bname, block in layers2_0.blocks.named_children():
|
294
|
+
block.load_from(weights, n_block=bname, layer="layers2") # type: ignore[operator]
|
295
|
+
|
296
|
+
if layers2_0.downsample is not None:
|
297
|
+
d = layers2_0.downsample
|
298
|
+
d.reduction.weight.copy_(wstate["module.layers2.0.downsample.reduction.weight"]) # type: ignore
|
299
|
+
d.norm.weight.copy_(wstate["module.layers2.0.downsample.norm.weight"]) # type: ignore
|
300
|
+
d.norm.bias.copy_(wstate["module.layers2.0.downsample.norm.bias"]) # type: ignore
|
301
|
+
|
302
|
+
for bname, block in layers3_0.blocks.named_children():
|
303
|
+
block.load_from(weights, n_block=bname, layer="layers3") # type: ignore[operator]
|
304
|
+
|
305
|
+
if layers3_0.downsample is not None:
|
306
|
+
d = layers3_0.downsample
|
307
|
+
d.reduction.weight.copy_(wstate["module.layers3.0.downsample.reduction.weight"]) # type: ignore
|
308
|
+
d.norm.weight.copy_(wstate["module.layers3.0.downsample.norm.weight"]) # type: ignore
|
309
|
+
d.norm.bias.copy_(wstate["module.layers3.0.downsample.norm.bias"]) # type: ignore
|
310
|
+
|
311
|
+
for bname, block in layers4_0.blocks.named_children():
|
312
|
+
block.load_from(weights, n_block=bname, layer="layers4") # type: ignore[operator]
|
313
|
+
|
314
|
+
if layers4_0.downsample is not None:
|
315
|
+
d = layers4_0.downsample
|
316
|
+
d.reduction.weight.copy_(wstate["module.layers4.0.downsample.reduction.weight"]) # type: ignore
|
317
|
+
d.norm.weight.copy_(wstate["module.layers4.0.downsample.norm.weight"]) # type: ignore
|
318
|
+
d.norm.bias.copy_(wstate["module.layers4.0.downsample.norm.bias"]) # type: ignore
|
322
319
|
|
323
320
|
@torch.jit.unused
|
324
321
|
def _check_input_size(self, spatial_shape):
|
@@ -532,7 +529,7 @@ class WindowAttention(nn.Module):
|
|
532
529
|
q = q * self.scale
|
533
530
|
attn = q @ k.transpose(-2, -1)
|
534
531
|
relative_position_bias = self.relative_position_bias_table[
|
535
|
-
self.relative_position_index.clone()[:n, :n].reshape(-1)
|
532
|
+
self.relative_position_index.clone()[:n, :n].reshape(-1) # type: ignore[operator]
|
536
533
|
].reshape(n, n, -1)
|
537
534
|
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
|
538
535
|
attn = attn + relative_position_bias.unsqueeze(0)
|
@@ -691,7 +688,7 @@ class SwinTransformerBlock(nn.Module):
|
|
691
688
|
self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]])
|
692
689
|
self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]])
|
693
690
|
self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]])
|
694
|
-
self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]])
|
691
|
+
self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]]) # type: ignore[operator]
|
695
692
|
self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]])
|
696
693
|
self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]])
|
697
694
|
self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]])
|
@@ -1118,7 +1115,7 @@ def filter_swinunetr(key, value):
|
|
1118
1115
|
)
|
1119
1116
|
ssl_weights_path = "./ssl_pretrained_weights.pth"
|
1120
1117
|
download_url(resource, ssl_weights_path)
|
1121
|
-
ssl_weights = torch.load(ssl_weights_path)["model"]
|
1118
|
+
ssl_weights = torch.load(ssl_weights_path, weights_only=True)["model"]
|
1122
1119
|
|
1123
1120
|
dst_dict, loaded, not_loaded = copy_model_state(model, ssl_weights, filter_func=filter_swinunetr)
|
1124
1121
|
|
monai/networks/nets/transchex.py
CHANGED
@@ -43,7 +43,7 @@ class BertPreTrainedModel(nn.Module):
|
|
43
43
|
|
44
44
|
def init_bert_weights(self, module):
|
45
45
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
46
|
-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
46
|
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) # type: ignore[union-attr,arg-type]
|
47
47
|
elif isinstance(module, torch.nn.LayerNorm):
|
48
48
|
module.bias.data.zero_()
|
49
49
|
module.weight.data.fill_(1.0)
|
@@ -68,7 +68,8 @@ class BertPreTrainedModel(nn.Module):
|
|
68
68
|
weights_path = cached_file(path_or_repo_id, filename, cache_dir=cache_dir)
|
69
69
|
model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs)
|
70
70
|
if state_dict is None and not from_tf:
|
71
|
-
|
71
|
+
map_location = "cpu" if not torch.cuda.is_available() else None
|
72
|
+
state_dict = torch.load(weights_path, map_location=map_location, weights_only=True)
|
72
73
|
if from_tf:
|
73
74
|
return load_tf_weights_in_bert(model, weights_path)
|
74
75
|
old_keys = []
|
monai/networks/nets/vista3d.py
CHANGED
@@ -315,7 +315,7 @@ class VISTA3D(nn.Module):
|
|
315
315
|
"""
|
316
316
|
if auto_freeze != self.auto_freeze:
|
317
317
|
if hasattr(self.image_encoder, "set_auto_grad"):
|
318
|
-
self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze)
|
318
|
+
self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) # type: ignore[operator]
|
319
319
|
else:
|
320
320
|
for param in self.image_encoder.parameters():
|
321
321
|
param.requires_grad = (not auto_freeze) and (not point_freeze)
|
@@ -325,7 +325,7 @@ class VISTA3D(nn.Module):
|
|
325
325
|
|
326
326
|
if point_freeze != self.point_freeze:
|
327
327
|
if hasattr(self.image_encoder, "set_auto_grad"):
|
328
|
-
self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze)
|
328
|
+
self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) # type: ignore[operator]
|
329
329
|
else:
|
330
330
|
for param in self.image_encoder.parameters():
|
331
331
|
param.requires_grad = (not auto_freeze) and (not point_freeze)
|
@@ -543,10 +543,10 @@ class PointMappingSAM(nn.Module):
|
|
543
543
|
point_embedding = self.pe_layer.forward_with_coords(points, out_shape) # type: ignore
|
544
544
|
point_embedding[point_labels == -1] = 0.0
|
545
545
|
point_embedding[point_labels == -1] += self.not_a_point_embed.weight
|
546
|
-
point_embedding[point_labels == 0] += self.point_embeddings[0].weight
|
547
|
-
point_embedding[point_labels == 1] += self.point_embeddings[1].weight
|
548
|
-
point_embedding[point_labels == 2] += self.point_embeddings[0].weight + self.special_class_embed.weight
|
549
|
-
point_embedding[point_labels == 3] += self.point_embeddings[1].weight + self.special_class_embed.weight
|
546
|
+
point_embedding[point_labels == 0] += self.point_embeddings[0].weight # type: ignore[arg-type]
|
547
|
+
point_embedding[point_labels == 1] += self.point_embeddings[1].weight # type: ignore[arg-type]
|
548
|
+
point_embedding[point_labels == 2] += self.point_embeddings[0].weight + self.special_class_embed.weight # type: ignore[operator]
|
549
|
+
point_embedding[point_labels == 3] += self.point_embeddings[1].weight + self.special_class_embed.weight # type: ignore[operator]
|
550
550
|
output_tokens = self.mask_tokens.weight
|
551
551
|
|
552
552
|
output_tokens = output_tokens.unsqueeze(0).expand(point_embedding.size(0), -1, -1)
|
@@ -884,7 +884,7 @@ class PositionEmbeddingRandom(nn.Module):
|
|
884
884
|
coords = 2 * coords - 1
|
885
885
|
# [bs=1,N=2,2] @ [2,128]
|
886
886
|
# [bs=1, N=2, 128]
|
887
|
-
coords = coords @ self.positional_encoding_gaussian_matrix
|
887
|
+
coords = coords @ self.positional_encoding_gaussian_matrix # type: ignore[operator]
|
888
888
|
coords = 2 * np.pi * coords
|
889
889
|
# outputs d_1 x ... x d_n x C shape
|
890
890
|
# [bs=1, N=2, 128+128=256]
|
monai/networks/utils.py
CHANGED
@@ -22,7 +22,7 @@ from collections import OrderedDict
|
|
22
22
|
from collections.abc import Callable, Mapping, Sequence
|
23
23
|
from contextlib import contextmanager
|
24
24
|
from copy import deepcopy
|
25
|
-
from typing import Any
|
25
|
+
from typing import Any, Iterable
|
26
26
|
|
27
27
|
import numpy as np
|
28
28
|
import torch
|
@@ -1238,7 +1238,7 @@ class CastToFloat(torch.nn.Module):
|
|
1238
1238
|
|
1239
1239
|
def forward(self, x):
|
1240
1240
|
dtype = x.dtype
|
1241
|
-
with torch.
|
1241
|
+
with torch.autocast("cuda", enabled=False):
|
1242
1242
|
ret = self.mod.forward(x.to(torch.float32)).to(dtype)
|
1243
1243
|
return ret
|
1244
1244
|
|
@@ -1255,7 +1255,7 @@ class CastToFloatAll(torch.nn.Module):
|
|
1255
1255
|
|
1256
1256
|
def forward(self, *args):
|
1257
1257
|
from_dtype = args[0].dtype
|
1258
|
-
with torch.
|
1258
|
+
with torch.autocast("cuda", enabled=False):
|
1259
1259
|
ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32))
|
1260
1260
|
return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype)
|
1261
1261
|
|
@@ -1291,7 +1291,8 @@ def simple_replace(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable
|
|
1291
1291
|
def expansion_fn(mod: nn.Module) -> nn.Module | None:
|
1292
1292
|
if not isinstance(mod, base_t):
|
1293
1293
|
return None
|
1294
|
-
|
1294
|
+
constants: Iterable = mod.__constants__ # type: ignore[assignment]
|
1295
|
+
args = [getattr(mod, name, None) for name in constants]
|
1295
1296
|
out = dest_t(*args)
|
1296
1297
|
return out
|
1297
1298
|
|
@@ -1856,7 +1856,7 @@ class RandHistogramShift(RandomizableTransform):
|
|
1856
1856
|
indices = ns.searchsorted(xp.reshape(-1), x.reshape(-1)) - 1
|
1857
1857
|
indices = ns.clip(indices, 0, len(m) - 1)
|
1858
1858
|
|
1859
|
-
f = (m[indices] * x.reshape(-1) + b[indices]).reshape(x.shape)
|
1859
|
+
f: NdarrayOrTensor = (m[indices] * x.reshape(-1) + b[indices]).reshape(x.shape)
|
1860
1860
|
f[x < xp[0]] = fp[0]
|
1861
1861
|
f[x > xp[-1]] = fp[-1]
|
1862
1862
|
return f
|
@@ -1758,13 +1758,13 @@ class AffineGrid(LazyTransform):
|
|
1758
1758
|
if self.affine is None:
|
1759
1759
|
affine = torch.eye(spatial_dims + 1, device=_device)
|
1760
1760
|
if self.rotate_params:
|
1761
|
-
affine @= create_rotate(spatial_dims, self.rotate_params, device=_device, backend=_b)
|
1761
|
+
affine @= create_rotate(spatial_dims, self.rotate_params, device=_device, backend=_b) # type: ignore[assignment]
|
1762
1762
|
if self.shear_params:
|
1763
|
-
affine @= create_shear(spatial_dims, self.shear_params, device=_device, backend=_b)
|
1763
|
+
affine @= create_shear(spatial_dims, self.shear_params, device=_device, backend=_b) # type: ignore[assignment]
|
1764
1764
|
if self.translate_params:
|
1765
|
-
affine @= create_translate(spatial_dims, self.translate_params, device=_device, backend=_b)
|
1765
|
+
affine @= create_translate(spatial_dims, self.translate_params, device=_device, backend=_b) # type: ignore[assignment]
|
1766
1766
|
if self.scale_params:
|
1767
|
-
affine @= create_scale(spatial_dims, self.scale_params, device=_device, backend=_b)
|
1767
|
+
affine @= create_scale(spatial_dims, self.scale_params, device=_device, backend=_b) # type: ignore[assignment]
|
1768
1768
|
else:
|
1769
1769
|
affine = self.affine # type: ignore
|
1770
1770
|
affine = to_affine_nd(spatial_dims, affine)
|
@@ -1780,7 +1780,7 @@ class AffineGrid(LazyTransform):
|
|
1780
1780
|
grid_ = ((affine @ sc) @ grid_.view((grid_.shape[0], -1))).view([-1] + list(grid_.shape[1:]))
|
1781
1781
|
else:
|
1782
1782
|
grid_ = (affine @ grid_.view((grid_.shape[0], -1))).view([-1] + list(grid_.shape[1:]))
|
1783
|
-
return grid_, affine
|
1783
|
+
return grid_, affine # type: ignore[return-value]
|
1784
1784
|
|
1785
1785
|
|
1786
1786
|
class RandAffineGrid(Randomizable, LazyTransform):
|
@@ -3257,7 +3257,7 @@ class GridPatch(Transform, MultiSampleTrait):
|
|
3257
3257
|
tuple[NdarrayOrTensor, numpy.ndarray]: tuple of filtered patches and locations.
|
3258
3258
|
"""
|
3259
3259
|
n_dims = len(image_np.shape)
|
3260
|
-
idx = argwhere(image_np.sum(tuple(range(1, n_dims))) < self.threshold).reshape(-1)
|
3260
|
+
idx = argwhere(image_np.sum(tuple(range(1, n_dims))) < self.threshold).reshape(-1) # type: ignore[operator]
|
3261
3261
|
idx_np = convert_data_type(idx, np.ndarray)[0]
|
3262
3262
|
return image_np[idx], locations[idx_np]
|
3263
3263
|
|
monai/utils/misc.py
CHANGED
monai/utils/state_cacher.py
CHANGED
@@ -124,7 +124,7 @@ class StateCacher:
|
|
124
124
|
fn = self.cached[key]["obj"] # pytype: disable=attribute-error
|
125
125
|
if not os.path.exists(fn): # pytype: disable=wrong-arg-types
|
126
126
|
raise RuntimeError(f"Failed to load state in {fn}. File doesn't exist anymore.")
|
127
|
-
data_obj = torch.load(fn, map_location=lambda storage, location: storage)
|
127
|
+
data_obj = torch.load(fn, map_location=lambda storage, location: storage, weights_only=False)
|
128
128
|
# copy back to device if necessary
|
129
129
|
if "device" in self.cached[key]:
|
130
130
|
data_obj = data_obj.to(self.cached[key]["device"])
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: monai-weekly
|
3
|
-
Version: 1.5.
|
3
|
+
Version: 1.5.dev2510
|
4
4
|
Summary: AI Toolkit for Healthcare Imaging
|
5
5
|
Home-page: https://monai.io/
|
6
6
|
Author: MONAI Consortium
|
@@ -29,8 +29,9 @@ Classifier: Typing :: Typed
|
|
29
29
|
Requires-Python: >=3.9
|
30
30
|
Description-Content-Type: text/markdown; charset=UTF-8
|
31
31
|
License-File: LICENSE
|
32
|
-
Requires-Dist: torch>=
|
33
|
-
Requires-Dist:
|
32
|
+
Requires-Dist: torch>=2.3.0; sys_platform != "win32"
|
33
|
+
Requires-Dist: torch>=2.4.1; sys_platform == "win32"
|
34
|
+
Requires-Dist: numpy<3.0,>=1.24
|
34
35
|
Provides-Extra: all
|
35
36
|
Requires-Dist: nibabel; extra == "all"
|
36
37
|
Requires-Dist: ninja; extra == "all"
|