crfm-helm 0.5.5__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.6.dist-info}/METADATA +27 -13
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.6.dist-info}/RECORD +203 -156
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.6.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +3 -3
- helm/benchmark/adaptation/adapters/test_adapter.py +4 -4
- helm/benchmark/annotation/air_bench_annotator.py +1 -1
- helm/benchmark/annotation/bigcodebench_annotator.py +3 -3
- helm/benchmark/annotation/bird_sql_annotator.py +2 -2
- helm/benchmark/annotation/chw_care_plan_annotator.py +7 -12
- helm/benchmark/annotation/ehr_sql_annotator.py +2 -2
- helm/benchmark/annotation/helpdesk_call_summarization_annotator.py +7 -7
- helm/benchmark/annotation/mimic_bhc_annotator.py +100 -0
- helm/benchmark/annotation/model_as_judge.py +12 -16
- helm/benchmark/annotation/omni_math_annotator.py +13 -14
- helm/benchmark/annotation/wildbench_annotator.py +9 -9
- helm/benchmark/executor.py +11 -12
- helm/benchmark/metrics/aci_bench_metrics.py +9 -29
- helm/benchmark/metrics/bias_word_lists.py +1 -1
- helm/benchmark/metrics/chw_care_plan_metrics.py +10 -30
- helm/benchmark/metrics/classification_metrics.py +3 -3
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +2 -2
- helm/benchmark/metrics/dischargeme_metrics.py +9 -29
- helm/benchmark/metrics/efficiency_metrics.py +3 -3
- helm/benchmark/metrics/gpt4_audio_refusal_metrics.py +145 -0
- helm/benchmark/metrics/ifeval_metrics.py +2 -2
- helm/benchmark/metrics/kpi_edgar_metrics.py +121 -0
- helm/benchmark/metrics/llm_jury_metrics.py +46 -0
- helm/benchmark/metrics/med_dialog_metrics.py +9 -29
- helm/benchmark/metrics/medalign_metrics.py +9 -29
- helm/benchmark/metrics/medi_qa_metrics.py +9 -29
- helm/benchmark/metrics/medication_qa_metrics.py +10 -30
- helm/benchmark/metrics/melt_bias_metric.py +234 -0
- helm/benchmark/metrics/melt_bias_word_lists.py +1367 -0
- helm/benchmark/metrics/melt_metric_specs.py +43 -0
- helm/benchmark/metrics/melt_toxicity_metric.py +107 -0
- helm/benchmark/metrics/mental_health_metrics.py +9 -29
- helm/benchmark/metrics/metric_service.py +11 -11
- helm/benchmark/metrics/mimic_bhc_metrics.py +14 -0
- helm/benchmark/metrics/mimic_rrs_metrics.py +9 -29
- helm/benchmark/metrics/mtsamples_procedures_metrics.py +9 -29
- helm/benchmark/metrics/mtsamples_replicate_metrics.py +9 -29
- helm/benchmark/metrics/openai_mrcr_metrics.py +52 -0
- helm/benchmark/metrics/ruler_qa_metrics.py +34 -0
- helm/benchmark/metrics/starr_patient_instructions_metrics.py +9 -29
- helm/benchmark/metrics/summac/model_summac.py +1 -2
- helm/benchmark/metrics/summarization_metrics.py +2 -1
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +2 -2
- helm/benchmark/metrics/toxicity_metrics.py +2 -2
- helm/benchmark/metrics/unitxt_metrics.py +3 -4
- helm/benchmark/metrics/vision_language/emd_utils.py +4 -2
- helm/benchmark/metrics/vision_language/image_utils.py +2 -2
- helm/benchmark/model_deployment_registry.py +6 -8
- helm/benchmark/presentation/contamination.py +3 -3
- helm/benchmark/presentation/create_plots.py +33 -12
- helm/benchmark/presentation/run_display.py +13 -0
- helm/benchmark/presentation/schema.py +2 -1
- helm/benchmark/presentation/summarize.py +76 -59
- helm/benchmark/reeval_run.py +3 -4
- helm/benchmark/reeval_runner.py +3 -3
- helm/benchmark/run.py +78 -73
- helm/benchmark/run_expander.py +12 -1
- helm/benchmark/run_spec_factory.py +7 -6
- helm/benchmark/run_specs/audio_run_specs.py +52 -8
- helm/benchmark/run_specs/enterprise_run_specs.py +20 -0
- helm/benchmark/run_specs/experimental_run_specs.py +31 -1
- helm/benchmark/run_specs/long_context_run_specs.py +67 -15
- helm/benchmark/run_specs/medhelm_run_specs.py +146 -41
- helm/benchmark/run_specs/melt_run_specs.py +783 -0
- helm/benchmark/run_specs/speech_disorder_audio_run_specs.py +169 -0
- helm/benchmark/run_specs/vlm_run_specs.py +28 -0
- helm/benchmark/runner.py +5 -5
- helm/benchmark/scenarios/aci_bench_scenario.py +7 -1
- helm/benchmark/scenarios/audio_language/air_bench_chat_scenario.py +3 -1
- helm/benchmark/scenarios/audio_language/air_bench_foundation_scenario.py +5 -5
- helm/benchmark/scenarios/audio_language/corebench_scenario.py +77 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_asr_classification.py +103 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_classification_scenario.py +110 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_breakdown_scenario.py +78 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_symptoms_scenario.py +109 -0
- helm/benchmark/scenarios/audio_language/vocal_sound_scenario.py +15 -1
- helm/benchmark/scenarios/audio_language/voxceleb2_scenario.py +1 -2
- helm/benchmark/scenarios/autobencher_capabilities_scenario.py +2 -2
- helm/benchmark/scenarios/chw_care_plan_scenario.py +14 -13
- helm/benchmark/scenarios/clear_scenario.py +11 -7
- helm/benchmark/scenarios/dischargeme_scenario.py +36 -21
- helm/benchmark/scenarios/ehr_sql_scenario.py +7 -1
- helm/benchmark/scenarios/ehrshot_scenario.py +28 -55
- helm/benchmark/scenarios/grammar.py +2 -2
- helm/benchmark/scenarios/headqa_scenario.py +6 -1
- helm/benchmark/scenarios/infinite_bench_en_qa_scenario.py +85 -0
- helm/benchmark/scenarios/{infinite_bench_sum_scenario.py → infinite_bench_en_sum_scenario.py} +10 -13
- helm/benchmark/scenarios/kpi_edgar_scenario.py +151 -0
- helm/benchmark/scenarios/med_dialog_scenario.py +6 -1
- helm/benchmark/scenarios/medalign_scenario.py +9 -3
- helm/benchmark/scenarios/medalign_scenario_helper.py +8 -5
- helm/benchmark/scenarios/medbullets_scenario.py +7 -2
- helm/benchmark/scenarios/medcalc_bench_scenario.py +4 -2
- helm/benchmark/scenarios/medec_scenario.py +6 -1
- helm/benchmark/scenarios/medhallu_scenario.py +7 -1
- helm/benchmark/scenarios/medi_qa_scenario.py +10 -4
- helm/benchmark/scenarios/medication_qa_scenario.py +7 -1
- helm/benchmark/scenarios/melt_ir_scenario.py +171 -0
- helm/benchmark/scenarios/melt_knowledge_scenario.py +246 -0
- helm/benchmark/scenarios/melt_lm_scenarios.py +252 -0
- helm/benchmark/scenarios/melt_scenarios.py +793 -0
- helm/benchmark/scenarios/melt_srn_scenario.py +342 -0
- helm/benchmark/scenarios/melt_synthetic_reasoning_scenario.py +222 -0
- helm/benchmark/scenarios/melt_translation_scenario.py +152 -0
- helm/benchmark/scenarios/mental_health_scenario.py +16 -5
- helm/benchmark/scenarios/mimic_bhc_scenario.py +12 -7
- helm/benchmark/scenarios/mimic_rrs_scenario.py +17 -8
- helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +14 -8
- helm/benchmark/scenarios/mmlu_pro_scenario.py +1 -1
- helm/benchmark/scenarios/mtsamples_procedures_scenario.py +5 -2
- helm/benchmark/scenarios/mtsamples_replicate_scenario.py +3 -2
- helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +11 -5
- helm/benchmark/scenarios/numeracy_scenario.py +2 -1
- helm/benchmark/scenarios/openai_mrcr_scenario.py +79 -0
- helm/benchmark/scenarios/pubmed_qa_scenario.py +6 -1
- helm/benchmark/scenarios/race_based_med_scenario.py +18 -8
- helm/benchmark/scenarios/ruler_qa_scenario_helper.py +2 -2
- helm/benchmark/scenarios/ruler_qa_scenarios.py +2 -2
- helm/benchmark/scenarios/shc_bmt_scenario.py +12 -6
- helm/benchmark/scenarios/shc_cdi_scenario.py +11 -6
- helm/benchmark/scenarios/shc_conf_scenario.py +12 -6
- helm/benchmark/scenarios/shc_ent_scenario.py +11 -6
- helm/benchmark/scenarios/shc_gip_scenario.py +13 -5
- helm/benchmark/scenarios/shc_privacy_scenario.py +78 -0
- helm/benchmark/scenarios/shc_proxy_scenario.py +76 -0
- helm/benchmark/scenarios/shc_ptbm_scenario.py +12 -7
- helm/benchmark/scenarios/shc_sei_scenario.py +12 -7
- helm/benchmark/scenarios/shc_sequoia_scenario.py +13 -5
- helm/benchmark/scenarios/starr_patient_instructions_scenario.py +15 -8
- helm/benchmark/scenarios/test_infinite_bench_en_qa_scenario.py +18 -0
- helm/benchmark/scenarios/test_infinite_bench_en_sum_scenario.py +31 -0
- helm/benchmark/scenarios/truthful_qa_scenario.py +2 -1
- helm/benchmark/scenarios/vision_language/msr_vtt_scenario.py +75 -0
- helm/benchmark/server.py +2 -1
- helm/benchmark/static/schema_audio.yaml +60 -49
- helm/benchmark/static/schema_enterprise.yaml +21 -0
- helm/benchmark/static/schema_long_context.yaml +63 -20
- helm/benchmark/static/schema_medhelm.yaml +272 -213
- helm/benchmark/static/schema_melt.yaml +1257 -0
- helm/benchmark/static/schema_slphelm.yaml +162 -0
- helm/benchmark/static/schema_vhelm.yaml +26 -26
- helm/benchmark/static/schema_video.yaml +219 -0
- helm/benchmark/static_build/assets/index-94295e78.js +10 -0
- helm/benchmark/static_build/assets/index-b9779128.css +1 -0
- helm/benchmark/static_build/assets/medhelm-overview-eac29843.png +0 -0
- helm/benchmark/static_build/assets/{tremor-9cefc3c5.js → tremor-38a10867.js} +1 -1
- helm/benchmark/static_build/index.html +4 -4
- helm/benchmark/window_services/encoder_decoder_window_service.py +3 -3
- helm/benchmark/window_services/test_utils.py +3 -4
- helm/benchmark/window_services/tokenizer_service.py +7 -8
- helm/clients/anthropic_client.py +69 -29
- helm/clients/audio_language/diva_llama_client.py +4 -2
- helm/clients/audio_language/qwen2_5_omni_client.py +197 -0
- helm/clients/audio_language/qwen2_audiolm_client.py +8 -6
- helm/clients/audio_language/qwen_audiolm_client.py +4 -2
- helm/clients/audio_language/test.py +62 -0
- helm/clients/bedrock_client.py +3 -1
- helm/clients/client.py +7 -7
- helm/clients/grok_client.py +36 -0
- helm/clients/huggingface_client.py +42 -3
- helm/clients/huggingface_pipeline_client.py +138 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +1 -1
- helm/clients/image_generation/dalle_mini/model/modeling.py +1 -1
- helm/clients/image_generation/dalle_mini/model/processor.py +1 -1
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +1 -1
- helm/clients/openai_client.py +100 -54
- helm/clients/openai_responses_client.py +174 -0
- helm/clients/palmyra_client.py +2 -5
- helm/clients/reka_client.py +2 -2
- helm/clients/together_client.py +31 -4
- helm/clients/vertexai_client.py +6 -0
- helm/clients/vision_language/huggingface_vision2seq_client.py +6 -4
- helm/clients/vision_language/huggingface_vlm_client.py +2 -2
- helm/clients/vision_language/idefics_client.py +6 -2
- helm/clients/vision_language/paligemma_client.py +2 -2
- helm/clients/vision_language/qwen2_vlm_client.py +66 -53
- helm/clients/vision_language/qwen_vlm_client.py +7 -5
- helm/clients/writer_client.py +102 -0
- helm/common/context.py +80 -0
- helm/common/credentials_utils.py +5 -5
- helm/common/general.py +9 -2
- helm/common/hierarchical_logger.py +46 -3
- helm/common/local_context.py +140 -0
- helm/common/remote_context.py +61 -0
- helm/common/request.py +8 -0
- helm/config/model_deployments.yaml +864 -193
- helm/config/model_metadata.yaml +667 -53
- helm/config/tokenizer_configs.yaml +144 -3
- helm/proxy/cli.py +3 -1
- helm/proxy/critique/mechanical_turk_utils.py +1 -1
- helm/proxy/services/server_service.py +21 -85
- helm/tokenizers/grok_tokenizer.py +53 -0
- helm/tokenizers/huggingface_tokenizer.py +1 -1
- helm/tokenizers/test_grok_tokenizer.py +33 -0
- helm/benchmark/scenarios/test_infinite_bench_sum_scenario.py +0 -46
- helm/benchmark/static_build/assets/index-262903c1.js +0 -10
- helm/benchmark/static_build/assets/index-42060d71.css +0 -1
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.6.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.6.dist-info}/licenses/LICENSE +0 -0
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.6.dist-info}/top_level.txt +0 -0
- /helm/benchmark/static_build/assets/{medhelm-overview-3ddfcd65.png → medhelm-v1-overview-3ddfcd65.png} +0 -0
helm/clients/bedrock_client.py
CHANGED
|
@@ -117,10 +117,12 @@ class BedrockNovaClient(CachingClient):
|
|
|
117
117
|
tokenizer_name: str,
|
|
118
118
|
assumed_role: Optional[str] = None,
|
|
119
119
|
region: Optional[str] = None,
|
|
120
|
+
bedrock_model_id: Optional[str] = None,
|
|
120
121
|
):
|
|
121
122
|
super().__init__(cache_config=cache_config)
|
|
122
123
|
self.tokenizer = tokenizer
|
|
123
124
|
self.tokenizer_name = tokenizer_name
|
|
125
|
+
self.bedrock_model_id = bedrock_model_id
|
|
124
126
|
self.bedrock_client = get_bedrock_client_v1(
|
|
125
127
|
assumed_role=assumed_role or os.environ.get("BEDROCK_ASSUME_ROLE", None),
|
|
126
128
|
region=region,
|
|
@@ -144,7 +146,7 @@ class BedrockNovaClient(CachingClient):
|
|
|
144
146
|
messages = self._get_messages_from_request(request)
|
|
145
147
|
|
|
146
148
|
return {
|
|
147
|
-
"modelId": model_id,
|
|
149
|
+
"modelId": self.bedrock_model_id or model_id,
|
|
148
150
|
"inferenceConfig": {
|
|
149
151
|
"temperature": request.temperature,
|
|
150
152
|
"maxTokens": request.max_tokens,
|
helm/clients/client.py
CHANGED
|
@@ -2,7 +2,7 @@ import json
|
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from typing import List, Mapping, Optional, cast
|
|
4
4
|
|
|
5
|
-
from helm.common.hierarchical_logger import
|
|
5
|
+
from helm.common.hierarchical_logger import hwarn
|
|
6
6
|
from helm.common.media_object import MultimediaObject, TEXT_TYPE
|
|
7
7
|
from helm.common.request import Request, RequestResult, GeneratedOutput, Token
|
|
8
8
|
from helm.common.cache import Cache, CacheConfig
|
|
@@ -65,7 +65,7 @@ def truncate_sequence(
|
|
|
65
65
|
# where max_tokens = 0, so there's nothing to truncate.
|
|
66
66
|
if request.echo_prompt:
|
|
67
67
|
if request.max_tokens != 0:
|
|
68
|
-
|
|
68
|
+
hwarn("don't know how to handle echo_prompt and max_tokens > 0, not truncating")
|
|
69
69
|
return sequence
|
|
70
70
|
|
|
71
71
|
if end_of_text_token:
|
|
@@ -90,8 +90,8 @@ def truncate_sequence(
|
|
|
90
90
|
new_tokens.append(token)
|
|
91
91
|
|
|
92
92
|
if len(new_text) < len(sequence.text) and len(new_tokens) == len(sequence.tokens):
|
|
93
|
-
|
|
94
|
-
f"
|
|
93
|
+
hwarn(
|
|
94
|
+
f"Stripped characters from text ({len(sequence.text)} -> {len(new_text)}), "
|
|
95
95
|
f"but wasn't able to strip the tokens"
|
|
96
96
|
)
|
|
97
97
|
|
|
@@ -99,14 +99,14 @@ def truncate_sequence(
|
|
|
99
99
|
new_logprob = sum(token.logprob for token in new_tokens)
|
|
100
100
|
|
|
101
101
|
if print_warning:
|
|
102
|
-
|
|
102
|
+
hwarn(f"truncate_sequence needs to strip {json.dumps(stop)}")
|
|
103
103
|
|
|
104
104
|
sequence = GeneratedOutput(text=new_text, logprob=new_logprob, tokens=new_tokens)
|
|
105
105
|
|
|
106
106
|
# Truncate based on the max number of tokens.
|
|
107
107
|
if len(sequence.tokens) > request.max_tokens:
|
|
108
108
|
if print_warning:
|
|
109
|
-
|
|
109
|
+
hwarn(f"truncate_sequence needs to truncate {len(sequence.tokens)} down to {request.max_tokens}")
|
|
110
110
|
new_tokens = sequence.tokens[: request.max_tokens]
|
|
111
111
|
|
|
112
112
|
# This is imperfect stitching together of tokens, so just to make sure this is okay
|
|
@@ -114,7 +114,7 @@ def truncate_sequence(
|
|
|
114
114
|
# Usually, in our benchmark, max_tokens is active when it's 1, so hopefully this isn't an issue.
|
|
115
115
|
new_text = "".join(token.text for token in new_tokens)
|
|
116
116
|
if not sequence.text.startswith(new_text):
|
|
117
|
-
|
|
117
|
+
hwarn(f"{json.dumps(sequence.text)} does not start with truncated text {json.dumps(new_text)}")
|
|
118
118
|
|
|
119
119
|
new_logprob = sum(token.logprob for token in new_tokens)
|
|
120
120
|
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
|
|
3
|
+
from helm.clients.openai_client import OpenAIClient
|
|
4
|
+
from helm.common.cache import CacheConfig
|
|
5
|
+
from helm.common.request import Request
|
|
6
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class GrokChatClient(OpenAIClient):
|
|
10
|
+
|
|
11
|
+
BASE_URL = "https://api.x.ai/v1"
|
|
12
|
+
|
|
13
|
+
_UNSUPPORTED_ARGUMENTS = ["presence_penalty", "frequency_penalty"]
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
tokenizer: Tokenizer,
|
|
18
|
+
tokenizer_name: str,
|
|
19
|
+
cache_config: CacheConfig,
|
|
20
|
+
api_key: Optional[str] = None,
|
|
21
|
+
):
|
|
22
|
+
super().__init__(
|
|
23
|
+
tokenizer=tokenizer,
|
|
24
|
+
tokenizer_name=tokenizer_name,
|
|
25
|
+
cache_config=cache_config,
|
|
26
|
+
api_key=api_key,
|
|
27
|
+
org_id=None,
|
|
28
|
+
base_url="https://api.x.ai/v1",
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
def _make_chat_raw_request(self, request: Request) -> Dict[str, Any]:
|
|
32
|
+
raw_request = super()._make_chat_raw_request(request)
|
|
33
|
+
for unsupported_argument in self._UNSUPPORTED_ARGUMENTS:
|
|
34
|
+
if unsupported_argument in raw_request:
|
|
35
|
+
del raw_request[unsupported_argument]
|
|
36
|
+
return raw_request
|
|
@@ -8,7 +8,7 @@ from transformers.generation.stopping_criteria import (
|
|
|
8
8
|
from typing import Any, Dict, List, Optional, TypedDict
|
|
9
9
|
|
|
10
10
|
from helm.common.cache import CacheConfig
|
|
11
|
-
from helm.common.hierarchical_logger import htrack_block, hlog
|
|
11
|
+
from helm.common.hierarchical_logger import htrack_block, hlog, hwarn
|
|
12
12
|
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
13
13
|
from helm.common.request import (
|
|
14
14
|
wrap_request_time,
|
|
@@ -18,6 +18,7 @@ from helm.common.request import (
|
|
|
18
18
|
GeneratedOutput,
|
|
19
19
|
Token,
|
|
20
20
|
)
|
|
21
|
+
from helm.proxy.retry import NonRetriableException
|
|
21
22
|
from helm.tokenizers.tokenizer import Tokenizer
|
|
22
23
|
from helm.clients.client import CachingClient, truncate_sequence
|
|
23
24
|
from helm.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer, WrappedPreTrainedTokenizer
|
|
@@ -256,6 +257,7 @@ class HuggingFaceClient(CachingClient):
|
|
|
256
257
|
tokenizer: Tokenizer,
|
|
257
258
|
pretrained_model_name_or_path: Optional[str] = None,
|
|
258
259
|
end_of_text_token: Optional[str] = None,
|
|
260
|
+
apply_chat_template: Optional[bool] = None,
|
|
259
261
|
**kwargs,
|
|
260
262
|
):
|
|
261
263
|
super().__init__(cache_config=cache_config)
|
|
@@ -266,9 +268,46 @@ class HuggingFaceClient(CachingClient):
|
|
|
266
268
|
"but instead it is {tokenizer}"
|
|
267
269
|
)
|
|
268
270
|
self._wrapped_tokenizer: WrappedPreTrainedTokenizer = tokenizer.get_wrapped_tokenizer()
|
|
269
|
-
self._tokenizer = tokenizer
|
|
270
271
|
self._kwargs = _process_huggingface_client_kwargs(kwargs)
|
|
271
272
|
self._end_of_text_token = end_of_text_token
|
|
273
|
+
# If the user did not explicitly configure whether the model is a chat model with `apply_chat_template` arg,
|
|
274
|
+
# auto-infer if the model is a chat model based on whether the tokenizer has a chat template.
|
|
275
|
+
# Note: Auto-inference is incorrect for some non-chat models that still have chat templates
|
|
276
|
+
# e.g. Qwen2, Qwen 2.5.
|
|
277
|
+
# For these models, the `apply_chat_template` arg should be explicitly set to false.
|
|
278
|
+
if apply_chat_template is not None:
|
|
279
|
+
self._apply_chat_template = apply_chat_template
|
|
280
|
+
else:
|
|
281
|
+
with self._wrapped_tokenizer as hf_tokenizer:
|
|
282
|
+
self._apply_chat_template = bool(hf_tokenizer.chat_template)
|
|
283
|
+
hwarn(
|
|
284
|
+
f"Automatically set `apply_chat_template` to {self._apply_chat_template} based on "
|
|
285
|
+
"whether the tokenizer has a chat template. "
|
|
286
|
+
"If this is incorrect, please explicitly set `apply_chat_template`."
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
def get_prompt(self, request: Request) -> str:
|
|
290
|
+
if request.prompt and request.messages:
|
|
291
|
+
raise NonRetriableException(f"More than one of `prompt` and `messages` was set in request: {request}")
|
|
292
|
+
# Chat model expects a list of messages as input
|
|
293
|
+
if self._apply_chat_template:
|
|
294
|
+
with self._wrapped_tokenizer as tokenizer:
|
|
295
|
+
if request.messages:
|
|
296
|
+
prompt = tokenizer.apply_chat_template(request.messages, tokenize=False)
|
|
297
|
+
assert isinstance(prompt, str)
|
|
298
|
+
return prompt
|
|
299
|
+
else:
|
|
300
|
+
prompt = tokenizer.apply_chat_template(
|
|
301
|
+
[{"role": "user", "content": request.prompt}], tokenize=False
|
|
302
|
+
)
|
|
303
|
+
assert isinstance(prompt, str)
|
|
304
|
+
return prompt
|
|
305
|
+
# Base non-chat model expects a string as input
|
|
306
|
+
else:
|
|
307
|
+
if request.messages:
|
|
308
|
+
raise NonRetriableException("Chat mesages not supported by non-chat model")
|
|
309
|
+
else:
|
|
310
|
+
return request.prompt
|
|
272
311
|
|
|
273
312
|
def make_request(self, request: Request) -> RequestResult:
|
|
274
313
|
# Embedding not supported for this model
|
|
@@ -277,7 +316,7 @@ class HuggingFaceClient(CachingClient):
|
|
|
277
316
|
|
|
278
317
|
raw_request: HuggingFaceRequest = {
|
|
279
318
|
"engine": request.model_engine,
|
|
280
|
-
"prompt": request
|
|
319
|
+
"prompt": self.get_prompt(request),
|
|
281
320
|
"temperature": 1e-7 if request.temperature == 0 else request.temperature,
|
|
282
321
|
"num_return_sequences": request.num_completions,
|
|
283
322
|
"max_new_tokens": request.max_tokens,
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
from threading import Lock
|
|
2
|
+
from typing import Any, Dict, List, Optional, Union
|
|
3
|
+
|
|
4
|
+
import transformers
|
|
5
|
+
|
|
6
|
+
from helm.clients.client import CachingClient
|
|
7
|
+
from helm.common.cache import CacheConfig
|
|
8
|
+
from helm.common.hierarchical_logger import htrack_block, hwarn
|
|
9
|
+
from helm.common.request import GeneratedOutput, Request, RequestResult, wrap_request_time
|
|
10
|
+
from helm.proxy.retry import NonRetriableException
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
_pipelines: Dict[str, transformers.Pipeline] = {}
|
|
14
|
+
_pipelines_lock: Lock = Lock()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _get_pipeline(
|
|
18
|
+
helm_model_name: str,
|
|
19
|
+
pipeline_kwargs: Dict[str, Any],
|
|
20
|
+
) -> Any:
|
|
21
|
+
"""
|
|
22
|
+
Checks if the desired HuggingFaceModel is cached. Creates the HuggingFaceModel if it's not cached.
|
|
23
|
+
Returns the HuggingFaceModel.
|
|
24
|
+
"""
|
|
25
|
+
global _pipelines
|
|
26
|
+
global _pipelines_lock
|
|
27
|
+
with _pipelines_lock:
|
|
28
|
+
if helm_model_name not in _pipelines:
|
|
29
|
+
huggingface_model_name = pipeline_kwargs["model"]
|
|
30
|
+
with htrack_block(
|
|
31
|
+
f"Loading HuggingFace model {huggingface_model_name} (kwargs={pipeline_kwargs}) "
|
|
32
|
+
f"for HELM model {helm_model_name} with transformers.pipeline"
|
|
33
|
+
):
|
|
34
|
+
_pipelines[helm_model_name] = transformers.pipeline(**pipeline_kwargs)
|
|
35
|
+
|
|
36
|
+
return _pipelines[helm_model_name]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class HuggingFacePipelineClient(CachingClient):
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
cache_config: CacheConfig,
|
|
43
|
+
model_name: str,
|
|
44
|
+
pretrained_model_name_or_path: Optional[str] = None,
|
|
45
|
+
apply_chat_template: Optional[bool] = None,
|
|
46
|
+
**kwargs,
|
|
47
|
+
):
|
|
48
|
+
# Include `pretrained_model_name_or_path` parameter so that model deployments can use
|
|
49
|
+
# the `pretrained_model_name_or_path` arg to override `model_name`
|
|
50
|
+
super().__init__(cache_config=cache_config)
|
|
51
|
+
self._helm_model_name = model_name
|
|
52
|
+
self._pipeline_kwargs = {
|
|
53
|
+
"model": pretrained_model_name_or_path or self._helm_model_name,
|
|
54
|
+
"task": "text-generation",
|
|
55
|
+
**kwargs,
|
|
56
|
+
}
|
|
57
|
+
self._pipeline = _get_pipeline(self._helm_model_name, self._pipeline_kwargs)
|
|
58
|
+
if apply_chat_template is not None:
|
|
59
|
+
self._apply_chat_template = apply_chat_template
|
|
60
|
+
else:
|
|
61
|
+
# If the user did not explicitly configure whether the model is a chat model with `apply_chat_template` arg,
|
|
62
|
+
# auto-infer if the model is a chat model based on whether the tokenizer has a chat template.
|
|
63
|
+
# Note: Auto-inference is incorrect for some non-chat models that still have chat templates
|
|
64
|
+
# e.g. Qwen2, Qwen 2.5.
|
|
65
|
+
# For these models, the `apply_chat_template` arg should be explicitly set to false.
|
|
66
|
+
self._apply_chat_template = bool(self._pipeline.tokenizer.chat_template)
|
|
67
|
+
hwarn(
|
|
68
|
+
f"Automatically set `apply_chat_template` to {self._apply_chat_template} based on "
|
|
69
|
+
"whether the tokenizer has a chat template. "
|
|
70
|
+
"If this is incorrect, please explicitly set `apply_chat_template`."
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def make_text_inputs(self, request: Request) -> Union[str, List[Dict[str, str]]]:
|
|
74
|
+
if request.prompt and request.messages:
|
|
75
|
+
raise NonRetriableException(f"More than one of `prompt` and `messages` was set in request: {request}")
|
|
76
|
+
# Chat model expects a list of messages as input
|
|
77
|
+
if self._apply_chat_template:
|
|
78
|
+
if request.messages:
|
|
79
|
+
return request.messages
|
|
80
|
+
else:
|
|
81
|
+
return [{"role": "user", "content": request.prompt}]
|
|
82
|
+
# Base non-chat model expects a string as input
|
|
83
|
+
else:
|
|
84
|
+
if request.messages:
|
|
85
|
+
raise NonRetriableException("Chat mesages not supported by non-chat model")
|
|
86
|
+
else:
|
|
87
|
+
return request.prompt
|
|
88
|
+
|
|
89
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
90
|
+
"""Make a request"""
|
|
91
|
+
if request.model != self._helm_model_name:
|
|
92
|
+
raise NonRetriableException(
|
|
93
|
+
f"This instance of HuggingFacePipelineClient has loaded model {self._helm_model_name} but the request was for model {request.model}" # noqa: E501
|
|
94
|
+
)
|
|
95
|
+
completions: List[GeneratedOutput] = []
|
|
96
|
+
do_sample = request.temperature > 0.0
|
|
97
|
+
raw_request = {
|
|
98
|
+
"text_inputs": self.make_text_inputs(request),
|
|
99
|
+
"return_full_text": request.echo_prompt,
|
|
100
|
+
"temperature": request.temperature if do_sample else None,
|
|
101
|
+
"num_return_sequences": request.num_completions,
|
|
102
|
+
"max_new_tokens": request.max_tokens,
|
|
103
|
+
"top_p": request.top_p,
|
|
104
|
+
"top_k": request.top_k_per_token if do_sample else None,
|
|
105
|
+
"do_sample": do_sample,
|
|
106
|
+
"return_dict_in_generate": True,
|
|
107
|
+
}
|
|
108
|
+
if request.stop_sequences:
|
|
109
|
+
stop_sequence_ids = self._pipeline.tokenizer(
|
|
110
|
+
request.stop_sequences, return_token_type_ids=False, add_special_tokens=False
|
|
111
|
+
)
|
|
112
|
+
if len(stop_sequence_ids.input_ids) == 1 and len(stop_sequence_ids.input_ids[0]) == 1:
|
|
113
|
+
raw_request["eos_token_id"] = stop_sequence_ids.input_ids[0][0]
|
|
114
|
+
else:
|
|
115
|
+
raise NonRetriableException(
|
|
116
|
+
"Multiple stop sequences and stop sequences of multiple tokens, are not yet supported by HuggingFacePipelineClient" # noqa: E501
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def do_it() -> Dict[str, Any]:
|
|
120
|
+
pipeline_outputs = self._pipeline(**raw_request)
|
|
121
|
+
return {"outputs": pipeline_outputs}
|
|
122
|
+
|
|
123
|
+
cache_key = CachingClient.make_cache_key(
|
|
124
|
+
{"pipeline_kwargs": self._pipeline_kwargs, **raw_request},
|
|
125
|
+
request,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
129
|
+
for raw_output in response["outputs"]:
|
|
130
|
+
completions.append(GeneratedOutput(text=raw_output["generated_text"], logprob=0, tokens=[]))
|
|
131
|
+
return RequestResult(
|
|
132
|
+
success=True,
|
|
133
|
+
cached=cached,
|
|
134
|
+
request_time=response["request_time"],
|
|
135
|
+
request_datetime=response["request_datetime"],
|
|
136
|
+
completions=completions,
|
|
137
|
+
embedding=[],
|
|
138
|
+
)
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
"""
|
|
15
|
+
"""DalleBart model configuration"""
|
|
16
16
|
import warnings
|
|
17
17
|
|
|
18
18
|
from transformers.configuration_utils import PretrainedConfig
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
"""
|
|
15
|
+
"""DalleBart model."""
|
|
16
16
|
|
|
17
17
|
import math
|
|
18
18
|
from functools import partial
|
helm/clients/openai_client.py
CHANGED
|
@@ -1,13 +1,16 @@
|
|
|
1
1
|
# mypy: check_untyped_defs = False
|
|
2
2
|
from dataclasses import replace
|
|
3
|
+
import re
|
|
3
4
|
from typing import Any, Dict, List, Optional, cast, Union, Callable
|
|
4
5
|
|
|
6
|
+
from openai import OpenAIError
|
|
7
|
+
|
|
5
8
|
from helm.benchmark.model_metadata_registry import is_vlm
|
|
6
9
|
from helm.common import multimodal_request_utils
|
|
7
10
|
from helm.common.cache import CacheConfig
|
|
8
|
-
from helm.common.media_object import TEXT_TYPE, MultimediaObject
|
|
9
|
-
from helm.common.request import ErrorFlags, wrap_request_time, Request, RequestResult, GeneratedOutput, Token
|
|
10
|
-
from helm.common.hierarchical_logger import hlog
|
|
11
|
+
from helm.common.media_object import TEXT_TYPE, MultimediaObject, MediaObject
|
|
12
|
+
from helm.common.request import ErrorFlags, Thinking, wrap_request_time, Request, RequestResult, GeneratedOutput, Token
|
|
13
|
+
from helm.common.hierarchical_logger import hlog, hwarn
|
|
11
14
|
from helm.common.object_spec import get_class_by_name
|
|
12
15
|
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
13
16
|
from helm.common.tokenization_request import (
|
|
@@ -24,8 +27,13 @@ except ModuleNotFoundError as e:
|
|
|
24
27
|
handle_module_not_found_error(e, ["openai"])
|
|
25
28
|
|
|
26
29
|
|
|
27
|
-
class
|
|
28
|
-
|
|
30
|
+
class OpenAIClientUtils:
|
|
31
|
+
"""Methods used by both the chat completions client and the responses API client"""
|
|
32
|
+
|
|
33
|
+
@classmethod
|
|
34
|
+
def is_reasoning_model(cls, model_engine: str) -> bool:
|
|
35
|
+
# All OpenAI reasoning models start "o[somenumber]", so we regexp for that to future proof things
|
|
36
|
+
return bool(re.match(r"^o\d+", model_engine))
|
|
29
37
|
|
|
30
38
|
# Error OpenAI throws when the image in the prompt violates their content policy
|
|
31
39
|
INAPPROPRIATE_IMAGE_ERROR: str = "Your input image may contain content that is not allowed by our safety system"
|
|
@@ -49,6 +57,56 @@ class OpenAIClient(CachingClient):
|
|
|
49
57
|
"See https://labs.openai.com/policies/content-policy for more information."
|
|
50
58
|
)
|
|
51
59
|
|
|
60
|
+
@classmethod
|
|
61
|
+
def handle_openai_error(cls, e: OpenAIError, request: Request):
|
|
62
|
+
if cls.INAPPROPRIATE_IMAGE_ERROR in str(e) or cls.INAPPROPRIATE_PROMPT_ERROR in str(e):
|
|
63
|
+
hwarn(f"Failed safety check: {str(request)}")
|
|
64
|
+
empty_completion = GeneratedOutput(
|
|
65
|
+
text="",
|
|
66
|
+
logprob=0,
|
|
67
|
+
tokens=[],
|
|
68
|
+
finish_reason={"reason": cls.CONTENT_POLICY_VIOLATED_FINISH_REASON},
|
|
69
|
+
)
|
|
70
|
+
return RequestResult(
|
|
71
|
+
success=True,
|
|
72
|
+
cached=False,
|
|
73
|
+
request_time=0,
|
|
74
|
+
completions=[empty_completion] * request.num_completions,
|
|
75
|
+
embedding=[],
|
|
76
|
+
)
|
|
77
|
+
elif cls.OPENAI_SERVER_ERROR in str(e):
|
|
78
|
+
# Handle these errors by returning an empty completion to unblock
|
|
79
|
+
hwarn(f"OpenAI server error for request: {str(request)}")
|
|
80
|
+
empty_completion = GeneratedOutput(
|
|
81
|
+
text="",
|
|
82
|
+
logprob=0,
|
|
83
|
+
tokens=[],
|
|
84
|
+
finish_reason={"reason": cls.OPENAI_SERVER_ERROR},
|
|
85
|
+
)
|
|
86
|
+
return RequestResult(
|
|
87
|
+
success=True,
|
|
88
|
+
cached=False,
|
|
89
|
+
request_time=0,
|
|
90
|
+
completions=[empty_completion] * request.num_completions,
|
|
91
|
+
embedding=[],
|
|
92
|
+
)
|
|
93
|
+
elif cls.INAPPROPRIATE_PROMPT_AZURE_ERROR in str(e) or cls.INAPPROPRIATE_PROMPT_MICROSOFT_ERROR in str(e):
|
|
94
|
+
return RequestResult(
|
|
95
|
+
success=False,
|
|
96
|
+
cached=False,
|
|
97
|
+
error="Content blocked by Azure's content management filter",
|
|
98
|
+
completions=[],
|
|
99
|
+
embedding=[],
|
|
100
|
+
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
error: str = f"OpenAI error: {e}"
|
|
104
|
+
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class OpenAIClient(CachingClient):
|
|
108
|
+
END_OF_TEXT: str = "<|endoftext|>"
|
|
109
|
+
|
|
52
110
|
def __init__(
|
|
53
111
|
self,
|
|
54
112
|
tokenizer: Tokenizer,
|
|
@@ -118,7 +176,7 @@ class OpenAIClient(CachingClient):
|
|
|
118
176
|
embedding=embedding,
|
|
119
177
|
)
|
|
120
178
|
|
|
121
|
-
def
|
|
179
|
+
def _make_chat_raw_request(self, request: Request) -> Dict[str, Any]:
|
|
122
180
|
messages: Optional[List[Dict[str, Union[str, Any]]]] = request.messages
|
|
123
181
|
if (
|
|
124
182
|
(request.prompt and request.messages)
|
|
@@ -137,7 +195,7 @@ class OpenAIClient(CachingClient):
|
|
|
137
195
|
if request.messages[-1]["role"] != "user":
|
|
138
196
|
raise ValueError("Last message must have role 'user'")
|
|
139
197
|
if request.prompt != "":
|
|
140
|
-
|
|
198
|
+
hwarn("Since message is set, prompt will be ignored")
|
|
141
199
|
else:
|
|
142
200
|
# Convert prompt into a single message
|
|
143
201
|
# For now, put the whole prompt in a single user message, and expect the response
|
|
@@ -223,7 +281,7 @@ class OpenAIClient(CachingClient):
|
|
|
223
281
|
# Refer to the "Reasoning models" documentation further discussion of o1 model limitations:
|
|
224
282
|
# https://platform.openai.com/docs/guides/reasoning
|
|
225
283
|
model_engine: str = request.model_engine
|
|
226
|
-
if
|
|
284
|
+
if OpenAIClientUtils.is_reasoning_model(model_engine):
|
|
227
285
|
# Avoid error:
|
|
228
286
|
# "Unsupported parameter: 'max_tokens' is not supported with this model. Use 'max_completion_tokens' instead." # noqa: E501
|
|
229
287
|
# Note that openai>=1.45 is needed for this
|
|
@@ -241,8 +299,13 @@ class OpenAIClient(CachingClient):
|
|
|
241
299
|
# 'code': 'unsupported_parameter'}}"
|
|
242
300
|
raw_request.pop("temperature", None)
|
|
243
301
|
|
|
302
|
+
# The following parameters also happen to be unsupported by the o-series (code unsupported_parameter)
|
|
303
|
+
raw_request.pop("top_p", None)
|
|
304
|
+
raw_request.pop("frequency_penalty", None)
|
|
305
|
+
raw_request.pop("presence_penalty", None)
|
|
306
|
+
|
|
244
307
|
if self.reasoning_effort:
|
|
245
|
-
raw_request["reasoning_effort"] =
|
|
308
|
+
raw_request["reasoning_effort"] = self.reasoning_effort
|
|
246
309
|
elif is_vlm(request.model):
|
|
247
310
|
# Avoid error:
|
|
248
311
|
# "Invalid type for 'stop': expected an unsupported value, but got null instead."
|
|
@@ -258,6 +321,10 @@ class OpenAIClient(CachingClient):
|
|
|
258
321
|
# OpenAI error: Error code: 400 - {'error': {'message': "[{'type': 'string_type', 'loc': ('body', 'stop', 'str'), 'msg': 'Input should be a valid string', 'input': None}, {'type': 'list_type', 'loc': ('body', 'stop', 'list[str]'), 'msg': 'Input should be a valid list', 'input': None}, {'type': 'list_type', 'loc': ('body', 'stop', 'list[list[int]]'), 'msg': 'Input should be a valid list', 'input': None}]", 'type': 'invalid_request_error', 'param': None, 'code': None}} # noqa: 3501
|
|
259
322
|
if raw_request["stop"] is None:
|
|
260
323
|
raw_request.pop("stop")
|
|
324
|
+
return raw_request
|
|
325
|
+
|
|
326
|
+
def _make_chat_request(self, request: Request) -> RequestResult:
|
|
327
|
+
raw_request = self._make_chat_raw_request(request)
|
|
261
328
|
|
|
262
329
|
def do_it() -> Dict[str, Any]:
|
|
263
330
|
return self.client.chat.completions.create(**raw_request).model_dump(mode="json")
|
|
@@ -266,49 +333,7 @@ class OpenAIClient(CachingClient):
|
|
|
266
333
|
cache_key = self._get_cache_key(raw_request, request)
|
|
267
334
|
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
268
335
|
except openai.OpenAIError as e:
|
|
269
|
-
|
|
270
|
-
hlog(f"Failed safety check: {str(request)}")
|
|
271
|
-
empty_completion = GeneratedOutput(
|
|
272
|
-
text="",
|
|
273
|
-
logprob=0,
|
|
274
|
-
tokens=[],
|
|
275
|
-
finish_reason={"reason": self.CONTENT_POLICY_VIOLATED_FINISH_REASON},
|
|
276
|
-
)
|
|
277
|
-
return RequestResult(
|
|
278
|
-
success=True,
|
|
279
|
-
cached=False,
|
|
280
|
-
request_time=0,
|
|
281
|
-
completions=[empty_completion] * request.num_completions,
|
|
282
|
-
embedding=[],
|
|
283
|
-
)
|
|
284
|
-
elif self.OPENAI_SERVER_ERROR in str(e):
|
|
285
|
-
# Handle these errors by returning an empty completion to unblock
|
|
286
|
-
hlog(f"OpenAI server error for request: {str(request)}")
|
|
287
|
-
empty_completion = GeneratedOutput(
|
|
288
|
-
text="",
|
|
289
|
-
logprob=0,
|
|
290
|
-
tokens=[],
|
|
291
|
-
finish_reason={"reason": self.OPENAI_SERVER_ERROR},
|
|
292
|
-
)
|
|
293
|
-
return RequestResult(
|
|
294
|
-
success=True,
|
|
295
|
-
cached=False,
|
|
296
|
-
request_time=0,
|
|
297
|
-
completions=[empty_completion] * request.num_completions,
|
|
298
|
-
embedding=[],
|
|
299
|
-
)
|
|
300
|
-
elif self.INAPPROPRIATE_PROMPT_AZURE_ERROR in str(e) or self.INAPPROPRIATE_PROMPT_MICROSOFT_ERROR in str(e):
|
|
301
|
-
return RequestResult(
|
|
302
|
-
success=False,
|
|
303
|
-
cached=False,
|
|
304
|
-
error="Content blocked by Azure's content management filter",
|
|
305
|
-
completions=[],
|
|
306
|
-
embedding=[],
|
|
307
|
-
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
308
|
-
)
|
|
309
|
-
|
|
310
|
-
error: str = f"OpenAI error: {e}"
|
|
311
|
-
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
336
|
+
return OpenAIClientUtils.handle_openai_error(e, request)
|
|
312
337
|
|
|
313
338
|
completions: List[GeneratedOutput] = []
|
|
314
339
|
for raw_completion in response["choices"]:
|
|
@@ -338,11 +363,20 @@ class OpenAIClient(CachingClient):
|
|
|
338
363
|
tokens: List[Token] = [
|
|
339
364
|
Token(text=cast(str, raw_token), logprob=0) for raw_token in tokenization_result.raw_tokens
|
|
340
365
|
]
|
|
366
|
+
# vLLM has a optional `reasoning_content` field in the message
|
|
367
|
+
# that is not in the standard OpenAI API.
|
|
368
|
+
# This field is also used by some model providers such as Grok.
|
|
369
|
+
thinking = (
|
|
370
|
+
Thinking(text=raw_completion["message"]["reasoning_content"])
|
|
371
|
+
if "reasoning_content" in raw_completion["message"]
|
|
372
|
+
else None
|
|
373
|
+
)
|
|
341
374
|
completion = GeneratedOutput(
|
|
342
375
|
text=text,
|
|
343
376
|
logprob=0, # OpenAI does not provide logprobs
|
|
344
377
|
tokens=tokens,
|
|
345
378
|
finish_reason={"reason": raw_completion["finish_reason"]},
|
|
379
|
+
thinking=thinking,
|
|
346
380
|
)
|
|
347
381
|
completions.append(truncate_sequence(completion, request)) # Truncate the text by stop sequences
|
|
348
382
|
|
|
@@ -459,7 +493,7 @@ class OpenAIClient(CachingClient):
|
|
|
459
493
|
def make_request(self, request: Request) -> RequestResult:
|
|
460
494
|
if request.embedding:
|
|
461
495
|
return self._make_embedding_request(request)
|
|
462
|
-
elif "whisper" in request.model_engine:
|
|
496
|
+
elif "whisper" in request.model_engine or "transcribe" in request.model_engine:
|
|
463
497
|
return self._make_transcription_request(request)
|
|
464
498
|
else:
|
|
465
499
|
return self._make_chat_request(request)
|
|
@@ -536,6 +570,18 @@ class OpenAITranscriptionThenCompletionClient(Client):
|
|
|
536
570
|
# Now make the request to the completion model with just a text-only prompt and no audio
|
|
537
571
|
# Use the same decoding parameters as the original request
|
|
538
572
|
# Ensure to set multimodal_prompt to None so the request is treated as text-only.
|
|
539
|
-
|
|
573
|
+
request_result: RequestResult = self._openai_client.make_request(
|
|
540
574
|
replace(request, prompt=text_prompt, model=f"openai/{completion_model}", multimodal_prompt=None)
|
|
541
575
|
)
|
|
576
|
+
|
|
577
|
+
# Also include the generated transcript to the request result
|
|
578
|
+
completions_with_transcript: List[GeneratedOutput] = [
|
|
579
|
+
replace(
|
|
580
|
+
completion,
|
|
581
|
+
multimodal_content=MultimediaObject(
|
|
582
|
+
media_objects=[MediaObject(text=text_prompt, content_type="text/plain")]
|
|
583
|
+
),
|
|
584
|
+
)
|
|
585
|
+
for completion in request_result.completions
|
|
586
|
+
]
|
|
587
|
+
return replace(request_result, completions=completions_with_transcript)
|