kaiko-eva 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. eva/core/data/datasets/base.py +7 -2
  2. eva/core/models/modules/head.py +4 -2
  3. eva/core/models/modules/typings.py +2 -2
  4. eva/core/models/transforms/__init__.py +2 -1
  5. eva/core/models/transforms/as_discrete.py +57 -0
  6. eva/core/models/wrappers/_utils.py +121 -1
  7. eva/core/utils/suppress_logs.py +28 -0
  8. eva/vision/data/__init__.py +2 -2
  9. eva/vision/data/dataloaders/__init__.py +5 -0
  10. eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
  11. eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
  12. eva/vision/data/datasets/__init__.py +2 -2
  13. eva/vision/data/datasets/classification/bach.py +3 -4
  14. eva/vision/data/datasets/classification/bracs.py +3 -4
  15. eva/vision/data/datasets/classification/breakhis.py +3 -4
  16. eva/vision/data/datasets/classification/camelyon16.py +4 -5
  17. eva/vision/data/datasets/classification/crc.py +3 -4
  18. eva/vision/data/datasets/classification/gleason_arvaniti.py +3 -4
  19. eva/vision/data/datasets/classification/mhist.py +3 -4
  20. eva/vision/data/datasets/classification/panda.py +4 -5
  21. eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
  22. eva/vision/data/datasets/classification/unitopatho.py +3 -4
  23. eva/vision/data/datasets/classification/wsi.py +6 -5
  24. eva/vision/data/datasets/segmentation/__init__.py +2 -2
  25. eva/vision/data/datasets/segmentation/_utils.py +47 -0
  26. eva/vision/data/datasets/segmentation/bcss.py +7 -8
  27. eva/vision/data/datasets/segmentation/btcv.py +236 -0
  28. eva/vision/data/datasets/segmentation/consep.py +6 -7
  29. eva/vision/data/datasets/segmentation/lits.py +9 -8
  30. eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
  31. eva/vision/data/datasets/segmentation/monusac.py +4 -5
  32. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
  33. eva/vision/data/datasets/vision.py +95 -4
  34. eva/vision/data/datasets/wsi.py +5 -5
  35. eva/vision/data/transforms/__init__.py +22 -3
  36. eva/vision/data/transforms/common/__init__.py +1 -2
  37. eva/vision/data/transforms/croppad/__init__.py +11 -0
  38. eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
  39. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
  40. eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
  41. eva/vision/data/transforms/intensity/__init__.py +11 -0
  42. eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
  43. eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
  44. eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
  45. eva/vision/data/transforms/spatial/__init__.py +7 -0
  46. eva/vision/data/transforms/spatial/flip.py +72 -0
  47. eva/vision/data/transforms/spatial/rotate.py +53 -0
  48. eva/vision/data/transforms/spatial/spacing.py +69 -0
  49. eva/vision/data/transforms/utility/__init__.py +5 -0
  50. eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
  51. eva/vision/data/tv_tensors/__init__.py +5 -0
  52. eva/vision/data/tv_tensors/volume.py +61 -0
  53. eva/vision/metrics/segmentation/monai_dice.py +9 -2
  54. eva/vision/models/modules/semantic_segmentation.py +28 -20
  55. eva/vision/models/networks/backbones/__init__.py +9 -2
  56. eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
  57. eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
  58. eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
  59. eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
  60. eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
  61. eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
  62. eva/vision/models/networks/backbones/radiology/voco.py +75 -0
  63. eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
  64. eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
  65. eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
  66. eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
  67. eva/vision/utils/io/__init__.py +2 -0
  68. eva/vision/utils/io/nifti.py +91 -11
  69. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/METADATA +3 -1
  70. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/RECORD +73 -57
  71. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/WHEEL +1 -1
  72. eva/vision/data/datasets/classification/base.py +0 -96
  73. eva/vision/data/datasets/segmentation/base.py +0 -96
  74. eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
  75. eva/vision/data/transforms/normalization/__init__.py +0 -6
  76. eva/vision/data/transforms/normalization/clamp.py +0 -43
  77. eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
  78. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
  79. eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
  80. eva/vision/metrics/segmentation/BUILD +0 -1
  81. eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
  82. eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
  83. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/entry_points.txt +0 -0
  84. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,53 @@
1
+ """Rotation transforms."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, List
5
+
6
+ from monai.transforms.spatial import array as monai_spatial_transforms
7
+ from torchvision import tv_tensors
8
+ from torchvision.transforms import v2
9
+ from typing_extensions import override
10
+
11
+ from eva.vision.data import tv_tensors as eva_tv_tensors
12
+
13
+
14
+ class RandRotate90(v2.Transform):
15
+ """Rotate input tensors by 90 degrees."""
16
+
17
+ def __init__(
18
+ self,
19
+ prob: float = 0.1,
20
+ max_k: int = 3,
21
+ spatial_axes: tuple[int, int] = (1, 2),
22
+ ) -> None:
23
+ """Initializes the transform.
24
+
25
+ Args:
26
+ prob: probability of rotating.
27
+ (Default 0.1, with 10% probability it returns a rotated array)
28
+ max_k: number of rotations will be sampled from `np.random.randint(max_k) + 1`.
29
+ spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes.
30
+ Default: (1, 2), so for [C, T, H, W] will rotate along (H, W) plane (MONAI ignores
31
+ the first C dimension).
32
+ """
33
+ super().__init__()
34
+
35
+ self._rotate = monai_spatial_transforms.RandRotate90(
36
+ prob=prob, max_k=max_k, spatial_axes=spatial_axes
37
+ )
38
+
39
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
40
+ self._rotate.randomize()
41
+ return {}
42
+
43
+ @functools.singledispatchmethod
44
+ @override
45
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
46
+ return inpt
47
+
48
+ @_transform.register(tv_tensors.Image)
49
+ @_transform.register(eva_tv_tensors.Volume)
50
+ @_transform.register(tv_tensors.Mask)
51
+ def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
52
+ inpt_rotated = self._rotate(img=inpt, randomize=False)
53
+ return tv_tensors.wrap(inpt_rotated, like=inpt)
@@ -0,0 +1,69 @@
1
+ """Spacing resample transform."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, List, Sequence
5
+
6
+ import numpy as np
7
+ import torch
8
+ from monai.data import meta_tensor
9
+ from monai.transforms.spatial import array as monai_spatial_transforms
10
+ from torchvision import tv_tensors
11
+ from torchvision.transforms import v2
12
+ from typing_extensions import override
13
+
14
+ from eva.vision.data import tv_tensors as eva_tv_tensors
15
+
16
+
17
+ class Spacing(v2.Transform):
18
+ """Resample input image into the specified `pixdim`.
19
+
20
+ - Expects tensors of shape `[C, T, H, W]`.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ pixdim: Sequence[float] | float | np.ndarray,
26
+ ) -> None:
27
+ """Initializes the transform.
28
+
29
+ Args:
30
+ pixdim: output voxel spacing. if providing a single number,
31
+ will use it for the first dimension. Items of the pixdim
32
+ sequence map to the spatial dimensions of input image, if
33
+ length of pixdim sequence is longer than image spatial
34
+ dimensions, will ignore the longer part, if shorter, will
35
+ pad with the last value. For example, for 3D image if pixdim
36
+ is [1.0, 2.0] it will be padded to [1.0, 2.0, 2.0] if the
37
+ components of the `pixdim` are non-positive values, the
38
+ transform will use the corresponding components of the original
39
+ pixdim, which is computed from the `affine` matrix of input image.
40
+ """
41
+ super().__init__()
42
+
43
+ self._spacing = monai_spatial_transforms.Spacing(pixdim=pixdim, recompute_affine=True)
44
+ self._affine = None
45
+
46
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
47
+ self._affine = next(
48
+ inpt.affine for inpt in flat_inputs if isinstance(inpt, eva_tv_tensors.Volume)
49
+ )
50
+ return {}
51
+
52
+ @functools.singledispatchmethod
53
+ @override
54
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
55
+ return inpt
56
+
57
+ @_transform.register(eva_tv_tensors.Volume)
58
+ def _(self, inpt: eva_tv_tensors.Volume, params: Dict[str, Any]) -> Any:
59
+ inpt_spacing = self._spacing(inpt.to_meta_tensor(), mode="bilinear")
60
+ if not isinstance(inpt_spacing, meta_tensor.MetaTensor):
61
+ raise ValueError(f"Expected MetaTensor, got {type(inpt_spacing)}")
62
+ return eva_tv_tensors.Volume.from_meta_tensor(inpt_spacing)
63
+
64
+ @_transform.register(tv_tensors.Mask)
65
+ def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
66
+ inpt_spacing = self._spacing(
67
+ meta_tensor.MetaTensor(inpt, affine=self._affine), mode="nearest"
68
+ )
69
+ return tv_tensors.wrap(inpt_spacing.to(dtype=torch.long), like=inpt)
@@ -0,0 +1,5 @@
1
+ """Transforms for utility operations."""
2
+
3
+ from eva.vision.data.transforms.utility.ensure_channel_first import EnsureChannelFirst
4
+
5
+ __all__ = ["EnsureChannelFirst"]
@@ -0,0 +1,51 @@
1
+ """Adjust or add the channel dimension of input data to ensure `channel_first` shape."""
2
+
3
+ import functools
4
+ from typing import Any, Dict
5
+
6
+ from monai.transforms.utility import array as monai_utility_transforms
7
+ from torchvision import tv_tensors
8
+ from torchvision.transforms import v2
9
+ from typing_extensions import override
10
+
11
+ from eva.vision.data import tv_tensors as eva_tv_tensors
12
+
13
+
14
+ class EnsureChannelFirst(v2.Transform):
15
+ """Adjust or add the channel dimension of input data to ensure `channel_first` shape."""
16
+
17
+ def __init__(
18
+ self,
19
+ strict_check: bool = True,
20
+ channel_dim: None | str | int = None,
21
+ ) -> None:
22
+ """Initializes the transform.
23
+
24
+ Args:
25
+ strict_check: whether to raise an error when the meta information is insufficient.
26
+ channel_dim: This argument can be used to specify the original channel dimension
27
+ (integer) of the input array.
28
+ It overrides the `original_channel_dim` from provided MetaTensor input.
29
+ If the input array doesn't have a channel dim, this value should be
30
+ ``'no_channel'``.
31
+ If this is set to `None`, this class relies on `img` or `meta_dict` to provide
32
+ the channel dimension.
33
+ """
34
+ super().__init__()
35
+
36
+ self._ensure_channel_first = monai_utility_transforms.EnsureChannelFirst(
37
+ strict_check=strict_check,
38
+ channel_dim=channel_dim,
39
+ )
40
+
41
+ @functools.singledispatchmethod
42
+ @override
43
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
44
+ return inpt
45
+
46
+ @_transform.register(tv_tensors.Image)
47
+ @_transform.register(eva_tv_tensors.Volume)
48
+ @_transform.register(tv_tensors.Mask)
49
+ def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
50
+ inpt_channel_first = self._ensure_channel_first(inpt)
51
+ return tv_tensors.wrap(inpt_channel_first, like=inpt)
@@ -0,0 +1,5 @@
1
+ """Custom `tv_tensors` types for torchvision."""
2
+
3
+ from eva.vision.data.tv_tensors.volume import Volume
4
+
5
+ __all__ = ["Volume"]
@@ -0,0 +1,61 @@
1
+ """Custom `tv_tensors` type for 3D Volumes."""
2
+
3
+ from typing import Any, Dict, Optional, Union
4
+
5
+ import torch
6
+ from monai.data import meta_tensor
7
+ from torchvision import tv_tensors
8
+ from typing_extensions import override
9
+
10
+
11
+ class Volume(tv_tensors.Video):
12
+ """:class:`torchvision.TVTensor` subclass for 3D volumes.
13
+
14
+ - Adds optional metadata and affine matrix to the tensor.
15
+ - Expects tensors to be of shape `[..., T, C, H, W]`.
16
+
17
+ Args:
18
+ data: Any data that can be turned into a tensor with :func:`torch.as_tensor`.
19
+ affine: Affine matrix of the volume. Expected to be of shape `[4, 4]`, and
20
+ columns to correspond to [T, H, W, (translation)] dimensions. Note that
21
+ `nibabel` by default uses [H, W, T, (translation)] order for affine matrices.
22
+ metadata: Metadata associated with the volume.
23
+ dtype: Desired data type. If omitted, will be inferred from `data`.
24
+ device: Desired device.
25
+ requires_grad: Whether autograd should record operations.
26
+ """
27
+
28
+ @override
29
+ def __new__(
30
+ cls,
31
+ data: Any,
32
+ affine: torch.Tensor | None = None,
33
+ metadata: Dict[str, Any] | None = None,
34
+ dtype: Optional[torch.dtype] = None,
35
+ device: Optional[Union[torch.device, str, int]] = None,
36
+ requires_grad: Optional[bool] = None,
37
+ ) -> "Volume":
38
+ cls.affine = affine
39
+ cls.metadata = metadata
40
+
41
+ return super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad) # type: ignore
42
+
43
+ @classmethod
44
+ def from_meta_tensor(cls, meta_tensor: meta_tensor.MetaTensor) -> "Volume":
45
+ """Creates an instance from a :class:`monai.data.meta_tensor.MetaTensor`."""
46
+ return cls(
47
+ meta_tensor.data,
48
+ affine=meta_tensor.affine,
49
+ metadata=meta_tensor.meta,
50
+ dtype=meta_tensor.dtype,
51
+ device=meta_tensor.device,
52
+ requires_grad=meta_tensor.requires_grad,
53
+ ) # type: ignore
54
+
55
+ def to_meta_tensor(self) -> meta_tensor.MetaTensor:
56
+ """Converts the volume to a :class:`monai.data.meta_tensor.MetaTensor`."""
57
+ return meta_tensor.MetaTensor(self, affine=self.affine, meta=self.metadata)
58
+
59
+ def __repr__(self, *, tensor_contents: Any = None) -> str:
60
+ """Returns the string representation of the object."""
61
+ return self._make_repr()
@@ -1,5 +1,7 @@
1
1
  """Wrapper for dice score metric from MONAI."""
2
2
 
3
+ from typing import Literal
4
+
3
5
  from monai.metrics.meandice import DiceMetric
4
6
  from typing_extensions import override
5
7
 
@@ -14,6 +16,7 @@ class MonaiDiceScore(wrappers.MonaiMetricWrapper):
14
16
  self,
15
17
  num_classes: int,
16
18
  include_background: bool = True,
19
+ input_format: Literal["one-hot", "index"] = "index",
17
20
  reduction: str = "mean",
18
21
  ignore_index: int | None = None,
19
22
  **kwargs,
@@ -24,6 +27,8 @@ class MonaiDiceScore(wrappers.MonaiMetricWrapper):
24
27
  num_classes: The number of classes in the dataset.
25
28
  include_background: Whether to include the background class in the computation.
26
29
  reduction: The method to reduce the dice score. Options are `"mean"`, `"sum"`, `"none"`.
30
+ input_format: Choose between "one-hot" for one-hot encoded tensors or "index"
31
+ for index tensors.
27
32
  ignore_index: Integer specifying a target class to ignore. If given, this class
28
33
  index does not contribute to the returned score.
29
34
  kwargs: Additional keyword arguments for instantiating monai's `DiceMetric` class.
@@ -40,11 +45,13 @@ class MonaiDiceScore(wrappers.MonaiMetricWrapper):
40
45
  self.reduction = reduction
41
46
  self.orig_num_classes = num_classes
42
47
  self.ignore_index = ignore_index
48
+ self.input_format = input_format
43
49
 
44
50
  @override
45
51
  def update(self, preds, target):
46
- preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes)
47
- target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes)
52
+ if self.input_format == "index":
53
+ preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes)
54
+ target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes)
48
55
  if self.ignore_index is not None:
49
56
  preds, target = _utils.apply_ignore_index(preds, target, self.ignore_index)
50
57
  return super().update(preds, target)
@@ -1,10 +1,11 @@
1
1
  """"Neural Network Semantic Segmentation Module."""
2
2
 
3
- from typing import Any, Callable, Dict, Iterable, List, Tuple
3
+ from typing import Any, Callable, Dict, Iterable, List
4
4
 
5
5
  import torch
6
6
  from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
7
7
  from lightning.pytorch.utilities.types import STEP_OUTPUT
8
+ from monai.inferers.inferer import Inferer
8
9
  from torch import nn, optim
9
10
  from torch.optim import lr_scheduler
10
11
  from typing_extensions import override
@@ -15,6 +16,7 @@ from eva.core.models.modules.typings import INPUT_BATCH, INPUT_TENSOR_BATCH
15
16
  from eva.core.models.modules.utils import batch_postprocess, grad, submodule_state_dict
16
17
  from eva.core.utils import parser
17
18
  from eva.vision.models.networks import decoders
19
+ from eva.vision.models.networks.decoders import segmentation
18
20
  from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs
19
21
 
20
22
 
@@ -23,10 +25,11 @@ class SemanticSegmentationModule(module.ModelModule):
23
25
 
24
26
  def __init__(
25
27
  self,
26
- decoder: decoders.Decoder,
28
+ decoder: decoders.Decoder | nn.Module,
27
29
  criterion: Callable[..., torch.Tensor],
28
30
  encoder: Dict[str, Any] | Callable[[torch.Tensor], List[torch.Tensor]] | None = None,
29
31
  lr_multiplier_encoder: float = 0.0,
32
+ inferer: Inferer | None = None,
30
33
  optimizer: OptimizerCallable = optim.AdamW,
31
34
  lr_scheduler: LRSchedulerCallable = lr_scheduler.ConstantLR,
32
35
  metrics: metrics_lib.MetricsSchema | None = None,
@@ -44,6 +47,8 @@ class SemanticSegmentationModule(module.ModelModule):
44
47
  during the `configure_model` step.
45
48
  lr_multiplier_encoder: The learning rate multiplier for the
46
49
  encoder parameters. If `0`, it will freeze the encoder.
50
+ inferer: An optional MONAI `Inferer` for inference
51
+ postprocess during evaluation.
47
52
  optimizer: The optimizer to use.
48
53
  lr_scheduler: The learning rate scheduler to use.
49
54
  metrics: The metric groups to track.
@@ -62,6 +67,7 @@ class SemanticSegmentationModule(module.ModelModule):
62
67
  self.optimizer = optimizer
63
68
  self.lr_scheduler = lr_scheduler
64
69
  self.save_decoder_only = save_decoder_only
70
+ self.inferer = inferer
65
71
 
66
72
  @override
67
73
  def configure_model(self) -> None:
@@ -104,25 +110,15 @@ class SemanticSegmentationModule(module.ModelModule):
104
110
  @override
105
111
  def forward(
106
112
  self,
107
- inputs: torch.Tensor,
108
- to_size: Tuple[int, int] | None = None,
113
+ tensor: torch.Tensor,
109
114
  *args: Any,
110
115
  **kwargs: Any,
111
116
  ) -> torch.Tensor:
112
- """Maps the input tensor (image tensor or embeddings) to masks.
113
-
114
- If `inputs` is image tensor, then the `self.encoder`
115
- should be implemented, otherwise it will be interpreted
116
- as embeddings, where the `to_size` should be given.
117
- """
118
- if self.encoder is None and to_size is None:
119
- raise ValueError(
120
- "Please provide the expected `to_size` that the "
121
- "decoder should map the embeddings (`inputs`) to."
122
- )
123
- features = self.encoder(inputs) if self.encoder else inputs
124
- decoder_inputs = DecoderInputs(features, to_size or inputs.shape[-2:], inputs) # type: ignore
125
- return self.decoder(decoder_inputs)
117
+ return (
118
+ self.inferer(tensor, network=self._forward_networks)
119
+ if self.inferer is not None and not self.training
120
+ else self._forward_networks(tensor)
121
+ )
126
122
 
127
123
  @override
128
124
  def training_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
@@ -137,7 +133,9 @@ class SemanticSegmentationModule(module.ModelModule):
137
133
  return self._batch_step(batch)
138
134
 
139
135
  @override
140
- def predict_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> torch.Tensor:
136
+ def predict_step(
137
+ self, batch: INPUT_BATCH, *args: Any, **kwargs: Any
138
+ ) -> torch.Tensor | List[torch.Tensor]:
141
139
  tensor = INPUT_BATCH(*batch).data
142
140
  return self.encoder(tensor) if isinstance(self.encoder, nn.Module) else tensor
143
141
 
@@ -170,7 +168,7 @@ class SemanticSegmentationModule(module.ModelModule):
170
168
  The batch step output.
171
169
  """
172
170
  data, targets, metadata = INPUT_TENSOR_BATCH(*batch)
173
- predictions = self(data, to_size=targets.shape[-2:])
171
+ predictions = self(data)
174
172
  loss = self.criterion(predictions, targets)
175
173
  return {
176
174
  "loss": loss,
@@ -178,3 +176,13 @@ class SemanticSegmentationModule(module.ModelModule):
178
176
  "predictions": predictions,
179
177
  "metadata": metadata,
180
178
  }
179
+
180
+ def _forward_networks(self, tensor: torch.Tensor) -> torch.Tensor:
181
+ """Passes the input tensor through the encoder and decoder."""
182
+ features = self.encoder(tensor) if self.encoder else tensor
183
+ if isinstance(self.decoder, segmentation.Decoder):
184
+ if not isinstance(features, list):
185
+ raise ValueError(f"Expected a list of feature map tensors, got {type(features)}.")
186
+ image_size = (tensor.shape[-2], tensor.shape[-1])
187
+ return self.decoder(DecoderInputs(features, image_size, tensor))
188
+ return self.decoder(features)
@@ -1,6 +1,13 @@
1
1
  """Vision Model Backbones API."""
2
2
 
3
- from eva.vision.models.networks.backbones import pathology, timm, torchhub, universal
3
+ from eva.vision.models.networks.backbones import pathology, radiology, timm, universal
4
4
  from eva.vision.models.networks.backbones.registry import BackboneModelRegistry, register_model
5
5
 
6
- __all__ = ["pathology", "timm", "torchhub", "universal", "BackboneModelRegistry", "register_model"]
6
+ __all__ = [
7
+ "radiology",
8
+ "pathology",
9
+ "timm",
10
+ "universal",
11
+ "BackboneModelRegistry",
12
+ "register_model",
13
+ ]
@@ -1,9 +1,14 @@
1
1
  """Vision Pathology Model Backbones API."""
2
2
 
3
- from eva.vision.models.networks.backbones.pathology.bioptimus import bioptimus_h_optimus_0
3
+ from eva.vision.models.networks.backbones.pathology.bioptimus import (
4
+ bioptimus_h0_mini,
5
+ bioptimus_h_optimus_0,
6
+ )
4
7
  from eva.vision.models.networks.backbones.pathology.gigapath import prov_gigapath
5
8
  from eva.vision.models.networks.backbones.pathology.histai import histai_hibou_b, histai_hibou_l
9
+ from eva.vision.models.networks.backbones.pathology.hkust import hkust_gpfm
6
10
  from eva.vision.models.networks.backbones.pathology.kaiko import (
11
+ kaiko_midnight_12k,
7
12
  kaiko_vitb8,
8
13
  kaiko_vitb16,
9
14
  kaiko_vitl14,
@@ -11,11 +16,12 @@ from eva.vision.models.networks.backbones.pathology.kaiko import (
11
16
  kaiko_vits16,
12
17
  )
13
18
  from eva.vision.models.networks.backbones.pathology.lunit import lunit_vits8, lunit_vits16
14
- from eva.vision.models.networks.backbones.pathology.mahmood import mahmood_uni
19
+ from eva.vision.models.networks.backbones.pathology.mahmood import mahmood_uni, mahmood_uni2_h
15
20
  from eva.vision.models.networks.backbones.pathology.owkin import owkin_phikon, owkin_phikon_v2
16
21
  from eva.vision.models.networks.backbones.pathology.paige import paige_virchow2
17
22
 
18
23
  __all__ = [
24
+ "kaiko_midnight_12k",
19
25
  "kaiko_vitb16",
20
26
  "kaiko_vitb8",
21
27
  "kaiko_vitl14",
@@ -26,9 +32,12 @@ __all__ = [
26
32
  "lunit_vits16",
27
33
  "lunit_vits8",
28
34
  "mahmood_uni",
35
+ "mahmood_uni2_h",
29
36
  "bioptimus_h_optimus_0",
37
+ "bioptimus_h0_mini",
30
38
  "prov_gigapath",
31
39
  "histai_hibou_b",
32
40
  "histai_hibou_l",
33
41
  "paige_virchow2",
42
+ "hkust_gpfm",
34
43
  ]
@@ -5,6 +5,9 @@ from typing import Tuple
5
5
  import timm
6
6
  from torch import nn
7
7
 
8
+ from eva.core.models import transforms
9
+ from eva.vision.models import wrappers
10
+ from eva.vision.models.networks.backbones import _utils
8
11
  from eva.vision.models.networks.backbones.registry import register_model
9
12
 
10
13
 
@@ -13,7 +16,9 @@ def bioptimus_h_optimus_0(
13
16
  dynamic_img_size: bool = True,
14
17
  out_indices: int | Tuple[int, ...] | None = None,
15
18
  ) -> nn.Module:
16
- """Initializes the h_optimus_0 pathology FM by Bioptimus.
19
+ """Initializes the H-Optimus-0 pathology FM by Bioptimus.
20
+
21
+ See https://huggingface.co/bioptimus/H-optimus-0 for details.
17
22
 
18
23
  Args:
19
24
  dynamic_img_size: Whether to allow the interpolation embedding
@@ -32,3 +37,44 @@ def bioptimus_h_optimus_0(
32
37
  out_indices=out_indices,
33
38
  features_only=out_indices is not None,
34
39
  )
40
+
41
+
42
+ @register_model("pathology/bioptimus_h0_mini")
43
+ def bioptimus_h0_mini(
44
+ dynamic_img_size: bool = True,
45
+ out_indices: int | Tuple[int, ...] | None = None,
46
+ hf_token: str | None = None,
47
+ include_patch_tokens: bool = False,
48
+ ) -> nn.Module:
49
+ """Initializes H0-mini (ViT-B) pathology FM by Bioptimus.
50
+
51
+ This model was distilled from H-Optimus-0 on 40M TCGA tiles.
52
+
53
+ See https://huggingface.co/bioptimus/H0-mini for details.
54
+
55
+ Args:
56
+ dynamic_img_size: Support different input image sizes by allowing to change
57
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
58
+ out_indices: Whether and which multi-level patch embeddings to return.
59
+ hf_token: HuggingFace token to download the model.
60
+ include_patch_tokens: Whether to combine the mean aggregated patch tokens with cls token.
61
+
62
+ Returns:
63
+ The model instance.
64
+ """
65
+ _utils.huggingface_login(hf_token)
66
+ return wrappers.TimmModel(
67
+ model_name="hf-hub:bioptimus/H0-mini",
68
+ out_indices=out_indices,
69
+ pretrained=True,
70
+ model_kwargs={
71
+ "dynamic_img_size": dynamic_img_size,
72
+ "mlp_layer": timm.layers.SwiGLUPacked,
73
+ "act_layer": nn.SiLU,
74
+ },
75
+ tensor_transforms=(
76
+ transforms.ExtractCLSFeatures(include_patch_tokens=include_patch_tokens)
77
+ if out_indices is None
78
+ else None
79
+ ),
80
+ )
@@ -0,0 +1,69 @@
1
+ """Pathology FMs from Hong Kong University of Science and Technology."""
2
+
3
+ import re
4
+ from typing import Tuple
5
+
6
+ import timm
7
+ from torch import nn
8
+
9
+ from eva.core.models.wrappers import _utils
10
+ from eva.vision.models.networks.backbones.registry import register_model
11
+
12
+
13
+ @register_model("pathology/hkust_gpfm")
14
+ def hkust_gpfm(
15
+ dynamic_img_size: bool = True,
16
+ out_indices: int | Tuple[int, ...] | None = None,
17
+ ) -> nn.Module:
18
+ """Initializes GPFM model from Hong Kong University of Science and Technology.
19
+
20
+ Ma, J., Guo, Z., Zhou, F., Wang, Y., Xu, Y., et al. (2024).
21
+ Towards a generalizable pathology foundation model via unified knowledge
22
+ distillation (arXiv No. 2407.18449). arXiv. https://arxiv.org/abs/2407.18449
23
+
24
+ Args:
25
+ dynamic_img_size: Support different input image sizes by allowing to change
26
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
27
+ out_indices: Whether and which multi-level patch embeddings to return.
28
+
29
+ Returns:
30
+ The model instance.
31
+ """
32
+ return timm.create_model(
33
+ model_name="vit_large_patch14_dinov2",
34
+ pretrained=True,
35
+ pretrained_cfg={
36
+ "state_dict": _load_state_dict(),
37
+ "num_classes": 0,
38
+ },
39
+ out_indices=out_indices,
40
+ features_only=out_indices is not None,
41
+ **{
42
+ "img_size": 224,
43
+ "patch_size": 14,
44
+ "init_values": 1e-5,
45
+ "qkv_bias": True,
46
+ "dynamic_img_size": dynamic_img_size,
47
+ },
48
+ )
49
+
50
+
51
+ def _load_state_dict() -> dict:
52
+ """Loads the state dict with model weights from github."""
53
+ state_dict = _utils.load_state_dict_from_url(
54
+ url="https://github.com/birkhoffkiki/GPFM/releases/download/ckpt/GPFM.pth",
55
+ md5="0dc7e345de84f385d09c8c782b4b3236",
56
+ )
57
+ return _convert_state_dict(state_dict["teacher"])
58
+
59
+
60
+ def _convert_state_dict(state_dict: dict) -> dict:
61
+ """Rename state dict keys to match timm's format."""
62
+ state_dict = {
63
+ re.sub(r"blocks\.\d+\.(\d+)", r"blocks.\1", key.replace("backbone.", "")): value
64
+ for key, value in state_dict.items()
65
+ }
66
+ remove_keys = ["mask_token"] + [key for key in state_dict.keys() if "dino_head" in key]
67
+ for key in remove_keys:
68
+ state_dict.pop(key)
69
+ return state_dict
@@ -5,9 +5,27 @@ from typing import Tuple
5
5
  import torch
6
6
  from torch import nn
7
7
 
8
+ from eva.vision.models.networks.backbones import _utils
8
9
  from eva.vision.models.networks.backbones.registry import register_model
9
10
 
10
11
 
12
+ @register_model("pathology/kaiko_midnight_12k")
13
+ def kaiko_midnight_12k(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
14
+ """Initializes the Midnight-12k pathology FM by kaiko.ai.
15
+
16
+ Args:
17
+ out_indices: Whether and which multi-level patch embeddings to return.
18
+
19
+ Returns:
20
+ The model instance.
21
+ """
22
+ return _utils.load_hugingface_model(
23
+ model_name="kaiko-ai/midnight",
24
+ out_indices=out_indices,
25
+ model_kwargs={"trust_remote_code": True},
26
+ )
27
+
28
+
11
29
  @register_model("pathology/kaiko_vits16")
12
30
  def kaiko_vits16(
13
31
  dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
@@ -0,0 +1,11 @@
1
+ """Vision Radiology Model Backbones API."""
2
+
3
+ from eva.vision.models.networks.backbones.radiology.swin_unetr import SwinUNETREncoder
4
+ from eva.vision.models.networks.backbones.radiology.voco import VoCoB, VoCoH, VoCoL
5
+
6
+ __all__ = [
7
+ "VoCoB",
8
+ "VoCoL",
9
+ "VoCoH",
10
+ "SwinUNETREncoder",
11
+ ]