kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/__init__.py +3 -2
- eva/core/callbacks/config.py +143 -0
- eva/core/callbacks/writers/__init__.py +6 -3
- eva/core/callbacks/writers/embeddings/__init__.py +6 -0
- eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
- eva/core/callbacks/writers/embeddings/base.py +192 -0
- eva/core/callbacks/writers/embeddings/classification.py +117 -0
- eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
- eva/core/callbacks/writers/embeddings/typings.py +38 -0
- eva/core/data/datasets/__init__.py +10 -2
- eva/core/data/datasets/classification/__init__.py +5 -2
- eva/core/data/datasets/classification/embeddings.py +15 -135
- eva/core/data/datasets/classification/multi_embeddings.py +110 -0
- eva/core/data/datasets/embeddings.py +167 -0
- eva/core/data/splitting/__init__.py +6 -0
- eva/core/data/splitting/random.py +41 -0
- eva/core/data/splitting/stratified.py +56 -0
- eva/core/data/transforms/__init__.py +3 -1
- eva/core/data/transforms/padding/__init__.py +5 -0
- eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
- eva/core/data/transforms/sampling/__init__.py +5 -0
- eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
- eva/core/loggers/__init__.py +7 -0
- eva/core/loggers/dummy.py +38 -0
- eva/core/loggers/experimental_loggers.py +8 -0
- eva/core/loggers/log/__init__.py +6 -0
- eva/core/loggers/log/image.py +71 -0
- eva/core/loggers/log/parameters.py +74 -0
- eva/core/loggers/log/utils.py +13 -0
- eva/core/loggers/loggers.py +6 -0
- eva/core/metrics/__init__.py +6 -2
- eva/core/metrics/defaults/__init__.py +10 -3
- eva/core/metrics/defaults/classification/__init__.py +1 -1
- eva/core/metrics/defaults/classification/binary.py +0 -9
- eva/core/metrics/defaults/classification/multiclass.py +0 -8
- eva/core/metrics/defaults/segmentation/__init__.py +5 -0
- eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
- eva/core/metrics/generalized_dice.py +59 -0
- eva/core/metrics/mean_iou.py +120 -0
- eva/core/metrics/structs/schemas.py +3 -1
- eva/core/models/__init__.py +3 -1
- eva/core/models/modules/head.py +16 -15
- eva/core/models/modules/module.py +25 -1
- eva/core/models/modules/typings.py +14 -1
- eva/core/models/modules/utils/batch_postprocess.py +37 -5
- eva/core/models/networks/__init__.py +1 -2
- eva/core/models/networks/mlp.py +2 -2
- eva/core/models/transforms/__init__.py +6 -0
- eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
- eva/core/models/transforms/extract_patch_features.py +47 -0
- eva/core/models/wrappers/__init__.py +13 -0
- eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
- eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
- eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
- eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
- eva/core/trainers/_recorder.py +69 -7
- eva/core/trainers/functional.py +23 -5
- eva/core/trainers/trainer.py +20 -6
- eva/core/utils/__init__.py +6 -0
- eva/core/utils/clone.py +27 -0
- eva/core/utils/memory.py +28 -0
- eva/core/utils/operations.py +26 -0
- eva/core/utils/parser.py +20 -0
- eva/vision/__init__.py +2 -2
- eva/vision/callbacks/__init__.py +5 -0
- eva/vision/callbacks/loggers/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/base.py +130 -0
- eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
- eva/vision/data/datasets/__init__.py +24 -4
- eva/vision/data/datasets/_utils.py +3 -3
- eva/vision/data/datasets/_validators.py +15 -2
- eva/vision/data/datasets/classification/__init__.py +6 -2
- eva/vision/data/datasets/classification/bach.py +10 -15
- eva/vision/data/datasets/classification/base.py +17 -24
- eva/vision/data/datasets/classification/camelyon16.py +244 -0
- eva/vision/data/datasets/classification/crc.py +10 -15
- eva/vision/data/datasets/classification/mhist.py +10 -15
- eva/vision/data/datasets/classification/panda.py +184 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
- eva/vision/data/datasets/classification/wsi.py +105 -0
- eva/vision/data/datasets/segmentation/__init__.py +15 -2
- eva/vision/data/datasets/segmentation/_utils.py +38 -0
- eva/vision/data/datasets/segmentation/base.py +31 -47
- eva/vision/data/datasets/segmentation/bcss.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +156 -0
- eva/vision/data/datasets/segmentation/embeddings.py +34 -0
- eva/vision/data/datasets/segmentation/lits.py +178 -0
- eva/vision/data/datasets/segmentation/monusac.py +236 -0
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
- eva/vision/data/datasets/wsi.py +187 -0
- eva/vision/data/transforms/__init__.py +3 -2
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
- eva/vision/data/transforms/common/resize_and_crop.py +6 -7
- eva/vision/data/transforms/normalization/__init__.py +6 -0
- eva/vision/data/transforms/normalization/clamp.py +43 -0
- eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
- eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
- eva/vision/data/wsi/__init__.py +16 -0
- eva/vision/data/wsi/backends/__init__.py +69 -0
- eva/vision/data/wsi/backends/base.py +115 -0
- eva/vision/data/wsi/backends/openslide.py +73 -0
- eva/vision/data/wsi/backends/pil.py +52 -0
- eva/vision/data/wsi/backends/tiffslide.py +42 -0
- eva/vision/data/wsi/patching/__init__.py +6 -0
- eva/vision/data/wsi/patching/coordinates.py +98 -0
- eva/vision/data/wsi/patching/mask.py +123 -0
- eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
- eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
- eva/vision/data/wsi/patching/samplers/base.py +48 -0
- eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
- eva/vision/data/wsi/patching/samplers/grid.py +47 -0
- eva/vision/data/wsi/patching/samplers/random.py +41 -0
- eva/vision/losses/__init__.py +5 -0
- eva/vision/losses/dice.py +40 -0
- eva/vision/models/__init__.py +4 -2
- eva/vision/models/modules/__init__.py +5 -0
- eva/vision/models/modules/semantic_segmentation.py +161 -0
- eva/vision/models/networks/__init__.py +1 -2
- eva/vision/models/networks/backbones/__init__.py +6 -0
- eva/vision/models/networks/backbones/_utils.py +39 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
- eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
- eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
- eva/vision/models/networks/backbones/pathology/histai.py +46 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
- eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
- eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
- eva/vision/models/networks/backbones/registry.py +47 -0
- eva/vision/models/networks/backbones/timm/__init__.py +5 -0
- eva/vision/models/networks/backbones/timm/backbones.py +54 -0
- eva/vision/models/networks/backbones/universal/__init__.py +8 -0
- eva/vision/models/networks/backbones/universal/vit.py +54 -0
- eva/vision/models/networks/decoders/__init__.py +6 -0
- eva/vision/models/networks/decoders/decoder.py +7 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
- eva/vision/models/networks/decoders/segmentation/common.py +74 -0
- eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
- eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
- eva/vision/models/wrappers/__init__.py +6 -0
- eva/vision/models/wrappers/from_registry.py +48 -0
- eva/vision/models/wrappers/from_timm.py +68 -0
- eva/vision/utils/colormap.py +77 -0
- eva/vision/utils/convert.py +67 -0
- eva/vision/utils/io/__init__.py +10 -4
- eva/vision/utils/io/image.py +21 -2
- eva/vision/utils/io/mat.py +36 -0
- eva/vision/utils/io/nifti.py +40 -15
- eva/vision/utils/io/text.py +10 -3
- kaiko_eva-0.1.0.dist-info/METADATA +553 -0
- kaiko_eva-0.1.0.dist-info/RECORD +205 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
- eva/core/callbacks/writers/embeddings.py +0 -169
- eva/core/callbacks/writers/typings.py +0 -23
- eva/core/models/networks/transforms/__init__.py +0 -5
- eva/core/models/networks/wrappers/__init__.py +0 -8
- eva/vision/data/datasets/classification/total_segmentator.py +0 -213
- eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
- eva/vision/models/networks/postprocesses/__init__.py +0 -5
- eva/vision/models/networks/postprocesses/cls.py +0 -25
- kaiko_eva-0.0.1.dist-info/METADATA +0 -405
- kaiko_eva-0.0.1.dist-info/RECORD +0 -110
- /eva/core/models/{networks → wrappers}/_utils.py +0 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Random sampler."""
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from typing import Generator, Tuple
|
|
5
|
+
|
|
6
|
+
from eva.vision.data.wsi.patching.samplers import _utils, base
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class RandomSampler(base.Sampler):
|
|
10
|
+
"""Sample patch coordinates randomly.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
n_samples: The number of samples to return.
|
|
14
|
+
seed: The random seed.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, n_samples: int = 1, seed: int = 42):
|
|
18
|
+
"""Initializes the sampler."""
|
|
19
|
+
self.seed = seed
|
|
20
|
+
self.n_samples = n_samples
|
|
21
|
+
|
|
22
|
+
def sample(
|
|
23
|
+
self,
|
|
24
|
+
width: int,
|
|
25
|
+
height: int,
|
|
26
|
+
layer_shape: Tuple[int, int],
|
|
27
|
+
) -> Generator[Tuple[int, int], None, None]:
|
|
28
|
+
"""Sample random patches.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
width: The width of the patches.
|
|
32
|
+
height: The height of the patches.
|
|
33
|
+
layer_shape: The shape of the layer.
|
|
34
|
+
"""
|
|
35
|
+
_utils.validate_dimensions(width, height, layer_shape)
|
|
36
|
+
_utils.set_seed(self.seed)
|
|
37
|
+
|
|
38
|
+
x_max, y_max = layer_shape[0], layer_shape[1]
|
|
39
|
+
for _ in range(self.n_samples):
|
|
40
|
+
x, y = random.randint(0, x_max - width), random.randint(0, y_max - height) # nosec
|
|
41
|
+
yield x, y
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Dice loss."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from monai import losses
|
|
5
|
+
from monai.networks import one_hot # type: ignore
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DiceLoss(losses.DiceLoss): # type: ignore
|
|
10
|
+
"""Computes the average Dice loss between two tensors.
|
|
11
|
+
|
|
12
|
+
Extends the implementation from MONAI
|
|
13
|
+
- to support semantic target labels (meaning targets of shape BHW)
|
|
14
|
+
- to support `ignore_index` functionality
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, *args, ignore_index: int | None = None, **kwargs) -> None:
|
|
18
|
+
"""Initialize the DiceLoss with support for ignore_index.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
args: Positional arguments from the base class.
|
|
22
|
+
ignore_index: Specifies a target value that is ignored and
|
|
23
|
+
does not contribute to the input gradient.
|
|
24
|
+
kwargs: Key-word arguments from the base class.
|
|
25
|
+
"""
|
|
26
|
+
super().__init__(*args, **kwargs)
|
|
27
|
+
|
|
28
|
+
self.ignore_index = ignore_index
|
|
29
|
+
|
|
30
|
+
@override
|
|
31
|
+
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # noqa
|
|
32
|
+
if self.ignore_index is not None:
|
|
33
|
+
mask = targets != self.ignore_index
|
|
34
|
+
targets = targets * mask
|
|
35
|
+
inputs = torch.mul(inputs, mask.unsqueeze(1) if mask.ndim == 3 else mask)
|
|
36
|
+
|
|
37
|
+
if targets.ndim == 3:
|
|
38
|
+
targets = one_hot(targets[:, None, ...], num_classes=inputs.shape[1])
|
|
39
|
+
|
|
40
|
+
return super().forward(inputs, targets)
|
eva/vision/models/__init__.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
"""Vision Models API."""
|
|
2
2
|
|
|
3
|
-
from eva.vision.models import networks
|
|
3
|
+
from eva.vision.models import networks, wrappers
|
|
4
|
+
from eva.vision.models.networks import backbones
|
|
5
|
+
from eva.vision.models.wrappers import ModelFromRegistry, TimmModel
|
|
4
6
|
|
|
5
|
-
__all__ = ["networks"]
|
|
7
|
+
__all__ = ["networks", "wrappers", "backbones", "ModelFromRegistry", "TimmModel"]
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
""""Neural Network Semantic Segmentation Module."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
|
|
7
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
8
|
+
from torch import nn, optim
|
|
9
|
+
from torch.optim import lr_scheduler
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from eva.core.metrics import structs as metrics_lib
|
|
13
|
+
from eva.core.models.modules import module
|
|
14
|
+
from eva.core.models.modules.typings import INPUT_BATCH, INPUT_TENSOR_BATCH
|
|
15
|
+
from eva.core.models.modules.utils import batch_postprocess, grad
|
|
16
|
+
from eva.core.utils import parser
|
|
17
|
+
from eva.vision.models.networks import decoders
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SemanticSegmentationModule(module.ModelModule):
|
|
21
|
+
"""Neural network semantic segmentation module for training on patch embeddings."""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
decoder: decoders.Decoder,
|
|
26
|
+
criterion: Callable[..., torch.Tensor],
|
|
27
|
+
encoder: Dict[str, Any] | Callable[[torch.Tensor], List[torch.Tensor]] | None = None,
|
|
28
|
+
lr_multiplier_encoder: float = 0.0,
|
|
29
|
+
optimizer: OptimizerCallable = optim.AdamW,
|
|
30
|
+
lr_scheduler: LRSchedulerCallable = lr_scheduler.ConstantLR,
|
|
31
|
+
metrics: metrics_lib.MetricsSchema | None = None,
|
|
32
|
+
postprocess: batch_postprocess.BatchPostProcess | None = None,
|
|
33
|
+
) -> None:
|
|
34
|
+
"""Initializes the neural net head module.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
decoder: The decoder model.
|
|
38
|
+
criterion: The loss function to use.
|
|
39
|
+
encoder: The encoder model. If `None`, it will be expected
|
|
40
|
+
that the input batch returns the features directly.
|
|
41
|
+
If pass as a dictionary, it will be parsed to an object
|
|
42
|
+
during the `configure_model` step.
|
|
43
|
+
lr_multiplier_encoder: The learning rate multiplier for the
|
|
44
|
+
encoder parameters. If `0`, it will freeze the encoder.
|
|
45
|
+
optimizer: The optimizer to use.
|
|
46
|
+
lr_scheduler: The learning rate scheduler to use.
|
|
47
|
+
metrics: The metric groups to track.
|
|
48
|
+
postprocess: A list of helper functions to apply after the
|
|
49
|
+
loss and before the metrics calculation to the model
|
|
50
|
+
predictions and targets.
|
|
51
|
+
"""
|
|
52
|
+
super().__init__(metrics=metrics, postprocess=postprocess)
|
|
53
|
+
|
|
54
|
+
self.decoder = decoder
|
|
55
|
+
self.criterion = criterion
|
|
56
|
+
self.encoder = encoder # type: ignore
|
|
57
|
+
self.lr_multiplier_encoder = lr_multiplier_encoder
|
|
58
|
+
self.optimizer = optimizer
|
|
59
|
+
self.lr_scheduler = lr_scheduler
|
|
60
|
+
|
|
61
|
+
@override
|
|
62
|
+
def configure_model(self) -> None:
|
|
63
|
+
self._freeze_encoder()
|
|
64
|
+
|
|
65
|
+
if isinstance(self.encoder, dict):
|
|
66
|
+
self.encoder: Callable[[torch.Tensor], List[torch.Tensor]] = parser.parse_object(
|
|
67
|
+
self.encoder,
|
|
68
|
+
expected_type=nn.Module,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
@override
|
|
72
|
+
def configure_optimizers(self) -> Any:
|
|
73
|
+
optimizer = self.optimizer(
|
|
74
|
+
[
|
|
75
|
+
{"params": self.decoder.parameters()},
|
|
76
|
+
{
|
|
77
|
+
"params": self._encoder_trainable_parameters(),
|
|
78
|
+
"lr": self._base_lr * self.lr_multiplier_encoder,
|
|
79
|
+
},
|
|
80
|
+
]
|
|
81
|
+
)
|
|
82
|
+
lr_scheduler = self.lr_scheduler(optimizer)
|
|
83
|
+
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
|
|
84
|
+
|
|
85
|
+
@override
|
|
86
|
+
def forward(
|
|
87
|
+
self,
|
|
88
|
+
inputs: torch.Tensor,
|
|
89
|
+
to_size: Tuple[int, int] | None = None,
|
|
90
|
+
*args: Any,
|
|
91
|
+
**kwargs: Any,
|
|
92
|
+
) -> torch.Tensor:
|
|
93
|
+
"""Maps the input tensor (image tensor or embeddings) to masks.
|
|
94
|
+
|
|
95
|
+
If `inputs` is image tensor, then the `self.encoder`
|
|
96
|
+
should be implemented, otherwise it will be interpreted
|
|
97
|
+
as embeddings, where the `to_size` should be given.
|
|
98
|
+
"""
|
|
99
|
+
if self.encoder is None and to_size is None:
|
|
100
|
+
raise ValueError(
|
|
101
|
+
"Please provide the expected `to_size` that the "
|
|
102
|
+
"decoder should map the embeddings (`inputs`) to."
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
patch_embeddings = self.encoder(inputs) if self.encoder else inputs
|
|
106
|
+
return self.decoder(patch_embeddings, to_size or inputs.shape[-2:])
|
|
107
|
+
|
|
108
|
+
@override
|
|
109
|
+
def training_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
110
|
+
return self._batch_step(batch)
|
|
111
|
+
|
|
112
|
+
@override
|
|
113
|
+
def validation_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
114
|
+
return self._batch_step(batch)
|
|
115
|
+
|
|
116
|
+
@override
|
|
117
|
+
def test_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
118
|
+
return self._batch_step(batch)
|
|
119
|
+
|
|
120
|
+
@override
|
|
121
|
+
def predict_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> torch.Tensor:
|
|
122
|
+
tensor = INPUT_BATCH(*batch).data
|
|
123
|
+
return self.encoder(tensor) if isinstance(self.encoder, nn.Module) else tensor
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def _base_lr(self) -> float:
|
|
127
|
+
"""Returns the base learning rate."""
|
|
128
|
+
base_optimizer = self.optimizer(self.parameters())
|
|
129
|
+
return base_optimizer.param_groups[-1]["lr"]
|
|
130
|
+
|
|
131
|
+
def _encoder_trainable_parameters(self) -> Iterable[torch.Tensor]:
|
|
132
|
+
"""Returns the trainable parameters of the encoder."""
|
|
133
|
+
return (
|
|
134
|
+
self.encoder.parameters()
|
|
135
|
+
if isinstance(self.encoder, nn.Module) and self.lr_multiplier_encoder > 0
|
|
136
|
+
else iter(())
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def _freeze_encoder(self) -> None:
|
|
140
|
+
"""If initialized, it freezes the encoder network."""
|
|
141
|
+
if isinstance(self.encoder, nn.Module) and self.lr_multiplier_encoder == 0:
|
|
142
|
+
grad.deactivate_requires_grad(self.encoder)
|
|
143
|
+
|
|
144
|
+
def _batch_step(self, batch: INPUT_TENSOR_BATCH) -> STEP_OUTPUT:
|
|
145
|
+
"""Performs a model forward step and calculates the loss.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
batch: The desired batch to process.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
The batch step output.
|
|
152
|
+
"""
|
|
153
|
+
data, targets, metadata = INPUT_TENSOR_BATCH(*batch)
|
|
154
|
+
predictions = self(data, to_size=targets.shape[-2:])
|
|
155
|
+
loss = self.criterion(predictions, targets)
|
|
156
|
+
return {
|
|
157
|
+
"loss": loss,
|
|
158
|
+
"targets": targets,
|
|
159
|
+
"predictions": predictions,
|
|
160
|
+
"metadata": metadata,
|
|
161
|
+
}
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
"""Vision Model Backbones API."""
|
|
2
|
+
|
|
3
|
+
from eva.vision.models.networks.backbones import pathology, timm, universal
|
|
4
|
+
from eva.vision.models.networks.backbones.registry import BackboneModelRegistry, register_model
|
|
5
|
+
|
|
6
|
+
__all__ = ["pathology", "timm", "universal", "BackboneModelRegistry", "register_model"]
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Utilis for backbone networks."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Tuple
|
|
4
|
+
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
from eva import models
|
|
8
|
+
from eva.core.models import transforms
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def load_hugingface_model(
|
|
12
|
+
model_name: str,
|
|
13
|
+
out_indices: int | Tuple[int, ...] | None,
|
|
14
|
+
model_kwargs: Dict[str, Any] | None = None,
|
|
15
|
+
transform_args: Dict[str, Any] | None = None,
|
|
16
|
+
) -> nn.Module:
|
|
17
|
+
"""Helper function to load HuggingFace models.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
model_name: The model name to load.
|
|
21
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
22
|
+
Currently only out_indices=1 is supported.
|
|
23
|
+
model_kwargs: The arguments used for instantiating the model.
|
|
24
|
+
transform_args: The arguments used for instantiating the transform.
|
|
25
|
+
|
|
26
|
+
Returns: The model instance.
|
|
27
|
+
"""
|
|
28
|
+
if out_indices is None:
|
|
29
|
+
tensor_transforms = transforms.ExtractCLSFeatures(**(transform_args or {}))
|
|
30
|
+
elif out_indices == 1:
|
|
31
|
+
tensor_transforms = transforms.ExtractPatchFeatures(**(transform_args or {}))
|
|
32
|
+
else:
|
|
33
|
+
raise ValueError(f"out_indices={out_indices} is currently not supported.")
|
|
34
|
+
|
|
35
|
+
return models.HuggingFaceModel(
|
|
36
|
+
model_name_or_path=model_name,
|
|
37
|
+
tensor_transforms=tensor_transforms,
|
|
38
|
+
model_kwargs=model_kwargs,
|
|
39
|
+
)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Vision Pathology Model Backbones API."""
|
|
2
|
+
|
|
3
|
+
from eva.vision.models.networks.backbones.pathology.bioptimus import bioptimus_h_optimus_0
|
|
4
|
+
from eva.vision.models.networks.backbones.pathology.gigapath import prov_gigapath
|
|
5
|
+
from eva.vision.models.networks.backbones.pathology.histai import histai_hibou_b, histai_hibou_l
|
|
6
|
+
from eva.vision.models.networks.backbones.pathology.kaiko import (
|
|
7
|
+
kaiko_vitb8,
|
|
8
|
+
kaiko_vitb16,
|
|
9
|
+
kaiko_vitl14,
|
|
10
|
+
kaiko_vits8,
|
|
11
|
+
kaiko_vits16,
|
|
12
|
+
)
|
|
13
|
+
from eva.vision.models.networks.backbones.pathology.lunit import lunit_vits8, lunit_vits16
|
|
14
|
+
from eva.vision.models.networks.backbones.pathology.mahmood import mahmood_uni
|
|
15
|
+
from eva.vision.models.networks.backbones.pathology.owkin import owkin_phikon
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"kaiko_vitb16",
|
|
19
|
+
"kaiko_vitb8",
|
|
20
|
+
"kaiko_vitl14",
|
|
21
|
+
"kaiko_vits16",
|
|
22
|
+
"kaiko_vits8",
|
|
23
|
+
"owkin_phikon",
|
|
24
|
+
"lunit_vits16",
|
|
25
|
+
"lunit_vits8",
|
|
26
|
+
"mahmood_uni",
|
|
27
|
+
"bioptimus_h_optimus_0",
|
|
28
|
+
"prov_gigapath",
|
|
29
|
+
"histai_hibou_b",
|
|
30
|
+
"histai_hibou_l",
|
|
31
|
+
]
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Pathology FMs from Bioptimus."""
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
import timm
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from eva.vision.models.networks.backbones.registry import register_model
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@register_model("pathology/bioptimus_h_optimus_0")
|
|
12
|
+
def bioptimus_h_optimus_0(
|
|
13
|
+
dynamic_img_size: bool = True,
|
|
14
|
+
out_indices: int | Tuple[int, ...] | None = None,
|
|
15
|
+
) -> nn.Module:
|
|
16
|
+
"""Initializes the h_optimus_0 pathology FM by Bioptimus.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
dynamic_img_size: Whether to allow the interpolation embedding
|
|
20
|
+
to be interpolated at `forward()` time when image grid changes
|
|
21
|
+
from original.
|
|
22
|
+
out_indices: Weather and which multi-level patch embeddings to return.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
The model instance.
|
|
26
|
+
"""
|
|
27
|
+
return timm.create_model(
|
|
28
|
+
model_name="hf-hub:bioptimus/H-optimus-0",
|
|
29
|
+
pretrained=True,
|
|
30
|
+
init_values=1e-5,
|
|
31
|
+
dynamic_img_size=dynamic_img_size,
|
|
32
|
+
out_indices=out_indices,
|
|
33
|
+
features_only=out_indices is not None,
|
|
34
|
+
)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Pathology FMs from other/mixed entities."""
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
import timm
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from eva.vision.models.networks.backbones.registry import register_model
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@register_model("pathology/prov_gigapath")
|
|
12
|
+
def prov_gigapath(
|
|
13
|
+
dynamic_img_size: bool = True,
|
|
14
|
+
out_indices: int | Tuple[int, ...] | None = None,
|
|
15
|
+
) -> nn.Module:
|
|
16
|
+
"""Initializes the Prov-GigaPath pathology FM.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
dynamic_img_size: Whether to allow the interpolation embedding
|
|
20
|
+
to be interpolated at `forward()` time when image grid changes
|
|
21
|
+
from original.
|
|
22
|
+
out_indices: Weather and which multi-level patch embeddings to return.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
The model instance.
|
|
26
|
+
"""
|
|
27
|
+
return timm.create_model(
|
|
28
|
+
model_name="hf_hub:prov-gigapath/prov-gigapath",
|
|
29
|
+
pretrained=True,
|
|
30
|
+
dynamic_img_size=dynamic_img_size,
|
|
31
|
+
out_indices=out_indices,
|
|
32
|
+
features_only=out_indices is not None,
|
|
33
|
+
)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""Pathology FMs from owkin."""
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
from eva.vision.models.networks.backbones import _utils
|
|
8
|
+
from eva.vision.models.networks.backbones.registry import register_model
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@register_model("pathology/histai_hibou_b")
|
|
12
|
+
def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
|
|
13
|
+
"""Initializes the hibou-B pathology FM by hist.ai (https://huggingface.co/histai/hibou-B).
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
17
|
+
Currently only out_indices=1 is supported.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
The model instance.
|
|
21
|
+
"""
|
|
22
|
+
return _utils.load_hugingface_model(
|
|
23
|
+
model_name="histai/hibou-B",
|
|
24
|
+
out_indices=out_indices,
|
|
25
|
+
model_kwargs={"trust_remote_code": True},
|
|
26
|
+
transform_args={"ignore_remaining_dims": True} if out_indices is not None else None,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@register_model("pathology/histai_hibou_l")
|
|
31
|
+
def histai_hibou_l(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
|
|
32
|
+
"""Initializes the hibou-L pathology FM by hist.ai (https://huggingface.co/histai/hibou-L).
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
36
|
+
Currently only out_indices=1 is supported.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
The model instance.
|
|
40
|
+
"""
|
|
41
|
+
return _utils.load_hugingface_model(
|
|
42
|
+
model_name="histai/hibou-L",
|
|
43
|
+
out_indices=out_indices,
|
|
44
|
+
model_kwargs={"trust_remote_code": True},
|
|
45
|
+
transform_args={"ignore_remaining_dims": True} if out_indices is not None else None,
|
|
46
|
+
)
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""Pathology FMs from kaiko.ai."""
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from eva.vision.models.networks.backbones.registry import register_model
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@register_model("pathology/kaiko_vits16")
|
|
12
|
+
def kaiko_vits16(
|
|
13
|
+
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
14
|
+
) -> nn.Module:
|
|
15
|
+
"""Initializes the ViTS-16 pathology FM by kaiko.ai.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
dynamic_img_size: Support different input image sizes by allowing to change
|
|
19
|
+
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
|
|
20
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
The model instance.
|
|
24
|
+
"""
|
|
25
|
+
return torch.hub.load( # type: ignore
|
|
26
|
+
repo_or_dir="kaiko-ai/towards_large_pathology_fms",
|
|
27
|
+
model="vits16",
|
|
28
|
+
trust_repo=True,
|
|
29
|
+
dynamic_img_size=dynamic_img_size,
|
|
30
|
+
out_indices=out_indices,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@register_model("pathology/kaiko_vits8")
|
|
35
|
+
def kaiko_vits8(
|
|
36
|
+
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
37
|
+
) -> nn.Module:
|
|
38
|
+
"""Initializes the ViTS-8 pathology FM by kaiko.ai.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
dynamic_img_size: Support different input image sizes by allowing to change
|
|
42
|
+
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
|
|
43
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
The model instance.
|
|
47
|
+
"""
|
|
48
|
+
return torch.hub.load( # type: ignore
|
|
49
|
+
repo_or_dir="kaiko-ai/towards_large_pathology_fms",
|
|
50
|
+
model="vits8",
|
|
51
|
+
trust_repo=True,
|
|
52
|
+
dynamic_img_size=dynamic_img_size,
|
|
53
|
+
out_indices=out_indices,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@register_model("pathology/kaiko_vitb16")
|
|
58
|
+
def kaiko_vitb16(
|
|
59
|
+
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
60
|
+
) -> nn.Module:
|
|
61
|
+
"""Initializes the ViTB-16 pathology FM by kaiko.ai.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
dynamic_img_size: Support different input image sizes by allowing to change
|
|
65
|
+
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
|
|
66
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
The model instance.
|
|
70
|
+
"""
|
|
71
|
+
return torch.hub.load( # type: ignore
|
|
72
|
+
repo_or_dir="kaiko-ai/towards_large_pathology_fms",
|
|
73
|
+
model="vitb16",
|
|
74
|
+
trust_repo=True,
|
|
75
|
+
dynamic_img_size=dynamic_img_size,
|
|
76
|
+
out_indices=out_indices,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@register_model("pathology/kaiko_vitb8")
|
|
81
|
+
def kaiko_vitb8(
|
|
82
|
+
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
83
|
+
) -> nn.Module:
|
|
84
|
+
"""Initializes the ViTB-8 pathology FM by kaiko.ai.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
dynamic_img_size: Support different input image sizes by allowing to change
|
|
88
|
+
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
|
|
89
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
The model instance.
|
|
93
|
+
"""
|
|
94
|
+
return torch.hub.load( # type: ignore
|
|
95
|
+
repo_or_dir="kaiko-ai/towards_large_pathology_fms",
|
|
96
|
+
model="vitb8",
|
|
97
|
+
trust_repo=True,
|
|
98
|
+
dynamic_img_size=dynamic_img_size,
|
|
99
|
+
out_indices=out_indices,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@register_model("pathology/kaiko_vitl14")
|
|
104
|
+
def kaiko_vitl14(
|
|
105
|
+
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
106
|
+
) -> nn.Module:
|
|
107
|
+
"""Initializes the ViTL-14 pathology FM by kaiko.ai.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
dynamic_img_size: Support different input image sizes by allowing to change
|
|
111
|
+
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
|
|
112
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
The model instance.
|
|
116
|
+
"""
|
|
117
|
+
return torch.hub.load( # type: ignore
|
|
118
|
+
repo_or_dir="kaiko-ai/towards_large_pathology_fms",
|
|
119
|
+
model="vitl14",
|
|
120
|
+
trust_repo=True,
|
|
121
|
+
dynamic_img_size=dynamic_img_size,
|
|
122
|
+
out_indices=out_indices,
|
|
123
|
+
)
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""Pathology FMs from Lunit.
|
|
2
|
+
|
|
3
|
+
Source: https://github.com/lunit-io/benchmark-ssl-pathology/releases
|
|
4
|
+
|
|
5
|
+
For training the vit-s models the following standardization parameters were used:
|
|
6
|
+
|
|
7
|
+
mean: [ 0.70322989, 0.53606487, 0.66096631 ]
|
|
8
|
+
std: [ 0.21716536, 0.26081574, 0.20723464 ]
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from typing import Tuple
|
|
12
|
+
|
|
13
|
+
from torch import nn
|
|
14
|
+
|
|
15
|
+
from eva.vision.models import wrappers
|
|
16
|
+
from eva.vision.models.networks.backbones.registry import register_model
|
|
17
|
+
|
|
18
|
+
VITS_URL_PREFIX = (
|
|
19
|
+
"https://github.com/lunit-io/benchmark-ssl-pathology/releases/download/pretrained-weights"
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@register_model("pathology/lunit_vits16")
|
|
24
|
+
def lunit_vits16(
|
|
25
|
+
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
26
|
+
) -> nn.Module:
|
|
27
|
+
"""Initializes the ViTS-16 pathology FM by lunit.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
dynamic_img_size: Support different input image sizes by allowing to change
|
|
31
|
+
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
|
|
32
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
The model instance.
|
|
36
|
+
"""
|
|
37
|
+
return wrappers.TimmModel(
|
|
38
|
+
model_name="vit_small_patch16_224.dino",
|
|
39
|
+
out_indices=out_indices,
|
|
40
|
+
model_kwargs={
|
|
41
|
+
"dynamic_img_size": dynamic_img_size,
|
|
42
|
+
},
|
|
43
|
+
checkpoint_path=f"{VITS_URL_PREFIX}/dino_vit_small_patch16_ep200.torch",
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@register_model("pathology/lunit_vits8")
|
|
48
|
+
def lunit_vits8(
|
|
49
|
+
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
50
|
+
) -> nn.Module:
|
|
51
|
+
"""Initializes the ViTS-8 pathology FM by lunit.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
dynamic_img_size: Support different input image sizes by allowing to change
|
|
55
|
+
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
|
|
56
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
The model instance.
|
|
60
|
+
"""
|
|
61
|
+
return wrappers.TimmModel(
|
|
62
|
+
model_name="vit_small_patch8_224.dino",
|
|
63
|
+
out_indices=out_indices,
|
|
64
|
+
model_kwargs={
|
|
65
|
+
"dynamic_img_size": dynamic_img_size,
|
|
66
|
+
},
|
|
67
|
+
checkpoint_path=f"{VITS_URL_PREFIX}/dino_vit_small_patch8_ep200.torch",
|
|
68
|
+
)
|