kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (168) hide show
  1. eva/core/callbacks/__init__.py +3 -2
  2. eva/core/callbacks/config.py +143 -0
  3. eva/core/callbacks/writers/__init__.py +6 -3
  4. eva/core/callbacks/writers/embeddings/__init__.py +6 -0
  5. eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
  6. eva/core/callbacks/writers/embeddings/base.py +192 -0
  7. eva/core/callbacks/writers/embeddings/classification.py +117 -0
  8. eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
  9. eva/core/callbacks/writers/embeddings/typings.py +38 -0
  10. eva/core/data/datasets/__init__.py +10 -2
  11. eva/core/data/datasets/classification/__init__.py +5 -2
  12. eva/core/data/datasets/classification/embeddings.py +15 -135
  13. eva/core/data/datasets/classification/multi_embeddings.py +110 -0
  14. eva/core/data/datasets/embeddings.py +167 -0
  15. eva/core/data/splitting/__init__.py +6 -0
  16. eva/core/data/splitting/random.py +41 -0
  17. eva/core/data/splitting/stratified.py +56 -0
  18. eva/core/data/transforms/__init__.py +3 -1
  19. eva/core/data/transforms/padding/__init__.py +5 -0
  20. eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
  21. eva/core/data/transforms/sampling/__init__.py +5 -0
  22. eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
  23. eva/core/loggers/__init__.py +7 -0
  24. eva/core/loggers/dummy.py +38 -0
  25. eva/core/loggers/experimental_loggers.py +8 -0
  26. eva/core/loggers/log/__init__.py +6 -0
  27. eva/core/loggers/log/image.py +71 -0
  28. eva/core/loggers/log/parameters.py +74 -0
  29. eva/core/loggers/log/utils.py +13 -0
  30. eva/core/loggers/loggers.py +6 -0
  31. eva/core/metrics/__init__.py +6 -2
  32. eva/core/metrics/defaults/__init__.py +10 -3
  33. eva/core/metrics/defaults/classification/__init__.py +1 -1
  34. eva/core/metrics/defaults/classification/binary.py +0 -9
  35. eva/core/metrics/defaults/classification/multiclass.py +0 -8
  36. eva/core/metrics/defaults/segmentation/__init__.py +5 -0
  37. eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
  38. eva/core/metrics/generalized_dice.py +59 -0
  39. eva/core/metrics/mean_iou.py +120 -0
  40. eva/core/metrics/structs/schemas.py +3 -1
  41. eva/core/models/__init__.py +3 -1
  42. eva/core/models/modules/head.py +16 -15
  43. eva/core/models/modules/module.py +25 -1
  44. eva/core/models/modules/typings.py +14 -1
  45. eva/core/models/modules/utils/batch_postprocess.py +37 -5
  46. eva/core/models/networks/__init__.py +1 -2
  47. eva/core/models/networks/mlp.py +2 -2
  48. eva/core/models/transforms/__init__.py +6 -0
  49. eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
  50. eva/core/models/transforms/extract_patch_features.py +47 -0
  51. eva/core/models/wrappers/__init__.py +13 -0
  52. eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
  53. eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
  54. eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
  55. eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
  56. eva/core/trainers/_recorder.py +69 -7
  57. eva/core/trainers/functional.py +23 -5
  58. eva/core/trainers/trainer.py +20 -6
  59. eva/core/utils/__init__.py +6 -0
  60. eva/core/utils/clone.py +27 -0
  61. eva/core/utils/memory.py +28 -0
  62. eva/core/utils/operations.py +26 -0
  63. eva/core/utils/parser.py +20 -0
  64. eva/vision/__init__.py +2 -2
  65. eva/vision/callbacks/__init__.py +5 -0
  66. eva/vision/callbacks/loggers/__init__.py +5 -0
  67. eva/vision/callbacks/loggers/batch/__init__.py +5 -0
  68. eva/vision/callbacks/loggers/batch/base.py +130 -0
  69. eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
  70. eva/vision/data/datasets/__init__.py +24 -4
  71. eva/vision/data/datasets/_utils.py +3 -3
  72. eva/vision/data/datasets/_validators.py +15 -2
  73. eva/vision/data/datasets/classification/__init__.py +6 -2
  74. eva/vision/data/datasets/classification/bach.py +10 -15
  75. eva/vision/data/datasets/classification/base.py +17 -24
  76. eva/vision/data/datasets/classification/camelyon16.py +244 -0
  77. eva/vision/data/datasets/classification/crc.py +10 -15
  78. eva/vision/data/datasets/classification/mhist.py +10 -15
  79. eva/vision/data/datasets/classification/panda.py +184 -0
  80. eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
  81. eva/vision/data/datasets/classification/wsi.py +105 -0
  82. eva/vision/data/datasets/segmentation/__init__.py +15 -2
  83. eva/vision/data/datasets/segmentation/_utils.py +38 -0
  84. eva/vision/data/datasets/segmentation/base.py +31 -47
  85. eva/vision/data/datasets/segmentation/bcss.py +236 -0
  86. eva/vision/data/datasets/segmentation/consep.py +156 -0
  87. eva/vision/data/datasets/segmentation/embeddings.py +34 -0
  88. eva/vision/data/datasets/segmentation/lits.py +178 -0
  89. eva/vision/data/datasets/segmentation/monusac.py +236 -0
  90. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
  91. eva/vision/data/datasets/wsi.py +187 -0
  92. eva/vision/data/transforms/__init__.py +3 -2
  93. eva/vision/data/transforms/common/__init__.py +2 -1
  94. eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
  95. eva/vision/data/transforms/common/resize_and_crop.py +6 -7
  96. eva/vision/data/transforms/normalization/__init__.py +6 -0
  97. eva/vision/data/transforms/normalization/clamp.py +43 -0
  98. eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
  99. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
  100. eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
  101. eva/vision/data/wsi/__init__.py +16 -0
  102. eva/vision/data/wsi/backends/__init__.py +69 -0
  103. eva/vision/data/wsi/backends/base.py +115 -0
  104. eva/vision/data/wsi/backends/openslide.py +73 -0
  105. eva/vision/data/wsi/backends/pil.py +52 -0
  106. eva/vision/data/wsi/backends/tiffslide.py +42 -0
  107. eva/vision/data/wsi/patching/__init__.py +6 -0
  108. eva/vision/data/wsi/patching/coordinates.py +98 -0
  109. eva/vision/data/wsi/patching/mask.py +123 -0
  110. eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
  111. eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
  112. eva/vision/data/wsi/patching/samplers/base.py +48 -0
  113. eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
  114. eva/vision/data/wsi/patching/samplers/grid.py +47 -0
  115. eva/vision/data/wsi/patching/samplers/random.py +41 -0
  116. eva/vision/losses/__init__.py +5 -0
  117. eva/vision/losses/dice.py +40 -0
  118. eva/vision/models/__init__.py +4 -2
  119. eva/vision/models/modules/__init__.py +5 -0
  120. eva/vision/models/modules/semantic_segmentation.py +161 -0
  121. eva/vision/models/networks/__init__.py +1 -2
  122. eva/vision/models/networks/backbones/__init__.py +6 -0
  123. eva/vision/models/networks/backbones/_utils.py +39 -0
  124. eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
  125. eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
  126. eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
  127. eva/vision/models/networks/backbones/pathology/histai.py +46 -0
  128. eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
  129. eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
  130. eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
  131. eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
  132. eva/vision/models/networks/backbones/registry.py +47 -0
  133. eva/vision/models/networks/backbones/timm/__init__.py +5 -0
  134. eva/vision/models/networks/backbones/timm/backbones.py +54 -0
  135. eva/vision/models/networks/backbones/universal/__init__.py +8 -0
  136. eva/vision/models/networks/backbones/universal/vit.py +54 -0
  137. eva/vision/models/networks/decoders/__init__.py +6 -0
  138. eva/vision/models/networks/decoders/decoder.py +7 -0
  139. eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
  140. eva/vision/models/networks/decoders/segmentation/common.py +74 -0
  141. eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
  142. eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
  143. eva/vision/models/wrappers/__init__.py +6 -0
  144. eva/vision/models/wrappers/from_registry.py +48 -0
  145. eva/vision/models/wrappers/from_timm.py +68 -0
  146. eva/vision/utils/colormap.py +77 -0
  147. eva/vision/utils/convert.py +67 -0
  148. eva/vision/utils/io/__init__.py +10 -4
  149. eva/vision/utils/io/image.py +21 -2
  150. eva/vision/utils/io/mat.py +36 -0
  151. eva/vision/utils/io/nifti.py +40 -15
  152. eva/vision/utils/io/text.py +10 -3
  153. kaiko_eva-0.1.0.dist-info/METADATA +553 -0
  154. kaiko_eva-0.1.0.dist-info/RECORD +205 -0
  155. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
  156. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
  157. eva/core/callbacks/writers/embeddings.py +0 -169
  158. eva/core/callbacks/writers/typings.py +0 -23
  159. eva/core/models/networks/transforms/__init__.py +0 -5
  160. eva/core/models/networks/wrappers/__init__.py +0 -8
  161. eva/vision/data/datasets/classification/total_segmentator.py +0 -213
  162. eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
  163. eva/vision/models/networks/postprocesses/__init__.py +0 -5
  164. eva/vision/models/networks/postprocesses/cls.py +0 -25
  165. kaiko_eva-0.0.1.dist-info/METADATA +0 -405
  166. kaiko_eva-0.0.1.dist-info/RECORD +0 -110
  167. /eva/core/models/{networks → wrappers}/_utils.py +0 -0
  168. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -3,11 +3,14 @@
3
3
  import os
4
4
  from typing import Any
5
5
 
6
+ import loguru
6
7
  from lightning.pytorch import loggers as pl_loggers
7
8
  from lightning.pytorch import trainer as pl_trainer
8
9
  from lightning.pytorch.utilities import argparse
10
+ from lightning_fabric.utilities import cloud_io
9
11
  from typing_extensions import override
10
12
 
13
+ from eva.core import loggers as eva_loggers
11
14
  from eva.core.data import datamodules
12
15
  from eva.core.models import modules
13
16
  from eva.core.trainers import _logging, functional
@@ -65,13 +68,23 @@ class Trainer(pl_trainer.Trainer):
65
68
  subdirectory: Whether to append a subdirectory to the output log.
66
69
  """
67
70
  self._log_dir = os.path.join(self.default_root_dir, self._session_id, subdirectory)
68
- os.fspath(self._log_dir)
69
71
 
70
- for logger in self.loggers:
71
- if isinstance(logger, (pl_loggers.CSVLogger, pl_loggers.TensorBoardLogger)):
72
- logger._root_dir = self.default_root_dir
73
- logger._name = self._session_id
74
- logger._version = subdirectory
72
+ enabled_loggers = []
73
+ if isinstance(self.loggers, list) and len(self.loggers) > 0:
74
+ for logger in self.loggers:
75
+ if isinstance(logger, (pl_loggers.CSVLogger, pl_loggers.TensorBoardLogger)):
76
+ if not cloud_io._is_local_file_protocol(self.default_root_dir):
77
+ loguru.logger.warning(
78
+ f"Skipped {type(logger).__name__} as remote storage is not supported."
79
+ )
80
+ continue
81
+ else:
82
+ logger._root_dir = self.default_root_dir
83
+ logger._name = self._session_id
84
+ logger._version = subdirectory
85
+ enabled_loggers.append(logger)
86
+
87
+ self._loggers = enabled_loggers or [eva_loggers.DummyLogger(self._log_dir)]
75
88
 
76
89
  def run_evaluation_session(
77
90
  self,
@@ -94,4 +107,5 @@ class Trainer(pl_trainer.Trainer):
94
107
  base_model=model,
95
108
  datamodule=datamodule,
96
109
  n_runs=self._n_runs,
110
+ verbose=self._n_runs > 1,
97
111
  )
@@ -1 +1,7 @@
1
1
  """Utilities and library level helper functionalities."""
2
+
3
+ from eva.core.utils.clone import clone
4
+ from eva.core.utils.memory import to_cpu
5
+ from eva.core.utils.operations import numeric_sort
6
+
7
+ __all__ = ["clone", "to_cpu", "numeric_sort"]
@@ -0,0 +1,27 @@
1
+ """Clone related functions."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, List
5
+
6
+ import torch
7
+
8
+
9
+ @functools.singledispatch
10
+ def clone(tensor_type: Any) -> Any:
11
+ """Clone tensor objects."""
12
+ raise TypeError(f"Unsupported input type: {type(input)}.")
13
+
14
+
15
+ @clone.register
16
+ def _(tensor: torch.Tensor) -> torch.Tensor:
17
+ return tensor.clone()
18
+
19
+
20
+ @clone.register
21
+ def _(tensors: list) -> List[torch.Tensor]:
22
+ return list(map(clone, tensors))
23
+
24
+
25
+ @clone.register
26
+ def _(tensors: dict) -> Dict[str, torch.Tensor]:
27
+ return {key: clone(tensors[key]) for key in tensors}
@@ -0,0 +1,28 @@
1
+ """Memory related functions."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, List
5
+
6
+ import torch
7
+
8
+
9
+ @functools.singledispatch
10
+ def to_cpu(tensor_type: Any) -> Any:
11
+ """Moves tensor objects to `cpu`."""
12
+ raise TypeError(f"Unsupported input type: {type(tensor_type)}.")
13
+
14
+
15
+ @to_cpu.register
16
+ def _(tensor: torch.Tensor) -> torch.Tensor:
17
+ detached_tensor = tensor.detach()
18
+ return detached_tensor.cpu()
19
+
20
+
21
+ @to_cpu.register
22
+ def _(tensors: list) -> List[torch.Tensor]:
23
+ return list(map(to_cpu, tensors))
24
+
25
+
26
+ @to_cpu.register
27
+ def _(tensors: dict) -> Dict[str, torch.Tensor]:
28
+ return {key: to_cpu(tensors[key]) for key in tensors}
@@ -0,0 +1,26 @@
1
+ """Functional operations."""
2
+
3
+ import re
4
+ from typing import Iterable, List
5
+
6
+
7
+ def numeric_sort(item: Iterable[str], /) -> List[str]:
8
+ """Sorts an iterable of strings treating embedded numbers as numeric values.
9
+
10
+ Here the strings are compared based on their numeric value rather than their
11
+ string representation.
12
+
13
+ Args:
14
+ item: An iterable of strings to be sorted.
15
+
16
+ Returns:
17
+ A list of strings sorted based on their numeric values.
18
+ """
19
+ return sorted(
20
+ item,
21
+ key=lambda value: re.sub(
22
+ r"(\d+)",
23
+ lambda num: f"{int(num[0]):010d}",
24
+ value,
25
+ ),
26
+ )
@@ -0,0 +1,20 @@
1
+ """Parsing related helper functions."""
2
+
3
+ from typing import Any, Dict
4
+
5
+ import jsonargparse
6
+
7
+
8
+ def parse_object(config: Dict[str, Any], expected_type: Any = Any) -> Any:
9
+ """Parse object which is defined as dictionary."""
10
+ parser = jsonargparse.ArgumentParser()
11
+ parser.add_argument("module", type=expected_type)
12
+ configuration = parser.parse_object({"module": config})
13
+ init_object = parser.instantiate_classes(configuration)
14
+ obj_module = init_object.module
15
+ if isinstance(obj_module, jsonargparse.Namespace):
16
+ raise ValueError(
17
+ f"Failed to parsed object '{obj_module.class_path}'. "
18
+ "Please check that the initialized arguments are valid."
19
+ )
20
+ return obj_module
eva/vision/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
1
  """eva vision API."""
2
2
 
3
3
  try:
4
- from eva.vision import models, utils
4
+ from eva.vision import callbacks, losses, models, utils
5
5
  from eva.vision.data import datasets, transforms
6
6
  except ImportError as e:
7
7
  msg = (
@@ -11,4 +11,4 @@ except ImportError as e:
11
11
  )
12
12
  raise ImportError(str(e) + "\n\n" + msg) from e
13
13
 
14
- __all__ = ["models", "utils", "datasets", "transforms"]
14
+ __all__ = ["callbacks", "losses", "models", "utils", "datasets", "transforms"]
@@ -0,0 +1,5 @@
1
+ """Vision callbacks API."""
2
+
3
+ from eva.vision.callbacks.loggers import SemanticSegmentationLogger
4
+
5
+ __all__ = ["SemanticSegmentationLogger"]
@@ -0,0 +1,5 @@
1
+ """Vision logging related callbacks API."""
2
+
3
+ from eva.vision.callbacks.loggers.batch import SemanticSegmentationLogger
4
+
5
+ __all__ = ["SemanticSegmentationLogger"]
@@ -0,0 +1,5 @@
1
+ """Batch related loggers callbacks API."""
2
+
3
+ from eva.vision.callbacks.loggers.batch.segmentation import SemanticSegmentationLogger
4
+
5
+ __all__ = ["SemanticSegmentationLogger"]
@@ -0,0 +1,130 @@
1
+ """Base batch callback logger."""
2
+
3
+ import abc
4
+
5
+ from lightning import pytorch as pl
6
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
7
+ from typing_extensions import override
8
+
9
+ from eva.core.models.modules.typings import INPUT_TENSOR_BATCH
10
+
11
+
12
+ class BatchLogger(pl.Callback, abc.ABC):
13
+ """Logs training and validation batch assets."""
14
+
15
+ _batch_idx_to_log: int = 0
16
+ """The batch index log."""
17
+
18
+ def __init__(
19
+ self,
20
+ log_every_n_epochs: int | None = None,
21
+ log_every_n_steps: int | None = None,
22
+ ) -> None:
23
+ """Initializes the callback object.
24
+
25
+ Args:
26
+ log_every_n_epochs: Epoch-wise logging frequency.
27
+ log_every_n_steps: Step-wise logging frequency.
28
+ """
29
+ super().__init__()
30
+
31
+ if log_every_n_epochs is None and log_every_n_steps is None:
32
+ raise ValueError(
33
+ "Please configure the logging frequency though "
34
+ "`log_every_n_epochs` or `log_every_n_steps`."
35
+ )
36
+ if None not in [log_every_n_epochs, log_every_n_steps]:
37
+ raise ValueError(
38
+ "Arguments `log_every_n_epochs` and `log_every_n_steps` "
39
+ "are mutually exclusive. Please configure one of them."
40
+ )
41
+
42
+ self._log_every_n_epochs = log_every_n_epochs
43
+ self._log_every_n_steps = log_every_n_steps
44
+
45
+ @override
46
+ def on_train_batch_end(
47
+ self,
48
+ trainer: pl.Trainer,
49
+ pl_module: pl.LightningModule,
50
+ outputs: STEP_OUTPUT,
51
+ batch: INPUT_TENSOR_BATCH,
52
+ batch_idx: int,
53
+ ) -> None:
54
+ if self._skip_logging(trainer, batch_idx if self._log_every_n_epochs else None):
55
+ return
56
+
57
+ self._log_batch(
58
+ trainer=trainer,
59
+ batch=batch,
60
+ outputs=outputs,
61
+ tag="BatchTrain",
62
+ )
63
+
64
+ @override
65
+ def on_validation_batch_end(
66
+ self,
67
+ trainer: pl.Trainer,
68
+ pl_module: pl.LightningModule,
69
+ outputs: STEP_OUTPUT,
70
+ batch: INPUT_TENSOR_BATCH,
71
+ batch_idx: int,
72
+ dataloader_idx: int = 0,
73
+ ) -> None:
74
+ if self._skip_logging(trainer, batch_idx):
75
+ return
76
+
77
+ self._log_batch(
78
+ trainer=trainer,
79
+ batch=batch,
80
+ outputs=outputs,
81
+ tag="BatchValidation",
82
+ )
83
+
84
+ @abc.abstractmethod
85
+ def _log_batch(
86
+ self,
87
+ trainer: pl.Trainer,
88
+ outputs: STEP_OUTPUT,
89
+ batch: INPUT_TENSOR_BATCH,
90
+ tag: str,
91
+ ) -> None:
92
+ """Logs the batch data.
93
+
94
+ Args:
95
+ trainer: The trainer.
96
+ outputs: The output of the train / val step.
97
+ batch: The data batch.
98
+ tag: The log tag.
99
+ """
100
+
101
+ def _skip_logging(
102
+ self,
103
+ trainer: pl.Trainer,
104
+ batch_idx: int | None = None,
105
+ ) -> bool:
106
+ """Determines whether skip the logging step or not.
107
+
108
+ Args:
109
+ trainer: The trainer.
110
+ batch_idx: The batch index.
111
+
112
+ Returns:
113
+ A boolean indicating whether to skip the step execution.
114
+ """
115
+ if trainer.global_step in [0, 1]:
116
+ return False
117
+
118
+ skip_due_frequency = any(
119
+ [
120
+ (trainer.current_epoch + 1) % (self._log_every_n_epochs or 1) != 0,
121
+ (trainer.global_step + 1) % (self._log_every_n_steps or 1) != 0,
122
+ ]
123
+ )
124
+
125
+ conditions = [
126
+ skip_due_frequency,
127
+ not trainer.is_global_zero,
128
+ batch_idx != self._batch_idx_to_log if batch_idx else False,
129
+ ]
130
+ return any(conditions)
@@ -0,0 +1,188 @@
1
+ """Segmentation datasets related data loggers."""
2
+
3
+ from typing import List, Tuple
4
+
5
+ import torch
6
+ import torchvision
7
+ from lightning import pytorch as pl
8
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
9
+ from torch.nn import functional
10
+ from typing_extensions import override
11
+
12
+ from eva.core.loggers import log
13
+ from eva.core.models.modules.typings import INPUT_TENSOR_BATCH
14
+ from eva.core.utils import to_cpu
15
+ from eva.vision.callbacks.loggers.batch import base
16
+ from eva.vision.utils import colormap, convert
17
+
18
+
19
+ class SemanticSegmentationLogger(base.BatchLogger):
20
+ """Log the segmentation batch."""
21
+
22
+ def __init__(
23
+ self,
24
+ max_samples: int = 10,
25
+ number_of_images_per_subgrid_row: int = 1,
26
+ log_images: bool = True,
27
+ mean: Tuple[float, ...] = (0.0, 0.0, 0.0),
28
+ std: Tuple[float, ...] = (1.0, 1.0, 1.0),
29
+ log_every_n_epochs: int | None = None,
30
+ log_every_n_steps: int | None = None,
31
+ ) -> None:
32
+ """Initializes the callback object.
33
+
34
+ Args:
35
+ max_samples: The maximum number of images displayed in the grid.
36
+ number_of_images_per_subgrid_row: Number of images displayed in each
37
+ row of each sub-grid (that is images, targets and predictions).
38
+ log_images: Whether to log the input batch images.
39
+ mean: The mean of the input images to de-normalize from.
40
+ std: The std of the input images to de-normalize from.
41
+ log_every_n_epochs: Epoch-wise logging frequency.
42
+ log_every_n_steps: Step-wise logging frequency.
43
+ """
44
+ super().__init__(
45
+ log_every_n_epochs=log_every_n_epochs,
46
+ log_every_n_steps=log_every_n_steps,
47
+ )
48
+
49
+ self._max_samples = max_samples
50
+ self._number_of_images_per_subgrid_row = number_of_images_per_subgrid_row
51
+ self._log_images = log_images
52
+ self._mean = mean
53
+ self._std = std
54
+
55
+ @override
56
+ def _log_batch(
57
+ self,
58
+ trainer: pl.Trainer,
59
+ outputs: STEP_OUTPUT,
60
+ batch: INPUT_TENSOR_BATCH,
61
+ tag: str,
62
+ ) -> None:
63
+ predictions = outputs.get("predictions") if isinstance(outputs, dict) else None
64
+ if predictions is None:
65
+ raise ValueError("Key `predictions` is missing from the output.")
66
+
67
+ data_batch, target_batch = batch[0], batch[1]
68
+ data, targets, predictions = _subsample_tensors(
69
+ tensors_stack=[data_batch, target_batch, predictions],
70
+ max_samples=self._max_samples,
71
+ )
72
+ data, targets, predictions = to_cpu([data, targets, predictions])
73
+ predictions = torch.argmax(predictions, dim=1)
74
+
75
+ target_images = list(map(_draw_semantic_mask, targets))
76
+ prediction_images = list(map(_draw_semantic_mask, predictions))
77
+ image_groups = [target_images, prediction_images]
78
+
79
+ if self._log_images:
80
+ images = list(map(self._format_image, data))
81
+ overlay_targets = [
82
+ _overlay_mask(image, mask) for image, mask in zip(images, targets, strict=False)
83
+ ]
84
+ overlay_predictions = [
85
+ _overlay_mask(image, mask) for image, mask in zip(images, predictions, strict=False)
86
+ ]
87
+ image_groups = [images, overlay_targets, overlay_predictions] + image_groups
88
+
89
+ image_grid = _make_grid_from_image_groups(
90
+ image_groups, self._number_of_images_per_subgrid_row
91
+ )
92
+
93
+ log.log_image(
94
+ trainer.loggers,
95
+ image=image_grid,
96
+ tag=tag,
97
+ step=trainer.global_step,
98
+ )
99
+
100
+ def _format_image(self, image: torch.Tensor) -> torch.Tensor:
101
+ """Descaled an image tensor to (0, 255) uint8 tensor."""
102
+ return convert.descale_and_denorm_image(image, mean=self._mean, std=self._std)
103
+
104
+
105
+ def _subsample_tensors(
106
+ tensors_stack: List[torch.Tensor],
107
+ max_samples: int,
108
+ ) -> List[torch.Tensor]:
109
+ """Sub-samples tensors from a list of tensors in-place.
110
+
111
+ Args:
112
+ tensors_stack: A list of tensors.
113
+ max_samples: The maximum number of images
114
+ displayed in the grid.
115
+
116
+ Returns:
117
+ A sub-sample of the input tensors stack.
118
+ """
119
+ for i, tensor in enumerate(tensors_stack):
120
+ tensors_stack[i] = tensor[:max_samples]
121
+ return tensors_stack
122
+
123
+
124
+ def _draw_semantic_mask(tensor: torch.Tensor) -> torch.Tensor:
125
+ """Draws a semantic mask to an image RGB tensor.
126
+
127
+ The input semantic mask is a (H x W) shaped tensor with
128
+ integer values which represent the pixel class id.
129
+
130
+ Args:
131
+ tensor: An image tensor of range [0., 1.].
132
+
133
+ Returns:
134
+ The image as a tensor of range [0., 255.].
135
+ """
136
+ tensor = torch.squeeze(tensor)
137
+ height, width = tensor.shape[-2], tensor.shape[-1]
138
+ red, green, blue = torch.zeros((3, height, width), dtype=torch.uint8)
139
+ for class_id, color in colormap.COLORMAP.items():
140
+ indices = tensor == class_id
141
+ red[indices], green[indices], blue[indices] = color
142
+ return torch.stack([red, green, blue])
143
+
144
+
145
+ def _overlay_mask(image: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
146
+ """Overlays a segmentation mask onto an image.
147
+
148
+ Args:
149
+ image: A 3D tensor of shape (C, H, W) representing the image.
150
+ mask: A 2D tensor of shape (H, W) representing the segmentation mask.
151
+ Each pixel in the mask corresponds to a class label.
152
+
153
+ Returns:
154
+ A tensor of the same shape as the input image (C, H, W) with the
155
+ segmentation mask overlaid on top. The output image retains the
156
+ original color channels but with the mask applied, using the colors
157
+ from the predefined colormap.
158
+ """
159
+ binary_masks = functional.one_hot(mask).permute(2, 0, 1).to(dtype=torch.bool)
160
+ return torchvision.utils.draw_segmentation_masks(
161
+ image, binary_masks[1:], alpha=0.65, colors=colormap.COLORS[1:] # type: ignore
162
+ )
163
+
164
+
165
+ def _make_grid_from_image_groups(
166
+ image_groups: List[List[torch.Tensor]],
167
+ number_of_images_per_subgrid_row: int = 2,
168
+ ) -> torch.Tensor:
169
+ """Creates a single image grid from image groups.
170
+
171
+ For example, it can combine the input images, targets predictions into a single image.
172
+
173
+ Args:
174
+ image_groups: A list of lists of image tensors of shape (C x H x W)
175
+ all of the same size.
176
+ number_of_images_per_subgrid_row: Number of images displayed in each
177
+ row of the sub-grid.
178
+
179
+ Returns:
180
+ An image grid as a `torch.Tensor`.
181
+ """
182
+ return torchvision.utils.make_grid(
183
+ [
184
+ torchvision.utils.make_grid(image_group, nrow=number_of_images_per_subgrid_row)
185
+ for image_group in image_groups
186
+ ],
187
+ nrow=len(image_groups),
188
+ )
@@ -4,19 +4,39 @@ from eva.vision.data.datasets.classification import (
4
4
  BACH,
5
5
  CRC,
6
6
  MHIST,
7
+ PANDA,
8
+ Camelyon16,
7
9
  PatchCamelyon,
8
- TotalSegmentatorClassification,
10
+ WsiClassificationDataset,
11
+ )
12
+ from eva.vision.data.datasets.segmentation import (
13
+ BCSS,
14
+ CoNSeP,
15
+ EmbeddingsSegmentationDataset,
16
+ ImageSegmentation,
17
+ LiTS,
18
+ MoNuSAC,
19
+ TotalSegmentator2D,
9
20
  )
10
- from eva.vision.data.datasets.segmentation import ImageSegmentation, TotalSegmentator2D
11
21
  from eva.vision.data.datasets.vision import VisionDataset
22
+ from eva.vision.data.datasets.wsi import MultiWsiDataset, WsiDataset
12
23
 
13
24
  __all__ = [
14
25
  "BACH",
26
+ "BCSS",
15
27
  "CRC",
16
28
  "MHIST",
17
- "ImageSegmentation",
29
+ "PANDA",
30
+ "Camelyon16",
18
31
  "PatchCamelyon",
19
- "TotalSegmentatorClassification",
32
+ "WsiClassificationDataset",
33
+ "CoNSeP",
34
+ "EmbeddingsSegmentationDataset",
35
+ "ImageSegmentation",
36
+ "LiTS",
37
+ "MoNuSAC",
20
38
  "TotalSegmentator2D",
21
39
  "VisionDataset",
40
+ "MultiWsiDataset",
41
+ "WsiDataset",
22
42
  ]
@@ -1,6 +1,6 @@
1
1
  """Dataset related function and helper functions."""
2
2
 
3
- from typing import List, Tuple
3
+ from typing import List, Sequence, Tuple
4
4
 
5
5
 
6
6
  def indices_to_ranges(indices: List[int]) -> List[Tuple[int, int]]:
@@ -33,11 +33,11 @@ def indices_to_ranges(indices: List[int]) -> List[Tuple[int, int]]:
33
33
  return ranges
34
34
 
35
35
 
36
- def ranges_to_indices(ranges: List[Tuple[int, int]]) -> List[int]:
36
+ def ranges_to_indices(ranges: Sequence[Tuple[int, int]]) -> List[int]:
37
37
  """Unpacks a list of ranges to individual indices.
38
38
 
39
39
  Args:
40
- ranges: The list of ranges to produce the indices from.
40
+ ranges: A sequence of ranges to produce the indices from.
41
41
 
42
42
  Return:
43
43
  A list of the indices.
@@ -13,7 +13,7 @@ _SUFFIX_ERROR_MESSAGE = "Please verify that the data are properly downloaded and
13
13
  def check_dataset_integrity(
14
14
  dataset: vision.VisionDataset,
15
15
  *,
16
- length: int,
16
+ length: int | None,
17
17
  n_classes: int,
18
18
  first_and_last_labels: Tuple[str, str],
19
19
  ) -> None:
@@ -23,7 +23,7 @@ def check_dataset_integrity(
23
23
  ValueError: If the input dataset's values do not
24
24
  match the expected ones.
25
25
  """
26
- if len(dataset) != length:
26
+ if length and len(dataset) != length:
27
27
  raise ValueError(
28
28
  f"Dataset's '{dataset.__class__.__qualname__}' length "
29
29
  f"({len(dataset)}) does not match the expected one ({length}). "
@@ -57,3 +57,16 @@ def check_dataset_exists(dataset_dir: str, download_available: bool) -> None:
57
57
  if download_available:
58
58
  error_message += " You can set `download=True` to download the dataset automatically."
59
59
  raise FileNotFoundError(error_message)
60
+
61
+
62
+ def check_number_of_files(file_paths: List[str], expected_length: int, split: str | None) -> None:
63
+ """Verifies the number of files in the dataset.
64
+
65
+ Raise:
66
+ ValueError: If the number of files in the dataset does not match the expected one.
67
+ """
68
+ if len(file_paths) != expected_length:
69
+ raise ValueError(
70
+ f"Expected {expected_length} files, for split '{split}' found {len(file_paths)}. "
71
+ f"{_SUFFIX_ERROR_MESSAGE}"
72
+ )
@@ -1,15 +1,19 @@
1
1
  """Image classification datasets API."""
2
2
 
3
3
  from eva.vision.data.datasets.classification.bach import BACH
4
+ from eva.vision.data.datasets.classification.camelyon16 import Camelyon16
4
5
  from eva.vision.data.datasets.classification.crc import CRC
5
6
  from eva.vision.data.datasets.classification.mhist import MHIST
7
+ from eva.vision.data.datasets.classification.panda import PANDA
6
8
  from eva.vision.data.datasets.classification.patch_camelyon import PatchCamelyon
7
- from eva.vision.data.datasets.classification.total_segmentator import TotalSegmentatorClassification
9
+ from eva.vision.data.datasets.classification.wsi import WsiClassificationDataset
8
10
 
9
11
  __all__ = [
10
12
  "BACH",
11
13
  "CRC",
12
14
  "MHIST",
13
15
  "PatchCamelyon",
14
- "TotalSegmentatorClassification",
16
+ "WsiClassificationDataset",
17
+ "PANDA",
18
+ "Camelyon16",
15
19
  ]