kaiko-eva 0.1.1__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. eva/core/callbacks/writers/embeddings/base.py +3 -4
  2. eva/core/data/dataloaders/dataloader.py +2 -2
  3. eva/core/data/splitting/random.py +6 -5
  4. eva/core/data/splitting/stratified.py +12 -6
  5. eva/core/losses/__init__.py +5 -0
  6. eva/core/losses/cross_entropy.py +27 -0
  7. eva/core/metrics/__init__.py +0 -4
  8. eva/core/metrics/defaults/__init__.py +0 -2
  9. eva/core/models/modules/module.py +9 -9
  10. eva/core/models/transforms/extract_cls_features.py +17 -9
  11. eva/core/models/transforms/extract_patch_features.py +23 -11
  12. eva/core/utils/io/__init__.py +2 -1
  13. eva/core/utils/io/gz.py +28 -0
  14. eva/core/utils/multiprocessing.py +46 -1
  15. eva/core/utils/progress_bar.py +15 -0
  16. eva/vision/callbacks/loggers/batch/segmentation.py +7 -4
  17. eva/vision/data/datasets/__init__.py +4 -0
  18. eva/vision/data/datasets/classification/__init__.py +2 -1
  19. eva/vision/data/datasets/classification/camelyon16.py +4 -1
  20. eva/vision/data/datasets/classification/panda.py +17 -1
  21. eva/vision/data/datasets/classification/wsi.py +4 -1
  22. eva/vision/data/datasets/segmentation/__init__.py +2 -0
  23. eva/vision/data/datasets/segmentation/consep.py +2 -2
  24. eva/vision/data/datasets/segmentation/lits.py +49 -29
  25. eva/vision/data/datasets/segmentation/lits_balanced.py +93 -0
  26. eva/vision/data/datasets/segmentation/monusac.py +7 -7
  27. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +50 -18
  28. eva/vision/data/datasets/wsi.py +37 -1
  29. eva/vision/data/wsi/patching/coordinates.py +9 -1
  30. eva/vision/data/wsi/patching/samplers/_utils.py +2 -8
  31. eva/vision/data/wsi/patching/samplers/random.py +4 -2
  32. eva/vision/losses/__init__.py +2 -2
  33. eva/vision/losses/dice.py +75 -8
  34. eva/vision/metrics/__init__.py +11 -0
  35. eva/vision/metrics/defaults/__init__.py +7 -0
  36. eva/{core → vision}/metrics/defaults/segmentation/__init__.py +1 -1
  37. eva/{core → vision}/metrics/defaults/segmentation/multiclass.py +2 -1
  38. eva/vision/metrics/segmentation/BUILD +1 -0
  39. eva/vision/metrics/segmentation/__init__.py +9 -0
  40. eva/vision/metrics/segmentation/_utils.py +69 -0
  41. eva/{core/metrics → vision/metrics/segmentation}/generalized_dice.py +12 -10
  42. eva/vision/metrics/segmentation/mean_iou.py +57 -0
  43. eva/vision/models/modules/semantic_segmentation.py +4 -3
  44. eva/vision/models/networks/backbones/_utils.py +12 -0
  45. eva/vision/models/networks/backbones/pathology/__init__.py +4 -1
  46. eva/vision/models/networks/backbones/pathology/histai.py +8 -2
  47. eva/vision/models/networks/backbones/pathology/mahmood.py +2 -9
  48. eva/vision/models/networks/backbones/pathology/owkin.py +14 -0
  49. eva/vision/models/networks/backbones/pathology/paige.py +51 -0
  50. eva/vision/models/networks/decoders/__init__.py +1 -1
  51. eva/vision/models/networks/decoders/segmentation/__init__.py +12 -4
  52. eva/vision/models/networks/decoders/segmentation/base.py +16 -0
  53. eva/vision/models/networks/decoders/segmentation/{conv2d.py → decoder2d.py} +26 -22
  54. eva/vision/models/networks/decoders/segmentation/linear.py +2 -2
  55. eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +12 -0
  56. eva/vision/models/networks/decoders/segmentation/{common.py → semantic/common.py} +3 -3
  57. eva/vision/models/networks/decoders/segmentation/semantic/with_image.py +94 -0
  58. eva/vision/models/networks/decoders/segmentation/typings.py +18 -0
  59. eva/vision/utils/colormap.py +20 -0
  60. eva/vision/utils/io/__init__.py +7 -1
  61. eva/vision/utils/io/nifti.py +19 -4
  62. {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.5.dist-info}/METADATA +8 -39
  63. {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.5.dist-info}/RECORD +66 -52
  64. {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.5.dist-info}/WHEEL +1 -1
  65. eva/core/metrics/mean_iou.py +0 -120
  66. eva/vision/models/networks/decoders/decoder.py +0 -7
  67. {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.5.dist-info}/entry_points.txt +0 -0
  68. {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.5.dist-info}/licenses/LICENSE +0 -0
eva/vision/losses/dice.py CHANGED
@@ -1,4 +1,6 @@
1
- """Dice loss."""
1
+ """Dice based loss functions."""
2
+
3
+ from typing import Sequence, Tuple
2
4
 
3
5
  import torch
4
6
  from monai import losses
@@ -12,29 +14,94 @@ class DiceLoss(losses.DiceLoss): # type: ignore
12
14
  Extends the implementation from MONAI
13
15
  - to support semantic target labels (meaning targets of shape BHW)
14
16
  - to support `ignore_index` functionality
17
+ - accept weight argument in list format
15
18
  """
16
19
 
17
- def __init__(self, *args, ignore_index: int | None = None, **kwargs) -> None:
18
- """Initialize the DiceLoss with support for ignore_index.
20
+ def __init__(
21
+ self,
22
+ *args,
23
+ ignore_index: int | None = None,
24
+ weight: Sequence[float] | torch.Tensor | None = None,
25
+ **kwargs,
26
+ ) -> None:
27
+ """Initialize the DiceLoss.
19
28
 
20
29
  Args:
21
30
  args: Positional arguments from the base class.
22
31
  ignore_index: Specifies a target value that is ignored and
23
32
  does not contribute to the input gradient.
33
+ weight: A list of weights to assign to each class.
24
34
  kwargs: Key-word arguments from the base class.
25
35
  """
26
- super().__init__(*args, **kwargs)
36
+ if weight is not None and not isinstance(weight, torch.Tensor):
37
+ weight = torch.tensor(weight)
38
+
39
+ super().__init__(*args, **kwargs, weight=weight)
27
40
 
28
41
  self.ignore_index = ignore_index
29
42
 
30
43
  @override
31
44
  def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # noqa
32
- if self.ignore_index is not None:
33
- mask = targets != self.ignore_index
34
- targets = targets * mask
35
- inputs = torch.mul(inputs, mask.unsqueeze(1) if mask.ndim == 3 else mask)
45
+ inputs, targets = _apply_ignore_index(inputs, targets, self.ignore_index)
46
+ targets = _to_one_hot(targets, num_classes=inputs.shape[1])
36
47
 
37
48
  if targets.ndim == 3:
38
49
  targets = one_hot(targets[:, None, ...], num_classes=inputs.shape[1])
39
50
 
40
51
  return super().forward(inputs, targets)
52
+
53
+
54
+ class DiceCELoss(losses.dice.DiceCELoss):
55
+ """Combination of Dice and Cross Entropy Loss.
56
+
57
+ Extends the implementation from MONAI
58
+ - to support semantic target labels (meaning targets of shape BHW)
59
+ - to support `ignore_index` functionality
60
+ - accept weight argument in list format
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ *args,
66
+ ignore_index: int | None = None,
67
+ weight: Sequence[float] | torch.Tensor | None = None,
68
+ **kwargs,
69
+ ) -> None:
70
+ """Initialize the DiceCELoss.
71
+
72
+ Args:
73
+ args: Positional arguments from the base class.
74
+ ignore_index: Specifies a target value that is ignored and
75
+ does not contribute to the input gradient.
76
+ weight: A list of weights to assign to each class.
77
+ kwargs: Key-word arguments from the base class.
78
+ """
79
+ if weight is not None and not isinstance(weight, torch.Tensor):
80
+ weight = torch.tensor(weight)
81
+
82
+ super().__init__(*args, **kwargs, weight=weight)
83
+
84
+ self.ignore_index = ignore_index
85
+
86
+ @override
87
+ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # noqa
88
+ inputs, targets = _apply_ignore_index(inputs, targets, self.ignore_index)
89
+ targets = _to_one_hot(targets, num_classes=inputs.shape[1])
90
+
91
+ return super().forward(inputs, targets)
92
+
93
+
94
+ def _apply_ignore_index(
95
+ inputs: torch.Tensor, targets: torch.Tensor, ignore_index: int | None
96
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
97
+ if ignore_index is not None:
98
+ mask = targets != ignore_index
99
+ targets = targets * mask
100
+ inputs = torch.mul(inputs, mask.unsqueeze(1) if mask.ndim == 3 else mask)
101
+ return inputs, targets
102
+
103
+
104
+ def _to_one_hot(tensor: torch.Tensor, num_classes: int) -> torch.Tensor:
105
+ if tensor.ndim == 3:
106
+ return one_hot(tensor[:, None, ...], num_classes=num_classes)
107
+ return tensor
@@ -0,0 +1,11 @@
1
+ """Default metric collections API."""
2
+
3
+ from eva.vision.metrics.defaults.segmentation import MulticlassSegmentationMetrics
4
+ from eva.vision.metrics.segmentation.generalized_dice import GeneralizedDiceScore
5
+ from eva.vision.metrics.segmentation.mean_iou import MeanIoU
6
+
7
+ __all__ = [
8
+ "MulticlassSegmentationMetrics",
9
+ "GeneralizedDiceScore",
10
+ "MeanIoU",
11
+ ]
@@ -0,0 +1,7 @@
1
+ """Default metric collections API."""
2
+
3
+ from eva.vision.metrics.defaults.segmentation import MulticlassSegmentationMetrics
4
+
5
+ __all__ = [
6
+ "MulticlassSegmentationMetrics",
7
+ ]
@@ -1,5 +1,5 @@
1
1
  """Default segmentation metric collections API."""
2
2
 
3
- from eva.core.metrics.defaults.segmentation.multiclass import MulticlassSegmentationMetrics
3
+ from eva.vision.metrics.defaults.segmentation.multiclass import MulticlassSegmentationMetrics
4
4
 
5
5
  __all__ = ["MulticlassSegmentationMetrics"]
@@ -1,6 +1,7 @@
1
1
  """Default metric collection for multiclass semantic segmentation tasks."""
2
2
 
3
- from eva.core.metrics import generalized_dice, mean_iou, structs
3
+ from eva.core.metrics import structs
4
+ from eva.vision.metrics.segmentation import generalized_dice, mean_iou
4
5
 
5
6
 
6
7
  class MulticlassSegmentationMetrics(structs.MetricCollection):
@@ -0,0 +1 @@
1
+ python_sources()
@@ -0,0 +1,9 @@
1
+ """Segmentation metrics API."""
2
+
3
+ from eva.vision.metrics.segmentation.generalized_dice import GeneralizedDiceScore
4
+ from eva.vision.metrics.segmentation.mean_iou import MeanIoU
5
+
6
+ __all__ = [
7
+ "GeneralizedDiceScore",
8
+ "MeanIoU",
9
+ ]
@@ -0,0 +1,69 @@
1
+ """Utils for segmentation metric collections."""
2
+
3
+ from typing import Tuple
4
+
5
+ import torch
6
+
7
+
8
+ def apply_ignore_index(
9
+ preds: torch.Tensor, target: torch.Tensor, ignore_index: int, num_classes: int
10
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
11
+ """Applies the ignore index to the predictions and target tensors.
12
+
13
+ 1. Masks the values in the target tensor that correspond to the ignored index.
14
+ 2. Remove the channel corresponding to the ignored index from both tensors.
15
+
16
+ Args:
17
+ preds: The predictions tensor. Expected to be of shape `(N,C,...)`.
18
+ target: The target tensor. Expected to be of shape `(N,C,...)`.
19
+ ignore_index: The index to ignore.
20
+ num_classes: The number of classes.
21
+
22
+ Returns:
23
+ The modified predictions and target tensors of shape `(N,C-1,...)`.
24
+ """
25
+ if ignore_index < 0:
26
+ raise ValueError("ignore_index must be a non-negative integer")
27
+
28
+ ignore_mask = preds[:, ignore_index] == 1
29
+ target = target * (~ignore_mask.unsqueeze(1))
30
+
31
+ preds = _ignore_tensor_channel(preds, ignore_index)
32
+ target = _ignore_tensor_channel(target, ignore_index)
33
+
34
+ return preds, target
35
+
36
+
37
+ def index_to_one_hot(tensor: torch.Tensor, num_classes: int) -> torch.Tensor:
38
+ """Converts an index tensor to a one-hot tensor.
39
+
40
+ Args:
41
+ tensor: The index tensor to convert. Expected to be of shape `(N,...)`.
42
+ num_classes: The number of classes to one-hot encode.
43
+
44
+ Returns:
45
+ A one-hot tensor of shape `(N,C,...)`.
46
+ """
47
+ if not _is_one_hot(tensor):
48
+ tensor = torch.nn.functional.one_hot(tensor.long(), num_classes=num_classes).movedim(-1, 1)
49
+ return tensor
50
+
51
+
52
+ def _ignore_tensor_channel(tensor: torch.Tensor, ignore_index: int) -> torch.Tensor:
53
+ """Removes the channel corresponding to the specified ignore index.
54
+
55
+ Args:
56
+ tensor: The tensor to remove the channel from. Expected to be of shape `(N,C,...)`.
57
+ ignore_index: The index of the channel dimension (C) to remove.
58
+
59
+ Returns:
60
+ A tensor without the specified channel `(N,C-1,...)`.
61
+ """
62
+ if ignore_index < 0:
63
+ raise ValueError("ignore_index must be a non-negative integer")
64
+ return torch.cat([tensor[:, :ignore_index], tensor[:, ignore_index + 1 :]], dim=1)
65
+
66
+
67
+ def _is_one_hot(tensor: torch.Tensor, expected_dim: int = 4) -> bool:
68
+ """Checks if the tensor is a one-hot tensor."""
69
+ return bool((tensor.bool() == tensor).all()) and tensor.ndim == expected_dim
@@ -6,6 +6,8 @@ import torch
6
6
  from torchmetrics import segmentation
7
7
  from typing_extensions import override
8
8
 
9
+ from eva.vision.metrics.segmentation import _utils
10
+
9
11
 
10
12
  class GeneralizedDiceScore(segmentation.GeneralizedDiceScore):
11
13
  """Defines the Generalized Dice Score.
@@ -30,8 +32,6 @@ class GeneralizedDiceScore(segmentation.GeneralizedDiceScore):
30
32
  include_background: Whether to include the background class in the computation
31
33
  weight_type: The type of weight to apply to each class. Can be one of `"square"`,
32
34
  `"simple"`, or `"linear"`.
33
- input_format: What kind of input the function receives. Choose between ``"one-hot"``
34
- for one-hot encoded tensors or ``"index"`` for index tensors.
35
35
  ignore_index: Integer specifying a target class to ignore. If given, this class
36
36
  index does not contribute to the returned score, regardless of reduction method.
37
37
  per_class: Whether to compute the IoU for each class separately. If set to ``False``,
@@ -39,21 +39,23 @@ class GeneralizedDiceScore(segmentation.GeneralizedDiceScore):
39
39
  kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
40
40
  """
41
41
  super().__init__(
42
- num_classes=num_classes,
42
+ num_classes=num_classes
43
+ - (ignore_index is not None)
44
+ + (ignore_index == 0 and not include_background),
43
45
  include_background=include_background,
44
46
  weight_type=weight_type,
45
47
  per_class=per_class,
46
48
  **kwargs,
47
49
  )
48
-
50
+ self.orig_num_classes = num_classes
49
51
  self.ignore_index = ignore_index
50
52
 
51
53
  @override
52
54
  def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
55
+ preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes)
56
+ target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes)
53
57
  if self.ignore_index is not None:
54
- mask = target != self.ignore_index
55
- mask = mask.all(dim=-1, keepdim=True)
56
- preds = preds * mask
57
- target = target * mask
58
-
59
- super().update(preds=preds, target=target)
58
+ preds, target = _utils.apply_ignore_index(
59
+ preds, target, self.ignore_index, self.num_classes
60
+ )
61
+ super().update(preds=preds.long(), target=target.long())
@@ -0,0 +1,57 @@
1
+ """MeanIoU metric for semantic segmentation."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+ from torchmetrics import segmentation
7
+ from typing_extensions import override
8
+
9
+ from eva.vision.metrics.segmentation import _utils
10
+
11
+
12
+ class MeanIoU(segmentation.MeanIoU):
13
+ """MeanIoU (mIOU) metric for semantic segmentation.
14
+
15
+ It expands the `torchmetrics` class by including an `ignore_index`
16
+ functionality.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ num_classes: int,
22
+ include_background: bool = True,
23
+ ignore_index: int | None = None,
24
+ per_class: bool = False,
25
+ **kwargs: Any,
26
+ ) -> None:
27
+ """Initializes the metric.
28
+
29
+ Args:
30
+ num_classes: The number of classes in the segmentation problem.
31
+ include_background: Whether to include the background class in the computation
32
+ ignore_index: Integer specifying a target class to ignore. If given, this class
33
+ index does not contribute to the returned score, regardless of reduction method.
34
+ per_class: Whether to compute the IoU for each class separately. If set to ``False``,
35
+ the metric will compute the mean IoU over all classes.
36
+ kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
37
+ """
38
+ super().__init__(
39
+ num_classes=num_classes
40
+ - (ignore_index is not None)
41
+ + (ignore_index == 0 and not include_background),
42
+ include_background=include_background,
43
+ per_class=per_class,
44
+ **kwargs,
45
+ )
46
+ self.orig_num_classes = num_classes
47
+ self.ignore_index = ignore_index
48
+
49
+ @override
50
+ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
51
+ preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes)
52
+ target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes)
53
+ if self.ignore_index is not None:
54
+ preds, target = _utils.apply_ignore_index(
55
+ preds, target, self.ignore_index, self.num_classes
56
+ )
57
+ super().update(preds=preds.long(), target=target.long())
@@ -15,6 +15,7 @@ from eva.core.models.modules.typings import INPUT_BATCH, INPUT_TENSOR_BATCH
15
15
  from eva.core.models.modules.utils import batch_postprocess, grad
16
16
  from eva.core.utils import parser
17
17
  from eva.vision.models.networks import decoders
18
+ from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs
18
19
 
19
20
 
20
21
  class SemanticSegmentationModule(module.ModelModule):
@@ -101,9 +102,9 @@ class SemanticSegmentationModule(module.ModelModule):
101
102
  "Please provide the expected `to_size` that the "
102
103
  "decoder should map the embeddings (`inputs`) to."
103
104
  )
104
-
105
- patch_embeddings = self.encoder(inputs) if self.encoder else inputs
106
- return self.decoder(patch_embeddings, to_size or inputs.shape[-2:])
105
+ features = self.encoder(inputs) if self.encoder else inputs
106
+ decoder_inputs = DecoderInputs(features, to_size or inputs.shape[-2:], inputs) # type: ignore
107
+ return self.decoder(decoder_inputs)
107
108
 
108
109
  @override
109
110
  def training_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
@@ -1,7 +1,9 @@
1
1
  """Utilis for backbone networks."""
2
2
 
3
+ import os
3
4
  from typing import Any, Dict, Tuple
4
5
 
6
+ import huggingface_hub
5
7
  from torch import nn
6
8
 
7
9
  from eva import models
@@ -37,3 +39,13 @@ def load_hugingface_model(
37
39
  tensor_transforms=tensor_transforms,
38
40
  model_kwargs=model_kwargs,
39
41
  )
42
+
43
+
44
+ def huggingface_login(hf_token: str | None = None):
45
+ token = hf_token or os.environ.get("HF_TOKEN")
46
+ if not token:
47
+ raise ValueError(
48
+ "Please provide a HuggingFace token to download the model. "
49
+ "You can either pass it as an argument or set the env variable HF_TOKEN."
50
+ )
51
+ huggingface_hub.login(token=token)
@@ -12,7 +12,8 @@ from eva.vision.models.networks.backbones.pathology.kaiko import (
12
12
  )
13
13
  from eva.vision.models.networks.backbones.pathology.lunit import lunit_vits8, lunit_vits16
14
14
  from eva.vision.models.networks.backbones.pathology.mahmood import mahmood_uni
15
- from eva.vision.models.networks.backbones.pathology.owkin import owkin_phikon
15
+ from eva.vision.models.networks.backbones.pathology.owkin import owkin_phikon, owkin_phikon_v2
16
+ from eva.vision.models.networks.backbones.pathology.paige import paige_virchow2
16
17
 
17
18
  __all__ = [
18
19
  "kaiko_vitb16",
@@ -21,6 +22,7 @@ __all__ = [
21
22
  "kaiko_vits16",
22
23
  "kaiko_vits8",
23
24
  "owkin_phikon",
25
+ "owkin_phikon_v2",
24
26
  "lunit_vits16",
25
27
  "lunit_vits8",
26
28
  "mahmood_uni",
@@ -28,4 +30,5 @@ __all__ = [
28
30
  "prov_gigapath",
29
31
  "histai_hibou_b",
30
32
  "histai_hibou_l",
33
+ "paige_virchow2",
31
34
  ]
@@ -12,6 +12,9 @@ from eva.vision.models.networks.backbones.registry import register_model
12
12
  def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
13
13
  """Initializes the hibou-B pathology FM by hist.ai (https://huggingface.co/histai/hibou-B).
14
14
 
15
+ Uses a customized implementation of the DINOv2 architecture from the transformers
16
+ library to add support for registers, which requires the trust_remote_code=True flag.
17
+
15
18
  Args:
16
19
  out_indices: Whether and which multi-level patch embeddings to return.
17
20
  Currently only out_indices=1 is supported.
@@ -23,7 +26,7 @@ def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Modul
23
26
  model_name="histai/hibou-B",
24
27
  out_indices=out_indices,
25
28
  model_kwargs={"trust_remote_code": True},
26
- transform_args={"ignore_remaining_dims": True} if out_indices is not None else None,
29
+ transform_args={"num_register_tokens": 4} if out_indices is not None else None,
27
30
  )
28
31
 
29
32
 
@@ -31,6 +34,9 @@ def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Modul
31
34
  def histai_hibou_l(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
32
35
  """Initializes the hibou-L pathology FM by hist.ai (https://huggingface.co/histai/hibou-L).
33
36
 
37
+ Uses a customized implementation of the DINOv2 architecture from the transformers
38
+ library to add support for registers, which requires the trust_remote_code=True flag.
39
+
34
40
  Args:
35
41
  out_indices: Whether and which multi-level patch embeddings to return.
36
42
  Currently only out_indices=1 is supported.
@@ -42,5 +48,5 @@ def histai_hibou_l(out_indices: int | Tuple[int, ...] | None = None) -> nn.Modul
42
48
  model_name="histai/hibou-L",
43
49
  out_indices=out_indices,
44
50
  model_kwargs={"trust_remote_code": True},
45
- transform_args={"ignore_remaining_dims": True} if out_indices is not None else None,
51
+ transform_args={"num_register_tokens": 4} if out_indices is not None else None,
46
52
  )
@@ -9,6 +9,7 @@ from loguru import logger
9
9
  from torch import nn
10
10
 
11
11
  from eva.vision.models import wrappers
12
+ from eva.vision.models.networks.backbones import _utils
12
13
  from eva.vision.models.networks.backbones.registry import register_model
13
14
 
14
15
 
@@ -31,19 +32,11 @@ def mahmood_uni(
31
32
  Returns:
32
33
  The model instance.
33
34
  """
34
- token = hf_token or os.environ.get("HF_TOKEN")
35
- if not token:
36
- raise ValueError(
37
- "Please provide a HuggingFace token to download the model. "
38
- "You can either pass it as an argument or set the env variable HF_TOKEN."
39
- )
40
-
41
35
  checkpoint_path = os.path.join(download_dir, "pytorch_model.bin")
42
-
43
36
  if not os.path.exists(checkpoint_path):
44
37
  logger.info(f"Downloading the model checkpoint to {download_dir} ...")
45
38
  os.makedirs(download_dir, exist_ok=True)
46
- huggingface_hub.login(token=token)
39
+ _utils.huggingface_login(hf_token)
47
40
  huggingface_hub.hf_hub_download(
48
41
  "MahmoodLab/UNI",
49
42
  filename="pytorch_model.bin",
@@ -20,3 +20,17 @@ def owkin_phikon(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
20
20
  The model instance.
21
21
  """
22
22
  return _utils.load_hugingface_model(model_name="owkin/phikon", out_indices=out_indices)
23
+
24
+
25
+ @register_model("pathology/owkin_phikon_v2")
26
+ def owkin_phikon_v2(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
27
+ """Initializes the phikon-v2 pathology FM by owkin (https://huggingface.co/owkin/phikon-v2).
28
+
29
+ Args:
30
+ out_indices: Whether and which multi-level patch embeddings to return.
31
+ Currently only out_indices=1 is supported.
32
+
33
+ Returns:
34
+ The model instance.
35
+ """
36
+ return _utils.load_hugingface_model(model_name="owkin/phikon-v2", out_indices=out_indices)
@@ -0,0 +1,51 @@
1
+ """Pathology FMs from paige.ai.
2
+
3
+ Source: https://huggingface.co/paige-ai/
4
+ """
5
+
6
+ from typing import Tuple
7
+
8
+ import timm
9
+ import torch.nn as nn
10
+
11
+ from eva.core.models import transforms
12
+ from eva.vision.models import wrappers
13
+ from eva.vision.models.networks.backbones import _utils
14
+ from eva.vision.models.networks.backbones.registry import register_model
15
+
16
+
17
+ @register_model("pathology/paige_virchow2")
18
+ def paige_virchow2(
19
+ dynamic_img_size: bool = True,
20
+ out_indices: int | Tuple[int, ...] | None = None,
21
+ hf_token: str | None = None,
22
+ include_patch_tokens: bool = False,
23
+ ) -> nn.Module:
24
+ """Initializes the Virchow2 pathology FM by paige.ai.
25
+
26
+ Args:
27
+ dynamic_img_size: Support different input image sizes by allowing to change
28
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
29
+ out_indices: Whether and which multi-level patch embeddings to return.
30
+ include_patch_tokens: Whether to combine the mean aggregated patch tokens with cls token.
31
+ hf_token: HuggingFace token to download the model.
32
+
33
+ Returns:
34
+ The model instance.
35
+ """
36
+ _utils.huggingface_login(hf_token)
37
+ return wrappers.TimmModel(
38
+ model_name="hf-hub:paige-ai/Virchow2",
39
+ out_indices=out_indices,
40
+ pretrained=True,
41
+ model_kwargs={
42
+ "dynamic_img_size": dynamic_img_size,
43
+ "mlp_layer": timm.layers.SwiGLUPacked,
44
+ "act_layer": nn.SiLU,
45
+ },
46
+ tensor_transforms=(
47
+ transforms.ExtractCLSFeatures(include_patch_tokens=include_patch_tokens)
48
+ if out_indices is None
49
+ else None
50
+ ),
51
+ )
@@ -1,6 +1,6 @@
1
1
  """Decoder heads API."""
2
2
 
3
3
  from eva.vision.models.networks.decoders import segmentation
4
- from eva.vision.models.networks.decoders.decoder import Decoder
4
+ from eva.vision.models.networks.decoders.segmentation.base import Decoder
5
5
 
6
6
  __all__ = ["segmentation", "Decoder"]
@@ -1,11 +1,19 @@
1
1
  """Segmentation decoder heads API."""
2
2
 
3
- from eva.vision.models.networks.decoders.segmentation.common import (
3
+ from eva.vision.models.networks.decoders.segmentation.decoder2d import Decoder2D
4
+ from eva.vision.models.networks.decoders.segmentation.linear import LinearDecoder
5
+ from eva.vision.models.networks.decoders.segmentation.semantic import (
4
6
  ConvDecoder1x1,
5
7
  ConvDecoderMS,
8
+ ConvDecoderWithImage,
6
9
  SingleLinearDecoder,
7
10
  )
8
- from eva.vision.models.networks.decoders.segmentation.conv2d import ConvDecoder
9
- from eva.vision.models.networks.decoders.segmentation.linear import LinearDecoder
10
11
 
11
- __all__ = ["ConvDecoder1x1", "ConvDecoderMS", "SingleLinearDecoder", "ConvDecoder", "LinearDecoder"]
12
+ __all__ = [
13
+ "ConvDecoder1x1",
14
+ "ConvDecoderMS",
15
+ "SingleLinearDecoder",
16
+ "ConvDecoderWithImage",
17
+ "Decoder2D",
18
+ "LinearDecoder",
19
+ ]
@@ -0,0 +1,16 @@
1
+ """Semantic segmentation decoder base class."""
2
+
3
+ import abc
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs
9
+
10
+
11
+ class Decoder(nn.Module, abc.ABC):
12
+ """Abstract base class for segmentation decoders."""
13
+
14
+ @abc.abstractmethod
15
+ def forward(self, decoder_inputs: DecoderInputs) -> torch.Tensor:
16
+ """Forward pass of the decoder."""