kaiko-eva 0.1.8__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eva/core/data/datasets/base.py +7 -2
- eva/core/data/datasets/classification/embeddings.py +2 -2
- eva/core/data/datasets/classification/multi_embeddings.py +2 -2
- eva/core/data/datasets/embeddings.py +4 -4
- eva/core/data/samplers/classification/balanced.py +19 -18
- eva/core/loggers/utils/wandb.py +33 -0
- eva/core/models/modules/head.py +5 -3
- eva/core/models/modules/typings.py +2 -2
- eva/core/models/transforms/__init__.py +2 -1
- eva/core/models/transforms/as_discrete.py +57 -0
- eva/core/models/wrappers/_utils.py +121 -1
- eva/core/trainers/functional.py +8 -5
- eva/core/trainers/trainer.py +32 -17
- eva/core/utils/suppress_logs.py +28 -0
- eva/vision/data/__init__.py +2 -2
- eva/vision/data/dataloaders/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
- eva/vision/data/datasets/__init__.py +10 -2
- eva/vision/data/datasets/classification/__init__.py +9 -0
- eva/vision/data/datasets/classification/bach.py +3 -4
- eva/vision/data/datasets/classification/bracs.py +111 -0
- eva/vision/data/datasets/classification/breakhis.py +209 -0
- eva/vision/data/datasets/classification/camelyon16.py +4 -5
- eva/vision/data/datasets/classification/crc.py +3 -4
- eva/vision/data/datasets/classification/gleason_arvaniti.py +171 -0
- eva/vision/data/datasets/classification/mhist.py +3 -4
- eva/vision/data/datasets/classification/panda.py +4 -5
- eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
- eva/vision/data/datasets/classification/unitopatho.py +158 -0
- eva/vision/data/datasets/classification/wsi.py +6 -5
- eva/vision/data/datasets/segmentation/__init__.py +2 -2
- eva/vision/data/datasets/segmentation/_utils.py +47 -0
- eva/vision/data/datasets/segmentation/bcss.py +7 -8
- eva/vision/data/datasets/segmentation/btcv.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +6 -7
- eva/vision/data/datasets/segmentation/embeddings.py +2 -2
- eva/vision/data/datasets/segmentation/lits.py +9 -8
- eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
- eva/vision/data/datasets/segmentation/monusac.py +4 -5
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
- eva/vision/data/datasets/vision.py +95 -4
- eva/vision/data/datasets/wsi.py +5 -5
- eva/vision/data/transforms/__init__.py +22 -3
- eva/vision/data/transforms/common/__init__.py +1 -2
- eva/vision/data/transforms/croppad/__init__.py +11 -0
- eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
- eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
- eva/vision/data/transforms/intensity/__init__.py +11 -0
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
- eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
- eva/vision/data/transforms/spatial/__init__.py +7 -0
- eva/vision/data/transforms/spatial/flip.py +72 -0
- eva/vision/data/transforms/spatial/rotate.py +53 -0
- eva/vision/data/transforms/spatial/spacing.py +69 -0
- eva/vision/data/transforms/utility/__init__.py +5 -0
- eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
- eva/vision/data/tv_tensors/__init__.py +5 -0
- eva/vision/data/tv_tensors/volume.py +61 -0
- eva/vision/metrics/segmentation/monai_dice.py +9 -2
- eva/vision/models/modules/semantic_segmentation.py +28 -20
- eva/vision/models/networks/backbones/__init__.py +9 -2
- eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
- eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
- eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +46 -19
- eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
- eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
- eva/vision/models/networks/backbones/radiology/voco.py +75 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
- eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
- eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
- eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
- eva/vision/utils/io/__init__.py +2 -0
- eva/vision/utils/io/nifti.py +91 -11
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/METADATA +3 -1
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/RECORD +83 -62
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/WHEEL +1 -1
- eva/vision/data/datasets/classification/base.py +0 -96
- eva/vision/data/datasets/segmentation/base.py +0 -96
- eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
- eva/vision/data/transforms/normalization/__init__.py +0 -6
- eva/vision/data/transforms/normalization/clamp.py +0 -43
- eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
- eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
- eva/vision/metrics/segmentation/BUILD +0 -1
- eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
- eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""Intensity shifting transform."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict
|
|
5
|
+
|
|
6
|
+
from monai.transforms.intensity import array as monai_intensity_transforms
|
|
7
|
+
from torchvision import tv_tensors
|
|
8
|
+
from torchvision.transforms import v2
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RandShiftIntensity(v2.Transform):
|
|
15
|
+
"""Randomly shift intensity with randomly picked offset."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
offsets: tuple[float, float] | float,
|
|
20
|
+
safe: bool = False,
|
|
21
|
+
prob: float = 0.1,
|
|
22
|
+
channel_wise: bool = False,
|
|
23
|
+
) -> None:
|
|
24
|
+
"""Initializes the transform.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
offsets: Offset range to randomly shift.
|
|
28
|
+
if single number, offset value is picked from (-offsets, offsets).
|
|
29
|
+
safe: If `True`, then do safe dtype convert when intensity overflow.
|
|
30
|
+
E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then
|
|
31
|
+
`[256, -12]` -> `[array(255), array(0)]`.
|
|
32
|
+
prob: Probability of shift.
|
|
33
|
+
channel_wise: If True, shift intensity on each channel separately.
|
|
34
|
+
For each channel, a random offset will be chosen. Please ensure
|
|
35
|
+
that the first dimension represents the channel of the image if True.
|
|
36
|
+
"""
|
|
37
|
+
super().__init__()
|
|
38
|
+
|
|
39
|
+
self._rand_swift_intensity = monai_intensity_transforms.RandShiftIntensity(
|
|
40
|
+
offsets=offsets,
|
|
41
|
+
safe=safe,
|
|
42
|
+
prob=prob,
|
|
43
|
+
channel_wise=channel_wise,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
@functools.singledispatchmethod
|
|
47
|
+
@override
|
|
48
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
49
|
+
return inpt
|
|
50
|
+
|
|
51
|
+
@_transform.register(tv_tensors.Image)
|
|
52
|
+
@_transform.register(eva_tv_tensors.Volume)
|
|
53
|
+
def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
|
|
54
|
+
inpt_scaled = self._rand_swift_intensity(inpt)
|
|
55
|
+
return tv_tensors.wrap(inpt_scaled, like=inpt)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Intensity scaling transform."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, Tuple
|
|
5
|
+
|
|
6
|
+
from monai.transforms.intensity import array as monai_intensity_transforms
|
|
7
|
+
from torchvision import tv_tensors
|
|
8
|
+
from torchvision.transforms import v2
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ScaleIntensityRange(v2.Transform):
|
|
15
|
+
"""Intensity scaling transform.
|
|
16
|
+
|
|
17
|
+
Scaling from [a_min, a_max] to [b_min, b_max] with clip option.
|
|
18
|
+
|
|
19
|
+
When `b_min` or `b_max` are `None`, `scaled_array * (b_max - b_min) + b_min`
|
|
20
|
+
will be skipped. If `clip=True`, when `b_min`/`b_max` is None, the clipping
|
|
21
|
+
is not performed on the corresponding edge.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
input_range: Tuple[float, float],
|
|
27
|
+
output_range: Tuple[float, float] | None = None,
|
|
28
|
+
clip: bool = True,
|
|
29
|
+
) -> None:
|
|
30
|
+
"""Initializes the transform.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
input_range: Intensity original range min and max.
|
|
34
|
+
output_range: Intensity target range min and max.
|
|
35
|
+
clip: Whether to perform clip after scaling.
|
|
36
|
+
"""
|
|
37
|
+
super().__init__()
|
|
38
|
+
|
|
39
|
+
self._scale_intensity_range = monai_intensity_transforms.ScaleIntensityRange(
|
|
40
|
+
a_min=input_range[0],
|
|
41
|
+
a_max=input_range[1],
|
|
42
|
+
b_min=output_range[0] if output_range else None,
|
|
43
|
+
b_max=output_range[1] if output_range else None,
|
|
44
|
+
clip=clip,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
@functools.singledispatchmethod
|
|
48
|
+
@override
|
|
49
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
50
|
+
return inpt
|
|
51
|
+
|
|
52
|
+
@_transform.register(tv_tensors.Image)
|
|
53
|
+
@_transform.register(eva_tv_tensors.Volume)
|
|
54
|
+
def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
|
|
55
|
+
inpt_scaled = self._scale_intensity_range(inpt)
|
|
56
|
+
return tv_tensors.wrap(inpt_scaled, like=inpt)
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
"""Transforms for spatial operations."""
|
|
2
|
+
|
|
3
|
+
from eva.vision.data.transforms.spatial.flip import RandFlip
|
|
4
|
+
from eva.vision.data.transforms.spatial.rotate import RandRotate90
|
|
5
|
+
from eva.vision.data.transforms.spatial.spacing import Spacing
|
|
6
|
+
|
|
7
|
+
__all__ = ["Spacing", "RandFlip", "RandRotate90"]
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""Flip transforms."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, List, Sequence
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from monai.transforms.spatial import array as monai_spatial_transforms
|
|
8
|
+
from torchvision import tv_tensors
|
|
9
|
+
from torchvision.transforms import v2
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class RandFlip(v2.Transform):
|
|
16
|
+
"""Randomly flips the image along axes."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
prob: float = 0.1,
|
|
21
|
+
spatial_axes: Sequence[int] | int | None = None,
|
|
22
|
+
apply_per_axis: bool = True,
|
|
23
|
+
) -> None:
|
|
24
|
+
"""Initializes the transform.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
prob: Probability of flipping.
|
|
28
|
+
spatial_axes: Spatial axes along which to flip over. Default is None.
|
|
29
|
+
apply_per_axis: If True, will apply a random flip transform to each
|
|
30
|
+
axis individually (if spatial_axes is a sequence of multiple axis).
|
|
31
|
+
If False, will apply a single random flip transform applied to all axes.
|
|
32
|
+
"""
|
|
33
|
+
super().__init__()
|
|
34
|
+
|
|
35
|
+
if apply_per_axis:
|
|
36
|
+
if not isinstance(spatial_axes, (list, tuple)):
|
|
37
|
+
raise ValueError(
|
|
38
|
+
"`spatial_axis` is expected to be sequence `apply_per_axis` "
|
|
39
|
+
f"is enabled, got {type(spatial_axes)}"
|
|
40
|
+
)
|
|
41
|
+
self._flips = [
|
|
42
|
+
monai_spatial_transforms.RandFlip(prob=prob, spatial_axis=axis)
|
|
43
|
+
for axis in spatial_axes
|
|
44
|
+
]
|
|
45
|
+
else:
|
|
46
|
+
self._flips = [monai_spatial_transforms.RandFlip(prob=prob, spatial_axis=spatial_axes)]
|
|
47
|
+
|
|
48
|
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
49
|
+
for flip in self._flips:
|
|
50
|
+
flip.randomize(None)
|
|
51
|
+
return {}
|
|
52
|
+
|
|
53
|
+
@functools.singledispatchmethod
|
|
54
|
+
@override
|
|
55
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
56
|
+
return inpt
|
|
57
|
+
|
|
58
|
+
@_transform.register(tv_tensors.Image)
|
|
59
|
+
@_transform.register(eva_tv_tensors.Volume)
|
|
60
|
+
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
61
|
+
inpt_flipped = self._apply_flips(inpt)
|
|
62
|
+
return tv_tensors.wrap(inpt_flipped, like=inpt)
|
|
63
|
+
|
|
64
|
+
@_transform.register(tv_tensors.Mask)
|
|
65
|
+
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
66
|
+
inpt_flipped = torch.tensor(self._apply_flips(inpt), dtype=torch.long)
|
|
67
|
+
return tv_tensors.wrap(inpt_flipped, like=inpt)
|
|
68
|
+
|
|
69
|
+
def _apply_flips(self, inpt: Any) -> Any:
|
|
70
|
+
for flip in self._flips:
|
|
71
|
+
inpt = flip(img=inpt, randomize=False)
|
|
72
|
+
return inpt
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Rotation transforms."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
|
|
6
|
+
from monai.transforms.spatial import array as monai_spatial_transforms
|
|
7
|
+
from torchvision import tv_tensors
|
|
8
|
+
from torchvision.transforms import v2
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RandRotate90(v2.Transform):
|
|
15
|
+
"""Rotate input tensors by 90 degrees."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
prob: float = 0.1,
|
|
20
|
+
max_k: int = 3,
|
|
21
|
+
spatial_axes: tuple[int, int] = (1, 2),
|
|
22
|
+
) -> None:
|
|
23
|
+
"""Initializes the transform.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
prob: probability of rotating.
|
|
27
|
+
(Default 0.1, with 10% probability it returns a rotated array)
|
|
28
|
+
max_k: number of rotations will be sampled from `np.random.randint(max_k) + 1`.
|
|
29
|
+
spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes.
|
|
30
|
+
Default: (1, 2), so for [C, T, H, W] will rotate along (H, W) plane (MONAI ignores
|
|
31
|
+
the first C dimension).
|
|
32
|
+
"""
|
|
33
|
+
super().__init__()
|
|
34
|
+
|
|
35
|
+
self._rotate = monai_spatial_transforms.RandRotate90(
|
|
36
|
+
prob=prob, max_k=max_k, spatial_axes=spatial_axes
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
40
|
+
self._rotate.randomize()
|
|
41
|
+
return {}
|
|
42
|
+
|
|
43
|
+
@functools.singledispatchmethod
|
|
44
|
+
@override
|
|
45
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
46
|
+
return inpt
|
|
47
|
+
|
|
48
|
+
@_transform.register(tv_tensors.Image)
|
|
49
|
+
@_transform.register(eva_tv_tensors.Volume)
|
|
50
|
+
@_transform.register(tv_tensors.Mask)
|
|
51
|
+
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
52
|
+
inpt_rotated = self._rotate(img=inpt, randomize=False)
|
|
53
|
+
return tv_tensors.wrap(inpt_rotated, like=inpt)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Spacing resample transform."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, List, Sequence
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from monai.data import meta_tensor
|
|
9
|
+
from monai.transforms.spatial import array as monai_spatial_transforms
|
|
10
|
+
from torchvision import tv_tensors
|
|
11
|
+
from torchvision.transforms import v2
|
|
12
|
+
from typing_extensions import override
|
|
13
|
+
|
|
14
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Spacing(v2.Transform):
|
|
18
|
+
"""Resample input image into the specified `pixdim`.
|
|
19
|
+
|
|
20
|
+
- Expects tensors of shape `[C, T, H, W]`.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
pixdim: Sequence[float] | float | np.ndarray,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Initializes the transform.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
pixdim: output voxel spacing. if providing a single number,
|
|
31
|
+
will use it for the first dimension. Items of the pixdim
|
|
32
|
+
sequence map to the spatial dimensions of input image, if
|
|
33
|
+
length of pixdim sequence is longer than image spatial
|
|
34
|
+
dimensions, will ignore the longer part, if shorter, will
|
|
35
|
+
pad with the last value. For example, for 3D image if pixdim
|
|
36
|
+
is [1.0, 2.0] it will be padded to [1.0, 2.0, 2.0] if the
|
|
37
|
+
components of the `pixdim` are non-positive values, the
|
|
38
|
+
transform will use the corresponding components of the original
|
|
39
|
+
pixdim, which is computed from the `affine` matrix of input image.
|
|
40
|
+
"""
|
|
41
|
+
super().__init__()
|
|
42
|
+
|
|
43
|
+
self._spacing = monai_spatial_transforms.Spacing(pixdim=pixdim, recompute_affine=True)
|
|
44
|
+
self._affine = None
|
|
45
|
+
|
|
46
|
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
47
|
+
self._affine = next(
|
|
48
|
+
inpt.affine for inpt in flat_inputs if isinstance(inpt, eva_tv_tensors.Volume)
|
|
49
|
+
)
|
|
50
|
+
return {}
|
|
51
|
+
|
|
52
|
+
@functools.singledispatchmethod
|
|
53
|
+
@override
|
|
54
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
55
|
+
return inpt
|
|
56
|
+
|
|
57
|
+
@_transform.register(eva_tv_tensors.Volume)
|
|
58
|
+
def _(self, inpt: eva_tv_tensors.Volume, params: Dict[str, Any]) -> Any:
|
|
59
|
+
inpt_spacing = self._spacing(inpt.to_meta_tensor(), mode="bilinear")
|
|
60
|
+
if not isinstance(inpt_spacing, meta_tensor.MetaTensor):
|
|
61
|
+
raise ValueError(f"Expected MetaTensor, got {type(inpt_spacing)}")
|
|
62
|
+
return eva_tv_tensors.Volume.from_meta_tensor(inpt_spacing)
|
|
63
|
+
|
|
64
|
+
@_transform.register(tv_tensors.Mask)
|
|
65
|
+
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
66
|
+
inpt_spacing = self._spacing(
|
|
67
|
+
meta_tensor.MetaTensor(inpt, affine=self._affine), mode="nearest"
|
|
68
|
+
)
|
|
69
|
+
return tv_tensors.wrap(inpt_spacing.to(dtype=torch.long), like=inpt)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Adjust or add the channel dimension of input data to ensure `channel_first` shape."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict
|
|
5
|
+
|
|
6
|
+
from monai.transforms.utility import array as monai_utility_transforms
|
|
7
|
+
from torchvision import tv_tensors
|
|
8
|
+
from torchvision.transforms import v2
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class EnsureChannelFirst(v2.Transform):
|
|
15
|
+
"""Adjust or add the channel dimension of input data to ensure `channel_first` shape."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
strict_check: bool = True,
|
|
20
|
+
channel_dim: None | str | int = None,
|
|
21
|
+
) -> None:
|
|
22
|
+
"""Initializes the transform.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
strict_check: whether to raise an error when the meta information is insufficient.
|
|
26
|
+
channel_dim: This argument can be used to specify the original channel dimension
|
|
27
|
+
(integer) of the input array.
|
|
28
|
+
It overrides the `original_channel_dim` from provided MetaTensor input.
|
|
29
|
+
If the input array doesn't have a channel dim, this value should be
|
|
30
|
+
``'no_channel'``.
|
|
31
|
+
If this is set to `None`, this class relies on `img` or `meta_dict` to provide
|
|
32
|
+
the channel dimension.
|
|
33
|
+
"""
|
|
34
|
+
super().__init__()
|
|
35
|
+
|
|
36
|
+
self._ensure_channel_first = monai_utility_transforms.EnsureChannelFirst(
|
|
37
|
+
strict_check=strict_check,
|
|
38
|
+
channel_dim=channel_dim,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
@functools.singledispatchmethod
|
|
42
|
+
@override
|
|
43
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
44
|
+
return inpt
|
|
45
|
+
|
|
46
|
+
@_transform.register(tv_tensors.Image)
|
|
47
|
+
@_transform.register(eva_tv_tensors.Volume)
|
|
48
|
+
@_transform.register(tv_tensors.Mask)
|
|
49
|
+
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
50
|
+
inpt_channel_first = self._ensure_channel_first(inpt)
|
|
51
|
+
return tv_tensors.wrap(inpt_channel_first, like=inpt)
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""Custom `tv_tensors` type for 3D Volumes."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Optional, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from monai.data import meta_tensor
|
|
7
|
+
from torchvision import tv_tensors
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Volume(tv_tensors.Video):
|
|
12
|
+
""":class:`torchvision.TVTensor` subclass for 3D volumes.
|
|
13
|
+
|
|
14
|
+
- Adds optional metadata and affine matrix to the tensor.
|
|
15
|
+
- Expects tensors to be of shape `[..., T, C, H, W]`.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
data: Any data that can be turned into a tensor with :func:`torch.as_tensor`.
|
|
19
|
+
affine: Affine matrix of the volume. Expected to be of shape `[4, 4]`, and
|
|
20
|
+
columns to correspond to [T, H, W, (translation)] dimensions. Note that
|
|
21
|
+
`nibabel` by default uses [H, W, T, (translation)] order for affine matrices.
|
|
22
|
+
metadata: Metadata associated with the volume.
|
|
23
|
+
dtype: Desired data type. If omitted, will be inferred from `data`.
|
|
24
|
+
device: Desired device.
|
|
25
|
+
requires_grad: Whether autograd should record operations.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
@override
|
|
29
|
+
def __new__(
|
|
30
|
+
cls,
|
|
31
|
+
data: Any,
|
|
32
|
+
affine: torch.Tensor | None = None,
|
|
33
|
+
metadata: Dict[str, Any] | None = None,
|
|
34
|
+
dtype: Optional[torch.dtype] = None,
|
|
35
|
+
device: Optional[Union[torch.device, str, int]] = None,
|
|
36
|
+
requires_grad: Optional[bool] = None,
|
|
37
|
+
) -> "Volume":
|
|
38
|
+
cls.affine = affine
|
|
39
|
+
cls.metadata = metadata
|
|
40
|
+
|
|
41
|
+
return super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad) # type: ignore
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def from_meta_tensor(cls, meta_tensor: meta_tensor.MetaTensor) -> "Volume":
|
|
45
|
+
"""Creates an instance from a :class:`monai.data.meta_tensor.MetaTensor`."""
|
|
46
|
+
return cls(
|
|
47
|
+
meta_tensor.data,
|
|
48
|
+
affine=meta_tensor.affine,
|
|
49
|
+
metadata=meta_tensor.meta,
|
|
50
|
+
dtype=meta_tensor.dtype,
|
|
51
|
+
device=meta_tensor.device,
|
|
52
|
+
requires_grad=meta_tensor.requires_grad,
|
|
53
|
+
) # type: ignore
|
|
54
|
+
|
|
55
|
+
def to_meta_tensor(self) -> meta_tensor.MetaTensor:
|
|
56
|
+
"""Converts the volume to a :class:`monai.data.meta_tensor.MetaTensor`."""
|
|
57
|
+
return meta_tensor.MetaTensor(self, affine=self.affine, meta=self.metadata)
|
|
58
|
+
|
|
59
|
+
def __repr__(self, *, tensor_contents: Any = None) -> str:
|
|
60
|
+
"""Returns the string representation of the object."""
|
|
61
|
+
return self._make_repr()
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
"""Wrapper for dice score metric from MONAI."""
|
|
2
2
|
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
3
5
|
from monai.metrics.meandice import DiceMetric
|
|
4
6
|
from typing_extensions import override
|
|
5
7
|
|
|
@@ -14,6 +16,7 @@ class MonaiDiceScore(wrappers.MonaiMetricWrapper):
|
|
|
14
16
|
self,
|
|
15
17
|
num_classes: int,
|
|
16
18
|
include_background: bool = True,
|
|
19
|
+
input_format: Literal["one-hot", "index"] = "index",
|
|
17
20
|
reduction: str = "mean",
|
|
18
21
|
ignore_index: int | None = None,
|
|
19
22
|
**kwargs,
|
|
@@ -24,6 +27,8 @@ class MonaiDiceScore(wrappers.MonaiMetricWrapper):
|
|
|
24
27
|
num_classes: The number of classes in the dataset.
|
|
25
28
|
include_background: Whether to include the background class in the computation.
|
|
26
29
|
reduction: The method to reduce the dice score. Options are `"mean"`, `"sum"`, `"none"`.
|
|
30
|
+
input_format: Choose between "one-hot" for one-hot encoded tensors or "index"
|
|
31
|
+
for index tensors.
|
|
27
32
|
ignore_index: Integer specifying a target class to ignore. If given, this class
|
|
28
33
|
index does not contribute to the returned score.
|
|
29
34
|
kwargs: Additional keyword arguments for instantiating monai's `DiceMetric` class.
|
|
@@ -40,11 +45,13 @@ class MonaiDiceScore(wrappers.MonaiMetricWrapper):
|
|
|
40
45
|
self.reduction = reduction
|
|
41
46
|
self.orig_num_classes = num_classes
|
|
42
47
|
self.ignore_index = ignore_index
|
|
48
|
+
self.input_format = input_format
|
|
43
49
|
|
|
44
50
|
@override
|
|
45
51
|
def update(self, preds, target):
|
|
46
|
-
|
|
47
|
-
|
|
52
|
+
if self.input_format == "index":
|
|
53
|
+
preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes)
|
|
54
|
+
target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes)
|
|
48
55
|
if self.ignore_index is not None:
|
|
49
56
|
preds, target = _utils.apply_ignore_index(preds, target, self.ignore_index)
|
|
50
57
|
return super().update(preds, target)
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
""""Neural Network Semantic Segmentation Module."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Callable, Dict, Iterable, List
|
|
3
|
+
from typing import Any, Callable, Dict, Iterable, List
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
|
|
7
7
|
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
8
|
+
from monai.inferers.inferer import Inferer
|
|
8
9
|
from torch import nn, optim
|
|
9
10
|
from torch.optim import lr_scheduler
|
|
10
11
|
from typing_extensions import override
|
|
@@ -15,6 +16,7 @@ from eva.core.models.modules.typings import INPUT_BATCH, INPUT_TENSOR_BATCH
|
|
|
15
16
|
from eva.core.models.modules.utils import batch_postprocess, grad, submodule_state_dict
|
|
16
17
|
from eva.core.utils import parser
|
|
17
18
|
from eva.vision.models.networks import decoders
|
|
19
|
+
from eva.vision.models.networks.decoders import segmentation
|
|
18
20
|
from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs
|
|
19
21
|
|
|
20
22
|
|
|
@@ -23,10 +25,11 @@ class SemanticSegmentationModule(module.ModelModule):
|
|
|
23
25
|
|
|
24
26
|
def __init__(
|
|
25
27
|
self,
|
|
26
|
-
decoder: decoders.Decoder,
|
|
28
|
+
decoder: decoders.Decoder | nn.Module,
|
|
27
29
|
criterion: Callable[..., torch.Tensor],
|
|
28
30
|
encoder: Dict[str, Any] | Callable[[torch.Tensor], List[torch.Tensor]] | None = None,
|
|
29
31
|
lr_multiplier_encoder: float = 0.0,
|
|
32
|
+
inferer: Inferer | None = None,
|
|
30
33
|
optimizer: OptimizerCallable = optim.AdamW,
|
|
31
34
|
lr_scheduler: LRSchedulerCallable = lr_scheduler.ConstantLR,
|
|
32
35
|
metrics: metrics_lib.MetricsSchema | None = None,
|
|
@@ -44,6 +47,8 @@ class SemanticSegmentationModule(module.ModelModule):
|
|
|
44
47
|
during the `configure_model` step.
|
|
45
48
|
lr_multiplier_encoder: The learning rate multiplier for the
|
|
46
49
|
encoder parameters. If `0`, it will freeze the encoder.
|
|
50
|
+
inferer: An optional MONAI `Inferer` for inference
|
|
51
|
+
postprocess during evaluation.
|
|
47
52
|
optimizer: The optimizer to use.
|
|
48
53
|
lr_scheduler: The learning rate scheduler to use.
|
|
49
54
|
metrics: The metric groups to track.
|
|
@@ -62,6 +67,7 @@ class SemanticSegmentationModule(module.ModelModule):
|
|
|
62
67
|
self.optimizer = optimizer
|
|
63
68
|
self.lr_scheduler = lr_scheduler
|
|
64
69
|
self.save_decoder_only = save_decoder_only
|
|
70
|
+
self.inferer = inferer
|
|
65
71
|
|
|
66
72
|
@override
|
|
67
73
|
def configure_model(self) -> None:
|
|
@@ -104,25 +110,15 @@ class SemanticSegmentationModule(module.ModelModule):
|
|
|
104
110
|
@override
|
|
105
111
|
def forward(
|
|
106
112
|
self,
|
|
107
|
-
|
|
108
|
-
to_size: Tuple[int, int] | None = None,
|
|
113
|
+
tensor: torch.Tensor,
|
|
109
114
|
*args: Any,
|
|
110
115
|
**kwargs: Any,
|
|
111
116
|
) -> torch.Tensor:
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
"""
|
|
118
|
-
if self.encoder is None and to_size is None:
|
|
119
|
-
raise ValueError(
|
|
120
|
-
"Please provide the expected `to_size` that the "
|
|
121
|
-
"decoder should map the embeddings (`inputs`) to."
|
|
122
|
-
)
|
|
123
|
-
features = self.encoder(inputs) if self.encoder else inputs
|
|
124
|
-
decoder_inputs = DecoderInputs(features, to_size or inputs.shape[-2:], inputs) # type: ignore
|
|
125
|
-
return self.decoder(decoder_inputs)
|
|
117
|
+
return (
|
|
118
|
+
self.inferer(tensor, network=self._forward_networks)
|
|
119
|
+
if self.inferer is not None and not self.training
|
|
120
|
+
else self._forward_networks(tensor)
|
|
121
|
+
)
|
|
126
122
|
|
|
127
123
|
@override
|
|
128
124
|
def training_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
@@ -137,7 +133,9 @@ class SemanticSegmentationModule(module.ModelModule):
|
|
|
137
133
|
return self._batch_step(batch)
|
|
138
134
|
|
|
139
135
|
@override
|
|
140
|
-
def predict_step(
|
|
136
|
+
def predict_step(
|
|
137
|
+
self, batch: INPUT_BATCH, *args: Any, **kwargs: Any
|
|
138
|
+
) -> torch.Tensor | List[torch.Tensor]:
|
|
141
139
|
tensor = INPUT_BATCH(*batch).data
|
|
142
140
|
return self.encoder(tensor) if isinstance(self.encoder, nn.Module) else tensor
|
|
143
141
|
|
|
@@ -170,7 +168,7 @@ class SemanticSegmentationModule(module.ModelModule):
|
|
|
170
168
|
The batch step output.
|
|
171
169
|
"""
|
|
172
170
|
data, targets, metadata = INPUT_TENSOR_BATCH(*batch)
|
|
173
|
-
predictions = self(data
|
|
171
|
+
predictions = self(data)
|
|
174
172
|
loss = self.criterion(predictions, targets)
|
|
175
173
|
return {
|
|
176
174
|
"loss": loss,
|
|
@@ -178,3 +176,13 @@ class SemanticSegmentationModule(module.ModelModule):
|
|
|
178
176
|
"predictions": predictions,
|
|
179
177
|
"metadata": metadata,
|
|
180
178
|
}
|
|
179
|
+
|
|
180
|
+
def _forward_networks(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
181
|
+
"""Passes the input tensor through the encoder and decoder."""
|
|
182
|
+
features = self.encoder(tensor) if self.encoder else tensor
|
|
183
|
+
if isinstance(self.decoder, segmentation.Decoder):
|
|
184
|
+
if not isinstance(features, list):
|
|
185
|
+
raise ValueError(f"Expected a list of feature map tensors, got {type(features)}.")
|
|
186
|
+
image_size = (tensor.shape[-2], tensor.shape[-1])
|
|
187
|
+
return self.decoder(DecoderInputs(features, image_size, tensor))
|
|
188
|
+
return self.decoder(features)
|
|
@@ -1,6 +1,13 @@
|
|
|
1
1
|
"""Vision Model Backbones API."""
|
|
2
2
|
|
|
3
|
-
from eva.vision.models.networks.backbones import pathology,
|
|
3
|
+
from eva.vision.models.networks.backbones import pathology, radiology, timm, universal
|
|
4
4
|
from eva.vision.models.networks.backbones.registry import BackboneModelRegistry, register_model
|
|
5
5
|
|
|
6
|
-
__all__ = [
|
|
6
|
+
__all__ = [
|
|
7
|
+
"radiology",
|
|
8
|
+
"pathology",
|
|
9
|
+
"timm",
|
|
10
|
+
"universal",
|
|
11
|
+
"BackboneModelRegistry",
|
|
12
|
+
"register_model",
|
|
13
|
+
]
|
|
@@ -1,9 +1,14 @@
|
|
|
1
1
|
"""Vision Pathology Model Backbones API."""
|
|
2
2
|
|
|
3
|
-
from eva.vision.models.networks.backbones.pathology.bioptimus import
|
|
3
|
+
from eva.vision.models.networks.backbones.pathology.bioptimus import (
|
|
4
|
+
bioptimus_h0_mini,
|
|
5
|
+
bioptimus_h_optimus_0,
|
|
6
|
+
)
|
|
4
7
|
from eva.vision.models.networks.backbones.pathology.gigapath import prov_gigapath
|
|
5
8
|
from eva.vision.models.networks.backbones.pathology.histai import histai_hibou_b, histai_hibou_l
|
|
9
|
+
from eva.vision.models.networks.backbones.pathology.hkust import hkust_gpfm
|
|
6
10
|
from eva.vision.models.networks.backbones.pathology.kaiko import (
|
|
11
|
+
kaiko_midnight_12k,
|
|
7
12
|
kaiko_vitb8,
|
|
8
13
|
kaiko_vitb16,
|
|
9
14
|
kaiko_vitl14,
|
|
@@ -11,11 +16,12 @@ from eva.vision.models.networks.backbones.pathology.kaiko import (
|
|
|
11
16
|
kaiko_vits16,
|
|
12
17
|
)
|
|
13
18
|
from eva.vision.models.networks.backbones.pathology.lunit import lunit_vits8, lunit_vits16
|
|
14
|
-
from eva.vision.models.networks.backbones.pathology.mahmood import mahmood_uni
|
|
19
|
+
from eva.vision.models.networks.backbones.pathology.mahmood import mahmood_uni, mahmood_uni2_h
|
|
15
20
|
from eva.vision.models.networks.backbones.pathology.owkin import owkin_phikon, owkin_phikon_v2
|
|
16
21
|
from eva.vision.models.networks.backbones.pathology.paige import paige_virchow2
|
|
17
22
|
|
|
18
23
|
__all__ = [
|
|
24
|
+
"kaiko_midnight_12k",
|
|
19
25
|
"kaiko_vitb16",
|
|
20
26
|
"kaiko_vitb8",
|
|
21
27
|
"kaiko_vitl14",
|
|
@@ -26,9 +32,12 @@ __all__ = [
|
|
|
26
32
|
"lunit_vits16",
|
|
27
33
|
"lunit_vits8",
|
|
28
34
|
"mahmood_uni",
|
|
35
|
+
"mahmood_uni2_h",
|
|
29
36
|
"bioptimus_h_optimus_0",
|
|
37
|
+
"bioptimus_h0_mini",
|
|
30
38
|
"prov_gigapath",
|
|
31
39
|
"histai_hibou_b",
|
|
32
40
|
"histai_hibou_l",
|
|
33
41
|
"paige_virchow2",
|
|
42
|
+
"hkust_gpfm",
|
|
34
43
|
]
|