monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/auto3dseg/hpo_gen.py +1 -1
  4. monai/apps/detection/utils/anchor_utils.py +2 -2
  5. monai/apps/pathology/transforms/post/array.py +7 -4
  6. monai/auto3dseg/analyzer.py +1 -1
  7. monai/bundle/scripts.py +204 -22
  8. monai/bundle/utils.py +1 -0
  9. monai/data/dataset_summary.py +1 -0
  10. monai/data/meta_tensor.py +2 -2
  11. monai/data/test_time_augmentation.py +2 -0
  12. monai/data/utils.py +9 -6
  13. monai/data/wsi_reader.py +2 -2
  14. monai/engines/__init__.py +3 -1
  15. monai/engines/trainer.py +281 -2
  16. monai/engines/utils.py +76 -1
  17. monai/handlers/mlflow_handler.py +21 -4
  18. monai/inferers/__init__.py +5 -0
  19. monai/inferers/inferer.py +1279 -1
  20. monai/metrics/cumulative_average.py +2 -0
  21. monai/metrics/panoptic_quality.py +1 -1
  22. monai/metrics/rocauc.py +2 -2
  23. monai/networks/blocks/__init__.py +3 -0
  24. monai/networks/blocks/attention_utils.py +128 -0
  25. monai/networks/blocks/crossattention.py +168 -0
  26. monai/networks/blocks/rel_pos_embedding.py +56 -0
  27. monai/networks/blocks/selfattention.py +74 -5
  28. monai/networks/blocks/spade_norm.py +95 -0
  29. monai/networks/blocks/spatialattention.py +82 -0
  30. monai/networks/blocks/transformerblock.py +25 -4
  31. monai/networks/blocks/upsample.py +22 -10
  32. monai/networks/layers/__init__.py +2 -1
  33. monai/networks/layers/factories.py +12 -1
  34. monai/networks/layers/simplelayers.py +1 -1
  35. monai/networks/layers/utils.py +14 -1
  36. monai/networks/layers/vector_quantizer.py +233 -0
  37. monai/networks/nets/__init__.py +9 -0
  38. monai/networks/nets/autoencoderkl.py +702 -0
  39. monai/networks/nets/controlnet.py +465 -0
  40. monai/networks/nets/diffusion_model_unet.py +1913 -0
  41. monai/networks/nets/patchgan_discriminator.py +230 -0
  42. monai/networks/nets/quicknat.py +8 -6
  43. monai/networks/nets/resnet.py +3 -4
  44. monai/networks/nets/spade_autoencoderkl.py +480 -0
  45. monai/networks/nets/spade_diffusion_model_unet.py +934 -0
  46. monai/networks/nets/spade_network.py +435 -0
  47. monai/networks/nets/swin_unetr.py +4 -3
  48. monai/networks/nets/transformer.py +157 -0
  49. monai/networks/nets/vqvae.py +472 -0
  50. monai/networks/schedulers/__init__.py +17 -0
  51. monai/networks/schedulers/ddim.py +294 -0
  52. monai/networks/schedulers/ddpm.py +250 -0
  53. monai/networks/schedulers/pndm.py +316 -0
  54. monai/networks/schedulers/scheduler.py +205 -0
  55. monai/networks/utils.py +22 -0
  56. monai/transforms/croppad/array.py +8 -8
  57. monai/transforms/croppad/dictionary.py +4 -4
  58. monai/transforms/croppad/functional.py +1 -1
  59. monai/transforms/regularization/array.py +4 -0
  60. monai/transforms/spatial/array.py +1 -1
  61. monai/transforms/utils_create_transform_ims.py +2 -4
  62. monai/utils/__init__.py +1 -0
  63. monai/utils/misc.py +5 -4
  64. monai/utils/ordering.py +207 -0
  65. monai/visualize/class_activation_maps.py +5 -5
  66. monai/visualize/img2tensorboard.py +3 -1
  67. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
  68. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
  69. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
  70. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
  71. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,465 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+ #
12
+ # =========================================================================
13
+ # Adapted from https://github.com/huggingface/diffusers
14
+ # which has the following license:
15
+ # https://github.com/huggingface/diffusers/blob/main/LICENSE
16
+ #
17
+ # Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
18
+ #
19
+ # Licensed under the Apache License, Version 2.0 (the "License");
20
+ # you may not use this file except in compliance with the License.
21
+ # You may obtain a copy of the License at
22
+ #
23
+ # http://www.apache.org/licenses/LICENSE-2.0
24
+ #
25
+ # Unless required by applicable law or agreed to in writing, software
26
+ # distributed under the License is distributed on an "AS IS" BASIS,
27
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28
+ # See the License for the specific language governing permissions and
29
+ # limitations under the License.
30
+ # =========================================================================
31
+
32
+ from __future__ import annotations
33
+
34
+ from collections.abc import Sequence
35
+
36
+ import torch
37
+ from torch import nn
38
+
39
+ from monai.networks.blocks import Convolution
40
+ from monai.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding
41
+ from monai.utils import ensure_tuple_rep
42
+
43
+
44
+ class ControlNetConditioningEmbedding(nn.Module):
45
+ """
46
+ Network to encode the conditioning into a latent space.
47
+ """
48
+
49
+ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, channels: Sequence[int]):
50
+ super().__init__()
51
+
52
+ self.conv_in = Convolution(
53
+ spatial_dims=spatial_dims,
54
+ in_channels=in_channels,
55
+ out_channels=channels[0],
56
+ strides=1,
57
+ kernel_size=3,
58
+ padding=1,
59
+ adn_ordering="A",
60
+ act="SWISH",
61
+ )
62
+
63
+ self.blocks = nn.ModuleList([])
64
+
65
+ for i in range(len(channels) - 1):
66
+ channel_in = channels[i]
67
+ channel_out = channels[i + 1]
68
+ self.blocks.append(
69
+ Convolution(
70
+ spatial_dims=spatial_dims,
71
+ in_channels=channel_in,
72
+ out_channels=channel_in,
73
+ strides=1,
74
+ kernel_size=3,
75
+ padding=1,
76
+ adn_ordering="A",
77
+ act="SWISH",
78
+ )
79
+ )
80
+
81
+ self.blocks.append(
82
+ Convolution(
83
+ spatial_dims=spatial_dims,
84
+ in_channels=channel_in,
85
+ out_channels=channel_out,
86
+ strides=2,
87
+ kernel_size=3,
88
+ padding=1,
89
+ adn_ordering="A",
90
+ act="SWISH",
91
+ )
92
+ )
93
+
94
+ self.conv_out = zero_module(
95
+ Convolution(
96
+ spatial_dims=spatial_dims,
97
+ in_channels=channels[-1],
98
+ out_channels=out_channels,
99
+ strides=1,
100
+ kernel_size=3,
101
+ padding=1,
102
+ conv_only=True,
103
+ )
104
+ )
105
+
106
+ def forward(self, conditioning):
107
+ embedding = self.conv_in(conditioning)
108
+
109
+ for block in self.blocks:
110
+ embedding = block(embedding)
111
+
112
+ embedding = self.conv_out(embedding)
113
+
114
+ return embedding
115
+
116
+
117
+ def zero_module(module):
118
+ for p in module.parameters():
119
+ nn.init.zeros_(p)
120
+ return module
121
+
122
+
123
+ class ControlNet(nn.Module):
124
+ """
125
+ Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image
126
+ Diffusion Models" (https://arxiv.org/abs/2302.05543)
127
+
128
+ Args:
129
+ spatial_dims: number of spatial dimensions.
130
+ in_channels: number of input channels.
131
+ num_res_blocks: number of residual blocks (see ResnetBlock) per level.
132
+ channels: tuple of block output channels.
133
+ attention_levels: list of levels to add attention.
134
+ norm_num_groups: number of groups for the normalization.
135
+ norm_eps: epsilon for the normalization.
136
+ resblock_updown: if True use residual blocks for up/downsampling.
137
+ num_head_channels: number of channels in each attention head.
138
+ with_conditioning: if True add spatial transformers to perform conditioning.
139
+ transformer_num_layers: number of layers of Transformer blocks to use.
140
+ cross_attention_dim: number of context dimensions to use.
141
+ num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
142
+ classes.
143
+ upcast_attention: if True, upcast attention operations to full precision.
144
+ conditioning_embedding_in_channels: number of input channels for the conditioning embedding.
145
+ conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding.
146
+ """
147
+
148
+ def __init__(
149
+ self,
150
+ spatial_dims: int,
151
+ in_channels: int,
152
+ num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
153
+ channels: Sequence[int] = (32, 64, 64, 64),
154
+ attention_levels: Sequence[bool] = (False, False, True, True),
155
+ norm_num_groups: int = 32,
156
+ norm_eps: float = 1e-6,
157
+ resblock_updown: bool = False,
158
+ num_head_channels: int | Sequence[int] = 8,
159
+ with_conditioning: bool = False,
160
+ transformer_num_layers: int = 1,
161
+ cross_attention_dim: int | None = None,
162
+ num_class_embeds: int | None = None,
163
+ upcast_attention: bool = False,
164
+ conditioning_embedding_in_channels: int = 1,
165
+ conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256),
166
+ ) -> None:
167
+ super().__init__()
168
+ if with_conditioning is True and cross_attention_dim is None:
169
+ raise ValueError(
170
+ "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) "
171
+ "to be specified when with_conditioning=True."
172
+ )
173
+ if cross_attention_dim is not None and with_conditioning is False:
174
+ raise ValueError(
175
+ "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim."
176
+ )
177
+
178
+ # All number of channels should be multiple of num_groups
179
+ if any((out_channel % norm_num_groups) != 0 for out_channel in channels):
180
+ raise ValueError(
181
+ f"DiffusionModelUNet expects all channels to be a multiple of norm_num_groups, but got"
182
+ f" channels={channels} and norm_num_groups={norm_num_groups}"
183
+ )
184
+
185
+ if len(channels) != len(attention_levels):
186
+ raise ValueError(
187
+ f"DiffusionModelUNet expects channels to have the same length as attention_levels, but got "
188
+ f"channels={channels} and attention_levels={attention_levels}"
189
+ )
190
+
191
+ if isinstance(num_head_channels, int):
192
+ num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))
193
+
194
+ if len(num_head_channels) != len(attention_levels):
195
+ raise ValueError(
196
+ f"num_head_channels should have the same length as attention_levels, but got channels={channels} and "
197
+ f"attention_levels={attention_levels} . For the i levels without attention,"
198
+ " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored."
199
+ )
200
+
201
+ if isinstance(num_res_blocks, int):
202
+ num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))
203
+
204
+ if len(num_res_blocks) != len(channels):
205
+ raise ValueError(
206
+ f"`num_res_blocks` should be a single integer or a tuple of integers with the same length as "
207
+ f"`num_channels`, but got num_res_blocks={num_res_blocks} and channels={channels}."
208
+ )
209
+
210
+ self.in_channels = in_channels
211
+ self.block_out_channels = channels
212
+ self.num_res_blocks = num_res_blocks
213
+ self.attention_levels = attention_levels
214
+ self.num_head_channels = num_head_channels
215
+ self.with_conditioning = with_conditioning
216
+
217
+ # input
218
+ self.conv_in = Convolution(
219
+ spatial_dims=spatial_dims,
220
+ in_channels=in_channels,
221
+ out_channels=channels[0],
222
+ strides=1,
223
+ kernel_size=3,
224
+ padding=1,
225
+ conv_only=True,
226
+ )
227
+
228
+ # time
229
+ time_embed_dim = channels[0] * 4
230
+ self.time_embed = nn.Sequential(
231
+ nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim)
232
+ )
233
+
234
+ # class embedding
235
+ self.num_class_embeds = num_class_embeds
236
+ if num_class_embeds is not None:
237
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
238
+
239
+ # control net conditioning embedding
240
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
241
+ spatial_dims=spatial_dims,
242
+ in_channels=conditioning_embedding_in_channels,
243
+ channels=conditioning_embedding_num_channels,
244
+ out_channels=channels[0],
245
+ )
246
+
247
+ # down
248
+ self.down_blocks = nn.ModuleList([])
249
+ self.controlnet_down_blocks = nn.ModuleList([])
250
+ output_channel = channels[0]
251
+
252
+ controlnet_block = Convolution(
253
+ spatial_dims=spatial_dims,
254
+ in_channels=output_channel,
255
+ out_channels=output_channel,
256
+ strides=1,
257
+ kernel_size=1,
258
+ padding=0,
259
+ conv_only=True,
260
+ )
261
+ controlnet_block = zero_module(controlnet_block.conv)
262
+ self.controlnet_down_blocks.append(controlnet_block)
263
+
264
+ for i in range(len(channels)):
265
+ input_channel = output_channel
266
+ output_channel = channels[i]
267
+ is_final_block = i == len(channels) - 1
268
+
269
+ down_block = get_down_block(
270
+ spatial_dims=spatial_dims,
271
+ in_channels=input_channel,
272
+ out_channels=output_channel,
273
+ temb_channels=time_embed_dim,
274
+ num_res_blocks=num_res_blocks[i],
275
+ norm_num_groups=norm_num_groups,
276
+ norm_eps=norm_eps,
277
+ add_downsample=not is_final_block,
278
+ resblock_updown=resblock_updown,
279
+ with_attn=(attention_levels[i] and not with_conditioning),
280
+ with_cross_attn=(attention_levels[i] and with_conditioning),
281
+ num_head_channels=num_head_channels[i],
282
+ transformer_num_layers=transformer_num_layers,
283
+ cross_attention_dim=cross_attention_dim,
284
+ upcast_attention=upcast_attention,
285
+ )
286
+
287
+ self.down_blocks.append(down_block)
288
+
289
+ for _ in range(num_res_blocks[i]):
290
+ controlnet_block = Convolution(
291
+ spatial_dims=spatial_dims,
292
+ in_channels=output_channel,
293
+ out_channels=output_channel,
294
+ strides=1,
295
+ kernel_size=1,
296
+ padding=0,
297
+ conv_only=True,
298
+ )
299
+ controlnet_block = zero_module(controlnet_block)
300
+ self.controlnet_down_blocks.append(controlnet_block)
301
+ #
302
+ if not is_final_block:
303
+ controlnet_block = Convolution(
304
+ spatial_dims=spatial_dims,
305
+ in_channels=output_channel,
306
+ out_channels=output_channel,
307
+ strides=1,
308
+ kernel_size=1,
309
+ padding=0,
310
+ conv_only=True,
311
+ )
312
+ controlnet_block = zero_module(controlnet_block)
313
+ self.controlnet_down_blocks.append(controlnet_block)
314
+
315
+ # mid
316
+ mid_block_channel = channels[-1]
317
+
318
+ self.middle_block = get_mid_block(
319
+ spatial_dims=spatial_dims,
320
+ in_channels=mid_block_channel,
321
+ temb_channels=time_embed_dim,
322
+ norm_num_groups=norm_num_groups,
323
+ norm_eps=norm_eps,
324
+ with_conditioning=with_conditioning,
325
+ num_head_channels=num_head_channels[-1],
326
+ transformer_num_layers=transformer_num_layers,
327
+ cross_attention_dim=cross_attention_dim,
328
+ upcast_attention=upcast_attention,
329
+ )
330
+
331
+ controlnet_block = Convolution(
332
+ spatial_dims=spatial_dims,
333
+ in_channels=output_channel,
334
+ out_channels=output_channel,
335
+ strides=1,
336
+ kernel_size=1,
337
+ padding=0,
338
+ conv_only=True,
339
+ )
340
+ controlnet_block = zero_module(controlnet_block)
341
+ self.controlnet_mid_block = controlnet_block
342
+
343
+ def forward(
344
+ self,
345
+ x: torch.Tensor,
346
+ timesteps: torch.Tensor,
347
+ controlnet_cond: torch.Tensor,
348
+ conditioning_scale: float = 1.0,
349
+ context: torch.Tensor | None = None,
350
+ class_labels: torch.Tensor | None = None,
351
+ ) -> tuple[list[torch.Tensor], torch.Tensor]:
352
+ """
353
+ Args:
354
+ x: input tensor (N, C, H, W, [D]).
355
+ timesteps: timestep tensor (N,).
356
+ controlnet_cond: controlnet conditioning tensor (N, C, H, W, [D])
357
+ conditioning_scale: conditioning scale.
358
+ context: context tensor (N, 1, cross_attention_dim), where cross_attention_dim is specified in the model init.
359
+ class_labels: context tensor (N, ).
360
+ """
361
+ # 1. time
362
+ t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])
363
+
364
+ # timesteps does not contain any weights and will always return f32 tensors
365
+ # but time_embedding might actually be running in fp16. so we need to cast here.
366
+ # there might be better ways to encapsulate this.
367
+ t_emb = t_emb.to(dtype=x.dtype)
368
+ emb = self.time_embed(t_emb)
369
+
370
+ # 2. class
371
+ if self.num_class_embeds is not None:
372
+ if class_labels is None:
373
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
374
+ class_emb = self.class_embedding(class_labels)
375
+ class_emb = class_emb.to(dtype=x.dtype)
376
+ emb = emb + class_emb
377
+
378
+ # 3. initial convolution
379
+ h = self.conv_in(x)
380
+
381
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
382
+
383
+ h += controlnet_cond
384
+
385
+ # 4. down
386
+ if context is not None and self.with_conditioning is False:
387
+ raise ValueError("model should have with_conditioning = True if context is provided")
388
+ down_block_res_samples: list[torch.Tensor] = [h]
389
+ for downsample_block in self.down_blocks:
390
+ h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)
391
+ for residual in res_samples:
392
+ down_block_res_samples.append(residual)
393
+
394
+ # 5. mid
395
+ h = self.middle_block(hidden_states=h, temb=emb, context=context)
396
+
397
+ # 6. Control net blocks
398
+ controlnet_down_block_res_samples = []
399
+
400
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
401
+ down_block_res_sample = controlnet_block(down_block_res_sample)
402
+ controlnet_down_block_res_samples.append(down_block_res_sample)
403
+
404
+ down_block_res_samples = controlnet_down_block_res_samples
405
+
406
+ mid_block_res_sample: torch.Tensor = self.controlnet_mid_block(h)
407
+
408
+ # 6. scaling
409
+ down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples]
410
+ mid_block_res_sample *= conditioning_scale
411
+
412
+ return down_block_res_samples, mid_block_res_sample
413
+
414
+ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
415
+ """
416
+ Load a state dict from a ControlNet trained with
417
+ [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).
418
+
419
+ Args:
420
+ old_state_dict: state dict from the old ControlNet model.
421
+ """
422
+
423
+ new_state_dict = self.state_dict()
424
+ # if all keys match, just load the state dict
425
+ if all(k in new_state_dict for k in old_state_dict):
426
+ print("All keys match, loading state dict.")
427
+ self.load_state_dict(old_state_dict)
428
+ return
429
+
430
+ if verbose:
431
+ # print all new_state_dict keys that are not in old_state_dict
432
+ for k in new_state_dict:
433
+ if k not in old_state_dict:
434
+ print(f"key {k} not found in old state dict")
435
+ # and vice versa
436
+ print("----------------------------------------------")
437
+ for k in old_state_dict:
438
+ if k not in new_state_dict:
439
+ print(f"key {k} not found in new state dict")
440
+
441
+ # copy over all matching keys
442
+ for k in new_state_dict:
443
+ if k in old_state_dict:
444
+ new_state_dict[k] = old_state_dict[k]
445
+
446
+ # fix the attention blocks
447
+ attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k]
448
+ for block in attention_blocks:
449
+ new_state_dict[f"{block}.attn1.qkv.weight"] = torch.cat(
450
+ [
451
+ old_state_dict[f"{block}.attn1.to_q.weight"],
452
+ old_state_dict[f"{block}.attn1.to_k.weight"],
453
+ old_state_dict[f"{block}.attn1.to_v.weight"],
454
+ ],
455
+ dim=0,
456
+ )
457
+
458
+ # projection
459
+ new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"]
460
+ new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"]
461
+
462
+ new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"]
463
+ new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"]
464
+
465
+ self.load_state_dict(new_state_dict)