kaiko-eva 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/config.py +11 -6
- eva/core/callbacks/writers/embeddings/base.py +44 -10
- eva/core/data/samplers/classification/balanced.py +24 -12
- eva/core/loggers/utils/wandb.py +4 -1
- eva/core/trainers/trainer.py +11 -1
- eva/core/utils/__init__.py +2 -1
- eva/core/utils/distributed.py +12 -0
- eva/core/utils/paths.py +14 -0
- eva/core/utils/requirements.py +52 -6
- eva/language/callbacks/writers/prediction.py +44 -19
- eva/language/data/datasets/classification/pubmedqa.py +1 -1
- eva/language/models/modules/language.py +7 -6
- eva/language/models/typings.py +19 -2
- eva/language/models/wrappers/base.py +4 -4
- eva/language/models/wrappers/huggingface.py +14 -4
- eva/language/models/wrappers/litellm.py +14 -4
- eva/multimodal/models/modules/vision_language.py +6 -5
- eva/multimodal/models/networks/alibaba.py +1 -0
- eva/multimodal/models/networks/others.py +2 -1
- eva/multimodal/models/wrappers/base.py +4 -3
- eva/multimodal/models/wrappers/huggingface.py +26 -13
- eva/multimodal/models/wrappers/litellm.py +4 -2
- eva/multimodal/utils/batch/__init__.py +5 -0
- eva/multimodal/utils/batch/unpack.py +11 -0
- eva/vision/data/datasets/classification/breakhis.py +5 -8
- eva/vision/data/datasets/classification/panda.py +12 -5
- eva/vision/data/datasets/segmentation/btcv.py +1 -1
- eva/vision/data/datasets/segmentation/consep.py +1 -1
- eva/vision/data/datasets/segmentation/lits17.py +1 -1
- eva/vision/data/datasets/segmentation/monusac.py +15 -6
- eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +1 -1
- eva/vision/data/transforms/base/__init__.py +2 -1
- eva/vision/data/transforms/base/monai.py +2 -2
- eva/vision/data/transforms/base/torchvision.py +33 -0
- eva/vision/data/transforms/common/squeeze.py +6 -3
- eva/vision/data/transforms/croppad/crop_foreground.py +8 -7
- eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +6 -5
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -5
- eva/vision/data/transforms/croppad/rand_spatial_crop.py +8 -7
- eva/vision/data/transforms/croppad/spatial_pad.py +6 -6
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +3 -3
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +3 -3
- eva/vision/data/transforms/intensity/scale_intensity_ranged.py +5 -5
- eva/vision/data/transforms/spatial/flip.py +8 -7
- eva/vision/data/transforms/spatial/resize.py +5 -4
- eva/vision/data/transforms/spatial/rotate.py +8 -7
- eva/vision/data/transforms/spatial/spacing.py +7 -6
- eva/vision/data/transforms/utility/ensure_channel_first.py +6 -6
- eva/vision/models/networks/backbones/universal/vit.py +24 -0
- {kaiko_eva-0.4.0.dist-info → kaiko_eva-0.4.1.dist-info}/METADATA +8 -2
- {kaiko_eva-0.4.0.dist-info → kaiko_eva-0.4.1.dist-info}/RECORD +54 -49
- {kaiko_eva-0.4.0.dist-info → kaiko_eva-0.4.1.dist-info}/WHEEL +0 -0
- {kaiko_eva-0.4.0.dist-info → kaiko_eva-0.4.1.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.4.0.dist-info → kaiko_eva-0.4.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -16,7 +16,7 @@ from litellm.exceptions import (
|
|
|
16
16
|
from loguru import logger
|
|
17
17
|
from typing_extensions import override
|
|
18
18
|
|
|
19
|
-
from eva.language.models.typings import TextBatch
|
|
19
|
+
from eva.language.models.typings import ModelOutput, TextBatch
|
|
20
20
|
from eva.language.models.wrappers import base
|
|
21
21
|
from eva.language.utils.text import messages as message_utils
|
|
22
22
|
|
|
@@ -32,6 +32,14 @@ RETRYABLE_ERRORS = (
|
|
|
32
32
|
class LiteLLMModel(base.LanguageModel):
|
|
33
33
|
"""Wrapper class for LiteLLM language models."""
|
|
34
34
|
|
|
35
|
+
_default_model_kwargs = {
|
|
36
|
+
"temperature": 0.0,
|
|
37
|
+
"max_completion_tokens": 1024,
|
|
38
|
+
"top_p": 1.0,
|
|
39
|
+
"seed": 42,
|
|
40
|
+
}
|
|
41
|
+
"""Default API model parameters for evaluation."""
|
|
42
|
+
|
|
35
43
|
def __init__(
|
|
36
44
|
self,
|
|
37
45
|
model_name: str,
|
|
@@ -51,9 +59,10 @@ class LiteLLMModel(base.LanguageModel):
|
|
|
51
59
|
super().__init__(system_prompt=system_prompt)
|
|
52
60
|
|
|
53
61
|
self.model_name = model_name
|
|
54
|
-
self.model_kwargs = model_kwargs or {}
|
|
62
|
+
self.model_kwargs = self._default_model_kwargs | (model_kwargs or {})
|
|
55
63
|
|
|
56
64
|
litellm.suppress_debug_info = True
|
|
65
|
+
litellm.drop_params = True
|
|
57
66
|
|
|
58
67
|
if log_level is not None:
|
|
59
68
|
logging.getLogger("LiteLLM").setLevel(log_level)
|
|
@@ -94,16 +103,17 @@ class LiteLLMModel(base.LanguageModel):
|
|
|
94
103
|
f"Retrying due to {details.get('exception') or 'Unknown error'}"
|
|
95
104
|
),
|
|
96
105
|
)
|
|
97
|
-
def model_forward(self, batch: List[List[Dict[str, Any]]]) ->
|
|
106
|
+
def model_forward(self, batch: List[List[Dict[str, Any]]]) -> ModelOutput:
|
|
98
107
|
"""Generates output text through API calls via LiteLLM's batch completion functionality."""
|
|
99
108
|
outputs = batch_completion(model=self.model_name, messages=batch, **self.model_kwargs)
|
|
100
109
|
self._raise_exceptions(outputs)
|
|
101
110
|
|
|
102
|
-
|
|
111
|
+
generated_text = [
|
|
103
112
|
output["choices"][0]["message"]["content"]
|
|
104
113
|
for output in outputs
|
|
105
114
|
if output["choices"][0]["message"]["role"] == "assistant"
|
|
106
115
|
]
|
|
116
|
+
return ModelOutput(generated_text=generated_text)
|
|
107
117
|
|
|
108
118
|
def _raise_exceptions(self, outputs: list):
|
|
109
119
|
for output in outputs:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Model module for vision-language models."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
4
|
|
|
5
5
|
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
6
6
|
from torch import nn
|
|
@@ -9,6 +9,7 @@ from typing_extensions import override
|
|
|
9
9
|
from eva.core.metrics import structs as metrics_lib
|
|
10
10
|
from eva.core.models.modules import module
|
|
11
11
|
from eva.core.models.modules.utils import batch_postprocess
|
|
12
|
+
from eva.language.models.typings import ModelOutput
|
|
12
13
|
from eva.multimodal.models.typings import TextImageBatch
|
|
13
14
|
|
|
14
15
|
|
|
@@ -33,7 +34,7 @@ class VisionLanguageModule(module.ModelModule):
|
|
|
33
34
|
self.model = model
|
|
34
35
|
|
|
35
36
|
@override
|
|
36
|
-
def forward(self, batch: TextImageBatch, *args: Any, **kwargs: Any) ->
|
|
37
|
+
def forward(self, batch: TextImageBatch, *args: Any, **kwargs: Any) -> ModelOutput:
|
|
37
38
|
return self.model(batch)
|
|
38
39
|
|
|
39
40
|
@override
|
|
@@ -46,10 +47,10 @@ class VisionLanguageModule(module.ModelModule):
|
|
|
46
47
|
|
|
47
48
|
def _batch_step(self, batch: TextImageBatch) -> STEP_OUTPUT:
|
|
48
49
|
text, _, targets, metadata = TextImageBatch(*batch)
|
|
49
|
-
|
|
50
|
+
output = self.forward(batch)
|
|
50
51
|
return {
|
|
51
52
|
"inputs": text,
|
|
52
|
-
"predictions":
|
|
53
|
+
"predictions": output.pop("generated_text"), # type: ignore
|
|
53
54
|
"targets": targets,
|
|
54
55
|
"metadata": metadata,
|
|
55
|
-
}
|
|
56
|
+
} | output
|
|
@@ -20,7 +20,7 @@ class PathoR13b(wrappers.HuggingFaceModel):
|
|
|
20
20
|
attn_implementation: str = "flash_attention_2",
|
|
21
21
|
):
|
|
22
22
|
"""Initialize the Patho-R1-3B model."""
|
|
23
|
-
requirements.
|
|
23
|
+
requirements.check_min_versions(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
|
|
24
24
|
|
|
25
25
|
if not os.getenv("HF_TOKEN"):
|
|
26
26
|
raise ValueError("HF_TOKEN env variable must be set.")
|
|
@@ -44,4 +44,5 @@ class PathoR13b(wrappers.HuggingFaceModel):
|
|
|
44
44
|
"max_pixels": 451584, # 672*672
|
|
45
45
|
},
|
|
46
46
|
system_prompt=system_prompt,
|
|
47
|
+
image_key="images",
|
|
47
48
|
)
|
|
@@ -1,16 +1,17 @@
|
|
|
1
1
|
"""Base class for vision language model wrappers."""
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
|
-
from typing import Any, Callable
|
|
4
|
+
from typing import Any, Callable
|
|
5
5
|
|
|
6
6
|
from typing_extensions import override
|
|
7
7
|
|
|
8
8
|
from eva.core.models.wrappers import base
|
|
9
9
|
from eva.language.data.messages import ModelSystemMessage
|
|
10
|
+
from eva.language.models.typings import ModelOutput
|
|
10
11
|
from eva.multimodal.models.typings import TextImageBatch
|
|
11
12
|
|
|
12
13
|
|
|
13
|
-
class VisionLanguageModel(base.BaseModel[TextImageBatch,
|
|
14
|
+
class VisionLanguageModel(base.BaseModel[TextImageBatch, ModelOutput]):
|
|
14
15
|
"""Base class for multimodal models.
|
|
15
16
|
|
|
16
17
|
Classes that inherit from this should implement the following methods:
|
|
@@ -36,7 +37,7 @@ class VisionLanguageModel(base.BaseModel[TextImageBatch, List[str]]):
|
|
|
36
37
|
self.system_message = ModelSystemMessage(content=system_prompt) if system_prompt else None
|
|
37
38
|
|
|
38
39
|
@override
|
|
39
|
-
def forward(self, batch: TextImageBatch) ->
|
|
40
|
+
def forward(self, batch: TextImageBatch) -> ModelOutput:
|
|
40
41
|
"""Forward pass of the model."""
|
|
41
42
|
inputs = self.format_inputs(batch)
|
|
42
43
|
return super().forward(inputs)
|
|
@@ -9,10 +9,11 @@ from loguru import logger
|
|
|
9
9
|
from torch import nn
|
|
10
10
|
from typing_extensions import override
|
|
11
11
|
|
|
12
|
-
from eva.language.models.typings import TextBatch
|
|
12
|
+
from eva.language.models.typings import ModelOutput, TextBatch
|
|
13
13
|
from eva.language.utils.text import messages as language_message_utils
|
|
14
14
|
from eva.multimodal.models.typings import TextImageBatch
|
|
15
15
|
from eva.multimodal.models.wrappers import base
|
|
16
|
+
from eva.multimodal.utils.batch import unpack_batch
|
|
16
17
|
from eva.multimodal.utils.text import messages as message_utils
|
|
17
18
|
|
|
18
19
|
|
|
@@ -27,6 +28,14 @@ class HuggingFaceModel(base.VisionLanguageModel):
|
|
|
27
28
|
generation_kwargs: Additional generation arguments.
|
|
28
29
|
"""
|
|
29
30
|
|
|
31
|
+
_default_generation_kwargs = {
|
|
32
|
+
"temperature": 0.0,
|
|
33
|
+
"max_new_tokens": 1024,
|
|
34
|
+
"do_sample": False,
|
|
35
|
+
"top_p": 1.0,
|
|
36
|
+
}
|
|
37
|
+
"""Default HF model parameters for evaluation."""
|
|
38
|
+
|
|
30
39
|
def __init__(
|
|
31
40
|
self,
|
|
32
41
|
model_name_or_path: str,
|
|
@@ -35,6 +44,7 @@ class HuggingFaceModel(base.VisionLanguageModel):
|
|
|
35
44
|
system_prompt: str | None = None,
|
|
36
45
|
processor_kwargs: Dict[str, Any] | None = None,
|
|
37
46
|
generation_kwargs: Dict[str, Any] | None = None,
|
|
47
|
+
image_key: str = "image",
|
|
38
48
|
):
|
|
39
49
|
"""Initialize the HuggingFace model wrapper.
|
|
40
50
|
|
|
@@ -45,6 +55,7 @@ class HuggingFaceModel(base.VisionLanguageModel):
|
|
|
45
55
|
system_prompt: System prompt to use.
|
|
46
56
|
processor_kwargs: Additional processor arguments.
|
|
47
57
|
generation_kwargs: Additional generation arguments.
|
|
58
|
+
image_key: The key used for image inputs in the chat template.
|
|
48
59
|
"""
|
|
49
60
|
super().__init__(system_prompt=system_prompt)
|
|
50
61
|
|
|
@@ -52,7 +63,8 @@ class HuggingFaceModel(base.VisionLanguageModel):
|
|
|
52
63
|
self.model_kwargs = model_kwargs or {}
|
|
53
64
|
self.base_model_class = model_class
|
|
54
65
|
self.processor_kwargs = processor_kwargs or {}
|
|
55
|
-
self.generation_kwargs = generation_kwargs or {}
|
|
66
|
+
self.generation_kwargs = self._default_generation_kwargs | (generation_kwargs or {})
|
|
67
|
+
self.image_key = image_key
|
|
56
68
|
|
|
57
69
|
self.processor = self.load_processor()
|
|
58
70
|
self.model = self.load_model()
|
|
@@ -72,7 +84,7 @@ class HuggingFaceModel(base.VisionLanguageModel):
|
|
|
72
84
|
"pixel_values": ...
|
|
73
85
|
}
|
|
74
86
|
"""
|
|
75
|
-
message_batch, image_batch, _, _ =
|
|
87
|
+
message_batch, image_batch, _, _ = unpack_batch(batch)
|
|
76
88
|
with_images = image_batch is not None
|
|
77
89
|
|
|
78
90
|
message_batch = language_message_utils.batch_insert_system_message(
|
|
@@ -105,12 +117,12 @@ class HuggingFaceModel(base.VisionLanguageModel):
|
|
|
105
117
|
}
|
|
106
118
|
|
|
107
119
|
if with_images:
|
|
108
|
-
processor_inputs[
|
|
120
|
+
processor_inputs[self.image_key] = [[image] for image in image_batch]
|
|
109
121
|
|
|
110
122
|
return self.processor(**processor_inputs).to(self.model.device) # type: ignore
|
|
111
123
|
|
|
112
124
|
@override
|
|
113
|
-
def model_forward(self, batch: Dict[str, torch.Tensor]) ->
|
|
125
|
+
def model_forward(self, batch: Dict[str, torch.Tensor]) -> ModelOutput:
|
|
114
126
|
"""Generates text output from the model. Is called by the `generate` method.
|
|
115
127
|
|
|
116
128
|
Args:
|
|
@@ -121,8 +133,14 @@ class HuggingFaceModel(base.VisionLanguageModel):
|
|
|
121
133
|
Returns:
|
|
122
134
|
A dictionary containing the processed input and the model's output.
|
|
123
135
|
"""
|
|
124
|
-
|
|
125
|
-
|
|
136
|
+
output_ids = self.model.generate(**batch, **self.generation_kwargs) # type: ignore
|
|
137
|
+
|
|
138
|
+
return ModelOutput(
|
|
139
|
+
generated_text=self._decode_output(output_ids, batch["input_ids"].shape[-1]),
|
|
140
|
+
input_ids=batch.get("input_ids"),
|
|
141
|
+
output_ids=output_ids,
|
|
142
|
+
attention_mask=batch.get("attention_mask"),
|
|
143
|
+
)
|
|
126
144
|
|
|
127
145
|
@override
|
|
128
146
|
def load_model(self) -> nn.Module:
|
|
@@ -148,15 +166,10 @@ class HuggingFaceModel(base.VisionLanguageModel):
|
|
|
148
166
|
def load_processor(self) -> Callable:
|
|
149
167
|
"""Initialize the processor."""
|
|
150
168
|
return transformers.AutoProcessor.from_pretrained(
|
|
151
|
-
self.model_name_or_path,
|
|
169
|
+
self.processor_kwargs.pop("model_name_or_path", self.model_name_or_path),
|
|
152
170
|
**self.processor_kwargs,
|
|
153
171
|
)
|
|
154
172
|
|
|
155
|
-
def _unpack_batch(self, batch: TextImageBatch | TextBatch) -> tuple:
|
|
156
|
-
if isinstance(batch, TextImageBatch):
|
|
157
|
-
return batch.text, batch.image, batch.target, batch.metadata
|
|
158
|
-
return batch.text, None, batch.target, batch.metadata
|
|
159
|
-
|
|
160
173
|
def _decode_output(self, output: torch.Tensor, instruction_length: int) -> List[str]:
|
|
161
174
|
"""Decode the model's batch output to text.
|
|
162
175
|
|
|
@@ -6,9 +6,11 @@ from typing import Any, Dict, List
|
|
|
6
6
|
from typing_extensions import override
|
|
7
7
|
|
|
8
8
|
from eva.language.models import wrappers as language_wrappers
|
|
9
|
+
from eva.language.models.typings import ModelOutput
|
|
9
10
|
from eva.language.utils.text import messages as language_message_utils
|
|
10
11
|
from eva.multimodal.models.typings import TextImageBatch
|
|
11
12
|
from eva.multimodal.models.wrappers import base
|
|
13
|
+
from eva.multimodal.utils.batch import unpack_batch
|
|
12
14
|
from eva.multimodal.utils.text import messages as message_utils
|
|
13
15
|
|
|
14
16
|
|
|
@@ -42,7 +44,7 @@ class LiteLLMModel(base.VisionLanguageModel):
|
|
|
42
44
|
|
|
43
45
|
@override
|
|
44
46
|
def format_inputs(self, batch: TextImageBatch) -> List[List[Dict[str, Any]]]:
|
|
45
|
-
message_batch, image_batch, _, _ =
|
|
47
|
+
message_batch, image_batch, _, _ = unpack_batch(batch)
|
|
46
48
|
|
|
47
49
|
message_batch = language_message_utils.batch_insert_system_message(
|
|
48
50
|
message_batch, self.system_message
|
|
@@ -52,5 +54,5 @@ class LiteLLMModel(base.VisionLanguageModel):
|
|
|
52
54
|
return list(map(message_utils.format_litellm_message, message_batch, image_batch))
|
|
53
55
|
|
|
54
56
|
@override
|
|
55
|
-
def model_forward(self, batch: List[List[Dict[str, Any]]]) ->
|
|
57
|
+
def model_forward(self, batch: List[List[Dict[str, Any]]]) -> ModelOutput:
|
|
56
58
|
return self.language_model.model_forward(batch)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Unpack batch utility function."""
|
|
2
|
+
|
|
3
|
+
from eva.language.models.typings import TextBatch
|
|
4
|
+
from eva.multimodal.models.typings import TextImageBatch
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def unpack_batch(batch: TextImageBatch | TextBatch) -> tuple:
|
|
8
|
+
"""Unpacks a TextImageBatch or TextBatch into its components."""
|
|
9
|
+
if isinstance(batch, TextImageBatch):
|
|
10
|
+
return batch.text, batch.image, batch.target, batch.metadata
|
|
11
|
+
return batch.text, None, batch.target, batch.metadata
|
|
@@ -101,11 +101,6 @@ class BreaKHis(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
|
101
101
|
def class_to_idx(self) -> Dict[str, int]:
|
|
102
102
|
return {label: index for index, label in enumerate(self.classes)}
|
|
103
103
|
|
|
104
|
-
@property
|
|
105
|
-
def _dataset_path(self) -> str:
|
|
106
|
-
"""Returns the path of the image data of the dataset."""
|
|
107
|
-
return os.path.join(self._root, "BreaKHis_v1", "histology_slides")
|
|
108
|
-
|
|
109
104
|
@functools.cached_property
|
|
110
105
|
def _image_files(self) -> List[str]:
|
|
111
106
|
"""Return the list of image files in the dataset.
|
|
@@ -115,14 +110,14 @@ class BreaKHis(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
|
115
110
|
"""
|
|
116
111
|
image_files = []
|
|
117
112
|
for magnification in self._magnifications:
|
|
118
|
-
files_pattern = os.path.join(self.
|
|
113
|
+
files_pattern = os.path.join(self._root, f"**/{magnification}", "*.png")
|
|
119
114
|
image_files.extend(list(glob.glob(files_pattern, recursive=True)))
|
|
120
115
|
return sorted(image_files)
|
|
121
116
|
|
|
122
117
|
@override
|
|
123
118
|
def filename(self, index: int) -> str:
|
|
124
119
|
image_path = self._image_files[self._indices[index]]
|
|
125
|
-
return os.path.relpath(image_path, self.
|
|
120
|
+
return os.path.relpath(image_path, self._root)
|
|
126
121
|
|
|
127
122
|
@override
|
|
128
123
|
def prepare_data(self) -> None:
|
|
@@ -136,6 +131,8 @@ class BreaKHis(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
|
136
131
|
|
|
137
132
|
@override
|
|
138
133
|
def validate(self) -> None:
|
|
134
|
+
if not os.path.exists(self._root):
|
|
135
|
+
raise RuntimeError(f"Dataset not found at {self._root}.")
|
|
139
136
|
_validators.check_dataset_integrity(
|
|
140
137
|
self,
|
|
141
138
|
length=self._expected_dataset_lengths[self._split],
|
|
@@ -164,7 +161,7 @@ class BreaKHis(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
|
164
161
|
def _download_dataset(self) -> None:
|
|
165
162
|
"""Downloads the dataset."""
|
|
166
163
|
for resource in self._resources:
|
|
167
|
-
if os.path.isdir(self.
|
|
164
|
+
if os.path.isdir(self._root):
|
|
168
165
|
continue
|
|
169
166
|
|
|
170
167
|
self._print_license()
|
|
@@ -12,6 +12,7 @@ from torchvision.datasets import utils
|
|
|
12
12
|
from torchvision.transforms.v2 import functional
|
|
13
13
|
from typing_extensions import override
|
|
14
14
|
|
|
15
|
+
from eva.core import utils as core_utils
|
|
15
16
|
from eva.core.data import splitting
|
|
16
17
|
from eva.vision.data.datasets import _validators, structs, vision, wsi
|
|
17
18
|
from eva.vision.data.wsi.patching import samplers
|
|
@@ -50,6 +51,7 @@ class PANDA(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Te
|
|
|
50
51
|
image_transforms: Callable | None = None,
|
|
51
52
|
coords_path: str | None = None,
|
|
52
53
|
seed: int = 42,
|
|
54
|
+
download_dir: str | None = None,
|
|
53
55
|
) -> None:
|
|
54
56
|
"""Initializes the dataset.
|
|
55
57
|
|
|
@@ -64,10 +66,13 @@ class PANDA(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Te
|
|
|
64
66
|
image_transforms: Transforms to apply to the extracted image patches.
|
|
65
67
|
coords_path: File path to save the patch coordinates as .csv.
|
|
66
68
|
seed: Random seed for reproducibility.
|
|
69
|
+
download_dir: Directory to download the dataset resources to. If None,
|
|
70
|
+
defaults to eva's home directory.
|
|
67
71
|
"""
|
|
68
72
|
self._split = split
|
|
69
73
|
self._root = root
|
|
70
74
|
self._seed = seed
|
|
75
|
+
self._download_dir = download_dir or os.path.join(core_utils.home_dir(), "data", "panda")
|
|
71
76
|
|
|
72
77
|
self._download_resources()
|
|
73
78
|
|
|
@@ -92,7 +97,7 @@ class PANDA(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Te
|
|
|
92
97
|
@functools.cached_property
|
|
93
98
|
def annotations(self) -> pd.DataFrame:
|
|
94
99
|
"""Loads the dataset labels."""
|
|
95
|
-
path = os.path.join(self.
|
|
100
|
+
path = os.path.join(self._download_dir, "train_with_noisy_labels.csv")
|
|
96
101
|
return pd.read_csv(path, index_col="image_id")
|
|
97
102
|
|
|
98
103
|
@override
|
|
@@ -100,14 +105,16 @@ class PANDA(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Te
|
|
|
100
105
|
_validators.check_dataset_exists(self._root, False)
|
|
101
106
|
|
|
102
107
|
if not os.path.isdir(os.path.join(self._root, "train_images")):
|
|
103
|
-
raise FileNotFoundError("'train_images'
|
|
104
|
-
if not os.path.isfile(os.path.join(self.
|
|
105
|
-
raise FileNotFoundError(
|
|
108
|
+
raise FileNotFoundError(f"'train_images' dir not found in folder: {self._root}")
|
|
109
|
+
if not os.path.isfile(os.path.join(self._download_dir, "train_with_noisy_labels.csv")):
|
|
110
|
+
raise FileNotFoundError(
|
|
111
|
+
f"'train_with_noisy_labels.csv' file not found in folder: {self._download_dir}"
|
|
112
|
+
)
|
|
106
113
|
|
|
107
114
|
def _download_resources(self) -> None:
|
|
108
115
|
"""Downloads the dataset resources."""
|
|
109
116
|
for resource in self._resources:
|
|
110
|
-
utils.download_url(resource.url, self.
|
|
117
|
+
utils.download_url(resource.url, self._download_dir, resource.filename, resource.md5)
|
|
111
118
|
|
|
112
119
|
@override
|
|
113
120
|
def validate(self) -> None:
|
|
@@ -106,7 +106,7 @@ class BTCV(VisionDataset[eva_tv_tensors.Volume, tv_tensors.Mask]):
|
|
|
106
106
|
|
|
107
107
|
@override
|
|
108
108
|
def validate(self) -> None:
|
|
109
|
-
requirements.
|
|
109
|
+
requirements.check_min_versions(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
|
|
110
110
|
|
|
111
111
|
def _valid_sample(index: int) -> bool:
|
|
112
112
|
"""Indicates if the sample files exist and are reachable."""
|
|
@@ -108,7 +108,7 @@ class CoNSeP(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, tv_tens
|
|
|
108
108
|
n_classes=5,
|
|
109
109
|
first_and_last_labels=((self.classes[0], self.classes[-1])),
|
|
110
110
|
)
|
|
111
|
-
n_expected = self._expected_dataset_lengths[
|
|
111
|
+
n_expected = self._expected_dataset_lengths[self._split]
|
|
112
112
|
if len(self._file_paths) != n_expected:
|
|
113
113
|
raise ValueError(
|
|
114
114
|
f"Expected {n_expected} images, found {len(self._file_paths)} in {self._root}."
|
|
@@ -123,7 +123,7 @@ class LiTS17(VisionDataset[eva_tv_tensors.Volume, tv_tensors.Mask]):
|
|
|
123
123
|
|
|
124
124
|
@override
|
|
125
125
|
def validate(self) -> None:
|
|
126
|
-
requirements.
|
|
126
|
+
requirements.check_min_versions(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
|
|
127
127
|
|
|
128
128
|
def _valid_sample(index: int) -> bool:
|
|
129
129
|
"""Indicates if the sample files exist and are reachable."""
|
|
@@ -15,6 +15,7 @@ from torchvision import tv_tensors
|
|
|
15
15
|
from torchvision.datasets import utils
|
|
16
16
|
from typing_extensions import override
|
|
17
17
|
|
|
18
|
+
from eva.core import utils as core_utils
|
|
18
19
|
from eva.core.utils.progress_bar import tqdm
|
|
19
20
|
from eva.vision.data.datasets import _validators, structs, vision
|
|
20
21
|
from eva.vision.utils import io
|
|
@@ -55,6 +56,7 @@ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
|
|
|
55
56
|
root: str,
|
|
56
57
|
split: Literal["train", "test"],
|
|
57
58
|
export_masks: bool = True,
|
|
59
|
+
processed_dir: str | None = None,
|
|
58
60
|
download: bool = False,
|
|
59
61
|
transforms: Callable | None = None,
|
|
60
62
|
) -> None:
|
|
@@ -66,6 +68,8 @@ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
|
|
|
66
68
|
split: Dataset split to use.
|
|
67
69
|
export_masks: Whether to export, save and use the semantic label masks
|
|
68
70
|
from disk.
|
|
71
|
+
processed_dir: Directory where to store the processed masks.
|
|
72
|
+
Only used if `export_masks` is `True`.
|
|
69
73
|
download: Whether to download the data for the specified split.
|
|
70
74
|
Note that the download will be executed only by additionally
|
|
71
75
|
calling the :meth:`prepare_data` method and if the data does not
|
|
@@ -79,6 +83,9 @@ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
|
|
|
79
83
|
self._split = split
|
|
80
84
|
self._export_masks = export_masks
|
|
81
85
|
self._download = download
|
|
86
|
+
self._processed_dir = processed_dir or os.path.join(
|
|
87
|
+
core_utils.home_dir(), "data", "processed", "monusac"
|
|
88
|
+
)
|
|
82
89
|
|
|
83
90
|
@property
|
|
84
91
|
@override
|
|
@@ -155,10 +162,7 @@ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
|
|
|
155
162
|
|
|
156
163
|
def _export_semantic_label_masks(self) -> None:
|
|
157
164
|
"""Export semantic label masks to disk."""
|
|
158
|
-
mask_files = [
|
|
159
|
-
(index, filename.replace(".tif", ".npy"))
|
|
160
|
-
for index, filename in enumerate(self._image_files)
|
|
161
|
-
]
|
|
165
|
+
mask_files = [(i, self._processed_filename(i)) for i in range(len(self._image_files))]
|
|
162
166
|
to_export = filter(lambda x: not os.path.isfile(x[1]), mask_files)
|
|
163
167
|
for sample_index, filename in tqdm(
|
|
164
168
|
list(to_export),
|
|
@@ -166,6 +170,7 @@ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
|
|
|
166
170
|
leave=False,
|
|
167
171
|
):
|
|
168
172
|
semantic_labels = self._get_semantic_mask(sample_index)
|
|
173
|
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
|
169
174
|
np.save(filename, semantic_labels)
|
|
170
175
|
|
|
171
176
|
def _load_semantic_mask_file(self, index: int) -> npt.NDArray[Any]:
|
|
@@ -177,8 +182,7 @@ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
|
|
|
177
182
|
Returns:
|
|
178
183
|
Loaded mask as a numpy array.
|
|
179
184
|
"""
|
|
180
|
-
|
|
181
|
-
return np.load(mask_filename)
|
|
185
|
+
return np.load(self._processed_filename(index))
|
|
182
186
|
|
|
183
187
|
def _get_semantic_mask(self, index: int) -> npt.NDArray[Any]:
|
|
184
188
|
"""Builds and loads the semantic label mask from the XML annotations.
|
|
@@ -216,6 +220,11 @@ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
|
|
|
216
220
|
|
|
217
221
|
return semantic_labels
|
|
218
222
|
|
|
223
|
+
def _processed_filename(self, index: int) -> str:
|
|
224
|
+
"""Returns the path of the processed mask for a given index."""
|
|
225
|
+
relative_path = os.path.relpath(self._image_files[index], self._root)
|
|
226
|
+
return os.path.join(self._processed_dir, relative_path).replace(".tif", ".npy")
|
|
227
|
+
|
|
219
228
|
def _download_dataset(self) -> None:
|
|
220
229
|
"""Downloads the dataset."""
|
|
221
230
|
self._print_license()
|
|
@@ -95,7 +95,7 @@ class MSDTask7Pancreas(VisionDataset[eva_tv_tensors.Volume, tv_tensors.Mask]):
|
|
|
95
95
|
|
|
96
96
|
@override
|
|
97
97
|
def validate(self) -> None:
|
|
98
|
-
requirements.
|
|
98
|
+
requirements.check_min_versions(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
|
|
99
99
|
|
|
100
100
|
def _valid_sample(index: int) -> bool:
|
|
101
101
|
"""Indicates if the sample files exist and are reachable."""
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Base classes for transforms."""
|
|
2
2
|
|
|
3
3
|
from eva.vision.data.transforms.base.monai import RandomMonaiTransform
|
|
4
|
+
from eva.vision.data.transforms.base.torchvision import TorchvisionTransformV2
|
|
4
5
|
|
|
5
|
-
__all__ = ["RandomMonaiTransform"]
|
|
6
|
+
__all__ = ["RandomMonaiTransform", "TorchvisionTransformV2"]
|
|
@@ -2,10 +2,10 @@
|
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from eva.vision.data.transforms.base.torchvision import TorchvisionTransformV2
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
class RandomMonaiTransform(
|
|
8
|
+
class RandomMonaiTransform(TorchvisionTransformV2, abc.ABC):
|
|
9
9
|
"""Base class for MONAI transform wrappers."""
|
|
10
10
|
|
|
11
11
|
@abc.abstractmethod
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Base class for torchvision.v2 transforms."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
|
|
6
|
+
from torchvision.transforms import v2
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TorchvisionTransformV2(v2.Transform, abc.ABC):
|
|
10
|
+
"""Wrapper for torchvision.v2.Transform.
|
|
11
|
+
|
|
12
|
+
This class ensures compatibility both with >=0.21.0 and older versions,
|
|
13
|
+
as torchvision 0.21.0 introduced a new transform API where they
|
|
14
|
+
renamed the following methods:
|
|
15
|
+
|
|
16
|
+
- `_get_params` -> `make_params`
|
|
17
|
+
- `_transform` -> `transform`
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
21
|
+
"""Called internally before calling transform() on each input."""
|
|
22
|
+
return {}
|
|
23
|
+
|
|
24
|
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
25
|
+
return self.make_params(flat_inputs)
|
|
26
|
+
|
|
27
|
+
@abc.abstractmethod
|
|
28
|
+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
29
|
+
"""Applies the transformation to the input."""
|
|
30
|
+
raise NotImplementedError
|
|
31
|
+
|
|
32
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
33
|
+
return self.transform(inpt, params)
|
|
@@ -4,10 +4,12 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
from torchvision import tv_tensors
|
|
7
|
-
from
|
|
7
|
+
from typing_extensions import override
|
|
8
8
|
|
|
9
|
+
from eva.vision.data.transforms import base
|
|
9
10
|
|
|
10
|
-
|
|
11
|
+
|
|
12
|
+
class Squeeze(base.TorchvisionTransformV2):
|
|
11
13
|
"""Squeezes the input tensor accross all or specified dimensions."""
|
|
12
14
|
|
|
13
15
|
def __init__(self, dim: int | list[int] | None = None):
|
|
@@ -19,6 +21,7 @@ class Squeeze(v2.Transform):
|
|
|
19
21
|
super().__init__()
|
|
20
22
|
self._dim = dim
|
|
21
23
|
|
|
22
|
-
|
|
24
|
+
@override
|
|
25
|
+
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
|
|
23
26
|
output = torch.squeeze(inpt) if self._dim is None else torch.squeeze(inpt, dim=self._dim)
|
|
24
27
|
return tv_tensors.wrap(output, like=inpt)
|
|
@@ -8,13 +8,13 @@ from monai.config import type_definitions
|
|
|
8
8
|
from monai.transforms.croppad import array as monai_croppad_transforms
|
|
9
9
|
from monai.utils.enums import PytorchPadMode
|
|
10
10
|
from torchvision import tv_tensors
|
|
11
|
-
from torchvision.transforms import v2
|
|
12
11
|
from typing_extensions import override
|
|
13
12
|
|
|
14
13
|
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
14
|
+
from eva.vision.data.transforms import base
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
class CropForeground(
|
|
17
|
+
class CropForeground(base.TorchvisionTransformV2):
|
|
18
18
|
"""Crop an image using a bounding box.
|
|
19
19
|
|
|
20
20
|
The bounding box is generated by selecting foreground using select_fn
|
|
@@ -74,19 +74,20 @@ class CropForeground(v2.Transform):
|
|
|
74
74
|
**pad_kwargs,
|
|
75
75
|
)
|
|
76
76
|
|
|
77
|
-
|
|
77
|
+
@override
|
|
78
|
+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
78
79
|
volume = next(inpt for inpt in flat_inputs if isinstance(inpt, eva_tv_tensors.Volume))
|
|
79
80
|
box_start, box_end = self._foreground_crop.compute_bounding_box(volume)
|
|
80
81
|
return {"box_start": box_start, "box_end": box_end}
|
|
81
82
|
|
|
82
83
|
@functools.singledispatchmethod
|
|
83
84
|
@override
|
|
84
|
-
def
|
|
85
|
+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
85
86
|
return inpt
|
|
86
87
|
|
|
87
|
-
@
|
|
88
|
-
@
|
|
89
|
-
@
|
|
88
|
+
@transform.register(tv_tensors.Image)
|
|
89
|
+
@transform.register(eva_tv_tensors.Volume)
|
|
90
|
+
@transform.register(tv_tensors.Mask)
|
|
90
91
|
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
91
92
|
inpt_foreground_cropped = self._foreground_crop.crop_pad(
|
|
92
93
|
inpt, params["box_start"], params["box_end"]
|