sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1995 @@
1
+ # Copied and adapted from: https://huggingface.co/openbmb/MiniCPM-o-2_6/blob/main/modeling_minicpmo.py
2
+
3
+ # Copyright 2023-2024 SGLang Team
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+ """Inference-only MiniCPM-o model compatible with HuggingFace weights."""
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import Any, Iterable, List, Literal, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.nn.utils.parametrize as P
26
+ import torch.types
27
+ from torch import nn
28
+ from torch.nn.utils import weight_norm
29
+ from tqdm import tqdm
30
+ from transformers import LlamaConfig, LlamaModel, PretrainedConfig, PreTrainedModel
31
+ from transformers.activations import ACT2FN
32
+ from transformers.cache_utils import DynamicCache, EncoderDecoderCache
33
+ from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
34
+ from transformers.models.whisper.modeling_whisper import (
35
+ WHISPER_ATTENTION_CLASSES,
36
+ WhisperConfig,
37
+ WhisperEncoder,
38
+ )
39
+
40
+ from sglang.srt.layers.quantization import QuantizationConfig
41
+ from sglang.srt.managers.mm_utils import (
42
+ MultiModalityDataPaddingPatternTokenPairs,
43
+ embed_mm_inputs,
44
+ get_multimodal_data_bounds,
45
+ )
46
+ from sglang.srt.managers.schedule_batch import MultimodalInputs
47
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
48
+ from sglang.srt.model_loader.utils import set_default_torch_dtype
49
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
50
+ from sglang.srt.models.minicpmv import (
51
+ Idefics2VisionTransformer,
52
+ MiniCPMVBaseModel,
53
+ Resampler2_5,
54
+ )
55
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
56
+ from sglang.srt.utils import logger
57
+
58
+ try:
59
+ from transformers import LogitsWarper
60
+ from vector_quantize_pytorch import GroupedResidualFSQ
61
+ from vocos import Vocos
62
+ from vocos.pretrained import instantiate_class
63
+
64
+ _tts_deps = True
65
+ except:
66
+ LogitsWarper = None
67
+ _tts_deps = False
68
+
69
+
70
+ def apply_spk_emb(
71
+ input_ids: torch.Tensor = None,
72
+ spk_emb: torch.Tensor = None,
73
+ input_embeds: torch.Tensor = None,
74
+ spk_emb_token_id: int = 0,
75
+ num_spk_embs: int = 1,
76
+ ):
77
+ """
78
+ Replace consecutive `num_spk_embs` speaker embedding placeholders in input_embeds with pre-prepared speaker embeddings. This is an in-place replacement, no new tensor is created, so no value is returned.
79
+
80
+ Args:
81
+ input_ids (torch.Tensor): Input ID tensor, shape [batch_size, seq_len_max]
82
+ spk_emb (torch.Tensor): Speaker embedding tensor, shape [batch_size, num_spk_emb, hidden_dim]
83
+ input_embeds (torch.Tensor): Input embedding tensor, shape [batch_size, seq_len_max, hidden_dim]
84
+ spk_emb_token_id (int): ID of the speaker embedding token
85
+ num_spk_embs (int): Number of speaker embeddings
86
+
87
+ Returns:
88
+ None
89
+ """
90
+
91
+ batch_size = input_ids.shape[0]
92
+
93
+ for idx in range(batch_size):
94
+ input_ids_ = input_ids[idx] # [seq_len_max]
95
+ spk_emb_ = spk_emb[idx] # [num_spk_emb]
96
+ mask_ = input_ids_ == spk_emb_token_id # [batch_size, seq_len_max]
97
+ nonzero_position_idx = mask_.nonzero(as_tuple=False) # [num_spk_emb, 1]
98
+ assert nonzero_position_idx.shape[0] == num_spk_embs
99
+ begin_idx = nonzero_position_idx.min()
100
+ end_idx = nonzero_position_idx.max()
101
+ input_embeds[idx, begin_idx : end_idx + 1, :] = spk_emb_
102
+
103
+ return
104
+
105
+
106
+ @dataclass
107
+ class ConditionalChatTTSGenerationOutput(ModelOutput):
108
+ """
109
+ Output class for ConditionalChatTTS generation.
110
+
111
+ Args:
112
+ new_ids (torch.LongTensor): Newly generated audio code sequence, shape (batch_size, sequence_length, num_vq).
113
+ audio_input_ids (torch.LongTensor): Updated input IDs including condition and generated audio codes, shape (batch_size, full_sequence_length, num_vq).
114
+ past_key_values (Tuple[Tuple[torch.FloatTensor]]): Tuple containing pre-computed keys and values used for attention mechanism. Each element has shape (batch_size, num_heads, sequence_length, embed_size_per_head).
115
+ finished (bool): Boolean indicating whether generation is complete.
116
+
117
+ """
118
+
119
+ new_ids: torch.LongTensor = None
120
+ audio_input_ids: torch.LongTensor = None
121
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
122
+ finished: bool = None
123
+
124
+
125
+ def make_streaming_chunk_mask_generation(
126
+ inputs_embeds: torch.Tensor,
127
+ past_seen_tokens: int,
128
+ streaming_tts_text_mask: torch.Tensor,
129
+ streaming_reserved_length: int = 300,
130
+ streaming_audio_chunk_size: int = 50,
131
+ streaming_text_chunk_size: int = 10,
132
+ num_spk_emb: int = 1,
133
+ use_spk_emb: bool = True,
134
+ ) -> torch.Tensor:
135
+ """
136
+ In streaming audio generation, determine which `text` positions the TTS model can attend to when generating each chunk of `audio` tokens.
137
+
138
+ This function creates a mask that allows the model to attend to a specific chunk of text
139
+ tokens when generating each chunk of audio tokens, enabling streaming TTS generation.
140
+
141
+ Args:
142
+ inputs_embeds (torch.Tensor): Input embeddings tensor.
143
+ past_seen_tokens (int): Number of tokens already seen by the model.
144
+ streaming_tts_text_mask (torch.Tensor): Mask for the text tokens.
145
+ streaming_reserved_length (int, optional): Number of reserved tokens for streaming. Defaults to 300.
146
+ streaming_text_chunk_size (int, optional): Size of each text chunk. Defaults to 7.
147
+
148
+ Returns:
149
+ torch.Tensor: Causal mask for streaming TTS generation, shape is [batch_size=1, 1, seq_len=1, past_seen_tokens+1]
150
+
151
+ Raises:
152
+ AssertionError: If the batch size is not 1 (only supports batch size of 1 for inference).
153
+ """
154
+ assert inputs_embeds.shape[0] == 1
155
+
156
+ dtype = inputs_embeds.dtype
157
+ device = inputs_embeds.device
158
+ min_dtype = torch.finfo(dtype).min
159
+
160
+ # Add `1` to the past seen tokens to account for new `tokens` during `generate`
161
+ causal_mask = torch.full(
162
+ (1, past_seen_tokens + inputs_embeds.shape[1]),
163
+ fill_value=0,
164
+ dtype=dtype,
165
+ device=device,
166
+ )
167
+
168
+ # Calculate the start of invisible text tokens
169
+ invisible_text_tokens_start = (
170
+ min(
171
+ math.ceil(
172
+ (past_seen_tokens - streaming_reserved_length)
173
+ / streaming_audio_chunk_size
174
+ )
175
+ * streaming_text_chunk_size,
176
+ streaming_reserved_length,
177
+ )
178
+ + 1
179
+ + num_spk_emb * use_spk_emb
180
+ ) # Add 1 for [Stts] and N for [spk_emb] tokens if `use_spk_emb` is True
181
+
182
+ invisible_text_tokens_end = (
183
+ streaming_reserved_length + 1 + num_spk_emb * use_spk_emb + 1
184
+ ) # Add 1 for [Ptts] (aka `audio_bos_token_id`)
185
+
186
+ # Set invisible text tokens to min_dtype (effectively -inf)
187
+ causal_mask[0, invisible_text_tokens_start:invisible_text_tokens_end] = min_dtype
188
+
189
+ # Mask padding positions in the text mask
190
+ causal_mask[
191
+ 0, 0 : 1 + num_spk_emb * use_spk_emb + streaming_reserved_length + 1
192
+ ].masked_fill_(streaming_tts_text_mask == 0, min_dtype)
193
+
194
+ # Add extra dimensions for batch and heads
195
+ causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
196
+
197
+ return causal_mask
198
+
199
+
200
+ # Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py`
201
+ class ConvNeXtBlock(nn.Module):
202
+ def __init__(
203
+ self,
204
+ dim: int,
205
+ intermediate_dim: int,
206
+ kernel: int,
207
+ dilation: int,
208
+ layer_scale_init_value: float = 1e-6,
209
+ ):
210
+ # ConvNeXt Block copied from Vocos.
211
+ super().__init__()
212
+ self.dwconv = nn.Conv1d(
213
+ dim,
214
+ dim,
215
+ kernel_size=kernel,
216
+ padding=dilation * (kernel // 2),
217
+ dilation=dilation,
218
+ groups=dim,
219
+ )
220
+
221
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
222
+ self.pwconv1 = nn.Linear(dim, intermediate_dim)
223
+ self.act = nn.GELU()
224
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
225
+ self.coef = (
226
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
227
+ if layer_scale_init_value > 0
228
+ else None
229
+ )
230
+
231
+ def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor:
232
+ residual = x
233
+
234
+ y = self.dwconv(x)
235
+ y.transpose_(1, 2) # (B, C, T) -> (B, T, C)
236
+ x = self.norm(y)
237
+ del y
238
+ y = self.pwconv1(x)
239
+ del x
240
+ x = self.act(y)
241
+ del y
242
+ y = self.pwconv2(x)
243
+ del x
244
+ if self.coef is not None:
245
+ y *= self.coef
246
+ y.transpose_(1, 2) # (B, T, C) -> (B, C, T)
247
+
248
+ x = y + residual
249
+ del y
250
+
251
+ return x
252
+
253
+
254
+ # Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py`
255
+ class DVAEDecoder(nn.Module):
256
+ def __init__(
257
+ self,
258
+ idim: int,
259
+ odim: int,
260
+ n_layer=12,
261
+ bn_dim=64,
262
+ hidden=256,
263
+ kernel=7,
264
+ dilation=2,
265
+ up=False,
266
+ ):
267
+ super().__init__()
268
+ self.up = up
269
+ self.conv_in = nn.Sequential(
270
+ nn.Conv1d(idim, bn_dim, 3, 1, 1),
271
+ nn.GELU(),
272
+ nn.Conv1d(bn_dim, hidden, 3, 1, 1),
273
+ )
274
+ self.decoder_block = nn.ModuleList(
275
+ [
276
+ ConvNeXtBlock(
277
+ hidden,
278
+ hidden * 4,
279
+ kernel,
280
+ dilation,
281
+ )
282
+ for _ in range(n_layer)
283
+ ]
284
+ )
285
+ self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
286
+
287
+ def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor:
288
+ # B, C, T
289
+ y = self.conv_in(x)
290
+ del x
291
+ for f in self.decoder_block:
292
+ y = f(y, conditioning)
293
+
294
+ x = self.conv_out(y)
295
+ del y
296
+ return x
297
+
298
+
299
+ # Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py`
300
+ class GFSQ(nn.Module):
301
+ def __init__(
302
+ self,
303
+ dim: int,
304
+ levels: List[int],
305
+ G: int,
306
+ R: int,
307
+ eps=1e-5,
308
+ transpose=True,
309
+ ):
310
+ super(GFSQ, self).__init__()
311
+ self.quantizer = GroupedResidualFSQ(
312
+ dim=dim,
313
+ levels=list(levels),
314
+ num_quantizers=R,
315
+ groups=G,
316
+ )
317
+ self.n_ind = math.prod(levels)
318
+ self.eps = eps
319
+ self.transpose = transpose
320
+ self.G = G
321
+ self.R = R
322
+
323
+ def _embed(self, x: torch.Tensor):
324
+ if self.transpose:
325
+ x = x.transpose(1, 2)
326
+ x = x.view(x.size(0), x.size(1), self.G, self.R).permute(2, 0, 1, 3)
327
+ feat = self.quantizer.get_output_from_indices(x)
328
+ return feat.transpose_(1, 2) if self.transpose else feat
329
+
330
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
331
+ return super().__call__(x)
332
+
333
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
334
+ if self.transpose:
335
+ x.transpose_(1, 2)
336
+ _, ind = self.quantizer(x)
337
+ ind = ind.permute(1, 2, 0, 3).contiguous()
338
+ ind = ind.view(ind.size(0), ind.size(1), -1)
339
+ return ind.transpose_(1, 2) if self.transpose else ind
340
+
341
+
342
+ # Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py`
343
+ class DVAE(nn.Module):
344
+ def __init__(
345
+ self,
346
+ ):
347
+ super().__init__()
348
+
349
+ coef = torch.rand(100)
350
+ self.coef = nn.Parameter(coef.unsqueeze(0).unsqueeze_(2))
351
+
352
+ self.downsample_conv = nn.Sequential(
353
+ nn.Conv1d(100, 512, 3, 1, 1),
354
+ nn.GELU(),
355
+ nn.Conv1d(512, 512, 4, 2, 1),
356
+ nn.GELU(),
357
+ )
358
+
359
+ self.encoder = DVAEDecoder(
360
+ idim=512,
361
+ odim=1024,
362
+ hidden=256,
363
+ n_layer=12,
364
+ bn_dim=128,
365
+ )
366
+
367
+ self.decoder = DVAEDecoder(
368
+ idim=512,
369
+ odim=512,
370
+ hidden=256,
371
+ n_layer=12,
372
+ bn_dim=128,
373
+ )
374
+
375
+ self.out_conv = nn.Conv1d(512, 100, 3, 1, 1, bias=False)
376
+
377
+ self.vq_layer = GFSQ(
378
+ dim=1024,
379
+ levels=(5, 5, 5, 5),
380
+ G=2,
381
+ R=2,
382
+ )
383
+
384
+ @torch.inference_mode()
385
+ def forward(
386
+ self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode"
387
+ ) -> torch.Tensor:
388
+ if mode == "encode" and hasattr(self, "encoder") and self.vq_layer is not None:
389
+ mel = inp.clone()
390
+ x: torch.Tensor = self.downsample_conv(
391
+ torch.div(mel, self.coef.view(100, 1).expand(mel.shape), out=mel),
392
+ ).unsqueeze_(0)
393
+ del mel
394
+ x = self.encoder(x)
395
+ ind = self.vq_layer(x)
396
+ del x
397
+ return ind
398
+
399
+ if self.vq_layer is not None:
400
+ vq_feats = self.vq_layer._embed(inp)
401
+ else:
402
+ vq_feats = inp
403
+
404
+ vq_feats = (
405
+ vq_feats.view(
406
+ (vq_feats.size(0), 2, vq_feats.size(1) // 2, vq_feats.size(2)),
407
+ )
408
+ .permute(0, 2, 3, 1)
409
+ .flatten(2)
410
+ )
411
+
412
+ dec_out = self.out_conv(
413
+ self.decoder(
414
+ x=vq_feats,
415
+ ),
416
+ )
417
+
418
+ del vq_feats
419
+
420
+ return torch.mul(dec_out, self.coef, out=dec_out)
421
+
422
+
423
+ # Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/processors.py`
424
+ class CustomRepetitionPenaltyLogitsProcessorRepeat:
425
+ def __init__(self, penalty: float, max_input_ids: int, past_window: int):
426
+ if not isinstance(penalty, float) or not (penalty > 0):
427
+ raise ValueError(
428
+ f"`penalty` has to be a strictly positive float, but is {penalty}"
429
+ )
430
+
431
+ self.penalty = penalty
432
+ self.max_input_ids = max_input_ids
433
+ self.past_window = past_window
434
+
435
+ def __call__(
436
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
437
+ ) -> torch.FloatTensor:
438
+ if input_ids.size(1) > self.past_window:
439
+ input_ids = input_ids.narrow(1, -self.past_window, self.past_window)
440
+ freq = F.one_hot(input_ids, scores.size(1)).sum(1)
441
+ if freq.size(0) > self.max_input_ids:
442
+ freq.narrow(
443
+ 0, self.max_input_ids, freq.size(0) - self.max_input_ids
444
+ ).zero_()
445
+ alpha = torch.pow(self.penalty, freq)
446
+ scores = scores.contiguous()
447
+ inp = scores.multiply(alpha)
448
+ oth = scores.divide(alpha)
449
+ con = scores < 0
450
+ out = torch.where(con, inp, oth)
451
+ del inp, oth, scores, con, alpha
452
+ return out
453
+
454
+
455
+ class ConditionalChatTTS(PreTrainedModel):
456
+ """A conditional text-to-speech model that can generate speech from text with speaker conditioning.
457
+
458
+ This model extends PreTrainedModel to provide text-to-speech capabilities with:
459
+ - LLM hidden state conditioning
460
+ - Streaming generation
461
+
462
+ The model uses a transformer architecture with LLM hidden states and can operate in both
463
+ streaming and non-streaming modes for flexible deployment.
464
+
465
+ The model process sequence in the following format:
466
+ | text bos token | LLM embedding projected to tts embedding space | text tokens (fixed length, reserved for future tokens) | audio bos token | audio tokens (audio token length is not fixed)| audio eos token |
467
+
468
+ The format is designed to support LLM-conditioned streaming audio generation.
469
+
470
+ Usage:
471
+ To support streaming generation, two global variables should be maintained outside of the model.
472
+ 1. `audio_input_ids`: stores *discrete* audio codes. It is a tensor with shape [1, sequence length+1, num_vq].
473
+ 2. `past_key_values`: stores the KV cache for both text tokens and audio codes. It is a list of tuples, each tuple contains two tensors with shape [1, num_attention_heads, sequence length, hidden_size // num_attention_heads]
474
+
475
+ where `num_vq` is the number of audio codebooks, in default setting, it is `4`.
476
+
477
+ 1. Create an empty `past_key_values` with
478
+ ```python
479
+ initial_kv_cache_length = 1 + model.num_spk_embs + model.streaming_text_reserved_len # where `1` denotes the `bos` token
480
+ dtype = model.emb_text.weight.dtype
481
+ device = model.emb_text.weight.device
482
+ past_key_values = [
483
+ (
484
+ torch.zeros(1, model.config.num_attention_heads, initial_kv_cache_length, model.config.hidden_size // model.config.num_attention_heads, dtype=dtype, device=device),
485
+ torch.zeros(1, model.config.num_attention_heads, initial_kv_cache_length, model.config.hidden_size // model.config.num_attention_heads, dtype=dtype, device=device)
486
+ )
487
+ for _ in range(model.config.num_hidden_layers)
488
+ ]
489
+
490
+ 2. At the same time, create an empty `audio_input_ids` with shape [1, sequence length, num_vq], `num_vq` denotes multiple layer audio codebooks. But here we also include text tokens in the sequence, but they will be zeros, and will not be used, just a placeholder.
491
+
492
+ ```python
493
+ initial_audio_input_ids_length = 1 + model.num_spk_embs + model.streaming_text_reserved_len + 1
494
+ # [bos token, speaker embeddings, text tokens, audio bos token]
495
+ audio_input_ids = torch.zeros(batch_size=1, initial_audio_input_ids_length, model.num_vq)
496
+ ```
497
+
498
+ 2. Prefill some text tokens to TTS model (for example, 10 tokens) using `prefill_text` method.
499
+
500
+ ```python
501
+ outputs = llm.generate(**kwargs)
502
+ llm_tokens = some_function_to_extract_llm_tokens(outputs)
503
+ lm_spk_emb_last_hidden_states = some_function_to_extract_lm_spk_emb_last_hidden_states(outputs)
504
+ tts_text_input_ids = tts_tokenizer.encode(llm_tokenizer.decode(llm_tokens))
505
+ # here assume we are prefilling text token 0 to text token 9 (included), totally 10 tokens.
506
+ begin = 0
507
+ end = 9+1
508
+ position_ids = torch.arange(begin, end, dtype=torch.long, device=device)
509
+
510
+ past_key_values = model.prefill_text(
511
+ input_ids=tts_text_input_ids,
512
+ position_ids=position_ids,
513
+ past_key_values=past_key_values,
514
+ lm_spk_emb_last_hidden_states=lm_spk_emb_last_hidden_states,
515
+ )
516
+ ```
517
+
518
+ 3. Make a `streaming_tts_text_mask` to denote which position contains valid text tokens, similar to `attention_mask` in standard causal attention.
519
+
520
+ ```python
521
+ streaming_tts_text_mask = torch.zeros(model.streaming_reserved_length)
522
+ streaming_tts_text_mask[0:end] = 1 # denotes these post
523
+ ```
524
+
525
+ 3. Generate audio codes using `generate` method.
526
+
527
+ ```python
528
+ outputs = model.generate(
529
+ input_ids=audio_input_ids,
530
+ past_key_values=past_key_values,
531
+ streaming_tts_text_mask=streaming_tts_text_mask,
532
+ max_new_token=50,
533
+ )
534
+
535
+ # update past_key_values and input_ids
536
+ past_key_values = outputs.past_key_values
537
+ audio_input_ids = outputs.input_ids
538
+ ```
539
+
540
+ The `past_key_values` is extended by `max_new_token=50`, and `audio_input_ids` is also extended by `max_new_token=50` after `generate` calling.
541
+
542
+ 4. Notice that after prefilling `10` text tokens, the model can generate up to `50` audio tokens, if you want to generate more audio tokens, you need to prefill next `10` text tokens. And it is okay to only generate `25` audio tokens for faster initial response.
543
+
544
+ 5. Repeat steps `2,3,4` as needed in your streaming audio generation cases, but ensure usage complies with the following guidelines discussed above.
545
+ """
546
+
547
+ config_class = PretrainedConfig
548
+ _no_split_modules = []
549
+
550
+ def __init__(self, config: PretrainedConfig):
551
+ super().__init__(config)
552
+
553
+ self.use_speaker_embedding = config.use_speaker_embedding
554
+ self.use_llm_hidden_state = config.use_llm_hidden_state
555
+ self.num_spk_embs = config.num_spk_embs
556
+ self.spk_emb_token_id = config.spk_emb_token_id
557
+
558
+ self.use_text = config.use_text
559
+ self.streaming = config.streaming
560
+ self.streaming_text_chunk_size = config.streaming_text_chunk_size
561
+ self.streaming_audio_chunk_size = config.streaming_audio_chunk_size
562
+ self.streaming_text_reserved_len = config.streaming_text_reserved_len
563
+ self.audio_bos_token_id = config.audio_bos_token_id
564
+ self.num_mel_bins = config.num_mel_bins
565
+ self.num_vq = config.num_vq
566
+ self.num_audio_tokens = config.num_audio_tokens
567
+
568
+ self.top_p = config.top_p
569
+ self.top_k = config.top_k
570
+ self.repetition_penalty = config.repetition_penalty
571
+
572
+ if self.config.use_mlp:
573
+ self.projector = MultiModalProjector(config.llm_dim, config.hidden_size)
574
+ else:
575
+ self.projector = nn.Linear(config.llm_dim, config.hidden_size, bias=False)
576
+ self.emb_code = nn.ModuleList(
577
+ [
578
+ nn.Embedding(config.num_audio_tokens, config.hidden_size)
579
+ for _ in range(config.num_vq)
580
+ ]
581
+ )
582
+ self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size)
583
+ self.head_code = nn.ModuleList(
584
+ [
585
+ weight_norm(
586
+ nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False),
587
+ name="weight",
588
+ )
589
+ for _ in range(config.num_vq)
590
+ ]
591
+ )
592
+
593
+ dvae = DVAE()
594
+ self.dvae = dvae
595
+
596
+ model_config = LlamaConfig(
597
+ hidden_size=config.hidden_size,
598
+ intermediate_size=config.intermediate_size,
599
+ num_attention_heads=config.num_attention_heads,
600
+ num_hidden_layers=config.num_hidden_layers,
601
+ max_position_embeddings=config.max_position_embeddings,
602
+ attn_implementation=config.attn_implementation,
603
+ )
604
+
605
+ model = LlamaModel(model_config)
606
+ self.model = model
607
+
608
+ @torch.inference_mode()
609
+ def merge_inputs_embeds(
610
+ self,
611
+ input_ids: torch.Tensor,
612
+ lm_spk_emb_last_hidden_states: Optional[torch.Tensor] = None,
613
+ ):
614
+ """Merge `input_ids` and `lm_spk_emb_last_hidden_states` to `inputs_embeds`.
615
+
616
+ Args:
617
+ input_ids (torch.Tensor): Input token IDs.
618
+ lm_spk_emb_last_hidden_states (Optional[torch.Tensor], optional): Last hidden states of speaker embeddings from the language model. Defaults to None.
619
+
620
+ Raises:
621
+ NotImplementedError: If speaker embedding is not used and language model hidden states are not implemented.
622
+
623
+ Returns:
624
+ torch.Tensor: Prepared input embeddings for the model.
625
+ """
626
+ assert input_ids.shape[0] == 1
627
+
628
+ # Embed input_ids to input_embeds
629
+ inputs_embeds = self.emb_text(input_ids)
630
+
631
+ # Inject speaker embedding to input_embeds if it exists
632
+ if self.use_speaker_embedding:
633
+ spk_emb_mask = input_ids == self.spk_emb_token_id
634
+ if spk_emb_mask.any():
635
+ assert lm_spk_emb_last_hidden_states is not None
636
+ # Project spk emb to tts hidden size first, [batch_size, num_spk_emb, llm_dim] -> [batch_size, num_spk_emb, self.hidden_size]
637
+ lm_spk_emb_last_hidden_states = lm_spk_emb_last_hidden_states.to(
638
+ self.projector.linear1.weight.dtype
639
+ )
640
+ projected_spk_emb = self.projector(lm_spk_emb_last_hidden_states)
641
+ projected_spk_emb = F.normalize(projected_spk_emb, p=2, dim=-1)
642
+ apply_spk_emb(
643
+ input_ids=input_ids,
644
+ spk_emb=projected_spk_emb,
645
+ input_embeds=inputs_embeds,
646
+ spk_emb_token_id=self.spk_emb_token_id,
647
+ num_spk_embs=self.num_spk_embs,
648
+ )
649
+ else:
650
+ raise NotImplementedError
651
+
652
+ return inputs_embeds
653
+
654
+ @torch.inference_mode()
655
+ def prefill_text(
656
+ self,
657
+ input_ids: torch.Tensor,
658
+ position_ids: torch.LongTensor,
659
+ past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
660
+ lm_spk_emb_last_hidden_states: Optional[torch.Tensor] = None,
661
+ ):
662
+ """Prefill a chunk of new text tokens in streaming setting.
663
+ Specifically speaking, update `past_key_values` using new text tokens, then the model will read the new text tokens.
664
+
665
+ Args:
666
+ input_ids (Tensor): Tensor of shape [batch_size, seq_len]
667
+ position_ids (LongTensor): Tensor of shape [batch_size, seq_len]
668
+ past_key_values (List[Tuple[Tensor]]): KV Cache of all layers, each layer is a tuple (Tensor, Tensor) denoting keys and values. Each tensor is of seq_len = `self.streaming_text_reserved_len`. `past_key_values` will be updated.
669
+ lm_spk_emb_last_hidden_states (Tensor, optional): Tensor of shape [batch_size, num_spk_emb, llm_dim]. Defaults to None.
670
+
671
+ Note that all `batch_size` should be `1`.
672
+ """
673
+ assert input_ids.shape[0] == 1
674
+ assert past_key_values is not None
675
+
676
+ # Merge text and LLM embeddings
677
+ inputs_embeds = self.merge_inputs_embeds(
678
+ input_ids=input_ids,
679
+ lm_spk_emb_last_hidden_states=lm_spk_emb_last_hidden_states,
680
+ )
681
+
682
+ # Clone KV Cache
683
+ past_key_values_for_prefill = []
684
+ for i in range(len(past_key_values)):
685
+ past_key_values_for_prefill.append(
686
+ (
687
+ past_key_values[i][0][:, :, : position_ids[:, 0], :].clone(),
688
+ past_key_values[i][1][:, :, : position_ids[:, 0], :].clone(),
689
+ )
690
+ )
691
+
692
+ # ModelMiniCPMVBaseModel
693
+ outputs_prefill: BaseModelOutputWithPast = self.model(
694
+ attention_mask=None, # because for text, it is standard causal attention mask, do nothing
695
+ position_ids=position_ids, # position_ids denotes the position of new text tokens in the sequence
696
+ past_key_values=past_key_values_for_prefill, # `past_key_values` will be updated by the model
697
+ inputs_embeds=inputs_embeds, # contains text and language model embedding
698
+ use_cache=True,
699
+ output_attentions=False,
700
+ cache_position=position_ids, # which new positions will use this cache, basically the same as position_ids
701
+ )
702
+
703
+ # Get model updated KV Cache
704
+ past_key_values_for_prefill_updated = outputs_prefill.past_key_values
705
+
706
+ # Update generated KV Cache to input `past_key_values`
707
+ for layer_idx in range(len(past_key_values)):
708
+ # Update keys
709
+ past_key_values[layer_idx][0][
710
+ :, :, position_ids[:, 0] : position_ids[:, -1] + 1, :
711
+ ] = past_key_values_for_prefill_updated[layer_idx][0][
712
+ :, :, position_ids[:, 0] : position_ids[:, -1] + 1
713
+ ].clone()
714
+ # Update values
715
+ past_key_values[layer_idx][1][
716
+ :, :, position_ids[:, 0] : position_ids[:, -1] + 1, :
717
+ ] = past_key_values_for_prefill_updated[layer_idx][1][
718
+ :, :, position_ids[:, 0] : position_ids[:, -1] + 1
719
+ ].clone()
720
+
721
+ # TODO: del past_key_values_for_prefill_updated recursively
722
+ # TODO: del outputs_prefill recursively
723
+
724
+ return past_key_values
725
+
726
+ @torch.inference_mode()
727
+ def prefill_audio_ids(
728
+ self,
729
+ input_ids: torch.Tensor,
730
+ past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
731
+ streaming_tts_text_mask=None,
732
+ add_audio_bos: bool = True,
733
+ ):
734
+ """Prefill a chunk of audio ids to the model. Used in sliding-window long audio generation.
735
+ Specifically, prefill many audio ids (typically from last window) to the model in the new window.
736
+
737
+ Args:
738
+ input_ids (torch.Tensor): (1, seq_len, num_vq) Audio input token ids.
739
+ past_key_values (List[Tuple[torch.Tensor, torch.Tensor]]): Past key values for attention mechanism.
740
+ """
741
+ assert input_ids.shape[0] == 1
742
+ assert past_key_values is not None
743
+
744
+ code_emb = [self.emb_code[i](input_ids[:, :, i]) for i in range(self.num_vq)]
745
+ inputs_embeds = torch.stack(code_emb, 3).sum(3) # [1,seq_len,768]
746
+ input_len = input_ids.shape[1]
747
+
748
+ if add_audio_bos:
749
+ narrowed_input_ids = torch.tensor(
750
+ [[self.audio_bos_token_id]], dtype=torch.long, device=self.device
751
+ )
752
+ bos_inputs_embeds = self.emb_text(narrowed_input_ids)
753
+ inputs_embeds = torch.cat([bos_inputs_embeds, inputs_embeds], dim=1)
754
+ input_len += 1
755
+
756
+ past_key_values_length = past_key_values[0][0].shape[2]
757
+ position_ids = torch.arange(
758
+ past_key_values_length,
759
+ past_key_values_length + input_len,
760
+ dtype=torch.long,
761
+ device=self.device,
762
+ ).unsqueeze(0)
763
+
764
+ cache_position = position_ids.clone()
765
+ causal_mask = make_streaming_chunk_mask_generation(
766
+ inputs_embeds=inputs_embeds,
767
+ past_seen_tokens=past_key_values[0][0].shape[2],
768
+ streaming_tts_text_mask=streaming_tts_text_mask,
769
+ streaming_reserved_length=self.streaming_text_reserved_len,
770
+ streaming_text_chunk_size=self.streaming_text_chunk_size,
771
+ ) # [1, 1, 1, past_key_values_length + input_len]
772
+
773
+ # Model forward
774
+ outputs: BaseModelOutputWithPast = self.model(
775
+ attention_mask=causal_mask,
776
+ position_ids=position_ids,
777
+ past_key_values=past_key_values,
778
+ inputs_embeds=inputs_embeds,
779
+ use_cache=True,
780
+ output_attentions=False,
781
+ cache_position=cache_position,
782
+ )
783
+ past_key_values = outputs.past_key_values
784
+ return past_key_values
785
+
786
+ @torch.inference_mode()
787
+ def generate(
788
+ self,
789
+ input_ids: torch.Tensor,
790
+ past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
791
+ temperature: torch.Tensor,
792
+ eos_token: Union[int, torch.Tensor],
793
+ streaming_tts_text_mask=None,
794
+ force_no_stop=False,
795
+ min_new_token=10,
796
+ max_new_token=50,
797
+ logits_warpers: List[LogitsWarper] = [],
798
+ logits_processors: List[CustomRepetitionPenaltyLogitsProcessorRepeat] = [],
799
+ show_tqdm=False,
800
+ ):
801
+ """Generate audio codes in streaming setting or non-streaming setting.
802
+ Specifically speaking, generate audio codes when not all text tokens are prefilled.
803
+
804
+ Always pass a valid `past_key_values` to the method. The method does not do `prefill` by itself. It relies on `prefill_text` method to provide valid `past_key_values`. Please refer to docstring of this class for more details.
805
+
806
+ In this method, we borrowed a lot of codes from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/gpt.py`.
807
+
808
+ Args:
809
+ input_ids (torch.Tensor): Input token ids.
810
+ past_key_values (List[Tuple[torch.Tensor, torch.Tensor]]): Past key values for attention mechanism.
811
+ temperature (torch.Tensor): Temperature for sampling.
812
+ eos_token (Union[int, torch.Tensor]): End of sequence token.
813
+ streaming_tts_text_mask (Optional[torch.Tensor], optional): Mask for streaming TTS text. Defaults to None.
814
+ max_new_token (int, optional): Maximum number of new tokens to generate. Defaults to 50.
815
+ logits_warpers (List[LogitsWarper], optional): List of logits warpers. Defaults to [].
816
+ logits_processors (List[CustomRepetitionPenaltyLogitsProcessorRepeat], optional): List of logits processors. Defaults to [].
817
+ show_tqdm (bool, optional): Whether to show progress bar. Defaults to True.
818
+
819
+ Returns:
820
+ GenerationOutputs: Generation outputs.
821
+ """
822
+
823
+ # We only support batch size `1` for now
824
+ assert input_ids.shape[0] == 1
825
+ assert past_key_values is not None
826
+
827
+ # fix: this should not be `input_ids.shape[1]`
828
+ # start_idx = input_ids.shape[1]
829
+ start_idx = (
830
+ 1
831
+ + self.num_spk_embs * self.use_speaker_embedding
832
+ + self.streaming_text_reserved_len
833
+ + 1
834
+ )
835
+
836
+ finish = torch.zeros(input_ids.shape[0], device=input_ids.device).bool()
837
+
838
+ temperature = (
839
+ temperature.unsqueeze(0)
840
+ .expand(input_ids.shape[0], -1)
841
+ .contiguous()
842
+ .view(-1, 1)
843
+ )
844
+
845
+ progress = input_ids.shape[1]
846
+
847
+ # Pre-allocate input_ids, shape is [batch_size=1, max_possible_seq_len, self.num_vqs]
848
+ input_ids_buf = torch.zeros(
849
+ input_ids.shape[0], # batch_size
850
+ progress
851
+ + max_new_token, # max_possible_seq_len = input_ids.shape[1] + max_new_token
852
+ input_ids.shape[2], # self.num_vqs
853
+ dtype=input_ids.dtype,
854
+ device=input_ids.device,
855
+ )
856
+
857
+ # Copy existing `input_ids` to `input_ids_buf`
858
+ input_ids_buf.narrow(1, 0, progress).copy_(input_ids)
859
+
860
+ del input_ids
861
+ input_ids = input_ids_buf.narrow(1, 0, progress)
862
+
863
+ pbar: Optional[tqdm] = None
864
+ if show_tqdm:
865
+ pbar = tqdm(
866
+ total=max_new_token,
867
+ desc="code",
868
+ bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]",
869
+ )
870
+
871
+ condition_length = (
872
+ 1
873
+ + self.num_spk_embs * self.use_speaker_embedding
874
+ + self.streaming_text_reserved_len
875
+ + 1
876
+ )
877
+
878
+ for i in range(max_new_token):
879
+ # Prepare generation inputs
880
+ audio_bos = False
881
+
882
+ # If this is the first audio token, the case is SPECIAL
883
+ if progress == condition_length:
884
+ audio_bos = True
885
+
886
+ assert progress == (
887
+ past_key_values[0][0].shape[2] + 1
888
+ ) # If you are using according to the guidelines, this should be passed.
889
+
890
+ if audio_bos:
891
+ # Generate the first token, activate the model with `self.audio_bos_token_id`, the model will predict
892
+ # a new audio token. This is a special case because without the `audio bos token`, it is impossible
893
+ # to generate the first audio token in our streaming setting.
894
+ narrowed_input_ids = torch.tensor(
895
+ [[self.audio_bos_token_id]], dtype=torch.long, device=self.device
896
+ )
897
+ inputs_embeds = self.emb_text(narrowed_input_ids)
898
+ del narrowed_input_ids
899
+ else:
900
+ # Generate the following audio tokens, it is applicable to all other cases, including second and the
901
+ # following calling of `generate`.
902
+ narrowed_input_ids = input_ids.narrow(
903
+ dim=1, start=input_ids.shape[1] - 1, length=1
904
+ )
905
+ code_emb = [
906
+ self.emb_code[i](narrowed_input_ids[:, :, i])
907
+ for i in range(self.num_vq)
908
+ ]
909
+ inputs_embeds = torch.stack(code_emb, 3).sum(3)
910
+
911
+ position_ids = torch.tensor(
912
+ [past_key_values[0][0].shape[2]], dtype=torch.long, device=self.device
913
+ ).unsqueeze(0)
914
+
915
+ cache_position = position_ids.clone()
916
+
917
+ # Make causal mask
918
+ causal_mask = make_streaming_chunk_mask_generation(
919
+ inputs_embeds=inputs_embeds,
920
+ past_seen_tokens=past_key_values[0][0].shape[2],
921
+ streaming_tts_text_mask=streaming_tts_text_mask,
922
+ streaming_reserved_length=self.streaming_text_reserved_len,
923
+ streaming_text_chunk_size=self.streaming_text_chunk_size,
924
+ )
925
+
926
+ # Model forward
927
+ outputs: BaseModelOutputWithPast = self.model(
928
+ attention_mask=causal_mask,
929
+ position_ids=position_ids,
930
+ past_key_values=past_key_values,
931
+ inputs_embeds=inputs_embeds,
932
+ use_cache=True,
933
+ output_attentions=False,
934
+ cache_position=cache_position,
935
+ )
936
+
937
+ del position_ids
938
+ del inputs_embeds
939
+ del cache_position
940
+ del causal_mask
941
+
942
+ hidden_states = outputs.last_hidden_state
943
+ past_key_values = outputs.past_key_values
944
+
945
+ with P.cached():
946
+ logits = torch.empty(
947
+ hidden_states.size(0),
948
+ hidden_states.size(1),
949
+ self.num_audio_tokens,
950
+ self.num_vq,
951
+ dtype=torch.float,
952
+ device=self.device,
953
+ )
954
+ for num_vq_iter in range(self.num_vq):
955
+ x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
956
+ logits[..., num_vq_iter] = x
957
+ del x
958
+
959
+ del hidden_states
960
+
961
+ # logits = logits[:, -1].float()
962
+ logits = logits.narrow(1, -1, 1).squeeze_(1).float()
963
+
964
+ # logits = rearrange(logits, "b c n -> (b n) c")
965
+ logits = logits.permute(0, 2, 1)
966
+ logits = logits.reshape(-1, logits.size(2))
967
+ # logits_token = rearrange(input_ids[:, start_idx:], "b c n -> (b n) c")
968
+ input_ids_sliced = input_ids.narrow(
969
+ 1,
970
+ start_idx,
971
+ input_ids.size(1) - start_idx,
972
+ ).permute(0, 2, 1)
973
+ logits_token = input_ids_sliced.reshape(
974
+ input_ids_sliced.size(0) * input_ids_sliced.size(1),
975
+ -1,
976
+ ).to(self.device)
977
+ del input_ids_sliced
978
+
979
+ logits /= temperature
980
+
981
+ if not audio_bos:
982
+ for logitsProcessors in logits_processors:
983
+ logits = logitsProcessors(logits_token, logits)
984
+ if not audio_bos:
985
+ for logitsWarpers in logits_warpers:
986
+ logits = logitsWarpers(logits_token, logits)
987
+
988
+ del logits_token
989
+
990
+ if i < min_new_token:
991
+ logits[:, eos_token] = -torch.inf
992
+
993
+ if force_no_stop:
994
+ logits[:, eos_token] = -torch.inf
995
+
996
+ scores = F.softmax(logits, dim=-1)
997
+
998
+ del logits
999
+ idx_next = torch.multinomial(scores, num_samples=1) # .to(finish.device)
1000
+
1001
+ del scores
1002
+
1003
+ # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
1004
+ idx_next = idx_next.view(-1, self.num_vq)
1005
+ finish_or = idx_next.eq(eos_token).any(1)
1006
+ finish.logical_or_(finish_or)
1007
+
1008
+ del finish_or
1009
+ # Store new `token` into `input_ids_buf`
1010
+ input_ids_buf.narrow(1, progress, 1).copy_(idx_next.unsqueeze_(1))
1011
+
1012
+ if i == 0 and finish.any():
1013
+ # raise Exception
1014
+ break
1015
+
1016
+ del idx_next
1017
+ progress += 1
1018
+ input_ids = input_ids_buf.narrow(1, 0, progress)
1019
+
1020
+ if finish.all():
1021
+ break
1022
+
1023
+ if pbar is not None:
1024
+ pbar.update(1)
1025
+
1026
+ if pbar is not None:
1027
+ pbar.close()
1028
+
1029
+ if not finish.all():
1030
+ if show_tqdm:
1031
+ logger.info(f"incomplete result. hit max_new_token: {max_new_token}")
1032
+
1033
+ del input_ids_buf
1034
+
1035
+ if finish.all():
1036
+ # the last may contains eos token
1037
+ genrated_input_ids = input_ids[:, condition_length:-1, :]
1038
+ else:
1039
+ # there is no eos token
1040
+ genrated_input_ids = input_ids[:, condition_length:, :]
1041
+
1042
+ return ConditionalChatTTSGenerationOutput(
1043
+ new_ids=genrated_input_ids,
1044
+ audio_input_ids=input_ids, # for update purpose
1045
+ past_key_values=past_key_values, # for update purpose
1046
+ finished=finish.all(),
1047
+ )
1048
+
1049
+ @torch.inference_mode()
1050
+ def decode_to_mel_specs(
1051
+ self,
1052
+ result_list: List[torch.Tensor],
1053
+ ):
1054
+ """Decode discrete audio codes to mel spectrograms.
1055
+
1056
+ Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/core.py`
1057
+
1058
+ Args:
1059
+ result_list (List[torch.Tensor]): Audio codes output from `generate`.
1060
+
1061
+ Returns:
1062
+ torch.Tensor: Mel spectrograms.
1063
+ """
1064
+
1065
+ decoder = self.dvae
1066
+ max_x_len = -1
1067
+ if len(result_list) == 0:
1068
+ return np.array([], dtype=np.float32)
1069
+ for result in result_list:
1070
+ if result.size(0) > max_x_len:
1071
+ max_x_len = result.size(0)
1072
+ batch_result = torch.zeros(
1073
+ (len(result_list), result_list[0].size(1), max_x_len),
1074
+ dtype=result_list[0].dtype,
1075
+ device=result_list[0].device,
1076
+ )
1077
+ for i in range(len(result_list)):
1078
+ src = result_list[i]
1079
+ batch_result[i].narrow(1, 0, src.size(0)).copy_(src.permute(1, 0))
1080
+ del src
1081
+
1082
+ mel_specs = decoder(batch_result)
1083
+ del batch_result
1084
+ return mel_specs
1085
+
1086
+
1087
+ # Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer and add use_cache for streaming inference
1088
+ class MiniCPMWhisperEncoderLayer(nn.Module):
1089
+ def __init__(self, config: WhisperConfig, layer_idx: int = None):
1090
+ super().__init__()
1091
+ self.embed_dim = config.d_model
1092
+ self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
1093
+ embed_dim=self.embed_dim,
1094
+ num_heads=config.encoder_attention_heads,
1095
+ dropout=config.attention_dropout,
1096
+ config=config,
1097
+ layer_idx=layer_idx,
1098
+ )
1099
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
1100
+ self.dropout = config.dropout
1101
+ self.activation_fn = ACT2FN[config.activation_function]
1102
+ self.activation_dropout = config.activation_dropout
1103
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
1104
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
1105
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
1106
+
1107
+ def forward(
1108
+ self,
1109
+ hidden_states: torch.Tensor,
1110
+ attention_mask: torch.Tensor,
1111
+ layer_head_mask: torch.Tensor,
1112
+ output_attentions: bool = False,
1113
+ past_key_values: Optional[EncoderDecoderCache] = None,
1114
+ use_cache: Optional[bool] = False,
1115
+ ) -> torch.Tensor:
1116
+ r"""
1117
+ Args:
1118
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, embed_dim)`):
1119
+ Hidden states to be fed into the encoder layer.
1120
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, 1, tgt_len, src_len)`):
1121
+ Attention mask where padding elements are indicated by large negative values.
1122
+ layer_head_mask (`torch.FloatTensor` of shape `(encoder_attention_heads,)`):
1123
+ Mask to nullify selected heads of the attention modules.
1124
+ output_attentions (`bool`, *optional*):
1125
+ Whether or not to return the attention weights.
1126
+ past_key_values (`EncoderDecoderCache`, *optional*):
1127
+ Past key-value pairs used for incremental decoding.
1128
+ use_cache (`bool`, *optional*):
1129
+ Whether or not to return updated `past_key_values` for caching.
1130
+
1131
+ Returns:
1132
+ A tuple of shape `(hidden_states, optional(attn_weights), optional(past_key_values))`.
1133
+ """
1134
+ residual = hidden_states
1135
+ hidden_states = self.self_attn_layer_norm(hidden_states)
1136
+ hidden_states, attn_weights, past_key_values = self.self_attn(
1137
+ hidden_states=hidden_states,
1138
+ attention_mask=attention_mask,
1139
+ layer_head_mask=layer_head_mask,
1140
+ output_attentions=output_attentions,
1141
+ past_key_value=past_key_values,
1142
+ )
1143
+ hidden_states = nn.functional.dropout(
1144
+ hidden_states, p=self.dropout, training=False
1145
+ )
1146
+ hidden_states = residual + hidden_states
1147
+
1148
+ residual = hidden_states
1149
+ hidden_states = self.final_layer_norm(hidden_states)
1150
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
1151
+ hidden_states = nn.functional.dropout(
1152
+ hidden_states, p=self.activation_dropout, training=False
1153
+ )
1154
+ hidden_states = self.fc2(hidden_states)
1155
+ hidden_states = nn.functional.dropout(
1156
+ hidden_states, p=self.dropout, training=False
1157
+ )
1158
+ hidden_states = residual + hidden_states
1159
+
1160
+ if hidden_states.dtype == torch.float16 and (
1161
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
1162
+ ):
1163
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
1164
+ hidden_states = torch.clamp(
1165
+ hidden_states, min=-clamp_value, max=clamp_value
1166
+ )
1167
+
1168
+ outputs = (hidden_states,)
1169
+
1170
+ if output_attentions:
1171
+ outputs += (attn_weights,)
1172
+
1173
+ if use_cache:
1174
+ outputs += (past_key_values,)
1175
+
1176
+ return outputs
1177
+
1178
+
1179
+ # Copied from from transformers.models.whisper.modeling_whisper.WhisperEncoder and add use_cache for streaming inference
1180
+ class MiniCPMWhisperEncoder(WhisperEncoder):
1181
+
1182
+ def __init__(self, config: WhisperConfig):
1183
+ super().__init__(config)
1184
+ self.layers = nn.ModuleList(
1185
+ [
1186
+ MiniCPMWhisperEncoderLayer(config, layer_idx=i)
1187
+ for i in range(config.encoder_layers)
1188
+ ]
1189
+ )
1190
+
1191
+ def forward(
1192
+ self,
1193
+ input_features,
1194
+ attention_mask=None,
1195
+ head_mask=None,
1196
+ output_attentions=None,
1197
+ output_hidden_states=None,
1198
+ return_dict=None,
1199
+ past_key_values: Optional[EncoderDecoderCache] = None,
1200
+ use_cache: Optional[bool] = None,
1201
+ ):
1202
+ r"""
1203
+ Forward pass of the Whisper encoder.
1204
+
1205
+ Args:
1206
+ input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
1207
+ Float values of log-mel features extracted from the raw audio waveform. Typically generated
1208
+ by a feature extractor (e.g., `WhisperFeatureExtractor`) that processes `.flac` or `.wav`
1209
+ files into padded 2D mel spectrogram frames. These features are projected via convolution layers
1210
+ (`conv1` and `conv2`) and then transformed into embeddings for the encoder.
1211
+
1212
+ attention_mask (`torch.Tensor`, *optional*):
1213
+ Not used by Whisper for masking `input_features`, but included for API compatibility with
1214
+ other models. If provided, it is simply ignored within the model. By default, Whisper
1215
+ effectively ignores silence in the input log-mel spectrogram.
1216
+
1217
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
1218
+ Mask to nullify selected attention heads. The elements should be either 1 or 0, where:
1219
+ - 1 indicates the head is **not masked**,
1220
+ - 0 indicates the head is **masked** (i.e., the attention head is dropped).
1221
+
1222
+ output_attentions (`bool`, *optional*):
1223
+ Whether or not to return the attention tensors of all encoder layers. If set to `True`, the
1224
+ returned tuple (or `BaseModelOutputWithPast`) will contain an additional element with
1225
+ attention weights for each encoder layer.
1226
+
1227
+ output_hidden_states (`bool`, *optional*):
1228
+ Whether or not to return the hidden states of all layers. If set to `True`, the returned
1229
+ tuple (or `BaseModelOutputWithPast`) will contain a tuple of hidden states, including the
1230
+ initial embedding output as well as the outputs of each layer.
1231
+
1232
+ return_dict (`bool`, *optional*):
1233
+ Whether or not to return a `BaseModelOutputWithPast` (a subclass of `ModelOutput`) instead
1234
+ of a plain tuple. If set to `True`, the output will be a `BaseModelOutputWithPast` object,
1235
+ otherwise it will be a tuple.
1236
+
1237
+ past_key_values (`EncoderDecoderCache`, *optional*):
1238
+ When using caching for faster inference, this is an object that stores the key-value pairs
1239
+ for attention states. If provided, the model will append new states to the existing cache
1240
+ and return the updated cache. This speeds up sequential decoding or chunked inference.
1241
+
1242
+ - If `past_key_values` is `None`, no past states are used or returned.
1243
+ - If `past_key_values` is not `None` and `use_cache=True`, the model will use the provided
1244
+ cache and return the updated cache (as `next_encoder_cache`).
1245
+
1246
+ use_cache (`bool`, *optional*):
1247
+ Whether or not the model should use caching (`past_key_values`) to speed up processing
1248
+ during inference. When set to `True`, the model will:
1249
+ - Inspect and use `past_key_values` if provided.
1250
+ - Return updated `past_key_values` (under the name `next_encoder_cache` in
1251
+ `BaseModelOutputWithPast`).
1252
+
1253
+ Returns:
1254
+ `BaseModelOutputWithPast` or `tuple` (depending on `return_dict`):
1255
+ If `return_dict=True`, a `BaseModelOutputWithPast` is returned, which contains:
1256
+ - **last_hidden_state** (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1257
+ The output of the final encoder layer.
1258
+ - **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_hidden_states=True`):
1259
+ Hidden states of the model at each layer (including the initial projection).
1260
+ - **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_attentions=True`):
1261
+ Attention weights from each encoder layer.
1262
+ - **past_key_values** (an object of type `EncoderDecoderCache` or `None`, *optional*):
1263
+ Updated cache of key-value pairs if `use_cache=True`.
1264
+
1265
+ If `return_dict=False`, a tuple is returned, where the format is:
1266
+ `(last_hidden_state, hidden_states, attentions)`, with `hidden_states` and `attentions`
1267
+ only present if their respective `output_*` arguments are set to `True`.
1268
+
1269
+ """
1270
+ output_attentions = (
1271
+ output_attentions
1272
+ if output_attentions is not None
1273
+ else self.config.output_attentions
1274
+ )
1275
+ output_hidden_states = (
1276
+ output_hidden_states
1277
+ if output_hidden_states is not None
1278
+ else self.config.output_hidden_states
1279
+ )
1280
+ return_dict = (
1281
+ return_dict if return_dict is not None else self.config.use_return_dict
1282
+ )
1283
+
1284
+ # Ignore copy
1285
+ input_features = input_features.to(
1286
+ dtype=self.conv1.weight.dtype, device=self.conv1.weight.device
1287
+ )
1288
+
1289
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features))
1290
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
1291
+
1292
+ inputs_embeds = inputs_embeds.permute(0, 2, 1)
1293
+
1294
+ embed_pos = self.embed_positions.weight
1295
+ past_key_values_length = 0
1296
+ if use_cache:
1297
+ if past_key_values is None:
1298
+ past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
1299
+ elif isinstance(past_key_values, list):
1300
+ past_key_values = EncoderDecoderCache(
1301
+ DynamicCache.from_legacy_cache(past_key_values), DynamicCache()
1302
+ )
1303
+ elif isinstance(past_key_values, DynamicCache):
1304
+ past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
1305
+ else:
1306
+ pass
1307
+ past_key_values_length = (
1308
+ past_key_values.self_attention_cache.get_usable_length(
1309
+ inputs_embeds.shape[1]
1310
+ )
1311
+ )
1312
+ if inputs_embeds.shape[1] + past_key_values_length > embed_pos.shape[0]:
1313
+ logger.warning(
1314
+ "seems the audio is longer than 30s. repeating the last part of the audio"
1315
+ )
1316
+ embed_pos_front = embed_pos[past_key_values_length:, :]
1317
+ embed_pos = torch.cat(
1318
+ (
1319
+ embed_pos_front,
1320
+ torch.repeat_interleave(
1321
+ embed_pos[-1, :].unsqueeze(0),
1322
+ inputs_embeds.shape[1]
1323
+ - embed_pos.shape[0]
1324
+ + past_key_values_length,
1325
+ dim=0,
1326
+ ),
1327
+ )
1328
+ )
1329
+ else:
1330
+ embed_pos = embed_pos[
1331
+ past_key_values_length : inputs_embeds.shape[1]
1332
+ + past_key_values_length,
1333
+ :,
1334
+ ]
1335
+ else:
1336
+ embed_pos = embed_pos[: inputs_embeds.shape[1], :]
1337
+
1338
+ hidden_states = inputs_embeds + embed_pos
1339
+ hidden_states = nn.functional.dropout(
1340
+ hidden_states, p=self.dropout, training=False
1341
+ )
1342
+
1343
+ encoder_states = () if output_hidden_states else None
1344
+ all_attentions = () if output_attentions else None
1345
+
1346
+ # check if head_mask has a correct number of layers specified if desired
1347
+ if head_mask is not None:
1348
+ assert head_mask.size()[0] == (
1349
+ len(self.layers)
1350
+ ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
1351
+
1352
+ for idx, encoder_layer in enumerate(self.layers):
1353
+ if output_hidden_states:
1354
+ encoder_states = encoder_states + (hidden_states,)
1355
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1356
+ to_drop = False
1357
+
1358
+ # Ignore copy
1359
+ if to_drop:
1360
+ layer_outputs = (None, None)
1361
+ else:
1362
+ layer_outputs = encoder_layer(
1363
+ hidden_states,
1364
+ attention_mask,
1365
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
1366
+ output_attentions=output_attentions,
1367
+ past_key_values=past_key_values,
1368
+ use_cache=use_cache,
1369
+ )
1370
+
1371
+ hidden_states = layer_outputs[0]
1372
+
1373
+ if use_cache:
1374
+ next_encoder_cache = layer_outputs[2 if output_attentions else 1]
1375
+ else:
1376
+ next_encoder_cache = None
1377
+
1378
+ if output_attentions:
1379
+ all_attentions = all_attentions + (layer_outputs[1],)
1380
+
1381
+ hidden_states = self.layer_norm(hidden_states)
1382
+ if output_hidden_states:
1383
+ encoder_states = encoder_states + (hidden_states,)
1384
+
1385
+ if not return_dict:
1386
+ return tuple(
1387
+ v
1388
+ for v in [hidden_states, encoder_states, all_attentions]
1389
+ if v is not None
1390
+ )
1391
+ return BaseModelOutputWithPast(
1392
+ last_hidden_state=hidden_states,
1393
+ hidden_states=encoder_states,
1394
+ attentions=all_attentions,
1395
+ past_key_values=next_encoder_cache,
1396
+ )
1397
+
1398
+
1399
+ class MultiModalProjector(nn.Module):
1400
+ def __init__(self, in_dim, out_dim):
1401
+ super().__init__()
1402
+ self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True)
1403
+ self.relu = nn.ReLU()
1404
+ self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True)
1405
+
1406
+ def forward(self, audio_features):
1407
+ hidden_states = self.relu(self.linear1(audio_features))
1408
+ hidden_states = self.linear2(hidden_states)
1409
+ return hidden_states
1410
+
1411
+
1412
+ class MiniCPMO(MiniCPMVBaseModel):
1413
+ def __init__(
1414
+ self,
1415
+ config: PretrainedConfig,
1416
+ quant_config: Optional[QuantizationConfig] = None,
1417
+ ) -> None:
1418
+ super().__init__(config=config, quant_config=quant_config)
1419
+
1420
+ self.llm = self.init_llm(config=config, quant_config=quant_config)
1421
+
1422
+ self.embed_dim = self.llm.config.hidden_size
1423
+
1424
+ # init vision module
1425
+ if self.config.init_vision:
1426
+ # print("vision-understanding enabled")
1427
+ self.vpm = self.init_vision_module(config=config, quant_config=quant_config)
1428
+ self.vision_dim = self.vpm.embed_dim
1429
+ self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
1430
+
1431
+ # init audio module
1432
+ self.config.init_audio = True
1433
+ if self.config.init_audio:
1434
+ # print("audio-understanding enabled")
1435
+ self.apm = self.init_audio_module()
1436
+ audio_output_dim = int(self.apm.config.encoder_ffn_dim // 4)
1437
+ self.audio_avg_pooler = nn.AvgPool1d(
1438
+ self.config.audio_pool_step, stride=self.config.audio_pool_step
1439
+ )
1440
+ self.audio_projection_layer = MultiModalProjector(
1441
+ in_dim=audio_output_dim, out_dim=self.embed_dim
1442
+ )
1443
+ self.audio_encoder_layer = -1
1444
+
1445
+ # init tts module
1446
+ self.config.init_tts = False
1447
+ logger.info("TTS is disabled for now")
1448
+ if self.config.init_tts:
1449
+ # print("tts enabled")
1450
+ assert (
1451
+ _tts_deps
1452
+ ), "please make sure vector_quantize_pytorch and vocos are installed."
1453
+ self.tts = self.init_tts_module()
1454
+
1455
+ def init_tts_module(self):
1456
+ model = ConditionalChatTTS(self.config.tts_config)
1457
+ return model
1458
+
1459
+ def init_audio_module(self):
1460
+ model = MiniCPMWhisperEncoder(self.config.audio_config)
1461
+ return model
1462
+
1463
+ def init_llm(
1464
+ self,
1465
+ config: PretrainedConfig,
1466
+ quant_config: Optional[QuantizationConfig] = None,
1467
+ prefix: str = "",
1468
+ ) -> nn.Module:
1469
+ return Qwen2ForCausalLM(config=config, quant_config=quant_config, prefix=prefix)
1470
+
1471
+ def init_vision_module(
1472
+ self,
1473
+ config: PretrainedConfig,
1474
+ quant_config: Optional[QuantizationConfig],
1475
+ prefix: str = "",
1476
+ ):
1477
+ if self.config._attn_implementation == "flash_attention_2":
1478
+ self.config.vision_config._attn_implementation = "flash_attention_2"
1479
+ else:
1480
+ self.config.vision_config._attn_implementation = "eager"
1481
+ model = Idefics2VisionTransformer(
1482
+ config=config.vision_config, quant_config=quant_config, prefix=prefix
1483
+ )
1484
+ if self.config.drop_vision_last_layer:
1485
+ model.encoder.layers = model.encoder.layers[:-1]
1486
+
1487
+ setattr(model, "embed_dim", model.embeddings.embed_dim)
1488
+ setattr(model, "patch_size", model.embeddings.patch_size)
1489
+
1490
+ return model
1491
+
1492
+ def init_resampler(
1493
+ self,
1494
+ embed_dim: int,
1495
+ vision_dim: int,
1496
+ quant_config: Optional[QuantizationConfig] = None,
1497
+ prefix: str = "",
1498
+ ) -> nn.Module:
1499
+ with set_default_torch_dtype(torch.float16):
1500
+ # The resampler in 2.6 remains consistent with the one in 2.5.
1501
+ resampler = Resampler2_5(
1502
+ num_queries=self.config.query_num,
1503
+ embed_dim=embed_dim,
1504
+ num_heads=embed_dim // 128,
1505
+ kv_dim=vision_dim,
1506
+ quant_config=quant_config,
1507
+ prefix=prefix,
1508
+ )
1509
+
1510
+ return resampler.to(device="cuda", dtype=torch.get_default_dtype())
1511
+
1512
+ def pad_input_ids(self, input_ids: List[int], mm_input: MultimodalInputs):
1513
+ # Get all special token IDs
1514
+ im_start_id: int = mm_input.im_start_id
1515
+ im_end_id: int = mm_input.im_end_id
1516
+ slice_start_id: int = mm_input.slice_start_id
1517
+ slice_end_id: int = mm_input.slice_end_id
1518
+
1519
+ media_token_pairs = [
1520
+ (im_start_id, im_end_id),
1521
+ (slice_start_id, slice_end_id),
1522
+ (mm_input.audio_start_id, mm_input.audio_end_id),
1523
+ ]
1524
+ pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
1525
+
1526
+ return pattern.pad_input_tokens(input_ids, mm_input)
1527
+
1528
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
1529
+ """
1530
+ Computes the output length of the convolutional layers and the output length of the audio encoder
1531
+ """
1532
+ input_lengths_after_cnn = (input_lengths - 1) // 2 + 1
1533
+ input_lengths_after_pooling = (
1534
+ input_lengths_after_cnn - self.config.audio_pool_step
1535
+ ) // self.config.audio_pool_step + 1
1536
+ input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32)
1537
+
1538
+ return input_lengths_after_cnn, input_lengths_after_pooling
1539
+
1540
+ def get_audio_embedding_streaming(self, multimodal_input: MultimodalInputs):
1541
+ r"""
1542
+ Extract audio embeddings in a streaming manner using cached key-value pairs.
1543
+
1544
+ This method processes incoming audio features incrementally and stores/updates `past_key_values`
1545
+ for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended
1546
+ for streaming scenarios.
1547
+
1548
+ Args:
1549
+ multimodal_input (dict):
1550
+ - **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
1551
+ - **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
1552
+
1553
+ Returns:
1554
+ List[List[torch.Tensor]]: audio embeddings
1555
+ """
1556
+ # print("audio embedding")
1557
+
1558
+ wavforms = (
1559
+ []
1560
+ if multimodal_input.audio_features is None
1561
+ else multimodal_input.audio_features
1562
+ )
1563
+ # list, [[x1, x2], [y1], [z1]]
1564
+ audio_feature_lens_raw = (
1565
+ []
1566
+ if multimodal_input.audio_feature_lens is None
1567
+ else multimodal_input.audio_feature_lens
1568
+ )
1569
+
1570
+ # exist audio
1571
+ if len(wavforms) > 0:
1572
+ audio_feature_lens = torch.hstack(audio_feature_lens_raw)
1573
+ batch_size, _, max_mel_seq_len = wavforms.shape
1574
+ assert batch_size == 1
1575
+ max_seq_len = (max_mel_seq_len - 1) // 2 + 1
1576
+
1577
+ if self.audio_past_key_values is not None:
1578
+ cache_length = self.audio_past_key_values[0][0].shape[2]
1579
+ apm_max_len = self.apm.embed_positions.weight.shape[0]
1580
+ if cache_length + max_seq_len >= apm_max_len:
1581
+ logger.warning(
1582
+ f"audio_past_key_values length {cache_length + max_seq_len} exceed {apm_max_len}, reset."
1583
+ )
1584
+ self.audio_past_key_values = None
1585
+
1586
+ audio_outputs = self.apm(
1587
+ wavforms, past_key_values=self.audio_past_key_values, use_cache=True
1588
+ )
1589
+ audio_states = (
1590
+ audio_outputs.last_hidden_state
1591
+ ) # [:, :audio_feat_lengths, :]
1592
+ self.audio_past_key_values = audio_outputs.past_key_values
1593
+
1594
+ audio_embeds = self.audio_projection_layer(audio_states)
1595
+
1596
+ audio_embeds = audio_embeds.transpose(1, 2)
1597
+ audio_embeds = self.audio_avg_pooler(audio_embeds)
1598
+ audio_embeds = audio_embeds.transpose(1, 2)
1599
+
1600
+ _, feature_lens_after_pooling = self._get_feat_extract_output_lengths(
1601
+ audio_feature_lens
1602
+ )
1603
+
1604
+ num_audio_tokens = feature_lens_after_pooling
1605
+
1606
+ final_audio_embeds = []
1607
+ idx = 0
1608
+ for i in range(len(audio_feature_lens_raw)):
1609
+ target_audio_embeds = []
1610
+ for _ in range(len(audio_feature_lens_raw[i])):
1611
+ target_audio_embeds.append(
1612
+ audio_embeds[idx, : num_audio_tokens[idx], :]
1613
+ )
1614
+ idx += 1
1615
+ final_audio_embeds.append(target_audio_embeds)
1616
+ return final_audio_embeds
1617
+ else:
1618
+ return []
1619
+
1620
+ def subsequent_chunk_mask(
1621
+ self,
1622
+ size: int,
1623
+ chunk_size: int,
1624
+ num_left_chunks: int = -1,
1625
+ device: torch.device = torch.device("cpu"),
1626
+ num_lookhead: int = 0,
1627
+ ) -> torch.Tensor:
1628
+ """Create mask for subsequent steps (size, size) with chunk size,
1629
+ this is for streaming encoder
1630
+
1631
+ Args:
1632
+ size (int): size of mask
1633
+ chunk_size (int): size of chunk
1634
+ num_left_chunks (int): number of left chunks
1635
+ <0: use full chunk
1636
+ >=0: use num_left_chunks
1637
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
1638
+
1639
+ Returns:
1640
+ torch.Tensor: mask
1641
+
1642
+ """
1643
+ ret = torch.zeros(size, size, device=device, dtype=torch.bool)
1644
+ for i in range(size):
1645
+ if num_left_chunks < 0:
1646
+ start = 0
1647
+ else:
1648
+ start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
1649
+ ending = min((i // chunk_size + 1) * chunk_size + num_lookhead, size)
1650
+ ret[i, start:ending] = True
1651
+ return ret
1652
+
1653
+ def get_audio_embedding(self, multimodal_input: MultimodalInputs, chunk_length=-1):
1654
+ r"""
1655
+ Extract full audio embeddings with optional chunk-based attention.
1656
+
1657
+ This method computes embeddings for all audio frames at once, either using full attention (when
1658
+ `chunk_length` is -1) or chunk-based attention (when `chunk_length` is a positive number). It does
1659
+ not use key-value caching and is suitable for non-streaming inference.
1660
+
1661
+ Args:
1662
+ multimodal_input (dict):
1663
+ - **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
1664
+ - **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
1665
+ chunk_length (int, optional): Determines whether to use full attention (-1) or chunk-based
1666
+ attention (>0) during embedding computation.
1667
+
1668
+ Returns:
1669
+ List[List[torch.Tensor]]: audio embeddings
1670
+ """
1671
+ # print("audio embedding")
1672
+ # (bs, 80, frames) or [], multi audios need filled in advance
1673
+ wavforms = (
1674
+ []
1675
+ if multimodal_input.audio_features is None
1676
+ else multimodal_input.audio_features
1677
+ )
1678
+ # list, [[x1, x2], [y1], [z1]]
1679
+ audio_feature_lens_raw = (
1680
+ []
1681
+ if multimodal_input.audio_feature_lens is None
1682
+ else multimodal_input.audio_feature_lens
1683
+ )
1684
+
1685
+ final_audio_embeds = []
1686
+
1687
+ # exist audio
1688
+ for wavform in wavforms:
1689
+ if len(wavform) > 0:
1690
+ audio_feature_lens = torch.hstack(audio_feature_lens_raw)
1691
+ batch_size, _, max_mel_seq_len = wavform.shape
1692
+ max_seq_len = (max_mel_seq_len - 1) // 2 + 1
1693
+
1694
+ # Create a sequence tensor of shape (batch_size, max_seq_len)
1695
+ seq_range = (
1696
+ torch.arange(
1697
+ 0,
1698
+ max_seq_len,
1699
+ dtype=audio_feature_lens.dtype,
1700
+ device=audio_feature_lens.device,
1701
+ )
1702
+ .unsqueeze(0)
1703
+ .expand(batch_size, max_seq_len)
1704
+ )
1705
+ lengths_expand = audio_feature_lens.unsqueeze(1).expand(
1706
+ batch_size, max_seq_len
1707
+ )
1708
+ # Create mask
1709
+ padding_mask = seq_range >= lengths_expand # 1 for padded values
1710
+
1711
+ audio_attention_mask_ = padding_mask.view(
1712
+ batch_size, 1, 1, max_seq_len
1713
+ ).expand(batch_size, 1, max_seq_len, max_seq_len)
1714
+ audio_attention_mask = audio_attention_mask_.to(
1715
+ dtype=self.apm.conv1.weight.dtype,
1716
+ device=self.apm.conv1.weight.device,
1717
+ )
1718
+
1719
+ if chunk_length > 0:
1720
+ chunk_num_frame = int(chunk_length * 50)
1721
+ chunk_mask = self.subsequent_chunk_mask(
1722
+ size=max_seq_len,
1723
+ chunk_size=chunk_num_frame,
1724
+ num_left_chunks=-1,
1725
+ device=audio_attention_mask_.device,
1726
+ )
1727
+ audio_attention_mask_ = torch.logical_or(
1728
+ audio_attention_mask_, torch.logical_not(chunk_mask)
1729
+ )
1730
+
1731
+ audio_attention_mask[audio_attention_mask_] = float("-inf")
1732
+ audio_states = self.apm(
1733
+ wavform,
1734
+ output_hidden_states=True,
1735
+ attention_mask=audio_attention_mask,
1736
+ ).hidden_states[self.audio_encoder_layer]
1737
+ audio_embeds = self.audio_projection_layer(audio_states)
1738
+
1739
+ audio_embeds = audio_embeds.transpose(1, 2)
1740
+ audio_embeds = self.audio_avg_pooler(audio_embeds)
1741
+ audio_embeds = audio_embeds.transpose(1, 2)
1742
+
1743
+ _, feature_lens_after_pooling = self._get_feat_extract_output_lengths(
1744
+ audio_feature_lens
1745
+ )
1746
+
1747
+ num_audio_tokens = feature_lens_after_pooling
1748
+
1749
+ idx = 0
1750
+ for i in range(len(audio_feature_lens_raw)):
1751
+ target_audio_embeds = []
1752
+ for _ in range(len(audio_feature_lens_raw[i])):
1753
+ target_audio_embeds.append(
1754
+ audio_embeds[idx, : num_audio_tokens[idx], :]
1755
+ )
1756
+ idx += 1
1757
+ final_audio_embeds.append(target_audio_embeds)
1758
+ return final_audio_embeds
1759
+
1760
+ def get_omni_embedding(
1761
+ self,
1762
+ input_ids,
1763
+ multimodal_input: MultimodalInputs,
1764
+ input_embeds: torch.Tensor,
1765
+ forward_mode: ForwardMode,
1766
+ chunk_length=-1,
1767
+ stream_input=False,
1768
+ ):
1769
+ """
1770
+ Args:
1771
+ multimodal_input:
1772
+ input_embeds:
1773
+ chunk_length: whisper use full attention or chunk attention
1774
+ stream_input: use streaming audio embedding
1775
+ Returns:
1776
+ final embeddings with audio feature
1777
+ """
1778
+ input_embeds = input_embeds.unsqueeze(0)
1779
+ if not forward_mode.is_decode() and multimodal_input.contains_audio_inputs():
1780
+ audio_bounds = get_multimodal_data_bounds(
1781
+ input_ids=input_ids,
1782
+ pad_values=multimodal_input.pad_values,
1783
+ token_pairs=[
1784
+ (multimodal_input.audio_start_id, multimodal_input.audio_end_id)
1785
+ ],
1786
+ )
1787
+ if audio_bounds.numel() == 0:
1788
+ input_embeds = input_embeds.squeeze(0)
1789
+ # TODO
1790
+ logger.warn("Unimplemented logic. Please try disabling chunked prefill")
1791
+ return input_embeds
1792
+ audio_bounds = audio_bounds.unsqueeze(0)
1793
+ bs = len(input_embeds)
1794
+
1795
+ if stream_input:
1796
+ audio_embeddings = self.get_audio_embedding_streaming(multimodal_input)
1797
+ else:
1798
+ audio_embeddings = self.get_audio_embedding(
1799
+ multimodal_input, chunk_length
1800
+ )
1801
+ # batch size
1802
+ assert len(audio_embeddings) == len(input_embeds)
1803
+ if len(audio_embeddings) > 0:
1804
+ if self.config.chunk_input:
1805
+ for i in range(bs):
1806
+ audio_embs = torch.cat(audio_embeddings[i], dim=0).to(
1807
+ device=input_embeds.device, dtype=input_embeds.dtype
1808
+ )
1809
+ audio_start_pos = 0
1810
+ for bound in audio_bounds[i]:
1811
+ audio_len = bound[1] - bound[0] + 1
1812
+ input_embeds[0, bound[0] : bound[1] + 1] = audio_embs[
1813
+ audio_start_pos : audio_start_pos + audio_len, :
1814
+ ]
1815
+ audio_start_pos += audio_len
1816
+ else:
1817
+ for i in range(bs):
1818
+ audio_embs = audio_embeddings[i]
1819
+ bounds = audio_bounds[i]
1820
+ for embs, bound in zip(audio_embs, bounds):
1821
+ audio_indices = torch.arange(
1822
+ bound[0], bound[1], dtype=torch.long
1823
+ ).to(input_embeds.device)
1824
+
1825
+ if embs.shape[0] != len(audio_indices):
1826
+ raise ValueError(
1827
+ f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} "
1828
+ f"to input indices of length {len(audio_indices)}"
1829
+ )
1830
+ input_embeds[i, audio_indices] = embs.to(input_embeds.dtype)
1831
+ input_embeds = input_embeds.squeeze(0)
1832
+ return input_embeds
1833
+
1834
+ def get_image_features(
1835
+ self,
1836
+ image_inputs: MultimodalInputs,
1837
+ ) -> torch.Tensor:
1838
+ pixel_values = image_inputs.pixel_values
1839
+ tgt_sizes = image_inputs.tgt_sizes
1840
+ device = self.vpm.embeddings.position_embedding.weight.device
1841
+ dtype = self.vpm.embeddings.position_embedding.weight.dtype
1842
+ all_pixel_values_lst = [
1843
+ i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
1844
+ ]
1845
+
1846
+ max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
1847
+ assert isinstance(max_patches, int)
1848
+
1849
+ all_pixel_values = torch.nn.utils.rnn.pad_sequence(
1850
+ all_pixel_values_lst, batch_first=True, padding_value=0.0
1851
+ )
1852
+ B, L, _ = all_pixel_values.shape
1853
+ all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
1854
+ patch_attn_mask = torch.zeros(
1855
+ (B, 1, max_patches), dtype=torch.bool, device=device
1856
+ )
1857
+
1858
+ tgt_sizes_tensor = tgt_sizes.clone().to(device=patch_attn_mask.device)
1859
+ mask_shapes = tgt_sizes_tensor[:, 0] * tgt_sizes_tensor[:, 1]
1860
+ patch_attn_mask[:, 0, :] = torch.arange(
1861
+ patch_attn_mask.size(2), device=patch_attn_mask.device
1862
+ ).unsqueeze(0) < mask_shapes.unsqueeze(1)
1863
+
1864
+ vision_embedding = self.vpm(
1865
+ all_pixel_values.type(dtype),
1866
+ patch_attention_mask=patch_attn_mask,
1867
+ tgt_sizes=tgt_sizes,
1868
+ )
1869
+ return self.resampler(vision_embedding, tgt_sizes)
1870
+
1871
+ def forward(
1872
+ self,
1873
+ input_ids: torch.Tensor,
1874
+ positions: torch.Tensor,
1875
+ forward_batch: ForwardBatch,
1876
+ **kwargs: Any,
1877
+ ) -> torch.Tensor:
1878
+ inputs_embeds = None
1879
+ # TODO(mick): optimize the logic here: clamp, merge and embedding should happens at most once
1880
+ if (
1881
+ not forward_batch.forward_mode.is_decode()
1882
+ and forward_batch.contains_image_inputs()
1883
+ ):
1884
+ mm_inputs = forward_batch.merge_mm_inputs()
1885
+ inputs_embeds = embed_mm_inputs(
1886
+ mm_input=mm_inputs,
1887
+ input_ids=input_ids,
1888
+ input_embedding=self.get_input_embeddings(),
1889
+ mm_data_embedding_func=self.get_image_features,
1890
+ placeholder_token_ids=[mm_inputs.im_token_id] + mm_inputs.pad_values,
1891
+ )
1892
+
1893
+ input_ids = input_ids.clamp(
1894
+ min=0, max=self.get_input_embeddings().num_embeddings - 1
1895
+ )
1896
+ if inputs_embeds is None:
1897
+ inputs_embeds = self.llm.get_input_embeddings(input_ids)
1898
+ if (
1899
+ not forward_batch.forward_mode.is_decode()
1900
+ and self.config.init_audio
1901
+ and forward_batch.contains_audio_inputs()
1902
+ ):
1903
+ mm_input = forward_batch.merge_mm_inputs()
1904
+ inputs_embeds = self.get_omni_embedding(
1905
+ input_ids=input_ids,
1906
+ multimodal_input=mm_input,
1907
+ input_embeds=inputs_embeds,
1908
+ forward_mode=forward_batch.forward_mode,
1909
+ chunk_length=self.config.audio_chunk_length,
1910
+ stream_input=False,
1911
+ )
1912
+
1913
+ forward_batch.mm_inputs = None
1914
+
1915
+ hidden_states = self.llm.model(
1916
+ input_ids=None,
1917
+ positions=positions,
1918
+ forward_batch=forward_batch,
1919
+ input_embeds=inputs_embeds,
1920
+ )
1921
+
1922
+ return self.logits_processor(
1923
+ input_ids, hidden_states, self.llm.lm_head, forward_batch
1924
+ )
1925
+
1926
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
1927
+ stacked_params_mapping = [
1928
+ # (param_name, shard_name, shard_id)
1929
+ ("qkv_proj", "q_proj", "q"),
1930
+ ("qkv_proj", "k_proj", "k"),
1931
+ ("qkv_proj", "v_proj", "v"),
1932
+ ("gate_up_proj", "gate_proj", 0),
1933
+ ("gate_up_proj", "up_proj", 1),
1934
+ ]
1935
+
1936
+ params_dict = dict(self.named_parameters())
1937
+ for name, loaded_weight in weights:
1938
+
1939
+ if "rotary_emb.inv_freq~" in name or "projector" in name:
1940
+ continue
1941
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
1942
+ # Models trained using ColossalAI may include these tensors in
1943
+ # the checkpoint. Skip them.
1944
+ continue
1945
+
1946
+ # adapt to parametrization
1947
+ if self.config.init_tts and "tts" in name:
1948
+ name = name.replace(".parametrizations", "")
1949
+ name = name.replace(".weight.original0", ".weight_g")
1950
+ name = name.replace(".weight.original1", ".weight_v")
1951
+
1952
+ # adapt to VisionAttention
1953
+ if "vpm" in name:
1954
+ name = name.replace(r"self_attn.out_proj", r"self_attn.proj")
1955
+
1956
+ if not self.config.init_tts and "tts" in name:
1957
+ continue
1958
+ if not self.config.init_audio and ("apm" in name or "audio" in name):
1959
+ continue
1960
+ if not self.config.init_vision and "vpm" in name:
1961
+ continue
1962
+
1963
+ if (
1964
+ "sampler" in name
1965
+ or "apm" in name
1966
+ or ("tts" in name and "self_attn" in name)
1967
+ or ("tts.model.layers" in name and ".mlp" in name)
1968
+ ):
1969
+ param = params_dict[name]
1970
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
1971
+ weight_loader(param, loaded_weight)
1972
+ continue
1973
+
1974
+ for param_name, weight_name, shard_id in stacked_params_mapping:
1975
+ # replace the name and load with customized loader
1976
+ if weight_name not in name:
1977
+ continue
1978
+ name = name.replace(weight_name, param_name)
1979
+ # # Skip loading extra bias for GPTQ models.
1980
+ if name.endswith(".bias") and name not in params_dict:
1981
+ continue
1982
+ param = params_dict[name]
1983
+ weight_loader = param.weight_loader
1984
+ weight_loader(param, loaded_weight, shard_id)
1985
+ break
1986
+ else:
1987
+ # Skip loading extra bias for GPTQ models.
1988
+ if name.endswith(".bias") and name not in params_dict:
1989
+ continue
1990
+ param = params_dict[name]
1991
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
1992
+ weight_loader(param, loaded_weight)
1993
+
1994
+
1995
+ EntryClass = [MiniCPMO]