monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/hpo_gen.py +1 -1
- monai/apps/detection/utils/anchor_utils.py +2 -2
- monai/apps/pathology/transforms/post/array.py +7 -4
- monai/auto3dseg/analyzer.py +1 -1
- monai/bundle/scripts.py +204 -22
- monai/bundle/utils.py +1 -0
- monai/data/dataset_summary.py +1 -0
- monai/data/meta_tensor.py +2 -2
- monai/data/test_time_augmentation.py +2 -0
- monai/data/utils.py +9 -6
- monai/data/wsi_reader.py +2 -2
- monai/engines/__init__.py +3 -1
- monai/engines/trainer.py +281 -2
- monai/engines/utils.py +76 -1
- monai/handlers/mlflow_handler.py +21 -4
- monai/inferers/__init__.py +5 -0
- monai/inferers/inferer.py +1279 -1
- monai/metrics/cumulative_average.py +2 -0
- monai/metrics/panoptic_quality.py +1 -1
- monai/metrics/rocauc.py +2 -2
- monai/networks/blocks/__init__.py +3 -0
- monai/networks/blocks/attention_utils.py +128 -0
- monai/networks/blocks/crossattention.py +168 -0
- monai/networks/blocks/rel_pos_embedding.py +56 -0
- monai/networks/blocks/selfattention.py +74 -5
- monai/networks/blocks/spade_norm.py +95 -0
- monai/networks/blocks/spatialattention.py +82 -0
- monai/networks/blocks/transformerblock.py +25 -4
- monai/networks/blocks/upsample.py +22 -10
- monai/networks/layers/__init__.py +2 -1
- monai/networks/layers/factories.py +12 -1
- monai/networks/layers/simplelayers.py +1 -1
- monai/networks/layers/utils.py +14 -1
- monai/networks/layers/vector_quantizer.py +233 -0
- monai/networks/nets/__init__.py +9 -0
- monai/networks/nets/autoencoderkl.py +702 -0
- monai/networks/nets/controlnet.py +465 -0
- monai/networks/nets/diffusion_model_unet.py +1913 -0
- monai/networks/nets/patchgan_discriminator.py +230 -0
- monai/networks/nets/quicknat.py +8 -6
- monai/networks/nets/resnet.py +3 -4
- monai/networks/nets/spade_autoencoderkl.py +480 -0
- monai/networks/nets/spade_diffusion_model_unet.py +934 -0
- monai/networks/nets/spade_network.py +435 -0
- monai/networks/nets/swin_unetr.py +4 -3
- monai/networks/nets/transformer.py +157 -0
- monai/networks/nets/vqvae.py +472 -0
- monai/networks/schedulers/__init__.py +17 -0
- monai/networks/schedulers/ddim.py +294 -0
- monai/networks/schedulers/ddpm.py +250 -0
- monai/networks/schedulers/pndm.py +316 -0
- monai/networks/schedulers/scheduler.py +205 -0
- monai/networks/utils.py +22 -0
- monai/transforms/croppad/array.py +8 -8
- monai/transforms/croppad/dictionary.py +4 -4
- monai/transforms/croppad/functional.py +1 -1
- monai/transforms/regularization/array.py +4 -0
- monai/transforms/spatial/array.py +1 -1
- monai/transforms/utils_create_transform_ims.py +2 -4
- monai/utils/__init__.py +1 -0
- monai/utils/misc.py +5 -4
- monai/utils/ordering.py +207 -0
- monai/visualize/class_activation_maps.py +5 -5
- monai/visualize/img2tensorboard.py +3 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,435 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
from typing import Sequence
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
import torch.nn as nn
|
19
|
+
import torch.nn.functional as F
|
20
|
+
|
21
|
+
from monai.networks.blocks import Convolution
|
22
|
+
from monai.networks.blocks.spade_norm import SPADE
|
23
|
+
from monai.networks.layers import Act
|
24
|
+
from monai.networks.layers.utils import get_act_layer
|
25
|
+
from monai.utils.enums import StrEnum
|
26
|
+
|
27
|
+
__all__ = ["SPADENet"]
|
28
|
+
|
29
|
+
|
30
|
+
class UpsamplingModes(StrEnum):
|
31
|
+
bicubic = "bicubic"
|
32
|
+
nearest = "nearest"
|
33
|
+
bilinear = "bilinear"
|
34
|
+
|
35
|
+
|
36
|
+
class SPADENetResBlock(nn.Module):
|
37
|
+
"""
|
38
|
+
Creates a Residual Block with SPADE normalisation.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
spatial_dims: number of spatial dimensions
|
42
|
+
in_channels: number of input channels
|
43
|
+
out_channels: number of output channels
|
44
|
+
label_nc: number of semantic channels that will be taken into account in SPADE normalisation blocks
|
45
|
+
spade_intermediate_channels: number of intermediate channels in the middle conv. layers in SPADE normalisation blocks
|
46
|
+
norm: base normalisation type used on top of SPADE
|
47
|
+
kernel_size: convolutional kernel size
|
48
|
+
"""
|
49
|
+
|
50
|
+
def __init__(
|
51
|
+
self,
|
52
|
+
spatial_dims: int,
|
53
|
+
in_channels: int,
|
54
|
+
out_channels: int,
|
55
|
+
label_nc: int,
|
56
|
+
spade_intermediate_channels: int = 128,
|
57
|
+
norm: str | tuple = "INSTANCE",
|
58
|
+
act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}),
|
59
|
+
kernel_size: int = 3,
|
60
|
+
):
|
61
|
+
super().__init__()
|
62
|
+
self.in_channels = in_channels
|
63
|
+
self.out_channels = out_channels
|
64
|
+
self.int_channels = min(self.in_channels, self.out_channels)
|
65
|
+
self.learned_shortcut = self.in_channels != self.out_channels
|
66
|
+
self.conv_0 = Convolution(
|
67
|
+
spatial_dims=spatial_dims, in_channels=self.in_channels, out_channels=self.int_channels, act=None, norm=None
|
68
|
+
)
|
69
|
+
self.conv_1 = Convolution(
|
70
|
+
spatial_dims=spatial_dims,
|
71
|
+
in_channels=self.int_channels,
|
72
|
+
out_channels=self.out_channels,
|
73
|
+
act=None,
|
74
|
+
norm=None,
|
75
|
+
)
|
76
|
+
self.activation = get_act_layer(act)
|
77
|
+
self.norm_0 = SPADE(
|
78
|
+
label_nc=label_nc,
|
79
|
+
norm_nc=self.in_channels,
|
80
|
+
kernel_size=kernel_size,
|
81
|
+
spatial_dims=spatial_dims,
|
82
|
+
hidden_channels=spade_intermediate_channels,
|
83
|
+
norm=norm,
|
84
|
+
)
|
85
|
+
self.norm_1 = SPADE(
|
86
|
+
label_nc=label_nc,
|
87
|
+
norm_nc=self.int_channels,
|
88
|
+
kernel_size=kernel_size,
|
89
|
+
spatial_dims=spatial_dims,
|
90
|
+
hidden_channels=spade_intermediate_channels,
|
91
|
+
norm=norm,
|
92
|
+
)
|
93
|
+
|
94
|
+
if self.learned_shortcut:
|
95
|
+
self.conv_s = Convolution(
|
96
|
+
spatial_dims=spatial_dims,
|
97
|
+
in_channels=self.in_channels,
|
98
|
+
out_channels=self.out_channels,
|
99
|
+
act=None,
|
100
|
+
norm=None,
|
101
|
+
kernel_size=1,
|
102
|
+
)
|
103
|
+
self.norm_s = SPADE(
|
104
|
+
label_nc=label_nc,
|
105
|
+
norm_nc=self.in_channels,
|
106
|
+
kernel_size=kernel_size,
|
107
|
+
spatial_dims=spatial_dims,
|
108
|
+
hidden_channels=spade_intermediate_channels,
|
109
|
+
norm=norm,
|
110
|
+
)
|
111
|
+
|
112
|
+
def forward(self, x, seg):
|
113
|
+
x_s = self.shortcut(x, seg)
|
114
|
+
dx = self.conv_0(self.activation(self.norm_0(x, seg)))
|
115
|
+
dx = self.conv_1(self.activation(self.norm_1(dx, seg)))
|
116
|
+
out = x_s + dx
|
117
|
+
return out
|
118
|
+
|
119
|
+
def shortcut(self, x, seg):
|
120
|
+
if self.learned_shortcut:
|
121
|
+
x_s = self.conv_s(self.norm_s(x, seg))
|
122
|
+
else:
|
123
|
+
x_s = x
|
124
|
+
return x_s
|
125
|
+
|
126
|
+
|
127
|
+
class SPADEEncoder(nn.Module):
|
128
|
+
"""
|
129
|
+
Encoding branch of a VAE compatible with a SPADE-like generator
|
130
|
+
|
131
|
+
Args:
|
132
|
+
spatial_dims: number of spatial dimensions
|
133
|
+
in_channels: number of input channels
|
134
|
+
z_dim: latent space dimension of the VAE containing the image sytle information
|
135
|
+
channels: number of output after each downsampling block
|
136
|
+
input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers
|
137
|
+
of the autoencoder (HxWx[D])
|
138
|
+
kernel_size: convolutional kernel size
|
139
|
+
norm: normalisation layer type
|
140
|
+
act: activation type
|
141
|
+
"""
|
142
|
+
|
143
|
+
def __init__(
|
144
|
+
self,
|
145
|
+
spatial_dims: int,
|
146
|
+
in_channels: int,
|
147
|
+
z_dim: int,
|
148
|
+
channels: Sequence[int],
|
149
|
+
input_shape: Sequence[int],
|
150
|
+
kernel_size: int = 3,
|
151
|
+
norm: str | tuple = "INSTANCE",
|
152
|
+
act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}),
|
153
|
+
):
|
154
|
+
super().__init__()
|
155
|
+
self.in_channels = in_channels
|
156
|
+
self.z_dim = z_dim
|
157
|
+
self.channels = channels
|
158
|
+
if len(input_shape) != spatial_dims:
|
159
|
+
raise ValueError("Length of parameter input shape must match spatial_dims; got %s" % (input_shape))
|
160
|
+
for s_ind, s_ in enumerate(input_shape):
|
161
|
+
if s_ / (2 ** len(channels)) != s_ // (2 ** len(channels)):
|
162
|
+
raise ValueError(
|
163
|
+
"Each dimension of your input must be divisible by 2 ** (autoencoder depth)."
|
164
|
+
"The shape in position %d, %d is not divisible by %d. " % (s_ind, s_, len(channels))
|
165
|
+
)
|
166
|
+
self.input_shape = input_shape
|
167
|
+
self.latent_spatial_shape = [s_ // (2 ** len(self.channels)) for s_ in self.input_shape]
|
168
|
+
blocks = []
|
169
|
+
ch_init = self.in_channels
|
170
|
+
for _, ch_value in enumerate(channels):
|
171
|
+
blocks.append(
|
172
|
+
Convolution(
|
173
|
+
spatial_dims=spatial_dims,
|
174
|
+
in_channels=ch_init,
|
175
|
+
out_channels=ch_value,
|
176
|
+
strides=2,
|
177
|
+
kernel_size=kernel_size,
|
178
|
+
norm=norm,
|
179
|
+
act=act,
|
180
|
+
)
|
181
|
+
)
|
182
|
+
ch_init = ch_value
|
183
|
+
|
184
|
+
self.blocks = nn.ModuleList(blocks)
|
185
|
+
self.fc_mu = nn.Linear(
|
186
|
+
in_features=np.prod(self.latent_spatial_shape) * self.channels[-1], out_features=self.z_dim
|
187
|
+
)
|
188
|
+
self.fc_var = nn.Linear(
|
189
|
+
in_features=np.prod(self.latent_spatial_shape) * self.channels[-1], out_features=self.z_dim
|
190
|
+
)
|
191
|
+
|
192
|
+
def forward(self, x):
|
193
|
+
for block in self.blocks:
|
194
|
+
x = block(x)
|
195
|
+
x = x.view(x.size(0), -1)
|
196
|
+
mu = self.fc_mu(x)
|
197
|
+
logvar = self.fc_var(x)
|
198
|
+
return mu, logvar
|
199
|
+
|
200
|
+
def encode(self, x):
|
201
|
+
for block in self.blocks:
|
202
|
+
x = block(x)
|
203
|
+
x = x.view(x.size(0), -1)
|
204
|
+
mu = self.fc_mu(x)
|
205
|
+
logvar = self.fc_var(x)
|
206
|
+
return self.reparameterize(mu, logvar)
|
207
|
+
|
208
|
+
def reparameterize(self, mu, logvar):
|
209
|
+
std = torch.exp(0.5 * logvar)
|
210
|
+
eps = torch.randn_like(std)
|
211
|
+
return eps.mul(std) + mu
|
212
|
+
|
213
|
+
|
214
|
+
class SPADEDecoder(nn.Module):
|
215
|
+
"""
|
216
|
+
Decoder branch of a SPADE-like generator. It can be used independently, without an encoding branch,
|
217
|
+
behaving like a GAN, or coupled to a SPADE encoder.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
label_nc: number of semantic labels
|
221
|
+
spatial_dims: number of spatial dimensions
|
222
|
+
out_channels: number of output channels
|
223
|
+
label_nc: number of semantic channels used for the SPADE normalisation blocks
|
224
|
+
input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers
|
225
|
+
channels: number of output after each downsampling block
|
226
|
+
z_dim: latent space dimension of the VAE containing the image sytle information (None if encoder is not used)
|
227
|
+
is_vae: whether the decoder is going to be coupled to an autoencoder or not (true: yes, false: no)
|
228
|
+
spade_intermediate_channels: number of channels in the intermediate layers of the SPADE normalisation blocks
|
229
|
+
norm: base normalisation type
|
230
|
+
act: activation layer type
|
231
|
+
last_act: activation layer type for the last layer of the network (can differ from previous)
|
232
|
+
kernel_size: convolutional kernel size
|
233
|
+
upsampling_mode: upsampling mode (nearest, bilinear etc.)
|
234
|
+
"""
|
235
|
+
|
236
|
+
def __init__(
|
237
|
+
self,
|
238
|
+
spatial_dims: int,
|
239
|
+
out_channels: int,
|
240
|
+
label_nc: int,
|
241
|
+
input_shape: Sequence[int],
|
242
|
+
channels: list[int],
|
243
|
+
z_dim: int | None = None,
|
244
|
+
is_vae: bool = True,
|
245
|
+
spade_intermediate_channels: int = 128,
|
246
|
+
norm: str | tuple = "INSTANCE",
|
247
|
+
act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}),
|
248
|
+
last_act: str | tuple | None = (Act.LEAKYRELU, {"negative_slope": 0.2}),
|
249
|
+
kernel_size: int = 3,
|
250
|
+
upsampling_mode: str = UpsamplingModes.nearest.value,
|
251
|
+
):
|
252
|
+
super().__init__()
|
253
|
+
self.is_vae = is_vae
|
254
|
+
self.out_channels = out_channels
|
255
|
+
self.label_nc = label_nc
|
256
|
+
self.num_channels = channels
|
257
|
+
if len(input_shape) != spatial_dims:
|
258
|
+
raise ValueError("Length of parameter input shape must match spatial_dims; got %s" % (input_shape))
|
259
|
+
for s_ind, s_ in enumerate(input_shape):
|
260
|
+
if s_ / (2 ** len(channels)) != s_ // (2 ** len(channels)):
|
261
|
+
raise ValueError(
|
262
|
+
"Each dimension of your input must be divisible by 2 ** (autoencoder depth)."
|
263
|
+
"The shape in position %d, %d is not divisible by %d. " % (s_ind, s_, len(channels))
|
264
|
+
)
|
265
|
+
self.latent_spatial_shape = [s_ // (2 ** len(self.num_channels)) for s_ in input_shape]
|
266
|
+
|
267
|
+
if not self.is_vae:
|
268
|
+
self.conv_init = Convolution(
|
269
|
+
spatial_dims=spatial_dims, in_channels=label_nc, out_channels=channels[0], kernel_size=kernel_size
|
270
|
+
)
|
271
|
+
elif self.is_vae and z_dim is None:
|
272
|
+
raise ValueError(
|
273
|
+
"If the network is used in VAE-GAN mode, parameter z_dim "
|
274
|
+
"(number of latent channels in the VAE) must be populated."
|
275
|
+
)
|
276
|
+
else:
|
277
|
+
self.fc = nn.Linear(z_dim, np.prod(self.latent_spatial_shape) * channels[0])
|
278
|
+
|
279
|
+
self.z_dim = z_dim
|
280
|
+
blocks = []
|
281
|
+
channels.append(self.out_channels)
|
282
|
+
self.upsampling = torch.nn.Upsample(scale_factor=2, mode=upsampling_mode)
|
283
|
+
for ch_ind, ch_value in enumerate(channels[:-1]):
|
284
|
+
blocks.append(
|
285
|
+
SPADENetResBlock(
|
286
|
+
spatial_dims=spatial_dims,
|
287
|
+
in_channels=ch_value,
|
288
|
+
out_channels=channels[ch_ind + 1],
|
289
|
+
label_nc=label_nc,
|
290
|
+
spade_intermediate_channels=spade_intermediate_channels,
|
291
|
+
norm=norm,
|
292
|
+
kernel_size=kernel_size,
|
293
|
+
act=act,
|
294
|
+
)
|
295
|
+
)
|
296
|
+
|
297
|
+
self.blocks = torch.nn.ModuleList(blocks)
|
298
|
+
self.last_conv = Convolution(
|
299
|
+
spatial_dims=spatial_dims,
|
300
|
+
in_channels=channels[-1],
|
301
|
+
out_channels=out_channels,
|
302
|
+
padding=(kernel_size - 1) // 2,
|
303
|
+
kernel_size=kernel_size,
|
304
|
+
norm=None,
|
305
|
+
act=last_act,
|
306
|
+
)
|
307
|
+
|
308
|
+
def forward(self, seg, z: torch.Tensor | None = None):
|
309
|
+
"""
|
310
|
+
Args:
|
311
|
+
seg: input BxCxHxW[xD] semantic map on which the output is conditioned on
|
312
|
+
z: latent vector output by the encoder if self.is_vae is True. When is_vae is
|
313
|
+
False, z is a random noise vector.
|
314
|
+
|
315
|
+
Returns:
|
316
|
+
|
317
|
+
"""
|
318
|
+
if not self.is_vae:
|
319
|
+
x = F.interpolate(seg, size=tuple(self.latent_spatial_shape))
|
320
|
+
x = self.conv_init(x)
|
321
|
+
else:
|
322
|
+
if (
|
323
|
+
z is None and self.z_dim is not None
|
324
|
+
): # Even though this network is a VAE (self.is_vae), you should be able to sample from noise as well.
|
325
|
+
z = torch.randn(seg.size(0), self.z_dim, dtype=torch.float32, device=seg.get_device())
|
326
|
+
x = self.fc(z)
|
327
|
+
x = x.view(*[-1, self.num_channels[0]] + self.latent_spatial_shape)
|
328
|
+
|
329
|
+
for res_block in self.blocks:
|
330
|
+
x = res_block(x, seg)
|
331
|
+
x = self.upsampling(x)
|
332
|
+
|
333
|
+
x = self.last_conv(x)
|
334
|
+
return x
|
335
|
+
|
336
|
+
|
337
|
+
class SPADENet(nn.Module):
|
338
|
+
"""
|
339
|
+
SPADE Network, implemented based on the code by Park, T et al. in
|
340
|
+
"Semantic Image Synthesis with Spatially-Adaptive Normalization"
|
341
|
+
(https://github.com/NVlabs/SPADE)
|
342
|
+
|
343
|
+
Args:
|
344
|
+
spatial_dims: number of spatial dimensions
|
345
|
+
in_channels: number of input channels
|
346
|
+
out_channels: number of output channels
|
347
|
+
label_nc: number of semantic channels used for the SPADE normalisation blocks
|
348
|
+
input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers
|
349
|
+
channels: number of output after each downsampling block
|
350
|
+
z_dim: latent space dimension of the VAE containing the image sytle information (None if encoder is not used)
|
351
|
+
is_vae: whether the decoder is going to be coupled to an autoencoder (true) or not (false)
|
352
|
+
spade_intermediate_channels: number of channels in the intermediate layers of the SPADE normalisation blocks
|
353
|
+
norm: base normalisation type
|
354
|
+
act: activation layer type
|
355
|
+
last_act: activation layer type for the last layer of the network (can differ from previous)
|
356
|
+
kernel_size: convolutional kernel size
|
357
|
+
upsampling_mode: upsampling mode (nearest, bilinear etc.)
|
358
|
+
"""
|
359
|
+
|
360
|
+
def __init__(
|
361
|
+
self,
|
362
|
+
spatial_dims: int,
|
363
|
+
in_channels: int,
|
364
|
+
out_channels: int,
|
365
|
+
label_nc: int,
|
366
|
+
input_shape: Sequence[int],
|
367
|
+
channels: list[int],
|
368
|
+
z_dim: int | None = None,
|
369
|
+
is_vae: bool = True,
|
370
|
+
spade_intermediate_channels: int = 128,
|
371
|
+
norm: str | tuple = "INSTANCE",
|
372
|
+
act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}),
|
373
|
+
last_act: str | tuple | None = (Act.LEAKYRELU, {"negative_slope": 0.2}),
|
374
|
+
kernel_size: int = 3,
|
375
|
+
upsampling_mode: str = UpsamplingModes.nearest.value,
|
376
|
+
):
|
377
|
+
super().__init__()
|
378
|
+
self.is_vae = is_vae
|
379
|
+
self.in_channels = in_channels
|
380
|
+
self.out_channels = out_channels
|
381
|
+
self.channels = channels
|
382
|
+
self.label_nc = label_nc
|
383
|
+
self.input_shape = input_shape
|
384
|
+
|
385
|
+
if self.is_vae:
|
386
|
+
if z_dim is None:
|
387
|
+
ValueError("The latent space dimension mapped by parameter z_dim cannot be None is is_vae is True.")
|
388
|
+
else:
|
389
|
+
self.encoder = SPADEEncoder(
|
390
|
+
spatial_dims=spatial_dims,
|
391
|
+
in_channels=in_channels,
|
392
|
+
z_dim=z_dim,
|
393
|
+
channels=channels,
|
394
|
+
input_shape=input_shape,
|
395
|
+
kernel_size=kernel_size,
|
396
|
+
norm=norm,
|
397
|
+
act=act,
|
398
|
+
)
|
399
|
+
|
400
|
+
decoder_channels = channels
|
401
|
+
decoder_channels.reverse()
|
402
|
+
|
403
|
+
self.decoder = SPADEDecoder(
|
404
|
+
spatial_dims=spatial_dims,
|
405
|
+
out_channels=out_channels,
|
406
|
+
label_nc=label_nc,
|
407
|
+
input_shape=input_shape,
|
408
|
+
channels=decoder_channels,
|
409
|
+
z_dim=z_dim,
|
410
|
+
is_vae=is_vae,
|
411
|
+
spade_intermediate_channels=spade_intermediate_channels,
|
412
|
+
norm=norm,
|
413
|
+
act=act,
|
414
|
+
last_act=last_act,
|
415
|
+
kernel_size=kernel_size,
|
416
|
+
upsampling_mode=upsampling_mode,
|
417
|
+
)
|
418
|
+
|
419
|
+
def forward(self, seg: torch.Tensor, x: torch.Tensor | None = None):
|
420
|
+
z = None
|
421
|
+
if self.is_vae:
|
422
|
+
z_mu, z_logvar = self.encoder(x)
|
423
|
+
z = self.encoder.reparameterize(z_mu, z_logvar)
|
424
|
+
return self.decoder(seg, z), z_mu, z_logvar
|
425
|
+
else:
|
426
|
+
return (self.decoder(seg, z),)
|
427
|
+
|
428
|
+
def encode(self, x: torch.Tensor):
|
429
|
+
if self.is_vae:
|
430
|
+
return self.encoder.encode(x)
|
431
|
+
else:
|
432
|
+
return None
|
433
|
+
|
434
|
+
def decode(self, seg: torch.Tensor, z: torch.Tensor | None = None):
|
435
|
+
return self.decoder(seg, z)
|
@@ -347,7 +347,7 @@ def window_partition(x, window_size):
|
|
347
347
|
x: input tensor.
|
348
348
|
window_size: local window size.
|
349
349
|
"""
|
350
|
-
x_shape = x.size()
|
350
|
+
x_shape = x.size() # length 4 or 5 only
|
351
351
|
if len(x_shape) == 5:
|
352
352
|
b, d, h, w, c = x_shape
|
353
353
|
x = x.view(
|
@@ -363,10 +363,11 @@ def window_partition(x, window_size):
|
|
363
363
|
windows = (
|
364
364
|
x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c)
|
365
365
|
)
|
366
|
-
|
366
|
+
else: # if len(x_shape) == 4:
|
367
367
|
b, h, w, c = x.shape
|
368
368
|
x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c)
|
369
369
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c)
|
370
|
+
|
370
371
|
return windows
|
371
372
|
|
372
373
|
|
@@ -613,7 +614,7 @@ class SwinTransformerBlock(nn.Module):
|
|
613
614
|
_, dp, hp, wp, _ = x.shape
|
614
615
|
dims = [b, dp, hp, wp]
|
615
616
|
|
616
|
-
elif len(x_shape) == 4
|
617
|
+
else: # elif len(x_shape) == 4
|
617
618
|
b, h, w, c = x.shape
|
618
619
|
window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
|
619
620
|
pad_l = pad_t = 0
|
@@ -0,0 +1,157 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import torch
|
15
|
+
import torch.nn as nn
|
16
|
+
|
17
|
+
from monai.networks.blocks import TransformerBlock
|
18
|
+
|
19
|
+
__all__ = ["DecoderOnlyTransformer"]
|
20
|
+
|
21
|
+
|
22
|
+
class AbsolutePositionalEmbedding(nn.Module):
|
23
|
+
"""Absolute positional embedding.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
max_seq_len: Maximum sequence length.
|
27
|
+
embedding_dim: Dimensionality of the embedding.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(self, max_seq_len: int, embedding_dim: int) -> None:
|
31
|
+
super().__init__()
|
32
|
+
self.max_seq_len = max_seq_len
|
33
|
+
self.embedding_dim = embedding_dim
|
34
|
+
self.embedding = nn.Embedding(max_seq_len, embedding_dim)
|
35
|
+
|
36
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
37
|
+
batch_size, seq_len = x.size()
|
38
|
+
positions = torch.arange(seq_len, device=x.device).repeat(batch_size, 1)
|
39
|
+
embedding: torch.Tensor = self.embedding(positions)
|
40
|
+
return embedding
|
41
|
+
|
42
|
+
|
43
|
+
class DecoderOnlyTransformer(nn.Module):
|
44
|
+
"""Decoder-only (Autoregressive) Transformer model.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
num_tokens: Number of tokens in the vocabulary.
|
48
|
+
max_seq_len: Maximum sequence length.
|
49
|
+
attn_layers_dim: Dimensionality of the attention layers.
|
50
|
+
attn_layers_depth: Number of attention layers.
|
51
|
+
attn_layers_heads: Number of attention heads.
|
52
|
+
with_cross_attention: Whether to use cross attention for conditioning.
|
53
|
+
embedding_dropout_rate: Dropout rate for the embedding.
|
54
|
+
"""
|
55
|
+
|
56
|
+
def __init__(
|
57
|
+
self,
|
58
|
+
num_tokens: int,
|
59
|
+
max_seq_len: int,
|
60
|
+
attn_layers_dim: int,
|
61
|
+
attn_layers_depth: int,
|
62
|
+
attn_layers_heads: int,
|
63
|
+
with_cross_attention: bool = False,
|
64
|
+
embedding_dropout_rate: float = 0.0,
|
65
|
+
) -> None:
|
66
|
+
super().__init__()
|
67
|
+
self.num_tokens = num_tokens
|
68
|
+
self.max_seq_len = max_seq_len
|
69
|
+
self.attn_layers_dim = attn_layers_dim
|
70
|
+
self.attn_layers_depth = attn_layers_depth
|
71
|
+
self.attn_layers_heads = attn_layers_heads
|
72
|
+
self.with_cross_attention = with_cross_attention
|
73
|
+
|
74
|
+
self.token_embeddings = nn.Embedding(num_tokens, attn_layers_dim)
|
75
|
+
self.position_embeddings = AbsolutePositionalEmbedding(max_seq_len=max_seq_len, embedding_dim=attn_layers_dim)
|
76
|
+
self.embedding_dropout = nn.Dropout(embedding_dropout_rate)
|
77
|
+
|
78
|
+
self.blocks = nn.ModuleList(
|
79
|
+
[
|
80
|
+
TransformerBlock(
|
81
|
+
hidden_size=attn_layers_dim,
|
82
|
+
mlp_dim=attn_layers_dim * 4,
|
83
|
+
num_heads=attn_layers_heads,
|
84
|
+
dropout_rate=0.0,
|
85
|
+
qkv_bias=False,
|
86
|
+
causal=True,
|
87
|
+
sequence_length=max_seq_len,
|
88
|
+
with_cross_attention=with_cross_attention,
|
89
|
+
)
|
90
|
+
for _ in range(attn_layers_depth)
|
91
|
+
]
|
92
|
+
)
|
93
|
+
|
94
|
+
self.to_logits = nn.Linear(attn_layers_dim, num_tokens)
|
95
|
+
|
96
|
+
def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
|
97
|
+
tok_emb = self.token_embeddings(x)
|
98
|
+
pos_emb = self.position_embeddings(x)
|
99
|
+
x = self.embedding_dropout(tok_emb + pos_emb)
|
100
|
+
|
101
|
+
for block in self.blocks:
|
102
|
+
x = block(x, context=context)
|
103
|
+
logits: torch.Tensor = self.to_logits(x)
|
104
|
+
return logits
|
105
|
+
|
106
|
+
def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
|
107
|
+
"""
|
108
|
+
Load a state dict from a DecoderOnlyTransformer trained with
|
109
|
+
[MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).
|
110
|
+
|
111
|
+
Args:
|
112
|
+
old_state_dict: state dict from the old DecoderOnlyTransformer model.
|
113
|
+
"""
|
114
|
+
|
115
|
+
new_state_dict = self.state_dict()
|
116
|
+
# if all keys match, just load the state dict
|
117
|
+
if all(k in new_state_dict for k in old_state_dict):
|
118
|
+
print("All keys match, loading state dict.")
|
119
|
+
self.load_state_dict(old_state_dict)
|
120
|
+
return
|
121
|
+
|
122
|
+
if verbose:
|
123
|
+
# print all new_state_dict keys that are not in old_state_dict
|
124
|
+
for k in new_state_dict:
|
125
|
+
if k not in old_state_dict:
|
126
|
+
print(f"key {k} not found in old state dict")
|
127
|
+
# and vice versa
|
128
|
+
print("----------------------------------------------")
|
129
|
+
for k in old_state_dict:
|
130
|
+
if k not in new_state_dict:
|
131
|
+
print(f"key {k} not found in new state dict")
|
132
|
+
|
133
|
+
# copy over all matching keys
|
134
|
+
for k in new_state_dict:
|
135
|
+
if k in old_state_dict:
|
136
|
+
new_state_dict[k] = old_state_dict[k]
|
137
|
+
|
138
|
+
# fix the attention blocks
|
139
|
+
attention_blocks = [k.replace(".attn.qkv.weight", "") for k in new_state_dict if "attn.qkv.weight" in k]
|
140
|
+
for block in attention_blocks:
|
141
|
+
new_state_dict[f"{block}.attn.qkv.weight"] = torch.cat(
|
142
|
+
[
|
143
|
+
old_state_dict[f"{block}.attn.to_q.weight"],
|
144
|
+
old_state_dict[f"{block}.attn.to_k.weight"],
|
145
|
+
old_state_dict[f"{block}.attn.to_v.weight"],
|
146
|
+
],
|
147
|
+
dim=0,
|
148
|
+
)
|
149
|
+
|
150
|
+
# fix the renamed norm blocks first norm2 -> norm_cross_attention , norm3 -> norm2
|
151
|
+
for k in old_state_dict:
|
152
|
+
if "norm2" in k:
|
153
|
+
new_state_dict[k.replace("norm2", "norm_cross_attn")] = old_state_dict[k]
|
154
|
+
if "norm3" in k:
|
155
|
+
new_state_dict[k.replace("norm3", "norm2")] = old_state_dict[k]
|
156
|
+
|
157
|
+
self.load_state_dict(new_state_dict)
|