crfm-helm 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crfm-helm might be problematic. Click here for more details.

Files changed (125) hide show
  1. {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/METADATA +19 -5
  2. {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/RECORD +121 -76
  3. helm/benchmark/adaptation/adapter_spec.py +32 -31
  4. helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
  5. helm/benchmark/adaptation/adapters/multimodal/multimodal_prompt.py +7 -0
  6. helm/benchmark/adaptation/adapters/multimodal/test_multimodal_prompt.py +2 -0
  7. helm/benchmark/annotation/air_bench_annotator.py +64 -0
  8. helm/benchmark/annotation/annotator_factory.py +6 -0
  9. helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +1 -1
  10. helm/benchmark/annotation/live_qa_annotator.py +84 -0
  11. helm/benchmark/annotation/medication_qa_annotator.py +81 -0
  12. helm/benchmark/augmentations/perturbation.py +17 -1
  13. helm/benchmark/augmentations/test_perturbation.py +30 -0
  14. helm/benchmark/augmentations/translate_perturbation.py +1 -0
  15. helm/benchmark/huggingface_registration.py +16 -6
  16. helm/benchmark/metrics/air_bench_metrics.py +56 -0
  17. helm/benchmark/metrics/efficiency_metrics.py +9 -2
  18. helm/benchmark/metrics/evaluate_reference_metrics.py +16 -0
  19. helm/benchmark/metrics/fin_qa_metrics.py +60 -0
  20. helm/benchmark/metrics/fin_qa_metrics_helper.py +398 -0
  21. helm/benchmark/metrics/gpt4v_originality_critique_metrics.py +126 -0
  22. helm/benchmark/metrics/instruction_following_critique_metrics.py +1 -0
  23. helm/benchmark/metrics/live_qa_metrics.py +23 -0
  24. helm/benchmark/metrics/medication_qa_metrics.py +23 -0
  25. helm/benchmark/metrics/prometheus_vision_critique_metrics.py +185 -0
  26. helm/benchmark/metrics/reka_vibe_critique_metrics.py +158 -0
  27. helm/benchmark/metrics/unitxt_metrics.py +20 -10
  28. helm/benchmark/metrics/vision_language/emd_utils.py +4 -0
  29. helm/benchmark/metrics/vision_language/image_metrics.py +104 -21
  30. helm/benchmark/model_metadata_registry.py +5 -1
  31. helm/benchmark/presentation/schema.py +54 -4
  32. helm/benchmark/presentation/test_schema.py +11 -0
  33. helm/benchmark/run.py +16 -2
  34. helm/benchmark/run_expander.py +112 -63
  35. helm/benchmark/run_spec_factory.py +15 -10
  36. helm/benchmark/run_specs/air_bench_run_specs.py +40 -0
  37. helm/benchmark/run_specs/classic_run_specs.py +15 -11
  38. helm/benchmark/run_specs/decodingtrust_run_specs.py +3 -1
  39. helm/benchmark/run_specs/experimental_run_specs.py +33 -0
  40. helm/benchmark/run_specs/finance_run_specs.py +33 -0
  41. helm/benchmark/run_specs/vlm_run_specs.py +444 -65
  42. helm/benchmark/scenarios/air_bench_scenario.py +50 -0
  43. helm/benchmark/scenarios/ci_mcqa_scenario.py +80 -0
  44. helm/benchmark/scenarios/entity_data_imputation_scenario.py +8 -2
  45. helm/benchmark/scenarios/fin_qa_scenario.py +117 -0
  46. helm/benchmark/scenarios/legalbench_scenario.py +6 -2
  47. helm/benchmark/scenarios/math_scenario.py +1 -1
  48. helm/benchmark/scenarios/test_air_bench_scenario.py +27 -0
  49. helm/benchmark/scenarios/vision_language/a_okvqa_scenario.py +83 -0
  50. helm/benchmark/scenarios/vision_language/bingo_scenario.py +3 -3
  51. helm/benchmark/scenarios/vision_language/crossmodal_3600_scenario.py +134 -0
  52. helm/benchmark/scenarios/vision_language/flickr30k_scenario.py +74 -0
  53. helm/benchmark/scenarios/vision_language/gqa_scenario.py +91 -0
  54. helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +4 -2
  55. helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +13 -2
  56. helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +1 -5
  57. helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +1 -5
  58. helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +5 -3
  59. helm/benchmark/scenarios/vision_language/math_vista_scenario.py +117 -0
  60. helm/benchmark/scenarios/vision_language/mm_safety_bench_scenario.py +103 -0
  61. helm/benchmark/scenarios/vision_language/mscoco_captioning_scenario.py +92 -0
  62. helm/benchmark/scenarios/vision_language/mscoco_categorization_scenario.py +117 -0
  63. helm/benchmark/scenarios/vision_language/originality_scenario.py +35 -0
  64. helm/benchmark/scenarios/vision_language/pairs_scenario.py +247 -0
  65. helm/benchmark/scenarios/vision_language/unicorn_scenario.py +3 -3
  66. helm/benchmark/scenarios/vision_language/vibe_eval_scenario.py +95 -0
  67. helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +2 -2
  68. helm/benchmark/scenarios/vision_language/vqa_scenario.py +4 -2
  69. helm/benchmark/static/schema_air_bench.yaml +3149 -0
  70. helm/benchmark/static/schema_classic.yaml +3 -59
  71. helm/benchmark/static/schema_finance.yaml +143 -0
  72. helm/benchmark/static/schema_image2structure.yaml +447 -0
  73. helm/benchmark/static/schema_instruction_following.yaml +3 -52
  74. helm/benchmark/static/schema_lite.yaml +3 -61
  75. helm/benchmark/static/schema_medical.yaml +255 -0
  76. helm/benchmark/static/schema_mmlu.yaml +3 -61
  77. helm/benchmark/static/schema_tables.yaml +200 -0
  78. helm/benchmark/static/schema_thai.yaml +223 -0
  79. helm/benchmark/static/schema_unitxt.yaml +3 -61
  80. helm/benchmark/static/schema_vhelm.yaml +824 -0
  81. helm/benchmark/static/schema_vhelm_lite.yaml +109 -0
  82. helm/benchmark/static_build/assets/air-overview-d2e6c49f.png +0 -0
  83. helm/benchmark/static_build/assets/index-30dbceba.js +10 -0
  84. helm/benchmark/static_build/assets/index-66b02d40.css +1 -0
  85. helm/benchmark/static_build/assets/overview-74aea3d8.png +0 -0
  86. helm/benchmark/static_build/assets/process-flow-bd2eba96.png +0 -0
  87. helm/benchmark/static_build/index.html +2 -2
  88. helm/clients/anthropic_client.py +78 -14
  89. helm/clients/auto_client.py +11 -0
  90. helm/clients/client.py +24 -7
  91. helm/clients/cohere_client.py +98 -3
  92. helm/clients/huggingface_client.py +71 -12
  93. helm/clients/openai_client.py +11 -5
  94. helm/clients/reka_client.py +189 -0
  95. helm/clients/test_client.py +3 -3
  96. helm/clients/test_huggingface_client.py +19 -3
  97. helm/clients/test_together_client.py +72 -2
  98. helm/clients/together_client.py +199 -2
  99. helm/clients/vertexai_client.py +117 -64
  100. helm/clients/vision_language/huggingface_vision2seq_client.py +145 -0
  101. helm/clients/vision_language/huggingface_vlm_client.py +12 -4
  102. helm/clients/vision_language/idefics_client.py +2 -2
  103. helm/clients/vision_language/paligemma_client.py +146 -0
  104. helm/clients/vision_language/palmyra_vision_client.py +84 -0
  105. helm/clients/yi_client.py +31 -0
  106. helm/common/critique_request.py +10 -1
  107. helm/common/images_utils.py +29 -3
  108. helm/config/model_deployments.yaml +504 -12
  109. helm/config/model_metadata.yaml +579 -52
  110. helm/config/tokenizer_configs.yaml +100 -1
  111. helm/proxy/critique/model_critique_client.py +32 -4
  112. helm/proxy/services/server_service.py +1 -1
  113. helm/tokenizers/auto_tokenizer.py +1 -1
  114. helm/tokenizers/cohere_tokenizer.py +44 -2
  115. helm/tokenizers/huggingface_tokenizer.py +36 -13
  116. helm/tokenizers/test_cohere_tokenizer.py +39 -0
  117. helm/tokenizers/test_huggingface_tokenizer.py +5 -1
  118. helm/benchmark/static/schema_vlm.yaml +0 -576
  119. helm/benchmark/static_build/assets/index-5088afcb.css +0 -1
  120. helm/benchmark/static_build/assets/index-d839df55.js +0 -9
  121. helm/benchmark/test_model_deployment_definition.py +0 -90
  122. {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/LICENSE +0 -0
  123. {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/WHEEL +0 -0
  124. {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/entry_points.txt +0 -0
  125. {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/top_level.txt +0 -0
@@ -7,11 +7,11 @@
7
7
  <title>Holistic Evaluation of Language Models (HELM)</title>
8
8
  <meta name="description" content="The Holistic Evaluation of Language Models (HELM) serves as a living benchmark for transparency in language models. Providing broad coverage and recognizing incompleteness, multi-metric measurements, and standardization. All data and analysis are freely accessible on the website for exploration and study." />
9
9
  <script type="text/javascript" src="./config.js"></script>
10
- <script type="module" crossorigin src="./assets/index-d839df55.js"></script>
10
+ <script type="module" crossorigin src="./assets/index-30dbceba.js"></script>
11
11
  <link rel="modulepreload" crossorigin href="./assets/react-d4a0b69b.js">
12
12
  <link rel="modulepreload" crossorigin href="./assets/recharts-6d337683.js">
13
13
  <link rel="modulepreload" crossorigin href="./assets/tremor-54a99cc4.js">
14
- <link rel="stylesheet" href="./assets/index-5088afcb.css">
14
+ <link rel="stylesheet" href="./assets/index-66b02d40.css">
15
15
  </head>
16
16
  <body class="block">
17
17
  <div id="root"></div>
@@ -1,6 +1,8 @@
1
1
  from typing import Any, Dict, List, Optional, TypedDict, Union, cast
2
2
  import json
3
+ import os
3
4
  import requests
5
+ import tempfile
4
6
  import time
5
7
  import urllib.parse
6
8
 
@@ -68,6 +70,9 @@ class AnthropicClient(CachingClient):
68
70
  MAX_COMPLETION_LENGTH: int = (
69
71
  8192 # See https://docs.google.com/document/d/1vX6xgoA-KEKxqtMlBVAqYvE8KUfZ7ABCjTxAjf1T5kI/edit#
70
72
  )
73
+ # An Anthropic error message: "At least one of the image dimensions exceed max allowed size: 8000 pixels"
74
+ MAX_IMAGE_DIMENSION: int = 8000
75
+
71
76
  ADDITIONAL_TOKENS: int = 5
72
77
  PROMPT_ANSWER_START: str = "The answer is "
73
78
 
@@ -206,7 +211,7 @@ class AnthropicClient(CachingClient):
206
211
 
207
212
 
208
213
  def _is_content_moderation_failure(response: Dict) -> bool:
209
- """Return whether a a response failed because of the content moderation filter."""
214
+ """Return whether a response failed because of the content moderation filter."""
210
215
  if (
211
216
  "error" in response
212
217
  and "message" in response["error"]
@@ -238,7 +243,9 @@ class AnthropicMessagesResponseError(Exception):
238
243
 
239
244
  class AnthropicMessagesClient(CachingClient):
240
245
  # Source: https://docs.anthropic.com/claude/docs/models-overview
241
- MAX_OUTPUT_TOKENS = 4096
246
+ MAX_OUTPUT_TOKENS: int = 4096
247
+
248
+ MAX_IMAGE_SIZE_BYTES: int = 5242880 # 5MB
242
249
 
243
250
  def __init__(
244
251
  self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig, api_key: Optional[str] = None
@@ -273,7 +280,7 @@ class AnthropicMessagesClient(CachingClient):
273
280
  # TODO(#2439): Refactor out Request validation
274
281
  if request.messages is not None or request.prompt:
275
282
  raise AnthropicMessagesRequestError(
276
- "Exactly one of Request.messages, Request.prompt or Request.multimodel_prompt should be set"
283
+ "Exactly one of Request.messages, Request.prompt or Request.multimodal_prompt should be set"
277
284
  )
278
285
  blocks: List[Union[TextBlockParam, ImageBlockParam]] = []
279
286
  for media_object in request.multimodal_prompt.media_objects:
@@ -282,9 +289,53 @@ class AnthropicMessagesClient(CachingClient):
282
289
  if not media_object.location:
283
290
  raise Exception("MediaObject of image type has missing location field value")
284
291
 
285
- from helm.common.images_utils import encode_base64
292
+ from helm.common.images_utils import (
293
+ encode_base64,
294
+ get_dimensions,
295
+ copy_image,
296
+ resize_image_to_max_file_size,
297
+ )
298
+
299
+ image_location: str = media_object.location
300
+ base64_image: str
301
+
302
+ image_width, image_height = get_dimensions(media_object.location)
303
+ if (
304
+ image_width > AnthropicClient.MAX_IMAGE_DIMENSION
305
+ or image_height > AnthropicClient.MAX_IMAGE_DIMENSION
306
+ ):
307
+ hlog(
308
+ f"WARNING: Image {image_location} exceeds max allowed size: "
309
+ f"{AnthropicClient.MAX_IMAGE_DIMENSION} pixels"
310
+ )
311
+ # Save the resized image to a temporary file
312
+ with tempfile.NamedTemporaryFile(suffix=".jpg") as temp_file:
313
+ hlog(f"Resizing image to temporary path: {temp_file.name}")
314
+ copy_image(
315
+ src=image_location,
316
+ dest=temp_file.name,
317
+ width=min(image_width, AnthropicClient.MAX_IMAGE_DIMENSION),
318
+ height=min(image_height, AnthropicClient.MAX_IMAGE_DIMENSION),
319
+ )
320
+ base64_image = encode_base64(temp_file.name, format="JPEG")
321
+
322
+ elif os.path.getsize(image_location) > AnthropicMessagesClient.MAX_IMAGE_SIZE_BYTES:
323
+ hlog(
324
+ f"WARNING: Image {image_location} exceeds max allowed size: "
325
+ f"{AnthropicMessagesClient.MAX_IMAGE_SIZE_BYTES} bytes"
326
+ )
327
+ # Resize the image so it is smaller than the max allowed size
328
+ with tempfile.NamedTemporaryFile(suffix=".jpg") as temp_file:
329
+ hlog(f"Resizing image to temporary path: {temp_file.name}")
330
+ resize_image_to_max_file_size(
331
+ src=image_location,
332
+ dest=temp_file.name,
333
+ max_size_in_bytes=AnthropicMessagesClient.MAX_IMAGE_SIZE_BYTES,
334
+ )
335
+ base64_image = encode_base64(temp_file.name, format="JPEG")
336
+ else:
337
+ base64_image = encode_base64(image_location, format="JPEG")
286
338
 
287
- base64_image: str = encode_base64(media_object.location, format="JPEG")
288
339
  image_block: ImageBlockParam = {
289
340
  "type": "image",
290
341
  "source": {
@@ -302,7 +353,9 @@ class AnthropicMessagesClient(CachingClient):
302
353
  "type": "text",
303
354
  "text": media_object.text,
304
355
  }
305
- blocks.append(text_block)
356
+ # Anthropic does not support empty text blocks
357
+ if media_object.text.strip():
358
+ blocks.append(text_block)
306
359
  messages = [{"role": "user", "content": blocks}]
307
360
 
308
361
  else:
@@ -338,14 +391,25 @@ class AnthropicMessagesClient(CachingClient):
338
391
  return response
339
392
  raise
340
393
 
341
- cache_key = CachingClient.make_cache_key(
342
- {
343
- "completion_index": completion_index,
344
- **raw_request,
345
- },
346
- request,
347
- )
348
- response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
394
+ try:
395
+ cache_key = CachingClient.make_cache_key(
396
+ {
397
+ "completion_index": completion_index,
398
+ **raw_request,
399
+ },
400
+ request,
401
+ )
402
+ response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
403
+ except AnthropicMessagesResponseError:
404
+ hlog("WARNING: Response has empty content")
405
+ return RequestResult(
406
+ success=False,
407
+ cached=False,
408
+ error="Anthropic response has empty content",
409
+ completions=[],
410
+ embedding=[],
411
+ error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
412
+ )
349
413
 
350
414
  if _is_content_moderation_failure(response):
351
415
  hlog(
@@ -5,6 +5,7 @@ from typing import Any, Dict, Mapping, Optional
5
5
  from retrying import Attempt, RetryError
6
6
 
7
7
  from helm.benchmark.model_deployment_registry import ModelDeployment, get_model_deployment
8
+ from helm.benchmark.tokenizer_config_registry import get_tokenizer_config
8
9
  from helm.common.file_caches.file_cache import FileCache
9
10
  from helm.common.file_caches.local_file_cache import LocalFileCache
10
11
  from helm.common.credentials_utils import provide_api_key
@@ -88,6 +89,10 @@ class AutoClient(Client):
88
89
  "location": lambda: self.credentials.get(host_organization + "Location", None), # VertexAI
89
90
  "hf_auth_token": lambda: self.credentials.get("huggingfaceAuthToken", None), # HuggingFace
90
91
  "file_cache": lambda: self._get_file_cache(host_organization), # Text-to-image models
92
+ "endpoint": lambda: self.credentials.get(host_organization + "Endpoint", None), # Palmyra
93
+ "end_of_text_token": lambda: self._get_end_of_text_token(
94
+ tokenizer_name=model_deployment.tokenizer_name or model_deployment.name
95
+ ),
91
96
  },
92
97
  )
93
98
  client = create_object(client_spec)
@@ -213,3 +218,9 @@ class AutoClient(Client):
213
218
  # Initialize `FileCache` for text-to-image model APIs
214
219
  local_file_cache_path: str = os.path.join(self.file_storage_path, "output", host_organization)
215
220
  return LocalFileCache(local_file_cache_path, file_extension="png")
221
+
222
+ def _get_end_of_text_token(self, tokenizer_name: str) -> Optional[str]:
223
+ tokenizer_config = get_tokenizer_config(tokenizer_name)
224
+ if tokenizer_config is None:
225
+ raise ValueError(f"Could not find tokenizer_config for tokenizer {tokenizer_name}")
226
+ return tokenizer_config.end_of_text_token
helm/clients/client.py CHANGED
@@ -39,13 +39,17 @@ class CachingClient(Client):
39
39
  """
40
40
  if request.random is not None:
41
41
  assert "random" not in raw_request
42
- cache_key: Mapping = {**raw_request, "random": request.random}
42
+ return {**raw_request, "random": request.random}
43
43
  else:
44
- cache_key = raw_request
45
- return cache_key
44
+ return {**raw_request}
46
45
 
47
46
 
48
- def truncate_sequence(sequence: GeneratedOutput, request: Request, print_warning: bool = True) -> GeneratedOutput:
47
+ def truncate_sequence(
48
+ sequence: GeneratedOutput,
49
+ request: Request,
50
+ end_of_text_token: Optional[str] = None,
51
+ print_warning: bool = True,
52
+ ) -> GeneratedOutput:
49
53
  """
50
54
  Certain providers have bugs where they aren't respecting max_tokens,
51
55
  stop_sequences and the end of text token, so as a hack, we have to manually
@@ -64,7 +68,11 @@ def truncate_sequence(sequence: GeneratedOutput, request: Request, print_warning
64
68
  hlog("WARNING: don't know how to handle echo_prompt and max_tokens > 0, not truncating")
65
69
  return sequence
66
70
 
67
- for stop in request.stop_sequences:
71
+ if end_of_text_token:
72
+ stop_sequences = request.stop_sequences + [end_of_text_token]
73
+ else:
74
+ stop_sequences = request.stop_sequences
75
+ for stop in stop_sequences:
68
76
  # Find `stop` in the text
69
77
  try:
70
78
  new_text = sequence.text[: sequence.text.index(stop)]
@@ -116,7 +124,12 @@ def truncate_sequence(sequence: GeneratedOutput, request: Request, print_warning
116
124
 
117
125
 
118
126
  def truncate_and_tokenize_response_text(
119
- text: str, request: Request, tokenizer: Tokenizer, tokenizer_name: str, original_finish_reason: str = "endoftext"
127
+ text: str,
128
+ request: Request,
129
+ tokenizer: Tokenizer,
130
+ tokenizer_name: str,
131
+ end_of_text_token: Optional[str] = None,
132
+ original_finish_reason: str = "endoftext",
120
133
  ) -> GeneratedOutput:
121
134
  """Truncate a string-only response to respect stop_sequences and max_tokens.
122
135
 
@@ -139,7 +152,11 @@ def truncate_and_tokenize_response_text(
139
152
  if request.echo_prompt:
140
153
  raise Exception("truncate_and_tokenize_response_text() does not support requests with echo_prompt = True")
141
154
 
142
- for stop_sequence in request.stop_sequences:
155
+ if end_of_text_token:
156
+ stop_sequences = request.stop_sequences + [end_of_text_token]
157
+ else:
158
+ stop_sequences = request.stop_sequences
159
+ for stop_sequence in stop_sequences:
143
160
  try:
144
161
  text = text[: text.index(stop_sequence)]
145
162
  finish_reason = "stop"
@@ -1,8 +1,9 @@
1
1
  import json
2
2
  import requests
3
- from typing import List
3
+ from typing import List, Optional, Sequence, TypedDict
4
4
 
5
5
  from helm.common.cache import CacheConfig
6
+ from helm.common.optional_dependencies import handle_module_not_found_error
6
7
  from helm.common.request import (
7
8
  wrap_request_time,
8
9
  EMBEDDING_UNAVAILABLE_REQUEST_RESULT,
@@ -11,8 +12,13 @@ from helm.common.request import (
11
12
  GeneratedOutput,
12
13
  Token,
13
14
  )
14
- from .client import CachingClient, truncate_sequence
15
- from .cohere_utils import get_cohere_url, DEFAULT_COHERE_API_VERSION
15
+ from helm.clients.client import CachingClient, truncate_sequence
16
+ from helm.clients.cohere_utils import get_cohere_url, DEFAULT_COHERE_API_VERSION
17
+
18
+ try:
19
+ import cohere
20
+ except ModuleNotFoundError as e:
21
+ handle_module_not_found_error(e, ["cohere"])
16
22
 
17
23
 
18
24
  class CohereClient(CachingClient):
@@ -152,3 +158,92 @@ class CohereClient(CachingClient):
152
158
  completions=completions,
153
159
  embedding=[],
154
160
  )
161
+
162
+
163
+ class CohereRawChatRequest(TypedDict):
164
+ message: str
165
+ model: Optional[str]
166
+ preamble: Optional[str]
167
+ chat_history: Optional[Sequence[cohere.ChatMessage]]
168
+ temperature: Optional[float]
169
+ max_tokens: Optional[int]
170
+ k: Optional[int]
171
+ p: Optional[float]
172
+ seed: Optional[float]
173
+ stop_sequences: Optional[Sequence[str]]
174
+ frequency_penalty: Optional[float]
175
+ presence_penalty: Optional[float]
176
+
177
+
178
+ def convert_to_raw_chat_request(request: Request) -> CohereRawChatRequest:
179
+ # TODO: Support chat
180
+ model = request.model.replace("cohere/", "")
181
+ return {
182
+ "message": request.prompt,
183
+ "model": model,
184
+ "preamble": None,
185
+ "chat_history": None,
186
+ "temperature": request.temperature,
187
+ "max_tokens": request.max_tokens,
188
+ "k": request.top_k_per_token,
189
+ "p": request.top_p,
190
+ "stop_sequences": request.stop_sequences,
191
+ "seed": float(request.random) if request.random is not None else None,
192
+ "frequency_penalty": request.frequency_penalty,
193
+ "presence_penalty": request.presence_penalty,
194
+ }
195
+
196
+
197
+ class CohereChatClient(CachingClient):
198
+ """
199
+ Leverages the chat endpoint: https://docs.cohere.com/reference/chat
200
+
201
+ Cohere models will only support chat soon: https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat
202
+ """
203
+
204
+ def __init__(self, api_key: str, cache_config: CacheConfig):
205
+ super().__init__(cache_config=cache_config)
206
+ self.client = cohere.Client(api_key=api_key)
207
+
208
+ def make_request(self, request: Request) -> RequestResult:
209
+ if request.embedding:
210
+ return EMBEDDING_UNAVAILABLE_REQUEST_RESULT
211
+ # TODO: Support multiple completions
212
+ assert request.num_completions == 1, "CohereChatClient only supports num_completions=1"
213
+ # TODO: Support messages
214
+ assert not request.messages, "CohereChatClient currently does not support the messages API"
215
+
216
+ raw_request: CohereRawChatRequest = convert_to_raw_chat_request(request)
217
+
218
+ try:
219
+
220
+ def do_it():
221
+ """
222
+ Send the request to the Cohere Chat API. Responses will be structured like this:
223
+ cohere.Chat {
224
+ message: What's up?
225
+ text: Hey there! How's it going? I'm doing well, thank you for asking 😊.
226
+ ...
227
+ }
228
+ """
229
+ raw_response = self.client.chat(**raw_request).dict()
230
+ assert "text" in raw_response, f"Response does not contain text: {raw_response}"
231
+ return raw_response
232
+
233
+ response, cached = self.cache.get(raw_request, wrap_request_time(do_it))
234
+ except (requests.exceptions.RequestException, AssertionError) as e:
235
+ error: str = f"CohereClient error: {e}"
236
+ return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
237
+
238
+ completions: List[GeneratedOutput] = []
239
+ completion: GeneratedOutput = GeneratedOutput(text=response["text"], logprob=0.0, tokens=[])
240
+ completions.append(completion)
241
+
242
+ return RequestResult(
243
+ success=True,
244
+ cached=cached,
245
+ request_time=response["request_time"],
246
+ request_datetime=response["request_datetime"],
247
+ completions=completions,
248
+ embedding=[],
249
+ )
@@ -17,6 +17,7 @@ from helm.common.request import (
17
17
  GeneratedOutput,
18
18
  Token,
19
19
  )
20
+ from helm.tokenizers.tokenizer import Tokenizer
20
21
  from .client import CachingClient, truncate_sequence
21
22
  from helm.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer, WrappedPreTrainedTokenizer
22
23
  from threading import Lock
@@ -53,7 +54,13 @@ class HuggingFaceRequest(TypedDict):
53
54
  class HuggingFaceServer:
54
55
  """A thin wrapper around a Hugging Face AutoModelForCausalLM for HuggingFaceClient to call."""
55
56
 
56
- def __init__(self, pretrained_model_name_or_path: str, **kwargs):
57
+ def __init__(
58
+ self,
59
+ pretrained_model_name_or_path: str,
60
+ wrapped_tokenizer: WrappedPreTrainedTokenizer,
61
+ openvino=False,
62
+ **kwargs,
63
+ ):
57
64
  if torch.cuda.is_available():
58
65
  hlog("CUDA is available, initializing with a GPU...")
59
66
  self.device: str = "cuda:0"
@@ -61,13 +68,44 @@ class HuggingFaceServer:
61
68
  self.device = "cpu"
62
69
  with htrack_block(f"Loading Hugging Face model {pretrained_model_name_or_path}"):
63
70
  # WARNING this may fail if your GPU does not have enough memory
64
- self.model = AutoModelForCausalLM.from_pretrained(
65
- pretrained_model_name_or_path, trust_remote_code=True, **kwargs
66
- ).to(self.device)
67
- with htrack_block(f"Loading Hugging Face tokenizer for model {pretrained_model_name_or_path}"):
68
- self.wrapped_tokenizer: WrappedPreTrainedTokenizer = HuggingFaceTokenizer.create_tokenizer(
69
- pretrained_model_name_or_path, **kwargs
70
- )
71
+ if openvino:
72
+ """
73
+ Optimum Intel provides a simple interface to optimize Transformer models and convert them to \
74
+ OpenVINO™ Intermediate Representation (IR) format to accelerate end-to-end pipelines on \
75
+ Intel® architectures using OpenVINO™ runtime.
76
+ """
77
+ from helm.common.optional_dependencies import handle_module_not_found_error
78
+
79
+ try:
80
+ from optimum.intel.openvino import OVModelForCausalLM
81
+ except ModuleNotFoundError as e:
82
+ handle_module_not_found_error(e, ["openvino"])
83
+
84
+ self.device = "cpu"
85
+ # Security issue: currently we trust remote code by default.
86
+ # We retain this temporarily to maintain reverse compatibility.
87
+ # TODO: Delete if-else and don't set trust_remote_code=True
88
+ if "trust_remote_code" in kwargs:
89
+ self.model = OVModelForCausalLM.from_pretrained(
90
+ pretrained_model_name_or_path, export=True, **kwargs
91
+ ).to(self.device)
92
+ else:
93
+ self.model = OVModelForCausalLM.from_pretrained(
94
+ pretrained_model_name_or_path, export=True, trust_remote_code=True, **kwargs
95
+ ).to(self.device)
96
+ else:
97
+ # Security issue: currently we trust remote code by default.
98
+ # We retain this temporarily to maintain reverse compatibility.
99
+ # TODO: Delete if-else and don't set trust_remote_code=True
100
+ if "trust_remote_code" in kwargs:
101
+ self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs).to(
102
+ self.device
103
+ )
104
+ else:
105
+ self.model = AutoModelForCausalLM.from_pretrained(
106
+ pretrained_model_name_or_path, trust_remote_code=True, **kwargs
107
+ ).to(self.device)
108
+ self.wrapped_tokenizer = wrapped_tokenizer
71
109
 
72
110
  def serve_request(self, raw_request: HuggingFaceRequest) -> Dict:
73
111
  with self.wrapped_tokenizer as tokenizer:
@@ -170,7 +208,12 @@ class HuggingFaceServerFactory:
170
208
  _servers_lock: Lock = Lock()
171
209
 
172
210
  @staticmethod
173
- def get_server(helm_model_name: str, pretrained_model_name_or_path: str, **kwargs) -> Any:
211
+ def get_server(
212
+ helm_model_name: str,
213
+ pretrained_model_name_or_path: str,
214
+ wrapped_tokenizer: WrappedPreTrainedTokenizer,
215
+ **kwargs,
216
+ ) -> Any:
174
217
  """
175
218
  Checks if the desired HuggingFaceModel is cached. Creates the HuggingFaceModel if it's not cached.
176
219
  Returns the HuggingFaceModel.
@@ -182,7 +225,7 @@ class HuggingFaceServerFactory:
182
225
  f"for HELM model {helm_model_name} with Hugging Face Transformers"
183
226
  ):
184
227
  HuggingFaceServerFactory._servers[helm_model_name] = HuggingFaceServer(
185
- pretrained_model_name_or_path, **kwargs
228
+ pretrained_model_name_or_path, wrapped_tokenizer, **kwargs
186
229
  )
187
230
 
188
231
  return HuggingFaceServerFactory._servers[helm_model_name]
@@ -214,10 +257,25 @@ def _process_huggingface_client_kwargs(raw_kwargs: Dict[str, Any]):
214
257
 
215
258
 
216
259
  class HuggingFaceClient(CachingClient):
217
- def __init__(self, cache_config: CacheConfig, pretrained_model_name_or_path: Optional[str] = None, **kwargs):
260
+ def __init__(
261
+ self,
262
+ cache_config: CacheConfig,
263
+ tokenizer: Tokenizer,
264
+ pretrained_model_name_or_path: Optional[str] = None,
265
+ end_of_text_token: Optional[str] = None,
266
+ **kwargs,
267
+ ):
218
268
  super().__init__(cache_config=cache_config)
219
269
  self._pretrained_model_name_or_path = pretrained_model_name_or_path
270
+ if not isinstance(tokenizer, HuggingFaceTokenizer):
271
+ raise ValueError(
272
+ f"Tokenizer for Hugging Face model {pretrained_model_name_or_path} must be a HuggingFaceTokenizer, "
273
+ "but instead it is {tokenizer}"
274
+ )
275
+ self._wrapped_tokenizer: WrappedPreTrainedTokenizer = tokenizer.get_wrapped_tokenizer()
276
+ self._tokenizer = tokenizer
220
277
  self._kwargs = _process_huggingface_client_kwargs(kwargs)
278
+ self._end_of_text_token = end_of_text_token
221
279
 
222
280
  def make_request(self, request: Request) -> RequestResult:
223
281
  # Embedding not supported for this model
@@ -242,6 +300,7 @@ class HuggingFaceClient(CachingClient):
242
300
  huggingface_model: HuggingFaceServer = HuggingFaceServerFactory.get_server(
243
301
  helm_model_name=request.model,
244
302
  pretrained_model_name_or_path=pretrained_model_name_or_path,
303
+ wrapped_tokenizer=self._wrapped_tokenizer,
245
304
  **self._kwargs,
246
305
  )
247
306
 
@@ -284,7 +343,7 @@ class HuggingFaceClient(CachingClient):
284
343
  sequence_logprob += logprob
285
344
 
286
345
  completion = GeneratedOutput(text=raw_completion["text"], logprob=sequence_logprob, tokens=tokens)
287
- completion = truncate_sequence(completion, request)
346
+ completion = truncate_sequence(completion, request, end_of_text_token=self._end_of_text_token)
288
347
  completions.append(completion)
289
348
 
290
349
  return RequestResult(
@@ -60,8 +60,7 @@ class OpenAIClient(CachingClient):
60
60
 
61
61
  def _get_cache_key(self, raw_request: Dict, request: Request):
62
62
  cache_key = CachingClient.make_cache_key(raw_request, request)
63
- if is_vlm(request.model):
64
- assert request.multimodal_prompt is not None
63
+ if request.multimodal_prompt:
65
64
  prompt_key: str = generate_uid_for_multimodal_prompt(request.multimodal_prompt)
66
65
  cache_key = {**cache_key, "multimodal_prompt": prompt_key}
67
66
  del cache_key["messages"]
@@ -103,6 +102,14 @@ class OpenAIClient(CachingClient):
103
102
 
104
103
  def _make_chat_request(self, request: Request) -> RequestResult:
105
104
  messages: Optional[List[Dict[str, Union[str, Any]]]] = request.messages
105
+ if (
106
+ (request.prompt and request.messages)
107
+ or (request.prompt and request.multimodal_prompt)
108
+ or (request.messages and request.multimodal_prompt)
109
+ ):
110
+ raise ValueError(
111
+ f"More than one of `prompt`, `messages` and `multimodal_prompt` was set in request: {request}"
112
+ )
106
113
  if request.messages is not None:
107
114
  # Checks that all messages have a role and some content
108
115
  for message in request.messages:
@@ -130,9 +137,8 @@ class OpenAIClient(CachingClient):
130
137
  from helm.common.images_utils import encode_base64
131
138
 
132
139
  base64_image: str = encode_base64(media_object.location)
133
- content.append(
134
- {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
135
- )
140
+ image_object: Dict[str, str] = {"url": f"data:image/jpeg;base64,{base64_image}"}
141
+ content.append({"type": "image_url", "image_url": image_object})
136
142
  elif media_object.is_type(TEXT_TYPE):
137
143
  if media_object.text is None:
138
144
  raise ValueError("MediaObject of text type has missing text field value")