monai-weekly 1.5.dev2508__py3-none-any.whl → 1.5.dev2510__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/deepedit/interaction.py +1 -1
- monai/apps/deepgrow/interaction.py +1 -1
- monai/apps/detection/networks/retinanet_detector.py +1 -1
- monai/apps/detection/networks/retinanet_network.py +5 -5
- monai/apps/detection/utils/box_coder.py +2 -2
- monai/apps/mmars/mmars.py +1 -1
- monai/apps/reconstruction/networks/blocks/varnetblock.py +1 -1
- monai/bundle/scripts.py +42 -20
- monai/data/dataset.py +2 -9
- monai/data/utils.py +1 -1
- monai/data/video_dataset.py +1 -1
- monai/engines/evaluator.py +11 -16
- monai/engines/trainer.py +11 -17
- monai/engines/utils.py +1 -1
- monai/engines/workflow.py +2 -2
- monai/fl/client/monai_algo.py +1 -1
- monai/handlers/checkpoint_loader.py +1 -1
- monai/inferers/inferer.py +35 -17
- monai/inferers/merger.py +16 -13
- monai/losses/perceptual.py +1 -1
- monai/losses/sure_loss.py +1 -1
- monai/networks/blocks/crossattention.py +1 -6
- monai/networks/blocks/feature_pyramid_network.py +4 -2
- monai/networks/blocks/selfattention.py +1 -6
- monai/networks/blocks/upsample.py +3 -11
- monai/networks/layers/vector_quantizer.py +2 -2
- monai/networks/nets/hovernet.py +5 -4
- monai/networks/nets/resnet.py +2 -2
- monai/networks/nets/senet.py +1 -1
- monai/networks/nets/swin_unetr.py +46 -49
- monai/networks/nets/transchex.py +3 -2
- monai/networks/nets/vista3d.py +7 -7
- monai/networks/utils.py +5 -4
- monai/transforms/intensity/array.py +1 -1
- monai/transforms/spatial/array.py +6 -6
- monai/utils/misc.py +1 -1
- monai/utils/state_cacher.py +1 -1
- {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2510.dist-info}/METADATA +4 -3
- {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2510.dist-info}/RECORD +60 -60
- {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2510.dist-info}/WHEEL +1 -1
- tests/bundle/test_bundle_download.py +16 -6
- tests/config/test_cv2_dist.py +1 -2
- tests/inferers/test_controlnet_inferers.py +9 -0
- tests/integration/test_integration_bundle_run.py +2 -4
- tests/integration/test_integration_classification_2d.py +1 -1
- tests/integration/test_integration_fast_train.py +2 -2
- tests/integration/test_integration_segmentation_3d.py +1 -1
- tests/metrics/test_compute_multiscalessim_metric.py +3 -3
- tests/metrics/test_surface_dice.py +3 -3
- tests/networks/nets/test_autoencoderkl.py +1 -1
- tests/networks/nets/test_controlnet.py +1 -1
- tests/networks/nets/test_diffusion_model_unet.py +1 -1
- tests/networks/nets/test_network_consistency.py +1 -1
- tests/networks/nets/test_swin_unetr.py +1 -1
- tests/networks/nets/test_transformer.py +1 -1
- tests/networks/test_save_state.py +1 -1
- {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2510.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2508.dist-info → monai_weekly-1.5.dev2510.dist-info}/top_level.txt +0 -0
monai/inferers/inferer.py
CHANGED
@@ -882,7 +882,7 @@ class DiffusionInferer(Inferer):
|
|
882
882
|
)
|
883
883
|
|
884
884
|
# 2. compute previous image: x_t -> x_t-1
|
885
|
-
image, _ = scheduler.step(model_output, t, image)
|
885
|
+
image, _ = scheduler.step(model_output, t, image) # type: ignore[operator]
|
886
886
|
if save_intermediates and t % intermediate_steps == 0:
|
887
887
|
intermediates.append(image)
|
888
888
|
if save_intermediates:
|
@@ -986,8 +986,8 @@ class DiffusionInferer(Inferer):
|
|
986
986
|
predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image
|
987
987
|
|
988
988
|
# get the posterior mean and variance
|
989
|
-
posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image)
|
990
|
-
posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance)
|
989
|
+
posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) # type: ignore[operator]
|
990
|
+
posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) # type: ignore[operator]
|
991
991
|
|
992
992
|
log_posterior_variance = torch.log(posterior_variance)
|
993
993
|
log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance
|
@@ -1334,13 +1334,15 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1334
1334
|
raise NotImplementedError(f"{mode} condition is not supported")
|
1335
1335
|
|
1336
1336
|
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
|
1337
|
-
|
1338
|
-
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond
|
1339
|
-
)
|
1337
|
+
|
1340
1338
|
if mode == "concat" and condition is not None:
|
1341
1339
|
noisy_image = torch.cat([noisy_image, condition], dim=1)
|
1342
1340
|
condition = None
|
1343
1341
|
|
1342
|
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
1343
|
+
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond, context=condition
|
1344
|
+
)
|
1345
|
+
|
1344
1346
|
diffuse = diffusion_model
|
1345
1347
|
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
|
1346
1348
|
diffuse = partial(diffusion_model, seg=seg)
|
@@ -1396,17 +1398,21 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1396
1398
|
progress_bar = iter(scheduler.timesteps)
|
1397
1399
|
intermediates = []
|
1398
1400
|
for t in progress_bar:
|
1399
|
-
# 1. ControlNet forward
|
1400
|
-
down_block_res_samples, mid_block_res_sample = controlnet(
|
1401
|
-
x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond
|
1402
|
-
)
|
1403
|
-
# 2. predict noise model_output
|
1404
1401
|
diffuse = diffusion_model
|
1405
1402
|
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
|
1406
1403
|
diffuse = partial(diffusion_model, seg=seg)
|
1407
1404
|
|
1408
1405
|
if mode == "concat" and conditioning is not None:
|
1406
|
+
# 1. Conditioning
|
1409
1407
|
model_input = torch.cat([image, conditioning], dim=1)
|
1408
|
+
# 2. ControlNet forward
|
1409
|
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
1410
|
+
x=model_input,
|
1411
|
+
timesteps=torch.Tensor((t,)).to(input_noise.device),
|
1412
|
+
controlnet_cond=cn_cond,
|
1413
|
+
context=None,
|
1414
|
+
)
|
1415
|
+
# 3. predict noise model_output
|
1410
1416
|
model_output = diffuse(
|
1411
1417
|
model_input,
|
1412
1418
|
timesteps=torch.Tensor((t,)).to(input_noise.device),
|
@@ -1415,6 +1421,12 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1415
1421
|
mid_block_additional_residual=mid_block_res_sample,
|
1416
1422
|
)
|
1417
1423
|
else:
|
1424
|
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
1425
|
+
x=image,
|
1426
|
+
timesteps=torch.Tensor((t,)).to(input_noise.device),
|
1427
|
+
controlnet_cond=cn_cond,
|
1428
|
+
context=conditioning,
|
1429
|
+
)
|
1418
1430
|
model_output = diffuse(
|
1419
1431
|
image,
|
1420
1432
|
timesteps=torch.Tensor((t,)).to(input_noise.device),
|
@@ -1424,7 +1436,7 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1424
1436
|
)
|
1425
1437
|
|
1426
1438
|
# 3. compute previous image: x_t -> x_t-1
|
1427
|
-
image, _ = scheduler.step(model_output, t, image)
|
1439
|
+
image, _ = scheduler.step(model_output, t, image) # type: ignore[operator]
|
1428
1440
|
if save_intermediates and t % intermediate_steps == 0:
|
1429
1441
|
intermediates.append(image)
|
1430
1442
|
if save_intermediates:
|
@@ -1485,9 +1497,6 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1485
1497
|
for t in progress_bar:
|
1486
1498
|
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
|
1487
1499
|
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
|
1488
|
-
down_block_res_samples, mid_block_res_sample = controlnet(
|
1489
|
-
x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond
|
1490
|
-
)
|
1491
1500
|
|
1492
1501
|
diffuse = diffusion_model
|
1493
1502
|
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
|
@@ -1495,6 +1504,9 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1495
1504
|
|
1496
1505
|
if mode == "concat" and conditioning is not None:
|
1497
1506
|
noisy_image = torch.cat([noisy_image, conditioning], dim=1)
|
1507
|
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
1508
|
+
x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond, context=None
|
1509
|
+
)
|
1498
1510
|
model_output = diffuse(
|
1499
1511
|
noisy_image,
|
1500
1512
|
timesteps=timesteps,
|
@@ -1503,6 +1515,12 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1503
1515
|
mid_block_additional_residual=mid_block_res_sample,
|
1504
1516
|
)
|
1505
1517
|
else:
|
1518
|
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
1519
|
+
x=noisy_image,
|
1520
|
+
timesteps=torch.Tensor((t,)).to(inputs.device),
|
1521
|
+
controlnet_cond=cn_cond,
|
1522
|
+
context=conditioning,
|
1523
|
+
)
|
1506
1524
|
model_output = diffuse(
|
1507
1525
|
x=noisy_image,
|
1508
1526
|
timesteps=timesteps,
|
@@ -1544,8 +1562,8 @@ class ControlNetDiffusionInferer(DiffusionInferer):
|
|
1544
1562
|
predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image
|
1545
1563
|
|
1546
1564
|
# get the posterior mean and variance
|
1547
|
-
posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image)
|
1548
|
-
posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance)
|
1565
|
+
posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) # type: ignore[operator]
|
1566
|
+
posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) # type: ignore[operator]
|
1549
1567
|
|
1550
1568
|
log_posterior_variance = torch.log(posterior_variance)
|
1551
1569
|
log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance
|
monai/inferers/merger.py
CHANGED
@@ -53,8 +53,11 @@ class Merger(ABC):
|
|
53
53
|
cropped_shape: Sequence[int] | None = None,
|
54
54
|
device: torch.device | str | None = None,
|
55
55
|
) -> None:
|
56
|
-
|
57
|
-
|
56
|
+
if merged_shape is None:
|
57
|
+
raise ValueError("Argument `merged_shape` must be provided")
|
58
|
+
|
59
|
+
self.merged_shape: tuple[int, ...] = tuple(merged_shape)
|
60
|
+
self.cropped_shape: tuple[int, ...] = tuple(self.merged_shape if cropped_shape is None else cropped_shape)
|
58
61
|
self.device = device
|
59
62
|
self.is_finalized = False
|
60
63
|
|
@@ -231,9 +234,9 @@ class ZarrAvgMerger(Merger):
|
|
231
234
|
dtype: np.dtype | str = "float32",
|
232
235
|
value_dtype: np.dtype | str = "float32",
|
233
236
|
count_dtype: np.dtype | str = "uint8",
|
234
|
-
store: zarr.storage.Store | str = "merged.zarr",
|
235
|
-
value_store: zarr.storage.Store | str | None = None,
|
236
|
-
count_store: zarr.storage.Store | str | None = None,
|
237
|
+
store: zarr.storage.Store | str = "merged.zarr", # type: ignore
|
238
|
+
value_store: zarr.storage.Store | str | None = None, # type: ignore
|
239
|
+
count_store: zarr.storage.Store | str | None = None, # type: ignore
|
237
240
|
compressor: str | None = None,
|
238
241
|
value_compressor: str | None = None,
|
239
242
|
count_compressor: str | None = None,
|
@@ -251,18 +254,18 @@ class ZarrAvgMerger(Merger):
|
|
251
254
|
if version_geq(get_package_version("zarr"), "3.0.0"):
|
252
255
|
if value_store is None:
|
253
256
|
self.tmpdir = TemporaryDirectory()
|
254
|
-
self.value_store = zarr.storage.LocalStore(self.tmpdir.name)
|
257
|
+
self.value_store = zarr.storage.LocalStore(self.tmpdir.name) # type: ignore
|
255
258
|
else:
|
256
|
-
self.value_store = value_store
|
259
|
+
self.value_store = value_store # type: ignore
|
257
260
|
if count_store is None:
|
258
261
|
self.tmpdir = TemporaryDirectory()
|
259
|
-
self.count_store = zarr.storage.LocalStore(self.tmpdir.name)
|
262
|
+
self.count_store = zarr.storage.LocalStore(self.tmpdir.name) # type: ignore
|
260
263
|
else:
|
261
|
-
self.count_store = count_store
|
264
|
+
self.count_store = count_store # type: ignore
|
262
265
|
else:
|
263
266
|
self.tmpdir = None
|
264
|
-
self.value_store = zarr.storage.TempStore() if value_store is None else value_store
|
265
|
-
self.count_store = zarr.storage.TempStore() if count_store is None else count_store
|
267
|
+
self.value_store = zarr.storage.TempStore() if value_store is None else value_store # type: ignore
|
268
|
+
self.count_store = zarr.storage.TempStore() if count_store is None else count_store # type: ignore
|
266
269
|
self.chunks = chunks
|
267
270
|
self.compressor = compressor
|
268
271
|
self.value_compressor = value_compressor
|
@@ -314,7 +317,7 @@ class ZarrAvgMerger(Merger):
|
|
314
317
|
map_slice = ensure_tuple_size(map_slice, values.ndim, pad_val=slice(None), pad_from_start=True)
|
315
318
|
with self.lock:
|
316
319
|
self.values[map_slice] += values.numpy()
|
317
|
-
self.counts[map_slice] += 1
|
320
|
+
self.counts[map_slice] += 1 # type: ignore[operator]
|
318
321
|
|
319
322
|
def finalize(self) -> zarr.Array:
|
320
323
|
"""
|
@@ -332,7 +335,7 @@ class ZarrAvgMerger(Merger):
|
|
332
335
|
if not self.is_finalized:
|
333
336
|
# use chunks for division to fit into memory
|
334
337
|
for chunk in iterate_over_chunks(self.values.chunks, self.values.cdata_shape):
|
335
|
-
self.output[chunk] = self.values[chunk] / self.counts[chunk]
|
338
|
+
self.output[chunk] = self.values[chunk] / self.counts[chunk] # type: ignore[operator]
|
336
339
|
# finalize the shape
|
337
340
|
self.output.resize(self.cropped_shape)
|
338
341
|
# set finalize flag to protect performing in-place division again
|
monai/losses/perceptual.py
CHANGED
@@ -374,7 +374,7 @@ class TorchvisionModelPerceptualSimilarity(nn.Module):
|
|
374
374
|
else:
|
375
375
|
network = torchvision.models.resnet50(weights=None)
|
376
376
|
if pretrained is True:
|
377
|
-
state_dict = torch.load(pretrained_path)
|
377
|
+
state_dict = torch.load(pretrained_path, weights_only=True)
|
378
378
|
if pretrained_state_dict_key is not None:
|
379
379
|
state_dict = state_dict[pretrained_state_dict_key]
|
380
380
|
network.load_state_dict(state_dict)
|
monai/losses/sure_loss.py
CHANGED
@@ -92,7 +92,7 @@ def sure_loss_function(
|
|
92
92
|
y_ref = operator(x)
|
93
93
|
|
94
94
|
# get perturbed output
|
95
|
-
x_perturbed = x + eps * perturb_noise
|
95
|
+
x_perturbed = x + eps * perturb_noise # type: ignore
|
96
96
|
y_perturbed = operator(x_perturbed)
|
97
97
|
# divergence
|
98
98
|
divergence = torch.sum(1.0 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref)) # type: ignore
|
@@ -17,7 +17,7 @@ import torch
|
|
17
17
|
import torch.nn as nn
|
18
18
|
|
19
19
|
from monai.networks.layers.utils import get_rel_pos_embedding_layer
|
20
|
-
from monai.utils import optional_import
|
20
|
+
from monai.utils import optional_import
|
21
21
|
|
22
22
|
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
|
23
23
|
|
@@ -84,11 +84,6 @@ class CrossAttentionBlock(nn.Module):
|
|
84
84
|
if causal and sequence_length is None:
|
85
85
|
raise ValueError("sequence_length is necessary for causal attention.")
|
86
86
|
|
87
|
-
if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
|
88
|
-
raise ValueError(
|
89
|
-
"use_flash_attention is only supported for PyTorch versions >= 2.0."
|
90
|
-
"Upgrade your PyTorch or set the flag to False."
|
91
|
-
)
|
92
87
|
if use_flash_attention and save_attn:
|
93
88
|
raise ValueError(
|
94
89
|
"save_attn has been set to True, but use_flash_attention is also set"
|
@@ -54,7 +54,9 @@ from __future__ import annotations
|
|
54
54
|
|
55
55
|
from collections import OrderedDict
|
56
56
|
from collections.abc import Callable
|
57
|
+
from typing import cast
|
57
58
|
|
59
|
+
import torch
|
58
60
|
import torch.nn.functional as F
|
59
61
|
from torch import Tensor, nn
|
60
62
|
|
@@ -194,8 +196,8 @@ class FeaturePyramidNetwork(nn.Module):
|
|
194
196
|
conv_type_: type[nn.Module] = Conv[Conv.CONV, spatial_dims]
|
195
197
|
for m in self.modules():
|
196
198
|
if isinstance(m, conv_type_):
|
197
|
-
nn.init.kaiming_uniform_(m.weight, a=1)
|
198
|
-
nn.init.constant_(m.bias, 0.0)
|
199
|
+
nn.init.kaiming_uniform_(cast(torch.Tensor, m.weight), a=1)
|
200
|
+
nn.init.constant_(cast(torch.Tensor, m.bias), 0.0)
|
199
201
|
|
200
202
|
if extra_blocks is not None:
|
201
203
|
if not isinstance(extra_blocks, ExtraFPNBlock):
|
@@ -18,7 +18,7 @@ import torch.nn as nn
|
|
18
18
|
import torch.nn.functional as F
|
19
19
|
|
20
20
|
from monai.networks.layers.utils import get_rel_pos_embedding_layer
|
21
|
-
from monai.utils import optional_import
|
21
|
+
from monai.utils import optional_import
|
22
22
|
|
23
23
|
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
|
24
24
|
|
@@ -90,11 +90,6 @@ class SABlock(nn.Module):
|
|
90
90
|
if causal and sequence_length is None:
|
91
91
|
raise ValueError("sequence_length is necessary for causal attention.")
|
92
92
|
|
93
|
-
if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
|
94
|
-
raise ValueError(
|
95
|
-
"use_flash_attention is only supported for PyTorch versions >= 2.0."
|
96
|
-
"Upgrade your PyTorch or set the flag to False."
|
97
|
-
)
|
98
93
|
if use_flash_attention and save_attn:
|
99
94
|
raise ValueError(
|
100
95
|
"save_attn has been set to True, but use_flash_attention is also set"
|
@@ -17,8 +17,8 @@ import torch
|
|
17
17
|
import torch.nn as nn
|
18
18
|
|
19
19
|
from monai.networks.layers.factories import Conv, Pad, Pool
|
20
|
-
from monai.networks.utils import
|
21
|
-
from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option
|
20
|
+
from monai.networks.utils import icnr_init, pixelshuffle
|
21
|
+
from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option
|
22
22
|
|
23
23
|
__all__ = ["Upsample", "UpSample", "SubpixelUpsample", "Subpixelupsample", "SubpixelUpSample"]
|
24
24
|
|
@@ -164,15 +164,7 @@ class UpSample(nn.Sequential):
|
|
164
164
|
align_corners=align_corners,
|
165
165
|
)
|
166
166
|
|
167
|
-
|
168
|
-
# https://github.com/pytorch/pytorch/issues/86679. This issue is solved in PyTorch 2.1
|
169
|
-
if pytorch_after(major=2, minor=1):
|
170
|
-
self.add_module("upsample_non_trainable", upsample)
|
171
|
-
else:
|
172
|
-
self.add_module(
|
173
|
-
"upsample_non_trainable",
|
174
|
-
CastTempType(initial_type=torch.bfloat16, temporary_type=torch.float32, submodule=upsample),
|
175
|
-
)
|
167
|
+
self.add_module("upsample_non_trainable", upsample)
|
176
168
|
if post_conv:
|
177
169
|
self.add_module("postconv", post_conv)
|
178
170
|
elif up_mode == UpsampleMode.PIXELSHUFFLE:
|
@@ -100,7 +100,7 @@ class EMAQuantizer(nn.Module):
|
|
100
100
|
torch.Tensor: Quantization indices of shape [B,H,W,D,1]
|
101
101
|
|
102
102
|
"""
|
103
|
-
with torch.
|
103
|
+
with torch.autocast("cuda", enabled=False):
|
104
104
|
encoding_indices_view = list(inputs.shape)
|
105
105
|
del encoding_indices_view[1]
|
106
106
|
|
@@ -138,7 +138,7 @@ class EMAQuantizer(nn.Module):
|
|
138
138
|
Returns:
|
139
139
|
torch.Tensor: Quantize space representation of encoding_indices in channel first format.
|
140
140
|
"""
|
141
|
-
with torch.
|
141
|
+
with torch.autocast("cuda", enabled=False):
|
142
142
|
embedding: torch.Tensor = (
|
143
143
|
self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous()
|
144
144
|
)
|
monai/networks/nets/hovernet.py
CHANGED
@@ -633,9 +633,9 @@ def _remap_preact_resnet_model(model_url: str):
|
|
633
633
|
# download the pretrained weights into torch hub's default dir
|
634
634
|
weights_dir = os.path.join(torch.hub.get_dir(), "preact-resnet50.pth")
|
635
635
|
download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)
|
636
|
-
|
637
|
-
|
638
|
-
|
636
|
+
map_location = None if torch.cuda.is_available() else torch.device("cpu")
|
637
|
+
state_dict = torch.load(weights_dir, map_location=map_location, weights_only=True)["desc"]
|
638
|
+
|
639
639
|
for key in list(state_dict.keys()):
|
640
640
|
new_key = None
|
641
641
|
if pattern_conv0.match(key):
|
@@ -668,7 +668,8 @@ def _remap_standard_resnet_model(model_url: str, state_dict_key: str | None = No
|
|
668
668
|
# download the pretrained weights into torch hub's default dir
|
669
669
|
weights_dir = os.path.join(torch.hub.get_dir(), "resnet50.pth")
|
670
670
|
download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)
|
671
|
-
|
671
|
+
map_location = None if torch.cuda.is_available() else torch.device("cpu")
|
672
|
+
state_dict = torch.load(weights_dir, map_location=map_location, weights_only=True)
|
672
673
|
if state_dict_key is not None:
|
673
674
|
state_dict = state_dict[state_dict_key]
|
674
675
|
|
monai/networks/nets/resnet.py
CHANGED
@@ -493,7 +493,7 @@ def _resnet(
|
|
493
493
|
if isinstance(pretrained, str):
|
494
494
|
if Path(pretrained).exists():
|
495
495
|
logger.info(f"Loading weights from {pretrained}...")
|
496
|
-
model_state_dict = torch.load(pretrained, map_location=device)
|
496
|
+
model_state_dict = torch.load(pretrained, map_location=device, weights_only=True)
|
497
497
|
else:
|
498
498
|
# Throw error
|
499
499
|
raise FileNotFoundError("The pretrained checkpoint file is not found")
|
@@ -665,7 +665,7 @@ def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", dat
|
|
665
665
|
raise EntryNotFoundError(
|
666
666
|
f"{filename} not found on {medicalnet_huggingface_repo_basename}{resnet_depth}"
|
667
667
|
) from None
|
668
|
-
checkpoint = torch.load(pretrained_path, map_location=torch.device(device))
|
668
|
+
checkpoint = torch.load(pretrained_path, map_location=torch.device(device), weights_only=True)
|
669
669
|
else:
|
670
670
|
raise NotImplementedError("Supported resnet_depth are: [10, 18, 34, 50, 101, 152, 200]")
|
671
671
|
logger.info(f"{filename} downloaded")
|
monai/networks/nets/senet.py
CHANGED
@@ -302,7 +302,7 @@ def _load_state_dict(model: nn.Module, arch: str, progress: bool):
|
|
302
302
|
|
303
303
|
if isinstance(model_url, dict):
|
304
304
|
download_url(model_url["url"], filepath=model_url["filename"])
|
305
|
-
state_dict = torch.load(model_url["filename"], map_location=None)
|
305
|
+
state_dict = torch.load(model_url["filename"], map_location=None, weights_only=True)
|
306
306
|
else:
|
307
307
|
state_dict = load_state_dict_from_url(model_url, progress=progress)
|
308
308
|
for key in list(state_dict.keys()):
|
@@ -272,53 +272,50 @@ class SwinUNETR(nn.Module):
|
|
272
272
|
self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels)
|
273
273
|
|
274
274
|
def load_from(self, weights):
|
275
|
+
layers1_0: BasicLayer = self.swinViT.layers1[0] # type: ignore[assignment]
|
276
|
+
layers2_0: BasicLayer = self.swinViT.layers2[0] # type: ignore[assignment]
|
277
|
+
layers3_0: BasicLayer = self.swinViT.layers3[0] # type: ignore[assignment]
|
278
|
+
layers4_0: BasicLayer = self.swinViT.layers4[0] # type: ignore[assignment]
|
279
|
+
wstate = weights["state_dict"]
|
280
|
+
|
275
281
|
with torch.no_grad():
|
276
|
-
self.swinViT.patch_embed.proj.weight.copy_(
|
277
|
-
self.swinViT.patch_embed.proj.bias.copy_(
|
278
|
-
for bname, block in
|
279
|
-
block.load_from(weights, n_block=bname, layer="layers1")
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
)
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
weights
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
self.swinViT.layers4[0].downsample.reduction.weight.copy_(
|
314
|
-
weights["state_dict"]["module.layers4.0.downsample.reduction.weight"]
|
315
|
-
)
|
316
|
-
self.swinViT.layers4[0].downsample.norm.weight.copy_(
|
317
|
-
weights["state_dict"]["module.layers4.0.downsample.norm.weight"]
|
318
|
-
)
|
319
|
-
self.swinViT.layers4[0].downsample.norm.bias.copy_(
|
320
|
-
weights["state_dict"]["module.layers4.0.downsample.norm.bias"]
|
321
|
-
)
|
282
|
+
self.swinViT.patch_embed.proj.weight.copy_(wstate["module.patch_embed.proj.weight"])
|
283
|
+
self.swinViT.patch_embed.proj.bias.copy_(wstate["module.patch_embed.proj.bias"])
|
284
|
+
for bname, block in layers1_0.blocks.named_children():
|
285
|
+
block.load_from(weights, n_block=bname, layer="layers1") # type: ignore[operator]
|
286
|
+
|
287
|
+
if layers1_0.downsample is not None:
|
288
|
+
d = layers1_0.downsample
|
289
|
+
d.reduction.weight.copy_(wstate["module.layers1.0.downsample.reduction.weight"]) # type: ignore
|
290
|
+
d.norm.weight.copy_(wstate["module.layers1.0.downsample.norm.weight"]) # type: ignore
|
291
|
+
d.norm.bias.copy_(wstate["module.layers1.0.downsample.norm.bias"]) # type: ignore
|
292
|
+
|
293
|
+
for bname, block in layers2_0.blocks.named_children():
|
294
|
+
block.load_from(weights, n_block=bname, layer="layers2") # type: ignore[operator]
|
295
|
+
|
296
|
+
if layers2_0.downsample is not None:
|
297
|
+
d = layers2_0.downsample
|
298
|
+
d.reduction.weight.copy_(wstate["module.layers2.0.downsample.reduction.weight"]) # type: ignore
|
299
|
+
d.norm.weight.copy_(wstate["module.layers2.0.downsample.norm.weight"]) # type: ignore
|
300
|
+
d.norm.bias.copy_(wstate["module.layers2.0.downsample.norm.bias"]) # type: ignore
|
301
|
+
|
302
|
+
for bname, block in layers3_0.blocks.named_children():
|
303
|
+
block.load_from(weights, n_block=bname, layer="layers3") # type: ignore[operator]
|
304
|
+
|
305
|
+
if layers3_0.downsample is not None:
|
306
|
+
d = layers3_0.downsample
|
307
|
+
d.reduction.weight.copy_(wstate["module.layers3.0.downsample.reduction.weight"]) # type: ignore
|
308
|
+
d.norm.weight.copy_(wstate["module.layers3.0.downsample.norm.weight"]) # type: ignore
|
309
|
+
d.norm.bias.copy_(wstate["module.layers3.0.downsample.norm.bias"]) # type: ignore
|
310
|
+
|
311
|
+
for bname, block in layers4_0.blocks.named_children():
|
312
|
+
block.load_from(weights, n_block=bname, layer="layers4") # type: ignore[operator]
|
313
|
+
|
314
|
+
if layers4_0.downsample is not None:
|
315
|
+
d = layers4_0.downsample
|
316
|
+
d.reduction.weight.copy_(wstate["module.layers4.0.downsample.reduction.weight"]) # type: ignore
|
317
|
+
d.norm.weight.copy_(wstate["module.layers4.0.downsample.norm.weight"]) # type: ignore
|
318
|
+
d.norm.bias.copy_(wstate["module.layers4.0.downsample.norm.bias"]) # type: ignore
|
322
319
|
|
323
320
|
@torch.jit.unused
|
324
321
|
def _check_input_size(self, spatial_shape):
|
@@ -532,7 +529,7 @@ class WindowAttention(nn.Module):
|
|
532
529
|
q = q * self.scale
|
533
530
|
attn = q @ k.transpose(-2, -1)
|
534
531
|
relative_position_bias = self.relative_position_bias_table[
|
535
|
-
self.relative_position_index.clone()[:n, :n].reshape(-1)
|
532
|
+
self.relative_position_index.clone()[:n, :n].reshape(-1) # type: ignore[operator]
|
536
533
|
].reshape(n, n, -1)
|
537
534
|
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
|
538
535
|
attn = attn + relative_position_bias.unsqueeze(0)
|
@@ -691,7 +688,7 @@ class SwinTransformerBlock(nn.Module):
|
|
691
688
|
self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]])
|
692
689
|
self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]])
|
693
690
|
self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]])
|
694
|
-
self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]])
|
691
|
+
self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]]) # type: ignore[operator]
|
695
692
|
self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]])
|
696
693
|
self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]])
|
697
694
|
self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]])
|
@@ -1118,7 +1115,7 @@ def filter_swinunetr(key, value):
|
|
1118
1115
|
)
|
1119
1116
|
ssl_weights_path = "./ssl_pretrained_weights.pth"
|
1120
1117
|
download_url(resource, ssl_weights_path)
|
1121
|
-
ssl_weights = torch.load(ssl_weights_path)["model"]
|
1118
|
+
ssl_weights = torch.load(ssl_weights_path, weights_only=True)["model"]
|
1122
1119
|
|
1123
1120
|
dst_dict, loaded, not_loaded = copy_model_state(model, ssl_weights, filter_func=filter_swinunetr)
|
1124
1121
|
|
monai/networks/nets/transchex.py
CHANGED
@@ -43,7 +43,7 @@ class BertPreTrainedModel(nn.Module):
|
|
43
43
|
|
44
44
|
def init_bert_weights(self, module):
|
45
45
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
46
|
-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
46
|
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) # type: ignore[union-attr,arg-type]
|
47
47
|
elif isinstance(module, torch.nn.LayerNorm):
|
48
48
|
module.bias.data.zero_()
|
49
49
|
module.weight.data.fill_(1.0)
|
@@ -68,7 +68,8 @@ class BertPreTrainedModel(nn.Module):
|
|
68
68
|
weights_path = cached_file(path_or_repo_id, filename, cache_dir=cache_dir)
|
69
69
|
model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs)
|
70
70
|
if state_dict is None and not from_tf:
|
71
|
-
|
71
|
+
map_location = "cpu" if not torch.cuda.is_available() else None
|
72
|
+
state_dict = torch.load(weights_path, map_location=map_location, weights_only=True)
|
72
73
|
if from_tf:
|
73
74
|
return load_tf_weights_in_bert(model, weights_path)
|
74
75
|
old_keys = []
|
monai/networks/nets/vista3d.py
CHANGED
@@ -315,7 +315,7 @@ class VISTA3D(nn.Module):
|
|
315
315
|
"""
|
316
316
|
if auto_freeze != self.auto_freeze:
|
317
317
|
if hasattr(self.image_encoder, "set_auto_grad"):
|
318
|
-
self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze)
|
318
|
+
self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) # type: ignore[operator]
|
319
319
|
else:
|
320
320
|
for param in self.image_encoder.parameters():
|
321
321
|
param.requires_grad = (not auto_freeze) and (not point_freeze)
|
@@ -325,7 +325,7 @@ class VISTA3D(nn.Module):
|
|
325
325
|
|
326
326
|
if point_freeze != self.point_freeze:
|
327
327
|
if hasattr(self.image_encoder, "set_auto_grad"):
|
328
|
-
self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze)
|
328
|
+
self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) # type: ignore[operator]
|
329
329
|
else:
|
330
330
|
for param in self.image_encoder.parameters():
|
331
331
|
param.requires_grad = (not auto_freeze) and (not point_freeze)
|
@@ -543,10 +543,10 @@ class PointMappingSAM(nn.Module):
|
|
543
543
|
point_embedding = self.pe_layer.forward_with_coords(points, out_shape) # type: ignore
|
544
544
|
point_embedding[point_labels == -1] = 0.0
|
545
545
|
point_embedding[point_labels == -1] += self.not_a_point_embed.weight
|
546
|
-
point_embedding[point_labels == 0] += self.point_embeddings[0].weight
|
547
|
-
point_embedding[point_labels == 1] += self.point_embeddings[1].weight
|
548
|
-
point_embedding[point_labels == 2] += self.point_embeddings[0].weight + self.special_class_embed.weight
|
549
|
-
point_embedding[point_labels == 3] += self.point_embeddings[1].weight + self.special_class_embed.weight
|
546
|
+
point_embedding[point_labels == 0] += self.point_embeddings[0].weight # type: ignore[arg-type]
|
547
|
+
point_embedding[point_labels == 1] += self.point_embeddings[1].weight # type: ignore[arg-type]
|
548
|
+
point_embedding[point_labels == 2] += self.point_embeddings[0].weight + self.special_class_embed.weight # type: ignore[operator]
|
549
|
+
point_embedding[point_labels == 3] += self.point_embeddings[1].weight + self.special_class_embed.weight # type: ignore[operator]
|
550
550
|
output_tokens = self.mask_tokens.weight
|
551
551
|
|
552
552
|
output_tokens = output_tokens.unsqueeze(0).expand(point_embedding.size(0), -1, -1)
|
@@ -884,7 +884,7 @@ class PositionEmbeddingRandom(nn.Module):
|
|
884
884
|
coords = 2 * coords - 1
|
885
885
|
# [bs=1,N=2,2] @ [2,128]
|
886
886
|
# [bs=1, N=2, 128]
|
887
|
-
coords = coords @ self.positional_encoding_gaussian_matrix
|
887
|
+
coords = coords @ self.positional_encoding_gaussian_matrix # type: ignore[operator]
|
888
888
|
coords = 2 * np.pi * coords
|
889
889
|
# outputs d_1 x ... x d_n x C shape
|
890
890
|
# [bs=1, N=2, 128+128=256]
|
monai/networks/utils.py
CHANGED
@@ -22,7 +22,7 @@ from collections import OrderedDict
|
|
22
22
|
from collections.abc import Callable, Mapping, Sequence
|
23
23
|
from contextlib import contextmanager
|
24
24
|
from copy import deepcopy
|
25
|
-
from typing import Any
|
25
|
+
from typing import Any, Iterable
|
26
26
|
|
27
27
|
import numpy as np
|
28
28
|
import torch
|
@@ -1238,7 +1238,7 @@ class CastToFloat(torch.nn.Module):
|
|
1238
1238
|
|
1239
1239
|
def forward(self, x):
|
1240
1240
|
dtype = x.dtype
|
1241
|
-
with torch.
|
1241
|
+
with torch.autocast("cuda", enabled=False):
|
1242
1242
|
ret = self.mod.forward(x.to(torch.float32)).to(dtype)
|
1243
1243
|
return ret
|
1244
1244
|
|
@@ -1255,7 +1255,7 @@ class CastToFloatAll(torch.nn.Module):
|
|
1255
1255
|
|
1256
1256
|
def forward(self, *args):
|
1257
1257
|
from_dtype = args[0].dtype
|
1258
|
-
with torch.
|
1258
|
+
with torch.autocast("cuda", enabled=False):
|
1259
1259
|
ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32))
|
1260
1260
|
return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype)
|
1261
1261
|
|
@@ -1291,7 +1291,8 @@ def simple_replace(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable
|
|
1291
1291
|
def expansion_fn(mod: nn.Module) -> nn.Module | None:
|
1292
1292
|
if not isinstance(mod, base_t):
|
1293
1293
|
return None
|
1294
|
-
|
1294
|
+
constants: Iterable = mod.__constants__ # type: ignore[assignment]
|
1295
|
+
args = [getattr(mod, name, None) for name in constants]
|
1295
1296
|
out = dest_t(*args)
|
1296
1297
|
return out
|
1297
1298
|
|
@@ -1856,7 +1856,7 @@ class RandHistogramShift(RandomizableTransform):
|
|
1856
1856
|
indices = ns.searchsorted(xp.reshape(-1), x.reshape(-1)) - 1
|
1857
1857
|
indices = ns.clip(indices, 0, len(m) - 1)
|
1858
1858
|
|
1859
|
-
f = (m[indices] * x.reshape(-1) + b[indices]).reshape(x.shape)
|
1859
|
+
f: NdarrayOrTensor = (m[indices] * x.reshape(-1) + b[indices]).reshape(x.shape)
|
1860
1860
|
f[x < xp[0]] = fp[0]
|
1861
1861
|
f[x > xp[-1]] = fp[-1]
|
1862
1862
|
return f
|