crfm-helm 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- crfm_helm-0.5.6.dist-info/METADATA +427 -0
- crfm_helm-0.5.6.dist-info/RECORD +941 -0
- {crfm_helm-0.5.4.dist-info → crfm_helm-0.5.6.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +13 -1
- helm/benchmark/adaptation/adapters/adapter_factory.py +15 -1
- helm/benchmark/adaptation/adapters/binary_ranking_adapter.py +1 -1
- helm/benchmark/adaptation/adapters/chat_adapter.py +49 -0
- helm/benchmark/adaptation/adapters/ehr_instruction_adapter.py +108 -0
- helm/benchmark/adaptation/adapters/generation_adapter.py +1 -1
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +4 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +1 -1
- helm/benchmark/adaptation/adapters/multimodal/generation_multimodal_adapter.py +4 -2
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -1
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +1 -1
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +4 -2
- helm/benchmark/adaptation/adapters/multimodal/test_multimodal_prompt.py +1 -1
- helm/benchmark/adaptation/adapters/multiple_choice_calibrated_adapter.py +1 -1
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +2 -2
- helm/benchmark/adaptation/adapters/multiple_choice_joint_chain_of_thought_adapter.py +87 -0
- helm/benchmark/adaptation/adapters/multiple_choice_separate_adapter.py +1 -1
- helm/benchmark/adaptation/adapters/test_adapter.py +4 -4
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +3 -3
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +2 -2
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +2 -2
- helm/benchmark/adaptation/common_adapter_specs.py +69 -4
- helm/benchmark/adaptation/prompt.py +1 -1
- helm/benchmark/annotation/aci_bench_annotator.py +95 -0
- helm/benchmark/annotation/air_bench_annotator.py +21 -6
- helm/benchmark/annotation/annotator.py +5 -0
- helm/benchmark/annotation/annotator_factory.py +3 -20
- helm/benchmark/annotation/autobencher_capabilities_annotator.py +107 -0
- helm/benchmark/annotation/autobencher_safety_annotator.py +98 -0
- helm/benchmark/annotation/bigcodebench_annotator.py +108 -0
- helm/benchmark/annotation/bird_sql_annotator.py +58 -0
- helm/benchmark/annotation/chw_care_plan_annotator.py +93 -0
- helm/benchmark/annotation/czech_bank_qa_annotator.py +78 -0
- helm/benchmark/annotation/dischargeme_annotator.py +107 -0
- helm/benchmark/annotation/ehr_sql_annotator.py +87 -0
- helm/benchmark/annotation/helpdesk_call_summarization_annotator.py +131 -0
- helm/benchmark/annotation/image2struct/image_compiler_annotator.py +6 -1
- helm/benchmark/annotation/live_qa_annotator.py +1 -1
- helm/benchmark/annotation/med_dialog_annotator.py +99 -0
- helm/benchmark/annotation/medalign_annotator.py +100 -0
- helm/benchmark/annotation/medi_qa_annotator.py +98 -0
- helm/benchmark/annotation/medication_qa_annotator.py +87 -63
- helm/benchmark/annotation/mental_health_annotator.py +98 -0
- helm/benchmark/annotation/mimic_bhc_annotator.py +100 -0
- helm/benchmark/annotation/mimic_rrs_annotator.py +100 -0
- helm/benchmark/annotation/model_as_judge.py +214 -6
- helm/benchmark/annotation/mtsamples_procedures_annotator.py +98 -0
- helm/benchmark/annotation/mtsamples_replicate_annotator.py +101 -0
- helm/benchmark/annotation/omni_math/gpt_evaluation_template.txt +152 -0
- helm/benchmark/annotation/omni_math/gpt_evaluation_zero_shot_template.txt +36 -0
- helm/benchmark/annotation/omni_math_annotator.py +131 -0
- helm/benchmark/annotation/spider_annotator.py +18 -0
- helm/benchmark/annotation/starr_patient_instructions_annotator.py +98 -0
- helm/benchmark/annotation/wildbench/eval_template.pairwise.v2.md +75 -0
- helm/benchmark/annotation/wildbench/eval_template.score.v2.md +66 -0
- helm/benchmark/annotation/wildbench_annotator.py +119 -0
- helm/benchmark/annotation_executor.py +35 -15
- helm/benchmark/augmentations/cleva_perturbation.py +9 -8
- helm/benchmark/augmentations/contraction_expansion_perturbation.py +2 -2
- helm/benchmark/augmentations/contrast_sets_perturbation.py +2 -2
- helm/benchmark/augmentations/dialect_perturbation.py +4 -5
- helm/benchmark/augmentations/extra_space_perturbation.py +2 -2
- helm/benchmark/augmentations/filler_words_perturbation.py +2 -2
- helm/benchmark/augmentations/gender_perturbation.py +2 -2
- helm/benchmark/augmentations/lowercase_perturbation.py +2 -2
- helm/benchmark/augmentations/mild_mix_perturbation.py +6 -6
- helm/benchmark/augmentations/misspelling_perturbation.py +2 -2
- helm/benchmark/augmentations/person_name_perturbation.py +4 -5
- helm/benchmark/augmentations/perturbation.py +1 -1
- helm/benchmark/augmentations/space_perturbation.py +2 -2
- helm/benchmark/augmentations/suffix_perturbation.py +2 -2
- helm/benchmark/augmentations/synonym_perturbation.py +4 -3
- helm/benchmark/augmentations/test_perturbation.py +16 -13
- helm/benchmark/augmentations/translate_perturbation.py +2 -2
- helm/benchmark/augmentations/typos_perturbation.py +2 -2
- helm/benchmark/data_preprocessor.py +2 -2
- helm/benchmark/executor.py +11 -12
- helm/benchmark/huggingface_registration.py +2 -7
- helm/benchmark/metrics/aci_bench_metrics.py +14 -0
- helm/benchmark/metrics/basic_metrics.py +6 -6
- helm/benchmark/metrics/bbq_metrics.py +2 -2
- helm/benchmark/metrics/bias_metrics.py +12 -3
- helm/benchmark/metrics/bias_word_lists.py +1 -1
- helm/benchmark/metrics/bigcodebench_metrics.py +25 -0
- helm/benchmark/metrics/bird_sql_metrics.py +28 -0
- helm/benchmark/metrics/chw_care_plan_metrics.py +14 -0
- helm/benchmark/metrics/classification_metrics.py +76 -12
- helm/benchmark/metrics/cleva_harms_metrics.py +10 -9
- helm/benchmark/metrics/code_metrics.py +5 -5
- helm/benchmark/metrics/comet_metric.py +125 -0
- helm/benchmark/metrics/common_metric_specs.py +9 -2
- helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +72 -0
- helm/benchmark/metrics/copyright_metrics.py +4 -4
- helm/benchmark/metrics/czech_bank_qa_metrics.py +29 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +2 -2
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +2 -2
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +2 -2
- helm/benchmark/metrics/dischargeme_metrics.py +14 -0
- helm/benchmark/metrics/disinformation_metrics.py +4 -4
- helm/benchmark/metrics/dry_run_metrics.py +5 -5
- helm/benchmark/metrics/efficiency_metrics.py +6 -6
- helm/benchmark/metrics/ehr_sql_metrics.py +103 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +3 -3
- helm/benchmark/metrics/evaluate_reference_metrics.py +144 -16
- helm/benchmark/metrics/gpqa_chain_of_thought_metric.py +103 -0
- helm/benchmark/metrics/gpt4_audio_critique_metrics.py +167 -0
- helm/benchmark/metrics/gpt4_audio_refusal_metrics.py +145 -0
- helm/benchmark/metrics/helpdesk_call_summarization_metrics.py +36 -0
- helm/benchmark/metrics/ifeval/__init__.py +0 -0
- helm/benchmark/metrics/ifeval/instructions.py +1574 -0
- helm/benchmark/metrics/ifeval/instructions_registry.py +182 -0
- helm/benchmark/metrics/ifeval/instructions_registry.pyi +3 -0
- helm/benchmark/metrics/ifeval/instructions_util.py +153 -0
- helm/benchmark/metrics/ifeval_metrics.py +55 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +1 -1
- helm/benchmark/metrics/image_generation/detection_metrics.py +1 -1
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +1 -1
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +1 -1
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +1 -1
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +1 -1
- helm/benchmark/metrics/image_generation/q16/test_q16.py +3 -1
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +1 -1
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +2 -2
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +1 -1
- helm/benchmark/metrics/image_generation/watermark_metrics.py +1 -1
- helm/benchmark/metrics/instruction_following_critique_metrics.py +4 -4
- helm/benchmark/metrics/kpi_edgar_metrics.py +121 -0
- helm/benchmark/metrics/language_modeling_metrics.py +4 -4
- helm/benchmark/metrics/llm_jury_metrics.py +46 -0
- helm/benchmark/metrics/machine_translation_metrics.py +2 -2
- helm/benchmark/metrics/med_dialog_metrics.py +14 -0
- helm/benchmark/metrics/medalign_metrics.py +14 -0
- helm/benchmark/metrics/medcalc_bench_metrics.py +124 -0
- helm/benchmark/metrics/medec_metrics.py +101 -0
- helm/benchmark/metrics/medi_qa_metrics.py +14 -0
- helm/benchmark/metrics/medication_qa_metrics.py +10 -19
- helm/benchmark/metrics/melt_bias_metric.py +234 -0
- helm/benchmark/metrics/melt_bias_word_lists.py +1367 -0
- helm/benchmark/metrics/melt_metric_specs.py +43 -0
- helm/benchmark/metrics/melt_toxicity_metric.py +107 -0
- helm/benchmark/metrics/mental_health_metrics.py +14 -0
- helm/benchmark/metrics/metric.py +3 -3
- helm/benchmark/metrics/metric_service.py +11 -11
- helm/benchmark/metrics/mimic_bhc_metrics.py +14 -0
- helm/benchmark/metrics/mimic_rrs_metrics.py +14 -0
- helm/benchmark/metrics/mimiciv_billing_code_metrics.py +96 -0
- helm/benchmark/metrics/mtsamples_procedures_metrics.py +14 -0
- helm/benchmark/metrics/mtsamples_replicate_metrics.py +14 -0
- helm/benchmark/metrics/nltk_helper.py +32 -0
- helm/benchmark/metrics/numeracy_metrics.py +4 -4
- helm/benchmark/metrics/omni_math_metrics.py +32 -0
- helm/benchmark/metrics/openai_mrcr_metrics.py +52 -0
- helm/benchmark/metrics/output_processing_metric.py +60 -0
- helm/benchmark/metrics/output_processors.py +15 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +2 -2
- helm/benchmark/metrics/ranking_metrics.py +3 -3
- helm/benchmark/metrics/reference_metric.py +3 -3
- helm/benchmark/metrics/ruler_qa_metrics.py +34 -0
- helm/benchmark/metrics/{bhasa_metrics.py → seahelm_metrics.py} +3 -3
- helm/benchmark/metrics/seahelm_metrics_specs.py +10 -0
- helm/benchmark/metrics/spider_metrics.py +7 -0
- helm/benchmark/metrics/starr_patient_instructions_metrics.py +14 -0
- helm/benchmark/metrics/statistic.py +1 -1
- helm/benchmark/metrics/summac/model_summac.py +2 -3
- helm/benchmark/metrics/summarization_critique_metrics.py +4 -4
- helm/benchmark/metrics/summarization_metrics.py +20 -9
- helm/benchmark/metrics/test_bias_metrics.py +5 -1
- helm/benchmark/metrics/test_classification_metrics.py +140 -68
- helm/benchmark/metrics/test_evaluate_reference_metrics.py +15 -0
- helm/benchmark/metrics/test_metric.py +1 -1
- helm/benchmark/metrics/test_statistic.py +2 -2
- helm/benchmark/metrics/tokens/ai21_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/auto_token_cost_estimator.py +6 -6
- helm/benchmark/metrics/tokens/cohere_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/free_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/openai_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +3 -3
- helm/benchmark/metrics/toxicity_metrics.py +6 -6
- helm/benchmark/metrics/unitxt_metrics.py +7 -5
- helm/benchmark/metrics/vision_language/emd_utils.py +4 -2
- helm/benchmark/metrics/vision_language/image_metrics.py +1 -1
- helm/benchmark/metrics/vision_language/image_utils.py +2 -2
- helm/benchmark/metrics/wildbench_metrics.py +34 -0
- helm/benchmark/model_deployment_registry.py +6 -8
- helm/benchmark/model_metadata_registry.py +16 -0
- helm/benchmark/presentation/contamination.py +3 -3
- helm/benchmark/presentation/create_plots.py +33 -12
- helm/benchmark/presentation/run_display.py +13 -0
- helm/benchmark/presentation/schema.py +2 -1
- helm/benchmark/presentation/summarize.py +97 -67
- helm/benchmark/presentation/torr_robustness_summarizer.py +178 -0
- helm/benchmark/reeval_run.py +202 -0
- helm/benchmark/reeval_runner.py +355 -0
- helm/benchmark/run.py +86 -90
- helm/benchmark/run_expander.py +90 -9
- helm/benchmark/run_spec_factory.py +13 -0
- helm/benchmark/run_specs/air_bench_run_specs.py +21 -3
- helm/benchmark/run_specs/audio_run_specs.py +657 -0
- helm/benchmark/run_specs/call_center_run_specs.py +49 -0
- helm/benchmark/run_specs/capabilities_run_specs.py +308 -0
- helm/benchmark/run_specs/classic_run_specs.py +1 -69
- helm/benchmark/run_specs/enem_challenge_specs.py +31 -0
- helm/benchmark/run_specs/enterprise_run_specs.py +280 -0
- helm/benchmark/run_specs/experimental_run_specs.py +142 -3
- helm/benchmark/run_specs/imdb_ptbr_run_specs.py +30 -0
- helm/benchmark/run_specs/lite_run_specs.py +2 -2
- helm/benchmark/run_specs/long_context_run_specs.py +141 -0
- helm/benchmark/run_specs/medhelm_run_specs.py +1260 -0
- helm/benchmark/run_specs/melt_run_specs.py +783 -0
- helm/benchmark/run_specs/mmlu_clinical_afr_run_specs.py +49 -0
- helm/benchmark/run_specs/oab_exams_specs.py +32 -0
- helm/benchmark/run_specs/safety_run_specs.py +37 -0
- helm/benchmark/run_specs/{bhasa_run_specs.py → seahelm_run_specs.py} +44 -44
- helm/benchmark/run_specs/speech_disorder_audio_run_specs.py +169 -0
- helm/benchmark/run_specs/sql_run_specs.py +54 -0
- helm/benchmark/run_specs/tweetsentbr_run_specs.py +32 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +14 -5
- helm/benchmark/run_specs/vlm_run_specs.py +103 -2
- helm/benchmark/run_specs/winogrande_afr_run_specs.py +47 -0
- helm/benchmark/runner.py +5 -5
- helm/benchmark/scenarios/aci_bench_scenario.py +126 -0
- helm/benchmark/scenarios/air_bench_scenario.py +6 -1
- helm/benchmark/scenarios/anthropic_hh_rlhf_scenario.py +5 -3
- helm/benchmark/scenarios/anthropic_red_team_scenario.py +1 -1
- helm/benchmark/scenarios/audio_language/__init__.py +0 -0
- helm/benchmark/scenarios/audio_language/air_bench_chat_scenario.py +130 -0
- helm/benchmark/scenarios/audio_language/air_bench_foundation_scenario.py +154 -0
- helm/benchmark/scenarios/audio_language/ami_scenario.py +96 -0
- helm/benchmark/scenarios/audio_language/audio_mnist_scenario.py +62 -0
- helm/benchmark/scenarios/audio_language/audio_pairs_scenario.py +62 -0
- helm/benchmark/scenarios/audio_language/audiocaps_scenario.py +59 -0
- helm/benchmark/scenarios/audio_language/casual_conversations2_scenario.py +152 -0
- helm/benchmark/scenarios/audio_language/common_voice_15_scenario.py +99 -0
- helm/benchmark/scenarios/audio_language/corebench_scenario.py +77 -0
- helm/benchmark/scenarios/audio_language/covost2_scenario.py +163 -0
- helm/benchmark/scenarios/audio_language/fleurs_fairness_scenario.py +83 -0
- helm/benchmark/scenarios/audio_language/fleurs_scenario.py +312 -0
- helm/benchmark/scenarios/audio_language/iemocap_audio_scenario.py +83 -0
- helm/benchmark/scenarios/audio_language/librispeech_fairness_scenario.py +96 -0
- helm/benchmark/scenarios/audio_language/librispeech_scenario.py +80 -0
- helm/benchmark/scenarios/audio_language/meld_audio_scenario.py +113 -0
- helm/benchmark/scenarios/audio_language/multilingual_librispeech_scenario.py +80 -0
- helm/benchmark/scenarios/audio_language/mustard_scenario.py +142 -0
- helm/benchmark/scenarios/audio_language/mutox_scenario.py +254 -0
- helm/benchmark/scenarios/audio_language/parade_scenario.py +97 -0
- helm/benchmark/scenarios/audio_language/speech_robust_bench_scenario.py +124 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_asr_classification.py +103 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_classification_scenario.py +110 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_breakdown_scenario.py +78 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_symptoms_scenario.py +109 -0
- helm/benchmark/scenarios/audio_language/vocal_sound_scenario.py +83 -0
- helm/benchmark/scenarios/audio_language/voice_jailbreak_attacks_scenario.py +87 -0
- helm/benchmark/scenarios/audio_language/voxceleb2_scenario.py +105 -0
- helm/benchmark/scenarios/autobencher_capabilities_scenario.py +68 -0
- helm/benchmark/scenarios/autobencher_safety_scenario.py +51 -0
- helm/benchmark/scenarios/babi_qa_scenario.py +1 -1
- helm/benchmark/scenarios/banking77_scenario.py +6 -1
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/big_bench_scenario.py +11 -1
- helm/benchmark/scenarios/bigcodebench_scenario.py +58 -0
- helm/benchmark/scenarios/bird_sql_scenario.py +94 -0
- helm/benchmark/scenarios/bird_sql_scenario_helper.py +118 -0
- helm/benchmark/scenarios/blimp_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +1 -1
- helm/benchmark/scenarios/boolq_scenario.py +1 -1
- helm/benchmark/scenarios/casehold_scenario.py +79 -0
- helm/benchmark/scenarios/chw_care_plan_scenario.py +106 -0
- helm/benchmark/scenarios/civil_comments_scenario.py +1 -1
- helm/benchmark/scenarios/clear_scenario.py +157 -0
- helm/benchmark/scenarios/cleva_scenario.py +2 -2
- helm/benchmark/scenarios/code_scenario.py +17 -4
- helm/benchmark/scenarios/commonsense_scenario.py +1 -1
- helm/benchmark/scenarios/conv_fin_qa_calc_scenario.py +97 -0
- helm/benchmark/scenarios/copyright_scenario.py +1 -1
- helm/benchmark/scenarios/covid_dialog_scenario.py +10 -1
- helm/benchmark/scenarios/cti_to_mitre_scenario.py +240 -0
- helm/benchmark/scenarios/custom_mcqa_scenario.py +1 -1
- helm/benchmark/scenarios/czech_bank_qa_scenario.py +130 -0
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +1 -1
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +1 -1
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +1 -1
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +1 -1
- helm/benchmark/scenarios/dialogue_scenarios.py +13 -2
- helm/benchmark/scenarios/dischargeme_scenario.py +172 -0
- helm/benchmark/scenarios/disinformation_scenario.py +10 -1
- helm/benchmark/scenarios/dyck_language_scenario.py +10 -1
- helm/benchmark/scenarios/echr_judgment_classification_scenario.py +113 -0
- helm/benchmark/scenarios/ehr_sql_scenario.py +137 -0
- helm/benchmark/scenarios/ehrshot_scenario.py +1519 -0
- helm/benchmark/scenarios/enem_challenge_scenario.py +58 -0
- helm/benchmark/scenarios/entity_data_imputation_scenario.py +11 -1
- helm/benchmark/scenarios/entity_matching_scenario.py +12 -2
- helm/benchmark/scenarios/financial_phrasebank_scenario.py +94 -0
- helm/benchmark/scenarios/gold_commodity_news_scenario.py +124 -0
- helm/benchmark/scenarios/gpqa_scenario.py +80 -0
- helm/benchmark/scenarios/grammar.py +2 -2
- helm/benchmark/scenarios/grammar_scenario.py +2 -2
- helm/benchmark/scenarios/gsm_scenario.py +10 -1
- helm/benchmark/scenarios/harm_bench_gcg_transfer_scenario.py +50 -0
- helm/benchmark/scenarios/harm_bench_scenario.py +1 -1
- helm/benchmark/scenarios/headqa_scenario.py +136 -0
- helm/benchmark/scenarios/helpdesk_call_summarization_scenario.py +37 -0
- helm/benchmark/scenarios/ice_scenario.py +8 -4
- helm/benchmark/scenarios/ifeval_scenario.py +53 -0
- helm/benchmark/scenarios/imdb_ptbr_scenario.py +60 -0
- helm/benchmark/scenarios/imdb_scenario.py +11 -2
- helm/benchmark/scenarios/infinite_bench_en_qa_scenario.py +85 -0
- helm/benchmark/scenarios/infinite_bench_en_sum_scenario.py +79 -0
- helm/benchmark/scenarios/interactive_qa_mmlu_scenario.py +2 -2
- helm/benchmark/scenarios/koala_scenario.py +1 -1
- helm/benchmark/scenarios/kpi_edgar_scenario.py +151 -0
- helm/benchmark/scenarios/legal_contract_summarization_scenario.py +129 -0
- helm/benchmark/scenarios/legal_opinion_sentiment_classification_scenario.py +77 -0
- helm/benchmark/scenarios/legal_summarization_scenario.py +11 -1
- helm/benchmark/scenarios/legal_support_scenario.py +11 -1
- helm/benchmark/scenarios/legalbench_scenario.py +22 -3
- helm/benchmark/scenarios/lex_glue_scenario.py +12 -2
- helm/benchmark/scenarios/lextreme_scenario.py +11 -1
- helm/benchmark/scenarios/live_qa_scenario.py +1 -1
- helm/benchmark/scenarios/lm_entry_scenario.py +1 -1
- helm/benchmark/scenarios/lsat_qa_scenario.py +1 -1
- helm/benchmark/scenarios/math_scenario.py +9 -1
- helm/benchmark/scenarios/me_q_sum_scenario.py +10 -1
- helm/benchmark/scenarios/med_dialog_scenario.py +25 -22
- helm/benchmark/scenarios/med_mcqa_scenario.py +10 -1
- helm/benchmark/scenarios/med_paragraph_simplification_scenario.py +10 -1
- helm/benchmark/scenarios/med_qa_scenario.py +10 -1
- helm/benchmark/scenarios/medalign_scenario.py +94 -0
- helm/benchmark/scenarios/medalign_scenario_helper.py +432 -0
- helm/benchmark/scenarios/medbullets_scenario.py +145 -0
- helm/benchmark/scenarios/medcalc_bench_scenario.py +127 -0
- helm/benchmark/scenarios/medec_scenario.py +125 -0
- helm/benchmark/scenarios/medhallu_scenario.py +72 -0
- helm/benchmark/scenarios/medi_qa_scenario.py +111 -0
- helm/benchmark/scenarios/medication_qa_scenario.py +8 -2
- helm/benchmark/scenarios/melt_ir_scenario.py +171 -0
- helm/benchmark/scenarios/melt_knowledge_scenario.py +246 -0
- helm/benchmark/scenarios/melt_lm_scenarios.py +252 -0
- helm/benchmark/scenarios/melt_scenarios.py +793 -0
- helm/benchmark/scenarios/melt_srn_scenario.py +342 -0
- helm/benchmark/scenarios/melt_synthetic_reasoning_scenario.py +222 -0
- helm/benchmark/scenarios/melt_translation_scenario.py +152 -0
- helm/benchmark/scenarios/mental_health_scenario.py +123 -0
- helm/benchmark/scenarios/mimic_bhc_scenario.py +103 -0
- helm/benchmark/scenarios/mimic_rrs_scenario.py +98 -0
- helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +77 -0
- helm/benchmark/scenarios/mmlu_clinical_afr_scenario.py +74 -0
- helm/benchmark/scenarios/mmlu_pro_scenario.py +95 -0
- helm/benchmark/scenarios/mmlu_scenario.py +11 -1
- helm/benchmark/scenarios/msmarco_scenario.py +1 -1
- helm/benchmark/scenarios/mtsamples_procedures_scenario.py +144 -0
- helm/benchmark/scenarios/mtsamples_replicate_scenario.py +142 -0
- helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +277 -0
- helm/benchmark/scenarios/narrativeqa_scenario.py +1 -1
- helm/benchmark/scenarios/natural_qa_scenario.py +1 -1
- helm/benchmark/scenarios/newsqa_scenario.py +1 -1
- helm/benchmark/scenarios/numeracy_scenario.py +12 -2
- helm/benchmark/scenarios/oab_exams_scenario.py +57 -0
- helm/benchmark/scenarios/omni_math_scenario.py +53 -0
- helm/benchmark/scenarios/open_assistant_scenario.py +11 -2
- helm/benchmark/scenarios/openai_mrcr_scenario.py +79 -0
- helm/benchmark/scenarios/opinions_qa_scenario.py +1 -1
- helm/benchmark/scenarios/pubmed_qa_scenario.py +59 -43
- helm/benchmark/scenarios/quac_scenario.py +10 -1
- helm/benchmark/scenarios/race_based_med_scenario.py +152 -0
- helm/benchmark/scenarios/raft_scenario.py +17 -2
- helm/benchmark/scenarios/real_toxicity_prompts_scenario.py +1 -1
- helm/benchmark/scenarios/ruler_qa_scenario_helper.py +171 -0
- helm/benchmark/scenarios/ruler_qa_scenarios.py +88 -0
- helm/benchmark/scenarios/scenario.py +9 -1
- helm/benchmark/scenarios/{bhasa_scenario.py → seahelm_scenario.py} +7 -2
- helm/benchmark/scenarios/self_instruct_scenario.py +1 -1
- helm/benchmark/scenarios/shc_bmt_scenario.py +75 -0
- helm/benchmark/scenarios/shc_cdi_scenario.py +75 -0
- helm/benchmark/scenarios/shc_conf_scenario.py +76 -0
- helm/benchmark/scenarios/shc_ent_scenario.py +77 -0
- helm/benchmark/scenarios/shc_gip_scenario.py +74 -0
- helm/benchmark/scenarios/shc_privacy_scenario.py +78 -0
- helm/benchmark/scenarios/shc_proxy_scenario.py +76 -0
- helm/benchmark/scenarios/shc_ptbm_scenario.py +81 -0
- helm/benchmark/scenarios/shc_sei_scenario.py +94 -0
- helm/benchmark/scenarios/shc_sequoia_scenario.py +77 -0
- helm/benchmark/scenarios/simple_safety_tests_scenario.py +1 -1
- helm/benchmark/scenarios/spider_scenario.py +91 -0
- helm/benchmark/scenarios/starr_patient_instructions_scenario.py +97 -0
- helm/benchmark/scenarios/summarization_scenario.py +11 -1
- helm/benchmark/scenarios/sumosum_scenario.py +157 -0
- helm/benchmark/scenarios/synthetic_efficiency_scenario.py +1 -1
- helm/benchmark/scenarios/synthetic_reasoning_natural_scenario.py +11 -1
- helm/benchmark/scenarios/synthetic_reasoning_scenario.py +11 -1
- helm/benchmark/scenarios/test_bigcodebench_scenario.py +26 -0
- helm/benchmark/scenarios/test_czech_bank_qa_scenario.py +18 -0
- helm/benchmark/scenarios/test_enem_challenge_scenario.py +53 -0
- helm/benchmark/scenarios/test_ewok_scenario.py +6 -2
- helm/benchmark/scenarios/test_gold_commodity_news_scenario.py +18 -0
- helm/benchmark/scenarios/test_gpqa_scenario.py +44 -0
- helm/benchmark/scenarios/test_ifeval_scenario.py +36 -0
- helm/benchmark/scenarios/test_imdb_ptbr_scenario.py +27 -0
- helm/benchmark/scenarios/test_infinite_bench_en_qa_scenario.py +18 -0
- helm/benchmark/scenarios/test_infinite_bench_en_sum_scenario.py +31 -0
- helm/benchmark/scenarios/test_math_scenario.py +1 -0
- helm/benchmark/scenarios/test_mmlu_clinical_afr_scenario.py +21 -0
- helm/benchmark/scenarios/test_mmlu_pro_scenario.py +53 -0
- helm/benchmark/scenarios/test_oab_exams_scenario.py +51 -0
- helm/benchmark/scenarios/test_omni_math_scenario.py +27 -0
- helm/benchmark/scenarios/test_tweetsentbr_scenario.py +24 -0
- helm/benchmark/scenarios/test_wildbench_scenario.py +15 -0
- helm/benchmark/scenarios/test_winogrande_afr_scenario.py +19 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +10 -1
- helm/benchmark/scenarios/the_pile_scenario.py +1 -1
- helm/benchmark/scenarios/truthful_qa_scenario.py +12 -2
- helm/benchmark/scenarios/tweetsentbr_scenario.py +66 -0
- helm/benchmark/scenarios/twitter_aae_scenario.py +1 -1
- helm/benchmark/scenarios/unitxt_scenario.py +8 -2
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +1 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/blink_scenario.py +140 -0
- helm/benchmark/scenarios/vision_language/mm_star_scenario.py +95 -0
- helm/benchmark/scenarios/vision_language/msr_vtt_scenario.py +75 -0
- helm/benchmark/scenarios/vision_language/vqa_rad_scenario.py +88 -0
- helm/benchmark/scenarios/wikifact_scenario.py +11 -1
- helm/benchmark/scenarios/wikitext_103_scenario.py +1 -1
- helm/benchmark/scenarios/wildbench_scenario.py +83 -0
- helm/benchmark/scenarios/winogrande_afr_scenario.py +78 -0
- helm/benchmark/scenarios/wmt_14_scenario.py +14 -2
- helm/benchmark/scenarios/xstest_scenario.py +1 -1
- helm/benchmark/server.py +13 -1
- helm/benchmark/slurm_runner.py +1 -1
- helm/benchmark/static/schema_audio.yaml +763 -0
- helm/benchmark/static/schema_autobencher.yaml +150 -0
- helm/benchmark/static/schema_call_center.yaml +97 -60
- helm/benchmark/static/{schema_medical.yaml → schema_capabilities.yaml} +100 -101
- helm/benchmark/static/schema_czech_bank.yaml +148 -0
- helm/benchmark/static/schema_enem_challenge.yaml +146 -0
- helm/benchmark/static/schema_enterprise.yaml +319 -0
- helm/benchmark/static/schema_finance.yaml +14 -12
- helm/benchmark/static/schema_heim.yaml +1389 -0
- helm/benchmark/static/schema_long_context.yaml +283 -0
- helm/benchmark/static/schema_medhelm.yaml +1140 -0
- helm/benchmark/static/schema_melt.yaml +1257 -0
- helm/benchmark/static/schema_mmlu_winogrande_afr.yaml +1045 -0
- helm/benchmark/static/schema_safety.yaml +18 -1
- helm/benchmark/static/{schema_bhasa.yaml → schema_seahelm.yaml} +30 -16
- helm/benchmark/static/schema_slphelm.yaml +162 -0
- helm/benchmark/static/schema_social_audio.yaml +224 -0
- helm/benchmark/static/schema_sql.yaml +171 -0
- helm/benchmark/static/{schema_tables.yaml → schema_torr.yaml} +169 -36
- helm/benchmark/static/schema_tweetsentbr.yaml +146 -0
- helm/benchmark/static/schema_vhelm.yaml +129 -56
- helm/benchmark/static/schema_video.yaml +219 -0
- helm/benchmark/static_build/assets/helm-safety-2907a7b6.png +0 -0
- helm/benchmark/static_build/assets/index-94295e78.js +10 -0
- helm/benchmark/static_build/assets/index-b9779128.css +1 -0
- helm/benchmark/static_build/assets/medhelm-overview-eac29843.png +0 -0
- helm/benchmark/static_build/assets/medhelm-v1-overview-3ddfcd65.png +0 -0
- helm/benchmark/static_build/assets/{react-d4a0b69b.js → react-f82877fd.js} +1 -1
- helm/benchmark/static_build/assets/{recharts-6d337683.js → recharts-4037aff0.js} +1 -1
- helm/benchmark/static_build/assets/{tremor-54a99cc4.js → tremor-38a10867.js} +2 -2
- helm/benchmark/static_build/config.js +1 -1
- helm/benchmark/static_build/index.html +6 -6
- helm/benchmark/window_services/default_window_service.py +1 -1
- helm/benchmark/window_services/encoder_decoder_window_service.py +4 -4
- helm/benchmark/window_services/ice_window_service.py +1 -1
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +1 -1
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +1 -1
- helm/benchmark/window_services/local_window_service.py +2 -2
- helm/benchmark/window_services/test_anthropic_window_service.py +3 -3
- helm/benchmark/window_services/test_bloom_window_service.py +3 -3
- helm/benchmark/window_services/test_gpt2_window_service.py +7 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +8 -3
- helm/benchmark/window_services/test_gptj_window_service.py +8 -3
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -3
- helm/benchmark/window_services/test_openai_window_service.py +8 -3
- helm/benchmark/window_services/test_opt_window_service.py +3 -3
- helm/benchmark/window_services/test_palmyra_window_service.py +3 -3
- helm/benchmark/window_services/test_t0pp_window_service.py +3 -3
- helm/benchmark/window_services/test_t511b_window_service.py +3 -3
- helm/benchmark/window_services/test_ul2_window_service.py +3 -3
- helm/benchmark/window_services/test_utils.py +4 -5
- helm/benchmark/window_services/test_yalm_window_service.py +3 -3
- helm/benchmark/window_services/tokenizer_service.py +7 -8
- helm/benchmark/window_services/yalm_window_service.py +1 -1
- helm/clients/ai21_client.py +3 -3
- helm/clients/aleph_alpha_client.py +1 -1
- helm/clients/anthropic_client.py +69 -29
- helm/clients/audio_language/__init__.py +0 -0
- helm/clients/audio_language/diva_llama_client.py +120 -0
- helm/clients/audio_language/llama_omni_client.py +198 -0
- helm/clients/audio_language/qwen2_5_omni_client.py +197 -0
- helm/clients/audio_language/qwen2_audiolm_client.py +190 -0
- helm/clients/audio_language/qwen_audiolm_client.py +152 -0
- helm/clients/audio_language/test.py +62 -0
- helm/clients/auto_client.py +4 -2
- helm/clients/azure_openai_client.py +55 -0
- helm/clients/bedrock_client.py +203 -7
- helm/clients/bedrock_utils.py +33 -0
- helm/clients/client.py +7 -7
- helm/clients/clip_scorers/clip_scorer.py +1 -1
- helm/clients/clip_scorers/multilingual_clip_scorer.py +1 -1
- helm/clients/cohere_client.py +3 -3
- helm/clients/google_client.py +1 -1
- helm/clients/grok_client.py +36 -0
- helm/clients/http_model_client.py +1 -1
- helm/clients/huggingface_client.py +52 -21
- helm/clients/huggingface_pipeline_client.py +138 -0
- helm/clients/ibm_client.py +267 -0
- helm/clients/image_generation/adobe_vision_client.py +1 -1
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +1 -1
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +3 -3
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +5 -2
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +5 -2
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +2 -2
- helm/clients/image_generation/cogview2_client.py +1 -1
- helm/clients/image_generation/dalle2_client.py +1 -1
- helm/clients/image_generation/dalle3_client.py +2 -2
- helm/clients/image_generation/dalle_mini/__init__.py +1 -1
- helm/clients/image_generation/dalle_mini/data.py +1 -1
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -5
- helm/clients/image_generation/dalle_mini/model/configuration.py +2 -2
- helm/clients/image_generation/dalle_mini/model/modeling.py +3 -3
- helm/clients/image_generation/dalle_mini/model/processor.py +5 -5
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +2 -2
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -1
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +2 -2
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +1 -1
- helm/clients/image_generation/dalle_mini_client.py +1 -1
- helm/clients/image_generation/deep_floyd_client.py +1 -1
- helm/clients/image_generation/huggingface_diffusers_client.py +1 -1
- helm/clients/image_generation/lexica_client.py +1 -1
- helm/clients/image_generation/mindalle/models/__init__.py +6 -6
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +1 -1
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +1 -1
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -3
- helm/clients/image_generation/mindalle_client.py +1 -1
- helm/clients/image_generation/together_image_generation_client.py +1 -1
- helm/clients/lit_gpt_client.py +2 -2
- helm/clients/mistral_client.py +62 -18
- helm/clients/nvidia_nim_client.py +0 -3
- helm/clients/openai_client.py +308 -43
- helm/clients/openai_responses_client.py +174 -0
- helm/clients/palmyra_client.py +3 -9
- helm/clients/reka_client.py +3 -3
- helm/clients/stanfordhealthcare_azure_openai_client.py +58 -0
- helm/clients/stanfordhealthcare_claude_client.py +31 -0
- helm/clients/stanfordhealthcare_google_client.py +43 -0
- helm/clients/stanfordhealthcare_http_model_client.py +93 -0
- helm/clients/stanfordhealthcare_openai_client.py +62 -0
- helm/clients/stanfordhealthcare_shc_openai_client.py +42 -0
- helm/clients/test_client.py +1 -1
- helm/clients/test_together_client.py +6 -1
- helm/clients/together_client.py +76 -9
- helm/clients/upstage_client.py +23 -0
- helm/clients/vertexai_client.py +45 -13
- helm/clients/vision_language/huggingface_vision2seq_client.py +6 -4
- helm/clients/vision_language/huggingface_vlm_client.py +2 -2
- helm/clients/vision_language/idefics_client.py +6 -2
- helm/clients/vision_language/open_flamingo/__init__.py +2 -2
- helm/clients/vision_language/open_flamingo/src/factory.py +3 -3
- helm/clients/vision_language/open_flamingo/src/flamingo.py +2 -2
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +2 -2
- helm/clients/vision_language/paligemma_client.py +2 -2
- helm/clients/vision_language/qwen2_vlm_client.py +188 -0
- helm/clients/vision_language/qwen_vlm_client.py +7 -5
- helm/clients/vllm_client.py +4 -6
- helm/clients/writer_client.py +102 -0
- helm/clients/yi_client.py +0 -3
- helm/common/audio_utils.py +111 -0
- helm/common/context.py +80 -0
- helm/common/credentials_utils.py +5 -5
- helm/common/file_caches/local_file_cache.py +1 -1
- helm/common/file_caches/test_local_file_cache.py +1 -1
- helm/common/general.py +9 -2
- helm/common/hierarchical_logger.py +46 -3
- helm/common/images_utils.py +2 -2
- helm/common/local_context.py +140 -0
- helm/common/media_object.py +2 -2
- helm/common/multimodal_request_utils.py +26 -0
- helm/common/reeval_parameters.py +12 -0
- helm/common/remote_context.py +61 -0
- helm/common/request.py +14 -2
- helm/common/response_format.py +18 -0
- helm/common/test_media_object.py +1 -1
- helm/config/model_deployments.yaml +1792 -28
- helm/config/model_metadata.yaml +1606 -51
- helm/config/tokenizer_configs.yaml +521 -4
- helm/proxy/cli.py +5 -3
- helm/proxy/critique/mechanical_turk_utils.py +1 -1
- helm/proxy/example_queries.py +1 -1
- helm/proxy/server.py +11 -4
- helm/proxy/services/remote_service.py +1 -1
- helm/proxy/services/server_service.py +22 -86
- helm/proxy/services/test_remote_service.py +2 -2
- helm/proxy/services/test_service.py +1 -1
- helm/proxy/static/general.js +122 -0
- helm/proxy/static/help.html +99 -0
- helm/proxy/static/index.css +57 -0
- helm/proxy/static/index.html +40 -0
- helm/proxy/static/index.js +456 -0
- helm/proxy/static/info-icon.png +0 -0
- helm/proxy/test_retry.py +1 -1
- helm/proxy/token_counters/auto_token_counter.py +1 -1
- helm/tokenizers/aleph_alpha_tokenizer.py +1 -1
- helm/tokenizers/caching_tokenizer.py +2 -30
- helm/tokenizers/grok_tokenizer.py +53 -0
- helm/tokenizers/http_model_tokenizer.py +1 -1
- helm/tokenizers/huggingface_tokenizer.py +3 -3
- helm/tokenizers/lit_gpt_tokenizer.py +1 -1
- helm/tokenizers/test_anthropic_tokenizer.py +6 -2
- helm/tokenizers/test_grok_tokenizer.py +33 -0
- helm/tokenizers/test_huggingface_tokenizer.py +1 -1
- helm/tokenizers/test_yalm_tokenizer.py +1 -1
- helm/tokenizers/tiktoken_tokenizer.py +1 -1
- helm/tokenizers/tokenizer.py +3 -1
- helm/tokenizers/yalm_tokenizer.py +3 -3
- helm/tokenizers/yalm_tokenizer_data/test_yalm_tokenizer.py +1 -1
- crfm_helm-0.5.4.dist-info/METADATA +0 -350
- crfm_helm-0.5.4.dist-info/RECORD +0 -697
- helm/benchmark/metrics/bhasa_metrics_specs.py +0 -10
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/accenture-6f97eeda.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aisingapore-6dfc9acf.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/cresta-9e22b983.png +0 -0
- helm/benchmark/static_build/assets/cuhk-8c5631e9.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/index-05c76bb1.css +0 -1
- helm/benchmark/static_build/assets/index-3ee38b3d.js +0 -10
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/scb10x-204bd786.png +0 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/wellsfargo-a86a6c4a.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/tokenizers/anthropic_tokenizer.py +0 -52
- {crfm_helm-0.5.4.dist-info → crfm_helm-0.5.6.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.4.dist-info → crfm_helm-0.5.6.dist-info/licenses}/LICENSE +0 -0
- {crfm_helm-0.5.4.dist-info → crfm_helm-0.5.6.dist-info}/top_level.txt +0 -0
helm/clients/vertexai_client.py
CHANGED
|
@@ -4,6 +4,7 @@ from threading import Lock
|
|
|
4
4
|
from typing import Any, Dict, Mapping, Optional, List, Union
|
|
5
5
|
|
|
6
6
|
from helm.common.cache import CacheConfig
|
|
7
|
+
from helm.common.multimodal_request_utils import get_contents_as_bytes
|
|
7
8
|
from helm.common.media_object import TEXT_TYPE
|
|
8
9
|
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
9
10
|
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, ErrorFlags
|
|
@@ -12,7 +13,14 @@ from helm.clients.client import CachingClient, truncate_sequence, generate_uid_f
|
|
|
12
13
|
try:
|
|
13
14
|
import vertexai
|
|
14
15
|
from vertexai.language_models import TextGenerationModel, TextGenerationResponse # PaLM2
|
|
15
|
-
from vertexai.preview.generative_models import
|
|
16
|
+
from vertexai.preview.generative_models import (
|
|
17
|
+
GenerativeModel,
|
|
18
|
+
GenerationResponse,
|
|
19
|
+
Candidate,
|
|
20
|
+
Content,
|
|
21
|
+
Part,
|
|
22
|
+
Image,
|
|
23
|
+
) # Gemini
|
|
16
24
|
from google.cloud.aiplatform_v1beta1.types import SafetySetting, HarmCategory
|
|
17
25
|
except ModuleNotFoundError as e:
|
|
18
26
|
handle_module_not_found_error(e, ["google"])
|
|
@@ -48,17 +56,16 @@ def _get_safety_settings_for_preset(
|
|
|
48
56
|
raise ValueError(f"Unknown safety_settings_preset: {safety_settings_preset}")
|
|
49
57
|
|
|
50
58
|
|
|
51
|
-
def _get_model_name_for_request(request: Request) -> str:
|
|
52
|
-
# We have to strip "-safety-" suffixes from model names because they are not part of the Vertex AI model name
|
|
53
|
-
# TODO: Clean up this hack
|
|
54
|
-
return request.model_engine.split("-safety-")[0]
|
|
55
|
-
|
|
56
|
-
|
|
57
59
|
class VertexAIClient(CachingClient, ABC):
|
|
58
60
|
"""Client for Vertex AI models"""
|
|
59
61
|
|
|
60
62
|
def __init__(
|
|
61
|
-
self,
|
|
63
|
+
self,
|
|
64
|
+
cache_config: CacheConfig,
|
|
65
|
+
project_id: str,
|
|
66
|
+
location: str,
|
|
67
|
+
safety_settings_preset: Optional[str] = None,
|
|
68
|
+
vertexai_model: Optional[str] = None,
|
|
62
69
|
) -> None:
|
|
63
70
|
super().__init__(cache_config=cache_config)
|
|
64
71
|
self.project_id = project_id
|
|
@@ -67,8 +74,15 @@ class VertexAIClient(CachingClient, ABC):
|
|
|
67
74
|
self.safety_settings_preset = safety_settings_preset
|
|
68
75
|
self.safety_settings = _get_safety_settings_for_preset(safety_settings_preset)
|
|
69
76
|
|
|
77
|
+
self.vertexai_model = vertexai_model
|
|
78
|
+
|
|
70
79
|
vertexai.init(project=self.project_id, location=self.location)
|
|
71
80
|
|
|
81
|
+
def _get_model_name_for_request(self, request: Request) -> str:
|
|
82
|
+
if self.vertexai_model is not None:
|
|
83
|
+
return self.vertexai_model
|
|
84
|
+
return request.model_engine
|
|
85
|
+
|
|
72
86
|
def make_cache_key_with_safety_settings_preset(self, raw_request: Mapping, request: Request) -> Mapping:
|
|
73
87
|
"""Construct the key for the cache using the raw request.
|
|
74
88
|
|
|
@@ -111,7 +125,7 @@ class VertexAITextClient(VertexAIClient):
|
|
|
111
125
|
}
|
|
112
126
|
|
|
113
127
|
completions: List[GeneratedOutput] = []
|
|
114
|
-
model_name: str = _get_model_name_for_request(request)
|
|
128
|
+
model_name: str = self._get_model_name_for_request(request)
|
|
115
129
|
|
|
116
130
|
try:
|
|
117
131
|
|
|
@@ -193,12 +207,20 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
193
207
|
|
|
194
208
|
def make_request(self, request: Request) -> RequestResult:
|
|
195
209
|
"""Make a request"""
|
|
196
|
-
contents
|
|
210
|
+
contents = [request.prompt]
|
|
197
211
|
|
|
198
212
|
# For the multimodal case, build up the content with the media objects of `request.multimodal_prompt`
|
|
199
213
|
if request.multimodal_prompt is not None:
|
|
200
214
|
return self._make_multimodal_request(request)
|
|
201
215
|
|
|
216
|
+
if request.messages is not None:
|
|
217
|
+
contents = []
|
|
218
|
+
role_mapping = {"user": "user", "assistant": "model"}
|
|
219
|
+
for msg in request.messages:
|
|
220
|
+
contents.append(
|
|
221
|
+
Content(role=role_mapping.get(msg["role"], "user"), parts=[Part.from_text(msg["content"])])
|
|
222
|
+
)
|
|
223
|
+
|
|
202
224
|
parameters = {
|
|
203
225
|
"temperature": request.temperature,
|
|
204
226
|
"max_output_tokens": request.max_tokens,
|
|
@@ -217,7 +239,7 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
217
239
|
}
|
|
218
240
|
|
|
219
241
|
completions: List[GeneratedOutput] = []
|
|
220
|
-
model_name: str = _get_model_name_for_request(request)
|
|
242
|
+
model_name: str = self._get_model_name_for_request(request)
|
|
221
243
|
model = self.get_model(model_name)
|
|
222
244
|
|
|
223
245
|
try:
|
|
@@ -263,7 +285,7 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
263
285
|
cache_key = self.make_cache_key_with_safety_settings_preset(
|
|
264
286
|
{
|
|
265
287
|
"model_name": model_name,
|
|
266
|
-
"prompt": request.prompt,
|
|
288
|
+
"prompt": request.messages or request.prompt,
|
|
267
289
|
**parameters,
|
|
268
290
|
},
|
|
269
291
|
request,
|
|
@@ -338,6 +360,16 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
338
360
|
for media_object in request.multimodal_prompt.media_objects:
|
|
339
361
|
if media_object.is_type("image") and media_object.location:
|
|
340
362
|
contents.append(Part.from_image(Image.load_from_file(media_object.location)))
|
|
363
|
+
elif media_object.is_type("video") and media_object.location:
|
|
364
|
+
# Following this example
|
|
365
|
+
# https://cloud.google.com/vertex-ai/generative-ai/docs/samples/googlegenaisdk-textgen-with-local-video
|
|
366
|
+
with open(media_object.location, "rb") as fp:
|
|
367
|
+
video_content = fp.read()
|
|
368
|
+
contents.append(Part.from_data(data=video_content, mime_type=media_object.content_type))
|
|
369
|
+
elif media_object.is_type("audio") and media_object.location:
|
|
370
|
+
contents.append(
|
|
371
|
+
Part.from_data(get_contents_as_bytes(media_object.location), mime_type=media_object.content_type)
|
|
372
|
+
)
|
|
341
373
|
elif media_object.is_type(TEXT_TYPE):
|
|
342
374
|
if media_object.text is None:
|
|
343
375
|
raise ValueError("MediaObject of text type has missing text field value")
|
|
@@ -355,7 +387,7 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
355
387
|
}
|
|
356
388
|
|
|
357
389
|
completions: List[GeneratedOutput] = []
|
|
358
|
-
model_name: str = _get_model_name_for_request(request)
|
|
390
|
+
model_name: str = self._get_model_name_for_request(request)
|
|
359
391
|
model = self.get_model(model_name)
|
|
360
392
|
|
|
361
393
|
request_time = 0
|
|
@@ -95,8 +95,8 @@ class HuggingFaceVision2SeqClient(CachingClient):
|
|
|
95
95
|
|
|
96
96
|
def do_it() -> Dict[str, Any]:
|
|
97
97
|
messages = [{"role": "user", "content": multimodal_prompt}]
|
|
98
|
-
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
|
99
|
-
inputs = processor(
|
|
98
|
+
prompt = processor.apply_chat_template(messages, add_generation_prompt=True) # type: ignore
|
|
99
|
+
inputs = processor( # type: ignore
|
|
100
100
|
text=[prompt] * request.num_completions,
|
|
101
101
|
images=[
|
|
102
102
|
[load_image(image_path) for image_path in image_paths]
|
|
@@ -107,8 +107,10 @@ class HuggingFaceVision2SeqClient(CachingClient):
|
|
|
107
107
|
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
|
108
108
|
|
|
109
109
|
# Generate
|
|
110
|
-
generated_ids = model.generate(**inputs, **generation_args)
|
|
111
|
-
generated_texts: List[str] = processor.batch_decode(
|
|
110
|
+
generated_ids = model.generate(**inputs, **generation_args) # type: ignore
|
|
111
|
+
generated_texts: List[str] = processor.batch_decode( # type: ignore
|
|
112
|
+
generated_ids, skip_special_tokens=True
|
|
113
|
+
)
|
|
112
114
|
return {"output": generated_texts}
|
|
113
115
|
|
|
114
116
|
# Include the prompt and model name in the cache key
|
|
@@ -50,7 +50,7 @@ class HuggingFaceVLMClient(CachingClient):
|
|
|
50
50
|
with self._models_lock:
|
|
51
51
|
model_id: str = self._models_aliases.get(model_name, model_name)
|
|
52
52
|
if model_id not in self._models:
|
|
53
|
-
self._models[model_id] = pipeline("image-to-text", model=model_id, device_map="auto")
|
|
53
|
+
self._models[model_id] = pipeline("image-to-text", model=model_id, device_map="auto") # type: ignore
|
|
54
54
|
return self._models[model_id]
|
|
55
55
|
|
|
56
56
|
def make_request(self, request: Request) -> RequestResult:
|
|
@@ -80,7 +80,7 @@ class HuggingFaceVLMClient(CachingClient):
|
|
|
80
80
|
|
|
81
81
|
def do_it() -> Dict[str, Any]:
|
|
82
82
|
model: ImageToTextPipeline = self._get_model(request.model_deployment)
|
|
83
|
-
outputs = model(image, prompt=prompt, generate_kwargs=generation_args)
|
|
83
|
+
outputs = model(image, prompt=prompt, generate_kwargs=generation_args) # type: ignore
|
|
84
84
|
return outputs[0]
|
|
85
85
|
|
|
86
86
|
cache_key = CachingClient.make_cache_key(
|
|
@@ -89,14 +89,18 @@ class IDEFICSClient(CachingClient):
|
|
|
89
89
|
input_args: Dict[str, Union[str, bool]] = {"return_tensors": "pt"}
|
|
90
90
|
generation_args = {
|
|
91
91
|
"max_new_tokens": request.max_tokens,
|
|
92
|
-
"bad_words_ids": processor.tokenizer(
|
|
92
|
+
"bad_words_ids": processor.tokenizer( # type: ignore
|
|
93
|
+
self.BAD_WORD_TOKENS, add_special_tokens=False
|
|
94
|
+
).input_ids,
|
|
93
95
|
}
|
|
94
96
|
|
|
95
97
|
if self.END_OF_UTTERANCE_TOKEN in request.stop_sequences:
|
|
96
98
|
# Following https://huggingface.co/HuggingFaceM4/idefics-80b-instruct,
|
|
97
99
|
# specify <end_of_utterance> as an exit condition.
|
|
98
100
|
input_args["add_end_of_utterance_token"] = False
|
|
99
|
-
exit_condition = processor.tokenizer(
|
|
101
|
+
exit_condition = processor.tokenizer( # type: ignore
|
|
102
|
+
self.END_OF_UTTERANCE_TOKEN, add_special_tokens=False
|
|
103
|
+
).input_ids
|
|
100
104
|
generation_args["eos_token_id"] = exit_condition
|
|
101
105
|
|
|
102
106
|
multimodal_prompt: List[Union[str, Image.Image]] = []
|
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
from .src.flamingo import Flamingo
|
|
2
|
-
from .src.factory import create_model_and_transforms
|
|
1
|
+
from helm.clients.vision_language.open_flamingo.src.flamingo import Flamingo
|
|
2
|
+
from helm.clients.vision_language.open_flamingo.src.factory import create_model_and_transforms
|
|
@@ -7,9 +7,9 @@ from typing import Optional
|
|
|
7
7
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
8
8
|
|
|
9
9
|
from helm.common.general import handle_module_not_found_error
|
|
10
|
-
from .flamingo import Flamingo
|
|
11
|
-
from .flamingo_lm import FlamingoLMMixin
|
|
12
|
-
from .utils import extend_instance
|
|
10
|
+
from helm.clients.vision_language.open_flamingo.src.flamingo import Flamingo
|
|
11
|
+
from helm.clients.vision_language.open_flamingo.src.flamingo_lm import FlamingoLMMixin
|
|
12
|
+
from helm.clients.vision_language.open_flamingo.src.utils import extend_instance
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def create_model_and_transforms(
|
|
@@ -5,7 +5,7 @@ Source: https://github.com/mlfoundations/open_flamingo
|
|
|
5
5
|
import torch
|
|
6
6
|
from einops import rearrange
|
|
7
7
|
from torch import nn
|
|
8
|
-
from .helpers import PerceiverResampler
|
|
8
|
+
from helm.clients.vision_language.open_flamingo.src.helpers import PerceiverResampler
|
|
9
9
|
from torch.distributed.fsdp.wrap import (
|
|
10
10
|
enable_wrap,
|
|
11
11
|
wrap,
|
|
@@ -15,7 +15,7 @@ from torch.distributed.fsdp import (
|
|
|
15
15
|
FullyShardedDataParallel as FSDP,
|
|
16
16
|
)
|
|
17
17
|
|
|
18
|
-
from .utils import apply_with_stopping_condition
|
|
18
|
+
from helm.clients.vision_language.open_flamingo.src.utils import apply_with_stopping_condition
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class Flamingo(nn.Module):
|
|
@@ -3,8 +3,8 @@ Source: https://github.com/mlfoundations/open_flamingo
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
import torch.nn as nn
|
|
6
|
-
from .helpers import GatedCrossAttentionBlock
|
|
7
|
-
from .utils import getattr_recursive, setattr_recursive
|
|
6
|
+
from helm.clients.vision_language.open_flamingo.src.helpers import GatedCrossAttentionBlock
|
|
7
|
+
from helm.clients.vision_language.open_flamingo.src.utils import getattr_recursive, setattr_recursive
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class FlamingoLayer(nn.Module):
|
|
@@ -93,7 +93,7 @@ class PaliGemmaClient(CachingClient):
|
|
|
93
93
|
else:
|
|
94
94
|
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
|
95
95
|
prompt_text: str = "\n".join(prompt_pieces)
|
|
96
|
-
model_inputs = processor(text=prompt_text, images=images, return_tensors="pt").to(self._device)
|
|
96
|
+
model_inputs = processor(text=prompt_text, images=images, return_tensors="pt").to(self._device) # type: ignore
|
|
97
97
|
input_len = model_inputs["input_ids"].shape[-1]
|
|
98
98
|
|
|
99
99
|
completions: List[GeneratedOutput] = []
|
|
@@ -109,7 +109,7 @@ class PaliGemmaClient(CachingClient):
|
|
|
109
109
|
)[0]
|
|
110
110
|
if not request.echo_prompt:
|
|
111
111
|
generation = generation[input_len:]
|
|
112
|
-
decoded = processor.decode(generation, skip_special_tokens=True)
|
|
112
|
+
decoded = processor.decode(generation, skip_special_tokens=True) # type: ignore
|
|
113
113
|
return {"output": decoded}
|
|
114
114
|
|
|
115
115
|
# Include the prompt and model name in the cache key
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
from threading import Lock
|
|
2
|
+
from typing import Any, Dict, List, Optional
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from transformers import AutoProcessor
|
|
6
|
+
from qwen_vl_utils import process_vision_info
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from helm.common.cache import CacheConfig
|
|
10
|
+
from helm.common.gpu_utils import get_torch_device_name
|
|
11
|
+
from helm.common.hierarchical_logger import hlog, htrack_block
|
|
12
|
+
from helm.common.media_object import TEXT_TYPE
|
|
13
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, Token
|
|
14
|
+
from helm.common.request import wrap_request_time
|
|
15
|
+
from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(frozen=True)
|
|
19
|
+
class LoadedModelProcessor:
|
|
20
|
+
model: Any
|
|
21
|
+
processor: AutoProcessor
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# Global cache for all models
|
|
25
|
+
_models_lock: Lock = Lock()
|
|
26
|
+
_models: Dict[str, Optional[LoadedModelProcessor]] = {
|
|
27
|
+
"Qwen/Qwen2-VL-7B-Instruct": None,
|
|
28
|
+
"Qwen/Qwen2-VL-72B-Instruct": None,
|
|
29
|
+
"Qwen/Qwen2.5-VL-3B-Instruct": None,
|
|
30
|
+
"Qwen/Qwen2.5-VL-7B-Instruct": None,
|
|
31
|
+
"Qwen/Qwen2.5-VL-32B-Instruct": None,
|
|
32
|
+
"Qwen/Qwen2.5-VL-72B-Instruct": None,
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Qwen2VLMClient(CachingClient):
|
|
37
|
+
def __init__(self, cache_config: CacheConfig):
|
|
38
|
+
super().__init__(cache_config=cache_config)
|
|
39
|
+
self._device: str = get_torch_device_name()
|
|
40
|
+
|
|
41
|
+
def _get_model_name(self, helm_model_name: str) -> str:
|
|
42
|
+
if helm_model_name == "qwen2-vl-7b-instruct":
|
|
43
|
+
return "Qwen/Qwen2-VL-7B-Instruct"
|
|
44
|
+
elif helm_model_name == "qwen2-vl-72b-instruct":
|
|
45
|
+
return "Qwen/Qwen2-VL-72B-Instruct"
|
|
46
|
+
elif helm_model_name == "qwen2.5-vl-3b-instruct":
|
|
47
|
+
return "Qwen/Qwen2.5-VL-3B-Instruct"
|
|
48
|
+
elif helm_model_name == "qwen2.5-vl-7b-instruct":
|
|
49
|
+
return "Qwen/Qwen2.5-VL-7B-Instruct"
|
|
50
|
+
elif helm_model_name == "qwen2.5-vl-32b-instruct":
|
|
51
|
+
return "Qwen/Qwen2.5-VL-32B-Instruct"
|
|
52
|
+
elif helm_model_name == "qwen2.5-vl-72b-instruct":
|
|
53
|
+
return "Qwen/Qwen2.5-VL-72B-Instruct"
|
|
54
|
+
else:
|
|
55
|
+
raise ValueError(f"Unhandled model name: {helm_model_name}")
|
|
56
|
+
|
|
57
|
+
def _get_model(self, helm_model_name: str) -> LoadedModelProcessor:
|
|
58
|
+
from transformers import Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
|
|
59
|
+
|
|
60
|
+
global _models_lock, _models
|
|
61
|
+
|
|
62
|
+
model_name = self._get_model_name(helm_model_name)
|
|
63
|
+
with _models_lock:
|
|
64
|
+
loaded = _models[model_name]
|
|
65
|
+
if loaded is None:
|
|
66
|
+
hlog(f"Loading model {model_name} and caching in memory...")
|
|
67
|
+
# Use different loading routines depending on whether it's Qwen2.5 or Qwen2.
|
|
68
|
+
if "2.5" in model_name:
|
|
69
|
+
# Qwen2.5: by default use torch_dtype="auto". You can enable flash_attention_2 if desired.
|
|
70
|
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
|
71
|
+
model_name,
|
|
72
|
+
torch_dtype=torch.bfloat16,
|
|
73
|
+
device_map="auto",
|
|
74
|
+
attn_implementation="flash_attention_2",
|
|
75
|
+
).eval()
|
|
76
|
+
else:
|
|
77
|
+
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
78
|
+
model_name,
|
|
79
|
+
torch_dtype=torch.bfloat16,
|
|
80
|
+
device_map="auto",
|
|
81
|
+
attn_implementation="flash_attention_2",
|
|
82
|
+
).eval()
|
|
83
|
+
processor = AutoProcessor.from_pretrained(model_name)
|
|
84
|
+
loaded = LoadedModelProcessor(model=model, processor=processor)
|
|
85
|
+
_models[model_name] = loaded
|
|
86
|
+
return loaded
|
|
87
|
+
|
|
88
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
89
|
+
assert request.multimodal_prompt is not None, "Multimodal prompt is required"
|
|
90
|
+
|
|
91
|
+
# Build messages by collating all media objects into a single "user" message.
|
|
92
|
+
message_content = []
|
|
93
|
+
for media_object in request.multimodal_prompt.media_objects:
|
|
94
|
+
if media_object.is_type("image") and media_object.location:
|
|
95
|
+
message_content.append({"type": "image", "image": media_object.location})
|
|
96
|
+
elif media_object.is_type(TEXT_TYPE):
|
|
97
|
+
if media_object.text is None:
|
|
98
|
+
raise ValueError("MediaObject of text type has missing text field value")
|
|
99
|
+
message_content.append({"type": "text", "text": media_object.text})
|
|
100
|
+
else:
|
|
101
|
+
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
|
102
|
+
|
|
103
|
+
messages = [{"role": "user", "content": message_content}]
|
|
104
|
+
|
|
105
|
+
generation_args = {
|
|
106
|
+
"max_new_tokens": request.max_tokens,
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
completions: List[GeneratedOutput] = []
|
|
110
|
+
request_time: float = 0
|
|
111
|
+
request_datetime: Optional[int] = None
|
|
112
|
+
all_cached: bool = True
|
|
113
|
+
|
|
114
|
+
with htrack_block(f"Generating for prompt: {request.multimodal_prompt.text}"):
|
|
115
|
+
for completion_index in range(request.num_completions):
|
|
116
|
+
try:
|
|
117
|
+
|
|
118
|
+
def do_it() -> Dict[str, Any]:
|
|
119
|
+
loaded = self._get_model(request.model_engine)
|
|
120
|
+
model = loaded.model
|
|
121
|
+
processor = loaded.processor
|
|
122
|
+
|
|
123
|
+
# Prepare text and vision inputs.
|
|
124
|
+
text = processor.apply_chat_template( # type: ignore
|
|
125
|
+
messages, tokenize=False, add_generation_prompt=True
|
|
126
|
+
)
|
|
127
|
+
image_inputs, video_inputs = process_vision_info(messages)
|
|
128
|
+
inputs = processor( # type: ignore
|
|
129
|
+
text=[text],
|
|
130
|
+
images=image_inputs,
|
|
131
|
+
videos=video_inputs,
|
|
132
|
+
padding=True,
|
|
133
|
+
return_tensors="pt",
|
|
134
|
+
).to(self._device)
|
|
135
|
+
|
|
136
|
+
generated_ids = model.generate(**inputs, **generation_args)
|
|
137
|
+
# Remove the input prefix from outputs.
|
|
138
|
+
generated_ids_trimmed = [
|
|
139
|
+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
140
|
+
]
|
|
141
|
+
output_text = processor.batch_decode( # type: ignore
|
|
142
|
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
143
|
+
)
|
|
144
|
+
# For simplicity, we split tokens by whitespace.
|
|
145
|
+
completion = output_text[0]
|
|
146
|
+
tokens = completion.split()
|
|
147
|
+
return {"output": (completion, tokens)}
|
|
148
|
+
|
|
149
|
+
cache_key = CachingClient.make_cache_key(
|
|
150
|
+
raw_request={
|
|
151
|
+
"completion_index": completion_index,
|
|
152
|
+
"model": request.model,
|
|
153
|
+
"prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt),
|
|
154
|
+
**generation_args,
|
|
155
|
+
},
|
|
156
|
+
request=request,
|
|
157
|
+
)
|
|
158
|
+
result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
159
|
+
except RuntimeError as model_error:
|
|
160
|
+
return RequestResult(
|
|
161
|
+
success=False,
|
|
162
|
+
cached=False,
|
|
163
|
+
error=str(model_error),
|
|
164
|
+
completions=[],
|
|
165
|
+
embedding=[],
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
text_out, tokens = result["output"]
|
|
169
|
+
completions.append(
|
|
170
|
+
GeneratedOutput(
|
|
171
|
+
text=text_out,
|
|
172
|
+
logprob=0,
|
|
173
|
+
tokens=[Token(text=str(token), logprob=0) for token in tokens],
|
|
174
|
+
)
|
|
175
|
+
)
|
|
176
|
+
hlog(f"Generated: {text_out}")
|
|
177
|
+
request_time += result["request_time"]
|
|
178
|
+
request_datetime = request_datetime or result.get("request_datetime")
|
|
179
|
+
all_cached = all_cached and cached
|
|
180
|
+
|
|
181
|
+
return RequestResult(
|
|
182
|
+
success=True,
|
|
183
|
+
cached=all_cached,
|
|
184
|
+
request_time=request_time,
|
|
185
|
+
request_datetime=request_datetime,
|
|
186
|
+
completions=completions,
|
|
187
|
+
embedding=[],
|
|
188
|
+
)
|
|
@@ -115,14 +115,16 @@ class QwenVLMClient(CachingClient):
|
|
|
115
115
|
|
|
116
116
|
def do_it() -> Dict[str, Any]:
|
|
117
117
|
if request.model_engine == "qwen-vl-chat":
|
|
118
|
-
completion, _ = model.chat(
|
|
118
|
+
completion, _ = model.chat( # type: ignore
|
|
119
|
+
tokenizer, query=tokenizer.from_list_format(query), history=None # type: ignore
|
|
120
|
+
)
|
|
119
121
|
else:
|
|
120
|
-
inputs = tokenizer(tokenizer.from_list_format(query), return_tensors="pt")
|
|
122
|
+
inputs = tokenizer(tokenizer.from_list_format(query), return_tensors="pt") # type: ignore
|
|
121
123
|
inputs = inputs.to(self._device)
|
|
122
|
-
pred = model.generate(**inputs, **generation_args)
|
|
123
|
-
completion = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False)
|
|
124
|
+
pred = model.generate(**inputs, **generation_args) # type: ignore
|
|
125
|
+
completion = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False) # type: ignore
|
|
124
126
|
|
|
125
|
-
tokens: List[str] = tokenizer.tokenize(completion)
|
|
127
|
+
tokens: List[str] = tokenizer.tokenize(completion) # type: ignore
|
|
126
128
|
return {"output": (completion, tokens)}
|
|
127
129
|
|
|
128
130
|
# Include the prompt and model name in the cache key
|
helm/clients/vllm_client.py
CHANGED
|
@@ -2,13 +2,15 @@ from typing import Any, Dict, Optional
|
|
|
2
2
|
|
|
3
3
|
from helm.common.cache import CacheConfig
|
|
4
4
|
from helm.common.request import Request
|
|
5
|
-
from helm.clients.openai_client import
|
|
5
|
+
from helm.clients.openai_client import OpenAILegacyCompletionsClient
|
|
6
6
|
from helm.tokenizers.tokenizer import Tokenizer
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
class VLLMClient(
|
|
9
|
+
class VLLMClient(OpenAILegacyCompletionsClient):
|
|
10
10
|
"""Sends request to a vLLM server using the OpenAI-compatible API.
|
|
11
11
|
|
|
12
|
+
Only supports the legacy Text Completions API, rather than the Chat Completions API.
|
|
13
|
+
|
|
12
14
|
See: https://docs.vllm.ai/en/latest/getting_started/quickstart.html#openai-compatible-server"""
|
|
13
15
|
|
|
14
16
|
def __init__(
|
|
@@ -29,10 +31,6 @@ class VLLMClient(OpenAIClient):
|
|
|
29
31
|
self.tokenizer = tokenizer
|
|
30
32
|
self.tokenizer_name = tokenizer_name
|
|
31
33
|
|
|
32
|
-
def _is_chat_model_engine(self, model_engine: str) -> bool:
|
|
33
|
-
# Only support vLLM completion models for now.
|
|
34
|
-
return False
|
|
35
|
-
|
|
36
34
|
def _get_model_for_request(self, request: Request) -> str:
|
|
37
35
|
# The `model` parameter for vLLM should be the whole model name including the creator organization,
|
|
38
36
|
# unlike OpenAI which only uses the model engine.
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Mapping, Optional
|
|
2
|
+
|
|
3
|
+
from helm.clients.client import CachingClient
|
|
4
|
+
from helm.common.cache import CacheConfig
|
|
5
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
6
|
+
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, Token
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
from writerai import Writer
|
|
10
|
+
from writerai.types.chat_completion import ChatCompletion
|
|
11
|
+
except ModuleNotFoundError as e:
|
|
12
|
+
handle_module_not_found_error(e, ["openai"])
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class WriterClient(CachingClient):
|
|
16
|
+
def __init__(self, cache_config: CacheConfig, api_key: Optional[str] = None):
|
|
17
|
+
super().__init__(cache_config=cache_config)
|
|
18
|
+
self._writer_client = Writer(api_key=api_key)
|
|
19
|
+
|
|
20
|
+
def _get_messages_from_request(self, request: Request) -> List[Dict]:
|
|
21
|
+
if request.prompt and request.messages:
|
|
22
|
+
raise ValueError(f"Only one of `prompt` and `messages` may be set in request: {request}")
|
|
23
|
+
if request.multimodal_prompt:
|
|
24
|
+
raise ValueError("`multimodal_prompt` is not supported by WriterClient")
|
|
25
|
+
if request.messages:
|
|
26
|
+
return [{"role": message["role"], "content": message["content"]} for message in request.messages]
|
|
27
|
+
else:
|
|
28
|
+
return [{"role": "user", "content": request.prompt}]
|
|
29
|
+
|
|
30
|
+
def _convert_chat_completion_to_generated_outputs(
|
|
31
|
+
self, chat_completion: ChatCompletion, request: Request
|
|
32
|
+
) -> List[GeneratedOutput]:
|
|
33
|
+
generated_outputs: List[GeneratedOutput] = []
|
|
34
|
+
for choice in chat_completion.choices:
|
|
35
|
+
raw_completion_content = choice.message.content
|
|
36
|
+
# The Writer chat completion API doesn't support echo.
|
|
37
|
+
# If `echo_prompt` is true, combine the prompt and completion.
|
|
38
|
+
text: str = request.prompt + raw_completion_content if request.echo_prompt else raw_completion_content
|
|
39
|
+
tokens: List[Token] = []
|
|
40
|
+
if choice.logprobs and choice.logprobs.content:
|
|
41
|
+
tokens = [
|
|
42
|
+
Token(text=choice_token.token, logprob=choice_token.logprob)
|
|
43
|
+
for choice_token in choice.logprobs.content
|
|
44
|
+
]
|
|
45
|
+
generated_output = GeneratedOutput(
|
|
46
|
+
text=text,
|
|
47
|
+
logprob=sum(token.logprob for token in tokens) if tokens else 0.0,
|
|
48
|
+
tokens=tokens,
|
|
49
|
+
finish_reason={"reason": choice.finish_reason},
|
|
50
|
+
)
|
|
51
|
+
generated_outputs.append(generated_output)
|
|
52
|
+
return generated_outputs
|
|
53
|
+
|
|
54
|
+
def _convert_request_to_raw_request(self, request: Request) -> Dict:
|
|
55
|
+
raw_request = {
|
|
56
|
+
"messages": self._get_messages_from_request(request),
|
|
57
|
+
"model": request.model.split("/")[-1],
|
|
58
|
+
"logprobs": bool(request.top_k_per_token),
|
|
59
|
+
"max_tokens": request.max_tokens,
|
|
60
|
+
"n": request.num_completions,
|
|
61
|
+
"stop": request.stop_sequences,
|
|
62
|
+
"temperature": request.temperature,
|
|
63
|
+
"top_p": request.top_p,
|
|
64
|
+
}
|
|
65
|
+
if request.response_format and request.response_format.json_schema:
|
|
66
|
+
raw_request["response_format"] = {
|
|
67
|
+
"type": "json_schema",
|
|
68
|
+
"json_schema": {
|
|
69
|
+
"schema": request.response_format.json_schema,
|
|
70
|
+
},
|
|
71
|
+
}
|
|
72
|
+
return raw_request
|
|
73
|
+
|
|
74
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
75
|
+
raw_request = self._convert_request_to_raw_request(request)
|
|
76
|
+
cache_key: Mapping = CachingClient.make_cache_key(raw_request, request)
|
|
77
|
+
|
|
78
|
+
def do_it() -> Dict[Any, Any]:
|
|
79
|
+
return self._writer_client.chat.chat(**raw_request).model_dump()
|
|
80
|
+
|
|
81
|
+
try:
|
|
82
|
+
raw_response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
83
|
+
chat_completion: ChatCompletion = ChatCompletion.model_validate(raw_response)
|
|
84
|
+
except Exception as error:
|
|
85
|
+
return RequestResult(
|
|
86
|
+
success=False,
|
|
87
|
+
cached=False,
|
|
88
|
+
error=str(error),
|
|
89
|
+
completions=[],
|
|
90
|
+
embedding=[],
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
generated_outputs = self._convert_chat_completion_to_generated_outputs(chat_completion, request)
|
|
94
|
+
|
|
95
|
+
return RequestResult(
|
|
96
|
+
success=True,
|
|
97
|
+
cached=cached,
|
|
98
|
+
request_time=raw_response["request_time"],
|
|
99
|
+
request_datetime=raw_response["request_datetime"],
|
|
100
|
+
completions=generated_outputs,
|
|
101
|
+
embedding=[],
|
|
102
|
+
)
|