kaiko-eva 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (159) hide show
  1. eva/core/callbacks/__init__.py +2 -2
  2. eva/core/callbacks/writers/__init__.py +6 -3
  3. eva/core/callbacks/writers/embeddings/__init__.py +6 -0
  4. eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
  5. eva/core/callbacks/writers/embeddings/base.py +192 -0
  6. eva/core/callbacks/writers/embeddings/classification.py +117 -0
  7. eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
  8. eva/core/callbacks/writers/embeddings/typings.py +38 -0
  9. eva/core/data/datasets/__init__.py +2 -2
  10. eva/core/data/datasets/classification/__init__.py +8 -0
  11. eva/core/data/datasets/classification/embeddings.py +34 -0
  12. eva/core/data/datasets/{embeddings/classification → classification}/multi_embeddings.py +13 -9
  13. eva/core/data/datasets/{embeddings/base.py → embeddings.py} +47 -32
  14. eva/core/data/splitting/__init__.py +6 -0
  15. eva/core/data/splitting/random.py +41 -0
  16. eva/core/data/splitting/stratified.py +56 -0
  17. eva/core/loggers/experimental_loggers.py +2 -2
  18. eva/core/loggers/log/__init__.py +3 -2
  19. eva/core/loggers/log/image.py +71 -0
  20. eva/core/loggers/log/parameters.py +10 -0
  21. eva/core/loggers/loggers.py +6 -0
  22. eva/core/metrics/__init__.py +6 -2
  23. eva/core/metrics/defaults/__init__.py +10 -3
  24. eva/core/metrics/defaults/classification/__init__.py +1 -1
  25. eva/core/metrics/defaults/classification/binary.py +0 -9
  26. eva/core/metrics/defaults/classification/multiclass.py +0 -8
  27. eva/core/metrics/defaults/segmentation/__init__.py +5 -0
  28. eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
  29. eva/core/metrics/generalized_dice.py +59 -0
  30. eva/core/metrics/mean_iou.py +120 -0
  31. eva/core/metrics/structs/schemas.py +3 -1
  32. eva/core/models/__init__.py +3 -1
  33. eva/core/models/modules/head.py +10 -4
  34. eva/core/models/modules/typings.py +14 -1
  35. eva/core/models/modules/utils/batch_postprocess.py +37 -5
  36. eva/core/models/networks/__init__.py +1 -2
  37. eva/core/models/networks/mlp.py +2 -2
  38. eva/core/models/transforms/__init__.py +6 -0
  39. eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
  40. eva/core/models/transforms/extract_patch_features.py +47 -0
  41. eva/core/models/wrappers/__init__.py +13 -0
  42. eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
  43. eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
  44. eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
  45. eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
  46. eva/core/trainers/functional.py +1 -0
  47. eva/core/utils/__init__.py +6 -0
  48. eva/core/utils/clone.py +27 -0
  49. eva/core/utils/memory.py +28 -0
  50. eva/core/utils/operations.py +26 -0
  51. eva/core/utils/parser.py +20 -0
  52. eva/vision/__init__.py +2 -2
  53. eva/vision/callbacks/__init__.py +5 -0
  54. eva/vision/callbacks/loggers/__init__.py +5 -0
  55. eva/vision/callbacks/loggers/batch/__init__.py +5 -0
  56. eva/vision/callbacks/loggers/batch/base.py +130 -0
  57. eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
  58. eva/vision/data/datasets/__init__.py +30 -3
  59. eva/vision/data/datasets/_validators.py +15 -2
  60. eva/vision/data/datasets/classification/__init__.py +12 -1
  61. eva/vision/data/datasets/classification/bach.py +10 -15
  62. eva/vision/data/datasets/classification/base.py +17 -24
  63. eva/vision/data/datasets/classification/camelyon16.py +244 -0
  64. eva/vision/data/datasets/classification/crc.py +10 -15
  65. eva/vision/data/datasets/classification/mhist.py +10 -15
  66. eva/vision/data/datasets/classification/panda.py +184 -0
  67. eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
  68. eva/vision/data/datasets/classification/wsi.py +105 -0
  69. eva/vision/data/datasets/segmentation/__init__.py +15 -2
  70. eva/vision/data/datasets/segmentation/_utils.py +38 -0
  71. eva/vision/data/datasets/segmentation/base.py +16 -17
  72. eva/vision/data/datasets/segmentation/bcss.py +236 -0
  73. eva/vision/data/datasets/segmentation/consep.py +156 -0
  74. eva/vision/data/datasets/segmentation/embeddings.py +34 -0
  75. eva/vision/data/datasets/segmentation/lits.py +178 -0
  76. eva/vision/data/datasets/segmentation/monusac.py +236 -0
  77. eva/vision/data/datasets/segmentation/{total_segmentator.py → total_segmentator_2d.py} +130 -36
  78. eva/vision/data/datasets/wsi.py +187 -0
  79. eva/vision/data/transforms/__init__.py +3 -2
  80. eva/vision/data/transforms/common/__init__.py +2 -1
  81. eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
  82. eva/vision/data/transforms/common/resize_and_crop.py +6 -7
  83. eva/vision/data/transforms/normalization/__init__.py +6 -0
  84. eva/vision/data/transforms/normalization/clamp.py +43 -0
  85. eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
  86. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
  87. eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
  88. eva/vision/data/wsi/__init__.py +16 -0
  89. eva/vision/data/wsi/backends/__init__.py +69 -0
  90. eva/vision/data/wsi/backends/base.py +115 -0
  91. eva/vision/data/wsi/backends/openslide.py +73 -0
  92. eva/vision/data/wsi/backends/pil.py +52 -0
  93. eva/vision/data/wsi/backends/tiffslide.py +42 -0
  94. eva/vision/data/wsi/patching/__init__.py +6 -0
  95. eva/vision/data/wsi/patching/coordinates.py +98 -0
  96. eva/vision/data/wsi/patching/mask.py +123 -0
  97. eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
  98. eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
  99. eva/vision/data/wsi/patching/samplers/base.py +48 -0
  100. eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
  101. eva/vision/data/wsi/patching/samplers/grid.py +47 -0
  102. eva/vision/data/wsi/patching/samplers/random.py +41 -0
  103. eva/vision/losses/__init__.py +5 -0
  104. eva/vision/losses/dice.py +40 -0
  105. eva/vision/models/__init__.py +4 -2
  106. eva/vision/models/modules/__init__.py +5 -0
  107. eva/vision/models/modules/semantic_segmentation.py +161 -0
  108. eva/vision/models/networks/__init__.py +1 -2
  109. eva/vision/models/networks/backbones/__init__.py +6 -0
  110. eva/vision/models/networks/backbones/_utils.py +39 -0
  111. eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
  112. eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
  113. eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
  114. eva/vision/models/networks/backbones/pathology/histai.py +46 -0
  115. eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
  116. eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
  117. eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
  118. eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
  119. eva/vision/models/networks/backbones/registry.py +47 -0
  120. eva/vision/models/networks/backbones/timm/__init__.py +5 -0
  121. eva/vision/models/networks/backbones/timm/backbones.py +54 -0
  122. eva/vision/models/networks/backbones/universal/__init__.py +8 -0
  123. eva/vision/models/networks/backbones/universal/vit.py +54 -0
  124. eva/vision/models/networks/decoders/__init__.py +6 -0
  125. eva/vision/models/networks/decoders/decoder.py +7 -0
  126. eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
  127. eva/vision/models/networks/decoders/segmentation/common.py +74 -0
  128. eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
  129. eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
  130. eva/vision/models/wrappers/__init__.py +6 -0
  131. eva/vision/models/wrappers/from_registry.py +48 -0
  132. eva/vision/models/wrappers/from_timm.py +68 -0
  133. eva/vision/utils/colormap.py +77 -0
  134. eva/vision/utils/convert.py +56 -13
  135. eva/vision/utils/io/__init__.py +10 -4
  136. eva/vision/utils/io/image.py +21 -2
  137. eva/vision/utils/io/mat.py +36 -0
  138. eva/vision/utils/io/nifti.py +33 -12
  139. eva/vision/utils/io/text.py +10 -3
  140. kaiko_eva-0.1.1.dist-info/METADATA +553 -0
  141. kaiko_eva-0.1.1.dist-info/RECORD +205 -0
  142. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/WHEEL +1 -1
  143. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/entry_points.txt +2 -0
  144. eva/.DS_Store +0 -0
  145. eva/core/callbacks/writers/embeddings.py +0 -169
  146. eva/core/callbacks/writers/typings.py +0 -23
  147. eva/core/data/datasets/embeddings/__init__.py +0 -13
  148. eva/core/data/datasets/embeddings/classification/__init__.py +0 -10
  149. eva/core/data/datasets/embeddings/classification/embeddings.py +0 -66
  150. eva/core/models/networks/transforms/__init__.py +0 -5
  151. eva/core/models/networks/wrappers/__init__.py +0 -8
  152. eva/vision/models/.DS_Store +0 -0
  153. eva/vision/models/networks/.DS_Store +0 -0
  154. eva/vision/models/networks/postprocesses/__init__.py +0 -5
  155. eva/vision/models/networks/postprocesses/cls.py +0 -25
  156. kaiko_eva-0.0.2.dist-info/METADATA +0 -431
  157. kaiko_eva-0.0.2.dist-info/RECORD +0 -127
  158. /eva/core/models/{networks → wrappers}/_utils.py +0 -0
  159. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,6 @@
1
1
  """Callbacks API."""
2
2
 
3
3
  from eva.core.callbacks.config import ConfigurationLogger
4
- from eva.core.callbacks.writers import EmbeddingsWriter
4
+ from eva.core.callbacks.writers import ClassificationEmbeddingsWriter, SegmentationEmbeddingsWriter
5
5
 
6
- __all__ = ["ConfigurationLogger", "EmbeddingsWriter"]
6
+ __all__ = ["ConfigurationLogger", "ClassificationEmbeddingsWriter", "SegmentationEmbeddingsWriter"]
@@ -1,5 +1,8 @@
1
- """Callbacks API."""
1
+ """Writers callbacks API."""
2
2
 
3
- from eva.core.callbacks.writers.embeddings import EmbeddingsWriter
3
+ from eva.core.callbacks.writers.embeddings import (
4
+ ClassificationEmbeddingsWriter,
5
+ SegmentationEmbeddingsWriter,
6
+ )
4
7
 
5
- __all__ = ["EmbeddingsWriter"]
8
+ __all__ = ["ClassificationEmbeddingsWriter", "SegmentationEmbeddingsWriter"]
@@ -0,0 +1,6 @@
1
+ """Embedding callback writers."""
2
+
3
+ from eva.core.callbacks.writers.embeddings.classification import ClassificationEmbeddingsWriter
4
+ from eva.core.callbacks.writers.embeddings.segmentation import SegmentationEmbeddingsWriter
5
+
6
+ __all__ = ["ClassificationEmbeddingsWriter", "SegmentationEmbeddingsWriter"]
@@ -0,0 +1,71 @@
1
+ """Manifest file manager."""
2
+
3
+ import csv
4
+ import io
5
+ import os
6
+ from typing import Any, Dict, List
7
+
8
+ import _csv
9
+ import torch
10
+
11
+
12
+ class ManifestManager:
13
+ """Class for writing the embedding manifest files."""
14
+
15
+ def __init__(
16
+ self,
17
+ output_dir: str,
18
+ metadata_keys: List[str] | None = None,
19
+ overwrite: bool = False,
20
+ ) -> None:
21
+ """Initializes the writing manager.
22
+
23
+ Args:
24
+ output_dir: The directory where the embeddings will be saved.
25
+ metadata_keys: An optional list of keys to extract from the batch
26
+ metadata and store as additional columns in the manifest file.
27
+ overwrite: Whether to overwrite the output directory.
28
+ """
29
+ self._output_dir = output_dir
30
+ self._metadata_keys = metadata_keys or []
31
+ self._overwrite = overwrite
32
+
33
+ self._manifest_file: io.TextIOWrapper
34
+ self._manifest_writer: _csv.Writer # type: ignore
35
+
36
+ self._setup()
37
+
38
+ def _setup(self) -> None:
39
+ """Initializes the manifest file and sets the file object and writer."""
40
+ manifest_path = os.path.join(self._output_dir, "manifest.csv")
41
+ if os.path.exists(manifest_path) and not self._overwrite:
42
+ raise FileExistsError(
43
+ f"A manifest file already exists at {manifest_path}, which indicates that the "
44
+ "chosen output directory has been previously used for writing embeddings."
45
+ )
46
+ self._manifest_file = open(manifest_path, "w", newline="")
47
+ self._manifest_writer = csv.writer(self._manifest_file)
48
+ self._manifest_writer.writerow(
49
+ ["origin", "embeddings", "target", "split"] + self._metadata_keys
50
+ )
51
+
52
+ def update(
53
+ self,
54
+ input_name: str,
55
+ save_name: str,
56
+ target: str,
57
+ split: str | None,
58
+ metadata: Dict[str, Any] | None = None,
59
+ ) -> None:
60
+ """Adds a new entry to the manifest file."""
61
+ metadata_entries = _to_dict_values(metadata or {})
62
+ self._manifest_writer.writerow([input_name, save_name, target, split] + metadata_entries)
63
+
64
+ def close(self) -> None:
65
+ """Closes the manifest file."""
66
+ if self._manifest_file:
67
+ self._manifest_file.close()
68
+
69
+
70
+ def _to_dict_values(data: Dict[str, Any]) -> List[Any]:
71
+ return [value.item() if isinstance(value, torch.Tensor) else value for value in data.values()]
@@ -0,0 +1,192 @@
1
+ """Embeddings writer base class."""
2
+
3
+ import abc
4
+ import io
5
+ import os
6
+ from typing import Any, Dict, List, Sequence
7
+
8
+ import lightning.pytorch as pl
9
+ import torch
10
+ from lightning.pytorch import callbacks
11
+ from loguru import logger
12
+ from torch import multiprocessing, nn
13
+ from typing_extensions import override
14
+
15
+ from eva.core import utils
16
+ from eva.core.callbacks.writers.embeddings.typings import QUEUE_ITEM
17
+ from eva.core.models.modules.typings import INPUT_BATCH
18
+ from eva.core.utils import multiprocessing as eva_multiprocessing
19
+
20
+
21
+ class EmbeddingsWriter(callbacks.BasePredictionWriter, abc.ABC):
22
+ """Callback for writing generated embeddings to disk."""
23
+
24
+ def __init__(
25
+ self,
26
+ output_dir: str,
27
+ backbone: nn.Module | None = None,
28
+ dataloader_idx_map: Dict[int, str] | None = None,
29
+ metadata_keys: List[str] | None = None,
30
+ overwrite: bool = False,
31
+ save_every_n: int = 100,
32
+ ) -> None:
33
+ """Initializes a new EmbeddingsWriter instance.
34
+
35
+ This callback writes the embedding files in a separate process to avoid blocking the
36
+ main process where the model forward pass is executed.
37
+
38
+ Args:
39
+ output_dir: The directory where the embeddings will be saved.
40
+ backbone: A model to be used as feature extractor. If `None`,
41
+ it will be expected that the input batch returns the features directly.
42
+ dataloader_idx_map: A dictionary mapping dataloader indices to their respective
43
+ names (e.g. train, val, test).
44
+ metadata_keys: An optional list of keys to extract from the batch metadata and store
45
+ as additional columns in the manifest file.
46
+ overwrite: Whether to overwrite if embeddings are already present in the specified
47
+ output directory. If set to `False`, an error will be raised if embeddings are
48
+ already present (recommended).
49
+ save_every_n: Interval for number of iterations to save the embeddings to disk.
50
+ During this interval, the embeddings are accumulated in memory.
51
+ """
52
+ super().__init__(write_interval="batch")
53
+
54
+ self._output_dir = output_dir
55
+ self._backbone = backbone
56
+ self._dataloader_idx_map = dataloader_idx_map or {}
57
+ self._overwrite = overwrite
58
+ self._save_every_n = save_every_n
59
+ self._metadata_keys = metadata_keys or []
60
+
61
+ self._write_queue: multiprocessing.Queue
62
+ self._write_process: eva_multiprocessing.Process
63
+
64
+ @staticmethod
65
+ @abc.abstractmethod
66
+ def _process_write_queue(
67
+ write_queue: multiprocessing.Queue,
68
+ output_dir: str,
69
+ metadata_keys: List[str],
70
+ save_every_n: int,
71
+ overwrite: bool = False,
72
+ ) -> None:
73
+ """This function receives and processes items added by the main process to the queue.
74
+
75
+ Queue items contain the embedding tensors, targets and metadata which need to be
76
+ saved to disk (.pt files and manifest).
77
+ """
78
+
79
+ @override
80
+ def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
81
+ self._check_if_exists()
82
+ self._initialize_write_process()
83
+ self._write_process.start()
84
+
85
+ if self._backbone is not None:
86
+ self._backbone = self._backbone.to(pl_module.device)
87
+ self._backbone.eval()
88
+
89
+ @override
90
+ def write_on_batch_end(
91
+ self,
92
+ trainer: pl.Trainer,
93
+ pl_module: pl.LightningModule,
94
+ prediction: Any,
95
+ batch_indices: Sequence[int],
96
+ batch: INPUT_BATCH,
97
+ batch_idx: int,
98
+ dataloader_idx: int,
99
+ ) -> None:
100
+ dataset = trainer.predict_dataloaders[dataloader_idx].dataset # type: ignore
101
+ _, targets, metadata = INPUT_BATCH(*batch)
102
+ split = self._dataloader_idx_map.get(dataloader_idx)
103
+ if not isinstance(targets, torch.Tensor):
104
+ raise ValueError(f"Targets ({type(targets)}) should be `torch.Tensor`.")
105
+
106
+ with torch.no_grad():
107
+ embeddings = self._get_embeddings(prediction)
108
+
109
+ for local_idx, global_idx in enumerate(batch_indices[: len(embeddings)]):
110
+ data_name = dataset.filename(global_idx)
111
+ save_name = os.path.splitext(data_name)[0] + ".pt"
112
+ embeddings_buffer, target_buffer = _as_io_buffers(
113
+ embeddings[local_idx], targets[local_idx]
114
+ )
115
+ item_metadata = self._get_item_metadata(metadata, local_idx)
116
+ item = QUEUE_ITEM(
117
+ prediction_buffer=embeddings_buffer,
118
+ target_buffer=target_buffer,
119
+ data_name=data_name,
120
+ save_name=save_name,
121
+ split=split,
122
+ metadata=item_metadata,
123
+ )
124
+ self._write_queue.put(item)
125
+
126
+ self._write_process.check_exceptions()
127
+
128
+ @override
129
+ def on_predict_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
130
+ self._write_queue.put(None)
131
+ self._write_process.join()
132
+ logger.info(f"Predictions and manifest saved to {self._output_dir}")
133
+
134
+ def _initialize_write_process(self) -> None:
135
+ self._write_queue = multiprocessing.Queue()
136
+ self._write_process = eva_multiprocessing.Process(
137
+ target=self._process_write_queue,
138
+ args=(
139
+ self._write_queue,
140
+ self._output_dir,
141
+ self._metadata_keys,
142
+ self._save_every_n,
143
+ self._overwrite,
144
+ ),
145
+ )
146
+
147
+ @abc.abstractmethod
148
+ def _get_embeddings(self, tensor: torch.Tensor) -> torch.Tensor | List[List[torch.Tensor]]:
149
+ """Returns the embeddings from predictions."""
150
+
151
+ def _get_item_metadata(
152
+ self, metadata: Dict[str, Any] | None, local_idx: int
153
+ ) -> Dict[str, Any] | None:
154
+ """Returns the metadata for the item at the given local index."""
155
+ if not metadata:
156
+ if self._metadata_keys:
157
+ raise ValueError("Metadata keys are provided but the batch metadata is empty.")
158
+ else:
159
+ return None
160
+
161
+ item_metadata = {}
162
+ for key in self._metadata_keys:
163
+ if key not in metadata:
164
+ raise KeyError(f"Metadata key '{key}' not found in the batch metadata.")
165
+ metadata_value = metadata[key][local_idx]
166
+ try:
167
+ item_metadata[key] = utils.to_cpu(metadata_value)
168
+ except TypeError:
169
+ item_metadata[key] = metadata_value
170
+
171
+ return item_metadata
172
+
173
+ def _check_if_exists(self) -> None:
174
+ """Checks if the output directory already exists and if it should be overwritten."""
175
+ try:
176
+ os.makedirs(self._output_dir, exist_ok=self._overwrite)
177
+ except FileExistsError as e:
178
+ raise FileExistsError(
179
+ f"The embeddings output directory already exists: {self._output_dir}. This "
180
+ "either means that they have been computed before or that a wrong output "
181
+ "directory is being used. Consider using `eva fit` instead, selecting a "
182
+ "different output directory or setting overwrite=True."
183
+ ) from e
184
+ os.makedirs(self._output_dir, exist_ok=True)
185
+
186
+
187
+ def _as_io_buffers(*items: torch.Tensor | List[torch.Tensor]) -> Sequence[io.BytesIO]:
188
+ """Loads torch tensors as io buffers."""
189
+ buffers = [io.BytesIO() for _ in range(len(items))]
190
+ for tensor, buffer in zip(items, buffers, strict=False):
191
+ torch.save(utils.clone(tensor), buffer)
192
+ return buffers
@@ -0,0 +1,117 @@
1
+ """Embeddings writer for classification."""
2
+
3
+ import io
4
+ import os
5
+ from typing import Dict, List
6
+
7
+ import torch
8
+ from torch import multiprocessing
9
+ from typing_extensions import override
10
+
11
+ from eva.core.callbacks.writers.embeddings import base
12
+ from eva.core.callbacks.writers.embeddings._manifest import ManifestManager
13
+ from eva.core.callbacks.writers.embeddings.typings import ITEM_DICT_ENTRY, QUEUE_ITEM
14
+
15
+
16
+ class ClassificationEmbeddingsWriter(base.EmbeddingsWriter):
17
+ """Callback for writing generated embeddings to disk for classification tasks."""
18
+
19
+ @staticmethod
20
+ @override
21
+ def _process_write_queue(
22
+ write_queue: multiprocessing.Queue,
23
+ output_dir: str,
24
+ metadata_keys: List[str],
25
+ save_every_n: int,
26
+ overwrite: bool = False,
27
+ ) -> None:
28
+ """Processes the write queue and saves the predictions to disk.
29
+
30
+ Note that in Multi Instance Learning (MIL) scenarios, we can have multiple
31
+ embeddings per input data point. In that case, this function will save all
32
+ embeddings that correspond to the same data point as a list of tensors to
33
+ the same .pt file.
34
+ """
35
+ manifest_manager = ManifestManager(output_dir, metadata_keys, overwrite)
36
+ name_to_items: Dict[str, ITEM_DICT_ENTRY] = {}
37
+
38
+ counter = 0
39
+ while True:
40
+ item = write_queue.get()
41
+ if item is None:
42
+ break
43
+ item = QUEUE_ITEM(*item)
44
+
45
+ if item.save_name in name_to_items:
46
+ name_to_items[item.save_name].items.append(item)
47
+ else:
48
+ name_to_items[item.save_name] = ITEM_DICT_ENTRY(items=[item], save_count=0)
49
+
50
+ if counter > 0 and counter % save_every_n == 0:
51
+ name_to_items = _save_items(name_to_items, output_dir, manifest_manager)
52
+ counter += 1
53
+
54
+ if len(name_to_items) > 0:
55
+ _save_items(name_to_items, output_dir, manifest_manager)
56
+
57
+ manifest_manager.close()
58
+
59
+ @override
60
+ def _get_embeddings(self, tensor: torch.Tensor) -> torch.Tensor:
61
+ """Returns the embeddings from predictions."""
62
+ return self._backbone(tensor) if self._backbone else tensor
63
+
64
+
65
+ def _save_items(
66
+ name_to_items: Dict[str, ITEM_DICT_ENTRY],
67
+ output_dir: str,
68
+ manifest_manager: ManifestManager,
69
+ ) -> Dict[str, ITEM_DICT_ENTRY]:
70
+ """Saves predictions to disk and updates the manifest file.
71
+
72
+ Args:
73
+ name_to_items: A dictionary mapping save data point names to the corresponding queue items
74
+ holding the prediction tensors and the information for the manifest file.
75
+ output_dir: The directory where the embedding tensors & manifest will be saved.
76
+ manifest_manager: The manifest manager instance to update the manifest file.
77
+ """
78
+ for save_name, entry in name_to_items.items():
79
+ if len(entry.items) > 0:
80
+ save_path = os.path.join(output_dir, save_name)
81
+ is_first_save = entry.save_count == 0
82
+ if is_first_save:
83
+ _, target, input_name, _, split, metadata = QUEUE_ITEM(*entry.items[0])
84
+ target = torch.load(io.BytesIO(target.getbuffer()), map_location="cpu").item()
85
+ manifest_manager.update(input_name, save_name, target, split, metadata)
86
+
87
+ prediction_buffers = [item.prediction_buffer for item in entry.items]
88
+ _save_predictions(prediction_buffers, save_path, is_first_save)
89
+ name_to_items[save_name].save_count += 1
90
+ name_to_items[save_name].items = []
91
+
92
+ return name_to_items
93
+
94
+
95
+ def _save_predictions(
96
+ prediction_buffers: List[io.BytesIO], save_path: str, is_first_save: bool
97
+ ) -> None:
98
+ """Saves the embedding tensors as list to .pt files.
99
+
100
+ If it's not the first save to this save_path, the new predictions are appended to
101
+ the existing ones and saved to the same file.
102
+
103
+ Example use-case: Save all patch embeddings corresponding to the same WSI to a single file.
104
+ """
105
+ predictions = [
106
+ torch.load(io.BytesIO(buffer.getbuffer()), map_location="cpu")
107
+ for buffer in prediction_buffers
108
+ ]
109
+
110
+ if not is_first_save:
111
+ previous_predictions = torch.load(save_path, map_location="cpu")
112
+ if not isinstance(previous_predictions, list):
113
+ raise ValueError("Previous predictions should be a list of tensors.")
114
+ predictions = predictions + previous_predictions
115
+
116
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
117
+ torch.save(predictions, save_path)
@@ -0,0 +1,78 @@
1
+ """Segmentation embeddings writer."""
2
+
3
+ import collections
4
+ import io
5
+ import os
6
+ from typing import List
7
+
8
+ import torch
9
+ from torch import multiprocessing
10
+ from typing_extensions import override
11
+
12
+ from eva.core.callbacks.writers.embeddings import base
13
+ from eva.core.callbacks.writers.embeddings._manifest import ManifestManager
14
+ from eva.core.callbacks.writers.embeddings.typings import QUEUE_ITEM
15
+
16
+
17
+ class SegmentationEmbeddingsWriter(base.EmbeddingsWriter):
18
+ """Callback for writing generated embeddings to disk."""
19
+
20
+ @staticmethod
21
+ @override
22
+ def _process_write_queue(
23
+ write_queue: multiprocessing.Queue,
24
+ output_dir: str,
25
+ metadata_keys: List[str],
26
+ save_every_n: int,
27
+ overwrite: bool = False,
28
+ ) -> None:
29
+ manifest_manager = ManifestManager(output_dir, metadata_keys, overwrite)
30
+ counter = collections.defaultdict(lambda: -1)
31
+ while True:
32
+ item = write_queue.get()
33
+ if item is None:
34
+ break
35
+
36
+ embeddings_buffer, target_buffer, input_name, save_name, split, metadata = QUEUE_ITEM(
37
+ *item
38
+ )
39
+ counter[save_name] += 1
40
+ save_name = save_name.replace(".pt", f"-{counter[save_name]}.pt")
41
+ target_filename = save_name.replace(".pt", "-mask.pt")
42
+
43
+ _save_embedding(embeddings_buffer, save_name, output_dir)
44
+ _save_embedding(target_buffer, target_filename, output_dir)
45
+ manifest_manager.update(input_name, save_name, target_filename, split, metadata)
46
+
47
+ manifest_manager.close()
48
+
49
+ @override
50
+ def _get_embeddings(self, tensor: torch.Tensor) -> torch.Tensor | List[List[torch.Tensor]]:
51
+ """Returns the embeddings from predictions."""
52
+
53
+ def _get_grouped_embeddings(embeddings: List[torch.Tensor]) -> List[List[torch.Tensor]]:
54
+ """Casts a list of multi-leveled batched embeddings to grouped per batch.
55
+
56
+ That is, for embeddings to be a list of shape (batch_size, hidden_dim, height, width),
57
+ such as `[(2, 192, 16, 16), (2, 192, 16, 16)]`, to be reshaped as a list of lists of
58
+ per batch multi-level embeddings, thus
59
+ `[ [(192, 16, 16), (192, 16, 16)], [(192, 16, 16), (192, 16, 16)] ]`.
60
+ """
61
+ batch_size = embeddings[0].shape[0]
62
+ grouped_embeddings = []
63
+ for batch_idx in range(batch_size):
64
+ batch_list = [layer_embeddings[batch_idx] for layer_embeddings in embeddings]
65
+ grouped_embeddings.append(batch_list)
66
+ return grouped_embeddings
67
+
68
+ embeddings = self._backbone(tensor) if self._backbone else tensor
69
+ if isinstance(embeddings, list):
70
+ embeddings = _get_grouped_embeddings(embeddings)
71
+ return embeddings
72
+
73
+
74
+ def _save_embedding(embeddings_buffer: io.BytesIO, save_name: str, output_dir: str) -> None:
75
+ save_path = os.path.join(output_dir, save_name)
76
+ prediction = torch.load(io.BytesIO(embeddings_buffer.getbuffer()), map_location="cpu")
77
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
78
+ torch.save(prediction, save_path)
@@ -0,0 +1,38 @@
1
+ """Typing definitions for the writer callback functions."""
2
+
3
+ import dataclasses
4
+ import io
5
+ from typing import Any, Dict, List, NamedTuple
6
+
7
+
8
+ class QUEUE_ITEM(NamedTuple):
9
+ """The default input batch data scheme."""
10
+
11
+ prediction_buffer: io.BytesIO
12
+ """IO buffer containing the prediction tensor."""
13
+
14
+ target_buffer: io.BytesIO
15
+ """IO buffer containing the target tensor."""
16
+
17
+ data_name: str
18
+ """Name of the input data that was used to generate the embedding."""
19
+
20
+ save_name: str
21
+ """Name to store the generated embedding."""
22
+
23
+ split: str | None
24
+ """The dataset split the item belongs to (e.g. train, val, test)."""
25
+
26
+ metadata: Dict[str, Any] | None = None
27
+ """Dictionary holding additional metadata."""
28
+
29
+
30
+ @dataclasses.dataclass
31
+ class ITEM_DICT_ENTRY:
32
+ """Typing for holding queue items and number of save operations."""
33
+
34
+ items: List[QUEUE_ITEM]
35
+ """List of queue items."""
36
+
37
+ save_count: int
38
+ """Number of prior item batch saves to same file."""
@@ -1,11 +1,11 @@
1
1
  """Datasets API."""
2
2
 
3
3
  from eva.core.data.datasets.base import Dataset
4
- from eva.core.data.datasets.dataset import TorchDataset
5
- from eva.core.data.datasets.embeddings import (
4
+ from eva.core.data.datasets.classification import (
6
5
  EmbeddingsClassificationDataset,
7
6
  MultiEmbeddingsClassificationDataset,
8
7
  )
8
+ from eva.core.data.datasets.dataset import TorchDataset
9
9
 
10
10
  __all__ = [
11
11
  "Dataset",
@@ -0,0 +1,8 @@
1
+ """Embedding cllassification datasets API."""
2
+
3
+ from eva.core.data.datasets.classification.embeddings import EmbeddingsClassificationDataset
4
+ from eva.core.data.datasets.classification.multi_embeddings import (
5
+ MultiEmbeddingsClassificationDataset,
6
+ )
7
+
8
+ __all__ = ["EmbeddingsClassificationDataset", "MultiEmbeddingsClassificationDataset"]
@@ -0,0 +1,34 @@
1
+ """Embeddings classification dataset."""
2
+
3
+ import os
4
+
5
+ import torch
6
+ from typing_extensions import override
7
+
8
+ from eva.core.data.datasets import embeddings as embeddings_base
9
+
10
+
11
+ class EmbeddingsClassificationDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]):
12
+ """Embeddings dataset class for classification tasks."""
13
+
14
+ @override
15
+ def _load_embeddings(self, index: int) -> torch.Tensor:
16
+ filename = self.filename(index)
17
+ embeddings_path = os.path.join(self._root, filename)
18
+ tensor = torch.load(embeddings_path, map_location="cpu")
19
+ if isinstance(tensor, list):
20
+ if len(tensor) > 1:
21
+ raise ValueError(
22
+ f"Expected a single tensor in the .pt file, but found {len(tensor)}."
23
+ )
24
+ tensor = tensor[0]
25
+ return tensor.squeeze(0)
26
+
27
+ @override
28
+ def _load_target(self, index: int) -> torch.Tensor:
29
+ target = self._data.at[index, self._column_mapping["target"]]
30
+ return torch.tensor(target, dtype=torch.int64)
31
+
32
+ @override
33
+ def __len__(self) -> int:
34
+ return len(self._data)
@@ -7,10 +7,10 @@ import numpy as np
7
7
  import torch
8
8
  from typing_extensions import override
9
9
 
10
- from eva.core.data.datasets.embeddings import base
10
+ from eva.core.data.datasets import embeddings as embeddings_base
11
11
 
12
12
 
13
- class MultiEmbeddingsClassificationDataset(base.EmbeddingsDataset):
13
+ class MultiEmbeddingsClassificationDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]):
14
14
  """Dataset class for where a sample corresponds to multiple embeddings.
15
15
 
16
16
  Example use case: Slide level dataset where each slide has multiple patch embeddings.
@@ -21,7 +21,7 @@ class MultiEmbeddingsClassificationDataset(base.EmbeddingsDataset):
21
21
  root: str,
22
22
  manifest_file: str,
23
23
  split: Literal["train", "val", "test"],
24
- column_mapping: Dict[str, str] = base.default_column_mapping,
24
+ column_mapping: Dict[str, str] = embeddings_base.default_column_mapping,
25
25
  embeddings_transforms: Callable | None = None,
26
26
  target_transforms: Callable | None = None,
27
27
  ):
@@ -32,9 +32,9 @@ class MultiEmbeddingsClassificationDataset(base.EmbeddingsDataset):
32
32
  The manifest must have a `column_mapping["multi_id"]` column that contains the
33
33
  unique identifier group of embeddings. For oncology datasets, this would be usually
34
34
  the slide id. Each row in the manifest file points to a .pt file that can contain
35
- one or multiple embeddings. There can also be multiple rows for the same `multi_id`,
36
- in which case the embeddings from the different .pt files corresponding to that same
37
- `multi_id` will be stacked along the first dimension.
35
+ one or multiple embeddings (either as a list or stacked tensors). There can also be
36
+ multiple rows for the same `multi_id`, in which case the embeddings from the different
37
+ .pt files corresponding to that same `multi_id` will be stacked along the first dimension.
38
38
 
39
39
  Args:
40
40
  root: Root directory of the dataset.
@@ -73,10 +73,14 @@ class MultiEmbeddingsClassificationDataset(base.EmbeddingsDataset):
73
73
  embedding_paths = self._data.loc[
74
74
  self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["path"]
75
75
  ].to_list()
76
- embedding_paths = [os.path.join(self._root, path) for path in embedding_paths]
77
76
 
78
77
  # Load embeddings and stack them accross the first dimension
79
- embeddings = [torch.load(path, map_location="cpu") for path in embedding_paths]
78
+ embeddings = []
79
+ for path in embedding_paths:
80
+ embedding = torch.load(os.path.join(self._root, path), map_location="cpu")
81
+ if isinstance(embedding, list):
82
+ embedding = torch.stack(embedding, dim=0)
83
+ embeddings.append(embedding.unsqueeze(0) if embedding.ndim == 1 else embedding)
80
84
  embeddings = torch.cat(embeddings, dim=0)
81
85
 
82
86
  if not embeddings.ndim == 2:
@@ -103,4 +107,4 @@ class MultiEmbeddingsClassificationDataset(base.EmbeddingsDataset):
103
107
 
104
108
  @override
105
109
  def __len__(self) -> int:
106
- return len(self._data)
110
+ return len(self._multi_ids)