kaiko-eva 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eva/core/data/dataloaders/__init__.py +2 -1
- eva/core/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/core/data/dataloaders/collate_fn/collate.py +24 -0
- eva/core/data/dataloaders/dataloader.py +4 -0
- eva/core/interface/interface.py +34 -1
- eva/core/metrics/defaults/classification/multiclass.py +45 -35
- eva/core/models/modules/__init__.py +2 -1
- eva/core/models/modules/scheduler.py +51 -0
- eva/core/models/transforms/extract_cls_features.py +1 -1
- eva/core/models/transforms/extract_patch_features.py +1 -1
- eva/core/models/wrappers/base.py +17 -14
- eva/core/models/wrappers/from_function.py +5 -4
- eva/core/models/wrappers/from_torchhub.py +5 -6
- eva/core/models/wrappers/huggingface.py +8 -5
- eva/core/models/wrappers/onnx.py +4 -4
- eva/core/trainers/_recorder.py +4 -1
- eva/core/trainers/functional.py +40 -43
- eva/core/utils/factory.py +66 -0
- eva/core/utils/registry.py +42 -0
- eva/core/utils/requirements.py +26 -0
- eva/language/__init__.py +13 -0
- eva/language/data/__init__.py +5 -0
- eva/language/data/datasets/__init__.py +9 -0
- eva/language/data/datasets/classification/__init__.py +7 -0
- eva/language/data/datasets/classification/base.py +63 -0
- eva/language/data/datasets/classification/pubmedqa.py +149 -0
- eva/language/data/datasets/language.py +13 -0
- eva/language/models/__init__.py +25 -0
- eva/language/models/modules/__init__.py +5 -0
- eva/language/models/modules/text.py +85 -0
- eva/language/models/modules/typings.py +16 -0
- eva/language/models/wrappers/__init__.py +11 -0
- eva/language/models/wrappers/huggingface.py +69 -0
- eva/language/models/wrappers/litellm.py +77 -0
- eva/language/models/wrappers/vllm.py +149 -0
- eva/language/utils/__init__.py +5 -0
- eva/language/utils/str_to_int_tensor.py +95 -0
- eva/vision/data/dataloaders/__init__.py +2 -1
- eva/vision/data/dataloaders/worker_init.py +35 -0
- eva/vision/data/datasets/__init__.py +5 -5
- eva/vision/data/datasets/segmentation/__init__.py +4 -4
- eva/vision/data/datasets/segmentation/btcv.py +3 -0
- eva/vision/data/datasets/segmentation/consep.py +5 -4
- eva/vision/data/datasets/segmentation/lits17.py +231 -0
- eva/vision/data/datasets/segmentation/metadata/__init__.py +1 -0
- eva/vision/data/datasets/segmentation/metadata/_msd_task7_pancreas.py +287 -0
- eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +243 -0
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +1 -1
- eva/vision/data/transforms/__init__.py +11 -2
- eva/vision/data/transforms/base/__init__.py +5 -0
- eva/vision/data/transforms/base/monai.py +27 -0
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/squeeze.py +24 -0
- eva/vision/data/transforms/croppad/__init__.py +4 -0
- eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +74 -0
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -2
- eva/vision/data/transforms/croppad/rand_spatial_crop.py +89 -0
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +6 -2
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +8 -4
- eva/vision/models/modules/semantic_segmentation.py +27 -11
- eva/vision/models/networks/backbones/__init__.py +2 -3
- eva/vision/models/networks/backbones/_utils.py +1 -1
- eva/vision/models/networks/backbones/pathology/bioptimus.py +4 -4
- eva/vision/models/networks/backbones/pathology/gigapath.py +2 -2
- eva/vision/models/networks/backbones/pathology/histai.py +3 -3
- eva/vision/models/networks/backbones/pathology/hkust.py +2 -2
- eva/vision/models/networks/backbones/pathology/kaiko.py +7 -7
- eva/vision/models/networks/backbones/pathology/lunit.py +3 -3
- eva/vision/models/networks/backbones/pathology/mahmood.py +3 -3
- eva/vision/models/networks/backbones/pathology/owkin.py +3 -3
- eva/vision/models/networks/backbones/pathology/paige.py +3 -3
- eva/vision/models/networks/backbones/radiology/swin_unetr.py +2 -2
- eva/vision/models/networks/backbones/radiology/voco.py +5 -5
- eva/vision/models/networks/backbones/registry.py +2 -44
- eva/vision/models/networks/backbones/timm/backbones.py +2 -2
- eva/vision/models/networks/backbones/universal/__init__.py +8 -1
- eva/vision/models/networks/backbones/universal/vit.py +53 -3
- eva/vision/models/networks/decoders/segmentation/decoder2d.py +1 -1
- eva/vision/models/networks/decoders/segmentation/linear.py +1 -1
- eva/vision/models/networks/decoders/segmentation/semantic/common.py +2 -2
- eva/vision/models/networks/decoders/segmentation/typings.py +1 -1
- eva/vision/models/wrappers/from_registry.py +14 -9
- eva/vision/models/wrappers/from_timm.py +6 -5
- {kaiko_eva-0.2.1.dist-info → kaiko_eva-0.3.0.dist-info}/METADATA +22 -12
- {kaiko_eva-0.2.1.dist-info → kaiko_eva-0.3.0.dist-info}/RECORD +89 -58
- {kaiko_eva-0.2.1.dist-info → kaiko_eva-0.3.0.dist-info}/WHEEL +1 -1
- eva/vision/data/datasets/segmentation/lits.py +0 -199
- eva/vision/data/datasets/segmentation/lits_balanced.py +0 -94
- /eva/vision/data/datasets/segmentation/{_total_segmentator.py → metadata/_total_segmentator.py} +0 -0
- {kaiko_eva-0.2.1.dist-info → kaiko_eva-0.3.0.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.2.1.dist-info → kaiko_eva-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""LLM wrapper for litellm models."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List
|
|
4
|
+
|
|
5
|
+
from litellm import batch_completion # type: ignore
|
|
6
|
+
from loguru import logger
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.core.models.wrappers import base
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LiteLLMTextModel(base.BaseModel[List[str], List[str]]):
|
|
13
|
+
"""Wrapper class for using litellm for chat-based text generation.
|
|
14
|
+
|
|
15
|
+
This wrapper uses litellm's `completion` function which accepts a list of
|
|
16
|
+
message dicts. The `forward` method converts a string prompt into a chat
|
|
17
|
+
message with a default "user" role, optionally prepends a system message,
|
|
18
|
+
and includes an API key if provided.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
model_name_or_path: str,
|
|
24
|
+
model_kwargs: Dict[str, Any] | None = None,
|
|
25
|
+
) -> None:
|
|
26
|
+
"""Initializes the litellm chat model wrapper.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
model_name_or_path: The model identifier (or name) for litellm
|
|
30
|
+
(e.g.,"openai/gpt-4o" or "anthropic/claude-3-sonnet-20240229").
|
|
31
|
+
model_kwargs: Additional keyword arguments to pass during
|
|
32
|
+
generation (e.g., `temperature`, `max_tokens`).
|
|
33
|
+
"""
|
|
34
|
+
super().__init__()
|
|
35
|
+
self._model_name_or_path = model_name_or_path
|
|
36
|
+
self._model_kwargs = model_kwargs or {}
|
|
37
|
+
self.load_model()
|
|
38
|
+
|
|
39
|
+
@override
|
|
40
|
+
def load_model(self) -> None:
|
|
41
|
+
"""Prepares the litellm model.
|
|
42
|
+
|
|
43
|
+
Note:
|
|
44
|
+
litellm doesn't require an explicit loading step; models are called
|
|
45
|
+
directly during generation. This method exists for API consistency.
|
|
46
|
+
"""
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
@override
|
|
50
|
+
def model_forward(self, prompts: List[str]) -> List[str]:
|
|
51
|
+
"""Generates text using litellm.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
prompts: A list of prompts to be converted into a "user" message.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
A list of generated text responses. Failed generations will contain
|
|
58
|
+
error messages instead of generated text.
|
|
59
|
+
"""
|
|
60
|
+
messages = [[{"role": "user", "content": prompt}] for prompt in prompts]
|
|
61
|
+
|
|
62
|
+
responses = batch_completion(
|
|
63
|
+
model=self._model_name_or_path,
|
|
64
|
+
messages=messages,
|
|
65
|
+
**self._model_kwargs,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
results = []
|
|
69
|
+
for i, response in enumerate(responses):
|
|
70
|
+
if isinstance(response, Exception):
|
|
71
|
+
error_msg = f"Error generating text for prompt {i}: {response}"
|
|
72
|
+
logger.error(error_msg)
|
|
73
|
+
raise RuntimeError(error_msg)
|
|
74
|
+
else:
|
|
75
|
+
results.append(response["choices"][0]["message"]["content"])
|
|
76
|
+
|
|
77
|
+
return results
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""LLM wrapper for vLLM models."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List, Sequence
|
|
4
|
+
|
|
5
|
+
from loguru import logger
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
from vllm import LLM, SamplingParams # type: ignore
|
|
10
|
+
from vllm.inputs import TokensPrompt # type: ignore
|
|
11
|
+
from vllm.transformers_utils.tokenizer import AnyTokenizer # type: ignore
|
|
12
|
+
except ImportError as e:
|
|
13
|
+
raise ImportError(
|
|
14
|
+
"vLLM is required for VLLMTextModel but not installed. "
|
|
15
|
+
"vLLM must be installed manually as it requires CUDA and is not included in dependencies. "
|
|
16
|
+
"Install with: pip install vllm "
|
|
17
|
+
"Note: vLLM requires Linux with CUDA support for optimal performance. "
|
|
18
|
+
"For alternatives, consider using HuggingFaceTextModel or LiteLLMTextModel."
|
|
19
|
+
) from e
|
|
20
|
+
|
|
21
|
+
from eva.core.models.wrappers import base
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class VLLMTextModel(base.BaseModel):
|
|
25
|
+
"""Wrapper class for using vLLM for text generation.
|
|
26
|
+
|
|
27
|
+
This wrapper loads a vLLM model, sets up the tokenizer and sampling
|
|
28
|
+
parameters, and uses a chat template to convert a plain string prompt
|
|
29
|
+
into the proper input format for vLLM generation. It then returns the
|
|
30
|
+
generated text response.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
model_name_or_path: str,
|
|
36
|
+
model_kwargs: Dict[str, Any] | None = None,
|
|
37
|
+
generation_kwargs: Dict[str, Any] | None = None,
|
|
38
|
+
) -> None:
|
|
39
|
+
"""Initializes the vLLM model wrapper.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
model_name_or_path: The model identifier (e.g., a Hugging Face
|
|
43
|
+
repo ID or local path).
|
|
44
|
+
model_kwargs: Arguments required to initialize the vLLM model,
|
|
45
|
+
see [link](https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py)
|
|
46
|
+
for more information.
|
|
47
|
+
generation_kwargs: Arguments required to generate the output,
|
|
48
|
+
need to align with the arguments of
|
|
49
|
+
[vllm.SamplingParams](https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py).
|
|
50
|
+
|
|
51
|
+
"""
|
|
52
|
+
super().__init__()
|
|
53
|
+
self._model_name_or_path = model_name_or_path
|
|
54
|
+
self._model_kwargs = model_kwargs or {}
|
|
55
|
+
self._generation_kwargs = generation_kwargs or {}
|
|
56
|
+
|
|
57
|
+
# Postpone heavy LLM initialisation to avoid pickling issues
|
|
58
|
+
self._llm_model: LLM | None = None
|
|
59
|
+
self._llm_tokenizer: AnyTokenizer | None = None
|
|
60
|
+
|
|
61
|
+
@override
|
|
62
|
+
def load_model(self) -> None:
|
|
63
|
+
"""Create the vLLM engine on first use.
|
|
64
|
+
|
|
65
|
+
This lazy initialisation keeps the wrapper picklable by Ray / Lightning.
|
|
66
|
+
"""
|
|
67
|
+
if self._llm_model is not None:
|
|
68
|
+
return
|
|
69
|
+
self._llm_model = LLM(model=self._model_name_or_path, **self._model_kwargs)
|
|
70
|
+
if self._llm_model is None:
|
|
71
|
+
raise RuntimeError("Model not initialized")
|
|
72
|
+
self._llm_tokenizer = self._llm_model.get_tokenizer()
|
|
73
|
+
|
|
74
|
+
def _apply_chat_template(self, prompts: Sequence[str]) -> list[TokensPrompt]:
|
|
75
|
+
"""Apply chat template to the messages.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
prompts: List of raw user strings.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
List of encoded messages.
|
|
82
|
+
|
|
83
|
+
Raises:
|
|
84
|
+
ValueError: If the tokenizer does not have a chat template.
|
|
85
|
+
"""
|
|
86
|
+
self.load_model()
|
|
87
|
+
if self._llm_tokenizer is None:
|
|
88
|
+
raise RuntimeError("Tokenizer not initialized")
|
|
89
|
+
|
|
90
|
+
if not hasattr(self._llm_tokenizer, "chat_template"):
|
|
91
|
+
raise ValueError("Tokenizer does not have a chat template.")
|
|
92
|
+
|
|
93
|
+
chat_messages = [[{"role": "user", "content": p}] for p in prompts]
|
|
94
|
+
encoded_messages = self._llm_tokenizer.apply_chat_template(
|
|
95
|
+
chat_messages, # type: ignore
|
|
96
|
+
tokenize=True,
|
|
97
|
+
add_generation_prompt=True,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Check for double start token (BOS)
|
|
101
|
+
if (
|
|
102
|
+
hasattr(self._llm_tokenizer, "bos_token_id")
|
|
103
|
+
and self._llm_tokenizer.bos_token_id is not None
|
|
104
|
+
and isinstance(encoded_messages, list)
|
|
105
|
+
and len(encoded_messages) >= 2
|
|
106
|
+
and encoded_messages[0] == self._llm_tokenizer.bos_token_id
|
|
107
|
+
and encoded_messages[1] == self._llm_tokenizer.bos_token_id
|
|
108
|
+
):
|
|
109
|
+
|
|
110
|
+
logger.warning("Found a double start token in the input_ids. Removing it.")
|
|
111
|
+
encoded_messages = encoded_messages[1:]
|
|
112
|
+
|
|
113
|
+
result = []
|
|
114
|
+
for encoded_message in encoded_messages:
|
|
115
|
+
if isinstance(encoded_message, (list, tuple)):
|
|
116
|
+
# Ensure all elements are integers
|
|
117
|
+
token_ids = [
|
|
118
|
+
int(token) if isinstance(token, (int, str)) and str(token).isdigit() else 0
|
|
119
|
+
for token in encoded_message
|
|
120
|
+
]
|
|
121
|
+
else:
|
|
122
|
+
# Handle single token case
|
|
123
|
+
token_id = (
|
|
124
|
+
int(encoded_message)
|
|
125
|
+
if isinstance(encoded_message, (int, str)) and str(encoded_message).isdigit()
|
|
126
|
+
else 0
|
|
127
|
+
)
|
|
128
|
+
token_ids = [token_id]
|
|
129
|
+
|
|
130
|
+
result.append(TokensPrompt(prompt_token_ids=token_ids))
|
|
131
|
+
|
|
132
|
+
return result
|
|
133
|
+
|
|
134
|
+
def generate(self, prompts: List[str]) -> List[str]:
|
|
135
|
+
"""Generates text for the given prompt using the vLLM model.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
prompts: A list of string prompts for generation.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
The generated text response.
|
|
142
|
+
"""
|
|
143
|
+
self.load_model()
|
|
144
|
+
if self._llm_model is None:
|
|
145
|
+
raise RuntimeError("Model not initialized")
|
|
146
|
+
|
|
147
|
+
prompt_tokens = self._apply_chat_template(prompts)
|
|
148
|
+
outputs = self._llm_model.generate(prompt_tokens, SamplingParams(**self._generation_kwargs))
|
|
149
|
+
return [output.outputs[0].text for output in outputs]
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""Transform utilities for post-processing predictions."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from typing import Any, Dict, List, Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CastStrToIntTensor:
|
|
10
|
+
"""Casts string predictions to a torch.Tensor of ints using regex mapping.
|
|
11
|
+
|
|
12
|
+
This transform is useful when model outputs are text responses (e.g., 'yes', 'no', 'maybe')
|
|
13
|
+
that need to be converted into integer tensors for evaluation. It uses regex patterns
|
|
14
|
+
to map text responses to integer labels, making it flexible for various classification tasks.
|
|
15
|
+
|
|
16
|
+
Supports single values, lists of strings, or lists of integers.
|
|
17
|
+
|
|
18
|
+
Example:
|
|
19
|
+
>>> # Default mapping for yes/no/maybe classification
|
|
20
|
+
>>> transform = CastStrToIntTensor()
|
|
21
|
+
>>> transform(['yes', 'no', 'maybe'])
|
|
22
|
+
tensor([1, 0, 2])
|
|
23
|
+
>>> transform('yes')
|
|
24
|
+
tensor([1])
|
|
25
|
+
|
|
26
|
+
>>> # Custom mapping
|
|
27
|
+
>>> transform = CastStrToIntTensor({r'positive|good': 1, r'negative|bad': 0})
|
|
28
|
+
>>> transform(['positive', 'bad'])
|
|
29
|
+
tensor([1, 0])
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, mapping: Dict[str, int] | None = None):
|
|
33
|
+
"""Initialize the transform with a regex-to-integer mapping.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
mapping: Dictionary mapping regex patterns to integers. If None, uses default
|
|
37
|
+
yes/no/maybe mapping: {'no': 0, 'yes': 1, 'maybe': 2}
|
|
38
|
+
"""
|
|
39
|
+
if mapping is None:
|
|
40
|
+
self.mapping = {r"\bno\b": 0, r"\byes\b": 1, r"\bmaybe\b": 2}
|
|
41
|
+
else:
|
|
42
|
+
self.mapping = mapping
|
|
43
|
+
|
|
44
|
+
self.compiled_patterns = [
|
|
45
|
+
(re.compile(pattern, re.IGNORECASE), value) for pattern, value in self.mapping.items()
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
def __call__(self, values: Union[str, List[str], List[int]]) -> torch.Tensor:
|
|
49
|
+
"""Convert string or list of strings/ints to a torch.Tensor of integers.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
values: A string, or a list of strings/integers representing responses.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
A 1D torch.Tensor of integers.
|
|
56
|
+
|
|
57
|
+
Raises:
|
|
58
|
+
ValueError: If any value cannot be mapped to an integer.
|
|
59
|
+
"""
|
|
60
|
+
return torch.tensor(
|
|
61
|
+
[self._cast_single(v) for v in (values if isinstance(values, list) else [values])],
|
|
62
|
+
dtype=torch.int,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def _cast_single(self, value: Any) -> int:
|
|
66
|
+
"""Casts a single value to an integer using regex mapping.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
value: A single value to convert (typically a string or int).
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
The value as an integer.
|
|
73
|
+
|
|
74
|
+
Raises:
|
|
75
|
+
ValueError: If the value cannot be mapped to an integer.
|
|
76
|
+
"""
|
|
77
|
+
if isinstance(value, int):
|
|
78
|
+
return value
|
|
79
|
+
|
|
80
|
+
if not isinstance(value, str):
|
|
81
|
+
value = str(value)
|
|
82
|
+
|
|
83
|
+
value = value.strip()
|
|
84
|
+
|
|
85
|
+
for pattern, mapped_value in self.compiled_patterns:
|
|
86
|
+
if pattern.search(value):
|
|
87
|
+
return mapped_value
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
return int(value)
|
|
91
|
+
except (ValueError, TypeError) as e:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"Cannot map value to int: {value!r}. "
|
|
94
|
+
f"Available patterns: {list(self.mapping.keys())}"
|
|
95
|
+
) from e
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Dataloader worker init functions."""
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
import torch.utils.data
|
|
8
|
+
import torchvision.transforms.v2
|
|
9
|
+
|
|
10
|
+
from eva.vision.data.transforms import base
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def seed_worker(worker_id: int) -> None:
|
|
14
|
+
"""Sets the random seed for each dataloader worker process.
|
|
15
|
+
|
|
16
|
+
How to use?
|
|
17
|
+
`torch.utils.data.Dataloader(..., worker_init_fn=seed_worker)`
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
worker_id: The ID of the worker process.
|
|
21
|
+
"""
|
|
22
|
+
worker_seed = (torch.initial_seed() + worker_id) % 2**32
|
|
23
|
+
np.random.seed(worker_seed)
|
|
24
|
+
random.seed(worker_seed)
|
|
25
|
+
torch.manual_seed(worker_seed)
|
|
26
|
+
|
|
27
|
+
worker_info = torch.utils.data.get_worker_info()
|
|
28
|
+
if worker_info is not None and hasattr(worker_info, "dataset"):
|
|
29
|
+
dataset = torch.utils.data.get_worker_info().dataset # type: ignore
|
|
30
|
+
if hasattr(dataset, "_transforms"):
|
|
31
|
+
transforms = dataset._transforms # type: ignore
|
|
32
|
+
if isinstance(transforms, torchvision.transforms.v2.Compose):
|
|
33
|
+
for transform in transforms.transforms:
|
|
34
|
+
if isinstance(transform, base.RandomMonaiTransform):
|
|
35
|
+
transform.set_random_state(seed=worker_seed)
|
|
@@ -19,9 +19,9 @@ from eva.vision.data.datasets.segmentation import (
|
|
|
19
19
|
BTCV,
|
|
20
20
|
CoNSeP,
|
|
21
21
|
EmbeddingsSegmentationDataset,
|
|
22
|
-
|
|
23
|
-
LiTSBalanced,
|
|
22
|
+
LiTS17,
|
|
24
23
|
MoNuSAC,
|
|
24
|
+
MSDTask7Pancreas,
|
|
25
25
|
TotalSegmentator2D,
|
|
26
26
|
)
|
|
27
27
|
from eva.vision.data.datasets.vision import VisionDataset
|
|
@@ -40,14 +40,14 @@ __all__ = [
|
|
|
40
40
|
"PANDASmall",
|
|
41
41
|
"Camelyon16",
|
|
42
42
|
"PatchCamelyon",
|
|
43
|
+
"TotalSegmentator2D",
|
|
43
44
|
"UniToPatho",
|
|
44
45
|
"WsiClassificationDataset",
|
|
45
46
|
"CoNSeP",
|
|
46
47
|
"EmbeddingsSegmentationDataset",
|
|
47
|
-
"
|
|
48
|
-
"
|
|
48
|
+
"LiTS17",
|
|
49
|
+
"MSDTask7Pancreas",
|
|
49
50
|
"MoNuSAC",
|
|
50
|
-
"TotalSegmentator2D",
|
|
51
51
|
"VisionDataset",
|
|
52
52
|
"MultiWsiDataset",
|
|
53
53
|
"WsiDataset",
|
|
@@ -4,9 +4,9 @@ from eva.vision.data.datasets.segmentation.bcss import BCSS
|
|
|
4
4
|
from eva.vision.data.datasets.segmentation.btcv import BTCV
|
|
5
5
|
from eva.vision.data.datasets.segmentation.consep import CoNSeP
|
|
6
6
|
from eva.vision.data.datasets.segmentation.embeddings import EmbeddingsSegmentationDataset
|
|
7
|
-
from eva.vision.data.datasets.segmentation.
|
|
8
|
-
from eva.vision.data.datasets.segmentation.lits_balanced import LiTSBalanced
|
|
7
|
+
from eva.vision.data.datasets.segmentation.lits17 import LiTS17
|
|
9
8
|
from eva.vision.data.datasets.segmentation.monusac import MoNuSAC
|
|
9
|
+
from eva.vision.data.datasets.segmentation.msd_task7_pancreas import MSDTask7Pancreas
|
|
10
10
|
from eva.vision.data.datasets.segmentation.total_segmentator_2d import TotalSegmentator2D
|
|
11
11
|
|
|
12
12
|
__all__ = [
|
|
@@ -14,8 +14,8 @@ __all__ = [
|
|
|
14
14
|
"BTCV",
|
|
15
15
|
"CoNSeP",
|
|
16
16
|
"EmbeddingsSegmentationDataset",
|
|
17
|
-
"
|
|
18
|
-
"
|
|
17
|
+
"LiTS17",
|
|
18
|
+
"MSDTask7Pancreas",
|
|
19
19
|
"MoNuSAC",
|
|
20
20
|
"TotalSegmentator2D",
|
|
21
21
|
]
|
|
@@ -10,6 +10,7 @@ from torchvision import tv_tensors
|
|
|
10
10
|
from torchvision.datasets import utils as data_utils
|
|
11
11
|
from typing_extensions import override
|
|
12
12
|
|
|
13
|
+
from eva.core.utils import requirements
|
|
13
14
|
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
14
15
|
from eva.vision.data.datasets import _utils as _data_utils
|
|
15
16
|
from eva.vision.data.datasets.segmentation import _utils
|
|
@@ -105,6 +106,8 @@ class BTCV(VisionDataset[eva_tv_tensors.Volume, tv_tensors.Mask]):
|
|
|
105
106
|
|
|
106
107
|
@override
|
|
107
108
|
def validate(self) -> None:
|
|
109
|
+
requirements.check_dependencies(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
|
|
110
|
+
|
|
108
111
|
def _valid_sample(index: int) -> bool:
|
|
109
112
|
"""Indicates if the sample files exist and are reachable."""
|
|
110
113
|
volume_file, segmentation_file = self._samples[self._indices[index]]
|
|
@@ -108,6 +108,11 @@ class CoNSeP(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, tv_tens
|
|
|
108
108
|
n_classes=5,
|
|
109
109
|
first_and_last_labels=((self.classes[0], self.classes[-1])),
|
|
110
110
|
)
|
|
111
|
+
n_expected = self._expected_dataset_lengths[None]
|
|
112
|
+
if len(self._file_paths) != n_expected:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
f"Expected {n_expected} images, found {len(self._file_paths)} in {self._root}."
|
|
115
|
+
)
|
|
111
116
|
|
|
112
117
|
@override
|
|
113
118
|
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, tv_tensors.Mask, Dict[str, Any]]:
|
|
@@ -135,10 +140,6 @@ class CoNSeP(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, tv_tens
|
|
|
135
140
|
def _load_file_paths(self, split: Literal["train", "val"] | None = None) -> List[str]:
|
|
136
141
|
"""Loads the file paths of the corresponding dataset split."""
|
|
137
142
|
paths = list(glob.glob(os.path.join(self._root, "**/Images/*.png"), recursive=True))
|
|
138
|
-
n_expected = self._expected_dataset_lengths[None]
|
|
139
|
-
if len(paths) != n_expected:
|
|
140
|
-
raise ValueError(f"Expected {n_expected} images, found {len(paths)} in {self._root}.")
|
|
141
|
-
|
|
142
143
|
if split is not None:
|
|
143
144
|
split_to_folder = {"train": "Train", "val": "Test"}
|
|
144
145
|
paths = filter(lambda p: split_to_folder[split] == p.split("/")[-3], paths)
|