kaiko-eva 0.1.8__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. eva/core/data/datasets/base.py +7 -2
  2. eva/core/data/datasets/classification/embeddings.py +2 -2
  3. eva/core/data/datasets/classification/multi_embeddings.py +2 -2
  4. eva/core/data/datasets/embeddings.py +4 -4
  5. eva/core/data/samplers/classification/balanced.py +19 -18
  6. eva/core/loggers/utils/wandb.py +33 -0
  7. eva/core/models/modules/head.py +5 -3
  8. eva/core/models/modules/typings.py +2 -2
  9. eva/core/models/transforms/__init__.py +2 -1
  10. eva/core/models/transforms/as_discrete.py +57 -0
  11. eva/core/models/wrappers/_utils.py +121 -1
  12. eva/core/trainers/functional.py +8 -5
  13. eva/core/trainers/trainer.py +32 -17
  14. eva/core/utils/suppress_logs.py +28 -0
  15. eva/vision/data/__init__.py +2 -2
  16. eva/vision/data/dataloaders/__init__.py +5 -0
  17. eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
  18. eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
  19. eva/vision/data/datasets/__init__.py +10 -2
  20. eva/vision/data/datasets/classification/__init__.py +9 -0
  21. eva/vision/data/datasets/classification/bach.py +3 -4
  22. eva/vision/data/datasets/classification/bracs.py +111 -0
  23. eva/vision/data/datasets/classification/breakhis.py +209 -0
  24. eva/vision/data/datasets/classification/camelyon16.py +4 -5
  25. eva/vision/data/datasets/classification/crc.py +3 -4
  26. eva/vision/data/datasets/classification/gleason_arvaniti.py +171 -0
  27. eva/vision/data/datasets/classification/mhist.py +3 -4
  28. eva/vision/data/datasets/classification/panda.py +4 -5
  29. eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
  30. eva/vision/data/datasets/classification/unitopatho.py +158 -0
  31. eva/vision/data/datasets/classification/wsi.py +6 -5
  32. eva/vision/data/datasets/segmentation/__init__.py +2 -2
  33. eva/vision/data/datasets/segmentation/_utils.py +47 -0
  34. eva/vision/data/datasets/segmentation/bcss.py +7 -8
  35. eva/vision/data/datasets/segmentation/btcv.py +236 -0
  36. eva/vision/data/datasets/segmentation/consep.py +6 -7
  37. eva/vision/data/datasets/segmentation/embeddings.py +2 -2
  38. eva/vision/data/datasets/segmentation/lits.py +9 -8
  39. eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
  40. eva/vision/data/datasets/segmentation/monusac.py +4 -5
  41. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
  42. eva/vision/data/datasets/vision.py +95 -4
  43. eva/vision/data/datasets/wsi.py +5 -5
  44. eva/vision/data/transforms/__init__.py +22 -3
  45. eva/vision/data/transforms/common/__init__.py +1 -2
  46. eva/vision/data/transforms/croppad/__init__.py +11 -0
  47. eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
  48. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
  49. eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
  50. eva/vision/data/transforms/intensity/__init__.py +11 -0
  51. eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
  52. eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
  53. eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
  54. eva/vision/data/transforms/spatial/__init__.py +7 -0
  55. eva/vision/data/transforms/spatial/flip.py +72 -0
  56. eva/vision/data/transforms/spatial/rotate.py +53 -0
  57. eva/vision/data/transforms/spatial/spacing.py +69 -0
  58. eva/vision/data/transforms/utility/__init__.py +5 -0
  59. eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
  60. eva/vision/data/tv_tensors/__init__.py +5 -0
  61. eva/vision/data/tv_tensors/volume.py +61 -0
  62. eva/vision/metrics/segmentation/monai_dice.py +9 -2
  63. eva/vision/models/modules/semantic_segmentation.py +28 -20
  64. eva/vision/models/networks/backbones/__init__.py +9 -2
  65. eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
  66. eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
  67. eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
  68. eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
  69. eva/vision/models/networks/backbones/pathology/mahmood.py +46 -19
  70. eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
  71. eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
  72. eva/vision/models/networks/backbones/radiology/voco.py +75 -0
  73. eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
  74. eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
  75. eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
  76. eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
  77. eva/vision/utils/io/__init__.py +2 -0
  78. eva/vision/utils/io/nifti.py +91 -11
  79. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/METADATA +3 -1
  80. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/RECORD +83 -62
  81. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/WHEEL +1 -1
  82. eva/vision/data/datasets/classification/base.py +0 -96
  83. eva/vision/data/datasets/segmentation/base.py +0 -96
  84. eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
  85. eva/vision/data/transforms/normalization/__init__.py +0 -6
  86. eva/vision/data/transforms/normalization/clamp.py +0 -43
  87. eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
  88. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
  89. eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
  90. eva/vision/metrics/segmentation/BUILD +0 -1
  91. eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
  92. eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
  93. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/entry_points.txt +0 -0
  94. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,7 @@
1
1
  """Base dataset class."""
2
2
 
3
3
  import abc
4
+ from typing import Generic, TypeVar
4
5
 
5
6
  from eva.core.data.datasets import dataset
6
7
 
@@ -55,11 +56,15 @@ class Dataset(dataset.TorchDataset):
55
56
  """
56
57
 
57
58
 
58
- class MapDataset(Dataset):
59
+ DataSample = TypeVar("DataSample")
60
+ """The data sample type."""
61
+
62
+
63
+ class MapDataset(Dataset, abc.ABC, Generic[DataSample]):
59
64
  """Abstract base class for all map-style datasets."""
60
65
 
61
66
  @abc.abstractmethod
62
- def __getitem__(self, index: int):
67
+ def __getitem__(self, index: int) -> DataSample:
63
68
  """Retrieves the item at the given index.
64
69
 
65
70
  Args:
@@ -12,7 +12,7 @@ class EmbeddingsClassificationDataset(embeddings_base.EmbeddingsDataset[torch.Te
12
12
  """Embeddings dataset class for classification tasks."""
13
13
 
14
14
  @override
15
- def _load_embeddings(self, index: int) -> torch.Tensor:
15
+ def load_embeddings(self, index: int) -> torch.Tensor:
16
16
  filename = self.filename(index)
17
17
  embeddings_path = os.path.join(self._root, filename)
18
18
  tensor = torch.load(embeddings_path, map_location="cpu")
@@ -25,7 +25,7 @@ class EmbeddingsClassificationDataset(embeddings_base.EmbeddingsDataset[torch.Te
25
25
  return tensor.squeeze(0)
26
26
 
27
27
  @override
28
- def _load_target(self, index: int) -> torch.Tensor:
28
+ def load_target(self, index: int) -> torch.Tensor:
29
29
  target = self._data.at[index, self._column_mapping["target"]]
30
30
  return torch.tensor(target, dtype=torch.int64)
31
31
 
@@ -66,7 +66,7 @@ class MultiEmbeddingsClassificationDataset(embeddings_base.EmbeddingsDataset[tor
66
66
  self._multi_ids = list(self._data[self._column_mapping["multi_id"]].unique())
67
67
 
68
68
  @override
69
- def _load_embeddings(self, index: int) -> torch.Tensor:
69
+ def load_embeddings(self, index: int) -> torch.Tensor:
70
70
  """Loads and stacks all embedding corresponding to the `index`'th multi_id."""
71
71
  # Get all embeddings for the given index (multi_id)
72
72
  multi_id = self._multi_ids[index]
@@ -89,7 +89,7 @@ class MultiEmbeddingsClassificationDataset(embeddings_base.EmbeddingsDataset[tor
89
89
  return embeddings
90
90
 
91
91
  @override
92
- def _load_target(self, index: int) -> np.ndarray:
92
+ def load_target(self, index: int) -> np.ndarray:
93
93
  """Returns the target corresponding to the `index`'th multi_id.
94
94
 
95
95
  This method assumes that all the embeddings corresponding to the same `multi_id`
@@ -98,12 +98,12 @@ class EmbeddingsDataset(base.Dataset, Generic[TargetType]):
98
98
  Returns:
99
99
  A data sample and its target.
100
100
  """
101
- embeddings = self._load_embeddings(index)
102
- target = self._load_target(index)
101
+ embeddings = self.load_embeddings(index)
102
+ target = self.load_target(index)
103
103
  return self._apply_transforms(embeddings, target)
104
104
 
105
105
  @abc.abstractmethod
106
- def _load_embeddings(self, index: int) -> torch.Tensor:
106
+ def load_embeddings(self, index: int) -> torch.Tensor:
107
107
  """Returns the `index`'th embedding sample.
108
108
 
109
109
  Args:
@@ -114,7 +114,7 @@ class EmbeddingsDataset(base.Dataset, Generic[TargetType]):
114
114
  """
115
115
 
116
116
  @abc.abstractmethod
117
- def _load_target(self, index: int) -> TargetType:
117
+ def load_target(self, index: int) -> TargetType:
118
118
  """Returns the `index`'th target sample.
119
119
 
120
120
  Args:
@@ -4,6 +4,7 @@ from collections import defaultdict
4
4
  from typing import Dict, Iterator, List
5
5
 
6
6
  import numpy as np
7
+ from loguru import logger
7
8
  from typing_extensions import override
8
9
 
9
10
  from eva.core.data import datasets
@@ -33,6 +34,7 @@ class BalancedSampler(SamplerWithDataSource[int]):
33
34
  self._replacement = replacement
34
35
  self._class_indices: Dict[int, List[int]] = defaultdict(list)
35
36
  self._random_generator = np.random.default_rng(seed)
37
+ self._indices: List[int] = []
36
38
 
37
39
  def __len__(self) -> int:
38
40
  """Returns the total number of samples."""
@@ -44,18 +46,7 @@ class BalancedSampler(SamplerWithDataSource[int]):
44
46
  Returns:
45
47
  Iterator yielding dataset indices.
46
48
  """
47
- indices = []
48
-
49
- for class_idx in self._class_indices:
50
- class_indices = self._class_indices[class_idx]
51
- sampled_indices = self._random_generator.choice(
52
- class_indices, size=self._num_samples, replace=self._replacement
53
- ).tolist()
54
- indices.extend(sampled_indices)
55
-
56
- self._random_generator.shuffle(indices)
57
-
58
- return iter(indices)
49
+ return iter(self._indices)
59
50
 
60
51
  @override
61
52
  def set_dataset(self, data_source: datasets.MapDataset):
@@ -72,13 +63,13 @@ class BalancedSampler(SamplerWithDataSource[int]):
72
63
  self._make_indices()
73
64
 
74
65
  def _make_indices(self):
75
- """Builds indices for each class in the dataset."""
66
+ """Samples the indices for each class in the dataset."""
76
67
  self._class_indices.clear()
77
-
78
- for idx in tqdm(
79
- range(len(self.data_source)), desc="Fetching class indices for balanced sampler"
80
- ):
81
- _, target, _ = DataSample(*self.data_source[idx])
68
+ for idx in tqdm(range(len(self.data_source)), desc="Fetching class indices for sampler"):
69
+ if hasattr(self.data_source, "load_target"):
70
+ target = self.data_source.load_target(idx) # type: ignore
71
+ else:
72
+ _, target, _ = DataSample(*self.data_source[idx])
82
73
  if target is None:
83
74
  raise ValueError("The dataset must return non-empty targets.")
84
75
  if target.numel() != 1:
@@ -94,3 +85,13 @@ class BalancedSampler(SamplerWithDataSource[int]):
94
85
  f"Class {class_idx} has only {len(indices)} samples, "
95
86
  f"which is less than the required {self._num_samples} samples."
96
87
  )
88
+
89
+ self._indices = []
90
+ for class_idx in self._class_indices:
91
+ class_indices = self._class_indices[class_idx]
92
+ sampled_indices = self._random_generator.choice(
93
+ class_indices, size=self._num_samples, replace=self._replacement
94
+ ).tolist()
95
+ self._indices.extend(sampled_indices)
96
+ self._random_generator.shuffle(self._indices)
97
+ logger.debug(f"Sampled indices: {self._indices}")
@@ -0,0 +1,33 @@
1
+ # type: ignore
2
+ """Utility functions for logging with Weights & Biases."""
3
+
4
+ from typing import Any, Dict
5
+
6
+ from loguru import logger
7
+
8
+
9
+ def rename_active_run(name: str) -> None:
10
+ """Renames the current run."""
11
+ import wandb
12
+
13
+ if wandb.run:
14
+ wandb.run.name = name
15
+ wandb.run.save()
16
+ else:
17
+ logger.warning("No active wandb run found that could be renamed.")
18
+
19
+
20
+ def init_run(name: str, init_kwargs: Dict[str, Any]) -> None:
21
+ """Initializes a new run. If there is an active run, it will be renamed and reused."""
22
+ import wandb
23
+
24
+ init_kwargs["name"] = name
25
+ rename_active_run(name)
26
+ wandb.init(**init_kwargs)
27
+
28
+
29
+ def finish_run() -> None:
30
+ """Finish the current run."""
31
+ import wandb
32
+
33
+ wandb.finish()
@@ -1,6 +1,6 @@
1
- """"Neural Network Head Module."""
1
+ """Neural Network Head Module."""
2
2
 
3
- from typing import Any, Callable, Dict
3
+ from typing import Any, Callable, Dict, List
4
4
 
5
5
  import torch
6
6
  from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
@@ -108,7 +108,9 @@ class HeadModule(module.ModelModule):
108
108
  return self._batch_step(batch)
109
109
 
110
110
  @override
111
- def predict_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> torch.Tensor:
111
+ def predict_step(
112
+ self, batch: INPUT_BATCH, *args: Any, **kwargs: Any
113
+ ) -> torch.Tensor | List[torch.Tensor]:
112
114
  tensor = INPUT_BATCH(*batch).data
113
115
  return tensor if self.backbone is None else self.backbone(tensor)
114
116
 
@@ -1,6 +1,6 @@
1
1
  """Type annotations for model modules."""
2
2
 
3
- from typing import Any, Dict, NamedTuple
3
+ from typing import Any, Dict, List, NamedTuple
4
4
 
5
5
  import lightning.pytorch as pl
6
6
  import torch
@@ -13,7 +13,7 @@ MODEL_TYPE = nn.Module | pl.LightningModule
13
13
  class INPUT_BATCH(NamedTuple):
14
14
  """The default input batch data scheme."""
15
15
 
16
- data: torch.Tensor
16
+ data: torch.Tensor | List[torch.Tensor]
17
17
  """The data batch."""
18
18
 
19
19
  targets: torch.Tensor | None = None
@@ -1,6 +1,7 @@
1
1
  """Model outputs transforms API."""
2
2
 
3
+ from eva.core.models.transforms.as_discrete import AsDiscrete
3
4
  from eva.core.models.transforms.extract_cls_features import ExtractCLSFeatures
4
5
  from eva.core.models.transforms.extract_patch_features import ExtractPatchFeatures
5
6
 
6
- __all__ = ["ExtractCLSFeatures", "ExtractPatchFeatures"]
7
+ __all__ = ["AsDiscrete", "ExtractCLSFeatures", "ExtractPatchFeatures"]
@@ -0,0 +1,57 @@
1
+ """Defines the AsDiscrete transformation."""
2
+
3
+ import torch
4
+
5
+
6
+ class AsDiscrete:
7
+ """Convert the logits tensor to discrete values."""
8
+
9
+ def __init__(
10
+ self,
11
+ argmax: bool = False,
12
+ to_onehot: int | bool | None = None,
13
+ threshold: float | None = None,
14
+ ) -> None:
15
+ """Convert the input tensor/array into discrete values.
16
+
17
+ Args:
18
+ argmax: Whether to execute argmax function on input data before transform.
19
+ to_onehot: if not None, convert input data into the one-hot format with
20
+ specified number of classes. If bool, it will try to infer the number
21
+ of classes.
22
+ threshold: If not None, threshold the float values to int number 0 or 1
23
+ with specified threshold.
24
+ """
25
+ super().__init__()
26
+
27
+ self._argmax = argmax
28
+ self._to_onehot = to_onehot
29
+ self._threshold = threshold
30
+
31
+ def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
32
+ """Call method for the transformation."""
33
+ if self._argmax:
34
+ tensor = torch.argmax(tensor, dim=1, keepdim=True)
35
+
36
+ if self._to_onehot is not None:
37
+ tensor = _one_hot(tensor, num_classes=self._to_onehot, dim=1, dtype=torch.long)
38
+
39
+ if self._threshold is not None:
40
+ tensor = tensor >= self._threshold
41
+
42
+ return tensor
43
+
44
+
45
+ def _one_hot(
46
+ tensor: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1
47
+ ) -> torch.Tensor:
48
+ """Convert input tensor into one-hot format (implementation taken from MONAI)."""
49
+ shape = list(tensor.shape)
50
+ if shape[dim] != 1:
51
+ raise AssertionError(f"Input tensor must have 1 channel at dim {dim}.")
52
+
53
+ shape[dim] = num_classes
54
+ o = torch.zeros(size=shape, dtype=dtype, device=tensor.device)
55
+ tensor = o.scatter_(dim=dim, index=tensor.long(), value=1)
56
+
57
+ return tensor
@@ -1,8 +1,17 @@
1
1
  """Utilities and helper functions for models."""
2
2
 
3
+ import hashlib
4
+ import os
5
+ import sys
6
+ from typing import Any, Dict
7
+
8
+ import torch
9
+ from fsspec.core import url_to_fs
3
10
  from lightning_fabric.utilities import cloud_io
4
11
  from loguru import logger
5
- from torch import nn
12
+ from torch import hub, nn
13
+
14
+ from eva.core.utils.progress_bar import tqdm
6
15
 
7
16
 
8
17
  def load_model_weights(model: nn.Module, checkpoint_path: str) -> None:
@@ -23,3 +32,114 @@ def load_model_weights(model: nn.Module, checkpoint_path: str) -> None:
23
32
  model.load_state_dict(checkpoint, strict=True)
24
33
 
25
34
  logger.info(f"Loading weights from '{checkpoint_path}' completed successfully.")
35
+
36
+
37
+ def load_state_dict_from_url(
38
+ url: str,
39
+ *,
40
+ model_dir: str | None = None,
41
+ filename: str | None = None,
42
+ progress: bool = True,
43
+ md5: str | None = None,
44
+ force: bool = False,
45
+ ) -> Dict[str, Any]:
46
+ """Loads the Torch serialized object at the given URL.
47
+
48
+ If the object is already present and valid in `model_dir`, it's
49
+ deserialized and returned.
50
+
51
+ The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where
52
+ ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`.
53
+
54
+ Args:
55
+ url: URL of the object to download.
56
+ model_dir: Directory in which to save the object.
57
+ filename: Name for the downloaded file. Filename from ``url`` will be used if not set.
58
+ progress: Whether or not to display a progress bar to stderr.
59
+ md5: MD5 file code to check whether the file is valid. If not, it will re-download it.
60
+ force: Whether to download the file regardless if it exists.
61
+ """
62
+ model_dir = model_dir or os.path.join(hub.get_dir(), "checkpoints")
63
+ os.makedirs(model_dir, exist_ok=True)
64
+
65
+ cached_file = os.path.join(model_dir, filename or os.path.basename(url))
66
+ if force or not os.path.exists(cached_file) or not _check_integrity(cached_file, md5):
67
+ sys.stderr.write(f"Downloading: '{url}' to {cached_file}\n")
68
+ _download_url_to_file(url, cached_file, progress=progress)
69
+ if md5 is None or not _check_integrity(cached_file, md5):
70
+ sys.stderr.write(f"File MD5: {_calculate_md5(cached_file)}\n")
71
+
72
+ return torch.load(cached_file, map_location="cpu")
73
+
74
+
75
+ def _download_url_to_file(
76
+ url: str,
77
+ dst: str,
78
+ *,
79
+ progress: bool = True,
80
+ ) -> None:
81
+ """Download object at the given URL to a local path.
82
+
83
+ Args:
84
+ url: URL of the object to download.
85
+ dst: Full path where object will be saved.
86
+ chunk_size: The size of each chunk to read in bytes.
87
+ progress: Whether or not to display a progress bar to stderr.
88
+ """
89
+ try:
90
+ _download_with_fsspec(url=url, dst=dst, progress=progress)
91
+ except Exception:
92
+ try:
93
+ hub.download_url_to_file(url=url, dst=dst, progress=progress)
94
+ except Exception as hub_e:
95
+ raise RuntimeError(
96
+ f"Failed to download file from {url} using both fsspec and hub."
97
+ ) from hub_e
98
+
99
+
100
+ def _download_with_fsspec(
101
+ url: str,
102
+ dst: str,
103
+ *,
104
+ chunk_size: int = 1024 * 1024,
105
+ progress: bool = True,
106
+ ) -> None:
107
+ """Download object at the given URL to a local path using fsspec.
108
+
109
+ Args:
110
+ url: URL of the object to download.
111
+ dst: Full path where object will be saved.
112
+ chunk_size: The size of each chunk to read in bytes.
113
+ progress: Whether or not to display a progress bar to stderr.
114
+ """
115
+ filesystem, _ = url_to_fs(url, anon=False)
116
+ total_size_bytes = filesystem.size(url)
117
+ with (
118
+ filesystem.open(url, "rb") as remote_file,
119
+ tqdm(
120
+ total=total_size_bytes,
121
+ unit="iB",
122
+ unit_scale=True,
123
+ unit_divisor=1024,
124
+ disable=not progress,
125
+ ) as pbar,
126
+ ):
127
+ with open(dst, "wb") as local_file:
128
+ while True:
129
+ data = remote_file.read(chunk_size)
130
+ if not data:
131
+ break
132
+
133
+ local_file.write(data)
134
+ pbar.update(chunk_size)
135
+
136
+
137
+ def _calculate_md5(path: str) -> str:
138
+ """Calculate the md5 hash of a file."""
139
+ with open(path, "rb") as file:
140
+ return hashlib.md5(file.read(), usedforsecurity=False).hexdigest()
141
+
142
+
143
+ def _check_integrity(path: str, md5: str | None) -> bool:
144
+ """Check if the file matches the specified md5 hash."""
145
+ return (md5 is None) or (md5 == _calculate_md5(path))
@@ -39,7 +39,7 @@ def run_evaluation_session(
39
39
  base_trainer,
40
40
  base_model,
41
41
  datamodule,
42
- run_id=f"run_{run_index}",
42
+ run_id=run_index,
43
43
  verbose=not verbose,
44
44
  )
45
45
  recorder.update(validation_scores, test_scores)
@@ -51,7 +51,7 @@ def run_evaluation(
51
51
  base_model: modules.ModelModule,
52
52
  datamodule: datamodules.DataModule,
53
53
  *,
54
- run_id: str | None = None,
54
+ run_id: int | None = None,
55
55
  verbose: bool = True,
56
56
  ) -> Tuple[_EVALUATE_OUTPUT, _EVALUATE_OUTPUT | None]:
57
57
  """Fits and evaluates a model out-of-place.
@@ -61,7 +61,6 @@ def run_evaluation(
61
61
  base_model: The model module to use but not modify.
62
62
  datamodule: The data module.
63
63
  run_id: The run id to be appended to the output log directory.
64
- If `None`, it will use the log directory of the trainer as is.
65
64
  verbose: Whether to print the validation and test metrics
66
65
  in the end of the training.
67
66
 
@@ -70,8 +69,12 @@ def run_evaluation(
70
69
  """
71
70
  trainer, model = _utils.clone(base_trainer, base_model)
72
71
  model.configure_model()
73
- trainer.setup_log_dirs(run_id or "")
74
- return fit_and_validate(trainer, model, datamodule, verbose=verbose)
72
+
73
+ trainer.init_logger_run(run_id)
74
+ results = fit_and_validate(trainer, model, datamodule, verbose=verbose)
75
+ trainer.finish_logger_run(run_id)
76
+
77
+ return results
75
78
 
76
79
 
77
80
  def fit_and_validate(
@@ -12,6 +12,7 @@ from typing_extensions import override
12
12
 
13
13
  from eva.core import loggers as eva_loggers
14
14
  from eva.core.data import datamodules
15
+ from eva.core.loggers.utils import wandb as wandb_utils
15
16
  from eva.core.models import modules
16
17
  from eva.core.trainers import _logging, functional
17
18
 
@@ -53,7 +54,7 @@ class Trainer(pl_trainer.Trainer):
53
54
  self._session_id: str = _logging.generate_session_id()
54
55
  self._log_dir: str = self.default_log_dir
55
56
 
56
- self.setup_log_dirs()
57
+ self.init_logger_run(0)
57
58
 
58
59
  @property
59
60
  def default_log_dir(self) -> str:
@@ -65,31 +66,45 @@ class Trainer(pl_trainer.Trainer):
65
66
  def log_dir(self) -> str | None:
66
67
  return self.strategy.broadcast(self._log_dir)
67
68
 
68
- def setup_log_dirs(self, subdirectory: str = "") -> None:
69
- """Setups the logging directory of the trainer and experimental loggers in-place.
69
+ def init_logger_run(self, run_id: int | None) -> None:
70
+ """Setup the loggers & log directories when starting a new run.
70
71
 
71
72
  Args:
72
- subdirectory: Whether to append a subdirectory to the output log.
73
+ run_id: The id of the current run.
73
74
  """
75
+ subdirectory = f"run_{run_id}" if run_id is not None else ""
74
76
  self._log_dir = os.path.join(self.default_root_dir, self._session_id, subdirectory)
75
77
 
76
78
  enabled_loggers = []
77
- if isinstance(self.loggers, list) and len(self.loggers) > 0:
78
- for logger in self.loggers:
79
- if isinstance(logger, (pl_loggers.CSVLogger, pl_loggers.TensorBoardLogger)):
80
- if not cloud_io._is_local_file_protocol(self.default_root_dir):
81
- loguru.logger.warning(
82
- f"Skipped {type(logger).__name__} as remote storage is not supported."
83
- )
84
- continue
85
- else:
86
- logger._root_dir = self.default_root_dir
87
- logger._name = self._session_id
88
- logger._version = subdirectory
89
- enabled_loggers.append(logger)
79
+ for logger in self.loggers or []:
80
+ if isinstance(logger, (pl_loggers.CSVLogger, pl_loggers.TensorBoardLogger)):
81
+ if not cloud_io._is_local_file_protocol(self.default_root_dir):
82
+ loguru.logger.warning(
83
+ f"Skipped {type(logger).__name__} as remote storage is not supported."
84
+ )
85
+ continue
86
+ else:
87
+ logger._root_dir = self.default_root_dir
88
+ logger._name = self._session_id
89
+ logger._version = subdirectory
90
+ elif isinstance(logger, pl_loggers.WandbLogger):
91
+ task_name = self.default_root_dir.split("/")[-1]
92
+ run_name = os.getenv("WANDB_RUN_NAME", f"{task_name}_{self._session_id}")
93
+ wandb_utils.init_run(f"{run_name}_{run_id}", logger._wandb_init)
94
+ enabled_loggers.append(logger)
90
95
 
91
96
  self._loggers = enabled_loggers or [eva_loggers.DummyLogger(self._log_dir)]
92
97
 
98
+ def finish_logger_run(self, run_id: int | None) -> None:
99
+ """Finish the current run in the enabled loggers.
100
+
101
+ Args:
102
+ run_id: The id of the current run.
103
+ """
104
+ for logger in self.loggers or []:
105
+ if isinstance(logger, pl_loggers.WandbLogger):
106
+ wandb_utils.finish_run()
107
+
93
108
  def run_evaluation_session(
94
109
  self,
95
110
  model: modules.ModelModule,
@@ -0,0 +1,28 @@
1
+ """Context manager to temporarily suppress all logging outputs."""
2
+
3
+ import logging
4
+ import sys
5
+ from types import TracebackType
6
+ from typing import Type
7
+
8
+
9
+ class SuppressLogs:
10
+ """Context manager to suppress all logs but print exceptions if they occur."""
11
+
12
+ def __enter__(self) -> None:
13
+ """Temporarily increase log level to suppress all logs."""
14
+ self._logger = logging.getLogger()
15
+ self._previous_level = self._logger.level
16
+ self._logger.setLevel(logging.CRITICAL + 1)
17
+
18
+ def __exit__(
19
+ self,
20
+ exc_type: Type[BaseException] | None,
21
+ exc_value: BaseException | None,
22
+ traceback: TracebackType | None,
23
+ ) -> bool:
24
+ """Restores the previous logging level and print exceptions."""
25
+ self._logger.setLevel(self._previous_level)
26
+ if exc_value:
27
+ print(f"Error: {exc_value}", file=sys.stderr)
28
+ return False
@@ -1,5 +1,5 @@
1
1
  """Vision data API."""
2
2
 
3
- from eva.vision.data import datasets, transforms
3
+ from eva.vision.data import datasets, transforms, tv_tensors
4
4
 
5
- __all__ = ["datasets", "transforms"]
5
+ __all__ = ["datasets", "transforms", "tv_tensors"]
@@ -0,0 +1,5 @@
1
+ """Dataloader related utilities and functions."""
2
+
3
+ from eva.vision.data.dataloaders import collate_fn
4
+
5
+ __all__ = ["collate_fn"]
@@ -0,0 +1,5 @@
1
+ """Dataloader collate API."""
2
+
3
+ from eva.vision.data.dataloaders.collate_fn.collection import collection_collate
4
+
5
+ __all__ = ["collection_collate"]
@@ -0,0 +1,22 @@
1
+ """Data only collate filter function."""
2
+
3
+ from typing import Any, List
4
+
5
+ import torch
6
+
7
+ from eva.core.models.modules.typings import INPUT_BATCH
8
+
9
+
10
+ def collection_collate(batch: List[List[INPUT_BATCH]]) -> Any:
11
+ """Collate function for stacking a collection of data samples.
12
+
13
+ Args:
14
+ batch: The batch to be collated.
15
+
16
+ Returns:
17
+ The collated batch.
18
+ """
19
+ tensors, targets, metadata = zip(*batch, strict=False)
20
+ batch_tensors = torch.cat(list(map(torch.stack, tensors)))
21
+ batch_targets = torch.cat(list(map(torch.stack, targets)))
22
+ return batch_tensors, batch_targets, metadata
@@ -2,19 +2,23 @@
2
2
 
3
3
  from eva.vision.data.datasets.classification import (
4
4
  BACH,
5
+ BRACS,
5
6
  CRC,
6
7
  MHIST,
7
8
  PANDA,
9
+ BreaKHis,
8
10
  Camelyon16,
11
+ GleasonArvaniti,
9
12
  PANDASmall,
10
13
  PatchCamelyon,
14
+ UniToPatho,
11
15
  WsiClassificationDataset,
12
16
  )
13
17
  from eva.vision.data.datasets.segmentation import (
14
18
  BCSS,
19
+ BTCV,
15
20
  CoNSeP,
16
21
  EmbeddingsSegmentationDataset,
17
- ImageSegmentation,
18
22
  LiTS,
19
23
  LiTSBalanced,
20
24
  MoNuSAC,
@@ -25,17 +29,21 @@ from eva.vision.data.datasets.wsi import MultiWsiDataset, WsiDataset
25
29
 
26
30
  __all__ = [
27
31
  "BACH",
32
+ "BTCV",
28
33
  "BCSS",
34
+ "BreaKHis",
35
+ "BRACS",
29
36
  "CRC",
37
+ "GleasonArvaniti",
30
38
  "MHIST",
31
39
  "PANDA",
32
40
  "PANDASmall",
33
41
  "Camelyon16",
34
42
  "PatchCamelyon",
43
+ "UniToPatho",
35
44
  "WsiClassificationDataset",
36
45
  "CoNSeP",
37
46
  "EmbeddingsSegmentationDataset",
38
- "ImageSegmentation",
39
47
  "LiTS",
40
48
  "LiTSBalanced",
41
49
  "MoNuSAC",