crfm-helm 0.5.7__py3-none-any.whl → 0.5.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/METADATA +5 -77
- {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/RECORD +228 -197
- helm/benchmark/adaptation/adapter_spec.py +5 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +11 -3
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +11 -8
- helm/benchmark/annotation/aci_bench_annotator.py +11 -22
- helm/benchmark/annotation/alrage_annotator.py +90 -0
- helm/benchmark/annotation/chw_care_plan_annotator.py +10 -21
- helm/benchmark/annotation/dischargeme_annotator.py +11 -22
- helm/benchmark/annotation/med_dialog_annotator.py +11 -22
- helm/benchmark/annotation/medalign_annotator.py +11 -22
- helm/benchmark/annotation/medi_qa_annotator.py +11 -22
- helm/benchmark/annotation/medication_qa_annotator.py +11 -22
- helm/benchmark/annotation/mental_health_annotator.py +11 -22
- helm/benchmark/annotation/mimic_bhc_annotator.py +11 -22
- helm/benchmark/annotation/mimic_rrs_annotator.py +11 -22
- helm/benchmark/annotation/model_as_judge.py +23 -18
- helm/benchmark/annotation/mtsamples_procedures_annotator.py +11 -22
- helm/benchmark/annotation/mtsamples_replicate_annotator.py +11 -22
- helm/benchmark/annotation/starr_patient_instructions_annotator.py +11 -22
- helm/benchmark/metrics/air_bench_metrics.py +3157 -1
- helm/benchmark/metrics/alrage_metric.py +35 -0
- helm/benchmark/metrics/basic_metrics.py +267 -2
- helm/benchmark/metrics/classification_metrics.py +19 -1
- helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +12 -1
- helm/benchmark/metrics/dry_run_metrics.py +30 -1
- helm/benchmark/metrics/efficiency_metrics.py +74 -0
- helm/benchmark/metrics/ehr_sql_metrics.py +57 -1
- helm/benchmark/metrics/evaluate_reference_metrics.py +299 -0
- helm/benchmark/metrics/gpqa_chain_of_thought_metric.py +13 -1
- helm/benchmark/metrics/helpdesk_call_summarization_metrics.py +13 -1
- helm/benchmark/metrics/ifeval_metrics.py +13 -1
- helm/benchmark/metrics/instruction_following_critique_metrics.py +41 -1
- helm/benchmark/metrics/kpi_edgar_metrics.py +21 -0
- helm/benchmark/metrics/language_modeling_metrics.py +13 -1
- helm/benchmark/metrics/live_qa_metrics.py +13 -1
- helm/benchmark/metrics/llm_jury_metrics.py +13 -1
- helm/benchmark/metrics/medcalc_bench_metrics.py +14 -1
- helm/benchmark/metrics/medec_metrics.py +25 -2
- helm/benchmark/metrics/metric.py +25 -0
- helm/benchmark/metrics/mimiciv_billing_code_metrics.py +32 -1
- helm/benchmark/metrics/omni_math_metrics.py +13 -1
- helm/benchmark/metrics/seahelm_metrics.py +14 -1
- helm/benchmark/metrics/summac/model_summac.py +2 -2
- helm/benchmark/metrics/summarization_metrics.py +129 -1
- helm/benchmark/metrics/toxicity_metrics.py +31 -1
- helm/benchmark/metrics/wildbench_metrics.py +21 -1
- helm/benchmark/presentation/schema.py +5 -22
- helm/benchmark/presentation/summarize.py +180 -11
- helm/benchmark/presentation/taxonomy_info.py +20 -0
- helm/benchmark/run_expander.py +4 -0
- helm/benchmark/run_specs/arabic_run_specs.py +134 -16
- helm/benchmark/run_specs/bluex_run_specs.py +1 -1
- helm/benchmark/run_specs/classic_run_specs.py +2 -2
- helm/benchmark/run_specs/long_context_run_specs.py +2 -2
- helm/benchmark/run_specs/medhelm/__init__.py +0 -0
- helm/benchmark/run_specs/medhelm/benchmark_config.py +219 -0
- helm/benchmark/run_specs/medhelm_run_specs.py +360 -50
- helm/benchmark/scenarios/aci_bench_scenario.py +23 -0
- helm/benchmark/scenarios/air_bench_scenario.py +21 -0
- helm/benchmark/scenarios/alrage_scenario.py +54 -0
- helm/benchmark/scenarios/anthropic_hh_rlhf_scenario.py +23 -1
- helm/benchmark/scenarios/arabic_exams_scenario.py +114 -0
- helm/benchmark/scenarios/arabic_mmlu_scenario.py +8 -4
- helm/benchmark/scenarios/aratrust_scenario.py +19 -0
- helm/benchmark/scenarios/babi_qa_scenario.py +15 -0
- helm/benchmark/scenarios/bbq_scenario.py +15 -0
- helm/benchmark/scenarios/best_chatgpt_prompts.yaml +473 -0
- helm/benchmark/scenarios/bluex_scenario.py +6 -2
- helm/benchmark/scenarios/bold_scenario.py +15 -0
- helm/benchmark/scenarios/boolq_scenario.py +20 -0
- helm/benchmark/scenarios/chw_care_plan_scenario.py +23 -0
- helm/benchmark/scenarios/civil_comments_scenario.py +13 -0
- helm/benchmark/scenarios/clear_scenario.py +23 -0
- helm/benchmark/scenarios/cleva_scenario.py +479 -0
- helm/benchmark/scenarios/code_scenario.py +28 -0
- helm/benchmark/scenarios/commonsense_scenario.py +26 -0
- helm/benchmark/scenarios/compositional_instructions.yaml +70 -0
- helm/benchmark/scenarios/conv_fin_qa_calc_scenario.py +21 -0
- helm/benchmark/scenarios/copyright_scenario.py +35 -1
- helm/benchmark/scenarios/cti_to_mitre_scenario.py +21 -0
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +22 -1
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +23 -1
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +22 -1
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +21 -1
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +13 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +13 -1
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +13 -1
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +13 -1
- helm/benchmark/scenarios/dischargeme_scenario.py +24 -0
- helm/benchmark/scenarios/disinformation_scenario.py +22 -0
- helm/benchmark/scenarios/dyck_language_scenario.py +15 -0
- helm/benchmark/scenarios/ehrshot_scenario.py +22 -0
- helm/benchmark/scenarios/enem_challenge_scenario.py +19 -0
- helm/benchmark/scenarios/entity_data_imputation_scenario.py +14 -0
- helm/benchmark/scenarios/entity_matching_scenario.py +14 -0
- helm/benchmark/scenarios/financial_phrasebank_scenario.py +21 -0
- helm/benchmark/scenarios/gold_commodity_news_scenario.py +21 -0
- helm/benchmark/scenarios/gpqa_scenario.py +18 -0
- helm/benchmark/scenarios/grammar_scenario.py +20 -1
- helm/benchmark/scenarios/gsm_scenario.py +15 -0
- helm/benchmark/scenarios/headqa_scenario.py +22 -0
- helm/benchmark/scenarios/helpdesk_call_summarization_scenario.py +13 -0
- helm/benchmark/scenarios/ice_scenario.py +21 -1
- helm/benchmark/scenarios/ifeval_scenario.py +18 -0
- helm/benchmark/scenarios/imdb_scenario.py +15 -0
- helm/benchmark/scenarios/koala_scenario.py +21 -1
- helm/benchmark/scenarios/kpi_edgar_scenario.py +21 -0
- helm/benchmark/scenarios/legal_contract_summarization_scenario.py +20 -0
- helm/benchmark/scenarios/legal_summarization_scenario.py +50 -0
- helm/benchmark/scenarios/legal_support_scenario.py +13 -0
- helm/benchmark/scenarios/legalbench_scenario.py +20 -0
- helm/benchmark/scenarios/lex_glue_scenario.py +11 -0
- helm/benchmark/scenarios/lextreme_scenario.py +11 -0
- helm/benchmark/scenarios/lsat_qa_scenario.py +14 -0
- helm/benchmark/scenarios/madinah_qa_scenario.py +73 -0
- helm/benchmark/scenarios/math_scenario.py +26 -0
- helm/benchmark/scenarios/mbzuai_human_translated_arabic_mmlu.py +68 -0
- helm/benchmark/scenarios/med_dialog_scenario.py +32 -1
- helm/benchmark/scenarios/med_mcqa_scenario.py +14 -0
- helm/benchmark/scenarios/med_qa_scenario.py +14 -0
- helm/benchmark/scenarios/medalign_scenario.py +23 -0
- helm/benchmark/scenarios/medbullets_scenario.py +22 -0
- helm/benchmark/scenarios/medcalc_bench_scenario.py +22 -0
- helm/benchmark/scenarios/medec_scenario.py +23 -0
- helm/benchmark/scenarios/medhallu_scenario.py +23 -0
- helm/benchmark/scenarios/medhelm/__init__.py +0 -0
- helm/benchmark/scenarios/medhelm/judges.yaml +14 -0
- helm/benchmark/scenarios/medhelm_configurable_scenario.py +101 -0
- helm/benchmark/scenarios/medi_qa_scenario.py +23 -0
- helm/benchmark/scenarios/medication_qa_scenario.py +31 -1
- helm/benchmark/scenarios/mental_health_scenario.py +23 -0
- helm/benchmark/scenarios/mimic_bhc_scenario.py +24 -0
- helm/benchmark/scenarios/mimic_rrs_scenario.py +23 -0
- helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +22 -0
- helm/benchmark/scenarios/mmlu_pro_scenario.py +18 -0
- helm/benchmark/scenarios/mmlu_scenario.py +15 -0
- helm/benchmark/scenarios/msmarco_scenario.py +30 -0
- helm/benchmark/scenarios/mtsamples_procedures_scenario.py +22 -0
- helm/benchmark/scenarios/mtsamples_replicate_scenario.py +22 -0
- helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +20 -0
- helm/benchmark/scenarios/narrativeqa_scenario.py +20 -0
- helm/benchmark/scenarios/natural_qa_scenario.py +32 -0
- helm/benchmark/scenarios/omni_math_scenario.py +18 -0
- helm/benchmark/scenarios/open_assistant_scenario.py +22 -0
- helm/benchmark/scenarios/pubmed_qa_scenario.py +22 -0
- helm/benchmark/scenarios/quac_scenario.py +14 -0
- helm/benchmark/scenarios/race_based_med_scenario.py +23 -0
- helm/benchmark/scenarios/raft_scenario.py +15 -0
- helm/benchmark/scenarios/real_toxicity_prompts_scenario.py +14 -1
- helm/benchmark/scenarios/scenario.py +31 -0
- helm/benchmark/scenarios/seahelm_scenario.py +348 -0
- helm/benchmark/scenarios/self_instruct_scenario.py +29 -1
- helm/benchmark/scenarios/shc_bmt_scenario.py +22 -0
- helm/benchmark/scenarios/shc_cdi_scenario.py +20 -0
- helm/benchmark/scenarios/shc_conf_scenario.py +23 -0
- helm/benchmark/scenarios/shc_ent_scenario.py +21 -0
- helm/benchmark/scenarios/shc_gip_scenario.py +20 -0
- helm/benchmark/scenarios/shc_privacy_scenario.py +22 -0
- helm/benchmark/scenarios/shc_proxy_scenario.py +22 -0
- helm/benchmark/scenarios/shc_ptbm_scenario.py +23 -0
- helm/benchmark/scenarios/shc_sequoia_scenario.py +21 -0
- helm/benchmark/scenarios/situation_prompts.yaml +49 -0
- helm/benchmark/scenarios/starr_patient_instructions_scenario.py +22 -0
- helm/benchmark/scenarios/summarization_scenario.py +37 -0
- helm/benchmark/scenarios/synthetic_efficiency_scenario.py +22 -1
- helm/benchmark/scenarios/synthetic_reasoning_natural_scenario.py +13 -0
- helm/benchmark/scenarios/test_alrage_scenario.py +23 -0
- helm/benchmark/scenarios/test_arabic_exams_scenario.py +21 -0
- helm/benchmark/scenarios/test_aratrust_scenario.py +1 -1
- helm/benchmark/scenarios/test_bluex_scenario.py +2 -2
- helm/benchmark/scenarios/the_pile_scenario.py +13 -1
- helm/benchmark/scenarios/truthful_qa_scenario.py +14 -0
- helm/benchmark/scenarios/twitter_aae_scenario.py +20 -1
- helm/benchmark/scenarios/vicuna_scenario.py +21 -1
- helm/benchmark/scenarios/wikifact_scenario.py +20 -0
- helm/benchmark/scenarios/wildbench_scenario.py +18 -0
- helm/benchmark/scenarios/wmt_14_scenario.py +12 -0
- helm/benchmark/static/schema_arabic.yaml +55 -12
- helm/benchmark/static/schema_long_context.yaml +17 -17
- helm/benchmark/static/schema_medhelm.yaml +36 -0
- helm/benchmark/static/schema_slp.yaml +219 -0
- helm/benchmark/static_build/assets/index-671a5e06.js +10 -0
- helm/benchmark/static_build/assets/index-9352595e.css +1 -0
- helm/benchmark/static_build/index.html +2 -2
- helm/clients/audio_language/llama_omni/arguments.py +61 -0
- helm/clients/audio_language/llama_omni/constants.py +9 -0
- helm/clients/audio_language/llama_omni/conversation.py +213 -0
- helm/clients/audio_language/llama_omni/model/__init__.py +0 -0
- helm/clients/audio_language/llama_omni/model/builder.py +88 -0
- helm/clients/audio_language/llama_omni/model/language_model/omni_speech2s_llama.py +190 -0
- helm/clients/audio_language/llama_omni/model/language_model/omni_speech_llama.py +118 -0
- helm/clients/audio_language/llama_omni/model/omni_speech_arch.py +249 -0
- helm/clients/audio_language/llama_omni/model/speech_encoder/builder.py +9 -0
- helm/clients/audio_language/llama_omni/model/speech_encoder/speech_encoder.py +27 -0
- helm/clients/audio_language/llama_omni/model/speech_generator/builder.py +9 -0
- helm/clients/audio_language/llama_omni/model/speech_generator/generation.py +622 -0
- helm/clients/audio_language/llama_omni/model/speech_generator/speech_generator.py +104 -0
- helm/clients/audio_language/llama_omni/model/speech_projector/builder.py +9 -0
- helm/clients/audio_language/llama_omni/model/speech_projector/speech_projector.py +27 -0
- helm/clients/audio_language/llama_omni/preprocess.py +295 -0
- helm/clients/audio_language/llama_omni/utils.py +202 -0
- helm/clients/audio_language/qwen_omni/configuration_qwen2_5_omni.py +519 -0
- helm/clients/audio_language/qwen_omni/modeling_qwen2_5_omni.py +4308 -0
- helm/clients/audio_language/qwen_omni/processing_qwen2_5_omni.py +270 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/__init__.py +0 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/__init__.py +8 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/audio_process.py +56 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/vision_process.py +380 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +1 -1
- helm/clients/image_generation/mindalle/models/stage1/layers.py +2 -2
- helm/clients/openai_client.py +31 -19
- helm/clients/openai_responses_client.py +27 -3
- helm/clients/openrouter_client.py +31 -0
- helm/clients/test_openrouter_client.py +69 -0
- helm/clients/together_client.py +48 -11
- helm/clients/vertexai_client.py +8 -2
- helm/config/model_deployments.yaml +75 -1
- helm/config/model_metadata.yaml +70 -2
- helm/config/tokenizer_configs.yaml +19 -1
- helm/proxy/example_queries.py +8 -8
- helm/proxy/server.py +2 -1
- helm/proxy/static/index.css +4 -0
- helm/proxy/static/index.js +7 -1
- helm/benchmark/metrics/aci_bench_metrics.py +0 -14
- helm/benchmark/metrics/chw_care_plan_metrics.py +0 -14
- helm/benchmark/metrics/dischargeme_metrics.py +0 -14
- helm/benchmark/metrics/med_dialog_metrics.py +0 -14
- helm/benchmark/metrics/medalign_metrics.py +0 -14
- helm/benchmark/metrics/medi_qa_metrics.py +0 -14
- helm/benchmark/metrics/medication_qa_metrics.py +0 -14
- helm/benchmark/metrics/mental_health_metrics.py +0 -14
- helm/benchmark/metrics/mimic_bhc_metrics.py +0 -14
- helm/benchmark/metrics/mimic_rrs_metrics.py +0 -14
- helm/benchmark/metrics/mtsamples_procedures_metrics.py +0 -14
- helm/benchmark/metrics/mtsamples_replicate_metrics.py +0 -14
- helm/benchmark/metrics/starr_patient_instructions_metrics.py +0 -14
- helm/benchmark/static_build/assets/index-b9779128.css +0 -1
- helm/benchmark/static_build/assets/index-e439d5e1.js +0 -10
- {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/WHEEL +0 -0
- {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/licenses/LICENSE +0 -0
- {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/top_level.txt +0 -0
|
@@ -7,11 +7,11 @@
|
|
|
7
7
|
<title>Holistic Evaluation of Language Models (HELM)</title>
|
|
8
8
|
<meta name="description" content="The Holistic Evaluation of Language Models (HELM) serves as a living benchmark for transparency in language models. Providing broad coverage and recognizing incompleteness, multi-metric measurements, and standardization. All data and analysis are freely accessible on the website for exploration and study." />
|
|
9
9
|
<script type="text/javascript" src="./config.js"></script>
|
|
10
|
-
<script type="module" crossorigin src="./assets/index-
|
|
10
|
+
<script type="module" crossorigin src="./assets/index-671a5e06.js"></script>
|
|
11
11
|
<link rel="modulepreload" crossorigin href="./assets/react-f82877fd.js">
|
|
12
12
|
<link rel="modulepreload" crossorigin href="./assets/recharts-4037aff0.js">
|
|
13
13
|
<link rel="modulepreload" crossorigin href="./assets/tremor-38a10867.js">
|
|
14
|
-
<link rel="stylesheet" href="./assets/index-
|
|
14
|
+
<link rel="stylesheet" href="./assets/index-9352595e.css">
|
|
15
15
|
</head>
|
|
16
16
|
<body class="block">
|
|
17
17
|
<div id="root"></div>
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import transformers
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class ModelArguments:
|
|
9
|
+
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
|
10
|
+
version: Optional[str] = field(default="v0")
|
|
11
|
+
freeze_backbone: bool = field(default=False)
|
|
12
|
+
tune_speech_projector: bool = field(default=False)
|
|
13
|
+
tune_speech_encoder: bool = field(default=False)
|
|
14
|
+
tune_speech_generator_only: bool = field(default=False)
|
|
15
|
+
speech_encoder_type: Optional[str] = field(default=None)
|
|
16
|
+
speech_encoder: Optional[str] = field(default=None)
|
|
17
|
+
pretrain_speech_projector: Optional[str] = field(default=None)
|
|
18
|
+
speech_projector_type: Optional[str] = field(default="linear")
|
|
19
|
+
speech_generator_type: Optional[str] = field(default="ctc")
|
|
20
|
+
ctc_decoder_config: str = "(2,4096,32,11008)"
|
|
21
|
+
ctc_upsample_factor: int = 1
|
|
22
|
+
ctc_loss_weight: float = 1.0
|
|
23
|
+
unit_vocab_size: int = 1000
|
|
24
|
+
speech_encoder_ds_rate: int = 5
|
|
25
|
+
speech_encoder_hidden_size: int = 1280
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class DataArguments:
|
|
30
|
+
data_path: str = field(default="", metadata={"help": "Path to the training data."})
|
|
31
|
+
is_multimodal: bool = False
|
|
32
|
+
input_type: str = field(default="mel")
|
|
33
|
+
speech_normalize: bool = False
|
|
34
|
+
mel_size: int = 128
|
|
35
|
+
has_tgt_units: bool = False
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class TrainingArguments(transformers.TrainingArguments):
|
|
40
|
+
cache_dir: Optional[str] = field(default=None)
|
|
41
|
+
optim: str = field(default="adamw_torch")
|
|
42
|
+
freeze_speech_projector: bool = field(default=False)
|
|
43
|
+
model_max_length: int = field(
|
|
44
|
+
default=512,
|
|
45
|
+
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
|
|
46
|
+
)
|
|
47
|
+
double_quant: bool = field(
|
|
48
|
+
default=True, metadata={"help": "Compress the quantization statistics through double quantization."}
|
|
49
|
+
)
|
|
50
|
+
quant_type: str = field(
|
|
51
|
+
default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
|
|
52
|
+
)
|
|
53
|
+
bits: int = field(default=16, metadata={"help": "How many bits to use."})
|
|
54
|
+
lora_enable: bool = False
|
|
55
|
+
lora_r: int = 64
|
|
56
|
+
lora_alpha: int = 16
|
|
57
|
+
lora_dropout: float = 0.05
|
|
58
|
+
lora_weight_path: str = ""
|
|
59
|
+
lora_bias: str = "none"
|
|
60
|
+
speech_projector_lr: Optional[float] = None
|
|
61
|
+
group_by_modality_length: bool = field(default=False)
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
|
|
2
|
+
# Copyright 2023 Haotian Liu
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import dataclasses
|
|
17
|
+
from enum import auto, Enum
|
|
18
|
+
from typing import List, Any, Union, Optional
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SeparatorStyle(Enum):
|
|
22
|
+
"""Different separator style."""
|
|
23
|
+
|
|
24
|
+
TWO = auto()
|
|
25
|
+
PLAIN = auto()
|
|
26
|
+
LLAMA_2 = auto()
|
|
27
|
+
LLAMA_3 = auto()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclasses.dataclass
|
|
31
|
+
class Conversation:
|
|
32
|
+
"""A class that keeps all conversation history."""
|
|
33
|
+
|
|
34
|
+
system: str
|
|
35
|
+
roles: List[str]
|
|
36
|
+
messages: List[List[str]]
|
|
37
|
+
offset: int
|
|
38
|
+
sep_style: SeparatorStyle = SeparatorStyle.PLAIN
|
|
39
|
+
sep: str = "###"
|
|
40
|
+
sep2: str = ""
|
|
41
|
+
version: str = "Unknown"
|
|
42
|
+
|
|
43
|
+
tokenizer_id: str = ""
|
|
44
|
+
tokenizer: Any = None
|
|
45
|
+
# Stop criteria (the default one is EOS token)
|
|
46
|
+
stop_str: Optional[Union[str, List[str]]] = None
|
|
47
|
+
# Stops generation if meeting any token in this list
|
|
48
|
+
stop_token_ids: Optional[List[int]] = None
|
|
49
|
+
|
|
50
|
+
skip_next: bool = False
|
|
51
|
+
|
|
52
|
+
def get_prompt(self):
|
|
53
|
+
messages = self.messages
|
|
54
|
+
|
|
55
|
+
if self.sep_style == SeparatorStyle.TWO:
|
|
56
|
+
seps = [self.sep, self.sep2]
|
|
57
|
+
ret = self.system + seps[0]
|
|
58
|
+
for i, (role, message) in enumerate(messages):
|
|
59
|
+
if message:
|
|
60
|
+
if type(message) is tuple:
|
|
61
|
+
message = message[0]
|
|
62
|
+
ret += role + ": " + message + seps[i % 2]
|
|
63
|
+
else:
|
|
64
|
+
ret += role + ":"
|
|
65
|
+
elif self.sep_style == SeparatorStyle.LLAMA_3:
|
|
66
|
+
wrap_sys = lambda msg: (
|
|
67
|
+
f"<|start_header_id|>system<|end_header_id|>\n\n{msg}<|eot_id|>" if len(msg) > 0 else msg
|
|
68
|
+
)
|
|
69
|
+
ret = "<|begin_of_text|>" + wrap_sys(self.system)
|
|
70
|
+
for i, (role, message) in enumerate(messages):
|
|
71
|
+
if message:
|
|
72
|
+
if type(message) is tuple:
|
|
73
|
+
message = message[0]
|
|
74
|
+
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
|
75
|
+
ret += message.strip() + self.sep2
|
|
76
|
+
else:
|
|
77
|
+
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
|
78
|
+
return ret
|
|
79
|
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
|
80
|
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
|
|
81
|
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
|
82
|
+
ret = ""
|
|
83
|
+
|
|
84
|
+
for i, (role, message) in enumerate(messages):
|
|
85
|
+
if i == 0:
|
|
86
|
+
assert message, "first message should not be none"
|
|
87
|
+
assert role == self.roles[0], "first message should come from user"
|
|
88
|
+
if message:
|
|
89
|
+
if type(message) is tuple:
|
|
90
|
+
message = message[0]
|
|
91
|
+
if i == 0:
|
|
92
|
+
message = wrap_sys(self.system) + message
|
|
93
|
+
if i % 2 == 0:
|
|
94
|
+
message = wrap_inst(message)
|
|
95
|
+
ret += self.sep + message
|
|
96
|
+
else:
|
|
97
|
+
ret += " " + message + " " + self.sep2
|
|
98
|
+
else:
|
|
99
|
+
ret += ""
|
|
100
|
+
ret = ret.lstrip(self.sep)
|
|
101
|
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
|
102
|
+
seps = [self.sep, self.sep2]
|
|
103
|
+
ret = self.system
|
|
104
|
+
for i, (role, message) in enumerate(messages):
|
|
105
|
+
if message:
|
|
106
|
+
if type(message) is tuple:
|
|
107
|
+
message = message[0]
|
|
108
|
+
ret += message + seps[i % 2]
|
|
109
|
+
else:
|
|
110
|
+
ret += ""
|
|
111
|
+
else:
|
|
112
|
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
|
113
|
+
|
|
114
|
+
return ret
|
|
115
|
+
|
|
116
|
+
def append_message(self, role, message):
|
|
117
|
+
self.messages.append([role, message])
|
|
118
|
+
|
|
119
|
+
def to_gradio_chatbot(self):
|
|
120
|
+
ret = []
|
|
121
|
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
|
122
|
+
if i % 2 == 0:
|
|
123
|
+
if type(msg) is tuple:
|
|
124
|
+
msg = msg[0]
|
|
125
|
+
ret.append([msg, None])
|
|
126
|
+
else:
|
|
127
|
+
ret.append([msg, None])
|
|
128
|
+
else:
|
|
129
|
+
ret[-1][-1] = msg
|
|
130
|
+
return ret
|
|
131
|
+
|
|
132
|
+
def copy(self):
|
|
133
|
+
return Conversation(
|
|
134
|
+
system=self.system,
|
|
135
|
+
roles=self.roles,
|
|
136
|
+
messages=[[x, y] for x, y in self.messages],
|
|
137
|
+
offset=self.offset,
|
|
138
|
+
sep_style=self.sep_style,
|
|
139
|
+
sep=self.sep,
|
|
140
|
+
sep2=self.sep2,
|
|
141
|
+
version=self.version,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def dict(self):
|
|
145
|
+
return {
|
|
146
|
+
"system": self.system,
|
|
147
|
+
"roles": self.roles,
|
|
148
|
+
"messages": self.messages,
|
|
149
|
+
"offset": self.offset,
|
|
150
|
+
"sep": self.sep,
|
|
151
|
+
"sep2": self.sep2,
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
conv_vicuna_v1 = Conversation(
|
|
156
|
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
|
157
|
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
|
158
|
+
roles=["USER", "ASSISTANT"],
|
|
159
|
+
version="v1",
|
|
160
|
+
messages=[],
|
|
161
|
+
offset=0,
|
|
162
|
+
sep_style=SeparatorStyle.TWO,
|
|
163
|
+
sep=" ",
|
|
164
|
+
sep2="</s>",
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
conv_llama_2 = Conversation(
|
|
168
|
+
system="You are a helpful language and speech assistant. "
|
|
169
|
+
"You are able to understand the speech content that the user provides, "
|
|
170
|
+
"and assist the user with a variety of tasks using natural language.",
|
|
171
|
+
roles=["USER", "ASSISTANT"],
|
|
172
|
+
version="llama_v2",
|
|
173
|
+
messages=[],
|
|
174
|
+
offset=0,
|
|
175
|
+
sep_style=SeparatorStyle.LLAMA_2,
|
|
176
|
+
sep="<s>",
|
|
177
|
+
sep2="</s>",
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
conv_llama_3 = Conversation(
|
|
181
|
+
system="You are a helpful language and speech assistant. "
|
|
182
|
+
"You are able to understand the speech content that the user provides, "
|
|
183
|
+
"and assist the user with a variety of tasks using natural language.",
|
|
184
|
+
roles=["user", "assistant"],
|
|
185
|
+
version="llama_v3",
|
|
186
|
+
messages=[],
|
|
187
|
+
offset=0,
|
|
188
|
+
sep_style=SeparatorStyle.LLAMA_3,
|
|
189
|
+
sep="",
|
|
190
|
+
sep2="<|eot_id|>",
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
conv_plain = Conversation(
|
|
194
|
+
system="",
|
|
195
|
+
roles=["", ""],
|
|
196
|
+
messages=[],
|
|
197
|
+
offset=0,
|
|
198
|
+
sep_style=SeparatorStyle.PLAIN,
|
|
199
|
+
sep="</s>",
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
default_conversation = conv_llama_3
|
|
204
|
+
conv_templates = {
|
|
205
|
+
"v1": conv_vicuna_v1,
|
|
206
|
+
"plain": conv_plain,
|
|
207
|
+
"llama_2": conv_llama_2,
|
|
208
|
+
"llama_3": conv_llama_3,
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
if __name__ == "__main__":
|
|
213
|
+
print(default_conversation.get_prompt())
|
|
File without changes
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from transformers import AutoTokenizer, AutoConfig, BitsAndBytesConfig
|
|
4
|
+
import torch
|
|
5
|
+
from helm.clients.audio_language.llama_omni.model.language_model.omni_speech_llama import OmniSpeechLlamaForCausalLM
|
|
6
|
+
from helm.clients.audio_language.llama_omni.model.language_model.omni_speech2s_llama import OmniSpeech2SLlamaForCausalLM
|
|
7
|
+
from helm.clients.audio_language.llama_omni.model.speech_encoder.builder import build_speech_encoder
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def load_pretrained_model(
|
|
11
|
+
model_path,
|
|
12
|
+
model_base,
|
|
13
|
+
is_lora=False,
|
|
14
|
+
s2s=False,
|
|
15
|
+
load_8bit=False,
|
|
16
|
+
load_4bit=False,
|
|
17
|
+
device="cuda",
|
|
18
|
+
use_flash_attn=False,
|
|
19
|
+
**kwargs,
|
|
20
|
+
):
|
|
21
|
+
if load_8bit:
|
|
22
|
+
kwargs["load_in_8bit"] = True
|
|
23
|
+
elif load_4bit:
|
|
24
|
+
kwargs["load_in_4bit"] = True
|
|
25
|
+
kwargs["quantization_config"] = BitsAndBytesConfig(
|
|
26
|
+
load_in_4bit=True,
|
|
27
|
+
bnb_4bit_compute_dtype=torch.float16,
|
|
28
|
+
bnb_4bit_use_double_quant=True,
|
|
29
|
+
bnb_4bit_quant_type="nf4",
|
|
30
|
+
)
|
|
31
|
+
else:
|
|
32
|
+
kwargs["torch_dtype"] = torch.float16
|
|
33
|
+
|
|
34
|
+
if use_flash_attn:
|
|
35
|
+
kwargs["attn_implementation"] = "flash_attention_2"
|
|
36
|
+
|
|
37
|
+
model_cls = OmniSpeech2SLlamaForCausalLM if s2s else OmniSpeechLlamaForCausalLM
|
|
38
|
+
|
|
39
|
+
# Load OmniSpeech model
|
|
40
|
+
if is_lora:
|
|
41
|
+
assert model_base is not None, "model_base is required for LoRA models."
|
|
42
|
+
from language_model.omni_speech_llama import OmniSpeechConfig
|
|
43
|
+
|
|
44
|
+
lora_cfg_pretrained = OmniSpeechConfig.from_pretrained(model_path)
|
|
45
|
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
|
46
|
+
print("Loading OmniSpeech from base model...")
|
|
47
|
+
model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=lora_cfg_pretrained, **kwargs)
|
|
48
|
+
print("Loading additional OmniSpeech weights...")
|
|
49
|
+
if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
|
|
50
|
+
non_lora_trainables = torch.load(os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu")
|
|
51
|
+
non_lora_trainables = {
|
|
52
|
+
(k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items()
|
|
53
|
+
}
|
|
54
|
+
if any(k.startswith("model.model.") for k in non_lora_trainables):
|
|
55
|
+
non_lora_trainables = {(k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()}
|
|
56
|
+
model.load_state_dict(non_lora_trainables, strict=False)
|
|
57
|
+
|
|
58
|
+
from peft import PeftModel
|
|
59
|
+
|
|
60
|
+
print("Loading LoRA weights...")
|
|
61
|
+
model = PeftModel.from_pretrained(model, model_path)
|
|
62
|
+
print("Merging LoRA weights...")
|
|
63
|
+
model = model.merge_and_unload()
|
|
64
|
+
print("Model is loaded...")
|
|
65
|
+
elif model_base is not None:
|
|
66
|
+
print("Loading OmniSpeech from base model...")
|
|
67
|
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
|
68
|
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
|
69
|
+
model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=cfg_pretrained, **kwargs)
|
|
70
|
+
|
|
71
|
+
speech_projector_weights = torch.load(os.path.join(model_path, "speech_projector.bin"), map_location="cpu")
|
|
72
|
+
speech_projector_weights = {k: v.to(torch.float16) for k, v in speech_projector_weights.items()}
|
|
73
|
+
model.load_state_dict(speech_projector_weights, strict=False)
|
|
74
|
+
model = model.to(device=device)
|
|
75
|
+
else:
|
|
76
|
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
|
77
|
+
model = model_cls.from_pretrained(model_path, low_cpu_mem_usage=False, **kwargs)
|
|
78
|
+
model = model.to(device=device)
|
|
79
|
+
|
|
80
|
+
model.get_model().speech_encoder = build_speech_encoder(model.config)
|
|
81
|
+
model.get_model().speech_encoder.to(device=device, dtype=torch.float16)
|
|
82
|
+
|
|
83
|
+
if hasattr(model.config, "max_sequence_length"):
|
|
84
|
+
context_len = model.config.max_sequence_length
|
|
85
|
+
else:
|
|
86
|
+
context_len = 2048
|
|
87
|
+
|
|
88
|
+
return tokenizer, model, context_len
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
from typing import List, Optional, Tuple, Union, Callable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig
|
|
6
|
+
|
|
7
|
+
from transformers import PreTrainedModel
|
|
8
|
+
from transformers.generation.streamers import BaseStreamer
|
|
9
|
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
10
|
+
from transformers.generation.utils import (
|
|
11
|
+
GenerationConfig,
|
|
12
|
+
LogitsProcessorList,
|
|
13
|
+
StoppingCriteriaList,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from helm.clients.audio_language.llama_omni.model.language_model.omni_speech_llama import OmniSpeechLlamaForCausalLM
|
|
17
|
+
from helm.clients.audio_language.llama_omni.model.speech_generator.builder import build_speech_generator
|
|
18
|
+
from helm.clients.audio_language.llama_omni.model.speech_generator.generation import GenerationWithCTC
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class OmniSpeech2SConfig(LlamaConfig):
|
|
22
|
+
model_type = "omni_speech2s_llama"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class OmniSpeech2SLlamaForCausalLM(OmniSpeechLlamaForCausalLM, GenerationWithCTC):
|
|
26
|
+
config_class = OmniSpeech2SConfig
|
|
27
|
+
|
|
28
|
+
def __init__(self, config):
|
|
29
|
+
super().__init__(config)
|
|
30
|
+
|
|
31
|
+
# Initialize weights and apply final processing
|
|
32
|
+
self.post_init()
|
|
33
|
+
if hasattr(config, "speech_generator_type"):
|
|
34
|
+
self.speech_generator = build_speech_generator(config)
|
|
35
|
+
|
|
36
|
+
def initialize_speech_generator(self, model_args):
|
|
37
|
+
self.config.speech_generator_type = getattr(model_args, "speech_generator_type", "ctc")
|
|
38
|
+
self.config.ctc_decoder_config = getattr(model_args, "ctc_decoder_config", "(4,4096,32,11008)")
|
|
39
|
+
self.config.ctc_upsample_factor = getattr(model_args, "ctc_upsample_factor", 1)
|
|
40
|
+
self.config.ctc_loss_weight = getattr(model_args, "ctc_loss_weight", 1.0)
|
|
41
|
+
self.config.unit_vocab_size = getattr(model_args, "unit_vocab_size", 1000)
|
|
42
|
+
self.tune_speech_generator_only = getattr(model_args, "tune_speech_generator_only", False)
|
|
43
|
+
if getattr(self, "speech_generator", None) is None:
|
|
44
|
+
self.speech_generator = build_speech_generator(self.config)
|
|
45
|
+
|
|
46
|
+
def forward(
|
|
47
|
+
self,
|
|
48
|
+
input_ids: torch.LongTensor,
|
|
49
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
50
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
51
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
52
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
53
|
+
labels: Optional[torch.LongTensor] = None,
|
|
54
|
+
use_cache: Optional[bool] = None,
|
|
55
|
+
output_attentions: Optional[bool] = None,
|
|
56
|
+
output_hidden_states: Optional[bool] = None,
|
|
57
|
+
speech: Optional[torch.FloatTensor] = None,
|
|
58
|
+
speech_lengths: Optional[torch.LongTensor] = None,
|
|
59
|
+
tgt_units: Optional[torch.LongTensor] = None,
|
|
60
|
+
return_dict: Optional[bool] = None,
|
|
61
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
62
|
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
63
|
+
|
|
64
|
+
if inputs_embeds is None:
|
|
65
|
+
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = (
|
|
66
|
+
self.prepare_inputs_labels_for_speech_and_text(
|
|
67
|
+
input_ids, position_ids, attention_mask, past_key_values, labels, speech, speech_lengths
|
|
68
|
+
)
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
if self.training:
|
|
72
|
+
if self.tune_speech_generator_only:
|
|
73
|
+
with torch.no_grad():
|
|
74
|
+
llama_output = super(OmniSpeechLlamaForCausalLM, self).forward(
|
|
75
|
+
input_ids=input_ids,
|
|
76
|
+
attention_mask=attention_mask,
|
|
77
|
+
position_ids=position_ids,
|
|
78
|
+
past_key_values=past_key_values,
|
|
79
|
+
inputs_embeds=inputs_embeds,
|
|
80
|
+
labels=labels,
|
|
81
|
+
use_cache=use_cache,
|
|
82
|
+
output_attentions=output_attentions,
|
|
83
|
+
output_hidden_states=True,
|
|
84
|
+
return_dict=return_dict,
|
|
85
|
+
)
|
|
86
|
+
loss = self.speech_generator(llama_output["hidden_states"][-1], labels, tgt_units)
|
|
87
|
+
else:
|
|
88
|
+
llama_output = super(OmniSpeechLlamaForCausalLM, self).forward(
|
|
89
|
+
input_ids=input_ids,
|
|
90
|
+
attention_mask=attention_mask,
|
|
91
|
+
position_ids=position_ids,
|
|
92
|
+
past_key_values=past_key_values,
|
|
93
|
+
inputs_embeds=inputs_embeds,
|
|
94
|
+
labels=labels,
|
|
95
|
+
use_cache=use_cache,
|
|
96
|
+
output_attentions=output_attentions,
|
|
97
|
+
output_hidden_states=True,
|
|
98
|
+
return_dict=return_dict,
|
|
99
|
+
)
|
|
100
|
+
lm_loss = llama_output.loss
|
|
101
|
+
ctc_loss = self.speech_generator(llama_output["hidden_states"][-1], labels, tgt_units)
|
|
102
|
+
loss = lm_loss + ctc_loss * self.config.ctc_loss_weight
|
|
103
|
+
else:
|
|
104
|
+
llama_output = super(OmniSpeechLlamaForCausalLM, self).forward(
|
|
105
|
+
input_ids=input_ids,
|
|
106
|
+
attention_mask=attention_mask,
|
|
107
|
+
position_ids=position_ids,
|
|
108
|
+
past_key_values=past_key_values,
|
|
109
|
+
inputs_embeds=inputs_embeds,
|
|
110
|
+
labels=labels,
|
|
111
|
+
use_cache=use_cache,
|
|
112
|
+
output_attentions=output_attentions,
|
|
113
|
+
output_hidden_states=True,
|
|
114
|
+
return_dict=return_dict,
|
|
115
|
+
)
|
|
116
|
+
loss = llama_output.loss
|
|
117
|
+
|
|
118
|
+
return CausalLMOutputWithPast(
|
|
119
|
+
loss=loss,
|
|
120
|
+
logits=llama_output.logits,
|
|
121
|
+
past_key_values=llama_output.past_key_values,
|
|
122
|
+
hidden_states=llama_output.hidden_states,
|
|
123
|
+
attentions=llama_output.attentions,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
@torch.no_grad()
|
|
127
|
+
def generate(
|
|
128
|
+
self,
|
|
129
|
+
inputs: Optional[torch.Tensor] = None,
|
|
130
|
+
speech: Optional[torch.Tensor] = None,
|
|
131
|
+
speech_lengths: Optional[torch.Tensor] = None,
|
|
132
|
+
generation_config: Optional[GenerationConfig] = None,
|
|
133
|
+
logits_processor: Optional[LogitsProcessorList] = None,
|
|
134
|
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
|
135
|
+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
|
136
|
+
synced_gpus: Optional[bool] = None,
|
|
137
|
+
assistant_model: Optional["PreTrainedModel"] = None,
|
|
138
|
+
streamer: Optional["BaseStreamer"] = None,
|
|
139
|
+
streamer_unit: Optional["BaseStreamer"] = None,
|
|
140
|
+
streaming_unit_gen=False,
|
|
141
|
+
negative_prompt_ids: Optional[torch.Tensor] = None,
|
|
142
|
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
|
143
|
+
**kwargs,
|
|
144
|
+
):
|
|
145
|
+
position_ids = kwargs.pop("position_ids", None)
|
|
146
|
+
attention_mask = kwargs.pop("attention_mask", None)
|
|
147
|
+
if "inputs_embeds" in kwargs:
|
|
148
|
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
|
149
|
+
|
|
150
|
+
if speech is not None:
|
|
151
|
+
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = (
|
|
152
|
+
self.prepare_inputs_labels_for_speech_and_text(
|
|
153
|
+
inputs, position_ids, attention_mask, None, None, speech, speech_lengths
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
else:
|
|
157
|
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
|
158
|
+
outputs = GenerationWithCTC.generate(
|
|
159
|
+
self,
|
|
160
|
+
position_ids=position_ids,
|
|
161
|
+
attention_mask=attention_mask,
|
|
162
|
+
inputs_embeds=inputs_embeds,
|
|
163
|
+
output_hidden_states=True,
|
|
164
|
+
return_dict_in_generate=True,
|
|
165
|
+
streaming_unit_gen=streaming_unit_gen,
|
|
166
|
+
**kwargs,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
hidden_states = outputs["hidden_states"]
|
|
170
|
+
hidden_states = torch.cat(
|
|
171
|
+
[hidden_states[0][-1][:, -1:, :]] + [hidden_states[i][-1] for i in range(1, len(hidden_states))], dim=1
|
|
172
|
+
)
|
|
173
|
+
ctc_pred = self.speech_generator.predict(hidden_states.squeeze(0))
|
|
174
|
+
|
|
175
|
+
return outputs.sequences, ctc_pred
|
|
176
|
+
|
|
177
|
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
|
178
|
+
speech = kwargs.pop("speech", None)
|
|
179
|
+
speech_lengths = kwargs.pop("speech_lengths", None)
|
|
180
|
+
inputs = super().prepare_inputs_for_generation(
|
|
181
|
+
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
|
182
|
+
)
|
|
183
|
+
if speech is not None:
|
|
184
|
+
inputs["speech"] = speech
|
|
185
|
+
inputs["speech_lengths"] = speech_lengths
|
|
186
|
+
return inputs
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
AutoConfig.register("omni_speech2s_llama", OmniSpeech2SConfig)
|
|
190
|
+
AutoModelForCausalLM.register(OmniSpeech2SConfig, OmniSpeech2SLlamaForCausalLM)
|