monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/auto3dseg/hpo_gen.py +1 -1
  4. monai/apps/detection/utils/anchor_utils.py +2 -2
  5. monai/apps/pathology/transforms/post/array.py +7 -4
  6. monai/auto3dseg/analyzer.py +1 -1
  7. monai/bundle/scripts.py +204 -22
  8. monai/bundle/utils.py +1 -0
  9. monai/data/dataset_summary.py +1 -0
  10. monai/data/meta_tensor.py +2 -2
  11. monai/data/test_time_augmentation.py +2 -0
  12. monai/data/utils.py +9 -6
  13. monai/data/wsi_reader.py +2 -2
  14. monai/engines/__init__.py +3 -1
  15. monai/engines/trainer.py +281 -2
  16. monai/engines/utils.py +76 -1
  17. monai/handlers/mlflow_handler.py +21 -4
  18. monai/inferers/__init__.py +5 -0
  19. monai/inferers/inferer.py +1279 -1
  20. monai/metrics/cumulative_average.py +2 -0
  21. monai/metrics/panoptic_quality.py +1 -1
  22. monai/metrics/rocauc.py +2 -2
  23. monai/networks/blocks/__init__.py +3 -0
  24. monai/networks/blocks/attention_utils.py +128 -0
  25. monai/networks/blocks/crossattention.py +168 -0
  26. monai/networks/blocks/rel_pos_embedding.py +56 -0
  27. monai/networks/blocks/selfattention.py +74 -5
  28. monai/networks/blocks/spade_norm.py +95 -0
  29. monai/networks/blocks/spatialattention.py +82 -0
  30. monai/networks/blocks/transformerblock.py +25 -4
  31. monai/networks/blocks/upsample.py +22 -10
  32. monai/networks/layers/__init__.py +2 -1
  33. monai/networks/layers/factories.py +12 -1
  34. monai/networks/layers/simplelayers.py +1 -1
  35. monai/networks/layers/utils.py +14 -1
  36. monai/networks/layers/vector_quantizer.py +233 -0
  37. monai/networks/nets/__init__.py +9 -0
  38. monai/networks/nets/autoencoderkl.py +702 -0
  39. monai/networks/nets/controlnet.py +465 -0
  40. monai/networks/nets/diffusion_model_unet.py +1913 -0
  41. monai/networks/nets/patchgan_discriminator.py +230 -0
  42. monai/networks/nets/quicknat.py +8 -6
  43. monai/networks/nets/resnet.py +3 -4
  44. monai/networks/nets/spade_autoencoderkl.py +480 -0
  45. monai/networks/nets/spade_diffusion_model_unet.py +934 -0
  46. monai/networks/nets/spade_network.py +435 -0
  47. monai/networks/nets/swin_unetr.py +4 -3
  48. monai/networks/nets/transformer.py +157 -0
  49. monai/networks/nets/vqvae.py +472 -0
  50. monai/networks/schedulers/__init__.py +17 -0
  51. monai/networks/schedulers/ddim.py +294 -0
  52. monai/networks/schedulers/ddpm.py +250 -0
  53. monai/networks/schedulers/pndm.py +316 -0
  54. monai/networks/schedulers/scheduler.py +205 -0
  55. monai/networks/utils.py +22 -0
  56. monai/transforms/croppad/array.py +8 -8
  57. monai/transforms/croppad/dictionary.py +4 -4
  58. monai/transforms/croppad/functional.py +1 -1
  59. monai/transforms/regularization/array.py +4 -0
  60. monai/transforms/spatial/array.py +1 -1
  61. monai/transforms/utils_create_transform_ims.py +2 -4
  62. monai/utils/__init__.py +1 -0
  63. monai/utils/misc.py +5 -4
  64. monai/utils/ordering.py +207 -0
  65. monai/visualize/class_activation_maps.py +5 -5
  66. monai/visualize/img2tensorboard.py +3 -1
  67. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
  68. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
  69. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
  70. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
  71. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1913 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+ #
12
+ # =========================================================================
13
+ # Adapted from https://github.com/huggingface/diffusers
14
+ # which has the following license:
15
+ # https://github.com/huggingface/diffusers/blob/main/LICENSE
16
+ #
17
+ # Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
18
+ #
19
+ # Licensed under the Apache License, Version 2.0 (the "License");
20
+ # you may not use this file except in compliance with the License.
21
+ # You may obtain a copy of the License at
22
+ #
23
+ # http://www.apache.org/licenses/LICENSE-2.0
24
+ #
25
+ # Unless required by applicable law or agreed to in writing, software
26
+ # distributed under the License is distributed on an "AS IS" BASIS,
27
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28
+ # See the License for the specific language governing permissions and
29
+ # limitations under the License.
30
+ # =========================================================================
31
+
32
+ from __future__ import annotations
33
+
34
+ import math
35
+ from collections.abc import Sequence
36
+
37
+ import torch
38
+ from torch import nn
39
+
40
+ from monai.networks.blocks import Convolution, CrossAttentionBlock, MLPBlock, SABlock, SpatialAttentionBlock, Upsample
41
+ from monai.networks.layers.factories import Pool
42
+ from monai.utils import ensure_tuple_rep, optional_import
43
+
44
+ Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
45
+
46
+ __all__ = ["DiffusionModelUNet"]
47
+
48
+
49
+ def zero_module(module: nn.Module) -> nn.Module:
50
+ """
51
+ Zero out the parameters of a module and return it.
52
+ """
53
+ for p in module.parameters():
54
+ p.detach().zero_()
55
+ return module
56
+
57
+
58
+ class DiffusionUNetTransformerBlock(nn.Module):
59
+ """
60
+ A Transformer block that allows for the input dimension to differ from the hidden dimension.
61
+
62
+ Args:
63
+ num_channels: number of channels in the input and output.
64
+ num_attention_heads: number of heads to use for multi-head attention.
65
+ num_head_channels: number of channels in each attention head.
66
+ dropout: dropout probability to use.
67
+ cross_attention_dim: size of the context vector for cross attention.
68
+ upcast_attention: if True, upcast attention operations to full precision.
69
+
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ num_channels: int,
75
+ num_attention_heads: int,
76
+ num_head_channels: int,
77
+ dropout: float = 0.0,
78
+ cross_attention_dim: int | None = None,
79
+ upcast_attention: bool = False,
80
+ ) -> None:
81
+ super().__init__()
82
+ self.attn1 = SABlock(
83
+ hidden_size=num_attention_heads * num_head_channels,
84
+ hidden_input_size=num_channels,
85
+ num_heads=num_attention_heads,
86
+ dim_head=num_head_channels,
87
+ dropout_rate=dropout,
88
+ attention_dtype=torch.float if upcast_attention else None,
89
+ )
90
+ self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout)
91
+ self.attn2 = CrossAttentionBlock(
92
+ hidden_size=num_attention_heads * num_head_channels,
93
+ num_heads=num_attention_heads,
94
+ hidden_input_size=num_channels,
95
+ context_input_size=cross_attention_dim,
96
+ dim_head=num_head_channels,
97
+ dropout_rate=dropout,
98
+ attention_dtype=torch.float if upcast_attention else None,
99
+ )
100
+ self.norm1 = nn.LayerNorm(num_channels)
101
+ self.norm2 = nn.LayerNorm(num_channels)
102
+ self.norm3 = nn.LayerNorm(num_channels)
103
+
104
+ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
105
+ # 1. Self-Attention
106
+ x = self.attn1(self.norm1(x)) + x
107
+
108
+ # 2. Cross-Attention
109
+ x = self.attn2(self.norm2(x), context=context) + x
110
+
111
+ # 3. Feed-forward
112
+ x = self.ff(self.norm3(x)) + x
113
+ return x
114
+
115
+
116
+ class SpatialTransformer(nn.Module):
117
+ """
118
+ Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
119
+ standard transformer action. Finally, reshape to image.
120
+
121
+ Args:
122
+ spatial_dims: number of spatial dimensions.
123
+ in_channels: number of channels in the input and output.
124
+ num_attention_heads: number of heads to use for multi-head attention.
125
+ num_head_channels: number of channels in each attention head.
126
+ num_layers: number of layers of Transformer blocks to use.
127
+ dropout: dropout probability to use.
128
+ norm_num_groups: number of groups for the normalization.
129
+ norm_eps: epsilon for the normalization.
130
+ cross_attention_dim: number of context dimensions to use.
131
+ upcast_attention: if True, upcast attention operations to full precision.
132
+ """
133
+
134
+ def __init__(
135
+ self,
136
+ spatial_dims: int,
137
+ in_channels: int,
138
+ num_attention_heads: int,
139
+ num_head_channels: int,
140
+ num_layers: int = 1,
141
+ dropout: float = 0.0,
142
+ norm_num_groups: int = 32,
143
+ norm_eps: float = 1e-6,
144
+ cross_attention_dim: int | None = None,
145
+ upcast_attention: bool = False,
146
+ ) -> None:
147
+ super().__init__()
148
+ self.spatial_dims = spatial_dims
149
+ self.in_channels = in_channels
150
+ inner_dim = num_attention_heads * num_head_channels
151
+
152
+ self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True)
153
+
154
+ self.proj_in = Convolution(
155
+ spatial_dims=spatial_dims,
156
+ in_channels=in_channels,
157
+ out_channels=inner_dim,
158
+ strides=1,
159
+ kernel_size=1,
160
+ padding=0,
161
+ conv_only=True,
162
+ )
163
+
164
+ self.transformer_blocks = nn.ModuleList(
165
+ [
166
+ DiffusionUNetTransformerBlock(
167
+ num_channels=inner_dim,
168
+ num_attention_heads=num_attention_heads,
169
+ num_head_channels=num_head_channels,
170
+ dropout=dropout,
171
+ cross_attention_dim=cross_attention_dim,
172
+ upcast_attention=upcast_attention,
173
+ )
174
+ for _ in range(num_layers)
175
+ ]
176
+ )
177
+
178
+ self.proj_out = zero_module(
179
+ Convolution(
180
+ spatial_dims=spatial_dims,
181
+ in_channels=inner_dim,
182
+ out_channels=in_channels,
183
+ strides=1,
184
+ kernel_size=1,
185
+ padding=0,
186
+ conv_only=True,
187
+ )
188
+ )
189
+
190
+ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
191
+ # note: if no context is given, cross-attention defaults to self-attention
192
+ batch = channel = height = width = depth = -1
193
+ if self.spatial_dims == 2:
194
+ batch, channel, height, width = x.shape
195
+ if self.spatial_dims == 3:
196
+ batch, channel, height, width, depth = x.shape
197
+
198
+ residual = x
199
+ x = self.norm(x)
200
+ x = self.proj_in(x)
201
+
202
+ inner_dim = x.shape[1]
203
+
204
+ if self.spatial_dims == 2:
205
+ x = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
206
+ if self.spatial_dims == 3:
207
+ x = x.permute(0, 2, 3, 4, 1).reshape(batch, height * width * depth, inner_dim)
208
+
209
+ for block in self.transformer_blocks:
210
+ x = block(x, context=context)
211
+
212
+ if self.spatial_dims == 2:
213
+ x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
214
+ if self.spatial_dims == 3:
215
+ x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3).contiguous()
216
+
217
+ x = self.proj_out(x)
218
+ return x + residual
219
+
220
+
221
+ def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor:
222
+ """
223
+ Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic
224
+ Models" https://arxiv.org/abs/2006.11239.
225
+
226
+ Args:
227
+ timesteps: a 1-D Tensor of N indices, one per batch element.
228
+ embedding_dim: the dimension of the output.
229
+ max_period: controls the minimum frequency of the embeddings.
230
+ """
231
+ if timesteps.ndim != 1:
232
+ raise ValueError("Timesteps should be a 1d-array")
233
+
234
+ half_dim = embedding_dim // 2
235
+ exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
236
+ freqs = torch.exp(exponent / half_dim)
237
+
238
+ args = timesteps[:, None].float() * freqs[None, :]
239
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
240
+
241
+ # zero pad
242
+ if embedding_dim % 2 == 1:
243
+ embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0))
244
+
245
+ return embedding
246
+
247
+
248
+ class DiffusionUnetDownsample(nn.Module):
249
+ """
250
+ Downsampling layer.
251
+
252
+ Args:
253
+ spatial_dims: number of spatial dimensions.
254
+ num_channels: number of input channels.
255
+ use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is
256
+ False, the number of output channels must be the same as the number of input channels.
257
+ out_channels: number of output channels.
258
+ padding: controls the amount of implicit zero-paddings on both sides for padding number of points
259
+ for each dimension.
260
+ """
261
+
262
+ def __init__(
263
+ self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1
264
+ ) -> None:
265
+ super().__init__()
266
+ self.num_channels = num_channels
267
+ self.out_channels = out_channels or num_channels
268
+ self.use_conv = use_conv
269
+ if use_conv:
270
+ self.op = Convolution(
271
+ spatial_dims=spatial_dims,
272
+ in_channels=self.num_channels,
273
+ out_channels=self.out_channels,
274
+ strides=2,
275
+ kernel_size=3,
276
+ padding=padding,
277
+ conv_only=True,
278
+ )
279
+ else:
280
+ if self.num_channels != self.out_channels:
281
+ raise ValueError("num_channels and out_channels must be equal when use_conv=False")
282
+ self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2)
283
+
284
+ def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
285
+ del emb
286
+ if x.shape[1] != self.num_channels:
287
+ raise ValueError(
288
+ f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels "
289
+ f"({self.num_channels})"
290
+ )
291
+ output: torch.Tensor = self.op(x)
292
+ return output
293
+
294
+
295
+ class WrappedUpsample(Upsample):
296
+ """
297
+ Wraps MONAI upsample block to allow for calling with timestep embeddings.
298
+ """
299
+
300
+ def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
301
+ del emb
302
+ upsampled: torch.Tensor = super().forward(x)
303
+ return upsampled
304
+
305
+
306
+ class DiffusionUNetResnetBlock(nn.Module):
307
+ """
308
+ Residual block with timestep conditioning.
309
+
310
+ Args:
311
+ spatial_dims: The number of spatial dimensions.
312
+ in_channels: number of input channels.
313
+ temb_channels: number of timestep embedding channels.
314
+ out_channels: number of output channels.
315
+ up: if True, performs upsampling.
316
+ down: if True, performs downsampling.
317
+ norm_num_groups: number of groups for the group normalization.
318
+ norm_eps: epsilon for the group normalization.
319
+ """
320
+
321
+ def __init__(
322
+ self,
323
+ spatial_dims: int,
324
+ in_channels: int,
325
+ temb_channels: int,
326
+ out_channels: int | None = None,
327
+ up: bool = False,
328
+ down: bool = False,
329
+ norm_num_groups: int = 32,
330
+ norm_eps: float = 1e-6,
331
+ ) -> None:
332
+ super().__init__()
333
+ self.spatial_dims = spatial_dims
334
+ self.channels = in_channels
335
+ self.emb_channels = temb_channels
336
+ self.out_channels = out_channels or in_channels
337
+ self.up = up
338
+ self.down = down
339
+
340
+ self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True)
341
+ self.nonlinearity = nn.SiLU()
342
+ self.conv1 = Convolution(
343
+ spatial_dims=spatial_dims,
344
+ in_channels=in_channels,
345
+ out_channels=self.out_channels,
346
+ strides=1,
347
+ kernel_size=3,
348
+ padding=1,
349
+ conv_only=True,
350
+ )
351
+
352
+ self.upsample = self.downsample = None
353
+ if self.up:
354
+ self.upsample = WrappedUpsample(
355
+ spatial_dims=spatial_dims,
356
+ mode="nontrainable",
357
+ in_channels=in_channels,
358
+ out_channels=in_channels,
359
+ interp_mode="nearest",
360
+ scale_factor=2.0,
361
+ align_corners=None,
362
+ )
363
+ elif down:
364
+ self.downsample = DiffusionUnetDownsample(spatial_dims, in_channels, use_conv=False)
365
+
366
+ self.time_emb_proj = nn.Linear(temb_channels, self.out_channels)
367
+
368
+ self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=self.out_channels, eps=norm_eps, affine=True)
369
+ self.conv2 = zero_module(
370
+ Convolution(
371
+ spatial_dims=spatial_dims,
372
+ in_channels=self.out_channels,
373
+ out_channels=self.out_channels,
374
+ strides=1,
375
+ kernel_size=3,
376
+ padding=1,
377
+ conv_only=True,
378
+ )
379
+ )
380
+ self.skip_connection: nn.Module
381
+ if self.out_channels == in_channels:
382
+ self.skip_connection = nn.Identity()
383
+ else:
384
+ self.skip_connection = Convolution(
385
+ spatial_dims=spatial_dims,
386
+ in_channels=in_channels,
387
+ out_channels=self.out_channels,
388
+ strides=1,
389
+ kernel_size=1,
390
+ padding=0,
391
+ conv_only=True,
392
+ )
393
+
394
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
395
+ h = x
396
+ h = self.norm1(h)
397
+ h = self.nonlinearity(h)
398
+
399
+ if self.upsample is not None:
400
+ x = self.upsample(x)
401
+ h = self.upsample(h)
402
+ elif self.downsample is not None:
403
+ x = self.downsample(x)
404
+ h = self.downsample(h)
405
+
406
+ h = self.conv1(h)
407
+
408
+ if self.spatial_dims == 2:
409
+ temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None]
410
+ else:
411
+ temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None]
412
+ h = h + temb
413
+
414
+ h = self.norm2(h)
415
+ h = self.nonlinearity(h)
416
+ h = self.conv2(h)
417
+ output: torch.Tensor = self.skip_connection(x) + h
418
+ return output
419
+
420
+
421
+ class DownBlock(nn.Module):
422
+ """
423
+ Unet's down block containing resnet and downsamplers blocks.
424
+
425
+ Args:
426
+ spatial_dims: The number of spatial dimensions.
427
+ in_channels: number of input channels.
428
+ out_channels: number of output channels.
429
+ temb_channels: number of timestep embedding channels.
430
+ num_res_blocks: number of residual blocks.
431
+ norm_num_groups: number of groups for the group normalization.
432
+ norm_eps: epsilon for the group normalization.
433
+ add_downsample: if True add downsample block.
434
+ resblock_updown: if True use residual blocks for downsampling.
435
+ downsample_padding: padding used in the downsampling block.
436
+ """
437
+
438
+ def __init__(
439
+ self,
440
+ spatial_dims: int,
441
+ in_channels: int,
442
+ out_channels: int,
443
+ temb_channels: int,
444
+ num_res_blocks: int = 1,
445
+ norm_num_groups: int = 32,
446
+ norm_eps: float = 1e-6,
447
+ add_downsample: bool = True,
448
+ resblock_updown: bool = False,
449
+ downsample_padding: int = 1,
450
+ ) -> None:
451
+ super().__init__()
452
+ self.resblock_updown = resblock_updown
453
+
454
+ resnets = []
455
+
456
+ for i in range(num_res_blocks):
457
+ in_channels = in_channels if i == 0 else out_channels
458
+ resnets.append(
459
+ DiffusionUNetResnetBlock(
460
+ spatial_dims=spatial_dims,
461
+ in_channels=in_channels,
462
+ out_channels=out_channels,
463
+ temb_channels=temb_channels,
464
+ norm_num_groups=norm_num_groups,
465
+ norm_eps=norm_eps,
466
+ )
467
+ )
468
+
469
+ self.resnets = nn.ModuleList(resnets)
470
+
471
+ if add_downsample:
472
+ self.downsampler: nn.Module | None
473
+ if resblock_updown:
474
+ self.downsampler = DiffusionUNetResnetBlock(
475
+ spatial_dims=spatial_dims,
476
+ in_channels=out_channels,
477
+ out_channels=out_channels,
478
+ temb_channels=temb_channels,
479
+ norm_num_groups=norm_num_groups,
480
+ norm_eps=norm_eps,
481
+ down=True,
482
+ )
483
+ else:
484
+ self.downsampler = DiffusionUnetDownsample(
485
+ spatial_dims=spatial_dims,
486
+ num_channels=out_channels,
487
+ use_conv=True,
488
+ out_channels=out_channels,
489
+ padding=downsample_padding,
490
+ )
491
+ else:
492
+ self.downsampler = None
493
+
494
+ def forward(
495
+ self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
496
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
497
+ del context
498
+ output_states = []
499
+
500
+ for resnet in self.resnets:
501
+ hidden_states = resnet(hidden_states, temb)
502
+ output_states.append(hidden_states)
503
+
504
+ if self.downsampler is not None:
505
+ hidden_states = self.downsampler(hidden_states, temb)
506
+ output_states.append(hidden_states)
507
+
508
+ return hidden_states, output_states
509
+
510
+
511
+ class AttnDownBlock(nn.Module):
512
+ """
513
+ Unet's down block containing resnet, downsamplers and self-attention blocks.
514
+
515
+ Args:
516
+ spatial_dims: The number of spatial dimensions.
517
+ in_channels: number of input channels.
518
+ out_channels: number of output channels.
519
+ temb_channels: number of timestep embedding channels.
520
+ num_res_blocks: number of residual blocks.
521
+ norm_num_groups: number of groups for the group normalization.
522
+ norm_eps: epsilon for the group normalization.
523
+ add_downsample: if True add downsample block.
524
+ resblock_updown: if True use residual blocks for downsampling.
525
+ downsample_padding: padding used in the downsampling block.
526
+ num_head_channels: number of channels in each attention head.
527
+ """
528
+
529
+ def __init__(
530
+ self,
531
+ spatial_dims: int,
532
+ in_channels: int,
533
+ out_channels: int,
534
+ temb_channels: int,
535
+ num_res_blocks: int = 1,
536
+ norm_num_groups: int = 32,
537
+ norm_eps: float = 1e-6,
538
+ add_downsample: bool = True,
539
+ resblock_updown: bool = False,
540
+ downsample_padding: int = 1,
541
+ num_head_channels: int = 1,
542
+ ) -> None:
543
+ super().__init__()
544
+ self.resblock_updown = resblock_updown
545
+
546
+ resnets = []
547
+ attentions = []
548
+
549
+ for i in range(num_res_blocks):
550
+ in_channels = in_channels if i == 0 else out_channels
551
+ resnets.append(
552
+ DiffusionUNetResnetBlock(
553
+ spatial_dims=spatial_dims,
554
+ in_channels=in_channels,
555
+ out_channels=out_channels,
556
+ temb_channels=temb_channels,
557
+ norm_num_groups=norm_num_groups,
558
+ norm_eps=norm_eps,
559
+ )
560
+ )
561
+ attentions.append(
562
+ SpatialAttentionBlock(
563
+ spatial_dims=spatial_dims,
564
+ num_channels=out_channels,
565
+ num_head_channels=num_head_channels,
566
+ norm_num_groups=norm_num_groups,
567
+ norm_eps=norm_eps,
568
+ )
569
+ )
570
+
571
+ self.attentions = nn.ModuleList(attentions)
572
+ self.resnets = nn.ModuleList(resnets)
573
+
574
+ self.downsampler: nn.Module | None
575
+ if add_downsample:
576
+ if resblock_updown:
577
+ self.downsampler = DiffusionUNetResnetBlock(
578
+ spatial_dims=spatial_dims,
579
+ in_channels=out_channels,
580
+ out_channels=out_channels,
581
+ temb_channels=temb_channels,
582
+ norm_num_groups=norm_num_groups,
583
+ norm_eps=norm_eps,
584
+ down=True,
585
+ )
586
+ else:
587
+ self.downsampler = DiffusionUnetDownsample(
588
+ spatial_dims=spatial_dims,
589
+ num_channels=out_channels,
590
+ use_conv=True,
591
+ out_channels=out_channels,
592
+ padding=downsample_padding,
593
+ )
594
+ else:
595
+ self.downsampler = None
596
+
597
+ def forward(
598
+ self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
599
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
600
+ del context
601
+ output_states = []
602
+
603
+ for resnet, attn in zip(self.resnets, self.attentions):
604
+ hidden_states = resnet(hidden_states, temb)
605
+ hidden_states = attn(hidden_states).contiguous()
606
+ output_states.append(hidden_states)
607
+
608
+ if self.downsampler is not None:
609
+ hidden_states = self.downsampler(hidden_states, temb)
610
+ output_states.append(hidden_states)
611
+
612
+ return hidden_states, output_states
613
+
614
+
615
+ class CrossAttnDownBlock(nn.Module):
616
+ """
617
+ Unet's down block containing resnet, downsamplers and cross-attention blocks.
618
+
619
+ Args:
620
+ spatial_dims: number of spatial dimensions.
621
+ in_channels: number of input channels.
622
+ out_channels: number of output channels.
623
+ temb_channels: number of timestep embedding channels.
624
+ num_res_blocks: number of residual blocks.
625
+ norm_num_groups: number of groups for the group normalization.
626
+ norm_eps: epsilon for the group normalization.
627
+ add_downsample: if True add downsample block.
628
+ resblock_updown: if True use residual blocks for downsampling.
629
+ downsample_padding: padding used in the downsampling block.
630
+ num_head_channels: number of channels in each attention head.
631
+ transformer_num_layers: number of layers of Transformer blocks to use.
632
+ cross_attention_dim: number of context dimensions to use.
633
+ upcast_attention: if True, upcast attention operations to full precision.
634
+ dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers
635
+ """
636
+
637
+ def __init__(
638
+ self,
639
+ spatial_dims: int,
640
+ in_channels: int,
641
+ out_channels: int,
642
+ temb_channels: int,
643
+ num_res_blocks: int = 1,
644
+ norm_num_groups: int = 32,
645
+ norm_eps: float = 1e-6,
646
+ add_downsample: bool = True,
647
+ resblock_updown: bool = False,
648
+ downsample_padding: int = 1,
649
+ num_head_channels: int = 1,
650
+ transformer_num_layers: int = 1,
651
+ cross_attention_dim: int | None = None,
652
+ upcast_attention: bool = False,
653
+ dropout_cattn: float = 0.0,
654
+ ) -> None:
655
+ super().__init__()
656
+ self.resblock_updown = resblock_updown
657
+
658
+ resnets = []
659
+ attentions = []
660
+
661
+ for i in range(num_res_blocks):
662
+ in_channels = in_channels if i == 0 else out_channels
663
+ resnets.append(
664
+ DiffusionUNetResnetBlock(
665
+ spatial_dims=spatial_dims,
666
+ in_channels=in_channels,
667
+ out_channels=out_channels,
668
+ temb_channels=temb_channels,
669
+ norm_num_groups=norm_num_groups,
670
+ norm_eps=norm_eps,
671
+ )
672
+ )
673
+
674
+ attentions.append(
675
+ SpatialTransformer(
676
+ spatial_dims=spatial_dims,
677
+ in_channels=out_channels,
678
+ num_attention_heads=out_channels // num_head_channels,
679
+ num_head_channels=num_head_channels,
680
+ num_layers=transformer_num_layers,
681
+ norm_num_groups=norm_num_groups,
682
+ norm_eps=norm_eps,
683
+ cross_attention_dim=cross_attention_dim,
684
+ upcast_attention=upcast_attention,
685
+ dropout=dropout_cattn,
686
+ )
687
+ )
688
+
689
+ self.attentions = nn.ModuleList(attentions)
690
+ self.resnets = nn.ModuleList(resnets)
691
+
692
+ self.downsampler: nn.Module | None
693
+ if add_downsample:
694
+ if resblock_updown:
695
+ self.downsampler = DiffusionUNetResnetBlock(
696
+ spatial_dims=spatial_dims,
697
+ in_channels=out_channels,
698
+ out_channels=out_channels,
699
+ temb_channels=temb_channels,
700
+ norm_num_groups=norm_num_groups,
701
+ norm_eps=norm_eps,
702
+ down=True,
703
+ )
704
+ else:
705
+ self.downsampler = DiffusionUnetDownsample(
706
+ spatial_dims=spatial_dims,
707
+ num_channels=out_channels,
708
+ use_conv=True,
709
+ out_channels=out_channels,
710
+ padding=downsample_padding,
711
+ )
712
+ else:
713
+ self.downsampler = None
714
+
715
+ def forward(
716
+ self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
717
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
718
+ output_states = []
719
+
720
+ for resnet, attn in zip(self.resnets, self.attentions):
721
+ hidden_states = resnet(hidden_states, temb)
722
+ hidden_states = attn(hidden_states, context=context).contiguous()
723
+ output_states.append(hidden_states)
724
+
725
+ if self.downsampler is not None:
726
+ hidden_states = self.downsampler(hidden_states, temb)
727
+ output_states.append(hidden_states)
728
+
729
+ return hidden_states, output_states
730
+
731
+
732
+ class AttnMidBlock(nn.Module):
733
+ """
734
+ Unet's mid block containing resnet and self-attention blocks.
735
+
736
+ Args:
737
+ spatial_dims: The number of spatial dimensions.
738
+ in_channels: number of input channels.
739
+ temb_channels: number of timestep embedding channels.
740
+ norm_num_groups: number of groups for the group normalization.
741
+ norm_eps: epsilon for the group normalization.
742
+ num_head_channels: number of channels in each attention head.
743
+ """
744
+
745
+ def __init__(
746
+ self,
747
+ spatial_dims: int,
748
+ in_channels: int,
749
+ temb_channels: int,
750
+ norm_num_groups: int = 32,
751
+ norm_eps: float = 1e-6,
752
+ num_head_channels: int = 1,
753
+ ) -> None:
754
+ super().__init__()
755
+
756
+ self.resnet_1 = DiffusionUNetResnetBlock(
757
+ spatial_dims=spatial_dims,
758
+ in_channels=in_channels,
759
+ out_channels=in_channels,
760
+ temb_channels=temb_channels,
761
+ norm_num_groups=norm_num_groups,
762
+ norm_eps=norm_eps,
763
+ )
764
+ self.attention = SpatialAttentionBlock(
765
+ spatial_dims=spatial_dims,
766
+ num_channels=in_channels,
767
+ num_head_channels=num_head_channels,
768
+ norm_num_groups=norm_num_groups,
769
+ norm_eps=norm_eps,
770
+ )
771
+
772
+ self.resnet_2 = DiffusionUNetResnetBlock(
773
+ spatial_dims=spatial_dims,
774
+ in_channels=in_channels,
775
+ out_channels=in_channels,
776
+ temb_channels=temb_channels,
777
+ norm_num_groups=norm_num_groups,
778
+ norm_eps=norm_eps,
779
+ )
780
+
781
+ def forward(
782
+ self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
783
+ ) -> torch.Tensor:
784
+ del context
785
+ hidden_states = self.resnet_1(hidden_states, temb)
786
+ hidden_states = self.attention(hidden_states).contiguous()
787
+ hidden_states = self.resnet_2(hidden_states, temb)
788
+
789
+ return hidden_states
790
+
791
+
792
+ class CrossAttnMidBlock(nn.Module):
793
+ """
794
+ Unet's mid block containing resnet and cross-attention blocks.
795
+
796
+ Args:
797
+ spatial_dims: The number of spatial dimensions.
798
+ in_channels: number of input channels.
799
+ temb_channels: number of timestep embedding channels
800
+ norm_num_groups: number of groups for the group normalization.
801
+ norm_eps: epsilon for the group normalization.
802
+ num_head_channels: number of channels in each attention head.
803
+ transformer_num_layers: number of layers of Transformer blocks to use.
804
+ cross_attention_dim: number of context dimensions to use.
805
+ upcast_attention: if True, upcast attention operations to full precision.
806
+ """
807
+
808
+ def __init__(
809
+ self,
810
+ spatial_dims: int,
811
+ in_channels: int,
812
+ temb_channels: int,
813
+ norm_num_groups: int = 32,
814
+ norm_eps: float = 1e-6,
815
+ num_head_channels: int = 1,
816
+ transformer_num_layers: int = 1,
817
+ cross_attention_dim: int | None = None,
818
+ upcast_attention: bool = False,
819
+ dropout_cattn: float = 0.0,
820
+ ) -> None:
821
+ super().__init__()
822
+
823
+ self.resnet_1 = DiffusionUNetResnetBlock(
824
+ spatial_dims=spatial_dims,
825
+ in_channels=in_channels,
826
+ out_channels=in_channels,
827
+ temb_channels=temb_channels,
828
+ norm_num_groups=norm_num_groups,
829
+ norm_eps=norm_eps,
830
+ )
831
+ self.attention = SpatialTransformer(
832
+ spatial_dims=spatial_dims,
833
+ in_channels=in_channels,
834
+ num_attention_heads=in_channels // num_head_channels,
835
+ num_head_channels=num_head_channels,
836
+ num_layers=transformer_num_layers,
837
+ norm_num_groups=norm_num_groups,
838
+ norm_eps=norm_eps,
839
+ cross_attention_dim=cross_attention_dim,
840
+ upcast_attention=upcast_attention,
841
+ dropout=dropout_cattn,
842
+ )
843
+ self.resnet_2 = DiffusionUNetResnetBlock(
844
+ spatial_dims=spatial_dims,
845
+ in_channels=in_channels,
846
+ out_channels=in_channels,
847
+ temb_channels=temb_channels,
848
+ norm_num_groups=norm_num_groups,
849
+ norm_eps=norm_eps,
850
+ )
851
+
852
+ def forward(
853
+ self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
854
+ ) -> torch.Tensor:
855
+ hidden_states = self.resnet_1(hidden_states, temb)
856
+ hidden_states = self.attention(hidden_states, context=context)
857
+ hidden_states = self.resnet_2(hidden_states, temb)
858
+
859
+ return hidden_states
860
+
861
+
862
+ class UpBlock(nn.Module):
863
+ """
864
+ Unet's up block containing resnet and upsamplers blocks.
865
+
866
+ Args:
867
+ spatial_dims: The number of spatial dimensions.
868
+ in_channels: number of input channels.
869
+ prev_output_channel: number of channels from residual connection.
870
+ out_channels: number of output channels.
871
+ temb_channels: number of timestep embedding channels.
872
+ num_res_blocks: number of residual blocks.
873
+ norm_num_groups: number of groups for the group normalization.
874
+ norm_eps: epsilon for the group normalization.
875
+ add_upsample: if True add downsample block.
876
+ resblock_updown: if True use residual blocks for upsampling.
877
+ """
878
+
879
+ def __init__(
880
+ self,
881
+ spatial_dims: int,
882
+ in_channels: int,
883
+ prev_output_channel: int,
884
+ out_channels: int,
885
+ temb_channels: int,
886
+ num_res_blocks: int = 1,
887
+ norm_num_groups: int = 32,
888
+ norm_eps: float = 1e-6,
889
+ add_upsample: bool = True,
890
+ resblock_updown: bool = False,
891
+ ) -> None:
892
+ super().__init__()
893
+ self.resblock_updown = resblock_updown
894
+ resnets = []
895
+
896
+ for i in range(num_res_blocks):
897
+ res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels
898
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
899
+
900
+ resnets.append(
901
+ DiffusionUNetResnetBlock(
902
+ spatial_dims=spatial_dims,
903
+ in_channels=resnet_in_channels + res_skip_channels,
904
+ out_channels=out_channels,
905
+ temb_channels=temb_channels,
906
+ norm_num_groups=norm_num_groups,
907
+ norm_eps=norm_eps,
908
+ )
909
+ )
910
+
911
+ self.resnets = nn.ModuleList(resnets)
912
+
913
+ self.upsampler: nn.Module | None
914
+ if add_upsample:
915
+ if resblock_updown:
916
+ self.upsampler = DiffusionUNetResnetBlock(
917
+ spatial_dims=spatial_dims,
918
+ in_channels=out_channels,
919
+ out_channels=out_channels,
920
+ temb_channels=temb_channels,
921
+ norm_num_groups=norm_num_groups,
922
+ norm_eps=norm_eps,
923
+ up=True,
924
+ )
925
+ else:
926
+ post_conv = Convolution(
927
+ spatial_dims=spatial_dims,
928
+ in_channels=out_channels,
929
+ out_channels=out_channels,
930
+ strides=1,
931
+ kernel_size=3,
932
+ padding=1,
933
+ conv_only=True,
934
+ )
935
+ self.upsampler = WrappedUpsample(
936
+ spatial_dims=spatial_dims,
937
+ mode="nontrainable",
938
+ in_channels=out_channels,
939
+ out_channels=out_channels,
940
+ interp_mode="nearest",
941
+ scale_factor=2.0,
942
+ post_conv=post_conv,
943
+ align_corners=None,
944
+ )
945
+
946
+ else:
947
+ self.upsampler = None
948
+
949
+ def forward(
950
+ self,
951
+ hidden_states: torch.Tensor,
952
+ res_hidden_states_list: list[torch.Tensor],
953
+ temb: torch.Tensor,
954
+ context: torch.Tensor | None = None,
955
+ ) -> torch.Tensor:
956
+ del context
957
+ for resnet in self.resnets:
958
+ # pop res hidden states
959
+ res_hidden_states = res_hidden_states_list[-1]
960
+ res_hidden_states_list = res_hidden_states_list[:-1]
961
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
962
+
963
+ hidden_states = resnet(hidden_states, temb)
964
+
965
+ if self.upsampler is not None:
966
+ hidden_states = self.upsampler(hidden_states, temb)
967
+
968
+ return hidden_states
969
+
970
+
971
+ class AttnUpBlock(nn.Module):
972
+ """
973
+ Unet's up block containing resnet, upsamplers, and self-attention blocks.
974
+
975
+ Args:
976
+ spatial_dims: The number of spatial dimensions.
977
+ in_channels: number of input channels.
978
+ prev_output_channel: number of channels from residual connection.
979
+ out_channels: number of output channels.
980
+ temb_channels: number of timestep embedding channels.
981
+ num_res_blocks: number of residual blocks.
982
+ norm_num_groups: number of groups for the group normalization.
983
+ norm_eps: epsilon for the group normalization.
984
+ add_upsample: if True add downsample block.
985
+ resblock_updown: if True use residual blocks for upsampling.
986
+ num_head_channels: number of channels in each attention head.
987
+ """
988
+
989
+ def __init__(
990
+ self,
991
+ spatial_dims: int,
992
+ in_channels: int,
993
+ prev_output_channel: int,
994
+ out_channels: int,
995
+ temb_channels: int,
996
+ num_res_blocks: int = 1,
997
+ norm_num_groups: int = 32,
998
+ norm_eps: float = 1e-6,
999
+ add_upsample: bool = True,
1000
+ resblock_updown: bool = False,
1001
+ num_head_channels: int = 1,
1002
+ ) -> None:
1003
+ super().__init__()
1004
+ self.resblock_updown = resblock_updown
1005
+
1006
+ resnets = []
1007
+ attentions = []
1008
+
1009
+ for i in range(num_res_blocks):
1010
+ res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels
1011
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1012
+
1013
+ resnets.append(
1014
+ DiffusionUNetResnetBlock(
1015
+ spatial_dims=spatial_dims,
1016
+ in_channels=resnet_in_channels + res_skip_channels,
1017
+ out_channels=out_channels,
1018
+ temb_channels=temb_channels,
1019
+ norm_num_groups=norm_num_groups,
1020
+ norm_eps=norm_eps,
1021
+ )
1022
+ )
1023
+ attentions.append(
1024
+ SpatialAttentionBlock(
1025
+ spatial_dims=spatial_dims,
1026
+ num_channels=out_channels,
1027
+ num_head_channels=num_head_channels,
1028
+ norm_num_groups=norm_num_groups,
1029
+ norm_eps=norm_eps,
1030
+ )
1031
+ )
1032
+
1033
+ self.resnets = nn.ModuleList(resnets)
1034
+ self.attentions = nn.ModuleList(attentions)
1035
+
1036
+ self.upsampler: nn.Module | None
1037
+ if add_upsample:
1038
+ if resblock_updown:
1039
+ self.upsampler = DiffusionUNetResnetBlock(
1040
+ spatial_dims=spatial_dims,
1041
+ in_channels=out_channels,
1042
+ out_channels=out_channels,
1043
+ temb_channels=temb_channels,
1044
+ norm_num_groups=norm_num_groups,
1045
+ norm_eps=norm_eps,
1046
+ up=True,
1047
+ )
1048
+ else:
1049
+
1050
+ post_conv = Convolution(
1051
+ spatial_dims=spatial_dims,
1052
+ in_channels=out_channels,
1053
+ out_channels=out_channels,
1054
+ strides=1,
1055
+ kernel_size=3,
1056
+ padding=1,
1057
+ conv_only=True,
1058
+ )
1059
+ self.upsampler = WrappedUpsample(
1060
+ spatial_dims=spatial_dims,
1061
+ mode="nontrainable",
1062
+ in_channels=out_channels,
1063
+ out_channels=out_channels,
1064
+ interp_mode="nearest",
1065
+ scale_factor=2.0,
1066
+ post_conv=post_conv,
1067
+ align_corners=None,
1068
+ )
1069
+ else:
1070
+ self.upsampler = None
1071
+
1072
+ def forward(
1073
+ self,
1074
+ hidden_states: torch.Tensor,
1075
+ res_hidden_states_list: list[torch.Tensor],
1076
+ temb: torch.Tensor,
1077
+ context: torch.Tensor | None = None,
1078
+ ) -> torch.Tensor:
1079
+ del context
1080
+ for resnet, attn in zip(self.resnets, self.attentions):
1081
+ # pop res hidden states
1082
+ res_hidden_states = res_hidden_states_list[-1]
1083
+ res_hidden_states_list = res_hidden_states_list[:-1]
1084
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1085
+
1086
+ hidden_states = resnet(hidden_states, temb)
1087
+ hidden_states = attn(hidden_states).contiguous()
1088
+
1089
+ if self.upsampler is not None:
1090
+ hidden_states = self.upsampler(hidden_states, temb)
1091
+
1092
+ return hidden_states
1093
+
1094
+
1095
+ class CrossAttnUpBlock(nn.Module):
1096
+ """
1097
+ Unet's up block containing resnet, upsamplers, and self-attention blocks.
1098
+
1099
+ Args:
1100
+ spatial_dims: The number of spatial dimensions.
1101
+ in_channels: number of input channels.
1102
+ prev_output_channel: number of channels from residual connection.
1103
+ out_channels: number of output channels.
1104
+ temb_channels: number of timestep embedding channels.
1105
+ num_res_blocks: number of residual blocks.
1106
+ norm_num_groups: number of groups for the group normalization.
1107
+ norm_eps: epsilon for the group normalization.
1108
+ add_upsample: if True add downsample block.
1109
+ resblock_updown: if True use residual blocks for upsampling.
1110
+ num_head_channels: number of channels in each attention head.
1111
+ transformer_num_layers: number of layers of Transformer blocks to use.
1112
+ cross_attention_dim: number of context dimensions to use.
1113
+ upcast_attention: if True, upcast attention operations to full precision.
1114
+ dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers
1115
+ """
1116
+
1117
+ def __init__(
1118
+ self,
1119
+ spatial_dims: int,
1120
+ in_channels: int,
1121
+ prev_output_channel: int,
1122
+ out_channels: int,
1123
+ temb_channels: int,
1124
+ num_res_blocks: int = 1,
1125
+ norm_num_groups: int = 32,
1126
+ norm_eps: float = 1e-6,
1127
+ add_upsample: bool = True,
1128
+ resblock_updown: bool = False,
1129
+ num_head_channels: int = 1,
1130
+ transformer_num_layers: int = 1,
1131
+ cross_attention_dim: int | None = None,
1132
+ upcast_attention: bool = False,
1133
+ dropout_cattn: float = 0.0,
1134
+ ) -> None:
1135
+ super().__init__()
1136
+ self.resblock_updown = resblock_updown
1137
+
1138
+ resnets = []
1139
+ attentions = []
1140
+
1141
+ for i in range(num_res_blocks):
1142
+ res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels
1143
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1144
+
1145
+ resnets.append(
1146
+ DiffusionUNetResnetBlock(
1147
+ spatial_dims=spatial_dims,
1148
+ in_channels=resnet_in_channels + res_skip_channels,
1149
+ out_channels=out_channels,
1150
+ temb_channels=temb_channels,
1151
+ norm_num_groups=norm_num_groups,
1152
+ norm_eps=norm_eps,
1153
+ )
1154
+ )
1155
+ attentions.append(
1156
+ SpatialTransformer(
1157
+ spatial_dims=spatial_dims,
1158
+ in_channels=out_channels,
1159
+ num_attention_heads=out_channels // num_head_channels,
1160
+ num_head_channels=num_head_channels,
1161
+ norm_num_groups=norm_num_groups,
1162
+ norm_eps=norm_eps,
1163
+ num_layers=transformer_num_layers,
1164
+ cross_attention_dim=cross_attention_dim,
1165
+ upcast_attention=upcast_attention,
1166
+ dropout=dropout_cattn,
1167
+ )
1168
+ )
1169
+
1170
+ self.attentions = nn.ModuleList(attentions)
1171
+ self.resnets = nn.ModuleList(resnets)
1172
+
1173
+ self.upsampler: nn.Module | None
1174
+ if add_upsample:
1175
+ if resblock_updown:
1176
+ self.upsampler = DiffusionUNetResnetBlock(
1177
+ spatial_dims=spatial_dims,
1178
+ in_channels=out_channels,
1179
+ out_channels=out_channels,
1180
+ temb_channels=temb_channels,
1181
+ norm_num_groups=norm_num_groups,
1182
+ norm_eps=norm_eps,
1183
+ up=True,
1184
+ )
1185
+ else:
1186
+
1187
+ post_conv = Convolution(
1188
+ spatial_dims=spatial_dims,
1189
+ in_channels=out_channels,
1190
+ out_channels=out_channels,
1191
+ strides=1,
1192
+ kernel_size=3,
1193
+ padding=1,
1194
+ conv_only=True,
1195
+ )
1196
+ self.upsampler = WrappedUpsample(
1197
+ spatial_dims=spatial_dims,
1198
+ mode="nontrainable",
1199
+ in_channels=out_channels,
1200
+ out_channels=out_channels,
1201
+ interp_mode="nearest",
1202
+ scale_factor=2.0,
1203
+ post_conv=post_conv,
1204
+ align_corners=None,
1205
+ )
1206
+ else:
1207
+ self.upsampler = None
1208
+
1209
+ def forward(
1210
+ self,
1211
+ hidden_states: torch.Tensor,
1212
+ res_hidden_states_list: list[torch.Tensor],
1213
+ temb: torch.Tensor,
1214
+ context: torch.Tensor | None = None,
1215
+ ) -> torch.Tensor:
1216
+ for resnet, attn in zip(self.resnets, self.attentions):
1217
+ # pop res hidden states
1218
+ res_hidden_states = res_hidden_states_list[-1]
1219
+ res_hidden_states_list = res_hidden_states_list[:-1]
1220
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1221
+
1222
+ hidden_states = resnet(hidden_states, temb)
1223
+ hidden_states = attn(hidden_states, context=context)
1224
+
1225
+ if self.upsampler is not None:
1226
+ hidden_states = self.upsampler(hidden_states, temb)
1227
+
1228
+ return hidden_states
1229
+
1230
+
1231
+ def get_down_block(
1232
+ spatial_dims: int,
1233
+ in_channels: int,
1234
+ out_channels: int,
1235
+ temb_channels: int,
1236
+ num_res_blocks: int,
1237
+ norm_num_groups: int,
1238
+ norm_eps: float,
1239
+ add_downsample: bool,
1240
+ resblock_updown: bool,
1241
+ with_attn: bool,
1242
+ with_cross_attn: bool,
1243
+ num_head_channels: int,
1244
+ transformer_num_layers: int,
1245
+ cross_attention_dim: int | None,
1246
+ upcast_attention: bool = False,
1247
+ dropout_cattn: float = 0.0,
1248
+ ) -> nn.Module:
1249
+ if with_attn:
1250
+ return AttnDownBlock(
1251
+ spatial_dims=spatial_dims,
1252
+ in_channels=in_channels,
1253
+ out_channels=out_channels,
1254
+ temb_channels=temb_channels,
1255
+ num_res_blocks=num_res_blocks,
1256
+ norm_num_groups=norm_num_groups,
1257
+ norm_eps=norm_eps,
1258
+ add_downsample=add_downsample,
1259
+ resblock_updown=resblock_updown,
1260
+ num_head_channels=num_head_channels,
1261
+ )
1262
+ elif with_cross_attn:
1263
+ return CrossAttnDownBlock(
1264
+ spatial_dims=spatial_dims,
1265
+ in_channels=in_channels,
1266
+ out_channels=out_channels,
1267
+ temb_channels=temb_channels,
1268
+ num_res_blocks=num_res_blocks,
1269
+ norm_num_groups=norm_num_groups,
1270
+ norm_eps=norm_eps,
1271
+ add_downsample=add_downsample,
1272
+ resblock_updown=resblock_updown,
1273
+ num_head_channels=num_head_channels,
1274
+ transformer_num_layers=transformer_num_layers,
1275
+ cross_attention_dim=cross_attention_dim,
1276
+ upcast_attention=upcast_attention,
1277
+ dropout_cattn=dropout_cattn,
1278
+ )
1279
+ else:
1280
+ return DownBlock(
1281
+ spatial_dims=spatial_dims,
1282
+ in_channels=in_channels,
1283
+ out_channels=out_channels,
1284
+ temb_channels=temb_channels,
1285
+ num_res_blocks=num_res_blocks,
1286
+ norm_num_groups=norm_num_groups,
1287
+ norm_eps=norm_eps,
1288
+ add_downsample=add_downsample,
1289
+ resblock_updown=resblock_updown,
1290
+ )
1291
+
1292
+
1293
+ def get_mid_block(
1294
+ spatial_dims: int,
1295
+ in_channels: int,
1296
+ temb_channels: int,
1297
+ norm_num_groups: int,
1298
+ norm_eps: float,
1299
+ with_conditioning: bool,
1300
+ num_head_channels: int,
1301
+ transformer_num_layers: int,
1302
+ cross_attention_dim: int | None,
1303
+ upcast_attention: bool = False,
1304
+ dropout_cattn: float = 0.0,
1305
+ ) -> nn.Module:
1306
+ if with_conditioning:
1307
+ return CrossAttnMidBlock(
1308
+ spatial_dims=spatial_dims,
1309
+ in_channels=in_channels,
1310
+ temb_channels=temb_channels,
1311
+ norm_num_groups=norm_num_groups,
1312
+ norm_eps=norm_eps,
1313
+ num_head_channels=num_head_channels,
1314
+ transformer_num_layers=transformer_num_layers,
1315
+ cross_attention_dim=cross_attention_dim,
1316
+ upcast_attention=upcast_attention,
1317
+ dropout_cattn=dropout_cattn,
1318
+ )
1319
+ else:
1320
+ return AttnMidBlock(
1321
+ spatial_dims=spatial_dims,
1322
+ in_channels=in_channels,
1323
+ temb_channels=temb_channels,
1324
+ norm_num_groups=norm_num_groups,
1325
+ norm_eps=norm_eps,
1326
+ num_head_channels=num_head_channels,
1327
+ )
1328
+
1329
+
1330
+ def get_up_block(
1331
+ spatial_dims: int,
1332
+ in_channels: int,
1333
+ prev_output_channel: int,
1334
+ out_channels: int,
1335
+ temb_channels: int,
1336
+ num_res_blocks: int,
1337
+ norm_num_groups: int,
1338
+ norm_eps: float,
1339
+ add_upsample: bool,
1340
+ resblock_updown: bool,
1341
+ with_attn: bool,
1342
+ with_cross_attn: bool,
1343
+ num_head_channels: int,
1344
+ transformer_num_layers: int,
1345
+ cross_attention_dim: int | None,
1346
+ upcast_attention: bool = False,
1347
+ dropout_cattn: float = 0.0,
1348
+ ) -> nn.Module:
1349
+ if with_attn:
1350
+ return AttnUpBlock(
1351
+ spatial_dims=spatial_dims,
1352
+ in_channels=in_channels,
1353
+ prev_output_channel=prev_output_channel,
1354
+ out_channels=out_channels,
1355
+ temb_channels=temb_channels,
1356
+ num_res_blocks=num_res_blocks,
1357
+ norm_num_groups=norm_num_groups,
1358
+ norm_eps=norm_eps,
1359
+ add_upsample=add_upsample,
1360
+ resblock_updown=resblock_updown,
1361
+ num_head_channels=num_head_channels,
1362
+ )
1363
+ elif with_cross_attn:
1364
+ return CrossAttnUpBlock(
1365
+ spatial_dims=spatial_dims,
1366
+ in_channels=in_channels,
1367
+ prev_output_channel=prev_output_channel,
1368
+ out_channels=out_channels,
1369
+ temb_channels=temb_channels,
1370
+ num_res_blocks=num_res_blocks,
1371
+ norm_num_groups=norm_num_groups,
1372
+ norm_eps=norm_eps,
1373
+ add_upsample=add_upsample,
1374
+ resblock_updown=resblock_updown,
1375
+ num_head_channels=num_head_channels,
1376
+ transformer_num_layers=transformer_num_layers,
1377
+ cross_attention_dim=cross_attention_dim,
1378
+ upcast_attention=upcast_attention,
1379
+ dropout_cattn=dropout_cattn,
1380
+ )
1381
+ else:
1382
+ return UpBlock(
1383
+ spatial_dims=spatial_dims,
1384
+ in_channels=in_channels,
1385
+ prev_output_channel=prev_output_channel,
1386
+ out_channels=out_channels,
1387
+ temb_channels=temb_channels,
1388
+ num_res_blocks=num_res_blocks,
1389
+ norm_num_groups=norm_num_groups,
1390
+ norm_eps=norm_eps,
1391
+ add_upsample=add_upsample,
1392
+ resblock_updown=resblock_updown,
1393
+ )
1394
+
1395
+
1396
+ class DiffusionModelUNet(nn.Module):
1397
+ """
1398
+ Unet network with timestep embedding and attention mechanisms for conditioning based on
1399
+ Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752
1400
+ and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162
1401
+
1402
+ Args:
1403
+ spatial_dims: number of spatial dimensions.
1404
+ in_channels: number of input channels.
1405
+ out_channels: number of output channels.
1406
+ num_res_blocks: number of residual blocks (see _ResnetBlock) per level.
1407
+ channels: tuple of block output channels.
1408
+ attention_levels: list of levels to add attention.
1409
+ norm_num_groups: number of groups for the normalization.
1410
+ norm_eps: epsilon for the normalization.
1411
+ resblock_updown: if True use residual blocks for up/downsampling.
1412
+ num_head_channels: number of channels in each attention head.
1413
+ with_conditioning: if True add spatial transformers to perform conditioning.
1414
+ transformer_num_layers: number of layers of Transformer blocks to use.
1415
+ cross_attention_dim: number of context dimensions to use.
1416
+ num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
1417
+ classes.
1418
+ upcast_attention: if True, upcast attention operations to full precision.
1419
+ dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers
1420
+ """
1421
+
1422
+ def __init__(
1423
+ self,
1424
+ spatial_dims: int,
1425
+ in_channels: int,
1426
+ out_channels: int,
1427
+ num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
1428
+ channels: Sequence[int] = (32, 64, 64, 64),
1429
+ attention_levels: Sequence[bool] = (False, False, True, True),
1430
+ norm_num_groups: int = 32,
1431
+ norm_eps: float = 1e-6,
1432
+ resblock_updown: bool = False,
1433
+ num_head_channels: int | Sequence[int] = 8,
1434
+ with_conditioning: bool = False,
1435
+ transformer_num_layers: int = 1,
1436
+ cross_attention_dim: int | None = None,
1437
+ num_class_embeds: int | None = None,
1438
+ upcast_attention: bool = False,
1439
+ dropout_cattn: float = 0.0,
1440
+ ) -> None:
1441
+ super().__init__()
1442
+ if with_conditioning is True and cross_attention_dim is None:
1443
+ raise ValueError(
1444
+ "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) "
1445
+ "when using with_conditioning."
1446
+ )
1447
+ if cross_attention_dim is not None and with_conditioning is False:
1448
+ raise ValueError(
1449
+ "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim."
1450
+ )
1451
+ if dropout_cattn > 1.0 or dropout_cattn < 0.0:
1452
+ raise ValueError("Dropout cannot be negative or >1.0!")
1453
+
1454
+ # All number of channels should be multiple of num_groups
1455
+ if any((out_channel % norm_num_groups) != 0 for out_channel in channels):
1456
+ raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups")
1457
+
1458
+ if len(channels) != len(attention_levels):
1459
+ raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels")
1460
+
1461
+ if isinstance(num_head_channels, int):
1462
+ num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))
1463
+
1464
+ if len(num_head_channels) != len(attention_levels):
1465
+ raise ValueError(
1466
+ "num_head_channels should have the same length as attention_levels. For the i levels without attention,"
1467
+ " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored."
1468
+ )
1469
+
1470
+ if isinstance(num_res_blocks, int):
1471
+ num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))
1472
+
1473
+ if len(num_res_blocks) != len(channels):
1474
+ raise ValueError(
1475
+ "`num_res_blocks` should be a single integer or a tuple of integers with the same length as "
1476
+ "`num_channels`."
1477
+ )
1478
+
1479
+ self.in_channels = in_channels
1480
+ self.block_out_channels = channels
1481
+ self.out_channels = out_channels
1482
+ self.num_res_blocks = num_res_blocks
1483
+ self.attention_levels = attention_levels
1484
+ self.num_head_channels = num_head_channels
1485
+ self.with_conditioning = with_conditioning
1486
+
1487
+ # input
1488
+ self.conv_in = Convolution(
1489
+ spatial_dims=spatial_dims,
1490
+ in_channels=in_channels,
1491
+ out_channels=channels[0],
1492
+ strides=1,
1493
+ kernel_size=3,
1494
+ padding=1,
1495
+ conv_only=True,
1496
+ )
1497
+
1498
+ # time
1499
+ time_embed_dim = channels[0] * 4
1500
+ self.time_embed = nn.Sequential(
1501
+ nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim)
1502
+ )
1503
+
1504
+ # class embedding
1505
+ self.num_class_embeds = num_class_embeds
1506
+ if num_class_embeds is not None:
1507
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
1508
+
1509
+ # down
1510
+ self.down_blocks = nn.ModuleList([])
1511
+ output_channel = channels[0]
1512
+ for i in range(len(channels)):
1513
+ input_channel = output_channel
1514
+ output_channel = channels[i]
1515
+ is_final_block = i == len(channels) - 1
1516
+
1517
+ down_block = get_down_block(
1518
+ spatial_dims=spatial_dims,
1519
+ in_channels=input_channel,
1520
+ out_channels=output_channel,
1521
+ temb_channels=time_embed_dim,
1522
+ num_res_blocks=num_res_blocks[i],
1523
+ norm_num_groups=norm_num_groups,
1524
+ norm_eps=norm_eps,
1525
+ add_downsample=not is_final_block,
1526
+ resblock_updown=resblock_updown,
1527
+ with_attn=(attention_levels[i] and not with_conditioning),
1528
+ with_cross_attn=(attention_levels[i] and with_conditioning),
1529
+ num_head_channels=num_head_channels[i],
1530
+ transformer_num_layers=transformer_num_layers,
1531
+ cross_attention_dim=cross_attention_dim,
1532
+ upcast_attention=upcast_attention,
1533
+ dropout_cattn=dropout_cattn,
1534
+ )
1535
+
1536
+ self.down_blocks.append(down_block)
1537
+
1538
+ # mid
1539
+ self.middle_block = get_mid_block(
1540
+ spatial_dims=spatial_dims,
1541
+ in_channels=channels[-1],
1542
+ temb_channels=time_embed_dim,
1543
+ norm_num_groups=norm_num_groups,
1544
+ norm_eps=norm_eps,
1545
+ with_conditioning=with_conditioning,
1546
+ num_head_channels=num_head_channels[-1],
1547
+ transformer_num_layers=transformer_num_layers,
1548
+ cross_attention_dim=cross_attention_dim,
1549
+ upcast_attention=upcast_attention,
1550
+ dropout_cattn=dropout_cattn,
1551
+ )
1552
+
1553
+ # up
1554
+ self.up_blocks = nn.ModuleList([])
1555
+ reversed_block_out_channels = list(reversed(channels))
1556
+ reversed_num_res_blocks = list(reversed(num_res_blocks))
1557
+ reversed_attention_levels = list(reversed(attention_levels))
1558
+ reversed_num_head_channels = list(reversed(num_head_channels))
1559
+ output_channel = reversed_block_out_channels[0]
1560
+ for i in range(len(reversed_block_out_channels)):
1561
+ prev_output_channel = output_channel
1562
+ output_channel = reversed_block_out_channels[i]
1563
+ input_channel = reversed_block_out_channels[min(i + 1, len(channels) - 1)]
1564
+
1565
+ is_final_block = i == len(channels) - 1
1566
+
1567
+ up_block = get_up_block(
1568
+ spatial_dims=spatial_dims,
1569
+ in_channels=input_channel,
1570
+ prev_output_channel=prev_output_channel,
1571
+ out_channels=output_channel,
1572
+ temb_channels=time_embed_dim,
1573
+ num_res_blocks=reversed_num_res_blocks[i] + 1,
1574
+ norm_num_groups=norm_num_groups,
1575
+ norm_eps=norm_eps,
1576
+ add_upsample=not is_final_block,
1577
+ resblock_updown=resblock_updown,
1578
+ with_attn=(reversed_attention_levels[i] and not with_conditioning),
1579
+ with_cross_attn=(reversed_attention_levels[i] and with_conditioning),
1580
+ num_head_channels=reversed_num_head_channels[i],
1581
+ transformer_num_layers=transformer_num_layers,
1582
+ cross_attention_dim=cross_attention_dim,
1583
+ upcast_attention=upcast_attention,
1584
+ dropout_cattn=dropout_cattn,
1585
+ )
1586
+
1587
+ self.up_blocks.append(up_block)
1588
+
1589
+ # out
1590
+ self.out = nn.Sequential(
1591
+ nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[0], eps=norm_eps, affine=True),
1592
+ nn.SiLU(),
1593
+ zero_module(
1594
+ Convolution(
1595
+ spatial_dims=spatial_dims,
1596
+ in_channels=channels[0],
1597
+ out_channels=out_channels,
1598
+ strides=1,
1599
+ kernel_size=3,
1600
+ padding=1,
1601
+ conv_only=True,
1602
+ )
1603
+ ),
1604
+ )
1605
+
1606
+ def forward(
1607
+ self,
1608
+ x: torch.Tensor,
1609
+ timesteps: torch.Tensor,
1610
+ context: torch.Tensor | None = None,
1611
+ class_labels: torch.Tensor | None = None,
1612
+ down_block_additional_residuals: tuple[torch.Tensor] | None = None,
1613
+ mid_block_additional_residual: torch.Tensor | None = None,
1614
+ ) -> torch.Tensor:
1615
+ """
1616
+ Args:
1617
+ x: input tensor (N, C, SpatialDims).
1618
+ timesteps: timestep tensor (N,).
1619
+ context: context tensor (N, 1, ContextDim).
1620
+ class_labels: context tensor (N, ).
1621
+ down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims).
1622
+ mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims).
1623
+ """
1624
+ # 1. time
1625
+ t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])
1626
+
1627
+ # timesteps does not contain any weights and will always return f32 tensors
1628
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1629
+ # there might be better ways to encapsulate this.
1630
+ t_emb = t_emb.to(dtype=x.dtype)
1631
+ emb = self.time_embed(t_emb)
1632
+
1633
+ # 2. class
1634
+ if self.num_class_embeds is not None:
1635
+ if class_labels is None:
1636
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
1637
+ class_emb = self.class_embedding(class_labels)
1638
+ class_emb = class_emb.to(dtype=x.dtype)
1639
+ emb = emb + class_emb
1640
+
1641
+ # 3. initial convolution
1642
+ h = self.conv_in(x)
1643
+
1644
+ # 4. down
1645
+ if context is not None and self.with_conditioning is False:
1646
+ raise ValueError("model should have with_conditioning = True if context is provided")
1647
+ down_block_res_samples: list[torch.Tensor] = [h]
1648
+ for downsample_block in self.down_blocks:
1649
+ h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)
1650
+ for residual in res_samples:
1651
+ down_block_res_samples.append(residual)
1652
+
1653
+ # Additional residual conections for Controlnets
1654
+ if down_block_additional_residuals is not None:
1655
+ new_down_block_res_samples: list[torch.Tensor] = []
1656
+ for down_block_res_sample, down_block_additional_residual in zip(
1657
+ down_block_res_samples, down_block_additional_residuals
1658
+ ):
1659
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1660
+ new_down_block_res_samples += [down_block_res_sample]
1661
+
1662
+ down_block_res_samples = new_down_block_res_samples
1663
+
1664
+ # 5. mid
1665
+ h = self.middle_block(hidden_states=h, temb=emb, context=context)
1666
+
1667
+ # Additional residual conections for Controlnets
1668
+ if mid_block_additional_residual is not None:
1669
+ h = h + mid_block_additional_residual
1670
+
1671
+ # 6. up
1672
+ for upsample_block in self.up_blocks:
1673
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1674
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1675
+ h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context)
1676
+
1677
+ # 7. output block
1678
+ output: torch.Tensor = self.out(h)
1679
+
1680
+ return output
1681
+
1682
+ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
1683
+ """
1684
+ Load a state dict from a DiffusionModelUNet trained with
1685
+ [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).
1686
+
1687
+ Args:
1688
+ old_state_dict: state dict from the old DecoderOnlyTransformer model.
1689
+ """
1690
+
1691
+ new_state_dict = self.state_dict()
1692
+ # if all keys match, just load the state dict
1693
+ if all(k in new_state_dict for k in old_state_dict):
1694
+ print("All keys match, loading state dict.")
1695
+ self.load_state_dict(old_state_dict)
1696
+ return
1697
+
1698
+ if verbose:
1699
+ # print all new_state_dict keys that are not in old_state_dict
1700
+ for k in new_state_dict:
1701
+ if k not in old_state_dict:
1702
+ print(f"key {k} not found in old state dict")
1703
+ # and vice versa
1704
+ print("----------------------------------------------")
1705
+ for k in old_state_dict:
1706
+ if k not in new_state_dict:
1707
+ print(f"key {k} not found in new state dict")
1708
+
1709
+ # copy over all matching keys
1710
+ for k in new_state_dict:
1711
+ if k in old_state_dict:
1712
+ new_state_dict[k] = old_state_dict[k]
1713
+
1714
+ # fix the attention blocks
1715
+ attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k]
1716
+ for block in attention_blocks:
1717
+ new_state_dict[f"{block}.attn1.qkv.weight"] = torch.cat(
1718
+ [
1719
+ old_state_dict[f"{block}.attn1.to_q.weight"],
1720
+ old_state_dict[f"{block}.attn1.to_k.weight"],
1721
+ old_state_dict[f"{block}.attn1.to_v.weight"],
1722
+ ],
1723
+ dim=0,
1724
+ )
1725
+
1726
+ # projection
1727
+ new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"]
1728
+ new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"]
1729
+
1730
+ new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"]
1731
+ new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"]
1732
+ # fix the upsample conv blocks which were renamed postconv
1733
+ for k in new_state_dict:
1734
+ if "postconv" in k:
1735
+ old_name = k.replace("postconv", "conv")
1736
+ new_state_dict[k] = old_state_dict[old_name]
1737
+ self.load_state_dict(new_state_dict)
1738
+
1739
+
1740
+ class DiffusionModelEncoder(nn.Module):
1741
+ """
1742
+ Classification Network based on the Encoder of the Diffusion Model, followed by fully connected layers. This network is based on
1743
+ Wolleb et al. "Diffusion Models for Medical Anomaly Detection" (https://arxiv.org/abs/2203.04306).
1744
+
1745
+ Args:
1746
+ spatial_dims: number of spatial dimensions.
1747
+ in_channels: number of input channels.
1748
+ out_channels: number of output channels.
1749
+ num_res_blocks: number of residual blocks (see _ResnetBlock) per level.
1750
+ channels: tuple of block output channels.
1751
+ attention_levels: list of levels to add attention.
1752
+ norm_num_groups: number of groups for the normalization.
1753
+ norm_eps: epsilon for the normalization.
1754
+ resblock_updown: if True use residual blocks for downsampling.
1755
+ num_head_channels: number of channels in each attention head.
1756
+ with_conditioning: if True add spatial transformers to perform conditioning.
1757
+ transformer_num_layers: number of layers of Transformer blocks to use.
1758
+ cross_attention_dim: number of context dimensions to use.
1759
+ num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes.
1760
+ upcast_attention: if True, upcast attention operations to full precision.
1761
+ """
1762
+
1763
+ def __init__(
1764
+ self,
1765
+ spatial_dims: int,
1766
+ in_channels: int,
1767
+ out_channels: int,
1768
+ num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
1769
+ channels: Sequence[int] = (32, 64, 64, 64),
1770
+ attention_levels: Sequence[bool] = (False, False, True, True),
1771
+ norm_num_groups: int = 32,
1772
+ norm_eps: float = 1e-6,
1773
+ resblock_updown: bool = False,
1774
+ num_head_channels: int | Sequence[int] = 8,
1775
+ with_conditioning: bool = False,
1776
+ transformer_num_layers: int = 1,
1777
+ cross_attention_dim: int | None = None,
1778
+ num_class_embeds: int | None = None,
1779
+ upcast_attention: bool = False,
1780
+ ) -> None:
1781
+ super().__init__()
1782
+ if with_conditioning is True and cross_attention_dim is None:
1783
+ raise ValueError(
1784
+ "DiffusionModelEncoder expects dimension of the cross-attention conditioning (cross_attention_dim) "
1785
+ "when using with_conditioning."
1786
+ )
1787
+ if cross_attention_dim is not None and with_conditioning is False:
1788
+ raise ValueError(
1789
+ "DiffusionModelEncoder expects with_conditioning=True when specifying the cross_attention_dim."
1790
+ )
1791
+
1792
+ # All number of channels should be multiple of num_groups
1793
+ if any((out_channel % norm_num_groups) != 0 for out_channel in channels):
1794
+ raise ValueError("DiffusionModelEncoder expects all num_channels being multiple of norm_num_groups")
1795
+ if len(channels) != len(attention_levels):
1796
+ raise ValueError("DiffusionModelEncoder expects num_channels being same size of attention_levels")
1797
+
1798
+ if isinstance(num_head_channels, int):
1799
+ num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))
1800
+
1801
+ if isinstance(num_res_blocks, int):
1802
+ num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))
1803
+
1804
+ if len(num_head_channels) != len(attention_levels):
1805
+ raise ValueError(
1806
+ "num_head_channels should have the same length as attention_levels. For the i levels without attention,"
1807
+ " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored."
1808
+ )
1809
+
1810
+ self.in_channels = in_channels
1811
+ self.block_out_channels = channels
1812
+ self.out_channels = out_channels
1813
+ self.num_res_blocks = num_res_blocks
1814
+ self.attention_levels = attention_levels
1815
+ self.num_head_channels = num_head_channels
1816
+ self.with_conditioning = with_conditioning
1817
+
1818
+ # input
1819
+ self.conv_in = Convolution(
1820
+ spatial_dims=spatial_dims,
1821
+ in_channels=in_channels,
1822
+ out_channels=channels[0],
1823
+ strides=1,
1824
+ kernel_size=3,
1825
+ padding=1,
1826
+ conv_only=True,
1827
+ )
1828
+
1829
+ # time
1830
+ time_embed_dim = channels[0] * 4
1831
+ self.time_embed = nn.Sequential(
1832
+ nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim)
1833
+ )
1834
+
1835
+ # class embedding
1836
+ self.num_class_embeds = num_class_embeds
1837
+ if num_class_embeds is not None:
1838
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
1839
+
1840
+ # down
1841
+ self.down_blocks = nn.ModuleList([])
1842
+ output_channel = channels[0]
1843
+ for i in range(len(channels)):
1844
+ input_channel = output_channel
1845
+ output_channel = channels[i]
1846
+ is_final_block = i == len(channels) # - 1
1847
+
1848
+ down_block = get_down_block(
1849
+ spatial_dims=spatial_dims,
1850
+ in_channels=input_channel,
1851
+ out_channels=output_channel,
1852
+ temb_channels=time_embed_dim,
1853
+ num_res_blocks=num_res_blocks[i],
1854
+ norm_num_groups=norm_num_groups,
1855
+ norm_eps=norm_eps,
1856
+ add_downsample=not is_final_block,
1857
+ resblock_updown=resblock_updown,
1858
+ with_attn=(attention_levels[i] and not with_conditioning),
1859
+ with_cross_attn=(attention_levels[i] and with_conditioning),
1860
+ num_head_channels=num_head_channels[i],
1861
+ transformer_num_layers=transformer_num_layers,
1862
+ cross_attention_dim=cross_attention_dim,
1863
+ upcast_attention=upcast_attention,
1864
+ )
1865
+
1866
+ self.down_blocks.append(down_block)
1867
+
1868
+ self.out = nn.Sequential(nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels))
1869
+
1870
+ def forward(
1871
+ self,
1872
+ x: torch.Tensor,
1873
+ timesteps: torch.Tensor,
1874
+ context: torch.Tensor | None = None,
1875
+ class_labels: torch.Tensor | None = None,
1876
+ ) -> torch.Tensor:
1877
+ """
1878
+ Args:
1879
+ x: input tensor (N, C, SpatialDims).
1880
+ timesteps: timestep tensor (N,).
1881
+ context: context tensor (N, 1, ContextDim).
1882
+ class_labels: context tensor (N, ).
1883
+ """
1884
+ # 1. time
1885
+ t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])
1886
+
1887
+ # timesteps does not contain any weights and will always return f32 tensors
1888
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1889
+ # there might be better ways to encapsulate this.
1890
+ t_emb = t_emb.to(dtype=x.dtype)
1891
+ emb = self.time_embed(t_emb)
1892
+
1893
+ # 2. class
1894
+ if self.num_class_embeds is not None:
1895
+ if class_labels is None:
1896
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
1897
+ class_emb = self.class_embedding(class_labels)
1898
+ class_emb = class_emb.to(dtype=x.dtype)
1899
+ emb = emb + class_emb
1900
+
1901
+ # 3. initial convolution
1902
+ h = self.conv_in(x)
1903
+
1904
+ # 4. down
1905
+ if context is not None and self.with_conditioning is False:
1906
+ raise ValueError("model should have with_conditioning = True if context is provided")
1907
+ for downsample_block in self.down_blocks:
1908
+ h, _ = downsample_block(hidden_states=h, temb=emb, context=context)
1909
+
1910
+ h = h.reshape(h.shape[0], -1)
1911
+ output: torch.Tensor = self.out(h)
1912
+
1913
+ return output