kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (168) hide show
  1. eva/core/callbacks/__init__.py +3 -2
  2. eva/core/callbacks/config.py +143 -0
  3. eva/core/callbacks/writers/__init__.py +6 -3
  4. eva/core/callbacks/writers/embeddings/__init__.py +6 -0
  5. eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
  6. eva/core/callbacks/writers/embeddings/base.py +192 -0
  7. eva/core/callbacks/writers/embeddings/classification.py +117 -0
  8. eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
  9. eva/core/callbacks/writers/embeddings/typings.py +38 -0
  10. eva/core/data/datasets/__init__.py +10 -2
  11. eva/core/data/datasets/classification/__init__.py +5 -2
  12. eva/core/data/datasets/classification/embeddings.py +15 -135
  13. eva/core/data/datasets/classification/multi_embeddings.py +110 -0
  14. eva/core/data/datasets/embeddings.py +167 -0
  15. eva/core/data/splitting/__init__.py +6 -0
  16. eva/core/data/splitting/random.py +41 -0
  17. eva/core/data/splitting/stratified.py +56 -0
  18. eva/core/data/transforms/__init__.py +3 -1
  19. eva/core/data/transforms/padding/__init__.py +5 -0
  20. eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
  21. eva/core/data/transforms/sampling/__init__.py +5 -0
  22. eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
  23. eva/core/loggers/__init__.py +7 -0
  24. eva/core/loggers/dummy.py +38 -0
  25. eva/core/loggers/experimental_loggers.py +8 -0
  26. eva/core/loggers/log/__init__.py +6 -0
  27. eva/core/loggers/log/image.py +71 -0
  28. eva/core/loggers/log/parameters.py +74 -0
  29. eva/core/loggers/log/utils.py +13 -0
  30. eva/core/loggers/loggers.py +6 -0
  31. eva/core/metrics/__init__.py +6 -2
  32. eva/core/metrics/defaults/__init__.py +10 -3
  33. eva/core/metrics/defaults/classification/__init__.py +1 -1
  34. eva/core/metrics/defaults/classification/binary.py +0 -9
  35. eva/core/metrics/defaults/classification/multiclass.py +0 -8
  36. eva/core/metrics/defaults/segmentation/__init__.py +5 -0
  37. eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
  38. eva/core/metrics/generalized_dice.py +59 -0
  39. eva/core/metrics/mean_iou.py +120 -0
  40. eva/core/metrics/structs/schemas.py +3 -1
  41. eva/core/models/__init__.py +3 -1
  42. eva/core/models/modules/head.py +16 -15
  43. eva/core/models/modules/module.py +25 -1
  44. eva/core/models/modules/typings.py +14 -1
  45. eva/core/models/modules/utils/batch_postprocess.py +37 -5
  46. eva/core/models/networks/__init__.py +1 -2
  47. eva/core/models/networks/mlp.py +2 -2
  48. eva/core/models/transforms/__init__.py +6 -0
  49. eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
  50. eva/core/models/transforms/extract_patch_features.py +47 -0
  51. eva/core/models/wrappers/__init__.py +13 -0
  52. eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
  53. eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
  54. eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
  55. eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
  56. eva/core/trainers/_recorder.py +69 -7
  57. eva/core/trainers/functional.py +23 -5
  58. eva/core/trainers/trainer.py +20 -6
  59. eva/core/utils/__init__.py +6 -0
  60. eva/core/utils/clone.py +27 -0
  61. eva/core/utils/memory.py +28 -0
  62. eva/core/utils/operations.py +26 -0
  63. eva/core/utils/parser.py +20 -0
  64. eva/vision/__init__.py +2 -2
  65. eva/vision/callbacks/__init__.py +5 -0
  66. eva/vision/callbacks/loggers/__init__.py +5 -0
  67. eva/vision/callbacks/loggers/batch/__init__.py +5 -0
  68. eva/vision/callbacks/loggers/batch/base.py +130 -0
  69. eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
  70. eva/vision/data/datasets/__init__.py +24 -4
  71. eva/vision/data/datasets/_utils.py +3 -3
  72. eva/vision/data/datasets/_validators.py +15 -2
  73. eva/vision/data/datasets/classification/__init__.py +6 -2
  74. eva/vision/data/datasets/classification/bach.py +10 -15
  75. eva/vision/data/datasets/classification/base.py +17 -24
  76. eva/vision/data/datasets/classification/camelyon16.py +244 -0
  77. eva/vision/data/datasets/classification/crc.py +10 -15
  78. eva/vision/data/datasets/classification/mhist.py +10 -15
  79. eva/vision/data/datasets/classification/panda.py +184 -0
  80. eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
  81. eva/vision/data/datasets/classification/wsi.py +105 -0
  82. eva/vision/data/datasets/segmentation/__init__.py +15 -2
  83. eva/vision/data/datasets/segmentation/_utils.py +38 -0
  84. eva/vision/data/datasets/segmentation/base.py +31 -47
  85. eva/vision/data/datasets/segmentation/bcss.py +236 -0
  86. eva/vision/data/datasets/segmentation/consep.py +156 -0
  87. eva/vision/data/datasets/segmentation/embeddings.py +34 -0
  88. eva/vision/data/datasets/segmentation/lits.py +178 -0
  89. eva/vision/data/datasets/segmentation/monusac.py +236 -0
  90. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
  91. eva/vision/data/datasets/wsi.py +187 -0
  92. eva/vision/data/transforms/__init__.py +3 -2
  93. eva/vision/data/transforms/common/__init__.py +2 -1
  94. eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
  95. eva/vision/data/transforms/common/resize_and_crop.py +6 -7
  96. eva/vision/data/transforms/normalization/__init__.py +6 -0
  97. eva/vision/data/transforms/normalization/clamp.py +43 -0
  98. eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
  99. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
  100. eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
  101. eva/vision/data/wsi/__init__.py +16 -0
  102. eva/vision/data/wsi/backends/__init__.py +69 -0
  103. eva/vision/data/wsi/backends/base.py +115 -0
  104. eva/vision/data/wsi/backends/openslide.py +73 -0
  105. eva/vision/data/wsi/backends/pil.py +52 -0
  106. eva/vision/data/wsi/backends/tiffslide.py +42 -0
  107. eva/vision/data/wsi/patching/__init__.py +6 -0
  108. eva/vision/data/wsi/patching/coordinates.py +98 -0
  109. eva/vision/data/wsi/patching/mask.py +123 -0
  110. eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
  111. eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
  112. eva/vision/data/wsi/patching/samplers/base.py +48 -0
  113. eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
  114. eva/vision/data/wsi/patching/samplers/grid.py +47 -0
  115. eva/vision/data/wsi/patching/samplers/random.py +41 -0
  116. eva/vision/losses/__init__.py +5 -0
  117. eva/vision/losses/dice.py +40 -0
  118. eva/vision/models/__init__.py +4 -2
  119. eva/vision/models/modules/__init__.py +5 -0
  120. eva/vision/models/modules/semantic_segmentation.py +161 -0
  121. eva/vision/models/networks/__init__.py +1 -2
  122. eva/vision/models/networks/backbones/__init__.py +6 -0
  123. eva/vision/models/networks/backbones/_utils.py +39 -0
  124. eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
  125. eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
  126. eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
  127. eva/vision/models/networks/backbones/pathology/histai.py +46 -0
  128. eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
  129. eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
  130. eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
  131. eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
  132. eva/vision/models/networks/backbones/registry.py +47 -0
  133. eva/vision/models/networks/backbones/timm/__init__.py +5 -0
  134. eva/vision/models/networks/backbones/timm/backbones.py +54 -0
  135. eva/vision/models/networks/backbones/universal/__init__.py +8 -0
  136. eva/vision/models/networks/backbones/universal/vit.py +54 -0
  137. eva/vision/models/networks/decoders/__init__.py +6 -0
  138. eva/vision/models/networks/decoders/decoder.py +7 -0
  139. eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
  140. eva/vision/models/networks/decoders/segmentation/common.py +74 -0
  141. eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
  142. eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
  143. eva/vision/models/wrappers/__init__.py +6 -0
  144. eva/vision/models/wrappers/from_registry.py +48 -0
  145. eva/vision/models/wrappers/from_timm.py +68 -0
  146. eva/vision/utils/colormap.py +77 -0
  147. eva/vision/utils/convert.py +67 -0
  148. eva/vision/utils/io/__init__.py +10 -4
  149. eva/vision/utils/io/image.py +21 -2
  150. eva/vision/utils/io/mat.py +36 -0
  151. eva/vision/utils/io/nifti.py +40 -15
  152. eva/vision/utils/io/text.py +10 -3
  153. kaiko_eva-0.1.0.dist-info/METADATA +553 -0
  154. kaiko_eva-0.1.0.dist-info/RECORD +205 -0
  155. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
  156. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
  157. eva/core/callbacks/writers/embeddings.py +0 -169
  158. eva/core/callbacks/writers/typings.py +0 -23
  159. eva/core/models/networks/transforms/__init__.py +0 -5
  160. eva/core/models/networks/wrappers/__init__.py +0 -8
  161. eva/vision/data/datasets/classification/total_segmentator.py +0 -213
  162. eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
  163. eva/vision/models/networks/postprocesses/__init__.py +0 -5
  164. eva/vision/models/networks/postprocesses/cls.py +0 -25
  165. kaiko_eva-0.0.1.dist-info/METADATA +0 -405
  166. kaiko_eva-0.0.1.dist-info/RECORD +0 -110
  167. /eva/core/models/{networks → wrappers}/_utils.py +0 -0
  168. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,68 @@
1
+ """Model wrapper for timm models."""
2
+
3
+ from typing import Any, Callable, Dict, Tuple
4
+ from urllib import parse
5
+
6
+ import timm
7
+ from typing_extensions import override
8
+
9
+ from eva.core.models import wrappers
10
+
11
+
12
+ class TimmModel(wrappers.BaseModel):
13
+ """Model wrapper for `timm` models.
14
+
15
+ Note that only models with `forward_intermediates`
16
+ method are currently supported.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ model_name: str,
22
+ pretrained: bool = True,
23
+ checkpoint_path: str = "",
24
+ out_indices: int | Tuple[int, ...] | None = None,
25
+ model_kwargs: Dict[str, Any] | None = None,
26
+ tensor_transforms: Callable | None = None,
27
+ ) -> None:
28
+ """Initializes the encoder.
29
+
30
+ Args:
31
+ model_name: Name of model to instantiate.
32
+ pretrained: If set to `True`, load pretrained ImageNet-1k weights.
33
+ checkpoint_path: Path of checkpoint to load.
34
+ out_indices: Returns last n blocks if `int`, all if `None`, select
35
+ matching indices if sequence.
36
+ model_kwargs: Extra model arguments.
37
+ tensor_transforms: The transforms to apply to the output tensor
38
+ produced by the model.
39
+ """
40
+ super().__init__(tensor_transforms=tensor_transforms)
41
+
42
+ self._model_name = model_name
43
+ self._pretrained = pretrained
44
+ self._checkpoint_path = checkpoint_path
45
+ self._out_indices = out_indices
46
+ self._model_kwargs = model_kwargs or {}
47
+
48
+ self.load_model()
49
+
50
+ @override
51
+ def load_model(self) -> None:
52
+ """Builds and loads the timm model as feature extractor."""
53
+ self._model = timm.create_model(
54
+ model_name=self._model_name,
55
+ pretrained=True if self._checkpoint_path else self._pretrained,
56
+ pretrained_cfg=self._pretrained_cfg,
57
+ out_indices=self._out_indices,
58
+ features_only=self._out_indices is not None,
59
+ **self._model_kwargs,
60
+ )
61
+ TimmModel.__name__ = self._model_name
62
+
63
+ @property
64
+ def _pretrained_cfg(self) -> Dict[str, Any]:
65
+ if not self._checkpoint_path:
66
+ return {}
67
+ key = "file" if parse.urlparse(self._checkpoint_path).scheme in ("file", "") else "url"
68
+ return {key: self._checkpoint_path, "num_classes": 0}
@@ -0,0 +1,77 @@
1
+ """Color mapping constants."""
2
+
3
+ COLORS = [
4
+ (0, 0, 0),
5
+ (255, 0, 0), # Red
6
+ (0, 255, 0), # Green
7
+ (0, 0, 255), # Blue
8
+ (255, 255, 0), # Yellow
9
+ (255, 0, 255), # Magenta
10
+ (0, 255, 255), # Cyan
11
+ (128, 128, 0), # Olive
12
+ (128, 0, 128), # Purple
13
+ (0, 128, 128), # Teal
14
+ (192, 192, 192), # Silver
15
+ (128, 128, 128), # Gray
16
+ (255, 165, 0), # Orange
17
+ (210, 105, 30), # Chocolate
18
+ (0, 128, 0), # Lime
19
+ (255, 192, 203), # Pink
20
+ (255, 69, 0), # Red-Orange
21
+ (255, 140, 0), # Dark Orange
22
+ (0, 255, 255), # Sky Blue
23
+ (0, 255, 127), # Spring Green
24
+ (0, 0, 139), # Dark Blue
25
+ (255, 20, 147), # Deep Pink
26
+ (139, 69, 19), # Saddle Brown
27
+ (0, 100, 0), # Dark Green
28
+ (106, 90, 205), # Slate Blue
29
+ (138, 43, 226), # Blue-Violet
30
+ (218, 165, 32), # Goldenrod
31
+ (199, 21, 133), # Medium Violet Red
32
+ (70, 130, 180), # Steel Blue
33
+ (165, 42, 42), # Brown
34
+ (128, 0, 0), # Maroon
35
+ (255, 0, 255), # Fuchsia
36
+ (210, 180, 140), # Tan
37
+ (0, 0, 128), # Navy
38
+ (139, 0, 139), # Dark Magenta
39
+ (144, 238, 144), # Light Green
40
+ (46, 139, 87), # Sea Green
41
+ (255, 255, 0), # Gold
42
+ (154, 205, 50), # Yellow Green
43
+ (0, 191, 255), # Deep Sky Blue
44
+ (0, 250, 154), # Medium Spring Green
45
+ (250, 128, 114), # Salmon
46
+ (255, 105, 180), # Hot Pink
47
+ (204, 255, 204), # Pastel Light Green
48
+ (51, 0, 51), # Very Dark Magenta
49
+ (255, 102, 0), # Dark Orange
50
+ (0, 255, 0), # Bright Green
51
+ (51, 153, 255), # Blue-Purple
52
+ (51, 51, 255), # Bright Blue
53
+ (204, 0, 0), # Dark Red
54
+ (90, 90, 90), # Very Dark Gray
55
+ (255, 255, 51), # Pastel Yellow
56
+ (255, 153, 255), # Pink-Magenta
57
+ (153, 0, 76), # Dark Pink
58
+ (51, 25, 0), # Very Dark Brown
59
+ (102, 51, 0), # Dark Brown
60
+ (0, 0, 51), # Very Dark Blue
61
+ (180, 180, 180), # Dark Gray
62
+ (102, 255, 204), # Pastel Green
63
+ (0, 102, 0), # Dark Green
64
+ (220, 245, 20), # Lime Yellow
65
+ (255, 204, 204), # Pastel Pink
66
+ (0, 204, 255), # Pastel Blue
67
+ (240, 240, 240), # Light Gray
68
+ (153, 153, 0), # Dark Yellow
69
+ (102, 0, 51), # Dark Red-Pink
70
+ (0, 51, 0), # Very Dark Green
71
+ (255, 102, 204), # Magenta Pink
72
+ (204, 0, 102), # Red-Pink
73
+ ]
74
+ """RGB colors."""
75
+
76
+ COLORMAP = dict(enumerate(COLORS)) | {255: (255, 255, 255)}
77
+ """Class id to RGB color mapping."""
@@ -0,0 +1,67 @@
1
+ """Image conversion related functionalities."""
2
+
3
+ from typing import Iterable
4
+
5
+ import torch
6
+ from torchvision.transforms.v2 import functional
7
+
8
+
9
+ def descale_and_denorm_image(
10
+ image: torch.Tensor,
11
+ mean: Iterable[float] = (0.0, 0.0, 0.0),
12
+ std: Iterable[float] = (1.0, 1.0, 1.0),
13
+ inplace: bool = True,
14
+ ) -> torch.Tensor:
15
+ """De-scales and de-norms an image tensor to (0, 255) range.
16
+
17
+ Args:
18
+ image: An image float tensor.
19
+ mean: The mean that the image channels are normalized with.
20
+ std: The std that the image channels are normalized with.
21
+ inplace: Whether to perform the operation in-place.
22
+
23
+ Returns:
24
+ The image tensor of range (0, 255) range as uint8.
25
+ """
26
+ if not inplace:
27
+ image = image.clone()
28
+
29
+ norm_image = _descale_image(image, mean=mean, std=std)
30
+ return _denorm_image(norm_image)
31
+
32
+
33
+ def _descale_image(
34
+ image: torch.Tensor,
35
+ mean: Iterable[float] = (0.0, 0.0, 0.0),
36
+ std: Iterable[float] = (1.0, 1.0, 1.0),
37
+ ) -> torch.Tensor:
38
+ """De-scales an image tensor to (0., 1.) range.
39
+
40
+ Args:
41
+ image: An image float tensor.
42
+ mean: The normalized channels mean values.
43
+ std: The normalized channels std values.
44
+
45
+ Returns:
46
+ The de-normalized image tensor of range (0., 1.).
47
+ """
48
+ return functional.normalize(
49
+ image,
50
+ mean=[-cmean / cstd for cmean, cstd in zip(mean, std, strict=False)],
51
+ std=[1 / cstd for cstd in std],
52
+ )
53
+
54
+
55
+ def _denorm_image(image: torch.Tensor) -> torch.Tensor:
56
+ """De-normalizes an image tensor from (0., 1.) to (0, 255) range.
57
+
58
+ Args:
59
+ image: An image float tensor.
60
+
61
+ Returns:
62
+ The image tensor of range (0, 255) range as uint8.
63
+ """
64
+ image_scaled = image - image.min()
65
+ image_scaled /= image_scaled.max()
66
+ image_scaled *= 255
67
+ return image_scaled.to(dtype=torch.uint8)
@@ -1,12 +1,18 @@
1
1
  """Vision I/O utilities."""
2
2
 
3
- from eva.vision.utils.io.image import read_image
4
- from eva.vision.utils.io.nifti import fetch_total_nifti_slices, read_nifti_slice
3
+ from eva.vision.utils.io.image import read_image, read_image_as_array, read_image_as_tensor
4
+ from eva.vision.utils.io.mat import read_mat, save_mat
5
+ from eva.vision.utils.io.nifti import fetch_nifti_shape, read_nifti, save_array_as_nifti
5
6
  from eva.vision.utils.io.text import read_csv
6
7
 
7
8
  __all__ = [
8
9
  "read_image",
9
- "fetch_total_nifti_slices",
10
- "read_nifti_slice",
10
+ "read_image_as_array",
11
+ "read_image_as_tensor",
12
+ "fetch_nifti_shape",
13
+ "read_nifti",
14
+ "save_array_as_nifti",
11
15
  "read_csv",
16
+ "read_mat",
17
+ "save_mat",
12
18
  ]
@@ -3,6 +3,8 @@
3
3
  import cv2
4
4
  import numpy as np
5
5
  import numpy.typing as npt
6
+ from torchvision import tv_tensors
7
+ from torchvision.transforms.v2 import functional
6
8
 
7
9
  from eva.vision.utils.io import _utils
8
10
 
@@ -14,7 +16,7 @@ def read_image(path: str) -> npt.NDArray[np.uint8]:
14
16
  path: The path of the image file.
15
17
 
16
18
  Returns:
17
- The RGB image as a numpy array.
19
+ The RGB image as a numpy array (HxWxC).
18
20
 
19
21
  Raises:
20
22
  FileExistsError: If the path does not exist or it is unreachable.
@@ -23,6 +25,23 @@ def read_image(path: str) -> npt.NDArray[np.uint8]:
23
25
  return read_image_as_array(path, cv2.IMREAD_COLOR)
24
26
 
25
27
 
28
+ def read_image_as_tensor(path: str) -> tv_tensors.Image:
29
+ """Reads and loads the image from a file path as a RGB torch tensor.
30
+
31
+ Args:
32
+ path: The path of the image file.
33
+
34
+ Returns:
35
+ The RGB image as a torch tensor (CxHxW).
36
+
37
+ Raises:
38
+ FileExistsError: If the path does not exist or it is unreachable.
39
+ IOError: If the image could not be loaded.
40
+ """
41
+ image_array = read_image(path)
42
+ return functional.to_image(image_array)
43
+
44
+
26
45
  def read_image_as_array(path: str, flags: int = cv2.IMREAD_UNCHANGED) -> npt.NDArray[np.uint8]:
27
46
  """Reads and loads an image file as a numpy array.
28
47
 
@@ -51,4 +70,4 @@ def read_image_as_array(path: str, flags: int = cv2.IMREAD_UNCHANGED) -> npt.NDA
51
70
  if image.ndim == 2 and flags == cv2.IMREAD_COLOR:
52
71
  image = image[:, :, np.newaxis]
53
72
 
54
- return np.asarray(image).astype(np.uint8)
73
+ return np.asarray(image, dtype=np.uint8)
@@ -0,0 +1,36 @@
1
+ """mat I/O related functions."""
2
+
3
+ import os
4
+ from typing import Any, Dict
5
+
6
+ import numpy.typing as npt
7
+ import scipy.io
8
+
9
+ from eva.vision.utils.io import _utils
10
+
11
+
12
+ def read_mat(path: str) -> Dict[str, npt.NDArray[Any]]:
13
+ """Reads and loads a mat file.
14
+
15
+ Args:
16
+ path: The path to the mat file.
17
+
18
+ Returns:
19
+ mat file as dictionary.
20
+
21
+ Raises:
22
+ FileExistsError: If the path does not exist or it is unreachable.
23
+ """
24
+ _utils.check_file(path)
25
+ return scipy.io.loadmat(path)
26
+
27
+
28
+ def save_mat(path: str, data: Dict[str, npt.NDArray[Any]]) -> None:
29
+ """Saves a mat file.
30
+
31
+ Args:
32
+ path: The path to save the mat file.
33
+ data: The dictionary containing the data to save.
34
+ """
35
+ os.makedirs(os.path.dirname(path), exist_ok=True)
36
+ scipy.io.savemat(path, data)
@@ -1,23 +1,27 @@
1
1
  """NIfTI I/O related functions."""
2
2
 
3
- from typing import Any
3
+ from typing import Any, Tuple
4
4
 
5
5
  import nibabel as nib
6
+ import numpy as np
6
7
  import numpy.typing as npt
7
8
 
8
9
  from eva.vision.utils.io import _utils
9
10
 
10
11
 
11
- def read_nifti_slice(path: str, slice_index: int) -> npt.NDArray[Any]:
12
- """Reads and loads a NIfTI image from a file path as `uint8`.
12
+ def read_nifti(
13
+ path: str, slice_index: int | None = None, *, use_storage_dtype: bool = True
14
+ ) -> npt.NDArray[Any]:
15
+ """Reads and loads a NIfTI image from a file path.
13
16
 
14
17
  Args:
15
18
  path: The path to the NIfTI file.
16
- slice_index: The image slice index to return. If `None`, it will
17
- return the full 3D image.
19
+ slice_index: Whether to read only a slice from the file.
20
+ use_storage_dtype: Whether to cast the raw image
21
+ array to the inferred type.
18
22
 
19
23
  Returns:
20
- The image as a numpy array.
24
+ The image as a numpy array (height, width, channels).
21
25
 
22
26
  Raises:
23
27
  FileExistsError: If the path does not exist or it is unreachable.
@@ -25,20 +29,42 @@ def read_nifti_slice(path: str, slice_index: int) -> npt.NDArray[Any]:
25
29
  """
26
30
  _utils.check_file(path)
27
31
  image_data = nib.load(path) # type: ignore
28
- dtype = image_data.get_data_dtype() # type: ignore
29
- image_slice = image_data.slicer[:, :, slice_index : slice_index + 1] # type: ignore
30
- image_array = image_slice.get_fdata()
31
- return image_array.astype(dtype)
32
+ if slice_index is not None:
33
+ image_data = image_data.slicer[:, :, slice_index : slice_index + 1] # type: ignore
32
34
 
35
+ image_array = image_data.get_fdata() # type: ignore
36
+ if use_storage_dtype:
37
+ image_array = image_array.astype(image_data.get_data_dtype()) # type: ignore
33
38
 
34
- def fetch_total_nifti_slices(path: str) -> int:
35
- """Fetches the total slides of a NIfTI image file.
39
+ return image_array
40
+
41
+
42
+ def save_array_as_nifti(
43
+ array: npt.ArrayLike,
44
+ filename: str,
45
+ *,
46
+ dtype: npt.DTypeLike | None = np.int64,
47
+ ) -> None:
48
+ """Saved a numpy array as a NIfTI image file.
49
+
50
+ Args:
51
+ array: The image array to save.
52
+ filename: The name to save the image like.
53
+ dtype: The data type to save the image.
54
+ """
55
+ nifti_image = nib.Nifti1Image(array, affine=np.eye(4), dtype=dtype) # type: ignore
56
+ nifti_image.header.get_xyzt_units()
57
+ nifti_image.to_filename(filename)
58
+
59
+
60
+ def fetch_nifti_shape(path: str) -> Tuple[int]:
61
+ """Fetches the NIfTI image shape from a file.
36
62
 
37
63
  Args:
38
64
  path: The path to the NIfTI file.
39
65
 
40
66
  Returns:
41
- The number of the total available slides.
67
+ The image shape.
42
68
 
43
69
  Raises:
44
70
  FileExistsError: If the path does not exist or it is unreachable.
@@ -46,5 +72,4 @@ def fetch_total_nifti_slices(path: str) -> int:
46
72
  """
47
73
  _utils.check_file(path)
48
74
  image = nib.load(path) # type: ignore
49
- image_shape = image.header.get_data_shape() # type: ignore
50
- return image_shape[-1]
75
+ return image.header.get_data_shape() # type: ignore
@@ -4,15 +4,22 @@ import csv
4
4
  from typing import Dict, List
5
5
 
6
6
 
7
- def read_csv(path: str) -> List[Dict[str, str]]:
7
+ def read_csv(
8
+ path: str,
9
+ *,
10
+ delimiter: str = ",",
11
+ encoding: str = "utf-8",
12
+ ) -> List[Dict[str, str]]:
8
13
  """Reads a CSV file and returns its contents as a list of dictionaries.
9
14
 
10
15
  Args:
11
16
  path: The path to the CSV file.
17
+ delimiter: The character that separates fields in the CSV file.
18
+ encoding: The encoding of the CSV file.
12
19
 
13
20
  Returns:
14
21
  A list of dictionaries representing the data in the CSV file.
15
22
  """
16
- with open(path, newline="") as file:
17
- data = csv.DictReader(file, skipinitialspace=True)
23
+ with open(path, newline="", encoding=encoding) as file:
24
+ data = csv.DictReader(file, skipinitialspace=True, delimiter=delimiter)
18
25
  return list(data)