kaiko-eva 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (159) hide show
  1. eva/core/callbacks/__init__.py +2 -2
  2. eva/core/callbacks/writers/__init__.py +6 -3
  3. eva/core/callbacks/writers/embeddings/__init__.py +6 -0
  4. eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
  5. eva/core/callbacks/writers/embeddings/base.py +192 -0
  6. eva/core/callbacks/writers/embeddings/classification.py +117 -0
  7. eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
  8. eva/core/callbacks/writers/embeddings/typings.py +38 -0
  9. eva/core/data/datasets/__init__.py +2 -2
  10. eva/core/data/datasets/classification/__init__.py +8 -0
  11. eva/core/data/datasets/classification/embeddings.py +34 -0
  12. eva/core/data/datasets/{embeddings/classification → classification}/multi_embeddings.py +13 -9
  13. eva/core/data/datasets/{embeddings/base.py → embeddings.py} +47 -32
  14. eva/core/data/splitting/__init__.py +6 -0
  15. eva/core/data/splitting/random.py +41 -0
  16. eva/core/data/splitting/stratified.py +56 -0
  17. eva/core/loggers/experimental_loggers.py +2 -2
  18. eva/core/loggers/log/__init__.py +3 -2
  19. eva/core/loggers/log/image.py +71 -0
  20. eva/core/loggers/log/parameters.py +10 -0
  21. eva/core/loggers/loggers.py +6 -0
  22. eva/core/metrics/__init__.py +6 -2
  23. eva/core/metrics/defaults/__init__.py +10 -3
  24. eva/core/metrics/defaults/classification/__init__.py +1 -1
  25. eva/core/metrics/defaults/classification/binary.py +0 -9
  26. eva/core/metrics/defaults/classification/multiclass.py +0 -8
  27. eva/core/metrics/defaults/segmentation/__init__.py +5 -0
  28. eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
  29. eva/core/metrics/generalized_dice.py +59 -0
  30. eva/core/metrics/mean_iou.py +120 -0
  31. eva/core/metrics/structs/schemas.py +3 -1
  32. eva/core/models/__init__.py +3 -1
  33. eva/core/models/modules/head.py +10 -4
  34. eva/core/models/modules/typings.py +14 -1
  35. eva/core/models/modules/utils/batch_postprocess.py +37 -5
  36. eva/core/models/networks/__init__.py +1 -2
  37. eva/core/models/networks/mlp.py +2 -2
  38. eva/core/models/transforms/__init__.py +6 -0
  39. eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
  40. eva/core/models/transforms/extract_patch_features.py +47 -0
  41. eva/core/models/wrappers/__init__.py +13 -0
  42. eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
  43. eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
  44. eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
  45. eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
  46. eva/core/trainers/functional.py +1 -0
  47. eva/core/utils/__init__.py +6 -0
  48. eva/core/utils/clone.py +27 -0
  49. eva/core/utils/memory.py +28 -0
  50. eva/core/utils/operations.py +26 -0
  51. eva/core/utils/parser.py +20 -0
  52. eva/vision/__init__.py +2 -2
  53. eva/vision/callbacks/__init__.py +5 -0
  54. eva/vision/callbacks/loggers/__init__.py +5 -0
  55. eva/vision/callbacks/loggers/batch/__init__.py +5 -0
  56. eva/vision/callbacks/loggers/batch/base.py +130 -0
  57. eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
  58. eva/vision/data/datasets/__init__.py +30 -3
  59. eva/vision/data/datasets/_validators.py +15 -2
  60. eva/vision/data/datasets/classification/__init__.py +12 -1
  61. eva/vision/data/datasets/classification/bach.py +10 -15
  62. eva/vision/data/datasets/classification/base.py +17 -24
  63. eva/vision/data/datasets/classification/camelyon16.py +244 -0
  64. eva/vision/data/datasets/classification/crc.py +10 -15
  65. eva/vision/data/datasets/classification/mhist.py +10 -15
  66. eva/vision/data/datasets/classification/panda.py +184 -0
  67. eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
  68. eva/vision/data/datasets/classification/wsi.py +105 -0
  69. eva/vision/data/datasets/segmentation/__init__.py +15 -2
  70. eva/vision/data/datasets/segmentation/_utils.py +38 -0
  71. eva/vision/data/datasets/segmentation/base.py +16 -17
  72. eva/vision/data/datasets/segmentation/bcss.py +236 -0
  73. eva/vision/data/datasets/segmentation/consep.py +156 -0
  74. eva/vision/data/datasets/segmentation/embeddings.py +34 -0
  75. eva/vision/data/datasets/segmentation/lits.py +178 -0
  76. eva/vision/data/datasets/segmentation/monusac.py +236 -0
  77. eva/vision/data/datasets/segmentation/{total_segmentator.py → total_segmentator_2d.py} +130 -36
  78. eva/vision/data/datasets/wsi.py +187 -0
  79. eva/vision/data/transforms/__init__.py +3 -2
  80. eva/vision/data/transforms/common/__init__.py +2 -1
  81. eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
  82. eva/vision/data/transforms/common/resize_and_crop.py +6 -7
  83. eva/vision/data/transforms/normalization/__init__.py +6 -0
  84. eva/vision/data/transforms/normalization/clamp.py +43 -0
  85. eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
  86. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
  87. eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
  88. eva/vision/data/wsi/__init__.py +16 -0
  89. eva/vision/data/wsi/backends/__init__.py +69 -0
  90. eva/vision/data/wsi/backends/base.py +115 -0
  91. eva/vision/data/wsi/backends/openslide.py +73 -0
  92. eva/vision/data/wsi/backends/pil.py +52 -0
  93. eva/vision/data/wsi/backends/tiffslide.py +42 -0
  94. eva/vision/data/wsi/patching/__init__.py +6 -0
  95. eva/vision/data/wsi/patching/coordinates.py +98 -0
  96. eva/vision/data/wsi/patching/mask.py +123 -0
  97. eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
  98. eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
  99. eva/vision/data/wsi/patching/samplers/base.py +48 -0
  100. eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
  101. eva/vision/data/wsi/patching/samplers/grid.py +47 -0
  102. eva/vision/data/wsi/patching/samplers/random.py +41 -0
  103. eva/vision/losses/__init__.py +5 -0
  104. eva/vision/losses/dice.py +40 -0
  105. eva/vision/models/__init__.py +4 -2
  106. eva/vision/models/modules/__init__.py +5 -0
  107. eva/vision/models/modules/semantic_segmentation.py +161 -0
  108. eva/vision/models/networks/__init__.py +1 -2
  109. eva/vision/models/networks/backbones/__init__.py +6 -0
  110. eva/vision/models/networks/backbones/_utils.py +39 -0
  111. eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
  112. eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
  113. eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
  114. eva/vision/models/networks/backbones/pathology/histai.py +46 -0
  115. eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
  116. eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
  117. eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
  118. eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
  119. eva/vision/models/networks/backbones/registry.py +47 -0
  120. eva/vision/models/networks/backbones/timm/__init__.py +5 -0
  121. eva/vision/models/networks/backbones/timm/backbones.py +54 -0
  122. eva/vision/models/networks/backbones/universal/__init__.py +8 -0
  123. eva/vision/models/networks/backbones/universal/vit.py +54 -0
  124. eva/vision/models/networks/decoders/__init__.py +6 -0
  125. eva/vision/models/networks/decoders/decoder.py +7 -0
  126. eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
  127. eva/vision/models/networks/decoders/segmentation/common.py +74 -0
  128. eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
  129. eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
  130. eva/vision/models/wrappers/__init__.py +6 -0
  131. eva/vision/models/wrappers/from_registry.py +48 -0
  132. eva/vision/models/wrappers/from_timm.py +68 -0
  133. eva/vision/utils/colormap.py +77 -0
  134. eva/vision/utils/convert.py +56 -13
  135. eva/vision/utils/io/__init__.py +10 -4
  136. eva/vision/utils/io/image.py +21 -2
  137. eva/vision/utils/io/mat.py +36 -0
  138. eva/vision/utils/io/nifti.py +33 -12
  139. eva/vision/utils/io/text.py +10 -3
  140. kaiko_eva-0.1.1.dist-info/METADATA +553 -0
  141. kaiko_eva-0.1.1.dist-info/RECORD +205 -0
  142. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/WHEEL +1 -1
  143. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/entry_points.txt +2 -0
  144. eva/.DS_Store +0 -0
  145. eva/core/callbacks/writers/embeddings.py +0 -169
  146. eva/core/callbacks/writers/typings.py +0 -23
  147. eva/core/data/datasets/embeddings/__init__.py +0 -13
  148. eva/core/data/datasets/embeddings/classification/__init__.py +0 -10
  149. eva/core/data/datasets/embeddings/classification/embeddings.py +0 -66
  150. eva/core/models/networks/transforms/__init__.py +0 -5
  151. eva/core/models/networks/wrappers/__init__.py +0 -8
  152. eva/vision/models/.DS_Store +0 -0
  153. eva/vision/models/networks/.DS_Store +0 -0
  154. eva/vision/models/networks/postprocesses/__init__.py +0 -5
  155. eva/vision/models/networks/postprocesses/cls.py +0 -25
  156. kaiko_eva-0.0.2.dist-info/METADATA +0 -431
  157. kaiko_eva-0.0.2.dist-info/RECORD +0 -127
  158. /eva/core/models/{networks → wrappers}/_utils.py +0 -0
  159. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,46 @@
1
+ """Pathology FMs from owkin."""
2
+
3
+ from typing import Tuple
4
+
5
+ from torch import nn
6
+
7
+ from eva.vision.models.networks.backbones import _utils
8
+ from eva.vision.models.networks.backbones.registry import register_model
9
+
10
+
11
+ @register_model("pathology/histai_hibou_b")
12
+ def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
13
+ """Initializes the hibou-B pathology FM by hist.ai (https://huggingface.co/histai/hibou-B).
14
+
15
+ Args:
16
+ out_indices: Whether and which multi-level patch embeddings to return.
17
+ Currently only out_indices=1 is supported.
18
+
19
+ Returns:
20
+ The model instance.
21
+ """
22
+ return _utils.load_hugingface_model(
23
+ model_name="histai/hibou-B",
24
+ out_indices=out_indices,
25
+ model_kwargs={"trust_remote_code": True},
26
+ transform_args={"ignore_remaining_dims": True} if out_indices is not None else None,
27
+ )
28
+
29
+
30
+ @register_model("pathology/histai_hibou_l")
31
+ def histai_hibou_l(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
32
+ """Initializes the hibou-L pathology FM by hist.ai (https://huggingface.co/histai/hibou-L).
33
+
34
+ Args:
35
+ out_indices: Whether and which multi-level patch embeddings to return.
36
+ Currently only out_indices=1 is supported.
37
+
38
+ Returns:
39
+ The model instance.
40
+ """
41
+ return _utils.load_hugingface_model(
42
+ model_name="histai/hibou-L",
43
+ out_indices=out_indices,
44
+ model_kwargs={"trust_remote_code": True},
45
+ transform_args={"ignore_remaining_dims": True} if out_indices is not None else None,
46
+ )
@@ -0,0 +1,123 @@
1
+ """Pathology FMs from kaiko.ai."""
2
+
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from eva.vision.models.networks.backbones.registry import register_model
9
+
10
+
11
+ @register_model("pathology/kaiko_vits16")
12
+ def kaiko_vits16(
13
+ dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
14
+ ) -> nn.Module:
15
+ """Initializes the ViTS-16 pathology FM by kaiko.ai.
16
+
17
+ Args:
18
+ dynamic_img_size: Support different input image sizes by allowing to change
19
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
20
+ out_indices: Whether and which multi-level patch embeddings to return.
21
+
22
+ Returns:
23
+ The model instance.
24
+ """
25
+ return torch.hub.load( # type: ignore
26
+ repo_or_dir="kaiko-ai/towards_large_pathology_fms",
27
+ model="vits16",
28
+ trust_repo=True,
29
+ dynamic_img_size=dynamic_img_size,
30
+ out_indices=out_indices,
31
+ )
32
+
33
+
34
+ @register_model("pathology/kaiko_vits8")
35
+ def kaiko_vits8(
36
+ dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
37
+ ) -> nn.Module:
38
+ """Initializes the ViTS-8 pathology FM by kaiko.ai.
39
+
40
+ Args:
41
+ dynamic_img_size: Support different input image sizes by allowing to change
42
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
43
+ out_indices: Whether and which multi-level patch embeddings to return.
44
+
45
+ Returns:
46
+ The model instance.
47
+ """
48
+ return torch.hub.load( # type: ignore
49
+ repo_or_dir="kaiko-ai/towards_large_pathology_fms",
50
+ model="vits8",
51
+ trust_repo=True,
52
+ dynamic_img_size=dynamic_img_size,
53
+ out_indices=out_indices,
54
+ )
55
+
56
+
57
+ @register_model("pathology/kaiko_vitb16")
58
+ def kaiko_vitb16(
59
+ dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
60
+ ) -> nn.Module:
61
+ """Initializes the ViTB-16 pathology FM by kaiko.ai.
62
+
63
+ Args:
64
+ dynamic_img_size: Support different input image sizes by allowing to change
65
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
66
+ out_indices: Whether and which multi-level patch embeddings to return.
67
+
68
+ Returns:
69
+ The model instance.
70
+ """
71
+ return torch.hub.load( # type: ignore
72
+ repo_or_dir="kaiko-ai/towards_large_pathology_fms",
73
+ model="vitb16",
74
+ trust_repo=True,
75
+ dynamic_img_size=dynamic_img_size,
76
+ out_indices=out_indices,
77
+ )
78
+
79
+
80
+ @register_model("pathology/kaiko_vitb8")
81
+ def kaiko_vitb8(
82
+ dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
83
+ ) -> nn.Module:
84
+ """Initializes the ViTB-8 pathology FM by kaiko.ai.
85
+
86
+ Args:
87
+ dynamic_img_size: Support different input image sizes by allowing to change
88
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
89
+ out_indices: Whether and which multi-level patch embeddings to return.
90
+
91
+ Returns:
92
+ The model instance.
93
+ """
94
+ return torch.hub.load( # type: ignore
95
+ repo_or_dir="kaiko-ai/towards_large_pathology_fms",
96
+ model="vitb8",
97
+ trust_repo=True,
98
+ dynamic_img_size=dynamic_img_size,
99
+ out_indices=out_indices,
100
+ )
101
+
102
+
103
+ @register_model("pathology/kaiko_vitl14")
104
+ def kaiko_vitl14(
105
+ dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
106
+ ) -> nn.Module:
107
+ """Initializes the ViTL-14 pathology FM by kaiko.ai.
108
+
109
+ Args:
110
+ dynamic_img_size: Support different input image sizes by allowing to change
111
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
112
+ out_indices: Whether and which multi-level patch embeddings to return.
113
+
114
+ Returns:
115
+ The model instance.
116
+ """
117
+ return torch.hub.load( # type: ignore
118
+ repo_or_dir="kaiko-ai/towards_large_pathology_fms",
119
+ model="vitl14",
120
+ trust_repo=True,
121
+ dynamic_img_size=dynamic_img_size,
122
+ out_indices=out_indices,
123
+ )
@@ -0,0 +1,68 @@
1
+ """Pathology FMs from Lunit.
2
+
3
+ Source: https://github.com/lunit-io/benchmark-ssl-pathology/releases
4
+
5
+ For training the vit-s models the following standardization parameters were used:
6
+
7
+ mean: [ 0.70322989, 0.53606487, 0.66096631 ]
8
+ std: [ 0.21716536, 0.26081574, 0.20723464 ]
9
+ """
10
+
11
+ from typing import Tuple
12
+
13
+ from torch import nn
14
+
15
+ from eva.vision.models import wrappers
16
+ from eva.vision.models.networks.backbones.registry import register_model
17
+
18
+ VITS_URL_PREFIX = (
19
+ "https://github.com/lunit-io/benchmark-ssl-pathology/releases/download/pretrained-weights"
20
+ )
21
+
22
+
23
+ @register_model("pathology/lunit_vits16")
24
+ def lunit_vits16(
25
+ dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
26
+ ) -> nn.Module:
27
+ """Initializes the ViTS-16 pathology FM by lunit.
28
+
29
+ Args:
30
+ dynamic_img_size: Support different input image sizes by allowing to change
31
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
32
+ out_indices: Whether and which multi-level patch embeddings to return.
33
+
34
+ Returns:
35
+ The model instance.
36
+ """
37
+ return wrappers.TimmModel(
38
+ model_name="vit_small_patch16_224.dino",
39
+ out_indices=out_indices,
40
+ model_kwargs={
41
+ "dynamic_img_size": dynamic_img_size,
42
+ },
43
+ checkpoint_path=f"{VITS_URL_PREFIX}/dino_vit_small_patch16_ep200.torch",
44
+ )
45
+
46
+
47
+ @register_model("pathology/lunit_vits8")
48
+ def lunit_vits8(
49
+ dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
50
+ ) -> nn.Module:
51
+ """Initializes the ViTS-8 pathology FM by lunit.
52
+
53
+ Args:
54
+ dynamic_img_size: Support different input image sizes by allowing to change
55
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
56
+ out_indices: Whether and which multi-level patch embeddings to return.
57
+
58
+ Returns:
59
+ The model instance.
60
+ """
61
+ return wrappers.TimmModel(
62
+ model_name="vit_small_patch8_224.dino",
63
+ out_indices=out_indices,
64
+ model_kwargs={
65
+ "dynamic_img_size": dynamic_img_size,
66
+ },
67
+ checkpoint_path=f"{VITS_URL_PREFIX}/dino_vit_small_patch8_ep200.torch",
68
+ )
@@ -0,0 +1,62 @@
1
+ """Pathology FMs from MahmoodLab."""
2
+
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Tuple
6
+
7
+ import huggingface_hub
8
+ from loguru import logger
9
+ from torch import nn
10
+
11
+ from eva.vision.models import wrappers
12
+ from eva.vision.models.networks.backbones.registry import register_model
13
+
14
+
15
+ @register_model("pathology/mahmood_uni")
16
+ def mahmood_uni(
17
+ dynamic_img_size: bool = True,
18
+ out_indices: int | Tuple[int, ...] | None = None,
19
+ hf_token: str | None = None,
20
+ download_dir: str = os.path.join(str(Path.home()), ".cache/eva"),
21
+ ) -> nn.Module:
22
+ """Initializes UNI model from MahmoodLab.
23
+
24
+ Args:
25
+ dynamic_img_size: Support different input image sizes by allowing to change
26
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
27
+ out_indices: Whether and which multi-level patch embeddings to return.
28
+ hf_token: HuggingFace token to download the model.
29
+ download_dir: Directory to download the model checkpoint.
30
+
31
+ Returns:
32
+ The model instance.
33
+ """
34
+ token = hf_token or os.environ.get("HF_TOKEN")
35
+ if not token:
36
+ raise ValueError(
37
+ "Please provide a HuggingFace token to download the model. "
38
+ "You can either pass it as an argument or set the env variable HF_TOKEN."
39
+ )
40
+
41
+ checkpoint_path = os.path.join(download_dir, "pytorch_model.bin")
42
+
43
+ if not os.path.exists(checkpoint_path):
44
+ logger.info(f"Downloading the model checkpoint to {download_dir} ...")
45
+ os.makedirs(download_dir, exist_ok=True)
46
+ huggingface_hub.login(token=token)
47
+ huggingface_hub.hf_hub_download(
48
+ "MahmoodLab/UNI",
49
+ filename="pytorch_model.bin",
50
+ local_dir=download_dir,
51
+ force_download=True,
52
+ )
53
+
54
+ return wrappers.TimmModel(
55
+ model_name="vit_large_patch16_224",
56
+ out_indices=out_indices,
57
+ model_kwargs={
58
+ "init_values": 1e-5,
59
+ "dynamic_img_size": dynamic_img_size,
60
+ },
61
+ checkpoint_path=checkpoint_path,
62
+ )
@@ -0,0 +1,22 @@
1
+ """Pathology FMs from owkin."""
2
+
3
+ from typing import Tuple
4
+
5
+ from torch import nn
6
+
7
+ from eva.vision.models.networks.backbones import _utils
8
+ from eva.vision.models.networks.backbones.registry import register_model
9
+
10
+
11
+ @register_model("pathology/owkin_phikon")
12
+ def owkin_phikon(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
13
+ """Initializes the phikon pathology FM by owkin (https://huggingface.co/owkin/phikon).
14
+
15
+ Args:
16
+ out_indices: Whether and which multi-level patch embeddings to return.
17
+ Currently only out_indices=1 is supported.
18
+
19
+ Returns:
20
+ The model instance.
21
+ """
22
+ return _utils.load_hugingface_model(model_name="owkin/phikon", out_indices=out_indices)
@@ -0,0 +1,47 @@
1
+ """Backbone Model Registry."""
2
+
3
+ from typing import Any, Callable, Dict, List
4
+
5
+ import torch.nn as nn
6
+
7
+
8
+ class BackboneModelRegistry:
9
+ """A model registry for accessing backbone models by name."""
10
+
11
+ _registry: Dict[str, Callable[..., nn.Module]] = {}
12
+
13
+ @classmethod
14
+ def register(cls, name: str) -> Callable:
15
+ """Decorator to register a new model."""
16
+
17
+ def decorator(model_fn: Callable[..., nn.Module]) -> Callable[..., nn.Module]:
18
+ if name in cls._registry:
19
+ raise ValueError(f"Model {name} is already registered.")
20
+ cls._registry[name] = model_fn
21
+ return model_fn
22
+
23
+ return decorator
24
+
25
+ @classmethod
26
+ def get(cls, model_name: str) -> Callable[..., nn.Module]:
27
+ """Gets a model function from the registry."""
28
+ if model_name not in cls._registry:
29
+ raise ValueError(f"Model {model_name} not found in the registry.")
30
+ return cls._registry[model_name]
31
+
32
+ @classmethod
33
+ def load_model(cls, model_name: str, model_kwargs: Dict[str, Any] | None = None) -> nn.Module:
34
+ """Loads & initializes a model class from the registry."""
35
+ model_fn = cls.get(model_name)
36
+ return model_fn(**(model_kwargs or {}))
37
+
38
+ @classmethod
39
+ def list_models(cls) -> List[str]:
40
+ """List all models in the registry."""
41
+ register_models = [name for name in cls._registry.keys() if not name.startswith("timm")]
42
+ return register_models + ["timm/<model_name>"]
43
+
44
+
45
+ def register_model(name: str) -> Callable:
46
+ """Simple decorator to register a model."""
47
+ return BackboneModelRegistry.register(name)
@@ -0,0 +1,5 @@
1
+ """timm backbones API."""
2
+
3
+ from eva.vision.models.networks.backbones.timm.backbones import timm_model
4
+
5
+ __all__ = ["timm_model"]
@@ -0,0 +1,54 @@
1
+ """timm backbones."""
2
+
3
+ import functools
4
+ from typing import Tuple
5
+
6
+ import timm
7
+ from loguru import logger
8
+ from torch import nn
9
+
10
+ from eva.vision.models import wrappers
11
+ from eva.vision.models.networks.backbones.registry import BackboneModelRegistry
12
+
13
+
14
+ def timm_model(
15
+ model_name: str,
16
+ checkpoint_path: str | None = None,
17
+ pretrained: bool = False,
18
+ dynamic_img_size: bool = True,
19
+ out_indices: int | Tuple[int, ...] | None = None,
20
+ **kwargs,
21
+ ) -> nn.Module:
22
+ """Initializes any ViT model from timm with weights from a specified checkpoint.
23
+
24
+ Args:
25
+ model_name: The name of the model to load.
26
+ checkpoint_path: The path to the checkpoint file.
27
+ pretrained: If set to `True`, load pretrained ImageNet-1k weights.
28
+ dynamic_img_size: Support different input image sizes by allowing to change
29
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
30
+ out_indices: Whether and which multi-level patch embeddings to return.
31
+ **kwargs: Additional arguments to pass to the model
32
+
33
+ Returns:
34
+ The VIT model instance.
35
+ """
36
+ logger.info(
37
+ f"Loading timm model {model_name}"
38
+ + (f"using checkpoint {checkpoint_path}" if checkpoint_path else "")
39
+ )
40
+ return wrappers.TimmModel(
41
+ model_name=model_name,
42
+ checkpoint_path=checkpoint_path or "",
43
+ pretrained=pretrained,
44
+ out_indices=out_indices,
45
+ model_kwargs=kwargs,
46
+ )
47
+
48
+
49
+ BackboneModelRegistry._registry.update(
50
+ {
51
+ f"timm/{model_name}": functools.partial(timm_model, model_name=model_name)
52
+ for model_name in timm.list_models()
53
+ }
54
+ )
@@ -0,0 +1,8 @@
1
+ """Universal Vision Model Backbones API."""
2
+
3
+ from eva.vision.models.networks.backbones.universal.vit import (
4
+ vit_small_patch16_224_dino,
5
+ vit_small_patch16_224_random,
6
+ )
7
+
8
+ __all__ = ["vit_small_patch16_224_dino", "vit_small_patch16_224_random"]
@@ -0,0 +1,54 @@
1
+ """Vision Transformers base universal backbones."""
2
+
3
+ from typing import Tuple
4
+
5
+ import timm
6
+ from torch import nn
7
+
8
+ from eva.vision.models.networks.backbones.registry import register_model
9
+
10
+
11
+ @register_model("universal/vit_small_patch16_224_random")
12
+ def vit_small_patch16_224_random(
13
+ dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
14
+ ) -> nn.Module:
15
+ """Initializes a ViTS-16 baseline model with random weights.
16
+
17
+ Args:
18
+ dynamic_img_size: Support different input image sizes by allowing to change
19
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
20
+ out_indices: Whether and which multi-level patch embeddings to return.
21
+
22
+ Returns:
23
+ The torch ViTS-16 based foundation model.
24
+ """
25
+ return timm.create_model(
26
+ model_name="vit_small_patch16_224.dino",
27
+ pretrained=False,
28
+ features_only=out_indices is not None,
29
+ out_indices=out_indices,
30
+ dynamic_img_size=dynamic_img_size,
31
+ )
32
+
33
+
34
+ @register_model("universal/vit_small_patch16_224_dino")
35
+ def vit_small_patch16_224_dino(
36
+ dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
37
+ ) -> nn.Module:
38
+ """Initializes a ViTS-16 baseline model pretrained w/ DINO.
39
+
40
+ Args:
41
+ dynamic_img_size: Support different input image sizes by allowing to change
42
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
43
+ out_indices: Whether and which multi-level patch embeddings to return.
44
+
45
+ Returns:
46
+ The torch ViTS-16 based foundation model.
47
+ """
48
+ return timm.create_model(
49
+ model_name="vit_small_patch16_224.dino",
50
+ pretrained=True,
51
+ features_only=out_indices is not None,
52
+ out_indices=out_indices,
53
+ dynamic_img_size=dynamic_img_size,
54
+ )
@@ -0,0 +1,6 @@
1
+ """Decoder heads API."""
2
+
3
+ from eva.vision.models.networks.decoders import segmentation
4
+ from eva.vision.models.networks.decoders.decoder import Decoder
5
+
6
+ __all__ = ["segmentation", "Decoder"]
@@ -0,0 +1,7 @@
1
+ """Semantic segmentation decoder base class."""
2
+
3
+ from torch import nn
4
+
5
+
6
+ class Decoder(nn.Module):
7
+ """Semantic segmentation decoder base class."""
@@ -0,0 +1,11 @@
1
+ """Segmentation decoder heads API."""
2
+
3
+ from eva.vision.models.networks.decoders.segmentation.common import (
4
+ ConvDecoder1x1,
5
+ ConvDecoderMS,
6
+ SingleLinearDecoder,
7
+ )
8
+ from eva.vision.models.networks.decoders.segmentation.conv2d import ConvDecoder
9
+ from eva.vision.models.networks.decoders.segmentation.linear import LinearDecoder
10
+
11
+ __all__ = ["ConvDecoder1x1", "ConvDecoderMS", "SingleLinearDecoder", "ConvDecoder", "LinearDecoder"]
@@ -0,0 +1,74 @@
1
+ """Common semantic segmentation decoders.
2
+
3
+ This module contains implementations of different types of decoder models
4
+ used in semantic segmentation. These decoders convert the high-level features
5
+ output by an encoder into pixel-wise predictions for segmentation tasks.
6
+ """
7
+
8
+ from torch import nn
9
+
10
+ from eva.vision.models.networks.decoders.segmentation import conv2d, linear
11
+
12
+
13
+ class ConvDecoder1x1(conv2d.ConvDecoder):
14
+ """A convolutional decoder with a single 1x1 convolutional layer."""
15
+
16
+ def __init__(self, in_features: int, num_classes: int) -> None:
17
+ """Initializes the decoder.
18
+
19
+ Args:
20
+ in_features: The hidden dimension size of the embeddings.
21
+ num_classes: Number of output classes as channels.
22
+ """
23
+ super().__init__(
24
+ layers=nn.Conv2d(
25
+ in_channels=in_features,
26
+ out_channels=num_classes,
27
+ kernel_size=(1, 1),
28
+ ),
29
+ )
30
+
31
+
32
+ class ConvDecoderMS(conv2d.ConvDecoder):
33
+ """A multi-stage convolutional decoder with upsampling and convolutional layers.
34
+
35
+ This decoder applies a series of upsampling and convolutional layers to transform
36
+ the input features into output predictions with the desired spatial resolution.
37
+
38
+ This decoder is based on the `+ms` segmentation decoder from DINOv2
39
+ (https://arxiv.org/pdf/2304.07193)
40
+ """
41
+
42
+ def __init__(self, in_features: int, num_classes: int) -> None:
43
+ """Initializes the decoder.
44
+
45
+ Args:
46
+ in_features: The hidden dimension size of the embeddings.
47
+ num_classes: Number of output classes as channels.
48
+ """
49
+ super().__init__(
50
+ layers=nn.Sequential(
51
+ nn.Upsample(scale_factor=2),
52
+ nn.Conv2d(in_features, 64, kernel_size=(3, 3), padding=(1, 1)),
53
+ nn.Upsample(scale_factor=2),
54
+ nn.Conv2d(64, num_classes, kernel_size=(3, 3), padding=(1, 1)),
55
+ ),
56
+ )
57
+
58
+
59
+ class SingleLinearDecoder(linear.LinearDecoder):
60
+ """A simple linear decoder with a single fully connected layer."""
61
+
62
+ def __init__(self, in_features: int, num_classes: int) -> None:
63
+ """Initializes the decoder.
64
+
65
+ Args:
66
+ in_features: The hidden dimension size of the embeddings.
67
+ num_classes: Number of output classes as channels.
68
+ """
69
+ super().__init__(
70
+ layers=nn.Linear(
71
+ in_features=in_features,
72
+ out_features=num_classes,
73
+ ),
74
+ )