kaiko-eva 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eva/core/data/datasets/base.py +7 -2
- eva/core/models/modules/head.py +4 -2
- eva/core/models/modules/typings.py +2 -2
- eva/core/models/transforms/__init__.py +2 -1
- eva/core/models/transforms/as_discrete.py +57 -0
- eva/core/models/wrappers/_utils.py +121 -1
- eva/core/utils/suppress_logs.py +28 -0
- eva/vision/data/__init__.py +2 -2
- eva/vision/data/dataloaders/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
- eva/vision/data/datasets/__init__.py +2 -2
- eva/vision/data/datasets/classification/bach.py +3 -4
- eva/vision/data/datasets/classification/bracs.py +3 -4
- eva/vision/data/datasets/classification/breakhis.py +3 -4
- eva/vision/data/datasets/classification/camelyon16.py +4 -5
- eva/vision/data/datasets/classification/crc.py +3 -4
- eva/vision/data/datasets/classification/gleason_arvaniti.py +3 -4
- eva/vision/data/datasets/classification/mhist.py +3 -4
- eva/vision/data/datasets/classification/panda.py +4 -5
- eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
- eva/vision/data/datasets/classification/unitopatho.py +3 -4
- eva/vision/data/datasets/classification/wsi.py +6 -5
- eva/vision/data/datasets/segmentation/__init__.py +2 -2
- eva/vision/data/datasets/segmentation/_utils.py +47 -0
- eva/vision/data/datasets/segmentation/bcss.py +7 -8
- eva/vision/data/datasets/segmentation/btcv.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +6 -7
- eva/vision/data/datasets/segmentation/lits.py +9 -8
- eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
- eva/vision/data/datasets/segmentation/monusac.py +4 -5
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
- eva/vision/data/datasets/vision.py +95 -4
- eva/vision/data/datasets/wsi.py +5 -5
- eva/vision/data/transforms/__init__.py +22 -3
- eva/vision/data/transforms/common/__init__.py +1 -2
- eva/vision/data/transforms/croppad/__init__.py +11 -0
- eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
- eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
- eva/vision/data/transforms/intensity/__init__.py +11 -0
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
- eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
- eva/vision/data/transforms/spatial/__init__.py +7 -0
- eva/vision/data/transforms/spatial/flip.py +72 -0
- eva/vision/data/transforms/spatial/rotate.py +53 -0
- eva/vision/data/transforms/spatial/spacing.py +69 -0
- eva/vision/data/transforms/utility/__init__.py +5 -0
- eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
- eva/vision/data/tv_tensors/__init__.py +5 -0
- eva/vision/data/tv_tensors/volume.py +61 -0
- eva/vision/metrics/segmentation/monai_dice.py +9 -2
- eva/vision/models/modules/semantic_segmentation.py +28 -20
- eva/vision/models/networks/backbones/__init__.py +9 -2
- eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
- eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
- eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
- eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
- eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
- eva/vision/models/networks/backbones/radiology/voco.py +75 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
- eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
- eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
- eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
- eva/vision/utils/io/__init__.py +2 -0
- eva/vision/utils/io/nifti.py +91 -11
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/METADATA +3 -1
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/RECORD +73 -57
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/WHEEL +1 -1
- eva/vision/data/datasets/classification/base.py +0 -96
- eva/vision/data/datasets/segmentation/base.py +0 -96
- eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
- eva/vision/data/transforms/normalization/__init__.py +0 -6
- eva/vision/data/transforms/normalization/clamp.py +0 -43
- eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
- eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
- eva/vision/metrics/segmentation/BUILD +0 -1
- eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
- eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
"""Encoder based on Swin UNETR."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from monai.inferers.inferer import Inferer
|
|
7
|
+
from monai.networks.blocks import unetr_block
|
|
8
|
+
from monai.networks.nets import swin_unetr
|
|
9
|
+
from monai.utils import misc
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
from eva.vision.models.networks.backbones.registry import register_model
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@register_model("radiology/swin_unetr_encoder")
|
|
16
|
+
class SwinUNETREncoder(nn.Module):
|
|
17
|
+
"""Swin transformer encoder based on UNETR [0].
|
|
18
|
+
|
|
19
|
+
- [0] UNETR: Transformers for 3D Medical Image Segmentation
|
|
20
|
+
https://arxiv.org/pdf/2103.10504
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
in_channels: int = 1,
|
|
26
|
+
feature_size: int = 48,
|
|
27
|
+
spatial_dims: int = 3,
|
|
28
|
+
out_indices: int | None = None,
|
|
29
|
+
inferer: Inferer | None = None,
|
|
30
|
+
use_v2: bool = True,
|
|
31
|
+
) -> None:
|
|
32
|
+
"""Build the UNETR encoder.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
in_channels: Number of input channels.
|
|
36
|
+
feature_size: The dimension of network feature size.
|
|
37
|
+
spatial_dims: Number of spatial dimensions.
|
|
38
|
+
out_indices: Number of feature outputs. If None,
|
|
39
|
+
the aggregated feature vector is returned.
|
|
40
|
+
inferer: An optional MONAI `Inferer` for efficient
|
|
41
|
+
inference during evaluation.
|
|
42
|
+
use_v2: Whether to use SwinTransformerV2.
|
|
43
|
+
"""
|
|
44
|
+
super().__init__()
|
|
45
|
+
|
|
46
|
+
self._in_channels = in_channels
|
|
47
|
+
self._feature_size = feature_size
|
|
48
|
+
self._spatial_dims = spatial_dims
|
|
49
|
+
self._out_indices = out_indices
|
|
50
|
+
self._inferer = inferer
|
|
51
|
+
self._use_v2 = use_v2
|
|
52
|
+
|
|
53
|
+
self._window_size = misc.ensure_tuple_rep(7, spatial_dims)
|
|
54
|
+
self._patch_size = misc.ensure_tuple_rep(2, spatial_dims)
|
|
55
|
+
|
|
56
|
+
self.swinViT = swin_unetr.SwinTransformer(
|
|
57
|
+
in_chans=in_channels,
|
|
58
|
+
embed_dim=feature_size,
|
|
59
|
+
window_size=self._window_size,
|
|
60
|
+
patch_size=self._patch_size,
|
|
61
|
+
depths=(2, 2, 2, 2),
|
|
62
|
+
num_heads=(3, 6, 12, 24),
|
|
63
|
+
mlp_ratio=4.0,
|
|
64
|
+
qkv_bias=True,
|
|
65
|
+
drop_rate=0.0,
|
|
66
|
+
attn_drop_rate=0.0,
|
|
67
|
+
drop_path_rate=0.0,
|
|
68
|
+
norm_layer=torch.nn.LayerNorm,
|
|
69
|
+
spatial_dims=spatial_dims,
|
|
70
|
+
use_v2=use_v2,
|
|
71
|
+
)
|
|
72
|
+
self.encoder1 = unetr_block.UnetrBasicBlock(
|
|
73
|
+
spatial_dims=spatial_dims,
|
|
74
|
+
in_channels=in_channels,
|
|
75
|
+
out_channels=feature_size,
|
|
76
|
+
kernel_size=3,
|
|
77
|
+
stride=1,
|
|
78
|
+
norm_name="instance",
|
|
79
|
+
res_block=True,
|
|
80
|
+
)
|
|
81
|
+
self.encoder2 = unetr_block.UnetrBasicBlock(
|
|
82
|
+
spatial_dims=spatial_dims,
|
|
83
|
+
in_channels=feature_size,
|
|
84
|
+
out_channels=feature_size,
|
|
85
|
+
kernel_size=3,
|
|
86
|
+
stride=1,
|
|
87
|
+
norm_name="instance",
|
|
88
|
+
res_block=True,
|
|
89
|
+
)
|
|
90
|
+
self.encoder3 = unetr_block.UnetrBasicBlock(
|
|
91
|
+
spatial_dims=spatial_dims,
|
|
92
|
+
in_channels=2 * feature_size,
|
|
93
|
+
out_channels=2 * feature_size,
|
|
94
|
+
kernel_size=3,
|
|
95
|
+
stride=1,
|
|
96
|
+
norm_name="instance",
|
|
97
|
+
res_block=True,
|
|
98
|
+
)
|
|
99
|
+
self.encoder4 = unetr_block.UnetrBasicBlock(
|
|
100
|
+
spatial_dims=spatial_dims,
|
|
101
|
+
in_channels=4 * feature_size,
|
|
102
|
+
out_channels=4 * feature_size,
|
|
103
|
+
kernel_size=3,
|
|
104
|
+
stride=1,
|
|
105
|
+
norm_name="instance",
|
|
106
|
+
res_block=True,
|
|
107
|
+
)
|
|
108
|
+
self.encoder10 = unetr_block.UnetrBasicBlock(
|
|
109
|
+
spatial_dims=spatial_dims,
|
|
110
|
+
in_channels=16 * feature_size,
|
|
111
|
+
out_channels=16 * feature_size,
|
|
112
|
+
kernel_size=3,
|
|
113
|
+
stride=1,
|
|
114
|
+
norm_name="instance",
|
|
115
|
+
res_block=True,
|
|
116
|
+
)
|
|
117
|
+
self._pool_op = (
|
|
118
|
+
nn.AdaptiveAvgPool3d(output_size=(1, 1, 1))
|
|
119
|
+
if spatial_dims == 3
|
|
120
|
+
else nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def _forward_features(self, tensor: torch.Tensor) -> List[torch.Tensor]:
|
|
124
|
+
"""Extracts feature maps from the Swin Transformer and encoder blocks.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
tensor: Input tensor of shape (B, C, T, H, W).
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
List of feature maps from encoder stages.
|
|
131
|
+
"""
|
|
132
|
+
hidden_states = self.swinViT(tensor)
|
|
133
|
+
enc0 = self.encoder1(tensor)
|
|
134
|
+
enc1 = self.encoder2(hidden_states[0])
|
|
135
|
+
enc2 = self.encoder3(hidden_states[1])
|
|
136
|
+
enc3 = self.encoder4(hidden_states[2])
|
|
137
|
+
dec4 = self.encoder10(hidden_states[4])
|
|
138
|
+
return [enc0, enc1, enc2, enc3, hidden_states[3], dec4]
|
|
139
|
+
|
|
140
|
+
def forward_features(self, tensor: torch.Tensor) -> List[torch.Tensor]:
|
|
141
|
+
"""Computes feature maps using either standard forward pass or inference mode.
|
|
142
|
+
|
|
143
|
+
If in inference mode (`self.training` is False) and an inference method
|
|
144
|
+
(`self._inferer`) is available, the `_inferer` is used to extract features.
|
|
145
|
+
Otherwise, `_forward_features` is called directly.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
tensor: Input tensor of shape (B, C, T, H, W).
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
List of feature maps from encoder stages.
|
|
152
|
+
"""
|
|
153
|
+
if not self.training and self._inferer:
|
|
154
|
+
return self._inferer(inputs=tensor, network=self._forward_features)
|
|
155
|
+
|
|
156
|
+
return self._forward_features(tensor)
|
|
157
|
+
|
|
158
|
+
def forward_encoders(self, features: List[torch.Tensor]) -> torch.Tensor:
|
|
159
|
+
"""Aggregates encoder features into a single feature vector.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
features: List of feature maps from encoder stages.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
Aggregated feature vector (B, C').
|
|
166
|
+
"""
|
|
167
|
+
batch_size = features[0].shape[0]
|
|
168
|
+
reduced_features = []
|
|
169
|
+
for patch_features in features[:4] + features[5:]:
|
|
170
|
+
hidden_features = self._pool_op(patch_features)
|
|
171
|
+
hidden_features_reduced = hidden_features.view(batch_size, -1)
|
|
172
|
+
reduced_features.append(hidden_features_reduced)
|
|
173
|
+
return torch.cat(reduced_features, dim=1)
|
|
174
|
+
|
|
175
|
+
def forward_head(self, features: List[torch.Tensor]) -> torch.Tensor:
|
|
176
|
+
"""Casts last feature map into a single feature vector.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
features: List of encoder feature maps.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Aggregated feature vector (B, C').
|
|
183
|
+
"""
|
|
184
|
+
last_feature_map = features[-1]
|
|
185
|
+
pooled_features = self._pool_op(last_feature_map)
|
|
186
|
+
return torch.flatten(pooled_features, 1)
|
|
187
|
+
|
|
188
|
+
def forward_embeddings(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
189
|
+
"""Computes the final aggregated feature vector.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
tensor: Input tensor of shape (B, C, T, H, W).
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
Aggregated feature vector of shape (B, C').
|
|
196
|
+
"""
|
|
197
|
+
intermediates = self.forward_features(tensor)
|
|
198
|
+
return self.forward_encoders(intermediates)
|
|
199
|
+
|
|
200
|
+
def forward_intermediates(
|
|
201
|
+
self, tensor: torch.Tensor
|
|
202
|
+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
|
203
|
+
"""Computes encoder features and their embeddings.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
tensor: Input tensor of shape (B, C, T, H, W).
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
Aggregated feature vector and list of intermediate features.
|
|
210
|
+
"""
|
|
211
|
+
features = self.forward_features(tensor)
|
|
212
|
+
embeddings = self.forward_encoders(features)
|
|
213
|
+
return embeddings, features
|
|
214
|
+
|
|
215
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor | List[torch.Tensor]:
|
|
216
|
+
"""Forward pass through the encoder.
|
|
217
|
+
|
|
218
|
+
If `self._out_indices` is None, it returns the aggregated feature vector.
|
|
219
|
+
Otherwise, it returns the intermediate feature maps up to the specified index.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
tensor: Input tensor of shape (B, C, T, H, W).
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
Aggregated feature vector or intermediate features.
|
|
226
|
+
"""
|
|
227
|
+
if self._out_indices is None:
|
|
228
|
+
return self.forward_embeddings(tensor)
|
|
229
|
+
|
|
230
|
+
intermediates = self.forward_features(tensor)
|
|
231
|
+
return intermediates[-1 * self._out_indices :]
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""VoCo Self-Supervised Encoders."""
|
|
2
|
+
|
|
3
|
+
from typing_extensions import override
|
|
4
|
+
|
|
5
|
+
from eva.core.models.wrappers import _utils
|
|
6
|
+
from eva.vision.models.networks.backbones.radiology import swin_unetr
|
|
7
|
+
from eva.vision.models.networks.backbones.registry import register_model
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class _VoCo(swin_unetr.SwinUNETREncoder):
|
|
11
|
+
"""Base class for the VoCo self-supervised encoders."""
|
|
12
|
+
|
|
13
|
+
_checkpoint: str
|
|
14
|
+
"""Path to the model state dict."""
|
|
15
|
+
|
|
16
|
+
_md5: str | None = None
|
|
17
|
+
"""State dict MD5 validation code."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, feature_size: int, out_indices: int | None = None) -> None:
|
|
20
|
+
"""Initializes the model.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
feature_size: Size of the last feature map of SwinUNETR.
|
|
24
|
+
out_indices: The number of feature maps from intermediate blocks
|
|
25
|
+
to be returned. If set to 1, only the last feature map is returned.
|
|
26
|
+
"""
|
|
27
|
+
super().__init__(
|
|
28
|
+
in_channels=1,
|
|
29
|
+
feature_size=feature_size,
|
|
30
|
+
spatial_dims=3,
|
|
31
|
+
out_indices=out_indices,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
self._load_checkpoint()
|
|
35
|
+
|
|
36
|
+
def _load_checkpoint(self) -> None:
|
|
37
|
+
"""Loads the model checkpoint."""
|
|
38
|
+
state_dict = _utils.load_state_dict_from_url(self._checkpoint, md5=self._md5)
|
|
39
|
+
self.load_state_dict(state_dict)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@register_model("radiology/voco_b")
|
|
43
|
+
class VoCoB(_VoCo):
|
|
44
|
+
"""VoCo Self-supervised pre-trained B model."""
|
|
45
|
+
|
|
46
|
+
_checkpoint = "https://huggingface.co/Luffy503/VoCo/resolve/main/VoCo_B_SSL_head.pt"
|
|
47
|
+
_md5 = "f80c4da2f81d700bdae3df188f2057eb"
|
|
48
|
+
|
|
49
|
+
@override
|
|
50
|
+
def __init__(self, out_indices: int | None = None) -> None:
|
|
51
|
+
super().__init__(feature_size=48, out_indices=out_indices)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@register_model("radiology/voco_l")
|
|
55
|
+
class VoCoL(_VoCo):
|
|
56
|
+
"""VoCo Self-supervised pre-trained L model."""
|
|
57
|
+
|
|
58
|
+
_checkpoint = "https://huggingface.co/Luffy503/VoCo/resolve/main/VoCo_L_SSL_head.pt"
|
|
59
|
+
_md5 = "795095d1d43ef3808ec4c41798310136"
|
|
60
|
+
|
|
61
|
+
@override
|
|
62
|
+
def __init__(self, out_indices: int | None = None) -> None:
|
|
63
|
+
super().__init__(feature_size=96, out_indices=out_indices)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@register_model("radiology/voco_h")
|
|
67
|
+
class VoCoH(_VoCo):
|
|
68
|
+
"""VoCo Self-supervised pre-trained H model."""
|
|
69
|
+
|
|
70
|
+
_checkpoint = "https://huggingface.co/Luffy503/VoCo/resolve/main/VoCo_H_SSL_head.pt"
|
|
71
|
+
_md5 = "76f95a474736b60bf5b8aad94643744d"
|
|
72
|
+
|
|
73
|
+
@override
|
|
74
|
+
def __init__(self, out_indices: int | None = None) -> None:
|
|
75
|
+
super().__init__(feature_size=192, out_indices=out_indices)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Segmentation decoder heads API."""
|
|
2
2
|
|
|
3
|
+
from eva.vision.models.networks.decoders.segmentation.base import Decoder
|
|
3
4
|
from eva.vision.models.networks.decoders.segmentation.decoder2d import Decoder2D
|
|
4
5
|
from eva.vision.models.networks.decoders.segmentation.linear import LinearDecoder
|
|
5
6
|
from eva.vision.models.networks.decoders.segmentation.semantic import (
|
|
@@ -7,13 +8,16 @@ from eva.vision.models.networks.decoders.segmentation.semantic import (
|
|
|
7
8
|
ConvDecoderMS,
|
|
8
9
|
ConvDecoderWithImage,
|
|
9
10
|
SingleLinearDecoder,
|
|
11
|
+
SwinUNETRDecoder,
|
|
10
12
|
)
|
|
11
13
|
|
|
12
14
|
__all__ = [
|
|
15
|
+
"Decoder",
|
|
16
|
+
"Decoder2D",
|
|
13
17
|
"ConvDecoder1x1",
|
|
14
18
|
"ConvDecoderMS",
|
|
15
|
-
"SingleLinearDecoder",
|
|
16
19
|
"ConvDecoderWithImage",
|
|
17
|
-
"Decoder2D",
|
|
18
20
|
"LinearDecoder",
|
|
21
|
+
"SingleLinearDecoder",
|
|
22
|
+
"SwinUNETRDecoder",
|
|
19
23
|
]
|
|
@@ -7,6 +7,7 @@ from torch import nn
|
|
|
7
7
|
from torch.nn import functional
|
|
8
8
|
|
|
9
9
|
from eva.vision.models.networks.decoders.segmentation import base
|
|
10
|
+
from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
class LinearDecoder(base.Decoder):
|
|
@@ -104,22 +105,16 @@ class LinearDecoder(base.Decoder):
|
|
|
104
105
|
"""
|
|
105
106
|
return functional.interpolate(logits, image_size, mode="bilinear")
|
|
106
107
|
|
|
107
|
-
def forward(
|
|
108
|
-
self,
|
|
109
|
-
features: List[torch.Tensor],
|
|
110
|
-
image_size: Tuple[int, int],
|
|
111
|
-
) -> torch.Tensor:
|
|
108
|
+
def forward(self, decoder_inputs: DecoderInputs) -> torch.Tensor:
|
|
112
109
|
"""Maps the patch embeddings to a segmentation mask of the image size.
|
|
113
110
|
|
|
114
111
|
Args:
|
|
115
|
-
|
|
116
|
-
hidden_size, n_patches_height, n_patches_width).
|
|
117
|
-
image_size: The target image size (height, width).
|
|
112
|
+
decoder_inputs: Inputs required by the decoder.
|
|
118
113
|
|
|
119
114
|
Returns:
|
|
120
115
|
Tensor containing scores for all of the classes with shape
|
|
121
116
|
(batch_size, n_classes, image_height, image_width).
|
|
122
117
|
"""
|
|
123
|
-
patch_embeddings = self._forward_features(features)
|
|
118
|
+
patch_embeddings = self._forward_features(decoder_inputs.features)
|
|
124
119
|
logits = self._forward_head(patch_embeddings)
|
|
125
|
-
return self._cls_seg(logits, image_size)
|
|
120
|
+
return self._cls_seg(logits, decoder_inputs.image_size)
|
|
@@ -5,8 +5,15 @@ from eva.vision.models.networks.decoders.segmentation.semantic.common import (
|
|
|
5
5
|
ConvDecoderMS,
|
|
6
6
|
SingleLinearDecoder,
|
|
7
7
|
)
|
|
8
|
+
from eva.vision.models.networks.decoders.segmentation.semantic.swin_unetr import SwinUNETRDecoder
|
|
8
9
|
from eva.vision.models.networks.decoders.segmentation.semantic.with_image import (
|
|
9
10
|
ConvDecoderWithImage,
|
|
10
11
|
)
|
|
11
12
|
|
|
12
|
-
__all__ = [
|
|
13
|
+
__all__ = [
|
|
14
|
+
"ConvDecoder1x1",
|
|
15
|
+
"ConvDecoderMS",
|
|
16
|
+
"ConvDecoderWithImage",
|
|
17
|
+
"SingleLinearDecoder",
|
|
18
|
+
"SwinUNETRDecoder",
|
|
19
|
+
]
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""Decoder based on Swin UNETR."""
|
|
2
|
+
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from monai.networks.blocks import dynunet_block, unetr_block
|
|
7
|
+
from torch import nn
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SwinUNETRDecoder(nn.Module):
|
|
11
|
+
"""Swin transformer decoder based on UNETR [0].
|
|
12
|
+
|
|
13
|
+
- [0] UNETR: Transformers for 3D Medical Image Segmentation
|
|
14
|
+
https://arxiv.org/pdf/2103.10504
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
out_channels: int,
|
|
20
|
+
feature_size: int = 48,
|
|
21
|
+
spatial_dims: int = 3,
|
|
22
|
+
) -> None:
|
|
23
|
+
"""Builds the decoder.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
out_channels: Number of output channels.
|
|
27
|
+
feature_size: Dimension of network feature size.
|
|
28
|
+
spatial_dims: Number of spatial dimensions.
|
|
29
|
+
"""
|
|
30
|
+
super().__init__()
|
|
31
|
+
|
|
32
|
+
self.decoder5 = unetr_block.UnetrUpBlock(
|
|
33
|
+
spatial_dims=spatial_dims,
|
|
34
|
+
in_channels=16 * feature_size,
|
|
35
|
+
out_channels=8 * feature_size,
|
|
36
|
+
kernel_size=3,
|
|
37
|
+
upsample_kernel_size=2,
|
|
38
|
+
norm_name="instance",
|
|
39
|
+
res_block=True,
|
|
40
|
+
)
|
|
41
|
+
self.decoder4 = unetr_block.UnetrUpBlock(
|
|
42
|
+
spatial_dims=spatial_dims,
|
|
43
|
+
in_channels=feature_size * 8,
|
|
44
|
+
out_channels=feature_size * 4,
|
|
45
|
+
kernel_size=3,
|
|
46
|
+
upsample_kernel_size=2,
|
|
47
|
+
norm_name="instance",
|
|
48
|
+
res_block=True,
|
|
49
|
+
)
|
|
50
|
+
self.decoder3 = unetr_block.UnetrUpBlock(
|
|
51
|
+
spatial_dims=spatial_dims,
|
|
52
|
+
in_channels=feature_size * 4,
|
|
53
|
+
out_channels=feature_size * 2,
|
|
54
|
+
kernel_size=3,
|
|
55
|
+
upsample_kernel_size=2,
|
|
56
|
+
norm_name="instance",
|
|
57
|
+
res_block=True,
|
|
58
|
+
)
|
|
59
|
+
self.decoder2 = unetr_block.UnetrUpBlock(
|
|
60
|
+
spatial_dims=spatial_dims,
|
|
61
|
+
in_channels=feature_size * 2,
|
|
62
|
+
out_channels=feature_size,
|
|
63
|
+
kernel_size=3,
|
|
64
|
+
upsample_kernel_size=2,
|
|
65
|
+
norm_name="instance",
|
|
66
|
+
res_block=True,
|
|
67
|
+
)
|
|
68
|
+
self.decoder1 = unetr_block.UnetrUpBlock(
|
|
69
|
+
spatial_dims=spatial_dims,
|
|
70
|
+
in_channels=feature_size,
|
|
71
|
+
out_channels=feature_size,
|
|
72
|
+
kernel_size=3,
|
|
73
|
+
upsample_kernel_size=2,
|
|
74
|
+
norm_name="instance",
|
|
75
|
+
res_block=True,
|
|
76
|
+
)
|
|
77
|
+
self.out = dynunet_block.UnetOutBlock(
|
|
78
|
+
spatial_dims=spatial_dims,
|
|
79
|
+
in_channels=feature_size,
|
|
80
|
+
out_channels=out_channels,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor:
|
|
84
|
+
"""Forward function for multi-level feature maps to a single one."""
|
|
85
|
+
enc0, enc1, enc2, enc3, hid3, dec4 = features
|
|
86
|
+
dec3 = self.decoder5(dec4, hid3)
|
|
87
|
+
dec2 = self.decoder4(dec3, enc3)
|
|
88
|
+
dec1 = self.decoder3(dec2, enc2)
|
|
89
|
+
dec0 = self.decoder2(dec1, enc1)
|
|
90
|
+
out = self.decoder1(dec0, enc0)
|
|
91
|
+
return self.out(out)
|
|
92
|
+
|
|
93
|
+
def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
|
|
94
|
+
"""Maps the patch embeddings to a segmentation mask.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
features: List of multi-level intermediate features from
|
|
98
|
+
:class:`SwinUNETREncoder`.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
Tensor containing scores for all of the classes with shape
|
|
102
|
+
(batch_size, n_classes, image_height, image_width).
|
|
103
|
+
"""
|
|
104
|
+
return self._forward_features(features)
|
eva/vision/utils/io/__init__.py
CHANGED
|
@@ -5,6 +5,7 @@ from eva.vision.utils.io.mat import read_mat, save_mat
|
|
|
5
5
|
from eva.vision.utils.io.nifti import (
|
|
6
6
|
fetch_nifti_axis_direction_code,
|
|
7
7
|
fetch_nifti_shape,
|
|
8
|
+
nifti_to_array,
|
|
8
9
|
read_nifti,
|
|
9
10
|
save_array_as_nifti,
|
|
10
11
|
)
|
|
@@ -16,6 +17,7 @@ __all__ = [
|
|
|
16
17
|
"read_image_as_tensor",
|
|
17
18
|
"fetch_nifti_shape",
|
|
18
19
|
"fetch_nifti_axis_direction_code",
|
|
20
|
+
"nifti_to_array",
|
|
19
21
|
"read_nifti",
|
|
20
22
|
"save_array_as_nifti",
|
|
21
23
|
"read_csv",
|
eva/vision/utils/io/nifti.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
# type: ignore
|
|
1
2
|
"""NIfTI I/O related functions."""
|
|
2
3
|
|
|
3
4
|
from typing import Any, Tuple
|
|
@@ -7,36 +8,63 @@ import numpy as np
|
|
|
7
8
|
import numpy.typing as npt
|
|
8
9
|
from nibabel import orientations
|
|
9
10
|
|
|
11
|
+
from eva.core.utils.suppress_logs import SuppressLogs
|
|
10
12
|
from eva.vision.utils.io import _utils
|
|
11
13
|
|
|
12
14
|
|
|
13
15
|
def read_nifti(
|
|
14
|
-
path: str,
|
|
15
|
-
|
|
16
|
+
path: str,
|
|
17
|
+
slice_index: int | None = None,
|
|
18
|
+
*,
|
|
19
|
+
orientation: str | None = None,
|
|
20
|
+
orientation_reference: str | None = None,
|
|
21
|
+
) -> nib.nifti1.Nifti1Image:
|
|
16
22
|
"""Reads and loads a NIfTI image from a file path.
|
|
17
23
|
|
|
18
24
|
Args:
|
|
19
25
|
path: The path to the NIfTI file.
|
|
20
26
|
slice_index: Whether to read only a slice from the file.
|
|
27
|
+
orientation: The orientation code to reorient the nifti image.
|
|
28
|
+
orientation_reference: Path to a NIfTI file which
|
|
29
|
+
will be used as a reference for the orientation
|
|
30
|
+
transform in case the file missing the pixdim array
|
|
31
|
+
in the NIfTI header.
|
|
21
32
|
use_storage_dtype: Whether to cast the raw image
|
|
22
33
|
array to the inferred type.
|
|
23
34
|
|
|
24
35
|
Returns:
|
|
25
|
-
The image
|
|
36
|
+
The NIfTI image class instance.
|
|
26
37
|
|
|
27
38
|
Raises:
|
|
28
39
|
FileExistsError: If the path does not exist or it is unreachable.
|
|
29
40
|
ValueError: If the input channel is invalid for the image.
|
|
30
41
|
"""
|
|
31
42
|
_utils.check_file(path)
|
|
32
|
-
image_data
|
|
43
|
+
image_data = _load_nifti_silently(path)
|
|
33
44
|
if slice_index is not None:
|
|
34
45
|
image_data = image_data.slicer[:, :, slice_index : slice_index + 1]
|
|
46
|
+
if orientation:
|
|
47
|
+
image_data = _reorient(
|
|
48
|
+
image_data, orientation=orientation, reference_file=orientation_reference
|
|
49
|
+
)
|
|
35
50
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
51
|
+
return image_data
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def nifti_to_array(nii: nib.Nifti1Image, use_storage_dtype: bool = True) -> npt.NDArray[Any]:
|
|
55
|
+
"""Converts a NIfTI image to a numpy array.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
nii: The input NIfTI image.
|
|
59
|
+
use_storage_dtype: Whether to cast the raw image
|
|
60
|
+
array to the inferred type.
|
|
39
61
|
|
|
62
|
+
Returns:
|
|
63
|
+
The image as a numpy array (height, width, channels).
|
|
64
|
+
"""
|
|
65
|
+
image_array = nii.get_fdata()
|
|
66
|
+
if use_storage_dtype:
|
|
67
|
+
image_array = image_array.astype(nii.get_data_dtype())
|
|
40
68
|
return image_array
|
|
41
69
|
|
|
42
70
|
|
|
@@ -53,7 +81,7 @@ def save_array_as_nifti(
|
|
|
53
81
|
filename: The name to save the image like.
|
|
54
82
|
dtype: The data type to save the image.
|
|
55
83
|
"""
|
|
56
|
-
nifti_image = nib.Nifti1Image(array, affine=np.eye(4), dtype=dtype)
|
|
84
|
+
nifti_image = nib.Nifti1Image(array, affine=np.eye(4), dtype=dtype)
|
|
57
85
|
nifti_image.to_filename(filename)
|
|
58
86
|
|
|
59
87
|
|
|
@@ -71,8 +99,22 @@ def fetch_nifti_shape(path: str) -> Tuple[int]:
|
|
|
71
99
|
ValueError: If the input channel is invalid for the image.
|
|
72
100
|
"""
|
|
73
101
|
_utils.check_file(path)
|
|
74
|
-
|
|
75
|
-
return
|
|
102
|
+
nii = _load_nifti_silently(path)
|
|
103
|
+
return nii.header.get_data_shape() # type: ignore
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def fetch_nifti_orientation(path: str) -> npt.NDArray[Any]:
|
|
107
|
+
"""Fetches the NIfTI image orientation.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
path: The path to the NIfTI file.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
The array orientation.
|
|
114
|
+
"""
|
|
115
|
+
_utils.check_file(path)
|
|
116
|
+
nii = _load_nifti_silently(path)
|
|
117
|
+
return nib.io_orientation(nii.affine)
|
|
76
118
|
|
|
77
119
|
|
|
78
120
|
def fetch_nifti_axis_direction_code(path: str) -> str:
|
|
@@ -85,5 +127,43 @@ def fetch_nifti_axis_direction_code(path: str) -> str:
|
|
|
85
127
|
The axis direction codes as string (e.g. "LAS").
|
|
86
128
|
"""
|
|
87
129
|
_utils.check_file(path)
|
|
88
|
-
image_data: nib.Nifti1Image = nib.load(path)
|
|
130
|
+
image_data: nib.Nifti1Image = nib.load(path)
|
|
89
131
|
return "".join(orientations.aff2axcodes(image_data.affine))
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _load_nifti_silently(path: str) -> nib.Nifti1Image:
|
|
135
|
+
"""Reads a NIfTI image in silent mode."""
|
|
136
|
+
with SuppressLogs():
|
|
137
|
+
return nib.load(path)
|
|
138
|
+
raise ValueError(f"Failed to load NIfTI file: {path}")
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _reorient(
|
|
142
|
+
nii: nib.Nifti1Image,
|
|
143
|
+
/,
|
|
144
|
+
orientation: str | tuple[str, str, str] = "RAS",
|
|
145
|
+
reference_file: str | None = None,
|
|
146
|
+
) -> nib.Nifti1Image:
|
|
147
|
+
"""Reorients a NIfTI image to a specified orientation.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
nii: The input NIfTI image.
|
|
151
|
+
orientation: Desired orientation expressed as a
|
|
152
|
+
three-character string (e.g., "RAS") or a tuple
|
|
153
|
+
(e.g., ("R", "A", "S")).
|
|
154
|
+
reference_file: Path to a reference NIfTI file whose
|
|
155
|
+
orientation should be used if the input image lacks
|
|
156
|
+
a valid affine transformation.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
The reoriented NIfTI image.
|
|
160
|
+
"""
|
|
161
|
+
affine_matrix, _ = nii.get_qform(coded=True)
|
|
162
|
+
orig_ornt = (
|
|
163
|
+
fetch_nifti_orientation(reference_file)
|
|
164
|
+
if reference_file and affine_matrix is None
|
|
165
|
+
else nib.io_orientation(nii.affine)
|
|
166
|
+
)
|
|
167
|
+
targ_ornt = orientations.axcodes2ornt(orientation)
|
|
168
|
+
transform = orientations.ornt_transform(orig_ornt, targ_ornt)
|
|
169
|
+
return nii.as_reoriented(transform)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: kaiko-eva
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.1
|
|
4
4
|
Summary: Evaluation Framework for oncology foundation models.
|
|
5
5
|
Keywords: machine-learning,evaluation-framework,oncology,foundation-models
|
|
6
6
|
Author-Email: Ioannis Gatopoulos <ioannis@kaiko.ai>, =?utf-8?q?Nicolas_K=C3=A4nzig?= <nicolas@kaiko.ai>, Roman Moser <roman@kaiko.ai>
|
|
@@ -241,6 +241,7 @@ Requires-Dist: scikit-image>=0.24.0; extra == "vision"
|
|
|
241
241
|
Requires-Dist: imagesize>=1.4.1; extra == "vision"
|
|
242
242
|
Requires-Dist: scipy>=1.14.0; extra == "vision"
|
|
243
243
|
Requires-Dist: monai>=1.3.2; extra == "vision"
|
|
244
|
+
Requires-Dist: einops>=0.8.1; extra == "vision"
|
|
244
245
|
Provides-Extra: all
|
|
245
246
|
Requires-Dist: h5py>=3.10.0; extra == "all"
|
|
246
247
|
Requires-Dist: nibabel>=4.0.1; extra == "all"
|
|
@@ -253,6 +254,7 @@ Requires-Dist: scikit-image>=0.24.0; extra == "all"
|
|
|
253
254
|
Requires-Dist: imagesize>=1.4.1; extra == "all"
|
|
254
255
|
Requires-Dist: scipy>=1.14.0; extra == "all"
|
|
255
256
|
Requires-Dist: monai>=1.3.2; extra == "all"
|
|
257
|
+
Requires-Dist: einops>=0.8.1; extra == "all"
|
|
256
258
|
Description-Content-Type: text/markdown
|
|
257
259
|
|
|
258
260
|
<div align="center">
|