sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. sglang/bench_one_batch.py +0 -7
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +25 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -2
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +29 -4
  24. sglang/srt/entrypoints/http_server.py +76 -0
  25. sglang/srt/entrypoints/openai/protocol.py +4 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +23 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +10 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +14 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
  37. sglang/srt/layers/attention/triton_backend.py +109 -73
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +58 -10
  46. sglang/srt/layers/dp_attention.py +137 -27
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +16 -18
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  71. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  72. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  73. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  75. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  76. sglang/srt/layers/moe/router.py +15 -9
  77. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  78. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  79. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  80. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  81. sglang/srt/layers/moe/topk.py +167 -83
  82. sglang/srt/layers/moe/utils.py +159 -18
  83. sglang/srt/layers/multimodal.py +156 -40
  84. sglang/srt/layers/quantization/__init__.py +18 -46
  85. sglang/srt/layers/quantization/awq.py +22 -23
  86. sglang/srt/layers/quantization/base_config.py +2 -6
  87. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  88. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
  89. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  90. sglang/srt/layers/quantization/fp8.py +127 -119
  91. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  92. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  93. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  94. sglang/srt/layers/quantization/gptq.py +17 -21
  95. sglang/srt/layers/quantization/marlin_utils.py +26 -8
  96. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  97. sglang/srt/layers/quantization/modelopt_quant.py +217 -98
  98. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  99. sglang/srt/layers/quantization/mxfp4.py +222 -39
  100. sglang/srt/layers/quantization/quark/quark.py +390 -0
  101. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  102. sglang/srt/layers/quantization/unquant.py +34 -70
  103. sglang/srt/layers/quantization/utils.py +77 -2
  104. sglang/srt/layers/quantization/w4afp8.py +7 -8
  105. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  106. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  107. sglang/srt/layers/radix_attention.py +6 -0
  108. sglang/srt/layers/rotary_embedding.py +1 -0
  109. sglang/srt/layers/sampler.py +5 -2
  110. sglang/srt/lora/layers.py +6 -2
  111. sglang/srt/lora/lora_manager.py +21 -22
  112. sglang/srt/lora/lora_registry.py +3 -3
  113. sglang/srt/lora/mem_pool.py +26 -24
  114. sglang/srt/lora/utils.py +10 -12
  115. sglang/srt/managers/cache_controller.py +80 -19
  116. sglang/srt/managers/detokenizer_manager.py +10 -2
  117. sglang/srt/managers/io_struct.py +23 -0
  118. sglang/srt/managers/mm_utils.py +1 -1
  119. sglang/srt/managers/schedule_batch.py +22 -48
  120. sglang/srt/managers/scheduler.py +28 -20
  121. sglang/srt/managers/session_controller.py +1 -1
  122. sglang/srt/managers/template_manager.py +7 -5
  123. sglang/srt/managers/tokenizer_manager.py +88 -39
  124. sglang/srt/managers/tp_worker.py +1 -0
  125. sglang/srt/managers/utils.py +59 -1
  126. sglang/srt/mem_cache/allocator.py +10 -157
  127. sglang/srt/mem_cache/allocator_ascend.py +147 -0
  128. sglang/srt/mem_cache/chunk_cache.py +1 -1
  129. sglang/srt/mem_cache/hicache_storage.py +14 -4
  130. sglang/srt/mem_cache/memory_pool.py +3 -3
  131. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  132. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  133. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  134. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  135. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  136. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  137. sglang/srt/model_executor/cuda_graph_runner.py +33 -33
  138. sglang/srt/model_executor/forward_batch_info.py +11 -10
  139. sglang/srt/model_executor/model_runner.py +93 -78
  140. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  141. sglang/srt/model_loader/loader.py +24 -6
  142. sglang/srt/models/dbrx.py +12 -6
  143. sglang/srt/models/deepseek.py +2 -1
  144. sglang/srt/models/deepseek_nextn.py +5 -2
  145. sglang/srt/models/deepseek_v2.py +226 -223
  146. sglang/srt/models/ernie4.py +2 -2
  147. sglang/srt/models/glm4_moe.py +27 -65
  148. sglang/srt/models/glm4_moe_nextn.py +2 -1
  149. sglang/srt/models/glm4v.py +52 -1
  150. sglang/srt/models/glm4v_moe.py +8 -11
  151. sglang/srt/models/gpt_oss.py +41 -76
  152. sglang/srt/models/granitemoe.py +0 -1
  153. sglang/srt/models/grok.py +376 -48
  154. sglang/srt/models/interns1.py +12 -47
  155. sglang/srt/models/internvl.py +6 -51
  156. sglang/srt/models/llama.py +10 -2
  157. sglang/srt/models/llama4.py +18 -7
  158. sglang/srt/models/minicpm3.py +0 -1
  159. sglang/srt/models/mixtral.py +0 -2
  160. sglang/srt/models/nemotron_nas.py +435 -0
  161. sglang/srt/models/olmoe.py +0 -1
  162. sglang/srt/models/phi4mm.py +3 -21
  163. sglang/srt/models/qwen2.py +2 -2
  164. sglang/srt/models/qwen2_5_vl.py +2 -0
  165. sglang/srt/models/qwen2_moe.py +23 -23
  166. sglang/srt/models/qwen3.py +2 -2
  167. sglang/srt/models/qwen3_classification.py +84 -0
  168. sglang/srt/models/qwen3_moe.py +27 -43
  169. sglang/srt/models/step3_vl.py +8 -3
  170. sglang/srt/models/xverse_moe.py +11 -5
  171. sglang/srt/multimodal/processors/base_processor.py +3 -3
  172. sglang/srt/multimodal/processors/internvl.py +7 -2
  173. sglang/srt/multimodal/processors/llava.py +11 -7
  174. sglang/srt/offloader.py +433 -0
  175. sglang/srt/operations.py +22 -2
  176. sglang/srt/reasoning_parser.py +4 -3
  177. sglang/srt/sampling/sampling_batch_info.py +7 -4
  178. sglang/srt/server_args.py +264 -105
  179. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
  180. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  181. sglang/srt/speculative/eagle_utils.py +36 -13
  182. sglang/srt/speculative/eagle_worker.py +56 -3
  183. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  184. sglang/srt/two_batch_overlap.py +20 -19
  185. sglang/srt/utils.py +68 -70
  186. sglang/test/runners.py +8 -5
  187. sglang/test/test_block_fp8.py +5 -6
  188. sglang/test/test_block_fp8_ep.py +13 -19
  189. sglang/test/test_cutlass_moe.py +4 -6
  190. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  191. sglang/test/test_fp4_moe.py +4 -3
  192. sglang/test/test_marlin_moe.py +1 -1
  193. sglang/test/test_marlin_utils.py +1 -1
  194. sglang/test/test_utils.py +7 -0
  195. sglang/utils.py +0 -1
  196. sglang/version.py +1 -1
  197. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
  198. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
  199. sglang/srt/layers/quantization/fp4.py +0 -557
  200. sglang/srt/layers/quantization/scalar_type.py +0 -352
  201. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  202. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  203. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -486,3 +486,97 @@ def gelu_and_mul_triton(
486
486
  return out_hidden_states, out_scales
487
487
  else:
488
488
  return out_hidden_states, None
489
+
490
+
491
+ # silu on first half of vector
492
+ @triton.jit
493
+ def silu_and_mul_kernel(
494
+ out_hidden_states_ptr, # (bs, hidden_dim)
495
+ out_scales_ptr, # (bs,)
496
+ hidden_states_ptr, # (bs, hidden_dim * 2)
497
+ quant_max: tl.constexpr,
498
+ static_scale: tl.constexpr,
499
+ hidden_dim: tl.constexpr, # the output hidden_dim
500
+ BLOCK_SIZE: tl.constexpr,
501
+ ):
502
+ pid = tl.program_id(axis=0)
503
+
504
+ input_start = pid * hidden_dim * 2
505
+ output_start = pid * hidden_dim
506
+
507
+ input1_offs = tl.arange(0, BLOCK_SIZE)
508
+ mask = tl.arange(0, BLOCK_SIZE) < hidden_dim # shared for input1, input3, output
509
+ input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE)
510
+ output_offs = tl.arange(0, BLOCK_SIZE)
511
+
512
+ x1 = tl.load(
513
+ hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0
514
+ ).to(tl.float32)
515
+ x3 = tl.load(
516
+ hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0
517
+ ).to(tl.float32)
518
+
519
+ # silu
520
+ # cast down before mul to better match training?
521
+ silu_x1 = x1 * tl.sigmoid(x1)
522
+ out = x3 * silu_x1.to(hidden_states_ptr.dtype.element_ty)
523
+
524
+ if quant_max is not None:
525
+ raise NotImplementedError()
526
+
527
+ tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask)
528
+
529
+
530
+ def silu_and_mul_triton(
531
+ hidden_states,
532
+ scales=None,
533
+ quantize=None, # dtype to quantize to
534
+ out=None,
535
+ ):
536
+ bs, in_hidden_dim = hidden_states.shape
537
+ hidden_dim = in_hidden_dim // 2
538
+
539
+ if out is None:
540
+ out_hidden_states = torch.empty(
541
+ (bs, hidden_dim),
542
+ dtype=quantize or hidden_states.dtype,
543
+ device=hidden_states.device,
544
+ )
545
+ else:
546
+ assert out.shape == (bs, hidden_dim)
547
+ assert out.dtype == (quantize or hidden_states.dtype)
548
+ out_hidden_states = out
549
+ out_scales = None
550
+ static_scale = False
551
+ if quantize is not None:
552
+ if scales is None:
553
+ out_scales = torch.empty(
554
+ (bs,), dtype=torch.float32, device=hidden_states.device
555
+ )
556
+ else:
557
+ out_scales = scales
558
+ static_scale = True
559
+
560
+ max_warps = 16 if _is_hip else 32
561
+ config = {
562
+ # 8 ele per thread (not tuned)
563
+ "num_warps": max(
564
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4
565
+ ),
566
+ }
567
+
568
+ silu_and_mul_kernel[(bs,)](
569
+ out_hidden_states,
570
+ out_scales,
571
+ hidden_states,
572
+ quant_max=torch.finfo(quantize).max if quantize is not None else None,
573
+ static_scale=static_scale,
574
+ hidden_dim=hidden_dim,
575
+ BLOCK_SIZE=triton.next_power_of_2(hidden_dim),
576
+ **config,
577
+ )
578
+
579
+ if quantize is not None:
580
+ return out_hidden_states, out_scales
581
+ else:
582
+ return out_hidden_states, None
@@ -5,7 +5,11 @@ import torch
5
5
  import torch.distributed as dist
6
6
 
7
7
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
8
- from sglang.srt.utils import is_flashinfer_available
8
+ from sglang.srt.utils import (
9
+ direct_register_custom_op,
10
+ is_flashinfer_available,
11
+ supports_custom_op,
12
+ )
9
13
 
10
14
  logger = logging.getLogger(__name__)
11
15
 
@@ -196,6 +200,30 @@ def flashinfer_allreduce_residual_rmsnorm(
196
200
  return norm_out, residual_out
197
201
 
198
202
 
203
+ def fake_flashinfer_allreduce_residual_rmsnorm(
204
+ input_tensor: torch.Tensor,
205
+ residual: torch.Tensor,
206
+ weight: torch.Tensor,
207
+ eps: float = 1e-6,
208
+ max_token_num: int = 2048,
209
+ use_oneshot: Optional[bool] = None,
210
+ trigger_completion_at_end: bool = False,
211
+ fp32_acc: bool = False,
212
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
213
+ residual_out = torch.empty_like(residual)
214
+ norm_out = torch.empty_like(input_tensor)
215
+ return norm_out, residual_out
216
+
217
+
218
+ if supports_custom_op():
219
+ direct_register_custom_op(
220
+ "flashinfer_allreduce_residual_rmsnorm",
221
+ flashinfer_allreduce_residual_rmsnorm,
222
+ mutates_args=["input_tensor", "residual", "weight"],
223
+ fake_impl=fake_flashinfer_allreduce_residual_rmsnorm,
224
+ )
225
+
226
+
199
227
  def cleanup_flashinfer_workspace():
200
228
  global _workspace_manager
201
229
  if _workspace_manager is not None:
@@ -27,6 +27,7 @@ from sglang.srt.utils import (
27
27
  is_cuda,
28
28
  is_hip,
29
29
  is_npu,
30
+ supports_custom_op,
30
31
  )
31
32
 
32
33
  _is_cuda = is_cuda()
@@ -202,8 +203,14 @@ class RMSNorm(CustomOp):
202
203
  flashinfer_allreduce_residual_rmsnorm,
203
204
  )
204
205
 
206
+ fused_op = (
207
+ torch.ops.sglang.flashinfer_allreduce_residual_rmsnorm
208
+ if supports_custom_op()
209
+ else flashinfer_allreduce_residual_rmsnorm
210
+ )
211
+
205
212
  if get_tensor_model_parallel_world_size() > 1:
206
- fused_result = flashinfer_allreduce_residual_rmsnorm(
213
+ fused_result = fused_op(
207
214
  input_tensor=x,
208
215
  residual=residual,
209
216
  weight=self.weight,
@@ -110,6 +110,20 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
110
110
  return param[shard_id], loaded_weight
111
111
 
112
112
 
113
+ def adjust_shard_offsets(shard_offsets, loaded_weight, dim):
114
+ actual_weight_size = loaded_weight.size(dim)
115
+ target_weight_size = shard_offsets[-1][-1] + shard_offsets[-1][-2]
116
+ if actual_weight_size != target_weight_size:
117
+ new_shard_offsets = []
118
+ new_offset = 0
119
+ for shard_id, shard_offset, shard_size in shard_offsets:
120
+ actual_shard_size = actual_weight_size * shard_size // target_weight_size
121
+ new_shard_offsets.append((shard_id, new_offset, actual_shard_size))
122
+ new_offset += actual_shard_size
123
+ return new_shard_offsets
124
+ return shard_offsets
125
+
126
+
113
127
  class LinearBase(torch.nn.Module):
114
128
  """Base linear layer.
115
129
 
@@ -535,6 +549,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
535
549
  packed_dim = getattr(param, "packed_dim", None)
536
550
 
537
551
  use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
552
+ if _is_cpu:
553
+ shard_offsets = adjust_shard_offsets(
554
+ shard_offsets, loaded_weight, output_dim
555
+ )
556
+
538
557
  for shard_id, shard_offset, shard_size in shard_offsets:
539
558
  # Special case for Quantization.
540
559
  # If quantized, we need to adjust the offset and size to account
@@ -977,6 +996,11 @@ class QKVParallelLinear(ColumnParallelLinear):
977
996
  use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
978
997
 
979
998
  packed_dim = getattr(param, "packed_dim", None)
999
+ if _is_cpu:
1000
+ shard_offsets = adjust_shard_offsets(
1001
+ shard_offsets, loaded_weight, output_dim
1002
+ )
1003
+
980
1004
  for shard_id, shard_offset, shard_size in shard_offsets:
981
1005
  # Special case for Quantized Weights.
982
1006
  # If quantized, we need to adjust the offset and size to account
@@ -27,7 +27,7 @@ from sglang.srt.distributed import (
27
27
  tensor_model_parallel_all_gather,
28
28
  )
29
29
  from sglang.srt.layers.dp_attention import (
30
- DPPaddingMode,
30
+ DpPaddingMode,
31
31
  attn_tp_all_gather,
32
32
  attn_tp_all_gather_into_tensor,
33
33
  dp_gather_replicate,
@@ -35,7 +35,9 @@ from sglang.srt.layers.dp_attention import (
35
35
  get_attention_dp_rank,
36
36
  get_attention_dp_size,
37
37
  get_attention_tp_size,
38
+ get_global_dp_buffer,
38
39
  get_local_attention_dp_size,
40
+ set_dp_buffer_len,
39
41
  )
40
42
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
41
43
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -108,14 +110,12 @@ class LogitsMetadata:
108
110
  # The start position of local hidden states.
109
111
  dp_local_start_pos: Optional[torch.Tensor] = None
110
112
  dp_local_num_tokens: Optional[torch.Tensor] = None
111
- gathered_buffer: Optional[torch.Tensor] = None
112
- # Buffer to gather logits from all ranks.
113
- forward_batch_gathered_buffer: Optional[torch.Tensor] = None
113
+ global_dp_buffer_len: Optional[int] = None
114
114
  # Number of tokens to sample per DP rank
115
115
  global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None
116
116
  global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
117
117
  # The gather mode for DP attention
118
- dp_padding_mode: Optional[DPPaddingMode] = None
118
+ dp_padding_mode: Optional[DpPaddingMode] = None
119
119
  # for padding
120
120
  padded_static_len: int = -1
121
121
 
@@ -164,11 +164,10 @@ class LogitsMetadata:
164
164
  global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
165
165
  dp_local_start_pos=forward_batch.dp_local_start_pos,
166
166
  dp_local_num_tokens=forward_batch.dp_local_num_tokens,
167
- gathered_buffer=forward_batch.gathered_buffer,
168
- forward_batch_gathered_buffer=forward_batch.gathered_buffer,
167
+ global_dp_buffer_len=forward_batch.global_dp_buffer_len,
169
168
  global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu,
170
169
  global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
171
- dp_padding_mode=DPPaddingMode.SUM_LEN,
170
+ dp_padding_mode=DpPaddingMode.SUM_LEN,
172
171
  )
173
172
 
174
173
  def compute_dp_attention_metadata(self):
@@ -188,16 +187,15 @@ class LogitsMetadata:
188
187
 
189
188
  if self.global_num_tokens_for_logprob_cpu is not None:
190
189
  # create a smaller buffer to reduce peak memory usage
191
- self.gathered_buffer = torch.empty(
192
- (
193
- sum(self.global_num_tokens_for_logprob_cpu),
194
- self.gathered_buffer.shape[1],
195
- ),
196
- dtype=self.gathered_buffer.dtype,
197
- device=self.gathered_buffer.device,
198
- )
190
+ self.global_dp_buffer_len = sum(self.global_num_tokens_for_logprob_cpu)
199
191
  else:
200
- self.gathered_buffer = torch.empty_like(self.gathered_buffer)
192
+ self.global_dp_buffer_len = self.global_dp_buffer_len
193
+
194
+ set_dp_buffer_len(
195
+ self.global_dp_buffer_len,
196
+ self.dp_local_num_tokens,
197
+ self.global_num_tokens_for_logprob_cpu,
198
+ )
201
199
 
202
200
 
203
201
  class LogitsProcessor(nn.Module):
@@ -443,7 +441,7 @@ class LogitsProcessor(nn.Module):
443
441
  if self.do_tensor_parallel_all_gather_dp_attn:
444
442
  logits_metadata.compute_dp_attention_metadata()
445
443
  hidden_states, local_hidden_states = (
446
- logits_metadata.gathered_buffer,
444
+ get_global_dp_buffer(),
447
445
  hidden_states,
448
446
  )
449
447
  dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
@@ -0,0 +1,31 @@
1
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
2
+ from sglang.srt.layers.moe.utils import (
3
+ DeepEPMode,
4
+ MoeA2ABackend,
5
+ MoeRunnerBackend,
6
+ get_deepep_config,
7
+ get_deepep_mode,
8
+ get_moe_a2a_backend,
9
+ get_moe_runner_backend,
10
+ get_tbo_token_distribution_threshold,
11
+ initialize_moe_config,
12
+ is_tbo_enabled,
13
+ should_use_flashinfer_cutlass_moe_fp4_allgather,
14
+ should_use_flashinfer_trtllm_moe,
15
+ )
16
+
17
+ __all__ = [
18
+ "DeepEPMode",
19
+ "MoeA2ABackend",
20
+ "MoeRunnerConfig",
21
+ "MoeRunnerBackend",
22
+ "initialize_moe_config",
23
+ "get_moe_a2a_backend",
24
+ "get_moe_runner_backend",
25
+ "get_deepep_mode",
26
+ "should_use_flashinfer_trtllm_moe",
27
+ "should_use_flashinfer_cutlass_moe_fp4_allgather",
28
+ "is_tbo_enabled",
29
+ "get_tbo_token_distribution_threshold",
30
+ "get_deepep_config",
31
+ ]
@@ -1,11 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import TYPE_CHECKING, Optional
4
+ from typing import TYPE_CHECKING, Optional, Union
5
5
 
6
6
  import torch
7
7
 
8
8
  from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
9
+ from sglang.srt.layers.moe import (
10
+ get_deepep_mode,
11
+ get_moe_a2a_backend,
12
+ get_moe_runner_backend,
13
+ should_use_flashinfer_trtllm_moe,
14
+ )
9
15
  from sglang.srt.layers.moe.ep_moe.kernels import (
10
16
  ep_gather,
11
17
  ep_scatter,
@@ -16,14 +22,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
16
22
  )
17
23
  from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
18
24
  from sglang.srt.layers.moe.topk import TopKOutput
19
- from sglang.srt.layers.moe.utils import DeepEPMode, should_use_flashinfer_trtllm_moe
20
25
  from sglang.srt.layers.quantization import deep_gemm_wrapper
21
26
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
22
- from sglang.srt.layers.quantization.fp8 import (
23
- Fp8Config,
24
- Fp8MoEMethod,
25
- get_tile_tokens_dim,
26
- )
27
+ from sglang.srt.layers.quantization.fp8 import Fp8Config
27
28
  from sglang.srt.layers.quantization.fp8_kernel import (
28
29
  is_fp8_fnuz,
29
30
  sglang_per_token_group_quant_fp8,
@@ -51,7 +52,6 @@ if not (_is_npu or _is_hip):
51
52
  if _use_aiter:
52
53
  from aiter import ActivationType, QuantType
53
54
  from aiter.fused_moe import fused_moe
54
- from aiter.ops.shuffle import shuffle_weight
55
55
 
56
56
  logger = logging.getLogger(__name__)
57
57
 
@@ -89,12 +89,11 @@ class EPMoE(FusedMoE):
89
89
  num_fused_shared_experts: int = 0,
90
90
  params_dtype: Optional[torch.dtype] = None,
91
91
  quant_config: Optional[QuantizationConfig] = None,
92
- tp_size: Optional[int] = None,
93
92
  prefix: str = "",
94
93
  activation: str = "silu",
95
94
  routed_scaling_factor: Optional[float] = None,
96
- activation_alpha: Optional[float] = None,
97
- swiglu_limit: Optional[float] = None,
95
+ gemm1_alpha: Optional[float] = None,
96
+ gemm1_clamp_limit: Optional[float] = None,
98
97
  with_bias: bool = False,
99
98
  ):
100
99
  super().__init__(
@@ -106,13 +105,12 @@ class EPMoE(FusedMoE):
106
105
  top_k=top_k,
107
106
  params_dtype=params_dtype,
108
107
  quant_config=quant_config,
109
- tp_size=tp_size,
110
108
  prefix=prefix,
111
109
  activation=activation,
112
110
  # apply_router_weight_on_input=apply_router_weight_on_input,
113
111
  routed_scaling_factor=routed_scaling_factor,
114
- activation_alpha=activation_alpha,
115
- swiglu_limit=swiglu_limit,
112
+ gemm1_alpha=gemm1_alpha,
113
+ gemm1_clamp_limit=gemm1_clamp_limit,
116
114
  with_bias=with_bias,
117
115
  )
118
116
 
@@ -163,7 +161,8 @@ class EPMoE(FusedMoE):
163
161
  )
164
162
 
165
163
  assert self.quant_method is not None
166
- assert self.activation == "silu"
164
+ assert self.moe_runner_config.activation == "silu"
165
+
167
166
  hidden_states_shape = hidden_states.shape
168
167
  hidden_states_dtype = hidden_states.dtype
169
168
  hidden_states_device = hidden_states.device
@@ -327,8 +326,8 @@ class EPMoE(FusedMoE):
327
326
  m_max * self.start_expert_id,
328
327
  BLOCK_SIZE=512,
329
328
  )
330
- if self.routed_scaling_factor is not None:
331
- output *= self.routed_scaling_factor
329
+ if self.moe_runner_config.routed_scaling_factor is not None:
330
+ output *= self.moe_runner_config.routed_scaling_factor
332
331
  return output
333
332
 
334
333
 
@@ -349,11 +348,9 @@ class DeepEPMoE(EPMoE):
349
348
  num_fused_shared_experts: int = 0,
350
349
  params_dtype: Optional[torch.dtype] = None,
351
350
  quant_config: Optional[QuantizationConfig] = None,
352
- tp_size: Optional[int] = None,
353
351
  prefix: str = "",
354
352
  activation: str = "silu",
355
353
  routed_scaling_factor: Optional[float] = None,
356
- deepep_mode: DeepEPMode = DeepEPMode.AUTO,
357
354
  ):
358
355
  super().__init__(
359
356
  num_experts=num_experts,
@@ -364,12 +361,11 @@ class DeepEPMoE(EPMoE):
364
361
  num_fused_shared_experts=num_fused_shared_experts,
365
362
  params_dtype=params_dtype,
366
363
  quant_config=quant_config,
367
- tp_size=tp_size,
368
364
  prefix=prefix,
369
365
  activation=activation,
370
366
  routed_scaling_factor=routed_scaling_factor,
371
367
  )
372
- self.deepep_mode = deepep_mode
368
+ self.deepep_mode = get_deepep_mode()
373
369
 
374
370
  # TODO: move to the beginning of the file
375
371
  from sglang.srt.distributed.parallel_state import get_tp_group
@@ -383,7 +379,7 @@ class DeepEPMoE(EPMoE):
383
379
  num_local_experts=self.num_local_experts,
384
380
  hidden_size=hidden_size,
385
381
  params_dtype=params_dtype,
386
- deepep_mode=deepep_mode,
382
+ deepep_mode=self.deepep_mode,
387
383
  async_finish=True, # TODO
388
384
  return_recv_hook=True,
389
385
  )
@@ -458,15 +454,19 @@ class DeepEPMoE(EPMoE):
458
454
  )
459
455
 
460
456
  def moe_impl(self, dispatch_output: DispatchOutput):
457
+ from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
458
+
461
459
  if _use_aiter:
460
+ assert DispatchOutputChecker.format_is_deepep(dispatch_output)
462
461
  # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
463
462
  return self.forward_aiter(dispatch_output)
464
463
  if _is_npu:
464
+ assert DispatchOutputChecker.format_is_ascent_ll(dispatch_output)
465
465
  return self.forward_npu(dispatch_output)
466
- if dispatch_output.format.is_deepep_normal():
466
+ if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
467
467
  assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
468
468
  return self.forward_deepgemm_contiguous(dispatch_output)
469
- elif dispatch_output.format.is_deepep_ll():
469
+ elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
470
470
  assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
471
471
  return self.forward_deepgemm_masked(dispatch_output)
472
472
  else:
@@ -490,7 +490,7 @@ class DeepEPMoE(EPMoE):
490
490
 
491
491
  def forward_aiter(
492
492
  self,
493
- dispatch_output: DeepEPNormalOutput,
493
+ dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
494
494
  ):
495
495
  hidden_states, topk_idx, topk_weights = (
496
496
  dispatch_output.hidden_states,
@@ -516,7 +516,7 @@ class DeepEPMoE(EPMoE):
516
516
  quant_type=QuantType.per_128x128,
517
517
  activation=(
518
518
  ActivationType.Silu
519
- if self.activation == "silu"
519
+ if self.moe_runner_config.activation == "silu"
520
520
  else ActivationType.Gelu
521
521
  ),
522
522
  expert_mask=self.expert_mask,
@@ -531,7 +531,7 @@ class DeepEPMoE(EPMoE):
531
531
  )
532
532
  hidden_states_fp8, hidden_states_scale = hidden_states_fp8
533
533
  assert self.quant_method is not None
534
- assert self.activation == "silu"
534
+ assert self.moe_runner_config.activation == "silu"
535
535
  if num_recv_tokens_per_expert is None:
536
536
  return hidden_states_fp8.bfloat16()
537
537
  all_tokens = sum(num_recv_tokens_per_expert)
@@ -652,7 +652,7 @@ class DeepEPMoE(EPMoE):
652
652
  ):
653
653
  hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
654
654
  assert self.quant_method is not None
655
- assert self.activation == "silu"
655
+ assert self.moe_runner_config.activation == "silu"
656
656
 
657
657
  # GroupGemm-0
658
658
  num_groups, m, k = hidden_states_fp8[0].size()
@@ -735,7 +735,7 @@ class DeepEPMoE(EPMoE):
735
735
  assert isinstance(dispatch_output, AscendDeepEPLLOutput)
736
736
  hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
737
737
  assert self.quant_method is not None
738
- assert self.activation == "silu"
738
+ assert self.moe_runner_config.activation == "silu"
739
739
 
740
740
  # NOTE: Ascend's Dispatch & Combine does not support FP16
741
741
  output_dtype = torch.bfloat16
@@ -782,13 +782,17 @@ class DeepEPMoE(EPMoE):
782
782
  return hidden_states
783
783
 
784
784
 
785
- def get_moe_impl_class():
786
- if global_server_args_dict["moe_a2a_backend"].is_deepep():
785
+ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
786
+ if get_moe_a2a_backend().is_deepep():
787
787
  return DeepEPMoE
788
788
 
789
789
  # NEW: Direct FP4 detection (bypasses EP requirements)
790
790
  # Check for FP4 quantization with TRTLLM flag, regardless of EP
791
- if global_server_args_dict.get("enable_flashinfer_trtllm_moe", False):
791
+ if get_moe_runner_backend().is_flashinfer_trtllm():
792
+ # FlashInferFP4MoE must be paired with ModelOptNvFp4FusedMoEMethod.
793
+ # If UnquantizedFusedMoEMethod is detected, fall back to FusedMoE instead.
794
+ if quant_config is None:
795
+ return FusedMoE
792
796
  try:
793
797
  # Check the quantization argument directly
794
798
  quantization = global_server_args_dict.get("quantization")
@@ -803,7 +807,7 @@ def get_moe_impl_class():
803
807
 
804
808
  if should_use_flashinfer_trtllm_moe():
805
809
  return FlashInferFusedMoE
806
- if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
810
+ if get_moe_runner_backend().is_flashinfer_cutlass():
807
811
  return FusedMoE
808
812
  if get_moe_expert_parallel_world_size() > 1:
809
813
  return EPMoE
@@ -3,28 +3,22 @@ Torch-native implementation for FusedMoE. This is used for torch.compile.
3
3
  It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
4
4
  """
5
5
 
6
- from typing import Callable, Optional
7
-
8
6
  import torch
9
7
  from torch.nn import functional as F
10
8
 
11
9
  from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
12
- from sglang.srt.layers.moe.topk import TopKOutput
10
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
11
+ from sglang.srt.layers.moe.topk import StandardTopKOutput
13
12
 
14
13
 
15
14
  def fused_moe_forward_native(
16
15
  layer: torch.nn.Module,
17
16
  x: torch.Tensor,
18
- topk_output: TopKOutput,
19
- *,
20
- activation: str = "silu",
21
- apply_router_weight_on_input: bool = False,
22
- inplace: bool = True,
23
- no_combine: bool = False,
24
- routed_scaling_factor: Optional[float] = None,
17
+ topk_output: StandardTopKOutput,
18
+ moe_runner_config: MoeRunnerConfig,
25
19
  ) -> torch.Tensor:
26
20
 
27
- if apply_router_weight_on_input:
21
+ if moe_runner_config.apply_router_weight_on_input:
28
22
  raise NotImplementedError()
29
23
 
30
24
  topk_weights, topk_ids, _ = topk_output
@@ -33,12 +27,12 @@ def fused_moe_forward_native(
33
27
  w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
34
28
  w2_weights = layer.w2_weight[topk_ids]
35
29
  x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
36
- if activation == "silu":
30
+ if moe_runner_config.activation == "silu":
37
31
  x1 = F.silu(x1)
38
- elif activation == "gelu":
32
+ elif moe_runner_config.activation == "gelu":
39
33
  x1 = F.gelu(x1)
40
34
  else:
41
- raise ValueError(f"Unsupported activation: {activation=}")
35
+ raise ValueError(f"Unsupported activation: {moe_runner_config.activation=}")
42
36
  x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
43
37
  expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
44
38
  return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
@@ -47,16 +41,11 @@ def fused_moe_forward_native(
47
41
  def moe_forward_native(
48
42
  layer: torch.nn.Module,
49
43
  x: torch.Tensor,
50
- topk_output: TopKOutput,
51
- *,
52
- activation: str = "silu",
53
- apply_router_weight_on_input: bool = False,
54
- inplace: bool = True,
55
- no_combine: bool = False,
56
- routed_scaling_factor: Optional[float] = None,
44
+ topk_output: StandardTopKOutput,
45
+ moe_runner_config: MoeRunnerConfig,
57
46
  ) -> torch.Tensor:
58
47
 
59
- if apply_router_weight_on_input:
48
+ if moe_runner_config.apply_router_weight_on_input:
60
49
  raise NotImplementedError()
61
50
 
62
51
  topk_weights, topk_ids, _ = topk_output
@@ -72,12 +61,12 @@ def moe_forward_native(
72
61
  sorted_tokens = x[idxs // topk_ids.shape[1]]
73
62
  tokens_per_expert = tokens_per_expert.cpu().numpy()
74
63
 
75
- if activation == "silu":
64
+ if moe_runner_config.activation == "silu":
76
65
  act = SiluAndMul()
77
- elif activation == "gelu":
66
+ elif moe_runner_config.activation == "gelu":
78
67
  act = GeluAndMul()
79
68
  else:
80
- raise ValueError(f"Unsupported activation: {activation=}")
69
+ raise ValueError(f"Unsupported activation: {moe_runner_config.activation=}")
81
70
 
82
71
  outputs = []
83
72
  start_idx = 0