crfm-helm 0.5.7__py3-none-any.whl → 0.5.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crfm-helm might be problematic. Click here for more details.

Files changed (243) hide show
  1. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/METADATA +5 -77
  2. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/RECORD +228 -197
  3. helm/benchmark/adaptation/adapter_spec.py +5 -0
  4. helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +11 -3
  5. helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +11 -8
  6. helm/benchmark/annotation/aci_bench_annotator.py +11 -22
  7. helm/benchmark/annotation/alrage_annotator.py +90 -0
  8. helm/benchmark/annotation/chw_care_plan_annotator.py +10 -21
  9. helm/benchmark/annotation/dischargeme_annotator.py +11 -22
  10. helm/benchmark/annotation/med_dialog_annotator.py +11 -22
  11. helm/benchmark/annotation/medalign_annotator.py +11 -22
  12. helm/benchmark/annotation/medi_qa_annotator.py +11 -22
  13. helm/benchmark/annotation/medication_qa_annotator.py +11 -22
  14. helm/benchmark/annotation/mental_health_annotator.py +11 -22
  15. helm/benchmark/annotation/mimic_bhc_annotator.py +11 -22
  16. helm/benchmark/annotation/mimic_rrs_annotator.py +11 -22
  17. helm/benchmark/annotation/model_as_judge.py +23 -18
  18. helm/benchmark/annotation/mtsamples_procedures_annotator.py +11 -22
  19. helm/benchmark/annotation/mtsamples_replicate_annotator.py +11 -22
  20. helm/benchmark/annotation/starr_patient_instructions_annotator.py +11 -22
  21. helm/benchmark/metrics/air_bench_metrics.py +3157 -1
  22. helm/benchmark/metrics/alrage_metric.py +35 -0
  23. helm/benchmark/metrics/basic_metrics.py +267 -2
  24. helm/benchmark/metrics/classification_metrics.py +19 -1
  25. helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +12 -1
  26. helm/benchmark/metrics/dry_run_metrics.py +30 -1
  27. helm/benchmark/metrics/efficiency_metrics.py +74 -0
  28. helm/benchmark/metrics/ehr_sql_metrics.py +57 -1
  29. helm/benchmark/metrics/evaluate_reference_metrics.py +299 -0
  30. helm/benchmark/metrics/gpqa_chain_of_thought_metric.py +13 -1
  31. helm/benchmark/metrics/helpdesk_call_summarization_metrics.py +13 -1
  32. helm/benchmark/metrics/ifeval_metrics.py +13 -1
  33. helm/benchmark/metrics/instruction_following_critique_metrics.py +41 -1
  34. helm/benchmark/metrics/kpi_edgar_metrics.py +21 -0
  35. helm/benchmark/metrics/language_modeling_metrics.py +13 -1
  36. helm/benchmark/metrics/live_qa_metrics.py +13 -1
  37. helm/benchmark/metrics/llm_jury_metrics.py +13 -1
  38. helm/benchmark/metrics/medcalc_bench_metrics.py +14 -1
  39. helm/benchmark/metrics/medec_metrics.py +25 -2
  40. helm/benchmark/metrics/metric.py +25 -0
  41. helm/benchmark/metrics/mimiciv_billing_code_metrics.py +32 -1
  42. helm/benchmark/metrics/omni_math_metrics.py +13 -1
  43. helm/benchmark/metrics/seahelm_metrics.py +14 -1
  44. helm/benchmark/metrics/summac/model_summac.py +2 -2
  45. helm/benchmark/metrics/summarization_metrics.py +129 -1
  46. helm/benchmark/metrics/toxicity_metrics.py +31 -1
  47. helm/benchmark/metrics/wildbench_metrics.py +21 -1
  48. helm/benchmark/presentation/schema.py +5 -22
  49. helm/benchmark/presentation/summarize.py +180 -11
  50. helm/benchmark/presentation/taxonomy_info.py +20 -0
  51. helm/benchmark/run_expander.py +4 -0
  52. helm/benchmark/run_specs/arabic_run_specs.py +134 -16
  53. helm/benchmark/run_specs/bluex_run_specs.py +1 -1
  54. helm/benchmark/run_specs/classic_run_specs.py +2 -2
  55. helm/benchmark/run_specs/long_context_run_specs.py +2 -2
  56. helm/benchmark/run_specs/medhelm/__init__.py +0 -0
  57. helm/benchmark/run_specs/medhelm/benchmark_config.py +219 -0
  58. helm/benchmark/run_specs/medhelm_run_specs.py +360 -50
  59. helm/benchmark/scenarios/aci_bench_scenario.py +23 -0
  60. helm/benchmark/scenarios/air_bench_scenario.py +21 -0
  61. helm/benchmark/scenarios/alrage_scenario.py +54 -0
  62. helm/benchmark/scenarios/anthropic_hh_rlhf_scenario.py +23 -1
  63. helm/benchmark/scenarios/arabic_exams_scenario.py +114 -0
  64. helm/benchmark/scenarios/arabic_mmlu_scenario.py +8 -4
  65. helm/benchmark/scenarios/aratrust_scenario.py +19 -0
  66. helm/benchmark/scenarios/babi_qa_scenario.py +15 -0
  67. helm/benchmark/scenarios/bbq_scenario.py +15 -0
  68. helm/benchmark/scenarios/best_chatgpt_prompts.yaml +473 -0
  69. helm/benchmark/scenarios/bluex_scenario.py +6 -2
  70. helm/benchmark/scenarios/bold_scenario.py +15 -0
  71. helm/benchmark/scenarios/boolq_scenario.py +20 -0
  72. helm/benchmark/scenarios/chw_care_plan_scenario.py +23 -0
  73. helm/benchmark/scenarios/civil_comments_scenario.py +13 -0
  74. helm/benchmark/scenarios/clear_scenario.py +23 -0
  75. helm/benchmark/scenarios/cleva_scenario.py +479 -0
  76. helm/benchmark/scenarios/code_scenario.py +28 -0
  77. helm/benchmark/scenarios/commonsense_scenario.py +26 -0
  78. helm/benchmark/scenarios/compositional_instructions.yaml +70 -0
  79. helm/benchmark/scenarios/conv_fin_qa_calc_scenario.py +21 -0
  80. helm/benchmark/scenarios/copyright_scenario.py +35 -1
  81. helm/benchmark/scenarios/cti_to_mitre_scenario.py +21 -0
  82. helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +22 -1
  83. helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +23 -1
  84. helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +22 -1
  85. helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +21 -1
  86. helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +13 -0
  87. helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +13 -1
  88. helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +13 -1
  89. helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +13 -1
  90. helm/benchmark/scenarios/dischargeme_scenario.py +24 -0
  91. helm/benchmark/scenarios/disinformation_scenario.py +22 -0
  92. helm/benchmark/scenarios/dyck_language_scenario.py +15 -0
  93. helm/benchmark/scenarios/ehrshot_scenario.py +22 -0
  94. helm/benchmark/scenarios/enem_challenge_scenario.py +19 -0
  95. helm/benchmark/scenarios/entity_data_imputation_scenario.py +14 -0
  96. helm/benchmark/scenarios/entity_matching_scenario.py +14 -0
  97. helm/benchmark/scenarios/financial_phrasebank_scenario.py +21 -0
  98. helm/benchmark/scenarios/gold_commodity_news_scenario.py +21 -0
  99. helm/benchmark/scenarios/gpqa_scenario.py +18 -0
  100. helm/benchmark/scenarios/grammar_scenario.py +20 -1
  101. helm/benchmark/scenarios/gsm_scenario.py +15 -0
  102. helm/benchmark/scenarios/headqa_scenario.py +22 -0
  103. helm/benchmark/scenarios/helpdesk_call_summarization_scenario.py +13 -0
  104. helm/benchmark/scenarios/ice_scenario.py +21 -1
  105. helm/benchmark/scenarios/ifeval_scenario.py +18 -0
  106. helm/benchmark/scenarios/imdb_scenario.py +15 -0
  107. helm/benchmark/scenarios/koala_scenario.py +21 -1
  108. helm/benchmark/scenarios/kpi_edgar_scenario.py +21 -0
  109. helm/benchmark/scenarios/legal_contract_summarization_scenario.py +20 -0
  110. helm/benchmark/scenarios/legal_summarization_scenario.py +50 -0
  111. helm/benchmark/scenarios/legal_support_scenario.py +13 -0
  112. helm/benchmark/scenarios/legalbench_scenario.py +20 -0
  113. helm/benchmark/scenarios/lex_glue_scenario.py +11 -0
  114. helm/benchmark/scenarios/lextreme_scenario.py +11 -0
  115. helm/benchmark/scenarios/lsat_qa_scenario.py +14 -0
  116. helm/benchmark/scenarios/madinah_qa_scenario.py +73 -0
  117. helm/benchmark/scenarios/math_scenario.py +26 -0
  118. helm/benchmark/scenarios/mbzuai_human_translated_arabic_mmlu.py +68 -0
  119. helm/benchmark/scenarios/med_dialog_scenario.py +32 -1
  120. helm/benchmark/scenarios/med_mcqa_scenario.py +14 -0
  121. helm/benchmark/scenarios/med_qa_scenario.py +14 -0
  122. helm/benchmark/scenarios/medalign_scenario.py +23 -0
  123. helm/benchmark/scenarios/medbullets_scenario.py +22 -0
  124. helm/benchmark/scenarios/medcalc_bench_scenario.py +22 -0
  125. helm/benchmark/scenarios/medec_scenario.py +23 -0
  126. helm/benchmark/scenarios/medhallu_scenario.py +23 -0
  127. helm/benchmark/scenarios/medhelm/__init__.py +0 -0
  128. helm/benchmark/scenarios/medhelm/judges.yaml +14 -0
  129. helm/benchmark/scenarios/medhelm_configurable_scenario.py +101 -0
  130. helm/benchmark/scenarios/medi_qa_scenario.py +23 -0
  131. helm/benchmark/scenarios/medication_qa_scenario.py +31 -1
  132. helm/benchmark/scenarios/mental_health_scenario.py +23 -0
  133. helm/benchmark/scenarios/mimic_bhc_scenario.py +24 -0
  134. helm/benchmark/scenarios/mimic_rrs_scenario.py +23 -0
  135. helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +22 -0
  136. helm/benchmark/scenarios/mmlu_pro_scenario.py +18 -0
  137. helm/benchmark/scenarios/mmlu_scenario.py +15 -0
  138. helm/benchmark/scenarios/msmarco_scenario.py +30 -0
  139. helm/benchmark/scenarios/mtsamples_procedures_scenario.py +22 -0
  140. helm/benchmark/scenarios/mtsamples_replicate_scenario.py +22 -0
  141. helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +20 -0
  142. helm/benchmark/scenarios/narrativeqa_scenario.py +20 -0
  143. helm/benchmark/scenarios/natural_qa_scenario.py +32 -0
  144. helm/benchmark/scenarios/omni_math_scenario.py +18 -0
  145. helm/benchmark/scenarios/open_assistant_scenario.py +22 -0
  146. helm/benchmark/scenarios/pubmed_qa_scenario.py +22 -0
  147. helm/benchmark/scenarios/quac_scenario.py +14 -0
  148. helm/benchmark/scenarios/race_based_med_scenario.py +23 -0
  149. helm/benchmark/scenarios/raft_scenario.py +15 -0
  150. helm/benchmark/scenarios/real_toxicity_prompts_scenario.py +14 -1
  151. helm/benchmark/scenarios/scenario.py +31 -0
  152. helm/benchmark/scenarios/seahelm_scenario.py +348 -0
  153. helm/benchmark/scenarios/self_instruct_scenario.py +29 -1
  154. helm/benchmark/scenarios/shc_bmt_scenario.py +22 -0
  155. helm/benchmark/scenarios/shc_cdi_scenario.py +20 -0
  156. helm/benchmark/scenarios/shc_conf_scenario.py +23 -0
  157. helm/benchmark/scenarios/shc_ent_scenario.py +21 -0
  158. helm/benchmark/scenarios/shc_gip_scenario.py +20 -0
  159. helm/benchmark/scenarios/shc_privacy_scenario.py +22 -0
  160. helm/benchmark/scenarios/shc_proxy_scenario.py +22 -0
  161. helm/benchmark/scenarios/shc_ptbm_scenario.py +23 -0
  162. helm/benchmark/scenarios/shc_sequoia_scenario.py +21 -0
  163. helm/benchmark/scenarios/situation_prompts.yaml +49 -0
  164. helm/benchmark/scenarios/starr_patient_instructions_scenario.py +22 -0
  165. helm/benchmark/scenarios/summarization_scenario.py +37 -0
  166. helm/benchmark/scenarios/synthetic_efficiency_scenario.py +22 -1
  167. helm/benchmark/scenarios/synthetic_reasoning_natural_scenario.py +13 -0
  168. helm/benchmark/scenarios/test_alrage_scenario.py +23 -0
  169. helm/benchmark/scenarios/test_arabic_exams_scenario.py +21 -0
  170. helm/benchmark/scenarios/test_aratrust_scenario.py +1 -1
  171. helm/benchmark/scenarios/test_bluex_scenario.py +2 -2
  172. helm/benchmark/scenarios/the_pile_scenario.py +13 -1
  173. helm/benchmark/scenarios/truthful_qa_scenario.py +14 -0
  174. helm/benchmark/scenarios/twitter_aae_scenario.py +20 -1
  175. helm/benchmark/scenarios/vicuna_scenario.py +21 -1
  176. helm/benchmark/scenarios/wikifact_scenario.py +20 -0
  177. helm/benchmark/scenarios/wildbench_scenario.py +18 -0
  178. helm/benchmark/scenarios/wmt_14_scenario.py +12 -0
  179. helm/benchmark/static/schema_arabic.yaml +55 -12
  180. helm/benchmark/static/schema_long_context.yaml +17 -17
  181. helm/benchmark/static/schema_medhelm.yaml +36 -0
  182. helm/benchmark/static/schema_slp.yaml +219 -0
  183. helm/benchmark/static_build/assets/index-671a5e06.js +10 -0
  184. helm/benchmark/static_build/assets/index-9352595e.css +1 -0
  185. helm/benchmark/static_build/index.html +2 -2
  186. helm/clients/audio_language/llama_omni/arguments.py +61 -0
  187. helm/clients/audio_language/llama_omni/constants.py +9 -0
  188. helm/clients/audio_language/llama_omni/conversation.py +213 -0
  189. helm/clients/audio_language/llama_omni/model/__init__.py +0 -0
  190. helm/clients/audio_language/llama_omni/model/builder.py +88 -0
  191. helm/clients/audio_language/llama_omni/model/language_model/omni_speech2s_llama.py +190 -0
  192. helm/clients/audio_language/llama_omni/model/language_model/omni_speech_llama.py +118 -0
  193. helm/clients/audio_language/llama_omni/model/omni_speech_arch.py +249 -0
  194. helm/clients/audio_language/llama_omni/model/speech_encoder/builder.py +9 -0
  195. helm/clients/audio_language/llama_omni/model/speech_encoder/speech_encoder.py +27 -0
  196. helm/clients/audio_language/llama_omni/model/speech_generator/builder.py +9 -0
  197. helm/clients/audio_language/llama_omni/model/speech_generator/generation.py +622 -0
  198. helm/clients/audio_language/llama_omni/model/speech_generator/speech_generator.py +104 -0
  199. helm/clients/audio_language/llama_omni/model/speech_projector/builder.py +9 -0
  200. helm/clients/audio_language/llama_omni/model/speech_projector/speech_projector.py +27 -0
  201. helm/clients/audio_language/llama_omni/preprocess.py +295 -0
  202. helm/clients/audio_language/llama_omni/utils.py +202 -0
  203. helm/clients/audio_language/qwen_omni/configuration_qwen2_5_omni.py +519 -0
  204. helm/clients/audio_language/qwen_omni/modeling_qwen2_5_omni.py +4308 -0
  205. helm/clients/audio_language/qwen_omni/processing_qwen2_5_omni.py +270 -0
  206. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/__init__.py +0 -0
  207. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/__init__.py +8 -0
  208. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/audio_process.py +56 -0
  209. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/vision_process.py +380 -0
  210. helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +1 -1
  211. helm/clients/image_generation/mindalle/models/stage1/layers.py +2 -2
  212. helm/clients/openai_client.py +31 -19
  213. helm/clients/openai_responses_client.py +27 -3
  214. helm/clients/openrouter_client.py +31 -0
  215. helm/clients/test_openrouter_client.py +69 -0
  216. helm/clients/together_client.py +48 -11
  217. helm/clients/vertexai_client.py +8 -2
  218. helm/config/model_deployments.yaml +75 -1
  219. helm/config/model_metadata.yaml +70 -2
  220. helm/config/tokenizer_configs.yaml +19 -1
  221. helm/proxy/example_queries.py +8 -8
  222. helm/proxy/server.py +2 -1
  223. helm/proxy/static/index.css +4 -0
  224. helm/proxy/static/index.js +7 -1
  225. helm/benchmark/metrics/aci_bench_metrics.py +0 -14
  226. helm/benchmark/metrics/chw_care_plan_metrics.py +0 -14
  227. helm/benchmark/metrics/dischargeme_metrics.py +0 -14
  228. helm/benchmark/metrics/med_dialog_metrics.py +0 -14
  229. helm/benchmark/metrics/medalign_metrics.py +0 -14
  230. helm/benchmark/metrics/medi_qa_metrics.py +0 -14
  231. helm/benchmark/metrics/medication_qa_metrics.py +0 -14
  232. helm/benchmark/metrics/mental_health_metrics.py +0 -14
  233. helm/benchmark/metrics/mimic_bhc_metrics.py +0 -14
  234. helm/benchmark/metrics/mimic_rrs_metrics.py +0 -14
  235. helm/benchmark/metrics/mtsamples_procedures_metrics.py +0 -14
  236. helm/benchmark/metrics/mtsamples_replicate_metrics.py +0 -14
  237. helm/benchmark/metrics/starr_patient_instructions_metrics.py +0 -14
  238. helm/benchmark/static_build/assets/index-b9779128.css +0 -1
  239. helm/benchmark/static_build/assets/index-e439d5e1.js +0 -10
  240. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/WHEEL +0 -0
  241. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/entry_points.txt +0 -0
  242. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/licenses/LICENSE +0 -0
  243. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/top_level.txt +0 -0
@@ -7,11 +7,11 @@
7
7
  <title>Holistic Evaluation of Language Models (HELM)</title>
8
8
  <meta name="description" content="The Holistic Evaluation of Language Models (HELM) serves as a living benchmark for transparency in language models. Providing broad coverage and recognizing incompleteness, multi-metric measurements, and standardization. All data and analysis are freely accessible on the website for exploration and study." />
9
9
  <script type="text/javascript" src="./config.js"></script>
10
- <script type="module" crossorigin src="./assets/index-e439d5e1.js"></script>
10
+ <script type="module" crossorigin src="./assets/index-671a5e06.js"></script>
11
11
  <link rel="modulepreload" crossorigin href="./assets/react-f82877fd.js">
12
12
  <link rel="modulepreload" crossorigin href="./assets/recharts-4037aff0.js">
13
13
  <link rel="modulepreload" crossorigin href="./assets/tremor-38a10867.js">
14
- <link rel="stylesheet" href="./assets/index-b9779128.css">
14
+ <link rel="stylesheet" href="./assets/index-9352595e.css">
15
15
  </head>
16
16
  <body class="block">
17
17
  <div id="root"></div>
@@ -0,0 +1,61 @@
1
+ import transformers
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional
5
+
6
+
7
+ @dataclass
8
+ class ModelArguments:
9
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
10
+ version: Optional[str] = field(default="v0")
11
+ freeze_backbone: bool = field(default=False)
12
+ tune_speech_projector: bool = field(default=False)
13
+ tune_speech_encoder: bool = field(default=False)
14
+ tune_speech_generator_only: bool = field(default=False)
15
+ speech_encoder_type: Optional[str] = field(default=None)
16
+ speech_encoder: Optional[str] = field(default=None)
17
+ pretrain_speech_projector: Optional[str] = field(default=None)
18
+ speech_projector_type: Optional[str] = field(default="linear")
19
+ speech_generator_type: Optional[str] = field(default="ctc")
20
+ ctc_decoder_config: str = "(2,4096,32,11008)"
21
+ ctc_upsample_factor: int = 1
22
+ ctc_loss_weight: float = 1.0
23
+ unit_vocab_size: int = 1000
24
+ speech_encoder_ds_rate: int = 5
25
+ speech_encoder_hidden_size: int = 1280
26
+
27
+
28
+ @dataclass
29
+ class DataArguments:
30
+ data_path: str = field(default="", metadata={"help": "Path to the training data."})
31
+ is_multimodal: bool = False
32
+ input_type: str = field(default="mel")
33
+ speech_normalize: bool = False
34
+ mel_size: int = 128
35
+ has_tgt_units: bool = False
36
+
37
+
38
+ @dataclass
39
+ class TrainingArguments(transformers.TrainingArguments):
40
+ cache_dir: Optional[str] = field(default=None)
41
+ optim: str = field(default="adamw_torch")
42
+ freeze_speech_projector: bool = field(default=False)
43
+ model_max_length: int = field(
44
+ default=512,
45
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
46
+ )
47
+ double_quant: bool = field(
48
+ default=True, metadata={"help": "Compress the quantization statistics through double quantization."}
49
+ )
50
+ quant_type: str = field(
51
+ default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
52
+ )
53
+ bits: int = field(default=16, metadata={"help": "How many bits to use."})
54
+ lora_enable: bool = False
55
+ lora_r: int = 64
56
+ lora_alpha: int = 16
57
+ lora_dropout: float = 0.05
58
+ lora_weight_path: str = ""
59
+ lora_bias: str = "none"
60
+ speech_projector_lr: Optional[float] = None
61
+ group_by_modality_length: bool = field(default=False)
@@ -0,0 +1,9 @@
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ SPEECH_TOKEN_INDEX = -200
9
+ DEFAULT_SPEECH_TOKEN = "<speech>"
@@ -0,0 +1,213 @@
1
+ # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
2
+ # Copyright 2023 Haotian Liu
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import dataclasses
17
+ from enum import auto, Enum
18
+ from typing import List, Any, Union, Optional
19
+
20
+
21
+ class SeparatorStyle(Enum):
22
+ """Different separator style."""
23
+
24
+ TWO = auto()
25
+ PLAIN = auto()
26
+ LLAMA_2 = auto()
27
+ LLAMA_3 = auto()
28
+
29
+
30
+ @dataclasses.dataclass
31
+ class Conversation:
32
+ """A class that keeps all conversation history."""
33
+
34
+ system: str
35
+ roles: List[str]
36
+ messages: List[List[str]]
37
+ offset: int
38
+ sep_style: SeparatorStyle = SeparatorStyle.PLAIN
39
+ sep: str = "###"
40
+ sep2: str = ""
41
+ version: str = "Unknown"
42
+
43
+ tokenizer_id: str = ""
44
+ tokenizer: Any = None
45
+ # Stop criteria (the default one is EOS token)
46
+ stop_str: Optional[Union[str, List[str]]] = None
47
+ # Stops generation if meeting any token in this list
48
+ stop_token_ids: Optional[List[int]] = None
49
+
50
+ skip_next: bool = False
51
+
52
+ def get_prompt(self):
53
+ messages = self.messages
54
+
55
+ if self.sep_style == SeparatorStyle.TWO:
56
+ seps = [self.sep, self.sep2]
57
+ ret = self.system + seps[0]
58
+ for i, (role, message) in enumerate(messages):
59
+ if message:
60
+ if type(message) is tuple:
61
+ message = message[0]
62
+ ret += role + ": " + message + seps[i % 2]
63
+ else:
64
+ ret += role + ":"
65
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
66
+ wrap_sys = lambda msg: (
67
+ f"<|start_header_id|>system<|end_header_id|>\n\n{msg}<|eot_id|>" if len(msg) > 0 else msg
68
+ )
69
+ ret = "<|begin_of_text|>" + wrap_sys(self.system)
70
+ for i, (role, message) in enumerate(messages):
71
+ if message:
72
+ if type(message) is tuple:
73
+ message = message[0]
74
+ ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
75
+ ret += message.strip() + self.sep2
76
+ else:
77
+ ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
78
+ return ret
79
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
80
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
81
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
82
+ ret = ""
83
+
84
+ for i, (role, message) in enumerate(messages):
85
+ if i == 0:
86
+ assert message, "first message should not be none"
87
+ assert role == self.roles[0], "first message should come from user"
88
+ if message:
89
+ if type(message) is tuple:
90
+ message = message[0]
91
+ if i == 0:
92
+ message = wrap_sys(self.system) + message
93
+ if i % 2 == 0:
94
+ message = wrap_inst(message)
95
+ ret += self.sep + message
96
+ else:
97
+ ret += " " + message + " " + self.sep2
98
+ else:
99
+ ret += ""
100
+ ret = ret.lstrip(self.sep)
101
+ elif self.sep_style == SeparatorStyle.PLAIN:
102
+ seps = [self.sep, self.sep2]
103
+ ret = self.system
104
+ for i, (role, message) in enumerate(messages):
105
+ if message:
106
+ if type(message) is tuple:
107
+ message = message[0]
108
+ ret += message + seps[i % 2]
109
+ else:
110
+ ret += ""
111
+ else:
112
+ raise ValueError(f"Invalid style: {self.sep_style}")
113
+
114
+ return ret
115
+
116
+ def append_message(self, role, message):
117
+ self.messages.append([role, message])
118
+
119
+ def to_gradio_chatbot(self):
120
+ ret = []
121
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
122
+ if i % 2 == 0:
123
+ if type(msg) is tuple:
124
+ msg = msg[0]
125
+ ret.append([msg, None])
126
+ else:
127
+ ret.append([msg, None])
128
+ else:
129
+ ret[-1][-1] = msg
130
+ return ret
131
+
132
+ def copy(self):
133
+ return Conversation(
134
+ system=self.system,
135
+ roles=self.roles,
136
+ messages=[[x, y] for x, y in self.messages],
137
+ offset=self.offset,
138
+ sep_style=self.sep_style,
139
+ sep=self.sep,
140
+ sep2=self.sep2,
141
+ version=self.version,
142
+ )
143
+
144
+ def dict(self):
145
+ return {
146
+ "system": self.system,
147
+ "roles": self.roles,
148
+ "messages": self.messages,
149
+ "offset": self.offset,
150
+ "sep": self.sep,
151
+ "sep2": self.sep2,
152
+ }
153
+
154
+
155
+ conv_vicuna_v1 = Conversation(
156
+ system="A chat between a curious user and an artificial intelligence assistant. "
157
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
158
+ roles=["USER", "ASSISTANT"],
159
+ version="v1",
160
+ messages=[],
161
+ offset=0,
162
+ sep_style=SeparatorStyle.TWO,
163
+ sep=" ",
164
+ sep2="</s>",
165
+ )
166
+
167
+ conv_llama_2 = Conversation(
168
+ system="You are a helpful language and speech assistant. "
169
+ "You are able to understand the speech content that the user provides, "
170
+ "and assist the user with a variety of tasks using natural language.",
171
+ roles=["USER", "ASSISTANT"],
172
+ version="llama_v2",
173
+ messages=[],
174
+ offset=0,
175
+ sep_style=SeparatorStyle.LLAMA_2,
176
+ sep="<s>",
177
+ sep2="</s>",
178
+ )
179
+
180
+ conv_llama_3 = Conversation(
181
+ system="You are a helpful language and speech assistant. "
182
+ "You are able to understand the speech content that the user provides, "
183
+ "and assist the user with a variety of tasks using natural language.",
184
+ roles=["user", "assistant"],
185
+ version="llama_v3",
186
+ messages=[],
187
+ offset=0,
188
+ sep_style=SeparatorStyle.LLAMA_3,
189
+ sep="",
190
+ sep2="<|eot_id|>",
191
+ )
192
+
193
+ conv_plain = Conversation(
194
+ system="",
195
+ roles=["", ""],
196
+ messages=[],
197
+ offset=0,
198
+ sep_style=SeparatorStyle.PLAIN,
199
+ sep="</s>",
200
+ )
201
+
202
+
203
+ default_conversation = conv_llama_3
204
+ conv_templates = {
205
+ "v1": conv_vicuna_v1,
206
+ "plain": conv_plain,
207
+ "llama_2": conv_llama_2,
208
+ "llama_3": conv_llama_3,
209
+ }
210
+
211
+
212
+ if __name__ == "__main__":
213
+ print(default_conversation.get_prompt())
@@ -0,0 +1,88 @@
1
+ import os
2
+
3
+ from transformers import AutoTokenizer, AutoConfig, BitsAndBytesConfig
4
+ import torch
5
+ from helm.clients.audio_language.llama_omni.model.language_model.omni_speech_llama import OmniSpeechLlamaForCausalLM
6
+ from helm.clients.audio_language.llama_omni.model.language_model.omni_speech2s_llama import OmniSpeech2SLlamaForCausalLM
7
+ from helm.clients.audio_language.llama_omni.model.speech_encoder.builder import build_speech_encoder
8
+
9
+
10
+ def load_pretrained_model(
11
+ model_path,
12
+ model_base,
13
+ is_lora=False,
14
+ s2s=False,
15
+ load_8bit=False,
16
+ load_4bit=False,
17
+ device="cuda",
18
+ use_flash_attn=False,
19
+ **kwargs,
20
+ ):
21
+ if load_8bit:
22
+ kwargs["load_in_8bit"] = True
23
+ elif load_4bit:
24
+ kwargs["load_in_4bit"] = True
25
+ kwargs["quantization_config"] = BitsAndBytesConfig(
26
+ load_in_4bit=True,
27
+ bnb_4bit_compute_dtype=torch.float16,
28
+ bnb_4bit_use_double_quant=True,
29
+ bnb_4bit_quant_type="nf4",
30
+ )
31
+ else:
32
+ kwargs["torch_dtype"] = torch.float16
33
+
34
+ if use_flash_attn:
35
+ kwargs["attn_implementation"] = "flash_attention_2"
36
+
37
+ model_cls = OmniSpeech2SLlamaForCausalLM if s2s else OmniSpeechLlamaForCausalLM
38
+
39
+ # Load OmniSpeech model
40
+ if is_lora:
41
+ assert model_base is not None, "model_base is required for LoRA models."
42
+ from language_model.omni_speech_llama import OmniSpeechConfig
43
+
44
+ lora_cfg_pretrained = OmniSpeechConfig.from_pretrained(model_path)
45
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
46
+ print("Loading OmniSpeech from base model...")
47
+ model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=lora_cfg_pretrained, **kwargs)
48
+ print("Loading additional OmniSpeech weights...")
49
+ if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
50
+ non_lora_trainables = torch.load(os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu")
51
+ non_lora_trainables = {
52
+ (k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items()
53
+ }
54
+ if any(k.startswith("model.model.") for k in non_lora_trainables):
55
+ non_lora_trainables = {(k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()}
56
+ model.load_state_dict(non_lora_trainables, strict=False)
57
+
58
+ from peft import PeftModel
59
+
60
+ print("Loading LoRA weights...")
61
+ model = PeftModel.from_pretrained(model, model_path)
62
+ print("Merging LoRA weights...")
63
+ model = model.merge_and_unload()
64
+ print("Model is loaded...")
65
+ elif model_base is not None:
66
+ print("Loading OmniSpeech from base model...")
67
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
68
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
69
+ model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=cfg_pretrained, **kwargs)
70
+
71
+ speech_projector_weights = torch.load(os.path.join(model_path, "speech_projector.bin"), map_location="cpu")
72
+ speech_projector_weights = {k: v.to(torch.float16) for k, v in speech_projector_weights.items()}
73
+ model.load_state_dict(speech_projector_weights, strict=False)
74
+ model = model.to(device=device)
75
+ else:
76
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
77
+ model = model_cls.from_pretrained(model_path, low_cpu_mem_usage=False, **kwargs)
78
+ model = model.to(device=device)
79
+
80
+ model.get_model().speech_encoder = build_speech_encoder(model.config)
81
+ model.get_model().speech_encoder.to(device=device, dtype=torch.float16)
82
+
83
+ if hasattr(model.config, "max_sequence_length"):
84
+ context_len = model.config.max_sequence_length
85
+ else:
86
+ context_len = 2048
87
+
88
+ return tokenizer, model, context_len
@@ -0,0 +1,190 @@
1
+ from typing import List, Optional, Tuple, Union, Callable
2
+
3
+ import torch
4
+
5
+ from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig
6
+
7
+ from transformers import PreTrainedModel
8
+ from transformers.generation.streamers import BaseStreamer
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+ from transformers.generation.utils import (
11
+ GenerationConfig,
12
+ LogitsProcessorList,
13
+ StoppingCriteriaList,
14
+ )
15
+
16
+ from helm.clients.audio_language.llama_omni.model.language_model.omni_speech_llama import OmniSpeechLlamaForCausalLM
17
+ from helm.clients.audio_language.llama_omni.model.speech_generator.builder import build_speech_generator
18
+ from helm.clients.audio_language.llama_omni.model.speech_generator.generation import GenerationWithCTC
19
+
20
+
21
+ class OmniSpeech2SConfig(LlamaConfig):
22
+ model_type = "omni_speech2s_llama"
23
+
24
+
25
+ class OmniSpeech2SLlamaForCausalLM(OmniSpeechLlamaForCausalLM, GenerationWithCTC):
26
+ config_class = OmniSpeech2SConfig
27
+
28
+ def __init__(self, config):
29
+ super().__init__(config)
30
+
31
+ # Initialize weights and apply final processing
32
+ self.post_init()
33
+ if hasattr(config, "speech_generator_type"):
34
+ self.speech_generator = build_speech_generator(config)
35
+
36
+ def initialize_speech_generator(self, model_args):
37
+ self.config.speech_generator_type = getattr(model_args, "speech_generator_type", "ctc")
38
+ self.config.ctc_decoder_config = getattr(model_args, "ctc_decoder_config", "(4,4096,32,11008)")
39
+ self.config.ctc_upsample_factor = getattr(model_args, "ctc_upsample_factor", 1)
40
+ self.config.ctc_loss_weight = getattr(model_args, "ctc_loss_weight", 1.0)
41
+ self.config.unit_vocab_size = getattr(model_args, "unit_vocab_size", 1000)
42
+ self.tune_speech_generator_only = getattr(model_args, "tune_speech_generator_only", False)
43
+ if getattr(self, "speech_generator", None) is None:
44
+ self.speech_generator = build_speech_generator(self.config)
45
+
46
+ def forward(
47
+ self,
48
+ input_ids: torch.LongTensor,
49
+ attention_mask: Optional[torch.Tensor] = None,
50
+ position_ids: Optional[torch.LongTensor] = None,
51
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
52
+ inputs_embeds: Optional[torch.FloatTensor] = None,
53
+ labels: Optional[torch.LongTensor] = None,
54
+ use_cache: Optional[bool] = None,
55
+ output_attentions: Optional[bool] = None,
56
+ output_hidden_states: Optional[bool] = None,
57
+ speech: Optional[torch.FloatTensor] = None,
58
+ speech_lengths: Optional[torch.LongTensor] = None,
59
+ tgt_units: Optional[torch.LongTensor] = None,
60
+ return_dict: Optional[bool] = None,
61
+ cache_position: Optional[torch.LongTensor] = None,
62
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
63
+
64
+ if inputs_embeds is None:
65
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = (
66
+ self.prepare_inputs_labels_for_speech_and_text(
67
+ input_ids, position_ids, attention_mask, past_key_values, labels, speech, speech_lengths
68
+ )
69
+ )
70
+
71
+ if self.training:
72
+ if self.tune_speech_generator_only:
73
+ with torch.no_grad():
74
+ llama_output = super(OmniSpeechLlamaForCausalLM, self).forward(
75
+ input_ids=input_ids,
76
+ attention_mask=attention_mask,
77
+ position_ids=position_ids,
78
+ past_key_values=past_key_values,
79
+ inputs_embeds=inputs_embeds,
80
+ labels=labels,
81
+ use_cache=use_cache,
82
+ output_attentions=output_attentions,
83
+ output_hidden_states=True,
84
+ return_dict=return_dict,
85
+ )
86
+ loss = self.speech_generator(llama_output["hidden_states"][-1], labels, tgt_units)
87
+ else:
88
+ llama_output = super(OmniSpeechLlamaForCausalLM, self).forward(
89
+ input_ids=input_ids,
90
+ attention_mask=attention_mask,
91
+ position_ids=position_ids,
92
+ past_key_values=past_key_values,
93
+ inputs_embeds=inputs_embeds,
94
+ labels=labels,
95
+ use_cache=use_cache,
96
+ output_attentions=output_attentions,
97
+ output_hidden_states=True,
98
+ return_dict=return_dict,
99
+ )
100
+ lm_loss = llama_output.loss
101
+ ctc_loss = self.speech_generator(llama_output["hidden_states"][-1], labels, tgt_units)
102
+ loss = lm_loss + ctc_loss * self.config.ctc_loss_weight
103
+ else:
104
+ llama_output = super(OmniSpeechLlamaForCausalLM, self).forward(
105
+ input_ids=input_ids,
106
+ attention_mask=attention_mask,
107
+ position_ids=position_ids,
108
+ past_key_values=past_key_values,
109
+ inputs_embeds=inputs_embeds,
110
+ labels=labels,
111
+ use_cache=use_cache,
112
+ output_attentions=output_attentions,
113
+ output_hidden_states=True,
114
+ return_dict=return_dict,
115
+ )
116
+ loss = llama_output.loss
117
+
118
+ return CausalLMOutputWithPast(
119
+ loss=loss,
120
+ logits=llama_output.logits,
121
+ past_key_values=llama_output.past_key_values,
122
+ hidden_states=llama_output.hidden_states,
123
+ attentions=llama_output.attentions,
124
+ )
125
+
126
+ @torch.no_grad()
127
+ def generate(
128
+ self,
129
+ inputs: Optional[torch.Tensor] = None,
130
+ speech: Optional[torch.Tensor] = None,
131
+ speech_lengths: Optional[torch.Tensor] = None,
132
+ generation_config: Optional[GenerationConfig] = None,
133
+ logits_processor: Optional[LogitsProcessorList] = None,
134
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
135
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
136
+ synced_gpus: Optional[bool] = None,
137
+ assistant_model: Optional["PreTrainedModel"] = None,
138
+ streamer: Optional["BaseStreamer"] = None,
139
+ streamer_unit: Optional["BaseStreamer"] = None,
140
+ streaming_unit_gen=False,
141
+ negative_prompt_ids: Optional[torch.Tensor] = None,
142
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
143
+ **kwargs,
144
+ ):
145
+ position_ids = kwargs.pop("position_ids", None)
146
+ attention_mask = kwargs.pop("attention_mask", None)
147
+ if "inputs_embeds" in kwargs:
148
+ raise NotImplementedError("`inputs_embeds` is not supported")
149
+
150
+ if speech is not None:
151
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = (
152
+ self.prepare_inputs_labels_for_speech_and_text(
153
+ inputs, position_ids, attention_mask, None, None, speech, speech_lengths
154
+ )
155
+ )
156
+ else:
157
+ inputs_embeds = self.get_model().embed_tokens(inputs)
158
+ outputs = GenerationWithCTC.generate(
159
+ self,
160
+ position_ids=position_ids,
161
+ attention_mask=attention_mask,
162
+ inputs_embeds=inputs_embeds,
163
+ output_hidden_states=True,
164
+ return_dict_in_generate=True,
165
+ streaming_unit_gen=streaming_unit_gen,
166
+ **kwargs,
167
+ )
168
+
169
+ hidden_states = outputs["hidden_states"]
170
+ hidden_states = torch.cat(
171
+ [hidden_states[0][-1][:, -1:, :]] + [hidden_states[i][-1] for i in range(1, len(hidden_states))], dim=1
172
+ )
173
+ ctc_pred = self.speech_generator.predict(hidden_states.squeeze(0))
174
+
175
+ return outputs.sequences, ctc_pred
176
+
177
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
178
+ speech = kwargs.pop("speech", None)
179
+ speech_lengths = kwargs.pop("speech_lengths", None)
180
+ inputs = super().prepare_inputs_for_generation(
181
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
182
+ )
183
+ if speech is not None:
184
+ inputs["speech"] = speech
185
+ inputs["speech_lengths"] = speech_lengths
186
+ return inputs
187
+
188
+
189
+ AutoConfig.register("omni_speech2s_llama", OmniSpeech2SConfig)
190
+ AutoModelForCausalLM.register(OmniSpeech2SConfig, OmniSpeech2SLlamaForCausalLM)