monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/auto3dseg/hpo_gen.py +1 -1
  4. monai/apps/detection/utils/anchor_utils.py +2 -2
  5. monai/apps/pathology/transforms/post/array.py +7 -4
  6. monai/auto3dseg/analyzer.py +1 -1
  7. monai/bundle/scripts.py +204 -22
  8. monai/bundle/utils.py +1 -0
  9. monai/data/dataset_summary.py +1 -0
  10. monai/data/meta_tensor.py +2 -2
  11. monai/data/test_time_augmentation.py +2 -0
  12. monai/data/utils.py +9 -6
  13. monai/data/wsi_reader.py +2 -2
  14. monai/engines/__init__.py +3 -1
  15. monai/engines/trainer.py +281 -2
  16. monai/engines/utils.py +76 -1
  17. monai/handlers/mlflow_handler.py +21 -4
  18. monai/inferers/__init__.py +5 -0
  19. monai/inferers/inferer.py +1279 -1
  20. monai/metrics/cumulative_average.py +2 -0
  21. monai/metrics/panoptic_quality.py +1 -1
  22. monai/metrics/rocauc.py +2 -2
  23. monai/networks/blocks/__init__.py +3 -0
  24. monai/networks/blocks/attention_utils.py +128 -0
  25. monai/networks/blocks/crossattention.py +168 -0
  26. monai/networks/blocks/rel_pos_embedding.py +56 -0
  27. monai/networks/blocks/selfattention.py +74 -5
  28. monai/networks/blocks/spade_norm.py +95 -0
  29. monai/networks/blocks/spatialattention.py +82 -0
  30. monai/networks/blocks/transformerblock.py +25 -4
  31. monai/networks/blocks/upsample.py +22 -10
  32. monai/networks/layers/__init__.py +2 -1
  33. monai/networks/layers/factories.py +12 -1
  34. monai/networks/layers/simplelayers.py +1 -1
  35. monai/networks/layers/utils.py +14 -1
  36. monai/networks/layers/vector_quantizer.py +233 -0
  37. monai/networks/nets/__init__.py +9 -0
  38. monai/networks/nets/autoencoderkl.py +702 -0
  39. monai/networks/nets/controlnet.py +465 -0
  40. monai/networks/nets/diffusion_model_unet.py +1913 -0
  41. monai/networks/nets/patchgan_discriminator.py +230 -0
  42. monai/networks/nets/quicknat.py +8 -6
  43. monai/networks/nets/resnet.py +3 -4
  44. monai/networks/nets/spade_autoencoderkl.py +480 -0
  45. monai/networks/nets/spade_diffusion_model_unet.py +934 -0
  46. monai/networks/nets/spade_network.py +435 -0
  47. monai/networks/nets/swin_unetr.py +4 -3
  48. monai/networks/nets/transformer.py +157 -0
  49. monai/networks/nets/vqvae.py +472 -0
  50. monai/networks/schedulers/__init__.py +17 -0
  51. monai/networks/schedulers/ddim.py +294 -0
  52. monai/networks/schedulers/ddpm.py +250 -0
  53. monai/networks/schedulers/pndm.py +316 -0
  54. monai/networks/schedulers/scheduler.py +205 -0
  55. monai/networks/utils.py +22 -0
  56. monai/transforms/croppad/array.py +8 -8
  57. monai/transforms/croppad/dictionary.py +4 -4
  58. monai/transforms/croppad/functional.py +1 -1
  59. monai/transforms/regularization/array.py +4 -0
  60. monai/transforms/spatial/array.py +1 -1
  61. monai/transforms/utils_create_transform_ims.py +2 -4
  62. monai/utils/__init__.py +1 -0
  63. monai/utils/misc.py +5 -4
  64. monai/utils/ordering.py +207 -0
  65. monai/visualize/class_activation_maps.py +5 -5
  66. monai/visualize/img2tensorboard.py +3 -1
  67. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
  68. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
  69. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
  70. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
  71. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,435 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import Sequence
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+ from monai.networks.blocks import Convolution
22
+ from monai.networks.blocks.spade_norm import SPADE
23
+ from monai.networks.layers import Act
24
+ from monai.networks.layers.utils import get_act_layer
25
+ from monai.utils.enums import StrEnum
26
+
27
+ __all__ = ["SPADENet"]
28
+
29
+
30
+ class UpsamplingModes(StrEnum):
31
+ bicubic = "bicubic"
32
+ nearest = "nearest"
33
+ bilinear = "bilinear"
34
+
35
+
36
+ class SPADENetResBlock(nn.Module):
37
+ """
38
+ Creates a Residual Block with SPADE normalisation.
39
+
40
+ Args:
41
+ spatial_dims: number of spatial dimensions
42
+ in_channels: number of input channels
43
+ out_channels: number of output channels
44
+ label_nc: number of semantic channels that will be taken into account in SPADE normalisation blocks
45
+ spade_intermediate_channels: number of intermediate channels in the middle conv. layers in SPADE normalisation blocks
46
+ norm: base normalisation type used on top of SPADE
47
+ kernel_size: convolutional kernel size
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ spatial_dims: int,
53
+ in_channels: int,
54
+ out_channels: int,
55
+ label_nc: int,
56
+ spade_intermediate_channels: int = 128,
57
+ norm: str | tuple = "INSTANCE",
58
+ act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}),
59
+ kernel_size: int = 3,
60
+ ):
61
+ super().__init__()
62
+ self.in_channels = in_channels
63
+ self.out_channels = out_channels
64
+ self.int_channels = min(self.in_channels, self.out_channels)
65
+ self.learned_shortcut = self.in_channels != self.out_channels
66
+ self.conv_0 = Convolution(
67
+ spatial_dims=spatial_dims, in_channels=self.in_channels, out_channels=self.int_channels, act=None, norm=None
68
+ )
69
+ self.conv_1 = Convolution(
70
+ spatial_dims=spatial_dims,
71
+ in_channels=self.int_channels,
72
+ out_channels=self.out_channels,
73
+ act=None,
74
+ norm=None,
75
+ )
76
+ self.activation = get_act_layer(act)
77
+ self.norm_0 = SPADE(
78
+ label_nc=label_nc,
79
+ norm_nc=self.in_channels,
80
+ kernel_size=kernel_size,
81
+ spatial_dims=spatial_dims,
82
+ hidden_channels=spade_intermediate_channels,
83
+ norm=norm,
84
+ )
85
+ self.norm_1 = SPADE(
86
+ label_nc=label_nc,
87
+ norm_nc=self.int_channels,
88
+ kernel_size=kernel_size,
89
+ spatial_dims=spatial_dims,
90
+ hidden_channels=spade_intermediate_channels,
91
+ norm=norm,
92
+ )
93
+
94
+ if self.learned_shortcut:
95
+ self.conv_s = Convolution(
96
+ spatial_dims=spatial_dims,
97
+ in_channels=self.in_channels,
98
+ out_channels=self.out_channels,
99
+ act=None,
100
+ norm=None,
101
+ kernel_size=1,
102
+ )
103
+ self.norm_s = SPADE(
104
+ label_nc=label_nc,
105
+ norm_nc=self.in_channels,
106
+ kernel_size=kernel_size,
107
+ spatial_dims=spatial_dims,
108
+ hidden_channels=spade_intermediate_channels,
109
+ norm=norm,
110
+ )
111
+
112
+ def forward(self, x, seg):
113
+ x_s = self.shortcut(x, seg)
114
+ dx = self.conv_0(self.activation(self.norm_0(x, seg)))
115
+ dx = self.conv_1(self.activation(self.norm_1(dx, seg)))
116
+ out = x_s + dx
117
+ return out
118
+
119
+ def shortcut(self, x, seg):
120
+ if self.learned_shortcut:
121
+ x_s = self.conv_s(self.norm_s(x, seg))
122
+ else:
123
+ x_s = x
124
+ return x_s
125
+
126
+
127
+ class SPADEEncoder(nn.Module):
128
+ """
129
+ Encoding branch of a VAE compatible with a SPADE-like generator
130
+
131
+ Args:
132
+ spatial_dims: number of spatial dimensions
133
+ in_channels: number of input channels
134
+ z_dim: latent space dimension of the VAE containing the image sytle information
135
+ channels: number of output after each downsampling block
136
+ input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers
137
+ of the autoencoder (HxWx[D])
138
+ kernel_size: convolutional kernel size
139
+ norm: normalisation layer type
140
+ act: activation type
141
+ """
142
+
143
+ def __init__(
144
+ self,
145
+ spatial_dims: int,
146
+ in_channels: int,
147
+ z_dim: int,
148
+ channels: Sequence[int],
149
+ input_shape: Sequence[int],
150
+ kernel_size: int = 3,
151
+ norm: str | tuple = "INSTANCE",
152
+ act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}),
153
+ ):
154
+ super().__init__()
155
+ self.in_channels = in_channels
156
+ self.z_dim = z_dim
157
+ self.channels = channels
158
+ if len(input_shape) != spatial_dims:
159
+ raise ValueError("Length of parameter input shape must match spatial_dims; got %s" % (input_shape))
160
+ for s_ind, s_ in enumerate(input_shape):
161
+ if s_ / (2 ** len(channels)) != s_ // (2 ** len(channels)):
162
+ raise ValueError(
163
+ "Each dimension of your input must be divisible by 2 ** (autoencoder depth)."
164
+ "The shape in position %d, %d is not divisible by %d. " % (s_ind, s_, len(channels))
165
+ )
166
+ self.input_shape = input_shape
167
+ self.latent_spatial_shape = [s_ // (2 ** len(self.channels)) for s_ in self.input_shape]
168
+ blocks = []
169
+ ch_init = self.in_channels
170
+ for _, ch_value in enumerate(channels):
171
+ blocks.append(
172
+ Convolution(
173
+ spatial_dims=spatial_dims,
174
+ in_channels=ch_init,
175
+ out_channels=ch_value,
176
+ strides=2,
177
+ kernel_size=kernel_size,
178
+ norm=norm,
179
+ act=act,
180
+ )
181
+ )
182
+ ch_init = ch_value
183
+
184
+ self.blocks = nn.ModuleList(blocks)
185
+ self.fc_mu = nn.Linear(
186
+ in_features=np.prod(self.latent_spatial_shape) * self.channels[-1], out_features=self.z_dim
187
+ )
188
+ self.fc_var = nn.Linear(
189
+ in_features=np.prod(self.latent_spatial_shape) * self.channels[-1], out_features=self.z_dim
190
+ )
191
+
192
+ def forward(self, x):
193
+ for block in self.blocks:
194
+ x = block(x)
195
+ x = x.view(x.size(0), -1)
196
+ mu = self.fc_mu(x)
197
+ logvar = self.fc_var(x)
198
+ return mu, logvar
199
+
200
+ def encode(self, x):
201
+ for block in self.blocks:
202
+ x = block(x)
203
+ x = x.view(x.size(0), -1)
204
+ mu = self.fc_mu(x)
205
+ logvar = self.fc_var(x)
206
+ return self.reparameterize(mu, logvar)
207
+
208
+ def reparameterize(self, mu, logvar):
209
+ std = torch.exp(0.5 * logvar)
210
+ eps = torch.randn_like(std)
211
+ return eps.mul(std) + mu
212
+
213
+
214
+ class SPADEDecoder(nn.Module):
215
+ """
216
+ Decoder branch of a SPADE-like generator. It can be used independently, without an encoding branch,
217
+ behaving like a GAN, or coupled to a SPADE encoder.
218
+
219
+ Args:
220
+ label_nc: number of semantic labels
221
+ spatial_dims: number of spatial dimensions
222
+ out_channels: number of output channels
223
+ label_nc: number of semantic channels used for the SPADE normalisation blocks
224
+ input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers
225
+ channels: number of output after each downsampling block
226
+ z_dim: latent space dimension of the VAE containing the image sytle information (None if encoder is not used)
227
+ is_vae: whether the decoder is going to be coupled to an autoencoder or not (true: yes, false: no)
228
+ spade_intermediate_channels: number of channels in the intermediate layers of the SPADE normalisation blocks
229
+ norm: base normalisation type
230
+ act: activation layer type
231
+ last_act: activation layer type for the last layer of the network (can differ from previous)
232
+ kernel_size: convolutional kernel size
233
+ upsampling_mode: upsampling mode (nearest, bilinear etc.)
234
+ """
235
+
236
+ def __init__(
237
+ self,
238
+ spatial_dims: int,
239
+ out_channels: int,
240
+ label_nc: int,
241
+ input_shape: Sequence[int],
242
+ channels: list[int],
243
+ z_dim: int | None = None,
244
+ is_vae: bool = True,
245
+ spade_intermediate_channels: int = 128,
246
+ norm: str | tuple = "INSTANCE",
247
+ act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}),
248
+ last_act: str | tuple | None = (Act.LEAKYRELU, {"negative_slope": 0.2}),
249
+ kernel_size: int = 3,
250
+ upsampling_mode: str = UpsamplingModes.nearest.value,
251
+ ):
252
+ super().__init__()
253
+ self.is_vae = is_vae
254
+ self.out_channels = out_channels
255
+ self.label_nc = label_nc
256
+ self.num_channels = channels
257
+ if len(input_shape) != spatial_dims:
258
+ raise ValueError("Length of parameter input shape must match spatial_dims; got %s" % (input_shape))
259
+ for s_ind, s_ in enumerate(input_shape):
260
+ if s_ / (2 ** len(channels)) != s_ // (2 ** len(channels)):
261
+ raise ValueError(
262
+ "Each dimension of your input must be divisible by 2 ** (autoencoder depth)."
263
+ "The shape in position %d, %d is not divisible by %d. " % (s_ind, s_, len(channels))
264
+ )
265
+ self.latent_spatial_shape = [s_ // (2 ** len(self.num_channels)) for s_ in input_shape]
266
+
267
+ if not self.is_vae:
268
+ self.conv_init = Convolution(
269
+ spatial_dims=spatial_dims, in_channels=label_nc, out_channels=channels[0], kernel_size=kernel_size
270
+ )
271
+ elif self.is_vae and z_dim is None:
272
+ raise ValueError(
273
+ "If the network is used in VAE-GAN mode, parameter z_dim "
274
+ "(number of latent channels in the VAE) must be populated."
275
+ )
276
+ else:
277
+ self.fc = nn.Linear(z_dim, np.prod(self.latent_spatial_shape) * channels[0])
278
+
279
+ self.z_dim = z_dim
280
+ blocks = []
281
+ channels.append(self.out_channels)
282
+ self.upsampling = torch.nn.Upsample(scale_factor=2, mode=upsampling_mode)
283
+ for ch_ind, ch_value in enumerate(channels[:-1]):
284
+ blocks.append(
285
+ SPADENetResBlock(
286
+ spatial_dims=spatial_dims,
287
+ in_channels=ch_value,
288
+ out_channels=channels[ch_ind + 1],
289
+ label_nc=label_nc,
290
+ spade_intermediate_channels=spade_intermediate_channels,
291
+ norm=norm,
292
+ kernel_size=kernel_size,
293
+ act=act,
294
+ )
295
+ )
296
+
297
+ self.blocks = torch.nn.ModuleList(blocks)
298
+ self.last_conv = Convolution(
299
+ spatial_dims=spatial_dims,
300
+ in_channels=channels[-1],
301
+ out_channels=out_channels,
302
+ padding=(kernel_size - 1) // 2,
303
+ kernel_size=kernel_size,
304
+ norm=None,
305
+ act=last_act,
306
+ )
307
+
308
+ def forward(self, seg, z: torch.Tensor | None = None):
309
+ """
310
+ Args:
311
+ seg: input BxCxHxW[xD] semantic map on which the output is conditioned on
312
+ z: latent vector output by the encoder if self.is_vae is True. When is_vae is
313
+ False, z is a random noise vector.
314
+
315
+ Returns:
316
+
317
+ """
318
+ if not self.is_vae:
319
+ x = F.interpolate(seg, size=tuple(self.latent_spatial_shape))
320
+ x = self.conv_init(x)
321
+ else:
322
+ if (
323
+ z is None and self.z_dim is not None
324
+ ): # Even though this network is a VAE (self.is_vae), you should be able to sample from noise as well.
325
+ z = torch.randn(seg.size(0), self.z_dim, dtype=torch.float32, device=seg.get_device())
326
+ x = self.fc(z)
327
+ x = x.view(*[-1, self.num_channels[0]] + self.latent_spatial_shape)
328
+
329
+ for res_block in self.blocks:
330
+ x = res_block(x, seg)
331
+ x = self.upsampling(x)
332
+
333
+ x = self.last_conv(x)
334
+ return x
335
+
336
+
337
+ class SPADENet(nn.Module):
338
+ """
339
+ SPADE Network, implemented based on the code by Park, T et al. in
340
+ "Semantic Image Synthesis with Spatially-Adaptive Normalization"
341
+ (https://github.com/NVlabs/SPADE)
342
+
343
+ Args:
344
+ spatial_dims: number of spatial dimensions
345
+ in_channels: number of input channels
346
+ out_channels: number of output channels
347
+ label_nc: number of semantic channels used for the SPADE normalisation blocks
348
+ input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers
349
+ channels: number of output after each downsampling block
350
+ z_dim: latent space dimension of the VAE containing the image sytle information (None if encoder is not used)
351
+ is_vae: whether the decoder is going to be coupled to an autoencoder (true) or not (false)
352
+ spade_intermediate_channels: number of channels in the intermediate layers of the SPADE normalisation blocks
353
+ norm: base normalisation type
354
+ act: activation layer type
355
+ last_act: activation layer type for the last layer of the network (can differ from previous)
356
+ kernel_size: convolutional kernel size
357
+ upsampling_mode: upsampling mode (nearest, bilinear etc.)
358
+ """
359
+
360
+ def __init__(
361
+ self,
362
+ spatial_dims: int,
363
+ in_channels: int,
364
+ out_channels: int,
365
+ label_nc: int,
366
+ input_shape: Sequence[int],
367
+ channels: list[int],
368
+ z_dim: int | None = None,
369
+ is_vae: bool = True,
370
+ spade_intermediate_channels: int = 128,
371
+ norm: str | tuple = "INSTANCE",
372
+ act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}),
373
+ last_act: str | tuple | None = (Act.LEAKYRELU, {"negative_slope": 0.2}),
374
+ kernel_size: int = 3,
375
+ upsampling_mode: str = UpsamplingModes.nearest.value,
376
+ ):
377
+ super().__init__()
378
+ self.is_vae = is_vae
379
+ self.in_channels = in_channels
380
+ self.out_channels = out_channels
381
+ self.channels = channels
382
+ self.label_nc = label_nc
383
+ self.input_shape = input_shape
384
+
385
+ if self.is_vae:
386
+ if z_dim is None:
387
+ ValueError("The latent space dimension mapped by parameter z_dim cannot be None is is_vae is True.")
388
+ else:
389
+ self.encoder = SPADEEncoder(
390
+ spatial_dims=spatial_dims,
391
+ in_channels=in_channels,
392
+ z_dim=z_dim,
393
+ channels=channels,
394
+ input_shape=input_shape,
395
+ kernel_size=kernel_size,
396
+ norm=norm,
397
+ act=act,
398
+ )
399
+
400
+ decoder_channels = channels
401
+ decoder_channels.reverse()
402
+
403
+ self.decoder = SPADEDecoder(
404
+ spatial_dims=spatial_dims,
405
+ out_channels=out_channels,
406
+ label_nc=label_nc,
407
+ input_shape=input_shape,
408
+ channels=decoder_channels,
409
+ z_dim=z_dim,
410
+ is_vae=is_vae,
411
+ spade_intermediate_channels=spade_intermediate_channels,
412
+ norm=norm,
413
+ act=act,
414
+ last_act=last_act,
415
+ kernel_size=kernel_size,
416
+ upsampling_mode=upsampling_mode,
417
+ )
418
+
419
+ def forward(self, seg: torch.Tensor, x: torch.Tensor | None = None):
420
+ z = None
421
+ if self.is_vae:
422
+ z_mu, z_logvar = self.encoder(x)
423
+ z = self.encoder.reparameterize(z_mu, z_logvar)
424
+ return self.decoder(seg, z), z_mu, z_logvar
425
+ else:
426
+ return (self.decoder(seg, z),)
427
+
428
+ def encode(self, x: torch.Tensor):
429
+ if self.is_vae:
430
+ return self.encoder.encode(x)
431
+ else:
432
+ return None
433
+
434
+ def decode(self, seg: torch.Tensor, z: torch.Tensor | None = None):
435
+ return self.decoder(seg, z)
@@ -347,7 +347,7 @@ def window_partition(x, window_size):
347
347
  x: input tensor.
348
348
  window_size: local window size.
349
349
  """
350
- x_shape = x.size()
350
+ x_shape = x.size() # length 4 or 5 only
351
351
  if len(x_shape) == 5:
352
352
  b, d, h, w, c = x_shape
353
353
  x = x.view(
@@ -363,10 +363,11 @@ def window_partition(x, window_size):
363
363
  windows = (
364
364
  x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c)
365
365
  )
366
- elif len(x_shape) == 4:
366
+ else: # if len(x_shape) == 4:
367
367
  b, h, w, c = x.shape
368
368
  x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c)
369
369
  windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c)
370
+
370
371
  return windows
371
372
 
372
373
 
@@ -613,7 +614,7 @@ class SwinTransformerBlock(nn.Module):
613
614
  _, dp, hp, wp, _ = x.shape
614
615
  dims = [b, dp, hp, wp]
615
616
 
616
- elif len(x_shape) == 4:
617
+ else: # elif len(x_shape) == 4
617
618
  b, h, w, c = x.shape
618
619
  window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
619
620
  pad_l = pad_t = 0
@@ -0,0 +1,157 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from monai.networks.blocks import TransformerBlock
18
+
19
+ __all__ = ["DecoderOnlyTransformer"]
20
+
21
+
22
+ class AbsolutePositionalEmbedding(nn.Module):
23
+ """Absolute positional embedding.
24
+
25
+ Args:
26
+ max_seq_len: Maximum sequence length.
27
+ embedding_dim: Dimensionality of the embedding.
28
+ """
29
+
30
+ def __init__(self, max_seq_len: int, embedding_dim: int) -> None:
31
+ super().__init__()
32
+ self.max_seq_len = max_seq_len
33
+ self.embedding_dim = embedding_dim
34
+ self.embedding = nn.Embedding(max_seq_len, embedding_dim)
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ batch_size, seq_len = x.size()
38
+ positions = torch.arange(seq_len, device=x.device).repeat(batch_size, 1)
39
+ embedding: torch.Tensor = self.embedding(positions)
40
+ return embedding
41
+
42
+
43
+ class DecoderOnlyTransformer(nn.Module):
44
+ """Decoder-only (Autoregressive) Transformer model.
45
+
46
+ Args:
47
+ num_tokens: Number of tokens in the vocabulary.
48
+ max_seq_len: Maximum sequence length.
49
+ attn_layers_dim: Dimensionality of the attention layers.
50
+ attn_layers_depth: Number of attention layers.
51
+ attn_layers_heads: Number of attention heads.
52
+ with_cross_attention: Whether to use cross attention for conditioning.
53
+ embedding_dropout_rate: Dropout rate for the embedding.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ num_tokens: int,
59
+ max_seq_len: int,
60
+ attn_layers_dim: int,
61
+ attn_layers_depth: int,
62
+ attn_layers_heads: int,
63
+ with_cross_attention: bool = False,
64
+ embedding_dropout_rate: float = 0.0,
65
+ ) -> None:
66
+ super().__init__()
67
+ self.num_tokens = num_tokens
68
+ self.max_seq_len = max_seq_len
69
+ self.attn_layers_dim = attn_layers_dim
70
+ self.attn_layers_depth = attn_layers_depth
71
+ self.attn_layers_heads = attn_layers_heads
72
+ self.with_cross_attention = with_cross_attention
73
+
74
+ self.token_embeddings = nn.Embedding(num_tokens, attn_layers_dim)
75
+ self.position_embeddings = AbsolutePositionalEmbedding(max_seq_len=max_seq_len, embedding_dim=attn_layers_dim)
76
+ self.embedding_dropout = nn.Dropout(embedding_dropout_rate)
77
+
78
+ self.blocks = nn.ModuleList(
79
+ [
80
+ TransformerBlock(
81
+ hidden_size=attn_layers_dim,
82
+ mlp_dim=attn_layers_dim * 4,
83
+ num_heads=attn_layers_heads,
84
+ dropout_rate=0.0,
85
+ qkv_bias=False,
86
+ causal=True,
87
+ sequence_length=max_seq_len,
88
+ with_cross_attention=with_cross_attention,
89
+ )
90
+ for _ in range(attn_layers_depth)
91
+ ]
92
+ )
93
+
94
+ self.to_logits = nn.Linear(attn_layers_dim, num_tokens)
95
+
96
+ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
97
+ tok_emb = self.token_embeddings(x)
98
+ pos_emb = self.position_embeddings(x)
99
+ x = self.embedding_dropout(tok_emb + pos_emb)
100
+
101
+ for block in self.blocks:
102
+ x = block(x, context=context)
103
+ logits: torch.Tensor = self.to_logits(x)
104
+ return logits
105
+
106
+ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
107
+ """
108
+ Load a state dict from a DecoderOnlyTransformer trained with
109
+ [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).
110
+
111
+ Args:
112
+ old_state_dict: state dict from the old DecoderOnlyTransformer model.
113
+ """
114
+
115
+ new_state_dict = self.state_dict()
116
+ # if all keys match, just load the state dict
117
+ if all(k in new_state_dict for k in old_state_dict):
118
+ print("All keys match, loading state dict.")
119
+ self.load_state_dict(old_state_dict)
120
+ return
121
+
122
+ if verbose:
123
+ # print all new_state_dict keys that are not in old_state_dict
124
+ for k in new_state_dict:
125
+ if k not in old_state_dict:
126
+ print(f"key {k} not found in old state dict")
127
+ # and vice versa
128
+ print("----------------------------------------------")
129
+ for k in old_state_dict:
130
+ if k not in new_state_dict:
131
+ print(f"key {k} not found in new state dict")
132
+
133
+ # copy over all matching keys
134
+ for k in new_state_dict:
135
+ if k in old_state_dict:
136
+ new_state_dict[k] = old_state_dict[k]
137
+
138
+ # fix the attention blocks
139
+ attention_blocks = [k.replace(".attn.qkv.weight", "") for k in new_state_dict if "attn.qkv.weight" in k]
140
+ for block in attention_blocks:
141
+ new_state_dict[f"{block}.attn.qkv.weight"] = torch.cat(
142
+ [
143
+ old_state_dict[f"{block}.attn.to_q.weight"],
144
+ old_state_dict[f"{block}.attn.to_k.weight"],
145
+ old_state_dict[f"{block}.attn.to_v.weight"],
146
+ ],
147
+ dim=0,
148
+ )
149
+
150
+ # fix the renamed norm blocks first norm2 -> norm_cross_attention , norm3 -> norm2
151
+ for k in old_state_dict:
152
+ if "norm2" in k:
153
+ new_state_dict[k.replace("norm2", "norm_cross_attn")] = old_state_dict[k]
154
+ if "norm3" in k:
155
+ new_state_dict[k.replace("norm3", "norm2")] = old_state_dict[k]
156
+
157
+ self.load_state_dict(new_state_dict)