crfm-helm 0.5.7__py3-none-any.whl → 0.5.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/METADATA +5 -77
- {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/RECORD +228 -197
- helm/benchmark/adaptation/adapter_spec.py +5 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +11 -3
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +11 -8
- helm/benchmark/annotation/aci_bench_annotator.py +11 -22
- helm/benchmark/annotation/alrage_annotator.py +90 -0
- helm/benchmark/annotation/chw_care_plan_annotator.py +10 -21
- helm/benchmark/annotation/dischargeme_annotator.py +11 -22
- helm/benchmark/annotation/med_dialog_annotator.py +11 -22
- helm/benchmark/annotation/medalign_annotator.py +11 -22
- helm/benchmark/annotation/medi_qa_annotator.py +11 -22
- helm/benchmark/annotation/medication_qa_annotator.py +11 -22
- helm/benchmark/annotation/mental_health_annotator.py +11 -22
- helm/benchmark/annotation/mimic_bhc_annotator.py +11 -22
- helm/benchmark/annotation/mimic_rrs_annotator.py +11 -22
- helm/benchmark/annotation/model_as_judge.py +23 -18
- helm/benchmark/annotation/mtsamples_procedures_annotator.py +11 -22
- helm/benchmark/annotation/mtsamples_replicate_annotator.py +11 -22
- helm/benchmark/annotation/starr_patient_instructions_annotator.py +11 -22
- helm/benchmark/metrics/air_bench_metrics.py +3157 -1
- helm/benchmark/metrics/alrage_metric.py +35 -0
- helm/benchmark/metrics/basic_metrics.py +267 -2
- helm/benchmark/metrics/classification_metrics.py +19 -1
- helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +12 -1
- helm/benchmark/metrics/dry_run_metrics.py +30 -1
- helm/benchmark/metrics/efficiency_metrics.py +74 -0
- helm/benchmark/metrics/ehr_sql_metrics.py +57 -1
- helm/benchmark/metrics/evaluate_reference_metrics.py +299 -0
- helm/benchmark/metrics/gpqa_chain_of_thought_metric.py +13 -1
- helm/benchmark/metrics/helpdesk_call_summarization_metrics.py +13 -1
- helm/benchmark/metrics/ifeval_metrics.py +13 -1
- helm/benchmark/metrics/instruction_following_critique_metrics.py +41 -1
- helm/benchmark/metrics/kpi_edgar_metrics.py +21 -0
- helm/benchmark/metrics/language_modeling_metrics.py +13 -1
- helm/benchmark/metrics/live_qa_metrics.py +13 -1
- helm/benchmark/metrics/llm_jury_metrics.py +13 -1
- helm/benchmark/metrics/medcalc_bench_metrics.py +14 -1
- helm/benchmark/metrics/medec_metrics.py +25 -2
- helm/benchmark/metrics/metric.py +25 -0
- helm/benchmark/metrics/mimiciv_billing_code_metrics.py +32 -1
- helm/benchmark/metrics/omni_math_metrics.py +13 -1
- helm/benchmark/metrics/seahelm_metrics.py +14 -1
- helm/benchmark/metrics/summac/model_summac.py +2 -2
- helm/benchmark/metrics/summarization_metrics.py +129 -1
- helm/benchmark/metrics/toxicity_metrics.py +31 -1
- helm/benchmark/metrics/wildbench_metrics.py +21 -1
- helm/benchmark/presentation/schema.py +5 -22
- helm/benchmark/presentation/summarize.py +180 -11
- helm/benchmark/presentation/taxonomy_info.py +20 -0
- helm/benchmark/run_expander.py +4 -0
- helm/benchmark/run_specs/arabic_run_specs.py +134 -16
- helm/benchmark/run_specs/bluex_run_specs.py +1 -1
- helm/benchmark/run_specs/classic_run_specs.py +2 -2
- helm/benchmark/run_specs/long_context_run_specs.py +2 -2
- helm/benchmark/run_specs/medhelm/__init__.py +0 -0
- helm/benchmark/run_specs/medhelm/benchmark_config.py +219 -0
- helm/benchmark/run_specs/medhelm_run_specs.py +360 -50
- helm/benchmark/scenarios/aci_bench_scenario.py +23 -0
- helm/benchmark/scenarios/air_bench_scenario.py +21 -0
- helm/benchmark/scenarios/alrage_scenario.py +54 -0
- helm/benchmark/scenarios/anthropic_hh_rlhf_scenario.py +23 -1
- helm/benchmark/scenarios/arabic_exams_scenario.py +114 -0
- helm/benchmark/scenarios/arabic_mmlu_scenario.py +8 -4
- helm/benchmark/scenarios/aratrust_scenario.py +19 -0
- helm/benchmark/scenarios/babi_qa_scenario.py +15 -0
- helm/benchmark/scenarios/bbq_scenario.py +15 -0
- helm/benchmark/scenarios/best_chatgpt_prompts.yaml +473 -0
- helm/benchmark/scenarios/bluex_scenario.py +6 -2
- helm/benchmark/scenarios/bold_scenario.py +15 -0
- helm/benchmark/scenarios/boolq_scenario.py +20 -0
- helm/benchmark/scenarios/chw_care_plan_scenario.py +23 -0
- helm/benchmark/scenarios/civil_comments_scenario.py +13 -0
- helm/benchmark/scenarios/clear_scenario.py +23 -0
- helm/benchmark/scenarios/cleva_scenario.py +479 -0
- helm/benchmark/scenarios/code_scenario.py +28 -0
- helm/benchmark/scenarios/commonsense_scenario.py +26 -0
- helm/benchmark/scenarios/compositional_instructions.yaml +70 -0
- helm/benchmark/scenarios/conv_fin_qa_calc_scenario.py +21 -0
- helm/benchmark/scenarios/copyright_scenario.py +35 -1
- helm/benchmark/scenarios/cti_to_mitre_scenario.py +21 -0
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +22 -1
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +23 -1
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +22 -1
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +21 -1
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +13 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +13 -1
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +13 -1
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +13 -1
- helm/benchmark/scenarios/dischargeme_scenario.py +24 -0
- helm/benchmark/scenarios/disinformation_scenario.py +22 -0
- helm/benchmark/scenarios/dyck_language_scenario.py +15 -0
- helm/benchmark/scenarios/ehrshot_scenario.py +22 -0
- helm/benchmark/scenarios/enem_challenge_scenario.py +19 -0
- helm/benchmark/scenarios/entity_data_imputation_scenario.py +14 -0
- helm/benchmark/scenarios/entity_matching_scenario.py +14 -0
- helm/benchmark/scenarios/financial_phrasebank_scenario.py +21 -0
- helm/benchmark/scenarios/gold_commodity_news_scenario.py +21 -0
- helm/benchmark/scenarios/gpqa_scenario.py +18 -0
- helm/benchmark/scenarios/grammar_scenario.py +20 -1
- helm/benchmark/scenarios/gsm_scenario.py +15 -0
- helm/benchmark/scenarios/headqa_scenario.py +22 -0
- helm/benchmark/scenarios/helpdesk_call_summarization_scenario.py +13 -0
- helm/benchmark/scenarios/ice_scenario.py +21 -1
- helm/benchmark/scenarios/ifeval_scenario.py +18 -0
- helm/benchmark/scenarios/imdb_scenario.py +15 -0
- helm/benchmark/scenarios/koala_scenario.py +21 -1
- helm/benchmark/scenarios/kpi_edgar_scenario.py +21 -0
- helm/benchmark/scenarios/legal_contract_summarization_scenario.py +20 -0
- helm/benchmark/scenarios/legal_summarization_scenario.py +50 -0
- helm/benchmark/scenarios/legal_support_scenario.py +13 -0
- helm/benchmark/scenarios/legalbench_scenario.py +20 -0
- helm/benchmark/scenarios/lex_glue_scenario.py +11 -0
- helm/benchmark/scenarios/lextreme_scenario.py +11 -0
- helm/benchmark/scenarios/lsat_qa_scenario.py +14 -0
- helm/benchmark/scenarios/madinah_qa_scenario.py +73 -0
- helm/benchmark/scenarios/math_scenario.py +26 -0
- helm/benchmark/scenarios/mbzuai_human_translated_arabic_mmlu.py +68 -0
- helm/benchmark/scenarios/med_dialog_scenario.py +32 -1
- helm/benchmark/scenarios/med_mcqa_scenario.py +14 -0
- helm/benchmark/scenarios/med_qa_scenario.py +14 -0
- helm/benchmark/scenarios/medalign_scenario.py +23 -0
- helm/benchmark/scenarios/medbullets_scenario.py +22 -0
- helm/benchmark/scenarios/medcalc_bench_scenario.py +22 -0
- helm/benchmark/scenarios/medec_scenario.py +23 -0
- helm/benchmark/scenarios/medhallu_scenario.py +23 -0
- helm/benchmark/scenarios/medhelm/__init__.py +0 -0
- helm/benchmark/scenarios/medhelm/judges.yaml +14 -0
- helm/benchmark/scenarios/medhelm_configurable_scenario.py +101 -0
- helm/benchmark/scenarios/medi_qa_scenario.py +23 -0
- helm/benchmark/scenarios/medication_qa_scenario.py +31 -1
- helm/benchmark/scenarios/mental_health_scenario.py +23 -0
- helm/benchmark/scenarios/mimic_bhc_scenario.py +24 -0
- helm/benchmark/scenarios/mimic_rrs_scenario.py +23 -0
- helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +22 -0
- helm/benchmark/scenarios/mmlu_pro_scenario.py +18 -0
- helm/benchmark/scenarios/mmlu_scenario.py +15 -0
- helm/benchmark/scenarios/msmarco_scenario.py +30 -0
- helm/benchmark/scenarios/mtsamples_procedures_scenario.py +22 -0
- helm/benchmark/scenarios/mtsamples_replicate_scenario.py +22 -0
- helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +20 -0
- helm/benchmark/scenarios/narrativeqa_scenario.py +20 -0
- helm/benchmark/scenarios/natural_qa_scenario.py +32 -0
- helm/benchmark/scenarios/omni_math_scenario.py +18 -0
- helm/benchmark/scenarios/open_assistant_scenario.py +22 -0
- helm/benchmark/scenarios/pubmed_qa_scenario.py +22 -0
- helm/benchmark/scenarios/quac_scenario.py +14 -0
- helm/benchmark/scenarios/race_based_med_scenario.py +23 -0
- helm/benchmark/scenarios/raft_scenario.py +15 -0
- helm/benchmark/scenarios/real_toxicity_prompts_scenario.py +14 -1
- helm/benchmark/scenarios/scenario.py +31 -0
- helm/benchmark/scenarios/seahelm_scenario.py +348 -0
- helm/benchmark/scenarios/self_instruct_scenario.py +29 -1
- helm/benchmark/scenarios/shc_bmt_scenario.py +22 -0
- helm/benchmark/scenarios/shc_cdi_scenario.py +20 -0
- helm/benchmark/scenarios/shc_conf_scenario.py +23 -0
- helm/benchmark/scenarios/shc_ent_scenario.py +21 -0
- helm/benchmark/scenarios/shc_gip_scenario.py +20 -0
- helm/benchmark/scenarios/shc_privacy_scenario.py +22 -0
- helm/benchmark/scenarios/shc_proxy_scenario.py +22 -0
- helm/benchmark/scenarios/shc_ptbm_scenario.py +23 -0
- helm/benchmark/scenarios/shc_sequoia_scenario.py +21 -0
- helm/benchmark/scenarios/situation_prompts.yaml +49 -0
- helm/benchmark/scenarios/starr_patient_instructions_scenario.py +22 -0
- helm/benchmark/scenarios/summarization_scenario.py +37 -0
- helm/benchmark/scenarios/synthetic_efficiency_scenario.py +22 -1
- helm/benchmark/scenarios/synthetic_reasoning_natural_scenario.py +13 -0
- helm/benchmark/scenarios/test_alrage_scenario.py +23 -0
- helm/benchmark/scenarios/test_arabic_exams_scenario.py +21 -0
- helm/benchmark/scenarios/test_aratrust_scenario.py +1 -1
- helm/benchmark/scenarios/test_bluex_scenario.py +2 -2
- helm/benchmark/scenarios/the_pile_scenario.py +13 -1
- helm/benchmark/scenarios/truthful_qa_scenario.py +14 -0
- helm/benchmark/scenarios/twitter_aae_scenario.py +20 -1
- helm/benchmark/scenarios/vicuna_scenario.py +21 -1
- helm/benchmark/scenarios/wikifact_scenario.py +20 -0
- helm/benchmark/scenarios/wildbench_scenario.py +18 -0
- helm/benchmark/scenarios/wmt_14_scenario.py +12 -0
- helm/benchmark/static/schema_arabic.yaml +55 -12
- helm/benchmark/static/schema_long_context.yaml +17 -17
- helm/benchmark/static/schema_medhelm.yaml +36 -0
- helm/benchmark/static/schema_slp.yaml +219 -0
- helm/benchmark/static_build/assets/index-671a5e06.js +10 -0
- helm/benchmark/static_build/assets/index-9352595e.css +1 -0
- helm/benchmark/static_build/index.html +2 -2
- helm/clients/audio_language/llama_omni/arguments.py +61 -0
- helm/clients/audio_language/llama_omni/constants.py +9 -0
- helm/clients/audio_language/llama_omni/conversation.py +213 -0
- helm/clients/audio_language/llama_omni/model/__init__.py +0 -0
- helm/clients/audio_language/llama_omni/model/builder.py +88 -0
- helm/clients/audio_language/llama_omni/model/language_model/omni_speech2s_llama.py +190 -0
- helm/clients/audio_language/llama_omni/model/language_model/omni_speech_llama.py +118 -0
- helm/clients/audio_language/llama_omni/model/omni_speech_arch.py +249 -0
- helm/clients/audio_language/llama_omni/model/speech_encoder/builder.py +9 -0
- helm/clients/audio_language/llama_omni/model/speech_encoder/speech_encoder.py +27 -0
- helm/clients/audio_language/llama_omni/model/speech_generator/builder.py +9 -0
- helm/clients/audio_language/llama_omni/model/speech_generator/generation.py +622 -0
- helm/clients/audio_language/llama_omni/model/speech_generator/speech_generator.py +104 -0
- helm/clients/audio_language/llama_omni/model/speech_projector/builder.py +9 -0
- helm/clients/audio_language/llama_omni/model/speech_projector/speech_projector.py +27 -0
- helm/clients/audio_language/llama_omni/preprocess.py +295 -0
- helm/clients/audio_language/llama_omni/utils.py +202 -0
- helm/clients/audio_language/qwen_omni/configuration_qwen2_5_omni.py +519 -0
- helm/clients/audio_language/qwen_omni/modeling_qwen2_5_omni.py +4308 -0
- helm/clients/audio_language/qwen_omni/processing_qwen2_5_omni.py +270 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/__init__.py +0 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/__init__.py +8 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/audio_process.py +56 -0
- helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/vision_process.py +380 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +1 -1
- helm/clients/image_generation/mindalle/models/stage1/layers.py +2 -2
- helm/clients/openai_client.py +31 -19
- helm/clients/openai_responses_client.py +27 -3
- helm/clients/openrouter_client.py +31 -0
- helm/clients/test_openrouter_client.py +69 -0
- helm/clients/together_client.py +48 -11
- helm/clients/vertexai_client.py +8 -2
- helm/config/model_deployments.yaml +75 -1
- helm/config/model_metadata.yaml +70 -2
- helm/config/tokenizer_configs.yaml +19 -1
- helm/proxy/example_queries.py +8 -8
- helm/proxy/server.py +2 -1
- helm/proxy/static/index.css +4 -0
- helm/proxy/static/index.js +7 -1
- helm/benchmark/metrics/aci_bench_metrics.py +0 -14
- helm/benchmark/metrics/chw_care_plan_metrics.py +0 -14
- helm/benchmark/metrics/dischargeme_metrics.py +0 -14
- helm/benchmark/metrics/med_dialog_metrics.py +0 -14
- helm/benchmark/metrics/medalign_metrics.py +0 -14
- helm/benchmark/metrics/medi_qa_metrics.py +0 -14
- helm/benchmark/metrics/medication_qa_metrics.py +0 -14
- helm/benchmark/metrics/mental_health_metrics.py +0 -14
- helm/benchmark/metrics/mimic_bhc_metrics.py +0 -14
- helm/benchmark/metrics/mimic_rrs_metrics.py +0 -14
- helm/benchmark/metrics/mtsamples_procedures_metrics.py +0 -14
- helm/benchmark/metrics/mtsamples_replicate_metrics.py +0 -14
- helm/benchmark/metrics/starr_patient_instructions_metrics.py +0 -14
- helm/benchmark/static_build/assets/index-b9779128.css +0 -1
- helm/benchmark/static_build/assets/index-e439d5e1.js +0 -10
- {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/WHEEL +0 -0
- {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/licenses/LICENSE +0 -0
- {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.8.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,622 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import inspect
|
|
3
|
+
import warnings
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
from typing import Optional, Union, List, Callable
|
|
6
|
+
import torch.distributed as dist
|
|
7
|
+
|
|
8
|
+
from transformers import PreTrainedModel
|
|
9
|
+
from transformers.generation.streamers import BaseStreamer
|
|
10
|
+
from transformers.generation.utils import (
|
|
11
|
+
GenerationConfig,
|
|
12
|
+
GenerationMode,
|
|
13
|
+
LogitsProcessorList,
|
|
14
|
+
StoppingCriteriaList,
|
|
15
|
+
GenerationMixin,
|
|
16
|
+
GenerateEncoderDecoderOutput,
|
|
17
|
+
GenerateDecoderOnlyOutput,
|
|
18
|
+
GenerateNonBeamOutput,
|
|
19
|
+
is_deepspeed_zero3_enabled,
|
|
20
|
+
is_torchdynamo_compiling,
|
|
21
|
+
NEED_SETUP_CACHE_CLASSES_MAPPING,
|
|
22
|
+
QUANT_BACKEND_CLASSES_MAPPING,
|
|
23
|
+
is_hqq_available,
|
|
24
|
+
QuantizedCacheConfig,
|
|
25
|
+
is_quanto_available,
|
|
26
|
+
DynamicCache,
|
|
27
|
+
EncoderDecoderCache,
|
|
28
|
+
logging,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
logger = logging.get_logger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class GenerationWithCTC(GenerationMixin):
|
|
35
|
+
|
|
36
|
+
@torch.no_grad()
|
|
37
|
+
def generate(
|
|
38
|
+
self,
|
|
39
|
+
inputs: Optional[torch.Tensor] = None,
|
|
40
|
+
speech: Optional[torch.Tensor] = None,
|
|
41
|
+
speech_lengths: Optional[torch.Tensor] = None,
|
|
42
|
+
generation_config: Optional[GenerationConfig] = None,
|
|
43
|
+
logits_processor: Optional[LogitsProcessorList] = None,
|
|
44
|
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
|
45
|
+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
|
46
|
+
synced_gpus: Optional[bool] = None,
|
|
47
|
+
assistant_model: Optional["PreTrainedModel"] = None,
|
|
48
|
+
streamer: Optional["BaseStreamer"] = None,
|
|
49
|
+
streamer_unit: Optional["BaseStreamer"] = None,
|
|
50
|
+
streaming_unit_gen=False,
|
|
51
|
+
negative_prompt_ids: Optional[torch.Tensor] = None,
|
|
52
|
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
|
53
|
+
**kwargs,
|
|
54
|
+
):
|
|
55
|
+
|
|
56
|
+
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
|
57
|
+
self._validate_model_class()
|
|
58
|
+
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
|
|
59
|
+
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
|
60
|
+
self._validate_model_kwargs(model_kwargs.copy())
|
|
61
|
+
self._validate_assistant(assistant_model)
|
|
62
|
+
|
|
63
|
+
# 2. Set generation parameters if not already defined
|
|
64
|
+
if synced_gpus is None:
|
|
65
|
+
if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1:
|
|
66
|
+
synced_gpus = True
|
|
67
|
+
else:
|
|
68
|
+
synced_gpus = False
|
|
69
|
+
|
|
70
|
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
|
71
|
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
|
72
|
+
|
|
73
|
+
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
|
74
|
+
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
|
75
|
+
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
|
76
|
+
|
|
77
|
+
# 3. Define model inputs
|
|
78
|
+
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
|
79
|
+
inputs, generation_config.bos_token_id, model_kwargs
|
|
80
|
+
)
|
|
81
|
+
batch_size = inputs_tensor.shape[0]
|
|
82
|
+
|
|
83
|
+
device = inputs_tensor.device
|
|
84
|
+
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
|
|
85
|
+
|
|
86
|
+
# decoder-only models must use left-padding for batched generation.
|
|
87
|
+
if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
|
|
88
|
+
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
|
|
89
|
+
# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
|
|
90
|
+
if (
|
|
91
|
+
generation_config._pad_token_tensor is not None
|
|
92
|
+
and batch_size > 1
|
|
93
|
+
and len(inputs_tensor.shape) == 2
|
|
94
|
+
and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
|
|
95
|
+
):
|
|
96
|
+
logger.warning(
|
|
97
|
+
"A decoder-only architecture is being used, but right-padding was detected! For correct "
|
|
98
|
+
"generation results, please set `padding_side='left'` when initializing the tokenizer."
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# 4. Define other model kwargs
|
|
102
|
+
# decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
|
|
103
|
+
# generating the first new token or not, and we only want to use the embeddings for the first new token)
|
|
104
|
+
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
|
|
105
|
+
model_kwargs["use_cache"] = True
|
|
106
|
+
else:
|
|
107
|
+
model_kwargs["use_cache"] = generation_config.use_cache
|
|
108
|
+
|
|
109
|
+
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
|
|
110
|
+
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
|
111
|
+
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
|
|
115
|
+
# if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
|
|
116
|
+
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
|
|
117
|
+
inputs_tensor, model_kwargs, model_input_name, generation_config
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
|
121
|
+
if self.config.is_encoder_decoder:
|
|
122
|
+
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
|
|
123
|
+
batch_size=batch_size,
|
|
124
|
+
model_input_name=model_input_name,
|
|
125
|
+
model_kwargs=model_kwargs,
|
|
126
|
+
decoder_start_token_id=generation_config._decoder_start_token_tensor,
|
|
127
|
+
device=inputs_tensor.device,
|
|
128
|
+
)
|
|
129
|
+
else:
|
|
130
|
+
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
|
131
|
+
|
|
132
|
+
if generation_config.token_healing:
|
|
133
|
+
input_ids = self.heal_tokens(input_ids, tokenizer)
|
|
134
|
+
|
|
135
|
+
if streamer is not None:
|
|
136
|
+
streamer.put(input_ids.cpu())
|
|
137
|
+
|
|
138
|
+
# 6. Prepare `max_length` depending on other stopping criteria.
|
|
139
|
+
input_ids_length = input_ids.shape[-1]
|
|
140
|
+
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
|
141
|
+
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
|
|
142
|
+
generation_config = self._prepare_generated_length(
|
|
143
|
+
generation_config=generation_config,
|
|
144
|
+
has_default_max_length=has_default_max_length,
|
|
145
|
+
has_default_min_length=has_default_min_length,
|
|
146
|
+
model_input_name=model_input_name,
|
|
147
|
+
inputs_tensor=inputs_tensor,
|
|
148
|
+
input_ids_length=input_ids_length,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# use_dynamic_cache_by_default = False
|
|
152
|
+
if "mamba" in self.__class__.__name__.lower():
|
|
153
|
+
cache_name = "cache_params"
|
|
154
|
+
else:
|
|
155
|
+
cache_name = "past_key_values"
|
|
156
|
+
if generation_config.cache_implementation is not None and (model_kwargs.get(cache_name) is not None):
|
|
157
|
+
raise ValueError(
|
|
158
|
+
f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a "
|
|
159
|
+
"Cache object) is unsupported. Please use only one of the two."
|
|
160
|
+
)
|
|
161
|
+
elif generation_config.cache_implementation is not None:
|
|
162
|
+
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
|
|
163
|
+
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
|
|
164
|
+
raise ValueError(
|
|
165
|
+
"This model does not support `cache_implementation='static'`. Please check the following "
|
|
166
|
+
"issue: https://github.com/huggingface/transformers/issues/28981"
|
|
167
|
+
)
|
|
168
|
+
model_kwargs[cache_name] = self._get_cache(
|
|
169
|
+
generation_config.cache_implementation,
|
|
170
|
+
getattr(generation_config, "num_beams", 1) * batch_size,
|
|
171
|
+
generation_config.max_length,
|
|
172
|
+
model_kwargs,
|
|
173
|
+
)
|
|
174
|
+
elif generation_config.cache_implementation == "quantized":
|
|
175
|
+
if not self._supports_quantized_cache:
|
|
176
|
+
raise ValueError(
|
|
177
|
+
"This model does not support the quantized cache. If you want your model to support quantized "
|
|
178
|
+
"cache, please open an issue."
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
cache_config = (
|
|
182
|
+
generation_config.cache_config
|
|
183
|
+
if generation_config.cache_config is not None
|
|
184
|
+
else QuantizedCacheConfig()
|
|
185
|
+
)
|
|
186
|
+
cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
|
|
187
|
+
|
|
188
|
+
if cache_config.backend == "quanto" and not is_quanto_available():
|
|
189
|
+
raise ImportError(
|
|
190
|
+
"You need to install `quanto` in order to use KV cache quantization with quanto backend. "
|
|
191
|
+
"Please install it via with `pip install quanto`"
|
|
192
|
+
)
|
|
193
|
+
elif cache_config.backend == "HQQ" and not is_hqq_available():
|
|
194
|
+
raise ImportError(
|
|
195
|
+
"You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
|
|
196
|
+
"Please install it via with `pip install hqq`"
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
model_kwargs[cache_name] = cache_class(cache_config)
|
|
200
|
+
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
|
|
201
|
+
# keeps copying the cache thus using much more memory
|
|
202
|
+
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
|
|
203
|
+
past = model_kwargs.get(cache_name, None)
|
|
204
|
+
requires_cross_attention_cache = (
|
|
205
|
+
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
|
|
206
|
+
)
|
|
207
|
+
if past is None:
|
|
208
|
+
model_kwargs[cache_name] = (
|
|
209
|
+
DynamicCache()
|
|
210
|
+
if not requires_cross_attention_cache
|
|
211
|
+
else EncoderDecoderCache(DynamicCache(), DynamicCache())
|
|
212
|
+
)
|
|
213
|
+
# use_dynamic_cache_by_default = True
|
|
214
|
+
elif isinstance(past, tuple):
|
|
215
|
+
model_kwargs[cache_name] = (
|
|
216
|
+
DynamicCache.from_legacy_cache(past)
|
|
217
|
+
if not requires_cross_attention_cache
|
|
218
|
+
else EncoderDecoderCache.from_legacy_cache(past)
|
|
219
|
+
)
|
|
220
|
+
# use_dynamic_cache_by_default = True
|
|
221
|
+
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
|
222
|
+
|
|
223
|
+
# 7. determine generation mode
|
|
224
|
+
generation_mode = generation_config.get_generation_mode(assistant_model)
|
|
225
|
+
|
|
226
|
+
if (streamer is not None or streamer_unit is not None) and (generation_config.num_beams > 1):
|
|
227
|
+
raise ValueError(
|
|
228
|
+
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
if self.device.type != input_ids.device.type:
|
|
232
|
+
warnings.warn(
|
|
233
|
+
"You are calling .generate() with the `input_ids` being on a device type different"
|
|
234
|
+
f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
|
|
235
|
+
f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
|
|
236
|
+
" Please make sure that you have put `input_ids` to the"
|
|
237
|
+
f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
|
|
238
|
+
" running `.generate()`.",
|
|
239
|
+
UserWarning,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# 8. prepare distribution pre_processing samplers
|
|
243
|
+
prepared_logits_processor = self._get_logits_processor(
|
|
244
|
+
generation_config=generation_config,
|
|
245
|
+
input_ids_seq_length=input_ids_length,
|
|
246
|
+
encoder_input_ids=inputs_tensor,
|
|
247
|
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
|
248
|
+
logits_processor=logits_processor,
|
|
249
|
+
device=inputs_tensor.device,
|
|
250
|
+
model_kwargs=model_kwargs,
|
|
251
|
+
negative_prompt_ids=negative_prompt_ids,
|
|
252
|
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# 9. prepare stopping criteria
|
|
256
|
+
prepared_stopping_criteria = self._get_stopping_criteria(
|
|
257
|
+
generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
|
|
258
|
+
)
|
|
259
|
+
# 10. go into different generation modes
|
|
260
|
+
|
|
261
|
+
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
|
262
|
+
# 11. prepare logits warper
|
|
263
|
+
prepared_logits_warper = (
|
|
264
|
+
self._get_logits_warper(generation_config, device=input_ids.device)
|
|
265
|
+
if generation_config.do_sample
|
|
266
|
+
else None
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
|
|
270
|
+
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
|
271
|
+
input_ids=input_ids,
|
|
272
|
+
expand_size=generation_config.num_return_sequences,
|
|
273
|
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
274
|
+
**model_kwargs,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
|
|
278
|
+
if streaming_unit_gen:
|
|
279
|
+
return self._sample_streaming_unit(
|
|
280
|
+
input_ids,
|
|
281
|
+
logits_processor=prepared_logits_processor,
|
|
282
|
+
logits_warper=prepared_logits_warper,
|
|
283
|
+
stopping_criteria=prepared_stopping_criteria,
|
|
284
|
+
generation_config=generation_config,
|
|
285
|
+
synced_gpus=synced_gpus,
|
|
286
|
+
streamer=streamer,
|
|
287
|
+
streamer_unit=streamer_unit,
|
|
288
|
+
**model_kwargs,
|
|
289
|
+
)
|
|
290
|
+
else:
|
|
291
|
+
return self._sample(
|
|
292
|
+
input_ids,
|
|
293
|
+
logits_processor=prepared_logits_processor,
|
|
294
|
+
logits_warper=prepared_logits_warper,
|
|
295
|
+
stopping_criteria=prepared_stopping_criteria,
|
|
296
|
+
generation_config=generation_config,
|
|
297
|
+
synced_gpus=synced_gpus,
|
|
298
|
+
streamer=streamer,
|
|
299
|
+
**model_kwargs,
|
|
300
|
+
)
|
|
301
|
+
else:
|
|
302
|
+
raise NotImplementedError
|
|
303
|
+
|
|
304
|
+
def _sample(
|
|
305
|
+
self,
|
|
306
|
+
input_ids: torch.Tensor,
|
|
307
|
+
logits_processor: LogitsProcessorList,
|
|
308
|
+
stopping_criteria: StoppingCriteriaList,
|
|
309
|
+
generation_config: GenerationConfig,
|
|
310
|
+
synced_gpus: bool,
|
|
311
|
+
streamer: Optional["BaseStreamer"],
|
|
312
|
+
logits_warper: Optional[LogitsProcessorList],
|
|
313
|
+
**model_kwargs,
|
|
314
|
+
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
|
315
|
+
# init values
|
|
316
|
+
pad_token_id = generation_config._pad_token_tensor
|
|
317
|
+
output_attentions = generation_config.output_attentions
|
|
318
|
+
output_hidden_states = generation_config.output_hidden_states
|
|
319
|
+
output_scores = generation_config.output_scores
|
|
320
|
+
output_logits = generation_config.output_logits
|
|
321
|
+
return_dict_in_generate = generation_config.return_dict_in_generate
|
|
322
|
+
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
|
323
|
+
do_sample = generation_config.do_sample
|
|
324
|
+
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
|
|
325
|
+
raise ValueError(
|
|
326
|
+
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
|
|
327
|
+
f"{logits_warper})."
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# init attention / hidden states / scores tuples
|
|
331
|
+
# scores = () if (return_dict_in_generate and output_scores) else None
|
|
332
|
+
# raw_logits = () if (return_dict_in_generate and output_logits) else None
|
|
333
|
+
# decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
334
|
+
# cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
335
|
+
# decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
|
336
|
+
|
|
337
|
+
scores: tuple = ()
|
|
338
|
+
raw_logits: tuple = ()
|
|
339
|
+
decoder_attentions: tuple = ()
|
|
340
|
+
cross_attentions: tuple = ()
|
|
341
|
+
decoder_hidden_states: tuple = ()
|
|
342
|
+
|
|
343
|
+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
|
344
|
+
if return_dict_in_generate and self.config.is_encoder_decoder:
|
|
345
|
+
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
|
346
|
+
encoder_hidden_states = (
|
|
347
|
+
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
# keep track of which sequences are already finished
|
|
351
|
+
batch_size = input_ids.shape[0]
|
|
352
|
+
this_peer_finished = False
|
|
353
|
+
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
|
354
|
+
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
|
355
|
+
|
|
356
|
+
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
|
357
|
+
# prepare model inputs
|
|
358
|
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
|
359
|
+
|
|
360
|
+
# prepare variable output controls (note: some models won't accept all output controls)
|
|
361
|
+
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
|
|
362
|
+
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
|
|
363
|
+
|
|
364
|
+
# forward pass to get next token
|
|
365
|
+
outputs = self(**model_inputs, return_dict=True)
|
|
366
|
+
|
|
367
|
+
if synced_gpus and this_peer_finished:
|
|
368
|
+
continue # don't waste resources running the code we don't need
|
|
369
|
+
|
|
370
|
+
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be
|
|
371
|
+
# very large for first iteration (the clone itself is always small)
|
|
372
|
+
next_token_logits = outputs.logits[:, -1, :].clone()
|
|
373
|
+
|
|
374
|
+
# pre-process distribution
|
|
375
|
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
|
376
|
+
if do_sample and logits_warper is not None:
|
|
377
|
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
|
378
|
+
|
|
379
|
+
# Store scores, attentions and hidden_states when required
|
|
380
|
+
if return_dict_in_generate:
|
|
381
|
+
if output_scores:
|
|
382
|
+
scores += (next_token_scores,)
|
|
383
|
+
if output_logits:
|
|
384
|
+
raw_logits += (next_token_logits,)
|
|
385
|
+
if output_attentions:
|
|
386
|
+
decoder_attentions += (
|
|
387
|
+
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
|
388
|
+
)
|
|
389
|
+
if self.config.is_encoder_decoder:
|
|
390
|
+
cross_attentions += (outputs.cross_attentions,)
|
|
391
|
+
|
|
392
|
+
if output_hidden_states:
|
|
393
|
+
decoder_hidden_states += (
|
|
394
|
+
(outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,)
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
# token selection
|
|
398
|
+
if do_sample:
|
|
399
|
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
|
400
|
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
|
401
|
+
else:
|
|
402
|
+
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
|
403
|
+
|
|
404
|
+
# finished sentences should have their next token be a padding token
|
|
405
|
+
if has_eos_stopping_criteria:
|
|
406
|
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
|
407
|
+
|
|
408
|
+
# update generated ids, model inputs, and length for next step
|
|
409
|
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
|
410
|
+
if streamer is not None:
|
|
411
|
+
streamer.put(next_tokens.cpu())
|
|
412
|
+
model_kwargs = self._update_model_kwargs_for_generation(
|
|
413
|
+
outputs,
|
|
414
|
+
model_kwargs,
|
|
415
|
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
|
419
|
+
this_peer_finished = bool(int(unfinished_sequences.max()) == 0)
|
|
420
|
+
|
|
421
|
+
# This is needed to properly delete outputs.logits which may be very large for first iteration
|
|
422
|
+
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
|
|
423
|
+
del outputs
|
|
424
|
+
|
|
425
|
+
if streamer is not None:
|
|
426
|
+
streamer.end()
|
|
427
|
+
|
|
428
|
+
if return_dict_in_generate:
|
|
429
|
+
if self.config.is_encoder_decoder:
|
|
430
|
+
return GenerateEncoderDecoderOutput(
|
|
431
|
+
sequences=input_ids,
|
|
432
|
+
scores=scores,
|
|
433
|
+
logits=raw_logits,
|
|
434
|
+
encoder_attentions=encoder_attentions,
|
|
435
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
436
|
+
decoder_attentions=decoder_attentions,
|
|
437
|
+
cross_attentions=cross_attentions,
|
|
438
|
+
decoder_hidden_states=decoder_hidden_states,
|
|
439
|
+
past_key_values=model_kwargs.get("past_key_values"),
|
|
440
|
+
)
|
|
441
|
+
else:
|
|
442
|
+
return GenerateDecoderOnlyOutput(
|
|
443
|
+
sequences=input_ids,
|
|
444
|
+
scores=scores,
|
|
445
|
+
logits=raw_logits,
|
|
446
|
+
attentions=decoder_attentions,
|
|
447
|
+
hidden_states=decoder_hidden_states,
|
|
448
|
+
past_key_values=model_kwargs.get("past_key_values"),
|
|
449
|
+
)
|
|
450
|
+
else:
|
|
451
|
+
return input_ids
|
|
452
|
+
|
|
453
|
+
def _sample_streaming_unit(
|
|
454
|
+
self,
|
|
455
|
+
input_ids: torch.Tensor,
|
|
456
|
+
logits_processor: LogitsProcessorList,
|
|
457
|
+
stopping_criteria: StoppingCriteriaList,
|
|
458
|
+
generation_config: GenerationConfig,
|
|
459
|
+
synced_gpus: bool,
|
|
460
|
+
streamer: Optional["BaseStreamer"],
|
|
461
|
+
streamer_unit: Optional["BaseStreamer"],
|
|
462
|
+
logits_warper: Optional[LogitsProcessorList],
|
|
463
|
+
**model_kwargs,
|
|
464
|
+
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
|
465
|
+
# init values
|
|
466
|
+
pad_token_id = generation_config._pad_token_tensor
|
|
467
|
+
output_attentions = generation_config.output_attentions
|
|
468
|
+
output_hidden_states = generation_config.output_hidden_states
|
|
469
|
+
output_scores = generation_config.output_scores
|
|
470
|
+
output_logits = generation_config.output_logits
|
|
471
|
+
return_dict_in_generate = generation_config.return_dict_in_generate
|
|
472
|
+
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
|
473
|
+
do_sample = generation_config.do_sample
|
|
474
|
+
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
|
|
475
|
+
raise ValueError(
|
|
476
|
+
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
|
|
477
|
+
f"{logits_warper})."
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
# init attention / hidden states / scores tuples
|
|
481
|
+
# scores = () if (return_dict_in_generate and output_scores) else None
|
|
482
|
+
# raw_logits = () if (return_dict_in_generate and output_logits) else None
|
|
483
|
+
# decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
484
|
+
# cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
485
|
+
# decoder_hidden_states: tuple = () if (return_dict_in_generate and output_hidden_states) else None
|
|
486
|
+
|
|
487
|
+
scores: tuple = ()
|
|
488
|
+
raw_logits: tuple = ()
|
|
489
|
+
decoder_attentions: tuple = ()
|
|
490
|
+
cross_attentions: tuple = ()
|
|
491
|
+
decoder_hidden_states: tuple = ()
|
|
492
|
+
|
|
493
|
+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
|
494
|
+
if return_dict_in_generate and self.config.is_encoder_decoder:
|
|
495
|
+
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
|
496
|
+
encoder_hidden_states = (
|
|
497
|
+
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
# keep track of which sequences are already finished
|
|
501
|
+
batch_size = input_ids.shape[0]
|
|
502
|
+
this_peer_finished = False
|
|
503
|
+
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
|
504
|
+
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
|
505
|
+
|
|
506
|
+
generated_units = torch.tensor([])
|
|
507
|
+
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
|
508
|
+
# prepare model inputs
|
|
509
|
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
|
510
|
+
|
|
511
|
+
# prepare variable output controls (note: some models won't accept all output controls)
|
|
512
|
+
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
|
|
513
|
+
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
|
|
514
|
+
|
|
515
|
+
# forward pass to get next token
|
|
516
|
+
outputs = self(**model_inputs, return_dict=True)
|
|
517
|
+
|
|
518
|
+
if synced_gpus and this_peer_finished:
|
|
519
|
+
continue # don't waste resources running the code we don't need
|
|
520
|
+
|
|
521
|
+
# Clone is needed to avoid keeping a hanging ref to outputs.logits
|
|
522
|
+
# which may be very large for first iteration (the clone itself is always small)
|
|
523
|
+
next_token_logits = outputs.logits[:, -1, :].clone()
|
|
524
|
+
|
|
525
|
+
# pre-process distribution
|
|
526
|
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
|
527
|
+
if do_sample and logits_warper is not None:
|
|
528
|
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
|
529
|
+
|
|
530
|
+
# Store scores, attentions and hidden_states when required
|
|
531
|
+
if return_dict_in_generate:
|
|
532
|
+
if output_scores and scores is not None and next_token_scores is not None:
|
|
533
|
+
scores += (next_token_scores,)
|
|
534
|
+
if output_logits and raw_logits is not None and next_token_logits is not None:
|
|
535
|
+
raw_logits += (next_token_logits,)
|
|
536
|
+
if output_attentions and decoder_attentions is not None:
|
|
537
|
+
decoder_attentions += (
|
|
538
|
+
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
|
539
|
+
)
|
|
540
|
+
if self.config.is_encoder_decoder and cross_attentions is not None:
|
|
541
|
+
cross_attentions += (outputs.cross_attentions,)
|
|
542
|
+
|
|
543
|
+
if output_hidden_states and decoder_hidden_states is not None:
|
|
544
|
+
decoder_hidden_states += (
|
|
545
|
+
(outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,)
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
# token selection
|
|
549
|
+
if do_sample:
|
|
550
|
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
|
551
|
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
|
552
|
+
else:
|
|
553
|
+
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
|
554
|
+
|
|
555
|
+
# speechgen
|
|
556
|
+
hidden_states = torch.cat(
|
|
557
|
+
[decoder_hidden_states[0][-1][:, -1:, :]]
|
|
558
|
+
+ [decoder_hidden_states[i][-1] for i in range(1, len(decoder_hidden_states))],
|
|
559
|
+
dim=1,
|
|
560
|
+
)
|
|
561
|
+
ctc_pred = self.speech_generator.predict(hidden_states.squeeze(0))
|
|
562
|
+
cur_units = ctc_postprocess(ctc_pred, blank=self.model.config.unit_vocab_size)
|
|
563
|
+
|
|
564
|
+
# finished sentences should have their next token be a padding token
|
|
565
|
+
if has_eos_stopping_criteria:
|
|
566
|
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
|
567
|
+
|
|
568
|
+
# update generated ids, model inputs, and length for next step
|
|
569
|
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
|
570
|
+
if streamer is not None:
|
|
571
|
+
streamer.put(next_tokens.cpu())
|
|
572
|
+
if streamer_unit is not None:
|
|
573
|
+
for i in range(len(generated_units), len(cur_units)):
|
|
574
|
+
streamer_unit.put(cur_units[i].unsqueeze(0))
|
|
575
|
+
generated_units = cur_units
|
|
576
|
+
model_kwargs = self._update_model_kwargs_for_generation(
|
|
577
|
+
outputs,
|
|
578
|
+
model_kwargs,
|
|
579
|
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
|
583
|
+
this_peer_finished = bool(int(unfinished_sequences.max()) == 0)
|
|
584
|
+
|
|
585
|
+
# This is needed to properly delete outputs.logits which may be very large for first iteration
|
|
586
|
+
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
|
|
587
|
+
del outputs
|
|
588
|
+
|
|
589
|
+
if streamer is not None:
|
|
590
|
+
streamer.end()
|
|
591
|
+
|
|
592
|
+
if return_dict_in_generate:
|
|
593
|
+
if self.config.is_encoder_decoder:
|
|
594
|
+
return GenerateEncoderDecoderOutput(
|
|
595
|
+
sequences=input_ids,
|
|
596
|
+
scores=scores,
|
|
597
|
+
logits=raw_logits,
|
|
598
|
+
encoder_attentions=encoder_attentions,
|
|
599
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
600
|
+
decoder_attentions=decoder_attentions,
|
|
601
|
+
cross_attentions=cross_attentions,
|
|
602
|
+
decoder_hidden_states=decoder_hidden_states,
|
|
603
|
+
past_key_values=model_kwargs.get("past_key_values"),
|
|
604
|
+
)
|
|
605
|
+
else:
|
|
606
|
+
return GenerateDecoderOnlyOutput(
|
|
607
|
+
sequences=input_ids,
|
|
608
|
+
scores=scores,
|
|
609
|
+
logits=raw_logits,
|
|
610
|
+
attentions=decoder_attentions,
|
|
611
|
+
hidden_states=decoder_hidden_states,
|
|
612
|
+
past_key_values=model_kwargs.get("past_key_values"),
|
|
613
|
+
)
|
|
614
|
+
else:
|
|
615
|
+
return input_ids
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
def ctc_postprocess(tokens, blank):
|
|
619
|
+
_toks = tokens.squeeze(0).tolist()
|
|
620
|
+
deduplicated_toks = [v for i, v in enumerate(_toks) if i == 0 or v != _toks[i - 1]]
|
|
621
|
+
hyp = torch.tensor([v for v in deduplicated_toks if v != blank])
|
|
622
|
+
return hyp
|