kaiko-eva 0.0.0.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/.DS_Store +0 -0
- eva/__init__.py +33 -0
- eva/__main__.py +18 -0
- eva/__version__.py +25 -0
- eva/core/__init__.py +19 -0
- eva/core/callbacks/__init__.py +5 -0
- eva/core/callbacks/writers/__init__.py +5 -0
- eva/core/callbacks/writers/embeddings.py +169 -0
- eva/core/callbacks/writers/typings.py +23 -0
- eva/core/cli/__init__.py +5 -0
- eva/core/cli/cli.py +19 -0
- eva/core/cli/logo.py +38 -0
- eva/core/cli/setup.py +89 -0
- eva/core/data/__init__.py +14 -0
- eva/core/data/dataloaders/__init__.py +5 -0
- eva/core/data/dataloaders/dataloader.py +80 -0
- eva/core/data/datamodules/__init__.py +6 -0
- eva/core/data/datamodules/call.py +33 -0
- eva/core/data/datamodules/datamodule.py +108 -0
- eva/core/data/datamodules/schemas.py +62 -0
- eva/core/data/datasets/__init__.py +7 -0
- eva/core/data/datasets/base.py +53 -0
- eva/core/data/datasets/classification/__init__.py +5 -0
- eva/core/data/datasets/classification/embeddings.py +154 -0
- eva/core/data/datasets/dataset.py +6 -0
- eva/core/data/samplers/__init__.py +5 -0
- eva/core/data/samplers/sampler.py +6 -0
- eva/core/data/transforms/__init__.py +5 -0
- eva/core/data/transforms/dtype/__init__.py +5 -0
- eva/core/data/transforms/dtype/array.py +28 -0
- eva/core/interface/__init__.py +5 -0
- eva/core/interface/interface.py +79 -0
- eva/core/metrics/__init__.py +17 -0
- eva/core/metrics/average_loss.py +47 -0
- eva/core/metrics/binary_balanced_accuracy.py +22 -0
- eva/core/metrics/defaults/__init__.py +6 -0
- eva/core/metrics/defaults/classification/__init__.py +6 -0
- eva/core/metrics/defaults/classification/binary.py +76 -0
- eva/core/metrics/defaults/classification/multiclass.py +80 -0
- eva/core/metrics/structs/__init__.py +9 -0
- eva/core/metrics/structs/collection.py +6 -0
- eva/core/metrics/structs/metric.py +6 -0
- eva/core/metrics/structs/module.py +115 -0
- eva/core/metrics/structs/schemas.py +47 -0
- eva/core/metrics/structs/typings.py +15 -0
- eva/core/models/__init__.py +13 -0
- eva/core/models/modules/__init__.py +7 -0
- eva/core/models/modules/head.py +113 -0
- eva/core/models/modules/inference.py +37 -0
- eva/core/models/modules/module.py +190 -0
- eva/core/models/modules/typings.py +23 -0
- eva/core/models/modules/utils/__init__.py +6 -0
- eva/core/models/modules/utils/batch_postprocess.py +57 -0
- eva/core/models/modules/utils/grad.py +23 -0
- eva/core/models/networks/__init__.py +6 -0
- eva/core/models/networks/_utils.py +25 -0
- eva/core/models/networks/mlp.py +69 -0
- eva/core/models/networks/transforms/__init__.py +5 -0
- eva/core/models/networks/transforms/extract_cls_features.py +25 -0
- eva/core/models/networks/wrappers/__init__.py +8 -0
- eva/core/models/networks/wrappers/base.py +47 -0
- eva/core/models/networks/wrappers/from_function.py +58 -0
- eva/core/models/networks/wrappers/huggingface.py +37 -0
- eva/core/models/networks/wrappers/onnx.py +47 -0
- eva/core/trainers/__init__.py +6 -0
- eva/core/trainers/_logging.py +81 -0
- eva/core/trainers/_recorder.py +149 -0
- eva/core/trainers/_utils.py +12 -0
- eva/core/trainers/functional.py +113 -0
- eva/core/trainers/trainer.py +97 -0
- eva/core/utils/__init__.py +1 -0
- eva/core/utils/io/__init__.py +5 -0
- eva/core/utils/io/dataframe.py +21 -0
- eva/core/utils/multiprocessing.py +44 -0
- eva/core/utils/workers.py +21 -0
- eva/vision/__init__.py +14 -0
- eva/vision/data/__init__.py +5 -0
- eva/vision/data/datasets/__init__.py +22 -0
- eva/vision/data/datasets/_utils.py +50 -0
- eva/vision/data/datasets/_validators.py +44 -0
- eva/vision/data/datasets/classification/__init__.py +15 -0
- eva/vision/data/datasets/classification/bach.py +174 -0
- eva/vision/data/datasets/classification/base.py +103 -0
- eva/vision/data/datasets/classification/crc.py +176 -0
- eva/vision/data/datasets/classification/mhist.py +106 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +203 -0
- eva/vision/data/datasets/classification/total_segmentator.py +212 -0
- eva/vision/data/datasets/segmentation/__init__.py +6 -0
- eva/vision/data/datasets/segmentation/base.py +112 -0
- eva/vision/data/datasets/segmentation/total_segmentator.py +212 -0
- eva/vision/data/datasets/structs.py +17 -0
- eva/vision/data/datasets/vision.py +43 -0
- eva/vision/data/transforms/__init__.py +5 -0
- eva/vision/data/transforms/common/__init__.py +5 -0
- eva/vision/data/transforms/common/resize_and_crop.py +44 -0
- eva/vision/models/__init__.py +5 -0
- eva/vision/models/networks/__init__.py +6 -0
- eva/vision/models/networks/abmil.py +176 -0
- eva/vision/models/networks/postprocesses/__init__.py +5 -0
- eva/vision/models/networks/postprocesses/cls.py +25 -0
- eva/vision/utils/__init__.py +5 -0
- eva/vision/utils/io/__init__.py +12 -0
- eva/vision/utils/io/_utils.py +29 -0
- eva/vision/utils/io/image.py +54 -0
- eva/vision/utils/io/nifti.py +50 -0
- eva/vision/utils/io/text.py +18 -0
- kaiko_eva-0.0.0.dev6.dist-info/METADATA +393 -0
- kaiko_eva-0.0.0.dev6.dist-info/RECORD +111 -0
- kaiko_eva-0.0.0.dev6.dist-info/WHEEL +4 -0
- kaiko_eva-0.0.0.dev6.dist-info/entry_points.txt +4 -0
- kaiko_eva-0.0.0.dev6.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""Helper functions and utilities for trainer logging."""
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import sys
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
|
|
7
|
+
from lightning_fabric.utilities import cloud_io
|
|
8
|
+
from loguru import logger
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def generate_session_id() -> str:
|
|
12
|
+
"""Generates and returns a unique string ID of an experiment.
|
|
13
|
+
|
|
14
|
+
The ID is composed of the run timestamp and a its config hash. If the
|
|
15
|
+
configuration hash is an empty string, it will use only the timestamp.
|
|
16
|
+
"""
|
|
17
|
+
timestamp = _generate_timestamp_hash()
|
|
18
|
+
config_hash = _generate_config_hash()
|
|
19
|
+
return f"{timestamp}_{config_hash}" if config_hash else timestamp
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _generate_timestamp_hash() -> str:
|
|
23
|
+
"""Generate a time-based hash id."""
|
|
24
|
+
timestamp = datetime.now()
|
|
25
|
+
return timestamp.strftime("%Y%m%d-%H%M%S%f")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _generate_config_hash(max_hash_len: int = 8) -> str:
|
|
29
|
+
"""Generates a hash id based on a yaml configuration file.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
max_hash_len: The maximum length of the produced hash id.
|
|
33
|
+
"""
|
|
34
|
+
config_path = _fetch_config_path()
|
|
35
|
+
if config_path is None:
|
|
36
|
+
logger.warning(
|
|
37
|
+
"No or multiple configuration file found from command line arguments. "
|
|
38
|
+
"No configuration hash code will created for this experiment."
|
|
39
|
+
)
|
|
40
|
+
return ""
|
|
41
|
+
|
|
42
|
+
return _generate_hash_from_config(config_path, max_hash_len)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _fetch_config_path() -> str | None:
|
|
46
|
+
"""Retrieves the configuration path from command line arguments.
|
|
47
|
+
|
|
48
|
+
It returns `None` if no or multiple configuration files found in
|
|
49
|
+
the system arguments.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
The path to the configuration file.
|
|
53
|
+
"""
|
|
54
|
+
inputs = sys.argv
|
|
55
|
+
config_paths = [inputs[i + 1] for i, arg in enumerate(inputs) if arg == "--config"]
|
|
56
|
+
if len(config_paths) == 0 or len(config_paths) > 1:
|
|
57
|
+
# TODO combine the multiple configuration files
|
|
58
|
+
# and produced hash for the merged one.
|
|
59
|
+
return None
|
|
60
|
+
|
|
61
|
+
return config_paths[0]
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _generate_hash_from_config(path: str, max_hash_len: int = 8) -> str:
|
|
65
|
+
"""Return a hash from the contents of the configuration file.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
path: Path to the configuration file.
|
|
69
|
+
max_hash_len: Maximum length of the returned hash.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
Hash of the configuration file content.
|
|
73
|
+
"""
|
|
74
|
+
fs = cloud_io.get_filesystem(path)
|
|
75
|
+
with fs.open(path, "r") as stream:
|
|
76
|
+
config = stream.read()
|
|
77
|
+
if isinstance(config, str):
|
|
78
|
+
config = config.encode("utf-8")
|
|
79
|
+
config_sha256 = hashlib.sha256(config)
|
|
80
|
+
hash_id = config_sha256.hexdigest()
|
|
81
|
+
return hash_id[:max_hash_len]
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""Multi-run summary recorder."""
|
|
2
|
+
|
|
3
|
+
import collections
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import statistics
|
|
7
|
+
import sys
|
|
8
|
+
from typing import Any, Dict, List, Mapping
|
|
9
|
+
|
|
10
|
+
from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT
|
|
11
|
+
from lightning_fabric.utilities import cloud_io
|
|
12
|
+
from loguru import logger
|
|
13
|
+
from omegaconf import OmegaConf
|
|
14
|
+
from toolz import dicttoolz
|
|
15
|
+
|
|
16
|
+
SESSION_METRICS = Mapping[str, List[float]]
|
|
17
|
+
"""Session metrics type-hint."""
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SessionRecorder:
|
|
21
|
+
"""Multi-run (session) summary logger."""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
output_dir: str,
|
|
26
|
+
results_file: str = "results.json",
|
|
27
|
+
config_file: str = "config.yaml",
|
|
28
|
+
) -> None:
|
|
29
|
+
"""Initializes the recorder.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
output_dir: The destination folder to save the results.
|
|
33
|
+
results_file: The name of the results json file.
|
|
34
|
+
config_file: The name of the yaml configuration file.
|
|
35
|
+
"""
|
|
36
|
+
self._output_dir = output_dir
|
|
37
|
+
self._results_file = results_file
|
|
38
|
+
self._config_file = config_file
|
|
39
|
+
|
|
40
|
+
self._validation_metrics: List[SESSION_METRICS] = []
|
|
41
|
+
self._test_metrics: List[SESSION_METRICS] = []
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def filename(self) -> str:
|
|
45
|
+
"""Returns the output filename."""
|
|
46
|
+
return os.path.join(self._output_dir, self._results_file)
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def config_path(self) -> str | None:
|
|
50
|
+
"""Returns the path to the .yaml configuration file from sys args if available."""
|
|
51
|
+
if "--config" in sys.argv:
|
|
52
|
+
try:
|
|
53
|
+
config_path = sys.argv[sys.argv.index("--config") + 1]
|
|
54
|
+
if not config_path.endswith(".yaml"):
|
|
55
|
+
logger.warning(f"Unexpected config file {config_path}, should be a .yaml file.")
|
|
56
|
+
else:
|
|
57
|
+
return config_path
|
|
58
|
+
except IndexError as e:
|
|
59
|
+
logger.warning(f"Failed to fetch config_path from sys args {e}")
|
|
60
|
+
|
|
61
|
+
def update(
|
|
62
|
+
self,
|
|
63
|
+
validation_scores: _EVALUATE_OUTPUT,
|
|
64
|
+
test_scores: _EVALUATE_OUTPUT | None = None,
|
|
65
|
+
) -> None:
|
|
66
|
+
"""Updates the state of the tracked metrics in-place."""
|
|
67
|
+
self._update_validation_metrics(validation_scores)
|
|
68
|
+
self._update_test_metrics(test_scores)
|
|
69
|
+
|
|
70
|
+
def compute(self) -> Dict[str, List[Dict[str, Any]]]:
|
|
71
|
+
"""Computes and returns the session statistics."""
|
|
72
|
+
validation_statistics = list(map(_calculate_statistics, self._validation_metrics))
|
|
73
|
+
test_statistics = list(map(_calculate_statistics, self._test_metrics))
|
|
74
|
+
return {"val": validation_statistics, "test": test_statistics}
|
|
75
|
+
|
|
76
|
+
def export(self) -> Dict[str, Any]:
|
|
77
|
+
"""Exports the results."""
|
|
78
|
+
statistics = self.compute()
|
|
79
|
+
return {"metrics": statistics}
|
|
80
|
+
|
|
81
|
+
def save(self) -> None:
|
|
82
|
+
"""Saves the recorded results."""
|
|
83
|
+
results = self.export()
|
|
84
|
+
_save_json(results, self.filename)
|
|
85
|
+
self._save_config()
|
|
86
|
+
|
|
87
|
+
def reset(self) -> None:
|
|
88
|
+
"""Resets the state of the tracked metrics."""
|
|
89
|
+
self._validation_metrics = []
|
|
90
|
+
self._test_metrics = []
|
|
91
|
+
|
|
92
|
+
def _update_validation_metrics(self, metrics: _EVALUATE_OUTPUT) -> None:
|
|
93
|
+
"""Updates the validation metrics in-place."""
|
|
94
|
+
self._validation_metrics = _update_session_metrics(self._validation_metrics, metrics)
|
|
95
|
+
|
|
96
|
+
def _update_test_metrics(self, metrics: _EVALUATE_OUTPUT | None) -> None:
|
|
97
|
+
"""Updates the test metrics in-place."""
|
|
98
|
+
if metrics:
|
|
99
|
+
self._test_metrics = _update_session_metrics(self._test_metrics, metrics)
|
|
100
|
+
|
|
101
|
+
def _save_config(self) -> None:
|
|
102
|
+
"""Saves the config yaml with resolved env placeholders to the output directory."""
|
|
103
|
+
if self.config_path:
|
|
104
|
+
config = OmegaConf.load(self.config_path)
|
|
105
|
+
fs = cloud_io.get_filesystem(self._output_dir, anon=False)
|
|
106
|
+
with fs.open(os.path.join(self._output_dir, self._config_file), "w") as file:
|
|
107
|
+
config_yaml = OmegaConf.to_yaml(config, resolve=True)
|
|
108
|
+
file.write(config_yaml)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _update_session_metrics(
|
|
112
|
+
session_metrics: List[SESSION_METRICS],
|
|
113
|
+
run_metrics: _EVALUATE_OUTPUT,
|
|
114
|
+
) -> List[SESSION_METRICS]:
|
|
115
|
+
"""Updates and returns the given metrics session with the new ones."""
|
|
116
|
+
session_metrics = session_metrics or _init_session_metrics(len(run_metrics))
|
|
117
|
+
for index, dataset_metrics in enumerate(run_metrics):
|
|
118
|
+
for name, value in dataset_metrics.items():
|
|
119
|
+
session_metrics[index][name].append(value)
|
|
120
|
+
return session_metrics
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _init_session_metrics(n_datasets: int) -> List[SESSION_METRICS]:
|
|
124
|
+
"""Returns the init session metrics."""
|
|
125
|
+
return [collections.defaultdict(list) for _ in range(n_datasets)]
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _calculate_statistics(session_metrics: SESSION_METRICS) -> Dict[str, float | List[float]]:
|
|
129
|
+
"""Calculate the metric statistics of a dataset session run."""
|
|
130
|
+
|
|
131
|
+
def _calculate_metric_statistics(values: List[float]) -> Dict[str, float | List[float]]:
|
|
132
|
+
"""Calculates and returns the metric statistics."""
|
|
133
|
+
mean = statistics.mean(values)
|
|
134
|
+
stdev = statistics.stdev(values) if len(values) > 1 else 0
|
|
135
|
+
return {"mean": mean, "stdev": stdev, "values": values}
|
|
136
|
+
|
|
137
|
+
return dicttoolz.valmap(_calculate_metric_statistics, session_metrics)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _save_json(data: Dict[str, Any], save_as: str = "data.json"):
|
|
141
|
+
"""Saves data to a json file."""
|
|
142
|
+
if not save_as.endswith(".json"):
|
|
143
|
+
raise ValueError()
|
|
144
|
+
|
|
145
|
+
output_dir = os.path.dirname(save_as)
|
|
146
|
+
fs = cloud_io.get_filesystem(output_dir, anon=False)
|
|
147
|
+
fs.makedirs(output_dir, exist_ok=True)
|
|
148
|
+
with fs.open(save_as, "w") as file:
|
|
149
|
+
json.dump(data, file, indent=4, sort_keys=True)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Training related utilities."""
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
from collections import abc
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def clone(*inputs: Any) -> Any:
|
|
9
|
+
"""Deep copies a list of object and returns them."""
|
|
10
|
+
if not isinstance(inputs, abc.Iterable):
|
|
11
|
+
return copy.deepcopy(inputs)
|
|
12
|
+
return [copy.deepcopy(obj) for obj in inputs]
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""Fit session related functions."""
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT
|
|
6
|
+
|
|
7
|
+
from eva.core.data import datamodules
|
|
8
|
+
from eva.core.models import modules
|
|
9
|
+
from eva.core.trainers import _recorder, _utils
|
|
10
|
+
from eva.core.trainers import trainer as eva_trainer
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def run_evaluation_session(
|
|
14
|
+
base_trainer: eva_trainer.Trainer,
|
|
15
|
+
base_model: modules.ModelModule,
|
|
16
|
+
datamodule: datamodules.DataModule,
|
|
17
|
+
*,
|
|
18
|
+
n_runs: int = 1,
|
|
19
|
+
) -> None:
|
|
20
|
+
"""Runs a downstream evaluation session out-of-place.
|
|
21
|
+
|
|
22
|
+
It performs an evaluation run (fit and evaluate) on the model
|
|
23
|
+
multiple times. Note that as the input `base_trainer` and
|
|
24
|
+
`base_model` would be cloned, the input object would not
|
|
25
|
+
be modified.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
base_trainer: The base trainer module to use.
|
|
29
|
+
base_model: The base model module to use.
|
|
30
|
+
datamodule: The data module.
|
|
31
|
+
n_runs: The amount of runs (fit and evaluate) to perform.
|
|
32
|
+
"""
|
|
33
|
+
recorder = _recorder.SessionRecorder(output_dir=base_trainer.default_log_dir)
|
|
34
|
+
for run_index in range(n_runs):
|
|
35
|
+
validation_scores, test_scores = run_evaluation(
|
|
36
|
+
base_trainer, base_model, datamodule, run_id=f"run_{run_index}"
|
|
37
|
+
)
|
|
38
|
+
recorder.update(validation_scores, test_scores)
|
|
39
|
+
recorder.save()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def run_evaluation(
|
|
43
|
+
base_trainer: eva_trainer.Trainer,
|
|
44
|
+
base_model: modules.ModelModule,
|
|
45
|
+
datamodule: datamodules.DataModule,
|
|
46
|
+
*,
|
|
47
|
+
run_id: str | None = None,
|
|
48
|
+
) -> Tuple[_EVALUATE_OUTPUT, _EVALUATE_OUTPUT | None]:
|
|
49
|
+
"""Fits and evaluates a model out-of-place.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
base_trainer: The base trainer to use but not modify.
|
|
53
|
+
base_model: The model module to use but not modify.
|
|
54
|
+
datamodule: The data module.
|
|
55
|
+
run_id: The run id to be appended to the output log directory.
|
|
56
|
+
If `None`, it will use the log directory of the trainer as is.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
A tuple of with the validation and the test metrics (if exists).
|
|
60
|
+
"""
|
|
61
|
+
trainer, model = _utils.clone(base_trainer, base_model)
|
|
62
|
+
trainer.setup_log_dirs(run_id or "")
|
|
63
|
+
return fit_and_validate(trainer, model, datamodule)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def fit_and_validate(
|
|
67
|
+
trainer: eva_trainer.Trainer,
|
|
68
|
+
model: modules.ModelModule,
|
|
69
|
+
datamodule: datamodules.DataModule,
|
|
70
|
+
) -> Tuple[_EVALUATE_OUTPUT, _EVALUATE_OUTPUT | None]:
|
|
71
|
+
"""Fits and evaluates a model in-place.
|
|
72
|
+
|
|
73
|
+
If the test set is set in the datamodule, it will evaluate the model
|
|
74
|
+
on the test set as well.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
trainer: The trainer module to use and update in-place.
|
|
78
|
+
model: The model module to use and update in-place.
|
|
79
|
+
datamodule: The data module.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
A tuple of with the validation and the test metrics (if exists).
|
|
83
|
+
"""
|
|
84
|
+
trainer.fit(model, datamodule=datamodule)
|
|
85
|
+
validation_scores = trainer.validate(datamodule=datamodule)
|
|
86
|
+
test_scores = None if datamodule.datasets.test is None else trainer.test(datamodule=datamodule)
|
|
87
|
+
return validation_scores, test_scores
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def infer_model(
|
|
91
|
+
base_trainer: eva_trainer.Trainer,
|
|
92
|
+
base_model: modules.ModelModule,
|
|
93
|
+
datamodule: datamodules.DataModule,
|
|
94
|
+
*,
|
|
95
|
+
return_predictions: bool = False,
|
|
96
|
+
) -> None:
|
|
97
|
+
"""Performs model inference out-of-place.
|
|
98
|
+
|
|
99
|
+
Note that the input `base_model` and `base_trainer` would
|
|
100
|
+
not be modified.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
base_trainer: The base trainer to use but not modify.
|
|
104
|
+
base_model: The model module to use but not modify.
|
|
105
|
+
datamodule: The data module.
|
|
106
|
+
return_predictions: Whether to return the model predictions.
|
|
107
|
+
"""
|
|
108
|
+
trainer, model = _utils.clone(base_trainer, base_model)
|
|
109
|
+
return trainer.predict(
|
|
110
|
+
model=model,
|
|
111
|
+
datamodule=datamodule,
|
|
112
|
+
return_predictions=return_predictions,
|
|
113
|
+
)
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""Core trainer module."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from lightning.pytorch import loggers as pl_loggers
|
|
7
|
+
from lightning.pytorch import trainer as pl_trainer
|
|
8
|
+
from lightning.pytorch.utilities import argparse
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from eva.core.data import datamodules
|
|
12
|
+
from eva.core.models import modules
|
|
13
|
+
from eva.core.trainers import _logging, functional
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Trainer(pl_trainer.Trainer):
|
|
17
|
+
"""Core trainer class.
|
|
18
|
+
|
|
19
|
+
This is an extended version of lightning's core trainer class.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
@argparse._defaults_from_env_vars
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
*args: Any,
|
|
26
|
+
default_root_dir: str = "logs",
|
|
27
|
+
n_runs: int = 1,
|
|
28
|
+
**kwargs: Any,
|
|
29
|
+
) -> None:
|
|
30
|
+
"""Initializes the trainer.
|
|
31
|
+
|
|
32
|
+
For the input arguments, refer to ::class::`lightning.pytorch.Trainer`.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
args: Positional arguments of ::class::`lightning.pytorch.Trainer`.
|
|
36
|
+
default_root_dir: The default root directory to store the output logs.
|
|
37
|
+
Unlike in ::class::`lightning.pytorch.Trainer`, this path would be the
|
|
38
|
+
prioritized destination point.
|
|
39
|
+
n_runs: The amount of runs (fit and evaluate) to perform in an evaluation session.
|
|
40
|
+
kwargs: Kew-word arguments of ::class::`lightning.pytorch.Trainer`.
|
|
41
|
+
"""
|
|
42
|
+
super().__init__(*args, default_root_dir=default_root_dir, **kwargs)
|
|
43
|
+
|
|
44
|
+
self._n_runs = n_runs
|
|
45
|
+
|
|
46
|
+
self._session_id: str = _logging.generate_session_id()
|
|
47
|
+
self._log_dir: str = self.default_log_dir
|
|
48
|
+
|
|
49
|
+
self.setup_log_dirs()
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def default_log_dir(self) -> str:
|
|
53
|
+
"""Returns the default log directory."""
|
|
54
|
+
return os.path.join(self.default_root_dir, self._session_id)
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
@override
|
|
58
|
+
def log_dir(self) -> str | None:
|
|
59
|
+
return self.strategy.broadcast(self._log_dir)
|
|
60
|
+
|
|
61
|
+
def setup_log_dirs(self, subdirectory: str = "") -> None:
|
|
62
|
+
"""Setups the logging directory of the trainer and experimental loggers in-place.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
subdirectory: Whether to append a subdirectory to the output log.
|
|
66
|
+
"""
|
|
67
|
+
self._log_dir = os.path.join(self.default_root_dir, self._session_id, subdirectory)
|
|
68
|
+
os.fspath(self._log_dir)
|
|
69
|
+
|
|
70
|
+
for logger in self.loggers:
|
|
71
|
+
if isinstance(logger, (pl_loggers.CSVLogger, pl_loggers.TensorBoardLogger)):
|
|
72
|
+
logger._root_dir = self.default_root_dir
|
|
73
|
+
logger._name = self._session_id
|
|
74
|
+
logger._version = subdirectory
|
|
75
|
+
|
|
76
|
+
def run_evaluation_session(
|
|
77
|
+
self,
|
|
78
|
+
model: modules.ModelModule,
|
|
79
|
+
datamodule: datamodules.DataModule,
|
|
80
|
+
) -> None:
|
|
81
|
+
"""Runs a evaluation session out-of-place.
|
|
82
|
+
|
|
83
|
+
It performs an evaluation run (fit and evaluate) the model
|
|
84
|
+
`self._n_run` times. Note that the input `base_model` would
|
|
85
|
+
not be modified, so the weights of the input model will remain
|
|
86
|
+
as they are.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
model: The base model module to evaluate.
|
|
90
|
+
datamodule: The data module.
|
|
91
|
+
"""
|
|
92
|
+
functional.run_evaluation_session(
|
|
93
|
+
base_trainer=self,
|
|
94
|
+
base_model=model,
|
|
95
|
+
datamodule=datamodule,
|
|
96
|
+
n_runs=self._n_runs,
|
|
97
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Utilities and library level helper functionalities."""
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""DataFrame related I/O operations."""
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def read_dataframe(path: str) -> pd.DataFrame:
|
|
7
|
+
"""Reads and loads a DataFrame file.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
path: The path to the manifest file.
|
|
11
|
+
|
|
12
|
+
Returns:
|
|
13
|
+
The data of the file as a `DataFrame`.
|
|
14
|
+
"""
|
|
15
|
+
if path.endswith(".csv"):
|
|
16
|
+
data = pd.read_csv(path)
|
|
17
|
+
elif path.endswith(".parquet"):
|
|
18
|
+
data = pd.read_parquet(path)
|
|
19
|
+
else:
|
|
20
|
+
raise ValueError(f"Failed to load manifest file at '{path}'.")
|
|
21
|
+
return data
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""Multiprocessing utilities."""
|
|
2
|
+
|
|
3
|
+
import multiprocessing
|
|
4
|
+
import sys
|
|
5
|
+
import traceback
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Process(multiprocessing.Process):
|
|
10
|
+
"""Multiprocessing wrapper with logic to propagate exceptions to the parent process.
|
|
11
|
+
|
|
12
|
+
Source: https://stackoverflow.com/a/33599967/4992248
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
16
|
+
"""Initialize the process."""
|
|
17
|
+
multiprocessing.Process.__init__(self, *args, **kwargs)
|
|
18
|
+
|
|
19
|
+
self._parent_conn, self._child_conn = multiprocessing.Pipe()
|
|
20
|
+
self._exception = None
|
|
21
|
+
|
|
22
|
+
def run(self) -> None:
|
|
23
|
+
"""Run the process."""
|
|
24
|
+
try:
|
|
25
|
+
multiprocessing.Process.run(self)
|
|
26
|
+
self._child_conn.send(None)
|
|
27
|
+
except Exception as e:
|
|
28
|
+
tb = traceback.format_exc()
|
|
29
|
+
self._child_conn.send((e, tb))
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def exception(self):
|
|
33
|
+
"""Property that contains exception information from the process."""
|
|
34
|
+
if self._parent_conn.poll():
|
|
35
|
+
self._exception = self._parent_conn.recv()
|
|
36
|
+
return self._exception
|
|
37
|
+
|
|
38
|
+
def check_exceptions(self) -> None:
|
|
39
|
+
"""Check for exception propagate it to the parent process."""
|
|
40
|
+
if not self.is_alive():
|
|
41
|
+
if self.exception:
|
|
42
|
+
error, traceback = self.exception
|
|
43
|
+
sys.stderr.write(traceback + "\n")
|
|
44
|
+
raise error
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Processing workers utilities and helper functions."""
|
|
2
|
+
|
|
3
|
+
import multiprocessing
|
|
4
|
+
from typing import Any, Callable
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def main_worker_only(func: Callable) -> Any:
|
|
8
|
+
"""Function decorator which will execute it only on main / worker process."""
|
|
9
|
+
|
|
10
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
11
|
+
"""Wrapper function for the decorated method."""
|
|
12
|
+
if is_main_worker():
|
|
13
|
+
return func(*args, **kwargs)
|
|
14
|
+
|
|
15
|
+
return wrapper
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def is_main_worker() -> bool:
|
|
19
|
+
"""Returns whether the main process / worker is currently used."""
|
|
20
|
+
process = multiprocessing.current_process()
|
|
21
|
+
return process.name == "MainProcess"
|
eva/vision/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""eva vision API."""
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
from eva.vision import models, utils
|
|
5
|
+
from eva.vision.data import datasets, transforms
|
|
6
|
+
except ImportError as e:
|
|
7
|
+
msg = (
|
|
8
|
+
"eva vision requirements are not installed.\n\n"
|
|
9
|
+
"Please pip install as follows:\n"
|
|
10
|
+
' python -m pip install "eva[vision]" --upgrade'
|
|
11
|
+
)
|
|
12
|
+
raise ImportError(str(e) + "\n\n" + msg) from e
|
|
13
|
+
|
|
14
|
+
__all__ = ["models", "utils", "datasets", "transforms"]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Vision Datasets API."""
|
|
2
|
+
|
|
3
|
+
from eva.vision.data.datasets.classification import (
|
|
4
|
+
BACH,
|
|
5
|
+
CRC,
|
|
6
|
+
MHIST,
|
|
7
|
+
PatchCamelyon,
|
|
8
|
+
TotalSegmentatorClassification,
|
|
9
|
+
)
|
|
10
|
+
from eva.vision.data.datasets.segmentation import ImageSegmentation, TotalSegmentator2D
|
|
11
|
+
from eva.vision.data.datasets.vision import VisionDataset
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"BACH",
|
|
15
|
+
"CRC",
|
|
16
|
+
"MHIST",
|
|
17
|
+
"ImageSegmentation",
|
|
18
|
+
"PatchCamelyon",
|
|
19
|
+
"TotalSegmentatorClassification",
|
|
20
|
+
"TotalSegmentator2D",
|
|
21
|
+
"VisionDataset",
|
|
22
|
+
]
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""Dataset related function and helper functions."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Tuple
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def indices_to_ranges(indices: List[int]) -> List[Tuple[int, int]]:
|
|
7
|
+
"""Turns a list of indices to a list of ranges.
|
|
8
|
+
|
|
9
|
+
The produced range intervals are half-open inequalities: start <= x < end.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
indices: The list of indices to produce the ranges from.
|
|
13
|
+
|
|
14
|
+
Return:
|
|
15
|
+
A list of half-open intervals.
|
|
16
|
+
|
|
17
|
+
Example:
|
|
18
|
+
>>> indices = [0, 1, 2, 4, 6, 7, 8]
|
|
19
|
+
>>> ranges = indices_to_ranges(indices)
|
|
20
|
+
>>> assert ranges == [(0, 3), (4, 5), (6, 9)]
|
|
21
|
+
"""
|
|
22
|
+
ranges = []
|
|
23
|
+
start_index = 0
|
|
24
|
+
for i, current in enumerate(indices):
|
|
25
|
+
if i + 1 < len(indices) and current + 1 == indices[i + 1]:
|
|
26
|
+
continue
|
|
27
|
+
|
|
28
|
+
start = indices[start_index]
|
|
29
|
+
end = start if start_index == i else current
|
|
30
|
+
ranges.append((start, end + 1))
|
|
31
|
+
start_index = i + 1
|
|
32
|
+
|
|
33
|
+
return ranges
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def ranges_to_indices(ranges: List[Tuple[int, int]]) -> List[int]:
|
|
37
|
+
"""Unpacks a list of ranges to individual indices.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
ranges: The list of ranges to produce the indices from.
|
|
41
|
+
|
|
42
|
+
Return:
|
|
43
|
+
A list of the indices.
|
|
44
|
+
|
|
45
|
+
Example:
|
|
46
|
+
>>> ranges == [(0, 3), (4, 5), (6, 9)]
|
|
47
|
+
>>> indices = ranges_to_indices(ranges)
|
|
48
|
+
>>> assert indices == [0, 1, 2, 4, 6, 7, 8]
|
|
49
|
+
"""
|
|
50
|
+
return [index for start, end in ranges for index in range(start, end)]
|