kaiko-eva 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (131) hide show
  1. eva/core/callbacks/config.py +15 -6
  2. eva/core/callbacks/writers/embeddings/base.py +44 -10
  3. eva/core/cli/setup.py +1 -1
  4. eva/core/data/dataloaders/__init__.py +1 -2
  5. eva/core/data/samplers/classification/balanced.py +24 -12
  6. eva/core/data/samplers/random.py +17 -10
  7. eva/core/interface/interface.py +21 -0
  8. eva/core/loggers/utils/wandb.py +4 -1
  9. eva/core/models/modules/module.py +2 -2
  10. eva/core/models/wrappers/base.py +2 -2
  11. eva/core/models/wrappers/from_function.py +3 -3
  12. eva/core/models/wrappers/from_torchhub.py +9 -7
  13. eva/core/models/wrappers/huggingface.py +4 -5
  14. eva/core/models/wrappers/onnx.py +5 -5
  15. eva/core/trainers/trainer.py +13 -1
  16. eva/core/utils/__init__.py +2 -1
  17. eva/core/utils/distributed.py +12 -0
  18. eva/core/utils/paths.py +14 -0
  19. eva/core/utils/requirements.py +52 -6
  20. eva/language/__init__.py +2 -1
  21. eva/language/callbacks/__init__.py +5 -0
  22. eva/language/callbacks/writers/__init__.py +5 -0
  23. eva/language/callbacks/writers/prediction.py +201 -0
  24. eva/language/data/dataloaders/__init__.py +5 -0
  25. eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
  26. eva/language/data/dataloaders/collate_fn/text.py +57 -0
  27. eva/language/data/datasets/__init__.py +3 -1
  28. eva/language/data/datasets/{language.py → base.py} +1 -1
  29. eva/language/data/datasets/classification/base.py +3 -43
  30. eva/language/data/datasets/classification/pubmedqa.py +36 -4
  31. eva/language/data/datasets/prediction.py +151 -0
  32. eva/language/data/datasets/schemas.py +18 -0
  33. eva/language/data/datasets/text.py +92 -0
  34. eva/language/data/datasets/typings.py +39 -0
  35. eva/language/data/messages.py +60 -0
  36. eva/language/models/__init__.py +15 -11
  37. eva/language/models/modules/__init__.py +2 -2
  38. eva/language/models/modules/language.py +94 -0
  39. eva/language/models/networks/__init__.py +12 -0
  40. eva/language/models/networks/alibaba.py +26 -0
  41. eva/language/models/networks/api/__init__.py +11 -0
  42. eva/language/models/networks/api/anthropic.py +34 -0
  43. eva/language/models/networks/registry.py +5 -0
  44. eva/language/models/typings.py +56 -0
  45. eva/language/models/wrappers/__init__.py +13 -5
  46. eva/language/models/wrappers/base.py +47 -0
  47. eva/language/models/wrappers/from_registry.py +54 -0
  48. eva/language/models/wrappers/huggingface.py +57 -11
  49. eva/language/models/wrappers/litellm.py +91 -46
  50. eva/language/models/wrappers/vllm.py +37 -13
  51. eva/language/utils/__init__.py +2 -1
  52. eva/language/utils/str_to_int_tensor.py +20 -12
  53. eva/language/utils/text/__init__.py +5 -0
  54. eva/language/utils/text/messages.py +113 -0
  55. eva/multimodal/__init__.py +6 -0
  56. eva/multimodal/callbacks/__init__.py +5 -0
  57. eva/multimodal/callbacks/writers/__init__.py +5 -0
  58. eva/multimodal/callbacks/writers/prediction.py +39 -0
  59. eva/multimodal/data/__init__.py +5 -0
  60. eva/multimodal/data/dataloaders/__init__.py +5 -0
  61. eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
  62. eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
  63. eva/multimodal/data/datasets/__init__.py +6 -0
  64. eva/multimodal/data/datasets/base.py +13 -0
  65. eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
  66. eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
  67. eva/multimodal/data/datasets/schemas.py +14 -0
  68. eva/multimodal/data/datasets/text_image.py +77 -0
  69. eva/multimodal/data/datasets/typings.py +27 -0
  70. eva/multimodal/models/__init__.py +8 -0
  71. eva/multimodal/models/modules/__init__.py +5 -0
  72. eva/multimodal/models/modules/vision_language.py +56 -0
  73. eva/multimodal/models/networks/__init__.py +14 -0
  74. eva/multimodal/models/networks/alibaba.py +40 -0
  75. eva/multimodal/models/networks/api/__init__.py +11 -0
  76. eva/multimodal/models/networks/api/anthropic.py +34 -0
  77. eva/multimodal/models/networks/others.py +48 -0
  78. eva/multimodal/models/networks/registry.py +5 -0
  79. eva/multimodal/models/typings.py +27 -0
  80. eva/multimodal/models/wrappers/__init__.py +13 -0
  81. eva/multimodal/models/wrappers/base.py +48 -0
  82. eva/multimodal/models/wrappers/from_registry.py +54 -0
  83. eva/multimodal/models/wrappers/huggingface.py +193 -0
  84. eva/multimodal/models/wrappers/litellm.py +58 -0
  85. eva/multimodal/utils/__init__.py +1 -0
  86. eva/multimodal/utils/batch/__init__.py +5 -0
  87. eva/multimodal/utils/batch/unpack.py +11 -0
  88. eva/multimodal/utils/image/__init__.py +5 -0
  89. eva/multimodal/utils/image/encode.py +28 -0
  90. eva/multimodal/utils/text/__init__.py +1 -0
  91. eva/multimodal/utils/text/messages.py +79 -0
  92. eva/vision/data/datasets/classification/breakhis.py +5 -8
  93. eva/vision/data/datasets/classification/panda.py +12 -5
  94. eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
  95. eva/vision/data/datasets/segmentation/btcv.py +1 -1
  96. eva/vision/data/datasets/segmentation/consep.py +1 -1
  97. eva/vision/data/datasets/segmentation/lits17.py +1 -1
  98. eva/vision/data/datasets/segmentation/monusac.py +15 -6
  99. eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +1 -1
  100. eva/vision/data/transforms/__init__.py +2 -1
  101. eva/vision/data/transforms/base/__init__.py +2 -1
  102. eva/vision/data/transforms/base/monai.py +2 -2
  103. eva/vision/data/transforms/base/torchvision.py +33 -0
  104. eva/vision/data/transforms/common/squeeze.py +6 -3
  105. eva/vision/data/transforms/croppad/crop_foreground.py +8 -7
  106. eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +6 -5
  107. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -5
  108. eva/vision/data/transforms/croppad/rand_spatial_crop.py +8 -7
  109. eva/vision/data/transforms/croppad/spatial_pad.py +6 -6
  110. eva/vision/data/transforms/intensity/rand_scale_intensity.py +3 -3
  111. eva/vision/data/transforms/intensity/rand_shift_intensity.py +3 -3
  112. eva/vision/data/transforms/intensity/scale_intensity_ranged.py +5 -5
  113. eva/vision/data/transforms/spatial/__init__.py +2 -1
  114. eva/vision/data/transforms/spatial/flip.py +8 -7
  115. eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
  116. eva/vision/data/transforms/spatial/functional/resize.py +26 -0
  117. eva/vision/data/transforms/spatial/resize.py +63 -0
  118. eva/vision/data/transforms/spatial/rotate.py +8 -7
  119. eva/vision/data/transforms/spatial/spacing.py +7 -6
  120. eva/vision/data/transforms/utility/ensure_channel_first.py +6 -6
  121. eva/vision/models/networks/backbones/universal/vit.py +24 -0
  122. eva/vision/models/wrappers/from_registry.py +6 -5
  123. eva/vision/models/wrappers/from_timm.py +6 -4
  124. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/METADATA +17 -3
  125. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/RECORD +128 -66
  126. eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
  127. eva/core/data/dataloaders/collate_fn/collate.py +0 -24
  128. eva/language/models/modules/text.py +0 -85
  129. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/WHEEL +0 -0
  130. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/entry_points.txt +0 -0
  131. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,193 @@
1
+ """HuggingFace Vision-Language Model Wrapper."""
2
+
3
+ import functools
4
+ from typing import Any, Callable, Dict, List
5
+
6
+ import torch
7
+ import transformers
8
+ from loguru import logger
9
+ from torch import nn
10
+ from typing_extensions import override
11
+
12
+ from eva.language.models.typings import ModelOutput, TextBatch
13
+ from eva.language.utils.text import messages as language_message_utils
14
+ from eva.multimodal.models.typings import TextImageBatch
15
+ from eva.multimodal.models.wrappers import base
16
+ from eva.multimodal.utils.batch import unpack_batch
17
+ from eva.multimodal.utils.text import messages as message_utils
18
+
19
+
20
+ class HuggingFaceModel(base.VisionLanguageModel):
21
+ """Lightweight wrapper for Huggingface VLMs.
22
+
23
+ Args:
24
+ model_name_or_path: The name of the model to use.
25
+ model_class: The class of the model to use.
26
+ model_kwargs: Additional model arguments.
27
+ processor_kwargs: Additional processor arguments.
28
+ generation_kwargs: Additional generation arguments.
29
+ """
30
+
31
+ _default_generation_kwargs = {
32
+ "temperature": 0.0,
33
+ "max_new_tokens": 1024,
34
+ "do_sample": False,
35
+ "top_p": 1.0,
36
+ }
37
+ """Default HF model parameters for evaluation."""
38
+
39
+ def __init__(
40
+ self,
41
+ model_name_or_path: str,
42
+ model_class: str,
43
+ model_kwargs: Dict[str, Any] | None = None,
44
+ system_prompt: str | None = None,
45
+ processor_kwargs: Dict[str, Any] | None = None,
46
+ generation_kwargs: Dict[str, Any] | None = None,
47
+ image_key: str = "image",
48
+ ):
49
+ """Initialize the HuggingFace model wrapper.
50
+
51
+ Args:
52
+ model_name_or_path: The name or path of the model to use.
53
+ model_class: The class of the model to use.
54
+ model_kwargs: Additional model arguments.
55
+ system_prompt: System prompt to use.
56
+ processor_kwargs: Additional processor arguments.
57
+ generation_kwargs: Additional generation arguments.
58
+ image_key: The key used for image inputs in the chat template.
59
+ """
60
+ super().__init__(system_prompt=system_prompt)
61
+
62
+ self.model_name_or_path = model_name_or_path
63
+ self.model_kwargs = model_kwargs or {}
64
+ self.base_model_class = model_class
65
+ self.processor_kwargs = processor_kwargs or {}
66
+ self.generation_kwargs = self._default_generation_kwargs | (generation_kwargs or {})
67
+ self.image_key = image_key
68
+
69
+ self.processor = self.load_processor()
70
+ self.model = self.load_model()
71
+
72
+ @override
73
+ def format_inputs(self, batch: TextImageBatch | TextBatch) -> Dict[str, torch.Tensor]:
74
+ """Formats inputs for HuggingFace models.
75
+
76
+ Args:
77
+ batch: A batch of text and image inputs.
78
+
79
+ Returns:
80
+ A dictionary produced by the provided processor following a format like:
81
+ {
82
+ "input_ids": ...,
83
+ "attention_mask": ...,
84
+ "pixel_values": ...
85
+ }
86
+ """
87
+ message_batch, image_batch, _, _ = unpack_batch(batch)
88
+ with_images = image_batch is not None
89
+
90
+ message_batch = language_message_utils.batch_insert_system_message(
91
+ message_batch, self.system_message
92
+ )
93
+ message_batch = list(map(language_message_utils.combine_system_messages, message_batch))
94
+
95
+ if self.processor.chat_template is not None: # type: ignore
96
+ templated_text = [
97
+ self.processor.apply_chat_template( # type: ignore
98
+ message,
99
+ add_generation_prompt=True,
100
+ tokenize=False,
101
+ )
102
+ for message in map(
103
+ functools.partial(
104
+ message_utils.format_huggingface_message,
105
+ with_images=with_images,
106
+ ),
107
+ message_batch,
108
+ )
109
+ ]
110
+ else:
111
+ raise NotImplementedError("Currently only chat models are supported.")
112
+
113
+ processor_inputs = {
114
+ "text": templated_text,
115
+ "return_tensors": "pt",
116
+ **self.processor_kwargs,
117
+ }
118
+
119
+ if with_images:
120
+ processor_inputs[self.image_key] = [[image] for image in image_batch]
121
+
122
+ return self.processor(**processor_inputs).to(self.model.device) # type: ignore
123
+
124
+ @override
125
+ def model_forward(self, batch: Dict[str, torch.Tensor]) -> ModelOutput:
126
+ """Generates text output from the model. Is called by the `generate` method.
127
+
128
+ Args:
129
+ batch: A dictionary containing the input data, which may include:
130
+ - "text": List of messages formatted for the model.
131
+ - "image": List of image tensors.
132
+
133
+ Returns:
134
+ A dictionary containing the processed input and the model's output.
135
+ """
136
+ output_ids = self.model.generate(**batch, **self.generation_kwargs) # type: ignore
137
+
138
+ return ModelOutput(
139
+ generated_text=self._decode_output(output_ids, batch["input_ids"].shape[-1]),
140
+ input_ids=batch.get("input_ids"),
141
+ output_ids=output_ids,
142
+ attention_mask=batch.get("attention_mask"),
143
+ )
144
+
145
+ @override
146
+ def load_model(self) -> nn.Module:
147
+ """Setting up the model. Used for delayed model initialization.
148
+
149
+ Raises:
150
+ ValueError: If the model class is not found in transformers or if the model
151
+ does not support gradient checkpointing but it is enabled.
152
+ """
153
+ logger.info(f"Configuring model: {self.model_name_or_path}")
154
+ if hasattr(transformers, self.base_model_class):
155
+ model_class = getattr(transformers, self.base_model_class)
156
+ else:
157
+ raise ValueError(f"Model class {self.base_model_class} not found in transformers")
158
+
159
+ model = model_class.from_pretrained(self.model_name_or_path, **self.model_kwargs)
160
+
161
+ if not hasattr(model, "generate"):
162
+ raise ValueError(f"Model {self.model_name_or_path} does not support generation. ")
163
+
164
+ return model
165
+
166
+ def load_processor(self) -> Callable:
167
+ """Initialize the processor."""
168
+ return transformers.AutoProcessor.from_pretrained(
169
+ self.processor_kwargs.pop("model_name_or_path", self.model_name_or_path),
170
+ **self.processor_kwargs,
171
+ )
172
+
173
+ def _decode_output(self, output: torch.Tensor, instruction_length: int) -> List[str]:
174
+ """Decode the model's batch output to text.
175
+
176
+ Args:
177
+ output: The raw output from the model.
178
+ instruction_length: The length of the instruction in the input.
179
+
180
+ Returns:
181
+ A list of decoded text responses.
182
+ """
183
+ decoded_input = self.processor.batch_decode( # type: ignore
184
+ output[:, :instruction_length], skip_special_tokens=True
185
+ )
186
+ decoded_output = self.processor.batch_decode( # type: ignore
187
+ output[:, instruction_length:], skip_special_tokens=True
188
+ )
189
+
190
+ logger.debug(f"Decoded input: {decoded_input}")
191
+ logger.debug(f"Decoded output: {decoded_output}")
192
+
193
+ return decoded_output
@@ -0,0 +1,58 @@
1
+ """LiteLLM vision-language model wrapper."""
2
+
3
+ import logging
4
+ from typing import Any, Dict, List
5
+
6
+ from typing_extensions import override
7
+
8
+ from eva.language.models import wrappers as language_wrappers
9
+ from eva.language.models.typings import ModelOutput
10
+ from eva.language.utils.text import messages as language_message_utils
11
+ from eva.multimodal.models.typings import TextImageBatch
12
+ from eva.multimodal.models.wrappers import base
13
+ from eva.multimodal.utils.batch import unpack_batch
14
+ from eva.multimodal.utils.text import messages as message_utils
15
+
16
+
17
+ class LiteLLMModel(base.VisionLanguageModel):
18
+ """Wrapper class for LiteLLM vision-language models."""
19
+
20
+ def __init__(
21
+ self,
22
+ model_name: str,
23
+ model_kwargs: Dict[str, Any] | None = None,
24
+ system_prompt: str | None = None,
25
+ log_level: int | None = logging.INFO,
26
+ ):
27
+ """Initialize the LiteLLM Wrapper.
28
+
29
+ Args:
30
+ model_name: The name of the model to use.
31
+ model_kwargs: Additional keyword arguments to pass during
32
+ generation (e.g., `temperature`, `max_tokens`).
33
+ system_prompt: The system prompt to use (optional).
34
+ log_level: Optional logging level for LiteLLM. Defaults to WARNING.
35
+ """
36
+ super().__init__(system_prompt=system_prompt)
37
+
38
+ self.language_model = language_wrappers.LiteLLMModel(
39
+ model_name=model_name,
40
+ model_kwargs=model_kwargs,
41
+ system_prompt=system_prompt,
42
+ log_level=log_level,
43
+ )
44
+
45
+ @override
46
+ def format_inputs(self, batch: TextImageBatch) -> List[List[Dict[str, Any]]]:
47
+ message_batch, image_batch, _, _ = unpack_batch(batch)
48
+
49
+ message_batch = language_message_utils.batch_insert_system_message(
50
+ message_batch, self.system_message
51
+ )
52
+ message_batch = list(map(language_message_utils.combine_system_messages, message_batch))
53
+
54
+ return list(map(message_utils.format_litellm_message, message_batch, image_batch))
55
+
56
+ @override
57
+ def model_forward(self, batch: List[List[Dict[str, Any]]]) -> ModelOutput:
58
+ return self.language_model.model_forward(batch)
@@ -0,0 +1 @@
1
+ """Multimodal utilities API."""
@@ -0,0 +1,5 @@
1
+ """Multimodal batch utilities API."""
2
+
3
+ from eva.multimodal.utils.batch.unpack import unpack_batch
4
+
5
+ __all__ = ["unpack_batch"]
@@ -0,0 +1,11 @@
1
+ """Unpack batch utility function."""
2
+
3
+ from eva.language.models.typings import TextBatch
4
+ from eva.multimodal.models.typings import TextImageBatch
5
+
6
+
7
+ def unpack_batch(batch: TextImageBatch | TextBatch) -> tuple:
8
+ """Unpacks a TextImageBatch or TextBatch into its components."""
9
+ if isinstance(batch, TextImageBatch):
10
+ return batch.text, batch.image, batch.target, batch.metadata
11
+ return batch.text, None, batch.target, batch.metadata
@@ -0,0 +1,5 @@
1
+ """Multimodal image utilities API."""
2
+
3
+ from eva.multimodal.utils.image.encode import encode_image
4
+
5
+ __all__ = ["encode_image"]
@@ -0,0 +1,28 @@
1
+ """Image encoding utilities."""
2
+
3
+ import base64
4
+ import io
5
+ from typing import Literal
6
+
7
+ from torchvision import tv_tensors
8
+ from torchvision.transforms.v2 import functional as F
9
+
10
+
11
+ def encode_image(image: tv_tensors.Image, encoding: Literal["base64"]) -> str:
12
+ """Encodes an image tensor into a string format.
13
+
14
+ Args:
15
+ image: The image tensor to encode.
16
+ encoding: The encoding format to use. Currently only supports "base64".
17
+
18
+ Returns:
19
+ An encoded string representation of the image.
20
+ """
21
+ match encoding:
22
+ case "base64":
23
+ image_bytes = io.BytesIO()
24
+ F.to_pil_image(image).save(image_bytes, format="PNG", optimize=True)
25
+ image_bytes.seek(0)
26
+ return base64.b64encode(image_bytes.getvalue()).decode("utf-8")
27
+ case _:
28
+ raise ValueError(f"Unsupported encoding type: {encoding}. Supported: 'base64'")
@@ -0,0 +1 @@
1
+ """Multimodal text utilities API."""
@@ -0,0 +1,79 @@
1
+ """Message formatting utilities for multimodal models."""
2
+
3
+ from typing import Any, Dict, List
4
+
5
+ from torchvision import tv_tensors
6
+
7
+ from eva.language import utils as language_utils
8
+ from eva.language.data.messages import MessageSeries, Role
9
+ from eva.multimodal.utils import image as image_utils
10
+
11
+
12
+ def format_huggingface_message(
13
+ message: MessageSeries, with_images: bool = False
14
+ ) -> List[Dict[str, Any]]:
15
+ """Formats a message series into a format suitable for Huggingface models."""
16
+ if not with_images:
17
+ return language_utils.format_chat_message(message)
18
+
19
+ formatted_message = []
20
+ for item in message:
21
+ if item.role == Role.SYSTEM:
22
+ formatted_message += language_utils.format_chat_message([item])
23
+ else:
24
+ formatted_message.append(
25
+ {
26
+ "role": item.role,
27
+ "content": [
28
+ {
29
+ "type": "text",
30
+ "text": str(item.content),
31
+ },
32
+ {"type": "image"},
33
+ ],
34
+ }
35
+ )
36
+ return formatted_message
37
+
38
+
39
+ def format_litellm_message(
40
+ message: MessageSeries, image: tv_tensors.Image | None
41
+ ) -> List[Dict[str, Any]]:
42
+ """Format a message series for LiteLLM API.
43
+
44
+ Args:
45
+ message: The message series to format.
46
+ image: Optional image to include in the message.
47
+
48
+ Returns:
49
+ A list of formatted message dictionaries.
50
+ """
51
+ if image is None:
52
+ return language_utils.format_chat_message(message)
53
+
54
+ formatted_message = []
55
+ for item in message:
56
+ if item.role == Role.SYSTEM:
57
+ formatted_message += language_utils.format_chat_message([item])
58
+ else:
59
+ formatted_message.append(
60
+ {
61
+ "role": item.role,
62
+ "content": [
63
+ {
64
+ "type": "text",
65
+ "text": str(item.content),
66
+ },
67
+ {
68
+ "type": "image_url",
69
+ "image_url": {
70
+ "url": (
71
+ f"data:image/png;base64,"
72
+ f"{image_utils.encode_image(image, encoding='base64')}"
73
+ )
74
+ },
75
+ },
76
+ ],
77
+ }
78
+ )
79
+ return formatted_message
@@ -101,11 +101,6 @@ class BreaKHis(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
101
101
  def class_to_idx(self) -> Dict[str, int]:
102
102
  return {label: index for index, label in enumerate(self.classes)}
103
103
 
104
- @property
105
- def _dataset_path(self) -> str:
106
- """Returns the path of the image data of the dataset."""
107
- return os.path.join(self._root, "BreaKHis_v1", "histology_slides")
108
-
109
104
  @functools.cached_property
110
105
  def _image_files(self) -> List[str]:
111
106
  """Return the list of image files in the dataset.
@@ -115,14 +110,14 @@ class BreaKHis(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
115
110
  """
116
111
  image_files = []
117
112
  for magnification in self._magnifications:
118
- files_pattern = os.path.join(self._dataset_path, f"**/{magnification}", "*.png")
113
+ files_pattern = os.path.join(self._root, f"**/{magnification}", "*.png")
119
114
  image_files.extend(list(glob.glob(files_pattern, recursive=True)))
120
115
  return sorted(image_files)
121
116
 
122
117
  @override
123
118
  def filename(self, index: int) -> str:
124
119
  image_path = self._image_files[self._indices[index]]
125
- return os.path.relpath(image_path, self._dataset_path)
120
+ return os.path.relpath(image_path, self._root)
126
121
 
127
122
  @override
128
123
  def prepare_data(self) -> None:
@@ -136,6 +131,8 @@ class BreaKHis(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
136
131
 
137
132
  @override
138
133
  def validate(self) -> None:
134
+ if not os.path.exists(self._root):
135
+ raise RuntimeError(f"Dataset not found at {self._root}.")
139
136
  _validators.check_dataset_integrity(
140
137
  self,
141
138
  length=self._expected_dataset_lengths[self._split],
@@ -164,7 +161,7 @@ class BreaKHis(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
164
161
  def _download_dataset(self) -> None:
165
162
  """Downloads the dataset."""
166
163
  for resource in self._resources:
167
- if os.path.isdir(self._dataset_path):
164
+ if os.path.isdir(self._root):
168
165
  continue
169
166
 
170
167
  self._print_license()
@@ -12,6 +12,7 @@ from torchvision.datasets import utils
12
12
  from torchvision.transforms.v2 import functional
13
13
  from typing_extensions import override
14
14
 
15
+ from eva.core import utils as core_utils
15
16
  from eva.core.data import splitting
16
17
  from eva.vision.data.datasets import _validators, structs, vision, wsi
17
18
  from eva.vision.data.wsi.patching import samplers
@@ -50,6 +51,7 @@ class PANDA(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Te
50
51
  image_transforms: Callable | None = None,
51
52
  coords_path: str | None = None,
52
53
  seed: int = 42,
54
+ download_dir: str | None = None,
53
55
  ) -> None:
54
56
  """Initializes the dataset.
55
57
 
@@ -64,10 +66,13 @@ class PANDA(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Te
64
66
  image_transforms: Transforms to apply to the extracted image patches.
65
67
  coords_path: File path to save the patch coordinates as .csv.
66
68
  seed: Random seed for reproducibility.
69
+ download_dir: Directory to download the dataset resources to. If None,
70
+ defaults to eva's home directory.
67
71
  """
68
72
  self._split = split
69
73
  self._root = root
70
74
  self._seed = seed
75
+ self._download_dir = download_dir or os.path.join(core_utils.home_dir(), "data", "panda")
71
76
 
72
77
  self._download_resources()
73
78
 
@@ -92,7 +97,7 @@ class PANDA(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Te
92
97
  @functools.cached_property
93
98
  def annotations(self) -> pd.DataFrame:
94
99
  """Loads the dataset labels."""
95
- path = os.path.join(self._root, "train_with_noisy_labels.csv")
100
+ path = os.path.join(self._download_dir, "train_with_noisy_labels.csv")
96
101
  return pd.read_csv(path, index_col="image_id")
97
102
 
98
103
  @override
@@ -100,14 +105,16 @@ class PANDA(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Te
100
105
  _validators.check_dataset_exists(self._root, False)
101
106
 
102
107
  if not os.path.isdir(os.path.join(self._root, "train_images")):
103
- raise FileNotFoundError("'train_images' directory not found in the root folder.")
104
- if not os.path.isfile(os.path.join(self._root, "train_with_noisy_labels.csv")):
105
- raise FileNotFoundError("'train.csv' file not found in the root folder.")
108
+ raise FileNotFoundError(f"'train_images' dir not found in folder: {self._root}")
109
+ if not os.path.isfile(os.path.join(self._download_dir, "train_with_noisy_labels.csv")):
110
+ raise FileNotFoundError(
111
+ f"'train_with_noisy_labels.csv' file not found in folder: {self._download_dir}"
112
+ )
106
113
 
107
114
  def _download_resources(self) -> None:
108
115
  """Downloads the dataset resources."""
109
116
  for resource in self._resources:
110
- utils.download_url(resource.url, self._root, resource.filename, resource.md5)
117
+ utils.download_url(resource.url, self._download_dir, resource.filename, resource.md5)
111
118
 
112
119
  @override
113
120
  def validate(self) -> None:
@@ -61,6 +61,13 @@ class PatchCamelyon(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
61
61
  ]
62
62
  """Test resources."""
63
63
 
64
+ _expected_length = {
65
+ "train": 262144,
66
+ "val": 32768,
67
+ "test": 32768,
68
+ }
69
+ """Expected dataset length for each split."""
70
+
64
71
  _license: str = (
65
72
  "Creative Commons Zero v1.0 Universal (https://choosealicense.com/licenses/cc0-1.0/)"
66
73
  )
@@ -113,14 +120,9 @@ class PatchCamelyon(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
113
120
 
114
121
  @override
115
122
  def validate(self) -> None:
116
- expected_length = {
117
- "train": 262144,
118
- "val": 32768,
119
- "test": 32768,
120
- }
121
123
  _validators.check_dataset_integrity(
122
124
  self,
123
- length=expected_length.get(self._split, 0),
125
+ length=self._expected_length.get(self._split, 0),
124
126
  n_classes=2,
125
127
  first_and_last_labels=("no_tumor", "tumor"),
126
128
  )
@@ -106,7 +106,7 @@ class BTCV(VisionDataset[eva_tv_tensors.Volume, tv_tensors.Mask]):
106
106
 
107
107
  @override
108
108
  def validate(self) -> None:
109
- requirements.check_dependencies(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
109
+ requirements.check_min_versions(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
110
110
 
111
111
  def _valid_sample(index: int) -> bool:
112
112
  """Indicates if the sample files exist and are reachable."""
@@ -108,7 +108,7 @@ class CoNSeP(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, tv_tens
108
108
  n_classes=5,
109
109
  first_and_last_labels=((self.classes[0], self.classes[-1])),
110
110
  )
111
- n_expected = self._expected_dataset_lengths[None]
111
+ n_expected = self._expected_dataset_lengths[self._split]
112
112
  if len(self._file_paths) != n_expected:
113
113
  raise ValueError(
114
114
  f"Expected {n_expected} images, found {len(self._file_paths)} in {self._root}."
@@ -123,7 +123,7 @@ class LiTS17(VisionDataset[eva_tv_tensors.Volume, tv_tensors.Mask]):
123
123
 
124
124
  @override
125
125
  def validate(self) -> None:
126
- requirements.check_dependencies(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
126
+ requirements.check_min_versions(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
127
127
 
128
128
  def _valid_sample(index: int) -> bool:
129
129
  """Indicates if the sample files exist and are reachable."""
@@ -15,6 +15,7 @@ from torchvision import tv_tensors
15
15
  from torchvision.datasets import utils
16
16
  from typing_extensions import override
17
17
 
18
+ from eva.core import utils as core_utils
18
19
  from eva.core.utils.progress_bar import tqdm
19
20
  from eva.vision.data.datasets import _validators, structs, vision
20
21
  from eva.vision.utils import io
@@ -55,6 +56,7 @@ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
55
56
  root: str,
56
57
  split: Literal["train", "test"],
57
58
  export_masks: bool = True,
59
+ processed_dir: str | None = None,
58
60
  download: bool = False,
59
61
  transforms: Callable | None = None,
60
62
  ) -> None:
@@ -66,6 +68,8 @@ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
66
68
  split: Dataset split to use.
67
69
  export_masks: Whether to export, save and use the semantic label masks
68
70
  from disk.
71
+ processed_dir: Directory where to store the processed masks.
72
+ Only used if `export_masks` is `True`.
69
73
  download: Whether to download the data for the specified split.
70
74
  Note that the download will be executed only by additionally
71
75
  calling the :meth:`prepare_data` method and if the data does not
@@ -79,6 +83,9 @@ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
79
83
  self._split = split
80
84
  self._export_masks = export_masks
81
85
  self._download = download
86
+ self._processed_dir = processed_dir or os.path.join(
87
+ core_utils.home_dir(), "data", "processed", "monusac"
88
+ )
82
89
 
83
90
  @property
84
91
  @override
@@ -155,10 +162,7 @@ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
155
162
 
156
163
  def _export_semantic_label_masks(self) -> None:
157
164
  """Export semantic label masks to disk."""
158
- mask_files = [
159
- (index, filename.replace(".tif", ".npy"))
160
- for index, filename in enumerate(self._image_files)
161
- ]
165
+ mask_files = [(i, self._processed_filename(i)) for i in range(len(self._image_files))]
162
166
  to_export = filter(lambda x: not os.path.isfile(x[1]), mask_files)
163
167
  for sample_index, filename in tqdm(
164
168
  list(to_export),
@@ -166,6 +170,7 @@ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
166
170
  leave=False,
167
171
  ):
168
172
  semantic_labels = self._get_semantic_mask(sample_index)
173
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
169
174
  np.save(filename, semantic_labels)
170
175
 
171
176
  def _load_semantic_mask_file(self, index: int) -> npt.NDArray[Any]:
@@ -177,8 +182,7 @@ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
177
182
  Returns:
178
183
  Loaded mask as a numpy array.
179
184
  """
180
- mask_filename = self._image_files[index].replace(".tif", ".npy")
181
- return np.load(mask_filename)
185
+ return np.load(self._processed_filename(index))
182
186
 
183
187
  def _get_semantic_mask(self, index: int) -> npt.NDArray[Any]:
184
188
  """Builds and loads the semantic label mask from the XML annotations.
@@ -216,6 +220,11 @@ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
216
220
 
217
221
  return semantic_labels
218
222
 
223
+ def _processed_filename(self, index: int) -> str:
224
+ """Returns the path of the processed mask for a given index."""
225
+ relative_path = os.path.relpath(self._image_files[index], self._root)
226
+ return os.path.join(self._processed_dir, relative_path).replace(".tif", ".npy")
227
+
219
228
  def _download_dataset(self) -> None:
220
229
  """Downloads the dataset."""
221
230
  self._print_license()
@@ -95,7 +95,7 @@ class MSDTask7Pancreas(VisionDataset[eva_tv_tensors.Volume, tv_tensors.Mask]):
95
95
 
96
96
  @override
97
97
  def validate(self) -> None:
98
- requirements.check_dependencies(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
98
+ requirements.check_min_versions(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
99
99
 
100
100
  def _valid_sample(index: int) -> bool:
101
101
  """Indicates if the sample files exist and are reachable."""
@@ -13,10 +13,11 @@ from eva.vision.data.transforms.intensity import (
13
13
  RandShiftIntensity,
14
14
  ScaleIntensityRange,
15
15
  )
16
- from eva.vision.data.transforms.spatial import RandFlip, RandRotate90, Spacing
16
+ from eva.vision.data.transforms.spatial import RandFlip, RandRotate90, Resize, Spacing
17
17
  from eva.vision.data.transforms.utility import EnsureChannelFirst
18
18
 
19
19
  __all__ = [
20
+ "Resize",
20
21
  "ResizeAndCrop",
21
22
  "Squeeze",
22
23
  "CropForeground",
@@ -1,5 +1,6 @@
1
1
  """Base classes for transforms."""
2
2
 
3
3
  from eva.vision.data.transforms.base.monai import RandomMonaiTransform
4
+ from eva.vision.data.transforms.base.torchvision import TorchvisionTransformV2
4
5
 
5
- __all__ = ["RandomMonaiTransform"]
6
+ __all__ = ["RandomMonaiTransform", "TorchvisionTransformV2"]