kaiko-eva 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/__init__.py +2 -2
- eva/core/callbacks/writers/__init__.py +6 -3
- eva/core/callbacks/writers/embeddings/__init__.py +6 -0
- eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
- eva/core/callbacks/writers/embeddings/base.py +192 -0
- eva/core/callbacks/writers/embeddings/classification.py +117 -0
- eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
- eva/core/callbacks/writers/embeddings/typings.py +38 -0
- eva/core/data/datasets/__init__.py +2 -2
- eva/core/data/datasets/classification/__init__.py +8 -0
- eva/core/data/datasets/classification/embeddings.py +34 -0
- eva/core/data/datasets/{embeddings/classification → classification}/multi_embeddings.py +13 -9
- eva/core/data/datasets/{embeddings/base.py → embeddings.py} +47 -32
- eva/core/data/splitting/__init__.py +6 -0
- eva/core/data/splitting/random.py +41 -0
- eva/core/data/splitting/stratified.py +56 -0
- eva/core/loggers/experimental_loggers.py +2 -2
- eva/core/loggers/log/__init__.py +3 -2
- eva/core/loggers/log/image.py +71 -0
- eva/core/loggers/log/parameters.py +10 -0
- eva/core/loggers/loggers.py +6 -0
- eva/core/metrics/__init__.py +6 -2
- eva/core/metrics/defaults/__init__.py +10 -3
- eva/core/metrics/defaults/classification/__init__.py +1 -1
- eva/core/metrics/defaults/classification/binary.py +0 -9
- eva/core/metrics/defaults/classification/multiclass.py +0 -8
- eva/core/metrics/defaults/segmentation/__init__.py +5 -0
- eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
- eva/core/metrics/generalized_dice.py +59 -0
- eva/core/metrics/mean_iou.py +120 -0
- eva/core/metrics/structs/schemas.py +3 -1
- eva/core/models/__init__.py +3 -1
- eva/core/models/modules/head.py +10 -4
- eva/core/models/modules/typings.py +14 -1
- eva/core/models/modules/utils/batch_postprocess.py +37 -5
- eva/core/models/networks/__init__.py +1 -2
- eva/core/models/networks/mlp.py +2 -2
- eva/core/models/transforms/__init__.py +6 -0
- eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
- eva/core/models/transforms/extract_patch_features.py +47 -0
- eva/core/models/wrappers/__init__.py +13 -0
- eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
- eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
- eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
- eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
- eva/core/trainers/functional.py +1 -0
- eva/core/utils/__init__.py +6 -0
- eva/core/utils/clone.py +27 -0
- eva/core/utils/memory.py +28 -0
- eva/core/utils/operations.py +26 -0
- eva/core/utils/parser.py +20 -0
- eva/vision/__init__.py +2 -2
- eva/vision/callbacks/__init__.py +5 -0
- eva/vision/callbacks/loggers/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/base.py +130 -0
- eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
- eva/vision/data/datasets/__init__.py +30 -3
- eva/vision/data/datasets/_validators.py +15 -2
- eva/vision/data/datasets/classification/__init__.py +12 -1
- eva/vision/data/datasets/classification/bach.py +10 -15
- eva/vision/data/datasets/classification/base.py +17 -24
- eva/vision/data/datasets/classification/camelyon16.py +244 -0
- eva/vision/data/datasets/classification/crc.py +10 -15
- eva/vision/data/datasets/classification/mhist.py +10 -15
- eva/vision/data/datasets/classification/panda.py +184 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
- eva/vision/data/datasets/classification/wsi.py +105 -0
- eva/vision/data/datasets/segmentation/__init__.py +15 -2
- eva/vision/data/datasets/segmentation/_utils.py +38 -0
- eva/vision/data/datasets/segmentation/base.py +16 -17
- eva/vision/data/datasets/segmentation/bcss.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +156 -0
- eva/vision/data/datasets/segmentation/embeddings.py +34 -0
- eva/vision/data/datasets/segmentation/lits.py +178 -0
- eva/vision/data/datasets/segmentation/monusac.py +236 -0
- eva/vision/data/datasets/segmentation/{total_segmentator.py → total_segmentator_2d.py} +130 -36
- eva/vision/data/datasets/wsi.py +187 -0
- eva/vision/data/transforms/__init__.py +3 -2
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
- eva/vision/data/transforms/common/resize_and_crop.py +6 -7
- eva/vision/data/transforms/normalization/__init__.py +6 -0
- eva/vision/data/transforms/normalization/clamp.py +43 -0
- eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
- eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
- eva/vision/data/wsi/__init__.py +16 -0
- eva/vision/data/wsi/backends/__init__.py +69 -0
- eva/vision/data/wsi/backends/base.py +115 -0
- eva/vision/data/wsi/backends/openslide.py +73 -0
- eva/vision/data/wsi/backends/pil.py +52 -0
- eva/vision/data/wsi/backends/tiffslide.py +42 -0
- eva/vision/data/wsi/patching/__init__.py +6 -0
- eva/vision/data/wsi/patching/coordinates.py +98 -0
- eva/vision/data/wsi/patching/mask.py +123 -0
- eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
- eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
- eva/vision/data/wsi/patching/samplers/base.py +48 -0
- eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
- eva/vision/data/wsi/patching/samplers/grid.py +47 -0
- eva/vision/data/wsi/patching/samplers/random.py +41 -0
- eva/vision/losses/__init__.py +5 -0
- eva/vision/losses/dice.py +40 -0
- eva/vision/models/__init__.py +4 -2
- eva/vision/models/modules/__init__.py +5 -0
- eva/vision/models/modules/semantic_segmentation.py +161 -0
- eva/vision/models/networks/__init__.py +1 -2
- eva/vision/models/networks/backbones/__init__.py +6 -0
- eva/vision/models/networks/backbones/_utils.py +39 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
- eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
- eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
- eva/vision/models/networks/backbones/pathology/histai.py +46 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
- eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
- eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
- eva/vision/models/networks/backbones/registry.py +47 -0
- eva/vision/models/networks/backbones/timm/__init__.py +5 -0
- eva/vision/models/networks/backbones/timm/backbones.py +54 -0
- eva/vision/models/networks/backbones/universal/__init__.py +8 -0
- eva/vision/models/networks/backbones/universal/vit.py +54 -0
- eva/vision/models/networks/decoders/__init__.py +6 -0
- eva/vision/models/networks/decoders/decoder.py +7 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
- eva/vision/models/networks/decoders/segmentation/common.py +74 -0
- eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
- eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
- eva/vision/models/wrappers/__init__.py +6 -0
- eva/vision/models/wrappers/from_registry.py +48 -0
- eva/vision/models/wrappers/from_timm.py +68 -0
- eva/vision/utils/colormap.py +77 -0
- eva/vision/utils/convert.py +56 -13
- eva/vision/utils/io/__init__.py +10 -4
- eva/vision/utils/io/image.py +21 -2
- eva/vision/utils/io/mat.py +36 -0
- eva/vision/utils/io/nifti.py +33 -12
- eva/vision/utils/io/text.py +10 -3
- kaiko_eva-0.1.1.dist-info/METADATA +553 -0
- kaiko_eva-0.1.1.dist-info/RECORD +205 -0
- {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/WHEEL +1 -1
- {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/entry_points.txt +2 -0
- eva/.DS_Store +0 -0
- eva/core/callbacks/writers/embeddings.py +0 -169
- eva/core/callbacks/writers/typings.py +0 -23
- eva/core/data/datasets/embeddings/__init__.py +0 -13
- eva/core/data/datasets/embeddings/classification/__init__.py +0 -10
- eva/core/data/datasets/embeddings/classification/embeddings.py +0 -66
- eva/core/models/networks/transforms/__init__.py +0 -5
- eva/core/models/networks/wrappers/__init__.py +0 -8
- eva/vision/models/.DS_Store +0 -0
- eva/vision/models/networks/.DS_Store +0 -0
- eva/vision/models/networks/postprocesses/__init__.py +0 -5
- eva/vision/models/networks/postprocesses/cls.py +0 -25
- kaiko_eva-0.0.2.dist-info/METADATA +0 -431
- kaiko_eva-0.0.2.dist-info/RECORD +0 -127
- /eva/core/models/{networks → wrappers}/_utils.py +0 -0
- {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from typing import Tuple
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def set_seed(seed: int) -> None:
|
|
8
|
+
random.seed(seed)
|
|
9
|
+
np.random.seed(seed)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_grid_coords_and_indices(
|
|
13
|
+
layer_shape: Tuple[int, int],
|
|
14
|
+
width: int,
|
|
15
|
+
height: int,
|
|
16
|
+
overlap: Tuple[int, int],
|
|
17
|
+
shuffle: bool = True,
|
|
18
|
+
seed: int = 42,
|
|
19
|
+
):
|
|
20
|
+
"""Get grid coordinates and indices.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
layer_shape: The shape of the layer.
|
|
24
|
+
width: The width of the patches.
|
|
25
|
+
height: The height of the patches.
|
|
26
|
+
overlap: The overlap between patches in the grid.
|
|
27
|
+
shuffle: Whether to shuffle the indices.
|
|
28
|
+
seed: The random seed.
|
|
29
|
+
"""
|
|
30
|
+
x_range = range(0, layer_shape[0] - width + 1, width - overlap[0])
|
|
31
|
+
y_range = range(0, layer_shape[1] - height + 1, height - overlap[1])
|
|
32
|
+
x_y = [(x, y) for x in x_range for y in y_range]
|
|
33
|
+
|
|
34
|
+
indices = list(range(len(x_y)))
|
|
35
|
+
if shuffle:
|
|
36
|
+
set_seed(seed)
|
|
37
|
+
np.random.shuffle(indices)
|
|
38
|
+
return x_y, indices
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def validate_dimensions(width: int, height: int, layer_shape: Tuple[int, int]) -> None:
|
|
42
|
+
"""Checks if the width / height is bigger than the layer shape.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
width: The width of the patches.
|
|
46
|
+
height: The height of the patches.
|
|
47
|
+
layer_shape: The shape of the layer.
|
|
48
|
+
"""
|
|
49
|
+
if width > layer_shape[0] or height > layer_shape[1]:
|
|
50
|
+
raise ValueError("The width / height cannot be bigger than the layer shape.")
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Base classes for samplers."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Generator, Tuple
|
|
5
|
+
|
|
6
|
+
from eva.vision.data.wsi.patching.mask import Mask
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Sampler(abc.ABC):
|
|
10
|
+
"""Base class for samplers."""
|
|
11
|
+
|
|
12
|
+
@abc.abstractmethod
|
|
13
|
+
def sample(
|
|
14
|
+
self,
|
|
15
|
+
width: int,
|
|
16
|
+
height: int,
|
|
17
|
+
layer_shape: Tuple[int, int],
|
|
18
|
+
mask: Mask | None = None,
|
|
19
|
+
) -> Generator[Tuple[int, int], None, None]:
|
|
20
|
+
"""Sample patche coordinates.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
width: The width of the patches.
|
|
24
|
+
height: The height of the patches.
|
|
25
|
+
layer_shape: The shape of the layer.
|
|
26
|
+
mask: Tuple containing the mask array and the scaling factor with respect to the
|
|
27
|
+
provided layer_shape. Optional, only required for samplers with foreground
|
|
28
|
+
filtering.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
A generator producing sampled patch coordinates.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ForegroundSampler(Sampler):
|
|
36
|
+
"""Base class for samplers with foreground filtering capabilities."""
|
|
37
|
+
|
|
38
|
+
@abc.abstractmethod
|
|
39
|
+
def is_foreground(
|
|
40
|
+
self,
|
|
41
|
+
mask: Mask,
|
|
42
|
+
x: int,
|
|
43
|
+
y: int,
|
|
44
|
+
width: int,
|
|
45
|
+
height: int,
|
|
46
|
+
min_foreground_ratio: float,
|
|
47
|
+
) -> bool:
|
|
48
|
+
"""Check if a patch contains sufficient foreground."""
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""Foreground grid sampler."""
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
from eva.vision.data.wsi.patching.mask import Mask
|
|
6
|
+
from eva.vision.data.wsi.patching.samplers import _utils, base
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ForegroundGridSampler(base.ForegroundSampler):
|
|
10
|
+
"""Sample patches based on a grid, only returning patches containing foreground."""
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
max_samples: int = 20,
|
|
15
|
+
overlap: Tuple[int, int] = (0, 0),
|
|
16
|
+
min_foreground_ratio: float = 0.35,
|
|
17
|
+
seed: int = 42,
|
|
18
|
+
) -> None:
|
|
19
|
+
"""Initializes the sampler.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
max_samples: The maximum number of samples to return.
|
|
23
|
+
overlap: The overlap between patches in the grid.
|
|
24
|
+
min_foreground_ratio: The minimum amount of foreground
|
|
25
|
+
within a sampled patch.
|
|
26
|
+
seed: The random seed.
|
|
27
|
+
"""
|
|
28
|
+
self.max_samples = max_samples
|
|
29
|
+
self.overlap = overlap
|
|
30
|
+
self.min_foreground_ratio = min_foreground_ratio
|
|
31
|
+
self.seed = seed
|
|
32
|
+
|
|
33
|
+
def sample(
|
|
34
|
+
self,
|
|
35
|
+
width: int,
|
|
36
|
+
height: int,
|
|
37
|
+
layer_shape: Tuple[int, int],
|
|
38
|
+
mask: Mask,
|
|
39
|
+
):
|
|
40
|
+
"""Sample patches from a grid containing foreground.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
width: The width of the patches.
|
|
44
|
+
height: The height of the patches.
|
|
45
|
+
layer_shape: The shape of the layer.
|
|
46
|
+
mask: The mask of the image.
|
|
47
|
+
"""
|
|
48
|
+
_utils.validate_dimensions(width, height, layer_shape)
|
|
49
|
+
x_y, indices = _utils.get_grid_coords_and_indices(
|
|
50
|
+
layer_shape, width, height, self.overlap, seed=self.seed
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
count = 0
|
|
54
|
+
for i in indices:
|
|
55
|
+
if count >= self.max_samples:
|
|
56
|
+
break
|
|
57
|
+
|
|
58
|
+
if self.is_foreground(
|
|
59
|
+
mask=mask,
|
|
60
|
+
x=x_y[i][0],
|
|
61
|
+
y=x_y[i][1],
|
|
62
|
+
width=width,
|
|
63
|
+
height=height,
|
|
64
|
+
min_foreground_ratio=self.min_foreground_ratio,
|
|
65
|
+
):
|
|
66
|
+
count += 1
|
|
67
|
+
yield x_y[i]
|
|
68
|
+
|
|
69
|
+
def is_foreground(
|
|
70
|
+
self,
|
|
71
|
+
mask: Mask,
|
|
72
|
+
x: int,
|
|
73
|
+
y: int,
|
|
74
|
+
width: int,
|
|
75
|
+
height: int,
|
|
76
|
+
min_foreground_ratio: float,
|
|
77
|
+
) -> bool:
|
|
78
|
+
"""Check if a patch contains sufficient foreground.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
mask: The mask of the image.
|
|
82
|
+
x: The x-coordinate of the patch.
|
|
83
|
+
y: The y-coordinate of the patch.
|
|
84
|
+
width: The width of the patch.
|
|
85
|
+
height: The height of the patch.
|
|
86
|
+
min_foreground_ratio: The minimum amount of foreground in the patch.
|
|
87
|
+
"""
|
|
88
|
+
x_, y_ = self._scale_coords(x, y, mask.scale_factors)
|
|
89
|
+
width_, height_ = self._scale_coords(width, height, mask.scale_factors)
|
|
90
|
+
patch_mask = mask.mask_array[y_ : y_ + height_, x_ : x_ + width_]
|
|
91
|
+
return patch_mask.sum() / patch_mask.size >= min_foreground_ratio
|
|
92
|
+
|
|
93
|
+
def _scale_coords(
|
|
94
|
+
self,
|
|
95
|
+
x: int,
|
|
96
|
+
y: int,
|
|
97
|
+
scale_factors: Tuple[float, float],
|
|
98
|
+
) -> Tuple[int, int]:
|
|
99
|
+
return int(x / scale_factors[0]), int(y / scale_factors[1])
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Grid sampler."""
|
|
2
|
+
|
|
3
|
+
from typing import Generator, Tuple
|
|
4
|
+
|
|
5
|
+
from eva.vision.data.wsi.patching.samplers import _utils, base
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class GridSampler(base.Sampler):
|
|
9
|
+
"""Sample patches based on a grid.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
max_samples: The maximum number of samples to return.
|
|
13
|
+
overlap: The overlap between patches in the grid.
|
|
14
|
+
seed: The random seed.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
max_samples: int | None = None,
|
|
20
|
+
overlap: Tuple[int, int] = (0, 0),
|
|
21
|
+
seed: int = 42,
|
|
22
|
+
):
|
|
23
|
+
"""Initializes the sampler."""
|
|
24
|
+
self.max_samples = max_samples
|
|
25
|
+
self.overlap = overlap
|
|
26
|
+
self.seed = seed
|
|
27
|
+
|
|
28
|
+
def sample(
|
|
29
|
+
self,
|
|
30
|
+
width: int,
|
|
31
|
+
height: int,
|
|
32
|
+
layer_shape: Tuple[int, int],
|
|
33
|
+
) -> Generator[Tuple[int, int], None, None]:
|
|
34
|
+
"""Sample patches from a grid.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
width: The width of the patches.
|
|
38
|
+
height: The height of the patches.
|
|
39
|
+
layer_shape: The shape of the layer.
|
|
40
|
+
"""
|
|
41
|
+
_utils.validate_dimensions(width, height, layer_shape)
|
|
42
|
+
x_y, indices = _utils.get_grid_coords_and_indices(
|
|
43
|
+
layer_shape, width, height, self.overlap, seed=self.seed
|
|
44
|
+
)
|
|
45
|
+
max_samples = len(indices) if self.max_samples is None else self.max_samples
|
|
46
|
+
for i in indices[:max_samples]:
|
|
47
|
+
yield x_y[i]
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Random sampler."""
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from typing import Generator, Tuple
|
|
5
|
+
|
|
6
|
+
from eva.vision.data.wsi.patching.samplers import _utils, base
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class RandomSampler(base.Sampler):
|
|
10
|
+
"""Sample patch coordinates randomly.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
n_samples: The number of samples to return.
|
|
14
|
+
seed: The random seed.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, n_samples: int = 1, seed: int = 42):
|
|
18
|
+
"""Initializes the sampler."""
|
|
19
|
+
self.seed = seed
|
|
20
|
+
self.n_samples = n_samples
|
|
21
|
+
|
|
22
|
+
def sample(
|
|
23
|
+
self,
|
|
24
|
+
width: int,
|
|
25
|
+
height: int,
|
|
26
|
+
layer_shape: Tuple[int, int],
|
|
27
|
+
) -> Generator[Tuple[int, int], None, None]:
|
|
28
|
+
"""Sample random patches.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
width: The width of the patches.
|
|
32
|
+
height: The height of the patches.
|
|
33
|
+
layer_shape: The shape of the layer.
|
|
34
|
+
"""
|
|
35
|
+
_utils.validate_dimensions(width, height, layer_shape)
|
|
36
|
+
_utils.set_seed(self.seed)
|
|
37
|
+
|
|
38
|
+
x_max, y_max = layer_shape[0], layer_shape[1]
|
|
39
|
+
for _ in range(self.n_samples):
|
|
40
|
+
x, y = random.randint(0, x_max - width), random.randint(0, y_max - height) # nosec
|
|
41
|
+
yield x, y
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Dice loss."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from monai import losses
|
|
5
|
+
from monai.networks import one_hot # type: ignore
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DiceLoss(losses.DiceLoss): # type: ignore
|
|
10
|
+
"""Computes the average Dice loss between two tensors.
|
|
11
|
+
|
|
12
|
+
Extends the implementation from MONAI
|
|
13
|
+
- to support semantic target labels (meaning targets of shape BHW)
|
|
14
|
+
- to support `ignore_index` functionality
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, *args, ignore_index: int | None = None, **kwargs) -> None:
|
|
18
|
+
"""Initialize the DiceLoss with support for ignore_index.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
args: Positional arguments from the base class.
|
|
22
|
+
ignore_index: Specifies a target value that is ignored and
|
|
23
|
+
does not contribute to the input gradient.
|
|
24
|
+
kwargs: Key-word arguments from the base class.
|
|
25
|
+
"""
|
|
26
|
+
super().__init__(*args, **kwargs)
|
|
27
|
+
|
|
28
|
+
self.ignore_index = ignore_index
|
|
29
|
+
|
|
30
|
+
@override
|
|
31
|
+
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # noqa
|
|
32
|
+
if self.ignore_index is not None:
|
|
33
|
+
mask = targets != self.ignore_index
|
|
34
|
+
targets = targets * mask
|
|
35
|
+
inputs = torch.mul(inputs, mask.unsqueeze(1) if mask.ndim == 3 else mask)
|
|
36
|
+
|
|
37
|
+
if targets.ndim == 3:
|
|
38
|
+
targets = one_hot(targets[:, None, ...], num_classes=inputs.shape[1])
|
|
39
|
+
|
|
40
|
+
return super().forward(inputs, targets)
|
eva/vision/models/__init__.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
"""Vision Models API."""
|
|
2
2
|
|
|
3
|
-
from eva.vision.models import networks
|
|
3
|
+
from eva.vision.models import networks, wrappers
|
|
4
|
+
from eva.vision.models.networks import backbones
|
|
5
|
+
from eva.vision.models.wrappers import ModelFromRegistry, TimmModel
|
|
4
6
|
|
|
5
|
-
__all__ = ["networks"]
|
|
7
|
+
__all__ = ["networks", "wrappers", "backbones", "ModelFromRegistry", "TimmModel"]
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
""""Neural Network Semantic Segmentation Module."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
|
|
7
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
8
|
+
from torch import nn, optim
|
|
9
|
+
from torch.optim import lr_scheduler
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from eva.core.metrics import structs as metrics_lib
|
|
13
|
+
from eva.core.models.modules import module
|
|
14
|
+
from eva.core.models.modules.typings import INPUT_BATCH, INPUT_TENSOR_BATCH
|
|
15
|
+
from eva.core.models.modules.utils import batch_postprocess, grad
|
|
16
|
+
from eva.core.utils import parser
|
|
17
|
+
from eva.vision.models.networks import decoders
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SemanticSegmentationModule(module.ModelModule):
|
|
21
|
+
"""Neural network semantic segmentation module for training on patch embeddings."""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
decoder: decoders.Decoder,
|
|
26
|
+
criterion: Callable[..., torch.Tensor],
|
|
27
|
+
encoder: Dict[str, Any] | Callable[[torch.Tensor], List[torch.Tensor]] | None = None,
|
|
28
|
+
lr_multiplier_encoder: float = 0.0,
|
|
29
|
+
optimizer: OptimizerCallable = optim.AdamW,
|
|
30
|
+
lr_scheduler: LRSchedulerCallable = lr_scheduler.ConstantLR,
|
|
31
|
+
metrics: metrics_lib.MetricsSchema | None = None,
|
|
32
|
+
postprocess: batch_postprocess.BatchPostProcess | None = None,
|
|
33
|
+
) -> None:
|
|
34
|
+
"""Initializes the neural net head module.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
decoder: The decoder model.
|
|
38
|
+
criterion: The loss function to use.
|
|
39
|
+
encoder: The encoder model. If `None`, it will be expected
|
|
40
|
+
that the input batch returns the features directly.
|
|
41
|
+
If pass as a dictionary, it will be parsed to an object
|
|
42
|
+
during the `configure_model` step.
|
|
43
|
+
lr_multiplier_encoder: The learning rate multiplier for the
|
|
44
|
+
encoder parameters. If `0`, it will freeze the encoder.
|
|
45
|
+
optimizer: The optimizer to use.
|
|
46
|
+
lr_scheduler: The learning rate scheduler to use.
|
|
47
|
+
metrics: The metric groups to track.
|
|
48
|
+
postprocess: A list of helper functions to apply after the
|
|
49
|
+
loss and before the metrics calculation to the model
|
|
50
|
+
predictions and targets.
|
|
51
|
+
"""
|
|
52
|
+
super().__init__(metrics=metrics, postprocess=postprocess)
|
|
53
|
+
|
|
54
|
+
self.decoder = decoder
|
|
55
|
+
self.criterion = criterion
|
|
56
|
+
self.encoder = encoder # type: ignore
|
|
57
|
+
self.lr_multiplier_encoder = lr_multiplier_encoder
|
|
58
|
+
self.optimizer = optimizer
|
|
59
|
+
self.lr_scheduler = lr_scheduler
|
|
60
|
+
|
|
61
|
+
@override
|
|
62
|
+
def configure_model(self) -> None:
|
|
63
|
+
self._freeze_encoder()
|
|
64
|
+
|
|
65
|
+
if isinstance(self.encoder, dict):
|
|
66
|
+
self.encoder: Callable[[torch.Tensor], List[torch.Tensor]] = parser.parse_object(
|
|
67
|
+
self.encoder,
|
|
68
|
+
expected_type=nn.Module,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
@override
|
|
72
|
+
def configure_optimizers(self) -> Any:
|
|
73
|
+
optimizer = self.optimizer(
|
|
74
|
+
[
|
|
75
|
+
{"params": self.decoder.parameters()},
|
|
76
|
+
{
|
|
77
|
+
"params": self._encoder_trainable_parameters(),
|
|
78
|
+
"lr": self._base_lr * self.lr_multiplier_encoder,
|
|
79
|
+
},
|
|
80
|
+
]
|
|
81
|
+
)
|
|
82
|
+
lr_scheduler = self.lr_scheduler(optimizer)
|
|
83
|
+
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
|
|
84
|
+
|
|
85
|
+
@override
|
|
86
|
+
def forward(
|
|
87
|
+
self,
|
|
88
|
+
inputs: torch.Tensor,
|
|
89
|
+
to_size: Tuple[int, int] | None = None,
|
|
90
|
+
*args: Any,
|
|
91
|
+
**kwargs: Any,
|
|
92
|
+
) -> torch.Tensor:
|
|
93
|
+
"""Maps the input tensor (image tensor or embeddings) to masks.
|
|
94
|
+
|
|
95
|
+
If `inputs` is image tensor, then the `self.encoder`
|
|
96
|
+
should be implemented, otherwise it will be interpreted
|
|
97
|
+
as embeddings, where the `to_size` should be given.
|
|
98
|
+
"""
|
|
99
|
+
if self.encoder is None and to_size is None:
|
|
100
|
+
raise ValueError(
|
|
101
|
+
"Please provide the expected `to_size` that the "
|
|
102
|
+
"decoder should map the embeddings (`inputs`) to."
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
patch_embeddings = self.encoder(inputs) if self.encoder else inputs
|
|
106
|
+
return self.decoder(patch_embeddings, to_size or inputs.shape[-2:])
|
|
107
|
+
|
|
108
|
+
@override
|
|
109
|
+
def training_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
110
|
+
return self._batch_step(batch)
|
|
111
|
+
|
|
112
|
+
@override
|
|
113
|
+
def validation_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
114
|
+
return self._batch_step(batch)
|
|
115
|
+
|
|
116
|
+
@override
|
|
117
|
+
def test_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
118
|
+
return self._batch_step(batch)
|
|
119
|
+
|
|
120
|
+
@override
|
|
121
|
+
def predict_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> torch.Tensor:
|
|
122
|
+
tensor = INPUT_BATCH(*batch).data
|
|
123
|
+
return self.encoder(tensor) if isinstance(self.encoder, nn.Module) else tensor
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def _base_lr(self) -> float:
|
|
127
|
+
"""Returns the base learning rate."""
|
|
128
|
+
base_optimizer = self.optimizer(self.parameters())
|
|
129
|
+
return base_optimizer.param_groups[-1]["lr"]
|
|
130
|
+
|
|
131
|
+
def _encoder_trainable_parameters(self) -> Iterable[torch.Tensor]:
|
|
132
|
+
"""Returns the trainable parameters of the encoder."""
|
|
133
|
+
return (
|
|
134
|
+
self.encoder.parameters()
|
|
135
|
+
if isinstance(self.encoder, nn.Module) and self.lr_multiplier_encoder > 0
|
|
136
|
+
else iter(())
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def _freeze_encoder(self) -> None:
|
|
140
|
+
"""If initialized, it freezes the encoder network."""
|
|
141
|
+
if isinstance(self.encoder, nn.Module) and self.lr_multiplier_encoder == 0:
|
|
142
|
+
grad.deactivate_requires_grad(self.encoder)
|
|
143
|
+
|
|
144
|
+
def _batch_step(self, batch: INPUT_TENSOR_BATCH) -> STEP_OUTPUT:
|
|
145
|
+
"""Performs a model forward step and calculates the loss.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
batch: The desired batch to process.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
The batch step output.
|
|
152
|
+
"""
|
|
153
|
+
data, targets, metadata = INPUT_TENSOR_BATCH(*batch)
|
|
154
|
+
predictions = self(data, to_size=targets.shape[-2:])
|
|
155
|
+
loss = self.criterion(predictions, targets)
|
|
156
|
+
return {
|
|
157
|
+
"loss": loss,
|
|
158
|
+
"targets": targets,
|
|
159
|
+
"predictions": predictions,
|
|
160
|
+
"metadata": metadata,
|
|
161
|
+
}
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
"""Vision Model Backbones API."""
|
|
2
|
+
|
|
3
|
+
from eva.vision.models.networks.backbones import pathology, timm, universal
|
|
4
|
+
from eva.vision.models.networks.backbones.registry import BackboneModelRegistry, register_model
|
|
5
|
+
|
|
6
|
+
__all__ = ["pathology", "timm", "universal", "BackboneModelRegistry", "register_model"]
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Utilis for backbone networks."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Tuple
|
|
4
|
+
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
from eva import models
|
|
8
|
+
from eva.core.models import transforms
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def load_hugingface_model(
|
|
12
|
+
model_name: str,
|
|
13
|
+
out_indices: int | Tuple[int, ...] | None,
|
|
14
|
+
model_kwargs: Dict[str, Any] | None = None,
|
|
15
|
+
transform_args: Dict[str, Any] | None = None,
|
|
16
|
+
) -> nn.Module:
|
|
17
|
+
"""Helper function to load HuggingFace models.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
model_name: The model name to load.
|
|
21
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
22
|
+
Currently only out_indices=1 is supported.
|
|
23
|
+
model_kwargs: The arguments used for instantiating the model.
|
|
24
|
+
transform_args: The arguments used for instantiating the transform.
|
|
25
|
+
|
|
26
|
+
Returns: The model instance.
|
|
27
|
+
"""
|
|
28
|
+
if out_indices is None:
|
|
29
|
+
tensor_transforms = transforms.ExtractCLSFeatures(**(transform_args or {}))
|
|
30
|
+
elif out_indices == 1:
|
|
31
|
+
tensor_transforms = transforms.ExtractPatchFeatures(**(transform_args or {}))
|
|
32
|
+
else:
|
|
33
|
+
raise ValueError(f"out_indices={out_indices} is currently not supported.")
|
|
34
|
+
|
|
35
|
+
return models.HuggingFaceModel(
|
|
36
|
+
model_name_or_path=model_name,
|
|
37
|
+
tensor_transforms=tensor_transforms,
|
|
38
|
+
model_kwargs=model_kwargs,
|
|
39
|
+
)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Vision Pathology Model Backbones API."""
|
|
2
|
+
|
|
3
|
+
from eva.vision.models.networks.backbones.pathology.bioptimus import bioptimus_h_optimus_0
|
|
4
|
+
from eva.vision.models.networks.backbones.pathology.gigapath import prov_gigapath
|
|
5
|
+
from eva.vision.models.networks.backbones.pathology.histai import histai_hibou_b, histai_hibou_l
|
|
6
|
+
from eva.vision.models.networks.backbones.pathology.kaiko import (
|
|
7
|
+
kaiko_vitb8,
|
|
8
|
+
kaiko_vitb16,
|
|
9
|
+
kaiko_vitl14,
|
|
10
|
+
kaiko_vits8,
|
|
11
|
+
kaiko_vits16,
|
|
12
|
+
)
|
|
13
|
+
from eva.vision.models.networks.backbones.pathology.lunit import lunit_vits8, lunit_vits16
|
|
14
|
+
from eva.vision.models.networks.backbones.pathology.mahmood import mahmood_uni
|
|
15
|
+
from eva.vision.models.networks.backbones.pathology.owkin import owkin_phikon
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"kaiko_vitb16",
|
|
19
|
+
"kaiko_vitb8",
|
|
20
|
+
"kaiko_vitl14",
|
|
21
|
+
"kaiko_vits16",
|
|
22
|
+
"kaiko_vits8",
|
|
23
|
+
"owkin_phikon",
|
|
24
|
+
"lunit_vits16",
|
|
25
|
+
"lunit_vits8",
|
|
26
|
+
"mahmood_uni",
|
|
27
|
+
"bioptimus_h_optimus_0",
|
|
28
|
+
"prov_gigapath",
|
|
29
|
+
"histai_hibou_b",
|
|
30
|
+
"histai_hibou_l",
|
|
31
|
+
]
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Pathology FMs from Bioptimus."""
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
import timm
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from eva.vision.models.networks.backbones.registry import register_model
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@register_model("pathology/bioptimus_h_optimus_0")
|
|
12
|
+
def bioptimus_h_optimus_0(
|
|
13
|
+
dynamic_img_size: bool = True,
|
|
14
|
+
out_indices: int | Tuple[int, ...] | None = None,
|
|
15
|
+
) -> nn.Module:
|
|
16
|
+
"""Initializes the h_optimus_0 pathology FM by Bioptimus.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
dynamic_img_size: Whether to allow the interpolation embedding
|
|
20
|
+
to be interpolated at `forward()` time when image grid changes
|
|
21
|
+
from original.
|
|
22
|
+
out_indices: Weather and which multi-level patch embeddings to return.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
The model instance.
|
|
26
|
+
"""
|
|
27
|
+
return timm.create_model(
|
|
28
|
+
model_name="hf-hub:bioptimus/H-optimus-0",
|
|
29
|
+
pretrained=True,
|
|
30
|
+
init_values=1e-5,
|
|
31
|
+
dynamic_img_size=dynamic_img_size,
|
|
32
|
+
out_indices=out_indices,
|
|
33
|
+
features_only=out_indices is not None,
|
|
34
|
+
)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Pathology FMs from other/mixed entities."""
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
import timm
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from eva.vision.models.networks.backbones.registry import register_model
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@register_model("pathology/prov_gigapath")
|
|
12
|
+
def prov_gigapath(
|
|
13
|
+
dynamic_img_size: bool = True,
|
|
14
|
+
out_indices: int | Tuple[int, ...] | None = None,
|
|
15
|
+
) -> nn.Module:
|
|
16
|
+
"""Initializes the Prov-GigaPath pathology FM.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
dynamic_img_size: Whether to allow the interpolation embedding
|
|
20
|
+
to be interpolated at `forward()` time when image grid changes
|
|
21
|
+
from original.
|
|
22
|
+
out_indices: Weather and which multi-level patch embeddings to return.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
The model instance.
|
|
26
|
+
"""
|
|
27
|
+
return timm.create_model(
|
|
28
|
+
model_name="hf_hub:prov-gigapath/prov-gigapath",
|
|
29
|
+
pretrained=True,
|
|
30
|
+
dynamic_img_size=dynamic_img_size,
|
|
31
|
+
out_indices=out_indices,
|
|
32
|
+
features_only=out_indices is not None,
|
|
33
|
+
)
|