crfm-helm 0.5.6__py3-none-any.whl → 0.5.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crfm-helm might be problematic. Click here for more details.

Files changed (394) hide show
  1. {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.10.dist-info}/METADATA +72 -130
  2. {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.10.dist-info}/RECORD +372 -305
  3. helm/benchmark/adaptation/adapter_spec.py +10 -0
  4. helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +11 -3
  5. helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +11 -8
  6. helm/benchmark/annotation/aci_bench_annotator.py +11 -22
  7. helm/benchmark/annotation/air_bench_annotator.py +1 -1
  8. helm/benchmark/annotation/alrage_annotator.py +90 -0
  9. helm/benchmark/annotation/chw_care_plan_annotator.py +10 -21
  10. helm/benchmark/annotation/dischargeme_annotator.py +11 -22
  11. helm/benchmark/annotation/live_qa_annotator.py +1 -1
  12. helm/benchmark/annotation/med_dialog_annotator.py +11 -22
  13. helm/benchmark/annotation/medalign_annotator.py +11 -22
  14. helm/benchmark/annotation/medi_qa_annotator.py +11 -22
  15. helm/benchmark/annotation/medication_qa_annotator.py +11 -22
  16. helm/benchmark/annotation/mental_health_annotator.py +11 -22
  17. helm/benchmark/annotation/mimic_bhc_annotator.py +11 -22
  18. helm/benchmark/annotation/mimic_rrs_annotator.py +11 -22
  19. helm/benchmark/annotation/model_as_judge.py +23 -18
  20. helm/benchmark/annotation/mtsamples_procedures_annotator.py +11 -22
  21. helm/benchmark/annotation/mtsamples_replicate_annotator.py +11 -22
  22. helm/benchmark/annotation/starr_patient_instructions_annotator.py +11 -22
  23. helm/benchmark/metrics/air_bench_metrics.py +3157 -1
  24. helm/benchmark/metrics/alrage_metric.py +35 -0
  25. helm/benchmark/metrics/basic_metrics.py +267 -2
  26. helm/benchmark/metrics/bbq_metrics.py +12 -0
  27. helm/benchmark/metrics/classification_metrics.py +19 -1
  28. helm/benchmark/metrics/codeinsights_code_efficiency_metrics.py +186 -0
  29. helm/benchmark/metrics/codeinsights_code_evaluation_metrics.py +477 -0
  30. helm/benchmark/metrics/codeinsights_correct_code_metrics.py +366 -0
  31. helm/benchmark/metrics/codeinsights_edge_case_metrics.py +92 -0
  32. helm/benchmark/metrics/codeinsights_metric_specs.py +51 -0
  33. helm/benchmark/metrics/comet_metric.py +1 -1
  34. helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +12 -1
  35. helm/benchmark/metrics/copyright_metrics.py +1 -1
  36. helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +1 -1
  37. helm/benchmark/metrics/dry_run_metrics.py +30 -1
  38. helm/benchmark/metrics/efficiency_metrics.py +74 -0
  39. helm/benchmark/metrics/ehr_sql_metrics.py +57 -1
  40. helm/benchmark/metrics/evaluate_reference_metrics.py +312 -1
  41. helm/benchmark/metrics/gpqa_chain_of_thought_metric.py +13 -1
  42. helm/benchmark/metrics/helpdesk_call_summarization_metrics.py +13 -1
  43. helm/benchmark/metrics/ifeval_metrics.py +13 -1
  44. helm/benchmark/metrics/image_generation/clip_score_metrics.py +13 -2
  45. helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +1 -1
  46. helm/benchmark/metrics/instruction_following_critique_metrics.py +41 -1
  47. helm/benchmark/metrics/kpi_edgar_metrics.py +21 -0
  48. helm/benchmark/metrics/language_modeling_metrics.py +13 -1
  49. helm/benchmark/metrics/live_qa_metrics.py +13 -1
  50. helm/benchmark/metrics/llm_jury_metrics.py +13 -1
  51. helm/benchmark/metrics/lmkt_metric_specs.py +12 -0
  52. helm/benchmark/metrics/lmkt_metrics.py +47 -0
  53. helm/benchmark/metrics/medcalc_bench_metrics.py +14 -1
  54. helm/benchmark/metrics/medec_metrics.py +25 -2
  55. helm/benchmark/metrics/melt_toxicity_metric.py +1 -1
  56. helm/benchmark/metrics/metric.py +25 -0
  57. helm/benchmark/metrics/mimiciv_billing_code_metrics.py +32 -1
  58. helm/benchmark/metrics/omni_math_metrics.py +13 -1
  59. helm/benchmark/metrics/safety_metrics.py +13 -1
  60. helm/benchmark/metrics/seahelm_metrics.py +14 -1
  61. helm/benchmark/metrics/summac/model_summac.py +3 -3
  62. helm/benchmark/metrics/summarization_metrics.py +129 -1
  63. helm/benchmark/metrics/toxicity_metrics.py +31 -1
  64. helm/benchmark/metrics/ultra_suite_asr_classification_metrics.py +52 -0
  65. helm/benchmark/metrics/wildbench_metrics.py +21 -1
  66. helm/benchmark/model_deployment_registry.py +11 -19
  67. helm/benchmark/presentation/create_plots.py +11 -2
  68. helm/benchmark/presentation/run_display.py +13 -3
  69. helm/benchmark/presentation/run_entry.py +2 -2
  70. helm/benchmark/presentation/schema.py +10 -22
  71. helm/benchmark/presentation/summarize.py +189 -14
  72. helm/benchmark/presentation/taxonomy_info.py +20 -0
  73. helm/benchmark/presentation/test_create_plots.py +4 -1
  74. helm/benchmark/run.py +15 -4
  75. helm/benchmark/run_expander.py +4 -0
  76. helm/benchmark/run_specs/arabic_run_specs.py +197 -0
  77. helm/benchmark/run_specs/bluex_run_specs.py +40 -0
  78. helm/benchmark/run_specs/classic_run_specs.py +2 -55
  79. helm/benchmark/run_specs/codeinsights_run_specs.py +192 -0
  80. helm/benchmark/run_specs/healthqa_br_run_specs.py +40 -0
  81. helm/benchmark/run_specs/heim_run_specs.py +3 -1
  82. helm/benchmark/run_specs/lmkt_run_specs.py +144 -0
  83. helm/benchmark/run_specs/long_context_run_specs.py +48 -1
  84. helm/benchmark/run_specs/medhelm/__init__.py +0 -0
  85. helm/benchmark/run_specs/medhelm/benchmark_config.py +219 -0
  86. helm/benchmark/run_specs/medhelm_run_specs.py +363 -53
  87. helm/benchmark/run_specs/multilingual_run_specs.py +50 -0
  88. helm/benchmark/run_specs/speech_disorder_audio_run_specs.py +11 -13
  89. helm/benchmark/runner.py +7 -0
  90. helm/benchmark/scenarios/aci_bench_scenario.py +23 -0
  91. helm/benchmark/scenarios/air_bench_scenario.py +21 -0
  92. helm/benchmark/scenarios/alghafa_scenario.py +126 -0
  93. helm/benchmark/scenarios/alrage_scenario.py +54 -0
  94. helm/benchmark/scenarios/anthropic_hh_rlhf_scenario.py +23 -1
  95. helm/benchmark/scenarios/anthropic_red_team_scenario.py +12 -1
  96. helm/benchmark/scenarios/arabic_exams_scenario.py +114 -0
  97. helm/benchmark/scenarios/arabic_mmlu_scenario.py +82 -0
  98. helm/benchmark/scenarios/aratrust_scenario.py +95 -0
  99. helm/benchmark/scenarios/audio_language/casual_conversations2_scenario.py +1 -1
  100. helm/benchmark/scenarios/audio_language/mustard_scenario.py +1 -1
  101. helm/benchmark/scenarios/audio_language/ultra_suite_asr_classification_scenario.py +74 -0
  102. helm/benchmark/scenarios/audio_language/ultra_suite_asr_transcription_scenario.py +70 -0
  103. helm/benchmark/scenarios/audio_language/ultra_suite_classification_scenario.py +22 -53
  104. helm/benchmark/scenarios/audio_language/ultra_suite_disorder_breakdown_scenario.py +21 -21
  105. helm/benchmark/scenarios/audio_language/ultra_suite_disorder_symptoms_scenario.py +21 -52
  106. helm/benchmark/scenarios/babi_qa_scenario.py +15 -0
  107. helm/benchmark/scenarios/banking77_scenario.py +21 -0
  108. helm/benchmark/scenarios/bbq_scenario.py +15 -0
  109. helm/benchmark/scenarios/best_chatgpt_prompts.yaml +473 -0
  110. helm/benchmark/scenarios/bird_sql_scenario.py +18 -0
  111. helm/benchmark/scenarios/bluex_scenario.py +70 -0
  112. helm/benchmark/scenarios/bold_scenario.py +15 -0
  113. helm/benchmark/scenarios/boolq_scenario.py +20 -0
  114. helm/benchmark/scenarios/chw_care_plan_scenario.py +23 -0
  115. helm/benchmark/scenarios/civil_comments_scenario.py +13 -0
  116. helm/benchmark/scenarios/clear_scenario.py +23 -0
  117. helm/benchmark/scenarios/cleva_scenario.py +480 -1
  118. helm/benchmark/scenarios/code_scenario.py +28 -0
  119. helm/benchmark/scenarios/codeinsights_code_efficiency_scenario.py +197 -0
  120. helm/benchmark/scenarios/codeinsights_correct_code_scenario.py +78 -0
  121. helm/benchmark/scenarios/codeinsights_edge_case_scenario.py +192 -0
  122. helm/benchmark/scenarios/codeinsights_student_coding_scenario.py +162 -0
  123. helm/benchmark/scenarios/codeinsights_student_mistake_scenario.py +188 -0
  124. helm/benchmark/scenarios/commonsense_scenario.py +32 -0
  125. helm/benchmark/scenarios/compositional_instructions.yaml +70 -0
  126. helm/benchmark/scenarios/conv_fin_qa_calc_scenario.py +21 -0
  127. helm/benchmark/scenarios/copyright_scenario.py +35 -1
  128. helm/benchmark/scenarios/cti_to_mitre_scenario.py +21 -0
  129. helm/benchmark/scenarios/czech_bank_qa_scenario.py +18 -0
  130. helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +22 -1
  131. helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +23 -1
  132. helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +22 -1
  133. helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +21 -1
  134. helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +13 -0
  135. helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +13 -1
  136. helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +13 -1
  137. helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +13 -1
  138. helm/benchmark/scenarios/dischargeme_scenario.py +24 -0
  139. helm/benchmark/scenarios/disinformation_scenario.py +22 -0
  140. helm/benchmark/scenarios/dyck_language_scenario.py +15 -0
  141. helm/benchmark/scenarios/ehrshot_scenario.py +22 -0
  142. helm/benchmark/scenarios/enem_challenge_scenario.py +19 -0
  143. helm/benchmark/scenarios/entity_data_imputation_scenario.py +14 -0
  144. helm/benchmark/scenarios/entity_matching_scenario.py +14 -0
  145. helm/benchmark/scenarios/exams_multilingual_scenario.py +115 -0
  146. helm/benchmark/scenarios/fin_qa_scenario.py +20 -0
  147. helm/benchmark/scenarios/financebench_scenario.py +21 -0
  148. helm/benchmark/scenarios/financial_phrasebank_scenario.py +21 -0
  149. helm/benchmark/scenarios/gold_commodity_news_scenario.py +21 -0
  150. helm/benchmark/scenarios/gpqa_scenario.py +18 -0
  151. helm/benchmark/scenarios/grammar_scenario.py +20 -1
  152. helm/benchmark/scenarios/gsm_scenario.py +21 -0
  153. helm/benchmark/scenarios/harm_bench_gcg_transfer_scenario.py +12 -1
  154. helm/benchmark/scenarios/harm_bench_scenario.py +12 -1
  155. helm/benchmark/scenarios/headqa_scenario.py +22 -0
  156. helm/benchmark/scenarios/healthqa_br_scenario.py +80 -0
  157. helm/benchmark/scenarios/helpdesk_call_summarization_scenario.py +13 -0
  158. helm/benchmark/scenarios/ice_scenario.py +21 -1
  159. helm/benchmark/scenarios/ifeval_scenario.py +18 -0
  160. helm/benchmark/scenarios/imdb_scenario.py +15 -0
  161. helm/benchmark/scenarios/infinite_bench_en_mc_scenario.py +111 -0
  162. helm/benchmark/scenarios/infinite_bench_en_qa_scenario.py +1 -1
  163. helm/benchmark/scenarios/infinite_bench_en_sum_scenario.py +19 -0
  164. helm/benchmark/scenarios/koala_scenario.py +21 -1
  165. helm/benchmark/scenarios/kpi_edgar_scenario.py +21 -0
  166. helm/benchmark/scenarios/legal_contract_summarization_scenario.py +20 -0
  167. helm/benchmark/scenarios/legal_summarization_scenario.py +50 -0
  168. helm/benchmark/scenarios/legal_support_scenario.py +13 -0
  169. helm/benchmark/scenarios/legalbench_scenario.py +19 -0
  170. helm/benchmark/scenarios/lex_glue_scenario.py +11 -0
  171. helm/benchmark/scenarios/lextreme_scenario.py +11 -0
  172. helm/benchmark/scenarios/lmkt_scenarios.py +288 -0
  173. helm/benchmark/scenarios/lsat_qa_scenario.py +14 -0
  174. helm/benchmark/scenarios/madinah_qa_scenario.py +73 -0
  175. helm/benchmark/scenarios/math_scenario.py +54 -20
  176. helm/benchmark/scenarios/mbzuai_human_translated_arabic_mmlu.py +68 -0
  177. helm/benchmark/scenarios/med_dialog_scenario.py +32 -1
  178. helm/benchmark/scenarios/med_mcqa_scenario.py +14 -0
  179. helm/benchmark/scenarios/med_qa_scenario.py +20 -0
  180. helm/benchmark/scenarios/medalign_scenario.py +23 -0
  181. helm/benchmark/scenarios/medalign_scenario_helper.py +19 -125
  182. helm/benchmark/scenarios/medbullets_scenario.py +22 -0
  183. helm/benchmark/scenarios/medcalc_bench_scenario.py +22 -0
  184. helm/benchmark/scenarios/medec_scenario.py +23 -0
  185. helm/benchmark/scenarios/medhallu_scenario.py +23 -0
  186. helm/benchmark/scenarios/medhelm/__init__.py +0 -0
  187. helm/benchmark/scenarios/medhelm/judges.yaml +14 -0
  188. helm/benchmark/scenarios/medhelm_configurable_scenario.py +101 -0
  189. helm/benchmark/scenarios/medi_qa_scenario.py +24 -1
  190. helm/benchmark/scenarios/medication_qa_scenario.py +31 -1
  191. helm/benchmark/scenarios/melt_scenarios.py +2 -2
  192. helm/benchmark/scenarios/mental_health_scenario.py +23 -0
  193. helm/benchmark/scenarios/mimic_bhc_scenario.py +25 -1
  194. helm/benchmark/scenarios/mimic_rrs_scenario.py +23 -0
  195. helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +22 -0
  196. helm/benchmark/scenarios/mmlu_pro_scenario.py +18 -0
  197. helm/benchmark/scenarios/mmlu_scenario.py +21 -0
  198. helm/benchmark/scenarios/mmmlu_scenario.py +85 -0
  199. helm/benchmark/scenarios/msmarco_scenario.py +30 -0
  200. helm/benchmark/scenarios/mtsamples_procedures_scenario.py +22 -0
  201. helm/benchmark/scenarios/mtsamples_replicate_scenario.py +22 -0
  202. helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +20 -0
  203. helm/benchmark/scenarios/narrativeqa_scenario.py +19 -0
  204. helm/benchmark/scenarios/natural_qa_scenario.py +32 -0
  205. helm/benchmark/scenarios/omni_math_scenario.py +18 -0
  206. helm/benchmark/scenarios/open_assistant_scenario.py +22 -0
  207. helm/benchmark/scenarios/openai_mrcr_scenario.py +15 -0
  208. helm/benchmark/scenarios/pubmed_qa_scenario.py +22 -0
  209. helm/benchmark/scenarios/quac_scenario.py +14 -0
  210. helm/benchmark/scenarios/race_based_med_scenario.py +23 -0
  211. helm/benchmark/scenarios/raft_scenario.py +15 -0
  212. helm/benchmark/scenarios/real_toxicity_prompts_scenario.py +14 -1
  213. helm/benchmark/scenarios/ruler_qa_scenarios.py +40 -0
  214. helm/benchmark/scenarios/scenario.py +31 -0
  215. helm/benchmark/scenarios/seahelm_scenario.py +350 -2
  216. helm/benchmark/scenarios/self_instruct_scenario.py +29 -1
  217. helm/benchmark/scenarios/shc_bmt_scenario.py +22 -0
  218. helm/benchmark/scenarios/shc_cdi_scenario.py +20 -0
  219. helm/benchmark/scenarios/shc_conf_scenario.py +23 -0
  220. helm/benchmark/scenarios/shc_ent_scenario.py +21 -0
  221. helm/benchmark/scenarios/shc_gip_scenario.py +20 -0
  222. helm/benchmark/scenarios/shc_privacy_scenario.py +22 -0
  223. helm/benchmark/scenarios/shc_proxy_scenario.py +23 -1
  224. helm/benchmark/scenarios/shc_ptbm_scenario.py +23 -0
  225. helm/benchmark/scenarios/shc_sequoia_scenario.py +21 -0
  226. helm/benchmark/scenarios/simple_safety_tests_scenario.py +12 -1
  227. helm/benchmark/scenarios/situation_prompts.yaml +49 -0
  228. helm/benchmark/scenarios/spider_scenario.py +18 -0
  229. helm/benchmark/scenarios/starr_patient_instructions_scenario.py +22 -0
  230. helm/benchmark/scenarios/summarization_scenario.py +37 -0
  231. helm/benchmark/scenarios/synthetic_efficiency_scenario.py +22 -1
  232. helm/benchmark/scenarios/synthetic_reasoning_natural_scenario.py +13 -0
  233. helm/benchmark/scenarios/test_alghafa_scenario.py +29 -0
  234. helm/benchmark/scenarios/test_alrage_scenario.py +23 -0
  235. helm/benchmark/scenarios/test_arabic_exams_scenario.py +21 -0
  236. helm/benchmark/scenarios/test_aratrust_scenario.py +21 -0
  237. helm/benchmark/scenarios/test_bluex_scenario.py +59 -0
  238. helm/benchmark/scenarios/test_exams_multilingual_scenario.py +29 -0
  239. helm/benchmark/scenarios/test_healtha_br_scenario.py +57 -0
  240. helm/benchmark/scenarios/thai_exam_scenario.py +95 -0
  241. helm/benchmark/scenarios/the_pile_scenario.py +13 -1
  242. helm/benchmark/scenarios/truthful_qa_scenario.py +14 -0
  243. helm/benchmark/scenarios/twitter_aae_scenario.py +20 -1
  244. helm/benchmark/scenarios/vicuna_scenario.py +21 -1
  245. helm/benchmark/scenarios/wikifact_scenario.py +20 -0
  246. helm/benchmark/scenarios/wildbench_scenario.py +18 -0
  247. helm/benchmark/scenarios/wmt_14_scenario.py +19 -0
  248. helm/benchmark/slurm_jobs.py +1 -2
  249. helm/benchmark/slurm_runner.py +8 -1
  250. helm/benchmark/static/schema_arabic.yaml +271 -0
  251. helm/benchmark/static/schema_classic.yaml +0 -17
  252. helm/benchmark/static/schema_long_context.yaml +17 -18
  253. helm/benchmark/static/schema_medhelm.yaml +36 -0
  254. helm/benchmark/static/schema_slp.yaml +219 -0
  255. helm/benchmark/static_build/assets/audio-table-Dn5NMMeJ.png +0 -0
  256. helm/benchmark/static_build/assets/index-oIeiQW2g.css +1 -0
  257. helm/benchmark/static_build/assets/index-qOFpOyHb.js +10 -0
  258. helm/benchmark/static_build/assets/react-BteFIppM.js +85 -0
  259. helm/benchmark/static_build/assets/recharts-DxuQtTOs.js +97 -0
  260. helm/benchmark/static_build/assets/tremor-DR4fE7ko.js +10 -0
  261. helm/benchmark/static_build/index.html +5 -6
  262. helm/benchmark/window_services/image_generation/clip_window_service.py +1 -3
  263. helm/clients/ai21_client.py +2 -0
  264. helm/clients/aleph_alpha_client.py +2 -0
  265. helm/clients/anthropic_client.py +7 -1
  266. helm/clients/audio_language/diva_llama_client.py +2 -0
  267. helm/clients/audio_language/llama_omni/arguments.py +61 -0
  268. helm/clients/audio_language/llama_omni/constants.py +9 -0
  269. helm/clients/audio_language/llama_omni/conversation.py +213 -0
  270. helm/clients/audio_language/llama_omni/model/__init__.py +0 -0
  271. helm/clients/audio_language/llama_omni/model/builder.py +88 -0
  272. helm/clients/audio_language/llama_omni/model/language_model/omni_speech2s_llama.py +190 -0
  273. helm/clients/audio_language/llama_omni/model/language_model/omni_speech_llama.py +118 -0
  274. helm/clients/audio_language/llama_omni/model/omni_speech_arch.py +249 -0
  275. helm/clients/audio_language/llama_omni/model/speech_encoder/builder.py +9 -0
  276. helm/clients/audio_language/llama_omni/model/speech_encoder/speech_encoder.py +27 -0
  277. helm/clients/audio_language/llama_omni/model/speech_generator/builder.py +9 -0
  278. helm/clients/audio_language/llama_omni/model/speech_generator/generation.py +622 -0
  279. helm/clients/audio_language/llama_omni/model/speech_generator/speech_generator.py +104 -0
  280. helm/clients/audio_language/llama_omni/model/speech_projector/builder.py +9 -0
  281. helm/clients/audio_language/llama_omni/model/speech_projector/speech_projector.py +27 -0
  282. helm/clients/audio_language/llama_omni/preprocess.py +295 -0
  283. helm/clients/audio_language/llama_omni/utils.py +202 -0
  284. helm/clients/audio_language/llama_omni_client.py +2 -1
  285. helm/clients/audio_language/qwen2_5_omni_client.py +21 -8
  286. helm/clients/audio_language/qwen2_audiolm_client.py +2 -1
  287. helm/clients/audio_language/qwen_audiolm_client.py +2 -1
  288. helm/clients/audio_language/qwen_omni/configuration_qwen2_5_omni.py +519 -0
  289. helm/clients/audio_language/qwen_omni/modeling_qwen2_5_omni.py +4308 -0
  290. helm/clients/audio_language/qwen_omni/processing_qwen2_5_omni.py +270 -0
  291. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/__init__.py +0 -0
  292. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/__init__.py +8 -0
  293. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/audio_process.py +56 -0
  294. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/vision_process.py +380 -0
  295. helm/clients/bedrock_client.py +63 -6
  296. helm/clients/cohere_client.py +3 -0
  297. helm/clients/dspy_client.py +135 -0
  298. helm/clients/google_client.py +2 -0
  299. helm/clients/http_model_client.py +2 -0
  300. helm/clients/huggingface_client.py +4 -3
  301. helm/clients/ibm_client.py +3 -1
  302. helm/clients/image_generation/adobe_vision_client.py +2 -0
  303. helm/clients/image_generation/aleph_alpha_image_generation_client.py +2 -0
  304. helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +1 -1
  305. helm/clients/image_generation/cogview2_client.py +2 -1
  306. helm/clients/image_generation/dalle2_client.py +2 -0
  307. helm/clients/image_generation/dalle_mini_client.py +2 -1
  308. helm/clients/image_generation/deep_floyd_client.py +2 -0
  309. helm/clients/image_generation/huggingface_diffusers_client.py +2 -1
  310. helm/clients/image_generation/lexica_client.py +2 -0
  311. helm/clients/image_generation/mindalle/models/stage1/layers.py +2 -2
  312. helm/clients/image_generation/mindalle_client.py +2 -1
  313. helm/clients/image_generation/together_image_generation_client.py +2 -0
  314. helm/clients/megatron_client.py +2 -0
  315. helm/clients/mistral_client.py +2 -0
  316. helm/clients/moderation_api_client.py +2 -0
  317. helm/clients/openai_client.py +38 -21
  318. helm/clients/openai_responses_client.py +34 -8
  319. helm/clients/openrouter_client.py +31 -0
  320. helm/clients/palmyra_client.py +2 -1
  321. helm/clients/reka_client.py +2 -1
  322. helm/clients/stanfordhealthcare_azure_openai_client.py +2 -2
  323. helm/clients/stanfordhealthcare_http_model_client.py +2 -0
  324. helm/clients/test_huggingface_client.py +3 -3
  325. helm/clients/test_openrouter_client.py +69 -0
  326. helm/clients/together_client.py +52 -13
  327. helm/clients/vertexai_client.py +23 -11
  328. helm/clients/vision_language/huggingface_vision2seq_client.py +2 -1
  329. helm/clients/vision_language/huggingface_vlm_client.py +2 -0
  330. helm/clients/vision_language/idefics_client.py +2 -1
  331. helm/clients/vision_language/open_flamingo_client.py +2 -1
  332. helm/clients/vision_language/paligemma_client.py +2 -1
  333. helm/clients/vision_language/palmyra_vision_client.py +2 -0
  334. helm/clients/vision_language/qwen2_vlm_client.py +2 -1
  335. helm/clients/vision_language/qwen_vlm_client.py +2 -1
  336. helm/clients/vllm_client.py +43 -7
  337. helm/clients/vllm_granite_thinking_client.py +56 -0
  338. helm/clients/writer_client.py +5 -2
  339. helm/common/critique_request.py +0 -1
  340. helm/common/hierarchical_logger.py +103 -34
  341. helm/common/object_spec.py +23 -8
  342. helm/common/optional_dependencies.py +1 -1
  343. helm/common/test_general.py +4 -0
  344. helm/common/test_logging.py +94 -0
  345. helm/config/model_deployments.yaml +1001 -187
  346. helm/config/model_metadata.yaml +602 -18
  347. helm/config/tokenizer_configs.yaml +202 -5
  348. helm/proxy/cli.py +1 -1
  349. helm/proxy/example_queries.py +8 -8
  350. helm/proxy/retry.py +5 -0
  351. helm/proxy/server.py +2 -1
  352. helm/proxy/static/index.css +4 -0
  353. helm/proxy/static/index.js +7 -1
  354. helm/tokenizers/auto_tokenizer.py +2 -2
  355. helm/tokenizers/grok_tokenizer.py +2 -0
  356. helm/benchmark/metrics/aci_bench_metrics.py +0 -14
  357. helm/benchmark/metrics/chw_care_plan_metrics.py +0 -14
  358. helm/benchmark/metrics/dischargeme_metrics.py +0 -14
  359. helm/benchmark/metrics/med_dialog_metrics.py +0 -14
  360. helm/benchmark/metrics/medalign_metrics.py +0 -14
  361. helm/benchmark/metrics/medi_qa_metrics.py +0 -14
  362. helm/benchmark/metrics/medication_qa_metrics.py +0 -14
  363. helm/benchmark/metrics/mental_health_metrics.py +0 -14
  364. helm/benchmark/metrics/mimic_bhc_metrics.py +0 -14
  365. helm/benchmark/metrics/mimic_rrs_metrics.py +0 -14
  366. helm/benchmark/metrics/mtsamples_procedures_metrics.py +0 -14
  367. helm/benchmark/metrics/mtsamples_replicate_metrics.py +0 -14
  368. helm/benchmark/metrics/numeracy_metrics.py +0 -72
  369. helm/benchmark/metrics/starr_patient_instructions_metrics.py +0 -14
  370. helm/benchmark/metrics/test_numeracy_metrics.py +0 -95
  371. helm/benchmark/scenarios/audio_language/ultra_suite_asr_classification.py +0 -103
  372. helm/benchmark/scenarios/numeracy_scenario.py +0 -794
  373. helm/benchmark/static_build/assets/index-94295e78.js +0 -10
  374. helm/benchmark/static_build/assets/index-b9779128.css +0 -1
  375. helm/benchmark/static_build/assets/react-f82877fd.js +0 -85
  376. helm/benchmark/static_build/assets/recharts-4037aff0.js +0 -97
  377. helm/benchmark/static_build/assets/tremor-38a10867.js +0 -10
  378. {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.10.dist-info}/WHEEL +0 -0
  379. {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.10.dist-info}/entry_points.txt +0 -0
  380. {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.10.dist-info}/licenses/LICENSE +0 -0
  381. {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.10.dist-info}/top_level.txt +0 -0
  382. /helm/benchmark/static_build/assets/{air-overview-d2e6c49f.png → air-overview-DpBbyagA.png} +0 -0
  383. /helm/benchmark/static_build/assets/{crfm-logo-74391ab8.png → crfm-logo-Du4T1uWZ.png} +0 -0
  384. /helm/benchmark/static_build/assets/{heim-logo-3e5e3aa4.png → heim-logo-BJtQlEbV.png} +0 -0
  385. /helm/benchmark/static_build/assets/{helm-logo-simple-2ed5400b.png → helm-logo-simple-DzOhNN41.png} +0 -0
  386. /helm/benchmark/static_build/assets/{helm-safety-2907a7b6.png → helm-safety-COfndXuS.png} +0 -0
  387. /helm/benchmark/static_build/assets/{helmhero-28e90f4d.png → helmhero-D9TvmJsp.png} +0 -0
  388. /helm/benchmark/static_build/assets/{medhelm-overview-eac29843.png → medhelm-overview-CND0EIsy.png} +0 -0
  389. /helm/benchmark/static_build/assets/{medhelm-v1-overview-3ddfcd65.png → medhelm-v1-overview-Cu2tphBB.png} +0 -0
  390. /helm/benchmark/static_build/assets/{overview-74aea3d8.png → overview-BwypNWnk.png} +0 -0
  391. /helm/benchmark/static_build/assets/{process-flow-bd2eba96.png → process-flow-DWDJC733.png} +0 -0
  392. /helm/benchmark/static_build/assets/{vhelm-aspects-1437d673.png → vhelm-aspects-NiDQofvP.png} +0 -0
  393. /helm/benchmark/static_build/assets/{vhelm-framework-a1ca3f3f.png → vhelm-framework-NxJE4fdA.png} +0 -0
  394. /helm/benchmark/static_build/assets/{vhelm-model-8afb7616.png → vhelm-model-ypCL5Yvq.png} +0 -0
@@ -7,14 +7,13 @@
7
7
  <title>Holistic Evaluation of Language Models (HELM)</title>
8
8
  <meta name="description" content="The Holistic Evaluation of Language Models (HELM) serves as a living benchmark for transparency in language models. Providing broad coverage and recognizing incompleteness, multi-metric measurements, and standardization. All data and analysis are freely accessible on the website for exploration and study." />
9
9
  <script type="text/javascript" src="./config.js"></script>
10
- <script type="module" crossorigin src="./assets/index-94295e78.js"></script>
11
- <link rel="modulepreload" crossorigin href="./assets/react-f82877fd.js">
12
- <link rel="modulepreload" crossorigin href="./assets/recharts-4037aff0.js">
13
- <link rel="modulepreload" crossorigin href="./assets/tremor-38a10867.js">
14
- <link rel="stylesheet" href="./assets/index-b9779128.css">
10
+ <script type="module" crossorigin src="./assets/index-qOFpOyHb.js"></script>
11
+ <link rel="modulepreload" crossorigin href="./assets/react-BteFIppM.js">
12
+ <link rel="modulepreload" crossorigin href="./assets/recharts-DxuQtTOs.js">
13
+ <link rel="modulepreload" crossorigin href="./assets/tremor-DR4fE7ko.js">
14
+ <link rel="stylesheet" crossorigin href="./assets/index-oIeiQW2g.css">
15
15
  </head>
16
16
  <body class="block">
17
17
  <div id="root"></div>
18
-
19
18
  </body>
20
19
  </html>
@@ -1,9 +1,7 @@
1
- from abc import ABC
2
-
3
1
  from helm.benchmark.window_services.local_window_service import LocalWindowService
4
2
 
5
3
 
6
- class CLIPWindowService(LocalWindowService, ABC):
4
+ class CLIPWindowService(LocalWindowService):
7
5
  def truncate_from_right(self, text: str, expected_completion_token_length: int = 0) -> str:
8
6
  result: str = self.decode(self.encode(text, truncation=True, max_length=self.max_request_length).tokens)
9
7
 
@@ -2,6 +2,7 @@ from typing import Dict, List, Optional, TypedDict
2
2
  import requests
3
3
 
4
4
  from helm.common.cache import CacheConfig
5
+ from helm.common.hierarchical_logger import hexception
5
6
  from helm.common.optional_dependencies import handle_module_not_found_error
6
7
  from helm.common.request import (
7
8
  wrap_request_time,
@@ -76,6 +77,7 @@ class AI21Client(CachingClient):
76
77
  cache_key = CachingClient.make_cache_key({"engine": request.model_engine, **raw_request}, request)
77
78
  response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
78
79
  except AI21RequestError as e:
80
+ hexception(e)
79
81
  return RequestResult(success=False, cached=False, error=str(e), completions=[], embedding=[])
80
82
 
81
83
  def fix_text(x: str, first: bool) -> str:
@@ -1,6 +1,7 @@
1
1
  from typing import List
2
2
 
3
3
  from helm.common.cache import CacheConfig
4
+ from helm.common.hierarchical_logger import hexception
4
5
  from helm.common.media_object import TEXT_TYPE
5
6
  from helm.common.optional_dependencies import handle_module_not_found_error
6
7
  from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, Token
@@ -76,6 +77,7 @@ class AlephAlphaClient(CachingClient):
76
77
  cache_key = CachingClient.make_cache_key({"model": model, "prompt": prompt_key, **parameters}, request)
77
78
  response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
78
79
  except Exception as e:
80
+ hexception(e)
79
81
  error: str = f"AlephAlphaClient error: {e}"
80
82
  return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
81
83
 
@@ -8,7 +8,7 @@ import time
8
8
  import urllib.parse
9
9
 
10
10
  from helm.common.cache import CacheConfig
11
- from helm.common.hierarchical_logger import htrack_block, hlog, hwarn
11
+ from helm.common.hierarchical_logger import hexception, htrack_block, hlog, hwarn
12
12
  from helm.common.media_object import IMAGE_TYPE, TEXT_TYPE
13
13
  from helm.common.optional_dependencies import handle_module_not_found_error
14
14
  from helm.common.request import (
@@ -184,6 +184,7 @@ class AnthropicClient(CachingClient):
184
184
  embedding=[],
185
185
  error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
186
186
  )
187
+ hexception(error)
187
188
  return RequestResult(success=False, cached=False, error=str(error), completions=[], embedding=[])
188
189
 
189
190
  # Post process the completion.
@@ -385,6 +386,10 @@ class AnthropicMessagesClient(CachingClient):
385
386
  # Avoid error:
386
387
  # `top_k` must be unset when thinking is enabled. Please consult our documentation at https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking # noqa: E501
387
388
  del raw_request["top_k"]
389
+ if raw_request["model"].startswith("claude-sonnet-4-5"):
390
+ # Avoid error:
391
+ # `temperature` and `top_p` cannot both be specified for this model. Please use only one.
392
+ del raw_request["top_p"]
388
393
 
389
394
  completions: List[GeneratedOutput] = []
390
395
 
@@ -696,6 +701,7 @@ class AnthropicLegacyClient(CachingClient):
696
701
  )
697
702
  response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
698
703
  except AnthropicRequestError as error:
704
+ hexception(error)
699
705
  return RequestResult(success=False, cached=False, error=str(error), completions=[], embedding=[])
700
706
 
701
707
  sequence_logprob: float = 0
@@ -6,6 +6,7 @@ from transformers import AutoModel, PreTrainedModel
6
6
 
7
7
  from helm.clients.client import CachingClient
8
8
  from helm.common.cache import CacheConfig
9
+ from helm.common.hierarchical_logger import hexception
9
10
  from helm.common.media_object import TEXT_TYPE
10
11
  from helm.common.request import (
11
12
  GeneratedOutput,
@@ -105,6 +106,7 @@ class DivaLlamaClient(CachingClient):
105
106
  cache_key = CachingClient.make_cache_key(raw_request, request)
106
107
  response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
107
108
  except Exception as e: # Do something if error is encountered.
109
+ hexception(e)
108
110
  error: str = f"HuggingFace error: {e}"
109
111
  return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
110
112
 
@@ -0,0 +1,61 @@
1
+ import transformers
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional
5
+
6
+
7
+ @dataclass
8
+ class ModelArguments:
9
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
10
+ version: Optional[str] = field(default="v0")
11
+ freeze_backbone: bool = field(default=False)
12
+ tune_speech_projector: bool = field(default=False)
13
+ tune_speech_encoder: bool = field(default=False)
14
+ tune_speech_generator_only: bool = field(default=False)
15
+ speech_encoder_type: Optional[str] = field(default=None)
16
+ speech_encoder: Optional[str] = field(default=None)
17
+ pretrain_speech_projector: Optional[str] = field(default=None)
18
+ speech_projector_type: Optional[str] = field(default="linear")
19
+ speech_generator_type: Optional[str] = field(default="ctc")
20
+ ctc_decoder_config: str = "(2,4096,32,11008)"
21
+ ctc_upsample_factor: int = 1
22
+ ctc_loss_weight: float = 1.0
23
+ unit_vocab_size: int = 1000
24
+ speech_encoder_ds_rate: int = 5
25
+ speech_encoder_hidden_size: int = 1280
26
+
27
+
28
+ @dataclass
29
+ class DataArguments:
30
+ data_path: str = field(default="", metadata={"help": "Path to the training data."})
31
+ is_multimodal: bool = False
32
+ input_type: str = field(default="mel")
33
+ speech_normalize: bool = False
34
+ mel_size: int = 128
35
+ has_tgt_units: bool = False
36
+
37
+
38
+ @dataclass
39
+ class TrainingArguments(transformers.TrainingArguments):
40
+ cache_dir: Optional[str] = field(default=None)
41
+ optim: str = field(default="adamw_torch")
42
+ freeze_speech_projector: bool = field(default=False)
43
+ model_max_length: int = field(
44
+ default=512,
45
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
46
+ )
47
+ double_quant: bool = field(
48
+ default=True, metadata={"help": "Compress the quantization statistics through double quantization."}
49
+ )
50
+ quant_type: str = field(
51
+ default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
52
+ )
53
+ bits: int = field(default=16, metadata={"help": "How many bits to use."})
54
+ lora_enable: bool = False
55
+ lora_r: int = 64
56
+ lora_alpha: int = 16
57
+ lora_dropout: float = 0.05
58
+ lora_weight_path: str = ""
59
+ lora_bias: str = "none"
60
+ speech_projector_lr: Optional[float] = None
61
+ group_by_modality_length: bool = field(default=False)
@@ -0,0 +1,9 @@
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ SPEECH_TOKEN_INDEX = -200
9
+ DEFAULT_SPEECH_TOKEN = "<speech>"
@@ -0,0 +1,213 @@
1
+ # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
2
+ # Copyright 2023 Haotian Liu
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import dataclasses
17
+ from enum import auto, Enum
18
+ from typing import List, Any, Union, Optional
19
+
20
+
21
+ class SeparatorStyle(Enum):
22
+ """Different separator style."""
23
+
24
+ TWO = auto()
25
+ PLAIN = auto()
26
+ LLAMA_2 = auto()
27
+ LLAMA_3 = auto()
28
+
29
+
30
+ @dataclasses.dataclass
31
+ class Conversation:
32
+ """A class that keeps all conversation history."""
33
+
34
+ system: str
35
+ roles: List[str]
36
+ messages: List[List[str]]
37
+ offset: int
38
+ sep_style: SeparatorStyle = SeparatorStyle.PLAIN
39
+ sep: str = "###"
40
+ sep2: str = ""
41
+ version: str = "Unknown"
42
+
43
+ tokenizer_id: str = ""
44
+ tokenizer: Any = None
45
+ # Stop criteria (the default one is EOS token)
46
+ stop_str: Optional[Union[str, List[str]]] = None
47
+ # Stops generation if meeting any token in this list
48
+ stop_token_ids: Optional[List[int]] = None
49
+
50
+ skip_next: bool = False
51
+
52
+ def get_prompt(self):
53
+ messages = self.messages
54
+
55
+ if self.sep_style == SeparatorStyle.TWO:
56
+ seps = [self.sep, self.sep2]
57
+ ret = self.system + seps[0]
58
+ for i, (role, message) in enumerate(messages):
59
+ if message:
60
+ if type(message) is tuple:
61
+ message = message[0]
62
+ ret += role + ": " + message + seps[i % 2]
63
+ else:
64
+ ret += role + ":"
65
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
66
+ wrap_sys = lambda msg: (
67
+ f"<|start_header_id|>system<|end_header_id|>\n\n{msg}<|eot_id|>" if len(msg) > 0 else msg
68
+ )
69
+ ret = "<|begin_of_text|>" + wrap_sys(self.system)
70
+ for i, (role, message) in enumerate(messages):
71
+ if message:
72
+ if type(message) is tuple:
73
+ message = message[0]
74
+ ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
75
+ ret += message.strip() + self.sep2
76
+ else:
77
+ ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
78
+ return ret
79
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
80
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
81
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
82
+ ret = ""
83
+
84
+ for i, (role, message) in enumerate(messages):
85
+ if i == 0:
86
+ assert message, "first message should not be none"
87
+ assert role == self.roles[0], "first message should come from user"
88
+ if message:
89
+ if type(message) is tuple:
90
+ message = message[0]
91
+ if i == 0:
92
+ message = wrap_sys(self.system) + message
93
+ if i % 2 == 0:
94
+ message = wrap_inst(message)
95
+ ret += self.sep + message
96
+ else:
97
+ ret += " " + message + " " + self.sep2
98
+ else:
99
+ ret += ""
100
+ ret = ret.lstrip(self.sep)
101
+ elif self.sep_style == SeparatorStyle.PLAIN:
102
+ seps = [self.sep, self.sep2]
103
+ ret = self.system
104
+ for i, (role, message) in enumerate(messages):
105
+ if message:
106
+ if type(message) is tuple:
107
+ message = message[0]
108
+ ret += message + seps[i % 2]
109
+ else:
110
+ ret += ""
111
+ else:
112
+ raise ValueError(f"Invalid style: {self.sep_style}")
113
+
114
+ return ret
115
+
116
+ def append_message(self, role, message):
117
+ self.messages.append([role, message])
118
+
119
+ def to_gradio_chatbot(self):
120
+ ret = []
121
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
122
+ if i % 2 == 0:
123
+ if type(msg) is tuple:
124
+ msg = msg[0]
125
+ ret.append([msg, None])
126
+ else:
127
+ ret.append([msg, None])
128
+ else:
129
+ ret[-1][-1] = msg
130
+ return ret
131
+
132
+ def copy(self):
133
+ return Conversation(
134
+ system=self.system,
135
+ roles=self.roles,
136
+ messages=[[x, y] for x, y in self.messages],
137
+ offset=self.offset,
138
+ sep_style=self.sep_style,
139
+ sep=self.sep,
140
+ sep2=self.sep2,
141
+ version=self.version,
142
+ )
143
+
144
+ def dict(self):
145
+ return {
146
+ "system": self.system,
147
+ "roles": self.roles,
148
+ "messages": self.messages,
149
+ "offset": self.offset,
150
+ "sep": self.sep,
151
+ "sep2": self.sep2,
152
+ }
153
+
154
+
155
+ conv_vicuna_v1 = Conversation(
156
+ system="A chat between a curious user and an artificial intelligence assistant. "
157
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
158
+ roles=["USER", "ASSISTANT"],
159
+ version="v1",
160
+ messages=[],
161
+ offset=0,
162
+ sep_style=SeparatorStyle.TWO,
163
+ sep=" ",
164
+ sep2="</s>",
165
+ )
166
+
167
+ conv_llama_2 = Conversation(
168
+ system="You are a helpful language and speech assistant. "
169
+ "You are able to understand the speech content that the user provides, "
170
+ "and assist the user with a variety of tasks using natural language.",
171
+ roles=["USER", "ASSISTANT"],
172
+ version="llama_v2",
173
+ messages=[],
174
+ offset=0,
175
+ sep_style=SeparatorStyle.LLAMA_2,
176
+ sep="<s>",
177
+ sep2="</s>",
178
+ )
179
+
180
+ conv_llama_3 = Conversation(
181
+ system="You are a helpful language and speech assistant. "
182
+ "You are able to understand the speech content that the user provides, "
183
+ "and assist the user with a variety of tasks using natural language.",
184
+ roles=["user", "assistant"],
185
+ version="llama_v3",
186
+ messages=[],
187
+ offset=0,
188
+ sep_style=SeparatorStyle.LLAMA_3,
189
+ sep="",
190
+ sep2="<|eot_id|>",
191
+ )
192
+
193
+ conv_plain = Conversation(
194
+ system="",
195
+ roles=["", ""],
196
+ messages=[],
197
+ offset=0,
198
+ sep_style=SeparatorStyle.PLAIN,
199
+ sep="</s>",
200
+ )
201
+
202
+
203
+ default_conversation = conv_llama_3
204
+ conv_templates = {
205
+ "v1": conv_vicuna_v1,
206
+ "plain": conv_plain,
207
+ "llama_2": conv_llama_2,
208
+ "llama_3": conv_llama_3,
209
+ }
210
+
211
+
212
+ if __name__ == "__main__":
213
+ print(default_conversation.get_prompt())
@@ -0,0 +1,88 @@
1
+ import os
2
+
3
+ from transformers import AutoTokenizer, AutoConfig, BitsAndBytesConfig
4
+ import torch
5
+ from helm.clients.audio_language.llama_omni.model.language_model.omni_speech_llama import OmniSpeechLlamaForCausalLM
6
+ from helm.clients.audio_language.llama_omni.model.language_model.omni_speech2s_llama import OmniSpeech2SLlamaForCausalLM
7
+ from helm.clients.audio_language.llama_omni.model.speech_encoder.builder import build_speech_encoder
8
+
9
+
10
+ def load_pretrained_model(
11
+ model_path,
12
+ model_base,
13
+ is_lora=False,
14
+ s2s=False,
15
+ load_8bit=False,
16
+ load_4bit=False,
17
+ device="cuda",
18
+ use_flash_attn=False,
19
+ **kwargs,
20
+ ):
21
+ if load_8bit:
22
+ kwargs["load_in_8bit"] = True
23
+ elif load_4bit:
24
+ kwargs["load_in_4bit"] = True
25
+ kwargs["quantization_config"] = BitsAndBytesConfig(
26
+ load_in_4bit=True,
27
+ bnb_4bit_compute_dtype=torch.float16,
28
+ bnb_4bit_use_double_quant=True,
29
+ bnb_4bit_quant_type="nf4",
30
+ )
31
+ else:
32
+ kwargs["torch_dtype"] = torch.float16
33
+
34
+ if use_flash_attn:
35
+ kwargs["attn_implementation"] = "flash_attention_2"
36
+
37
+ model_cls = OmniSpeech2SLlamaForCausalLM if s2s else OmniSpeechLlamaForCausalLM
38
+
39
+ # Load OmniSpeech model
40
+ if is_lora:
41
+ assert model_base is not None, "model_base is required for LoRA models."
42
+ from language_model.omni_speech_llama import OmniSpeechConfig
43
+
44
+ lora_cfg_pretrained = OmniSpeechConfig.from_pretrained(model_path)
45
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
46
+ print("Loading OmniSpeech from base model...")
47
+ model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=lora_cfg_pretrained, **kwargs)
48
+ print("Loading additional OmniSpeech weights...")
49
+ if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
50
+ non_lora_trainables = torch.load(os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu")
51
+ non_lora_trainables = {
52
+ (k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items()
53
+ }
54
+ if any(k.startswith("model.model.") for k in non_lora_trainables):
55
+ non_lora_trainables = {(k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()}
56
+ model.load_state_dict(non_lora_trainables, strict=False)
57
+
58
+ from peft import PeftModel
59
+
60
+ print("Loading LoRA weights...")
61
+ model = PeftModel.from_pretrained(model, model_path)
62
+ print("Merging LoRA weights...")
63
+ model = model.merge_and_unload()
64
+ print("Model is loaded...")
65
+ elif model_base is not None:
66
+ print("Loading OmniSpeech from base model...")
67
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
68
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
69
+ model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=cfg_pretrained, **kwargs)
70
+
71
+ speech_projector_weights = torch.load(os.path.join(model_path, "speech_projector.bin"), map_location="cpu")
72
+ speech_projector_weights = {k: v.to(torch.float16) for k, v in speech_projector_weights.items()}
73
+ model.load_state_dict(speech_projector_weights, strict=False)
74
+ model = model.to(device=device)
75
+ else:
76
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
77
+ model = model_cls.from_pretrained(model_path, low_cpu_mem_usage=False, **kwargs)
78
+ model = model.to(device=device)
79
+
80
+ model.get_model().speech_encoder = build_speech_encoder(model.config)
81
+ model.get_model().speech_encoder.to(device=device, dtype=torch.float16)
82
+
83
+ if hasattr(model.config, "max_sequence_length"):
84
+ context_len = model.config.max_sequence_length
85
+ else:
86
+ context_len = 2048
87
+
88
+ return tokenizer, model, context_len