crfm-helm 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/METADATA +41 -57
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/RECORD +197 -152
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +32 -31
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +12 -5
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +12 -12
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +8 -8
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +77 -9
- helm/benchmark/adaptation/common_adapter_specs.py +2 -0
- helm/benchmark/annotation/air_bench_annotator.py +64 -0
- helm/benchmark/annotation/annotator_factory.py +6 -0
- helm/benchmark/annotation/anthropic_red_team_annotator.py +70 -0
- helm/benchmark/annotation/call_center_annotator.py +247 -0
- helm/benchmark/annotation/financebench_annotator.py +79 -0
- helm/benchmark/annotation/harm_bench_annotator.py +68 -0
- helm/benchmark/annotation/{image2structure → image2struct}/latex_compiler_annotator.py +2 -2
- helm/benchmark/annotation/{image2structure → image2struct}/lilypond_compiler_annotator.py +5 -3
- helm/benchmark/annotation/{image2structure → image2struct}/webpage_compiler_annotator.py +5 -5
- helm/benchmark/annotation/live_qa_annotator.py +71 -0
- helm/benchmark/annotation/medication_qa_annotator.py +68 -0
- helm/benchmark/annotation/model_as_judge.py +45 -0
- helm/benchmark/annotation/simple_safety_tests_annotator.py +64 -0
- helm/benchmark/annotation/xstest_annotator.py +110 -0
- helm/benchmark/augmentations/translate_perturbation.py +1 -0
- helm/benchmark/huggingface_registration.py +16 -6
- helm/benchmark/metrics/air_bench_metrics.py +56 -0
- helm/benchmark/metrics/annotation_metrics.py +108 -0
- helm/benchmark/metrics/bhasa_metrics.py +188 -0
- helm/benchmark/metrics/bhasa_metrics_specs.py +10 -0
- helm/benchmark/metrics/code_metrics_helper.py +11 -1
- helm/benchmark/metrics/fin_qa_metrics.py +60 -0
- helm/benchmark/metrics/fin_qa_metrics_helper.py +398 -0
- helm/benchmark/metrics/gpt4v_originality_critique_metrics.py +126 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +1 -0
- helm/benchmark/metrics/live_qa_metrics.py +23 -0
- helm/benchmark/metrics/medication_qa_metrics.py +23 -0
- helm/benchmark/metrics/prometheus_vision_critique_metrics.py +185 -0
- helm/benchmark/metrics/reka_vibe_critique_metrics.py +158 -0
- helm/benchmark/metrics/safety_metrics.py +57 -0
- helm/benchmark/metrics/summac/model_summac.py +3 -3
- helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -2
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +4 -4
- helm/benchmark/metrics/unitxt_metrics.py +20 -10
- helm/benchmark/metrics/vision_language/emd_utils.py +4 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +30 -72
- helm/benchmark/metrics/vision_language/image_utils.py +1 -1
- helm/benchmark/model_metadata_registry.py +3 -3
- helm/benchmark/presentation/schema.py +54 -4
- helm/benchmark/presentation/test_run_entry.py +1 -0
- helm/benchmark/presentation/test_schema.py +11 -0
- helm/benchmark/run.py +31 -2
- helm/benchmark/run_expander.py +113 -10
- helm/benchmark/run_spec_factory.py +4 -0
- helm/benchmark/run_specs/air_bench_run_specs.py +40 -0
- helm/benchmark/run_specs/bhasa_run_specs.py +638 -0
- helm/benchmark/run_specs/call_center_run_specs.py +152 -0
- helm/benchmark/run_specs/classic_run_specs.py +15 -11
- helm/benchmark/run_specs/decodingtrust_run_specs.py +11 -9
- helm/benchmark/run_specs/experimental_run_specs.py +85 -0
- helm/benchmark/run_specs/finance_run_specs.py +110 -0
- helm/benchmark/run_specs/safety_run_specs.py +154 -0
- helm/benchmark/run_specs/vlm_run_specs.py +251 -57
- helm/benchmark/scenarios/air_bench_scenario.py +50 -0
- helm/benchmark/scenarios/anthropic_red_team_scenario.py +71 -0
- helm/benchmark/scenarios/banking77_scenario.py +51 -0
- helm/benchmark/scenarios/bhasa_scenario.py +1798 -0
- helm/benchmark/scenarios/call_center_scenario.py +84 -0
- helm/benchmark/scenarios/ci_mcqa_scenario.py +80 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +2 -1
- helm/benchmark/scenarios/entity_data_imputation_scenario.py +8 -2
- helm/benchmark/scenarios/ewok_scenario.py +116 -0
- helm/benchmark/scenarios/fin_qa_scenario.py +119 -0
- helm/benchmark/scenarios/financebench_scenario.py +53 -0
- helm/benchmark/scenarios/harm_bench_scenario.py +59 -0
- helm/benchmark/scenarios/scenario.py +1 -1
- helm/benchmark/scenarios/simple_safety_tests_scenario.py +33 -0
- helm/benchmark/scenarios/test_air_bench_scenario.py +27 -0
- helm/benchmark/scenarios/test_commonsense_scenario.py +21 -0
- helm/benchmark/scenarios/test_ewok_scenario.py +25 -0
- helm/benchmark/scenarios/test_financebench_scenario.py +26 -0
- helm/benchmark/scenarios/test_gsm_scenario.py +31 -0
- helm/benchmark/scenarios/test_legalbench_scenario.py +30 -0
- helm/benchmark/scenarios/test_math_scenario.py +2 -8
- helm/benchmark/scenarios/test_med_qa_scenario.py +30 -0
- helm/benchmark/scenarios/test_mmlu_scenario.py +33 -0
- helm/benchmark/scenarios/test_narrativeqa_scenario.py +73 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +4 -4
- helm/benchmark/scenarios/vision_language/a_okvqa_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +5 -5
- helm/benchmark/scenarios/vision_language/crossmodal_3600_scenario.py +2 -1
- helm/benchmark/scenarios/vision_language/exams_v_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/fair_face_scenario.py +136 -0
- helm/benchmark/scenarios/vision_language/flickr30k_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/gqa_scenario.py +2 -2
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/chart2csv_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/{image2structure/image2structure_scenario.py → image2struct/image2struct_scenario.py} +13 -2
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/latex_scenario.py +3 -7
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/musicsheet_scenario.py +1 -5
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/utils_latex.py +31 -39
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/driver.py +1 -1
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/utils.py +1 -1
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage_scenario.py +44 -13
- helm/benchmark/scenarios/vision_language/math_vista_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +3 -3
- helm/benchmark/scenarios/vision_language/mm_safety_bench_scenario.py +2 -2
- helm/benchmark/scenarios/vision_language/mme_scenario.py +21 -18
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/pairs_scenario.py +7 -6
- helm/benchmark/scenarios/vision_language/pope_scenario.py +2 -1
- helm/benchmark/scenarios/vision_language/real_world_qa_scenario.py +57 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +7 -5
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +5 -5
- helm/benchmark/scenarios/vision_language/vibe_eval_scenario.py +98 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +3 -1
- helm/benchmark/scenarios/xstest_scenario.py +35 -0
- helm/benchmark/server.py +1 -6
- helm/benchmark/static/schema_air_bench.yaml +3149 -0
- helm/benchmark/static/schema_bhasa.yaml +709 -0
- helm/benchmark/static/schema_call_center.yaml +232 -0
- helm/benchmark/static/schema_classic.yaml +3 -59
- helm/benchmark/static/schema_cleva.yaml +768 -0
- helm/benchmark/static/schema_decodingtrust.yaml +444 -0
- helm/benchmark/static/schema_ewok.yaml +367 -0
- helm/benchmark/static/schema_finance.yaml +189 -0
- helm/benchmark/static/schema_image2struct.yaml +588 -0
- helm/benchmark/static/schema_instruction_following.yaml +3 -52
- helm/benchmark/static/schema_lite.yaml +3 -61
- helm/benchmark/static/schema_medical.yaml +255 -0
- helm/benchmark/static/schema_mmlu.yaml +3 -61
- helm/benchmark/static/schema_safety.yaml +247 -0
- helm/benchmark/static/schema_tables.yaml +317 -0
- helm/benchmark/static/schema_thai.yaml +244 -0
- helm/benchmark/static/schema_unitxt.yaml +3 -61
- helm/benchmark/static/{schema_vlm.yaml → schema_vhelm.yaml} +304 -298
- helm/benchmark/static/schema_vhelm_lite.yaml +4 -59
- helm/benchmark/static_build/assets/accenture-6f97eeda.png +0 -0
- helm/benchmark/static_build/assets/air-overview-d2e6c49f.png +0 -0
- helm/benchmark/static_build/assets/aisingapore-6dfc9acf.png +0 -0
- helm/benchmark/static_build/assets/cresta-9e22b983.png +0 -0
- helm/benchmark/static_build/assets/cuhk-8c5631e9.png +0 -0
- helm/benchmark/static_build/assets/index-05c76bb1.css +1 -0
- helm/benchmark/static_build/assets/index-58f97dcd.js +10 -0
- helm/benchmark/static_build/assets/overview-74aea3d8.png +0 -0
- helm/benchmark/static_build/assets/process-flow-bd2eba96.png +0 -0
- helm/benchmark/static_build/assets/scb10x-204bd786.png +0 -0
- helm/benchmark/static_build/assets/wellsfargo-a86a6c4a.png +0 -0
- helm/benchmark/static_build/index.html +2 -2
- helm/benchmark/window_services/test_openai_window_service.py +8 -8
- helm/clients/ai21_client.py +71 -1
- helm/clients/anthropic_client.py +50 -28
- helm/clients/auto_client.py +11 -0
- helm/clients/client.py +24 -7
- helm/clients/cohere_client.py +98 -3
- helm/clients/huggingface_client.py +79 -19
- helm/clients/nvidia_nim_client.py +35 -0
- helm/clients/openai_client.py +11 -5
- helm/clients/palmyra_client.py +25 -0
- helm/clients/perspective_api_client.py +11 -6
- helm/clients/reka_client.py +189 -0
- helm/clients/test_client.py +7 -9
- helm/clients/test_huggingface_client.py +19 -3
- helm/clients/test_together_client.py +72 -2
- helm/clients/together_client.py +129 -23
- helm/clients/vertexai_client.py +62 -18
- helm/clients/vision_language/huggingface_vlm_client.py +1 -0
- helm/clients/vision_language/open_flamingo_client.py +1 -2
- helm/clients/vision_language/paligemma_client.py +146 -0
- helm/clients/vision_language/palmyra_vision_client.py +99 -0
- helm/clients/yi_client.py +31 -0
- helm/common/critique_request.py +10 -1
- helm/common/images_utils.py +25 -0
- helm/common/mongo_key_value_store.py +2 -1
- helm/common/request.py +16 -0
- helm/config/model_deployments.yaml +740 -363
- helm/config/model_metadata.yaml +824 -128
- helm/config/tokenizer_configs.yaml +207 -10
- helm/proxy/critique/model_critique_client.py +32 -4
- helm/proxy/example_queries.py +14 -21
- helm/proxy/services/server_service.py +2 -3
- helm/proxy/token_counters/test_auto_token_counter.py +2 -2
- helm/tokenizers/ai21_tokenizer.py +51 -59
- helm/tokenizers/auto_tokenizer.py +1 -1
- helm/tokenizers/cohere_tokenizer.py +29 -62
- helm/tokenizers/huggingface_tokenizer.py +35 -13
- helm/tokenizers/test_ai21_tokenizer.py +48 -0
- helm/tokenizers/test_cohere_tokenizer.py +39 -0
- helm/tokenizers/test_huggingface_tokenizer.py +5 -1
- helm/benchmark/static/benchmarking.css +0 -156
- helm/benchmark/static/benchmarking.js +0 -1705
- helm/benchmark/static/config.js +0 -3
- helm/benchmark/static/general.js +0 -122
- helm/benchmark/static/images/crfm-logo.png +0 -0
- helm/benchmark/static/images/helm-logo-simple.png +0 -0
- helm/benchmark/static/images/helm-logo.png +0 -0
- helm/benchmark/static/images/language-model-helm.png +0 -0
- helm/benchmark/static/images/organizations/ai21.png +0 -0
- helm/benchmark/static/images/organizations/anthropic.png +0 -0
- helm/benchmark/static/images/organizations/bigscience.png +0 -0
- helm/benchmark/static/images/organizations/cohere.png +0 -0
- helm/benchmark/static/images/organizations/eleutherai.png +0 -0
- helm/benchmark/static/images/organizations/google.png +0 -0
- helm/benchmark/static/images/organizations/meta.png +0 -0
- helm/benchmark/static/images/organizations/microsoft.png +0 -0
- helm/benchmark/static/images/organizations/nvidia.png +0 -0
- helm/benchmark/static/images/organizations/openai.png +0 -0
- helm/benchmark/static/images/organizations/together.png +0 -0
- helm/benchmark/static/images/organizations/tsinghua-keg.png +0 -0
- helm/benchmark/static/images/organizations/yandex.png +0 -0
- helm/benchmark/static/images/scenarios-by-metrics.png +0 -0
- helm/benchmark/static/images/taxonomy-scenarios.png +0 -0
- helm/benchmark/static/index.html +0 -68
- helm/benchmark/static/info-icon.png +0 -0
- helm/benchmark/static/json-urls.js +0 -69
- helm/benchmark/static/plot-captions.js +0 -27
- helm/benchmark/static/schema_image2structure.yaml +0 -304
- helm/benchmark/static/utils.js +0 -285
- helm/benchmark/static_build/assets/index-737eef9e.js +0 -10
- helm/benchmark/static_build/assets/index-878a1094.css +0 -1
- helm/benchmark/window_services/ai21_window_service.py +0 -247
- helm/benchmark/window_services/cohere_window_service.py +0 -101
- helm/benchmark/window_services/test_ai21_window_service.py +0 -163
- helm/benchmark/window_services/test_cohere_window_service.py +0 -75
- helm/benchmark/window_services/test_cohere_window_service_utils.py +0 -8328
- helm/benchmark/window_services/test_ice_window_service.py +0 -327
- helm/tokenizers/ice_tokenizer.py +0 -30
- helm/tokenizers/test_ice_tokenizer.py +0 -57
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/LICENSE +0 -0
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/top_level.txt +0 -0
- /helm/benchmark/annotation/{image2structure → image2struct}/__init__.py +0 -0
- /helm/benchmark/annotation/{image2structure → image2struct}/image_compiler_annotator.py +0 -0
- /helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/__init__.py +0 -0
- /helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/__init__.py +0 -0
- /helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/jekyll_server.py +0 -0
helm/clients/client.py
CHANGED
|
@@ -39,13 +39,17 @@ class CachingClient(Client):
|
|
|
39
39
|
"""
|
|
40
40
|
if request.random is not None:
|
|
41
41
|
assert "random" not in raw_request
|
|
42
|
-
|
|
42
|
+
return {**raw_request, "random": request.random}
|
|
43
43
|
else:
|
|
44
|
-
|
|
45
|
-
return cache_key
|
|
44
|
+
return {**raw_request}
|
|
46
45
|
|
|
47
46
|
|
|
48
|
-
def truncate_sequence(
|
|
47
|
+
def truncate_sequence(
|
|
48
|
+
sequence: GeneratedOutput,
|
|
49
|
+
request: Request,
|
|
50
|
+
end_of_text_token: Optional[str] = None,
|
|
51
|
+
print_warning: bool = True,
|
|
52
|
+
) -> GeneratedOutput:
|
|
49
53
|
"""
|
|
50
54
|
Certain providers have bugs where they aren't respecting max_tokens,
|
|
51
55
|
stop_sequences and the end of text token, so as a hack, we have to manually
|
|
@@ -64,7 +68,11 @@ def truncate_sequence(sequence: GeneratedOutput, request: Request, print_warning
|
|
|
64
68
|
hlog("WARNING: don't know how to handle echo_prompt and max_tokens > 0, not truncating")
|
|
65
69
|
return sequence
|
|
66
70
|
|
|
67
|
-
|
|
71
|
+
if end_of_text_token:
|
|
72
|
+
stop_sequences = request.stop_sequences + [end_of_text_token]
|
|
73
|
+
else:
|
|
74
|
+
stop_sequences = request.stop_sequences
|
|
75
|
+
for stop in stop_sequences:
|
|
68
76
|
# Find `stop` in the text
|
|
69
77
|
try:
|
|
70
78
|
new_text = sequence.text[: sequence.text.index(stop)]
|
|
@@ -116,7 +124,12 @@ def truncate_sequence(sequence: GeneratedOutput, request: Request, print_warning
|
|
|
116
124
|
|
|
117
125
|
|
|
118
126
|
def truncate_and_tokenize_response_text(
|
|
119
|
-
text: str,
|
|
127
|
+
text: str,
|
|
128
|
+
request: Request,
|
|
129
|
+
tokenizer: Tokenizer,
|
|
130
|
+
tokenizer_name: str,
|
|
131
|
+
end_of_text_token: Optional[str] = None,
|
|
132
|
+
original_finish_reason: str = "endoftext",
|
|
120
133
|
) -> GeneratedOutput:
|
|
121
134
|
"""Truncate a string-only response to respect stop_sequences and max_tokens.
|
|
122
135
|
|
|
@@ -139,7 +152,11 @@ def truncate_and_tokenize_response_text(
|
|
|
139
152
|
if request.echo_prompt:
|
|
140
153
|
raise Exception("truncate_and_tokenize_response_text() does not support requests with echo_prompt = True")
|
|
141
154
|
|
|
142
|
-
|
|
155
|
+
if end_of_text_token:
|
|
156
|
+
stop_sequences = request.stop_sequences + [end_of_text_token]
|
|
157
|
+
else:
|
|
158
|
+
stop_sequences = request.stop_sequences
|
|
159
|
+
for stop_sequence in stop_sequences:
|
|
143
160
|
try:
|
|
144
161
|
text = text[: text.index(stop_sequence)]
|
|
145
162
|
finish_reason = "stop"
|
helm/clients/cohere_client.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import requests
|
|
3
|
-
from typing import List
|
|
3
|
+
from typing import List, Optional, Sequence, TypedDict
|
|
4
4
|
|
|
5
5
|
from helm.common.cache import CacheConfig
|
|
6
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
6
7
|
from helm.common.request import (
|
|
7
8
|
wrap_request_time,
|
|
8
9
|
EMBEDDING_UNAVAILABLE_REQUEST_RESULT,
|
|
@@ -11,8 +12,13 @@ from helm.common.request import (
|
|
|
11
12
|
GeneratedOutput,
|
|
12
13
|
Token,
|
|
13
14
|
)
|
|
14
|
-
from .client import CachingClient, truncate_sequence
|
|
15
|
-
from .cohere_utils import get_cohere_url, DEFAULT_COHERE_API_VERSION
|
|
15
|
+
from helm.clients.client import CachingClient, truncate_sequence
|
|
16
|
+
from helm.clients.cohere_utils import get_cohere_url, DEFAULT_COHERE_API_VERSION
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
import cohere
|
|
20
|
+
except ModuleNotFoundError as e:
|
|
21
|
+
handle_module_not_found_error(e, ["cohere"])
|
|
16
22
|
|
|
17
23
|
|
|
18
24
|
class CohereClient(CachingClient):
|
|
@@ -152,3 +158,92 @@ class CohereClient(CachingClient):
|
|
|
152
158
|
completions=completions,
|
|
153
159
|
embedding=[],
|
|
154
160
|
)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class CohereRawChatRequest(TypedDict):
|
|
164
|
+
message: str
|
|
165
|
+
model: Optional[str]
|
|
166
|
+
preamble: Optional[str]
|
|
167
|
+
chat_history: Optional[Sequence[cohere.ChatMessage]]
|
|
168
|
+
temperature: Optional[float]
|
|
169
|
+
max_tokens: Optional[int]
|
|
170
|
+
k: Optional[int]
|
|
171
|
+
p: Optional[float]
|
|
172
|
+
seed: Optional[float]
|
|
173
|
+
stop_sequences: Optional[Sequence[str]]
|
|
174
|
+
frequency_penalty: Optional[float]
|
|
175
|
+
presence_penalty: Optional[float]
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def convert_to_raw_chat_request(request: Request) -> CohereRawChatRequest:
|
|
179
|
+
# TODO: Support chat
|
|
180
|
+
model = request.model.replace("cohere/", "")
|
|
181
|
+
return {
|
|
182
|
+
"message": request.prompt,
|
|
183
|
+
"model": model,
|
|
184
|
+
"preamble": None,
|
|
185
|
+
"chat_history": None,
|
|
186
|
+
"temperature": request.temperature,
|
|
187
|
+
"max_tokens": request.max_tokens,
|
|
188
|
+
"k": request.top_k_per_token,
|
|
189
|
+
"p": request.top_p,
|
|
190
|
+
"stop_sequences": request.stop_sequences,
|
|
191
|
+
"seed": float(request.random) if request.random is not None else None,
|
|
192
|
+
"frequency_penalty": request.frequency_penalty,
|
|
193
|
+
"presence_penalty": request.presence_penalty,
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class CohereChatClient(CachingClient):
|
|
198
|
+
"""
|
|
199
|
+
Leverages the chat endpoint: https://docs.cohere.com/reference/chat
|
|
200
|
+
|
|
201
|
+
Cohere models will only support chat soon: https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
def __init__(self, api_key: str, cache_config: CacheConfig):
|
|
205
|
+
super().__init__(cache_config=cache_config)
|
|
206
|
+
self.client = cohere.Client(api_key=api_key)
|
|
207
|
+
|
|
208
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
209
|
+
if request.embedding:
|
|
210
|
+
return EMBEDDING_UNAVAILABLE_REQUEST_RESULT
|
|
211
|
+
# TODO: Support multiple completions
|
|
212
|
+
assert request.num_completions == 1, "CohereChatClient only supports num_completions=1"
|
|
213
|
+
# TODO: Support messages
|
|
214
|
+
assert not request.messages, "CohereChatClient currently does not support the messages API"
|
|
215
|
+
|
|
216
|
+
raw_request: CohereRawChatRequest = convert_to_raw_chat_request(request)
|
|
217
|
+
|
|
218
|
+
try:
|
|
219
|
+
|
|
220
|
+
def do_it():
|
|
221
|
+
"""
|
|
222
|
+
Send the request to the Cohere Chat API. Responses will be structured like this:
|
|
223
|
+
cohere.Chat {
|
|
224
|
+
message: What's up?
|
|
225
|
+
text: Hey there! How's it going? I'm doing well, thank you for asking 😊.
|
|
226
|
+
...
|
|
227
|
+
}
|
|
228
|
+
"""
|
|
229
|
+
raw_response = self.client.chat(**raw_request).dict()
|
|
230
|
+
assert "text" in raw_response, f"Response does not contain text: {raw_response}"
|
|
231
|
+
return raw_response
|
|
232
|
+
|
|
233
|
+
response, cached = self.cache.get(raw_request, wrap_request_time(do_it))
|
|
234
|
+
except (requests.exceptions.RequestException, AssertionError) as e:
|
|
235
|
+
error: str = f"CohereClient error: {e}"
|
|
236
|
+
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
237
|
+
|
|
238
|
+
completions: List[GeneratedOutput] = []
|
|
239
|
+
completion: GeneratedOutput = GeneratedOutput(text=response["text"], logprob=0.0, tokens=[])
|
|
240
|
+
completions.append(completion)
|
|
241
|
+
|
|
242
|
+
return RequestResult(
|
|
243
|
+
success=True,
|
|
244
|
+
cached=cached,
|
|
245
|
+
request_time=response["request_time"],
|
|
246
|
+
request_datetime=response["request_datetime"],
|
|
247
|
+
completions=completions,
|
|
248
|
+
embedding=[],
|
|
249
|
+
)
|
|
@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, TypedDict
|
|
|
9
9
|
|
|
10
10
|
from helm.common.cache import CacheConfig
|
|
11
11
|
from helm.common.hierarchical_logger import htrack_block, hlog
|
|
12
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
12
13
|
from helm.common.request import (
|
|
13
14
|
wrap_request_time,
|
|
14
15
|
EMBEDDING_UNAVAILABLE_REQUEST_RESULT,
|
|
@@ -17,6 +18,7 @@ from helm.common.request import (
|
|
|
17
18
|
GeneratedOutput,
|
|
18
19
|
Token,
|
|
19
20
|
)
|
|
21
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
20
22
|
from .client import CachingClient, truncate_sequence
|
|
21
23
|
from helm.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer, WrappedPreTrainedTokenizer
|
|
22
24
|
from threading import Lock
|
|
@@ -53,27 +55,65 @@ class HuggingFaceRequest(TypedDict):
|
|
|
53
55
|
class HuggingFaceServer:
|
|
54
56
|
"""A thin wrapper around a Hugging Face AutoModelForCausalLM for HuggingFaceClient to call."""
|
|
55
57
|
|
|
56
|
-
def __init__(
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
pretrained_model_name_or_path: str,
|
|
61
|
+
wrapped_tokenizer: WrappedPreTrainedTokenizer,
|
|
62
|
+
openvino: bool = False,
|
|
63
|
+
**kwargs,
|
|
64
|
+
):
|
|
65
|
+
self.device: Optional[str]
|
|
66
|
+
if "device_map" in kwargs:
|
|
67
|
+
try:
|
|
68
|
+
import accelerate # noqa: F401
|
|
69
|
+
except ModuleNotFoundError as e:
|
|
70
|
+
handle_module_not_found_error(e, ["accelerate"])
|
|
71
|
+
hlog(f'Hugging Face device_map set to "{kwargs["device_map"]}".')
|
|
72
|
+
self.device = None
|
|
73
|
+
elif torch.cuda.is_available():
|
|
74
|
+
hlog('Hugging Face device set to "cuda:0" because CUDA is available.')
|
|
75
|
+
self.device = "cuda:0"
|
|
60
76
|
else:
|
|
77
|
+
hlog('Hugging Face device set to "cpu" because CUDA is unavailable.')
|
|
61
78
|
self.device = "cpu"
|
|
79
|
+
|
|
80
|
+
# Security issue: currently we trust remote code by default.
|
|
81
|
+
# We retain this temporarily to maintain reverse compatibility.
|
|
82
|
+
# TODO: Delete if-else and don't set trust_remote_code=True
|
|
83
|
+
if "trust_remote_code" not in kwargs:
|
|
84
|
+
kwargs["trust_remote_code"] = True
|
|
85
|
+
|
|
62
86
|
with htrack_block(f"Loading Hugging Face model {pretrained_model_name_or_path}"):
|
|
63
87
|
# WARNING this may fail if your GPU does not have enough memory
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
88
|
+
if openvino:
|
|
89
|
+
# Optimum Intel provides a simple interface to optimize Transformer models and convert them to \
|
|
90
|
+
# OpenVINO™ Intermediate Representation (IR) format to accelerate end-to-end pipelines on \
|
|
91
|
+
# Intel® architectures using OpenVINO™ runtime.
|
|
92
|
+
try:
|
|
93
|
+
from optimum.intel.openvino import OVModelForCausalLM
|
|
94
|
+
except ModuleNotFoundError as e:
|
|
95
|
+
handle_module_not_found_error(e, ["openvino"])
|
|
96
|
+
|
|
97
|
+
self.device = "cpu"
|
|
98
|
+
self.model = OVModelForCausalLM.from_pretrained(
|
|
99
|
+
pretrained_model_name_or_path, export=True, **kwargs
|
|
100
|
+
).to(self.device)
|
|
101
|
+
elif self.device is None:
|
|
102
|
+
# kwargs contains device_map=auto
|
|
103
|
+
# Do not call to() because accelerate will take care of model device placement.
|
|
104
|
+
self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
|
105
|
+
else:
|
|
106
|
+
self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs).to(
|
|
107
|
+
self.device
|
|
108
|
+
)
|
|
109
|
+
self.wrapped_tokenizer = wrapped_tokenizer
|
|
71
110
|
|
|
72
111
|
def serve_request(self, raw_request: HuggingFaceRequest) -> Dict:
|
|
73
112
|
with self.wrapped_tokenizer as tokenizer:
|
|
74
113
|
encoded_input = tokenizer(raw_request["prompt"], return_tensors="pt", return_token_type_ids=False).to(
|
|
75
|
-
self.device
|
|
114
|
+
0 if self.device is None else self.device
|
|
76
115
|
)
|
|
116
|
+
|
|
77
117
|
stopping_criteria: Optional[StoppingCriteriaList] = None
|
|
78
118
|
optional_args = {}
|
|
79
119
|
if len(raw_request["stop_sequences"]) > 0:
|
|
@@ -170,7 +210,12 @@ class HuggingFaceServerFactory:
|
|
|
170
210
|
_servers_lock: Lock = Lock()
|
|
171
211
|
|
|
172
212
|
@staticmethod
|
|
173
|
-
def get_server(
|
|
213
|
+
def get_server(
|
|
214
|
+
helm_model_name: str,
|
|
215
|
+
pretrained_model_name_or_path: str,
|
|
216
|
+
wrapped_tokenizer: WrappedPreTrainedTokenizer,
|
|
217
|
+
**kwargs,
|
|
218
|
+
) -> Any:
|
|
174
219
|
"""
|
|
175
220
|
Checks if the desired HuggingFaceModel is cached. Creates the HuggingFaceModel if it's not cached.
|
|
176
221
|
Returns the HuggingFaceModel.
|
|
@@ -182,7 +227,7 @@ class HuggingFaceServerFactory:
|
|
|
182
227
|
f"for HELM model {helm_model_name} with Hugging Face Transformers"
|
|
183
228
|
):
|
|
184
229
|
HuggingFaceServerFactory._servers[helm_model_name] = HuggingFaceServer(
|
|
185
|
-
pretrained_model_name_or_path, **kwargs
|
|
230
|
+
pretrained_model_name_or_path, wrapped_tokenizer, **kwargs
|
|
186
231
|
)
|
|
187
232
|
|
|
188
233
|
return HuggingFaceServerFactory._servers[helm_model_name]
|
|
@@ -206,18 +251,32 @@ def _process_huggingface_client_kwargs(raw_kwargs: Dict[str, Any]):
|
|
|
206
251
|
# e.g. the string "torch.bfloat16" is converted to torch.bfloat16
|
|
207
252
|
torch_dtype = processed_kwargs.get(TORCH_DTYPE_KEY)
|
|
208
253
|
if torch_dtype and isinstance(torch_dtype, str):
|
|
209
|
-
if
|
|
210
|
-
|
|
211
|
-
processed_kwargs[TORCH_DTYPE_KEY] = getattr(torch, torch_dtype[len(TORCH_DTYPE_VALUE_PREFIX) :])
|
|
254
|
+
if torch_dtype.startswith(TORCH_DTYPE_VALUE_PREFIX):
|
|
255
|
+
processed_kwargs[TORCH_DTYPE_KEY] = getattr(torch, torch_dtype[len(TORCH_DTYPE_VALUE_PREFIX) :])
|
|
212
256
|
|
|
213
257
|
return processed_kwargs
|
|
214
258
|
|
|
215
259
|
|
|
216
260
|
class HuggingFaceClient(CachingClient):
|
|
217
|
-
def __init__(
|
|
261
|
+
def __init__(
|
|
262
|
+
self,
|
|
263
|
+
cache_config: CacheConfig,
|
|
264
|
+
tokenizer: Tokenizer,
|
|
265
|
+
pretrained_model_name_or_path: Optional[str] = None,
|
|
266
|
+
end_of_text_token: Optional[str] = None,
|
|
267
|
+
**kwargs,
|
|
268
|
+
):
|
|
218
269
|
super().__init__(cache_config=cache_config)
|
|
219
270
|
self._pretrained_model_name_or_path = pretrained_model_name_or_path
|
|
271
|
+
if not isinstance(tokenizer, HuggingFaceTokenizer):
|
|
272
|
+
raise ValueError(
|
|
273
|
+
f"Tokenizer for Hugging Face model {pretrained_model_name_or_path} must be a HuggingFaceTokenizer, "
|
|
274
|
+
"but instead it is {tokenizer}"
|
|
275
|
+
)
|
|
276
|
+
self._wrapped_tokenizer: WrappedPreTrainedTokenizer = tokenizer.get_wrapped_tokenizer()
|
|
277
|
+
self._tokenizer = tokenizer
|
|
220
278
|
self._kwargs = _process_huggingface_client_kwargs(kwargs)
|
|
279
|
+
self._end_of_text_token = end_of_text_token
|
|
221
280
|
|
|
222
281
|
def make_request(self, request: Request) -> RequestResult:
|
|
223
282
|
# Embedding not supported for this model
|
|
@@ -242,6 +301,7 @@ class HuggingFaceClient(CachingClient):
|
|
|
242
301
|
huggingface_model: HuggingFaceServer = HuggingFaceServerFactory.get_server(
|
|
243
302
|
helm_model_name=request.model,
|
|
244
303
|
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
|
304
|
+
wrapped_tokenizer=self._wrapped_tokenizer,
|
|
245
305
|
**self._kwargs,
|
|
246
306
|
)
|
|
247
307
|
|
|
@@ -284,7 +344,7 @@ class HuggingFaceClient(CachingClient):
|
|
|
284
344
|
sequence_logprob += logprob
|
|
285
345
|
|
|
286
346
|
completion = GeneratedOutput(text=raw_completion["text"], logprob=sequence_logprob, tokens=tokens)
|
|
287
|
-
completion = truncate_sequence(completion, request)
|
|
347
|
+
completion = truncate_sequence(completion, request, end_of_text_token=self._end_of_text_token)
|
|
288
348
|
completions.append(completion)
|
|
289
349
|
|
|
290
350
|
return RequestResult(
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from helm.clients.openai_client import OpenAIClient
|
|
4
|
+
from helm.common.cache import CacheConfig
|
|
5
|
+
from helm.common.request import Request
|
|
6
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class NvidiaNimClient(OpenAIClient):
|
|
10
|
+
|
|
11
|
+
BASE_URL = "https://integrate.api.nvidia.com/v1"
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
tokenizer: Tokenizer,
|
|
16
|
+
tokenizer_name: str,
|
|
17
|
+
cache_config: CacheConfig,
|
|
18
|
+
api_key: Optional[str] = None,
|
|
19
|
+
):
|
|
20
|
+
self.tokenizer = tokenizer
|
|
21
|
+
self.tokenizer_name = tokenizer_name
|
|
22
|
+
super().__init__(
|
|
23
|
+
tokenizer=tokenizer,
|
|
24
|
+
tokenizer_name=tokenizer_name,
|
|
25
|
+
cache_config=cache_config,
|
|
26
|
+
api_key=api_key,
|
|
27
|
+
org_id=None,
|
|
28
|
+
base_url=NvidiaNimClient.BASE_URL,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
def _get_model_for_request(self, request: Request) -> str:
|
|
32
|
+
return request.model
|
|
33
|
+
|
|
34
|
+
def _is_chat_model_engine(self, model_engine: str) -> bool:
|
|
35
|
+
return True
|
helm/clients/openai_client.py
CHANGED
|
@@ -12,8 +12,8 @@ from helm.common.tokenization_request import (
|
|
|
12
12
|
TokenizationRequest,
|
|
13
13
|
TokenizationRequestResult,
|
|
14
14
|
)
|
|
15
|
-
from helm.tokenizers.tokenizer import Tokenizer
|
|
16
15
|
from .client import CachingClient, truncate_sequence, generate_uid_for_multimodal_prompt
|
|
16
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
17
17
|
|
|
18
18
|
try:
|
|
19
19
|
import openai
|
|
@@ -60,8 +60,7 @@ class OpenAIClient(CachingClient):
|
|
|
60
60
|
|
|
61
61
|
def _get_cache_key(self, raw_request: Dict, request: Request):
|
|
62
62
|
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
63
|
-
if
|
|
64
|
-
assert request.multimodal_prompt is not None
|
|
63
|
+
if request.multimodal_prompt:
|
|
65
64
|
prompt_key: str = generate_uid_for_multimodal_prompt(request.multimodal_prompt)
|
|
66
65
|
cache_key = {**cache_key, "multimodal_prompt": prompt_key}
|
|
67
66
|
del cache_key["messages"]
|
|
@@ -103,6 +102,14 @@ class OpenAIClient(CachingClient):
|
|
|
103
102
|
|
|
104
103
|
def _make_chat_request(self, request: Request) -> RequestResult:
|
|
105
104
|
messages: Optional[List[Dict[str, Union[str, Any]]]] = request.messages
|
|
105
|
+
if (
|
|
106
|
+
(request.prompt and request.messages)
|
|
107
|
+
or (request.prompt and request.multimodal_prompt)
|
|
108
|
+
or (request.messages and request.multimodal_prompt)
|
|
109
|
+
):
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"More than one of `prompt`, `messages` and `multimodal_prompt` was set in request: {request}"
|
|
112
|
+
)
|
|
106
113
|
if request.messages is not None:
|
|
107
114
|
# Checks that all messages have a role and some content
|
|
108
115
|
for message in request.messages:
|
|
@@ -125,6 +132,7 @@ class OpenAIClient(CachingClient):
|
|
|
125
132
|
content: Union[str, List[Union[str, Any]]]
|
|
126
133
|
if request.multimodal_prompt is not None:
|
|
127
134
|
content = []
|
|
135
|
+
request.validate()
|
|
128
136
|
for media_object in request.multimodal_prompt.media_objects:
|
|
129
137
|
if media_object.is_type("image") and media_object.location:
|
|
130
138
|
from helm.common.images_utils import encode_base64
|
|
@@ -133,8 +141,6 @@ class OpenAIClient(CachingClient):
|
|
|
133
141
|
image_object: Dict[str, str] = {"url": f"data:image/jpeg;base64,{base64_image}"}
|
|
134
142
|
content.append({"type": "image_url", "image_url": image_object})
|
|
135
143
|
elif media_object.is_type(TEXT_TYPE):
|
|
136
|
-
if media_object.text is None:
|
|
137
|
-
raise ValueError("MediaObject of text type has missing text field value")
|
|
138
144
|
content.append({"type": media_object.type, "text": media_object.text})
|
|
139
145
|
else:
|
|
140
146
|
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
helm/clients/palmyra_client.py
CHANGED
|
@@ -3,6 +3,7 @@ import json
|
|
|
3
3
|
import requests
|
|
4
4
|
from typing import Any, Dict, List
|
|
5
5
|
|
|
6
|
+
from helm.clients.openai_client import OpenAIClient
|
|
6
7
|
from helm.common.cache import CacheConfig
|
|
7
8
|
from helm.common.hierarchical_logger import hlog
|
|
8
9
|
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, Token, ErrorFlags
|
|
@@ -142,3 +143,27 @@ class PalmyraClient(CachingClient):
|
|
|
142
143
|
completions=completions,
|
|
143
144
|
embedding=[],
|
|
144
145
|
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class PalmyraChatClient(OpenAIClient):
|
|
149
|
+
"""Sends request to a Palmyra model using a OpenAI-compatible Chat API."""
|
|
150
|
+
|
|
151
|
+
def __init__(
|
|
152
|
+
self,
|
|
153
|
+
tokenizer: Tokenizer,
|
|
154
|
+
tokenizer_name: str,
|
|
155
|
+
cache_config: CacheConfig,
|
|
156
|
+
api_key: str,
|
|
157
|
+
base_url: str,
|
|
158
|
+
):
|
|
159
|
+
super().__init__(
|
|
160
|
+
tokenizer=tokenizer,
|
|
161
|
+
tokenizer_name=tokenizer_name,
|
|
162
|
+
cache_config=cache_config,
|
|
163
|
+
api_key=api_key,
|
|
164
|
+
org_id=None,
|
|
165
|
+
base_url=base_url,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
def _is_chat_model_engine(self, model_engine: str) -> bool:
|
|
169
|
+
return True
|
|
@@ -4,16 +4,21 @@ from dataclasses import asdict
|
|
|
4
4
|
from typing import Any, List, Dict, Optional
|
|
5
5
|
|
|
6
6
|
from dacite import from_dict
|
|
7
|
-
|
|
8
|
-
from googleapiclient.errors import BatchError, HttpError
|
|
9
|
-
from googleapiclient.http import BatchHttpRequest
|
|
10
|
-
from httplib2 import HttpLib2Error
|
|
7
|
+
|
|
11
8
|
from helm.clients.toxicity_classifier_client import ToxicityClassifierClient
|
|
9
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
12
10
|
from helm.proxy.retry import NonRetriableException
|
|
13
|
-
|
|
14
11
|
from helm.common.cache import Cache, CacheConfig
|
|
15
12
|
from helm.common.perspective_api_request import ToxicityAttributes, PerspectiveAPIRequest, PerspectiveAPIRequestResult
|
|
16
|
-
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from googleapiclient import discovery
|
|
16
|
+
from googleapiclient.errors import BatchError, HttpError
|
|
17
|
+
from googleapiclient.http import BatchHttpRequest
|
|
18
|
+
from httplib2 import HttpLib2Error
|
|
19
|
+
from google.auth.exceptions import DefaultCredentialsError
|
|
20
|
+
except ModuleNotFoundError as e:
|
|
21
|
+
handle_module_not_found_error(e, ["metrics"])
|
|
17
22
|
|
|
18
23
|
|
|
19
24
|
class PerspectiveAPIClientCredentialsError(NonRetriableException):
|