kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (168) hide show
  1. eva/core/callbacks/__init__.py +3 -2
  2. eva/core/callbacks/config.py +143 -0
  3. eva/core/callbacks/writers/__init__.py +6 -3
  4. eva/core/callbacks/writers/embeddings/__init__.py +6 -0
  5. eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
  6. eva/core/callbacks/writers/embeddings/base.py +192 -0
  7. eva/core/callbacks/writers/embeddings/classification.py +117 -0
  8. eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
  9. eva/core/callbacks/writers/embeddings/typings.py +38 -0
  10. eva/core/data/datasets/__init__.py +10 -2
  11. eva/core/data/datasets/classification/__init__.py +5 -2
  12. eva/core/data/datasets/classification/embeddings.py +15 -135
  13. eva/core/data/datasets/classification/multi_embeddings.py +110 -0
  14. eva/core/data/datasets/embeddings.py +167 -0
  15. eva/core/data/splitting/__init__.py +6 -0
  16. eva/core/data/splitting/random.py +41 -0
  17. eva/core/data/splitting/stratified.py +56 -0
  18. eva/core/data/transforms/__init__.py +3 -1
  19. eva/core/data/transforms/padding/__init__.py +5 -0
  20. eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
  21. eva/core/data/transforms/sampling/__init__.py +5 -0
  22. eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
  23. eva/core/loggers/__init__.py +7 -0
  24. eva/core/loggers/dummy.py +38 -0
  25. eva/core/loggers/experimental_loggers.py +8 -0
  26. eva/core/loggers/log/__init__.py +6 -0
  27. eva/core/loggers/log/image.py +71 -0
  28. eva/core/loggers/log/parameters.py +74 -0
  29. eva/core/loggers/log/utils.py +13 -0
  30. eva/core/loggers/loggers.py +6 -0
  31. eva/core/metrics/__init__.py +6 -2
  32. eva/core/metrics/defaults/__init__.py +10 -3
  33. eva/core/metrics/defaults/classification/__init__.py +1 -1
  34. eva/core/metrics/defaults/classification/binary.py +0 -9
  35. eva/core/metrics/defaults/classification/multiclass.py +0 -8
  36. eva/core/metrics/defaults/segmentation/__init__.py +5 -0
  37. eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
  38. eva/core/metrics/generalized_dice.py +59 -0
  39. eva/core/metrics/mean_iou.py +120 -0
  40. eva/core/metrics/structs/schemas.py +3 -1
  41. eva/core/models/__init__.py +3 -1
  42. eva/core/models/modules/head.py +16 -15
  43. eva/core/models/modules/module.py +25 -1
  44. eva/core/models/modules/typings.py +14 -1
  45. eva/core/models/modules/utils/batch_postprocess.py +37 -5
  46. eva/core/models/networks/__init__.py +1 -2
  47. eva/core/models/networks/mlp.py +2 -2
  48. eva/core/models/transforms/__init__.py +6 -0
  49. eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
  50. eva/core/models/transforms/extract_patch_features.py +47 -0
  51. eva/core/models/wrappers/__init__.py +13 -0
  52. eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
  53. eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
  54. eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
  55. eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
  56. eva/core/trainers/_recorder.py +69 -7
  57. eva/core/trainers/functional.py +23 -5
  58. eva/core/trainers/trainer.py +20 -6
  59. eva/core/utils/__init__.py +6 -0
  60. eva/core/utils/clone.py +27 -0
  61. eva/core/utils/memory.py +28 -0
  62. eva/core/utils/operations.py +26 -0
  63. eva/core/utils/parser.py +20 -0
  64. eva/vision/__init__.py +2 -2
  65. eva/vision/callbacks/__init__.py +5 -0
  66. eva/vision/callbacks/loggers/__init__.py +5 -0
  67. eva/vision/callbacks/loggers/batch/__init__.py +5 -0
  68. eva/vision/callbacks/loggers/batch/base.py +130 -0
  69. eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
  70. eva/vision/data/datasets/__init__.py +24 -4
  71. eva/vision/data/datasets/_utils.py +3 -3
  72. eva/vision/data/datasets/_validators.py +15 -2
  73. eva/vision/data/datasets/classification/__init__.py +6 -2
  74. eva/vision/data/datasets/classification/bach.py +10 -15
  75. eva/vision/data/datasets/classification/base.py +17 -24
  76. eva/vision/data/datasets/classification/camelyon16.py +244 -0
  77. eva/vision/data/datasets/classification/crc.py +10 -15
  78. eva/vision/data/datasets/classification/mhist.py +10 -15
  79. eva/vision/data/datasets/classification/panda.py +184 -0
  80. eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
  81. eva/vision/data/datasets/classification/wsi.py +105 -0
  82. eva/vision/data/datasets/segmentation/__init__.py +15 -2
  83. eva/vision/data/datasets/segmentation/_utils.py +38 -0
  84. eva/vision/data/datasets/segmentation/base.py +31 -47
  85. eva/vision/data/datasets/segmentation/bcss.py +236 -0
  86. eva/vision/data/datasets/segmentation/consep.py +156 -0
  87. eva/vision/data/datasets/segmentation/embeddings.py +34 -0
  88. eva/vision/data/datasets/segmentation/lits.py +178 -0
  89. eva/vision/data/datasets/segmentation/monusac.py +236 -0
  90. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
  91. eva/vision/data/datasets/wsi.py +187 -0
  92. eva/vision/data/transforms/__init__.py +3 -2
  93. eva/vision/data/transforms/common/__init__.py +2 -1
  94. eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
  95. eva/vision/data/transforms/common/resize_and_crop.py +6 -7
  96. eva/vision/data/transforms/normalization/__init__.py +6 -0
  97. eva/vision/data/transforms/normalization/clamp.py +43 -0
  98. eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
  99. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
  100. eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
  101. eva/vision/data/wsi/__init__.py +16 -0
  102. eva/vision/data/wsi/backends/__init__.py +69 -0
  103. eva/vision/data/wsi/backends/base.py +115 -0
  104. eva/vision/data/wsi/backends/openslide.py +73 -0
  105. eva/vision/data/wsi/backends/pil.py +52 -0
  106. eva/vision/data/wsi/backends/tiffslide.py +42 -0
  107. eva/vision/data/wsi/patching/__init__.py +6 -0
  108. eva/vision/data/wsi/patching/coordinates.py +98 -0
  109. eva/vision/data/wsi/patching/mask.py +123 -0
  110. eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
  111. eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
  112. eva/vision/data/wsi/patching/samplers/base.py +48 -0
  113. eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
  114. eva/vision/data/wsi/patching/samplers/grid.py +47 -0
  115. eva/vision/data/wsi/patching/samplers/random.py +41 -0
  116. eva/vision/losses/__init__.py +5 -0
  117. eva/vision/losses/dice.py +40 -0
  118. eva/vision/models/__init__.py +4 -2
  119. eva/vision/models/modules/__init__.py +5 -0
  120. eva/vision/models/modules/semantic_segmentation.py +161 -0
  121. eva/vision/models/networks/__init__.py +1 -2
  122. eva/vision/models/networks/backbones/__init__.py +6 -0
  123. eva/vision/models/networks/backbones/_utils.py +39 -0
  124. eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
  125. eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
  126. eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
  127. eva/vision/models/networks/backbones/pathology/histai.py +46 -0
  128. eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
  129. eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
  130. eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
  131. eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
  132. eva/vision/models/networks/backbones/registry.py +47 -0
  133. eva/vision/models/networks/backbones/timm/__init__.py +5 -0
  134. eva/vision/models/networks/backbones/timm/backbones.py +54 -0
  135. eva/vision/models/networks/backbones/universal/__init__.py +8 -0
  136. eva/vision/models/networks/backbones/universal/vit.py +54 -0
  137. eva/vision/models/networks/decoders/__init__.py +6 -0
  138. eva/vision/models/networks/decoders/decoder.py +7 -0
  139. eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
  140. eva/vision/models/networks/decoders/segmentation/common.py +74 -0
  141. eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
  142. eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
  143. eva/vision/models/wrappers/__init__.py +6 -0
  144. eva/vision/models/wrappers/from_registry.py +48 -0
  145. eva/vision/models/wrappers/from_timm.py +68 -0
  146. eva/vision/utils/colormap.py +77 -0
  147. eva/vision/utils/convert.py +67 -0
  148. eva/vision/utils/io/__init__.py +10 -4
  149. eva/vision/utils/io/image.py +21 -2
  150. eva/vision/utils/io/mat.py +36 -0
  151. eva/vision/utils/io/nifti.py +40 -15
  152. eva/vision/utils/io/text.py +10 -3
  153. kaiko_eva-0.1.0.dist-info/METADATA +553 -0
  154. kaiko_eva-0.1.0.dist-info/RECORD +205 -0
  155. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
  156. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
  157. eva/core/callbacks/writers/embeddings.py +0 -169
  158. eva/core/callbacks/writers/typings.py +0 -23
  159. eva/core/models/networks/transforms/__init__.py +0 -5
  160. eva/core/models/networks/wrappers/__init__.py +0 -8
  161. eva/vision/data/datasets/classification/total_segmentator.py +0 -213
  162. eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
  163. eva/vision/models/networks/postprocesses/__init__.py +0 -5
  164. eva/vision/models/networks/postprocesses/cls.py +0 -25
  165. kaiko_eva-0.0.1.dist-info/METADATA +0 -405
  166. kaiko_eva-0.0.1.dist-info/RECORD +0 -110
  167. /eva/core/models/{networks → wrappers}/_utils.py +0 -0
  168. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -4,6 +4,7 @@ from typing import Any, Mapping
4
4
 
5
5
  import lightning.pytorch as pl
6
6
  import torch
7
+ from lightning.pytorch.strategies.single_device import SingleDeviceStrategy
7
8
  from lightning.pytorch.utilities import memory
8
9
  from lightning.pytorch.utilities.types import STEP_OUTPUT
9
10
  from typing_extensions import override
@@ -46,6 +47,21 @@ class ModelModule(pl.LightningModule):
46
47
  """The default post-processes."""
47
48
  return batch_postprocess.BatchPostProcess()
48
49
 
50
+ @property
51
+ def metrics_device(self) -> torch.device:
52
+ """Returns the device by which the metrics should be calculated.
53
+
54
+ We allocate the metrics to CPU when operating on single device, as
55
+ it is much faster, but to GPU when employing multiple ones, as DDP
56
+ strategy requires the metrics to be allocated to the module's GPU.
57
+ """
58
+ move_to_cpu = isinstance(self.trainer.strategy, SingleDeviceStrategy)
59
+ return torch.device("cpu") if move_to_cpu else self.device
60
+
61
+ @override
62
+ def on_fit_start(self) -> None:
63
+ self.metrics.to(device=self.metrics_device)
64
+
49
65
  @override
50
66
  def on_train_batch_end(
51
67
  self,
@@ -59,6 +75,10 @@ class ModelModule(pl.LightningModule):
59
75
  batch_outputs=outputs,
60
76
  )
61
77
 
78
+ @override
79
+ def on_validation_start(self) -> None:
80
+ self.metrics.to(device=self.metrics_device)
81
+
62
82
  @override
63
83
  def on_validation_batch_end(
64
84
  self,
@@ -78,6 +98,10 @@ class ModelModule(pl.LightningModule):
78
98
  def on_validation_epoch_end(self) -> None:
79
99
  self._compute_and_log_metrics(self.metrics.validation_metrics)
80
100
 
101
+ @override
102
+ def on_test_start(self) -> None:
103
+ self.metrics.to(device=self.metrics_device)
104
+
81
105
  @override
82
106
  def on_test_batch_end(
83
107
  self,
@@ -110,7 +134,7 @@ class ModelModule(pl.LightningModule):
110
134
  The updated outputs.
111
135
  """
112
136
  self._postprocess(outputs)
113
- return memory.recursive_detach(outputs, to_cpu=self.device.type == "cpu")
137
+ return memory.recursive_detach(outputs, to_cpu=self.metrics_device.type == "cpu")
114
138
 
115
139
  def _forward_and_log_metrics(
116
140
  self,
@@ -16,7 +16,20 @@ class INPUT_BATCH(NamedTuple):
16
16
  data: torch.Tensor
17
17
  """The data batch."""
18
18
 
19
- targets: torch.Tensor | Dict[str, Any] | None = None
19
+ targets: torch.Tensor | None = None
20
+ """The target batch."""
21
+
22
+ metadata: Dict[str, Any] | None = None
23
+ """The associated metadata."""
24
+
25
+
26
+ class INPUT_TENSOR_BATCH(NamedTuple):
27
+ """Tensor based input batch data scheme."""
28
+
29
+ data: torch.Tensor
30
+ """The data batch."""
31
+
32
+ targets: torch.Tensor
20
33
  """The target batch."""
21
34
 
22
35
  metadata: Dict[str, Any] | None = None
@@ -2,9 +2,10 @@
2
2
 
3
3
  import dataclasses
4
4
  import functools
5
- from typing import Callable, List
5
+ from typing import Any, Callable, Dict, List
6
6
 
7
7
  import torch
8
+ from jsonargparse import _util
8
9
  from lightning.pytorch.utilities.types import STEP_OUTPUT
9
10
 
10
11
  Transform = Callable[[torch.Tensor], torch.Tensor]
@@ -15,10 +16,10 @@ Transform = Callable[[torch.Tensor], torch.Tensor]
15
16
  class BatchPostProcess:
16
17
  """Batch post-processes transform schema."""
17
18
 
18
- targets_transforms: List[Transform] | None = None
19
+ targets_transforms: List[Transform | Dict[str, Any]] | None = None
19
20
  """Holds the common train and evaluation metrics."""
20
21
 
21
- predictions_transforms: List[Transform] | None = None
22
+ predictions_transforms: List[Transform | Dict[str, Any]] | None = None
22
23
  """Holds the common train and evaluation metrics."""
23
24
 
24
25
  def __call__(self, outputs: STEP_OUTPUT) -> None:
@@ -35,12 +36,13 @@ class BatchPostProcess:
35
36
 
36
37
  if "targets" in outputs and self.targets_transforms is not None:
37
38
  outputs["targets"] = _apply_transforms(
38
- outputs["targets"], transforms=self.targets_transforms
39
+ outputs["targets"], transforms=_parse_callable_inputs(self.targets_transforms)
39
40
  )
40
41
 
41
42
  if "predictions" in outputs and self.predictions_transforms is not None:
42
43
  outputs["predictions"] = _apply_transforms(
43
- outputs["predictions"], transforms=self.predictions_transforms
44
+ outputs["predictions"],
45
+ transforms=_parse_callable_inputs(self.predictions_transforms),
44
46
  )
45
47
 
46
48
 
@@ -55,3 +57,33 @@ def _apply_transforms(tensor: torch.Tensor, transforms: List[Transform]) -> torc
55
57
  The processed tensor.
56
58
  """
57
59
  return functools.reduce(lambda tensor, transform: transform(tensor), transforms, tensor)
60
+
61
+
62
+ def _parse_callable_inputs(inputs: List[Callable | Dict[str, Any]]) -> List[Callable]:
63
+ """Parses the inputs which where passed as dictionary to callable objects."""
64
+ parsed = []
65
+ for item in inputs:
66
+ if isinstance(item, dict):
67
+ item = _parse_dict(item)
68
+ parsed.append(item)
69
+ return parsed
70
+
71
+
72
+ def _parse_dict(item: Dict[str, Any]) -> Callable:
73
+ """Parses the input dictionary to a partial callable object."""
74
+ if not _is_valid_dict(item):
75
+ raise ValueError(
76
+ "Transform dictionary format is not valid. "
77
+ "It must contain a key 'class_path' and optionally 'init_args' for "
78
+ "the function and additional call arguments."
79
+ )
80
+
81
+ return functools.partial(
82
+ _util.import_object(item["class_path"]),
83
+ **item.get("init_args", {}),
84
+ )
85
+
86
+
87
+ def _is_valid_dict(item: Dict[str, Any], /) -> bool:
88
+ """Checks if the input has the valid structure."""
89
+ return "class_path" in item and set(item.keys()) <= {"class_path", "init_args"}
@@ -1,6 +1,5 @@
1
1
  """Networks API."""
2
2
 
3
3
  from eva.core.models.networks.mlp import MLP
4
- from eva.core.models.networks.wrappers import HuggingFaceModel, ModelFromFunction, ONNXModel
5
4
 
6
- __all__ = ["ModelFromFunction", "HuggingFaceModel", "ONNXModel", "MLP"]
5
+ __all__ = ["MLP"]
@@ -1,6 +1,6 @@
1
1
  """Multi-layer Perceptron (MLP) implemented in PyTorch."""
2
2
 
3
- from typing import Type
3
+ from typing import Tuple, Type
4
4
 
5
5
  import torch
6
6
  import torch.nn as nn
@@ -13,7 +13,7 @@ class MLP(nn.Module):
13
13
  self,
14
14
  input_size: int,
15
15
  output_size: int,
16
- hidden_layer_sizes: tuple[int, ...] | None = None,
16
+ hidden_layer_sizes: Tuple[int, ...] | None = None,
17
17
  hidden_activation_fn: Type[torch.nn.Module] | None = nn.ReLU,
18
18
  output_activation_fn: Type[torch.nn.Module] | None = None,
19
19
  dropout: float = 0.0,
@@ -0,0 +1,6 @@
1
+ """Model outputs transforms API."""
2
+
3
+ from eva.core.models.transforms.extract_cls_features import ExtractCLSFeatures
4
+ from eva.core.models.transforms.extract_patch_features import ExtractPatchFeatures
5
+
6
+ __all__ = ["ExtractCLSFeatures", "ExtractPatchFeatures"]
@@ -7,6 +7,14 @@ from transformers import modeling_outputs
7
7
  class ExtractCLSFeatures:
8
8
  """Extracts the CLS token from a ViT model output."""
9
9
 
10
+ def __init__(self, cls_index: int = 0) -> None:
11
+ """Initializes the transformation.
12
+
13
+ Args:
14
+ cls_index: The index of the CLS token in the output tensor.
15
+ """
16
+ self._cls_index = cls_index
17
+
10
18
  def __call__(
11
19
  self, tensor: torch.Tensor | modeling_outputs.BaseModelOutputWithPooling
12
20
  ) -> torch.Tensor:
@@ -16,9 +24,9 @@ class ExtractCLSFeatures:
16
24
  tensor: The tensor representing the model output.
17
25
  """
18
26
  if isinstance(tensor, torch.Tensor):
19
- transformed_tensor = tensor[:, 0, :]
27
+ transformed_tensor = tensor[:, self._cls_index, :]
20
28
  elif isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling):
21
- transformed_tensor = tensor.last_hidden_state[:, 0, :]
29
+ transformed_tensor = tensor.last_hidden_state[:, self._cls_index, :]
22
30
  else:
23
31
  raise ValueError(f"Unsupported type {type(tensor)}")
24
32
 
@@ -0,0 +1,47 @@
1
+ """Transforms for extracting the patch features from a model output."""
2
+
3
+ import math
4
+ from typing import List
5
+
6
+ import torch
7
+ from transformers import modeling_outputs
8
+
9
+
10
+ class ExtractPatchFeatures:
11
+ """Extracts the patch features from a ViT model output."""
12
+
13
+ def __init__(self, ignore_remaining_dims: bool = False) -> None:
14
+ """Initializes the transformation.
15
+
16
+ Args:
17
+ ignore_remaining_dims: If set to `True`, ignore the remaining dimensions
18
+ of the patch grid if it is not a square number.
19
+ """
20
+ self._ignore_remaining_dims = ignore_remaining_dims
21
+
22
+ def __call__(
23
+ self, tensor: torch.Tensor | modeling_outputs.BaseModelOutputWithPooling
24
+ ) -> List[torch.Tensor]:
25
+ """Call method for the transformation.
26
+
27
+ Args:
28
+ tensor: The raw embeddings of the model.
29
+
30
+ Returns:
31
+ A tensor (batch_size, hidden_size, n_patches_height, n_patches_width)
32
+ representing the model output.
33
+ """
34
+ if isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling):
35
+ features = tensor.last_hidden_state[:, 1:, :].permute(0, 2, 1)
36
+ batch_size, hidden_size, patch_grid = features.shape
37
+ height = width = int(math.sqrt(patch_grid))
38
+ if height * width != patch_grid:
39
+ if self._ignore_remaining_dims:
40
+ features = features[:, :, : height * width]
41
+ else:
42
+ raise ValueError(f"Patch grid size must be a square number {patch_grid}.")
43
+ patch_embeddings = features.view(batch_size, hidden_size, height, width)
44
+ else:
45
+ raise ValueError(f"Unsupported type {type(tensor)}")
46
+
47
+ return [patch_embeddings]
@@ -0,0 +1,13 @@
1
+ """Model Wrappers API."""
2
+
3
+ from eva.core.models.wrappers.base import BaseModel
4
+ from eva.core.models.wrappers.from_function import ModelFromFunction
5
+ from eva.core.models.wrappers.huggingface import HuggingFaceModel
6
+ from eva.core.models.wrappers.onnx import ONNXModel
7
+
8
+ __all__ = [
9
+ "BaseModel",
10
+ "ModelFromFunction",
11
+ "HuggingFaceModel",
12
+ "ONNXModel",
13
+ ]
@@ -22,6 +22,8 @@ class BaseModel(nn.Module):
22
22
 
23
23
  self._output_transforms = tensor_transforms
24
24
 
25
+ self._model: Callable[..., torch.Tensor] | nn.Module
26
+
25
27
  @override
26
28
  def forward(self, tensor: torch.Tensor) -> torch.Tensor:
27
29
  tensor = self.model_forward(tensor)
@@ -32,14 +34,13 @@ class BaseModel(nn.Module):
32
34
  """Loads the model."""
33
35
  raise NotImplementedError
34
36
 
35
- @abc.abstractmethod
36
37
  def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
37
38
  """Implements the forward pass of the model.
38
39
 
39
40
  Args:
40
41
  tensor: The input tensor to the model.
41
42
  """
42
- raise NotImplementedError
43
+ return self._model(tensor)
43
44
 
44
45
  def _apply_transforms(self, tensor: torch.Tensor) -> torch.Tensor:
45
46
  if self._output_transforms is not None:
@@ -3,12 +3,10 @@
3
3
  from typing import Any, Callable, Dict
4
4
 
5
5
  import jsonargparse
6
- import torch
7
6
  from torch import nn
8
7
  from typing_extensions import override
9
8
 
10
- from eva.core.models.networks import _utils
11
- from eva.core.models.networks.wrappers import base
9
+ from eva.core.models.wrappers import _utils, base
12
10
 
13
11
 
14
12
  class ModelFromFunction(base.BaseModel):
@@ -36,23 +34,18 @@ class ModelFromFunction(base.BaseModel):
36
34
  tensor_transforms: The transforms to apply to the output tensor
37
35
  produced by the model.
38
36
  """
39
- super().__init__()
37
+ super().__init__(tensor_transforms=tensor_transforms)
40
38
 
41
39
  self._path = path
42
40
  self._arguments = arguments
43
41
  self._checkpoint_path = checkpoint_path
44
- self._tensor_transforms = tensor_transforms
45
42
 
46
- self._model = self.load_model()
43
+ self.load_model()
47
44
 
48
45
  @override
49
- def load_model(self) -> nn.Module:
46
+ def load_model(self) -> None:
50
47
  class_path = jsonargparse.class_from_function(self._path, func_return=nn.Module)
51
48
  model = class_path(**self._arguments or {})
52
49
  if self._checkpoint_path is not None:
53
50
  _utils.load_model_weights(model, self._checkpoint_path)
54
- return model
55
-
56
- @override
57
- def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
58
- return self._model(tensor)
51
+ self._model = model
@@ -1,18 +1,22 @@
1
1
  """Wrappers for HuggingFace `transformers` models."""
2
2
 
3
- from typing import Any, Callable
3
+ from typing import Any, Callable, Dict
4
4
 
5
- import torch
6
5
  import transformers
7
6
  from typing_extensions import override
8
7
 
9
- from eva.core.models.networks.wrappers import base
8
+ from eva.core.models.wrappers import base
10
9
 
11
10
 
12
11
  class HuggingFaceModel(base.BaseModel):
13
12
  """Wrapper class for loading HuggingFace `transformers` models."""
14
13
 
15
- def __init__(self, model_name_or_path: str, tensor_transforms: Callable | None = None) -> None:
14
+ def __init__(
15
+ self,
16
+ model_name_or_path: str,
17
+ tensor_transforms: Callable | None = None,
18
+ model_kwargs: Dict[str, Any] | None = None,
19
+ ) -> None:
16
20
  """Initializes the model.
17
21
 
18
22
  Args:
@@ -21,17 +25,17 @@ class HuggingFaceModel(base.BaseModel):
21
25
  model hub.
22
26
  tensor_transforms: The transforms to apply to the output tensor
23
27
  produced by the model.
28
+ model_kwargs: The arguments used for instantiating the model.
24
29
  """
25
30
  super().__init__(tensor_transforms=tensor_transforms)
26
31
 
27
32
  self._model_name_or_path = model_name_or_path
28
- self._model = self.load_model()
33
+ self._model_kwargs = model_kwargs or {}
29
34
 
30
- @override
31
- def load_model(self) -> Any:
32
- config = transformers.AutoConfig.from_pretrained(self._model_name_or_path)
33
- return transformers.AutoModel.from_pretrained(self._model_name_or_path, config=config)
35
+ self.load_model()
34
36
 
35
37
  @override
36
- def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
37
- return self._model(tensor)
38
+ def load_model(self) -> None:
39
+ self._model = transformers.AutoModel.from_pretrained(
40
+ self._model_name_or_path, **self._model_kwargs
41
+ )
@@ -6,7 +6,7 @@ import onnxruntime as ort
6
6
  import torch
7
7
  from typing_extensions import override
8
8
 
9
- from eva.core.models.networks.wrappers import base
9
+ from eva.core.models.wrappers import base
10
10
 
11
11
 
12
12
  class ONNXModel(base.BaseModel):
@@ -29,19 +29,22 @@ class ONNXModel(base.BaseModel):
29
29
 
30
30
  self._path = path
31
31
  self._device = device
32
- self._model = self.load_model()
32
+
33
+ self.load_model()
33
34
 
34
35
  @override
35
36
  def load_model(self) -> Any:
36
37
  if self._device == "cuda" and not torch.cuda.is_available():
37
38
  raise ValueError("Device is set to 'cuda', but CUDA is not available.")
38
39
  provider = "CUDAExecutionProvider" if self._device == "cuda" else "CPUExecutionProvider"
39
- return ort.InferenceSession(self._path, providers=[provider])
40
+ self._model = ort.InferenceSession(self._path, providers=[provider]) # type: ignore
40
41
 
41
42
  @override
42
43
  def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
43
44
  # TODO: Use IO binding to avoid copying the tensor to CPU.
44
45
  # https://onnxruntime.ai/docs/api/python/api_summary.html#data-on-device
46
+ if not isinstance(self._model, ort.InferenceSession):
47
+ raise ValueError("Model is not loaded.")
45
48
  inputs = {self._model.get_inputs()[0].name: tensor.detach().cpu().numpy()}
46
49
  outputs = self._model.run(None, inputs)[0]
47
50
  return torch.from_numpy(outputs).float().to(tensor.device)
@@ -5,18 +5,41 @@ import json
5
5
  import os
6
6
  import statistics
7
7
  import sys
8
- from typing import Any, Dict, List, Mapping
8
+ from typing import Dict, List, Mapping, TypedDict
9
9
 
10
10
  from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT
11
11
  from lightning_fabric.utilities import cloud_io
12
12
  from loguru import logger
13
13
  from omegaconf import OmegaConf
14
+ from rich import console as rich_console
15
+ from rich import table as rich_table
14
16
  from toolz import dicttoolz
15
17
 
16
18
  SESSION_METRICS = Mapping[str, List[float]]
17
19
  """Session metrics type-hint."""
18
20
 
19
21
 
22
+ class SESSION_STATISTICS(TypedDict):
23
+ """Type-hint for aggregated metrics of multiple runs with mean & stdev."""
24
+
25
+ mean: float
26
+ stdev: float
27
+ values: List[float]
28
+
29
+
30
+ class STAGE_RESULTS(TypedDict):
31
+ """Type-hint for metrics statstics for val & test stages."""
32
+
33
+ val: List[Dict[str, SESSION_STATISTICS]]
34
+ test: List[Dict[str, SESSION_STATISTICS]]
35
+
36
+
37
+ class RESULTS_DICT(TypedDict):
38
+ """Type-hint for the final results dictionary."""
39
+
40
+ metrics: STAGE_RESULTS
41
+
42
+
20
43
  class SessionRecorder:
21
44
  """Multi-run (session) summary logger."""
22
45
 
@@ -25,6 +48,7 @@ class SessionRecorder:
25
48
  output_dir: str,
26
49
  results_file: str = "results.json",
27
50
  config_file: str = "config.yaml",
51
+ verbose: bool = True,
28
52
  ) -> None:
29
53
  """Initializes the recorder.
30
54
 
@@ -32,10 +56,12 @@ class SessionRecorder:
32
56
  output_dir: The destination folder to save the results.
33
57
  results_file: The name of the results json file.
34
58
  config_file: The name of the yaml configuration file.
59
+ verbose: Whether to print the session metrics.
35
60
  """
36
61
  self._output_dir = output_dir
37
62
  self._results_file = results_file
38
63
  self._config_file = config_file
64
+ self._verbose = verbose
39
65
 
40
66
  self._validation_metrics: List[SESSION_METRICS] = []
41
67
  self._test_metrics: List[SESSION_METRICS] = []
@@ -67,13 +93,13 @@ class SessionRecorder:
67
93
  self._update_validation_metrics(validation_scores)
68
94
  self._update_test_metrics(test_scores)
69
95
 
70
- def compute(self) -> Dict[str, List[Dict[str, Any]]]:
96
+ def compute(self) -> STAGE_RESULTS:
71
97
  """Computes and returns the session statistics."""
72
98
  validation_statistics = list(map(_calculate_statistics, self._validation_metrics))
73
99
  test_statistics = list(map(_calculate_statistics, self._test_metrics))
74
100
  return {"val": validation_statistics, "test": test_statistics}
75
101
 
76
- def export(self) -> Dict[str, Any]:
102
+ def export(self) -> RESULTS_DICT:
77
103
  """Exports the results."""
78
104
  statistics = self.compute()
79
105
  return {"metrics": statistics}
@@ -83,6 +109,8 @@ class SessionRecorder:
83
109
  results = self.export()
84
110
  _save_json(results, self.filename)
85
111
  self._save_config()
112
+ if self._verbose:
113
+ _print_results(results)
86
114
 
87
115
  def reset(self) -> None:
88
116
  """Resets the state of the tracked metrics."""
@@ -125,10 +153,10 @@ def _init_session_metrics(n_datasets: int) -> List[SESSION_METRICS]:
125
153
  return [collections.defaultdict(list) for _ in range(n_datasets)]
126
154
 
127
155
 
128
- def _calculate_statistics(session_metrics: SESSION_METRICS) -> Dict[str, float | List[float]]:
156
+ def _calculate_statistics(session_metrics: SESSION_METRICS) -> Dict[str, SESSION_STATISTICS]:
129
157
  """Calculate the metric statistics of a dataset session run."""
130
158
 
131
- def _calculate_metric_statistics(values: List[float]) -> Dict[str, float | List[float]]:
159
+ def _calculate_metric_statistics(values: List[float]) -> SESSION_STATISTICS:
132
160
  """Calculates and returns the metric statistics."""
133
161
  mean = statistics.mean(values)
134
162
  stdev = statistics.stdev(values) if len(values) > 1 else 0
@@ -137,7 +165,7 @@ def _calculate_statistics(session_metrics: SESSION_METRICS) -> Dict[str, float |
137
165
  return dicttoolz.valmap(_calculate_metric_statistics, session_metrics)
138
166
 
139
167
 
140
- def _save_json(data: Dict[str, Any], save_as: str = "data.json"):
168
+ def _save_json(data: RESULTS_DICT, save_as: str = "data.json"):
141
169
  """Saves data to a json file."""
142
170
  if not save_as.endswith(".json"):
143
171
  raise ValueError()
@@ -146,4 +174,38 @@ def _save_json(data: Dict[str, Any], save_as: str = "data.json"):
146
174
  fs = cloud_io.get_filesystem(output_dir, anon=False)
147
175
  fs.makedirs(output_dir, exist_ok=True)
148
176
  with fs.open(save_as, "w") as file:
149
- json.dump(data, file, indent=4, sort_keys=True)
177
+ json.dump(data, file, indent=2, sort_keys=True)
178
+
179
+
180
+ def _print_results(results: RESULTS_DICT) -> None:
181
+ """Prints the results to the console."""
182
+ try:
183
+ for stage in ["val", "test"]:
184
+ for dataset_idx in range(len(results["metrics"][stage])):
185
+ _print_table(results["metrics"][stage][dataset_idx], stage, dataset_idx)
186
+ except Exception as e:
187
+ logger.error(f"Failed to print the results: {e}")
188
+
189
+
190
+ def _print_table(metrics_dict: Dict[str, SESSION_STATISTICS], stage: str, dataset_idx: int):
191
+ """Prints the metrics of a single dataset as a table."""
192
+ metrics_table = rich_table.Table(
193
+ title=f"\n{stage.capitalize()} Dataset {dataset_idx}", title_style="bold"
194
+ )
195
+ metrics_table.add_column("Metric", style="cyan")
196
+ metrics_table.add_column("Mean", style="magenta")
197
+ metrics_table.add_column("Stdev", style="magenta")
198
+ metrics_table.add_column("All", style="magenta")
199
+
200
+ n_runs = len(metrics_dict[next(iter(metrics_dict))]["values"])
201
+ for metric_name, metric_dict in metrics_dict.items():
202
+ row = [
203
+ metric_name,
204
+ f'{metric_dict["mean"]:.3f}',
205
+ f'{metric_dict["stdev"]:.3f}',
206
+ ", ".join(f'{metric_dict["values"][i]:.3f}' for i in range(n_runs)),
207
+ ]
208
+ metrics_table.add_row(*row)
209
+
210
+ console = rich_console.Console()
211
+ console.print(metrics_table)
@@ -16,6 +16,7 @@ def run_evaluation_session(
16
16
  datamodule: datamodules.DataModule,
17
17
  *,
18
18
  n_runs: int = 1,
19
+ verbose: bool = True,
19
20
  ) -> None:
20
21
  """Runs a downstream evaluation session out-of-place.
21
22
 
@@ -29,11 +30,17 @@ def run_evaluation_session(
29
30
  base_model: The base model module to use.
30
31
  datamodule: The data module.
31
32
  n_runs: The amount of runs (fit and evaluate) to perform.
33
+ verbose: Whether to verbose the session metrics instead of
34
+ these of each individual runs and vice-versa.
32
35
  """
33
- recorder = _recorder.SessionRecorder(output_dir=base_trainer.default_log_dir)
36
+ recorder = _recorder.SessionRecorder(output_dir=base_trainer.default_log_dir, verbose=verbose)
34
37
  for run_index in range(n_runs):
35
38
  validation_scores, test_scores = run_evaluation(
36
- base_trainer, base_model, datamodule, run_id=f"run_{run_index}"
39
+ base_trainer,
40
+ base_model,
41
+ datamodule,
42
+ run_id=f"run_{run_index}",
43
+ verbose=not verbose,
37
44
  )
38
45
  recorder.update(validation_scores, test_scores)
39
46
  recorder.save()
@@ -45,6 +52,7 @@ def run_evaluation(
45
52
  datamodule: datamodules.DataModule,
46
53
  *,
47
54
  run_id: str | None = None,
55
+ verbose: bool = True,
48
56
  ) -> Tuple[_EVALUATE_OUTPUT, _EVALUATE_OUTPUT | None]:
49
57
  """Fits and evaluates a model out-of-place.
50
58
 
@@ -54,19 +62,23 @@ def run_evaluation(
54
62
  datamodule: The data module.
55
63
  run_id: The run id to be appended to the output log directory.
56
64
  If `None`, it will use the log directory of the trainer as is.
65
+ verbose: Whether to print the validation and test metrics
66
+ in the end of the training.
57
67
 
58
68
  Returns:
59
69
  A tuple of with the validation and the test metrics (if exists).
60
70
  """
61
71
  trainer, model = _utils.clone(base_trainer, base_model)
72
+ model.configure_model()
62
73
  trainer.setup_log_dirs(run_id or "")
63
- return fit_and_validate(trainer, model, datamodule)
74
+ return fit_and_validate(trainer, model, datamodule, verbose=verbose)
64
75
 
65
76
 
66
77
  def fit_and_validate(
67
78
  trainer: eva_trainer.Trainer,
68
79
  model: modules.ModelModule,
69
80
  datamodule: datamodules.DataModule,
81
+ verbose: bool = True,
70
82
  ) -> Tuple[_EVALUATE_OUTPUT, _EVALUATE_OUTPUT | None]:
71
83
  """Fits and evaluates a model in-place.
72
84
 
@@ -77,13 +89,19 @@ def fit_and_validate(
77
89
  trainer: The trainer module to use and update in-place.
78
90
  model: The model module to use and update in-place.
79
91
  datamodule: The data module.
92
+ verbose: Whether to print the validation and test metrics
93
+ in the end of the training.
80
94
 
81
95
  Returns:
82
96
  A tuple of with the validation and the test metrics (if exists).
83
97
  """
84
98
  trainer.fit(model, datamodule=datamodule)
85
- validation_scores = trainer.validate(datamodule=datamodule)
86
- test_scores = None if datamodule.datasets.test is None else trainer.test(datamodule=datamodule)
99
+ validation_scores = trainer.validate(datamodule=datamodule, verbose=verbose)
100
+ test_scores = (
101
+ None
102
+ if datamodule.datasets.test is None
103
+ else trainer.test(datamodule=datamodule, verbose=verbose)
104
+ )
87
105
  return validation_scores, test_scores
88
106
 
89
107