monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/auto3dseg/hpo_gen.py +1 -1
  4. monai/apps/detection/utils/anchor_utils.py +2 -2
  5. monai/apps/pathology/transforms/post/array.py +7 -4
  6. monai/auto3dseg/analyzer.py +1 -1
  7. monai/bundle/scripts.py +204 -22
  8. monai/bundle/utils.py +1 -0
  9. monai/data/dataset_summary.py +1 -0
  10. monai/data/meta_tensor.py +2 -2
  11. monai/data/test_time_augmentation.py +2 -0
  12. monai/data/utils.py +9 -6
  13. monai/data/wsi_reader.py +2 -2
  14. monai/engines/__init__.py +3 -1
  15. monai/engines/trainer.py +281 -2
  16. monai/engines/utils.py +76 -1
  17. monai/handlers/mlflow_handler.py +21 -4
  18. monai/inferers/__init__.py +5 -0
  19. monai/inferers/inferer.py +1279 -1
  20. monai/metrics/cumulative_average.py +2 -0
  21. monai/metrics/panoptic_quality.py +1 -1
  22. monai/metrics/rocauc.py +2 -2
  23. monai/networks/blocks/__init__.py +3 -0
  24. monai/networks/blocks/attention_utils.py +128 -0
  25. monai/networks/blocks/crossattention.py +168 -0
  26. monai/networks/blocks/rel_pos_embedding.py +56 -0
  27. monai/networks/blocks/selfattention.py +74 -5
  28. monai/networks/blocks/spade_norm.py +95 -0
  29. monai/networks/blocks/spatialattention.py +82 -0
  30. monai/networks/blocks/transformerblock.py +25 -4
  31. monai/networks/blocks/upsample.py +22 -10
  32. monai/networks/layers/__init__.py +2 -1
  33. monai/networks/layers/factories.py +12 -1
  34. monai/networks/layers/simplelayers.py +1 -1
  35. monai/networks/layers/utils.py +14 -1
  36. monai/networks/layers/vector_quantizer.py +233 -0
  37. monai/networks/nets/__init__.py +9 -0
  38. monai/networks/nets/autoencoderkl.py +702 -0
  39. monai/networks/nets/controlnet.py +465 -0
  40. monai/networks/nets/diffusion_model_unet.py +1913 -0
  41. monai/networks/nets/patchgan_discriminator.py +230 -0
  42. monai/networks/nets/quicknat.py +8 -6
  43. monai/networks/nets/resnet.py +3 -4
  44. monai/networks/nets/spade_autoencoderkl.py +480 -0
  45. monai/networks/nets/spade_diffusion_model_unet.py +934 -0
  46. monai/networks/nets/spade_network.py +435 -0
  47. monai/networks/nets/swin_unetr.py +4 -3
  48. monai/networks/nets/transformer.py +157 -0
  49. monai/networks/nets/vqvae.py +472 -0
  50. monai/networks/schedulers/__init__.py +17 -0
  51. monai/networks/schedulers/ddim.py +294 -0
  52. monai/networks/schedulers/ddpm.py +250 -0
  53. monai/networks/schedulers/pndm.py +316 -0
  54. monai/networks/schedulers/scheduler.py +205 -0
  55. monai/networks/utils.py +22 -0
  56. monai/transforms/croppad/array.py +8 -8
  57. monai/transforms/croppad/dictionary.py +4 -4
  58. monai/transforms/croppad/functional.py +1 -1
  59. monai/transforms/regularization/array.py +4 -0
  60. monai/transforms/spatial/array.py +1 -1
  61. monai/transforms/utils_create_transform_ims.py +2 -4
  62. monai/utils/__init__.py +1 -0
  63. monai/utils/misc.py +5 -4
  64. monai/utils/ordering.py +207 -0
  65. monai/visualize/class_activation_maps.py +5 -5
  66. monai/visualize/img2tensorboard.py +3 -1
  67. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
  68. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
  69. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
  70. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
  71. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,934 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+ #
12
+ # =========================================================================
13
+ # Adapted from https://github.com/huggingface/diffusers
14
+ # which has the following license:
15
+ # https://github.com/huggingface/diffusers/blob/main/LICENSE
16
+ #
17
+ # Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
18
+ #
19
+ # Licensed under the Apache License, Version 2.0 (the "License");
20
+ # you may not use this file except in compliance with the License.
21
+ # You may obtain a copy of the License at
22
+ #
23
+ # http://www.apache.org/licenses/LICENSE-2.0
24
+ #
25
+ # Unless required by applicable law or agreed to in writing, software
26
+ # distributed under the License is distributed on an "AS IS" BASIS,
27
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28
+ # See the License for the specific language governing permissions and
29
+ # limitations under the License.
30
+ # =========================================================================
31
+
32
+ from __future__ import annotations
33
+
34
+ from collections.abc import Sequence
35
+
36
+ import torch
37
+ from torch import nn
38
+
39
+ from monai.networks.blocks import Convolution, SpatialAttentionBlock
40
+ from monai.networks.blocks.spade_norm import SPADE
41
+ from monai.networks.nets.diffusion_model_unet import (
42
+ DiffusionUnetDownsample,
43
+ DiffusionUNetResnetBlock,
44
+ SpatialTransformer,
45
+ WrappedUpsample,
46
+ get_down_block,
47
+ get_mid_block,
48
+ get_timestep_embedding,
49
+ zero_module,
50
+ )
51
+ from monai.utils import ensure_tuple_rep
52
+
53
+ __all__ = ["SPADEDiffusionModelUNet"]
54
+
55
+
56
+ class SPADEDiffResBlock(nn.Module):
57
+ """
58
+ Residual block with timestep conditioning and SPADE norm.
59
+ Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)
60
+
61
+ Args:
62
+ spatial_dims: The number of spatial dimensions.
63
+ in_channels: number of input channels.
64
+ temb_channels: number of timestep embedding channels.
65
+ label_nc: number of semantic channels for SPADE normalisation.
66
+ out_channels: number of output channels.
67
+ up: if True, performs upsampling.
68
+ down: if True, performs downsampling.
69
+ norm_num_groups: number of groups for the group normalization.
70
+ norm_eps: epsilon for the group normalization.
71
+ spade_intermediate_channels: number of intermediate channels for SPADE block layer
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ spatial_dims: int,
77
+ in_channels: int,
78
+ temb_channels: int,
79
+ label_nc: int,
80
+ out_channels: int | None = None,
81
+ up: bool = False,
82
+ down: bool = False,
83
+ norm_num_groups: int = 32,
84
+ norm_eps: float = 1e-6,
85
+ spade_intermediate_channels: int = 128,
86
+ ) -> None:
87
+ super().__init__()
88
+ self.spatial_dims = spatial_dims
89
+ self.channels = in_channels
90
+ self.emb_channels = temb_channels
91
+ self.out_channels = out_channels or in_channels
92
+ self.up = up
93
+ self.down = down
94
+
95
+ self.norm1 = SPADE(
96
+ label_nc=label_nc,
97
+ norm_nc=in_channels,
98
+ norm="GROUP",
99
+ norm_params={"num_groups": norm_num_groups, "eps": norm_eps, "affine": True},
100
+ hidden_channels=spade_intermediate_channels,
101
+ kernel_size=3,
102
+ spatial_dims=spatial_dims,
103
+ )
104
+
105
+ self.nonlinearity = nn.SiLU()
106
+ self.conv1 = Convolution(
107
+ spatial_dims=spatial_dims,
108
+ in_channels=in_channels,
109
+ out_channels=self.out_channels,
110
+ strides=1,
111
+ kernel_size=3,
112
+ padding=1,
113
+ conv_only=True,
114
+ )
115
+
116
+ self.upsample = self.downsample = None
117
+ if self.up:
118
+ self.upsample = WrappedUpsample(
119
+ spatial_dims=spatial_dims,
120
+ mode="nontrainable",
121
+ in_channels=in_channels,
122
+ out_channels=in_channels,
123
+ interp_mode="nearest",
124
+ scale_factor=2.0,
125
+ align_corners=None,
126
+ )
127
+ elif down:
128
+ self.downsample = DiffusionUnetDownsample(spatial_dims, in_channels, use_conv=False)
129
+
130
+ self.time_emb_proj = nn.Linear(temb_channels, self.out_channels)
131
+
132
+ self.norm2 = SPADE(
133
+ label_nc=label_nc,
134
+ norm_nc=self.out_channels,
135
+ norm="GROUP",
136
+ norm_params={"num_groups": norm_num_groups, "eps": norm_eps, "affine": True},
137
+ hidden_channels=spade_intermediate_channels,
138
+ kernel_size=3,
139
+ spatial_dims=spatial_dims,
140
+ )
141
+ self.conv2 = zero_module(
142
+ Convolution(
143
+ spatial_dims=spatial_dims,
144
+ in_channels=self.out_channels,
145
+ out_channels=self.out_channels,
146
+ strides=1,
147
+ kernel_size=3,
148
+ padding=1,
149
+ conv_only=True,
150
+ )
151
+ )
152
+ self.skip_connection: nn.Module
153
+
154
+ if self.out_channels == in_channels:
155
+ self.skip_connection = nn.Identity()
156
+ else:
157
+ self.skip_connection = Convolution(
158
+ spatial_dims=spatial_dims,
159
+ in_channels=in_channels,
160
+ out_channels=self.out_channels,
161
+ strides=1,
162
+ kernel_size=1,
163
+ padding=0,
164
+ conv_only=True,
165
+ )
166
+
167
+ def forward(self, x: torch.Tensor, emb: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:
168
+ h = x
169
+ h = self.norm1(h, seg)
170
+ h = self.nonlinearity(h)
171
+
172
+ if self.upsample is not None:
173
+ x = self.upsample(x)
174
+ h = self.upsample(h)
175
+ elif self.downsample is not None:
176
+ x = self.downsample(x)
177
+ h = self.downsample(h)
178
+
179
+ h = self.conv1(h)
180
+
181
+ if self.spatial_dims == 2:
182
+ temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None]
183
+ else:
184
+ temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None]
185
+ h = h + temb
186
+
187
+ h = self.norm2(h, seg)
188
+ h = self.nonlinearity(h)
189
+ h = self.conv2(h)
190
+ output: torch.Tensor = self.skip_connection(x) + h
191
+ return output
192
+
193
+
194
+ class SPADEUpBlock(nn.Module):
195
+ """
196
+ Unet's up block containing resnet and upsamplers blocks.
197
+ Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)
198
+
199
+ Args:
200
+ spatial_dims: The number of spatial dimensions.
201
+ in_channels: number of input channels.
202
+ prev_output_channel: number of channels from residual connection.
203
+ out_channels: number of output channels.
204
+ temb_channels: number of timestep embedding channels.
205
+ label_nc: number of semantic channels for SPADE normalisation.
206
+ num_res_blocks: number of residual blocks.
207
+ norm_num_groups: number of groups for the group normalization.
208
+ norm_eps: epsilon for the group normalization.
209
+ add_upsample: if True add downsample block.
210
+ resblock_updown: if True use residual blocks for upsampling.
211
+ spade_intermediate_channels: number of intermediate channels for SPADE block layer.
212
+ """
213
+
214
+ def __init__(
215
+ self,
216
+ spatial_dims: int,
217
+ in_channels: int,
218
+ prev_output_channel: int,
219
+ out_channels: int,
220
+ temb_channels: int,
221
+ label_nc: int,
222
+ num_res_blocks: int = 1,
223
+ norm_num_groups: int = 32,
224
+ norm_eps: float = 1e-6,
225
+ add_upsample: bool = True,
226
+ resblock_updown: bool = False,
227
+ spade_intermediate_channels: int = 128,
228
+ ) -> None:
229
+ super().__init__()
230
+ self.resblock_updown = resblock_updown
231
+ resnets = []
232
+
233
+ for i in range(num_res_blocks):
234
+ res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels
235
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
236
+
237
+ resnets.append(
238
+ SPADEDiffResBlock(
239
+ spatial_dims=spatial_dims,
240
+ in_channels=resnet_in_channels + res_skip_channels,
241
+ out_channels=out_channels,
242
+ temb_channels=temb_channels,
243
+ label_nc=label_nc,
244
+ norm_num_groups=norm_num_groups,
245
+ norm_eps=norm_eps,
246
+ spade_intermediate_channels=spade_intermediate_channels,
247
+ )
248
+ )
249
+
250
+ self.resnets = nn.ModuleList(resnets)
251
+
252
+ self.upsampler: nn.Module | None
253
+ if add_upsample:
254
+ if resblock_updown:
255
+ self.upsampler = DiffusionUNetResnetBlock(
256
+ spatial_dims=spatial_dims,
257
+ in_channels=out_channels,
258
+ out_channels=out_channels,
259
+ temb_channels=temb_channels,
260
+ norm_num_groups=norm_num_groups,
261
+ norm_eps=norm_eps,
262
+ up=True,
263
+ )
264
+ else:
265
+ post_conv = Convolution(
266
+ spatial_dims=spatial_dims,
267
+ in_channels=out_channels,
268
+ out_channels=out_channels,
269
+ strides=1,
270
+ kernel_size=3,
271
+ padding=1,
272
+ conv_only=True,
273
+ )
274
+ self.upsampler = WrappedUpsample(
275
+ spatial_dims=spatial_dims,
276
+ mode="nontrainable",
277
+ in_channels=out_channels,
278
+ out_channels=out_channels,
279
+ interp_mode="nearest",
280
+ scale_factor=2.0,
281
+ post_conv=post_conv,
282
+ align_corners=None,
283
+ )
284
+ else:
285
+ self.upsampler = None
286
+
287
+ def forward(
288
+ self,
289
+ hidden_states: torch.Tensor,
290
+ res_hidden_states_list: list[torch.Tensor],
291
+ temb: torch.Tensor,
292
+ seg: torch.Tensor,
293
+ context: torch.Tensor | None = None,
294
+ ) -> torch.Tensor:
295
+ del context
296
+ for resnet in self.resnets:
297
+ # pop res hidden states
298
+ res_hidden_states = res_hidden_states_list[-1]
299
+ res_hidden_states_list = res_hidden_states_list[:-1]
300
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
301
+ hidden_states = resnet(hidden_states, temb, seg)
302
+
303
+ if self.upsampler is not None:
304
+ hidden_states = self.upsampler(hidden_states, temb)
305
+
306
+ return hidden_states
307
+
308
+
309
+ class SPADEAttnUpBlock(nn.Module):
310
+ """
311
+ Unet's up block containing resnet, upsamplers, and self-attention blocks.
312
+ Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)
313
+
314
+ Args:
315
+ spatial_dims: The number of spatial dimensions.
316
+ in_channels: number of input channels.
317
+ prev_output_channel: number of channels from residual connection.
318
+ out_channels: number of output channels.
319
+ temb_channels: number of timestep embedding channels.
320
+ label_nc: number of semantic channels for SPADE normalisation
321
+ num_res_blocks: number of residual blocks.
322
+ norm_num_groups: number of groups for the group normalization.
323
+ norm_eps: epsilon for the group normalization.
324
+ add_upsample: if True add downsample block.
325
+ resblock_updown: if True use residual blocks for upsampling.
326
+ num_head_channels: number of channels in each attention head.
327
+ spade_intermediate_channels: number of intermediate channels for SPADE block layer
328
+ """
329
+
330
+ def __init__(
331
+ self,
332
+ spatial_dims: int,
333
+ in_channels: int,
334
+ prev_output_channel: int,
335
+ out_channels: int,
336
+ temb_channels: int,
337
+ label_nc: int,
338
+ num_res_blocks: int = 1,
339
+ norm_num_groups: int = 32,
340
+ norm_eps: float = 1e-6,
341
+ add_upsample: bool = True,
342
+ resblock_updown: bool = False,
343
+ num_head_channels: int = 1,
344
+ spade_intermediate_channels: int = 128,
345
+ ) -> None:
346
+ super().__init__()
347
+ self.resblock_updown = resblock_updown
348
+ resnets = []
349
+ attentions = []
350
+
351
+ for i in range(num_res_blocks):
352
+ res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels
353
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
354
+
355
+ resnets.append(
356
+ SPADEDiffResBlock(
357
+ spatial_dims=spatial_dims,
358
+ in_channels=resnet_in_channels + res_skip_channels,
359
+ out_channels=out_channels,
360
+ temb_channels=temb_channels,
361
+ label_nc=label_nc,
362
+ norm_num_groups=norm_num_groups,
363
+ norm_eps=norm_eps,
364
+ spade_intermediate_channels=spade_intermediate_channels,
365
+ )
366
+ )
367
+ attentions.append(
368
+ SpatialAttentionBlock(
369
+ spatial_dims=spatial_dims,
370
+ num_channels=out_channels,
371
+ num_head_channels=num_head_channels,
372
+ norm_num_groups=norm_num_groups,
373
+ norm_eps=norm_eps,
374
+ )
375
+ )
376
+
377
+ self.resnets = nn.ModuleList(resnets)
378
+ self.attentions = nn.ModuleList(attentions)
379
+
380
+ self.upsampler: nn.Module | None
381
+ if add_upsample:
382
+ if resblock_updown:
383
+ self.upsampler = DiffusionUNetResnetBlock(
384
+ spatial_dims=spatial_dims,
385
+ in_channels=out_channels,
386
+ out_channels=out_channels,
387
+ temb_channels=temb_channels,
388
+ norm_num_groups=norm_num_groups,
389
+ norm_eps=norm_eps,
390
+ up=True,
391
+ )
392
+ else:
393
+ post_conv = Convolution(
394
+ spatial_dims=spatial_dims,
395
+ in_channels=out_channels,
396
+ out_channels=out_channels,
397
+ strides=1,
398
+ kernel_size=3,
399
+ padding=1,
400
+ conv_only=True,
401
+ )
402
+ self.upsampler = WrappedUpsample(
403
+ spatial_dims=spatial_dims,
404
+ mode="nontrainable",
405
+ in_channels=out_channels,
406
+ out_channels=out_channels,
407
+ interp_mode="nearest",
408
+ scale_factor=2.0,
409
+ post_conv=post_conv,
410
+ align_corners=None,
411
+ )
412
+ else:
413
+ self.upsampler = None
414
+
415
+ def forward(
416
+ self,
417
+ hidden_states: torch.Tensor,
418
+ res_hidden_states_list: list[torch.Tensor],
419
+ temb: torch.Tensor,
420
+ seg: torch.Tensor,
421
+ context: torch.Tensor | None = None,
422
+ ) -> torch.Tensor:
423
+ del context
424
+ for resnet, attn in zip(self.resnets, self.attentions):
425
+ # pop res hidden states
426
+ res_hidden_states = res_hidden_states_list[-1]
427
+ res_hidden_states_list = res_hidden_states_list[:-1]
428
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
429
+ hidden_states = resnet(hidden_states, temb, seg)
430
+ hidden_states = attn(hidden_states).contiguous()
431
+
432
+ if self.upsampler is not None:
433
+ hidden_states = self.upsampler(hidden_states, temb)
434
+
435
+ return hidden_states
436
+
437
+
438
+ class SPADECrossAttnUpBlock(nn.Module):
439
+ """
440
+ Unet's up block containing resnet, upsamplers, and self-attention blocks.
441
+ Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)
442
+
443
+ Args:
444
+ spatial_dims: The number of spatial dimensions.
445
+ in_channels: number of input channels.
446
+ prev_output_channel: number of channels from residual connection.
447
+ out_channels: number of output channels.
448
+ temb_channels: number of timestep embedding channels.
449
+ label_nc: number of semantic channels for SPADE normalisation.
450
+ num_res_blocks: number of residual blocks.
451
+ norm_num_groups: number of groups for the group normalization.
452
+ norm_eps: epsilon for the group normalization.
453
+ add_upsample: if True add downsample block.
454
+ resblock_updown: if True use residual blocks for upsampling.
455
+ num_head_channels: number of channels in each attention head.
456
+ transformer_num_layers: number of layers of Transformer blocks to use.
457
+ cross_attention_dim: number of context dimensions to use.
458
+ upcast_attention: if True, upcast attention operations to full precision.
459
+ spade_intermediate_channels: number of intermediate channels for SPADE block layer.
460
+ """
461
+
462
+ def __init__(
463
+ self,
464
+ spatial_dims: int,
465
+ in_channels: int,
466
+ prev_output_channel: int,
467
+ out_channels: int,
468
+ temb_channels: int,
469
+ label_nc: int,
470
+ num_res_blocks: int = 1,
471
+ norm_num_groups: int = 32,
472
+ norm_eps: float = 1e-6,
473
+ add_upsample: bool = True,
474
+ resblock_updown: bool = False,
475
+ num_head_channels: int = 1,
476
+ transformer_num_layers: int = 1,
477
+ cross_attention_dim: int | None = None,
478
+ upcast_attention: bool = False,
479
+ spade_intermediate_channels: int = 128,
480
+ ) -> None:
481
+ super().__init__()
482
+ self.resblock_updown = resblock_updown
483
+ resnets = []
484
+ attentions = []
485
+
486
+ for i in range(num_res_blocks):
487
+ res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels
488
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
489
+
490
+ resnets.append(
491
+ SPADEDiffResBlock(
492
+ spatial_dims=spatial_dims,
493
+ in_channels=resnet_in_channels + res_skip_channels,
494
+ out_channels=out_channels,
495
+ temb_channels=temb_channels,
496
+ norm_num_groups=norm_num_groups,
497
+ norm_eps=norm_eps,
498
+ label_nc=label_nc,
499
+ spade_intermediate_channels=spade_intermediate_channels,
500
+ )
501
+ )
502
+ attentions.append(
503
+ SpatialTransformer(
504
+ spatial_dims=spatial_dims,
505
+ in_channels=out_channels,
506
+ num_attention_heads=out_channels // num_head_channels,
507
+ num_head_channels=num_head_channels,
508
+ norm_num_groups=norm_num_groups,
509
+ norm_eps=norm_eps,
510
+ num_layers=transformer_num_layers,
511
+ cross_attention_dim=cross_attention_dim,
512
+ upcast_attention=upcast_attention,
513
+ )
514
+ )
515
+
516
+ self.attentions = nn.ModuleList(attentions)
517
+ self.resnets = nn.ModuleList(resnets)
518
+
519
+ self.upsampler: nn.Module | None
520
+ if add_upsample:
521
+ if resblock_updown:
522
+ self.upsampler = DiffusionUNetResnetBlock(
523
+ spatial_dims=spatial_dims,
524
+ in_channels=out_channels,
525
+ out_channels=out_channels,
526
+ temb_channels=temb_channels,
527
+ norm_num_groups=norm_num_groups,
528
+ norm_eps=norm_eps,
529
+ up=True,
530
+ )
531
+ else:
532
+ post_conv = Convolution(
533
+ spatial_dims=spatial_dims,
534
+ in_channels=out_channels,
535
+ out_channels=out_channels,
536
+ strides=1,
537
+ kernel_size=3,
538
+ padding=1,
539
+ conv_only=True,
540
+ )
541
+ self.upsampler = WrappedUpsample(
542
+ spatial_dims=spatial_dims,
543
+ mode="nontrainable",
544
+ in_channels=out_channels,
545
+ out_channels=out_channels,
546
+ interp_mode="nearest",
547
+ scale_factor=2.0,
548
+ post_conv=post_conv,
549
+ align_corners=None,
550
+ )
551
+ else:
552
+ self.upsampler = None
553
+
554
+ def forward(
555
+ self,
556
+ hidden_states: torch.Tensor,
557
+ res_hidden_states_list: list[torch.Tensor],
558
+ temb: torch.Tensor,
559
+ seg: torch.Tensor | None = None,
560
+ context: torch.Tensor | None = None,
561
+ ) -> torch.Tensor:
562
+ for resnet, attn in zip(self.resnets, self.attentions):
563
+ # pop res hidden states
564
+ res_hidden_states = res_hidden_states_list[-1]
565
+ res_hidden_states_list = res_hidden_states_list[:-1]
566
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
567
+ hidden_states = resnet(hidden_states, temb, seg)
568
+ hidden_states = attn(hidden_states, context=context).contiguous()
569
+
570
+ if self.upsampler is not None:
571
+ hidden_states = self.upsampler(hidden_states, temb)
572
+
573
+ return hidden_states
574
+
575
+
576
+ def get_spade_up_block(
577
+ spatial_dims: int,
578
+ in_channels: int,
579
+ prev_output_channel: int,
580
+ out_channels: int,
581
+ temb_channels: int,
582
+ num_res_blocks: int,
583
+ norm_num_groups: int,
584
+ norm_eps: float,
585
+ add_upsample: bool,
586
+ resblock_updown: bool,
587
+ with_attn: bool,
588
+ with_cross_attn: bool,
589
+ num_head_channels: int,
590
+ transformer_num_layers: int,
591
+ label_nc: int,
592
+ cross_attention_dim: int | None,
593
+ upcast_attention: bool = False,
594
+ spade_intermediate_channels: int = 128,
595
+ ) -> nn.Module:
596
+ if with_attn:
597
+ return SPADEAttnUpBlock(
598
+ spatial_dims=spatial_dims,
599
+ in_channels=in_channels,
600
+ prev_output_channel=prev_output_channel,
601
+ out_channels=out_channels,
602
+ temb_channels=temb_channels,
603
+ label_nc=label_nc,
604
+ num_res_blocks=num_res_blocks,
605
+ norm_num_groups=norm_num_groups,
606
+ norm_eps=norm_eps,
607
+ add_upsample=add_upsample,
608
+ resblock_updown=resblock_updown,
609
+ num_head_channels=num_head_channels,
610
+ spade_intermediate_channels=spade_intermediate_channels,
611
+ )
612
+ elif with_cross_attn:
613
+ return SPADECrossAttnUpBlock(
614
+ spatial_dims=spatial_dims,
615
+ in_channels=in_channels,
616
+ prev_output_channel=prev_output_channel,
617
+ out_channels=out_channels,
618
+ temb_channels=temb_channels,
619
+ label_nc=label_nc,
620
+ num_res_blocks=num_res_blocks,
621
+ norm_num_groups=norm_num_groups,
622
+ norm_eps=norm_eps,
623
+ add_upsample=add_upsample,
624
+ resblock_updown=resblock_updown,
625
+ num_head_channels=num_head_channels,
626
+ transformer_num_layers=transformer_num_layers,
627
+ cross_attention_dim=cross_attention_dim,
628
+ upcast_attention=upcast_attention,
629
+ spade_intermediate_channels=spade_intermediate_channels,
630
+ )
631
+ else:
632
+ return SPADEUpBlock(
633
+ spatial_dims=spatial_dims,
634
+ in_channels=in_channels,
635
+ prev_output_channel=prev_output_channel,
636
+ out_channels=out_channels,
637
+ temb_channels=temb_channels,
638
+ label_nc=label_nc,
639
+ num_res_blocks=num_res_blocks,
640
+ norm_num_groups=norm_num_groups,
641
+ norm_eps=norm_eps,
642
+ add_upsample=add_upsample,
643
+ resblock_updown=resblock_updown,
644
+ spade_intermediate_channels=spade_intermediate_channels,
645
+ )
646
+
647
+
648
+ class SPADEDiffusionModelUNet(nn.Module):
649
+ """
650
+ UNet network with timestep embedding and attention mechanisms for conditioning, with added SPADE normalization for
651
+ semantic conditioning (Park et.al (2019): https://github.com/NVlabs/SPADE). An example tutorial can be found at
652
+ https://github.com/Project-MONAI/GenerativeModels/tree/main/tutorials/generative/2d_spade_ldm
653
+
654
+ Args:
655
+ spatial_dims: number of spatial dimensions.
656
+ in_channels: number of input channels.
657
+ out_channels: number of output channels.
658
+ label_nc: number of semantic channels for SPADE normalisation.
659
+ num_res_blocks: number of residual blocks (see ResnetBlock) per level.
660
+ channels: tuple of block output channels.
661
+ attention_levels: list of levels to add attention.
662
+ norm_num_groups: number of groups for the normalization.
663
+ norm_eps: epsilon for the normalization.
664
+ resblock_updown: if True use residual blocks for up/downsampling.
665
+ num_head_channels: number of channels in each attention head.
666
+ with_conditioning: if True add spatial transformers to perform conditioning.
667
+ transformer_num_layers: number of layers of Transformer blocks to use.
668
+ cross_attention_dim: number of context dimensions to use.
669
+ num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
670
+ classes.
671
+ upcast_attention: if True, upcast attention operations to full precision.
672
+ spade_intermediate_channels: number of intermediate channels for SPADE block layer
673
+ """
674
+
675
+ def __init__(
676
+ self,
677
+ spatial_dims: int,
678
+ in_channels: int,
679
+ out_channels: int,
680
+ label_nc: int,
681
+ num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
682
+ channels: Sequence[int] = (32, 64, 64, 64),
683
+ attention_levels: Sequence[bool] = (False, False, True, True),
684
+ norm_num_groups: int = 32,
685
+ norm_eps: float = 1e-6,
686
+ resblock_updown: bool = False,
687
+ num_head_channels: int | Sequence[int] = 8,
688
+ with_conditioning: bool = False,
689
+ transformer_num_layers: int = 1,
690
+ cross_attention_dim: int | None = None,
691
+ num_class_embeds: int | None = None,
692
+ upcast_attention: bool = False,
693
+ spade_intermediate_channels: int = 128,
694
+ ) -> None:
695
+ super().__init__()
696
+ if with_conditioning is True and cross_attention_dim is None:
697
+ raise ValueError(
698
+ "SPADEDiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) "
699
+ "when using with_conditioning."
700
+ )
701
+ if cross_attention_dim is not None and with_conditioning is False:
702
+ raise ValueError(
703
+ "SPADEDiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim."
704
+ )
705
+
706
+ # All number of channels should be multiple of num_groups
707
+ if any((out_channel % norm_num_groups) != 0 for out_channel in channels):
708
+ raise ValueError("SPADEDiffusionModelUNet expects all num_channels being multiple of norm_num_groups")
709
+
710
+ if len(channels) != len(attention_levels):
711
+ raise ValueError("SPADEDiffusionModelUNet expects num_channels being same size of attention_levels")
712
+
713
+ if isinstance(num_head_channels, int):
714
+ num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))
715
+
716
+ if len(num_head_channels) != len(attention_levels):
717
+ raise ValueError(
718
+ "num_head_channels should have the same length as attention_levels. For the i levels without attention,"
719
+ " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored."
720
+ )
721
+
722
+ if isinstance(num_res_blocks, int):
723
+ num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))
724
+
725
+ if len(num_res_blocks) != len(channels):
726
+ raise ValueError(
727
+ "`num_res_blocks` should be a single integer or a tuple of integers with the same length as "
728
+ "`num_channels`."
729
+ )
730
+
731
+ self.in_channels = in_channels
732
+ self.block_out_channels = channels
733
+ self.out_channels = out_channels
734
+ self.num_res_blocks = num_res_blocks
735
+ self.attention_levels = attention_levels
736
+ self.num_head_channels = num_head_channels
737
+ self.with_conditioning = with_conditioning
738
+ self.label_nc = label_nc
739
+
740
+ # input
741
+ self.conv_in = Convolution(
742
+ spatial_dims=spatial_dims,
743
+ in_channels=in_channels,
744
+ out_channels=channels[0],
745
+ strides=1,
746
+ kernel_size=3,
747
+ padding=1,
748
+ conv_only=True,
749
+ )
750
+
751
+ # time
752
+ time_embed_dim = channels[0] * 4
753
+ self.time_embed = nn.Sequential(
754
+ nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim)
755
+ )
756
+
757
+ # class embedding
758
+ self.num_class_embeds = num_class_embeds
759
+ if num_class_embeds is not None:
760
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
761
+
762
+ # down
763
+ self.down_blocks = nn.ModuleList([])
764
+ output_channel = channels[0]
765
+ for i in range(len(channels)):
766
+ input_channel = output_channel
767
+ output_channel = channels[i]
768
+ is_final_block = i == len(channels) - 1
769
+
770
+ down_block = get_down_block(
771
+ spatial_dims=spatial_dims,
772
+ in_channels=input_channel,
773
+ out_channels=output_channel,
774
+ temb_channels=time_embed_dim,
775
+ num_res_blocks=num_res_blocks[i],
776
+ norm_num_groups=norm_num_groups,
777
+ norm_eps=norm_eps,
778
+ add_downsample=not is_final_block,
779
+ resblock_updown=resblock_updown,
780
+ with_attn=(attention_levels[i] and not with_conditioning),
781
+ with_cross_attn=(attention_levels[i] and with_conditioning),
782
+ num_head_channels=num_head_channels[i],
783
+ transformer_num_layers=transformer_num_layers,
784
+ cross_attention_dim=cross_attention_dim,
785
+ upcast_attention=upcast_attention,
786
+ )
787
+
788
+ self.down_blocks.append(down_block)
789
+
790
+ # mid
791
+ self.middle_block = get_mid_block(
792
+ spatial_dims=spatial_dims,
793
+ in_channels=channels[-1],
794
+ temb_channels=time_embed_dim,
795
+ norm_num_groups=norm_num_groups,
796
+ norm_eps=norm_eps,
797
+ with_conditioning=with_conditioning,
798
+ num_head_channels=num_head_channels[-1],
799
+ transformer_num_layers=transformer_num_layers,
800
+ cross_attention_dim=cross_attention_dim,
801
+ upcast_attention=upcast_attention,
802
+ )
803
+
804
+ # up
805
+ self.up_blocks = nn.ModuleList([])
806
+ reversed_block_out_channels = list(reversed(channels))
807
+ reversed_num_res_blocks = list(reversed(num_res_blocks))
808
+ reversed_attention_levels = list(reversed(attention_levels))
809
+ reversed_num_head_channels = list(reversed(num_head_channels))
810
+ output_channel = reversed_block_out_channels[0]
811
+ for i in range(len(reversed_block_out_channels)):
812
+ prev_output_channel = output_channel
813
+ output_channel = reversed_block_out_channels[i]
814
+ input_channel = reversed_block_out_channels[min(i + 1, len(channels) - 1)]
815
+
816
+ is_final_block = i == len(channels) - 1
817
+
818
+ up_block = get_spade_up_block(
819
+ spatial_dims=spatial_dims,
820
+ in_channels=input_channel,
821
+ prev_output_channel=prev_output_channel,
822
+ out_channels=output_channel,
823
+ temb_channels=time_embed_dim,
824
+ num_res_blocks=reversed_num_res_blocks[i] + 1,
825
+ norm_num_groups=norm_num_groups,
826
+ norm_eps=norm_eps,
827
+ add_upsample=not is_final_block,
828
+ resblock_updown=resblock_updown,
829
+ with_attn=(reversed_attention_levels[i] and not with_conditioning),
830
+ with_cross_attn=(reversed_attention_levels[i] and with_conditioning),
831
+ num_head_channels=reversed_num_head_channels[i],
832
+ transformer_num_layers=transformer_num_layers,
833
+ cross_attention_dim=cross_attention_dim,
834
+ upcast_attention=upcast_attention,
835
+ label_nc=label_nc,
836
+ spade_intermediate_channels=spade_intermediate_channels,
837
+ )
838
+
839
+ self.up_blocks.append(up_block)
840
+
841
+ # out
842
+ self.out = nn.Sequential(
843
+ nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[0], eps=norm_eps, affine=True),
844
+ nn.SiLU(),
845
+ zero_module(
846
+ Convolution(
847
+ spatial_dims=spatial_dims,
848
+ in_channels=channels[0],
849
+ out_channels=out_channels,
850
+ strides=1,
851
+ kernel_size=3,
852
+ padding=1,
853
+ conv_only=True,
854
+ )
855
+ ),
856
+ )
857
+
858
+ def forward(
859
+ self,
860
+ x: torch.Tensor,
861
+ timesteps: torch.Tensor,
862
+ seg: torch.Tensor,
863
+ context: torch.Tensor | None = None,
864
+ class_labels: torch.Tensor | None = None,
865
+ down_block_additional_residuals: tuple[torch.Tensor] | None = None,
866
+ mid_block_additional_residual: torch.Tensor | None = None,
867
+ ) -> torch.Tensor:
868
+ """
869
+ Args:
870
+ x: input tensor (N, C, SpatialDims).
871
+ timesteps: timestep tensor (N,).
872
+ seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm.
873
+ context: context tensor (N, 1, ContextDim).
874
+ class_labels: context tensor (N, ).
875
+ down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims).
876
+ mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims).
877
+ """
878
+ # 1. time
879
+ t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])
880
+
881
+ # timesteps does not contain any weights and will always return f32 tensors
882
+ # but time_embedding might actually be running in fp16. so we need to cast here.
883
+ # there might be better ways to encapsulate this.
884
+ t_emb = t_emb.to(dtype=x.dtype)
885
+ emb = self.time_embed(t_emb)
886
+
887
+ # 2. class
888
+ if self.num_class_embeds is not None:
889
+ if class_labels is None:
890
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
891
+ class_emb = self.class_embedding(class_labels)
892
+ class_emb = class_emb.to(dtype=x.dtype)
893
+ emb = emb + class_emb
894
+
895
+ # 3. initial convolution
896
+ h = self.conv_in(x)
897
+
898
+ # 4. down
899
+ if context is not None and self.with_conditioning is False:
900
+ raise ValueError("model should have with_conditioning = True if context is provided")
901
+ down_block_res_samples: list[torch.Tensor] = [h]
902
+ for downsample_block in self.down_blocks:
903
+ h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)
904
+ for residual in res_samples:
905
+ down_block_res_samples.append(residual)
906
+
907
+ # Additional residual conections for Controlnets
908
+ if down_block_additional_residuals is not None:
909
+ new_down_block_res_samples: list[torch.Tensor] = [h]
910
+ for down_block_res_sample, down_block_additional_residual in zip(
911
+ down_block_res_samples, down_block_additional_residuals
912
+ ):
913
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
914
+ new_down_block_res_samples.append(down_block_res_sample)
915
+
916
+ down_block_res_samples = new_down_block_res_samples
917
+
918
+ # 5. mid
919
+ h = self.middle_block(hidden_states=h, temb=emb, context=context)
920
+
921
+ # Additional residual conections for Controlnets
922
+ if mid_block_additional_residual is not None:
923
+ h = h + mid_block_additional_residual
924
+
925
+ # 6. up
926
+ for upsample_block in self.up_blocks:
927
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
928
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
929
+ h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, seg=seg, temb=emb, context=context)
930
+
931
+ # 7. output block
932
+ output: torch.Tensor = self.out(h)
933
+
934
+ return output