kaiko-eva 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (54) hide show
  1. eva/core/callbacks/config.py +11 -6
  2. eva/core/callbacks/writers/embeddings/base.py +44 -10
  3. eva/core/data/samplers/classification/balanced.py +24 -12
  4. eva/core/loggers/utils/wandb.py +4 -1
  5. eva/core/trainers/trainer.py +11 -1
  6. eva/core/utils/__init__.py +2 -1
  7. eva/core/utils/distributed.py +12 -0
  8. eva/core/utils/paths.py +14 -0
  9. eva/core/utils/requirements.py +52 -6
  10. eva/language/callbacks/writers/prediction.py +44 -19
  11. eva/language/data/datasets/classification/pubmedqa.py +1 -1
  12. eva/language/models/modules/language.py +7 -6
  13. eva/language/models/typings.py +19 -2
  14. eva/language/models/wrappers/base.py +4 -4
  15. eva/language/models/wrappers/huggingface.py +14 -4
  16. eva/language/models/wrappers/litellm.py +14 -4
  17. eva/multimodal/models/modules/vision_language.py +6 -5
  18. eva/multimodal/models/networks/alibaba.py +1 -0
  19. eva/multimodal/models/networks/others.py +2 -1
  20. eva/multimodal/models/wrappers/base.py +4 -3
  21. eva/multimodal/models/wrappers/huggingface.py +26 -13
  22. eva/multimodal/models/wrappers/litellm.py +4 -2
  23. eva/multimodal/utils/batch/__init__.py +5 -0
  24. eva/multimodal/utils/batch/unpack.py +11 -0
  25. eva/vision/data/datasets/classification/breakhis.py +5 -8
  26. eva/vision/data/datasets/classification/panda.py +12 -5
  27. eva/vision/data/datasets/segmentation/btcv.py +1 -1
  28. eva/vision/data/datasets/segmentation/consep.py +1 -1
  29. eva/vision/data/datasets/segmentation/lits17.py +1 -1
  30. eva/vision/data/datasets/segmentation/monusac.py +15 -6
  31. eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +1 -1
  32. eva/vision/data/transforms/base/__init__.py +2 -1
  33. eva/vision/data/transforms/base/monai.py +2 -2
  34. eva/vision/data/transforms/base/torchvision.py +33 -0
  35. eva/vision/data/transforms/common/squeeze.py +6 -3
  36. eva/vision/data/transforms/croppad/crop_foreground.py +8 -7
  37. eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +6 -5
  38. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -5
  39. eva/vision/data/transforms/croppad/rand_spatial_crop.py +8 -7
  40. eva/vision/data/transforms/croppad/spatial_pad.py +6 -6
  41. eva/vision/data/transforms/intensity/rand_scale_intensity.py +3 -3
  42. eva/vision/data/transforms/intensity/rand_shift_intensity.py +3 -3
  43. eva/vision/data/transforms/intensity/scale_intensity_ranged.py +5 -5
  44. eva/vision/data/transforms/spatial/flip.py +8 -7
  45. eva/vision/data/transforms/spatial/resize.py +5 -4
  46. eva/vision/data/transforms/spatial/rotate.py +8 -7
  47. eva/vision/data/transforms/spatial/spacing.py +7 -6
  48. eva/vision/data/transforms/utility/ensure_channel_first.py +6 -6
  49. eva/vision/models/networks/backbones/universal/vit.py +24 -0
  50. {kaiko_eva-0.4.0.dist-info → kaiko_eva-0.4.1.dist-info}/METADATA +8 -2
  51. {kaiko_eva-0.4.0.dist-info → kaiko_eva-0.4.1.dist-info}/RECORD +54 -49
  52. {kaiko_eva-0.4.0.dist-info → kaiko_eva-0.4.1.dist-info}/WHEEL +0 -0
  53. {kaiko_eva-0.4.0.dist-info → kaiko_eva-0.4.1.dist-info}/entry_points.txt +0 -0
  54. {kaiko_eva-0.4.0.dist-info → kaiko_eva-0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -16,7 +16,7 @@ from litellm.exceptions import (
16
16
  from loguru import logger
17
17
  from typing_extensions import override
18
18
 
19
- from eva.language.models.typings import TextBatch
19
+ from eva.language.models.typings import ModelOutput, TextBatch
20
20
  from eva.language.models.wrappers import base
21
21
  from eva.language.utils.text import messages as message_utils
22
22
 
@@ -32,6 +32,14 @@ RETRYABLE_ERRORS = (
32
32
  class LiteLLMModel(base.LanguageModel):
33
33
  """Wrapper class for LiteLLM language models."""
34
34
 
35
+ _default_model_kwargs = {
36
+ "temperature": 0.0,
37
+ "max_completion_tokens": 1024,
38
+ "top_p": 1.0,
39
+ "seed": 42,
40
+ }
41
+ """Default API model parameters for evaluation."""
42
+
35
43
  def __init__(
36
44
  self,
37
45
  model_name: str,
@@ -51,9 +59,10 @@ class LiteLLMModel(base.LanguageModel):
51
59
  super().__init__(system_prompt=system_prompt)
52
60
 
53
61
  self.model_name = model_name
54
- self.model_kwargs = model_kwargs or {}
62
+ self.model_kwargs = self._default_model_kwargs | (model_kwargs or {})
55
63
 
56
64
  litellm.suppress_debug_info = True
65
+ litellm.drop_params = True
57
66
 
58
67
  if log_level is not None:
59
68
  logging.getLogger("LiteLLM").setLevel(log_level)
@@ -94,16 +103,17 @@ class LiteLLMModel(base.LanguageModel):
94
103
  f"Retrying due to {details.get('exception') or 'Unknown error'}"
95
104
  ),
96
105
  )
97
- def model_forward(self, batch: List[List[Dict[str, Any]]]) -> List[str]:
106
+ def model_forward(self, batch: List[List[Dict[str, Any]]]) -> ModelOutput:
98
107
  """Generates output text through API calls via LiteLLM's batch completion functionality."""
99
108
  outputs = batch_completion(model=self.model_name, messages=batch, **self.model_kwargs)
100
109
  self._raise_exceptions(outputs)
101
110
 
102
- return [
111
+ generated_text = [
103
112
  output["choices"][0]["message"]["content"]
104
113
  for output in outputs
105
114
  if output["choices"][0]["message"]["role"] == "assistant"
106
115
  ]
116
+ return ModelOutput(generated_text=generated_text)
107
117
 
108
118
  def _raise_exceptions(self, outputs: list):
109
119
  for output in outputs:
@@ -1,6 +1,6 @@
1
1
  """Model module for vision-language models."""
2
2
 
3
- from typing import Any, List
3
+ from typing import Any
4
4
 
5
5
  from lightning.pytorch.utilities.types import STEP_OUTPUT
6
6
  from torch import nn
@@ -9,6 +9,7 @@ from typing_extensions import override
9
9
  from eva.core.metrics import structs as metrics_lib
10
10
  from eva.core.models.modules import module
11
11
  from eva.core.models.modules.utils import batch_postprocess
12
+ from eva.language.models.typings import ModelOutput
12
13
  from eva.multimodal.models.typings import TextImageBatch
13
14
 
14
15
 
@@ -33,7 +34,7 @@ class VisionLanguageModule(module.ModelModule):
33
34
  self.model = model
34
35
 
35
36
  @override
36
- def forward(self, batch: TextImageBatch, *args: Any, **kwargs: Any) -> List[str]:
37
+ def forward(self, batch: TextImageBatch, *args: Any, **kwargs: Any) -> ModelOutput:
37
38
  return self.model(batch)
38
39
 
39
40
  @override
@@ -46,10 +47,10 @@ class VisionLanguageModule(module.ModelModule):
46
47
 
47
48
  def _batch_step(self, batch: TextImageBatch) -> STEP_OUTPUT:
48
49
  text, _, targets, metadata = TextImageBatch(*batch)
49
- predictions = self.forward(batch)
50
+ output = self.forward(batch)
50
51
  return {
51
52
  "inputs": text,
52
- "predictions": predictions,
53
+ "predictions": output.pop("generated_text"), # type: ignore
53
54
  "targets": targets,
54
55
  "metadata": metadata,
55
- }
56
+ } | output
@@ -36,4 +36,5 @@ class Qwen25VL7BInstruct(wrappers.HuggingFaceModel):
36
36
  "max_pixels": 451584, # 672*672
37
37
  },
38
38
  system_prompt=system_prompt,
39
+ image_key="images",
39
40
  )
@@ -20,7 +20,7 @@ class PathoR13b(wrappers.HuggingFaceModel):
20
20
  attn_implementation: str = "flash_attention_2",
21
21
  ):
22
22
  """Initialize the Patho-R1-3B model."""
23
- requirements.check_dependencies(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
23
+ requirements.check_min_versions(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
24
24
 
25
25
  if not os.getenv("HF_TOKEN"):
26
26
  raise ValueError("HF_TOKEN env variable must be set.")
@@ -44,4 +44,5 @@ class PathoR13b(wrappers.HuggingFaceModel):
44
44
  "max_pixels": 451584, # 672*672
45
45
  },
46
46
  system_prompt=system_prompt,
47
+ image_key="images",
47
48
  )
@@ -1,16 +1,17 @@
1
1
  """Base class for vision language model wrappers."""
2
2
 
3
3
  import abc
4
- from typing import Any, Callable, List
4
+ from typing import Any, Callable
5
5
 
6
6
  from typing_extensions import override
7
7
 
8
8
  from eva.core.models.wrappers import base
9
9
  from eva.language.data.messages import ModelSystemMessage
10
+ from eva.language.models.typings import ModelOutput
10
11
  from eva.multimodal.models.typings import TextImageBatch
11
12
 
12
13
 
13
- class VisionLanguageModel(base.BaseModel[TextImageBatch, List[str]]):
14
+ class VisionLanguageModel(base.BaseModel[TextImageBatch, ModelOutput]):
14
15
  """Base class for multimodal models.
15
16
 
16
17
  Classes that inherit from this should implement the following methods:
@@ -36,7 +37,7 @@ class VisionLanguageModel(base.BaseModel[TextImageBatch, List[str]]):
36
37
  self.system_message = ModelSystemMessage(content=system_prompt) if system_prompt else None
37
38
 
38
39
  @override
39
- def forward(self, batch: TextImageBatch) -> List[str]:
40
+ def forward(self, batch: TextImageBatch) -> ModelOutput:
40
41
  """Forward pass of the model."""
41
42
  inputs = self.format_inputs(batch)
42
43
  return super().forward(inputs)
@@ -9,10 +9,11 @@ from loguru import logger
9
9
  from torch import nn
10
10
  from typing_extensions import override
11
11
 
12
- from eva.language.models.typings import TextBatch
12
+ from eva.language.models.typings import ModelOutput, TextBatch
13
13
  from eva.language.utils.text import messages as language_message_utils
14
14
  from eva.multimodal.models.typings import TextImageBatch
15
15
  from eva.multimodal.models.wrappers import base
16
+ from eva.multimodal.utils.batch import unpack_batch
16
17
  from eva.multimodal.utils.text import messages as message_utils
17
18
 
18
19
 
@@ -27,6 +28,14 @@ class HuggingFaceModel(base.VisionLanguageModel):
27
28
  generation_kwargs: Additional generation arguments.
28
29
  """
29
30
 
31
+ _default_generation_kwargs = {
32
+ "temperature": 0.0,
33
+ "max_new_tokens": 1024,
34
+ "do_sample": False,
35
+ "top_p": 1.0,
36
+ }
37
+ """Default HF model parameters for evaluation."""
38
+
30
39
  def __init__(
31
40
  self,
32
41
  model_name_or_path: str,
@@ -35,6 +44,7 @@ class HuggingFaceModel(base.VisionLanguageModel):
35
44
  system_prompt: str | None = None,
36
45
  processor_kwargs: Dict[str, Any] | None = None,
37
46
  generation_kwargs: Dict[str, Any] | None = None,
47
+ image_key: str = "image",
38
48
  ):
39
49
  """Initialize the HuggingFace model wrapper.
40
50
 
@@ -45,6 +55,7 @@ class HuggingFaceModel(base.VisionLanguageModel):
45
55
  system_prompt: System prompt to use.
46
56
  processor_kwargs: Additional processor arguments.
47
57
  generation_kwargs: Additional generation arguments.
58
+ image_key: The key used for image inputs in the chat template.
48
59
  """
49
60
  super().__init__(system_prompt=system_prompt)
50
61
 
@@ -52,7 +63,8 @@ class HuggingFaceModel(base.VisionLanguageModel):
52
63
  self.model_kwargs = model_kwargs or {}
53
64
  self.base_model_class = model_class
54
65
  self.processor_kwargs = processor_kwargs or {}
55
- self.generation_kwargs = generation_kwargs or {}
66
+ self.generation_kwargs = self._default_generation_kwargs | (generation_kwargs or {})
67
+ self.image_key = image_key
56
68
 
57
69
  self.processor = self.load_processor()
58
70
  self.model = self.load_model()
@@ -72,7 +84,7 @@ class HuggingFaceModel(base.VisionLanguageModel):
72
84
  "pixel_values": ...
73
85
  }
74
86
  """
75
- message_batch, image_batch, _, _ = self._unpack_batch(batch)
87
+ message_batch, image_batch, _, _ = unpack_batch(batch)
76
88
  with_images = image_batch is not None
77
89
 
78
90
  message_batch = language_message_utils.batch_insert_system_message(
@@ -105,12 +117,12 @@ class HuggingFaceModel(base.VisionLanguageModel):
105
117
  }
106
118
 
107
119
  if with_images:
108
- processor_inputs["image"] = [[image] for image in image_batch]
120
+ processor_inputs[self.image_key] = [[image] for image in image_batch]
109
121
 
110
122
  return self.processor(**processor_inputs).to(self.model.device) # type: ignore
111
123
 
112
124
  @override
113
- def model_forward(self, batch: Dict[str, torch.Tensor]) -> List[str]:
125
+ def model_forward(self, batch: Dict[str, torch.Tensor]) -> ModelOutput:
114
126
  """Generates text output from the model. Is called by the `generate` method.
115
127
 
116
128
  Args:
@@ -121,8 +133,14 @@ class HuggingFaceModel(base.VisionLanguageModel):
121
133
  Returns:
122
134
  A dictionary containing the processed input and the model's output.
123
135
  """
124
- output = self.model.generate(**batch, **self.generation_kwargs) # type: ignore
125
- return self._decode_output(output, batch["input_ids"].shape[-1])
136
+ output_ids = self.model.generate(**batch, **self.generation_kwargs) # type: ignore
137
+
138
+ return ModelOutput(
139
+ generated_text=self._decode_output(output_ids, batch["input_ids"].shape[-1]),
140
+ input_ids=batch.get("input_ids"),
141
+ output_ids=output_ids,
142
+ attention_mask=batch.get("attention_mask"),
143
+ )
126
144
 
127
145
  @override
128
146
  def load_model(self) -> nn.Module:
@@ -148,15 +166,10 @@ class HuggingFaceModel(base.VisionLanguageModel):
148
166
  def load_processor(self) -> Callable:
149
167
  """Initialize the processor."""
150
168
  return transformers.AutoProcessor.from_pretrained(
151
- self.model_name_or_path,
169
+ self.processor_kwargs.pop("model_name_or_path", self.model_name_or_path),
152
170
  **self.processor_kwargs,
153
171
  )
154
172
 
155
- def _unpack_batch(self, batch: TextImageBatch | TextBatch) -> tuple:
156
- if isinstance(batch, TextImageBatch):
157
- return batch.text, batch.image, batch.target, batch.metadata
158
- return batch.text, None, batch.target, batch.metadata
159
-
160
173
  def _decode_output(self, output: torch.Tensor, instruction_length: int) -> List[str]:
161
174
  """Decode the model's batch output to text.
162
175
 
@@ -6,9 +6,11 @@ from typing import Any, Dict, List
6
6
  from typing_extensions import override
7
7
 
8
8
  from eva.language.models import wrappers as language_wrappers
9
+ from eva.language.models.typings import ModelOutput
9
10
  from eva.language.utils.text import messages as language_message_utils
10
11
  from eva.multimodal.models.typings import TextImageBatch
11
12
  from eva.multimodal.models.wrappers import base
13
+ from eva.multimodal.utils.batch import unpack_batch
12
14
  from eva.multimodal.utils.text import messages as message_utils
13
15
 
14
16
 
@@ -42,7 +44,7 @@ class LiteLLMModel(base.VisionLanguageModel):
42
44
 
43
45
  @override
44
46
  def format_inputs(self, batch: TextImageBatch) -> List[List[Dict[str, Any]]]:
45
- message_batch, image_batch, _, _ = TextImageBatch(*batch)
47
+ message_batch, image_batch, _, _ = unpack_batch(batch)
46
48
 
47
49
  message_batch = language_message_utils.batch_insert_system_message(
48
50
  message_batch, self.system_message
@@ -52,5 +54,5 @@ class LiteLLMModel(base.VisionLanguageModel):
52
54
  return list(map(message_utils.format_litellm_message, message_batch, image_batch))
53
55
 
54
56
  @override
55
- def model_forward(self, batch: List[List[Dict[str, Any]]]) -> List[str]:
57
+ def model_forward(self, batch: List[List[Dict[str, Any]]]) -> ModelOutput:
56
58
  return self.language_model.model_forward(batch)
@@ -0,0 +1,5 @@
1
+ """Multimodal batch utilities API."""
2
+
3
+ from eva.multimodal.utils.batch.unpack import unpack_batch
4
+
5
+ __all__ = ["unpack_batch"]
@@ -0,0 +1,11 @@
1
+ """Unpack batch utility function."""
2
+
3
+ from eva.language.models.typings import TextBatch
4
+ from eva.multimodal.models.typings import TextImageBatch
5
+
6
+
7
+ def unpack_batch(batch: TextImageBatch | TextBatch) -> tuple:
8
+ """Unpacks a TextImageBatch or TextBatch into its components."""
9
+ if isinstance(batch, TextImageBatch):
10
+ return batch.text, batch.image, batch.target, batch.metadata
11
+ return batch.text, None, batch.target, batch.metadata
@@ -101,11 +101,6 @@ class BreaKHis(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
101
101
  def class_to_idx(self) -> Dict[str, int]:
102
102
  return {label: index for index, label in enumerate(self.classes)}
103
103
 
104
- @property
105
- def _dataset_path(self) -> str:
106
- """Returns the path of the image data of the dataset."""
107
- return os.path.join(self._root, "BreaKHis_v1", "histology_slides")
108
-
109
104
  @functools.cached_property
110
105
  def _image_files(self) -> List[str]:
111
106
  """Return the list of image files in the dataset.
@@ -115,14 +110,14 @@ class BreaKHis(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
115
110
  """
116
111
  image_files = []
117
112
  for magnification in self._magnifications:
118
- files_pattern = os.path.join(self._dataset_path, f"**/{magnification}", "*.png")
113
+ files_pattern = os.path.join(self._root, f"**/{magnification}", "*.png")
119
114
  image_files.extend(list(glob.glob(files_pattern, recursive=True)))
120
115
  return sorted(image_files)
121
116
 
122
117
  @override
123
118
  def filename(self, index: int) -> str:
124
119
  image_path = self._image_files[self._indices[index]]
125
- return os.path.relpath(image_path, self._dataset_path)
120
+ return os.path.relpath(image_path, self._root)
126
121
 
127
122
  @override
128
123
  def prepare_data(self) -> None:
@@ -136,6 +131,8 @@ class BreaKHis(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
136
131
 
137
132
  @override
138
133
  def validate(self) -> None:
134
+ if not os.path.exists(self._root):
135
+ raise RuntimeError(f"Dataset not found at {self._root}.")
139
136
  _validators.check_dataset_integrity(
140
137
  self,
141
138
  length=self._expected_dataset_lengths[self._split],
@@ -164,7 +161,7 @@ class BreaKHis(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
164
161
  def _download_dataset(self) -> None:
165
162
  """Downloads the dataset."""
166
163
  for resource in self._resources:
167
- if os.path.isdir(self._dataset_path):
164
+ if os.path.isdir(self._root):
168
165
  continue
169
166
 
170
167
  self._print_license()
@@ -12,6 +12,7 @@ from torchvision.datasets import utils
12
12
  from torchvision.transforms.v2 import functional
13
13
  from typing_extensions import override
14
14
 
15
+ from eva.core import utils as core_utils
15
16
  from eva.core.data import splitting
16
17
  from eva.vision.data.datasets import _validators, structs, vision, wsi
17
18
  from eva.vision.data.wsi.patching import samplers
@@ -50,6 +51,7 @@ class PANDA(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Te
50
51
  image_transforms: Callable | None = None,
51
52
  coords_path: str | None = None,
52
53
  seed: int = 42,
54
+ download_dir: str | None = None,
53
55
  ) -> None:
54
56
  """Initializes the dataset.
55
57
 
@@ -64,10 +66,13 @@ class PANDA(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Te
64
66
  image_transforms: Transforms to apply to the extracted image patches.
65
67
  coords_path: File path to save the patch coordinates as .csv.
66
68
  seed: Random seed for reproducibility.
69
+ download_dir: Directory to download the dataset resources to. If None,
70
+ defaults to eva's home directory.
67
71
  """
68
72
  self._split = split
69
73
  self._root = root
70
74
  self._seed = seed
75
+ self._download_dir = download_dir or os.path.join(core_utils.home_dir(), "data", "panda")
71
76
 
72
77
  self._download_resources()
73
78
 
@@ -92,7 +97,7 @@ class PANDA(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Te
92
97
  @functools.cached_property
93
98
  def annotations(self) -> pd.DataFrame:
94
99
  """Loads the dataset labels."""
95
- path = os.path.join(self._root, "train_with_noisy_labels.csv")
100
+ path = os.path.join(self._download_dir, "train_with_noisy_labels.csv")
96
101
  return pd.read_csv(path, index_col="image_id")
97
102
 
98
103
  @override
@@ -100,14 +105,16 @@ class PANDA(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Te
100
105
  _validators.check_dataset_exists(self._root, False)
101
106
 
102
107
  if not os.path.isdir(os.path.join(self._root, "train_images")):
103
- raise FileNotFoundError("'train_images' directory not found in the root folder.")
104
- if not os.path.isfile(os.path.join(self._root, "train_with_noisy_labels.csv")):
105
- raise FileNotFoundError("'train.csv' file not found in the root folder.")
108
+ raise FileNotFoundError(f"'train_images' dir not found in folder: {self._root}")
109
+ if not os.path.isfile(os.path.join(self._download_dir, "train_with_noisy_labels.csv")):
110
+ raise FileNotFoundError(
111
+ f"'train_with_noisy_labels.csv' file not found in folder: {self._download_dir}"
112
+ )
106
113
 
107
114
  def _download_resources(self) -> None:
108
115
  """Downloads the dataset resources."""
109
116
  for resource in self._resources:
110
- utils.download_url(resource.url, self._root, resource.filename, resource.md5)
117
+ utils.download_url(resource.url, self._download_dir, resource.filename, resource.md5)
111
118
 
112
119
  @override
113
120
  def validate(self) -> None:
@@ -106,7 +106,7 @@ class BTCV(VisionDataset[eva_tv_tensors.Volume, tv_tensors.Mask]):
106
106
 
107
107
  @override
108
108
  def validate(self) -> None:
109
- requirements.check_dependencies(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
109
+ requirements.check_min_versions(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
110
110
 
111
111
  def _valid_sample(index: int) -> bool:
112
112
  """Indicates if the sample files exist and are reachable."""
@@ -108,7 +108,7 @@ class CoNSeP(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, tv_tens
108
108
  n_classes=5,
109
109
  first_and_last_labels=((self.classes[0], self.classes[-1])),
110
110
  )
111
- n_expected = self._expected_dataset_lengths[None]
111
+ n_expected = self._expected_dataset_lengths[self._split]
112
112
  if len(self._file_paths) != n_expected:
113
113
  raise ValueError(
114
114
  f"Expected {n_expected} images, found {len(self._file_paths)} in {self._root}."
@@ -123,7 +123,7 @@ class LiTS17(VisionDataset[eva_tv_tensors.Volume, tv_tensors.Mask]):
123
123
 
124
124
  @override
125
125
  def validate(self) -> None:
126
- requirements.check_dependencies(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
126
+ requirements.check_min_versions(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
127
127
 
128
128
  def _valid_sample(index: int) -> bool:
129
129
  """Indicates if the sample files exist and are reachable."""
@@ -15,6 +15,7 @@ from torchvision import tv_tensors
15
15
  from torchvision.datasets import utils
16
16
  from typing_extensions import override
17
17
 
18
+ from eva.core import utils as core_utils
18
19
  from eva.core.utils.progress_bar import tqdm
19
20
  from eva.vision.data.datasets import _validators, structs, vision
20
21
  from eva.vision.utils import io
@@ -55,6 +56,7 @@ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
55
56
  root: str,
56
57
  split: Literal["train", "test"],
57
58
  export_masks: bool = True,
59
+ processed_dir: str | None = None,
58
60
  download: bool = False,
59
61
  transforms: Callable | None = None,
60
62
  ) -> None:
@@ -66,6 +68,8 @@ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
66
68
  split: Dataset split to use.
67
69
  export_masks: Whether to export, save and use the semantic label masks
68
70
  from disk.
71
+ processed_dir: Directory where to store the processed masks.
72
+ Only used if `export_masks` is `True`.
69
73
  download: Whether to download the data for the specified split.
70
74
  Note that the download will be executed only by additionally
71
75
  calling the :meth:`prepare_data` method and if the data does not
@@ -79,6 +83,9 @@ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
79
83
  self._split = split
80
84
  self._export_masks = export_masks
81
85
  self._download = download
86
+ self._processed_dir = processed_dir or os.path.join(
87
+ core_utils.home_dir(), "data", "processed", "monusac"
88
+ )
82
89
 
83
90
  @property
84
91
  @override
@@ -155,10 +162,7 @@ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
155
162
 
156
163
  def _export_semantic_label_masks(self) -> None:
157
164
  """Export semantic label masks to disk."""
158
- mask_files = [
159
- (index, filename.replace(".tif", ".npy"))
160
- for index, filename in enumerate(self._image_files)
161
- ]
165
+ mask_files = [(i, self._processed_filename(i)) for i in range(len(self._image_files))]
162
166
  to_export = filter(lambda x: not os.path.isfile(x[1]), mask_files)
163
167
  for sample_index, filename in tqdm(
164
168
  list(to_export),
@@ -166,6 +170,7 @@ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
166
170
  leave=False,
167
171
  ):
168
172
  semantic_labels = self._get_semantic_mask(sample_index)
173
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
169
174
  np.save(filename, semantic_labels)
170
175
 
171
176
  def _load_semantic_mask_file(self, index: int) -> npt.NDArray[Any]:
@@ -177,8 +182,7 @@ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
177
182
  Returns:
178
183
  Loaded mask as a numpy array.
179
184
  """
180
- mask_filename = self._image_files[index].replace(".tif", ".npy")
181
- return np.load(mask_filename)
185
+ return np.load(self._processed_filename(index))
182
186
 
183
187
  def _get_semantic_mask(self, index: int) -> npt.NDArray[Any]:
184
188
  """Builds and loads the semantic label mask from the XML annotations.
@@ -216,6 +220,11 @@ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
216
220
 
217
221
  return semantic_labels
218
222
 
223
+ def _processed_filename(self, index: int) -> str:
224
+ """Returns the path of the processed mask for a given index."""
225
+ relative_path = os.path.relpath(self._image_files[index], self._root)
226
+ return os.path.join(self._processed_dir, relative_path).replace(".tif", ".npy")
227
+
219
228
  def _download_dataset(self) -> None:
220
229
  """Downloads the dataset."""
221
230
  self._print_license()
@@ -95,7 +95,7 @@ class MSDTask7Pancreas(VisionDataset[eva_tv_tensors.Volume, tv_tensors.Mask]):
95
95
 
96
96
  @override
97
97
  def validate(self) -> None:
98
- requirements.check_dependencies(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
98
+ requirements.check_min_versions(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
99
99
 
100
100
  def _valid_sample(index: int) -> bool:
101
101
  """Indicates if the sample files exist and are reachable."""
@@ -1,5 +1,6 @@
1
1
  """Base classes for transforms."""
2
2
 
3
3
  from eva.vision.data.transforms.base.monai import RandomMonaiTransform
4
+ from eva.vision.data.transforms.base.torchvision import TorchvisionTransformV2
4
5
 
5
- __all__ = ["RandomMonaiTransform"]
6
+ __all__ = ["RandomMonaiTransform", "TorchvisionTransformV2"]
@@ -2,10 +2,10 @@
2
2
 
3
3
  import abc
4
4
 
5
- from torchvision.transforms import v2
5
+ from eva.vision.data.transforms.base.torchvision import TorchvisionTransformV2
6
6
 
7
7
 
8
- class RandomMonaiTransform(v2.Transform, abc.ABC):
8
+ class RandomMonaiTransform(TorchvisionTransformV2, abc.ABC):
9
9
  """Base class for MONAI transform wrappers."""
10
10
 
11
11
  @abc.abstractmethod
@@ -0,0 +1,33 @@
1
+ """Base class for torchvision.v2 transforms."""
2
+
3
+ import abc
4
+ from typing import Any, Dict, List
5
+
6
+ from torchvision.transforms import v2
7
+
8
+
9
+ class TorchvisionTransformV2(v2.Transform, abc.ABC):
10
+ """Wrapper for torchvision.v2.Transform.
11
+
12
+ This class ensures compatibility both with >=0.21.0 and older versions,
13
+ as torchvision 0.21.0 introduced a new transform API where they
14
+ renamed the following methods:
15
+
16
+ - `_get_params` -> `make_params`
17
+ - `_transform` -> `transform`
18
+ """
19
+
20
+ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
21
+ """Called internally before calling transform() on each input."""
22
+ return {}
23
+
24
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
25
+ return self.make_params(flat_inputs)
26
+
27
+ @abc.abstractmethod
28
+ def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
29
+ """Applies the transformation to the input."""
30
+ raise NotImplementedError
31
+
32
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
33
+ return self.transform(inpt, params)
@@ -4,10 +4,12 @@ from typing import Any
4
4
 
5
5
  import torch
6
6
  from torchvision import tv_tensors
7
- from torchvision.transforms import v2
7
+ from typing_extensions import override
8
8
 
9
+ from eva.vision.data.transforms import base
9
10
 
10
- class Squeeze(v2.Transform):
11
+
12
+ class Squeeze(base.TorchvisionTransformV2):
11
13
  """Squeezes the input tensor accross all or specified dimensions."""
12
14
 
13
15
  def __init__(self, dim: int | list[int] | None = None):
@@ -19,6 +21,7 @@ class Squeeze(v2.Transform):
19
21
  super().__init__()
20
22
  self._dim = dim
21
23
 
22
- def _transform(self, inpt: Any, params: dict[str, Any]) -> Any:
24
+ @override
25
+ def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
23
26
  output = torch.squeeze(inpt) if self._dim is None else torch.squeeze(inpt, dim=self._dim)
24
27
  return tv_tensors.wrap(output, like=inpt)
@@ -8,13 +8,13 @@ from monai.config import type_definitions
8
8
  from monai.transforms.croppad import array as monai_croppad_transforms
9
9
  from monai.utils.enums import PytorchPadMode
10
10
  from torchvision import tv_tensors
11
- from torchvision.transforms import v2
12
11
  from typing_extensions import override
13
12
 
14
13
  from eva.vision.data import tv_tensors as eva_tv_tensors
14
+ from eva.vision.data.transforms import base
15
15
 
16
16
 
17
- class CropForeground(v2.Transform):
17
+ class CropForeground(base.TorchvisionTransformV2):
18
18
  """Crop an image using a bounding box.
19
19
 
20
20
  The bounding box is generated by selecting foreground using select_fn
@@ -74,19 +74,20 @@ class CropForeground(v2.Transform):
74
74
  **pad_kwargs,
75
75
  )
76
76
 
77
- def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
77
+ @override
78
+ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
78
79
  volume = next(inpt for inpt in flat_inputs if isinstance(inpt, eva_tv_tensors.Volume))
79
80
  box_start, box_end = self._foreground_crop.compute_bounding_box(volume)
80
81
  return {"box_start": box_start, "box_end": box_end}
81
82
 
82
83
  @functools.singledispatchmethod
83
84
  @override
84
- def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
85
+ def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
85
86
  return inpt
86
87
 
87
- @_transform.register(tv_tensors.Image)
88
- @_transform.register(eva_tv_tensors.Volume)
89
- @_transform.register(tv_tensors.Mask)
88
+ @transform.register(tv_tensors.Image)
89
+ @transform.register(eva_tv_tensors.Volume)
90
+ @transform.register(tv_tensors.Mask)
90
91
  def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
91
92
  inpt_foreground_cropped = self._foreground_crop.crop_pad(
92
93
  inpt, params["box_start"], params["box_end"]