crfm-helm 0.5.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/METADATA +13 -3
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/RECORD +96 -63
- helm/benchmark/adaptation/adapter_spec.py +32 -31
- helm/benchmark/annotation/air_bench_annotator.py +64 -0
- helm/benchmark/annotation/annotator_factory.py +6 -0
- helm/benchmark/annotation/live_qa_annotator.py +84 -0
- helm/benchmark/annotation/medication_qa_annotator.py +81 -0
- helm/benchmark/augmentations/translate_perturbation.py +1 -0
- helm/benchmark/huggingface_registration.py +16 -6
- helm/benchmark/metrics/air_bench_metrics.py +56 -0
- helm/benchmark/metrics/fin_qa_metrics.py +60 -0
- helm/benchmark/metrics/fin_qa_metrics_helper.py +398 -0
- helm/benchmark/metrics/gpt4v_originality_critique_metrics.py +126 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +1 -0
- helm/benchmark/metrics/live_qa_metrics.py +23 -0
- helm/benchmark/metrics/medication_qa_metrics.py +23 -0
- helm/benchmark/metrics/prometheus_vision_critique_metrics.py +185 -0
- helm/benchmark/metrics/reka_vibe_critique_metrics.py +158 -0
- helm/benchmark/metrics/unitxt_metrics.py +20 -10
- helm/benchmark/metrics/vision_language/emd_utils.py +4 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +29 -71
- helm/benchmark/presentation/schema.py +54 -4
- helm/benchmark/presentation/test_schema.py +11 -0
- helm/benchmark/run.py +16 -2
- helm/benchmark/run_expander.py +77 -0
- helm/benchmark/run_spec_factory.py +4 -0
- helm/benchmark/run_specs/air_bench_run_specs.py +40 -0
- helm/benchmark/run_specs/classic_run_specs.py +15 -11
- helm/benchmark/run_specs/decodingtrust_run_specs.py +3 -1
- helm/benchmark/run_specs/experimental_run_specs.py +33 -0
- helm/benchmark/run_specs/finance_run_specs.py +33 -0
- helm/benchmark/run_specs/vlm_run_specs.py +168 -45
- helm/benchmark/scenarios/air_bench_scenario.py +50 -0
- helm/benchmark/scenarios/ci_mcqa_scenario.py +80 -0
- helm/benchmark/scenarios/entity_data_imputation_scenario.py +8 -2
- helm/benchmark/scenarios/fin_qa_scenario.py +117 -0
- helm/benchmark/scenarios/test_air_bench_scenario.py +27 -0
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +3 -3
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +13 -2
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +1 -5
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +0 -4
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +4 -2
- helm/benchmark/scenarios/vision_language/pairs_scenario.py +6 -5
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +3 -3
- helm/benchmark/scenarios/vision_language/vibe_eval_scenario.py +95 -0
- helm/benchmark/static/schema_air_bench.yaml +3149 -0
- helm/benchmark/static/schema_classic.yaml +3 -59
- helm/benchmark/static/schema_finance.yaml +143 -0
- helm/benchmark/static/schema_image2structure.yaml +254 -111
- helm/benchmark/static/schema_instruction_following.yaml +3 -52
- helm/benchmark/static/schema_lite.yaml +3 -61
- helm/benchmark/static/schema_medical.yaml +255 -0
- helm/benchmark/static/schema_mmlu.yaml +3 -61
- helm/benchmark/static/schema_tables.yaml +200 -0
- helm/benchmark/static/schema_thai.yaml +223 -0
- helm/benchmark/static/schema_unitxt.yaml +3 -61
- helm/benchmark/static/{schema_vlm.yaml → schema_vhelm.yaml} +294 -293
- helm/benchmark/static/schema_vhelm_lite.yaml +4 -59
- helm/benchmark/static_build/assets/air-overview-d2e6c49f.png +0 -0
- helm/benchmark/static_build/assets/index-30dbceba.js +10 -0
- helm/benchmark/static_build/assets/index-66b02d40.css +1 -0
- helm/benchmark/static_build/assets/overview-74aea3d8.png +0 -0
- helm/benchmark/static_build/assets/process-flow-bd2eba96.png +0 -0
- helm/benchmark/static_build/index.html +2 -2
- helm/clients/anthropic_client.py +43 -9
- helm/clients/auto_client.py +11 -0
- helm/clients/client.py +24 -7
- helm/clients/cohere_client.py +98 -3
- helm/clients/huggingface_client.py +71 -12
- helm/clients/openai_client.py +9 -2
- helm/clients/reka_client.py +189 -0
- helm/clients/test_client.py +3 -3
- helm/clients/test_huggingface_client.py +19 -3
- helm/clients/test_together_client.py +72 -2
- helm/clients/together_client.py +129 -23
- helm/clients/vertexai_client.py +62 -18
- helm/clients/vision_language/huggingface_vlm_client.py +1 -0
- helm/clients/vision_language/paligemma_client.py +146 -0
- helm/clients/vision_language/palmyra_vision_client.py +84 -0
- helm/clients/yi_client.py +31 -0
- helm/common/critique_request.py +10 -1
- helm/common/images_utils.py +19 -0
- helm/config/model_deployments.yaml +412 -18
- helm/config/model_metadata.yaml +447 -25
- helm/config/tokenizer_configs.yaml +93 -1
- helm/proxy/critique/model_critique_client.py +32 -4
- helm/proxy/services/server_service.py +1 -1
- helm/tokenizers/auto_tokenizer.py +1 -1
- helm/tokenizers/cohere_tokenizer.py +44 -2
- helm/tokenizers/huggingface_tokenizer.py +36 -13
- helm/tokenizers/test_cohere_tokenizer.py +39 -0
- helm/tokenizers/test_huggingface_tokenizer.py +5 -1
- helm/benchmark/static_build/assets/index-737eef9e.js +0 -10
- helm/benchmark/static_build/assets/index-878a1094.css +0 -1
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/LICENSE +0 -0
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/WHEEL +0 -0
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/top_level.txt +0 -0
helm/clients/test_client.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
from helm.common.
|
|
2
|
-
from helm.tokenizers.
|
|
1
|
+
from helm.common.cache_backend_config import BlackHoleCacheBackendConfig
|
|
2
|
+
from helm.tokenizers.auto_tokenizer import AutoTokenizer
|
|
3
3
|
from .client import truncate_sequence, truncate_and_tokenize_response_text
|
|
4
4
|
from typing import List
|
|
5
5
|
from helm.common.request import Request, GeneratedOutput, Token
|
|
@@ -52,8 +52,8 @@ def test_truncate_sequence():
|
|
|
52
52
|
|
|
53
53
|
|
|
54
54
|
def test_truncate_and_tokenize_response_text():
|
|
55
|
-
tokenizer = HuggingFaceTokenizer(BlackHoleCacheConfig())
|
|
56
55
|
tokenizer_name = "huggingface/gpt2"
|
|
56
|
+
tokenizer = AutoTokenizer(credentials={}, cache_backend_config=BlackHoleCacheBackendConfig())
|
|
57
57
|
|
|
58
58
|
# No truncation
|
|
59
59
|
response = truncate_and_tokenize_response_text(
|
|
@@ -3,12 +3,18 @@ import pytest
|
|
|
3
3
|
from helm.common.cache import BlackHoleCacheConfig
|
|
4
4
|
from helm.common.request import Request, RequestResult
|
|
5
5
|
from helm.clients.huggingface_client import HuggingFaceClient
|
|
6
|
+
from helm.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class TestHuggingFaceClient:
|
|
9
10
|
def test_gpt2(self):
|
|
11
|
+
tokenizer = HuggingFaceTokenizer(
|
|
12
|
+
BlackHoleCacheConfig(), "huggingface/gpt2", pretrained_model_name_or_path="openai/gpt2"
|
|
13
|
+
)
|
|
10
14
|
client = HuggingFaceClient(
|
|
11
|
-
cache_config=BlackHoleCacheConfig(),
|
|
15
|
+
cache_config=BlackHoleCacheConfig(),
|
|
16
|
+
tokenizer=tokenizer,
|
|
17
|
+
pretrained_model_name_or_path="openai-community/gpt2",
|
|
12
18
|
)
|
|
13
19
|
prompt: str = "I am a computer scientist."
|
|
14
20
|
result: RequestResult = client.make_request(
|
|
@@ -29,8 +35,13 @@ class TestHuggingFaceClient:
|
|
|
29
35
|
|
|
30
36
|
@pytest.mark.skip(reason="GPT-J 6B is 22 GB and extremely slow without a GPU.")
|
|
31
37
|
def test_gptj_6b(self):
|
|
38
|
+
tokenizer = HuggingFaceTokenizer(
|
|
39
|
+
BlackHoleCacheConfig(), "huggingface/gpt2", pretrained_model_name_or_path="openai/gpt2"
|
|
40
|
+
)
|
|
32
41
|
client = HuggingFaceClient(
|
|
33
|
-
cache_config=BlackHoleCacheConfig(),
|
|
42
|
+
cache_config=BlackHoleCacheConfig(),
|
|
43
|
+
tokenizer=tokenizer,
|
|
44
|
+
pretrained_model_name_or_path="openai-community/gpt2",
|
|
34
45
|
)
|
|
35
46
|
result: RequestResult = client.make_request(
|
|
36
47
|
Request(
|
|
@@ -45,8 +56,13 @@ class TestHuggingFaceClient:
|
|
|
45
56
|
assert len(result.completions) == 3
|
|
46
57
|
|
|
47
58
|
def test_logprob(self):
|
|
59
|
+
tokenizer = HuggingFaceTokenizer(
|
|
60
|
+
BlackHoleCacheConfig(), "huggingface/gpt2", pretrained_model_name_or_path="openai/gpt2"
|
|
61
|
+
)
|
|
48
62
|
client = HuggingFaceClient(
|
|
49
|
-
cache_config=BlackHoleCacheConfig(),
|
|
63
|
+
cache_config=BlackHoleCacheConfig(),
|
|
64
|
+
tokenizer=tokenizer,
|
|
65
|
+
pretrained_model_name_or_path="openai-community/gpt2",
|
|
50
66
|
)
|
|
51
67
|
prompt: str = "I am a computer scientist."
|
|
52
68
|
result: RequestResult = client.make_request(
|
|
@@ -2,10 +2,10 @@ import os
|
|
|
2
2
|
import pytest
|
|
3
3
|
import tempfile
|
|
4
4
|
|
|
5
|
-
from helm.common.cache import SqliteCacheConfig
|
|
5
|
+
from helm.common.cache import BlackHoleCacheConfig, SqliteCacheConfig
|
|
6
6
|
from helm.common.request import Request
|
|
7
7
|
|
|
8
|
-
from .together_client import TogetherClient, TogetherClientError
|
|
8
|
+
from .together_client import TogetherClient, TogetherChatClient, TogetherCompletionClient, TogetherClientError
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class TestTogetherClient:
|
|
@@ -107,3 +107,73 @@ class TestTogetherClient:
|
|
|
107
107
|
model_deployment="together/redpajama-incite-base-3b-v1",
|
|
108
108
|
)
|
|
109
109
|
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@pytest.mark.models
|
|
113
|
+
def test_together_chat_client_make_request():
|
|
114
|
+
# Requires setting TOGETHER_API_KEY environment variable.
|
|
115
|
+
client = TogetherChatClient(
|
|
116
|
+
cache_config=BlackHoleCacheConfig(), api_key=None, together_model="meta-llama/Llama-3-8b-chat-hf"
|
|
117
|
+
)
|
|
118
|
+
request = Request(
|
|
119
|
+
model="meta/llama-3-8b-instruct",
|
|
120
|
+
model_deployment="together/llama-3-8b-instruct",
|
|
121
|
+
prompt="Elephants are one of the most",
|
|
122
|
+
temperature=0.0,
|
|
123
|
+
max_tokens=10,
|
|
124
|
+
)
|
|
125
|
+
result = client.make_request(request)
|
|
126
|
+
assert result.success
|
|
127
|
+
assert not result.cached
|
|
128
|
+
assert result.embedding == []
|
|
129
|
+
assert len(result.completions) == 1
|
|
130
|
+
assert result.completions[0].text == "...intelligent animals on Earth!assistant"
|
|
131
|
+
assert result.completions[0].logprob == 0.0
|
|
132
|
+
result_token_strings = [token.text for token in result.completions[0].tokens]
|
|
133
|
+
assert result_token_strings == [
|
|
134
|
+
"...",
|
|
135
|
+
"int",
|
|
136
|
+
"elligent",
|
|
137
|
+
" animals",
|
|
138
|
+
" on",
|
|
139
|
+
" Earth",
|
|
140
|
+
"!",
|
|
141
|
+
"<|eot_id|>",
|
|
142
|
+
"<|start_header_id|>",
|
|
143
|
+
"assistant",
|
|
144
|
+
]
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@pytest.mark.models
|
|
148
|
+
def test_together_completion_client_make_request():
|
|
149
|
+
# Requires setting TOGETHER_API_KEY environment variable.
|
|
150
|
+
client = TogetherCompletionClient(
|
|
151
|
+
cache_config=BlackHoleCacheConfig(), api_key=None, together_model="meta-llama/Llama-3-8b-hf"
|
|
152
|
+
)
|
|
153
|
+
request = Request(
|
|
154
|
+
model="meta/llama-3-8b",
|
|
155
|
+
model_deployment="together/llama-3-8b",
|
|
156
|
+
prompt="Elephants are one of the most",
|
|
157
|
+
temperature=0.0,
|
|
158
|
+
max_tokens=10,
|
|
159
|
+
)
|
|
160
|
+
result = client.make_request(request)
|
|
161
|
+
assert result.success
|
|
162
|
+
assert not result.cached
|
|
163
|
+
assert result.embedding == []
|
|
164
|
+
assert len(result.completions) == 1
|
|
165
|
+
assert result.completions[0].text == " popular animals in the world. They are known for"
|
|
166
|
+
assert result.completions[0].logprob == 0.0
|
|
167
|
+
result_token_strings = [token.text for token in result.completions[0].tokens]
|
|
168
|
+
assert result_token_strings == [
|
|
169
|
+
" popular",
|
|
170
|
+
" animals",
|
|
171
|
+
" in",
|
|
172
|
+
" the",
|
|
173
|
+
" world",
|
|
174
|
+
".",
|
|
175
|
+
" They",
|
|
176
|
+
" are",
|
|
177
|
+
" known",
|
|
178
|
+
" for",
|
|
179
|
+
]
|
helm/clients/together_client.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from copy import deepcopy
|
|
2
2
|
from itertools import zip_longest
|
|
3
|
-
|
|
3
|
+
import threading
|
|
4
|
+
from typing import List, Dict, Any, Mapping, Optional, TypedDict, Union
|
|
4
5
|
|
|
5
6
|
import requests
|
|
6
7
|
from retrying import retry
|
|
@@ -12,7 +13,7 @@ from helm.clients.client import CachingClient, truncate_sequence, cleanup_str
|
|
|
12
13
|
|
|
13
14
|
try:
|
|
14
15
|
from together import Together
|
|
15
|
-
from together.types import ChatCompletionResponse
|
|
16
|
+
from together.types import ChatCompletionResponse, CompletionResponse
|
|
16
17
|
except ModuleNotFoundError as e:
|
|
17
18
|
handle_module_not_found_error(e, ["together"])
|
|
18
19
|
|
|
@@ -282,6 +283,24 @@ class TogetherClient(CachingClient):
|
|
|
282
283
|
)
|
|
283
284
|
|
|
284
285
|
|
|
286
|
+
_MODEL_TO_DEFAULT_STOP_TOKENS: Optional[Mapping[str, List[str]]] = None
|
|
287
|
+
_MODEL_TO_DEFAULT_STOP_TOKENS_LOCK = threading.Lock()
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def get_default_stop_tokens_for_model(together_model: str, together_client: Together) -> List[str]:
|
|
291
|
+
global _MODEL_TO_DEFAULT_STOP_TOKENS
|
|
292
|
+
global _MODEL_TO_DEFAULT_STOP_TOKENS_LOCK
|
|
293
|
+
with _MODEL_TO_DEFAULT_STOP_TOKENS_LOCK:
|
|
294
|
+
if _MODEL_TO_DEFAULT_STOP_TOKENS is None:
|
|
295
|
+
_MODEL_TO_DEFAULT_STOP_TOKENS = {}
|
|
296
|
+
for model in together_client.models.list():
|
|
297
|
+
_MODEL_TO_DEFAULT_STOP_TOKENS[model.id.lower()] = model.config["stop"]
|
|
298
|
+
stop_tokens = _MODEL_TO_DEFAULT_STOP_TOKENS.get(together_model.lower())
|
|
299
|
+
if stop_tokens is None:
|
|
300
|
+
raise ValueError(f"Unknown together_model {together_model}")
|
|
301
|
+
return stop_tokens
|
|
302
|
+
|
|
303
|
+
|
|
285
304
|
class TogetherRawChatRequest(TypedDict):
|
|
286
305
|
messages: List[Dict[str, str]]
|
|
287
306
|
model: str
|
|
@@ -295,34 +314,38 @@ class TogetherRawChatRequest(TypedDict):
|
|
|
295
314
|
n: int
|
|
296
315
|
|
|
297
316
|
|
|
298
|
-
def convert_to_raw_chat_request(request: Request) -> TogetherRawChatRequest:
|
|
299
|
-
if request.messages:
|
|
300
|
-
messages = request.messages
|
|
301
|
-
else:
|
|
302
|
-
messages = [{"role": "user", "content": request.prompt}]
|
|
303
|
-
return {
|
|
304
|
-
"messages": messages,
|
|
305
|
-
"model": request.model,
|
|
306
|
-
"max_tokens": request.max_tokens,
|
|
307
|
-
"stop": request.stop_sequences,
|
|
308
|
-
"temperature": request.temperature,
|
|
309
|
-
"top_p": request.top_p,
|
|
310
|
-
"top_k": request.top_k_per_token,
|
|
311
|
-
"logprobs": min(request.top_k_per_token, 1),
|
|
312
|
-
"echo": request.echo_prompt,
|
|
313
|
-
"n": request.num_completions,
|
|
314
|
-
}
|
|
315
|
-
|
|
316
|
-
|
|
317
317
|
class TogetherChatClient(CachingClient):
|
|
318
318
|
"""Client that uses the Python Together library for chat models."""
|
|
319
319
|
|
|
320
|
-
def __init__(self, cache_config: CacheConfig, api_key: str, together_model: Optional[str] = None):
|
|
320
|
+
def __init__(self, cache_config: CacheConfig, api_key: Optional[str], together_model: Optional[str] = None):
|
|
321
321
|
super().__init__(cache_config=cache_config)
|
|
322
322
|
self._client = Together(api_key=api_key)
|
|
323
|
+
self._together_model = together_model
|
|
324
|
+
|
|
325
|
+
def convert_to_raw_chat_request(self, request: Request) -> TogetherRawChatRequest:
|
|
326
|
+
if request.messages:
|
|
327
|
+
messages = request.messages
|
|
328
|
+
else:
|
|
329
|
+
messages = [{"role": "user", "content": request.prompt}]
|
|
330
|
+
if self._together_model is not None:
|
|
331
|
+
model = self._together_model
|
|
332
|
+
else:
|
|
333
|
+
model = request.model
|
|
334
|
+
return {
|
|
335
|
+
"messages": messages,
|
|
336
|
+
"model": model,
|
|
337
|
+
"max_tokens": request.max_tokens,
|
|
338
|
+
"stop": request.stop_sequences + get_default_stop_tokens_for_model(model, self._client),
|
|
339
|
+
"temperature": request.temperature,
|
|
340
|
+
"top_p": request.top_p,
|
|
341
|
+
"top_k": request.top_k_per_token,
|
|
342
|
+
"logprobs": min(request.top_k_per_token, 1),
|
|
343
|
+
"echo": request.echo_prompt,
|
|
344
|
+
"n": request.num_completions,
|
|
345
|
+
}
|
|
323
346
|
|
|
324
347
|
def make_request(self, request: Request) -> RequestResult:
|
|
325
|
-
raw_request = convert_to_raw_chat_request(request)
|
|
348
|
+
raw_request = self.convert_to_raw_chat_request(request)
|
|
326
349
|
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
327
350
|
|
|
328
351
|
def do_it() -> Dict[Any, Any]:
|
|
@@ -363,3 +386,86 @@ class TogetherChatClient(CachingClient):
|
|
|
363
386
|
completions=generated_outputs,
|
|
364
387
|
embedding=[],
|
|
365
388
|
)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
class TogetherRawCompletionRequest(TypedDict):
|
|
392
|
+
prompt: str
|
|
393
|
+
model: str
|
|
394
|
+
max_tokens: int
|
|
395
|
+
stop: List[str]
|
|
396
|
+
temperature: float
|
|
397
|
+
top_p: float
|
|
398
|
+
top_k: int
|
|
399
|
+
logprobs: int
|
|
400
|
+
echo: bool
|
|
401
|
+
n: int
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
class TogetherCompletionClient(CachingClient):
|
|
405
|
+
"""Client that uses the Python Together library for text completion models."""
|
|
406
|
+
|
|
407
|
+
def __init__(self, cache_config: CacheConfig, api_key: Optional[str], together_model: Optional[str] = None):
|
|
408
|
+
super().__init__(cache_config=cache_config)
|
|
409
|
+
self._client = Together(api_key=api_key)
|
|
410
|
+
self._together_model = together_model
|
|
411
|
+
|
|
412
|
+
def convert_to_raw_completion_request(self, request: Request) -> TogetherRawCompletionRequest:
|
|
413
|
+
if self._together_model is not None:
|
|
414
|
+
model = self._together_model
|
|
415
|
+
else:
|
|
416
|
+
model = request.model
|
|
417
|
+
return {
|
|
418
|
+
"prompt": request.prompt,
|
|
419
|
+
"model": model,
|
|
420
|
+
"max_tokens": request.max_tokens,
|
|
421
|
+
"stop": request.stop_sequences + get_default_stop_tokens_for_model(model, self._client),
|
|
422
|
+
"temperature": request.temperature,
|
|
423
|
+
"top_p": request.top_p,
|
|
424
|
+
"top_k": request.top_k_per_token,
|
|
425
|
+
"logprobs": min(request.top_k_per_token, 1),
|
|
426
|
+
"echo": request.echo_prompt,
|
|
427
|
+
"n": request.num_completions,
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
431
|
+
raw_request = self.convert_to_raw_completion_request(request)
|
|
432
|
+
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
433
|
+
|
|
434
|
+
def do_it() -> Dict[Any, Any]:
|
|
435
|
+
response = self._client.completions.create(**raw_request)
|
|
436
|
+
return response.model_dump(mode="json")
|
|
437
|
+
|
|
438
|
+
try:
|
|
439
|
+
raw_response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
440
|
+
response = CompletionResponse.model_validate(raw_response)
|
|
441
|
+
except Exception as error:
|
|
442
|
+
return RequestResult(
|
|
443
|
+
success=False,
|
|
444
|
+
cached=False,
|
|
445
|
+
error=str(error),
|
|
446
|
+
completions=[],
|
|
447
|
+
embedding=[],
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
generated_outputs: List[GeneratedOutput] = []
|
|
451
|
+
for choice in response.choices:
|
|
452
|
+
# NOTE: Together always returns None for choice.finish_reason
|
|
453
|
+
# NOTE: Together does not return logprobs for the whole generated output, only for individual tokens
|
|
454
|
+
tokens: List[Token] = []
|
|
455
|
+
if choice.logprobs:
|
|
456
|
+
for token_text, token_logprob in zip_longest(
|
|
457
|
+
choice.logprobs.tokens or [], choice.logprobs.token_logprobs or []
|
|
458
|
+
):
|
|
459
|
+
if token_text is None:
|
|
460
|
+
break
|
|
461
|
+
tokens.append(Token(text=token_text, logprob=token_logprob or 0.0))
|
|
462
|
+
assert choice.text
|
|
463
|
+
generated_outputs.append(GeneratedOutput(text=choice.text, logprob=0.0, tokens=tokens))
|
|
464
|
+
return RequestResult(
|
|
465
|
+
success=True,
|
|
466
|
+
cached=cached,
|
|
467
|
+
request_time=raw_response["request_time"],
|
|
468
|
+
request_datetime=raw_response["request_datetime"],
|
|
469
|
+
completions=generated_outputs,
|
|
470
|
+
embedding=[],
|
|
471
|
+
)
|
helm/clients/vertexai_client.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import requests
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from threading import Lock
|
|
4
|
-
from typing import Any, Dict, Optional, List, Union
|
|
4
|
+
from typing import Any, Dict, Mapping, Optional, List, Union
|
|
5
5
|
|
|
6
6
|
from helm.common.cache import CacheConfig
|
|
7
7
|
from helm.common.media_object import TEXT_TYPE
|
|
@@ -26,22 +26,62 @@ class VertexAIContentBlockedError(Exception):
|
|
|
26
26
|
pass
|
|
27
27
|
|
|
28
28
|
|
|
29
|
+
class SafetySettingPresets:
|
|
30
|
+
BLOCK_NONE = "block_none" # Disable all blocking
|
|
31
|
+
DEFAULT = "default" # Use default safety settings
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _get_safety_settings_for_preset(
|
|
35
|
+
safety_settings_preset: Optional[str],
|
|
36
|
+
) -> Optional[Dict[HarmCategory, SafetySetting.HarmBlockThreshold]]:
|
|
37
|
+
"""Get the safety settings for the safety_settings_preset.
|
|
38
|
+
|
|
39
|
+
If safety_settings_preset is None, use the default value of BLOCK_NONE (*not* DEFAULT)."""
|
|
40
|
+
if safety_settings_preset is None or safety_settings_preset == SafetySettingPresets.BLOCK_NONE:
|
|
41
|
+
return {
|
|
42
|
+
harm_category: SafetySetting.HarmBlockThreshold(SafetySetting.HarmBlockThreshold.BLOCK_NONE)
|
|
43
|
+
for harm_category in iter(HarmCategory)
|
|
44
|
+
}
|
|
45
|
+
elif safety_settings_preset == SafetySettingPresets.DEFAULT:
|
|
46
|
+
return None
|
|
47
|
+
else:
|
|
48
|
+
raise ValueError(f"Unknown safety_settings_preset: {safety_settings_preset}")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _get_model_name_for_request(request: Request) -> str:
|
|
52
|
+
# We have to strip "-safety-" suffixes from model names because they are not part of the Vertex AI model name
|
|
53
|
+
# TODO: Clean up this hack
|
|
54
|
+
return request.model_engine.split("-safety-")[0]
|
|
55
|
+
|
|
56
|
+
|
|
29
57
|
class VertexAIClient(CachingClient, ABC):
|
|
30
58
|
"""Client for Vertex AI models"""
|
|
31
59
|
|
|
32
|
-
def __init__(
|
|
60
|
+
def __init__(
|
|
61
|
+
self, cache_config: CacheConfig, project_id: str, location: str, safety_settings_preset: Optional[str] = None
|
|
62
|
+
) -> None:
|
|
33
63
|
super().__init__(cache_config=cache_config)
|
|
34
64
|
self.project_id = project_id
|
|
35
65
|
self.location = location
|
|
36
66
|
|
|
37
|
-
|
|
38
|
-
self.safety_settings
|
|
39
|
-
harm_category: SafetySetting.HarmBlockThreshold(SafetySetting.HarmBlockThreshold.BLOCK_NONE)
|
|
40
|
-
for harm_category in iter(HarmCategory)
|
|
41
|
-
}
|
|
67
|
+
self.safety_settings_preset = safety_settings_preset
|
|
68
|
+
self.safety_settings = _get_safety_settings_for_preset(safety_settings_preset)
|
|
42
69
|
|
|
43
70
|
vertexai.init(project=self.project_id, location=self.location)
|
|
44
71
|
|
|
72
|
+
def make_cache_key_with_safety_settings_preset(self, raw_request: Mapping, request: Request) -> Mapping:
|
|
73
|
+
"""Construct the key for the cache using the raw request.
|
|
74
|
+
|
|
75
|
+
Add `self.safety_settings_preset` to the key, if not None."""
|
|
76
|
+
if self.safety_settings_preset is not None:
|
|
77
|
+
assert "safety_settings_preset" not in raw_request
|
|
78
|
+
return {
|
|
79
|
+
**CachingClient.make_cache_key(raw_request, request),
|
|
80
|
+
"safety_settings_preset": self.safety_settings_preset,
|
|
81
|
+
}
|
|
82
|
+
else:
|
|
83
|
+
return CachingClient.make_cache_key(raw_request, request)
|
|
84
|
+
|
|
45
85
|
@abstractmethod
|
|
46
86
|
def make_request(self, request: Request) -> RequestResult:
|
|
47
87
|
raise NotImplementedError
|
|
@@ -71,7 +111,7 @@ class VertexAITextClient(VertexAIClient):
|
|
|
71
111
|
}
|
|
72
112
|
|
|
73
113
|
completions: List[GeneratedOutput] = []
|
|
74
|
-
model_name: str = request
|
|
114
|
+
model_name: str = _get_model_name_for_request(request)
|
|
75
115
|
|
|
76
116
|
try:
|
|
77
117
|
|
|
@@ -87,9 +127,9 @@ class VertexAITextClient(VertexAIClient):
|
|
|
87
127
|
# We need to include the engine's name to differentiate among requests made for different model
|
|
88
128
|
# engines since the engine name is not included in the request itself.
|
|
89
129
|
# Same for the prompt.
|
|
90
|
-
cache_key =
|
|
130
|
+
cache_key = self.make_cache_key_with_safety_settings_preset(
|
|
91
131
|
{
|
|
92
|
-
"engine":
|
|
132
|
+
"engine": model_name,
|
|
93
133
|
"prompt": request.prompt,
|
|
94
134
|
**parameters,
|
|
95
135
|
},
|
|
@@ -177,7 +217,7 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
177
217
|
}
|
|
178
218
|
|
|
179
219
|
completions: List[GeneratedOutput] = []
|
|
180
|
-
model_name: str = request
|
|
220
|
+
model_name: str = _get_model_name_for_request(request)
|
|
181
221
|
model = self.get_model(model_name)
|
|
182
222
|
|
|
183
223
|
try:
|
|
@@ -197,7 +237,7 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
197
237
|
|
|
198
238
|
# Depending on the version of the Vertex AI library and the type of prompt blocking,
|
|
199
239
|
# prompt blocking can show up in many ways, so this defensively handles most of these ways
|
|
200
|
-
if response.prompt_feedback.block_reason:
|
|
240
|
+
if response.prompt_feedback and response.prompt_feedback.block_reason:
|
|
201
241
|
raise VertexAIContentBlockedError(
|
|
202
242
|
f"Prompt blocked with reason: {response.prompt_feedback.block_reason}"
|
|
203
243
|
)
|
|
@@ -209,8 +249,10 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
209
249
|
# content blocking can show up in many ways, so this defensively handles most of these ways
|
|
210
250
|
if candidate.finish_reason in VertexAIChatClient.CONTENT_BLOCKED_FINISH_REASONS:
|
|
211
251
|
raise VertexAIContentBlockedError(f"Content blocked with reason: {candidate.finish_reason}")
|
|
252
|
+
if not candidate.content:
|
|
253
|
+
raise VertexAIContentBlockedError(f"No content in candidate: {candidate}")
|
|
212
254
|
if not candidate.content.parts:
|
|
213
|
-
raise VertexAIContentBlockedError(f"No parts in candidate: {candidate}")
|
|
255
|
+
raise VertexAIContentBlockedError(f"No content parts in candidate: {candidate}")
|
|
214
256
|
predictions.append({"text": candidate.content.text})
|
|
215
257
|
# TODO: Extract more information from the response
|
|
216
258
|
return {"predictions": predictions}
|
|
@@ -218,7 +260,7 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
218
260
|
# We need to include the engine's name to differentiate among requests made for different model
|
|
219
261
|
# engines since the engine name is not included in the request itself.
|
|
220
262
|
# Same for the prompt.
|
|
221
|
-
cache_key =
|
|
263
|
+
cache_key = self.make_cache_key_with_safety_settings_preset(
|
|
222
264
|
{
|
|
223
265
|
"model_name": model_name,
|
|
224
266
|
"prompt": request.prompt,
|
|
@@ -313,7 +355,7 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
313
355
|
}
|
|
314
356
|
|
|
315
357
|
completions: List[GeneratedOutput] = []
|
|
316
|
-
model_name: str = request
|
|
358
|
+
model_name: str = _get_model_name_for_request(request)
|
|
317
359
|
model = self.get_model(model_name)
|
|
318
360
|
|
|
319
361
|
request_time = 0
|
|
@@ -330,7 +372,7 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
330
372
|
)
|
|
331
373
|
# Depending on the version of the Vertex AI library and the type of prompt blocking,
|
|
332
374
|
# prompt blocking can show up in many ways, so this defensively handles most of these ways
|
|
333
|
-
if response.prompt_feedback.block_reason:
|
|
375
|
+
if response.prompt_feedback and response.prompt_feedback.block_reason:
|
|
334
376
|
raise VertexAIContentBlockedError(
|
|
335
377
|
f"Prompt blocked with reason: {response.prompt_feedback.block_reason}"
|
|
336
378
|
)
|
|
@@ -345,15 +387,17 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
345
387
|
# content blocking can show up in many ways, so this defensively handles most of these ways
|
|
346
388
|
if candidate.finish_reason in VertexAIChatClient.CONTENT_BLOCKED_FINISH_REASONS:
|
|
347
389
|
raise VertexAIContentBlockedError(f"Content blocked with reason: {candidate.finish_reason}")
|
|
390
|
+
if not candidate.content:
|
|
391
|
+
raise VertexAIContentBlockedError(f"No content in candidate: {candidate}")
|
|
348
392
|
if not candidate.content.parts:
|
|
349
|
-
raise VertexAIContentBlockedError(f"No parts in candidate: {candidate}")
|
|
393
|
+
raise VertexAIContentBlockedError(f"No content parts in candidate: {candidate}")
|
|
350
394
|
return {"predictions": [{"text": candidate.text}]}
|
|
351
395
|
|
|
352
396
|
raw_cache_key = {"model_name": model_name, "prompt": prompt_key, **parameters}
|
|
353
397
|
if completion_index > 0:
|
|
354
398
|
raw_cache_key["completion_index"] = completion_index
|
|
355
399
|
|
|
356
|
-
cache_key =
|
|
400
|
+
cache_key = self.make_cache_key_with_safety_settings_preset(raw_cache_key, request)
|
|
357
401
|
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
358
402
|
except requests.exceptions.RequestException as e:
|
|
359
403
|
error: str = f"Gemini Vision error: {e}"
|
|
@@ -38,6 +38,7 @@ class HuggingFaceVLMClient(CachingClient):
|
|
|
38
38
|
"huggingface/llava-v1.6-vicuna-13b-hf": "llava-hf/llava-v1.6-vicuna-13b-hf",
|
|
39
39
|
"huggingface/llava-v1.6-mistral-7b-hf": "llava-hf/llava-v1.6-mistral-7b-hf",
|
|
40
40
|
"huggingface/llava-v1.6-34b-hf": "llava-hf/llava-v1.6-34b-hf",
|
|
41
|
+
"huggingface/prometheus-vision-13b-v1.0-hf": "PahaII/prometheus-vision-13b-v1.0-hf",
|
|
41
42
|
}
|
|
42
43
|
|
|
43
44
|
def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig):
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from threading import Lock
|
|
2
|
+
from typing import Any, Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
|
7
|
+
|
|
8
|
+
from helm.common.cache import CacheConfig
|
|
9
|
+
from helm.common.images_utils import open_image
|
|
10
|
+
from helm.common.gpu_utils import get_torch_device_name
|
|
11
|
+
from helm.common.hierarchical_logger import hlog, htrack_block
|
|
12
|
+
from helm.common.media_object import TEXT_TYPE
|
|
13
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
14
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, Token
|
|
15
|
+
from helm.common.tokenization_request import TokenizationRequest
|
|
16
|
+
from helm.common.request import wrap_request_time
|
|
17
|
+
from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt
|
|
18
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from PIL import Image
|
|
22
|
+
except ModuleNotFoundError as e:
|
|
23
|
+
handle_module_not_found_error(e, ["images"])
|
|
24
|
+
|
|
25
|
+
# Added to solve: cutlassF: no kernel found to launch!
|
|
26
|
+
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
|
27
|
+
torch.backends.cuda.enable_flash_sdp(False)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True)
|
|
31
|
+
class LoadedPaliGemmaForConditionalGeneration:
|
|
32
|
+
"""Loaded model and processor for PaliGemma."""
|
|
33
|
+
|
|
34
|
+
model: PaliGemmaForConditionalGeneration
|
|
35
|
+
processor: AutoProcessor
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
_models_lock: Lock = Lock()
|
|
39
|
+
_models: Dict[str, Optional[LoadedPaliGemmaForConditionalGeneration]] = {}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class PaliGemmaClient(CachingClient):
|
|
43
|
+
"""
|
|
44
|
+
PaliGemma is a versatile and lightweight vision-language model (VLM) inspired by PaLI-3
|
|
45
|
+
and based on open components such as the SigLIP vision model and the Gemma language model.
|
|
46
|
+
It takes both image and text as input and generates text as output, supporting multiple languages.
|
|
47
|
+
It is designed for class-leading fine-tune performance on a wide range of vision-language tasks
|
|
48
|
+
such as image and short video caption, visual question answering, text reading, object detection
|
|
49
|
+
and object segmentation.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig):
|
|
53
|
+
super().__init__(cache_config=cache_config)
|
|
54
|
+
self.tokenizer = tokenizer
|
|
55
|
+
self.tokenizer_name = tokenizer_name
|
|
56
|
+
self._device: str = get_torch_device_name()
|
|
57
|
+
|
|
58
|
+
def _get_model(self, checkpoint: str) -> LoadedPaliGemmaForConditionalGeneration:
|
|
59
|
+
global _models_lock
|
|
60
|
+
global _models
|
|
61
|
+
|
|
62
|
+
# Ensure that only one thread is loading the model at a time
|
|
63
|
+
with _models_lock:
|
|
64
|
+
if checkpoint not in _models or _models[checkpoint] is None:
|
|
65
|
+
hlog(f"Loading model {checkpoint} and caching in memory...")
|
|
66
|
+
model = PaliGemmaForConditionalGeneration.from_pretrained(
|
|
67
|
+
checkpoint, torch_dtype=torch.bfloat16, device_map="auto"
|
|
68
|
+
).eval()
|
|
69
|
+
processor = AutoProcessor.from_pretrained(checkpoint)
|
|
70
|
+
_models[checkpoint] = LoadedPaliGemmaForConditionalGeneration(model, processor)
|
|
71
|
+
loaded_model_processor = _models[checkpoint]
|
|
72
|
+
|
|
73
|
+
assert loaded_model_processor is not None
|
|
74
|
+
return loaded_model_processor
|
|
75
|
+
|
|
76
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
77
|
+
assert request.multimodal_prompt is not None, "Multimodal prompt is required"
|
|
78
|
+
|
|
79
|
+
loaded_model_processor: LoadedPaliGemmaForConditionalGeneration = self._get_model(request.model_deployment)
|
|
80
|
+
model = loaded_model_processor.model
|
|
81
|
+
processor = loaded_model_processor.processor
|
|
82
|
+
generation_args = {"max_new_tokens": request.max_tokens}
|
|
83
|
+
|
|
84
|
+
images: List[Image.Image] = []
|
|
85
|
+
prompt_pieces: List[str] = []
|
|
86
|
+
for media_object in request.multimodal_prompt.media_objects:
|
|
87
|
+
if media_object.is_type("image") and media_object.location:
|
|
88
|
+
images += [open_image(media_object.location).convert("RGB")]
|
|
89
|
+
elif media_object.is_type(TEXT_TYPE):
|
|
90
|
+
if media_object.text is None:
|
|
91
|
+
raise ValueError("MediaObject of text type has missing text field value")
|
|
92
|
+
prompt_pieces.append(media_object.text)
|
|
93
|
+
else:
|
|
94
|
+
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
|
95
|
+
prompt_text: str = "\n".join(prompt_pieces)
|
|
96
|
+
model_inputs = processor(text=prompt_text, images=images, return_tensors="pt").to(self._device)
|
|
97
|
+
input_len = model_inputs["input_ids"].shape[-1]
|
|
98
|
+
|
|
99
|
+
completions: List[GeneratedOutput] = []
|
|
100
|
+
with htrack_block(f"Generating for prompt: {prompt_text}"):
|
|
101
|
+
try:
|
|
102
|
+
concat_results = []
|
|
103
|
+
for i_completion in range(request.num_completions):
|
|
104
|
+
|
|
105
|
+
def do_it() -> Dict[str, Any]:
|
|
106
|
+
with torch.inference_mode():
|
|
107
|
+
generation = model.generate(
|
|
108
|
+
**model_inputs, max_new_tokens=request.max_tokens, do_sample=False
|
|
109
|
+
)[0]
|
|
110
|
+
if not request.echo_prompt:
|
|
111
|
+
generation = generation[input_len:]
|
|
112
|
+
decoded = processor.decode(generation, skip_special_tokens=True)
|
|
113
|
+
return {"output": decoded}
|
|
114
|
+
|
|
115
|
+
# Include the prompt and model name in the cache key
|
|
116
|
+
cache_key = CachingClient.make_cache_key(
|
|
117
|
+
raw_request={
|
|
118
|
+
"n": request.num_completions,
|
|
119
|
+
"i": i_completion,
|
|
120
|
+
"model": request.model,
|
|
121
|
+
"prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt),
|
|
122
|
+
**generation_args,
|
|
123
|
+
},
|
|
124
|
+
request=request,
|
|
125
|
+
)
|
|
126
|
+
result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
127
|
+
concat_results.append(result)
|
|
128
|
+
except RuntimeError as model_error:
|
|
129
|
+
return RequestResult(success=False, cached=False, error=str(model_error), completions=[], embedding=[])
|
|
130
|
+
|
|
131
|
+
for result in concat_results:
|
|
132
|
+
text = result["output"]
|
|
133
|
+
hlog(f"Generated text: {text}")
|
|
134
|
+
tokenization_result = self.tokenizer.tokenize(
|
|
135
|
+
TokenizationRequest(text, tokenizer=self.tokenizer_name, encode=False)
|
|
136
|
+
)
|
|
137
|
+
tokens: List[Token] = [Token(text=str(text), logprob=0) for text in tokenization_result.raw_tokens]
|
|
138
|
+
completions.append(GeneratedOutput(text=text, logprob=0, tokens=tokens))
|
|
139
|
+
|
|
140
|
+
return RequestResult(
|
|
141
|
+
success=True,
|
|
142
|
+
cached=cached,
|
|
143
|
+
request_time=result["request_time"],
|
|
144
|
+
completions=completions,
|
|
145
|
+
embedding=[],
|
|
146
|
+
)
|