crfm-helm 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/METADATA +19 -5
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/RECORD +121 -76
- helm/benchmark/adaptation/adapter_spec.py +32 -31
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/multimodal_prompt.py +7 -0
- helm/benchmark/adaptation/adapters/multimodal/test_multimodal_prompt.py +2 -0
- helm/benchmark/annotation/air_bench_annotator.py +64 -0
- helm/benchmark/annotation/annotator_factory.py +6 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +1 -1
- helm/benchmark/annotation/live_qa_annotator.py +84 -0
- helm/benchmark/annotation/medication_qa_annotator.py +81 -0
- helm/benchmark/augmentations/perturbation.py +17 -1
- helm/benchmark/augmentations/test_perturbation.py +30 -0
- helm/benchmark/augmentations/translate_perturbation.py +1 -0
- helm/benchmark/huggingface_registration.py +16 -6
- helm/benchmark/metrics/air_bench_metrics.py +56 -0
- helm/benchmark/metrics/efficiency_metrics.py +9 -2
- helm/benchmark/metrics/evaluate_reference_metrics.py +16 -0
- helm/benchmark/metrics/fin_qa_metrics.py +60 -0
- helm/benchmark/metrics/fin_qa_metrics_helper.py +398 -0
- helm/benchmark/metrics/gpt4v_originality_critique_metrics.py +126 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +1 -0
- helm/benchmark/metrics/live_qa_metrics.py +23 -0
- helm/benchmark/metrics/medication_qa_metrics.py +23 -0
- helm/benchmark/metrics/prometheus_vision_critique_metrics.py +185 -0
- helm/benchmark/metrics/reka_vibe_critique_metrics.py +158 -0
- helm/benchmark/metrics/unitxt_metrics.py +20 -10
- helm/benchmark/metrics/vision_language/emd_utils.py +4 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +104 -21
- helm/benchmark/model_metadata_registry.py +5 -1
- helm/benchmark/presentation/schema.py +54 -4
- helm/benchmark/presentation/test_schema.py +11 -0
- helm/benchmark/run.py +16 -2
- helm/benchmark/run_expander.py +112 -63
- helm/benchmark/run_spec_factory.py +15 -10
- helm/benchmark/run_specs/air_bench_run_specs.py +40 -0
- helm/benchmark/run_specs/classic_run_specs.py +15 -11
- helm/benchmark/run_specs/decodingtrust_run_specs.py +3 -1
- helm/benchmark/run_specs/experimental_run_specs.py +33 -0
- helm/benchmark/run_specs/finance_run_specs.py +33 -0
- helm/benchmark/run_specs/vlm_run_specs.py +444 -65
- helm/benchmark/scenarios/air_bench_scenario.py +50 -0
- helm/benchmark/scenarios/ci_mcqa_scenario.py +80 -0
- helm/benchmark/scenarios/entity_data_imputation_scenario.py +8 -2
- helm/benchmark/scenarios/fin_qa_scenario.py +117 -0
- helm/benchmark/scenarios/legalbench_scenario.py +6 -2
- helm/benchmark/scenarios/math_scenario.py +1 -1
- helm/benchmark/scenarios/test_air_bench_scenario.py +27 -0
- helm/benchmark/scenarios/vision_language/a_okvqa_scenario.py +83 -0
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +3 -3
- helm/benchmark/scenarios/vision_language/crossmodal_3600_scenario.py +134 -0
- helm/benchmark/scenarios/vision_language/flickr30k_scenario.py +74 -0
- helm/benchmark/scenarios/vision_language/gqa_scenario.py +91 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +4 -2
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +13 -2
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +1 -5
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +1 -5
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +5 -3
- helm/benchmark/scenarios/vision_language/math_vista_scenario.py +117 -0
- helm/benchmark/scenarios/vision_language/mm_safety_bench_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/mscoco_captioning_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/mscoco_categorization_scenario.py +117 -0
- helm/benchmark/scenarios/vision_language/originality_scenario.py +35 -0
- helm/benchmark/scenarios/vision_language/pairs_scenario.py +247 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +3 -3
- helm/benchmark/scenarios/vision_language/vibe_eval_scenario.py +95 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +2 -2
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +4 -2
- helm/benchmark/static/schema_air_bench.yaml +3149 -0
- helm/benchmark/static/schema_classic.yaml +3 -59
- helm/benchmark/static/schema_finance.yaml +143 -0
- helm/benchmark/static/schema_image2structure.yaml +447 -0
- helm/benchmark/static/schema_instruction_following.yaml +3 -52
- helm/benchmark/static/schema_lite.yaml +3 -61
- helm/benchmark/static/schema_medical.yaml +255 -0
- helm/benchmark/static/schema_mmlu.yaml +3 -61
- helm/benchmark/static/schema_tables.yaml +200 -0
- helm/benchmark/static/schema_thai.yaml +223 -0
- helm/benchmark/static/schema_unitxt.yaml +3 -61
- helm/benchmark/static/schema_vhelm.yaml +824 -0
- helm/benchmark/static/schema_vhelm_lite.yaml +109 -0
- helm/benchmark/static_build/assets/air-overview-d2e6c49f.png +0 -0
- helm/benchmark/static_build/assets/index-30dbceba.js +10 -0
- helm/benchmark/static_build/assets/index-66b02d40.css +1 -0
- helm/benchmark/static_build/assets/overview-74aea3d8.png +0 -0
- helm/benchmark/static_build/assets/process-flow-bd2eba96.png +0 -0
- helm/benchmark/static_build/index.html +2 -2
- helm/clients/anthropic_client.py +78 -14
- helm/clients/auto_client.py +11 -0
- helm/clients/client.py +24 -7
- helm/clients/cohere_client.py +98 -3
- helm/clients/huggingface_client.py +71 -12
- helm/clients/openai_client.py +11 -5
- helm/clients/reka_client.py +189 -0
- helm/clients/test_client.py +3 -3
- helm/clients/test_huggingface_client.py +19 -3
- helm/clients/test_together_client.py +72 -2
- helm/clients/together_client.py +199 -2
- helm/clients/vertexai_client.py +117 -64
- helm/clients/vision_language/huggingface_vision2seq_client.py +145 -0
- helm/clients/vision_language/huggingface_vlm_client.py +12 -4
- helm/clients/vision_language/idefics_client.py +2 -2
- helm/clients/vision_language/paligemma_client.py +146 -0
- helm/clients/vision_language/palmyra_vision_client.py +84 -0
- helm/clients/yi_client.py +31 -0
- helm/common/critique_request.py +10 -1
- helm/common/images_utils.py +29 -3
- helm/config/model_deployments.yaml +504 -12
- helm/config/model_metadata.yaml +579 -52
- helm/config/tokenizer_configs.yaml +100 -1
- helm/proxy/critique/model_critique_client.py +32 -4
- helm/proxy/services/server_service.py +1 -1
- helm/tokenizers/auto_tokenizer.py +1 -1
- helm/tokenizers/cohere_tokenizer.py +44 -2
- helm/tokenizers/huggingface_tokenizer.py +36 -13
- helm/tokenizers/test_cohere_tokenizer.py +39 -0
- helm/tokenizers/test_huggingface_tokenizer.py +5 -1
- helm/benchmark/static/schema_vlm.yaml +0 -576
- helm/benchmark/static_build/assets/index-5088afcb.css +0 -1
- helm/benchmark/static_build/assets/index-d839df55.js +0 -9
- helm/benchmark/test_model_deployment_definition.py +0 -90
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/LICENSE +0 -0
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/WHEEL +0 -0
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
# mypy: check_untyped_defs = False
|
|
2
|
+
import requests
|
|
3
|
+
from typing import Any, Dict, List, Optional, TypedDict
|
|
4
|
+
|
|
5
|
+
from helm.proxy.retry import NonRetriableException
|
|
6
|
+
from helm.common.cache import CacheConfig
|
|
7
|
+
from helm.common.media_object import TEXT_TYPE
|
|
8
|
+
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput
|
|
9
|
+
from helm.common.hierarchical_logger import hlog
|
|
10
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
11
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
12
|
+
from .client import CachingClient, truncate_and_tokenize_response_text
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import reka
|
|
16
|
+
except ModuleNotFoundError as e:
|
|
17
|
+
handle_module_not_found_error(e, ["reka-api"])
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RekaAIRequest(TypedDict):
|
|
21
|
+
"""Data passed between make_request and _send_request. Used as the cache key."""
|
|
22
|
+
|
|
23
|
+
model_name: str
|
|
24
|
+
conversation_history: List[Dict[str, str]]
|
|
25
|
+
request_output_len: int
|
|
26
|
+
temperature: float
|
|
27
|
+
runtime_top_p: float
|
|
28
|
+
random_seed: Optional[int]
|
|
29
|
+
stop_words: Optional[List[str]]
|
|
30
|
+
presence_penalty: float
|
|
31
|
+
frequency_penalty: float
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class RekaClient(CachingClient):
|
|
35
|
+
REKA_CHAT_ROLE_MAPPING: Dict[str, str] = {
|
|
36
|
+
"user": "human",
|
|
37
|
+
"assistant": "model",
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
tokenizer: Tokenizer,
|
|
43
|
+
tokenizer_name: str,
|
|
44
|
+
cache_config: CacheConfig,
|
|
45
|
+
api_key: Optional[str] = None,
|
|
46
|
+
):
|
|
47
|
+
super().__init__(cache_config=cache_config)
|
|
48
|
+
self.tokenizer = tokenizer
|
|
49
|
+
self.tokenizer_name = tokenizer_name
|
|
50
|
+
self.client = reka
|
|
51
|
+
self.client.API_KEY = api_key
|
|
52
|
+
|
|
53
|
+
def _is_reka_model_engine(self, model_engine: str) -> bool:
|
|
54
|
+
if (
|
|
55
|
+
model_engine.startswith("reka-edge")
|
|
56
|
+
or model_engine.startswith("reka-flash")
|
|
57
|
+
or model_engine.startswith("reka-core")
|
|
58
|
+
):
|
|
59
|
+
return True
|
|
60
|
+
else:
|
|
61
|
+
return False
|
|
62
|
+
|
|
63
|
+
def _get_model_for_request(self, request: Request) -> str:
|
|
64
|
+
return request.model_engine
|
|
65
|
+
|
|
66
|
+
def _get_random_seed(self, request: Request, completion_index: int) -> Optional[int]:
|
|
67
|
+
if request.random is None and completion_index == 0:
|
|
68
|
+
return None
|
|
69
|
+
|
|
70
|
+
# Treat the user's request.random as an integer for the random seed.
|
|
71
|
+
try:
|
|
72
|
+
request_random_seed = int(request.random) if request.random is not None else 0
|
|
73
|
+
except ValueError:
|
|
74
|
+
raise NonRetriableException("RekaAIClient only supports integer values for request.random")
|
|
75
|
+
|
|
76
|
+
# A large prime is used so that the resulting values are unlikely to collide
|
|
77
|
+
# with request.random values chosen by the user.
|
|
78
|
+
fixed_large_prime = 1911011
|
|
79
|
+
completion_index_random_seed = completion_index * fixed_large_prime
|
|
80
|
+
|
|
81
|
+
return request_random_seed + completion_index_random_seed
|
|
82
|
+
|
|
83
|
+
def _convert_messages_to_reka_chat_history(self, messages: List[Dict[str, Any]]):
|
|
84
|
+
chat_history = []
|
|
85
|
+
num_images: int = 0
|
|
86
|
+
for chat_turn, message in enumerate(messages):
|
|
87
|
+
role = message["role"]
|
|
88
|
+
content = message["content"]
|
|
89
|
+
current_chat_history: Dict[str, Any] = {
|
|
90
|
+
"type": self.REKA_CHAT_ROLE_MAPPING[role],
|
|
91
|
+
"text": "", # text placeholder
|
|
92
|
+
"media_url": None,
|
|
93
|
+
}
|
|
94
|
+
for item in content:
|
|
95
|
+
if item["type"] == "image_url":
|
|
96
|
+
if chat_turn == 0 and num_images == 0:
|
|
97
|
+
current_chat_history["media_url"] = item["image_url"]["url"]
|
|
98
|
+
num_images += 1
|
|
99
|
+
else:
|
|
100
|
+
raise ValueError(
|
|
101
|
+
f"Only the first message can contain one image. Found image input "
|
|
102
|
+
f"in message {chat_turn + 1}"
|
|
103
|
+
)
|
|
104
|
+
elif item["type"] == "text":
|
|
105
|
+
current_chat_history["text"] = item["text"]
|
|
106
|
+
else:
|
|
107
|
+
raise ValueError(f"Unrecognized message type {item['type']}")
|
|
108
|
+
chat_history.append(current_chat_history)
|
|
109
|
+
return chat_history
|
|
110
|
+
|
|
111
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
112
|
+
completions: List[GeneratedOutput] = []
|
|
113
|
+
messages: Optional[List[Dict[str, Any]]] = request.messages
|
|
114
|
+
reka_chat_history: List[Dict[str, Any]]
|
|
115
|
+
if messages is not None:
|
|
116
|
+
# Checks that all messages have a role and some content
|
|
117
|
+
for message in messages:
|
|
118
|
+
if not message.get("role") or not message.get("content"):
|
|
119
|
+
raise ValueError("All messages must have a role and content")
|
|
120
|
+
# Checks that the last role is "user"
|
|
121
|
+
if messages[-1]["role"] != "user":
|
|
122
|
+
raise ValueError("Last message must have role 'user'")
|
|
123
|
+
if request.prompt != "":
|
|
124
|
+
hlog("WARNING: Since message is set, prompt will be ignored")
|
|
125
|
+
reka_chat_history = self._convert_messages_to_reka_chat_history(messages)
|
|
126
|
+
else:
|
|
127
|
+
current_chat_history: Dict[str, Any] = {
|
|
128
|
+
"type": "human",
|
|
129
|
+
"text": "",
|
|
130
|
+
"media_url": None,
|
|
131
|
+
}
|
|
132
|
+
if request.multimodal_prompt is not None:
|
|
133
|
+
for media_object in request.multimodal_prompt.media_objects:
|
|
134
|
+
if media_object.is_type("image") and media_object.location:
|
|
135
|
+
from helm.common.images_utils import encode_base64
|
|
136
|
+
|
|
137
|
+
base64_image: str = encode_base64(media_object.location)
|
|
138
|
+
current_chat_history["media_url"] = f"data:image/jpeg;base64,{base64_image}"
|
|
139
|
+
elif media_object.is_type(TEXT_TYPE):
|
|
140
|
+
if media_object.text is None:
|
|
141
|
+
raise ValueError("MediaObject of text type has missing text field value")
|
|
142
|
+
current_chat_history["text"] = media_object.text
|
|
143
|
+
else:
|
|
144
|
+
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
|
145
|
+
|
|
146
|
+
else:
|
|
147
|
+
current_chat_history["text"] = request.prompt
|
|
148
|
+
reka_chat_history = [current_chat_history]
|
|
149
|
+
|
|
150
|
+
# `num_completions` is not supported, so instead make `num_completions` separate requests.
|
|
151
|
+
for completion_index in range(request.num_completions):
|
|
152
|
+
try:
|
|
153
|
+
raw_request: RekaAIRequest = {
|
|
154
|
+
"model_name": self._get_model_for_request(request),
|
|
155
|
+
"conversation_history": reka_chat_history, # we only use chat_history as the input
|
|
156
|
+
"request_output_len": request.max_tokens,
|
|
157
|
+
"temperature": request.temperature,
|
|
158
|
+
"random_seed": self._get_random_seed(request, completion_index),
|
|
159
|
+
"stop_words": request.stop_sequences or None, # API doesn't like empty list
|
|
160
|
+
"runtime_top_p": request.top_p,
|
|
161
|
+
"presence_penalty": request.presence_penalty,
|
|
162
|
+
"frequency_penalty": request.frequency_penalty,
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
def do_it() -> Dict[str, Any]:
|
|
166
|
+
return self.client.chat(**raw_request)
|
|
167
|
+
|
|
168
|
+
response, cached = self.cache.get(raw_request, wrap_request_time(do_it))
|
|
169
|
+
except (requests.exceptions.RequestException, AssertionError) as e:
|
|
170
|
+
error: str = f"RekaClient error: {e}"
|
|
171
|
+
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
172
|
+
|
|
173
|
+
response_message: Dict[str, Any] = response
|
|
174
|
+
assert response_message["type"] == "model"
|
|
175
|
+
response_text: str = response_message["text"]
|
|
176
|
+
|
|
177
|
+
# The Reka API doesn't support echo. If `echo_prompt` is true, combine the prompt and completion.
|
|
178
|
+
text: str = request.prompt + response_text if request.echo_prompt else response_text
|
|
179
|
+
completion = truncate_and_tokenize_response_text(text, request, self.tokenizer, self.tokenizer_name)
|
|
180
|
+
completions.append(completion)
|
|
181
|
+
|
|
182
|
+
return RequestResult(
|
|
183
|
+
success=True,
|
|
184
|
+
cached=cached,
|
|
185
|
+
request_time=response["request_time"],
|
|
186
|
+
request_datetime=response.get("request_datetime"),
|
|
187
|
+
completions=completions,
|
|
188
|
+
embedding=[],
|
|
189
|
+
)
|
helm/clients/test_client.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
from helm.common.
|
|
2
|
-
from helm.tokenizers.
|
|
1
|
+
from helm.common.cache_backend_config import BlackHoleCacheBackendConfig
|
|
2
|
+
from helm.tokenizers.auto_tokenizer import AutoTokenizer
|
|
3
3
|
from .client import truncate_sequence, truncate_and_tokenize_response_text
|
|
4
4
|
from typing import List
|
|
5
5
|
from helm.common.request import Request, GeneratedOutput, Token
|
|
@@ -52,8 +52,8 @@ def test_truncate_sequence():
|
|
|
52
52
|
|
|
53
53
|
|
|
54
54
|
def test_truncate_and_tokenize_response_text():
|
|
55
|
-
tokenizer = HuggingFaceTokenizer(BlackHoleCacheConfig())
|
|
56
55
|
tokenizer_name = "huggingface/gpt2"
|
|
56
|
+
tokenizer = AutoTokenizer(credentials={}, cache_backend_config=BlackHoleCacheBackendConfig())
|
|
57
57
|
|
|
58
58
|
# No truncation
|
|
59
59
|
response = truncate_and_tokenize_response_text(
|
|
@@ -3,12 +3,18 @@ import pytest
|
|
|
3
3
|
from helm.common.cache import BlackHoleCacheConfig
|
|
4
4
|
from helm.common.request import Request, RequestResult
|
|
5
5
|
from helm.clients.huggingface_client import HuggingFaceClient
|
|
6
|
+
from helm.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class TestHuggingFaceClient:
|
|
9
10
|
def test_gpt2(self):
|
|
11
|
+
tokenizer = HuggingFaceTokenizer(
|
|
12
|
+
BlackHoleCacheConfig(), "huggingface/gpt2", pretrained_model_name_or_path="openai/gpt2"
|
|
13
|
+
)
|
|
10
14
|
client = HuggingFaceClient(
|
|
11
|
-
cache_config=BlackHoleCacheConfig(),
|
|
15
|
+
cache_config=BlackHoleCacheConfig(),
|
|
16
|
+
tokenizer=tokenizer,
|
|
17
|
+
pretrained_model_name_or_path="openai-community/gpt2",
|
|
12
18
|
)
|
|
13
19
|
prompt: str = "I am a computer scientist."
|
|
14
20
|
result: RequestResult = client.make_request(
|
|
@@ -29,8 +35,13 @@ class TestHuggingFaceClient:
|
|
|
29
35
|
|
|
30
36
|
@pytest.mark.skip(reason="GPT-J 6B is 22 GB and extremely slow without a GPU.")
|
|
31
37
|
def test_gptj_6b(self):
|
|
38
|
+
tokenizer = HuggingFaceTokenizer(
|
|
39
|
+
BlackHoleCacheConfig(), "huggingface/gpt2", pretrained_model_name_or_path="openai/gpt2"
|
|
40
|
+
)
|
|
32
41
|
client = HuggingFaceClient(
|
|
33
|
-
cache_config=BlackHoleCacheConfig(),
|
|
42
|
+
cache_config=BlackHoleCacheConfig(),
|
|
43
|
+
tokenizer=tokenizer,
|
|
44
|
+
pretrained_model_name_or_path="openai-community/gpt2",
|
|
34
45
|
)
|
|
35
46
|
result: RequestResult = client.make_request(
|
|
36
47
|
Request(
|
|
@@ -45,8 +56,13 @@ class TestHuggingFaceClient:
|
|
|
45
56
|
assert len(result.completions) == 3
|
|
46
57
|
|
|
47
58
|
def test_logprob(self):
|
|
59
|
+
tokenizer = HuggingFaceTokenizer(
|
|
60
|
+
BlackHoleCacheConfig(), "huggingface/gpt2", pretrained_model_name_or_path="openai/gpt2"
|
|
61
|
+
)
|
|
48
62
|
client = HuggingFaceClient(
|
|
49
|
-
cache_config=BlackHoleCacheConfig(),
|
|
63
|
+
cache_config=BlackHoleCacheConfig(),
|
|
64
|
+
tokenizer=tokenizer,
|
|
65
|
+
pretrained_model_name_or_path="openai-community/gpt2",
|
|
50
66
|
)
|
|
51
67
|
prompt: str = "I am a computer scientist."
|
|
52
68
|
result: RequestResult = client.make_request(
|
|
@@ -2,10 +2,10 @@ import os
|
|
|
2
2
|
import pytest
|
|
3
3
|
import tempfile
|
|
4
4
|
|
|
5
|
-
from helm.common.cache import SqliteCacheConfig
|
|
5
|
+
from helm.common.cache import BlackHoleCacheConfig, SqliteCacheConfig
|
|
6
6
|
from helm.common.request import Request
|
|
7
7
|
|
|
8
|
-
from .together_client import TogetherClient, TogetherClientError
|
|
8
|
+
from .together_client import TogetherClient, TogetherChatClient, TogetherCompletionClient, TogetherClientError
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class TestTogetherClient:
|
|
@@ -107,3 +107,73 @@ class TestTogetherClient:
|
|
|
107
107
|
model_deployment="together/redpajama-incite-base-3b-v1",
|
|
108
108
|
)
|
|
109
109
|
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@pytest.mark.models
|
|
113
|
+
def test_together_chat_client_make_request():
|
|
114
|
+
# Requires setting TOGETHER_API_KEY environment variable.
|
|
115
|
+
client = TogetherChatClient(
|
|
116
|
+
cache_config=BlackHoleCacheConfig(), api_key=None, together_model="meta-llama/Llama-3-8b-chat-hf"
|
|
117
|
+
)
|
|
118
|
+
request = Request(
|
|
119
|
+
model="meta/llama-3-8b-instruct",
|
|
120
|
+
model_deployment="together/llama-3-8b-instruct",
|
|
121
|
+
prompt="Elephants are one of the most",
|
|
122
|
+
temperature=0.0,
|
|
123
|
+
max_tokens=10,
|
|
124
|
+
)
|
|
125
|
+
result = client.make_request(request)
|
|
126
|
+
assert result.success
|
|
127
|
+
assert not result.cached
|
|
128
|
+
assert result.embedding == []
|
|
129
|
+
assert len(result.completions) == 1
|
|
130
|
+
assert result.completions[0].text == "...intelligent animals on Earth!assistant"
|
|
131
|
+
assert result.completions[0].logprob == 0.0
|
|
132
|
+
result_token_strings = [token.text for token in result.completions[0].tokens]
|
|
133
|
+
assert result_token_strings == [
|
|
134
|
+
"...",
|
|
135
|
+
"int",
|
|
136
|
+
"elligent",
|
|
137
|
+
" animals",
|
|
138
|
+
" on",
|
|
139
|
+
" Earth",
|
|
140
|
+
"!",
|
|
141
|
+
"<|eot_id|>",
|
|
142
|
+
"<|start_header_id|>",
|
|
143
|
+
"assistant",
|
|
144
|
+
]
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@pytest.mark.models
|
|
148
|
+
def test_together_completion_client_make_request():
|
|
149
|
+
# Requires setting TOGETHER_API_KEY environment variable.
|
|
150
|
+
client = TogetherCompletionClient(
|
|
151
|
+
cache_config=BlackHoleCacheConfig(), api_key=None, together_model="meta-llama/Llama-3-8b-hf"
|
|
152
|
+
)
|
|
153
|
+
request = Request(
|
|
154
|
+
model="meta/llama-3-8b",
|
|
155
|
+
model_deployment="together/llama-3-8b",
|
|
156
|
+
prompt="Elephants are one of the most",
|
|
157
|
+
temperature=0.0,
|
|
158
|
+
max_tokens=10,
|
|
159
|
+
)
|
|
160
|
+
result = client.make_request(request)
|
|
161
|
+
assert result.success
|
|
162
|
+
assert not result.cached
|
|
163
|
+
assert result.embedding == []
|
|
164
|
+
assert len(result.completions) == 1
|
|
165
|
+
assert result.completions[0].text == " popular animals in the world. They are known for"
|
|
166
|
+
assert result.completions[0].logprob == 0.0
|
|
167
|
+
result_token_strings = [token.text for token in result.completions[0].tokens]
|
|
168
|
+
assert result_token_strings == [
|
|
169
|
+
" popular",
|
|
170
|
+
" animals",
|
|
171
|
+
" in",
|
|
172
|
+
" the",
|
|
173
|
+
" world",
|
|
174
|
+
".",
|
|
175
|
+
" They",
|
|
176
|
+
" are",
|
|
177
|
+
" known",
|
|
178
|
+
" for",
|
|
179
|
+
]
|
helm/clients/together_client.py
CHANGED
|
@@ -1,12 +1,21 @@
|
|
|
1
1
|
from copy import deepcopy
|
|
2
|
-
from
|
|
2
|
+
from itertools import zip_longest
|
|
3
|
+
import threading
|
|
4
|
+
from typing import List, Dict, Any, Mapping, Optional, TypedDict, Union
|
|
3
5
|
|
|
4
6
|
import requests
|
|
5
7
|
from retrying import retry
|
|
6
8
|
|
|
7
9
|
from helm.common.cache import CacheConfig
|
|
10
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
8
11
|
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, Token
|
|
9
|
-
from .client import CachingClient, truncate_sequence, cleanup_str
|
|
12
|
+
from helm.clients.client import CachingClient, truncate_sequence, cleanup_str
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from together import Together
|
|
16
|
+
from together.types import ChatCompletionResponse, CompletionResponse
|
|
17
|
+
except ModuleNotFoundError as e:
|
|
18
|
+
handle_module_not_found_error(e, ["together"])
|
|
10
19
|
|
|
11
20
|
|
|
12
21
|
class _RewriteRequestTags:
|
|
@@ -272,3 +281,191 @@ class TogetherClient(CachingClient):
|
|
|
272
281
|
completions=completions,
|
|
273
282
|
embedding=[],
|
|
274
283
|
)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
_MODEL_TO_DEFAULT_STOP_TOKENS: Optional[Mapping[str, List[str]]] = None
|
|
287
|
+
_MODEL_TO_DEFAULT_STOP_TOKENS_LOCK = threading.Lock()
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def get_default_stop_tokens_for_model(together_model: str, together_client: Together) -> List[str]:
|
|
291
|
+
global _MODEL_TO_DEFAULT_STOP_TOKENS
|
|
292
|
+
global _MODEL_TO_DEFAULT_STOP_TOKENS_LOCK
|
|
293
|
+
with _MODEL_TO_DEFAULT_STOP_TOKENS_LOCK:
|
|
294
|
+
if _MODEL_TO_DEFAULT_STOP_TOKENS is None:
|
|
295
|
+
_MODEL_TO_DEFAULT_STOP_TOKENS = {}
|
|
296
|
+
for model in together_client.models.list():
|
|
297
|
+
_MODEL_TO_DEFAULT_STOP_TOKENS[model.id.lower()] = model.config["stop"]
|
|
298
|
+
stop_tokens = _MODEL_TO_DEFAULT_STOP_TOKENS.get(together_model.lower())
|
|
299
|
+
if stop_tokens is None:
|
|
300
|
+
raise ValueError(f"Unknown together_model {together_model}")
|
|
301
|
+
return stop_tokens
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
class TogetherRawChatRequest(TypedDict):
|
|
305
|
+
messages: List[Dict[str, str]]
|
|
306
|
+
model: str
|
|
307
|
+
max_tokens: int
|
|
308
|
+
stop: List[str]
|
|
309
|
+
temperature: float
|
|
310
|
+
top_p: float
|
|
311
|
+
top_k: int
|
|
312
|
+
logprobs: int
|
|
313
|
+
echo: bool
|
|
314
|
+
n: int
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
class TogetherChatClient(CachingClient):
|
|
318
|
+
"""Client that uses the Python Together library for chat models."""
|
|
319
|
+
|
|
320
|
+
def __init__(self, cache_config: CacheConfig, api_key: Optional[str], together_model: Optional[str] = None):
|
|
321
|
+
super().__init__(cache_config=cache_config)
|
|
322
|
+
self._client = Together(api_key=api_key)
|
|
323
|
+
self._together_model = together_model
|
|
324
|
+
|
|
325
|
+
def convert_to_raw_chat_request(self, request: Request) -> TogetherRawChatRequest:
|
|
326
|
+
if request.messages:
|
|
327
|
+
messages = request.messages
|
|
328
|
+
else:
|
|
329
|
+
messages = [{"role": "user", "content": request.prompt}]
|
|
330
|
+
if self._together_model is not None:
|
|
331
|
+
model = self._together_model
|
|
332
|
+
else:
|
|
333
|
+
model = request.model
|
|
334
|
+
return {
|
|
335
|
+
"messages": messages,
|
|
336
|
+
"model": model,
|
|
337
|
+
"max_tokens": request.max_tokens,
|
|
338
|
+
"stop": request.stop_sequences + get_default_stop_tokens_for_model(model, self._client),
|
|
339
|
+
"temperature": request.temperature,
|
|
340
|
+
"top_p": request.top_p,
|
|
341
|
+
"top_k": request.top_k_per_token,
|
|
342
|
+
"logprobs": min(request.top_k_per_token, 1),
|
|
343
|
+
"echo": request.echo_prompt,
|
|
344
|
+
"n": request.num_completions,
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
348
|
+
raw_request = self.convert_to_raw_chat_request(request)
|
|
349
|
+
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
350
|
+
|
|
351
|
+
def do_it() -> Dict[Any, Any]:
|
|
352
|
+
response = self._client.chat.completions.create(**raw_request)
|
|
353
|
+
return response.model_dump(mode="json")
|
|
354
|
+
|
|
355
|
+
try:
|
|
356
|
+
raw_response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
357
|
+
response = ChatCompletionResponse.model_validate(raw_response)
|
|
358
|
+
except Exception as error:
|
|
359
|
+
return RequestResult(
|
|
360
|
+
success=False,
|
|
361
|
+
cached=False,
|
|
362
|
+
error=str(error),
|
|
363
|
+
completions=[],
|
|
364
|
+
embedding=[],
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
generated_outputs: List[GeneratedOutput] = []
|
|
368
|
+
for choice in response.choices:
|
|
369
|
+
# NOTE: Together always returns None for choice.finish_reason
|
|
370
|
+
# NOTE: Together does not return logprobs for the whole generated output, only for individual tokens
|
|
371
|
+
tokens: List[Token] = []
|
|
372
|
+
if choice.logprobs:
|
|
373
|
+
for token_text, token_logprob in zip_longest(
|
|
374
|
+
choice.logprobs.tokens or [], choice.logprobs.token_logprobs or []
|
|
375
|
+
):
|
|
376
|
+
if token_text is None:
|
|
377
|
+
break
|
|
378
|
+
tokens.append(Token(text=token_text, logprob=token_logprob or 0.0))
|
|
379
|
+
assert choice.message.role == "assistant"
|
|
380
|
+
generated_outputs.append(GeneratedOutput(text=choice.message.content, logprob=0.0, tokens=tokens))
|
|
381
|
+
return RequestResult(
|
|
382
|
+
success=True,
|
|
383
|
+
cached=cached,
|
|
384
|
+
request_time=raw_response["request_time"],
|
|
385
|
+
request_datetime=raw_response["request_datetime"],
|
|
386
|
+
completions=generated_outputs,
|
|
387
|
+
embedding=[],
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
class TogetherRawCompletionRequest(TypedDict):
|
|
392
|
+
prompt: str
|
|
393
|
+
model: str
|
|
394
|
+
max_tokens: int
|
|
395
|
+
stop: List[str]
|
|
396
|
+
temperature: float
|
|
397
|
+
top_p: float
|
|
398
|
+
top_k: int
|
|
399
|
+
logprobs: int
|
|
400
|
+
echo: bool
|
|
401
|
+
n: int
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
class TogetherCompletionClient(CachingClient):
|
|
405
|
+
"""Client that uses the Python Together library for text completion models."""
|
|
406
|
+
|
|
407
|
+
def __init__(self, cache_config: CacheConfig, api_key: Optional[str], together_model: Optional[str] = None):
|
|
408
|
+
super().__init__(cache_config=cache_config)
|
|
409
|
+
self._client = Together(api_key=api_key)
|
|
410
|
+
self._together_model = together_model
|
|
411
|
+
|
|
412
|
+
def convert_to_raw_completion_request(self, request: Request) -> TogetherRawCompletionRequest:
|
|
413
|
+
if self._together_model is not None:
|
|
414
|
+
model = self._together_model
|
|
415
|
+
else:
|
|
416
|
+
model = request.model
|
|
417
|
+
return {
|
|
418
|
+
"prompt": request.prompt,
|
|
419
|
+
"model": model,
|
|
420
|
+
"max_tokens": request.max_tokens,
|
|
421
|
+
"stop": request.stop_sequences + get_default_stop_tokens_for_model(model, self._client),
|
|
422
|
+
"temperature": request.temperature,
|
|
423
|
+
"top_p": request.top_p,
|
|
424
|
+
"top_k": request.top_k_per_token,
|
|
425
|
+
"logprobs": min(request.top_k_per_token, 1),
|
|
426
|
+
"echo": request.echo_prompt,
|
|
427
|
+
"n": request.num_completions,
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
431
|
+
raw_request = self.convert_to_raw_completion_request(request)
|
|
432
|
+
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
433
|
+
|
|
434
|
+
def do_it() -> Dict[Any, Any]:
|
|
435
|
+
response = self._client.completions.create(**raw_request)
|
|
436
|
+
return response.model_dump(mode="json")
|
|
437
|
+
|
|
438
|
+
try:
|
|
439
|
+
raw_response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
440
|
+
response = CompletionResponse.model_validate(raw_response)
|
|
441
|
+
except Exception as error:
|
|
442
|
+
return RequestResult(
|
|
443
|
+
success=False,
|
|
444
|
+
cached=False,
|
|
445
|
+
error=str(error),
|
|
446
|
+
completions=[],
|
|
447
|
+
embedding=[],
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
generated_outputs: List[GeneratedOutput] = []
|
|
451
|
+
for choice in response.choices:
|
|
452
|
+
# NOTE: Together always returns None for choice.finish_reason
|
|
453
|
+
# NOTE: Together does not return logprobs for the whole generated output, only for individual tokens
|
|
454
|
+
tokens: List[Token] = []
|
|
455
|
+
if choice.logprobs:
|
|
456
|
+
for token_text, token_logprob in zip_longest(
|
|
457
|
+
choice.logprobs.tokens or [], choice.logprobs.token_logprobs or []
|
|
458
|
+
):
|
|
459
|
+
if token_text is None:
|
|
460
|
+
break
|
|
461
|
+
tokens.append(Token(text=token_text, logprob=token_logprob or 0.0))
|
|
462
|
+
assert choice.text
|
|
463
|
+
generated_outputs.append(GeneratedOutput(text=choice.text, logprob=0.0, tokens=tokens))
|
|
464
|
+
return RequestResult(
|
|
465
|
+
success=True,
|
|
466
|
+
cached=cached,
|
|
467
|
+
request_time=raw_response["request_time"],
|
|
468
|
+
request_datetime=raw_response["request_datetime"],
|
|
469
|
+
completions=generated_outputs,
|
|
470
|
+
embedding=[],
|
|
471
|
+
)
|