crfm-helm 0.5.0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.1.dist-info}/METADATA +7 -3
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.1.dist-info}/RECORD +53 -41
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/multimodal_prompt.py +7 -0
- helm/benchmark/adaptation/adapters/multimodal/test_multimodal_prompt.py +2 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +1 -1
- helm/benchmark/augmentations/perturbation.py +17 -1
- helm/benchmark/augmentations/test_perturbation.py +30 -0
- helm/benchmark/metrics/efficiency_metrics.py +9 -2
- helm/benchmark/metrics/evaluate_reference_metrics.py +16 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +142 -17
- helm/benchmark/model_metadata_registry.py +5 -1
- helm/benchmark/run_expander.py +35 -63
- helm/benchmark/run_spec_factory.py +11 -10
- helm/benchmark/run_specs/vlm_run_specs.py +294 -38
- helm/benchmark/scenarios/legalbench_scenario.py +6 -2
- helm/benchmark/scenarios/math_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/a_okvqa_scenario.py +83 -0
- helm/benchmark/scenarios/vision_language/crossmodal_3600_scenario.py +134 -0
- helm/benchmark/scenarios/vision_language/flickr30k_scenario.py +74 -0
- helm/benchmark/scenarios/vision_language/gqa_scenario.py +91 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +4 -2
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/math_vista_scenario.py +117 -0
- helm/benchmark/scenarios/vision_language/mm_safety_bench_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/mscoco_captioning_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/mscoco_categorization_scenario.py +117 -0
- helm/benchmark/scenarios/vision_language/originality_scenario.py +35 -0
- helm/benchmark/scenarios/vision_language/pairs_scenario.py +246 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +2 -2
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +4 -2
- helm/benchmark/static/schema_image2structure.yaml +304 -0
- helm/benchmark/static/schema_vhelm_lite.yaml +164 -0
- helm/benchmark/static/schema_vlm.yaml +257 -10
- helm/benchmark/static_build/assets/index-737eef9e.js +10 -0
- helm/benchmark/static_build/assets/index-878a1094.css +1 -0
- helm/benchmark/static_build/index.html +2 -2
- helm/clients/anthropic_client.py +36 -6
- helm/clients/openai_client.py +2 -3
- helm/clients/together_client.py +93 -2
- helm/clients/vertexai_client.py +59 -50
- helm/clients/vision_language/huggingface_vision2seq_client.py +145 -0
- helm/clients/vision_language/huggingface_vlm_client.py +11 -4
- helm/clients/vision_language/idefics_client.py +2 -2
- helm/common/images_utils.py +10 -3
- helm/config/model_deployments.yaml +100 -2
- helm/config/model_metadata.yaml +136 -31
- helm/config/tokenizer_configs.yaml +7 -0
- helm/benchmark/static_build/assets/index-5088afcb.css +0 -1
- helm/benchmark/static_build/assets/index-d839df55.js +0 -9
- helm/benchmark/test_model_deployment_definition.py +0 -90
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.1.dist-info}/LICENSE +0 -0
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.1.dist-info}/WHEEL +0 -0
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.1.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -7,11 +7,11 @@
|
|
|
7
7
|
<title>Holistic Evaluation of Language Models (HELM)</title>
|
|
8
8
|
<meta name="description" content="The Holistic Evaluation of Language Models (HELM) serves as a living benchmark for transparency in language models. Providing broad coverage and recognizing incompleteness, multi-metric measurements, and standardization. All data and analysis are freely accessible on the website for exploration and study." />
|
|
9
9
|
<script type="text/javascript" src="./config.js"></script>
|
|
10
|
-
<script type="module" crossorigin src="./assets/index-
|
|
10
|
+
<script type="module" crossorigin src="./assets/index-737eef9e.js"></script>
|
|
11
11
|
<link rel="modulepreload" crossorigin href="./assets/react-d4a0b69b.js">
|
|
12
12
|
<link rel="modulepreload" crossorigin href="./assets/recharts-6d337683.js">
|
|
13
13
|
<link rel="modulepreload" crossorigin href="./assets/tremor-54a99cc4.js">
|
|
14
|
-
<link rel="stylesheet" href="./assets/index-
|
|
14
|
+
<link rel="stylesheet" href="./assets/index-878a1094.css">
|
|
15
15
|
</head>
|
|
16
16
|
<body class="block">
|
|
17
17
|
<div id="root"></div>
|
helm/clients/anthropic_client.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from typing import Any, Dict, List, Optional, TypedDict, Union, cast
|
|
2
2
|
import json
|
|
3
3
|
import requests
|
|
4
|
+
import tempfile
|
|
4
5
|
import time
|
|
5
6
|
import urllib.parse
|
|
6
7
|
|
|
@@ -68,6 +69,9 @@ class AnthropicClient(CachingClient):
|
|
|
68
69
|
MAX_COMPLETION_LENGTH: int = (
|
|
69
70
|
8192 # See https://docs.google.com/document/d/1vX6xgoA-KEKxqtMlBVAqYvE8KUfZ7ABCjTxAjf1T5kI/edit#
|
|
70
71
|
)
|
|
72
|
+
# An Anthropic error message: "At least one of the image dimensions exceed max allowed size: 8000 pixels"
|
|
73
|
+
MAX_IMAGE_DIMENSION: int = 8000
|
|
74
|
+
|
|
71
75
|
ADDITIONAL_TOKENS: int = 5
|
|
72
76
|
PROMPT_ANSWER_START: str = "The answer is "
|
|
73
77
|
|
|
@@ -206,7 +210,7 @@ class AnthropicClient(CachingClient):
|
|
|
206
210
|
|
|
207
211
|
|
|
208
212
|
def _is_content_moderation_failure(response: Dict) -> bool:
|
|
209
|
-
"""Return whether a
|
|
213
|
+
"""Return whether a response failed because of the content moderation filter."""
|
|
210
214
|
if (
|
|
211
215
|
"error" in response
|
|
212
216
|
and "message" in response["error"]
|
|
@@ -238,7 +242,7 @@ class AnthropicMessagesResponseError(Exception):
|
|
|
238
242
|
|
|
239
243
|
class AnthropicMessagesClient(CachingClient):
|
|
240
244
|
# Source: https://docs.anthropic.com/claude/docs/models-overview
|
|
241
|
-
MAX_OUTPUT_TOKENS = 4096
|
|
245
|
+
MAX_OUTPUT_TOKENS: int = 4096
|
|
242
246
|
|
|
243
247
|
def __init__(
|
|
244
248
|
self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig, api_key: Optional[str] = None
|
|
@@ -273,7 +277,7 @@ class AnthropicMessagesClient(CachingClient):
|
|
|
273
277
|
# TODO(#2439): Refactor out Request validation
|
|
274
278
|
if request.messages is not None or request.prompt:
|
|
275
279
|
raise AnthropicMessagesRequestError(
|
|
276
|
-
"Exactly one of Request.messages, Request.prompt or Request.
|
|
280
|
+
"Exactly one of Request.messages, Request.prompt or Request.multimodal_prompt should be set"
|
|
277
281
|
)
|
|
278
282
|
blocks: List[Union[TextBlockParam, ImageBlockParam]] = []
|
|
279
283
|
for media_object in request.multimodal_prompt.media_objects:
|
|
@@ -282,9 +286,33 @@ class AnthropicMessagesClient(CachingClient):
|
|
|
282
286
|
if not media_object.location:
|
|
283
287
|
raise Exception("MediaObject of image type has missing location field value")
|
|
284
288
|
|
|
285
|
-
from helm.common.images_utils import encode_base64
|
|
289
|
+
from helm.common.images_utils import encode_base64, get_dimensions, copy_image
|
|
290
|
+
|
|
291
|
+
image_location: str = media_object.location
|
|
292
|
+
base64_image: str
|
|
293
|
+
|
|
294
|
+
image_width, image_height = get_dimensions(media_object.location)
|
|
295
|
+
if (
|
|
296
|
+
image_width > AnthropicClient.MAX_IMAGE_DIMENSION
|
|
297
|
+
or image_height > AnthropicClient.MAX_IMAGE_DIMENSION
|
|
298
|
+
):
|
|
299
|
+
hlog(
|
|
300
|
+
f"WARNING: Image {image_location} exceeds max allowed size: "
|
|
301
|
+
f"{AnthropicClient.MAX_IMAGE_DIMENSION} pixels"
|
|
302
|
+
)
|
|
303
|
+
# Save the resized image to a temporary file
|
|
304
|
+
with tempfile.NamedTemporaryFile(suffix=".jpg") as temp_file:
|
|
305
|
+
hlog(f"Resizing image to temporary path: {temp_file.name}")
|
|
306
|
+
copy_image(
|
|
307
|
+
src=image_location,
|
|
308
|
+
dest=temp_file.name,
|
|
309
|
+
width=min(image_width, AnthropicClient.MAX_IMAGE_DIMENSION),
|
|
310
|
+
height=min(image_height, AnthropicClient.MAX_IMAGE_DIMENSION),
|
|
311
|
+
)
|
|
312
|
+
base64_image = encode_base64(temp_file.name, format="JPEG")
|
|
313
|
+
else:
|
|
314
|
+
base64_image = encode_base64(image_location, format="JPEG")
|
|
286
315
|
|
|
287
|
-
base64_image: str = encode_base64(media_object.location, format="JPEG")
|
|
288
316
|
image_block: ImageBlockParam = {
|
|
289
317
|
"type": "image",
|
|
290
318
|
"source": {
|
|
@@ -302,7 +330,9 @@ class AnthropicMessagesClient(CachingClient):
|
|
|
302
330
|
"type": "text",
|
|
303
331
|
"text": media_object.text,
|
|
304
332
|
}
|
|
305
|
-
blocks
|
|
333
|
+
# Anthropic does not support empty text blocks
|
|
334
|
+
if media_object.text.strip():
|
|
335
|
+
blocks.append(text_block)
|
|
306
336
|
messages = [{"role": "user", "content": blocks}]
|
|
307
337
|
|
|
308
338
|
else:
|
helm/clients/openai_client.py
CHANGED
|
@@ -130,9 +130,8 @@ class OpenAIClient(CachingClient):
|
|
|
130
130
|
from helm.common.images_utils import encode_base64
|
|
131
131
|
|
|
132
132
|
base64_image: str = encode_base64(media_object.location)
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
)
|
|
133
|
+
image_object: Dict[str, str] = {"url": f"data:image/jpeg;base64,{base64_image}"}
|
|
134
|
+
content.append({"type": "image_url", "image_url": image_object})
|
|
136
135
|
elif media_object.is_type(TEXT_TYPE):
|
|
137
136
|
if media_object.text is None:
|
|
138
137
|
raise ValueError("MediaObject of text type has missing text field value")
|
helm/clients/together_client.py
CHANGED
|
@@ -1,12 +1,20 @@
|
|
|
1
1
|
from copy import deepcopy
|
|
2
|
-
from
|
|
2
|
+
from itertools import zip_longest
|
|
3
|
+
from typing import List, Dict, Any, Optional, TypedDict, Union
|
|
3
4
|
|
|
4
5
|
import requests
|
|
5
6
|
from retrying import retry
|
|
6
7
|
|
|
7
8
|
from helm.common.cache import CacheConfig
|
|
9
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
8
10
|
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, Token
|
|
9
|
-
from .client import CachingClient, truncate_sequence, cleanup_str
|
|
11
|
+
from helm.clients.client import CachingClient, truncate_sequence, cleanup_str
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from together import Together
|
|
15
|
+
from together.types import ChatCompletionResponse
|
|
16
|
+
except ModuleNotFoundError as e:
|
|
17
|
+
handle_module_not_found_error(e, ["together"])
|
|
10
18
|
|
|
11
19
|
|
|
12
20
|
class _RewriteRequestTags:
|
|
@@ -272,3 +280,86 @@ class TogetherClient(CachingClient):
|
|
|
272
280
|
completions=completions,
|
|
273
281
|
embedding=[],
|
|
274
282
|
)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
class TogetherRawChatRequest(TypedDict):
|
|
286
|
+
messages: List[Dict[str, str]]
|
|
287
|
+
model: str
|
|
288
|
+
max_tokens: int
|
|
289
|
+
stop: List[str]
|
|
290
|
+
temperature: float
|
|
291
|
+
top_p: float
|
|
292
|
+
top_k: int
|
|
293
|
+
logprobs: int
|
|
294
|
+
echo: bool
|
|
295
|
+
n: int
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def convert_to_raw_chat_request(request: Request) -> TogetherRawChatRequest:
|
|
299
|
+
if request.messages:
|
|
300
|
+
messages = request.messages
|
|
301
|
+
else:
|
|
302
|
+
messages = [{"role": "user", "content": request.prompt}]
|
|
303
|
+
return {
|
|
304
|
+
"messages": messages,
|
|
305
|
+
"model": request.model,
|
|
306
|
+
"max_tokens": request.max_tokens,
|
|
307
|
+
"stop": request.stop_sequences,
|
|
308
|
+
"temperature": request.temperature,
|
|
309
|
+
"top_p": request.top_p,
|
|
310
|
+
"top_k": request.top_k_per_token,
|
|
311
|
+
"logprobs": min(request.top_k_per_token, 1),
|
|
312
|
+
"echo": request.echo_prompt,
|
|
313
|
+
"n": request.num_completions,
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
class TogetherChatClient(CachingClient):
|
|
318
|
+
"""Client that uses the Python Together library for chat models."""
|
|
319
|
+
|
|
320
|
+
def __init__(self, cache_config: CacheConfig, api_key: str, together_model: Optional[str] = None):
|
|
321
|
+
super().__init__(cache_config=cache_config)
|
|
322
|
+
self._client = Together(api_key=api_key)
|
|
323
|
+
|
|
324
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
325
|
+
raw_request = convert_to_raw_chat_request(request)
|
|
326
|
+
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
327
|
+
|
|
328
|
+
def do_it() -> Dict[Any, Any]:
|
|
329
|
+
response = self._client.chat.completions.create(**raw_request)
|
|
330
|
+
return response.model_dump(mode="json")
|
|
331
|
+
|
|
332
|
+
try:
|
|
333
|
+
raw_response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
334
|
+
response = ChatCompletionResponse.model_validate(raw_response)
|
|
335
|
+
except Exception as error:
|
|
336
|
+
return RequestResult(
|
|
337
|
+
success=False,
|
|
338
|
+
cached=False,
|
|
339
|
+
error=str(error),
|
|
340
|
+
completions=[],
|
|
341
|
+
embedding=[],
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
generated_outputs: List[GeneratedOutput] = []
|
|
345
|
+
for choice in response.choices:
|
|
346
|
+
# NOTE: Together always returns None for choice.finish_reason
|
|
347
|
+
# NOTE: Together does not return logprobs for the whole generated output, only for individual tokens
|
|
348
|
+
tokens: List[Token] = []
|
|
349
|
+
if choice.logprobs:
|
|
350
|
+
for token_text, token_logprob in zip_longest(
|
|
351
|
+
choice.logprobs.tokens or [], choice.logprobs.token_logprobs or []
|
|
352
|
+
):
|
|
353
|
+
if token_text is None:
|
|
354
|
+
break
|
|
355
|
+
tokens.append(Token(text=token_text, logprob=token_logprob or 0.0))
|
|
356
|
+
assert choice.message.role == "assistant"
|
|
357
|
+
generated_outputs.append(GeneratedOutput(text=choice.message.content, logprob=0.0, tokens=tokens))
|
|
358
|
+
return RequestResult(
|
|
359
|
+
success=True,
|
|
360
|
+
cached=cached,
|
|
361
|
+
request_time=raw_response["request_time"],
|
|
362
|
+
request_datetime=raw_response["request_datetime"],
|
|
363
|
+
completions=generated_outputs,
|
|
364
|
+
embedding=[],
|
|
365
|
+
)
|
helm/clients/vertexai_client.py
CHANGED
|
@@ -4,7 +4,6 @@ from threading import Lock
|
|
|
4
4
|
from typing import Any, Dict, Optional, List, Union
|
|
5
5
|
|
|
6
6
|
from helm.common.cache import CacheConfig
|
|
7
|
-
from helm.common.hierarchical_logger import hlog
|
|
8
7
|
from helm.common.media_object import TEXT_TYPE
|
|
9
8
|
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
10
9
|
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, ErrorFlags
|
|
@@ -131,12 +130,6 @@ class VertexAITextClient(VertexAIClient):
|
|
|
131
130
|
class VertexAIChatClient(VertexAIClient):
|
|
132
131
|
"""Client for Vertex AI chat models (e.g., Gemini). Supports multimodal prompts."""
|
|
133
132
|
|
|
134
|
-
# Set the finish reason to this if the prompt violates the content policy
|
|
135
|
-
CONTENT_POLICY_VIOLATED_FINISH_REASON: str = "The prompt violates Google's content policy."
|
|
136
|
-
|
|
137
|
-
# Gemini returns this error for certain valid requests
|
|
138
|
-
CONTENT_HAS_NO_PARTS_ERROR: str = "Content has no parts."
|
|
139
|
-
|
|
140
133
|
# Enum taken from:
|
|
141
134
|
# https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1beta1#google.cloud.aiplatform.v1beta1.Candidate.FinishReason
|
|
142
135
|
# We don't directly import this enum because it can differ between different Vertex AI library versions.
|
|
@@ -149,7 +142,7 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
149
142
|
]
|
|
150
143
|
|
|
151
144
|
@staticmethod
|
|
152
|
-
def get_model(model_name: str) ->
|
|
145
|
+
def get_model(model_name: str) -> GenerativeModel:
|
|
153
146
|
global _models_lock
|
|
154
147
|
global _models
|
|
155
148
|
|
|
@@ -202,21 +195,22 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
202
195
|
)
|
|
203
196
|
candidates: List[Candidate] = response.candidates
|
|
204
197
|
|
|
205
|
-
# Depending on the version of the Vertex AI library and the type of
|
|
206
|
-
#
|
|
198
|
+
# Depending on the version of the Vertex AI library and the type of prompt blocking,
|
|
199
|
+
# prompt blocking can show up in many ways, so this defensively handles most of these ways
|
|
200
|
+
if response.prompt_feedback.block_reason:
|
|
201
|
+
raise VertexAIContentBlockedError(
|
|
202
|
+
f"Prompt blocked with reason: {response.prompt_feedback.block_reason}"
|
|
203
|
+
)
|
|
207
204
|
if not candidates:
|
|
208
|
-
raise VertexAIContentBlockedError("No candidates in response
|
|
205
|
+
raise VertexAIContentBlockedError(f"No candidates in response: {response}")
|
|
209
206
|
predictions: List[Dict[str, Any]] = []
|
|
210
207
|
for candidate in candidates:
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
# For now, we don't cache blocked requests, because we are trying to get the
|
|
218
|
-
# content blocking removed.
|
|
219
|
-
raise VertexAIContentBlockedError("Content has no parts due to content blocking")
|
|
208
|
+
# Depending on the version of the Vertex AI library and the type of prompt blocking,
|
|
209
|
+
# content blocking can show up in many ways, so this defensively handles most of these ways
|
|
210
|
+
if candidate.finish_reason in VertexAIChatClient.CONTENT_BLOCKED_FINISH_REASONS:
|
|
211
|
+
raise VertexAIContentBlockedError(f"Content blocked with reason: {candidate.finish_reason}")
|
|
212
|
+
if not candidate.content.parts:
|
|
213
|
+
raise VertexAIContentBlockedError(f"No parts in candidate: {candidate}")
|
|
220
214
|
predictions.append({"text": candidate.content.text})
|
|
221
215
|
# TODO: Extract more information from the response
|
|
222
216
|
return {"predictions": predictions}
|
|
@@ -234,11 +228,11 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
234
228
|
)
|
|
235
229
|
|
|
236
230
|
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
237
|
-
except VertexAIContentBlockedError:
|
|
231
|
+
except VertexAIContentBlockedError as e:
|
|
238
232
|
return RequestResult(
|
|
239
233
|
success=False,
|
|
240
234
|
cached=False,
|
|
241
|
-
error="
|
|
235
|
+
error=f"Content blocked: {str(e)}",
|
|
242
236
|
completions=[],
|
|
243
237
|
embedding=[],
|
|
244
238
|
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
@@ -252,7 +246,7 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
252
246
|
return RequestResult(
|
|
253
247
|
success=False,
|
|
254
248
|
cached=False,
|
|
255
|
-
error="
|
|
249
|
+
error=f"Content blocked error in cached response: {str(response)}",
|
|
256
250
|
completions=[],
|
|
257
251
|
embedding=[],
|
|
258
252
|
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
@@ -266,7 +260,7 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
266
260
|
return RequestResult(
|
|
267
261
|
success=False,
|
|
268
262
|
cached=False,
|
|
269
|
-
error="
|
|
263
|
+
error=f"Content blocked error in cached prediction: {str(prediction)}",
|
|
270
264
|
completions=[],
|
|
271
265
|
embedding=[],
|
|
272
266
|
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
@@ -291,21 +285,6 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
291
285
|
)
|
|
292
286
|
|
|
293
287
|
def _make_multimodal_request(self, request: Request) -> RequestResult:
|
|
294
|
-
def complete_for_valid_error(error_message: str) -> RequestResult:
|
|
295
|
-
empty_completion = GeneratedOutput(
|
|
296
|
-
text="",
|
|
297
|
-
logprob=0,
|
|
298
|
-
tokens=[],
|
|
299
|
-
finish_reason={"reason": error_message},
|
|
300
|
-
)
|
|
301
|
-
return RequestResult(
|
|
302
|
-
success=True,
|
|
303
|
-
cached=False,
|
|
304
|
-
request_time=0,
|
|
305
|
-
completions=[empty_completion] * request.num_completions,
|
|
306
|
-
embedding=[],
|
|
307
|
-
)
|
|
308
|
-
|
|
309
288
|
# Contents can either be text or a list of multimodal content made up of text, images or other content
|
|
310
289
|
contents: Union[str, List[Union[str, Any]]] = request.prompt
|
|
311
290
|
# Used to generate a unique cache key for this specific request
|
|
@@ -346,14 +325,29 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
346
325
|
try:
|
|
347
326
|
|
|
348
327
|
def do_it() -> Dict[str, Any]:
|
|
349
|
-
|
|
328
|
+
response: GenerationResponse = model.generate_content(
|
|
350
329
|
contents, generation_config=parameters, safety_settings=self.safety_settings
|
|
351
330
|
)
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
331
|
+
# Depending on the version of the Vertex AI library and the type of prompt blocking,
|
|
332
|
+
# prompt blocking can show up in many ways, so this defensively handles most of these ways
|
|
333
|
+
if response.prompt_feedback.block_reason:
|
|
334
|
+
raise VertexAIContentBlockedError(
|
|
335
|
+
f"Prompt blocked with reason: {response.prompt_feedback.block_reason}"
|
|
336
|
+
)
|
|
337
|
+
if not response.candidates:
|
|
338
|
+
raise VertexAIContentBlockedError(f"No candidates in response: {response}")
|
|
339
|
+
# We should only have one candidate
|
|
340
|
+
assert (
|
|
341
|
+
len(response.candidates) == 1
|
|
342
|
+
), f"Expected 1 candidate since candidate_count is 1, got {len(response.candidates)}."
|
|
343
|
+
candidate = response.candidates[0]
|
|
344
|
+
# Depending on the version of the Vertex AI library and the type of prompt blocking,
|
|
345
|
+
# content blocking can show up in many ways, so this defensively handles most of these ways
|
|
346
|
+
if candidate.finish_reason in VertexAIChatClient.CONTENT_BLOCKED_FINISH_REASONS:
|
|
347
|
+
raise VertexAIContentBlockedError(f"Content blocked with reason: {candidate.finish_reason}")
|
|
348
|
+
if not candidate.content.parts:
|
|
349
|
+
raise VertexAIContentBlockedError(f"No parts in candidate: {candidate}")
|
|
350
|
+
return {"predictions": [{"text": candidate.text}]}
|
|
357
351
|
|
|
358
352
|
raw_cache_key = {"model_name": model_name, "prompt": prompt_key, **parameters}
|
|
359
353
|
if completion_index > 0:
|
|
@@ -361,15 +355,30 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
361
355
|
|
|
362
356
|
cache_key = CachingClient.make_cache_key(raw_cache_key, request)
|
|
363
357
|
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
364
|
-
except
|
|
365
|
-
if str(e) == self.CONTENT_HAS_NO_PARTS_ERROR:
|
|
366
|
-
return complete_for_valid_error(self.CONTENT_HAS_NO_PARTS_ERROR)
|
|
367
|
-
|
|
358
|
+
except requests.exceptions.RequestException as e:
|
|
368
359
|
error: str = f"Gemini Vision error: {e}"
|
|
369
360
|
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
361
|
+
except VertexAIContentBlockedError as e:
|
|
362
|
+
return RequestResult(
|
|
363
|
+
success=False,
|
|
364
|
+
cached=False,
|
|
365
|
+
error=f"Content blocked: {str(e)}",
|
|
366
|
+
completions=[],
|
|
367
|
+
embedding=[],
|
|
368
|
+
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
369
|
+
)
|
|
370
370
|
|
|
371
371
|
if "error" in response:
|
|
372
|
-
return
|
|
372
|
+
return RequestResult(
|
|
373
|
+
success=False,
|
|
374
|
+
cached=True,
|
|
375
|
+
error=f"Content blocked error in cached response: {str(response)}",
|
|
376
|
+
completions=[],
|
|
377
|
+
embedding=[],
|
|
378
|
+
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
379
|
+
request_time=response["request_time"],
|
|
380
|
+
request_datetime=response["request_datetime"],
|
|
381
|
+
)
|
|
373
382
|
|
|
374
383
|
response_text = response["predictions"][0]["text"]
|
|
375
384
|
completion = GeneratedOutput(text=response_text, logprob=0, tokens=[])
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
from threading import Lock
|
|
2
|
+
from typing import Any, Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from transformers import AutoProcessor, AutoModelForVision2Seq
|
|
6
|
+
from transformers.image_utils import load_image
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from helm.common.cache import CacheConfig
|
|
10
|
+
from helm.common.gpu_utils import get_torch_device_name, is_cuda_available
|
|
11
|
+
from helm.common.hierarchical_logger import hlog, htrack_block
|
|
12
|
+
from helm.common.media_object import TEXT_TYPE
|
|
13
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, Token
|
|
14
|
+
from helm.common.request import wrap_request_time
|
|
15
|
+
from helm.common.tokenization_request import TokenizationRequest
|
|
16
|
+
from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt
|
|
17
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True)
|
|
21
|
+
class Vision2SeqModelProcessor:
|
|
22
|
+
"""Loaded model and processor."""
|
|
23
|
+
|
|
24
|
+
model: AutoModelForVision2Seq
|
|
25
|
+
processor: AutoProcessor
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
_models_lock: Lock = Lock()
|
|
29
|
+
_models: Dict[str, Optional[Vision2SeqModelProcessor]] = {
|
|
30
|
+
"HuggingFaceM4/idefics2-8b": None,
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class HuggingFaceVision2SeqClient(CachingClient):
|
|
35
|
+
"""
|
|
36
|
+
Models for Vision2Seq models from HuggingFace.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
ASSISTANT_PREFIX: str = "Assistant:"
|
|
40
|
+
|
|
41
|
+
def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig):
|
|
42
|
+
super().__init__(cache_config=cache_config)
|
|
43
|
+
self.tokenizer = tokenizer
|
|
44
|
+
self.tokenizer_name = tokenizer_name
|
|
45
|
+
self._device: str = get_torch_device_name()
|
|
46
|
+
|
|
47
|
+
def _get_model(self, checkpoint: str) -> Vision2SeqModelProcessor:
|
|
48
|
+
global _models_lock
|
|
49
|
+
global _models
|
|
50
|
+
|
|
51
|
+
# Ensure that only one thread is loading the model at a time
|
|
52
|
+
with _models_lock:
|
|
53
|
+
loaded_model_processor = _models[checkpoint]
|
|
54
|
+
if loaded_model_processor is None:
|
|
55
|
+
hlog(f"Loading model {checkpoint} and caching in memory...")
|
|
56
|
+
torch_dtype: torch.dtype = torch.float16 if is_cuda_available() else torch.float32
|
|
57
|
+
model = AutoModelForVision2Seq.from_pretrained(checkpoint, torch_dtype=torch_dtype).to(self._device)
|
|
58
|
+
processor = AutoProcessor.from_pretrained(checkpoint)
|
|
59
|
+
|
|
60
|
+
_models[checkpoint] = Vision2SeqModelProcessor(model, processor)
|
|
61
|
+
loaded_model_processor = _models[checkpoint]
|
|
62
|
+
|
|
63
|
+
assert loaded_model_processor is not None
|
|
64
|
+
return loaded_model_processor
|
|
65
|
+
|
|
66
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
67
|
+
assert request.model_deployment in _models, f"Not a valid model for this client: {request.model_deployment}"
|
|
68
|
+
assert request.multimodal_prompt is not None, "Multimodal prompt is required"
|
|
69
|
+
|
|
70
|
+
loaded_model_processor: Vision2SeqModelProcessor = self._get_model(request.model_deployment)
|
|
71
|
+
model = loaded_model_processor.model
|
|
72
|
+
processor = loaded_model_processor.processor
|
|
73
|
+
|
|
74
|
+
generation_args: Dict[str, Any] = {
|
|
75
|
+
"max_new_tokens": request.max_tokens,
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
image_paths: List[str] = []
|
|
79
|
+
multimodal_prompt: List[Dict[str, str]] = []
|
|
80
|
+
for media_object in request.multimodal_prompt.media_objects:
|
|
81
|
+
if media_object.is_type("image") and media_object.location:
|
|
82
|
+
image_paths.append(media_object.location)
|
|
83
|
+
multimodal_prompt.append({"type": "image"})
|
|
84
|
+
elif media_object.is_type(TEXT_TYPE):
|
|
85
|
+
if media_object.text is None:
|
|
86
|
+
raise ValueError("MediaObject of text type has missing text field value")
|
|
87
|
+
|
|
88
|
+
multimodal_prompt.append({"type": "text", "text": media_object.text})
|
|
89
|
+
else:
|
|
90
|
+
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
|
91
|
+
|
|
92
|
+
completions: List[GeneratedOutput] = []
|
|
93
|
+
with htrack_block(f"Generating for prompt: {request.multimodal_prompt.text}"):
|
|
94
|
+
try:
|
|
95
|
+
|
|
96
|
+
def do_it() -> Dict[str, Any]:
|
|
97
|
+
messages = [{"role": "user", "content": multimodal_prompt}]
|
|
98
|
+
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
|
99
|
+
inputs = processor(
|
|
100
|
+
text=[prompt] * request.num_completions,
|
|
101
|
+
images=[
|
|
102
|
+
[load_image(image_path) for image_path in image_paths]
|
|
103
|
+
for _ in range(request.num_completions)
|
|
104
|
+
],
|
|
105
|
+
return_tensors="pt",
|
|
106
|
+
)
|
|
107
|
+
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
|
108
|
+
|
|
109
|
+
# Generate
|
|
110
|
+
generated_ids = model.generate(**inputs, **generation_args)
|
|
111
|
+
generated_texts: List[str] = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
|
112
|
+
return {"output": generated_texts}
|
|
113
|
+
|
|
114
|
+
# Include the prompt and model name in the cache key
|
|
115
|
+
cache_key = CachingClient.make_cache_key(
|
|
116
|
+
raw_request={
|
|
117
|
+
"n": request.num_completions,
|
|
118
|
+
"model": request.model,
|
|
119
|
+
"prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt),
|
|
120
|
+
**generation_args,
|
|
121
|
+
},
|
|
122
|
+
request=request,
|
|
123
|
+
)
|
|
124
|
+
result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
125
|
+
except RuntimeError as model_error:
|
|
126
|
+
return RequestResult(success=False, cached=False, error=str(model_error), completions=[], embedding=[])
|
|
127
|
+
|
|
128
|
+
for text in result["output"]:
|
|
129
|
+
hlog(f"Generated text: {text}")
|
|
130
|
+
assert self.ASSISTANT_PREFIX in text, f"Expected {self.ASSISTANT_PREFIX} in the output"
|
|
131
|
+
text = text.rpartition(self.ASSISTANT_PREFIX)[-1]
|
|
132
|
+
hlog(f"Truncated: {text}")
|
|
133
|
+
tokenization_result = self.tokenizer.tokenize(
|
|
134
|
+
TokenizationRequest(text, tokenizer=self.tokenizer_name, encode=False)
|
|
135
|
+
)
|
|
136
|
+
tokens: List[Token] = [Token(text=str(text), logprob=0) for text in tokenization_result.raw_tokens]
|
|
137
|
+
completions.append(GeneratedOutput(text=text, logprob=0, tokens=tokens))
|
|
138
|
+
|
|
139
|
+
return RequestResult(
|
|
140
|
+
success=True,
|
|
141
|
+
cached=cached,
|
|
142
|
+
request_time=result["request_time"],
|
|
143
|
+
completions=completions,
|
|
144
|
+
embedding=[],
|
|
145
|
+
)
|
|
@@ -25,7 +25,7 @@ except ModuleNotFoundError as e:
|
|
|
25
25
|
|
|
26
26
|
class HuggingFaceVLMClient(CachingClient):
|
|
27
27
|
"""
|
|
28
|
-
General
|
|
28
|
+
General client for VLM models from HuggingFace.
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
31
|
_models_lock: Lock = Lock()
|
|
@@ -34,6 +34,10 @@ class HuggingFaceVLMClient(CachingClient):
|
|
|
34
34
|
"huggingface/llava-1.5-7b-hf": "llava-hf/llava-1.5-7b-hf",
|
|
35
35
|
"huggingface/llava-1.5-13b-hf": "llava-hf/llava-1.5-13b-hf",
|
|
36
36
|
"huggingface/bakLlava-v1-hf": "llava-hf/bakLlava-v1-hf",
|
|
37
|
+
"huggingface/llava-v1.6-vicuna-7b-hf": "llava-hf/llava-v1.6-vicuna-7b-hf",
|
|
38
|
+
"huggingface/llava-v1.6-vicuna-13b-hf": "llava-hf/llava-v1.6-vicuna-13b-hf",
|
|
39
|
+
"huggingface/llava-v1.6-mistral-7b-hf": "llava-hf/llava-v1.6-mistral-7b-hf",
|
|
40
|
+
"huggingface/llava-v1.6-34b-hf": "llava-hf/llava-v1.6-34b-hf",
|
|
37
41
|
}
|
|
38
42
|
|
|
39
43
|
def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig):
|
|
@@ -45,7 +49,7 @@ class HuggingFaceVLMClient(CachingClient):
|
|
|
45
49
|
with self._models_lock:
|
|
46
50
|
model_id: str = self._models_aliases.get(model_name, model_name)
|
|
47
51
|
if model_id not in self._models:
|
|
48
|
-
self._models[model_id] = pipeline("image-to-text", model=model_id)
|
|
52
|
+
self._models[model_id] = pipeline("image-to-text", model=model_id, device_map="auto")
|
|
49
53
|
return self._models[model_id]
|
|
50
54
|
|
|
51
55
|
def make_request(self, request: Request) -> RequestResult:
|
|
@@ -90,11 +94,14 @@ class HuggingFaceVLMClient(CachingClient):
|
|
|
90
94
|
except RuntimeError as e:
|
|
91
95
|
return RequestResult(success=False, cached=False, error=str(e), completions=[], embedding=[])
|
|
92
96
|
|
|
97
|
+
output: str = result["generated_text"]
|
|
98
|
+
if "ASSISTANT: " in output:
|
|
99
|
+
output = output.split("ASSISTANT: ")[1]
|
|
93
100
|
tokenization_result: TokenizationRequestResult = self.tokenizer.tokenize(
|
|
94
|
-
TokenizationRequest(
|
|
101
|
+
TokenizationRequest(output, tokenizer=self.tokenizer_name)
|
|
95
102
|
)
|
|
96
103
|
tokens: List[Token] = [Token(text=str(text), logprob=0) for text in tokenization_result.raw_tokens]
|
|
97
|
-
completions: List[GeneratedOutput] = [GeneratedOutput(text=
|
|
104
|
+
completions: List[GeneratedOutput] = [GeneratedOutput(text=output, logprob=0, tokens=tokens)]
|
|
98
105
|
return RequestResult(
|
|
99
106
|
success=True,
|
|
100
107
|
cached=cached,
|
|
@@ -88,7 +88,7 @@ class IDEFICSClient(CachingClient):
|
|
|
88
88
|
|
|
89
89
|
input_args: Dict[str, Union[str, bool]] = {"return_tensors": "pt"}
|
|
90
90
|
generation_args = {
|
|
91
|
-
"
|
|
91
|
+
"max_new_tokens": request.max_tokens,
|
|
92
92
|
"bad_words_ids": processor.tokenizer(self.BAD_WORD_TOKENS, add_special_tokens=False).input_ids,
|
|
93
93
|
}
|
|
94
94
|
|
|
@@ -140,7 +140,7 @@ class IDEFICSClient(CachingClient):
|
|
|
140
140
|
|
|
141
141
|
# Truncate the output text as IDEFICS outputs the entire sequence including the prompt
|
|
142
142
|
if "instruct" in request.model:
|
|
143
|
-
assert self.ASSISTANT_PREFIX in text, f"Expected {self.ASSISTANT_PREFIX} in the output"
|
|
143
|
+
assert self.ASSISTANT_PREFIX in text, f"Expected {self.ASSISTANT_PREFIX} in the output: {text}"
|
|
144
144
|
text = text.rpartition(self.ASSISTANT_PREFIX)[-1]
|
|
145
145
|
else:
|
|
146
146
|
# Best we can do is to remove the text portion of the prompt from the output
|