sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,9 @@ import torch
5
5
  from torch.nn import Module
6
6
 
7
7
  from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
8
+ from sglang.srt.managers.expert_location import get_global_expert_location_metadata
9
+ from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
10
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
8
11
 
9
12
  try:
10
13
  from deep_gemm import (
@@ -40,7 +43,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
40
43
  tma_align_input_scale,
41
44
  )
42
45
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
43
- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
46
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase
44
47
  from sglang.srt.layers.moe.topk import select_experts
45
48
  from sglang.srt.layers.quantization.base_config import (
46
49
  QuantizationConfig,
@@ -49,7 +52,7 @@ from sglang.srt.layers.quantization.base_config import (
49
52
  from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
50
53
  from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
51
54
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
52
- from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
55
+ from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs
53
56
 
54
57
  _is_hip = is_hip()
55
58
 
@@ -92,6 +95,7 @@ class GroupedGemmRunner(torch.nn.Module):
92
95
  scale_a: torch.Tensor = None,
93
96
  scale_b: torch.Tensor = None,
94
97
  block_shape: Optional[List[int]] = None,
98
+ c_dtype=None,
95
99
  ):
96
100
  if self.use_flashinfer:
97
101
  # TODO: flashinfer
@@ -119,6 +123,7 @@ class GroupedGemmRunner(torch.nn.Module):
119
123
  scale_a,
120
124
  scale_b,
121
125
  block_shape=block_shape,
126
+ c_dtype=c_dtype,
122
127
  )
123
128
  return c
124
129
 
@@ -136,6 +141,7 @@ class EPMoE(torch.nn.Module):
136
141
  top_k: int,
137
142
  hidden_size: int,
138
143
  intermediate_size: int,
144
+ layer_id: int,
139
145
  params_dtype: Optional[torch.dtype] = None,
140
146
  renormalize: bool = True,
141
147
  use_grouped_topk: bool = False,
@@ -159,6 +165,7 @@ class EPMoE(torch.nn.Module):
159
165
  )
160
166
  self.tp_rank = get_tensor_model_parallel_rank()
161
167
 
168
+ self.layer_id = layer_id
162
169
  self.num_experts = num_experts
163
170
  assert self.num_experts % self.tp_size == 0
164
171
  self.num_experts_per_partition = self.num_experts // self.tp_size
@@ -210,6 +217,10 @@ class EPMoE(torch.nn.Module):
210
217
  self.grouped_gemm_runner = None
211
218
 
212
219
  def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
220
+ hidden_states_shape = hidden_states.shape
221
+ hidden_states_dtype = hidden_states.dtype
222
+ hidden_states_device = hidden_states.device
223
+
213
224
  assert self.quant_method is not None
214
225
 
215
226
  if self.grouped_gemm_runner is None:
@@ -229,6 +240,9 @@ class EPMoE(torch.nn.Module):
229
240
  correction_bias=self.correction_bias,
230
241
  custom_routing_function=self.custom_routing_function,
231
242
  routed_scaling_factor=self.routed_scaling_factor,
243
+ expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
244
+ layer_id=self.layer_id,
245
+ ),
232
246
  )
233
247
 
234
248
  reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
@@ -265,25 +279,21 @@ class EPMoE(torch.nn.Module):
265
279
  hidden_states.shape[1],
266
280
  BLOCK_SIZE=512,
267
281
  )
282
+ dispose_tensor(hidden_states)
268
283
 
269
284
  seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
270
285
  weight_indices_cur_rank = torch.arange(
271
286
  0,
272
287
  self.num_experts_per_partition,
273
- device=hidden_states.device,
288
+ device=hidden_states_device,
274
289
  dtype=torch.int64,
275
290
  )
276
291
  # GroupGemm-0
277
- gateup_output = torch.empty(
278
- gateup_input.shape[0],
279
- self.w13_weight.shape[1],
280
- device=hidden_states.device,
281
- dtype=hidden_states.dtype,
282
- )
283
292
  gateup_output = self.grouped_gemm_runner(
284
293
  a=gateup_input,
285
294
  b=self.w13_weight,
286
- c=gateup_output,
295
+ c=None,
296
+ c_dtype=hidden_states_dtype,
287
297
  batch_size=self.num_experts_per_partition,
288
298
  weight_column_major=True,
289
299
  seg_indptr=seg_indptr_cur_rank,
@@ -297,6 +307,7 @@ class EPMoE(torch.nn.Module):
297
307
  ),
298
308
  block_shape=self.block_shape,
299
309
  )
310
+ del gateup_input
300
311
 
301
312
  # Act
302
313
  down_input = torch.empty(
@@ -306,14 +317,14 @@ class EPMoE(torch.nn.Module):
306
317
  dtype=(
307
318
  self.fp8_dtype
308
319
  if (self.use_fp8_w8a8 and not self.use_block_quant)
309
- else hidden_states.dtype
320
+ else hidden_states_dtype
310
321
  ),
311
322
  )
312
323
  if self.w2_input_scale is None and not self.use_block_quant:
313
324
  self.w2_input_scale = torch.ones(
314
325
  self.num_experts_per_partition,
315
326
  dtype=torch.float32,
316
- device=hidden_states.device,
327
+ device=hidden_states_device,
317
328
  )
318
329
 
319
330
  if self.activation == "silu":
@@ -340,13 +351,14 @@ class EPMoE(torch.nn.Module):
340
351
  )
341
352
  else:
342
353
  raise ValueError(f"Unsupported activation: {self.activation=}")
354
+ del gateup_output
343
355
 
344
356
  # GroupGemm-1
345
357
  down_output = torch.empty(
346
358
  down_input.shape[0],
347
359
  self.w2_weight.shape[1],
348
- device=hidden_states.device,
349
- dtype=hidden_states.dtype,
360
+ device=hidden_states_device,
361
+ dtype=hidden_states_dtype,
350
362
  )
351
363
  down_output = self.grouped_gemm_runner(
352
364
  a=down_input,
@@ -365,10 +377,13 @@ class EPMoE(torch.nn.Module):
365
377
  ),
366
378
  block_shape=self.block_shape,
367
379
  )
380
+ del down_input
368
381
 
369
382
  # PostReorder
370
- output = torch.empty_like(hidden_states)
371
- post_reorder_triton_kernel[(hidden_states.size(0),)](
383
+ output = torch.empty(
384
+ hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
385
+ )
386
+ post_reorder_triton_kernel[(hidden_states_shape[0],)](
372
387
  down_output,
373
388
  output,
374
389
  src2dst,
@@ -377,7 +392,7 @@ class EPMoE(torch.nn.Module):
377
392
  self.start_expert_id,
378
393
  self.end_expert_id,
379
394
  self.top_k,
380
- hidden_states.size(1),
395
+ hidden_states_shape[1],
381
396
  BLOCK_SIZE=512,
382
397
  )
383
398
  return output
@@ -417,6 +432,28 @@ class EPMoE(torch.nn.Module):
417
432
  weight_name: str,
418
433
  shard_id: str,
419
434
  expert_id: int,
435
+ ) -> None:
436
+ physical_expert_ids = (
437
+ get_global_expert_location_metadata().logical_to_all_physical(
438
+ self.layer_id, expert_id
439
+ )
440
+ )
441
+ for physical_expert_id in physical_expert_ids:
442
+ self._weight_loader_physical(
443
+ param=param,
444
+ loaded_weight=loaded_weight,
445
+ weight_name=weight_name,
446
+ shard_id=shard_id,
447
+ expert_id=physical_expert_id,
448
+ )
449
+
450
+ def _weight_loader_physical(
451
+ self,
452
+ param: torch.nn.Parameter,
453
+ loaded_weight: torch.Tensor,
454
+ weight_name: str,
455
+ shard_id: str,
456
+ expert_id: int,
420
457
  ) -> None:
421
458
  if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
422
459
  return
@@ -460,7 +497,8 @@ class EPMoE(torch.nn.Module):
460
497
  # Input scales can be loaded directly and should be equal.
461
498
  if "input_scale" in weight_name:
462
499
  if (
463
- param_data[expert_id] != 1
500
+ (shard_id == "w1" or shard_id == "w3")
501
+ and param_data[expert_id] != 1
464
502
  and (param_data[expert_id] - loaded_weight).abs() > 1e-5
465
503
  ):
466
504
  raise ValueError(
@@ -534,13 +572,10 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
534
572
  set_weight_attrs(w2_weight, extra_weight_attrs)
535
573
 
536
574
  # scale
575
+ layer.register_parameter("w13_input_scale", None)
576
+ layer.register_parameter("w13_weight_scale", None)
577
+
537
578
  ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
538
- w13_input_scale = torch.nn.Parameter(
539
- ones_tensor,
540
- requires_grad=False,
541
- )
542
- layer.register_parameter("w13_input_scale", w13_input_scale)
543
- set_weight_attrs(w13_input_scale, extra_weight_attrs)
544
579
 
545
580
  w2_input_scale = torch.nn.Parameter(
546
581
  ones_tensor,
@@ -549,13 +584,6 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
549
584
  layer.register_parameter("w2_input_scale", w2_input_scale)
550
585
  set_weight_attrs(w2_input_scale, extra_weight_attrs)
551
586
 
552
- w13_weight_scale = torch.nn.Parameter(
553
- ones_tensor,
554
- requires_grad=False,
555
- )
556
- layer.register_parameter("w13_weight_scale", w13_weight_scale)
557
- set_weight_attrs(w13_weight_scale, extra_weight_attrs)
558
-
559
587
  w2_weight_scale = torch.nn.Parameter(
560
588
  ones_tensor,
561
589
  requires_grad=False,
@@ -611,7 +639,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
611
639
  self.quant_config.weight_block_size[1],
612
640
  )
613
641
  # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
614
- # Required by collum parallel or enabling merged weights
642
+ # Required by column parallel or enabling merged weights
615
643
  if intermediate_size % block_n != 0:
616
644
  raise ValueError(
617
645
  f"The output_size of gate's and up's weight = "
@@ -802,6 +830,7 @@ class DeepEPMoE(EPMoE):
802
830
  top_k: int,
803
831
  hidden_size: int,
804
832
  intermediate_size: int,
833
+ layer_id: int,
805
834
  params_dtype: Optional[torch.dtype] = None,
806
835
  renormalize: bool = True,
807
836
  use_grouped_topk: bool = False,
@@ -821,6 +850,7 @@ class DeepEPMoE(EPMoE):
821
850
  top_k,
822
851
  hidden_size,
823
852
  intermediate_size,
853
+ layer_id,
824
854
  params_dtype,
825
855
  renormalize,
826
856
  use_grouped_topk,
@@ -881,6 +911,9 @@ class DeepEPMoE(EPMoE):
881
911
  reorder_topk_ids: torch.Tensor,
882
912
  seg_indptr: torch.Tensor,
883
913
  ):
914
+ hidden_states_dtype = hidden_states.dtype
915
+ hidden_states_device = hidden_states.device
916
+
884
917
  assert self.quant_method is not None
885
918
  assert self.activation == "silu"
886
919
  if self.grouped_gemm_runner is None:
@@ -903,18 +936,12 @@ class DeepEPMoE(EPMoE):
903
936
  )
904
937
 
905
938
  # GroupGemm-0
906
- gateup_output = torch.empty(
907
- hidden_states.shape[0],
908
- self.w13_weight.shape[1],
909
- device=hidden_states.device,
910
- dtype=hidden_states.dtype,
911
- )
912
-
913
939
  if hidden_states.shape[0] > 0:
914
940
  gateup_output = self.grouped_gemm_runner(
915
941
  a=hidden_states,
916
942
  b=self.w13_weight,
917
- c=gateup_output,
943
+ c=None,
944
+ c_dtype=hidden_states.dtype,
918
945
  batch_size=self.num_experts_per_partition,
919
946
  weight_column_major=True,
920
947
  seg_indptr=seg_indptr,
@@ -928,6 +955,13 @@ class DeepEPMoE(EPMoE):
928
955
  ),
929
956
  block_shape=self.block_shape,
930
957
  )
958
+ else:
959
+ gateup_output = torch.empty(
960
+ hidden_states.shape[0],
961
+ self.w13_weight.shape[1],
962
+ device=hidden_states.device,
963
+ dtype=hidden_states.dtype,
964
+ )
931
965
 
932
966
  # Act
933
967
  down_input = torch.empty(
@@ -937,14 +971,14 @@ class DeepEPMoE(EPMoE):
937
971
  dtype=(
938
972
  self.fp8_dtype
939
973
  if (self.use_fp8_w8a8 and not self.use_block_quant)
940
- else hidden_states.dtype
974
+ else hidden_states_dtype
941
975
  ),
942
976
  )
943
977
  if self.w2_input_scale is None and not self.use_block_quant:
944
978
  self.w2_input_scale = torch.ones(
945
979
  self.num_experts_per_partition,
946
980
  dtype=torch.float32,
947
- device=hidden_states.device,
981
+ device=hidden_states_device,
948
982
  )
949
983
 
950
984
  if self.activation == "silu":
@@ -961,12 +995,14 @@ class DeepEPMoE(EPMoE):
961
995
  else:
962
996
  raise ValueError(f"Unsupported activation: {self.activation=}")
963
997
 
998
+ del gateup_output
999
+
964
1000
  # GroupGemm-1
965
1001
  down_output = torch.empty(
966
1002
  down_input.shape[0],
967
1003
  self.w2_weight.shape[1],
968
- device=hidden_states.device,
969
- dtype=hidden_states.dtype,
1004
+ device=hidden_states_device,
1005
+ dtype=hidden_states_dtype,
970
1006
  )
971
1007
  if down_input.shape[0] > 0:
972
1008
  down_output = self.grouped_gemm_runner(
@@ -1007,11 +1043,9 @@ class DeepEPMoE(EPMoE):
1007
1043
  N = self.w13_weight.size(1)
1008
1044
  scale_block_size = 128
1009
1045
 
1010
- gather_out = torch.empty_like(
1011
- hidden_states_fp8,
1012
- device=hidden_states_fp8.device,
1013
- dtype=torch.bfloat16,
1014
- )
1046
+ hidden_states_fp8_shape = hidden_states_fp8.shape
1047
+ hidden_states_fp8_device = hidden_states_fp8.device
1048
+ hidden_states_fp8_dtype = hidden_states_fp8.dtype
1015
1049
 
1016
1050
  input_tensor = [
1017
1051
  torch.empty(
@@ -1049,16 +1083,18 @@ class DeepEPMoE(EPMoE):
1049
1083
  m_indices,
1050
1084
  output_index,
1051
1085
  )
1086
+ dispose_tensor(hidden_states_fp8)
1052
1087
 
1053
1088
  gateup_output = torch.empty(
1054
1089
  (all_tokens, N),
1055
- device=hidden_states_fp8.device,
1090
+ device=hidden_states_fp8_device,
1056
1091
  dtype=torch.bfloat16,
1057
1092
  )
1058
1093
  input_tensor[1] = tma_align_input_scale(input_tensor[1])
1059
1094
  m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
1060
1095
  input_tensor, self.w13_weight_fp8, gateup_output, m_indices
1061
1096
  )
1097
+ del input_tensor
1062
1098
  down_input = torch.empty(
1063
1099
  (
1064
1100
  all_tokens,
@@ -1068,14 +1104,16 @@ class DeepEPMoE(EPMoE):
1068
1104
  dtype=torch.bfloat16,
1069
1105
  )
1070
1106
  silu_and_mul(gateup_output.view(-1, N), down_input)
1107
+ del gateup_output
1071
1108
  down_output = torch.empty(
1072
1109
  (all_tokens, K),
1073
- device=hidden_states_fp8.device,
1110
+ device=hidden_states_fp8_device,
1074
1111
  dtype=torch.bfloat16,
1075
1112
  )
1076
1113
  down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
1077
1114
  down_input, scale_block_size
1078
1115
  )
1116
+ del down_input
1079
1117
  down_input_scale = tma_align_input_scale(down_input_scale)
1080
1118
  m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
1081
1119
  (down_input_fp8, down_input_scale),
@@ -1083,7 +1121,13 @@ class DeepEPMoE(EPMoE):
1083
1121
  down_output,
1084
1122
  m_indices,
1085
1123
  )
1124
+ del down_input_fp8, down_input_scale
1086
1125
 
1126
+ gather_out = torch.empty(
1127
+ hidden_states_fp8_shape,
1128
+ device=hidden_states_fp8_device,
1129
+ dtype=torch.bfloat16,
1130
+ )
1087
1131
  ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
1088
1132
 
1089
1133
  return gather_out
@@ -1107,6 +1151,7 @@ class DeepEPMoE(EPMoE):
1107
1151
  m_grouped_gemm_fp8_fp8_bf16_nt_masked(
1108
1152
  hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
1109
1153
  )
1154
+ dispose_tensor(hidden_states_fp8[0])
1110
1155
 
1111
1156
  # Act
1112
1157
  down_input = torch.empty(
@@ -1135,6 +1180,7 @@ class DeepEPMoE(EPMoE):
1135
1180
  scale_block_size,
1136
1181
  masked_m,
1137
1182
  )
1183
+ del gateup_output
1138
1184
 
1139
1185
  # GroupGemm-1
1140
1186
  n = self.w2_weight.size(1)
@@ -1150,3 +1196,11 @@ class DeepEPMoE(EPMoE):
1150
1196
  )
1151
1197
 
1152
1198
  return down_output
1199
+
1200
+
1201
+ def get_moe_impl_class():
1202
+ if global_server_args_dict["enable_deepep_moe"]:
1203
+ return DeepEPMoE
1204
+ if global_server_args_dict["enable_ep_moe"]:
1205
+ return EPMoE
1206
+ return FusedMoE
@@ -1,8 +1,15 @@
1
+ import logging
2
+ from dataclasses import dataclass
3
+
1
4
  from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
2
- from sglang.srt.utils import DeepEPMode
5
+ from sglang.srt.managers.expert_distribution import (
6
+ get_global_expert_distribution_recorder,
7
+ )
8
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
9
+ from sglang.srt.utils import DeepEPMode, load_json_config
3
10
 
4
11
  try:
5
- from deep_ep import Buffer
12
+ from deep_ep import Buffer, Config
6
13
 
7
14
  from sglang.srt.layers.quantization.fp8_kernel import (
8
15
  sglang_per_token_group_quant_fp8,
@@ -12,7 +19,7 @@ try:
12
19
  except ImportError:
13
20
  use_deepep = False
14
21
 
15
- from enum import IntEnum, auto
22
+ from enum import Enum, IntEnum, auto
16
23
  from typing import Optional, Tuple, Union
17
24
 
18
25
  import torch
@@ -25,6 +32,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
25
32
  )
26
33
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
27
34
 
35
+ logger = logging.getLogger(__name__)
36
+
28
37
 
29
38
  class DeepEPDispatchMode(IntEnum):
30
39
  NORMAL = auto()
@@ -32,7 +41,6 @@ class DeepEPDispatchMode(IntEnum):
32
41
 
33
42
 
34
43
  class DeepEPBuffer:
35
-
36
44
  _buffer = None
37
45
  _dispatch_mode: Optional[DeepEPDispatchMode] = None
38
46
  _hidden_size: Optional[int] = None
@@ -60,8 +68,10 @@ class DeepEPBuffer:
60
68
  if deepep_mode.enable_normal():
61
69
  hidden_bytes = hidden_size * param_bytes
62
70
  for config in (
63
- Buffer.get_dispatch_config(group.size()),
64
- Buffer.get_combine_config(group.size()),
71
+ DeepEPConfig.get_instance().normal_dispatch_config
72
+ or Buffer.get_dispatch_config(group.size()),
73
+ DeepEPConfig.get_instance().normal_combine_config
74
+ or Buffer.get_combine_config(group.size()),
65
75
  ):
66
76
  num_nvl_bytes = max(
67
77
  config.get_nvl_buffer_size_hint(hidden_bytes, group.size()),
@@ -88,7 +98,12 @@ class DeepEPBuffer:
88
98
  num_nvl_bytes,
89
99
  num_rdma_bytes,
90
100
  low_latency_mode=deepep_mode.enable_low_latency(),
91
- num_qps_per_rank=(max(num_experts // group.size(), Buffer.num_sms // 2)),
101
+ num_qps_per_rank=(
102
+ max(
103
+ num_experts // group.size(),
104
+ DeepEPConfig.get_instance().num_sms // 2,
105
+ )
106
+ ),
92
107
  )
93
108
  return cls._buffer
94
109
 
@@ -113,6 +128,35 @@ class DeepEPBuffer:
113
128
  cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
114
129
 
115
130
 
131
+ class DeepEPConfig:
132
+ _instance = None
133
+
134
+ def __init__(self):
135
+ config_str = global_server_args_dict["deepep_config"]
136
+ if config_str:
137
+ config_parsed = load_json_config(config_str)
138
+ if torch.distributed.get_rank() == 0:
139
+ logger.info(f"Use DeepEP Config: {config_parsed}")
140
+ config_dispatch = config_parsed["normal_dispatch"]
141
+ config_combine = config_parsed["normal_combine"]
142
+
143
+ self.normal_dispatch_config = Config(**config_dispatch)
144
+ self.normal_combine_config = Config(**config_combine)
145
+
146
+ assert config_dispatch["num_sms"] == config_combine["num_sms"]
147
+ self.num_sms = config_dispatch["num_sms"]
148
+ else:
149
+ self.normal_dispatch_config = None
150
+ self.normal_combine_config = None
151
+ self.num_sms = Buffer.num_sms
152
+
153
+ @classmethod
154
+ def get_instance(cls):
155
+ if cls._instance is None:
156
+ cls._instance = DeepEPConfig()
157
+ return cls._instance
158
+
159
+
116
160
  class _DeepEPDispatcherImplBase:
117
161
  def __init__(
118
162
  self,
@@ -295,6 +339,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
295
339
  async_finish=self.async_finish,
296
340
  allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
297
341
  expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
342
+ config=DeepEPConfig.get_instance().normal_dispatch_config,
343
+ )
344
+
345
+ get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
346
+ num_recv_tokens_per_expert_list,
347
+ num_tokens_per_rank=num_tokens_per_rank,
348
+ num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
349
+ num_tokens_per_expert=num_tokens_per_expert,
298
350
  )
299
351
 
300
352
  return (
@@ -394,6 +446,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
394
446
  async_finish=self.async_finish,
395
447
  previous_event=previous_event,
396
448
  allocate_on_comm_stream=previous_event is not None,
449
+ config=DeepEPConfig.get_instance().normal_combine_config,
397
450
  )
398
451
  return combined_x, event
399
452
 
@@ -459,6 +512,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
459
512
  ):
460
513
  hook() if self.return_recv_hook else event.current_stream_wait()
461
514
 
515
+ get_global_expert_distribution_recorder().on_deepep_dispatch_low_latency(
516
+ masked_m
517
+ )
518
+
462
519
  reorder_topk_ids = seg_indptr = None
463
520
 
464
521
  return (
@@ -571,6 +628,14 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
571
628
  )
572
629
 
573
630
 
631
+ @dataclass
632
+ class _Stage(Enum):
633
+ INITIAL = auto()
634
+ AFTER_DISPATCH_A = auto()
635
+ AFTER_DISPATCH_B = auto()
636
+ AFTER_COMBINE_A = auto()
637
+
638
+
574
639
  class DeepEPDispatcher:
575
640
  def __init__(
576
641
  self,
@@ -609,6 +674,8 @@ class DeepEPDispatcher:
609
674
  **common_kwargs,
610
675
  )
611
676
 
677
+ self._stage = _Stage.INITIAL
678
+
612
679
  def dispatch(self, *args, **kwargs) -> Tuple:
613
680
  self.dispatch_a(*args, **kwargs)
614
681
  ret = self.dispatch_b()
@@ -621,6 +688,7 @@ class DeepEPDispatcher:
621
688
  topk_weights: torch.Tensor,
622
689
  forward_mode: ForwardMode = None,
623
690
  ):
691
+ self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
624
692
  inner_state = self._get_impl(forward_mode).dispatch_a(
625
693
  hidden_states=hidden_states,
626
694
  topk_idx=topk_idx,
@@ -629,6 +697,7 @@ class DeepEPDispatcher:
629
697
  self._dispatch_intermediate_state = forward_mode, inner_state
630
698
 
631
699
  def dispatch_b(self):
700
+ self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
632
701
  forward_mode, inner_state = self._dispatch_intermediate_state
633
702
  del self._dispatch_intermediate_state
634
703
  return self._get_impl(forward_mode).dispatch_b(*inner_state)
@@ -645,6 +714,7 @@ class DeepEPDispatcher:
645
714
  topk_weights: torch.Tensor,
646
715
  forward_mode: ForwardMode,
647
716
  ):
717
+ self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
648
718
  inner_state = self._get_impl(forward_mode).combine_a(
649
719
  hidden_states=hidden_states,
650
720
  topk_idx=topk_idx,
@@ -653,6 +723,7 @@ class DeepEPDispatcher:
653
723
  self._combine_intermediate_state = forward_mode, inner_state
654
724
 
655
725
  def combine_b(self):
726
+ self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
656
727
  forward_mode, inner_state = self._combine_intermediate_state
657
728
  del self._combine_intermediate_state
658
729
  return self._get_impl(forward_mode).combine_b(*inner_state)
@@ -665,3 +736,7 @@ class DeepEPDispatcher:
665
736
  return self._low_latency_dispatcher
666
737
  else:
667
738
  raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
739
+
740
+ def _update_stage(self, old_stage, new_stage):
741
+ assert self._stage == old_stage
742
+ self._stage = new_stage
@@ -994,7 +994,7 @@ def get_default_config(
994
994
  "num_stages": 2 if _is_hip else 4,
995
995
  }
996
996
  else:
997
- # Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
997
+ # Block-wise quant: BLOCK_SIZE_K must be divisible by block_shape[1]
998
998
  config = {
999
999
  "BLOCK_SIZE_M": 64,
1000
1000
  "BLOCK_SIZE_N": block_shape[0],
@@ -186,6 +186,19 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
186
186
 
187
187
  if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
188
188
  assert not no_combine, "unsupported"
189
+ if apply_router_weight_on_input:
190
+ assert (
191
+ topk_weights.dim() == 2
192
+ ), "`topk_weights` should be in shape (num_tokens, topk)"
193
+ _, topk = topk_weights.shape
194
+ assert (
195
+ topk == 1
196
+ ), "Only support topk=1 when `apply_router_weight_on_input` is True"
197
+ x = x * topk_weights.to(x.dtype)
198
+ topk_weights = torch.ones_like(
199
+ topk_weights, dtype=torch.float32
200
+ ) # topk_weights must be FP32 (float32)
201
+
189
202
  return ck_moe_2stages(
190
203
  x,
191
204
  layer.w13_weight,
@@ -270,6 +283,7 @@ class FusedMoE(torch.nn.Module):
270
283
  top_k: int,
271
284
  hidden_size: int,
272
285
  intermediate_size: int,
286
+ layer_id: Optional[int] = None,
273
287
  params_dtype: Optional[torch.dtype] = None,
274
288
  reduce_results: bool = False,
275
289
  renormalize: bool = True,