kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (168) hide show
  1. eva/core/callbacks/__init__.py +3 -2
  2. eva/core/callbacks/config.py +143 -0
  3. eva/core/callbacks/writers/__init__.py +6 -3
  4. eva/core/callbacks/writers/embeddings/__init__.py +6 -0
  5. eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
  6. eva/core/callbacks/writers/embeddings/base.py +192 -0
  7. eva/core/callbacks/writers/embeddings/classification.py +117 -0
  8. eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
  9. eva/core/callbacks/writers/embeddings/typings.py +38 -0
  10. eva/core/data/datasets/__init__.py +10 -2
  11. eva/core/data/datasets/classification/__init__.py +5 -2
  12. eva/core/data/datasets/classification/embeddings.py +15 -135
  13. eva/core/data/datasets/classification/multi_embeddings.py +110 -0
  14. eva/core/data/datasets/embeddings.py +167 -0
  15. eva/core/data/splitting/__init__.py +6 -0
  16. eva/core/data/splitting/random.py +41 -0
  17. eva/core/data/splitting/stratified.py +56 -0
  18. eva/core/data/transforms/__init__.py +3 -1
  19. eva/core/data/transforms/padding/__init__.py +5 -0
  20. eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
  21. eva/core/data/transforms/sampling/__init__.py +5 -0
  22. eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
  23. eva/core/loggers/__init__.py +7 -0
  24. eva/core/loggers/dummy.py +38 -0
  25. eva/core/loggers/experimental_loggers.py +8 -0
  26. eva/core/loggers/log/__init__.py +6 -0
  27. eva/core/loggers/log/image.py +71 -0
  28. eva/core/loggers/log/parameters.py +74 -0
  29. eva/core/loggers/log/utils.py +13 -0
  30. eva/core/loggers/loggers.py +6 -0
  31. eva/core/metrics/__init__.py +6 -2
  32. eva/core/metrics/defaults/__init__.py +10 -3
  33. eva/core/metrics/defaults/classification/__init__.py +1 -1
  34. eva/core/metrics/defaults/classification/binary.py +0 -9
  35. eva/core/metrics/defaults/classification/multiclass.py +0 -8
  36. eva/core/metrics/defaults/segmentation/__init__.py +5 -0
  37. eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
  38. eva/core/metrics/generalized_dice.py +59 -0
  39. eva/core/metrics/mean_iou.py +120 -0
  40. eva/core/metrics/structs/schemas.py +3 -1
  41. eva/core/models/__init__.py +3 -1
  42. eva/core/models/modules/head.py +16 -15
  43. eva/core/models/modules/module.py +25 -1
  44. eva/core/models/modules/typings.py +14 -1
  45. eva/core/models/modules/utils/batch_postprocess.py +37 -5
  46. eva/core/models/networks/__init__.py +1 -2
  47. eva/core/models/networks/mlp.py +2 -2
  48. eva/core/models/transforms/__init__.py +6 -0
  49. eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
  50. eva/core/models/transforms/extract_patch_features.py +47 -0
  51. eva/core/models/wrappers/__init__.py +13 -0
  52. eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
  53. eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
  54. eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
  55. eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
  56. eva/core/trainers/_recorder.py +69 -7
  57. eva/core/trainers/functional.py +23 -5
  58. eva/core/trainers/trainer.py +20 -6
  59. eva/core/utils/__init__.py +6 -0
  60. eva/core/utils/clone.py +27 -0
  61. eva/core/utils/memory.py +28 -0
  62. eva/core/utils/operations.py +26 -0
  63. eva/core/utils/parser.py +20 -0
  64. eva/vision/__init__.py +2 -2
  65. eva/vision/callbacks/__init__.py +5 -0
  66. eva/vision/callbacks/loggers/__init__.py +5 -0
  67. eva/vision/callbacks/loggers/batch/__init__.py +5 -0
  68. eva/vision/callbacks/loggers/batch/base.py +130 -0
  69. eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
  70. eva/vision/data/datasets/__init__.py +24 -4
  71. eva/vision/data/datasets/_utils.py +3 -3
  72. eva/vision/data/datasets/_validators.py +15 -2
  73. eva/vision/data/datasets/classification/__init__.py +6 -2
  74. eva/vision/data/datasets/classification/bach.py +10 -15
  75. eva/vision/data/datasets/classification/base.py +17 -24
  76. eva/vision/data/datasets/classification/camelyon16.py +244 -0
  77. eva/vision/data/datasets/classification/crc.py +10 -15
  78. eva/vision/data/datasets/classification/mhist.py +10 -15
  79. eva/vision/data/datasets/classification/panda.py +184 -0
  80. eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
  81. eva/vision/data/datasets/classification/wsi.py +105 -0
  82. eva/vision/data/datasets/segmentation/__init__.py +15 -2
  83. eva/vision/data/datasets/segmentation/_utils.py +38 -0
  84. eva/vision/data/datasets/segmentation/base.py +31 -47
  85. eva/vision/data/datasets/segmentation/bcss.py +236 -0
  86. eva/vision/data/datasets/segmentation/consep.py +156 -0
  87. eva/vision/data/datasets/segmentation/embeddings.py +34 -0
  88. eva/vision/data/datasets/segmentation/lits.py +178 -0
  89. eva/vision/data/datasets/segmentation/monusac.py +236 -0
  90. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
  91. eva/vision/data/datasets/wsi.py +187 -0
  92. eva/vision/data/transforms/__init__.py +3 -2
  93. eva/vision/data/transforms/common/__init__.py +2 -1
  94. eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
  95. eva/vision/data/transforms/common/resize_and_crop.py +6 -7
  96. eva/vision/data/transforms/normalization/__init__.py +6 -0
  97. eva/vision/data/transforms/normalization/clamp.py +43 -0
  98. eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
  99. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
  100. eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
  101. eva/vision/data/wsi/__init__.py +16 -0
  102. eva/vision/data/wsi/backends/__init__.py +69 -0
  103. eva/vision/data/wsi/backends/base.py +115 -0
  104. eva/vision/data/wsi/backends/openslide.py +73 -0
  105. eva/vision/data/wsi/backends/pil.py +52 -0
  106. eva/vision/data/wsi/backends/tiffslide.py +42 -0
  107. eva/vision/data/wsi/patching/__init__.py +6 -0
  108. eva/vision/data/wsi/patching/coordinates.py +98 -0
  109. eva/vision/data/wsi/patching/mask.py +123 -0
  110. eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
  111. eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
  112. eva/vision/data/wsi/patching/samplers/base.py +48 -0
  113. eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
  114. eva/vision/data/wsi/patching/samplers/grid.py +47 -0
  115. eva/vision/data/wsi/patching/samplers/random.py +41 -0
  116. eva/vision/losses/__init__.py +5 -0
  117. eva/vision/losses/dice.py +40 -0
  118. eva/vision/models/__init__.py +4 -2
  119. eva/vision/models/modules/__init__.py +5 -0
  120. eva/vision/models/modules/semantic_segmentation.py +161 -0
  121. eva/vision/models/networks/__init__.py +1 -2
  122. eva/vision/models/networks/backbones/__init__.py +6 -0
  123. eva/vision/models/networks/backbones/_utils.py +39 -0
  124. eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
  125. eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
  126. eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
  127. eva/vision/models/networks/backbones/pathology/histai.py +46 -0
  128. eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
  129. eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
  130. eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
  131. eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
  132. eva/vision/models/networks/backbones/registry.py +47 -0
  133. eva/vision/models/networks/backbones/timm/__init__.py +5 -0
  134. eva/vision/models/networks/backbones/timm/backbones.py +54 -0
  135. eva/vision/models/networks/backbones/universal/__init__.py +8 -0
  136. eva/vision/models/networks/backbones/universal/vit.py +54 -0
  137. eva/vision/models/networks/decoders/__init__.py +6 -0
  138. eva/vision/models/networks/decoders/decoder.py +7 -0
  139. eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
  140. eva/vision/models/networks/decoders/segmentation/common.py +74 -0
  141. eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
  142. eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
  143. eva/vision/models/wrappers/__init__.py +6 -0
  144. eva/vision/models/wrappers/from_registry.py +48 -0
  145. eva/vision/models/wrappers/from_timm.py +68 -0
  146. eva/vision/utils/colormap.py +77 -0
  147. eva/vision/utils/convert.py +67 -0
  148. eva/vision/utils/io/__init__.py +10 -4
  149. eva/vision/utils/io/image.py +21 -2
  150. eva/vision/utils/io/mat.py +36 -0
  151. eva/vision/utils/io/nifti.py +40 -15
  152. eva/vision/utils/io/text.py +10 -3
  153. kaiko_eva-0.1.0.dist-info/METADATA +553 -0
  154. kaiko_eva-0.1.0.dist-info/RECORD +205 -0
  155. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
  156. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
  157. eva/core/callbacks/writers/embeddings.py +0 -169
  158. eva/core/callbacks/writers/typings.py +0 -23
  159. eva/core/models/networks/transforms/__init__.py +0 -5
  160. eva/core/models/networks/wrappers/__init__.py +0 -8
  161. eva/vision/data/datasets/classification/total_segmentator.py +0 -213
  162. eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
  163. eva/vision/models/networks/postprocesses/__init__.py +0 -5
  164. eva/vision/models/networks/postprocesses/cls.py +0 -25
  165. kaiko_eva-0.0.1.dist-info/METADATA +0 -405
  166. kaiko_eva-0.0.1.dist-info/RECORD +0 -110
  167. /eva/core/models/{networks → wrappers}/_utils.py +0 -0
  168. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,41 @@
1
+ """Random sampler."""
2
+
3
+ import random
4
+ from typing import Generator, Tuple
5
+
6
+ from eva.vision.data.wsi.patching.samplers import _utils, base
7
+
8
+
9
+ class RandomSampler(base.Sampler):
10
+ """Sample patch coordinates randomly.
11
+
12
+ Args:
13
+ n_samples: The number of samples to return.
14
+ seed: The random seed.
15
+ """
16
+
17
+ def __init__(self, n_samples: int = 1, seed: int = 42):
18
+ """Initializes the sampler."""
19
+ self.seed = seed
20
+ self.n_samples = n_samples
21
+
22
+ def sample(
23
+ self,
24
+ width: int,
25
+ height: int,
26
+ layer_shape: Tuple[int, int],
27
+ ) -> Generator[Tuple[int, int], None, None]:
28
+ """Sample random patches.
29
+
30
+ Args:
31
+ width: The width of the patches.
32
+ height: The height of the patches.
33
+ layer_shape: The shape of the layer.
34
+ """
35
+ _utils.validate_dimensions(width, height, layer_shape)
36
+ _utils.set_seed(self.seed)
37
+
38
+ x_max, y_max = layer_shape[0], layer_shape[1]
39
+ for _ in range(self.n_samples):
40
+ x, y = random.randint(0, x_max - width), random.randint(0, y_max - height) # nosec
41
+ yield x, y
@@ -0,0 +1,5 @@
1
+ """Loss functions API."""
2
+
3
+ from eva.vision.losses.dice import DiceLoss
4
+
5
+ __all__ = ["DiceLoss"]
@@ -0,0 +1,40 @@
1
+ """Dice loss."""
2
+
3
+ import torch
4
+ from monai import losses
5
+ from monai.networks import one_hot # type: ignore
6
+ from typing_extensions import override
7
+
8
+
9
+ class DiceLoss(losses.DiceLoss): # type: ignore
10
+ """Computes the average Dice loss between two tensors.
11
+
12
+ Extends the implementation from MONAI
13
+ - to support semantic target labels (meaning targets of shape BHW)
14
+ - to support `ignore_index` functionality
15
+ """
16
+
17
+ def __init__(self, *args, ignore_index: int | None = None, **kwargs) -> None:
18
+ """Initialize the DiceLoss with support for ignore_index.
19
+
20
+ Args:
21
+ args: Positional arguments from the base class.
22
+ ignore_index: Specifies a target value that is ignored and
23
+ does not contribute to the input gradient.
24
+ kwargs: Key-word arguments from the base class.
25
+ """
26
+ super().__init__(*args, **kwargs)
27
+
28
+ self.ignore_index = ignore_index
29
+
30
+ @override
31
+ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # noqa
32
+ if self.ignore_index is not None:
33
+ mask = targets != self.ignore_index
34
+ targets = targets * mask
35
+ inputs = torch.mul(inputs, mask.unsqueeze(1) if mask.ndim == 3 else mask)
36
+
37
+ if targets.ndim == 3:
38
+ targets = one_hot(targets[:, None, ...], num_classes=inputs.shape[1])
39
+
40
+ return super().forward(inputs, targets)
@@ -1,5 +1,7 @@
1
1
  """Vision Models API."""
2
2
 
3
- from eva.vision.models import networks
3
+ from eva.vision.models import networks, wrappers
4
+ from eva.vision.models.networks import backbones
5
+ from eva.vision.models.wrappers import ModelFromRegistry, TimmModel
4
6
 
5
- __all__ = ["networks"]
7
+ __all__ = ["networks", "wrappers", "backbones", "ModelFromRegistry", "TimmModel"]
@@ -0,0 +1,5 @@
1
+ """Vision modules API."""
2
+
3
+ from eva.vision.models.modules.semantic_segmentation import SemanticSegmentationModule
4
+
5
+ __all__ = ["SemanticSegmentationModule"]
@@ -0,0 +1,161 @@
1
+ """"Neural Network Semantic Segmentation Module."""
2
+
3
+ from typing import Any, Callable, Dict, Iterable, List, Tuple
4
+
5
+ import torch
6
+ from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
7
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
8
+ from torch import nn, optim
9
+ from torch.optim import lr_scheduler
10
+ from typing_extensions import override
11
+
12
+ from eva.core.metrics import structs as metrics_lib
13
+ from eva.core.models.modules import module
14
+ from eva.core.models.modules.typings import INPUT_BATCH, INPUT_TENSOR_BATCH
15
+ from eva.core.models.modules.utils import batch_postprocess, grad
16
+ from eva.core.utils import parser
17
+ from eva.vision.models.networks import decoders
18
+
19
+
20
+ class SemanticSegmentationModule(module.ModelModule):
21
+ """Neural network semantic segmentation module for training on patch embeddings."""
22
+
23
+ def __init__(
24
+ self,
25
+ decoder: decoders.Decoder,
26
+ criterion: Callable[..., torch.Tensor],
27
+ encoder: Dict[str, Any] | Callable[[torch.Tensor], List[torch.Tensor]] | None = None,
28
+ lr_multiplier_encoder: float = 0.0,
29
+ optimizer: OptimizerCallable = optim.AdamW,
30
+ lr_scheduler: LRSchedulerCallable = lr_scheduler.ConstantLR,
31
+ metrics: metrics_lib.MetricsSchema | None = None,
32
+ postprocess: batch_postprocess.BatchPostProcess | None = None,
33
+ ) -> None:
34
+ """Initializes the neural net head module.
35
+
36
+ Args:
37
+ decoder: The decoder model.
38
+ criterion: The loss function to use.
39
+ encoder: The encoder model. If `None`, it will be expected
40
+ that the input batch returns the features directly.
41
+ If pass as a dictionary, it will be parsed to an object
42
+ during the `configure_model` step.
43
+ lr_multiplier_encoder: The learning rate multiplier for the
44
+ encoder parameters. If `0`, it will freeze the encoder.
45
+ optimizer: The optimizer to use.
46
+ lr_scheduler: The learning rate scheduler to use.
47
+ metrics: The metric groups to track.
48
+ postprocess: A list of helper functions to apply after the
49
+ loss and before the metrics calculation to the model
50
+ predictions and targets.
51
+ """
52
+ super().__init__(metrics=metrics, postprocess=postprocess)
53
+
54
+ self.decoder = decoder
55
+ self.criterion = criterion
56
+ self.encoder = encoder # type: ignore
57
+ self.lr_multiplier_encoder = lr_multiplier_encoder
58
+ self.optimizer = optimizer
59
+ self.lr_scheduler = lr_scheduler
60
+
61
+ @override
62
+ def configure_model(self) -> None:
63
+ self._freeze_encoder()
64
+
65
+ if isinstance(self.encoder, dict):
66
+ self.encoder: Callable[[torch.Tensor], List[torch.Tensor]] = parser.parse_object(
67
+ self.encoder,
68
+ expected_type=nn.Module,
69
+ )
70
+
71
+ @override
72
+ def configure_optimizers(self) -> Any:
73
+ optimizer = self.optimizer(
74
+ [
75
+ {"params": self.decoder.parameters()},
76
+ {
77
+ "params": self._encoder_trainable_parameters(),
78
+ "lr": self._base_lr * self.lr_multiplier_encoder,
79
+ },
80
+ ]
81
+ )
82
+ lr_scheduler = self.lr_scheduler(optimizer)
83
+ return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
84
+
85
+ @override
86
+ def forward(
87
+ self,
88
+ inputs: torch.Tensor,
89
+ to_size: Tuple[int, int] | None = None,
90
+ *args: Any,
91
+ **kwargs: Any,
92
+ ) -> torch.Tensor:
93
+ """Maps the input tensor (image tensor or embeddings) to masks.
94
+
95
+ If `inputs` is image tensor, then the `self.encoder`
96
+ should be implemented, otherwise it will be interpreted
97
+ as embeddings, where the `to_size` should be given.
98
+ """
99
+ if self.encoder is None and to_size is None:
100
+ raise ValueError(
101
+ "Please provide the expected `to_size` that the "
102
+ "decoder should map the embeddings (`inputs`) to."
103
+ )
104
+
105
+ patch_embeddings = self.encoder(inputs) if self.encoder else inputs
106
+ return self.decoder(patch_embeddings, to_size or inputs.shape[-2:])
107
+
108
+ @override
109
+ def training_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
110
+ return self._batch_step(batch)
111
+
112
+ @override
113
+ def validation_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
114
+ return self._batch_step(batch)
115
+
116
+ @override
117
+ def test_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
118
+ return self._batch_step(batch)
119
+
120
+ @override
121
+ def predict_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> torch.Tensor:
122
+ tensor = INPUT_BATCH(*batch).data
123
+ return self.encoder(tensor) if isinstance(self.encoder, nn.Module) else tensor
124
+
125
+ @property
126
+ def _base_lr(self) -> float:
127
+ """Returns the base learning rate."""
128
+ base_optimizer = self.optimizer(self.parameters())
129
+ return base_optimizer.param_groups[-1]["lr"]
130
+
131
+ def _encoder_trainable_parameters(self) -> Iterable[torch.Tensor]:
132
+ """Returns the trainable parameters of the encoder."""
133
+ return (
134
+ self.encoder.parameters()
135
+ if isinstance(self.encoder, nn.Module) and self.lr_multiplier_encoder > 0
136
+ else iter(())
137
+ )
138
+
139
+ def _freeze_encoder(self) -> None:
140
+ """If initialized, it freezes the encoder network."""
141
+ if isinstance(self.encoder, nn.Module) and self.lr_multiplier_encoder == 0:
142
+ grad.deactivate_requires_grad(self.encoder)
143
+
144
+ def _batch_step(self, batch: INPUT_TENSOR_BATCH) -> STEP_OUTPUT:
145
+ """Performs a model forward step and calculates the loss.
146
+
147
+ Args:
148
+ batch: The desired batch to process.
149
+
150
+ Returns:
151
+ The batch step output.
152
+ """
153
+ data, targets, metadata = INPUT_TENSOR_BATCH(*batch)
154
+ predictions = self(data, to_size=targets.shape[-2:])
155
+ loss = self.criterion(predictions, targets)
156
+ return {
157
+ "loss": loss,
158
+ "targets": targets,
159
+ "predictions": predictions,
160
+ "metadata": metadata,
161
+ }
@@ -1,6 +1,5 @@
1
1
  """Vision Networks API."""
2
2
 
3
- from eva.vision.models.networks import postprocesses
4
3
  from eva.vision.models.networks.abmil import ABMIL
5
4
 
6
- __all__ = ["postprocesses", "ABMIL"]
5
+ __all__ = ["ABMIL"]
@@ -0,0 +1,6 @@
1
+ """Vision Model Backbones API."""
2
+
3
+ from eva.vision.models.networks.backbones import pathology, timm, universal
4
+ from eva.vision.models.networks.backbones.registry import BackboneModelRegistry, register_model
5
+
6
+ __all__ = ["pathology", "timm", "universal", "BackboneModelRegistry", "register_model"]
@@ -0,0 +1,39 @@
1
+ """Utilis for backbone networks."""
2
+
3
+ from typing import Any, Dict, Tuple
4
+
5
+ from torch import nn
6
+
7
+ from eva import models
8
+ from eva.core.models import transforms
9
+
10
+
11
+ def load_hugingface_model(
12
+ model_name: str,
13
+ out_indices: int | Tuple[int, ...] | None,
14
+ model_kwargs: Dict[str, Any] | None = None,
15
+ transform_args: Dict[str, Any] | None = None,
16
+ ) -> nn.Module:
17
+ """Helper function to load HuggingFace models.
18
+
19
+ Args:
20
+ model_name: The model name to load.
21
+ out_indices: Whether and which multi-level patch embeddings to return.
22
+ Currently only out_indices=1 is supported.
23
+ model_kwargs: The arguments used for instantiating the model.
24
+ transform_args: The arguments used for instantiating the transform.
25
+
26
+ Returns: The model instance.
27
+ """
28
+ if out_indices is None:
29
+ tensor_transforms = transforms.ExtractCLSFeatures(**(transform_args or {}))
30
+ elif out_indices == 1:
31
+ tensor_transforms = transforms.ExtractPatchFeatures(**(transform_args or {}))
32
+ else:
33
+ raise ValueError(f"out_indices={out_indices} is currently not supported.")
34
+
35
+ return models.HuggingFaceModel(
36
+ model_name_or_path=model_name,
37
+ tensor_transforms=tensor_transforms,
38
+ model_kwargs=model_kwargs,
39
+ )
@@ -0,0 +1,31 @@
1
+ """Vision Pathology Model Backbones API."""
2
+
3
+ from eva.vision.models.networks.backbones.pathology.bioptimus import bioptimus_h_optimus_0
4
+ from eva.vision.models.networks.backbones.pathology.gigapath import prov_gigapath
5
+ from eva.vision.models.networks.backbones.pathology.histai import histai_hibou_b, histai_hibou_l
6
+ from eva.vision.models.networks.backbones.pathology.kaiko import (
7
+ kaiko_vitb8,
8
+ kaiko_vitb16,
9
+ kaiko_vitl14,
10
+ kaiko_vits8,
11
+ kaiko_vits16,
12
+ )
13
+ from eva.vision.models.networks.backbones.pathology.lunit import lunit_vits8, lunit_vits16
14
+ from eva.vision.models.networks.backbones.pathology.mahmood import mahmood_uni
15
+ from eva.vision.models.networks.backbones.pathology.owkin import owkin_phikon
16
+
17
+ __all__ = [
18
+ "kaiko_vitb16",
19
+ "kaiko_vitb8",
20
+ "kaiko_vitl14",
21
+ "kaiko_vits16",
22
+ "kaiko_vits8",
23
+ "owkin_phikon",
24
+ "lunit_vits16",
25
+ "lunit_vits8",
26
+ "mahmood_uni",
27
+ "bioptimus_h_optimus_0",
28
+ "prov_gigapath",
29
+ "histai_hibou_b",
30
+ "histai_hibou_l",
31
+ ]
@@ -0,0 +1,34 @@
1
+ """Pathology FMs from Bioptimus."""
2
+
3
+ from typing import Tuple
4
+
5
+ import timm
6
+ from torch import nn
7
+
8
+ from eva.vision.models.networks.backbones.registry import register_model
9
+
10
+
11
+ @register_model("pathology/bioptimus_h_optimus_0")
12
+ def bioptimus_h_optimus_0(
13
+ dynamic_img_size: bool = True,
14
+ out_indices: int | Tuple[int, ...] | None = None,
15
+ ) -> nn.Module:
16
+ """Initializes the h_optimus_0 pathology FM by Bioptimus.
17
+
18
+ Args:
19
+ dynamic_img_size: Whether to allow the interpolation embedding
20
+ to be interpolated at `forward()` time when image grid changes
21
+ from original.
22
+ out_indices: Weather and which multi-level patch embeddings to return.
23
+
24
+ Returns:
25
+ The model instance.
26
+ """
27
+ return timm.create_model(
28
+ model_name="hf-hub:bioptimus/H-optimus-0",
29
+ pretrained=True,
30
+ init_values=1e-5,
31
+ dynamic_img_size=dynamic_img_size,
32
+ out_indices=out_indices,
33
+ features_only=out_indices is not None,
34
+ )
@@ -0,0 +1,33 @@
1
+ """Pathology FMs from other/mixed entities."""
2
+
3
+ from typing import Tuple
4
+
5
+ import timm
6
+ from torch import nn
7
+
8
+ from eva.vision.models.networks.backbones.registry import register_model
9
+
10
+
11
+ @register_model("pathology/prov_gigapath")
12
+ def prov_gigapath(
13
+ dynamic_img_size: bool = True,
14
+ out_indices: int | Tuple[int, ...] | None = None,
15
+ ) -> nn.Module:
16
+ """Initializes the Prov-GigaPath pathology FM.
17
+
18
+ Args:
19
+ dynamic_img_size: Whether to allow the interpolation embedding
20
+ to be interpolated at `forward()` time when image grid changes
21
+ from original.
22
+ out_indices: Weather and which multi-level patch embeddings to return.
23
+
24
+ Returns:
25
+ The model instance.
26
+ """
27
+ return timm.create_model(
28
+ model_name="hf_hub:prov-gigapath/prov-gigapath",
29
+ pretrained=True,
30
+ dynamic_img_size=dynamic_img_size,
31
+ out_indices=out_indices,
32
+ features_only=out_indices is not None,
33
+ )
@@ -0,0 +1,46 @@
1
+ """Pathology FMs from owkin."""
2
+
3
+ from typing import Tuple
4
+
5
+ from torch import nn
6
+
7
+ from eva.vision.models.networks.backbones import _utils
8
+ from eva.vision.models.networks.backbones.registry import register_model
9
+
10
+
11
+ @register_model("pathology/histai_hibou_b")
12
+ def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
13
+ """Initializes the hibou-B pathology FM by hist.ai (https://huggingface.co/histai/hibou-B).
14
+
15
+ Args:
16
+ out_indices: Whether and which multi-level patch embeddings to return.
17
+ Currently only out_indices=1 is supported.
18
+
19
+ Returns:
20
+ The model instance.
21
+ """
22
+ return _utils.load_hugingface_model(
23
+ model_name="histai/hibou-B",
24
+ out_indices=out_indices,
25
+ model_kwargs={"trust_remote_code": True},
26
+ transform_args={"ignore_remaining_dims": True} if out_indices is not None else None,
27
+ )
28
+
29
+
30
+ @register_model("pathology/histai_hibou_l")
31
+ def histai_hibou_l(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
32
+ """Initializes the hibou-L pathology FM by hist.ai (https://huggingface.co/histai/hibou-L).
33
+
34
+ Args:
35
+ out_indices: Whether and which multi-level patch embeddings to return.
36
+ Currently only out_indices=1 is supported.
37
+
38
+ Returns:
39
+ The model instance.
40
+ """
41
+ return _utils.load_hugingface_model(
42
+ model_name="histai/hibou-L",
43
+ out_indices=out_indices,
44
+ model_kwargs={"trust_remote_code": True},
45
+ transform_args={"ignore_remaining_dims": True} if out_indices is not None else None,
46
+ )
@@ -0,0 +1,123 @@
1
+ """Pathology FMs from kaiko.ai."""
2
+
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from eva.vision.models.networks.backbones.registry import register_model
9
+
10
+
11
+ @register_model("pathology/kaiko_vits16")
12
+ def kaiko_vits16(
13
+ dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
14
+ ) -> nn.Module:
15
+ """Initializes the ViTS-16 pathology FM by kaiko.ai.
16
+
17
+ Args:
18
+ dynamic_img_size: Support different input image sizes by allowing to change
19
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
20
+ out_indices: Whether and which multi-level patch embeddings to return.
21
+
22
+ Returns:
23
+ The model instance.
24
+ """
25
+ return torch.hub.load( # type: ignore
26
+ repo_or_dir="kaiko-ai/towards_large_pathology_fms",
27
+ model="vits16",
28
+ trust_repo=True,
29
+ dynamic_img_size=dynamic_img_size,
30
+ out_indices=out_indices,
31
+ )
32
+
33
+
34
+ @register_model("pathology/kaiko_vits8")
35
+ def kaiko_vits8(
36
+ dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
37
+ ) -> nn.Module:
38
+ """Initializes the ViTS-8 pathology FM by kaiko.ai.
39
+
40
+ Args:
41
+ dynamic_img_size: Support different input image sizes by allowing to change
42
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
43
+ out_indices: Whether and which multi-level patch embeddings to return.
44
+
45
+ Returns:
46
+ The model instance.
47
+ """
48
+ return torch.hub.load( # type: ignore
49
+ repo_or_dir="kaiko-ai/towards_large_pathology_fms",
50
+ model="vits8",
51
+ trust_repo=True,
52
+ dynamic_img_size=dynamic_img_size,
53
+ out_indices=out_indices,
54
+ )
55
+
56
+
57
+ @register_model("pathology/kaiko_vitb16")
58
+ def kaiko_vitb16(
59
+ dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
60
+ ) -> nn.Module:
61
+ """Initializes the ViTB-16 pathology FM by kaiko.ai.
62
+
63
+ Args:
64
+ dynamic_img_size: Support different input image sizes by allowing to change
65
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
66
+ out_indices: Whether and which multi-level patch embeddings to return.
67
+
68
+ Returns:
69
+ The model instance.
70
+ """
71
+ return torch.hub.load( # type: ignore
72
+ repo_or_dir="kaiko-ai/towards_large_pathology_fms",
73
+ model="vitb16",
74
+ trust_repo=True,
75
+ dynamic_img_size=dynamic_img_size,
76
+ out_indices=out_indices,
77
+ )
78
+
79
+
80
+ @register_model("pathology/kaiko_vitb8")
81
+ def kaiko_vitb8(
82
+ dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
83
+ ) -> nn.Module:
84
+ """Initializes the ViTB-8 pathology FM by kaiko.ai.
85
+
86
+ Args:
87
+ dynamic_img_size: Support different input image sizes by allowing to change
88
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
89
+ out_indices: Whether and which multi-level patch embeddings to return.
90
+
91
+ Returns:
92
+ The model instance.
93
+ """
94
+ return torch.hub.load( # type: ignore
95
+ repo_or_dir="kaiko-ai/towards_large_pathology_fms",
96
+ model="vitb8",
97
+ trust_repo=True,
98
+ dynamic_img_size=dynamic_img_size,
99
+ out_indices=out_indices,
100
+ )
101
+
102
+
103
+ @register_model("pathology/kaiko_vitl14")
104
+ def kaiko_vitl14(
105
+ dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
106
+ ) -> nn.Module:
107
+ """Initializes the ViTL-14 pathology FM by kaiko.ai.
108
+
109
+ Args:
110
+ dynamic_img_size: Support different input image sizes by allowing to change
111
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
112
+ out_indices: Whether and which multi-level patch embeddings to return.
113
+
114
+ Returns:
115
+ The model instance.
116
+ """
117
+ return torch.hub.load( # type: ignore
118
+ repo_or_dir="kaiko-ai/towards_large_pathology_fms",
119
+ model="vitl14",
120
+ trust_repo=True,
121
+ dynamic_img_size=dynamic_img_size,
122
+ out_indices=out_indices,
123
+ )
@@ -0,0 +1,68 @@
1
+ """Pathology FMs from Lunit.
2
+
3
+ Source: https://github.com/lunit-io/benchmark-ssl-pathology/releases
4
+
5
+ For training the vit-s models the following standardization parameters were used:
6
+
7
+ mean: [ 0.70322989, 0.53606487, 0.66096631 ]
8
+ std: [ 0.21716536, 0.26081574, 0.20723464 ]
9
+ """
10
+
11
+ from typing import Tuple
12
+
13
+ from torch import nn
14
+
15
+ from eva.vision.models import wrappers
16
+ from eva.vision.models.networks.backbones.registry import register_model
17
+
18
+ VITS_URL_PREFIX = (
19
+ "https://github.com/lunit-io/benchmark-ssl-pathology/releases/download/pretrained-weights"
20
+ )
21
+
22
+
23
+ @register_model("pathology/lunit_vits16")
24
+ def lunit_vits16(
25
+ dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
26
+ ) -> nn.Module:
27
+ """Initializes the ViTS-16 pathology FM by lunit.
28
+
29
+ Args:
30
+ dynamic_img_size: Support different input image sizes by allowing to change
31
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
32
+ out_indices: Whether and which multi-level patch embeddings to return.
33
+
34
+ Returns:
35
+ The model instance.
36
+ """
37
+ return wrappers.TimmModel(
38
+ model_name="vit_small_patch16_224.dino",
39
+ out_indices=out_indices,
40
+ model_kwargs={
41
+ "dynamic_img_size": dynamic_img_size,
42
+ },
43
+ checkpoint_path=f"{VITS_URL_PREFIX}/dino_vit_small_patch16_ep200.torch",
44
+ )
45
+
46
+
47
+ @register_model("pathology/lunit_vits8")
48
+ def lunit_vits8(
49
+ dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
50
+ ) -> nn.Module:
51
+ """Initializes the ViTS-8 pathology FM by lunit.
52
+
53
+ Args:
54
+ dynamic_img_size: Support different input image sizes by allowing to change
55
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
56
+ out_indices: Whether and which multi-level patch embeddings to return.
57
+
58
+ Returns:
59
+ The model instance.
60
+ """
61
+ return wrappers.TimmModel(
62
+ model_name="vit_small_patch8_224.dino",
63
+ out_indices=out_indices,
64
+ model_kwargs={
65
+ "dynamic_img_size": dynamic_img_size,
66
+ },
67
+ checkpoint_path=f"{VITS_URL_PREFIX}/dino_vit_small_patch8_ep200.torch",
68
+ )