crfm-helm 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- crfm_helm-0.5.6.dist-info/METADATA +427 -0
- crfm_helm-0.5.6.dist-info/RECORD +941 -0
- {crfm_helm-0.5.4.dist-info → crfm_helm-0.5.6.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +13 -1
- helm/benchmark/adaptation/adapters/adapter_factory.py +15 -1
- helm/benchmark/adaptation/adapters/binary_ranking_adapter.py +1 -1
- helm/benchmark/adaptation/adapters/chat_adapter.py +49 -0
- helm/benchmark/adaptation/adapters/ehr_instruction_adapter.py +108 -0
- helm/benchmark/adaptation/adapters/generation_adapter.py +1 -1
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +4 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +1 -1
- helm/benchmark/adaptation/adapters/multimodal/generation_multimodal_adapter.py +4 -2
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -1
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +1 -1
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +4 -2
- helm/benchmark/adaptation/adapters/multimodal/test_multimodal_prompt.py +1 -1
- helm/benchmark/adaptation/adapters/multiple_choice_calibrated_adapter.py +1 -1
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +2 -2
- helm/benchmark/adaptation/adapters/multiple_choice_joint_chain_of_thought_adapter.py +87 -0
- helm/benchmark/adaptation/adapters/multiple_choice_separate_adapter.py +1 -1
- helm/benchmark/adaptation/adapters/test_adapter.py +4 -4
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +3 -3
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +2 -2
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +2 -2
- helm/benchmark/adaptation/common_adapter_specs.py +69 -4
- helm/benchmark/adaptation/prompt.py +1 -1
- helm/benchmark/annotation/aci_bench_annotator.py +95 -0
- helm/benchmark/annotation/air_bench_annotator.py +21 -6
- helm/benchmark/annotation/annotator.py +5 -0
- helm/benchmark/annotation/annotator_factory.py +3 -20
- helm/benchmark/annotation/autobencher_capabilities_annotator.py +107 -0
- helm/benchmark/annotation/autobencher_safety_annotator.py +98 -0
- helm/benchmark/annotation/bigcodebench_annotator.py +108 -0
- helm/benchmark/annotation/bird_sql_annotator.py +58 -0
- helm/benchmark/annotation/chw_care_plan_annotator.py +93 -0
- helm/benchmark/annotation/czech_bank_qa_annotator.py +78 -0
- helm/benchmark/annotation/dischargeme_annotator.py +107 -0
- helm/benchmark/annotation/ehr_sql_annotator.py +87 -0
- helm/benchmark/annotation/helpdesk_call_summarization_annotator.py +131 -0
- helm/benchmark/annotation/image2struct/image_compiler_annotator.py +6 -1
- helm/benchmark/annotation/live_qa_annotator.py +1 -1
- helm/benchmark/annotation/med_dialog_annotator.py +99 -0
- helm/benchmark/annotation/medalign_annotator.py +100 -0
- helm/benchmark/annotation/medi_qa_annotator.py +98 -0
- helm/benchmark/annotation/medication_qa_annotator.py +87 -63
- helm/benchmark/annotation/mental_health_annotator.py +98 -0
- helm/benchmark/annotation/mimic_bhc_annotator.py +100 -0
- helm/benchmark/annotation/mimic_rrs_annotator.py +100 -0
- helm/benchmark/annotation/model_as_judge.py +214 -6
- helm/benchmark/annotation/mtsamples_procedures_annotator.py +98 -0
- helm/benchmark/annotation/mtsamples_replicate_annotator.py +101 -0
- helm/benchmark/annotation/omni_math/gpt_evaluation_template.txt +152 -0
- helm/benchmark/annotation/omni_math/gpt_evaluation_zero_shot_template.txt +36 -0
- helm/benchmark/annotation/omni_math_annotator.py +131 -0
- helm/benchmark/annotation/spider_annotator.py +18 -0
- helm/benchmark/annotation/starr_patient_instructions_annotator.py +98 -0
- helm/benchmark/annotation/wildbench/eval_template.pairwise.v2.md +75 -0
- helm/benchmark/annotation/wildbench/eval_template.score.v2.md +66 -0
- helm/benchmark/annotation/wildbench_annotator.py +119 -0
- helm/benchmark/annotation_executor.py +35 -15
- helm/benchmark/augmentations/cleva_perturbation.py +9 -8
- helm/benchmark/augmentations/contraction_expansion_perturbation.py +2 -2
- helm/benchmark/augmentations/contrast_sets_perturbation.py +2 -2
- helm/benchmark/augmentations/dialect_perturbation.py +4 -5
- helm/benchmark/augmentations/extra_space_perturbation.py +2 -2
- helm/benchmark/augmentations/filler_words_perturbation.py +2 -2
- helm/benchmark/augmentations/gender_perturbation.py +2 -2
- helm/benchmark/augmentations/lowercase_perturbation.py +2 -2
- helm/benchmark/augmentations/mild_mix_perturbation.py +6 -6
- helm/benchmark/augmentations/misspelling_perturbation.py +2 -2
- helm/benchmark/augmentations/person_name_perturbation.py +4 -5
- helm/benchmark/augmentations/perturbation.py +1 -1
- helm/benchmark/augmentations/space_perturbation.py +2 -2
- helm/benchmark/augmentations/suffix_perturbation.py +2 -2
- helm/benchmark/augmentations/synonym_perturbation.py +4 -3
- helm/benchmark/augmentations/test_perturbation.py +16 -13
- helm/benchmark/augmentations/translate_perturbation.py +2 -2
- helm/benchmark/augmentations/typos_perturbation.py +2 -2
- helm/benchmark/data_preprocessor.py +2 -2
- helm/benchmark/executor.py +11 -12
- helm/benchmark/huggingface_registration.py +2 -7
- helm/benchmark/metrics/aci_bench_metrics.py +14 -0
- helm/benchmark/metrics/basic_metrics.py +6 -6
- helm/benchmark/metrics/bbq_metrics.py +2 -2
- helm/benchmark/metrics/bias_metrics.py +12 -3
- helm/benchmark/metrics/bias_word_lists.py +1 -1
- helm/benchmark/metrics/bigcodebench_metrics.py +25 -0
- helm/benchmark/metrics/bird_sql_metrics.py +28 -0
- helm/benchmark/metrics/chw_care_plan_metrics.py +14 -0
- helm/benchmark/metrics/classification_metrics.py +76 -12
- helm/benchmark/metrics/cleva_harms_metrics.py +10 -9
- helm/benchmark/metrics/code_metrics.py +5 -5
- helm/benchmark/metrics/comet_metric.py +125 -0
- helm/benchmark/metrics/common_metric_specs.py +9 -2
- helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +72 -0
- helm/benchmark/metrics/copyright_metrics.py +4 -4
- helm/benchmark/metrics/czech_bank_qa_metrics.py +29 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +2 -2
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +2 -2
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +2 -2
- helm/benchmark/metrics/dischargeme_metrics.py +14 -0
- helm/benchmark/metrics/disinformation_metrics.py +4 -4
- helm/benchmark/metrics/dry_run_metrics.py +5 -5
- helm/benchmark/metrics/efficiency_metrics.py +6 -6
- helm/benchmark/metrics/ehr_sql_metrics.py +103 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +3 -3
- helm/benchmark/metrics/evaluate_reference_metrics.py +144 -16
- helm/benchmark/metrics/gpqa_chain_of_thought_metric.py +103 -0
- helm/benchmark/metrics/gpt4_audio_critique_metrics.py +167 -0
- helm/benchmark/metrics/gpt4_audio_refusal_metrics.py +145 -0
- helm/benchmark/metrics/helpdesk_call_summarization_metrics.py +36 -0
- helm/benchmark/metrics/ifeval/__init__.py +0 -0
- helm/benchmark/metrics/ifeval/instructions.py +1574 -0
- helm/benchmark/metrics/ifeval/instructions_registry.py +182 -0
- helm/benchmark/metrics/ifeval/instructions_registry.pyi +3 -0
- helm/benchmark/metrics/ifeval/instructions_util.py +153 -0
- helm/benchmark/metrics/ifeval_metrics.py +55 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +1 -1
- helm/benchmark/metrics/image_generation/detection_metrics.py +1 -1
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +1 -1
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +1 -1
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +1 -1
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +1 -1
- helm/benchmark/metrics/image_generation/q16/test_q16.py +3 -1
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +1 -1
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +2 -2
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +1 -1
- helm/benchmark/metrics/image_generation/watermark_metrics.py +1 -1
- helm/benchmark/metrics/instruction_following_critique_metrics.py +4 -4
- helm/benchmark/metrics/kpi_edgar_metrics.py +121 -0
- helm/benchmark/metrics/language_modeling_metrics.py +4 -4
- helm/benchmark/metrics/llm_jury_metrics.py +46 -0
- helm/benchmark/metrics/machine_translation_metrics.py +2 -2
- helm/benchmark/metrics/med_dialog_metrics.py +14 -0
- helm/benchmark/metrics/medalign_metrics.py +14 -0
- helm/benchmark/metrics/medcalc_bench_metrics.py +124 -0
- helm/benchmark/metrics/medec_metrics.py +101 -0
- helm/benchmark/metrics/medi_qa_metrics.py +14 -0
- helm/benchmark/metrics/medication_qa_metrics.py +10 -19
- helm/benchmark/metrics/melt_bias_metric.py +234 -0
- helm/benchmark/metrics/melt_bias_word_lists.py +1367 -0
- helm/benchmark/metrics/melt_metric_specs.py +43 -0
- helm/benchmark/metrics/melt_toxicity_metric.py +107 -0
- helm/benchmark/metrics/mental_health_metrics.py +14 -0
- helm/benchmark/metrics/metric.py +3 -3
- helm/benchmark/metrics/metric_service.py +11 -11
- helm/benchmark/metrics/mimic_bhc_metrics.py +14 -0
- helm/benchmark/metrics/mimic_rrs_metrics.py +14 -0
- helm/benchmark/metrics/mimiciv_billing_code_metrics.py +96 -0
- helm/benchmark/metrics/mtsamples_procedures_metrics.py +14 -0
- helm/benchmark/metrics/mtsamples_replicate_metrics.py +14 -0
- helm/benchmark/metrics/nltk_helper.py +32 -0
- helm/benchmark/metrics/numeracy_metrics.py +4 -4
- helm/benchmark/metrics/omni_math_metrics.py +32 -0
- helm/benchmark/metrics/openai_mrcr_metrics.py +52 -0
- helm/benchmark/metrics/output_processing_metric.py +60 -0
- helm/benchmark/metrics/output_processors.py +15 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +2 -2
- helm/benchmark/metrics/ranking_metrics.py +3 -3
- helm/benchmark/metrics/reference_metric.py +3 -3
- helm/benchmark/metrics/ruler_qa_metrics.py +34 -0
- helm/benchmark/metrics/{bhasa_metrics.py → seahelm_metrics.py} +3 -3
- helm/benchmark/metrics/seahelm_metrics_specs.py +10 -0
- helm/benchmark/metrics/spider_metrics.py +7 -0
- helm/benchmark/metrics/starr_patient_instructions_metrics.py +14 -0
- helm/benchmark/metrics/statistic.py +1 -1
- helm/benchmark/metrics/summac/model_summac.py +2 -3
- helm/benchmark/metrics/summarization_critique_metrics.py +4 -4
- helm/benchmark/metrics/summarization_metrics.py +20 -9
- helm/benchmark/metrics/test_bias_metrics.py +5 -1
- helm/benchmark/metrics/test_classification_metrics.py +140 -68
- helm/benchmark/metrics/test_evaluate_reference_metrics.py +15 -0
- helm/benchmark/metrics/test_metric.py +1 -1
- helm/benchmark/metrics/test_statistic.py +2 -2
- helm/benchmark/metrics/tokens/ai21_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/auto_token_cost_estimator.py +6 -6
- helm/benchmark/metrics/tokens/cohere_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/free_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/openai_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +3 -3
- helm/benchmark/metrics/toxicity_metrics.py +6 -6
- helm/benchmark/metrics/unitxt_metrics.py +7 -5
- helm/benchmark/metrics/vision_language/emd_utils.py +4 -2
- helm/benchmark/metrics/vision_language/image_metrics.py +1 -1
- helm/benchmark/metrics/vision_language/image_utils.py +2 -2
- helm/benchmark/metrics/wildbench_metrics.py +34 -0
- helm/benchmark/model_deployment_registry.py +6 -8
- helm/benchmark/model_metadata_registry.py +16 -0
- helm/benchmark/presentation/contamination.py +3 -3
- helm/benchmark/presentation/create_plots.py +33 -12
- helm/benchmark/presentation/run_display.py +13 -0
- helm/benchmark/presentation/schema.py +2 -1
- helm/benchmark/presentation/summarize.py +97 -67
- helm/benchmark/presentation/torr_robustness_summarizer.py +178 -0
- helm/benchmark/reeval_run.py +202 -0
- helm/benchmark/reeval_runner.py +355 -0
- helm/benchmark/run.py +86 -90
- helm/benchmark/run_expander.py +90 -9
- helm/benchmark/run_spec_factory.py +13 -0
- helm/benchmark/run_specs/air_bench_run_specs.py +21 -3
- helm/benchmark/run_specs/audio_run_specs.py +657 -0
- helm/benchmark/run_specs/call_center_run_specs.py +49 -0
- helm/benchmark/run_specs/capabilities_run_specs.py +308 -0
- helm/benchmark/run_specs/classic_run_specs.py +1 -69
- helm/benchmark/run_specs/enem_challenge_specs.py +31 -0
- helm/benchmark/run_specs/enterprise_run_specs.py +280 -0
- helm/benchmark/run_specs/experimental_run_specs.py +142 -3
- helm/benchmark/run_specs/imdb_ptbr_run_specs.py +30 -0
- helm/benchmark/run_specs/lite_run_specs.py +2 -2
- helm/benchmark/run_specs/long_context_run_specs.py +141 -0
- helm/benchmark/run_specs/medhelm_run_specs.py +1260 -0
- helm/benchmark/run_specs/melt_run_specs.py +783 -0
- helm/benchmark/run_specs/mmlu_clinical_afr_run_specs.py +49 -0
- helm/benchmark/run_specs/oab_exams_specs.py +32 -0
- helm/benchmark/run_specs/safety_run_specs.py +37 -0
- helm/benchmark/run_specs/{bhasa_run_specs.py → seahelm_run_specs.py} +44 -44
- helm/benchmark/run_specs/speech_disorder_audio_run_specs.py +169 -0
- helm/benchmark/run_specs/sql_run_specs.py +54 -0
- helm/benchmark/run_specs/tweetsentbr_run_specs.py +32 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +14 -5
- helm/benchmark/run_specs/vlm_run_specs.py +103 -2
- helm/benchmark/run_specs/winogrande_afr_run_specs.py +47 -0
- helm/benchmark/runner.py +5 -5
- helm/benchmark/scenarios/aci_bench_scenario.py +126 -0
- helm/benchmark/scenarios/air_bench_scenario.py +6 -1
- helm/benchmark/scenarios/anthropic_hh_rlhf_scenario.py +5 -3
- helm/benchmark/scenarios/anthropic_red_team_scenario.py +1 -1
- helm/benchmark/scenarios/audio_language/__init__.py +0 -0
- helm/benchmark/scenarios/audio_language/air_bench_chat_scenario.py +130 -0
- helm/benchmark/scenarios/audio_language/air_bench_foundation_scenario.py +154 -0
- helm/benchmark/scenarios/audio_language/ami_scenario.py +96 -0
- helm/benchmark/scenarios/audio_language/audio_mnist_scenario.py +62 -0
- helm/benchmark/scenarios/audio_language/audio_pairs_scenario.py +62 -0
- helm/benchmark/scenarios/audio_language/audiocaps_scenario.py +59 -0
- helm/benchmark/scenarios/audio_language/casual_conversations2_scenario.py +152 -0
- helm/benchmark/scenarios/audio_language/common_voice_15_scenario.py +99 -0
- helm/benchmark/scenarios/audio_language/corebench_scenario.py +77 -0
- helm/benchmark/scenarios/audio_language/covost2_scenario.py +163 -0
- helm/benchmark/scenarios/audio_language/fleurs_fairness_scenario.py +83 -0
- helm/benchmark/scenarios/audio_language/fleurs_scenario.py +312 -0
- helm/benchmark/scenarios/audio_language/iemocap_audio_scenario.py +83 -0
- helm/benchmark/scenarios/audio_language/librispeech_fairness_scenario.py +96 -0
- helm/benchmark/scenarios/audio_language/librispeech_scenario.py +80 -0
- helm/benchmark/scenarios/audio_language/meld_audio_scenario.py +113 -0
- helm/benchmark/scenarios/audio_language/multilingual_librispeech_scenario.py +80 -0
- helm/benchmark/scenarios/audio_language/mustard_scenario.py +142 -0
- helm/benchmark/scenarios/audio_language/mutox_scenario.py +254 -0
- helm/benchmark/scenarios/audio_language/parade_scenario.py +97 -0
- helm/benchmark/scenarios/audio_language/speech_robust_bench_scenario.py +124 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_asr_classification.py +103 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_classification_scenario.py +110 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_breakdown_scenario.py +78 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_symptoms_scenario.py +109 -0
- helm/benchmark/scenarios/audio_language/vocal_sound_scenario.py +83 -0
- helm/benchmark/scenarios/audio_language/voice_jailbreak_attacks_scenario.py +87 -0
- helm/benchmark/scenarios/audio_language/voxceleb2_scenario.py +105 -0
- helm/benchmark/scenarios/autobencher_capabilities_scenario.py +68 -0
- helm/benchmark/scenarios/autobencher_safety_scenario.py +51 -0
- helm/benchmark/scenarios/babi_qa_scenario.py +1 -1
- helm/benchmark/scenarios/banking77_scenario.py +6 -1
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/big_bench_scenario.py +11 -1
- helm/benchmark/scenarios/bigcodebench_scenario.py +58 -0
- helm/benchmark/scenarios/bird_sql_scenario.py +94 -0
- helm/benchmark/scenarios/bird_sql_scenario_helper.py +118 -0
- helm/benchmark/scenarios/blimp_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +1 -1
- helm/benchmark/scenarios/boolq_scenario.py +1 -1
- helm/benchmark/scenarios/casehold_scenario.py +79 -0
- helm/benchmark/scenarios/chw_care_plan_scenario.py +106 -0
- helm/benchmark/scenarios/civil_comments_scenario.py +1 -1
- helm/benchmark/scenarios/clear_scenario.py +157 -0
- helm/benchmark/scenarios/cleva_scenario.py +2 -2
- helm/benchmark/scenarios/code_scenario.py +17 -4
- helm/benchmark/scenarios/commonsense_scenario.py +1 -1
- helm/benchmark/scenarios/conv_fin_qa_calc_scenario.py +97 -0
- helm/benchmark/scenarios/copyright_scenario.py +1 -1
- helm/benchmark/scenarios/covid_dialog_scenario.py +10 -1
- helm/benchmark/scenarios/cti_to_mitre_scenario.py +240 -0
- helm/benchmark/scenarios/custom_mcqa_scenario.py +1 -1
- helm/benchmark/scenarios/czech_bank_qa_scenario.py +130 -0
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +1 -1
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +1 -1
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +1 -1
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +1 -1
- helm/benchmark/scenarios/dialogue_scenarios.py +13 -2
- helm/benchmark/scenarios/dischargeme_scenario.py +172 -0
- helm/benchmark/scenarios/disinformation_scenario.py +10 -1
- helm/benchmark/scenarios/dyck_language_scenario.py +10 -1
- helm/benchmark/scenarios/echr_judgment_classification_scenario.py +113 -0
- helm/benchmark/scenarios/ehr_sql_scenario.py +137 -0
- helm/benchmark/scenarios/ehrshot_scenario.py +1519 -0
- helm/benchmark/scenarios/enem_challenge_scenario.py +58 -0
- helm/benchmark/scenarios/entity_data_imputation_scenario.py +11 -1
- helm/benchmark/scenarios/entity_matching_scenario.py +12 -2
- helm/benchmark/scenarios/financial_phrasebank_scenario.py +94 -0
- helm/benchmark/scenarios/gold_commodity_news_scenario.py +124 -0
- helm/benchmark/scenarios/gpqa_scenario.py +80 -0
- helm/benchmark/scenarios/grammar.py +2 -2
- helm/benchmark/scenarios/grammar_scenario.py +2 -2
- helm/benchmark/scenarios/gsm_scenario.py +10 -1
- helm/benchmark/scenarios/harm_bench_gcg_transfer_scenario.py +50 -0
- helm/benchmark/scenarios/harm_bench_scenario.py +1 -1
- helm/benchmark/scenarios/headqa_scenario.py +136 -0
- helm/benchmark/scenarios/helpdesk_call_summarization_scenario.py +37 -0
- helm/benchmark/scenarios/ice_scenario.py +8 -4
- helm/benchmark/scenarios/ifeval_scenario.py +53 -0
- helm/benchmark/scenarios/imdb_ptbr_scenario.py +60 -0
- helm/benchmark/scenarios/imdb_scenario.py +11 -2
- helm/benchmark/scenarios/infinite_bench_en_qa_scenario.py +85 -0
- helm/benchmark/scenarios/infinite_bench_en_sum_scenario.py +79 -0
- helm/benchmark/scenarios/interactive_qa_mmlu_scenario.py +2 -2
- helm/benchmark/scenarios/koala_scenario.py +1 -1
- helm/benchmark/scenarios/kpi_edgar_scenario.py +151 -0
- helm/benchmark/scenarios/legal_contract_summarization_scenario.py +129 -0
- helm/benchmark/scenarios/legal_opinion_sentiment_classification_scenario.py +77 -0
- helm/benchmark/scenarios/legal_summarization_scenario.py +11 -1
- helm/benchmark/scenarios/legal_support_scenario.py +11 -1
- helm/benchmark/scenarios/legalbench_scenario.py +22 -3
- helm/benchmark/scenarios/lex_glue_scenario.py +12 -2
- helm/benchmark/scenarios/lextreme_scenario.py +11 -1
- helm/benchmark/scenarios/live_qa_scenario.py +1 -1
- helm/benchmark/scenarios/lm_entry_scenario.py +1 -1
- helm/benchmark/scenarios/lsat_qa_scenario.py +1 -1
- helm/benchmark/scenarios/math_scenario.py +9 -1
- helm/benchmark/scenarios/me_q_sum_scenario.py +10 -1
- helm/benchmark/scenarios/med_dialog_scenario.py +25 -22
- helm/benchmark/scenarios/med_mcqa_scenario.py +10 -1
- helm/benchmark/scenarios/med_paragraph_simplification_scenario.py +10 -1
- helm/benchmark/scenarios/med_qa_scenario.py +10 -1
- helm/benchmark/scenarios/medalign_scenario.py +94 -0
- helm/benchmark/scenarios/medalign_scenario_helper.py +432 -0
- helm/benchmark/scenarios/medbullets_scenario.py +145 -0
- helm/benchmark/scenarios/medcalc_bench_scenario.py +127 -0
- helm/benchmark/scenarios/medec_scenario.py +125 -0
- helm/benchmark/scenarios/medhallu_scenario.py +72 -0
- helm/benchmark/scenarios/medi_qa_scenario.py +111 -0
- helm/benchmark/scenarios/medication_qa_scenario.py +8 -2
- helm/benchmark/scenarios/melt_ir_scenario.py +171 -0
- helm/benchmark/scenarios/melt_knowledge_scenario.py +246 -0
- helm/benchmark/scenarios/melt_lm_scenarios.py +252 -0
- helm/benchmark/scenarios/melt_scenarios.py +793 -0
- helm/benchmark/scenarios/melt_srn_scenario.py +342 -0
- helm/benchmark/scenarios/melt_synthetic_reasoning_scenario.py +222 -0
- helm/benchmark/scenarios/melt_translation_scenario.py +152 -0
- helm/benchmark/scenarios/mental_health_scenario.py +123 -0
- helm/benchmark/scenarios/mimic_bhc_scenario.py +103 -0
- helm/benchmark/scenarios/mimic_rrs_scenario.py +98 -0
- helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +77 -0
- helm/benchmark/scenarios/mmlu_clinical_afr_scenario.py +74 -0
- helm/benchmark/scenarios/mmlu_pro_scenario.py +95 -0
- helm/benchmark/scenarios/mmlu_scenario.py +11 -1
- helm/benchmark/scenarios/msmarco_scenario.py +1 -1
- helm/benchmark/scenarios/mtsamples_procedures_scenario.py +144 -0
- helm/benchmark/scenarios/mtsamples_replicate_scenario.py +142 -0
- helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +277 -0
- helm/benchmark/scenarios/narrativeqa_scenario.py +1 -1
- helm/benchmark/scenarios/natural_qa_scenario.py +1 -1
- helm/benchmark/scenarios/newsqa_scenario.py +1 -1
- helm/benchmark/scenarios/numeracy_scenario.py +12 -2
- helm/benchmark/scenarios/oab_exams_scenario.py +57 -0
- helm/benchmark/scenarios/omni_math_scenario.py +53 -0
- helm/benchmark/scenarios/open_assistant_scenario.py +11 -2
- helm/benchmark/scenarios/openai_mrcr_scenario.py +79 -0
- helm/benchmark/scenarios/opinions_qa_scenario.py +1 -1
- helm/benchmark/scenarios/pubmed_qa_scenario.py +59 -43
- helm/benchmark/scenarios/quac_scenario.py +10 -1
- helm/benchmark/scenarios/race_based_med_scenario.py +152 -0
- helm/benchmark/scenarios/raft_scenario.py +17 -2
- helm/benchmark/scenarios/real_toxicity_prompts_scenario.py +1 -1
- helm/benchmark/scenarios/ruler_qa_scenario_helper.py +171 -0
- helm/benchmark/scenarios/ruler_qa_scenarios.py +88 -0
- helm/benchmark/scenarios/scenario.py +9 -1
- helm/benchmark/scenarios/{bhasa_scenario.py → seahelm_scenario.py} +7 -2
- helm/benchmark/scenarios/self_instruct_scenario.py +1 -1
- helm/benchmark/scenarios/shc_bmt_scenario.py +75 -0
- helm/benchmark/scenarios/shc_cdi_scenario.py +75 -0
- helm/benchmark/scenarios/shc_conf_scenario.py +76 -0
- helm/benchmark/scenarios/shc_ent_scenario.py +77 -0
- helm/benchmark/scenarios/shc_gip_scenario.py +74 -0
- helm/benchmark/scenarios/shc_privacy_scenario.py +78 -0
- helm/benchmark/scenarios/shc_proxy_scenario.py +76 -0
- helm/benchmark/scenarios/shc_ptbm_scenario.py +81 -0
- helm/benchmark/scenarios/shc_sei_scenario.py +94 -0
- helm/benchmark/scenarios/shc_sequoia_scenario.py +77 -0
- helm/benchmark/scenarios/simple_safety_tests_scenario.py +1 -1
- helm/benchmark/scenarios/spider_scenario.py +91 -0
- helm/benchmark/scenarios/starr_patient_instructions_scenario.py +97 -0
- helm/benchmark/scenarios/summarization_scenario.py +11 -1
- helm/benchmark/scenarios/sumosum_scenario.py +157 -0
- helm/benchmark/scenarios/synthetic_efficiency_scenario.py +1 -1
- helm/benchmark/scenarios/synthetic_reasoning_natural_scenario.py +11 -1
- helm/benchmark/scenarios/synthetic_reasoning_scenario.py +11 -1
- helm/benchmark/scenarios/test_bigcodebench_scenario.py +26 -0
- helm/benchmark/scenarios/test_czech_bank_qa_scenario.py +18 -0
- helm/benchmark/scenarios/test_enem_challenge_scenario.py +53 -0
- helm/benchmark/scenarios/test_ewok_scenario.py +6 -2
- helm/benchmark/scenarios/test_gold_commodity_news_scenario.py +18 -0
- helm/benchmark/scenarios/test_gpqa_scenario.py +44 -0
- helm/benchmark/scenarios/test_ifeval_scenario.py +36 -0
- helm/benchmark/scenarios/test_imdb_ptbr_scenario.py +27 -0
- helm/benchmark/scenarios/test_infinite_bench_en_qa_scenario.py +18 -0
- helm/benchmark/scenarios/test_infinite_bench_en_sum_scenario.py +31 -0
- helm/benchmark/scenarios/test_math_scenario.py +1 -0
- helm/benchmark/scenarios/test_mmlu_clinical_afr_scenario.py +21 -0
- helm/benchmark/scenarios/test_mmlu_pro_scenario.py +53 -0
- helm/benchmark/scenarios/test_oab_exams_scenario.py +51 -0
- helm/benchmark/scenarios/test_omni_math_scenario.py +27 -0
- helm/benchmark/scenarios/test_tweetsentbr_scenario.py +24 -0
- helm/benchmark/scenarios/test_wildbench_scenario.py +15 -0
- helm/benchmark/scenarios/test_winogrande_afr_scenario.py +19 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +10 -1
- helm/benchmark/scenarios/the_pile_scenario.py +1 -1
- helm/benchmark/scenarios/truthful_qa_scenario.py +12 -2
- helm/benchmark/scenarios/tweetsentbr_scenario.py +66 -0
- helm/benchmark/scenarios/twitter_aae_scenario.py +1 -1
- helm/benchmark/scenarios/unitxt_scenario.py +8 -2
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +1 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/blink_scenario.py +140 -0
- helm/benchmark/scenarios/vision_language/mm_star_scenario.py +95 -0
- helm/benchmark/scenarios/vision_language/msr_vtt_scenario.py +75 -0
- helm/benchmark/scenarios/vision_language/vqa_rad_scenario.py +88 -0
- helm/benchmark/scenarios/wikifact_scenario.py +11 -1
- helm/benchmark/scenarios/wikitext_103_scenario.py +1 -1
- helm/benchmark/scenarios/wildbench_scenario.py +83 -0
- helm/benchmark/scenarios/winogrande_afr_scenario.py +78 -0
- helm/benchmark/scenarios/wmt_14_scenario.py +14 -2
- helm/benchmark/scenarios/xstest_scenario.py +1 -1
- helm/benchmark/server.py +13 -1
- helm/benchmark/slurm_runner.py +1 -1
- helm/benchmark/static/schema_audio.yaml +763 -0
- helm/benchmark/static/schema_autobencher.yaml +150 -0
- helm/benchmark/static/schema_call_center.yaml +97 -60
- helm/benchmark/static/{schema_medical.yaml → schema_capabilities.yaml} +100 -101
- helm/benchmark/static/schema_czech_bank.yaml +148 -0
- helm/benchmark/static/schema_enem_challenge.yaml +146 -0
- helm/benchmark/static/schema_enterprise.yaml +319 -0
- helm/benchmark/static/schema_finance.yaml +14 -12
- helm/benchmark/static/schema_heim.yaml +1389 -0
- helm/benchmark/static/schema_long_context.yaml +283 -0
- helm/benchmark/static/schema_medhelm.yaml +1140 -0
- helm/benchmark/static/schema_melt.yaml +1257 -0
- helm/benchmark/static/schema_mmlu_winogrande_afr.yaml +1045 -0
- helm/benchmark/static/schema_safety.yaml +18 -1
- helm/benchmark/static/{schema_bhasa.yaml → schema_seahelm.yaml} +30 -16
- helm/benchmark/static/schema_slphelm.yaml +162 -0
- helm/benchmark/static/schema_social_audio.yaml +224 -0
- helm/benchmark/static/schema_sql.yaml +171 -0
- helm/benchmark/static/{schema_tables.yaml → schema_torr.yaml} +169 -36
- helm/benchmark/static/schema_tweetsentbr.yaml +146 -0
- helm/benchmark/static/schema_vhelm.yaml +129 -56
- helm/benchmark/static/schema_video.yaml +219 -0
- helm/benchmark/static_build/assets/helm-safety-2907a7b6.png +0 -0
- helm/benchmark/static_build/assets/index-94295e78.js +10 -0
- helm/benchmark/static_build/assets/index-b9779128.css +1 -0
- helm/benchmark/static_build/assets/medhelm-overview-eac29843.png +0 -0
- helm/benchmark/static_build/assets/medhelm-v1-overview-3ddfcd65.png +0 -0
- helm/benchmark/static_build/assets/{react-d4a0b69b.js → react-f82877fd.js} +1 -1
- helm/benchmark/static_build/assets/{recharts-6d337683.js → recharts-4037aff0.js} +1 -1
- helm/benchmark/static_build/assets/{tremor-54a99cc4.js → tremor-38a10867.js} +2 -2
- helm/benchmark/static_build/config.js +1 -1
- helm/benchmark/static_build/index.html +6 -6
- helm/benchmark/window_services/default_window_service.py +1 -1
- helm/benchmark/window_services/encoder_decoder_window_service.py +4 -4
- helm/benchmark/window_services/ice_window_service.py +1 -1
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +1 -1
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +1 -1
- helm/benchmark/window_services/local_window_service.py +2 -2
- helm/benchmark/window_services/test_anthropic_window_service.py +3 -3
- helm/benchmark/window_services/test_bloom_window_service.py +3 -3
- helm/benchmark/window_services/test_gpt2_window_service.py +7 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +8 -3
- helm/benchmark/window_services/test_gptj_window_service.py +8 -3
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -3
- helm/benchmark/window_services/test_openai_window_service.py +8 -3
- helm/benchmark/window_services/test_opt_window_service.py +3 -3
- helm/benchmark/window_services/test_palmyra_window_service.py +3 -3
- helm/benchmark/window_services/test_t0pp_window_service.py +3 -3
- helm/benchmark/window_services/test_t511b_window_service.py +3 -3
- helm/benchmark/window_services/test_ul2_window_service.py +3 -3
- helm/benchmark/window_services/test_utils.py +4 -5
- helm/benchmark/window_services/test_yalm_window_service.py +3 -3
- helm/benchmark/window_services/tokenizer_service.py +7 -8
- helm/benchmark/window_services/yalm_window_service.py +1 -1
- helm/clients/ai21_client.py +3 -3
- helm/clients/aleph_alpha_client.py +1 -1
- helm/clients/anthropic_client.py +69 -29
- helm/clients/audio_language/__init__.py +0 -0
- helm/clients/audio_language/diva_llama_client.py +120 -0
- helm/clients/audio_language/llama_omni_client.py +198 -0
- helm/clients/audio_language/qwen2_5_omni_client.py +197 -0
- helm/clients/audio_language/qwen2_audiolm_client.py +190 -0
- helm/clients/audio_language/qwen_audiolm_client.py +152 -0
- helm/clients/audio_language/test.py +62 -0
- helm/clients/auto_client.py +4 -2
- helm/clients/azure_openai_client.py +55 -0
- helm/clients/bedrock_client.py +203 -7
- helm/clients/bedrock_utils.py +33 -0
- helm/clients/client.py +7 -7
- helm/clients/clip_scorers/clip_scorer.py +1 -1
- helm/clients/clip_scorers/multilingual_clip_scorer.py +1 -1
- helm/clients/cohere_client.py +3 -3
- helm/clients/google_client.py +1 -1
- helm/clients/grok_client.py +36 -0
- helm/clients/http_model_client.py +1 -1
- helm/clients/huggingface_client.py +52 -21
- helm/clients/huggingface_pipeline_client.py +138 -0
- helm/clients/ibm_client.py +267 -0
- helm/clients/image_generation/adobe_vision_client.py +1 -1
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +1 -1
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +3 -3
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +5 -2
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +5 -2
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +2 -2
- helm/clients/image_generation/cogview2_client.py +1 -1
- helm/clients/image_generation/dalle2_client.py +1 -1
- helm/clients/image_generation/dalle3_client.py +2 -2
- helm/clients/image_generation/dalle_mini/__init__.py +1 -1
- helm/clients/image_generation/dalle_mini/data.py +1 -1
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -5
- helm/clients/image_generation/dalle_mini/model/configuration.py +2 -2
- helm/clients/image_generation/dalle_mini/model/modeling.py +3 -3
- helm/clients/image_generation/dalle_mini/model/processor.py +5 -5
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +2 -2
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -1
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +2 -2
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +1 -1
- helm/clients/image_generation/dalle_mini_client.py +1 -1
- helm/clients/image_generation/deep_floyd_client.py +1 -1
- helm/clients/image_generation/huggingface_diffusers_client.py +1 -1
- helm/clients/image_generation/lexica_client.py +1 -1
- helm/clients/image_generation/mindalle/models/__init__.py +6 -6
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +1 -1
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +1 -1
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -3
- helm/clients/image_generation/mindalle_client.py +1 -1
- helm/clients/image_generation/together_image_generation_client.py +1 -1
- helm/clients/lit_gpt_client.py +2 -2
- helm/clients/mistral_client.py +62 -18
- helm/clients/nvidia_nim_client.py +0 -3
- helm/clients/openai_client.py +308 -43
- helm/clients/openai_responses_client.py +174 -0
- helm/clients/palmyra_client.py +3 -9
- helm/clients/reka_client.py +3 -3
- helm/clients/stanfordhealthcare_azure_openai_client.py +58 -0
- helm/clients/stanfordhealthcare_claude_client.py +31 -0
- helm/clients/stanfordhealthcare_google_client.py +43 -0
- helm/clients/stanfordhealthcare_http_model_client.py +93 -0
- helm/clients/stanfordhealthcare_openai_client.py +62 -0
- helm/clients/stanfordhealthcare_shc_openai_client.py +42 -0
- helm/clients/test_client.py +1 -1
- helm/clients/test_together_client.py +6 -1
- helm/clients/together_client.py +76 -9
- helm/clients/upstage_client.py +23 -0
- helm/clients/vertexai_client.py +45 -13
- helm/clients/vision_language/huggingface_vision2seq_client.py +6 -4
- helm/clients/vision_language/huggingface_vlm_client.py +2 -2
- helm/clients/vision_language/idefics_client.py +6 -2
- helm/clients/vision_language/open_flamingo/__init__.py +2 -2
- helm/clients/vision_language/open_flamingo/src/factory.py +3 -3
- helm/clients/vision_language/open_flamingo/src/flamingo.py +2 -2
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +2 -2
- helm/clients/vision_language/paligemma_client.py +2 -2
- helm/clients/vision_language/qwen2_vlm_client.py +188 -0
- helm/clients/vision_language/qwen_vlm_client.py +7 -5
- helm/clients/vllm_client.py +4 -6
- helm/clients/writer_client.py +102 -0
- helm/clients/yi_client.py +0 -3
- helm/common/audio_utils.py +111 -0
- helm/common/context.py +80 -0
- helm/common/credentials_utils.py +5 -5
- helm/common/file_caches/local_file_cache.py +1 -1
- helm/common/file_caches/test_local_file_cache.py +1 -1
- helm/common/general.py +9 -2
- helm/common/hierarchical_logger.py +46 -3
- helm/common/images_utils.py +2 -2
- helm/common/local_context.py +140 -0
- helm/common/media_object.py +2 -2
- helm/common/multimodal_request_utils.py +26 -0
- helm/common/reeval_parameters.py +12 -0
- helm/common/remote_context.py +61 -0
- helm/common/request.py +14 -2
- helm/common/response_format.py +18 -0
- helm/common/test_media_object.py +1 -1
- helm/config/model_deployments.yaml +1792 -28
- helm/config/model_metadata.yaml +1606 -51
- helm/config/tokenizer_configs.yaml +521 -4
- helm/proxy/cli.py +5 -3
- helm/proxy/critique/mechanical_turk_utils.py +1 -1
- helm/proxy/example_queries.py +1 -1
- helm/proxy/server.py +11 -4
- helm/proxy/services/remote_service.py +1 -1
- helm/proxy/services/server_service.py +22 -86
- helm/proxy/services/test_remote_service.py +2 -2
- helm/proxy/services/test_service.py +1 -1
- helm/proxy/static/general.js +122 -0
- helm/proxy/static/help.html +99 -0
- helm/proxy/static/index.css +57 -0
- helm/proxy/static/index.html +40 -0
- helm/proxy/static/index.js +456 -0
- helm/proxy/static/info-icon.png +0 -0
- helm/proxy/test_retry.py +1 -1
- helm/proxy/token_counters/auto_token_counter.py +1 -1
- helm/tokenizers/aleph_alpha_tokenizer.py +1 -1
- helm/tokenizers/caching_tokenizer.py +2 -30
- helm/tokenizers/grok_tokenizer.py +53 -0
- helm/tokenizers/http_model_tokenizer.py +1 -1
- helm/tokenizers/huggingface_tokenizer.py +3 -3
- helm/tokenizers/lit_gpt_tokenizer.py +1 -1
- helm/tokenizers/test_anthropic_tokenizer.py +6 -2
- helm/tokenizers/test_grok_tokenizer.py +33 -0
- helm/tokenizers/test_huggingface_tokenizer.py +1 -1
- helm/tokenizers/test_yalm_tokenizer.py +1 -1
- helm/tokenizers/tiktoken_tokenizer.py +1 -1
- helm/tokenizers/tokenizer.py +3 -1
- helm/tokenizers/yalm_tokenizer.py +3 -3
- helm/tokenizers/yalm_tokenizer_data/test_yalm_tokenizer.py +1 -1
- crfm_helm-0.5.4.dist-info/METADATA +0 -350
- crfm_helm-0.5.4.dist-info/RECORD +0 -697
- helm/benchmark/metrics/bhasa_metrics_specs.py +0 -10
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/accenture-6f97eeda.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aisingapore-6dfc9acf.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/cresta-9e22b983.png +0 -0
- helm/benchmark/static_build/assets/cuhk-8c5631e9.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/index-05c76bb1.css +0 -1
- helm/benchmark/static_build/assets/index-3ee38b3d.js +0 -10
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/scb10x-204bd786.png +0 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/wellsfargo-a86a6c4a.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/tokenizers/anthropic_tokenizer.py +0 -52
- {crfm_helm-0.5.4.dist-info → crfm_helm-0.5.6.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.4.dist-info → crfm_helm-0.5.6.dist-info/licenses}/LICENSE +0 -0
- {crfm_helm-0.5.4.dist-info → crfm_helm-0.5.6.dist-info}/top_level.txt +0 -0
helm/clients/bedrock_client.py
CHANGED
|
@@ -2,12 +2,13 @@ from abc import abstractmethod
|
|
|
2
2
|
from copy import deepcopy
|
|
3
3
|
import json
|
|
4
4
|
import os
|
|
5
|
-
from typing import Any, Dict, List, Mapping, Optional
|
|
5
|
+
from typing import Any, Dict, List, Mapping, Optional, TypedDict
|
|
6
|
+
from datetime import datetime
|
|
6
7
|
|
|
7
8
|
from helm.common.cache import CacheConfig
|
|
8
9
|
from helm.clients.client import CachingClient, truncate_and_tokenize_response_text
|
|
9
10
|
from helm.common.request import Request, RequestResult, GeneratedOutput, wrap_request_time
|
|
10
|
-
from helm.clients.bedrock_utils import get_bedrock_client
|
|
11
|
+
from helm.clients.bedrock_utils import get_bedrock_client, get_bedrock_client_v1
|
|
11
12
|
from helm.tokenizers.tokenizer import Tokenizer
|
|
12
13
|
|
|
13
14
|
|
|
@@ -23,27 +24,41 @@ class BedrockClient(CachingClient):
|
|
|
23
24
|
def convert_raw_response_to_completions(self, response: Dict, request: Request) -> List[GeneratedOutput]:
|
|
24
25
|
raise NotImplementedError()
|
|
25
26
|
|
|
27
|
+
"""
|
|
28
|
+
Amazon Bedrock is a fully managed service that provides s selection of leading foundation models (FMs) from Amazon
|
|
29
|
+
and other partner model providers.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def model_provider(self) -> str:
|
|
35
|
+
raise NotImplementedError()
|
|
36
|
+
|
|
26
37
|
def __init__(
|
|
27
38
|
self,
|
|
28
39
|
cache_config: CacheConfig,
|
|
29
40
|
tokenizer: Tokenizer,
|
|
30
41
|
tokenizer_name: str,
|
|
31
|
-
bedrock_model_id: Optional[str] = None,
|
|
32
42
|
assumed_role: Optional[str] = None,
|
|
33
43
|
region: Optional[str] = None,
|
|
34
44
|
):
|
|
35
45
|
super().__init__(cache_config=cache_config)
|
|
36
46
|
self.tokenizer = tokenizer
|
|
37
47
|
self.tokenizer_name = tokenizer_name
|
|
38
|
-
self.bedrock_model_id = bedrock_model_id
|
|
39
48
|
self.bedrock_client = get_bedrock_client(
|
|
40
49
|
assumed_role=assumed_role or os.environ.get("BEDROCK_ASSUME_ROLE", None),
|
|
41
|
-
region=region
|
|
50
|
+
region=region,
|
|
42
51
|
)
|
|
43
52
|
|
|
44
53
|
def make_request(self, request: Request) -> RequestResult:
|
|
45
|
-
# model_id should be something like "amazon.titan-tg1-large"
|
|
46
|
-
|
|
54
|
+
# model_id should be something like "amazon.titan-tg1-large", replace amazon- prefix with model creator name
|
|
55
|
+
model_name = request.model.split("/")[-1]
|
|
56
|
+
# check if model_name starts with "amazon-"
|
|
57
|
+
if self.model_provider == "amazon":
|
|
58
|
+
model_id = f"{self.model_provider}.{model_name}"
|
|
59
|
+
else:
|
|
60
|
+
model_id = model_name.replace("amazon-", f"{self.model_provider}.")
|
|
61
|
+
|
|
47
62
|
raw_request = self.convert_request_to_raw_request(request)
|
|
48
63
|
|
|
49
64
|
# modelId isn't part of raw_request, so it must be explicitly passed into the input to
|
|
@@ -58,6 +73,7 @@ class BedrockClient(CachingClient):
|
|
|
58
73
|
|
|
59
74
|
try:
|
|
60
75
|
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
76
|
+
|
|
61
77
|
except Exception as error:
|
|
62
78
|
return RequestResult(
|
|
63
79
|
success=False,
|
|
@@ -79,12 +95,111 @@ class BedrockClient(CachingClient):
|
|
|
79
95
|
)
|
|
80
96
|
|
|
81
97
|
|
|
98
|
+
class _ContentBlock(TypedDict):
|
|
99
|
+
text: str
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class _Message(TypedDict):
|
|
103
|
+
role: str
|
|
104
|
+
content: List[_ContentBlock]
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class BedrockNovaClient(CachingClient):
|
|
108
|
+
"""
|
|
109
|
+
Amazon Bedrock is a fully managed service that provides s selection of leading foundation models (FMs) from Amazon
|
|
110
|
+
and other partner model providers.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
cache_config: CacheConfig,
|
|
116
|
+
tokenizer: Tokenizer,
|
|
117
|
+
tokenizer_name: str,
|
|
118
|
+
assumed_role: Optional[str] = None,
|
|
119
|
+
region: Optional[str] = None,
|
|
120
|
+
bedrock_model_id: Optional[str] = None,
|
|
121
|
+
):
|
|
122
|
+
super().__init__(cache_config=cache_config)
|
|
123
|
+
self.tokenizer = tokenizer
|
|
124
|
+
self.tokenizer_name = tokenizer_name
|
|
125
|
+
self.bedrock_model_id = bedrock_model_id
|
|
126
|
+
self.bedrock_client = get_bedrock_client_v1(
|
|
127
|
+
assumed_role=assumed_role or os.environ.get("BEDROCK_ASSUME_ROLE", None),
|
|
128
|
+
region=region,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def _get_messages_from_request(self, request: Request) -> List[_Message]:
|
|
132
|
+
if request.prompt and request.messages:
|
|
133
|
+
raise ValueError(f"Only one of `prompt` and `messages` may be set in request: {request}")
|
|
134
|
+
if request.multimodal_prompt:
|
|
135
|
+
raise ValueError(f"`multimodal_prompt` is not supported in request: {request}")
|
|
136
|
+
|
|
137
|
+
if request.messages:
|
|
138
|
+
return [
|
|
139
|
+
{"role": message["role"], "content": [{"text": message["content"]}]} for message in request.messages
|
|
140
|
+
]
|
|
141
|
+
else:
|
|
142
|
+
return [{"role": "user", "content": [{"text": request.prompt}]}]
|
|
143
|
+
|
|
144
|
+
def convert_request_to_raw_request(self, request: Request) -> Dict:
|
|
145
|
+
model_id = request.model.replace("/", ".")
|
|
146
|
+
messages = self._get_messages_from_request(request)
|
|
147
|
+
|
|
148
|
+
return {
|
|
149
|
+
"modelId": self.bedrock_model_id or model_id,
|
|
150
|
+
"inferenceConfig": {
|
|
151
|
+
"temperature": request.temperature,
|
|
152
|
+
"maxTokens": request.max_tokens,
|
|
153
|
+
"topP": request.top_p,
|
|
154
|
+
},
|
|
155
|
+
"messages": messages,
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
159
|
+
raw_request = self.convert_request_to_raw_request(request)
|
|
160
|
+
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
161
|
+
|
|
162
|
+
def do_it() -> Dict[Any, Any]:
|
|
163
|
+
return self.bedrock_client.converse(**raw_request)
|
|
164
|
+
|
|
165
|
+
response, cached = self.cache.get(cache_key, do_it)
|
|
166
|
+
|
|
167
|
+
completions = self.convert_raw_response_to_completions(response, request)
|
|
168
|
+
dt = datetime.strptime(response["ResponseMetadata"]["HTTPHeaders"]["date"], "%a, %d %b %Y %H:%M:%S GMT")
|
|
169
|
+
# Use API reported latency rather than client measured latency
|
|
170
|
+
request_time = response["metrics"]["latencyMs"] / 1000
|
|
171
|
+
|
|
172
|
+
return RequestResult(
|
|
173
|
+
success=True,
|
|
174
|
+
cached=cached,
|
|
175
|
+
request_time=request_time,
|
|
176
|
+
request_datetime=int(dt.timestamp()),
|
|
177
|
+
completions=completions,
|
|
178
|
+
embedding=[],
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
def convert_raw_response_to_completions(self, response: Dict, request: Request) -> List[GeneratedOutput]:
|
|
182
|
+
completions: List[GeneratedOutput] = []
|
|
183
|
+
raw_completion = response["output"]
|
|
184
|
+
output_text = raw_completion["message"]["content"][0]["text"]
|
|
185
|
+
finish_reason = response["stopReason"]
|
|
186
|
+
completion = truncate_and_tokenize_response_text(
|
|
187
|
+
output_text.lstrip(), request, self.tokenizer, self.tokenizer_name, finish_reason
|
|
188
|
+
)
|
|
189
|
+
completions.append(completion)
|
|
190
|
+
return completions
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
# Amazon Bedrock Client for Titan Models
|
|
82
194
|
class BedrockTitanClient(BedrockClient):
|
|
83
195
|
_COMPLETION_REASON_TO_FINISH_REASON = {
|
|
84
196
|
"LENGTH": "length",
|
|
85
197
|
"FINISH": "endoftext",
|
|
86
198
|
}
|
|
87
199
|
|
|
200
|
+
# creator org for titan
|
|
201
|
+
model_provider = "amazon"
|
|
202
|
+
|
|
88
203
|
def convert_request_to_raw_request(self, request: Request) -> Dict:
|
|
89
204
|
# TODO: Support the following:
|
|
90
205
|
# - top_k_per_token
|
|
@@ -115,6 +230,7 @@ class BedrockTitanClient(BedrockClient):
|
|
|
115
230
|
# - tokens
|
|
116
231
|
# - logprob
|
|
117
232
|
completions: List[GeneratedOutput] = []
|
|
233
|
+
|
|
118
234
|
for raw_completion in response["results"]:
|
|
119
235
|
output_text = raw_completion["outputText"]
|
|
120
236
|
# Call lstrip() Titan has the tendency to emit "\n" as the first token in the generated text output.
|
|
@@ -126,3 +242,83 @@ class BedrockTitanClient(BedrockClient):
|
|
|
126
242
|
)
|
|
127
243
|
completions.append(completion)
|
|
128
244
|
return completions
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
# Amazon Bedrock Client for Mistral Models
|
|
248
|
+
class BedrockMistralClient(BedrockClient):
|
|
249
|
+
_COMPLETION_REASON_TO_FINISH_REASON = {
|
|
250
|
+
"length": "length",
|
|
251
|
+
"stop": "endoftext",
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
model_provider = "mistral"
|
|
255
|
+
|
|
256
|
+
def convert_request_to_raw_request(self, request: Request) -> Dict:
|
|
257
|
+
# TODO: Support the following:
|
|
258
|
+
# - top_k_per_token
|
|
259
|
+
# - echo_prompt
|
|
260
|
+
# - num_completions
|
|
261
|
+
return {
|
|
262
|
+
"prompt": f"[INST]{request.prompt}[/INST]",
|
|
263
|
+
"temperature": request.temperature,
|
|
264
|
+
"top_p": request.top_p,
|
|
265
|
+
"max_tokens": request.max_tokens,
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
def convert_raw_response_to_completions(self, response: Dict, request: Request) -> List[GeneratedOutput]:
|
|
269
|
+
# - logprob
|
|
270
|
+
completions: List[GeneratedOutput] = []
|
|
271
|
+
|
|
272
|
+
for raw_completion in response["outputs"]:
|
|
273
|
+
output_text = raw_completion["text"]
|
|
274
|
+
|
|
275
|
+
finish_reason = BedrockMistralClient._COMPLETION_REASON_TO_FINISH_REASON.get(
|
|
276
|
+
raw_completion["stop_reason"], raw_completion["stop_reason"].lower()
|
|
277
|
+
)
|
|
278
|
+
# Work around generated outputs with leading whitespace due to issue #2467
|
|
279
|
+
# TODO(#2467): Remove workaround
|
|
280
|
+
completion = truncate_and_tokenize_response_text(
|
|
281
|
+
output_text.lstrip(), request, self.tokenizer, self.tokenizer_name, finish_reason
|
|
282
|
+
)
|
|
283
|
+
completions.append(completion)
|
|
284
|
+
|
|
285
|
+
return completions
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
# Amazon Bedrock Client for LLAMA Models
|
|
289
|
+
class BedrockLlamaClient(BedrockClient):
|
|
290
|
+
_COMPLETION_REASON_TO_FINISH_REASON = {
|
|
291
|
+
"length": "length",
|
|
292
|
+
"stop": "endoftext",
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
model_provider = "meta"
|
|
296
|
+
|
|
297
|
+
def convert_request_to_raw_request(self, request: Request) -> Dict:
|
|
298
|
+
# TODO: Support the following:
|
|
299
|
+
# - top_k_per_token
|
|
300
|
+
# - echo_prompt
|
|
301
|
+
# - num_completions
|
|
302
|
+
return {
|
|
303
|
+
"prompt": f"[INST]{request.prompt}[/INST]",
|
|
304
|
+
"temperature": request.temperature,
|
|
305
|
+
"top_p": request.top_p,
|
|
306
|
+
"max_gen_len": request.max_tokens,
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
def convert_raw_response_to_completions(self, response: Dict, request: Request) -> List[GeneratedOutput]:
|
|
310
|
+
# - logprob
|
|
311
|
+
completions: List[GeneratedOutput] = []
|
|
312
|
+
output_text = response["generation"]
|
|
313
|
+
|
|
314
|
+
finish_reason = BedrockLlamaClient._COMPLETION_REASON_TO_FINISH_REASON.get(
|
|
315
|
+
response["stop_reason"], response["stop_reason"].lower()
|
|
316
|
+
)
|
|
317
|
+
# Work around generated outputs with leading whitespace due to issue #2467
|
|
318
|
+
# TODO(#2467): Remove workaround
|
|
319
|
+
completion = truncate_and_tokenize_response_text(
|
|
320
|
+
output_text.lstrip(), request, self.tokenizer, self.tokenizer_name, finish_reason
|
|
321
|
+
)
|
|
322
|
+
completions.append(completion)
|
|
323
|
+
|
|
324
|
+
return completions
|
helm/clients/bedrock_utils.py
CHANGED
|
@@ -8,6 +8,7 @@ from helm.common.optional_dependencies import handle_module_not_found_error
|
|
|
8
8
|
|
|
9
9
|
try:
|
|
10
10
|
import boto3
|
|
11
|
+
from boto3 import Session
|
|
11
12
|
from botocore.config import Config
|
|
12
13
|
except ModuleNotFoundError as e:
|
|
13
14
|
handle_module_not_found_error(e, ["aws"])
|
|
@@ -70,3 +71,35 @@ def get_bedrock_client(
|
|
|
70
71
|
|
|
71
72
|
hlog(f"Amazon Bedrock client successfully created with endpoint {bedrock_client._endpoint}")
|
|
72
73
|
return bedrock_client
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def get_bedrock_client_v1(
|
|
77
|
+
region: Optional[str] = None,
|
|
78
|
+
service_name: str = "bedrock-runtime",
|
|
79
|
+
assumed_role: Optional[str] = None,
|
|
80
|
+
read_timeout: int = 5000,
|
|
81
|
+
connect_timeout: int = 5000,
|
|
82
|
+
max_attempts: int = 10,
|
|
83
|
+
):
|
|
84
|
+
boto_config = Config(
|
|
85
|
+
read_timeout=read_timeout, connect_timeout=connect_timeout, retries={"max_attempts": max_attempts}
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
if assumed_role:
|
|
89
|
+
session = boto3.Session(region_name=region)
|
|
90
|
+
# Assume role and get credentials
|
|
91
|
+
sts = session.client("sts")
|
|
92
|
+
creds = sts.assume_role(RoleArn=str(assumed_role), RoleSessionName="crfm-helm")["Credentials"]
|
|
93
|
+
session = Session(
|
|
94
|
+
aws_access_key_id=creds["AccessKeyId"],
|
|
95
|
+
aws_secret_access_key=creds["SecretAccessKey"],
|
|
96
|
+
aws_session_token=creds["SessionToken"],
|
|
97
|
+
)
|
|
98
|
+
return session.client(
|
|
99
|
+
service_name=service_name,
|
|
100
|
+
region_name=region,
|
|
101
|
+
config=boto_config,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# default to instance role to get the aws credentials or aws configured credentials
|
|
105
|
+
return boto3.client(service_name=service_name, region_name=region, config=boto_config)
|
helm/clients/client.py
CHANGED
|
@@ -2,7 +2,7 @@ import json
|
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from typing import List, Mapping, Optional, cast
|
|
4
4
|
|
|
5
|
-
from helm.common.hierarchical_logger import
|
|
5
|
+
from helm.common.hierarchical_logger import hwarn
|
|
6
6
|
from helm.common.media_object import MultimediaObject, TEXT_TYPE
|
|
7
7
|
from helm.common.request import Request, RequestResult, GeneratedOutput, Token
|
|
8
8
|
from helm.common.cache import Cache, CacheConfig
|
|
@@ -65,7 +65,7 @@ def truncate_sequence(
|
|
|
65
65
|
# where max_tokens = 0, so there's nothing to truncate.
|
|
66
66
|
if request.echo_prompt:
|
|
67
67
|
if request.max_tokens != 0:
|
|
68
|
-
|
|
68
|
+
hwarn("don't know how to handle echo_prompt and max_tokens > 0, not truncating")
|
|
69
69
|
return sequence
|
|
70
70
|
|
|
71
71
|
if end_of_text_token:
|
|
@@ -90,8 +90,8 @@ def truncate_sequence(
|
|
|
90
90
|
new_tokens.append(token)
|
|
91
91
|
|
|
92
92
|
if len(new_text) < len(sequence.text) and len(new_tokens) == len(sequence.tokens):
|
|
93
|
-
|
|
94
|
-
f"
|
|
93
|
+
hwarn(
|
|
94
|
+
f"Stripped characters from text ({len(sequence.text)} -> {len(new_text)}), "
|
|
95
95
|
f"but wasn't able to strip the tokens"
|
|
96
96
|
)
|
|
97
97
|
|
|
@@ -99,14 +99,14 @@ def truncate_sequence(
|
|
|
99
99
|
new_logprob = sum(token.logprob for token in new_tokens)
|
|
100
100
|
|
|
101
101
|
if print_warning:
|
|
102
|
-
|
|
102
|
+
hwarn(f"truncate_sequence needs to strip {json.dumps(stop)}")
|
|
103
103
|
|
|
104
104
|
sequence = GeneratedOutput(text=new_text, logprob=new_logprob, tokens=new_tokens)
|
|
105
105
|
|
|
106
106
|
# Truncate based on the max number of tokens.
|
|
107
107
|
if len(sequence.tokens) > request.max_tokens:
|
|
108
108
|
if print_warning:
|
|
109
|
-
|
|
109
|
+
hwarn(f"truncate_sequence needs to truncate {len(sequence.tokens)} down to {request.max_tokens}")
|
|
110
110
|
new_tokens = sequence.tokens[: request.max_tokens]
|
|
111
111
|
|
|
112
112
|
# This is imperfect stitching together of tokens, so just to make sure this is okay
|
|
@@ -114,7 +114,7 @@ def truncate_sequence(
|
|
|
114
114
|
# Usually, in our benchmark, max_tokens is active when it's 1, so hopefully this isn't an issue.
|
|
115
115
|
new_text = "".join(token.text for token in new_tokens)
|
|
116
116
|
if not sequence.text.startswith(new_text):
|
|
117
|
-
|
|
117
|
+
hwarn(f"{json.dumps(sequence.text)} does not start with truncated text {json.dumps(new_text)}")
|
|
118
118
|
|
|
119
119
|
new_logprob = sum(token.logprob for token in new_tokens)
|
|
120
120
|
|
|
@@ -6,7 +6,7 @@ import torch
|
|
|
6
6
|
from helm.common.gpu_utils import get_torch_device
|
|
7
7
|
from helm.common.images_utils import open_image
|
|
8
8
|
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
9
|
-
from .base_clip_scorer import BaseCLIPScorer
|
|
9
|
+
from helm.clients.clip_scorers.base_clip_scorer import BaseCLIPScorer
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
_ = torch.manual_seed(42)
|
|
@@ -4,7 +4,7 @@ import transformers
|
|
|
4
4
|
from helm.common.gpu_utils import get_torch_device, get_torch_device_name
|
|
5
5
|
from helm.common.images_utils import open_image
|
|
6
6
|
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
7
|
-
from .base_clip_scorer import BaseCLIPScorer
|
|
7
|
+
from helm.clients.clip_scorers.base_clip_scorer import BaseCLIPScorer
|
|
8
8
|
|
|
9
9
|
_ = torch.manual_seed(42)
|
|
10
10
|
|
helm/clients/cohere_client.py
CHANGED
|
@@ -164,12 +164,12 @@ class CohereRawChatRequest(TypedDict):
|
|
|
164
164
|
message: str
|
|
165
165
|
model: Optional[str]
|
|
166
166
|
preamble: Optional[str]
|
|
167
|
-
chat_history: Optional[Sequence[cohere.
|
|
167
|
+
chat_history: Optional[Sequence[cohere.ChatbotMessage]]
|
|
168
168
|
temperature: Optional[float]
|
|
169
169
|
max_tokens: Optional[int]
|
|
170
170
|
k: Optional[int]
|
|
171
171
|
p: Optional[float]
|
|
172
|
-
seed: Optional[
|
|
172
|
+
seed: Optional[int]
|
|
173
173
|
stop_sequences: Optional[Sequence[str]]
|
|
174
174
|
frequency_penalty: Optional[float]
|
|
175
175
|
presence_penalty: Optional[float]
|
|
@@ -188,7 +188,7 @@ def convert_to_raw_chat_request(request: Request) -> CohereRawChatRequest:
|
|
|
188
188
|
"k": request.top_k_per_token,
|
|
189
189
|
"p": request.top_p,
|
|
190
190
|
"stop_sequences": request.stop_sequences,
|
|
191
|
-
"seed":
|
|
191
|
+
"seed": int(request.random) if request.random is not None else None,
|
|
192
192
|
"frequency_penalty": request.frequency_penalty,
|
|
193
193
|
"presence_penalty": request.presence_penalty,
|
|
194
194
|
}
|
helm/clients/google_client.py
CHANGED
|
@@ -2,7 +2,7 @@ from typing import List, Dict
|
|
|
2
2
|
|
|
3
3
|
from helm.common.cache import CacheConfig
|
|
4
4
|
from helm.common.request import Request, RequestResult, GeneratedOutput, Token
|
|
5
|
-
from .client import CachingClient, truncate_sequence
|
|
5
|
+
from helm.clients.client import CachingClient, truncate_sequence
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class GoogleClient(CachingClient):
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
|
|
3
|
+
from helm.clients.openai_client import OpenAIClient
|
|
4
|
+
from helm.common.cache import CacheConfig
|
|
5
|
+
from helm.common.request import Request
|
|
6
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class GrokChatClient(OpenAIClient):
|
|
10
|
+
|
|
11
|
+
BASE_URL = "https://api.x.ai/v1"
|
|
12
|
+
|
|
13
|
+
_UNSUPPORTED_ARGUMENTS = ["presence_penalty", "frequency_penalty"]
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
tokenizer: Tokenizer,
|
|
18
|
+
tokenizer_name: str,
|
|
19
|
+
cache_config: CacheConfig,
|
|
20
|
+
api_key: Optional[str] = None,
|
|
21
|
+
):
|
|
22
|
+
super().__init__(
|
|
23
|
+
tokenizer=tokenizer,
|
|
24
|
+
tokenizer_name=tokenizer_name,
|
|
25
|
+
cache_config=cache_config,
|
|
26
|
+
api_key=api_key,
|
|
27
|
+
org_id=None,
|
|
28
|
+
base_url="https://api.x.ai/v1",
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
def _make_chat_raw_request(self, request: Request) -> Dict[str, Any]:
|
|
32
|
+
raw_request = super()._make_chat_raw_request(request)
|
|
33
|
+
for unsupported_argument in self._UNSUPPORTED_ARGUMENTS:
|
|
34
|
+
if unsupported_argument in raw_request:
|
|
35
|
+
del raw_request[unsupported_argument]
|
|
36
|
+
return raw_request
|
|
@@ -8,7 +8,7 @@ from transformers.generation.stopping_criteria import (
|
|
|
8
8
|
from typing import Any, Dict, List, Optional, TypedDict
|
|
9
9
|
|
|
10
10
|
from helm.common.cache import CacheConfig
|
|
11
|
-
from helm.common.hierarchical_logger import htrack_block, hlog
|
|
11
|
+
from helm.common.hierarchical_logger import htrack_block, hlog, hwarn
|
|
12
12
|
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
13
13
|
from helm.common.request import (
|
|
14
14
|
wrap_request_time,
|
|
@@ -18,8 +18,9 @@ from helm.common.request import (
|
|
|
18
18
|
GeneratedOutput,
|
|
19
19
|
Token,
|
|
20
20
|
)
|
|
21
|
+
from helm.proxy.retry import NonRetriableException
|
|
21
22
|
from helm.tokenizers.tokenizer import Tokenizer
|
|
22
|
-
from .client import CachingClient, truncate_sequence
|
|
23
|
+
from helm.clients.client import CachingClient, truncate_sequence
|
|
23
24
|
from helm.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer, WrappedPreTrainedTokenizer
|
|
24
25
|
from threading import Lock
|
|
25
26
|
|
|
@@ -59,17 +60,23 @@ class HuggingFaceServer:
|
|
|
59
60
|
self,
|
|
60
61
|
pretrained_model_name_or_path: str,
|
|
61
62
|
wrapped_tokenizer: WrappedPreTrainedTokenizer,
|
|
62
|
-
openvino: bool = False,
|
|
63
63
|
**kwargs,
|
|
64
64
|
):
|
|
65
65
|
self.device: Optional[str]
|
|
66
66
|
if "device_map" in kwargs:
|
|
67
|
+
if "device" in kwargs:
|
|
68
|
+
raise ValueError("At most one of one of `device` and `device_map` may be specified.")
|
|
67
69
|
try:
|
|
68
70
|
import accelerate # noqa: F401
|
|
69
71
|
except ModuleNotFoundError as e:
|
|
70
72
|
handle_module_not_found_error(e, ["accelerate"])
|
|
71
|
-
hlog(f'Hugging Face device_map set to "{kwargs["device_map"]}".')
|
|
73
|
+
hlog(f'Hugging Face device_map set to "{kwargs["device_map"]}" from kwargs.')
|
|
72
74
|
self.device = None
|
|
75
|
+
elif "device" in kwargs:
|
|
76
|
+
if "device_map" in kwargs:
|
|
77
|
+
raise ValueError("At most one of one of `device` and `device_map` may be specified.")
|
|
78
|
+
hlog(f'Hugging Face device set to "{kwargs["device"]}" from kwargs.')
|
|
79
|
+
self.device = kwargs.pop("device")
|
|
73
80
|
elif torch.cuda.is_available():
|
|
74
81
|
hlog('Hugging Face device set to "cuda:0" because CUDA is available.')
|
|
75
82
|
self.device = "cuda:0"
|
|
@@ -85,20 +92,7 @@ class HuggingFaceServer:
|
|
|
85
92
|
|
|
86
93
|
with htrack_block(f"Loading Hugging Face model {pretrained_model_name_or_path}"):
|
|
87
94
|
# WARNING this may fail if your GPU does not have enough memory
|
|
88
|
-
if
|
|
89
|
-
# Optimum Intel provides a simple interface to optimize Transformer models and convert them to \
|
|
90
|
-
# OpenVINO™ Intermediate Representation (IR) format to accelerate end-to-end pipelines on \
|
|
91
|
-
# Intel® architectures using OpenVINO™ runtime.
|
|
92
|
-
try:
|
|
93
|
-
from optimum.intel.openvino import OVModelForCausalLM
|
|
94
|
-
except ModuleNotFoundError as e:
|
|
95
|
-
handle_module_not_found_error(e, ["openvino"])
|
|
96
|
-
|
|
97
|
-
self.device = "cpu"
|
|
98
|
-
self.model = OVModelForCausalLM.from_pretrained(
|
|
99
|
-
pretrained_model_name_or_path, export=True, **kwargs
|
|
100
|
-
).to(self.device)
|
|
101
|
-
elif self.device is None:
|
|
95
|
+
if self.device is None:
|
|
102
96
|
# kwargs contains device_map=auto
|
|
103
97
|
# Do not call to() because accelerate will take care of model device placement.
|
|
104
98
|
self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
|
@@ -113,7 +107,6 @@ class HuggingFaceServer:
|
|
|
113
107
|
encoded_input = tokenizer(raw_request["prompt"], return_tensors="pt", return_token_type_ids=False).to(
|
|
114
108
|
0 if self.device is None else self.device
|
|
115
109
|
)
|
|
116
|
-
|
|
117
110
|
stopping_criteria: Optional[StoppingCriteriaList] = None
|
|
118
111
|
optional_args = {}
|
|
119
112
|
if len(raw_request["stop_sequences"]) > 0:
|
|
@@ -264,6 +257,7 @@ class HuggingFaceClient(CachingClient):
|
|
|
264
257
|
tokenizer: Tokenizer,
|
|
265
258
|
pretrained_model_name_or_path: Optional[str] = None,
|
|
266
259
|
end_of_text_token: Optional[str] = None,
|
|
260
|
+
apply_chat_template: Optional[bool] = None,
|
|
267
261
|
**kwargs,
|
|
268
262
|
):
|
|
269
263
|
super().__init__(cache_config=cache_config)
|
|
@@ -274,9 +268,46 @@ class HuggingFaceClient(CachingClient):
|
|
|
274
268
|
"but instead it is {tokenizer}"
|
|
275
269
|
)
|
|
276
270
|
self._wrapped_tokenizer: WrappedPreTrainedTokenizer = tokenizer.get_wrapped_tokenizer()
|
|
277
|
-
self._tokenizer = tokenizer
|
|
278
271
|
self._kwargs = _process_huggingface_client_kwargs(kwargs)
|
|
279
272
|
self._end_of_text_token = end_of_text_token
|
|
273
|
+
# If the user did not explicitly configure whether the model is a chat model with `apply_chat_template` arg,
|
|
274
|
+
# auto-infer if the model is a chat model based on whether the tokenizer has a chat template.
|
|
275
|
+
# Note: Auto-inference is incorrect for some non-chat models that still have chat templates
|
|
276
|
+
# e.g. Qwen2, Qwen 2.5.
|
|
277
|
+
# For these models, the `apply_chat_template` arg should be explicitly set to false.
|
|
278
|
+
if apply_chat_template is not None:
|
|
279
|
+
self._apply_chat_template = apply_chat_template
|
|
280
|
+
else:
|
|
281
|
+
with self._wrapped_tokenizer as hf_tokenizer:
|
|
282
|
+
self._apply_chat_template = bool(hf_tokenizer.chat_template)
|
|
283
|
+
hwarn(
|
|
284
|
+
f"Automatically set `apply_chat_template` to {self._apply_chat_template} based on "
|
|
285
|
+
"whether the tokenizer has a chat template. "
|
|
286
|
+
"If this is incorrect, please explicitly set `apply_chat_template`."
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
def get_prompt(self, request: Request) -> str:
|
|
290
|
+
if request.prompt and request.messages:
|
|
291
|
+
raise NonRetriableException(f"More than one of `prompt` and `messages` was set in request: {request}")
|
|
292
|
+
# Chat model expects a list of messages as input
|
|
293
|
+
if self._apply_chat_template:
|
|
294
|
+
with self._wrapped_tokenizer as tokenizer:
|
|
295
|
+
if request.messages:
|
|
296
|
+
prompt = tokenizer.apply_chat_template(request.messages, tokenize=False)
|
|
297
|
+
assert isinstance(prompt, str)
|
|
298
|
+
return prompt
|
|
299
|
+
else:
|
|
300
|
+
prompt = tokenizer.apply_chat_template(
|
|
301
|
+
[{"role": "user", "content": request.prompt}], tokenize=False
|
|
302
|
+
)
|
|
303
|
+
assert isinstance(prompt, str)
|
|
304
|
+
return prompt
|
|
305
|
+
# Base non-chat model expects a string as input
|
|
306
|
+
else:
|
|
307
|
+
if request.messages:
|
|
308
|
+
raise NonRetriableException("Chat mesages not supported by non-chat model")
|
|
309
|
+
else:
|
|
310
|
+
return request.prompt
|
|
280
311
|
|
|
281
312
|
def make_request(self, request: Request) -> RequestResult:
|
|
282
313
|
# Embedding not supported for this model
|
|
@@ -285,7 +316,7 @@ class HuggingFaceClient(CachingClient):
|
|
|
285
316
|
|
|
286
317
|
raw_request: HuggingFaceRequest = {
|
|
287
318
|
"engine": request.model_engine,
|
|
288
|
-
"prompt": request
|
|
319
|
+
"prompt": self.get_prompt(request),
|
|
289
320
|
"temperature": 1e-7 if request.temperature == 0 else request.temperature,
|
|
290
321
|
"num_return_sequences": request.num_completions,
|
|
291
322
|
"max_new_tokens": request.max_tokens,
|