kaiko-eva 0.1.8__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eva/core/data/datasets/base.py +7 -2
- eva/core/data/datasets/classification/embeddings.py +2 -2
- eva/core/data/datasets/classification/multi_embeddings.py +2 -2
- eva/core/data/datasets/embeddings.py +4 -4
- eva/core/data/samplers/classification/balanced.py +19 -18
- eva/core/loggers/utils/wandb.py +33 -0
- eva/core/models/modules/head.py +5 -3
- eva/core/models/modules/typings.py +2 -2
- eva/core/models/transforms/__init__.py +2 -1
- eva/core/models/transforms/as_discrete.py +57 -0
- eva/core/models/wrappers/_utils.py +121 -1
- eva/core/trainers/functional.py +8 -5
- eva/core/trainers/trainer.py +32 -17
- eva/core/utils/suppress_logs.py +28 -0
- eva/vision/data/__init__.py +2 -2
- eva/vision/data/dataloaders/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
- eva/vision/data/datasets/__init__.py +10 -2
- eva/vision/data/datasets/classification/__init__.py +9 -0
- eva/vision/data/datasets/classification/bach.py +3 -4
- eva/vision/data/datasets/classification/bracs.py +111 -0
- eva/vision/data/datasets/classification/breakhis.py +209 -0
- eva/vision/data/datasets/classification/camelyon16.py +4 -5
- eva/vision/data/datasets/classification/crc.py +3 -4
- eva/vision/data/datasets/classification/gleason_arvaniti.py +171 -0
- eva/vision/data/datasets/classification/mhist.py +3 -4
- eva/vision/data/datasets/classification/panda.py +4 -5
- eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
- eva/vision/data/datasets/classification/unitopatho.py +158 -0
- eva/vision/data/datasets/classification/wsi.py +6 -5
- eva/vision/data/datasets/segmentation/__init__.py +2 -2
- eva/vision/data/datasets/segmentation/_utils.py +47 -0
- eva/vision/data/datasets/segmentation/bcss.py +7 -8
- eva/vision/data/datasets/segmentation/btcv.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +6 -7
- eva/vision/data/datasets/segmentation/embeddings.py +2 -2
- eva/vision/data/datasets/segmentation/lits.py +9 -8
- eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
- eva/vision/data/datasets/segmentation/monusac.py +4 -5
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
- eva/vision/data/datasets/vision.py +95 -4
- eva/vision/data/datasets/wsi.py +5 -5
- eva/vision/data/transforms/__init__.py +22 -3
- eva/vision/data/transforms/common/__init__.py +1 -2
- eva/vision/data/transforms/croppad/__init__.py +11 -0
- eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
- eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
- eva/vision/data/transforms/intensity/__init__.py +11 -0
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
- eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
- eva/vision/data/transforms/spatial/__init__.py +7 -0
- eva/vision/data/transforms/spatial/flip.py +72 -0
- eva/vision/data/transforms/spatial/rotate.py +53 -0
- eva/vision/data/transforms/spatial/spacing.py +69 -0
- eva/vision/data/transforms/utility/__init__.py +5 -0
- eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
- eva/vision/data/tv_tensors/__init__.py +5 -0
- eva/vision/data/tv_tensors/volume.py +61 -0
- eva/vision/metrics/segmentation/monai_dice.py +9 -2
- eva/vision/models/modules/semantic_segmentation.py +28 -20
- eva/vision/models/networks/backbones/__init__.py +9 -2
- eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
- eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
- eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +46 -19
- eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
- eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
- eva/vision/models/networks/backbones/radiology/voco.py +75 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
- eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
- eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
- eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
- eva/vision/utils/io/__init__.py +2 -0
- eva/vision/utils/io/nifti.py +91 -11
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/METADATA +3 -1
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/RECORD +83 -62
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/WHEEL +1 -1
- eva/vision/data/datasets/classification/base.py +0 -96
- eva/vision/data/datasets/segmentation/base.py +0 -96
- eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
- eva/vision/data/transforms/normalization/__init__.py +0 -6
- eva/vision/data/transforms/normalization/clamp.py +0 -43
- eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
- eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
- eva/vision/metrics/segmentation/BUILD +0 -1
- eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
- eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,96 +0,0 @@
|
|
|
1
|
-
"""Base for image classification datasets."""
|
|
2
|
-
|
|
3
|
-
import abc
|
|
4
|
-
from typing import Any, Callable, Dict, List, Tuple
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
from torchvision import tv_tensors
|
|
8
|
-
from typing_extensions import override
|
|
9
|
-
|
|
10
|
-
from eva.vision.data.datasets import vision
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class ImageClassification(vision.VisionDataset[Tuple[tv_tensors.Image, torch.Tensor]], abc.ABC):
|
|
14
|
-
"""Image classification abstract dataset."""
|
|
15
|
-
|
|
16
|
-
def __init__(
|
|
17
|
-
self,
|
|
18
|
-
transforms: Callable | None = None,
|
|
19
|
-
) -> None:
|
|
20
|
-
"""Initializes the image classification dataset.
|
|
21
|
-
|
|
22
|
-
Args:
|
|
23
|
-
transforms: A function/transform which returns a transformed
|
|
24
|
-
version of the raw data samples.
|
|
25
|
-
"""
|
|
26
|
-
super().__init__()
|
|
27
|
-
|
|
28
|
-
self._transforms = transforms
|
|
29
|
-
|
|
30
|
-
@property
|
|
31
|
-
def classes(self) -> List[str] | None:
|
|
32
|
-
"""Returns the list with names of the dataset names."""
|
|
33
|
-
|
|
34
|
-
@property
|
|
35
|
-
def class_to_idx(self) -> Dict[str, int] | None:
|
|
36
|
-
"""Returns a mapping of the class name to its target index."""
|
|
37
|
-
|
|
38
|
-
def load_metadata(self, index: int) -> Dict[str, Any] | None:
|
|
39
|
-
"""Returns the dataset metadata.
|
|
40
|
-
|
|
41
|
-
Args:
|
|
42
|
-
index: The index of the data sample to return the metadata of.
|
|
43
|
-
|
|
44
|
-
Returns:
|
|
45
|
-
The sample metadata.
|
|
46
|
-
"""
|
|
47
|
-
|
|
48
|
-
@abc.abstractmethod
|
|
49
|
-
def load_image(self, index: int) -> tv_tensors.Image:
|
|
50
|
-
"""Returns the `index`'th image sample.
|
|
51
|
-
|
|
52
|
-
Args:
|
|
53
|
-
index: The index of the data sample to load.
|
|
54
|
-
|
|
55
|
-
Returns:
|
|
56
|
-
The image as a numpy array.
|
|
57
|
-
"""
|
|
58
|
-
|
|
59
|
-
@abc.abstractmethod
|
|
60
|
-
def load_target(self, index: int) -> torch.Tensor:
|
|
61
|
-
"""Returns the `index`'th target sample.
|
|
62
|
-
|
|
63
|
-
Args:
|
|
64
|
-
index: The index of the data sample to load.
|
|
65
|
-
|
|
66
|
-
Returns:
|
|
67
|
-
The sample target as an array.
|
|
68
|
-
"""
|
|
69
|
-
|
|
70
|
-
@abc.abstractmethod
|
|
71
|
-
@override
|
|
72
|
-
def __len__(self) -> int:
|
|
73
|
-
raise NotImplementedError
|
|
74
|
-
|
|
75
|
-
@override
|
|
76
|
-
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
|
|
77
|
-
image = self.load_image(index)
|
|
78
|
-
target = self.load_target(index)
|
|
79
|
-
image, target = self._apply_transforms(image, target)
|
|
80
|
-
return image, target, self.load_metadata(index) or {}
|
|
81
|
-
|
|
82
|
-
def _apply_transforms(
|
|
83
|
-
self, image: tv_tensors.Image, target: torch.Tensor
|
|
84
|
-
) -> Tuple[tv_tensors.Image, torch.Tensor]:
|
|
85
|
-
"""Applies the transforms to the provided data and returns them.
|
|
86
|
-
|
|
87
|
-
Args:
|
|
88
|
-
image: The desired image.
|
|
89
|
-
target: The target of the image.
|
|
90
|
-
|
|
91
|
-
Returns:
|
|
92
|
-
A tuple with the image and the target transformed.
|
|
93
|
-
"""
|
|
94
|
-
if self._transforms is not None:
|
|
95
|
-
image, target = self._transforms(image, target)
|
|
96
|
-
return image, target
|
|
@@ -1,96 +0,0 @@
|
|
|
1
|
-
"""Base for image segmentation datasets."""
|
|
2
|
-
|
|
3
|
-
import abc
|
|
4
|
-
from typing import Any, Callable, Dict, List, Tuple
|
|
5
|
-
|
|
6
|
-
from torchvision import tv_tensors
|
|
7
|
-
from typing_extensions import override
|
|
8
|
-
|
|
9
|
-
from eva.vision.data.datasets import vision
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class ImageSegmentation(vision.VisionDataset[Tuple[tv_tensors.Image, tv_tensors.Mask]], abc.ABC):
|
|
13
|
-
"""Image segmentation abstract dataset."""
|
|
14
|
-
|
|
15
|
-
def __init__(self, transforms: Callable | None = None) -> None:
|
|
16
|
-
"""Initializes the image segmentation base class.
|
|
17
|
-
|
|
18
|
-
Args:
|
|
19
|
-
transforms: A function/transforms that takes in an
|
|
20
|
-
image and a label and returns the transformed versions of both.
|
|
21
|
-
"""
|
|
22
|
-
super().__init__()
|
|
23
|
-
|
|
24
|
-
self._transforms = transforms
|
|
25
|
-
|
|
26
|
-
@property
|
|
27
|
-
def classes(self) -> List[str] | None:
|
|
28
|
-
"""Returns the list with names of the dataset names."""
|
|
29
|
-
|
|
30
|
-
@property
|
|
31
|
-
def class_to_idx(self) -> Dict[str, int] | None:
|
|
32
|
-
"""Returns a mapping of the class name to its target index."""
|
|
33
|
-
|
|
34
|
-
@abc.abstractmethod
|
|
35
|
-
def load_image(self, index: int) -> tv_tensors.Image:
|
|
36
|
-
"""Loads and returns the `index`'th image sample.
|
|
37
|
-
|
|
38
|
-
Args:
|
|
39
|
-
index: The index of the data sample to load.
|
|
40
|
-
|
|
41
|
-
Returns:
|
|
42
|
-
An image torchvision tensor (channels, height, width).
|
|
43
|
-
"""
|
|
44
|
-
|
|
45
|
-
@abc.abstractmethod
|
|
46
|
-
def load_mask(self, index: int) -> tv_tensors.Mask:
|
|
47
|
-
"""Returns the `index`'th target masks sample.
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
index: The index of the data sample target masks to load.
|
|
51
|
-
|
|
52
|
-
Returns:
|
|
53
|
-
The semantic mask as a (H x W) shaped tensor with integer
|
|
54
|
-
values which represent the pixel class id.
|
|
55
|
-
"""
|
|
56
|
-
|
|
57
|
-
def load_metadata(self, index: int) -> Dict[str, Any] | None:
|
|
58
|
-
"""Returns the dataset metadata.
|
|
59
|
-
|
|
60
|
-
Args:
|
|
61
|
-
index: The index of the data sample to return the metadata of.
|
|
62
|
-
If `None`, it will return the metadata of the current dataset.
|
|
63
|
-
|
|
64
|
-
Returns:
|
|
65
|
-
The sample metadata.
|
|
66
|
-
"""
|
|
67
|
-
|
|
68
|
-
@abc.abstractmethod
|
|
69
|
-
@override
|
|
70
|
-
def __len__(self) -> int:
|
|
71
|
-
raise NotImplementedError
|
|
72
|
-
|
|
73
|
-
@override
|
|
74
|
-
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, tv_tensors.Mask, Dict[str, Any]]:
|
|
75
|
-
image = self.load_image(index)
|
|
76
|
-
mask = self.load_mask(index)
|
|
77
|
-
metadata = self.load_metadata(index) or {}
|
|
78
|
-
image_tensor, mask_tensor = self._apply_transforms(image, mask)
|
|
79
|
-
return image_tensor, mask_tensor, metadata
|
|
80
|
-
|
|
81
|
-
def _apply_transforms(
|
|
82
|
-
self, image: tv_tensors.Image, mask: tv_tensors.Mask
|
|
83
|
-
) -> Tuple[tv_tensors.Image, tv_tensors.Mask]:
|
|
84
|
-
"""Applies the transforms to the provided data and returns them.
|
|
85
|
-
|
|
86
|
-
Args:
|
|
87
|
-
image: The desired image.
|
|
88
|
-
mask: The target segmentation mask.
|
|
89
|
-
|
|
90
|
-
Returns:
|
|
91
|
-
A tuple with the image and the masks transformed.
|
|
92
|
-
"""
|
|
93
|
-
if self._transforms is not None:
|
|
94
|
-
image, mask = self._transforms(image, mask)
|
|
95
|
-
|
|
96
|
-
return image, mask
|
|
@@ -1,51 +0,0 @@
|
|
|
1
|
-
"""Specialized transforms for resizing, clamping and range normalizing."""
|
|
2
|
-
|
|
3
|
-
from typing import Callable, Sequence, Tuple
|
|
4
|
-
|
|
5
|
-
from torchvision.transforms import v2
|
|
6
|
-
|
|
7
|
-
from eva.vision.data.transforms import normalization
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class ResizeAndClamp(v2.Compose):
|
|
11
|
-
"""Resizes, crops, clamps and normalizes an input image."""
|
|
12
|
-
|
|
13
|
-
def __init__(
|
|
14
|
-
self,
|
|
15
|
-
size: int | Sequence[int] = 224,
|
|
16
|
-
clamp_range: Tuple[int, int] = (-1024, 1024),
|
|
17
|
-
mean: Sequence[float] = (0.0, 0.0, 0.0),
|
|
18
|
-
std: Sequence[float] = (1.0, 1.0, 1.0),
|
|
19
|
-
) -> None:
|
|
20
|
-
"""Initializes the transform object.
|
|
21
|
-
|
|
22
|
-
Args:
|
|
23
|
-
size: Desired output size of the crop. If size is an `int` instead
|
|
24
|
-
of sequence like (h, w), a square crop (size, size) is made.
|
|
25
|
-
clamp_range: The lower and upper bound to clamp the pixel values.
|
|
26
|
-
mean: Sequence of means for each image channel.
|
|
27
|
-
std: Sequence of standard deviations for each image channel.
|
|
28
|
-
"""
|
|
29
|
-
self._size = size
|
|
30
|
-
self._clamp_range = clamp_range
|
|
31
|
-
self._mean = mean
|
|
32
|
-
self._std = std
|
|
33
|
-
|
|
34
|
-
super().__init__(transforms=self._build_transforms())
|
|
35
|
-
|
|
36
|
-
def _build_transforms(self) -> Sequence[Callable]:
|
|
37
|
-
"""Builds and returns the list of transforms."""
|
|
38
|
-
transforms = [
|
|
39
|
-
v2.Resize(size=self._size),
|
|
40
|
-
v2.CenterCrop(size=self._size),
|
|
41
|
-
normalization.Clamp(out_range=self._clamp_range),
|
|
42
|
-
normalization.RescaleIntensity(
|
|
43
|
-
in_range=self._clamp_range,
|
|
44
|
-
out_range=(0.0, 1.0),
|
|
45
|
-
),
|
|
46
|
-
v2.Normalize(
|
|
47
|
-
mean=self._mean,
|
|
48
|
-
std=self._std,
|
|
49
|
-
),
|
|
50
|
-
]
|
|
51
|
-
return transforms
|
|
@@ -1,43 +0,0 @@
|
|
|
1
|
-
"""Image clamp transform."""
|
|
2
|
-
|
|
3
|
-
import functools
|
|
4
|
-
from typing import Any, Dict, Tuple
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
import torchvision.transforms.v2 as torch_transforms
|
|
8
|
-
from torchvision import tv_tensors
|
|
9
|
-
from typing_extensions import override
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class Clamp(torch_transforms.Transform):
|
|
13
|
-
"""Clamps all elements in input into a specific range."""
|
|
14
|
-
|
|
15
|
-
def __init__(self, out_range: Tuple[int, int]) -> None:
|
|
16
|
-
"""Initializes the transform.
|
|
17
|
-
|
|
18
|
-
Args:
|
|
19
|
-
out_range: The lower and upper bound of the range to
|
|
20
|
-
be clamped to.
|
|
21
|
-
"""
|
|
22
|
-
super().__init__()
|
|
23
|
-
|
|
24
|
-
self._out_range = out_range
|
|
25
|
-
|
|
26
|
-
@functools.singledispatchmethod
|
|
27
|
-
@override
|
|
28
|
-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
29
|
-
return inpt
|
|
30
|
-
|
|
31
|
-
@_transform.register(torch.Tensor)
|
|
32
|
-
def _(self, inpt: torch.Tensor, params: Dict[str, Any]) -> Any:
|
|
33
|
-
return torch.clamp(inpt, min=self._out_range[0], max=self._out_range[1])
|
|
34
|
-
|
|
35
|
-
@_transform.register(tv_tensors.Image)
|
|
36
|
-
def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
|
|
37
|
-
inpt_clamp = torch.clamp(inpt, min=self._out_range[0], max=self._out_range[1])
|
|
38
|
-
return tv_tensors.wrap(inpt_clamp, like=inpt)
|
|
39
|
-
|
|
40
|
-
@_transform.register(tv_tensors.BoundingBoxes)
|
|
41
|
-
@_transform.register(tv_tensors.Mask)
|
|
42
|
-
def _(self, inpt: tv_tensors.BoundingBoxes | tv_tensors.Mask, params: Dict[str, Any]) -> Any:
|
|
43
|
-
return inpt
|
|
@@ -1,28 +0,0 @@
|
|
|
1
|
-
"""Intensity level functions."""
|
|
2
|
-
|
|
3
|
-
import sys
|
|
4
|
-
from typing import Tuple
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
def rescale_intensity(
|
|
10
|
-
image: torch.Tensor,
|
|
11
|
-
in_range: Tuple[float, float] | None = None,
|
|
12
|
-
out_range: Tuple[float, float] = (0.0, 1.0),
|
|
13
|
-
) -> torch.Tensor:
|
|
14
|
-
"""Stretches or shrinks the image intensity levels.
|
|
15
|
-
|
|
16
|
-
Args:
|
|
17
|
-
image: The image tensor as float-type.
|
|
18
|
-
in_range: The input data range. If `None`, it will
|
|
19
|
-
fetch the min and max of the input image.
|
|
20
|
-
out_range: The desired intensity range of the output.
|
|
21
|
-
|
|
22
|
-
Returns:
|
|
23
|
-
The image tensor after stretching or shrinking its intensity levels.
|
|
24
|
-
"""
|
|
25
|
-
imin, imax = in_range or (image.min(), image.max())
|
|
26
|
-
omin, omax = out_range
|
|
27
|
-
image_scaled = (image - imin) / (imax - imin + sys.float_info.epsilon)
|
|
28
|
-
return image_scaled * (omax - omin) + omin
|
|
@@ -1,53 +0,0 @@
|
|
|
1
|
-
"""Intensity level scaling transform."""
|
|
2
|
-
|
|
3
|
-
import functools
|
|
4
|
-
from typing import Any, Dict, Tuple
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
import torchvision.transforms.v2 as torch_transforms
|
|
8
|
-
from torchvision import tv_tensors
|
|
9
|
-
from typing_extensions import override
|
|
10
|
-
|
|
11
|
-
from eva.vision.data.transforms.normalization import functional
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class RescaleIntensity(torch_transforms.Transform):
|
|
15
|
-
"""Stretches or shrinks the image intensity levels."""
|
|
16
|
-
|
|
17
|
-
def __init__(
|
|
18
|
-
self,
|
|
19
|
-
in_range: Tuple[float, float] | None = None,
|
|
20
|
-
out_range: Tuple[float, float] = (0.0, 1.0),
|
|
21
|
-
) -> None:
|
|
22
|
-
"""Initializes the transform.
|
|
23
|
-
|
|
24
|
-
Args:
|
|
25
|
-
in_range: The input data range. If `None`, it will
|
|
26
|
-
fetch the min and max of the input image.
|
|
27
|
-
out_range: The desired intensity range of the output.
|
|
28
|
-
"""
|
|
29
|
-
super().__init__()
|
|
30
|
-
|
|
31
|
-
self._in_range = in_range
|
|
32
|
-
self._out_range = out_range
|
|
33
|
-
|
|
34
|
-
@functools.singledispatchmethod
|
|
35
|
-
@override
|
|
36
|
-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
37
|
-
return inpt
|
|
38
|
-
|
|
39
|
-
@_transform.register(torch.Tensor)
|
|
40
|
-
def _(self, inpt: torch.Tensor, params: Dict[str, Any]) -> Any:
|
|
41
|
-
return functional.rescale_intensity(
|
|
42
|
-
inpt, in_range=self._in_range, out_range=self._out_range
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
@_transform.register(tv_tensors.Image)
|
|
46
|
-
def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
|
|
47
|
-
scaled_inpt = functional.rescale_intensity(inpt, out_range=self._out_range)
|
|
48
|
-
return tv_tensors.wrap(scaled_inpt, like=inpt)
|
|
49
|
-
|
|
50
|
-
@_transform.register(tv_tensors.BoundingBoxes)
|
|
51
|
-
@_transform.register(tv_tensors.Mask)
|
|
52
|
-
def _(self, inpt: tv_tensors.BoundingBoxes | tv_tensors.Mask, params: Dict[str, Any]) -> Any:
|
|
53
|
-
return inpt
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
python_sources()
|
|
@@ -1,61 +0,0 @@
|
|
|
1
|
-
"""torch.hub backbones."""
|
|
2
|
-
|
|
3
|
-
import functools
|
|
4
|
-
from typing import Tuple
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
from loguru import logger
|
|
8
|
-
from torch import nn
|
|
9
|
-
|
|
10
|
-
from eva.core.models import wrappers
|
|
11
|
-
from eva.vision.models.networks.backbones.registry import BackboneModelRegistry
|
|
12
|
-
|
|
13
|
-
HUB_REPOS = ["facebookresearch/dinov2:main", "kaiko-ai/towards_large_pathology_fms"]
|
|
14
|
-
"""List of torch.hub repositories for which to add the models to the registry."""
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def torch_hub_model(
|
|
18
|
-
model_name: str,
|
|
19
|
-
repo_or_dir: str,
|
|
20
|
-
checkpoint_path: str | None = None,
|
|
21
|
-
pretrained: bool = False,
|
|
22
|
-
out_indices: int | Tuple[int, ...] | None = None,
|
|
23
|
-
**kwargs,
|
|
24
|
-
) -> nn.Module:
|
|
25
|
-
"""Initializes any ViT model from torch.hub with weights from a specified checkpoint.
|
|
26
|
-
|
|
27
|
-
Args:
|
|
28
|
-
model_name: The name of the model to load.
|
|
29
|
-
repo_or_dir: The torch.hub repository or local directory to load the model from.
|
|
30
|
-
checkpoint_path: The path to the checkpoint file.
|
|
31
|
-
pretrained: If set to `True`, load pretrained model weights if available.
|
|
32
|
-
out_indices: Whether and which multi-level patch embeddings to return.
|
|
33
|
-
**kwargs: Additional arguments to pass to the model
|
|
34
|
-
|
|
35
|
-
Returns:
|
|
36
|
-
The VIT model instance.
|
|
37
|
-
"""
|
|
38
|
-
logger.info(
|
|
39
|
-
f"Loading torch.hub model {model_name} from {repo_or_dir}"
|
|
40
|
-
+ (f"using checkpoint {checkpoint_path}" if checkpoint_path else "")
|
|
41
|
-
)
|
|
42
|
-
|
|
43
|
-
return wrappers.TorchHubModel(
|
|
44
|
-
model_name=model_name,
|
|
45
|
-
repo_or_dir=repo_or_dir,
|
|
46
|
-
pretrained=pretrained,
|
|
47
|
-
checkpoint_path=checkpoint_path or "",
|
|
48
|
-
out_indices=out_indices,
|
|
49
|
-
model_kwargs=kwargs,
|
|
50
|
-
)
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
BackboneModelRegistry._registry.update(
|
|
54
|
-
{
|
|
55
|
-
f"torchhub/{repo}:{model_name}": functools.partial(
|
|
56
|
-
torch_hub_model, model_name=model_name, repo_or_dir=repo
|
|
57
|
-
)
|
|
58
|
-
for repo in HUB_REPOS
|
|
59
|
-
for model_name in torch.hub.list(repo, verbose=False)
|
|
60
|
-
}
|
|
61
|
-
)
|
|
File without changes
|
|
File without changes
|