crfm-helm 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/METADATA +19 -5
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/RECORD +121 -76
- helm/benchmark/adaptation/adapter_spec.py +32 -31
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/multimodal_prompt.py +7 -0
- helm/benchmark/adaptation/adapters/multimodal/test_multimodal_prompt.py +2 -0
- helm/benchmark/annotation/air_bench_annotator.py +64 -0
- helm/benchmark/annotation/annotator_factory.py +6 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +1 -1
- helm/benchmark/annotation/live_qa_annotator.py +84 -0
- helm/benchmark/annotation/medication_qa_annotator.py +81 -0
- helm/benchmark/augmentations/perturbation.py +17 -1
- helm/benchmark/augmentations/test_perturbation.py +30 -0
- helm/benchmark/augmentations/translate_perturbation.py +1 -0
- helm/benchmark/huggingface_registration.py +16 -6
- helm/benchmark/metrics/air_bench_metrics.py +56 -0
- helm/benchmark/metrics/efficiency_metrics.py +9 -2
- helm/benchmark/metrics/evaluate_reference_metrics.py +16 -0
- helm/benchmark/metrics/fin_qa_metrics.py +60 -0
- helm/benchmark/metrics/fin_qa_metrics_helper.py +398 -0
- helm/benchmark/metrics/gpt4v_originality_critique_metrics.py +126 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +1 -0
- helm/benchmark/metrics/live_qa_metrics.py +23 -0
- helm/benchmark/metrics/medication_qa_metrics.py +23 -0
- helm/benchmark/metrics/prometheus_vision_critique_metrics.py +185 -0
- helm/benchmark/metrics/reka_vibe_critique_metrics.py +158 -0
- helm/benchmark/metrics/unitxt_metrics.py +20 -10
- helm/benchmark/metrics/vision_language/emd_utils.py +4 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +104 -21
- helm/benchmark/model_metadata_registry.py +5 -1
- helm/benchmark/presentation/schema.py +54 -4
- helm/benchmark/presentation/test_schema.py +11 -0
- helm/benchmark/run.py +16 -2
- helm/benchmark/run_expander.py +112 -63
- helm/benchmark/run_spec_factory.py +15 -10
- helm/benchmark/run_specs/air_bench_run_specs.py +40 -0
- helm/benchmark/run_specs/classic_run_specs.py +15 -11
- helm/benchmark/run_specs/decodingtrust_run_specs.py +3 -1
- helm/benchmark/run_specs/experimental_run_specs.py +33 -0
- helm/benchmark/run_specs/finance_run_specs.py +33 -0
- helm/benchmark/run_specs/vlm_run_specs.py +444 -65
- helm/benchmark/scenarios/air_bench_scenario.py +50 -0
- helm/benchmark/scenarios/ci_mcqa_scenario.py +80 -0
- helm/benchmark/scenarios/entity_data_imputation_scenario.py +8 -2
- helm/benchmark/scenarios/fin_qa_scenario.py +117 -0
- helm/benchmark/scenarios/legalbench_scenario.py +6 -2
- helm/benchmark/scenarios/math_scenario.py +1 -1
- helm/benchmark/scenarios/test_air_bench_scenario.py +27 -0
- helm/benchmark/scenarios/vision_language/a_okvqa_scenario.py +83 -0
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +3 -3
- helm/benchmark/scenarios/vision_language/crossmodal_3600_scenario.py +134 -0
- helm/benchmark/scenarios/vision_language/flickr30k_scenario.py +74 -0
- helm/benchmark/scenarios/vision_language/gqa_scenario.py +91 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +4 -2
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +13 -2
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +1 -5
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +1 -5
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +5 -3
- helm/benchmark/scenarios/vision_language/math_vista_scenario.py +117 -0
- helm/benchmark/scenarios/vision_language/mm_safety_bench_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/mscoco_captioning_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/mscoco_categorization_scenario.py +117 -0
- helm/benchmark/scenarios/vision_language/originality_scenario.py +35 -0
- helm/benchmark/scenarios/vision_language/pairs_scenario.py +247 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +3 -3
- helm/benchmark/scenarios/vision_language/vibe_eval_scenario.py +95 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +2 -2
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +4 -2
- helm/benchmark/static/schema_air_bench.yaml +3149 -0
- helm/benchmark/static/schema_classic.yaml +3 -59
- helm/benchmark/static/schema_finance.yaml +143 -0
- helm/benchmark/static/schema_image2structure.yaml +447 -0
- helm/benchmark/static/schema_instruction_following.yaml +3 -52
- helm/benchmark/static/schema_lite.yaml +3 -61
- helm/benchmark/static/schema_medical.yaml +255 -0
- helm/benchmark/static/schema_mmlu.yaml +3 -61
- helm/benchmark/static/schema_tables.yaml +200 -0
- helm/benchmark/static/schema_thai.yaml +223 -0
- helm/benchmark/static/schema_unitxt.yaml +3 -61
- helm/benchmark/static/schema_vhelm.yaml +824 -0
- helm/benchmark/static/schema_vhelm_lite.yaml +109 -0
- helm/benchmark/static_build/assets/air-overview-d2e6c49f.png +0 -0
- helm/benchmark/static_build/assets/index-30dbceba.js +10 -0
- helm/benchmark/static_build/assets/index-66b02d40.css +1 -0
- helm/benchmark/static_build/assets/overview-74aea3d8.png +0 -0
- helm/benchmark/static_build/assets/process-flow-bd2eba96.png +0 -0
- helm/benchmark/static_build/index.html +2 -2
- helm/clients/anthropic_client.py +78 -14
- helm/clients/auto_client.py +11 -0
- helm/clients/client.py +24 -7
- helm/clients/cohere_client.py +98 -3
- helm/clients/huggingface_client.py +71 -12
- helm/clients/openai_client.py +11 -5
- helm/clients/reka_client.py +189 -0
- helm/clients/test_client.py +3 -3
- helm/clients/test_huggingface_client.py +19 -3
- helm/clients/test_together_client.py +72 -2
- helm/clients/together_client.py +199 -2
- helm/clients/vertexai_client.py +117 -64
- helm/clients/vision_language/huggingface_vision2seq_client.py +145 -0
- helm/clients/vision_language/huggingface_vlm_client.py +12 -4
- helm/clients/vision_language/idefics_client.py +2 -2
- helm/clients/vision_language/paligemma_client.py +146 -0
- helm/clients/vision_language/palmyra_vision_client.py +84 -0
- helm/clients/yi_client.py +31 -0
- helm/common/critique_request.py +10 -1
- helm/common/images_utils.py +29 -3
- helm/config/model_deployments.yaml +504 -12
- helm/config/model_metadata.yaml +579 -52
- helm/config/tokenizer_configs.yaml +100 -1
- helm/proxy/critique/model_critique_client.py +32 -4
- helm/proxy/services/server_service.py +1 -1
- helm/tokenizers/auto_tokenizer.py +1 -1
- helm/tokenizers/cohere_tokenizer.py +44 -2
- helm/tokenizers/huggingface_tokenizer.py +36 -13
- helm/tokenizers/test_cohere_tokenizer.py +39 -0
- helm/tokenizers/test_huggingface_tokenizer.py +5 -1
- helm/benchmark/static/schema_vlm.yaml +0 -576
- helm/benchmark/static_build/assets/index-5088afcb.css +0 -1
- helm/benchmark/static_build/assets/index-d839df55.js +0 -9
- helm/benchmark/test_model_deployment_definition.py +0 -90
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/LICENSE +0 -0
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/WHEEL +0 -0
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/top_level.txt +0 -0
|
Binary file
|
|
Binary file
|
|
@@ -7,11 +7,11 @@
|
|
|
7
7
|
<title>Holistic Evaluation of Language Models (HELM)</title>
|
|
8
8
|
<meta name="description" content="The Holistic Evaluation of Language Models (HELM) serves as a living benchmark for transparency in language models. Providing broad coverage and recognizing incompleteness, multi-metric measurements, and standardization. All data and analysis are freely accessible on the website for exploration and study." />
|
|
9
9
|
<script type="text/javascript" src="./config.js"></script>
|
|
10
|
-
<script type="module" crossorigin src="./assets/index-
|
|
10
|
+
<script type="module" crossorigin src="./assets/index-30dbceba.js"></script>
|
|
11
11
|
<link rel="modulepreload" crossorigin href="./assets/react-d4a0b69b.js">
|
|
12
12
|
<link rel="modulepreload" crossorigin href="./assets/recharts-6d337683.js">
|
|
13
13
|
<link rel="modulepreload" crossorigin href="./assets/tremor-54a99cc4.js">
|
|
14
|
-
<link rel="stylesheet" href="./assets/index-
|
|
14
|
+
<link rel="stylesheet" href="./assets/index-66b02d40.css">
|
|
15
15
|
</head>
|
|
16
16
|
<body class="block">
|
|
17
17
|
<div id="root"></div>
|
helm/clients/anthropic_client.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
from typing import Any, Dict, List, Optional, TypedDict, Union, cast
|
|
2
2
|
import json
|
|
3
|
+
import os
|
|
3
4
|
import requests
|
|
5
|
+
import tempfile
|
|
4
6
|
import time
|
|
5
7
|
import urllib.parse
|
|
6
8
|
|
|
@@ -68,6 +70,9 @@ class AnthropicClient(CachingClient):
|
|
|
68
70
|
MAX_COMPLETION_LENGTH: int = (
|
|
69
71
|
8192 # See https://docs.google.com/document/d/1vX6xgoA-KEKxqtMlBVAqYvE8KUfZ7ABCjTxAjf1T5kI/edit#
|
|
70
72
|
)
|
|
73
|
+
# An Anthropic error message: "At least one of the image dimensions exceed max allowed size: 8000 pixels"
|
|
74
|
+
MAX_IMAGE_DIMENSION: int = 8000
|
|
75
|
+
|
|
71
76
|
ADDITIONAL_TOKENS: int = 5
|
|
72
77
|
PROMPT_ANSWER_START: str = "The answer is "
|
|
73
78
|
|
|
@@ -206,7 +211,7 @@ class AnthropicClient(CachingClient):
|
|
|
206
211
|
|
|
207
212
|
|
|
208
213
|
def _is_content_moderation_failure(response: Dict) -> bool:
|
|
209
|
-
"""Return whether a
|
|
214
|
+
"""Return whether a response failed because of the content moderation filter."""
|
|
210
215
|
if (
|
|
211
216
|
"error" in response
|
|
212
217
|
and "message" in response["error"]
|
|
@@ -238,7 +243,9 @@ class AnthropicMessagesResponseError(Exception):
|
|
|
238
243
|
|
|
239
244
|
class AnthropicMessagesClient(CachingClient):
|
|
240
245
|
# Source: https://docs.anthropic.com/claude/docs/models-overview
|
|
241
|
-
MAX_OUTPUT_TOKENS = 4096
|
|
246
|
+
MAX_OUTPUT_TOKENS: int = 4096
|
|
247
|
+
|
|
248
|
+
MAX_IMAGE_SIZE_BYTES: int = 5242880 # 5MB
|
|
242
249
|
|
|
243
250
|
def __init__(
|
|
244
251
|
self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig, api_key: Optional[str] = None
|
|
@@ -273,7 +280,7 @@ class AnthropicMessagesClient(CachingClient):
|
|
|
273
280
|
# TODO(#2439): Refactor out Request validation
|
|
274
281
|
if request.messages is not None or request.prompt:
|
|
275
282
|
raise AnthropicMessagesRequestError(
|
|
276
|
-
"Exactly one of Request.messages, Request.prompt or Request.
|
|
283
|
+
"Exactly one of Request.messages, Request.prompt or Request.multimodal_prompt should be set"
|
|
277
284
|
)
|
|
278
285
|
blocks: List[Union[TextBlockParam, ImageBlockParam]] = []
|
|
279
286
|
for media_object in request.multimodal_prompt.media_objects:
|
|
@@ -282,9 +289,53 @@ class AnthropicMessagesClient(CachingClient):
|
|
|
282
289
|
if not media_object.location:
|
|
283
290
|
raise Exception("MediaObject of image type has missing location field value")
|
|
284
291
|
|
|
285
|
-
from helm.common.images_utils import
|
|
292
|
+
from helm.common.images_utils import (
|
|
293
|
+
encode_base64,
|
|
294
|
+
get_dimensions,
|
|
295
|
+
copy_image,
|
|
296
|
+
resize_image_to_max_file_size,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
image_location: str = media_object.location
|
|
300
|
+
base64_image: str
|
|
301
|
+
|
|
302
|
+
image_width, image_height = get_dimensions(media_object.location)
|
|
303
|
+
if (
|
|
304
|
+
image_width > AnthropicClient.MAX_IMAGE_DIMENSION
|
|
305
|
+
or image_height > AnthropicClient.MAX_IMAGE_DIMENSION
|
|
306
|
+
):
|
|
307
|
+
hlog(
|
|
308
|
+
f"WARNING: Image {image_location} exceeds max allowed size: "
|
|
309
|
+
f"{AnthropicClient.MAX_IMAGE_DIMENSION} pixels"
|
|
310
|
+
)
|
|
311
|
+
# Save the resized image to a temporary file
|
|
312
|
+
with tempfile.NamedTemporaryFile(suffix=".jpg") as temp_file:
|
|
313
|
+
hlog(f"Resizing image to temporary path: {temp_file.name}")
|
|
314
|
+
copy_image(
|
|
315
|
+
src=image_location,
|
|
316
|
+
dest=temp_file.name,
|
|
317
|
+
width=min(image_width, AnthropicClient.MAX_IMAGE_DIMENSION),
|
|
318
|
+
height=min(image_height, AnthropicClient.MAX_IMAGE_DIMENSION),
|
|
319
|
+
)
|
|
320
|
+
base64_image = encode_base64(temp_file.name, format="JPEG")
|
|
321
|
+
|
|
322
|
+
elif os.path.getsize(image_location) > AnthropicMessagesClient.MAX_IMAGE_SIZE_BYTES:
|
|
323
|
+
hlog(
|
|
324
|
+
f"WARNING: Image {image_location} exceeds max allowed size: "
|
|
325
|
+
f"{AnthropicMessagesClient.MAX_IMAGE_SIZE_BYTES} bytes"
|
|
326
|
+
)
|
|
327
|
+
# Resize the image so it is smaller than the max allowed size
|
|
328
|
+
with tempfile.NamedTemporaryFile(suffix=".jpg") as temp_file:
|
|
329
|
+
hlog(f"Resizing image to temporary path: {temp_file.name}")
|
|
330
|
+
resize_image_to_max_file_size(
|
|
331
|
+
src=image_location,
|
|
332
|
+
dest=temp_file.name,
|
|
333
|
+
max_size_in_bytes=AnthropicMessagesClient.MAX_IMAGE_SIZE_BYTES,
|
|
334
|
+
)
|
|
335
|
+
base64_image = encode_base64(temp_file.name, format="JPEG")
|
|
336
|
+
else:
|
|
337
|
+
base64_image = encode_base64(image_location, format="JPEG")
|
|
286
338
|
|
|
287
|
-
base64_image: str = encode_base64(media_object.location, format="JPEG")
|
|
288
339
|
image_block: ImageBlockParam = {
|
|
289
340
|
"type": "image",
|
|
290
341
|
"source": {
|
|
@@ -302,7 +353,9 @@ class AnthropicMessagesClient(CachingClient):
|
|
|
302
353
|
"type": "text",
|
|
303
354
|
"text": media_object.text,
|
|
304
355
|
}
|
|
305
|
-
blocks
|
|
356
|
+
# Anthropic does not support empty text blocks
|
|
357
|
+
if media_object.text.strip():
|
|
358
|
+
blocks.append(text_block)
|
|
306
359
|
messages = [{"role": "user", "content": blocks}]
|
|
307
360
|
|
|
308
361
|
else:
|
|
@@ -338,14 +391,25 @@ class AnthropicMessagesClient(CachingClient):
|
|
|
338
391
|
return response
|
|
339
392
|
raise
|
|
340
393
|
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
394
|
+
try:
|
|
395
|
+
cache_key = CachingClient.make_cache_key(
|
|
396
|
+
{
|
|
397
|
+
"completion_index": completion_index,
|
|
398
|
+
**raw_request,
|
|
399
|
+
},
|
|
400
|
+
request,
|
|
401
|
+
)
|
|
402
|
+
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
403
|
+
except AnthropicMessagesResponseError:
|
|
404
|
+
hlog("WARNING: Response has empty content")
|
|
405
|
+
return RequestResult(
|
|
406
|
+
success=False,
|
|
407
|
+
cached=False,
|
|
408
|
+
error="Anthropic response has empty content",
|
|
409
|
+
completions=[],
|
|
410
|
+
embedding=[],
|
|
411
|
+
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
412
|
+
)
|
|
349
413
|
|
|
350
414
|
if _is_content_moderation_failure(response):
|
|
351
415
|
hlog(
|
helm/clients/auto_client.py
CHANGED
|
@@ -5,6 +5,7 @@ from typing import Any, Dict, Mapping, Optional
|
|
|
5
5
|
from retrying import Attempt, RetryError
|
|
6
6
|
|
|
7
7
|
from helm.benchmark.model_deployment_registry import ModelDeployment, get_model_deployment
|
|
8
|
+
from helm.benchmark.tokenizer_config_registry import get_tokenizer_config
|
|
8
9
|
from helm.common.file_caches.file_cache import FileCache
|
|
9
10
|
from helm.common.file_caches.local_file_cache import LocalFileCache
|
|
10
11
|
from helm.common.credentials_utils import provide_api_key
|
|
@@ -88,6 +89,10 @@ class AutoClient(Client):
|
|
|
88
89
|
"location": lambda: self.credentials.get(host_organization + "Location", None), # VertexAI
|
|
89
90
|
"hf_auth_token": lambda: self.credentials.get("huggingfaceAuthToken", None), # HuggingFace
|
|
90
91
|
"file_cache": lambda: self._get_file_cache(host_organization), # Text-to-image models
|
|
92
|
+
"endpoint": lambda: self.credentials.get(host_organization + "Endpoint", None), # Palmyra
|
|
93
|
+
"end_of_text_token": lambda: self._get_end_of_text_token(
|
|
94
|
+
tokenizer_name=model_deployment.tokenizer_name or model_deployment.name
|
|
95
|
+
),
|
|
91
96
|
},
|
|
92
97
|
)
|
|
93
98
|
client = create_object(client_spec)
|
|
@@ -213,3 +218,9 @@ class AutoClient(Client):
|
|
|
213
218
|
# Initialize `FileCache` for text-to-image model APIs
|
|
214
219
|
local_file_cache_path: str = os.path.join(self.file_storage_path, "output", host_organization)
|
|
215
220
|
return LocalFileCache(local_file_cache_path, file_extension="png")
|
|
221
|
+
|
|
222
|
+
def _get_end_of_text_token(self, tokenizer_name: str) -> Optional[str]:
|
|
223
|
+
tokenizer_config = get_tokenizer_config(tokenizer_name)
|
|
224
|
+
if tokenizer_config is None:
|
|
225
|
+
raise ValueError(f"Could not find tokenizer_config for tokenizer {tokenizer_name}")
|
|
226
|
+
return tokenizer_config.end_of_text_token
|
helm/clients/client.py
CHANGED
|
@@ -39,13 +39,17 @@ class CachingClient(Client):
|
|
|
39
39
|
"""
|
|
40
40
|
if request.random is not None:
|
|
41
41
|
assert "random" not in raw_request
|
|
42
|
-
|
|
42
|
+
return {**raw_request, "random": request.random}
|
|
43
43
|
else:
|
|
44
|
-
|
|
45
|
-
return cache_key
|
|
44
|
+
return {**raw_request}
|
|
46
45
|
|
|
47
46
|
|
|
48
|
-
def truncate_sequence(
|
|
47
|
+
def truncate_sequence(
|
|
48
|
+
sequence: GeneratedOutput,
|
|
49
|
+
request: Request,
|
|
50
|
+
end_of_text_token: Optional[str] = None,
|
|
51
|
+
print_warning: bool = True,
|
|
52
|
+
) -> GeneratedOutput:
|
|
49
53
|
"""
|
|
50
54
|
Certain providers have bugs where they aren't respecting max_tokens,
|
|
51
55
|
stop_sequences and the end of text token, so as a hack, we have to manually
|
|
@@ -64,7 +68,11 @@ def truncate_sequence(sequence: GeneratedOutput, request: Request, print_warning
|
|
|
64
68
|
hlog("WARNING: don't know how to handle echo_prompt and max_tokens > 0, not truncating")
|
|
65
69
|
return sequence
|
|
66
70
|
|
|
67
|
-
|
|
71
|
+
if end_of_text_token:
|
|
72
|
+
stop_sequences = request.stop_sequences + [end_of_text_token]
|
|
73
|
+
else:
|
|
74
|
+
stop_sequences = request.stop_sequences
|
|
75
|
+
for stop in stop_sequences:
|
|
68
76
|
# Find `stop` in the text
|
|
69
77
|
try:
|
|
70
78
|
new_text = sequence.text[: sequence.text.index(stop)]
|
|
@@ -116,7 +124,12 @@ def truncate_sequence(sequence: GeneratedOutput, request: Request, print_warning
|
|
|
116
124
|
|
|
117
125
|
|
|
118
126
|
def truncate_and_tokenize_response_text(
|
|
119
|
-
text: str,
|
|
127
|
+
text: str,
|
|
128
|
+
request: Request,
|
|
129
|
+
tokenizer: Tokenizer,
|
|
130
|
+
tokenizer_name: str,
|
|
131
|
+
end_of_text_token: Optional[str] = None,
|
|
132
|
+
original_finish_reason: str = "endoftext",
|
|
120
133
|
) -> GeneratedOutput:
|
|
121
134
|
"""Truncate a string-only response to respect stop_sequences and max_tokens.
|
|
122
135
|
|
|
@@ -139,7 +152,11 @@ def truncate_and_tokenize_response_text(
|
|
|
139
152
|
if request.echo_prompt:
|
|
140
153
|
raise Exception("truncate_and_tokenize_response_text() does not support requests with echo_prompt = True")
|
|
141
154
|
|
|
142
|
-
|
|
155
|
+
if end_of_text_token:
|
|
156
|
+
stop_sequences = request.stop_sequences + [end_of_text_token]
|
|
157
|
+
else:
|
|
158
|
+
stop_sequences = request.stop_sequences
|
|
159
|
+
for stop_sequence in stop_sequences:
|
|
143
160
|
try:
|
|
144
161
|
text = text[: text.index(stop_sequence)]
|
|
145
162
|
finish_reason = "stop"
|
helm/clients/cohere_client.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import requests
|
|
3
|
-
from typing import List
|
|
3
|
+
from typing import List, Optional, Sequence, TypedDict
|
|
4
4
|
|
|
5
5
|
from helm.common.cache import CacheConfig
|
|
6
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
6
7
|
from helm.common.request import (
|
|
7
8
|
wrap_request_time,
|
|
8
9
|
EMBEDDING_UNAVAILABLE_REQUEST_RESULT,
|
|
@@ -11,8 +12,13 @@ from helm.common.request import (
|
|
|
11
12
|
GeneratedOutput,
|
|
12
13
|
Token,
|
|
13
14
|
)
|
|
14
|
-
from .client import CachingClient, truncate_sequence
|
|
15
|
-
from .cohere_utils import get_cohere_url, DEFAULT_COHERE_API_VERSION
|
|
15
|
+
from helm.clients.client import CachingClient, truncate_sequence
|
|
16
|
+
from helm.clients.cohere_utils import get_cohere_url, DEFAULT_COHERE_API_VERSION
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
import cohere
|
|
20
|
+
except ModuleNotFoundError as e:
|
|
21
|
+
handle_module_not_found_error(e, ["cohere"])
|
|
16
22
|
|
|
17
23
|
|
|
18
24
|
class CohereClient(CachingClient):
|
|
@@ -152,3 +158,92 @@ class CohereClient(CachingClient):
|
|
|
152
158
|
completions=completions,
|
|
153
159
|
embedding=[],
|
|
154
160
|
)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class CohereRawChatRequest(TypedDict):
|
|
164
|
+
message: str
|
|
165
|
+
model: Optional[str]
|
|
166
|
+
preamble: Optional[str]
|
|
167
|
+
chat_history: Optional[Sequence[cohere.ChatMessage]]
|
|
168
|
+
temperature: Optional[float]
|
|
169
|
+
max_tokens: Optional[int]
|
|
170
|
+
k: Optional[int]
|
|
171
|
+
p: Optional[float]
|
|
172
|
+
seed: Optional[float]
|
|
173
|
+
stop_sequences: Optional[Sequence[str]]
|
|
174
|
+
frequency_penalty: Optional[float]
|
|
175
|
+
presence_penalty: Optional[float]
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def convert_to_raw_chat_request(request: Request) -> CohereRawChatRequest:
|
|
179
|
+
# TODO: Support chat
|
|
180
|
+
model = request.model.replace("cohere/", "")
|
|
181
|
+
return {
|
|
182
|
+
"message": request.prompt,
|
|
183
|
+
"model": model,
|
|
184
|
+
"preamble": None,
|
|
185
|
+
"chat_history": None,
|
|
186
|
+
"temperature": request.temperature,
|
|
187
|
+
"max_tokens": request.max_tokens,
|
|
188
|
+
"k": request.top_k_per_token,
|
|
189
|
+
"p": request.top_p,
|
|
190
|
+
"stop_sequences": request.stop_sequences,
|
|
191
|
+
"seed": float(request.random) if request.random is not None else None,
|
|
192
|
+
"frequency_penalty": request.frequency_penalty,
|
|
193
|
+
"presence_penalty": request.presence_penalty,
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class CohereChatClient(CachingClient):
|
|
198
|
+
"""
|
|
199
|
+
Leverages the chat endpoint: https://docs.cohere.com/reference/chat
|
|
200
|
+
|
|
201
|
+
Cohere models will only support chat soon: https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
def __init__(self, api_key: str, cache_config: CacheConfig):
|
|
205
|
+
super().__init__(cache_config=cache_config)
|
|
206
|
+
self.client = cohere.Client(api_key=api_key)
|
|
207
|
+
|
|
208
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
209
|
+
if request.embedding:
|
|
210
|
+
return EMBEDDING_UNAVAILABLE_REQUEST_RESULT
|
|
211
|
+
# TODO: Support multiple completions
|
|
212
|
+
assert request.num_completions == 1, "CohereChatClient only supports num_completions=1"
|
|
213
|
+
# TODO: Support messages
|
|
214
|
+
assert not request.messages, "CohereChatClient currently does not support the messages API"
|
|
215
|
+
|
|
216
|
+
raw_request: CohereRawChatRequest = convert_to_raw_chat_request(request)
|
|
217
|
+
|
|
218
|
+
try:
|
|
219
|
+
|
|
220
|
+
def do_it():
|
|
221
|
+
"""
|
|
222
|
+
Send the request to the Cohere Chat API. Responses will be structured like this:
|
|
223
|
+
cohere.Chat {
|
|
224
|
+
message: What's up?
|
|
225
|
+
text: Hey there! How's it going? I'm doing well, thank you for asking 😊.
|
|
226
|
+
...
|
|
227
|
+
}
|
|
228
|
+
"""
|
|
229
|
+
raw_response = self.client.chat(**raw_request).dict()
|
|
230
|
+
assert "text" in raw_response, f"Response does not contain text: {raw_response}"
|
|
231
|
+
return raw_response
|
|
232
|
+
|
|
233
|
+
response, cached = self.cache.get(raw_request, wrap_request_time(do_it))
|
|
234
|
+
except (requests.exceptions.RequestException, AssertionError) as e:
|
|
235
|
+
error: str = f"CohereClient error: {e}"
|
|
236
|
+
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
237
|
+
|
|
238
|
+
completions: List[GeneratedOutput] = []
|
|
239
|
+
completion: GeneratedOutput = GeneratedOutput(text=response["text"], logprob=0.0, tokens=[])
|
|
240
|
+
completions.append(completion)
|
|
241
|
+
|
|
242
|
+
return RequestResult(
|
|
243
|
+
success=True,
|
|
244
|
+
cached=cached,
|
|
245
|
+
request_time=response["request_time"],
|
|
246
|
+
request_datetime=response["request_datetime"],
|
|
247
|
+
completions=completions,
|
|
248
|
+
embedding=[],
|
|
249
|
+
)
|
|
@@ -17,6 +17,7 @@ from helm.common.request import (
|
|
|
17
17
|
GeneratedOutput,
|
|
18
18
|
Token,
|
|
19
19
|
)
|
|
20
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
20
21
|
from .client import CachingClient, truncate_sequence
|
|
21
22
|
from helm.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer, WrappedPreTrainedTokenizer
|
|
22
23
|
from threading import Lock
|
|
@@ -53,7 +54,13 @@ class HuggingFaceRequest(TypedDict):
|
|
|
53
54
|
class HuggingFaceServer:
|
|
54
55
|
"""A thin wrapper around a Hugging Face AutoModelForCausalLM for HuggingFaceClient to call."""
|
|
55
56
|
|
|
56
|
-
def __init__(
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
pretrained_model_name_or_path: str,
|
|
60
|
+
wrapped_tokenizer: WrappedPreTrainedTokenizer,
|
|
61
|
+
openvino=False,
|
|
62
|
+
**kwargs,
|
|
63
|
+
):
|
|
57
64
|
if torch.cuda.is_available():
|
|
58
65
|
hlog("CUDA is available, initializing with a GPU...")
|
|
59
66
|
self.device: str = "cuda:0"
|
|
@@ -61,13 +68,44 @@ class HuggingFaceServer:
|
|
|
61
68
|
self.device = "cpu"
|
|
62
69
|
with htrack_block(f"Loading Hugging Face model {pretrained_model_name_or_path}"):
|
|
63
70
|
# WARNING this may fail if your GPU does not have enough memory
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
+
if openvino:
|
|
72
|
+
"""
|
|
73
|
+
Optimum Intel provides a simple interface to optimize Transformer models and convert them to \
|
|
74
|
+
OpenVINO™ Intermediate Representation (IR) format to accelerate end-to-end pipelines on \
|
|
75
|
+
Intel® architectures using OpenVINO™ runtime.
|
|
76
|
+
"""
|
|
77
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
from optimum.intel.openvino import OVModelForCausalLM
|
|
81
|
+
except ModuleNotFoundError as e:
|
|
82
|
+
handle_module_not_found_error(e, ["openvino"])
|
|
83
|
+
|
|
84
|
+
self.device = "cpu"
|
|
85
|
+
# Security issue: currently we trust remote code by default.
|
|
86
|
+
# We retain this temporarily to maintain reverse compatibility.
|
|
87
|
+
# TODO: Delete if-else and don't set trust_remote_code=True
|
|
88
|
+
if "trust_remote_code" in kwargs:
|
|
89
|
+
self.model = OVModelForCausalLM.from_pretrained(
|
|
90
|
+
pretrained_model_name_or_path, export=True, **kwargs
|
|
91
|
+
).to(self.device)
|
|
92
|
+
else:
|
|
93
|
+
self.model = OVModelForCausalLM.from_pretrained(
|
|
94
|
+
pretrained_model_name_or_path, export=True, trust_remote_code=True, **kwargs
|
|
95
|
+
).to(self.device)
|
|
96
|
+
else:
|
|
97
|
+
# Security issue: currently we trust remote code by default.
|
|
98
|
+
# We retain this temporarily to maintain reverse compatibility.
|
|
99
|
+
# TODO: Delete if-else and don't set trust_remote_code=True
|
|
100
|
+
if "trust_remote_code" in kwargs:
|
|
101
|
+
self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs).to(
|
|
102
|
+
self.device
|
|
103
|
+
)
|
|
104
|
+
else:
|
|
105
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
106
|
+
pretrained_model_name_or_path, trust_remote_code=True, **kwargs
|
|
107
|
+
).to(self.device)
|
|
108
|
+
self.wrapped_tokenizer = wrapped_tokenizer
|
|
71
109
|
|
|
72
110
|
def serve_request(self, raw_request: HuggingFaceRequest) -> Dict:
|
|
73
111
|
with self.wrapped_tokenizer as tokenizer:
|
|
@@ -170,7 +208,12 @@ class HuggingFaceServerFactory:
|
|
|
170
208
|
_servers_lock: Lock = Lock()
|
|
171
209
|
|
|
172
210
|
@staticmethod
|
|
173
|
-
def get_server(
|
|
211
|
+
def get_server(
|
|
212
|
+
helm_model_name: str,
|
|
213
|
+
pretrained_model_name_or_path: str,
|
|
214
|
+
wrapped_tokenizer: WrappedPreTrainedTokenizer,
|
|
215
|
+
**kwargs,
|
|
216
|
+
) -> Any:
|
|
174
217
|
"""
|
|
175
218
|
Checks if the desired HuggingFaceModel is cached. Creates the HuggingFaceModel if it's not cached.
|
|
176
219
|
Returns the HuggingFaceModel.
|
|
@@ -182,7 +225,7 @@ class HuggingFaceServerFactory:
|
|
|
182
225
|
f"for HELM model {helm_model_name} with Hugging Face Transformers"
|
|
183
226
|
):
|
|
184
227
|
HuggingFaceServerFactory._servers[helm_model_name] = HuggingFaceServer(
|
|
185
|
-
pretrained_model_name_or_path, **kwargs
|
|
228
|
+
pretrained_model_name_or_path, wrapped_tokenizer, **kwargs
|
|
186
229
|
)
|
|
187
230
|
|
|
188
231
|
return HuggingFaceServerFactory._servers[helm_model_name]
|
|
@@ -214,10 +257,25 @@ def _process_huggingface_client_kwargs(raw_kwargs: Dict[str, Any]):
|
|
|
214
257
|
|
|
215
258
|
|
|
216
259
|
class HuggingFaceClient(CachingClient):
|
|
217
|
-
def __init__(
|
|
260
|
+
def __init__(
|
|
261
|
+
self,
|
|
262
|
+
cache_config: CacheConfig,
|
|
263
|
+
tokenizer: Tokenizer,
|
|
264
|
+
pretrained_model_name_or_path: Optional[str] = None,
|
|
265
|
+
end_of_text_token: Optional[str] = None,
|
|
266
|
+
**kwargs,
|
|
267
|
+
):
|
|
218
268
|
super().__init__(cache_config=cache_config)
|
|
219
269
|
self._pretrained_model_name_or_path = pretrained_model_name_or_path
|
|
270
|
+
if not isinstance(tokenizer, HuggingFaceTokenizer):
|
|
271
|
+
raise ValueError(
|
|
272
|
+
f"Tokenizer for Hugging Face model {pretrained_model_name_or_path} must be a HuggingFaceTokenizer, "
|
|
273
|
+
"but instead it is {tokenizer}"
|
|
274
|
+
)
|
|
275
|
+
self._wrapped_tokenizer: WrappedPreTrainedTokenizer = tokenizer.get_wrapped_tokenizer()
|
|
276
|
+
self._tokenizer = tokenizer
|
|
220
277
|
self._kwargs = _process_huggingface_client_kwargs(kwargs)
|
|
278
|
+
self._end_of_text_token = end_of_text_token
|
|
221
279
|
|
|
222
280
|
def make_request(self, request: Request) -> RequestResult:
|
|
223
281
|
# Embedding not supported for this model
|
|
@@ -242,6 +300,7 @@ class HuggingFaceClient(CachingClient):
|
|
|
242
300
|
huggingface_model: HuggingFaceServer = HuggingFaceServerFactory.get_server(
|
|
243
301
|
helm_model_name=request.model,
|
|
244
302
|
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
|
303
|
+
wrapped_tokenizer=self._wrapped_tokenizer,
|
|
245
304
|
**self._kwargs,
|
|
246
305
|
)
|
|
247
306
|
|
|
@@ -284,7 +343,7 @@ class HuggingFaceClient(CachingClient):
|
|
|
284
343
|
sequence_logprob += logprob
|
|
285
344
|
|
|
286
345
|
completion = GeneratedOutput(text=raw_completion["text"], logprob=sequence_logprob, tokens=tokens)
|
|
287
|
-
completion = truncate_sequence(completion, request)
|
|
346
|
+
completion = truncate_sequence(completion, request, end_of_text_token=self._end_of_text_token)
|
|
288
347
|
completions.append(completion)
|
|
289
348
|
|
|
290
349
|
return RequestResult(
|
helm/clients/openai_client.py
CHANGED
|
@@ -60,8 +60,7 @@ class OpenAIClient(CachingClient):
|
|
|
60
60
|
|
|
61
61
|
def _get_cache_key(self, raw_request: Dict, request: Request):
|
|
62
62
|
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
63
|
-
if
|
|
64
|
-
assert request.multimodal_prompt is not None
|
|
63
|
+
if request.multimodal_prompt:
|
|
65
64
|
prompt_key: str = generate_uid_for_multimodal_prompt(request.multimodal_prompt)
|
|
66
65
|
cache_key = {**cache_key, "multimodal_prompt": prompt_key}
|
|
67
66
|
del cache_key["messages"]
|
|
@@ -103,6 +102,14 @@ class OpenAIClient(CachingClient):
|
|
|
103
102
|
|
|
104
103
|
def _make_chat_request(self, request: Request) -> RequestResult:
|
|
105
104
|
messages: Optional[List[Dict[str, Union[str, Any]]]] = request.messages
|
|
105
|
+
if (
|
|
106
|
+
(request.prompt and request.messages)
|
|
107
|
+
or (request.prompt and request.multimodal_prompt)
|
|
108
|
+
or (request.messages and request.multimodal_prompt)
|
|
109
|
+
):
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"More than one of `prompt`, `messages` and `multimodal_prompt` was set in request: {request}"
|
|
112
|
+
)
|
|
106
113
|
if request.messages is not None:
|
|
107
114
|
# Checks that all messages have a role and some content
|
|
108
115
|
for message in request.messages:
|
|
@@ -130,9 +137,8 @@ class OpenAIClient(CachingClient):
|
|
|
130
137
|
from helm.common.images_utils import encode_base64
|
|
131
138
|
|
|
132
139
|
base64_image: str = encode_base64(media_object.location)
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
)
|
|
140
|
+
image_object: Dict[str, str] = {"url": f"data:image/jpeg;base64,{base64_image}"}
|
|
141
|
+
content.append({"type": "image_url", "image_url": image_object})
|
|
136
142
|
elif media_object.is_type(TEXT_TYPE):
|
|
137
143
|
if media_object.text is None:
|
|
138
144
|
raise ValueError("MediaObject of text type has missing text field value")
|