kaiko-eva 0.0.0.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (111) hide show
  1. eva/.DS_Store +0 -0
  2. eva/__init__.py +33 -0
  3. eva/__main__.py +18 -0
  4. eva/__version__.py +25 -0
  5. eva/core/__init__.py +19 -0
  6. eva/core/callbacks/__init__.py +5 -0
  7. eva/core/callbacks/writers/__init__.py +5 -0
  8. eva/core/callbacks/writers/embeddings.py +169 -0
  9. eva/core/callbacks/writers/typings.py +23 -0
  10. eva/core/cli/__init__.py +5 -0
  11. eva/core/cli/cli.py +19 -0
  12. eva/core/cli/logo.py +38 -0
  13. eva/core/cli/setup.py +89 -0
  14. eva/core/data/__init__.py +14 -0
  15. eva/core/data/dataloaders/__init__.py +5 -0
  16. eva/core/data/dataloaders/dataloader.py +80 -0
  17. eva/core/data/datamodules/__init__.py +6 -0
  18. eva/core/data/datamodules/call.py +33 -0
  19. eva/core/data/datamodules/datamodule.py +108 -0
  20. eva/core/data/datamodules/schemas.py +62 -0
  21. eva/core/data/datasets/__init__.py +7 -0
  22. eva/core/data/datasets/base.py +53 -0
  23. eva/core/data/datasets/classification/__init__.py +5 -0
  24. eva/core/data/datasets/classification/embeddings.py +154 -0
  25. eva/core/data/datasets/dataset.py +6 -0
  26. eva/core/data/samplers/__init__.py +5 -0
  27. eva/core/data/samplers/sampler.py +6 -0
  28. eva/core/data/transforms/__init__.py +5 -0
  29. eva/core/data/transforms/dtype/__init__.py +5 -0
  30. eva/core/data/transforms/dtype/array.py +28 -0
  31. eva/core/interface/__init__.py +5 -0
  32. eva/core/interface/interface.py +79 -0
  33. eva/core/metrics/__init__.py +17 -0
  34. eva/core/metrics/average_loss.py +47 -0
  35. eva/core/metrics/binary_balanced_accuracy.py +22 -0
  36. eva/core/metrics/defaults/__init__.py +6 -0
  37. eva/core/metrics/defaults/classification/__init__.py +6 -0
  38. eva/core/metrics/defaults/classification/binary.py +76 -0
  39. eva/core/metrics/defaults/classification/multiclass.py +80 -0
  40. eva/core/metrics/structs/__init__.py +9 -0
  41. eva/core/metrics/structs/collection.py +6 -0
  42. eva/core/metrics/structs/metric.py +6 -0
  43. eva/core/metrics/structs/module.py +115 -0
  44. eva/core/metrics/structs/schemas.py +47 -0
  45. eva/core/metrics/structs/typings.py +15 -0
  46. eva/core/models/__init__.py +13 -0
  47. eva/core/models/modules/__init__.py +7 -0
  48. eva/core/models/modules/head.py +113 -0
  49. eva/core/models/modules/inference.py +37 -0
  50. eva/core/models/modules/module.py +190 -0
  51. eva/core/models/modules/typings.py +23 -0
  52. eva/core/models/modules/utils/__init__.py +6 -0
  53. eva/core/models/modules/utils/batch_postprocess.py +57 -0
  54. eva/core/models/modules/utils/grad.py +23 -0
  55. eva/core/models/networks/__init__.py +6 -0
  56. eva/core/models/networks/_utils.py +25 -0
  57. eva/core/models/networks/mlp.py +69 -0
  58. eva/core/models/networks/transforms/__init__.py +5 -0
  59. eva/core/models/networks/transforms/extract_cls_features.py +25 -0
  60. eva/core/models/networks/wrappers/__init__.py +8 -0
  61. eva/core/models/networks/wrappers/base.py +47 -0
  62. eva/core/models/networks/wrappers/from_function.py +58 -0
  63. eva/core/models/networks/wrappers/huggingface.py +37 -0
  64. eva/core/models/networks/wrappers/onnx.py +47 -0
  65. eva/core/trainers/__init__.py +6 -0
  66. eva/core/trainers/_logging.py +81 -0
  67. eva/core/trainers/_recorder.py +149 -0
  68. eva/core/trainers/_utils.py +12 -0
  69. eva/core/trainers/functional.py +113 -0
  70. eva/core/trainers/trainer.py +97 -0
  71. eva/core/utils/__init__.py +1 -0
  72. eva/core/utils/io/__init__.py +5 -0
  73. eva/core/utils/io/dataframe.py +21 -0
  74. eva/core/utils/multiprocessing.py +44 -0
  75. eva/core/utils/workers.py +21 -0
  76. eva/vision/__init__.py +14 -0
  77. eva/vision/data/__init__.py +5 -0
  78. eva/vision/data/datasets/__init__.py +22 -0
  79. eva/vision/data/datasets/_utils.py +50 -0
  80. eva/vision/data/datasets/_validators.py +44 -0
  81. eva/vision/data/datasets/classification/__init__.py +15 -0
  82. eva/vision/data/datasets/classification/bach.py +174 -0
  83. eva/vision/data/datasets/classification/base.py +103 -0
  84. eva/vision/data/datasets/classification/crc.py +176 -0
  85. eva/vision/data/datasets/classification/mhist.py +106 -0
  86. eva/vision/data/datasets/classification/patch_camelyon.py +203 -0
  87. eva/vision/data/datasets/classification/total_segmentator.py +212 -0
  88. eva/vision/data/datasets/segmentation/__init__.py +6 -0
  89. eva/vision/data/datasets/segmentation/base.py +112 -0
  90. eva/vision/data/datasets/segmentation/total_segmentator.py +212 -0
  91. eva/vision/data/datasets/structs.py +17 -0
  92. eva/vision/data/datasets/vision.py +43 -0
  93. eva/vision/data/transforms/__init__.py +5 -0
  94. eva/vision/data/transforms/common/__init__.py +5 -0
  95. eva/vision/data/transforms/common/resize_and_crop.py +44 -0
  96. eva/vision/models/__init__.py +5 -0
  97. eva/vision/models/networks/__init__.py +6 -0
  98. eva/vision/models/networks/abmil.py +176 -0
  99. eva/vision/models/networks/postprocesses/__init__.py +5 -0
  100. eva/vision/models/networks/postprocesses/cls.py +25 -0
  101. eva/vision/utils/__init__.py +5 -0
  102. eva/vision/utils/io/__init__.py +12 -0
  103. eva/vision/utils/io/_utils.py +29 -0
  104. eva/vision/utils/io/image.py +54 -0
  105. eva/vision/utils/io/nifti.py +50 -0
  106. eva/vision/utils/io/text.py +18 -0
  107. kaiko_eva-0.0.0.dev6.dist-info/METADATA +393 -0
  108. kaiko_eva-0.0.0.dev6.dist-info/RECORD +111 -0
  109. kaiko_eva-0.0.0.dev6.dist-info/WHEEL +4 -0
  110. kaiko_eva-0.0.0.dev6.dist-info/entry_points.txt +4 -0
  111. kaiko_eva-0.0.0.dev6.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,212 @@
1
+ """TotalSegmentator 2D segmentation dataset class."""
2
+
3
+ import functools
4
+ import os
5
+ from glob import glob
6
+ from typing import Callable, Dict, List, Literal, Tuple
7
+
8
+ import numpy as np
9
+ from torchvision.datasets import utils
10
+ from typing_extensions import override
11
+
12
+ from eva.vision.data.datasets import _utils, _validators, structs
13
+ from eva.vision.data.datasets.segmentation import base
14
+ from eva.vision.utils import io
15
+
16
+
17
+ class TotalSegmentator2D(base.ImageSegmentation):
18
+ """TotalSegmentator 2D segmentation dataset."""
19
+
20
+ _train_index_ranges: List[Tuple[int, int]] = [(0, 83)]
21
+ """Train range indices."""
22
+
23
+ _val_index_ranges: List[Tuple[int, int]] = [(83, 103)]
24
+ """Validation range indices."""
25
+
26
+ _n_slices_per_image: int = 20
27
+ """The amount of slices to sample per 3D CT scan image."""
28
+
29
+ _resources_full: List[structs.DownloadResource] = [
30
+ structs.DownloadResource(
31
+ filename="Totalsegmentator_dataset_v201.zip",
32
+ url="https://zenodo.org/records/10047292/files/Totalsegmentator_dataset_v201.zip",
33
+ md5="fe250e5718e0a3b5df4c4ea9d58a62fe",
34
+ ),
35
+ ]
36
+ """Resources for the full dataset version."""
37
+
38
+ _resources_small: List[structs.DownloadResource] = [
39
+ structs.DownloadResource(
40
+ filename="Totalsegmentator_dataset_small_v201.zip",
41
+ url="https://zenodo.org/records/10047263/files/Totalsegmentator_dataset_small_v201.zip",
42
+ md5="6b5524af4b15e6ba06ef2d700c0c73e0",
43
+ ),
44
+ ]
45
+ """Resources for the small dataset version."""
46
+
47
+ def __init__(
48
+ self,
49
+ root: str,
50
+ split: Literal["train", "val"] | None,
51
+ version: Literal["small", "full"] = "small",
52
+ download: bool = False,
53
+ image_transforms: Callable | None = None,
54
+ target_transforms: Callable | None = None,
55
+ image_target_transforms: Callable | None = None,
56
+ ) -> None:
57
+ """Initialize dataset.
58
+
59
+ Args:
60
+ root: Path to the root directory of the dataset. The dataset will
61
+ be downloaded and extracted here, if it does not already exist.
62
+ split: Dataset split to use. If `None`, the entire dataset is used.
63
+ version: The version of the dataset to initialize.
64
+ download: Whether to download the data for the specified split.
65
+ Note that the download will be executed only by additionally
66
+ calling the :meth:`prepare_data` method and if the data does not
67
+ exist yet on disk.
68
+ image_transforms: A function/transform that takes in an image
69
+ and returns a transformed version.
70
+ target_transforms: A function/transform that takes in the target
71
+ and transforms it.
72
+ image_target_transforms: A function/transforms that takes in an
73
+ image and a label and returns the transformed versions of both.
74
+ This transform happens after the `image_transforms` and
75
+ `target_transforms`.
76
+ """
77
+ super().__init__(
78
+ image_transforms=image_transforms,
79
+ target_transforms=target_transforms,
80
+ image_target_transforms=image_target_transforms,
81
+ )
82
+
83
+ self._root = root
84
+ self._split = split
85
+ self._version = version
86
+ self._download = download
87
+
88
+ self._samples_dirs: List[str] = []
89
+ self._indices: List[int] = []
90
+
91
+ @functools.cached_property
92
+ @override
93
+ def classes(self) -> List[str]:
94
+ def get_filename(path: str) -> str:
95
+ """Returns the filename from the full path."""
96
+ return os.path.basename(path).split(".")[0]
97
+
98
+ first_sample_labels = os.path.join(
99
+ self._root, self._samples_dirs[0], "segmentations", "*.nii.gz"
100
+ )
101
+ return sorted(map(get_filename, glob(first_sample_labels)))
102
+
103
+ @property
104
+ @override
105
+ def class_to_idx(self) -> Dict[str, int]:
106
+ return {label: index for index, label in enumerate(self.classes)}
107
+
108
+ @override
109
+ def filename(self, index: int) -> str:
110
+ sample_dir = self._samples_dirs[self._indices[index]]
111
+ return os.path.join(sample_dir, "ct.nii.gz")
112
+
113
+ @override
114
+ def prepare_data(self) -> None:
115
+ if self._download:
116
+ self._download_dataset()
117
+
118
+ @override
119
+ def configure(self) -> None:
120
+ self._samples_dirs = self._fetch_samples_dirs()
121
+ self._indices = self._create_indices()
122
+
123
+ @override
124
+ def validate(self) -> None:
125
+ _validators.check_dataset_integrity(
126
+ self,
127
+ length=1660 if self._split == "train" else 400,
128
+ n_classes=117,
129
+ first_and_last_labels=("adrenal_gland_left", "vertebrae_T9"),
130
+ )
131
+
132
+ @override
133
+ def __len__(self) -> int:
134
+ return len(self._indices) * self._n_slices_per_image
135
+
136
+ @override
137
+ def load_image(self, index: int) -> np.ndarray:
138
+ image_path = self._get_image_path(index)
139
+ slice_index = self._get_sample_slice_index(index)
140
+ image_array = io.read_nifti_slice(image_path, slice_index)
141
+ return image_array.repeat(3, axis=2)
142
+
143
+ @override
144
+ def load_mask(self, index: int) -> np.ndarray:
145
+ masks_dir = self._get_masks_dir(index)
146
+ slice_index = self._get_sample_slice_index(index)
147
+ mask_paths = (os.path.join(masks_dir, label + ".nii.gz") for label in self.classes)
148
+ masks = [io.read_nifti_slice(path, slice_index) for path in mask_paths]
149
+ return np.concatenate(masks, axis=-1)
150
+
151
+ def _get_masks_dir(self, index: int) -> str:
152
+ """Returns the directory of the corresponding masks."""
153
+ sample_dir = self._get_sample_dir(index)
154
+ return os.path.join(self._root, sample_dir, "segmentations")
155
+
156
+ def _get_image_path(self, index: int) -> str:
157
+ """Returns the corresponding image path."""
158
+ sample_dir = self._get_sample_dir(index)
159
+ return os.path.join(self._root, sample_dir, "ct.nii.gz")
160
+
161
+ def _get_sample_dir(self, index: int) -> str:
162
+ """Returns the corresponding sample directory."""
163
+ sample_index = self._indices[index // self._n_slices_per_image]
164
+ return self._samples_dirs[sample_index]
165
+
166
+ def _get_sample_slice_index(self, index: int) -> int:
167
+ """Returns the corresponding slice index."""
168
+ image_path = self._get_image_path(index)
169
+ total_slices = io.fetch_total_nifti_slices(image_path)
170
+ slice_indices = np.linspace(0, total_slices - 1, num=self._n_slices_per_image, dtype=int)
171
+ return slice_indices[index % self._n_slices_per_image]
172
+
173
+ def _fetch_samples_dirs(self) -> List[str]:
174
+ """Returns the name of all the samples of all the splits of the dataset."""
175
+ sample_filenames = [
176
+ filename
177
+ for filename in os.listdir(self._root)
178
+ if os.path.isdir(os.path.join(self._root, filename))
179
+ ]
180
+ return sorted(sample_filenames)
181
+
182
+ def _create_indices(self) -> List[int]:
183
+ """Builds the dataset indices for the specified split."""
184
+ split_index_ranges = {
185
+ "train": self._train_index_ranges,
186
+ "val": self._val_index_ranges,
187
+ None: [(0, 103)],
188
+ }
189
+ index_ranges = split_index_ranges.get(self._split)
190
+ if index_ranges is None:
191
+ raise ValueError("Invalid data split. Use 'train', 'val' or `None`.")
192
+
193
+ return _utils.ranges_to_indices(index_ranges)
194
+
195
+ def _download_dataset(self) -> None:
196
+ """Downloads the dataset."""
197
+ dataset_resources = {
198
+ "small": self._resources_small,
199
+ "full": self._resources_full,
200
+ None: (0, 103),
201
+ }
202
+ resources = dataset_resources.get(self._version)
203
+ if resources is None:
204
+ raise ValueError("Invalid data version. Use 'small' or 'full'.")
205
+
206
+ for resource in resources:
207
+ utils.download_and_extract_archive(
208
+ resource.url,
209
+ download_root=self._root,
210
+ filename=resource.filename,
211
+ remove_finished=True,
212
+ )
@@ -0,0 +1,17 @@
1
+ """Helper dataclasses and data structures for vision datasets."""
2
+
3
+ import dataclasses
4
+
5
+
6
+ @dataclasses.dataclass(frozen=True)
7
+ class DownloadResource:
8
+ """Contains download information for a specific resource."""
9
+
10
+ filename: str
11
+ """The filename of the resource."""
12
+
13
+ url: str
14
+ """The URL of the resource."""
15
+
16
+ md5: str | None = None
17
+ """The MD5 hash of the resource."""
@@ -0,0 +1,43 @@
1
+ """Vision Dataset base class."""
2
+
3
+ import abc
4
+ from typing import Generic, TypeVar
5
+
6
+ from eva.core.data.datasets import base
7
+
8
+ DataSample = TypeVar("DataSample")
9
+ """The data sample type."""
10
+
11
+
12
+ class VisionDataset(base.Dataset, abc.ABC, Generic[DataSample]):
13
+ """Base dataset class for vision tasks."""
14
+
15
+ @abc.abstractmethod
16
+ def filename(self, index: int) -> str:
17
+ """Returns the filename of the `index`'th data sample.
18
+
19
+ Note that this is the relative file path to the root.
20
+
21
+ Args:
22
+ index: The index of the data-sample to select.
23
+
24
+ Returns:
25
+ The filename of the `index`'th data sample.
26
+ """
27
+
28
+ @abc.abstractmethod
29
+ def __getitem__(self, index: int) -> DataSample:
30
+ """Returns the `index`'th data sample.
31
+
32
+ Args:
33
+ index: The index of the data-sample to select.
34
+
35
+ Returns:
36
+ A data sample and its target.
37
+ """
38
+ raise NotImplementedError
39
+
40
+ @abc.abstractmethod
41
+ def __len__(self) -> int:
42
+ """Returns the total length of the data."""
43
+ raise NotImplementedError
@@ -0,0 +1,5 @@
1
+ """Vision data transforms."""
2
+
3
+ from eva.vision.data.transforms.common import ResizeAndCrop
4
+
5
+ __all__ = ["ResizeAndCrop"]
@@ -0,0 +1,5 @@
1
+ """Common vision transforms."""
2
+
3
+ from eva.vision.data.transforms.common.resize_and_crop import ResizeAndCrop
4
+
5
+ __all__ = ["ResizeAndCrop"]
@@ -0,0 +1,44 @@
1
+ """Resizes and normalizes the input image."""
2
+
3
+ from typing import Callable, Sequence
4
+
5
+ import torch
6
+ import torchvision.transforms.v2 as torch_transforms
7
+
8
+
9
+ class ResizeAndCrop(torch_transforms.Compose):
10
+ """Resizes, crops and normalizes an input image while preserving its aspect ratio."""
11
+
12
+ def __init__(
13
+ self,
14
+ size: int | Sequence[int] = 224,
15
+ mean: Sequence[float] = (0.5, 0.5, 0.5),
16
+ std: Sequence[float] = (0.5, 0.5, 0.5),
17
+ ) -> None:
18
+ """Initializes the transform object.
19
+
20
+ Args:
21
+ size: Desired output size of the crop. If size is an `int` instead
22
+ of sequence like (h, w), a square crop (size, size) is made.
23
+ mean: Sequence of means for each image channel.
24
+ std: Sequence of standard deviations for each image channel.
25
+ """
26
+ self._size = size
27
+ self._mean = mean
28
+ self._std = std
29
+
30
+ super().__init__(transforms=self._build_transforms())
31
+
32
+ def _build_transforms(self) -> Sequence[Callable]:
33
+ """Builds and returns the list of transforms."""
34
+ transforms = [
35
+ torch_transforms.ToImage(),
36
+ torch_transforms.Resize(size=self._size),
37
+ torch_transforms.CenterCrop(size=self._size),
38
+ torch_transforms.ToDtype(torch.float32, scale=True),
39
+ torch_transforms.Normalize(
40
+ mean=self._mean,
41
+ std=self._std,
42
+ ),
43
+ ]
44
+ return transforms
@@ -0,0 +1,5 @@
1
+ """Vision Models API."""
2
+
3
+ from eva.vision.models import networks
4
+
5
+ __all__ = ["networks"]
@@ -0,0 +1,6 @@
1
+ """Vision Networks API."""
2
+
3
+ from eva.vision.models.networks import postprocesses
4
+ from eva.vision.models.networks.abmil import ABMIL
5
+
6
+ __all__ = ["postprocesses", "ABMIL"]
@@ -0,0 +1,176 @@
1
+ """ABMIL Network."""
2
+
3
+ from typing import Type
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from eva.core.models.networks import MLP
9
+
10
+
11
+ class ABMIL(torch.nn.Module):
12
+ """ABMIL network for multiple instance learning classification tasks.
13
+
14
+ Takes an array of patch level embeddings per slide as input. This implementation supports
15
+ batched inputs of shape (`batch_size`, `n_instances`, `input_size`). For slides with less
16
+ than `n_instances` patches, you can apply padding and provide a mask tensor to the forward
17
+ pass.
18
+
19
+ The original implementation from [1] was used as a reference:
20
+ https://github.com/AMLab-Amsterdam/AttentionDeepMIL/blob/master/model.py
21
+
22
+ Notes:
23
+ - use_bias: The paper didn't use bias in their formalism, but their published
24
+ example code inadvertently does.
25
+ - To prevent dot product similarities near-equal due to concentration of measure
26
+ as a consequence of large input embedding dimensionality (>128), we added the
27
+ option to project the input embeddings to a lower dimensionality
28
+
29
+ [1] Maximilian Ilse, Jakub M. Tomczak, Max Welling, "Attention-based Deep Multiple
30
+ Instance Learning", 2018
31
+ https://arxiv.org/abs/1802.04712
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ input_size: int,
37
+ output_size: int,
38
+ projected_input_size: int | None,
39
+ hidden_size_attention: int = 128,
40
+ hidden_sizes_mlp: tuple = (128, 64),
41
+ use_bias: bool = True,
42
+ dropout_input_embeddings: float = 0.0,
43
+ dropout_attention: float = 0.0,
44
+ dropout_mlp: float = 0.0,
45
+ pad_value: int | float | None = float("-inf"),
46
+ ) -> None:
47
+ """Initializes the ABMIL network.
48
+
49
+ Args:
50
+ input_size: input embedding dimension
51
+ output_size: number of classes
52
+ projected_input_size: size of the projected input. if `None`, no projection is
53
+ performed.
54
+ hidden_size_attention: hidden dimension in attention network
55
+ hidden_sizes_mlp: dimensions for hidden layers in last mlp
56
+ use_bias: whether to use bias in the attention network
57
+ dropout_input_embeddings: dropout rate for the input embeddings
58
+ dropout_attention: dropout rate for the attention network and classifier
59
+ dropout_mlp: dropout rate for the final MLP network
60
+ pad_value: Value indicating padding in the input tensor. If specified, entries with
61
+ this value in the will be masked. If set to `None`, no masking is applied.
62
+ """
63
+ super().__init__()
64
+
65
+ self._pad_value = pad_value
66
+
67
+ if projected_input_size:
68
+ self.projector = nn.Sequential(
69
+ nn.Linear(input_size, projected_input_size, bias=True),
70
+ nn.Dropout(p=dropout_input_embeddings),
71
+ )
72
+ input_size = projected_input_size
73
+ else:
74
+ self.projector = nn.Dropout(p=dropout_input_embeddings)
75
+
76
+ self.gated_attention = GatedAttention(
77
+ input_dim=input_size,
78
+ hidden_dim=hidden_size_attention,
79
+ dropout=dropout_attention,
80
+ n_classes=1,
81
+ use_bias=use_bias,
82
+ )
83
+
84
+ self.classifier = MLP(
85
+ input_size=input_size,
86
+ output_size=output_size,
87
+ hidden_layer_sizes=hidden_sizes_mlp,
88
+ dropout=dropout_mlp,
89
+ hidden_activation_fn=nn.ReLU,
90
+ )
91
+
92
+ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
93
+ """Forward pass.
94
+
95
+ Args:
96
+ input_tensor: Tensor with expected shape of (batch_size, n_instances, input_size).
97
+ """
98
+ input_tensor, mask = self._mask_values(input_tensor, self._pad_value)
99
+
100
+ # (batch_size, n_instances, input_size) -> (batch_size, n_instances, projected_input_size)
101
+ input_tensor = self.projector(input_tensor)
102
+
103
+ attention_logits = self.gated_attention(input_tensor) # (batch_size, n_instances, 1)
104
+ if mask is not None:
105
+ # fill masked values with -inf, which will yield 0s after softmax
106
+ attention_logits = attention_logits.masked_fill(mask, float("-inf"))
107
+
108
+ attention_weights = nn.functional.softmax(attention_logits, dim=1)
109
+ # (batch_size, n_instances, 1)
110
+
111
+ attention_result = torch.matmul(torch.transpose(attention_weights, 1, 2), input_tensor)
112
+ # (batch_size, 1, hidden_size_attention)
113
+
114
+ attention_result = torch.squeeze(attention_result, 1) # (batch_size, hidden_size_attention)
115
+
116
+ return self.classifier(attention_result) # (batch_size, output_size)
117
+
118
+ def _mask_values(self, input_tensor: torch.Tensor, pad_value: float | None):
119
+ """Masks the padded values in the input tensor."""
120
+ if pad_value is None:
121
+ return input_tensor, None
122
+ else:
123
+ # (batch_size, n_instances, input_size)
124
+ mask = input_tensor == pad_value
125
+
126
+ # (batch_size, n_instances, input_size) -> (batch_size, n_instances, 1)
127
+ mask = mask.all(dim=-1, keepdim=True)
128
+
129
+ # Fill masked values with 0, so that they don't contribute to dense layers
130
+ input_tensor = input_tensor.masked_fill(mask, 0)
131
+
132
+ return input_tensor, mask
133
+
134
+
135
+ class GatedAttention(nn.Module):
136
+ """Attention mechanism with Sigmoid Gating using 3 linear layers."""
137
+
138
+ def __init__(
139
+ self,
140
+ input_dim: int,
141
+ hidden_dim: int,
142
+ dropout: float = 0.25,
143
+ n_classes: int = 1,
144
+ use_bias: bool = True,
145
+ activation_a: Type[nn.Module] = nn.Tanh,
146
+ activation_b: Type[nn.Module] = nn.Sigmoid,
147
+ ):
148
+ """Initializes the GatedAttention network.
149
+
150
+ Args:
151
+ input_dim: input feature dimension
152
+ hidden_dim: hidden layer dimension
153
+ dropout: dropout rate
154
+ n_classes: number of classes
155
+ use_bias: whether to use bias in the linear layers
156
+ activation_a: activation function for attention_a.
157
+ activation_b: activation function for attention_b.
158
+ """
159
+ super().__init__()
160
+
161
+ def make_attention(activation: nn.Module):
162
+ return nn.Sequential(
163
+ nn.Linear(input_dim, hidden_dim, bias=use_bias), nn.Dropout(p=dropout), activation
164
+ )
165
+
166
+ self.attention_a = make_attention(activation_a())
167
+ self.attention_b = make_attention(activation_b())
168
+ self.attention_c = nn.Linear(hidden_dim, n_classes)
169
+
170
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
171
+ """Forward pass."""
172
+ a = self.attention_a(x) # [..., hidden_dim]
173
+ b = self.attention_b(x) # [..., hidden_dim]
174
+ att = a.mul(b) # [..., hidden_dim]
175
+ att = self.attention_c(att) # [..., n_classes]
176
+ return att
@@ -0,0 +1,5 @@
1
+ """Model post-process transforms."""
2
+
3
+ from eva.vision.models.networks.postprocesses.cls import ExtractCLSFeatures
4
+
5
+ __all__ = ["ExtractCLSFeatures"]
@@ -0,0 +1,25 @@
1
+ """Transforms for extracting the CLS output from a model output."""
2
+
3
+ import torch
4
+ from transformers import modeling_outputs
5
+
6
+
7
+ class ExtractCLSFeatures:
8
+ """Extracts the CLS token from a ViT model output."""
9
+
10
+ def __call__(
11
+ self, tensor: torch.Tensor | modeling_outputs.BaseModelOutputWithPooling
12
+ ) -> torch.Tensor:
13
+ """Call method for the transformation.
14
+
15
+ Args:
16
+ tensor: The tensor representing the model output.
17
+ """
18
+ if isinstance(tensor, torch.Tensor):
19
+ transformed_tensor = tensor[:, 0, :]
20
+ elif isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling):
21
+ transformed_tensor = tensor.last_hidden_state[:, 0, :]
22
+ else:
23
+ raise ValueError(f"Unsupported type {type(tensor)}")
24
+
25
+ return transformed_tensor
@@ -0,0 +1,5 @@
1
+ """Vision utilities and helper functions."""
2
+
3
+ from eva.vision.utils import io
4
+
5
+ __all__ = ["io"]
@@ -0,0 +1,12 @@
1
+ """Vision I/O utilities."""
2
+
3
+ from eva.vision.utils.io.image import read_image
4
+ from eva.vision.utils.io.nifti import fetch_total_nifti_slices, read_nifti_slice
5
+ from eva.vision.utils.io.text import read_csv
6
+
7
+ __all__ = [
8
+ "read_image",
9
+ "fetch_total_nifti_slices",
10
+ "read_nifti_slice",
11
+ "read_csv",
12
+ ]
@@ -0,0 +1,29 @@
1
+ """File IO utilities."""
2
+
3
+ import os
4
+
5
+
6
+ def is_file(path: str) -> bool:
7
+ """Checks if the input path is a valid file.
8
+
9
+ Args:
10
+ path: The file path to be checked.
11
+
12
+ Returns:
13
+ A boolean value whether the file exists.
14
+ """
15
+ return os.path.exists(path) and os.stat(path).st_size != 0 and os.path.isfile(path)
16
+
17
+
18
+ def check_file(path: str) -> None:
19
+ """Checks whether the input path is a valid file and raises and error.
20
+
21
+ Args:
22
+ path: The file path to be checked.
23
+ """
24
+ if not is_file(path):
25
+ raise FileExistsError(
26
+ f"Input '{path if isinstance(path, str) else type(path)}' "
27
+ "could not be recognized as a valid file. Please verify "
28
+ "that the file exists and is reachable."
29
+ )
@@ -0,0 +1,54 @@
1
+ """Image I/O related functions."""
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+
7
+ from eva.vision.utils.io import _utils
8
+
9
+
10
+ def read_image(path: str) -> npt.NDArray[np.uint8]:
11
+ """Reads and loads the image from a file path as a RGB.
12
+
13
+ Args:
14
+ path: The path of the image file.
15
+
16
+ Returns:
17
+ The RGB image as a numpy array.
18
+
19
+ Raises:
20
+ FileExistsError: If the path does not exist or it is unreachable.
21
+ IOError: If the image could not be loaded.
22
+ """
23
+ return read_image_as_array(path, cv2.IMREAD_COLOR)
24
+
25
+
26
+ def read_image_as_array(path: str, flags: int = cv2.IMREAD_UNCHANGED) -> npt.NDArray[np.uint8]:
27
+ """Reads and loads an image file as a numpy array.
28
+
29
+ Args:
30
+ path: The path to the image file.
31
+ flags: Specifies the way in which the image should be read.
32
+
33
+ Returns:
34
+ The image as a numpy array.
35
+
36
+ Raises:
37
+ FileExistsError: If the path does not exist or it is unreachable.
38
+ IOError: If the image could not be loaded.
39
+ """
40
+ _utils.check_file(path)
41
+ image = cv2.imread(path, flags=flags)
42
+ if image is None:
43
+ raise IOError(
44
+ f"Input '{path}' could not be loaded. "
45
+ "Please verify that the path is a valid image file."
46
+ )
47
+
48
+ if image.ndim == 3:
49
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
50
+
51
+ if image.ndim == 2 and flags == cv2.IMREAD_COLOR:
52
+ image = image[:, :, np.newaxis]
53
+
54
+ return np.asarray(image).astype(np.uint8)