kaiko-eva 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. eva/core/data/dataloaders/__init__.py +2 -1
  2. eva/core/data/dataloaders/collate_fn/__init__.py +5 -0
  3. eva/core/data/dataloaders/collate_fn/collate.py +24 -0
  4. eva/core/data/dataloaders/dataloader.py +4 -0
  5. eva/core/interface/interface.py +34 -1
  6. eva/core/metrics/defaults/classification/multiclass.py +45 -35
  7. eva/core/models/modules/__init__.py +2 -1
  8. eva/core/models/modules/scheduler.py +51 -0
  9. eva/core/models/transforms/extract_cls_features.py +1 -1
  10. eva/core/models/transforms/extract_patch_features.py +1 -1
  11. eva/core/models/wrappers/base.py +17 -14
  12. eva/core/models/wrappers/from_function.py +5 -4
  13. eva/core/models/wrappers/from_torchhub.py +5 -6
  14. eva/core/models/wrappers/huggingface.py +8 -5
  15. eva/core/models/wrappers/onnx.py +4 -4
  16. eva/core/trainers/_recorder.py +4 -1
  17. eva/core/trainers/functional.py +40 -43
  18. eva/core/utils/factory.py +66 -0
  19. eva/core/utils/registry.py +42 -0
  20. eva/core/utils/requirements.py +26 -0
  21. eva/language/__init__.py +13 -0
  22. eva/language/data/__init__.py +5 -0
  23. eva/language/data/datasets/__init__.py +9 -0
  24. eva/language/data/datasets/classification/__init__.py +7 -0
  25. eva/language/data/datasets/classification/base.py +63 -0
  26. eva/language/data/datasets/classification/pubmedqa.py +149 -0
  27. eva/language/data/datasets/language.py +13 -0
  28. eva/language/models/__init__.py +25 -0
  29. eva/language/models/modules/__init__.py +5 -0
  30. eva/language/models/modules/text.py +85 -0
  31. eva/language/models/modules/typings.py +16 -0
  32. eva/language/models/wrappers/__init__.py +11 -0
  33. eva/language/models/wrappers/huggingface.py +69 -0
  34. eva/language/models/wrappers/litellm.py +77 -0
  35. eva/language/models/wrappers/vllm.py +149 -0
  36. eva/language/utils/__init__.py +5 -0
  37. eva/language/utils/str_to_int_tensor.py +95 -0
  38. eva/vision/data/dataloaders/__init__.py +2 -1
  39. eva/vision/data/dataloaders/worker_init.py +35 -0
  40. eva/vision/data/datasets/__init__.py +5 -5
  41. eva/vision/data/datasets/segmentation/__init__.py +4 -4
  42. eva/vision/data/datasets/segmentation/btcv.py +3 -0
  43. eva/vision/data/datasets/segmentation/consep.py +5 -4
  44. eva/vision/data/datasets/segmentation/lits17.py +231 -0
  45. eva/vision/data/datasets/segmentation/metadata/__init__.py +1 -0
  46. eva/vision/data/datasets/segmentation/metadata/_msd_task7_pancreas.py +287 -0
  47. eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +243 -0
  48. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +1 -1
  49. eva/vision/data/transforms/__init__.py +11 -2
  50. eva/vision/data/transforms/base/__init__.py +5 -0
  51. eva/vision/data/transforms/base/monai.py +27 -0
  52. eva/vision/data/transforms/common/__init__.py +2 -1
  53. eva/vision/data/transforms/common/squeeze.py +24 -0
  54. eva/vision/data/transforms/croppad/__init__.py +4 -0
  55. eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +74 -0
  56. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -2
  57. eva/vision/data/transforms/croppad/rand_spatial_crop.py +89 -0
  58. eva/vision/data/transforms/intensity/rand_scale_intensity.py +6 -2
  59. eva/vision/data/transforms/intensity/rand_shift_intensity.py +8 -4
  60. eva/vision/models/modules/semantic_segmentation.py +27 -11
  61. eva/vision/models/networks/backbones/__init__.py +2 -3
  62. eva/vision/models/networks/backbones/_utils.py +1 -1
  63. eva/vision/models/networks/backbones/pathology/bioptimus.py +4 -4
  64. eva/vision/models/networks/backbones/pathology/gigapath.py +2 -2
  65. eva/vision/models/networks/backbones/pathology/histai.py +3 -3
  66. eva/vision/models/networks/backbones/pathology/hkust.py +2 -2
  67. eva/vision/models/networks/backbones/pathology/kaiko.py +7 -7
  68. eva/vision/models/networks/backbones/pathology/lunit.py +3 -3
  69. eva/vision/models/networks/backbones/pathology/mahmood.py +3 -3
  70. eva/vision/models/networks/backbones/pathology/owkin.py +3 -3
  71. eva/vision/models/networks/backbones/pathology/paige.py +3 -3
  72. eva/vision/models/networks/backbones/radiology/swin_unetr.py +2 -2
  73. eva/vision/models/networks/backbones/radiology/voco.py +5 -5
  74. eva/vision/models/networks/backbones/registry.py +2 -44
  75. eva/vision/models/networks/backbones/timm/backbones.py +2 -2
  76. eva/vision/models/networks/backbones/universal/__init__.py +8 -1
  77. eva/vision/models/networks/backbones/universal/vit.py +53 -3
  78. eva/vision/models/networks/decoders/segmentation/decoder2d.py +1 -1
  79. eva/vision/models/networks/decoders/segmentation/linear.py +1 -1
  80. eva/vision/models/networks/decoders/segmentation/semantic/common.py +2 -2
  81. eva/vision/models/networks/decoders/segmentation/typings.py +1 -1
  82. eva/vision/models/wrappers/from_registry.py +14 -9
  83. eva/vision/models/wrappers/from_timm.py +6 -5
  84. {kaiko_eva-0.2.1.dist-info → kaiko_eva-0.3.0.dist-info}/METADATA +22 -12
  85. {kaiko_eva-0.2.1.dist-info → kaiko_eva-0.3.0.dist-info}/RECORD +89 -58
  86. {kaiko_eva-0.2.1.dist-info → kaiko_eva-0.3.0.dist-info}/WHEEL +1 -1
  87. eva/vision/data/datasets/segmentation/lits.py +0 -199
  88. eva/vision/data/datasets/segmentation/lits_balanced.py +0 -94
  89. /eva/vision/data/datasets/segmentation/{_total_segmentator.py → metadata/_total_segmentator.py} +0 -0
  90. {kaiko_eva-0.2.1.dist-info → kaiko_eva-0.3.0.dist-info}/entry_points.txt +0 -0
  91. {kaiko_eva-0.2.1.dist-info → kaiko_eva-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,77 @@
1
+ """LLM wrapper for litellm models."""
2
+
3
+ from typing import Any, Dict, List
4
+
5
+ from litellm import batch_completion # type: ignore
6
+ from loguru import logger
7
+ from typing_extensions import override
8
+
9
+ from eva.core.models.wrappers import base
10
+
11
+
12
+ class LiteLLMTextModel(base.BaseModel[List[str], List[str]]):
13
+ """Wrapper class for using litellm for chat-based text generation.
14
+
15
+ This wrapper uses litellm's `completion` function which accepts a list of
16
+ message dicts. The `forward` method converts a string prompt into a chat
17
+ message with a default "user" role, optionally prepends a system message,
18
+ and includes an API key if provided.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ model_name_or_path: str,
24
+ model_kwargs: Dict[str, Any] | None = None,
25
+ ) -> None:
26
+ """Initializes the litellm chat model wrapper.
27
+
28
+ Args:
29
+ model_name_or_path: The model identifier (or name) for litellm
30
+ (e.g.,"openai/gpt-4o" or "anthropic/claude-3-sonnet-20240229").
31
+ model_kwargs: Additional keyword arguments to pass during
32
+ generation (e.g., `temperature`, `max_tokens`).
33
+ """
34
+ super().__init__()
35
+ self._model_name_or_path = model_name_or_path
36
+ self._model_kwargs = model_kwargs or {}
37
+ self.load_model()
38
+
39
+ @override
40
+ def load_model(self) -> None:
41
+ """Prepares the litellm model.
42
+
43
+ Note:
44
+ litellm doesn't require an explicit loading step; models are called
45
+ directly during generation. This method exists for API consistency.
46
+ """
47
+ pass
48
+
49
+ @override
50
+ def model_forward(self, prompts: List[str]) -> List[str]:
51
+ """Generates text using litellm.
52
+
53
+ Args:
54
+ prompts: A list of prompts to be converted into a "user" message.
55
+
56
+ Returns:
57
+ A list of generated text responses. Failed generations will contain
58
+ error messages instead of generated text.
59
+ """
60
+ messages = [[{"role": "user", "content": prompt}] for prompt in prompts]
61
+
62
+ responses = batch_completion(
63
+ model=self._model_name_or_path,
64
+ messages=messages,
65
+ **self._model_kwargs,
66
+ )
67
+
68
+ results = []
69
+ for i, response in enumerate(responses):
70
+ if isinstance(response, Exception):
71
+ error_msg = f"Error generating text for prompt {i}: {response}"
72
+ logger.error(error_msg)
73
+ raise RuntimeError(error_msg)
74
+ else:
75
+ results.append(response["choices"][0]["message"]["content"])
76
+
77
+ return results
@@ -0,0 +1,149 @@
1
+ """LLM wrapper for vLLM models."""
2
+
3
+ from typing import Any, Dict, List, Sequence
4
+
5
+ from loguru import logger
6
+ from typing_extensions import override
7
+
8
+ try:
9
+ from vllm import LLM, SamplingParams # type: ignore
10
+ from vllm.inputs import TokensPrompt # type: ignore
11
+ from vllm.transformers_utils.tokenizer import AnyTokenizer # type: ignore
12
+ except ImportError as e:
13
+ raise ImportError(
14
+ "vLLM is required for VLLMTextModel but not installed. "
15
+ "vLLM must be installed manually as it requires CUDA and is not included in dependencies. "
16
+ "Install with: pip install vllm "
17
+ "Note: vLLM requires Linux with CUDA support for optimal performance. "
18
+ "For alternatives, consider using HuggingFaceTextModel or LiteLLMTextModel."
19
+ ) from e
20
+
21
+ from eva.core.models.wrappers import base
22
+
23
+
24
+ class VLLMTextModel(base.BaseModel):
25
+ """Wrapper class for using vLLM for text generation.
26
+
27
+ This wrapper loads a vLLM model, sets up the tokenizer and sampling
28
+ parameters, and uses a chat template to convert a plain string prompt
29
+ into the proper input format for vLLM generation. It then returns the
30
+ generated text response.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ model_name_or_path: str,
36
+ model_kwargs: Dict[str, Any] | None = None,
37
+ generation_kwargs: Dict[str, Any] | None = None,
38
+ ) -> None:
39
+ """Initializes the vLLM model wrapper.
40
+
41
+ Args:
42
+ model_name_or_path: The model identifier (e.g., a Hugging Face
43
+ repo ID or local path).
44
+ model_kwargs: Arguments required to initialize the vLLM model,
45
+ see [link](https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py)
46
+ for more information.
47
+ generation_kwargs: Arguments required to generate the output,
48
+ need to align with the arguments of
49
+ [vllm.SamplingParams](https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py).
50
+
51
+ """
52
+ super().__init__()
53
+ self._model_name_or_path = model_name_or_path
54
+ self._model_kwargs = model_kwargs or {}
55
+ self._generation_kwargs = generation_kwargs or {}
56
+
57
+ # Postpone heavy LLM initialisation to avoid pickling issues
58
+ self._llm_model: LLM | None = None
59
+ self._llm_tokenizer: AnyTokenizer | None = None
60
+
61
+ @override
62
+ def load_model(self) -> None:
63
+ """Create the vLLM engine on first use.
64
+
65
+ This lazy initialisation keeps the wrapper picklable by Ray / Lightning.
66
+ """
67
+ if self._llm_model is not None:
68
+ return
69
+ self._llm_model = LLM(model=self._model_name_or_path, **self._model_kwargs)
70
+ if self._llm_model is None:
71
+ raise RuntimeError("Model not initialized")
72
+ self._llm_tokenizer = self._llm_model.get_tokenizer()
73
+
74
+ def _apply_chat_template(self, prompts: Sequence[str]) -> list[TokensPrompt]:
75
+ """Apply chat template to the messages.
76
+
77
+ Args:
78
+ prompts: List of raw user strings.
79
+
80
+ Returns:
81
+ List of encoded messages.
82
+
83
+ Raises:
84
+ ValueError: If the tokenizer does not have a chat template.
85
+ """
86
+ self.load_model()
87
+ if self._llm_tokenizer is None:
88
+ raise RuntimeError("Tokenizer not initialized")
89
+
90
+ if not hasattr(self._llm_tokenizer, "chat_template"):
91
+ raise ValueError("Tokenizer does not have a chat template.")
92
+
93
+ chat_messages = [[{"role": "user", "content": p}] for p in prompts]
94
+ encoded_messages = self._llm_tokenizer.apply_chat_template(
95
+ chat_messages, # type: ignore
96
+ tokenize=True,
97
+ add_generation_prompt=True,
98
+ )
99
+
100
+ # Check for double start token (BOS)
101
+ if (
102
+ hasattr(self._llm_tokenizer, "bos_token_id")
103
+ and self._llm_tokenizer.bos_token_id is not None
104
+ and isinstance(encoded_messages, list)
105
+ and len(encoded_messages) >= 2
106
+ and encoded_messages[0] == self._llm_tokenizer.bos_token_id
107
+ and encoded_messages[1] == self._llm_tokenizer.bos_token_id
108
+ ):
109
+
110
+ logger.warning("Found a double start token in the input_ids. Removing it.")
111
+ encoded_messages = encoded_messages[1:]
112
+
113
+ result = []
114
+ for encoded_message in encoded_messages:
115
+ if isinstance(encoded_message, (list, tuple)):
116
+ # Ensure all elements are integers
117
+ token_ids = [
118
+ int(token) if isinstance(token, (int, str)) and str(token).isdigit() else 0
119
+ for token in encoded_message
120
+ ]
121
+ else:
122
+ # Handle single token case
123
+ token_id = (
124
+ int(encoded_message)
125
+ if isinstance(encoded_message, (int, str)) and str(encoded_message).isdigit()
126
+ else 0
127
+ )
128
+ token_ids = [token_id]
129
+
130
+ result.append(TokensPrompt(prompt_token_ids=token_ids))
131
+
132
+ return result
133
+
134
+ def generate(self, prompts: List[str]) -> List[str]:
135
+ """Generates text for the given prompt using the vLLM model.
136
+
137
+ Args:
138
+ prompts: A list of string prompts for generation.
139
+
140
+ Returns:
141
+ The generated text response.
142
+ """
143
+ self.load_model()
144
+ if self._llm_model is None:
145
+ raise RuntimeError("Model not initialized")
146
+
147
+ prompt_tokens = self._apply_chat_template(prompts)
148
+ outputs = self._llm_model.generate(prompt_tokens, SamplingParams(**self._generation_kwargs))
149
+ return [output.outputs[0].text for output in outputs]
@@ -0,0 +1,5 @@
1
+ """Language utilities and helper functions."""
2
+
3
+ from eva.language.utils.str_to_int_tensor import CastStrToIntTensor
4
+
5
+ __all__ = ["CastStrToIntTensor"]
@@ -0,0 +1,95 @@
1
+ """Transform utilities for post-processing predictions."""
2
+
3
+ import re
4
+ from typing import Any, Dict, List, Union
5
+
6
+ import torch
7
+
8
+
9
+ class CastStrToIntTensor:
10
+ """Casts string predictions to a torch.Tensor of ints using regex mapping.
11
+
12
+ This transform is useful when model outputs are text responses (e.g., 'yes', 'no', 'maybe')
13
+ that need to be converted into integer tensors for evaluation. It uses regex patterns
14
+ to map text responses to integer labels, making it flexible for various classification tasks.
15
+
16
+ Supports single values, lists of strings, or lists of integers.
17
+
18
+ Example:
19
+ >>> # Default mapping for yes/no/maybe classification
20
+ >>> transform = CastStrToIntTensor()
21
+ >>> transform(['yes', 'no', 'maybe'])
22
+ tensor([1, 0, 2])
23
+ >>> transform('yes')
24
+ tensor([1])
25
+
26
+ >>> # Custom mapping
27
+ >>> transform = CastStrToIntTensor({r'positive|good': 1, r'negative|bad': 0})
28
+ >>> transform(['positive', 'bad'])
29
+ tensor([1, 0])
30
+ """
31
+
32
+ def __init__(self, mapping: Dict[str, int] | None = None):
33
+ """Initialize the transform with a regex-to-integer mapping.
34
+
35
+ Args:
36
+ mapping: Dictionary mapping regex patterns to integers. If None, uses default
37
+ yes/no/maybe mapping: {'no': 0, 'yes': 1, 'maybe': 2}
38
+ """
39
+ if mapping is None:
40
+ self.mapping = {r"\bno\b": 0, r"\byes\b": 1, r"\bmaybe\b": 2}
41
+ else:
42
+ self.mapping = mapping
43
+
44
+ self.compiled_patterns = [
45
+ (re.compile(pattern, re.IGNORECASE), value) for pattern, value in self.mapping.items()
46
+ ]
47
+
48
+ def __call__(self, values: Union[str, List[str], List[int]]) -> torch.Tensor:
49
+ """Convert string or list of strings/ints to a torch.Tensor of integers.
50
+
51
+ Args:
52
+ values: A string, or a list of strings/integers representing responses.
53
+
54
+ Returns:
55
+ A 1D torch.Tensor of integers.
56
+
57
+ Raises:
58
+ ValueError: If any value cannot be mapped to an integer.
59
+ """
60
+ return torch.tensor(
61
+ [self._cast_single(v) for v in (values if isinstance(values, list) else [values])],
62
+ dtype=torch.int,
63
+ )
64
+
65
+ def _cast_single(self, value: Any) -> int:
66
+ """Casts a single value to an integer using regex mapping.
67
+
68
+ Args:
69
+ value: A single value to convert (typically a string or int).
70
+
71
+ Returns:
72
+ The value as an integer.
73
+
74
+ Raises:
75
+ ValueError: If the value cannot be mapped to an integer.
76
+ """
77
+ if isinstance(value, int):
78
+ return value
79
+
80
+ if not isinstance(value, str):
81
+ value = str(value)
82
+
83
+ value = value.strip()
84
+
85
+ for pattern, mapped_value in self.compiled_patterns:
86
+ if pattern.search(value):
87
+ return mapped_value
88
+
89
+ try:
90
+ return int(value)
91
+ except (ValueError, TypeError) as e:
92
+ raise ValueError(
93
+ f"Cannot map value to int: {value!r}. "
94
+ f"Available patterns: {list(self.mapping.keys())}"
95
+ ) from e
@@ -1,5 +1,6 @@
1
1
  """Dataloader related utilities and functions."""
2
2
 
3
3
  from eva.vision.data.dataloaders import collate_fn
4
+ from eva.vision.data.dataloaders.worker_init import seed_worker
4
5
 
5
- __all__ = ["collate_fn"]
6
+ __all__ = ["collate_fn", "seed_worker"]
@@ -0,0 +1,35 @@
1
+ """Dataloader worker init functions."""
2
+
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.utils.data
8
+ import torchvision.transforms.v2
9
+
10
+ from eva.vision.data.transforms import base
11
+
12
+
13
+ def seed_worker(worker_id: int) -> None:
14
+ """Sets the random seed for each dataloader worker process.
15
+
16
+ How to use?
17
+ `torch.utils.data.Dataloader(..., worker_init_fn=seed_worker)`
18
+
19
+ Args:
20
+ worker_id: The ID of the worker process.
21
+ """
22
+ worker_seed = (torch.initial_seed() + worker_id) % 2**32
23
+ np.random.seed(worker_seed)
24
+ random.seed(worker_seed)
25
+ torch.manual_seed(worker_seed)
26
+
27
+ worker_info = torch.utils.data.get_worker_info()
28
+ if worker_info is not None and hasattr(worker_info, "dataset"):
29
+ dataset = torch.utils.data.get_worker_info().dataset # type: ignore
30
+ if hasattr(dataset, "_transforms"):
31
+ transforms = dataset._transforms # type: ignore
32
+ if isinstance(transforms, torchvision.transforms.v2.Compose):
33
+ for transform in transforms.transforms:
34
+ if isinstance(transform, base.RandomMonaiTransform):
35
+ transform.set_random_state(seed=worker_seed)
@@ -19,9 +19,9 @@ from eva.vision.data.datasets.segmentation import (
19
19
  BTCV,
20
20
  CoNSeP,
21
21
  EmbeddingsSegmentationDataset,
22
- LiTS,
23
- LiTSBalanced,
22
+ LiTS17,
24
23
  MoNuSAC,
24
+ MSDTask7Pancreas,
25
25
  TotalSegmentator2D,
26
26
  )
27
27
  from eva.vision.data.datasets.vision import VisionDataset
@@ -40,14 +40,14 @@ __all__ = [
40
40
  "PANDASmall",
41
41
  "Camelyon16",
42
42
  "PatchCamelyon",
43
+ "TotalSegmentator2D",
43
44
  "UniToPatho",
44
45
  "WsiClassificationDataset",
45
46
  "CoNSeP",
46
47
  "EmbeddingsSegmentationDataset",
47
- "LiTS",
48
- "LiTSBalanced",
48
+ "LiTS17",
49
+ "MSDTask7Pancreas",
49
50
  "MoNuSAC",
50
- "TotalSegmentator2D",
51
51
  "VisionDataset",
52
52
  "MultiWsiDataset",
53
53
  "WsiDataset",
@@ -4,9 +4,9 @@ from eva.vision.data.datasets.segmentation.bcss import BCSS
4
4
  from eva.vision.data.datasets.segmentation.btcv import BTCV
5
5
  from eva.vision.data.datasets.segmentation.consep import CoNSeP
6
6
  from eva.vision.data.datasets.segmentation.embeddings import EmbeddingsSegmentationDataset
7
- from eva.vision.data.datasets.segmentation.lits import LiTS
8
- from eva.vision.data.datasets.segmentation.lits_balanced import LiTSBalanced
7
+ from eva.vision.data.datasets.segmentation.lits17 import LiTS17
9
8
  from eva.vision.data.datasets.segmentation.monusac import MoNuSAC
9
+ from eva.vision.data.datasets.segmentation.msd_task7_pancreas import MSDTask7Pancreas
10
10
  from eva.vision.data.datasets.segmentation.total_segmentator_2d import TotalSegmentator2D
11
11
 
12
12
  __all__ = [
@@ -14,8 +14,8 @@ __all__ = [
14
14
  "BTCV",
15
15
  "CoNSeP",
16
16
  "EmbeddingsSegmentationDataset",
17
- "LiTS",
18
- "LiTSBalanced",
17
+ "LiTS17",
18
+ "MSDTask7Pancreas",
19
19
  "MoNuSAC",
20
20
  "TotalSegmentator2D",
21
21
  ]
@@ -10,6 +10,7 @@ from torchvision import tv_tensors
10
10
  from torchvision.datasets import utils as data_utils
11
11
  from typing_extensions import override
12
12
 
13
+ from eva.core.utils import requirements
13
14
  from eva.vision.data import tv_tensors as eva_tv_tensors
14
15
  from eva.vision.data.datasets import _utils as _data_utils
15
16
  from eva.vision.data.datasets.segmentation import _utils
@@ -105,6 +106,8 @@ class BTCV(VisionDataset[eva_tv_tensors.Volume, tv_tensors.Mask]):
105
106
 
106
107
  @override
107
108
  def validate(self) -> None:
109
+ requirements.check_dependencies(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
110
+
108
111
  def _valid_sample(index: int) -> bool:
109
112
  """Indicates if the sample files exist and are reachable."""
110
113
  volume_file, segmentation_file = self._samples[self._indices[index]]
@@ -108,6 +108,11 @@ class CoNSeP(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, tv_tens
108
108
  n_classes=5,
109
109
  first_and_last_labels=((self.classes[0], self.classes[-1])),
110
110
  )
111
+ n_expected = self._expected_dataset_lengths[None]
112
+ if len(self._file_paths) != n_expected:
113
+ raise ValueError(
114
+ f"Expected {n_expected} images, found {len(self._file_paths)} in {self._root}."
115
+ )
111
116
 
112
117
  @override
113
118
  def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, tv_tensors.Mask, Dict[str, Any]]:
@@ -135,10 +140,6 @@ class CoNSeP(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, tv_tens
135
140
  def _load_file_paths(self, split: Literal["train", "val"] | None = None) -> List[str]:
136
141
  """Loads the file paths of the corresponding dataset split."""
137
142
  paths = list(glob.glob(os.path.join(self._root, "**/Images/*.png"), recursive=True))
138
- n_expected = self._expected_dataset_lengths[None]
139
- if len(paths) != n_expected:
140
- raise ValueError(f"Expected {n_expected} images, found {len(paths)} in {self._root}.")
141
-
142
143
  if split is not None:
143
144
  split_to_folder = {"train": "Train", "val": "Test"}
144
145
  paths = filter(lambda p: split_to_folder[split] == p.split("/")[-3], paths)