kaiko-eva 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (159) hide show
  1. eva/core/callbacks/__init__.py +2 -2
  2. eva/core/callbacks/writers/__init__.py +6 -3
  3. eva/core/callbacks/writers/embeddings/__init__.py +6 -0
  4. eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
  5. eva/core/callbacks/writers/embeddings/base.py +192 -0
  6. eva/core/callbacks/writers/embeddings/classification.py +117 -0
  7. eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
  8. eva/core/callbacks/writers/embeddings/typings.py +38 -0
  9. eva/core/data/datasets/__init__.py +2 -2
  10. eva/core/data/datasets/classification/__init__.py +8 -0
  11. eva/core/data/datasets/classification/embeddings.py +34 -0
  12. eva/core/data/datasets/{embeddings/classification → classification}/multi_embeddings.py +13 -9
  13. eva/core/data/datasets/{embeddings/base.py → embeddings.py} +47 -32
  14. eva/core/data/splitting/__init__.py +6 -0
  15. eva/core/data/splitting/random.py +41 -0
  16. eva/core/data/splitting/stratified.py +56 -0
  17. eva/core/loggers/experimental_loggers.py +2 -2
  18. eva/core/loggers/log/__init__.py +3 -2
  19. eva/core/loggers/log/image.py +71 -0
  20. eva/core/loggers/log/parameters.py +10 -0
  21. eva/core/loggers/loggers.py +6 -0
  22. eva/core/metrics/__init__.py +6 -2
  23. eva/core/metrics/defaults/__init__.py +10 -3
  24. eva/core/metrics/defaults/classification/__init__.py +1 -1
  25. eva/core/metrics/defaults/classification/binary.py +0 -9
  26. eva/core/metrics/defaults/classification/multiclass.py +0 -8
  27. eva/core/metrics/defaults/segmentation/__init__.py +5 -0
  28. eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
  29. eva/core/metrics/generalized_dice.py +59 -0
  30. eva/core/metrics/mean_iou.py +120 -0
  31. eva/core/metrics/structs/schemas.py +3 -1
  32. eva/core/models/__init__.py +3 -1
  33. eva/core/models/modules/head.py +10 -4
  34. eva/core/models/modules/typings.py +14 -1
  35. eva/core/models/modules/utils/batch_postprocess.py +37 -5
  36. eva/core/models/networks/__init__.py +1 -2
  37. eva/core/models/networks/mlp.py +2 -2
  38. eva/core/models/transforms/__init__.py +6 -0
  39. eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
  40. eva/core/models/transforms/extract_patch_features.py +47 -0
  41. eva/core/models/wrappers/__init__.py +13 -0
  42. eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
  43. eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
  44. eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
  45. eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
  46. eva/core/trainers/functional.py +1 -0
  47. eva/core/utils/__init__.py +6 -0
  48. eva/core/utils/clone.py +27 -0
  49. eva/core/utils/memory.py +28 -0
  50. eva/core/utils/operations.py +26 -0
  51. eva/core/utils/parser.py +20 -0
  52. eva/vision/__init__.py +2 -2
  53. eva/vision/callbacks/__init__.py +5 -0
  54. eva/vision/callbacks/loggers/__init__.py +5 -0
  55. eva/vision/callbacks/loggers/batch/__init__.py +5 -0
  56. eva/vision/callbacks/loggers/batch/base.py +130 -0
  57. eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
  58. eva/vision/data/datasets/__init__.py +30 -3
  59. eva/vision/data/datasets/_validators.py +15 -2
  60. eva/vision/data/datasets/classification/__init__.py +12 -1
  61. eva/vision/data/datasets/classification/bach.py +10 -15
  62. eva/vision/data/datasets/classification/base.py +17 -24
  63. eva/vision/data/datasets/classification/camelyon16.py +244 -0
  64. eva/vision/data/datasets/classification/crc.py +10 -15
  65. eva/vision/data/datasets/classification/mhist.py +10 -15
  66. eva/vision/data/datasets/classification/panda.py +184 -0
  67. eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
  68. eva/vision/data/datasets/classification/wsi.py +105 -0
  69. eva/vision/data/datasets/segmentation/__init__.py +15 -2
  70. eva/vision/data/datasets/segmentation/_utils.py +38 -0
  71. eva/vision/data/datasets/segmentation/base.py +16 -17
  72. eva/vision/data/datasets/segmentation/bcss.py +236 -0
  73. eva/vision/data/datasets/segmentation/consep.py +156 -0
  74. eva/vision/data/datasets/segmentation/embeddings.py +34 -0
  75. eva/vision/data/datasets/segmentation/lits.py +178 -0
  76. eva/vision/data/datasets/segmentation/monusac.py +236 -0
  77. eva/vision/data/datasets/segmentation/{total_segmentator.py → total_segmentator_2d.py} +130 -36
  78. eva/vision/data/datasets/wsi.py +187 -0
  79. eva/vision/data/transforms/__init__.py +3 -2
  80. eva/vision/data/transforms/common/__init__.py +2 -1
  81. eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
  82. eva/vision/data/transforms/common/resize_and_crop.py +6 -7
  83. eva/vision/data/transforms/normalization/__init__.py +6 -0
  84. eva/vision/data/transforms/normalization/clamp.py +43 -0
  85. eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
  86. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
  87. eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
  88. eva/vision/data/wsi/__init__.py +16 -0
  89. eva/vision/data/wsi/backends/__init__.py +69 -0
  90. eva/vision/data/wsi/backends/base.py +115 -0
  91. eva/vision/data/wsi/backends/openslide.py +73 -0
  92. eva/vision/data/wsi/backends/pil.py +52 -0
  93. eva/vision/data/wsi/backends/tiffslide.py +42 -0
  94. eva/vision/data/wsi/patching/__init__.py +6 -0
  95. eva/vision/data/wsi/patching/coordinates.py +98 -0
  96. eva/vision/data/wsi/patching/mask.py +123 -0
  97. eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
  98. eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
  99. eva/vision/data/wsi/patching/samplers/base.py +48 -0
  100. eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
  101. eva/vision/data/wsi/patching/samplers/grid.py +47 -0
  102. eva/vision/data/wsi/patching/samplers/random.py +41 -0
  103. eva/vision/losses/__init__.py +5 -0
  104. eva/vision/losses/dice.py +40 -0
  105. eva/vision/models/__init__.py +4 -2
  106. eva/vision/models/modules/__init__.py +5 -0
  107. eva/vision/models/modules/semantic_segmentation.py +161 -0
  108. eva/vision/models/networks/__init__.py +1 -2
  109. eva/vision/models/networks/backbones/__init__.py +6 -0
  110. eva/vision/models/networks/backbones/_utils.py +39 -0
  111. eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
  112. eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
  113. eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
  114. eva/vision/models/networks/backbones/pathology/histai.py +46 -0
  115. eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
  116. eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
  117. eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
  118. eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
  119. eva/vision/models/networks/backbones/registry.py +47 -0
  120. eva/vision/models/networks/backbones/timm/__init__.py +5 -0
  121. eva/vision/models/networks/backbones/timm/backbones.py +54 -0
  122. eva/vision/models/networks/backbones/universal/__init__.py +8 -0
  123. eva/vision/models/networks/backbones/universal/vit.py +54 -0
  124. eva/vision/models/networks/decoders/__init__.py +6 -0
  125. eva/vision/models/networks/decoders/decoder.py +7 -0
  126. eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
  127. eva/vision/models/networks/decoders/segmentation/common.py +74 -0
  128. eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
  129. eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
  130. eva/vision/models/wrappers/__init__.py +6 -0
  131. eva/vision/models/wrappers/from_registry.py +48 -0
  132. eva/vision/models/wrappers/from_timm.py +68 -0
  133. eva/vision/utils/colormap.py +77 -0
  134. eva/vision/utils/convert.py +56 -13
  135. eva/vision/utils/io/__init__.py +10 -4
  136. eva/vision/utils/io/image.py +21 -2
  137. eva/vision/utils/io/mat.py +36 -0
  138. eva/vision/utils/io/nifti.py +33 -12
  139. eva/vision/utils/io/text.py +10 -3
  140. kaiko_eva-0.1.0.dist-info/METADATA +553 -0
  141. kaiko_eva-0.1.0.dist-info/RECORD +205 -0
  142. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
  143. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
  144. eva/.DS_Store +0 -0
  145. eva/core/callbacks/writers/embeddings.py +0 -169
  146. eva/core/callbacks/writers/typings.py +0 -23
  147. eva/core/data/datasets/embeddings/__init__.py +0 -13
  148. eva/core/data/datasets/embeddings/classification/__init__.py +0 -10
  149. eva/core/data/datasets/embeddings/classification/embeddings.py +0 -66
  150. eva/core/models/networks/transforms/__init__.py +0 -5
  151. eva/core/models/networks/wrappers/__init__.py +0 -8
  152. eva/vision/models/.DS_Store +0 -0
  153. eva/vision/models/networks/.DS_Store +0 -0
  154. eva/vision/models/networks/postprocesses/__init__.py +0 -5
  155. eva/vision/models/networks/postprocesses/cls.py +0 -25
  156. kaiko_eva-0.0.2.dist-info/METADATA +0 -431
  157. kaiko_eva-0.0.2.dist-info/RECORD +0 -127
  158. /eva/core/models/{networks → wrappers}/_utils.py +0 -0
  159. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,50 @@
1
+ import random
2
+ from typing import Tuple
3
+
4
+ import numpy as np
5
+
6
+
7
+ def set_seed(seed: int) -> None:
8
+ random.seed(seed)
9
+ np.random.seed(seed)
10
+
11
+
12
+ def get_grid_coords_and_indices(
13
+ layer_shape: Tuple[int, int],
14
+ width: int,
15
+ height: int,
16
+ overlap: Tuple[int, int],
17
+ shuffle: bool = True,
18
+ seed: int = 42,
19
+ ):
20
+ """Get grid coordinates and indices.
21
+
22
+ Args:
23
+ layer_shape: The shape of the layer.
24
+ width: The width of the patches.
25
+ height: The height of the patches.
26
+ overlap: The overlap between patches in the grid.
27
+ shuffle: Whether to shuffle the indices.
28
+ seed: The random seed.
29
+ """
30
+ x_range = range(0, layer_shape[0] - width + 1, width - overlap[0])
31
+ y_range = range(0, layer_shape[1] - height + 1, height - overlap[1])
32
+ x_y = [(x, y) for x in x_range for y in y_range]
33
+
34
+ indices = list(range(len(x_y)))
35
+ if shuffle:
36
+ set_seed(seed)
37
+ np.random.shuffle(indices)
38
+ return x_y, indices
39
+
40
+
41
+ def validate_dimensions(width: int, height: int, layer_shape: Tuple[int, int]) -> None:
42
+ """Checks if the width / height is bigger than the layer shape.
43
+
44
+ Args:
45
+ width: The width of the patches.
46
+ height: The height of the patches.
47
+ layer_shape: The shape of the layer.
48
+ """
49
+ if width > layer_shape[0] or height > layer_shape[1]:
50
+ raise ValueError("The width / height cannot be bigger than the layer shape.")
@@ -0,0 +1,48 @@
1
+ """Base classes for samplers."""
2
+
3
+ import abc
4
+ from typing import Generator, Tuple
5
+
6
+ from eva.vision.data.wsi.patching.mask import Mask
7
+
8
+
9
+ class Sampler(abc.ABC):
10
+ """Base class for samplers."""
11
+
12
+ @abc.abstractmethod
13
+ def sample(
14
+ self,
15
+ width: int,
16
+ height: int,
17
+ layer_shape: Tuple[int, int],
18
+ mask: Mask | None = None,
19
+ ) -> Generator[Tuple[int, int], None, None]:
20
+ """Sample patche coordinates.
21
+
22
+ Args:
23
+ width: The width of the patches.
24
+ height: The height of the patches.
25
+ layer_shape: The shape of the layer.
26
+ mask: Tuple containing the mask array and the scaling factor with respect to the
27
+ provided layer_shape. Optional, only required for samplers with foreground
28
+ filtering.
29
+
30
+ Returns:
31
+ A generator producing sampled patch coordinates.
32
+ """
33
+
34
+
35
+ class ForegroundSampler(Sampler):
36
+ """Base class for samplers with foreground filtering capabilities."""
37
+
38
+ @abc.abstractmethod
39
+ def is_foreground(
40
+ self,
41
+ mask: Mask,
42
+ x: int,
43
+ y: int,
44
+ width: int,
45
+ height: int,
46
+ min_foreground_ratio: float,
47
+ ) -> bool:
48
+ """Check if a patch contains sufficient foreground."""
@@ -0,0 +1,99 @@
1
+ """Foreground grid sampler."""
2
+
3
+ from typing import Tuple
4
+
5
+ from eva.vision.data.wsi.patching.mask import Mask
6
+ from eva.vision.data.wsi.patching.samplers import _utils, base
7
+
8
+
9
+ class ForegroundGridSampler(base.ForegroundSampler):
10
+ """Sample patches based on a grid, only returning patches containing foreground."""
11
+
12
+ def __init__(
13
+ self,
14
+ max_samples: int = 20,
15
+ overlap: Tuple[int, int] = (0, 0),
16
+ min_foreground_ratio: float = 0.35,
17
+ seed: int = 42,
18
+ ) -> None:
19
+ """Initializes the sampler.
20
+
21
+ Args:
22
+ max_samples: The maximum number of samples to return.
23
+ overlap: The overlap between patches in the grid.
24
+ min_foreground_ratio: The minimum amount of foreground
25
+ within a sampled patch.
26
+ seed: The random seed.
27
+ """
28
+ self.max_samples = max_samples
29
+ self.overlap = overlap
30
+ self.min_foreground_ratio = min_foreground_ratio
31
+ self.seed = seed
32
+
33
+ def sample(
34
+ self,
35
+ width: int,
36
+ height: int,
37
+ layer_shape: Tuple[int, int],
38
+ mask: Mask,
39
+ ):
40
+ """Sample patches from a grid containing foreground.
41
+
42
+ Args:
43
+ width: The width of the patches.
44
+ height: The height of the patches.
45
+ layer_shape: The shape of the layer.
46
+ mask: The mask of the image.
47
+ """
48
+ _utils.validate_dimensions(width, height, layer_shape)
49
+ x_y, indices = _utils.get_grid_coords_and_indices(
50
+ layer_shape, width, height, self.overlap, seed=self.seed
51
+ )
52
+
53
+ count = 0
54
+ for i in indices:
55
+ if count >= self.max_samples:
56
+ break
57
+
58
+ if self.is_foreground(
59
+ mask=mask,
60
+ x=x_y[i][0],
61
+ y=x_y[i][1],
62
+ width=width,
63
+ height=height,
64
+ min_foreground_ratio=self.min_foreground_ratio,
65
+ ):
66
+ count += 1
67
+ yield x_y[i]
68
+
69
+ def is_foreground(
70
+ self,
71
+ mask: Mask,
72
+ x: int,
73
+ y: int,
74
+ width: int,
75
+ height: int,
76
+ min_foreground_ratio: float,
77
+ ) -> bool:
78
+ """Check if a patch contains sufficient foreground.
79
+
80
+ Args:
81
+ mask: The mask of the image.
82
+ x: The x-coordinate of the patch.
83
+ y: The y-coordinate of the patch.
84
+ width: The width of the patch.
85
+ height: The height of the patch.
86
+ min_foreground_ratio: The minimum amount of foreground in the patch.
87
+ """
88
+ x_, y_ = self._scale_coords(x, y, mask.scale_factors)
89
+ width_, height_ = self._scale_coords(width, height, mask.scale_factors)
90
+ patch_mask = mask.mask_array[y_ : y_ + height_, x_ : x_ + width_]
91
+ return patch_mask.sum() / patch_mask.size >= min_foreground_ratio
92
+
93
+ def _scale_coords(
94
+ self,
95
+ x: int,
96
+ y: int,
97
+ scale_factors: Tuple[float, float],
98
+ ) -> Tuple[int, int]:
99
+ return int(x / scale_factors[0]), int(y / scale_factors[1])
@@ -0,0 +1,47 @@
1
+ """Grid sampler."""
2
+
3
+ from typing import Generator, Tuple
4
+
5
+ from eva.vision.data.wsi.patching.samplers import _utils, base
6
+
7
+
8
+ class GridSampler(base.Sampler):
9
+ """Sample patches based on a grid.
10
+
11
+ Args:
12
+ max_samples: The maximum number of samples to return.
13
+ overlap: The overlap between patches in the grid.
14
+ seed: The random seed.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ max_samples: int | None = None,
20
+ overlap: Tuple[int, int] = (0, 0),
21
+ seed: int = 42,
22
+ ):
23
+ """Initializes the sampler."""
24
+ self.max_samples = max_samples
25
+ self.overlap = overlap
26
+ self.seed = seed
27
+
28
+ def sample(
29
+ self,
30
+ width: int,
31
+ height: int,
32
+ layer_shape: Tuple[int, int],
33
+ ) -> Generator[Tuple[int, int], None, None]:
34
+ """Sample patches from a grid.
35
+
36
+ Args:
37
+ width: The width of the patches.
38
+ height: The height of the patches.
39
+ layer_shape: The shape of the layer.
40
+ """
41
+ _utils.validate_dimensions(width, height, layer_shape)
42
+ x_y, indices = _utils.get_grid_coords_and_indices(
43
+ layer_shape, width, height, self.overlap, seed=self.seed
44
+ )
45
+ max_samples = len(indices) if self.max_samples is None else self.max_samples
46
+ for i in indices[:max_samples]:
47
+ yield x_y[i]
@@ -0,0 +1,41 @@
1
+ """Random sampler."""
2
+
3
+ import random
4
+ from typing import Generator, Tuple
5
+
6
+ from eva.vision.data.wsi.patching.samplers import _utils, base
7
+
8
+
9
+ class RandomSampler(base.Sampler):
10
+ """Sample patch coordinates randomly.
11
+
12
+ Args:
13
+ n_samples: The number of samples to return.
14
+ seed: The random seed.
15
+ """
16
+
17
+ def __init__(self, n_samples: int = 1, seed: int = 42):
18
+ """Initializes the sampler."""
19
+ self.seed = seed
20
+ self.n_samples = n_samples
21
+
22
+ def sample(
23
+ self,
24
+ width: int,
25
+ height: int,
26
+ layer_shape: Tuple[int, int],
27
+ ) -> Generator[Tuple[int, int], None, None]:
28
+ """Sample random patches.
29
+
30
+ Args:
31
+ width: The width of the patches.
32
+ height: The height of the patches.
33
+ layer_shape: The shape of the layer.
34
+ """
35
+ _utils.validate_dimensions(width, height, layer_shape)
36
+ _utils.set_seed(self.seed)
37
+
38
+ x_max, y_max = layer_shape[0], layer_shape[1]
39
+ for _ in range(self.n_samples):
40
+ x, y = random.randint(0, x_max - width), random.randint(0, y_max - height) # nosec
41
+ yield x, y
@@ -0,0 +1,5 @@
1
+ """Loss functions API."""
2
+
3
+ from eva.vision.losses.dice import DiceLoss
4
+
5
+ __all__ = ["DiceLoss"]
@@ -0,0 +1,40 @@
1
+ """Dice loss."""
2
+
3
+ import torch
4
+ from monai import losses
5
+ from monai.networks import one_hot # type: ignore
6
+ from typing_extensions import override
7
+
8
+
9
+ class DiceLoss(losses.DiceLoss): # type: ignore
10
+ """Computes the average Dice loss between two tensors.
11
+
12
+ Extends the implementation from MONAI
13
+ - to support semantic target labels (meaning targets of shape BHW)
14
+ - to support `ignore_index` functionality
15
+ """
16
+
17
+ def __init__(self, *args, ignore_index: int | None = None, **kwargs) -> None:
18
+ """Initialize the DiceLoss with support for ignore_index.
19
+
20
+ Args:
21
+ args: Positional arguments from the base class.
22
+ ignore_index: Specifies a target value that is ignored and
23
+ does not contribute to the input gradient.
24
+ kwargs: Key-word arguments from the base class.
25
+ """
26
+ super().__init__(*args, **kwargs)
27
+
28
+ self.ignore_index = ignore_index
29
+
30
+ @override
31
+ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # noqa
32
+ if self.ignore_index is not None:
33
+ mask = targets != self.ignore_index
34
+ targets = targets * mask
35
+ inputs = torch.mul(inputs, mask.unsqueeze(1) if mask.ndim == 3 else mask)
36
+
37
+ if targets.ndim == 3:
38
+ targets = one_hot(targets[:, None, ...], num_classes=inputs.shape[1])
39
+
40
+ return super().forward(inputs, targets)
@@ -1,5 +1,7 @@
1
1
  """Vision Models API."""
2
2
 
3
- from eva.vision.models import networks
3
+ from eva.vision.models import networks, wrappers
4
+ from eva.vision.models.networks import backbones
5
+ from eva.vision.models.wrappers import ModelFromRegistry, TimmModel
4
6
 
5
- __all__ = ["networks"]
7
+ __all__ = ["networks", "wrappers", "backbones", "ModelFromRegistry", "TimmModel"]
@@ -0,0 +1,5 @@
1
+ """Vision modules API."""
2
+
3
+ from eva.vision.models.modules.semantic_segmentation import SemanticSegmentationModule
4
+
5
+ __all__ = ["SemanticSegmentationModule"]
@@ -0,0 +1,161 @@
1
+ """"Neural Network Semantic Segmentation Module."""
2
+
3
+ from typing import Any, Callable, Dict, Iterable, List, Tuple
4
+
5
+ import torch
6
+ from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
7
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
8
+ from torch import nn, optim
9
+ from torch.optim import lr_scheduler
10
+ from typing_extensions import override
11
+
12
+ from eva.core.metrics import structs as metrics_lib
13
+ from eva.core.models.modules import module
14
+ from eva.core.models.modules.typings import INPUT_BATCH, INPUT_TENSOR_BATCH
15
+ from eva.core.models.modules.utils import batch_postprocess, grad
16
+ from eva.core.utils import parser
17
+ from eva.vision.models.networks import decoders
18
+
19
+
20
+ class SemanticSegmentationModule(module.ModelModule):
21
+ """Neural network semantic segmentation module for training on patch embeddings."""
22
+
23
+ def __init__(
24
+ self,
25
+ decoder: decoders.Decoder,
26
+ criterion: Callable[..., torch.Tensor],
27
+ encoder: Dict[str, Any] | Callable[[torch.Tensor], List[torch.Tensor]] | None = None,
28
+ lr_multiplier_encoder: float = 0.0,
29
+ optimizer: OptimizerCallable = optim.AdamW,
30
+ lr_scheduler: LRSchedulerCallable = lr_scheduler.ConstantLR,
31
+ metrics: metrics_lib.MetricsSchema | None = None,
32
+ postprocess: batch_postprocess.BatchPostProcess | None = None,
33
+ ) -> None:
34
+ """Initializes the neural net head module.
35
+
36
+ Args:
37
+ decoder: The decoder model.
38
+ criterion: The loss function to use.
39
+ encoder: The encoder model. If `None`, it will be expected
40
+ that the input batch returns the features directly.
41
+ If pass as a dictionary, it will be parsed to an object
42
+ during the `configure_model` step.
43
+ lr_multiplier_encoder: The learning rate multiplier for the
44
+ encoder parameters. If `0`, it will freeze the encoder.
45
+ optimizer: The optimizer to use.
46
+ lr_scheduler: The learning rate scheduler to use.
47
+ metrics: The metric groups to track.
48
+ postprocess: A list of helper functions to apply after the
49
+ loss and before the metrics calculation to the model
50
+ predictions and targets.
51
+ """
52
+ super().__init__(metrics=metrics, postprocess=postprocess)
53
+
54
+ self.decoder = decoder
55
+ self.criterion = criterion
56
+ self.encoder = encoder # type: ignore
57
+ self.lr_multiplier_encoder = lr_multiplier_encoder
58
+ self.optimizer = optimizer
59
+ self.lr_scheduler = lr_scheduler
60
+
61
+ @override
62
+ def configure_model(self) -> None:
63
+ self._freeze_encoder()
64
+
65
+ if isinstance(self.encoder, dict):
66
+ self.encoder: Callable[[torch.Tensor], List[torch.Tensor]] = parser.parse_object(
67
+ self.encoder,
68
+ expected_type=nn.Module,
69
+ )
70
+
71
+ @override
72
+ def configure_optimizers(self) -> Any:
73
+ optimizer = self.optimizer(
74
+ [
75
+ {"params": self.decoder.parameters()},
76
+ {
77
+ "params": self._encoder_trainable_parameters(),
78
+ "lr": self._base_lr * self.lr_multiplier_encoder,
79
+ },
80
+ ]
81
+ )
82
+ lr_scheduler = self.lr_scheduler(optimizer)
83
+ return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
84
+
85
+ @override
86
+ def forward(
87
+ self,
88
+ inputs: torch.Tensor,
89
+ to_size: Tuple[int, int] | None = None,
90
+ *args: Any,
91
+ **kwargs: Any,
92
+ ) -> torch.Tensor:
93
+ """Maps the input tensor (image tensor or embeddings) to masks.
94
+
95
+ If `inputs` is image tensor, then the `self.encoder`
96
+ should be implemented, otherwise it will be interpreted
97
+ as embeddings, where the `to_size` should be given.
98
+ """
99
+ if self.encoder is None and to_size is None:
100
+ raise ValueError(
101
+ "Please provide the expected `to_size` that the "
102
+ "decoder should map the embeddings (`inputs`) to."
103
+ )
104
+
105
+ patch_embeddings = self.encoder(inputs) if self.encoder else inputs
106
+ return self.decoder(patch_embeddings, to_size or inputs.shape[-2:])
107
+
108
+ @override
109
+ def training_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
110
+ return self._batch_step(batch)
111
+
112
+ @override
113
+ def validation_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
114
+ return self._batch_step(batch)
115
+
116
+ @override
117
+ def test_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
118
+ return self._batch_step(batch)
119
+
120
+ @override
121
+ def predict_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> torch.Tensor:
122
+ tensor = INPUT_BATCH(*batch).data
123
+ return self.encoder(tensor) if isinstance(self.encoder, nn.Module) else tensor
124
+
125
+ @property
126
+ def _base_lr(self) -> float:
127
+ """Returns the base learning rate."""
128
+ base_optimizer = self.optimizer(self.parameters())
129
+ return base_optimizer.param_groups[-1]["lr"]
130
+
131
+ def _encoder_trainable_parameters(self) -> Iterable[torch.Tensor]:
132
+ """Returns the trainable parameters of the encoder."""
133
+ return (
134
+ self.encoder.parameters()
135
+ if isinstance(self.encoder, nn.Module) and self.lr_multiplier_encoder > 0
136
+ else iter(())
137
+ )
138
+
139
+ def _freeze_encoder(self) -> None:
140
+ """If initialized, it freezes the encoder network."""
141
+ if isinstance(self.encoder, nn.Module) and self.lr_multiplier_encoder == 0:
142
+ grad.deactivate_requires_grad(self.encoder)
143
+
144
+ def _batch_step(self, batch: INPUT_TENSOR_BATCH) -> STEP_OUTPUT:
145
+ """Performs a model forward step and calculates the loss.
146
+
147
+ Args:
148
+ batch: The desired batch to process.
149
+
150
+ Returns:
151
+ The batch step output.
152
+ """
153
+ data, targets, metadata = INPUT_TENSOR_BATCH(*batch)
154
+ predictions = self(data, to_size=targets.shape[-2:])
155
+ loss = self.criterion(predictions, targets)
156
+ return {
157
+ "loss": loss,
158
+ "targets": targets,
159
+ "predictions": predictions,
160
+ "metadata": metadata,
161
+ }
@@ -1,6 +1,5 @@
1
1
  """Vision Networks API."""
2
2
 
3
- from eva.vision.models.networks import postprocesses
4
3
  from eva.vision.models.networks.abmil import ABMIL
5
4
 
6
- __all__ = ["postprocesses", "ABMIL"]
5
+ __all__ = ["ABMIL"]
@@ -0,0 +1,6 @@
1
+ """Vision Model Backbones API."""
2
+
3
+ from eva.vision.models.networks.backbones import pathology, timm, universal
4
+ from eva.vision.models.networks.backbones.registry import BackboneModelRegistry, register_model
5
+
6
+ __all__ = ["pathology", "timm", "universal", "BackboneModelRegistry", "register_model"]
@@ -0,0 +1,39 @@
1
+ """Utilis for backbone networks."""
2
+
3
+ from typing import Any, Dict, Tuple
4
+
5
+ from torch import nn
6
+
7
+ from eva import models
8
+ from eva.core.models import transforms
9
+
10
+
11
+ def load_hugingface_model(
12
+ model_name: str,
13
+ out_indices: int | Tuple[int, ...] | None,
14
+ model_kwargs: Dict[str, Any] | None = None,
15
+ transform_args: Dict[str, Any] | None = None,
16
+ ) -> nn.Module:
17
+ """Helper function to load HuggingFace models.
18
+
19
+ Args:
20
+ model_name: The model name to load.
21
+ out_indices: Whether and which multi-level patch embeddings to return.
22
+ Currently only out_indices=1 is supported.
23
+ model_kwargs: The arguments used for instantiating the model.
24
+ transform_args: The arguments used for instantiating the transform.
25
+
26
+ Returns: The model instance.
27
+ """
28
+ if out_indices is None:
29
+ tensor_transforms = transforms.ExtractCLSFeatures(**(transform_args or {}))
30
+ elif out_indices == 1:
31
+ tensor_transforms = transforms.ExtractPatchFeatures(**(transform_args or {}))
32
+ else:
33
+ raise ValueError(f"out_indices={out_indices} is currently not supported.")
34
+
35
+ return models.HuggingFaceModel(
36
+ model_name_or_path=model_name,
37
+ tensor_transforms=tensor_transforms,
38
+ model_kwargs=model_kwargs,
39
+ )
@@ -0,0 +1,31 @@
1
+ """Vision Pathology Model Backbones API."""
2
+
3
+ from eva.vision.models.networks.backbones.pathology.bioptimus import bioptimus_h_optimus_0
4
+ from eva.vision.models.networks.backbones.pathology.gigapath import prov_gigapath
5
+ from eva.vision.models.networks.backbones.pathology.histai import histai_hibou_b, histai_hibou_l
6
+ from eva.vision.models.networks.backbones.pathology.kaiko import (
7
+ kaiko_vitb8,
8
+ kaiko_vitb16,
9
+ kaiko_vitl14,
10
+ kaiko_vits8,
11
+ kaiko_vits16,
12
+ )
13
+ from eva.vision.models.networks.backbones.pathology.lunit import lunit_vits8, lunit_vits16
14
+ from eva.vision.models.networks.backbones.pathology.mahmood import mahmood_uni
15
+ from eva.vision.models.networks.backbones.pathology.owkin import owkin_phikon
16
+
17
+ __all__ = [
18
+ "kaiko_vitb16",
19
+ "kaiko_vitb8",
20
+ "kaiko_vitl14",
21
+ "kaiko_vits16",
22
+ "kaiko_vits8",
23
+ "owkin_phikon",
24
+ "lunit_vits16",
25
+ "lunit_vits8",
26
+ "mahmood_uni",
27
+ "bioptimus_h_optimus_0",
28
+ "prov_gigapath",
29
+ "histai_hibou_b",
30
+ "histai_hibou_l",
31
+ ]
@@ -0,0 +1,34 @@
1
+ """Pathology FMs from Bioptimus."""
2
+
3
+ from typing import Tuple
4
+
5
+ import timm
6
+ from torch import nn
7
+
8
+ from eva.vision.models.networks.backbones.registry import register_model
9
+
10
+
11
+ @register_model("pathology/bioptimus_h_optimus_0")
12
+ def bioptimus_h_optimus_0(
13
+ dynamic_img_size: bool = True,
14
+ out_indices: int | Tuple[int, ...] | None = None,
15
+ ) -> nn.Module:
16
+ """Initializes the h_optimus_0 pathology FM by Bioptimus.
17
+
18
+ Args:
19
+ dynamic_img_size: Whether to allow the interpolation embedding
20
+ to be interpolated at `forward()` time when image grid changes
21
+ from original.
22
+ out_indices: Weather and which multi-level patch embeddings to return.
23
+
24
+ Returns:
25
+ The model instance.
26
+ """
27
+ return timm.create_model(
28
+ model_name="hf-hub:bioptimus/H-optimus-0",
29
+ pretrained=True,
30
+ init_values=1e-5,
31
+ dynamic_img_size=dynamic_img_size,
32
+ out_indices=out_indices,
33
+ features_only=out_indices is not None,
34
+ )
@@ -0,0 +1,33 @@
1
+ """Pathology FMs from other/mixed entities."""
2
+
3
+ from typing import Tuple
4
+
5
+ import timm
6
+ from torch import nn
7
+
8
+ from eva.vision.models.networks.backbones.registry import register_model
9
+
10
+
11
+ @register_model("pathology/prov_gigapath")
12
+ def prov_gigapath(
13
+ dynamic_img_size: bool = True,
14
+ out_indices: int | Tuple[int, ...] | None = None,
15
+ ) -> nn.Module:
16
+ """Initializes the Prov-GigaPath pathology FM.
17
+
18
+ Args:
19
+ dynamic_img_size: Whether to allow the interpolation embedding
20
+ to be interpolated at `forward()` time when image grid changes
21
+ from original.
22
+ out_indices: Weather and which multi-level patch embeddings to return.
23
+
24
+ Returns:
25
+ The model instance.
26
+ """
27
+ return timm.create_model(
28
+ model_name="hf_hub:prov-gigapath/prov-gigapath",
29
+ pretrained=True,
30
+ dynamic_img_size=dynamic_img_size,
31
+ out_indices=out_indices,
32
+ features_only=out_indices is not None,
33
+ )