kaiko-eva 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (41) hide show
  1. eva/.DS_Store +0 -0
  2. eva/core/callbacks/__init__.py +2 -1
  3. eva/core/callbacks/config.py +143 -0
  4. eva/core/data/datasets/__init__.py +10 -2
  5. eva/core/data/datasets/embeddings/__init__.py +13 -0
  6. eva/core/data/datasets/{classification/embeddings.py → embeddings/base.py} +41 -43
  7. eva/core/data/datasets/embeddings/classification/__init__.py +10 -0
  8. eva/core/data/datasets/embeddings/classification/embeddings.py +66 -0
  9. eva/core/data/datasets/embeddings/classification/multi_embeddings.py +106 -0
  10. eva/core/data/transforms/__init__.py +3 -1
  11. eva/core/data/transforms/padding/__init__.py +5 -0
  12. eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
  13. eva/core/data/transforms/sampling/__init__.py +5 -0
  14. eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
  15. eva/core/loggers/__init__.py +7 -0
  16. eva/core/loggers/dummy.py +38 -0
  17. eva/core/loggers/experimental_loggers.py +8 -0
  18. eva/core/loggers/log/__init__.py +5 -0
  19. eva/core/loggers/log/parameters.py +64 -0
  20. eva/core/loggers/log/utils.py +13 -0
  21. eva/core/models/modules/head.py +6 -11
  22. eva/core/models/modules/module.py +25 -1
  23. eva/core/trainers/_recorder.py +69 -7
  24. eva/core/trainers/functional.py +22 -5
  25. eva/core/trainers/trainer.py +20 -6
  26. eva/vision/data/datasets/__init__.py +1 -8
  27. eva/vision/data/datasets/_utils.py +3 -3
  28. eva/vision/data/datasets/classification/__init__.py +1 -8
  29. eva/vision/data/datasets/segmentation/base.py +20 -35
  30. eva/vision/data/datasets/segmentation/total_segmentator.py +88 -69
  31. eva/vision/models/.DS_Store +0 -0
  32. eva/vision/models/networks/.DS_Store +0 -0
  33. eva/vision/utils/convert.py +24 -0
  34. eva/vision/utils/io/nifti.py +10 -6
  35. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.0.2.dist-info}/METADATA +51 -25
  36. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.0.2.dist-info}/RECORD +39 -22
  37. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.0.2.dist-info}/WHEEL +1 -1
  38. eva/core/data/datasets/classification/__init__.py +0 -5
  39. eva/vision/data/datasets/classification/total_segmentator.py +0 -213
  40. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.0.2.dist-info}/entry_points.txt +0 -0
  41. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.0.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,64 @@
1
+ """Text log functionality."""
2
+
3
+ import functools
4
+ from typing import Any, Dict
5
+
6
+ import yaml
7
+
8
+ from eva.core.loggers import experimental_loggers as loggers_lib
9
+ from eva.core.loggers.log import utils
10
+
11
+
12
+ @functools.singledispatch
13
+ def log_parameters(
14
+ logger,
15
+ tag: str,
16
+ parameters: Dict[str, Any],
17
+ ) -> None:
18
+ """Adds parameters to the logger.
19
+
20
+ Args:
21
+ logger: The desired logger.
22
+ tag: The log tag.
23
+ parameters: The parameters to log.
24
+ """
25
+ utils.raise_not_supported(logger, "parameters")
26
+
27
+
28
+ @log_parameters.register
29
+ def _(
30
+ loggers: list,
31
+ tag: str,
32
+ parameters: Dict[str, Any],
33
+ ) -> None:
34
+ """Adds parameters to a list of supported loggers."""
35
+ for logger in loggers:
36
+ log_parameters(logger, tag=tag, parameters=parameters)
37
+
38
+
39
+ @log_parameters.register
40
+ def _(
41
+ logger: loggers_lib.TensorBoardLogger,
42
+ tag: str,
43
+ parameters: Dict[str, Any],
44
+ ) -> None:
45
+ """Adds parameters to a TensorBoard logger."""
46
+ as_markdown_text = _yaml_to_markdown(parameters)
47
+ logger.experiment.add_text(
48
+ tag=tag,
49
+ text_string=as_markdown_text,
50
+ global_step=0,
51
+ )
52
+
53
+
54
+ def _yaml_to_markdown(data: Dict[str, Any]) -> str:
55
+ """Casts yaml data to markdown.
56
+
57
+ Args:
58
+ data: The yaml data.
59
+
60
+ Returns:
61
+ A string markdown friendly formatted.
62
+ """
63
+ text = yaml.dump(data, sort_keys=False)
64
+ return f"```yaml\n{text}```"
@@ -0,0 +1,13 @@
1
+ """Logging related utilities."""
2
+
3
+ from loguru import logger as cli_logger
4
+
5
+ from eva.core.loggers import ExperimentalLoggers
6
+
7
+
8
+ def raise_not_supported(logger: ExperimentalLoggers, data_type: str) -> None:
9
+ """Raises a warning for not supported tasks from the given logger."""
10
+ print("\n")
11
+ cli_logger.debug(
12
+ f"Logger '{logger.__class__.__name__}' is not supported for " f"'{data_type}' data."
13
+ )
@@ -54,9 +54,14 @@ class HeadModule(module.ModelModule):
54
54
  self.optimizer = optimizer
55
55
  self.lr_scheduler = lr_scheduler
56
56
 
57
+ @override
58
+ def configure_model(self) -> Any:
59
+ if self.backbone is not None:
60
+ grad.deactivate_requires_grad(self.backbone)
61
+
57
62
  @override
58
63
  def configure_optimizers(self) -> Any:
59
- parameters = list(self.head.parameters())
64
+ parameters = self.head.parameters()
60
65
  optimizer = self.optimizer(parameters)
61
66
  lr_scheduler = self.lr_scheduler(optimizer)
62
67
  return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
@@ -66,11 +71,6 @@ class HeadModule(module.ModelModule):
66
71
  features = tensor if self.backbone is None else self.backbone(tensor)
67
72
  return self.head(features).squeeze(-1)
68
73
 
69
- @override
70
- def on_fit_start(self) -> None:
71
- if self.backbone is not None:
72
- grad.deactivate_requires_grad(self.backbone)
73
-
74
74
  @override
75
75
  def training_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
76
76
  return self._batch_step(batch)
@@ -88,11 +88,6 @@ class HeadModule(module.ModelModule):
88
88
  tensor = INPUT_BATCH(*batch).data
89
89
  return tensor if self.backbone is None else self.backbone(tensor)
90
90
 
91
- @override
92
- def on_fit_end(self) -> None:
93
- if self.backbone is not None:
94
- grad.activate_requires_grad(self.backbone)
95
-
96
91
  def _batch_step(self, batch: INPUT_BATCH) -> STEP_OUTPUT:
97
92
  """Performs a model forward step and calculates the loss.
98
93
 
@@ -4,6 +4,7 @@ from typing import Any, Mapping
4
4
 
5
5
  import lightning.pytorch as pl
6
6
  import torch
7
+ from lightning.pytorch.strategies.single_device import SingleDeviceStrategy
7
8
  from lightning.pytorch.utilities import memory
8
9
  from lightning.pytorch.utilities.types import STEP_OUTPUT
9
10
  from typing_extensions import override
@@ -46,6 +47,21 @@ class ModelModule(pl.LightningModule):
46
47
  """The default post-processes."""
47
48
  return batch_postprocess.BatchPostProcess()
48
49
 
50
+ @property
51
+ def metrics_device(self) -> torch.device:
52
+ """Returns the device by which the metrics should be calculated.
53
+
54
+ We allocate the metrics to CPU when operating on single device, as
55
+ it is much faster, but to GPU when employing multiple ones, as DDP
56
+ strategy requires the metrics to be allocated to the module's GPU.
57
+ """
58
+ move_to_cpu = isinstance(self.trainer.strategy, SingleDeviceStrategy)
59
+ return torch.device("cpu") if move_to_cpu else self.device
60
+
61
+ @override
62
+ def on_fit_start(self) -> None:
63
+ self.metrics.to(device=self.metrics_device)
64
+
49
65
  @override
50
66
  def on_train_batch_end(
51
67
  self,
@@ -59,6 +75,10 @@ class ModelModule(pl.LightningModule):
59
75
  batch_outputs=outputs,
60
76
  )
61
77
 
78
+ @override
79
+ def on_validation_start(self) -> None:
80
+ self.metrics.to(device=self.metrics_device)
81
+
62
82
  @override
63
83
  def on_validation_batch_end(
64
84
  self,
@@ -78,6 +98,10 @@ class ModelModule(pl.LightningModule):
78
98
  def on_validation_epoch_end(self) -> None:
79
99
  self._compute_and_log_metrics(self.metrics.validation_metrics)
80
100
 
101
+ @override
102
+ def on_test_start(self) -> None:
103
+ self.metrics.to(device=self.metrics_device)
104
+
81
105
  @override
82
106
  def on_test_batch_end(
83
107
  self,
@@ -110,7 +134,7 @@ class ModelModule(pl.LightningModule):
110
134
  The updated outputs.
111
135
  """
112
136
  self._postprocess(outputs)
113
- return memory.recursive_detach(outputs, to_cpu=self.device.type == "cpu")
137
+ return memory.recursive_detach(outputs, to_cpu=self.metrics_device.type == "cpu")
114
138
 
115
139
  def _forward_and_log_metrics(
116
140
  self,
@@ -5,18 +5,41 @@ import json
5
5
  import os
6
6
  import statistics
7
7
  import sys
8
- from typing import Any, Dict, List, Mapping
8
+ from typing import Dict, List, Mapping, TypedDict
9
9
 
10
10
  from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT
11
11
  from lightning_fabric.utilities import cloud_io
12
12
  from loguru import logger
13
13
  from omegaconf import OmegaConf
14
+ from rich import console as rich_console
15
+ from rich import table as rich_table
14
16
  from toolz import dicttoolz
15
17
 
16
18
  SESSION_METRICS = Mapping[str, List[float]]
17
19
  """Session metrics type-hint."""
18
20
 
19
21
 
22
+ class SESSION_STATISTICS(TypedDict):
23
+ """Type-hint for aggregated metrics of multiple runs with mean & stdev."""
24
+
25
+ mean: float
26
+ stdev: float
27
+ values: List[float]
28
+
29
+
30
+ class STAGE_RESULTS(TypedDict):
31
+ """Type-hint for metrics statstics for val & test stages."""
32
+
33
+ val: List[Dict[str, SESSION_STATISTICS]]
34
+ test: List[Dict[str, SESSION_STATISTICS]]
35
+
36
+
37
+ class RESULTS_DICT(TypedDict):
38
+ """Type-hint for the final results dictionary."""
39
+
40
+ metrics: STAGE_RESULTS
41
+
42
+
20
43
  class SessionRecorder:
21
44
  """Multi-run (session) summary logger."""
22
45
 
@@ -25,6 +48,7 @@ class SessionRecorder:
25
48
  output_dir: str,
26
49
  results_file: str = "results.json",
27
50
  config_file: str = "config.yaml",
51
+ verbose: bool = True,
28
52
  ) -> None:
29
53
  """Initializes the recorder.
30
54
 
@@ -32,10 +56,12 @@ class SessionRecorder:
32
56
  output_dir: The destination folder to save the results.
33
57
  results_file: The name of the results json file.
34
58
  config_file: The name of the yaml configuration file.
59
+ verbose: Whether to print the session metrics.
35
60
  """
36
61
  self._output_dir = output_dir
37
62
  self._results_file = results_file
38
63
  self._config_file = config_file
64
+ self._verbose = verbose
39
65
 
40
66
  self._validation_metrics: List[SESSION_METRICS] = []
41
67
  self._test_metrics: List[SESSION_METRICS] = []
@@ -67,13 +93,13 @@ class SessionRecorder:
67
93
  self._update_validation_metrics(validation_scores)
68
94
  self._update_test_metrics(test_scores)
69
95
 
70
- def compute(self) -> Dict[str, List[Dict[str, Any]]]:
96
+ def compute(self) -> STAGE_RESULTS:
71
97
  """Computes and returns the session statistics."""
72
98
  validation_statistics = list(map(_calculate_statistics, self._validation_metrics))
73
99
  test_statistics = list(map(_calculate_statistics, self._test_metrics))
74
100
  return {"val": validation_statistics, "test": test_statistics}
75
101
 
76
- def export(self) -> Dict[str, Any]:
102
+ def export(self) -> RESULTS_DICT:
77
103
  """Exports the results."""
78
104
  statistics = self.compute()
79
105
  return {"metrics": statistics}
@@ -83,6 +109,8 @@ class SessionRecorder:
83
109
  results = self.export()
84
110
  _save_json(results, self.filename)
85
111
  self._save_config()
112
+ if self._verbose:
113
+ _print_results(results)
86
114
 
87
115
  def reset(self) -> None:
88
116
  """Resets the state of the tracked metrics."""
@@ -125,10 +153,10 @@ def _init_session_metrics(n_datasets: int) -> List[SESSION_METRICS]:
125
153
  return [collections.defaultdict(list) for _ in range(n_datasets)]
126
154
 
127
155
 
128
- def _calculate_statistics(session_metrics: SESSION_METRICS) -> Dict[str, float | List[float]]:
156
+ def _calculate_statistics(session_metrics: SESSION_METRICS) -> Dict[str, SESSION_STATISTICS]:
129
157
  """Calculate the metric statistics of a dataset session run."""
130
158
 
131
- def _calculate_metric_statistics(values: List[float]) -> Dict[str, float | List[float]]:
159
+ def _calculate_metric_statistics(values: List[float]) -> SESSION_STATISTICS:
132
160
  """Calculates and returns the metric statistics."""
133
161
  mean = statistics.mean(values)
134
162
  stdev = statistics.stdev(values) if len(values) > 1 else 0
@@ -137,7 +165,7 @@ def _calculate_statistics(session_metrics: SESSION_METRICS) -> Dict[str, float |
137
165
  return dicttoolz.valmap(_calculate_metric_statistics, session_metrics)
138
166
 
139
167
 
140
- def _save_json(data: Dict[str, Any], save_as: str = "data.json"):
168
+ def _save_json(data: RESULTS_DICT, save_as: str = "data.json"):
141
169
  """Saves data to a json file."""
142
170
  if not save_as.endswith(".json"):
143
171
  raise ValueError()
@@ -146,4 +174,38 @@ def _save_json(data: Dict[str, Any], save_as: str = "data.json"):
146
174
  fs = cloud_io.get_filesystem(output_dir, anon=False)
147
175
  fs.makedirs(output_dir, exist_ok=True)
148
176
  with fs.open(save_as, "w") as file:
149
- json.dump(data, file, indent=4, sort_keys=True)
177
+ json.dump(data, file, indent=2, sort_keys=True)
178
+
179
+
180
+ def _print_results(results: RESULTS_DICT) -> None:
181
+ """Prints the results to the console."""
182
+ try:
183
+ for stage in ["val", "test"]:
184
+ for dataset_idx in range(len(results["metrics"][stage])):
185
+ _print_table(results["metrics"][stage][dataset_idx], stage, dataset_idx)
186
+ except Exception as e:
187
+ logger.error(f"Failed to print the results: {e}")
188
+
189
+
190
+ def _print_table(metrics_dict: Dict[str, SESSION_STATISTICS], stage: str, dataset_idx: int):
191
+ """Prints the metrics of a single dataset as a table."""
192
+ metrics_table = rich_table.Table(
193
+ title=f"\n{stage.capitalize()} Dataset {dataset_idx}", title_style="bold"
194
+ )
195
+ metrics_table.add_column("Metric", style="cyan")
196
+ metrics_table.add_column("Mean", style="magenta")
197
+ metrics_table.add_column("Stdev", style="magenta")
198
+ metrics_table.add_column("All", style="magenta")
199
+
200
+ n_runs = len(metrics_dict[next(iter(metrics_dict))]["values"])
201
+ for metric_name, metric_dict in metrics_dict.items():
202
+ row = [
203
+ metric_name,
204
+ f'{metric_dict["mean"]:.3f}',
205
+ f'{metric_dict["stdev"]:.3f}',
206
+ ", ".join(f'{metric_dict["values"][i]:.3f}' for i in range(n_runs)),
207
+ ]
208
+ metrics_table.add_row(*row)
209
+
210
+ console = rich_console.Console()
211
+ console.print(metrics_table)
@@ -16,6 +16,7 @@ def run_evaluation_session(
16
16
  datamodule: datamodules.DataModule,
17
17
  *,
18
18
  n_runs: int = 1,
19
+ verbose: bool = True,
19
20
  ) -> None:
20
21
  """Runs a downstream evaluation session out-of-place.
21
22
 
@@ -29,11 +30,17 @@ def run_evaluation_session(
29
30
  base_model: The base model module to use.
30
31
  datamodule: The data module.
31
32
  n_runs: The amount of runs (fit and evaluate) to perform.
33
+ verbose: Whether to verbose the session metrics instead of
34
+ these of each individual runs and vice-versa.
32
35
  """
33
- recorder = _recorder.SessionRecorder(output_dir=base_trainer.default_log_dir)
36
+ recorder = _recorder.SessionRecorder(output_dir=base_trainer.default_log_dir, verbose=verbose)
34
37
  for run_index in range(n_runs):
35
38
  validation_scores, test_scores = run_evaluation(
36
- base_trainer, base_model, datamodule, run_id=f"run_{run_index}"
39
+ base_trainer,
40
+ base_model,
41
+ datamodule,
42
+ run_id=f"run_{run_index}",
43
+ verbose=not verbose,
37
44
  )
38
45
  recorder.update(validation_scores, test_scores)
39
46
  recorder.save()
@@ -45,6 +52,7 @@ def run_evaluation(
45
52
  datamodule: datamodules.DataModule,
46
53
  *,
47
54
  run_id: str | None = None,
55
+ verbose: bool = True,
48
56
  ) -> Tuple[_EVALUATE_OUTPUT, _EVALUATE_OUTPUT | None]:
49
57
  """Fits and evaluates a model out-of-place.
50
58
 
@@ -54,19 +62,22 @@ def run_evaluation(
54
62
  datamodule: The data module.
55
63
  run_id: The run id to be appended to the output log directory.
56
64
  If `None`, it will use the log directory of the trainer as is.
65
+ verbose: Whether to print the validation and test metrics
66
+ in the end of the training.
57
67
 
58
68
  Returns:
59
69
  A tuple of with the validation and the test metrics (if exists).
60
70
  """
61
71
  trainer, model = _utils.clone(base_trainer, base_model)
62
72
  trainer.setup_log_dirs(run_id or "")
63
- return fit_and_validate(trainer, model, datamodule)
73
+ return fit_and_validate(trainer, model, datamodule, verbose=verbose)
64
74
 
65
75
 
66
76
  def fit_and_validate(
67
77
  trainer: eva_trainer.Trainer,
68
78
  model: modules.ModelModule,
69
79
  datamodule: datamodules.DataModule,
80
+ verbose: bool = True,
70
81
  ) -> Tuple[_EVALUATE_OUTPUT, _EVALUATE_OUTPUT | None]:
71
82
  """Fits and evaluates a model in-place.
72
83
 
@@ -77,13 +88,19 @@ def fit_and_validate(
77
88
  trainer: The trainer module to use and update in-place.
78
89
  model: The model module to use and update in-place.
79
90
  datamodule: The data module.
91
+ verbose: Whether to print the validation and test metrics
92
+ in the end of the training.
80
93
 
81
94
  Returns:
82
95
  A tuple of with the validation and the test metrics (if exists).
83
96
  """
84
97
  trainer.fit(model, datamodule=datamodule)
85
- validation_scores = trainer.validate(datamodule=datamodule)
86
- test_scores = None if datamodule.datasets.test is None else trainer.test(datamodule=datamodule)
98
+ validation_scores = trainer.validate(datamodule=datamodule, verbose=verbose)
99
+ test_scores = (
100
+ None
101
+ if datamodule.datasets.test is None
102
+ else trainer.test(datamodule=datamodule, verbose=verbose)
103
+ )
87
104
  return validation_scores, test_scores
88
105
 
89
106
 
@@ -3,11 +3,14 @@
3
3
  import os
4
4
  from typing import Any
5
5
 
6
+ import loguru
6
7
  from lightning.pytorch import loggers as pl_loggers
7
8
  from lightning.pytorch import trainer as pl_trainer
8
9
  from lightning.pytorch.utilities import argparse
10
+ from lightning_fabric.utilities import cloud_io
9
11
  from typing_extensions import override
10
12
 
13
+ from eva.core import loggers as eva_loggers
11
14
  from eva.core.data import datamodules
12
15
  from eva.core.models import modules
13
16
  from eva.core.trainers import _logging, functional
@@ -65,13 +68,23 @@ class Trainer(pl_trainer.Trainer):
65
68
  subdirectory: Whether to append a subdirectory to the output log.
66
69
  """
67
70
  self._log_dir = os.path.join(self.default_root_dir, self._session_id, subdirectory)
68
- os.fspath(self._log_dir)
69
71
 
70
- for logger in self.loggers:
71
- if isinstance(logger, (pl_loggers.CSVLogger, pl_loggers.TensorBoardLogger)):
72
- logger._root_dir = self.default_root_dir
73
- logger._name = self._session_id
74
- logger._version = subdirectory
72
+ enabled_loggers = []
73
+ if isinstance(self.loggers, list) and len(self.loggers) > 0:
74
+ for logger in self.loggers:
75
+ if isinstance(logger, (pl_loggers.CSVLogger, pl_loggers.TensorBoardLogger)):
76
+ if not cloud_io._is_local_file_protocol(self.default_root_dir):
77
+ loguru.logger.warning(
78
+ f"Skipped {type(logger).__name__} as remote storage is not supported."
79
+ )
80
+ continue
81
+ else:
82
+ logger._root_dir = self.default_root_dir
83
+ logger._name = self._session_id
84
+ logger._version = subdirectory
85
+ enabled_loggers.append(logger)
86
+
87
+ self._loggers = enabled_loggers or [eva_loggers.DummyLogger(self._log_dir)]
75
88
 
76
89
  def run_evaluation_session(
77
90
  self,
@@ -94,4 +107,5 @@ class Trainer(pl_trainer.Trainer):
94
107
  base_model=model,
95
108
  datamodule=datamodule,
96
109
  n_runs=self._n_runs,
110
+ verbose=self._n_runs > 1,
97
111
  )
@@ -1,12 +1,6 @@
1
1
  """Vision Datasets API."""
2
2
 
3
- from eva.vision.data.datasets.classification import (
4
- BACH,
5
- CRC,
6
- MHIST,
7
- PatchCamelyon,
8
- TotalSegmentatorClassification,
9
- )
3
+ from eva.vision.data.datasets.classification import BACH, CRC, MHIST, PatchCamelyon
10
4
  from eva.vision.data.datasets.segmentation import ImageSegmentation, TotalSegmentator2D
11
5
  from eva.vision.data.datasets.vision import VisionDataset
12
6
 
@@ -16,7 +10,6 @@ __all__ = [
16
10
  "MHIST",
17
11
  "ImageSegmentation",
18
12
  "PatchCamelyon",
19
- "TotalSegmentatorClassification",
20
13
  "TotalSegmentator2D",
21
14
  "VisionDataset",
22
15
  ]
@@ -1,6 +1,6 @@
1
1
  """Dataset related function and helper functions."""
2
2
 
3
- from typing import List, Tuple
3
+ from typing import List, Sequence, Tuple
4
4
 
5
5
 
6
6
  def indices_to_ranges(indices: List[int]) -> List[Tuple[int, int]]:
@@ -33,11 +33,11 @@ def indices_to_ranges(indices: List[int]) -> List[Tuple[int, int]]:
33
33
  return ranges
34
34
 
35
35
 
36
- def ranges_to_indices(ranges: List[Tuple[int, int]]) -> List[int]:
36
+ def ranges_to_indices(ranges: Sequence[Tuple[int, int]]) -> List[int]:
37
37
  """Unpacks a list of ranges to individual indices.
38
38
 
39
39
  Args:
40
- ranges: The list of ranges to produce the indices from.
40
+ ranges: A sequence of ranges to produce the indices from.
41
41
 
42
42
  Return:
43
43
  A list of the indices.
@@ -4,12 +4,5 @@ from eva.vision.data.datasets.classification.bach import BACH
4
4
  from eva.vision.data.datasets.classification.crc import CRC
5
5
  from eva.vision.data.datasets.classification.mhist import MHIST
6
6
  from eva.vision.data.datasets.classification.patch_camelyon import PatchCamelyon
7
- from eva.vision.data.datasets.classification.total_segmentator import TotalSegmentatorClassification
8
7
 
9
- __all__ = [
10
- "BACH",
11
- "CRC",
12
- "MHIST",
13
- "PatchCamelyon",
14
- "TotalSegmentatorClassification",
15
- ]
8
+ __all__ = ["BACH", "CRC", "MHIST", "PatchCamelyon"]
@@ -3,38 +3,28 @@
3
3
  import abc
4
4
  from typing import Any, Callable, Dict, List, Tuple
5
5
 
6
- import numpy as np
6
+ from torchvision import tv_tensors
7
7
  from typing_extensions import override
8
8
 
9
9
  from eva.vision.data.datasets import vision
10
10
 
11
11
 
12
- class ImageSegmentation(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], abc.ABC):
12
+ class ImageSegmentation(vision.VisionDataset[Tuple[tv_tensors.Image, tv_tensors.Mask]], abc.ABC):
13
13
  """Image segmentation abstract dataset."""
14
14
 
15
15
  def __init__(
16
16
  self,
17
- image_transforms: Callable | None = None,
18
- target_transforms: Callable | None = None,
19
- image_target_transforms: Callable | None = None,
17
+ transforms: Callable | None = None,
20
18
  ) -> None:
21
19
  """Initializes the image segmentation base class.
22
20
 
23
21
  Args:
24
- image_transforms: A function/transform that takes in an image
25
- and returns a transformed version.
26
- target_transforms: A function/transform that takes in the target
27
- and transforms it.
28
- image_target_transforms: A function/transforms that takes in an
22
+ transforms: A function/transforms that takes in an
29
23
  image and a label and returns the transformed versions of both.
30
- This transform happens after the `image_transforms` and
31
- `target_transforms`.
32
24
  """
33
25
  super().__init__()
34
26
 
35
- self._image_transforms = image_transforms
36
- self._target_transforms = target_transforms
37
- self._image_target_transforms = image_target_transforms
27
+ self._transforms = transforms
38
28
 
39
29
  @property
40
30
  def classes(self) -> List[str] | None:
@@ -56,25 +46,26 @@ class ImageSegmentation(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], abc
56
46
  """
57
47
 
58
48
  @abc.abstractmethod
59
- def load_image(self, index: int) -> np.ndarray:
49
+ def load_image(self, index: int) -> tv_tensors.Image:
60
50
  """Loads and returns the `index`'th image sample.
61
51
 
62
52
  Args:
63
53
  index: The index of the data sample to load.
64
54
 
65
55
  Returns:
66
- The image as a numpy array.
56
+ An image torchvision tensor (channels, height, width).
67
57
  """
68
58
 
69
59
  @abc.abstractmethod
70
- def load_mask(self, index: int) -> np.ndarray:
71
- """Returns the `index`'th target mask sample.
60
+ def load_mask(self, index: int) -> tv_tensors.Mask:
61
+ """Returns the `index`'th target masks sample.
72
62
 
73
63
  Args:
74
- index: The index of the data sample target mask to load.
64
+ index: The index of the data sample target masks to load.
75
65
 
76
66
  Returns:
77
- The sample mask as a stack of binary mask arrays (label, height, width).
67
+ The semantic mask as a (H x W) shaped tensor with integer
68
+ values which represent the pixel class id.
78
69
  """
79
70
 
80
71
  @abc.abstractmethod
@@ -83,30 +74,24 @@ class ImageSegmentation(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], abc
83
74
  raise NotImplementedError
84
75
 
85
76
  @override
86
- def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
77
+ def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, tv_tensors.Mask]:
87
78
  image = self.load_image(index)
88
79
  mask = self.load_mask(index)
89
80
  return self._apply_transforms(image, mask)
90
81
 
91
82
  def _apply_transforms(
92
- self, image: np.ndarray, target: np.ndarray
93
- ) -> Tuple[np.ndarray, np.ndarray]:
83
+ self, image: tv_tensors.Image, mask: tv_tensors.Mask
84
+ ) -> Tuple[tv_tensors.Image, tv_tensors.Mask]:
94
85
  """Applies the transforms to the provided data and returns them.
95
86
 
96
87
  Args:
97
88
  image: The desired image.
98
- target: The target of the image.
89
+ mask: The target segmentation mask.
99
90
 
100
91
  Returns:
101
- A tuple with the image and the target transformed.
92
+ A tuple with the image and the masks transformed.
102
93
  """
103
- if self._image_transforms is not None:
104
- image = self._image_transforms(image)
94
+ if self._transforms is not None:
95
+ image, mask = self._transforms(image, mask)
105
96
 
106
- if self._target_transforms is not None:
107
- target = self._target_transforms(target)
108
-
109
- if self._image_target_transforms is not None:
110
- image, target = self._image_target_transforms(image, target)
111
-
112
- return image, target
97
+ return image, mask