monai-weekly 1.4.dev2430__py3-none-any.whl → 1.4.dev2434__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +43 -25
  4. monai/apps/generation/maisi/networks/controlnet_maisi.py +15 -18
  5. monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +18 -18
  6. monai/bundle/config_parser.py +2 -2
  7. monai/bundle/reference_resolver.py +18 -1
  8. monai/bundle/scripts.py +45 -22
  9. monai/bundle/utils.py +3 -1
  10. monai/data/utils.py +1 -1
  11. monai/data/wsi_datasets.py +3 -3
  12. monai/losses/__init__.py +1 -0
  13. monai/losses/dice.py +10 -1
  14. monai/losses/nacl_loss.py +139 -0
  15. monai/networks/blocks/crossattention.py +48 -26
  16. monai/networks/blocks/mlp.py +16 -4
  17. monai/networks/blocks/selfattention.py +75 -23
  18. monai/networks/blocks/spatialattention.py +16 -1
  19. monai/networks/blocks/transformerblock.py +17 -2
  20. monai/networks/nets/__init__.py +2 -1
  21. monai/networks/nets/autoencoderkl.py +55 -22
  22. monai/networks/nets/cell_sam_wrapper.py +92 -0
  23. monai/networks/nets/controlnet.py +24 -22
  24. monai/networks/nets/diffusion_model_unet.py +159 -19
  25. monai/networks/nets/segresnet_ds.py +127 -1
  26. monai/networks/nets/spade_autoencoderkl.py +24 -2
  27. monai/networks/nets/spade_diffusion_model_unet.py +39 -2
  28. monai/networks/nets/transformer.py +17 -17
  29. monai/networks/nets/vista3d.py +908 -0
  30. monai/networks/utils.py +3 -3
  31. monai/transforms/__init__.py +1 -0
  32. monai/transforms/io/array.py +1 -1
  33. monai/transforms/post/array.py +2 -1
  34. monai/transforms/spatial/functional.py +1 -1
  35. monai/transforms/transform.py +2 -2
  36. monai/transforms/utils.py +183 -0
  37. monai/{apps/generation/maisi/utils/morphological_ops.py → transforms/utils_morphological_ops.py} +2 -0
  38. monai/transforms/utils_pytorch_numpy_unification.py +2 -2
  39. monai/utils/module.py +7 -6
  40. {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/METADATA +83 -81
  41. {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/RECORD +44 -41
  42. {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/WHEEL +1 -1
  43. {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/LICENSE +0 -0
  44. {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/top_level.txt +0 -0
@@ -157,6 +157,10 @@ class Encoder(nn.Module):
157
157
  norm_eps: epsilon for the normalization.
158
158
  attention_levels: indicate which level from num_channels contain an attention block.
159
159
  with_nonlocal_attn: if True use non-local attention block.
160
+ include_fc: whether to include the final linear layer. Default to True.
161
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
162
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
163
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
160
164
  """
161
165
 
162
166
  def __init__(
@@ -170,6 +174,9 @@ class Encoder(nn.Module):
170
174
  norm_eps: float,
171
175
  attention_levels: Sequence[bool],
172
176
  with_nonlocal_attn: bool = True,
177
+ include_fc: bool = True,
178
+ use_combined_linear: bool = False,
179
+ use_flash_attention: bool = False,
173
180
  ) -> None:
174
181
  super().__init__()
175
182
  self.spatial_dims = spatial_dims
@@ -220,6 +227,9 @@ class Encoder(nn.Module):
220
227
  num_channels=input_channel,
221
228
  norm_num_groups=norm_num_groups,
222
229
  norm_eps=norm_eps,
230
+ include_fc=include_fc,
231
+ use_combined_linear=use_combined_linear,
232
+ use_flash_attention=use_flash_attention,
223
233
  )
224
234
  )
225
235
 
@@ -243,6 +253,9 @@ class Encoder(nn.Module):
243
253
  num_channels=channels[-1],
244
254
  norm_num_groups=norm_num_groups,
245
255
  norm_eps=norm_eps,
256
+ include_fc=include_fc,
257
+ use_combined_linear=use_combined_linear,
258
+ use_flash_attention=use_flash_attention,
246
259
  )
247
260
  )
248
261
  blocks.append(
@@ -291,6 +304,10 @@ class Decoder(nn.Module):
291
304
  attention_levels: indicate which level from num_channels contain an attention block.
292
305
  with_nonlocal_attn: if True use non-local attention block.
293
306
  use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder.
307
+ include_fc: whether to include the final linear layer. Default to True.
308
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
309
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
310
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
294
311
  """
295
312
 
296
313
  def __init__(
@@ -305,6 +322,9 @@ class Decoder(nn.Module):
305
322
  attention_levels: Sequence[bool],
306
323
  with_nonlocal_attn: bool = True,
307
324
  use_convtranspose: bool = False,
325
+ include_fc: bool = True,
326
+ use_combined_linear: bool = False,
327
+ use_flash_attention: bool = False,
308
328
  ) -> None:
309
329
  super().__init__()
310
330
  self.spatial_dims = spatial_dims
@@ -350,6 +370,9 @@ class Decoder(nn.Module):
350
370
  num_channels=reversed_block_out_channels[0],
351
371
  norm_num_groups=norm_num_groups,
352
372
  norm_eps=norm_eps,
373
+ include_fc=include_fc,
374
+ use_combined_linear=use_combined_linear,
375
+ use_flash_attention=use_flash_attention,
353
376
  )
354
377
  )
355
378
  blocks.append(
@@ -389,6 +412,9 @@ class Decoder(nn.Module):
389
412
  num_channels=block_in_ch,
390
413
  norm_num_groups=norm_num_groups,
391
414
  norm_eps=norm_eps,
415
+ include_fc=include_fc,
416
+ use_combined_linear=use_combined_linear,
417
+ use_flash_attention=use_flash_attention,
392
418
  )
393
419
  )
394
420
 
@@ -463,6 +489,10 @@ class AutoencoderKL(nn.Module):
463
489
  with_decoder_nonlocal_attn: if True use non-local attention block in the decoder.
464
490
  use_checkpoint: if True, use activation checkpoint to save memory.
465
491
  use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder.
492
+ include_fc: whether to include the final linear layer in the attention block. Default to True.
493
+ use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
494
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
495
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
466
496
  """
467
497
 
468
498
  def __init__(
@@ -480,6 +510,9 @@ class AutoencoderKL(nn.Module):
480
510
  with_decoder_nonlocal_attn: bool = True,
481
511
  use_checkpoint: bool = False,
482
512
  use_convtranspose: bool = False,
513
+ include_fc: bool = True,
514
+ use_combined_linear: bool = False,
515
+ use_flash_attention: bool = False,
483
516
  ) -> None:
484
517
  super().__init__()
485
518
 
@@ -499,7 +532,7 @@ class AutoencoderKL(nn.Module):
499
532
  "`num_channels`."
500
533
  )
501
534
 
502
- self.encoder = Encoder(
535
+ self.encoder: nn.Module = Encoder(
503
536
  spatial_dims=spatial_dims,
504
537
  in_channels=in_channels,
505
538
  channels=channels,
@@ -509,8 +542,11 @@ class AutoencoderKL(nn.Module):
509
542
  norm_eps=norm_eps,
510
543
  attention_levels=attention_levels,
511
544
  with_nonlocal_attn=with_encoder_nonlocal_attn,
545
+ include_fc=include_fc,
546
+ use_combined_linear=use_combined_linear,
547
+ use_flash_attention=use_flash_attention,
512
548
  )
513
- self.decoder = Decoder(
549
+ self.decoder: nn.Module = Decoder(
514
550
  spatial_dims=spatial_dims,
515
551
  channels=channels,
516
552
  in_channels=latent_channels,
@@ -521,6 +557,9 @@ class AutoencoderKL(nn.Module):
521
557
  attention_levels=attention_levels,
522
558
  with_nonlocal_attn=with_decoder_nonlocal_attn,
523
559
  use_convtranspose=use_convtranspose,
560
+ include_fc=include_fc,
561
+ use_combined_linear=use_combined_linear,
562
+ use_flash_attention=use_flash_attention,
524
563
  )
525
564
  self.quant_conv_mu = Convolution(
526
565
  spatial_dims=spatial_dims,
@@ -665,27 +704,18 @@ class AutoencoderKL(nn.Module):
665
704
  # copy over all matching keys
666
705
  for k in new_state_dict:
667
706
  if k in old_state_dict:
668
- new_state_dict[k] = old_state_dict[k]
707
+ new_state_dict[k] = old_state_dict.pop(k)
669
708
 
670
709
  # fix the attention blocks
671
- attention_blocks = [k.replace(".attn.qkv.weight", "") for k in new_state_dict if "attn.qkv.weight" in k]
710
+ attention_blocks = [k.replace(".attn.to_q.weight", "") for k in new_state_dict if "attn.to_q.weight" in k]
672
711
  for block in attention_blocks:
673
- new_state_dict[f"{block}.attn.qkv.weight"] = torch.cat(
674
- [
675
- old_state_dict[f"{block}.to_q.weight"],
676
- old_state_dict[f"{block}.to_k.weight"],
677
- old_state_dict[f"{block}.to_v.weight"],
678
- ],
679
- dim=0,
680
- )
681
- new_state_dict[f"{block}.attn.qkv.bias"] = torch.cat(
682
- [
683
- old_state_dict[f"{block}.to_q.bias"],
684
- old_state_dict[f"{block}.to_k.bias"],
685
- old_state_dict[f"{block}.to_v.bias"],
686
- ],
687
- dim=0,
688
- )
712
+ new_state_dict[f"{block}.attn.to_q.weight"] = old_state_dict.pop(f"{block}.to_q.weight")
713
+ new_state_dict[f"{block}.attn.to_k.weight"] = old_state_dict.pop(f"{block}.to_k.weight")
714
+ new_state_dict[f"{block}.attn.to_v.weight"] = old_state_dict.pop(f"{block}.to_v.weight")
715
+ new_state_dict[f"{block}.attn.to_q.bias"] = old_state_dict.pop(f"{block}.to_q.bias")
716
+ new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias")
717
+ new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias")
718
+
689
719
  # old version did not have a projection so set these to the identity
690
720
  new_state_dict[f"{block}.attn.out_proj.weight"] = torch.eye(
691
721
  new_state_dict[f"{block}.attn.out_proj.weight"].shape[0]
@@ -698,5 +728,8 @@ class AutoencoderKL(nn.Module):
698
728
  for k in new_state_dict:
699
729
  if "postconv" in k:
700
730
  old_name = k.replace("postconv", "conv")
701
- new_state_dict[k] = old_state_dict[old_name]
702
- self.load_state_dict(new_state_dict)
731
+ new_state_dict[k] = old_state_dict.pop(old_name)
732
+ if verbose:
733
+ # print all remaining keys in old_state_dict
734
+ print("remaining keys in old_state_dict:", old_state_dict.keys())
735
+ self.load_state_dict(new_state_dict, strict=True)
@@ -0,0 +1,92 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import torch
15
+ from torch import nn
16
+ from torch.nn import functional as F
17
+
18
+ from monai.utils import optional_import
19
+
20
+ build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b")
21
+
22
+ _all__ = ["CellSamWrapper"]
23
+
24
+
25
+ class CellSamWrapper(torch.nn.Module):
26
+ """
27
+ CellSamWrapper is thin wrapper around SAM model https://github.com/facebookresearch/segment-anything
28
+ with an image only decoder, that can be used for segmentation tasks.
29
+
30
+
31
+ Args:
32
+ auto_resize_inputs: whether to resize inputs before passing to the network.
33
+ (usually they need be resized, unless they are already at the expected size)
34
+ network_resize_roi: expected input size for the network.
35
+ (currently SAM expects 1024x1024)
36
+ checkpoint: checkpoint file to load the SAM weights from.
37
+ (this can be downloaded from SAM repo https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth)
38
+ return_features: whether to return features from SAM encoder
39
+ (without using decoder/upsampling to the original input size)
40
+
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ auto_resize_inputs=True,
46
+ network_resize_roi=(1024, 1024),
47
+ checkpoint="sam_vit_b_01ec64.pth",
48
+ return_features=False,
49
+ *args,
50
+ **kwargs,
51
+ ) -> None:
52
+ super().__init__(*args, **kwargs)
53
+
54
+ self.network_resize_roi = network_resize_roi
55
+ self.auto_resize_inputs = auto_resize_inputs
56
+ self.return_features = return_features
57
+
58
+ if not has_sam:
59
+ raise ValueError(
60
+ "SAM is not installed, please run: pip install git+https://github.com/facebookresearch/segment-anything.git"
61
+ )
62
+
63
+ model = build_sam_vit_b(checkpoint=checkpoint)
64
+
65
+ model.prompt_encoder = None
66
+ model.mask_decoder = None
67
+
68
+ model.mask_decoder = nn.Sequential(
69
+ nn.BatchNorm2d(num_features=256),
70
+ nn.ReLU(inplace=True),
71
+ nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
72
+ nn.BatchNorm2d(num_features=128),
73
+ nn.ReLU(inplace=True),
74
+ nn.ConvTranspose2d(128, 3, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True),
75
+ )
76
+
77
+ self.model = model
78
+
79
+ def forward(self, x):
80
+ sh = x.shape[2:]
81
+
82
+ if self.auto_resize_inputs:
83
+ x = F.interpolate(x, size=self.network_resize_roi, mode="bilinear")
84
+
85
+ x = self.model.image_encoder(x)
86
+
87
+ if not self.return_features:
88
+ x = self.model.mask_decoder(x)
89
+ if self.auto_resize_inputs:
90
+ x = F.interpolate(x, size=sh, mode="bilinear")
91
+
92
+ return x
@@ -143,6 +143,10 @@ class ControlNet(nn.Module):
143
143
  upcast_attention: if True, upcast attention operations to full precision.
144
144
  conditioning_embedding_in_channels: number of input channels for the conditioning embedding.
145
145
  conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding.
146
+ include_fc: whether to include the final linear layer. Default to True.
147
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
148
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
149
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
146
150
  """
147
151
 
148
152
  def __init__(
@@ -163,28 +167,29 @@ class ControlNet(nn.Module):
163
167
  upcast_attention: bool = False,
164
168
  conditioning_embedding_in_channels: int = 1,
165
169
  conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256),
170
+ include_fc: bool = True,
171
+ use_combined_linear: bool = False,
172
+ use_flash_attention: bool = False,
166
173
  ) -> None:
167
174
  super().__init__()
168
175
  if with_conditioning is True and cross_attention_dim is None:
169
176
  raise ValueError(
170
- "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) "
177
+ "ControlNet expects dimension of the cross-attention conditioning (cross_attention_dim) "
171
178
  "to be specified when with_conditioning=True."
172
179
  )
173
180
  if cross_attention_dim is not None and with_conditioning is False:
174
- raise ValueError(
175
- "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim."
176
- )
181
+ raise ValueError("ControlNet expects with_conditioning=True when specifying the cross_attention_dim.")
177
182
 
178
183
  # All number of channels should be multiple of num_groups
179
184
  if any((out_channel % norm_num_groups) != 0 for out_channel in channels):
180
185
  raise ValueError(
181
- f"DiffusionModelUNet expects all channels to be a multiple of norm_num_groups, but got"
186
+ f"ControlNet expects all channels to be a multiple of norm_num_groups, but got"
182
187
  f" channels={channels} and norm_num_groups={norm_num_groups}"
183
188
  )
184
189
 
185
190
  if len(channels) != len(attention_levels):
186
191
  raise ValueError(
187
- f"DiffusionModelUNet expects channels to have the same length as attention_levels, but got "
192
+ f"ControlNet expects channels to have the same length as attention_levels, but got "
188
193
  f"channels={channels} and attention_levels={attention_levels}"
189
194
  )
190
195
 
@@ -282,6 +287,9 @@ class ControlNet(nn.Module):
282
287
  transformer_num_layers=transformer_num_layers,
283
288
  cross_attention_dim=cross_attention_dim,
284
289
  upcast_attention=upcast_attention,
290
+ include_fc=include_fc,
291
+ use_combined_linear=use_combined_linear,
292
+ use_flash_attention=use_flash_attention,
285
293
  )
286
294
 
287
295
  self.down_blocks.append(down_block)
@@ -326,6 +334,9 @@ class ControlNet(nn.Module):
326
334
  transformer_num_layers=transformer_num_layers,
327
335
  cross_attention_dim=cross_attention_dim,
328
336
  upcast_attention=upcast_attention,
337
+ include_fc=include_fc,
338
+ use_combined_linear=use_combined_linear,
339
+ use_flash_attention=use_flash_attention,
329
340
  )
330
341
 
331
342
  controlnet_block = Convolution(
@@ -441,25 +452,16 @@ class ControlNet(nn.Module):
441
452
  # copy over all matching keys
442
453
  for k in new_state_dict:
443
454
  if k in old_state_dict:
444
- new_state_dict[k] = old_state_dict[k]
455
+ new_state_dict[k] = old_state_dict.pop(k)
445
456
 
446
457
  # fix the attention blocks
447
- attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k]
458
+ attention_blocks = [k.replace(".out_proj.weight", "") for k in new_state_dict if "out_proj.weight" in k]
448
459
  for block in attention_blocks:
449
- new_state_dict[f"{block}.attn1.qkv.weight"] = torch.cat(
450
- [
451
- old_state_dict[f"{block}.attn1.to_q.weight"],
452
- old_state_dict[f"{block}.attn1.to_k.weight"],
453
- old_state_dict[f"{block}.attn1.to_v.weight"],
454
- ],
455
- dim=0,
456
- )
457
-
458
460
  # projection
459
- new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"]
460
- new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"]
461
-
462
- new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"]
463
- new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"]
461
+ new_state_dict[f"{block}.out_proj.weight"] = old_state_dict.pop(f"{block}.to_out.0.weight")
462
+ new_state_dict[f"{block}.out_proj.bias"] = old_state_dict.pop(f"{block}.to_out.0.bias")
464
463
 
464
+ if verbose:
465
+ # print all remaining keys in old_state_dict
466
+ print("remaining keys in old_state_dict:", old_state_dict.keys())
465
467
  self.load_state_dict(new_state_dict)