kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/__init__.py +3 -2
- eva/core/callbacks/config.py +143 -0
- eva/core/callbacks/writers/__init__.py +6 -3
- eva/core/callbacks/writers/embeddings/__init__.py +6 -0
- eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
- eva/core/callbacks/writers/embeddings/base.py +192 -0
- eva/core/callbacks/writers/embeddings/classification.py +117 -0
- eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
- eva/core/callbacks/writers/embeddings/typings.py +38 -0
- eva/core/data/datasets/__init__.py +10 -2
- eva/core/data/datasets/classification/__init__.py +5 -2
- eva/core/data/datasets/classification/embeddings.py +15 -135
- eva/core/data/datasets/classification/multi_embeddings.py +110 -0
- eva/core/data/datasets/embeddings.py +167 -0
- eva/core/data/splitting/__init__.py +6 -0
- eva/core/data/splitting/random.py +41 -0
- eva/core/data/splitting/stratified.py +56 -0
- eva/core/data/transforms/__init__.py +3 -1
- eva/core/data/transforms/padding/__init__.py +5 -0
- eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
- eva/core/data/transforms/sampling/__init__.py +5 -0
- eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
- eva/core/loggers/__init__.py +7 -0
- eva/core/loggers/dummy.py +38 -0
- eva/core/loggers/experimental_loggers.py +8 -0
- eva/core/loggers/log/__init__.py +6 -0
- eva/core/loggers/log/image.py +71 -0
- eva/core/loggers/log/parameters.py +74 -0
- eva/core/loggers/log/utils.py +13 -0
- eva/core/loggers/loggers.py +6 -0
- eva/core/metrics/__init__.py +6 -2
- eva/core/metrics/defaults/__init__.py +10 -3
- eva/core/metrics/defaults/classification/__init__.py +1 -1
- eva/core/metrics/defaults/classification/binary.py +0 -9
- eva/core/metrics/defaults/classification/multiclass.py +0 -8
- eva/core/metrics/defaults/segmentation/__init__.py +5 -0
- eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
- eva/core/metrics/generalized_dice.py +59 -0
- eva/core/metrics/mean_iou.py +120 -0
- eva/core/metrics/structs/schemas.py +3 -1
- eva/core/models/__init__.py +3 -1
- eva/core/models/modules/head.py +16 -15
- eva/core/models/modules/module.py +25 -1
- eva/core/models/modules/typings.py +14 -1
- eva/core/models/modules/utils/batch_postprocess.py +37 -5
- eva/core/models/networks/__init__.py +1 -2
- eva/core/models/networks/mlp.py +2 -2
- eva/core/models/transforms/__init__.py +6 -0
- eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
- eva/core/models/transforms/extract_patch_features.py +47 -0
- eva/core/models/wrappers/__init__.py +13 -0
- eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
- eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
- eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
- eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
- eva/core/trainers/_recorder.py +69 -7
- eva/core/trainers/functional.py +23 -5
- eva/core/trainers/trainer.py +20 -6
- eva/core/utils/__init__.py +6 -0
- eva/core/utils/clone.py +27 -0
- eva/core/utils/memory.py +28 -0
- eva/core/utils/operations.py +26 -0
- eva/core/utils/parser.py +20 -0
- eva/vision/__init__.py +2 -2
- eva/vision/callbacks/__init__.py +5 -0
- eva/vision/callbacks/loggers/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/base.py +130 -0
- eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
- eva/vision/data/datasets/__init__.py +24 -4
- eva/vision/data/datasets/_utils.py +3 -3
- eva/vision/data/datasets/_validators.py +15 -2
- eva/vision/data/datasets/classification/__init__.py +6 -2
- eva/vision/data/datasets/classification/bach.py +10 -15
- eva/vision/data/datasets/classification/base.py +17 -24
- eva/vision/data/datasets/classification/camelyon16.py +244 -0
- eva/vision/data/datasets/classification/crc.py +10 -15
- eva/vision/data/datasets/classification/mhist.py +10 -15
- eva/vision/data/datasets/classification/panda.py +184 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
- eva/vision/data/datasets/classification/wsi.py +105 -0
- eva/vision/data/datasets/segmentation/__init__.py +15 -2
- eva/vision/data/datasets/segmentation/_utils.py +38 -0
- eva/vision/data/datasets/segmentation/base.py +31 -47
- eva/vision/data/datasets/segmentation/bcss.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +156 -0
- eva/vision/data/datasets/segmentation/embeddings.py +34 -0
- eva/vision/data/datasets/segmentation/lits.py +178 -0
- eva/vision/data/datasets/segmentation/monusac.py +236 -0
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
- eva/vision/data/datasets/wsi.py +187 -0
- eva/vision/data/transforms/__init__.py +3 -2
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
- eva/vision/data/transforms/common/resize_and_crop.py +6 -7
- eva/vision/data/transforms/normalization/__init__.py +6 -0
- eva/vision/data/transforms/normalization/clamp.py +43 -0
- eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
- eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
- eva/vision/data/wsi/__init__.py +16 -0
- eva/vision/data/wsi/backends/__init__.py +69 -0
- eva/vision/data/wsi/backends/base.py +115 -0
- eva/vision/data/wsi/backends/openslide.py +73 -0
- eva/vision/data/wsi/backends/pil.py +52 -0
- eva/vision/data/wsi/backends/tiffslide.py +42 -0
- eva/vision/data/wsi/patching/__init__.py +6 -0
- eva/vision/data/wsi/patching/coordinates.py +98 -0
- eva/vision/data/wsi/patching/mask.py +123 -0
- eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
- eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
- eva/vision/data/wsi/patching/samplers/base.py +48 -0
- eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
- eva/vision/data/wsi/patching/samplers/grid.py +47 -0
- eva/vision/data/wsi/patching/samplers/random.py +41 -0
- eva/vision/losses/__init__.py +5 -0
- eva/vision/losses/dice.py +40 -0
- eva/vision/models/__init__.py +4 -2
- eva/vision/models/modules/__init__.py +5 -0
- eva/vision/models/modules/semantic_segmentation.py +161 -0
- eva/vision/models/networks/__init__.py +1 -2
- eva/vision/models/networks/backbones/__init__.py +6 -0
- eva/vision/models/networks/backbones/_utils.py +39 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
- eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
- eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
- eva/vision/models/networks/backbones/pathology/histai.py +46 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
- eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
- eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
- eva/vision/models/networks/backbones/registry.py +47 -0
- eva/vision/models/networks/backbones/timm/__init__.py +5 -0
- eva/vision/models/networks/backbones/timm/backbones.py +54 -0
- eva/vision/models/networks/backbones/universal/__init__.py +8 -0
- eva/vision/models/networks/backbones/universal/vit.py +54 -0
- eva/vision/models/networks/decoders/__init__.py +6 -0
- eva/vision/models/networks/decoders/decoder.py +7 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
- eva/vision/models/networks/decoders/segmentation/common.py +74 -0
- eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
- eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
- eva/vision/models/wrappers/__init__.py +6 -0
- eva/vision/models/wrappers/from_registry.py +48 -0
- eva/vision/models/wrappers/from_timm.py +68 -0
- eva/vision/utils/colormap.py +77 -0
- eva/vision/utils/convert.py +67 -0
- eva/vision/utils/io/__init__.py +10 -4
- eva/vision/utils/io/image.py +21 -2
- eva/vision/utils/io/mat.py +36 -0
- eva/vision/utils/io/nifti.py +40 -15
- eva/vision/utils/io/text.py +10 -3
- kaiko_eva-0.1.0.dist-info/METADATA +553 -0
- kaiko_eva-0.1.0.dist-info/RECORD +205 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
- eva/core/callbacks/writers/embeddings.py +0 -169
- eva/core/callbacks/writers/typings.py +0 -23
- eva/core/models/networks/transforms/__init__.py +0 -5
- eva/core/models/networks/wrappers/__init__.py +0 -8
- eva/vision/data/datasets/classification/total_segmentator.py +0 -213
- eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
- eva/vision/models/networks/postprocesses/__init__.py +0 -5
- eva/vision/models/networks/postprocesses/cls.py +0 -25
- kaiko_eva-0.0.1.dist-info/METADATA +0 -405
- kaiko_eva-0.0.1.dist-info/RECORD +0 -110
- /eva/core/models/{networks → wrappers}/_utils.py +0 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""Pathology FMs from MahmoodLab."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Tuple
|
|
6
|
+
|
|
7
|
+
import huggingface_hub
|
|
8
|
+
from loguru import logger
|
|
9
|
+
from torch import nn
|
|
10
|
+
|
|
11
|
+
from eva.vision.models import wrappers
|
|
12
|
+
from eva.vision.models.networks.backbones.registry import register_model
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@register_model("pathology/mahmood_uni")
|
|
16
|
+
def mahmood_uni(
|
|
17
|
+
dynamic_img_size: bool = True,
|
|
18
|
+
out_indices: int | Tuple[int, ...] | None = None,
|
|
19
|
+
hf_token: str | None = None,
|
|
20
|
+
download_dir: str = os.path.join(str(Path.home()), ".cache/eva"),
|
|
21
|
+
) -> nn.Module:
|
|
22
|
+
"""Initializes UNI model from MahmoodLab.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
dynamic_img_size: Support different input image sizes by allowing to change
|
|
26
|
+
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
|
|
27
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
28
|
+
hf_token: HuggingFace token to download the model.
|
|
29
|
+
download_dir: Directory to download the model checkpoint.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
The model instance.
|
|
33
|
+
"""
|
|
34
|
+
token = hf_token or os.environ.get("HF_TOKEN")
|
|
35
|
+
if not token:
|
|
36
|
+
raise ValueError(
|
|
37
|
+
"Please provide a HuggingFace token to download the model. "
|
|
38
|
+
"You can either pass it as an argument or set the env variable HF_TOKEN."
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
checkpoint_path = os.path.join(download_dir, "pytorch_model.bin")
|
|
42
|
+
|
|
43
|
+
if not os.path.exists(checkpoint_path):
|
|
44
|
+
logger.info(f"Downloading the model checkpoint to {download_dir} ...")
|
|
45
|
+
os.makedirs(download_dir, exist_ok=True)
|
|
46
|
+
huggingface_hub.login(token=token)
|
|
47
|
+
huggingface_hub.hf_hub_download(
|
|
48
|
+
"MahmoodLab/UNI",
|
|
49
|
+
filename="pytorch_model.bin",
|
|
50
|
+
local_dir=download_dir,
|
|
51
|
+
force_download=True,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
return wrappers.TimmModel(
|
|
55
|
+
model_name="vit_large_patch16_224",
|
|
56
|
+
out_indices=out_indices,
|
|
57
|
+
model_kwargs={
|
|
58
|
+
"init_values": 1e-5,
|
|
59
|
+
"dynamic_img_size": dynamic_img_size,
|
|
60
|
+
},
|
|
61
|
+
checkpoint_path=checkpoint_path,
|
|
62
|
+
)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Pathology FMs from owkin."""
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
from eva.vision.models.networks.backbones import _utils
|
|
8
|
+
from eva.vision.models.networks.backbones.registry import register_model
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@register_model("pathology/owkin_phikon")
|
|
12
|
+
def owkin_phikon(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
|
|
13
|
+
"""Initializes the phikon pathology FM by owkin (https://huggingface.co/owkin/phikon).
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
17
|
+
Currently only out_indices=1 is supported.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
The model instance.
|
|
21
|
+
"""
|
|
22
|
+
return _utils.load_hugingface_model(model_name="owkin/phikon", out_indices=out_indices)
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Backbone Model Registry."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Callable, Dict, List
|
|
4
|
+
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BackboneModelRegistry:
|
|
9
|
+
"""A model registry for accessing backbone models by name."""
|
|
10
|
+
|
|
11
|
+
_registry: Dict[str, Callable[..., nn.Module]] = {}
|
|
12
|
+
|
|
13
|
+
@classmethod
|
|
14
|
+
def register(cls, name: str) -> Callable:
|
|
15
|
+
"""Decorator to register a new model."""
|
|
16
|
+
|
|
17
|
+
def decorator(model_fn: Callable[..., nn.Module]) -> Callable[..., nn.Module]:
|
|
18
|
+
if name in cls._registry:
|
|
19
|
+
raise ValueError(f"Model {name} is already registered.")
|
|
20
|
+
cls._registry[name] = model_fn
|
|
21
|
+
return model_fn
|
|
22
|
+
|
|
23
|
+
return decorator
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def get(cls, model_name: str) -> Callable[..., nn.Module]:
|
|
27
|
+
"""Gets a model function from the registry."""
|
|
28
|
+
if model_name not in cls._registry:
|
|
29
|
+
raise ValueError(f"Model {model_name} not found in the registry.")
|
|
30
|
+
return cls._registry[model_name]
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def load_model(cls, model_name: str, model_kwargs: Dict[str, Any] | None = None) -> nn.Module:
|
|
34
|
+
"""Loads & initializes a model class from the registry."""
|
|
35
|
+
model_fn = cls.get(model_name)
|
|
36
|
+
return model_fn(**(model_kwargs or {}))
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
def list_models(cls) -> List[str]:
|
|
40
|
+
"""List all models in the registry."""
|
|
41
|
+
register_models = [name for name in cls._registry.keys() if not name.startswith("timm")]
|
|
42
|
+
return register_models + ["timm/<model_name>"]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def register_model(name: str) -> Callable:
|
|
46
|
+
"""Simple decorator to register a model."""
|
|
47
|
+
return BackboneModelRegistry.register(name)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""timm backbones."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
import timm
|
|
7
|
+
from loguru import logger
|
|
8
|
+
from torch import nn
|
|
9
|
+
|
|
10
|
+
from eva.vision.models import wrappers
|
|
11
|
+
from eva.vision.models.networks.backbones.registry import BackboneModelRegistry
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def timm_model(
|
|
15
|
+
model_name: str,
|
|
16
|
+
checkpoint_path: str | None = None,
|
|
17
|
+
pretrained: bool = False,
|
|
18
|
+
dynamic_img_size: bool = True,
|
|
19
|
+
out_indices: int | Tuple[int, ...] | None = None,
|
|
20
|
+
**kwargs,
|
|
21
|
+
) -> nn.Module:
|
|
22
|
+
"""Initializes any ViT model from timm with weights from a specified checkpoint.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
model_name: The name of the model to load.
|
|
26
|
+
checkpoint_path: The path to the checkpoint file.
|
|
27
|
+
pretrained: If set to `True`, load pretrained ImageNet-1k weights.
|
|
28
|
+
dynamic_img_size: Support different input image sizes by allowing to change
|
|
29
|
+
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
|
|
30
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
31
|
+
**kwargs: Additional arguments to pass to the model
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
The VIT model instance.
|
|
35
|
+
"""
|
|
36
|
+
logger.info(
|
|
37
|
+
f"Loading timm model {model_name}"
|
|
38
|
+
+ (f"using checkpoint {checkpoint_path}" if checkpoint_path else "")
|
|
39
|
+
)
|
|
40
|
+
return wrappers.TimmModel(
|
|
41
|
+
model_name=model_name,
|
|
42
|
+
checkpoint_path=checkpoint_path or "",
|
|
43
|
+
pretrained=pretrained,
|
|
44
|
+
out_indices=out_indices,
|
|
45
|
+
model_kwargs=kwargs,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
BackboneModelRegistry._registry.update(
|
|
50
|
+
{
|
|
51
|
+
f"timm/{model_name}": functools.partial(timm_model, model_name=model_name)
|
|
52
|
+
for model_name in timm.list_models()
|
|
53
|
+
}
|
|
54
|
+
)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Vision Transformers base universal backbones."""
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
import timm
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from eva.vision.models.networks.backbones.registry import register_model
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@register_model("universal/vit_small_patch16_224_random")
|
|
12
|
+
def vit_small_patch16_224_random(
|
|
13
|
+
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
14
|
+
) -> nn.Module:
|
|
15
|
+
"""Initializes a ViTS-16 baseline model with random weights.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
dynamic_img_size: Support different input image sizes by allowing to change
|
|
19
|
+
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
|
|
20
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
The torch ViTS-16 based foundation model.
|
|
24
|
+
"""
|
|
25
|
+
return timm.create_model(
|
|
26
|
+
model_name="vit_small_patch16_224.dino",
|
|
27
|
+
pretrained=False,
|
|
28
|
+
features_only=out_indices is not None,
|
|
29
|
+
out_indices=out_indices,
|
|
30
|
+
dynamic_img_size=dynamic_img_size,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@register_model("universal/vit_small_patch16_224_dino")
|
|
35
|
+
def vit_small_patch16_224_dino(
|
|
36
|
+
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
37
|
+
) -> nn.Module:
|
|
38
|
+
"""Initializes a ViTS-16 baseline model pretrained w/ DINO.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
dynamic_img_size: Support different input image sizes by allowing to change
|
|
42
|
+
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
|
|
43
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
The torch ViTS-16 based foundation model.
|
|
47
|
+
"""
|
|
48
|
+
return timm.create_model(
|
|
49
|
+
model_name="vit_small_patch16_224.dino",
|
|
50
|
+
pretrained=True,
|
|
51
|
+
features_only=out_indices is not None,
|
|
52
|
+
out_indices=out_indices,
|
|
53
|
+
dynamic_img_size=dynamic_img_size,
|
|
54
|
+
)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Segmentation decoder heads API."""
|
|
2
|
+
|
|
3
|
+
from eva.vision.models.networks.decoders.segmentation.common import (
|
|
4
|
+
ConvDecoder1x1,
|
|
5
|
+
ConvDecoderMS,
|
|
6
|
+
SingleLinearDecoder,
|
|
7
|
+
)
|
|
8
|
+
from eva.vision.models.networks.decoders.segmentation.conv2d import ConvDecoder
|
|
9
|
+
from eva.vision.models.networks.decoders.segmentation.linear import LinearDecoder
|
|
10
|
+
|
|
11
|
+
__all__ = ["ConvDecoder1x1", "ConvDecoderMS", "SingleLinearDecoder", "ConvDecoder", "LinearDecoder"]
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""Common semantic segmentation decoders.
|
|
2
|
+
|
|
3
|
+
This module contains implementations of different types of decoder models
|
|
4
|
+
used in semantic segmentation. These decoders convert the high-level features
|
|
5
|
+
output by an encoder into pixel-wise predictions for segmentation tasks.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from torch import nn
|
|
9
|
+
|
|
10
|
+
from eva.vision.models.networks.decoders.segmentation import conv2d, linear
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ConvDecoder1x1(conv2d.ConvDecoder):
|
|
14
|
+
"""A convolutional decoder with a single 1x1 convolutional layer."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, in_features: int, num_classes: int) -> None:
|
|
17
|
+
"""Initializes the decoder.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
in_features: The hidden dimension size of the embeddings.
|
|
21
|
+
num_classes: Number of output classes as channels.
|
|
22
|
+
"""
|
|
23
|
+
super().__init__(
|
|
24
|
+
layers=nn.Conv2d(
|
|
25
|
+
in_channels=in_features,
|
|
26
|
+
out_channels=num_classes,
|
|
27
|
+
kernel_size=(1, 1),
|
|
28
|
+
),
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ConvDecoderMS(conv2d.ConvDecoder):
|
|
33
|
+
"""A multi-stage convolutional decoder with upsampling and convolutional layers.
|
|
34
|
+
|
|
35
|
+
This decoder applies a series of upsampling and convolutional layers to transform
|
|
36
|
+
the input features into output predictions with the desired spatial resolution.
|
|
37
|
+
|
|
38
|
+
This decoder is based on the `+ms` segmentation decoder from DINOv2
|
|
39
|
+
(https://arxiv.org/pdf/2304.07193)
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, in_features: int, num_classes: int) -> None:
|
|
43
|
+
"""Initializes the decoder.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
in_features: The hidden dimension size of the embeddings.
|
|
47
|
+
num_classes: Number of output classes as channels.
|
|
48
|
+
"""
|
|
49
|
+
super().__init__(
|
|
50
|
+
layers=nn.Sequential(
|
|
51
|
+
nn.Upsample(scale_factor=2),
|
|
52
|
+
nn.Conv2d(in_features, 64, kernel_size=(3, 3), padding=(1, 1)),
|
|
53
|
+
nn.Upsample(scale_factor=2),
|
|
54
|
+
nn.Conv2d(64, num_classes, kernel_size=(3, 3), padding=(1, 1)),
|
|
55
|
+
),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class SingleLinearDecoder(linear.LinearDecoder):
|
|
60
|
+
"""A simple linear decoder with a single fully connected layer."""
|
|
61
|
+
|
|
62
|
+
def __init__(self, in_features: int, num_classes: int) -> None:
|
|
63
|
+
"""Initializes the decoder.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
in_features: The hidden dimension size of the embeddings.
|
|
67
|
+
num_classes: Number of output classes as channels.
|
|
68
|
+
"""
|
|
69
|
+
super().__init__(
|
|
70
|
+
layers=nn.Linear(
|
|
71
|
+
in_features=in_features,
|
|
72
|
+
out_features=num_classes,
|
|
73
|
+
),
|
|
74
|
+
)
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
"""Convolutional based semantic segmentation decoder."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
from torch.nn import functional
|
|
8
|
+
|
|
9
|
+
from eva.vision.models.networks.decoders import decoder
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ConvDecoder(decoder.Decoder):
|
|
13
|
+
"""Convolutional segmentation decoder."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, layers: nn.Module) -> None:
|
|
16
|
+
"""Initializes the convolutional based decoder head.
|
|
17
|
+
|
|
18
|
+
Here the input nn layers will be directly applied to the
|
|
19
|
+
features of shape (batch_size, hidden_size, n_patches_height,
|
|
20
|
+
n_patches_width), where n_patches is image_size / patch_size.
|
|
21
|
+
Note the n_patches is also known as grid_size.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
layers: The convolutional layers to be used as the decoder head.
|
|
25
|
+
"""
|
|
26
|
+
super().__init__()
|
|
27
|
+
|
|
28
|
+
self._layers = layers
|
|
29
|
+
|
|
30
|
+
def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor:
|
|
31
|
+
"""Forward function for multi-level feature maps to a single one.
|
|
32
|
+
|
|
33
|
+
It will interpolate the features and concat them into a single tensor
|
|
34
|
+
on the dimension axis of the hidden size.
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
>>> features = [torch.Tensor(16, 384, 14, 14), torch.Size(16, 384, 14, 14)]
|
|
38
|
+
>>> output = self._forward_features(features)
|
|
39
|
+
>>> assert output.shape == torch.Size([16, 768, 14, 14])
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
features: List of multi-level image features of shape (batch_size,
|
|
43
|
+
hidden_size, n_patches_height, n_patches_width).
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
A tensor of shape (batch_size, hidden_size, n_patches_height,
|
|
47
|
+
n_patches_width) which is feature map of the decoder head.
|
|
48
|
+
"""
|
|
49
|
+
if not isinstance(features, list) or features[0].ndim != 4:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"Input features should be a list of four (4) dimensional inputs of "
|
|
52
|
+
"shape (batch_size, hidden_size, n_patches_height, n_patches_width)."
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
upsampled_features = [
|
|
56
|
+
functional.interpolate(
|
|
57
|
+
input=embeddings,
|
|
58
|
+
size=features[0].shape[2:],
|
|
59
|
+
mode="bilinear",
|
|
60
|
+
align_corners=False,
|
|
61
|
+
)
|
|
62
|
+
for embeddings in features
|
|
63
|
+
]
|
|
64
|
+
return torch.cat(upsampled_features, dim=1)
|
|
65
|
+
|
|
66
|
+
def _forward_head(self, patch_embeddings: torch.Tensor) -> torch.Tensor:
|
|
67
|
+
"""Forward of the decoder head.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
patch_embeddings: The patch embeddings tensor of shape
|
|
71
|
+
(batch_size, hidden_size, n_patches_height, n_patches_width).
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
The logits as a tensor (batch_size, n_classes, upscale_height, upscale_width).
|
|
75
|
+
"""
|
|
76
|
+
return self._layers(patch_embeddings)
|
|
77
|
+
|
|
78
|
+
def _cls_seg(
|
|
79
|
+
self,
|
|
80
|
+
logits: torch.Tensor,
|
|
81
|
+
image_size: Tuple[int, int],
|
|
82
|
+
) -> torch.Tensor:
|
|
83
|
+
"""Classify each pixel of the image.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
logits: The decoder outputs of shape (batch_size, n_classes,
|
|
87
|
+
height, width).
|
|
88
|
+
image_size: The target image size (height, width).
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Tensor containing scores for all of the classes with shape
|
|
92
|
+
(batch_size, n_classes, image_height, image_width).
|
|
93
|
+
"""
|
|
94
|
+
return functional.interpolate(logits, image_size, mode="bilinear")
|
|
95
|
+
|
|
96
|
+
def forward(
|
|
97
|
+
self,
|
|
98
|
+
features: List[torch.Tensor],
|
|
99
|
+
image_size: Tuple[int, int],
|
|
100
|
+
) -> torch.Tensor:
|
|
101
|
+
"""Maps the patch embeddings to a segmentation mask of the image size.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
features: List of multi-level image features of shape (batch_size,
|
|
105
|
+
hidden_size, n_patches_height, n_patches_width).
|
|
106
|
+
image_size: The target image size (height, width).
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Tensor containing scores for all of the classes with shape
|
|
110
|
+
(batch_size, n_classes, image_height, image_width).
|
|
111
|
+
"""
|
|
112
|
+
patch_embeddings = self._forward_features(features)
|
|
113
|
+
logits = self._forward_head(patch_embeddings)
|
|
114
|
+
return self._cls_seg(logits, image_size)
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
"""Linear based decoder."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
from torch.nn import functional
|
|
8
|
+
|
|
9
|
+
from eva.vision.models.networks.decoders import decoder
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LinearDecoder(decoder.Decoder):
|
|
13
|
+
"""Linear decoder."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, layers: nn.Module) -> None:
|
|
16
|
+
"""Initializes the linear based decoder head.
|
|
17
|
+
|
|
18
|
+
Here the input nn layers will be applied to the reshaped
|
|
19
|
+
features (batch_size, patch_embeddings, hidden_size) from
|
|
20
|
+
the input (batch_size, hidden_size, height, width) and then
|
|
21
|
+
unwrapped again to (batch_size, n_classes, height, width).
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
layers: The linear layers to be used as the decoder head.
|
|
25
|
+
"""
|
|
26
|
+
super().__init__()
|
|
27
|
+
|
|
28
|
+
self._layers = layers
|
|
29
|
+
|
|
30
|
+
def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor:
|
|
31
|
+
"""Forward function for multi-level feature maps to a single one.
|
|
32
|
+
|
|
33
|
+
It will interpolate the features and concat them into a single tensor
|
|
34
|
+
on the dimension axis of the hidden size.
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
>>> features = [torch.Tensor(16, 384, 14, 14), torch.Size(16, 384, 14, 14)]
|
|
38
|
+
>>> output = self._forward_features(features)
|
|
39
|
+
>>> assert output.shape == torch.Size([16, 768, 14, 14])
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
features: List of multi-level image features of shape (batch_size,
|
|
43
|
+
hidden_size, n_patches_height, n_patches_width).
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
A tensor of shape (batch_size, hidden_size, n_patches_height,
|
|
47
|
+
n_patches_width) which is feature map of the decoder head.
|
|
48
|
+
"""
|
|
49
|
+
if not isinstance(features, list) or features[0].ndim != 4:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"Input features should be a list of four (4) dimensional inputs of "
|
|
52
|
+
"shape (batch_size, hidden_size, n_patches_height, n_patches_width)."
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
upsampled_features = [
|
|
56
|
+
functional.interpolate(
|
|
57
|
+
input=embeddings,
|
|
58
|
+
size=features[0].shape[2:],
|
|
59
|
+
mode="bilinear",
|
|
60
|
+
align_corners=False,
|
|
61
|
+
)
|
|
62
|
+
for embeddings in features
|
|
63
|
+
]
|
|
64
|
+
return torch.cat(upsampled_features, dim=1)
|
|
65
|
+
|
|
66
|
+
def _forward_head(self, patch_embeddings: torch.Tensor) -> torch.Tensor:
|
|
67
|
+
"""Forward of the decoder head.
|
|
68
|
+
|
|
69
|
+
Here the following transformations will take place:
|
|
70
|
+
- (batch_size, hidden_size, n_patches_height, n_patches_width)
|
|
71
|
+
- (batch_size, hidden_size, n_patches_height * n_patches_width)
|
|
72
|
+
- (batch_size, n_patches_height * n_patches_width, hidden_size)
|
|
73
|
+
- (batch_size, n_patches_height * n_patches_width, n_classes)
|
|
74
|
+
- (batch_size, n_classes, n_patches_height, n_patches_width)
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
patch_embeddings: The patch embeddings tensor of shape
|
|
78
|
+
(batch_size, hidden_size, n_patches_height, n_patches_width).
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
The logits as a tensor (batch_size, n_classes, n_patches_height,
|
|
82
|
+
n_patches_width).
|
|
83
|
+
"""
|
|
84
|
+
batch_size, hidden_size, height, width = patch_embeddings.shape
|
|
85
|
+
embeddings_reshaped = patch_embeddings.reshape(batch_size, hidden_size, height * width)
|
|
86
|
+
logits = self._layers(embeddings_reshaped.permute(0, 2, 1))
|
|
87
|
+
return logits.permute(0, 2, 1).reshape(batch_size, -1, height, width)
|
|
88
|
+
|
|
89
|
+
def _cls_seg(
|
|
90
|
+
self,
|
|
91
|
+
logits: torch.Tensor,
|
|
92
|
+
image_size: Tuple[int, int],
|
|
93
|
+
) -> torch.Tensor:
|
|
94
|
+
"""Classify each pixel of the image.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
logits: The decoder outputs of shape (batch_size, n_classes,
|
|
98
|
+
height, width).
|
|
99
|
+
image_size: The target image size (height, width).
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Tensor containing scores for all of the classes with shape
|
|
103
|
+
(batch_size, n_classes, image_height, image_width).
|
|
104
|
+
"""
|
|
105
|
+
return functional.interpolate(logits, image_size, mode="bilinear")
|
|
106
|
+
|
|
107
|
+
def forward(
|
|
108
|
+
self,
|
|
109
|
+
features: List[torch.Tensor],
|
|
110
|
+
image_size: Tuple[int, int],
|
|
111
|
+
) -> torch.Tensor:
|
|
112
|
+
"""Maps the patch embeddings to a segmentation mask of the image size.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
features: List of multi-level image features of shape (batch_size,
|
|
116
|
+
hidden_size, n_patches_height, n_patches_width).
|
|
117
|
+
image_size: The target image size (height, width).
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
Tensor containing scores for all of the classes with shape
|
|
121
|
+
(batch_size, n_classes, image_height, image_width).
|
|
122
|
+
"""
|
|
123
|
+
patch_embeddings = self._forward_features(features)
|
|
124
|
+
logits = self._forward_head(patch_embeddings)
|
|
125
|
+
return self._cls_seg(logits, image_size)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Vision backbone helper class."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Callable, Dict
|
|
4
|
+
|
|
5
|
+
from typing_extensions import override
|
|
6
|
+
|
|
7
|
+
from eva.core.models import wrappers
|
|
8
|
+
from eva.vision.models.networks.backbones import BackboneModelRegistry
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ModelFromRegistry(wrappers.BaseModel):
|
|
12
|
+
"""Wrapper class for vision backbone models.
|
|
13
|
+
|
|
14
|
+
This class can be used by load backbones available in eva's
|
|
15
|
+
model registry by name. New backbones can be registered by using
|
|
16
|
+
the `@register_model(model_name)` decorator.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
model_name: str,
|
|
22
|
+
model_kwargs: Dict[str, Any] | None = None,
|
|
23
|
+
model_extra_kwargs: Dict[str, Any] | None = None,
|
|
24
|
+
tensor_transforms: Callable | None = None,
|
|
25
|
+
) -> None:
|
|
26
|
+
"""Initializes the model.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
model_name: The name of the model to load.
|
|
30
|
+
model_kwargs: The arguments used for instantiating the model.
|
|
31
|
+
model_extra_kwargs: Extra arguments used for instantiating the model.
|
|
32
|
+
tensor_transforms: The transforms to apply to the output tensor
|
|
33
|
+
produced by the model.
|
|
34
|
+
"""
|
|
35
|
+
super().__init__(tensor_transforms=tensor_transforms)
|
|
36
|
+
|
|
37
|
+
self._model_name = model_name
|
|
38
|
+
self._model_kwargs = model_kwargs or {}
|
|
39
|
+
self._model_extra_kwargs = model_extra_kwargs or {}
|
|
40
|
+
|
|
41
|
+
self.load_model()
|
|
42
|
+
|
|
43
|
+
@override
|
|
44
|
+
def load_model(self) -> None:
|
|
45
|
+
self._model = BackboneModelRegistry.load_model(
|
|
46
|
+
self._model_name, self._model_kwargs | self._model_extra_kwargs
|
|
47
|
+
)
|
|
48
|
+
ModelFromRegistry.__name__ = self._model_name
|