kaiko-eva 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. eva/core/data/datasets/base.py +7 -2
  2. eva/core/models/modules/head.py +4 -2
  3. eva/core/models/modules/typings.py +2 -2
  4. eva/core/models/transforms/__init__.py +2 -1
  5. eva/core/models/transforms/as_discrete.py +57 -0
  6. eva/core/models/wrappers/_utils.py +121 -1
  7. eva/core/utils/suppress_logs.py +28 -0
  8. eva/vision/data/__init__.py +2 -2
  9. eva/vision/data/dataloaders/__init__.py +5 -0
  10. eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
  11. eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
  12. eva/vision/data/datasets/__init__.py +2 -2
  13. eva/vision/data/datasets/classification/bach.py +3 -4
  14. eva/vision/data/datasets/classification/bracs.py +3 -4
  15. eva/vision/data/datasets/classification/breakhis.py +3 -4
  16. eva/vision/data/datasets/classification/camelyon16.py +4 -5
  17. eva/vision/data/datasets/classification/crc.py +3 -4
  18. eva/vision/data/datasets/classification/gleason_arvaniti.py +3 -4
  19. eva/vision/data/datasets/classification/mhist.py +3 -4
  20. eva/vision/data/datasets/classification/panda.py +4 -5
  21. eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
  22. eva/vision/data/datasets/classification/unitopatho.py +3 -4
  23. eva/vision/data/datasets/classification/wsi.py +6 -5
  24. eva/vision/data/datasets/segmentation/__init__.py +2 -2
  25. eva/vision/data/datasets/segmentation/_utils.py +47 -0
  26. eva/vision/data/datasets/segmentation/bcss.py +7 -8
  27. eva/vision/data/datasets/segmentation/btcv.py +236 -0
  28. eva/vision/data/datasets/segmentation/consep.py +6 -7
  29. eva/vision/data/datasets/segmentation/lits.py +9 -8
  30. eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
  31. eva/vision/data/datasets/segmentation/monusac.py +4 -5
  32. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
  33. eva/vision/data/datasets/vision.py +95 -4
  34. eva/vision/data/datasets/wsi.py +5 -5
  35. eva/vision/data/transforms/__init__.py +22 -3
  36. eva/vision/data/transforms/common/__init__.py +1 -2
  37. eva/vision/data/transforms/croppad/__init__.py +11 -0
  38. eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
  39. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
  40. eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
  41. eva/vision/data/transforms/intensity/__init__.py +11 -0
  42. eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
  43. eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
  44. eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
  45. eva/vision/data/transforms/spatial/__init__.py +7 -0
  46. eva/vision/data/transforms/spatial/flip.py +72 -0
  47. eva/vision/data/transforms/spatial/rotate.py +53 -0
  48. eva/vision/data/transforms/spatial/spacing.py +69 -0
  49. eva/vision/data/transforms/utility/__init__.py +5 -0
  50. eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
  51. eva/vision/data/tv_tensors/__init__.py +5 -0
  52. eva/vision/data/tv_tensors/volume.py +61 -0
  53. eva/vision/metrics/segmentation/monai_dice.py +9 -2
  54. eva/vision/models/modules/semantic_segmentation.py +28 -20
  55. eva/vision/models/networks/backbones/__init__.py +9 -2
  56. eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
  57. eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
  58. eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
  59. eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
  60. eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
  61. eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
  62. eva/vision/models/networks/backbones/radiology/voco.py +75 -0
  63. eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
  64. eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
  65. eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
  66. eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
  67. eva/vision/utils/io/__init__.py +2 -0
  68. eva/vision/utils/io/nifti.py +91 -11
  69. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/METADATA +3 -1
  70. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/RECORD +73 -57
  71. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/WHEEL +1 -1
  72. eva/vision/data/datasets/classification/base.py +0 -96
  73. eva/vision/data/datasets/segmentation/base.py +0 -96
  74. eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
  75. eva/vision/data/transforms/normalization/__init__.py +0 -6
  76. eva/vision/data/transforms/normalization/clamp.py +0 -43
  77. eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
  78. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
  79. eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
  80. eva/vision/metrics/segmentation/BUILD +0 -1
  81. eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
  82. eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
  83. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/entry_points.txt +0 -0
  84. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,231 @@
1
+ """Encoder based on Swin UNETR."""
2
+
3
+ from typing import List, Tuple
4
+
5
+ import torch
6
+ from monai.inferers.inferer import Inferer
7
+ from monai.networks.blocks import unetr_block
8
+ from monai.networks.nets import swin_unetr
9
+ from monai.utils import misc
10
+ from torch import nn
11
+
12
+ from eva.vision.models.networks.backbones.registry import register_model
13
+
14
+
15
+ @register_model("radiology/swin_unetr_encoder")
16
+ class SwinUNETREncoder(nn.Module):
17
+ """Swin transformer encoder based on UNETR [0].
18
+
19
+ - [0] UNETR: Transformers for 3D Medical Image Segmentation
20
+ https://arxiv.org/pdf/2103.10504
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ in_channels: int = 1,
26
+ feature_size: int = 48,
27
+ spatial_dims: int = 3,
28
+ out_indices: int | None = None,
29
+ inferer: Inferer | None = None,
30
+ use_v2: bool = True,
31
+ ) -> None:
32
+ """Build the UNETR encoder.
33
+
34
+ Args:
35
+ in_channels: Number of input channels.
36
+ feature_size: The dimension of network feature size.
37
+ spatial_dims: Number of spatial dimensions.
38
+ out_indices: Number of feature outputs. If None,
39
+ the aggregated feature vector is returned.
40
+ inferer: An optional MONAI `Inferer` for efficient
41
+ inference during evaluation.
42
+ use_v2: Whether to use SwinTransformerV2.
43
+ """
44
+ super().__init__()
45
+
46
+ self._in_channels = in_channels
47
+ self._feature_size = feature_size
48
+ self._spatial_dims = spatial_dims
49
+ self._out_indices = out_indices
50
+ self._inferer = inferer
51
+ self._use_v2 = use_v2
52
+
53
+ self._window_size = misc.ensure_tuple_rep(7, spatial_dims)
54
+ self._patch_size = misc.ensure_tuple_rep(2, spatial_dims)
55
+
56
+ self.swinViT = swin_unetr.SwinTransformer(
57
+ in_chans=in_channels,
58
+ embed_dim=feature_size,
59
+ window_size=self._window_size,
60
+ patch_size=self._patch_size,
61
+ depths=(2, 2, 2, 2),
62
+ num_heads=(3, 6, 12, 24),
63
+ mlp_ratio=4.0,
64
+ qkv_bias=True,
65
+ drop_rate=0.0,
66
+ attn_drop_rate=0.0,
67
+ drop_path_rate=0.0,
68
+ norm_layer=torch.nn.LayerNorm,
69
+ spatial_dims=spatial_dims,
70
+ use_v2=use_v2,
71
+ )
72
+ self.encoder1 = unetr_block.UnetrBasicBlock(
73
+ spatial_dims=spatial_dims,
74
+ in_channels=in_channels,
75
+ out_channels=feature_size,
76
+ kernel_size=3,
77
+ stride=1,
78
+ norm_name="instance",
79
+ res_block=True,
80
+ )
81
+ self.encoder2 = unetr_block.UnetrBasicBlock(
82
+ spatial_dims=spatial_dims,
83
+ in_channels=feature_size,
84
+ out_channels=feature_size,
85
+ kernel_size=3,
86
+ stride=1,
87
+ norm_name="instance",
88
+ res_block=True,
89
+ )
90
+ self.encoder3 = unetr_block.UnetrBasicBlock(
91
+ spatial_dims=spatial_dims,
92
+ in_channels=2 * feature_size,
93
+ out_channels=2 * feature_size,
94
+ kernel_size=3,
95
+ stride=1,
96
+ norm_name="instance",
97
+ res_block=True,
98
+ )
99
+ self.encoder4 = unetr_block.UnetrBasicBlock(
100
+ spatial_dims=spatial_dims,
101
+ in_channels=4 * feature_size,
102
+ out_channels=4 * feature_size,
103
+ kernel_size=3,
104
+ stride=1,
105
+ norm_name="instance",
106
+ res_block=True,
107
+ )
108
+ self.encoder10 = unetr_block.UnetrBasicBlock(
109
+ spatial_dims=spatial_dims,
110
+ in_channels=16 * feature_size,
111
+ out_channels=16 * feature_size,
112
+ kernel_size=3,
113
+ stride=1,
114
+ norm_name="instance",
115
+ res_block=True,
116
+ )
117
+ self._pool_op = (
118
+ nn.AdaptiveAvgPool3d(output_size=(1, 1, 1))
119
+ if spatial_dims == 3
120
+ else nn.AdaptiveAvgPool2d(output_size=(1, 1))
121
+ )
122
+
123
+ def _forward_features(self, tensor: torch.Tensor) -> List[torch.Tensor]:
124
+ """Extracts feature maps from the Swin Transformer and encoder blocks.
125
+
126
+ Args:
127
+ tensor: Input tensor of shape (B, C, T, H, W).
128
+
129
+ Returns:
130
+ List of feature maps from encoder stages.
131
+ """
132
+ hidden_states = self.swinViT(tensor)
133
+ enc0 = self.encoder1(tensor)
134
+ enc1 = self.encoder2(hidden_states[0])
135
+ enc2 = self.encoder3(hidden_states[1])
136
+ enc3 = self.encoder4(hidden_states[2])
137
+ dec4 = self.encoder10(hidden_states[4])
138
+ return [enc0, enc1, enc2, enc3, hidden_states[3], dec4]
139
+
140
+ def forward_features(self, tensor: torch.Tensor) -> List[torch.Tensor]:
141
+ """Computes feature maps using either standard forward pass or inference mode.
142
+
143
+ If in inference mode (`self.training` is False) and an inference method
144
+ (`self._inferer`) is available, the `_inferer` is used to extract features.
145
+ Otherwise, `_forward_features` is called directly.
146
+
147
+ Args:
148
+ tensor: Input tensor of shape (B, C, T, H, W).
149
+
150
+ Returns:
151
+ List of feature maps from encoder stages.
152
+ """
153
+ if not self.training and self._inferer:
154
+ return self._inferer(inputs=tensor, network=self._forward_features)
155
+
156
+ return self._forward_features(tensor)
157
+
158
+ def forward_encoders(self, features: List[torch.Tensor]) -> torch.Tensor:
159
+ """Aggregates encoder features into a single feature vector.
160
+
161
+ Args:
162
+ features: List of feature maps from encoder stages.
163
+
164
+ Returns:
165
+ Aggregated feature vector (B, C').
166
+ """
167
+ batch_size = features[0].shape[0]
168
+ reduced_features = []
169
+ for patch_features in features[:4] + features[5:]:
170
+ hidden_features = self._pool_op(patch_features)
171
+ hidden_features_reduced = hidden_features.view(batch_size, -1)
172
+ reduced_features.append(hidden_features_reduced)
173
+ return torch.cat(reduced_features, dim=1)
174
+
175
+ def forward_head(self, features: List[torch.Tensor]) -> torch.Tensor:
176
+ """Casts last feature map into a single feature vector.
177
+
178
+ Args:
179
+ features: List of encoder feature maps.
180
+
181
+ Returns:
182
+ Aggregated feature vector (B, C').
183
+ """
184
+ last_feature_map = features[-1]
185
+ pooled_features = self._pool_op(last_feature_map)
186
+ return torch.flatten(pooled_features, 1)
187
+
188
+ def forward_embeddings(self, tensor: torch.Tensor) -> torch.Tensor:
189
+ """Computes the final aggregated feature vector.
190
+
191
+ Args:
192
+ tensor: Input tensor of shape (B, C, T, H, W).
193
+
194
+ Returns:
195
+ Aggregated feature vector of shape (B, C').
196
+ """
197
+ intermediates = self.forward_features(tensor)
198
+ return self.forward_encoders(intermediates)
199
+
200
+ def forward_intermediates(
201
+ self, tensor: torch.Tensor
202
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
203
+ """Computes encoder features and their embeddings.
204
+
205
+ Args:
206
+ tensor: Input tensor of shape (B, C, T, H, W).
207
+
208
+ Returns:
209
+ Aggregated feature vector and list of intermediate features.
210
+ """
211
+ features = self.forward_features(tensor)
212
+ embeddings = self.forward_encoders(features)
213
+ return embeddings, features
214
+
215
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor | List[torch.Tensor]:
216
+ """Forward pass through the encoder.
217
+
218
+ If `self._out_indices` is None, it returns the aggregated feature vector.
219
+ Otherwise, it returns the intermediate feature maps up to the specified index.
220
+
221
+ Args:
222
+ tensor: Input tensor of shape (B, C, T, H, W).
223
+
224
+ Returns:
225
+ Aggregated feature vector or intermediate features.
226
+ """
227
+ if self._out_indices is None:
228
+ return self.forward_embeddings(tensor)
229
+
230
+ intermediates = self.forward_features(tensor)
231
+ return intermediates[-1 * self._out_indices :]
@@ -0,0 +1,75 @@
1
+ """VoCo Self-Supervised Encoders."""
2
+
3
+ from typing_extensions import override
4
+
5
+ from eva.core.models.wrappers import _utils
6
+ from eva.vision.models.networks.backbones.radiology import swin_unetr
7
+ from eva.vision.models.networks.backbones.registry import register_model
8
+
9
+
10
+ class _VoCo(swin_unetr.SwinUNETREncoder):
11
+ """Base class for the VoCo self-supervised encoders."""
12
+
13
+ _checkpoint: str
14
+ """Path to the model state dict."""
15
+
16
+ _md5: str | None = None
17
+ """State dict MD5 validation code."""
18
+
19
+ def __init__(self, feature_size: int, out_indices: int | None = None) -> None:
20
+ """Initializes the model.
21
+
22
+ Args:
23
+ feature_size: Size of the last feature map of SwinUNETR.
24
+ out_indices: The number of feature maps from intermediate blocks
25
+ to be returned. If set to 1, only the last feature map is returned.
26
+ """
27
+ super().__init__(
28
+ in_channels=1,
29
+ feature_size=feature_size,
30
+ spatial_dims=3,
31
+ out_indices=out_indices,
32
+ )
33
+
34
+ self._load_checkpoint()
35
+
36
+ def _load_checkpoint(self) -> None:
37
+ """Loads the model checkpoint."""
38
+ state_dict = _utils.load_state_dict_from_url(self._checkpoint, md5=self._md5)
39
+ self.load_state_dict(state_dict)
40
+
41
+
42
+ @register_model("radiology/voco_b")
43
+ class VoCoB(_VoCo):
44
+ """VoCo Self-supervised pre-trained B model."""
45
+
46
+ _checkpoint = "https://huggingface.co/Luffy503/VoCo/resolve/main/VoCo_B_SSL_head.pt"
47
+ _md5 = "f80c4da2f81d700bdae3df188f2057eb"
48
+
49
+ @override
50
+ def __init__(self, out_indices: int | None = None) -> None:
51
+ super().__init__(feature_size=48, out_indices=out_indices)
52
+
53
+
54
+ @register_model("radiology/voco_l")
55
+ class VoCoL(_VoCo):
56
+ """VoCo Self-supervised pre-trained L model."""
57
+
58
+ _checkpoint = "https://huggingface.co/Luffy503/VoCo/resolve/main/VoCo_L_SSL_head.pt"
59
+ _md5 = "795095d1d43ef3808ec4c41798310136"
60
+
61
+ @override
62
+ def __init__(self, out_indices: int | None = None) -> None:
63
+ super().__init__(feature_size=96, out_indices=out_indices)
64
+
65
+
66
+ @register_model("radiology/voco_h")
67
+ class VoCoH(_VoCo):
68
+ """VoCo Self-supervised pre-trained H model."""
69
+
70
+ _checkpoint = "https://huggingface.co/Luffy503/VoCo/resolve/main/VoCo_H_SSL_head.pt"
71
+ _md5 = "76f95a474736b60bf5b8aad94643744d"
72
+
73
+ @override
74
+ def __init__(self, out_indices: int | None = None) -> None:
75
+ super().__init__(feature_size=192, out_indices=out_indices)
@@ -1,5 +1,6 @@
1
1
  """Segmentation decoder heads API."""
2
2
 
3
+ from eva.vision.models.networks.decoders.segmentation.base import Decoder
3
4
  from eva.vision.models.networks.decoders.segmentation.decoder2d import Decoder2D
4
5
  from eva.vision.models.networks.decoders.segmentation.linear import LinearDecoder
5
6
  from eva.vision.models.networks.decoders.segmentation.semantic import (
@@ -7,13 +8,16 @@ from eva.vision.models.networks.decoders.segmentation.semantic import (
7
8
  ConvDecoderMS,
8
9
  ConvDecoderWithImage,
9
10
  SingleLinearDecoder,
11
+ SwinUNETRDecoder,
10
12
  )
11
13
 
12
14
  __all__ = [
15
+ "Decoder",
16
+ "Decoder2D",
13
17
  "ConvDecoder1x1",
14
18
  "ConvDecoderMS",
15
- "SingleLinearDecoder",
16
19
  "ConvDecoderWithImage",
17
- "Decoder2D",
18
20
  "LinearDecoder",
21
+ "SingleLinearDecoder",
22
+ "SwinUNETRDecoder",
19
23
  ]
@@ -7,6 +7,7 @@ from torch import nn
7
7
  from torch.nn import functional
8
8
 
9
9
  from eva.vision.models.networks.decoders.segmentation import base
10
+ from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs
10
11
 
11
12
 
12
13
  class LinearDecoder(base.Decoder):
@@ -104,22 +105,16 @@ class LinearDecoder(base.Decoder):
104
105
  """
105
106
  return functional.interpolate(logits, image_size, mode="bilinear")
106
107
 
107
- def forward(
108
- self,
109
- features: List[torch.Tensor],
110
- image_size: Tuple[int, int],
111
- ) -> torch.Tensor:
108
+ def forward(self, decoder_inputs: DecoderInputs) -> torch.Tensor:
112
109
  """Maps the patch embeddings to a segmentation mask of the image size.
113
110
 
114
111
  Args:
115
- features: List of multi-level image features of shape (batch_size,
116
- hidden_size, n_patches_height, n_patches_width).
117
- image_size: The target image size (height, width).
112
+ decoder_inputs: Inputs required by the decoder.
118
113
 
119
114
  Returns:
120
115
  Tensor containing scores for all of the classes with shape
121
116
  (batch_size, n_classes, image_height, image_width).
122
117
  """
123
- patch_embeddings = self._forward_features(features)
118
+ patch_embeddings = self._forward_features(decoder_inputs.features)
124
119
  logits = self._forward_head(patch_embeddings)
125
- return self._cls_seg(logits, image_size)
120
+ return self._cls_seg(logits, decoder_inputs.image_size)
@@ -5,8 +5,15 @@ from eva.vision.models.networks.decoders.segmentation.semantic.common import (
5
5
  ConvDecoderMS,
6
6
  SingleLinearDecoder,
7
7
  )
8
+ from eva.vision.models.networks.decoders.segmentation.semantic.swin_unetr import SwinUNETRDecoder
8
9
  from eva.vision.models.networks.decoders.segmentation.semantic.with_image import (
9
10
  ConvDecoderWithImage,
10
11
  )
11
12
 
12
- __all__ = ["ConvDecoder1x1", "ConvDecoderMS", "SingleLinearDecoder", "ConvDecoderWithImage"]
13
+ __all__ = [
14
+ "ConvDecoder1x1",
15
+ "ConvDecoderMS",
16
+ "ConvDecoderWithImage",
17
+ "SingleLinearDecoder",
18
+ "SwinUNETRDecoder",
19
+ ]
@@ -0,0 +1,104 @@
1
+ """Decoder based on Swin UNETR."""
2
+
3
+ from typing import List
4
+
5
+ import torch
6
+ from monai.networks.blocks import dynunet_block, unetr_block
7
+ from torch import nn
8
+
9
+
10
+ class SwinUNETRDecoder(nn.Module):
11
+ """Swin transformer decoder based on UNETR [0].
12
+
13
+ - [0] UNETR: Transformers for 3D Medical Image Segmentation
14
+ https://arxiv.org/pdf/2103.10504
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ out_channels: int,
20
+ feature_size: int = 48,
21
+ spatial_dims: int = 3,
22
+ ) -> None:
23
+ """Builds the decoder.
24
+
25
+ Args:
26
+ out_channels: Number of output channels.
27
+ feature_size: Dimension of network feature size.
28
+ spatial_dims: Number of spatial dimensions.
29
+ """
30
+ super().__init__()
31
+
32
+ self.decoder5 = unetr_block.UnetrUpBlock(
33
+ spatial_dims=spatial_dims,
34
+ in_channels=16 * feature_size,
35
+ out_channels=8 * feature_size,
36
+ kernel_size=3,
37
+ upsample_kernel_size=2,
38
+ norm_name="instance",
39
+ res_block=True,
40
+ )
41
+ self.decoder4 = unetr_block.UnetrUpBlock(
42
+ spatial_dims=spatial_dims,
43
+ in_channels=feature_size * 8,
44
+ out_channels=feature_size * 4,
45
+ kernel_size=3,
46
+ upsample_kernel_size=2,
47
+ norm_name="instance",
48
+ res_block=True,
49
+ )
50
+ self.decoder3 = unetr_block.UnetrUpBlock(
51
+ spatial_dims=spatial_dims,
52
+ in_channels=feature_size * 4,
53
+ out_channels=feature_size * 2,
54
+ kernel_size=3,
55
+ upsample_kernel_size=2,
56
+ norm_name="instance",
57
+ res_block=True,
58
+ )
59
+ self.decoder2 = unetr_block.UnetrUpBlock(
60
+ spatial_dims=spatial_dims,
61
+ in_channels=feature_size * 2,
62
+ out_channels=feature_size,
63
+ kernel_size=3,
64
+ upsample_kernel_size=2,
65
+ norm_name="instance",
66
+ res_block=True,
67
+ )
68
+ self.decoder1 = unetr_block.UnetrUpBlock(
69
+ spatial_dims=spatial_dims,
70
+ in_channels=feature_size,
71
+ out_channels=feature_size,
72
+ kernel_size=3,
73
+ upsample_kernel_size=2,
74
+ norm_name="instance",
75
+ res_block=True,
76
+ )
77
+ self.out = dynunet_block.UnetOutBlock(
78
+ spatial_dims=spatial_dims,
79
+ in_channels=feature_size,
80
+ out_channels=out_channels,
81
+ )
82
+
83
+ def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor:
84
+ """Forward function for multi-level feature maps to a single one."""
85
+ enc0, enc1, enc2, enc3, hid3, dec4 = features
86
+ dec3 = self.decoder5(dec4, hid3)
87
+ dec2 = self.decoder4(dec3, enc3)
88
+ dec1 = self.decoder3(dec2, enc2)
89
+ dec0 = self.decoder2(dec1, enc1)
90
+ out = self.decoder1(dec0, enc0)
91
+ return self.out(out)
92
+
93
+ def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
94
+ """Maps the patch embeddings to a segmentation mask.
95
+
96
+ Args:
97
+ features: List of multi-level intermediate features from
98
+ :class:`SwinUNETREncoder`.
99
+
100
+ Returns:
101
+ Tensor containing scores for all of the classes with shape
102
+ (batch_size, n_classes, image_height, image_width).
103
+ """
104
+ return self._forward_features(features)
@@ -5,6 +5,7 @@ from eva.vision.utils.io.mat import read_mat, save_mat
5
5
  from eva.vision.utils.io.nifti import (
6
6
  fetch_nifti_axis_direction_code,
7
7
  fetch_nifti_shape,
8
+ nifti_to_array,
8
9
  read_nifti,
9
10
  save_array_as_nifti,
10
11
  )
@@ -16,6 +17,7 @@ __all__ = [
16
17
  "read_image_as_tensor",
17
18
  "fetch_nifti_shape",
18
19
  "fetch_nifti_axis_direction_code",
20
+ "nifti_to_array",
19
21
  "read_nifti",
20
22
  "save_array_as_nifti",
21
23
  "read_csv",
@@ -1,3 +1,4 @@
1
+ # type: ignore
1
2
  """NIfTI I/O related functions."""
2
3
 
3
4
  from typing import Any, Tuple
@@ -7,36 +8,63 @@ import numpy as np
7
8
  import numpy.typing as npt
8
9
  from nibabel import orientations
9
10
 
11
+ from eva.core.utils.suppress_logs import SuppressLogs
10
12
  from eva.vision.utils.io import _utils
11
13
 
12
14
 
13
15
  def read_nifti(
14
- path: str, slice_index: int | None = None, *, use_storage_dtype: bool = True
15
- ) -> npt.NDArray[Any]:
16
+ path: str,
17
+ slice_index: int | None = None,
18
+ *,
19
+ orientation: str | None = None,
20
+ orientation_reference: str | None = None,
21
+ ) -> nib.nifti1.Nifti1Image:
16
22
  """Reads and loads a NIfTI image from a file path.
17
23
 
18
24
  Args:
19
25
  path: The path to the NIfTI file.
20
26
  slice_index: Whether to read only a slice from the file.
27
+ orientation: The orientation code to reorient the nifti image.
28
+ orientation_reference: Path to a NIfTI file which
29
+ will be used as a reference for the orientation
30
+ transform in case the file missing the pixdim array
31
+ in the NIfTI header.
21
32
  use_storage_dtype: Whether to cast the raw image
22
33
  array to the inferred type.
23
34
 
24
35
  Returns:
25
- The image as a numpy array (height, width, channels).
36
+ The NIfTI image class instance.
26
37
 
27
38
  Raises:
28
39
  FileExistsError: If the path does not exist or it is unreachable.
29
40
  ValueError: If the input channel is invalid for the image.
30
41
  """
31
42
  _utils.check_file(path)
32
- image_data: nib.Nifti1Image = nib.load(path) # type: ignore
43
+ image_data = _load_nifti_silently(path)
33
44
  if slice_index is not None:
34
45
  image_data = image_data.slicer[:, :, slice_index : slice_index + 1]
46
+ if orientation:
47
+ image_data = _reorient(
48
+ image_data, orientation=orientation, reference_file=orientation_reference
49
+ )
35
50
 
36
- image_array = image_data.get_fdata()
37
- if use_storage_dtype:
38
- image_array = image_array.astype(image_data.get_data_dtype())
51
+ return image_data
52
+
53
+
54
+ def nifti_to_array(nii: nib.Nifti1Image, use_storage_dtype: bool = True) -> npt.NDArray[Any]:
55
+ """Converts a NIfTI image to a numpy array.
56
+
57
+ Args:
58
+ nii: The input NIfTI image.
59
+ use_storage_dtype: Whether to cast the raw image
60
+ array to the inferred type.
39
61
 
62
+ Returns:
63
+ The image as a numpy array (height, width, channels).
64
+ """
65
+ image_array = nii.get_fdata()
66
+ if use_storage_dtype:
67
+ image_array = image_array.astype(nii.get_data_dtype())
40
68
  return image_array
41
69
 
42
70
 
@@ -53,7 +81,7 @@ def save_array_as_nifti(
53
81
  filename: The name to save the image like.
54
82
  dtype: The data type to save the image.
55
83
  """
56
- nifti_image = nib.Nifti1Image(array, affine=np.eye(4), dtype=dtype) # type: ignore
84
+ nifti_image = nib.Nifti1Image(array, affine=np.eye(4), dtype=dtype)
57
85
  nifti_image.to_filename(filename)
58
86
 
59
87
 
@@ -71,8 +99,22 @@ def fetch_nifti_shape(path: str) -> Tuple[int]:
71
99
  ValueError: If the input channel is invalid for the image.
72
100
  """
73
101
  _utils.check_file(path)
74
- image = nib.load(path) # type: ignore
75
- return image.header.get_data_shape() # type: ignore
102
+ nii = _load_nifti_silently(path)
103
+ return nii.header.get_data_shape() # type: ignore
104
+
105
+
106
+ def fetch_nifti_orientation(path: str) -> npt.NDArray[Any]:
107
+ """Fetches the NIfTI image orientation.
108
+
109
+ Args:
110
+ path: The path to the NIfTI file.
111
+
112
+ Returns:
113
+ The array orientation.
114
+ """
115
+ _utils.check_file(path)
116
+ nii = _load_nifti_silently(path)
117
+ return nib.io_orientation(nii.affine)
76
118
 
77
119
 
78
120
  def fetch_nifti_axis_direction_code(path: str) -> str:
@@ -85,5 +127,43 @@ def fetch_nifti_axis_direction_code(path: str) -> str:
85
127
  The axis direction codes as string (e.g. "LAS").
86
128
  """
87
129
  _utils.check_file(path)
88
- image_data: nib.Nifti1Image = nib.load(path) # type: ignore
130
+ image_data: nib.Nifti1Image = nib.load(path)
89
131
  return "".join(orientations.aff2axcodes(image_data.affine))
132
+
133
+
134
+ def _load_nifti_silently(path: str) -> nib.Nifti1Image:
135
+ """Reads a NIfTI image in silent mode."""
136
+ with SuppressLogs():
137
+ return nib.load(path)
138
+ raise ValueError(f"Failed to load NIfTI file: {path}")
139
+
140
+
141
+ def _reorient(
142
+ nii: nib.Nifti1Image,
143
+ /,
144
+ orientation: str | tuple[str, str, str] = "RAS",
145
+ reference_file: str | None = None,
146
+ ) -> nib.Nifti1Image:
147
+ """Reorients a NIfTI image to a specified orientation.
148
+
149
+ Args:
150
+ nii: The input NIfTI image.
151
+ orientation: Desired orientation expressed as a
152
+ three-character string (e.g., "RAS") or a tuple
153
+ (e.g., ("R", "A", "S")).
154
+ reference_file: Path to a reference NIfTI file whose
155
+ orientation should be used if the input image lacks
156
+ a valid affine transformation.
157
+
158
+ Returns:
159
+ The reoriented NIfTI image.
160
+ """
161
+ affine_matrix, _ = nii.get_qform(coded=True)
162
+ orig_ornt = (
163
+ fetch_nifti_orientation(reference_file)
164
+ if reference_file and affine_matrix is None
165
+ else nib.io_orientation(nii.affine)
166
+ )
167
+ targ_ornt = orientations.axcodes2ornt(orientation)
168
+ transform = orientations.ornt_transform(orig_ornt, targ_ornt)
169
+ return nii.as_reoriented(transform)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: kaiko-eva
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: Evaluation Framework for oncology foundation models.
5
5
  Keywords: machine-learning,evaluation-framework,oncology,foundation-models
6
6
  Author-Email: Ioannis Gatopoulos <ioannis@kaiko.ai>, =?utf-8?q?Nicolas_K=C3=A4nzig?= <nicolas@kaiko.ai>, Roman Moser <roman@kaiko.ai>
@@ -241,6 +241,7 @@ Requires-Dist: scikit-image>=0.24.0; extra == "vision"
241
241
  Requires-Dist: imagesize>=1.4.1; extra == "vision"
242
242
  Requires-Dist: scipy>=1.14.0; extra == "vision"
243
243
  Requires-Dist: monai>=1.3.2; extra == "vision"
244
+ Requires-Dist: einops>=0.8.1; extra == "vision"
244
245
  Provides-Extra: all
245
246
  Requires-Dist: h5py>=3.10.0; extra == "all"
246
247
  Requires-Dist: nibabel>=4.0.1; extra == "all"
@@ -253,6 +254,7 @@ Requires-Dist: scikit-image>=0.24.0; extra == "all"
253
254
  Requires-Dist: imagesize>=1.4.1; extra == "all"
254
255
  Requires-Dist: scipy>=1.14.0; extra == "all"
255
256
  Requires-Dist: monai>=1.3.2; extra == "all"
257
+ Requires-Dist: einops>=0.8.1; extra == "all"
256
258
  Description-Content-Type: text/markdown
257
259
 
258
260
  <div align="center">