kaiko-eva 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (159) hide show
  1. eva/core/callbacks/__init__.py +2 -2
  2. eva/core/callbacks/writers/__init__.py +6 -3
  3. eva/core/callbacks/writers/embeddings/__init__.py +6 -0
  4. eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
  5. eva/core/callbacks/writers/embeddings/base.py +192 -0
  6. eva/core/callbacks/writers/embeddings/classification.py +117 -0
  7. eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
  8. eva/core/callbacks/writers/embeddings/typings.py +38 -0
  9. eva/core/data/datasets/__init__.py +2 -2
  10. eva/core/data/datasets/classification/__init__.py +8 -0
  11. eva/core/data/datasets/classification/embeddings.py +34 -0
  12. eva/core/data/datasets/{embeddings/classification → classification}/multi_embeddings.py +13 -9
  13. eva/core/data/datasets/{embeddings/base.py → embeddings.py} +47 -32
  14. eva/core/data/splitting/__init__.py +6 -0
  15. eva/core/data/splitting/random.py +41 -0
  16. eva/core/data/splitting/stratified.py +56 -0
  17. eva/core/loggers/experimental_loggers.py +2 -2
  18. eva/core/loggers/log/__init__.py +3 -2
  19. eva/core/loggers/log/image.py +71 -0
  20. eva/core/loggers/log/parameters.py +10 -0
  21. eva/core/loggers/loggers.py +6 -0
  22. eva/core/metrics/__init__.py +6 -2
  23. eva/core/metrics/defaults/__init__.py +10 -3
  24. eva/core/metrics/defaults/classification/__init__.py +1 -1
  25. eva/core/metrics/defaults/classification/binary.py +0 -9
  26. eva/core/metrics/defaults/classification/multiclass.py +0 -8
  27. eva/core/metrics/defaults/segmentation/__init__.py +5 -0
  28. eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
  29. eva/core/metrics/generalized_dice.py +59 -0
  30. eva/core/metrics/mean_iou.py +120 -0
  31. eva/core/metrics/structs/schemas.py +3 -1
  32. eva/core/models/__init__.py +3 -1
  33. eva/core/models/modules/head.py +10 -4
  34. eva/core/models/modules/typings.py +14 -1
  35. eva/core/models/modules/utils/batch_postprocess.py +37 -5
  36. eva/core/models/networks/__init__.py +1 -2
  37. eva/core/models/networks/mlp.py +2 -2
  38. eva/core/models/transforms/__init__.py +6 -0
  39. eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
  40. eva/core/models/transforms/extract_patch_features.py +47 -0
  41. eva/core/models/wrappers/__init__.py +13 -0
  42. eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
  43. eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
  44. eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
  45. eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
  46. eva/core/trainers/functional.py +1 -0
  47. eva/core/utils/__init__.py +6 -0
  48. eva/core/utils/clone.py +27 -0
  49. eva/core/utils/memory.py +28 -0
  50. eva/core/utils/operations.py +26 -0
  51. eva/core/utils/parser.py +20 -0
  52. eva/vision/__init__.py +2 -2
  53. eva/vision/callbacks/__init__.py +5 -0
  54. eva/vision/callbacks/loggers/__init__.py +5 -0
  55. eva/vision/callbacks/loggers/batch/__init__.py +5 -0
  56. eva/vision/callbacks/loggers/batch/base.py +130 -0
  57. eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
  58. eva/vision/data/datasets/__init__.py +30 -3
  59. eva/vision/data/datasets/_validators.py +15 -2
  60. eva/vision/data/datasets/classification/__init__.py +12 -1
  61. eva/vision/data/datasets/classification/bach.py +10 -15
  62. eva/vision/data/datasets/classification/base.py +17 -24
  63. eva/vision/data/datasets/classification/camelyon16.py +244 -0
  64. eva/vision/data/datasets/classification/crc.py +10 -15
  65. eva/vision/data/datasets/classification/mhist.py +10 -15
  66. eva/vision/data/datasets/classification/panda.py +184 -0
  67. eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
  68. eva/vision/data/datasets/classification/wsi.py +105 -0
  69. eva/vision/data/datasets/segmentation/__init__.py +15 -2
  70. eva/vision/data/datasets/segmentation/_utils.py +38 -0
  71. eva/vision/data/datasets/segmentation/base.py +16 -17
  72. eva/vision/data/datasets/segmentation/bcss.py +236 -0
  73. eva/vision/data/datasets/segmentation/consep.py +156 -0
  74. eva/vision/data/datasets/segmentation/embeddings.py +34 -0
  75. eva/vision/data/datasets/segmentation/lits.py +178 -0
  76. eva/vision/data/datasets/segmentation/monusac.py +236 -0
  77. eva/vision/data/datasets/segmentation/{total_segmentator.py → total_segmentator_2d.py} +130 -36
  78. eva/vision/data/datasets/wsi.py +187 -0
  79. eva/vision/data/transforms/__init__.py +3 -2
  80. eva/vision/data/transforms/common/__init__.py +2 -1
  81. eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
  82. eva/vision/data/transforms/common/resize_and_crop.py +6 -7
  83. eva/vision/data/transforms/normalization/__init__.py +6 -0
  84. eva/vision/data/transforms/normalization/clamp.py +43 -0
  85. eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
  86. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
  87. eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
  88. eva/vision/data/wsi/__init__.py +16 -0
  89. eva/vision/data/wsi/backends/__init__.py +69 -0
  90. eva/vision/data/wsi/backends/base.py +115 -0
  91. eva/vision/data/wsi/backends/openslide.py +73 -0
  92. eva/vision/data/wsi/backends/pil.py +52 -0
  93. eva/vision/data/wsi/backends/tiffslide.py +42 -0
  94. eva/vision/data/wsi/patching/__init__.py +6 -0
  95. eva/vision/data/wsi/patching/coordinates.py +98 -0
  96. eva/vision/data/wsi/patching/mask.py +123 -0
  97. eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
  98. eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
  99. eva/vision/data/wsi/patching/samplers/base.py +48 -0
  100. eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
  101. eva/vision/data/wsi/patching/samplers/grid.py +47 -0
  102. eva/vision/data/wsi/patching/samplers/random.py +41 -0
  103. eva/vision/losses/__init__.py +5 -0
  104. eva/vision/losses/dice.py +40 -0
  105. eva/vision/models/__init__.py +4 -2
  106. eva/vision/models/modules/__init__.py +5 -0
  107. eva/vision/models/modules/semantic_segmentation.py +161 -0
  108. eva/vision/models/networks/__init__.py +1 -2
  109. eva/vision/models/networks/backbones/__init__.py +6 -0
  110. eva/vision/models/networks/backbones/_utils.py +39 -0
  111. eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
  112. eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
  113. eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
  114. eva/vision/models/networks/backbones/pathology/histai.py +46 -0
  115. eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
  116. eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
  117. eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
  118. eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
  119. eva/vision/models/networks/backbones/registry.py +47 -0
  120. eva/vision/models/networks/backbones/timm/__init__.py +5 -0
  121. eva/vision/models/networks/backbones/timm/backbones.py +54 -0
  122. eva/vision/models/networks/backbones/universal/__init__.py +8 -0
  123. eva/vision/models/networks/backbones/universal/vit.py +54 -0
  124. eva/vision/models/networks/decoders/__init__.py +6 -0
  125. eva/vision/models/networks/decoders/decoder.py +7 -0
  126. eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
  127. eva/vision/models/networks/decoders/segmentation/common.py +74 -0
  128. eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
  129. eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
  130. eva/vision/models/wrappers/__init__.py +6 -0
  131. eva/vision/models/wrappers/from_registry.py +48 -0
  132. eva/vision/models/wrappers/from_timm.py +68 -0
  133. eva/vision/utils/colormap.py +77 -0
  134. eva/vision/utils/convert.py +56 -13
  135. eva/vision/utils/io/__init__.py +10 -4
  136. eva/vision/utils/io/image.py +21 -2
  137. eva/vision/utils/io/mat.py +36 -0
  138. eva/vision/utils/io/nifti.py +33 -12
  139. eva/vision/utils/io/text.py +10 -3
  140. kaiko_eva-0.1.1.dist-info/METADATA +553 -0
  141. kaiko_eva-0.1.1.dist-info/RECORD +205 -0
  142. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/WHEEL +1 -1
  143. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/entry_points.txt +2 -0
  144. eva/.DS_Store +0 -0
  145. eva/core/callbacks/writers/embeddings.py +0 -169
  146. eva/core/callbacks/writers/typings.py +0 -23
  147. eva/core/data/datasets/embeddings/__init__.py +0 -13
  148. eva/core/data/datasets/embeddings/classification/__init__.py +0 -10
  149. eva/core/data/datasets/embeddings/classification/embeddings.py +0 -66
  150. eva/core/models/networks/transforms/__init__.py +0 -5
  151. eva/core/models/networks/wrappers/__init__.py +0 -8
  152. eva/vision/models/.DS_Store +0 -0
  153. eva/vision/models/networks/.DS_Store +0 -0
  154. eva/vision/models/networks/postprocesses/__init__.py +0 -5
  155. eva/vision/models/networks/postprocesses/cls.py +0 -25
  156. kaiko_eva-0.0.2.dist-info/METADATA +0 -431
  157. kaiko_eva-0.0.2.dist-info/RECORD +0 -127
  158. /eva/core/models/{networks → wrappers}/_utils.py +0 -0
  159. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,12 +1,14 @@
1
1
  """Models API."""
2
2
 
3
3
  from eva.core.models.modules import HeadModule, InferenceModule
4
- from eva.core.models.networks import MLP, HuggingFaceModel, ModelFromFunction, ONNXModel
4
+ from eva.core.models.networks import MLP
5
+ from eva.core.models.wrappers import BaseModel, HuggingFaceModel, ModelFromFunction, ONNXModel
5
6
 
6
7
  __all__ = [
7
8
  "HeadModule",
8
9
  "InferenceModule",
9
10
  "MLP",
11
+ "BaseModel",
10
12
  "HuggingFaceModel",
11
13
  "ModelFromFunction",
12
14
  "ONNXModel",
@@ -1,11 +1,11 @@
1
1
  """"Neural Network Head Module."""
2
2
 
3
- from typing import Any, Callable
3
+ from typing import Any, Callable, Dict
4
4
 
5
5
  import torch
6
6
  from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
7
7
  from lightning.pytorch.utilities.types import STEP_OUTPUT
8
- from torch import optim
8
+ from torch import nn, optim
9
9
  from torch.optim import lr_scheduler
10
10
  from typing_extensions import override
11
11
 
@@ -13,6 +13,7 @@ from eva.core.metrics import structs as metrics_lib
13
13
  from eva.core.models.modules import module
14
14
  from eva.core.models.modules.typings import INPUT_BATCH, MODEL_TYPE
15
15
  from eva.core.models.modules.utils import batch_postprocess, grad
16
+ from eva.core.utils import parser
16
17
 
17
18
 
18
19
  class HeadModule(module.ModelModule):
@@ -24,7 +25,7 @@ class HeadModule(module.ModelModule):
24
25
 
25
26
  def __init__(
26
27
  self,
27
- head: MODEL_TYPE,
28
+ head: Dict[str, Any] | MODEL_TYPE,
28
29
  criterion: Callable[..., torch.Tensor],
29
30
  backbone: MODEL_TYPE | None = None,
30
31
  optimizer: OptimizerCallable = optim.Adam,
@@ -36,6 +37,8 @@ class HeadModule(module.ModelModule):
36
37
 
37
38
  Args:
38
39
  head: The neural network that would be trained on the features.
40
+ If its a dictionary, it will be parsed to an object during the
41
+ `configure_model` step.
39
42
  criterion: The loss function to use.
40
43
  backbone: The feature extractor. If `None`, it will be expected
41
44
  that the input batch returns the features directly.
@@ -48,7 +51,7 @@ class HeadModule(module.ModelModule):
48
51
  """
49
52
  super().__init__(metrics=metrics, postprocess=postprocess)
50
53
 
51
- self.head = head
54
+ self.head = head # type: ignore
52
55
  self.criterion = criterion
53
56
  self.backbone = backbone
54
57
  self.optimizer = optimizer
@@ -59,6 +62,9 @@ class HeadModule(module.ModelModule):
59
62
  if self.backbone is not None:
60
63
  grad.deactivate_requires_grad(self.backbone)
61
64
 
65
+ if isinstance(self.head, dict):
66
+ self.head: MODEL_TYPE = parser.parse_object(self.head, expected_type=nn.Module)
67
+
62
68
  @override
63
69
  def configure_optimizers(self) -> Any:
64
70
  parameters = self.head.parameters()
@@ -16,7 +16,20 @@ class INPUT_BATCH(NamedTuple):
16
16
  data: torch.Tensor
17
17
  """The data batch."""
18
18
 
19
- targets: torch.Tensor | Dict[str, Any] | None = None
19
+ targets: torch.Tensor | None = None
20
+ """The target batch."""
21
+
22
+ metadata: Dict[str, Any] | None = None
23
+ """The associated metadata."""
24
+
25
+
26
+ class INPUT_TENSOR_BATCH(NamedTuple):
27
+ """Tensor based input batch data scheme."""
28
+
29
+ data: torch.Tensor
30
+ """The data batch."""
31
+
32
+ targets: torch.Tensor
20
33
  """The target batch."""
21
34
 
22
35
  metadata: Dict[str, Any] | None = None
@@ -2,9 +2,10 @@
2
2
 
3
3
  import dataclasses
4
4
  import functools
5
- from typing import Callable, List
5
+ from typing import Any, Callable, Dict, List
6
6
 
7
7
  import torch
8
+ from jsonargparse import _util
8
9
  from lightning.pytorch.utilities.types import STEP_OUTPUT
9
10
 
10
11
  Transform = Callable[[torch.Tensor], torch.Tensor]
@@ -15,10 +16,10 @@ Transform = Callable[[torch.Tensor], torch.Tensor]
15
16
  class BatchPostProcess:
16
17
  """Batch post-processes transform schema."""
17
18
 
18
- targets_transforms: List[Transform] | None = None
19
+ targets_transforms: List[Transform | Dict[str, Any]] | None = None
19
20
  """Holds the common train and evaluation metrics."""
20
21
 
21
- predictions_transforms: List[Transform] | None = None
22
+ predictions_transforms: List[Transform | Dict[str, Any]] | None = None
22
23
  """Holds the common train and evaluation metrics."""
23
24
 
24
25
  def __call__(self, outputs: STEP_OUTPUT) -> None:
@@ -35,12 +36,13 @@ class BatchPostProcess:
35
36
 
36
37
  if "targets" in outputs and self.targets_transforms is not None:
37
38
  outputs["targets"] = _apply_transforms(
38
- outputs["targets"], transforms=self.targets_transforms
39
+ outputs["targets"], transforms=_parse_callable_inputs(self.targets_transforms)
39
40
  )
40
41
 
41
42
  if "predictions" in outputs and self.predictions_transforms is not None:
42
43
  outputs["predictions"] = _apply_transforms(
43
- outputs["predictions"], transforms=self.predictions_transforms
44
+ outputs["predictions"],
45
+ transforms=_parse_callable_inputs(self.predictions_transforms),
44
46
  )
45
47
 
46
48
 
@@ -55,3 +57,33 @@ def _apply_transforms(tensor: torch.Tensor, transforms: List[Transform]) -> torc
55
57
  The processed tensor.
56
58
  """
57
59
  return functools.reduce(lambda tensor, transform: transform(tensor), transforms, tensor)
60
+
61
+
62
+ def _parse_callable_inputs(inputs: List[Callable | Dict[str, Any]]) -> List[Callable]:
63
+ """Parses the inputs which where passed as dictionary to callable objects."""
64
+ parsed = []
65
+ for item in inputs:
66
+ if isinstance(item, dict):
67
+ item = _parse_dict(item)
68
+ parsed.append(item)
69
+ return parsed
70
+
71
+
72
+ def _parse_dict(item: Dict[str, Any]) -> Callable:
73
+ """Parses the input dictionary to a partial callable object."""
74
+ if not _is_valid_dict(item):
75
+ raise ValueError(
76
+ "Transform dictionary format is not valid. "
77
+ "It must contain a key 'class_path' and optionally 'init_args' for "
78
+ "the function and additional call arguments."
79
+ )
80
+
81
+ return functools.partial(
82
+ _util.import_object(item["class_path"]),
83
+ **item.get("init_args", {}),
84
+ )
85
+
86
+
87
+ def _is_valid_dict(item: Dict[str, Any], /) -> bool:
88
+ """Checks if the input has the valid structure."""
89
+ return "class_path" in item and set(item.keys()) <= {"class_path", "init_args"}
@@ -1,6 +1,5 @@
1
1
  """Networks API."""
2
2
 
3
3
  from eva.core.models.networks.mlp import MLP
4
- from eva.core.models.networks.wrappers import HuggingFaceModel, ModelFromFunction, ONNXModel
5
4
 
6
- __all__ = ["ModelFromFunction", "HuggingFaceModel", "ONNXModel", "MLP"]
5
+ __all__ = ["MLP"]
@@ -1,6 +1,6 @@
1
1
  """Multi-layer Perceptron (MLP) implemented in PyTorch."""
2
2
 
3
- from typing import Type
3
+ from typing import Tuple, Type
4
4
 
5
5
  import torch
6
6
  import torch.nn as nn
@@ -13,7 +13,7 @@ class MLP(nn.Module):
13
13
  self,
14
14
  input_size: int,
15
15
  output_size: int,
16
- hidden_layer_sizes: tuple[int, ...] | None = None,
16
+ hidden_layer_sizes: Tuple[int, ...] | None = None,
17
17
  hidden_activation_fn: Type[torch.nn.Module] | None = nn.ReLU,
18
18
  output_activation_fn: Type[torch.nn.Module] | None = None,
19
19
  dropout: float = 0.0,
@@ -0,0 +1,6 @@
1
+ """Model outputs transforms API."""
2
+
3
+ from eva.core.models.transforms.extract_cls_features import ExtractCLSFeatures
4
+ from eva.core.models.transforms.extract_patch_features import ExtractPatchFeatures
5
+
6
+ __all__ = ["ExtractCLSFeatures", "ExtractPatchFeatures"]
@@ -7,6 +7,14 @@ from transformers import modeling_outputs
7
7
  class ExtractCLSFeatures:
8
8
  """Extracts the CLS token from a ViT model output."""
9
9
 
10
+ def __init__(self, cls_index: int = 0) -> None:
11
+ """Initializes the transformation.
12
+
13
+ Args:
14
+ cls_index: The index of the CLS token in the output tensor.
15
+ """
16
+ self._cls_index = cls_index
17
+
10
18
  def __call__(
11
19
  self, tensor: torch.Tensor | modeling_outputs.BaseModelOutputWithPooling
12
20
  ) -> torch.Tensor:
@@ -16,9 +24,9 @@ class ExtractCLSFeatures:
16
24
  tensor: The tensor representing the model output.
17
25
  """
18
26
  if isinstance(tensor, torch.Tensor):
19
- transformed_tensor = tensor[:, 0, :]
27
+ transformed_tensor = tensor[:, self._cls_index, :]
20
28
  elif isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling):
21
- transformed_tensor = tensor.last_hidden_state[:, 0, :]
29
+ transformed_tensor = tensor.last_hidden_state[:, self._cls_index, :]
22
30
  else:
23
31
  raise ValueError(f"Unsupported type {type(tensor)}")
24
32
 
@@ -0,0 +1,47 @@
1
+ """Transforms for extracting the patch features from a model output."""
2
+
3
+ import math
4
+ from typing import List
5
+
6
+ import torch
7
+ from transformers import modeling_outputs
8
+
9
+
10
+ class ExtractPatchFeatures:
11
+ """Extracts the patch features from a ViT model output."""
12
+
13
+ def __init__(self, ignore_remaining_dims: bool = False) -> None:
14
+ """Initializes the transformation.
15
+
16
+ Args:
17
+ ignore_remaining_dims: If set to `True`, ignore the remaining dimensions
18
+ of the patch grid if it is not a square number.
19
+ """
20
+ self._ignore_remaining_dims = ignore_remaining_dims
21
+
22
+ def __call__(
23
+ self, tensor: torch.Tensor | modeling_outputs.BaseModelOutputWithPooling
24
+ ) -> List[torch.Tensor]:
25
+ """Call method for the transformation.
26
+
27
+ Args:
28
+ tensor: The raw embeddings of the model.
29
+
30
+ Returns:
31
+ A tensor (batch_size, hidden_size, n_patches_height, n_patches_width)
32
+ representing the model output.
33
+ """
34
+ if isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling):
35
+ features = tensor.last_hidden_state[:, 1:, :].permute(0, 2, 1)
36
+ batch_size, hidden_size, patch_grid = features.shape
37
+ height = width = int(math.sqrt(patch_grid))
38
+ if height * width != patch_grid:
39
+ if self._ignore_remaining_dims:
40
+ features = features[:, :, : height * width]
41
+ else:
42
+ raise ValueError(f"Patch grid size must be a square number {patch_grid}.")
43
+ patch_embeddings = features.view(batch_size, hidden_size, height, width)
44
+ else:
45
+ raise ValueError(f"Unsupported type {type(tensor)}")
46
+
47
+ return [patch_embeddings]
@@ -0,0 +1,13 @@
1
+ """Model Wrappers API."""
2
+
3
+ from eva.core.models.wrappers.base import BaseModel
4
+ from eva.core.models.wrappers.from_function import ModelFromFunction
5
+ from eva.core.models.wrappers.huggingface import HuggingFaceModel
6
+ from eva.core.models.wrappers.onnx import ONNXModel
7
+
8
+ __all__ = [
9
+ "BaseModel",
10
+ "ModelFromFunction",
11
+ "HuggingFaceModel",
12
+ "ONNXModel",
13
+ ]
@@ -22,6 +22,8 @@ class BaseModel(nn.Module):
22
22
 
23
23
  self._output_transforms = tensor_transforms
24
24
 
25
+ self._model: Callable[..., torch.Tensor] | nn.Module
26
+
25
27
  @override
26
28
  def forward(self, tensor: torch.Tensor) -> torch.Tensor:
27
29
  tensor = self.model_forward(tensor)
@@ -32,14 +34,13 @@ class BaseModel(nn.Module):
32
34
  """Loads the model."""
33
35
  raise NotImplementedError
34
36
 
35
- @abc.abstractmethod
36
37
  def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
37
38
  """Implements the forward pass of the model.
38
39
 
39
40
  Args:
40
41
  tensor: The input tensor to the model.
41
42
  """
42
- raise NotImplementedError
43
+ return self._model(tensor)
43
44
 
44
45
  def _apply_transforms(self, tensor: torch.Tensor) -> torch.Tensor:
45
46
  if self._output_transforms is not None:
@@ -3,12 +3,10 @@
3
3
  from typing import Any, Callable, Dict
4
4
 
5
5
  import jsonargparse
6
- import torch
7
6
  from torch import nn
8
7
  from typing_extensions import override
9
8
 
10
- from eva.core.models.networks import _utils
11
- from eva.core.models.networks.wrappers import base
9
+ from eva.core.models.wrappers import _utils, base
12
10
 
13
11
 
14
12
  class ModelFromFunction(base.BaseModel):
@@ -36,23 +34,18 @@ class ModelFromFunction(base.BaseModel):
36
34
  tensor_transforms: The transforms to apply to the output tensor
37
35
  produced by the model.
38
36
  """
39
- super().__init__()
37
+ super().__init__(tensor_transforms=tensor_transforms)
40
38
 
41
39
  self._path = path
42
40
  self._arguments = arguments
43
41
  self._checkpoint_path = checkpoint_path
44
- self._tensor_transforms = tensor_transforms
45
42
 
46
- self._model = self.load_model()
43
+ self.load_model()
47
44
 
48
45
  @override
49
- def load_model(self) -> nn.Module:
46
+ def load_model(self) -> None:
50
47
  class_path = jsonargparse.class_from_function(self._path, func_return=nn.Module)
51
48
  model = class_path(**self._arguments or {})
52
49
  if self._checkpoint_path is not None:
53
50
  _utils.load_model_weights(model, self._checkpoint_path)
54
- return model
55
-
56
- @override
57
- def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
58
- return self._model(tensor)
51
+ self._model = model
@@ -1,18 +1,22 @@
1
1
  """Wrappers for HuggingFace `transformers` models."""
2
2
 
3
- from typing import Any, Callable
3
+ from typing import Any, Callable, Dict
4
4
 
5
- import torch
6
5
  import transformers
7
6
  from typing_extensions import override
8
7
 
9
- from eva.core.models.networks.wrappers import base
8
+ from eva.core.models.wrappers import base
10
9
 
11
10
 
12
11
  class HuggingFaceModel(base.BaseModel):
13
12
  """Wrapper class for loading HuggingFace `transformers` models."""
14
13
 
15
- def __init__(self, model_name_or_path: str, tensor_transforms: Callable | None = None) -> None:
14
+ def __init__(
15
+ self,
16
+ model_name_or_path: str,
17
+ tensor_transforms: Callable | None = None,
18
+ model_kwargs: Dict[str, Any] | None = None,
19
+ ) -> None:
16
20
  """Initializes the model.
17
21
 
18
22
  Args:
@@ -21,17 +25,17 @@ class HuggingFaceModel(base.BaseModel):
21
25
  model hub.
22
26
  tensor_transforms: The transforms to apply to the output tensor
23
27
  produced by the model.
28
+ model_kwargs: The arguments used for instantiating the model.
24
29
  """
25
30
  super().__init__(tensor_transforms=tensor_transforms)
26
31
 
27
32
  self._model_name_or_path = model_name_or_path
28
- self._model = self.load_model()
33
+ self._model_kwargs = model_kwargs or {}
29
34
 
30
- @override
31
- def load_model(self) -> Any:
32
- config = transformers.AutoConfig.from_pretrained(self._model_name_or_path)
33
- return transformers.AutoModel.from_pretrained(self._model_name_or_path, config=config)
35
+ self.load_model()
34
36
 
35
37
  @override
36
- def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
37
- return self._model(tensor)
38
+ def load_model(self) -> None:
39
+ self._model = transformers.AutoModel.from_pretrained(
40
+ self._model_name_or_path, **self._model_kwargs
41
+ )
@@ -6,7 +6,7 @@ import onnxruntime as ort
6
6
  import torch
7
7
  from typing_extensions import override
8
8
 
9
- from eva.core.models.networks.wrappers import base
9
+ from eva.core.models.wrappers import base
10
10
 
11
11
 
12
12
  class ONNXModel(base.BaseModel):
@@ -29,19 +29,22 @@ class ONNXModel(base.BaseModel):
29
29
 
30
30
  self._path = path
31
31
  self._device = device
32
- self._model = self.load_model()
32
+
33
+ self.load_model()
33
34
 
34
35
  @override
35
36
  def load_model(self) -> Any:
36
37
  if self._device == "cuda" and not torch.cuda.is_available():
37
38
  raise ValueError("Device is set to 'cuda', but CUDA is not available.")
38
39
  provider = "CUDAExecutionProvider" if self._device == "cuda" else "CPUExecutionProvider"
39
- return ort.InferenceSession(self._path, providers=[provider])
40
+ self._model = ort.InferenceSession(self._path, providers=[provider]) # type: ignore
40
41
 
41
42
  @override
42
43
  def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
43
44
  # TODO: Use IO binding to avoid copying the tensor to CPU.
44
45
  # https://onnxruntime.ai/docs/api/python/api_summary.html#data-on-device
46
+ if not isinstance(self._model, ort.InferenceSession):
47
+ raise ValueError("Model is not loaded.")
45
48
  inputs = {self._model.get_inputs()[0].name: tensor.detach().cpu().numpy()}
46
49
  outputs = self._model.run(None, inputs)[0]
47
50
  return torch.from_numpy(outputs).float().to(tensor.device)
@@ -69,6 +69,7 @@ def run_evaluation(
69
69
  A tuple of with the validation and the test metrics (if exists).
70
70
  """
71
71
  trainer, model = _utils.clone(base_trainer, base_model)
72
+ model.configure_model()
72
73
  trainer.setup_log_dirs(run_id or "")
73
74
  return fit_and_validate(trainer, model, datamodule, verbose=verbose)
74
75
 
@@ -1 +1,7 @@
1
1
  """Utilities and library level helper functionalities."""
2
+
3
+ from eva.core.utils.clone import clone
4
+ from eva.core.utils.memory import to_cpu
5
+ from eva.core.utils.operations import numeric_sort
6
+
7
+ __all__ = ["clone", "to_cpu", "numeric_sort"]
@@ -0,0 +1,27 @@
1
+ """Clone related functions."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, List
5
+
6
+ import torch
7
+
8
+
9
+ @functools.singledispatch
10
+ def clone(tensor_type: Any) -> Any:
11
+ """Clone tensor objects."""
12
+ raise TypeError(f"Unsupported input type: {type(input)}.")
13
+
14
+
15
+ @clone.register
16
+ def _(tensor: torch.Tensor) -> torch.Tensor:
17
+ return tensor.clone()
18
+
19
+
20
+ @clone.register
21
+ def _(tensors: list) -> List[torch.Tensor]:
22
+ return list(map(clone, tensors))
23
+
24
+
25
+ @clone.register
26
+ def _(tensors: dict) -> Dict[str, torch.Tensor]:
27
+ return {key: clone(tensors[key]) for key in tensors}
@@ -0,0 +1,28 @@
1
+ """Memory related functions."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, List
5
+
6
+ import torch
7
+
8
+
9
+ @functools.singledispatch
10
+ def to_cpu(tensor_type: Any) -> Any:
11
+ """Moves tensor objects to `cpu`."""
12
+ raise TypeError(f"Unsupported input type: {type(tensor_type)}.")
13
+
14
+
15
+ @to_cpu.register
16
+ def _(tensor: torch.Tensor) -> torch.Tensor:
17
+ detached_tensor = tensor.detach()
18
+ return detached_tensor.cpu()
19
+
20
+
21
+ @to_cpu.register
22
+ def _(tensors: list) -> List[torch.Tensor]:
23
+ return list(map(to_cpu, tensors))
24
+
25
+
26
+ @to_cpu.register
27
+ def _(tensors: dict) -> Dict[str, torch.Tensor]:
28
+ return {key: to_cpu(tensors[key]) for key in tensors}
@@ -0,0 +1,26 @@
1
+ """Functional operations."""
2
+
3
+ import re
4
+ from typing import Iterable, List
5
+
6
+
7
+ def numeric_sort(item: Iterable[str], /) -> List[str]:
8
+ """Sorts an iterable of strings treating embedded numbers as numeric values.
9
+
10
+ Here the strings are compared based on their numeric value rather than their
11
+ string representation.
12
+
13
+ Args:
14
+ item: An iterable of strings to be sorted.
15
+
16
+ Returns:
17
+ A list of strings sorted based on their numeric values.
18
+ """
19
+ return sorted(
20
+ item,
21
+ key=lambda value: re.sub(
22
+ r"(\d+)",
23
+ lambda num: f"{int(num[0]):010d}",
24
+ value,
25
+ ),
26
+ )
@@ -0,0 +1,20 @@
1
+ """Parsing related helper functions."""
2
+
3
+ from typing import Any, Dict
4
+
5
+ import jsonargparse
6
+
7
+
8
+ def parse_object(config: Dict[str, Any], expected_type: Any = Any) -> Any:
9
+ """Parse object which is defined as dictionary."""
10
+ parser = jsonargparse.ArgumentParser()
11
+ parser.add_argument("module", type=expected_type)
12
+ configuration = parser.parse_object({"module": config})
13
+ init_object = parser.instantiate_classes(configuration)
14
+ obj_module = init_object.module
15
+ if isinstance(obj_module, jsonargparse.Namespace):
16
+ raise ValueError(
17
+ f"Failed to parsed object '{obj_module.class_path}'. "
18
+ "Please check that the initialized arguments are valid."
19
+ )
20
+ return obj_module
eva/vision/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
1
  """eva vision API."""
2
2
 
3
3
  try:
4
- from eva.vision import models, utils
4
+ from eva.vision import callbacks, losses, models, utils
5
5
  from eva.vision.data import datasets, transforms
6
6
  except ImportError as e:
7
7
  msg = (
@@ -11,4 +11,4 @@ except ImportError as e:
11
11
  )
12
12
  raise ImportError(str(e) + "\n\n" + msg) from e
13
13
 
14
- __all__ = ["models", "utils", "datasets", "transforms"]
14
+ __all__ = ["callbacks", "losses", "models", "utils", "datasets", "transforms"]
@@ -0,0 +1,5 @@
1
+ """Vision callbacks API."""
2
+
3
+ from eva.vision.callbacks.loggers import SemanticSegmentationLogger
4
+
5
+ __all__ = ["SemanticSegmentationLogger"]
@@ -0,0 +1,5 @@
1
+ """Vision logging related callbacks API."""
2
+
3
+ from eva.vision.callbacks.loggers.batch import SemanticSegmentationLogger
4
+
5
+ __all__ = ["SemanticSegmentationLogger"]
@@ -0,0 +1,5 @@
1
+ """Batch related loggers callbacks API."""
2
+
3
+ from eva.vision.callbacks.loggers.batch.segmentation import SemanticSegmentationLogger
4
+
5
+ __all__ = ["SemanticSegmentationLogger"]