kaiko-eva 0.1.8__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eva/core/data/datasets/base.py +7 -2
- eva/core/data/datasets/classification/embeddings.py +2 -2
- eva/core/data/datasets/classification/multi_embeddings.py +2 -2
- eva/core/data/datasets/embeddings.py +4 -4
- eva/core/data/samplers/classification/balanced.py +19 -18
- eva/core/loggers/utils/wandb.py +33 -0
- eva/core/models/modules/head.py +5 -3
- eva/core/models/modules/typings.py +2 -2
- eva/core/models/transforms/__init__.py +2 -1
- eva/core/models/transforms/as_discrete.py +57 -0
- eva/core/models/wrappers/_utils.py +121 -1
- eva/core/trainers/functional.py +8 -5
- eva/core/trainers/trainer.py +32 -17
- eva/core/utils/suppress_logs.py +28 -0
- eva/vision/data/__init__.py +2 -2
- eva/vision/data/dataloaders/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
- eva/vision/data/datasets/__init__.py +10 -2
- eva/vision/data/datasets/classification/__init__.py +9 -0
- eva/vision/data/datasets/classification/bach.py +3 -4
- eva/vision/data/datasets/classification/bracs.py +111 -0
- eva/vision/data/datasets/classification/breakhis.py +209 -0
- eva/vision/data/datasets/classification/camelyon16.py +4 -5
- eva/vision/data/datasets/classification/crc.py +3 -4
- eva/vision/data/datasets/classification/gleason_arvaniti.py +171 -0
- eva/vision/data/datasets/classification/mhist.py +3 -4
- eva/vision/data/datasets/classification/panda.py +4 -5
- eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
- eva/vision/data/datasets/classification/unitopatho.py +158 -0
- eva/vision/data/datasets/classification/wsi.py +6 -5
- eva/vision/data/datasets/segmentation/__init__.py +2 -2
- eva/vision/data/datasets/segmentation/_utils.py +47 -0
- eva/vision/data/datasets/segmentation/bcss.py +7 -8
- eva/vision/data/datasets/segmentation/btcv.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +6 -7
- eva/vision/data/datasets/segmentation/embeddings.py +2 -2
- eva/vision/data/datasets/segmentation/lits.py +9 -8
- eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
- eva/vision/data/datasets/segmentation/monusac.py +4 -5
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
- eva/vision/data/datasets/vision.py +95 -4
- eva/vision/data/datasets/wsi.py +5 -5
- eva/vision/data/transforms/__init__.py +22 -3
- eva/vision/data/transforms/common/__init__.py +1 -2
- eva/vision/data/transforms/croppad/__init__.py +11 -0
- eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
- eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
- eva/vision/data/transforms/intensity/__init__.py +11 -0
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
- eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
- eva/vision/data/transforms/spatial/__init__.py +7 -0
- eva/vision/data/transforms/spatial/flip.py +72 -0
- eva/vision/data/transforms/spatial/rotate.py +53 -0
- eva/vision/data/transforms/spatial/spacing.py +69 -0
- eva/vision/data/transforms/utility/__init__.py +5 -0
- eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
- eva/vision/data/tv_tensors/__init__.py +5 -0
- eva/vision/data/tv_tensors/volume.py +61 -0
- eva/vision/metrics/segmentation/monai_dice.py +9 -2
- eva/vision/models/modules/semantic_segmentation.py +28 -20
- eva/vision/models/networks/backbones/__init__.py +9 -2
- eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
- eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
- eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +46 -19
- eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
- eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
- eva/vision/models/networks/backbones/radiology/voco.py +75 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
- eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
- eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
- eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
- eva/vision/utils/io/__init__.py +2 -0
- eva/vision/utils/io/nifti.py +91 -11
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/METADATA +3 -1
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/RECORD +83 -62
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/WHEEL +1 -1
- eva/vision/data/datasets/classification/base.py +0 -96
- eva/vision/data/datasets/segmentation/base.py +0 -96
- eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
- eva/vision/data/transforms/normalization/__init__.py +0 -6
- eva/vision/data/transforms/normalization/clamp.py +0 -43
- eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
- eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
- eva/vision/metrics/segmentation/BUILD +0 -1
- eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
- eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -5,6 +5,9 @@ from typing import Tuple
|
|
|
5
5
|
import timm
|
|
6
6
|
from torch import nn
|
|
7
7
|
|
|
8
|
+
from eva.core.models import transforms
|
|
9
|
+
from eva.vision.models import wrappers
|
|
10
|
+
from eva.vision.models.networks.backbones import _utils
|
|
8
11
|
from eva.vision.models.networks.backbones.registry import register_model
|
|
9
12
|
|
|
10
13
|
|
|
@@ -13,7 +16,9 @@ def bioptimus_h_optimus_0(
|
|
|
13
16
|
dynamic_img_size: bool = True,
|
|
14
17
|
out_indices: int | Tuple[int, ...] | None = None,
|
|
15
18
|
) -> nn.Module:
|
|
16
|
-
"""Initializes the
|
|
19
|
+
"""Initializes the H-Optimus-0 pathology FM by Bioptimus.
|
|
20
|
+
|
|
21
|
+
See https://huggingface.co/bioptimus/H-optimus-0 for details.
|
|
17
22
|
|
|
18
23
|
Args:
|
|
19
24
|
dynamic_img_size: Whether to allow the interpolation embedding
|
|
@@ -32,3 +37,44 @@ def bioptimus_h_optimus_0(
|
|
|
32
37
|
out_indices=out_indices,
|
|
33
38
|
features_only=out_indices is not None,
|
|
34
39
|
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@register_model("pathology/bioptimus_h0_mini")
|
|
43
|
+
def bioptimus_h0_mini(
|
|
44
|
+
dynamic_img_size: bool = True,
|
|
45
|
+
out_indices: int | Tuple[int, ...] | None = None,
|
|
46
|
+
hf_token: str | None = None,
|
|
47
|
+
include_patch_tokens: bool = False,
|
|
48
|
+
) -> nn.Module:
|
|
49
|
+
"""Initializes H0-mini (ViT-B) pathology FM by Bioptimus.
|
|
50
|
+
|
|
51
|
+
This model was distilled from H-Optimus-0 on 40M TCGA tiles.
|
|
52
|
+
|
|
53
|
+
See https://huggingface.co/bioptimus/H0-mini for details.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
dynamic_img_size: Support different input image sizes by allowing to change
|
|
57
|
+
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
|
|
58
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
59
|
+
hf_token: HuggingFace token to download the model.
|
|
60
|
+
include_patch_tokens: Whether to combine the mean aggregated patch tokens with cls token.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
The model instance.
|
|
64
|
+
"""
|
|
65
|
+
_utils.huggingface_login(hf_token)
|
|
66
|
+
return wrappers.TimmModel(
|
|
67
|
+
model_name="hf-hub:bioptimus/H0-mini",
|
|
68
|
+
out_indices=out_indices,
|
|
69
|
+
pretrained=True,
|
|
70
|
+
model_kwargs={
|
|
71
|
+
"dynamic_img_size": dynamic_img_size,
|
|
72
|
+
"mlp_layer": timm.layers.SwiGLUPacked,
|
|
73
|
+
"act_layer": nn.SiLU,
|
|
74
|
+
},
|
|
75
|
+
tensor_transforms=(
|
|
76
|
+
transforms.ExtractCLSFeatures(include_patch_tokens=include_patch_tokens)
|
|
77
|
+
if out_indices is None
|
|
78
|
+
else None
|
|
79
|
+
),
|
|
80
|
+
)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Pathology FMs from Hong Kong University of Science and Technology."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
import timm
|
|
7
|
+
from torch import nn
|
|
8
|
+
|
|
9
|
+
from eva.core.models.wrappers import _utils
|
|
10
|
+
from eva.vision.models.networks.backbones.registry import register_model
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@register_model("pathology/hkust_gpfm")
|
|
14
|
+
def hkust_gpfm(
|
|
15
|
+
dynamic_img_size: bool = True,
|
|
16
|
+
out_indices: int | Tuple[int, ...] | None = None,
|
|
17
|
+
) -> nn.Module:
|
|
18
|
+
"""Initializes GPFM model from Hong Kong University of Science and Technology.
|
|
19
|
+
|
|
20
|
+
Ma, J., Guo, Z., Zhou, F., Wang, Y., Xu, Y., et al. (2024).
|
|
21
|
+
Towards a generalizable pathology foundation model via unified knowledge
|
|
22
|
+
distillation (arXiv No. 2407.18449). arXiv. https://arxiv.org/abs/2407.18449
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
dynamic_img_size: Support different input image sizes by allowing to change
|
|
26
|
+
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
|
|
27
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
The model instance.
|
|
31
|
+
"""
|
|
32
|
+
return timm.create_model(
|
|
33
|
+
model_name="vit_large_patch14_dinov2",
|
|
34
|
+
pretrained=True,
|
|
35
|
+
pretrained_cfg={
|
|
36
|
+
"state_dict": _load_state_dict(),
|
|
37
|
+
"num_classes": 0,
|
|
38
|
+
},
|
|
39
|
+
out_indices=out_indices,
|
|
40
|
+
features_only=out_indices is not None,
|
|
41
|
+
**{
|
|
42
|
+
"img_size": 224,
|
|
43
|
+
"patch_size": 14,
|
|
44
|
+
"init_values": 1e-5,
|
|
45
|
+
"qkv_bias": True,
|
|
46
|
+
"dynamic_img_size": dynamic_img_size,
|
|
47
|
+
},
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _load_state_dict() -> dict:
|
|
52
|
+
"""Loads the state dict with model weights from github."""
|
|
53
|
+
state_dict = _utils.load_state_dict_from_url(
|
|
54
|
+
url="https://github.com/birkhoffkiki/GPFM/releases/download/ckpt/GPFM.pth",
|
|
55
|
+
md5="0dc7e345de84f385d09c8c782b4b3236",
|
|
56
|
+
)
|
|
57
|
+
return _convert_state_dict(state_dict["teacher"])
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _convert_state_dict(state_dict: dict) -> dict:
|
|
61
|
+
"""Rename state dict keys to match timm's format."""
|
|
62
|
+
state_dict = {
|
|
63
|
+
re.sub(r"blocks\.\d+\.(\d+)", r"blocks.\1", key.replace("backbone.", "")): value
|
|
64
|
+
for key, value in state_dict.items()
|
|
65
|
+
}
|
|
66
|
+
remove_keys = ["mask_token"] + [key for key in state_dict.keys() if "dino_head" in key]
|
|
67
|
+
for key in remove_keys:
|
|
68
|
+
state_dict.pop(key)
|
|
69
|
+
return state_dict
|
|
@@ -5,9 +5,27 @@ from typing import Tuple
|
|
|
5
5
|
import torch
|
|
6
6
|
from torch import nn
|
|
7
7
|
|
|
8
|
+
from eva.vision.models.networks.backbones import _utils
|
|
8
9
|
from eva.vision.models.networks.backbones.registry import register_model
|
|
9
10
|
|
|
10
11
|
|
|
12
|
+
@register_model("pathology/kaiko_midnight_12k")
|
|
13
|
+
def kaiko_midnight_12k(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
|
|
14
|
+
"""Initializes the Midnight-12k pathology FM by kaiko.ai.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
The model instance.
|
|
21
|
+
"""
|
|
22
|
+
return _utils.load_hugingface_model(
|
|
23
|
+
model_name="kaiko-ai/midnight",
|
|
24
|
+
out_indices=out_indices,
|
|
25
|
+
model_kwargs={"trust_remote_code": True},
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
11
29
|
@register_model("pathology/kaiko_vits16")
|
|
12
30
|
def kaiko_vits16(
|
|
13
31
|
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
@@ -1,11 +1,9 @@
|
|
|
1
1
|
"""Pathology FMs from MahmoodLab."""
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
|
-
from pathlib import Path
|
|
5
3
|
from typing import Tuple
|
|
6
4
|
|
|
7
|
-
import
|
|
8
|
-
|
|
5
|
+
import timm
|
|
6
|
+
import torch
|
|
9
7
|
from torch import nn
|
|
10
8
|
|
|
11
9
|
from eva.vision.models import wrappers
|
|
@@ -18,7 +16,6 @@ def mahmood_uni(
|
|
|
18
16
|
dynamic_img_size: bool = True,
|
|
19
17
|
out_indices: int | Tuple[int, ...] | None = None,
|
|
20
18
|
hf_token: str | None = None,
|
|
21
|
-
download_dir: str = os.path.join(str(Path.home()), ".cache/eva"),
|
|
22
19
|
) -> nn.Module:
|
|
23
20
|
"""Initializes UNI model from MahmoodLab.
|
|
24
21
|
|
|
@@ -27,29 +24,59 @@ def mahmood_uni(
|
|
|
27
24
|
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
|
|
28
25
|
out_indices: Whether and which multi-level patch embeddings to return.
|
|
29
26
|
hf_token: HuggingFace token to download the model.
|
|
30
|
-
download_dir: Directory to download the model checkpoint.
|
|
31
27
|
|
|
32
28
|
Returns:
|
|
33
29
|
The model instance.
|
|
34
30
|
"""
|
|
35
|
-
|
|
36
|
-
if not os.path.exists(checkpoint_path):
|
|
37
|
-
logger.info(f"Downloading the model checkpoint to {download_dir} ...")
|
|
38
|
-
os.makedirs(download_dir, exist_ok=True)
|
|
39
|
-
_utils.huggingface_login(hf_token)
|
|
40
|
-
huggingface_hub.hf_hub_download(
|
|
41
|
-
"MahmoodLab/UNI",
|
|
42
|
-
filename="pytorch_model.bin",
|
|
43
|
-
local_dir=download_dir,
|
|
44
|
-
force_download=True,
|
|
45
|
-
)
|
|
31
|
+
_utils.huggingface_login(hf_token)
|
|
46
32
|
|
|
47
33
|
return wrappers.TimmModel(
|
|
48
|
-
model_name="
|
|
34
|
+
model_name="hf-hub:MahmoodLab/uni",
|
|
35
|
+
pretrained=True,
|
|
49
36
|
out_indices=out_indices,
|
|
50
37
|
model_kwargs={
|
|
51
38
|
"init_values": 1e-5,
|
|
52
39
|
"dynamic_img_size": dynamic_img_size,
|
|
53
40
|
},
|
|
54
|
-
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@register_model("pathology/mahmood_uni2_h")
|
|
45
|
+
def mahmood_uni2_h(
|
|
46
|
+
dynamic_img_size: bool = True,
|
|
47
|
+
out_indices: int | Tuple[int, ...] | None = None,
|
|
48
|
+
hf_token: str | None = None,
|
|
49
|
+
) -> nn.Module:
|
|
50
|
+
"""Initializes UNI model from MahmoodLab.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
dynamic_img_size: Support different input image sizes by allowing to change
|
|
54
|
+
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
|
|
55
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
56
|
+
hf_token: HuggingFace token to download the model.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
The model instance.
|
|
60
|
+
"""
|
|
61
|
+
_utils.huggingface_login(hf_token)
|
|
62
|
+
|
|
63
|
+
return wrappers.TimmModel(
|
|
64
|
+
model_name="hf-hub:MahmoodLab/UNI2-h",
|
|
65
|
+
pretrained=True,
|
|
66
|
+
out_indices=out_indices,
|
|
67
|
+
model_kwargs={
|
|
68
|
+
"img_size": 224,
|
|
69
|
+
"patch_size": 14,
|
|
70
|
+
"depth": 24,
|
|
71
|
+
"num_heads": 24,
|
|
72
|
+
"init_values": 1e-5,
|
|
73
|
+
"embed_dim": 1536,
|
|
74
|
+
"mlp_ratio": 2.66667 * 2,
|
|
75
|
+
"num_classes": 0,
|
|
76
|
+
"no_embed_class": True,
|
|
77
|
+
"mlp_layer": timm.layers.SwiGLUPacked,
|
|
78
|
+
"act_layer": torch.nn.SiLU,
|
|
79
|
+
"reg_tokens": 8,
|
|
80
|
+
"dynamic_img_size": dynamic_img_size,
|
|
81
|
+
},
|
|
55
82
|
)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Vision Radiology Model Backbones API."""
|
|
2
|
+
|
|
3
|
+
from eva.vision.models.networks.backbones.radiology.swin_unetr import SwinUNETREncoder
|
|
4
|
+
from eva.vision.models.networks.backbones.radiology.voco import VoCoB, VoCoH, VoCoL
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"VoCoB",
|
|
8
|
+
"VoCoL",
|
|
9
|
+
"VoCoH",
|
|
10
|
+
"SwinUNETREncoder",
|
|
11
|
+
]
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
"""Encoder based on Swin UNETR."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from monai.inferers.inferer import Inferer
|
|
7
|
+
from monai.networks.blocks import unetr_block
|
|
8
|
+
from monai.networks.nets import swin_unetr
|
|
9
|
+
from monai.utils import misc
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
from eva.vision.models.networks.backbones.registry import register_model
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@register_model("radiology/swin_unetr_encoder")
|
|
16
|
+
class SwinUNETREncoder(nn.Module):
|
|
17
|
+
"""Swin transformer encoder based on UNETR [0].
|
|
18
|
+
|
|
19
|
+
- [0] UNETR: Transformers for 3D Medical Image Segmentation
|
|
20
|
+
https://arxiv.org/pdf/2103.10504
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
in_channels: int = 1,
|
|
26
|
+
feature_size: int = 48,
|
|
27
|
+
spatial_dims: int = 3,
|
|
28
|
+
out_indices: int | None = None,
|
|
29
|
+
inferer: Inferer | None = None,
|
|
30
|
+
use_v2: bool = True,
|
|
31
|
+
) -> None:
|
|
32
|
+
"""Build the UNETR encoder.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
in_channels: Number of input channels.
|
|
36
|
+
feature_size: The dimension of network feature size.
|
|
37
|
+
spatial_dims: Number of spatial dimensions.
|
|
38
|
+
out_indices: Number of feature outputs. If None,
|
|
39
|
+
the aggregated feature vector is returned.
|
|
40
|
+
inferer: An optional MONAI `Inferer` for efficient
|
|
41
|
+
inference during evaluation.
|
|
42
|
+
use_v2: Whether to use SwinTransformerV2.
|
|
43
|
+
"""
|
|
44
|
+
super().__init__()
|
|
45
|
+
|
|
46
|
+
self._in_channels = in_channels
|
|
47
|
+
self._feature_size = feature_size
|
|
48
|
+
self._spatial_dims = spatial_dims
|
|
49
|
+
self._out_indices = out_indices
|
|
50
|
+
self._inferer = inferer
|
|
51
|
+
self._use_v2 = use_v2
|
|
52
|
+
|
|
53
|
+
self._window_size = misc.ensure_tuple_rep(7, spatial_dims)
|
|
54
|
+
self._patch_size = misc.ensure_tuple_rep(2, spatial_dims)
|
|
55
|
+
|
|
56
|
+
self.swinViT = swin_unetr.SwinTransformer(
|
|
57
|
+
in_chans=in_channels,
|
|
58
|
+
embed_dim=feature_size,
|
|
59
|
+
window_size=self._window_size,
|
|
60
|
+
patch_size=self._patch_size,
|
|
61
|
+
depths=(2, 2, 2, 2),
|
|
62
|
+
num_heads=(3, 6, 12, 24),
|
|
63
|
+
mlp_ratio=4.0,
|
|
64
|
+
qkv_bias=True,
|
|
65
|
+
drop_rate=0.0,
|
|
66
|
+
attn_drop_rate=0.0,
|
|
67
|
+
drop_path_rate=0.0,
|
|
68
|
+
norm_layer=torch.nn.LayerNorm,
|
|
69
|
+
spatial_dims=spatial_dims,
|
|
70
|
+
use_v2=use_v2,
|
|
71
|
+
)
|
|
72
|
+
self.encoder1 = unetr_block.UnetrBasicBlock(
|
|
73
|
+
spatial_dims=spatial_dims,
|
|
74
|
+
in_channels=in_channels,
|
|
75
|
+
out_channels=feature_size,
|
|
76
|
+
kernel_size=3,
|
|
77
|
+
stride=1,
|
|
78
|
+
norm_name="instance",
|
|
79
|
+
res_block=True,
|
|
80
|
+
)
|
|
81
|
+
self.encoder2 = unetr_block.UnetrBasicBlock(
|
|
82
|
+
spatial_dims=spatial_dims,
|
|
83
|
+
in_channels=feature_size,
|
|
84
|
+
out_channels=feature_size,
|
|
85
|
+
kernel_size=3,
|
|
86
|
+
stride=1,
|
|
87
|
+
norm_name="instance",
|
|
88
|
+
res_block=True,
|
|
89
|
+
)
|
|
90
|
+
self.encoder3 = unetr_block.UnetrBasicBlock(
|
|
91
|
+
spatial_dims=spatial_dims,
|
|
92
|
+
in_channels=2 * feature_size,
|
|
93
|
+
out_channels=2 * feature_size,
|
|
94
|
+
kernel_size=3,
|
|
95
|
+
stride=1,
|
|
96
|
+
norm_name="instance",
|
|
97
|
+
res_block=True,
|
|
98
|
+
)
|
|
99
|
+
self.encoder4 = unetr_block.UnetrBasicBlock(
|
|
100
|
+
spatial_dims=spatial_dims,
|
|
101
|
+
in_channels=4 * feature_size,
|
|
102
|
+
out_channels=4 * feature_size,
|
|
103
|
+
kernel_size=3,
|
|
104
|
+
stride=1,
|
|
105
|
+
norm_name="instance",
|
|
106
|
+
res_block=True,
|
|
107
|
+
)
|
|
108
|
+
self.encoder10 = unetr_block.UnetrBasicBlock(
|
|
109
|
+
spatial_dims=spatial_dims,
|
|
110
|
+
in_channels=16 * feature_size,
|
|
111
|
+
out_channels=16 * feature_size,
|
|
112
|
+
kernel_size=3,
|
|
113
|
+
stride=1,
|
|
114
|
+
norm_name="instance",
|
|
115
|
+
res_block=True,
|
|
116
|
+
)
|
|
117
|
+
self._pool_op = (
|
|
118
|
+
nn.AdaptiveAvgPool3d(output_size=(1, 1, 1))
|
|
119
|
+
if spatial_dims == 3
|
|
120
|
+
else nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def _forward_features(self, tensor: torch.Tensor) -> List[torch.Tensor]:
|
|
124
|
+
"""Extracts feature maps from the Swin Transformer and encoder blocks.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
tensor: Input tensor of shape (B, C, T, H, W).
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
List of feature maps from encoder stages.
|
|
131
|
+
"""
|
|
132
|
+
hidden_states = self.swinViT(tensor)
|
|
133
|
+
enc0 = self.encoder1(tensor)
|
|
134
|
+
enc1 = self.encoder2(hidden_states[0])
|
|
135
|
+
enc2 = self.encoder3(hidden_states[1])
|
|
136
|
+
enc3 = self.encoder4(hidden_states[2])
|
|
137
|
+
dec4 = self.encoder10(hidden_states[4])
|
|
138
|
+
return [enc0, enc1, enc2, enc3, hidden_states[3], dec4]
|
|
139
|
+
|
|
140
|
+
def forward_features(self, tensor: torch.Tensor) -> List[torch.Tensor]:
|
|
141
|
+
"""Computes feature maps using either standard forward pass or inference mode.
|
|
142
|
+
|
|
143
|
+
If in inference mode (`self.training` is False) and an inference method
|
|
144
|
+
(`self._inferer`) is available, the `_inferer` is used to extract features.
|
|
145
|
+
Otherwise, `_forward_features` is called directly.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
tensor: Input tensor of shape (B, C, T, H, W).
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
List of feature maps from encoder stages.
|
|
152
|
+
"""
|
|
153
|
+
if not self.training and self._inferer:
|
|
154
|
+
return self._inferer(inputs=tensor, network=self._forward_features)
|
|
155
|
+
|
|
156
|
+
return self._forward_features(tensor)
|
|
157
|
+
|
|
158
|
+
def forward_encoders(self, features: List[torch.Tensor]) -> torch.Tensor:
|
|
159
|
+
"""Aggregates encoder features into a single feature vector.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
features: List of feature maps from encoder stages.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
Aggregated feature vector (B, C').
|
|
166
|
+
"""
|
|
167
|
+
batch_size = features[0].shape[0]
|
|
168
|
+
reduced_features = []
|
|
169
|
+
for patch_features in features[:4] + features[5:]:
|
|
170
|
+
hidden_features = self._pool_op(patch_features)
|
|
171
|
+
hidden_features_reduced = hidden_features.view(batch_size, -1)
|
|
172
|
+
reduced_features.append(hidden_features_reduced)
|
|
173
|
+
return torch.cat(reduced_features, dim=1)
|
|
174
|
+
|
|
175
|
+
def forward_head(self, features: List[torch.Tensor]) -> torch.Tensor:
|
|
176
|
+
"""Casts last feature map into a single feature vector.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
features: List of encoder feature maps.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Aggregated feature vector (B, C').
|
|
183
|
+
"""
|
|
184
|
+
last_feature_map = features[-1]
|
|
185
|
+
pooled_features = self._pool_op(last_feature_map)
|
|
186
|
+
return torch.flatten(pooled_features, 1)
|
|
187
|
+
|
|
188
|
+
def forward_embeddings(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
189
|
+
"""Computes the final aggregated feature vector.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
tensor: Input tensor of shape (B, C, T, H, W).
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
Aggregated feature vector of shape (B, C').
|
|
196
|
+
"""
|
|
197
|
+
intermediates = self.forward_features(tensor)
|
|
198
|
+
return self.forward_encoders(intermediates)
|
|
199
|
+
|
|
200
|
+
def forward_intermediates(
|
|
201
|
+
self, tensor: torch.Tensor
|
|
202
|
+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
|
203
|
+
"""Computes encoder features and their embeddings.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
tensor: Input tensor of shape (B, C, T, H, W).
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
Aggregated feature vector and list of intermediate features.
|
|
210
|
+
"""
|
|
211
|
+
features = self.forward_features(tensor)
|
|
212
|
+
embeddings = self.forward_encoders(features)
|
|
213
|
+
return embeddings, features
|
|
214
|
+
|
|
215
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor | List[torch.Tensor]:
|
|
216
|
+
"""Forward pass through the encoder.
|
|
217
|
+
|
|
218
|
+
If `self._out_indices` is None, it returns the aggregated feature vector.
|
|
219
|
+
Otherwise, it returns the intermediate feature maps up to the specified index.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
tensor: Input tensor of shape (B, C, T, H, W).
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
Aggregated feature vector or intermediate features.
|
|
226
|
+
"""
|
|
227
|
+
if self._out_indices is None:
|
|
228
|
+
return self.forward_embeddings(tensor)
|
|
229
|
+
|
|
230
|
+
intermediates = self.forward_features(tensor)
|
|
231
|
+
return intermediates[-1 * self._out_indices :]
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""VoCo Self-Supervised Encoders."""
|
|
2
|
+
|
|
3
|
+
from typing_extensions import override
|
|
4
|
+
|
|
5
|
+
from eva.core.models.wrappers import _utils
|
|
6
|
+
from eva.vision.models.networks.backbones.radiology import swin_unetr
|
|
7
|
+
from eva.vision.models.networks.backbones.registry import register_model
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class _VoCo(swin_unetr.SwinUNETREncoder):
|
|
11
|
+
"""Base class for the VoCo self-supervised encoders."""
|
|
12
|
+
|
|
13
|
+
_checkpoint: str
|
|
14
|
+
"""Path to the model state dict."""
|
|
15
|
+
|
|
16
|
+
_md5: str | None = None
|
|
17
|
+
"""State dict MD5 validation code."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, feature_size: int, out_indices: int | None = None) -> None:
|
|
20
|
+
"""Initializes the model.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
feature_size: Size of the last feature map of SwinUNETR.
|
|
24
|
+
out_indices: The number of feature maps from intermediate blocks
|
|
25
|
+
to be returned. If set to 1, only the last feature map is returned.
|
|
26
|
+
"""
|
|
27
|
+
super().__init__(
|
|
28
|
+
in_channels=1,
|
|
29
|
+
feature_size=feature_size,
|
|
30
|
+
spatial_dims=3,
|
|
31
|
+
out_indices=out_indices,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
self._load_checkpoint()
|
|
35
|
+
|
|
36
|
+
def _load_checkpoint(self) -> None:
|
|
37
|
+
"""Loads the model checkpoint."""
|
|
38
|
+
state_dict = _utils.load_state_dict_from_url(self._checkpoint, md5=self._md5)
|
|
39
|
+
self.load_state_dict(state_dict)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@register_model("radiology/voco_b")
|
|
43
|
+
class VoCoB(_VoCo):
|
|
44
|
+
"""VoCo Self-supervised pre-trained B model."""
|
|
45
|
+
|
|
46
|
+
_checkpoint = "https://huggingface.co/Luffy503/VoCo/resolve/main/VoCo_B_SSL_head.pt"
|
|
47
|
+
_md5 = "f80c4da2f81d700bdae3df188f2057eb"
|
|
48
|
+
|
|
49
|
+
@override
|
|
50
|
+
def __init__(self, out_indices: int | None = None) -> None:
|
|
51
|
+
super().__init__(feature_size=48, out_indices=out_indices)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@register_model("radiology/voco_l")
|
|
55
|
+
class VoCoL(_VoCo):
|
|
56
|
+
"""VoCo Self-supervised pre-trained L model."""
|
|
57
|
+
|
|
58
|
+
_checkpoint = "https://huggingface.co/Luffy503/VoCo/resolve/main/VoCo_L_SSL_head.pt"
|
|
59
|
+
_md5 = "795095d1d43ef3808ec4c41798310136"
|
|
60
|
+
|
|
61
|
+
@override
|
|
62
|
+
def __init__(self, out_indices: int | None = None) -> None:
|
|
63
|
+
super().__init__(feature_size=96, out_indices=out_indices)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@register_model("radiology/voco_h")
|
|
67
|
+
class VoCoH(_VoCo):
|
|
68
|
+
"""VoCo Self-supervised pre-trained H model."""
|
|
69
|
+
|
|
70
|
+
_checkpoint = "https://huggingface.co/Luffy503/VoCo/resolve/main/VoCo_H_SSL_head.pt"
|
|
71
|
+
_md5 = "76f95a474736b60bf5b8aad94643744d"
|
|
72
|
+
|
|
73
|
+
@override
|
|
74
|
+
def __init__(self, out_indices: int | None = None) -> None:
|
|
75
|
+
super().__init__(feature_size=192, out_indices=out_indices)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Segmentation decoder heads API."""
|
|
2
2
|
|
|
3
|
+
from eva.vision.models.networks.decoders.segmentation.base import Decoder
|
|
3
4
|
from eva.vision.models.networks.decoders.segmentation.decoder2d import Decoder2D
|
|
4
5
|
from eva.vision.models.networks.decoders.segmentation.linear import LinearDecoder
|
|
5
6
|
from eva.vision.models.networks.decoders.segmentation.semantic import (
|
|
@@ -7,13 +8,16 @@ from eva.vision.models.networks.decoders.segmentation.semantic import (
|
|
|
7
8
|
ConvDecoderMS,
|
|
8
9
|
ConvDecoderWithImage,
|
|
9
10
|
SingleLinearDecoder,
|
|
11
|
+
SwinUNETRDecoder,
|
|
10
12
|
)
|
|
11
13
|
|
|
12
14
|
__all__ = [
|
|
15
|
+
"Decoder",
|
|
16
|
+
"Decoder2D",
|
|
13
17
|
"ConvDecoder1x1",
|
|
14
18
|
"ConvDecoderMS",
|
|
15
|
-
"SingleLinearDecoder",
|
|
16
19
|
"ConvDecoderWithImage",
|
|
17
|
-
"Decoder2D",
|
|
18
20
|
"LinearDecoder",
|
|
21
|
+
"SingleLinearDecoder",
|
|
22
|
+
"SwinUNETRDecoder",
|
|
19
23
|
]
|
|
@@ -7,6 +7,7 @@ from torch import nn
|
|
|
7
7
|
from torch.nn import functional
|
|
8
8
|
|
|
9
9
|
from eva.vision.models.networks.decoders.segmentation import base
|
|
10
|
+
from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
class LinearDecoder(base.Decoder):
|
|
@@ -104,22 +105,16 @@ class LinearDecoder(base.Decoder):
|
|
|
104
105
|
"""
|
|
105
106
|
return functional.interpolate(logits, image_size, mode="bilinear")
|
|
106
107
|
|
|
107
|
-
def forward(
|
|
108
|
-
self,
|
|
109
|
-
features: List[torch.Tensor],
|
|
110
|
-
image_size: Tuple[int, int],
|
|
111
|
-
) -> torch.Tensor:
|
|
108
|
+
def forward(self, decoder_inputs: DecoderInputs) -> torch.Tensor:
|
|
112
109
|
"""Maps the patch embeddings to a segmentation mask of the image size.
|
|
113
110
|
|
|
114
111
|
Args:
|
|
115
|
-
|
|
116
|
-
hidden_size, n_patches_height, n_patches_width).
|
|
117
|
-
image_size: The target image size (height, width).
|
|
112
|
+
decoder_inputs: Inputs required by the decoder.
|
|
118
113
|
|
|
119
114
|
Returns:
|
|
120
115
|
Tensor containing scores for all of the classes with shape
|
|
121
116
|
(batch_size, n_classes, image_height, image_width).
|
|
122
117
|
"""
|
|
123
|
-
patch_embeddings = self._forward_features(features)
|
|
118
|
+
patch_embeddings = self._forward_features(decoder_inputs.features)
|
|
124
119
|
logits = self._forward_head(patch_embeddings)
|
|
125
|
-
return self._cls_seg(logits, image_size)
|
|
120
|
+
return self._cls_seg(logits, decoder_inputs.image_size)
|
|
@@ -5,8 +5,15 @@ from eva.vision.models.networks.decoders.segmentation.semantic.common import (
|
|
|
5
5
|
ConvDecoderMS,
|
|
6
6
|
SingleLinearDecoder,
|
|
7
7
|
)
|
|
8
|
+
from eva.vision.models.networks.decoders.segmentation.semantic.swin_unetr import SwinUNETRDecoder
|
|
8
9
|
from eva.vision.models.networks.decoders.segmentation.semantic.with_image import (
|
|
9
10
|
ConvDecoderWithImage,
|
|
10
11
|
)
|
|
11
12
|
|
|
12
|
-
__all__ = [
|
|
13
|
+
__all__ = [
|
|
14
|
+
"ConvDecoder1x1",
|
|
15
|
+
"ConvDecoderMS",
|
|
16
|
+
"ConvDecoderWithImage",
|
|
17
|
+
"SingleLinearDecoder",
|
|
18
|
+
"SwinUNETRDecoder",
|
|
19
|
+
]
|