kaiko-eva 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (159) hide show
  1. eva/core/callbacks/__init__.py +2 -2
  2. eva/core/callbacks/writers/__init__.py +6 -3
  3. eva/core/callbacks/writers/embeddings/__init__.py +6 -0
  4. eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
  5. eva/core/callbacks/writers/embeddings/base.py +192 -0
  6. eva/core/callbacks/writers/embeddings/classification.py +117 -0
  7. eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
  8. eva/core/callbacks/writers/embeddings/typings.py +38 -0
  9. eva/core/data/datasets/__init__.py +2 -2
  10. eva/core/data/datasets/classification/__init__.py +8 -0
  11. eva/core/data/datasets/classification/embeddings.py +34 -0
  12. eva/core/data/datasets/{embeddings/classification → classification}/multi_embeddings.py +13 -9
  13. eva/core/data/datasets/{embeddings/base.py → embeddings.py} +47 -32
  14. eva/core/data/splitting/__init__.py +6 -0
  15. eva/core/data/splitting/random.py +41 -0
  16. eva/core/data/splitting/stratified.py +56 -0
  17. eva/core/loggers/experimental_loggers.py +2 -2
  18. eva/core/loggers/log/__init__.py +3 -2
  19. eva/core/loggers/log/image.py +71 -0
  20. eva/core/loggers/log/parameters.py +10 -0
  21. eva/core/loggers/loggers.py +6 -0
  22. eva/core/metrics/__init__.py +6 -2
  23. eva/core/metrics/defaults/__init__.py +10 -3
  24. eva/core/metrics/defaults/classification/__init__.py +1 -1
  25. eva/core/metrics/defaults/classification/binary.py +0 -9
  26. eva/core/metrics/defaults/classification/multiclass.py +0 -8
  27. eva/core/metrics/defaults/segmentation/__init__.py +5 -0
  28. eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
  29. eva/core/metrics/generalized_dice.py +59 -0
  30. eva/core/metrics/mean_iou.py +120 -0
  31. eva/core/metrics/structs/schemas.py +3 -1
  32. eva/core/models/__init__.py +3 -1
  33. eva/core/models/modules/head.py +10 -4
  34. eva/core/models/modules/typings.py +14 -1
  35. eva/core/models/modules/utils/batch_postprocess.py +37 -5
  36. eva/core/models/networks/__init__.py +1 -2
  37. eva/core/models/networks/mlp.py +2 -2
  38. eva/core/models/transforms/__init__.py +6 -0
  39. eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
  40. eva/core/models/transforms/extract_patch_features.py +47 -0
  41. eva/core/models/wrappers/__init__.py +13 -0
  42. eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
  43. eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
  44. eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
  45. eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
  46. eva/core/trainers/functional.py +1 -0
  47. eva/core/utils/__init__.py +6 -0
  48. eva/core/utils/clone.py +27 -0
  49. eva/core/utils/memory.py +28 -0
  50. eva/core/utils/operations.py +26 -0
  51. eva/core/utils/parser.py +20 -0
  52. eva/vision/__init__.py +2 -2
  53. eva/vision/callbacks/__init__.py +5 -0
  54. eva/vision/callbacks/loggers/__init__.py +5 -0
  55. eva/vision/callbacks/loggers/batch/__init__.py +5 -0
  56. eva/vision/callbacks/loggers/batch/base.py +130 -0
  57. eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
  58. eva/vision/data/datasets/__init__.py +30 -3
  59. eva/vision/data/datasets/_validators.py +15 -2
  60. eva/vision/data/datasets/classification/__init__.py +12 -1
  61. eva/vision/data/datasets/classification/bach.py +10 -15
  62. eva/vision/data/datasets/classification/base.py +17 -24
  63. eva/vision/data/datasets/classification/camelyon16.py +244 -0
  64. eva/vision/data/datasets/classification/crc.py +10 -15
  65. eva/vision/data/datasets/classification/mhist.py +10 -15
  66. eva/vision/data/datasets/classification/panda.py +184 -0
  67. eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
  68. eva/vision/data/datasets/classification/wsi.py +105 -0
  69. eva/vision/data/datasets/segmentation/__init__.py +15 -2
  70. eva/vision/data/datasets/segmentation/_utils.py +38 -0
  71. eva/vision/data/datasets/segmentation/base.py +16 -17
  72. eva/vision/data/datasets/segmentation/bcss.py +236 -0
  73. eva/vision/data/datasets/segmentation/consep.py +156 -0
  74. eva/vision/data/datasets/segmentation/embeddings.py +34 -0
  75. eva/vision/data/datasets/segmentation/lits.py +178 -0
  76. eva/vision/data/datasets/segmentation/monusac.py +236 -0
  77. eva/vision/data/datasets/segmentation/{total_segmentator.py → total_segmentator_2d.py} +130 -36
  78. eva/vision/data/datasets/wsi.py +187 -0
  79. eva/vision/data/transforms/__init__.py +3 -2
  80. eva/vision/data/transforms/common/__init__.py +2 -1
  81. eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
  82. eva/vision/data/transforms/common/resize_and_crop.py +6 -7
  83. eva/vision/data/transforms/normalization/__init__.py +6 -0
  84. eva/vision/data/transforms/normalization/clamp.py +43 -0
  85. eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
  86. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
  87. eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
  88. eva/vision/data/wsi/__init__.py +16 -0
  89. eva/vision/data/wsi/backends/__init__.py +69 -0
  90. eva/vision/data/wsi/backends/base.py +115 -0
  91. eva/vision/data/wsi/backends/openslide.py +73 -0
  92. eva/vision/data/wsi/backends/pil.py +52 -0
  93. eva/vision/data/wsi/backends/tiffslide.py +42 -0
  94. eva/vision/data/wsi/patching/__init__.py +6 -0
  95. eva/vision/data/wsi/patching/coordinates.py +98 -0
  96. eva/vision/data/wsi/patching/mask.py +123 -0
  97. eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
  98. eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
  99. eva/vision/data/wsi/patching/samplers/base.py +48 -0
  100. eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
  101. eva/vision/data/wsi/patching/samplers/grid.py +47 -0
  102. eva/vision/data/wsi/patching/samplers/random.py +41 -0
  103. eva/vision/losses/__init__.py +5 -0
  104. eva/vision/losses/dice.py +40 -0
  105. eva/vision/models/__init__.py +4 -2
  106. eva/vision/models/modules/__init__.py +5 -0
  107. eva/vision/models/modules/semantic_segmentation.py +161 -0
  108. eva/vision/models/networks/__init__.py +1 -2
  109. eva/vision/models/networks/backbones/__init__.py +6 -0
  110. eva/vision/models/networks/backbones/_utils.py +39 -0
  111. eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
  112. eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
  113. eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
  114. eva/vision/models/networks/backbones/pathology/histai.py +46 -0
  115. eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
  116. eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
  117. eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
  118. eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
  119. eva/vision/models/networks/backbones/registry.py +47 -0
  120. eva/vision/models/networks/backbones/timm/__init__.py +5 -0
  121. eva/vision/models/networks/backbones/timm/backbones.py +54 -0
  122. eva/vision/models/networks/backbones/universal/__init__.py +8 -0
  123. eva/vision/models/networks/backbones/universal/vit.py +54 -0
  124. eva/vision/models/networks/decoders/__init__.py +6 -0
  125. eva/vision/models/networks/decoders/decoder.py +7 -0
  126. eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
  127. eva/vision/models/networks/decoders/segmentation/common.py +74 -0
  128. eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
  129. eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
  130. eva/vision/models/wrappers/__init__.py +6 -0
  131. eva/vision/models/wrappers/from_registry.py +48 -0
  132. eva/vision/models/wrappers/from_timm.py +68 -0
  133. eva/vision/utils/colormap.py +77 -0
  134. eva/vision/utils/convert.py +56 -13
  135. eva/vision/utils/io/__init__.py +10 -4
  136. eva/vision/utils/io/image.py +21 -2
  137. eva/vision/utils/io/mat.py +36 -0
  138. eva/vision/utils/io/nifti.py +33 -12
  139. eva/vision/utils/io/text.py +10 -3
  140. kaiko_eva-0.1.0.dist-info/METADATA +553 -0
  141. kaiko_eva-0.1.0.dist-info/RECORD +205 -0
  142. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
  143. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
  144. eva/.DS_Store +0 -0
  145. eva/core/callbacks/writers/embeddings.py +0 -169
  146. eva/core/callbacks/writers/typings.py +0 -23
  147. eva/core/data/datasets/embeddings/__init__.py +0 -13
  148. eva/core/data/datasets/embeddings/classification/__init__.py +0 -10
  149. eva/core/data/datasets/embeddings/classification/embeddings.py +0 -66
  150. eva/core/models/networks/transforms/__init__.py +0 -5
  151. eva/core/models/networks/wrappers/__init__.py +0 -8
  152. eva/vision/models/.DS_Store +0 -0
  153. eva/vision/models/networks/.DS_Store +0 -0
  154. eva/vision/models/networks/postprocesses/__init__.py +0 -5
  155. eva/vision/models/networks/postprocesses/cls.py +0 -25
  156. kaiko_eva-0.0.2.dist-info/METADATA +0 -431
  157. kaiko_eva-0.0.2.dist-info/RECORD +0 -127
  158. /eva/core/models/{networks → wrappers}/_utils.py +0 -0
  159. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,114 @@
1
+ """Convolutional based semantic segmentation decoder."""
2
+
3
+ from typing import List, Tuple
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional
8
+
9
+ from eva.vision.models.networks.decoders import decoder
10
+
11
+
12
+ class ConvDecoder(decoder.Decoder):
13
+ """Convolutional segmentation decoder."""
14
+
15
+ def __init__(self, layers: nn.Module) -> None:
16
+ """Initializes the convolutional based decoder head.
17
+
18
+ Here the input nn layers will be directly applied to the
19
+ features of shape (batch_size, hidden_size, n_patches_height,
20
+ n_patches_width), where n_patches is image_size / patch_size.
21
+ Note the n_patches is also known as grid_size.
22
+
23
+ Args:
24
+ layers: The convolutional layers to be used as the decoder head.
25
+ """
26
+ super().__init__()
27
+
28
+ self._layers = layers
29
+
30
+ def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor:
31
+ """Forward function for multi-level feature maps to a single one.
32
+
33
+ It will interpolate the features and concat them into a single tensor
34
+ on the dimension axis of the hidden size.
35
+
36
+ Example:
37
+ >>> features = [torch.Tensor(16, 384, 14, 14), torch.Size(16, 384, 14, 14)]
38
+ >>> output = self._forward_features(features)
39
+ >>> assert output.shape == torch.Size([16, 768, 14, 14])
40
+
41
+ Args:
42
+ features: List of multi-level image features of shape (batch_size,
43
+ hidden_size, n_patches_height, n_patches_width).
44
+
45
+ Returns:
46
+ A tensor of shape (batch_size, hidden_size, n_patches_height,
47
+ n_patches_width) which is feature map of the decoder head.
48
+ """
49
+ if not isinstance(features, list) or features[0].ndim != 4:
50
+ raise ValueError(
51
+ "Input features should be a list of four (4) dimensional inputs of "
52
+ "shape (batch_size, hidden_size, n_patches_height, n_patches_width)."
53
+ )
54
+
55
+ upsampled_features = [
56
+ functional.interpolate(
57
+ input=embeddings,
58
+ size=features[0].shape[2:],
59
+ mode="bilinear",
60
+ align_corners=False,
61
+ )
62
+ for embeddings in features
63
+ ]
64
+ return torch.cat(upsampled_features, dim=1)
65
+
66
+ def _forward_head(self, patch_embeddings: torch.Tensor) -> torch.Tensor:
67
+ """Forward of the decoder head.
68
+
69
+ Args:
70
+ patch_embeddings: The patch embeddings tensor of shape
71
+ (batch_size, hidden_size, n_patches_height, n_patches_width).
72
+
73
+ Returns:
74
+ The logits as a tensor (batch_size, n_classes, upscale_height, upscale_width).
75
+ """
76
+ return self._layers(patch_embeddings)
77
+
78
+ def _cls_seg(
79
+ self,
80
+ logits: torch.Tensor,
81
+ image_size: Tuple[int, int],
82
+ ) -> torch.Tensor:
83
+ """Classify each pixel of the image.
84
+
85
+ Args:
86
+ logits: The decoder outputs of shape (batch_size, n_classes,
87
+ height, width).
88
+ image_size: The target image size (height, width).
89
+
90
+ Returns:
91
+ Tensor containing scores for all of the classes with shape
92
+ (batch_size, n_classes, image_height, image_width).
93
+ """
94
+ return functional.interpolate(logits, image_size, mode="bilinear")
95
+
96
+ def forward(
97
+ self,
98
+ features: List[torch.Tensor],
99
+ image_size: Tuple[int, int],
100
+ ) -> torch.Tensor:
101
+ """Maps the patch embeddings to a segmentation mask of the image size.
102
+
103
+ Args:
104
+ features: List of multi-level image features of shape (batch_size,
105
+ hidden_size, n_patches_height, n_patches_width).
106
+ image_size: The target image size (height, width).
107
+
108
+ Returns:
109
+ Tensor containing scores for all of the classes with shape
110
+ (batch_size, n_classes, image_height, image_width).
111
+ """
112
+ patch_embeddings = self._forward_features(features)
113
+ logits = self._forward_head(patch_embeddings)
114
+ return self._cls_seg(logits, image_size)
@@ -0,0 +1,125 @@
1
+ """Linear based decoder."""
2
+
3
+ from typing import List, Tuple
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional
8
+
9
+ from eva.vision.models.networks.decoders import decoder
10
+
11
+
12
+ class LinearDecoder(decoder.Decoder):
13
+ """Linear decoder."""
14
+
15
+ def __init__(self, layers: nn.Module) -> None:
16
+ """Initializes the linear based decoder head.
17
+
18
+ Here the input nn layers will be applied to the reshaped
19
+ features (batch_size, patch_embeddings, hidden_size) from
20
+ the input (batch_size, hidden_size, height, width) and then
21
+ unwrapped again to (batch_size, n_classes, height, width).
22
+
23
+ Args:
24
+ layers: The linear layers to be used as the decoder head.
25
+ """
26
+ super().__init__()
27
+
28
+ self._layers = layers
29
+
30
+ def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor:
31
+ """Forward function for multi-level feature maps to a single one.
32
+
33
+ It will interpolate the features and concat them into a single tensor
34
+ on the dimension axis of the hidden size.
35
+
36
+ Example:
37
+ >>> features = [torch.Tensor(16, 384, 14, 14), torch.Size(16, 384, 14, 14)]
38
+ >>> output = self._forward_features(features)
39
+ >>> assert output.shape == torch.Size([16, 768, 14, 14])
40
+
41
+ Args:
42
+ features: List of multi-level image features of shape (batch_size,
43
+ hidden_size, n_patches_height, n_patches_width).
44
+
45
+ Returns:
46
+ A tensor of shape (batch_size, hidden_size, n_patches_height,
47
+ n_patches_width) which is feature map of the decoder head.
48
+ """
49
+ if not isinstance(features, list) or features[0].ndim != 4:
50
+ raise ValueError(
51
+ "Input features should be a list of four (4) dimensional inputs of "
52
+ "shape (batch_size, hidden_size, n_patches_height, n_patches_width)."
53
+ )
54
+
55
+ upsampled_features = [
56
+ functional.interpolate(
57
+ input=embeddings,
58
+ size=features[0].shape[2:],
59
+ mode="bilinear",
60
+ align_corners=False,
61
+ )
62
+ for embeddings in features
63
+ ]
64
+ return torch.cat(upsampled_features, dim=1)
65
+
66
+ def _forward_head(self, patch_embeddings: torch.Tensor) -> torch.Tensor:
67
+ """Forward of the decoder head.
68
+
69
+ Here the following transformations will take place:
70
+ - (batch_size, hidden_size, n_patches_height, n_patches_width)
71
+ - (batch_size, hidden_size, n_patches_height * n_patches_width)
72
+ - (batch_size, n_patches_height * n_patches_width, hidden_size)
73
+ - (batch_size, n_patches_height * n_patches_width, n_classes)
74
+ - (batch_size, n_classes, n_patches_height, n_patches_width)
75
+
76
+ Args:
77
+ patch_embeddings: The patch embeddings tensor of shape
78
+ (batch_size, hidden_size, n_patches_height, n_patches_width).
79
+
80
+ Returns:
81
+ The logits as a tensor (batch_size, n_classes, n_patches_height,
82
+ n_patches_width).
83
+ """
84
+ batch_size, hidden_size, height, width = patch_embeddings.shape
85
+ embeddings_reshaped = patch_embeddings.reshape(batch_size, hidden_size, height * width)
86
+ logits = self._layers(embeddings_reshaped.permute(0, 2, 1))
87
+ return logits.permute(0, 2, 1).reshape(batch_size, -1, height, width)
88
+
89
+ def _cls_seg(
90
+ self,
91
+ logits: torch.Tensor,
92
+ image_size: Tuple[int, int],
93
+ ) -> torch.Tensor:
94
+ """Classify each pixel of the image.
95
+
96
+ Args:
97
+ logits: The decoder outputs of shape (batch_size, n_classes,
98
+ height, width).
99
+ image_size: The target image size (height, width).
100
+
101
+ Returns:
102
+ Tensor containing scores for all of the classes with shape
103
+ (batch_size, n_classes, image_height, image_width).
104
+ """
105
+ return functional.interpolate(logits, image_size, mode="bilinear")
106
+
107
+ def forward(
108
+ self,
109
+ features: List[torch.Tensor],
110
+ image_size: Tuple[int, int],
111
+ ) -> torch.Tensor:
112
+ """Maps the patch embeddings to a segmentation mask of the image size.
113
+
114
+ Args:
115
+ features: List of multi-level image features of shape (batch_size,
116
+ hidden_size, n_patches_height, n_patches_width).
117
+ image_size: The target image size (height, width).
118
+
119
+ Returns:
120
+ Tensor containing scores for all of the classes with shape
121
+ (batch_size, n_classes, image_height, image_width).
122
+ """
123
+ patch_embeddings = self._forward_features(features)
124
+ logits = self._forward_head(patch_embeddings)
125
+ return self._cls_seg(logits, image_size)
@@ -0,0 +1,6 @@
1
+ """Vision Model Wrappers API."""
2
+
3
+ from eva.vision.models.wrappers.from_registry import ModelFromRegistry
4
+ from eva.vision.models.wrappers.from_timm import TimmModel
5
+
6
+ __all__ = ["TimmModel", "ModelFromRegistry"]
@@ -0,0 +1,48 @@
1
+ """Vision backbone helper class."""
2
+
3
+ from typing import Any, Callable, Dict
4
+
5
+ from typing_extensions import override
6
+
7
+ from eva.core.models import wrappers
8
+ from eva.vision.models.networks.backbones import BackboneModelRegistry
9
+
10
+
11
+ class ModelFromRegistry(wrappers.BaseModel):
12
+ """Wrapper class for vision backbone models.
13
+
14
+ This class can be used by load backbones available in eva's
15
+ model registry by name. New backbones can be registered by using
16
+ the `@register_model(model_name)` decorator.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ model_name: str,
22
+ model_kwargs: Dict[str, Any] | None = None,
23
+ model_extra_kwargs: Dict[str, Any] | None = None,
24
+ tensor_transforms: Callable | None = None,
25
+ ) -> None:
26
+ """Initializes the model.
27
+
28
+ Args:
29
+ model_name: The name of the model to load.
30
+ model_kwargs: The arguments used for instantiating the model.
31
+ model_extra_kwargs: Extra arguments used for instantiating the model.
32
+ tensor_transforms: The transforms to apply to the output tensor
33
+ produced by the model.
34
+ """
35
+ super().__init__(tensor_transforms=tensor_transforms)
36
+
37
+ self._model_name = model_name
38
+ self._model_kwargs = model_kwargs or {}
39
+ self._model_extra_kwargs = model_extra_kwargs or {}
40
+
41
+ self.load_model()
42
+
43
+ @override
44
+ def load_model(self) -> None:
45
+ self._model = BackboneModelRegistry.load_model(
46
+ self._model_name, self._model_kwargs | self._model_extra_kwargs
47
+ )
48
+ ModelFromRegistry.__name__ = self._model_name
@@ -0,0 +1,68 @@
1
+ """Model wrapper for timm models."""
2
+
3
+ from typing import Any, Callable, Dict, Tuple
4
+ from urllib import parse
5
+
6
+ import timm
7
+ from typing_extensions import override
8
+
9
+ from eva.core.models import wrappers
10
+
11
+
12
+ class TimmModel(wrappers.BaseModel):
13
+ """Model wrapper for `timm` models.
14
+
15
+ Note that only models with `forward_intermediates`
16
+ method are currently supported.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ model_name: str,
22
+ pretrained: bool = True,
23
+ checkpoint_path: str = "",
24
+ out_indices: int | Tuple[int, ...] | None = None,
25
+ model_kwargs: Dict[str, Any] | None = None,
26
+ tensor_transforms: Callable | None = None,
27
+ ) -> None:
28
+ """Initializes the encoder.
29
+
30
+ Args:
31
+ model_name: Name of model to instantiate.
32
+ pretrained: If set to `True`, load pretrained ImageNet-1k weights.
33
+ checkpoint_path: Path of checkpoint to load.
34
+ out_indices: Returns last n blocks if `int`, all if `None`, select
35
+ matching indices if sequence.
36
+ model_kwargs: Extra model arguments.
37
+ tensor_transforms: The transforms to apply to the output tensor
38
+ produced by the model.
39
+ """
40
+ super().__init__(tensor_transforms=tensor_transforms)
41
+
42
+ self._model_name = model_name
43
+ self._pretrained = pretrained
44
+ self._checkpoint_path = checkpoint_path
45
+ self._out_indices = out_indices
46
+ self._model_kwargs = model_kwargs or {}
47
+
48
+ self.load_model()
49
+
50
+ @override
51
+ def load_model(self) -> None:
52
+ """Builds and loads the timm model as feature extractor."""
53
+ self._model = timm.create_model(
54
+ model_name=self._model_name,
55
+ pretrained=True if self._checkpoint_path else self._pretrained,
56
+ pretrained_cfg=self._pretrained_cfg,
57
+ out_indices=self._out_indices,
58
+ features_only=self._out_indices is not None,
59
+ **self._model_kwargs,
60
+ )
61
+ TimmModel.__name__ = self._model_name
62
+
63
+ @property
64
+ def _pretrained_cfg(self) -> Dict[str, Any]:
65
+ if not self._checkpoint_path:
66
+ return {}
67
+ key = "file" if parse.urlparse(self._checkpoint_path).scheme in ("file", "") else "url"
68
+ return {key: self._checkpoint_path, "num_classes": 0}
@@ -0,0 +1,77 @@
1
+ """Color mapping constants."""
2
+
3
+ COLORS = [
4
+ (0, 0, 0),
5
+ (255, 0, 0), # Red
6
+ (0, 255, 0), # Green
7
+ (0, 0, 255), # Blue
8
+ (255, 255, 0), # Yellow
9
+ (255, 0, 255), # Magenta
10
+ (0, 255, 255), # Cyan
11
+ (128, 128, 0), # Olive
12
+ (128, 0, 128), # Purple
13
+ (0, 128, 128), # Teal
14
+ (192, 192, 192), # Silver
15
+ (128, 128, 128), # Gray
16
+ (255, 165, 0), # Orange
17
+ (210, 105, 30), # Chocolate
18
+ (0, 128, 0), # Lime
19
+ (255, 192, 203), # Pink
20
+ (255, 69, 0), # Red-Orange
21
+ (255, 140, 0), # Dark Orange
22
+ (0, 255, 255), # Sky Blue
23
+ (0, 255, 127), # Spring Green
24
+ (0, 0, 139), # Dark Blue
25
+ (255, 20, 147), # Deep Pink
26
+ (139, 69, 19), # Saddle Brown
27
+ (0, 100, 0), # Dark Green
28
+ (106, 90, 205), # Slate Blue
29
+ (138, 43, 226), # Blue-Violet
30
+ (218, 165, 32), # Goldenrod
31
+ (199, 21, 133), # Medium Violet Red
32
+ (70, 130, 180), # Steel Blue
33
+ (165, 42, 42), # Brown
34
+ (128, 0, 0), # Maroon
35
+ (255, 0, 255), # Fuchsia
36
+ (210, 180, 140), # Tan
37
+ (0, 0, 128), # Navy
38
+ (139, 0, 139), # Dark Magenta
39
+ (144, 238, 144), # Light Green
40
+ (46, 139, 87), # Sea Green
41
+ (255, 255, 0), # Gold
42
+ (154, 205, 50), # Yellow Green
43
+ (0, 191, 255), # Deep Sky Blue
44
+ (0, 250, 154), # Medium Spring Green
45
+ (250, 128, 114), # Salmon
46
+ (255, 105, 180), # Hot Pink
47
+ (204, 255, 204), # Pastel Light Green
48
+ (51, 0, 51), # Very Dark Magenta
49
+ (255, 102, 0), # Dark Orange
50
+ (0, 255, 0), # Bright Green
51
+ (51, 153, 255), # Blue-Purple
52
+ (51, 51, 255), # Bright Blue
53
+ (204, 0, 0), # Dark Red
54
+ (90, 90, 90), # Very Dark Gray
55
+ (255, 255, 51), # Pastel Yellow
56
+ (255, 153, 255), # Pink-Magenta
57
+ (153, 0, 76), # Dark Pink
58
+ (51, 25, 0), # Very Dark Brown
59
+ (102, 51, 0), # Dark Brown
60
+ (0, 0, 51), # Very Dark Blue
61
+ (180, 180, 180), # Dark Gray
62
+ (102, 255, 204), # Pastel Green
63
+ (0, 102, 0), # Dark Green
64
+ (220, 245, 20), # Lime Yellow
65
+ (255, 204, 204), # Pastel Pink
66
+ (0, 204, 255), # Pastel Blue
67
+ (240, 240, 240), # Light Gray
68
+ (153, 153, 0), # Dark Yellow
69
+ (102, 0, 51), # Dark Red-Pink
70
+ (0, 51, 0), # Very Dark Green
71
+ (255, 102, 204), # Magenta Pink
72
+ (204, 0, 102), # Red-Pink
73
+ ]
74
+ """RGB colors."""
75
+
76
+ COLORMAP = dict(enumerate(COLORS)) | {255: (255, 255, 255)}
77
+ """Class id to RGB color mapping."""
@@ -1,24 +1,67 @@
1
1
  """Image conversion related functionalities."""
2
2
 
3
- from typing import Any
3
+ from typing import Iterable
4
4
 
5
- import numpy as np
6
- import numpy.typing as npt
5
+ import torch
6
+ from torchvision.transforms.v2 import functional
7
7
 
8
8
 
9
- def to_8bit(image_array: npt.NDArray[Any]) -> npt.NDArray[np.uint8]:
10
- """Casts an image of higher bit image (i.e. 16bit) to 8bit.
9
+ def descale_and_denorm_image(
10
+ image: torch.Tensor,
11
+ mean: Iterable[float] = (0.0, 0.0, 0.0),
12
+ std: Iterable[float] = (1.0, 1.0, 1.0),
13
+ inplace: bool = True,
14
+ ) -> torch.Tensor:
15
+ """De-scales and de-norms an image tensor to (0, 255) range.
11
16
 
12
17
  Args:
13
- image_array: The image array to convert.
18
+ image: An image float tensor.
19
+ mean: The mean that the image channels are normalized with.
20
+ std: The std that the image channels are normalized with.
21
+ inplace: Whether to perform the operation in-place.
14
22
 
15
23
  Returns:
16
- The image as normalized as a 8-bit format.
24
+ The image tensor of range (0, 255) range as uint8.
17
25
  """
18
- if np.issubdtype(image_array.dtype, np.integer):
19
- image_array = image_array.astype(np.float64)
26
+ if not inplace:
27
+ image = image.clone()
20
28
 
21
- image_scaled_array = image_array - image_array.min()
22
- image_scaled_array /= image_scaled_array.max()
23
- image_scaled_array *= 255
24
- return image_scaled_array.astype(np.uint8)
29
+ norm_image = _descale_image(image, mean=mean, std=std)
30
+ return _denorm_image(norm_image)
31
+
32
+
33
+ def _descale_image(
34
+ image: torch.Tensor,
35
+ mean: Iterable[float] = (0.0, 0.0, 0.0),
36
+ std: Iterable[float] = (1.0, 1.0, 1.0),
37
+ ) -> torch.Tensor:
38
+ """De-scales an image tensor to (0., 1.) range.
39
+
40
+ Args:
41
+ image: An image float tensor.
42
+ mean: The normalized channels mean values.
43
+ std: The normalized channels std values.
44
+
45
+ Returns:
46
+ The de-normalized image tensor of range (0., 1.).
47
+ """
48
+ return functional.normalize(
49
+ image,
50
+ mean=[-cmean / cstd for cmean, cstd in zip(mean, std, strict=False)],
51
+ std=[1 / cstd for cstd in std],
52
+ )
53
+
54
+
55
+ def _denorm_image(image: torch.Tensor) -> torch.Tensor:
56
+ """De-normalizes an image tensor from (0., 1.) to (0, 255) range.
57
+
58
+ Args:
59
+ image: An image float tensor.
60
+
61
+ Returns:
62
+ The image tensor of range (0, 255) range as uint8.
63
+ """
64
+ image_scaled = image - image.min()
65
+ image_scaled /= image_scaled.max()
66
+ image_scaled *= 255
67
+ return image_scaled.to(dtype=torch.uint8)
@@ -1,12 +1,18 @@
1
1
  """Vision I/O utilities."""
2
2
 
3
- from eva.vision.utils.io.image import read_image
4
- from eva.vision.utils.io.nifti import fetch_total_nifti_slices, read_nifti_slice
3
+ from eva.vision.utils.io.image import read_image, read_image_as_array, read_image_as_tensor
4
+ from eva.vision.utils.io.mat import read_mat, save_mat
5
+ from eva.vision.utils.io.nifti import fetch_nifti_shape, read_nifti, save_array_as_nifti
5
6
  from eva.vision.utils.io.text import read_csv
6
7
 
7
8
  __all__ = [
8
9
  "read_image",
9
- "fetch_total_nifti_slices",
10
- "read_nifti_slice",
10
+ "read_image_as_array",
11
+ "read_image_as_tensor",
12
+ "fetch_nifti_shape",
13
+ "read_nifti",
14
+ "save_array_as_nifti",
11
15
  "read_csv",
16
+ "read_mat",
17
+ "save_mat",
12
18
  ]
@@ -3,6 +3,8 @@
3
3
  import cv2
4
4
  import numpy as np
5
5
  import numpy.typing as npt
6
+ from torchvision import tv_tensors
7
+ from torchvision.transforms.v2 import functional
6
8
 
7
9
  from eva.vision.utils.io import _utils
8
10
 
@@ -14,7 +16,7 @@ def read_image(path: str) -> npt.NDArray[np.uint8]:
14
16
  path: The path of the image file.
15
17
 
16
18
  Returns:
17
- The RGB image as a numpy array.
19
+ The RGB image as a numpy array (HxWxC).
18
20
 
19
21
  Raises:
20
22
  FileExistsError: If the path does not exist or it is unreachable.
@@ -23,6 +25,23 @@ def read_image(path: str) -> npt.NDArray[np.uint8]:
23
25
  return read_image_as_array(path, cv2.IMREAD_COLOR)
24
26
 
25
27
 
28
+ def read_image_as_tensor(path: str) -> tv_tensors.Image:
29
+ """Reads and loads the image from a file path as a RGB torch tensor.
30
+
31
+ Args:
32
+ path: The path of the image file.
33
+
34
+ Returns:
35
+ The RGB image as a torch tensor (CxHxW).
36
+
37
+ Raises:
38
+ FileExistsError: If the path does not exist or it is unreachable.
39
+ IOError: If the image could not be loaded.
40
+ """
41
+ image_array = read_image(path)
42
+ return functional.to_image(image_array)
43
+
44
+
26
45
  def read_image_as_array(path: str, flags: int = cv2.IMREAD_UNCHANGED) -> npt.NDArray[np.uint8]:
27
46
  """Reads and loads an image file as a numpy array.
28
47
 
@@ -51,4 +70,4 @@ def read_image_as_array(path: str, flags: int = cv2.IMREAD_UNCHANGED) -> npt.NDA
51
70
  if image.ndim == 2 and flags == cv2.IMREAD_COLOR:
52
71
  image = image[:, :, np.newaxis]
53
72
 
54
- return np.asarray(image).astype(np.uint8)
73
+ return np.asarray(image, dtype=np.uint8)
@@ -0,0 +1,36 @@
1
+ """mat I/O related functions."""
2
+
3
+ import os
4
+ from typing import Any, Dict
5
+
6
+ import numpy.typing as npt
7
+ import scipy.io
8
+
9
+ from eva.vision.utils.io import _utils
10
+
11
+
12
+ def read_mat(path: str) -> Dict[str, npt.NDArray[Any]]:
13
+ """Reads and loads a mat file.
14
+
15
+ Args:
16
+ path: The path to the mat file.
17
+
18
+ Returns:
19
+ mat file as dictionary.
20
+
21
+ Raises:
22
+ FileExistsError: If the path does not exist or it is unreachable.
23
+ """
24
+ _utils.check_file(path)
25
+ return scipy.io.loadmat(path)
26
+
27
+
28
+ def save_mat(path: str, data: Dict[str, npt.NDArray[Any]]) -> None:
29
+ """Saves a mat file.
30
+
31
+ Args:
32
+ path: The path to save the mat file.
33
+ data: The dictionary containing the data to save.
34
+ """
35
+ os.makedirs(os.path.dirname(path), exist_ok=True)
36
+ scipy.io.savemat(path, data)