kaiko-eva 0.1.8__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eva/core/data/datasets/base.py +7 -2
- eva/core/data/datasets/classification/embeddings.py +2 -2
- eva/core/data/datasets/classification/multi_embeddings.py +2 -2
- eva/core/data/datasets/embeddings.py +4 -4
- eva/core/data/samplers/classification/balanced.py +19 -18
- eva/core/loggers/utils/wandb.py +33 -0
- eva/core/models/modules/head.py +5 -3
- eva/core/models/modules/typings.py +2 -2
- eva/core/models/transforms/__init__.py +2 -1
- eva/core/models/transforms/as_discrete.py +57 -0
- eva/core/models/wrappers/_utils.py +121 -1
- eva/core/trainers/functional.py +8 -5
- eva/core/trainers/trainer.py +32 -17
- eva/core/utils/suppress_logs.py +28 -0
- eva/vision/data/__init__.py +2 -2
- eva/vision/data/dataloaders/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
- eva/vision/data/datasets/__init__.py +10 -2
- eva/vision/data/datasets/classification/__init__.py +9 -0
- eva/vision/data/datasets/classification/bach.py +3 -4
- eva/vision/data/datasets/classification/bracs.py +111 -0
- eva/vision/data/datasets/classification/breakhis.py +209 -0
- eva/vision/data/datasets/classification/camelyon16.py +4 -5
- eva/vision/data/datasets/classification/crc.py +3 -4
- eva/vision/data/datasets/classification/gleason_arvaniti.py +171 -0
- eva/vision/data/datasets/classification/mhist.py +3 -4
- eva/vision/data/datasets/classification/panda.py +4 -5
- eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
- eva/vision/data/datasets/classification/unitopatho.py +158 -0
- eva/vision/data/datasets/classification/wsi.py +6 -5
- eva/vision/data/datasets/segmentation/__init__.py +2 -2
- eva/vision/data/datasets/segmentation/_utils.py +47 -0
- eva/vision/data/datasets/segmentation/bcss.py +7 -8
- eva/vision/data/datasets/segmentation/btcv.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +6 -7
- eva/vision/data/datasets/segmentation/embeddings.py +2 -2
- eva/vision/data/datasets/segmentation/lits.py +9 -8
- eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
- eva/vision/data/datasets/segmentation/monusac.py +4 -5
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
- eva/vision/data/datasets/vision.py +95 -4
- eva/vision/data/datasets/wsi.py +5 -5
- eva/vision/data/transforms/__init__.py +22 -3
- eva/vision/data/transforms/common/__init__.py +1 -2
- eva/vision/data/transforms/croppad/__init__.py +11 -0
- eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
- eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
- eva/vision/data/transforms/intensity/__init__.py +11 -0
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
- eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
- eva/vision/data/transforms/spatial/__init__.py +7 -0
- eva/vision/data/transforms/spatial/flip.py +72 -0
- eva/vision/data/transforms/spatial/rotate.py +53 -0
- eva/vision/data/transforms/spatial/spacing.py +69 -0
- eva/vision/data/transforms/utility/__init__.py +5 -0
- eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
- eva/vision/data/tv_tensors/__init__.py +5 -0
- eva/vision/data/tv_tensors/volume.py +61 -0
- eva/vision/metrics/segmentation/monai_dice.py +9 -2
- eva/vision/models/modules/semantic_segmentation.py +28 -20
- eva/vision/models/networks/backbones/__init__.py +9 -2
- eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
- eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
- eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +46 -19
- eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
- eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
- eva/vision/models/networks/backbones/radiology/voco.py +75 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
- eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
- eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
- eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
- eva/vision/utils/io/__init__.py +2 -0
- eva/vision/utils/io/nifti.py +91 -11
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/METADATA +3 -1
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/RECORD +83 -62
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/WHEEL +1 -1
- eva/vision/data/datasets/classification/base.py +0 -96
- eva/vision/data/datasets/segmentation/base.py +0 -96
- eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
- eva/vision/data/transforms/normalization/__init__.py +0 -6
- eva/vision/data/transforms/normalization/clamp.py +0 -43
- eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
- eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
- eva/vision/metrics/segmentation/BUILD +0 -1
- eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
- eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/licenses/LICENSE +0 -0
eva/core/data/datasets/base.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Base dataset class."""
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
|
+
from typing import Generic, TypeVar
|
|
4
5
|
|
|
5
6
|
from eva.core.data.datasets import dataset
|
|
6
7
|
|
|
@@ -55,11 +56,15 @@ class Dataset(dataset.TorchDataset):
|
|
|
55
56
|
"""
|
|
56
57
|
|
|
57
58
|
|
|
58
|
-
|
|
59
|
+
DataSample = TypeVar("DataSample")
|
|
60
|
+
"""The data sample type."""
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class MapDataset(Dataset, abc.ABC, Generic[DataSample]):
|
|
59
64
|
"""Abstract base class for all map-style datasets."""
|
|
60
65
|
|
|
61
66
|
@abc.abstractmethod
|
|
62
|
-
def __getitem__(self, index: int):
|
|
67
|
+
def __getitem__(self, index: int) -> DataSample:
|
|
63
68
|
"""Retrieves the item at the given index.
|
|
64
69
|
|
|
65
70
|
Args:
|
|
@@ -12,7 +12,7 @@ class EmbeddingsClassificationDataset(embeddings_base.EmbeddingsDataset[torch.Te
|
|
|
12
12
|
"""Embeddings dataset class for classification tasks."""
|
|
13
13
|
|
|
14
14
|
@override
|
|
15
|
-
def
|
|
15
|
+
def load_embeddings(self, index: int) -> torch.Tensor:
|
|
16
16
|
filename = self.filename(index)
|
|
17
17
|
embeddings_path = os.path.join(self._root, filename)
|
|
18
18
|
tensor = torch.load(embeddings_path, map_location="cpu")
|
|
@@ -25,7 +25,7 @@ class EmbeddingsClassificationDataset(embeddings_base.EmbeddingsDataset[torch.Te
|
|
|
25
25
|
return tensor.squeeze(0)
|
|
26
26
|
|
|
27
27
|
@override
|
|
28
|
-
def
|
|
28
|
+
def load_target(self, index: int) -> torch.Tensor:
|
|
29
29
|
target = self._data.at[index, self._column_mapping["target"]]
|
|
30
30
|
return torch.tensor(target, dtype=torch.int64)
|
|
31
31
|
|
|
@@ -66,7 +66,7 @@ class MultiEmbeddingsClassificationDataset(embeddings_base.EmbeddingsDataset[tor
|
|
|
66
66
|
self._multi_ids = list(self._data[self._column_mapping["multi_id"]].unique())
|
|
67
67
|
|
|
68
68
|
@override
|
|
69
|
-
def
|
|
69
|
+
def load_embeddings(self, index: int) -> torch.Tensor:
|
|
70
70
|
"""Loads and stacks all embedding corresponding to the `index`'th multi_id."""
|
|
71
71
|
# Get all embeddings for the given index (multi_id)
|
|
72
72
|
multi_id = self._multi_ids[index]
|
|
@@ -89,7 +89,7 @@ class MultiEmbeddingsClassificationDataset(embeddings_base.EmbeddingsDataset[tor
|
|
|
89
89
|
return embeddings
|
|
90
90
|
|
|
91
91
|
@override
|
|
92
|
-
def
|
|
92
|
+
def load_target(self, index: int) -> np.ndarray:
|
|
93
93
|
"""Returns the target corresponding to the `index`'th multi_id.
|
|
94
94
|
|
|
95
95
|
This method assumes that all the embeddings corresponding to the same `multi_id`
|
|
@@ -98,12 +98,12 @@ class EmbeddingsDataset(base.Dataset, Generic[TargetType]):
|
|
|
98
98
|
Returns:
|
|
99
99
|
A data sample and its target.
|
|
100
100
|
"""
|
|
101
|
-
embeddings = self.
|
|
102
|
-
target = self.
|
|
101
|
+
embeddings = self.load_embeddings(index)
|
|
102
|
+
target = self.load_target(index)
|
|
103
103
|
return self._apply_transforms(embeddings, target)
|
|
104
104
|
|
|
105
105
|
@abc.abstractmethod
|
|
106
|
-
def
|
|
106
|
+
def load_embeddings(self, index: int) -> torch.Tensor:
|
|
107
107
|
"""Returns the `index`'th embedding sample.
|
|
108
108
|
|
|
109
109
|
Args:
|
|
@@ -114,7 +114,7 @@ class EmbeddingsDataset(base.Dataset, Generic[TargetType]):
|
|
|
114
114
|
"""
|
|
115
115
|
|
|
116
116
|
@abc.abstractmethod
|
|
117
|
-
def
|
|
117
|
+
def load_target(self, index: int) -> TargetType:
|
|
118
118
|
"""Returns the `index`'th target sample.
|
|
119
119
|
|
|
120
120
|
Args:
|
|
@@ -4,6 +4,7 @@ from collections import defaultdict
|
|
|
4
4
|
from typing import Dict, Iterator, List
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
|
+
from loguru import logger
|
|
7
8
|
from typing_extensions import override
|
|
8
9
|
|
|
9
10
|
from eva.core.data import datasets
|
|
@@ -33,6 +34,7 @@ class BalancedSampler(SamplerWithDataSource[int]):
|
|
|
33
34
|
self._replacement = replacement
|
|
34
35
|
self._class_indices: Dict[int, List[int]] = defaultdict(list)
|
|
35
36
|
self._random_generator = np.random.default_rng(seed)
|
|
37
|
+
self._indices: List[int] = []
|
|
36
38
|
|
|
37
39
|
def __len__(self) -> int:
|
|
38
40
|
"""Returns the total number of samples."""
|
|
@@ -44,18 +46,7 @@ class BalancedSampler(SamplerWithDataSource[int]):
|
|
|
44
46
|
Returns:
|
|
45
47
|
Iterator yielding dataset indices.
|
|
46
48
|
"""
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
for class_idx in self._class_indices:
|
|
50
|
-
class_indices = self._class_indices[class_idx]
|
|
51
|
-
sampled_indices = self._random_generator.choice(
|
|
52
|
-
class_indices, size=self._num_samples, replace=self._replacement
|
|
53
|
-
).tolist()
|
|
54
|
-
indices.extend(sampled_indices)
|
|
55
|
-
|
|
56
|
-
self._random_generator.shuffle(indices)
|
|
57
|
-
|
|
58
|
-
return iter(indices)
|
|
49
|
+
return iter(self._indices)
|
|
59
50
|
|
|
60
51
|
@override
|
|
61
52
|
def set_dataset(self, data_source: datasets.MapDataset):
|
|
@@ -72,13 +63,13 @@ class BalancedSampler(SamplerWithDataSource[int]):
|
|
|
72
63
|
self._make_indices()
|
|
73
64
|
|
|
74
65
|
def _make_indices(self):
|
|
75
|
-
"""
|
|
66
|
+
"""Samples the indices for each class in the dataset."""
|
|
76
67
|
self._class_indices.clear()
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
68
|
+
for idx in tqdm(range(len(self.data_source)), desc="Fetching class indices for sampler"):
|
|
69
|
+
if hasattr(self.data_source, "load_target"):
|
|
70
|
+
target = self.data_source.load_target(idx) # type: ignore
|
|
71
|
+
else:
|
|
72
|
+
_, target, _ = DataSample(*self.data_source[idx])
|
|
82
73
|
if target is None:
|
|
83
74
|
raise ValueError("The dataset must return non-empty targets.")
|
|
84
75
|
if target.numel() != 1:
|
|
@@ -94,3 +85,13 @@ class BalancedSampler(SamplerWithDataSource[int]):
|
|
|
94
85
|
f"Class {class_idx} has only {len(indices)} samples, "
|
|
95
86
|
f"which is less than the required {self._num_samples} samples."
|
|
96
87
|
)
|
|
88
|
+
|
|
89
|
+
self._indices = []
|
|
90
|
+
for class_idx in self._class_indices:
|
|
91
|
+
class_indices = self._class_indices[class_idx]
|
|
92
|
+
sampled_indices = self._random_generator.choice(
|
|
93
|
+
class_indices, size=self._num_samples, replace=self._replacement
|
|
94
|
+
).tolist()
|
|
95
|
+
self._indices.extend(sampled_indices)
|
|
96
|
+
self._random_generator.shuffle(self._indices)
|
|
97
|
+
logger.debug(f"Sampled indices: {self._indices}")
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# type: ignore
|
|
2
|
+
"""Utility functions for logging with Weights & Biases."""
|
|
3
|
+
|
|
4
|
+
from typing import Any, Dict
|
|
5
|
+
|
|
6
|
+
from loguru import logger
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def rename_active_run(name: str) -> None:
|
|
10
|
+
"""Renames the current run."""
|
|
11
|
+
import wandb
|
|
12
|
+
|
|
13
|
+
if wandb.run:
|
|
14
|
+
wandb.run.name = name
|
|
15
|
+
wandb.run.save()
|
|
16
|
+
else:
|
|
17
|
+
logger.warning("No active wandb run found that could be renamed.")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def init_run(name: str, init_kwargs: Dict[str, Any]) -> None:
|
|
21
|
+
"""Initializes a new run. If there is an active run, it will be renamed and reused."""
|
|
22
|
+
import wandb
|
|
23
|
+
|
|
24
|
+
init_kwargs["name"] = name
|
|
25
|
+
rename_active_run(name)
|
|
26
|
+
wandb.init(**init_kwargs)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def finish_run() -> None:
|
|
30
|
+
"""Finish the current run."""
|
|
31
|
+
import wandb
|
|
32
|
+
|
|
33
|
+
wandb.finish()
|
eva/core/models/modules/head.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Neural Network Head Module."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Callable, Dict
|
|
3
|
+
from typing import Any, Callable, Dict, List
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
|
|
@@ -108,7 +108,9 @@ class HeadModule(module.ModelModule):
|
|
|
108
108
|
return self._batch_step(batch)
|
|
109
109
|
|
|
110
110
|
@override
|
|
111
|
-
def predict_step(
|
|
111
|
+
def predict_step(
|
|
112
|
+
self, batch: INPUT_BATCH, *args: Any, **kwargs: Any
|
|
113
|
+
) -> torch.Tensor | List[torch.Tensor]:
|
|
112
114
|
tensor = INPUT_BATCH(*batch).data
|
|
113
115
|
return tensor if self.backbone is None else self.backbone(tensor)
|
|
114
116
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Type annotations for model modules."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Dict, NamedTuple
|
|
3
|
+
from typing import Any, Dict, List, NamedTuple
|
|
4
4
|
|
|
5
5
|
import lightning.pytorch as pl
|
|
6
6
|
import torch
|
|
@@ -13,7 +13,7 @@ MODEL_TYPE = nn.Module | pl.LightningModule
|
|
|
13
13
|
class INPUT_BATCH(NamedTuple):
|
|
14
14
|
"""The default input batch data scheme."""
|
|
15
15
|
|
|
16
|
-
data: torch.Tensor
|
|
16
|
+
data: torch.Tensor | List[torch.Tensor]
|
|
17
17
|
"""The data batch."""
|
|
18
18
|
|
|
19
19
|
targets: torch.Tensor | None = None
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Model outputs transforms API."""
|
|
2
2
|
|
|
3
|
+
from eva.core.models.transforms.as_discrete import AsDiscrete
|
|
3
4
|
from eva.core.models.transforms.extract_cls_features import ExtractCLSFeatures
|
|
4
5
|
from eva.core.models.transforms.extract_patch_features import ExtractPatchFeatures
|
|
5
6
|
|
|
6
|
-
__all__ = ["ExtractCLSFeatures", "ExtractPatchFeatures"]
|
|
7
|
+
__all__ = ["AsDiscrete", "ExtractCLSFeatures", "ExtractPatchFeatures"]
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Defines the AsDiscrete transformation."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class AsDiscrete:
|
|
7
|
+
"""Convert the logits tensor to discrete values."""
|
|
8
|
+
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
argmax: bool = False,
|
|
12
|
+
to_onehot: int | bool | None = None,
|
|
13
|
+
threshold: float | None = None,
|
|
14
|
+
) -> None:
|
|
15
|
+
"""Convert the input tensor/array into discrete values.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
argmax: Whether to execute argmax function on input data before transform.
|
|
19
|
+
to_onehot: if not None, convert input data into the one-hot format with
|
|
20
|
+
specified number of classes. If bool, it will try to infer the number
|
|
21
|
+
of classes.
|
|
22
|
+
threshold: If not None, threshold the float values to int number 0 or 1
|
|
23
|
+
with specified threshold.
|
|
24
|
+
"""
|
|
25
|
+
super().__init__()
|
|
26
|
+
|
|
27
|
+
self._argmax = argmax
|
|
28
|
+
self._to_onehot = to_onehot
|
|
29
|
+
self._threshold = threshold
|
|
30
|
+
|
|
31
|
+
def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
32
|
+
"""Call method for the transformation."""
|
|
33
|
+
if self._argmax:
|
|
34
|
+
tensor = torch.argmax(tensor, dim=1, keepdim=True)
|
|
35
|
+
|
|
36
|
+
if self._to_onehot is not None:
|
|
37
|
+
tensor = _one_hot(tensor, num_classes=self._to_onehot, dim=1, dtype=torch.long)
|
|
38
|
+
|
|
39
|
+
if self._threshold is not None:
|
|
40
|
+
tensor = tensor >= self._threshold
|
|
41
|
+
|
|
42
|
+
return tensor
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _one_hot(
|
|
46
|
+
tensor: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1
|
|
47
|
+
) -> torch.Tensor:
|
|
48
|
+
"""Convert input tensor into one-hot format (implementation taken from MONAI)."""
|
|
49
|
+
shape = list(tensor.shape)
|
|
50
|
+
if shape[dim] != 1:
|
|
51
|
+
raise AssertionError(f"Input tensor must have 1 channel at dim {dim}.")
|
|
52
|
+
|
|
53
|
+
shape[dim] = num_classes
|
|
54
|
+
o = torch.zeros(size=shape, dtype=dtype, device=tensor.device)
|
|
55
|
+
tensor = o.scatter_(dim=dim, index=tensor.long(), value=1)
|
|
56
|
+
|
|
57
|
+
return tensor
|
|
@@ -1,8 +1,17 @@
|
|
|
1
1
|
"""Utilities and helper functions for models."""
|
|
2
2
|
|
|
3
|
+
import hashlib
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
from typing import Any, Dict
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from fsspec.core import url_to_fs
|
|
3
10
|
from lightning_fabric.utilities import cloud_io
|
|
4
11
|
from loguru import logger
|
|
5
|
-
from torch import nn
|
|
12
|
+
from torch import hub, nn
|
|
13
|
+
|
|
14
|
+
from eva.core.utils.progress_bar import tqdm
|
|
6
15
|
|
|
7
16
|
|
|
8
17
|
def load_model_weights(model: nn.Module, checkpoint_path: str) -> None:
|
|
@@ -23,3 +32,114 @@ def load_model_weights(model: nn.Module, checkpoint_path: str) -> None:
|
|
|
23
32
|
model.load_state_dict(checkpoint, strict=True)
|
|
24
33
|
|
|
25
34
|
logger.info(f"Loading weights from '{checkpoint_path}' completed successfully.")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def load_state_dict_from_url(
|
|
38
|
+
url: str,
|
|
39
|
+
*,
|
|
40
|
+
model_dir: str | None = None,
|
|
41
|
+
filename: str | None = None,
|
|
42
|
+
progress: bool = True,
|
|
43
|
+
md5: str | None = None,
|
|
44
|
+
force: bool = False,
|
|
45
|
+
) -> Dict[str, Any]:
|
|
46
|
+
"""Loads the Torch serialized object at the given URL.
|
|
47
|
+
|
|
48
|
+
If the object is already present and valid in `model_dir`, it's
|
|
49
|
+
deserialized and returned.
|
|
50
|
+
|
|
51
|
+
The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where
|
|
52
|
+
``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
url: URL of the object to download.
|
|
56
|
+
model_dir: Directory in which to save the object.
|
|
57
|
+
filename: Name for the downloaded file. Filename from ``url`` will be used if not set.
|
|
58
|
+
progress: Whether or not to display a progress bar to stderr.
|
|
59
|
+
md5: MD5 file code to check whether the file is valid. If not, it will re-download it.
|
|
60
|
+
force: Whether to download the file regardless if it exists.
|
|
61
|
+
"""
|
|
62
|
+
model_dir = model_dir or os.path.join(hub.get_dir(), "checkpoints")
|
|
63
|
+
os.makedirs(model_dir, exist_ok=True)
|
|
64
|
+
|
|
65
|
+
cached_file = os.path.join(model_dir, filename or os.path.basename(url))
|
|
66
|
+
if force or not os.path.exists(cached_file) or not _check_integrity(cached_file, md5):
|
|
67
|
+
sys.stderr.write(f"Downloading: '{url}' to {cached_file}\n")
|
|
68
|
+
_download_url_to_file(url, cached_file, progress=progress)
|
|
69
|
+
if md5 is None or not _check_integrity(cached_file, md5):
|
|
70
|
+
sys.stderr.write(f"File MD5: {_calculate_md5(cached_file)}\n")
|
|
71
|
+
|
|
72
|
+
return torch.load(cached_file, map_location="cpu")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _download_url_to_file(
|
|
76
|
+
url: str,
|
|
77
|
+
dst: str,
|
|
78
|
+
*,
|
|
79
|
+
progress: bool = True,
|
|
80
|
+
) -> None:
|
|
81
|
+
"""Download object at the given URL to a local path.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
url: URL of the object to download.
|
|
85
|
+
dst: Full path where object will be saved.
|
|
86
|
+
chunk_size: The size of each chunk to read in bytes.
|
|
87
|
+
progress: Whether or not to display a progress bar to stderr.
|
|
88
|
+
"""
|
|
89
|
+
try:
|
|
90
|
+
_download_with_fsspec(url=url, dst=dst, progress=progress)
|
|
91
|
+
except Exception:
|
|
92
|
+
try:
|
|
93
|
+
hub.download_url_to_file(url=url, dst=dst, progress=progress)
|
|
94
|
+
except Exception as hub_e:
|
|
95
|
+
raise RuntimeError(
|
|
96
|
+
f"Failed to download file from {url} using both fsspec and hub."
|
|
97
|
+
) from hub_e
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _download_with_fsspec(
|
|
101
|
+
url: str,
|
|
102
|
+
dst: str,
|
|
103
|
+
*,
|
|
104
|
+
chunk_size: int = 1024 * 1024,
|
|
105
|
+
progress: bool = True,
|
|
106
|
+
) -> None:
|
|
107
|
+
"""Download object at the given URL to a local path using fsspec.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
url: URL of the object to download.
|
|
111
|
+
dst: Full path where object will be saved.
|
|
112
|
+
chunk_size: The size of each chunk to read in bytes.
|
|
113
|
+
progress: Whether or not to display a progress bar to stderr.
|
|
114
|
+
"""
|
|
115
|
+
filesystem, _ = url_to_fs(url, anon=False)
|
|
116
|
+
total_size_bytes = filesystem.size(url)
|
|
117
|
+
with (
|
|
118
|
+
filesystem.open(url, "rb") as remote_file,
|
|
119
|
+
tqdm(
|
|
120
|
+
total=total_size_bytes,
|
|
121
|
+
unit="iB",
|
|
122
|
+
unit_scale=True,
|
|
123
|
+
unit_divisor=1024,
|
|
124
|
+
disable=not progress,
|
|
125
|
+
) as pbar,
|
|
126
|
+
):
|
|
127
|
+
with open(dst, "wb") as local_file:
|
|
128
|
+
while True:
|
|
129
|
+
data = remote_file.read(chunk_size)
|
|
130
|
+
if not data:
|
|
131
|
+
break
|
|
132
|
+
|
|
133
|
+
local_file.write(data)
|
|
134
|
+
pbar.update(chunk_size)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _calculate_md5(path: str) -> str:
|
|
138
|
+
"""Calculate the md5 hash of a file."""
|
|
139
|
+
with open(path, "rb") as file:
|
|
140
|
+
return hashlib.md5(file.read(), usedforsecurity=False).hexdigest()
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _check_integrity(path: str, md5: str | None) -> bool:
|
|
144
|
+
"""Check if the file matches the specified md5 hash."""
|
|
145
|
+
return (md5 is None) or (md5 == _calculate_md5(path))
|
eva/core/trainers/functional.py
CHANGED
|
@@ -39,7 +39,7 @@ def run_evaluation_session(
|
|
|
39
39
|
base_trainer,
|
|
40
40
|
base_model,
|
|
41
41
|
datamodule,
|
|
42
|
-
run_id=
|
|
42
|
+
run_id=run_index,
|
|
43
43
|
verbose=not verbose,
|
|
44
44
|
)
|
|
45
45
|
recorder.update(validation_scores, test_scores)
|
|
@@ -51,7 +51,7 @@ def run_evaluation(
|
|
|
51
51
|
base_model: modules.ModelModule,
|
|
52
52
|
datamodule: datamodules.DataModule,
|
|
53
53
|
*,
|
|
54
|
-
run_id:
|
|
54
|
+
run_id: int | None = None,
|
|
55
55
|
verbose: bool = True,
|
|
56
56
|
) -> Tuple[_EVALUATE_OUTPUT, _EVALUATE_OUTPUT | None]:
|
|
57
57
|
"""Fits and evaluates a model out-of-place.
|
|
@@ -61,7 +61,6 @@ def run_evaluation(
|
|
|
61
61
|
base_model: The model module to use but not modify.
|
|
62
62
|
datamodule: The data module.
|
|
63
63
|
run_id: The run id to be appended to the output log directory.
|
|
64
|
-
If `None`, it will use the log directory of the trainer as is.
|
|
65
64
|
verbose: Whether to print the validation and test metrics
|
|
66
65
|
in the end of the training.
|
|
67
66
|
|
|
@@ -70,8 +69,12 @@ def run_evaluation(
|
|
|
70
69
|
"""
|
|
71
70
|
trainer, model = _utils.clone(base_trainer, base_model)
|
|
72
71
|
model.configure_model()
|
|
73
|
-
|
|
74
|
-
|
|
72
|
+
|
|
73
|
+
trainer.init_logger_run(run_id)
|
|
74
|
+
results = fit_and_validate(trainer, model, datamodule, verbose=verbose)
|
|
75
|
+
trainer.finish_logger_run(run_id)
|
|
76
|
+
|
|
77
|
+
return results
|
|
75
78
|
|
|
76
79
|
|
|
77
80
|
def fit_and_validate(
|
eva/core/trainers/trainer.py
CHANGED
|
@@ -12,6 +12,7 @@ from typing_extensions import override
|
|
|
12
12
|
|
|
13
13
|
from eva.core import loggers as eva_loggers
|
|
14
14
|
from eva.core.data import datamodules
|
|
15
|
+
from eva.core.loggers.utils import wandb as wandb_utils
|
|
15
16
|
from eva.core.models import modules
|
|
16
17
|
from eva.core.trainers import _logging, functional
|
|
17
18
|
|
|
@@ -53,7 +54,7 @@ class Trainer(pl_trainer.Trainer):
|
|
|
53
54
|
self._session_id: str = _logging.generate_session_id()
|
|
54
55
|
self._log_dir: str = self.default_log_dir
|
|
55
56
|
|
|
56
|
-
self.
|
|
57
|
+
self.init_logger_run(0)
|
|
57
58
|
|
|
58
59
|
@property
|
|
59
60
|
def default_log_dir(self) -> str:
|
|
@@ -65,31 +66,45 @@ class Trainer(pl_trainer.Trainer):
|
|
|
65
66
|
def log_dir(self) -> str | None:
|
|
66
67
|
return self.strategy.broadcast(self._log_dir)
|
|
67
68
|
|
|
68
|
-
def
|
|
69
|
-
"""
|
|
69
|
+
def init_logger_run(self, run_id: int | None) -> None:
|
|
70
|
+
"""Setup the loggers & log directories when starting a new run.
|
|
70
71
|
|
|
71
72
|
Args:
|
|
72
|
-
|
|
73
|
+
run_id: The id of the current run.
|
|
73
74
|
"""
|
|
75
|
+
subdirectory = f"run_{run_id}" if run_id is not None else ""
|
|
74
76
|
self._log_dir = os.path.join(self.default_root_dir, self._session_id, subdirectory)
|
|
75
77
|
|
|
76
78
|
enabled_loggers = []
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
if
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
79
|
+
for logger in self.loggers or []:
|
|
80
|
+
if isinstance(logger, (pl_loggers.CSVLogger, pl_loggers.TensorBoardLogger)):
|
|
81
|
+
if not cloud_io._is_local_file_protocol(self.default_root_dir):
|
|
82
|
+
loguru.logger.warning(
|
|
83
|
+
f"Skipped {type(logger).__name__} as remote storage is not supported."
|
|
84
|
+
)
|
|
85
|
+
continue
|
|
86
|
+
else:
|
|
87
|
+
logger._root_dir = self.default_root_dir
|
|
88
|
+
logger._name = self._session_id
|
|
89
|
+
logger._version = subdirectory
|
|
90
|
+
elif isinstance(logger, pl_loggers.WandbLogger):
|
|
91
|
+
task_name = self.default_root_dir.split("/")[-1]
|
|
92
|
+
run_name = os.getenv("WANDB_RUN_NAME", f"{task_name}_{self._session_id}")
|
|
93
|
+
wandb_utils.init_run(f"{run_name}_{run_id}", logger._wandb_init)
|
|
94
|
+
enabled_loggers.append(logger)
|
|
90
95
|
|
|
91
96
|
self._loggers = enabled_loggers or [eva_loggers.DummyLogger(self._log_dir)]
|
|
92
97
|
|
|
98
|
+
def finish_logger_run(self, run_id: int | None) -> None:
|
|
99
|
+
"""Finish the current run in the enabled loggers.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
run_id: The id of the current run.
|
|
103
|
+
"""
|
|
104
|
+
for logger in self.loggers or []:
|
|
105
|
+
if isinstance(logger, pl_loggers.WandbLogger):
|
|
106
|
+
wandb_utils.finish_run()
|
|
107
|
+
|
|
93
108
|
def run_evaluation_session(
|
|
94
109
|
self,
|
|
95
110
|
model: modules.ModelModule,
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Context manager to temporarily suppress all logging outputs."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import sys
|
|
5
|
+
from types import TracebackType
|
|
6
|
+
from typing import Type
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SuppressLogs:
|
|
10
|
+
"""Context manager to suppress all logs but print exceptions if they occur."""
|
|
11
|
+
|
|
12
|
+
def __enter__(self) -> None:
|
|
13
|
+
"""Temporarily increase log level to suppress all logs."""
|
|
14
|
+
self._logger = logging.getLogger()
|
|
15
|
+
self._previous_level = self._logger.level
|
|
16
|
+
self._logger.setLevel(logging.CRITICAL + 1)
|
|
17
|
+
|
|
18
|
+
def __exit__(
|
|
19
|
+
self,
|
|
20
|
+
exc_type: Type[BaseException] | None,
|
|
21
|
+
exc_value: BaseException | None,
|
|
22
|
+
traceback: TracebackType | None,
|
|
23
|
+
) -> bool:
|
|
24
|
+
"""Restores the previous logging level and print exceptions."""
|
|
25
|
+
self._logger.setLevel(self._previous_level)
|
|
26
|
+
if exc_value:
|
|
27
|
+
print(f"Error: {exc_value}", file=sys.stderr)
|
|
28
|
+
return False
|
eva/vision/data/__init__.py
CHANGED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Data only collate filter function."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from eva.core.models.modules.typings import INPUT_BATCH
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def collection_collate(batch: List[List[INPUT_BATCH]]) -> Any:
|
|
11
|
+
"""Collate function for stacking a collection of data samples.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
batch: The batch to be collated.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
The collated batch.
|
|
18
|
+
"""
|
|
19
|
+
tensors, targets, metadata = zip(*batch, strict=False)
|
|
20
|
+
batch_tensors = torch.cat(list(map(torch.stack, tensors)))
|
|
21
|
+
batch_targets = torch.cat(list(map(torch.stack, targets)))
|
|
22
|
+
return batch_tensors, batch_targets, metadata
|
|
@@ -2,19 +2,23 @@
|
|
|
2
2
|
|
|
3
3
|
from eva.vision.data.datasets.classification import (
|
|
4
4
|
BACH,
|
|
5
|
+
BRACS,
|
|
5
6
|
CRC,
|
|
6
7
|
MHIST,
|
|
7
8
|
PANDA,
|
|
9
|
+
BreaKHis,
|
|
8
10
|
Camelyon16,
|
|
11
|
+
GleasonArvaniti,
|
|
9
12
|
PANDASmall,
|
|
10
13
|
PatchCamelyon,
|
|
14
|
+
UniToPatho,
|
|
11
15
|
WsiClassificationDataset,
|
|
12
16
|
)
|
|
13
17
|
from eva.vision.data.datasets.segmentation import (
|
|
14
18
|
BCSS,
|
|
19
|
+
BTCV,
|
|
15
20
|
CoNSeP,
|
|
16
21
|
EmbeddingsSegmentationDataset,
|
|
17
|
-
ImageSegmentation,
|
|
18
22
|
LiTS,
|
|
19
23
|
LiTSBalanced,
|
|
20
24
|
MoNuSAC,
|
|
@@ -25,17 +29,21 @@ from eva.vision.data.datasets.wsi import MultiWsiDataset, WsiDataset
|
|
|
25
29
|
|
|
26
30
|
__all__ = [
|
|
27
31
|
"BACH",
|
|
32
|
+
"BTCV",
|
|
28
33
|
"BCSS",
|
|
34
|
+
"BreaKHis",
|
|
35
|
+
"BRACS",
|
|
29
36
|
"CRC",
|
|
37
|
+
"GleasonArvaniti",
|
|
30
38
|
"MHIST",
|
|
31
39
|
"PANDA",
|
|
32
40
|
"PANDASmall",
|
|
33
41
|
"Camelyon16",
|
|
34
42
|
"PatchCamelyon",
|
|
43
|
+
"UniToPatho",
|
|
35
44
|
"WsiClassificationDataset",
|
|
36
45
|
"CoNSeP",
|
|
37
46
|
"EmbeddingsSegmentationDataset",
|
|
38
|
-
"ImageSegmentation",
|
|
39
47
|
"LiTS",
|
|
40
48
|
"LiTSBalanced",
|
|
41
49
|
"MoNuSAC",
|