crfm-helm 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/METADATA +74 -53
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/RECORD +262 -182
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +3 -3
- helm/benchmark/adaptation/adapters/test_adapter.py +4 -4
- helm/benchmark/annotation/air_bench_annotator.py +2 -2
- helm/benchmark/annotation/bigcodebench_annotator.py +3 -3
- helm/benchmark/annotation/bird_sql_annotator.py +2 -2
- helm/benchmark/annotation/chw_care_plan_annotator.py +7 -12
- helm/benchmark/annotation/ehr_sql_annotator.py +2 -2
- helm/benchmark/annotation/helpdesk_call_summarization_annotator.py +7 -7
- helm/benchmark/annotation/live_qa_annotator.py +1 -1
- helm/benchmark/annotation/mimic_bhc_annotator.py +100 -0
- helm/benchmark/annotation/model_as_judge.py +12 -16
- helm/benchmark/annotation/omni_math_annotator.py +13 -14
- helm/benchmark/annotation/wildbench_annotator.py +9 -9
- helm/benchmark/executor.py +11 -12
- helm/benchmark/metrics/aci_bench_metrics.py +9 -29
- helm/benchmark/metrics/bias_word_lists.py +1 -1
- helm/benchmark/metrics/chw_care_plan_metrics.py +10 -30
- helm/benchmark/metrics/classification_metrics.py +3 -3
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/codeinsights_code_efficiency_metrics.py +186 -0
- helm/benchmark/metrics/codeinsights_code_evaluation_metrics.py +477 -0
- helm/benchmark/metrics/codeinsights_correct_code_metrics.py +366 -0
- helm/benchmark/metrics/codeinsights_edge_case_metrics.py +92 -0
- helm/benchmark/metrics/codeinsights_metric_specs.py +51 -0
- helm/benchmark/metrics/comet_metric.py +1 -1
- helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +2 -2
- helm/benchmark/metrics/copyright_metrics.py +1 -1
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +1 -1
- helm/benchmark/metrics/dischargeme_metrics.py +9 -29
- helm/benchmark/metrics/efficiency_metrics.py +3 -3
- helm/benchmark/metrics/evaluate_reference_metrics.py +1 -1
- helm/benchmark/metrics/gpt4_audio_refusal_metrics.py +145 -0
- helm/benchmark/metrics/ifeval_metrics.py +2 -2
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +13 -2
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +1 -1
- helm/benchmark/metrics/kpi_edgar_metrics.py +121 -0
- helm/benchmark/metrics/llm_jury_metrics.py +46 -0
- helm/benchmark/metrics/lmkt_metric_specs.py +12 -0
- helm/benchmark/metrics/lmkt_metrics.py +47 -0
- helm/benchmark/metrics/med_dialog_metrics.py +9 -29
- helm/benchmark/metrics/medalign_metrics.py +9 -29
- helm/benchmark/metrics/medi_qa_metrics.py +9 -29
- helm/benchmark/metrics/medication_qa_metrics.py +10 -30
- helm/benchmark/metrics/melt_bias_metric.py +234 -0
- helm/benchmark/metrics/melt_bias_word_lists.py +1367 -0
- helm/benchmark/metrics/melt_metric_specs.py +43 -0
- helm/benchmark/metrics/melt_toxicity_metric.py +107 -0
- helm/benchmark/metrics/mental_health_metrics.py +9 -29
- helm/benchmark/metrics/metric_service.py +11 -11
- helm/benchmark/metrics/mimic_bhc_metrics.py +14 -0
- helm/benchmark/metrics/mimic_rrs_metrics.py +9 -29
- helm/benchmark/metrics/mtsamples_procedures_metrics.py +9 -29
- helm/benchmark/metrics/mtsamples_replicate_metrics.py +9 -29
- helm/benchmark/metrics/openai_mrcr_metrics.py +52 -0
- helm/benchmark/metrics/ruler_qa_metrics.py +34 -0
- helm/benchmark/metrics/starr_patient_instructions_metrics.py +9 -29
- helm/benchmark/metrics/summac/model_summac.py +2 -3
- helm/benchmark/metrics/summarization_metrics.py +2 -1
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +2 -2
- helm/benchmark/metrics/toxicity_metrics.py +2 -2
- helm/benchmark/metrics/unitxt_metrics.py +3 -4
- helm/benchmark/metrics/vision_language/emd_utils.py +4 -2
- helm/benchmark/metrics/vision_language/image_utils.py +2 -2
- helm/benchmark/model_deployment_registry.py +16 -26
- helm/benchmark/presentation/contamination.py +3 -3
- helm/benchmark/presentation/create_plots.py +43 -13
- helm/benchmark/presentation/run_display.py +13 -0
- helm/benchmark/presentation/schema.py +7 -1
- helm/benchmark/presentation/summarize.py +84 -61
- helm/benchmark/presentation/test_create_plots.py +4 -1
- helm/benchmark/reeval_run.py +3 -4
- helm/benchmark/reeval_runner.py +3 -3
- helm/benchmark/run.py +84 -73
- helm/benchmark/run_expander.py +12 -1
- helm/benchmark/run_spec_factory.py +7 -6
- helm/benchmark/run_specs/arabic_run_specs.py +73 -0
- helm/benchmark/run_specs/audio_run_specs.py +52 -8
- helm/benchmark/run_specs/bluex_run_specs.py +40 -0
- helm/benchmark/run_specs/classic_run_specs.py +0 -53
- helm/benchmark/run_specs/codeinsights_run_specs.py +192 -0
- helm/benchmark/run_specs/enterprise_run_specs.py +20 -0
- helm/benchmark/run_specs/experimental_run_specs.py +31 -1
- helm/benchmark/run_specs/healthqa_br_run_specs.py +40 -0
- helm/benchmark/run_specs/heim_run_specs.py +3 -1
- helm/benchmark/run_specs/lmkt_run_specs.py +144 -0
- helm/benchmark/run_specs/long_context_run_specs.py +114 -15
- helm/benchmark/run_specs/medhelm_run_specs.py +146 -41
- helm/benchmark/run_specs/melt_run_specs.py +783 -0
- helm/benchmark/run_specs/multilingual_run_specs.py +50 -0
- helm/benchmark/run_specs/speech_disorder_audio_run_specs.py +163 -0
- helm/benchmark/run_specs/vlm_run_specs.py +28 -0
- helm/benchmark/runner.py +5 -5
- helm/benchmark/scenarios/aci_bench_scenario.py +7 -1
- helm/benchmark/scenarios/alghafa_scenario.py +126 -0
- helm/benchmark/scenarios/arabic_mmlu_scenario.py +78 -0
- helm/benchmark/scenarios/aratrust_scenario.py +76 -0
- helm/benchmark/scenarios/audio_language/air_bench_chat_scenario.py +3 -1
- helm/benchmark/scenarios/audio_language/air_bench_foundation_scenario.py +5 -5
- helm/benchmark/scenarios/audio_language/casual_conversations2_scenario.py +1 -1
- helm/benchmark/scenarios/audio_language/corebench_scenario.py +77 -0
- helm/benchmark/scenarios/audio_language/mustard_scenario.py +1 -1
- helm/benchmark/scenarios/audio_language/ultra_suite_asr_classification_scenario.py +104 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_asr_transcription_scenario.py +99 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_classification_scenario.py +118 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_breakdown_scenario.py +86 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_symptoms_scenario.py +117 -0
- helm/benchmark/scenarios/audio_language/vocal_sound_scenario.py +15 -1
- helm/benchmark/scenarios/audio_language/voxceleb2_scenario.py +1 -2
- helm/benchmark/scenarios/autobencher_capabilities_scenario.py +2 -2
- helm/benchmark/scenarios/bluex_scenario.py +66 -0
- helm/benchmark/scenarios/chw_care_plan_scenario.py +14 -13
- helm/benchmark/scenarios/clear_scenario.py +11 -7
- helm/benchmark/scenarios/cleva_scenario.py +1 -1
- helm/benchmark/scenarios/codeinsights_code_efficiency_scenario.py +197 -0
- helm/benchmark/scenarios/codeinsights_correct_code_scenario.py +78 -0
- helm/benchmark/scenarios/codeinsights_edge_case_scenario.py +192 -0
- helm/benchmark/scenarios/codeinsights_student_coding_scenario.py +162 -0
- helm/benchmark/scenarios/codeinsights_student_mistake_scenario.py +188 -0
- helm/benchmark/scenarios/dischargeme_scenario.py +36 -21
- helm/benchmark/scenarios/ehr_sql_scenario.py +7 -1
- helm/benchmark/scenarios/ehrshot_scenario.py +28 -55
- helm/benchmark/scenarios/exams_multilingual_scenario.py +115 -0
- helm/benchmark/scenarios/grammar.py +2 -2
- helm/benchmark/scenarios/headqa_scenario.py +6 -1
- helm/benchmark/scenarios/healthqa_br_scenario.py +80 -0
- helm/benchmark/scenarios/infinite_bench_en_mc_scenario.py +90 -0
- helm/benchmark/scenarios/infinite_bench_en_qa_scenario.py +85 -0
- helm/benchmark/scenarios/{infinite_bench_sum_scenario.py → infinite_bench_en_sum_scenario.py} +10 -13
- helm/benchmark/scenarios/kpi_edgar_scenario.py +151 -0
- helm/benchmark/scenarios/lmkt_scenarios.py +288 -0
- helm/benchmark/scenarios/math_scenario.py +21 -20
- helm/benchmark/scenarios/med_dialog_scenario.py +6 -1
- helm/benchmark/scenarios/medalign_scenario.py +9 -3
- helm/benchmark/scenarios/medalign_scenario_helper.py +27 -130
- helm/benchmark/scenarios/medbullets_scenario.py +7 -2
- helm/benchmark/scenarios/medcalc_bench_scenario.py +4 -2
- helm/benchmark/scenarios/medec_scenario.py +6 -1
- helm/benchmark/scenarios/medhallu_scenario.py +7 -1
- helm/benchmark/scenarios/medi_qa_scenario.py +10 -4
- helm/benchmark/scenarios/medication_qa_scenario.py +7 -1
- helm/benchmark/scenarios/melt_ir_scenario.py +171 -0
- helm/benchmark/scenarios/melt_knowledge_scenario.py +246 -0
- helm/benchmark/scenarios/melt_lm_scenarios.py +252 -0
- helm/benchmark/scenarios/melt_scenarios.py +793 -0
- helm/benchmark/scenarios/melt_srn_scenario.py +342 -0
- helm/benchmark/scenarios/melt_synthetic_reasoning_scenario.py +222 -0
- helm/benchmark/scenarios/melt_translation_scenario.py +152 -0
- helm/benchmark/scenarios/mental_health_scenario.py +16 -5
- helm/benchmark/scenarios/mimic_bhc_scenario.py +13 -8
- helm/benchmark/scenarios/mimic_rrs_scenario.py +17 -8
- helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +14 -8
- helm/benchmark/scenarios/mmlu_pro_scenario.py +1 -1
- helm/benchmark/scenarios/mmmlu_scenario.py +85 -0
- helm/benchmark/scenarios/mtsamples_procedures_scenario.py +5 -2
- helm/benchmark/scenarios/mtsamples_replicate_scenario.py +3 -2
- helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +11 -5
- helm/benchmark/scenarios/openai_mrcr_scenario.py +79 -0
- helm/benchmark/scenarios/pubmed_qa_scenario.py +6 -1
- helm/benchmark/scenarios/race_based_med_scenario.py +18 -8
- helm/benchmark/scenarios/ruler_qa_scenario_helper.py +2 -2
- helm/benchmark/scenarios/ruler_qa_scenarios.py +2 -2
- helm/benchmark/scenarios/seahelm_scenario.py +2 -2
- helm/benchmark/scenarios/shc_bmt_scenario.py +12 -6
- helm/benchmark/scenarios/shc_cdi_scenario.py +11 -6
- helm/benchmark/scenarios/shc_conf_scenario.py +12 -6
- helm/benchmark/scenarios/shc_ent_scenario.py +11 -6
- helm/benchmark/scenarios/shc_gip_scenario.py +13 -5
- helm/benchmark/scenarios/shc_privacy_scenario.py +78 -0
- helm/benchmark/scenarios/shc_proxy_scenario.py +76 -0
- helm/benchmark/scenarios/shc_ptbm_scenario.py +12 -7
- helm/benchmark/scenarios/shc_sei_scenario.py +12 -7
- helm/benchmark/scenarios/shc_sequoia_scenario.py +13 -5
- helm/benchmark/scenarios/starr_patient_instructions_scenario.py +15 -8
- helm/benchmark/scenarios/test_alghafa_scenario.py +29 -0
- helm/benchmark/scenarios/test_aratrust_scenario.py +21 -0
- helm/benchmark/scenarios/test_bluex_scenario.py +59 -0
- helm/benchmark/scenarios/test_exams_multilingual_scenario.py +29 -0
- helm/benchmark/scenarios/test_healtha_br_scenario.py +57 -0
- helm/benchmark/scenarios/test_infinite_bench_en_qa_scenario.py +18 -0
- helm/benchmark/scenarios/test_infinite_bench_en_sum_scenario.py +31 -0
- helm/benchmark/scenarios/truthful_qa_scenario.py +2 -1
- helm/benchmark/scenarios/vision_language/msr_vtt_scenario.py +75 -0
- helm/benchmark/server.py +2 -1
- helm/benchmark/slurm_jobs.py +1 -2
- helm/benchmark/slurm_runner.py +8 -1
- helm/benchmark/static/schema_arabic.yaml +228 -0
- helm/benchmark/static/schema_audio.yaml +60 -49
- helm/benchmark/static/schema_classic.yaml +0 -17
- helm/benchmark/static/schema_enterprise.yaml +21 -0
- helm/benchmark/static/schema_long_context.yaml +81 -20
- helm/benchmark/static/schema_medhelm.yaml +272 -213
- helm/benchmark/static/schema_melt.yaml +1257 -0
- helm/benchmark/static/schema_slphelm.yaml +162 -0
- helm/benchmark/static/schema_vhelm.yaml +26 -26
- helm/benchmark/static/schema_video.yaml +219 -0
- helm/benchmark/static_build/assets/index-b9779128.css +1 -0
- helm/benchmark/static_build/assets/index-e439d5e1.js +10 -0
- helm/benchmark/static_build/assets/medhelm-overview-eac29843.png +0 -0
- helm/benchmark/static_build/assets/{tremor-9cefc3c5.js → tremor-38a10867.js} +1 -1
- helm/benchmark/static_build/index.html +4 -4
- helm/benchmark/window_services/encoder_decoder_window_service.py +3 -3
- helm/benchmark/window_services/image_generation/clip_window_service.py +1 -3
- helm/benchmark/window_services/test_utils.py +3 -4
- helm/benchmark/window_services/tokenizer_service.py +7 -8
- helm/clients/anthropic_client.py +69 -29
- helm/clients/audio_language/diva_llama_client.py +4 -2
- helm/clients/audio_language/qwen2_5_omni_client.py +209 -0
- helm/clients/audio_language/qwen2_audiolm_client.py +8 -6
- helm/clients/audio_language/qwen_audiolm_client.py +4 -2
- helm/clients/audio_language/test.py +62 -0
- helm/clients/bedrock_client.py +3 -1
- helm/clients/client.py +7 -7
- helm/clients/grok_client.py +36 -0
- helm/clients/huggingface_client.py +42 -3
- helm/clients/huggingface_pipeline_client.py +138 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +1 -1
- helm/clients/image_generation/dalle_mini/model/modeling.py +1 -1
- helm/clients/image_generation/dalle_mini/model/processor.py +1 -1
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +1 -1
- helm/clients/openai_client.py +102 -55
- helm/clients/openai_responses_client.py +176 -0
- helm/clients/palmyra_client.py +2 -5
- helm/clients/reka_client.py +2 -2
- helm/clients/test_huggingface_client.py +3 -3
- helm/clients/together_client.py +31 -6
- helm/clients/vertexai_client.py +17 -9
- helm/clients/vision_language/huggingface_vision2seq_client.py +6 -4
- helm/clients/vision_language/huggingface_vlm_client.py +2 -2
- helm/clients/vision_language/idefics_client.py +6 -2
- helm/clients/vision_language/paligemma_client.py +2 -2
- helm/clients/vision_language/qwen2_vlm_client.py +66 -53
- helm/clients/vision_language/qwen_vlm_client.py +7 -5
- helm/clients/vllm_client.py +43 -7
- helm/clients/vllm_granite_thinking_client.py +56 -0
- helm/clients/writer_client.py +102 -0
- helm/common/context.py +80 -0
- helm/common/credentials_utils.py +5 -5
- helm/common/critique_request.py +0 -1
- helm/common/general.py +9 -2
- helm/common/hierarchical_logger.py +104 -12
- helm/common/local_context.py +140 -0
- helm/common/object_spec.py +23 -8
- helm/common/remote_context.py +61 -0
- helm/common/request.py +8 -0
- helm/common/test_logging.py +94 -0
- helm/config/model_deployments.yaml +995 -45
- helm/config/model_metadata.yaml +780 -59
- helm/config/tokenizer_configs.yaml +224 -3
- helm/proxy/cli.py +4 -2
- helm/proxy/critique/mechanical_turk_utils.py +1 -1
- helm/proxy/retry.py +5 -0
- helm/proxy/services/server_service.py +21 -85
- helm/tokenizers/grok_tokenizer.py +55 -0
- helm/tokenizers/huggingface_tokenizer.py +1 -1
- helm/tokenizers/test_grok_tokenizer.py +33 -0
- helm/benchmark/metrics/numeracy_metrics.py +0 -72
- helm/benchmark/metrics/test_numeracy_metrics.py +0 -95
- helm/benchmark/scenarios/numeracy_scenario.py +0 -793
- helm/benchmark/scenarios/test_infinite_bench_sum_scenario.py +0 -46
- helm/benchmark/static_build/assets/index-262903c1.js +0 -10
- helm/benchmark/static_build/assets/index-42060d71.css +0 -1
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/licenses/LICENSE +0 -0
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/top_level.txt +0 -0
- /helm/benchmark/static_build/assets/{medhelm-overview-3ddfcd65.png → medhelm-v1-overview-3ddfcd65.png} +0 -0
|
@@ -95,8 +95,8 @@ class HuggingFaceVision2SeqClient(CachingClient):
|
|
|
95
95
|
|
|
96
96
|
def do_it() -> Dict[str, Any]:
|
|
97
97
|
messages = [{"role": "user", "content": multimodal_prompt}]
|
|
98
|
-
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
|
99
|
-
inputs = processor(
|
|
98
|
+
prompt = processor.apply_chat_template(messages, add_generation_prompt=True) # type: ignore
|
|
99
|
+
inputs = processor( # type: ignore
|
|
100
100
|
text=[prompt] * request.num_completions,
|
|
101
101
|
images=[
|
|
102
102
|
[load_image(image_path) for image_path in image_paths]
|
|
@@ -107,8 +107,10 @@ class HuggingFaceVision2SeqClient(CachingClient):
|
|
|
107
107
|
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
|
108
108
|
|
|
109
109
|
# Generate
|
|
110
|
-
generated_ids = model.generate(**inputs, **generation_args)
|
|
111
|
-
generated_texts: List[str] = processor.batch_decode(
|
|
110
|
+
generated_ids = model.generate(**inputs, **generation_args) # type: ignore
|
|
111
|
+
generated_texts: List[str] = processor.batch_decode( # type: ignore
|
|
112
|
+
generated_ids, skip_special_tokens=True
|
|
113
|
+
)
|
|
112
114
|
return {"output": generated_texts}
|
|
113
115
|
|
|
114
116
|
# Include the prompt and model name in the cache key
|
|
@@ -50,7 +50,7 @@ class HuggingFaceVLMClient(CachingClient):
|
|
|
50
50
|
with self._models_lock:
|
|
51
51
|
model_id: str = self._models_aliases.get(model_name, model_name)
|
|
52
52
|
if model_id not in self._models:
|
|
53
|
-
self._models[model_id] = pipeline("image-to-text", model=model_id, device_map="auto")
|
|
53
|
+
self._models[model_id] = pipeline("image-to-text", model=model_id, device_map="auto") # type: ignore
|
|
54
54
|
return self._models[model_id]
|
|
55
55
|
|
|
56
56
|
def make_request(self, request: Request) -> RequestResult:
|
|
@@ -80,7 +80,7 @@ class HuggingFaceVLMClient(CachingClient):
|
|
|
80
80
|
|
|
81
81
|
def do_it() -> Dict[str, Any]:
|
|
82
82
|
model: ImageToTextPipeline = self._get_model(request.model_deployment)
|
|
83
|
-
outputs = model(image, prompt=prompt, generate_kwargs=generation_args)
|
|
83
|
+
outputs = model(image, prompt=prompt, generate_kwargs=generation_args) # type: ignore
|
|
84
84
|
return outputs[0]
|
|
85
85
|
|
|
86
86
|
cache_key = CachingClient.make_cache_key(
|
|
@@ -89,14 +89,18 @@ class IDEFICSClient(CachingClient):
|
|
|
89
89
|
input_args: Dict[str, Union[str, bool]] = {"return_tensors": "pt"}
|
|
90
90
|
generation_args = {
|
|
91
91
|
"max_new_tokens": request.max_tokens,
|
|
92
|
-
"bad_words_ids": processor.tokenizer(
|
|
92
|
+
"bad_words_ids": processor.tokenizer( # type: ignore
|
|
93
|
+
self.BAD_WORD_TOKENS, add_special_tokens=False
|
|
94
|
+
).input_ids,
|
|
93
95
|
}
|
|
94
96
|
|
|
95
97
|
if self.END_OF_UTTERANCE_TOKEN in request.stop_sequences:
|
|
96
98
|
# Following https://huggingface.co/HuggingFaceM4/idefics-80b-instruct,
|
|
97
99
|
# specify <end_of_utterance> as an exit condition.
|
|
98
100
|
input_args["add_end_of_utterance_token"] = False
|
|
99
|
-
exit_condition = processor.tokenizer(
|
|
101
|
+
exit_condition = processor.tokenizer( # type: ignore
|
|
102
|
+
self.END_OF_UTTERANCE_TOKEN, add_special_tokens=False
|
|
103
|
+
).input_ids
|
|
100
104
|
generation_args["eos_token_id"] = exit_condition
|
|
101
105
|
|
|
102
106
|
multimodal_prompt: List[Union[str, Image.Image]] = []
|
|
@@ -93,7 +93,7 @@ class PaliGemmaClient(CachingClient):
|
|
|
93
93
|
else:
|
|
94
94
|
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
|
95
95
|
prompt_text: str = "\n".join(prompt_pieces)
|
|
96
|
-
model_inputs = processor(text=prompt_text, images=images, return_tensors="pt").to(self._device)
|
|
96
|
+
model_inputs = processor(text=prompt_text, images=images, return_tensors="pt").to(self._device) # type: ignore
|
|
97
97
|
input_len = model_inputs["input_ids"].shape[-1]
|
|
98
98
|
|
|
99
99
|
completions: List[GeneratedOutput] = []
|
|
@@ -109,7 +109,7 @@ class PaliGemmaClient(CachingClient):
|
|
|
109
109
|
)[0]
|
|
110
110
|
if not request.echo_prompt:
|
|
111
111
|
generation = generation[input_len:]
|
|
112
|
-
decoded = processor.decode(generation, skip_special_tokens=True)
|
|
112
|
+
decoded = processor.decode(generation, skip_special_tokens=True) # type: ignore
|
|
113
113
|
return {"output": decoded}
|
|
114
114
|
|
|
115
115
|
# Include the prompt and model name in the cache key
|
|
@@ -2,7 +2,7 @@ from threading import Lock
|
|
|
2
2
|
from typing import Any, Dict, List, Optional
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
|
|
5
|
-
from transformers import
|
|
5
|
+
from transformers import AutoProcessor
|
|
6
6
|
from qwen_vl_utils import process_vision_info
|
|
7
7
|
import torch
|
|
8
8
|
|
|
@@ -16,15 +16,20 @@ from helm.clients.client import CachingClient, generate_uid_for_multimodal_promp
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
@dataclass(frozen=True)
|
|
19
|
-
class
|
|
20
|
-
model:
|
|
19
|
+
class LoadedModelProcessor:
|
|
20
|
+
model: Any
|
|
21
21
|
processor: AutoProcessor
|
|
22
22
|
|
|
23
23
|
|
|
24
|
+
# Global cache for all models
|
|
24
25
|
_models_lock: Lock = Lock()
|
|
25
|
-
_models: Dict[str, Optional[
|
|
26
|
+
_models: Dict[str, Optional[LoadedModelProcessor]] = {
|
|
26
27
|
"Qwen/Qwen2-VL-7B-Instruct": None,
|
|
27
28
|
"Qwen/Qwen2-VL-72B-Instruct": None,
|
|
29
|
+
"Qwen/Qwen2.5-VL-3B-Instruct": None,
|
|
30
|
+
"Qwen/Qwen2.5-VL-7B-Instruct": None,
|
|
31
|
+
"Qwen/Qwen2.5-VL-32B-Instruct": None,
|
|
32
|
+
"Qwen/Qwen2.5-VL-72B-Instruct": None,
|
|
28
33
|
}
|
|
29
34
|
|
|
30
35
|
|
|
@@ -38,50 +43,52 @@ class Qwen2VLMClient(CachingClient):
|
|
|
38
43
|
return "Qwen/Qwen2-VL-7B-Instruct"
|
|
39
44
|
elif helm_model_name == "qwen2-vl-72b-instruct":
|
|
40
45
|
return "Qwen/Qwen2-VL-72B-Instruct"
|
|
46
|
+
elif helm_model_name == "qwen2.5-vl-3b-instruct":
|
|
47
|
+
return "Qwen/Qwen2.5-VL-3B-Instruct"
|
|
48
|
+
elif helm_model_name == "qwen2.5-vl-7b-instruct":
|
|
49
|
+
return "Qwen/Qwen2.5-VL-7B-Instruct"
|
|
50
|
+
elif helm_model_name == "qwen2.5-vl-32b-instruct":
|
|
51
|
+
return "Qwen/Qwen2.5-VL-32B-Instruct"
|
|
52
|
+
elif helm_model_name == "qwen2.5-vl-72b-instruct":
|
|
53
|
+
return "Qwen/Qwen2.5-VL-72B-Instruct"
|
|
41
54
|
else:
|
|
42
55
|
raise ValueError(f"Unhandled model name: {helm_model_name}")
|
|
43
56
|
|
|
44
|
-
def _get_model(self, helm_model_name: str) ->
|
|
45
|
-
|
|
46
|
-
global _models
|
|
57
|
+
def _get_model(self, helm_model_name: str) -> LoadedModelProcessor:
|
|
58
|
+
from transformers import Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
|
|
47
59
|
|
|
48
|
-
|
|
60
|
+
global _models_lock, _models
|
|
49
61
|
|
|
62
|
+
model_name = self._get_model_name(helm_model_name)
|
|
50
63
|
with _models_lock:
|
|
51
64
|
loaded = _models[model_name]
|
|
52
65
|
if loaded is None:
|
|
53
66
|
hlog(f"Loading model {model_name} and caching in memory...")
|
|
54
|
-
#
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
67
|
+
# Use different loading routines depending on whether it's Qwen2.5 or Qwen2.
|
|
68
|
+
if "2.5" in model_name:
|
|
69
|
+
# Qwen2.5: by default use torch_dtype="auto". You can enable flash_attention_2 if desired.
|
|
70
|
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
|
71
|
+
model_name,
|
|
72
|
+
torch_dtype=torch.bfloat16,
|
|
73
|
+
device_map="auto",
|
|
74
|
+
attn_implementation="flash_attention_2",
|
|
75
|
+
).eval()
|
|
76
|
+
else:
|
|
77
|
+
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
78
|
+
model_name,
|
|
79
|
+
torch_dtype=torch.bfloat16,
|
|
80
|
+
device_map="auto",
|
|
81
|
+
attn_implementation="flash_attention_2",
|
|
82
|
+
).eval()
|
|
61
83
|
processor = AutoProcessor.from_pretrained(model_name)
|
|
62
|
-
loaded =
|
|
84
|
+
loaded = LoadedModelProcessor(model=model, processor=processor)
|
|
63
85
|
_models[model_name] = loaded
|
|
64
|
-
|
|
65
86
|
return loaded
|
|
66
87
|
|
|
67
88
|
def make_request(self, request: Request) -> RequestResult:
|
|
68
89
|
assert request.multimodal_prompt is not None, "Multimodal prompt is required"
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
processor = loaded.processor
|
|
72
|
-
|
|
73
|
-
# Build Qwen2 messages
|
|
74
|
-
# We assume all media objects go into a single "user" message:
|
|
75
|
-
# messages = [
|
|
76
|
-
# {
|
|
77
|
-
# "role": "user",
|
|
78
|
-
# "content": [
|
|
79
|
-
# {"type": "image", "image": "file:///path/to/image1.jpg"},
|
|
80
|
-
# {"type": "image", "image": "file:///path/to/image2.jpg"},
|
|
81
|
-
# {"type": "text", "text": "Describe these images."}
|
|
82
|
-
# ]
|
|
83
|
-
# }
|
|
84
|
-
# ]
|
|
90
|
+
|
|
91
|
+
# Build messages by collating all media objects into a single "user" message.
|
|
85
92
|
message_content = []
|
|
86
93
|
for media_object in request.multimodal_prompt.media_objects:
|
|
87
94
|
if media_object.is_type("image") and media_object.location:
|
|
@@ -95,18 +102,6 @@ class Qwen2VLMClient(CachingClient):
|
|
|
95
102
|
|
|
96
103
|
messages = [{"role": "user", "content": message_content}]
|
|
97
104
|
|
|
98
|
-
# Prepare text and vision inputs
|
|
99
|
-
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
100
|
-
image_inputs, video_inputs = process_vision_info(messages)
|
|
101
|
-
|
|
102
|
-
inputs = processor(
|
|
103
|
-
text=[text],
|
|
104
|
-
images=image_inputs,
|
|
105
|
-
videos=video_inputs,
|
|
106
|
-
padding=True,
|
|
107
|
-
return_tensors="pt",
|
|
108
|
-
).to(self._device)
|
|
109
|
-
|
|
110
105
|
generation_args = {
|
|
111
106
|
"max_new_tokens": request.max_tokens,
|
|
112
107
|
}
|
|
@@ -116,23 +111,38 @@ class Qwen2VLMClient(CachingClient):
|
|
|
116
111
|
request_datetime: Optional[int] = None
|
|
117
112
|
all_cached: bool = True
|
|
118
113
|
|
|
119
|
-
with htrack_block(f"Generating for prompt: {text}"):
|
|
114
|
+
with htrack_block(f"Generating for prompt: {request.multimodal_prompt.text}"):
|
|
120
115
|
for completion_index in range(request.num_completions):
|
|
121
116
|
try:
|
|
122
117
|
|
|
123
118
|
def do_it() -> Dict[str, Any]:
|
|
119
|
+
loaded = self._get_model(request.model_engine)
|
|
120
|
+
model = loaded.model
|
|
121
|
+
processor = loaded.processor
|
|
122
|
+
|
|
123
|
+
# Prepare text and vision inputs.
|
|
124
|
+
text = processor.apply_chat_template( # type: ignore
|
|
125
|
+
messages, tokenize=False, add_generation_prompt=True
|
|
126
|
+
)
|
|
127
|
+
image_inputs, video_inputs = process_vision_info(messages)
|
|
128
|
+
inputs = processor( # type: ignore
|
|
129
|
+
text=[text],
|
|
130
|
+
images=image_inputs,
|
|
131
|
+
videos=video_inputs,
|
|
132
|
+
padding=True,
|
|
133
|
+
return_tensors="pt",
|
|
134
|
+
).to(self._device)
|
|
135
|
+
|
|
124
136
|
generated_ids = model.generate(**inputs, **generation_args)
|
|
125
|
-
# Remove the input prefix from outputs
|
|
137
|
+
# Remove the input prefix from outputs.
|
|
126
138
|
generated_ids_trimmed = [
|
|
127
139
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
128
140
|
]
|
|
129
|
-
output_text = processor.batch_decode(
|
|
141
|
+
output_text = processor.batch_decode( # type: ignore
|
|
130
142
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
131
143
|
)
|
|
132
|
-
# There's only one batch element
|
|
133
|
-
completion = output_text[0]
|
|
134
144
|
# For simplicity, we split tokens by whitespace.
|
|
135
|
-
|
|
145
|
+
completion = output_text[0]
|
|
136
146
|
tokens = completion.split()
|
|
137
147
|
return {"output": (completion, tokens)}
|
|
138
148
|
|
|
@@ -148,7 +158,11 @@ class Qwen2VLMClient(CachingClient):
|
|
|
148
158
|
result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
149
159
|
except RuntimeError as model_error:
|
|
150
160
|
return RequestResult(
|
|
151
|
-
success=False,
|
|
161
|
+
success=False,
|
|
162
|
+
cached=False,
|
|
163
|
+
error=str(model_error),
|
|
164
|
+
completions=[],
|
|
165
|
+
embedding=[],
|
|
152
166
|
)
|
|
153
167
|
|
|
154
168
|
text_out, tokens = result["output"]
|
|
@@ -160,7 +174,6 @@ class Qwen2VLMClient(CachingClient):
|
|
|
160
174
|
)
|
|
161
175
|
)
|
|
162
176
|
hlog(f"Generated: {text_out}")
|
|
163
|
-
|
|
164
177
|
request_time += result["request_time"]
|
|
165
178
|
request_datetime = request_datetime or result.get("request_datetime")
|
|
166
179
|
all_cached = all_cached and cached
|
|
@@ -115,14 +115,16 @@ class QwenVLMClient(CachingClient):
|
|
|
115
115
|
|
|
116
116
|
def do_it() -> Dict[str, Any]:
|
|
117
117
|
if request.model_engine == "qwen-vl-chat":
|
|
118
|
-
completion, _ = model.chat(
|
|
118
|
+
completion, _ = model.chat( # type: ignore
|
|
119
|
+
tokenizer, query=tokenizer.from_list_format(query), history=None # type: ignore
|
|
120
|
+
)
|
|
119
121
|
else:
|
|
120
|
-
inputs = tokenizer(tokenizer.from_list_format(query), return_tensors="pt")
|
|
122
|
+
inputs = tokenizer(tokenizer.from_list_format(query), return_tensors="pt") # type: ignore
|
|
121
123
|
inputs = inputs.to(self._device)
|
|
122
|
-
pred = model.generate(**inputs, **generation_args)
|
|
123
|
-
completion = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False)
|
|
124
|
+
pred = model.generate(**inputs, **generation_args) # type: ignore
|
|
125
|
+
completion = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False) # type: ignore
|
|
124
126
|
|
|
125
|
-
tokens: List[str] = tokenizer.tokenize(completion)
|
|
127
|
+
tokens: List[str] = tokenizer.tokenize(completion) # type: ignore
|
|
126
128
|
return {"output": (completion, tokens)}
|
|
127
129
|
|
|
128
130
|
# Include the prompt and model name in the cache key
|
helm/clients/vllm_client.py
CHANGED
|
@@ -2,7 +2,7 @@ from typing import Any, Dict, Optional
|
|
|
2
2
|
|
|
3
3
|
from helm.common.cache import CacheConfig
|
|
4
4
|
from helm.common.request import Request
|
|
5
|
-
from helm.clients.openai_client import OpenAILegacyCompletionsClient
|
|
5
|
+
from helm.clients.openai_client import OpenAIClient, OpenAILegacyCompletionsClient
|
|
6
6
|
from helm.tokenizers.tokenizer import Tokenizer
|
|
7
7
|
|
|
8
8
|
|
|
@@ -19,6 +19,8 @@ class VLLMClient(OpenAILegacyCompletionsClient):
|
|
|
19
19
|
tokenizer_name: str,
|
|
20
20
|
cache_config: CacheConfig,
|
|
21
21
|
base_url: Optional[str] = None,
|
|
22
|
+
vllm_model_name: Optional[str] = None,
|
|
23
|
+
**kwargs,
|
|
22
24
|
):
|
|
23
25
|
super().__init__(
|
|
24
26
|
tokenizer=tokenizer,
|
|
@@ -27,18 +29,52 @@ class VLLMClient(OpenAILegacyCompletionsClient):
|
|
|
27
29
|
api_key="EMPTY",
|
|
28
30
|
org_id=None,
|
|
29
31
|
base_url=base_url,
|
|
32
|
+
openai_model_name=vllm_model_name,
|
|
33
|
+
**kwargs,
|
|
30
34
|
)
|
|
31
35
|
self.tokenizer = tokenizer
|
|
32
36
|
self.tokenizer_name = tokenizer_name
|
|
33
|
-
|
|
34
|
-
def _get_model_for_request(self, request: Request) -> str:
|
|
35
|
-
# The `model` parameter for vLLM should be the whole model name including the creator organization,
|
|
36
|
-
# unlike OpenAI which only uses the model engine.
|
|
37
|
-
return request.model
|
|
37
|
+
self.vllm_model_name = vllm_model_name
|
|
38
38
|
|
|
39
39
|
def _to_raw_completion_request(self, request: Request) -> Dict[str, Any]:
|
|
40
40
|
raw_request = super()._to_raw_completion_request(request)
|
|
41
41
|
# This avoids the error: best_of must be 1 when using greedy sampling
|
|
42
|
-
if
|
|
42
|
+
if (
|
|
43
|
+
"temperature" in raw_request
|
|
44
|
+
and raw_request["temperature"] == 0.0
|
|
45
|
+
and "best_of" in raw_request
|
|
46
|
+
and raw_request["best_of"] > 1
|
|
47
|
+
):
|
|
43
48
|
raw_request["best_of"] = 1
|
|
44
49
|
return raw_request
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class VLLMChatClient(OpenAIClient):
|
|
53
|
+
"""Sends request to a vLLM server using the OpenAI-compatible API.
|
|
54
|
+
|
|
55
|
+
Only uses the Chat Completions API.
|
|
56
|
+
|
|
57
|
+
See: https://docs.vllm.ai/en/latest/getting_started/quickstart.html#openai-compatible-server"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
tokenizer: Tokenizer,
|
|
62
|
+
tokenizer_name: str,
|
|
63
|
+
cache_config: CacheConfig,
|
|
64
|
+
base_url: Optional[str] = None,
|
|
65
|
+
vllm_model_name: Optional[str] = None,
|
|
66
|
+
**kwargs,
|
|
67
|
+
):
|
|
68
|
+
super().__init__(
|
|
69
|
+
tokenizer=tokenizer,
|
|
70
|
+
tokenizer_name=tokenizer_name,
|
|
71
|
+
cache_config=cache_config,
|
|
72
|
+
api_key="EMPTY",
|
|
73
|
+
org_id=None,
|
|
74
|
+
base_url=base_url,
|
|
75
|
+
openai_model_name=vllm_model_name,
|
|
76
|
+
**kwargs,
|
|
77
|
+
)
|
|
78
|
+
self.tokenizer = tokenizer
|
|
79
|
+
self.tokenizer_name = tokenizer_name
|
|
80
|
+
self.vllm_model_name = vllm_model_name
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from dataclasses import replace
|
|
2
|
+
import re
|
|
3
|
+
from typing import Any, Dict, List, Tuple
|
|
4
|
+
|
|
5
|
+
from helm.clients.vllm_client import VLLMChatClient
|
|
6
|
+
from helm.common.request import GeneratedOutput, Request, RequestResult, Thinking
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class VLLMGraniteThinkingClient(VLLMChatClient):
|
|
10
|
+
"""Sends request to a Granite model on vLLM server with thinking enabled.
|
|
11
|
+
|
|
12
|
+
From vLLM documentation at
|
|
13
|
+
https://docs.vllm.ai/en/v0.9.1/features/reasoning_outputs.html
|
|
14
|
+
|
|
15
|
+
IBM Granite 3.2 reasoning is disabled by default;
|
|
16
|
+
to enable it, you must also pass thinking=True in your chat_template_kwargs.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def _make_chat_raw_request(self, request: Request) -> Dict[str, Any]:
|
|
20
|
+
raw_request = super()._make_chat_raw_request(request)
|
|
21
|
+
raw_request["extra_body"] = {"chat_template_kwargs": {"thinking": True}}
|
|
22
|
+
return raw_request
|
|
23
|
+
|
|
24
|
+
def _parse_thinking(self, input: str) -> Tuple[str, str]:
|
|
25
|
+
"""Return a tuple of thinking text and output text."""
|
|
26
|
+
match = re.match(r"<think>(.*)</think>\s*<response>(.*)</response>", input, re.DOTALL)
|
|
27
|
+
if match:
|
|
28
|
+
return (match.group(1), match.group(2))
|
|
29
|
+
|
|
30
|
+
match = re.match(r"<think>(.*)</think>\s*<response>(.*)", input, re.DOTALL)
|
|
31
|
+
if match:
|
|
32
|
+
return (match.group(1), match.group(2))
|
|
33
|
+
|
|
34
|
+
match = re.match(r"<think>(.*)</think>\s*", input, re.DOTALL)
|
|
35
|
+
if match:
|
|
36
|
+
return (match.group(1), "")
|
|
37
|
+
|
|
38
|
+
match = re.match(r"<think>(.*)", input, re.DOTALL)
|
|
39
|
+
if match:
|
|
40
|
+
return (match.group(1), "")
|
|
41
|
+
|
|
42
|
+
return (input, "")
|
|
43
|
+
|
|
44
|
+
def _make_chat_request(self, request: Request) -> RequestResult:
|
|
45
|
+
request_result = super()._make_chat_request(request)
|
|
46
|
+
modified_completions: List[GeneratedOutput] = []
|
|
47
|
+
for completion in request_result.completions:
|
|
48
|
+
thinking, modified_text = self._parse_thinking(completion.text)
|
|
49
|
+
modified_completions.append(
|
|
50
|
+
replace(
|
|
51
|
+
completion,
|
|
52
|
+
text=modified_text,
|
|
53
|
+
thinking=Thinking(text=thinking),
|
|
54
|
+
)
|
|
55
|
+
)
|
|
56
|
+
return replace(request_result, completions=modified_completions)
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Mapping, Optional
|
|
2
|
+
|
|
3
|
+
from helm.clients.client import CachingClient
|
|
4
|
+
from helm.common.cache import CacheConfig
|
|
5
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
6
|
+
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, Token
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
from writerai import Writer
|
|
10
|
+
from writerai.types.chat_completion import ChatCompletion
|
|
11
|
+
except ModuleNotFoundError as e:
|
|
12
|
+
handle_module_not_found_error(e, ["openai"])
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class WriterClient(CachingClient):
|
|
16
|
+
def __init__(self, cache_config: CacheConfig, api_key: Optional[str] = None):
|
|
17
|
+
super().__init__(cache_config=cache_config)
|
|
18
|
+
self._writer_client = Writer(api_key=api_key)
|
|
19
|
+
|
|
20
|
+
def _get_messages_from_request(self, request: Request) -> List[Dict]:
|
|
21
|
+
if request.prompt and request.messages:
|
|
22
|
+
raise ValueError(f"Only one of `prompt` and `messages` may be set in request: {request}")
|
|
23
|
+
if request.multimodal_prompt:
|
|
24
|
+
raise ValueError("`multimodal_prompt` is not supported by WriterClient")
|
|
25
|
+
if request.messages:
|
|
26
|
+
return [{"role": message["role"], "content": message["content"]} for message in request.messages]
|
|
27
|
+
else:
|
|
28
|
+
return [{"role": "user", "content": request.prompt}]
|
|
29
|
+
|
|
30
|
+
def _convert_chat_completion_to_generated_outputs(
|
|
31
|
+
self, chat_completion: ChatCompletion, request: Request
|
|
32
|
+
) -> List[GeneratedOutput]:
|
|
33
|
+
generated_outputs: List[GeneratedOutput] = []
|
|
34
|
+
for choice in chat_completion.choices:
|
|
35
|
+
raw_completion_content = choice.message.content
|
|
36
|
+
# The Writer chat completion API doesn't support echo.
|
|
37
|
+
# If `echo_prompt` is true, combine the prompt and completion.
|
|
38
|
+
text: str = request.prompt + raw_completion_content if request.echo_prompt else raw_completion_content
|
|
39
|
+
tokens: List[Token] = []
|
|
40
|
+
if choice.logprobs and choice.logprobs.content:
|
|
41
|
+
tokens = [
|
|
42
|
+
Token(text=choice_token.token, logprob=choice_token.logprob)
|
|
43
|
+
for choice_token in choice.logprobs.content
|
|
44
|
+
]
|
|
45
|
+
generated_output = GeneratedOutput(
|
|
46
|
+
text=text,
|
|
47
|
+
logprob=sum(token.logprob for token in tokens) if tokens else 0.0,
|
|
48
|
+
tokens=tokens,
|
|
49
|
+
finish_reason={"reason": choice.finish_reason},
|
|
50
|
+
)
|
|
51
|
+
generated_outputs.append(generated_output)
|
|
52
|
+
return generated_outputs
|
|
53
|
+
|
|
54
|
+
def _convert_request_to_raw_request(self, request: Request) -> Dict:
|
|
55
|
+
raw_request = {
|
|
56
|
+
"messages": self._get_messages_from_request(request),
|
|
57
|
+
"model": request.model.split("/")[-1],
|
|
58
|
+
"logprobs": bool(request.top_k_per_token),
|
|
59
|
+
"max_tokens": request.max_tokens,
|
|
60
|
+
"n": request.num_completions,
|
|
61
|
+
"stop": request.stop_sequences,
|
|
62
|
+
"temperature": request.temperature,
|
|
63
|
+
"top_p": request.top_p,
|
|
64
|
+
}
|
|
65
|
+
if request.response_format and request.response_format.json_schema:
|
|
66
|
+
raw_request["response_format"] = {
|
|
67
|
+
"type": "json_schema",
|
|
68
|
+
"json_schema": {
|
|
69
|
+
"schema": request.response_format.json_schema,
|
|
70
|
+
},
|
|
71
|
+
}
|
|
72
|
+
return raw_request
|
|
73
|
+
|
|
74
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
75
|
+
raw_request = self._convert_request_to_raw_request(request)
|
|
76
|
+
cache_key: Mapping = CachingClient.make_cache_key(raw_request, request)
|
|
77
|
+
|
|
78
|
+
def do_it() -> Dict[Any, Any]:
|
|
79
|
+
return self._writer_client.chat.chat(**raw_request).model_dump()
|
|
80
|
+
|
|
81
|
+
try:
|
|
82
|
+
raw_response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
83
|
+
chat_completion: ChatCompletion = ChatCompletion.model_validate(raw_response)
|
|
84
|
+
except Exception as error:
|
|
85
|
+
return RequestResult(
|
|
86
|
+
success=False,
|
|
87
|
+
cached=False,
|
|
88
|
+
error=str(error),
|
|
89
|
+
completions=[],
|
|
90
|
+
embedding=[],
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
generated_outputs = self._convert_chat_completion_to_generated_outputs(chat_completion, request)
|
|
94
|
+
|
|
95
|
+
return RequestResult(
|
|
96
|
+
success=True,
|
|
97
|
+
cached=cached,
|
|
98
|
+
request_time=raw_response["request_time"],
|
|
99
|
+
request_datetime=raw_response["request_datetime"],
|
|
100
|
+
completions=generated_outputs,
|
|
101
|
+
embedding=[],
|
|
102
|
+
)
|
helm/common/context.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
from helm.common.critique_request import CritiqueRequest, CritiqueRequestResult
|
|
4
|
+
from helm.common.clip_score_request import CLIPScoreRequest, CLIPScoreResult
|
|
5
|
+
from helm.common.file_upload_request import FileUploadResult, FileUploadRequest
|
|
6
|
+
from helm.common.nudity_check_request import NudityCheckRequest, NudityCheckResult
|
|
7
|
+
from helm.common.perspective_api_request import PerspectiveAPIRequestResult, PerspectiveAPIRequest
|
|
8
|
+
from helm.common.moderations_api_request import ModerationAPIRequest, ModerationAPIRequestResult
|
|
9
|
+
from helm.common.tokenization_request import (
|
|
10
|
+
TokenizationRequest,
|
|
11
|
+
TokenizationRequestResult,
|
|
12
|
+
DecodeRequest,
|
|
13
|
+
DecodeRequestResult,
|
|
14
|
+
)
|
|
15
|
+
from helm.common.request import Request, RequestResult
|
|
16
|
+
from helm.proxy.query import Query, QueryResult
|
|
17
|
+
from helm.common.cache import CacheConfig
|
|
18
|
+
from helm.proxy.services.service import GeneralInfo
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Context(ABC):
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def get_general_info(self) -> GeneralInfo:
|
|
24
|
+
"""Get general info."""
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def expand_query(self, query: Query) -> QueryResult:
|
|
29
|
+
"""Turn the `query` into requests."""
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
34
|
+
"""Actually make a request to an API."""
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
@abstractmethod
|
|
38
|
+
def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
|
|
39
|
+
"""Tokenize via an API."""
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def decode(self, request: DecodeRequest) -> DecodeRequestResult:
|
|
44
|
+
"""Decodes to text."""
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def upload(self, request: FileUploadRequest) -> FileUploadResult:
|
|
49
|
+
"""Uploads a file to external storage."""
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
@abstractmethod
|
|
53
|
+
def check_nudity(self, request: NudityCheckRequest) -> NudityCheckResult:
|
|
54
|
+
"""Check for nudity for a batch of images."""
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
@abstractmethod
|
|
58
|
+
def compute_clip_score(self, request: CLIPScoreRequest) -> CLIPScoreResult:
|
|
59
|
+
"""Computes CLIPScore for a given caption and image."""
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
@abstractmethod
|
|
63
|
+
def get_toxicity_scores(self, request: PerspectiveAPIRequest) -> PerspectiveAPIRequestResult:
|
|
64
|
+
"""Get toxicity scores for a batch of text."""
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
@abstractmethod
|
|
68
|
+
def get_moderation_results(self, request: ModerationAPIRequest) -> ModerationAPIRequestResult:
|
|
69
|
+
"""Get OpenAI's moderation results for some text."""
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
@abstractmethod
|
|
73
|
+
def make_critique_request(self, request: CritiqueRequest) -> CritiqueRequestResult:
|
|
74
|
+
"""Get responses to a critique request."""
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
@abstractmethod
|
|
78
|
+
def get_cache_config(self, shard_name: str) -> CacheConfig:
|
|
79
|
+
"""Returns a CacheConfig"""
|
|
80
|
+
pass
|
helm/common/credentials_utils.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Any, Mapping, Optional
|
|
4
4
|
|
|
5
|
-
from helm.common.hierarchical_logger import hlog
|
|
5
|
+
from helm.common.hierarchical_logger import hlog, hwarn
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
def provide_api_key(
|
|
@@ -13,16 +13,16 @@ def provide_api_key(
|
|
|
13
13
|
hlog(f"Using host_organization api key defined in credentials.conf: {api_key_name}")
|
|
14
14
|
return credentials[api_key_name]
|
|
15
15
|
if "deployments" not in credentials:
|
|
16
|
-
|
|
17
|
-
"
|
|
16
|
+
hwarn(
|
|
17
|
+
"Could not find key 'deployments' in credentials.conf, "
|
|
18
18
|
f"therefore the API key {api_key_name} should be specified."
|
|
19
19
|
)
|
|
20
20
|
return None
|
|
21
21
|
deployment_api_keys = credentials["deployments"]
|
|
22
22
|
if model is None:
|
|
23
|
-
|
|
23
|
+
hwarn(f"Could not find key '{host_organization}' in credentials.conf and no model provided")
|
|
24
24
|
return None
|
|
25
25
|
if model not in deployment_api_keys:
|
|
26
|
-
|
|
26
|
+
hwarn(f"Could not find key '{model}' under key 'deployments' in credentials.conf")
|
|
27
27
|
return None
|
|
28
28
|
return deployment_api_keys[model]
|
helm/common/critique_request.py
CHANGED
|
@@ -6,7 +6,6 @@ from helm.common.media_object import MediaObject
|
|
|
6
6
|
class QuestionType:
|
|
7
7
|
"""String enum of question types."""
|
|
8
8
|
|
|
9
|
-
# TODO: Make this a StrEnum after upgrading to Python 3.11
|
|
10
9
|
MULTIPLE_CHOICE: str = "multiple_choice"
|
|
11
10
|
CHECKBOX: str = "checkbox"
|
|
12
11
|
FREE_RESPONSE: str = "free_response"
|