crfm-helm 0.5.0__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crfm-helm might be problematic. Click here for more details.

Files changed (56) hide show
  1. {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.1.dist-info}/METADATA +7 -3
  2. {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.1.dist-info}/RECORD +53 -41
  3. helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
  4. helm/benchmark/adaptation/adapters/multimodal/multimodal_prompt.py +7 -0
  5. helm/benchmark/adaptation/adapters/multimodal/test_multimodal_prompt.py +2 -0
  6. helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +1 -1
  7. helm/benchmark/augmentations/perturbation.py +17 -1
  8. helm/benchmark/augmentations/test_perturbation.py +30 -0
  9. helm/benchmark/metrics/efficiency_metrics.py +9 -2
  10. helm/benchmark/metrics/evaluate_reference_metrics.py +16 -0
  11. helm/benchmark/metrics/vision_language/image_metrics.py +142 -17
  12. helm/benchmark/model_metadata_registry.py +5 -1
  13. helm/benchmark/run_expander.py +35 -63
  14. helm/benchmark/run_spec_factory.py +11 -10
  15. helm/benchmark/run_specs/vlm_run_specs.py +294 -38
  16. helm/benchmark/scenarios/legalbench_scenario.py +6 -2
  17. helm/benchmark/scenarios/math_scenario.py +1 -1
  18. helm/benchmark/scenarios/vision_language/a_okvqa_scenario.py +83 -0
  19. helm/benchmark/scenarios/vision_language/crossmodal_3600_scenario.py +134 -0
  20. helm/benchmark/scenarios/vision_language/flickr30k_scenario.py +74 -0
  21. helm/benchmark/scenarios/vision_language/gqa_scenario.py +91 -0
  22. helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +4 -2
  23. helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +1 -1
  24. helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +1 -1
  25. helm/benchmark/scenarios/vision_language/math_vista_scenario.py +117 -0
  26. helm/benchmark/scenarios/vision_language/mm_safety_bench_scenario.py +103 -0
  27. helm/benchmark/scenarios/vision_language/mscoco_captioning_scenario.py +92 -0
  28. helm/benchmark/scenarios/vision_language/mscoco_categorization_scenario.py +117 -0
  29. helm/benchmark/scenarios/vision_language/originality_scenario.py +35 -0
  30. helm/benchmark/scenarios/vision_language/pairs_scenario.py +246 -0
  31. helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +2 -2
  32. helm/benchmark/scenarios/vision_language/vqa_scenario.py +4 -2
  33. helm/benchmark/static/schema_image2structure.yaml +304 -0
  34. helm/benchmark/static/schema_vhelm_lite.yaml +164 -0
  35. helm/benchmark/static/schema_vlm.yaml +257 -10
  36. helm/benchmark/static_build/assets/index-737eef9e.js +10 -0
  37. helm/benchmark/static_build/assets/index-878a1094.css +1 -0
  38. helm/benchmark/static_build/index.html +2 -2
  39. helm/clients/anthropic_client.py +36 -6
  40. helm/clients/openai_client.py +2 -3
  41. helm/clients/together_client.py +93 -2
  42. helm/clients/vertexai_client.py +59 -50
  43. helm/clients/vision_language/huggingface_vision2seq_client.py +145 -0
  44. helm/clients/vision_language/huggingface_vlm_client.py +11 -4
  45. helm/clients/vision_language/idefics_client.py +2 -2
  46. helm/common/images_utils.py +10 -3
  47. helm/config/model_deployments.yaml +100 -2
  48. helm/config/model_metadata.yaml +136 -31
  49. helm/config/tokenizer_configs.yaml +7 -0
  50. helm/benchmark/static_build/assets/index-5088afcb.css +0 -1
  51. helm/benchmark/static_build/assets/index-d839df55.js +0 -9
  52. helm/benchmark/test_model_deployment_definition.py +0 -90
  53. {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.1.dist-info}/LICENSE +0 -0
  54. {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.1.dist-info}/WHEEL +0 -0
  55. {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.1.dist-info}/entry_points.txt +0 -0
  56. {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.1.dist-info}/top_level.txt +0 -0
@@ -7,11 +7,11 @@
7
7
  <title>Holistic Evaluation of Language Models (HELM)</title>
8
8
  <meta name="description" content="The Holistic Evaluation of Language Models (HELM) serves as a living benchmark for transparency in language models. Providing broad coverage and recognizing incompleteness, multi-metric measurements, and standardization. All data and analysis are freely accessible on the website for exploration and study." />
9
9
  <script type="text/javascript" src="./config.js"></script>
10
- <script type="module" crossorigin src="./assets/index-d839df55.js"></script>
10
+ <script type="module" crossorigin src="./assets/index-737eef9e.js"></script>
11
11
  <link rel="modulepreload" crossorigin href="./assets/react-d4a0b69b.js">
12
12
  <link rel="modulepreload" crossorigin href="./assets/recharts-6d337683.js">
13
13
  <link rel="modulepreload" crossorigin href="./assets/tremor-54a99cc4.js">
14
- <link rel="stylesheet" href="./assets/index-5088afcb.css">
14
+ <link rel="stylesheet" href="./assets/index-878a1094.css">
15
15
  </head>
16
16
  <body class="block">
17
17
  <div id="root"></div>
@@ -1,6 +1,7 @@
1
1
  from typing import Any, Dict, List, Optional, TypedDict, Union, cast
2
2
  import json
3
3
  import requests
4
+ import tempfile
4
5
  import time
5
6
  import urllib.parse
6
7
 
@@ -68,6 +69,9 @@ class AnthropicClient(CachingClient):
68
69
  MAX_COMPLETION_LENGTH: int = (
69
70
  8192 # See https://docs.google.com/document/d/1vX6xgoA-KEKxqtMlBVAqYvE8KUfZ7ABCjTxAjf1T5kI/edit#
70
71
  )
72
+ # An Anthropic error message: "At least one of the image dimensions exceed max allowed size: 8000 pixels"
73
+ MAX_IMAGE_DIMENSION: int = 8000
74
+
71
75
  ADDITIONAL_TOKENS: int = 5
72
76
  PROMPT_ANSWER_START: str = "The answer is "
73
77
 
@@ -206,7 +210,7 @@ class AnthropicClient(CachingClient):
206
210
 
207
211
 
208
212
  def _is_content_moderation_failure(response: Dict) -> bool:
209
- """Return whether a a response failed because of the content moderation filter."""
213
+ """Return whether a response failed because of the content moderation filter."""
210
214
  if (
211
215
  "error" in response
212
216
  and "message" in response["error"]
@@ -238,7 +242,7 @@ class AnthropicMessagesResponseError(Exception):
238
242
 
239
243
  class AnthropicMessagesClient(CachingClient):
240
244
  # Source: https://docs.anthropic.com/claude/docs/models-overview
241
- MAX_OUTPUT_TOKENS = 4096
245
+ MAX_OUTPUT_TOKENS: int = 4096
242
246
 
243
247
  def __init__(
244
248
  self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig, api_key: Optional[str] = None
@@ -273,7 +277,7 @@ class AnthropicMessagesClient(CachingClient):
273
277
  # TODO(#2439): Refactor out Request validation
274
278
  if request.messages is not None or request.prompt:
275
279
  raise AnthropicMessagesRequestError(
276
- "Exactly one of Request.messages, Request.prompt or Request.multimodel_prompt should be set"
280
+ "Exactly one of Request.messages, Request.prompt or Request.multimodal_prompt should be set"
277
281
  )
278
282
  blocks: List[Union[TextBlockParam, ImageBlockParam]] = []
279
283
  for media_object in request.multimodal_prompt.media_objects:
@@ -282,9 +286,33 @@ class AnthropicMessagesClient(CachingClient):
282
286
  if not media_object.location:
283
287
  raise Exception("MediaObject of image type has missing location field value")
284
288
 
285
- from helm.common.images_utils import encode_base64
289
+ from helm.common.images_utils import encode_base64, get_dimensions, copy_image
290
+
291
+ image_location: str = media_object.location
292
+ base64_image: str
293
+
294
+ image_width, image_height = get_dimensions(media_object.location)
295
+ if (
296
+ image_width > AnthropicClient.MAX_IMAGE_DIMENSION
297
+ or image_height > AnthropicClient.MAX_IMAGE_DIMENSION
298
+ ):
299
+ hlog(
300
+ f"WARNING: Image {image_location} exceeds max allowed size: "
301
+ f"{AnthropicClient.MAX_IMAGE_DIMENSION} pixels"
302
+ )
303
+ # Save the resized image to a temporary file
304
+ with tempfile.NamedTemporaryFile(suffix=".jpg") as temp_file:
305
+ hlog(f"Resizing image to temporary path: {temp_file.name}")
306
+ copy_image(
307
+ src=image_location,
308
+ dest=temp_file.name,
309
+ width=min(image_width, AnthropicClient.MAX_IMAGE_DIMENSION),
310
+ height=min(image_height, AnthropicClient.MAX_IMAGE_DIMENSION),
311
+ )
312
+ base64_image = encode_base64(temp_file.name, format="JPEG")
313
+ else:
314
+ base64_image = encode_base64(image_location, format="JPEG")
286
315
 
287
- base64_image: str = encode_base64(media_object.location, format="JPEG")
288
316
  image_block: ImageBlockParam = {
289
317
  "type": "image",
290
318
  "source": {
@@ -302,7 +330,9 @@ class AnthropicMessagesClient(CachingClient):
302
330
  "type": "text",
303
331
  "text": media_object.text,
304
332
  }
305
- blocks.append(text_block)
333
+ # Anthropic does not support empty text blocks
334
+ if media_object.text.strip():
335
+ blocks.append(text_block)
306
336
  messages = [{"role": "user", "content": blocks}]
307
337
 
308
338
  else:
@@ -130,9 +130,8 @@ class OpenAIClient(CachingClient):
130
130
  from helm.common.images_utils import encode_base64
131
131
 
132
132
  base64_image: str = encode_base64(media_object.location)
133
- content.append(
134
- {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
135
- )
133
+ image_object: Dict[str, str] = {"url": f"data:image/jpeg;base64,{base64_image}"}
134
+ content.append({"type": "image_url", "image_url": image_object})
136
135
  elif media_object.is_type(TEXT_TYPE):
137
136
  if media_object.text is None:
138
137
  raise ValueError("MediaObject of text type has missing text field value")
@@ -1,12 +1,20 @@
1
1
  from copy import deepcopy
2
- from typing import List, Dict, Any, Optional, Union
2
+ from itertools import zip_longest
3
+ from typing import List, Dict, Any, Optional, TypedDict, Union
3
4
 
4
5
  import requests
5
6
  from retrying import retry
6
7
 
7
8
  from helm.common.cache import CacheConfig
9
+ from helm.common.optional_dependencies import handle_module_not_found_error
8
10
  from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, Token
9
- from .client import CachingClient, truncate_sequence, cleanup_str
11
+ from helm.clients.client import CachingClient, truncate_sequence, cleanup_str
12
+
13
+ try:
14
+ from together import Together
15
+ from together.types import ChatCompletionResponse
16
+ except ModuleNotFoundError as e:
17
+ handle_module_not_found_error(e, ["together"])
10
18
 
11
19
 
12
20
  class _RewriteRequestTags:
@@ -272,3 +280,86 @@ class TogetherClient(CachingClient):
272
280
  completions=completions,
273
281
  embedding=[],
274
282
  )
283
+
284
+
285
+ class TogetherRawChatRequest(TypedDict):
286
+ messages: List[Dict[str, str]]
287
+ model: str
288
+ max_tokens: int
289
+ stop: List[str]
290
+ temperature: float
291
+ top_p: float
292
+ top_k: int
293
+ logprobs: int
294
+ echo: bool
295
+ n: int
296
+
297
+
298
+ def convert_to_raw_chat_request(request: Request) -> TogetherRawChatRequest:
299
+ if request.messages:
300
+ messages = request.messages
301
+ else:
302
+ messages = [{"role": "user", "content": request.prompt}]
303
+ return {
304
+ "messages": messages,
305
+ "model": request.model,
306
+ "max_tokens": request.max_tokens,
307
+ "stop": request.stop_sequences,
308
+ "temperature": request.temperature,
309
+ "top_p": request.top_p,
310
+ "top_k": request.top_k_per_token,
311
+ "logprobs": min(request.top_k_per_token, 1),
312
+ "echo": request.echo_prompt,
313
+ "n": request.num_completions,
314
+ }
315
+
316
+
317
+ class TogetherChatClient(CachingClient):
318
+ """Client that uses the Python Together library for chat models."""
319
+
320
+ def __init__(self, cache_config: CacheConfig, api_key: str, together_model: Optional[str] = None):
321
+ super().__init__(cache_config=cache_config)
322
+ self._client = Together(api_key=api_key)
323
+
324
+ def make_request(self, request: Request) -> RequestResult:
325
+ raw_request = convert_to_raw_chat_request(request)
326
+ cache_key = CachingClient.make_cache_key(raw_request, request)
327
+
328
+ def do_it() -> Dict[Any, Any]:
329
+ response = self._client.chat.completions.create(**raw_request)
330
+ return response.model_dump(mode="json")
331
+
332
+ try:
333
+ raw_response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
334
+ response = ChatCompletionResponse.model_validate(raw_response)
335
+ except Exception as error:
336
+ return RequestResult(
337
+ success=False,
338
+ cached=False,
339
+ error=str(error),
340
+ completions=[],
341
+ embedding=[],
342
+ )
343
+
344
+ generated_outputs: List[GeneratedOutput] = []
345
+ for choice in response.choices:
346
+ # NOTE: Together always returns None for choice.finish_reason
347
+ # NOTE: Together does not return logprobs for the whole generated output, only for individual tokens
348
+ tokens: List[Token] = []
349
+ if choice.logprobs:
350
+ for token_text, token_logprob in zip_longest(
351
+ choice.logprobs.tokens or [], choice.logprobs.token_logprobs or []
352
+ ):
353
+ if token_text is None:
354
+ break
355
+ tokens.append(Token(text=token_text, logprob=token_logprob or 0.0))
356
+ assert choice.message.role == "assistant"
357
+ generated_outputs.append(GeneratedOutput(text=choice.message.content, logprob=0.0, tokens=tokens))
358
+ return RequestResult(
359
+ success=True,
360
+ cached=cached,
361
+ request_time=raw_response["request_time"],
362
+ request_datetime=raw_response["request_datetime"],
363
+ completions=generated_outputs,
364
+ embedding=[],
365
+ )
@@ -4,7 +4,6 @@ from threading import Lock
4
4
  from typing import Any, Dict, Optional, List, Union
5
5
 
6
6
  from helm.common.cache import CacheConfig
7
- from helm.common.hierarchical_logger import hlog
8
7
  from helm.common.media_object import TEXT_TYPE
9
8
  from helm.common.optional_dependencies import handle_module_not_found_error
10
9
  from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, ErrorFlags
@@ -131,12 +130,6 @@ class VertexAITextClient(VertexAIClient):
131
130
  class VertexAIChatClient(VertexAIClient):
132
131
  """Client for Vertex AI chat models (e.g., Gemini). Supports multimodal prompts."""
133
132
 
134
- # Set the finish reason to this if the prompt violates the content policy
135
- CONTENT_POLICY_VIOLATED_FINISH_REASON: str = "The prompt violates Google's content policy."
136
-
137
- # Gemini returns this error for certain valid requests
138
- CONTENT_HAS_NO_PARTS_ERROR: str = "Content has no parts."
139
-
140
133
  # Enum taken from:
141
134
  # https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1beta1#google.cloud.aiplatform.v1beta1.Candidate.FinishReason
142
135
  # We don't directly import this enum because it can differ between different Vertex AI library versions.
@@ -149,7 +142,7 @@ class VertexAIChatClient(VertexAIClient):
149
142
  ]
150
143
 
151
144
  @staticmethod
152
- def get_model(model_name: str) -> Any:
145
+ def get_model(model_name: str) -> GenerativeModel:
153
146
  global _models_lock
154
147
  global _models
155
148
 
@@ -202,21 +195,22 @@ class VertexAIChatClient(VertexAIClient):
202
195
  )
203
196
  candidates: List[Candidate] = response.candidates
204
197
 
205
- # Depending on the version of the Vertex AI library and the type of content blocking,
206
- # content blocking can show up in many ways, so this defensively handles most of these ways
198
+ # Depending on the version of the Vertex AI library and the type of prompt blocking,
199
+ # prompt blocking can show up in many ways, so this defensively handles most of these ways
200
+ if response.prompt_feedback.block_reason:
201
+ raise VertexAIContentBlockedError(
202
+ f"Prompt blocked with reason: {response.prompt_feedback.block_reason}"
203
+ )
207
204
  if not candidates:
208
- raise VertexAIContentBlockedError("No candidates in response due to content blocking")
205
+ raise VertexAIContentBlockedError(f"No candidates in response: {response}")
209
206
  predictions: List[Dict[str, Any]] = []
210
207
  for candidate in candidates:
211
- if (
212
- candidate.finish_reason in VertexAIChatClient.CONTENT_BLOCKED_FINISH_REASONS
213
- or not candidate.content.parts
214
- ):
215
- # The prediction was either blocked due to safety settings or the model stopped and returned
216
- # nothing (which also happens when the model is blocked).
217
- # For now, we don't cache blocked requests, because we are trying to get the
218
- # content blocking removed.
219
- raise VertexAIContentBlockedError("Content has no parts due to content blocking")
208
+ # Depending on the version of the Vertex AI library and the type of prompt blocking,
209
+ # content blocking can show up in many ways, so this defensively handles most of these ways
210
+ if candidate.finish_reason in VertexAIChatClient.CONTENT_BLOCKED_FINISH_REASONS:
211
+ raise VertexAIContentBlockedError(f"Content blocked with reason: {candidate.finish_reason}")
212
+ if not candidate.content.parts:
213
+ raise VertexAIContentBlockedError(f"No parts in candidate: {candidate}")
220
214
  predictions.append({"text": candidate.content.text})
221
215
  # TODO: Extract more information from the response
222
216
  return {"predictions": predictions}
@@ -234,11 +228,11 @@ class VertexAIChatClient(VertexAIClient):
234
228
  )
235
229
 
236
230
  response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
237
- except VertexAIContentBlockedError:
231
+ except VertexAIContentBlockedError as e:
238
232
  return RequestResult(
239
233
  success=False,
240
234
  cached=False,
241
- error="Response was empty due to content moderation filter",
235
+ error=f"Content blocked: {str(e)}",
242
236
  completions=[],
243
237
  embedding=[],
244
238
  error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
@@ -252,7 +246,7 @@ class VertexAIChatClient(VertexAIClient):
252
246
  return RequestResult(
253
247
  success=False,
254
248
  cached=False,
255
- error="Response was empty due to content moderation filter",
249
+ error=f"Content blocked error in cached response: {str(response)}",
256
250
  completions=[],
257
251
  embedding=[],
258
252
  error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
@@ -266,7 +260,7 @@ class VertexAIChatClient(VertexAIClient):
266
260
  return RequestResult(
267
261
  success=False,
268
262
  cached=False,
269
- error="Response was empty due to content moderation filter",
263
+ error=f"Content blocked error in cached prediction: {str(prediction)}",
270
264
  completions=[],
271
265
  embedding=[],
272
266
  error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
@@ -291,21 +285,6 @@ class VertexAIChatClient(VertexAIClient):
291
285
  )
292
286
 
293
287
  def _make_multimodal_request(self, request: Request) -> RequestResult:
294
- def complete_for_valid_error(error_message: str) -> RequestResult:
295
- empty_completion = GeneratedOutput(
296
- text="",
297
- logprob=0,
298
- tokens=[],
299
- finish_reason={"reason": error_message},
300
- )
301
- return RequestResult(
302
- success=True,
303
- cached=False,
304
- request_time=0,
305
- completions=[empty_completion] * request.num_completions,
306
- embedding=[],
307
- )
308
-
309
288
  # Contents can either be text or a list of multimodal content made up of text, images or other content
310
289
  contents: Union[str, List[Union[str, Any]]] = request.prompt
311
290
  # Used to generate a unique cache key for this specific request
@@ -346,14 +325,29 @@ class VertexAIChatClient(VertexAIClient):
346
325
  try:
347
326
 
348
327
  def do_it() -> Dict[str, Any]:
349
- raw_response = model.generate_content(
328
+ response: GenerationResponse = model.generate_content(
350
329
  contents, generation_config=parameters, safety_settings=self.safety_settings
351
330
  )
352
- if raw_response._raw_response.prompt_feedback.block_reason != 0:
353
- hlog(f"Content blocked for prompt: {request.multimodal_prompt}")
354
- return {"error": self.CONTENT_POLICY_VIOLATED_FINISH_REASON}
355
-
356
- return {"predictions": [{"text": raw_response.candidates[0].text}]}
331
+ # Depending on the version of the Vertex AI library and the type of prompt blocking,
332
+ # prompt blocking can show up in many ways, so this defensively handles most of these ways
333
+ if response.prompt_feedback.block_reason:
334
+ raise VertexAIContentBlockedError(
335
+ f"Prompt blocked with reason: {response.prompt_feedback.block_reason}"
336
+ )
337
+ if not response.candidates:
338
+ raise VertexAIContentBlockedError(f"No candidates in response: {response}")
339
+ # We should only have one candidate
340
+ assert (
341
+ len(response.candidates) == 1
342
+ ), f"Expected 1 candidate since candidate_count is 1, got {len(response.candidates)}."
343
+ candidate = response.candidates[0]
344
+ # Depending on the version of the Vertex AI library and the type of prompt blocking,
345
+ # content blocking can show up in many ways, so this defensively handles most of these ways
346
+ if candidate.finish_reason in VertexAIChatClient.CONTENT_BLOCKED_FINISH_REASONS:
347
+ raise VertexAIContentBlockedError(f"Content blocked with reason: {candidate.finish_reason}")
348
+ if not candidate.content.parts:
349
+ raise VertexAIContentBlockedError(f"No parts in candidate: {candidate}")
350
+ return {"predictions": [{"text": candidate.text}]}
357
351
 
358
352
  raw_cache_key = {"model_name": model_name, "prompt": prompt_key, **parameters}
359
353
  if completion_index > 0:
@@ -361,15 +355,30 @@ class VertexAIChatClient(VertexAIClient):
361
355
 
362
356
  cache_key = CachingClient.make_cache_key(raw_cache_key, request)
363
357
  response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
364
- except (requests.exceptions.RequestException, ValueError) as e:
365
- if str(e) == self.CONTENT_HAS_NO_PARTS_ERROR:
366
- return complete_for_valid_error(self.CONTENT_HAS_NO_PARTS_ERROR)
367
-
358
+ except requests.exceptions.RequestException as e:
368
359
  error: str = f"Gemini Vision error: {e}"
369
360
  return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
361
+ except VertexAIContentBlockedError as e:
362
+ return RequestResult(
363
+ success=False,
364
+ cached=False,
365
+ error=f"Content blocked: {str(e)}",
366
+ completions=[],
367
+ embedding=[],
368
+ error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
369
+ )
370
370
 
371
371
  if "error" in response:
372
- return complete_for_valid_error(response["error"])
372
+ return RequestResult(
373
+ success=False,
374
+ cached=True,
375
+ error=f"Content blocked error in cached response: {str(response)}",
376
+ completions=[],
377
+ embedding=[],
378
+ error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
379
+ request_time=response["request_time"],
380
+ request_datetime=response["request_datetime"],
381
+ )
373
382
 
374
383
  response_text = response["predictions"][0]["text"]
375
384
  completion = GeneratedOutput(text=response_text, logprob=0, tokens=[])
@@ -0,0 +1,145 @@
1
+ from threading import Lock
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ from dataclasses import dataclass
5
+ from transformers import AutoProcessor, AutoModelForVision2Seq
6
+ from transformers.image_utils import load_image
7
+ import torch
8
+
9
+ from helm.common.cache import CacheConfig
10
+ from helm.common.gpu_utils import get_torch_device_name, is_cuda_available
11
+ from helm.common.hierarchical_logger import hlog, htrack_block
12
+ from helm.common.media_object import TEXT_TYPE
13
+ from helm.common.request import Request, RequestResult, GeneratedOutput, Token
14
+ from helm.common.request import wrap_request_time
15
+ from helm.common.tokenization_request import TokenizationRequest
16
+ from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt
17
+ from helm.tokenizers.tokenizer import Tokenizer
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class Vision2SeqModelProcessor:
22
+ """Loaded model and processor."""
23
+
24
+ model: AutoModelForVision2Seq
25
+ processor: AutoProcessor
26
+
27
+
28
+ _models_lock: Lock = Lock()
29
+ _models: Dict[str, Optional[Vision2SeqModelProcessor]] = {
30
+ "HuggingFaceM4/idefics2-8b": None,
31
+ }
32
+
33
+
34
+ class HuggingFaceVision2SeqClient(CachingClient):
35
+ """
36
+ Models for Vision2Seq models from HuggingFace.
37
+ """
38
+
39
+ ASSISTANT_PREFIX: str = "Assistant:"
40
+
41
+ def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig):
42
+ super().__init__(cache_config=cache_config)
43
+ self.tokenizer = tokenizer
44
+ self.tokenizer_name = tokenizer_name
45
+ self._device: str = get_torch_device_name()
46
+
47
+ def _get_model(self, checkpoint: str) -> Vision2SeqModelProcessor:
48
+ global _models_lock
49
+ global _models
50
+
51
+ # Ensure that only one thread is loading the model at a time
52
+ with _models_lock:
53
+ loaded_model_processor = _models[checkpoint]
54
+ if loaded_model_processor is None:
55
+ hlog(f"Loading model {checkpoint} and caching in memory...")
56
+ torch_dtype: torch.dtype = torch.float16 if is_cuda_available() else torch.float32
57
+ model = AutoModelForVision2Seq.from_pretrained(checkpoint, torch_dtype=torch_dtype).to(self._device)
58
+ processor = AutoProcessor.from_pretrained(checkpoint)
59
+
60
+ _models[checkpoint] = Vision2SeqModelProcessor(model, processor)
61
+ loaded_model_processor = _models[checkpoint]
62
+
63
+ assert loaded_model_processor is not None
64
+ return loaded_model_processor
65
+
66
+ def make_request(self, request: Request) -> RequestResult:
67
+ assert request.model_deployment in _models, f"Not a valid model for this client: {request.model_deployment}"
68
+ assert request.multimodal_prompt is not None, "Multimodal prompt is required"
69
+
70
+ loaded_model_processor: Vision2SeqModelProcessor = self._get_model(request.model_deployment)
71
+ model = loaded_model_processor.model
72
+ processor = loaded_model_processor.processor
73
+
74
+ generation_args: Dict[str, Any] = {
75
+ "max_new_tokens": request.max_tokens,
76
+ }
77
+
78
+ image_paths: List[str] = []
79
+ multimodal_prompt: List[Dict[str, str]] = []
80
+ for media_object in request.multimodal_prompt.media_objects:
81
+ if media_object.is_type("image") and media_object.location:
82
+ image_paths.append(media_object.location)
83
+ multimodal_prompt.append({"type": "image"})
84
+ elif media_object.is_type(TEXT_TYPE):
85
+ if media_object.text is None:
86
+ raise ValueError("MediaObject of text type has missing text field value")
87
+
88
+ multimodal_prompt.append({"type": "text", "text": media_object.text})
89
+ else:
90
+ raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
91
+
92
+ completions: List[GeneratedOutput] = []
93
+ with htrack_block(f"Generating for prompt: {request.multimodal_prompt.text}"):
94
+ try:
95
+
96
+ def do_it() -> Dict[str, Any]:
97
+ messages = [{"role": "user", "content": multimodal_prompt}]
98
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
99
+ inputs = processor(
100
+ text=[prompt] * request.num_completions,
101
+ images=[
102
+ [load_image(image_path) for image_path in image_paths]
103
+ for _ in range(request.num_completions)
104
+ ],
105
+ return_tensors="pt",
106
+ )
107
+ inputs = {k: v.to(self._device) for k, v in inputs.items()}
108
+
109
+ # Generate
110
+ generated_ids = model.generate(**inputs, **generation_args)
111
+ generated_texts: List[str] = processor.batch_decode(generated_ids, skip_special_tokens=True)
112
+ return {"output": generated_texts}
113
+
114
+ # Include the prompt and model name in the cache key
115
+ cache_key = CachingClient.make_cache_key(
116
+ raw_request={
117
+ "n": request.num_completions,
118
+ "model": request.model,
119
+ "prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt),
120
+ **generation_args,
121
+ },
122
+ request=request,
123
+ )
124
+ result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
125
+ except RuntimeError as model_error:
126
+ return RequestResult(success=False, cached=False, error=str(model_error), completions=[], embedding=[])
127
+
128
+ for text in result["output"]:
129
+ hlog(f"Generated text: {text}")
130
+ assert self.ASSISTANT_PREFIX in text, f"Expected {self.ASSISTANT_PREFIX} in the output"
131
+ text = text.rpartition(self.ASSISTANT_PREFIX)[-1]
132
+ hlog(f"Truncated: {text}")
133
+ tokenization_result = self.tokenizer.tokenize(
134
+ TokenizationRequest(text, tokenizer=self.tokenizer_name, encode=False)
135
+ )
136
+ tokens: List[Token] = [Token(text=str(text), logprob=0) for text in tokenization_result.raw_tokens]
137
+ completions.append(GeneratedOutput(text=text, logprob=0, tokens=tokens))
138
+
139
+ return RequestResult(
140
+ success=True,
141
+ cached=cached,
142
+ request_time=result["request_time"],
143
+ completions=completions,
144
+ embedding=[],
145
+ )
@@ -25,7 +25,7 @@ except ModuleNotFoundError as e:
25
25
 
26
26
  class HuggingFaceVLMClient(CachingClient):
27
27
  """
28
- General CLient for VLM models from HuggingFace.
28
+ General client for VLM models from HuggingFace.
29
29
  """
30
30
 
31
31
  _models_lock: Lock = Lock()
@@ -34,6 +34,10 @@ class HuggingFaceVLMClient(CachingClient):
34
34
  "huggingface/llava-1.5-7b-hf": "llava-hf/llava-1.5-7b-hf",
35
35
  "huggingface/llava-1.5-13b-hf": "llava-hf/llava-1.5-13b-hf",
36
36
  "huggingface/bakLlava-v1-hf": "llava-hf/bakLlava-v1-hf",
37
+ "huggingface/llava-v1.6-vicuna-7b-hf": "llava-hf/llava-v1.6-vicuna-7b-hf",
38
+ "huggingface/llava-v1.6-vicuna-13b-hf": "llava-hf/llava-v1.6-vicuna-13b-hf",
39
+ "huggingface/llava-v1.6-mistral-7b-hf": "llava-hf/llava-v1.6-mistral-7b-hf",
40
+ "huggingface/llava-v1.6-34b-hf": "llava-hf/llava-v1.6-34b-hf",
37
41
  }
38
42
 
39
43
  def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig):
@@ -45,7 +49,7 @@ class HuggingFaceVLMClient(CachingClient):
45
49
  with self._models_lock:
46
50
  model_id: str = self._models_aliases.get(model_name, model_name)
47
51
  if model_id not in self._models:
48
- self._models[model_id] = pipeline("image-to-text", model=model_id)
52
+ self._models[model_id] = pipeline("image-to-text", model=model_id, device_map="auto")
49
53
  return self._models[model_id]
50
54
 
51
55
  def make_request(self, request: Request) -> RequestResult:
@@ -90,11 +94,14 @@ class HuggingFaceVLMClient(CachingClient):
90
94
  except RuntimeError as e:
91
95
  return RequestResult(success=False, cached=False, error=str(e), completions=[], embedding=[])
92
96
 
97
+ output: str = result["generated_text"]
98
+ if "ASSISTANT: " in output:
99
+ output = output.split("ASSISTANT: ")[1]
93
100
  tokenization_result: TokenizationRequestResult = self.tokenizer.tokenize(
94
- TokenizationRequest(result["generated_text"], tokenizer=self.tokenizer_name)
101
+ TokenizationRequest(output, tokenizer=self.tokenizer_name)
95
102
  )
96
103
  tokens: List[Token] = [Token(text=str(text), logprob=0) for text in tokenization_result.raw_tokens]
97
- completions: List[GeneratedOutput] = [GeneratedOutput(text=result["generated_text"], logprob=0, tokens=tokens)]
104
+ completions: List[GeneratedOutput] = [GeneratedOutput(text=output, logprob=0, tokens=tokens)]
98
105
  return RequestResult(
99
106
  success=True,
100
107
  cached=cached,
@@ -88,7 +88,7 @@ class IDEFICSClient(CachingClient):
88
88
 
89
89
  input_args: Dict[str, Union[str, bool]] = {"return_tensors": "pt"}
90
90
  generation_args = {
91
- "max_length": request.max_tokens,
91
+ "max_new_tokens": request.max_tokens,
92
92
  "bad_words_ids": processor.tokenizer(self.BAD_WORD_TOKENS, add_special_tokens=False).input_ids,
93
93
  }
94
94
 
@@ -140,7 +140,7 @@ class IDEFICSClient(CachingClient):
140
140
 
141
141
  # Truncate the output text as IDEFICS outputs the entire sequence including the prompt
142
142
  if "instruct" in request.model:
143
- assert self.ASSISTANT_PREFIX in text, f"Expected {self.ASSISTANT_PREFIX} in the output"
143
+ assert self.ASSISTANT_PREFIX in text, f"Expected {self.ASSISTANT_PREFIX} in the output: {text}"
144
144
  text = text.rpartition(self.ASSISTANT_PREFIX)[-1]
145
145
  else:
146
146
  # Best we can do is to remove the text portion of the prompt from the output