monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/hpo_gen.py +1 -1
- monai/apps/detection/utils/anchor_utils.py +2 -2
- monai/apps/pathology/transforms/post/array.py +7 -4
- monai/auto3dseg/analyzer.py +1 -1
- monai/bundle/scripts.py +204 -22
- monai/bundle/utils.py +1 -0
- monai/data/dataset_summary.py +1 -0
- monai/data/meta_tensor.py +2 -2
- monai/data/test_time_augmentation.py +2 -0
- monai/data/utils.py +9 -6
- monai/data/wsi_reader.py +2 -2
- monai/engines/__init__.py +3 -1
- monai/engines/trainer.py +281 -2
- monai/engines/utils.py +76 -1
- monai/handlers/mlflow_handler.py +21 -4
- monai/inferers/__init__.py +5 -0
- monai/inferers/inferer.py +1279 -1
- monai/metrics/cumulative_average.py +2 -0
- monai/metrics/panoptic_quality.py +1 -1
- monai/metrics/rocauc.py +2 -2
- monai/networks/blocks/__init__.py +3 -0
- monai/networks/blocks/attention_utils.py +128 -0
- monai/networks/blocks/crossattention.py +168 -0
- monai/networks/blocks/rel_pos_embedding.py +56 -0
- monai/networks/blocks/selfattention.py +74 -5
- monai/networks/blocks/spade_norm.py +95 -0
- monai/networks/blocks/spatialattention.py +82 -0
- monai/networks/blocks/transformerblock.py +25 -4
- monai/networks/blocks/upsample.py +22 -10
- monai/networks/layers/__init__.py +2 -1
- monai/networks/layers/factories.py +12 -1
- monai/networks/layers/simplelayers.py +1 -1
- monai/networks/layers/utils.py +14 -1
- monai/networks/layers/vector_quantizer.py +233 -0
- monai/networks/nets/__init__.py +9 -0
- monai/networks/nets/autoencoderkl.py +702 -0
- monai/networks/nets/controlnet.py +465 -0
- monai/networks/nets/diffusion_model_unet.py +1913 -0
- monai/networks/nets/patchgan_discriminator.py +230 -0
- monai/networks/nets/quicknat.py +8 -6
- monai/networks/nets/resnet.py +3 -4
- monai/networks/nets/spade_autoencoderkl.py +480 -0
- monai/networks/nets/spade_diffusion_model_unet.py +934 -0
- monai/networks/nets/spade_network.py +435 -0
- monai/networks/nets/swin_unetr.py +4 -3
- monai/networks/nets/transformer.py +157 -0
- monai/networks/nets/vqvae.py +472 -0
- monai/networks/schedulers/__init__.py +17 -0
- monai/networks/schedulers/ddim.py +294 -0
- monai/networks/schedulers/ddpm.py +250 -0
- monai/networks/schedulers/pndm.py +316 -0
- monai/networks/schedulers/scheduler.py +205 -0
- monai/networks/utils.py +22 -0
- monai/transforms/croppad/array.py +8 -8
- monai/transforms/croppad/dictionary.py +4 -4
- monai/transforms/croppad/functional.py +1 -1
- monai/transforms/regularization/array.py +4 -0
- monai/transforms/spatial/array.py +1 -1
- monai/transforms/utils_create_transform_ims.py +2 -4
- monai/utils/__init__.py +1 -0
- monai/utils/misc.py +5 -4
- monai/utils/ordering.py +207 -0
- monai/visualize/class_activation_maps.py +5 -5
- monai/visualize/img2tensorboard.py +3 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,82 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
from typing import Optional
|
15
|
+
|
16
|
+
import torch
|
17
|
+
import torch.nn as nn
|
18
|
+
|
19
|
+
from monai.networks.blocks import SABlock
|
20
|
+
from monai.utils import optional_import
|
21
|
+
|
22
|
+
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
|
23
|
+
|
24
|
+
|
25
|
+
class SpatialAttentionBlock(nn.Module):
|
26
|
+
"""Perform spatial self-attention on the input tensor.
|
27
|
+
|
28
|
+
The input tensor is reshaped to B x (x_dim * y_dim [ * z_dim]) x C, where C is the number of channels, and then
|
29
|
+
self-attention is performed on the reshaped tensor. The output tensor is reshaped back to the original shape.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
|
33
|
+
num_channels: number of input channels. Must be divisible by num_head_channels.
|
34
|
+
num_head_channels: number of channels per head.
|
35
|
+
attention_dtype: cast attention operations to this dtype.
|
36
|
+
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
spatial_dims: int,
|
42
|
+
num_channels: int,
|
43
|
+
num_head_channels: int | None = None,
|
44
|
+
norm_num_groups: int = 32,
|
45
|
+
norm_eps: float = 1e-6,
|
46
|
+
attention_dtype: Optional[torch.dtype] = None,
|
47
|
+
) -> None:
|
48
|
+
super().__init__()
|
49
|
+
|
50
|
+
self.spatial_dims = spatial_dims
|
51
|
+
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True)
|
52
|
+
# check num_head_channels is divisible by num_channels
|
53
|
+
if num_head_channels is not None and num_channels % num_head_channels != 0:
|
54
|
+
raise ValueError("num_channels must be divisible by num_head_channels")
|
55
|
+
num_heads = num_channels // num_head_channels if num_head_channels is not None else 1
|
56
|
+
self.attn = SABlock(
|
57
|
+
hidden_size=num_channels, num_heads=num_heads, qkv_bias=True, attention_dtype=attention_dtype
|
58
|
+
)
|
59
|
+
|
60
|
+
def forward(self, x: torch.Tensor):
|
61
|
+
residual = x
|
62
|
+
|
63
|
+
if self.spatial_dims == 1:
|
64
|
+
h = x.shape[2]
|
65
|
+
rearrange_input = Rearrange("b c h -> b h c")
|
66
|
+
rearrange_output = Rearrange("b h c -> b c h", h=h)
|
67
|
+
if self.spatial_dims == 2:
|
68
|
+
h, w = x.shape[2], x.shape[3]
|
69
|
+
rearrange_input = Rearrange("b c h w -> b (h w) c")
|
70
|
+
rearrange_output = Rearrange("b (h w) c -> b c h w", h=h, w=w)
|
71
|
+
else:
|
72
|
+
h, w, d = x.shape[2], x.shape[3], x.shape[4]
|
73
|
+
rearrange_input = Rearrange("b c h w d -> b (h w d) c")
|
74
|
+
rearrange_output = Rearrange("b (h w d) c -> b c h w d", h=h, w=w, d=d)
|
75
|
+
|
76
|
+
x = self.norm(x)
|
77
|
+
x = rearrange_input(x) # B x (x_dim * y_dim [ * z_dim]) x C
|
78
|
+
|
79
|
+
x = self.attn(x)
|
80
|
+
x = rearrange_output(x) # B x x C x x_dim * y_dim * [z_dim]
|
81
|
+
x = x + residual
|
82
|
+
return x
|
@@ -11,10 +11,12 @@
|
|
11
11
|
|
12
12
|
from __future__ import annotations
|
13
13
|
|
14
|
+
from typing import Optional
|
15
|
+
|
16
|
+
import torch
|
14
17
|
import torch.nn as nn
|
15
18
|
|
16
|
-
from monai.networks.blocks
|
17
|
-
from monai.networks.blocks.selfattention import SABlock
|
19
|
+
from monai.networks.blocks import CrossAttentionBlock, MLPBlock, SABlock
|
18
20
|
|
19
21
|
|
20
22
|
class TransformerBlock(nn.Module):
|
@@ -31,6 +33,9 @@ class TransformerBlock(nn.Module):
|
|
31
33
|
dropout_rate: float = 0.0,
|
32
34
|
qkv_bias: bool = False,
|
33
35
|
save_attn: bool = False,
|
36
|
+
causal: bool = False,
|
37
|
+
sequence_length: int | None = None,
|
38
|
+
with_cross_attention: bool = False,
|
34
39
|
) -> None:
|
35
40
|
"""
|
36
41
|
Args:
|
@@ -53,10 +58,26 @@ class TransformerBlock(nn.Module):
|
|
53
58
|
|
54
59
|
self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate)
|
55
60
|
self.norm1 = nn.LayerNorm(hidden_size)
|
56
|
-
self.attn = SABlock(
|
61
|
+
self.attn = SABlock(
|
62
|
+
hidden_size,
|
63
|
+
num_heads,
|
64
|
+
dropout_rate,
|
65
|
+
qkv_bias=qkv_bias,
|
66
|
+
save_attn=save_attn,
|
67
|
+
causal=causal,
|
68
|
+
sequence_length=sequence_length,
|
69
|
+
)
|
57
70
|
self.norm2 = nn.LayerNorm(hidden_size)
|
71
|
+
self.with_cross_attention = with_cross_attention
|
72
|
+
|
73
|
+
self.norm_cross_attn = nn.LayerNorm(hidden_size)
|
74
|
+
self.cross_attn = CrossAttentionBlock(
|
75
|
+
hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False
|
76
|
+
)
|
58
77
|
|
59
|
-
def forward(self, x):
|
78
|
+
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
|
60
79
|
x = x + self.attn(self.norm1(x))
|
80
|
+
if self.with_cross_attention:
|
81
|
+
x = x + self.cross_attn(self.norm_cross_attn(x), context=context)
|
61
82
|
x = x + self.mlp(self.norm2(x))
|
62
83
|
return x
|
@@ -17,8 +17,8 @@ import torch
|
|
17
17
|
import torch.nn as nn
|
18
18
|
|
19
19
|
from monai.networks.layers.factories import Conv, Pad, Pool
|
20
|
-
from monai.networks.utils import icnr_init, pixelshuffle
|
21
|
-
from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option
|
20
|
+
from monai.networks.utils import CastTempType, icnr_init, pixelshuffle
|
21
|
+
from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option, pytorch_after
|
22
22
|
|
23
23
|
__all__ = ["Upsample", "UpSample", "SubpixelUpsample", "Subpixelupsample", "SubpixelUpSample"]
|
24
24
|
|
@@ -50,6 +50,7 @@ class UpSample(nn.Sequential):
|
|
50
50
|
size: tuple[int] | int | None = None,
|
51
51
|
mode: UpsampleMode | str = UpsampleMode.DECONV,
|
52
52
|
pre_conv: nn.Module | str | None = "default",
|
53
|
+
post_conv: nn.Module | None = None,
|
53
54
|
interp_mode: str = InterpolateMode.LINEAR,
|
54
55
|
align_corners: bool | None = True,
|
55
56
|
bias: bool = True,
|
@@ -71,6 +72,7 @@ class UpSample(nn.Sequential):
|
|
71
72
|
pre_conv: a conv block applied before upsampling. Defaults to "default".
|
72
73
|
When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized when
|
73
74
|
Only used in the "nontrainable" or "pixelshuffle" mode.
|
75
|
+
post_conv: a conv block applied after upsampling. Defaults to None. Only used in the "nontrainable" mode.
|
74
76
|
interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
|
75
77
|
Only used in the "nontrainable" mode.
|
76
78
|
If ends with ``"linear"`` will use ``spatial dims`` to determine the correct interpolation.
|
@@ -154,15 +156,25 @@ class UpSample(nn.Sequential):
|
|
154
156
|
linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR]
|
155
157
|
if interp_mode in linear_mode: # choose mode based on dimensions
|
156
158
|
interp_mode = linear_mode[spatial_dims - 1]
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
align_corners=align_corners,
|
164
|
-
),
|
159
|
+
|
160
|
+
upsample = nn.Upsample(
|
161
|
+
size=size,
|
162
|
+
scale_factor=None if size else scale_factor_,
|
163
|
+
mode=interp_mode.value,
|
164
|
+
align_corners=align_corners,
|
165
165
|
)
|
166
|
+
|
167
|
+
# Cast to float32 as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
168
|
+
# https://github.com/pytorch/pytorch/issues/86679. This issue is solved in PyTorch 2.1
|
169
|
+
if pytorch_after(major=2, minor=1):
|
170
|
+
self.add_module("upsample_non_trainable", upsample)
|
171
|
+
else:
|
172
|
+
self.add_module(
|
173
|
+
"upsample_non_trainable",
|
174
|
+
CastTempType(initial_type=torch.bfloat16, temporary_type=torch.float32, submodule=upsample),
|
175
|
+
)
|
176
|
+
if post_conv:
|
177
|
+
self.add_module("postconv", post_conv)
|
166
178
|
elif up_mode == UpsampleMode.PIXELSHUFFLE:
|
167
179
|
self.add_module(
|
168
180
|
"pixelshuffle",
|
@@ -14,7 +14,7 @@ from __future__ import annotations
|
|
14
14
|
from .conjugate_gradient import ConjugateGradient
|
15
15
|
from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding
|
16
16
|
from .drop_path import DropPath
|
17
|
-
from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args
|
17
|
+
from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, RelPosEmbedding, split_args
|
18
18
|
from .filtering import BilateralFilter, PHLFilter, TrainableBilateralFilter, TrainableJointBilateralFilter
|
19
19
|
from .gmm import GaussianMixtureModel
|
20
20
|
from .simplelayers import (
|
@@ -38,4 +38,5 @@ from .simplelayers import (
|
|
38
38
|
)
|
39
39
|
from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push
|
40
40
|
from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer
|
41
|
+
from .vector_quantizer import EMAQuantizer, VectorQuantizer
|
41
42
|
from .weight_init import _no_grad_trunc_normal_, trunc_normal_
|
@@ -70,7 +70,7 @@ import torch.nn as nn
|
|
70
70
|
from monai.networks.utils import has_nvfuser_instance_norm
|
71
71
|
from monai.utils import ComponentStore, look_up_option, optional_import
|
72
72
|
|
73
|
-
__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"]
|
73
|
+
__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "RelPosEmbedding", "split_args"]
|
74
74
|
|
75
75
|
|
76
76
|
class LayerFactory(ComponentStore):
|
@@ -201,6 +201,10 @@ Act = LayerFactory(name="Activation layers", description="Factory for creating a
|
|
201
201
|
Conv = LayerFactory(name="Convolution layers", description="Factory for creating convolution layers.")
|
202
202
|
Pool = LayerFactory(name="Pooling layers", description="Factory for creating pooling layers.")
|
203
203
|
Pad = LayerFactory(name="Padding layers", description="Factory for creating padding layers.")
|
204
|
+
RelPosEmbedding = LayerFactory(
|
205
|
+
name="Relative positional embedding layers",
|
206
|
+
description="Factory for creating relative positional embedding factory",
|
207
|
+
)
|
204
208
|
|
205
209
|
|
206
210
|
@Dropout.factory_function("dropout")
|
@@ -468,3 +472,10 @@ def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d |
|
|
468
472
|
"""
|
469
473
|
types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d)
|
470
474
|
return types[dim - 1]
|
475
|
+
|
476
|
+
|
477
|
+
@RelPosEmbedding.factory_function("decomposed")
|
478
|
+
def decomposed_rel_pos_embedding() -> type[nn.Module]:
|
479
|
+
from monai.networks.blocks.rel_pos_embedding import DecomposedRelativePosEmbedding
|
480
|
+
|
481
|
+
return DecomposedRelativePosEmbedding
|
@@ -452,7 +452,7 @@ def get_binary_kernel(window_size: Sequence[int], dtype=torch.float, device=None
|
|
452
452
|
|
453
453
|
def median_filter(
|
454
454
|
in_tensor: torch.Tensor,
|
455
|
-
kernel_size: Sequence[int] = (3, 3, 3),
|
455
|
+
kernel_size: Sequence[int] | int = (3, 3, 3),
|
456
456
|
spatial_dims: int = 3,
|
457
457
|
kernel: torch.Tensor | None = None,
|
458
458
|
**kwargs,
|
monai/networks/layers/utils.py
CHANGED
@@ -11,9 +11,11 @@
|
|
11
11
|
|
12
12
|
from __future__ import annotations
|
13
13
|
|
14
|
+
from typing import Optional
|
15
|
+
|
14
16
|
import torch.nn
|
15
17
|
|
16
|
-
from monai.networks.layers.factories import Act, Dropout, Norm, Pool, split_args
|
18
|
+
from monai.networks.layers.factories import Act, Dropout, Norm, Pool, RelPosEmbedding, split_args
|
17
19
|
from monai.utils import has_option
|
18
20
|
|
19
21
|
__all__ = ["get_norm_layer", "get_act_layer", "get_dropout_layer", "get_pool_layer"]
|
@@ -124,3 +126,14 @@ def get_pool_layer(name: tuple | str, spatial_dims: int | None = 1):
|
|
124
126
|
pool_name, pool_args = split_args(name)
|
125
127
|
pool_type = Pool[pool_name, spatial_dims]
|
126
128
|
return pool_type(**pool_args)
|
129
|
+
|
130
|
+
|
131
|
+
def get_rel_pos_embedding_layer(name: tuple | str, s_input_dims: Optional[tuple], c_dim: int, num_heads: int):
|
132
|
+
embedding_name, embedding_args = split_args(name)
|
133
|
+
embedding_type = RelPosEmbedding[embedding_name]
|
134
|
+
# create a dictionary with the default values which can be overridden by embedding_args
|
135
|
+
kw_args = {"s_input_dims": s_input_dims, "c_dim": c_dim, "num_heads": num_heads, **embedding_args}
|
136
|
+
# filter out unused argument names
|
137
|
+
kw_args = {k: v for k, v in kw_args.items() if has_option(embedding_type, k)}
|
138
|
+
|
139
|
+
return embedding_type(**kw_args)
|
@@ -0,0 +1,233 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
from typing import Sequence, Tuple
|
15
|
+
|
16
|
+
import torch
|
17
|
+
from torch import nn
|
18
|
+
|
19
|
+
__all__ = ["VectorQuantizer", "EMAQuantizer"]
|
20
|
+
|
21
|
+
|
22
|
+
class EMAQuantizer(nn.Module):
|
23
|
+
"""
|
24
|
+
Vector Quantization module using Exponential Moving Average (EMA) to learn the codebook parameters based on Neural
|
25
|
+
Discrete Representation Learning by Oord et al. (https://arxiv.org/abs/1711.00937) and the official implementation
|
26
|
+
that can be found at https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py#L148 and commit
|
27
|
+
58d9a2746493717a7c9252938da7efa6006f3739.
|
28
|
+
|
29
|
+
This module is not compatible with TorchScript while working in a Distributed Data Parallelism Module. This is due
|
30
|
+
to lack of TorchScript support for torch.distributed module as per https://github.com/pytorch/pytorch/issues/41353
|
31
|
+
on 22/10/2022. If you want to TorchScript your model, please turn set `ddp_sync` to False.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
spatial_dims: number of spatial dimensions of the input.
|
35
|
+
num_embeddings: number of atomic elements in the codebook.
|
36
|
+
embedding_dim: number of channels of the input and atomic elements.
|
37
|
+
commitment_cost: scaling factor of the MSE loss between input and its quantized version. Defaults to 0.25.
|
38
|
+
decay: EMA decay. Defaults to 0.99.
|
39
|
+
epsilon: epsilon value. Defaults to 1e-5.
|
40
|
+
embedding_init: initialization method for the codebook. Defaults to "normal".
|
41
|
+
ddp_sync: whether to synchronize the codebook across processes. Defaults to True.
|
42
|
+
"""
|
43
|
+
|
44
|
+
def __init__(
|
45
|
+
self,
|
46
|
+
spatial_dims: int,
|
47
|
+
num_embeddings: int,
|
48
|
+
embedding_dim: int,
|
49
|
+
commitment_cost: float = 0.25,
|
50
|
+
decay: float = 0.99,
|
51
|
+
epsilon: float = 1e-5,
|
52
|
+
embedding_init: str = "normal",
|
53
|
+
ddp_sync: bool = True,
|
54
|
+
):
|
55
|
+
super().__init__()
|
56
|
+
self.spatial_dims: int = spatial_dims
|
57
|
+
self.embedding_dim: int = embedding_dim
|
58
|
+
self.num_embeddings: int = num_embeddings
|
59
|
+
|
60
|
+
assert self.spatial_dims in [2, 3], ValueError(
|
61
|
+
f"EMAQuantizer only supports 4D and 5D tensor inputs but received spatial dims {spatial_dims}."
|
62
|
+
)
|
63
|
+
|
64
|
+
self.embedding: torch.nn.Embedding = torch.nn.Embedding(self.num_embeddings, self.embedding_dim)
|
65
|
+
if embedding_init == "normal":
|
66
|
+
# Initialization is passed since the default one is normal inside the nn.Embedding
|
67
|
+
pass
|
68
|
+
elif embedding_init == "kaiming_uniform":
|
69
|
+
torch.nn.init.kaiming_uniform_(self.embedding.weight.data, mode="fan_in", nonlinearity="linear")
|
70
|
+
self.embedding.weight.requires_grad = False
|
71
|
+
|
72
|
+
self.commitment_cost: float = commitment_cost
|
73
|
+
|
74
|
+
self.register_buffer("ema_cluster_size", torch.zeros(self.num_embeddings))
|
75
|
+
self.register_buffer("ema_w", self.embedding.weight.data.clone())
|
76
|
+
# declare types for mypy
|
77
|
+
self.ema_cluster_size: torch.Tensor
|
78
|
+
self.ema_w: torch.Tensor
|
79
|
+
self.decay: float = decay
|
80
|
+
self.epsilon: float = epsilon
|
81
|
+
|
82
|
+
self.ddp_sync: bool = ddp_sync
|
83
|
+
|
84
|
+
# Precalculating required permutation shapes
|
85
|
+
self.flatten_permutation = [0] + list(range(2, self.spatial_dims + 2)) + [1]
|
86
|
+
self.quantization_permutation: Sequence[int] = [0, self.spatial_dims + 1] + list(
|
87
|
+
range(1, self.spatial_dims + 1)
|
88
|
+
)
|
89
|
+
|
90
|
+
def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
91
|
+
"""
|
92
|
+
Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
inputs: Encoding space tensors of shape [B, C, H, W, D].
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
torch.Tensor: Flatten version of the input of shape [B*H*W*D, C].
|
99
|
+
torch.Tensor: One-hot representation of the quantization indices of shape [B*H*W*D, self.num_embeddings].
|
100
|
+
torch.Tensor: Quantization indices of shape [B,H,W,D,1]
|
101
|
+
|
102
|
+
"""
|
103
|
+
with torch.cuda.amp.autocast(enabled=False):
|
104
|
+
encoding_indices_view = list(inputs.shape)
|
105
|
+
del encoding_indices_view[1]
|
106
|
+
|
107
|
+
inputs = inputs.float()
|
108
|
+
|
109
|
+
# Converting to channel last format
|
110
|
+
flat_input = inputs.permute(self.flatten_permutation).contiguous().view(-1, self.embedding_dim)
|
111
|
+
|
112
|
+
# Calculate Euclidean distances
|
113
|
+
distances = (
|
114
|
+
(flat_input**2).sum(dim=1, keepdim=True)
|
115
|
+
+ (self.embedding.weight.t() ** 2).sum(dim=0, keepdim=True)
|
116
|
+
- 2 * torch.mm(flat_input, self.embedding.weight.t())
|
117
|
+
)
|
118
|
+
|
119
|
+
# Mapping distances to indexes
|
120
|
+
encoding_indices = torch.max(-distances, dim=1)[1]
|
121
|
+
encodings = torch.nn.functional.one_hot(encoding_indices, self.num_embeddings).float()
|
122
|
+
|
123
|
+
# Quantize and reshape
|
124
|
+
encoding_indices = encoding_indices.view(encoding_indices_view)
|
125
|
+
|
126
|
+
return flat_input, encodings, encoding_indices
|
127
|
+
|
128
|
+
def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor:
|
129
|
+
"""
|
130
|
+
Given encoding indices of shape [B,D,H,W,1] embeds them in the quantized space
|
131
|
+
[B, D, H, W, self.embedding_dim] and reshapes them to [B, self.embedding_dim, D, H, W] to be fed to the
|
132
|
+
decoder.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
embedding_indices: Tensor in channel last format which holds indices referencing atomic
|
136
|
+
elements from self.embedding
|
137
|
+
|
138
|
+
Returns:
|
139
|
+
torch.Tensor: Quantize space representation of encoding_indices in channel first format.
|
140
|
+
"""
|
141
|
+
with torch.cuda.amp.autocast(enabled=False):
|
142
|
+
embedding: torch.Tensor = (
|
143
|
+
self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous()
|
144
|
+
)
|
145
|
+
return embedding
|
146
|
+
|
147
|
+
def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Tensor) -> None:
|
148
|
+
"""
|
149
|
+
TorchScript does not support torch.distributed.all_reduce. This function is a bypassing trick based on the
|
150
|
+
example: https://pytorch.org/docs/stable/generated/torch.jit.unused.html#torch.jit.unused
|
151
|
+
|
152
|
+
Args:
|
153
|
+
encodings_sum: The summation of one hot representation of what encoding was used for each
|
154
|
+
position.
|
155
|
+
dw: The multiplication of the one hot representation of what encoding was used for each
|
156
|
+
position with the flattened input.
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
None
|
160
|
+
"""
|
161
|
+
if self.ddp_sync and torch.distributed.is_initialized():
|
162
|
+
torch.distributed.all_reduce(tensor=encodings_sum, op=torch.distributed.ReduceOp.SUM)
|
163
|
+
torch.distributed.all_reduce(tensor=dw, op=torch.distributed.ReduceOp.SUM)
|
164
|
+
else:
|
165
|
+
pass
|
166
|
+
|
167
|
+
def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
168
|
+
flat_input, encodings, encoding_indices = self.quantize(inputs)
|
169
|
+
quantized = self.embed(encoding_indices)
|
170
|
+
|
171
|
+
# Use EMA to update the embedding vectors
|
172
|
+
if self.training:
|
173
|
+
with torch.no_grad():
|
174
|
+
encodings_sum = encodings.sum(0)
|
175
|
+
dw = torch.mm(encodings.t(), flat_input)
|
176
|
+
|
177
|
+
if self.ddp_sync:
|
178
|
+
self.distributed_synchronization(encodings_sum, dw)
|
179
|
+
|
180
|
+
self.ema_cluster_size.data.mul_(self.decay).add_(torch.mul(encodings_sum, 1 - self.decay))
|
181
|
+
|
182
|
+
# Laplace smoothing of the cluster size
|
183
|
+
n = self.ema_cluster_size.sum()
|
184
|
+
weights = (self.ema_cluster_size + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n
|
185
|
+
self.ema_w.data.mul_(self.decay).add_(torch.mul(dw, 1 - self.decay))
|
186
|
+
self.embedding.weight.data.copy_(self.ema_w / weights.unsqueeze(1))
|
187
|
+
|
188
|
+
# Encoding Loss
|
189
|
+
loss = self.commitment_cost * torch.nn.functional.mse_loss(quantized.detach(), inputs)
|
190
|
+
|
191
|
+
# Straight Through Estimator
|
192
|
+
quantized = inputs + (quantized - inputs).detach()
|
193
|
+
|
194
|
+
return quantized, loss, encoding_indices
|
195
|
+
|
196
|
+
|
197
|
+
class VectorQuantizer(torch.nn.Module):
|
198
|
+
"""
|
199
|
+
Vector Quantization wrapper that is needed as a workaround for the AMP to isolate the non fp16 compatible parts of
|
200
|
+
the quantization in their own class.
|
201
|
+
|
202
|
+
Args:
|
203
|
+
quantizer (torch.nn.Module): Quantizer module that needs to return its quantized representation, loss and index
|
204
|
+
based quantized representation.
|
205
|
+
"""
|
206
|
+
|
207
|
+
def __init__(self, quantizer: EMAQuantizer):
|
208
|
+
super().__init__()
|
209
|
+
|
210
|
+
self.quantizer: EMAQuantizer = quantizer
|
211
|
+
|
212
|
+
self.perplexity: torch.Tensor = torch.rand(1)
|
213
|
+
|
214
|
+
def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
215
|
+
quantized, loss, encoding_indices = self.quantizer(inputs)
|
216
|
+
# Perplexity calculations
|
217
|
+
avg_probs = (
|
218
|
+
torch.histc(encoding_indices.float(), bins=self.quantizer.num_embeddings, max=self.quantizer.num_embeddings)
|
219
|
+
.float()
|
220
|
+
.div(encoding_indices.numel())
|
221
|
+
)
|
222
|
+
|
223
|
+
self.perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
224
|
+
|
225
|
+
return loss, quantized
|
226
|
+
|
227
|
+
def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor:
|
228
|
+
return self.quantizer.embed(embedding_indices=embedding_indices)
|
229
|
+
|
230
|
+
def quantize(self, encodings: torch.Tensor) -> torch.Tensor:
|
231
|
+
output = self.quantizer(encodings)
|
232
|
+
encoding_indices: torch.Tensor = output[2]
|
233
|
+
return encoding_indices
|
monai/networks/nets/__init__.py
CHANGED
@@ -14,9 +14,11 @@ from __future__ import annotations
|
|
14
14
|
from .ahnet import AHnet, Ahnet, AHNet
|
15
15
|
from .attentionunet import AttentionUnet
|
16
16
|
from .autoencoder import AutoEncoder
|
17
|
+
from .autoencoderkl import AutoencoderKL
|
17
18
|
from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet
|
18
19
|
from .basic_unetplusplus import BasicUNetPlusPlus, BasicUnetPlusPlus, BasicunetPlusPlus, basicunetplusplus
|
19
20
|
from .classifier import Classifier, Critic, Discriminator
|
21
|
+
from .controlnet import ControlNet
|
20
22
|
from .daf3d import DAF3D
|
21
23
|
from .densenet import (
|
22
24
|
DenseNet,
|
@@ -34,6 +36,7 @@ from .densenet import (
|
|
34
36
|
densenet201,
|
35
37
|
densenet264,
|
36
38
|
)
|
39
|
+
from .diffusion_model_unet import DiffusionModelUNet
|
37
40
|
from .dints import DiNTS, TopologyConstruction, TopologyInstance, TopologySearch
|
38
41
|
from .dynunet import DynUNet, DynUnet, Dynunet
|
39
42
|
from .efficientnet import (
|
@@ -52,6 +55,7 @@ from .highresnet import HighResBlock, HighResNet
|
|
52
55
|
from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet
|
53
56
|
from .milmodel import MILModel
|
54
57
|
from .netadapter import NetAdapter
|
58
|
+
from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator
|
55
59
|
from .quicknat import Quicknat
|
56
60
|
from .regressor import Regressor
|
57
61
|
from .regunet import GlobalNet, LocalNet, RegUNet
|
@@ -104,9 +108,13 @@ from .senet import (
|
|
104
108
|
seresnext50,
|
105
109
|
seresnext101,
|
106
110
|
)
|
111
|
+
from .spade_autoencoderkl import SPADEAutoencoderKL
|
112
|
+
from .spade_diffusion_model_unet import SPADEDiffusionModelUNet
|
113
|
+
from .spade_network import SPADENet
|
107
114
|
from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR
|
108
115
|
from .torchvision_fc import TorchVisionFCModel
|
109
116
|
from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex
|
117
|
+
from .transformer import DecoderOnlyTransformer
|
110
118
|
from .unet import UNet, Unet
|
111
119
|
from .unetr import UNETR
|
112
120
|
from .varautoencoder import VarAutoEncoder
|
@@ -114,3 +122,4 @@ from .vit import ViT
|
|
114
122
|
from .vitautoenc import ViTAutoEnc
|
115
123
|
from .vnet import VNet
|
116
124
|
from .voxelmorph import VoxelMorph, VoxelMorphUNet
|
125
|
+
from .vqvae import VQVAE
|