kaiko-eva 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. eva/core/data/dataloaders/__init__.py +2 -1
  2. eva/core/data/dataloaders/collate_fn/__init__.py +5 -0
  3. eva/core/data/dataloaders/collate_fn/collate.py +24 -0
  4. eva/core/data/dataloaders/dataloader.py +4 -0
  5. eva/core/interface/interface.py +34 -1
  6. eva/core/metrics/defaults/classification/multiclass.py +45 -35
  7. eva/core/models/modules/__init__.py +2 -1
  8. eva/core/models/modules/scheduler.py +51 -0
  9. eva/core/models/transforms/extract_cls_features.py +1 -1
  10. eva/core/models/transforms/extract_patch_features.py +1 -1
  11. eva/core/models/wrappers/base.py +17 -14
  12. eva/core/models/wrappers/from_function.py +5 -4
  13. eva/core/models/wrappers/from_torchhub.py +5 -6
  14. eva/core/models/wrappers/huggingface.py +8 -5
  15. eva/core/models/wrappers/onnx.py +4 -4
  16. eva/core/trainers/functional.py +40 -43
  17. eva/core/utils/factory.py +66 -0
  18. eva/core/utils/registry.py +42 -0
  19. eva/core/utils/requirements.py +26 -0
  20. eva/language/__init__.py +13 -0
  21. eva/language/data/__init__.py +5 -0
  22. eva/language/data/datasets/__init__.py +9 -0
  23. eva/language/data/datasets/classification/__init__.py +7 -0
  24. eva/language/data/datasets/classification/base.py +63 -0
  25. eva/language/data/datasets/classification/pubmedqa.py +149 -0
  26. eva/language/data/datasets/language.py +13 -0
  27. eva/language/models/__init__.py +25 -0
  28. eva/language/models/modules/__init__.py +5 -0
  29. eva/language/models/modules/text.py +85 -0
  30. eva/language/models/modules/typings.py +16 -0
  31. eva/language/models/wrappers/__init__.py +11 -0
  32. eva/language/models/wrappers/huggingface.py +69 -0
  33. eva/language/models/wrappers/litellm.py +77 -0
  34. eva/language/models/wrappers/vllm.py +149 -0
  35. eva/language/utils/__init__.py +5 -0
  36. eva/language/utils/str_to_int_tensor.py +95 -0
  37. eva/vision/data/dataloaders/__init__.py +2 -1
  38. eva/vision/data/dataloaders/worker_init.py +35 -0
  39. eva/vision/data/datasets/__init__.py +5 -5
  40. eva/vision/data/datasets/segmentation/__init__.py +4 -4
  41. eva/vision/data/datasets/segmentation/btcv.py +3 -0
  42. eva/vision/data/datasets/segmentation/consep.py +5 -4
  43. eva/vision/data/datasets/segmentation/lits17.py +231 -0
  44. eva/vision/data/datasets/segmentation/metadata/__init__.py +1 -0
  45. eva/vision/data/datasets/segmentation/metadata/_msd_task7_pancreas.py +287 -0
  46. eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +243 -0
  47. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +1 -1
  48. eva/vision/data/transforms/__init__.py +11 -2
  49. eva/vision/data/transforms/base/__init__.py +5 -0
  50. eva/vision/data/transforms/base/monai.py +27 -0
  51. eva/vision/data/transforms/common/__init__.py +2 -1
  52. eva/vision/data/transforms/common/squeeze.py +24 -0
  53. eva/vision/data/transforms/croppad/__init__.py +4 -0
  54. eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +74 -0
  55. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -2
  56. eva/vision/data/transforms/croppad/rand_spatial_crop.py +89 -0
  57. eva/vision/data/transforms/intensity/rand_scale_intensity.py +6 -2
  58. eva/vision/data/transforms/intensity/rand_shift_intensity.py +8 -4
  59. eva/vision/models/modules/semantic_segmentation.py +18 -7
  60. eva/vision/models/networks/backbones/__init__.py +2 -3
  61. eva/vision/models/networks/backbones/_utils.py +1 -1
  62. eva/vision/models/networks/backbones/pathology/bioptimus.py +4 -4
  63. eva/vision/models/networks/backbones/pathology/gigapath.py +2 -2
  64. eva/vision/models/networks/backbones/pathology/histai.py +3 -3
  65. eva/vision/models/networks/backbones/pathology/hkust.py +2 -2
  66. eva/vision/models/networks/backbones/pathology/kaiko.py +7 -7
  67. eva/vision/models/networks/backbones/pathology/lunit.py +3 -3
  68. eva/vision/models/networks/backbones/pathology/mahmood.py +3 -3
  69. eva/vision/models/networks/backbones/pathology/owkin.py +3 -3
  70. eva/vision/models/networks/backbones/pathology/paige.py +3 -3
  71. eva/vision/models/networks/backbones/radiology/swin_unetr.py +2 -2
  72. eva/vision/models/networks/backbones/radiology/voco.py +5 -5
  73. eva/vision/models/networks/backbones/registry.py +2 -44
  74. eva/vision/models/networks/backbones/timm/backbones.py +2 -2
  75. eva/vision/models/networks/backbones/universal/__init__.py +8 -1
  76. eva/vision/models/networks/backbones/universal/vit.py +53 -3
  77. eva/vision/models/networks/decoders/segmentation/decoder2d.py +1 -1
  78. eva/vision/models/networks/decoders/segmentation/linear.py +1 -1
  79. eva/vision/models/networks/decoders/segmentation/semantic/common.py +2 -2
  80. eva/vision/models/networks/decoders/segmentation/typings.py +1 -1
  81. eva/vision/models/wrappers/from_registry.py +14 -9
  82. eva/vision/models/wrappers/from_timm.py +6 -5
  83. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.0.dist-info}/METADATA +10 -2
  84. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.0.dist-info}/RECORD +88 -57
  85. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.0.dist-info}/WHEEL +1 -1
  86. eva/vision/data/datasets/segmentation/lits.py +0 -199
  87. eva/vision/data/datasets/segmentation/lits_balanced.py +0 -94
  88. /eva/vision/data/datasets/segmentation/{_total_segmentator.py → metadata/_total_segmentator.py} +0 -0
  89. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.0.dist-info}/entry_points.txt +0 -0
  90. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,231 @@
1
+ """LiTS17 dataset."""
2
+
3
+ import glob
4
+ import os
5
+ import re
6
+ from typing import Any, Callable, Dict, List, Literal, Tuple
7
+
8
+ from torchvision import tv_tensors
9
+ from typing_extensions import override
10
+
11
+ from eva.core.utils import requirements
12
+ from eva.vision.data import tv_tensors as eva_tv_tensors
13
+ from eva.vision.data.datasets import _utils as _data_utils
14
+ from eva.vision.data.datasets.segmentation import _utils
15
+ from eva.vision.data.datasets.vision import VisionDataset
16
+
17
+
18
+ class LiTS17(VisionDataset[eva_tv_tensors.Volume, tv_tensors.Mask]):
19
+ """LiTS17 - Liver Tumor Segmentation Challenge 2017.
20
+
21
+ More info:
22
+ - The Liver Tumor Segmentation Benchmark (LiTS)
23
+ https://arxiv.org/pdf/1901.04056
24
+ - Dataset Split
25
+ https://github.com/Luffy03/Large-Scale-Medical/blob/main/Downstream/monai/LiTs/dataset_lits.json
26
+ - Data needs to be manually downloaded from:
27
+ https://drive.google.com/drive/folders/0B0vscETPGI1-Q1h1WFdEM2FHSUE
28
+ """
29
+
30
+ _train_index_ranges: List[Tuple[int, int]] = [
31
+ (0, 2),
32
+ (4, 14),
33
+ (15, 16),
34
+ (18, 48),
35
+ (50, 51),
36
+ (52, 57),
37
+ (58, 65),
38
+ (66, 67),
39
+ (71, 74),
40
+ (75, 81),
41
+ (82, 85),
42
+ (86, 92),
43
+ (93, 99),
44
+ (102, 103),
45
+ (104, 116),
46
+ (117, 123),
47
+ (124, 126),
48
+ (127, 131),
49
+ ]
50
+ """Train range indices."""
51
+
52
+ _val_index_ranges: List[Tuple[int, int]] = [
53
+ (2, 4),
54
+ (14, 15),
55
+ (16, 18),
56
+ (48, 50),
57
+ (51, 52),
58
+ (57, 58),
59
+ (65, 66),
60
+ (67, 68),
61
+ (70, 71),
62
+ (74, 75),
63
+ (81, 82),
64
+ (85, 86),
65
+ (92, 93),
66
+ (99, 102),
67
+ (103, 104),
68
+ (116, 117),
69
+ (123, 124),
70
+ (126, 127),
71
+ ]
72
+ """Validation range indices."""
73
+
74
+ _split_index_ranges = {
75
+ "train": _train_index_ranges,
76
+ "val": _val_index_ranges,
77
+ None: [(0, 128)],
78
+ }
79
+ """Sample indices for the dataset splits."""
80
+
81
+ def __init__(
82
+ self,
83
+ root: str,
84
+ split: Literal["train", "val"] | None = None,
85
+ transforms: Callable | None = None,
86
+ ) -> None:
87
+ """Initializes the dataset.
88
+
89
+ Args:
90
+ root: Path to the dataset root directory.
91
+ split: Dataset split to use ('train' or 'val').
92
+ If None, it uses the full dataset.
93
+ transforms: A callable object for applying data transformations.
94
+ If None, no transformations are applied.
95
+ """
96
+ super().__init__()
97
+
98
+ self._root = root
99
+ self._split = split
100
+ self._transforms = transforms
101
+
102
+ self._samples: Dict[int, Tuple[str, str]]
103
+ self._indices: List[int]
104
+
105
+ @property
106
+ @override
107
+ def classes(self) -> List[str]:
108
+ return ["background", "liver", "tumor"]
109
+
110
+ @property
111
+ @override
112
+ def class_to_idx(self) -> Dict[str, int]:
113
+ return {label: index for index, label in enumerate(self.classes)}
114
+
115
+ @override
116
+ def filename(self, index: int) -> str:
117
+ return os.path.relpath(self._samples[self._indices[index]][0], self._root)
118
+
119
+ @override
120
+ def configure(self) -> None:
121
+ self._samples = self._find_samples()
122
+ self._indices = self._make_indices()
123
+
124
+ @override
125
+ def validate(self) -> None:
126
+ requirements.check_dependencies(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
127
+
128
+ def _valid_sample(index: int) -> bool:
129
+ """Indicates if the sample files exist and are reachable."""
130
+ volume_file, segmentation_file = self._samples[self._indices[index]]
131
+ return os.path.isfile(volume_file) and os.path.isfile(segmentation_file)
132
+
133
+ if len(self._samples) < len(self._indices):
134
+ raise OSError(f"Dataset is missing {len(self._indices) - len(self._samples)} files.")
135
+
136
+ invalid_samples = [self._samples[i] for i in range(len(self)) if not _valid_sample(i)]
137
+ if invalid_samples:
138
+ raise OSError(
139
+ f"Dataset '{self.__class__.__qualname__}' contains missing or "
140
+ f"corrupted samples ({len(invalid_samples)} in total). "
141
+ f"Examples of missing folders: {str(invalid_samples[:10])[:-1]}, ...]. "
142
+ )
143
+
144
+ @override
145
+ def __getitem__(
146
+ self, index: int
147
+ ) -> tuple[eva_tv_tensors.Volume, tv_tensors.Mask, dict[str, Any]]:
148
+ volume = self.load_data(index)
149
+ mask = self.load_target(index)
150
+ metadata = self.load_metadata(index) or {}
151
+ volume_tensor, mask_tensor = self._apply_transforms(volume, mask)
152
+ return volume_tensor, mask_tensor, metadata
153
+
154
+ @override
155
+ def __len__(self) -> int:
156
+ return len(self._indices)
157
+
158
+ @override
159
+ def load_data(self, index: int) -> eva_tv_tensors.Volume:
160
+ """Loads the CT volume for a given sample.
161
+
162
+ Args:
163
+ index: The index of the desired sample.
164
+
165
+ Returns:
166
+ Tensor representing the CT volume of shape `[T, C, H, W]`.
167
+ """
168
+ ct_scan_file, _ = self._samples[self._indices[index]]
169
+ return _utils.load_volume_tensor(ct_scan_file)
170
+
171
+ @override
172
+ def load_target(self, index: int) -> tv_tensors.Mask:
173
+ """Loads the segmentation mask for a given sample.
174
+
175
+ Args:
176
+ index: The index of the desired sample.
177
+
178
+ Returns:
179
+ Tensor representing the segmentation mask of shape `[T, C, H, W]`.
180
+ """
181
+ ct_scan_file, mask_file = self._samples[self._indices[index]]
182
+ return _utils.load_mask_tensor(mask_file, ct_scan_file)
183
+
184
+ def _apply_transforms(
185
+ self, ct_scan: eva_tv_tensors.Volume, mask: tv_tensors.Mask
186
+ ) -> Tuple[eva_tv_tensors.Volume, tv_tensors.Mask]:
187
+ """Applies transformations to the provided data.
188
+
189
+ Args:
190
+ ct_scan: The CT volume tensor.
191
+ mask: The segmentation mask tensor.
192
+
193
+ Returns:
194
+ A tuple containing the transformed CT and mask tensors.
195
+ """
196
+ return self._transforms(ct_scan, mask) if self._transforms else (ct_scan, mask)
197
+
198
+ def _find_samples(self) -> Dict[int, Tuple[str, str]]:
199
+ """Retrieves the file paths for the CT volumes and segmentation.
200
+
201
+ Returns:
202
+ The a dictionary mapping file IDs to tuples of volume and segmentation file paths.
203
+ """
204
+
205
+ def filename_id(filename: str) -> int:
206
+ matches = re.match(r".*(?:\D|^)(\d+)", filename)
207
+ if matches is None:
208
+ raise ValueError(f"Filename '{filename}' is not valid.")
209
+
210
+ return int(matches.group(1))
211
+
212
+ volume_files_pattern = os.path.join(self._root, "**", "volume-*.nii")
213
+ volume_filenames = glob.glob(volume_files_pattern, recursive=True)
214
+ volume_ids = {filename_id(filename): filename for filename in volume_filenames}
215
+
216
+ segmentation_files_pattern = os.path.join(self._root, "**", "segmentation-*.nii")
217
+ segmentation_filenames = glob.glob(segmentation_files_pattern, recursive=True)
218
+ segmentation_ids = {filename_id(filename): filename for filename in segmentation_filenames}
219
+
220
+ return {
221
+ file_id: (volume_ids[file_id], segmentation_ids[file_id])
222
+ for file_id in sorted(volume_ids.keys() & segmentation_ids.keys())
223
+ }
224
+
225
+ def _make_indices(self) -> List[int]:
226
+ """Builds the dataset indices for the specified split."""
227
+ index_ranges = self._split_index_ranges.get(self._split)
228
+ if index_ranges is None:
229
+ raise ValueError("Invalid data split. Use 'train', 'val' or `None`.")
230
+
231
+ return _data_utils.ranges_to_indices(index_ranges)
@@ -0,0 +1 @@
1
+ """Dataset Metadata."""
@@ -0,0 +1,287 @@
1
+ """File IDs for the MSDTask7Pancreas dataset splits."""
2
+
3
+ train_ids = [
4
+ 1,
5
+ 4,
6
+ 5,
7
+ 6,
8
+ 10,
9
+ 15,
10
+ 16,
11
+ 18,
12
+ 21,
13
+ 24,
14
+ 28,
15
+ 29,
16
+ 37,
17
+ 41,
18
+ 42,
19
+ 46,
20
+ 48,
21
+ 49,
22
+ 50,
23
+ 51,
24
+ 52,
25
+ 55,
26
+ 56,
27
+ 58,
28
+ 61,
29
+ 64,
30
+ 66,
31
+ 67,
32
+ 70,
33
+ 71,
34
+ 74,
35
+ 75,
36
+ 77,
37
+ 78,
38
+ 80,
39
+ 81,
40
+ 83,
41
+ 84,
42
+ 86,
43
+ 89,
44
+ 91,
45
+ 92,
46
+ 93,
47
+ 95,
48
+ 98,
49
+ 99,
50
+ 103,
51
+ 104,
52
+ 105,
53
+ 106,
54
+ 107,
55
+ 109,
56
+ 110,
57
+ 111,
58
+ 113,
59
+ 114,
60
+ 117,
61
+ 119,
62
+ 122,
63
+ 124,
64
+ 126,
65
+ 127,
66
+ 129,
67
+ 130,
68
+ 137,
69
+ 138,
70
+ 140,
71
+ 145,
72
+ 147,
73
+ 148,
74
+ 149,
75
+ 157,
76
+ 158,
77
+ 159,
78
+ 160,
79
+ 165,
80
+ 166,
81
+ 167,
82
+ 169,
83
+ 170,
84
+ 172,
85
+ 173,
86
+ 175,
87
+ 178,
88
+ 179,
89
+ 180,
90
+ 181,
91
+ 182,
92
+ 186,
93
+ 187,
94
+ 191,
95
+ 193,
96
+ 194,
97
+ 196,
98
+ 197,
99
+ 200,
100
+ 201,
101
+ 203,
102
+ 204,
103
+ 207,
104
+ 209,
105
+ 210,
106
+ 211,
107
+ 212,
108
+ 213,
109
+ 214,
110
+ 215,
111
+ 217,
112
+ 218,
113
+ 219,
114
+ 222,
115
+ 224,
116
+ 225,
117
+ 226,
118
+ 227,
119
+ 228,
120
+ 229,
121
+ 230,
122
+ 231,
123
+ 234,
124
+ 235,
125
+ 236,
126
+ 239,
127
+ 241,
128
+ 242,
129
+ 243,
130
+ 244,
131
+ 246,
132
+ 247,
133
+ 249,
134
+ 253,
135
+ 254,
136
+ 255,
137
+ 256,
138
+ 258,
139
+ 259,
140
+ 261,
141
+ 262,
142
+ 264,
143
+ 265,
144
+ 266,
145
+ 267,
146
+ 268,
147
+ 269,
148
+ 270,
149
+ 274,
150
+ 275,
151
+ 276,
152
+ 277,
153
+ 278,
154
+ 279,
155
+ 280,
156
+ 283,
157
+ 284,
158
+ 285,
159
+ 286,
160
+ 287,
161
+ 289,
162
+ 290,
163
+ 291,
164
+ 292,
165
+ 293,
166
+ 294,
167
+ 295,
168
+ 296,
169
+ 297,
170
+ 298,
171
+ 299,
172
+ 300,
173
+ 301,
174
+ 302,
175
+ 303,
176
+ 304,
177
+ 305,
178
+ 308,
179
+ 309,
180
+ 310,
181
+ 311,
182
+ 312,
183
+ 313,
184
+ 315,
185
+ 316,
186
+ 318,
187
+ 320,
188
+ 321,
189
+ 323,
190
+ 325,
191
+ 326,
192
+ 327,
193
+ 328,
194
+ 329,
195
+ 330,
196
+ 331,
197
+ 333,
198
+ 334,
199
+ 336,
200
+ 339,
201
+ 342,
202
+ 343,
203
+ 344,
204
+ 345,
205
+ 346,
206
+ 347,
207
+ 348,
208
+ 350,
209
+ 351,
210
+ 354,
211
+ 355,
212
+ 356,
213
+ 357,
214
+ 358,
215
+ 360,
216
+ 361,
217
+ 362,
218
+ 364,
219
+ 365,
220
+ 366,
221
+ 367,
222
+ 369,
223
+ 370,
224
+ 372,
225
+ 374,
226
+ 375,
227
+ 376,
228
+ 377,
229
+ 378,
230
+ 379,
231
+ 380,
232
+ 382,
233
+ 385,
234
+ 386,
235
+ 387,
236
+ 388,
237
+ 389,
238
+ 391,
239
+ 392,
240
+ 393,
241
+ 395,
242
+ 398,
243
+ 399,
244
+ 400,
245
+ 401,
246
+ 402,
247
+ 404,
248
+ 405,
249
+ 406,
250
+ 409,
251
+ 410,
252
+ 411,
253
+ 412,
254
+ 413,
255
+ 414,
256
+ 415,
257
+ 416,
258
+ 418,
259
+ 419,
260
+ 421,
261
+ ]
262
+ val_ids = [
263
+ 12,
264
+ 19,
265
+ 25,
266
+ 32,
267
+ 35,
268
+ 40,
269
+ 43,
270
+ 45,
271
+ 69,
272
+ 87,
273
+ 88,
274
+ 94,
275
+ 96,
276
+ 100,
277
+ 101,
278
+ 102,
279
+ 120,
280
+ 125,
281
+ 131,
282
+ 135,
283
+ 155,
284
+ 183,
285
+ 198,
286
+ 199,
287
+ ]