sglang 0.4.8__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. sglang/srt/configs/model_config.py +1 -0
  2. sglang/srt/conversation.py +1 -0
  3. sglang/srt/custom_op.py +7 -1
  4. sglang/srt/disaggregation/base/conn.py +2 -0
  5. sglang/srt/disaggregation/decode.py +1 -1
  6. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  7. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  8. sglang/srt/disaggregation/nixl/conn.py +94 -46
  9. sglang/srt/disaggregation/prefill.py +3 -2
  10. sglang/srt/disaggregation/utils.py +12 -11
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/openai/protocol.py +47 -4
  13. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  14. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  15. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  16. sglang/srt/layers/activation.py +7 -0
  17. sglang/srt/layers/attention/flashattention_backend.py +24 -14
  18. sglang/srt/layers/layernorm.py +15 -0
  19. sglang/srt/layers/linear.py +18 -1
  20. sglang/srt/layers/logits_processor.py +12 -3
  21. sglang/srt/layers/moe/ep_moe/layer.py +79 -12
  22. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  23. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  24. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -2
  25. sglang/srt/layers/moe/fused_moe_triton/layer.py +73 -14
  26. sglang/srt/layers/moe/topk.py +26 -0
  27. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  28. sglang/srt/layers/rotary_embedding.py +103 -11
  29. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  30. sglang/srt/managers/expert_distribution.py +21 -0
  31. sglang/srt/managers/io_struct.py +10 -2
  32. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  33. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  34. sglang/srt/managers/schedule_batch.py +9 -1
  35. sglang/srt/managers/scheduler.py +42 -6
  36. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  37. sglang/srt/model_executor/model_runner.py +5 -2
  38. sglang/srt/model_loader/loader.py +45 -10
  39. sglang/srt/model_loader/weight_utils.py +89 -0
  40. sglang/srt/models/deepseek_nextn.py +7 -4
  41. sglang/srt/models/deepseek_v2.py +147 -4
  42. sglang/srt/models/gemma3n_audio.py +949 -0
  43. sglang/srt/models/gemma3n_causal.py +1009 -0
  44. sglang/srt/models/gemma3n_mm.py +511 -0
  45. sglang/srt/models/hunyuan.py +771 -0
  46. sglang/srt/server_args.py +16 -2
  47. sglang/srt/two_batch_overlap.py +4 -1
  48. sglang/srt/utils.py +71 -0
  49. sglang/version.py +1 -1
  50. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +1 -1
  51. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +54 -49
  52. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  53. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  54. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,771 @@
1
+ # coding=utf-8
2
+ # Copyright 2024 The HunYuan team.
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Inference-only HunYuan model compatible with HuggingFace weights."""
15
+ import logging
16
+ import re
17
+ from dataclasses import dataclass
18
+ from enum import Enum, auto
19
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ from torch import nn
23
+ from transformers import PretrainedConfig
24
+
25
+ from sglang.srt.distributed import (
26
+ get_pp_group,
27
+ get_tensor_model_parallel_rank,
28
+ get_tensor_model_parallel_world_size,
29
+ tensor_model_parallel_all_reduce,
30
+ )
31
+ from sglang.srt.layers.activation import SiluAndMul
32
+ from sglang.srt.layers.layernorm import RMSNorm
33
+ from sglang.srt.layers.linear import (
34
+ ColumnParallelLinear,
35
+ MergedColumnParallelLinear,
36
+ QKVParallelLinear,
37
+ ReplicatedLinear,
38
+ RowParallelLinear,
39
+ )
40
+ from sglang.srt.layers.logits_processor import LogitsProcessor
41
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
42
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
43
+ from sglang.srt.layers.radix_attention import RadixAttention
44
+ from sglang.srt.layers.rotary_embedding import get_rope
45
+ from sglang.srt.layers.sampler import Sampler
46
+ from sglang.srt.layers.vocab_parallel_embedding import (
47
+ DEFAULT_VOCAB_PADDING_SIZE,
48
+ ParallelLMHead,
49
+ VocabParallelEmbedding,
50
+ )
51
+ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
52
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
53
+ from sglang.srt.model_loader.weight_utils import (
54
+ default_weight_loader,
55
+ kv_cache_scales_loader,
56
+ maybe_remap_kv_scale_name,
57
+ )
58
+ from sglang.srt.utils import add_prefix, is_hip
59
+
60
+ expert_distribution_recorder = ExpertDistributionRecorder()
61
+
62
+
63
+ def _is_moe(config: PretrainedConfig) -> bool:
64
+ if getattr(config, "num_experts", None) and (
65
+ (isinstance(config.num_experts, int) and config.num_experts > 1)
66
+ or (isinstance(config.num_experts, list) and max(config.num_experts) > 1)
67
+ ):
68
+ return True
69
+ else:
70
+ return False
71
+
72
+
73
+ def _get_cla_factor(config: PretrainedConfig) -> int:
74
+ if not getattr(config, "use_cla", False):
75
+ return 1
76
+ return getattr(config, "cla_share_factor", 1)
77
+
78
+
79
+ class HunYuanMLP(nn.Module):
80
+
81
+ def __init__(
82
+ self,
83
+ hidden_size: int,
84
+ intermediate_size: int,
85
+ hidden_act: str,
86
+ quant_config: Optional[QuantizationConfig] = None,
87
+ bias: bool = False,
88
+ prefix: str = "",
89
+ reduce_results: bool = True,
90
+ ) -> None:
91
+ super().__init__()
92
+ self.gate_up_proj = MergedColumnParallelLinear(
93
+ input_size=hidden_size,
94
+ output_sizes=[intermediate_size] * 2,
95
+ bias=bias,
96
+ quant_config=quant_config,
97
+ prefix=f"{prefix}.gate_up_proj",
98
+ )
99
+ self.down_proj = RowParallelLinear(
100
+ input_size=intermediate_size,
101
+ output_size=hidden_size,
102
+ bias=bias,
103
+ quant_config=quant_config,
104
+ prefix=f"{prefix}.down_proj",
105
+ reduce_results=reduce_results,
106
+ )
107
+ if hidden_act != "silu":
108
+ raise ValueError(
109
+ f"Unsupported activation: {hidden_act}. "
110
+ "Only silu is supported for now."
111
+ )
112
+ self.act_fn = SiluAndMul()
113
+
114
+ def forward(self, x):
115
+ gate_up, _ = self.gate_up_proj(x)
116
+ x = self.act_fn(gate_up)
117
+ x, _ = self.down_proj(x)
118
+ return x
119
+
120
+
121
+ class HunYuanSparseMoeBlock(nn.Module):
122
+
123
+ def __init__(
124
+ self,
125
+ config: PretrainedConfig,
126
+ quant_config: Optional[QuantizationConfig] = None,
127
+ layer_id: int = -1,
128
+ ):
129
+ super().__init__()
130
+ self.tp_size = get_tensor_model_parallel_world_size()
131
+
132
+ if self.tp_size > config.num_experts:
133
+ raise ValueError(
134
+ f"Tensor parallel size {self.tp_size} is greater than "
135
+ f"the number of experts {config.num_experts}."
136
+ )
137
+
138
+ # Get layer_id topk if config.moe_topk is a list
139
+ if isinstance(config.moe_topk, list):
140
+ assert layer_id >= 0
141
+ assert len(config.moe_topk) > layer_id
142
+ top_k = config.moe_topk[layer_id]
143
+ else:
144
+ top_k = config.moe_topk
145
+
146
+ # If it is moe, moe_intermediate_size is preferred
147
+ intermediate_size = config.intermediate_size
148
+ if config.moe_intermediate_size is not None:
149
+ intermediate_size = (
150
+ config.moe_intermediate_size
151
+ if isinstance(config.moe_intermediate_size, int)
152
+ else config.moe_intermediate_size[layer_id]
153
+ )
154
+
155
+ self.experts = FusedMoE(
156
+ num_experts=config.num_experts,
157
+ top_k=top_k,
158
+ hidden_size=config.hidden_size,
159
+ intermediate_size=intermediate_size,
160
+ reduce_results=False,
161
+ renormalize=True if top_k > 1 else False,
162
+ quant_config=quant_config,
163
+ )
164
+
165
+ self.gate = ReplicatedLinear(
166
+ config.hidden_size, config.num_experts, bias=False, quant_config=None
167
+ )
168
+ if config.use_mixed_mlp_moe > 0:
169
+ # Get layer_id num_shared_expert if config.num_shared_expert is a list
170
+ if isinstance(config.num_shared_expert, list):
171
+ assert layer_id >= 0
172
+ assert len(config.num_shared_expert) > layer_id
173
+ num_shared_expert = config.num_shared_expert[layer_id]
174
+ else:
175
+ num_shared_expert = config.num_shared_expert
176
+
177
+ self.shared_mlp = HunYuanMLP(
178
+ hidden_size=config.hidden_size,
179
+ intermediate_size=config.intermediate_size * num_shared_expert,
180
+ hidden_act=config.hidden_act,
181
+ quant_config=quant_config,
182
+ reduce_results=False,
183
+ )
184
+ else:
185
+ self.shared_mlp = None
186
+
187
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
188
+ # NOTE: hidden_states can have either 1D or 2D shape.
189
+ orig_shape = hidden_states.shape
190
+ hidden_dim = hidden_states.shape[-1]
191
+ hidden_states = hidden_states.view(-1, hidden_dim)
192
+ shared_output = None
193
+ if self.shared_mlp is not None:
194
+ shared_output = self.shared_mlp(hidden_states)
195
+
196
+ # router_logits: (num_tokens, n_experts)
197
+ router_logits, _ = self.gate(hidden_states)
198
+ final_hidden_states = self.experts(
199
+ hidden_states=hidden_states, router_logits=router_logits
200
+ )
201
+ if shared_output is not None:
202
+ final_hidden_states = final_hidden_states + shared_output
203
+ if self.tp_size > 1:
204
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
205
+
206
+ return final_hidden_states.view(orig_shape)
207
+
208
+
209
+ class HunYuanAttention(nn.Module):
210
+
211
+ def __init__(
212
+ self,
213
+ config: PretrainedConfig,
214
+ hidden_size: int,
215
+ num_heads: int,
216
+ num_kv_heads: int,
217
+ rope_theta: float = 10000,
218
+ rope_scaling: Optional[Dict[str, Any]] = None,
219
+ max_position_embeddings: int = 8192,
220
+ quant_config: Optional[QuantizationConfig] = None,
221
+ bias: bool = False,
222
+ prefix: str = "",
223
+ attention_type: str = "self",
224
+ layer_id: int = -1,
225
+ ) -> None:
226
+ super().__init__()
227
+ self.hidden_size = hidden_size
228
+ tp_size = get_tensor_model_parallel_world_size()
229
+ self.total_num_heads = num_heads
230
+ assert self.total_num_heads % tp_size == 0
231
+ self.num_heads = self.total_num_heads // tp_size
232
+ self.total_num_kv_heads = num_kv_heads
233
+ if self.total_num_kv_heads >= tp_size:
234
+ # Number of KV heads is greater than TP size, so we partition
235
+ # the KV heads across multiple tensor parallel GPUs.
236
+ assert self.total_num_kv_heads % tp_size == 0
237
+ else:
238
+ # Number of KV heads is less than TP size, so we replicate
239
+ # the KV heads across multiple tensor parallel GPUs.
240
+ assert tp_size % self.total_num_kv_heads == 0
241
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
242
+ # MistralConfig has an optional head_dim introduced by Mistral-Nemo
243
+ self.head_dim = getattr(
244
+ config, "head_dim", self.hidden_size // self.total_num_heads
245
+ )
246
+ self.q_size = self.num_heads * self.head_dim
247
+ self.kv_size = self.num_kv_heads * self.head_dim
248
+ self.scaling = self.head_dim**-0.5
249
+ self.rope_theta = rope_theta
250
+ self.max_position_embeddings = max_position_embeddings
251
+ self.use_qk_norm = getattr(config, "use_qk_norm", False)
252
+ self.attention_type = attention_type
253
+ self.layer_id = layer_id
254
+
255
+ if attention_type == "self":
256
+ self.qkv_proj = QKVParallelLinear(
257
+ hidden_size=hidden_size,
258
+ head_size=self.head_dim,
259
+ total_num_heads=self.total_num_heads,
260
+ total_num_kv_heads=self.total_num_kv_heads,
261
+ bias=bias,
262
+ quant_config=quant_config,
263
+ prefix=f"{prefix}.qkv_proj",
264
+ )
265
+ elif attention_type == "cross":
266
+ self.q_proj = ColumnParallelLinear(
267
+ hidden_size,
268
+ hidden_size,
269
+ bias=bias,
270
+ quant_config=quant_config,
271
+ prefix=f"{prefix}.q_proj",
272
+ )
273
+ else:
274
+ raise RuntimeError("Not support attnention type")
275
+
276
+ self.o_proj = RowParallelLinear(
277
+ input_size=self.total_num_heads * self.head_dim,
278
+ output_size=hidden_size,
279
+ bias=bias,
280
+ quant_config=quant_config,
281
+ prefix=f"{prefix}.o_proj",
282
+ )
283
+
284
+ is_neox_style = True
285
+ if quant_config is not None and quant_config.get_name() == "gguf":
286
+ is_neox_style = False
287
+
288
+ self.rotary_emb = get_rope(
289
+ self.head_dim,
290
+ rotary_dim=self.head_dim,
291
+ max_position=max_position_embeddings,
292
+ base=rope_theta,
293
+ rope_scaling=rope_scaling,
294
+ is_neox_style=is_neox_style,
295
+ )
296
+ self.attn = RadixAttention(
297
+ self.num_heads,
298
+ self.head_dim,
299
+ self.scaling,
300
+ num_kv_heads=self.num_kv_heads,
301
+ layer_id=layer_id,
302
+ prefix=f"{prefix}.attn",
303
+ )
304
+
305
+ if self.use_qk_norm:
306
+ self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
307
+ self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
308
+
309
+ def forward(
310
+ self,
311
+ positions: torch.Tensor,
312
+ hidden_states: torch.Tensor,
313
+ forward_batch: ForwardBatch,
314
+ kv_states: Optional[Tuple[torch.Tensor]] = None,
315
+ ) -> torch.Tensor:
316
+ if self.attention_type == "self":
317
+ qkv, _ = self.qkv_proj(hidden_states)
318
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
319
+ q, k = self.rotary_emb(positions, q, k)
320
+ ori_k = k
321
+ if self.use_qk_norm:
322
+ # q = self.query_layernorm(q.view(-1, self.num_heads, self.head_dim).contiguous())
323
+ # k = self.key_layernorm(k.view(-1, self.num_kv_heads, self.head_dim).contiguous())
324
+ q = self.query_layernorm(q.reshape(-1, self.head_dim).contiguous())
325
+ k = self.key_layernorm(k.reshape(-1, self.head_dim).contiguous())
326
+ elif self.attention_type == "cross":
327
+ assert kv_states is not None
328
+ ori_k, v = kv_states # use last layer kv,
329
+ k = ori_k
330
+ q, _ = self.q_proj(hidden_states)
331
+ k_tmp = torch.empty_like(k) # Todo: reduant rotary embedding
332
+ q, _ = self.rotary_emb(positions, q, k_tmp)
333
+ if self.use_qk_norm:
334
+ q = self.query_layernorm(
335
+ q.view(-1, self.num_heads, self.head_dim).contiguous()
336
+ )
337
+ k = self.key_layernorm(
338
+ k.view(-1, self.num_kv_heads, self.head_dim).contiguous()
339
+ )
340
+ else:
341
+ raise RuntimeError("Not support attnention type")
342
+
343
+ attn_output = self.attn(q, k, v, forward_batch)
344
+ output, _ = self.o_proj(attn_output)
345
+ return output, (ori_k, v)
346
+
347
+
348
+ class HunYuanDecoderLayer(nn.Module):
349
+
350
+ def __init__(
351
+ self,
352
+ config: PretrainedConfig,
353
+ quant_config: Optional[QuantizationConfig] = None,
354
+ prefix: str = "",
355
+ layer_id: int = -1,
356
+ ) -> None:
357
+ super().__init__()
358
+ assert layer_id >= 0
359
+ self.layer_id = layer_id
360
+ self.hidden_size = config.hidden_size
361
+ self.intermediate_size = (
362
+ config.intermediate_size
363
+ if isinstance(config.intermediate_size, int)
364
+ else config.intermediate_size[layer_id]
365
+ )
366
+ rope_theta = getattr(config, "rope_theta", 10000)
367
+ rope_scaling = getattr(config, "rope_scaling", None)
368
+ if rope_scaling is not None and getattr(
369
+ config, "original_max_position_embeddings", None
370
+ ):
371
+ rope_scaling["original_max_position_embeddings"] = (
372
+ config.original_max_position_embeddings
373
+ )
374
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
375
+ # Support abacusai/Smaug-72B-v0.1 with attention_bias
376
+ # Support internlm/internlm-7b with bias
377
+ attention_bias = getattr(config, "attention_bias", False) or getattr(
378
+ config, "bias", False
379
+ )
380
+ cla_factor = _get_cla_factor(config)
381
+ attention_type = (
382
+ "cross" if layer_id >= 0 and layer_id % cla_factor != 0 else "self"
383
+ )
384
+ self.self_attn = HunYuanAttention(
385
+ config=config,
386
+ hidden_size=self.hidden_size,
387
+ num_heads=config.num_attention_heads,
388
+ num_kv_heads=getattr(
389
+ config, "num_key_value_heads", config.num_attention_heads
390
+ ),
391
+ rope_theta=rope_theta,
392
+ rope_scaling=rope_scaling,
393
+ max_position_embeddings=max_position_embeddings,
394
+ quant_config=quant_config,
395
+ bias=attention_bias,
396
+ prefix=f"{prefix}.self_attn",
397
+ attention_type=attention_type,
398
+ layer_id=layer_id,
399
+ )
400
+ if _is_moe(config):
401
+ self.mlp = HunYuanSparseMoeBlock(
402
+ config=config,
403
+ quant_config=quant_config,
404
+ layer_id=layer_id,
405
+ )
406
+ else:
407
+ self.mlp = HunYuanMLP(
408
+ hidden_size=self.hidden_size,
409
+ intermediate_size=self.intermediate_size,
410
+ hidden_act=config.hidden_act,
411
+ quant_config=quant_config,
412
+ bias=getattr(config, "mlp_bias", False),
413
+ prefix=f"{prefix}.mlp",
414
+ )
415
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
416
+ self.post_attention_layernorm = RMSNorm(
417
+ config.hidden_size, eps=config.rms_norm_eps
418
+ )
419
+
420
+ def forward(
421
+ self,
422
+ positions: torch.Tensor,
423
+ hidden_states: torch.Tensor,
424
+ forward_batch: ForwardBatch,
425
+ residual: Optional[torch.Tensor],
426
+ kv_states: Optional[Tuple[torch.Tensor]] = None,
427
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
428
+ # Self Attention
429
+ if residual is None:
430
+ residual = hidden_states
431
+ hidden_states = self.input_layernorm(hidden_states)
432
+ else:
433
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
434
+ hidden_states, ori_kv_states = self.self_attn(
435
+ positions=positions,
436
+ hidden_states=hidden_states,
437
+ forward_batch=forward_batch,
438
+ kv_states=kv_states,
439
+ )
440
+
441
+ # Fully Connected
442
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
443
+ hidden_states = self.mlp(hidden_states)
444
+ return hidden_states, residual, ori_kv_states
445
+
446
+
447
+ class HunYuanModel(nn.Module):
448
+
449
+ def __init__(
450
+ self,
451
+ config: PretrainedConfig,
452
+ quant_config: Optional[QuantizationConfig] = None,
453
+ prefix: str = "",
454
+ ) -> None:
455
+ super().__init__()
456
+ self.config = config
457
+ self.padding_idx = config.pad_token_id
458
+ self.vocab_size = config.vocab_size
459
+ self.org_vocab_size = config.vocab_size
460
+
461
+ self.embed_tokens = VocabParallelEmbedding(
462
+ self.vocab_size,
463
+ config.hidden_size,
464
+ )
465
+
466
+ self.layers = nn.ModuleList(
467
+ [
468
+ HunYuanDecoderLayer(
469
+ config=config,
470
+ layer_id=layer_id,
471
+ quant_config=quant_config,
472
+ # prefix=prefix
473
+ )
474
+ for layer_id in range(config.num_hidden_layers)
475
+ ]
476
+ )
477
+
478
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
479
+
480
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
481
+ return self.embed_tokens(input_ids)
482
+
483
+ def forward(
484
+ self,
485
+ input_ids: Optional[torch.Tensor],
486
+ positions: torch.Tensor,
487
+ forward_batch: ForwardBatch,
488
+ input_embeds: Optional[torch.Tensor] = None,
489
+ ) -> torch.Tensor:
490
+ if input_embeds is not None:
491
+ hidden_states = input_embeds
492
+ else:
493
+ hidden_states = self.get_input_embeddings(input_ids)
494
+ residual = None
495
+
496
+ cla_factor = _get_cla_factor(self.config)
497
+ prev_kv_states = None
498
+ for i in range(len(self.layers)):
499
+ layer = self.layers[i]
500
+ hidden_states, residual, kv_states = layer(
501
+ positions,
502
+ hidden_states,
503
+ forward_batch,
504
+ residual,
505
+ prev_kv_states,
506
+ )
507
+
508
+ if False: # (i - self.start_layer) % cla_factor == 0:
509
+ prev_kv_states = kv_states
510
+ else:
511
+ prev_kv_states = None
512
+
513
+ hidden_states, _ = self.norm(hidden_states, residual)
514
+ return hidden_states
515
+
516
+
517
+ class HunYuanMoEV1ForCausalLM(nn.Module):
518
+ packed_modules_mapping = {
519
+ "qkv_proj": [
520
+ "q_proj",
521
+ "k_proj",
522
+ "v_proj",
523
+ ],
524
+ "gate_up_proj": [
525
+ "gate_proj",
526
+ "up_proj",
527
+ ],
528
+ }
529
+
530
+ embedding_modules = {
531
+ "embed_tokens": "input_embeddings",
532
+ "lm_head": "output_embeddings",
533
+ }
534
+ embedding_padding_modules = ["lm_head"]
535
+ bitsandbytes_stacked_params_mapping = {
536
+ # shard_name, weight_name, index
537
+ "q_proj": ("qkv_proj", 0),
538
+ "k_proj": ("qkv_proj", 1),
539
+ "v_proj": ("qkv_proj", 2),
540
+ "gate_proj": ("gate_up_proj", 0),
541
+ "up_proj": ("gate_up_proj", 1),
542
+ }
543
+
544
+ def __init__(
545
+ self,
546
+ config: PretrainedConfig,
547
+ quant_config: Optional[QuantizationConfig] = None,
548
+ ) -> None:
549
+ super().__init__()
550
+
551
+ self.config = config
552
+
553
+ self.model = HunYuanModel(config, quant_config, prefix="model")
554
+ self.unpadded_vocab_size = config.vocab_size
555
+ self.lm_head = ParallelLMHead(
556
+ config.vocab_size,
557
+ config.hidden_size,
558
+ quant_config=quant_config,
559
+ )
560
+ if config.tie_word_embeddings:
561
+ self.lm_head.weight = self.model.embed_tokens.weight
562
+
563
+ logit_scale = getattr(config, "logit_scale", 1.0)
564
+ self.logits_processor = LogitsProcessor(config, logit_scale=logit_scale)
565
+ self.sampler = Sampler()
566
+
567
+ def forward(
568
+ self,
569
+ input_ids: torch.Tensor,
570
+ positions: torch.Tensor,
571
+ forward_batch: ForwardBatch,
572
+ input_embeds: torch.Tensor = None,
573
+ ) -> torch.Tensor:
574
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
575
+ return self.logits_processor(
576
+ input_ids, hidden_states, self.lm_head, forward_batch
577
+ )
578
+
579
+ def _split_qkv_weight(self, qkv: torch.Tensor):
580
+ num_attention_heads = self.config.num_attention_heads
581
+ num_kv_heads = getattr(
582
+ self.config, "num_key_value_heads", self.config.num_attention_heads
583
+ )
584
+ num_key_value_groups = num_attention_heads // num_kv_heads
585
+ hidden_size = self.config.hidden_size
586
+ attention_head_dim = self.config.hidden_size // num_attention_heads
587
+
588
+ qkv = qkv.reshape(
589
+ num_kv_heads, num_key_value_groups + 2, attention_head_dim, hidden_size
590
+ )
591
+ q, k, v = torch.split(qkv, (num_key_value_groups, 1, 1), dim=1)
592
+ q = q.reshape(-1, hidden_size)
593
+ k = k.reshape(-1, hidden_size)
594
+ v = v.reshape(-1, hidden_size)
595
+ return torch.concat((q, k, v))
596
+ # return qkv.reshape((num_kv_heads, num_key_value_groups+2 , attention_head_dim, hidden_size)).permute((1,0,2,3)).reshape((-1, hidden_size)),
597
+
598
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
599
+ cla_factor = _get_cla_factor(self.config)
600
+ stacked_params_mapping = [
601
+ # (param_name, shard_name, shard_id)
602
+ (".qkv_proj", ".q_proj", "q"),
603
+ (".qkv_proj", ".k_proj", "k"),
604
+ (".qkv_proj", ".v_proj", "v"),
605
+ (".gate_up_proj", ".gate_proj", 0),
606
+ (".gate_up_proj", ".up_proj", 1),
607
+ ]
608
+
609
+ num_attention_heads = self.config.num_attention_heads
610
+ num_kv_heads = getattr(
611
+ self.config, "num_key_value_heads", self.config.num_attention_heads
612
+ )
613
+ split_params_mapping = [
614
+ (".gate_up_proj", ".gate_and_up_proj", 2, [(1, 1), (0, 1)], None),
615
+ (
616
+ ".qkv_proj",
617
+ ".qkv_proj",
618
+ num_attention_heads + num_kv_heads * 2,
619
+ [("q", num_attention_heads), ("k", num_kv_heads), ("v", num_kv_heads)],
620
+ self._split_qkv_weight,
621
+ ),
622
+ ]
623
+
624
+ if _is_moe(self.config):
625
+ # Params for weights, fp8 weight scales, fp8 activation scales
626
+ # (param_name, weight_name, expert_id, shard_id)
627
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
628
+ ckpt_gate_proj_name="gate_proj",
629
+ ckpt_down_proj_name="down_proj",
630
+ ckpt_up_proj_name="up_proj",
631
+ num_experts=self.config.num_experts,
632
+ )
633
+ else:
634
+ expert_params_mapping = {}
635
+
636
+ params_dict = dict(self.named_parameters())
637
+ for name, loaded_weight in weights:
638
+ if "rotary_emb.inv_freq" in name:
639
+ continue
640
+ if "gate_proj_bias" in name:
641
+ name = name.replace("gate_proj_bias", "gate_proj.bias")
642
+ if "up_proj_bias" in name:
643
+ name = name.replace("up_proj_bias", "up_proj.bias")
644
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
645
+ # Models trained using ColossalAI may include these tensors in
646
+ # the checkpoint. Skip them.
647
+ continue
648
+ # With tie_word_embeddings, we can skip lm_head.weight
649
+ # The weight might appear unnecessarily in the files if the model is
650
+ # processed with quantization, LoRA, fine-tuning, etc.
651
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
652
+ continue
653
+
654
+ is_found = False
655
+ for param_name, weight_name, shard_id in stacked_params_mapping:
656
+ if weight_name not in name:
657
+ continue
658
+ if "mlp.experts" in name:
659
+ continue
660
+ # cross layer only have q_proj, skip qkv pack
661
+ if weight_name == ".q_proj":
662
+ match = re.search(r"layers\.\d+", name)
663
+ if match:
664
+ layer_id = int(match.group(0).split(".")[-1])
665
+ if cla_factor > 1 and layer_id % cla_factor != 0:
666
+ continue
667
+ name = name.replace(weight_name, param_name)
668
+ # Skip loading extra bias for GPTQ models.
669
+ if name.endswith(".bias") and name not in params_dict:
670
+ continue
671
+
672
+ param = params_dict[name]
673
+ weight_loader = param.weight_loader
674
+ weight_loader(param, loaded_weight, shard_id)
675
+
676
+ is_found = True
677
+ break
678
+ if is_found:
679
+ continue
680
+
681
+ for param_name, weight_name, den, split_param, func in split_params_mapping:
682
+ if weight_name not in name:
683
+ continue
684
+ name = name.replace(weight_name, param_name)
685
+ # Skip loading extra bias for GPTQ models.
686
+ if name.endswith(".bias") and name not in params_dict:
687
+ continue
688
+
689
+ assert loaded_weight.shape[0] % den == 0
690
+ units = loaded_weight.shape[0] // den
691
+
692
+ param = params_dict[name]
693
+ weight_loader = param.weight_loader
694
+ offset = 0
695
+ for shard_id, num in split_param:
696
+ new_offset = offset + num * units
697
+ if func:
698
+ weight_loader(
699
+ param, func(loaded_weight)[offset:new_offset], shard_id
700
+ )
701
+ else:
702
+ weight_loader(param, loaded_weight[offset:new_offset], shard_id)
703
+ offset = new_offset
704
+
705
+ break
706
+ else:
707
+ # Skip loading extra bias for GPTQ models.
708
+ if name.endswith(".bias") and name not in params_dict:
709
+ continue
710
+ for mapping in expert_params_mapping:
711
+ param_name, weight_name, expert_id, shard_id = mapping
712
+ if weight_name not in name:
713
+ continue
714
+ name = name.replace(weight_name, param_name)
715
+ # Skip layers on other devices.
716
+ param = params_dict[name]
717
+ weight_loader = param.weight_loader
718
+ weight_loader(
719
+ param,
720
+ loaded_weight,
721
+ name,
722
+ shard_id=shard_id,
723
+ expert_id=expert_id,
724
+ )
725
+ break
726
+ else:
727
+ # Remapping the name of FP8 kv-scale.
728
+ name = maybe_remap_kv_scale_name(name, params_dict)
729
+ if name is None:
730
+ continue
731
+
732
+ if "mlp.gate.wg." in name:
733
+ name = name.replace("wg.", "")
734
+
735
+ param = params_dict[name]
736
+ weight_loader = getattr(
737
+ param, "weight_loader", default_weight_loader
738
+ )
739
+ weight_loader(param, loaded_weight)
740
+
741
+ # If this function is called, it should always initialize KV cache scale
742
+ # factors (or else raise an exception). Thus, handled exceptions should
743
+ # make sure to leave KV cache scale factors in a known good (dummy) state
744
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
745
+ tp_size = get_tensor_model_parallel_world_size()
746
+ tp_rank = get_tensor_model_parallel_rank()
747
+ for layer_idx, scaling_factor in kv_cache_scales_loader(
748
+ quantization_param_path,
749
+ tp_rank,
750
+ tp_size,
751
+ self.config.num_hidden_layers,
752
+ self.config.__class__.model_type,
753
+ ):
754
+ if not isinstance(self.model.layers[layer_idx], nn.Identity):
755
+ layer_self_attn = self.model.layers[layer_idx].self_attn
756
+
757
+ if is_hip():
758
+ # The scaling factor convention we are assuming is
759
+ # quantized_value * scaling_factor ~= true_value
760
+ # which is consistent with the practice of setting
761
+ # scaling_factor = tensor_amax / FPtype_max
762
+ scaling_factor *= 2
763
+ if hasattr(layer_self_attn, "kv_scale"):
764
+ layer_self_attn.attn._kv_scale = scaling_factor
765
+ else:
766
+ raise RuntimeError(
767
+ "Self attention has no KV cache scaling " "factor attribute!"
768
+ )
769
+
770
+
771
+ EntryClass = HunYuanMoEV1ForCausalLM