kaiko-eva 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/data/datasets/base.py +7 -2
- eva/core/models/modules/head.py +4 -2
- eva/core/models/modules/typings.py +2 -2
- eva/core/models/transforms/__init__.py +2 -1
- eva/core/models/transforms/as_discrete.py +57 -0
- eva/core/models/wrappers/_utils.py +121 -1
- eva/core/trainers/_recorder.py +4 -1
- eva/core/utils/suppress_logs.py +28 -0
- eva/vision/data/__init__.py +2 -2
- eva/vision/data/dataloaders/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
- eva/vision/data/datasets/__init__.py +2 -2
- eva/vision/data/datasets/classification/bach.py +3 -4
- eva/vision/data/datasets/classification/bracs.py +3 -4
- eva/vision/data/datasets/classification/breakhis.py +3 -4
- eva/vision/data/datasets/classification/camelyon16.py +4 -5
- eva/vision/data/datasets/classification/crc.py +3 -4
- eva/vision/data/datasets/classification/gleason_arvaniti.py +3 -4
- eva/vision/data/datasets/classification/mhist.py +3 -4
- eva/vision/data/datasets/classification/panda.py +4 -5
- eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
- eva/vision/data/datasets/classification/unitopatho.py +3 -4
- eva/vision/data/datasets/classification/wsi.py +6 -5
- eva/vision/data/datasets/segmentation/__init__.py +2 -2
- eva/vision/data/datasets/segmentation/_utils.py +47 -0
- eva/vision/data/datasets/segmentation/bcss.py +7 -8
- eva/vision/data/datasets/segmentation/btcv.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +6 -7
- eva/vision/data/datasets/segmentation/lits.py +9 -8
- eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
- eva/vision/data/datasets/segmentation/monusac.py +4 -5
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
- eva/vision/data/datasets/vision.py +95 -4
- eva/vision/data/datasets/wsi.py +5 -5
- eva/vision/data/transforms/__init__.py +22 -3
- eva/vision/data/transforms/common/__init__.py +1 -2
- eva/vision/data/transforms/croppad/__init__.py +11 -0
- eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
- eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
- eva/vision/data/transforms/intensity/__init__.py +11 -0
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
- eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
- eva/vision/data/transforms/spatial/__init__.py +7 -0
- eva/vision/data/transforms/spatial/flip.py +72 -0
- eva/vision/data/transforms/spatial/rotate.py +53 -0
- eva/vision/data/transforms/spatial/spacing.py +69 -0
- eva/vision/data/transforms/utility/__init__.py +5 -0
- eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
- eva/vision/data/tv_tensors/__init__.py +5 -0
- eva/vision/data/tv_tensors/volume.py +61 -0
- eva/vision/metrics/segmentation/monai_dice.py +9 -2
- eva/vision/models/modules/semantic_segmentation.py +32 -19
- eva/vision/models/networks/backbones/__init__.py +9 -2
- eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
- eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
- eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
- eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
- eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
- eva/vision/models/networks/backbones/radiology/voco.py +75 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
- eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
- eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
- eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
- eva/vision/utils/io/__init__.py +2 -0
- eva/vision/utils/io/nifti.py +91 -11
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/METADATA +16 -12
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/RECORD +74 -58
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/WHEEL +1 -1
- eva/vision/data/datasets/classification/base.py +0 -96
- eva/vision/data/datasets/segmentation/base.py +0 -96
- eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
- eva/vision/data/transforms/normalization/__init__.py +0 -6
- eva/vision/data/transforms/normalization/clamp.py +0 -43
- eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
- eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
- eva/vision/metrics/segmentation/BUILD +0 -1
- eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
- eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Rotation transforms."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
|
|
6
|
+
from monai.transforms.spatial import array as monai_spatial_transforms
|
|
7
|
+
from torchvision import tv_tensors
|
|
8
|
+
from torchvision.transforms import v2
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RandRotate90(v2.Transform):
|
|
15
|
+
"""Rotate input tensors by 90 degrees."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
prob: float = 0.1,
|
|
20
|
+
max_k: int = 3,
|
|
21
|
+
spatial_axes: tuple[int, int] = (1, 2),
|
|
22
|
+
) -> None:
|
|
23
|
+
"""Initializes the transform.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
prob: probability of rotating.
|
|
27
|
+
(Default 0.1, with 10% probability it returns a rotated array)
|
|
28
|
+
max_k: number of rotations will be sampled from `np.random.randint(max_k) + 1`.
|
|
29
|
+
spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes.
|
|
30
|
+
Default: (1, 2), so for [C, T, H, W] will rotate along (H, W) plane (MONAI ignores
|
|
31
|
+
the first C dimension).
|
|
32
|
+
"""
|
|
33
|
+
super().__init__()
|
|
34
|
+
|
|
35
|
+
self._rotate = monai_spatial_transforms.RandRotate90(
|
|
36
|
+
prob=prob, max_k=max_k, spatial_axes=spatial_axes
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
40
|
+
self._rotate.randomize()
|
|
41
|
+
return {}
|
|
42
|
+
|
|
43
|
+
@functools.singledispatchmethod
|
|
44
|
+
@override
|
|
45
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
46
|
+
return inpt
|
|
47
|
+
|
|
48
|
+
@_transform.register(tv_tensors.Image)
|
|
49
|
+
@_transform.register(eva_tv_tensors.Volume)
|
|
50
|
+
@_transform.register(tv_tensors.Mask)
|
|
51
|
+
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
52
|
+
inpt_rotated = self._rotate(img=inpt, randomize=False)
|
|
53
|
+
return tv_tensors.wrap(inpt_rotated, like=inpt)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Spacing resample transform."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, List, Sequence
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from monai.data import meta_tensor
|
|
9
|
+
from monai.transforms.spatial import array as monai_spatial_transforms
|
|
10
|
+
from torchvision import tv_tensors
|
|
11
|
+
from torchvision.transforms import v2
|
|
12
|
+
from typing_extensions import override
|
|
13
|
+
|
|
14
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Spacing(v2.Transform):
|
|
18
|
+
"""Resample input image into the specified `pixdim`.
|
|
19
|
+
|
|
20
|
+
- Expects tensors of shape `[C, T, H, W]`.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
pixdim: Sequence[float] | float | np.ndarray,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Initializes the transform.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
pixdim: output voxel spacing. if providing a single number,
|
|
31
|
+
will use it for the first dimension. Items of the pixdim
|
|
32
|
+
sequence map to the spatial dimensions of input image, if
|
|
33
|
+
length of pixdim sequence is longer than image spatial
|
|
34
|
+
dimensions, will ignore the longer part, if shorter, will
|
|
35
|
+
pad with the last value. For example, for 3D image if pixdim
|
|
36
|
+
is [1.0, 2.0] it will be padded to [1.0, 2.0, 2.0] if the
|
|
37
|
+
components of the `pixdim` are non-positive values, the
|
|
38
|
+
transform will use the corresponding components of the original
|
|
39
|
+
pixdim, which is computed from the `affine` matrix of input image.
|
|
40
|
+
"""
|
|
41
|
+
super().__init__()
|
|
42
|
+
|
|
43
|
+
self._spacing = monai_spatial_transforms.Spacing(pixdim=pixdim, recompute_affine=True)
|
|
44
|
+
self._affine = None
|
|
45
|
+
|
|
46
|
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
47
|
+
self._affine = next(
|
|
48
|
+
inpt.affine for inpt in flat_inputs if isinstance(inpt, eva_tv_tensors.Volume)
|
|
49
|
+
)
|
|
50
|
+
return {}
|
|
51
|
+
|
|
52
|
+
@functools.singledispatchmethod
|
|
53
|
+
@override
|
|
54
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
55
|
+
return inpt
|
|
56
|
+
|
|
57
|
+
@_transform.register(eva_tv_tensors.Volume)
|
|
58
|
+
def _(self, inpt: eva_tv_tensors.Volume, params: Dict[str, Any]) -> Any:
|
|
59
|
+
inpt_spacing = self._spacing(inpt.to_meta_tensor(), mode="bilinear")
|
|
60
|
+
if not isinstance(inpt_spacing, meta_tensor.MetaTensor):
|
|
61
|
+
raise ValueError(f"Expected MetaTensor, got {type(inpt_spacing)}")
|
|
62
|
+
return eva_tv_tensors.Volume.from_meta_tensor(inpt_spacing)
|
|
63
|
+
|
|
64
|
+
@_transform.register(tv_tensors.Mask)
|
|
65
|
+
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
66
|
+
inpt_spacing = self._spacing(
|
|
67
|
+
meta_tensor.MetaTensor(inpt, affine=self._affine), mode="nearest"
|
|
68
|
+
)
|
|
69
|
+
return tv_tensors.wrap(inpt_spacing.to(dtype=torch.long), like=inpt)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Adjust or add the channel dimension of input data to ensure `channel_first` shape."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict
|
|
5
|
+
|
|
6
|
+
from monai.transforms.utility import array as monai_utility_transforms
|
|
7
|
+
from torchvision import tv_tensors
|
|
8
|
+
from torchvision.transforms import v2
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class EnsureChannelFirst(v2.Transform):
|
|
15
|
+
"""Adjust or add the channel dimension of input data to ensure `channel_first` shape."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
strict_check: bool = True,
|
|
20
|
+
channel_dim: None | str | int = None,
|
|
21
|
+
) -> None:
|
|
22
|
+
"""Initializes the transform.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
strict_check: whether to raise an error when the meta information is insufficient.
|
|
26
|
+
channel_dim: This argument can be used to specify the original channel dimension
|
|
27
|
+
(integer) of the input array.
|
|
28
|
+
It overrides the `original_channel_dim` from provided MetaTensor input.
|
|
29
|
+
If the input array doesn't have a channel dim, this value should be
|
|
30
|
+
``'no_channel'``.
|
|
31
|
+
If this is set to `None`, this class relies on `img` or `meta_dict` to provide
|
|
32
|
+
the channel dimension.
|
|
33
|
+
"""
|
|
34
|
+
super().__init__()
|
|
35
|
+
|
|
36
|
+
self._ensure_channel_first = monai_utility_transforms.EnsureChannelFirst(
|
|
37
|
+
strict_check=strict_check,
|
|
38
|
+
channel_dim=channel_dim,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
@functools.singledispatchmethod
|
|
42
|
+
@override
|
|
43
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
44
|
+
return inpt
|
|
45
|
+
|
|
46
|
+
@_transform.register(tv_tensors.Image)
|
|
47
|
+
@_transform.register(eva_tv_tensors.Volume)
|
|
48
|
+
@_transform.register(tv_tensors.Mask)
|
|
49
|
+
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
50
|
+
inpt_channel_first = self._ensure_channel_first(inpt)
|
|
51
|
+
return tv_tensors.wrap(inpt_channel_first, like=inpt)
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""Custom `tv_tensors` type for 3D Volumes."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Optional, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from monai.data import meta_tensor
|
|
7
|
+
from torchvision import tv_tensors
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Volume(tv_tensors.Video):
|
|
12
|
+
""":class:`torchvision.TVTensor` subclass for 3D volumes.
|
|
13
|
+
|
|
14
|
+
- Adds optional metadata and affine matrix to the tensor.
|
|
15
|
+
- Expects tensors to be of shape `[..., T, C, H, W]`.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
data: Any data that can be turned into a tensor with :func:`torch.as_tensor`.
|
|
19
|
+
affine: Affine matrix of the volume. Expected to be of shape `[4, 4]`, and
|
|
20
|
+
columns to correspond to [T, H, W, (translation)] dimensions. Note that
|
|
21
|
+
`nibabel` by default uses [H, W, T, (translation)] order for affine matrices.
|
|
22
|
+
metadata: Metadata associated with the volume.
|
|
23
|
+
dtype: Desired data type. If omitted, will be inferred from `data`.
|
|
24
|
+
device: Desired device.
|
|
25
|
+
requires_grad: Whether autograd should record operations.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
@override
|
|
29
|
+
def __new__(
|
|
30
|
+
cls,
|
|
31
|
+
data: Any,
|
|
32
|
+
affine: torch.Tensor | None = None,
|
|
33
|
+
metadata: Dict[str, Any] | None = None,
|
|
34
|
+
dtype: Optional[torch.dtype] = None,
|
|
35
|
+
device: Optional[Union[torch.device, str, int]] = None,
|
|
36
|
+
requires_grad: Optional[bool] = None,
|
|
37
|
+
) -> "Volume":
|
|
38
|
+
cls.affine = affine
|
|
39
|
+
cls.metadata = metadata
|
|
40
|
+
|
|
41
|
+
return super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad) # type: ignore
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def from_meta_tensor(cls, meta_tensor: meta_tensor.MetaTensor) -> "Volume":
|
|
45
|
+
"""Creates an instance from a :class:`monai.data.meta_tensor.MetaTensor`."""
|
|
46
|
+
return cls(
|
|
47
|
+
meta_tensor.data,
|
|
48
|
+
affine=meta_tensor.affine,
|
|
49
|
+
metadata=meta_tensor.meta,
|
|
50
|
+
dtype=meta_tensor.dtype,
|
|
51
|
+
device=meta_tensor.device,
|
|
52
|
+
requires_grad=meta_tensor.requires_grad,
|
|
53
|
+
) # type: ignore
|
|
54
|
+
|
|
55
|
+
def to_meta_tensor(self) -> meta_tensor.MetaTensor:
|
|
56
|
+
"""Converts the volume to a :class:`monai.data.meta_tensor.MetaTensor`."""
|
|
57
|
+
return meta_tensor.MetaTensor(self, affine=self.affine, meta=self.metadata)
|
|
58
|
+
|
|
59
|
+
def __repr__(self, *, tensor_contents: Any = None) -> str:
|
|
60
|
+
"""Returns the string representation of the object."""
|
|
61
|
+
return self._make_repr()
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
"""Wrapper for dice score metric from MONAI."""
|
|
2
2
|
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
3
5
|
from monai.metrics.meandice import DiceMetric
|
|
4
6
|
from typing_extensions import override
|
|
5
7
|
|
|
@@ -14,6 +16,7 @@ class MonaiDiceScore(wrappers.MonaiMetricWrapper):
|
|
|
14
16
|
self,
|
|
15
17
|
num_classes: int,
|
|
16
18
|
include_background: bool = True,
|
|
19
|
+
input_format: Literal["one-hot", "index"] = "index",
|
|
17
20
|
reduction: str = "mean",
|
|
18
21
|
ignore_index: int | None = None,
|
|
19
22
|
**kwargs,
|
|
@@ -24,6 +27,8 @@ class MonaiDiceScore(wrappers.MonaiMetricWrapper):
|
|
|
24
27
|
num_classes: The number of classes in the dataset.
|
|
25
28
|
include_background: Whether to include the background class in the computation.
|
|
26
29
|
reduction: The method to reduce the dice score. Options are `"mean"`, `"sum"`, `"none"`.
|
|
30
|
+
input_format: Choose between "one-hot" for one-hot encoded tensors or "index"
|
|
31
|
+
for index tensors.
|
|
27
32
|
ignore_index: Integer specifying a target class to ignore. If given, this class
|
|
28
33
|
index does not contribute to the returned score.
|
|
29
34
|
kwargs: Additional keyword arguments for instantiating monai's `DiceMetric` class.
|
|
@@ -40,11 +45,13 @@ class MonaiDiceScore(wrappers.MonaiMetricWrapper):
|
|
|
40
45
|
self.reduction = reduction
|
|
41
46
|
self.orig_num_classes = num_classes
|
|
42
47
|
self.ignore_index = ignore_index
|
|
48
|
+
self.input_format = input_format
|
|
43
49
|
|
|
44
50
|
@override
|
|
45
51
|
def update(self, preds, target):
|
|
46
|
-
|
|
47
|
-
|
|
52
|
+
if self.input_format == "index":
|
|
53
|
+
preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes)
|
|
54
|
+
target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes)
|
|
48
55
|
if self.ignore_index is not None:
|
|
49
56
|
preds, target = _utils.apply_ignore_index(preds, target, self.ignore_index)
|
|
50
57
|
return super().update(preds, target)
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
""""Neural Network Semantic Segmentation Module."""
|
|
2
2
|
|
|
3
|
+
import functools
|
|
3
4
|
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
|
4
5
|
|
|
5
6
|
import torch
|
|
6
7
|
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
|
|
7
8
|
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
9
|
+
from monai.inferers.inferer import Inferer
|
|
8
10
|
from torch import nn, optim
|
|
9
11
|
from torch.optim import lr_scheduler
|
|
10
12
|
from typing_extensions import override
|
|
@@ -15,6 +17,7 @@ from eva.core.models.modules.typings import INPUT_BATCH, INPUT_TENSOR_BATCH
|
|
|
15
17
|
from eva.core.models.modules.utils import batch_postprocess, grad, submodule_state_dict
|
|
16
18
|
from eva.core.utils import parser
|
|
17
19
|
from eva.vision.models.networks import decoders
|
|
20
|
+
from eva.vision.models.networks.decoders import segmentation
|
|
18
21
|
from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs
|
|
19
22
|
|
|
20
23
|
|
|
@@ -23,15 +26,17 @@ class SemanticSegmentationModule(module.ModelModule):
|
|
|
23
26
|
|
|
24
27
|
def __init__(
|
|
25
28
|
self,
|
|
26
|
-
decoder: decoders.Decoder,
|
|
29
|
+
decoder: decoders.Decoder | nn.Module,
|
|
27
30
|
criterion: Callable[..., torch.Tensor],
|
|
28
31
|
encoder: Dict[str, Any] | Callable[[torch.Tensor], List[torch.Tensor]] | None = None,
|
|
29
32
|
lr_multiplier_encoder: float = 0.0,
|
|
33
|
+
inferer: Inferer | None = None,
|
|
30
34
|
optimizer: OptimizerCallable = optim.AdamW,
|
|
31
35
|
lr_scheduler: LRSchedulerCallable = lr_scheduler.ConstantLR,
|
|
32
36
|
metrics: metrics_lib.MetricsSchema | None = None,
|
|
33
37
|
postprocess: batch_postprocess.BatchPostProcess | None = None,
|
|
34
38
|
save_decoder_only: bool = True,
|
|
39
|
+
spatial_dims: int = 2,
|
|
35
40
|
) -> None:
|
|
36
41
|
"""Initializes the neural net head module.
|
|
37
42
|
|
|
@@ -44,6 +49,8 @@ class SemanticSegmentationModule(module.ModelModule):
|
|
|
44
49
|
during the `configure_model` step.
|
|
45
50
|
lr_multiplier_encoder: The learning rate multiplier for the
|
|
46
51
|
encoder parameters. If `0`, it will freeze the encoder.
|
|
52
|
+
inferer: An optional MONAI `Inferer` for inference
|
|
53
|
+
postprocess during evaluation.
|
|
47
54
|
optimizer: The optimizer to use.
|
|
48
55
|
lr_scheduler: The learning rate scheduler to use.
|
|
49
56
|
metrics: The metric groups to track.
|
|
@@ -52,6 +59,8 @@ class SemanticSegmentationModule(module.ModelModule):
|
|
|
52
59
|
predictions and targets.
|
|
53
60
|
save_decoder_only: Whether to save only the decoder during checkpointing. If False,
|
|
54
61
|
will also save the encoder (not recommended when frozen).
|
|
62
|
+
spatial_dims: The number of spatial dimensions, 2 for 2D
|
|
63
|
+
and 3 for 3D segmentation.
|
|
55
64
|
"""
|
|
56
65
|
super().__init__(metrics=metrics, postprocess=postprocess)
|
|
57
66
|
|
|
@@ -62,6 +71,8 @@ class SemanticSegmentationModule(module.ModelModule):
|
|
|
62
71
|
self.optimizer = optimizer
|
|
63
72
|
self.lr_scheduler = lr_scheduler
|
|
64
73
|
self.save_decoder_only = save_decoder_only
|
|
74
|
+
self.inferer = inferer
|
|
75
|
+
self.spatial_dims = spatial_dims
|
|
65
76
|
|
|
66
77
|
@override
|
|
67
78
|
def configure_model(self) -> None:
|
|
@@ -104,25 +115,16 @@ class SemanticSegmentationModule(module.ModelModule):
|
|
|
104
115
|
@override
|
|
105
116
|
def forward(
|
|
106
117
|
self,
|
|
107
|
-
|
|
108
|
-
to_size: Tuple[int, int]
|
|
118
|
+
tensor: torch.Tensor,
|
|
119
|
+
to_size: Tuple[int, int],
|
|
109
120
|
*args: Any,
|
|
110
121
|
**kwargs: Any,
|
|
111
122
|
) -> torch.Tensor:
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
"""
|
|
118
|
-
if self.encoder is None and to_size is None:
|
|
119
|
-
raise ValueError(
|
|
120
|
-
"Please provide the expected `to_size` that the "
|
|
121
|
-
"decoder should map the embeddings (`inputs`) to."
|
|
122
|
-
)
|
|
123
|
-
features = self.encoder(inputs) if self.encoder else inputs
|
|
124
|
-
decoder_inputs = DecoderInputs(features, to_size or inputs.shape[-2:], inputs) # type: ignore
|
|
125
|
-
return self.decoder(decoder_inputs)
|
|
123
|
+
return (
|
|
124
|
+
self.inferer(tensor, network=functools.partial(self._forward_networks, to_size=to_size))
|
|
125
|
+
if self.inferer is not None and not self.training
|
|
126
|
+
else self._forward_networks(tensor, to_size=to_size)
|
|
127
|
+
)
|
|
126
128
|
|
|
127
129
|
@override
|
|
128
130
|
def training_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
@@ -137,7 +139,9 @@ class SemanticSegmentationModule(module.ModelModule):
|
|
|
137
139
|
return self._batch_step(batch)
|
|
138
140
|
|
|
139
141
|
@override
|
|
140
|
-
def predict_step(
|
|
142
|
+
def predict_step(
|
|
143
|
+
self, batch: INPUT_BATCH, *args: Any, **kwargs: Any
|
|
144
|
+
) -> torch.Tensor | List[torch.Tensor]:
|
|
141
145
|
tensor = INPUT_BATCH(*batch).data
|
|
142
146
|
return self.encoder(tensor) if isinstance(self.encoder, nn.Module) else tensor
|
|
143
147
|
|
|
@@ -170,7 +174,7 @@ class SemanticSegmentationModule(module.ModelModule):
|
|
|
170
174
|
The batch step output.
|
|
171
175
|
"""
|
|
172
176
|
data, targets, metadata = INPUT_TENSOR_BATCH(*batch)
|
|
173
|
-
predictions = self(data, to_size=targets.shape[-
|
|
177
|
+
predictions = self(data, to_size=targets.shape[-self.spatial_dims :])
|
|
174
178
|
loss = self.criterion(predictions, targets)
|
|
175
179
|
return {
|
|
176
180
|
"loss": loss,
|
|
@@ -178,3 +182,12 @@ class SemanticSegmentationModule(module.ModelModule):
|
|
|
178
182
|
"predictions": predictions,
|
|
179
183
|
"metadata": metadata,
|
|
180
184
|
}
|
|
185
|
+
|
|
186
|
+
def _forward_networks(self, tensor: torch.Tensor, to_size: Tuple[int, int]) -> torch.Tensor:
|
|
187
|
+
"""Passes the input tensor through the encoder and decoder."""
|
|
188
|
+
features = self.encoder(tensor) if self.encoder else tensor
|
|
189
|
+
if isinstance(self.decoder, segmentation.Decoder):
|
|
190
|
+
if not isinstance(features, list):
|
|
191
|
+
raise ValueError(f"Expected a list of feature map tensors, got {type(features)}.")
|
|
192
|
+
return self.decoder(DecoderInputs(features, to_size, tensor))
|
|
193
|
+
return self.decoder(features)
|
|
@@ -1,6 +1,13 @@
|
|
|
1
1
|
"""Vision Model Backbones API."""
|
|
2
2
|
|
|
3
|
-
from eva.vision.models.networks.backbones import pathology,
|
|
3
|
+
from eva.vision.models.networks.backbones import pathology, radiology, timm, universal
|
|
4
4
|
from eva.vision.models.networks.backbones.registry import BackboneModelRegistry, register_model
|
|
5
5
|
|
|
6
|
-
__all__ = [
|
|
6
|
+
__all__ = [
|
|
7
|
+
"radiology",
|
|
8
|
+
"pathology",
|
|
9
|
+
"timm",
|
|
10
|
+
"universal",
|
|
11
|
+
"BackboneModelRegistry",
|
|
12
|
+
"register_model",
|
|
13
|
+
]
|
|
@@ -1,9 +1,14 @@
|
|
|
1
1
|
"""Vision Pathology Model Backbones API."""
|
|
2
2
|
|
|
3
|
-
from eva.vision.models.networks.backbones.pathology.bioptimus import
|
|
3
|
+
from eva.vision.models.networks.backbones.pathology.bioptimus import (
|
|
4
|
+
bioptimus_h0_mini,
|
|
5
|
+
bioptimus_h_optimus_0,
|
|
6
|
+
)
|
|
4
7
|
from eva.vision.models.networks.backbones.pathology.gigapath import prov_gigapath
|
|
5
8
|
from eva.vision.models.networks.backbones.pathology.histai import histai_hibou_b, histai_hibou_l
|
|
9
|
+
from eva.vision.models.networks.backbones.pathology.hkust import hkust_gpfm
|
|
6
10
|
from eva.vision.models.networks.backbones.pathology.kaiko import (
|
|
11
|
+
kaiko_midnight_12k,
|
|
7
12
|
kaiko_vitb8,
|
|
8
13
|
kaiko_vitb16,
|
|
9
14
|
kaiko_vitl14,
|
|
@@ -11,11 +16,12 @@ from eva.vision.models.networks.backbones.pathology.kaiko import (
|
|
|
11
16
|
kaiko_vits16,
|
|
12
17
|
)
|
|
13
18
|
from eva.vision.models.networks.backbones.pathology.lunit import lunit_vits8, lunit_vits16
|
|
14
|
-
from eva.vision.models.networks.backbones.pathology.mahmood import mahmood_uni
|
|
19
|
+
from eva.vision.models.networks.backbones.pathology.mahmood import mahmood_uni, mahmood_uni2_h
|
|
15
20
|
from eva.vision.models.networks.backbones.pathology.owkin import owkin_phikon, owkin_phikon_v2
|
|
16
21
|
from eva.vision.models.networks.backbones.pathology.paige import paige_virchow2
|
|
17
22
|
|
|
18
23
|
__all__ = [
|
|
24
|
+
"kaiko_midnight_12k",
|
|
19
25
|
"kaiko_vitb16",
|
|
20
26
|
"kaiko_vitb8",
|
|
21
27
|
"kaiko_vitl14",
|
|
@@ -26,9 +32,12 @@ __all__ = [
|
|
|
26
32
|
"lunit_vits16",
|
|
27
33
|
"lunit_vits8",
|
|
28
34
|
"mahmood_uni",
|
|
35
|
+
"mahmood_uni2_h",
|
|
29
36
|
"bioptimus_h_optimus_0",
|
|
37
|
+
"bioptimus_h0_mini",
|
|
30
38
|
"prov_gigapath",
|
|
31
39
|
"histai_hibou_b",
|
|
32
40
|
"histai_hibou_l",
|
|
33
41
|
"paige_virchow2",
|
|
42
|
+
"hkust_gpfm",
|
|
34
43
|
]
|
|
@@ -5,6 +5,9 @@ from typing import Tuple
|
|
|
5
5
|
import timm
|
|
6
6
|
from torch import nn
|
|
7
7
|
|
|
8
|
+
from eva.core.models import transforms
|
|
9
|
+
from eva.vision.models import wrappers
|
|
10
|
+
from eva.vision.models.networks.backbones import _utils
|
|
8
11
|
from eva.vision.models.networks.backbones.registry import register_model
|
|
9
12
|
|
|
10
13
|
|
|
@@ -13,7 +16,9 @@ def bioptimus_h_optimus_0(
|
|
|
13
16
|
dynamic_img_size: bool = True,
|
|
14
17
|
out_indices: int | Tuple[int, ...] | None = None,
|
|
15
18
|
) -> nn.Module:
|
|
16
|
-
"""Initializes the
|
|
19
|
+
"""Initializes the H-Optimus-0 pathology FM by Bioptimus.
|
|
20
|
+
|
|
21
|
+
See https://huggingface.co/bioptimus/H-optimus-0 for details.
|
|
17
22
|
|
|
18
23
|
Args:
|
|
19
24
|
dynamic_img_size: Whether to allow the interpolation embedding
|
|
@@ -32,3 +37,44 @@ def bioptimus_h_optimus_0(
|
|
|
32
37
|
out_indices=out_indices,
|
|
33
38
|
features_only=out_indices is not None,
|
|
34
39
|
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@register_model("pathology/bioptimus_h0_mini")
|
|
43
|
+
def bioptimus_h0_mini(
|
|
44
|
+
dynamic_img_size: bool = True,
|
|
45
|
+
out_indices: int | Tuple[int, ...] | None = None,
|
|
46
|
+
hf_token: str | None = None,
|
|
47
|
+
include_patch_tokens: bool = False,
|
|
48
|
+
) -> nn.Module:
|
|
49
|
+
"""Initializes H0-mini (ViT-B) pathology FM by Bioptimus.
|
|
50
|
+
|
|
51
|
+
This model was distilled from H-Optimus-0 on 40M TCGA tiles.
|
|
52
|
+
|
|
53
|
+
See https://huggingface.co/bioptimus/H0-mini for details.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
dynamic_img_size: Support different input image sizes by allowing to change
|
|
57
|
+
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
|
|
58
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
59
|
+
hf_token: HuggingFace token to download the model.
|
|
60
|
+
include_patch_tokens: Whether to combine the mean aggregated patch tokens with cls token.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
The model instance.
|
|
64
|
+
"""
|
|
65
|
+
_utils.huggingface_login(hf_token)
|
|
66
|
+
return wrappers.TimmModel(
|
|
67
|
+
model_name="hf-hub:bioptimus/H0-mini",
|
|
68
|
+
out_indices=out_indices,
|
|
69
|
+
pretrained=True,
|
|
70
|
+
model_kwargs={
|
|
71
|
+
"dynamic_img_size": dynamic_img_size,
|
|
72
|
+
"mlp_layer": timm.layers.SwiGLUPacked,
|
|
73
|
+
"act_layer": nn.SiLU,
|
|
74
|
+
},
|
|
75
|
+
tensor_transforms=(
|
|
76
|
+
transforms.ExtractCLSFeatures(include_patch_tokens=include_patch_tokens)
|
|
77
|
+
if out_indices is None
|
|
78
|
+
else None
|
|
79
|
+
),
|
|
80
|
+
)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Pathology FMs from Hong Kong University of Science and Technology."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
import timm
|
|
7
|
+
from torch import nn
|
|
8
|
+
|
|
9
|
+
from eva.core.models.wrappers import _utils
|
|
10
|
+
from eva.vision.models.networks.backbones.registry import register_model
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@register_model("pathology/hkust_gpfm")
|
|
14
|
+
def hkust_gpfm(
|
|
15
|
+
dynamic_img_size: bool = True,
|
|
16
|
+
out_indices: int | Tuple[int, ...] | None = None,
|
|
17
|
+
) -> nn.Module:
|
|
18
|
+
"""Initializes GPFM model from Hong Kong University of Science and Technology.
|
|
19
|
+
|
|
20
|
+
Ma, J., Guo, Z., Zhou, F., Wang, Y., Xu, Y., et al. (2024).
|
|
21
|
+
Towards a generalizable pathology foundation model via unified knowledge
|
|
22
|
+
distillation (arXiv No. 2407.18449). arXiv. https://arxiv.org/abs/2407.18449
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
dynamic_img_size: Support different input image sizes by allowing to change
|
|
26
|
+
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
|
|
27
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
The model instance.
|
|
31
|
+
"""
|
|
32
|
+
return timm.create_model(
|
|
33
|
+
model_name="vit_large_patch14_dinov2",
|
|
34
|
+
pretrained=True,
|
|
35
|
+
pretrained_cfg={
|
|
36
|
+
"state_dict": _load_state_dict(),
|
|
37
|
+
"num_classes": 0,
|
|
38
|
+
},
|
|
39
|
+
out_indices=out_indices,
|
|
40
|
+
features_only=out_indices is not None,
|
|
41
|
+
**{
|
|
42
|
+
"img_size": 224,
|
|
43
|
+
"patch_size": 14,
|
|
44
|
+
"init_values": 1e-5,
|
|
45
|
+
"qkv_bias": True,
|
|
46
|
+
"dynamic_img_size": dynamic_img_size,
|
|
47
|
+
},
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _load_state_dict() -> dict:
|
|
52
|
+
"""Loads the state dict with model weights from github."""
|
|
53
|
+
state_dict = _utils.load_state_dict_from_url(
|
|
54
|
+
url="https://github.com/birkhoffkiki/GPFM/releases/download/ckpt/GPFM.pth",
|
|
55
|
+
md5="0dc7e345de84f385d09c8c782b4b3236",
|
|
56
|
+
)
|
|
57
|
+
return _convert_state_dict(state_dict["teacher"])
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _convert_state_dict(state_dict: dict) -> dict:
|
|
61
|
+
"""Rename state dict keys to match timm's format."""
|
|
62
|
+
state_dict = {
|
|
63
|
+
re.sub(r"blocks\.\d+\.(\d+)", r"blocks.\1", key.replace("backbone.", "")): value
|
|
64
|
+
for key, value in state_dict.items()
|
|
65
|
+
}
|
|
66
|
+
remove_keys = ["mask_token"] + [key for key in state_dict.keys() if "dino_head" in key]
|
|
67
|
+
for key in remove_keys:
|
|
68
|
+
state_dict.pop(key)
|
|
69
|
+
return state_dict
|
|
@@ -5,9 +5,27 @@ from typing import Tuple
|
|
|
5
5
|
import torch
|
|
6
6
|
from torch import nn
|
|
7
7
|
|
|
8
|
+
from eva.vision.models.networks.backbones import _utils
|
|
8
9
|
from eva.vision.models.networks.backbones.registry import register_model
|
|
9
10
|
|
|
10
11
|
|
|
12
|
+
@register_model("pathology/kaiko_midnight_12k")
|
|
13
|
+
def kaiko_midnight_12k(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
|
|
14
|
+
"""Initializes the Midnight-12k pathology FM by kaiko.ai.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
The model instance.
|
|
21
|
+
"""
|
|
22
|
+
return _utils.load_hugingface_model(
|
|
23
|
+
model_name="kaiko-ai/midnight",
|
|
24
|
+
out_indices=out_indices,
|
|
25
|
+
model_kwargs={"trust_remote_code": True},
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
11
29
|
@register_model("pathology/kaiko_vits16")
|
|
12
30
|
def kaiko_vits16(
|
|
13
31
|
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Vision Radiology Model Backbones API."""
|
|
2
|
+
|
|
3
|
+
from eva.vision.models.networks.backbones.radiology.swin_unetr import SwinUNETREncoder
|
|
4
|
+
from eva.vision.models.networks.backbones.radiology.voco import VoCoB, VoCoH, VoCoL
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"VoCoB",
|
|
8
|
+
"VoCoL",
|
|
9
|
+
"VoCoH",
|
|
10
|
+
"SwinUNETREncoder",
|
|
11
|
+
]
|