kaiko-eva 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/data/dataloaders/__init__.py +2 -1
- eva/core/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/core/data/dataloaders/collate_fn/collate.py +24 -0
- eva/core/data/dataloaders/dataloader.py +4 -0
- eva/core/interface/interface.py +34 -1
- eva/core/metrics/defaults/classification/multiclass.py +45 -35
- eva/core/models/modules/__init__.py +2 -1
- eva/core/models/modules/scheduler.py +51 -0
- eva/core/models/transforms/extract_cls_features.py +1 -1
- eva/core/models/transforms/extract_patch_features.py +1 -1
- eva/core/models/wrappers/base.py +17 -14
- eva/core/models/wrappers/from_function.py +5 -4
- eva/core/models/wrappers/from_torchhub.py +5 -6
- eva/core/models/wrappers/huggingface.py +8 -5
- eva/core/models/wrappers/onnx.py +4 -4
- eva/core/trainers/functional.py +40 -43
- eva/core/utils/factory.py +66 -0
- eva/core/utils/registry.py +42 -0
- eva/core/utils/requirements.py +26 -0
- eva/language/__init__.py +13 -0
- eva/language/data/__init__.py +5 -0
- eva/language/data/datasets/__init__.py +9 -0
- eva/language/data/datasets/classification/__init__.py +7 -0
- eva/language/data/datasets/classification/base.py +63 -0
- eva/language/data/datasets/classification/pubmedqa.py +149 -0
- eva/language/data/datasets/language.py +13 -0
- eva/language/models/__init__.py +25 -0
- eva/language/models/modules/__init__.py +5 -0
- eva/language/models/modules/text.py +85 -0
- eva/language/models/modules/typings.py +16 -0
- eva/language/models/wrappers/__init__.py +11 -0
- eva/language/models/wrappers/huggingface.py +69 -0
- eva/language/models/wrappers/litellm.py +77 -0
- eva/language/models/wrappers/vllm.py +149 -0
- eva/language/utils/__init__.py +5 -0
- eva/language/utils/str_to_int_tensor.py +95 -0
- eva/vision/data/dataloaders/__init__.py +2 -1
- eva/vision/data/dataloaders/worker_init.py +35 -0
- eva/vision/data/datasets/__init__.py +5 -5
- eva/vision/data/datasets/segmentation/__init__.py +4 -4
- eva/vision/data/datasets/segmentation/btcv.py +3 -0
- eva/vision/data/datasets/segmentation/consep.py +5 -4
- eva/vision/data/datasets/segmentation/lits17.py +231 -0
- eva/vision/data/datasets/segmentation/metadata/__init__.py +1 -0
- eva/vision/data/datasets/segmentation/metadata/_msd_task7_pancreas.py +287 -0
- eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +243 -0
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +1 -1
- eva/vision/data/transforms/__init__.py +11 -2
- eva/vision/data/transforms/base/__init__.py +5 -0
- eva/vision/data/transforms/base/monai.py +27 -0
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/squeeze.py +24 -0
- eva/vision/data/transforms/croppad/__init__.py +4 -0
- eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +74 -0
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -2
- eva/vision/data/transforms/croppad/rand_spatial_crop.py +89 -0
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +6 -2
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +8 -4
- eva/vision/models/modules/semantic_segmentation.py +18 -7
- eva/vision/models/networks/backbones/__init__.py +2 -3
- eva/vision/models/networks/backbones/_utils.py +1 -1
- eva/vision/models/networks/backbones/pathology/bioptimus.py +4 -4
- eva/vision/models/networks/backbones/pathology/gigapath.py +2 -2
- eva/vision/models/networks/backbones/pathology/histai.py +3 -3
- eva/vision/models/networks/backbones/pathology/hkust.py +2 -2
- eva/vision/models/networks/backbones/pathology/kaiko.py +7 -7
- eva/vision/models/networks/backbones/pathology/lunit.py +3 -3
- eva/vision/models/networks/backbones/pathology/mahmood.py +3 -3
- eva/vision/models/networks/backbones/pathology/owkin.py +3 -3
- eva/vision/models/networks/backbones/pathology/paige.py +3 -3
- eva/vision/models/networks/backbones/radiology/swin_unetr.py +2 -2
- eva/vision/models/networks/backbones/radiology/voco.py +5 -5
- eva/vision/models/networks/backbones/registry.py +2 -44
- eva/vision/models/networks/backbones/timm/backbones.py +2 -2
- eva/vision/models/networks/backbones/universal/__init__.py +8 -1
- eva/vision/models/networks/backbones/universal/vit.py +53 -3
- eva/vision/models/networks/decoders/segmentation/decoder2d.py +1 -1
- eva/vision/models/networks/decoders/segmentation/linear.py +1 -1
- eva/vision/models/networks/decoders/segmentation/semantic/common.py +2 -2
- eva/vision/models/networks/decoders/segmentation/typings.py +1 -1
- eva/vision/models/wrappers/from_registry.py +14 -9
- eva/vision/models/wrappers/from_timm.py +6 -5
- {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/METADATA +10 -2
- {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/RECORD +88 -57
- {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/WHEEL +1 -1
- eva/vision/data/datasets/segmentation/lits.py +0 -199
- eva/vision/data/datasets/segmentation/lits_balanced.py +0 -94
- /eva/vision/data/datasets/segmentation/{_total_segmentator.py → metadata/_total_segmentator.py} +0 -0
- {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Neural Network Semantic Segmentation Module."""
|
|
2
2
|
|
|
3
3
|
import functools
|
|
4
4
|
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
|
@@ -12,7 +12,7 @@ from torch.optim import lr_scheduler
|
|
|
12
12
|
from typing_extensions import override
|
|
13
13
|
|
|
14
14
|
from eva.core.metrics import structs as metrics_lib
|
|
15
|
-
from eva.core.models.modules import module
|
|
15
|
+
from eva.core.models.modules import SchedulerConfiguration, module
|
|
16
16
|
from eva.core.models.modules.typings import INPUT_BATCH, INPUT_TENSOR_BATCH
|
|
17
17
|
from eva.core.models.modules.utils import batch_postprocess, grad, submodule_state_dict
|
|
18
18
|
from eva.core.utils import parser
|
|
@@ -32,7 +32,7 @@ class SemanticSegmentationModule(module.ModelModule):
|
|
|
32
32
|
lr_multiplier_encoder: float = 0.0,
|
|
33
33
|
inferer: Inferer | None = None,
|
|
34
34
|
optimizer: OptimizerCallable = optim.AdamW,
|
|
35
|
-
lr_scheduler: LRSchedulerCallable = lr_scheduler.ConstantLR,
|
|
35
|
+
lr_scheduler: LRSchedulerCallable | SchedulerConfiguration = lr_scheduler.ConstantLR,
|
|
36
36
|
metrics: metrics_lib.MetricsSchema | None = None,
|
|
37
37
|
postprocess: batch_postprocess.BatchPostProcess | None = None,
|
|
38
38
|
save_decoder_only: bool = True,
|
|
@@ -116,7 +116,7 @@ class SemanticSegmentationModule(module.ModelModule):
|
|
|
116
116
|
def forward(
|
|
117
117
|
self,
|
|
118
118
|
tensor: torch.Tensor,
|
|
119
|
-
to_size: Tuple[int,
|
|
119
|
+
to_size: Tuple[int, ...],
|
|
120
120
|
*args: Any,
|
|
121
121
|
**kwargs: Any,
|
|
122
122
|
) -> torch.Tensor:
|
|
@@ -174,7 +174,8 @@ class SemanticSegmentationModule(module.ModelModule):
|
|
|
174
174
|
The batch step output.
|
|
175
175
|
"""
|
|
176
176
|
data, targets, metadata = INPUT_TENSOR_BATCH(*batch)
|
|
177
|
-
|
|
177
|
+
to_size = targets.shape[-self.spatial_dims :] if self.inferer is None else None
|
|
178
|
+
predictions = self(data, to_size=to_size)
|
|
178
179
|
loss = self.criterion(predictions, targets)
|
|
179
180
|
return {
|
|
180
181
|
"loss": loss,
|
|
@@ -183,11 +184,21 @@ class SemanticSegmentationModule(module.ModelModule):
|
|
|
183
184
|
"metadata": metadata,
|
|
184
185
|
}
|
|
185
186
|
|
|
186
|
-
def _forward_networks(
|
|
187
|
+
def _forward_networks(
|
|
188
|
+
self, tensor: torch.Tensor, to_size: Tuple[int, ...] | None = None
|
|
189
|
+
) -> torch.Tensor:
|
|
187
190
|
"""Passes the input tensor through the encoder and decoder."""
|
|
188
|
-
|
|
191
|
+
if self.encoder:
|
|
192
|
+
to_size = to_size or tuple(tensor.shape[-self.spatial_dims :])
|
|
193
|
+
features = self.encoder(tensor)
|
|
194
|
+
else:
|
|
195
|
+
if to_size is None:
|
|
196
|
+
raise ValueError("`to_size` must be provided when no encoder is used.")
|
|
197
|
+
features = tensor
|
|
198
|
+
|
|
189
199
|
if isinstance(self.decoder, segmentation.Decoder):
|
|
190
200
|
if not isinstance(features, list):
|
|
191
201
|
raise ValueError(f"Expected a list of feature map tensors, got {type(features)}.")
|
|
192
202
|
return self.decoder(DecoderInputs(features, to_size, tensor))
|
|
203
|
+
|
|
193
204
|
return self.decoder(features)
|
|
@@ -1,13 +1,12 @@
|
|
|
1
1
|
"""Vision Model Backbones API."""
|
|
2
2
|
|
|
3
3
|
from eva.vision.models.networks.backbones import pathology, radiology, timm, universal
|
|
4
|
-
from eva.vision.models.networks.backbones.registry import
|
|
4
|
+
from eva.vision.models.networks.backbones.registry import backbone_registry
|
|
5
5
|
|
|
6
6
|
__all__ = [
|
|
7
7
|
"radiology",
|
|
8
8
|
"pathology",
|
|
9
9
|
"timm",
|
|
10
10
|
"universal",
|
|
11
|
-
"
|
|
12
|
-
"register_model",
|
|
11
|
+
"backbone_registry",
|
|
13
12
|
]
|
|
@@ -8,10 +8,10 @@ from torch import nn
|
|
|
8
8
|
from eva.core.models import transforms
|
|
9
9
|
from eva.vision.models import wrappers
|
|
10
10
|
from eva.vision.models.networks.backbones import _utils
|
|
11
|
-
from eva.vision.models.networks.backbones.registry import
|
|
11
|
+
from eva.vision.models.networks.backbones.registry import backbone_registry
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
@
|
|
14
|
+
@backbone_registry.register("pathology/bioptimus_h_optimus_0")
|
|
15
15
|
def bioptimus_h_optimus_0(
|
|
16
16
|
dynamic_img_size: bool = True,
|
|
17
17
|
out_indices: int | Tuple[int, ...] | None = None,
|
|
@@ -39,7 +39,7 @@ def bioptimus_h_optimus_0(
|
|
|
39
39
|
)
|
|
40
40
|
|
|
41
41
|
|
|
42
|
-
@
|
|
42
|
+
@backbone_registry.register("pathology/bioptimus_h0_mini")
|
|
43
43
|
def bioptimus_h0_mini(
|
|
44
44
|
dynamic_img_size: bool = True,
|
|
45
45
|
out_indices: int | Tuple[int, ...] | None = None,
|
|
@@ -72,7 +72,7 @@ def bioptimus_h0_mini(
|
|
|
72
72
|
"mlp_layer": timm.layers.SwiGLUPacked,
|
|
73
73
|
"act_layer": nn.SiLU,
|
|
74
74
|
},
|
|
75
|
-
|
|
75
|
+
transforms=(
|
|
76
76
|
transforms.ExtractCLSFeatures(include_patch_tokens=include_patch_tokens)
|
|
77
77
|
if out_indices is None
|
|
78
78
|
else None
|
|
@@ -5,10 +5,10 @@ from typing import Tuple
|
|
|
5
5
|
import timm
|
|
6
6
|
from torch import nn
|
|
7
7
|
|
|
8
|
-
from eva.vision.models.networks.backbones.registry import
|
|
8
|
+
from eva.vision.models.networks.backbones.registry import backbone_registry
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
@
|
|
11
|
+
@backbone_registry.register("pathology/prov_gigapath")
|
|
12
12
|
def prov_gigapath(
|
|
13
13
|
dynamic_img_size: bool = True,
|
|
14
14
|
out_indices: int | Tuple[int, ...] | None = None,
|
|
@@ -5,10 +5,10 @@ from typing import Tuple
|
|
|
5
5
|
from torch import nn
|
|
6
6
|
|
|
7
7
|
from eva.vision.models.networks.backbones import _utils
|
|
8
|
-
from eva.vision.models.networks.backbones.registry import
|
|
8
|
+
from eva.vision.models.networks.backbones.registry import backbone_registry
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
@
|
|
11
|
+
@backbone_registry.register("pathology/histai_hibou_b")
|
|
12
12
|
def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
|
|
13
13
|
"""Initializes the hibou-B pathology FM by hist.ai (https://huggingface.co/histai/hibou-B).
|
|
14
14
|
|
|
@@ -30,7 +30,7 @@ def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Modul
|
|
|
30
30
|
)
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
@
|
|
33
|
+
@backbone_registry.register("pathology/histai_hibou_l")
|
|
34
34
|
def histai_hibou_l(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
|
|
35
35
|
"""Initializes the hibou-L pathology FM by hist.ai (https://huggingface.co/histai/hibou-L).
|
|
36
36
|
|
|
@@ -7,10 +7,10 @@ import timm
|
|
|
7
7
|
from torch import nn
|
|
8
8
|
|
|
9
9
|
from eva.core.models.wrappers import _utils
|
|
10
|
-
from eva.vision.models.networks.backbones.registry import
|
|
10
|
+
from eva.vision.models.networks.backbones.registry import backbone_registry
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
@
|
|
13
|
+
@backbone_registry.register("pathology/hkust_gpfm")
|
|
14
14
|
def hkust_gpfm(
|
|
15
15
|
dynamic_img_size: bool = True,
|
|
16
16
|
out_indices: int | Tuple[int, ...] | None = None,
|
|
@@ -6,10 +6,10 @@ import torch
|
|
|
6
6
|
from torch import nn
|
|
7
7
|
|
|
8
8
|
from eva.vision.models.networks.backbones import _utils
|
|
9
|
-
from eva.vision.models.networks.backbones.registry import
|
|
9
|
+
from eva.vision.models.networks.backbones.registry import backbone_registry
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
@
|
|
12
|
+
@backbone_registry.register("pathology/kaiko_midnight_12k")
|
|
13
13
|
def kaiko_midnight_12k(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
|
|
14
14
|
"""Initializes the Midnight-12k pathology FM by kaiko.ai.
|
|
15
15
|
|
|
@@ -26,7 +26,7 @@ def kaiko_midnight_12k(out_indices: int | Tuple[int, ...] | None = None) -> nn.M
|
|
|
26
26
|
)
|
|
27
27
|
|
|
28
28
|
|
|
29
|
-
@
|
|
29
|
+
@backbone_registry.register("pathology/kaiko_vits16")
|
|
30
30
|
def kaiko_vits16(
|
|
31
31
|
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
32
32
|
) -> nn.Module:
|
|
@@ -49,7 +49,7 @@ def kaiko_vits16(
|
|
|
49
49
|
)
|
|
50
50
|
|
|
51
51
|
|
|
52
|
-
@
|
|
52
|
+
@backbone_registry.register("pathology/kaiko_vits8")
|
|
53
53
|
def kaiko_vits8(
|
|
54
54
|
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
55
55
|
) -> nn.Module:
|
|
@@ -72,7 +72,7 @@ def kaiko_vits8(
|
|
|
72
72
|
)
|
|
73
73
|
|
|
74
74
|
|
|
75
|
-
@
|
|
75
|
+
@backbone_registry.register("pathology/kaiko_vitb16")
|
|
76
76
|
def kaiko_vitb16(
|
|
77
77
|
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
78
78
|
) -> nn.Module:
|
|
@@ -95,7 +95,7 @@ def kaiko_vitb16(
|
|
|
95
95
|
)
|
|
96
96
|
|
|
97
97
|
|
|
98
|
-
@
|
|
98
|
+
@backbone_registry.register("pathology/kaiko_vitb8")
|
|
99
99
|
def kaiko_vitb8(
|
|
100
100
|
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
101
101
|
) -> nn.Module:
|
|
@@ -118,7 +118,7 @@ def kaiko_vitb8(
|
|
|
118
118
|
)
|
|
119
119
|
|
|
120
120
|
|
|
121
|
-
@
|
|
121
|
+
@backbone_registry.register("pathology/kaiko_vitl14")
|
|
122
122
|
def kaiko_vitl14(
|
|
123
123
|
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
124
124
|
) -> nn.Module:
|
|
@@ -13,14 +13,14 @@ from typing import Tuple
|
|
|
13
13
|
from torch import nn
|
|
14
14
|
|
|
15
15
|
from eva.vision.models import wrappers
|
|
16
|
-
from eva.vision.models.networks.backbones.registry import
|
|
16
|
+
from eva.vision.models.networks.backbones.registry import backbone_registry
|
|
17
17
|
|
|
18
18
|
VITS_URL_PREFIX = (
|
|
19
19
|
"https://github.com/lunit-io/benchmark-ssl-pathology/releases/download/pretrained-weights"
|
|
20
20
|
)
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
@
|
|
23
|
+
@backbone_registry.register("pathology/lunit_vits16")
|
|
24
24
|
def lunit_vits16(
|
|
25
25
|
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
26
26
|
) -> nn.Module:
|
|
@@ -44,7 +44,7 @@ def lunit_vits16(
|
|
|
44
44
|
)
|
|
45
45
|
|
|
46
46
|
|
|
47
|
-
@
|
|
47
|
+
@backbone_registry.register("pathology/lunit_vits8")
|
|
48
48
|
def lunit_vits8(
|
|
49
49
|
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
50
50
|
) -> nn.Module:
|
|
@@ -8,10 +8,10 @@ from torch import nn
|
|
|
8
8
|
|
|
9
9
|
from eva.vision.models import wrappers
|
|
10
10
|
from eva.vision.models.networks.backbones import _utils
|
|
11
|
-
from eva.vision.models.networks.backbones.registry import
|
|
11
|
+
from eva.vision.models.networks.backbones.registry import backbone_registry
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
@
|
|
14
|
+
@backbone_registry.register("pathology/mahmood_uni")
|
|
15
15
|
def mahmood_uni(
|
|
16
16
|
dynamic_img_size: bool = True,
|
|
17
17
|
out_indices: int | Tuple[int, ...] | None = None,
|
|
@@ -41,7 +41,7 @@ def mahmood_uni(
|
|
|
41
41
|
)
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
@
|
|
44
|
+
@backbone_registry.register("pathology/mahmood_uni2_h")
|
|
45
45
|
def mahmood_uni2_h(
|
|
46
46
|
dynamic_img_size: bool = True,
|
|
47
47
|
out_indices: int | Tuple[int, ...] | None = None,
|
|
@@ -5,10 +5,10 @@ from typing import Tuple
|
|
|
5
5
|
from torch import nn
|
|
6
6
|
|
|
7
7
|
from eva.vision.models.networks.backbones import _utils
|
|
8
|
-
from eva.vision.models.networks.backbones.registry import
|
|
8
|
+
from eva.vision.models.networks.backbones.registry import backbone_registry
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
@
|
|
11
|
+
@backbone_registry.register("pathology/owkin_phikon")
|
|
12
12
|
def owkin_phikon(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
|
|
13
13
|
"""Initializes the phikon pathology FM by owkin (https://huggingface.co/owkin/phikon).
|
|
14
14
|
|
|
@@ -22,7 +22,7 @@ def owkin_phikon(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
|
|
|
22
22
|
return _utils.load_hugingface_model(model_name="owkin/phikon", out_indices=out_indices)
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
@
|
|
25
|
+
@backbone_registry.register("pathology/owkin_phikon_v2")
|
|
26
26
|
def owkin_phikon_v2(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
|
|
27
27
|
"""Initializes the phikon-v2 pathology FM by owkin (https://huggingface.co/owkin/phikon-v2).
|
|
28
28
|
|
|
@@ -11,10 +11,10 @@ import torch.nn as nn
|
|
|
11
11
|
from eva.core.models import transforms
|
|
12
12
|
from eva.vision.models import wrappers
|
|
13
13
|
from eva.vision.models.networks.backbones import _utils
|
|
14
|
-
from eva.vision.models.networks.backbones.registry import
|
|
14
|
+
from eva.vision.models.networks.backbones.registry import backbone_registry
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
@
|
|
17
|
+
@backbone_registry.register("pathology/paige_virchow2")
|
|
18
18
|
def paige_virchow2(
|
|
19
19
|
dynamic_img_size: bool = True,
|
|
20
20
|
out_indices: int | Tuple[int, ...] | None = None,
|
|
@@ -43,7 +43,7 @@ def paige_virchow2(
|
|
|
43
43
|
"mlp_layer": timm.layers.SwiGLUPacked,
|
|
44
44
|
"act_layer": nn.SiLU,
|
|
45
45
|
},
|
|
46
|
-
|
|
46
|
+
transforms=(
|
|
47
47
|
transforms.ExtractCLSFeatures(include_patch_tokens=include_patch_tokens)
|
|
48
48
|
if out_indices is None
|
|
49
49
|
else None
|
|
@@ -9,10 +9,10 @@ from monai.networks.nets import swin_unetr
|
|
|
9
9
|
from monai.utils import misc
|
|
10
10
|
from torch import nn
|
|
11
11
|
|
|
12
|
-
from eva.vision.models.networks.backbones.registry import
|
|
12
|
+
from eva.vision.models.networks.backbones.registry import backbone_registry
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
@
|
|
15
|
+
@backbone_registry.register("radiology/swin_unetr_encoder")
|
|
16
16
|
class SwinUNETREncoder(nn.Module):
|
|
17
17
|
"""Swin transformer encoder based on UNETR [0].
|
|
18
18
|
|
|
@@ -4,10 +4,10 @@ from typing_extensions import override
|
|
|
4
4
|
|
|
5
5
|
from eva.core.models.wrappers import _utils
|
|
6
6
|
from eva.vision.models.networks.backbones.radiology import swin_unetr
|
|
7
|
-
from eva.vision.models.networks.backbones.registry import
|
|
7
|
+
from eva.vision.models.networks.backbones.registry import backbone_registry
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
class _VoCo(swin_unetr.SwinUNETREncoder):
|
|
10
|
+
class _VoCo(swin_unetr.SwinUNETREncoder): # type: ignore
|
|
11
11
|
"""Base class for the VoCo self-supervised encoders."""
|
|
12
12
|
|
|
13
13
|
_checkpoint: str
|
|
@@ -39,7 +39,7 @@ class _VoCo(swin_unetr.SwinUNETREncoder):
|
|
|
39
39
|
self.load_state_dict(state_dict)
|
|
40
40
|
|
|
41
41
|
|
|
42
|
-
@
|
|
42
|
+
@backbone_registry.register("radiology/voco_b")
|
|
43
43
|
class VoCoB(_VoCo):
|
|
44
44
|
"""VoCo Self-supervised pre-trained B model."""
|
|
45
45
|
|
|
@@ -51,7 +51,7 @@ class VoCoB(_VoCo):
|
|
|
51
51
|
super().__init__(feature_size=48, out_indices=out_indices)
|
|
52
52
|
|
|
53
53
|
|
|
54
|
-
@
|
|
54
|
+
@backbone_registry.register("radiology/voco_l")
|
|
55
55
|
class VoCoL(_VoCo):
|
|
56
56
|
"""VoCo Self-supervised pre-trained L model."""
|
|
57
57
|
|
|
@@ -63,7 +63,7 @@ class VoCoL(_VoCo):
|
|
|
63
63
|
super().__init__(feature_size=96, out_indices=out_indices)
|
|
64
64
|
|
|
65
65
|
|
|
66
|
-
@
|
|
66
|
+
@backbone_registry.register("radiology/voco_h")
|
|
67
67
|
class VoCoH(_VoCo):
|
|
68
68
|
"""VoCo Self-supervised pre-trained H model."""
|
|
69
69
|
|
|
@@ -1,47 +1,5 @@
|
|
|
1
1
|
"""Backbone Model Registry."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from eva.core.utils.registry import Registry
|
|
4
4
|
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class BackboneModelRegistry:
|
|
9
|
-
"""A model registry for accessing backbone models by name."""
|
|
10
|
-
|
|
11
|
-
_registry: Dict[str, Callable[..., nn.Module]] = {}
|
|
12
|
-
|
|
13
|
-
@classmethod
|
|
14
|
-
def register(cls, name: str) -> Callable:
|
|
15
|
-
"""Decorator to register a new model."""
|
|
16
|
-
|
|
17
|
-
def decorator(model_fn: Callable[..., nn.Module]) -> Callable[..., nn.Module]:
|
|
18
|
-
if name in cls._registry:
|
|
19
|
-
raise ValueError(f"Model {name} is already registered.")
|
|
20
|
-
cls._registry[name] = model_fn
|
|
21
|
-
return model_fn
|
|
22
|
-
|
|
23
|
-
return decorator
|
|
24
|
-
|
|
25
|
-
@classmethod
|
|
26
|
-
def get(cls, model_name: str) -> Callable[..., nn.Module]:
|
|
27
|
-
"""Gets a model function from the registry."""
|
|
28
|
-
if model_name not in cls._registry:
|
|
29
|
-
raise ValueError(f"Model {model_name} not found in the registry.")
|
|
30
|
-
return cls._registry[model_name]
|
|
31
|
-
|
|
32
|
-
@classmethod
|
|
33
|
-
def load_model(cls, model_name: str, model_kwargs: Dict[str, Any] | None = None) -> nn.Module:
|
|
34
|
-
"""Loads & initializes a model class from the registry."""
|
|
35
|
-
model_fn = cls.get(model_name)
|
|
36
|
-
return model_fn(**(model_kwargs or {}))
|
|
37
|
-
|
|
38
|
-
@classmethod
|
|
39
|
-
def list_models(cls) -> List[str]:
|
|
40
|
-
"""List all models in the registry."""
|
|
41
|
-
register_models = [name for name in cls._registry.keys() if not name.startswith("timm")]
|
|
42
|
-
return register_models + ["timm/<model_name>"]
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
def register_model(name: str) -> Callable:
|
|
46
|
-
"""Simple decorator to register a model."""
|
|
47
|
-
return BackboneModelRegistry.register(name)
|
|
5
|
+
backbone_registry = Registry()
|
|
@@ -8,7 +8,7 @@ from loguru import logger
|
|
|
8
8
|
from torch import nn
|
|
9
9
|
|
|
10
10
|
from eva.vision.models import wrappers
|
|
11
|
-
from eva.vision.models.networks.backbones.registry import
|
|
11
|
+
from eva.vision.models.networks.backbones.registry import backbone_registry
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
def timm_model(
|
|
@@ -46,7 +46,7 @@ def timm_model(
|
|
|
46
46
|
)
|
|
47
47
|
|
|
48
48
|
|
|
49
|
-
|
|
49
|
+
backbone_registry._registry.update(
|
|
50
50
|
{
|
|
51
51
|
f"timm/{model_name}": functools.partial(timm_model, model_name=model_name)
|
|
52
52
|
for model_name in timm.list_models()
|
|
@@ -1,8 +1,15 @@
|
|
|
1
1
|
"""Universal Vision Model Backbones API."""
|
|
2
2
|
|
|
3
3
|
from eva.vision.models.networks.backbones.universal.vit import (
|
|
4
|
+
vit_base_patch16_224_dino_1chan,
|
|
4
5
|
vit_small_patch16_224_dino,
|
|
6
|
+
vit_small_patch16_224_dino_1chan,
|
|
5
7
|
vit_small_patch16_224_random,
|
|
6
8
|
)
|
|
7
9
|
|
|
8
|
-
__all__ = [
|
|
10
|
+
__all__ = [
|
|
11
|
+
"vit_small_patch16_224_dino",
|
|
12
|
+
"vit_small_patch16_224_random",
|
|
13
|
+
"vit_small_patch16_224_dino_1chan",
|
|
14
|
+
"vit_base_patch16_224_dino_1chan",
|
|
15
|
+
]
|
|
@@ -5,10 +5,10 @@ from typing import Tuple
|
|
|
5
5
|
import timm
|
|
6
6
|
from torch import nn
|
|
7
7
|
|
|
8
|
-
from eva.vision.models.networks.backbones.registry import
|
|
8
|
+
from eva.vision.models.networks.backbones.registry import backbone_registry
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
@
|
|
11
|
+
@backbone_registry.register("universal/vit_small_patch16_224_random")
|
|
12
12
|
def vit_small_patch16_224_random(
|
|
13
13
|
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
14
14
|
) -> nn.Module:
|
|
@@ -31,7 +31,7 @@ def vit_small_patch16_224_random(
|
|
|
31
31
|
)
|
|
32
32
|
|
|
33
33
|
|
|
34
|
-
@
|
|
34
|
+
@backbone_registry.register("universal/vit_small_patch16_224_dino")
|
|
35
35
|
def vit_small_patch16_224_dino(
|
|
36
36
|
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
37
37
|
) -> nn.Module:
|
|
@@ -52,3 +52,53 @@ def vit_small_patch16_224_dino(
|
|
|
52
52
|
out_indices=out_indices,
|
|
53
53
|
dynamic_img_size=dynamic_img_size,
|
|
54
54
|
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@backbone_registry.register("universal/vit_small_patch16_224_dino_1chan")
|
|
58
|
+
def vit_small_patch16_224_dino_1chan(
|
|
59
|
+
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
60
|
+
) -> nn.Module:
|
|
61
|
+
"""Initializes a ViTS-16 baseline model pretrained w/ DINO for single-channel images.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
dynamic_img_size: Support different input image sizes by allowing to change
|
|
65
|
+
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
|
|
66
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
The torch ViTS-16 based foundation model.
|
|
70
|
+
"""
|
|
71
|
+
return timm.create_model(
|
|
72
|
+
model_name="vit_small_patch16_224.dino",
|
|
73
|
+
in_chans=1,
|
|
74
|
+
num_classes=0,
|
|
75
|
+
pretrained=True,
|
|
76
|
+
features_only=out_indices is not None,
|
|
77
|
+
out_indices=out_indices,
|
|
78
|
+
dynamic_img_size=dynamic_img_size,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@backbone_registry.register("universal/vit_base_patch16_224_dino_1chan")
|
|
83
|
+
def vit_base_patch16_224_dino_1chan(
|
|
84
|
+
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
85
|
+
) -> nn.Module:
|
|
86
|
+
"""Initializes a ViTB-16 baseline model pretrained w/ DINO for single-channel images.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
dynamic_img_size: Support different input image sizes by allowing to change
|
|
90
|
+
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
|
|
91
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
The torch ViTB-16 based foundation model.
|
|
95
|
+
"""
|
|
96
|
+
return timm.create_model(
|
|
97
|
+
model_name="vit_base_patch16_224.dino",
|
|
98
|
+
in_chans=1,
|
|
99
|
+
num_classes=0,
|
|
100
|
+
pretrained=True,
|
|
101
|
+
features_only=out_indices is not None,
|
|
102
|
+
out_indices=out_indices,
|
|
103
|
+
dynamic_img_size=dynamic_img_size,
|
|
104
|
+
)
|
|
@@ -117,4 +117,4 @@ class LinearDecoder(base.Decoder):
|
|
|
117
117
|
"""
|
|
118
118
|
patch_embeddings = self._forward_features(decoder_inputs.features)
|
|
119
119
|
logits = self._forward_head(patch_embeddings)
|
|
120
|
-
return self._cls_seg(logits, decoder_inputs.image_size)
|
|
120
|
+
return self._cls_seg(logits, decoder_inputs.image_size) # type: ignore
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Common semantic segmentation decoders.
|
|
2
2
|
|
|
3
|
-
This module contains implementations of different types of decoder models
|
|
4
|
-
used in semantic segmentation. These decoders convert the high-level features
|
|
3
|
+
This module contains implementations of different types of decoder models
|
|
4
|
+
used in semantic segmentation. These decoders convert the high-level features
|
|
5
5
|
output by an encoder into pixel-wise predictions for segmentation tasks.
|
|
6
6
|
"""
|
|
7
7
|
|
|
@@ -11,7 +11,7 @@ class DecoderInputs(NamedTuple):
|
|
|
11
11
|
features: List[torch.Tensor]
|
|
12
12
|
"""List of image features generated by the encoder from the original images."""
|
|
13
13
|
|
|
14
|
-
image_size: Tuple[int,
|
|
14
|
+
image_size: Tuple[int, ...]
|
|
15
15
|
"""Size of the original input images to be used for upsampling."""
|
|
16
16
|
|
|
17
17
|
images: torch.Tensor | None = None
|
|
@@ -2,18 +2,20 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Any, Callable, Dict
|
|
4
4
|
|
|
5
|
+
import torch
|
|
5
6
|
from typing_extensions import override
|
|
6
7
|
|
|
7
|
-
from eva.core.models import
|
|
8
|
-
from eva.
|
|
8
|
+
from eva.core.models.wrappers import base
|
|
9
|
+
from eva.core.utils import factory
|
|
10
|
+
from eva.vision.models.networks.backbones import backbone_registry
|
|
9
11
|
|
|
10
12
|
|
|
11
|
-
class ModelFromRegistry(
|
|
13
|
+
class ModelFromRegistry(base.BaseModel[torch.Tensor, torch.Tensor]):
|
|
12
14
|
"""Wrapper class for vision backbone models.
|
|
13
15
|
|
|
14
16
|
This class can be used by load backbones available in eva's
|
|
15
17
|
model registry by name. New backbones can be registered by using
|
|
16
|
-
the `@
|
|
18
|
+
the `@backbone_registry.register(model_name)` decorator.
|
|
17
19
|
"""
|
|
18
20
|
|
|
19
21
|
def __init__(
|
|
@@ -21,7 +23,7 @@ class ModelFromRegistry(wrappers.BaseModel):
|
|
|
21
23
|
model_name: str,
|
|
22
24
|
model_kwargs: Dict[str, Any] | None = None,
|
|
23
25
|
model_extra_kwargs: Dict[str, Any] | None = None,
|
|
24
|
-
|
|
26
|
+
transforms: Callable | None = None,
|
|
25
27
|
) -> None:
|
|
26
28
|
"""Initializes the model.
|
|
27
29
|
|
|
@@ -29,10 +31,10 @@ class ModelFromRegistry(wrappers.BaseModel):
|
|
|
29
31
|
model_name: The name of the model to load.
|
|
30
32
|
model_kwargs: The arguments used for instantiating the model.
|
|
31
33
|
model_extra_kwargs: Extra arguments used for instantiating the model.
|
|
32
|
-
|
|
34
|
+
transforms: The transforms to apply to the output tensor
|
|
33
35
|
produced by the model.
|
|
34
36
|
"""
|
|
35
|
-
super().__init__(
|
|
37
|
+
super().__init__(transforms=transforms)
|
|
36
38
|
|
|
37
39
|
self._model_name = model_name
|
|
38
40
|
self._model_kwargs = model_kwargs or {}
|
|
@@ -42,7 +44,10 @@ class ModelFromRegistry(wrappers.BaseModel):
|
|
|
42
44
|
|
|
43
45
|
@override
|
|
44
46
|
def load_model(self) -> None:
|
|
45
|
-
self._model =
|
|
46
|
-
|
|
47
|
+
self._model = factory.ModuleFactory(
|
|
48
|
+
registry=backbone_registry,
|
|
49
|
+
name=self._model_name,
|
|
50
|
+
init_args=self._model_kwargs | self._model_extra_kwargs,
|
|
47
51
|
)
|
|
52
|
+
|
|
48
53
|
ModelFromRegistry.__name__ = self._model_name
|