crfm-helm 0.5.6__py3-none-any.whl → 0.5.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.8.dist-info}/METADATA +60 -125
- {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.8.dist-info}/RECORD +293 -229
- helm/benchmark/adaptation/adapter_spec.py +5 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +11 -3
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +11 -8
- helm/benchmark/annotation/aci_bench_annotator.py +11 -22
- helm/benchmark/annotation/air_bench_annotator.py +1 -1
- helm/benchmark/annotation/alrage_annotator.py +90 -0
- helm/benchmark/annotation/chw_care_plan_annotator.py +10 -21
- helm/benchmark/annotation/dischargeme_annotator.py +11 -22
- helm/benchmark/annotation/live_qa_annotator.py +1 -1
- helm/benchmark/annotation/med_dialog_annotator.py +11 -22
- helm/benchmark/annotation/medalign_annotator.py +11 -22
- helm/benchmark/annotation/medi_qa_annotator.py +11 -22
- helm/benchmark/annotation/medication_qa_annotator.py +11 -22
- helm/benchmark/annotation/mental_health_annotator.py +11 -22
- helm/benchmark/annotation/mimic_bhc_annotator.py +11 -22
- helm/benchmark/annotation/mimic_rrs_annotator.py +11 -22
- helm/benchmark/annotation/model_as_judge.py +23 -18
- helm/benchmark/annotation/mtsamples_procedures_annotator.py +11 -22
- helm/benchmark/annotation/mtsamples_replicate_annotator.py +11 -22
- helm/benchmark/annotation/starr_patient_instructions_annotator.py +11 -22
- helm/benchmark/metrics/air_bench_metrics.py +3157 -1
- helm/benchmark/metrics/alrage_metric.py +35 -0
- helm/benchmark/metrics/basic_metrics.py +267 -2
- helm/benchmark/metrics/classification_metrics.py +19 -1
- helm/benchmark/metrics/codeinsights_code_efficiency_metrics.py +186 -0
- helm/benchmark/metrics/codeinsights_code_evaluation_metrics.py +477 -0
- helm/benchmark/metrics/codeinsights_correct_code_metrics.py +366 -0
- helm/benchmark/metrics/codeinsights_edge_case_metrics.py +92 -0
- helm/benchmark/metrics/codeinsights_metric_specs.py +51 -0
- helm/benchmark/metrics/comet_metric.py +1 -1
- helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +12 -1
- helm/benchmark/metrics/copyright_metrics.py +1 -1
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +1 -1
- helm/benchmark/metrics/dry_run_metrics.py +30 -1
- helm/benchmark/metrics/efficiency_metrics.py +74 -0
- helm/benchmark/metrics/ehr_sql_metrics.py +57 -1
- helm/benchmark/metrics/evaluate_reference_metrics.py +300 -1
- helm/benchmark/metrics/gpqa_chain_of_thought_metric.py +13 -1
- helm/benchmark/metrics/helpdesk_call_summarization_metrics.py +13 -1
- helm/benchmark/metrics/ifeval_metrics.py +13 -1
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +13 -2
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +1 -1
- helm/benchmark/metrics/instruction_following_critique_metrics.py +41 -1
- helm/benchmark/metrics/kpi_edgar_metrics.py +21 -0
- helm/benchmark/metrics/language_modeling_metrics.py +13 -1
- helm/benchmark/metrics/live_qa_metrics.py +13 -1
- helm/benchmark/metrics/llm_jury_metrics.py +13 -1
- helm/benchmark/metrics/lmkt_metric_specs.py +12 -0
- helm/benchmark/metrics/lmkt_metrics.py +47 -0
- helm/benchmark/metrics/medcalc_bench_metrics.py +14 -1
- helm/benchmark/metrics/medec_metrics.py +25 -2
- helm/benchmark/metrics/melt_toxicity_metric.py +1 -1
- helm/benchmark/metrics/metric.py +25 -0
- helm/benchmark/metrics/mimiciv_billing_code_metrics.py +32 -1
- helm/benchmark/metrics/omni_math_metrics.py +13 -1
- helm/benchmark/metrics/seahelm_metrics.py +14 -1
- helm/benchmark/metrics/summac/model_summac.py +3 -3
- helm/benchmark/metrics/summarization_metrics.py +129 -1
- helm/benchmark/metrics/toxicity_metrics.py +31 -1
- helm/benchmark/metrics/wildbench_metrics.py +21 -1
- helm/benchmark/model_deployment_registry.py +11 -19
- helm/benchmark/presentation/create_plots.py +11 -2
- helm/benchmark/presentation/schema.py +10 -22
- helm/benchmark/presentation/summarize.py +189 -14
- helm/benchmark/presentation/taxonomy_info.py +20 -0
- helm/benchmark/presentation/test_create_plots.py +4 -1
- helm/benchmark/run.py +7 -1
- helm/benchmark/run_expander.py +4 -0
- helm/benchmark/run_specs/arabic_run_specs.py +191 -0
- helm/benchmark/run_specs/bluex_run_specs.py +40 -0
- helm/benchmark/run_specs/classic_run_specs.py +2 -55
- helm/benchmark/run_specs/codeinsights_run_specs.py +192 -0
- helm/benchmark/run_specs/healthqa_br_run_specs.py +40 -0
- helm/benchmark/run_specs/heim_run_specs.py +3 -1
- helm/benchmark/run_specs/lmkt_run_specs.py +144 -0
- helm/benchmark/run_specs/long_context_run_specs.py +48 -1
- helm/benchmark/run_specs/medhelm/__init__.py +0 -0
- helm/benchmark/run_specs/medhelm/benchmark_config.py +219 -0
- helm/benchmark/run_specs/medhelm_run_specs.py +360 -50
- helm/benchmark/run_specs/multilingual_run_specs.py +50 -0
- helm/benchmark/run_specs/speech_disorder_audio_run_specs.py +5 -11
- helm/benchmark/scenarios/aci_bench_scenario.py +23 -0
- helm/benchmark/scenarios/air_bench_scenario.py +21 -0
- helm/benchmark/scenarios/alghafa_scenario.py +126 -0
- helm/benchmark/scenarios/alrage_scenario.py +54 -0
- helm/benchmark/scenarios/anthropic_hh_rlhf_scenario.py +23 -1
- helm/benchmark/scenarios/arabic_exams_scenario.py +114 -0
- helm/benchmark/scenarios/arabic_mmlu_scenario.py +82 -0
- helm/benchmark/scenarios/aratrust_scenario.py +95 -0
- helm/benchmark/scenarios/audio_language/casual_conversations2_scenario.py +1 -1
- helm/benchmark/scenarios/audio_language/mustard_scenario.py +1 -1
- helm/benchmark/scenarios/audio_language/{ultra_suite_asr_classification.py → ultra_suite_asr_classification_scenario.py} +9 -8
- helm/benchmark/scenarios/audio_language/ultra_suite_asr_transcription_scenario.py +99 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_classification_scenario.py +13 -5
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_breakdown_scenario.py +13 -5
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_symptoms_scenario.py +13 -5
- helm/benchmark/scenarios/babi_qa_scenario.py +15 -0
- helm/benchmark/scenarios/bbq_scenario.py +15 -0
- helm/benchmark/scenarios/best_chatgpt_prompts.yaml +473 -0
- helm/benchmark/scenarios/bluex_scenario.py +70 -0
- helm/benchmark/scenarios/bold_scenario.py +15 -0
- helm/benchmark/scenarios/boolq_scenario.py +20 -0
- helm/benchmark/scenarios/chw_care_plan_scenario.py +23 -0
- helm/benchmark/scenarios/civil_comments_scenario.py +13 -0
- helm/benchmark/scenarios/clear_scenario.py +23 -0
- helm/benchmark/scenarios/cleva_scenario.py +480 -1
- helm/benchmark/scenarios/code_scenario.py +28 -0
- helm/benchmark/scenarios/codeinsights_code_efficiency_scenario.py +197 -0
- helm/benchmark/scenarios/codeinsights_correct_code_scenario.py +78 -0
- helm/benchmark/scenarios/codeinsights_edge_case_scenario.py +192 -0
- helm/benchmark/scenarios/codeinsights_student_coding_scenario.py +162 -0
- helm/benchmark/scenarios/codeinsights_student_mistake_scenario.py +188 -0
- helm/benchmark/scenarios/commonsense_scenario.py +26 -0
- helm/benchmark/scenarios/compositional_instructions.yaml +70 -0
- helm/benchmark/scenarios/conv_fin_qa_calc_scenario.py +21 -0
- helm/benchmark/scenarios/copyright_scenario.py +35 -1
- helm/benchmark/scenarios/cti_to_mitre_scenario.py +21 -0
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +22 -1
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +23 -1
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +22 -1
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +21 -1
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +13 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +13 -1
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +13 -1
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +13 -1
- helm/benchmark/scenarios/dischargeme_scenario.py +24 -0
- helm/benchmark/scenarios/disinformation_scenario.py +22 -0
- helm/benchmark/scenarios/dyck_language_scenario.py +15 -0
- helm/benchmark/scenarios/ehrshot_scenario.py +22 -0
- helm/benchmark/scenarios/enem_challenge_scenario.py +19 -0
- helm/benchmark/scenarios/entity_data_imputation_scenario.py +14 -0
- helm/benchmark/scenarios/entity_matching_scenario.py +14 -0
- helm/benchmark/scenarios/exams_multilingual_scenario.py +115 -0
- helm/benchmark/scenarios/financial_phrasebank_scenario.py +21 -0
- helm/benchmark/scenarios/gold_commodity_news_scenario.py +21 -0
- helm/benchmark/scenarios/gpqa_scenario.py +18 -0
- helm/benchmark/scenarios/grammar_scenario.py +20 -1
- helm/benchmark/scenarios/gsm_scenario.py +15 -0
- helm/benchmark/scenarios/headqa_scenario.py +22 -0
- helm/benchmark/scenarios/healthqa_br_scenario.py +80 -0
- helm/benchmark/scenarios/helpdesk_call_summarization_scenario.py +13 -0
- helm/benchmark/scenarios/ice_scenario.py +21 -1
- helm/benchmark/scenarios/ifeval_scenario.py +18 -0
- helm/benchmark/scenarios/imdb_scenario.py +15 -0
- helm/benchmark/scenarios/infinite_bench_en_mc_scenario.py +90 -0
- helm/benchmark/scenarios/infinite_bench_en_qa_scenario.py +1 -1
- helm/benchmark/scenarios/koala_scenario.py +21 -1
- helm/benchmark/scenarios/kpi_edgar_scenario.py +21 -0
- helm/benchmark/scenarios/legal_contract_summarization_scenario.py +20 -0
- helm/benchmark/scenarios/legal_summarization_scenario.py +50 -0
- helm/benchmark/scenarios/legal_support_scenario.py +13 -0
- helm/benchmark/scenarios/legalbench_scenario.py +20 -0
- helm/benchmark/scenarios/lex_glue_scenario.py +11 -0
- helm/benchmark/scenarios/lextreme_scenario.py +11 -0
- helm/benchmark/scenarios/lmkt_scenarios.py +288 -0
- helm/benchmark/scenarios/lsat_qa_scenario.py +14 -0
- helm/benchmark/scenarios/madinah_qa_scenario.py +73 -0
- helm/benchmark/scenarios/math_scenario.py +47 -20
- helm/benchmark/scenarios/mbzuai_human_translated_arabic_mmlu.py +68 -0
- helm/benchmark/scenarios/med_dialog_scenario.py +32 -1
- helm/benchmark/scenarios/med_mcqa_scenario.py +14 -0
- helm/benchmark/scenarios/med_qa_scenario.py +14 -0
- helm/benchmark/scenarios/medalign_scenario.py +23 -0
- helm/benchmark/scenarios/medalign_scenario_helper.py +19 -125
- helm/benchmark/scenarios/medbullets_scenario.py +22 -0
- helm/benchmark/scenarios/medcalc_bench_scenario.py +22 -0
- helm/benchmark/scenarios/medec_scenario.py +23 -0
- helm/benchmark/scenarios/medhallu_scenario.py +23 -0
- helm/benchmark/scenarios/medhelm/__init__.py +0 -0
- helm/benchmark/scenarios/medhelm/judges.yaml +14 -0
- helm/benchmark/scenarios/medhelm_configurable_scenario.py +101 -0
- helm/benchmark/scenarios/medi_qa_scenario.py +23 -0
- helm/benchmark/scenarios/medication_qa_scenario.py +31 -1
- helm/benchmark/scenarios/melt_scenarios.py +2 -2
- helm/benchmark/scenarios/mental_health_scenario.py +23 -0
- helm/benchmark/scenarios/mimic_bhc_scenario.py +25 -1
- helm/benchmark/scenarios/mimic_rrs_scenario.py +23 -0
- helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +22 -0
- helm/benchmark/scenarios/mmlu_pro_scenario.py +18 -0
- helm/benchmark/scenarios/mmlu_scenario.py +15 -0
- helm/benchmark/scenarios/mmmlu_scenario.py +85 -0
- helm/benchmark/scenarios/msmarco_scenario.py +30 -0
- helm/benchmark/scenarios/mtsamples_procedures_scenario.py +22 -0
- helm/benchmark/scenarios/mtsamples_replicate_scenario.py +22 -0
- helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +20 -0
- helm/benchmark/scenarios/narrativeqa_scenario.py +20 -0
- helm/benchmark/scenarios/natural_qa_scenario.py +32 -0
- helm/benchmark/scenarios/omni_math_scenario.py +18 -0
- helm/benchmark/scenarios/open_assistant_scenario.py +22 -0
- helm/benchmark/scenarios/pubmed_qa_scenario.py +22 -0
- helm/benchmark/scenarios/quac_scenario.py +14 -0
- helm/benchmark/scenarios/race_based_med_scenario.py +23 -0
- helm/benchmark/scenarios/raft_scenario.py +15 -0
- helm/benchmark/scenarios/real_toxicity_prompts_scenario.py +14 -1
- helm/benchmark/scenarios/scenario.py +31 -0
- helm/benchmark/scenarios/seahelm_scenario.py +350 -2
- helm/benchmark/scenarios/self_instruct_scenario.py +29 -1
- helm/benchmark/scenarios/shc_bmt_scenario.py +22 -0
- helm/benchmark/scenarios/shc_cdi_scenario.py +20 -0
- helm/benchmark/scenarios/shc_conf_scenario.py +23 -0
- helm/benchmark/scenarios/shc_ent_scenario.py +21 -0
- helm/benchmark/scenarios/shc_gip_scenario.py +20 -0
- helm/benchmark/scenarios/shc_privacy_scenario.py +22 -0
- helm/benchmark/scenarios/shc_proxy_scenario.py +22 -0
- helm/benchmark/scenarios/shc_ptbm_scenario.py +23 -0
- helm/benchmark/scenarios/shc_sequoia_scenario.py +21 -0
- helm/benchmark/scenarios/situation_prompts.yaml +49 -0
- helm/benchmark/scenarios/starr_patient_instructions_scenario.py +22 -0
- helm/benchmark/scenarios/summarization_scenario.py +37 -0
- helm/benchmark/scenarios/synthetic_efficiency_scenario.py +22 -1
- helm/benchmark/scenarios/synthetic_reasoning_natural_scenario.py +13 -0
- helm/benchmark/scenarios/test_alghafa_scenario.py +29 -0
- helm/benchmark/scenarios/test_alrage_scenario.py +23 -0
- helm/benchmark/scenarios/test_arabic_exams_scenario.py +21 -0
- helm/benchmark/scenarios/test_aratrust_scenario.py +21 -0
- helm/benchmark/scenarios/test_bluex_scenario.py +59 -0
- helm/benchmark/scenarios/test_exams_multilingual_scenario.py +29 -0
- helm/benchmark/scenarios/test_healtha_br_scenario.py +57 -0
- helm/benchmark/scenarios/the_pile_scenario.py +13 -1
- helm/benchmark/scenarios/truthful_qa_scenario.py +14 -0
- helm/benchmark/scenarios/twitter_aae_scenario.py +20 -1
- helm/benchmark/scenarios/vicuna_scenario.py +21 -1
- helm/benchmark/scenarios/wikifact_scenario.py +20 -0
- helm/benchmark/scenarios/wildbench_scenario.py +18 -0
- helm/benchmark/scenarios/wmt_14_scenario.py +12 -0
- helm/benchmark/slurm_jobs.py +1 -2
- helm/benchmark/slurm_runner.py +8 -1
- helm/benchmark/static/schema_arabic.yaml +271 -0
- helm/benchmark/static/schema_classic.yaml +0 -17
- helm/benchmark/static/schema_long_context.yaml +24 -6
- helm/benchmark/static/schema_medhelm.yaml +36 -0
- helm/benchmark/static/schema_slp.yaml +219 -0
- helm/benchmark/static_build/assets/index-671a5e06.js +10 -0
- helm/benchmark/static_build/assets/index-9352595e.css +1 -0
- helm/benchmark/static_build/index.html +2 -2
- helm/benchmark/window_services/image_generation/clip_window_service.py +1 -3
- helm/clients/audio_language/llama_omni/arguments.py +61 -0
- helm/clients/audio_language/llama_omni/constants.py +9 -0
- helm/clients/audio_language/llama_omni/conversation.py +213 -0
- helm/clients/audio_language/llama_omni/model/__init__.py +0 -0
- helm/clients/audio_language/llama_omni/model/builder.py +88 -0
- helm/clients/audio_language/llama_omni/model/language_model/omni_speech2s_llama.py +190 -0
- helm/clients/audio_language/llama_omni/model/language_model/omni_speech_llama.py +118 -0
- helm/clients/audio_language/llama_omni/model/omni_speech_arch.py +249 -0
- helm/clients/audio_language/llama_omni/model/speech_encoder/builder.py +9 -0
- helm/clients/audio_language/llama_omni/model/speech_encoder/speech_encoder.py +27 -0
- helm/clients/audio_language/llama_omni/model/speech_generator/builder.py +9 -0
- helm/clients/audio_language/llama_omni/model/speech_generator/generation.py +622 -0
- helm/clients/audio_language/llama_omni/model/speech_generator/speech_generator.py +104 -0
- helm/clients/audio_language/llama_omni/model/speech_projector/builder.py +9 -0
- helm/clients/audio_language/llama_omni/model/speech_projector/speech_projector.py +27 -0
- helm/clients/audio_language/llama_omni/preprocess.py +295 -0
- helm/clients/audio_language/llama_omni/utils.py +202 -0
- helm/clients/audio_language/qwen2_5_omni_client.py +19 -7
- helm/clients/audio_language/qwen_omni/configuration_qwen2_5_omni.py +519 -0
- helm/clients/audio_language/qwen_omni/modeling_qwen2_5_omni.py +4308 -0
- helm/clients/audio_language/qwen_omni/processing_qwen2_5_omni.py +270 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/__init__.py +0 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/__init__.py +8 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/audio_process.py +56 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/vision_process.py +380 -0
- helm/clients/huggingface_client.py +2 -2
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +1 -1
- helm/clients/image_generation/mindalle/models/stage1/layers.py +2 -2
- helm/clients/openai_client.py +33 -20
- helm/clients/openai_responses_client.py +34 -8
- helm/clients/openrouter_client.py +31 -0
- helm/clients/test_huggingface_client.py +3 -3
- helm/clients/test_openrouter_client.py +69 -0
- helm/clients/together_client.py +48 -13
- helm/clients/vertexai_client.py +19 -11
- helm/clients/vllm_client.py +43 -7
- helm/clients/vllm_granite_thinking_client.py +56 -0
- helm/common/critique_request.py +0 -1
- helm/common/hierarchical_logger.py +83 -34
- helm/common/object_spec.py +23 -8
- helm/common/test_logging.py +94 -0
- helm/config/model_deployments.yaml +525 -172
- helm/config/model_metadata.yaml +185 -10
- helm/config/tokenizer_configs.yaml +100 -2
- helm/proxy/cli.py +1 -1
- helm/proxy/example_queries.py +8 -8
- helm/proxy/retry.py +5 -0
- helm/proxy/server.py +2 -1
- helm/proxy/static/index.css +4 -0
- helm/proxy/static/index.js +7 -1
- helm/tokenizers/grok_tokenizer.py +2 -0
- helm/benchmark/metrics/aci_bench_metrics.py +0 -14
- helm/benchmark/metrics/chw_care_plan_metrics.py +0 -14
- helm/benchmark/metrics/dischargeme_metrics.py +0 -14
- helm/benchmark/metrics/med_dialog_metrics.py +0 -14
- helm/benchmark/metrics/medalign_metrics.py +0 -14
- helm/benchmark/metrics/medi_qa_metrics.py +0 -14
- helm/benchmark/metrics/medication_qa_metrics.py +0 -14
- helm/benchmark/metrics/mental_health_metrics.py +0 -14
- helm/benchmark/metrics/mimic_bhc_metrics.py +0 -14
- helm/benchmark/metrics/mimic_rrs_metrics.py +0 -14
- helm/benchmark/metrics/mtsamples_procedures_metrics.py +0 -14
- helm/benchmark/metrics/mtsamples_replicate_metrics.py +0 -14
- helm/benchmark/metrics/numeracy_metrics.py +0 -72
- helm/benchmark/metrics/starr_patient_instructions_metrics.py +0 -14
- helm/benchmark/metrics/test_numeracy_metrics.py +0 -95
- helm/benchmark/scenarios/numeracy_scenario.py +0 -794
- helm/benchmark/static_build/assets/index-94295e78.js +0 -10
- helm/benchmark/static_build/assets/index-b9779128.css +0 -1
- {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.8.dist-info}/WHEEL +0 -0
- {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.8.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.8.dist-info}/licenses/LICENSE +0 -0
- {crfm_helm-0.5.6.dist-info → crfm_helm-0.5.8.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
from typing import List, Optional, Tuple, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
|
|
6
|
+
from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
|
|
7
|
+
|
|
8
|
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
9
|
+
from transformers.generation.utils import GenerateOutput
|
|
10
|
+
|
|
11
|
+
from helm.clients.audio_language.llama_omni.model.omni_speech_arch import OmniSpeechMetaModel, OmniSpeechMetaForCausalLM
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class OmniSpeechConfig(LlamaConfig):
|
|
15
|
+
model_type = "omni_speech_llama"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class OmniSpeechLlamaModel(OmniSpeechMetaModel, LlamaModel):
|
|
19
|
+
config_class = OmniSpeechConfig
|
|
20
|
+
|
|
21
|
+
def __init__(self, config: LlamaConfig):
|
|
22
|
+
super(OmniSpeechLlamaModel, self).__init__(config)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class OmniSpeechLlamaForCausalLM(LlamaForCausalLM, OmniSpeechMetaForCausalLM):
|
|
26
|
+
config_class = OmniSpeechConfig
|
|
27
|
+
|
|
28
|
+
def __init__(self, config):
|
|
29
|
+
super(LlamaForCausalLM, self).__init__(config)
|
|
30
|
+
self.model = OmniSpeechLlamaModel(config)
|
|
31
|
+
self.pretraining_tp = config.pretraining_tp
|
|
32
|
+
self.vocab_size = config.vocab_size
|
|
33
|
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
34
|
+
|
|
35
|
+
# Initialize weights and apply final processing
|
|
36
|
+
self.post_init()
|
|
37
|
+
|
|
38
|
+
def get_model(self):
|
|
39
|
+
return self.model
|
|
40
|
+
|
|
41
|
+
def forward(
|
|
42
|
+
self,
|
|
43
|
+
input_ids: torch.LongTensor,
|
|
44
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
45
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
46
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
47
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
48
|
+
labels: Optional[torch.LongTensor] = None,
|
|
49
|
+
use_cache: Optional[bool] = None,
|
|
50
|
+
output_attentions: Optional[bool] = None,
|
|
51
|
+
output_hidden_states: Optional[bool] = None,
|
|
52
|
+
speech: Optional[torch.FloatTensor] = None,
|
|
53
|
+
speech_lengths: Optional[torch.LongTensor] = None,
|
|
54
|
+
tgt_units: Optional[torch.LongTensor] = None,
|
|
55
|
+
return_dict: Optional[bool] = None,
|
|
56
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
57
|
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
58
|
+
|
|
59
|
+
if inputs_embeds is None:
|
|
60
|
+
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = (
|
|
61
|
+
self.prepare_inputs_labels_for_speech_and_text(
|
|
62
|
+
input_ids, position_ids, attention_mask, past_key_values, labels, speech, speech_lengths
|
|
63
|
+
)
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
return super().forward(
|
|
67
|
+
input_ids=input_ids,
|
|
68
|
+
attention_mask=attention_mask,
|
|
69
|
+
position_ids=position_ids,
|
|
70
|
+
past_key_values=past_key_values,
|
|
71
|
+
inputs_embeds=inputs_embeds,
|
|
72
|
+
labels=labels,
|
|
73
|
+
use_cache=use_cache,
|
|
74
|
+
output_attentions=output_attentions,
|
|
75
|
+
output_hidden_states=output_hidden_states,
|
|
76
|
+
return_dict=return_dict,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
@torch.no_grad()
|
|
80
|
+
def generate(
|
|
81
|
+
self,
|
|
82
|
+
inputs: Optional[torch.Tensor] = None,
|
|
83
|
+
speech: Optional[torch.Tensor] = None,
|
|
84
|
+
speech_lengths: Optional[torch.Tensor] = None,
|
|
85
|
+
**kwargs,
|
|
86
|
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
|
87
|
+
position_ids = kwargs.pop("position_ids", None)
|
|
88
|
+
attention_mask = kwargs.pop("attention_mask", None)
|
|
89
|
+
if "inputs_embeds" in kwargs:
|
|
90
|
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
|
91
|
+
|
|
92
|
+
if speech is not None:
|
|
93
|
+
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = (
|
|
94
|
+
self.prepare_inputs_labels_for_speech_and_text(
|
|
95
|
+
inputs, position_ids, attention_mask, None, None, speech, speech_lengths
|
|
96
|
+
)
|
|
97
|
+
)
|
|
98
|
+
else:
|
|
99
|
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
|
100
|
+
|
|
101
|
+
return super().generate(
|
|
102
|
+
position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
|
106
|
+
speech = kwargs.pop("speech", None)
|
|
107
|
+
speech_lengths = kwargs.pop("speech_lengths", None)
|
|
108
|
+
inputs = super().prepare_inputs_for_generation(
|
|
109
|
+
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
|
110
|
+
)
|
|
111
|
+
if speech is not None:
|
|
112
|
+
inputs["speech"] = speech
|
|
113
|
+
inputs["speech_lengths"] = speech_lengths
|
|
114
|
+
return inputs
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
AutoConfig.register("omni_speech_llama", OmniSpeechConfig)
|
|
118
|
+
AutoModelForCausalLM.register(OmniSpeechConfig, OmniSpeechLlamaForCausalLM)
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
from helm.clients.audio_language.llama_omni.model.speech_encoder.builder import build_speech_encoder
|
|
7
|
+
from helm.clients.audio_language.llama_omni.model.speech_projector.builder import build_speech_projector
|
|
8
|
+
from helm.clients.audio_language.llama_omni.constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OmniSpeechMetaModel(nn.Module):
|
|
12
|
+
|
|
13
|
+
def __init__(self, config):
|
|
14
|
+
super(OmniSpeechMetaModel, self).__init__(config)
|
|
15
|
+
self.config = config
|
|
16
|
+
|
|
17
|
+
if hasattr(config, "speech_encoder"):
|
|
18
|
+
self.speech_encoder = build_speech_encoder(config)
|
|
19
|
+
self.speech_projector = build_speech_projector(config)
|
|
20
|
+
|
|
21
|
+
def get_speech_encoder(self):
|
|
22
|
+
speech_encoder = getattr(self, "speech_encoder", None)
|
|
23
|
+
if type(speech_encoder) is list:
|
|
24
|
+
speech_encoder = speech_encoder[0]
|
|
25
|
+
return speech_encoder
|
|
26
|
+
|
|
27
|
+
def initialize_speech_modules(self, model_args, fsdp=None):
|
|
28
|
+
self.config.speech_encoder = getattr(model_args, "speech_encoder", None)
|
|
29
|
+
self.config.speech_encoder_type = getattr(model_args, "speech_encoder_type", None)
|
|
30
|
+
self.config.speech_projector_type = getattr(model_args, "speech_projector_type", "linear")
|
|
31
|
+
self.config.speech_encoder_ds_rate = getattr(model_args, "speech_encoder_ds_rate", 5)
|
|
32
|
+
self.config.speech_encoder_hidden_size = getattr(model_args, "speech_encoder_hidden_size", 1280)
|
|
33
|
+
|
|
34
|
+
if self.get_speech_encoder() is None:
|
|
35
|
+
speech_encoder = build_speech_encoder(self.config)
|
|
36
|
+
if fsdp is not None and len(fsdp) > 0:
|
|
37
|
+
self.speech_encoder = [speech_encoder]
|
|
38
|
+
else:
|
|
39
|
+
self.speech_encoder = speech_encoder
|
|
40
|
+
|
|
41
|
+
if getattr(self, "speech_projector", None) is None:
|
|
42
|
+
self.speech_projector = build_speech_projector(self.config)
|
|
43
|
+
else:
|
|
44
|
+
# In case it is frozen by LoRA
|
|
45
|
+
for p in self.speech_projector.parameters():
|
|
46
|
+
p.requires_grad = True
|
|
47
|
+
|
|
48
|
+
if model_args.pretrain_speech_projector is not None:
|
|
49
|
+
pretrain_speech_projector_weights = torch.load(model_args.pretrain_speech_projector, map_location="cpu")
|
|
50
|
+
|
|
51
|
+
def get_w(weights, keyword):
|
|
52
|
+
return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
|
|
53
|
+
|
|
54
|
+
self.speech_projector.load_state_dict(get_w(pretrain_speech_projector_weights, "speech_projector"))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class OmniSpeechMetaForCausalLM(ABC):
|
|
58
|
+
def __init__(self, config):
|
|
59
|
+
self.config = config
|
|
60
|
+
|
|
61
|
+
@abstractmethod
|
|
62
|
+
def get_model(self):
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
def get_speech_encoder(self):
|
|
66
|
+
return self.get_model().get_speech_encoder()
|
|
67
|
+
|
|
68
|
+
def get_speech_projector(self):
|
|
69
|
+
return self.get_model().speech_projector
|
|
70
|
+
|
|
71
|
+
def encode_speech(self, speech, speech_lengths):
|
|
72
|
+
speech_encoder_type = self.config.speech_encoder_type
|
|
73
|
+
speech_encoder = self.get_speech_encoder()
|
|
74
|
+
if "whisper" in speech_encoder_type.lower():
|
|
75
|
+
encoder_outs = speech_encoder(speech.permute(0, 2, 1))
|
|
76
|
+
speech_lengths = (speech_lengths + 1) // 2
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError(f"Unknown speech encoder: {speech_encoder}")
|
|
79
|
+
speech_projector_type = self.config.speech_projector_type
|
|
80
|
+
speech_projector = self.get_speech_projector()
|
|
81
|
+
if speech_projector_type == "linear":
|
|
82
|
+
encoder_outs = speech_projector(encoder_outs)
|
|
83
|
+
speech_lengths = speech_lengths // speech_projector.k
|
|
84
|
+
else:
|
|
85
|
+
raise ValueError(f"Unknown speech projector: {speech_projector_type}")
|
|
86
|
+
speech_features = [encoder_outs[i, : speech_lengths[i]] for i in range(len(encoder_outs))]
|
|
87
|
+
return speech_features
|
|
88
|
+
|
|
89
|
+
def prepare_inputs_labels_for_speech_and_text(
|
|
90
|
+
self, input_ids, position_ids, attention_mask, past_key_values, labels, speech, speech_lengths
|
|
91
|
+
):
|
|
92
|
+
# input_ids = input_ids.unsqueeze(0)
|
|
93
|
+
speech_encoder = self.get_speech_encoder()
|
|
94
|
+
if speech_encoder is None or speech is None or input_ids.shape[1] == 1:
|
|
95
|
+
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
|
96
|
+
|
|
97
|
+
speech_features = self.encode_speech(speech, speech_lengths)
|
|
98
|
+
# Let's just add dummy tensors if they do not exist,
|
|
99
|
+
# it is a headache to deal with None all the time.
|
|
100
|
+
# But it is not ideal, and if you have a better idea,
|
|
101
|
+
# please open an issue / submit a PR, thanks.
|
|
102
|
+
_labels = labels
|
|
103
|
+
_position_ids = position_ids
|
|
104
|
+
_attention_mask = attention_mask
|
|
105
|
+
if attention_mask is None:
|
|
106
|
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
|
107
|
+
else:
|
|
108
|
+
attention_mask = attention_mask.bool()
|
|
109
|
+
if position_ids is None:
|
|
110
|
+
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
|
111
|
+
if labels is None:
|
|
112
|
+
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
|
113
|
+
|
|
114
|
+
# remove the padding using attention_mask -- FIXME
|
|
115
|
+
# _input_ids = input_ids
|
|
116
|
+
input_ids = [
|
|
117
|
+
cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
|
|
118
|
+
]
|
|
119
|
+
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
|
120
|
+
|
|
121
|
+
new_input_embeds = []
|
|
122
|
+
new_labels = []
|
|
123
|
+
cur_speech_idx = 0
|
|
124
|
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
|
125
|
+
num_speech = (cur_input_ids == SPEECH_TOKEN_INDEX).sum()
|
|
126
|
+
if num_speech == 0:
|
|
127
|
+
cur_speech_features = speech_features[cur_speech_idx]
|
|
128
|
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
|
129
|
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_speech_features[0:0]], dim=0)
|
|
130
|
+
new_input_embeds.append(cur_input_embeds)
|
|
131
|
+
new_labels.append(labels[batch_idx])
|
|
132
|
+
cur_speech_idx += 1
|
|
133
|
+
continue
|
|
134
|
+
|
|
135
|
+
speech_token_indices = (
|
|
136
|
+
[-1] + torch.where(cur_input_ids == SPEECH_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
|
|
137
|
+
)
|
|
138
|
+
cur_input_ids_nospeech = []
|
|
139
|
+
cur_labels = labels[batch_idx]
|
|
140
|
+
cur_labels_nospeech = []
|
|
141
|
+
for i in range(len(speech_token_indices) - 1):
|
|
142
|
+
cur_input_ids_nospeech.append(cur_input_ids[speech_token_indices[i] + 1 : speech_token_indices[i + 1]])
|
|
143
|
+
cur_labels_nospeech.append(cur_labels[speech_token_indices[i] + 1 : speech_token_indices[i + 1]])
|
|
144
|
+
split_sizes = [x.shape[0] for x in cur_labels_nospeech]
|
|
145
|
+
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_nospeech))
|
|
146
|
+
cur_input_embeds_no_speech = torch.split(cur_input_embeds, split_sizes, dim=0)
|
|
147
|
+
cur_new_input_embeds = []
|
|
148
|
+
cur_new_labels = []
|
|
149
|
+
|
|
150
|
+
for i in range(num_speech + 1):
|
|
151
|
+
cur_new_input_embeds.append(cur_input_embeds_no_speech[i])
|
|
152
|
+
cur_new_labels.append(cur_labels_nospeech[i])
|
|
153
|
+
if i < num_speech:
|
|
154
|
+
cur_speech_features = speech_features[cur_speech_idx]
|
|
155
|
+
cur_speech_idx += 1
|
|
156
|
+
cur_new_input_embeds.append(cur_speech_features)
|
|
157
|
+
cur_new_labels.append(
|
|
158
|
+
torch.full(
|
|
159
|
+
(cur_speech_features.shape[0],),
|
|
160
|
+
IGNORE_INDEX,
|
|
161
|
+
device=cur_labels.device,
|
|
162
|
+
dtype=cur_labels.dtype,
|
|
163
|
+
)
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
cur_new_input_embeds_stack = [x.to(input_ids[0].device) for x in cur_new_input_embeds]
|
|
167
|
+
|
|
168
|
+
cur_new_input_embeds_tensor = torch.cat(cur_new_input_embeds_stack)
|
|
169
|
+
cur_new_labels_tensor = torch.cat(cur_new_labels)
|
|
170
|
+
|
|
171
|
+
new_input_embeds.append(cur_new_input_embeds_tensor)
|
|
172
|
+
new_labels.append(cur_new_labels_tensor)
|
|
173
|
+
|
|
174
|
+
# Truncate sequences to max length as speech features can make the sequence longer
|
|
175
|
+
tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
|
|
176
|
+
if tokenizer_model_max_length is not None:
|
|
177
|
+
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
|
|
178
|
+
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
|
|
179
|
+
|
|
180
|
+
# Combine them
|
|
181
|
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
|
182
|
+
batch_size = len(new_input_embeds)
|
|
183
|
+
|
|
184
|
+
new_input_embeds_padded = []
|
|
185
|
+
new_labels_padded = torch.full(
|
|
186
|
+
(batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device
|
|
187
|
+
)
|
|
188
|
+
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
|
189
|
+
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
|
190
|
+
|
|
191
|
+
for i, (cur_new_embed, cur_new_labels_loop) in enumerate(zip(new_input_embeds, new_labels)):
|
|
192
|
+
cur_len = cur_new_embed.shape[0]
|
|
193
|
+
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
|
|
194
|
+
new_input_embeds_padded.append(
|
|
195
|
+
torch.cat(
|
|
196
|
+
(
|
|
197
|
+
torch.zeros(
|
|
198
|
+
(max_len - cur_len, cur_new_embed.shape[1]),
|
|
199
|
+
dtype=cur_new_embed.dtype,
|
|
200
|
+
device=cur_new_embed.device,
|
|
201
|
+
),
|
|
202
|
+
cur_new_embed,
|
|
203
|
+
),
|
|
204
|
+
dim=0,
|
|
205
|
+
)
|
|
206
|
+
)
|
|
207
|
+
if cur_len > 0:
|
|
208
|
+
new_labels_padded[i, -cur_len:] = cur_new_labels_loop
|
|
209
|
+
attention_mask[i, -cur_len:] = True
|
|
210
|
+
position_ids[i, -cur_len:] = torch.arange(
|
|
211
|
+
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
|
|
212
|
+
)
|
|
213
|
+
else:
|
|
214
|
+
new_input_embeds_padded.append(
|
|
215
|
+
torch.cat(
|
|
216
|
+
(
|
|
217
|
+
cur_new_embed,
|
|
218
|
+
torch.zeros(
|
|
219
|
+
(max_len - cur_len, cur_new_embed.shape[1]),
|
|
220
|
+
dtype=cur_new_embed.dtype,
|
|
221
|
+
device=cur_new_embed.device,
|
|
222
|
+
),
|
|
223
|
+
),
|
|
224
|
+
dim=0,
|
|
225
|
+
)
|
|
226
|
+
)
|
|
227
|
+
if cur_len > 0:
|
|
228
|
+
new_labels_padded[i, :cur_len] = cur_new_labels_loop
|
|
229
|
+
attention_mask[i, :cur_len] = True
|
|
230
|
+
position_ids[i, :cur_len] = torch.arange(
|
|
231
|
+
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
new_input_embeds_tensor = torch.stack(new_input_embeds_padded, dim=0)
|
|
235
|
+
|
|
236
|
+
if _labels is None:
|
|
237
|
+
new_labels_new = None
|
|
238
|
+
else:
|
|
239
|
+
new_labels_new = new_labels_padded
|
|
240
|
+
|
|
241
|
+
if _attention_mask is None:
|
|
242
|
+
attention_mask_new = None
|
|
243
|
+
else:
|
|
244
|
+
attention_mask_new = attention_mask.to(dtype=_attention_mask.dtype)
|
|
245
|
+
|
|
246
|
+
if _position_ids is None:
|
|
247
|
+
position_ids = None
|
|
248
|
+
|
|
249
|
+
return None, position_ids, attention_mask_new, past_key_values, new_input_embeds_tensor, new_labels_new
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from helm.clients.audio_language.llama_omni.model.speech_encoder.speech_encoder import WhisperWrappedEncoder
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def build_speech_encoder(config):
|
|
5
|
+
speech_encoder_type = getattr(config, "speech_encoder_type", "none")
|
|
6
|
+
if "whisper" in speech_encoder_type.lower():
|
|
7
|
+
return WhisperWrappedEncoder.load(config)
|
|
8
|
+
|
|
9
|
+
raise ValueError(f"Unknown speech encoder: {speech_encoder_type}")
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# Adopted from https://github.com/ddlBoJack/SLAM-LLM/blob/main/src/slam_llm/models/encoder.py
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import whisper
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class WhisperWrappedEncoder:
|
|
7
|
+
|
|
8
|
+
@classmethod
|
|
9
|
+
def load(cls, model_config):
|
|
10
|
+
|
|
11
|
+
def replace_layer_norm(module):
|
|
12
|
+
from whisper.model import LayerNorm
|
|
13
|
+
|
|
14
|
+
for name, child in module.named_children():
|
|
15
|
+
if isinstance(child, LayerNorm):
|
|
16
|
+
old_params = child.state_dict()
|
|
17
|
+
new_layer_norm = nn.LayerNorm(
|
|
18
|
+
child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine
|
|
19
|
+
)
|
|
20
|
+
new_layer_norm.load_state_dict(old_params)
|
|
21
|
+
setattr(module, name, new_layer_norm)
|
|
22
|
+
else:
|
|
23
|
+
replace_layer_norm(child)
|
|
24
|
+
|
|
25
|
+
encoder = whisper.load_model(name="large-v3", device="cpu").encoder
|
|
26
|
+
replace_layer_norm(encoder)
|
|
27
|
+
return encoder
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from helm.clients.audio_language.llama_omni.model.speech_generator.speech_generator import SpeechGeneratorCTC
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def build_speech_generator(config):
|
|
5
|
+
generator_type = getattr(config, "speech_generator_type", "ctc")
|
|
6
|
+
if generator_type == "ctc":
|
|
7
|
+
return SpeechGeneratorCTC(config)
|
|
8
|
+
|
|
9
|
+
raise ValueError(f"Unknown generator type: {generator_type}")
|