monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/auto3dseg/hpo_gen.py +1 -1
  4. monai/apps/detection/utils/anchor_utils.py +2 -2
  5. monai/apps/pathology/transforms/post/array.py +7 -4
  6. monai/auto3dseg/analyzer.py +1 -1
  7. monai/bundle/scripts.py +204 -22
  8. monai/bundle/utils.py +1 -0
  9. monai/data/dataset_summary.py +1 -0
  10. monai/data/meta_tensor.py +2 -2
  11. monai/data/test_time_augmentation.py +2 -0
  12. monai/data/utils.py +9 -6
  13. monai/data/wsi_reader.py +2 -2
  14. monai/engines/__init__.py +3 -1
  15. monai/engines/trainer.py +281 -2
  16. monai/engines/utils.py +76 -1
  17. monai/handlers/mlflow_handler.py +21 -4
  18. monai/inferers/__init__.py +5 -0
  19. monai/inferers/inferer.py +1279 -1
  20. monai/metrics/cumulative_average.py +2 -0
  21. monai/metrics/panoptic_quality.py +1 -1
  22. monai/metrics/rocauc.py +2 -2
  23. monai/networks/blocks/__init__.py +3 -0
  24. monai/networks/blocks/attention_utils.py +128 -0
  25. monai/networks/blocks/crossattention.py +168 -0
  26. monai/networks/blocks/rel_pos_embedding.py +56 -0
  27. monai/networks/blocks/selfattention.py +74 -5
  28. monai/networks/blocks/spade_norm.py +95 -0
  29. monai/networks/blocks/spatialattention.py +82 -0
  30. monai/networks/blocks/transformerblock.py +25 -4
  31. monai/networks/blocks/upsample.py +22 -10
  32. monai/networks/layers/__init__.py +2 -1
  33. monai/networks/layers/factories.py +12 -1
  34. monai/networks/layers/simplelayers.py +1 -1
  35. monai/networks/layers/utils.py +14 -1
  36. monai/networks/layers/vector_quantizer.py +233 -0
  37. monai/networks/nets/__init__.py +9 -0
  38. monai/networks/nets/autoencoderkl.py +702 -0
  39. monai/networks/nets/controlnet.py +465 -0
  40. monai/networks/nets/diffusion_model_unet.py +1913 -0
  41. monai/networks/nets/patchgan_discriminator.py +230 -0
  42. monai/networks/nets/quicknat.py +8 -6
  43. monai/networks/nets/resnet.py +3 -4
  44. monai/networks/nets/spade_autoencoderkl.py +480 -0
  45. monai/networks/nets/spade_diffusion_model_unet.py +934 -0
  46. monai/networks/nets/spade_network.py +435 -0
  47. monai/networks/nets/swin_unetr.py +4 -3
  48. monai/networks/nets/transformer.py +157 -0
  49. monai/networks/nets/vqvae.py +472 -0
  50. monai/networks/schedulers/__init__.py +17 -0
  51. monai/networks/schedulers/ddim.py +294 -0
  52. monai/networks/schedulers/ddpm.py +250 -0
  53. monai/networks/schedulers/pndm.py +316 -0
  54. monai/networks/schedulers/scheduler.py +205 -0
  55. monai/networks/utils.py +22 -0
  56. monai/transforms/croppad/array.py +8 -8
  57. monai/transforms/croppad/dictionary.py +4 -4
  58. monai/transforms/croppad/functional.py +1 -1
  59. monai/transforms/regularization/array.py +4 -0
  60. monai/transforms/spatial/array.py +1 -1
  61. monai/transforms/utils_create_transform_ims.py +2 -4
  62. monai/utils/__init__.py +1 -0
  63. monai/utils/misc.py +5 -4
  64. monai/utils/ordering.py +207 -0
  65. monai/visualize/class_activation_maps.py +5 -5
  66. monai/visualize/img2tensorboard.py +3 -1
  67. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
  68. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
  69. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
  70. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
  71. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,480 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ from collections.abc import Sequence
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ from monai.networks.blocks import Convolution, SpatialAttentionBlock, Upsample
21
+ from monai.networks.blocks.spade_norm import SPADE
22
+ from monai.networks.nets.autoencoderkl import Encoder
23
+ from monai.utils import ensure_tuple_rep
24
+
25
+ __all__ = ["SPADEAutoencoderKL"]
26
+
27
+
28
+ class SPADEResBlock(nn.Module):
29
+ """
30
+ Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a
31
+ residual connection between input and output.
32
+ Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)
33
+
34
+ Args:
35
+ spatial_dims: number of spatial dimensions (1D, 2D, 3D).
36
+ in_channels: input channels to the layer.
37
+ norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of
38
+ channels is divisible by this number.
39
+ norm_eps: epsilon for the normalisation.
40
+ out_channels: number of output channels.
41
+ label_nc: number of semantic channels for SPADE normalisation
42
+ spade_intermediate_channels: number of intermediate channels for SPADE block layer
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ spatial_dims: int,
48
+ in_channels: int,
49
+ norm_num_groups: int,
50
+ norm_eps: float,
51
+ out_channels: int,
52
+ label_nc: int,
53
+ spade_intermediate_channels: int,
54
+ ) -> None:
55
+ super().__init__()
56
+ self.in_channels = in_channels
57
+ self.out_channels = in_channels if out_channels is None else out_channels
58
+ self.norm1 = SPADE(
59
+ label_nc=label_nc,
60
+ norm_nc=in_channels,
61
+ norm="GROUP",
62
+ norm_params={"num_groups": norm_num_groups, "affine": False},
63
+ hidden_channels=spade_intermediate_channels,
64
+ kernel_size=3,
65
+ spatial_dims=spatial_dims,
66
+ )
67
+ self.conv1 = Convolution(
68
+ spatial_dims=spatial_dims,
69
+ in_channels=self.in_channels,
70
+ out_channels=self.out_channels,
71
+ strides=1,
72
+ kernel_size=3,
73
+ padding=1,
74
+ conv_only=True,
75
+ )
76
+ self.norm2 = SPADE(
77
+ label_nc=label_nc,
78
+ norm_nc=out_channels,
79
+ norm="GROUP",
80
+ norm_params={"num_groups": norm_num_groups, "affine": False},
81
+ hidden_channels=spade_intermediate_channels,
82
+ kernel_size=3,
83
+ spatial_dims=spatial_dims,
84
+ )
85
+ self.conv2 = Convolution(
86
+ spatial_dims=spatial_dims,
87
+ in_channels=self.out_channels,
88
+ out_channels=self.out_channels,
89
+ strides=1,
90
+ kernel_size=3,
91
+ padding=1,
92
+ conv_only=True,
93
+ )
94
+
95
+ self.nin_shortcut: nn.Module
96
+ if self.in_channels != self.out_channels:
97
+ self.nin_shortcut = Convolution(
98
+ spatial_dims=spatial_dims,
99
+ in_channels=self.in_channels,
100
+ out_channels=self.out_channels,
101
+ strides=1,
102
+ kernel_size=1,
103
+ padding=0,
104
+ conv_only=True,
105
+ )
106
+ else:
107
+ self.nin_shortcut = nn.Identity()
108
+
109
+ def forward(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:
110
+ h = x
111
+ h = self.norm1(h, seg)
112
+ h = F.silu(h)
113
+ h = self.conv1(h)
114
+ h = self.norm2(h, seg)
115
+ h = F.silu(h)
116
+ h = self.conv2(h)
117
+
118
+ x = self.nin_shortcut(x)
119
+
120
+ return x + h
121
+
122
+
123
+ class SPADEDecoder(nn.Module):
124
+ """
125
+ Convolutional cascade upsampling from a spatial latent space into an image space.
126
+ Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)
127
+
128
+ Args:
129
+ spatial_dims: number of spatial dimensions (1D, 2D, 3D).
130
+ channels: sequence of block output channels.
131
+ in_channels: number of channels in the bottom layer (latent space) of the autoencoder.
132
+ out_channels: number of output channels.
133
+ num_res_blocks: number of residual blocks (see ResBlock) per level.
134
+ norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number.
135
+ norm_eps: epsilon for the normalization.
136
+ attention_levels: indicate which level from channels contain an attention block.
137
+ label_nc: number of semantic channels for SPADE normalisation.
138
+ with_nonlocal_attn: if True use non-local attention block.
139
+ spade_intermediate_channels: number of intermediate channels for SPADE block layer.
140
+ """
141
+
142
+ def __init__(
143
+ self,
144
+ spatial_dims: int,
145
+ channels: Sequence[int],
146
+ in_channels: int,
147
+ out_channels: int,
148
+ num_res_blocks: Sequence[int],
149
+ norm_num_groups: int,
150
+ norm_eps: float,
151
+ attention_levels: Sequence[bool],
152
+ label_nc: int,
153
+ with_nonlocal_attn: bool = True,
154
+ spade_intermediate_channels: int = 128,
155
+ ) -> None:
156
+ super().__init__()
157
+ self.spatial_dims = spatial_dims
158
+ self.channels = channels
159
+ self.in_channels = in_channels
160
+ self.out_channels = out_channels
161
+ self.num_res_blocks = num_res_blocks
162
+ self.norm_num_groups = norm_num_groups
163
+ self.norm_eps = norm_eps
164
+ self.attention_levels = attention_levels
165
+ self.label_nc = label_nc
166
+
167
+ reversed_block_out_channels = list(reversed(channels))
168
+
169
+ blocks: list[nn.Module] = []
170
+
171
+ # Initial convolution
172
+ blocks.append(
173
+ Convolution(
174
+ spatial_dims=spatial_dims,
175
+ in_channels=in_channels,
176
+ out_channels=reversed_block_out_channels[0],
177
+ strides=1,
178
+ kernel_size=3,
179
+ padding=1,
180
+ conv_only=True,
181
+ )
182
+ )
183
+
184
+ # Non-local attention block
185
+ if with_nonlocal_attn is True:
186
+ blocks.append(
187
+ SPADEResBlock(
188
+ spatial_dims=spatial_dims,
189
+ in_channels=reversed_block_out_channels[0],
190
+ norm_num_groups=norm_num_groups,
191
+ norm_eps=norm_eps,
192
+ out_channels=reversed_block_out_channels[0],
193
+ label_nc=label_nc,
194
+ spade_intermediate_channels=spade_intermediate_channels,
195
+ )
196
+ )
197
+ blocks.append(
198
+ SpatialAttentionBlock(
199
+ spatial_dims=spatial_dims,
200
+ num_channels=reversed_block_out_channels[0],
201
+ norm_num_groups=norm_num_groups,
202
+ norm_eps=norm_eps,
203
+ )
204
+ )
205
+ blocks.append(
206
+ SPADEResBlock(
207
+ spatial_dims=spatial_dims,
208
+ in_channels=reversed_block_out_channels[0],
209
+ norm_num_groups=norm_num_groups,
210
+ norm_eps=norm_eps,
211
+ out_channels=reversed_block_out_channels[0],
212
+ label_nc=label_nc,
213
+ spade_intermediate_channels=spade_intermediate_channels,
214
+ )
215
+ )
216
+
217
+ reversed_attention_levels = list(reversed(attention_levels))
218
+ reversed_num_res_blocks = list(reversed(num_res_blocks))
219
+ block_out_ch = reversed_block_out_channels[0]
220
+ for i in range(len(reversed_block_out_channels)):
221
+ block_in_ch = block_out_ch
222
+ block_out_ch = reversed_block_out_channels[i]
223
+ is_final_block = i == len(channels) - 1
224
+
225
+ for _ in range(reversed_num_res_blocks[i]):
226
+ blocks.append(
227
+ SPADEResBlock(
228
+ spatial_dims=spatial_dims,
229
+ in_channels=block_in_ch,
230
+ norm_num_groups=norm_num_groups,
231
+ norm_eps=norm_eps,
232
+ out_channels=block_out_ch,
233
+ label_nc=label_nc,
234
+ spade_intermediate_channels=spade_intermediate_channels,
235
+ )
236
+ )
237
+ block_in_ch = block_out_ch
238
+
239
+ if reversed_attention_levels[i]:
240
+ blocks.append(
241
+ SpatialAttentionBlock(
242
+ spatial_dims=spatial_dims,
243
+ num_channels=block_in_ch,
244
+ norm_num_groups=norm_num_groups,
245
+ norm_eps=norm_eps,
246
+ )
247
+ )
248
+
249
+ if not is_final_block:
250
+ post_conv = Convolution(
251
+ spatial_dims=spatial_dims,
252
+ in_channels=block_in_ch,
253
+ out_channels=block_in_ch,
254
+ strides=1,
255
+ kernel_size=3,
256
+ padding=1,
257
+ conv_only=True,
258
+ )
259
+ blocks.append(
260
+ Upsample(
261
+ spatial_dims=spatial_dims,
262
+ mode="nontrainable",
263
+ in_channels=block_in_ch,
264
+ out_channels=block_in_ch,
265
+ interp_mode="nearest",
266
+ scale_factor=2.0,
267
+ post_conv=post_conv,
268
+ align_corners=None,
269
+ )
270
+ )
271
+
272
+ blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True))
273
+ blocks.append(
274
+ Convolution(
275
+ spatial_dims=spatial_dims,
276
+ in_channels=block_in_ch,
277
+ out_channels=out_channels,
278
+ strides=1,
279
+ kernel_size=3,
280
+ padding=1,
281
+ conv_only=True,
282
+ )
283
+ )
284
+
285
+ self.blocks = nn.ModuleList(blocks)
286
+
287
+ def forward(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:
288
+ for block in self.blocks:
289
+ if isinstance(block, SPADEResBlock):
290
+ x = block(x, seg)
291
+ else:
292
+ x = block(x)
293
+ return x
294
+
295
+
296
+ class SPADEAutoencoderKL(nn.Module):
297
+ """
298
+ Autoencoder model with KL-regularized latent space based on
299
+ Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752
300
+ and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162
301
+ Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)
302
+
303
+ Args:
304
+ spatial_dims: number of spatial dimensions (1D, 2D, 3D).
305
+ label_nc: number of semantic channels for SPADE normalisation.
306
+ in_channels: number of input channels.
307
+ out_channels: number of output channels.
308
+ num_res_blocks: number of residual blocks (see ResBlock) per level.
309
+ channels: sequence of block output channels.
310
+ attention_levels: sequence of levels to add attention.
311
+ latent_channels: latent embedding dimension.
312
+ norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number.
313
+ norm_eps: epsilon for the normalization.
314
+ with_encoder_nonlocal_attn: if True use non-local attention block in the encoder.
315
+ with_decoder_nonlocal_attn: if True use non-local attention block in the decoder.
316
+ spade_intermediate_channels: number of intermediate channels for SPADE block layer.
317
+ """
318
+
319
+ def __init__(
320
+ self,
321
+ spatial_dims: int,
322
+ label_nc: int,
323
+ in_channels: int = 1,
324
+ out_channels: int = 1,
325
+ num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
326
+ channels: Sequence[int] = (32, 64, 64, 64),
327
+ attention_levels: Sequence[bool] = (False, False, True, True),
328
+ latent_channels: int = 3,
329
+ norm_num_groups: int = 32,
330
+ norm_eps: float = 1e-6,
331
+ with_encoder_nonlocal_attn: bool = True,
332
+ with_decoder_nonlocal_attn: bool = True,
333
+ spade_intermediate_channels: int = 128,
334
+ ) -> None:
335
+ super().__init__()
336
+
337
+ # All number of channels should be multiple of num_groups
338
+ if any((out_channel % norm_num_groups) != 0 for out_channel in channels):
339
+ raise ValueError("SPADEAutoencoderKL expects all channels being multiple of norm_num_groups")
340
+
341
+ if len(channels) != len(attention_levels):
342
+ raise ValueError("SPADEAutoencoderKL expects channels being same size of attention_levels")
343
+
344
+ if isinstance(num_res_blocks, int):
345
+ num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))
346
+
347
+ if len(num_res_blocks) != len(channels):
348
+ raise ValueError(
349
+ "`num_res_blocks` should be a single integer or a tuple of integers with the same length as "
350
+ "`channels`."
351
+ )
352
+
353
+ self.encoder = Encoder(
354
+ spatial_dims=spatial_dims,
355
+ in_channels=in_channels,
356
+ channels=channels,
357
+ out_channels=latent_channels,
358
+ num_res_blocks=num_res_blocks,
359
+ norm_num_groups=norm_num_groups,
360
+ norm_eps=norm_eps,
361
+ attention_levels=attention_levels,
362
+ with_nonlocal_attn=with_encoder_nonlocal_attn,
363
+ )
364
+ self.decoder = SPADEDecoder(
365
+ spatial_dims=spatial_dims,
366
+ channels=channels,
367
+ in_channels=latent_channels,
368
+ out_channels=out_channels,
369
+ num_res_blocks=num_res_blocks,
370
+ norm_num_groups=norm_num_groups,
371
+ norm_eps=norm_eps,
372
+ attention_levels=attention_levels,
373
+ label_nc=label_nc,
374
+ with_nonlocal_attn=with_decoder_nonlocal_attn,
375
+ spade_intermediate_channels=spade_intermediate_channels,
376
+ )
377
+ self.quant_conv_mu = Convolution(
378
+ spatial_dims=spatial_dims,
379
+ in_channels=latent_channels,
380
+ out_channels=latent_channels,
381
+ strides=1,
382
+ kernel_size=1,
383
+ padding=0,
384
+ conv_only=True,
385
+ )
386
+ self.quant_conv_log_sigma = Convolution(
387
+ spatial_dims=spatial_dims,
388
+ in_channels=latent_channels,
389
+ out_channels=latent_channels,
390
+ strides=1,
391
+ kernel_size=1,
392
+ padding=0,
393
+ conv_only=True,
394
+ )
395
+ self.post_quant_conv = Convolution(
396
+ spatial_dims=spatial_dims,
397
+ in_channels=latent_channels,
398
+ out_channels=latent_channels,
399
+ strides=1,
400
+ kernel_size=1,
401
+ padding=0,
402
+ conv_only=True,
403
+ )
404
+ self.latent_channels = latent_channels
405
+
406
+ def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
407
+ """
408
+ Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations.
409
+
410
+ Args:
411
+ x: BxCx[SPATIAL DIMS] tensor
412
+
413
+ """
414
+ h = self.encoder(x)
415
+ z_mu = self.quant_conv_mu(h)
416
+ z_log_var = self.quant_conv_log_sigma(h)
417
+ z_log_var = torch.clamp(z_log_var, -30.0, 20.0)
418
+ z_sigma = torch.exp(z_log_var / 2)
419
+
420
+ return z_mu, z_sigma
421
+
422
+ def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor:
423
+ """
424
+ From the mean and sigma representations resulting of encoding an image through the latent space,
425
+ obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and
426
+ adding the mean.
427
+
428
+ Args:
429
+ z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image
430
+ z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image
431
+
432
+ Returns:
433
+ sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE]
434
+ """
435
+ eps = torch.randn_like(z_sigma)
436
+ z_vae = z_mu + eps * z_sigma
437
+ return z_vae
438
+
439
+ def reconstruct(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:
440
+ """
441
+ Encodes and decodes an input image.
442
+
443
+ Args:
444
+ x: BxCx[SPATIAL DIMENSIONS] tensor.
445
+ seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm.
446
+ Returns:
447
+ reconstructed image, of the same shape as input
448
+ """
449
+ z_mu, _ = self.encode(x)
450
+ reconstruction = self.decode(z_mu, seg)
451
+ return reconstruction
452
+
453
+ def decode(self, z: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:
454
+ """
455
+ Based on a latent space sample, forwards it through the Decoder.
456
+
457
+ Args:
458
+ z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE]
459
+ seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm.
460
+ Returns:
461
+ decoded image tensor
462
+ """
463
+ z = self.post_quant_conv(z)
464
+ dec: torch.Tensor = self.decoder(z, seg)
465
+ return dec
466
+
467
+ def forward(self, x: torch.Tensor, seg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
468
+ z_mu, z_sigma = self.encode(x)
469
+ z = self.sampling(z_mu, z_sigma)
470
+ reconstruction = self.decode(z, seg)
471
+ return reconstruction, z_mu, z_sigma
472
+
473
+ def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor:
474
+ z_mu, z_sigma = self.encode(x)
475
+ z = self.sampling(z_mu, z_sigma)
476
+ return z
477
+
478
+ def decode_stage_2_outputs(self, z: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:
479
+ image = self.decode(z, seg)
480
+ return image