kaiko-eva 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (63) hide show
  1. eva/core/callbacks/writers/embeddings/base.py +3 -4
  2. eva/core/data/dataloaders/dataloader.py +2 -2
  3. eva/core/data/splitting/random.py +6 -5
  4. eva/core/data/splitting/stratified.py +12 -6
  5. eva/core/losses/__init__.py +5 -0
  6. eva/core/losses/cross_entropy.py +27 -0
  7. eva/core/metrics/__init__.py +0 -4
  8. eva/core/metrics/defaults/__init__.py +0 -2
  9. eva/core/models/modules/module.py +9 -9
  10. eva/core/models/transforms/extract_cls_features.py +17 -9
  11. eva/core/models/transforms/extract_patch_features.py +23 -11
  12. eva/core/utils/progress_bar.py +15 -0
  13. eva/vision/data/datasets/__init__.py +4 -0
  14. eva/vision/data/datasets/classification/__init__.py +2 -1
  15. eva/vision/data/datasets/classification/camelyon16.py +4 -1
  16. eva/vision/data/datasets/classification/panda.py +17 -1
  17. eva/vision/data/datasets/classification/wsi.py +4 -1
  18. eva/vision/data/datasets/segmentation/__init__.py +2 -0
  19. eva/vision/data/datasets/segmentation/consep.py +2 -2
  20. eva/vision/data/datasets/segmentation/lits.py +49 -29
  21. eva/vision/data/datasets/segmentation/lits_balanced.py +93 -0
  22. eva/vision/data/datasets/segmentation/monusac.py +7 -7
  23. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +2 -2
  24. eva/vision/data/datasets/wsi.py +37 -1
  25. eva/vision/data/wsi/patching/coordinates.py +9 -1
  26. eva/vision/data/wsi/patching/samplers/_utils.py +2 -8
  27. eva/vision/data/wsi/patching/samplers/random.py +4 -2
  28. eva/vision/losses/__init__.py +2 -2
  29. eva/vision/losses/dice.py +75 -8
  30. eva/vision/metrics/__init__.py +11 -0
  31. eva/vision/metrics/defaults/__init__.py +7 -0
  32. eva/{core → vision}/metrics/defaults/segmentation/__init__.py +1 -1
  33. eva/{core → vision}/metrics/defaults/segmentation/multiclass.py +2 -1
  34. eva/vision/metrics/segmentation/BUILD +1 -0
  35. eva/vision/metrics/segmentation/__init__.py +9 -0
  36. eva/vision/metrics/segmentation/_utils.py +69 -0
  37. eva/{core/metrics → vision/metrics/segmentation}/generalized_dice.py +12 -10
  38. eva/vision/metrics/segmentation/mean_iou.py +57 -0
  39. eva/vision/models/modules/semantic_segmentation.py +4 -3
  40. eva/vision/models/networks/backbones/_utils.py +12 -0
  41. eva/vision/models/networks/backbones/pathology/__init__.py +4 -1
  42. eva/vision/models/networks/backbones/pathology/histai.py +8 -2
  43. eva/vision/models/networks/backbones/pathology/mahmood.py +2 -9
  44. eva/vision/models/networks/backbones/pathology/owkin.py +14 -0
  45. eva/vision/models/networks/backbones/pathology/paige.py +51 -0
  46. eva/vision/models/networks/decoders/__init__.py +1 -1
  47. eva/vision/models/networks/decoders/segmentation/__init__.py +12 -4
  48. eva/vision/models/networks/decoders/segmentation/base.py +16 -0
  49. eva/vision/models/networks/decoders/segmentation/{conv2d.py → decoder2d.py} +26 -22
  50. eva/vision/models/networks/decoders/segmentation/linear.py +2 -2
  51. eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +12 -0
  52. eva/vision/models/networks/decoders/segmentation/{common.py → semantic/common.py} +3 -3
  53. eva/vision/models/networks/decoders/segmentation/semantic/with_image.py +94 -0
  54. eva/vision/models/networks/decoders/segmentation/typings.py +18 -0
  55. eva/vision/utils/io/__init__.py +7 -1
  56. eva/vision/utils/io/nifti.py +19 -4
  57. {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.3.dist-info}/METADATA +3 -34
  58. {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.3.dist-info}/RECORD +61 -48
  59. {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.3.dist-info}/WHEEL +1 -1
  60. eva/core/metrics/mean_iou.py +0 -120
  61. eva/vision/models/networks/decoders/decoder.py +0 -7
  62. {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.3.dist-info}/entry_points.txt +0 -0
  63. {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,69 @@
1
+ """Utils for segmentation metric collections."""
2
+
3
+ from typing import Tuple
4
+
5
+ import torch
6
+
7
+
8
+ def apply_ignore_index(
9
+ preds: torch.Tensor, target: torch.Tensor, ignore_index: int, num_classes: int
10
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
11
+ """Applies the ignore index to the predictions and target tensors.
12
+
13
+ 1. Masks the values in the target tensor that correspond to the ignored index.
14
+ 2. Remove the channel corresponding to the ignored index from both tensors.
15
+
16
+ Args:
17
+ preds: The predictions tensor. Expected to be of shape `(N,C,...)`.
18
+ target: The target tensor. Expected to be of shape `(N,C,...)`.
19
+ ignore_index: The index to ignore.
20
+ num_classes: The number of classes.
21
+
22
+ Returns:
23
+ The modified predictions and target tensors of shape `(N,C-1,...)`.
24
+ """
25
+ if ignore_index < 0:
26
+ raise ValueError("ignore_index must be a non-negative integer")
27
+
28
+ ignore_mask = preds[:, ignore_index] == 1
29
+ target = target * (~ignore_mask.unsqueeze(1))
30
+
31
+ preds = _ignore_tensor_channel(preds, ignore_index)
32
+ target = _ignore_tensor_channel(target, ignore_index)
33
+
34
+ return preds, target
35
+
36
+
37
+ def index_to_one_hot(tensor: torch.Tensor, num_classes: int) -> torch.Tensor:
38
+ """Converts an index tensor to a one-hot tensor.
39
+
40
+ Args:
41
+ tensor: The index tensor to convert. Expected to be of shape `(N,...)`.
42
+ num_classes: The number of classes to one-hot encode.
43
+
44
+ Returns:
45
+ A one-hot tensor of shape `(N,C,...)`.
46
+ """
47
+ if not _is_one_hot(tensor):
48
+ tensor = torch.nn.functional.one_hot(tensor.long(), num_classes=num_classes).movedim(-1, 1)
49
+ return tensor
50
+
51
+
52
+ def _ignore_tensor_channel(tensor: torch.Tensor, ignore_index: int) -> torch.Tensor:
53
+ """Removes the channel corresponding to the specified ignore index.
54
+
55
+ Args:
56
+ tensor: The tensor to remove the channel from. Expected to be of shape `(N,C,...)`.
57
+ ignore_index: The index of the channel dimension (C) to remove.
58
+
59
+ Returns:
60
+ A tensor without the specified channel `(N,C-1,...)`.
61
+ """
62
+ if ignore_index < 0:
63
+ raise ValueError("ignore_index must be a non-negative integer")
64
+ return torch.cat([tensor[:, :ignore_index], tensor[:, ignore_index + 1 :]], dim=1)
65
+
66
+
67
+ def _is_one_hot(tensor: torch.Tensor, expected_dim: int = 4) -> bool:
68
+ """Checks if the tensor is a one-hot tensor."""
69
+ return bool((tensor.bool() == tensor).all()) and tensor.ndim == expected_dim
@@ -6,6 +6,8 @@ import torch
6
6
  from torchmetrics import segmentation
7
7
  from typing_extensions import override
8
8
 
9
+ from eva.vision.metrics.segmentation import _utils
10
+
9
11
 
10
12
  class GeneralizedDiceScore(segmentation.GeneralizedDiceScore):
11
13
  """Defines the Generalized Dice Score.
@@ -30,8 +32,6 @@ class GeneralizedDiceScore(segmentation.GeneralizedDiceScore):
30
32
  include_background: Whether to include the background class in the computation
31
33
  weight_type: The type of weight to apply to each class. Can be one of `"square"`,
32
34
  `"simple"`, or `"linear"`.
33
- input_format: What kind of input the function receives. Choose between ``"one-hot"``
34
- for one-hot encoded tensors or ``"index"`` for index tensors.
35
35
  ignore_index: Integer specifying a target class to ignore. If given, this class
36
36
  index does not contribute to the returned score, regardless of reduction method.
37
37
  per_class: Whether to compute the IoU for each class separately. If set to ``False``,
@@ -39,21 +39,23 @@ class GeneralizedDiceScore(segmentation.GeneralizedDiceScore):
39
39
  kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
40
40
  """
41
41
  super().__init__(
42
- num_classes=num_classes,
42
+ num_classes=num_classes
43
+ - (ignore_index is not None)
44
+ + (ignore_index == 0 and not include_background),
43
45
  include_background=include_background,
44
46
  weight_type=weight_type,
45
47
  per_class=per_class,
46
48
  **kwargs,
47
49
  )
48
-
50
+ self.orig_num_classes = num_classes
49
51
  self.ignore_index = ignore_index
50
52
 
51
53
  @override
52
54
  def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
55
+ preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes)
56
+ target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes)
53
57
  if self.ignore_index is not None:
54
- mask = target != self.ignore_index
55
- mask = mask.all(dim=-1, keepdim=True)
56
- preds = preds * mask
57
- target = target * mask
58
-
59
- super().update(preds=preds, target=target)
58
+ preds, target = _utils.apply_ignore_index(
59
+ preds, target, self.ignore_index, self.num_classes
60
+ )
61
+ super().update(preds=preds.long(), target=target.long())
@@ -0,0 +1,57 @@
1
+ """MeanIoU metric for semantic segmentation."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+ from torchmetrics import segmentation
7
+ from typing_extensions import override
8
+
9
+ from eva.vision.metrics.segmentation import _utils
10
+
11
+
12
+ class MeanIoU(segmentation.MeanIoU):
13
+ """MeanIoU (mIOU) metric for semantic segmentation.
14
+
15
+ It expands the `torchmetrics` class by including an `ignore_index`
16
+ functionality.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ num_classes: int,
22
+ include_background: bool = True,
23
+ ignore_index: int | None = None,
24
+ per_class: bool = False,
25
+ **kwargs: Any,
26
+ ) -> None:
27
+ """Initializes the metric.
28
+
29
+ Args:
30
+ num_classes: The number of classes in the segmentation problem.
31
+ include_background: Whether to include the background class in the computation
32
+ ignore_index: Integer specifying a target class to ignore. If given, this class
33
+ index does not contribute to the returned score, regardless of reduction method.
34
+ per_class: Whether to compute the IoU for each class separately. If set to ``False``,
35
+ the metric will compute the mean IoU over all classes.
36
+ kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
37
+ """
38
+ super().__init__(
39
+ num_classes=num_classes
40
+ - (ignore_index is not None)
41
+ + (ignore_index == 0 and not include_background),
42
+ include_background=include_background,
43
+ per_class=per_class,
44
+ **kwargs,
45
+ )
46
+ self.orig_num_classes = num_classes
47
+ self.ignore_index = ignore_index
48
+
49
+ @override
50
+ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
51
+ preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes)
52
+ target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes)
53
+ if self.ignore_index is not None:
54
+ preds, target = _utils.apply_ignore_index(
55
+ preds, target, self.ignore_index, self.num_classes
56
+ )
57
+ super().update(preds=preds.long(), target=target.long())
@@ -15,6 +15,7 @@ from eva.core.models.modules.typings import INPUT_BATCH, INPUT_TENSOR_BATCH
15
15
  from eva.core.models.modules.utils import batch_postprocess, grad
16
16
  from eva.core.utils import parser
17
17
  from eva.vision.models.networks import decoders
18
+ from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs
18
19
 
19
20
 
20
21
  class SemanticSegmentationModule(module.ModelModule):
@@ -101,9 +102,9 @@ class SemanticSegmentationModule(module.ModelModule):
101
102
  "Please provide the expected `to_size` that the "
102
103
  "decoder should map the embeddings (`inputs`) to."
103
104
  )
104
-
105
- patch_embeddings = self.encoder(inputs) if self.encoder else inputs
106
- return self.decoder(patch_embeddings, to_size or inputs.shape[-2:])
105
+ features = self.encoder(inputs) if self.encoder else inputs
106
+ decoder_inputs = DecoderInputs(features, inputs.shape[-2:], inputs) # type: ignore
107
+ return self.decoder(decoder_inputs)
107
108
 
108
109
  @override
109
110
  def training_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
@@ -1,7 +1,9 @@
1
1
  """Utilis for backbone networks."""
2
2
 
3
+ import os
3
4
  from typing import Any, Dict, Tuple
4
5
 
6
+ import huggingface_hub
5
7
  from torch import nn
6
8
 
7
9
  from eva import models
@@ -37,3 +39,13 @@ def load_hugingface_model(
37
39
  tensor_transforms=tensor_transforms,
38
40
  model_kwargs=model_kwargs,
39
41
  )
42
+
43
+
44
+ def huggingface_login(hf_token: str | None = None):
45
+ token = hf_token or os.environ.get("HF_TOKEN")
46
+ if not token:
47
+ raise ValueError(
48
+ "Please provide a HuggingFace token to download the model. "
49
+ "You can either pass it as an argument or set the env variable HF_TOKEN."
50
+ )
51
+ huggingface_hub.login(token=token)
@@ -12,7 +12,8 @@ from eva.vision.models.networks.backbones.pathology.kaiko import (
12
12
  )
13
13
  from eva.vision.models.networks.backbones.pathology.lunit import lunit_vits8, lunit_vits16
14
14
  from eva.vision.models.networks.backbones.pathology.mahmood import mahmood_uni
15
- from eva.vision.models.networks.backbones.pathology.owkin import owkin_phikon
15
+ from eva.vision.models.networks.backbones.pathology.owkin import owkin_phikon, owkin_phikon_v2
16
+ from eva.vision.models.networks.backbones.pathology.paige import paige_virchow2
16
17
 
17
18
  __all__ = [
18
19
  "kaiko_vitb16",
@@ -21,6 +22,7 @@ __all__ = [
21
22
  "kaiko_vits16",
22
23
  "kaiko_vits8",
23
24
  "owkin_phikon",
25
+ "owkin_phikon_v2",
24
26
  "lunit_vits16",
25
27
  "lunit_vits8",
26
28
  "mahmood_uni",
@@ -28,4 +30,5 @@ __all__ = [
28
30
  "prov_gigapath",
29
31
  "histai_hibou_b",
30
32
  "histai_hibou_l",
33
+ "paige_virchow2",
31
34
  ]
@@ -12,6 +12,9 @@ from eva.vision.models.networks.backbones.registry import register_model
12
12
  def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
13
13
  """Initializes the hibou-B pathology FM by hist.ai (https://huggingface.co/histai/hibou-B).
14
14
 
15
+ Uses a customized implementation of the DINOv2 architecture from the transformers
16
+ library to add support for registers, which requires the trust_remote_code=True flag.
17
+
15
18
  Args:
16
19
  out_indices: Whether and which multi-level patch embeddings to return.
17
20
  Currently only out_indices=1 is supported.
@@ -23,7 +26,7 @@ def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Modul
23
26
  model_name="histai/hibou-B",
24
27
  out_indices=out_indices,
25
28
  model_kwargs={"trust_remote_code": True},
26
- transform_args={"ignore_remaining_dims": True} if out_indices is not None else None,
29
+ transform_args={"num_register_tokens": 4} if out_indices is not None else None,
27
30
  )
28
31
 
29
32
 
@@ -31,6 +34,9 @@ def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Modul
31
34
  def histai_hibou_l(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
32
35
  """Initializes the hibou-L pathology FM by hist.ai (https://huggingface.co/histai/hibou-L).
33
36
 
37
+ Uses a customized implementation of the DINOv2 architecture from the transformers
38
+ library to add support for registers, which requires the trust_remote_code=True flag.
39
+
34
40
  Args:
35
41
  out_indices: Whether and which multi-level patch embeddings to return.
36
42
  Currently only out_indices=1 is supported.
@@ -42,5 +48,5 @@ def histai_hibou_l(out_indices: int | Tuple[int, ...] | None = None) -> nn.Modul
42
48
  model_name="histai/hibou-L",
43
49
  out_indices=out_indices,
44
50
  model_kwargs={"trust_remote_code": True},
45
- transform_args={"ignore_remaining_dims": True} if out_indices is not None else None,
51
+ transform_args={"num_register_tokens": 4} if out_indices is not None else None,
46
52
  )
@@ -9,6 +9,7 @@ from loguru import logger
9
9
  from torch import nn
10
10
 
11
11
  from eva.vision.models import wrappers
12
+ from eva.vision.models.networks.backbones import _utils
12
13
  from eva.vision.models.networks.backbones.registry import register_model
13
14
 
14
15
 
@@ -31,19 +32,11 @@ def mahmood_uni(
31
32
  Returns:
32
33
  The model instance.
33
34
  """
34
- token = hf_token or os.environ.get("HF_TOKEN")
35
- if not token:
36
- raise ValueError(
37
- "Please provide a HuggingFace token to download the model. "
38
- "You can either pass it as an argument or set the env variable HF_TOKEN."
39
- )
40
-
41
35
  checkpoint_path = os.path.join(download_dir, "pytorch_model.bin")
42
-
43
36
  if not os.path.exists(checkpoint_path):
44
37
  logger.info(f"Downloading the model checkpoint to {download_dir} ...")
45
38
  os.makedirs(download_dir, exist_ok=True)
46
- huggingface_hub.login(token=token)
39
+ _utils.huggingface_login(hf_token)
47
40
  huggingface_hub.hf_hub_download(
48
41
  "MahmoodLab/UNI",
49
42
  filename="pytorch_model.bin",
@@ -20,3 +20,17 @@ def owkin_phikon(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
20
20
  The model instance.
21
21
  """
22
22
  return _utils.load_hugingface_model(model_name="owkin/phikon", out_indices=out_indices)
23
+
24
+
25
+ @register_model("pathology/owkin_phikon_v2")
26
+ def owkin_phikon_v2(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
27
+ """Initializes the phikon-v2 pathology FM by owkin (https://huggingface.co/owkin/phikon-v2).
28
+
29
+ Args:
30
+ out_indices: Whether and which multi-level patch embeddings to return.
31
+ Currently only out_indices=1 is supported.
32
+
33
+ Returns:
34
+ The model instance.
35
+ """
36
+ return _utils.load_hugingface_model(model_name="owkin/phikon-v2", out_indices=out_indices)
@@ -0,0 +1,51 @@
1
+ """Pathology FMs from paige.ai.
2
+
3
+ Source: https://huggingface.co/paige-ai/
4
+ """
5
+
6
+ from typing import Tuple
7
+
8
+ import timm
9
+ import torch.nn as nn
10
+
11
+ from eva.core.models import transforms
12
+ from eva.vision.models import wrappers
13
+ from eva.vision.models.networks.backbones import _utils
14
+ from eva.vision.models.networks.backbones.registry import register_model
15
+
16
+
17
+ @register_model("pathology/paige_virchow2")
18
+ def paige_virchow2(
19
+ dynamic_img_size: bool = True,
20
+ out_indices: int | Tuple[int, ...] | None = None,
21
+ hf_token: str | None = None,
22
+ include_patch_tokens: bool = False,
23
+ ) -> nn.Module:
24
+ """Initializes the Virchow2 pathology FM by paige.ai.
25
+
26
+ Args:
27
+ dynamic_img_size: Support different input image sizes by allowing to change
28
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
29
+ out_indices: Whether and which multi-level patch embeddings to return.
30
+ include_patch_tokens: Whether to combine the mean aggregated patch tokens with cls token.
31
+ hf_token: HuggingFace token to download the model.
32
+
33
+ Returns:
34
+ The model instance.
35
+ """
36
+ _utils.huggingface_login(hf_token)
37
+ return wrappers.TimmModel(
38
+ model_name="hf-hub:paige-ai/Virchow2",
39
+ out_indices=out_indices,
40
+ pretrained=True,
41
+ model_kwargs={
42
+ "dynamic_img_size": dynamic_img_size,
43
+ "mlp_layer": timm.layers.SwiGLUPacked,
44
+ "act_layer": nn.SiLU,
45
+ },
46
+ tensor_transforms=(
47
+ transforms.ExtractCLSFeatures(include_patch_tokens=include_patch_tokens)
48
+ if out_indices is None
49
+ else None
50
+ ),
51
+ )
@@ -1,6 +1,6 @@
1
1
  """Decoder heads API."""
2
2
 
3
3
  from eva.vision.models.networks.decoders import segmentation
4
- from eva.vision.models.networks.decoders.decoder import Decoder
4
+ from eva.vision.models.networks.decoders.segmentation.base import Decoder
5
5
 
6
6
  __all__ = ["segmentation", "Decoder"]
@@ -1,11 +1,19 @@
1
1
  """Segmentation decoder heads API."""
2
2
 
3
- from eva.vision.models.networks.decoders.segmentation.common import (
3
+ from eva.vision.models.networks.decoders.segmentation.decoder2d import Decoder2D
4
+ from eva.vision.models.networks.decoders.segmentation.linear import LinearDecoder
5
+ from eva.vision.models.networks.decoders.segmentation.semantic import (
4
6
  ConvDecoder1x1,
5
7
  ConvDecoderMS,
8
+ ConvDecoderWithImage,
6
9
  SingleLinearDecoder,
7
10
  )
8
- from eva.vision.models.networks.decoders.segmentation.conv2d import ConvDecoder
9
- from eva.vision.models.networks.decoders.segmentation.linear import LinearDecoder
10
11
 
11
- __all__ = ["ConvDecoder1x1", "ConvDecoderMS", "SingleLinearDecoder", "ConvDecoder", "LinearDecoder"]
12
+ __all__ = [
13
+ "ConvDecoder1x1",
14
+ "ConvDecoderMS",
15
+ "SingleLinearDecoder",
16
+ "ConvDecoderWithImage",
17
+ "Decoder2D",
18
+ "LinearDecoder",
19
+ ]
@@ -0,0 +1,16 @@
1
+ """Semantic segmentation decoder base class."""
2
+
3
+ import abc
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs
9
+
10
+
11
+ class Decoder(nn.Module, abc.ABC):
12
+ """Abstract base class for segmentation decoders."""
13
+
14
+ @abc.abstractmethod
15
+ def forward(self, decoder_inputs: DecoderInputs) -> torch.Tensor:
16
+ """Forward pass of the decoder."""
@@ -1,19 +1,20 @@
1
1
  """Convolutional based semantic segmentation decoder."""
2
2
 
3
- from typing import List, Tuple
3
+ from typing import List, Sequence, Tuple
4
4
 
5
5
  import torch
6
6
  from torch import nn
7
7
  from torch.nn import functional
8
8
 
9
- from eva.vision.models.networks.decoders import decoder
9
+ from eva.vision.models.networks.decoders.segmentation import base
10
+ from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs
10
11
 
11
12
 
12
- class ConvDecoder(decoder.Decoder):
13
- """Convolutional segmentation decoder."""
13
+ class Decoder2D(base.Decoder):
14
+ """Segmentation decoder for 2D applications."""
14
15
 
15
- def __init__(self, layers: nn.Module) -> None:
16
- """Initializes the convolutional based decoder head.
16
+ def __init__(self, layers: nn.Module, combine_features: bool = True) -> None:
17
+ """Initializes the based decoder head.
17
18
 
18
19
  Here the input nn layers will be directly applied to the
19
20
  features of shape (batch_size, hidden_size, n_patches_height,
@@ -21,13 +22,16 @@ class ConvDecoder(decoder.Decoder):
21
22
  Note the n_patches is also known as grid_size.
22
23
 
23
24
  Args:
24
- layers: The convolutional layers to be used as the decoder head.
25
+ layers: The layers to be used as the decoder head.
26
+ combine_features: Whether to combine the features from different
27
+ feature levels into one tensor before applying the decoder head.
25
28
  """
26
29
  super().__init__()
27
30
 
28
31
  self._layers = layers
32
+ self._combine_features = combine_features
29
33
 
30
- def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor:
34
+ def _forward_features(self, features: torch.Tensor | List[torch.Tensor]) -> torch.Tensor:
31
35
  """Forward function for multi-level feature maps to a single one.
32
36
 
33
37
  It will interpolate the features and concat them into a single tensor
@@ -46,6 +50,8 @@ class ConvDecoder(decoder.Decoder):
46
50
  A tensor of shape (batch_size, hidden_size, n_patches_height,
47
51
  n_patches_width) which is feature map of the decoder head.
48
52
  """
53
+ if isinstance(features, torch.Tensor):
54
+ features = [features]
49
55
  if not isinstance(features, list) or features[0].ndim != 4:
50
56
  raise ValueError(
51
57
  "Input features should be a list of four (4) dimensional inputs of "
@@ -63,7 +69,9 @@ class ConvDecoder(decoder.Decoder):
63
69
  ]
64
70
  return torch.cat(upsampled_features, dim=1)
65
71
 
66
- def _forward_head(self, patch_embeddings: torch.Tensor) -> torch.Tensor:
72
+ def _forward_head(
73
+ self, patch_embeddings: torch.Tensor | Sequence[torch.Tensor]
74
+ ) -> torch.Tensor:
67
75
  """Forward of the decoder head.
68
76
 
69
77
  Args:
@@ -75,12 +83,12 @@ class ConvDecoder(decoder.Decoder):
75
83
  """
76
84
  return self._layers(patch_embeddings)
77
85
 
78
- def _cls_seg(
86
+ def _upscale(
79
87
  self,
80
88
  logits: torch.Tensor,
81
89
  image_size: Tuple[int, int],
82
90
  ) -> torch.Tensor:
83
- """Classify each pixel of the image.
91
+ """Upscales the calculated logits to the target image size.
84
92
 
85
93
  Args:
86
94
  logits: The decoder outputs of shape (batch_size, n_classes,
@@ -93,22 +101,18 @@ class ConvDecoder(decoder.Decoder):
93
101
  """
94
102
  return functional.interpolate(logits, image_size, mode="bilinear")
95
103
 
96
- def forward(
97
- self,
98
- features: List[torch.Tensor],
99
- image_size: Tuple[int, int],
100
- ) -> torch.Tensor:
104
+ def forward(self, decoder_inputs: DecoderInputs) -> torch.Tensor:
101
105
  """Maps the patch embeddings to a segmentation mask of the image size.
102
106
 
103
107
  Args:
104
- features: List of multi-level image features of shape (batch_size,
105
- hidden_size, n_patches_height, n_patches_width).
106
- image_size: The target image size (height, width).
108
+ decoder_inputs: Inputs required by the decoder.
107
109
 
108
110
  Returns:
109
111
  Tensor containing scores for all of the classes with shape
110
112
  (batch_size, n_classes, image_height, image_width).
111
113
  """
112
- patch_embeddings = self._forward_features(features)
113
- logits = self._forward_head(patch_embeddings)
114
- return self._cls_seg(logits, image_size)
114
+ features, image_size, _ = DecoderInputs(*decoder_inputs)
115
+ if self._combine_features:
116
+ features = self._forward_features(features)
117
+ logits = self._forward_head(features)
118
+ return self._upscale(logits, image_size)
@@ -6,10 +6,10 @@ import torch
6
6
  from torch import nn
7
7
  from torch.nn import functional
8
8
 
9
- from eva.vision.models.networks.decoders import decoder
9
+ from eva.vision.models.networks.decoders.segmentation import base
10
10
 
11
11
 
12
- class LinearDecoder(decoder.Decoder):
12
+ class LinearDecoder(base.Decoder):
13
13
  """Linear decoder."""
14
14
 
15
15
  def __init__(self, layers: nn.Module) -> None:
@@ -0,0 +1,12 @@
1
+ """Semantic Segmentation decoder heads API."""
2
+
3
+ from eva.vision.models.networks.decoders.segmentation.semantic.common import (
4
+ ConvDecoder1x1,
5
+ ConvDecoderMS,
6
+ SingleLinearDecoder,
7
+ )
8
+ from eva.vision.models.networks.decoders.segmentation.semantic.with_image import (
9
+ ConvDecoderWithImage,
10
+ )
11
+
12
+ __all__ = ["ConvDecoder1x1", "ConvDecoderMS", "SingleLinearDecoder", "ConvDecoderWithImage"]
@@ -7,10 +7,10 @@ output by an encoder into pixel-wise predictions for segmentation tasks.
7
7
 
8
8
  from torch import nn
9
9
 
10
- from eva.vision.models.networks.decoders.segmentation import conv2d, linear
10
+ from eva.vision.models.networks.decoders.segmentation import decoder2d, linear
11
11
 
12
12
 
13
- class ConvDecoder1x1(conv2d.ConvDecoder):
13
+ class ConvDecoder1x1(decoder2d.Decoder2D):
14
14
  """A convolutional decoder with a single 1x1 convolutional layer."""
15
15
 
16
16
  def __init__(self, in_features: int, num_classes: int) -> None:
@@ -29,7 +29,7 @@ class ConvDecoder1x1(conv2d.ConvDecoder):
29
29
  )
30
30
 
31
31
 
32
- class ConvDecoderMS(conv2d.ConvDecoder):
32
+ class ConvDecoderMS(decoder2d.Decoder2D):
33
33
  """A multi-stage convolutional decoder with upsampling and convolutional layers.
34
34
 
35
35
  This decoder applies a series of upsampling and convolutional layers to transform