kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (168) hide show
  1. eva/core/callbacks/__init__.py +3 -2
  2. eva/core/callbacks/config.py +143 -0
  3. eva/core/callbacks/writers/__init__.py +6 -3
  4. eva/core/callbacks/writers/embeddings/__init__.py +6 -0
  5. eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
  6. eva/core/callbacks/writers/embeddings/base.py +192 -0
  7. eva/core/callbacks/writers/embeddings/classification.py +117 -0
  8. eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
  9. eva/core/callbacks/writers/embeddings/typings.py +38 -0
  10. eva/core/data/datasets/__init__.py +10 -2
  11. eva/core/data/datasets/classification/__init__.py +5 -2
  12. eva/core/data/datasets/classification/embeddings.py +15 -135
  13. eva/core/data/datasets/classification/multi_embeddings.py +110 -0
  14. eva/core/data/datasets/embeddings.py +167 -0
  15. eva/core/data/splitting/__init__.py +6 -0
  16. eva/core/data/splitting/random.py +41 -0
  17. eva/core/data/splitting/stratified.py +56 -0
  18. eva/core/data/transforms/__init__.py +3 -1
  19. eva/core/data/transforms/padding/__init__.py +5 -0
  20. eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
  21. eva/core/data/transforms/sampling/__init__.py +5 -0
  22. eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
  23. eva/core/loggers/__init__.py +7 -0
  24. eva/core/loggers/dummy.py +38 -0
  25. eva/core/loggers/experimental_loggers.py +8 -0
  26. eva/core/loggers/log/__init__.py +6 -0
  27. eva/core/loggers/log/image.py +71 -0
  28. eva/core/loggers/log/parameters.py +74 -0
  29. eva/core/loggers/log/utils.py +13 -0
  30. eva/core/loggers/loggers.py +6 -0
  31. eva/core/metrics/__init__.py +6 -2
  32. eva/core/metrics/defaults/__init__.py +10 -3
  33. eva/core/metrics/defaults/classification/__init__.py +1 -1
  34. eva/core/metrics/defaults/classification/binary.py +0 -9
  35. eva/core/metrics/defaults/classification/multiclass.py +0 -8
  36. eva/core/metrics/defaults/segmentation/__init__.py +5 -0
  37. eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
  38. eva/core/metrics/generalized_dice.py +59 -0
  39. eva/core/metrics/mean_iou.py +120 -0
  40. eva/core/metrics/structs/schemas.py +3 -1
  41. eva/core/models/__init__.py +3 -1
  42. eva/core/models/modules/head.py +16 -15
  43. eva/core/models/modules/module.py +25 -1
  44. eva/core/models/modules/typings.py +14 -1
  45. eva/core/models/modules/utils/batch_postprocess.py +37 -5
  46. eva/core/models/networks/__init__.py +1 -2
  47. eva/core/models/networks/mlp.py +2 -2
  48. eva/core/models/transforms/__init__.py +6 -0
  49. eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
  50. eva/core/models/transforms/extract_patch_features.py +47 -0
  51. eva/core/models/wrappers/__init__.py +13 -0
  52. eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
  53. eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
  54. eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
  55. eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
  56. eva/core/trainers/_recorder.py +69 -7
  57. eva/core/trainers/functional.py +23 -5
  58. eva/core/trainers/trainer.py +20 -6
  59. eva/core/utils/__init__.py +6 -0
  60. eva/core/utils/clone.py +27 -0
  61. eva/core/utils/memory.py +28 -0
  62. eva/core/utils/operations.py +26 -0
  63. eva/core/utils/parser.py +20 -0
  64. eva/vision/__init__.py +2 -2
  65. eva/vision/callbacks/__init__.py +5 -0
  66. eva/vision/callbacks/loggers/__init__.py +5 -0
  67. eva/vision/callbacks/loggers/batch/__init__.py +5 -0
  68. eva/vision/callbacks/loggers/batch/base.py +130 -0
  69. eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
  70. eva/vision/data/datasets/__init__.py +24 -4
  71. eva/vision/data/datasets/_utils.py +3 -3
  72. eva/vision/data/datasets/_validators.py +15 -2
  73. eva/vision/data/datasets/classification/__init__.py +6 -2
  74. eva/vision/data/datasets/classification/bach.py +10 -15
  75. eva/vision/data/datasets/classification/base.py +17 -24
  76. eva/vision/data/datasets/classification/camelyon16.py +244 -0
  77. eva/vision/data/datasets/classification/crc.py +10 -15
  78. eva/vision/data/datasets/classification/mhist.py +10 -15
  79. eva/vision/data/datasets/classification/panda.py +184 -0
  80. eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
  81. eva/vision/data/datasets/classification/wsi.py +105 -0
  82. eva/vision/data/datasets/segmentation/__init__.py +15 -2
  83. eva/vision/data/datasets/segmentation/_utils.py +38 -0
  84. eva/vision/data/datasets/segmentation/base.py +31 -47
  85. eva/vision/data/datasets/segmentation/bcss.py +236 -0
  86. eva/vision/data/datasets/segmentation/consep.py +156 -0
  87. eva/vision/data/datasets/segmentation/embeddings.py +34 -0
  88. eva/vision/data/datasets/segmentation/lits.py +178 -0
  89. eva/vision/data/datasets/segmentation/monusac.py +236 -0
  90. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
  91. eva/vision/data/datasets/wsi.py +187 -0
  92. eva/vision/data/transforms/__init__.py +3 -2
  93. eva/vision/data/transforms/common/__init__.py +2 -1
  94. eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
  95. eva/vision/data/transforms/common/resize_and_crop.py +6 -7
  96. eva/vision/data/transforms/normalization/__init__.py +6 -0
  97. eva/vision/data/transforms/normalization/clamp.py +43 -0
  98. eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
  99. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
  100. eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
  101. eva/vision/data/wsi/__init__.py +16 -0
  102. eva/vision/data/wsi/backends/__init__.py +69 -0
  103. eva/vision/data/wsi/backends/base.py +115 -0
  104. eva/vision/data/wsi/backends/openslide.py +73 -0
  105. eva/vision/data/wsi/backends/pil.py +52 -0
  106. eva/vision/data/wsi/backends/tiffslide.py +42 -0
  107. eva/vision/data/wsi/patching/__init__.py +6 -0
  108. eva/vision/data/wsi/patching/coordinates.py +98 -0
  109. eva/vision/data/wsi/patching/mask.py +123 -0
  110. eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
  111. eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
  112. eva/vision/data/wsi/patching/samplers/base.py +48 -0
  113. eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
  114. eva/vision/data/wsi/patching/samplers/grid.py +47 -0
  115. eva/vision/data/wsi/patching/samplers/random.py +41 -0
  116. eva/vision/losses/__init__.py +5 -0
  117. eva/vision/losses/dice.py +40 -0
  118. eva/vision/models/__init__.py +4 -2
  119. eva/vision/models/modules/__init__.py +5 -0
  120. eva/vision/models/modules/semantic_segmentation.py +161 -0
  121. eva/vision/models/networks/__init__.py +1 -2
  122. eva/vision/models/networks/backbones/__init__.py +6 -0
  123. eva/vision/models/networks/backbones/_utils.py +39 -0
  124. eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
  125. eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
  126. eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
  127. eva/vision/models/networks/backbones/pathology/histai.py +46 -0
  128. eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
  129. eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
  130. eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
  131. eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
  132. eva/vision/models/networks/backbones/registry.py +47 -0
  133. eva/vision/models/networks/backbones/timm/__init__.py +5 -0
  134. eva/vision/models/networks/backbones/timm/backbones.py +54 -0
  135. eva/vision/models/networks/backbones/universal/__init__.py +8 -0
  136. eva/vision/models/networks/backbones/universal/vit.py +54 -0
  137. eva/vision/models/networks/decoders/__init__.py +6 -0
  138. eva/vision/models/networks/decoders/decoder.py +7 -0
  139. eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
  140. eva/vision/models/networks/decoders/segmentation/common.py +74 -0
  141. eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
  142. eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
  143. eva/vision/models/wrappers/__init__.py +6 -0
  144. eva/vision/models/wrappers/from_registry.py +48 -0
  145. eva/vision/models/wrappers/from_timm.py +68 -0
  146. eva/vision/utils/colormap.py +77 -0
  147. eva/vision/utils/convert.py +67 -0
  148. eva/vision/utils/io/__init__.py +10 -4
  149. eva/vision/utils/io/image.py +21 -2
  150. eva/vision/utils/io/mat.py +36 -0
  151. eva/vision/utils/io/nifti.py +40 -15
  152. eva/vision/utils/io/text.py +10 -3
  153. kaiko_eva-0.1.0.dist-info/METADATA +553 -0
  154. kaiko_eva-0.1.0.dist-info/RECORD +205 -0
  155. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
  156. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
  157. eva/core/callbacks/writers/embeddings.py +0 -169
  158. eva/core/callbacks/writers/typings.py +0 -23
  159. eva/core/models/networks/transforms/__init__.py +0 -5
  160. eva/core/models/networks/wrappers/__init__.py +0 -8
  161. eva/vision/data/datasets/classification/total_segmentator.py +0 -213
  162. eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
  163. eva/vision/models/networks/postprocesses/__init__.py +0 -5
  164. eva/vision/models/networks/postprocesses/cls.py +0 -25
  165. kaiko_eva-0.0.1.dist-info/METADATA +0 -405
  166. kaiko_eva-0.0.1.dist-info/RECORD +0 -110
  167. /eva/core/models/{networks → wrappers}/_utils.py +0 -0
  168. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,8 @@
1
- """Classification datasets API."""
1
+ """Embedding cllassification datasets API."""
2
2
 
3
3
  from eva.core.data.datasets.classification.embeddings import EmbeddingsClassificationDataset
4
+ from eva.core.data.datasets.classification.multi_embeddings import (
5
+ MultiEmbeddingsClassificationDataset,
6
+ )
4
7
 
5
- __all__ = ["EmbeddingsClassificationDataset"]
8
+ __all__ = ["EmbeddingsClassificationDataset", "MultiEmbeddingsClassificationDataset"]
@@ -1,154 +1,34 @@
1
1
  """Embeddings classification dataset."""
2
2
 
3
3
  import os
4
- from typing import Callable, Dict, Tuple
5
4
 
6
- import numpy as np
7
- import pandas as pd
8
5
  import torch
9
6
  from typing_extensions import override
10
7
 
11
- from eva.core.data.datasets import base
12
- from eva.core.utils import io
8
+ from eva.core.data.datasets import embeddings as embeddings_base
13
9
 
14
10
 
15
- class EmbeddingsClassificationDataset(base.Dataset):
16
- """Embeddings classification dataset."""
17
-
18
- default_column_mapping: Dict[str, str] = {
19
- "data": "embeddings",
20
- "target": "target",
21
- "split": "split",
22
- }
23
- """The default column mapping of the variables to the manifest columns."""
24
-
25
- def __init__(
26
- self,
27
- root: str,
28
- manifest_file: str,
29
- split: str | None = None,
30
- column_mapping: Dict[str, str] = default_column_mapping,
31
- embeddings_transforms: Callable | None = None,
32
- target_transforms: Callable | None = None,
33
- ) -> None:
34
- """Initialize dataset.
35
-
36
- Expects a manifest file listing the paths of .pt files that contain
37
- tensor embeddings of shape [embedding_dim] or [1, embedding_dim].
38
-
39
- Args:
40
- root: Root directory of the dataset.
41
- manifest_file: The path to the manifest file, which is relative to
42
- the `root` argument.
43
- split: The dataset split to use. The `split` column of the manifest
44
- file will be splitted based on this value.
45
- column_mapping: Defines the map between the variables and the manifest
46
- columns. It will overwrite the `default_column_mapping` with
47
- the provided values, so that `column_mapping` can contain only the
48
- values which are altered or missing.
49
- embeddings_transforms: A function/transform that transforms the embedding.
50
- target_transforms: A function/transform that transforms the target.
51
- """
52
- super().__init__()
53
-
54
- self._root = root
55
- self._manifest_file = manifest_file
56
- self._split = split
57
- self._column_mapping = self.default_column_mapping | column_mapping
58
- self._embeddings_transforms = embeddings_transforms
59
- self._target_transforms = target_transforms
60
-
61
- self._data: pd.DataFrame
62
-
63
- def filename(self, index: int) -> str:
64
- """Returns the filename of the `index`'th data sample.
65
-
66
- Note that this is the relative file path to the root.
67
-
68
- Args:
69
- index: The index of the data-sample to select.
70
-
71
- Returns:
72
- The filename of the `index`'th data sample.
73
- """
74
- return self._data.at[index, self._column_mapping["data"]]
11
+ class EmbeddingsClassificationDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]):
12
+ """Embeddings dataset class for classification tasks."""
75
13
 
76
14
  @override
77
- def setup(self):
78
- self._data = self._load_manifest()
79
-
80
- def __getitem__(self, index) -> Tuple[torch.Tensor, np.ndarray]:
81
- """Returns the `index`'th data sample.
82
-
83
- Args:
84
- index: The index of the data-sample to select.
85
-
86
- Returns:
87
- A data sample and its target.
88
- """
89
- embeddings = self._load_embeddings(index)
90
- target = self._load_target(index)
91
- return self._apply_transforms(embeddings, target)
92
-
93
- def __len__(self) -> int:
94
- """Returns the total length of the data."""
95
- return len(self._data)
96
-
97
15
  def _load_embeddings(self, index: int) -> torch.Tensor:
98
- """Returns the `index`'th embedding sample.
99
-
100
- Args:
101
- index: The index of the data sample to load.
102
-
103
- Returns:
104
- The sample embedding as an array.
105
- """
106
16
  filename = self.filename(index)
107
17
  embeddings_path = os.path.join(self._root, filename)
108
18
  tensor = torch.load(embeddings_path, map_location="cpu")
19
+ if isinstance(tensor, list):
20
+ if len(tensor) > 1:
21
+ raise ValueError(
22
+ f"Expected a single tensor in the .pt file, but found {len(tensor)}."
23
+ )
24
+ tensor = tensor[0]
109
25
  return tensor.squeeze(0)
110
26
 
111
- def _load_target(self, index: int) -> np.ndarray:
112
- """Returns the `index`'th target sample.
113
-
114
- Args:
115
- index: The index of the data sample to load.
116
-
117
- Returns:
118
- The sample target as an array.
119
- """
27
+ @override
28
+ def _load_target(self, index: int) -> torch.Tensor:
120
29
  target = self._data.at[index, self._column_mapping["target"]]
121
- return np.asarray(target, dtype=np.int64)
122
-
123
- def _load_manifest(self) -> pd.DataFrame:
124
- """Loads manifest file and filters the data based on the split column.
125
-
126
- Returns:
127
- The data as a pandas DataFrame.
128
- """
129
- manifest_path = os.path.join(self._root, self._manifest_file)
130
- data = io.read_dataframe(manifest_path)
131
- if self._split is not None:
132
- filtered_data = data.loc[data[self._column_mapping["split"]] == self._split]
133
- data = filtered_data.reset_index(drop=True)
134
- return data
30
+ return torch.tensor(target, dtype=torch.int64)
135
31
 
136
- def _apply_transforms(
137
- self, embeddings: torch.Tensor, target: np.ndarray
138
- ) -> Tuple[torch.Tensor, np.ndarray]:
139
- """Applies the transforms to the provided data and returns them.
140
-
141
- Args:
142
- embeddings: The embeddings to be transformed.
143
- target: The training target.
144
-
145
- Returns:
146
- A tuple with the embeddings and the target transformed.
147
- """
148
- if self._embeddings_transforms is not None:
149
- embeddings = self._embeddings_transforms(embeddings)
150
-
151
- if self._target_transforms is not None:
152
- target = self._target_transforms(target)
153
-
154
- return embeddings, target
32
+ @override
33
+ def __len__(self) -> int:
34
+ return len(self._data)
@@ -0,0 +1,110 @@
1
+ """Dataset class for where a sample corresponds to multiple embeddings."""
2
+
3
+ import os
4
+ from typing import Callable, Dict, List, Literal
5
+
6
+ import numpy as np
7
+ import torch
8
+ from typing_extensions import override
9
+
10
+ from eva.core.data.datasets import embeddings as embeddings_base
11
+
12
+
13
+ class MultiEmbeddingsClassificationDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]):
14
+ """Dataset class for where a sample corresponds to multiple embeddings.
15
+
16
+ Example use case: Slide level dataset where each slide has multiple patch embeddings.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ root: str,
22
+ manifest_file: str,
23
+ split: Literal["train", "val", "test"],
24
+ column_mapping: Dict[str, str] = embeddings_base.default_column_mapping,
25
+ embeddings_transforms: Callable | None = None,
26
+ target_transforms: Callable | None = None,
27
+ ):
28
+ """Initialize dataset.
29
+
30
+ Expects a manifest file listing the paths of `.pt` files containing tensor embeddings.
31
+
32
+ The manifest must have a `column_mapping["multi_id"]` column that contains the
33
+ unique identifier group of embeddings. For oncology datasets, this would be usually
34
+ the slide id. Each row in the manifest file points to a .pt file that can contain
35
+ one or multiple embeddings (either as a list or stacked tensors). There can also be
36
+ multiple rows for the same `multi_id`, in which case the embeddings from the different
37
+ .pt files corresponding to that same `multi_id` will be stacked along the first dimension.
38
+
39
+ Args:
40
+ root: Root directory of the dataset.
41
+ manifest_file: The path to the manifest file, which is relative to
42
+ the `root` argument.
43
+ split: The dataset split to use. The `split` column of the manifest
44
+ file will be splitted based on this value.
45
+ column_mapping: Defines the map between the variables and the manifest
46
+ columns. It will overwrite the `default_column_mapping` with
47
+ the provided values, so that `column_mapping` can contain only the
48
+ values which are altered or missing.
49
+ embeddings_transforms: A function/transform that transforms the embedding.
50
+ target_transforms: A function/transform that transforms the target.
51
+ """
52
+ super().__init__(
53
+ manifest_file=manifest_file,
54
+ root=root,
55
+ split=split,
56
+ column_mapping=column_mapping,
57
+ embeddings_transforms=embeddings_transforms,
58
+ target_transforms=target_transforms,
59
+ )
60
+
61
+ self._multi_ids: List[int]
62
+
63
+ @override
64
+ def setup(self):
65
+ super().setup()
66
+ self._multi_ids = list(self._data[self._column_mapping["multi_id"]].unique())
67
+
68
+ @override
69
+ def _load_embeddings(self, index: int) -> torch.Tensor:
70
+ """Loads and stacks all embedding corresponding to the `index`'th multi_id."""
71
+ # Get all embeddings for the given index (multi_id)
72
+ multi_id = self._multi_ids[index]
73
+ embedding_paths = self._data.loc[
74
+ self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["path"]
75
+ ].to_list()
76
+
77
+ # Load embeddings and stack them accross the first dimension
78
+ embeddings = []
79
+ for path in embedding_paths:
80
+ embedding = torch.load(os.path.join(self._root, path), map_location="cpu")
81
+ if isinstance(embedding, list):
82
+ embedding = torch.stack(embedding, dim=0)
83
+ embeddings.append(embedding.unsqueeze(0) if embedding.ndim == 1 else embedding)
84
+ embeddings = torch.cat(embeddings, dim=0)
85
+
86
+ if not embeddings.ndim == 2:
87
+ raise ValueError(f"Expected 2D tensor, got {embeddings.ndim} for {multi_id}.")
88
+
89
+ return embeddings
90
+
91
+ @override
92
+ def _load_target(self, index: int) -> np.ndarray:
93
+ """Returns the target corresponding to the `index`'th multi_id.
94
+
95
+ This method assumes that all the embeddings corresponding to the same `multi_id`
96
+ have the same target. If this is not the case, it will raise an error.
97
+ """
98
+ multi_id = self._multi_ids[index]
99
+ targets = self._data.loc[
100
+ self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["target"]
101
+ ]
102
+
103
+ if not targets.nunique() == 1:
104
+ raise ValueError(f"Multiple targets found for {multi_id}.")
105
+
106
+ return np.asarray(targets.iloc[0], dtype=np.int64)
107
+
108
+ @override
109
+ def __len__(self) -> int:
110
+ return len(self._multi_ids)
@@ -0,0 +1,167 @@
1
+ """Base dataset class for Embeddings."""
2
+
3
+ import abc
4
+ import multiprocessing
5
+ import os
6
+ from typing import Callable, Dict, Generic, Literal, Tuple, TypeVar
7
+
8
+ import pandas as pd
9
+ import torch
10
+ from typing_extensions import override
11
+
12
+ from eva.core.data.datasets import base
13
+ from eva.core.utils import io
14
+
15
+ TargetType = TypeVar("TargetType")
16
+ """The target data type."""
17
+
18
+
19
+ default_column_mapping: Dict[str, str] = {
20
+ "path": "embeddings",
21
+ "target": "target",
22
+ "split": "split",
23
+ "multi_id": "wsi_id",
24
+ }
25
+ """The default column mapping of the variables to the manifest columns."""
26
+
27
+
28
+ class EmbeddingsDataset(base.Dataset, Generic[TargetType]):
29
+ """Abstract base class for embedding datasets."""
30
+
31
+ def __init__(
32
+ self,
33
+ root: str,
34
+ manifest_file: str,
35
+ split: Literal["train", "val", "test"] | None = None,
36
+ column_mapping: Dict[str, str] = default_column_mapping,
37
+ embeddings_transforms: Callable | None = None,
38
+ target_transforms: Callable | None = None,
39
+ ) -> None:
40
+ """Initialize dataset.
41
+
42
+ Expects a manifest file listing the paths of .pt files that contain
43
+ tensor embeddings of shape [embedding_dim] or [1, embedding_dim].
44
+
45
+ Args:
46
+ root: Root directory of the dataset.
47
+ manifest_file: The path to the manifest file, which is relative to
48
+ the `root` argument.
49
+ split: The dataset split to use. The `split` column of the manifest
50
+ file will be splitted based on this value.
51
+ column_mapping: Defines the map between the variables and the manifest
52
+ columns. It will overwrite the `default_column_mapping` with
53
+ the provided values, so that `column_mapping` can contain only the
54
+ values which are altered or missing.
55
+ embeddings_transforms: A function/transform that transforms the embedding.
56
+ target_transforms: A function/transform that transforms the target.
57
+ """
58
+ super().__init__()
59
+
60
+ self._root = root
61
+ self._manifest_file = manifest_file
62
+ self._split = split
63
+ self._column_mapping = default_column_mapping | column_mapping
64
+ self._embeddings_transforms = embeddings_transforms
65
+ self._target_transforms = target_transforms
66
+
67
+ self._data: pd.DataFrame
68
+
69
+ self._set_multiprocessing_start_method()
70
+
71
+ def filename(self, index: int) -> str:
72
+ """Returns the filename of the `index`'th data sample.
73
+
74
+ Note that this is the relative file path to the root.
75
+
76
+ Args:
77
+ index: The index of the data-sample to select.
78
+
79
+ Returns:
80
+ The filename of the `index`'th data sample.
81
+ """
82
+ return self._data.at[index, self._column_mapping["path"]]
83
+
84
+ @override
85
+ def setup(self):
86
+ self._data = self._load_manifest()
87
+
88
+ @abc.abstractmethod
89
+ def __len__(self) -> int:
90
+ """Returns the total length of the data."""
91
+
92
+ def __getitem__(self, index) -> Tuple[torch.Tensor, TargetType]:
93
+ """Returns the `index`'th data sample.
94
+
95
+ Args:
96
+ index: The index of the data-sample to select.
97
+
98
+ Returns:
99
+ A data sample and its target.
100
+ """
101
+ embeddings = self._load_embeddings(index)
102
+ target = self._load_target(index)
103
+ return self._apply_transforms(embeddings, target)
104
+
105
+ @abc.abstractmethod
106
+ def _load_embeddings(self, index: int) -> torch.Tensor:
107
+ """Returns the `index`'th embedding sample.
108
+
109
+ Args:
110
+ index: The index of the data sample to load.
111
+
112
+ Returns:
113
+ The embedding sample as a tensor.
114
+ """
115
+
116
+ @abc.abstractmethod
117
+ def _load_target(self, index: int) -> TargetType:
118
+ """Returns the `index`'th target sample.
119
+
120
+ Args:
121
+ index: The index of the data sample to load.
122
+
123
+ Returns:
124
+ The sample target as an array.
125
+ """
126
+
127
+ def _load_manifest(self) -> pd.DataFrame:
128
+ """Loads manifest file and filters the data based on the split column.
129
+
130
+ Returns:
131
+ The data as a pandas DataFrame.
132
+ """
133
+ manifest_path = os.path.join(self._root, self._manifest_file)
134
+ data = io.read_dataframe(manifest_path)
135
+ if self._split is not None:
136
+ filtered_data = data.loc[data[self._column_mapping["split"]] == self._split]
137
+ data = filtered_data.reset_index(drop=True)
138
+ return data
139
+
140
+ def _apply_transforms(
141
+ self, embeddings: torch.Tensor, target: TargetType
142
+ ) -> Tuple[torch.Tensor, TargetType]:
143
+ """Applies the transforms to the provided data and returns them.
144
+
145
+ Args:
146
+ embeddings: The embeddings to be transformed.
147
+ target: The training target.
148
+
149
+ Returns:
150
+ A tuple with the embeddings and the target transformed.
151
+ """
152
+ if self._embeddings_transforms is not None:
153
+ embeddings = self._embeddings_transforms(embeddings)
154
+
155
+ if self._target_transforms is not None:
156
+ target = self._target_transforms(target)
157
+
158
+ return embeddings, target
159
+
160
+ def _set_multiprocessing_start_method(self):
161
+ """Sets the multiprocessing start method to spawn.
162
+
163
+ If the start method is not set explicitly, the torch data loaders will
164
+ use the OS default method, which for some unix systems is `fork` and
165
+ can lead to runtime issues such as deadlocks in this context.
166
+ """
167
+ multiprocessing.set_start_method("spawn", force=True)
@@ -0,0 +1,6 @@
1
+ """Dataset splitting API."""
2
+
3
+ from eva.core.data.splitting.random import random_split
4
+ from eva.core.data.splitting.stratified import stratified_split
5
+
6
+ __all__ = ["random_split", "stratified_split"]
@@ -0,0 +1,41 @@
1
+ """Functions for random splitting."""
2
+
3
+ from typing import Any, List, Sequence, Tuple
4
+
5
+ import numpy as np
6
+
7
+
8
+ def random_split(
9
+ samples: Sequence[Any],
10
+ train_ratio: float,
11
+ val_ratio: float,
12
+ test_ratio: float = 0.0,
13
+ seed: int = 42,
14
+ ) -> Tuple[List[int], List[int], List[int] | None]:
15
+ """Splits the samples into random train, validation, and test (optional) sets.
16
+
17
+ Args:
18
+ samples: The samples to split.
19
+ train_ratio: The ratio of the training set.
20
+ val_ratio: The ratio of the validation set.
21
+ test_ratio: The ratio of the test set (optional).
22
+ seed: The seed for reproducibility.
23
+
24
+ Returns:
25
+ The indices of the train, validation, and test sets as lists.
26
+ """
27
+ if train_ratio + val_ratio + (test_ratio or 0) != 1:
28
+ raise ValueError("The sum of the ratios must be equal to 1.")
29
+
30
+ np.random.seed(seed)
31
+ n_samples = len(samples)
32
+ indices = np.random.permutation(n_samples)
33
+
34
+ n_train = int(np.floor(train_ratio * n_samples))
35
+ n_val = n_samples - n_train if test_ratio == 0.0 else int(np.floor(val_ratio * n_samples)) or 1
36
+
37
+ train_indices = list(indices[:n_train])
38
+ val_indices = list(indices[n_train : n_train + n_val])
39
+ test_indices = list(indices[n_train + n_val :]) if test_ratio > 0.0 else None
40
+
41
+ return train_indices, val_indices, test_indices
@@ -0,0 +1,56 @@
1
+ """Functions for stratified splitting."""
2
+
3
+ from typing import Any, List, Sequence, Tuple
4
+
5
+ import numpy as np
6
+
7
+
8
+ def stratified_split(
9
+ samples: Sequence[Any],
10
+ targets: Sequence[Any],
11
+ train_ratio: float,
12
+ val_ratio: float,
13
+ test_ratio: float = 0.0,
14
+ seed: int = 42,
15
+ ) -> Tuple[List[int], List[int], List[int] | None]:
16
+ """Splits the samples into stratified train, validation, and test (optional) sets.
17
+
18
+ Args:
19
+ samples: The samples to split.
20
+ targets: The corresponding targets used for stratification.
21
+ train_ratio: The ratio of the training set.
22
+ val_ratio: The ratio of the validation set.
23
+ test_ratio: The ratio of the test set (optional).
24
+ seed: The seed for reproducibility.
25
+
26
+ Returns:
27
+ The indices of the train, validation, and test sets.
28
+ """
29
+ if len(samples) != len(targets):
30
+ raise ValueError("The number of samples and targets must be equal.")
31
+ if train_ratio + val_ratio + (test_ratio or 0) != 1:
32
+ raise ValueError("The sum of the ratios must be equal to 1.")
33
+
34
+ np.random.seed(seed)
35
+ unique_classes, y_indices = np.unique(targets, return_inverse=True)
36
+ n_classes = unique_classes.shape[0]
37
+
38
+ train_indices, val_indices, test_indices = [], [], []
39
+
40
+ for c in range(n_classes):
41
+ class_indices = np.where(y_indices == c)[0]
42
+ np.random.shuffle(class_indices)
43
+
44
+ n_train = int(np.floor(train_ratio * len(class_indices))) or 1
45
+ n_val = (
46
+ len(class_indices) - n_train
47
+ if test_ratio == 0.0
48
+ else int(np.floor(val_ratio * len(class_indices))) or 1
49
+ )
50
+
51
+ train_indices.extend(class_indices[:n_train])
52
+ val_indices.extend(class_indices[n_train : n_train + n_val])
53
+ if test_ratio > 0.0:
54
+ test_indices.extend(class_indices[n_train + n_val :])
55
+
56
+ return train_indices, val_indices, test_indices or None
@@ -1,5 +1,7 @@
1
1
  """Core data transforms."""
2
2
 
3
3
  from eva.core.data.transforms.dtype import ArrayToFloatTensor, ArrayToTensor
4
+ from eva.core.data.transforms.padding import Pad2DTensor
5
+ from eva.core.data.transforms.sampling import SampleFromAxis
4
6
 
5
- __all__ = ["ArrayToFloatTensor", "ArrayToTensor"]
7
+ __all__ = ["ArrayToFloatTensor", "ArrayToTensor", "Pad2DTensor", "SampleFromAxis"]
@@ -0,0 +1,5 @@
1
+ """Padding related transformations."""
2
+
3
+ from eva.core.data.transforms.padding.pad_2d_tensor import Pad2DTensor
4
+
5
+ __all__ = ["Pad2DTensor"]
@@ -0,0 +1,38 @@
1
+ """Padding transformation for 2D tensors."""
2
+
3
+ import torch
4
+ import torch.nn.functional
5
+
6
+
7
+ class Pad2DTensor:
8
+ """Pads a 2D tensor to a fixed dimension accross the first dimension."""
9
+
10
+ def __init__(self, pad_size: int, pad_value: int | float = float("-inf")):
11
+ """Initialize the transformation.
12
+
13
+ Args:
14
+ pad_size: The size to pad the tensor to. If the tensor is larger than this size,
15
+ no padding will be applied.
16
+ pad_value: The value to use for padding.
17
+ """
18
+ self._pad_size = pad_size
19
+ self._pad_value = pad_value
20
+
21
+ def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
22
+ """Call method for the transformation.
23
+
24
+ Args:
25
+ tensor: The input tensor of shape [n, embedding_dim].
26
+
27
+ Returns:
28
+ A tensor of shape [max(n, pad_dim), embedding_dim].
29
+ """
30
+ n_pad_values = self._pad_size - tensor.size(0)
31
+ if n_pad_values > 0:
32
+ tensor = torch.nn.functional.pad(
33
+ tensor,
34
+ pad=(0, 0, 0, n_pad_values),
35
+ mode="constant",
36
+ value=self._pad_value,
37
+ )
38
+ return tensor
@@ -0,0 +1,5 @@
1
+ """Sampling related transformations."""
2
+
3
+ from eva.core.data.transforms.sampling.sample_from_axis import SampleFromAxis
4
+
5
+ __all__ = ["SampleFromAxis"]
@@ -0,0 +1,40 @@
1
+ """Sampling transformations."""
2
+
3
+ import torch
4
+
5
+
6
+ class SampleFromAxis:
7
+ """Samples n_samples entries from a tensor along a given axis."""
8
+
9
+ def __init__(self, n_samples: int, seed: int = 42, axis: int = 0):
10
+ """Initialize the transformation.
11
+
12
+ Args:
13
+ n_samples: The number of samples to draw.
14
+ seed: The seed to use for sampling.
15
+ axis: The axis along which to sample.
16
+ """
17
+ self._seed = seed
18
+ self._n_samples = n_samples
19
+ self._axis = axis
20
+ self._generator = self._get_generator()
21
+
22
+ def _get_generator(self):
23
+ """Return a torch random generator with fixed seed."""
24
+ generator = torch.Generator()
25
+ generator.manual_seed(self._seed)
26
+ return generator
27
+
28
+ def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
29
+ """Call method for the transformation.
30
+
31
+ Args:
32
+ tensor: The input tensor of shape [n, embedding_dim].
33
+
34
+ Returns:
35
+ A tensor of shape [n_samples, embedding_dim].
36
+ """
37
+ indices = torch.randperm(tensor.size(self._axis), generator=self._generator)[
38
+ : self._n_samples
39
+ ]
40
+ return tensor.index_select(self._axis, indices)
@@ -0,0 +1,7 @@
1
+ """Experimental loggers API."""
2
+
3
+ from eva.core.loggers.dummy import DummyLogger
4
+ from eva.core.loggers.experimental_loggers import ExperimentalLoggers
5
+ from eva.core.loggers.log import log_parameters
6
+
7
+ __all__ = ["DummyLogger", "ExperimentalLoggers", "log_parameters"]