kaiko-eva 0.1.1__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eva/core/callbacks/writers/embeddings/base.py +3 -4
- eva/core/data/dataloaders/dataloader.py +2 -2
- eva/core/data/splitting/random.py +6 -5
- eva/core/data/splitting/stratified.py +12 -6
- eva/core/losses/__init__.py +5 -0
- eva/core/losses/cross_entropy.py +27 -0
- eva/core/metrics/__init__.py +0 -4
- eva/core/metrics/defaults/__init__.py +0 -2
- eva/core/models/modules/module.py +9 -9
- eva/core/models/transforms/extract_cls_features.py +17 -9
- eva/core/models/transforms/extract_patch_features.py +23 -11
- eva/core/utils/io/__init__.py +2 -1
- eva/core/utils/io/gz.py +28 -0
- eva/core/utils/multiprocessing.py +46 -1
- eva/core/utils/progress_bar.py +15 -0
- eva/vision/callbacks/loggers/batch/segmentation.py +7 -4
- eva/vision/data/datasets/__init__.py +4 -0
- eva/vision/data/datasets/classification/__init__.py +2 -1
- eva/vision/data/datasets/classification/camelyon16.py +4 -1
- eva/vision/data/datasets/classification/panda.py +17 -1
- eva/vision/data/datasets/classification/wsi.py +4 -1
- eva/vision/data/datasets/segmentation/__init__.py +2 -0
- eva/vision/data/datasets/segmentation/consep.py +2 -2
- eva/vision/data/datasets/segmentation/lits.py +49 -29
- eva/vision/data/datasets/segmentation/lits_balanced.py +93 -0
- eva/vision/data/datasets/segmentation/monusac.py +7 -7
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +50 -18
- eva/vision/data/datasets/wsi.py +37 -1
- eva/vision/data/wsi/patching/coordinates.py +9 -1
- eva/vision/data/wsi/patching/samplers/_utils.py +2 -8
- eva/vision/data/wsi/patching/samplers/random.py +4 -2
- eva/vision/losses/__init__.py +2 -2
- eva/vision/losses/dice.py +75 -8
- eva/vision/metrics/__init__.py +11 -0
- eva/vision/metrics/defaults/__init__.py +7 -0
- eva/{core → vision}/metrics/defaults/segmentation/__init__.py +1 -1
- eva/{core → vision}/metrics/defaults/segmentation/multiclass.py +2 -1
- eva/vision/metrics/segmentation/BUILD +1 -0
- eva/vision/metrics/segmentation/__init__.py +9 -0
- eva/vision/metrics/segmentation/_utils.py +69 -0
- eva/{core/metrics → vision/metrics/segmentation}/generalized_dice.py +12 -10
- eva/vision/metrics/segmentation/mean_iou.py +57 -0
- eva/vision/models/modules/semantic_segmentation.py +4 -3
- eva/vision/models/networks/backbones/_utils.py +12 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +4 -1
- eva/vision/models/networks/backbones/pathology/histai.py +8 -2
- eva/vision/models/networks/backbones/pathology/mahmood.py +2 -9
- eva/vision/models/networks/backbones/pathology/owkin.py +14 -0
- eva/vision/models/networks/backbones/pathology/paige.py +51 -0
- eva/vision/models/networks/decoders/__init__.py +1 -1
- eva/vision/models/networks/decoders/segmentation/__init__.py +12 -4
- eva/vision/models/networks/decoders/segmentation/base.py +16 -0
- eva/vision/models/networks/decoders/segmentation/{conv2d.py → decoder2d.py} +26 -22
- eva/vision/models/networks/decoders/segmentation/linear.py +2 -2
- eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +12 -0
- eva/vision/models/networks/decoders/segmentation/{common.py → semantic/common.py} +3 -3
- eva/vision/models/networks/decoders/segmentation/semantic/with_image.py +94 -0
- eva/vision/models/networks/decoders/segmentation/typings.py +18 -0
- eva/vision/utils/colormap.py +20 -0
- eva/vision/utils/io/__init__.py +7 -1
- eva/vision/utils/io/nifti.py +19 -4
- {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.5.dist-info}/METADATA +8 -39
- {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.5.dist-info}/RECORD +66 -52
- {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.5.dist-info}/WHEEL +1 -1
- eva/core/metrics/mean_iou.py +0 -120
- eva/vision/models/networks/decoders/decoder.py +0 -7
- {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.5.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.5.dist-info}/licenses/LICENSE +0 -0
eva/vision/losses/dice.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
|
-
"""Dice loss."""
|
|
1
|
+
"""Dice based loss functions."""
|
|
2
|
+
|
|
3
|
+
from typing import Sequence, Tuple
|
|
2
4
|
|
|
3
5
|
import torch
|
|
4
6
|
from monai import losses
|
|
@@ -12,29 +14,94 @@ class DiceLoss(losses.DiceLoss): # type: ignore
|
|
|
12
14
|
Extends the implementation from MONAI
|
|
13
15
|
- to support semantic target labels (meaning targets of shape BHW)
|
|
14
16
|
- to support `ignore_index` functionality
|
|
17
|
+
- accept weight argument in list format
|
|
15
18
|
"""
|
|
16
19
|
|
|
17
|
-
def __init__(
|
|
18
|
-
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
*args,
|
|
23
|
+
ignore_index: int | None = None,
|
|
24
|
+
weight: Sequence[float] | torch.Tensor | None = None,
|
|
25
|
+
**kwargs,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Initialize the DiceLoss.
|
|
19
28
|
|
|
20
29
|
Args:
|
|
21
30
|
args: Positional arguments from the base class.
|
|
22
31
|
ignore_index: Specifies a target value that is ignored and
|
|
23
32
|
does not contribute to the input gradient.
|
|
33
|
+
weight: A list of weights to assign to each class.
|
|
24
34
|
kwargs: Key-word arguments from the base class.
|
|
25
35
|
"""
|
|
26
|
-
|
|
36
|
+
if weight is not None and not isinstance(weight, torch.Tensor):
|
|
37
|
+
weight = torch.tensor(weight)
|
|
38
|
+
|
|
39
|
+
super().__init__(*args, **kwargs, weight=weight)
|
|
27
40
|
|
|
28
41
|
self.ignore_index = ignore_index
|
|
29
42
|
|
|
30
43
|
@override
|
|
31
44
|
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # noqa
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
targets = targets * mask
|
|
35
|
-
inputs = torch.mul(inputs, mask.unsqueeze(1) if mask.ndim == 3 else mask)
|
|
45
|
+
inputs, targets = _apply_ignore_index(inputs, targets, self.ignore_index)
|
|
46
|
+
targets = _to_one_hot(targets, num_classes=inputs.shape[1])
|
|
36
47
|
|
|
37
48
|
if targets.ndim == 3:
|
|
38
49
|
targets = one_hot(targets[:, None, ...], num_classes=inputs.shape[1])
|
|
39
50
|
|
|
40
51
|
return super().forward(inputs, targets)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class DiceCELoss(losses.dice.DiceCELoss):
|
|
55
|
+
"""Combination of Dice and Cross Entropy Loss.
|
|
56
|
+
|
|
57
|
+
Extends the implementation from MONAI
|
|
58
|
+
- to support semantic target labels (meaning targets of shape BHW)
|
|
59
|
+
- to support `ignore_index` functionality
|
|
60
|
+
- accept weight argument in list format
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
*args,
|
|
66
|
+
ignore_index: int | None = None,
|
|
67
|
+
weight: Sequence[float] | torch.Tensor | None = None,
|
|
68
|
+
**kwargs,
|
|
69
|
+
) -> None:
|
|
70
|
+
"""Initialize the DiceCELoss.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
args: Positional arguments from the base class.
|
|
74
|
+
ignore_index: Specifies a target value that is ignored and
|
|
75
|
+
does not contribute to the input gradient.
|
|
76
|
+
weight: A list of weights to assign to each class.
|
|
77
|
+
kwargs: Key-word arguments from the base class.
|
|
78
|
+
"""
|
|
79
|
+
if weight is not None and not isinstance(weight, torch.Tensor):
|
|
80
|
+
weight = torch.tensor(weight)
|
|
81
|
+
|
|
82
|
+
super().__init__(*args, **kwargs, weight=weight)
|
|
83
|
+
|
|
84
|
+
self.ignore_index = ignore_index
|
|
85
|
+
|
|
86
|
+
@override
|
|
87
|
+
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # noqa
|
|
88
|
+
inputs, targets = _apply_ignore_index(inputs, targets, self.ignore_index)
|
|
89
|
+
targets = _to_one_hot(targets, num_classes=inputs.shape[1])
|
|
90
|
+
|
|
91
|
+
return super().forward(inputs, targets)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _apply_ignore_index(
|
|
95
|
+
inputs: torch.Tensor, targets: torch.Tensor, ignore_index: int | None
|
|
96
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
97
|
+
if ignore_index is not None:
|
|
98
|
+
mask = targets != ignore_index
|
|
99
|
+
targets = targets * mask
|
|
100
|
+
inputs = torch.mul(inputs, mask.unsqueeze(1) if mask.ndim == 3 else mask)
|
|
101
|
+
return inputs, targets
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _to_one_hot(tensor: torch.Tensor, num_classes: int) -> torch.Tensor:
|
|
105
|
+
if tensor.ndim == 3:
|
|
106
|
+
return one_hot(tensor[:, None, ...], num_classes=num_classes)
|
|
107
|
+
return tensor
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Default metric collections API."""
|
|
2
|
+
|
|
3
|
+
from eva.vision.metrics.defaults.segmentation import MulticlassSegmentationMetrics
|
|
4
|
+
from eva.vision.metrics.segmentation.generalized_dice import GeneralizedDiceScore
|
|
5
|
+
from eva.vision.metrics.segmentation.mean_iou import MeanIoU
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"MulticlassSegmentationMetrics",
|
|
9
|
+
"GeneralizedDiceScore",
|
|
10
|
+
"MeanIoU",
|
|
11
|
+
]
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
"""Default segmentation metric collections API."""
|
|
2
2
|
|
|
3
|
-
from eva.
|
|
3
|
+
from eva.vision.metrics.defaults.segmentation.multiclass import MulticlassSegmentationMetrics
|
|
4
4
|
|
|
5
5
|
__all__ = ["MulticlassSegmentationMetrics"]
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Default metric collection for multiclass semantic segmentation tasks."""
|
|
2
2
|
|
|
3
|
-
from eva.core.metrics import
|
|
3
|
+
from eva.core.metrics import structs
|
|
4
|
+
from eva.vision.metrics.segmentation import generalized_dice, mean_iou
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
class MulticlassSegmentationMetrics(structs.MetricCollection):
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
python_sources()
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Utils for segmentation metric collections."""
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def apply_ignore_index(
|
|
9
|
+
preds: torch.Tensor, target: torch.Tensor, ignore_index: int, num_classes: int
|
|
10
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
11
|
+
"""Applies the ignore index to the predictions and target tensors.
|
|
12
|
+
|
|
13
|
+
1. Masks the values in the target tensor that correspond to the ignored index.
|
|
14
|
+
2. Remove the channel corresponding to the ignored index from both tensors.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
preds: The predictions tensor. Expected to be of shape `(N,C,...)`.
|
|
18
|
+
target: The target tensor. Expected to be of shape `(N,C,...)`.
|
|
19
|
+
ignore_index: The index to ignore.
|
|
20
|
+
num_classes: The number of classes.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
The modified predictions and target tensors of shape `(N,C-1,...)`.
|
|
24
|
+
"""
|
|
25
|
+
if ignore_index < 0:
|
|
26
|
+
raise ValueError("ignore_index must be a non-negative integer")
|
|
27
|
+
|
|
28
|
+
ignore_mask = preds[:, ignore_index] == 1
|
|
29
|
+
target = target * (~ignore_mask.unsqueeze(1))
|
|
30
|
+
|
|
31
|
+
preds = _ignore_tensor_channel(preds, ignore_index)
|
|
32
|
+
target = _ignore_tensor_channel(target, ignore_index)
|
|
33
|
+
|
|
34
|
+
return preds, target
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def index_to_one_hot(tensor: torch.Tensor, num_classes: int) -> torch.Tensor:
|
|
38
|
+
"""Converts an index tensor to a one-hot tensor.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
tensor: The index tensor to convert. Expected to be of shape `(N,...)`.
|
|
42
|
+
num_classes: The number of classes to one-hot encode.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
A one-hot tensor of shape `(N,C,...)`.
|
|
46
|
+
"""
|
|
47
|
+
if not _is_one_hot(tensor):
|
|
48
|
+
tensor = torch.nn.functional.one_hot(tensor.long(), num_classes=num_classes).movedim(-1, 1)
|
|
49
|
+
return tensor
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _ignore_tensor_channel(tensor: torch.Tensor, ignore_index: int) -> torch.Tensor:
|
|
53
|
+
"""Removes the channel corresponding to the specified ignore index.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
tensor: The tensor to remove the channel from. Expected to be of shape `(N,C,...)`.
|
|
57
|
+
ignore_index: The index of the channel dimension (C) to remove.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
A tensor without the specified channel `(N,C-1,...)`.
|
|
61
|
+
"""
|
|
62
|
+
if ignore_index < 0:
|
|
63
|
+
raise ValueError("ignore_index must be a non-negative integer")
|
|
64
|
+
return torch.cat([tensor[:, :ignore_index], tensor[:, ignore_index + 1 :]], dim=1)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _is_one_hot(tensor: torch.Tensor, expected_dim: int = 4) -> bool:
|
|
68
|
+
"""Checks if the tensor is a one-hot tensor."""
|
|
69
|
+
return bool((tensor.bool() == tensor).all()) and tensor.ndim == expected_dim
|
|
@@ -6,6 +6,8 @@ import torch
|
|
|
6
6
|
from torchmetrics import segmentation
|
|
7
7
|
from typing_extensions import override
|
|
8
8
|
|
|
9
|
+
from eva.vision.metrics.segmentation import _utils
|
|
10
|
+
|
|
9
11
|
|
|
10
12
|
class GeneralizedDiceScore(segmentation.GeneralizedDiceScore):
|
|
11
13
|
"""Defines the Generalized Dice Score.
|
|
@@ -30,8 +32,6 @@ class GeneralizedDiceScore(segmentation.GeneralizedDiceScore):
|
|
|
30
32
|
include_background: Whether to include the background class in the computation
|
|
31
33
|
weight_type: The type of weight to apply to each class. Can be one of `"square"`,
|
|
32
34
|
`"simple"`, or `"linear"`.
|
|
33
|
-
input_format: What kind of input the function receives. Choose between ``"one-hot"``
|
|
34
|
-
for one-hot encoded tensors or ``"index"`` for index tensors.
|
|
35
35
|
ignore_index: Integer specifying a target class to ignore. If given, this class
|
|
36
36
|
index does not contribute to the returned score, regardless of reduction method.
|
|
37
37
|
per_class: Whether to compute the IoU for each class separately. If set to ``False``,
|
|
@@ -39,21 +39,23 @@ class GeneralizedDiceScore(segmentation.GeneralizedDiceScore):
|
|
|
39
39
|
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
|
|
40
40
|
"""
|
|
41
41
|
super().__init__(
|
|
42
|
-
num_classes=num_classes
|
|
42
|
+
num_classes=num_classes
|
|
43
|
+
- (ignore_index is not None)
|
|
44
|
+
+ (ignore_index == 0 and not include_background),
|
|
43
45
|
include_background=include_background,
|
|
44
46
|
weight_type=weight_type,
|
|
45
47
|
per_class=per_class,
|
|
46
48
|
**kwargs,
|
|
47
49
|
)
|
|
48
|
-
|
|
50
|
+
self.orig_num_classes = num_classes
|
|
49
51
|
self.ignore_index = ignore_index
|
|
50
52
|
|
|
51
53
|
@override
|
|
52
54
|
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
|
|
55
|
+
preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes)
|
|
56
|
+
target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes)
|
|
53
57
|
if self.ignore_index is not None:
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
super().update(preds=preds, target=target)
|
|
58
|
+
preds, target = _utils.apply_ignore_index(
|
|
59
|
+
preds, target, self.ignore_index, self.num_classes
|
|
60
|
+
)
|
|
61
|
+
super().update(preds=preds.long(), target=target.long())
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""MeanIoU metric for semantic segmentation."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torchmetrics import segmentation
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.vision.metrics.segmentation import _utils
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MeanIoU(segmentation.MeanIoU):
|
|
13
|
+
"""MeanIoU (mIOU) metric for semantic segmentation.
|
|
14
|
+
|
|
15
|
+
It expands the `torchmetrics` class by including an `ignore_index`
|
|
16
|
+
functionality.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
num_classes: int,
|
|
22
|
+
include_background: bool = True,
|
|
23
|
+
ignore_index: int | None = None,
|
|
24
|
+
per_class: bool = False,
|
|
25
|
+
**kwargs: Any,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Initializes the metric.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
num_classes: The number of classes in the segmentation problem.
|
|
31
|
+
include_background: Whether to include the background class in the computation
|
|
32
|
+
ignore_index: Integer specifying a target class to ignore. If given, this class
|
|
33
|
+
index does not contribute to the returned score, regardless of reduction method.
|
|
34
|
+
per_class: Whether to compute the IoU for each class separately. If set to ``False``,
|
|
35
|
+
the metric will compute the mean IoU over all classes.
|
|
36
|
+
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
|
|
37
|
+
"""
|
|
38
|
+
super().__init__(
|
|
39
|
+
num_classes=num_classes
|
|
40
|
+
- (ignore_index is not None)
|
|
41
|
+
+ (ignore_index == 0 and not include_background),
|
|
42
|
+
include_background=include_background,
|
|
43
|
+
per_class=per_class,
|
|
44
|
+
**kwargs,
|
|
45
|
+
)
|
|
46
|
+
self.orig_num_classes = num_classes
|
|
47
|
+
self.ignore_index = ignore_index
|
|
48
|
+
|
|
49
|
+
@override
|
|
50
|
+
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
|
|
51
|
+
preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes)
|
|
52
|
+
target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes)
|
|
53
|
+
if self.ignore_index is not None:
|
|
54
|
+
preds, target = _utils.apply_ignore_index(
|
|
55
|
+
preds, target, self.ignore_index, self.num_classes
|
|
56
|
+
)
|
|
57
|
+
super().update(preds=preds.long(), target=target.long())
|
|
@@ -15,6 +15,7 @@ from eva.core.models.modules.typings import INPUT_BATCH, INPUT_TENSOR_BATCH
|
|
|
15
15
|
from eva.core.models.modules.utils import batch_postprocess, grad
|
|
16
16
|
from eva.core.utils import parser
|
|
17
17
|
from eva.vision.models.networks import decoders
|
|
18
|
+
from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
class SemanticSegmentationModule(module.ModelModule):
|
|
@@ -101,9 +102,9 @@ class SemanticSegmentationModule(module.ModelModule):
|
|
|
101
102
|
"Please provide the expected `to_size` that the "
|
|
102
103
|
"decoder should map the embeddings (`inputs`) to."
|
|
103
104
|
)
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
return self.decoder(
|
|
105
|
+
features = self.encoder(inputs) if self.encoder else inputs
|
|
106
|
+
decoder_inputs = DecoderInputs(features, to_size or inputs.shape[-2:], inputs) # type: ignore
|
|
107
|
+
return self.decoder(decoder_inputs)
|
|
107
108
|
|
|
108
109
|
@override
|
|
109
110
|
def training_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
"""Utilis for backbone networks."""
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
from typing import Any, Dict, Tuple
|
|
4
5
|
|
|
6
|
+
import huggingface_hub
|
|
5
7
|
from torch import nn
|
|
6
8
|
|
|
7
9
|
from eva import models
|
|
@@ -37,3 +39,13 @@ def load_hugingface_model(
|
|
|
37
39
|
tensor_transforms=tensor_transforms,
|
|
38
40
|
model_kwargs=model_kwargs,
|
|
39
41
|
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def huggingface_login(hf_token: str | None = None):
|
|
45
|
+
token = hf_token or os.environ.get("HF_TOKEN")
|
|
46
|
+
if not token:
|
|
47
|
+
raise ValueError(
|
|
48
|
+
"Please provide a HuggingFace token to download the model. "
|
|
49
|
+
"You can either pass it as an argument or set the env variable HF_TOKEN."
|
|
50
|
+
)
|
|
51
|
+
huggingface_hub.login(token=token)
|
|
@@ -12,7 +12,8 @@ from eva.vision.models.networks.backbones.pathology.kaiko import (
|
|
|
12
12
|
)
|
|
13
13
|
from eva.vision.models.networks.backbones.pathology.lunit import lunit_vits8, lunit_vits16
|
|
14
14
|
from eva.vision.models.networks.backbones.pathology.mahmood import mahmood_uni
|
|
15
|
-
from eva.vision.models.networks.backbones.pathology.owkin import owkin_phikon
|
|
15
|
+
from eva.vision.models.networks.backbones.pathology.owkin import owkin_phikon, owkin_phikon_v2
|
|
16
|
+
from eva.vision.models.networks.backbones.pathology.paige import paige_virchow2
|
|
16
17
|
|
|
17
18
|
__all__ = [
|
|
18
19
|
"kaiko_vitb16",
|
|
@@ -21,6 +22,7 @@ __all__ = [
|
|
|
21
22
|
"kaiko_vits16",
|
|
22
23
|
"kaiko_vits8",
|
|
23
24
|
"owkin_phikon",
|
|
25
|
+
"owkin_phikon_v2",
|
|
24
26
|
"lunit_vits16",
|
|
25
27
|
"lunit_vits8",
|
|
26
28
|
"mahmood_uni",
|
|
@@ -28,4 +30,5 @@ __all__ = [
|
|
|
28
30
|
"prov_gigapath",
|
|
29
31
|
"histai_hibou_b",
|
|
30
32
|
"histai_hibou_l",
|
|
33
|
+
"paige_virchow2",
|
|
31
34
|
]
|
|
@@ -12,6 +12,9 @@ from eva.vision.models.networks.backbones.registry import register_model
|
|
|
12
12
|
def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
|
|
13
13
|
"""Initializes the hibou-B pathology FM by hist.ai (https://huggingface.co/histai/hibou-B).
|
|
14
14
|
|
|
15
|
+
Uses a customized implementation of the DINOv2 architecture from the transformers
|
|
16
|
+
library to add support for registers, which requires the trust_remote_code=True flag.
|
|
17
|
+
|
|
15
18
|
Args:
|
|
16
19
|
out_indices: Whether and which multi-level patch embeddings to return.
|
|
17
20
|
Currently only out_indices=1 is supported.
|
|
@@ -23,7 +26,7 @@ def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Modul
|
|
|
23
26
|
model_name="histai/hibou-B",
|
|
24
27
|
out_indices=out_indices,
|
|
25
28
|
model_kwargs={"trust_remote_code": True},
|
|
26
|
-
transform_args={"
|
|
29
|
+
transform_args={"num_register_tokens": 4} if out_indices is not None else None,
|
|
27
30
|
)
|
|
28
31
|
|
|
29
32
|
|
|
@@ -31,6 +34,9 @@ def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Modul
|
|
|
31
34
|
def histai_hibou_l(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
|
|
32
35
|
"""Initializes the hibou-L pathology FM by hist.ai (https://huggingface.co/histai/hibou-L).
|
|
33
36
|
|
|
37
|
+
Uses a customized implementation of the DINOv2 architecture from the transformers
|
|
38
|
+
library to add support for registers, which requires the trust_remote_code=True flag.
|
|
39
|
+
|
|
34
40
|
Args:
|
|
35
41
|
out_indices: Whether and which multi-level patch embeddings to return.
|
|
36
42
|
Currently only out_indices=1 is supported.
|
|
@@ -42,5 +48,5 @@ def histai_hibou_l(out_indices: int | Tuple[int, ...] | None = None) -> nn.Modul
|
|
|
42
48
|
model_name="histai/hibou-L",
|
|
43
49
|
out_indices=out_indices,
|
|
44
50
|
model_kwargs={"trust_remote_code": True},
|
|
45
|
-
transform_args={"
|
|
51
|
+
transform_args={"num_register_tokens": 4} if out_indices is not None else None,
|
|
46
52
|
)
|
|
@@ -9,6 +9,7 @@ from loguru import logger
|
|
|
9
9
|
from torch import nn
|
|
10
10
|
|
|
11
11
|
from eva.vision.models import wrappers
|
|
12
|
+
from eva.vision.models.networks.backbones import _utils
|
|
12
13
|
from eva.vision.models.networks.backbones.registry import register_model
|
|
13
14
|
|
|
14
15
|
|
|
@@ -31,19 +32,11 @@ def mahmood_uni(
|
|
|
31
32
|
Returns:
|
|
32
33
|
The model instance.
|
|
33
34
|
"""
|
|
34
|
-
token = hf_token or os.environ.get("HF_TOKEN")
|
|
35
|
-
if not token:
|
|
36
|
-
raise ValueError(
|
|
37
|
-
"Please provide a HuggingFace token to download the model. "
|
|
38
|
-
"You can either pass it as an argument or set the env variable HF_TOKEN."
|
|
39
|
-
)
|
|
40
|
-
|
|
41
35
|
checkpoint_path = os.path.join(download_dir, "pytorch_model.bin")
|
|
42
|
-
|
|
43
36
|
if not os.path.exists(checkpoint_path):
|
|
44
37
|
logger.info(f"Downloading the model checkpoint to {download_dir} ...")
|
|
45
38
|
os.makedirs(download_dir, exist_ok=True)
|
|
46
|
-
|
|
39
|
+
_utils.huggingface_login(hf_token)
|
|
47
40
|
huggingface_hub.hf_hub_download(
|
|
48
41
|
"MahmoodLab/UNI",
|
|
49
42
|
filename="pytorch_model.bin",
|
|
@@ -20,3 +20,17 @@ def owkin_phikon(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
|
|
|
20
20
|
The model instance.
|
|
21
21
|
"""
|
|
22
22
|
return _utils.load_hugingface_model(model_name="owkin/phikon", out_indices=out_indices)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@register_model("pathology/owkin_phikon_v2")
|
|
26
|
+
def owkin_phikon_v2(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
|
|
27
|
+
"""Initializes the phikon-v2 pathology FM by owkin (https://huggingface.co/owkin/phikon-v2).
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
31
|
+
Currently only out_indices=1 is supported.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
The model instance.
|
|
35
|
+
"""
|
|
36
|
+
return _utils.load_hugingface_model(model_name="owkin/phikon-v2", out_indices=out_indices)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Pathology FMs from paige.ai.
|
|
2
|
+
|
|
3
|
+
Source: https://huggingface.co/paige-ai/
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Tuple
|
|
7
|
+
|
|
8
|
+
import timm
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
|
|
11
|
+
from eva.core.models import transforms
|
|
12
|
+
from eva.vision.models import wrappers
|
|
13
|
+
from eva.vision.models.networks.backbones import _utils
|
|
14
|
+
from eva.vision.models.networks.backbones.registry import register_model
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@register_model("pathology/paige_virchow2")
|
|
18
|
+
def paige_virchow2(
|
|
19
|
+
dynamic_img_size: bool = True,
|
|
20
|
+
out_indices: int | Tuple[int, ...] | None = None,
|
|
21
|
+
hf_token: str | None = None,
|
|
22
|
+
include_patch_tokens: bool = False,
|
|
23
|
+
) -> nn.Module:
|
|
24
|
+
"""Initializes the Virchow2 pathology FM by paige.ai.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
dynamic_img_size: Support different input image sizes by allowing to change
|
|
28
|
+
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
|
|
29
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
30
|
+
include_patch_tokens: Whether to combine the mean aggregated patch tokens with cls token.
|
|
31
|
+
hf_token: HuggingFace token to download the model.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
The model instance.
|
|
35
|
+
"""
|
|
36
|
+
_utils.huggingface_login(hf_token)
|
|
37
|
+
return wrappers.TimmModel(
|
|
38
|
+
model_name="hf-hub:paige-ai/Virchow2",
|
|
39
|
+
out_indices=out_indices,
|
|
40
|
+
pretrained=True,
|
|
41
|
+
model_kwargs={
|
|
42
|
+
"dynamic_img_size": dynamic_img_size,
|
|
43
|
+
"mlp_layer": timm.layers.SwiGLUPacked,
|
|
44
|
+
"act_layer": nn.SiLU,
|
|
45
|
+
},
|
|
46
|
+
tensor_transforms=(
|
|
47
|
+
transforms.ExtractCLSFeatures(include_patch_tokens=include_patch_tokens)
|
|
48
|
+
if out_indices is None
|
|
49
|
+
else None
|
|
50
|
+
),
|
|
51
|
+
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Decoder heads API."""
|
|
2
2
|
|
|
3
3
|
from eva.vision.models.networks.decoders import segmentation
|
|
4
|
-
from eva.vision.models.networks.decoders.
|
|
4
|
+
from eva.vision.models.networks.decoders.segmentation.base import Decoder
|
|
5
5
|
|
|
6
6
|
__all__ = ["segmentation", "Decoder"]
|
|
@@ -1,11 +1,19 @@
|
|
|
1
1
|
"""Segmentation decoder heads API."""
|
|
2
2
|
|
|
3
|
-
from eva.vision.models.networks.decoders.segmentation.
|
|
3
|
+
from eva.vision.models.networks.decoders.segmentation.decoder2d import Decoder2D
|
|
4
|
+
from eva.vision.models.networks.decoders.segmentation.linear import LinearDecoder
|
|
5
|
+
from eva.vision.models.networks.decoders.segmentation.semantic import (
|
|
4
6
|
ConvDecoder1x1,
|
|
5
7
|
ConvDecoderMS,
|
|
8
|
+
ConvDecoderWithImage,
|
|
6
9
|
SingleLinearDecoder,
|
|
7
10
|
)
|
|
8
|
-
from eva.vision.models.networks.decoders.segmentation.conv2d import ConvDecoder
|
|
9
|
-
from eva.vision.models.networks.decoders.segmentation.linear import LinearDecoder
|
|
10
11
|
|
|
11
|
-
__all__ = [
|
|
12
|
+
__all__ = [
|
|
13
|
+
"ConvDecoder1x1",
|
|
14
|
+
"ConvDecoderMS",
|
|
15
|
+
"SingleLinearDecoder",
|
|
16
|
+
"ConvDecoderWithImage",
|
|
17
|
+
"Decoder2D",
|
|
18
|
+
"LinearDecoder",
|
|
19
|
+
]
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Semantic segmentation decoder base class."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Decoder(nn.Module, abc.ABC):
|
|
12
|
+
"""Abstract base class for segmentation decoders."""
|
|
13
|
+
|
|
14
|
+
@abc.abstractmethod
|
|
15
|
+
def forward(self, decoder_inputs: DecoderInputs) -> torch.Tensor:
|
|
16
|
+
"""Forward pass of the decoder."""
|