kaiko-eva 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (101) hide show
  1. eva/core/callbacks/config.py +4 -0
  2. eva/core/cli/setup.py +1 -1
  3. eva/core/data/dataloaders/__init__.py +1 -2
  4. eva/core/data/dataloaders/dataloader.py +3 -1
  5. eva/core/data/samplers/random.py +17 -10
  6. eva/core/interface/interface.py +21 -0
  7. eva/core/loggers/log/__init__.py +2 -1
  8. eva/core/loggers/log/table.py +73 -0
  9. eva/core/models/modules/module.py +2 -2
  10. eva/core/models/wrappers/base.py +2 -2
  11. eva/core/models/wrappers/from_function.py +3 -3
  12. eva/core/models/wrappers/from_torchhub.py +9 -7
  13. eva/core/models/wrappers/huggingface.py +4 -5
  14. eva/core/models/wrappers/onnx.py +5 -5
  15. eva/core/trainers/trainer.py +2 -0
  16. eva/language/__init__.py +2 -1
  17. eva/language/callbacks/__init__.py +5 -0
  18. eva/language/callbacks/writers/__init__.py +5 -0
  19. eva/language/callbacks/writers/prediction.py +176 -0
  20. eva/language/data/dataloaders/__init__.py +5 -0
  21. eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
  22. eva/language/data/dataloaders/collate_fn/text.py +57 -0
  23. eva/language/data/datasets/__init__.py +3 -1
  24. eva/language/data/datasets/{language.py → base.py} +1 -1
  25. eva/language/data/datasets/classification/base.py +3 -43
  26. eva/language/data/datasets/classification/pubmedqa.py +36 -4
  27. eva/language/data/datasets/prediction.py +151 -0
  28. eva/language/data/datasets/schemas.py +18 -0
  29. eva/language/data/datasets/text.py +92 -0
  30. eva/language/data/datasets/typings.py +39 -0
  31. eva/language/data/messages.py +60 -0
  32. eva/language/models/__init__.py +15 -11
  33. eva/language/models/modules/__init__.py +2 -2
  34. eva/language/models/modules/language.py +93 -0
  35. eva/language/models/networks/__init__.py +12 -0
  36. eva/language/models/networks/alibaba.py +26 -0
  37. eva/language/models/networks/api/__init__.py +11 -0
  38. eva/language/models/networks/api/anthropic.py +34 -0
  39. eva/language/models/networks/registry.py +5 -0
  40. eva/language/models/typings.py +39 -0
  41. eva/language/models/wrappers/__init__.py +13 -5
  42. eva/language/models/wrappers/base.py +47 -0
  43. eva/language/models/wrappers/from_registry.py +54 -0
  44. eva/language/models/wrappers/huggingface.py +44 -8
  45. eva/language/models/wrappers/litellm.py +81 -46
  46. eva/language/models/wrappers/vllm.py +37 -13
  47. eva/language/utils/__init__.py +2 -1
  48. eva/language/utils/str_to_int_tensor.py +20 -12
  49. eva/language/utils/text/__init__.py +5 -0
  50. eva/language/utils/text/messages.py +113 -0
  51. eva/multimodal/__init__.py +6 -0
  52. eva/multimodal/callbacks/__init__.py +5 -0
  53. eva/multimodal/callbacks/writers/__init__.py +5 -0
  54. eva/multimodal/callbacks/writers/prediction.py +39 -0
  55. eva/multimodal/data/__init__.py +5 -0
  56. eva/multimodal/data/dataloaders/__init__.py +5 -0
  57. eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
  58. eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
  59. eva/multimodal/data/datasets/__init__.py +6 -0
  60. eva/multimodal/data/datasets/base.py +13 -0
  61. eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
  62. eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
  63. eva/multimodal/data/datasets/schemas.py +14 -0
  64. eva/multimodal/data/datasets/text_image.py +77 -0
  65. eva/multimodal/data/datasets/typings.py +27 -0
  66. eva/multimodal/models/__init__.py +8 -0
  67. eva/multimodal/models/modules/__init__.py +5 -0
  68. eva/multimodal/models/modules/vision_language.py +55 -0
  69. eva/multimodal/models/networks/__init__.py +14 -0
  70. eva/multimodal/models/networks/alibaba.py +39 -0
  71. eva/multimodal/models/networks/api/__init__.py +11 -0
  72. eva/multimodal/models/networks/api/anthropic.py +34 -0
  73. eva/multimodal/models/networks/others.py +47 -0
  74. eva/multimodal/models/networks/registry.py +5 -0
  75. eva/multimodal/models/typings.py +27 -0
  76. eva/multimodal/models/wrappers/__init__.py +13 -0
  77. eva/multimodal/models/wrappers/base.py +47 -0
  78. eva/multimodal/models/wrappers/from_registry.py +54 -0
  79. eva/multimodal/models/wrappers/huggingface.py +180 -0
  80. eva/multimodal/models/wrappers/litellm.py +56 -0
  81. eva/multimodal/utils/__init__.py +1 -0
  82. eva/multimodal/utils/image/__init__.py +5 -0
  83. eva/multimodal/utils/image/encode.py +28 -0
  84. eva/multimodal/utils/text/__init__.py +1 -0
  85. eva/multimodal/utils/text/messages.py +79 -0
  86. eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
  87. eva/vision/data/transforms/__init__.py +2 -1
  88. eva/vision/data/transforms/spatial/__init__.py +2 -1
  89. eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
  90. eva/vision/data/transforms/spatial/functional/resize.py +26 -0
  91. eva/vision/data/transforms/spatial/resize.py +62 -0
  92. eva/vision/models/wrappers/from_registry.py +6 -5
  93. eva/vision/models/wrappers/from_timm.py +6 -4
  94. {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/METADATA +10 -2
  95. {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/RECORD +98 -40
  96. eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
  97. eva/core/data/dataloaders/collate_fn/collate.py +0 -24
  98. eva/language/models/modules/text.py +0 -85
  99. {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/WHEEL +0 -0
  100. {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/entry_points.txt +0 -0
  101. {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,34 @@
1
+ """Models from Anthropic."""
2
+
3
+ import os
4
+
5
+ from eva.multimodal.models import wrappers
6
+ from eva.multimodal.models.networks.registry import model_registry
7
+
8
+
9
+ class _Claude(wrappers.LiteLLMModel):
10
+ """Base class for Claude models."""
11
+
12
+ def __init__(self, model_name: str, system_prompt: str | None = None):
13
+ if not os.getenv("ANTHROPIC_API_KEY"):
14
+ raise ValueError("ANTHROPIC_API_KEY env variable must be set.")
15
+
16
+ super().__init__(model_name=model_name, system_prompt=system_prompt)
17
+
18
+
19
+ @model_registry.register("anthropic/claude-3-5-sonnet-20240620")
20
+ class Claude35Sonnet20240620(_Claude):
21
+ """Claude 3.5 Sonnet (June 2024) model."""
22
+
23
+ def __init__(self, system_prompt: str | None = None):
24
+ """Initialize the model."""
25
+ super().__init__(model_name="claude-3-5-sonnet-20240620", system_prompt=system_prompt)
26
+
27
+
28
+ @model_registry.register("anthropic/claude-3-7-sonnet-20250219")
29
+ class Claude37Sonnet20250219(_Claude):
30
+ """Claude 3.7 Sonnet (February 2025) model."""
31
+
32
+ def __init__(self, system_prompt: str | None = None):
33
+ """Initialize the model."""
34
+ super().__init__(model_name="claude-3-7-sonnet-20250219", system_prompt=system_prompt)
@@ -0,0 +1,47 @@
1
+ """Models from other providers (non-major entities)."""
2
+
3
+ import os
4
+
5
+ import torch
6
+
7
+ from eva.core.utils import requirements
8
+ from eva.multimodal.models import wrappers
9
+ from eva.multimodal.models.networks.registry import model_registry
10
+
11
+
12
+ @model_registry.register("others/wenchuanzhang_patho-r1-3b")
13
+ class PathoR13b(wrappers.HuggingFaceModel):
14
+ """Patho-R1-3B model by Wenchuan Zhang."""
15
+
16
+ def __init__(
17
+ self,
18
+ system_prompt: str | None = None,
19
+ cache_dir: str | None = None,
20
+ attn_implementation: str = "flash_attention_2",
21
+ ):
22
+ """Initialize the Patho-R1-3B model."""
23
+ requirements.check_dependencies(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
24
+
25
+ if not os.getenv("HF_TOKEN"):
26
+ raise ValueError("HF_TOKEN env variable must be set.")
27
+
28
+ super().__init__(
29
+ model_name_or_path="WenchuanZhang/Patho-R1-3B",
30
+ model_class="Qwen2_5_VLForConditionalGeneration",
31
+ model_kwargs={
32
+ "torch_dtype": torch.float16,
33
+ "trust_remote_code": True,
34
+ "cache_dir": cache_dir,
35
+ "attn_implementation": attn_implementation,
36
+ },
37
+ generation_kwargs={
38
+ "max_new_tokens": 512,
39
+ "do_sample": False,
40
+ },
41
+ processor_kwargs={
42
+ "padding": True,
43
+ "padding_side": "left",
44
+ "max_pixels": 451584, # 672*672
45
+ },
46
+ system_prompt=system_prompt,
47
+ )
@@ -0,0 +1,5 @@
1
+ """Multimodal Model Registry."""
2
+
3
+ from eva.core.utils.registry import Registry
4
+
5
+ model_registry = Registry()
@@ -0,0 +1,27 @@
1
+ """Type definitions for multimodal models."""
2
+
3
+ from typing import Any, Dict, Generic, List, TypeVar
4
+
5
+ from torchvision import tv_tensors
6
+ from typing_extensions import NamedTuple
7
+
8
+ from eva.language.data.messages import MessageSeries
9
+
10
+ TargetType = TypeVar("TargetType")
11
+ """The target data type."""
12
+
13
+
14
+ class TextImageBatch(NamedTuple, Generic[TargetType]):
15
+ """Text and image sample with target and metadata."""
16
+
17
+ text: List[MessageSeries]
18
+ """A batch of conversations with one or multiple messages each."""
19
+
20
+ image: List[tv_tensors.Image]
21
+ """Image tensor."""
22
+
23
+ target: TargetType | None
24
+ """Target data."""
25
+
26
+ metadata: Dict[str, Any] | None
27
+ """Additional metadata."""
@@ -0,0 +1,13 @@
1
+ """Multimodal Wrapper API."""
2
+
3
+ from eva.multimodal.models.wrappers.base import VisionLanguageModel
4
+ from eva.multimodal.models.wrappers.from_registry import ModelFromRegistry
5
+ from eva.multimodal.models.wrappers.huggingface import HuggingFaceModel
6
+ from eva.multimodal.models.wrappers.litellm import LiteLLMModel
7
+
8
+ __all__ = [
9
+ "HuggingFaceModel",
10
+ "LiteLLMModel",
11
+ "ModelFromRegistry",
12
+ "VisionLanguageModel",
13
+ ]
@@ -0,0 +1,47 @@
1
+ """Base class for vision language model wrappers."""
2
+
3
+ import abc
4
+ from typing import Any, Callable, List
5
+
6
+ from typing_extensions import override
7
+
8
+ from eva.core.models.wrappers import base
9
+ from eva.language.data.messages import ModelSystemMessage
10
+ from eva.multimodal.models.typings import TextImageBatch
11
+
12
+
13
+ class VisionLanguageModel(base.BaseModel[TextImageBatch, List[str]]):
14
+ """Base class for multimodal models.
15
+
16
+ Classes that inherit from this should implement the following methods:
17
+ - `load_model`: Loads & instantiates the model.
18
+ - `model_forward`: Implements the forward pass of the model. For API models,
19
+ this can be an API call.
20
+ - `format_inputs`: Preprocesses and converts the input batch into the format
21
+ expected by the `model_forward` method.
22
+ """
23
+
24
+ def __init__(
25
+ self, system_prompt: str | None, output_transforms: Callable | None = None
26
+ ) -> None:
27
+ """Creates a new model instance.
28
+
29
+ Args:
30
+ system_prompt: The system prompt to use for the model (optional).
31
+ output_transforms: Optional transforms to apply to the output of
32
+ the model's forward pass.
33
+ """
34
+ super().__init__(transforms=output_transforms)
35
+
36
+ self.system_message = ModelSystemMessage(content=system_prompt) if system_prompt else None
37
+
38
+ @override
39
+ def forward(self, batch: TextImageBatch) -> List[str]:
40
+ """Forward pass of the model."""
41
+ inputs = self.format_inputs(batch)
42
+ return super().forward(inputs)
43
+
44
+ @abc.abstractmethod
45
+ def format_inputs(self, batch: TextImageBatch) -> Any:
46
+ """Converts the inputs into the format expected by the model."""
47
+ raise NotImplementedError
@@ -0,0 +1,54 @@
1
+ """Vision backbone helper class."""
2
+
3
+ from typing import Any, Callable, Dict, List
4
+
5
+ from torch import nn
6
+ from typing_extensions import override
7
+
8
+ from eva.core.models.wrappers import base
9
+ from eva.core.utils import factory
10
+ from eva.multimodal.models.networks.registry import model_registry
11
+ from eva.multimodal.models.typings import TextImageBatch
12
+
13
+
14
+ class ModelFromRegistry(base.BaseModel[TextImageBatch, List[str]]):
15
+ """Wrapper class for vision backbone models.
16
+
17
+ This class can be used by load backbones available in eva's
18
+ model registry by name. New backbones can be registered by using
19
+ the `@backbone_registry.register(model_name)` decorator.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ model_name: str,
25
+ model_kwargs: Dict[str, Any] | None = None,
26
+ model_extra_kwargs: Dict[str, Any] | None = None,
27
+ transforms: Callable | None = None,
28
+ ) -> None:
29
+ """Initializes the model.
30
+
31
+ Args:
32
+ model_name: The name of the model to load.
33
+ model_kwargs: The arguments used for instantiating the model.
34
+ model_extra_kwargs: Extra arguments used for instantiating the model.
35
+ transforms: The transforms to apply to the output tensor
36
+ produced by the model.
37
+ """
38
+ super().__init__(transforms=transforms)
39
+
40
+ self._model_name = model_name
41
+ self._model_kwargs = model_kwargs or {}
42
+ self._model_extra_kwargs = model_extra_kwargs or {}
43
+
44
+ self.model = self.load_model()
45
+
46
+ @override
47
+ def load_model(self) -> nn.Module:
48
+ ModelFromRegistry.__name__ = self._model_name
49
+
50
+ return factory.ModuleFactory(
51
+ registry=model_registry,
52
+ name=self._model_name,
53
+ init_args=self._model_kwargs | self._model_extra_kwargs,
54
+ )
@@ -0,0 +1,180 @@
1
+ """HuggingFace Vision-Language Model Wrapper."""
2
+
3
+ import functools
4
+ from typing import Any, Callable, Dict, List
5
+
6
+ import torch
7
+ import transformers
8
+ from loguru import logger
9
+ from torch import nn
10
+ from typing_extensions import override
11
+
12
+ from eva.language.models.typings import TextBatch
13
+ from eva.language.utils.text import messages as language_message_utils
14
+ from eva.multimodal.models.typings import TextImageBatch
15
+ from eva.multimodal.models.wrappers import base
16
+ from eva.multimodal.utils.text import messages as message_utils
17
+
18
+
19
+ class HuggingFaceModel(base.VisionLanguageModel):
20
+ """Lightweight wrapper for Huggingface VLMs.
21
+
22
+ Args:
23
+ model_name_or_path: The name of the model to use.
24
+ model_class: The class of the model to use.
25
+ model_kwargs: Additional model arguments.
26
+ processor_kwargs: Additional processor arguments.
27
+ generation_kwargs: Additional generation arguments.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ model_name_or_path: str,
33
+ model_class: str,
34
+ model_kwargs: Dict[str, Any] | None = None,
35
+ system_prompt: str | None = None,
36
+ processor_kwargs: Dict[str, Any] | None = None,
37
+ generation_kwargs: Dict[str, Any] | None = None,
38
+ ):
39
+ """Initialize the HuggingFace model wrapper.
40
+
41
+ Args:
42
+ model_name_or_path: The name or path of the model to use.
43
+ model_class: The class of the model to use.
44
+ model_kwargs: Additional model arguments.
45
+ system_prompt: System prompt to use.
46
+ processor_kwargs: Additional processor arguments.
47
+ generation_kwargs: Additional generation arguments.
48
+ """
49
+ super().__init__(system_prompt=system_prompt)
50
+
51
+ self.model_name_or_path = model_name_or_path
52
+ self.model_kwargs = model_kwargs or {}
53
+ self.base_model_class = model_class
54
+ self.processor_kwargs = processor_kwargs or {}
55
+ self.generation_kwargs = generation_kwargs or {}
56
+
57
+ self.processor = self.load_processor()
58
+ self.model = self.load_model()
59
+
60
+ @override
61
+ def format_inputs(self, batch: TextImageBatch | TextBatch) -> Dict[str, torch.Tensor]:
62
+ """Formats inputs for HuggingFace models.
63
+
64
+ Args:
65
+ batch: A batch of text and image inputs.
66
+
67
+ Returns:
68
+ A dictionary produced by the provided processor following a format like:
69
+ {
70
+ "input_ids": ...,
71
+ "attention_mask": ...,
72
+ "pixel_values": ...
73
+ }
74
+ """
75
+ message_batch, image_batch, _, _ = self._unpack_batch(batch)
76
+ with_images = image_batch is not None
77
+
78
+ message_batch = language_message_utils.batch_insert_system_message(
79
+ message_batch, self.system_message
80
+ )
81
+ message_batch = list(map(language_message_utils.combine_system_messages, message_batch))
82
+
83
+ if self.processor.chat_template is not None: # type: ignore
84
+ templated_text = [
85
+ self.processor.apply_chat_template( # type: ignore
86
+ message,
87
+ add_generation_prompt=True,
88
+ tokenize=False,
89
+ )
90
+ for message in map(
91
+ functools.partial(
92
+ message_utils.format_huggingface_message,
93
+ with_images=with_images,
94
+ ),
95
+ message_batch,
96
+ )
97
+ ]
98
+ else:
99
+ raise NotImplementedError("Currently only chat models are supported.")
100
+
101
+ processor_inputs = {
102
+ "text": templated_text,
103
+ "return_tensors": "pt",
104
+ **self.processor_kwargs,
105
+ }
106
+
107
+ if with_images:
108
+ processor_inputs["image"] = [[image] for image in image_batch]
109
+
110
+ return self.processor(**processor_inputs).to(self.model.device) # type: ignore
111
+
112
+ @override
113
+ def model_forward(self, batch: Dict[str, torch.Tensor]) -> List[str]:
114
+ """Generates text output from the model. Is called by the `generate` method.
115
+
116
+ Args:
117
+ batch: A dictionary containing the input data, which may include:
118
+ - "text": List of messages formatted for the model.
119
+ - "image": List of image tensors.
120
+
121
+ Returns:
122
+ A dictionary containing the processed input and the model's output.
123
+ """
124
+ output = self.model.generate(**batch, **self.generation_kwargs) # type: ignore
125
+ return self._decode_output(output, batch["input_ids"].shape[-1])
126
+
127
+ @override
128
+ def load_model(self) -> nn.Module:
129
+ """Setting up the model. Used for delayed model initialization.
130
+
131
+ Raises:
132
+ ValueError: If the model class is not found in transformers or if the model
133
+ does not support gradient checkpointing but it is enabled.
134
+ """
135
+ logger.info(f"Configuring model: {self.model_name_or_path}")
136
+ if hasattr(transformers, self.base_model_class):
137
+ model_class = getattr(transformers, self.base_model_class)
138
+ else:
139
+ raise ValueError(f"Model class {self.base_model_class} not found in transformers")
140
+
141
+ model = model_class.from_pretrained(self.model_name_or_path, **self.model_kwargs)
142
+
143
+ if not hasattr(model, "generate"):
144
+ raise ValueError(f"Model {self.model_name_or_path} does not support generation. ")
145
+
146
+ return model
147
+
148
+ def load_processor(self) -> Callable:
149
+ """Initialize the processor."""
150
+ return transformers.AutoProcessor.from_pretrained(
151
+ self.model_name_or_path,
152
+ **self.processor_kwargs,
153
+ )
154
+
155
+ def _unpack_batch(self, batch: TextImageBatch | TextBatch) -> tuple:
156
+ if isinstance(batch, TextImageBatch):
157
+ return batch.text, batch.image, batch.target, batch.metadata
158
+ return batch.text, None, batch.target, batch.metadata
159
+
160
+ def _decode_output(self, output: torch.Tensor, instruction_length: int) -> List[str]:
161
+ """Decode the model's batch output to text.
162
+
163
+ Args:
164
+ output: The raw output from the model.
165
+ instruction_length: The length of the instruction in the input.
166
+
167
+ Returns:
168
+ A list of decoded text responses.
169
+ """
170
+ decoded_input = self.processor.batch_decode( # type: ignore
171
+ output[:, :instruction_length], skip_special_tokens=True
172
+ )
173
+ decoded_output = self.processor.batch_decode( # type: ignore
174
+ output[:, instruction_length:], skip_special_tokens=True
175
+ )
176
+
177
+ logger.debug(f"Decoded input: {decoded_input}")
178
+ logger.debug(f"Decoded output: {decoded_output}")
179
+
180
+ return decoded_output
@@ -0,0 +1,56 @@
1
+ """LiteLLM vision-language model wrapper."""
2
+
3
+ import logging
4
+ from typing import Any, Dict, List
5
+
6
+ from typing_extensions import override
7
+
8
+ from eva.language.models import wrappers as language_wrappers
9
+ from eva.language.utils.text import messages as language_message_utils
10
+ from eva.multimodal.models.typings import TextImageBatch
11
+ from eva.multimodal.models.wrappers import base
12
+ from eva.multimodal.utils.text import messages as message_utils
13
+
14
+
15
+ class LiteLLMModel(base.VisionLanguageModel):
16
+ """Wrapper class for LiteLLM vision-language models."""
17
+
18
+ def __init__(
19
+ self,
20
+ model_name: str,
21
+ model_kwargs: Dict[str, Any] | None = None,
22
+ system_prompt: str | None = None,
23
+ log_level: int | None = logging.INFO,
24
+ ):
25
+ """Initialize the LiteLLM Wrapper.
26
+
27
+ Args:
28
+ model_name: The name of the model to use.
29
+ model_kwargs: Additional keyword arguments to pass during
30
+ generation (e.g., `temperature`, `max_tokens`).
31
+ system_prompt: The system prompt to use (optional).
32
+ log_level: Optional logging level for LiteLLM. Defaults to WARNING.
33
+ """
34
+ super().__init__(system_prompt=system_prompt)
35
+
36
+ self.language_model = language_wrappers.LiteLLMModel(
37
+ model_name=model_name,
38
+ model_kwargs=model_kwargs,
39
+ system_prompt=system_prompt,
40
+ log_level=log_level,
41
+ )
42
+
43
+ @override
44
+ def format_inputs(self, batch: TextImageBatch) -> List[List[Dict[str, Any]]]:
45
+ message_batch, image_batch, _, _ = TextImageBatch(*batch)
46
+
47
+ message_batch = language_message_utils.batch_insert_system_message(
48
+ message_batch, self.system_message
49
+ )
50
+ message_batch = list(map(language_message_utils.combine_system_messages, message_batch))
51
+
52
+ return list(map(message_utils.format_litellm_message, message_batch, image_batch))
53
+
54
+ @override
55
+ def model_forward(self, batch: List[List[Dict[str, Any]]]) -> List[str]:
56
+ return self.language_model.model_forward(batch)
@@ -0,0 +1 @@
1
+ """Multimodal utilities API."""
@@ -0,0 +1,5 @@
1
+ """Multimodal image utilities API."""
2
+
3
+ from eva.multimodal.utils.image.encode import encode_image
4
+
5
+ __all__ = ["encode_image"]
@@ -0,0 +1,28 @@
1
+ """Image encoding utilities."""
2
+
3
+ import base64
4
+ import io
5
+ from typing import Literal
6
+
7
+ from torchvision import tv_tensors
8
+ from torchvision.transforms.v2 import functional as F
9
+
10
+
11
+ def encode_image(image: tv_tensors.Image, encoding: Literal["base64"]) -> str:
12
+ """Encodes an image tensor into a string format.
13
+
14
+ Args:
15
+ image: The image tensor to encode.
16
+ encoding: The encoding format to use. Currently only supports "base64".
17
+
18
+ Returns:
19
+ An encoded string representation of the image.
20
+ """
21
+ match encoding:
22
+ case "base64":
23
+ image_bytes = io.BytesIO()
24
+ F.to_pil_image(image).save(image_bytes, format="PNG", optimize=True)
25
+ image_bytes.seek(0)
26
+ return base64.b64encode(image_bytes.getvalue()).decode("utf-8")
27
+ case _:
28
+ raise ValueError(f"Unsupported encoding type: {encoding}. Supported: 'base64'")
@@ -0,0 +1 @@
1
+ """Multimodal text utilities API."""
@@ -0,0 +1,79 @@
1
+ """Message formatting utilities for multimodal models."""
2
+
3
+ from typing import Any, Dict, List
4
+
5
+ from torchvision import tv_tensors
6
+
7
+ from eva.language import utils as language_utils
8
+ from eva.language.data.messages import MessageSeries, Role
9
+ from eva.multimodal.utils import image as image_utils
10
+
11
+
12
+ def format_huggingface_message(
13
+ message: MessageSeries, with_images: bool = False
14
+ ) -> List[Dict[str, Any]]:
15
+ """Formats a message series into a format suitable for Huggingface models."""
16
+ if not with_images:
17
+ return language_utils.format_chat_message(message)
18
+
19
+ formatted_message = []
20
+ for item in message:
21
+ if item.role == Role.SYSTEM:
22
+ formatted_message += language_utils.format_chat_message([item])
23
+ else:
24
+ formatted_message.append(
25
+ {
26
+ "role": item.role,
27
+ "content": [
28
+ {
29
+ "type": "text",
30
+ "text": str(item.content),
31
+ },
32
+ {"type": "image"},
33
+ ],
34
+ }
35
+ )
36
+ return formatted_message
37
+
38
+
39
+ def format_litellm_message(
40
+ message: MessageSeries, image: tv_tensors.Image | None
41
+ ) -> List[Dict[str, Any]]:
42
+ """Format a message series for LiteLLM API.
43
+
44
+ Args:
45
+ message: The message series to format.
46
+ image: Optional image to include in the message.
47
+
48
+ Returns:
49
+ A list of formatted message dictionaries.
50
+ """
51
+ if image is None:
52
+ return language_utils.format_chat_message(message)
53
+
54
+ formatted_message = []
55
+ for item in message:
56
+ if item.role == Role.SYSTEM:
57
+ formatted_message += language_utils.format_chat_message([item])
58
+ else:
59
+ formatted_message.append(
60
+ {
61
+ "role": item.role,
62
+ "content": [
63
+ {
64
+ "type": "text",
65
+ "text": str(item.content),
66
+ },
67
+ {
68
+ "type": "image_url",
69
+ "image_url": {
70
+ "url": (
71
+ f"data:image/png;base64,"
72
+ f"{image_utils.encode_image(image, encoding='base64')}"
73
+ )
74
+ },
75
+ },
76
+ ],
77
+ }
78
+ )
79
+ return formatted_message
@@ -61,6 +61,13 @@ class PatchCamelyon(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
61
61
  ]
62
62
  """Test resources."""
63
63
 
64
+ _expected_length = {
65
+ "train": 262144,
66
+ "val": 32768,
67
+ "test": 32768,
68
+ }
69
+ """Expected dataset length for each split."""
70
+
64
71
  _license: str = (
65
72
  "Creative Commons Zero v1.0 Universal (https://choosealicense.com/licenses/cc0-1.0/)"
66
73
  )
@@ -113,14 +120,9 @@ class PatchCamelyon(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
113
120
 
114
121
  @override
115
122
  def validate(self) -> None:
116
- expected_length = {
117
- "train": 262144,
118
- "val": 32768,
119
- "test": 32768,
120
- }
121
123
  _validators.check_dataset_integrity(
122
124
  self,
123
- length=expected_length.get(self._split, 0),
125
+ length=self._expected_length.get(self._split, 0),
124
126
  n_classes=2,
125
127
  first_and_last_labels=("no_tumor", "tumor"),
126
128
  )
@@ -13,10 +13,11 @@ from eva.vision.data.transforms.intensity import (
13
13
  RandShiftIntensity,
14
14
  ScaleIntensityRange,
15
15
  )
16
- from eva.vision.data.transforms.spatial import RandFlip, RandRotate90, Spacing
16
+ from eva.vision.data.transforms.spatial import RandFlip, RandRotate90, Resize, Spacing
17
17
  from eva.vision.data.transforms.utility import EnsureChannelFirst
18
18
 
19
19
  __all__ = [
20
+ "Resize",
20
21
  "ResizeAndCrop",
21
22
  "Squeeze",
22
23
  "CropForeground",
@@ -1,7 +1,8 @@
1
1
  """Transforms for spatial operations."""
2
2
 
3
3
  from eva.vision.data.transforms.spatial.flip import RandFlip
4
+ from eva.vision.data.transforms.spatial.resize import Resize
4
5
  from eva.vision.data.transforms.spatial.rotate import RandRotate90
5
6
  from eva.vision.data.transforms.spatial.spacing import Spacing
6
7
 
7
- __all__ = ["Spacing", "RandFlip", "RandRotate90"]
8
+ __all__ = ["Spacing", "RandFlip", "RandRotate90", "Resize"]
@@ -0,0 +1,5 @@
1
+ """Functional API for spatial transforms."""
2
+
3
+ from eva.vision.data.transforms.spatial.functional.resize import resize_to_max_bytes
4
+
5
+ __all__ = ["resize_to_max_bytes"]