monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/auto3dseg/hpo_gen.py +1 -1
  4. monai/apps/detection/utils/anchor_utils.py +2 -2
  5. monai/apps/pathology/transforms/post/array.py +7 -4
  6. monai/auto3dseg/analyzer.py +1 -1
  7. monai/bundle/scripts.py +204 -22
  8. monai/bundle/utils.py +1 -0
  9. monai/data/dataset_summary.py +1 -0
  10. monai/data/meta_tensor.py +2 -2
  11. monai/data/test_time_augmentation.py +2 -0
  12. monai/data/utils.py +9 -6
  13. monai/data/wsi_reader.py +2 -2
  14. monai/engines/__init__.py +3 -1
  15. monai/engines/trainer.py +281 -2
  16. monai/engines/utils.py +76 -1
  17. monai/handlers/mlflow_handler.py +21 -4
  18. monai/inferers/__init__.py +5 -0
  19. monai/inferers/inferer.py +1279 -1
  20. monai/metrics/cumulative_average.py +2 -0
  21. monai/metrics/panoptic_quality.py +1 -1
  22. monai/metrics/rocauc.py +2 -2
  23. monai/networks/blocks/__init__.py +3 -0
  24. monai/networks/blocks/attention_utils.py +128 -0
  25. monai/networks/blocks/crossattention.py +168 -0
  26. monai/networks/blocks/rel_pos_embedding.py +56 -0
  27. monai/networks/blocks/selfattention.py +74 -5
  28. monai/networks/blocks/spade_norm.py +95 -0
  29. monai/networks/blocks/spatialattention.py +82 -0
  30. monai/networks/blocks/transformerblock.py +25 -4
  31. monai/networks/blocks/upsample.py +22 -10
  32. monai/networks/layers/__init__.py +2 -1
  33. monai/networks/layers/factories.py +12 -1
  34. monai/networks/layers/simplelayers.py +1 -1
  35. monai/networks/layers/utils.py +14 -1
  36. monai/networks/layers/vector_quantizer.py +233 -0
  37. monai/networks/nets/__init__.py +9 -0
  38. monai/networks/nets/autoencoderkl.py +702 -0
  39. monai/networks/nets/controlnet.py +465 -0
  40. monai/networks/nets/diffusion_model_unet.py +1913 -0
  41. monai/networks/nets/patchgan_discriminator.py +230 -0
  42. monai/networks/nets/quicknat.py +8 -6
  43. monai/networks/nets/resnet.py +3 -4
  44. monai/networks/nets/spade_autoencoderkl.py +480 -0
  45. monai/networks/nets/spade_diffusion_model_unet.py +934 -0
  46. monai/networks/nets/spade_network.py +435 -0
  47. monai/networks/nets/swin_unetr.py +4 -3
  48. monai/networks/nets/transformer.py +157 -0
  49. monai/networks/nets/vqvae.py +472 -0
  50. monai/networks/schedulers/__init__.py +17 -0
  51. monai/networks/schedulers/ddim.py +294 -0
  52. monai/networks/schedulers/ddpm.py +250 -0
  53. monai/networks/schedulers/pndm.py +316 -0
  54. monai/networks/schedulers/scheduler.py +205 -0
  55. monai/networks/utils.py +22 -0
  56. monai/transforms/croppad/array.py +8 -8
  57. monai/transforms/croppad/dictionary.py +4 -4
  58. monai/transforms/croppad/functional.py +1 -1
  59. monai/transforms/regularization/array.py +4 -0
  60. monai/transforms/spatial/array.py +1 -1
  61. monai/transforms/utils_create_transform_ims.py +2 -4
  62. monai/utils/__init__.py +1 -0
  63. monai/utils/misc.py +5 -4
  64. monai/utils/ordering.py +207 -0
  65. monai/visualize/class_activation_maps.py +5 -5
  66. monai/visualize/img2tensorboard.py +3 -1
  67. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
  68. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
  69. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
  70. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
  71. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,82 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import Optional
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from monai.networks.blocks import SABlock
20
+ from monai.utils import optional_import
21
+
22
+ Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
23
+
24
+
25
+ class SpatialAttentionBlock(nn.Module):
26
+ """Perform spatial self-attention on the input tensor.
27
+
28
+ The input tensor is reshaped to B x (x_dim * y_dim [ * z_dim]) x C, where C is the number of channels, and then
29
+ self-attention is performed on the reshaped tensor. The output tensor is reshaped back to the original shape.
30
+
31
+ Args:
32
+ spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
33
+ num_channels: number of input channels. Must be divisible by num_head_channels.
34
+ num_head_channels: number of channels per head.
35
+ attention_dtype: cast attention operations to this dtype.
36
+
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ spatial_dims: int,
42
+ num_channels: int,
43
+ num_head_channels: int | None = None,
44
+ norm_num_groups: int = 32,
45
+ norm_eps: float = 1e-6,
46
+ attention_dtype: Optional[torch.dtype] = None,
47
+ ) -> None:
48
+ super().__init__()
49
+
50
+ self.spatial_dims = spatial_dims
51
+ self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True)
52
+ # check num_head_channels is divisible by num_channels
53
+ if num_head_channels is not None and num_channels % num_head_channels != 0:
54
+ raise ValueError("num_channels must be divisible by num_head_channels")
55
+ num_heads = num_channels // num_head_channels if num_head_channels is not None else 1
56
+ self.attn = SABlock(
57
+ hidden_size=num_channels, num_heads=num_heads, qkv_bias=True, attention_dtype=attention_dtype
58
+ )
59
+
60
+ def forward(self, x: torch.Tensor):
61
+ residual = x
62
+
63
+ if self.spatial_dims == 1:
64
+ h = x.shape[2]
65
+ rearrange_input = Rearrange("b c h -> b h c")
66
+ rearrange_output = Rearrange("b h c -> b c h", h=h)
67
+ if self.spatial_dims == 2:
68
+ h, w = x.shape[2], x.shape[3]
69
+ rearrange_input = Rearrange("b c h w -> b (h w) c")
70
+ rearrange_output = Rearrange("b (h w) c -> b c h w", h=h, w=w)
71
+ else:
72
+ h, w, d = x.shape[2], x.shape[3], x.shape[4]
73
+ rearrange_input = Rearrange("b c h w d -> b (h w d) c")
74
+ rearrange_output = Rearrange("b (h w d) c -> b c h w d", h=h, w=w, d=d)
75
+
76
+ x = self.norm(x)
77
+ x = rearrange_input(x) # B x (x_dim * y_dim [ * z_dim]) x C
78
+
79
+ x = self.attn(x)
80
+ x = rearrange_output(x) # B x x C x x_dim * y_dim * [z_dim]
81
+ x = x + residual
82
+ return x
@@ -11,10 +11,12 @@
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
+ from typing import Optional
15
+
16
+ import torch
14
17
  import torch.nn as nn
15
18
 
16
- from monai.networks.blocks.mlp import MLPBlock
17
- from monai.networks.blocks.selfattention import SABlock
19
+ from monai.networks.blocks import CrossAttentionBlock, MLPBlock, SABlock
18
20
 
19
21
 
20
22
  class TransformerBlock(nn.Module):
@@ -31,6 +33,9 @@ class TransformerBlock(nn.Module):
31
33
  dropout_rate: float = 0.0,
32
34
  qkv_bias: bool = False,
33
35
  save_attn: bool = False,
36
+ causal: bool = False,
37
+ sequence_length: int | None = None,
38
+ with_cross_attention: bool = False,
34
39
  ) -> None:
35
40
  """
36
41
  Args:
@@ -53,10 +58,26 @@ class TransformerBlock(nn.Module):
53
58
 
54
59
  self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate)
55
60
  self.norm1 = nn.LayerNorm(hidden_size)
56
- self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias, save_attn)
61
+ self.attn = SABlock(
62
+ hidden_size,
63
+ num_heads,
64
+ dropout_rate,
65
+ qkv_bias=qkv_bias,
66
+ save_attn=save_attn,
67
+ causal=causal,
68
+ sequence_length=sequence_length,
69
+ )
57
70
  self.norm2 = nn.LayerNorm(hidden_size)
71
+ self.with_cross_attention = with_cross_attention
72
+
73
+ self.norm_cross_attn = nn.LayerNorm(hidden_size)
74
+ self.cross_attn = CrossAttentionBlock(
75
+ hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False
76
+ )
58
77
 
59
- def forward(self, x):
78
+ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
60
79
  x = x + self.attn(self.norm1(x))
80
+ if self.with_cross_attention:
81
+ x = x + self.cross_attn(self.norm_cross_attn(x), context=context)
61
82
  x = x + self.mlp(self.norm2(x))
62
83
  return x
@@ -17,8 +17,8 @@ import torch
17
17
  import torch.nn as nn
18
18
 
19
19
  from monai.networks.layers.factories import Conv, Pad, Pool
20
- from monai.networks.utils import icnr_init, pixelshuffle
21
- from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option
20
+ from monai.networks.utils import CastTempType, icnr_init, pixelshuffle
21
+ from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option, pytorch_after
22
22
 
23
23
  __all__ = ["Upsample", "UpSample", "SubpixelUpsample", "Subpixelupsample", "SubpixelUpSample"]
24
24
 
@@ -50,6 +50,7 @@ class UpSample(nn.Sequential):
50
50
  size: tuple[int] | int | None = None,
51
51
  mode: UpsampleMode | str = UpsampleMode.DECONV,
52
52
  pre_conv: nn.Module | str | None = "default",
53
+ post_conv: nn.Module | None = None,
53
54
  interp_mode: str = InterpolateMode.LINEAR,
54
55
  align_corners: bool | None = True,
55
56
  bias: bool = True,
@@ -71,6 +72,7 @@ class UpSample(nn.Sequential):
71
72
  pre_conv: a conv block applied before upsampling. Defaults to "default".
72
73
  When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized when
73
74
  Only used in the "nontrainable" or "pixelshuffle" mode.
75
+ post_conv: a conv block applied after upsampling. Defaults to None. Only used in the "nontrainable" mode.
74
76
  interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
75
77
  Only used in the "nontrainable" mode.
76
78
  If ends with ``"linear"`` will use ``spatial dims`` to determine the correct interpolation.
@@ -154,15 +156,25 @@ class UpSample(nn.Sequential):
154
156
  linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR]
155
157
  if interp_mode in linear_mode: # choose mode based on dimensions
156
158
  interp_mode = linear_mode[spatial_dims - 1]
157
- self.add_module(
158
- "upsample_non_trainable",
159
- nn.Upsample(
160
- size=size,
161
- scale_factor=None if size else scale_factor_,
162
- mode=interp_mode.value,
163
- align_corners=align_corners,
164
- ),
159
+
160
+ upsample = nn.Upsample(
161
+ size=size,
162
+ scale_factor=None if size else scale_factor_,
163
+ mode=interp_mode.value,
164
+ align_corners=align_corners,
165
165
  )
166
+
167
+ # Cast to float32 as 'upsample_nearest2d_out_frame' op does not support bfloat16
168
+ # https://github.com/pytorch/pytorch/issues/86679. This issue is solved in PyTorch 2.1
169
+ if pytorch_after(major=2, minor=1):
170
+ self.add_module("upsample_non_trainable", upsample)
171
+ else:
172
+ self.add_module(
173
+ "upsample_non_trainable",
174
+ CastTempType(initial_type=torch.bfloat16, temporary_type=torch.float32, submodule=upsample),
175
+ )
176
+ if post_conv:
177
+ self.add_module("postconv", post_conv)
166
178
  elif up_mode == UpsampleMode.PIXELSHUFFLE:
167
179
  self.add_module(
168
180
  "pixelshuffle",
@@ -14,7 +14,7 @@ from __future__ import annotations
14
14
  from .conjugate_gradient import ConjugateGradient
15
15
  from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding
16
16
  from .drop_path import DropPath
17
- from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args
17
+ from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, RelPosEmbedding, split_args
18
18
  from .filtering import BilateralFilter, PHLFilter, TrainableBilateralFilter, TrainableJointBilateralFilter
19
19
  from .gmm import GaussianMixtureModel
20
20
  from .simplelayers import (
@@ -38,4 +38,5 @@ from .simplelayers import (
38
38
  )
39
39
  from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push
40
40
  from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer
41
+ from .vector_quantizer import EMAQuantizer, VectorQuantizer
41
42
  from .weight_init import _no_grad_trunc_normal_, trunc_normal_
@@ -70,7 +70,7 @@ import torch.nn as nn
70
70
  from monai.networks.utils import has_nvfuser_instance_norm
71
71
  from monai.utils import ComponentStore, look_up_option, optional_import
72
72
 
73
- __all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"]
73
+ __all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "RelPosEmbedding", "split_args"]
74
74
 
75
75
 
76
76
  class LayerFactory(ComponentStore):
@@ -201,6 +201,10 @@ Act = LayerFactory(name="Activation layers", description="Factory for creating a
201
201
  Conv = LayerFactory(name="Convolution layers", description="Factory for creating convolution layers.")
202
202
  Pool = LayerFactory(name="Pooling layers", description="Factory for creating pooling layers.")
203
203
  Pad = LayerFactory(name="Padding layers", description="Factory for creating padding layers.")
204
+ RelPosEmbedding = LayerFactory(
205
+ name="Relative positional embedding layers",
206
+ description="Factory for creating relative positional embedding factory",
207
+ )
204
208
 
205
209
 
206
210
  @Dropout.factory_function("dropout")
@@ -468,3 +472,10 @@ def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d |
468
472
  """
469
473
  types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d)
470
474
  return types[dim - 1]
475
+
476
+
477
+ @RelPosEmbedding.factory_function("decomposed")
478
+ def decomposed_rel_pos_embedding() -> type[nn.Module]:
479
+ from monai.networks.blocks.rel_pos_embedding import DecomposedRelativePosEmbedding
480
+
481
+ return DecomposedRelativePosEmbedding
@@ -452,7 +452,7 @@ def get_binary_kernel(window_size: Sequence[int], dtype=torch.float, device=None
452
452
 
453
453
  def median_filter(
454
454
  in_tensor: torch.Tensor,
455
- kernel_size: Sequence[int] = (3, 3, 3),
455
+ kernel_size: Sequence[int] | int = (3, 3, 3),
456
456
  spatial_dims: int = 3,
457
457
  kernel: torch.Tensor | None = None,
458
458
  **kwargs,
@@ -11,9 +11,11 @@
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
+ from typing import Optional
15
+
14
16
  import torch.nn
15
17
 
16
- from monai.networks.layers.factories import Act, Dropout, Norm, Pool, split_args
18
+ from monai.networks.layers.factories import Act, Dropout, Norm, Pool, RelPosEmbedding, split_args
17
19
  from monai.utils import has_option
18
20
 
19
21
  __all__ = ["get_norm_layer", "get_act_layer", "get_dropout_layer", "get_pool_layer"]
@@ -124,3 +126,14 @@ def get_pool_layer(name: tuple | str, spatial_dims: int | None = 1):
124
126
  pool_name, pool_args = split_args(name)
125
127
  pool_type = Pool[pool_name, spatial_dims]
126
128
  return pool_type(**pool_args)
129
+
130
+
131
+ def get_rel_pos_embedding_layer(name: tuple | str, s_input_dims: Optional[tuple], c_dim: int, num_heads: int):
132
+ embedding_name, embedding_args = split_args(name)
133
+ embedding_type = RelPosEmbedding[embedding_name]
134
+ # create a dictionary with the default values which can be overridden by embedding_args
135
+ kw_args = {"s_input_dims": s_input_dims, "c_dim": c_dim, "num_heads": num_heads, **embedding_args}
136
+ # filter out unused argument names
137
+ kw_args = {k: v for k, v in kw_args.items() if has_option(embedding_type, k)}
138
+
139
+ return embedding_type(**kw_args)
@@ -0,0 +1,233 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import Sequence, Tuple
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ __all__ = ["VectorQuantizer", "EMAQuantizer"]
20
+
21
+
22
+ class EMAQuantizer(nn.Module):
23
+ """
24
+ Vector Quantization module using Exponential Moving Average (EMA) to learn the codebook parameters based on Neural
25
+ Discrete Representation Learning by Oord et al. (https://arxiv.org/abs/1711.00937) and the official implementation
26
+ that can be found at https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py#L148 and commit
27
+ 58d9a2746493717a7c9252938da7efa6006f3739.
28
+
29
+ This module is not compatible with TorchScript while working in a Distributed Data Parallelism Module. This is due
30
+ to lack of TorchScript support for torch.distributed module as per https://github.com/pytorch/pytorch/issues/41353
31
+ on 22/10/2022. If you want to TorchScript your model, please turn set `ddp_sync` to False.
32
+
33
+ Args:
34
+ spatial_dims: number of spatial dimensions of the input.
35
+ num_embeddings: number of atomic elements in the codebook.
36
+ embedding_dim: number of channels of the input and atomic elements.
37
+ commitment_cost: scaling factor of the MSE loss between input and its quantized version. Defaults to 0.25.
38
+ decay: EMA decay. Defaults to 0.99.
39
+ epsilon: epsilon value. Defaults to 1e-5.
40
+ embedding_init: initialization method for the codebook. Defaults to "normal".
41
+ ddp_sync: whether to synchronize the codebook across processes. Defaults to True.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ spatial_dims: int,
47
+ num_embeddings: int,
48
+ embedding_dim: int,
49
+ commitment_cost: float = 0.25,
50
+ decay: float = 0.99,
51
+ epsilon: float = 1e-5,
52
+ embedding_init: str = "normal",
53
+ ddp_sync: bool = True,
54
+ ):
55
+ super().__init__()
56
+ self.spatial_dims: int = spatial_dims
57
+ self.embedding_dim: int = embedding_dim
58
+ self.num_embeddings: int = num_embeddings
59
+
60
+ assert self.spatial_dims in [2, 3], ValueError(
61
+ f"EMAQuantizer only supports 4D and 5D tensor inputs but received spatial dims {spatial_dims}."
62
+ )
63
+
64
+ self.embedding: torch.nn.Embedding = torch.nn.Embedding(self.num_embeddings, self.embedding_dim)
65
+ if embedding_init == "normal":
66
+ # Initialization is passed since the default one is normal inside the nn.Embedding
67
+ pass
68
+ elif embedding_init == "kaiming_uniform":
69
+ torch.nn.init.kaiming_uniform_(self.embedding.weight.data, mode="fan_in", nonlinearity="linear")
70
+ self.embedding.weight.requires_grad = False
71
+
72
+ self.commitment_cost: float = commitment_cost
73
+
74
+ self.register_buffer("ema_cluster_size", torch.zeros(self.num_embeddings))
75
+ self.register_buffer("ema_w", self.embedding.weight.data.clone())
76
+ # declare types for mypy
77
+ self.ema_cluster_size: torch.Tensor
78
+ self.ema_w: torch.Tensor
79
+ self.decay: float = decay
80
+ self.epsilon: float = epsilon
81
+
82
+ self.ddp_sync: bool = ddp_sync
83
+
84
+ # Precalculating required permutation shapes
85
+ self.flatten_permutation = [0] + list(range(2, self.spatial_dims + 2)) + [1]
86
+ self.quantization_permutation: Sequence[int] = [0, self.spatial_dims + 1] + list(
87
+ range(1, self.spatial_dims + 1)
88
+ )
89
+
90
+ def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
91
+ """
92
+ Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss.
93
+
94
+ Args:
95
+ inputs: Encoding space tensors of shape [B, C, H, W, D].
96
+
97
+ Returns:
98
+ torch.Tensor: Flatten version of the input of shape [B*H*W*D, C].
99
+ torch.Tensor: One-hot representation of the quantization indices of shape [B*H*W*D, self.num_embeddings].
100
+ torch.Tensor: Quantization indices of shape [B,H,W,D,1]
101
+
102
+ """
103
+ with torch.cuda.amp.autocast(enabled=False):
104
+ encoding_indices_view = list(inputs.shape)
105
+ del encoding_indices_view[1]
106
+
107
+ inputs = inputs.float()
108
+
109
+ # Converting to channel last format
110
+ flat_input = inputs.permute(self.flatten_permutation).contiguous().view(-1, self.embedding_dim)
111
+
112
+ # Calculate Euclidean distances
113
+ distances = (
114
+ (flat_input**2).sum(dim=1, keepdim=True)
115
+ + (self.embedding.weight.t() ** 2).sum(dim=0, keepdim=True)
116
+ - 2 * torch.mm(flat_input, self.embedding.weight.t())
117
+ )
118
+
119
+ # Mapping distances to indexes
120
+ encoding_indices = torch.max(-distances, dim=1)[1]
121
+ encodings = torch.nn.functional.one_hot(encoding_indices, self.num_embeddings).float()
122
+
123
+ # Quantize and reshape
124
+ encoding_indices = encoding_indices.view(encoding_indices_view)
125
+
126
+ return flat_input, encodings, encoding_indices
127
+
128
+ def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor:
129
+ """
130
+ Given encoding indices of shape [B,D,H,W,1] embeds them in the quantized space
131
+ [B, D, H, W, self.embedding_dim] and reshapes them to [B, self.embedding_dim, D, H, W] to be fed to the
132
+ decoder.
133
+
134
+ Args:
135
+ embedding_indices: Tensor in channel last format which holds indices referencing atomic
136
+ elements from self.embedding
137
+
138
+ Returns:
139
+ torch.Tensor: Quantize space representation of encoding_indices in channel first format.
140
+ """
141
+ with torch.cuda.amp.autocast(enabled=False):
142
+ embedding: torch.Tensor = (
143
+ self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous()
144
+ )
145
+ return embedding
146
+
147
+ def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Tensor) -> None:
148
+ """
149
+ TorchScript does not support torch.distributed.all_reduce. This function is a bypassing trick based on the
150
+ example: https://pytorch.org/docs/stable/generated/torch.jit.unused.html#torch.jit.unused
151
+
152
+ Args:
153
+ encodings_sum: The summation of one hot representation of what encoding was used for each
154
+ position.
155
+ dw: The multiplication of the one hot representation of what encoding was used for each
156
+ position with the flattened input.
157
+
158
+ Returns:
159
+ None
160
+ """
161
+ if self.ddp_sync and torch.distributed.is_initialized():
162
+ torch.distributed.all_reduce(tensor=encodings_sum, op=torch.distributed.ReduceOp.SUM)
163
+ torch.distributed.all_reduce(tensor=dw, op=torch.distributed.ReduceOp.SUM)
164
+ else:
165
+ pass
166
+
167
+ def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
168
+ flat_input, encodings, encoding_indices = self.quantize(inputs)
169
+ quantized = self.embed(encoding_indices)
170
+
171
+ # Use EMA to update the embedding vectors
172
+ if self.training:
173
+ with torch.no_grad():
174
+ encodings_sum = encodings.sum(0)
175
+ dw = torch.mm(encodings.t(), flat_input)
176
+
177
+ if self.ddp_sync:
178
+ self.distributed_synchronization(encodings_sum, dw)
179
+
180
+ self.ema_cluster_size.data.mul_(self.decay).add_(torch.mul(encodings_sum, 1 - self.decay))
181
+
182
+ # Laplace smoothing of the cluster size
183
+ n = self.ema_cluster_size.sum()
184
+ weights = (self.ema_cluster_size + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n
185
+ self.ema_w.data.mul_(self.decay).add_(torch.mul(dw, 1 - self.decay))
186
+ self.embedding.weight.data.copy_(self.ema_w / weights.unsqueeze(1))
187
+
188
+ # Encoding Loss
189
+ loss = self.commitment_cost * torch.nn.functional.mse_loss(quantized.detach(), inputs)
190
+
191
+ # Straight Through Estimator
192
+ quantized = inputs + (quantized - inputs).detach()
193
+
194
+ return quantized, loss, encoding_indices
195
+
196
+
197
+ class VectorQuantizer(torch.nn.Module):
198
+ """
199
+ Vector Quantization wrapper that is needed as a workaround for the AMP to isolate the non fp16 compatible parts of
200
+ the quantization in their own class.
201
+
202
+ Args:
203
+ quantizer (torch.nn.Module): Quantizer module that needs to return its quantized representation, loss and index
204
+ based quantized representation.
205
+ """
206
+
207
+ def __init__(self, quantizer: EMAQuantizer):
208
+ super().__init__()
209
+
210
+ self.quantizer: EMAQuantizer = quantizer
211
+
212
+ self.perplexity: torch.Tensor = torch.rand(1)
213
+
214
+ def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
215
+ quantized, loss, encoding_indices = self.quantizer(inputs)
216
+ # Perplexity calculations
217
+ avg_probs = (
218
+ torch.histc(encoding_indices.float(), bins=self.quantizer.num_embeddings, max=self.quantizer.num_embeddings)
219
+ .float()
220
+ .div(encoding_indices.numel())
221
+ )
222
+
223
+ self.perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
224
+
225
+ return loss, quantized
226
+
227
+ def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor:
228
+ return self.quantizer.embed(embedding_indices=embedding_indices)
229
+
230
+ def quantize(self, encodings: torch.Tensor) -> torch.Tensor:
231
+ output = self.quantizer(encodings)
232
+ encoding_indices: torch.Tensor = output[2]
233
+ return encoding_indices
@@ -14,9 +14,11 @@ from __future__ import annotations
14
14
  from .ahnet import AHnet, Ahnet, AHNet
15
15
  from .attentionunet import AttentionUnet
16
16
  from .autoencoder import AutoEncoder
17
+ from .autoencoderkl import AutoencoderKL
17
18
  from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet
18
19
  from .basic_unetplusplus import BasicUNetPlusPlus, BasicUnetPlusPlus, BasicunetPlusPlus, basicunetplusplus
19
20
  from .classifier import Classifier, Critic, Discriminator
21
+ from .controlnet import ControlNet
20
22
  from .daf3d import DAF3D
21
23
  from .densenet import (
22
24
  DenseNet,
@@ -34,6 +36,7 @@ from .densenet import (
34
36
  densenet201,
35
37
  densenet264,
36
38
  )
39
+ from .diffusion_model_unet import DiffusionModelUNet
37
40
  from .dints import DiNTS, TopologyConstruction, TopologyInstance, TopologySearch
38
41
  from .dynunet import DynUNet, DynUnet, Dynunet
39
42
  from .efficientnet import (
@@ -52,6 +55,7 @@ from .highresnet import HighResBlock, HighResNet
52
55
  from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet
53
56
  from .milmodel import MILModel
54
57
  from .netadapter import NetAdapter
58
+ from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator
55
59
  from .quicknat import Quicknat
56
60
  from .regressor import Regressor
57
61
  from .regunet import GlobalNet, LocalNet, RegUNet
@@ -104,9 +108,13 @@ from .senet import (
104
108
  seresnext50,
105
109
  seresnext101,
106
110
  )
111
+ from .spade_autoencoderkl import SPADEAutoencoderKL
112
+ from .spade_diffusion_model_unet import SPADEDiffusionModelUNet
113
+ from .spade_network import SPADENet
107
114
  from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR
108
115
  from .torchvision_fc import TorchVisionFCModel
109
116
  from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex
117
+ from .transformer import DecoderOnlyTransformer
110
118
  from .unet import UNet, Unet
111
119
  from .unetr import UNETR
112
120
  from .varautoencoder import VarAutoEncoder
@@ -114,3 +122,4 @@ from .vit import ViT
114
122
  from .vitautoenc import ViTAutoEnc
115
123
  from .vnet import VNet
116
124
  from .voxelmorph import VoxelMorph, VoxelMorphUNet
125
+ from .vqvae import VQVAE