kaiko-eva 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (85) hide show
  1. eva/core/data/datasets/base.py +7 -2
  2. eva/core/models/modules/head.py +4 -2
  3. eva/core/models/modules/typings.py +2 -2
  4. eva/core/models/transforms/__init__.py +2 -1
  5. eva/core/models/transforms/as_discrete.py +57 -0
  6. eva/core/models/wrappers/_utils.py +121 -1
  7. eva/core/trainers/_recorder.py +4 -1
  8. eva/core/utils/suppress_logs.py +28 -0
  9. eva/vision/data/__init__.py +2 -2
  10. eva/vision/data/dataloaders/__init__.py +5 -0
  11. eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
  12. eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
  13. eva/vision/data/datasets/__init__.py +2 -2
  14. eva/vision/data/datasets/classification/bach.py +3 -4
  15. eva/vision/data/datasets/classification/bracs.py +3 -4
  16. eva/vision/data/datasets/classification/breakhis.py +3 -4
  17. eva/vision/data/datasets/classification/camelyon16.py +4 -5
  18. eva/vision/data/datasets/classification/crc.py +3 -4
  19. eva/vision/data/datasets/classification/gleason_arvaniti.py +3 -4
  20. eva/vision/data/datasets/classification/mhist.py +3 -4
  21. eva/vision/data/datasets/classification/panda.py +4 -5
  22. eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
  23. eva/vision/data/datasets/classification/unitopatho.py +3 -4
  24. eva/vision/data/datasets/classification/wsi.py +6 -5
  25. eva/vision/data/datasets/segmentation/__init__.py +2 -2
  26. eva/vision/data/datasets/segmentation/_utils.py +47 -0
  27. eva/vision/data/datasets/segmentation/bcss.py +7 -8
  28. eva/vision/data/datasets/segmentation/btcv.py +236 -0
  29. eva/vision/data/datasets/segmentation/consep.py +6 -7
  30. eva/vision/data/datasets/segmentation/lits.py +9 -8
  31. eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
  32. eva/vision/data/datasets/segmentation/monusac.py +4 -5
  33. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
  34. eva/vision/data/datasets/vision.py +95 -4
  35. eva/vision/data/datasets/wsi.py +5 -5
  36. eva/vision/data/transforms/__init__.py +22 -3
  37. eva/vision/data/transforms/common/__init__.py +1 -2
  38. eva/vision/data/transforms/croppad/__init__.py +11 -0
  39. eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
  40. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
  41. eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
  42. eva/vision/data/transforms/intensity/__init__.py +11 -0
  43. eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
  44. eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
  45. eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
  46. eva/vision/data/transforms/spatial/__init__.py +7 -0
  47. eva/vision/data/transforms/spatial/flip.py +72 -0
  48. eva/vision/data/transforms/spatial/rotate.py +53 -0
  49. eva/vision/data/transforms/spatial/spacing.py +69 -0
  50. eva/vision/data/transforms/utility/__init__.py +5 -0
  51. eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
  52. eva/vision/data/tv_tensors/__init__.py +5 -0
  53. eva/vision/data/tv_tensors/volume.py +61 -0
  54. eva/vision/metrics/segmentation/monai_dice.py +9 -2
  55. eva/vision/models/modules/semantic_segmentation.py +32 -19
  56. eva/vision/models/networks/backbones/__init__.py +9 -2
  57. eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
  58. eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
  59. eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
  60. eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
  61. eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
  62. eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
  63. eva/vision/models/networks/backbones/radiology/voco.py +75 -0
  64. eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
  65. eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
  66. eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
  67. eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
  68. eva/vision/utils/io/__init__.py +2 -0
  69. eva/vision/utils/io/nifti.py +91 -11
  70. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/METADATA +16 -12
  71. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/RECORD +74 -58
  72. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/WHEEL +1 -1
  73. eva/vision/data/datasets/classification/base.py +0 -96
  74. eva/vision/data/datasets/segmentation/base.py +0 -96
  75. eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
  76. eva/vision/data/transforms/normalization/__init__.py +0 -6
  77. eva/vision/data/transforms/normalization/clamp.py +0 -43
  78. eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
  79. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
  80. eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
  81. eva/vision/metrics/segmentation/BUILD +0 -1
  82. eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
  83. eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
  84. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/entry_points.txt +0 -0
  85. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,53 @@
1
+ """Rotation transforms."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, List
5
+
6
+ from monai.transforms.spatial import array as monai_spatial_transforms
7
+ from torchvision import tv_tensors
8
+ from torchvision.transforms import v2
9
+ from typing_extensions import override
10
+
11
+ from eva.vision.data import tv_tensors as eva_tv_tensors
12
+
13
+
14
+ class RandRotate90(v2.Transform):
15
+ """Rotate input tensors by 90 degrees."""
16
+
17
+ def __init__(
18
+ self,
19
+ prob: float = 0.1,
20
+ max_k: int = 3,
21
+ spatial_axes: tuple[int, int] = (1, 2),
22
+ ) -> None:
23
+ """Initializes the transform.
24
+
25
+ Args:
26
+ prob: probability of rotating.
27
+ (Default 0.1, with 10% probability it returns a rotated array)
28
+ max_k: number of rotations will be sampled from `np.random.randint(max_k) + 1`.
29
+ spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes.
30
+ Default: (1, 2), so for [C, T, H, W] will rotate along (H, W) plane (MONAI ignores
31
+ the first C dimension).
32
+ """
33
+ super().__init__()
34
+
35
+ self._rotate = monai_spatial_transforms.RandRotate90(
36
+ prob=prob, max_k=max_k, spatial_axes=spatial_axes
37
+ )
38
+
39
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
40
+ self._rotate.randomize()
41
+ return {}
42
+
43
+ @functools.singledispatchmethod
44
+ @override
45
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
46
+ return inpt
47
+
48
+ @_transform.register(tv_tensors.Image)
49
+ @_transform.register(eva_tv_tensors.Volume)
50
+ @_transform.register(tv_tensors.Mask)
51
+ def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
52
+ inpt_rotated = self._rotate(img=inpt, randomize=False)
53
+ return tv_tensors.wrap(inpt_rotated, like=inpt)
@@ -0,0 +1,69 @@
1
+ """Spacing resample transform."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, List, Sequence
5
+
6
+ import numpy as np
7
+ import torch
8
+ from monai.data import meta_tensor
9
+ from monai.transforms.spatial import array as monai_spatial_transforms
10
+ from torchvision import tv_tensors
11
+ from torchvision.transforms import v2
12
+ from typing_extensions import override
13
+
14
+ from eva.vision.data import tv_tensors as eva_tv_tensors
15
+
16
+
17
+ class Spacing(v2.Transform):
18
+ """Resample input image into the specified `pixdim`.
19
+
20
+ - Expects tensors of shape `[C, T, H, W]`.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ pixdim: Sequence[float] | float | np.ndarray,
26
+ ) -> None:
27
+ """Initializes the transform.
28
+
29
+ Args:
30
+ pixdim: output voxel spacing. if providing a single number,
31
+ will use it for the first dimension. Items of the pixdim
32
+ sequence map to the spatial dimensions of input image, if
33
+ length of pixdim sequence is longer than image spatial
34
+ dimensions, will ignore the longer part, if shorter, will
35
+ pad with the last value. For example, for 3D image if pixdim
36
+ is [1.0, 2.0] it will be padded to [1.0, 2.0, 2.0] if the
37
+ components of the `pixdim` are non-positive values, the
38
+ transform will use the corresponding components of the original
39
+ pixdim, which is computed from the `affine` matrix of input image.
40
+ """
41
+ super().__init__()
42
+
43
+ self._spacing = monai_spatial_transforms.Spacing(pixdim=pixdim, recompute_affine=True)
44
+ self._affine = None
45
+
46
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
47
+ self._affine = next(
48
+ inpt.affine for inpt in flat_inputs if isinstance(inpt, eva_tv_tensors.Volume)
49
+ )
50
+ return {}
51
+
52
+ @functools.singledispatchmethod
53
+ @override
54
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
55
+ return inpt
56
+
57
+ @_transform.register(eva_tv_tensors.Volume)
58
+ def _(self, inpt: eva_tv_tensors.Volume, params: Dict[str, Any]) -> Any:
59
+ inpt_spacing = self._spacing(inpt.to_meta_tensor(), mode="bilinear")
60
+ if not isinstance(inpt_spacing, meta_tensor.MetaTensor):
61
+ raise ValueError(f"Expected MetaTensor, got {type(inpt_spacing)}")
62
+ return eva_tv_tensors.Volume.from_meta_tensor(inpt_spacing)
63
+
64
+ @_transform.register(tv_tensors.Mask)
65
+ def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
66
+ inpt_spacing = self._spacing(
67
+ meta_tensor.MetaTensor(inpt, affine=self._affine), mode="nearest"
68
+ )
69
+ return tv_tensors.wrap(inpt_spacing.to(dtype=torch.long), like=inpt)
@@ -0,0 +1,5 @@
1
+ """Transforms for utility operations."""
2
+
3
+ from eva.vision.data.transforms.utility.ensure_channel_first import EnsureChannelFirst
4
+
5
+ __all__ = ["EnsureChannelFirst"]
@@ -0,0 +1,51 @@
1
+ """Adjust or add the channel dimension of input data to ensure `channel_first` shape."""
2
+
3
+ import functools
4
+ from typing import Any, Dict
5
+
6
+ from monai.transforms.utility import array as monai_utility_transforms
7
+ from torchvision import tv_tensors
8
+ from torchvision.transforms import v2
9
+ from typing_extensions import override
10
+
11
+ from eva.vision.data import tv_tensors as eva_tv_tensors
12
+
13
+
14
+ class EnsureChannelFirst(v2.Transform):
15
+ """Adjust or add the channel dimension of input data to ensure `channel_first` shape."""
16
+
17
+ def __init__(
18
+ self,
19
+ strict_check: bool = True,
20
+ channel_dim: None | str | int = None,
21
+ ) -> None:
22
+ """Initializes the transform.
23
+
24
+ Args:
25
+ strict_check: whether to raise an error when the meta information is insufficient.
26
+ channel_dim: This argument can be used to specify the original channel dimension
27
+ (integer) of the input array.
28
+ It overrides the `original_channel_dim` from provided MetaTensor input.
29
+ If the input array doesn't have a channel dim, this value should be
30
+ ``'no_channel'``.
31
+ If this is set to `None`, this class relies on `img` or `meta_dict` to provide
32
+ the channel dimension.
33
+ """
34
+ super().__init__()
35
+
36
+ self._ensure_channel_first = monai_utility_transforms.EnsureChannelFirst(
37
+ strict_check=strict_check,
38
+ channel_dim=channel_dim,
39
+ )
40
+
41
+ @functools.singledispatchmethod
42
+ @override
43
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
44
+ return inpt
45
+
46
+ @_transform.register(tv_tensors.Image)
47
+ @_transform.register(eva_tv_tensors.Volume)
48
+ @_transform.register(tv_tensors.Mask)
49
+ def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
50
+ inpt_channel_first = self._ensure_channel_first(inpt)
51
+ return tv_tensors.wrap(inpt_channel_first, like=inpt)
@@ -0,0 +1,5 @@
1
+ """Custom `tv_tensors` types for torchvision."""
2
+
3
+ from eva.vision.data.tv_tensors.volume import Volume
4
+
5
+ __all__ = ["Volume"]
@@ -0,0 +1,61 @@
1
+ """Custom `tv_tensors` type for 3D Volumes."""
2
+
3
+ from typing import Any, Dict, Optional, Union
4
+
5
+ import torch
6
+ from monai.data import meta_tensor
7
+ from torchvision import tv_tensors
8
+ from typing_extensions import override
9
+
10
+
11
+ class Volume(tv_tensors.Video):
12
+ """:class:`torchvision.TVTensor` subclass for 3D volumes.
13
+
14
+ - Adds optional metadata and affine matrix to the tensor.
15
+ - Expects tensors to be of shape `[..., T, C, H, W]`.
16
+
17
+ Args:
18
+ data: Any data that can be turned into a tensor with :func:`torch.as_tensor`.
19
+ affine: Affine matrix of the volume. Expected to be of shape `[4, 4]`, and
20
+ columns to correspond to [T, H, W, (translation)] dimensions. Note that
21
+ `nibabel` by default uses [H, W, T, (translation)] order for affine matrices.
22
+ metadata: Metadata associated with the volume.
23
+ dtype: Desired data type. If omitted, will be inferred from `data`.
24
+ device: Desired device.
25
+ requires_grad: Whether autograd should record operations.
26
+ """
27
+
28
+ @override
29
+ def __new__(
30
+ cls,
31
+ data: Any,
32
+ affine: torch.Tensor | None = None,
33
+ metadata: Dict[str, Any] | None = None,
34
+ dtype: Optional[torch.dtype] = None,
35
+ device: Optional[Union[torch.device, str, int]] = None,
36
+ requires_grad: Optional[bool] = None,
37
+ ) -> "Volume":
38
+ cls.affine = affine
39
+ cls.metadata = metadata
40
+
41
+ return super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad) # type: ignore
42
+
43
+ @classmethod
44
+ def from_meta_tensor(cls, meta_tensor: meta_tensor.MetaTensor) -> "Volume":
45
+ """Creates an instance from a :class:`monai.data.meta_tensor.MetaTensor`."""
46
+ return cls(
47
+ meta_tensor.data,
48
+ affine=meta_tensor.affine,
49
+ metadata=meta_tensor.meta,
50
+ dtype=meta_tensor.dtype,
51
+ device=meta_tensor.device,
52
+ requires_grad=meta_tensor.requires_grad,
53
+ ) # type: ignore
54
+
55
+ def to_meta_tensor(self) -> meta_tensor.MetaTensor:
56
+ """Converts the volume to a :class:`monai.data.meta_tensor.MetaTensor`."""
57
+ return meta_tensor.MetaTensor(self, affine=self.affine, meta=self.metadata)
58
+
59
+ def __repr__(self, *, tensor_contents: Any = None) -> str:
60
+ """Returns the string representation of the object."""
61
+ return self._make_repr()
@@ -1,5 +1,7 @@
1
1
  """Wrapper for dice score metric from MONAI."""
2
2
 
3
+ from typing import Literal
4
+
3
5
  from monai.metrics.meandice import DiceMetric
4
6
  from typing_extensions import override
5
7
 
@@ -14,6 +16,7 @@ class MonaiDiceScore(wrappers.MonaiMetricWrapper):
14
16
  self,
15
17
  num_classes: int,
16
18
  include_background: bool = True,
19
+ input_format: Literal["one-hot", "index"] = "index",
17
20
  reduction: str = "mean",
18
21
  ignore_index: int | None = None,
19
22
  **kwargs,
@@ -24,6 +27,8 @@ class MonaiDiceScore(wrappers.MonaiMetricWrapper):
24
27
  num_classes: The number of classes in the dataset.
25
28
  include_background: Whether to include the background class in the computation.
26
29
  reduction: The method to reduce the dice score. Options are `"mean"`, `"sum"`, `"none"`.
30
+ input_format: Choose between "one-hot" for one-hot encoded tensors or "index"
31
+ for index tensors.
27
32
  ignore_index: Integer specifying a target class to ignore. If given, this class
28
33
  index does not contribute to the returned score.
29
34
  kwargs: Additional keyword arguments for instantiating monai's `DiceMetric` class.
@@ -40,11 +45,13 @@ class MonaiDiceScore(wrappers.MonaiMetricWrapper):
40
45
  self.reduction = reduction
41
46
  self.orig_num_classes = num_classes
42
47
  self.ignore_index = ignore_index
48
+ self.input_format = input_format
43
49
 
44
50
  @override
45
51
  def update(self, preds, target):
46
- preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes)
47
- target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes)
52
+ if self.input_format == "index":
53
+ preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes)
54
+ target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes)
48
55
  if self.ignore_index is not None:
49
56
  preds, target = _utils.apply_ignore_index(preds, target, self.ignore_index)
50
57
  return super().update(preds, target)
@@ -1,10 +1,12 @@
1
1
  """"Neural Network Semantic Segmentation Module."""
2
2
 
3
+ import functools
3
4
  from typing import Any, Callable, Dict, Iterable, List, Tuple
4
5
 
5
6
  import torch
6
7
  from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
7
8
  from lightning.pytorch.utilities.types import STEP_OUTPUT
9
+ from monai.inferers.inferer import Inferer
8
10
  from torch import nn, optim
9
11
  from torch.optim import lr_scheduler
10
12
  from typing_extensions import override
@@ -15,6 +17,7 @@ from eva.core.models.modules.typings import INPUT_BATCH, INPUT_TENSOR_BATCH
15
17
  from eva.core.models.modules.utils import batch_postprocess, grad, submodule_state_dict
16
18
  from eva.core.utils import parser
17
19
  from eva.vision.models.networks import decoders
20
+ from eva.vision.models.networks.decoders import segmentation
18
21
  from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs
19
22
 
20
23
 
@@ -23,15 +26,17 @@ class SemanticSegmentationModule(module.ModelModule):
23
26
 
24
27
  def __init__(
25
28
  self,
26
- decoder: decoders.Decoder,
29
+ decoder: decoders.Decoder | nn.Module,
27
30
  criterion: Callable[..., torch.Tensor],
28
31
  encoder: Dict[str, Any] | Callable[[torch.Tensor], List[torch.Tensor]] | None = None,
29
32
  lr_multiplier_encoder: float = 0.0,
33
+ inferer: Inferer | None = None,
30
34
  optimizer: OptimizerCallable = optim.AdamW,
31
35
  lr_scheduler: LRSchedulerCallable = lr_scheduler.ConstantLR,
32
36
  metrics: metrics_lib.MetricsSchema | None = None,
33
37
  postprocess: batch_postprocess.BatchPostProcess | None = None,
34
38
  save_decoder_only: bool = True,
39
+ spatial_dims: int = 2,
35
40
  ) -> None:
36
41
  """Initializes the neural net head module.
37
42
 
@@ -44,6 +49,8 @@ class SemanticSegmentationModule(module.ModelModule):
44
49
  during the `configure_model` step.
45
50
  lr_multiplier_encoder: The learning rate multiplier for the
46
51
  encoder parameters. If `0`, it will freeze the encoder.
52
+ inferer: An optional MONAI `Inferer` for inference
53
+ postprocess during evaluation.
47
54
  optimizer: The optimizer to use.
48
55
  lr_scheduler: The learning rate scheduler to use.
49
56
  metrics: The metric groups to track.
@@ -52,6 +59,8 @@ class SemanticSegmentationModule(module.ModelModule):
52
59
  predictions and targets.
53
60
  save_decoder_only: Whether to save only the decoder during checkpointing. If False,
54
61
  will also save the encoder (not recommended when frozen).
62
+ spatial_dims: The number of spatial dimensions, 2 for 2D
63
+ and 3 for 3D segmentation.
55
64
  """
56
65
  super().__init__(metrics=metrics, postprocess=postprocess)
57
66
 
@@ -62,6 +71,8 @@ class SemanticSegmentationModule(module.ModelModule):
62
71
  self.optimizer = optimizer
63
72
  self.lr_scheduler = lr_scheduler
64
73
  self.save_decoder_only = save_decoder_only
74
+ self.inferer = inferer
75
+ self.spatial_dims = spatial_dims
65
76
 
66
77
  @override
67
78
  def configure_model(self) -> None:
@@ -104,25 +115,16 @@ class SemanticSegmentationModule(module.ModelModule):
104
115
  @override
105
116
  def forward(
106
117
  self,
107
- inputs: torch.Tensor,
108
- to_size: Tuple[int, int] | None = None,
118
+ tensor: torch.Tensor,
119
+ to_size: Tuple[int, int],
109
120
  *args: Any,
110
121
  **kwargs: Any,
111
122
  ) -> torch.Tensor:
112
- """Maps the input tensor (image tensor or embeddings) to masks.
113
-
114
- If `inputs` is image tensor, then the `self.encoder`
115
- should be implemented, otherwise it will be interpreted
116
- as embeddings, where the `to_size` should be given.
117
- """
118
- if self.encoder is None and to_size is None:
119
- raise ValueError(
120
- "Please provide the expected `to_size` that the "
121
- "decoder should map the embeddings (`inputs`) to."
122
- )
123
- features = self.encoder(inputs) if self.encoder else inputs
124
- decoder_inputs = DecoderInputs(features, to_size or inputs.shape[-2:], inputs) # type: ignore
125
- return self.decoder(decoder_inputs)
123
+ return (
124
+ self.inferer(tensor, network=functools.partial(self._forward_networks, to_size=to_size))
125
+ if self.inferer is not None and not self.training
126
+ else self._forward_networks(tensor, to_size=to_size)
127
+ )
126
128
 
127
129
  @override
128
130
  def training_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
@@ -137,7 +139,9 @@ class SemanticSegmentationModule(module.ModelModule):
137
139
  return self._batch_step(batch)
138
140
 
139
141
  @override
140
- def predict_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> torch.Tensor:
142
+ def predict_step(
143
+ self, batch: INPUT_BATCH, *args: Any, **kwargs: Any
144
+ ) -> torch.Tensor | List[torch.Tensor]:
141
145
  tensor = INPUT_BATCH(*batch).data
142
146
  return self.encoder(tensor) if isinstance(self.encoder, nn.Module) else tensor
143
147
 
@@ -170,7 +174,7 @@ class SemanticSegmentationModule(module.ModelModule):
170
174
  The batch step output.
171
175
  """
172
176
  data, targets, metadata = INPUT_TENSOR_BATCH(*batch)
173
- predictions = self(data, to_size=targets.shape[-2:])
177
+ predictions = self(data, to_size=targets.shape[-self.spatial_dims :])
174
178
  loss = self.criterion(predictions, targets)
175
179
  return {
176
180
  "loss": loss,
@@ -178,3 +182,12 @@ class SemanticSegmentationModule(module.ModelModule):
178
182
  "predictions": predictions,
179
183
  "metadata": metadata,
180
184
  }
185
+
186
+ def _forward_networks(self, tensor: torch.Tensor, to_size: Tuple[int, int]) -> torch.Tensor:
187
+ """Passes the input tensor through the encoder and decoder."""
188
+ features = self.encoder(tensor) if self.encoder else tensor
189
+ if isinstance(self.decoder, segmentation.Decoder):
190
+ if not isinstance(features, list):
191
+ raise ValueError(f"Expected a list of feature map tensors, got {type(features)}.")
192
+ return self.decoder(DecoderInputs(features, to_size, tensor))
193
+ return self.decoder(features)
@@ -1,6 +1,13 @@
1
1
  """Vision Model Backbones API."""
2
2
 
3
- from eva.vision.models.networks.backbones import pathology, timm, torchhub, universal
3
+ from eva.vision.models.networks.backbones import pathology, radiology, timm, universal
4
4
  from eva.vision.models.networks.backbones.registry import BackboneModelRegistry, register_model
5
5
 
6
- __all__ = ["pathology", "timm", "torchhub", "universal", "BackboneModelRegistry", "register_model"]
6
+ __all__ = [
7
+ "radiology",
8
+ "pathology",
9
+ "timm",
10
+ "universal",
11
+ "BackboneModelRegistry",
12
+ "register_model",
13
+ ]
@@ -1,9 +1,14 @@
1
1
  """Vision Pathology Model Backbones API."""
2
2
 
3
- from eva.vision.models.networks.backbones.pathology.bioptimus import bioptimus_h_optimus_0
3
+ from eva.vision.models.networks.backbones.pathology.bioptimus import (
4
+ bioptimus_h0_mini,
5
+ bioptimus_h_optimus_0,
6
+ )
4
7
  from eva.vision.models.networks.backbones.pathology.gigapath import prov_gigapath
5
8
  from eva.vision.models.networks.backbones.pathology.histai import histai_hibou_b, histai_hibou_l
9
+ from eva.vision.models.networks.backbones.pathology.hkust import hkust_gpfm
6
10
  from eva.vision.models.networks.backbones.pathology.kaiko import (
11
+ kaiko_midnight_12k,
7
12
  kaiko_vitb8,
8
13
  kaiko_vitb16,
9
14
  kaiko_vitl14,
@@ -11,11 +16,12 @@ from eva.vision.models.networks.backbones.pathology.kaiko import (
11
16
  kaiko_vits16,
12
17
  )
13
18
  from eva.vision.models.networks.backbones.pathology.lunit import lunit_vits8, lunit_vits16
14
- from eva.vision.models.networks.backbones.pathology.mahmood import mahmood_uni
19
+ from eva.vision.models.networks.backbones.pathology.mahmood import mahmood_uni, mahmood_uni2_h
15
20
  from eva.vision.models.networks.backbones.pathology.owkin import owkin_phikon, owkin_phikon_v2
16
21
  from eva.vision.models.networks.backbones.pathology.paige import paige_virchow2
17
22
 
18
23
  __all__ = [
24
+ "kaiko_midnight_12k",
19
25
  "kaiko_vitb16",
20
26
  "kaiko_vitb8",
21
27
  "kaiko_vitl14",
@@ -26,9 +32,12 @@ __all__ = [
26
32
  "lunit_vits16",
27
33
  "lunit_vits8",
28
34
  "mahmood_uni",
35
+ "mahmood_uni2_h",
29
36
  "bioptimus_h_optimus_0",
37
+ "bioptimus_h0_mini",
30
38
  "prov_gigapath",
31
39
  "histai_hibou_b",
32
40
  "histai_hibou_l",
33
41
  "paige_virchow2",
42
+ "hkust_gpfm",
34
43
  ]
@@ -5,6 +5,9 @@ from typing import Tuple
5
5
  import timm
6
6
  from torch import nn
7
7
 
8
+ from eva.core.models import transforms
9
+ from eva.vision.models import wrappers
10
+ from eva.vision.models.networks.backbones import _utils
8
11
  from eva.vision.models.networks.backbones.registry import register_model
9
12
 
10
13
 
@@ -13,7 +16,9 @@ def bioptimus_h_optimus_0(
13
16
  dynamic_img_size: bool = True,
14
17
  out_indices: int | Tuple[int, ...] | None = None,
15
18
  ) -> nn.Module:
16
- """Initializes the h_optimus_0 pathology FM by Bioptimus.
19
+ """Initializes the H-Optimus-0 pathology FM by Bioptimus.
20
+
21
+ See https://huggingface.co/bioptimus/H-optimus-0 for details.
17
22
 
18
23
  Args:
19
24
  dynamic_img_size: Whether to allow the interpolation embedding
@@ -32,3 +37,44 @@ def bioptimus_h_optimus_0(
32
37
  out_indices=out_indices,
33
38
  features_only=out_indices is not None,
34
39
  )
40
+
41
+
42
+ @register_model("pathology/bioptimus_h0_mini")
43
+ def bioptimus_h0_mini(
44
+ dynamic_img_size: bool = True,
45
+ out_indices: int | Tuple[int, ...] | None = None,
46
+ hf_token: str | None = None,
47
+ include_patch_tokens: bool = False,
48
+ ) -> nn.Module:
49
+ """Initializes H0-mini (ViT-B) pathology FM by Bioptimus.
50
+
51
+ This model was distilled from H-Optimus-0 on 40M TCGA tiles.
52
+
53
+ See https://huggingface.co/bioptimus/H0-mini for details.
54
+
55
+ Args:
56
+ dynamic_img_size: Support different input image sizes by allowing to change
57
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
58
+ out_indices: Whether and which multi-level patch embeddings to return.
59
+ hf_token: HuggingFace token to download the model.
60
+ include_patch_tokens: Whether to combine the mean aggregated patch tokens with cls token.
61
+
62
+ Returns:
63
+ The model instance.
64
+ """
65
+ _utils.huggingface_login(hf_token)
66
+ return wrappers.TimmModel(
67
+ model_name="hf-hub:bioptimus/H0-mini",
68
+ out_indices=out_indices,
69
+ pretrained=True,
70
+ model_kwargs={
71
+ "dynamic_img_size": dynamic_img_size,
72
+ "mlp_layer": timm.layers.SwiGLUPacked,
73
+ "act_layer": nn.SiLU,
74
+ },
75
+ tensor_transforms=(
76
+ transforms.ExtractCLSFeatures(include_patch_tokens=include_patch_tokens)
77
+ if out_indices is None
78
+ else None
79
+ ),
80
+ )
@@ -0,0 +1,69 @@
1
+ """Pathology FMs from Hong Kong University of Science and Technology."""
2
+
3
+ import re
4
+ from typing import Tuple
5
+
6
+ import timm
7
+ from torch import nn
8
+
9
+ from eva.core.models.wrappers import _utils
10
+ from eva.vision.models.networks.backbones.registry import register_model
11
+
12
+
13
+ @register_model("pathology/hkust_gpfm")
14
+ def hkust_gpfm(
15
+ dynamic_img_size: bool = True,
16
+ out_indices: int | Tuple[int, ...] | None = None,
17
+ ) -> nn.Module:
18
+ """Initializes GPFM model from Hong Kong University of Science and Technology.
19
+
20
+ Ma, J., Guo, Z., Zhou, F., Wang, Y., Xu, Y., et al. (2024).
21
+ Towards a generalizable pathology foundation model via unified knowledge
22
+ distillation (arXiv No. 2407.18449). arXiv. https://arxiv.org/abs/2407.18449
23
+
24
+ Args:
25
+ dynamic_img_size: Support different input image sizes by allowing to change
26
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
27
+ out_indices: Whether and which multi-level patch embeddings to return.
28
+
29
+ Returns:
30
+ The model instance.
31
+ """
32
+ return timm.create_model(
33
+ model_name="vit_large_patch14_dinov2",
34
+ pretrained=True,
35
+ pretrained_cfg={
36
+ "state_dict": _load_state_dict(),
37
+ "num_classes": 0,
38
+ },
39
+ out_indices=out_indices,
40
+ features_only=out_indices is not None,
41
+ **{
42
+ "img_size": 224,
43
+ "patch_size": 14,
44
+ "init_values": 1e-5,
45
+ "qkv_bias": True,
46
+ "dynamic_img_size": dynamic_img_size,
47
+ },
48
+ )
49
+
50
+
51
+ def _load_state_dict() -> dict:
52
+ """Loads the state dict with model weights from github."""
53
+ state_dict = _utils.load_state_dict_from_url(
54
+ url="https://github.com/birkhoffkiki/GPFM/releases/download/ckpt/GPFM.pth",
55
+ md5="0dc7e345de84f385d09c8c782b4b3236",
56
+ )
57
+ return _convert_state_dict(state_dict["teacher"])
58
+
59
+
60
+ def _convert_state_dict(state_dict: dict) -> dict:
61
+ """Rename state dict keys to match timm's format."""
62
+ state_dict = {
63
+ re.sub(r"blocks\.\d+\.(\d+)", r"blocks.\1", key.replace("backbone.", "")): value
64
+ for key, value in state_dict.items()
65
+ }
66
+ remove_keys = ["mask_token"] + [key for key in state_dict.keys() if "dino_head" in key]
67
+ for key in remove_keys:
68
+ state_dict.pop(key)
69
+ return state_dict
@@ -5,9 +5,27 @@ from typing import Tuple
5
5
  import torch
6
6
  from torch import nn
7
7
 
8
+ from eva.vision.models.networks.backbones import _utils
8
9
  from eva.vision.models.networks.backbones.registry import register_model
9
10
 
10
11
 
12
+ @register_model("pathology/kaiko_midnight_12k")
13
+ def kaiko_midnight_12k(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
14
+ """Initializes the Midnight-12k pathology FM by kaiko.ai.
15
+
16
+ Args:
17
+ out_indices: Whether and which multi-level patch embeddings to return.
18
+
19
+ Returns:
20
+ The model instance.
21
+ """
22
+ return _utils.load_hugingface_model(
23
+ model_name="kaiko-ai/midnight",
24
+ out_indices=out_indices,
25
+ model_kwargs={"trust_remote_code": True},
26
+ )
27
+
28
+
11
29
  @register_model("pathology/kaiko_vits16")
12
30
  def kaiko_vits16(
13
31
  dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
@@ -0,0 +1,11 @@
1
+ """Vision Radiology Model Backbones API."""
2
+
3
+ from eva.vision.models.networks.backbones.radiology.swin_unetr import SwinUNETREncoder
4
+ from eva.vision.models.networks.backbones.radiology.voco import VoCoB, VoCoH, VoCoL
5
+
6
+ __all__ = [
7
+ "VoCoB",
8
+ "VoCoL",
9
+ "VoCoH",
10
+ "SwinUNETREncoder",
11
+ ]