sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/srt/configs/model_config.py +35 -0
  3. sglang/srt/conversation.py +9 -5
  4. sglang/srt/disaggregation/base/conn.py +5 -2
  5. sglang/srt/disaggregation/decode.py +6 -1
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  8. sglang/srt/disaggregation/prefill.py +2 -0
  9. sglang/srt/distributed/parallel_state.py +11 -9
  10. sglang/srt/entrypoints/context.py +244 -0
  11. sglang/srt/entrypoints/engine.py +4 -3
  12. sglang/srt/entrypoints/harmony_utils.py +370 -0
  13. sglang/srt/entrypoints/http_server.py +71 -0
  14. sglang/srt/entrypoints/openai/protocol.py +227 -1
  15. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  16. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  18. sglang/srt/entrypoints/tool.py +87 -0
  19. sglang/srt/eplb/expert_location.py +5 -1
  20. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  21. sglang/srt/hf_transformers_utils.py +30 -3
  22. sglang/srt/jinja_template_utils.py +8 -1
  23. sglang/srt/layers/attention/aiter_backend.py +5 -8
  24. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  25. sglang/srt/layers/attention/triton_backend.py +85 -14
  26. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  28. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  29. sglang/srt/layers/attention/vision.py +13 -5
  30. sglang/srt/layers/communicator.py +21 -4
  31. sglang/srt/layers/dp_attention.py +12 -0
  32. sglang/srt/layers/linear.py +2 -7
  33. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  34. sglang/srt/layers/moe/ep_moe/layer.py +77 -73
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
  37. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  38. sglang/srt/layers/moe/topk.py +12 -3
  39. sglang/srt/layers/moe/utils.py +16 -0
  40. sglang/srt/layers/quantization/__init__.py +22 -0
  41. sglang/srt/layers/quantization/fp4.py +557 -0
  42. sglang/srt/layers/quantization/fp8.py +3 -6
  43. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  44. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  45. sglang/srt/layers/quantization/mxfp4.py +651 -0
  46. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  47. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  48. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  49. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  50. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  51. sglang/srt/layers/quantization/quark/utils.py +107 -0
  52. sglang/srt/layers/quantization/unquant.py +60 -6
  53. sglang/srt/layers/quantization/w4afp8.py +1 -1
  54. sglang/srt/layers/rotary_embedding.py +225 -1
  55. sglang/srt/layers/utils.py +9 -0
  56. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  57. sglang/srt/lora/lora_manager.py +70 -14
  58. sglang/srt/lora/lora_registry.py +3 -2
  59. sglang/srt/lora/mem_pool.py +43 -5
  60. sglang/srt/managers/cache_controller.py +55 -30
  61. sglang/srt/managers/detokenizer_manager.py +1 -1
  62. sglang/srt/managers/io_struct.py +15 -3
  63. sglang/srt/managers/mm_utils.py +5 -11
  64. sglang/srt/managers/schedule_batch.py +28 -7
  65. sglang/srt/managers/scheduler.py +26 -12
  66. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  67. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  68. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  69. sglang/srt/managers/template_manager.py +35 -1
  70. sglang/srt/managers/tokenizer_manager.py +24 -6
  71. sglang/srt/managers/tp_worker.py +3 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  73. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  74. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  75. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  76. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  77. sglang/srt/model_executor/cuda_graph_runner.py +7 -6
  78. sglang/srt/model_executor/forward_batch_info.py +35 -14
  79. sglang/srt/model_executor/model_runner.py +19 -2
  80. sglang/srt/model_loader/weight_utils.py +10 -0
  81. sglang/srt/models/bailing_moe.py +425 -0
  82. sglang/srt/models/deepseek_v2.py +72 -33
  83. sglang/srt/models/ernie4.py +426 -0
  84. sglang/srt/models/ernie4_eagle.py +203 -0
  85. sglang/srt/models/gemma3n_mm.py +39 -0
  86. sglang/srt/models/glm4_moe.py +24 -12
  87. sglang/srt/models/gpt_oss.py +1134 -0
  88. sglang/srt/models/qwen2.py +6 -0
  89. sglang/srt/models/qwen2_moe.py +6 -0
  90. sglang/srt/models/qwen3_moe.py +32 -6
  91. sglang/srt/models/step3_vl.py +9 -0
  92. sglang/srt/models/transformers.py +2 -5
  93. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  94. sglang/srt/reasoning_parser.py +18 -39
  95. sglang/srt/server_args.py +142 -7
  96. sglang/srt/two_batch_overlap.py +157 -5
  97. sglang/srt/utils.py +38 -2
  98. sglang/test/runners.py +2 -2
  99. sglang/test/test_utils.py +1 -1
  100. sglang/version.py +1 -1
  101. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
  102. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
  103. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  104. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  105. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,425 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/bailing_moe.py
3
+
4
+ from collections.abc import Iterable
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+ from transformers.configuration_utils import PretrainedConfig
11
+
12
+ from sglang.srt.distributed import (
13
+ get_tensor_model_parallel_world_size,
14
+ tensor_model_parallel_all_reduce,
15
+ )
16
+ from sglang.srt.layers.activation import SiluAndMul
17
+ from sglang.srt.layers.layernorm import RMSNorm
18
+ from sglang.srt.layers.linear import (
19
+ MergedColumnParallelLinear,
20
+ QKVParallelLinear,
21
+ ReplicatedLinear,
22
+ RowParallelLinear,
23
+ )
24
+ from sglang.srt.layers.logits_processor import LogitsProcessor
25
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
26
+ from sglang.srt.layers.moe.topk import TopK
27
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
28
+ from sglang.srt.layers.radix_attention import RadixAttention
29
+ from sglang.srt.layers.rotary_embedding import get_rope
30
+ from sglang.srt.layers.vocab_parallel_embedding import (
31
+ ParallelLMHead,
32
+ VocabParallelEmbedding,
33
+ )
34
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
35
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
36
+ from sglang.srt.utils import add_prefix, make_layers
37
+
38
+
39
+ class BailingAttention(nn.Module):
40
+
41
+ def __init__(
42
+ self,
43
+ config: PretrainedConfig,
44
+ layer_id: int = 0,
45
+ quant_config: Optional[QuantizationConfig] = None,
46
+ prefix: str = "",
47
+ ):
48
+ super().__init__()
49
+ self.hidden_size = config.hidden_size
50
+ tp_size = get_tensor_model_parallel_world_size()
51
+
52
+ self.total_num_heads = config.num_attention_heads
53
+ self.total_num_kv_heads = config.num_key_value_heads
54
+
55
+ assert self.total_num_heads % tp_size == 0
56
+ assert self.total_num_kv_heads % tp_size == 0
57
+
58
+ self.num_heads = self.total_num_heads // tp_size
59
+ self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
60
+ self.q_size = self.num_heads * self.head_dim
61
+
62
+ self.num_kv_heads = self.total_num_kv_heads // tp_size
63
+ self.kv_size = self.num_kv_heads * self.head_dim
64
+ self.scale = self.head_dim**-0.5
65
+
66
+ self.query_key_value = QKVParallelLinear(
67
+ self.hidden_size,
68
+ self.head_dim,
69
+ self.total_num_heads,
70
+ self.total_num_kv_heads,
71
+ bias=(config.use_bias or config.use_qkv_bias),
72
+ quant_config=quant_config,
73
+ prefix=add_prefix("query_key_value", prefix),
74
+ )
75
+
76
+ self.dense = RowParallelLinear(
77
+ self.total_num_heads * self.head_dim,
78
+ self.hidden_size,
79
+ bias=config.use_bias,
80
+ quant_config=quant_config,
81
+ prefix=add_prefix("dense", prefix),
82
+ )
83
+
84
+ self.attn = RadixAttention(
85
+ self.num_heads,
86
+ self.head_dim,
87
+ self.scale,
88
+ num_kv_heads=self.num_kv_heads,
89
+ layer_id=layer_id,
90
+ quant_config=quant_config,
91
+ prefix=add_prefix("attn", prefix),
92
+ )
93
+
94
+ self.rotary_emb = get_rope(
95
+ self.head_dim,
96
+ rotary_dim=self.head_dim,
97
+ max_position=config.max_position_embeddings,
98
+ base=config.rope_theta,
99
+ is_neox_style=True,
100
+ rope_scaling=config.rope_scaling,
101
+ )
102
+
103
+ def forward(
104
+ self,
105
+ hidden_states: torch.Tensor,
106
+ position_ids: torch.Tensor,
107
+ forward_batch: ForwardBatch,
108
+ ) -> torch.Tensor:
109
+ qkv, _ = self.query_key_value(hidden_states)
110
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
111
+
112
+ q, k = self.rotary_emb(position_ids, q, k)
113
+ context_layer = self.attn(q, k, v, forward_batch)
114
+ attn_output, _ = self.dense(context_layer)
115
+ return attn_output
116
+
117
+
118
+ class BailingMLP(nn.Module):
119
+ def __init__(
120
+ self,
121
+ intermediate_size: int,
122
+ config: PretrainedConfig,
123
+ quant_config: Optional[QuantizationConfig] = None,
124
+ reduce_results: Optional[bool] = True,
125
+ prefix: str = "",
126
+ ) -> None:
127
+ super().__init__()
128
+ self.gate_up_proj = MergedColumnParallelLinear(
129
+ config.hidden_size,
130
+ [intermediate_size] * 2,
131
+ bias=config.use_bias,
132
+ quant_config=quant_config,
133
+ prefix=add_prefix("gate_up_proj", prefix),
134
+ )
135
+ self.down_proj = RowParallelLinear(
136
+ intermediate_size,
137
+ config.hidden_size,
138
+ bias=config.use_bias,
139
+ quant_config=quant_config,
140
+ reduce_results=reduce_results,
141
+ prefix=add_prefix("down_proj", prefix),
142
+ )
143
+ self.act_fn = SiluAndMul()
144
+
145
+ def forward(self, x):
146
+ x, _ = self.gate_up_proj(x)
147
+ x = self.act_fn(x)
148
+ x, _ = self.down_proj(x)
149
+ return x
150
+
151
+
152
+ class BailingMoE(nn.Module):
153
+
154
+ def __init__(
155
+ self,
156
+ config: PretrainedConfig,
157
+ layer_id: int,
158
+ quant_config: Optional[QuantizationConfig] = None,
159
+ prefix: str = "",
160
+ ):
161
+ super().__init__()
162
+ self.tp_size = get_tensor_model_parallel_world_size()
163
+ self.num_experts = config.num_experts
164
+ self.top_k = config.num_experts_per_tok
165
+ self.hidden_size = config.hidden_size
166
+ self.num_shared_experts = config.num_shared_experts
167
+ self.norm_expert_prob = config.norm_topk_prob
168
+ self.moe_intermediate_size = config.moe_intermediate_size
169
+
170
+ self.gate = ReplicatedLinear(
171
+ self.hidden_size, self.num_experts, bias=False, quant_config=None
172
+ )
173
+
174
+ self.topk = TopK(top_k=self.top_k, renormalize=self.norm_expert_prob)
175
+
176
+ self.experts = FusedMoE(
177
+ num_experts=self.num_experts,
178
+ top_k=self.top_k,
179
+ layer_id=layer_id,
180
+ hidden_size=self.hidden_size,
181
+ intermediate_size=self.moe_intermediate_size,
182
+ reduce_results=False,
183
+ quant_config=quant_config,
184
+ prefix=add_prefix("experts", prefix),
185
+ )
186
+
187
+ if self.num_shared_experts > 0:
188
+ shared_intermediate_size = (
189
+ self.moe_intermediate_size * self.num_shared_experts
190
+ )
191
+ self.shared_experts = BailingMLP(
192
+ intermediate_size=shared_intermediate_size,
193
+ config=config,
194
+ quant_config=quant_config,
195
+ reduce_results=False,
196
+ prefix=add_prefix("shared_experts", prefix),
197
+ )
198
+ else:
199
+ self.shared_experts = None
200
+
201
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
202
+ orig_shape = hidden_states.shape
203
+ hidden_states_flat = hidden_states.view(-1, self.hidden_size)
204
+
205
+ shared_output = None
206
+ if self.shared_experts is not None:
207
+ shared_output = self.shared_experts(hidden_states_flat)
208
+
209
+ router_logits, _ = self.gate(hidden_states_flat)
210
+ topk_output = self.topk(hidden_states_flat, router_logits)
211
+ final_hidden_states = self.experts(hidden_states_flat, topk_output)
212
+
213
+ if shared_output is not None:
214
+ final_hidden_states = final_hidden_states + shared_output
215
+
216
+ if self.tp_size > 1:
217
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
218
+
219
+ return final_hidden_states.view(orig_shape)
220
+
221
+
222
+ class BailingMoeBlock(nn.Module):
223
+
224
+ def __init__(
225
+ self,
226
+ config: PretrainedConfig,
227
+ layer_id: int,
228
+ quant_config: Optional[QuantizationConfig] = None,
229
+ prefix: str = "",
230
+ ):
231
+ super().__init__()
232
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
233
+ self.attention = BailingAttention(
234
+ config, layer_id, quant_config, prefix=add_prefix("attention", prefix)
235
+ )
236
+ self.post_attention_layernorm = RMSNorm(
237
+ config.hidden_size, eps=config.rms_norm_eps
238
+ )
239
+ self.mlp = BailingMoE(
240
+ config=config,
241
+ layer_id=layer_id,
242
+ quant_config=quant_config,
243
+ prefix=add_prefix("mlp", prefix),
244
+ )
245
+
246
+ def forward(
247
+ self,
248
+ hidden_states: torch.Tensor,
249
+ position_ids: torch.Tensor,
250
+ residual: Optional[torch.Tensor],
251
+ forward_batch: ForwardBatch,
252
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
253
+ # Pre-normalization and residual connection for the attention block
254
+ if residual is None:
255
+ residual = hidden_states
256
+ normed_hidden_states = self.input_layernorm(hidden_states)
257
+ else:
258
+ normed_hidden_states, residual = self.input_layernorm(
259
+ hidden_states, residual
260
+ )
261
+
262
+ attn_output = self.attention(
263
+ hidden_states=normed_hidden_states,
264
+ position_ids=position_ids,
265
+ forward_batch=forward_batch,
266
+ )
267
+
268
+ # Pre-normalization and residual connection for the MLP block
269
+ normed_hidden_states, residual = self.post_attention_layernorm(
270
+ attn_output, residual
271
+ )
272
+ mlp_output = self.mlp(normed_hidden_states)
273
+
274
+ return mlp_output, residual
275
+
276
+
277
+ class BailingMoeModel(nn.Module):
278
+
279
+ def __init__(
280
+ self,
281
+ config: PretrainedConfig,
282
+ quant_config: Optional[QuantizationConfig] = None,
283
+ prefix: str = "",
284
+ ):
285
+ super().__init__()
286
+ self.config = config
287
+ self.padding_idx = config.pad_token_id
288
+ self.vocab_size = config.vocab_size
289
+ self.embed_dim = config.hidden_size
290
+
291
+ self.embed_tokens = VocabParallelEmbedding(
292
+ config.vocab_size,
293
+ config.hidden_size,
294
+ prefix=add_prefix("embed_tokens", prefix),
295
+ )
296
+ self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout)
297
+
298
+ self.layers = make_layers(
299
+ config.num_hidden_layers,
300
+ lambda idx, prefix: BailingMoeBlock(
301
+ config=config,
302
+ layer_id=idx,
303
+ quant_config=quant_config,
304
+ prefix=prefix,
305
+ ),
306
+ prefix=add_prefix("layers", prefix),
307
+ )
308
+
309
+ self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
310
+
311
+ def forward(
312
+ self,
313
+ input_ids: torch.Tensor,
314
+ position_ids: torch.Tensor,
315
+ forward_batch: ForwardBatch,
316
+ input_embeds: Optional[torch.Tensor] = None,
317
+ ) -> torch.Tensor:
318
+ if input_embeds is None:
319
+ hidden_states = self.embed_tokens(input_ids)
320
+ else:
321
+ hidden_states = input_embeds
322
+
323
+ residual = None
324
+ for layer in self.layers:
325
+ hidden_states, residual = layer(
326
+ hidden_states,
327
+ position_ids,
328
+ residual,
329
+ forward_batch,
330
+ )
331
+
332
+ hidden_states, _ = self.norm(hidden_states, residual)
333
+ return hidden_states
334
+
335
+
336
+ class BailingMoeForCausalLM(nn.Module):
337
+
338
+ def __init__(
339
+ self,
340
+ config: PretrainedConfig,
341
+ quant_config: Optional[QuantizationConfig] = None,
342
+ ) -> None:
343
+ super().__init__()
344
+ self.config = config
345
+ self.model = BailingMoeModel(config=config, quant_config=quant_config)
346
+ self.lm_head = ParallelLMHead(
347
+ num_embeddings=config.vocab_size,
348
+ embedding_dim=config.hidden_size,
349
+ quant_config=quant_config,
350
+ )
351
+ if config.tie_word_embeddings:
352
+ self.lm_head.weight = self.model.embed_tokens.weight
353
+
354
+ self.logits_processor = LogitsProcessor(config)
355
+
356
+ def forward(
357
+ self,
358
+ input_ids: torch.Tensor,
359
+ positions: torch.Tensor,
360
+ forward_batch: ForwardBatch,
361
+ inputs_embeds: Optional[torch.Tensor] = None,
362
+ ) -> torch.Tensor:
363
+ hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
364
+ return self.logits_processor(
365
+ input_ids, hidden_states, self.lm_head, forward_batch
366
+ )
367
+
368
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
369
+
370
+ stacked_params_mapping = [
371
+ ("gate_up_proj", "gate_proj", 0),
372
+ ("gate_up_proj", "up_proj", 1),
373
+ ]
374
+
375
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
376
+ ckpt_gate_proj_name="gate_proj",
377
+ ckpt_down_proj_name="down_proj",
378
+ ckpt_up_proj_name="up_proj",
379
+ num_experts=self.config.num_experts,
380
+ )
381
+
382
+ params_dict = dict(self.named_parameters())
383
+ for name, loaded_weight in weights:
384
+
385
+ if (
386
+ hasattr(self.config, "norm_head")
387
+ and self.config.norm_head
388
+ and "lm_head.weight" in name
389
+ ):
390
+ loaded_weight = F.normalize(loaded_weight, dim=0, p=2, eps=1e-7)
391
+
392
+ if "model.word_embeddings.weight" == name:
393
+ name = "model.embed_tokens.weight"
394
+
395
+ for param_name, weight_name, shard_id in stacked_params_mapping:
396
+ if weight_name in name and "mlp.experts" not in name:
397
+ full_param_name = name.replace(weight_name, param_name)
398
+ param = params_dict[full_param_name]
399
+ param.weight_loader(param, loaded_weight, shard_id)
400
+ break
401
+ else:
402
+ for p_name, w_name, e_id, s_id in expert_params_mapping:
403
+ if w_name in name and "mlp.experts" in name:
404
+ full_param_name = name.replace(w_name, p_name)
405
+ param = params_dict[full_param_name]
406
+ param.weight_loader(
407
+ param,
408
+ loaded_weight,
409
+ full_param_name,
410
+ shard_id=s_id,
411
+ expert_id=e_id,
412
+ )
413
+ break
414
+ else:
415
+ if name.endswith(".bias") and name not in params_dict:
416
+ continue
417
+
418
+ param = params_dict[name]
419
+ weight_loader = getattr(
420
+ param, "weight_loader", default_weight_loader
421
+ )
422
+ weight_loader(param, loaded_weight)
423
+
424
+
425
+ EntryClass = BailingMoeForCausalLM
@@ -60,12 +60,9 @@ from sglang.srt.layers.linear import (
60
60
  RowParallelLinear,
61
61
  )
62
62
  from sglang.srt.layers.logits_processor import LogitsProcessor
63
- from sglang.srt.layers.moe.ep_moe.layer import (
64
- DeepEPMoE,
65
- get_moe_impl_class,
66
- should_use_flashinfer_trtllm_moe,
67
- )
63
+ from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
68
64
  from sglang.srt.layers.moe.topk import TopK
65
+ from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
69
66
  from sglang.srt.layers.quantization import deep_gemm_wrapper
70
67
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
71
68
  from sglang.srt.layers.quantization.fp8_kernel import (
@@ -211,13 +208,21 @@ class DeepseekV2MLP(nn.Module):
211
208
  )
212
209
  self.act_fn = SiluAndMul()
213
210
 
214
- def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False):
211
+ def forward(
212
+ self,
213
+ x,
214
+ forward_batch=None,
215
+ can_fuse_mlp_allreduce: bool = False,
216
+ use_reduce_scatter: bool = False,
217
+ ):
215
218
  if (self.tp_size == 1) and x.shape[0] == 0:
216
219
  return x
217
220
 
218
221
  gate_up, _ = self.gate_up_proj(x)
219
222
  x = self.act_fn(gate_up)
220
- x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce)
223
+ x, _ = self.down_proj(
224
+ x, skip_all_reduce=can_fuse_mlp_allreduce or use_reduce_scatter
225
+ )
221
226
  return x
222
227
 
223
228
 
@@ -307,19 +312,15 @@ class DeepseekV2MoE(nn.Module):
307
312
  config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
308
313
  )
309
314
 
310
- self.topk = (
311
- TopK(
312
- top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
313
- renormalize=config.norm_topk_prob,
314
- use_grouped_topk=True,
315
- num_expert_group=config.n_group,
316
- num_fused_shared_experts=self.num_fused_shared_experts,
317
- topk_group=config.topk_group,
318
- correction_bias=self.gate.e_score_correction_bias,
319
- routed_scaling_factor=self.routed_scaling_factor,
320
- )
321
- if not should_use_flashinfer_trtllm_moe()
322
- else None
315
+ self.topk = TopK(
316
+ top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
317
+ renormalize=config.norm_topk_prob,
318
+ use_grouped_topk=True,
319
+ num_expert_group=config.n_group,
320
+ num_fused_shared_experts=self.num_fused_shared_experts,
321
+ topk_group=config.topk_group,
322
+ correction_bias=self.gate.e_score_correction_bias,
323
+ routed_scaling_factor=self.routed_scaling_factor,
323
324
  )
324
325
 
325
326
  self.experts = get_moe_impl_class()(
@@ -448,6 +449,7 @@ class DeepseekV2MoE(nn.Module):
448
449
  hidden_states: torch.Tensor,
449
450
  forward_batch: Optional[ForwardBatch] = None,
450
451
  can_fuse_mlp_allreduce: bool = False,
452
+ use_reduce_scatter: bool = False,
451
453
  ) -> torch.Tensor:
452
454
  if not self._enable_deepep_moe:
453
455
  DUAL_STREAM_TOKEN_THRESHOLD = 1024
@@ -457,15 +459,20 @@ class DeepseekV2MoE(nn.Module):
457
459
  and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
458
460
  ):
459
461
  return self.forward_normal_dual_stream(
460
- hidden_states, can_fuse_mlp_allreduce
462
+ hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
461
463
  )
462
464
  else:
463
- return self.forward_normal(hidden_states, can_fuse_mlp_allreduce)
465
+ return self.forward_normal(
466
+ hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
467
+ )
464
468
  else:
465
469
  return self.forward_deepep(hidden_states, forward_batch)
466
470
 
467
471
  def forward_normal_dual_stream(
468
- self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
472
+ self,
473
+ hidden_states: torch.Tensor,
474
+ can_fuse_mlp_allreduce: bool = False,
475
+ use_reduce_scatter: bool = False,
469
476
  ) -> torch.Tensor:
470
477
 
471
478
  current_stream = torch.cuda.current_stream()
@@ -476,10 +483,14 @@ class DeepseekV2MoE(nn.Module):
476
483
  # router_logits: (num_tokens, n_experts)
477
484
  router_logits = self.gate(hidden_states)
478
485
  kwargs = {"hidden_states": hidden_states}
479
- if self.topk is not None:
480
- kwargs["topk_output"] = self.topk(hidden_states, router_logits)
486
+
487
+ # FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
488
+ # Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
489
+ if should_use_flashinfer_trtllm_moe():
490
+ kwargs["topk_output"] = (self.topk, router_logits)
481
491
  else:
482
- kwargs["router_logits"] = router_logits
492
+ kwargs["topk_output"] = self.topk(hidden_states, router_logits)
493
+
483
494
  final_hidden_states = self.experts(**kwargs)
484
495
  if not _is_cuda:
485
496
  final_hidden_states *= self.routed_scaling_factor
@@ -489,12 +500,15 @@ class DeepseekV2MoE(nn.Module):
489
500
  torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
490
501
  final_hidden_states = final_hidden_states_out
491
502
  sm.tag(final_hidden_states)
492
- if self.tp_size > 1 and not can_fuse_mlp_allreduce:
503
+ if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
493
504
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
494
505
  return final_hidden_states
495
506
 
496
507
  def forward_normal(
497
- self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
508
+ self,
509
+ hidden_states: torch.Tensor,
510
+ can_fuse_mlp_allreduce: bool = False,
511
+ use_reduce_scatter: bool = False,
498
512
  ) -> torch.Tensor:
499
513
  if hasattr(self, "shared_experts") and use_intel_amx_backend(
500
514
  self.shared_experts.gate_up_proj
@@ -505,10 +519,14 @@ class DeepseekV2MoE(nn.Module):
505
519
  # router_logits: (num_tokens, n_experts)
506
520
  router_logits = self.gate(hidden_states)
507
521
  kwargs = {"hidden_states": hidden_states}
508
- if self.topk is not None:
509
- kwargs["topk_output"] = self.topk(hidden_states, router_logits)
522
+
523
+ # FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
524
+ # Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
525
+ if should_use_flashinfer_trtllm_moe():
526
+ kwargs["topk_output"] = (self.topk, router_logits)
510
527
  else:
511
- kwargs["router_logits"] = router_logits
528
+ kwargs["topk_output"] = self.topk(hidden_states, router_logits)
529
+
512
530
  final_hidden_states = self.experts(**kwargs)
513
531
  if not _is_cuda and not _use_aiter:
514
532
  # fused in biased_grouped_topk so we can skip here
@@ -519,7 +537,7 @@ class DeepseekV2MoE(nn.Module):
519
537
  torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
520
538
  final_hidden_states = final_hidden_states_out
521
539
  sm.tag(final_hidden_states)
522
- if self.tp_size > 1 and not can_fuse_mlp_allreduce:
540
+ if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
523
541
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
524
542
  return final_hidden_states
525
543
 
@@ -1821,6 +1839,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1821
1839
  layer_scatter_modes=self.layer_scatter_modes,
1822
1840
  input_layernorm=self.input_layernorm,
1823
1841
  post_attention_layernorm=self.post_attention_layernorm,
1842
+ allow_reduce_scatter=True,
1824
1843
  )
1825
1844
 
1826
1845
  def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
@@ -1883,7 +1902,13 @@ class DeepseekV2DecoderLayer(nn.Module):
1883
1902
  and not self.is_nextn
1884
1903
  )
1885
1904
 
1886
- hidden_states = self.mlp(hidden_states, forward_batch, can_fuse_mlp_allreduce)
1905
+ # For DP with padding, reduce scatter can be used instead of all-reduce.
1906
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
1907
+ forward_batch
1908
+ )
1909
+ hidden_states = self.mlp(
1910
+ hidden_states, forward_batch, can_fuse_mlp_allreduce, use_reduce_scatter
1911
+ )
1887
1912
 
1888
1913
  if can_fuse_mlp_allreduce:
1889
1914
  hidden_states._sglang_needs_allreduce_fusion = True
@@ -2060,6 +2085,8 @@ class DeepseekV2Model(nn.Module):
2060
2085
 
2061
2086
 
2062
2087
  class DeepseekV2ForCausalLM(nn.Module):
2088
+ # for quark model load
2089
+ packed_modules_mapping = {}
2063
2090
 
2064
2091
  def __init__(
2065
2092
  self,
@@ -2068,6 +2095,18 @@ class DeepseekV2ForCausalLM(nn.Module):
2068
2095
  prefix: str = "",
2069
2096
  ) -> None:
2070
2097
  super().__init__()
2098
+
2099
+ # for quark model load
2100
+ # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
2101
+ self.fuse_qkv_a_proj = (
2102
+ hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
2103
+ )
2104
+ if self.fuse_qkv_a_proj:
2105
+ self.packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
2106
+ "q_a_proj",
2107
+ "kv_a_proj_with_mqa",
2108
+ ]
2109
+
2071
2110
  self.config = config
2072
2111
  self.tp_size = get_tensor_model_parallel_world_size()
2073
2112
  self.quant_config = quant_config