kaiko-eva 0.1.8__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. eva/core/data/datasets/base.py +7 -2
  2. eva/core/data/datasets/classification/embeddings.py +2 -2
  3. eva/core/data/datasets/classification/multi_embeddings.py +2 -2
  4. eva/core/data/datasets/embeddings.py +4 -4
  5. eva/core/data/samplers/classification/balanced.py +19 -18
  6. eva/core/loggers/utils/wandb.py +33 -0
  7. eva/core/models/modules/head.py +5 -3
  8. eva/core/models/modules/typings.py +2 -2
  9. eva/core/models/transforms/__init__.py +2 -1
  10. eva/core/models/transforms/as_discrete.py +57 -0
  11. eva/core/models/wrappers/_utils.py +121 -1
  12. eva/core/trainers/functional.py +8 -5
  13. eva/core/trainers/trainer.py +32 -17
  14. eva/core/utils/suppress_logs.py +28 -0
  15. eva/vision/data/__init__.py +2 -2
  16. eva/vision/data/dataloaders/__init__.py +5 -0
  17. eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
  18. eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
  19. eva/vision/data/datasets/__init__.py +10 -2
  20. eva/vision/data/datasets/classification/__init__.py +9 -0
  21. eva/vision/data/datasets/classification/bach.py +3 -4
  22. eva/vision/data/datasets/classification/bracs.py +111 -0
  23. eva/vision/data/datasets/classification/breakhis.py +209 -0
  24. eva/vision/data/datasets/classification/camelyon16.py +4 -5
  25. eva/vision/data/datasets/classification/crc.py +3 -4
  26. eva/vision/data/datasets/classification/gleason_arvaniti.py +171 -0
  27. eva/vision/data/datasets/classification/mhist.py +3 -4
  28. eva/vision/data/datasets/classification/panda.py +4 -5
  29. eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
  30. eva/vision/data/datasets/classification/unitopatho.py +158 -0
  31. eva/vision/data/datasets/classification/wsi.py +6 -5
  32. eva/vision/data/datasets/segmentation/__init__.py +2 -2
  33. eva/vision/data/datasets/segmentation/_utils.py +47 -0
  34. eva/vision/data/datasets/segmentation/bcss.py +7 -8
  35. eva/vision/data/datasets/segmentation/btcv.py +236 -0
  36. eva/vision/data/datasets/segmentation/consep.py +6 -7
  37. eva/vision/data/datasets/segmentation/embeddings.py +2 -2
  38. eva/vision/data/datasets/segmentation/lits.py +9 -8
  39. eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
  40. eva/vision/data/datasets/segmentation/monusac.py +4 -5
  41. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
  42. eva/vision/data/datasets/vision.py +95 -4
  43. eva/vision/data/datasets/wsi.py +5 -5
  44. eva/vision/data/transforms/__init__.py +22 -3
  45. eva/vision/data/transforms/common/__init__.py +1 -2
  46. eva/vision/data/transforms/croppad/__init__.py +11 -0
  47. eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
  48. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
  49. eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
  50. eva/vision/data/transforms/intensity/__init__.py +11 -0
  51. eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
  52. eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
  53. eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
  54. eva/vision/data/transforms/spatial/__init__.py +7 -0
  55. eva/vision/data/transforms/spatial/flip.py +72 -0
  56. eva/vision/data/transforms/spatial/rotate.py +53 -0
  57. eva/vision/data/transforms/spatial/spacing.py +69 -0
  58. eva/vision/data/transforms/utility/__init__.py +5 -0
  59. eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
  60. eva/vision/data/tv_tensors/__init__.py +5 -0
  61. eva/vision/data/tv_tensors/volume.py +61 -0
  62. eva/vision/metrics/segmentation/monai_dice.py +9 -2
  63. eva/vision/models/modules/semantic_segmentation.py +28 -20
  64. eva/vision/models/networks/backbones/__init__.py +9 -2
  65. eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
  66. eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
  67. eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
  68. eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
  69. eva/vision/models/networks/backbones/pathology/mahmood.py +46 -19
  70. eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
  71. eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
  72. eva/vision/models/networks/backbones/radiology/voco.py +75 -0
  73. eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
  74. eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
  75. eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
  76. eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
  77. eva/vision/utils/io/__init__.py +2 -0
  78. eva/vision/utils/io/nifti.py +91 -11
  79. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/METADATA +3 -1
  80. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/RECORD +83 -62
  81. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/WHEEL +1 -1
  82. eva/vision/data/datasets/classification/base.py +0 -96
  83. eva/vision/data/datasets/segmentation/base.py +0 -96
  84. eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
  85. eva/vision/data/transforms/normalization/__init__.py +0 -6
  86. eva/vision/data/transforms/normalization/clamp.py +0 -43
  87. eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
  88. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
  89. eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
  90. eva/vision/metrics/segmentation/BUILD +0 -1
  91. eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
  92. eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
  93. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/entry_points.txt +0 -0
  94. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,55 @@
1
+ """Intensity shifting transform."""
2
+
3
+ import functools
4
+ from typing import Any, Dict
5
+
6
+ from monai.transforms.intensity import array as monai_intensity_transforms
7
+ from torchvision import tv_tensors
8
+ from torchvision.transforms import v2
9
+ from typing_extensions import override
10
+
11
+ from eva.vision.data import tv_tensors as eva_tv_tensors
12
+
13
+
14
+ class RandShiftIntensity(v2.Transform):
15
+ """Randomly shift intensity with randomly picked offset."""
16
+
17
+ def __init__(
18
+ self,
19
+ offsets: tuple[float, float] | float,
20
+ safe: bool = False,
21
+ prob: float = 0.1,
22
+ channel_wise: bool = False,
23
+ ) -> None:
24
+ """Initializes the transform.
25
+
26
+ Args:
27
+ offsets: Offset range to randomly shift.
28
+ if single number, offset value is picked from (-offsets, offsets).
29
+ safe: If `True`, then do safe dtype convert when intensity overflow.
30
+ E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then
31
+ `[256, -12]` -> `[array(255), array(0)]`.
32
+ prob: Probability of shift.
33
+ channel_wise: If True, shift intensity on each channel separately.
34
+ For each channel, a random offset will be chosen. Please ensure
35
+ that the first dimension represents the channel of the image if True.
36
+ """
37
+ super().__init__()
38
+
39
+ self._rand_swift_intensity = monai_intensity_transforms.RandShiftIntensity(
40
+ offsets=offsets,
41
+ safe=safe,
42
+ prob=prob,
43
+ channel_wise=channel_wise,
44
+ )
45
+
46
+ @functools.singledispatchmethod
47
+ @override
48
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
49
+ return inpt
50
+
51
+ @_transform.register(tv_tensors.Image)
52
+ @_transform.register(eva_tv_tensors.Volume)
53
+ def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
54
+ inpt_scaled = self._rand_swift_intensity(inpt)
55
+ return tv_tensors.wrap(inpt_scaled, like=inpt)
@@ -0,0 +1,56 @@
1
+ """Intensity scaling transform."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, Tuple
5
+
6
+ from monai.transforms.intensity import array as monai_intensity_transforms
7
+ from torchvision import tv_tensors
8
+ from torchvision.transforms import v2
9
+ from typing_extensions import override
10
+
11
+ from eva.vision.data import tv_tensors as eva_tv_tensors
12
+
13
+
14
+ class ScaleIntensityRange(v2.Transform):
15
+ """Intensity scaling transform.
16
+
17
+ Scaling from [a_min, a_max] to [b_min, b_max] with clip option.
18
+
19
+ When `b_min` or `b_max` are `None`, `scaled_array * (b_max - b_min) + b_min`
20
+ will be skipped. If `clip=True`, when `b_min`/`b_max` is None, the clipping
21
+ is not performed on the corresponding edge.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ input_range: Tuple[float, float],
27
+ output_range: Tuple[float, float] | None = None,
28
+ clip: bool = True,
29
+ ) -> None:
30
+ """Initializes the transform.
31
+
32
+ Args:
33
+ input_range: Intensity original range min and max.
34
+ output_range: Intensity target range min and max.
35
+ clip: Whether to perform clip after scaling.
36
+ """
37
+ super().__init__()
38
+
39
+ self._scale_intensity_range = monai_intensity_transforms.ScaleIntensityRange(
40
+ a_min=input_range[0],
41
+ a_max=input_range[1],
42
+ b_min=output_range[0] if output_range else None,
43
+ b_max=output_range[1] if output_range else None,
44
+ clip=clip,
45
+ )
46
+
47
+ @functools.singledispatchmethod
48
+ @override
49
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
50
+ return inpt
51
+
52
+ @_transform.register(tv_tensors.Image)
53
+ @_transform.register(eva_tv_tensors.Volume)
54
+ def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
55
+ inpt_scaled = self._scale_intensity_range(inpt)
56
+ return tv_tensors.wrap(inpt_scaled, like=inpt)
@@ -0,0 +1,7 @@
1
+ """Transforms for spatial operations."""
2
+
3
+ from eva.vision.data.transforms.spatial.flip import RandFlip
4
+ from eva.vision.data.transforms.spatial.rotate import RandRotate90
5
+ from eva.vision.data.transforms.spatial.spacing import Spacing
6
+
7
+ __all__ = ["Spacing", "RandFlip", "RandRotate90"]
@@ -0,0 +1,72 @@
1
+ """Flip transforms."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, List, Sequence
5
+
6
+ import torch
7
+ from monai.transforms.spatial import array as monai_spatial_transforms
8
+ from torchvision import tv_tensors
9
+ from torchvision.transforms import v2
10
+ from typing_extensions import override
11
+
12
+ from eva.vision.data import tv_tensors as eva_tv_tensors
13
+
14
+
15
+ class RandFlip(v2.Transform):
16
+ """Randomly flips the image along axes."""
17
+
18
+ def __init__(
19
+ self,
20
+ prob: float = 0.1,
21
+ spatial_axes: Sequence[int] | int | None = None,
22
+ apply_per_axis: bool = True,
23
+ ) -> None:
24
+ """Initializes the transform.
25
+
26
+ Args:
27
+ prob: Probability of flipping.
28
+ spatial_axes: Spatial axes along which to flip over. Default is None.
29
+ apply_per_axis: If True, will apply a random flip transform to each
30
+ axis individually (if spatial_axes is a sequence of multiple axis).
31
+ If False, will apply a single random flip transform applied to all axes.
32
+ """
33
+ super().__init__()
34
+
35
+ if apply_per_axis:
36
+ if not isinstance(spatial_axes, (list, tuple)):
37
+ raise ValueError(
38
+ "`spatial_axis` is expected to be sequence `apply_per_axis` "
39
+ f"is enabled, got {type(spatial_axes)}"
40
+ )
41
+ self._flips = [
42
+ monai_spatial_transforms.RandFlip(prob=prob, spatial_axis=axis)
43
+ for axis in spatial_axes
44
+ ]
45
+ else:
46
+ self._flips = [monai_spatial_transforms.RandFlip(prob=prob, spatial_axis=spatial_axes)]
47
+
48
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
49
+ for flip in self._flips:
50
+ flip.randomize(None)
51
+ return {}
52
+
53
+ @functools.singledispatchmethod
54
+ @override
55
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
56
+ return inpt
57
+
58
+ @_transform.register(tv_tensors.Image)
59
+ @_transform.register(eva_tv_tensors.Volume)
60
+ def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
61
+ inpt_flipped = self._apply_flips(inpt)
62
+ return tv_tensors.wrap(inpt_flipped, like=inpt)
63
+
64
+ @_transform.register(tv_tensors.Mask)
65
+ def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
66
+ inpt_flipped = torch.tensor(self._apply_flips(inpt), dtype=torch.long)
67
+ return tv_tensors.wrap(inpt_flipped, like=inpt)
68
+
69
+ def _apply_flips(self, inpt: Any) -> Any:
70
+ for flip in self._flips:
71
+ inpt = flip(img=inpt, randomize=False)
72
+ return inpt
@@ -0,0 +1,53 @@
1
+ """Rotation transforms."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, List
5
+
6
+ from monai.transforms.spatial import array as monai_spatial_transforms
7
+ from torchvision import tv_tensors
8
+ from torchvision.transforms import v2
9
+ from typing_extensions import override
10
+
11
+ from eva.vision.data import tv_tensors as eva_tv_tensors
12
+
13
+
14
+ class RandRotate90(v2.Transform):
15
+ """Rotate input tensors by 90 degrees."""
16
+
17
+ def __init__(
18
+ self,
19
+ prob: float = 0.1,
20
+ max_k: int = 3,
21
+ spatial_axes: tuple[int, int] = (1, 2),
22
+ ) -> None:
23
+ """Initializes the transform.
24
+
25
+ Args:
26
+ prob: probability of rotating.
27
+ (Default 0.1, with 10% probability it returns a rotated array)
28
+ max_k: number of rotations will be sampled from `np.random.randint(max_k) + 1`.
29
+ spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes.
30
+ Default: (1, 2), so for [C, T, H, W] will rotate along (H, W) plane (MONAI ignores
31
+ the first C dimension).
32
+ """
33
+ super().__init__()
34
+
35
+ self._rotate = monai_spatial_transforms.RandRotate90(
36
+ prob=prob, max_k=max_k, spatial_axes=spatial_axes
37
+ )
38
+
39
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
40
+ self._rotate.randomize()
41
+ return {}
42
+
43
+ @functools.singledispatchmethod
44
+ @override
45
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
46
+ return inpt
47
+
48
+ @_transform.register(tv_tensors.Image)
49
+ @_transform.register(eva_tv_tensors.Volume)
50
+ @_transform.register(tv_tensors.Mask)
51
+ def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
52
+ inpt_rotated = self._rotate(img=inpt, randomize=False)
53
+ return tv_tensors.wrap(inpt_rotated, like=inpt)
@@ -0,0 +1,69 @@
1
+ """Spacing resample transform."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, List, Sequence
5
+
6
+ import numpy as np
7
+ import torch
8
+ from monai.data import meta_tensor
9
+ from monai.transforms.spatial import array as monai_spatial_transforms
10
+ from torchvision import tv_tensors
11
+ from torchvision.transforms import v2
12
+ from typing_extensions import override
13
+
14
+ from eva.vision.data import tv_tensors as eva_tv_tensors
15
+
16
+
17
+ class Spacing(v2.Transform):
18
+ """Resample input image into the specified `pixdim`.
19
+
20
+ - Expects tensors of shape `[C, T, H, W]`.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ pixdim: Sequence[float] | float | np.ndarray,
26
+ ) -> None:
27
+ """Initializes the transform.
28
+
29
+ Args:
30
+ pixdim: output voxel spacing. if providing a single number,
31
+ will use it for the first dimension. Items of the pixdim
32
+ sequence map to the spatial dimensions of input image, if
33
+ length of pixdim sequence is longer than image spatial
34
+ dimensions, will ignore the longer part, if shorter, will
35
+ pad with the last value. For example, for 3D image if pixdim
36
+ is [1.0, 2.0] it will be padded to [1.0, 2.0, 2.0] if the
37
+ components of the `pixdim` are non-positive values, the
38
+ transform will use the corresponding components of the original
39
+ pixdim, which is computed from the `affine` matrix of input image.
40
+ """
41
+ super().__init__()
42
+
43
+ self._spacing = monai_spatial_transforms.Spacing(pixdim=pixdim, recompute_affine=True)
44
+ self._affine = None
45
+
46
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
47
+ self._affine = next(
48
+ inpt.affine for inpt in flat_inputs if isinstance(inpt, eva_tv_tensors.Volume)
49
+ )
50
+ return {}
51
+
52
+ @functools.singledispatchmethod
53
+ @override
54
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
55
+ return inpt
56
+
57
+ @_transform.register(eva_tv_tensors.Volume)
58
+ def _(self, inpt: eva_tv_tensors.Volume, params: Dict[str, Any]) -> Any:
59
+ inpt_spacing = self._spacing(inpt.to_meta_tensor(), mode="bilinear")
60
+ if not isinstance(inpt_spacing, meta_tensor.MetaTensor):
61
+ raise ValueError(f"Expected MetaTensor, got {type(inpt_spacing)}")
62
+ return eva_tv_tensors.Volume.from_meta_tensor(inpt_spacing)
63
+
64
+ @_transform.register(tv_tensors.Mask)
65
+ def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
66
+ inpt_spacing = self._spacing(
67
+ meta_tensor.MetaTensor(inpt, affine=self._affine), mode="nearest"
68
+ )
69
+ return tv_tensors.wrap(inpt_spacing.to(dtype=torch.long), like=inpt)
@@ -0,0 +1,5 @@
1
+ """Transforms for utility operations."""
2
+
3
+ from eva.vision.data.transforms.utility.ensure_channel_first import EnsureChannelFirst
4
+
5
+ __all__ = ["EnsureChannelFirst"]
@@ -0,0 +1,51 @@
1
+ """Adjust or add the channel dimension of input data to ensure `channel_first` shape."""
2
+
3
+ import functools
4
+ from typing import Any, Dict
5
+
6
+ from monai.transforms.utility import array as monai_utility_transforms
7
+ from torchvision import tv_tensors
8
+ from torchvision.transforms import v2
9
+ from typing_extensions import override
10
+
11
+ from eva.vision.data import tv_tensors as eva_tv_tensors
12
+
13
+
14
+ class EnsureChannelFirst(v2.Transform):
15
+ """Adjust or add the channel dimension of input data to ensure `channel_first` shape."""
16
+
17
+ def __init__(
18
+ self,
19
+ strict_check: bool = True,
20
+ channel_dim: None | str | int = None,
21
+ ) -> None:
22
+ """Initializes the transform.
23
+
24
+ Args:
25
+ strict_check: whether to raise an error when the meta information is insufficient.
26
+ channel_dim: This argument can be used to specify the original channel dimension
27
+ (integer) of the input array.
28
+ It overrides the `original_channel_dim` from provided MetaTensor input.
29
+ If the input array doesn't have a channel dim, this value should be
30
+ ``'no_channel'``.
31
+ If this is set to `None`, this class relies on `img` or `meta_dict` to provide
32
+ the channel dimension.
33
+ """
34
+ super().__init__()
35
+
36
+ self._ensure_channel_first = monai_utility_transforms.EnsureChannelFirst(
37
+ strict_check=strict_check,
38
+ channel_dim=channel_dim,
39
+ )
40
+
41
+ @functools.singledispatchmethod
42
+ @override
43
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
44
+ return inpt
45
+
46
+ @_transform.register(tv_tensors.Image)
47
+ @_transform.register(eva_tv_tensors.Volume)
48
+ @_transform.register(tv_tensors.Mask)
49
+ def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
50
+ inpt_channel_first = self._ensure_channel_first(inpt)
51
+ return tv_tensors.wrap(inpt_channel_first, like=inpt)
@@ -0,0 +1,5 @@
1
+ """Custom `tv_tensors` types for torchvision."""
2
+
3
+ from eva.vision.data.tv_tensors.volume import Volume
4
+
5
+ __all__ = ["Volume"]
@@ -0,0 +1,61 @@
1
+ """Custom `tv_tensors` type for 3D Volumes."""
2
+
3
+ from typing import Any, Dict, Optional, Union
4
+
5
+ import torch
6
+ from monai.data import meta_tensor
7
+ from torchvision import tv_tensors
8
+ from typing_extensions import override
9
+
10
+
11
+ class Volume(tv_tensors.Video):
12
+ """:class:`torchvision.TVTensor` subclass for 3D volumes.
13
+
14
+ - Adds optional metadata and affine matrix to the tensor.
15
+ - Expects tensors to be of shape `[..., T, C, H, W]`.
16
+
17
+ Args:
18
+ data: Any data that can be turned into a tensor with :func:`torch.as_tensor`.
19
+ affine: Affine matrix of the volume. Expected to be of shape `[4, 4]`, and
20
+ columns to correspond to [T, H, W, (translation)] dimensions. Note that
21
+ `nibabel` by default uses [H, W, T, (translation)] order for affine matrices.
22
+ metadata: Metadata associated with the volume.
23
+ dtype: Desired data type. If omitted, will be inferred from `data`.
24
+ device: Desired device.
25
+ requires_grad: Whether autograd should record operations.
26
+ """
27
+
28
+ @override
29
+ def __new__(
30
+ cls,
31
+ data: Any,
32
+ affine: torch.Tensor | None = None,
33
+ metadata: Dict[str, Any] | None = None,
34
+ dtype: Optional[torch.dtype] = None,
35
+ device: Optional[Union[torch.device, str, int]] = None,
36
+ requires_grad: Optional[bool] = None,
37
+ ) -> "Volume":
38
+ cls.affine = affine
39
+ cls.metadata = metadata
40
+
41
+ return super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad) # type: ignore
42
+
43
+ @classmethod
44
+ def from_meta_tensor(cls, meta_tensor: meta_tensor.MetaTensor) -> "Volume":
45
+ """Creates an instance from a :class:`monai.data.meta_tensor.MetaTensor`."""
46
+ return cls(
47
+ meta_tensor.data,
48
+ affine=meta_tensor.affine,
49
+ metadata=meta_tensor.meta,
50
+ dtype=meta_tensor.dtype,
51
+ device=meta_tensor.device,
52
+ requires_grad=meta_tensor.requires_grad,
53
+ ) # type: ignore
54
+
55
+ def to_meta_tensor(self) -> meta_tensor.MetaTensor:
56
+ """Converts the volume to a :class:`monai.data.meta_tensor.MetaTensor`."""
57
+ return meta_tensor.MetaTensor(self, affine=self.affine, meta=self.metadata)
58
+
59
+ def __repr__(self, *, tensor_contents: Any = None) -> str:
60
+ """Returns the string representation of the object."""
61
+ return self._make_repr()
@@ -1,5 +1,7 @@
1
1
  """Wrapper for dice score metric from MONAI."""
2
2
 
3
+ from typing import Literal
4
+
3
5
  from monai.metrics.meandice import DiceMetric
4
6
  from typing_extensions import override
5
7
 
@@ -14,6 +16,7 @@ class MonaiDiceScore(wrappers.MonaiMetricWrapper):
14
16
  self,
15
17
  num_classes: int,
16
18
  include_background: bool = True,
19
+ input_format: Literal["one-hot", "index"] = "index",
17
20
  reduction: str = "mean",
18
21
  ignore_index: int | None = None,
19
22
  **kwargs,
@@ -24,6 +27,8 @@ class MonaiDiceScore(wrappers.MonaiMetricWrapper):
24
27
  num_classes: The number of classes in the dataset.
25
28
  include_background: Whether to include the background class in the computation.
26
29
  reduction: The method to reduce the dice score. Options are `"mean"`, `"sum"`, `"none"`.
30
+ input_format: Choose between "one-hot" for one-hot encoded tensors or "index"
31
+ for index tensors.
27
32
  ignore_index: Integer specifying a target class to ignore. If given, this class
28
33
  index does not contribute to the returned score.
29
34
  kwargs: Additional keyword arguments for instantiating monai's `DiceMetric` class.
@@ -40,11 +45,13 @@ class MonaiDiceScore(wrappers.MonaiMetricWrapper):
40
45
  self.reduction = reduction
41
46
  self.orig_num_classes = num_classes
42
47
  self.ignore_index = ignore_index
48
+ self.input_format = input_format
43
49
 
44
50
  @override
45
51
  def update(self, preds, target):
46
- preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes)
47
- target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes)
52
+ if self.input_format == "index":
53
+ preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes)
54
+ target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes)
48
55
  if self.ignore_index is not None:
49
56
  preds, target = _utils.apply_ignore_index(preds, target, self.ignore_index)
50
57
  return super().update(preds, target)
@@ -1,10 +1,11 @@
1
1
  """"Neural Network Semantic Segmentation Module."""
2
2
 
3
- from typing import Any, Callable, Dict, Iterable, List, Tuple
3
+ from typing import Any, Callable, Dict, Iterable, List
4
4
 
5
5
  import torch
6
6
  from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
7
7
  from lightning.pytorch.utilities.types import STEP_OUTPUT
8
+ from monai.inferers.inferer import Inferer
8
9
  from torch import nn, optim
9
10
  from torch.optim import lr_scheduler
10
11
  from typing_extensions import override
@@ -15,6 +16,7 @@ from eva.core.models.modules.typings import INPUT_BATCH, INPUT_TENSOR_BATCH
15
16
  from eva.core.models.modules.utils import batch_postprocess, grad, submodule_state_dict
16
17
  from eva.core.utils import parser
17
18
  from eva.vision.models.networks import decoders
19
+ from eva.vision.models.networks.decoders import segmentation
18
20
  from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs
19
21
 
20
22
 
@@ -23,10 +25,11 @@ class SemanticSegmentationModule(module.ModelModule):
23
25
 
24
26
  def __init__(
25
27
  self,
26
- decoder: decoders.Decoder,
28
+ decoder: decoders.Decoder | nn.Module,
27
29
  criterion: Callable[..., torch.Tensor],
28
30
  encoder: Dict[str, Any] | Callable[[torch.Tensor], List[torch.Tensor]] | None = None,
29
31
  lr_multiplier_encoder: float = 0.0,
32
+ inferer: Inferer | None = None,
30
33
  optimizer: OptimizerCallable = optim.AdamW,
31
34
  lr_scheduler: LRSchedulerCallable = lr_scheduler.ConstantLR,
32
35
  metrics: metrics_lib.MetricsSchema | None = None,
@@ -44,6 +47,8 @@ class SemanticSegmentationModule(module.ModelModule):
44
47
  during the `configure_model` step.
45
48
  lr_multiplier_encoder: The learning rate multiplier for the
46
49
  encoder parameters. If `0`, it will freeze the encoder.
50
+ inferer: An optional MONAI `Inferer` for inference
51
+ postprocess during evaluation.
47
52
  optimizer: The optimizer to use.
48
53
  lr_scheduler: The learning rate scheduler to use.
49
54
  metrics: The metric groups to track.
@@ -62,6 +67,7 @@ class SemanticSegmentationModule(module.ModelModule):
62
67
  self.optimizer = optimizer
63
68
  self.lr_scheduler = lr_scheduler
64
69
  self.save_decoder_only = save_decoder_only
70
+ self.inferer = inferer
65
71
 
66
72
  @override
67
73
  def configure_model(self) -> None:
@@ -104,25 +110,15 @@ class SemanticSegmentationModule(module.ModelModule):
104
110
  @override
105
111
  def forward(
106
112
  self,
107
- inputs: torch.Tensor,
108
- to_size: Tuple[int, int] | None = None,
113
+ tensor: torch.Tensor,
109
114
  *args: Any,
110
115
  **kwargs: Any,
111
116
  ) -> torch.Tensor:
112
- """Maps the input tensor (image tensor or embeddings) to masks.
113
-
114
- If `inputs` is image tensor, then the `self.encoder`
115
- should be implemented, otherwise it will be interpreted
116
- as embeddings, where the `to_size` should be given.
117
- """
118
- if self.encoder is None and to_size is None:
119
- raise ValueError(
120
- "Please provide the expected `to_size` that the "
121
- "decoder should map the embeddings (`inputs`) to."
122
- )
123
- features = self.encoder(inputs) if self.encoder else inputs
124
- decoder_inputs = DecoderInputs(features, to_size or inputs.shape[-2:], inputs) # type: ignore
125
- return self.decoder(decoder_inputs)
117
+ return (
118
+ self.inferer(tensor, network=self._forward_networks)
119
+ if self.inferer is not None and not self.training
120
+ else self._forward_networks(tensor)
121
+ )
126
122
 
127
123
  @override
128
124
  def training_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
@@ -137,7 +133,9 @@ class SemanticSegmentationModule(module.ModelModule):
137
133
  return self._batch_step(batch)
138
134
 
139
135
  @override
140
- def predict_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> torch.Tensor:
136
+ def predict_step(
137
+ self, batch: INPUT_BATCH, *args: Any, **kwargs: Any
138
+ ) -> torch.Tensor | List[torch.Tensor]:
141
139
  tensor = INPUT_BATCH(*batch).data
142
140
  return self.encoder(tensor) if isinstance(self.encoder, nn.Module) else tensor
143
141
 
@@ -170,7 +168,7 @@ class SemanticSegmentationModule(module.ModelModule):
170
168
  The batch step output.
171
169
  """
172
170
  data, targets, metadata = INPUT_TENSOR_BATCH(*batch)
173
- predictions = self(data, to_size=targets.shape[-2:])
171
+ predictions = self(data)
174
172
  loss = self.criterion(predictions, targets)
175
173
  return {
176
174
  "loss": loss,
@@ -178,3 +176,13 @@ class SemanticSegmentationModule(module.ModelModule):
178
176
  "predictions": predictions,
179
177
  "metadata": metadata,
180
178
  }
179
+
180
+ def _forward_networks(self, tensor: torch.Tensor) -> torch.Tensor:
181
+ """Passes the input tensor through the encoder and decoder."""
182
+ features = self.encoder(tensor) if self.encoder else tensor
183
+ if isinstance(self.decoder, segmentation.Decoder):
184
+ if not isinstance(features, list):
185
+ raise ValueError(f"Expected a list of feature map tensors, got {type(features)}.")
186
+ image_size = (tensor.shape[-2], tensor.shape[-1])
187
+ return self.decoder(DecoderInputs(features, image_size, tensor))
188
+ return self.decoder(features)
@@ -1,6 +1,13 @@
1
1
  """Vision Model Backbones API."""
2
2
 
3
- from eva.vision.models.networks.backbones import pathology, timm, torchhub, universal
3
+ from eva.vision.models.networks.backbones import pathology, radiology, timm, universal
4
4
  from eva.vision.models.networks.backbones.registry import BackboneModelRegistry, register_model
5
5
 
6
- __all__ = ["pathology", "timm", "torchhub", "universal", "BackboneModelRegistry", "register_model"]
6
+ __all__ = [
7
+ "radiology",
8
+ "pathology",
9
+ "timm",
10
+ "universal",
11
+ "BackboneModelRegistry",
12
+ "register_model",
13
+ ]
@@ -1,9 +1,14 @@
1
1
  """Vision Pathology Model Backbones API."""
2
2
 
3
- from eva.vision.models.networks.backbones.pathology.bioptimus import bioptimus_h_optimus_0
3
+ from eva.vision.models.networks.backbones.pathology.bioptimus import (
4
+ bioptimus_h0_mini,
5
+ bioptimus_h_optimus_0,
6
+ )
4
7
  from eva.vision.models.networks.backbones.pathology.gigapath import prov_gigapath
5
8
  from eva.vision.models.networks.backbones.pathology.histai import histai_hibou_b, histai_hibou_l
9
+ from eva.vision.models.networks.backbones.pathology.hkust import hkust_gpfm
6
10
  from eva.vision.models.networks.backbones.pathology.kaiko import (
11
+ kaiko_midnight_12k,
7
12
  kaiko_vitb8,
8
13
  kaiko_vitb16,
9
14
  kaiko_vitl14,
@@ -11,11 +16,12 @@ from eva.vision.models.networks.backbones.pathology.kaiko import (
11
16
  kaiko_vits16,
12
17
  )
13
18
  from eva.vision.models.networks.backbones.pathology.lunit import lunit_vits8, lunit_vits16
14
- from eva.vision.models.networks.backbones.pathology.mahmood import mahmood_uni
19
+ from eva.vision.models.networks.backbones.pathology.mahmood import mahmood_uni, mahmood_uni2_h
15
20
  from eva.vision.models.networks.backbones.pathology.owkin import owkin_phikon, owkin_phikon_v2
16
21
  from eva.vision.models.networks.backbones.pathology.paige import paige_virchow2
17
22
 
18
23
  __all__ = [
24
+ "kaiko_midnight_12k",
19
25
  "kaiko_vitb16",
20
26
  "kaiko_vitb8",
21
27
  "kaiko_vitl14",
@@ -26,9 +32,12 @@ __all__ = [
26
32
  "lunit_vits16",
27
33
  "lunit_vits8",
28
34
  "mahmood_uni",
35
+ "mahmood_uni2_h",
29
36
  "bioptimus_h_optimus_0",
37
+ "bioptimus_h0_mini",
30
38
  "prov_gigapath",
31
39
  "histai_hibou_b",
32
40
  "histai_hibou_l",
33
41
  "paige_virchow2",
42
+ "hkust_gpfm",
34
43
  ]