kaiko-eva 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/__init__.py +2 -2
- eva/core/callbacks/writers/__init__.py +6 -3
- eva/core/callbacks/writers/embeddings/__init__.py +6 -0
- eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
- eva/core/callbacks/writers/embeddings/base.py +192 -0
- eva/core/callbacks/writers/embeddings/classification.py +117 -0
- eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
- eva/core/callbacks/writers/embeddings/typings.py +38 -0
- eva/core/data/datasets/__init__.py +2 -2
- eva/core/data/datasets/classification/__init__.py +8 -0
- eva/core/data/datasets/classification/embeddings.py +34 -0
- eva/core/data/datasets/{embeddings/classification → classification}/multi_embeddings.py +13 -9
- eva/core/data/datasets/{embeddings/base.py → embeddings.py} +47 -32
- eva/core/data/splitting/__init__.py +6 -0
- eva/core/data/splitting/random.py +41 -0
- eva/core/data/splitting/stratified.py +56 -0
- eva/core/loggers/experimental_loggers.py +2 -2
- eva/core/loggers/log/__init__.py +3 -2
- eva/core/loggers/log/image.py +71 -0
- eva/core/loggers/log/parameters.py +10 -0
- eva/core/loggers/loggers.py +6 -0
- eva/core/metrics/__init__.py +6 -2
- eva/core/metrics/defaults/__init__.py +10 -3
- eva/core/metrics/defaults/classification/__init__.py +1 -1
- eva/core/metrics/defaults/classification/binary.py +0 -9
- eva/core/metrics/defaults/classification/multiclass.py +0 -8
- eva/core/metrics/defaults/segmentation/__init__.py +5 -0
- eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
- eva/core/metrics/generalized_dice.py +59 -0
- eva/core/metrics/mean_iou.py +120 -0
- eva/core/metrics/structs/schemas.py +3 -1
- eva/core/models/__init__.py +3 -1
- eva/core/models/modules/head.py +10 -4
- eva/core/models/modules/typings.py +14 -1
- eva/core/models/modules/utils/batch_postprocess.py +37 -5
- eva/core/models/networks/__init__.py +1 -2
- eva/core/models/networks/mlp.py +2 -2
- eva/core/models/transforms/__init__.py +6 -0
- eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
- eva/core/models/transforms/extract_patch_features.py +47 -0
- eva/core/models/wrappers/__init__.py +13 -0
- eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
- eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
- eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
- eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
- eva/core/trainers/functional.py +1 -0
- eva/core/utils/__init__.py +6 -0
- eva/core/utils/clone.py +27 -0
- eva/core/utils/memory.py +28 -0
- eva/core/utils/operations.py +26 -0
- eva/core/utils/parser.py +20 -0
- eva/vision/__init__.py +2 -2
- eva/vision/callbacks/__init__.py +5 -0
- eva/vision/callbacks/loggers/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/base.py +130 -0
- eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
- eva/vision/data/datasets/__init__.py +30 -3
- eva/vision/data/datasets/_validators.py +15 -2
- eva/vision/data/datasets/classification/__init__.py +12 -1
- eva/vision/data/datasets/classification/bach.py +10 -15
- eva/vision/data/datasets/classification/base.py +17 -24
- eva/vision/data/datasets/classification/camelyon16.py +244 -0
- eva/vision/data/datasets/classification/crc.py +10 -15
- eva/vision/data/datasets/classification/mhist.py +10 -15
- eva/vision/data/datasets/classification/panda.py +184 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
- eva/vision/data/datasets/classification/wsi.py +105 -0
- eva/vision/data/datasets/segmentation/__init__.py +15 -2
- eva/vision/data/datasets/segmentation/_utils.py +38 -0
- eva/vision/data/datasets/segmentation/base.py +16 -17
- eva/vision/data/datasets/segmentation/bcss.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +156 -0
- eva/vision/data/datasets/segmentation/embeddings.py +34 -0
- eva/vision/data/datasets/segmentation/lits.py +178 -0
- eva/vision/data/datasets/segmentation/monusac.py +236 -0
- eva/vision/data/datasets/segmentation/{total_segmentator.py → total_segmentator_2d.py} +130 -36
- eva/vision/data/datasets/wsi.py +187 -0
- eva/vision/data/transforms/__init__.py +3 -2
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
- eva/vision/data/transforms/common/resize_and_crop.py +6 -7
- eva/vision/data/transforms/normalization/__init__.py +6 -0
- eva/vision/data/transforms/normalization/clamp.py +43 -0
- eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
- eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
- eva/vision/data/wsi/__init__.py +16 -0
- eva/vision/data/wsi/backends/__init__.py +69 -0
- eva/vision/data/wsi/backends/base.py +115 -0
- eva/vision/data/wsi/backends/openslide.py +73 -0
- eva/vision/data/wsi/backends/pil.py +52 -0
- eva/vision/data/wsi/backends/tiffslide.py +42 -0
- eva/vision/data/wsi/patching/__init__.py +6 -0
- eva/vision/data/wsi/patching/coordinates.py +98 -0
- eva/vision/data/wsi/patching/mask.py +123 -0
- eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
- eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
- eva/vision/data/wsi/patching/samplers/base.py +48 -0
- eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
- eva/vision/data/wsi/patching/samplers/grid.py +47 -0
- eva/vision/data/wsi/patching/samplers/random.py +41 -0
- eva/vision/losses/__init__.py +5 -0
- eva/vision/losses/dice.py +40 -0
- eva/vision/models/__init__.py +4 -2
- eva/vision/models/modules/__init__.py +5 -0
- eva/vision/models/modules/semantic_segmentation.py +161 -0
- eva/vision/models/networks/__init__.py +1 -2
- eva/vision/models/networks/backbones/__init__.py +6 -0
- eva/vision/models/networks/backbones/_utils.py +39 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
- eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
- eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
- eva/vision/models/networks/backbones/pathology/histai.py +46 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
- eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
- eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
- eva/vision/models/networks/backbones/registry.py +47 -0
- eva/vision/models/networks/backbones/timm/__init__.py +5 -0
- eva/vision/models/networks/backbones/timm/backbones.py +54 -0
- eva/vision/models/networks/backbones/universal/__init__.py +8 -0
- eva/vision/models/networks/backbones/universal/vit.py +54 -0
- eva/vision/models/networks/decoders/__init__.py +6 -0
- eva/vision/models/networks/decoders/decoder.py +7 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
- eva/vision/models/networks/decoders/segmentation/common.py +74 -0
- eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
- eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
- eva/vision/models/wrappers/__init__.py +6 -0
- eva/vision/models/wrappers/from_registry.py +48 -0
- eva/vision/models/wrappers/from_timm.py +68 -0
- eva/vision/utils/colormap.py +77 -0
- eva/vision/utils/convert.py +56 -13
- eva/vision/utils/io/__init__.py +10 -4
- eva/vision/utils/io/image.py +21 -2
- eva/vision/utils/io/mat.py +36 -0
- eva/vision/utils/io/nifti.py +33 -12
- eva/vision/utils/io/text.py +10 -3
- kaiko_eva-0.1.0.dist-info/METADATA +553 -0
- kaiko_eva-0.1.0.dist-info/RECORD +205 -0
- {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
- {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
- eva/.DS_Store +0 -0
- eva/core/callbacks/writers/embeddings.py +0 -169
- eva/core/callbacks/writers/typings.py +0 -23
- eva/core/data/datasets/embeddings/__init__.py +0 -13
- eva/core/data/datasets/embeddings/classification/__init__.py +0 -10
- eva/core/data/datasets/embeddings/classification/embeddings.py +0 -66
- eva/core/models/networks/transforms/__init__.py +0 -5
- eva/core/models/networks/wrappers/__init__.py +0 -8
- eva/vision/models/.DS_Store +0 -0
- eva/vision/models/networks/.DS_Store +0 -0
- eva/vision/models/networks/postprocesses/__init__.py +0 -5
- eva/vision/models/networks/postprocesses/cls.py +0 -25
- kaiko_eva-0.0.2.dist-info/METADATA +0 -431
- kaiko_eva-0.0.2.dist-info/RECORD +0 -127
- /eva/core/models/{networks → wrappers}/_utils.py +0 -0
- {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
eva/core/models/__init__.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
"""Models API."""
|
|
2
2
|
|
|
3
3
|
from eva.core.models.modules import HeadModule, InferenceModule
|
|
4
|
-
from eva.core.models.networks import MLP
|
|
4
|
+
from eva.core.models.networks import MLP
|
|
5
|
+
from eva.core.models.wrappers import BaseModel, HuggingFaceModel, ModelFromFunction, ONNXModel
|
|
5
6
|
|
|
6
7
|
__all__ = [
|
|
7
8
|
"HeadModule",
|
|
8
9
|
"InferenceModule",
|
|
9
10
|
"MLP",
|
|
11
|
+
"BaseModel",
|
|
10
12
|
"HuggingFaceModel",
|
|
11
13
|
"ModelFromFunction",
|
|
12
14
|
"ONNXModel",
|
eva/core/models/modules/head.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
""""Neural Network Head Module."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Callable
|
|
3
|
+
from typing import Any, Callable, Dict
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
|
|
7
7
|
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
8
|
-
from torch import optim
|
|
8
|
+
from torch import nn, optim
|
|
9
9
|
from torch.optim import lr_scheduler
|
|
10
10
|
from typing_extensions import override
|
|
11
11
|
|
|
@@ -13,6 +13,7 @@ from eva.core.metrics import structs as metrics_lib
|
|
|
13
13
|
from eva.core.models.modules import module
|
|
14
14
|
from eva.core.models.modules.typings import INPUT_BATCH, MODEL_TYPE
|
|
15
15
|
from eva.core.models.modules.utils import batch_postprocess, grad
|
|
16
|
+
from eva.core.utils import parser
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
class HeadModule(module.ModelModule):
|
|
@@ -24,7 +25,7 @@ class HeadModule(module.ModelModule):
|
|
|
24
25
|
|
|
25
26
|
def __init__(
|
|
26
27
|
self,
|
|
27
|
-
head: MODEL_TYPE,
|
|
28
|
+
head: Dict[str, Any] | MODEL_TYPE,
|
|
28
29
|
criterion: Callable[..., torch.Tensor],
|
|
29
30
|
backbone: MODEL_TYPE | None = None,
|
|
30
31
|
optimizer: OptimizerCallable = optim.Adam,
|
|
@@ -36,6 +37,8 @@ class HeadModule(module.ModelModule):
|
|
|
36
37
|
|
|
37
38
|
Args:
|
|
38
39
|
head: The neural network that would be trained on the features.
|
|
40
|
+
If its a dictionary, it will be parsed to an object during the
|
|
41
|
+
`configure_model` step.
|
|
39
42
|
criterion: The loss function to use.
|
|
40
43
|
backbone: The feature extractor. If `None`, it will be expected
|
|
41
44
|
that the input batch returns the features directly.
|
|
@@ -48,7 +51,7 @@ class HeadModule(module.ModelModule):
|
|
|
48
51
|
"""
|
|
49
52
|
super().__init__(metrics=metrics, postprocess=postprocess)
|
|
50
53
|
|
|
51
|
-
self.head = head
|
|
54
|
+
self.head = head # type: ignore
|
|
52
55
|
self.criterion = criterion
|
|
53
56
|
self.backbone = backbone
|
|
54
57
|
self.optimizer = optimizer
|
|
@@ -59,6 +62,9 @@ class HeadModule(module.ModelModule):
|
|
|
59
62
|
if self.backbone is not None:
|
|
60
63
|
grad.deactivate_requires_grad(self.backbone)
|
|
61
64
|
|
|
65
|
+
if isinstance(self.head, dict):
|
|
66
|
+
self.head: MODEL_TYPE = parser.parse_object(self.head, expected_type=nn.Module)
|
|
67
|
+
|
|
62
68
|
@override
|
|
63
69
|
def configure_optimizers(self) -> Any:
|
|
64
70
|
parameters = self.head.parameters()
|
|
@@ -16,7 +16,20 @@ class INPUT_BATCH(NamedTuple):
|
|
|
16
16
|
data: torch.Tensor
|
|
17
17
|
"""The data batch."""
|
|
18
18
|
|
|
19
|
-
targets: torch.Tensor |
|
|
19
|
+
targets: torch.Tensor | None = None
|
|
20
|
+
"""The target batch."""
|
|
21
|
+
|
|
22
|
+
metadata: Dict[str, Any] | None = None
|
|
23
|
+
"""The associated metadata."""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class INPUT_TENSOR_BATCH(NamedTuple):
|
|
27
|
+
"""Tensor based input batch data scheme."""
|
|
28
|
+
|
|
29
|
+
data: torch.Tensor
|
|
30
|
+
"""The data batch."""
|
|
31
|
+
|
|
32
|
+
targets: torch.Tensor
|
|
20
33
|
"""The target batch."""
|
|
21
34
|
|
|
22
35
|
metadata: Dict[str, Any] | None = None
|
|
@@ -2,9 +2,10 @@
|
|
|
2
2
|
|
|
3
3
|
import dataclasses
|
|
4
4
|
import functools
|
|
5
|
-
from typing import Callable, List
|
|
5
|
+
from typing import Any, Callable, Dict, List
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
|
+
from jsonargparse import _util
|
|
8
9
|
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
9
10
|
|
|
10
11
|
Transform = Callable[[torch.Tensor], torch.Tensor]
|
|
@@ -15,10 +16,10 @@ Transform = Callable[[torch.Tensor], torch.Tensor]
|
|
|
15
16
|
class BatchPostProcess:
|
|
16
17
|
"""Batch post-processes transform schema."""
|
|
17
18
|
|
|
18
|
-
targets_transforms: List[Transform] | None = None
|
|
19
|
+
targets_transforms: List[Transform | Dict[str, Any]] | None = None
|
|
19
20
|
"""Holds the common train and evaluation metrics."""
|
|
20
21
|
|
|
21
|
-
predictions_transforms: List[Transform] | None = None
|
|
22
|
+
predictions_transforms: List[Transform | Dict[str, Any]] | None = None
|
|
22
23
|
"""Holds the common train and evaluation metrics."""
|
|
23
24
|
|
|
24
25
|
def __call__(self, outputs: STEP_OUTPUT) -> None:
|
|
@@ -35,12 +36,13 @@ class BatchPostProcess:
|
|
|
35
36
|
|
|
36
37
|
if "targets" in outputs and self.targets_transforms is not None:
|
|
37
38
|
outputs["targets"] = _apply_transforms(
|
|
38
|
-
outputs["targets"], transforms=self.targets_transforms
|
|
39
|
+
outputs["targets"], transforms=_parse_callable_inputs(self.targets_transforms)
|
|
39
40
|
)
|
|
40
41
|
|
|
41
42
|
if "predictions" in outputs and self.predictions_transforms is not None:
|
|
42
43
|
outputs["predictions"] = _apply_transforms(
|
|
43
|
-
outputs["predictions"],
|
|
44
|
+
outputs["predictions"],
|
|
45
|
+
transforms=_parse_callable_inputs(self.predictions_transforms),
|
|
44
46
|
)
|
|
45
47
|
|
|
46
48
|
|
|
@@ -55,3 +57,33 @@ def _apply_transforms(tensor: torch.Tensor, transforms: List[Transform]) -> torc
|
|
|
55
57
|
The processed tensor.
|
|
56
58
|
"""
|
|
57
59
|
return functools.reduce(lambda tensor, transform: transform(tensor), transforms, tensor)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _parse_callable_inputs(inputs: List[Callable | Dict[str, Any]]) -> List[Callable]:
|
|
63
|
+
"""Parses the inputs which where passed as dictionary to callable objects."""
|
|
64
|
+
parsed = []
|
|
65
|
+
for item in inputs:
|
|
66
|
+
if isinstance(item, dict):
|
|
67
|
+
item = _parse_dict(item)
|
|
68
|
+
parsed.append(item)
|
|
69
|
+
return parsed
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _parse_dict(item: Dict[str, Any]) -> Callable:
|
|
73
|
+
"""Parses the input dictionary to a partial callable object."""
|
|
74
|
+
if not _is_valid_dict(item):
|
|
75
|
+
raise ValueError(
|
|
76
|
+
"Transform dictionary format is not valid. "
|
|
77
|
+
"It must contain a key 'class_path' and optionally 'init_args' for "
|
|
78
|
+
"the function and additional call arguments."
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
return functools.partial(
|
|
82
|
+
_util.import_object(item["class_path"]),
|
|
83
|
+
**item.get("init_args", {}),
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _is_valid_dict(item: Dict[str, Any], /) -> bool:
|
|
88
|
+
"""Checks if the input has the valid structure."""
|
|
89
|
+
return "class_path" in item and set(item.keys()) <= {"class_path", "init_args"}
|
eva/core/models/networks/mlp.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Multi-layer Perceptron (MLP) implemented in PyTorch."""
|
|
2
2
|
|
|
3
|
-
from typing import Type
|
|
3
|
+
from typing import Tuple, Type
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
import torch.nn as nn
|
|
@@ -13,7 +13,7 @@ class MLP(nn.Module):
|
|
|
13
13
|
self,
|
|
14
14
|
input_size: int,
|
|
15
15
|
output_size: int,
|
|
16
|
-
hidden_layer_sizes:
|
|
16
|
+
hidden_layer_sizes: Tuple[int, ...] | None = None,
|
|
17
17
|
hidden_activation_fn: Type[torch.nn.Module] | None = nn.ReLU,
|
|
18
18
|
output_activation_fn: Type[torch.nn.Module] | None = None,
|
|
19
19
|
dropout: float = 0.0,
|
|
@@ -7,6 +7,14 @@ from transformers import modeling_outputs
|
|
|
7
7
|
class ExtractCLSFeatures:
|
|
8
8
|
"""Extracts the CLS token from a ViT model output."""
|
|
9
9
|
|
|
10
|
+
def __init__(self, cls_index: int = 0) -> None:
|
|
11
|
+
"""Initializes the transformation.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
cls_index: The index of the CLS token in the output tensor.
|
|
15
|
+
"""
|
|
16
|
+
self._cls_index = cls_index
|
|
17
|
+
|
|
10
18
|
def __call__(
|
|
11
19
|
self, tensor: torch.Tensor | modeling_outputs.BaseModelOutputWithPooling
|
|
12
20
|
) -> torch.Tensor:
|
|
@@ -16,9 +24,9 @@ class ExtractCLSFeatures:
|
|
|
16
24
|
tensor: The tensor representing the model output.
|
|
17
25
|
"""
|
|
18
26
|
if isinstance(tensor, torch.Tensor):
|
|
19
|
-
transformed_tensor = tensor[:,
|
|
27
|
+
transformed_tensor = tensor[:, self._cls_index, :]
|
|
20
28
|
elif isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling):
|
|
21
|
-
transformed_tensor = tensor.last_hidden_state[:,
|
|
29
|
+
transformed_tensor = tensor.last_hidden_state[:, self._cls_index, :]
|
|
22
30
|
else:
|
|
23
31
|
raise ValueError(f"Unsupported type {type(tensor)}")
|
|
24
32
|
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Transforms for extracting the patch features from a model output."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from transformers import modeling_outputs
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ExtractPatchFeatures:
|
|
11
|
+
"""Extracts the patch features from a ViT model output."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, ignore_remaining_dims: bool = False) -> None:
|
|
14
|
+
"""Initializes the transformation.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
ignore_remaining_dims: If set to `True`, ignore the remaining dimensions
|
|
18
|
+
of the patch grid if it is not a square number.
|
|
19
|
+
"""
|
|
20
|
+
self._ignore_remaining_dims = ignore_remaining_dims
|
|
21
|
+
|
|
22
|
+
def __call__(
|
|
23
|
+
self, tensor: torch.Tensor | modeling_outputs.BaseModelOutputWithPooling
|
|
24
|
+
) -> List[torch.Tensor]:
|
|
25
|
+
"""Call method for the transformation.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
tensor: The raw embeddings of the model.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
A tensor (batch_size, hidden_size, n_patches_height, n_patches_width)
|
|
32
|
+
representing the model output.
|
|
33
|
+
"""
|
|
34
|
+
if isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling):
|
|
35
|
+
features = tensor.last_hidden_state[:, 1:, :].permute(0, 2, 1)
|
|
36
|
+
batch_size, hidden_size, patch_grid = features.shape
|
|
37
|
+
height = width = int(math.sqrt(patch_grid))
|
|
38
|
+
if height * width != patch_grid:
|
|
39
|
+
if self._ignore_remaining_dims:
|
|
40
|
+
features = features[:, :, : height * width]
|
|
41
|
+
else:
|
|
42
|
+
raise ValueError(f"Patch grid size must be a square number {patch_grid}.")
|
|
43
|
+
patch_embeddings = features.view(batch_size, hidden_size, height, width)
|
|
44
|
+
else:
|
|
45
|
+
raise ValueError(f"Unsupported type {type(tensor)}")
|
|
46
|
+
|
|
47
|
+
return [patch_embeddings]
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Model Wrappers API."""
|
|
2
|
+
|
|
3
|
+
from eva.core.models.wrappers.base import BaseModel
|
|
4
|
+
from eva.core.models.wrappers.from_function import ModelFromFunction
|
|
5
|
+
from eva.core.models.wrappers.huggingface import HuggingFaceModel
|
|
6
|
+
from eva.core.models.wrappers.onnx import ONNXModel
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"BaseModel",
|
|
10
|
+
"ModelFromFunction",
|
|
11
|
+
"HuggingFaceModel",
|
|
12
|
+
"ONNXModel",
|
|
13
|
+
]
|
|
@@ -22,6 +22,8 @@ class BaseModel(nn.Module):
|
|
|
22
22
|
|
|
23
23
|
self._output_transforms = tensor_transforms
|
|
24
24
|
|
|
25
|
+
self._model: Callable[..., torch.Tensor] | nn.Module
|
|
26
|
+
|
|
25
27
|
@override
|
|
26
28
|
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
27
29
|
tensor = self.model_forward(tensor)
|
|
@@ -32,14 +34,13 @@ class BaseModel(nn.Module):
|
|
|
32
34
|
"""Loads the model."""
|
|
33
35
|
raise NotImplementedError
|
|
34
36
|
|
|
35
|
-
@abc.abstractmethod
|
|
36
37
|
def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
37
38
|
"""Implements the forward pass of the model.
|
|
38
39
|
|
|
39
40
|
Args:
|
|
40
41
|
tensor: The input tensor to the model.
|
|
41
42
|
"""
|
|
42
|
-
|
|
43
|
+
return self._model(tensor)
|
|
43
44
|
|
|
44
45
|
def _apply_transforms(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
45
46
|
if self._output_transforms is not None:
|
|
@@ -3,12 +3,10 @@
|
|
|
3
3
|
from typing import Any, Callable, Dict
|
|
4
4
|
|
|
5
5
|
import jsonargparse
|
|
6
|
-
import torch
|
|
7
6
|
from torch import nn
|
|
8
7
|
from typing_extensions import override
|
|
9
8
|
|
|
10
|
-
from eva.core.models.
|
|
11
|
-
from eva.core.models.networks.wrappers import base
|
|
9
|
+
from eva.core.models.wrappers import _utils, base
|
|
12
10
|
|
|
13
11
|
|
|
14
12
|
class ModelFromFunction(base.BaseModel):
|
|
@@ -36,23 +34,18 @@ class ModelFromFunction(base.BaseModel):
|
|
|
36
34
|
tensor_transforms: The transforms to apply to the output tensor
|
|
37
35
|
produced by the model.
|
|
38
36
|
"""
|
|
39
|
-
super().__init__()
|
|
37
|
+
super().__init__(tensor_transforms=tensor_transforms)
|
|
40
38
|
|
|
41
39
|
self._path = path
|
|
42
40
|
self._arguments = arguments
|
|
43
41
|
self._checkpoint_path = checkpoint_path
|
|
44
|
-
self._tensor_transforms = tensor_transforms
|
|
45
42
|
|
|
46
|
-
self.
|
|
43
|
+
self.load_model()
|
|
47
44
|
|
|
48
45
|
@override
|
|
49
|
-
def load_model(self) ->
|
|
46
|
+
def load_model(self) -> None:
|
|
50
47
|
class_path = jsonargparse.class_from_function(self._path, func_return=nn.Module)
|
|
51
48
|
model = class_path(**self._arguments or {})
|
|
52
49
|
if self._checkpoint_path is not None:
|
|
53
50
|
_utils.load_model_weights(model, self._checkpoint_path)
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
@override
|
|
57
|
-
def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
58
|
-
return self._model(tensor)
|
|
51
|
+
self._model = model
|
|
@@ -1,18 +1,22 @@
|
|
|
1
1
|
"""Wrappers for HuggingFace `transformers` models."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Callable
|
|
3
|
+
from typing import Any, Callable, Dict
|
|
4
4
|
|
|
5
|
-
import torch
|
|
6
5
|
import transformers
|
|
7
6
|
from typing_extensions import override
|
|
8
7
|
|
|
9
|
-
from eva.core.models.
|
|
8
|
+
from eva.core.models.wrappers import base
|
|
10
9
|
|
|
11
10
|
|
|
12
11
|
class HuggingFaceModel(base.BaseModel):
|
|
13
12
|
"""Wrapper class for loading HuggingFace `transformers` models."""
|
|
14
13
|
|
|
15
|
-
def __init__(
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
model_name_or_path: str,
|
|
17
|
+
tensor_transforms: Callable | None = None,
|
|
18
|
+
model_kwargs: Dict[str, Any] | None = None,
|
|
19
|
+
) -> None:
|
|
16
20
|
"""Initializes the model.
|
|
17
21
|
|
|
18
22
|
Args:
|
|
@@ -21,17 +25,17 @@ class HuggingFaceModel(base.BaseModel):
|
|
|
21
25
|
model hub.
|
|
22
26
|
tensor_transforms: The transforms to apply to the output tensor
|
|
23
27
|
produced by the model.
|
|
28
|
+
model_kwargs: The arguments used for instantiating the model.
|
|
24
29
|
"""
|
|
25
30
|
super().__init__(tensor_transforms=tensor_transforms)
|
|
26
31
|
|
|
27
32
|
self._model_name_or_path = model_name_or_path
|
|
28
|
-
self.
|
|
33
|
+
self._model_kwargs = model_kwargs or {}
|
|
29
34
|
|
|
30
|
-
|
|
31
|
-
def load_model(self) -> Any:
|
|
32
|
-
config = transformers.AutoConfig.from_pretrained(self._model_name_or_path)
|
|
33
|
-
return transformers.AutoModel.from_pretrained(self._model_name_or_path, config=config)
|
|
35
|
+
self.load_model()
|
|
34
36
|
|
|
35
37
|
@override
|
|
36
|
-
def
|
|
37
|
-
|
|
38
|
+
def load_model(self) -> None:
|
|
39
|
+
self._model = transformers.AutoModel.from_pretrained(
|
|
40
|
+
self._model_name_or_path, **self._model_kwargs
|
|
41
|
+
)
|
|
@@ -6,7 +6,7 @@ import onnxruntime as ort
|
|
|
6
6
|
import torch
|
|
7
7
|
from typing_extensions import override
|
|
8
8
|
|
|
9
|
-
from eva.core.models.
|
|
9
|
+
from eva.core.models.wrappers import base
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class ONNXModel(base.BaseModel):
|
|
@@ -29,19 +29,22 @@ class ONNXModel(base.BaseModel):
|
|
|
29
29
|
|
|
30
30
|
self._path = path
|
|
31
31
|
self._device = device
|
|
32
|
-
|
|
32
|
+
|
|
33
|
+
self.load_model()
|
|
33
34
|
|
|
34
35
|
@override
|
|
35
36
|
def load_model(self) -> Any:
|
|
36
37
|
if self._device == "cuda" and not torch.cuda.is_available():
|
|
37
38
|
raise ValueError("Device is set to 'cuda', but CUDA is not available.")
|
|
38
39
|
provider = "CUDAExecutionProvider" if self._device == "cuda" else "CPUExecutionProvider"
|
|
39
|
-
|
|
40
|
+
self._model = ort.InferenceSession(self._path, providers=[provider]) # type: ignore
|
|
40
41
|
|
|
41
42
|
@override
|
|
42
43
|
def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
43
44
|
# TODO: Use IO binding to avoid copying the tensor to CPU.
|
|
44
45
|
# https://onnxruntime.ai/docs/api/python/api_summary.html#data-on-device
|
|
46
|
+
if not isinstance(self._model, ort.InferenceSession):
|
|
47
|
+
raise ValueError("Model is not loaded.")
|
|
45
48
|
inputs = {self._model.get_inputs()[0].name: tensor.detach().cpu().numpy()}
|
|
46
49
|
outputs = self._model.run(None, inputs)[0]
|
|
47
50
|
return torch.from_numpy(outputs).float().to(tensor.device)
|
eva/core/trainers/functional.py
CHANGED
|
@@ -69,6 +69,7 @@ def run_evaluation(
|
|
|
69
69
|
A tuple of with the validation and the test metrics (if exists).
|
|
70
70
|
"""
|
|
71
71
|
trainer, model = _utils.clone(base_trainer, base_model)
|
|
72
|
+
model.configure_model()
|
|
72
73
|
trainer.setup_log_dirs(run_id or "")
|
|
73
74
|
return fit_and_validate(trainer, model, datamodule, verbose=verbose)
|
|
74
75
|
|
eva/core/utils/__init__.py
CHANGED
eva/core/utils/clone.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Clone related functions."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@functools.singledispatch
|
|
10
|
+
def clone(tensor_type: Any) -> Any:
|
|
11
|
+
"""Clone tensor objects."""
|
|
12
|
+
raise TypeError(f"Unsupported input type: {type(input)}.")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@clone.register
|
|
16
|
+
def _(tensor: torch.Tensor) -> torch.Tensor:
|
|
17
|
+
return tensor.clone()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@clone.register
|
|
21
|
+
def _(tensors: list) -> List[torch.Tensor]:
|
|
22
|
+
return list(map(clone, tensors))
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@clone.register
|
|
26
|
+
def _(tensors: dict) -> Dict[str, torch.Tensor]:
|
|
27
|
+
return {key: clone(tensors[key]) for key in tensors}
|
eva/core/utils/memory.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Memory related functions."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@functools.singledispatch
|
|
10
|
+
def to_cpu(tensor_type: Any) -> Any:
|
|
11
|
+
"""Moves tensor objects to `cpu`."""
|
|
12
|
+
raise TypeError(f"Unsupported input type: {type(tensor_type)}.")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@to_cpu.register
|
|
16
|
+
def _(tensor: torch.Tensor) -> torch.Tensor:
|
|
17
|
+
detached_tensor = tensor.detach()
|
|
18
|
+
return detached_tensor.cpu()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@to_cpu.register
|
|
22
|
+
def _(tensors: list) -> List[torch.Tensor]:
|
|
23
|
+
return list(map(to_cpu, tensors))
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@to_cpu.register
|
|
27
|
+
def _(tensors: dict) -> Dict[str, torch.Tensor]:
|
|
28
|
+
return {key: to_cpu(tensors[key]) for key in tensors}
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Functional operations."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from typing import Iterable, List
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def numeric_sort(item: Iterable[str], /) -> List[str]:
|
|
8
|
+
"""Sorts an iterable of strings treating embedded numbers as numeric values.
|
|
9
|
+
|
|
10
|
+
Here the strings are compared based on their numeric value rather than their
|
|
11
|
+
string representation.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
item: An iterable of strings to be sorted.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
A list of strings sorted based on their numeric values.
|
|
18
|
+
"""
|
|
19
|
+
return sorted(
|
|
20
|
+
item,
|
|
21
|
+
key=lambda value: re.sub(
|
|
22
|
+
r"(\d+)",
|
|
23
|
+
lambda num: f"{int(num[0]):010d}",
|
|
24
|
+
value,
|
|
25
|
+
),
|
|
26
|
+
)
|
eva/core/utils/parser.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Parsing related helper functions."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict
|
|
4
|
+
|
|
5
|
+
import jsonargparse
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def parse_object(config: Dict[str, Any], expected_type: Any = Any) -> Any:
|
|
9
|
+
"""Parse object which is defined as dictionary."""
|
|
10
|
+
parser = jsonargparse.ArgumentParser()
|
|
11
|
+
parser.add_argument("module", type=expected_type)
|
|
12
|
+
configuration = parser.parse_object({"module": config})
|
|
13
|
+
init_object = parser.instantiate_classes(configuration)
|
|
14
|
+
obj_module = init_object.module
|
|
15
|
+
if isinstance(obj_module, jsonargparse.Namespace):
|
|
16
|
+
raise ValueError(
|
|
17
|
+
f"Failed to parsed object '{obj_module.class_path}'. "
|
|
18
|
+
"Please check that the initialized arguments are valid."
|
|
19
|
+
)
|
|
20
|
+
return obj_module
|
eva/vision/__init__.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""eva vision API."""
|
|
2
2
|
|
|
3
3
|
try:
|
|
4
|
-
from eva.vision import models, utils
|
|
4
|
+
from eva.vision import callbacks, losses, models, utils
|
|
5
5
|
from eva.vision.data import datasets, transforms
|
|
6
6
|
except ImportError as e:
|
|
7
7
|
msg = (
|
|
@@ -11,4 +11,4 @@ except ImportError as e:
|
|
|
11
11
|
)
|
|
12
12
|
raise ImportError(str(e) + "\n\n" + msg) from e
|
|
13
13
|
|
|
14
|
-
__all__ = ["models", "utils", "datasets", "transforms"]
|
|
14
|
+
__all__ = ["callbacks", "losses", "models", "utils", "datasets", "transforms"]
|