kaiko-eva 0.1.8__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. eva/core/data/datasets/base.py +7 -2
  2. eva/core/data/datasets/classification/embeddings.py +2 -2
  3. eva/core/data/datasets/classification/multi_embeddings.py +2 -2
  4. eva/core/data/datasets/embeddings.py +4 -4
  5. eva/core/data/samplers/classification/balanced.py +19 -18
  6. eva/core/loggers/utils/wandb.py +33 -0
  7. eva/core/models/modules/head.py +5 -3
  8. eva/core/models/modules/typings.py +2 -2
  9. eva/core/models/transforms/__init__.py +2 -1
  10. eva/core/models/transforms/as_discrete.py +57 -0
  11. eva/core/models/wrappers/_utils.py +121 -1
  12. eva/core/trainers/functional.py +8 -5
  13. eva/core/trainers/trainer.py +32 -17
  14. eva/core/utils/suppress_logs.py +28 -0
  15. eva/vision/data/__init__.py +2 -2
  16. eva/vision/data/dataloaders/__init__.py +5 -0
  17. eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
  18. eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
  19. eva/vision/data/datasets/__init__.py +10 -2
  20. eva/vision/data/datasets/classification/__init__.py +9 -0
  21. eva/vision/data/datasets/classification/bach.py +3 -4
  22. eva/vision/data/datasets/classification/bracs.py +111 -0
  23. eva/vision/data/datasets/classification/breakhis.py +209 -0
  24. eva/vision/data/datasets/classification/camelyon16.py +4 -5
  25. eva/vision/data/datasets/classification/crc.py +3 -4
  26. eva/vision/data/datasets/classification/gleason_arvaniti.py +171 -0
  27. eva/vision/data/datasets/classification/mhist.py +3 -4
  28. eva/vision/data/datasets/classification/panda.py +4 -5
  29. eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
  30. eva/vision/data/datasets/classification/unitopatho.py +158 -0
  31. eva/vision/data/datasets/classification/wsi.py +6 -5
  32. eva/vision/data/datasets/segmentation/__init__.py +2 -2
  33. eva/vision/data/datasets/segmentation/_utils.py +47 -0
  34. eva/vision/data/datasets/segmentation/bcss.py +7 -8
  35. eva/vision/data/datasets/segmentation/btcv.py +236 -0
  36. eva/vision/data/datasets/segmentation/consep.py +6 -7
  37. eva/vision/data/datasets/segmentation/embeddings.py +2 -2
  38. eva/vision/data/datasets/segmentation/lits.py +9 -8
  39. eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
  40. eva/vision/data/datasets/segmentation/monusac.py +4 -5
  41. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
  42. eva/vision/data/datasets/vision.py +95 -4
  43. eva/vision/data/datasets/wsi.py +5 -5
  44. eva/vision/data/transforms/__init__.py +22 -3
  45. eva/vision/data/transforms/common/__init__.py +1 -2
  46. eva/vision/data/transforms/croppad/__init__.py +11 -0
  47. eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
  48. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
  49. eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
  50. eva/vision/data/transforms/intensity/__init__.py +11 -0
  51. eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
  52. eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
  53. eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
  54. eva/vision/data/transforms/spatial/__init__.py +7 -0
  55. eva/vision/data/transforms/spatial/flip.py +72 -0
  56. eva/vision/data/transforms/spatial/rotate.py +53 -0
  57. eva/vision/data/transforms/spatial/spacing.py +69 -0
  58. eva/vision/data/transforms/utility/__init__.py +5 -0
  59. eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
  60. eva/vision/data/tv_tensors/__init__.py +5 -0
  61. eva/vision/data/tv_tensors/volume.py +61 -0
  62. eva/vision/metrics/segmentation/monai_dice.py +9 -2
  63. eva/vision/models/modules/semantic_segmentation.py +28 -20
  64. eva/vision/models/networks/backbones/__init__.py +9 -2
  65. eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
  66. eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
  67. eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
  68. eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
  69. eva/vision/models/networks/backbones/pathology/mahmood.py +46 -19
  70. eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
  71. eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
  72. eva/vision/models/networks/backbones/radiology/voco.py +75 -0
  73. eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
  74. eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
  75. eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
  76. eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
  77. eva/vision/utils/io/__init__.py +2 -0
  78. eva/vision/utils/io/nifti.py +91 -11
  79. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/METADATA +3 -1
  80. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/RECORD +83 -62
  81. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/WHEEL +1 -1
  82. eva/vision/data/datasets/classification/base.py +0 -96
  83. eva/vision/data/datasets/segmentation/base.py +0 -96
  84. eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
  85. eva/vision/data/transforms/normalization/__init__.py +0 -6
  86. eva/vision/data/transforms/normalization/clamp.py +0 -43
  87. eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
  88. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
  89. eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
  90. eva/vision/metrics/segmentation/BUILD +0 -1
  91. eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
  92. eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
  93. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/entry_points.txt +0 -0
  94. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -5,6 +5,9 @@ from typing import Tuple
5
5
  import timm
6
6
  from torch import nn
7
7
 
8
+ from eva.core.models import transforms
9
+ from eva.vision.models import wrappers
10
+ from eva.vision.models.networks.backbones import _utils
8
11
  from eva.vision.models.networks.backbones.registry import register_model
9
12
 
10
13
 
@@ -13,7 +16,9 @@ def bioptimus_h_optimus_0(
13
16
  dynamic_img_size: bool = True,
14
17
  out_indices: int | Tuple[int, ...] | None = None,
15
18
  ) -> nn.Module:
16
- """Initializes the h_optimus_0 pathology FM by Bioptimus.
19
+ """Initializes the H-Optimus-0 pathology FM by Bioptimus.
20
+
21
+ See https://huggingface.co/bioptimus/H-optimus-0 for details.
17
22
 
18
23
  Args:
19
24
  dynamic_img_size: Whether to allow the interpolation embedding
@@ -32,3 +37,44 @@ def bioptimus_h_optimus_0(
32
37
  out_indices=out_indices,
33
38
  features_only=out_indices is not None,
34
39
  )
40
+
41
+
42
+ @register_model("pathology/bioptimus_h0_mini")
43
+ def bioptimus_h0_mini(
44
+ dynamic_img_size: bool = True,
45
+ out_indices: int | Tuple[int, ...] | None = None,
46
+ hf_token: str | None = None,
47
+ include_patch_tokens: bool = False,
48
+ ) -> nn.Module:
49
+ """Initializes H0-mini (ViT-B) pathology FM by Bioptimus.
50
+
51
+ This model was distilled from H-Optimus-0 on 40M TCGA tiles.
52
+
53
+ See https://huggingface.co/bioptimus/H0-mini for details.
54
+
55
+ Args:
56
+ dynamic_img_size: Support different input image sizes by allowing to change
57
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
58
+ out_indices: Whether and which multi-level patch embeddings to return.
59
+ hf_token: HuggingFace token to download the model.
60
+ include_patch_tokens: Whether to combine the mean aggregated patch tokens with cls token.
61
+
62
+ Returns:
63
+ The model instance.
64
+ """
65
+ _utils.huggingface_login(hf_token)
66
+ return wrappers.TimmModel(
67
+ model_name="hf-hub:bioptimus/H0-mini",
68
+ out_indices=out_indices,
69
+ pretrained=True,
70
+ model_kwargs={
71
+ "dynamic_img_size": dynamic_img_size,
72
+ "mlp_layer": timm.layers.SwiGLUPacked,
73
+ "act_layer": nn.SiLU,
74
+ },
75
+ tensor_transforms=(
76
+ transforms.ExtractCLSFeatures(include_patch_tokens=include_patch_tokens)
77
+ if out_indices is None
78
+ else None
79
+ ),
80
+ )
@@ -0,0 +1,69 @@
1
+ """Pathology FMs from Hong Kong University of Science and Technology."""
2
+
3
+ import re
4
+ from typing import Tuple
5
+
6
+ import timm
7
+ from torch import nn
8
+
9
+ from eva.core.models.wrappers import _utils
10
+ from eva.vision.models.networks.backbones.registry import register_model
11
+
12
+
13
+ @register_model("pathology/hkust_gpfm")
14
+ def hkust_gpfm(
15
+ dynamic_img_size: bool = True,
16
+ out_indices: int | Tuple[int, ...] | None = None,
17
+ ) -> nn.Module:
18
+ """Initializes GPFM model from Hong Kong University of Science and Technology.
19
+
20
+ Ma, J., Guo, Z., Zhou, F., Wang, Y., Xu, Y., et al. (2024).
21
+ Towards a generalizable pathology foundation model via unified knowledge
22
+ distillation (arXiv No. 2407.18449). arXiv. https://arxiv.org/abs/2407.18449
23
+
24
+ Args:
25
+ dynamic_img_size: Support different input image sizes by allowing to change
26
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
27
+ out_indices: Whether and which multi-level patch embeddings to return.
28
+
29
+ Returns:
30
+ The model instance.
31
+ """
32
+ return timm.create_model(
33
+ model_name="vit_large_patch14_dinov2",
34
+ pretrained=True,
35
+ pretrained_cfg={
36
+ "state_dict": _load_state_dict(),
37
+ "num_classes": 0,
38
+ },
39
+ out_indices=out_indices,
40
+ features_only=out_indices is not None,
41
+ **{
42
+ "img_size": 224,
43
+ "patch_size": 14,
44
+ "init_values": 1e-5,
45
+ "qkv_bias": True,
46
+ "dynamic_img_size": dynamic_img_size,
47
+ },
48
+ )
49
+
50
+
51
+ def _load_state_dict() -> dict:
52
+ """Loads the state dict with model weights from github."""
53
+ state_dict = _utils.load_state_dict_from_url(
54
+ url="https://github.com/birkhoffkiki/GPFM/releases/download/ckpt/GPFM.pth",
55
+ md5="0dc7e345de84f385d09c8c782b4b3236",
56
+ )
57
+ return _convert_state_dict(state_dict["teacher"])
58
+
59
+
60
+ def _convert_state_dict(state_dict: dict) -> dict:
61
+ """Rename state dict keys to match timm's format."""
62
+ state_dict = {
63
+ re.sub(r"blocks\.\d+\.(\d+)", r"blocks.\1", key.replace("backbone.", "")): value
64
+ for key, value in state_dict.items()
65
+ }
66
+ remove_keys = ["mask_token"] + [key for key in state_dict.keys() if "dino_head" in key]
67
+ for key in remove_keys:
68
+ state_dict.pop(key)
69
+ return state_dict
@@ -5,9 +5,27 @@ from typing import Tuple
5
5
  import torch
6
6
  from torch import nn
7
7
 
8
+ from eva.vision.models.networks.backbones import _utils
8
9
  from eva.vision.models.networks.backbones.registry import register_model
9
10
 
10
11
 
12
+ @register_model("pathology/kaiko_midnight_12k")
13
+ def kaiko_midnight_12k(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
14
+ """Initializes the Midnight-12k pathology FM by kaiko.ai.
15
+
16
+ Args:
17
+ out_indices: Whether and which multi-level patch embeddings to return.
18
+
19
+ Returns:
20
+ The model instance.
21
+ """
22
+ return _utils.load_hugingface_model(
23
+ model_name="kaiko-ai/midnight",
24
+ out_indices=out_indices,
25
+ model_kwargs={"trust_remote_code": True},
26
+ )
27
+
28
+
11
29
  @register_model("pathology/kaiko_vits16")
12
30
  def kaiko_vits16(
13
31
  dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
@@ -1,11 +1,9 @@
1
1
  """Pathology FMs from MahmoodLab."""
2
2
 
3
- import os
4
- from pathlib import Path
5
3
  from typing import Tuple
6
4
 
7
- import huggingface_hub
8
- from loguru import logger
5
+ import timm
6
+ import torch
9
7
  from torch import nn
10
8
 
11
9
  from eva.vision.models import wrappers
@@ -18,7 +16,6 @@ def mahmood_uni(
18
16
  dynamic_img_size: bool = True,
19
17
  out_indices: int | Tuple[int, ...] | None = None,
20
18
  hf_token: str | None = None,
21
- download_dir: str = os.path.join(str(Path.home()), ".cache/eva"),
22
19
  ) -> nn.Module:
23
20
  """Initializes UNI model from MahmoodLab.
24
21
 
@@ -27,29 +24,59 @@ def mahmood_uni(
27
24
  the grid size (interpolate abs and/or ROPE pos) in the forward pass.
28
25
  out_indices: Whether and which multi-level patch embeddings to return.
29
26
  hf_token: HuggingFace token to download the model.
30
- download_dir: Directory to download the model checkpoint.
31
27
 
32
28
  Returns:
33
29
  The model instance.
34
30
  """
35
- checkpoint_path = os.path.join(download_dir, "pytorch_model.bin")
36
- if not os.path.exists(checkpoint_path):
37
- logger.info(f"Downloading the model checkpoint to {download_dir} ...")
38
- os.makedirs(download_dir, exist_ok=True)
39
- _utils.huggingface_login(hf_token)
40
- huggingface_hub.hf_hub_download(
41
- "MahmoodLab/UNI",
42
- filename="pytorch_model.bin",
43
- local_dir=download_dir,
44
- force_download=True,
45
- )
31
+ _utils.huggingface_login(hf_token)
46
32
 
47
33
  return wrappers.TimmModel(
48
- model_name="vit_large_patch16_224",
34
+ model_name="hf-hub:MahmoodLab/uni",
35
+ pretrained=True,
49
36
  out_indices=out_indices,
50
37
  model_kwargs={
51
38
  "init_values": 1e-5,
52
39
  "dynamic_img_size": dynamic_img_size,
53
40
  },
54
- checkpoint_path=checkpoint_path,
41
+ )
42
+
43
+
44
+ @register_model("pathology/mahmood_uni2_h")
45
+ def mahmood_uni2_h(
46
+ dynamic_img_size: bool = True,
47
+ out_indices: int | Tuple[int, ...] | None = None,
48
+ hf_token: str | None = None,
49
+ ) -> nn.Module:
50
+ """Initializes UNI model from MahmoodLab.
51
+
52
+ Args:
53
+ dynamic_img_size: Support different input image sizes by allowing to change
54
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
55
+ out_indices: Whether and which multi-level patch embeddings to return.
56
+ hf_token: HuggingFace token to download the model.
57
+
58
+ Returns:
59
+ The model instance.
60
+ """
61
+ _utils.huggingface_login(hf_token)
62
+
63
+ return wrappers.TimmModel(
64
+ model_name="hf-hub:MahmoodLab/UNI2-h",
65
+ pretrained=True,
66
+ out_indices=out_indices,
67
+ model_kwargs={
68
+ "img_size": 224,
69
+ "patch_size": 14,
70
+ "depth": 24,
71
+ "num_heads": 24,
72
+ "init_values": 1e-5,
73
+ "embed_dim": 1536,
74
+ "mlp_ratio": 2.66667 * 2,
75
+ "num_classes": 0,
76
+ "no_embed_class": True,
77
+ "mlp_layer": timm.layers.SwiGLUPacked,
78
+ "act_layer": torch.nn.SiLU,
79
+ "reg_tokens": 8,
80
+ "dynamic_img_size": dynamic_img_size,
81
+ },
55
82
  )
@@ -0,0 +1,11 @@
1
+ """Vision Radiology Model Backbones API."""
2
+
3
+ from eva.vision.models.networks.backbones.radiology.swin_unetr import SwinUNETREncoder
4
+ from eva.vision.models.networks.backbones.radiology.voco import VoCoB, VoCoH, VoCoL
5
+
6
+ __all__ = [
7
+ "VoCoB",
8
+ "VoCoL",
9
+ "VoCoH",
10
+ "SwinUNETREncoder",
11
+ ]
@@ -0,0 +1,231 @@
1
+ """Encoder based on Swin UNETR."""
2
+
3
+ from typing import List, Tuple
4
+
5
+ import torch
6
+ from monai.inferers.inferer import Inferer
7
+ from monai.networks.blocks import unetr_block
8
+ from monai.networks.nets import swin_unetr
9
+ from monai.utils import misc
10
+ from torch import nn
11
+
12
+ from eva.vision.models.networks.backbones.registry import register_model
13
+
14
+
15
+ @register_model("radiology/swin_unetr_encoder")
16
+ class SwinUNETREncoder(nn.Module):
17
+ """Swin transformer encoder based on UNETR [0].
18
+
19
+ - [0] UNETR: Transformers for 3D Medical Image Segmentation
20
+ https://arxiv.org/pdf/2103.10504
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ in_channels: int = 1,
26
+ feature_size: int = 48,
27
+ spatial_dims: int = 3,
28
+ out_indices: int | None = None,
29
+ inferer: Inferer | None = None,
30
+ use_v2: bool = True,
31
+ ) -> None:
32
+ """Build the UNETR encoder.
33
+
34
+ Args:
35
+ in_channels: Number of input channels.
36
+ feature_size: The dimension of network feature size.
37
+ spatial_dims: Number of spatial dimensions.
38
+ out_indices: Number of feature outputs. If None,
39
+ the aggregated feature vector is returned.
40
+ inferer: An optional MONAI `Inferer` for efficient
41
+ inference during evaluation.
42
+ use_v2: Whether to use SwinTransformerV2.
43
+ """
44
+ super().__init__()
45
+
46
+ self._in_channels = in_channels
47
+ self._feature_size = feature_size
48
+ self._spatial_dims = spatial_dims
49
+ self._out_indices = out_indices
50
+ self._inferer = inferer
51
+ self._use_v2 = use_v2
52
+
53
+ self._window_size = misc.ensure_tuple_rep(7, spatial_dims)
54
+ self._patch_size = misc.ensure_tuple_rep(2, spatial_dims)
55
+
56
+ self.swinViT = swin_unetr.SwinTransformer(
57
+ in_chans=in_channels,
58
+ embed_dim=feature_size,
59
+ window_size=self._window_size,
60
+ patch_size=self._patch_size,
61
+ depths=(2, 2, 2, 2),
62
+ num_heads=(3, 6, 12, 24),
63
+ mlp_ratio=4.0,
64
+ qkv_bias=True,
65
+ drop_rate=0.0,
66
+ attn_drop_rate=0.0,
67
+ drop_path_rate=0.0,
68
+ norm_layer=torch.nn.LayerNorm,
69
+ spatial_dims=spatial_dims,
70
+ use_v2=use_v2,
71
+ )
72
+ self.encoder1 = unetr_block.UnetrBasicBlock(
73
+ spatial_dims=spatial_dims,
74
+ in_channels=in_channels,
75
+ out_channels=feature_size,
76
+ kernel_size=3,
77
+ stride=1,
78
+ norm_name="instance",
79
+ res_block=True,
80
+ )
81
+ self.encoder2 = unetr_block.UnetrBasicBlock(
82
+ spatial_dims=spatial_dims,
83
+ in_channels=feature_size,
84
+ out_channels=feature_size,
85
+ kernel_size=3,
86
+ stride=1,
87
+ norm_name="instance",
88
+ res_block=True,
89
+ )
90
+ self.encoder3 = unetr_block.UnetrBasicBlock(
91
+ spatial_dims=spatial_dims,
92
+ in_channels=2 * feature_size,
93
+ out_channels=2 * feature_size,
94
+ kernel_size=3,
95
+ stride=1,
96
+ norm_name="instance",
97
+ res_block=True,
98
+ )
99
+ self.encoder4 = unetr_block.UnetrBasicBlock(
100
+ spatial_dims=spatial_dims,
101
+ in_channels=4 * feature_size,
102
+ out_channels=4 * feature_size,
103
+ kernel_size=3,
104
+ stride=1,
105
+ norm_name="instance",
106
+ res_block=True,
107
+ )
108
+ self.encoder10 = unetr_block.UnetrBasicBlock(
109
+ spatial_dims=spatial_dims,
110
+ in_channels=16 * feature_size,
111
+ out_channels=16 * feature_size,
112
+ kernel_size=3,
113
+ stride=1,
114
+ norm_name="instance",
115
+ res_block=True,
116
+ )
117
+ self._pool_op = (
118
+ nn.AdaptiveAvgPool3d(output_size=(1, 1, 1))
119
+ if spatial_dims == 3
120
+ else nn.AdaptiveAvgPool2d(output_size=(1, 1))
121
+ )
122
+
123
+ def _forward_features(self, tensor: torch.Tensor) -> List[torch.Tensor]:
124
+ """Extracts feature maps from the Swin Transformer and encoder blocks.
125
+
126
+ Args:
127
+ tensor: Input tensor of shape (B, C, T, H, W).
128
+
129
+ Returns:
130
+ List of feature maps from encoder stages.
131
+ """
132
+ hidden_states = self.swinViT(tensor)
133
+ enc0 = self.encoder1(tensor)
134
+ enc1 = self.encoder2(hidden_states[0])
135
+ enc2 = self.encoder3(hidden_states[1])
136
+ enc3 = self.encoder4(hidden_states[2])
137
+ dec4 = self.encoder10(hidden_states[4])
138
+ return [enc0, enc1, enc2, enc3, hidden_states[3], dec4]
139
+
140
+ def forward_features(self, tensor: torch.Tensor) -> List[torch.Tensor]:
141
+ """Computes feature maps using either standard forward pass or inference mode.
142
+
143
+ If in inference mode (`self.training` is False) and an inference method
144
+ (`self._inferer`) is available, the `_inferer` is used to extract features.
145
+ Otherwise, `_forward_features` is called directly.
146
+
147
+ Args:
148
+ tensor: Input tensor of shape (B, C, T, H, W).
149
+
150
+ Returns:
151
+ List of feature maps from encoder stages.
152
+ """
153
+ if not self.training and self._inferer:
154
+ return self._inferer(inputs=tensor, network=self._forward_features)
155
+
156
+ return self._forward_features(tensor)
157
+
158
+ def forward_encoders(self, features: List[torch.Tensor]) -> torch.Tensor:
159
+ """Aggregates encoder features into a single feature vector.
160
+
161
+ Args:
162
+ features: List of feature maps from encoder stages.
163
+
164
+ Returns:
165
+ Aggregated feature vector (B, C').
166
+ """
167
+ batch_size = features[0].shape[0]
168
+ reduced_features = []
169
+ for patch_features in features[:4] + features[5:]:
170
+ hidden_features = self._pool_op(patch_features)
171
+ hidden_features_reduced = hidden_features.view(batch_size, -1)
172
+ reduced_features.append(hidden_features_reduced)
173
+ return torch.cat(reduced_features, dim=1)
174
+
175
+ def forward_head(self, features: List[torch.Tensor]) -> torch.Tensor:
176
+ """Casts last feature map into a single feature vector.
177
+
178
+ Args:
179
+ features: List of encoder feature maps.
180
+
181
+ Returns:
182
+ Aggregated feature vector (B, C').
183
+ """
184
+ last_feature_map = features[-1]
185
+ pooled_features = self._pool_op(last_feature_map)
186
+ return torch.flatten(pooled_features, 1)
187
+
188
+ def forward_embeddings(self, tensor: torch.Tensor) -> torch.Tensor:
189
+ """Computes the final aggregated feature vector.
190
+
191
+ Args:
192
+ tensor: Input tensor of shape (B, C, T, H, W).
193
+
194
+ Returns:
195
+ Aggregated feature vector of shape (B, C').
196
+ """
197
+ intermediates = self.forward_features(tensor)
198
+ return self.forward_encoders(intermediates)
199
+
200
+ def forward_intermediates(
201
+ self, tensor: torch.Tensor
202
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
203
+ """Computes encoder features and their embeddings.
204
+
205
+ Args:
206
+ tensor: Input tensor of shape (B, C, T, H, W).
207
+
208
+ Returns:
209
+ Aggregated feature vector and list of intermediate features.
210
+ """
211
+ features = self.forward_features(tensor)
212
+ embeddings = self.forward_encoders(features)
213
+ return embeddings, features
214
+
215
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor | List[torch.Tensor]:
216
+ """Forward pass through the encoder.
217
+
218
+ If `self._out_indices` is None, it returns the aggregated feature vector.
219
+ Otherwise, it returns the intermediate feature maps up to the specified index.
220
+
221
+ Args:
222
+ tensor: Input tensor of shape (B, C, T, H, W).
223
+
224
+ Returns:
225
+ Aggregated feature vector or intermediate features.
226
+ """
227
+ if self._out_indices is None:
228
+ return self.forward_embeddings(tensor)
229
+
230
+ intermediates = self.forward_features(tensor)
231
+ return intermediates[-1 * self._out_indices :]
@@ -0,0 +1,75 @@
1
+ """VoCo Self-Supervised Encoders."""
2
+
3
+ from typing_extensions import override
4
+
5
+ from eva.core.models.wrappers import _utils
6
+ from eva.vision.models.networks.backbones.radiology import swin_unetr
7
+ from eva.vision.models.networks.backbones.registry import register_model
8
+
9
+
10
+ class _VoCo(swin_unetr.SwinUNETREncoder):
11
+ """Base class for the VoCo self-supervised encoders."""
12
+
13
+ _checkpoint: str
14
+ """Path to the model state dict."""
15
+
16
+ _md5: str | None = None
17
+ """State dict MD5 validation code."""
18
+
19
+ def __init__(self, feature_size: int, out_indices: int | None = None) -> None:
20
+ """Initializes the model.
21
+
22
+ Args:
23
+ feature_size: Size of the last feature map of SwinUNETR.
24
+ out_indices: The number of feature maps from intermediate blocks
25
+ to be returned. If set to 1, only the last feature map is returned.
26
+ """
27
+ super().__init__(
28
+ in_channels=1,
29
+ feature_size=feature_size,
30
+ spatial_dims=3,
31
+ out_indices=out_indices,
32
+ )
33
+
34
+ self._load_checkpoint()
35
+
36
+ def _load_checkpoint(self) -> None:
37
+ """Loads the model checkpoint."""
38
+ state_dict = _utils.load_state_dict_from_url(self._checkpoint, md5=self._md5)
39
+ self.load_state_dict(state_dict)
40
+
41
+
42
+ @register_model("radiology/voco_b")
43
+ class VoCoB(_VoCo):
44
+ """VoCo Self-supervised pre-trained B model."""
45
+
46
+ _checkpoint = "https://huggingface.co/Luffy503/VoCo/resolve/main/VoCo_B_SSL_head.pt"
47
+ _md5 = "f80c4da2f81d700bdae3df188f2057eb"
48
+
49
+ @override
50
+ def __init__(self, out_indices: int | None = None) -> None:
51
+ super().__init__(feature_size=48, out_indices=out_indices)
52
+
53
+
54
+ @register_model("radiology/voco_l")
55
+ class VoCoL(_VoCo):
56
+ """VoCo Self-supervised pre-trained L model."""
57
+
58
+ _checkpoint = "https://huggingface.co/Luffy503/VoCo/resolve/main/VoCo_L_SSL_head.pt"
59
+ _md5 = "795095d1d43ef3808ec4c41798310136"
60
+
61
+ @override
62
+ def __init__(self, out_indices: int | None = None) -> None:
63
+ super().__init__(feature_size=96, out_indices=out_indices)
64
+
65
+
66
+ @register_model("radiology/voco_h")
67
+ class VoCoH(_VoCo):
68
+ """VoCo Self-supervised pre-trained H model."""
69
+
70
+ _checkpoint = "https://huggingface.co/Luffy503/VoCo/resolve/main/VoCo_H_SSL_head.pt"
71
+ _md5 = "76f95a474736b60bf5b8aad94643744d"
72
+
73
+ @override
74
+ def __init__(self, out_indices: int | None = None) -> None:
75
+ super().__init__(feature_size=192, out_indices=out_indices)
@@ -1,5 +1,6 @@
1
1
  """Segmentation decoder heads API."""
2
2
 
3
+ from eva.vision.models.networks.decoders.segmentation.base import Decoder
3
4
  from eva.vision.models.networks.decoders.segmentation.decoder2d import Decoder2D
4
5
  from eva.vision.models.networks.decoders.segmentation.linear import LinearDecoder
5
6
  from eva.vision.models.networks.decoders.segmentation.semantic import (
@@ -7,13 +8,16 @@ from eva.vision.models.networks.decoders.segmentation.semantic import (
7
8
  ConvDecoderMS,
8
9
  ConvDecoderWithImage,
9
10
  SingleLinearDecoder,
11
+ SwinUNETRDecoder,
10
12
  )
11
13
 
12
14
  __all__ = [
15
+ "Decoder",
16
+ "Decoder2D",
13
17
  "ConvDecoder1x1",
14
18
  "ConvDecoderMS",
15
- "SingleLinearDecoder",
16
19
  "ConvDecoderWithImage",
17
- "Decoder2D",
18
20
  "LinearDecoder",
21
+ "SingleLinearDecoder",
22
+ "SwinUNETRDecoder",
19
23
  ]
@@ -7,6 +7,7 @@ from torch import nn
7
7
  from torch.nn import functional
8
8
 
9
9
  from eva.vision.models.networks.decoders.segmentation import base
10
+ from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs
10
11
 
11
12
 
12
13
  class LinearDecoder(base.Decoder):
@@ -104,22 +105,16 @@ class LinearDecoder(base.Decoder):
104
105
  """
105
106
  return functional.interpolate(logits, image_size, mode="bilinear")
106
107
 
107
- def forward(
108
- self,
109
- features: List[torch.Tensor],
110
- image_size: Tuple[int, int],
111
- ) -> torch.Tensor:
108
+ def forward(self, decoder_inputs: DecoderInputs) -> torch.Tensor:
112
109
  """Maps the patch embeddings to a segmentation mask of the image size.
113
110
 
114
111
  Args:
115
- features: List of multi-level image features of shape (batch_size,
116
- hidden_size, n_patches_height, n_patches_width).
117
- image_size: The target image size (height, width).
112
+ decoder_inputs: Inputs required by the decoder.
118
113
 
119
114
  Returns:
120
115
  Tensor containing scores for all of the classes with shape
121
116
  (batch_size, n_classes, image_height, image_width).
122
117
  """
123
- patch_embeddings = self._forward_features(features)
118
+ patch_embeddings = self._forward_features(decoder_inputs.features)
124
119
  logits = self._forward_head(patch_embeddings)
125
- return self._cls_seg(logits, image_size)
120
+ return self._cls_seg(logits, decoder_inputs.image_size)
@@ -5,8 +5,15 @@ from eva.vision.models.networks.decoders.segmentation.semantic.common import (
5
5
  ConvDecoderMS,
6
6
  SingleLinearDecoder,
7
7
  )
8
+ from eva.vision.models.networks.decoders.segmentation.semantic.swin_unetr import SwinUNETRDecoder
8
9
  from eva.vision.models.networks.decoders.segmentation.semantic.with_image import (
9
10
  ConvDecoderWithImage,
10
11
  )
11
12
 
12
- __all__ = ["ConvDecoder1x1", "ConvDecoderMS", "SingleLinearDecoder", "ConvDecoderWithImage"]
13
+ __all__ = [
14
+ "ConvDecoder1x1",
15
+ "ConvDecoderMS",
16
+ "ConvDecoderWithImage",
17
+ "SingleLinearDecoder",
18
+ "SwinUNETRDecoder",
19
+ ]