crfm-helm 0.5.7__py3-none-any.whl → 0.5.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crfm-helm might be problematic. Click here for more details.

Files changed (243) hide show
  1. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/METADATA +5 -77
  2. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/RECORD +228 -197
  3. helm/benchmark/adaptation/adapter_spec.py +5 -0
  4. helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +11 -3
  5. helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +11 -8
  6. helm/benchmark/annotation/aci_bench_annotator.py +11 -22
  7. helm/benchmark/annotation/alrage_annotator.py +90 -0
  8. helm/benchmark/annotation/chw_care_plan_annotator.py +10 -21
  9. helm/benchmark/annotation/dischargeme_annotator.py +11 -22
  10. helm/benchmark/annotation/med_dialog_annotator.py +11 -22
  11. helm/benchmark/annotation/medalign_annotator.py +11 -22
  12. helm/benchmark/annotation/medi_qa_annotator.py +11 -22
  13. helm/benchmark/annotation/medication_qa_annotator.py +11 -22
  14. helm/benchmark/annotation/mental_health_annotator.py +11 -22
  15. helm/benchmark/annotation/mimic_bhc_annotator.py +11 -22
  16. helm/benchmark/annotation/mimic_rrs_annotator.py +11 -22
  17. helm/benchmark/annotation/model_as_judge.py +23 -18
  18. helm/benchmark/annotation/mtsamples_procedures_annotator.py +11 -22
  19. helm/benchmark/annotation/mtsamples_replicate_annotator.py +11 -22
  20. helm/benchmark/annotation/starr_patient_instructions_annotator.py +11 -22
  21. helm/benchmark/metrics/air_bench_metrics.py +3157 -1
  22. helm/benchmark/metrics/alrage_metric.py +35 -0
  23. helm/benchmark/metrics/basic_metrics.py +267 -2
  24. helm/benchmark/metrics/classification_metrics.py +19 -1
  25. helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +12 -1
  26. helm/benchmark/metrics/dry_run_metrics.py +30 -1
  27. helm/benchmark/metrics/efficiency_metrics.py +74 -0
  28. helm/benchmark/metrics/ehr_sql_metrics.py +57 -1
  29. helm/benchmark/metrics/evaluate_reference_metrics.py +299 -0
  30. helm/benchmark/metrics/gpqa_chain_of_thought_metric.py +13 -1
  31. helm/benchmark/metrics/helpdesk_call_summarization_metrics.py +13 -1
  32. helm/benchmark/metrics/ifeval_metrics.py +13 -1
  33. helm/benchmark/metrics/instruction_following_critique_metrics.py +41 -1
  34. helm/benchmark/metrics/kpi_edgar_metrics.py +21 -0
  35. helm/benchmark/metrics/language_modeling_metrics.py +13 -1
  36. helm/benchmark/metrics/live_qa_metrics.py +13 -1
  37. helm/benchmark/metrics/llm_jury_metrics.py +13 -1
  38. helm/benchmark/metrics/medcalc_bench_metrics.py +14 -1
  39. helm/benchmark/metrics/medec_metrics.py +25 -2
  40. helm/benchmark/metrics/metric.py +25 -0
  41. helm/benchmark/metrics/mimiciv_billing_code_metrics.py +32 -1
  42. helm/benchmark/metrics/omni_math_metrics.py +13 -1
  43. helm/benchmark/metrics/seahelm_metrics.py +14 -1
  44. helm/benchmark/metrics/summac/model_summac.py +2 -2
  45. helm/benchmark/metrics/summarization_metrics.py +129 -1
  46. helm/benchmark/metrics/toxicity_metrics.py +31 -1
  47. helm/benchmark/metrics/wildbench_metrics.py +21 -1
  48. helm/benchmark/presentation/schema.py +5 -22
  49. helm/benchmark/presentation/summarize.py +180 -11
  50. helm/benchmark/presentation/taxonomy_info.py +20 -0
  51. helm/benchmark/run_expander.py +4 -0
  52. helm/benchmark/run_specs/arabic_run_specs.py +134 -16
  53. helm/benchmark/run_specs/bluex_run_specs.py +1 -1
  54. helm/benchmark/run_specs/classic_run_specs.py +2 -2
  55. helm/benchmark/run_specs/long_context_run_specs.py +2 -2
  56. helm/benchmark/run_specs/medhelm/__init__.py +0 -0
  57. helm/benchmark/run_specs/medhelm/benchmark_config.py +219 -0
  58. helm/benchmark/run_specs/medhelm_run_specs.py +360 -50
  59. helm/benchmark/scenarios/aci_bench_scenario.py +23 -0
  60. helm/benchmark/scenarios/air_bench_scenario.py +21 -0
  61. helm/benchmark/scenarios/alrage_scenario.py +54 -0
  62. helm/benchmark/scenarios/anthropic_hh_rlhf_scenario.py +23 -1
  63. helm/benchmark/scenarios/arabic_exams_scenario.py +114 -0
  64. helm/benchmark/scenarios/arabic_mmlu_scenario.py +8 -4
  65. helm/benchmark/scenarios/aratrust_scenario.py +19 -0
  66. helm/benchmark/scenarios/babi_qa_scenario.py +15 -0
  67. helm/benchmark/scenarios/bbq_scenario.py +15 -0
  68. helm/benchmark/scenarios/best_chatgpt_prompts.yaml +473 -0
  69. helm/benchmark/scenarios/bluex_scenario.py +6 -2
  70. helm/benchmark/scenarios/bold_scenario.py +15 -0
  71. helm/benchmark/scenarios/boolq_scenario.py +20 -0
  72. helm/benchmark/scenarios/chw_care_plan_scenario.py +23 -0
  73. helm/benchmark/scenarios/civil_comments_scenario.py +13 -0
  74. helm/benchmark/scenarios/clear_scenario.py +23 -0
  75. helm/benchmark/scenarios/cleva_scenario.py +479 -0
  76. helm/benchmark/scenarios/code_scenario.py +28 -0
  77. helm/benchmark/scenarios/commonsense_scenario.py +26 -0
  78. helm/benchmark/scenarios/compositional_instructions.yaml +70 -0
  79. helm/benchmark/scenarios/conv_fin_qa_calc_scenario.py +21 -0
  80. helm/benchmark/scenarios/copyright_scenario.py +35 -1
  81. helm/benchmark/scenarios/cti_to_mitre_scenario.py +21 -0
  82. helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +22 -1
  83. helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +23 -1
  84. helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +22 -1
  85. helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +21 -1
  86. helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +13 -0
  87. helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +13 -1
  88. helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +13 -1
  89. helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +13 -1
  90. helm/benchmark/scenarios/dischargeme_scenario.py +24 -0
  91. helm/benchmark/scenarios/disinformation_scenario.py +22 -0
  92. helm/benchmark/scenarios/dyck_language_scenario.py +15 -0
  93. helm/benchmark/scenarios/ehrshot_scenario.py +22 -0
  94. helm/benchmark/scenarios/enem_challenge_scenario.py +19 -0
  95. helm/benchmark/scenarios/entity_data_imputation_scenario.py +14 -0
  96. helm/benchmark/scenarios/entity_matching_scenario.py +14 -0
  97. helm/benchmark/scenarios/financial_phrasebank_scenario.py +21 -0
  98. helm/benchmark/scenarios/gold_commodity_news_scenario.py +21 -0
  99. helm/benchmark/scenarios/gpqa_scenario.py +18 -0
  100. helm/benchmark/scenarios/grammar_scenario.py +20 -1
  101. helm/benchmark/scenarios/gsm_scenario.py +15 -0
  102. helm/benchmark/scenarios/headqa_scenario.py +22 -0
  103. helm/benchmark/scenarios/helpdesk_call_summarization_scenario.py +13 -0
  104. helm/benchmark/scenarios/ice_scenario.py +21 -1
  105. helm/benchmark/scenarios/ifeval_scenario.py +18 -0
  106. helm/benchmark/scenarios/imdb_scenario.py +15 -0
  107. helm/benchmark/scenarios/koala_scenario.py +21 -1
  108. helm/benchmark/scenarios/kpi_edgar_scenario.py +21 -0
  109. helm/benchmark/scenarios/legal_contract_summarization_scenario.py +20 -0
  110. helm/benchmark/scenarios/legal_summarization_scenario.py +50 -0
  111. helm/benchmark/scenarios/legal_support_scenario.py +13 -0
  112. helm/benchmark/scenarios/legalbench_scenario.py +20 -0
  113. helm/benchmark/scenarios/lex_glue_scenario.py +11 -0
  114. helm/benchmark/scenarios/lextreme_scenario.py +11 -0
  115. helm/benchmark/scenarios/lsat_qa_scenario.py +14 -0
  116. helm/benchmark/scenarios/madinah_qa_scenario.py +73 -0
  117. helm/benchmark/scenarios/math_scenario.py +26 -0
  118. helm/benchmark/scenarios/mbzuai_human_translated_arabic_mmlu.py +68 -0
  119. helm/benchmark/scenarios/med_dialog_scenario.py +32 -1
  120. helm/benchmark/scenarios/med_mcqa_scenario.py +14 -0
  121. helm/benchmark/scenarios/med_qa_scenario.py +14 -0
  122. helm/benchmark/scenarios/medalign_scenario.py +23 -0
  123. helm/benchmark/scenarios/medbullets_scenario.py +22 -0
  124. helm/benchmark/scenarios/medcalc_bench_scenario.py +22 -0
  125. helm/benchmark/scenarios/medec_scenario.py +23 -0
  126. helm/benchmark/scenarios/medhallu_scenario.py +23 -0
  127. helm/benchmark/scenarios/medhelm/__init__.py +0 -0
  128. helm/benchmark/scenarios/medhelm/judges.yaml +14 -0
  129. helm/benchmark/scenarios/medhelm_configurable_scenario.py +101 -0
  130. helm/benchmark/scenarios/medi_qa_scenario.py +23 -0
  131. helm/benchmark/scenarios/medication_qa_scenario.py +31 -1
  132. helm/benchmark/scenarios/mental_health_scenario.py +23 -0
  133. helm/benchmark/scenarios/mimic_bhc_scenario.py +24 -0
  134. helm/benchmark/scenarios/mimic_rrs_scenario.py +23 -0
  135. helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +22 -0
  136. helm/benchmark/scenarios/mmlu_pro_scenario.py +18 -0
  137. helm/benchmark/scenarios/mmlu_scenario.py +15 -0
  138. helm/benchmark/scenarios/msmarco_scenario.py +30 -0
  139. helm/benchmark/scenarios/mtsamples_procedures_scenario.py +22 -0
  140. helm/benchmark/scenarios/mtsamples_replicate_scenario.py +22 -0
  141. helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +20 -0
  142. helm/benchmark/scenarios/narrativeqa_scenario.py +20 -0
  143. helm/benchmark/scenarios/natural_qa_scenario.py +32 -0
  144. helm/benchmark/scenarios/omni_math_scenario.py +18 -0
  145. helm/benchmark/scenarios/open_assistant_scenario.py +22 -0
  146. helm/benchmark/scenarios/pubmed_qa_scenario.py +22 -0
  147. helm/benchmark/scenarios/quac_scenario.py +14 -0
  148. helm/benchmark/scenarios/race_based_med_scenario.py +23 -0
  149. helm/benchmark/scenarios/raft_scenario.py +15 -0
  150. helm/benchmark/scenarios/real_toxicity_prompts_scenario.py +14 -1
  151. helm/benchmark/scenarios/scenario.py +31 -0
  152. helm/benchmark/scenarios/seahelm_scenario.py +348 -0
  153. helm/benchmark/scenarios/self_instruct_scenario.py +29 -1
  154. helm/benchmark/scenarios/shc_bmt_scenario.py +22 -0
  155. helm/benchmark/scenarios/shc_cdi_scenario.py +20 -0
  156. helm/benchmark/scenarios/shc_conf_scenario.py +23 -0
  157. helm/benchmark/scenarios/shc_ent_scenario.py +21 -0
  158. helm/benchmark/scenarios/shc_gip_scenario.py +20 -0
  159. helm/benchmark/scenarios/shc_privacy_scenario.py +22 -0
  160. helm/benchmark/scenarios/shc_proxy_scenario.py +22 -0
  161. helm/benchmark/scenarios/shc_ptbm_scenario.py +23 -0
  162. helm/benchmark/scenarios/shc_sequoia_scenario.py +21 -0
  163. helm/benchmark/scenarios/situation_prompts.yaml +49 -0
  164. helm/benchmark/scenarios/starr_patient_instructions_scenario.py +22 -0
  165. helm/benchmark/scenarios/summarization_scenario.py +37 -0
  166. helm/benchmark/scenarios/synthetic_efficiency_scenario.py +22 -1
  167. helm/benchmark/scenarios/synthetic_reasoning_natural_scenario.py +13 -0
  168. helm/benchmark/scenarios/test_alrage_scenario.py +23 -0
  169. helm/benchmark/scenarios/test_arabic_exams_scenario.py +21 -0
  170. helm/benchmark/scenarios/test_aratrust_scenario.py +1 -1
  171. helm/benchmark/scenarios/test_bluex_scenario.py +2 -2
  172. helm/benchmark/scenarios/the_pile_scenario.py +13 -1
  173. helm/benchmark/scenarios/truthful_qa_scenario.py +14 -0
  174. helm/benchmark/scenarios/twitter_aae_scenario.py +20 -1
  175. helm/benchmark/scenarios/vicuna_scenario.py +21 -1
  176. helm/benchmark/scenarios/wikifact_scenario.py +20 -0
  177. helm/benchmark/scenarios/wildbench_scenario.py +18 -0
  178. helm/benchmark/scenarios/wmt_14_scenario.py +12 -0
  179. helm/benchmark/static/schema_arabic.yaml +55 -12
  180. helm/benchmark/static/schema_long_context.yaml +17 -17
  181. helm/benchmark/static/schema_medhelm.yaml +36 -0
  182. helm/benchmark/static/schema_slp.yaml +219 -0
  183. helm/benchmark/static_build/assets/index-671a5e06.js +10 -0
  184. helm/benchmark/static_build/assets/index-9352595e.css +1 -0
  185. helm/benchmark/static_build/index.html +2 -2
  186. helm/clients/audio_language/llama_omni/arguments.py +61 -0
  187. helm/clients/audio_language/llama_omni/constants.py +9 -0
  188. helm/clients/audio_language/llama_omni/conversation.py +213 -0
  189. helm/clients/audio_language/llama_omni/model/__init__.py +0 -0
  190. helm/clients/audio_language/llama_omni/model/builder.py +88 -0
  191. helm/clients/audio_language/llama_omni/model/language_model/omni_speech2s_llama.py +190 -0
  192. helm/clients/audio_language/llama_omni/model/language_model/omni_speech_llama.py +118 -0
  193. helm/clients/audio_language/llama_omni/model/omni_speech_arch.py +249 -0
  194. helm/clients/audio_language/llama_omni/model/speech_encoder/builder.py +9 -0
  195. helm/clients/audio_language/llama_omni/model/speech_encoder/speech_encoder.py +27 -0
  196. helm/clients/audio_language/llama_omni/model/speech_generator/builder.py +9 -0
  197. helm/clients/audio_language/llama_omni/model/speech_generator/generation.py +622 -0
  198. helm/clients/audio_language/llama_omni/model/speech_generator/speech_generator.py +104 -0
  199. helm/clients/audio_language/llama_omni/model/speech_projector/builder.py +9 -0
  200. helm/clients/audio_language/llama_omni/model/speech_projector/speech_projector.py +27 -0
  201. helm/clients/audio_language/llama_omni/preprocess.py +295 -0
  202. helm/clients/audio_language/llama_omni/utils.py +202 -0
  203. helm/clients/audio_language/qwen_omni/configuration_qwen2_5_omni.py +519 -0
  204. helm/clients/audio_language/qwen_omni/modeling_qwen2_5_omni.py +4308 -0
  205. helm/clients/audio_language/qwen_omni/processing_qwen2_5_omni.py +270 -0
  206. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/__init__.py +0 -0
  207. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/__init__.py +8 -0
  208. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/audio_process.py +56 -0
  209. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/vision_process.py +380 -0
  210. helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +1 -1
  211. helm/clients/image_generation/mindalle/models/stage1/layers.py +2 -2
  212. helm/clients/openai_client.py +31 -19
  213. helm/clients/openai_responses_client.py +27 -3
  214. helm/clients/openrouter_client.py +31 -0
  215. helm/clients/test_openrouter_client.py +69 -0
  216. helm/clients/together_client.py +48 -11
  217. helm/clients/vertexai_client.py +8 -2
  218. helm/config/model_deployments.yaml +75 -1
  219. helm/config/model_metadata.yaml +70 -2
  220. helm/config/tokenizer_configs.yaml +19 -1
  221. helm/proxy/example_queries.py +8 -8
  222. helm/proxy/server.py +2 -1
  223. helm/proxy/static/index.css +4 -0
  224. helm/proxy/static/index.js +7 -1
  225. helm/benchmark/metrics/aci_bench_metrics.py +0 -14
  226. helm/benchmark/metrics/chw_care_plan_metrics.py +0 -14
  227. helm/benchmark/metrics/dischargeme_metrics.py +0 -14
  228. helm/benchmark/metrics/med_dialog_metrics.py +0 -14
  229. helm/benchmark/metrics/medalign_metrics.py +0 -14
  230. helm/benchmark/metrics/medi_qa_metrics.py +0 -14
  231. helm/benchmark/metrics/medication_qa_metrics.py +0 -14
  232. helm/benchmark/metrics/mental_health_metrics.py +0 -14
  233. helm/benchmark/metrics/mimic_bhc_metrics.py +0 -14
  234. helm/benchmark/metrics/mimic_rrs_metrics.py +0 -14
  235. helm/benchmark/metrics/mtsamples_procedures_metrics.py +0 -14
  236. helm/benchmark/metrics/mtsamples_replicate_metrics.py +0 -14
  237. helm/benchmark/metrics/starr_patient_instructions_metrics.py +0 -14
  238. helm/benchmark/static_build/assets/index-b9779128.css +0 -1
  239. helm/benchmark/static_build/assets/index-e439d5e1.js +0 -10
  240. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/WHEEL +0 -0
  241. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/entry_points.txt +0 -0
  242. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/licenses/LICENSE +0 -0
  243. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,118 @@
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
7
+
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast
9
+ from transformers.generation.utils import GenerateOutput
10
+
11
+ from helm.clients.audio_language.llama_omni.model.omni_speech_arch import OmniSpeechMetaModel, OmniSpeechMetaForCausalLM
12
+
13
+
14
+ class OmniSpeechConfig(LlamaConfig):
15
+ model_type = "omni_speech_llama"
16
+
17
+
18
+ class OmniSpeechLlamaModel(OmniSpeechMetaModel, LlamaModel):
19
+ config_class = OmniSpeechConfig
20
+
21
+ def __init__(self, config: LlamaConfig):
22
+ super(OmniSpeechLlamaModel, self).__init__(config)
23
+
24
+
25
+ class OmniSpeechLlamaForCausalLM(LlamaForCausalLM, OmniSpeechMetaForCausalLM):
26
+ config_class = OmniSpeechConfig
27
+
28
+ def __init__(self, config):
29
+ super(LlamaForCausalLM, self).__init__(config)
30
+ self.model = OmniSpeechLlamaModel(config)
31
+ self.pretraining_tp = config.pretraining_tp
32
+ self.vocab_size = config.vocab_size
33
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
34
+
35
+ # Initialize weights and apply final processing
36
+ self.post_init()
37
+
38
+ def get_model(self):
39
+ return self.model
40
+
41
+ def forward(
42
+ self,
43
+ input_ids: torch.LongTensor,
44
+ attention_mask: Optional[torch.Tensor] = None,
45
+ position_ids: Optional[torch.LongTensor] = None,
46
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
47
+ inputs_embeds: Optional[torch.FloatTensor] = None,
48
+ labels: Optional[torch.LongTensor] = None,
49
+ use_cache: Optional[bool] = None,
50
+ output_attentions: Optional[bool] = None,
51
+ output_hidden_states: Optional[bool] = None,
52
+ speech: Optional[torch.FloatTensor] = None,
53
+ speech_lengths: Optional[torch.LongTensor] = None,
54
+ tgt_units: Optional[torch.LongTensor] = None,
55
+ return_dict: Optional[bool] = None,
56
+ cache_position: Optional[torch.LongTensor] = None,
57
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
58
+
59
+ if inputs_embeds is None:
60
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = (
61
+ self.prepare_inputs_labels_for_speech_and_text(
62
+ input_ids, position_ids, attention_mask, past_key_values, labels, speech, speech_lengths
63
+ )
64
+ )
65
+
66
+ return super().forward(
67
+ input_ids=input_ids,
68
+ attention_mask=attention_mask,
69
+ position_ids=position_ids,
70
+ past_key_values=past_key_values,
71
+ inputs_embeds=inputs_embeds,
72
+ labels=labels,
73
+ use_cache=use_cache,
74
+ output_attentions=output_attentions,
75
+ output_hidden_states=output_hidden_states,
76
+ return_dict=return_dict,
77
+ )
78
+
79
+ @torch.no_grad()
80
+ def generate(
81
+ self,
82
+ inputs: Optional[torch.Tensor] = None,
83
+ speech: Optional[torch.Tensor] = None,
84
+ speech_lengths: Optional[torch.Tensor] = None,
85
+ **kwargs,
86
+ ) -> Union[GenerateOutput, torch.LongTensor]:
87
+ position_ids = kwargs.pop("position_ids", None)
88
+ attention_mask = kwargs.pop("attention_mask", None)
89
+ if "inputs_embeds" in kwargs:
90
+ raise NotImplementedError("`inputs_embeds` is not supported")
91
+
92
+ if speech is not None:
93
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = (
94
+ self.prepare_inputs_labels_for_speech_and_text(
95
+ inputs, position_ids, attention_mask, None, None, speech, speech_lengths
96
+ )
97
+ )
98
+ else:
99
+ inputs_embeds = self.get_model().embed_tokens(inputs)
100
+
101
+ return super().generate(
102
+ position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs
103
+ )
104
+
105
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
106
+ speech = kwargs.pop("speech", None)
107
+ speech_lengths = kwargs.pop("speech_lengths", None)
108
+ inputs = super().prepare_inputs_for_generation(
109
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
110
+ )
111
+ if speech is not None:
112
+ inputs["speech"] = speech
113
+ inputs["speech_lengths"] = speech_lengths
114
+ return inputs
115
+
116
+
117
+ AutoConfig.register("omni_speech_llama", OmniSpeechConfig)
118
+ AutoModelForCausalLM.register(OmniSpeechConfig, OmniSpeechLlamaForCausalLM)
@@ -0,0 +1,249 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from helm.clients.audio_language.llama_omni.model.speech_encoder.builder import build_speech_encoder
7
+ from helm.clients.audio_language.llama_omni.model.speech_projector.builder import build_speech_projector
8
+ from helm.clients.audio_language.llama_omni.constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX
9
+
10
+
11
+ class OmniSpeechMetaModel(nn.Module):
12
+
13
+ def __init__(self, config):
14
+ super(OmniSpeechMetaModel, self).__init__(config)
15
+ self.config = config
16
+
17
+ if hasattr(config, "speech_encoder"):
18
+ self.speech_encoder = build_speech_encoder(config)
19
+ self.speech_projector = build_speech_projector(config)
20
+
21
+ def get_speech_encoder(self):
22
+ speech_encoder = getattr(self, "speech_encoder", None)
23
+ if type(speech_encoder) is list:
24
+ speech_encoder = speech_encoder[0]
25
+ return speech_encoder
26
+
27
+ def initialize_speech_modules(self, model_args, fsdp=None):
28
+ self.config.speech_encoder = getattr(model_args, "speech_encoder", None)
29
+ self.config.speech_encoder_type = getattr(model_args, "speech_encoder_type", None)
30
+ self.config.speech_projector_type = getattr(model_args, "speech_projector_type", "linear")
31
+ self.config.speech_encoder_ds_rate = getattr(model_args, "speech_encoder_ds_rate", 5)
32
+ self.config.speech_encoder_hidden_size = getattr(model_args, "speech_encoder_hidden_size", 1280)
33
+
34
+ if self.get_speech_encoder() is None:
35
+ speech_encoder = build_speech_encoder(self.config)
36
+ if fsdp is not None and len(fsdp) > 0:
37
+ self.speech_encoder = [speech_encoder]
38
+ else:
39
+ self.speech_encoder = speech_encoder
40
+
41
+ if getattr(self, "speech_projector", None) is None:
42
+ self.speech_projector = build_speech_projector(self.config)
43
+ else:
44
+ # In case it is frozen by LoRA
45
+ for p in self.speech_projector.parameters():
46
+ p.requires_grad = True
47
+
48
+ if model_args.pretrain_speech_projector is not None:
49
+ pretrain_speech_projector_weights = torch.load(model_args.pretrain_speech_projector, map_location="cpu")
50
+
51
+ def get_w(weights, keyword):
52
+ return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
53
+
54
+ self.speech_projector.load_state_dict(get_w(pretrain_speech_projector_weights, "speech_projector"))
55
+
56
+
57
+ class OmniSpeechMetaForCausalLM(ABC):
58
+ def __init__(self, config):
59
+ self.config = config
60
+
61
+ @abstractmethod
62
+ def get_model(self):
63
+ pass
64
+
65
+ def get_speech_encoder(self):
66
+ return self.get_model().get_speech_encoder()
67
+
68
+ def get_speech_projector(self):
69
+ return self.get_model().speech_projector
70
+
71
+ def encode_speech(self, speech, speech_lengths):
72
+ speech_encoder_type = self.config.speech_encoder_type
73
+ speech_encoder = self.get_speech_encoder()
74
+ if "whisper" in speech_encoder_type.lower():
75
+ encoder_outs = speech_encoder(speech.permute(0, 2, 1))
76
+ speech_lengths = (speech_lengths + 1) // 2
77
+ else:
78
+ raise ValueError(f"Unknown speech encoder: {speech_encoder}")
79
+ speech_projector_type = self.config.speech_projector_type
80
+ speech_projector = self.get_speech_projector()
81
+ if speech_projector_type == "linear":
82
+ encoder_outs = speech_projector(encoder_outs)
83
+ speech_lengths = speech_lengths // speech_projector.k
84
+ else:
85
+ raise ValueError(f"Unknown speech projector: {speech_projector_type}")
86
+ speech_features = [encoder_outs[i, : speech_lengths[i]] for i in range(len(encoder_outs))]
87
+ return speech_features
88
+
89
+ def prepare_inputs_labels_for_speech_and_text(
90
+ self, input_ids, position_ids, attention_mask, past_key_values, labels, speech, speech_lengths
91
+ ):
92
+ # input_ids = input_ids.unsqueeze(0)
93
+ speech_encoder = self.get_speech_encoder()
94
+ if speech_encoder is None or speech is None or input_ids.shape[1] == 1:
95
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
96
+
97
+ speech_features = self.encode_speech(speech, speech_lengths)
98
+ # Let's just add dummy tensors if they do not exist,
99
+ # it is a headache to deal with None all the time.
100
+ # But it is not ideal, and if you have a better idea,
101
+ # please open an issue / submit a PR, thanks.
102
+ _labels = labels
103
+ _position_ids = position_ids
104
+ _attention_mask = attention_mask
105
+ if attention_mask is None:
106
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
107
+ else:
108
+ attention_mask = attention_mask.bool()
109
+ if position_ids is None:
110
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
111
+ if labels is None:
112
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
113
+
114
+ # remove the padding using attention_mask -- FIXME
115
+ # _input_ids = input_ids
116
+ input_ids = [
117
+ cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
118
+ ]
119
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
120
+
121
+ new_input_embeds = []
122
+ new_labels = []
123
+ cur_speech_idx = 0
124
+ for batch_idx, cur_input_ids in enumerate(input_ids):
125
+ num_speech = (cur_input_ids == SPEECH_TOKEN_INDEX).sum()
126
+ if num_speech == 0:
127
+ cur_speech_features = speech_features[cur_speech_idx]
128
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
129
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_speech_features[0:0]], dim=0)
130
+ new_input_embeds.append(cur_input_embeds)
131
+ new_labels.append(labels[batch_idx])
132
+ cur_speech_idx += 1
133
+ continue
134
+
135
+ speech_token_indices = (
136
+ [-1] + torch.where(cur_input_ids == SPEECH_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
137
+ )
138
+ cur_input_ids_nospeech = []
139
+ cur_labels = labels[batch_idx]
140
+ cur_labels_nospeech = []
141
+ for i in range(len(speech_token_indices) - 1):
142
+ cur_input_ids_nospeech.append(cur_input_ids[speech_token_indices[i] + 1 : speech_token_indices[i + 1]])
143
+ cur_labels_nospeech.append(cur_labels[speech_token_indices[i] + 1 : speech_token_indices[i + 1]])
144
+ split_sizes = [x.shape[0] for x in cur_labels_nospeech]
145
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_nospeech))
146
+ cur_input_embeds_no_speech = torch.split(cur_input_embeds, split_sizes, dim=0)
147
+ cur_new_input_embeds = []
148
+ cur_new_labels = []
149
+
150
+ for i in range(num_speech + 1):
151
+ cur_new_input_embeds.append(cur_input_embeds_no_speech[i])
152
+ cur_new_labels.append(cur_labels_nospeech[i])
153
+ if i < num_speech:
154
+ cur_speech_features = speech_features[cur_speech_idx]
155
+ cur_speech_idx += 1
156
+ cur_new_input_embeds.append(cur_speech_features)
157
+ cur_new_labels.append(
158
+ torch.full(
159
+ (cur_speech_features.shape[0],),
160
+ IGNORE_INDEX,
161
+ device=cur_labels.device,
162
+ dtype=cur_labels.dtype,
163
+ )
164
+ )
165
+
166
+ cur_new_input_embeds_stack = [x.to(input_ids[0].device) for x in cur_new_input_embeds]
167
+
168
+ cur_new_input_embeds_tensor = torch.cat(cur_new_input_embeds_stack)
169
+ cur_new_labels_tensor = torch.cat(cur_new_labels)
170
+
171
+ new_input_embeds.append(cur_new_input_embeds_tensor)
172
+ new_labels.append(cur_new_labels_tensor)
173
+
174
+ # Truncate sequences to max length as speech features can make the sequence longer
175
+ tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
176
+ if tokenizer_model_max_length is not None:
177
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
178
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
179
+
180
+ # Combine them
181
+ max_len = max(x.shape[0] for x in new_input_embeds)
182
+ batch_size = len(new_input_embeds)
183
+
184
+ new_input_embeds_padded = []
185
+ new_labels_padded = torch.full(
186
+ (batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device
187
+ )
188
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
189
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
190
+
191
+ for i, (cur_new_embed, cur_new_labels_loop) in enumerate(zip(new_input_embeds, new_labels)):
192
+ cur_len = cur_new_embed.shape[0]
193
+ if getattr(self.config, "tokenizer_padding_side", "right") == "left":
194
+ new_input_embeds_padded.append(
195
+ torch.cat(
196
+ (
197
+ torch.zeros(
198
+ (max_len - cur_len, cur_new_embed.shape[1]),
199
+ dtype=cur_new_embed.dtype,
200
+ device=cur_new_embed.device,
201
+ ),
202
+ cur_new_embed,
203
+ ),
204
+ dim=0,
205
+ )
206
+ )
207
+ if cur_len > 0:
208
+ new_labels_padded[i, -cur_len:] = cur_new_labels_loop
209
+ attention_mask[i, -cur_len:] = True
210
+ position_ids[i, -cur_len:] = torch.arange(
211
+ 0, cur_len, dtype=position_ids.dtype, device=position_ids.device
212
+ )
213
+ else:
214
+ new_input_embeds_padded.append(
215
+ torch.cat(
216
+ (
217
+ cur_new_embed,
218
+ torch.zeros(
219
+ (max_len - cur_len, cur_new_embed.shape[1]),
220
+ dtype=cur_new_embed.dtype,
221
+ device=cur_new_embed.device,
222
+ ),
223
+ ),
224
+ dim=0,
225
+ )
226
+ )
227
+ if cur_len > 0:
228
+ new_labels_padded[i, :cur_len] = cur_new_labels_loop
229
+ attention_mask[i, :cur_len] = True
230
+ position_ids[i, :cur_len] = torch.arange(
231
+ 0, cur_len, dtype=position_ids.dtype, device=position_ids.device
232
+ )
233
+
234
+ new_input_embeds_tensor = torch.stack(new_input_embeds_padded, dim=0)
235
+
236
+ if _labels is None:
237
+ new_labels_new = None
238
+ else:
239
+ new_labels_new = new_labels_padded
240
+
241
+ if _attention_mask is None:
242
+ attention_mask_new = None
243
+ else:
244
+ attention_mask_new = attention_mask.to(dtype=_attention_mask.dtype)
245
+
246
+ if _position_ids is None:
247
+ position_ids = None
248
+
249
+ return None, position_ids, attention_mask_new, past_key_values, new_input_embeds_tensor, new_labels_new
@@ -0,0 +1,9 @@
1
+ from helm.clients.audio_language.llama_omni.model.speech_encoder.speech_encoder import WhisperWrappedEncoder
2
+
3
+
4
+ def build_speech_encoder(config):
5
+ speech_encoder_type = getattr(config, "speech_encoder_type", "none")
6
+ if "whisper" in speech_encoder_type.lower():
7
+ return WhisperWrappedEncoder.load(config)
8
+
9
+ raise ValueError(f"Unknown speech encoder: {speech_encoder_type}")
@@ -0,0 +1,27 @@
1
+ # Adopted from https://github.com/ddlBoJack/SLAM-LLM/blob/main/src/slam_llm/models/encoder.py
2
+ import torch.nn as nn
3
+ import whisper
4
+
5
+
6
+ class WhisperWrappedEncoder:
7
+
8
+ @classmethod
9
+ def load(cls, model_config):
10
+
11
+ def replace_layer_norm(module):
12
+ from whisper.model import LayerNorm
13
+
14
+ for name, child in module.named_children():
15
+ if isinstance(child, LayerNorm):
16
+ old_params = child.state_dict()
17
+ new_layer_norm = nn.LayerNorm(
18
+ child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine
19
+ )
20
+ new_layer_norm.load_state_dict(old_params)
21
+ setattr(module, name, new_layer_norm)
22
+ else:
23
+ replace_layer_norm(child)
24
+
25
+ encoder = whisper.load_model(name="large-v3", device="cpu").encoder
26
+ replace_layer_norm(encoder)
27
+ return encoder
@@ -0,0 +1,9 @@
1
+ from helm.clients.audio_language.llama_omni.model.speech_generator.speech_generator import SpeechGeneratorCTC
2
+
3
+
4
+ def build_speech_generator(config):
5
+ generator_type = getattr(config, "speech_generator_type", "ctc")
6
+ if generator_type == "ctc":
7
+ return SpeechGeneratorCTC(config)
8
+
9
+ raise ValueError(f"Unknown generator type: {generator_type}")