kaiko-eva 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/writers/embeddings/base.py +3 -4
- eva/core/data/dataloaders/dataloader.py +2 -2
- eva/core/data/splitting/random.py +6 -5
- eva/core/data/splitting/stratified.py +12 -6
- eva/core/losses/__init__.py +5 -0
- eva/core/losses/cross_entropy.py +27 -0
- eva/core/metrics/__init__.py +0 -4
- eva/core/metrics/defaults/__init__.py +0 -2
- eva/core/models/modules/module.py +9 -9
- eva/core/models/transforms/extract_cls_features.py +17 -9
- eva/core/models/transforms/extract_patch_features.py +23 -11
- eva/core/utils/progress_bar.py +15 -0
- eva/vision/data/datasets/__init__.py +4 -0
- eva/vision/data/datasets/classification/__init__.py +2 -1
- eva/vision/data/datasets/classification/camelyon16.py +4 -1
- eva/vision/data/datasets/classification/panda.py +17 -1
- eva/vision/data/datasets/classification/wsi.py +4 -1
- eva/vision/data/datasets/segmentation/__init__.py +2 -0
- eva/vision/data/datasets/segmentation/consep.py +2 -2
- eva/vision/data/datasets/segmentation/lits.py +49 -29
- eva/vision/data/datasets/segmentation/lits_balanced.py +93 -0
- eva/vision/data/datasets/segmentation/monusac.py +7 -7
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +2 -2
- eva/vision/data/datasets/wsi.py +37 -1
- eva/vision/data/wsi/patching/coordinates.py +9 -1
- eva/vision/data/wsi/patching/samplers/_utils.py +2 -8
- eva/vision/data/wsi/patching/samplers/random.py +4 -2
- eva/vision/losses/__init__.py +2 -2
- eva/vision/losses/dice.py +75 -8
- eva/vision/metrics/__init__.py +11 -0
- eva/vision/metrics/defaults/__init__.py +7 -0
- eva/{core → vision}/metrics/defaults/segmentation/__init__.py +1 -1
- eva/{core → vision}/metrics/defaults/segmentation/multiclass.py +2 -1
- eva/vision/metrics/segmentation/BUILD +1 -0
- eva/vision/metrics/segmentation/__init__.py +9 -0
- eva/vision/metrics/segmentation/_utils.py +69 -0
- eva/{core/metrics → vision/metrics/segmentation}/generalized_dice.py +12 -10
- eva/vision/metrics/segmentation/mean_iou.py +57 -0
- eva/vision/models/modules/semantic_segmentation.py +4 -3
- eva/vision/models/networks/backbones/_utils.py +12 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +4 -1
- eva/vision/models/networks/backbones/pathology/histai.py +8 -2
- eva/vision/models/networks/backbones/pathology/mahmood.py +2 -9
- eva/vision/models/networks/backbones/pathology/owkin.py +14 -0
- eva/vision/models/networks/backbones/pathology/paige.py +51 -0
- eva/vision/models/networks/decoders/__init__.py +1 -1
- eva/vision/models/networks/decoders/segmentation/__init__.py +12 -4
- eva/vision/models/networks/decoders/segmentation/base.py +16 -0
- eva/vision/models/networks/decoders/segmentation/{conv2d.py → decoder2d.py} +26 -22
- eva/vision/models/networks/decoders/segmentation/linear.py +2 -2
- eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +12 -0
- eva/vision/models/networks/decoders/segmentation/{common.py → semantic/common.py} +3 -3
- eva/vision/models/networks/decoders/segmentation/semantic/with_image.py +94 -0
- eva/vision/models/networks/decoders/segmentation/typings.py +18 -0
- eva/vision/utils/io/__init__.py +7 -1
- eva/vision/utils/io/nifti.py +19 -4
- {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.3.dist-info}/METADATA +3 -34
- {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.3.dist-info}/RECORD +61 -48
- {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.3.dist-info}/WHEEL +1 -1
- eva/core/metrics/mean_iou.py +0 -120
- eva/vision/models/networks/decoders/decoder.py +0 -7
- {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.3.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Utils for segmentation metric collections."""
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def apply_ignore_index(
|
|
9
|
+
preds: torch.Tensor, target: torch.Tensor, ignore_index: int, num_classes: int
|
|
10
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
11
|
+
"""Applies the ignore index to the predictions and target tensors.
|
|
12
|
+
|
|
13
|
+
1. Masks the values in the target tensor that correspond to the ignored index.
|
|
14
|
+
2. Remove the channel corresponding to the ignored index from both tensors.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
preds: The predictions tensor. Expected to be of shape `(N,C,...)`.
|
|
18
|
+
target: The target tensor. Expected to be of shape `(N,C,...)`.
|
|
19
|
+
ignore_index: The index to ignore.
|
|
20
|
+
num_classes: The number of classes.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
The modified predictions and target tensors of shape `(N,C-1,...)`.
|
|
24
|
+
"""
|
|
25
|
+
if ignore_index < 0:
|
|
26
|
+
raise ValueError("ignore_index must be a non-negative integer")
|
|
27
|
+
|
|
28
|
+
ignore_mask = preds[:, ignore_index] == 1
|
|
29
|
+
target = target * (~ignore_mask.unsqueeze(1))
|
|
30
|
+
|
|
31
|
+
preds = _ignore_tensor_channel(preds, ignore_index)
|
|
32
|
+
target = _ignore_tensor_channel(target, ignore_index)
|
|
33
|
+
|
|
34
|
+
return preds, target
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def index_to_one_hot(tensor: torch.Tensor, num_classes: int) -> torch.Tensor:
|
|
38
|
+
"""Converts an index tensor to a one-hot tensor.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
tensor: The index tensor to convert. Expected to be of shape `(N,...)`.
|
|
42
|
+
num_classes: The number of classes to one-hot encode.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
A one-hot tensor of shape `(N,C,...)`.
|
|
46
|
+
"""
|
|
47
|
+
if not _is_one_hot(tensor):
|
|
48
|
+
tensor = torch.nn.functional.one_hot(tensor.long(), num_classes=num_classes).movedim(-1, 1)
|
|
49
|
+
return tensor
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _ignore_tensor_channel(tensor: torch.Tensor, ignore_index: int) -> torch.Tensor:
|
|
53
|
+
"""Removes the channel corresponding to the specified ignore index.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
tensor: The tensor to remove the channel from. Expected to be of shape `(N,C,...)`.
|
|
57
|
+
ignore_index: The index of the channel dimension (C) to remove.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
A tensor without the specified channel `(N,C-1,...)`.
|
|
61
|
+
"""
|
|
62
|
+
if ignore_index < 0:
|
|
63
|
+
raise ValueError("ignore_index must be a non-negative integer")
|
|
64
|
+
return torch.cat([tensor[:, :ignore_index], tensor[:, ignore_index + 1 :]], dim=1)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _is_one_hot(tensor: torch.Tensor, expected_dim: int = 4) -> bool:
|
|
68
|
+
"""Checks if the tensor is a one-hot tensor."""
|
|
69
|
+
return bool((tensor.bool() == tensor).all()) and tensor.ndim == expected_dim
|
|
@@ -6,6 +6,8 @@ import torch
|
|
|
6
6
|
from torchmetrics import segmentation
|
|
7
7
|
from typing_extensions import override
|
|
8
8
|
|
|
9
|
+
from eva.vision.metrics.segmentation import _utils
|
|
10
|
+
|
|
9
11
|
|
|
10
12
|
class GeneralizedDiceScore(segmentation.GeneralizedDiceScore):
|
|
11
13
|
"""Defines the Generalized Dice Score.
|
|
@@ -30,8 +32,6 @@ class GeneralizedDiceScore(segmentation.GeneralizedDiceScore):
|
|
|
30
32
|
include_background: Whether to include the background class in the computation
|
|
31
33
|
weight_type: The type of weight to apply to each class. Can be one of `"square"`,
|
|
32
34
|
`"simple"`, or `"linear"`.
|
|
33
|
-
input_format: What kind of input the function receives. Choose between ``"one-hot"``
|
|
34
|
-
for one-hot encoded tensors or ``"index"`` for index tensors.
|
|
35
35
|
ignore_index: Integer specifying a target class to ignore. If given, this class
|
|
36
36
|
index does not contribute to the returned score, regardless of reduction method.
|
|
37
37
|
per_class: Whether to compute the IoU for each class separately. If set to ``False``,
|
|
@@ -39,21 +39,23 @@ class GeneralizedDiceScore(segmentation.GeneralizedDiceScore):
|
|
|
39
39
|
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
|
|
40
40
|
"""
|
|
41
41
|
super().__init__(
|
|
42
|
-
num_classes=num_classes
|
|
42
|
+
num_classes=num_classes
|
|
43
|
+
- (ignore_index is not None)
|
|
44
|
+
+ (ignore_index == 0 and not include_background),
|
|
43
45
|
include_background=include_background,
|
|
44
46
|
weight_type=weight_type,
|
|
45
47
|
per_class=per_class,
|
|
46
48
|
**kwargs,
|
|
47
49
|
)
|
|
48
|
-
|
|
50
|
+
self.orig_num_classes = num_classes
|
|
49
51
|
self.ignore_index = ignore_index
|
|
50
52
|
|
|
51
53
|
@override
|
|
52
54
|
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
|
|
55
|
+
preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes)
|
|
56
|
+
target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes)
|
|
53
57
|
if self.ignore_index is not None:
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
super().update(preds=preds, target=target)
|
|
58
|
+
preds, target = _utils.apply_ignore_index(
|
|
59
|
+
preds, target, self.ignore_index, self.num_classes
|
|
60
|
+
)
|
|
61
|
+
super().update(preds=preds.long(), target=target.long())
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""MeanIoU metric for semantic segmentation."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torchmetrics import segmentation
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.vision.metrics.segmentation import _utils
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MeanIoU(segmentation.MeanIoU):
|
|
13
|
+
"""MeanIoU (mIOU) metric for semantic segmentation.
|
|
14
|
+
|
|
15
|
+
It expands the `torchmetrics` class by including an `ignore_index`
|
|
16
|
+
functionality.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
num_classes: int,
|
|
22
|
+
include_background: bool = True,
|
|
23
|
+
ignore_index: int | None = None,
|
|
24
|
+
per_class: bool = False,
|
|
25
|
+
**kwargs: Any,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Initializes the metric.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
num_classes: The number of classes in the segmentation problem.
|
|
31
|
+
include_background: Whether to include the background class in the computation
|
|
32
|
+
ignore_index: Integer specifying a target class to ignore. If given, this class
|
|
33
|
+
index does not contribute to the returned score, regardless of reduction method.
|
|
34
|
+
per_class: Whether to compute the IoU for each class separately. If set to ``False``,
|
|
35
|
+
the metric will compute the mean IoU over all classes.
|
|
36
|
+
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
|
|
37
|
+
"""
|
|
38
|
+
super().__init__(
|
|
39
|
+
num_classes=num_classes
|
|
40
|
+
- (ignore_index is not None)
|
|
41
|
+
+ (ignore_index == 0 and not include_background),
|
|
42
|
+
include_background=include_background,
|
|
43
|
+
per_class=per_class,
|
|
44
|
+
**kwargs,
|
|
45
|
+
)
|
|
46
|
+
self.orig_num_classes = num_classes
|
|
47
|
+
self.ignore_index = ignore_index
|
|
48
|
+
|
|
49
|
+
@override
|
|
50
|
+
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
|
|
51
|
+
preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes)
|
|
52
|
+
target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes)
|
|
53
|
+
if self.ignore_index is not None:
|
|
54
|
+
preds, target = _utils.apply_ignore_index(
|
|
55
|
+
preds, target, self.ignore_index, self.num_classes
|
|
56
|
+
)
|
|
57
|
+
super().update(preds=preds.long(), target=target.long())
|
|
@@ -15,6 +15,7 @@ from eva.core.models.modules.typings import INPUT_BATCH, INPUT_TENSOR_BATCH
|
|
|
15
15
|
from eva.core.models.modules.utils import batch_postprocess, grad
|
|
16
16
|
from eva.core.utils import parser
|
|
17
17
|
from eva.vision.models.networks import decoders
|
|
18
|
+
from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
class SemanticSegmentationModule(module.ModelModule):
|
|
@@ -101,9 +102,9 @@ class SemanticSegmentationModule(module.ModelModule):
|
|
|
101
102
|
"Please provide the expected `to_size` that the "
|
|
102
103
|
"decoder should map the embeddings (`inputs`) to."
|
|
103
104
|
)
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
return self.decoder(
|
|
105
|
+
features = self.encoder(inputs) if self.encoder else inputs
|
|
106
|
+
decoder_inputs = DecoderInputs(features, inputs.shape[-2:], inputs) # type: ignore
|
|
107
|
+
return self.decoder(decoder_inputs)
|
|
107
108
|
|
|
108
109
|
@override
|
|
109
110
|
def training_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
"""Utilis for backbone networks."""
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
from typing import Any, Dict, Tuple
|
|
4
5
|
|
|
6
|
+
import huggingface_hub
|
|
5
7
|
from torch import nn
|
|
6
8
|
|
|
7
9
|
from eva import models
|
|
@@ -37,3 +39,13 @@ def load_hugingface_model(
|
|
|
37
39
|
tensor_transforms=tensor_transforms,
|
|
38
40
|
model_kwargs=model_kwargs,
|
|
39
41
|
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def huggingface_login(hf_token: str | None = None):
|
|
45
|
+
token = hf_token or os.environ.get("HF_TOKEN")
|
|
46
|
+
if not token:
|
|
47
|
+
raise ValueError(
|
|
48
|
+
"Please provide a HuggingFace token to download the model. "
|
|
49
|
+
"You can either pass it as an argument or set the env variable HF_TOKEN."
|
|
50
|
+
)
|
|
51
|
+
huggingface_hub.login(token=token)
|
|
@@ -12,7 +12,8 @@ from eva.vision.models.networks.backbones.pathology.kaiko import (
|
|
|
12
12
|
)
|
|
13
13
|
from eva.vision.models.networks.backbones.pathology.lunit import lunit_vits8, lunit_vits16
|
|
14
14
|
from eva.vision.models.networks.backbones.pathology.mahmood import mahmood_uni
|
|
15
|
-
from eva.vision.models.networks.backbones.pathology.owkin import owkin_phikon
|
|
15
|
+
from eva.vision.models.networks.backbones.pathology.owkin import owkin_phikon, owkin_phikon_v2
|
|
16
|
+
from eva.vision.models.networks.backbones.pathology.paige import paige_virchow2
|
|
16
17
|
|
|
17
18
|
__all__ = [
|
|
18
19
|
"kaiko_vitb16",
|
|
@@ -21,6 +22,7 @@ __all__ = [
|
|
|
21
22
|
"kaiko_vits16",
|
|
22
23
|
"kaiko_vits8",
|
|
23
24
|
"owkin_phikon",
|
|
25
|
+
"owkin_phikon_v2",
|
|
24
26
|
"lunit_vits16",
|
|
25
27
|
"lunit_vits8",
|
|
26
28
|
"mahmood_uni",
|
|
@@ -28,4 +30,5 @@ __all__ = [
|
|
|
28
30
|
"prov_gigapath",
|
|
29
31
|
"histai_hibou_b",
|
|
30
32
|
"histai_hibou_l",
|
|
33
|
+
"paige_virchow2",
|
|
31
34
|
]
|
|
@@ -12,6 +12,9 @@ from eva.vision.models.networks.backbones.registry import register_model
|
|
|
12
12
|
def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
|
|
13
13
|
"""Initializes the hibou-B pathology FM by hist.ai (https://huggingface.co/histai/hibou-B).
|
|
14
14
|
|
|
15
|
+
Uses a customized implementation of the DINOv2 architecture from the transformers
|
|
16
|
+
library to add support for registers, which requires the trust_remote_code=True flag.
|
|
17
|
+
|
|
15
18
|
Args:
|
|
16
19
|
out_indices: Whether and which multi-level patch embeddings to return.
|
|
17
20
|
Currently only out_indices=1 is supported.
|
|
@@ -23,7 +26,7 @@ def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Modul
|
|
|
23
26
|
model_name="histai/hibou-B",
|
|
24
27
|
out_indices=out_indices,
|
|
25
28
|
model_kwargs={"trust_remote_code": True},
|
|
26
|
-
transform_args={"
|
|
29
|
+
transform_args={"num_register_tokens": 4} if out_indices is not None else None,
|
|
27
30
|
)
|
|
28
31
|
|
|
29
32
|
|
|
@@ -31,6 +34,9 @@ def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Modul
|
|
|
31
34
|
def histai_hibou_l(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
|
|
32
35
|
"""Initializes the hibou-L pathology FM by hist.ai (https://huggingface.co/histai/hibou-L).
|
|
33
36
|
|
|
37
|
+
Uses a customized implementation of the DINOv2 architecture from the transformers
|
|
38
|
+
library to add support for registers, which requires the trust_remote_code=True flag.
|
|
39
|
+
|
|
34
40
|
Args:
|
|
35
41
|
out_indices: Whether and which multi-level patch embeddings to return.
|
|
36
42
|
Currently only out_indices=1 is supported.
|
|
@@ -42,5 +48,5 @@ def histai_hibou_l(out_indices: int | Tuple[int, ...] | None = None) -> nn.Modul
|
|
|
42
48
|
model_name="histai/hibou-L",
|
|
43
49
|
out_indices=out_indices,
|
|
44
50
|
model_kwargs={"trust_remote_code": True},
|
|
45
|
-
transform_args={"
|
|
51
|
+
transform_args={"num_register_tokens": 4} if out_indices is not None else None,
|
|
46
52
|
)
|
|
@@ -9,6 +9,7 @@ from loguru import logger
|
|
|
9
9
|
from torch import nn
|
|
10
10
|
|
|
11
11
|
from eva.vision.models import wrappers
|
|
12
|
+
from eva.vision.models.networks.backbones import _utils
|
|
12
13
|
from eva.vision.models.networks.backbones.registry import register_model
|
|
13
14
|
|
|
14
15
|
|
|
@@ -31,19 +32,11 @@ def mahmood_uni(
|
|
|
31
32
|
Returns:
|
|
32
33
|
The model instance.
|
|
33
34
|
"""
|
|
34
|
-
token = hf_token or os.environ.get("HF_TOKEN")
|
|
35
|
-
if not token:
|
|
36
|
-
raise ValueError(
|
|
37
|
-
"Please provide a HuggingFace token to download the model. "
|
|
38
|
-
"You can either pass it as an argument or set the env variable HF_TOKEN."
|
|
39
|
-
)
|
|
40
|
-
|
|
41
35
|
checkpoint_path = os.path.join(download_dir, "pytorch_model.bin")
|
|
42
|
-
|
|
43
36
|
if not os.path.exists(checkpoint_path):
|
|
44
37
|
logger.info(f"Downloading the model checkpoint to {download_dir} ...")
|
|
45
38
|
os.makedirs(download_dir, exist_ok=True)
|
|
46
|
-
|
|
39
|
+
_utils.huggingface_login(hf_token)
|
|
47
40
|
huggingface_hub.hf_hub_download(
|
|
48
41
|
"MahmoodLab/UNI",
|
|
49
42
|
filename="pytorch_model.bin",
|
|
@@ -20,3 +20,17 @@ def owkin_phikon(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
|
|
|
20
20
|
The model instance.
|
|
21
21
|
"""
|
|
22
22
|
return _utils.load_hugingface_model(model_name="owkin/phikon", out_indices=out_indices)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@register_model("pathology/owkin_phikon_v2")
|
|
26
|
+
def owkin_phikon_v2(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
|
|
27
|
+
"""Initializes the phikon-v2 pathology FM by owkin (https://huggingface.co/owkin/phikon-v2).
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
31
|
+
Currently only out_indices=1 is supported.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
The model instance.
|
|
35
|
+
"""
|
|
36
|
+
return _utils.load_hugingface_model(model_name="owkin/phikon-v2", out_indices=out_indices)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Pathology FMs from paige.ai.
|
|
2
|
+
|
|
3
|
+
Source: https://huggingface.co/paige-ai/
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Tuple
|
|
7
|
+
|
|
8
|
+
import timm
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
|
|
11
|
+
from eva.core.models import transforms
|
|
12
|
+
from eva.vision.models import wrappers
|
|
13
|
+
from eva.vision.models.networks.backbones import _utils
|
|
14
|
+
from eva.vision.models.networks.backbones.registry import register_model
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@register_model("pathology/paige_virchow2")
|
|
18
|
+
def paige_virchow2(
|
|
19
|
+
dynamic_img_size: bool = True,
|
|
20
|
+
out_indices: int | Tuple[int, ...] | None = None,
|
|
21
|
+
hf_token: str | None = None,
|
|
22
|
+
include_patch_tokens: bool = False,
|
|
23
|
+
) -> nn.Module:
|
|
24
|
+
"""Initializes the Virchow2 pathology FM by paige.ai.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
dynamic_img_size: Support different input image sizes by allowing to change
|
|
28
|
+
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
|
|
29
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
30
|
+
include_patch_tokens: Whether to combine the mean aggregated patch tokens with cls token.
|
|
31
|
+
hf_token: HuggingFace token to download the model.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
The model instance.
|
|
35
|
+
"""
|
|
36
|
+
_utils.huggingface_login(hf_token)
|
|
37
|
+
return wrappers.TimmModel(
|
|
38
|
+
model_name="hf-hub:paige-ai/Virchow2",
|
|
39
|
+
out_indices=out_indices,
|
|
40
|
+
pretrained=True,
|
|
41
|
+
model_kwargs={
|
|
42
|
+
"dynamic_img_size": dynamic_img_size,
|
|
43
|
+
"mlp_layer": timm.layers.SwiGLUPacked,
|
|
44
|
+
"act_layer": nn.SiLU,
|
|
45
|
+
},
|
|
46
|
+
tensor_transforms=(
|
|
47
|
+
transforms.ExtractCLSFeatures(include_patch_tokens=include_patch_tokens)
|
|
48
|
+
if out_indices is None
|
|
49
|
+
else None
|
|
50
|
+
),
|
|
51
|
+
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Decoder heads API."""
|
|
2
2
|
|
|
3
3
|
from eva.vision.models.networks.decoders import segmentation
|
|
4
|
-
from eva.vision.models.networks.decoders.
|
|
4
|
+
from eva.vision.models.networks.decoders.segmentation.base import Decoder
|
|
5
5
|
|
|
6
6
|
__all__ = ["segmentation", "Decoder"]
|
|
@@ -1,11 +1,19 @@
|
|
|
1
1
|
"""Segmentation decoder heads API."""
|
|
2
2
|
|
|
3
|
-
from eva.vision.models.networks.decoders.segmentation.
|
|
3
|
+
from eva.vision.models.networks.decoders.segmentation.decoder2d import Decoder2D
|
|
4
|
+
from eva.vision.models.networks.decoders.segmentation.linear import LinearDecoder
|
|
5
|
+
from eva.vision.models.networks.decoders.segmentation.semantic import (
|
|
4
6
|
ConvDecoder1x1,
|
|
5
7
|
ConvDecoderMS,
|
|
8
|
+
ConvDecoderWithImage,
|
|
6
9
|
SingleLinearDecoder,
|
|
7
10
|
)
|
|
8
|
-
from eva.vision.models.networks.decoders.segmentation.conv2d import ConvDecoder
|
|
9
|
-
from eva.vision.models.networks.decoders.segmentation.linear import LinearDecoder
|
|
10
11
|
|
|
11
|
-
__all__ = [
|
|
12
|
+
__all__ = [
|
|
13
|
+
"ConvDecoder1x1",
|
|
14
|
+
"ConvDecoderMS",
|
|
15
|
+
"SingleLinearDecoder",
|
|
16
|
+
"ConvDecoderWithImage",
|
|
17
|
+
"Decoder2D",
|
|
18
|
+
"LinearDecoder",
|
|
19
|
+
]
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Semantic segmentation decoder base class."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Decoder(nn.Module, abc.ABC):
|
|
12
|
+
"""Abstract base class for segmentation decoders."""
|
|
13
|
+
|
|
14
|
+
@abc.abstractmethod
|
|
15
|
+
def forward(self, decoder_inputs: DecoderInputs) -> torch.Tensor:
|
|
16
|
+
"""Forward pass of the decoder."""
|
|
@@ -1,19 +1,20 @@
|
|
|
1
1
|
"""Convolutional based semantic segmentation decoder."""
|
|
2
2
|
|
|
3
|
-
from typing import List, Tuple
|
|
3
|
+
from typing import List, Sequence, Tuple
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
from torch import nn
|
|
7
7
|
from torch.nn import functional
|
|
8
8
|
|
|
9
|
-
from eva.vision.models.networks.decoders import
|
|
9
|
+
from eva.vision.models.networks.decoders.segmentation import base
|
|
10
|
+
from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs
|
|
10
11
|
|
|
11
12
|
|
|
12
|
-
class
|
|
13
|
-
"""
|
|
13
|
+
class Decoder2D(base.Decoder):
|
|
14
|
+
"""Segmentation decoder for 2D applications."""
|
|
14
15
|
|
|
15
|
-
def __init__(self, layers: nn.Module) -> None:
|
|
16
|
-
"""Initializes the
|
|
16
|
+
def __init__(self, layers: nn.Module, combine_features: bool = True) -> None:
|
|
17
|
+
"""Initializes the based decoder head.
|
|
17
18
|
|
|
18
19
|
Here the input nn layers will be directly applied to the
|
|
19
20
|
features of shape (batch_size, hidden_size, n_patches_height,
|
|
@@ -21,13 +22,16 @@ class ConvDecoder(decoder.Decoder):
|
|
|
21
22
|
Note the n_patches is also known as grid_size.
|
|
22
23
|
|
|
23
24
|
Args:
|
|
24
|
-
layers: The
|
|
25
|
+
layers: The layers to be used as the decoder head.
|
|
26
|
+
combine_features: Whether to combine the features from different
|
|
27
|
+
feature levels into one tensor before applying the decoder head.
|
|
25
28
|
"""
|
|
26
29
|
super().__init__()
|
|
27
30
|
|
|
28
31
|
self._layers = layers
|
|
32
|
+
self._combine_features = combine_features
|
|
29
33
|
|
|
30
|
-
def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor:
|
|
34
|
+
def _forward_features(self, features: torch.Tensor | List[torch.Tensor]) -> torch.Tensor:
|
|
31
35
|
"""Forward function for multi-level feature maps to a single one.
|
|
32
36
|
|
|
33
37
|
It will interpolate the features and concat them into a single tensor
|
|
@@ -46,6 +50,8 @@ class ConvDecoder(decoder.Decoder):
|
|
|
46
50
|
A tensor of shape (batch_size, hidden_size, n_patches_height,
|
|
47
51
|
n_patches_width) which is feature map of the decoder head.
|
|
48
52
|
"""
|
|
53
|
+
if isinstance(features, torch.Tensor):
|
|
54
|
+
features = [features]
|
|
49
55
|
if not isinstance(features, list) or features[0].ndim != 4:
|
|
50
56
|
raise ValueError(
|
|
51
57
|
"Input features should be a list of four (4) dimensional inputs of "
|
|
@@ -63,7 +69,9 @@ class ConvDecoder(decoder.Decoder):
|
|
|
63
69
|
]
|
|
64
70
|
return torch.cat(upsampled_features, dim=1)
|
|
65
71
|
|
|
66
|
-
def _forward_head(
|
|
72
|
+
def _forward_head(
|
|
73
|
+
self, patch_embeddings: torch.Tensor | Sequence[torch.Tensor]
|
|
74
|
+
) -> torch.Tensor:
|
|
67
75
|
"""Forward of the decoder head.
|
|
68
76
|
|
|
69
77
|
Args:
|
|
@@ -75,12 +83,12 @@ class ConvDecoder(decoder.Decoder):
|
|
|
75
83
|
"""
|
|
76
84
|
return self._layers(patch_embeddings)
|
|
77
85
|
|
|
78
|
-
def
|
|
86
|
+
def _upscale(
|
|
79
87
|
self,
|
|
80
88
|
logits: torch.Tensor,
|
|
81
89
|
image_size: Tuple[int, int],
|
|
82
90
|
) -> torch.Tensor:
|
|
83
|
-
"""
|
|
91
|
+
"""Upscales the calculated logits to the target image size.
|
|
84
92
|
|
|
85
93
|
Args:
|
|
86
94
|
logits: The decoder outputs of shape (batch_size, n_classes,
|
|
@@ -93,22 +101,18 @@ class ConvDecoder(decoder.Decoder):
|
|
|
93
101
|
"""
|
|
94
102
|
return functional.interpolate(logits, image_size, mode="bilinear")
|
|
95
103
|
|
|
96
|
-
def forward(
|
|
97
|
-
self,
|
|
98
|
-
features: List[torch.Tensor],
|
|
99
|
-
image_size: Tuple[int, int],
|
|
100
|
-
) -> torch.Tensor:
|
|
104
|
+
def forward(self, decoder_inputs: DecoderInputs) -> torch.Tensor:
|
|
101
105
|
"""Maps the patch embeddings to a segmentation mask of the image size.
|
|
102
106
|
|
|
103
107
|
Args:
|
|
104
|
-
|
|
105
|
-
hidden_size, n_patches_height, n_patches_width).
|
|
106
|
-
image_size: The target image size (height, width).
|
|
108
|
+
decoder_inputs: Inputs required by the decoder.
|
|
107
109
|
|
|
108
110
|
Returns:
|
|
109
111
|
Tensor containing scores for all of the classes with shape
|
|
110
112
|
(batch_size, n_classes, image_height, image_width).
|
|
111
113
|
"""
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
114
|
+
features, image_size, _ = DecoderInputs(*decoder_inputs)
|
|
115
|
+
if self._combine_features:
|
|
116
|
+
features = self._forward_features(features)
|
|
117
|
+
logits = self._forward_head(features)
|
|
118
|
+
return self._upscale(logits, image_size)
|
|
@@ -6,10 +6,10 @@ import torch
|
|
|
6
6
|
from torch import nn
|
|
7
7
|
from torch.nn import functional
|
|
8
8
|
|
|
9
|
-
from eva.vision.models.networks.decoders import
|
|
9
|
+
from eva.vision.models.networks.decoders.segmentation import base
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
class LinearDecoder(
|
|
12
|
+
class LinearDecoder(base.Decoder):
|
|
13
13
|
"""Linear decoder."""
|
|
14
14
|
|
|
15
15
|
def __init__(self, layers: nn.Module) -> None:
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Semantic Segmentation decoder heads API."""
|
|
2
|
+
|
|
3
|
+
from eva.vision.models.networks.decoders.segmentation.semantic.common import (
|
|
4
|
+
ConvDecoder1x1,
|
|
5
|
+
ConvDecoderMS,
|
|
6
|
+
SingleLinearDecoder,
|
|
7
|
+
)
|
|
8
|
+
from eva.vision.models.networks.decoders.segmentation.semantic.with_image import (
|
|
9
|
+
ConvDecoderWithImage,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
__all__ = ["ConvDecoder1x1", "ConvDecoderMS", "SingleLinearDecoder", "ConvDecoderWithImage"]
|
|
@@ -7,10 +7,10 @@ output by an encoder into pixel-wise predictions for segmentation tasks.
|
|
|
7
7
|
|
|
8
8
|
from torch import nn
|
|
9
9
|
|
|
10
|
-
from eva.vision.models.networks.decoders.segmentation import
|
|
10
|
+
from eva.vision.models.networks.decoders.segmentation import decoder2d, linear
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
class ConvDecoder1x1(
|
|
13
|
+
class ConvDecoder1x1(decoder2d.Decoder2D):
|
|
14
14
|
"""A convolutional decoder with a single 1x1 convolutional layer."""
|
|
15
15
|
|
|
16
16
|
def __init__(self, in_features: int, num_classes: int) -> None:
|
|
@@ -29,7 +29,7 @@ class ConvDecoder1x1(conv2d.ConvDecoder):
|
|
|
29
29
|
)
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
class ConvDecoderMS(
|
|
32
|
+
class ConvDecoderMS(decoder2d.Decoder2D):
|
|
33
33
|
"""A multi-stage convolutional decoder with upsampling and convolutional layers.
|
|
34
34
|
|
|
35
35
|
This decoder applies a series of upsampling and convolutional layers to transform
|