kaiko-eva 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

@@ -102,6 +102,9 @@ class Interface:
102
102
  model: The model module to use but not modify.
103
103
  data: The data module containing validation data.
104
104
  """
105
+ if getattr(data.datasets, "val", None) is None:
106
+ raise ValueError("The provided data module does not contain a validation dataset.")
107
+
105
108
  eva_trainer.run_evaluation_session(
106
109
  base_trainer=trainer,
107
110
  base_model=model,
@@ -110,3 +113,22 @@ class Interface:
110
113
  n_runs=trainer.n_runs,
111
114
  verbose=trainer.n_runs > 1,
112
115
  )
116
+
117
+ def test(
118
+ self,
119
+ trainer: eva_trainer.Trainer,
120
+ model: modules.ModelModule,
121
+ data: datamodules.DataModule,
122
+ ) -> None:
123
+ """Same as validate, but runs the test stage."""
124
+ if getattr(data.datasets, "test", None) is None:
125
+ raise ValueError("The provided data module does not contain a test dataset.")
126
+
127
+ eva_trainer.run_evaluation_session(
128
+ base_trainer=trainer,
129
+ base_model=model,
130
+ datamodule=data,
131
+ stages=["test"],
132
+ n_runs=trainer.n_runs,
133
+ verbose=trainer.n_runs > 1,
134
+ )
@@ -86,10 +86,13 @@ class SessionRecorder:
86
86
 
87
87
  def update(
88
88
  self,
89
- validation_scores: _EVALUATE_OUTPUT,
89
+ validation_scores: _EVALUATE_OUTPUT | None = None,
90
90
  test_scores: _EVALUATE_OUTPUT | None = None,
91
91
  ) -> None:
92
92
  """Updates the state of the tracked metrics in-place."""
93
+ if validation_scores is None and test_scores is None:
94
+ raise ValueError("At least one of validation_scores or test_scores must be provided.")
95
+
93
96
  self._update_validation_metrics(validation_scores)
94
97
  self._update_test_metrics(test_scores)
95
98
 
@@ -117,9 +120,10 @@ class SessionRecorder:
117
120
  self._validation_metrics = []
118
121
  self._test_metrics = []
119
122
 
120
- def _update_validation_metrics(self, metrics: _EVALUATE_OUTPUT) -> None:
123
+ def _update_validation_metrics(self, metrics: _EVALUATE_OUTPUT | None) -> None:
121
124
  """Updates the validation metrics in-place."""
122
- self._validation_metrics = _update_session_metrics(self._validation_metrics, metrics)
125
+ if metrics:
126
+ self._validation_metrics = _update_session_metrics(self._validation_metrics, metrics)
123
127
 
124
128
  def _update_test_metrics(self, metrics: _EVALUATE_OUTPUT | None) -> None:
125
129
  """Updates the test metrics in-place."""
@@ -47,7 +47,7 @@ def run_evaluation_session(
47
47
  stages=stages,
48
48
  verbose=not verbose,
49
49
  )
50
- if validation_scores:
50
+ if validation_scores or test_scores:
51
51
  recorder.update(validation_scores, test_scores)
52
52
  recorder.save()
53
53
 
@@ -89,7 +89,7 @@ def run_evaluation(
89
89
 
90
90
  if "fit" in stages:
91
91
  trainer.fit(model, datamodule=datamodule)
92
- if "validate" in stages:
92
+ if "validate" in stages and getattr(datamodule.datasets, "val", None) is not None:
93
93
  validation_scores = trainer.validate(
94
94
  model=model,
95
95
  datamodule=datamodule,
@@ -22,7 +22,6 @@ from eva.vision.data.datasets.segmentation import (
22
22
  LiTS17,
23
23
  MoNuSAC,
24
24
  MSDTask7Pancreas,
25
- TotalSegmentator2D,
26
25
  )
27
26
  from eva.vision.data.datasets.vision import VisionDataset
28
27
  from eva.vision.data.datasets.wsi import MultiWsiDataset, WsiDataset
@@ -40,7 +39,6 @@ __all__ = [
40
39
  "PANDASmall",
41
40
  "Camelyon16",
42
41
  "PatchCamelyon",
43
- "TotalSegmentator2D",
44
42
  "UniToPatho",
45
43
  "WsiClassificationDataset",
46
44
  "CoNSeP",
@@ -7,7 +7,6 @@ from eva.vision.data.datasets.segmentation.embeddings import EmbeddingsSegmentat
7
7
  from eva.vision.data.datasets.segmentation.lits17 import LiTS17
8
8
  from eva.vision.data.datasets.segmentation.monusac import MoNuSAC
9
9
  from eva.vision.data.datasets.segmentation.msd_task7_pancreas import MSDTask7Pancreas
10
- from eva.vision.data.datasets.segmentation.total_segmentator_2d import TotalSegmentator2D
11
10
 
12
11
  __all__ = [
13
12
  "BCSS",
@@ -17,5 +16,4 @@ __all__ = [
17
16
  "LiTS17",
18
17
  "MSDTask7Pancreas",
19
18
  "MoNuSAC",
20
- "TotalSegmentator2D",
21
19
  ]
@@ -9,6 +9,7 @@ from eva.vision.models.networks.decoders.segmentation.semantic import (
9
9
  ConvDecoderWithImage,
10
10
  SingleLinearDecoder,
11
11
  SwinUNETRDecoder,
12
+ SwinUNETRDecoderWithProjection,
12
13
  )
13
14
 
14
15
  __all__ = [
@@ -20,4 +21,5 @@ __all__ = [
20
21
  "LinearDecoder",
21
22
  "SingleLinearDecoder",
22
23
  "SwinUNETRDecoder",
24
+ "SwinUNETRDecoderWithProjection",
23
25
  ]
@@ -5,7 +5,10 @@ from eva.vision.models.networks.decoders.segmentation.semantic.common import (
5
5
  ConvDecoderMS,
6
6
  SingleLinearDecoder,
7
7
  )
8
- from eva.vision.models.networks.decoders.segmentation.semantic.swin_unetr import SwinUNETRDecoder
8
+ from eva.vision.models.networks.decoders.segmentation.semantic.swin_unetr import (
9
+ SwinUNETRDecoder,
10
+ SwinUNETRDecoderWithProjection,
11
+ )
9
12
  from eva.vision.models.networks.decoders.segmentation.semantic.with_image import (
10
13
  ConvDecoderWithImage,
11
14
  )
@@ -16,4 +19,5 @@ __all__ = [
16
19
  "ConvDecoderWithImage",
17
20
  "SingleLinearDecoder",
18
21
  "SwinUNETRDecoder",
22
+ "SwinUNETRDecoderWithProjection",
19
23
  ]
@@ -102,3 +102,154 @@ class SwinUNETRDecoder(nn.Module):
102
102
  (batch_size, n_classes, image_height, image_width).
103
103
  """
104
104
  return self._forward_features(features)
105
+
106
+
107
+ class SwinUNETRDecoderWithProjection(nn.Module):
108
+ """Swin transformer decoder based on UNETR [0].
109
+
110
+ This implementation adds additional projection layers to reduce
111
+ the number of channels in the feature maps before applying the upscaling
112
+ convolutional blocks. This reduces the number of trainable parameters
113
+ significantly and is useful when scaling up the encoder architecture.
114
+
115
+ - [0] UNETR: Transformers for 3D Medical Image Segmentation
116
+ https://arxiv.org/pdf/2103.10504
117
+ """
118
+
119
+ def __init__(
120
+ self,
121
+ out_channels: int,
122
+ feature_size: int = 48,
123
+ spatial_dims: int = 3,
124
+ project_dims: list[int] | None = None,
125
+ checkpoint_path: str | None = None,
126
+ ) -> None:
127
+ """Builds the decoder.
128
+
129
+ Args:
130
+ out_channels: Number of output channels.
131
+ feature_size: Dimension of network feature size.
132
+ spatial_dims: Number of spatial dimensions.
133
+ project_dims: List of 6 dimensions to project encoder features to.
134
+ If None, uses default channel progression based on feature_size.
135
+ This is not part of the original implementation, but helps
136
+ to reduce the number of decoder parameters when scaling up
137
+ the encoder architecture (feature_size).
138
+ checkpoint_path: Path to the checkpoint file.
139
+ """
140
+ super().__init__()
141
+
142
+ self._checkpoint_path = checkpoint_path
143
+ self._project_dims = project_dims
144
+
145
+ if project_dims is not None and len(project_dims) != 6:
146
+ raise ValueError(
147
+ f"project_dims must have exactly 6 dimensions, got {len(project_dims)}"
148
+ )
149
+
150
+ channel_dims = project_dims or [
151
+ feature_size,
152
+ feature_size,
153
+ feature_size * 2,
154
+ feature_size * 4,
155
+ feature_size * 8,
156
+ feature_size * 16,
157
+ ]
158
+
159
+ self.decoder5 = unetr_block.UnetrUpBlock(
160
+ spatial_dims=spatial_dims,
161
+ in_channels=channel_dims[5],
162
+ out_channels=channel_dims[4],
163
+ kernel_size=3,
164
+ upsample_kernel_size=2,
165
+ norm_name="instance",
166
+ res_block=True,
167
+ )
168
+ self.decoder4 = unetr_block.UnetrUpBlock(
169
+ spatial_dims=spatial_dims,
170
+ in_channels=channel_dims[4],
171
+ out_channels=channel_dims[3],
172
+ kernel_size=3,
173
+ upsample_kernel_size=2,
174
+ norm_name="instance",
175
+ res_block=True,
176
+ )
177
+ self.decoder3 = unetr_block.UnetrUpBlock(
178
+ spatial_dims=spatial_dims,
179
+ in_channels=channel_dims[3],
180
+ out_channels=channel_dims[2],
181
+ kernel_size=3,
182
+ upsample_kernel_size=2,
183
+ norm_name="instance",
184
+ res_block=True,
185
+ )
186
+ self.decoder2 = unetr_block.UnetrUpBlock(
187
+ spatial_dims=spatial_dims,
188
+ in_channels=channel_dims[2],
189
+ out_channels=channel_dims[1],
190
+ kernel_size=3,
191
+ upsample_kernel_size=2,
192
+ norm_name="instance",
193
+ res_block=True,
194
+ )
195
+ self.decoder1 = unetr_block.UnetrUpBlock(
196
+ spatial_dims=spatial_dims,
197
+ in_channels=channel_dims[1],
198
+ out_channels=channel_dims[0],
199
+ kernel_size=3,
200
+ upsample_kernel_size=2,
201
+ norm_name="instance",
202
+ res_block=True,
203
+ )
204
+ self.out = dynunet_block.UnetOutBlock(
205
+ spatial_dims=spatial_dims,
206
+ in_channels=channel_dims[0],
207
+ out_channels=out_channels,
208
+ )
209
+
210
+ if self._project_dims:
211
+ conv_layer = nn.Conv2d if spatial_dims == 2 else nn.Conv3d
212
+ self.proj_enc0 = conv_layer(feature_size, self._project_dims[0], kernel_size=1)
213
+ self.proj_enc1 = conv_layer(feature_size, self._project_dims[1], kernel_size=1)
214
+ self.proj_enc2 = conv_layer(feature_size * 2, self._project_dims[2], kernel_size=1)
215
+ self.proj_enc3 = conv_layer(feature_size * 4, self._project_dims[3], kernel_size=1)
216
+ self.proj_hid3 = conv_layer(feature_size * 8, self._project_dims[4], kernel_size=1)
217
+ self.proj_dec4 = conv_layer(feature_size * 16, self._project_dims[5], kernel_size=1)
218
+
219
+ def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor:
220
+ """Forward function for multi-level feature maps to a single one."""
221
+ enc0, enc1, enc2, enc3, hid3, dec4 = self._project_features(features)
222
+ dec3 = self.decoder5(dec4, hid3)
223
+ dec2 = self.decoder4(dec3, enc3)
224
+ dec1 = self.decoder3(dec2, enc2)
225
+ dec0 = self.decoder2(dec1, enc1)
226
+ out = self.decoder1(dec0, enc0)
227
+ return self.out(out)
228
+
229
+ def _project_features(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
230
+ """Projects features using 1x1 to reduce number of channels."""
231
+ return (
232
+ [
233
+ self.proj_enc0(features[0]),
234
+ self.proj_enc1(features[1]),
235
+ self.proj_enc2(features[2]),
236
+ self.proj_enc3(features[3]),
237
+ self.proj_hid3(features[4]),
238
+ self.proj_dec4(features[5]),
239
+ ]
240
+ if self._project_dims
241
+ else features
242
+ )
243
+
244
+ def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
245
+ """Maps the patch embeddings to a segmentation mask.
246
+
247
+ Args:
248
+ features: List of multi-level intermediate features from
249
+ :class:`SwinUNETREncoder`.
250
+
251
+ Returns:
252
+ Tensor containing scores for all of the classes with shape
253
+ (batch_size, n_classes, image_height, image_width).
254
+ """
255
+ return self._forward_features(features)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: kaiko-eva
3
- Version: 0.3.1
3
+ Version: 0.3.2
4
4
  Summary: Evaluation Framework for oncology foundation models.
5
5
  Keywords: machine-learning,evaluation-framework,oncology,foundation-models
6
6
  Author-Email: Ioannis Gatopoulos <ioannis@kaiko.ai>, =?utf-8?q?Nicolas_K=C3=A4nzig?= <nicolas@kaiko.ai>, Roman Moser <roman@kaiko.ai>
@@ -48,7 +48,7 @@ eva/core/data/transforms/padding/pad_2d_tensor.py,sha256=J4maGFmeQf9IHRxt5kU-6eI
48
48
  eva/core/data/transforms/sampling/__init__.py,sha256=BFKbvRjlZrwS0GcNrM54ZSWt6PrQARfFlXM1jJ-wpvo,149
49
49
  eva/core/data/transforms/sampling/sample_from_axis.py,sha256=Zbhp94lVa70WQKmSOKMTsOMe2c7wLqNZto7JqWhSdtI,1229
50
50
  eva/core/interface/__init__.py,sha256=chdpKXipxe1NP-Fgr_d9r6X1gMna0XiEa38waJ6FzTM,98
51
- eva/core/interface/interface.py,sha256=VNagHKsr2T7Ufm1rpA0NCnMi6F2qpKEmMBq_3eGZIRE,3826
51
+ eva/core/interface/interface.py,sha256=EeBrIqUlO497BOZGZrtO-fZnc_BhIJrrqyQmRfqWXcw,4591
52
52
  eva/core/loggers/__init__.py,sha256=4YMLNlN9LnuKqhBI1R1keh69dmMD-2lcH3HKwwyn380,266
53
53
  eva/core/loggers/dummy.py,sha256=Y7ypH0ecSAIkkZ5LzTmNNEzlKkqeaHfUNMCDKVOg6D4,1204
54
54
  eva/core/loggers/experimental_loggers.py,sha256=p5uCK_9QCYufRhE-LZQUJWbhGElyobX_zRM78yX4p2o,230
@@ -99,9 +99,9 @@ eva/core/models/wrappers/huggingface.py,sha256=-_fA81YRnoMc7O7SbrnCEj1dM_xArpQ8W
99
99
  eva/core/models/wrappers/onnx.py,sha256=34li_xSwPryN8nJDrFyif_Hve1AEH7Ry9E_lZmf7JJM,1834
100
100
  eva/core/trainers/__init__.py,sha256=jhsKJF7HAae7EOiG3gKIAHH_h3dZlTE2JRcCHJmOzJc,208
101
101
  eva/core/trainers/_logging.py,sha256=gi4FqPy2GuVmh0WZY6mYwF7zMPvnoFA050B0XdCP6PU,2571
102
- eva/core/trainers/_recorder.py,sha256=uD17l_WVveFuWuann59VU9iJ-Jumdh9F6vnAcL3M_FU,7855
102
+ eva/core/trainers/_recorder.py,sha256=M-BJHLgqGxR_MSV6f_WC7GN2JHYEEinV1-hNLpH667A,8062
103
103
  eva/core/trainers/_utils.py,sha256=M3h8lVhUmkeSiEXpX9hRdMvThGFCnTP15gv-hd1CZkc,321
104
- eva/core/trainers/functional.py,sha256=tsBfpXjEQ8BiBJ9wZWp0AUUOOxy7UUrLX4GSjQZTeCs,4510
104
+ eva/core/trainers/functional.py,sha256=_Mw-NIPU2tPffxpK5t3sHBmVI6u163phCpoJFiauH7E,4583
105
105
  eva/core/trainers/trainer.py,sha256=a3OwLWOZKDqxayrd0ugUmxJKyQx6XDb4GHtdL8-AEV0,4826
106
106
  eva/core/utils/__init__.py,sha256=cndVBvtYxEW7hykH39GCNVI86zkXNn8Lw2A0sUJHS04,237
107
107
  eva/core/utils/clone.py,sha256=qcThZOuAs1cs0uV3BL5eKeM2VIBjuRPBe1t-NiUFM5Y,569
@@ -146,7 +146,7 @@ eva/vision/data/dataloaders/__init__.py,sha256=9ykBD4vyZ-Yv3IEnqvVcSMURS-gXWjOun
146
146
  eva/vision/data/dataloaders/collate_fn/__init__.py,sha256=GCvJaeILmAc_-lhGw8yzj2cC2KG4i1PvSWAyVzPKvVo,146
147
147
  eva/vision/data/dataloaders/collate_fn/collection.py,sha256=45s9fKjVBnqfnuGWmJZMtt_DDGnfuf7qkWe0QmxXMKo,611
148
148
  eva/vision/data/dataloaders/worker_init.py,sha256=lFWywHGCC4QxHeDXrneF8DQ45XG3WmVltEELJrPyLz0,1182
149
- eva/vision/data/datasets/__init__.py,sha256=s3h4w71LiM6dT6AYWzCG2-nexkSuuTWixw4KrCGAhS8,1026
149
+ eva/vision/data/datasets/__init__.py,sha256=_04LqKv46oUXdmQAlSmWkgZYueHFwM-0iiOSMuFnFDw,976
150
150
  eva/vision/data/datasets/_utils.py,sha256=epPcaYE4w2_LtUKLLQJh6qQxUNVBe22JA06k4WUerYQ,1430
151
151
  eva/vision/data/datasets/_validators.py,sha256=77WZj8ewsuxUjW5WegJ-7zDuR6WdF5JbaOYdywhKIK4,2594
152
152
  eva/vision/data/datasets/classification/__init__.py,sha256=5fOGZxKGPeMCf3Jd9qAOYADPrkZnYg97_QE4DC79AMI,1074
@@ -161,7 +161,7 @@ eva/vision/data/datasets/classification/panda.py,sha256=HVfCvByyajdo5o_waqTpzZWC
161
161
  eva/vision/data/datasets/classification/patch_camelyon.py,sha256=1yXkfP680qxkQUFAPKRFbZv0cHAFx23s2vvT9th2nKM,7149
162
162
  eva/vision/data/datasets/classification/unitopatho.py,sha256=IO3msEsuOnmdcYZxF-eBpo0K97y54rWFmCb_KxuF4bk,5129
163
163
  eva/vision/data/datasets/classification/wsi.py,sha256=YMGxU8ECjudizt_uXUevuPS8k66HxtEQ7M2IZJmL6kE,4079
164
- eva/vision/data/datasets/segmentation/__init__.py,sha256=y_BjUj6kF-WeouSz0CCpPdOdX7n5hUrqsZGF68Xu9Hw,784
164
+ eva/vision/data/datasets/segmentation/__init__.py,sha256=f0q9tzk4ahaZfrw_SgIE_puk_D7qmkSCKX1FP9aJITU,668
165
165
  eva/vision/data/datasets/segmentation/_utils.py,sha256=aXUHrnbefP6-OgSvDQHqssFKhUwETul_8aosqYiOfm8,3065
166
166
  eva/vision/data/datasets/segmentation/bcss.py,sha256=rqk6VqK0QCHLFnMnDuHd1JPJVK5_C6WnsmnNSKBw6Uo,8230
167
167
  eva/vision/data/datasets/segmentation/btcv.py,sha256=9rlEqGyb2SGJBY6Oj42FlHajQF8csf1Jq6jeuPSsfXI,8396
@@ -170,10 +170,8 @@ eva/vision/data/datasets/segmentation/embeddings.py,sha256=RsTuAwGEJPnWPY7q3pwcj
170
170
  eva/vision/data/datasets/segmentation/lits17.py,sha256=kcSCKxsgtUuCD1YEYvrb_L_BgOtZC8xDq1lX8ldSZc4,7635
171
171
  eva/vision/data/datasets/segmentation/metadata/__init__.py,sha256=o9Od0v6N9dNdf8hfefn2QaNNCD2sZMvc2K58zHA_Nrg,24
172
172
  eva/vision/data/datasets/segmentation/metadata/_msd_task7_pancreas.py,sha256=O2-ye0A7wIjcI_D857uvpYw-jckTqfhBUrhinqSNWq0,2553
173
- eva/vision/data/datasets/segmentation/metadata/_total_segmentator.py,sha256=DTaQaAisY7j1h0-zYk1_81Sr4b3D9PTMieYX0PMPtIc,3127
174
173
  eva/vision/data/datasets/segmentation/monusac.py,sha256=iv9-MFaTsGfGV1u6_lQNcSEeSpmVBDQC1Oa123iEtu0,8410
175
174
  eva/vision/data/datasets/segmentation/msd_task7_pancreas.py,sha256=dTsPD73PAP15VOXdHnX4eQqbpz2jGpCB31YISzinUd4,8964
176
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py,sha256=TGz67AGuv8_Bm5DM5TyCtzRTuGXOuctZZNxdQtBxF1g,16987
177
175
  eva/vision/data/datasets/structs.py,sha256=RaTDW-B36PumcR5gymhCiX-r8GiKqIFcjqoEEjjFyUE,389
178
176
  eva/vision/data/datasets/vision.py,sha256=-_WRiyICMgqABR6Ay_RKBMfsPGwgx9MQfCA7WChHo24,3219
179
177
  eva/vision/data/datasets/wsi.py,sha256=dEAT_Si_Qb3qdSovUPeoiWeoPb7m-NGYqq44e3UXHk8,8384
@@ -256,13 +254,13 @@ eva/vision/models/networks/backbones/timm/backbones.py,sha256=ZbF9MMiL4Ylyy79XLe
256
254
  eva/vision/models/networks/backbones/universal/__init__.py,sha256=xgn3crSqlmUPYz-t2CR1zDKxhlyAEeApA-a6Y_eWQvc,417
257
255
  eva/vision/models/networks/backbones/universal/vit.py,sha256=To0OzwpuX5Y5PwjGidwV0Ssq3xa81dve081buwG_Ofg,3658
258
256
  eva/vision/models/networks/decoders/__init__.py,sha256=RXFWmoYw2i6E9VOUCJmU8c72icHannVuo-cUKy6fnLM,200
259
- eva/vision/models/networks/decoders/segmentation/__init__.py,sha256=SqmxtzxwBRF8g2hsiqe0o3Nr0HFK97azTnWLyqsYigY,652
257
+ eva/vision/models/networks/decoders/segmentation/__init__.py,sha256=yVrRo2OisNRAlxDjWJGwipKA9HGeqRXd1ZL88eltoy4,726
260
258
  eva/vision/models/networks/decoders/segmentation/base.py,sha256=b2TIJKiJR9vejVRpNyedMJLPTrpHhAEXvco8atb9TPU,411
261
259
  eva/vision/models/networks/decoders/segmentation/decoder2d.py,sha256=HRonYTSriiq13aZCSNiYUc484qfOhkVT0yFiMW06CDc,4472
262
260
  eva/vision/models/networks/decoders/segmentation/linear.py,sha256=ui3-Y0rl4VEF75-sUghaF29P9wpxCVlp5iR_Ym-utUE,4666
263
- eva/vision/models/networks/decoders/segmentation/semantic/__init__.py,sha256=2yol7W1ARXL-Ge7gYxjUzaGTjH6nfMBlNqQJHprEWGg,539
261
+ eva/vision/models/networks/decoders/segmentation/semantic/__init__.py,sha256=9QnepLMzQVE-wAZJXx0napVutg1HtkbDERcPsoevWGg,622
264
262
  eva/vision/models/networks/decoders/segmentation/semantic/common.py,sha256=FSf_eI-FaBroxPRJd4TiV97RCreauJh1IznIVzBT2eg,2528
265
- eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py,sha256=ODUpnJrpDQl0m8CC2SPnE_lpFflzS0GSiCZOmrjL6uQ,3373
263
+ eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py,sha256=eSFvHng2lrc-Wd4g9CW4z8-yfKndbl0c7-sKhOautBU,9170
266
264
  eva/vision/models/networks/decoders/segmentation/semantic/with_image.py,sha256=I5PyGKKo8DcXYcw4xlCFzuavRJNRrzGT-szpDidMPXI,3516
267
265
  eva/vision/models/networks/decoders/segmentation/typings.py,sha256=rY4CXp0MNF16SHnx9TgGjXI_r8bVGSqAWdR835hXndg,537
268
266
  eva/vision/models/wrappers/__init__.py,sha256=ogmr-eeVuGaOCcsuxSp6PGyauP2QqWTb8dGTtbC7lRU,210
@@ -277,8 +275,8 @@ eva/vision/utils/io/image.py,sha256=IdOkr5MYqhYHz8U9drZ7wULTM3YHwCWSjZlu_Qdl4GQ,
277
275
  eva/vision/utils/io/mat.py,sha256=qpGifyjmpE0Xhv567Si7-zxKrgkgE0sywP70cHiLFGU,808
278
276
  eva/vision/utils/io/nifti.py,sha256=TFMgNhLqIK3sl3RjIRXEABM7FmSQjqVOwk1vXkuvX2w,4983
279
277
  eva/vision/utils/io/text.py,sha256=qYgfo_ZaDZWfG02NkVVYzo5QFySqdCCz5uLA9d-zXtI,701
280
- kaiko_eva-0.3.1.dist-info/METADATA,sha256=gXYGvp6Ap95944atE7L9Dxk8AnmuVhn22sHAC2iIl_g,25704
281
- kaiko_eva-0.3.1.dist-info/WHEEL,sha256=9P2ygRxDrTJz3gsagc0Z96ukrxjr-LFBGOgv3AuKlCA,90
282
- kaiko_eva-0.3.1.dist-info/entry_points.txt,sha256=6CSLu9bmQYJSXEg8gbOzRhxH0AGs75BB-vPm3VvfcNE,88
283
- kaiko_eva-0.3.1.dist-info/licenses/LICENSE,sha256=e6AEzr7j_R-PYr2qLO-JwLn8y70jbVD3U2mxbRmwcI4,11338
284
- kaiko_eva-0.3.1.dist-info/RECORD,,
278
+ kaiko_eva-0.3.2.dist-info/METADATA,sha256=3OdB75bdgEKDkAhIh75c3WcpevnOsemYcMfBEG0MKy8,25704
279
+ kaiko_eva-0.3.2.dist-info/WHEEL,sha256=9P2ygRxDrTJz3gsagc0Z96ukrxjr-LFBGOgv3AuKlCA,90
280
+ kaiko_eva-0.3.2.dist-info/entry_points.txt,sha256=6CSLu9bmQYJSXEg8gbOzRhxH0AGs75BB-vPm3VvfcNE,88
281
+ kaiko_eva-0.3.2.dist-info/licenses/LICENSE,sha256=e6AEzr7j_R-PYr2qLO-JwLn8y70jbVD3U2mxbRmwcI4,11338
282
+ kaiko_eva-0.3.2.dist-info/RECORD,,
@@ -1,91 +0,0 @@
1
- """Utils for TotalSegmentator dataset classes."""
2
-
3
- from typing import Dict
4
-
5
- reduced_class_mappings: Dict[str, str] = {
6
- # Abdominal Organs
7
- "spleen": "spleen",
8
- "kidney_right": "kidney",
9
- "kidney_left": "kidney",
10
- "gallbladder": "gallbladder",
11
- "liver": "liver",
12
- "stomach": "stomach",
13
- "pancreas": "pancreas",
14
- "small_bowel": "small_bowel",
15
- "duodenum": "duodenum",
16
- "colon": "colon",
17
- # Endocrine System
18
- "adrenal_gland_right": "adrenal_gland",
19
- "adrenal_gland_left": "adrenal_gland",
20
- "thyroid_gland": "thyroid_gland",
21
- # Respiratory System
22
- "lung_upper_lobe_left": "lungs",
23
- "lung_lower_lobe_left": "lungs",
24
- "lung_upper_lobe_right": "lungs",
25
- "lung_middle_lobe_right": "lungs",
26
- "lung_lower_lobe_right": "lungs",
27
- "trachea": "trachea",
28
- "esophagus": "esophagus",
29
- # Urogenital System
30
- "urinary_bladder": "urogenital_system",
31
- "prostate": "urogenital_system",
32
- "kidney_cyst_left": "kidney_cyst",
33
- "kidney_cyst_right": "kidney_cyst",
34
- # Vertebral Column
35
- **{f"vertebrae_{v}": "vertebrae" for v in ["C1", "C2", "C3", "C4", "C5", "C6", "C7"]},
36
- **{f"vertebrae_{v}": "vertebrae" for v in [f"T{i}" for i in range(1, 13)]},
37
- **{f"vertebrae_{v}": "vertebrae" for v in [f"L{i}" for i in range(1, 6)]},
38
- "vertebrae_S1": "vertebrae",
39
- "sacrum": "sacral_spine",
40
- # Cardiovascular System
41
- "heart": "heart",
42
- "aorta": "aorta",
43
- "pulmonary_vein": "veins",
44
- "brachiocephalic_trunk": "arteries",
45
- "subclavian_artery_right": "arteries",
46
- "subclavian_artery_left": "arteries",
47
- "common_carotid_artery_right": "arteries",
48
- "common_carotid_artery_left": "arteries",
49
- "brachiocephalic_vein_left": "veins",
50
- "brachiocephalic_vein_right": "veins",
51
- "atrial_appendage_left": "atrial_appendage",
52
- "superior_vena_cava": "veins",
53
- "inferior_vena_cava": "veins",
54
- "portal_vein_and_splenic_vein": "veins",
55
- "iliac_artery_left": "arteries",
56
- "iliac_artery_right": "arteries",
57
- "iliac_vena_left": "veins",
58
- "iliac_vena_right": "veins",
59
- # Upper Extremity Bones
60
- "humerus_left": "humerus",
61
- "humerus_right": "humerus",
62
- "scapula_left": "scapula",
63
- "scapula_right": "scapula",
64
- "clavicula_left": "clavicula",
65
- "clavicula_right": "clavicula",
66
- # Lower Extremity Bones
67
- "femur_left": "femur",
68
- "femur_right": "femur",
69
- "hip_left": "hip",
70
- "hip_right": "hip",
71
- # Muscles
72
- "gluteus_maximus_left": "gluteus",
73
- "gluteus_maximus_right": "gluteus",
74
- "gluteus_medius_left": "gluteus",
75
- "gluteus_medius_right": "gluteus",
76
- "gluteus_minimus_left": "gluteus",
77
- "gluteus_minimus_right": "gluteus",
78
- "autochthon_left": "autochthon",
79
- "autochthon_right": "autochthon",
80
- "iliopsoas_left": "iliopsoas",
81
- "iliopsoas_right": "iliopsoas",
82
- # Central Nervous System
83
- "brain": "brain",
84
- "spinal_cord": "spinal_cord",
85
- # Skull and Thoracic Cage
86
- "skull": "skull",
87
- **{f"rib_left_{i}": "ribs" for i in range(1, 13)},
88
- **{f"rib_right_{i}": "ribs" for i in range(1, 13)},
89
- "costal_cartilages": "ribs",
90
- "sternum": "sternum",
91
- }
@@ -1,414 +0,0 @@
1
- """TotalSegmentator 2D segmentation dataset class."""
2
-
3
- import functools
4
- import hashlib
5
- import os
6
- import re
7
- from glob import glob
8
- from pathlib import Path
9
- from typing import Any, Callable, Dict, List, Literal, Tuple
10
-
11
- import numpy as np
12
- import numpy.typing as npt
13
- import torch
14
- from torchvision import tv_tensors
15
- from torchvision.datasets import utils
16
- from typing_extensions import override
17
-
18
- from eva.core.utils import io as core_io
19
- from eva.core.utils import multiprocessing
20
- from eva.vision.data.datasets import _validators, structs, vision
21
- from eva.vision.data.datasets.segmentation.metadata import _total_segmentator
22
- from eva.vision.utils import io
23
-
24
-
25
- class TotalSegmentator2D(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
26
- """TotalSegmentator 2D segmentation dataset."""
27
-
28
- _expected_dataset_lengths: Dict[str, int] = {
29
- "train_small": 35089,
30
- "val_small": 1283,
31
- "train_full": 278190,
32
- "val_full": 14095,
33
- "test_full": 25578,
34
- }
35
- """Dataset version and split to the expected size."""
36
-
37
- _sample_every_n_slices: int | None = None
38
- """The amount of slices to sub-sample per 3D CT scan image."""
39
-
40
- _resources_full: List[structs.DownloadResource] = [
41
- structs.DownloadResource(
42
- filename="Totalsegmentator_dataset_v201.zip",
43
- url="https://zenodo.org/records/10047292/files/Totalsegmentator_dataset_v201.zip",
44
- md5="fe250e5718e0a3b5df4c4ea9d58a62fe",
45
- ),
46
- ]
47
- """Resources for the full dataset version."""
48
-
49
- _resources_small: List[structs.DownloadResource] = [
50
- structs.DownloadResource(
51
- filename="Totalsegmentator_dataset_small_v201.zip",
52
- url="https://zenodo.org/records/10047263/files/Totalsegmentator_dataset_small_v201.zip",
53
- md5="6b5524af4b15e6ba06ef2d700c0c73e0",
54
- ),
55
- ]
56
- """Resources for the small dataset version."""
57
-
58
- _license: str = (
59
- "Creative Commons Attribution 4.0 International "
60
- "(https://creativecommons.org/licenses/by/4.0/deed.en)"
61
- )
62
- """Dataset license."""
63
-
64
- def __init__(
65
- self,
66
- root: str,
67
- split: Literal["train", "val", "test"] | None,
68
- version: Literal["small", "full"] | None = "full",
69
- download: bool = False,
70
- classes: List[str] | None = None,
71
- class_mappings: Dict[str, str] | None = _total_segmentator.reduced_class_mappings,
72
- optimize_mask_loading: bool = True,
73
- decompress: bool = True,
74
- num_workers: int = 10,
75
- transforms: Callable | None = None,
76
- ) -> None:
77
- """Initialize dataset.
78
-
79
- Args:
80
- root: Path to the root directory of the dataset. The dataset will
81
- be downloaded and extracted here, if it does not already exist.
82
- split: Dataset split to use. If `None`, the entire dataset is used.
83
- version: The version of the dataset to initialize. If `None`, it will
84
- use the files located at root as is and wont perform any checks.
85
- download: Whether to download the data for the specified split.
86
- Note that the download will be executed only by additionally
87
- calling the :meth:`prepare_data` method and if the data does not
88
- exist yet on disk.
89
- classes: Whether to configure the dataset with a subset of classes.
90
- If `None`, it will use all of them.
91
- class_mappings: A dictionary that maps the original class names to a
92
- reduced set of classes. If `None`, it will use the original classes.
93
- optimize_mask_loading: Whether to pre-process the segmentation masks
94
- in order to optimize the loading time. In the `setup` method, it
95
- will reformat the binary one-hot masks to a semantic mask and store
96
- it on disk.
97
- decompress: Whether to decompress the ct.nii.gz files when preparing the data.
98
- The label masks won't be decompressed, but when enabling optimize_mask_loading
99
- it will export the semantic label masks to a single file in uncompressed .nii
100
- format.
101
- num_workers: The number of workers to use for optimizing the masks &
102
- decompressing the .gz files.
103
- transforms: A function/transforms that takes in an image and a target
104
- mask and returns the transformed versions of both.
105
-
106
- """
107
- super().__init__(transforms=transforms)
108
-
109
- self._root = root
110
- self._split = split
111
- self._version = version
112
- self._download = download
113
- self._classes = classes
114
- self._optimize_mask_loading = optimize_mask_loading
115
- self._decompress = decompress
116
- self._num_workers = num_workers
117
- self._class_mappings = class_mappings
118
-
119
- if self._classes and self._class_mappings:
120
- raise ValueError("Both 'classes' and 'class_mappings' cannot be set at the same time.")
121
-
122
- self._samples_dirs: List[str] = []
123
- self._indices: List[Tuple[int, int]] = []
124
-
125
- @functools.cached_property
126
- @override
127
- def classes(self) -> List[str]:
128
- def get_filename(path: str) -> str:
129
- """Returns the filename from the full path."""
130
- return os.path.basename(path).split(".")[0]
131
-
132
- first_sample_labels = os.path.join(self._root, "s0011", "segmentations", "*.nii.gz")
133
- all_classes = sorted(map(get_filename, glob(first_sample_labels)))
134
- if self._classes:
135
- is_subset = all(name in all_classes for name in self._classes)
136
- if not is_subset:
137
- raise ValueError("Provided class names are not subset of the original ones.")
138
- classes = sorted(self._classes)
139
- elif self._class_mappings:
140
- is_subset = all(name in all_classes for name in self._class_mappings.keys())
141
- if not is_subset:
142
- raise ValueError("Provided class names are not subset of the original ones.")
143
- classes = sorted(set(self._class_mappings.values()))
144
- else:
145
- classes = all_classes
146
- return ["background"] + classes
147
-
148
- @property
149
- @override
150
- def class_to_idx(self) -> Dict[str, int]:
151
- return {label: index for index, label in enumerate(self.classes)}
152
-
153
- @property
154
- def _file_suffix(self) -> str:
155
- return "nii" if self._decompress else "nii.gz"
156
-
157
- @functools.cached_property
158
- def _classes_hash(self) -> str:
159
- return hashlib.md5(str(self.classes).encode(), usedforsecurity=False).hexdigest()
160
-
161
- @override
162
- def filename(self, index: int) -> str:
163
- sample_idx, _ = self._indices[index]
164
- sample_dir = self._samples_dirs[sample_idx]
165
- return os.path.join(sample_dir, f"ct.{self._file_suffix}")
166
-
167
- @override
168
- def prepare_data(self) -> None:
169
- if self._download:
170
- self._download_dataset()
171
- if self._decompress:
172
- self._decompress_files()
173
- self._samples_dirs = self._fetch_samples_dirs()
174
- if self._optimize_mask_loading:
175
- self._export_semantic_label_masks()
176
-
177
- @override
178
- def configure(self) -> None:
179
- self._indices = self._create_indices()
180
-
181
- @override
182
- def validate(self) -> None:
183
- if self._version is None or self._sample_every_n_slices is not None:
184
- return
185
-
186
- if self._classes:
187
- last_label = self._classes[-1]
188
- n_classes = len(self._classes)
189
- elif self._class_mappings:
190
- classes = sorted(set(self._class_mappings.values()))
191
- last_label = classes[-1]
192
- n_classes = len(classes)
193
- else:
194
- last_label = "vertebrae_T9"
195
- n_classes = 117
196
-
197
- _validators.check_dataset_integrity(
198
- self,
199
- length=self._expected_dataset_lengths.get(f"{self._split}_{self._version}", 0),
200
- n_classes=n_classes + 1,
201
- first_and_last_labels=("background", last_label),
202
- )
203
-
204
- @override
205
- def __len__(self) -> int:
206
- return len(self._indices)
207
-
208
- @override
209
- def load_data(self, index: int) -> tv_tensors.Image:
210
- sample_index, slice_index = self._indices[index]
211
- image_path = self._get_image_path(sample_index)
212
- image_nii = io.read_nifti(image_path, slice_index)
213
- image_array = io.nifti_to_array(image_nii)
214
- image_array = self._fix_orientation(image_array)
215
- return tv_tensors.Image(image_array.copy().transpose(2, 0, 1))
216
-
217
- @override
218
- def load_target(self, index: int) -> tv_tensors.Mask:
219
- if self._optimize_mask_loading:
220
- mask = self._load_semantic_label_mask(index)
221
- else:
222
- mask = self._load_target(index)
223
- mask = self._fix_orientation(mask)
224
- return tv_tensors.Mask(mask.copy().squeeze(), dtype=torch.int64) # type: ignore
225
-
226
- @override
227
- def load_metadata(self, index: int) -> Dict[str, Any]:
228
- _, slice_index = self._indices[index]
229
- return {"slice_index": slice_index}
230
-
231
- def _load_target(self, index: int) -> npt.NDArray[Any]:
232
- sample_index, slice_index = self._indices[index]
233
- return self._load_masks_as_semantic_label(sample_index, slice_index)
234
-
235
- def _load_semantic_label_mask(self, index: int) -> npt.NDArray[Any]:
236
- """Loads the segmentation mask from a semantic label NifTi file."""
237
- sample_index, slice_index = self._indices[index]
238
- nii = io.read_nifti(self._get_optimized_masks_file(sample_index), slice_index)
239
- return io.nifti_to_array(nii)
240
-
241
- def _load_masks_as_semantic_label(
242
- self, sample_index: int, slice_index: int | None = None
243
- ) -> npt.NDArray[Any]:
244
- """Loads binary masks as a semantic label mask.
245
-
246
- Args:
247
- sample_index: The data sample index.
248
- slice_index: Whether to return only a specific slice.
249
- """
250
- masks_dir = self._get_masks_dir(sample_index)
251
- classes = self._class_mappings.keys() if self._class_mappings else self.classes[1:]
252
- mask_paths = [os.path.join(masks_dir, f"{label}.nii.gz") for label in classes]
253
- binary_masks = [io.nifti_to_array(io.read_nifti(path, slice_index)) for path in mask_paths]
254
-
255
- if self._class_mappings:
256
- mapped_binary_masks = [np.zeros_like(binary_masks[0], dtype=np.bool_)] * len(
257
- self.classes[1:]
258
- )
259
- for original_class, mapped_class in self._class_mappings.items():
260
- mapped_index = self.class_to_idx[mapped_class] - 1
261
- original_index = list(self._class_mappings.keys()).index(original_class)
262
- mapped_binary_masks[mapped_index] = np.logical_or(
263
- mapped_binary_masks[mapped_index], binary_masks[original_index]
264
- )
265
- binary_masks = mapped_binary_masks
266
-
267
- background_mask = np.zeros_like(binary_masks[0])
268
- return np.argmax([background_mask] + binary_masks, axis=0)
269
-
270
- def _export_semantic_label_masks(self) -> None:
271
- """Exports the segmentation binary masks (one-hot) to semantic labels."""
272
- mask_classes_file = os.path.join(f"{self._get_optimized_masks_root()}/classes.txt")
273
- if os.path.isfile(mask_classes_file):
274
- with open(mask_classes_file, "r") as file:
275
- if file.read() != str(self.classes):
276
- raise ValueError(
277
- "Optimized masks hash doesn't match the current classes or mappings."
278
- )
279
- return
280
-
281
- total_samples = len(self._samples_dirs)
282
- semantic_labels = [
283
- (index, self._get_optimized_masks_file(index)) for index in range(total_samples)
284
- ]
285
- to_export = filter(lambda x: not os.path.isfile(x[1]), semantic_labels)
286
-
287
- def _process_mask(sample_index: Any, filename: str) -> None:
288
- semantic_labels = self._load_masks_as_semantic_label(sample_index)
289
- os.makedirs(os.path.dirname(filename), exist_ok=True)
290
- io.save_array_as_nifti(semantic_labels, filename)
291
-
292
- multiprocessing.run_with_threads(
293
- _process_mask,
294
- list(to_export),
295
- num_workers=self._num_workers,
296
- progress_desc=">> Exporting optimized semantic mask",
297
- return_results=False,
298
- )
299
-
300
- os.makedirs(os.path.dirname(mask_classes_file), exist_ok=True)
301
- with open(mask_classes_file, "w") as file:
302
- file.write(str(self.classes))
303
-
304
- def _fix_orientation(self, array: npt.NDArray):
305
- """Fixes orientation such that table is at the bottom & liver on the left."""
306
- array = np.rot90(array)
307
- array = np.flip(array, axis=1)
308
- return array
309
-
310
- def _get_image_path(self, sample_index: int) -> str:
311
- """Returns the corresponding image path."""
312
- sample_dir = self._samples_dirs[sample_index]
313
- return os.path.join(self._root, sample_dir, f"ct.{self._file_suffix}")
314
-
315
- def _get_masks_dir(self, sample_index: int) -> str:
316
- """Returns the directory of the corresponding masks."""
317
- sample_dir = self._samples_dirs[sample_index]
318
- return os.path.join(self._root, sample_dir, "segmentations")
319
-
320
- def _get_optimized_masks_root(self) -> str:
321
- """Returns the directory of the optimized masks."""
322
- return os.path.join(self._root, f"processed/masks/{self._classes_hash}")
323
-
324
- def _get_optimized_masks_file(self, sample_index: int) -> str:
325
- """Returns the semantic label filename."""
326
- return os.path.join(
327
- f"{self._get_optimized_masks_root()}/{self._samples_dirs[sample_index]}/masks.nii"
328
- )
329
-
330
- def _get_number_of_slices_per_sample(self, sample_index: int) -> int:
331
- """Returns the total amount of slices of a sample."""
332
- image_path = self._get_image_path(sample_index)
333
- image_shape = io.fetch_nifti_shape(image_path)
334
- return image_shape[-1]
335
-
336
- def _fetch_samples_dirs(self) -> List[str]:
337
- """Returns the name of all the samples of all the splits of the dataset."""
338
- sample_filenames = [
339
- filename
340
- for filename in os.listdir(self._root)
341
- if os.path.isdir(os.path.join(self._root, filename)) and re.match(r"^s\d{4}$", filename)
342
- ]
343
- return sorted(sample_filenames)
344
-
345
- def _get_split_indices(self) -> List[int]:
346
- """Returns the samples indices that corresponding the dataset split and version."""
347
- metadata_file = os.path.join(self._root, "meta.csv")
348
- metadata = io.read_csv(metadata_file, delimiter=";", encoding="utf-8-sig")
349
-
350
- match self._split:
351
- case "train":
352
- image_ids = [item["image_id"] for item in metadata if item["split"] == "train"]
353
- case "val":
354
- image_ids = [item["image_id"] for item in metadata if item["split"] == "val"]
355
- case "test":
356
- image_ids = [item["image_id"] for item in metadata if item["split"] == "test"]
357
- case _:
358
- image_ids = self._samples_dirs
359
-
360
- return sorted(map(self._samples_dirs.index, image_ids))
361
-
362
- def _create_indices(self) -> List[Tuple[int, int]]:
363
- """Builds the dataset indices for the specified split.
364
-
365
- Returns:
366
- A list of tuples, where the first value indicates the
367
- sample index which the second its corresponding slice
368
- index.
369
- """
370
- indices = [
371
- (sample_idx, slide_idx)
372
- for sample_idx in self._get_split_indices()
373
- for slide_idx in range(self._get_number_of_slices_per_sample(sample_idx))
374
- if slide_idx % (self._sample_every_n_slices or 1) == 0
375
- ]
376
- return indices
377
-
378
- def _download_dataset(self) -> None:
379
- """Downloads the dataset."""
380
- dataset_resources = {
381
- "small": self._resources_small,
382
- "full": self._resources_full,
383
- }
384
- resources = dataset_resources.get(self._version or "")
385
- if resources is None:
386
- raise ValueError(
387
- f"Can't download data version '{self._version}'. Use 'small' or 'full'."
388
- )
389
-
390
- self._print_license()
391
- for resource in resources:
392
- if os.path.isdir(self._root):
393
- continue
394
-
395
- utils.download_and_extract_archive(
396
- resource.url,
397
- download_root=self._root,
398
- filename=resource.filename,
399
- remove_finished=True,
400
- )
401
-
402
- def _decompress_files(self) -> None:
403
- compressed_paths = Path(self._root).rglob("*/ct.nii.gz")
404
- multiprocessing.run_with_threads(
405
- core_io.gunzip_file,
406
- [(str(path),) for path in compressed_paths],
407
- num_workers=self._num_workers,
408
- progress_desc=">> Decompressing .gz files",
409
- return_results=False,
410
- )
411
-
412
- def _print_license(self) -> None:
413
- """Prints the dataset license."""
414
- print(f"Dataset license: {self._license}")