kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/__init__.py +3 -2
- eva/core/callbacks/config.py +143 -0
- eva/core/callbacks/writers/__init__.py +6 -3
- eva/core/callbacks/writers/embeddings/__init__.py +6 -0
- eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
- eva/core/callbacks/writers/embeddings/base.py +192 -0
- eva/core/callbacks/writers/embeddings/classification.py +117 -0
- eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
- eva/core/callbacks/writers/embeddings/typings.py +38 -0
- eva/core/data/datasets/__init__.py +10 -2
- eva/core/data/datasets/classification/__init__.py +5 -2
- eva/core/data/datasets/classification/embeddings.py +15 -135
- eva/core/data/datasets/classification/multi_embeddings.py +110 -0
- eva/core/data/datasets/embeddings.py +167 -0
- eva/core/data/splitting/__init__.py +6 -0
- eva/core/data/splitting/random.py +41 -0
- eva/core/data/splitting/stratified.py +56 -0
- eva/core/data/transforms/__init__.py +3 -1
- eva/core/data/transforms/padding/__init__.py +5 -0
- eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
- eva/core/data/transforms/sampling/__init__.py +5 -0
- eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
- eva/core/loggers/__init__.py +7 -0
- eva/core/loggers/dummy.py +38 -0
- eva/core/loggers/experimental_loggers.py +8 -0
- eva/core/loggers/log/__init__.py +6 -0
- eva/core/loggers/log/image.py +71 -0
- eva/core/loggers/log/parameters.py +74 -0
- eva/core/loggers/log/utils.py +13 -0
- eva/core/loggers/loggers.py +6 -0
- eva/core/metrics/__init__.py +6 -2
- eva/core/metrics/defaults/__init__.py +10 -3
- eva/core/metrics/defaults/classification/__init__.py +1 -1
- eva/core/metrics/defaults/classification/binary.py +0 -9
- eva/core/metrics/defaults/classification/multiclass.py +0 -8
- eva/core/metrics/defaults/segmentation/__init__.py +5 -0
- eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
- eva/core/metrics/generalized_dice.py +59 -0
- eva/core/metrics/mean_iou.py +120 -0
- eva/core/metrics/structs/schemas.py +3 -1
- eva/core/models/__init__.py +3 -1
- eva/core/models/modules/head.py +16 -15
- eva/core/models/modules/module.py +25 -1
- eva/core/models/modules/typings.py +14 -1
- eva/core/models/modules/utils/batch_postprocess.py +37 -5
- eva/core/models/networks/__init__.py +1 -2
- eva/core/models/networks/mlp.py +2 -2
- eva/core/models/transforms/__init__.py +6 -0
- eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
- eva/core/models/transforms/extract_patch_features.py +47 -0
- eva/core/models/wrappers/__init__.py +13 -0
- eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
- eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
- eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
- eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
- eva/core/trainers/_recorder.py +69 -7
- eva/core/trainers/functional.py +23 -5
- eva/core/trainers/trainer.py +20 -6
- eva/core/utils/__init__.py +6 -0
- eva/core/utils/clone.py +27 -0
- eva/core/utils/memory.py +28 -0
- eva/core/utils/operations.py +26 -0
- eva/core/utils/parser.py +20 -0
- eva/vision/__init__.py +2 -2
- eva/vision/callbacks/__init__.py +5 -0
- eva/vision/callbacks/loggers/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/base.py +130 -0
- eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
- eva/vision/data/datasets/__init__.py +24 -4
- eva/vision/data/datasets/_utils.py +3 -3
- eva/vision/data/datasets/_validators.py +15 -2
- eva/vision/data/datasets/classification/__init__.py +6 -2
- eva/vision/data/datasets/classification/bach.py +10 -15
- eva/vision/data/datasets/classification/base.py +17 -24
- eva/vision/data/datasets/classification/camelyon16.py +244 -0
- eva/vision/data/datasets/classification/crc.py +10 -15
- eva/vision/data/datasets/classification/mhist.py +10 -15
- eva/vision/data/datasets/classification/panda.py +184 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
- eva/vision/data/datasets/classification/wsi.py +105 -0
- eva/vision/data/datasets/segmentation/__init__.py +15 -2
- eva/vision/data/datasets/segmentation/_utils.py +38 -0
- eva/vision/data/datasets/segmentation/base.py +31 -47
- eva/vision/data/datasets/segmentation/bcss.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +156 -0
- eva/vision/data/datasets/segmentation/embeddings.py +34 -0
- eva/vision/data/datasets/segmentation/lits.py +178 -0
- eva/vision/data/datasets/segmentation/monusac.py +236 -0
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
- eva/vision/data/datasets/wsi.py +187 -0
- eva/vision/data/transforms/__init__.py +3 -2
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
- eva/vision/data/transforms/common/resize_and_crop.py +6 -7
- eva/vision/data/transforms/normalization/__init__.py +6 -0
- eva/vision/data/transforms/normalization/clamp.py +43 -0
- eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
- eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
- eva/vision/data/wsi/__init__.py +16 -0
- eva/vision/data/wsi/backends/__init__.py +69 -0
- eva/vision/data/wsi/backends/base.py +115 -0
- eva/vision/data/wsi/backends/openslide.py +73 -0
- eva/vision/data/wsi/backends/pil.py +52 -0
- eva/vision/data/wsi/backends/tiffslide.py +42 -0
- eva/vision/data/wsi/patching/__init__.py +6 -0
- eva/vision/data/wsi/patching/coordinates.py +98 -0
- eva/vision/data/wsi/patching/mask.py +123 -0
- eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
- eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
- eva/vision/data/wsi/patching/samplers/base.py +48 -0
- eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
- eva/vision/data/wsi/patching/samplers/grid.py +47 -0
- eva/vision/data/wsi/patching/samplers/random.py +41 -0
- eva/vision/losses/__init__.py +5 -0
- eva/vision/losses/dice.py +40 -0
- eva/vision/models/__init__.py +4 -2
- eva/vision/models/modules/__init__.py +5 -0
- eva/vision/models/modules/semantic_segmentation.py +161 -0
- eva/vision/models/networks/__init__.py +1 -2
- eva/vision/models/networks/backbones/__init__.py +6 -0
- eva/vision/models/networks/backbones/_utils.py +39 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
- eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
- eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
- eva/vision/models/networks/backbones/pathology/histai.py +46 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
- eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
- eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
- eva/vision/models/networks/backbones/registry.py +47 -0
- eva/vision/models/networks/backbones/timm/__init__.py +5 -0
- eva/vision/models/networks/backbones/timm/backbones.py +54 -0
- eva/vision/models/networks/backbones/universal/__init__.py +8 -0
- eva/vision/models/networks/backbones/universal/vit.py +54 -0
- eva/vision/models/networks/decoders/__init__.py +6 -0
- eva/vision/models/networks/decoders/decoder.py +7 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
- eva/vision/models/networks/decoders/segmentation/common.py +74 -0
- eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
- eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
- eva/vision/models/wrappers/__init__.py +6 -0
- eva/vision/models/wrappers/from_registry.py +48 -0
- eva/vision/models/wrappers/from_timm.py +68 -0
- eva/vision/utils/colormap.py +77 -0
- eva/vision/utils/convert.py +67 -0
- eva/vision/utils/io/__init__.py +10 -4
- eva/vision/utils/io/image.py +21 -2
- eva/vision/utils/io/mat.py +36 -0
- eva/vision/utils/io/nifti.py +40 -15
- eva/vision/utils/io/text.py +10 -3
- kaiko_eva-0.1.0.dist-info/METADATA +553 -0
- kaiko_eva-0.1.0.dist-info/RECORD +205 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
- eva/core/callbacks/writers/embeddings.py +0 -169
- eva/core/callbacks/writers/typings.py +0 -23
- eva/core/models/networks/transforms/__init__.py +0 -5
- eva/core/models/networks/wrappers/__init__.py +0 -8
- eva/vision/data/datasets/classification/total_segmentator.py +0 -213
- eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
- eva/vision/models/networks/postprocesses/__init__.py +0 -5
- eva/vision/models/networks/postprocesses/cls.py +0 -25
- kaiko_eva-0.0.1.dist-info/METADATA +0 -405
- kaiko_eva-0.0.1.dist-info/RECORD +0 -110
- /eva/core/models/{networks → wrappers}/_utils.py +0 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
eva/core/callbacks/__init__.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Callbacks API."""
|
|
2
2
|
|
|
3
|
-
from eva.core.callbacks.
|
|
3
|
+
from eva.core.callbacks.config import ConfigurationLogger
|
|
4
|
+
from eva.core.callbacks.writers import ClassificationEmbeddingsWriter, SegmentationEmbeddingsWriter
|
|
4
5
|
|
|
5
|
-
__all__ = ["
|
|
6
|
+
__all__ = ["ConfigurationLogger", "ClassificationEmbeddingsWriter", "SegmentationEmbeddingsWriter"]
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
"""Configuration logger callback."""
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
from types import BuiltinFunctionType
|
|
7
|
+
from typing import Any, Dict, List
|
|
8
|
+
|
|
9
|
+
import lightning.pytorch as pl
|
|
10
|
+
import yaml
|
|
11
|
+
from lightning_fabric.utilities import cloud_io
|
|
12
|
+
from loguru import logger as cli_logger
|
|
13
|
+
from omegaconf import OmegaConf
|
|
14
|
+
from typing_extensions import TypeGuard, override
|
|
15
|
+
|
|
16
|
+
from eva.core import loggers
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ConfigurationLogger(pl.Callback):
|
|
20
|
+
"""Logs the submitted configuration to the experimental logger."""
|
|
21
|
+
|
|
22
|
+
_save_as: str = "config.yaml"
|
|
23
|
+
|
|
24
|
+
def __init__(self, verbose: bool = True) -> None:
|
|
25
|
+
"""Initializes the callback.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
verbose: Whether to print the configurations to print the
|
|
29
|
+
configuration to the terminal.
|
|
30
|
+
"""
|
|
31
|
+
super().__init__()
|
|
32
|
+
|
|
33
|
+
self._verbose = verbose
|
|
34
|
+
|
|
35
|
+
@override
|
|
36
|
+
def setup(
|
|
37
|
+
self,
|
|
38
|
+
trainer: pl.Trainer,
|
|
39
|
+
pl_module: pl.LightningModule,
|
|
40
|
+
stage: str | None = None,
|
|
41
|
+
) -> None:
|
|
42
|
+
log_dir = trainer.log_dir
|
|
43
|
+
if not _logdir_exists(log_dir):
|
|
44
|
+
return
|
|
45
|
+
|
|
46
|
+
configuration = _load_submitted_config()
|
|
47
|
+
|
|
48
|
+
if self._verbose:
|
|
49
|
+
config_as_text = yaml.dump(configuration, sort_keys=False)
|
|
50
|
+
print(f"Configuration:\033[94m\n---\n{config_as_text}\033[0m")
|
|
51
|
+
|
|
52
|
+
save_as = os.path.join(log_dir, self._save_as)
|
|
53
|
+
fs = cloud_io.get_filesystem(log_dir)
|
|
54
|
+
with fs.open(save_as, "w") as output_file:
|
|
55
|
+
yaml.dump(configuration, output_file, sort_keys=False)
|
|
56
|
+
|
|
57
|
+
loggers.log_parameters(trainer.loggers, tag="configuration", parameters=configuration)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _logdir_exists(logdir: str | None, verbose: bool = True) -> TypeGuard[str]:
|
|
61
|
+
"""Checks if the trainer has a log directory.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
logdir: Trainer's logdir.
|
|
65
|
+
name: The name to log with.
|
|
66
|
+
verbose: Whether to log if it does not exist.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
A bool indicating if the log directory exists or not.
|
|
70
|
+
"""
|
|
71
|
+
exists = isinstance(logdir, str)
|
|
72
|
+
if not exists and verbose:
|
|
73
|
+
print("\n")
|
|
74
|
+
cli_logger.warning("Log directory is `None`. Configuration file will not be logged.\n")
|
|
75
|
+
return exists
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _load_submitted_config() -> Dict[str, Any]:
|
|
79
|
+
"""Retrieves and loads the submitted configuration.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
The path to the configuration file.
|
|
83
|
+
"""
|
|
84
|
+
config_paths = _fetch_submitted_config_path()
|
|
85
|
+
return _load_yaml_files(config_paths)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _fetch_submitted_config_path() -> List[str]:
|
|
89
|
+
"""Fetches the config path from command line arguments.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
The path to the configuration file.
|
|
93
|
+
"""
|
|
94
|
+
return list(filter(lambda f: f.endswith(".yaml"), sys.argv))
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _load_yaml_files(paths: List[str]) -> Dict[str, Any]:
|
|
98
|
+
"""Loads yaml files and merge them from multiple paths.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
paths: The paths to the yaml files.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
The merged configurations as a dictionary.
|
|
105
|
+
"""
|
|
106
|
+
merged_config = {}
|
|
107
|
+
for config_path in paths:
|
|
108
|
+
fs = cloud_io.get_filesystem(config_path)
|
|
109
|
+
with fs.open(config_path, "r") as file:
|
|
110
|
+
omegaconf_file = OmegaConf.load(file) # type: ignore
|
|
111
|
+
config_dict = OmegaConf.to_object(omegaconf_file) # type: ignore
|
|
112
|
+
parsed_config = _type_resolver(config_dict) # type: ignore
|
|
113
|
+
merged_config.update(parsed_config)
|
|
114
|
+
return merged_config
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _type_resolver(mapping: Dict[str, Any]) -> Dict[str, Any]:
|
|
118
|
+
"""Parses the string values of a dictionary in-place.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
mapping: A dictionary object.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
The mapping with the formatted values.
|
|
125
|
+
"""
|
|
126
|
+
for key, value in mapping.items():
|
|
127
|
+
if isinstance(value, dict):
|
|
128
|
+
formatted_value = _type_resolver(value)
|
|
129
|
+
elif isinstance(value, list) and isinstance(value[0], dict):
|
|
130
|
+
formatted_value = [_type_resolver(subvalue) for subvalue in value]
|
|
131
|
+
else:
|
|
132
|
+
try:
|
|
133
|
+
parsed_value = ast.literal_eval(value) # type: ignore
|
|
134
|
+
formatted_value = (
|
|
135
|
+
value if isinstance(parsed_value, BuiltinFunctionType) else parsed_value
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
except Exception:
|
|
139
|
+
formatted_value = value
|
|
140
|
+
|
|
141
|
+
mapping[key] = formatted_value
|
|
142
|
+
|
|
143
|
+
return mapping
|
|
@@ -1,5 +1,8 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Writers callbacks API."""
|
|
2
2
|
|
|
3
|
-
from eva.core.callbacks.writers.embeddings import
|
|
3
|
+
from eva.core.callbacks.writers.embeddings import (
|
|
4
|
+
ClassificationEmbeddingsWriter,
|
|
5
|
+
SegmentationEmbeddingsWriter,
|
|
6
|
+
)
|
|
4
7
|
|
|
5
|
-
__all__ = ["
|
|
8
|
+
__all__ = ["ClassificationEmbeddingsWriter", "SegmentationEmbeddingsWriter"]
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
"""Embedding callback writers."""
|
|
2
|
+
|
|
3
|
+
from eva.core.callbacks.writers.embeddings.classification import ClassificationEmbeddingsWriter
|
|
4
|
+
from eva.core.callbacks.writers.embeddings.segmentation import SegmentationEmbeddingsWriter
|
|
5
|
+
|
|
6
|
+
__all__ = ["ClassificationEmbeddingsWriter", "SegmentationEmbeddingsWriter"]
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""Manifest file manager."""
|
|
2
|
+
|
|
3
|
+
import csv
|
|
4
|
+
import io
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any, Dict, List
|
|
7
|
+
|
|
8
|
+
import _csv
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ManifestManager:
|
|
13
|
+
"""Class for writing the embedding manifest files."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
output_dir: str,
|
|
18
|
+
metadata_keys: List[str] | None = None,
|
|
19
|
+
overwrite: bool = False,
|
|
20
|
+
) -> None:
|
|
21
|
+
"""Initializes the writing manager.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
output_dir: The directory where the embeddings will be saved.
|
|
25
|
+
metadata_keys: An optional list of keys to extract from the batch
|
|
26
|
+
metadata and store as additional columns in the manifest file.
|
|
27
|
+
overwrite: Whether to overwrite the output directory.
|
|
28
|
+
"""
|
|
29
|
+
self._output_dir = output_dir
|
|
30
|
+
self._metadata_keys = metadata_keys or []
|
|
31
|
+
self._overwrite = overwrite
|
|
32
|
+
|
|
33
|
+
self._manifest_file: io.TextIOWrapper
|
|
34
|
+
self._manifest_writer: _csv.Writer # type: ignore
|
|
35
|
+
|
|
36
|
+
self._setup()
|
|
37
|
+
|
|
38
|
+
def _setup(self) -> None:
|
|
39
|
+
"""Initializes the manifest file and sets the file object and writer."""
|
|
40
|
+
manifest_path = os.path.join(self._output_dir, "manifest.csv")
|
|
41
|
+
if os.path.exists(manifest_path) and not self._overwrite:
|
|
42
|
+
raise FileExistsError(
|
|
43
|
+
f"A manifest file already exists at {manifest_path}, which indicates that the "
|
|
44
|
+
"chosen output directory has been previously used for writing embeddings."
|
|
45
|
+
)
|
|
46
|
+
self._manifest_file = open(manifest_path, "w", newline="")
|
|
47
|
+
self._manifest_writer = csv.writer(self._manifest_file)
|
|
48
|
+
self._manifest_writer.writerow(
|
|
49
|
+
["origin", "embeddings", "target", "split"] + self._metadata_keys
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def update(
|
|
53
|
+
self,
|
|
54
|
+
input_name: str,
|
|
55
|
+
save_name: str,
|
|
56
|
+
target: str,
|
|
57
|
+
split: str | None,
|
|
58
|
+
metadata: Dict[str, Any] | None = None,
|
|
59
|
+
) -> None:
|
|
60
|
+
"""Adds a new entry to the manifest file."""
|
|
61
|
+
metadata_entries = _to_dict_values(metadata or {})
|
|
62
|
+
self._manifest_writer.writerow([input_name, save_name, target, split] + metadata_entries)
|
|
63
|
+
|
|
64
|
+
def close(self) -> None:
|
|
65
|
+
"""Closes the manifest file."""
|
|
66
|
+
if self._manifest_file:
|
|
67
|
+
self._manifest_file.close()
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _to_dict_values(data: Dict[str, Any]) -> List[Any]:
|
|
71
|
+
return [value.item() if isinstance(value, torch.Tensor) else value for value in data.values()]
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
"""Embeddings writer base class."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
import io
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any, Dict, List, Sequence
|
|
7
|
+
|
|
8
|
+
import lightning.pytorch as pl
|
|
9
|
+
import torch
|
|
10
|
+
from lightning.pytorch import callbacks
|
|
11
|
+
from loguru import logger
|
|
12
|
+
from torch import multiprocessing, nn
|
|
13
|
+
from typing_extensions import override
|
|
14
|
+
|
|
15
|
+
from eva.core import utils
|
|
16
|
+
from eva.core.callbacks.writers.embeddings.typings import QUEUE_ITEM
|
|
17
|
+
from eva.core.models.modules.typings import INPUT_BATCH
|
|
18
|
+
from eva.core.utils import multiprocessing as eva_multiprocessing
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class EmbeddingsWriter(callbacks.BasePredictionWriter, abc.ABC):
|
|
22
|
+
"""Callback for writing generated embeddings to disk."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
output_dir: str,
|
|
27
|
+
backbone: nn.Module | None = None,
|
|
28
|
+
dataloader_idx_map: Dict[int, str] | None = None,
|
|
29
|
+
metadata_keys: List[str] | None = None,
|
|
30
|
+
overwrite: bool = False,
|
|
31
|
+
save_every_n: int = 100,
|
|
32
|
+
) -> None:
|
|
33
|
+
"""Initializes a new EmbeddingsWriter instance.
|
|
34
|
+
|
|
35
|
+
This callback writes the embedding files in a separate process to avoid blocking the
|
|
36
|
+
main process where the model forward pass is executed.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
output_dir: The directory where the embeddings will be saved.
|
|
40
|
+
backbone: A model to be used as feature extractor. If `None`,
|
|
41
|
+
it will be expected that the input batch returns the features directly.
|
|
42
|
+
dataloader_idx_map: A dictionary mapping dataloader indices to their respective
|
|
43
|
+
names (e.g. train, val, test).
|
|
44
|
+
metadata_keys: An optional list of keys to extract from the batch metadata and store
|
|
45
|
+
as additional columns in the manifest file.
|
|
46
|
+
overwrite: Whether to overwrite if embeddings are already present in the specified
|
|
47
|
+
output directory. If set to `False`, an error will be raised if embeddings are
|
|
48
|
+
already present (recommended).
|
|
49
|
+
save_every_n: Interval for number of iterations to save the embeddings to disk.
|
|
50
|
+
During this interval, the embeddings are accumulated in memory.
|
|
51
|
+
"""
|
|
52
|
+
super().__init__(write_interval="batch")
|
|
53
|
+
|
|
54
|
+
self._output_dir = output_dir
|
|
55
|
+
self._backbone = backbone
|
|
56
|
+
self._dataloader_idx_map = dataloader_idx_map or {}
|
|
57
|
+
self._overwrite = overwrite
|
|
58
|
+
self._save_every_n = save_every_n
|
|
59
|
+
self._metadata_keys = metadata_keys or []
|
|
60
|
+
|
|
61
|
+
self._write_queue: multiprocessing.Queue
|
|
62
|
+
self._write_process: eva_multiprocessing.Process
|
|
63
|
+
|
|
64
|
+
@staticmethod
|
|
65
|
+
@abc.abstractmethod
|
|
66
|
+
def _process_write_queue(
|
|
67
|
+
write_queue: multiprocessing.Queue,
|
|
68
|
+
output_dir: str,
|
|
69
|
+
metadata_keys: List[str],
|
|
70
|
+
save_every_n: int,
|
|
71
|
+
overwrite: bool = False,
|
|
72
|
+
) -> None:
|
|
73
|
+
"""This function receives and processes items added by the main process to the queue.
|
|
74
|
+
|
|
75
|
+
Queue items contain the embedding tensors, targets and metadata which need to be
|
|
76
|
+
saved to disk (.pt files and manifest).
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
@override
|
|
80
|
+
def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
81
|
+
self._check_if_exists()
|
|
82
|
+
self._initialize_write_process()
|
|
83
|
+
self._write_process.start()
|
|
84
|
+
|
|
85
|
+
if self._backbone is not None:
|
|
86
|
+
self._backbone = self._backbone.to(pl_module.device)
|
|
87
|
+
self._backbone.eval()
|
|
88
|
+
|
|
89
|
+
@override
|
|
90
|
+
def write_on_batch_end(
|
|
91
|
+
self,
|
|
92
|
+
trainer: pl.Trainer,
|
|
93
|
+
pl_module: pl.LightningModule,
|
|
94
|
+
prediction: Any,
|
|
95
|
+
batch_indices: Sequence[int],
|
|
96
|
+
batch: INPUT_BATCH,
|
|
97
|
+
batch_idx: int,
|
|
98
|
+
dataloader_idx: int,
|
|
99
|
+
) -> None:
|
|
100
|
+
dataset = trainer.predict_dataloaders[dataloader_idx].dataset # type: ignore
|
|
101
|
+
_, targets, metadata = INPUT_BATCH(*batch)
|
|
102
|
+
split = self._dataloader_idx_map.get(dataloader_idx)
|
|
103
|
+
if not isinstance(targets, torch.Tensor):
|
|
104
|
+
raise ValueError(f"Targets ({type(targets)}) should be `torch.Tensor`.")
|
|
105
|
+
|
|
106
|
+
with torch.no_grad():
|
|
107
|
+
embeddings = self._get_embeddings(prediction)
|
|
108
|
+
|
|
109
|
+
for local_idx, global_idx in enumerate(batch_indices[: len(embeddings)]):
|
|
110
|
+
data_name = dataset.filename(global_idx)
|
|
111
|
+
save_name = os.path.splitext(data_name)[0] + ".pt"
|
|
112
|
+
embeddings_buffer, target_buffer = _as_io_buffers(
|
|
113
|
+
embeddings[local_idx], targets[local_idx]
|
|
114
|
+
)
|
|
115
|
+
item_metadata = self._get_item_metadata(metadata, local_idx)
|
|
116
|
+
item = QUEUE_ITEM(
|
|
117
|
+
prediction_buffer=embeddings_buffer,
|
|
118
|
+
target_buffer=target_buffer,
|
|
119
|
+
data_name=data_name,
|
|
120
|
+
save_name=save_name,
|
|
121
|
+
split=split,
|
|
122
|
+
metadata=item_metadata,
|
|
123
|
+
)
|
|
124
|
+
self._write_queue.put(item)
|
|
125
|
+
|
|
126
|
+
self._write_process.check_exceptions()
|
|
127
|
+
|
|
128
|
+
@override
|
|
129
|
+
def on_predict_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
130
|
+
self._write_queue.put(None)
|
|
131
|
+
self._write_process.join()
|
|
132
|
+
logger.info(f"Predictions and manifest saved to {self._output_dir}")
|
|
133
|
+
|
|
134
|
+
def _initialize_write_process(self) -> None:
|
|
135
|
+
self._write_queue = multiprocessing.Queue()
|
|
136
|
+
self._write_process = eva_multiprocessing.Process(
|
|
137
|
+
target=self._process_write_queue,
|
|
138
|
+
args=(
|
|
139
|
+
self._write_queue,
|
|
140
|
+
self._output_dir,
|
|
141
|
+
self._metadata_keys,
|
|
142
|
+
self._save_every_n,
|
|
143
|
+
self._overwrite,
|
|
144
|
+
),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
@abc.abstractmethod
|
|
148
|
+
def _get_embeddings(self, tensor: torch.Tensor) -> torch.Tensor | List[List[torch.Tensor]]:
|
|
149
|
+
"""Returns the embeddings from predictions."""
|
|
150
|
+
|
|
151
|
+
def _get_item_metadata(
|
|
152
|
+
self, metadata: Dict[str, Any] | None, local_idx: int
|
|
153
|
+
) -> Dict[str, Any] | None:
|
|
154
|
+
"""Returns the metadata for the item at the given local index."""
|
|
155
|
+
if not metadata:
|
|
156
|
+
if self._metadata_keys:
|
|
157
|
+
raise ValueError("Metadata keys are provided but the batch metadata is empty.")
|
|
158
|
+
else:
|
|
159
|
+
return None
|
|
160
|
+
|
|
161
|
+
item_metadata = {}
|
|
162
|
+
for key in self._metadata_keys:
|
|
163
|
+
if key not in metadata:
|
|
164
|
+
raise KeyError(f"Metadata key '{key}' not found in the batch metadata.")
|
|
165
|
+
metadata_value = metadata[key][local_idx]
|
|
166
|
+
try:
|
|
167
|
+
item_metadata[key] = utils.to_cpu(metadata_value)
|
|
168
|
+
except TypeError:
|
|
169
|
+
item_metadata[key] = metadata_value
|
|
170
|
+
|
|
171
|
+
return item_metadata
|
|
172
|
+
|
|
173
|
+
def _check_if_exists(self) -> None:
|
|
174
|
+
"""Checks if the output directory already exists and if it should be overwritten."""
|
|
175
|
+
try:
|
|
176
|
+
os.makedirs(self._output_dir, exist_ok=self._overwrite)
|
|
177
|
+
except FileExistsError as e:
|
|
178
|
+
raise FileExistsError(
|
|
179
|
+
f"The embeddings output directory already exists: {self._output_dir}. This "
|
|
180
|
+
"either means that they have been computed before or that a wrong output "
|
|
181
|
+
"directory is being used. Consider using `eva fit` instead, selecting a "
|
|
182
|
+
"different output directory or setting overwrite=True."
|
|
183
|
+
) from e
|
|
184
|
+
os.makedirs(self._output_dir, exist_ok=True)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _as_io_buffers(*items: torch.Tensor | List[torch.Tensor]) -> Sequence[io.BytesIO]:
|
|
188
|
+
"""Loads torch tensors as io buffers."""
|
|
189
|
+
buffers = [io.BytesIO() for _ in range(len(items))]
|
|
190
|
+
for tensor, buffer in zip(items, buffers, strict=False):
|
|
191
|
+
torch.save(utils.clone(tensor), buffer)
|
|
192
|
+
return buffers
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
"""Embeddings writer for classification."""
|
|
2
|
+
|
|
3
|
+
import io
|
|
4
|
+
import os
|
|
5
|
+
from typing import Dict, List
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import multiprocessing
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from eva.core.callbacks.writers.embeddings import base
|
|
12
|
+
from eva.core.callbacks.writers.embeddings._manifest import ManifestManager
|
|
13
|
+
from eva.core.callbacks.writers.embeddings.typings import ITEM_DICT_ENTRY, QUEUE_ITEM
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ClassificationEmbeddingsWriter(base.EmbeddingsWriter):
|
|
17
|
+
"""Callback for writing generated embeddings to disk for classification tasks."""
|
|
18
|
+
|
|
19
|
+
@staticmethod
|
|
20
|
+
@override
|
|
21
|
+
def _process_write_queue(
|
|
22
|
+
write_queue: multiprocessing.Queue,
|
|
23
|
+
output_dir: str,
|
|
24
|
+
metadata_keys: List[str],
|
|
25
|
+
save_every_n: int,
|
|
26
|
+
overwrite: bool = False,
|
|
27
|
+
) -> None:
|
|
28
|
+
"""Processes the write queue and saves the predictions to disk.
|
|
29
|
+
|
|
30
|
+
Note that in Multi Instance Learning (MIL) scenarios, we can have multiple
|
|
31
|
+
embeddings per input data point. In that case, this function will save all
|
|
32
|
+
embeddings that correspond to the same data point as a list of tensors to
|
|
33
|
+
the same .pt file.
|
|
34
|
+
"""
|
|
35
|
+
manifest_manager = ManifestManager(output_dir, metadata_keys, overwrite)
|
|
36
|
+
name_to_items: Dict[str, ITEM_DICT_ENTRY] = {}
|
|
37
|
+
|
|
38
|
+
counter = 0
|
|
39
|
+
while True:
|
|
40
|
+
item = write_queue.get()
|
|
41
|
+
if item is None:
|
|
42
|
+
break
|
|
43
|
+
item = QUEUE_ITEM(*item)
|
|
44
|
+
|
|
45
|
+
if item.save_name in name_to_items:
|
|
46
|
+
name_to_items[item.save_name].items.append(item)
|
|
47
|
+
else:
|
|
48
|
+
name_to_items[item.save_name] = ITEM_DICT_ENTRY(items=[item], save_count=0)
|
|
49
|
+
|
|
50
|
+
if counter > 0 and counter % save_every_n == 0:
|
|
51
|
+
name_to_items = _save_items(name_to_items, output_dir, manifest_manager)
|
|
52
|
+
counter += 1
|
|
53
|
+
|
|
54
|
+
if len(name_to_items) > 0:
|
|
55
|
+
_save_items(name_to_items, output_dir, manifest_manager)
|
|
56
|
+
|
|
57
|
+
manifest_manager.close()
|
|
58
|
+
|
|
59
|
+
@override
|
|
60
|
+
def _get_embeddings(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
61
|
+
"""Returns the embeddings from predictions."""
|
|
62
|
+
return self._backbone(tensor) if self._backbone else tensor
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _save_items(
|
|
66
|
+
name_to_items: Dict[str, ITEM_DICT_ENTRY],
|
|
67
|
+
output_dir: str,
|
|
68
|
+
manifest_manager: ManifestManager,
|
|
69
|
+
) -> Dict[str, ITEM_DICT_ENTRY]:
|
|
70
|
+
"""Saves predictions to disk and updates the manifest file.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
name_to_items: A dictionary mapping save data point names to the corresponding queue items
|
|
74
|
+
holding the prediction tensors and the information for the manifest file.
|
|
75
|
+
output_dir: The directory where the embedding tensors & manifest will be saved.
|
|
76
|
+
manifest_manager: The manifest manager instance to update the manifest file.
|
|
77
|
+
"""
|
|
78
|
+
for save_name, entry in name_to_items.items():
|
|
79
|
+
if len(entry.items) > 0:
|
|
80
|
+
save_path = os.path.join(output_dir, save_name)
|
|
81
|
+
is_first_save = entry.save_count == 0
|
|
82
|
+
if is_first_save:
|
|
83
|
+
_, target, input_name, _, split, metadata = QUEUE_ITEM(*entry.items[0])
|
|
84
|
+
target = torch.load(io.BytesIO(target.getbuffer()), map_location="cpu").item()
|
|
85
|
+
manifest_manager.update(input_name, save_name, target, split, metadata)
|
|
86
|
+
|
|
87
|
+
prediction_buffers = [item.prediction_buffer for item in entry.items]
|
|
88
|
+
_save_predictions(prediction_buffers, save_path, is_first_save)
|
|
89
|
+
name_to_items[save_name].save_count += 1
|
|
90
|
+
name_to_items[save_name].items = []
|
|
91
|
+
|
|
92
|
+
return name_to_items
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _save_predictions(
|
|
96
|
+
prediction_buffers: List[io.BytesIO], save_path: str, is_first_save: bool
|
|
97
|
+
) -> None:
|
|
98
|
+
"""Saves the embedding tensors as list to .pt files.
|
|
99
|
+
|
|
100
|
+
If it's not the first save to this save_path, the new predictions are appended to
|
|
101
|
+
the existing ones and saved to the same file.
|
|
102
|
+
|
|
103
|
+
Example use-case: Save all patch embeddings corresponding to the same WSI to a single file.
|
|
104
|
+
"""
|
|
105
|
+
predictions = [
|
|
106
|
+
torch.load(io.BytesIO(buffer.getbuffer()), map_location="cpu")
|
|
107
|
+
for buffer in prediction_buffers
|
|
108
|
+
]
|
|
109
|
+
|
|
110
|
+
if not is_first_save:
|
|
111
|
+
previous_predictions = torch.load(save_path, map_location="cpu")
|
|
112
|
+
if not isinstance(previous_predictions, list):
|
|
113
|
+
raise ValueError("Previous predictions should be a list of tensors.")
|
|
114
|
+
predictions = predictions + previous_predictions
|
|
115
|
+
|
|
116
|
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
117
|
+
torch.save(predictions, save_path)
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""Segmentation embeddings writer."""
|
|
2
|
+
|
|
3
|
+
import collections
|
|
4
|
+
import io
|
|
5
|
+
import os
|
|
6
|
+
from typing import List
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch import multiprocessing
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from eva.core.callbacks.writers.embeddings import base
|
|
13
|
+
from eva.core.callbacks.writers.embeddings._manifest import ManifestManager
|
|
14
|
+
from eva.core.callbacks.writers.embeddings.typings import QUEUE_ITEM
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SegmentationEmbeddingsWriter(base.EmbeddingsWriter):
|
|
18
|
+
"""Callback for writing generated embeddings to disk."""
|
|
19
|
+
|
|
20
|
+
@staticmethod
|
|
21
|
+
@override
|
|
22
|
+
def _process_write_queue(
|
|
23
|
+
write_queue: multiprocessing.Queue,
|
|
24
|
+
output_dir: str,
|
|
25
|
+
metadata_keys: List[str],
|
|
26
|
+
save_every_n: int,
|
|
27
|
+
overwrite: bool = False,
|
|
28
|
+
) -> None:
|
|
29
|
+
manifest_manager = ManifestManager(output_dir, metadata_keys, overwrite)
|
|
30
|
+
counter = collections.defaultdict(lambda: -1)
|
|
31
|
+
while True:
|
|
32
|
+
item = write_queue.get()
|
|
33
|
+
if item is None:
|
|
34
|
+
break
|
|
35
|
+
|
|
36
|
+
embeddings_buffer, target_buffer, input_name, save_name, split, metadata = QUEUE_ITEM(
|
|
37
|
+
*item
|
|
38
|
+
)
|
|
39
|
+
counter[save_name] += 1
|
|
40
|
+
save_name = save_name.replace(".pt", f"-{counter[save_name]}.pt")
|
|
41
|
+
target_filename = save_name.replace(".pt", "-mask.pt")
|
|
42
|
+
|
|
43
|
+
_save_embedding(embeddings_buffer, save_name, output_dir)
|
|
44
|
+
_save_embedding(target_buffer, target_filename, output_dir)
|
|
45
|
+
manifest_manager.update(input_name, save_name, target_filename, split, metadata)
|
|
46
|
+
|
|
47
|
+
manifest_manager.close()
|
|
48
|
+
|
|
49
|
+
@override
|
|
50
|
+
def _get_embeddings(self, tensor: torch.Tensor) -> torch.Tensor | List[List[torch.Tensor]]:
|
|
51
|
+
"""Returns the embeddings from predictions."""
|
|
52
|
+
|
|
53
|
+
def _get_grouped_embeddings(embeddings: List[torch.Tensor]) -> List[List[torch.Tensor]]:
|
|
54
|
+
"""Casts a list of multi-leveled batched embeddings to grouped per batch.
|
|
55
|
+
|
|
56
|
+
That is, for embeddings to be a list of shape (batch_size, hidden_dim, height, width),
|
|
57
|
+
such as `[(2, 192, 16, 16), (2, 192, 16, 16)]`, to be reshaped as a list of lists of
|
|
58
|
+
per batch multi-level embeddings, thus
|
|
59
|
+
`[ [(192, 16, 16), (192, 16, 16)], [(192, 16, 16), (192, 16, 16)] ]`.
|
|
60
|
+
"""
|
|
61
|
+
batch_size = embeddings[0].shape[0]
|
|
62
|
+
grouped_embeddings = []
|
|
63
|
+
for batch_idx in range(batch_size):
|
|
64
|
+
batch_list = [layer_embeddings[batch_idx] for layer_embeddings in embeddings]
|
|
65
|
+
grouped_embeddings.append(batch_list)
|
|
66
|
+
return grouped_embeddings
|
|
67
|
+
|
|
68
|
+
embeddings = self._backbone(tensor) if self._backbone else tensor
|
|
69
|
+
if isinstance(embeddings, list):
|
|
70
|
+
embeddings = _get_grouped_embeddings(embeddings)
|
|
71
|
+
return embeddings
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _save_embedding(embeddings_buffer: io.BytesIO, save_name: str, output_dir: str) -> None:
|
|
75
|
+
save_path = os.path.join(output_dir, save_name)
|
|
76
|
+
prediction = torch.load(io.BytesIO(embeddings_buffer.getbuffer()), map_location="cpu")
|
|
77
|
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
78
|
+
torch.save(prediction, save_path)
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""Typing definitions for the writer callback functions."""
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
import io
|
|
5
|
+
from typing import Any, Dict, List, NamedTuple
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class QUEUE_ITEM(NamedTuple):
|
|
9
|
+
"""The default input batch data scheme."""
|
|
10
|
+
|
|
11
|
+
prediction_buffer: io.BytesIO
|
|
12
|
+
"""IO buffer containing the prediction tensor."""
|
|
13
|
+
|
|
14
|
+
target_buffer: io.BytesIO
|
|
15
|
+
"""IO buffer containing the target tensor."""
|
|
16
|
+
|
|
17
|
+
data_name: str
|
|
18
|
+
"""Name of the input data that was used to generate the embedding."""
|
|
19
|
+
|
|
20
|
+
save_name: str
|
|
21
|
+
"""Name to store the generated embedding."""
|
|
22
|
+
|
|
23
|
+
split: str | None
|
|
24
|
+
"""The dataset split the item belongs to (e.g. train, val, test)."""
|
|
25
|
+
|
|
26
|
+
metadata: Dict[str, Any] | None = None
|
|
27
|
+
"""Dictionary holding additional metadata."""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclasses.dataclass
|
|
31
|
+
class ITEM_DICT_ENTRY:
|
|
32
|
+
"""Typing for holding queue items and number of save operations."""
|
|
33
|
+
|
|
34
|
+
items: List[QUEUE_ITEM]
|
|
35
|
+
"""List of queue items."""
|
|
36
|
+
|
|
37
|
+
save_count: int
|
|
38
|
+
"""Number of prior item batch saves to same file."""
|
|
@@ -1,7 +1,15 @@
|
|
|
1
1
|
"""Datasets API."""
|
|
2
2
|
|
|
3
3
|
from eva.core.data.datasets.base import Dataset
|
|
4
|
-
from eva.core.data.datasets.classification import
|
|
4
|
+
from eva.core.data.datasets.classification import (
|
|
5
|
+
EmbeddingsClassificationDataset,
|
|
6
|
+
MultiEmbeddingsClassificationDataset,
|
|
7
|
+
)
|
|
5
8
|
from eva.core.data.datasets.dataset import TorchDataset
|
|
6
9
|
|
|
7
|
-
__all__ = [
|
|
10
|
+
__all__ = [
|
|
11
|
+
"Dataset",
|
|
12
|
+
"EmbeddingsClassificationDataset",
|
|
13
|
+
"MultiEmbeddingsClassificationDataset",
|
|
14
|
+
"TorchDataset",
|
|
15
|
+
]
|