kaiko-eva 0.0.0.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (111) hide show
  1. eva/.DS_Store +0 -0
  2. eva/__init__.py +33 -0
  3. eva/__main__.py +18 -0
  4. eva/__version__.py +25 -0
  5. eva/core/__init__.py +19 -0
  6. eva/core/callbacks/__init__.py +5 -0
  7. eva/core/callbacks/writers/__init__.py +5 -0
  8. eva/core/callbacks/writers/embeddings.py +169 -0
  9. eva/core/callbacks/writers/typings.py +23 -0
  10. eva/core/cli/__init__.py +5 -0
  11. eva/core/cli/cli.py +19 -0
  12. eva/core/cli/logo.py +38 -0
  13. eva/core/cli/setup.py +89 -0
  14. eva/core/data/__init__.py +14 -0
  15. eva/core/data/dataloaders/__init__.py +5 -0
  16. eva/core/data/dataloaders/dataloader.py +80 -0
  17. eva/core/data/datamodules/__init__.py +6 -0
  18. eva/core/data/datamodules/call.py +33 -0
  19. eva/core/data/datamodules/datamodule.py +108 -0
  20. eva/core/data/datamodules/schemas.py +62 -0
  21. eva/core/data/datasets/__init__.py +7 -0
  22. eva/core/data/datasets/base.py +53 -0
  23. eva/core/data/datasets/classification/__init__.py +5 -0
  24. eva/core/data/datasets/classification/embeddings.py +154 -0
  25. eva/core/data/datasets/dataset.py +6 -0
  26. eva/core/data/samplers/__init__.py +5 -0
  27. eva/core/data/samplers/sampler.py +6 -0
  28. eva/core/data/transforms/__init__.py +5 -0
  29. eva/core/data/transforms/dtype/__init__.py +5 -0
  30. eva/core/data/transforms/dtype/array.py +28 -0
  31. eva/core/interface/__init__.py +5 -0
  32. eva/core/interface/interface.py +79 -0
  33. eva/core/metrics/__init__.py +17 -0
  34. eva/core/metrics/average_loss.py +47 -0
  35. eva/core/metrics/binary_balanced_accuracy.py +22 -0
  36. eva/core/metrics/defaults/__init__.py +6 -0
  37. eva/core/metrics/defaults/classification/__init__.py +6 -0
  38. eva/core/metrics/defaults/classification/binary.py +76 -0
  39. eva/core/metrics/defaults/classification/multiclass.py +80 -0
  40. eva/core/metrics/structs/__init__.py +9 -0
  41. eva/core/metrics/structs/collection.py +6 -0
  42. eva/core/metrics/structs/metric.py +6 -0
  43. eva/core/metrics/structs/module.py +115 -0
  44. eva/core/metrics/structs/schemas.py +47 -0
  45. eva/core/metrics/structs/typings.py +15 -0
  46. eva/core/models/__init__.py +13 -0
  47. eva/core/models/modules/__init__.py +7 -0
  48. eva/core/models/modules/head.py +113 -0
  49. eva/core/models/modules/inference.py +37 -0
  50. eva/core/models/modules/module.py +190 -0
  51. eva/core/models/modules/typings.py +23 -0
  52. eva/core/models/modules/utils/__init__.py +6 -0
  53. eva/core/models/modules/utils/batch_postprocess.py +57 -0
  54. eva/core/models/modules/utils/grad.py +23 -0
  55. eva/core/models/networks/__init__.py +6 -0
  56. eva/core/models/networks/_utils.py +25 -0
  57. eva/core/models/networks/mlp.py +69 -0
  58. eva/core/models/networks/transforms/__init__.py +5 -0
  59. eva/core/models/networks/transforms/extract_cls_features.py +25 -0
  60. eva/core/models/networks/wrappers/__init__.py +8 -0
  61. eva/core/models/networks/wrappers/base.py +47 -0
  62. eva/core/models/networks/wrappers/from_function.py +58 -0
  63. eva/core/models/networks/wrappers/huggingface.py +37 -0
  64. eva/core/models/networks/wrappers/onnx.py +47 -0
  65. eva/core/trainers/__init__.py +6 -0
  66. eva/core/trainers/_logging.py +81 -0
  67. eva/core/trainers/_recorder.py +149 -0
  68. eva/core/trainers/_utils.py +12 -0
  69. eva/core/trainers/functional.py +113 -0
  70. eva/core/trainers/trainer.py +97 -0
  71. eva/core/utils/__init__.py +1 -0
  72. eva/core/utils/io/__init__.py +5 -0
  73. eva/core/utils/io/dataframe.py +21 -0
  74. eva/core/utils/multiprocessing.py +44 -0
  75. eva/core/utils/workers.py +21 -0
  76. eva/vision/__init__.py +14 -0
  77. eva/vision/data/__init__.py +5 -0
  78. eva/vision/data/datasets/__init__.py +22 -0
  79. eva/vision/data/datasets/_utils.py +50 -0
  80. eva/vision/data/datasets/_validators.py +44 -0
  81. eva/vision/data/datasets/classification/__init__.py +15 -0
  82. eva/vision/data/datasets/classification/bach.py +174 -0
  83. eva/vision/data/datasets/classification/base.py +103 -0
  84. eva/vision/data/datasets/classification/crc.py +176 -0
  85. eva/vision/data/datasets/classification/mhist.py +106 -0
  86. eva/vision/data/datasets/classification/patch_camelyon.py +203 -0
  87. eva/vision/data/datasets/classification/total_segmentator.py +212 -0
  88. eva/vision/data/datasets/segmentation/__init__.py +6 -0
  89. eva/vision/data/datasets/segmentation/base.py +112 -0
  90. eva/vision/data/datasets/segmentation/total_segmentator.py +212 -0
  91. eva/vision/data/datasets/structs.py +17 -0
  92. eva/vision/data/datasets/vision.py +43 -0
  93. eva/vision/data/transforms/__init__.py +5 -0
  94. eva/vision/data/transforms/common/__init__.py +5 -0
  95. eva/vision/data/transforms/common/resize_and_crop.py +44 -0
  96. eva/vision/models/__init__.py +5 -0
  97. eva/vision/models/networks/__init__.py +6 -0
  98. eva/vision/models/networks/abmil.py +176 -0
  99. eva/vision/models/networks/postprocesses/__init__.py +5 -0
  100. eva/vision/models/networks/postprocesses/cls.py +25 -0
  101. eva/vision/utils/__init__.py +5 -0
  102. eva/vision/utils/io/__init__.py +12 -0
  103. eva/vision/utils/io/_utils.py +29 -0
  104. eva/vision/utils/io/image.py +54 -0
  105. eva/vision/utils/io/nifti.py +50 -0
  106. eva/vision/utils/io/text.py +18 -0
  107. kaiko_eva-0.0.0.dev6.dist-info/METADATA +393 -0
  108. kaiko_eva-0.0.0.dev6.dist-info/RECORD +111 -0
  109. kaiko_eva-0.0.0.dev6.dist-info/WHEEL +4 -0
  110. kaiko_eva-0.0.0.dev6.dist-info/entry_points.txt +4 -0
  111. kaiko_eva-0.0.0.dev6.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,81 @@
1
+ """Helper functions and utilities for trainer logging."""
2
+
3
+ import hashlib
4
+ import sys
5
+ from datetime import datetime
6
+
7
+ from lightning_fabric.utilities import cloud_io
8
+ from loguru import logger
9
+
10
+
11
+ def generate_session_id() -> str:
12
+ """Generates and returns a unique string ID of an experiment.
13
+
14
+ The ID is composed of the run timestamp and a its config hash. If the
15
+ configuration hash is an empty string, it will use only the timestamp.
16
+ """
17
+ timestamp = _generate_timestamp_hash()
18
+ config_hash = _generate_config_hash()
19
+ return f"{timestamp}_{config_hash}" if config_hash else timestamp
20
+
21
+
22
+ def _generate_timestamp_hash() -> str:
23
+ """Generate a time-based hash id."""
24
+ timestamp = datetime.now()
25
+ return timestamp.strftime("%Y%m%d-%H%M%S%f")
26
+
27
+
28
+ def _generate_config_hash(max_hash_len: int = 8) -> str:
29
+ """Generates a hash id based on a yaml configuration file.
30
+
31
+ Args:
32
+ max_hash_len: The maximum length of the produced hash id.
33
+ """
34
+ config_path = _fetch_config_path()
35
+ if config_path is None:
36
+ logger.warning(
37
+ "No or multiple configuration file found from command line arguments. "
38
+ "No configuration hash code will created for this experiment."
39
+ )
40
+ return ""
41
+
42
+ return _generate_hash_from_config(config_path, max_hash_len)
43
+
44
+
45
+ def _fetch_config_path() -> str | None:
46
+ """Retrieves the configuration path from command line arguments.
47
+
48
+ It returns `None` if no or multiple configuration files found in
49
+ the system arguments.
50
+
51
+ Returns:
52
+ The path to the configuration file.
53
+ """
54
+ inputs = sys.argv
55
+ config_paths = [inputs[i + 1] for i, arg in enumerate(inputs) if arg == "--config"]
56
+ if len(config_paths) == 0 or len(config_paths) > 1:
57
+ # TODO combine the multiple configuration files
58
+ # and produced hash for the merged one.
59
+ return None
60
+
61
+ return config_paths[0]
62
+
63
+
64
+ def _generate_hash_from_config(path: str, max_hash_len: int = 8) -> str:
65
+ """Return a hash from the contents of the configuration file.
66
+
67
+ Args:
68
+ path: Path to the configuration file.
69
+ max_hash_len: Maximum length of the returned hash.
70
+
71
+ Returns:
72
+ Hash of the configuration file content.
73
+ """
74
+ fs = cloud_io.get_filesystem(path)
75
+ with fs.open(path, "r") as stream:
76
+ config = stream.read()
77
+ if isinstance(config, str):
78
+ config = config.encode("utf-8")
79
+ config_sha256 = hashlib.sha256(config)
80
+ hash_id = config_sha256.hexdigest()
81
+ return hash_id[:max_hash_len]
@@ -0,0 +1,149 @@
1
+ """Multi-run summary recorder."""
2
+
3
+ import collections
4
+ import json
5
+ import os
6
+ import statistics
7
+ import sys
8
+ from typing import Any, Dict, List, Mapping
9
+
10
+ from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT
11
+ from lightning_fabric.utilities import cloud_io
12
+ from loguru import logger
13
+ from omegaconf import OmegaConf
14
+ from toolz import dicttoolz
15
+
16
+ SESSION_METRICS = Mapping[str, List[float]]
17
+ """Session metrics type-hint."""
18
+
19
+
20
+ class SessionRecorder:
21
+ """Multi-run (session) summary logger."""
22
+
23
+ def __init__(
24
+ self,
25
+ output_dir: str,
26
+ results_file: str = "results.json",
27
+ config_file: str = "config.yaml",
28
+ ) -> None:
29
+ """Initializes the recorder.
30
+
31
+ Args:
32
+ output_dir: The destination folder to save the results.
33
+ results_file: The name of the results json file.
34
+ config_file: The name of the yaml configuration file.
35
+ """
36
+ self._output_dir = output_dir
37
+ self._results_file = results_file
38
+ self._config_file = config_file
39
+
40
+ self._validation_metrics: List[SESSION_METRICS] = []
41
+ self._test_metrics: List[SESSION_METRICS] = []
42
+
43
+ @property
44
+ def filename(self) -> str:
45
+ """Returns the output filename."""
46
+ return os.path.join(self._output_dir, self._results_file)
47
+
48
+ @property
49
+ def config_path(self) -> str | None:
50
+ """Returns the path to the .yaml configuration file from sys args if available."""
51
+ if "--config" in sys.argv:
52
+ try:
53
+ config_path = sys.argv[sys.argv.index("--config") + 1]
54
+ if not config_path.endswith(".yaml"):
55
+ logger.warning(f"Unexpected config file {config_path}, should be a .yaml file.")
56
+ else:
57
+ return config_path
58
+ except IndexError as e:
59
+ logger.warning(f"Failed to fetch config_path from sys args {e}")
60
+
61
+ def update(
62
+ self,
63
+ validation_scores: _EVALUATE_OUTPUT,
64
+ test_scores: _EVALUATE_OUTPUT | None = None,
65
+ ) -> None:
66
+ """Updates the state of the tracked metrics in-place."""
67
+ self._update_validation_metrics(validation_scores)
68
+ self._update_test_metrics(test_scores)
69
+
70
+ def compute(self) -> Dict[str, List[Dict[str, Any]]]:
71
+ """Computes and returns the session statistics."""
72
+ validation_statistics = list(map(_calculate_statistics, self._validation_metrics))
73
+ test_statistics = list(map(_calculate_statistics, self._test_metrics))
74
+ return {"val": validation_statistics, "test": test_statistics}
75
+
76
+ def export(self) -> Dict[str, Any]:
77
+ """Exports the results."""
78
+ statistics = self.compute()
79
+ return {"metrics": statistics}
80
+
81
+ def save(self) -> None:
82
+ """Saves the recorded results."""
83
+ results = self.export()
84
+ _save_json(results, self.filename)
85
+ self._save_config()
86
+
87
+ def reset(self) -> None:
88
+ """Resets the state of the tracked metrics."""
89
+ self._validation_metrics = []
90
+ self._test_metrics = []
91
+
92
+ def _update_validation_metrics(self, metrics: _EVALUATE_OUTPUT) -> None:
93
+ """Updates the validation metrics in-place."""
94
+ self._validation_metrics = _update_session_metrics(self._validation_metrics, metrics)
95
+
96
+ def _update_test_metrics(self, metrics: _EVALUATE_OUTPUT | None) -> None:
97
+ """Updates the test metrics in-place."""
98
+ if metrics:
99
+ self._test_metrics = _update_session_metrics(self._test_metrics, metrics)
100
+
101
+ def _save_config(self) -> None:
102
+ """Saves the config yaml with resolved env placeholders to the output directory."""
103
+ if self.config_path:
104
+ config = OmegaConf.load(self.config_path)
105
+ fs = cloud_io.get_filesystem(self._output_dir, anon=False)
106
+ with fs.open(os.path.join(self._output_dir, self._config_file), "w") as file:
107
+ config_yaml = OmegaConf.to_yaml(config, resolve=True)
108
+ file.write(config_yaml)
109
+
110
+
111
+ def _update_session_metrics(
112
+ session_metrics: List[SESSION_METRICS],
113
+ run_metrics: _EVALUATE_OUTPUT,
114
+ ) -> List[SESSION_METRICS]:
115
+ """Updates and returns the given metrics session with the new ones."""
116
+ session_metrics = session_metrics or _init_session_metrics(len(run_metrics))
117
+ for index, dataset_metrics in enumerate(run_metrics):
118
+ for name, value in dataset_metrics.items():
119
+ session_metrics[index][name].append(value)
120
+ return session_metrics
121
+
122
+
123
+ def _init_session_metrics(n_datasets: int) -> List[SESSION_METRICS]:
124
+ """Returns the init session metrics."""
125
+ return [collections.defaultdict(list) for _ in range(n_datasets)]
126
+
127
+
128
+ def _calculate_statistics(session_metrics: SESSION_METRICS) -> Dict[str, float | List[float]]:
129
+ """Calculate the metric statistics of a dataset session run."""
130
+
131
+ def _calculate_metric_statistics(values: List[float]) -> Dict[str, float | List[float]]:
132
+ """Calculates and returns the metric statistics."""
133
+ mean = statistics.mean(values)
134
+ stdev = statistics.stdev(values) if len(values) > 1 else 0
135
+ return {"mean": mean, "stdev": stdev, "values": values}
136
+
137
+ return dicttoolz.valmap(_calculate_metric_statistics, session_metrics)
138
+
139
+
140
+ def _save_json(data: Dict[str, Any], save_as: str = "data.json"):
141
+ """Saves data to a json file."""
142
+ if not save_as.endswith(".json"):
143
+ raise ValueError()
144
+
145
+ output_dir = os.path.dirname(save_as)
146
+ fs = cloud_io.get_filesystem(output_dir, anon=False)
147
+ fs.makedirs(output_dir, exist_ok=True)
148
+ with fs.open(save_as, "w") as file:
149
+ json.dump(data, file, indent=4, sort_keys=True)
@@ -0,0 +1,12 @@
1
+ """Training related utilities."""
2
+
3
+ import copy
4
+ from collections import abc
5
+ from typing import Any
6
+
7
+
8
+ def clone(*inputs: Any) -> Any:
9
+ """Deep copies a list of object and returns them."""
10
+ if not isinstance(inputs, abc.Iterable):
11
+ return copy.deepcopy(inputs)
12
+ return [copy.deepcopy(obj) for obj in inputs]
@@ -0,0 +1,113 @@
1
+ """Fit session related functions."""
2
+
3
+ from typing import Tuple
4
+
5
+ from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT
6
+
7
+ from eva.core.data import datamodules
8
+ from eva.core.models import modules
9
+ from eva.core.trainers import _recorder, _utils
10
+ from eva.core.trainers import trainer as eva_trainer
11
+
12
+
13
+ def run_evaluation_session(
14
+ base_trainer: eva_trainer.Trainer,
15
+ base_model: modules.ModelModule,
16
+ datamodule: datamodules.DataModule,
17
+ *,
18
+ n_runs: int = 1,
19
+ ) -> None:
20
+ """Runs a downstream evaluation session out-of-place.
21
+
22
+ It performs an evaluation run (fit and evaluate) on the model
23
+ multiple times. Note that as the input `base_trainer` and
24
+ `base_model` would be cloned, the input object would not
25
+ be modified.
26
+
27
+ Args:
28
+ base_trainer: The base trainer module to use.
29
+ base_model: The base model module to use.
30
+ datamodule: The data module.
31
+ n_runs: The amount of runs (fit and evaluate) to perform.
32
+ """
33
+ recorder = _recorder.SessionRecorder(output_dir=base_trainer.default_log_dir)
34
+ for run_index in range(n_runs):
35
+ validation_scores, test_scores = run_evaluation(
36
+ base_trainer, base_model, datamodule, run_id=f"run_{run_index}"
37
+ )
38
+ recorder.update(validation_scores, test_scores)
39
+ recorder.save()
40
+
41
+
42
+ def run_evaluation(
43
+ base_trainer: eva_trainer.Trainer,
44
+ base_model: modules.ModelModule,
45
+ datamodule: datamodules.DataModule,
46
+ *,
47
+ run_id: str | None = None,
48
+ ) -> Tuple[_EVALUATE_OUTPUT, _EVALUATE_OUTPUT | None]:
49
+ """Fits and evaluates a model out-of-place.
50
+
51
+ Args:
52
+ base_trainer: The base trainer to use but not modify.
53
+ base_model: The model module to use but not modify.
54
+ datamodule: The data module.
55
+ run_id: The run id to be appended to the output log directory.
56
+ If `None`, it will use the log directory of the trainer as is.
57
+
58
+ Returns:
59
+ A tuple of with the validation and the test metrics (if exists).
60
+ """
61
+ trainer, model = _utils.clone(base_trainer, base_model)
62
+ trainer.setup_log_dirs(run_id or "")
63
+ return fit_and_validate(trainer, model, datamodule)
64
+
65
+
66
+ def fit_and_validate(
67
+ trainer: eva_trainer.Trainer,
68
+ model: modules.ModelModule,
69
+ datamodule: datamodules.DataModule,
70
+ ) -> Tuple[_EVALUATE_OUTPUT, _EVALUATE_OUTPUT | None]:
71
+ """Fits and evaluates a model in-place.
72
+
73
+ If the test set is set in the datamodule, it will evaluate the model
74
+ on the test set as well.
75
+
76
+ Args:
77
+ trainer: The trainer module to use and update in-place.
78
+ model: The model module to use and update in-place.
79
+ datamodule: The data module.
80
+
81
+ Returns:
82
+ A tuple of with the validation and the test metrics (if exists).
83
+ """
84
+ trainer.fit(model, datamodule=datamodule)
85
+ validation_scores = trainer.validate(datamodule=datamodule)
86
+ test_scores = None if datamodule.datasets.test is None else trainer.test(datamodule=datamodule)
87
+ return validation_scores, test_scores
88
+
89
+
90
+ def infer_model(
91
+ base_trainer: eva_trainer.Trainer,
92
+ base_model: modules.ModelModule,
93
+ datamodule: datamodules.DataModule,
94
+ *,
95
+ return_predictions: bool = False,
96
+ ) -> None:
97
+ """Performs model inference out-of-place.
98
+
99
+ Note that the input `base_model` and `base_trainer` would
100
+ not be modified.
101
+
102
+ Args:
103
+ base_trainer: The base trainer to use but not modify.
104
+ base_model: The model module to use but not modify.
105
+ datamodule: The data module.
106
+ return_predictions: Whether to return the model predictions.
107
+ """
108
+ trainer, model = _utils.clone(base_trainer, base_model)
109
+ return trainer.predict(
110
+ model=model,
111
+ datamodule=datamodule,
112
+ return_predictions=return_predictions,
113
+ )
@@ -0,0 +1,97 @@
1
+ """Core trainer module."""
2
+
3
+ import os
4
+ from typing import Any
5
+
6
+ from lightning.pytorch import loggers as pl_loggers
7
+ from lightning.pytorch import trainer as pl_trainer
8
+ from lightning.pytorch.utilities import argparse
9
+ from typing_extensions import override
10
+
11
+ from eva.core.data import datamodules
12
+ from eva.core.models import modules
13
+ from eva.core.trainers import _logging, functional
14
+
15
+
16
+ class Trainer(pl_trainer.Trainer):
17
+ """Core trainer class.
18
+
19
+ This is an extended version of lightning's core trainer class.
20
+ """
21
+
22
+ @argparse._defaults_from_env_vars
23
+ def __init__(
24
+ self,
25
+ *args: Any,
26
+ default_root_dir: str = "logs",
27
+ n_runs: int = 1,
28
+ **kwargs: Any,
29
+ ) -> None:
30
+ """Initializes the trainer.
31
+
32
+ For the input arguments, refer to ::class::`lightning.pytorch.Trainer`.
33
+
34
+ Args:
35
+ args: Positional arguments of ::class::`lightning.pytorch.Trainer`.
36
+ default_root_dir: The default root directory to store the output logs.
37
+ Unlike in ::class::`lightning.pytorch.Trainer`, this path would be the
38
+ prioritized destination point.
39
+ n_runs: The amount of runs (fit and evaluate) to perform in an evaluation session.
40
+ kwargs: Kew-word arguments of ::class::`lightning.pytorch.Trainer`.
41
+ """
42
+ super().__init__(*args, default_root_dir=default_root_dir, **kwargs)
43
+
44
+ self._n_runs = n_runs
45
+
46
+ self._session_id: str = _logging.generate_session_id()
47
+ self._log_dir: str = self.default_log_dir
48
+
49
+ self.setup_log_dirs()
50
+
51
+ @property
52
+ def default_log_dir(self) -> str:
53
+ """Returns the default log directory."""
54
+ return os.path.join(self.default_root_dir, self._session_id)
55
+
56
+ @property
57
+ @override
58
+ def log_dir(self) -> str | None:
59
+ return self.strategy.broadcast(self._log_dir)
60
+
61
+ def setup_log_dirs(self, subdirectory: str = "") -> None:
62
+ """Setups the logging directory of the trainer and experimental loggers in-place.
63
+
64
+ Args:
65
+ subdirectory: Whether to append a subdirectory to the output log.
66
+ """
67
+ self._log_dir = os.path.join(self.default_root_dir, self._session_id, subdirectory)
68
+ os.fspath(self._log_dir)
69
+
70
+ for logger in self.loggers:
71
+ if isinstance(logger, (pl_loggers.CSVLogger, pl_loggers.TensorBoardLogger)):
72
+ logger._root_dir = self.default_root_dir
73
+ logger._name = self._session_id
74
+ logger._version = subdirectory
75
+
76
+ def run_evaluation_session(
77
+ self,
78
+ model: modules.ModelModule,
79
+ datamodule: datamodules.DataModule,
80
+ ) -> None:
81
+ """Runs a evaluation session out-of-place.
82
+
83
+ It performs an evaluation run (fit and evaluate) the model
84
+ `self._n_run` times. Note that the input `base_model` would
85
+ not be modified, so the weights of the input model will remain
86
+ as they are.
87
+
88
+ Args:
89
+ model: The base model module to evaluate.
90
+ datamodule: The data module.
91
+ """
92
+ functional.run_evaluation_session(
93
+ base_trainer=self,
94
+ base_model=model,
95
+ datamodule=datamodule,
96
+ n_runs=self._n_runs,
97
+ )
@@ -0,0 +1 @@
1
+ """Utilities and library level helper functionalities."""
@@ -0,0 +1,5 @@
1
+ """Core I/O utilities."""
2
+
3
+ from eva.core.utils.io.dataframe import read_dataframe
4
+
5
+ __all__ = ["read_dataframe"]
@@ -0,0 +1,21 @@
1
+ """DataFrame related I/O operations."""
2
+
3
+ import pandas as pd
4
+
5
+
6
+ def read_dataframe(path: str) -> pd.DataFrame:
7
+ """Reads and loads a DataFrame file.
8
+
9
+ Args:
10
+ path: The path to the manifest file.
11
+
12
+ Returns:
13
+ The data of the file as a `DataFrame`.
14
+ """
15
+ if path.endswith(".csv"):
16
+ data = pd.read_csv(path)
17
+ elif path.endswith(".parquet"):
18
+ data = pd.read_parquet(path)
19
+ else:
20
+ raise ValueError(f"Failed to load manifest file at '{path}'.")
21
+ return data
@@ -0,0 +1,44 @@
1
+ """Multiprocessing utilities."""
2
+
3
+ import multiprocessing
4
+ import sys
5
+ import traceback
6
+ from typing import Any
7
+
8
+
9
+ class Process(multiprocessing.Process):
10
+ """Multiprocessing wrapper with logic to propagate exceptions to the parent process.
11
+
12
+ Source: https://stackoverflow.com/a/33599967/4992248
13
+ """
14
+
15
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
16
+ """Initialize the process."""
17
+ multiprocessing.Process.__init__(self, *args, **kwargs)
18
+
19
+ self._parent_conn, self._child_conn = multiprocessing.Pipe()
20
+ self._exception = None
21
+
22
+ def run(self) -> None:
23
+ """Run the process."""
24
+ try:
25
+ multiprocessing.Process.run(self)
26
+ self._child_conn.send(None)
27
+ except Exception as e:
28
+ tb = traceback.format_exc()
29
+ self._child_conn.send((e, tb))
30
+
31
+ @property
32
+ def exception(self):
33
+ """Property that contains exception information from the process."""
34
+ if self._parent_conn.poll():
35
+ self._exception = self._parent_conn.recv()
36
+ return self._exception
37
+
38
+ def check_exceptions(self) -> None:
39
+ """Check for exception propagate it to the parent process."""
40
+ if not self.is_alive():
41
+ if self.exception:
42
+ error, traceback = self.exception
43
+ sys.stderr.write(traceback + "\n")
44
+ raise error
@@ -0,0 +1,21 @@
1
+ """Processing workers utilities and helper functions."""
2
+
3
+ import multiprocessing
4
+ from typing import Any, Callable
5
+
6
+
7
+ def main_worker_only(func: Callable) -> Any:
8
+ """Function decorator which will execute it only on main / worker process."""
9
+
10
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
11
+ """Wrapper function for the decorated method."""
12
+ if is_main_worker():
13
+ return func(*args, **kwargs)
14
+
15
+ return wrapper
16
+
17
+
18
+ def is_main_worker() -> bool:
19
+ """Returns whether the main process / worker is currently used."""
20
+ process = multiprocessing.current_process()
21
+ return process.name == "MainProcess"
eva/vision/__init__.py ADDED
@@ -0,0 +1,14 @@
1
+ """eva vision API."""
2
+
3
+ try:
4
+ from eva.vision import models, utils
5
+ from eva.vision.data import datasets, transforms
6
+ except ImportError as e:
7
+ msg = (
8
+ "eva vision requirements are not installed.\n\n"
9
+ "Please pip install as follows:\n"
10
+ ' python -m pip install "eva[vision]" --upgrade'
11
+ )
12
+ raise ImportError(str(e) + "\n\n" + msg) from e
13
+
14
+ __all__ = ["models", "utils", "datasets", "transforms"]
@@ -0,0 +1,5 @@
1
+ """Vision data API."""
2
+
3
+ from eva.vision.data import datasets, transforms
4
+
5
+ __all__ = ["datasets", "transforms"]
@@ -0,0 +1,22 @@
1
+ """Vision Datasets API."""
2
+
3
+ from eva.vision.data.datasets.classification import (
4
+ BACH,
5
+ CRC,
6
+ MHIST,
7
+ PatchCamelyon,
8
+ TotalSegmentatorClassification,
9
+ )
10
+ from eva.vision.data.datasets.segmentation import ImageSegmentation, TotalSegmentator2D
11
+ from eva.vision.data.datasets.vision import VisionDataset
12
+
13
+ __all__ = [
14
+ "BACH",
15
+ "CRC",
16
+ "MHIST",
17
+ "ImageSegmentation",
18
+ "PatchCamelyon",
19
+ "TotalSegmentatorClassification",
20
+ "TotalSegmentator2D",
21
+ "VisionDataset",
22
+ ]
@@ -0,0 +1,50 @@
1
+ """Dataset related function and helper functions."""
2
+
3
+ from typing import List, Tuple
4
+
5
+
6
+ def indices_to_ranges(indices: List[int]) -> List[Tuple[int, int]]:
7
+ """Turns a list of indices to a list of ranges.
8
+
9
+ The produced range intervals are half-open inequalities: start <= x < end.
10
+
11
+ Args:
12
+ indices: The list of indices to produce the ranges from.
13
+
14
+ Return:
15
+ A list of half-open intervals.
16
+
17
+ Example:
18
+ >>> indices = [0, 1, 2, 4, 6, 7, 8]
19
+ >>> ranges = indices_to_ranges(indices)
20
+ >>> assert ranges == [(0, 3), (4, 5), (6, 9)]
21
+ """
22
+ ranges = []
23
+ start_index = 0
24
+ for i, current in enumerate(indices):
25
+ if i + 1 < len(indices) and current + 1 == indices[i + 1]:
26
+ continue
27
+
28
+ start = indices[start_index]
29
+ end = start if start_index == i else current
30
+ ranges.append((start, end + 1))
31
+ start_index = i + 1
32
+
33
+ return ranges
34
+
35
+
36
+ def ranges_to_indices(ranges: List[Tuple[int, int]]) -> List[int]:
37
+ """Unpacks a list of ranges to individual indices.
38
+
39
+ Args:
40
+ ranges: The list of ranges to produce the indices from.
41
+
42
+ Return:
43
+ A list of the indices.
44
+
45
+ Example:
46
+ >>> ranges == [(0, 3), (4, 5), (6, 9)]
47
+ >>> indices = ranges_to_indices(ranges)
48
+ >>> assert indices == [0, 1, 2, 4, 6, 7, 8]
49
+ """
50
+ return [index for start, end in ranges for index in range(start, end)]