kaiko-eva 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (131) hide show
  1. eva/core/callbacks/config.py +15 -6
  2. eva/core/callbacks/writers/embeddings/base.py +44 -10
  3. eva/core/cli/setup.py +1 -1
  4. eva/core/data/dataloaders/__init__.py +1 -2
  5. eva/core/data/samplers/classification/balanced.py +24 -12
  6. eva/core/data/samplers/random.py +17 -10
  7. eva/core/interface/interface.py +21 -0
  8. eva/core/loggers/utils/wandb.py +4 -1
  9. eva/core/models/modules/module.py +2 -2
  10. eva/core/models/wrappers/base.py +2 -2
  11. eva/core/models/wrappers/from_function.py +3 -3
  12. eva/core/models/wrappers/from_torchhub.py +9 -7
  13. eva/core/models/wrappers/huggingface.py +4 -5
  14. eva/core/models/wrappers/onnx.py +5 -5
  15. eva/core/trainers/trainer.py +13 -1
  16. eva/core/utils/__init__.py +2 -1
  17. eva/core/utils/distributed.py +12 -0
  18. eva/core/utils/paths.py +14 -0
  19. eva/core/utils/requirements.py +52 -6
  20. eva/language/__init__.py +2 -1
  21. eva/language/callbacks/__init__.py +5 -0
  22. eva/language/callbacks/writers/__init__.py +5 -0
  23. eva/language/callbacks/writers/prediction.py +201 -0
  24. eva/language/data/dataloaders/__init__.py +5 -0
  25. eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
  26. eva/language/data/dataloaders/collate_fn/text.py +57 -0
  27. eva/language/data/datasets/__init__.py +3 -1
  28. eva/language/data/datasets/{language.py → base.py} +1 -1
  29. eva/language/data/datasets/classification/base.py +3 -43
  30. eva/language/data/datasets/classification/pubmedqa.py +36 -4
  31. eva/language/data/datasets/prediction.py +151 -0
  32. eva/language/data/datasets/schemas.py +18 -0
  33. eva/language/data/datasets/text.py +92 -0
  34. eva/language/data/datasets/typings.py +39 -0
  35. eva/language/data/messages.py +60 -0
  36. eva/language/models/__init__.py +15 -11
  37. eva/language/models/modules/__init__.py +2 -2
  38. eva/language/models/modules/language.py +94 -0
  39. eva/language/models/networks/__init__.py +12 -0
  40. eva/language/models/networks/alibaba.py +26 -0
  41. eva/language/models/networks/api/__init__.py +11 -0
  42. eva/language/models/networks/api/anthropic.py +34 -0
  43. eva/language/models/networks/registry.py +5 -0
  44. eva/language/models/typings.py +56 -0
  45. eva/language/models/wrappers/__init__.py +13 -5
  46. eva/language/models/wrappers/base.py +47 -0
  47. eva/language/models/wrappers/from_registry.py +54 -0
  48. eva/language/models/wrappers/huggingface.py +57 -11
  49. eva/language/models/wrappers/litellm.py +91 -46
  50. eva/language/models/wrappers/vllm.py +37 -13
  51. eva/language/utils/__init__.py +2 -1
  52. eva/language/utils/str_to_int_tensor.py +20 -12
  53. eva/language/utils/text/__init__.py +5 -0
  54. eva/language/utils/text/messages.py +113 -0
  55. eva/multimodal/__init__.py +6 -0
  56. eva/multimodal/callbacks/__init__.py +5 -0
  57. eva/multimodal/callbacks/writers/__init__.py +5 -0
  58. eva/multimodal/callbacks/writers/prediction.py +39 -0
  59. eva/multimodal/data/__init__.py +5 -0
  60. eva/multimodal/data/dataloaders/__init__.py +5 -0
  61. eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
  62. eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
  63. eva/multimodal/data/datasets/__init__.py +6 -0
  64. eva/multimodal/data/datasets/base.py +13 -0
  65. eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
  66. eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
  67. eva/multimodal/data/datasets/schemas.py +14 -0
  68. eva/multimodal/data/datasets/text_image.py +77 -0
  69. eva/multimodal/data/datasets/typings.py +27 -0
  70. eva/multimodal/models/__init__.py +8 -0
  71. eva/multimodal/models/modules/__init__.py +5 -0
  72. eva/multimodal/models/modules/vision_language.py +56 -0
  73. eva/multimodal/models/networks/__init__.py +14 -0
  74. eva/multimodal/models/networks/alibaba.py +40 -0
  75. eva/multimodal/models/networks/api/__init__.py +11 -0
  76. eva/multimodal/models/networks/api/anthropic.py +34 -0
  77. eva/multimodal/models/networks/others.py +48 -0
  78. eva/multimodal/models/networks/registry.py +5 -0
  79. eva/multimodal/models/typings.py +27 -0
  80. eva/multimodal/models/wrappers/__init__.py +13 -0
  81. eva/multimodal/models/wrappers/base.py +48 -0
  82. eva/multimodal/models/wrappers/from_registry.py +54 -0
  83. eva/multimodal/models/wrappers/huggingface.py +193 -0
  84. eva/multimodal/models/wrappers/litellm.py +58 -0
  85. eva/multimodal/utils/__init__.py +1 -0
  86. eva/multimodal/utils/batch/__init__.py +5 -0
  87. eva/multimodal/utils/batch/unpack.py +11 -0
  88. eva/multimodal/utils/image/__init__.py +5 -0
  89. eva/multimodal/utils/image/encode.py +28 -0
  90. eva/multimodal/utils/text/__init__.py +1 -0
  91. eva/multimodal/utils/text/messages.py +79 -0
  92. eva/vision/data/datasets/classification/breakhis.py +5 -8
  93. eva/vision/data/datasets/classification/panda.py +12 -5
  94. eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
  95. eva/vision/data/datasets/segmentation/btcv.py +1 -1
  96. eva/vision/data/datasets/segmentation/consep.py +1 -1
  97. eva/vision/data/datasets/segmentation/lits17.py +1 -1
  98. eva/vision/data/datasets/segmentation/monusac.py +15 -6
  99. eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +1 -1
  100. eva/vision/data/transforms/__init__.py +2 -1
  101. eva/vision/data/transforms/base/__init__.py +2 -1
  102. eva/vision/data/transforms/base/monai.py +2 -2
  103. eva/vision/data/transforms/base/torchvision.py +33 -0
  104. eva/vision/data/transforms/common/squeeze.py +6 -3
  105. eva/vision/data/transforms/croppad/crop_foreground.py +8 -7
  106. eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +6 -5
  107. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -5
  108. eva/vision/data/transforms/croppad/rand_spatial_crop.py +8 -7
  109. eva/vision/data/transforms/croppad/spatial_pad.py +6 -6
  110. eva/vision/data/transforms/intensity/rand_scale_intensity.py +3 -3
  111. eva/vision/data/transforms/intensity/rand_shift_intensity.py +3 -3
  112. eva/vision/data/transforms/intensity/scale_intensity_ranged.py +5 -5
  113. eva/vision/data/transforms/spatial/__init__.py +2 -1
  114. eva/vision/data/transforms/spatial/flip.py +8 -7
  115. eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
  116. eva/vision/data/transforms/spatial/functional/resize.py +26 -0
  117. eva/vision/data/transforms/spatial/resize.py +63 -0
  118. eva/vision/data/transforms/spatial/rotate.py +8 -7
  119. eva/vision/data/transforms/spatial/spacing.py +7 -6
  120. eva/vision/data/transforms/utility/ensure_channel_first.py +6 -6
  121. eva/vision/models/networks/backbones/universal/vit.py +24 -0
  122. eva/vision/models/wrappers/from_registry.py +6 -5
  123. eva/vision/models/wrappers/from_timm.py +6 -4
  124. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/METADATA +17 -3
  125. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/RECORD +128 -66
  126. eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
  127. eva/core/data/dataloaders/collate_fn/collate.py +0 -24
  128. eva/language/models/modules/text.py +0 -85
  129. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/WHEEL +0 -0
  130. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/entry_points.txt +0 -0
  131. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -2,10 +2,10 @@
2
2
 
3
3
  import abc
4
4
 
5
- from torchvision.transforms import v2
5
+ from eva.vision.data.transforms.base.torchvision import TorchvisionTransformV2
6
6
 
7
7
 
8
- class RandomMonaiTransform(v2.Transform, abc.ABC):
8
+ class RandomMonaiTransform(TorchvisionTransformV2, abc.ABC):
9
9
  """Base class for MONAI transform wrappers."""
10
10
 
11
11
  @abc.abstractmethod
@@ -0,0 +1,33 @@
1
+ """Base class for torchvision.v2 transforms."""
2
+
3
+ import abc
4
+ from typing import Any, Dict, List
5
+
6
+ from torchvision.transforms import v2
7
+
8
+
9
+ class TorchvisionTransformV2(v2.Transform, abc.ABC):
10
+ """Wrapper for torchvision.v2.Transform.
11
+
12
+ This class ensures compatibility both with >=0.21.0 and older versions,
13
+ as torchvision 0.21.0 introduced a new transform API where they
14
+ renamed the following methods:
15
+
16
+ - `_get_params` -> `make_params`
17
+ - `_transform` -> `transform`
18
+ """
19
+
20
+ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
21
+ """Called internally before calling transform() on each input."""
22
+ return {}
23
+
24
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
25
+ return self.make_params(flat_inputs)
26
+
27
+ @abc.abstractmethod
28
+ def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
29
+ """Applies the transformation to the input."""
30
+ raise NotImplementedError
31
+
32
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
33
+ return self.transform(inpt, params)
@@ -4,10 +4,12 @@ from typing import Any
4
4
 
5
5
  import torch
6
6
  from torchvision import tv_tensors
7
- from torchvision.transforms import v2
7
+ from typing_extensions import override
8
8
 
9
+ from eva.vision.data.transforms import base
9
10
 
10
- class Squeeze(v2.Transform):
11
+
12
+ class Squeeze(base.TorchvisionTransformV2):
11
13
  """Squeezes the input tensor accross all or specified dimensions."""
12
14
 
13
15
  def __init__(self, dim: int | list[int] | None = None):
@@ -19,6 +21,7 @@ class Squeeze(v2.Transform):
19
21
  super().__init__()
20
22
  self._dim = dim
21
23
 
22
- def _transform(self, inpt: Any, params: dict[str, Any]) -> Any:
24
+ @override
25
+ def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
23
26
  output = torch.squeeze(inpt) if self._dim is None else torch.squeeze(inpt, dim=self._dim)
24
27
  return tv_tensors.wrap(output, like=inpt)
@@ -8,13 +8,13 @@ from monai.config import type_definitions
8
8
  from monai.transforms.croppad import array as monai_croppad_transforms
9
9
  from monai.utils.enums import PytorchPadMode
10
10
  from torchvision import tv_tensors
11
- from torchvision.transforms import v2
12
11
  from typing_extensions import override
13
12
 
14
13
  from eva.vision.data import tv_tensors as eva_tv_tensors
14
+ from eva.vision.data.transforms import base
15
15
 
16
16
 
17
- class CropForeground(v2.Transform):
17
+ class CropForeground(base.TorchvisionTransformV2):
18
18
  """Crop an image using a bounding box.
19
19
 
20
20
  The bounding box is generated by selecting foreground using select_fn
@@ -74,19 +74,20 @@ class CropForeground(v2.Transform):
74
74
  **pad_kwargs,
75
75
  )
76
76
 
77
- def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
77
+ @override
78
+ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
78
79
  volume = next(inpt for inpt in flat_inputs if isinstance(inpt, eva_tv_tensors.Volume))
79
80
  box_start, box_end = self._foreground_crop.compute_bounding_box(volume)
80
81
  return {"box_start": box_start, "box_end": box_end}
81
82
 
82
83
  @functools.singledispatchmethod
83
84
  @override
84
- def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
85
+ def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
85
86
  return inpt
86
87
 
87
- @_transform.register(tv_tensors.Image)
88
- @_transform.register(eva_tv_tensors.Volume)
89
- @_transform.register(tv_tensors.Mask)
88
+ @transform.register(tv_tensors.Image)
89
+ @transform.register(eva_tv_tensors.Volume)
90
+ @transform.register(tv_tensors.Mask)
90
91
  def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
91
92
  inpt_foreground_cropped = self._foreground_crop.crop_pad(
92
93
  inpt, params["box_start"], params["box_end"]
@@ -56,19 +56,20 @@ class RandCropByLabelClasses(base.RandomMonaiTransform):
56
56
  def set_random_state(self, seed: int) -> None:
57
57
  self._rand_crop.set_random_state(seed)
58
58
 
59
- def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
59
+ @override
60
+ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
60
61
  mask = next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.Mask))
61
62
  self._rand_crop.randomize(label=mask)
62
63
  return {}
63
64
 
64
65
  @functools.singledispatchmethod
65
66
  @override
66
- def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
67
+ def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
67
68
  return inpt
68
69
 
69
- @_transform.register(tv_tensors.Image)
70
- @_transform.register(eva_tv_tensors.Volume)
71
- @_transform.register(tv_tensors.Mask)
70
+ @transform.register(tv_tensors.Image)
71
+ @transform.register(eva_tv_tensors.Volume)
72
+ @transform.register(tv_tensors.Mask)
72
73
  def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
73
74
  inpt_foreground_crops = self._rand_crop(img=inpt, randomize=False)
74
75
  return [tv_tensors.wrap(crop, like=inpt) for crop in inpt_foreground_crops]
@@ -95,19 +95,20 @@ class RandCropByPosNegLabel(base.RandomMonaiTransform):
95
95
  def set_random_state(self, seed: int) -> None:
96
96
  self._rand_crop.set_random_state(seed)
97
97
 
98
- def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
98
+ @override
99
+ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
99
100
  mask = next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.Mask))
100
101
  self._rand_crop.randomize(label=mask)
101
102
  return {}
102
103
 
103
104
  @functools.singledispatchmethod
104
105
  @override
105
- def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
106
+ def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
106
107
  return inpt
107
108
 
108
- @_transform.register(tv_tensors.Image)
109
- @_transform.register(eva_tv_tensors.Volume)
110
- @_transform.register(tv_tensors.Mask)
109
+ @transform.register(tv_tensors.Image)
110
+ @transform.register(eva_tv_tensors.Volume)
111
+ @transform.register(tv_tensors.Mask)
111
112
  def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
112
113
  inpt_foreground_crops = self._rand_crop(img=inpt, randomize=False)
113
114
  return [tv_tensors.wrap(crop, like=inpt) for crop in inpt_foreground_crops]
@@ -5,14 +5,14 @@ from typing import Any, Dict, List, Sequence, Tuple
5
5
 
6
6
  from monai.transforms.croppad import array as monai_croppad_transforms
7
7
  from torchvision import tv_tensors
8
- from torchvision.transforms import v2
9
8
  from torchvision.transforms.v2 import _utils as tv_utils
10
9
  from typing_extensions import override
11
10
 
12
11
  from eva.vision.data import tv_tensors as eva_tv_tensors
12
+ from eva.vision.data.transforms import base
13
13
 
14
14
 
15
- class RandSpatialCrop(v2.Transform):
15
+ class RandSpatialCrop(base.TorchvisionTransformV2):
16
16
  """Crop image with random size or specific size ROI.
17
17
 
18
18
  It can crop at a random position as center or at the image center.
@@ -62,19 +62,20 @@ class RandSpatialCrop(v2.Transform):
62
62
  """Set the random state for the transform."""
63
63
  self._rand_spatial_crop.set_random_state(seed)
64
64
 
65
- def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
65
+ @override
66
+ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
66
67
  t, h, w = tv_utils.query_chw(flat_inputs)
67
68
  self._rand_spatial_crop.randomize((t, h, w))
68
69
  return {}
69
70
 
70
71
  @functools.singledispatchmethod
71
72
  @override
72
- def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
73
+ def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
73
74
  return inpt
74
75
 
75
- @_transform.register(tv_tensors.Image)
76
- @_transform.register(eva_tv_tensors.Volume)
77
- @_transform.register(tv_tensors.Mask)
76
+ @transform.register(tv_tensors.Image)
77
+ @transform.register(eva_tv_tensors.Volume)
78
+ @transform.register(tv_tensors.Mask)
78
79
  def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
79
80
  slices = self._get_crop_slices()
80
81
  inpt_rand_crop = self._cropper(inpt, slices=slices)
@@ -6,13 +6,13 @@ from typing import Any, Dict, Sequence
6
6
  from monai.transforms.croppad import array as monai_croppad_transforms
7
7
  from monai.utils.enums import Method, PytorchPadMode
8
8
  from torchvision import tv_tensors
9
- from torchvision.transforms import v2
10
9
  from typing_extensions import override
11
10
 
12
11
  from eva.vision.data import tv_tensors as eva_tv_tensors
12
+ from eva.vision.data.transforms import base
13
13
 
14
14
 
15
- class SpatialPad(v2.Transform):
15
+ class SpatialPad(base.TorchvisionTransformV2):
16
16
  """Performs padding to the data.
17
17
 
18
18
  Padding is applied symmetric for all sides or all on one side for each dimension.
@@ -56,12 +56,12 @@ class SpatialPad(v2.Transform):
56
56
 
57
57
  @functools.singledispatchmethod
58
58
  @override
59
- def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
59
+ def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
60
60
  return inpt
61
61
 
62
- @_transform.register(tv_tensors.Image)
63
- @_transform.register(eva_tv_tensors.Volume)
64
- @_transform.register(tv_tensors.Mask)
62
+ @transform.register(tv_tensors.Image)
63
+ @transform.register(eva_tv_tensors.Volume)
64
+ @transform.register(tv_tensors.Mask)
65
65
  def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
66
66
  inpt_padded = self._spatial_pad(inpt)
67
67
  return tv_tensors.wrap(inpt_padded, like=inpt)
@@ -53,11 +53,11 @@ class RandScaleIntensity(base.RandomMonaiTransform):
53
53
 
54
54
  @functools.singledispatchmethod
55
55
  @override
56
- def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
56
+ def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
57
57
  return inpt
58
58
 
59
- @_transform.register(tv_tensors.Image)
60
- @_transform.register(eva_tv_tensors.Volume)
59
+ @transform.register(tv_tensors.Image)
60
+ @transform.register(eva_tv_tensors.Volume)
61
61
  def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
62
62
  inpt_scaled = self._rand_scale_intensity(inpt)
63
63
  return tv_tensors.wrap(inpt_scaled, like=inpt)
@@ -49,11 +49,11 @@ class RandShiftIntensity(base.RandomMonaiTransform):
49
49
 
50
50
  @functools.singledispatchmethod
51
51
  @override
52
- def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
52
+ def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
53
53
  return inpt
54
54
 
55
- @_transform.register(tv_tensors.Image)
56
- @_transform.register(eva_tv_tensors.Volume)
55
+ @transform.register(tv_tensors.Image)
56
+ @transform.register(eva_tv_tensors.Volume)
57
57
  def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
58
58
  inpt_scaled = self._rand_shift_intensity(inpt)
59
59
  return tv_tensors.wrap(inpt_scaled, like=inpt)
@@ -5,13 +5,13 @@ from typing import Any, Dict, Tuple
5
5
 
6
6
  from monai.transforms.intensity import array as monai_intensity_transforms
7
7
  from torchvision import tv_tensors
8
- from torchvision.transforms import v2
9
8
  from typing_extensions import override
10
9
 
11
10
  from eva.vision.data import tv_tensors as eva_tv_tensors
11
+ from eva.vision.data.transforms import base
12
12
 
13
13
 
14
- class ScaleIntensityRange(v2.Transform):
14
+ class ScaleIntensityRange(base.TorchvisionTransformV2):
15
15
  """Intensity scaling transform.
16
16
 
17
17
  Scaling from [a_min, a_max] to [b_min, b_max] with clip option.
@@ -46,11 +46,11 @@ class ScaleIntensityRange(v2.Transform):
46
46
 
47
47
  @functools.singledispatchmethod
48
48
  @override
49
- def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
49
+ def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
50
50
  return inpt
51
51
 
52
- @_transform.register(tv_tensors.Image)
53
- @_transform.register(eva_tv_tensors.Volume)
52
+ @transform.register(tv_tensors.Image)
53
+ @transform.register(eva_tv_tensors.Volume)
54
54
  def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
55
55
  inpt_scaled = self._scale_intensity_range(inpt)
56
56
  return tv_tensors.wrap(inpt_scaled, like=inpt)
@@ -1,7 +1,8 @@
1
1
  """Transforms for spatial operations."""
2
2
 
3
3
  from eva.vision.data.transforms.spatial.flip import RandFlip
4
+ from eva.vision.data.transforms.spatial.resize import Resize
4
5
  from eva.vision.data.transforms.spatial.rotate import RandRotate90
5
6
  from eva.vision.data.transforms.spatial.spacing import Spacing
6
7
 
7
- __all__ = ["Spacing", "RandFlip", "RandRotate90"]
8
+ __all__ = ["Spacing", "RandFlip", "RandRotate90", "Resize"]
@@ -6,13 +6,13 @@ from typing import Any, Dict, List, Sequence
6
6
  import torch
7
7
  from monai.transforms.spatial import array as monai_spatial_transforms
8
8
  from torchvision import tv_tensors
9
- from torchvision.transforms import v2
10
9
  from typing_extensions import override
11
10
 
12
11
  from eva.vision.data import tv_tensors as eva_tv_tensors
12
+ from eva.vision.data.transforms import base
13
13
 
14
14
 
15
- class RandFlip(v2.Transform):
15
+ class RandFlip(base.TorchvisionTransformV2):
16
16
  """Randomly flips the image along axes."""
17
17
 
18
18
  def __init__(
@@ -45,23 +45,24 @@ class RandFlip(v2.Transform):
45
45
  else:
46
46
  self._flips = [monai_spatial_transforms.RandFlip(prob=prob, spatial_axis=spatial_axes)]
47
47
 
48
- def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
48
+ @override
49
+ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
49
50
  for flip in self._flips:
50
51
  flip.randomize(None)
51
52
  return {}
52
53
 
53
54
  @functools.singledispatchmethod
54
55
  @override
55
- def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
56
+ def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
56
57
  return inpt
57
58
 
58
- @_transform.register(tv_tensors.Image)
59
- @_transform.register(eva_tv_tensors.Volume)
59
+ @transform.register(tv_tensors.Image)
60
+ @transform.register(eva_tv_tensors.Volume)
60
61
  def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
61
62
  inpt_flipped = self._apply_flips(inpt)
62
63
  return tv_tensors.wrap(inpt_flipped, like=inpt)
63
64
 
64
- @_transform.register(tv_tensors.Mask)
65
+ @transform.register(tv_tensors.Mask)
65
66
  def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
66
67
  inpt_flipped = torch.tensor(self._apply_flips(inpt), dtype=torch.long)
67
68
  return tv_tensors.wrap(inpt_flipped, like=inpt)
@@ -0,0 +1,5 @@
1
+ """Functional API for spatial transforms."""
2
+
3
+ from eva.vision.data.transforms.spatial.functional.resize import resize_to_max_bytes
4
+
5
+ __all__ = ["resize_to_max_bytes"]
@@ -0,0 +1,26 @@
1
+ """Functional resizing utilities."""
2
+
3
+ import io
4
+ from typing import Tuple
5
+
6
+ from PIL import Image
7
+ from torchvision import tv_tensors
8
+ from torchvision.transforms.v2 import functional as F
9
+
10
+
11
+ def resize_to_max_bytes(image: tv_tensors.Image, max_bytes: int) -> tv_tensors.Image:
12
+ """Resize the image to fit within the specified byte size."""
13
+ image_pil = F.to_pil_image(image)
14
+ image_bytes = io.BytesIO()
15
+ image_pil.save(image_bytes, format="PNG", optimize=True)
16
+
17
+ while image_bytes.tell() > max_bytes:
18
+ size: Tuple[int, int] = image_pil.size # type: ignore
19
+ w, h = size
20
+ scale = (max_bytes / image_bytes.tell()) ** 0.5
21
+ new_size = (max(1, int(h * scale)), max(1, int(w * scale)))
22
+ image_pil = image_pil.resize(new_size, Image.Resampling.LANCZOS)
23
+ image_bytes = io.BytesIO()
24
+ image_pil.save(image_bytes, format="PNG", optimize=True)
25
+
26
+ return tv_tensors.Image(F.pil_to_tensor(image_pil))
@@ -0,0 +1,63 @@
1
+ """Image resize transforms."""
2
+
3
+ import functools
4
+ from typing import Any, Dict
5
+
6
+ from torchvision import tv_tensors
7
+ from torchvision.transforms import v2
8
+ from typing_extensions import override
9
+
10
+ from eva.vision.data.transforms import base
11
+ from eva.vision.data.transforms.spatial import functional
12
+
13
+
14
+ class Resize(base.TorchvisionTransformV2):
15
+ """Resize transform for images with spatial or byte-based constraints.
16
+
17
+ This transform provides two mutually exclusive modes of resizing:
18
+ 1. Spatial resizing: Resize to a specific (height, width) dimension
19
+ 2. Byte-based resizing: Resize to fit within a maximum byte size
20
+
21
+ The latter is particularly useful for API models (e.g. Claude 3.7) that
22
+ have strict byte size limits for image inputs.
23
+ """
24
+
25
+ def __init__(self, size: tuple[int, int] | None = None, max_bytes: int | None = None) -> None:
26
+ """Initializes the transform.
27
+
28
+ Args:
29
+ size: Target size as (height, width) tuple for spatial resizing.
30
+ If provided, max_bytes must be None.
31
+ max_bytes: Maximum allowed byte size for the image.
32
+ If provided, size must be None. Must be a positive integer.
33
+
34
+ Raises:
35
+ ValueError: If both size and max_bytes are provided, or if max_bytes
36
+ is not a positive integer.
37
+ """
38
+ if size is not None and max_bytes is not None:
39
+ raise ValueError("Cannot provide both 'size' and 'max_bytes' parameters.")
40
+ if max_bytes is not None and max_bytes <= 0:
41
+ raise ValueError("'max_bytes' must be a positive integer.")
42
+
43
+ super().__init__()
44
+
45
+ self.size = size
46
+ self.max_bytes = max_bytes
47
+ self.resize_fn = None
48
+
49
+ if size is not None:
50
+ self.resize_fn = v2.Resize(size=size)
51
+ elif max_bytes is not None:
52
+ self.resize_fn = functools.partial(functional.resize_to_max_bytes, max_bytes=max_bytes)
53
+
54
+ @functools.singledispatchmethod
55
+ @override
56
+ def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
57
+ return inpt
58
+
59
+ @transform.register(tv_tensors.Image)
60
+ @transform.register(tv_tensors.Mask)
61
+ def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
62
+ inpt_resized = self.resize_fn(inpt) if self.resize_fn is not None else inpt
63
+ return tv_tensors.wrap(inpt_resized, like=inpt)
@@ -5,13 +5,13 @@ from typing import Any, Dict, List
5
5
 
6
6
  from monai.transforms.spatial import array as monai_spatial_transforms
7
7
  from torchvision import tv_tensors
8
- from torchvision.transforms import v2
9
8
  from typing_extensions import override
10
9
 
11
10
  from eva.vision.data import tv_tensors as eva_tv_tensors
11
+ from eva.vision.data.transforms import base
12
12
 
13
13
 
14
- class RandRotate90(v2.Transform):
14
+ class RandRotate90(base.TorchvisionTransformV2):
15
15
  """Rotate input tensors by 90 degrees."""
16
16
 
17
17
  def __init__(
@@ -36,18 +36,19 @@ class RandRotate90(v2.Transform):
36
36
  prob=prob, max_k=max_k, spatial_axes=spatial_axes
37
37
  )
38
38
 
39
- def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
39
+ @override
40
+ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
40
41
  self._rotate.randomize()
41
42
  return {}
42
43
 
43
44
  @functools.singledispatchmethod
44
45
  @override
45
- def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
46
+ def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
46
47
  return inpt
47
48
 
48
- @_transform.register(tv_tensors.Image)
49
- @_transform.register(eva_tv_tensors.Volume)
50
- @_transform.register(tv_tensors.Mask)
49
+ @transform.register(tv_tensors.Image)
50
+ @transform.register(eva_tv_tensors.Volume)
51
+ @transform.register(tv_tensors.Mask)
51
52
  def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
52
53
  inpt_rotated = self._rotate(img=inpt, randomize=False)
53
54
  return tv_tensors.wrap(inpt_rotated, like=inpt)
@@ -8,13 +8,13 @@ import torch
8
8
  from monai.data import meta_tensor
9
9
  from monai.transforms.spatial import array as monai_spatial_transforms
10
10
  from torchvision import tv_tensors
11
- from torchvision.transforms import v2
12
11
  from typing_extensions import override
13
12
 
14
13
  from eva.vision.data import tv_tensors as eva_tv_tensors
14
+ from eva.vision.data.transforms import base
15
15
 
16
16
 
17
- class Spacing(v2.Transform):
17
+ class Spacing(base.TorchvisionTransformV2):
18
18
  """Resample input image into the specified `pixdim`.
19
19
 
20
20
  - Expects tensors of shape `[C, T, H, W]`.
@@ -43,7 +43,8 @@ class Spacing(v2.Transform):
43
43
  self._spacing = monai_spatial_transforms.Spacing(pixdim=pixdim, recompute_affine=True)
44
44
  self._affine = None
45
45
 
46
- def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
46
+ @override
47
+ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
47
48
  self._affine = next(
48
49
  inpt.affine for inpt in flat_inputs if isinstance(inpt, eva_tv_tensors.Volume)
49
50
  )
@@ -51,17 +52,17 @@ class Spacing(v2.Transform):
51
52
 
52
53
  @functools.singledispatchmethod
53
54
  @override
54
- def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
55
+ def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
55
56
  return inpt
56
57
 
57
- @_transform.register(eva_tv_tensors.Volume)
58
+ @transform.register(eva_tv_tensors.Volume)
58
59
  def _(self, inpt: eva_tv_tensors.Volume, params: Dict[str, Any]) -> Any:
59
60
  inpt_spacing = self._spacing(inpt.to_meta_tensor(), mode="bilinear")
60
61
  if not isinstance(inpt_spacing, meta_tensor.MetaTensor):
61
62
  raise ValueError(f"Expected MetaTensor, got {type(inpt_spacing)}")
62
63
  return eva_tv_tensors.Volume.from_meta_tensor(inpt_spacing)
63
64
 
64
- @_transform.register(tv_tensors.Mask)
65
+ @transform.register(tv_tensors.Mask)
65
66
  def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
66
67
  inpt_spacing = self._spacing(
67
68
  meta_tensor.MetaTensor(inpt, affine=self._affine), mode="nearest"
@@ -5,13 +5,13 @@ from typing import Any, Dict
5
5
 
6
6
  from monai.transforms.utility import array as monai_utility_transforms
7
7
  from torchvision import tv_tensors
8
- from torchvision.transforms import v2
9
8
  from typing_extensions import override
10
9
 
11
10
  from eva.vision.data import tv_tensors as eva_tv_tensors
11
+ from eva.vision.data.transforms import base
12
12
 
13
13
 
14
- class EnsureChannelFirst(v2.Transform):
14
+ class EnsureChannelFirst(base.TorchvisionTransformV2):
15
15
  """Adjust or add the channel dimension of input data to ensure `channel_first` shape."""
16
16
 
17
17
  def __init__(
@@ -40,12 +40,12 @@ class EnsureChannelFirst(v2.Transform):
40
40
 
41
41
  @functools.singledispatchmethod
42
42
  @override
43
- def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
43
+ def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
44
44
  return inpt
45
45
 
46
- @_transform.register(tv_tensors.Image)
47
- @_transform.register(eva_tv_tensors.Volume)
48
- @_transform.register(tv_tensors.Mask)
46
+ @transform.register(tv_tensors.Image)
47
+ @transform.register(eva_tv_tensors.Volume)
48
+ @transform.register(tv_tensors.Mask)
49
49
  def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
50
50
  inpt_channel_first = self._ensure_channel_first(inpt)
51
51
  return tv_tensors.wrap(inpt_channel_first, like=inpt)
@@ -54,6 +54,30 @@ def vit_small_patch16_224_dino(
54
54
  )
55
55
 
56
56
 
57
+ @backbone_registry.register("universal/vit_tiny_patch16_224_random")
58
+ def vit_tiny_patch16_224_random(
59
+ dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
60
+ ) -> nn.Module:
61
+ """Initializes a ViT-Tiny16 baseline model with random weights.
62
+
63
+ Args:
64
+ dynamic_img_size: Support different input image sizes by allowing to change
65
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
66
+ out_indices: Whether and which multi-level patch embeddings to return.
67
+
68
+ Returns:
69
+ The torch ViTS-16 based foundation model.
70
+ """
71
+ return timm.create_model(
72
+ model_name="vit_tiny_patch16_224",
73
+ pretrained=False,
74
+ num_classes=0,
75
+ features_only=out_indices is not None,
76
+ out_indices=out_indices,
77
+ dynamic_img_size=dynamic_img_size,
78
+ )
79
+
80
+
57
81
  @backbone_registry.register("universal/vit_small_patch16_224_dino_1chan")
58
82
  def vit_small_patch16_224_dino_1chan(
59
83
  dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
@@ -3,6 +3,7 @@
3
3
  from typing import Any, Callable, Dict
4
4
 
5
5
  import torch
6
+ from torch import nn
6
7
  from typing_extensions import override
7
8
 
8
9
  from eva.core.models.wrappers import base
@@ -40,14 +41,14 @@ class ModelFromRegistry(base.BaseModel[torch.Tensor, torch.Tensor]):
40
41
  self._model_kwargs = model_kwargs or {}
41
42
  self._model_extra_kwargs = model_extra_kwargs or {}
42
43
 
43
- self.load_model()
44
+ self.model = self.load_model()
44
45
 
45
46
  @override
46
- def load_model(self) -> None:
47
- self._model = factory.ModuleFactory(
47
+ def load_model(self) -> nn.Module:
48
+ ModelFromRegistry.__name__ = self._model_name
49
+
50
+ return factory.ModuleFactory(
48
51
  registry=backbone_registry,
49
52
  name=self._model_name,
50
53
  init_args=self._model_kwargs | self._model_extra_kwargs,
51
54
  )
52
-
53
- ModelFromRegistry.__name__ = self._model_name