crfm-helm 0.5.6__py3-none-any.whl → 0.5.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.10.dist-info}/METADATA +72 -130
- {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.10.dist-info}/RECORD +372 -305
- helm/benchmark/adaptation/adapter_spec.py +10 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +11 -3
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +11 -8
- helm/benchmark/annotation/aci_bench_annotator.py +11 -22
- helm/benchmark/annotation/air_bench_annotator.py +1 -1
- helm/benchmark/annotation/alrage_annotator.py +90 -0
- helm/benchmark/annotation/chw_care_plan_annotator.py +10 -21
- helm/benchmark/annotation/dischargeme_annotator.py +11 -22
- helm/benchmark/annotation/live_qa_annotator.py +1 -1
- helm/benchmark/annotation/med_dialog_annotator.py +11 -22
- helm/benchmark/annotation/medalign_annotator.py +11 -22
- helm/benchmark/annotation/medi_qa_annotator.py +11 -22
- helm/benchmark/annotation/medication_qa_annotator.py +11 -22
- helm/benchmark/annotation/mental_health_annotator.py +11 -22
- helm/benchmark/annotation/mimic_bhc_annotator.py +11 -22
- helm/benchmark/annotation/mimic_rrs_annotator.py +11 -22
- helm/benchmark/annotation/model_as_judge.py +23 -18
- helm/benchmark/annotation/mtsamples_procedures_annotator.py +11 -22
- helm/benchmark/annotation/mtsamples_replicate_annotator.py +11 -22
- helm/benchmark/annotation/starr_patient_instructions_annotator.py +11 -22
- helm/benchmark/metrics/air_bench_metrics.py +3157 -1
- helm/benchmark/metrics/alrage_metric.py +35 -0
- helm/benchmark/metrics/basic_metrics.py +267 -2
- helm/benchmark/metrics/bbq_metrics.py +12 -0
- helm/benchmark/metrics/classification_metrics.py +19 -1
- helm/benchmark/metrics/codeinsights_code_efficiency_metrics.py +186 -0
- helm/benchmark/metrics/codeinsights_code_evaluation_metrics.py +477 -0
- helm/benchmark/metrics/codeinsights_correct_code_metrics.py +366 -0
- helm/benchmark/metrics/codeinsights_edge_case_metrics.py +92 -0
- helm/benchmark/metrics/codeinsights_metric_specs.py +51 -0
- helm/benchmark/metrics/comet_metric.py +1 -1
- helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +12 -1
- helm/benchmark/metrics/copyright_metrics.py +1 -1
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +1 -1
- helm/benchmark/metrics/dry_run_metrics.py +30 -1
- helm/benchmark/metrics/efficiency_metrics.py +74 -0
- helm/benchmark/metrics/ehr_sql_metrics.py +57 -1
- helm/benchmark/metrics/evaluate_reference_metrics.py +312 -1
- helm/benchmark/metrics/gpqa_chain_of_thought_metric.py +13 -1
- helm/benchmark/metrics/helpdesk_call_summarization_metrics.py +13 -1
- helm/benchmark/metrics/ifeval_metrics.py +13 -1
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +13 -2
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +1 -1
- helm/benchmark/metrics/instruction_following_critique_metrics.py +41 -1
- helm/benchmark/metrics/kpi_edgar_metrics.py +21 -0
- helm/benchmark/metrics/language_modeling_metrics.py +13 -1
- helm/benchmark/metrics/live_qa_metrics.py +13 -1
- helm/benchmark/metrics/llm_jury_metrics.py +13 -1
- helm/benchmark/metrics/lmkt_metric_specs.py +12 -0
- helm/benchmark/metrics/lmkt_metrics.py +47 -0
- helm/benchmark/metrics/medcalc_bench_metrics.py +14 -1
- helm/benchmark/metrics/medec_metrics.py +25 -2
- helm/benchmark/metrics/melt_toxicity_metric.py +1 -1
- helm/benchmark/metrics/metric.py +25 -0
- helm/benchmark/metrics/mimiciv_billing_code_metrics.py +32 -1
- helm/benchmark/metrics/omni_math_metrics.py +13 -1
- helm/benchmark/metrics/safety_metrics.py +13 -1
- helm/benchmark/metrics/seahelm_metrics.py +14 -1
- helm/benchmark/metrics/summac/model_summac.py +3 -3
- helm/benchmark/metrics/summarization_metrics.py +129 -1
- helm/benchmark/metrics/toxicity_metrics.py +31 -1
- helm/benchmark/metrics/ultra_suite_asr_classification_metrics.py +52 -0
- helm/benchmark/metrics/wildbench_metrics.py +21 -1
- helm/benchmark/model_deployment_registry.py +11 -19
- helm/benchmark/presentation/create_plots.py +11 -2
- helm/benchmark/presentation/run_display.py +13 -3
- helm/benchmark/presentation/run_entry.py +2 -2
- helm/benchmark/presentation/schema.py +10 -22
- helm/benchmark/presentation/summarize.py +189 -14
- helm/benchmark/presentation/taxonomy_info.py +20 -0
- helm/benchmark/presentation/test_create_plots.py +4 -1
- helm/benchmark/run.py +15 -4
- helm/benchmark/run_expander.py +4 -0
- helm/benchmark/run_specs/arabic_run_specs.py +197 -0
- helm/benchmark/run_specs/bluex_run_specs.py +40 -0
- helm/benchmark/run_specs/classic_run_specs.py +2 -55
- helm/benchmark/run_specs/codeinsights_run_specs.py +192 -0
- helm/benchmark/run_specs/healthqa_br_run_specs.py +40 -0
- helm/benchmark/run_specs/heim_run_specs.py +3 -1
- helm/benchmark/run_specs/lmkt_run_specs.py +144 -0
- helm/benchmark/run_specs/long_context_run_specs.py +48 -1
- helm/benchmark/run_specs/medhelm/__init__.py +0 -0
- helm/benchmark/run_specs/medhelm/benchmark_config.py +219 -0
- helm/benchmark/run_specs/medhelm_run_specs.py +363 -53
- helm/benchmark/run_specs/multilingual_run_specs.py +50 -0
- helm/benchmark/run_specs/speech_disorder_audio_run_specs.py +11 -13
- helm/benchmark/runner.py +7 -0
- helm/benchmark/scenarios/aci_bench_scenario.py +23 -0
- helm/benchmark/scenarios/air_bench_scenario.py +21 -0
- helm/benchmark/scenarios/alghafa_scenario.py +126 -0
- helm/benchmark/scenarios/alrage_scenario.py +54 -0
- helm/benchmark/scenarios/anthropic_hh_rlhf_scenario.py +23 -1
- helm/benchmark/scenarios/anthropic_red_team_scenario.py +12 -1
- helm/benchmark/scenarios/arabic_exams_scenario.py +114 -0
- helm/benchmark/scenarios/arabic_mmlu_scenario.py +82 -0
- helm/benchmark/scenarios/aratrust_scenario.py +95 -0
- helm/benchmark/scenarios/audio_language/casual_conversations2_scenario.py +1 -1
- helm/benchmark/scenarios/audio_language/mustard_scenario.py +1 -1
- helm/benchmark/scenarios/audio_language/ultra_suite_asr_classification_scenario.py +74 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_asr_transcription_scenario.py +70 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_classification_scenario.py +22 -53
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_breakdown_scenario.py +21 -21
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_symptoms_scenario.py +21 -52
- helm/benchmark/scenarios/babi_qa_scenario.py +15 -0
- helm/benchmark/scenarios/banking77_scenario.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +15 -0
- helm/benchmark/scenarios/best_chatgpt_prompts.yaml +473 -0
- helm/benchmark/scenarios/bird_sql_scenario.py +18 -0
- helm/benchmark/scenarios/bluex_scenario.py +70 -0
- helm/benchmark/scenarios/bold_scenario.py +15 -0
- helm/benchmark/scenarios/boolq_scenario.py +20 -0
- helm/benchmark/scenarios/chw_care_plan_scenario.py +23 -0
- helm/benchmark/scenarios/civil_comments_scenario.py +13 -0
- helm/benchmark/scenarios/clear_scenario.py +23 -0
- helm/benchmark/scenarios/cleva_scenario.py +480 -1
- helm/benchmark/scenarios/code_scenario.py +28 -0
- helm/benchmark/scenarios/codeinsights_code_efficiency_scenario.py +197 -0
- helm/benchmark/scenarios/codeinsights_correct_code_scenario.py +78 -0
- helm/benchmark/scenarios/codeinsights_edge_case_scenario.py +192 -0
- helm/benchmark/scenarios/codeinsights_student_coding_scenario.py +162 -0
- helm/benchmark/scenarios/codeinsights_student_mistake_scenario.py +188 -0
- helm/benchmark/scenarios/commonsense_scenario.py +32 -0
- helm/benchmark/scenarios/compositional_instructions.yaml +70 -0
- helm/benchmark/scenarios/conv_fin_qa_calc_scenario.py +21 -0
- helm/benchmark/scenarios/copyright_scenario.py +35 -1
- helm/benchmark/scenarios/cti_to_mitre_scenario.py +21 -0
- helm/benchmark/scenarios/czech_bank_qa_scenario.py +18 -0
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +22 -1
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +23 -1
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +22 -1
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +21 -1
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +13 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +13 -1
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +13 -1
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +13 -1
- helm/benchmark/scenarios/dischargeme_scenario.py +24 -0
- helm/benchmark/scenarios/disinformation_scenario.py +22 -0
- helm/benchmark/scenarios/dyck_language_scenario.py +15 -0
- helm/benchmark/scenarios/ehrshot_scenario.py +22 -0
- helm/benchmark/scenarios/enem_challenge_scenario.py +19 -0
- helm/benchmark/scenarios/entity_data_imputation_scenario.py +14 -0
- helm/benchmark/scenarios/entity_matching_scenario.py +14 -0
- helm/benchmark/scenarios/exams_multilingual_scenario.py +115 -0
- helm/benchmark/scenarios/fin_qa_scenario.py +20 -0
- helm/benchmark/scenarios/financebench_scenario.py +21 -0
- helm/benchmark/scenarios/financial_phrasebank_scenario.py +21 -0
- helm/benchmark/scenarios/gold_commodity_news_scenario.py +21 -0
- helm/benchmark/scenarios/gpqa_scenario.py +18 -0
- helm/benchmark/scenarios/grammar_scenario.py +20 -1
- helm/benchmark/scenarios/gsm_scenario.py +21 -0
- helm/benchmark/scenarios/harm_bench_gcg_transfer_scenario.py +12 -1
- helm/benchmark/scenarios/harm_bench_scenario.py +12 -1
- helm/benchmark/scenarios/headqa_scenario.py +22 -0
- helm/benchmark/scenarios/healthqa_br_scenario.py +80 -0
- helm/benchmark/scenarios/helpdesk_call_summarization_scenario.py +13 -0
- helm/benchmark/scenarios/ice_scenario.py +21 -1
- helm/benchmark/scenarios/ifeval_scenario.py +18 -0
- helm/benchmark/scenarios/imdb_scenario.py +15 -0
- helm/benchmark/scenarios/infinite_bench_en_mc_scenario.py +111 -0
- helm/benchmark/scenarios/infinite_bench_en_qa_scenario.py +1 -1
- helm/benchmark/scenarios/infinite_bench_en_sum_scenario.py +19 -0
- helm/benchmark/scenarios/koala_scenario.py +21 -1
- helm/benchmark/scenarios/kpi_edgar_scenario.py +21 -0
- helm/benchmark/scenarios/legal_contract_summarization_scenario.py +20 -0
- helm/benchmark/scenarios/legal_summarization_scenario.py +50 -0
- helm/benchmark/scenarios/legal_support_scenario.py +13 -0
- helm/benchmark/scenarios/legalbench_scenario.py +19 -0
- helm/benchmark/scenarios/lex_glue_scenario.py +11 -0
- helm/benchmark/scenarios/lextreme_scenario.py +11 -0
- helm/benchmark/scenarios/lmkt_scenarios.py +288 -0
- helm/benchmark/scenarios/lsat_qa_scenario.py +14 -0
- helm/benchmark/scenarios/madinah_qa_scenario.py +73 -0
- helm/benchmark/scenarios/math_scenario.py +54 -20
- helm/benchmark/scenarios/mbzuai_human_translated_arabic_mmlu.py +68 -0
- helm/benchmark/scenarios/med_dialog_scenario.py +32 -1
- helm/benchmark/scenarios/med_mcqa_scenario.py +14 -0
- helm/benchmark/scenarios/med_qa_scenario.py +20 -0
- helm/benchmark/scenarios/medalign_scenario.py +23 -0
- helm/benchmark/scenarios/medalign_scenario_helper.py +19 -125
- helm/benchmark/scenarios/medbullets_scenario.py +22 -0
- helm/benchmark/scenarios/medcalc_bench_scenario.py +22 -0
- helm/benchmark/scenarios/medec_scenario.py +23 -0
- helm/benchmark/scenarios/medhallu_scenario.py +23 -0
- helm/benchmark/scenarios/medhelm/__init__.py +0 -0
- helm/benchmark/scenarios/medhelm/judges.yaml +14 -0
- helm/benchmark/scenarios/medhelm_configurable_scenario.py +101 -0
- helm/benchmark/scenarios/medi_qa_scenario.py +24 -1
- helm/benchmark/scenarios/medication_qa_scenario.py +31 -1
- helm/benchmark/scenarios/melt_scenarios.py +2 -2
- helm/benchmark/scenarios/mental_health_scenario.py +23 -0
- helm/benchmark/scenarios/mimic_bhc_scenario.py +25 -1
- helm/benchmark/scenarios/mimic_rrs_scenario.py +23 -0
- helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +22 -0
- helm/benchmark/scenarios/mmlu_pro_scenario.py +18 -0
- helm/benchmark/scenarios/mmlu_scenario.py +21 -0
- helm/benchmark/scenarios/mmmlu_scenario.py +85 -0
- helm/benchmark/scenarios/msmarco_scenario.py +30 -0
- helm/benchmark/scenarios/mtsamples_procedures_scenario.py +22 -0
- helm/benchmark/scenarios/mtsamples_replicate_scenario.py +22 -0
- helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +20 -0
- helm/benchmark/scenarios/narrativeqa_scenario.py +19 -0
- helm/benchmark/scenarios/natural_qa_scenario.py +32 -0
- helm/benchmark/scenarios/omni_math_scenario.py +18 -0
- helm/benchmark/scenarios/open_assistant_scenario.py +22 -0
- helm/benchmark/scenarios/openai_mrcr_scenario.py +15 -0
- helm/benchmark/scenarios/pubmed_qa_scenario.py +22 -0
- helm/benchmark/scenarios/quac_scenario.py +14 -0
- helm/benchmark/scenarios/race_based_med_scenario.py +23 -0
- helm/benchmark/scenarios/raft_scenario.py +15 -0
- helm/benchmark/scenarios/real_toxicity_prompts_scenario.py +14 -1
- helm/benchmark/scenarios/ruler_qa_scenarios.py +40 -0
- helm/benchmark/scenarios/scenario.py +31 -0
- helm/benchmark/scenarios/seahelm_scenario.py +350 -2
- helm/benchmark/scenarios/self_instruct_scenario.py +29 -1
- helm/benchmark/scenarios/shc_bmt_scenario.py +22 -0
- helm/benchmark/scenarios/shc_cdi_scenario.py +20 -0
- helm/benchmark/scenarios/shc_conf_scenario.py +23 -0
- helm/benchmark/scenarios/shc_ent_scenario.py +21 -0
- helm/benchmark/scenarios/shc_gip_scenario.py +20 -0
- helm/benchmark/scenarios/shc_privacy_scenario.py +22 -0
- helm/benchmark/scenarios/shc_proxy_scenario.py +23 -1
- helm/benchmark/scenarios/shc_ptbm_scenario.py +23 -0
- helm/benchmark/scenarios/shc_sequoia_scenario.py +21 -0
- helm/benchmark/scenarios/simple_safety_tests_scenario.py +12 -1
- helm/benchmark/scenarios/situation_prompts.yaml +49 -0
- helm/benchmark/scenarios/spider_scenario.py +18 -0
- helm/benchmark/scenarios/starr_patient_instructions_scenario.py +22 -0
- helm/benchmark/scenarios/summarization_scenario.py +37 -0
- helm/benchmark/scenarios/synthetic_efficiency_scenario.py +22 -1
- helm/benchmark/scenarios/synthetic_reasoning_natural_scenario.py +13 -0
- helm/benchmark/scenarios/test_alghafa_scenario.py +29 -0
- helm/benchmark/scenarios/test_alrage_scenario.py +23 -0
- helm/benchmark/scenarios/test_arabic_exams_scenario.py +21 -0
- helm/benchmark/scenarios/test_aratrust_scenario.py +21 -0
- helm/benchmark/scenarios/test_bluex_scenario.py +59 -0
- helm/benchmark/scenarios/test_exams_multilingual_scenario.py +29 -0
- helm/benchmark/scenarios/test_healtha_br_scenario.py +57 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +95 -0
- helm/benchmark/scenarios/the_pile_scenario.py +13 -1
- helm/benchmark/scenarios/truthful_qa_scenario.py +14 -0
- helm/benchmark/scenarios/twitter_aae_scenario.py +20 -1
- helm/benchmark/scenarios/vicuna_scenario.py +21 -1
- helm/benchmark/scenarios/wikifact_scenario.py +20 -0
- helm/benchmark/scenarios/wildbench_scenario.py +18 -0
- helm/benchmark/scenarios/wmt_14_scenario.py +19 -0
- helm/benchmark/slurm_jobs.py +1 -2
- helm/benchmark/slurm_runner.py +8 -1
- helm/benchmark/static/schema_arabic.yaml +271 -0
- helm/benchmark/static/schema_classic.yaml +0 -17
- helm/benchmark/static/schema_long_context.yaml +17 -18
- helm/benchmark/static/schema_medhelm.yaml +36 -0
- helm/benchmark/static/schema_slp.yaml +219 -0
- helm/benchmark/static_build/assets/audio-table-Dn5NMMeJ.png +0 -0
- helm/benchmark/static_build/assets/index-oIeiQW2g.css +1 -0
- helm/benchmark/static_build/assets/index-qOFpOyHb.js +10 -0
- helm/benchmark/static_build/assets/react-BteFIppM.js +85 -0
- helm/benchmark/static_build/assets/recharts-DxuQtTOs.js +97 -0
- helm/benchmark/static_build/assets/tremor-DR4fE7ko.js +10 -0
- helm/benchmark/static_build/index.html +5 -6
- helm/benchmark/window_services/image_generation/clip_window_service.py +1 -3
- helm/clients/ai21_client.py +2 -0
- helm/clients/aleph_alpha_client.py +2 -0
- helm/clients/anthropic_client.py +7 -1
- helm/clients/audio_language/diva_llama_client.py +2 -0
- helm/clients/audio_language/llama_omni/arguments.py +61 -0
- helm/clients/audio_language/llama_omni/constants.py +9 -0
- helm/clients/audio_language/llama_omni/conversation.py +213 -0
- helm/clients/audio_language/llama_omni/model/__init__.py +0 -0
- helm/clients/audio_language/llama_omni/model/builder.py +88 -0
- helm/clients/audio_language/llama_omni/model/language_model/omni_speech2s_llama.py +190 -0
- helm/clients/audio_language/llama_omni/model/language_model/omni_speech_llama.py +118 -0
- helm/clients/audio_language/llama_omni/model/omni_speech_arch.py +249 -0
- helm/clients/audio_language/llama_omni/model/speech_encoder/builder.py +9 -0
- helm/clients/audio_language/llama_omni/model/speech_encoder/speech_encoder.py +27 -0
- helm/clients/audio_language/llama_omni/model/speech_generator/builder.py +9 -0
- helm/clients/audio_language/llama_omni/model/speech_generator/generation.py +622 -0
- helm/clients/audio_language/llama_omni/model/speech_generator/speech_generator.py +104 -0
- helm/clients/audio_language/llama_omni/model/speech_projector/builder.py +9 -0
- helm/clients/audio_language/llama_omni/model/speech_projector/speech_projector.py +27 -0
- helm/clients/audio_language/llama_omni/preprocess.py +295 -0
- helm/clients/audio_language/llama_omni/utils.py +202 -0
- helm/clients/audio_language/llama_omni_client.py +2 -1
- helm/clients/audio_language/qwen2_5_omni_client.py +21 -8
- helm/clients/audio_language/qwen2_audiolm_client.py +2 -1
- helm/clients/audio_language/qwen_audiolm_client.py +2 -1
- helm/clients/audio_language/qwen_omni/configuration_qwen2_5_omni.py +519 -0
- helm/clients/audio_language/qwen_omni/modeling_qwen2_5_omni.py +4308 -0
- helm/clients/audio_language/qwen_omni/processing_qwen2_5_omni.py +270 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/__init__.py +0 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/__init__.py +8 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/audio_process.py +56 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/vision_process.py +380 -0
- helm/clients/bedrock_client.py +63 -6
- helm/clients/cohere_client.py +3 -0
- helm/clients/dspy_client.py +135 -0
- helm/clients/google_client.py +2 -0
- helm/clients/http_model_client.py +2 -0
- helm/clients/huggingface_client.py +4 -3
- helm/clients/ibm_client.py +3 -1
- helm/clients/image_generation/adobe_vision_client.py +2 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +2 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +1 -1
- helm/clients/image_generation/cogview2_client.py +2 -1
- helm/clients/image_generation/dalle2_client.py +2 -0
- helm/clients/image_generation/dalle_mini_client.py +2 -1
- helm/clients/image_generation/deep_floyd_client.py +2 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +2 -1
- helm/clients/image_generation/lexica_client.py +2 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +2 -2
- helm/clients/image_generation/mindalle_client.py +2 -1
- helm/clients/image_generation/together_image_generation_client.py +2 -0
- helm/clients/megatron_client.py +2 -0
- helm/clients/mistral_client.py +2 -0
- helm/clients/moderation_api_client.py +2 -0
- helm/clients/openai_client.py +38 -21
- helm/clients/openai_responses_client.py +34 -8
- helm/clients/openrouter_client.py +31 -0
- helm/clients/palmyra_client.py +2 -1
- helm/clients/reka_client.py +2 -1
- helm/clients/stanfordhealthcare_azure_openai_client.py +2 -2
- helm/clients/stanfordhealthcare_http_model_client.py +2 -0
- helm/clients/test_huggingface_client.py +3 -3
- helm/clients/test_openrouter_client.py +69 -0
- helm/clients/together_client.py +52 -13
- helm/clients/vertexai_client.py +23 -11
- helm/clients/vision_language/huggingface_vision2seq_client.py +2 -1
- helm/clients/vision_language/huggingface_vlm_client.py +2 -0
- helm/clients/vision_language/idefics_client.py +2 -1
- helm/clients/vision_language/open_flamingo_client.py +2 -1
- helm/clients/vision_language/paligemma_client.py +2 -1
- helm/clients/vision_language/palmyra_vision_client.py +2 -0
- helm/clients/vision_language/qwen2_vlm_client.py +2 -1
- helm/clients/vision_language/qwen_vlm_client.py +2 -1
- helm/clients/vllm_client.py +43 -7
- helm/clients/vllm_granite_thinking_client.py +56 -0
- helm/clients/writer_client.py +5 -2
- helm/common/critique_request.py +0 -1
- helm/common/hierarchical_logger.py +103 -34
- helm/common/object_spec.py +23 -8
- helm/common/optional_dependencies.py +1 -1
- helm/common/test_general.py +4 -0
- helm/common/test_logging.py +94 -0
- helm/config/model_deployments.yaml +1001 -187
- helm/config/model_metadata.yaml +602 -18
- helm/config/tokenizer_configs.yaml +202 -5
- helm/proxy/cli.py +1 -1
- helm/proxy/example_queries.py +8 -8
- helm/proxy/retry.py +5 -0
- helm/proxy/server.py +2 -1
- helm/proxy/static/index.css +4 -0
- helm/proxy/static/index.js +7 -1
- helm/tokenizers/auto_tokenizer.py +2 -2
- helm/tokenizers/grok_tokenizer.py +2 -0
- helm/benchmark/metrics/aci_bench_metrics.py +0 -14
- helm/benchmark/metrics/chw_care_plan_metrics.py +0 -14
- helm/benchmark/metrics/dischargeme_metrics.py +0 -14
- helm/benchmark/metrics/med_dialog_metrics.py +0 -14
- helm/benchmark/metrics/medalign_metrics.py +0 -14
- helm/benchmark/metrics/medi_qa_metrics.py +0 -14
- helm/benchmark/metrics/medication_qa_metrics.py +0 -14
- helm/benchmark/metrics/mental_health_metrics.py +0 -14
- helm/benchmark/metrics/mimic_bhc_metrics.py +0 -14
- helm/benchmark/metrics/mimic_rrs_metrics.py +0 -14
- helm/benchmark/metrics/mtsamples_procedures_metrics.py +0 -14
- helm/benchmark/metrics/mtsamples_replicate_metrics.py +0 -14
- helm/benchmark/metrics/numeracy_metrics.py +0 -72
- helm/benchmark/metrics/starr_patient_instructions_metrics.py +0 -14
- helm/benchmark/metrics/test_numeracy_metrics.py +0 -95
- helm/benchmark/scenarios/audio_language/ultra_suite_asr_classification.py +0 -103
- helm/benchmark/scenarios/numeracy_scenario.py +0 -794
- helm/benchmark/static_build/assets/index-94295e78.js +0 -10
- helm/benchmark/static_build/assets/index-b9779128.css +0 -1
- helm/benchmark/static_build/assets/react-f82877fd.js +0 -85
- helm/benchmark/static_build/assets/recharts-4037aff0.js +0 -97
- helm/benchmark/static_build/assets/tremor-38a10867.js +0 -10
- {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.10.dist-info}/WHEEL +0 -0
- {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.10.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.10.dist-info}/licenses/LICENSE +0 -0
- {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.10.dist-info}/top_level.txt +0 -0
- /helm/benchmark/static_build/assets/{air-overview-d2e6c49f.png → air-overview-DpBbyagA.png} +0 -0
- /helm/benchmark/static_build/assets/{crfm-logo-74391ab8.png → crfm-logo-Du4T1uWZ.png} +0 -0
- /helm/benchmark/static_build/assets/{heim-logo-3e5e3aa4.png → heim-logo-BJtQlEbV.png} +0 -0
- /helm/benchmark/static_build/assets/{helm-logo-simple-2ed5400b.png → helm-logo-simple-DzOhNN41.png} +0 -0
- /helm/benchmark/static_build/assets/{helm-safety-2907a7b6.png → helm-safety-COfndXuS.png} +0 -0
- /helm/benchmark/static_build/assets/{helmhero-28e90f4d.png → helmhero-D9TvmJsp.png} +0 -0
- /helm/benchmark/static_build/assets/{medhelm-overview-eac29843.png → medhelm-overview-CND0EIsy.png} +0 -0
- /helm/benchmark/static_build/assets/{medhelm-v1-overview-3ddfcd65.png → medhelm-v1-overview-Cu2tphBB.png} +0 -0
- /helm/benchmark/static_build/assets/{overview-74aea3d8.png → overview-BwypNWnk.png} +0 -0
- /helm/benchmark/static_build/assets/{process-flow-bd2eba96.png → process-flow-DWDJC733.png} +0 -0
- /helm/benchmark/static_build/assets/{vhelm-aspects-1437d673.png → vhelm-aspects-NiDQofvP.png} +0 -0
- /helm/benchmark/static_build/assets/{vhelm-framework-a1ca3f3f.png → vhelm-framework-NxJE4fdA.png} +0 -0
- /helm/benchmark/static_build/assets/{vhelm-model-8afb7616.png → vhelm-model-ypCL5Yvq.png} +0 -0
|
@@ -7,14 +7,13 @@
|
|
|
7
7
|
<title>Holistic Evaluation of Language Models (HELM)</title>
|
|
8
8
|
<meta name="description" content="The Holistic Evaluation of Language Models (HELM) serves as a living benchmark for transparency in language models. Providing broad coverage and recognizing incompleteness, multi-metric measurements, and standardization. All data and analysis are freely accessible on the website for exploration and study." />
|
|
9
9
|
<script type="text/javascript" src="./config.js"></script>
|
|
10
|
-
<script type="module" crossorigin src="./assets/index-
|
|
11
|
-
<link rel="modulepreload" crossorigin href="./assets/react-
|
|
12
|
-
<link rel="modulepreload" crossorigin href="./assets/recharts-
|
|
13
|
-
<link rel="modulepreload" crossorigin href="./assets/tremor-
|
|
14
|
-
<link rel="stylesheet" href="./assets/index-
|
|
10
|
+
<script type="module" crossorigin src="./assets/index-qOFpOyHb.js"></script>
|
|
11
|
+
<link rel="modulepreload" crossorigin href="./assets/react-BteFIppM.js">
|
|
12
|
+
<link rel="modulepreload" crossorigin href="./assets/recharts-DxuQtTOs.js">
|
|
13
|
+
<link rel="modulepreload" crossorigin href="./assets/tremor-DR4fE7ko.js">
|
|
14
|
+
<link rel="stylesheet" crossorigin href="./assets/index-oIeiQW2g.css">
|
|
15
15
|
</head>
|
|
16
16
|
<body class="block">
|
|
17
17
|
<div id="root"></div>
|
|
18
|
-
|
|
19
18
|
</body>
|
|
20
19
|
</html>
|
|
@@ -1,9 +1,7 @@
|
|
|
1
|
-
from abc import ABC
|
|
2
|
-
|
|
3
1
|
from helm.benchmark.window_services.local_window_service import LocalWindowService
|
|
4
2
|
|
|
5
3
|
|
|
6
|
-
class CLIPWindowService(LocalWindowService
|
|
4
|
+
class CLIPWindowService(LocalWindowService):
|
|
7
5
|
def truncate_from_right(self, text: str, expected_completion_token_length: int = 0) -> str:
|
|
8
6
|
result: str = self.decode(self.encode(text, truncation=True, max_length=self.max_request_length).tokens)
|
|
9
7
|
|
helm/clients/ai21_client.py
CHANGED
|
@@ -2,6 +2,7 @@ from typing import Dict, List, Optional, TypedDict
|
|
|
2
2
|
import requests
|
|
3
3
|
|
|
4
4
|
from helm.common.cache import CacheConfig
|
|
5
|
+
from helm.common.hierarchical_logger import hexception
|
|
5
6
|
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
6
7
|
from helm.common.request import (
|
|
7
8
|
wrap_request_time,
|
|
@@ -76,6 +77,7 @@ class AI21Client(CachingClient):
|
|
|
76
77
|
cache_key = CachingClient.make_cache_key({"engine": request.model_engine, **raw_request}, request)
|
|
77
78
|
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
78
79
|
except AI21RequestError as e:
|
|
80
|
+
hexception(e)
|
|
79
81
|
return RequestResult(success=False, cached=False, error=str(e), completions=[], embedding=[])
|
|
80
82
|
|
|
81
83
|
def fix_text(x: str, first: bool) -> str:
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from typing import List
|
|
2
2
|
|
|
3
3
|
from helm.common.cache import CacheConfig
|
|
4
|
+
from helm.common.hierarchical_logger import hexception
|
|
4
5
|
from helm.common.media_object import TEXT_TYPE
|
|
5
6
|
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
6
7
|
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, Token
|
|
@@ -76,6 +77,7 @@ class AlephAlphaClient(CachingClient):
|
|
|
76
77
|
cache_key = CachingClient.make_cache_key({"model": model, "prompt": prompt_key, **parameters}, request)
|
|
77
78
|
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
78
79
|
except Exception as e:
|
|
80
|
+
hexception(e)
|
|
79
81
|
error: str = f"AlephAlphaClient error: {e}"
|
|
80
82
|
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
81
83
|
|
helm/clients/anthropic_client.py
CHANGED
|
@@ -8,7 +8,7 @@ import time
|
|
|
8
8
|
import urllib.parse
|
|
9
9
|
|
|
10
10
|
from helm.common.cache import CacheConfig
|
|
11
|
-
from helm.common.hierarchical_logger import htrack_block, hlog, hwarn
|
|
11
|
+
from helm.common.hierarchical_logger import hexception, htrack_block, hlog, hwarn
|
|
12
12
|
from helm.common.media_object import IMAGE_TYPE, TEXT_TYPE
|
|
13
13
|
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
14
14
|
from helm.common.request import (
|
|
@@ -184,6 +184,7 @@ class AnthropicClient(CachingClient):
|
|
|
184
184
|
embedding=[],
|
|
185
185
|
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
186
186
|
)
|
|
187
|
+
hexception(error)
|
|
187
188
|
return RequestResult(success=False, cached=False, error=str(error), completions=[], embedding=[])
|
|
188
189
|
|
|
189
190
|
# Post process the completion.
|
|
@@ -385,6 +386,10 @@ class AnthropicMessagesClient(CachingClient):
|
|
|
385
386
|
# Avoid error:
|
|
386
387
|
# `top_k` must be unset when thinking is enabled. Please consult our documentation at https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking # noqa: E501
|
|
387
388
|
del raw_request["top_k"]
|
|
389
|
+
if raw_request["model"].startswith("claude-sonnet-4-5"):
|
|
390
|
+
# Avoid error:
|
|
391
|
+
# `temperature` and `top_p` cannot both be specified for this model. Please use only one.
|
|
392
|
+
del raw_request["top_p"]
|
|
388
393
|
|
|
389
394
|
completions: List[GeneratedOutput] = []
|
|
390
395
|
|
|
@@ -696,6 +701,7 @@ class AnthropicLegacyClient(CachingClient):
|
|
|
696
701
|
)
|
|
697
702
|
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
698
703
|
except AnthropicRequestError as error:
|
|
704
|
+
hexception(error)
|
|
699
705
|
return RequestResult(success=False, cached=False, error=str(error), completions=[], embedding=[])
|
|
700
706
|
|
|
701
707
|
sequence_logprob: float = 0
|
|
@@ -6,6 +6,7 @@ from transformers import AutoModel, PreTrainedModel
|
|
|
6
6
|
|
|
7
7
|
from helm.clients.client import CachingClient
|
|
8
8
|
from helm.common.cache import CacheConfig
|
|
9
|
+
from helm.common.hierarchical_logger import hexception
|
|
9
10
|
from helm.common.media_object import TEXT_TYPE
|
|
10
11
|
from helm.common.request import (
|
|
11
12
|
GeneratedOutput,
|
|
@@ -105,6 +106,7 @@ class DivaLlamaClient(CachingClient):
|
|
|
105
106
|
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
106
107
|
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
107
108
|
except Exception as e: # Do something if error is encountered.
|
|
109
|
+
hexception(e)
|
|
108
110
|
error: str = f"HuggingFace error: {e}"
|
|
109
111
|
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
110
112
|
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import transformers
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class ModelArguments:
|
|
9
|
+
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
|
10
|
+
version: Optional[str] = field(default="v0")
|
|
11
|
+
freeze_backbone: bool = field(default=False)
|
|
12
|
+
tune_speech_projector: bool = field(default=False)
|
|
13
|
+
tune_speech_encoder: bool = field(default=False)
|
|
14
|
+
tune_speech_generator_only: bool = field(default=False)
|
|
15
|
+
speech_encoder_type: Optional[str] = field(default=None)
|
|
16
|
+
speech_encoder: Optional[str] = field(default=None)
|
|
17
|
+
pretrain_speech_projector: Optional[str] = field(default=None)
|
|
18
|
+
speech_projector_type: Optional[str] = field(default="linear")
|
|
19
|
+
speech_generator_type: Optional[str] = field(default="ctc")
|
|
20
|
+
ctc_decoder_config: str = "(2,4096,32,11008)"
|
|
21
|
+
ctc_upsample_factor: int = 1
|
|
22
|
+
ctc_loss_weight: float = 1.0
|
|
23
|
+
unit_vocab_size: int = 1000
|
|
24
|
+
speech_encoder_ds_rate: int = 5
|
|
25
|
+
speech_encoder_hidden_size: int = 1280
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class DataArguments:
|
|
30
|
+
data_path: str = field(default="", metadata={"help": "Path to the training data."})
|
|
31
|
+
is_multimodal: bool = False
|
|
32
|
+
input_type: str = field(default="mel")
|
|
33
|
+
speech_normalize: bool = False
|
|
34
|
+
mel_size: int = 128
|
|
35
|
+
has_tgt_units: bool = False
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class TrainingArguments(transformers.TrainingArguments):
|
|
40
|
+
cache_dir: Optional[str] = field(default=None)
|
|
41
|
+
optim: str = field(default="adamw_torch")
|
|
42
|
+
freeze_speech_projector: bool = field(default=False)
|
|
43
|
+
model_max_length: int = field(
|
|
44
|
+
default=512,
|
|
45
|
+
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
|
|
46
|
+
)
|
|
47
|
+
double_quant: bool = field(
|
|
48
|
+
default=True, metadata={"help": "Compress the quantization statistics through double quantization."}
|
|
49
|
+
)
|
|
50
|
+
quant_type: str = field(
|
|
51
|
+
default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
|
|
52
|
+
)
|
|
53
|
+
bits: int = field(default=16, metadata={"help": "How many bits to use."})
|
|
54
|
+
lora_enable: bool = False
|
|
55
|
+
lora_r: int = 64
|
|
56
|
+
lora_alpha: int = 16
|
|
57
|
+
lora_dropout: float = 0.05
|
|
58
|
+
lora_weight_path: str = ""
|
|
59
|
+
lora_bias: str = "none"
|
|
60
|
+
speech_projector_lr: Optional[float] = None
|
|
61
|
+
group_by_modality_length: bool = field(default=False)
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
|
|
2
|
+
# Copyright 2023 Haotian Liu
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import dataclasses
|
|
17
|
+
from enum import auto, Enum
|
|
18
|
+
from typing import List, Any, Union, Optional
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SeparatorStyle(Enum):
|
|
22
|
+
"""Different separator style."""
|
|
23
|
+
|
|
24
|
+
TWO = auto()
|
|
25
|
+
PLAIN = auto()
|
|
26
|
+
LLAMA_2 = auto()
|
|
27
|
+
LLAMA_3 = auto()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclasses.dataclass
|
|
31
|
+
class Conversation:
|
|
32
|
+
"""A class that keeps all conversation history."""
|
|
33
|
+
|
|
34
|
+
system: str
|
|
35
|
+
roles: List[str]
|
|
36
|
+
messages: List[List[str]]
|
|
37
|
+
offset: int
|
|
38
|
+
sep_style: SeparatorStyle = SeparatorStyle.PLAIN
|
|
39
|
+
sep: str = "###"
|
|
40
|
+
sep2: str = ""
|
|
41
|
+
version: str = "Unknown"
|
|
42
|
+
|
|
43
|
+
tokenizer_id: str = ""
|
|
44
|
+
tokenizer: Any = None
|
|
45
|
+
# Stop criteria (the default one is EOS token)
|
|
46
|
+
stop_str: Optional[Union[str, List[str]]] = None
|
|
47
|
+
# Stops generation if meeting any token in this list
|
|
48
|
+
stop_token_ids: Optional[List[int]] = None
|
|
49
|
+
|
|
50
|
+
skip_next: bool = False
|
|
51
|
+
|
|
52
|
+
def get_prompt(self):
|
|
53
|
+
messages = self.messages
|
|
54
|
+
|
|
55
|
+
if self.sep_style == SeparatorStyle.TWO:
|
|
56
|
+
seps = [self.sep, self.sep2]
|
|
57
|
+
ret = self.system + seps[0]
|
|
58
|
+
for i, (role, message) in enumerate(messages):
|
|
59
|
+
if message:
|
|
60
|
+
if type(message) is tuple:
|
|
61
|
+
message = message[0]
|
|
62
|
+
ret += role + ": " + message + seps[i % 2]
|
|
63
|
+
else:
|
|
64
|
+
ret += role + ":"
|
|
65
|
+
elif self.sep_style == SeparatorStyle.LLAMA_3:
|
|
66
|
+
wrap_sys = lambda msg: (
|
|
67
|
+
f"<|start_header_id|>system<|end_header_id|>\n\n{msg}<|eot_id|>" if len(msg) > 0 else msg
|
|
68
|
+
)
|
|
69
|
+
ret = "<|begin_of_text|>" + wrap_sys(self.system)
|
|
70
|
+
for i, (role, message) in enumerate(messages):
|
|
71
|
+
if message:
|
|
72
|
+
if type(message) is tuple:
|
|
73
|
+
message = message[0]
|
|
74
|
+
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
|
75
|
+
ret += message.strip() + self.sep2
|
|
76
|
+
else:
|
|
77
|
+
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
|
78
|
+
return ret
|
|
79
|
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
|
80
|
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
|
|
81
|
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
|
82
|
+
ret = ""
|
|
83
|
+
|
|
84
|
+
for i, (role, message) in enumerate(messages):
|
|
85
|
+
if i == 0:
|
|
86
|
+
assert message, "first message should not be none"
|
|
87
|
+
assert role == self.roles[0], "first message should come from user"
|
|
88
|
+
if message:
|
|
89
|
+
if type(message) is tuple:
|
|
90
|
+
message = message[0]
|
|
91
|
+
if i == 0:
|
|
92
|
+
message = wrap_sys(self.system) + message
|
|
93
|
+
if i % 2 == 0:
|
|
94
|
+
message = wrap_inst(message)
|
|
95
|
+
ret += self.sep + message
|
|
96
|
+
else:
|
|
97
|
+
ret += " " + message + " " + self.sep2
|
|
98
|
+
else:
|
|
99
|
+
ret += ""
|
|
100
|
+
ret = ret.lstrip(self.sep)
|
|
101
|
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
|
102
|
+
seps = [self.sep, self.sep2]
|
|
103
|
+
ret = self.system
|
|
104
|
+
for i, (role, message) in enumerate(messages):
|
|
105
|
+
if message:
|
|
106
|
+
if type(message) is tuple:
|
|
107
|
+
message = message[0]
|
|
108
|
+
ret += message + seps[i % 2]
|
|
109
|
+
else:
|
|
110
|
+
ret += ""
|
|
111
|
+
else:
|
|
112
|
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
|
113
|
+
|
|
114
|
+
return ret
|
|
115
|
+
|
|
116
|
+
def append_message(self, role, message):
|
|
117
|
+
self.messages.append([role, message])
|
|
118
|
+
|
|
119
|
+
def to_gradio_chatbot(self):
|
|
120
|
+
ret = []
|
|
121
|
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
|
122
|
+
if i % 2 == 0:
|
|
123
|
+
if type(msg) is tuple:
|
|
124
|
+
msg = msg[0]
|
|
125
|
+
ret.append([msg, None])
|
|
126
|
+
else:
|
|
127
|
+
ret.append([msg, None])
|
|
128
|
+
else:
|
|
129
|
+
ret[-1][-1] = msg
|
|
130
|
+
return ret
|
|
131
|
+
|
|
132
|
+
def copy(self):
|
|
133
|
+
return Conversation(
|
|
134
|
+
system=self.system,
|
|
135
|
+
roles=self.roles,
|
|
136
|
+
messages=[[x, y] for x, y in self.messages],
|
|
137
|
+
offset=self.offset,
|
|
138
|
+
sep_style=self.sep_style,
|
|
139
|
+
sep=self.sep,
|
|
140
|
+
sep2=self.sep2,
|
|
141
|
+
version=self.version,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def dict(self):
|
|
145
|
+
return {
|
|
146
|
+
"system": self.system,
|
|
147
|
+
"roles": self.roles,
|
|
148
|
+
"messages": self.messages,
|
|
149
|
+
"offset": self.offset,
|
|
150
|
+
"sep": self.sep,
|
|
151
|
+
"sep2": self.sep2,
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
conv_vicuna_v1 = Conversation(
|
|
156
|
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
|
157
|
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
|
158
|
+
roles=["USER", "ASSISTANT"],
|
|
159
|
+
version="v1",
|
|
160
|
+
messages=[],
|
|
161
|
+
offset=0,
|
|
162
|
+
sep_style=SeparatorStyle.TWO,
|
|
163
|
+
sep=" ",
|
|
164
|
+
sep2="</s>",
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
conv_llama_2 = Conversation(
|
|
168
|
+
system="You are a helpful language and speech assistant. "
|
|
169
|
+
"You are able to understand the speech content that the user provides, "
|
|
170
|
+
"and assist the user with a variety of tasks using natural language.",
|
|
171
|
+
roles=["USER", "ASSISTANT"],
|
|
172
|
+
version="llama_v2",
|
|
173
|
+
messages=[],
|
|
174
|
+
offset=0,
|
|
175
|
+
sep_style=SeparatorStyle.LLAMA_2,
|
|
176
|
+
sep="<s>",
|
|
177
|
+
sep2="</s>",
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
conv_llama_3 = Conversation(
|
|
181
|
+
system="You are a helpful language and speech assistant. "
|
|
182
|
+
"You are able to understand the speech content that the user provides, "
|
|
183
|
+
"and assist the user with a variety of tasks using natural language.",
|
|
184
|
+
roles=["user", "assistant"],
|
|
185
|
+
version="llama_v3",
|
|
186
|
+
messages=[],
|
|
187
|
+
offset=0,
|
|
188
|
+
sep_style=SeparatorStyle.LLAMA_3,
|
|
189
|
+
sep="",
|
|
190
|
+
sep2="<|eot_id|>",
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
conv_plain = Conversation(
|
|
194
|
+
system="",
|
|
195
|
+
roles=["", ""],
|
|
196
|
+
messages=[],
|
|
197
|
+
offset=0,
|
|
198
|
+
sep_style=SeparatorStyle.PLAIN,
|
|
199
|
+
sep="</s>",
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
default_conversation = conv_llama_3
|
|
204
|
+
conv_templates = {
|
|
205
|
+
"v1": conv_vicuna_v1,
|
|
206
|
+
"plain": conv_plain,
|
|
207
|
+
"llama_2": conv_llama_2,
|
|
208
|
+
"llama_3": conv_llama_3,
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
if __name__ == "__main__":
|
|
213
|
+
print(default_conversation.get_prompt())
|
|
File without changes
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from transformers import AutoTokenizer, AutoConfig, BitsAndBytesConfig
|
|
4
|
+
import torch
|
|
5
|
+
from helm.clients.audio_language.llama_omni.model.language_model.omni_speech_llama import OmniSpeechLlamaForCausalLM
|
|
6
|
+
from helm.clients.audio_language.llama_omni.model.language_model.omni_speech2s_llama import OmniSpeech2SLlamaForCausalLM
|
|
7
|
+
from helm.clients.audio_language.llama_omni.model.speech_encoder.builder import build_speech_encoder
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def load_pretrained_model(
|
|
11
|
+
model_path,
|
|
12
|
+
model_base,
|
|
13
|
+
is_lora=False,
|
|
14
|
+
s2s=False,
|
|
15
|
+
load_8bit=False,
|
|
16
|
+
load_4bit=False,
|
|
17
|
+
device="cuda",
|
|
18
|
+
use_flash_attn=False,
|
|
19
|
+
**kwargs,
|
|
20
|
+
):
|
|
21
|
+
if load_8bit:
|
|
22
|
+
kwargs["load_in_8bit"] = True
|
|
23
|
+
elif load_4bit:
|
|
24
|
+
kwargs["load_in_4bit"] = True
|
|
25
|
+
kwargs["quantization_config"] = BitsAndBytesConfig(
|
|
26
|
+
load_in_4bit=True,
|
|
27
|
+
bnb_4bit_compute_dtype=torch.float16,
|
|
28
|
+
bnb_4bit_use_double_quant=True,
|
|
29
|
+
bnb_4bit_quant_type="nf4",
|
|
30
|
+
)
|
|
31
|
+
else:
|
|
32
|
+
kwargs["torch_dtype"] = torch.float16
|
|
33
|
+
|
|
34
|
+
if use_flash_attn:
|
|
35
|
+
kwargs["attn_implementation"] = "flash_attention_2"
|
|
36
|
+
|
|
37
|
+
model_cls = OmniSpeech2SLlamaForCausalLM if s2s else OmniSpeechLlamaForCausalLM
|
|
38
|
+
|
|
39
|
+
# Load OmniSpeech model
|
|
40
|
+
if is_lora:
|
|
41
|
+
assert model_base is not None, "model_base is required for LoRA models."
|
|
42
|
+
from language_model.omni_speech_llama import OmniSpeechConfig
|
|
43
|
+
|
|
44
|
+
lora_cfg_pretrained = OmniSpeechConfig.from_pretrained(model_path)
|
|
45
|
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
|
46
|
+
print("Loading OmniSpeech from base model...")
|
|
47
|
+
model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=lora_cfg_pretrained, **kwargs)
|
|
48
|
+
print("Loading additional OmniSpeech weights...")
|
|
49
|
+
if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
|
|
50
|
+
non_lora_trainables = torch.load(os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu")
|
|
51
|
+
non_lora_trainables = {
|
|
52
|
+
(k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items()
|
|
53
|
+
}
|
|
54
|
+
if any(k.startswith("model.model.") for k in non_lora_trainables):
|
|
55
|
+
non_lora_trainables = {(k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()}
|
|
56
|
+
model.load_state_dict(non_lora_trainables, strict=False)
|
|
57
|
+
|
|
58
|
+
from peft import PeftModel
|
|
59
|
+
|
|
60
|
+
print("Loading LoRA weights...")
|
|
61
|
+
model = PeftModel.from_pretrained(model, model_path)
|
|
62
|
+
print("Merging LoRA weights...")
|
|
63
|
+
model = model.merge_and_unload()
|
|
64
|
+
print("Model is loaded...")
|
|
65
|
+
elif model_base is not None:
|
|
66
|
+
print("Loading OmniSpeech from base model...")
|
|
67
|
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
|
68
|
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
|
69
|
+
model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=cfg_pretrained, **kwargs)
|
|
70
|
+
|
|
71
|
+
speech_projector_weights = torch.load(os.path.join(model_path, "speech_projector.bin"), map_location="cpu")
|
|
72
|
+
speech_projector_weights = {k: v.to(torch.float16) for k, v in speech_projector_weights.items()}
|
|
73
|
+
model.load_state_dict(speech_projector_weights, strict=False)
|
|
74
|
+
model = model.to(device=device)
|
|
75
|
+
else:
|
|
76
|
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
|
77
|
+
model = model_cls.from_pretrained(model_path, low_cpu_mem_usage=False, **kwargs)
|
|
78
|
+
model = model.to(device=device)
|
|
79
|
+
|
|
80
|
+
model.get_model().speech_encoder = build_speech_encoder(model.config)
|
|
81
|
+
model.get_model().speech_encoder.to(device=device, dtype=torch.float16)
|
|
82
|
+
|
|
83
|
+
if hasattr(model.config, "max_sequence_length"):
|
|
84
|
+
context_len = model.config.max_sequence_length
|
|
85
|
+
else:
|
|
86
|
+
context_len = 2048
|
|
87
|
+
|
|
88
|
+
return tokenizer, model, context_len
|