kaiko-eva 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (159) hide show
  1. eva/core/callbacks/__init__.py +2 -2
  2. eva/core/callbacks/writers/__init__.py +6 -3
  3. eva/core/callbacks/writers/embeddings/__init__.py +6 -0
  4. eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
  5. eva/core/callbacks/writers/embeddings/base.py +192 -0
  6. eva/core/callbacks/writers/embeddings/classification.py +117 -0
  7. eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
  8. eva/core/callbacks/writers/embeddings/typings.py +38 -0
  9. eva/core/data/datasets/__init__.py +2 -2
  10. eva/core/data/datasets/classification/__init__.py +8 -0
  11. eva/core/data/datasets/classification/embeddings.py +34 -0
  12. eva/core/data/datasets/{embeddings/classification → classification}/multi_embeddings.py +13 -9
  13. eva/core/data/datasets/{embeddings/base.py → embeddings.py} +47 -32
  14. eva/core/data/splitting/__init__.py +6 -0
  15. eva/core/data/splitting/random.py +41 -0
  16. eva/core/data/splitting/stratified.py +56 -0
  17. eva/core/loggers/experimental_loggers.py +2 -2
  18. eva/core/loggers/log/__init__.py +3 -2
  19. eva/core/loggers/log/image.py +71 -0
  20. eva/core/loggers/log/parameters.py +10 -0
  21. eva/core/loggers/loggers.py +6 -0
  22. eva/core/metrics/__init__.py +6 -2
  23. eva/core/metrics/defaults/__init__.py +10 -3
  24. eva/core/metrics/defaults/classification/__init__.py +1 -1
  25. eva/core/metrics/defaults/classification/binary.py +0 -9
  26. eva/core/metrics/defaults/classification/multiclass.py +0 -8
  27. eva/core/metrics/defaults/segmentation/__init__.py +5 -0
  28. eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
  29. eva/core/metrics/generalized_dice.py +59 -0
  30. eva/core/metrics/mean_iou.py +120 -0
  31. eva/core/metrics/structs/schemas.py +3 -1
  32. eva/core/models/__init__.py +3 -1
  33. eva/core/models/modules/head.py +10 -4
  34. eva/core/models/modules/typings.py +14 -1
  35. eva/core/models/modules/utils/batch_postprocess.py +37 -5
  36. eva/core/models/networks/__init__.py +1 -2
  37. eva/core/models/networks/mlp.py +2 -2
  38. eva/core/models/transforms/__init__.py +6 -0
  39. eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
  40. eva/core/models/transforms/extract_patch_features.py +47 -0
  41. eva/core/models/wrappers/__init__.py +13 -0
  42. eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
  43. eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
  44. eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
  45. eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
  46. eva/core/trainers/functional.py +1 -0
  47. eva/core/utils/__init__.py +6 -0
  48. eva/core/utils/clone.py +27 -0
  49. eva/core/utils/memory.py +28 -0
  50. eva/core/utils/operations.py +26 -0
  51. eva/core/utils/parser.py +20 -0
  52. eva/vision/__init__.py +2 -2
  53. eva/vision/callbacks/__init__.py +5 -0
  54. eva/vision/callbacks/loggers/__init__.py +5 -0
  55. eva/vision/callbacks/loggers/batch/__init__.py +5 -0
  56. eva/vision/callbacks/loggers/batch/base.py +130 -0
  57. eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
  58. eva/vision/data/datasets/__init__.py +30 -3
  59. eva/vision/data/datasets/_validators.py +15 -2
  60. eva/vision/data/datasets/classification/__init__.py +12 -1
  61. eva/vision/data/datasets/classification/bach.py +10 -15
  62. eva/vision/data/datasets/classification/base.py +17 -24
  63. eva/vision/data/datasets/classification/camelyon16.py +244 -0
  64. eva/vision/data/datasets/classification/crc.py +10 -15
  65. eva/vision/data/datasets/classification/mhist.py +10 -15
  66. eva/vision/data/datasets/classification/panda.py +184 -0
  67. eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
  68. eva/vision/data/datasets/classification/wsi.py +105 -0
  69. eva/vision/data/datasets/segmentation/__init__.py +15 -2
  70. eva/vision/data/datasets/segmentation/_utils.py +38 -0
  71. eva/vision/data/datasets/segmentation/base.py +16 -17
  72. eva/vision/data/datasets/segmentation/bcss.py +236 -0
  73. eva/vision/data/datasets/segmentation/consep.py +156 -0
  74. eva/vision/data/datasets/segmentation/embeddings.py +34 -0
  75. eva/vision/data/datasets/segmentation/lits.py +178 -0
  76. eva/vision/data/datasets/segmentation/monusac.py +236 -0
  77. eva/vision/data/datasets/segmentation/{total_segmentator.py → total_segmentator_2d.py} +130 -36
  78. eva/vision/data/datasets/wsi.py +187 -0
  79. eva/vision/data/transforms/__init__.py +3 -2
  80. eva/vision/data/transforms/common/__init__.py +2 -1
  81. eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
  82. eva/vision/data/transforms/common/resize_and_crop.py +6 -7
  83. eva/vision/data/transforms/normalization/__init__.py +6 -0
  84. eva/vision/data/transforms/normalization/clamp.py +43 -0
  85. eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
  86. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
  87. eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
  88. eva/vision/data/wsi/__init__.py +16 -0
  89. eva/vision/data/wsi/backends/__init__.py +69 -0
  90. eva/vision/data/wsi/backends/base.py +115 -0
  91. eva/vision/data/wsi/backends/openslide.py +73 -0
  92. eva/vision/data/wsi/backends/pil.py +52 -0
  93. eva/vision/data/wsi/backends/tiffslide.py +42 -0
  94. eva/vision/data/wsi/patching/__init__.py +6 -0
  95. eva/vision/data/wsi/patching/coordinates.py +98 -0
  96. eva/vision/data/wsi/patching/mask.py +123 -0
  97. eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
  98. eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
  99. eva/vision/data/wsi/patching/samplers/base.py +48 -0
  100. eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
  101. eva/vision/data/wsi/patching/samplers/grid.py +47 -0
  102. eva/vision/data/wsi/patching/samplers/random.py +41 -0
  103. eva/vision/losses/__init__.py +5 -0
  104. eva/vision/losses/dice.py +40 -0
  105. eva/vision/models/__init__.py +4 -2
  106. eva/vision/models/modules/__init__.py +5 -0
  107. eva/vision/models/modules/semantic_segmentation.py +161 -0
  108. eva/vision/models/networks/__init__.py +1 -2
  109. eva/vision/models/networks/backbones/__init__.py +6 -0
  110. eva/vision/models/networks/backbones/_utils.py +39 -0
  111. eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
  112. eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
  113. eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
  114. eva/vision/models/networks/backbones/pathology/histai.py +46 -0
  115. eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
  116. eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
  117. eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
  118. eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
  119. eva/vision/models/networks/backbones/registry.py +47 -0
  120. eva/vision/models/networks/backbones/timm/__init__.py +5 -0
  121. eva/vision/models/networks/backbones/timm/backbones.py +54 -0
  122. eva/vision/models/networks/backbones/universal/__init__.py +8 -0
  123. eva/vision/models/networks/backbones/universal/vit.py +54 -0
  124. eva/vision/models/networks/decoders/__init__.py +6 -0
  125. eva/vision/models/networks/decoders/decoder.py +7 -0
  126. eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
  127. eva/vision/models/networks/decoders/segmentation/common.py +74 -0
  128. eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
  129. eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
  130. eva/vision/models/wrappers/__init__.py +6 -0
  131. eva/vision/models/wrappers/from_registry.py +48 -0
  132. eva/vision/models/wrappers/from_timm.py +68 -0
  133. eva/vision/utils/colormap.py +77 -0
  134. eva/vision/utils/convert.py +56 -13
  135. eva/vision/utils/io/__init__.py +10 -4
  136. eva/vision/utils/io/image.py +21 -2
  137. eva/vision/utils/io/mat.py +36 -0
  138. eva/vision/utils/io/nifti.py +33 -12
  139. eva/vision/utils/io/text.py +10 -3
  140. kaiko_eva-0.1.0.dist-info/METADATA +553 -0
  141. kaiko_eva-0.1.0.dist-info/RECORD +205 -0
  142. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
  143. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
  144. eva/.DS_Store +0 -0
  145. eva/core/callbacks/writers/embeddings.py +0 -169
  146. eva/core/callbacks/writers/typings.py +0 -23
  147. eva/core/data/datasets/embeddings/__init__.py +0 -13
  148. eva/core/data/datasets/embeddings/classification/__init__.py +0 -10
  149. eva/core/data/datasets/embeddings/classification/embeddings.py +0 -66
  150. eva/core/models/networks/transforms/__init__.py +0 -5
  151. eva/core/models/networks/wrappers/__init__.py +0 -8
  152. eva/vision/models/.DS_Store +0 -0
  153. eva/vision/models/networks/.DS_Store +0 -0
  154. eva/vision/models/networks/postprocesses/__init__.py +0 -5
  155. eva/vision/models/networks/postprocesses/cls.py +0 -25
  156. kaiko_eva-0.0.2.dist-info/METADATA +0 -431
  157. kaiko_eva-0.0.2.dist-info/RECORD +0 -127
  158. /eva/core/models/{networks → wrappers}/_utils.py +0 -0
  159. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,130 @@
1
+ """Base batch callback logger."""
2
+
3
+ import abc
4
+
5
+ from lightning import pytorch as pl
6
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
7
+ from typing_extensions import override
8
+
9
+ from eva.core.models.modules.typings import INPUT_TENSOR_BATCH
10
+
11
+
12
+ class BatchLogger(pl.Callback, abc.ABC):
13
+ """Logs training and validation batch assets."""
14
+
15
+ _batch_idx_to_log: int = 0
16
+ """The batch index log."""
17
+
18
+ def __init__(
19
+ self,
20
+ log_every_n_epochs: int | None = None,
21
+ log_every_n_steps: int | None = None,
22
+ ) -> None:
23
+ """Initializes the callback object.
24
+
25
+ Args:
26
+ log_every_n_epochs: Epoch-wise logging frequency.
27
+ log_every_n_steps: Step-wise logging frequency.
28
+ """
29
+ super().__init__()
30
+
31
+ if log_every_n_epochs is None and log_every_n_steps is None:
32
+ raise ValueError(
33
+ "Please configure the logging frequency though "
34
+ "`log_every_n_epochs` or `log_every_n_steps`."
35
+ )
36
+ if None not in [log_every_n_epochs, log_every_n_steps]:
37
+ raise ValueError(
38
+ "Arguments `log_every_n_epochs` and `log_every_n_steps` "
39
+ "are mutually exclusive. Please configure one of them."
40
+ )
41
+
42
+ self._log_every_n_epochs = log_every_n_epochs
43
+ self._log_every_n_steps = log_every_n_steps
44
+
45
+ @override
46
+ def on_train_batch_end(
47
+ self,
48
+ trainer: pl.Trainer,
49
+ pl_module: pl.LightningModule,
50
+ outputs: STEP_OUTPUT,
51
+ batch: INPUT_TENSOR_BATCH,
52
+ batch_idx: int,
53
+ ) -> None:
54
+ if self._skip_logging(trainer, batch_idx if self._log_every_n_epochs else None):
55
+ return
56
+
57
+ self._log_batch(
58
+ trainer=trainer,
59
+ batch=batch,
60
+ outputs=outputs,
61
+ tag="BatchTrain",
62
+ )
63
+
64
+ @override
65
+ def on_validation_batch_end(
66
+ self,
67
+ trainer: pl.Trainer,
68
+ pl_module: pl.LightningModule,
69
+ outputs: STEP_OUTPUT,
70
+ batch: INPUT_TENSOR_BATCH,
71
+ batch_idx: int,
72
+ dataloader_idx: int = 0,
73
+ ) -> None:
74
+ if self._skip_logging(trainer, batch_idx):
75
+ return
76
+
77
+ self._log_batch(
78
+ trainer=trainer,
79
+ batch=batch,
80
+ outputs=outputs,
81
+ tag="BatchValidation",
82
+ )
83
+
84
+ @abc.abstractmethod
85
+ def _log_batch(
86
+ self,
87
+ trainer: pl.Trainer,
88
+ outputs: STEP_OUTPUT,
89
+ batch: INPUT_TENSOR_BATCH,
90
+ tag: str,
91
+ ) -> None:
92
+ """Logs the batch data.
93
+
94
+ Args:
95
+ trainer: The trainer.
96
+ outputs: The output of the train / val step.
97
+ batch: The data batch.
98
+ tag: The log tag.
99
+ """
100
+
101
+ def _skip_logging(
102
+ self,
103
+ trainer: pl.Trainer,
104
+ batch_idx: int | None = None,
105
+ ) -> bool:
106
+ """Determines whether skip the logging step or not.
107
+
108
+ Args:
109
+ trainer: The trainer.
110
+ batch_idx: The batch index.
111
+
112
+ Returns:
113
+ A boolean indicating whether to skip the step execution.
114
+ """
115
+ if trainer.global_step in [0, 1]:
116
+ return False
117
+
118
+ skip_due_frequency = any(
119
+ [
120
+ (trainer.current_epoch + 1) % (self._log_every_n_epochs or 1) != 0,
121
+ (trainer.global_step + 1) % (self._log_every_n_steps or 1) != 0,
122
+ ]
123
+ )
124
+
125
+ conditions = [
126
+ skip_due_frequency,
127
+ not trainer.is_global_zero,
128
+ batch_idx != self._batch_idx_to_log if batch_idx else False,
129
+ ]
130
+ return any(conditions)
@@ -0,0 +1,188 @@
1
+ """Segmentation datasets related data loggers."""
2
+
3
+ from typing import List, Tuple
4
+
5
+ import torch
6
+ import torchvision
7
+ from lightning import pytorch as pl
8
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
9
+ from torch.nn import functional
10
+ from typing_extensions import override
11
+
12
+ from eva.core.loggers import log
13
+ from eva.core.models.modules.typings import INPUT_TENSOR_BATCH
14
+ from eva.core.utils import to_cpu
15
+ from eva.vision.callbacks.loggers.batch import base
16
+ from eva.vision.utils import colormap, convert
17
+
18
+
19
+ class SemanticSegmentationLogger(base.BatchLogger):
20
+ """Log the segmentation batch."""
21
+
22
+ def __init__(
23
+ self,
24
+ max_samples: int = 10,
25
+ number_of_images_per_subgrid_row: int = 1,
26
+ log_images: bool = True,
27
+ mean: Tuple[float, ...] = (0.0, 0.0, 0.0),
28
+ std: Tuple[float, ...] = (1.0, 1.0, 1.0),
29
+ log_every_n_epochs: int | None = None,
30
+ log_every_n_steps: int | None = None,
31
+ ) -> None:
32
+ """Initializes the callback object.
33
+
34
+ Args:
35
+ max_samples: The maximum number of images displayed in the grid.
36
+ number_of_images_per_subgrid_row: Number of images displayed in each
37
+ row of each sub-grid (that is images, targets and predictions).
38
+ log_images: Whether to log the input batch images.
39
+ mean: The mean of the input images to de-normalize from.
40
+ std: The std of the input images to de-normalize from.
41
+ log_every_n_epochs: Epoch-wise logging frequency.
42
+ log_every_n_steps: Step-wise logging frequency.
43
+ """
44
+ super().__init__(
45
+ log_every_n_epochs=log_every_n_epochs,
46
+ log_every_n_steps=log_every_n_steps,
47
+ )
48
+
49
+ self._max_samples = max_samples
50
+ self._number_of_images_per_subgrid_row = number_of_images_per_subgrid_row
51
+ self._log_images = log_images
52
+ self._mean = mean
53
+ self._std = std
54
+
55
+ @override
56
+ def _log_batch(
57
+ self,
58
+ trainer: pl.Trainer,
59
+ outputs: STEP_OUTPUT,
60
+ batch: INPUT_TENSOR_BATCH,
61
+ tag: str,
62
+ ) -> None:
63
+ predictions = outputs.get("predictions") if isinstance(outputs, dict) else None
64
+ if predictions is None:
65
+ raise ValueError("Key `predictions` is missing from the output.")
66
+
67
+ data_batch, target_batch = batch[0], batch[1]
68
+ data, targets, predictions = _subsample_tensors(
69
+ tensors_stack=[data_batch, target_batch, predictions],
70
+ max_samples=self._max_samples,
71
+ )
72
+ data, targets, predictions = to_cpu([data, targets, predictions])
73
+ predictions = torch.argmax(predictions, dim=1)
74
+
75
+ target_images = list(map(_draw_semantic_mask, targets))
76
+ prediction_images = list(map(_draw_semantic_mask, predictions))
77
+ image_groups = [target_images, prediction_images]
78
+
79
+ if self._log_images:
80
+ images = list(map(self._format_image, data))
81
+ overlay_targets = [
82
+ _overlay_mask(image, mask) for image, mask in zip(images, targets, strict=False)
83
+ ]
84
+ overlay_predictions = [
85
+ _overlay_mask(image, mask) for image, mask in zip(images, predictions, strict=False)
86
+ ]
87
+ image_groups = [images, overlay_targets, overlay_predictions] + image_groups
88
+
89
+ image_grid = _make_grid_from_image_groups(
90
+ image_groups, self._number_of_images_per_subgrid_row
91
+ )
92
+
93
+ log.log_image(
94
+ trainer.loggers,
95
+ image=image_grid,
96
+ tag=tag,
97
+ step=trainer.global_step,
98
+ )
99
+
100
+ def _format_image(self, image: torch.Tensor) -> torch.Tensor:
101
+ """Descaled an image tensor to (0, 255) uint8 tensor."""
102
+ return convert.descale_and_denorm_image(image, mean=self._mean, std=self._std)
103
+
104
+
105
+ def _subsample_tensors(
106
+ tensors_stack: List[torch.Tensor],
107
+ max_samples: int,
108
+ ) -> List[torch.Tensor]:
109
+ """Sub-samples tensors from a list of tensors in-place.
110
+
111
+ Args:
112
+ tensors_stack: A list of tensors.
113
+ max_samples: The maximum number of images
114
+ displayed in the grid.
115
+
116
+ Returns:
117
+ A sub-sample of the input tensors stack.
118
+ """
119
+ for i, tensor in enumerate(tensors_stack):
120
+ tensors_stack[i] = tensor[:max_samples]
121
+ return tensors_stack
122
+
123
+
124
+ def _draw_semantic_mask(tensor: torch.Tensor) -> torch.Tensor:
125
+ """Draws a semantic mask to an image RGB tensor.
126
+
127
+ The input semantic mask is a (H x W) shaped tensor with
128
+ integer values which represent the pixel class id.
129
+
130
+ Args:
131
+ tensor: An image tensor of range [0., 1.].
132
+
133
+ Returns:
134
+ The image as a tensor of range [0., 255.].
135
+ """
136
+ tensor = torch.squeeze(tensor)
137
+ height, width = tensor.shape[-2], tensor.shape[-1]
138
+ red, green, blue = torch.zeros((3, height, width), dtype=torch.uint8)
139
+ for class_id, color in colormap.COLORMAP.items():
140
+ indices = tensor == class_id
141
+ red[indices], green[indices], blue[indices] = color
142
+ return torch.stack([red, green, blue])
143
+
144
+
145
+ def _overlay_mask(image: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
146
+ """Overlays a segmentation mask onto an image.
147
+
148
+ Args:
149
+ image: A 3D tensor of shape (C, H, W) representing the image.
150
+ mask: A 2D tensor of shape (H, W) representing the segmentation mask.
151
+ Each pixel in the mask corresponds to a class label.
152
+
153
+ Returns:
154
+ A tensor of the same shape as the input image (C, H, W) with the
155
+ segmentation mask overlaid on top. The output image retains the
156
+ original color channels but with the mask applied, using the colors
157
+ from the predefined colormap.
158
+ """
159
+ binary_masks = functional.one_hot(mask).permute(2, 0, 1).to(dtype=torch.bool)
160
+ return torchvision.utils.draw_segmentation_masks(
161
+ image, binary_masks[1:], alpha=0.65, colors=colormap.COLORS[1:] # type: ignore
162
+ )
163
+
164
+
165
+ def _make_grid_from_image_groups(
166
+ image_groups: List[List[torch.Tensor]],
167
+ number_of_images_per_subgrid_row: int = 2,
168
+ ) -> torch.Tensor:
169
+ """Creates a single image grid from image groups.
170
+
171
+ For example, it can combine the input images, targets predictions into a single image.
172
+
173
+ Args:
174
+ image_groups: A list of lists of image tensors of shape (C x H x W)
175
+ all of the same size.
176
+ number_of_images_per_subgrid_row: Number of images displayed in each
177
+ row of the sub-grid.
178
+
179
+ Returns:
180
+ An image grid as a `torch.Tensor`.
181
+ """
182
+ return torchvision.utils.make_grid(
183
+ [
184
+ torchvision.utils.make_grid(image_group, nrow=number_of_images_per_subgrid_row)
185
+ for image_group in image_groups
186
+ ],
187
+ nrow=len(image_groups),
188
+ )
@@ -1,15 +1,42 @@
1
1
  """Vision Datasets API."""
2
2
 
3
- from eva.vision.data.datasets.classification import BACH, CRC, MHIST, PatchCamelyon
4
- from eva.vision.data.datasets.segmentation import ImageSegmentation, TotalSegmentator2D
3
+ from eva.vision.data.datasets.classification import (
4
+ BACH,
5
+ CRC,
6
+ MHIST,
7
+ PANDA,
8
+ Camelyon16,
9
+ PatchCamelyon,
10
+ WsiClassificationDataset,
11
+ )
12
+ from eva.vision.data.datasets.segmentation import (
13
+ BCSS,
14
+ CoNSeP,
15
+ EmbeddingsSegmentationDataset,
16
+ ImageSegmentation,
17
+ LiTS,
18
+ MoNuSAC,
19
+ TotalSegmentator2D,
20
+ )
5
21
  from eva.vision.data.datasets.vision import VisionDataset
22
+ from eva.vision.data.datasets.wsi import MultiWsiDataset, WsiDataset
6
23
 
7
24
  __all__ = [
8
25
  "BACH",
26
+ "BCSS",
9
27
  "CRC",
10
28
  "MHIST",
11
- "ImageSegmentation",
29
+ "PANDA",
30
+ "Camelyon16",
12
31
  "PatchCamelyon",
32
+ "WsiClassificationDataset",
33
+ "CoNSeP",
34
+ "EmbeddingsSegmentationDataset",
35
+ "ImageSegmentation",
36
+ "LiTS",
37
+ "MoNuSAC",
13
38
  "TotalSegmentator2D",
14
39
  "VisionDataset",
40
+ "MultiWsiDataset",
41
+ "WsiDataset",
15
42
  ]
@@ -13,7 +13,7 @@ _SUFFIX_ERROR_MESSAGE = "Please verify that the data are properly downloaded and
13
13
  def check_dataset_integrity(
14
14
  dataset: vision.VisionDataset,
15
15
  *,
16
- length: int,
16
+ length: int | None,
17
17
  n_classes: int,
18
18
  first_and_last_labels: Tuple[str, str],
19
19
  ) -> None:
@@ -23,7 +23,7 @@ def check_dataset_integrity(
23
23
  ValueError: If the input dataset's values do not
24
24
  match the expected ones.
25
25
  """
26
- if len(dataset) != length:
26
+ if length and len(dataset) != length:
27
27
  raise ValueError(
28
28
  f"Dataset's '{dataset.__class__.__qualname__}' length "
29
29
  f"({len(dataset)}) does not match the expected one ({length}). "
@@ -57,3 +57,16 @@ def check_dataset_exists(dataset_dir: str, download_available: bool) -> None:
57
57
  if download_available:
58
58
  error_message += " You can set `download=True` to download the dataset automatically."
59
59
  raise FileNotFoundError(error_message)
60
+
61
+
62
+ def check_number_of_files(file_paths: List[str], expected_length: int, split: str | None) -> None:
63
+ """Verifies the number of files in the dataset.
64
+
65
+ Raise:
66
+ ValueError: If the number of files in the dataset does not match the expected one.
67
+ """
68
+ if len(file_paths) != expected_length:
69
+ raise ValueError(
70
+ f"Expected {expected_length} files, for split '{split}' found {len(file_paths)}. "
71
+ f"{_SUFFIX_ERROR_MESSAGE}"
72
+ )
@@ -1,8 +1,19 @@
1
1
  """Image classification datasets API."""
2
2
 
3
3
  from eva.vision.data.datasets.classification.bach import BACH
4
+ from eva.vision.data.datasets.classification.camelyon16 import Camelyon16
4
5
  from eva.vision.data.datasets.classification.crc import CRC
5
6
  from eva.vision.data.datasets.classification.mhist import MHIST
7
+ from eva.vision.data.datasets.classification.panda import PANDA
6
8
  from eva.vision.data.datasets.classification.patch_camelyon import PatchCamelyon
9
+ from eva.vision.data.datasets.classification.wsi import WsiClassificationDataset
7
10
 
8
- __all__ = ["BACH", "CRC", "MHIST", "PatchCamelyon"]
11
+ __all__ = [
12
+ "BACH",
13
+ "CRC",
14
+ "MHIST",
15
+ "PatchCamelyon",
16
+ "WsiClassificationDataset",
17
+ "PANDA",
18
+ "Camelyon16",
19
+ ]
@@ -3,7 +3,8 @@
3
3
  import os
4
4
  from typing import Callable, Dict, List, Literal, Tuple
5
5
 
6
- import numpy as np
6
+ import torch
7
+ from torchvision import tv_tensors
7
8
  from torchvision.datasets import folder, utils
8
9
  from typing_extensions import override
9
10
 
@@ -52,8 +53,7 @@ class BACH(base.ImageClassification):
52
53
  root: str,
53
54
  split: Literal["train", "val"] | None = None,
54
55
  download: bool = False,
55
- image_transforms: Callable | None = None,
56
- target_transforms: Callable | None = None,
56
+ transforms: Callable | None = None,
57
57
  ) -> None:
58
58
  """Initialize the dataset.
59
59
 
@@ -68,15 +68,10 @@ class BACH(base.ImageClassification):
68
68
  Note that the download will be executed only by additionally
69
69
  calling the :meth:`prepare_data` method and if the data does
70
70
  not yet exist on disk.
71
- image_transforms: A function/transform that takes in an image
72
- and returns a transformed version.
73
- target_transforms: A function/transform that takes in the target
74
- and transforms it.
71
+ transforms: A function/transform which returns a transformed
72
+ version of the raw data samples.
75
73
  """
76
- super().__init__(
77
- image_transforms=image_transforms,
78
- target_transforms=target_transforms,
79
- )
74
+ super().__init__(transforms=transforms)
80
75
 
81
76
  self._root = root
82
77
  self._split = split
@@ -130,14 +125,14 @@ class BACH(base.ImageClassification):
130
125
  )
131
126
 
132
127
  @override
133
- def load_image(self, index: int) -> np.ndarray:
128
+ def load_image(self, index: int) -> tv_tensors.Image:
134
129
  image_path, _ = self._samples[self._indices[index]]
135
- return io.read_image(image_path)
130
+ return io.read_image_as_tensor(image_path)
136
131
 
137
132
  @override
138
- def load_target(self, index: int) -> np.ndarray:
133
+ def load_target(self, index: int) -> torch.Tensor:
139
134
  _, target = self._samples[self._indices[index]]
140
- return np.asarray(target, dtype=np.int64)
135
+ return torch.tensor(target, dtype=torch.long)
141
136
 
142
137
  @override
143
138
  def __len__(self) -> int:
@@ -3,32 +3,29 @@
3
3
  import abc
4
4
  from typing import Any, Callable, Dict, List, Tuple
5
5
 
6
- import numpy as np
6
+ import torch
7
+ from torchvision import tv_tensors
7
8
  from typing_extensions import override
8
9
 
9
10
  from eva.vision.data.datasets import vision
10
11
 
11
12
 
12
- class ImageClassification(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], abc.ABC):
13
+ class ImageClassification(vision.VisionDataset[Tuple[tv_tensors.Image, torch.Tensor]], abc.ABC):
13
14
  """Image classification abstract dataset."""
14
15
 
15
16
  def __init__(
16
17
  self,
17
- image_transforms: Callable | None = None,
18
- target_transforms: Callable | None = None,
18
+ transforms: Callable | None = None,
19
19
  ) -> None:
20
20
  """Initializes the image classification dataset.
21
21
 
22
22
  Args:
23
- image_transforms: A function/transform that takes in an image
24
- and returns a transformed version.
25
- target_transforms: A function/transform that takes in the target
26
- and transforms it.
23
+ transforms: A function/transform which returns a transformed
24
+ version of the raw data samples.
27
25
  """
28
26
  super().__init__()
29
27
 
30
- self._image_transforms = image_transforms
31
- self._target_transforms = target_transforms
28
+ self._transforms = transforms
32
29
 
33
30
  @property
34
31
  def classes(self) -> List[str] | None:
@@ -38,19 +35,18 @@ class ImageClassification(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], a
38
35
  def class_to_idx(self) -> Dict[str, int] | None:
39
36
  """Returns a mapping of the class name to its target index."""
40
37
 
41
- def load_metadata(self, index: int | None) -> Dict[str, Any] | List[Dict[str, Any]] | None:
38
+ def load_metadata(self, index: int) -> Dict[str, Any] | None:
42
39
  """Returns the dataset metadata.
43
40
 
44
41
  Args:
45
42
  index: The index of the data sample to return the metadata of.
46
- If `None`, it will return the metadata of the current dataset.
47
43
 
48
44
  Returns:
49
45
  The sample metadata.
50
46
  """
51
47
 
52
48
  @abc.abstractmethod
53
- def load_image(self, index: int) -> np.ndarray:
49
+ def load_image(self, index: int) -> tv_tensors.Image:
54
50
  """Returns the `index`'th image sample.
55
51
 
56
52
  Args:
@@ -61,7 +57,7 @@ class ImageClassification(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], a
61
57
  """
62
58
 
63
59
  @abc.abstractmethod
64
- def load_target(self, index: int) -> np.ndarray:
60
+ def load_target(self, index: int) -> torch.Tensor:
65
61
  """Returns the `index`'th target sample.
66
62
 
67
63
  Args:
@@ -77,14 +73,15 @@ class ImageClassification(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], a
77
73
  raise NotImplementedError
78
74
 
79
75
  @override
80
- def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
76
+ def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
81
77
  image = self.load_image(index)
82
78
  target = self.load_target(index)
83
- return self._apply_transforms(image, target)
79
+ image, target = self._apply_transforms(image, target)
80
+ return image, target, self.load_metadata(index) or {}
84
81
 
85
82
  def _apply_transforms(
86
- self, image: np.ndarray, target: np.ndarray
87
- ) -> Tuple[np.ndarray, np.ndarray]:
83
+ self, image: tv_tensors.Image, target: torch.Tensor
84
+ ) -> Tuple[tv_tensors.Image, torch.Tensor]:
88
85
  """Applies the transforms to the provided data and returns them.
89
86
 
90
87
  Args:
@@ -94,10 +91,6 @@ class ImageClassification(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], a
94
91
  Returns:
95
92
  A tuple with the image and the target transformed.
96
93
  """
97
- if self._image_transforms is not None:
98
- image = self._image_transforms(image)
99
-
100
- if self._target_transforms is not None:
101
- target = self._target_transforms(target)
102
-
94
+ if self._transforms is not None:
95
+ image, target = self._transforms(image, target)
103
96
  return image, target