sglang 0.4.8__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +3 -2
- sglang/srt/disaggregation/utils.py +12 -11
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/openai/protocol.py +47 -4
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/attention/flashattention_backend.py +24 -14
- sglang/srt/layers/layernorm.py +15 -0
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +12 -3
- sglang/srt/layers/moe/ep_moe/layer.py +79 -12
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +73 -14
- sglang/srt/layers/moe/topk.py +26 -0
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/rotary_embedding.py +103 -11
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +10 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +9 -1
- sglang/srt/managers/scheduler.py +42 -6
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/model_runner.py +5 -2
- sglang/srt/model_loader/loader.py +45 -10
- sglang/srt/model_loader/weight_utils.py +89 -0
- sglang/srt/models/deepseek_nextn.py +7 -4
- sglang/srt/models/deepseek_v2.py +147 -4
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/server_args.py +16 -2
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +71 -0
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +54 -49
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,771 @@
|
|
1
|
+
# coding=utf-8
|
2
|
+
# Copyright 2024 The HunYuan team.
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Inference-only HunYuan model compatible with HuggingFace weights."""
|
15
|
+
import logging
|
16
|
+
import re
|
17
|
+
from dataclasses import dataclass
|
18
|
+
from enum import Enum, auto
|
19
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
20
|
+
|
21
|
+
import torch
|
22
|
+
from torch import nn
|
23
|
+
from transformers import PretrainedConfig
|
24
|
+
|
25
|
+
from sglang.srt.distributed import (
|
26
|
+
get_pp_group,
|
27
|
+
get_tensor_model_parallel_rank,
|
28
|
+
get_tensor_model_parallel_world_size,
|
29
|
+
tensor_model_parallel_all_reduce,
|
30
|
+
)
|
31
|
+
from sglang.srt.layers.activation import SiluAndMul
|
32
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
33
|
+
from sglang.srt.layers.linear import (
|
34
|
+
ColumnParallelLinear,
|
35
|
+
MergedColumnParallelLinear,
|
36
|
+
QKVParallelLinear,
|
37
|
+
ReplicatedLinear,
|
38
|
+
RowParallelLinear,
|
39
|
+
)
|
40
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
42
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
43
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
44
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
45
|
+
from sglang.srt.layers.sampler import Sampler
|
46
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
47
|
+
DEFAULT_VOCAB_PADDING_SIZE,
|
48
|
+
ParallelLMHead,
|
49
|
+
VocabParallelEmbedding,
|
50
|
+
)
|
51
|
+
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
52
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
53
|
+
from sglang.srt.model_loader.weight_utils import (
|
54
|
+
default_weight_loader,
|
55
|
+
kv_cache_scales_loader,
|
56
|
+
maybe_remap_kv_scale_name,
|
57
|
+
)
|
58
|
+
from sglang.srt.utils import add_prefix, is_hip
|
59
|
+
|
60
|
+
expert_distribution_recorder = ExpertDistributionRecorder()
|
61
|
+
|
62
|
+
|
63
|
+
def _is_moe(config: PretrainedConfig) -> bool:
|
64
|
+
if getattr(config, "num_experts", None) and (
|
65
|
+
(isinstance(config.num_experts, int) and config.num_experts > 1)
|
66
|
+
or (isinstance(config.num_experts, list) and max(config.num_experts) > 1)
|
67
|
+
):
|
68
|
+
return True
|
69
|
+
else:
|
70
|
+
return False
|
71
|
+
|
72
|
+
|
73
|
+
def _get_cla_factor(config: PretrainedConfig) -> int:
|
74
|
+
if not getattr(config, "use_cla", False):
|
75
|
+
return 1
|
76
|
+
return getattr(config, "cla_share_factor", 1)
|
77
|
+
|
78
|
+
|
79
|
+
class HunYuanMLP(nn.Module):
|
80
|
+
|
81
|
+
def __init__(
|
82
|
+
self,
|
83
|
+
hidden_size: int,
|
84
|
+
intermediate_size: int,
|
85
|
+
hidden_act: str,
|
86
|
+
quant_config: Optional[QuantizationConfig] = None,
|
87
|
+
bias: bool = False,
|
88
|
+
prefix: str = "",
|
89
|
+
reduce_results: bool = True,
|
90
|
+
) -> None:
|
91
|
+
super().__init__()
|
92
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
93
|
+
input_size=hidden_size,
|
94
|
+
output_sizes=[intermediate_size] * 2,
|
95
|
+
bias=bias,
|
96
|
+
quant_config=quant_config,
|
97
|
+
prefix=f"{prefix}.gate_up_proj",
|
98
|
+
)
|
99
|
+
self.down_proj = RowParallelLinear(
|
100
|
+
input_size=intermediate_size,
|
101
|
+
output_size=hidden_size,
|
102
|
+
bias=bias,
|
103
|
+
quant_config=quant_config,
|
104
|
+
prefix=f"{prefix}.down_proj",
|
105
|
+
reduce_results=reduce_results,
|
106
|
+
)
|
107
|
+
if hidden_act != "silu":
|
108
|
+
raise ValueError(
|
109
|
+
f"Unsupported activation: {hidden_act}. "
|
110
|
+
"Only silu is supported for now."
|
111
|
+
)
|
112
|
+
self.act_fn = SiluAndMul()
|
113
|
+
|
114
|
+
def forward(self, x):
|
115
|
+
gate_up, _ = self.gate_up_proj(x)
|
116
|
+
x = self.act_fn(gate_up)
|
117
|
+
x, _ = self.down_proj(x)
|
118
|
+
return x
|
119
|
+
|
120
|
+
|
121
|
+
class HunYuanSparseMoeBlock(nn.Module):
|
122
|
+
|
123
|
+
def __init__(
|
124
|
+
self,
|
125
|
+
config: PretrainedConfig,
|
126
|
+
quant_config: Optional[QuantizationConfig] = None,
|
127
|
+
layer_id: int = -1,
|
128
|
+
):
|
129
|
+
super().__init__()
|
130
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
131
|
+
|
132
|
+
if self.tp_size > config.num_experts:
|
133
|
+
raise ValueError(
|
134
|
+
f"Tensor parallel size {self.tp_size} is greater than "
|
135
|
+
f"the number of experts {config.num_experts}."
|
136
|
+
)
|
137
|
+
|
138
|
+
# Get layer_id topk if config.moe_topk is a list
|
139
|
+
if isinstance(config.moe_topk, list):
|
140
|
+
assert layer_id >= 0
|
141
|
+
assert len(config.moe_topk) > layer_id
|
142
|
+
top_k = config.moe_topk[layer_id]
|
143
|
+
else:
|
144
|
+
top_k = config.moe_topk
|
145
|
+
|
146
|
+
# If it is moe, moe_intermediate_size is preferred
|
147
|
+
intermediate_size = config.intermediate_size
|
148
|
+
if config.moe_intermediate_size is not None:
|
149
|
+
intermediate_size = (
|
150
|
+
config.moe_intermediate_size
|
151
|
+
if isinstance(config.moe_intermediate_size, int)
|
152
|
+
else config.moe_intermediate_size[layer_id]
|
153
|
+
)
|
154
|
+
|
155
|
+
self.experts = FusedMoE(
|
156
|
+
num_experts=config.num_experts,
|
157
|
+
top_k=top_k,
|
158
|
+
hidden_size=config.hidden_size,
|
159
|
+
intermediate_size=intermediate_size,
|
160
|
+
reduce_results=False,
|
161
|
+
renormalize=True if top_k > 1 else False,
|
162
|
+
quant_config=quant_config,
|
163
|
+
)
|
164
|
+
|
165
|
+
self.gate = ReplicatedLinear(
|
166
|
+
config.hidden_size, config.num_experts, bias=False, quant_config=None
|
167
|
+
)
|
168
|
+
if config.use_mixed_mlp_moe > 0:
|
169
|
+
# Get layer_id num_shared_expert if config.num_shared_expert is a list
|
170
|
+
if isinstance(config.num_shared_expert, list):
|
171
|
+
assert layer_id >= 0
|
172
|
+
assert len(config.num_shared_expert) > layer_id
|
173
|
+
num_shared_expert = config.num_shared_expert[layer_id]
|
174
|
+
else:
|
175
|
+
num_shared_expert = config.num_shared_expert
|
176
|
+
|
177
|
+
self.shared_mlp = HunYuanMLP(
|
178
|
+
hidden_size=config.hidden_size,
|
179
|
+
intermediate_size=config.intermediate_size * num_shared_expert,
|
180
|
+
hidden_act=config.hidden_act,
|
181
|
+
quant_config=quant_config,
|
182
|
+
reduce_results=False,
|
183
|
+
)
|
184
|
+
else:
|
185
|
+
self.shared_mlp = None
|
186
|
+
|
187
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
188
|
+
# NOTE: hidden_states can have either 1D or 2D shape.
|
189
|
+
orig_shape = hidden_states.shape
|
190
|
+
hidden_dim = hidden_states.shape[-1]
|
191
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
192
|
+
shared_output = None
|
193
|
+
if self.shared_mlp is not None:
|
194
|
+
shared_output = self.shared_mlp(hidden_states)
|
195
|
+
|
196
|
+
# router_logits: (num_tokens, n_experts)
|
197
|
+
router_logits, _ = self.gate(hidden_states)
|
198
|
+
final_hidden_states = self.experts(
|
199
|
+
hidden_states=hidden_states, router_logits=router_logits
|
200
|
+
)
|
201
|
+
if shared_output is not None:
|
202
|
+
final_hidden_states = final_hidden_states + shared_output
|
203
|
+
if self.tp_size > 1:
|
204
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
205
|
+
|
206
|
+
return final_hidden_states.view(orig_shape)
|
207
|
+
|
208
|
+
|
209
|
+
class HunYuanAttention(nn.Module):
|
210
|
+
|
211
|
+
def __init__(
|
212
|
+
self,
|
213
|
+
config: PretrainedConfig,
|
214
|
+
hidden_size: int,
|
215
|
+
num_heads: int,
|
216
|
+
num_kv_heads: int,
|
217
|
+
rope_theta: float = 10000,
|
218
|
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
219
|
+
max_position_embeddings: int = 8192,
|
220
|
+
quant_config: Optional[QuantizationConfig] = None,
|
221
|
+
bias: bool = False,
|
222
|
+
prefix: str = "",
|
223
|
+
attention_type: str = "self",
|
224
|
+
layer_id: int = -1,
|
225
|
+
) -> None:
|
226
|
+
super().__init__()
|
227
|
+
self.hidden_size = hidden_size
|
228
|
+
tp_size = get_tensor_model_parallel_world_size()
|
229
|
+
self.total_num_heads = num_heads
|
230
|
+
assert self.total_num_heads % tp_size == 0
|
231
|
+
self.num_heads = self.total_num_heads // tp_size
|
232
|
+
self.total_num_kv_heads = num_kv_heads
|
233
|
+
if self.total_num_kv_heads >= tp_size:
|
234
|
+
# Number of KV heads is greater than TP size, so we partition
|
235
|
+
# the KV heads across multiple tensor parallel GPUs.
|
236
|
+
assert self.total_num_kv_heads % tp_size == 0
|
237
|
+
else:
|
238
|
+
# Number of KV heads is less than TP size, so we replicate
|
239
|
+
# the KV heads across multiple tensor parallel GPUs.
|
240
|
+
assert tp_size % self.total_num_kv_heads == 0
|
241
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
242
|
+
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
243
|
+
self.head_dim = getattr(
|
244
|
+
config, "head_dim", self.hidden_size // self.total_num_heads
|
245
|
+
)
|
246
|
+
self.q_size = self.num_heads * self.head_dim
|
247
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
248
|
+
self.scaling = self.head_dim**-0.5
|
249
|
+
self.rope_theta = rope_theta
|
250
|
+
self.max_position_embeddings = max_position_embeddings
|
251
|
+
self.use_qk_norm = getattr(config, "use_qk_norm", False)
|
252
|
+
self.attention_type = attention_type
|
253
|
+
self.layer_id = layer_id
|
254
|
+
|
255
|
+
if attention_type == "self":
|
256
|
+
self.qkv_proj = QKVParallelLinear(
|
257
|
+
hidden_size=hidden_size,
|
258
|
+
head_size=self.head_dim,
|
259
|
+
total_num_heads=self.total_num_heads,
|
260
|
+
total_num_kv_heads=self.total_num_kv_heads,
|
261
|
+
bias=bias,
|
262
|
+
quant_config=quant_config,
|
263
|
+
prefix=f"{prefix}.qkv_proj",
|
264
|
+
)
|
265
|
+
elif attention_type == "cross":
|
266
|
+
self.q_proj = ColumnParallelLinear(
|
267
|
+
hidden_size,
|
268
|
+
hidden_size,
|
269
|
+
bias=bias,
|
270
|
+
quant_config=quant_config,
|
271
|
+
prefix=f"{prefix}.q_proj",
|
272
|
+
)
|
273
|
+
else:
|
274
|
+
raise RuntimeError("Not support attnention type")
|
275
|
+
|
276
|
+
self.o_proj = RowParallelLinear(
|
277
|
+
input_size=self.total_num_heads * self.head_dim,
|
278
|
+
output_size=hidden_size,
|
279
|
+
bias=bias,
|
280
|
+
quant_config=quant_config,
|
281
|
+
prefix=f"{prefix}.o_proj",
|
282
|
+
)
|
283
|
+
|
284
|
+
is_neox_style = True
|
285
|
+
if quant_config is not None and quant_config.get_name() == "gguf":
|
286
|
+
is_neox_style = False
|
287
|
+
|
288
|
+
self.rotary_emb = get_rope(
|
289
|
+
self.head_dim,
|
290
|
+
rotary_dim=self.head_dim,
|
291
|
+
max_position=max_position_embeddings,
|
292
|
+
base=rope_theta,
|
293
|
+
rope_scaling=rope_scaling,
|
294
|
+
is_neox_style=is_neox_style,
|
295
|
+
)
|
296
|
+
self.attn = RadixAttention(
|
297
|
+
self.num_heads,
|
298
|
+
self.head_dim,
|
299
|
+
self.scaling,
|
300
|
+
num_kv_heads=self.num_kv_heads,
|
301
|
+
layer_id=layer_id,
|
302
|
+
prefix=f"{prefix}.attn",
|
303
|
+
)
|
304
|
+
|
305
|
+
if self.use_qk_norm:
|
306
|
+
self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
307
|
+
self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
308
|
+
|
309
|
+
def forward(
|
310
|
+
self,
|
311
|
+
positions: torch.Tensor,
|
312
|
+
hidden_states: torch.Tensor,
|
313
|
+
forward_batch: ForwardBatch,
|
314
|
+
kv_states: Optional[Tuple[torch.Tensor]] = None,
|
315
|
+
) -> torch.Tensor:
|
316
|
+
if self.attention_type == "self":
|
317
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
318
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
319
|
+
q, k = self.rotary_emb(positions, q, k)
|
320
|
+
ori_k = k
|
321
|
+
if self.use_qk_norm:
|
322
|
+
# q = self.query_layernorm(q.view(-1, self.num_heads, self.head_dim).contiguous())
|
323
|
+
# k = self.key_layernorm(k.view(-1, self.num_kv_heads, self.head_dim).contiguous())
|
324
|
+
q = self.query_layernorm(q.reshape(-1, self.head_dim).contiguous())
|
325
|
+
k = self.key_layernorm(k.reshape(-1, self.head_dim).contiguous())
|
326
|
+
elif self.attention_type == "cross":
|
327
|
+
assert kv_states is not None
|
328
|
+
ori_k, v = kv_states # use last layer kv,
|
329
|
+
k = ori_k
|
330
|
+
q, _ = self.q_proj(hidden_states)
|
331
|
+
k_tmp = torch.empty_like(k) # Todo: reduant rotary embedding
|
332
|
+
q, _ = self.rotary_emb(positions, q, k_tmp)
|
333
|
+
if self.use_qk_norm:
|
334
|
+
q = self.query_layernorm(
|
335
|
+
q.view(-1, self.num_heads, self.head_dim).contiguous()
|
336
|
+
)
|
337
|
+
k = self.key_layernorm(
|
338
|
+
k.view(-1, self.num_kv_heads, self.head_dim).contiguous()
|
339
|
+
)
|
340
|
+
else:
|
341
|
+
raise RuntimeError("Not support attnention type")
|
342
|
+
|
343
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
344
|
+
output, _ = self.o_proj(attn_output)
|
345
|
+
return output, (ori_k, v)
|
346
|
+
|
347
|
+
|
348
|
+
class HunYuanDecoderLayer(nn.Module):
|
349
|
+
|
350
|
+
def __init__(
|
351
|
+
self,
|
352
|
+
config: PretrainedConfig,
|
353
|
+
quant_config: Optional[QuantizationConfig] = None,
|
354
|
+
prefix: str = "",
|
355
|
+
layer_id: int = -1,
|
356
|
+
) -> None:
|
357
|
+
super().__init__()
|
358
|
+
assert layer_id >= 0
|
359
|
+
self.layer_id = layer_id
|
360
|
+
self.hidden_size = config.hidden_size
|
361
|
+
self.intermediate_size = (
|
362
|
+
config.intermediate_size
|
363
|
+
if isinstance(config.intermediate_size, int)
|
364
|
+
else config.intermediate_size[layer_id]
|
365
|
+
)
|
366
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
367
|
+
rope_scaling = getattr(config, "rope_scaling", None)
|
368
|
+
if rope_scaling is not None and getattr(
|
369
|
+
config, "original_max_position_embeddings", None
|
370
|
+
):
|
371
|
+
rope_scaling["original_max_position_embeddings"] = (
|
372
|
+
config.original_max_position_embeddings
|
373
|
+
)
|
374
|
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
375
|
+
# Support abacusai/Smaug-72B-v0.1 with attention_bias
|
376
|
+
# Support internlm/internlm-7b with bias
|
377
|
+
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
378
|
+
config, "bias", False
|
379
|
+
)
|
380
|
+
cla_factor = _get_cla_factor(config)
|
381
|
+
attention_type = (
|
382
|
+
"cross" if layer_id >= 0 and layer_id % cla_factor != 0 else "self"
|
383
|
+
)
|
384
|
+
self.self_attn = HunYuanAttention(
|
385
|
+
config=config,
|
386
|
+
hidden_size=self.hidden_size,
|
387
|
+
num_heads=config.num_attention_heads,
|
388
|
+
num_kv_heads=getattr(
|
389
|
+
config, "num_key_value_heads", config.num_attention_heads
|
390
|
+
),
|
391
|
+
rope_theta=rope_theta,
|
392
|
+
rope_scaling=rope_scaling,
|
393
|
+
max_position_embeddings=max_position_embeddings,
|
394
|
+
quant_config=quant_config,
|
395
|
+
bias=attention_bias,
|
396
|
+
prefix=f"{prefix}.self_attn",
|
397
|
+
attention_type=attention_type,
|
398
|
+
layer_id=layer_id,
|
399
|
+
)
|
400
|
+
if _is_moe(config):
|
401
|
+
self.mlp = HunYuanSparseMoeBlock(
|
402
|
+
config=config,
|
403
|
+
quant_config=quant_config,
|
404
|
+
layer_id=layer_id,
|
405
|
+
)
|
406
|
+
else:
|
407
|
+
self.mlp = HunYuanMLP(
|
408
|
+
hidden_size=self.hidden_size,
|
409
|
+
intermediate_size=self.intermediate_size,
|
410
|
+
hidden_act=config.hidden_act,
|
411
|
+
quant_config=quant_config,
|
412
|
+
bias=getattr(config, "mlp_bias", False),
|
413
|
+
prefix=f"{prefix}.mlp",
|
414
|
+
)
|
415
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
416
|
+
self.post_attention_layernorm = RMSNorm(
|
417
|
+
config.hidden_size, eps=config.rms_norm_eps
|
418
|
+
)
|
419
|
+
|
420
|
+
def forward(
|
421
|
+
self,
|
422
|
+
positions: torch.Tensor,
|
423
|
+
hidden_states: torch.Tensor,
|
424
|
+
forward_batch: ForwardBatch,
|
425
|
+
residual: Optional[torch.Tensor],
|
426
|
+
kv_states: Optional[Tuple[torch.Tensor]] = None,
|
427
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
428
|
+
# Self Attention
|
429
|
+
if residual is None:
|
430
|
+
residual = hidden_states
|
431
|
+
hidden_states = self.input_layernorm(hidden_states)
|
432
|
+
else:
|
433
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
434
|
+
hidden_states, ori_kv_states = self.self_attn(
|
435
|
+
positions=positions,
|
436
|
+
hidden_states=hidden_states,
|
437
|
+
forward_batch=forward_batch,
|
438
|
+
kv_states=kv_states,
|
439
|
+
)
|
440
|
+
|
441
|
+
# Fully Connected
|
442
|
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
443
|
+
hidden_states = self.mlp(hidden_states)
|
444
|
+
return hidden_states, residual, ori_kv_states
|
445
|
+
|
446
|
+
|
447
|
+
class HunYuanModel(nn.Module):
|
448
|
+
|
449
|
+
def __init__(
|
450
|
+
self,
|
451
|
+
config: PretrainedConfig,
|
452
|
+
quant_config: Optional[QuantizationConfig] = None,
|
453
|
+
prefix: str = "",
|
454
|
+
) -> None:
|
455
|
+
super().__init__()
|
456
|
+
self.config = config
|
457
|
+
self.padding_idx = config.pad_token_id
|
458
|
+
self.vocab_size = config.vocab_size
|
459
|
+
self.org_vocab_size = config.vocab_size
|
460
|
+
|
461
|
+
self.embed_tokens = VocabParallelEmbedding(
|
462
|
+
self.vocab_size,
|
463
|
+
config.hidden_size,
|
464
|
+
)
|
465
|
+
|
466
|
+
self.layers = nn.ModuleList(
|
467
|
+
[
|
468
|
+
HunYuanDecoderLayer(
|
469
|
+
config=config,
|
470
|
+
layer_id=layer_id,
|
471
|
+
quant_config=quant_config,
|
472
|
+
# prefix=prefix
|
473
|
+
)
|
474
|
+
for layer_id in range(config.num_hidden_layers)
|
475
|
+
]
|
476
|
+
)
|
477
|
+
|
478
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
479
|
+
|
480
|
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
481
|
+
return self.embed_tokens(input_ids)
|
482
|
+
|
483
|
+
def forward(
|
484
|
+
self,
|
485
|
+
input_ids: Optional[torch.Tensor],
|
486
|
+
positions: torch.Tensor,
|
487
|
+
forward_batch: ForwardBatch,
|
488
|
+
input_embeds: Optional[torch.Tensor] = None,
|
489
|
+
) -> torch.Tensor:
|
490
|
+
if input_embeds is not None:
|
491
|
+
hidden_states = input_embeds
|
492
|
+
else:
|
493
|
+
hidden_states = self.get_input_embeddings(input_ids)
|
494
|
+
residual = None
|
495
|
+
|
496
|
+
cla_factor = _get_cla_factor(self.config)
|
497
|
+
prev_kv_states = None
|
498
|
+
for i in range(len(self.layers)):
|
499
|
+
layer = self.layers[i]
|
500
|
+
hidden_states, residual, kv_states = layer(
|
501
|
+
positions,
|
502
|
+
hidden_states,
|
503
|
+
forward_batch,
|
504
|
+
residual,
|
505
|
+
prev_kv_states,
|
506
|
+
)
|
507
|
+
|
508
|
+
if False: # (i - self.start_layer) % cla_factor == 0:
|
509
|
+
prev_kv_states = kv_states
|
510
|
+
else:
|
511
|
+
prev_kv_states = None
|
512
|
+
|
513
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
514
|
+
return hidden_states
|
515
|
+
|
516
|
+
|
517
|
+
class HunYuanMoEV1ForCausalLM(nn.Module):
|
518
|
+
packed_modules_mapping = {
|
519
|
+
"qkv_proj": [
|
520
|
+
"q_proj",
|
521
|
+
"k_proj",
|
522
|
+
"v_proj",
|
523
|
+
],
|
524
|
+
"gate_up_proj": [
|
525
|
+
"gate_proj",
|
526
|
+
"up_proj",
|
527
|
+
],
|
528
|
+
}
|
529
|
+
|
530
|
+
embedding_modules = {
|
531
|
+
"embed_tokens": "input_embeddings",
|
532
|
+
"lm_head": "output_embeddings",
|
533
|
+
}
|
534
|
+
embedding_padding_modules = ["lm_head"]
|
535
|
+
bitsandbytes_stacked_params_mapping = {
|
536
|
+
# shard_name, weight_name, index
|
537
|
+
"q_proj": ("qkv_proj", 0),
|
538
|
+
"k_proj": ("qkv_proj", 1),
|
539
|
+
"v_proj": ("qkv_proj", 2),
|
540
|
+
"gate_proj": ("gate_up_proj", 0),
|
541
|
+
"up_proj": ("gate_up_proj", 1),
|
542
|
+
}
|
543
|
+
|
544
|
+
def __init__(
|
545
|
+
self,
|
546
|
+
config: PretrainedConfig,
|
547
|
+
quant_config: Optional[QuantizationConfig] = None,
|
548
|
+
) -> None:
|
549
|
+
super().__init__()
|
550
|
+
|
551
|
+
self.config = config
|
552
|
+
|
553
|
+
self.model = HunYuanModel(config, quant_config, prefix="model")
|
554
|
+
self.unpadded_vocab_size = config.vocab_size
|
555
|
+
self.lm_head = ParallelLMHead(
|
556
|
+
config.vocab_size,
|
557
|
+
config.hidden_size,
|
558
|
+
quant_config=quant_config,
|
559
|
+
)
|
560
|
+
if config.tie_word_embeddings:
|
561
|
+
self.lm_head.weight = self.model.embed_tokens.weight
|
562
|
+
|
563
|
+
logit_scale = getattr(config, "logit_scale", 1.0)
|
564
|
+
self.logits_processor = LogitsProcessor(config, logit_scale=logit_scale)
|
565
|
+
self.sampler = Sampler()
|
566
|
+
|
567
|
+
def forward(
|
568
|
+
self,
|
569
|
+
input_ids: torch.Tensor,
|
570
|
+
positions: torch.Tensor,
|
571
|
+
forward_batch: ForwardBatch,
|
572
|
+
input_embeds: torch.Tensor = None,
|
573
|
+
) -> torch.Tensor:
|
574
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
575
|
+
return self.logits_processor(
|
576
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
577
|
+
)
|
578
|
+
|
579
|
+
def _split_qkv_weight(self, qkv: torch.Tensor):
|
580
|
+
num_attention_heads = self.config.num_attention_heads
|
581
|
+
num_kv_heads = getattr(
|
582
|
+
self.config, "num_key_value_heads", self.config.num_attention_heads
|
583
|
+
)
|
584
|
+
num_key_value_groups = num_attention_heads // num_kv_heads
|
585
|
+
hidden_size = self.config.hidden_size
|
586
|
+
attention_head_dim = self.config.hidden_size // num_attention_heads
|
587
|
+
|
588
|
+
qkv = qkv.reshape(
|
589
|
+
num_kv_heads, num_key_value_groups + 2, attention_head_dim, hidden_size
|
590
|
+
)
|
591
|
+
q, k, v = torch.split(qkv, (num_key_value_groups, 1, 1), dim=1)
|
592
|
+
q = q.reshape(-1, hidden_size)
|
593
|
+
k = k.reshape(-1, hidden_size)
|
594
|
+
v = v.reshape(-1, hidden_size)
|
595
|
+
return torch.concat((q, k, v))
|
596
|
+
# return qkv.reshape((num_kv_heads, num_key_value_groups+2 , attention_head_dim, hidden_size)).permute((1,0,2,3)).reshape((-1, hidden_size)),
|
597
|
+
|
598
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
599
|
+
cla_factor = _get_cla_factor(self.config)
|
600
|
+
stacked_params_mapping = [
|
601
|
+
# (param_name, shard_name, shard_id)
|
602
|
+
(".qkv_proj", ".q_proj", "q"),
|
603
|
+
(".qkv_proj", ".k_proj", "k"),
|
604
|
+
(".qkv_proj", ".v_proj", "v"),
|
605
|
+
(".gate_up_proj", ".gate_proj", 0),
|
606
|
+
(".gate_up_proj", ".up_proj", 1),
|
607
|
+
]
|
608
|
+
|
609
|
+
num_attention_heads = self.config.num_attention_heads
|
610
|
+
num_kv_heads = getattr(
|
611
|
+
self.config, "num_key_value_heads", self.config.num_attention_heads
|
612
|
+
)
|
613
|
+
split_params_mapping = [
|
614
|
+
(".gate_up_proj", ".gate_and_up_proj", 2, [(1, 1), (0, 1)], None),
|
615
|
+
(
|
616
|
+
".qkv_proj",
|
617
|
+
".qkv_proj",
|
618
|
+
num_attention_heads + num_kv_heads * 2,
|
619
|
+
[("q", num_attention_heads), ("k", num_kv_heads), ("v", num_kv_heads)],
|
620
|
+
self._split_qkv_weight,
|
621
|
+
),
|
622
|
+
]
|
623
|
+
|
624
|
+
if _is_moe(self.config):
|
625
|
+
# Params for weights, fp8 weight scales, fp8 activation scales
|
626
|
+
# (param_name, weight_name, expert_id, shard_id)
|
627
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
628
|
+
ckpt_gate_proj_name="gate_proj",
|
629
|
+
ckpt_down_proj_name="down_proj",
|
630
|
+
ckpt_up_proj_name="up_proj",
|
631
|
+
num_experts=self.config.num_experts,
|
632
|
+
)
|
633
|
+
else:
|
634
|
+
expert_params_mapping = {}
|
635
|
+
|
636
|
+
params_dict = dict(self.named_parameters())
|
637
|
+
for name, loaded_weight in weights:
|
638
|
+
if "rotary_emb.inv_freq" in name:
|
639
|
+
continue
|
640
|
+
if "gate_proj_bias" in name:
|
641
|
+
name = name.replace("gate_proj_bias", "gate_proj.bias")
|
642
|
+
if "up_proj_bias" in name:
|
643
|
+
name = name.replace("up_proj_bias", "up_proj.bias")
|
644
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
645
|
+
# Models trained using ColossalAI may include these tensors in
|
646
|
+
# the checkpoint. Skip them.
|
647
|
+
continue
|
648
|
+
# With tie_word_embeddings, we can skip lm_head.weight
|
649
|
+
# The weight might appear unnecessarily in the files if the model is
|
650
|
+
# processed with quantization, LoRA, fine-tuning, etc.
|
651
|
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
652
|
+
continue
|
653
|
+
|
654
|
+
is_found = False
|
655
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
656
|
+
if weight_name not in name:
|
657
|
+
continue
|
658
|
+
if "mlp.experts" in name:
|
659
|
+
continue
|
660
|
+
# cross layer only have q_proj, skip qkv pack
|
661
|
+
if weight_name == ".q_proj":
|
662
|
+
match = re.search(r"layers\.\d+", name)
|
663
|
+
if match:
|
664
|
+
layer_id = int(match.group(0).split(".")[-1])
|
665
|
+
if cla_factor > 1 and layer_id % cla_factor != 0:
|
666
|
+
continue
|
667
|
+
name = name.replace(weight_name, param_name)
|
668
|
+
# Skip loading extra bias for GPTQ models.
|
669
|
+
if name.endswith(".bias") and name not in params_dict:
|
670
|
+
continue
|
671
|
+
|
672
|
+
param = params_dict[name]
|
673
|
+
weight_loader = param.weight_loader
|
674
|
+
weight_loader(param, loaded_weight, shard_id)
|
675
|
+
|
676
|
+
is_found = True
|
677
|
+
break
|
678
|
+
if is_found:
|
679
|
+
continue
|
680
|
+
|
681
|
+
for param_name, weight_name, den, split_param, func in split_params_mapping:
|
682
|
+
if weight_name not in name:
|
683
|
+
continue
|
684
|
+
name = name.replace(weight_name, param_name)
|
685
|
+
# Skip loading extra bias for GPTQ models.
|
686
|
+
if name.endswith(".bias") and name not in params_dict:
|
687
|
+
continue
|
688
|
+
|
689
|
+
assert loaded_weight.shape[0] % den == 0
|
690
|
+
units = loaded_weight.shape[0] // den
|
691
|
+
|
692
|
+
param = params_dict[name]
|
693
|
+
weight_loader = param.weight_loader
|
694
|
+
offset = 0
|
695
|
+
for shard_id, num in split_param:
|
696
|
+
new_offset = offset + num * units
|
697
|
+
if func:
|
698
|
+
weight_loader(
|
699
|
+
param, func(loaded_weight)[offset:new_offset], shard_id
|
700
|
+
)
|
701
|
+
else:
|
702
|
+
weight_loader(param, loaded_weight[offset:new_offset], shard_id)
|
703
|
+
offset = new_offset
|
704
|
+
|
705
|
+
break
|
706
|
+
else:
|
707
|
+
# Skip loading extra bias for GPTQ models.
|
708
|
+
if name.endswith(".bias") and name not in params_dict:
|
709
|
+
continue
|
710
|
+
for mapping in expert_params_mapping:
|
711
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
712
|
+
if weight_name not in name:
|
713
|
+
continue
|
714
|
+
name = name.replace(weight_name, param_name)
|
715
|
+
# Skip layers on other devices.
|
716
|
+
param = params_dict[name]
|
717
|
+
weight_loader = param.weight_loader
|
718
|
+
weight_loader(
|
719
|
+
param,
|
720
|
+
loaded_weight,
|
721
|
+
name,
|
722
|
+
shard_id=shard_id,
|
723
|
+
expert_id=expert_id,
|
724
|
+
)
|
725
|
+
break
|
726
|
+
else:
|
727
|
+
# Remapping the name of FP8 kv-scale.
|
728
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
729
|
+
if name is None:
|
730
|
+
continue
|
731
|
+
|
732
|
+
if "mlp.gate.wg." in name:
|
733
|
+
name = name.replace("wg.", "")
|
734
|
+
|
735
|
+
param = params_dict[name]
|
736
|
+
weight_loader = getattr(
|
737
|
+
param, "weight_loader", default_weight_loader
|
738
|
+
)
|
739
|
+
weight_loader(param, loaded_weight)
|
740
|
+
|
741
|
+
# If this function is called, it should always initialize KV cache scale
|
742
|
+
# factors (or else raise an exception). Thus, handled exceptions should
|
743
|
+
# make sure to leave KV cache scale factors in a known good (dummy) state
|
744
|
+
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
745
|
+
tp_size = get_tensor_model_parallel_world_size()
|
746
|
+
tp_rank = get_tensor_model_parallel_rank()
|
747
|
+
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
748
|
+
quantization_param_path,
|
749
|
+
tp_rank,
|
750
|
+
tp_size,
|
751
|
+
self.config.num_hidden_layers,
|
752
|
+
self.config.__class__.model_type,
|
753
|
+
):
|
754
|
+
if not isinstance(self.model.layers[layer_idx], nn.Identity):
|
755
|
+
layer_self_attn = self.model.layers[layer_idx].self_attn
|
756
|
+
|
757
|
+
if is_hip():
|
758
|
+
# The scaling factor convention we are assuming is
|
759
|
+
# quantized_value * scaling_factor ~= true_value
|
760
|
+
# which is consistent with the practice of setting
|
761
|
+
# scaling_factor = tensor_amax / FPtype_max
|
762
|
+
scaling_factor *= 2
|
763
|
+
if hasattr(layer_self_attn, "kv_scale"):
|
764
|
+
layer_self_attn.attn._kv_scale = scaling_factor
|
765
|
+
else:
|
766
|
+
raise RuntimeError(
|
767
|
+
"Self attention has no KV cache scaling " "factor attribute!"
|
768
|
+
)
|
769
|
+
|
770
|
+
|
771
|
+
EntryClass = HunYuanMoEV1ForCausalLM
|