kaiko-eva 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/config.py +15 -6
- eva/core/callbacks/writers/embeddings/base.py +44 -10
- eva/core/cli/setup.py +1 -1
- eva/core/data/dataloaders/__init__.py +1 -2
- eva/core/data/samplers/classification/balanced.py +24 -12
- eva/core/data/samplers/random.py +17 -10
- eva/core/interface/interface.py +21 -0
- eva/core/loggers/utils/wandb.py +4 -1
- eva/core/models/modules/module.py +2 -2
- eva/core/models/wrappers/base.py +2 -2
- eva/core/models/wrappers/from_function.py +3 -3
- eva/core/models/wrappers/from_torchhub.py +9 -7
- eva/core/models/wrappers/huggingface.py +4 -5
- eva/core/models/wrappers/onnx.py +5 -5
- eva/core/trainers/trainer.py +13 -1
- eva/core/utils/__init__.py +2 -1
- eva/core/utils/distributed.py +12 -0
- eva/core/utils/paths.py +14 -0
- eva/core/utils/requirements.py +52 -6
- eva/language/__init__.py +2 -1
- eva/language/callbacks/__init__.py +5 -0
- eva/language/callbacks/writers/__init__.py +5 -0
- eva/language/callbacks/writers/prediction.py +201 -0
- eva/language/data/dataloaders/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/text.py +57 -0
- eva/language/data/datasets/__init__.py +3 -1
- eva/language/data/datasets/{language.py → base.py} +1 -1
- eva/language/data/datasets/classification/base.py +3 -43
- eva/language/data/datasets/classification/pubmedqa.py +36 -4
- eva/language/data/datasets/prediction.py +151 -0
- eva/language/data/datasets/schemas.py +18 -0
- eva/language/data/datasets/text.py +92 -0
- eva/language/data/datasets/typings.py +39 -0
- eva/language/data/messages.py +60 -0
- eva/language/models/__init__.py +15 -11
- eva/language/models/modules/__init__.py +2 -2
- eva/language/models/modules/language.py +94 -0
- eva/language/models/networks/__init__.py +12 -0
- eva/language/models/networks/alibaba.py +26 -0
- eva/language/models/networks/api/__init__.py +11 -0
- eva/language/models/networks/api/anthropic.py +34 -0
- eva/language/models/networks/registry.py +5 -0
- eva/language/models/typings.py +56 -0
- eva/language/models/wrappers/__init__.py +13 -5
- eva/language/models/wrappers/base.py +47 -0
- eva/language/models/wrappers/from_registry.py +54 -0
- eva/language/models/wrappers/huggingface.py +57 -11
- eva/language/models/wrappers/litellm.py +91 -46
- eva/language/models/wrappers/vllm.py +37 -13
- eva/language/utils/__init__.py +2 -1
- eva/language/utils/str_to_int_tensor.py +20 -12
- eva/language/utils/text/__init__.py +5 -0
- eva/language/utils/text/messages.py +113 -0
- eva/multimodal/__init__.py +6 -0
- eva/multimodal/callbacks/__init__.py +5 -0
- eva/multimodal/callbacks/writers/__init__.py +5 -0
- eva/multimodal/callbacks/writers/prediction.py +39 -0
- eva/multimodal/data/__init__.py +5 -0
- eva/multimodal/data/dataloaders/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
- eva/multimodal/data/datasets/__init__.py +6 -0
- eva/multimodal/data/datasets/base.py +13 -0
- eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
- eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
- eva/multimodal/data/datasets/schemas.py +14 -0
- eva/multimodal/data/datasets/text_image.py +77 -0
- eva/multimodal/data/datasets/typings.py +27 -0
- eva/multimodal/models/__init__.py +8 -0
- eva/multimodal/models/modules/__init__.py +5 -0
- eva/multimodal/models/modules/vision_language.py +56 -0
- eva/multimodal/models/networks/__init__.py +14 -0
- eva/multimodal/models/networks/alibaba.py +40 -0
- eva/multimodal/models/networks/api/__init__.py +11 -0
- eva/multimodal/models/networks/api/anthropic.py +34 -0
- eva/multimodal/models/networks/others.py +48 -0
- eva/multimodal/models/networks/registry.py +5 -0
- eva/multimodal/models/typings.py +27 -0
- eva/multimodal/models/wrappers/__init__.py +13 -0
- eva/multimodal/models/wrappers/base.py +48 -0
- eva/multimodal/models/wrappers/from_registry.py +54 -0
- eva/multimodal/models/wrappers/huggingface.py +193 -0
- eva/multimodal/models/wrappers/litellm.py +58 -0
- eva/multimodal/utils/__init__.py +1 -0
- eva/multimodal/utils/batch/__init__.py +5 -0
- eva/multimodal/utils/batch/unpack.py +11 -0
- eva/multimodal/utils/image/__init__.py +5 -0
- eva/multimodal/utils/image/encode.py +28 -0
- eva/multimodal/utils/text/__init__.py +1 -0
- eva/multimodal/utils/text/messages.py +79 -0
- eva/vision/data/datasets/classification/breakhis.py +5 -8
- eva/vision/data/datasets/classification/panda.py +12 -5
- eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
- eva/vision/data/datasets/segmentation/btcv.py +1 -1
- eva/vision/data/datasets/segmentation/consep.py +1 -1
- eva/vision/data/datasets/segmentation/lits17.py +1 -1
- eva/vision/data/datasets/segmentation/monusac.py +15 -6
- eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +1 -1
- eva/vision/data/transforms/__init__.py +2 -1
- eva/vision/data/transforms/base/__init__.py +2 -1
- eva/vision/data/transforms/base/monai.py +2 -2
- eva/vision/data/transforms/base/torchvision.py +33 -0
- eva/vision/data/transforms/common/squeeze.py +6 -3
- eva/vision/data/transforms/croppad/crop_foreground.py +8 -7
- eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +6 -5
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -5
- eva/vision/data/transforms/croppad/rand_spatial_crop.py +8 -7
- eva/vision/data/transforms/croppad/spatial_pad.py +6 -6
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +3 -3
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +3 -3
- eva/vision/data/transforms/intensity/scale_intensity_ranged.py +5 -5
- eva/vision/data/transforms/spatial/__init__.py +2 -1
- eva/vision/data/transforms/spatial/flip.py +8 -7
- eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
- eva/vision/data/transforms/spatial/functional/resize.py +26 -0
- eva/vision/data/transforms/spatial/resize.py +63 -0
- eva/vision/data/transforms/spatial/rotate.py +8 -7
- eva/vision/data/transforms/spatial/spacing.py +7 -6
- eva/vision/data/transforms/utility/ensure_channel_first.py +6 -6
- eva/vision/models/networks/backbones/universal/vit.py +24 -0
- eva/vision/models/wrappers/from_registry.py +6 -5
- eva/vision/models/wrappers/from_timm.py +6 -4
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/METADATA +17 -3
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/RECORD +128 -66
- eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
- eva/core/data/dataloaders/collate_fn/collate.py +0 -24
- eva/language/models/modules/text.py +0 -85
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/WHEEL +0 -0
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,10 +2,10 @@
|
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from eva.vision.data.transforms.base.torchvision import TorchvisionTransformV2
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
class RandomMonaiTransform(
|
|
8
|
+
class RandomMonaiTransform(TorchvisionTransformV2, abc.ABC):
|
|
9
9
|
"""Base class for MONAI transform wrappers."""
|
|
10
10
|
|
|
11
11
|
@abc.abstractmethod
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Base class for torchvision.v2 transforms."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
|
|
6
|
+
from torchvision.transforms import v2
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TorchvisionTransformV2(v2.Transform, abc.ABC):
|
|
10
|
+
"""Wrapper for torchvision.v2.Transform.
|
|
11
|
+
|
|
12
|
+
This class ensures compatibility both with >=0.21.0 and older versions,
|
|
13
|
+
as torchvision 0.21.0 introduced a new transform API where they
|
|
14
|
+
renamed the following methods:
|
|
15
|
+
|
|
16
|
+
- `_get_params` -> `make_params`
|
|
17
|
+
- `_transform` -> `transform`
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
21
|
+
"""Called internally before calling transform() on each input."""
|
|
22
|
+
return {}
|
|
23
|
+
|
|
24
|
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
25
|
+
return self.make_params(flat_inputs)
|
|
26
|
+
|
|
27
|
+
@abc.abstractmethod
|
|
28
|
+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
29
|
+
"""Applies the transformation to the input."""
|
|
30
|
+
raise NotImplementedError
|
|
31
|
+
|
|
32
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
33
|
+
return self.transform(inpt, params)
|
|
@@ -4,10 +4,12 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
from torchvision import tv_tensors
|
|
7
|
-
from
|
|
7
|
+
from typing_extensions import override
|
|
8
8
|
|
|
9
|
+
from eva.vision.data.transforms import base
|
|
9
10
|
|
|
10
|
-
|
|
11
|
+
|
|
12
|
+
class Squeeze(base.TorchvisionTransformV2):
|
|
11
13
|
"""Squeezes the input tensor accross all or specified dimensions."""
|
|
12
14
|
|
|
13
15
|
def __init__(self, dim: int | list[int] | None = None):
|
|
@@ -19,6 +21,7 @@ class Squeeze(v2.Transform):
|
|
|
19
21
|
super().__init__()
|
|
20
22
|
self._dim = dim
|
|
21
23
|
|
|
22
|
-
|
|
24
|
+
@override
|
|
25
|
+
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
|
|
23
26
|
output = torch.squeeze(inpt) if self._dim is None else torch.squeeze(inpt, dim=self._dim)
|
|
24
27
|
return tv_tensors.wrap(output, like=inpt)
|
|
@@ -8,13 +8,13 @@ from monai.config import type_definitions
|
|
|
8
8
|
from monai.transforms.croppad import array as monai_croppad_transforms
|
|
9
9
|
from monai.utils.enums import PytorchPadMode
|
|
10
10
|
from torchvision import tv_tensors
|
|
11
|
-
from torchvision.transforms import v2
|
|
12
11
|
from typing_extensions import override
|
|
13
12
|
|
|
14
13
|
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
14
|
+
from eva.vision.data.transforms import base
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
class CropForeground(
|
|
17
|
+
class CropForeground(base.TorchvisionTransformV2):
|
|
18
18
|
"""Crop an image using a bounding box.
|
|
19
19
|
|
|
20
20
|
The bounding box is generated by selecting foreground using select_fn
|
|
@@ -74,19 +74,20 @@ class CropForeground(v2.Transform):
|
|
|
74
74
|
**pad_kwargs,
|
|
75
75
|
)
|
|
76
76
|
|
|
77
|
-
|
|
77
|
+
@override
|
|
78
|
+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
78
79
|
volume = next(inpt for inpt in flat_inputs if isinstance(inpt, eva_tv_tensors.Volume))
|
|
79
80
|
box_start, box_end = self._foreground_crop.compute_bounding_box(volume)
|
|
80
81
|
return {"box_start": box_start, "box_end": box_end}
|
|
81
82
|
|
|
82
83
|
@functools.singledispatchmethod
|
|
83
84
|
@override
|
|
84
|
-
def
|
|
85
|
+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
85
86
|
return inpt
|
|
86
87
|
|
|
87
|
-
@
|
|
88
|
-
@
|
|
89
|
-
@
|
|
88
|
+
@transform.register(tv_tensors.Image)
|
|
89
|
+
@transform.register(eva_tv_tensors.Volume)
|
|
90
|
+
@transform.register(tv_tensors.Mask)
|
|
90
91
|
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
91
92
|
inpt_foreground_cropped = self._foreground_crop.crop_pad(
|
|
92
93
|
inpt, params["box_start"], params["box_end"]
|
|
@@ -56,19 +56,20 @@ class RandCropByLabelClasses(base.RandomMonaiTransform):
|
|
|
56
56
|
def set_random_state(self, seed: int) -> None:
|
|
57
57
|
self._rand_crop.set_random_state(seed)
|
|
58
58
|
|
|
59
|
-
|
|
59
|
+
@override
|
|
60
|
+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
60
61
|
mask = next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.Mask))
|
|
61
62
|
self._rand_crop.randomize(label=mask)
|
|
62
63
|
return {}
|
|
63
64
|
|
|
64
65
|
@functools.singledispatchmethod
|
|
65
66
|
@override
|
|
66
|
-
def
|
|
67
|
+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
67
68
|
return inpt
|
|
68
69
|
|
|
69
|
-
@
|
|
70
|
-
@
|
|
71
|
-
@
|
|
70
|
+
@transform.register(tv_tensors.Image)
|
|
71
|
+
@transform.register(eva_tv_tensors.Volume)
|
|
72
|
+
@transform.register(tv_tensors.Mask)
|
|
72
73
|
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
73
74
|
inpt_foreground_crops = self._rand_crop(img=inpt, randomize=False)
|
|
74
75
|
return [tv_tensors.wrap(crop, like=inpt) for crop in inpt_foreground_crops]
|
|
@@ -95,19 +95,20 @@ class RandCropByPosNegLabel(base.RandomMonaiTransform):
|
|
|
95
95
|
def set_random_state(self, seed: int) -> None:
|
|
96
96
|
self._rand_crop.set_random_state(seed)
|
|
97
97
|
|
|
98
|
-
|
|
98
|
+
@override
|
|
99
|
+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
99
100
|
mask = next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.Mask))
|
|
100
101
|
self._rand_crop.randomize(label=mask)
|
|
101
102
|
return {}
|
|
102
103
|
|
|
103
104
|
@functools.singledispatchmethod
|
|
104
105
|
@override
|
|
105
|
-
def
|
|
106
|
+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
106
107
|
return inpt
|
|
107
108
|
|
|
108
|
-
@
|
|
109
|
-
@
|
|
110
|
-
@
|
|
109
|
+
@transform.register(tv_tensors.Image)
|
|
110
|
+
@transform.register(eva_tv_tensors.Volume)
|
|
111
|
+
@transform.register(tv_tensors.Mask)
|
|
111
112
|
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
112
113
|
inpt_foreground_crops = self._rand_crop(img=inpt, randomize=False)
|
|
113
114
|
return [tv_tensors.wrap(crop, like=inpt) for crop in inpt_foreground_crops]
|
|
@@ -5,14 +5,14 @@ from typing import Any, Dict, List, Sequence, Tuple
|
|
|
5
5
|
|
|
6
6
|
from monai.transforms.croppad import array as monai_croppad_transforms
|
|
7
7
|
from torchvision import tv_tensors
|
|
8
|
-
from torchvision.transforms import v2
|
|
9
8
|
from torchvision.transforms.v2 import _utils as tv_utils
|
|
10
9
|
from typing_extensions import override
|
|
11
10
|
|
|
12
11
|
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
12
|
+
from eva.vision.data.transforms import base
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
class RandSpatialCrop(
|
|
15
|
+
class RandSpatialCrop(base.TorchvisionTransformV2):
|
|
16
16
|
"""Crop image with random size or specific size ROI.
|
|
17
17
|
|
|
18
18
|
It can crop at a random position as center or at the image center.
|
|
@@ -62,19 +62,20 @@ class RandSpatialCrop(v2.Transform):
|
|
|
62
62
|
"""Set the random state for the transform."""
|
|
63
63
|
self._rand_spatial_crop.set_random_state(seed)
|
|
64
64
|
|
|
65
|
-
|
|
65
|
+
@override
|
|
66
|
+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
66
67
|
t, h, w = tv_utils.query_chw(flat_inputs)
|
|
67
68
|
self._rand_spatial_crop.randomize((t, h, w))
|
|
68
69
|
return {}
|
|
69
70
|
|
|
70
71
|
@functools.singledispatchmethod
|
|
71
72
|
@override
|
|
72
|
-
def
|
|
73
|
+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
73
74
|
return inpt
|
|
74
75
|
|
|
75
|
-
@
|
|
76
|
-
@
|
|
77
|
-
@
|
|
76
|
+
@transform.register(tv_tensors.Image)
|
|
77
|
+
@transform.register(eva_tv_tensors.Volume)
|
|
78
|
+
@transform.register(tv_tensors.Mask)
|
|
78
79
|
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
79
80
|
slices = self._get_crop_slices()
|
|
80
81
|
inpt_rand_crop = self._cropper(inpt, slices=slices)
|
|
@@ -6,13 +6,13 @@ from typing import Any, Dict, Sequence
|
|
|
6
6
|
from monai.transforms.croppad import array as monai_croppad_transforms
|
|
7
7
|
from monai.utils.enums import Method, PytorchPadMode
|
|
8
8
|
from torchvision import tv_tensors
|
|
9
|
-
from torchvision.transforms import v2
|
|
10
9
|
from typing_extensions import override
|
|
11
10
|
|
|
12
11
|
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
12
|
+
from eva.vision.data.transforms import base
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
class SpatialPad(
|
|
15
|
+
class SpatialPad(base.TorchvisionTransformV2):
|
|
16
16
|
"""Performs padding to the data.
|
|
17
17
|
|
|
18
18
|
Padding is applied symmetric for all sides or all on one side for each dimension.
|
|
@@ -56,12 +56,12 @@ class SpatialPad(v2.Transform):
|
|
|
56
56
|
|
|
57
57
|
@functools.singledispatchmethod
|
|
58
58
|
@override
|
|
59
|
-
def
|
|
59
|
+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
60
60
|
return inpt
|
|
61
61
|
|
|
62
|
-
@
|
|
63
|
-
@
|
|
64
|
-
@
|
|
62
|
+
@transform.register(tv_tensors.Image)
|
|
63
|
+
@transform.register(eva_tv_tensors.Volume)
|
|
64
|
+
@transform.register(tv_tensors.Mask)
|
|
65
65
|
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
66
66
|
inpt_padded = self._spatial_pad(inpt)
|
|
67
67
|
return tv_tensors.wrap(inpt_padded, like=inpt)
|
|
@@ -53,11 +53,11 @@ class RandScaleIntensity(base.RandomMonaiTransform):
|
|
|
53
53
|
|
|
54
54
|
@functools.singledispatchmethod
|
|
55
55
|
@override
|
|
56
|
-
def
|
|
56
|
+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
57
57
|
return inpt
|
|
58
58
|
|
|
59
|
-
@
|
|
60
|
-
@
|
|
59
|
+
@transform.register(tv_tensors.Image)
|
|
60
|
+
@transform.register(eva_tv_tensors.Volume)
|
|
61
61
|
def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
|
|
62
62
|
inpt_scaled = self._rand_scale_intensity(inpt)
|
|
63
63
|
return tv_tensors.wrap(inpt_scaled, like=inpt)
|
|
@@ -49,11 +49,11 @@ class RandShiftIntensity(base.RandomMonaiTransform):
|
|
|
49
49
|
|
|
50
50
|
@functools.singledispatchmethod
|
|
51
51
|
@override
|
|
52
|
-
def
|
|
52
|
+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
53
53
|
return inpt
|
|
54
54
|
|
|
55
|
-
@
|
|
56
|
-
@
|
|
55
|
+
@transform.register(tv_tensors.Image)
|
|
56
|
+
@transform.register(eva_tv_tensors.Volume)
|
|
57
57
|
def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
|
|
58
58
|
inpt_scaled = self._rand_shift_intensity(inpt)
|
|
59
59
|
return tv_tensors.wrap(inpt_scaled, like=inpt)
|
|
@@ -5,13 +5,13 @@ from typing import Any, Dict, Tuple
|
|
|
5
5
|
|
|
6
6
|
from monai.transforms.intensity import array as monai_intensity_transforms
|
|
7
7
|
from torchvision import tv_tensors
|
|
8
|
-
from torchvision.transforms import v2
|
|
9
8
|
from typing_extensions import override
|
|
10
9
|
|
|
11
10
|
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
11
|
+
from eva.vision.data.transforms import base
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
class ScaleIntensityRange(
|
|
14
|
+
class ScaleIntensityRange(base.TorchvisionTransformV2):
|
|
15
15
|
"""Intensity scaling transform.
|
|
16
16
|
|
|
17
17
|
Scaling from [a_min, a_max] to [b_min, b_max] with clip option.
|
|
@@ -46,11 +46,11 @@ class ScaleIntensityRange(v2.Transform):
|
|
|
46
46
|
|
|
47
47
|
@functools.singledispatchmethod
|
|
48
48
|
@override
|
|
49
|
-
def
|
|
49
|
+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
50
50
|
return inpt
|
|
51
51
|
|
|
52
|
-
@
|
|
53
|
-
@
|
|
52
|
+
@transform.register(tv_tensors.Image)
|
|
53
|
+
@transform.register(eva_tv_tensors.Volume)
|
|
54
54
|
def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
|
|
55
55
|
inpt_scaled = self._scale_intensity_range(inpt)
|
|
56
56
|
return tv_tensors.wrap(inpt_scaled, like=inpt)
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""Transforms for spatial operations."""
|
|
2
2
|
|
|
3
3
|
from eva.vision.data.transforms.spatial.flip import RandFlip
|
|
4
|
+
from eva.vision.data.transforms.spatial.resize import Resize
|
|
4
5
|
from eva.vision.data.transforms.spatial.rotate import RandRotate90
|
|
5
6
|
from eva.vision.data.transforms.spatial.spacing import Spacing
|
|
6
7
|
|
|
7
|
-
__all__ = ["Spacing", "RandFlip", "RandRotate90"]
|
|
8
|
+
__all__ = ["Spacing", "RandFlip", "RandRotate90", "Resize"]
|
|
@@ -6,13 +6,13 @@ from typing import Any, Dict, List, Sequence
|
|
|
6
6
|
import torch
|
|
7
7
|
from monai.transforms.spatial import array as monai_spatial_transforms
|
|
8
8
|
from torchvision import tv_tensors
|
|
9
|
-
from torchvision.transforms import v2
|
|
10
9
|
from typing_extensions import override
|
|
11
10
|
|
|
12
11
|
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
12
|
+
from eva.vision.data.transforms import base
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
class RandFlip(
|
|
15
|
+
class RandFlip(base.TorchvisionTransformV2):
|
|
16
16
|
"""Randomly flips the image along axes."""
|
|
17
17
|
|
|
18
18
|
def __init__(
|
|
@@ -45,23 +45,24 @@ class RandFlip(v2.Transform):
|
|
|
45
45
|
else:
|
|
46
46
|
self._flips = [monai_spatial_transforms.RandFlip(prob=prob, spatial_axis=spatial_axes)]
|
|
47
47
|
|
|
48
|
-
|
|
48
|
+
@override
|
|
49
|
+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
49
50
|
for flip in self._flips:
|
|
50
51
|
flip.randomize(None)
|
|
51
52
|
return {}
|
|
52
53
|
|
|
53
54
|
@functools.singledispatchmethod
|
|
54
55
|
@override
|
|
55
|
-
def
|
|
56
|
+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
56
57
|
return inpt
|
|
57
58
|
|
|
58
|
-
@
|
|
59
|
-
@
|
|
59
|
+
@transform.register(tv_tensors.Image)
|
|
60
|
+
@transform.register(eva_tv_tensors.Volume)
|
|
60
61
|
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
61
62
|
inpt_flipped = self._apply_flips(inpt)
|
|
62
63
|
return tv_tensors.wrap(inpt_flipped, like=inpt)
|
|
63
64
|
|
|
64
|
-
@
|
|
65
|
+
@transform.register(tv_tensors.Mask)
|
|
65
66
|
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
66
67
|
inpt_flipped = torch.tensor(self._apply_flips(inpt), dtype=torch.long)
|
|
67
68
|
return tv_tensors.wrap(inpt_flipped, like=inpt)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Functional resizing utilities."""
|
|
2
|
+
|
|
3
|
+
import io
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
from PIL import Image
|
|
7
|
+
from torchvision import tv_tensors
|
|
8
|
+
from torchvision.transforms.v2 import functional as F
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def resize_to_max_bytes(image: tv_tensors.Image, max_bytes: int) -> tv_tensors.Image:
|
|
12
|
+
"""Resize the image to fit within the specified byte size."""
|
|
13
|
+
image_pil = F.to_pil_image(image)
|
|
14
|
+
image_bytes = io.BytesIO()
|
|
15
|
+
image_pil.save(image_bytes, format="PNG", optimize=True)
|
|
16
|
+
|
|
17
|
+
while image_bytes.tell() > max_bytes:
|
|
18
|
+
size: Tuple[int, int] = image_pil.size # type: ignore
|
|
19
|
+
w, h = size
|
|
20
|
+
scale = (max_bytes / image_bytes.tell()) ** 0.5
|
|
21
|
+
new_size = (max(1, int(h * scale)), max(1, int(w * scale)))
|
|
22
|
+
image_pil = image_pil.resize(new_size, Image.Resampling.LANCZOS)
|
|
23
|
+
image_bytes = io.BytesIO()
|
|
24
|
+
image_pil.save(image_bytes, format="PNG", optimize=True)
|
|
25
|
+
|
|
26
|
+
return tv_tensors.Image(F.pil_to_tensor(image_pil))
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""Image resize transforms."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict
|
|
5
|
+
|
|
6
|
+
from torchvision import tv_tensors
|
|
7
|
+
from torchvision.transforms import v2
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from eva.vision.data.transforms import base
|
|
11
|
+
from eva.vision.data.transforms.spatial import functional
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Resize(base.TorchvisionTransformV2):
|
|
15
|
+
"""Resize transform for images with spatial or byte-based constraints.
|
|
16
|
+
|
|
17
|
+
This transform provides two mutually exclusive modes of resizing:
|
|
18
|
+
1. Spatial resizing: Resize to a specific (height, width) dimension
|
|
19
|
+
2. Byte-based resizing: Resize to fit within a maximum byte size
|
|
20
|
+
|
|
21
|
+
The latter is particularly useful for API models (e.g. Claude 3.7) that
|
|
22
|
+
have strict byte size limits for image inputs.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, size: tuple[int, int] | None = None, max_bytes: int | None = None) -> None:
|
|
26
|
+
"""Initializes the transform.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
size: Target size as (height, width) tuple for spatial resizing.
|
|
30
|
+
If provided, max_bytes must be None.
|
|
31
|
+
max_bytes: Maximum allowed byte size for the image.
|
|
32
|
+
If provided, size must be None. Must be a positive integer.
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
ValueError: If both size and max_bytes are provided, or if max_bytes
|
|
36
|
+
is not a positive integer.
|
|
37
|
+
"""
|
|
38
|
+
if size is not None and max_bytes is not None:
|
|
39
|
+
raise ValueError("Cannot provide both 'size' and 'max_bytes' parameters.")
|
|
40
|
+
if max_bytes is not None and max_bytes <= 0:
|
|
41
|
+
raise ValueError("'max_bytes' must be a positive integer.")
|
|
42
|
+
|
|
43
|
+
super().__init__()
|
|
44
|
+
|
|
45
|
+
self.size = size
|
|
46
|
+
self.max_bytes = max_bytes
|
|
47
|
+
self.resize_fn = None
|
|
48
|
+
|
|
49
|
+
if size is not None:
|
|
50
|
+
self.resize_fn = v2.Resize(size=size)
|
|
51
|
+
elif max_bytes is not None:
|
|
52
|
+
self.resize_fn = functools.partial(functional.resize_to_max_bytes, max_bytes=max_bytes)
|
|
53
|
+
|
|
54
|
+
@functools.singledispatchmethod
|
|
55
|
+
@override
|
|
56
|
+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
57
|
+
return inpt
|
|
58
|
+
|
|
59
|
+
@transform.register(tv_tensors.Image)
|
|
60
|
+
@transform.register(tv_tensors.Mask)
|
|
61
|
+
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
62
|
+
inpt_resized = self.resize_fn(inpt) if self.resize_fn is not None else inpt
|
|
63
|
+
return tv_tensors.wrap(inpt_resized, like=inpt)
|
|
@@ -5,13 +5,13 @@ from typing import Any, Dict, List
|
|
|
5
5
|
|
|
6
6
|
from monai.transforms.spatial import array as monai_spatial_transforms
|
|
7
7
|
from torchvision import tv_tensors
|
|
8
|
-
from torchvision.transforms import v2
|
|
9
8
|
from typing_extensions import override
|
|
10
9
|
|
|
11
10
|
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
11
|
+
from eva.vision.data.transforms import base
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
class RandRotate90(
|
|
14
|
+
class RandRotate90(base.TorchvisionTransformV2):
|
|
15
15
|
"""Rotate input tensors by 90 degrees."""
|
|
16
16
|
|
|
17
17
|
def __init__(
|
|
@@ -36,18 +36,19 @@ class RandRotate90(v2.Transform):
|
|
|
36
36
|
prob=prob, max_k=max_k, spatial_axes=spatial_axes
|
|
37
37
|
)
|
|
38
38
|
|
|
39
|
-
|
|
39
|
+
@override
|
|
40
|
+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
40
41
|
self._rotate.randomize()
|
|
41
42
|
return {}
|
|
42
43
|
|
|
43
44
|
@functools.singledispatchmethod
|
|
44
45
|
@override
|
|
45
|
-
def
|
|
46
|
+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
46
47
|
return inpt
|
|
47
48
|
|
|
48
|
-
@
|
|
49
|
-
@
|
|
50
|
-
@
|
|
49
|
+
@transform.register(tv_tensors.Image)
|
|
50
|
+
@transform.register(eva_tv_tensors.Volume)
|
|
51
|
+
@transform.register(tv_tensors.Mask)
|
|
51
52
|
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
52
53
|
inpt_rotated = self._rotate(img=inpt, randomize=False)
|
|
53
54
|
return tv_tensors.wrap(inpt_rotated, like=inpt)
|
|
@@ -8,13 +8,13 @@ import torch
|
|
|
8
8
|
from monai.data import meta_tensor
|
|
9
9
|
from monai.transforms.spatial import array as monai_spatial_transforms
|
|
10
10
|
from torchvision import tv_tensors
|
|
11
|
-
from torchvision.transforms import v2
|
|
12
11
|
from typing_extensions import override
|
|
13
12
|
|
|
14
13
|
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
14
|
+
from eva.vision.data.transforms import base
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
class Spacing(
|
|
17
|
+
class Spacing(base.TorchvisionTransformV2):
|
|
18
18
|
"""Resample input image into the specified `pixdim`.
|
|
19
19
|
|
|
20
20
|
- Expects tensors of shape `[C, T, H, W]`.
|
|
@@ -43,7 +43,8 @@ class Spacing(v2.Transform):
|
|
|
43
43
|
self._spacing = monai_spatial_transforms.Spacing(pixdim=pixdim, recompute_affine=True)
|
|
44
44
|
self._affine = None
|
|
45
45
|
|
|
46
|
-
|
|
46
|
+
@override
|
|
47
|
+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
47
48
|
self._affine = next(
|
|
48
49
|
inpt.affine for inpt in flat_inputs if isinstance(inpt, eva_tv_tensors.Volume)
|
|
49
50
|
)
|
|
@@ -51,17 +52,17 @@ class Spacing(v2.Transform):
|
|
|
51
52
|
|
|
52
53
|
@functools.singledispatchmethod
|
|
53
54
|
@override
|
|
54
|
-
def
|
|
55
|
+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
55
56
|
return inpt
|
|
56
57
|
|
|
57
|
-
@
|
|
58
|
+
@transform.register(eva_tv_tensors.Volume)
|
|
58
59
|
def _(self, inpt: eva_tv_tensors.Volume, params: Dict[str, Any]) -> Any:
|
|
59
60
|
inpt_spacing = self._spacing(inpt.to_meta_tensor(), mode="bilinear")
|
|
60
61
|
if not isinstance(inpt_spacing, meta_tensor.MetaTensor):
|
|
61
62
|
raise ValueError(f"Expected MetaTensor, got {type(inpt_spacing)}")
|
|
62
63
|
return eva_tv_tensors.Volume.from_meta_tensor(inpt_spacing)
|
|
63
64
|
|
|
64
|
-
@
|
|
65
|
+
@transform.register(tv_tensors.Mask)
|
|
65
66
|
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
66
67
|
inpt_spacing = self._spacing(
|
|
67
68
|
meta_tensor.MetaTensor(inpt, affine=self._affine), mode="nearest"
|
|
@@ -5,13 +5,13 @@ from typing import Any, Dict
|
|
|
5
5
|
|
|
6
6
|
from monai.transforms.utility import array as monai_utility_transforms
|
|
7
7
|
from torchvision import tv_tensors
|
|
8
|
-
from torchvision.transforms import v2
|
|
9
8
|
from typing_extensions import override
|
|
10
9
|
|
|
11
10
|
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
11
|
+
from eva.vision.data.transforms import base
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
class EnsureChannelFirst(
|
|
14
|
+
class EnsureChannelFirst(base.TorchvisionTransformV2):
|
|
15
15
|
"""Adjust or add the channel dimension of input data to ensure `channel_first` shape."""
|
|
16
16
|
|
|
17
17
|
def __init__(
|
|
@@ -40,12 +40,12 @@ class EnsureChannelFirst(v2.Transform):
|
|
|
40
40
|
|
|
41
41
|
@functools.singledispatchmethod
|
|
42
42
|
@override
|
|
43
|
-
def
|
|
43
|
+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
44
44
|
return inpt
|
|
45
45
|
|
|
46
|
-
@
|
|
47
|
-
@
|
|
48
|
-
@
|
|
46
|
+
@transform.register(tv_tensors.Image)
|
|
47
|
+
@transform.register(eva_tv_tensors.Volume)
|
|
48
|
+
@transform.register(tv_tensors.Mask)
|
|
49
49
|
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
50
50
|
inpt_channel_first = self._ensure_channel_first(inpt)
|
|
51
51
|
return tv_tensors.wrap(inpt_channel_first, like=inpt)
|
|
@@ -54,6 +54,30 @@ def vit_small_patch16_224_dino(
|
|
|
54
54
|
)
|
|
55
55
|
|
|
56
56
|
|
|
57
|
+
@backbone_registry.register("universal/vit_tiny_patch16_224_random")
|
|
58
|
+
def vit_tiny_patch16_224_random(
|
|
59
|
+
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
60
|
+
) -> nn.Module:
|
|
61
|
+
"""Initializes a ViT-Tiny16 baseline model with random weights.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
dynamic_img_size: Support different input image sizes by allowing to change
|
|
65
|
+
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
|
|
66
|
+
out_indices: Whether and which multi-level patch embeddings to return.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
The torch ViTS-16 based foundation model.
|
|
70
|
+
"""
|
|
71
|
+
return timm.create_model(
|
|
72
|
+
model_name="vit_tiny_patch16_224",
|
|
73
|
+
pretrained=False,
|
|
74
|
+
num_classes=0,
|
|
75
|
+
features_only=out_indices is not None,
|
|
76
|
+
out_indices=out_indices,
|
|
77
|
+
dynamic_img_size=dynamic_img_size,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
57
81
|
@backbone_registry.register("universal/vit_small_patch16_224_dino_1chan")
|
|
58
82
|
def vit_small_patch16_224_dino_1chan(
|
|
59
83
|
dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
from typing import Any, Callable, Dict
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
+
from torch import nn
|
|
6
7
|
from typing_extensions import override
|
|
7
8
|
|
|
8
9
|
from eva.core.models.wrappers import base
|
|
@@ -40,14 +41,14 @@ class ModelFromRegistry(base.BaseModel[torch.Tensor, torch.Tensor]):
|
|
|
40
41
|
self._model_kwargs = model_kwargs or {}
|
|
41
42
|
self._model_extra_kwargs = model_extra_kwargs or {}
|
|
42
43
|
|
|
43
|
-
self.load_model()
|
|
44
|
+
self.model = self.load_model()
|
|
44
45
|
|
|
45
46
|
@override
|
|
46
|
-
def load_model(self) ->
|
|
47
|
-
|
|
47
|
+
def load_model(self) -> nn.Module:
|
|
48
|
+
ModelFromRegistry.__name__ = self._model_name
|
|
49
|
+
|
|
50
|
+
return factory.ModuleFactory(
|
|
48
51
|
registry=backbone_registry,
|
|
49
52
|
name=self._model_name,
|
|
50
53
|
init_args=self._model_kwargs | self._model_extra_kwargs,
|
|
51
54
|
)
|
|
52
|
-
|
|
53
|
-
ModelFromRegistry.__name__ = self._model_name
|