sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +26 -4
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +676 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +49 -8
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +42 -8
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +78 -13
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +133 -55
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +434 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +41 -19
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +60 -20
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +80 -53
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +25 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -19
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +78 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +87 -33
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +67 -30
- sglang/srt/lora/mem_pool.py +117 -52
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +18 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +43 -5
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -30
- sglang/srt/managers/scheduler.py +290 -31
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -24
- sglang/srt/managers/tp_worker.py +4 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +36 -21
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +75 -8
- sglang/srt/model_loader/loader.py +171 -3
- sglang/srt/model_loader/weight_utils.py +51 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +31 -88
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +329 -73
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +694 -0
- sglang/srt/models/gemma3_mm.py +468 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +201 -104
- sglang/srt/openai_api/protocol.py +33 -7
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +114 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +140 -54
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +215 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +29 -2
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +56 -5
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1995 @@
|
|
1
|
+
# Copied and adapted from: https://huggingface.co/openbmb/MiniCPM-o-2_6/blob/main/modeling_minicpmo.py
|
2
|
+
|
3
|
+
# Copyright 2023-2024 SGLang Team
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
# ==============================================================================
|
16
|
+
"""Inference-only MiniCPM-o model compatible with HuggingFace weights."""
|
17
|
+
|
18
|
+
import math
|
19
|
+
from dataclasses import dataclass
|
20
|
+
from typing import Any, Iterable, List, Literal, Optional, Tuple, Union
|
21
|
+
|
22
|
+
import numpy as np
|
23
|
+
import torch
|
24
|
+
import torch.nn.functional as F
|
25
|
+
import torch.nn.utils.parametrize as P
|
26
|
+
import torch.types
|
27
|
+
from torch import nn
|
28
|
+
from torch.nn.utils import weight_norm
|
29
|
+
from tqdm import tqdm
|
30
|
+
from transformers import LlamaConfig, LlamaModel, PretrainedConfig, PreTrainedModel
|
31
|
+
from transformers.activations import ACT2FN
|
32
|
+
from transformers.cache_utils import DynamicCache, EncoderDecoderCache
|
33
|
+
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
34
|
+
from transformers.models.whisper.modeling_whisper import (
|
35
|
+
WHISPER_ATTENTION_CLASSES,
|
36
|
+
WhisperConfig,
|
37
|
+
WhisperEncoder,
|
38
|
+
)
|
39
|
+
|
40
|
+
from sglang.srt.layers.quantization import QuantizationConfig
|
41
|
+
from sglang.srt.managers.mm_utils import (
|
42
|
+
MultiModalityDataPaddingPatternTokenPairs,
|
43
|
+
embed_mm_inputs,
|
44
|
+
get_multimodal_data_bounds,
|
45
|
+
)
|
46
|
+
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
47
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
48
|
+
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
49
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
50
|
+
from sglang.srt.models.minicpmv import (
|
51
|
+
Idefics2VisionTransformer,
|
52
|
+
MiniCPMVBaseModel,
|
53
|
+
Resampler2_5,
|
54
|
+
)
|
55
|
+
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
56
|
+
from sglang.srt.utils import logger
|
57
|
+
|
58
|
+
try:
|
59
|
+
from transformers import LogitsWarper
|
60
|
+
from vector_quantize_pytorch import GroupedResidualFSQ
|
61
|
+
from vocos import Vocos
|
62
|
+
from vocos.pretrained import instantiate_class
|
63
|
+
|
64
|
+
_tts_deps = True
|
65
|
+
except:
|
66
|
+
LogitsWarper = None
|
67
|
+
_tts_deps = False
|
68
|
+
|
69
|
+
|
70
|
+
def apply_spk_emb(
|
71
|
+
input_ids: torch.Tensor = None,
|
72
|
+
spk_emb: torch.Tensor = None,
|
73
|
+
input_embeds: torch.Tensor = None,
|
74
|
+
spk_emb_token_id: int = 0,
|
75
|
+
num_spk_embs: int = 1,
|
76
|
+
):
|
77
|
+
"""
|
78
|
+
Replace consecutive `num_spk_embs` speaker embedding placeholders in input_embeds with pre-prepared speaker embeddings. This is an in-place replacement, no new tensor is created, so no value is returned.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
input_ids (torch.Tensor): Input ID tensor, shape [batch_size, seq_len_max]
|
82
|
+
spk_emb (torch.Tensor): Speaker embedding tensor, shape [batch_size, num_spk_emb, hidden_dim]
|
83
|
+
input_embeds (torch.Tensor): Input embedding tensor, shape [batch_size, seq_len_max, hidden_dim]
|
84
|
+
spk_emb_token_id (int): ID of the speaker embedding token
|
85
|
+
num_spk_embs (int): Number of speaker embeddings
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
None
|
89
|
+
"""
|
90
|
+
|
91
|
+
batch_size = input_ids.shape[0]
|
92
|
+
|
93
|
+
for idx in range(batch_size):
|
94
|
+
input_ids_ = input_ids[idx] # [seq_len_max]
|
95
|
+
spk_emb_ = spk_emb[idx] # [num_spk_emb]
|
96
|
+
mask_ = input_ids_ == spk_emb_token_id # [batch_size, seq_len_max]
|
97
|
+
nonzero_position_idx = mask_.nonzero(as_tuple=False) # [num_spk_emb, 1]
|
98
|
+
assert nonzero_position_idx.shape[0] == num_spk_embs
|
99
|
+
begin_idx = nonzero_position_idx.min()
|
100
|
+
end_idx = nonzero_position_idx.max()
|
101
|
+
input_embeds[idx, begin_idx : end_idx + 1, :] = spk_emb_
|
102
|
+
|
103
|
+
return
|
104
|
+
|
105
|
+
|
106
|
+
@dataclass
|
107
|
+
class ConditionalChatTTSGenerationOutput(ModelOutput):
|
108
|
+
"""
|
109
|
+
Output class for ConditionalChatTTS generation.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
new_ids (torch.LongTensor): Newly generated audio code sequence, shape (batch_size, sequence_length, num_vq).
|
113
|
+
audio_input_ids (torch.LongTensor): Updated input IDs including condition and generated audio codes, shape (batch_size, full_sequence_length, num_vq).
|
114
|
+
past_key_values (Tuple[Tuple[torch.FloatTensor]]): Tuple containing pre-computed keys and values used for attention mechanism. Each element has shape (batch_size, num_heads, sequence_length, embed_size_per_head).
|
115
|
+
finished (bool): Boolean indicating whether generation is complete.
|
116
|
+
|
117
|
+
"""
|
118
|
+
|
119
|
+
new_ids: torch.LongTensor = None
|
120
|
+
audio_input_ids: torch.LongTensor = None
|
121
|
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
122
|
+
finished: bool = None
|
123
|
+
|
124
|
+
|
125
|
+
def make_streaming_chunk_mask_generation(
|
126
|
+
inputs_embeds: torch.Tensor,
|
127
|
+
past_seen_tokens: int,
|
128
|
+
streaming_tts_text_mask: torch.Tensor,
|
129
|
+
streaming_reserved_length: int = 300,
|
130
|
+
streaming_audio_chunk_size: int = 50,
|
131
|
+
streaming_text_chunk_size: int = 10,
|
132
|
+
num_spk_emb: int = 1,
|
133
|
+
use_spk_emb: bool = True,
|
134
|
+
) -> torch.Tensor:
|
135
|
+
"""
|
136
|
+
In streaming audio generation, determine which `text` positions the TTS model can attend to when generating each chunk of `audio` tokens.
|
137
|
+
|
138
|
+
This function creates a mask that allows the model to attend to a specific chunk of text
|
139
|
+
tokens when generating each chunk of audio tokens, enabling streaming TTS generation.
|
140
|
+
|
141
|
+
Args:
|
142
|
+
inputs_embeds (torch.Tensor): Input embeddings tensor.
|
143
|
+
past_seen_tokens (int): Number of tokens already seen by the model.
|
144
|
+
streaming_tts_text_mask (torch.Tensor): Mask for the text tokens.
|
145
|
+
streaming_reserved_length (int, optional): Number of reserved tokens for streaming. Defaults to 300.
|
146
|
+
streaming_text_chunk_size (int, optional): Size of each text chunk. Defaults to 7.
|
147
|
+
|
148
|
+
Returns:
|
149
|
+
torch.Tensor: Causal mask for streaming TTS generation, shape is [batch_size=1, 1, seq_len=1, past_seen_tokens+1]
|
150
|
+
|
151
|
+
Raises:
|
152
|
+
AssertionError: If the batch size is not 1 (only supports batch size of 1 for inference).
|
153
|
+
"""
|
154
|
+
assert inputs_embeds.shape[0] == 1
|
155
|
+
|
156
|
+
dtype = inputs_embeds.dtype
|
157
|
+
device = inputs_embeds.device
|
158
|
+
min_dtype = torch.finfo(dtype).min
|
159
|
+
|
160
|
+
# Add `1` to the past seen tokens to account for new `tokens` during `generate`
|
161
|
+
causal_mask = torch.full(
|
162
|
+
(1, past_seen_tokens + inputs_embeds.shape[1]),
|
163
|
+
fill_value=0,
|
164
|
+
dtype=dtype,
|
165
|
+
device=device,
|
166
|
+
)
|
167
|
+
|
168
|
+
# Calculate the start of invisible text tokens
|
169
|
+
invisible_text_tokens_start = (
|
170
|
+
min(
|
171
|
+
math.ceil(
|
172
|
+
(past_seen_tokens - streaming_reserved_length)
|
173
|
+
/ streaming_audio_chunk_size
|
174
|
+
)
|
175
|
+
* streaming_text_chunk_size,
|
176
|
+
streaming_reserved_length,
|
177
|
+
)
|
178
|
+
+ 1
|
179
|
+
+ num_spk_emb * use_spk_emb
|
180
|
+
) # Add 1 for [Stts] and N for [spk_emb] tokens if `use_spk_emb` is True
|
181
|
+
|
182
|
+
invisible_text_tokens_end = (
|
183
|
+
streaming_reserved_length + 1 + num_spk_emb * use_spk_emb + 1
|
184
|
+
) # Add 1 for [Ptts] (aka `audio_bos_token_id`)
|
185
|
+
|
186
|
+
# Set invisible text tokens to min_dtype (effectively -inf)
|
187
|
+
causal_mask[0, invisible_text_tokens_start:invisible_text_tokens_end] = min_dtype
|
188
|
+
|
189
|
+
# Mask padding positions in the text mask
|
190
|
+
causal_mask[
|
191
|
+
0, 0 : 1 + num_spk_emb * use_spk_emb + streaming_reserved_length + 1
|
192
|
+
].masked_fill_(streaming_tts_text_mask == 0, min_dtype)
|
193
|
+
|
194
|
+
# Add extra dimensions for batch and heads
|
195
|
+
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
|
196
|
+
|
197
|
+
return causal_mask
|
198
|
+
|
199
|
+
|
200
|
+
# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py`
|
201
|
+
class ConvNeXtBlock(nn.Module):
|
202
|
+
def __init__(
|
203
|
+
self,
|
204
|
+
dim: int,
|
205
|
+
intermediate_dim: int,
|
206
|
+
kernel: int,
|
207
|
+
dilation: int,
|
208
|
+
layer_scale_init_value: float = 1e-6,
|
209
|
+
):
|
210
|
+
# ConvNeXt Block copied from Vocos.
|
211
|
+
super().__init__()
|
212
|
+
self.dwconv = nn.Conv1d(
|
213
|
+
dim,
|
214
|
+
dim,
|
215
|
+
kernel_size=kernel,
|
216
|
+
padding=dilation * (kernel // 2),
|
217
|
+
dilation=dilation,
|
218
|
+
groups=dim,
|
219
|
+
)
|
220
|
+
|
221
|
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
222
|
+
self.pwconv1 = nn.Linear(dim, intermediate_dim)
|
223
|
+
self.act = nn.GELU()
|
224
|
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
225
|
+
self.coef = (
|
226
|
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
227
|
+
if layer_scale_init_value > 0
|
228
|
+
else None
|
229
|
+
)
|
230
|
+
|
231
|
+
def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor:
|
232
|
+
residual = x
|
233
|
+
|
234
|
+
y = self.dwconv(x)
|
235
|
+
y.transpose_(1, 2) # (B, C, T) -> (B, T, C)
|
236
|
+
x = self.norm(y)
|
237
|
+
del y
|
238
|
+
y = self.pwconv1(x)
|
239
|
+
del x
|
240
|
+
x = self.act(y)
|
241
|
+
del y
|
242
|
+
y = self.pwconv2(x)
|
243
|
+
del x
|
244
|
+
if self.coef is not None:
|
245
|
+
y *= self.coef
|
246
|
+
y.transpose_(1, 2) # (B, T, C) -> (B, C, T)
|
247
|
+
|
248
|
+
x = y + residual
|
249
|
+
del y
|
250
|
+
|
251
|
+
return x
|
252
|
+
|
253
|
+
|
254
|
+
# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py`
|
255
|
+
class DVAEDecoder(nn.Module):
|
256
|
+
def __init__(
|
257
|
+
self,
|
258
|
+
idim: int,
|
259
|
+
odim: int,
|
260
|
+
n_layer=12,
|
261
|
+
bn_dim=64,
|
262
|
+
hidden=256,
|
263
|
+
kernel=7,
|
264
|
+
dilation=2,
|
265
|
+
up=False,
|
266
|
+
):
|
267
|
+
super().__init__()
|
268
|
+
self.up = up
|
269
|
+
self.conv_in = nn.Sequential(
|
270
|
+
nn.Conv1d(idim, bn_dim, 3, 1, 1),
|
271
|
+
nn.GELU(),
|
272
|
+
nn.Conv1d(bn_dim, hidden, 3, 1, 1),
|
273
|
+
)
|
274
|
+
self.decoder_block = nn.ModuleList(
|
275
|
+
[
|
276
|
+
ConvNeXtBlock(
|
277
|
+
hidden,
|
278
|
+
hidden * 4,
|
279
|
+
kernel,
|
280
|
+
dilation,
|
281
|
+
)
|
282
|
+
for _ in range(n_layer)
|
283
|
+
]
|
284
|
+
)
|
285
|
+
self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
|
286
|
+
|
287
|
+
def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor:
|
288
|
+
# B, C, T
|
289
|
+
y = self.conv_in(x)
|
290
|
+
del x
|
291
|
+
for f in self.decoder_block:
|
292
|
+
y = f(y, conditioning)
|
293
|
+
|
294
|
+
x = self.conv_out(y)
|
295
|
+
del y
|
296
|
+
return x
|
297
|
+
|
298
|
+
|
299
|
+
# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py`
|
300
|
+
class GFSQ(nn.Module):
|
301
|
+
def __init__(
|
302
|
+
self,
|
303
|
+
dim: int,
|
304
|
+
levels: List[int],
|
305
|
+
G: int,
|
306
|
+
R: int,
|
307
|
+
eps=1e-5,
|
308
|
+
transpose=True,
|
309
|
+
):
|
310
|
+
super(GFSQ, self).__init__()
|
311
|
+
self.quantizer = GroupedResidualFSQ(
|
312
|
+
dim=dim,
|
313
|
+
levels=list(levels),
|
314
|
+
num_quantizers=R,
|
315
|
+
groups=G,
|
316
|
+
)
|
317
|
+
self.n_ind = math.prod(levels)
|
318
|
+
self.eps = eps
|
319
|
+
self.transpose = transpose
|
320
|
+
self.G = G
|
321
|
+
self.R = R
|
322
|
+
|
323
|
+
def _embed(self, x: torch.Tensor):
|
324
|
+
if self.transpose:
|
325
|
+
x = x.transpose(1, 2)
|
326
|
+
x = x.view(x.size(0), x.size(1), self.G, self.R).permute(2, 0, 1, 3)
|
327
|
+
feat = self.quantizer.get_output_from_indices(x)
|
328
|
+
return feat.transpose_(1, 2) if self.transpose else feat
|
329
|
+
|
330
|
+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
331
|
+
return super().__call__(x)
|
332
|
+
|
333
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
334
|
+
if self.transpose:
|
335
|
+
x.transpose_(1, 2)
|
336
|
+
_, ind = self.quantizer(x)
|
337
|
+
ind = ind.permute(1, 2, 0, 3).contiguous()
|
338
|
+
ind = ind.view(ind.size(0), ind.size(1), -1)
|
339
|
+
return ind.transpose_(1, 2) if self.transpose else ind
|
340
|
+
|
341
|
+
|
342
|
+
# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py`
|
343
|
+
class DVAE(nn.Module):
|
344
|
+
def __init__(
|
345
|
+
self,
|
346
|
+
):
|
347
|
+
super().__init__()
|
348
|
+
|
349
|
+
coef = torch.rand(100)
|
350
|
+
self.coef = nn.Parameter(coef.unsqueeze(0).unsqueeze_(2))
|
351
|
+
|
352
|
+
self.downsample_conv = nn.Sequential(
|
353
|
+
nn.Conv1d(100, 512, 3, 1, 1),
|
354
|
+
nn.GELU(),
|
355
|
+
nn.Conv1d(512, 512, 4, 2, 1),
|
356
|
+
nn.GELU(),
|
357
|
+
)
|
358
|
+
|
359
|
+
self.encoder = DVAEDecoder(
|
360
|
+
idim=512,
|
361
|
+
odim=1024,
|
362
|
+
hidden=256,
|
363
|
+
n_layer=12,
|
364
|
+
bn_dim=128,
|
365
|
+
)
|
366
|
+
|
367
|
+
self.decoder = DVAEDecoder(
|
368
|
+
idim=512,
|
369
|
+
odim=512,
|
370
|
+
hidden=256,
|
371
|
+
n_layer=12,
|
372
|
+
bn_dim=128,
|
373
|
+
)
|
374
|
+
|
375
|
+
self.out_conv = nn.Conv1d(512, 100, 3, 1, 1, bias=False)
|
376
|
+
|
377
|
+
self.vq_layer = GFSQ(
|
378
|
+
dim=1024,
|
379
|
+
levels=(5, 5, 5, 5),
|
380
|
+
G=2,
|
381
|
+
R=2,
|
382
|
+
)
|
383
|
+
|
384
|
+
@torch.inference_mode()
|
385
|
+
def forward(
|
386
|
+
self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode"
|
387
|
+
) -> torch.Tensor:
|
388
|
+
if mode == "encode" and hasattr(self, "encoder") and self.vq_layer is not None:
|
389
|
+
mel = inp.clone()
|
390
|
+
x: torch.Tensor = self.downsample_conv(
|
391
|
+
torch.div(mel, self.coef.view(100, 1).expand(mel.shape), out=mel),
|
392
|
+
).unsqueeze_(0)
|
393
|
+
del mel
|
394
|
+
x = self.encoder(x)
|
395
|
+
ind = self.vq_layer(x)
|
396
|
+
del x
|
397
|
+
return ind
|
398
|
+
|
399
|
+
if self.vq_layer is not None:
|
400
|
+
vq_feats = self.vq_layer._embed(inp)
|
401
|
+
else:
|
402
|
+
vq_feats = inp
|
403
|
+
|
404
|
+
vq_feats = (
|
405
|
+
vq_feats.view(
|
406
|
+
(vq_feats.size(0), 2, vq_feats.size(1) // 2, vq_feats.size(2)),
|
407
|
+
)
|
408
|
+
.permute(0, 2, 3, 1)
|
409
|
+
.flatten(2)
|
410
|
+
)
|
411
|
+
|
412
|
+
dec_out = self.out_conv(
|
413
|
+
self.decoder(
|
414
|
+
x=vq_feats,
|
415
|
+
),
|
416
|
+
)
|
417
|
+
|
418
|
+
del vq_feats
|
419
|
+
|
420
|
+
return torch.mul(dec_out, self.coef, out=dec_out)
|
421
|
+
|
422
|
+
|
423
|
+
# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/processors.py`
|
424
|
+
class CustomRepetitionPenaltyLogitsProcessorRepeat:
|
425
|
+
def __init__(self, penalty: float, max_input_ids: int, past_window: int):
|
426
|
+
if not isinstance(penalty, float) or not (penalty > 0):
|
427
|
+
raise ValueError(
|
428
|
+
f"`penalty` has to be a strictly positive float, but is {penalty}"
|
429
|
+
)
|
430
|
+
|
431
|
+
self.penalty = penalty
|
432
|
+
self.max_input_ids = max_input_ids
|
433
|
+
self.past_window = past_window
|
434
|
+
|
435
|
+
def __call__(
|
436
|
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
437
|
+
) -> torch.FloatTensor:
|
438
|
+
if input_ids.size(1) > self.past_window:
|
439
|
+
input_ids = input_ids.narrow(1, -self.past_window, self.past_window)
|
440
|
+
freq = F.one_hot(input_ids, scores.size(1)).sum(1)
|
441
|
+
if freq.size(0) > self.max_input_ids:
|
442
|
+
freq.narrow(
|
443
|
+
0, self.max_input_ids, freq.size(0) - self.max_input_ids
|
444
|
+
).zero_()
|
445
|
+
alpha = torch.pow(self.penalty, freq)
|
446
|
+
scores = scores.contiguous()
|
447
|
+
inp = scores.multiply(alpha)
|
448
|
+
oth = scores.divide(alpha)
|
449
|
+
con = scores < 0
|
450
|
+
out = torch.where(con, inp, oth)
|
451
|
+
del inp, oth, scores, con, alpha
|
452
|
+
return out
|
453
|
+
|
454
|
+
|
455
|
+
class ConditionalChatTTS(PreTrainedModel):
|
456
|
+
"""A conditional text-to-speech model that can generate speech from text with speaker conditioning.
|
457
|
+
|
458
|
+
This model extends PreTrainedModel to provide text-to-speech capabilities with:
|
459
|
+
- LLM hidden state conditioning
|
460
|
+
- Streaming generation
|
461
|
+
|
462
|
+
The model uses a transformer architecture with LLM hidden states and can operate in both
|
463
|
+
streaming and non-streaming modes for flexible deployment.
|
464
|
+
|
465
|
+
The model process sequence in the following format:
|
466
|
+
| text bos token | LLM embedding projected to tts embedding space | text tokens (fixed length, reserved for future tokens) | audio bos token | audio tokens (audio token length is not fixed)| audio eos token |
|
467
|
+
|
468
|
+
The format is designed to support LLM-conditioned streaming audio generation.
|
469
|
+
|
470
|
+
Usage:
|
471
|
+
To support streaming generation, two global variables should be maintained outside of the model.
|
472
|
+
1. `audio_input_ids`: stores *discrete* audio codes. It is a tensor with shape [1, sequence length+1, num_vq].
|
473
|
+
2. `past_key_values`: stores the KV cache for both text tokens and audio codes. It is a list of tuples, each tuple contains two tensors with shape [1, num_attention_heads, sequence length, hidden_size // num_attention_heads]
|
474
|
+
|
475
|
+
where `num_vq` is the number of audio codebooks, in default setting, it is `4`.
|
476
|
+
|
477
|
+
1. Create an empty `past_key_values` with
|
478
|
+
```python
|
479
|
+
initial_kv_cache_length = 1 + model.num_spk_embs + model.streaming_text_reserved_len # where `1` denotes the `bos` token
|
480
|
+
dtype = model.emb_text.weight.dtype
|
481
|
+
device = model.emb_text.weight.device
|
482
|
+
past_key_values = [
|
483
|
+
(
|
484
|
+
torch.zeros(1, model.config.num_attention_heads, initial_kv_cache_length, model.config.hidden_size // model.config.num_attention_heads, dtype=dtype, device=device),
|
485
|
+
torch.zeros(1, model.config.num_attention_heads, initial_kv_cache_length, model.config.hidden_size // model.config.num_attention_heads, dtype=dtype, device=device)
|
486
|
+
)
|
487
|
+
for _ in range(model.config.num_hidden_layers)
|
488
|
+
]
|
489
|
+
|
490
|
+
2. At the same time, create an empty `audio_input_ids` with shape [1, sequence length, num_vq], `num_vq` denotes multiple layer audio codebooks. But here we also include text tokens in the sequence, but they will be zeros, and will not be used, just a placeholder.
|
491
|
+
|
492
|
+
```python
|
493
|
+
initial_audio_input_ids_length = 1 + model.num_spk_embs + model.streaming_text_reserved_len + 1
|
494
|
+
# [bos token, speaker embeddings, text tokens, audio bos token]
|
495
|
+
audio_input_ids = torch.zeros(batch_size=1, initial_audio_input_ids_length, model.num_vq)
|
496
|
+
```
|
497
|
+
|
498
|
+
2. Prefill some text tokens to TTS model (for example, 10 tokens) using `prefill_text` method.
|
499
|
+
|
500
|
+
```python
|
501
|
+
outputs = llm.generate(**kwargs)
|
502
|
+
llm_tokens = some_function_to_extract_llm_tokens(outputs)
|
503
|
+
lm_spk_emb_last_hidden_states = some_function_to_extract_lm_spk_emb_last_hidden_states(outputs)
|
504
|
+
tts_text_input_ids = tts_tokenizer.encode(llm_tokenizer.decode(llm_tokens))
|
505
|
+
# here assume we are prefilling text token 0 to text token 9 (included), totally 10 tokens.
|
506
|
+
begin = 0
|
507
|
+
end = 9+1
|
508
|
+
position_ids = torch.arange(begin, end, dtype=torch.long, device=device)
|
509
|
+
|
510
|
+
past_key_values = model.prefill_text(
|
511
|
+
input_ids=tts_text_input_ids,
|
512
|
+
position_ids=position_ids,
|
513
|
+
past_key_values=past_key_values,
|
514
|
+
lm_spk_emb_last_hidden_states=lm_spk_emb_last_hidden_states,
|
515
|
+
)
|
516
|
+
```
|
517
|
+
|
518
|
+
3. Make a `streaming_tts_text_mask` to denote which position contains valid text tokens, similar to `attention_mask` in standard causal attention.
|
519
|
+
|
520
|
+
```python
|
521
|
+
streaming_tts_text_mask = torch.zeros(model.streaming_reserved_length)
|
522
|
+
streaming_tts_text_mask[0:end] = 1 # denotes these post
|
523
|
+
```
|
524
|
+
|
525
|
+
3. Generate audio codes using `generate` method.
|
526
|
+
|
527
|
+
```python
|
528
|
+
outputs = model.generate(
|
529
|
+
input_ids=audio_input_ids,
|
530
|
+
past_key_values=past_key_values,
|
531
|
+
streaming_tts_text_mask=streaming_tts_text_mask,
|
532
|
+
max_new_token=50,
|
533
|
+
)
|
534
|
+
|
535
|
+
# update past_key_values and input_ids
|
536
|
+
past_key_values = outputs.past_key_values
|
537
|
+
audio_input_ids = outputs.input_ids
|
538
|
+
```
|
539
|
+
|
540
|
+
The `past_key_values` is extended by `max_new_token=50`, and `audio_input_ids` is also extended by `max_new_token=50` after `generate` calling.
|
541
|
+
|
542
|
+
4. Notice that after prefilling `10` text tokens, the model can generate up to `50` audio tokens, if you want to generate more audio tokens, you need to prefill next `10` text tokens. And it is okay to only generate `25` audio tokens for faster initial response.
|
543
|
+
|
544
|
+
5. Repeat steps `2,3,4` as needed in your streaming audio generation cases, but ensure usage complies with the following guidelines discussed above.
|
545
|
+
"""
|
546
|
+
|
547
|
+
config_class = PretrainedConfig
|
548
|
+
_no_split_modules = []
|
549
|
+
|
550
|
+
def __init__(self, config: PretrainedConfig):
|
551
|
+
super().__init__(config)
|
552
|
+
|
553
|
+
self.use_speaker_embedding = config.use_speaker_embedding
|
554
|
+
self.use_llm_hidden_state = config.use_llm_hidden_state
|
555
|
+
self.num_spk_embs = config.num_spk_embs
|
556
|
+
self.spk_emb_token_id = config.spk_emb_token_id
|
557
|
+
|
558
|
+
self.use_text = config.use_text
|
559
|
+
self.streaming = config.streaming
|
560
|
+
self.streaming_text_chunk_size = config.streaming_text_chunk_size
|
561
|
+
self.streaming_audio_chunk_size = config.streaming_audio_chunk_size
|
562
|
+
self.streaming_text_reserved_len = config.streaming_text_reserved_len
|
563
|
+
self.audio_bos_token_id = config.audio_bos_token_id
|
564
|
+
self.num_mel_bins = config.num_mel_bins
|
565
|
+
self.num_vq = config.num_vq
|
566
|
+
self.num_audio_tokens = config.num_audio_tokens
|
567
|
+
|
568
|
+
self.top_p = config.top_p
|
569
|
+
self.top_k = config.top_k
|
570
|
+
self.repetition_penalty = config.repetition_penalty
|
571
|
+
|
572
|
+
if self.config.use_mlp:
|
573
|
+
self.projector = MultiModalProjector(config.llm_dim, config.hidden_size)
|
574
|
+
else:
|
575
|
+
self.projector = nn.Linear(config.llm_dim, config.hidden_size, bias=False)
|
576
|
+
self.emb_code = nn.ModuleList(
|
577
|
+
[
|
578
|
+
nn.Embedding(config.num_audio_tokens, config.hidden_size)
|
579
|
+
for _ in range(config.num_vq)
|
580
|
+
]
|
581
|
+
)
|
582
|
+
self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size)
|
583
|
+
self.head_code = nn.ModuleList(
|
584
|
+
[
|
585
|
+
weight_norm(
|
586
|
+
nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False),
|
587
|
+
name="weight",
|
588
|
+
)
|
589
|
+
for _ in range(config.num_vq)
|
590
|
+
]
|
591
|
+
)
|
592
|
+
|
593
|
+
dvae = DVAE()
|
594
|
+
self.dvae = dvae
|
595
|
+
|
596
|
+
model_config = LlamaConfig(
|
597
|
+
hidden_size=config.hidden_size,
|
598
|
+
intermediate_size=config.intermediate_size,
|
599
|
+
num_attention_heads=config.num_attention_heads,
|
600
|
+
num_hidden_layers=config.num_hidden_layers,
|
601
|
+
max_position_embeddings=config.max_position_embeddings,
|
602
|
+
attn_implementation=config.attn_implementation,
|
603
|
+
)
|
604
|
+
|
605
|
+
model = LlamaModel(model_config)
|
606
|
+
self.model = model
|
607
|
+
|
608
|
+
@torch.inference_mode()
|
609
|
+
def merge_inputs_embeds(
|
610
|
+
self,
|
611
|
+
input_ids: torch.Tensor,
|
612
|
+
lm_spk_emb_last_hidden_states: Optional[torch.Tensor] = None,
|
613
|
+
):
|
614
|
+
"""Merge `input_ids` and `lm_spk_emb_last_hidden_states` to `inputs_embeds`.
|
615
|
+
|
616
|
+
Args:
|
617
|
+
input_ids (torch.Tensor): Input token IDs.
|
618
|
+
lm_spk_emb_last_hidden_states (Optional[torch.Tensor], optional): Last hidden states of speaker embeddings from the language model. Defaults to None.
|
619
|
+
|
620
|
+
Raises:
|
621
|
+
NotImplementedError: If speaker embedding is not used and language model hidden states are not implemented.
|
622
|
+
|
623
|
+
Returns:
|
624
|
+
torch.Tensor: Prepared input embeddings for the model.
|
625
|
+
"""
|
626
|
+
assert input_ids.shape[0] == 1
|
627
|
+
|
628
|
+
# Embed input_ids to input_embeds
|
629
|
+
inputs_embeds = self.emb_text(input_ids)
|
630
|
+
|
631
|
+
# Inject speaker embedding to input_embeds if it exists
|
632
|
+
if self.use_speaker_embedding:
|
633
|
+
spk_emb_mask = input_ids == self.spk_emb_token_id
|
634
|
+
if spk_emb_mask.any():
|
635
|
+
assert lm_spk_emb_last_hidden_states is not None
|
636
|
+
# Project spk emb to tts hidden size first, [batch_size, num_spk_emb, llm_dim] -> [batch_size, num_spk_emb, self.hidden_size]
|
637
|
+
lm_spk_emb_last_hidden_states = lm_spk_emb_last_hidden_states.to(
|
638
|
+
self.projector.linear1.weight.dtype
|
639
|
+
)
|
640
|
+
projected_spk_emb = self.projector(lm_spk_emb_last_hidden_states)
|
641
|
+
projected_spk_emb = F.normalize(projected_spk_emb, p=2, dim=-1)
|
642
|
+
apply_spk_emb(
|
643
|
+
input_ids=input_ids,
|
644
|
+
spk_emb=projected_spk_emb,
|
645
|
+
input_embeds=inputs_embeds,
|
646
|
+
spk_emb_token_id=self.spk_emb_token_id,
|
647
|
+
num_spk_embs=self.num_spk_embs,
|
648
|
+
)
|
649
|
+
else:
|
650
|
+
raise NotImplementedError
|
651
|
+
|
652
|
+
return inputs_embeds
|
653
|
+
|
654
|
+
@torch.inference_mode()
|
655
|
+
def prefill_text(
|
656
|
+
self,
|
657
|
+
input_ids: torch.Tensor,
|
658
|
+
position_ids: torch.LongTensor,
|
659
|
+
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
|
660
|
+
lm_spk_emb_last_hidden_states: Optional[torch.Tensor] = None,
|
661
|
+
):
|
662
|
+
"""Prefill a chunk of new text tokens in streaming setting.
|
663
|
+
Specifically speaking, update `past_key_values` using new text tokens, then the model will read the new text tokens.
|
664
|
+
|
665
|
+
Args:
|
666
|
+
input_ids (Tensor): Tensor of shape [batch_size, seq_len]
|
667
|
+
position_ids (LongTensor): Tensor of shape [batch_size, seq_len]
|
668
|
+
past_key_values (List[Tuple[Tensor]]): KV Cache of all layers, each layer is a tuple (Tensor, Tensor) denoting keys and values. Each tensor is of seq_len = `self.streaming_text_reserved_len`. `past_key_values` will be updated.
|
669
|
+
lm_spk_emb_last_hidden_states (Tensor, optional): Tensor of shape [batch_size, num_spk_emb, llm_dim]. Defaults to None.
|
670
|
+
|
671
|
+
Note that all `batch_size` should be `1`.
|
672
|
+
"""
|
673
|
+
assert input_ids.shape[0] == 1
|
674
|
+
assert past_key_values is not None
|
675
|
+
|
676
|
+
# Merge text and LLM embeddings
|
677
|
+
inputs_embeds = self.merge_inputs_embeds(
|
678
|
+
input_ids=input_ids,
|
679
|
+
lm_spk_emb_last_hidden_states=lm_spk_emb_last_hidden_states,
|
680
|
+
)
|
681
|
+
|
682
|
+
# Clone KV Cache
|
683
|
+
past_key_values_for_prefill = []
|
684
|
+
for i in range(len(past_key_values)):
|
685
|
+
past_key_values_for_prefill.append(
|
686
|
+
(
|
687
|
+
past_key_values[i][0][:, :, : position_ids[:, 0], :].clone(),
|
688
|
+
past_key_values[i][1][:, :, : position_ids[:, 0], :].clone(),
|
689
|
+
)
|
690
|
+
)
|
691
|
+
|
692
|
+
# ModelMiniCPMVBaseModel
|
693
|
+
outputs_prefill: BaseModelOutputWithPast = self.model(
|
694
|
+
attention_mask=None, # because for text, it is standard causal attention mask, do nothing
|
695
|
+
position_ids=position_ids, # position_ids denotes the position of new text tokens in the sequence
|
696
|
+
past_key_values=past_key_values_for_prefill, # `past_key_values` will be updated by the model
|
697
|
+
inputs_embeds=inputs_embeds, # contains text and language model embedding
|
698
|
+
use_cache=True,
|
699
|
+
output_attentions=False,
|
700
|
+
cache_position=position_ids, # which new positions will use this cache, basically the same as position_ids
|
701
|
+
)
|
702
|
+
|
703
|
+
# Get model updated KV Cache
|
704
|
+
past_key_values_for_prefill_updated = outputs_prefill.past_key_values
|
705
|
+
|
706
|
+
# Update generated KV Cache to input `past_key_values`
|
707
|
+
for layer_idx in range(len(past_key_values)):
|
708
|
+
# Update keys
|
709
|
+
past_key_values[layer_idx][0][
|
710
|
+
:, :, position_ids[:, 0] : position_ids[:, -1] + 1, :
|
711
|
+
] = past_key_values_for_prefill_updated[layer_idx][0][
|
712
|
+
:, :, position_ids[:, 0] : position_ids[:, -1] + 1
|
713
|
+
].clone()
|
714
|
+
# Update values
|
715
|
+
past_key_values[layer_idx][1][
|
716
|
+
:, :, position_ids[:, 0] : position_ids[:, -1] + 1, :
|
717
|
+
] = past_key_values_for_prefill_updated[layer_idx][1][
|
718
|
+
:, :, position_ids[:, 0] : position_ids[:, -1] + 1
|
719
|
+
].clone()
|
720
|
+
|
721
|
+
# TODO: del past_key_values_for_prefill_updated recursively
|
722
|
+
# TODO: del outputs_prefill recursively
|
723
|
+
|
724
|
+
return past_key_values
|
725
|
+
|
726
|
+
@torch.inference_mode()
|
727
|
+
def prefill_audio_ids(
|
728
|
+
self,
|
729
|
+
input_ids: torch.Tensor,
|
730
|
+
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
|
731
|
+
streaming_tts_text_mask=None,
|
732
|
+
add_audio_bos: bool = True,
|
733
|
+
):
|
734
|
+
"""Prefill a chunk of audio ids to the model. Used in sliding-window long audio generation.
|
735
|
+
Specifically, prefill many audio ids (typically from last window) to the model in the new window.
|
736
|
+
|
737
|
+
Args:
|
738
|
+
input_ids (torch.Tensor): (1, seq_len, num_vq) Audio input token ids.
|
739
|
+
past_key_values (List[Tuple[torch.Tensor, torch.Tensor]]): Past key values for attention mechanism.
|
740
|
+
"""
|
741
|
+
assert input_ids.shape[0] == 1
|
742
|
+
assert past_key_values is not None
|
743
|
+
|
744
|
+
code_emb = [self.emb_code[i](input_ids[:, :, i]) for i in range(self.num_vq)]
|
745
|
+
inputs_embeds = torch.stack(code_emb, 3).sum(3) # [1,seq_len,768]
|
746
|
+
input_len = input_ids.shape[1]
|
747
|
+
|
748
|
+
if add_audio_bos:
|
749
|
+
narrowed_input_ids = torch.tensor(
|
750
|
+
[[self.audio_bos_token_id]], dtype=torch.long, device=self.device
|
751
|
+
)
|
752
|
+
bos_inputs_embeds = self.emb_text(narrowed_input_ids)
|
753
|
+
inputs_embeds = torch.cat([bos_inputs_embeds, inputs_embeds], dim=1)
|
754
|
+
input_len += 1
|
755
|
+
|
756
|
+
past_key_values_length = past_key_values[0][0].shape[2]
|
757
|
+
position_ids = torch.arange(
|
758
|
+
past_key_values_length,
|
759
|
+
past_key_values_length + input_len,
|
760
|
+
dtype=torch.long,
|
761
|
+
device=self.device,
|
762
|
+
).unsqueeze(0)
|
763
|
+
|
764
|
+
cache_position = position_ids.clone()
|
765
|
+
causal_mask = make_streaming_chunk_mask_generation(
|
766
|
+
inputs_embeds=inputs_embeds,
|
767
|
+
past_seen_tokens=past_key_values[0][0].shape[2],
|
768
|
+
streaming_tts_text_mask=streaming_tts_text_mask,
|
769
|
+
streaming_reserved_length=self.streaming_text_reserved_len,
|
770
|
+
streaming_text_chunk_size=self.streaming_text_chunk_size,
|
771
|
+
) # [1, 1, 1, past_key_values_length + input_len]
|
772
|
+
|
773
|
+
# Model forward
|
774
|
+
outputs: BaseModelOutputWithPast = self.model(
|
775
|
+
attention_mask=causal_mask,
|
776
|
+
position_ids=position_ids,
|
777
|
+
past_key_values=past_key_values,
|
778
|
+
inputs_embeds=inputs_embeds,
|
779
|
+
use_cache=True,
|
780
|
+
output_attentions=False,
|
781
|
+
cache_position=cache_position,
|
782
|
+
)
|
783
|
+
past_key_values = outputs.past_key_values
|
784
|
+
return past_key_values
|
785
|
+
|
786
|
+
@torch.inference_mode()
|
787
|
+
def generate(
|
788
|
+
self,
|
789
|
+
input_ids: torch.Tensor,
|
790
|
+
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
|
791
|
+
temperature: torch.Tensor,
|
792
|
+
eos_token: Union[int, torch.Tensor],
|
793
|
+
streaming_tts_text_mask=None,
|
794
|
+
force_no_stop=False,
|
795
|
+
min_new_token=10,
|
796
|
+
max_new_token=50,
|
797
|
+
logits_warpers: List[LogitsWarper] = [],
|
798
|
+
logits_processors: List[CustomRepetitionPenaltyLogitsProcessorRepeat] = [],
|
799
|
+
show_tqdm=False,
|
800
|
+
):
|
801
|
+
"""Generate audio codes in streaming setting or non-streaming setting.
|
802
|
+
Specifically speaking, generate audio codes when not all text tokens are prefilled.
|
803
|
+
|
804
|
+
Always pass a valid `past_key_values` to the method. The method does not do `prefill` by itself. It relies on `prefill_text` method to provide valid `past_key_values`. Please refer to docstring of this class for more details.
|
805
|
+
|
806
|
+
In this method, we borrowed a lot of codes from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/gpt.py`.
|
807
|
+
|
808
|
+
Args:
|
809
|
+
input_ids (torch.Tensor): Input token ids.
|
810
|
+
past_key_values (List[Tuple[torch.Tensor, torch.Tensor]]): Past key values for attention mechanism.
|
811
|
+
temperature (torch.Tensor): Temperature for sampling.
|
812
|
+
eos_token (Union[int, torch.Tensor]): End of sequence token.
|
813
|
+
streaming_tts_text_mask (Optional[torch.Tensor], optional): Mask for streaming TTS text. Defaults to None.
|
814
|
+
max_new_token (int, optional): Maximum number of new tokens to generate. Defaults to 50.
|
815
|
+
logits_warpers (List[LogitsWarper], optional): List of logits warpers. Defaults to [].
|
816
|
+
logits_processors (List[CustomRepetitionPenaltyLogitsProcessorRepeat], optional): List of logits processors. Defaults to [].
|
817
|
+
show_tqdm (bool, optional): Whether to show progress bar. Defaults to True.
|
818
|
+
|
819
|
+
Returns:
|
820
|
+
GenerationOutputs: Generation outputs.
|
821
|
+
"""
|
822
|
+
|
823
|
+
# We only support batch size `1` for now
|
824
|
+
assert input_ids.shape[0] == 1
|
825
|
+
assert past_key_values is not None
|
826
|
+
|
827
|
+
# fix: this should not be `input_ids.shape[1]`
|
828
|
+
# start_idx = input_ids.shape[1]
|
829
|
+
start_idx = (
|
830
|
+
1
|
831
|
+
+ self.num_spk_embs * self.use_speaker_embedding
|
832
|
+
+ self.streaming_text_reserved_len
|
833
|
+
+ 1
|
834
|
+
)
|
835
|
+
|
836
|
+
finish = torch.zeros(input_ids.shape[0], device=input_ids.device).bool()
|
837
|
+
|
838
|
+
temperature = (
|
839
|
+
temperature.unsqueeze(0)
|
840
|
+
.expand(input_ids.shape[0], -1)
|
841
|
+
.contiguous()
|
842
|
+
.view(-1, 1)
|
843
|
+
)
|
844
|
+
|
845
|
+
progress = input_ids.shape[1]
|
846
|
+
|
847
|
+
# Pre-allocate input_ids, shape is [batch_size=1, max_possible_seq_len, self.num_vqs]
|
848
|
+
input_ids_buf = torch.zeros(
|
849
|
+
input_ids.shape[0], # batch_size
|
850
|
+
progress
|
851
|
+
+ max_new_token, # max_possible_seq_len = input_ids.shape[1] + max_new_token
|
852
|
+
input_ids.shape[2], # self.num_vqs
|
853
|
+
dtype=input_ids.dtype,
|
854
|
+
device=input_ids.device,
|
855
|
+
)
|
856
|
+
|
857
|
+
# Copy existing `input_ids` to `input_ids_buf`
|
858
|
+
input_ids_buf.narrow(1, 0, progress).copy_(input_ids)
|
859
|
+
|
860
|
+
del input_ids
|
861
|
+
input_ids = input_ids_buf.narrow(1, 0, progress)
|
862
|
+
|
863
|
+
pbar: Optional[tqdm] = None
|
864
|
+
if show_tqdm:
|
865
|
+
pbar = tqdm(
|
866
|
+
total=max_new_token,
|
867
|
+
desc="code",
|
868
|
+
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]",
|
869
|
+
)
|
870
|
+
|
871
|
+
condition_length = (
|
872
|
+
1
|
873
|
+
+ self.num_spk_embs * self.use_speaker_embedding
|
874
|
+
+ self.streaming_text_reserved_len
|
875
|
+
+ 1
|
876
|
+
)
|
877
|
+
|
878
|
+
for i in range(max_new_token):
|
879
|
+
# Prepare generation inputs
|
880
|
+
audio_bos = False
|
881
|
+
|
882
|
+
# If this is the first audio token, the case is SPECIAL
|
883
|
+
if progress == condition_length:
|
884
|
+
audio_bos = True
|
885
|
+
|
886
|
+
assert progress == (
|
887
|
+
past_key_values[0][0].shape[2] + 1
|
888
|
+
) # If you are using according to the guidelines, this should be passed.
|
889
|
+
|
890
|
+
if audio_bos:
|
891
|
+
# Generate the first token, activate the model with `self.audio_bos_token_id`, the model will predict
|
892
|
+
# a new audio token. This is a special case because without the `audio bos token`, it is impossible
|
893
|
+
# to generate the first audio token in our streaming setting.
|
894
|
+
narrowed_input_ids = torch.tensor(
|
895
|
+
[[self.audio_bos_token_id]], dtype=torch.long, device=self.device
|
896
|
+
)
|
897
|
+
inputs_embeds = self.emb_text(narrowed_input_ids)
|
898
|
+
del narrowed_input_ids
|
899
|
+
else:
|
900
|
+
# Generate the following audio tokens, it is applicable to all other cases, including second and the
|
901
|
+
# following calling of `generate`.
|
902
|
+
narrowed_input_ids = input_ids.narrow(
|
903
|
+
dim=1, start=input_ids.shape[1] - 1, length=1
|
904
|
+
)
|
905
|
+
code_emb = [
|
906
|
+
self.emb_code[i](narrowed_input_ids[:, :, i])
|
907
|
+
for i in range(self.num_vq)
|
908
|
+
]
|
909
|
+
inputs_embeds = torch.stack(code_emb, 3).sum(3)
|
910
|
+
|
911
|
+
position_ids = torch.tensor(
|
912
|
+
[past_key_values[0][0].shape[2]], dtype=torch.long, device=self.device
|
913
|
+
).unsqueeze(0)
|
914
|
+
|
915
|
+
cache_position = position_ids.clone()
|
916
|
+
|
917
|
+
# Make causal mask
|
918
|
+
causal_mask = make_streaming_chunk_mask_generation(
|
919
|
+
inputs_embeds=inputs_embeds,
|
920
|
+
past_seen_tokens=past_key_values[0][0].shape[2],
|
921
|
+
streaming_tts_text_mask=streaming_tts_text_mask,
|
922
|
+
streaming_reserved_length=self.streaming_text_reserved_len,
|
923
|
+
streaming_text_chunk_size=self.streaming_text_chunk_size,
|
924
|
+
)
|
925
|
+
|
926
|
+
# Model forward
|
927
|
+
outputs: BaseModelOutputWithPast = self.model(
|
928
|
+
attention_mask=causal_mask,
|
929
|
+
position_ids=position_ids,
|
930
|
+
past_key_values=past_key_values,
|
931
|
+
inputs_embeds=inputs_embeds,
|
932
|
+
use_cache=True,
|
933
|
+
output_attentions=False,
|
934
|
+
cache_position=cache_position,
|
935
|
+
)
|
936
|
+
|
937
|
+
del position_ids
|
938
|
+
del inputs_embeds
|
939
|
+
del cache_position
|
940
|
+
del causal_mask
|
941
|
+
|
942
|
+
hidden_states = outputs.last_hidden_state
|
943
|
+
past_key_values = outputs.past_key_values
|
944
|
+
|
945
|
+
with P.cached():
|
946
|
+
logits = torch.empty(
|
947
|
+
hidden_states.size(0),
|
948
|
+
hidden_states.size(1),
|
949
|
+
self.num_audio_tokens,
|
950
|
+
self.num_vq,
|
951
|
+
dtype=torch.float,
|
952
|
+
device=self.device,
|
953
|
+
)
|
954
|
+
for num_vq_iter in range(self.num_vq):
|
955
|
+
x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
|
956
|
+
logits[..., num_vq_iter] = x
|
957
|
+
del x
|
958
|
+
|
959
|
+
del hidden_states
|
960
|
+
|
961
|
+
# logits = logits[:, -1].float()
|
962
|
+
logits = logits.narrow(1, -1, 1).squeeze_(1).float()
|
963
|
+
|
964
|
+
# logits = rearrange(logits, "b c n -> (b n) c")
|
965
|
+
logits = logits.permute(0, 2, 1)
|
966
|
+
logits = logits.reshape(-1, logits.size(2))
|
967
|
+
# logits_token = rearrange(input_ids[:, start_idx:], "b c n -> (b n) c")
|
968
|
+
input_ids_sliced = input_ids.narrow(
|
969
|
+
1,
|
970
|
+
start_idx,
|
971
|
+
input_ids.size(1) - start_idx,
|
972
|
+
).permute(0, 2, 1)
|
973
|
+
logits_token = input_ids_sliced.reshape(
|
974
|
+
input_ids_sliced.size(0) * input_ids_sliced.size(1),
|
975
|
+
-1,
|
976
|
+
).to(self.device)
|
977
|
+
del input_ids_sliced
|
978
|
+
|
979
|
+
logits /= temperature
|
980
|
+
|
981
|
+
if not audio_bos:
|
982
|
+
for logitsProcessors in logits_processors:
|
983
|
+
logits = logitsProcessors(logits_token, logits)
|
984
|
+
if not audio_bos:
|
985
|
+
for logitsWarpers in logits_warpers:
|
986
|
+
logits = logitsWarpers(logits_token, logits)
|
987
|
+
|
988
|
+
del logits_token
|
989
|
+
|
990
|
+
if i < min_new_token:
|
991
|
+
logits[:, eos_token] = -torch.inf
|
992
|
+
|
993
|
+
if force_no_stop:
|
994
|
+
logits[:, eos_token] = -torch.inf
|
995
|
+
|
996
|
+
scores = F.softmax(logits, dim=-1)
|
997
|
+
|
998
|
+
del logits
|
999
|
+
idx_next = torch.multinomial(scores, num_samples=1) # .to(finish.device)
|
1000
|
+
|
1001
|
+
del scores
|
1002
|
+
|
1003
|
+
# idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
|
1004
|
+
idx_next = idx_next.view(-1, self.num_vq)
|
1005
|
+
finish_or = idx_next.eq(eos_token).any(1)
|
1006
|
+
finish.logical_or_(finish_or)
|
1007
|
+
|
1008
|
+
del finish_or
|
1009
|
+
# Store new `token` into `input_ids_buf`
|
1010
|
+
input_ids_buf.narrow(1, progress, 1).copy_(idx_next.unsqueeze_(1))
|
1011
|
+
|
1012
|
+
if i == 0 and finish.any():
|
1013
|
+
# raise Exception
|
1014
|
+
break
|
1015
|
+
|
1016
|
+
del idx_next
|
1017
|
+
progress += 1
|
1018
|
+
input_ids = input_ids_buf.narrow(1, 0, progress)
|
1019
|
+
|
1020
|
+
if finish.all():
|
1021
|
+
break
|
1022
|
+
|
1023
|
+
if pbar is not None:
|
1024
|
+
pbar.update(1)
|
1025
|
+
|
1026
|
+
if pbar is not None:
|
1027
|
+
pbar.close()
|
1028
|
+
|
1029
|
+
if not finish.all():
|
1030
|
+
if show_tqdm:
|
1031
|
+
logger.info(f"incomplete result. hit max_new_token: {max_new_token}")
|
1032
|
+
|
1033
|
+
del input_ids_buf
|
1034
|
+
|
1035
|
+
if finish.all():
|
1036
|
+
# the last may contains eos token
|
1037
|
+
genrated_input_ids = input_ids[:, condition_length:-1, :]
|
1038
|
+
else:
|
1039
|
+
# there is no eos token
|
1040
|
+
genrated_input_ids = input_ids[:, condition_length:, :]
|
1041
|
+
|
1042
|
+
return ConditionalChatTTSGenerationOutput(
|
1043
|
+
new_ids=genrated_input_ids,
|
1044
|
+
audio_input_ids=input_ids, # for update purpose
|
1045
|
+
past_key_values=past_key_values, # for update purpose
|
1046
|
+
finished=finish.all(),
|
1047
|
+
)
|
1048
|
+
|
1049
|
+
@torch.inference_mode()
|
1050
|
+
def decode_to_mel_specs(
|
1051
|
+
self,
|
1052
|
+
result_list: List[torch.Tensor],
|
1053
|
+
):
|
1054
|
+
"""Decode discrete audio codes to mel spectrograms.
|
1055
|
+
|
1056
|
+
Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/core.py`
|
1057
|
+
|
1058
|
+
Args:
|
1059
|
+
result_list (List[torch.Tensor]): Audio codes output from `generate`.
|
1060
|
+
|
1061
|
+
Returns:
|
1062
|
+
torch.Tensor: Mel spectrograms.
|
1063
|
+
"""
|
1064
|
+
|
1065
|
+
decoder = self.dvae
|
1066
|
+
max_x_len = -1
|
1067
|
+
if len(result_list) == 0:
|
1068
|
+
return np.array([], dtype=np.float32)
|
1069
|
+
for result in result_list:
|
1070
|
+
if result.size(0) > max_x_len:
|
1071
|
+
max_x_len = result.size(0)
|
1072
|
+
batch_result = torch.zeros(
|
1073
|
+
(len(result_list), result_list[0].size(1), max_x_len),
|
1074
|
+
dtype=result_list[0].dtype,
|
1075
|
+
device=result_list[0].device,
|
1076
|
+
)
|
1077
|
+
for i in range(len(result_list)):
|
1078
|
+
src = result_list[i]
|
1079
|
+
batch_result[i].narrow(1, 0, src.size(0)).copy_(src.permute(1, 0))
|
1080
|
+
del src
|
1081
|
+
|
1082
|
+
mel_specs = decoder(batch_result)
|
1083
|
+
del batch_result
|
1084
|
+
return mel_specs
|
1085
|
+
|
1086
|
+
|
1087
|
+
# Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer and add use_cache for streaming inference
|
1088
|
+
class MiniCPMWhisperEncoderLayer(nn.Module):
|
1089
|
+
def __init__(self, config: WhisperConfig, layer_idx: int = None):
|
1090
|
+
super().__init__()
|
1091
|
+
self.embed_dim = config.d_model
|
1092
|
+
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
|
1093
|
+
embed_dim=self.embed_dim,
|
1094
|
+
num_heads=config.encoder_attention_heads,
|
1095
|
+
dropout=config.attention_dropout,
|
1096
|
+
config=config,
|
1097
|
+
layer_idx=layer_idx,
|
1098
|
+
)
|
1099
|
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
1100
|
+
self.dropout = config.dropout
|
1101
|
+
self.activation_fn = ACT2FN[config.activation_function]
|
1102
|
+
self.activation_dropout = config.activation_dropout
|
1103
|
+
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
1104
|
+
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
1105
|
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
1106
|
+
|
1107
|
+
def forward(
|
1108
|
+
self,
|
1109
|
+
hidden_states: torch.Tensor,
|
1110
|
+
attention_mask: torch.Tensor,
|
1111
|
+
layer_head_mask: torch.Tensor,
|
1112
|
+
output_attentions: bool = False,
|
1113
|
+
past_key_values: Optional[EncoderDecoderCache] = None,
|
1114
|
+
use_cache: Optional[bool] = False,
|
1115
|
+
) -> torch.Tensor:
|
1116
|
+
r"""
|
1117
|
+
Args:
|
1118
|
+
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, embed_dim)`):
|
1119
|
+
Hidden states to be fed into the encoder layer.
|
1120
|
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, 1, tgt_len, src_len)`):
|
1121
|
+
Attention mask where padding elements are indicated by large negative values.
|
1122
|
+
layer_head_mask (`torch.FloatTensor` of shape `(encoder_attention_heads,)`):
|
1123
|
+
Mask to nullify selected heads of the attention modules.
|
1124
|
+
output_attentions (`bool`, *optional*):
|
1125
|
+
Whether or not to return the attention weights.
|
1126
|
+
past_key_values (`EncoderDecoderCache`, *optional*):
|
1127
|
+
Past key-value pairs used for incremental decoding.
|
1128
|
+
use_cache (`bool`, *optional*):
|
1129
|
+
Whether or not to return updated `past_key_values` for caching.
|
1130
|
+
|
1131
|
+
Returns:
|
1132
|
+
A tuple of shape `(hidden_states, optional(attn_weights), optional(past_key_values))`.
|
1133
|
+
"""
|
1134
|
+
residual = hidden_states
|
1135
|
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
1136
|
+
hidden_states, attn_weights, past_key_values = self.self_attn(
|
1137
|
+
hidden_states=hidden_states,
|
1138
|
+
attention_mask=attention_mask,
|
1139
|
+
layer_head_mask=layer_head_mask,
|
1140
|
+
output_attentions=output_attentions,
|
1141
|
+
past_key_value=past_key_values,
|
1142
|
+
)
|
1143
|
+
hidden_states = nn.functional.dropout(
|
1144
|
+
hidden_states, p=self.dropout, training=False
|
1145
|
+
)
|
1146
|
+
hidden_states = residual + hidden_states
|
1147
|
+
|
1148
|
+
residual = hidden_states
|
1149
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
1150
|
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
1151
|
+
hidden_states = nn.functional.dropout(
|
1152
|
+
hidden_states, p=self.activation_dropout, training=False
|
1153
|
+
)
|
1154
|
+
hidden_states = self.fc2(hidden_states)
|
1155
|
+
hidden_states = nn.functional.dropout(
|
1156
|
+
hidden_states, p=self.dropout, training=False
|
1157
|
+
)
|
1158
|
+
hidden_states = residual + hidden_states
|
1159
|
+
|
1160
|
+
if hidden_states.dtype == torch.float16 and (
|
1161
|
+
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
1162
|
+
):
|
1163
|
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
1164
|
+
hidden_states = torch.clamp(
|
1165
|
+
hidden_states, min=-clamp_value, max=clamp_value
|
1166
|
+
)
|
1167
|
+
|
1168
|
+
outputs = (hidden_states,)
|
1169
|
+
|
1170
|
+
if output_attentions:
|
1171
|
+
outputs += (attn_weights,)
|
1172
|
+
|
1173
|
+
if use_cache:
|
1174
|
+
outputs += (past_key_values,)
|
1175
|
+
|
1176
|
+
return outputs
|
1177
|
+
|
1178
|
+
|
1179
|
+
# Copied from from transformers.models.whisper.modeling_whisper.WhisperEncoder and add use_cache for streaming inference
|
1180
|
+
class MiniCPMWhisperEncoder(WhisperEncoder):
|
1181
|
+
|
1182
|
+
def __init__(self, config: WhisperConfig):
|
1183
|
+
super().__init__(config)
|
1184
|
+
self.layers = nn.ModuleList(
|
1185
|
+
[
|
1186
|
+
MiniCPMWhisperEncoderLayer(config, layer_idx=i)
|
1187
|
+
for i in range(config.encoder_layers)
|
1188
|
+
]
|
1189
|
+
)
|
1190
|
+
|
1191
|
+
def forward(
|
1192
|
+
self,
|
1193
|
+
input_features,
|
1194
|
+
attention_mask=None,
|
1195
|
+
head_mask=None,
|
1196
|
+
output_attentions=None,
|
1197
|
+
output_hidden_states=None,
|
1198
|
+
return_dict=None,
|
1199
|
+
past_key_values: Optional[EncoderDecoderCache] = None,
|
1200
|
+
use_cache: Optional[bool] = None,
|
1201
|
+
):
|
1202
|
+
r"""
|
1203
|
+
Forward pass of the Whisper encoder.
|
1204
|
+
|
1205
|
+
Args:
|
1206
|
+
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
|
1207
|
+
Float values of log-mel features extracted from the raw audio waveform. Typically generated
|
1208
|
+
by a feature extractor (e.g., `WhisperFeatureExtractor`) that processes `.flac` or `.wav`
|
1209
|
+
files into padded 2D mel spectrogram frames. These features are projected via convolution layers
|
1210
|
+
(`conv1` and `conv2`) and then transformed into embeddings for the encoder.
|
1211
|
+
|
1212
|
+
attention_mask (`torch.Tensor`, *optional*):
|
1213
|
+
Not used by Whisper for masking `input_features`, but included for API compatibility with
|
1214
|
+
other models. If provided, it is simply ignored within the model. By default, Whisper
|
1215
|
+
effectively ignores silence in the input log-mel spectrogram.
|
1216
|
+
|
1217
|
+
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
1218
|
+
Mask to nullify selected attention heads. The elements should be either 1 or 0, where:
|
1219
|
+
- 1 indicates the head is **not masked**,
|
1220
|
+
- 0 indicates the head is **masked** (i.e., the attention head is dropped).
|
1221
|
+
|
1222
|
+
output_attentions (`bool`, *optional*):
|
1223
|
+
Whether or not to return the attention tensors of all encoder layers. If set to `True`, the
|
1224
|
+
returned tuple (or `BaseModelOutputWithPast`) will contain an additional element with
|
1225
|
+
attention weights for each encoder layer.
|
1226
|
+
|
1227
|
+
output_hidden_states (`bool`, *optional*):
|
1228
|
+
Whether or not to return the hidden states of all layers. If set to `True`, the returned
|
1229
|
+
tuple (or `BaseModelOutputWithPast`) will contain a tuple of hidden states, including the
|
1230
|
+
initial embedding output as well as the outputs of each layer.
|
1231
|
+
|
1232
|
+
return_dict (`bool`, *optional*):
|
1233
|
+
Whether or not to return a `BaseModelOutputWithPast` (a subclass of `ModelOutput`) instead
|
1234
|
+
of a plain tuple. If set to `True`, the output will be a `BaseModelOutputWithPast` object,
|
1235
|
+
otherwise it will be a tuple.
|
1236
|
+
|
1237
|
+
past_key_values (`EncoderDecoderCache`, *optional*):
|
1238
|
+
When using caching for faster inference, this is an object that stores the key-value pairs
|
1239
|
+
for attention states. If provided, the model will append new states to the existing cache
|
1240
|
+
and return the updated cache. This speeds up sequential decoding or chunked inference.
|
1241
|
+
|
1242
|
+
- If `past_key_values` is `None`, no past states are used or returned.
|
1243
|
+
- If `past_key_values` is not `None` and `use_cache=True`, the model will use the provided
|
1244
|
+
cache and return the updated cache (as `next_encoder_cache`).
|
1245
|
+
|
1246
|
+
use_cache (`bool`, *optional*):
|
1247
|
+
Whether or not the model should use caching (`past_key_values`) to speed up processing
|
1248
|
+
during inference. When set to `True`, the model will:
|
1249
|
+
- Inspect and use `past_key_values` if provided.
|
1250
|
+
- Return updated `past_key_values` (under the name `next_encoder_cache` in
|
1251
|
+
`BaseModelOutputWithPast`).
|
1252
|
+
|
1253
|
+
Returns:
|
1254
|
+
`BaseModelOutputWithPast` or `tuple` (depending on `return_dict`):
|
1255
|
+
If `return_dict=True`, a `BaseModelOutputWithPast` is returned, which contains:
|
1256
|
+
- **last_hidden_state** (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
1257
|
+
The output of the final encoder layer.
|
1258
|
+
- **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_hidden_states=True`):
|
1259
|
+
Hidden states of the model at each layer (including the initial projection).
|
1260
|
+
- **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_attentions=True`):
|
1261
|
+
Attention weights from each encoder layer.
|
1262
|
+
- **past_key_values** (an object of type `EncoderDecoderCache` or `None`, *optional*):
|
1263
|
+
Updated cache of key-value pairs if `use_cache=True`.
|
1264
|
+
|
1265
|
+
If `return_dict=False`, a tuple is returned, where the format is:
|
1266
|
+
`(last_hidden_state, hidden_states, attentions)`, with `hidden_states` and `attentions`
|
1267
|
+
only present if their respective `output_*` arguments are set to `True`.
|
1268
|
+
|
1269
|
+
"""
|
1270
|
+
output_attentions = (
|
1271
|
+
output_attentions
|
1272
|
+
if output_attentions is not None
|
1273
|
+
else self.config.output_attentions
|
1274
|
+
)
|
1275
|
+
output_hidden_states = (
|
1276
|
+
output_hidden_states
|
1277
|
+
if output_hidden_states is not None
|
1278
|
+
else self.config.output_hidden_states
|
1279
|
+
)
|
1280
|
+
return_dict = (
|
1281
|
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1282
|
+
)
|
1283
|
+
|
1284
|
+
# Ignore copy
|
1285
|
+
input_features = input_features.to(
|
1286
|
+
dtype=self.conv1.weight.dtype, device=self.conv1.weight.device
|
1287
|
+
)
|
1288
|
+
|
1289
|
+
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
|
1290
|
+
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
|
1291
|
+
|
1292
|
+
inputs_embeds = inputs_embeds.permute(0, 2, 1)
|
1293
|
+
|
1294
|
+
embed_pos = self.embed_positions.weight
|
1295
|
+
past_key_values_length = 0
|
1296
|
+
if use_cache:
|
1297
|
+
if past_key_values is None:
|
1298
|
+
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
|
1299
|
+
elif isinstance(past_key_values, list):
|
1300
|
+
past_key_values = EncoderDecoderCache(
|
1301
|
+
DynamicCache.from_legacy_cache(past_key_values), DynamicCache()
|
1302
|
+
)
|
1303
|
+
elif isinstance(past_key_values, DynamicCache):
|
1304
|
+
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
|
1305
|
+
else:
|
1306
|
+
pass
|
1307
|
+
past_key_values_length = (
|
1308
|
+
past_key_values.self_attention_cache.get_usable_length(
|
1309
|
+
inputs_embeds.shape[1]
|
1310
|
+
)
|
1311
|
+
)
|
1312
|
+
if inputs_embeds.shape[1] + past_key_values_length > embed_pos.shape[0]:
|
1313
|
+
logger.warning(
|
1314
|
+
"seems the audio is longer than 30s. repeating the last part of the audio"
|
1315
|
+
)
|
1316
|
+
embed_pos_front = embed_pos[past_key_values_length:, :]
|
1317
|
+
embed_pos = torch.cat(
|
1318
|
+
(
|
1319
|
+
embed_pos_front,
|
1320
|
+
torch.repeat_interleave(
|
1321
|
+
embed_pos[-1, :].unsqueeze(0),
|
1322
|
+
inputs_embeds.shape[1]
|
1323
|
+
- embed_pos.shape[0]
|
1324
|
+
+ past_key_values_length,
|
1325
|
+
dim=0,
|
1326
|
+
),
|
1327
|
+
)
|
1328
|
+
)
|
1329
|
+
else:
|
1330
|
+
embed_pos = embed_pos[
|
1331
|
+
past_key_values_length : inputs_embeds.shape[1]
|
1332
|
+
+ past_key_values_length,
|
1333
|
+
:,
|
1334
|
+
]
|
1335
|
+
else:
|
1336
|
+
embed_pos = embed_pos[: inputs_embeds.shape[1], :]
|
1337
|
+
|
1338
|
+
hidden_states = inputs_embeds + embed_pos
|
1339
|
+
hidden_states = nn.functional.dropout(
|
1340
|
+
hidden_states, p=self.dropout, training=False
|
1341
|
+
)
|
1342
|
+
|
1343
|
+
encoder_states = () if output_hidden_states else None
|
1344
|
+
all_attentions = () if output_attentions else None
|
1345
|
+
|
1346
|
+
# check if head_mask has a correct number of layers specified if desired
|
1347
|
+
if head_mask is not None:
|
1348
|
+
assert head_mask.size()[0] == (
|
1349
|
+
len(self.layers)
|
1350
|
+
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
1351
|
+
|
1352
|
+
for idx, encoder_layer in enumerate(self.layers):
|
1353
|
+
if output_hidden_states:
|
1354
|
+
encoder_states = encoder_states + (hidden_states,)
|
1355
|
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
1356
|
+
to_drop = False
|
1357
|
+
|
1358
|
+
# Ignore copy
|
1359
|
+
if to_drop:
|
1360
|
+
layer_outputs = (None, None)
|
1361
|
+
else:
|
1362
|
+
layer_outputs = encoder_layer(
|
1363
|
+
hidden_states,
|
1364
|
+
attention_mask,
|
1365
|
+
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
1366
|
+
output_attentions=output_attentions,
|
1367
|
+
past_key_values=past_key_values,
|
1368
|
+
use_cache=use_cache,
|
1369
|
+
)
|
1370
|
+
|
1371
|
+
hidden_states = layer_outputs[0]
|
1372
|
+
|
1373
|
+
if use_cache:
|
1374
|
+
next_encoder_cache = layer_outputs[2 if output_attentions else 1]
|
1375
|
+
else:
|
1376
|
+
next_encoder_cache = None
|
1377
|
+
|
1378
|
+
if output_attentions:
|
1379
|
+
all_attentions = all_attentions + (layer_outputs[1],)
|
1380
|
+
|
1381
|
+
hidden_states = self.layer_norm(hidden_states)
|
1382
|
+
if output_hidden_states:
|
1383
|
+
encoder_states = encoder_states + (hidden_states,)
|
1384
|
+
|
1385
|
+
if not return_dict:
|
1386
|
+
return tuple(
|
1387
|
+
v
|
1388
|
+
for v in [hidden_states, encoder_states, all_attentions]
|
1389
|
+
if v is not None
|
1390
|
+
)
|
1391
|
+
return BaseModelOutputWithPast(
|
1392
|
+
last_hidden_state=hidden_states,
|
1393
|
+
hidden_states=encoder_states,
|
1394
|
+
attentions=all_attentions,
|
1395
|
+
past_key_values=next_encoder_cache,
|
1396
|
+
)
|
1397
|
+
|
1398
|
+
|
1399
|
+
class MultiModalProjector(nn.Module):
|
1400
|
+
def __init__(self, in_dim, out_dim):
|
1401
|
+
super().__init__()
|
1402
|
+
self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True)
|
1403
|
+
self.relu = nn.ReLU()
|
1404
|
+
self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True)
|
1405
|
+
|
1406
|
+
def forward(self, audio_features):
|
1407
|
+
hidden_states = self.relu(self.linear1(audio_features))
|
1408
|
+
hidden_states = self.linear2(hidden_states)
|
1409
|
+
return hidden_states
|
1410
|
+
|
1411
|
+
|
1412
|
+
class MiniCPMO(MiniCPMVBaseModel):
|
1413
|
+
def __init__(
|
1414
|
+
self,
|
1415
|
+
config: PretrainedConfig,
|
1416
|
+
quant_config: Optional[QuantizationConfig] = None,
|
1417
|
+
) -> None:
|
1418
|
+
super().__init__(config=config, quant_config=quant_config)
|
1419
|
+
|
1420
|
+
self.llm = self.init_llm(config=config, quant_config=quant_config)
|
1421
|
+
|
1422
|
+
self.embed_dim = self.llm.config.hidden_size
|
1423
|
+
|
1424
|
+
# init vision module
|
1425
|
+
if self.config.init_vision:
|
1426
|
+
# print("vision-understanding enabled")
|
1427
|
+
self.vpm = self.init_vision_module(config=config, quant_config=quant_config)
|
1428
|
+
self.vision_dim = self.vpm.embed_dim
|
1429
|
+
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
|
1430
|
+
|
1431
|
+
# init audio module
|
1432
|
+
self.config.init_audio = True
|
1433
|
+
if self.config.init_audio:
|
1434
|
+
# print("audio-understanding enabled")
|
1435
|
+
self.apm = self.init_audio_module()
|
1436
|
+
audio_output_dim = int(self.apm.config.encoder_ffn_dim // 4)
|
1437
|
+
self.audio_avg_pooler = nn.AvgPool1d(
|
1438
|
+
self.config.audio_pool_step, stride=self.config.audio_pool_step
|
1439
|
+
)
|
1440
|
+
self.audio_projection_layer = MultiModalProjector(
|
1441
|
+
in_dim=audio_output_dim, out_dim=self.embed_dim
|
1442
|
+
)
|
1443
|
+
self.audio_encoder_layer = -1
|
1444
|
+
|
1445
|
+
# init tts module
|
1446
|
+
self.config.init_tts = False
|
1447
|
+
logger.info("TTS is disabled for now")
|
1448
|
+
if self.config.init_tts:
|
1449
|
+
# print("tts enabled")
|
1450
|
+
assert (
|
1451
|
+
_tts_deps
|
1452
|
+
), "please make sure vector_quantize_pytorch and vocos are installed."
|
1453
|
+
self.tts = self.init_tts_module()
|
1454
|
+
|
1455
|
+
def init_tts_module(self):
|
1456
|
+
model = ConditionalChatTTS(self.config.tts_config)
|
1457
|
+
return model
|
1458
|
+
|
1459
|
+
def init_audio_module(self):
|
1460
|
+
model = MiniCPMWhisperEncoder(self.config.audio_config)
|
1461
|
+
return model
|
1462
|
+
|
1463
|
+
def init_llm(
|
1464
|
+
self,
|
1465
|
+
config: PretrainedConfig,
|
1466
|
+
quant_config: Optional[QuantizationConfig] = None,
|
1467
|
+
prefix: str = "",
|
1468
|
+
) -> nn.Module:
|
1469
|
+
return Qwen2ForCausalLM(config=config, quant_config=quant_config, prefix=prefix)
|
1470
|
+
|
1471
|
+
def init_vision_module(
|
1472
|
+
self,
|
1473
|
+
config: PretrainedConfig,
|
1474
|
+
quant_config: Optional[QuantizationConfig],
|
1475
|
+
prefix: str = "",
|
1476
|
+
):
|
1477
|
+
if self.config._attn_implementation == "flash_attention_2":
|
1478
|
+
self.config.vision_config._attn_implementation = "flash_attention_2"
|
1479
|
+
else:
|
1480
|
+
self.config.vision_config._attn_implementation = "eager"
|
1481
|
+
model = Idefics2VisionTransformer(
|
1482
|
+
config=config.vision_config, quant_config=quant_config, prefix=prefix
|
1483
|
+
)
|
1484
|
+
if self.config.drop_vision_last_layer:
|
1485
|
+
model.encoder.layers = model.encoder.layers[:-1]
|
1486
|
+
|
1487
|
+
setattr(model, "embed_dim", model.embeddings.embed_dim)
|
1488
|
+
setattr(model, "patch_size", model.embeddings.patch_size)
|
1489
|
+
|
1490
|
+
return model
|
1491
|
+
|
1492
|
+
def init_resampler(
|
1493
|
+
self,
|
1494
|
+
embed_dim: int,
|
1495
|
+
vision_dim: int,
|
1496
|
+
quant_config: Optional[QuantizationConfig] = None,
|
1497
|
+
prefix: str = "",
|
1498
|
+
) -> nn.Module:
|
1499
|
+
with set_default_torch_dtype(torch.float16):
|
1500
|
+
# The resampler in 2.6 remains consistent with the one in 2.5.
|
1501
|
+
resampler = Resampler2_5(
|
1502
|
+
num_queries=self.config.query_num,
|
1503
|
+
embed_dim=embed_dim,
|
1504
|
+
num_heads=embed_dim // 128,
|
1505
|
+
kv_dim=vision_dim,
|
1506
|
+
quant_config=quant_config,
|
1507
|
+
prefix=prefix,
|
1508
|
+
)
|
1509
|
+
|
1510
|
+
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
|
1511
|
+
|
1512
|
+
def pad_input_ids(self, input_ids: List[int], mm_input: MultimodalInputs):
|
1513
|
+
# Get all special token IDs
|
1514
|
+
im_start_id: int = mm_input.im_start_id
|
1515
|
+
im_end_id: int = mm_input.im_end_id
|
1516
|
+
slice_start_id: int = mm_input.slice_start_id
|
1517
|
+
slice_end_id: int = mm_input.slice_end_id
|
1518
|
+
|
1519
|
+
media_token_pairs = [
|
1520
|
+
(im_start_id, im_end_id),
|
1521
|
+
(slice_start_id, slice_end_id),
|
1522
|
+
(mm_input.audio_start_id, mm_input.audio_end_id),
|
1523
|
+
]
|
1524
|
+
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
1525
|
+
|
1526
|
+
return pattern.pad_input_tokens(input_ids, mm_input)
|
1527
|
+
|
1528
|
+
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
1529
|
+
"""
|
1530
|
+
Computes the output length of the convolutional layers and the output length of the audio encoder
|
1531
|
+
"""
|
1532
|
+
input_lengths_after_cnn = (input_lengths - 1) // 2 + 1
|
1533
|
+
input_lengths_after_pooling = (
|
1534
|
+
input_lengths_after_cnn - self.config.audio_pool_step
|
1535
|
+
) // self.config.audio_pool_step + 1
|
1536
|
+
input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32)
|
1537
|
+
|
1538
|
+
return input_lengths_after_cnn, input_lengths_after_pooling
|
1539
|
+
|
1540
|
+
def get_audio_embedding_streaming(self, multimodal_input: MultimodalInputs):
|
1541
|
+
r"""
|
1542
|
+
Extract audio embeddings in a streaming manner using cached key-value pairs.
|
1543
|
+
|
1544
|
+
This method processes incoming audio features incrementally and stores/updates `past_key_values`
|
1545
|
+
for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended
|
1546
|
+
for streaming scenarios.
|
1547
|
+
|
1548
|
+
Args:
|
1549
|
+
multimodal_input (dict):
|
1550
|
+
- **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
|
1551
|
+
- **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
|
1552
|
+
|
1553
|
+
Returns:
|
1554
|
+
List[List[torch.Tensor]]: audio embeddings
|
1555
|
+
"""
|
1556
|
+
# print("audio embedding")
|
1557
|
+
|
1558
|
+
wavforms = (
|
1559
|
+
[]
|
1560
|
+
if multimodal_input.audio_features is None
|
1561
|
+
else multimodal_input.audio_features
|
1562
|
+
)
|
1563
|
+
# list, [[x1, x2], [y1], [z1]]
|
1564
|
+
audio_feature_lens_raw = (
|
1565
|
+
[]
|
1566
|
+
if multimodal_input.audio_feature_lens is None
|
1567
|
+
else multimodal_input.audio_feature_lens
|
1568
|
+
)
|
1569
|
+
|
1570
|
+
# exist audio
|
1571
|
+
if len(wavforms) > 0:
|
1572
|
+
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
|
1573
|
+
batch_size, _, max_mel_seq_len = wavforms.shape
|
1574
|
+
assert batch_size == 1
|
1575
|
+
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
|
1576
|
+
|
1577
|
+
if self.audio_past_key_values is not None:
|
1578
|
+
cache_length = self.audio_past_key_values[0][0].shape[2]
|
1579
|
+
apm_max_len = self.apm.embed_positions.weight.shape[0]
|
1580
|
+
if cache_length + max_seq_len >= apm_max_len:
|
1581
|
+
logger.warning(
|
1582
|
+
f"audio_past_key_values length {cache_length + max_seq_len} exceed {apm_max_len}, reset."
|
1583
|
+
)
|
1584
|
+
self.audio_past_key_values = None
|
1585
|
+
|
1586
|
+
audio_outputs = self.apm(
|
1587
|
+
wavforms, past_key_values=self.audio_past_key_values, use_cache=True
|
1588
|
+
)
|
1589
|
+
audio_states = (
|
1590
|
+
audio_outputs.last_hidden_state
|
1591
|
+
) # [:, :audio_feat_lengths, :]
|
1592
|
+
self.audio_past_key_values = audio_outputs.past_key_values
|
1593
|
+
|
1594
|
+
audio_embeds = self.audio_projection_layer(audio_states)
|
1595
|
+
|
1596
|
+
audio_embeds = audio_embeds.transpose(1, 2)
|
1597
|
+
audio_embeds = self.audio_avg_pooler(audio_embeds)
|
1598
|
+
audio_embeds = audio_embeds.transpose(1, 2)
|
1599
|
+
|
1600
|
+
_, feature_lens_after_pooling = self._get_feat_extract_output_lengths(
|
1601
|
+
audio_feature_lens
|
1602
|
+
)
|
1603
|
+
|
1604
|
+
num_audio_tokens = feature_lens_after_pooling
|
1605
|
+
|
1606
|
+
final_audio_embeds = []
|
1607
|
+
idx = 0
|
1608
|
+
for i in range(len(audio_feature_lens_raw)):
|
1609
|
+
target_audio_embeds = []
|
1610
|
+
for _ in range(len(audio_feature_lens_raw[i])):
|
1611
|
+
target_audio_embeds.append(
|
1612
|
+
audio_embeds[idx, : num_audio_tokens[idx], :]
|
1613
|
+
)
|
1614
|
+
idx += 1
|
1615
|
+
final_audio_embeds.append(target_audio_embeds)
|
1616
|
+
return final_audio_embeds
|
1617
|
+
else:
|
1618
|
+
return []
|
1619
|
+
|
1620
|
+
def subsequent_chunk_mask(
|
1621
|
+
self,
|
1622
|
+
size: int,
|
1623
|
+
chunk_size: int,
|
1624
|
+
num_left_chunks: int = -1,
|
1625
|
+
device: torch.device = torch.device("cpu"),
|
1626
|
+
num_lookhead: int = 0,
|
1627
|
+
) -> torch.Tensor:
|
1628
|
+
"""Create mask for subsequent steps (size, size) with chunk size,
|
1629
|
+
this is for streaming encoder
|
1630
|
+
|
1631
|
+
Args:
|
1632
|
+
size (int): size of mask
|
1633
|
+
chunk_size (int): size of chunk
|
1634
|
+
num_left_chunks (int): number of left chunks
|
1635
|
+
<0: use full chunk
|
1636
|
+
>=0: use num_left_chunks
|
1637
|
+
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
1638
|
+
|
1639
|
+
Returns:
|
1640
|
+
torch.Tensor: mask
|
1641
|
+
|
1642
|
+
"""
|
1643
|
+
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
1644
|
+
for i in range(size):
|
1645
|
+
if num_left_chunks < 0:
|
1646
|
+
start = 0
|
1647
|
+
else:
|
1648
|
+
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
1649
|
+
ending = min((i // chunk_size + 1) * chunk_size + num_lookhead, size)
|
1650
|
+
ret[i, start:ending] = True
|
1651
|
+
return ret
|
1652
|
+
|
1653
|
+
def get_audio_embedding(self, multimodal_input: MultimodalInputs, chunk_length=-1):
|
1654
|
+
r"""
|
1655
|
+
Extract full audio embeddings with optional chunk-based attention.
|
1656
|
+
|
1657
|
+
This method computes embeddings for all audio frames at once, either using full attention (when
|
1658
|
+
`chunk_length` is -1) or chunk-based attention (when `chunk_length` is a positive number). It does
|
1659
|
+
not use key-value caching and is suitable for non-streaming inference.
|
1660
|
+
|
1661
|
+
Args:
|
1662
|
+
multimodal_input (dict):
|
1663
|
+
- **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
|
1664
|
+
- **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
|
1665
|
+
chunk_length (int, optional): Determines whether to use full attention (-1) or chunk-based
|
1666
|
+
attention (>0) during embedding computation.
|
1667
|
+
|
1668
|
+
Returns:
|
1669
|
+
List[List[torch.Tensor]]: audio embeddings
|
1670
|
+
"""
|
1671
|
+
# print("audio embedding")
|
1672
|
+
# (bs, 80, frames) or [], multi audios need filled in advance
|
1673
|
+
wavforms = (
|
1674
|
+
[]
|
1675
|
+
if multimodal_input.audio_features is None
|
1676
|
+
else multimodal_input.audio_features
|
1677
|
+
)
|
1678
|
+
# list, [[x1, x2], [y1], [z1]]
|
1679
|
+
audio_feature_lens_raw = (
|
1680
|
+
[]
|
1681
|
+
if multimodal_input.audio_feature_lens is None
|
1682
|
+
else multimodal_input.audio_feature_lens
|
1683
|
+
)
|
1684
|
+
|
1685
|
+
final_audio_embeds = []
|
1686
|
+
|
1687
|
+
# exist audio
|
1688
|
+
for wavform in wavforms:
|
1689
|
+
if len(wavform) > 0:
|
1690
|
+
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
|
1691
|
+
batch_size, _, max_mel_seq_len = wavform.shape
|
1692
|
+
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
|
1693
|
+
|
1694
|
+
# Create a sequence tensor of shape (batch_size, max_seq_len)
|
1695
|
+
seq_range = (
|
1696
|
+
torch.arange(
|
1697
|
+
0,
|
1698
|
+
max_seq_len,
|
1699
|
+
dtype=audio_feature_lens.dtype,
|
1700
|
+
device=audio_feature_lens.device,
|
1701
|
+
)
|
1702
|
+
.unsqueeze(0)
|
1703
|
+
.expand(batch_size, max_seq_len)
|
1704
|
+
)
|
1705
|
+
lengths_expand = audio_feature_lens.unsqueeze(1).expand(
|
1706
|
+
batch_size, max_seq_len
|
1707
|
+
)
|
1708
|
+
# Create mask
|
1709
|
+
padding_mask = seq_range >= lengths_expand # 1 for padded values
|
1710
|
+
|
1711
|
+
audio_attention_mask_ = padding_mask.view(
|
1712
|
+
batch_size, 1, 1, max_seq_len
|
1713
|
+
).expand(batch_size, 1, max_seq_len, max_seq_len)
|
1714
|
+
audio_attention_mask = audio_attention_mask_.to(
|
1715
|
+
dtype=self.apm.conv1.weight.dtype,
|
1716
|
+
device=self.apm.conv1.weight.device,
|
1717
|
+
)
|
1718
|
+
|
1719
|
+
if chunk_length > 0:
|
1720
|
+
chunk_num_frame = int(chunk_length * 50)
|
1721
|
+
chunk_mask = self.subsequent_chunk_mask(
|
1722
|
+
size=max_seq_len,
|
1723
|
+
chunk_size=chunk_num_frame,
|
1724
|
+
num_left_chunks=-1,
|
1725
|
+
device=audio_attention_mask_.device,
|
1726
|
+
)
|
1727
|
+
audio_attention_mask_ = torch.logical_or(
|
1728
|
+
audio_attention_mask_, torch.logical_not(chunk_mask)
|
1729
|
+
)
|
1730
|
+
|
1731
|
+
audio_attention_mask[audio_attention_mask_] = float("-inf")
|
1732
|
+
audio_states = self.apm(
|
1733
|
+
wavform,
|
1734
|
+
output_hidden_states=True,
|
1735
|
+
attention_mask=audio_attention_mask,
|
1736
|
+
).hidden_states[self.audio_encoder_layer]
|
1737
|
+
audio_embeds = self.audio_projection_layer(audio_states)
|
1738
|
+
|
1739
|
+
audio_embeds = audio_embeds.transpose(1, 2)
|
1740
|
+
audio_embeds = self.audio_avg_pooler(audio_embeds)
|
1741
|
+
audio_embeds = audio_embeds.transpose(1, 2)
|
1742
|
+
|
1743
|
+
_, feature_lens_after_pooling = self._get_feat_extract_output_lengths(
|
1744
|
+
audio_feature_lens
|
1745
|
+
)
|
1746
|
+
|
1747
|
+
num_audio_tokens = feature_lens_after_pooling
|
1748
|
+
|
1749
|
+
idx = 0
|
1750
|
+
for i in range(len(audio_feature_lens_raw)):
|
1751
|
+
target_audio_embeds = []
|
1752
|
+
for _ in range(len(audio_feature_lens_raw[i])):
|
1753
|
+
target_audio_embeds.append(
|
1754
|
+
audio_embeds[idx, : num_audio_tokens[idx], :]
|
1755
|
+
)
|
1756
|
+
idx += 1
|
1757
|
+
final_audio_embeds.append(target_audio_embeds)
|
1758
|
+
return final_audio_embeds
|
1759
|
+
|
1760
|
+
def get_omni_embedding(
|
1761
|
+
self,
|
1762
|
+
input_ids,
|
1763
|
+
multimodal_input: MultimodalInputs,
|
1764
|
+
input_embeds: torch.Tensor,
|
1765
|
+
forward_mode: ForwardMode,
|
1766
|
+
chunk_length=-1,
|
1767
|
+
stream_input=False,
|
1768
|
+
):
|
1769
|
+
"""
|
1770
|
+
Args:
|
1771
|
+
multimodal_input:
|
1772
|
+
input_embeds:
|
1773
|
+
chunk_length: whisper use full attention or chunk attention
|
1774
|
+
stream_input: use streaming audio embedding
|
1775
|
+
Returns:
|
1776
|
+
final embeddings with audio feature
|
1777
|
+
"""
|
1778
|
+
input_embeds = input_embeds.unsqueeze(0)
|
1779
|
+
if not forward_mode.is_decode() and multimodal_input.contains_audio_inputs():
|
1780
|
+
audio_bounds = get_multimodal_data_bounds(
|
1781
|
+
input_ids=input_ids,
|
1782
|
+
pad_values=multimodal_input.pad_values,
|
1783
|
+
token_pairs=[
|
1784
|
+
(multimodal_input.audio_start_id, multimodal_input.audio_end_id)
|
1785
|
+
],
|
1786
|
+
)
|
1787
|
+
if audio_bounds.numel() == 0:
|
1788
|
+
input_embeds = input_embeds.squeeze(0)
|
1789
|
+
# TODO
|
1790
|
+
logger.warn("Unimplemented logic. Please try disabling chunked prefill")
|
1791
|
+
return input_embeds
|
1792
|
+
audio_bounds = audio_bounds.unsqueeze(0)
|
1793
|
+
bs = len(input_embeds)
|
1794
|
+
|
1795
|
+
if stream_input:
|
1796
|
+
audio_embeddings = self.get_audio_embedding_streaming(multimodal_input)
|
1797
|
+
else:
|
1798
|
+
audio_embeddings = self.get_audio_embedding(
|
1799
|
+
multimodal_input, chunk_length
|
1800
|
+
)
|
1801
|
+
# batch size
|
1802
|
+
assert len(audio_embeddings) == len(input_embeds)
|
1803
|
+
if len(audio_embeddings) > 0:
|
1804
|
+
if self.config.chunk_input:
|
1805
|
+
for i in range(bs):
|
1806
|
+
audio_embs = torch.cat(audio_embeddings[i], dim=0).to(
|
1807
|
+
device=input_embeds.device, dtype=input_embeds.dtype
|
1808
|
+
)
|
1809
|
+
audio_start_pos = 0
|
1810
|
+
for bound in audio_bounds[i]:
|
1811
|
+
audio_len = bound[1] - bound[0] + 1
|
1812
|
+
input_embeds[0, bound[0] : bound[1] + 1] = audio_embs[
|
1813
|
+
audio_start_pos : audio_start_pos + audio_len, :
|
1814
|
+
]
|
1815
|
+
audio_start_pos += audio_len
|
1816
|
+
else:
|
1817
|
+
for i in range(bs):
|
1818
|
+
audio_embs = audio_embeddings[i]
|
1819
|
+
bounds = audio_bounds[i]
|
1820
|
+
for embs, bound in zip(audio_embs, bounds):
|
1821
|
+
audio_indices = torch.arange(
|
1822
|
+
bound[0], bound[1], dtype=torch.long
|
1823
|
+
).to(input_embeds.device)
|
1824
|
+
|
1825
|
+
if embs.shape[0] != len(audio_indices):
|
1826
|
+
raise ValueError(
|
1827
|
+
f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} "
|
1828
|
+
f"to input indices of length {len(audio_indices)}"
|
1829
|
+
)
|
1830
|
+
input_embeds[i, audio_indices] = embs.to(input_embeds.dtype)
|
1831
|
+
input_embeds = input_embeds.squeeze(0)
|
1832
|
+
return input_embeds
|
1833
|
+
|
1834
|
+
def get_image_features(
|
1835
|
+
self,
|
1836
|
+
image_inputs: MultimodalInputs,
|
1837
|
+
) -> torch.Tensor:
|
1838
|
+
pixel_values = image_inputs.pixel_values
|
1839
|
+
tgt_sizes = image_inputs.tgt_sizes
|
1840
|
+
device = self.vpm.embeddings.position_embedding.weight.device
|
1841
|
+
dtype = self.vpm.embeddings.position_embedding.weight.dtype
|
1842
|
+
all_pixel_values_lst = [
|
1843
|
+
i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
|
1844
|
+
]
|
1845
|
+
|
1846
|
+
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
|
1847
|
+
assert isinstance(max_patches, int)
|
1848
|
+
|
1849
|
+
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
|
1850
|
+
all_pixel_values_lst, batch_first=True, padding_value=0.0
|
1851
|
+
)
|
1852
|
+
B, L, _ = all_pixel_values.shape
|
1853
|
+
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
1854
|
+
patch_attn_mask = torch.zeros(
|
1855
|
+
(B, 1, max_patches), dtype=torch.bool, device=device
|
1856
|
+
)
|
1857
|
+
|
1858
|
+
tgt_sizes_tensor = tgt_sizes.clone().to(device=patch_attn_mask.device)
|
1859
|
+
mask_shapes = tgt_sizes_tensor[:, 0] * tgt_sizes_tensor[:, 1]
|
1860
|
+
patch_attn_mask[:, 0, :] = torch.arange(
|
1861
|
+
patch_attn_mask.size(2), device=patch_attn_mask.device
|
1862
|
+
).unsqueeze(0) < mask_shapes.unsqueeze(1)
|
1863
|
+
|
1864
|
+
vision_embedding = self.vpm(
|
1865
|
+
all_pixel_values.type(dtype),
|
1866
|
+
patch_attention_mask=patch_attn_mask,
|
1867
|
+
tgt_sizes=tgt_sizes,
|
1868
|
+
)
|
1869
|
+
return self.resampler(vision_embedding, tgt_sizes)
|
1870
|
+
|
1871
|
+
def forward(
|
1872
|
+
self,
|
1873
|
+
input_ids: torch.Tensor,
|
1874
|
+
positions: torch.Tensor,
|
1875
|
+
forward_batch: ForwardBatch,
|
1876
|
+
**kwargs: Any,
|
1877
|
+
) -> torch.Tensor:
|
1878
|
+
inputs_embeds = None
|
1879
|
+
# TODO(mick): optimize the logic here: clamp, merge and embedding should happens at most once
|
1880
|
+
if (
|
1881
|
+
not forward_batch.forward_mode.is_decode()
|
1882
|
+
and forward_batch.contains_image_inputs()
|
1883
|
+
):
|
1884
|
+
mm_inputs = forward_batch.merge_mm_inputs()
|
1885
|
+
inputs_embeds = embed_mm_inputs(
|
1886
|
+
mm_input=mm_inputs,
|
1887
|
+
input_ids=input_ids,
|
1888
|
+
input_embedding=self.get_input_embeddings(),
|
1889
|
+
mm_data_embedding_func=self.get_image_features,
|
1890
|
+
placeholder_token_ids=[mm_inputs.im_token_id] + mm_inputs.pad_values,
|
1891
|
+
)
|
1892
|
+
|
1893
|
+
input_ids = input_ids.clamp(
|
1894
|
+
min=0, max=self.get_input_embeddings().num_embeddings - 1
|
1895
|
+
)
|
1896
|
+
if inputs_embeds is None:
|
1897
|
+
inputs_embeds = self.llm.get_input_embeddings(input_ids)
|
1898
|
+
if (
|
1899
|
+
not forward_batch.forward_mode.is_decode()
|
1900
|
+
and self.config.init_audio
|
1901
|
+
and forward_batch.contains_audio_inputs()
|
1902
|
+
):
|
1903
|
+
mm_input = forward_batch.merge_mm_inputs()
|
1904
|
+
inputs_embeds = self.get_omni_embedding(
|
1905
|
+
input_ids=input_ids,
|
1906
|
+
multimodal_input=mm_input,
|
1907
|
+
input_embeds=inputs_embeds,
|
1908
|
+
forward_mode=forward_batch.forward_mode,
|
1909
|
+
chunk_length=self.config.audio_chunk_length,
|
1910
|
+
stream_input=False,
|
1911
|
+
)
|
1912
|
+
|
1913
|
+
forward_batch.mm_inputs = None
|
1914
|
+
|
1915
|
+
hidden_states = self.llm.model(
|
1916
|
+
input_ids=None,
|
1917
|
+
positions=positions,
|
1918
|
+
forward_batch=forward_batch,
|
1919
|
+
input_embeds=inputs_embeds,
|
1920
|
+
)
|
1921
|
+
|
1922
|
+
return self.logits_processor(
|
1923
|
+
input_ids, hidden_states, self.llm.lm_head, forward_batch
|
1924
|
+
)
|
1925
|
+
|
1926
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
1927
|
+
stacked_params_mapping = [
|
1928
|
+
# (param_name, shard_name, shard_id)
|
1929
|
+
("qkv_proj", "q_proj", "q"),
|
1930
|
+
("qkv_proj", "k_proj", "k"),
|
1931
|
+
("qkv_proj", "v_proj", "v"),
|
1932
|
+
("gate_up_proj", "gate_proj", 0),
|
1933
|
+
("gate_up_proj", "up_proj", 1),
|
1934
|
+
]
|
1935
|
+
|
1936
|
+
params_dict = dict(self.named_parameters())
|
1937
|
+
for name, loaded_weight in weights:
|
1938
|
+
|
1939
|
+
if "rotary_emb.inv_freq~" in name or "projector" in name:
|
1940
|
+
continue
|
1941
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
1942
|
+
# Models trained using ColossalAI may include these tensors in
|
1943
|
+
# the checkpoint. Skip them.
|
1944
|
+
continue
|
1945
|
+
|
1946
|
+
# adapt to parametrization
|
1947
|
+
if self.config.init_tts and "tts" in name:
|
1948
|
+
name = name.replace(".parametrizations", "")
|
1949
|
+
name = name.replace(".weight.original0", ".weight_g")
|
1950
|
+
name = name.replace(".weight.original1", ".weight_v")
|
1951
|
+
|
1952
|
+
# adapt to VisionAttention
|
1953
|
+
if "vpm" in name:
|
1954
|
+
name = name.replace(r"self_attn.out_proj", r"self_attn.proj")
|
1955
|
+
|
1956
|
+
if not self.config.init_tts and "tts" in name:
|
1957
|
+
continue
|
1958
|
+
if not self.config.init_audio and ("apm" in name or "audio" in name):
|
1959
|
+
continue
|
1960
|
+
if not self.config.init_vision and "vpm" in name:
|
1961
|
+
continue
|
1962
|
+
|
1963
|
+
if (
|
1964
|
+
"sampler" in name
|
1965
|
+
or "apm" in name
|
1966
|
+
or ("tts" in name and "self_attn" in name)
|
1967
|
+
or ("tts.model.layers" in name and ".mlp" in name)
|
1968
|
+
):
|
1969
|
+
param = params_dict[name]
|
1970
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
1971
|
+
weight_loader(param, loaded_weight)
|
1972
|
+
continue
|
1973
|
+
|
1974
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
1975
|
+
# replace the name and load with customized loader
|
1976
|
+
if weight_name not in name:
|
1977
|
+
continue
|
1978
|
+
name = name.replace(weight_name, param_name)
|
1979
|
+
# # Skip loading extra bias for GPTQ models.
|
1980
|
+
if name.endswith(".bias") and name not in params_dict:
|
1981
|
+
continue
|
1982
|
+
param = params_dict[name]
|
1983
|
+
weight_loader = param.weight_loader
|
1984
|
+
weight_loader(param, loaded_weight, shard_id)
|
1985
|
+
break
|
1986
|
+
else:
|
1987
|
+
# Skip loading extra bias for GPTQ models.
|
1988
|
+
if name.endswith(".bias") and name not in params_dict:
|
1989
|
+
continue
|
1990
|
+
param = params_dict[name]
|
1991
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
1992
|
+
weight_loader(param, loaded_weight)
|
1993
|
+
|
1994
|
+
|
1995
|
+
EntryClass = [MiniCPMO]
|