crfm-helm 0.5.5__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.6.dist-info}/METADATA +27 -13
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.6.dist-info}/RECORD +203 -156
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.6.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +3 -3
- helm/benchmark/adaptation/adapters/test_adapter.py +4 -4
- helm/benchmark/annotation/air_bench_annotator.py +1 -1
- helm/benchmark/annotation/bigcodebench_annotator.py +3 -3
- helm/benchmark/annotation/bird_sql_annotator.py +2 -2
- helm/benchmark/annotation/chw_care_plan_annotator.py +7 -12
- helm/benchmark/annotation/ehr_sql_annotator.py +2 -2
- helm/benchmark/annotation/helpdesk_call_summarization_annotator.py +7 -7
- helm/benchmark/annotation/mimic_bhc_annotator.py +100 -0
- helm/benchmark/annotation/model_as_judge.py +12 -16
- helm/benchmark/annotation/omni_math_annotator.py +13 -14
- helm/benchmark/annotation/wildbench_annotator.py +9 -9
- helm/benchmark/executor.py +11 -12
- helm/benchmark/metrics/aci_bench_metrics.py +9 -29
- helm/benchmark/metrics/bias_word_lists.py +1 -1
- helm/benchmark/metrics/chw_care_plan_metrics.py +10 -30
- helm/benchmark/metrics/classification_metrics.py +3 -3
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +2 -2
- helm/benchmark/metrics/dischargeme_metrics.py +9 -29
- helm/benchmark/metrics/efficiency_metrics.py +3 -3
- helm/benchmark/metrics/gpt4_audio_refusal_metrics.py +145 -0
- helm/benchmark/metrics/ifeval_metrics.py +2 -2
- helm/benchmark/metrics/kpi_edgar_metrics.py +121 -0
- helm/benchmark/metrics/llm_jury_metrics.py +46 -0
- helm/benchmark/metrics/med_dialog_metrics.py +9 -29
- helm/benchmark/metrics/medalign_metrics.py +9 -29
- helm/benchmark/metrics/medi_qa_metrics.py +9 -29
- helm/benchmark/metrics/medication_qa_metrics.py +10 -30
- helm/benchmark/metrics/melt_bias_metric.py +234 -0
- helm/benchmark/metrics/melt_bias_word_lists.py +1367 -0
- helm/benchmark/metrics/melt_metric_specs.py +43 -0
- helm/benchmark/metrics/melt_toxicity_metric.py +107 -0
- helm/benchmark/metrics/mental_health_metrics.py +9 -29
- helm/benchmark/metrics/metric_service.py +11 -11
- helm/benchmark/metrics/mimic_bhc_metrics.py +14 -0
- helm/benchmark/metrics/mimic_rrs_metrics.py +9 -29
- helm/benchmark/metrics/mtsamples_procedures_metrics.py +9 -29
- helm/benchmark/metrics/mtsamples_replicate_metrics.py +9 -29
- helm/benchmark/metrics/openai_mrcr_metrics.py +52 -0
- helm/benchmark/metrics/ruler_qa_metrics.py +34 -0
- helm/benchmark/metrics/starr_patient_instructions_metrics.py +9 -29
- helm/benchmark/metrics/summac/model_summac.py +1 -2
- helm/benchmark/metrics/summarization_metrics.py +2 -1
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +2 -2
- helm/benchmark/metrics/toxicity_metrics.py +2 -2
- helm/benchmark/metrics/unitxt_metrics.py +3 -4
- helm/benchmark/metrics/vision_language/emd_utils.py +4 -2
- helm/benchmark/metrics/vision_language/image_utils.py +2 -2
- helm/benchmark/model_deployment_registry.py +6 -8
- helm/benchmark/presentation/contamination.py +3 -3
- helm/benchmark/presentation/create_plots.py +33 -12
- helm/benchmark/presentation/run_display.py +13 -0
- helm/benchmark/presentation/schema.py +2 -1
- helm/benchmark/presentation/summarize.py +76 -59
- helm/benchmark/reeval_run.py +3 -4
- helm/benchmark/reeval_runner.py +3 -3
- helm/benchmark/run.py +78 -73
- helm/benchmark/run_expander.py +12 -1
- helm/benchmark/run_spec_factory.py +7 -6
- helm/benchmark/run_specs/audio_run_specs.py +52 -8
- helm/benchmark/run_specs/enterprise_run_specs.py +20 -0
- helm/benchmark/run_specs/experimental_run_specs.py +31 -1
- helm/benchmark/run_specs/long_context_run_specs.py +67 -15
- helm/benchmark/run_specs/medhelm_run_specs.py +146 -41
- helm/benchmark/run_specs/melt_run_specs.py +783 -0
- helm/benchmark/run_specs/speech_disorder_audio_run_specs.py +169 -0
- helm/benchmark/run_specs/vlm_run_specs.py +28 -0
- helm/benchmark/runner.py +5 -5
- helm/benchmark/scenarios/aci_bench_scenario.py +7 -1
- helm/benchmark/scenarios/audio_language/air_bench_chat_scenario.py +3 -1
- helm/benchmark/scenarios/audio_language/air_bench_foundation_scenario.py +5 -5
- helm/benchmark/scenarios/audio_language/corebench_scenario.py +77 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_asr_classification.py +103 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_classification_scenario.py +110 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_breakdown_scenario.py +78 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_symptoms_scenario.py +109 -0
- helm/benchmark/scenarios/audio_language/vocal_sound_scenario.py +15 -1
- helm/benchmark/scenarios/audio_language/voxceleb2_scenario.py +1 -2
- helm/benchmark/scenarios/autobencher_capabilities_scenario.py +2 -2
- helm/benchmark/scenarios/chw_care_plan_scenario.py +14 -13
- helm/benchmark/scenarios/clear_scenario.py +11 -7
- helm/benchmark/scenarios/dischargeme_scenario.py +36 -21
- helm/benchmark/scenarios/ehr_sql_scenario.py +7 -1
- helm/benchmark/scenarios/ehrshot_scenario.py +28 -55
- helm/benchmark/scenarios/grammar.py +2 -2
- helm/benchmark/scenarios/headqa_scenario.py +6 -1
- helm/benchmark/scenarios/infinite_bench_en_qa_scenario.py +85 -0
- helm/benchmark/scenarios/{infinite_bench_sum_scenario.py → infinite_bench_en_sum_scenario.py} +10 -13
- helm/benchmark/scenarios/kpi_edgar_scenario.py +151 -0
- helm/benchmark/scenarios/med_dialog_scenario.py +6 -1
- helm/benchmark/scenarios/medalign_scenario.py +9 -3
- helm/benchmark/scenarios/medalign_scenario_helper.py +8 -5
- helm/benchmark/scenarios/medbullets_scenario.py +7 -2
- helm/benchmark/scenarios/medcalc_bench_scenario.py +4 -2
- helm/benchmark/scenarios/medec_scenario.py +6 -1
- helm/benchmark/scenarios/medhallu_scenario.py +7 -1
- helm/benchmark/scenarios/medi_qa_scenario.py +10 -4
- helm/benchmark/scenarios/medication_qa_scenario.py +7 -1
- helm/benchmark/scenarios/melt_ir_scenario.py +171 -0
- helm/benchmark/scenarios/melt_knowledge_scenario.py +246 -0
- helm/benchmark/scenarios/melt_lm_scenarios.py +252 -0
- helm/benchmark/scenarios/melt_scenarios.py +793 -0
- helm/benchmark/scenarios/melt_srn_scenario.py +342 -0
- helm/benchmark/scenarios/melt_synthetic_reasoning_scenario.py +222 -0
- helm/benchmark/scenarios/melt_translation_scenario.py +152 -0
- helm/benchmark/scenarios/mental_health_scenario.py +16 -5
- helm/benchmark/scenarios/mimic_bhc_scenario.py +12 -7
- helm/benchmark/scenarios/mimic_rrs_scenario.py +17 -8
- helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +14 -8
- helm/benchmark/scenarios/mmlu_pro_scenario.py +1 -1
- helm/benchmark/scenarios/mtsamples_procedures_scenario.py +5 -2
- helm/benchmark/scenarios/mtsamples_replicate_scenario.py +3 -2
- helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +11 -5
- helm/benchmark/scenarios/numeracy_scenario.py +2 -1
- helm/benchmark/scenarios/openai_mrcr_scenario.py +79 -0
- helm/benchmark/scenarios/pubmed_qa_scenario.py +6 -1
- helm/benchmark/scenarios/race_based_med_scenario.py +18 -8
- helm/benchmark/scenarios/ruler_qa_scenario_helper.py +2 -2
- helm/benchmark/scenarios/ruler_qa_scenarios.py +2 -2
- helm/benchmark/scenarios/shc_bmt_scenario.py +12 -6
- helm/benchmark/scenarios/shc_cdi_scenario.py +11 -6
- helm/benchmark/scenarios/shc_conf_scenario.py +12 -6
- helm/benchmark/scenarios/shc_ent_scenario.py +11 -6
- helm/benchmark/scenarios/shc_gip_scenario.py +13 -5
- helm/benchmark/scenarios/shc_privacy_scenario.py +78 -0
- helm/benchmark/scenarios/shc_proxy_scenario.py +76 -0
- helm/benchmark/scenarios/shc_ptbm_scenario.py +12 -7
- helm/benchmark/scenarios/shc_sei_scenario.py +12 -7
- helm/benchmark/scenarios/shc_sequoia_scenario.py +13 -5
- helm/benchmark/scenarios/starr_patient_instructions_scenario.py +15 -8
- helm/benchmark/scenarios/test_infinite_bench_en_qa_scenario.py +18 -0
- helm/benchmark/scenarios/test_infinite_bench_en_sum_scenario.py +31 -0
- helm/benchmark/scenarios/truthful_qa_scenario.py +2 -1
- helm/benchmark/scenarios/vision_language/msr_vtt_scenario.py +75 -0
- helm/benchmark/server.py +2 -1
- helm/benchmark/static/schema_audio.yaml +60 -49
- helm/benchmark/static/schema_enterprise.yaml +21 -0
- helm/benchmark/static/schema_long_context.yaml +63 -20
- helm/benchmark/static/schema_medhelm.yaml +272 -213
- helm/benchmark/static/schema_melt.yaml +1257 -0
- helm/benchmark/static/schema_slphelm.yaml +162 -0
- helm/benchmark/static/schema_vhelm.yaml +26 -26
- helm/benchmark/static/schema_video.yaml +219 -0
- helm/benchmark/static_build/assets/index-94295e78.js +10 -0
- helm/benchmark/static_build/assets/index-b9779128.css +1 -0
- helm/benchmark/static_build/assets/medhelm-overview-eac29843.png +0 -0
- helm/benchmark/static_build/assets/{tremor-9cefc3c5.js → tremor-38a10867.js} +1 -1
- helm/benchmark/static_build/index.html +4 -4
- helm/benchmark/window_services/encoder_decoder_window_service.py +3 -3
- helm/benchmark/window_services/test_utils.py +3 -4
- helm/benchmark/window_services/tokenizer_service.py +7 -8
- helm/clients/anthropic_client.py +69 -29
- helm/clients/audio_language/diva_llama_client.py +4 -2
- helm/clients/audio_language/qwen2_5_omni_client.py +197 -0
- helm/clients/audio_language/qwen2_audiolm_client.py +8 -6
- helm/clients/audio_language/qwen_audiolm_client.py +4 -2
- helm/clients/audio_language/test.py +62 -0
- helm/clients/bedrock_client.py +3 -1
- helm/clients/client.py +7 -7
- helm/clients/grok_client.py +36 -0
- helm/clients/huggingface_client.py +42 -3
- helm/clients/huggingface_pipeline_client.py +138 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +1 -1
- helm/clients/image_generation/dalle_mini/model/modeling.py +1 -1
- helm/clients/image_generation/dalle_mini/model/processor.py +1 -1
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +1 -1
- helm/clients/openai_client.py +100 -54
- helm/clients/openai_responses_client.py +174 -0
- helm/clients/palmyra_client.py +2 -5
- helm/clients/reka_client.py +2 -2
- helm/clients/together_client.py +31 -4
- helm/clients/vertexai_client.py +6 -0
- helm/clients/vision_language/huggingface_vision2seq_client.py +6 -4
- helm/clients/vision_language/huggingface_vlm_client.py +2 -2
- helm/clients/vision_language/idefics_client.py +6 -2
- helm/clients/vision_language/paligemma_client.py +2 -2
- helm/clients/vision_language/qwen2_vlm_client.py +66 -53
- helm/clients/vision_language/qwen_vlm_client.py +7 -5
- helm/clients/writer_client.py +102 -0
- helm/common/context.py +80 -0
- helm/common/credentials_utils.py +5 -5
- helm/common/general.py +9 -2
- helm/common/hierarchical_logger.py +46 -3
- helm/common/local_context.py +140 -0
- helm/common/remote_context.py +61 -0
- helm/common/request.py +8 -0
- helm/config/model_deployments.yaml +864 -193
- helm/config/model_metadata.yaml +667 -53
- helm/config/tokenizer_configs.yaml +144 -3
- helm/proxy/cli.py +3 -1
- helm/proxy/critique/mechanical_turk_utils.py +1 -1
- helm/proxy/services/server_service.py +21 -85
- helm/tokenizers/grok_tokenizer.py +53 -0
- helm/tokenizers/huggingface_tokenizer.py +1 -1
- helm/tokenizers/test_grok_tokenizer.py +33 -0
- helm/benchmark/scenarios/test_infinite_bench_sum_scenario.py +0 -46
- helm/benchmark/static_build/assets/index-262903c1.js +0 -10
- helm/benchmark/static_build/assets/index-42060d71.css +0 -1
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.6.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.6.dist-info}/licenses/LICENSE +0 -0
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.6.dist-info}/top_level.txt +0 -0
- /helm/benchmark/static_build/assets/{medhelm-overview-3ddfcd65.png → medhelm-v1-overview-3ddfcd65.png} +0 -0
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
# mypy: check_untyped_defs = False
|
|
2
|
+
import dataclasses
|
|
3
|
+
from typing import Any, Dict, List, Optional, Union
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
from helm.clients.openai_client import OpenAIClientUtils
|
|
7
|
+
from helm.common.cache import CacheConfig
|
|
8
|
+
from helm.common.media_object import TEXT_TYPE
|
|
9
|
+
from helm.common.request import (
|
|
10
|
+
Thinking,
|
|
11
|
+
wrap_request_time,
|
|
12
|
+
Request,
|
|
13
|
+
RequestResult,
|
|
14
|
+
GeneratedOutput,
|
|
15
|
+
)
|
|
16
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
17
|
+
from helm.clients.client import (
|
|
18
|
+
CachingClient,
|
|
19
|
+
truncate_and_tokenize_response_text,
|
|
20
|
+
generate_uid_for_multimodal_prompt,
|
|
21
|
+
)
|
|
22
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
import openai
|
|
26
|
+
from openai import OpenAI
|
|
27
|
+
except ModuleNotFoundError as e:
|
|
28
|
+
handle_module_not_found_error(e, ["openai"])
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class OpenAIResponseClient(CachingClient):
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
tokenizer: Tokenizer,
|
|
35
|
+
tokenizer_name: str,
|
|
36
|
+
cache_config: CacheConfig,
|
|
37
|
+
api_key: Optional[str] = None,
|
|
38
|
+
org_id: Optional[str] = None,
|
|
39
|
+
base_url: Optional[str] = None,
|
|
40
|
+
reasoning_effort: Optional[str] = None,
|
|
41
|
+
openai_model_name: Optional[str] = None,
|
|
42
|
+
):
|
|
43
|
+
super().__init__(cache_config=cache_config)
|
|
44
|
+
self.tokenizer = tokenizer
|
|
45
|
+
self.tokenizer_name = tokenizer_name
|
|
46
|
+
self.client = OpenAI(
|
|
47
|
+
api_key=api_key,
|
|
48
|
+
organization=org_id,
|
|
49
|
+
base_url=base_url,
|
|
50
|
+
)
|
|
51
|
+
self.reasoning_effort = reasoning_effort
|
|
52
|
+
self.openai_model_name = openai_model_name
|
|
53
|
+
|
|
54
|
+
def _get_cache_key(self, raw_request: Dict, request: Request):
|
|
55
|
+
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
56
|
+
if request.multimodal_prompt:
|
|
57
|
+
prompt_key: str = generate_uid_for_multimodal_prompt(request.multimodal_prompt)
|
|
58
|
+
cache_key = {**cache_key, "multimodal_prompt": prompt_key}
|
|
59
|
+
return cache_key
|
|
60
|
+
|
|
61
|
+
def _make_raw_request(self, request: Request) -> dict[str, Any]:
|
|
62
|
+
input: Union[str, List[Dict[str, Any]]]
|
|
63
|
+
if request.multimodal_prompt is not None:
|
|
64
|
+
content = []
|
|
65
|
+
request.validate()
|
|
66
|
+
for media_object in request.multimodal_prompt.media_objects:
|
|
67
|
+
if media_object.is_type("image") and media_object.location:
|
|
68
|
+
from helm.common.images_utils import encode_base64
|
|
69
|
+
|
|
70
|
+
base64_image: str = encode_base64(media_object.location)
|
|
71
|
+
content.append(
|
|
72
|
+
{
|
|
73
|
+
"type": "input_image",
|
|
74
|
+
"image_url": f"data:image/jpeg;base64,{base64_image}",
|
|
75
|
+
}
|
|
76
|
+
)
|
|
77
|
+
elif media_object.is_type(TEXT_TYPE):
|
|
78
|
+
assert media_object.text is not None
|
|
79
|
+
content.append({"type": "input_text", "text": media_object.text})
|
|
80
|
+
else:
|
|
81
|
+
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
|
82
|
+
input = [{"role": "user", "content": content}]
|
|
83
|
+
else:
|
|
84
|
+
input = request.prompt
|
|
85
|
+
|
|
86
|
+
raw_request: Dict[str, Any] = {
|
|
87
|
+
"model": self._get_model_for_request(request),
|
|
88
|
+
"input": input,
|
|
89
|
+
"top_p": request.top_p,
|
|
90
|
+
# API errors if max_output_tokens is less than 16
|
|
91
|
+
# (Error you get: "Invalid 'max_output_tokens': integer below minimum value.
|
|
92
|
+
# Expected a value >= 16, but got 5 instead.")
|
|
93
|
+
"max_output_tokens": max(16, request.max_tokens),
|
|
94
|
+
"temperature": request.temperature,
|
|
95
|
+
# Don't store responses for later retrieval
|
|
96
|
+
"store": False,
|
|
97
|
+
}
|
|
98
|
+
if self.reasoning_effort:
|
|
99
|
+
raw_request["reasoning"] = {"effort": self.reasoning_effort}
|
|
100
|
+
# If o-series model, get reasoning summaries
|
|
101
|
+
# Plus other changes
|
|
102
|
+
model_engine: str = request.model_engine
|
|
103
|
+
if OpenAIClientUtils.is_reasoning_model(model_engine):
|
|
104
|
+
raw_request["reasoning"]["summary"] = "detailed"
|
|
105
|
+
# Avoid error:
|
|
106
|
+
# "Error code: 400 - {'error': {'message': "Unsupported parameter: 'temperature' is
|
|
107
|
+
# not supported with this model.", 'type': 'invalid_request_error', 'param': 'temperature',
|
|
108
|
+
# 'code': 'unsupported_parameter'}}"
|
|
109
|
+
raw_request.pop("temperature", None)
|
|
110
|
+
|
|
111
|
+
# The following parameters also happen to be unsupported by the o-series (code unsupported_parameter)
|
|
112
|
+
raw_request.pop("top_p", None)
|
|
113
|
+
|
|
114
|
+
return raw_request
|
|
115
|
+
|
|
116
|
+
def _get_model_for_request(self, request: Request) -> str:
|
|
117
|
+
return self.openai_model_name or request.model_engine
|
|
118
|
+
|
|
119
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
120
|
+
# Content can either be text or a list of multimodal content made up of text and images:
|
|
121
|
+
# https://platform.openai.com/docs/api-reference/responses/create
|
|
122
|
+
raw_request = self._make_raw_request(request)
|
|
123
|
+
|
|
124
|
+
# The responses API does not support a "num_completions" parameter,
|
|
125
|
+
# so we need to handle it ourselves with a simple loop
|
|
126
|
+
completions: list[GeneratedOutput] = []
|
|
127
|
+
for _ in range(request.num_completions):
|
|
128
|
+
|
|
129
|
+
def do_it() -> Dict[str, Any]:
|
|
130
|
+
raw_response = self.client.responses.create(**raw_request).model_dump(mode="json")
|
|
131
|
+
assert not raw_response.get("error", None), f"Error in response: {raw_response}"
|
|
132
|
+
return raw_response
|
|
133
|
+
|
|
134
|
+
try:
|
|
135
|
+
cache_key = self._get_cache_key(raw_request, request)
|
|
136
|
+
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
137
|
+
except openai.OpenAIError as e:
|
|
138
|
+
return OpenAIClientUtils.handle_openai_error(e, request)
|
|
139
|
+
|
|
140
|
+
# We can only return one completition really,
|
|
141
|
+
# but we get an array of messages back, so we need to contact them
|
|
142
|
+
reasoning_output = ""
|
|
143
|
+
text_output = ""
|
|
144
|
+
|
|
145
|
+
if request.echo_prompt:
|
|
146
|
+
text_output += request.prompt
|
|
147
|
+
for output in response["output"]:
|
|
148
|
+
output_type = output["type"] # one of "message" or "reasoning" from API observation
|
|
149
|
+
is_reasoning_output = output_type == "reasoning"
|
|
150
|
+
|
|
151
|
+
if is_reasoning_output:
|
|
152
|
+
reasoning_output += "\n".join([raw_output["text"] for raw_output in output["summary"]])
|
|
153
|
+
else:
|
|
154
|
+
text_output += "\n".join([raw_output["text"] for raw_output in output["content"]])
|
|
155
|
+
|
|
156
|
+
completion = truncate_and_tokenize_response_text(
|
|
157
|
+
text_output,
|
|
158
|
+
request,
|
|
159
|
+
self.tokenizer,
|
|
160
|
+
self.tokenizer_name,
|
|
161
|
+
original_finish_reason="",
|
|
162
|
+
)
|
|
163
|
+
if reasoning_output:
|
|
164
|
+
completion = dataclasses.replace(completion, thinking=Thinking(text=reasoning_output))
|
|
165
|
+
completions.append(completion)
|
|
166
|
+
|
|
167
|
+
return RequestResult(
|
|
168
|
+
success=True,
|
|
169
|
+
cached=cached,
|
|
170
|
+
request_time=response["request_time"],
|
|
171
|
+
request_datetime=response.get("request_datetime"),
|
|
172
|
+
completions=completions,
|
|
173
|
+
embedding=[],
|
|
174
|
+
)
|
helm/clients/palmyra_client.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import Any, Dict, List
|
|
|
5
5
|
|
|
6
6
|
from helm.clients.openai_client import OpenAIClient
|
|
7
7
|
from helm.common.cache import CacheConfig
|
|
8
|
-
from helm.common.hierarchical_logger import
|
|
8
|
+
from helm.common.hierarchical_logger import hwarn
|
|
9
9
|
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, Token, ErrorFlags
|
|
10
10
|
from helm.common.tokenization_request import (
|
|
11
11
|
TokenizationRequest,
|
|
@@ -103,10 +103,7 @@ class PalmyraClient(CachingClient):
|
|
|
103
103
|
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
104
104
|
|
|
105
105
|
if _is_content_moderation_failure(response):
|
|
106
|
-
|
|
107
|
-
f"WARNING: Returning empty request for {request.model_deployment} "
|
|
108
|
-
"due to content moderation filter"
|
|
109
|
-
)
|
|
106
|
+
hwarn(f"Returning empty request for {request.model_deployment} " "due to content moderation filter")
|
|
110
107
|
return RequestResult(
|
|
111
108
|
success=False,
|
|
112
109
|
cached=False,
|
helm/clients/reka_client.py
CHANGED
|
@@ -6,7 +6,7 @@ from helm.proxy.retry import NonRetriableException
|
|
|
6
6
|
from helm.common.cache import CacheConfig
|
|
7
7
|
from helm.common.media_object import TEXT_TYPE
|
|
8
8
|
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput
|
|
9
|
-
from helm.common.hierarchical_logger import
|
|
9
|
+
from helm.common.hierarchical_logger import hwarn
|
|
10
10
|
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
11
11
|
from helm.tokenizers.tokenizer import Tokenizer
|
|
12
12
|
from helm.clients.client import CachingClient, truncate_and_tokenize_response_text
|
|
@@ -121,7 +121,7 @@ class RekaClient(CachingClient):
|
|
|
121
121
|
if messages[-1]["role"] != "user":
|
|
122
122
|
raise ValueError("Last message must have role 'user'")
|
|
123
123
|
if request.prompt != "":
|
|
124
|
-
|
|
124
|
+
hwarn("Since message is set, prompt will be ignored")
|
|
125
125
|
reka_chat_history = self._convert_messages_to_reka_chat_history(messages)
|
|
126
126
|
else:
|
|
127
127
|
current_chat_history: Dict[str, Any] = {
|
helm/clients/together_client.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from copy import deepcopy
|
|
2
2
|
from itertools import zip_longest
|
|
3
|
+
import re
|
|
3
4
|
import threading
|
|
4
|
-
from typing import Callable, List, Dict, Any, Mapping, Optional, TypedDict, Union
|
|
5
|
+
from typing import Callable, List, Dict, Any, Mapping, Optional, Tuple, TypedDict, Union
|
|
5
6
|
from typing_extensions import NotRequired
|
|
6
7
|
|
|
7
8
|
import requests
|
|
@@ -11,7 +12,7 @@ from helm.common.cache import CacheConfig
|
|
|
11
12
|
from helm.common.media_object import IMAGE_TYPE, TEXT_TYPE
|
|
12
13
|
from helm.common.object_spec import get_class_by_name
|
|
13
14
|
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
14
|
-
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, Token
|
|
15
|
+
from helm.common.request import Thinking, wrap_request_time, Request, RequestResult, GeneratedOutput, Token
|
|
15
16
|
from helm.clients.client import CachingClient, truncate_sequence, cleanup_str
|
|
16
17
|
|
|
17
18
|
try:
|
|
@@ -100,6 +101,19 @@ class JobNotFinishedError(TogetherClientError):
|
|
|
100
101
|
pass
|
|
101
102
|
|
|
102
103
|
|
|
104
|
+
def _parse_thinking(input: str) -> Tuple[str, str]:
|
|
105
|
+
"""Return a tuple of thinking text and output text."""
|
|
106
|
+
match = re.match(r"<think>\n(.*)\n</think>\n{0,2}(.*)", input, re.DOTALL)
|
|
107
|
+
if match:
|
|
108
|
+
return (match.group(1), match.group(2))
|
|
109
|
+
|
|
110
|
+
match = re.match(r"<think>\n?(.*)", input, re.DOTALL)
|
|
111
|
+
if match:
|
|
112
|
+
return (match.group(1), "")
|
|
113
|
+
|
|
114
|
+
return (input, "")
|
|
115
|
+
|
|
116
|
+
|
|
103
117
|
class TogetherClient(CachingClient):
|
|
104
118
|
"""
|
|
105
119
|
Client for the models where we evaluate offline. Since the queries are handled offline, the `TogetherClient` just
|
|
@@ -328,12 +342,14 @@ class TogetherChatClient(CachingClient):
|
|
|
328
342
|
together_model: Optional[str] = None,
|
|
329
343
|
disable_logprobs: Optional[bool] = None,
|
|
330
344
|
output_processor: Optional[str] = None,
|
|
345
|
+
parse_thinking: Optional[bool] = None,
|
|
331
346
|
):
|
|
332
347
|
super().__init__(cache_config=cache_config)
|
|
333
348
|
self._client = Together(api_key=api_key)
|
|
334
349
|
self._together_model = together_model
|
|
335
350
|
self._disable_logprobs = bool(disable_logprobs)
|
|
336
351
|
# self.output_processor is actually a function, not a class
|
|
352
|
+
self._parse_thinking = bool(parse_thinking)
|
|
337
353
|
|
|
338
354
|
self.output_processor: Optional[Callable[[str], str]] = (
|
|
339
355
|
get_class_by_name(output_processor) if output_processor else None
|
|
@@ -424,11 +440,21 @@ class TogetherChatClient(CachingClient):
|
|
|
424
440
|
if token_text is None:
|
|
425
441
|
break
|
|
426
442
|
tokens.append(Token(text=token_text, logprob=token_logprob or 0.0))
|
|
443
|
+
logprob = sum([token.logprob for token in tokens]) if tokens else 0.0
|
|
427
444
|
assert choice.message.role == "assistant"
|
|
428
445
|
output_text = choice.message.content
|
|
429
446
|
if self.output_processor:
|
|
430
447
|
output_text = self.output_processor(output_text)
|
|
431
|
-
|
|
448
|
+
|
|
449
|
+
if self._parse_thinking:
|
|
450
|
+
thinking_text, output_text = _parse_thinking(output_text)
|
|
451
|
+
generated_outputs.append(
|
|
452
|
+
GeneratedOutput(
|
|
453
|
+
text=output_text, logprob=logprob, tokens=tokens, thinking=Thinking(text=thinking_text)
|
|
454
|
+
)
|
|
455
|
+
)
|
|
456
|
+
else:
|
|
457
|
+
generated_outputs.append(GeneratedOutput(text=output_text, logprob=logprob, tokens=tokens))
|
|
432
458
|
return RequestResult(
|
|
433
459
|
success=True,
|
|
434
460
|
cached=cached,
|
|
@@ -521,8 +547,9 @@ class TogetherCompletionClient(CachingClient):
|
|
|
521
547
|
if token_text is None:
|
|
522
548
|
break
|
|
523
549
|
tokens.append(Token(text=token_text, logprob=token_logprob or 0.0))
|
|
550
|
+
logprob = sum([token.logprob for token in tokens]) if tokens else 0.0
|
|
524
551
|
assert choice.text
|
|
525
|
-
generated_outputs.append(GeneratedOutput(text=choice.text, logprob=
|
|
552
|
+
generated_outputs.append(GeneratedOutput(text=choice.text, logprob=logprob, tokens=tokens))
|
|
526
553
|
return RequestResult(
|
|
527
554
|
success=True,
|
|
528
555
|
cached=cached,
|
helm/clients/vertexai_client.py
CHANGED
|
@@ -360,6 +360,12 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
360
360
|
for media_object in request.multimodal_prompt.media_objects:
|
|
361
361
|
if media_object.is_type("image") and media_object.location:
|
|
362
362
|
contents.append(Part.from_image(Image.load_from_file(media_object.location)))
|
|
363
|
+
elif media_object.is_type("video") and media_object.location:
|
|
364
|
+
# Following this example
|
|
365
|
+
# https://cloud.google.com/vertex-ai/generative-ai/docs/samples/googlegenaisdk-textgen-with-local-video
|
|
366
|
+
with open(media_object.location, "rb") as fp:
|
|
367
|
+
video_content = fp.read()
|
|
368
|
+
contents.append(Part.from_data(data=video_content, mime_type=media_object.content_type))
|
|
363
369
|
elif media_object.is_type("audio") and media_object.location:
|
|
364
370
|
contents.append(
|
|
365
371
|
Part.from_data(get_contents_as_bytes(media_object.location), mime_type=media_object.content_type)
|
|
@@ -95,8 +95,8 @@ class HuggingFaceVision2SeqClient(CachingClient):
|
|
|
95
95
|
|
|
96
96
|
def do_it() -> Dict[str, Any]:
|
|
97
97
|
messages = [{"role": "user", "content": multimodal_prompt}]
|
|
98
|
-
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
|
99
|
-
inputs = processor(
|
|
98
|
+
prompt = processor.apply_chat_template(messages, add_generation_prompt=True) # type: ignore
|
|
99
|
+
inputs = processor( # type: ignore
|
|
100
100
|
text=[prompt] * request.num_completions,
|
|
101
101
|
images=[
|
|
102
102
|
[load_image(image_path) for image_path in image_paths]
|
|
@@ -107,8 +107,10 @@ class HuggingFaceVision2SeqClient(CachingClient):
|
|
|
107
107
|
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
|
108
108
|
|
|
109
109
|
# Generate
|
|
110
|
-
generated_ids = model.generate(**inputs, **generation_args)
|
|
111
|
-
generated_texts: List[str] = processor.batch_decode(
|
|
110
|
+
generated_ids = model.generate(**inputs, **generation_args) # type: ignore
|
|
111
|
+
generated_texts: List[str] = processor.batch_decode( # type: ignore
|
|
112
|
+
generated_ids, skip_special_tokens=True
|
|
113
|
+
)
|
|
112
114
|
return {"output": generated_texts}
|
|
113
115
|
|
|
114
116
|
# Include the prompt and model name in the cache key
|
|
@@ -50,7 +50,7 @@ class HuggingFaceVLMClient(CachingClient):
|
|
|
50
50
|
with self._models_lock:
|
|
51
51
|
model_id: str = self._models_aliases.get(model_name, model_name)
|
|
52
52
|
if model_id not in self._models:
|
|
53
|
-
self._models[model_id] = pipeline("image-to-text", model=model_id, device_map="auto")
|
|
53
|
+
self._models[model_id] = pipeline("image-to-text", model=model_id, device_map="auto") # type: ignore
|
|
54
54
|
return self._models[model_id]
|
|
55
55
|
|
|
56
56
|
def make_request(self, request: Request) -> RequestResult:
|
|
@@ -80,7 +80,7 @@ class HuggingFaceVLMClient(CachingClient):
|
|
|
80
80
|
|
|
81
81
|
def do_it() -> Dict[str, Any]:
|
|
82
82
|
model: ImageToTextPipeline = self._get_model(request.model_deployment)
|
|
83
|
-
outputs = model(image, prompt=prompt, generate_kwargs=generation_args)
|
|
83
|
+
outputs = model(image, prompt=prompt, generate_kwargs=generation_args) # type: ignore
|
|
84
84
|
return outputs[0]
|
|
85
85
|
|
|
86
86
|
cache_key = CachingClient.make_cache_key(
|
|
@@ -89,14 +89,18 @@ class IDEFICSClient(CachingClient):
|
|
|
89
89
|
input_args: Dict[str, Union[str, bool]] = {"return_tensors": "pt"}
|
|
90
90
|
generation_args = {
|
|
91
91
|
"max_new_tokens": request.max_tokens,
|
|
92
|
-
"bad_words_ids": processor.tokenizer(
|
|
92
|
+
"bad_words_ids": processor.tokenizer( # type: ignore
|
|
93
|
+
self.BAD_WORD_TOKENS, add_special_tokens=False
|
|
94
|
+
).input_ids,
|
|
93
95
|
}
|
|
94
96
|
|
|
95
97
|
if self.END_OF_UTTERANCE_TOKEN in request.stop_sequences:
|
|
96
98
|
# Following https://huggingface.co/HuggingFaceM4/idefics-80b-instruct,
|
|
97
99
|
# specify <end_of_utterance> as an exit condition.
|
|
98
100
|
input_args["add_end_of_utterance_token"] = False
|
|
99
|
-
exit_condition = processor.tokenizer(
|
|
101
|
+
exit_condition = processor.tokenizer( # type: ignore
|
|
102
|
+
self.END_OF_UTTERANCE_TOKEN, add_special_tokens=False
|
|
103
|
+
).input_ids
|
|
100
104
|
generation_args["eos_token_id"] = exit_condition
|
|
101
105
|
|
|
102
106
|
multimodal_prompt: List[Union[str, Image.Image]] = []
|
|
@@ -93,7 +93,7 @@ class PaliGemmaClient(CachingClient):
|
|
|
93
93
|
else:
|
|
94
94
|
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
|
95
95
|
prompt_text: str = "\n".join(prompt_pieces)
|
|
96
|
-
model_inputs = processor(text=prompt_text, images=images, return_tensors="pt").to(self._device)
|
|
96
|
+
model_inputs = processor(text=prompt_text, images=images, return_tensors="pt").to(self._device) # type: ignore
|
|
97
97
|
input_len = model_inputs["input_ids"].shape[-1]
|
|
98
98
|
|
|
99
99
|
completions: List[GeneratedOutput] = []
|
|
@@ -109,7 +109,7 @@ class PaliGemmaClient(CachingClient):
|
|
|
109
109
|
)[0]
|
|
110
110
|
if not request.echo_prompt:
|
|
111
111
|
generation = generation[input_len:]
|
|
112
|
-
decoded = processor.decode(generation, skip_special_tokens=True)
|
|
112
|
+
decoded = processor.decode(generation, skip_special_tokens=True) # type: ignore
|
|
113
113
|
return {"output": decoded}
|
|
114
114
|
|
|
115
115
|
# Include the prompt and model name in the cache key
|
|
@@ -2,7 +2,7 @@ from threading import Lock
|
|
|
2
2
|
from typing import Any, Dict, List, Optional
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
|
|
5
|
-
from transformers import
|
|
5
|
+
from transformers import AutoProcessor
|
|
6
6
|
from qwen_vl_utils import process_vision_info
|
|
7
7
|
import torch
|
|
8
8
|
|
|
@@ -16,15 +16,20 @@ from helm.clients.client import CachingClient, generate_uid_for_multimodal_promp
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
@dataclass(frozen=True)
|
|
19
|
-
class
|
|
20
|
-
model:
|
|
19
|
+
class LoadedModelProcessor:
|
|
20
|
+
model: Any
|
|
21
21
|
processor: AutoProcessor
|
|
22
22
|
|
|
23
23
|
|
|
24
|
+
# Global cache for all models
|
|
24
25
|
_models_lock: Lock = Lock()
|
|
25
|
-
_models: Dict[str, Optional[
|
|
26
|
+
_models: Dict[str, Optional[LoadedModelProcessor]] = {
|
|
26
27
|
"Qwen/Qwen2-VL-7B-Instruct": None,
|
|
27
28
|
"Qwen/Qwen2-VL-72B-Instruct": None,
|
|
29
|
+
"Qwen/Qwen2.5-VL-3B-Instruct": None,
|
|
30
|
+
"Qwen/Qwen2.5-VL-7B-Instruct": None,
|
|
31
|
+
"Qwen/Qwen2.5-VL-32B-Instruct": None,
|
|
32
|
+
"Qwen/Qwen2.5-VL-72B-Instruct": None,
|
|
28
33
|
}
|
|
29
34
|
|
|
30
35
|
|
|
@@ -38,50 +43,52 @@ class Qwen2VLMClient(CachingClient):
|
|
|
38
43
|
return "Qwen/Qwen2-VL-7B-Instruct"
|
|
39
44
|
elif helm_model_name == "qwen2-vl-72b-instruct":
|
|
40
45
|
return "Qwen/Qwen2-VL-72B-Instruct"
|
|
46
|
+
elif helm_model_name == "qwen2.5-vl-3b-instruct":
|
|
47
|
+
return "Qwen/Qwen2.5-VL-3B-Instruct"
|
|
48
|
+
elif helm_model_name == "qwen2.5-vl-7b-instruct":
|
|
49
|
+
return "Qwen/Qwen2.5-VL-7B-Instruct"
|
|
50
|
+
elif helm_model_name == "qwen2.5-vl-32b-instruct":
|
|
51
|
+
return "Qwen/Qwen2.5-VL-32B-Instruct"
|
|
52
|
+
elif helm_model_name == "qwen2.5-vl-72b-instruct":
|
|
53
|
+
return "Qwen/Qwen2.5-VL-72B-Instruct"
|
|
41
54
|
else:
|
|
42
55
|
raise ValueError(f"Unhandled model name: {helm_model_name}")
|
|
43
56
|
|
|
44
|
-
def _get_model(self, helm_model_name: str) ->
|
|
45
|
-
|
|
46
|
-
global _models
|
|
57
|
+
def _get_model(self, helm_model_name: str) -> LoadedModelProcessor:
|
|
58
|
+
from transformers import Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
|
|
47
59
|
|
|
48
|
-
|
|
60
|
+
global _models_lock, _models
|
|
49
61
|
|
|
62
|
+
model_name = self._get_model_name(helm_model_name)
|
|
50
63
|
with _models_lock:
|
|
51
64
|
loaded = _models[model_name]
|
|
52
65
|
if loaded is None:
|
|
53
66
|
hlog(f"Loading model {model_name} and caching in memory...")
|
|
54
|
-
#
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
67
|
+
# Use different loading routines depending on whether it's Qwen2.5 or Qwen2.
|
|
68
|
+
if "2.5" in model_name:
|
|
69
|
+
# Qwen2.5: by default use torch_dtype="auto". You can enable flash_attention_2 if desired.
|
|
70
|
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
|
71
|
+
model_name,
|
|
72
|
+
torch_dtype=torch.bfloat16,
|
|
73
|
+
device_map="auto",
|
|
74
|
+
attn_implementation="flash_attention_2",
|
|
75
|
+
).eval()
|
|
76
|
+
else:
|
|
77
|
+
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
78
|
+
model_name,
|
|
79
|
+
torch_dtype=torch.bfloat16,
|
|
80
|
+
device_map="auto",
|
|
81
|
+
attn_implementation="flash_attention_2",
|
|
82
|
+
).eval()
|
|
61
83
|
processor = AutoProcessor.from_pretrained(model_name)
|
|
62
|
-
loaded =
|
|
84
|
+
loaded = LoadedModelProcessor(model=model, processor=processor)
|
|
63
85
|
_models[model_name] = loaded
|
|
64
|
-
|
|
65
86
|
return loaded
|
|
66
87
|
|
|
67
88
|
def make_request(self, request: Request) -> RequestResult:
|
|
68
89
|
assert request.multimodal_prompt is not None, "Multimodal prompt is required"
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
processor = loaded.processor
|
|
72
|
-
|
|
73
|
-
# Build Qwen2 messages
|
|
74
|
-
# We assume all media objects go into a single "user" message:
|
|
75
|
-
# messages = [
|
|
76
|
-
# {
|
|
77
|
-
# "role": "user",
|
|
78
|
-
# "content": [
|
|
79
|
-
# {"type": "image", "image": "file:///path/to/image1.jpg"},
|
|
80
|
-
# {"type": "image", "image": "file:///path/to/image2.jpg"},
|
|
81
|
-
# {"type": "text", "text": "Describe these images."}
|
|
82
|
-
# ]
|
|
83
|
-
# }
|
|
84
|
-
# ]
|
|
90
|
+
|
|
91
|
+
# Build messages by collating all media objects into a single "user" message.
|
|
85
92
|
message_content = []
|
|
86
93
|
for media_object in request.multimodal_prompt.media_objects:
|
|
87
94
|
if media_object.is_type("image") and media_object.location:
|
|
@@ -95,18 +102,6 @@ class Qwen2VLMClient(CachingClient):
|
|
|
95
102
|
|
|
96
103
|
messages = [{"role": "user", "content": message_content}]
|
|
97
104
|
|
|
98
|
-
# Prepare text and vision inputs
|
|
99
|
-
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
100
|
-
image_inputs, video_inputs = process_vision_info(messages)
|
|
101
|
-
|
|
102
|
-
inputs = processor(
|
|
103
|
-
text=[text],
|
|
104
|
-
images=image_inputs,
|
|
105
|
-
videos=video_inputs,
|
|
106
|
-
padding=True,
|
|
107
|
-
return_tensors="pt",
|
|
108
|
-
).to(self._device)
|
|
109
|
-
|
|
110
105
|
generation_args = {
|
|
111
106
|
"max_new_tokens": request.max_tokens,
|
|
112
107
|
}
|
|
@@ -116,23 +111,38 @@ class Qwen2VLMClient(CachingClient):
|
|
|
116
111
|
request_datetime: Optional[int] = None
|
|
117
112
|
all_cached: bool = True
|
|
118
113
|
|
|
119
|
-
with htrack_block(f"Generating for prompt: {text}"):
|
|
114
|
+
with htrack_block(f"Generating for prompt: {request.multimodal_prompt.text}"):
|
|
120
115
|
for completion_index in range(request.num_completions):
|
|
121
116
|
try:
|
|
122
117
|
|
|
123
118
|
def do_it() -> Dict[str, Any]:
|
|
119
|
+
loaded = self._get_model(request.model_engine)
|
|
120
|
+
model = loaded.model
|
|
121
|
+
processor = loaded.processor
|
|
122
|
+
|
|
123
|
+
# Prepare text and vision inputs.
|
|
124
|
+
text = processor.apply_chat_template( # type: ignore
|
|
125
|
+
messages, tokenize=False, add_generation_prompt=True
|
|
126
|
+
)
|
|
127
|
+
image_inputs, video_inputs = process_vision_info(messages)
|
|
128
|
+
inputs = processor( # type: ignore
|
|
129
|
+
text=[text],
|
|
130
|
+
images=image_inputs,
|
|
131
|
+
videos=video_inputs,
|
|
132
|
+
padding=True,
|
|
133
|
+
return_tensors="pt",
|
|
134
|
+
).to(self._device)
|
|
135
|
+
|
|
124
136
|
generated_ids = model.generate(**inputs, **generation_args)
|
|
125
|
-
# Remove the input prefix from outputs
|
|
137
|
+
# Remove the input prefix from outputs.
|
|
126
138
|
generated_ids_trimmed = [
|
|
127
139
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
128
140
|
]
|
|
129
|
-
output_text = processor.batch_decode(
|
|
141
|
+
output_text = processor.batch_decode( # type: ignore
|
|
130
142
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
131
143
|
)
|
|
132
|
-
# There's only one batch element
|
|
133
|
-
completion = output_text[0]
|
|
134
144
|
# For simplicity, we split tokens by whitespace.
|
|
135
|
-
|
|
145
|
+
completion = output_text[0]
|
|
136
146
|
tokens = completion.split()
|
|
137
147
|
return {"output": (completion, tokens)}
|
|
138
148
|
|
|
@@ -148,7 +158,11 @@ class Qwen2VLMClient(CachingClient):
|
|
|
148
158
|
result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
149
159
|
except RuntimeError as model_error:
|
|
150
160
|
return RequestResult(
|
|
151
|
-
success=False,
|
|
161
|
+
success=False,
|
|
162
|
+
cached=False,
|
|
163
|
+
error=str(model_error),
|
|
164
|
+
completions=[],
|
|
165
|
+
embedding=[],
|
|
152
166
|
)
|
|
153
167
|
|
|
154
168
|
text_out, tokens = result["output"]
|
|
@@ -160,7 +174,6 @@ class Qwen2VLMClient(CachingClient):
|
|
|
160
174
|
)
|
|
161
175
|
)
|
|
162
176
|
hlog(f"Generated: {text_out}")
|
|
163
|
-
|
|
164
177
|
request_time += result["request_time"]
|
|
165
178
|
request_datetime = request_datetime or result.get("request_datetime")
|
|
166
179
|
all_cached = all_cached and cached
|
|
@@ -115,14 +115,16 @@ class QwenVLMClient(CachingClient):
|
|
|
115
115
|
|
|
116
116
|
def do_it() -> Dict[str, Any]:
|
|
117
117
|
if request.model_engine == "qwen-vl-chat":
|
|
118
|
-
completion, _ = model.chat(
|
|
118
|
+
completion, _ = model.chat( # type: ignore
|
|
119
|
+
tokenizer, query=tokenizer.from_list_format(query), history=None # type: ignore
|
|
120
|
+
)
|
|
119
121
|
else:
|
|
120
|
-
inputs = tokenizer(tokenizer.from_list_format(query), return_tensors="pt")
|
|
122
|
+
inputs = tokenizer(tokenizer.from_list_format(query), return_tensors="pt") # type: ignore
|
|
121
123
|
inputs = inputs.to(self._device)
|
|
122
|
-
pred = model.generate(**inputs, **generation_args)
|
|
123
|
-
completion = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False)
|
|
124
|
+
pred = model.generate(**inputs, **generation_args) # type: ignore
|
|
125
|
+
completion = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False) # type: ignore
|
|
124
126
|
|
|
125
|
-
tokens: List[str] = tokenizer.tokenize(completion)
|
|
127
|
+
tokens: List[str] = tokenizer.tokenize(completion) # type: ignore
|
|
126
128
|
return {"output": (completion, tokens)}
|
|
127
129
|
|
|
128
130
|
# Include the prompt and model name in the cache key
|