sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -5,13 +5,15 @@ from dataclasses import dataclass
5
5
  from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
6
6
 
7
7
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
8
- from sglang.srt.layers.moe import DeepEPMode, get_deepep_config, is_tbo_enabled
9
- from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
8
+ from sglang.srt.layers.moe.token_dispatcher.base import (
10
9
  BaseDispatcher,
11
10
  BaseDispatcherConfig,
11
+ CombineInput,
12
+ CombineInputFormat,
12
13
  DispatchOutput,
13
14
  DispatchOutputFormat,
14
15
  )
16
+ from sglang.srt.layers.moe.utils import DeepEPMode, get_deepep_config, is_tbo_enabled
15
17
  from sglang.srt.layers.quantization import deep_gemm_wrapper
16
18
  from sglang.srt.utils import (
17
19
  get_bool_env_var,
@@ -40,11 +42,6 @@ from enum import Enum, IntEnum, auto
40
42
  import torch
41
43
  import torch.distributed as dist
42
44
 
43
- from sglang.srt.layers.moe.ep_moe.kernels import (
44
- deepep_permute_triton_kernel,
45
- deepep_post_reorder_triton_kernel,
46
- deepep_run_moe_deep_preprocess,
47
- )
48
45
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
46
 
50
47
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
@@ -56,6 +53,7 @@ class DeepEPNormalOutput(NamedTuple):
56
53
  """DeepEP normal dispatch output."""
57
54
 
58
55
  hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
56
+ # hidden_states_scale
59
57
  topk_idx: torch.Tensor
60
58
  topk_weights: torch.Tensor
61
59
  num_recv_tokens_per_expert: List[int]
@@ -79,24 +77,32 @@ class DeepEPLLOutput(NamedTuple):
79
77
  return DispatchOutputFormat.DEEPEP_LL
80
78
 
81
79
 
82
- class AscendDeepEPLLOutput(NamedTuple):
83
- """AscendDeepEP low latency dispatch output."""
80
+ assert isinstance(DeepEPNormalOutput, DispatchOutput)
81
+ assert isinstance(DeepEPLLOutput, DispatchOutput)
84
82
 
85
- hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
86
- topk_idx: torch.Tensor
87
- topk_weights: torch.Tensor
88
- masked_m: torch.Tensor
89
- seg_indptr: torch.Tensor
90
- expected_m: int
83
+
84
+ class DeepEPNormalCombineInput(NamedTuple):
85
+ """DeepEP normal combine input."""
86
+
87
+ pass
91
88
 
92
89
  @property
93
- def format(self) -> DispatchOutputFormat:
94
- return DispatchOutputFormat.ASCENT_LL
90
+ def format(self) -> CombineInputFormat:
91
+ return CombineInputFormat.DEEPEP_NORMAL
95
92
 
96
93
 
97
- assert isinstance(DeepEPNormalOutput, DispatchOutput)
98
- assert isinstance(DeepEPLLOutput, DispatchOutput)
99
- assert isinstance(AscendDeepEPLLOutput, DispatchOutput)
94
+ class DeepEPLLCombineInput(NamedTuple):
95
+ """DeepEP low latency combine input."""
96
+
97
+ pass
98
+
99
+ @property
100
+ def format(self) -> CombineInputFormat:
101
+ return CombineInputFormat.DEEPEP_LL
102
+
103
+
104
+ assert isinstance(DeepEPNormalCombineInput, CombineInput)
105
+ assert isinstance(DeepEPLLCombineInput, CombineInput)
100
106
 
101
107
 
102
108
  class DeepEPDispatchMode(IntEnum):
@@ -272,6 +278,9 @@ class _DeepEPDispatcherImplBase:
272
278
  self.num_max_dispatch_tokens_per_rank = get_int_env_var(
273
279
  "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 128
274
280
  )
281
+ # DeepEP internode_ll dispatch uses FINISHED_SUM_TAG=1024
282
+ # and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
283
+ assert self.num_max_dispatch_tokens_per_rank <= 1024
275
284
 
276
285
  self.handle = None
277
286
 
@@ -409,7 +418,11 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
409
418
  topk_idx: torch.Tensor,
410
419
  topk_weights: torch.Tensor,
411
420
  ):
412
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter:
421
+ from sglang.srt.layers.moe.ep_moe.kernels import (
422
+ deepep_post_reorder_triton_kernel,
423
+ )
424
+
425
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
413
426
  output = hidden_states
414
427
  else:
415
428
  if hidden_states.shape[0] > 0:
@@ -495,7 +508,8 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
495
508
  hidden_states, masked_m, event, hook = self._dispatch_core(
496
509
  hidden_states,
497
510
  topk_idx,
498
- use_fp8=True,
511
+ # TODO(shuw): pending https://github.com/deepseek-ai/DeepEP/pull/341
512
+ use_fp8=not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"),
499
513
  )
500
514
  return (
501
515
  hidden_states,
@@ -523,23 +537,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
523
537
  masked_m
524
538
  )
525
539
 
526
- if _is_npu:
527
- deepep_output = AscendDeepEPLLOutput(
528
- hidden_states,
529
- topk_idx,
530
- topk_weights,
531
- masked_m,
532
- self.handle[1],
533
- expected_m,
534
- )
535
- else:
536
- deepep_output = DeepEPLLOutput(
537
- hidden_states,
538
- topk_idx,
539
- topk_weights,
540
- masked_m,
541
- expected_m,
542
- )
540
+ deepep_output = DeepEPLLOutput(
541
+ hidden_states,
542
+ topk_idx,
543
+ topk_weights,
544
+ masked_m,
545
+ expected_m,
546
+ )
543
547
  return deepep_output
544
548
 
545
549
  def _dispatch_core(
@@ -1,19 +1,61 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import NamedTuple
3
+ from typing import TYPE_CHECKING, NamedTuple
4
4
 
5
- from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
5
+ import torch
6
+
7
+ from sglang.srt.layers.moe.token_dispatcher.base import (
8
+ BaseDispatcher,
9
+ CombineInput,
10
+ CombineInputFormat,
6
11
  DispatchOutput,
7
12
  DispatchOutputFormat,
8
13
  )
9
14
 
15
+ if TYPE_CHECKING:
16
+ from sglang.srt.layers.moe.topk import TopKOutput
17
+
10
18
 
11
19
  class StandardDispatchOutput(NamedTuple):
12
20
  """Standard dispatch output."""
13
21
 
22
+ hidden_states: torch.Tensor
23
+ topk_output: TopKOutput
24
+
14
25
  @property
15
26
  def format(self) -> DispatchOutputFormat:
16
27
  return DispatchOutputFormat.STANDARD
17
28
 
18
29
 
19
30
  assert isinstance(StandardDispatchOutput, DispatchOutput)
31
+
32
+
33
+ class StandardCombineInput(NamedTuple):
34
+ """Standard combine input."""
35
+
36
+ hidden_states: torch.Tensor
37
+
38
+ @property
39
+ def format(self) -> CombineInputFormat:
40
+ return CombineInputFormat.STANDARD
41
+
42
+
43
+ assert isinstance(StandardCombineInput, CombineInput)
44
+
45
+
46
+ class StandardDispatcher(BaseDispatcher):
47
+
48
+ def dispatch(
49
+ self, hidden_states: torch.Tensor, topk_output: TopKOutput
50
+ ) -> DispatchOutput:
51
+ return StandardDispatchOutput(
52
+ hidden_states=hidden_states, topk_output=topk_output
53
+ )
54
+
55
+ def combine(self, combine_input: CombineInput) -> torch.Tensor:
56
+ if isinstance(combine_input, StandardCombineInput):
57
+ return combine_input.hidden_states
58
+ else:
59
+ # TODO: this branch should be removed in the future
60
+ assert isinstance(combine_input, torch.Tensor)
61
+ return combine_input
@@ -19,6 +19,7 @@ import math
19
19
  from dataclasses import dataclass
20
20
  from enum import Enum, auto
21
21
  from typing import (
22
+ TYPE_CHECKING,
22
23
  Callable,
23
24
  NamedTuple,
24
25
  Optional,
@@ -51,6 +52,9 @@ from sglang.srt.utils import (
51
52
  is_npu,
52
53
  )
53
54
 
55
+ if TYPE_CHECKING:
56
+ from sglang.srt.layers.quantization import QuantizationConfig
57
+
54
58
  try:
55
59
  from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
56
60
  except ImportError:
@@ -94,6 +98,7 @@ class TopKConfig:
94
98
  torch_native: bool = False
95
99
  routed_scaling_factor: Optional[float] = None
96
100
  apply_routed_scaling_factor_on_output: bool = False
101
+ output_format: Optional[TopKOutputFormat] = None
97
102
 
98
103
 
99
104
  # -------------------------------- TopKOutput ---------------------------------------
@@ -196,9 +201,10 @@ class TopK(CustomOp):
196
201
  custom_routing_function: Optional[Callable] = None,
197
202
  scoring_func: str = "softmax",
198
203
  correction_bias: Optional[torch.Tensor] = None,
204
+ quant_config: Optional[QuantizationConfig] = None,
199
205
  routed_scaling_factor: Optional[float] = None,
200
206
  apply_routed_scaling_factor_on_output: Optional[bool] = False,
201
- force_topk: bool = False,
207
+ output_format: Optional[TopKOutputFormat] = None,
202
208
  ):
203
209
  # NOTE: scoring_func is not used for now, but we keep it for future use
204
210
  # see https://github.com/sgl-project/sglang/pull/4505 for more details
@@ -218,11 +224,9 @@ class TopK(CustomOp):
218
224
  correction_bias=correction_bias,
219
225
  routed_scaling_factor=routed_scaling_factor,
220
226
  apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
227
+ output_format=output_format,
221
228
  )
222
229
 
223
- self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
224
- self.force_topk = force_topk
225
-
226
230
  def forward_native(
227
231
  self,
228
232
  hidden_states: torch.Tensor,
@@ -248,7 +252,19 @@ class TopK(CustomOp):
248
252
  num_token_non_padded: Optional[torch.Tensor] = None,
249
253
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
250
254
  ) -> TopKOutput:
251
- if self.use_triton_kernels:
255
+ if self.topk_config.output_format is not None:
256
+ output_format = self.topk_config.output_format
257
+ elif get_moe_runner_backend().is_triton_kernel():
258
+ output_format = TopKOutputFormat.TRITON_KERNEL
259
+ elif (
260
+ should_use_flashinfer_trtllm_moe()
261
+ or get_moe_runner_backend().is_flashinfer_mxfp4()
262
+ ):
263
+ output_format = TopKOutputFormat.BYPASSED
264
+ else:
265
+ output_format = TopKOutputFormat.STANDARD
266
+
267
+ if output_format == TopKOutputFormat.TRITON_KERNEL:
252
268
  # renormalize=True is equivalent to sm_first=False
253
269
  routing_data, gather_idx, scatter_idx = routing(
254
270
  router_logits,
@@ -256,10 +272,7 @@ class TopK(CustomOp):
256
272
  sm_first=not self.topk_config.renormalize,
257
273
  )
258
274
  return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
259
- elif not self.force_topk and (
260
- should_use_flashinfer_trtllm_moe()
261
- or get_moe_runner_backend().is_flashinfer_mxfp4()
262
- ):
275
+ elif output_format == TopKOutputFormat.BYPASSED:
263
276
  return BypassedTopKOutput(
264
277
  hidden_states=hidden_states,
265
278
  router_logits=router_logits,
@@ -330,6 +343,14 @@ class TopK(CustomOp):
330
343
  )
331
344
  topk_weights = topk_weights / topk_weights_sum
332
345
 
346
+ if expert_location_dispatch_info is not None:
347
+ topk_ids = topk_ids_logical_to_physical(
348
+ topk_ids, expert_location_dispatch_info
349
+ )
350
+ get_global_expert_distribution_recorder().on_select_experts(
351
+ topk_ids=topk_ids
352
+ )
353
+
333
354
  return StandardTopKOutput(topk_weights, topk_ids, _)
334
355
  else:
335
356
  self.topk_config.torch_native = True
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import importlib.util
4
+ import logging
4
5
  from enum import Enum
5
6
  from functools import lru_cache
6
7
  from typing import TYPE_CHECKING, Optional
@@ -12,11 +13,12 @@ from sglang.srt.layers.dp_attention import (
12
13
  get_attention_dp_size,
13
14
  is_dp_attention_enabled,
14
15
  )
15
- from sglang.srt.utils import logger
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  from sglang.srt.server_args import ServerArgs
19
19
 
20
+ logger = logging.getLogger(__name__)
21
+
20
22
 
21
23
  class MoeA2ABackend(Enum):
22
24
 
@@ -44,9 +46,10 @@ class MoeRunnerBackend(Enum):
44
46
  AUTO = "auto"
45
47
  TRITON = "triton"
46
48
  TRITON_KERNEL = "triton_kernel"
47
- FLASHINFER = "flashinfer_trtllm"
49
+ FLASHINFER_TRTLLM = "flashinfer_trtllm"
48
50
  FLASHINFER_CUTLASS = "flashinfer_cutlass"
49
51
  FLASHINFER_MXFP4 = "flashinfer_mxfp4"
52
+ FLASHINFER_CUTEDSL = "flashinfer_cutedsl"
50
53
 
51
54
  def is_auto(self):
52
55
  return self == MoeRunnerBackend.AUTO
@@ -58,11 +61,14 @@ class MoeRunnerBackend(Enum):
58
61
  return self == MoeRunnerBackend.TRITON_KERNEL
59
62
 
60
63
  def is_flashinfer_trtllm(self):
61
- return self == MoeRunnerBackend.FLASHINFER
64
+ return self == MoeRunnerBackend.FLASHINFER_TRTLLM
62
65
 
63
66
  def is_flashinfer_cutlass(self):
64
67
  return self == MoeRunnerBackend.FLASHINFER_CUTLASS
65
68
 
69
+ def is_flashinfer_cutedsl(self):
70
+ return self == MoeRunnerBackend.FLASHINFER_CUTEDSL
71
+
66
72
  def is_flashinfer_mxfp4(self):
67
73
  return self == MoeRunnerBackend.FLASHINFER_MXFP4
68
74
 
@@ -131,7 +137,7 @@ def get_moe_a2a_backend() -> MoeA2ABackend:
131
137
  global MOE_A2A_BACKEND
132
138
  if MOE_A2A_BACKEND is None:
133
139
  logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
134
- MOE_A2A_BACKEND = MoeA2ABackend(None)
140
+ MOE_A2A_BACKEND = MoeA2ABackend.NONE
135
141
  return MOE_A2A_BACKEND
136
142
 
137
143
 
@@ -139,7 +145,7 @@ def get_moe_runner_backend() -> MoeRunnerBackend:
139
145
  global MOE_RUNNER_BACKEND
140
146
  if MOE_RUNNER_BACKEND is None:
141
147
  logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
142
- MOE_RUNNER_BACKEND = MoeRunnerBackend("triton")
148
+ MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
143
149
  return MOE_RUNNER_BACKEND
144
150
 
145
151
 
@@ -147,7 +153,7 @@ def get_deepep_mode() -> DeepEPMode:
147
153
  global DEEPEP_MODE
148
154
  if DEEPEP_MODE is None:
149
155
  logger.warning("DEEPEP_MODE is not initialized, using auto mode")
150
- DEEPEP_MODE = DeepEPMode("auto")
156
+ DEEPEP_MODE = DeepEPMode.AUTO
151
157
  return DEEPEP_MODE
152
158
 
153
159
 
@@ -34,7 +34,10 @@ from sglang.srt.layers.quantization.utils import get_scalar_types, replace_param
34
34
 
35
35
  if TYPE_CHECKING:
36
36
  from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
37
- from sglang.srt.layers.moe.topk import StandardTopKOutput
37
+ from sglang.srt.layers.moe.token_dispatcher import (
38
+ StandardDispatchOutput,
39
+ CombineInput,
40
+ )
38
41
 
39
42
  from sglang.srt.utils import is_cuda, is_hip
40
43
 
@@ -736,24 +739,32 @@ class AWQMoEMethod(FusedMoEMethodBase):
736
739
  )
737
740
  replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
738
741
 
742
+ def create_moe_runner(
743
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
744
+ ):
745
+ self.moe_runner_config = moe_runner_config
746
+
739
747
  def apply(
740
748
  self,
741
749
  layer: torch.nn.Module,
742
- x: torch.Tensor,
743
- topk_output: StandardTopKOutput,
744
- moe_runner_config: MoeRunnerConfig,
745
- ) -> torch.Tensor:
750
+ dispatch_output: StandardDispatchOutput,
751
+ ) -> CombineInput:
752
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
753
+
746
754
  assert (
747
- moe_runner_config.activation == "silu"
755
+ self.moe_runner_config.activation == "silu"
748
756
  ), "Only SiLU activation is supported."
749
757
 
750
758
  # The input must currently be float16
759
+ x = dispatch_output.hidden_states
760
+ topk_output = dispatch_output.topk_output
761
+
751
762
  orig_dtype = x.dtype
752
763
  x = x.half()
753
764
 
754
765
  topk_weights, topk_ids, router_logits = topk_output
755
766
 
756
- return fused_marlin_moe(
767
+ output = fused_marlin_moe(
757
768
  x,
758
769
  layer.w13_qweight,
759
770
  layer.w2_qweight,
@@ -768,3 +779,4 @@ class AWQMoEMethod(FusedMoEMethodBase):
768
779
  w2_zeros=layer.w2_qzeros,
769
780
  num_bits=self.quant_config.weight_bits,
770
781
  ).to(orig_dtype)
782
+ return StandardCombineInput(hidden_states=output)
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
 
4
4
  import inspect
5
5
  from abc import ABC, abstractmethod
6
+ from dataclasses import dataclass
6
7
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
7
8
 
8
9
  import torch
@@ -10,7 +11,7 @@ from torch import nn
10
11
 
11
12
  if TYPE_CHECKING:
12
13
  from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
13
- from sglang.srt.layers.moe.topk import TopKOutput
14
+ from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput
14
15
 
15
16
 
16
17
  class QuantizeMethodBase(ABC):
@@ -89,20 +90,24 @@ class FusedMoEMethodBase(QuantizeMethodBase):
89
90
  layer: torch.nn.Module,
90
91
  num_experts: int,
91
92
  hidden_size: int,
92
- intermediate_size: int,
93
+ intermediate_size_per_partition: int,
93
94
  params_dtype: torch.dtype,
94
95
  **extra_weight_attrs,
95
96
  ):
96
97
  raise NotImplementedError
97
98
 
99
+ @abstractmethod
100
+ def create_moe_runner(
101
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
102
+ ):
103
+ raise NotImplementedError
104
+
98
105
  @abstractmethod
99
106
  def apply(
100
107
  self,
101
108
  layer: torch.nn.Module,
102
- x: torch.Tensor,
103
- topk_output: TopKOutput,
104
- moe_runner_config: MoeRunnerConfig,
105
- ) -> torch.Tensor:
109
+ dispatch_output: DispatchOutput,
110
+ ) -> CombineInput:
106
111
  raise NotImplementedError
107
112
 
108
113
 
@@ -9,6 +9,8 @@ import torch
9
9
  from torch.nn import Module
10
10
 
11
11
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
12
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
13
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
12
14
  from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter
13
15
  from sglang.srt.layers.quantization.base_config import (
14
16
  FusedMoEMethodBase,
@@ -22,8 +24,10 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
22
24
  from sglang.srt.utils import set_weight_attrs
23
25
 
24
26
  if TYPE_CHECKING:
25
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
26
- from sglang.srt.layers.moe.topk import TopKOutput
27
+ from sglang.srt.layers.moe.token_dispatcher import (
28
+ CombineInput,
29
+ StandardDispatchOutput,
30
+ )
27
31
 
28
32
  ACTIVATION_SCHEMES = ["static", "dynamic"]
29
33
 
@@ -257,7 +261,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
257
261
  layer: Module,
258
262
  num_experts: int,
259
263
  hidden_size: int,
260
- intermediate_size: int,
264
+ intermediate_size_per_partition: int,
261
265
  params_dtype: torch.dtype,
262
266
  **extra_weight_attrs,
263
267
  ):
@@ -273,25 +277,28 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
273
277
  )
274
278
  # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
275
279
  # Required by column parallel or enabling merged weights
276
- if intermediate_size % block_n != 0:
280
+ if intermediate_size_per_partition % block_n != 0:
277
281
  raise ValueError(
278
282
  f"The output_size of gate's and up's weight = "
279
- f"{intermediate_size} is not divisible by "
283
+ f"{intermediate_size_per_partition} is not divisible by "
280
284
  f"weight quantization block_n = {block_n}."
281
285
  )
282
286
  if tp_size > 1:
283
287
  # Required by row parallel
284
- if intermediate_size % block_k != 0:
288
+ if intermediate_size_per_partition % block_k != 0:
285
289
  raise ValueError(
286
290
  f"The input_size of down's weight = "
287
- f"{intermediate_size} is not divisible by "
291
+ f"{intermediate_size_per_partition} is not divisible by "
288
292
  f"weight quantization block_k = {block_k}."
289
293
  )
290
294
 
291
295
  # WEIGHTS
292
296
  w13_weight = torch.nn.Parameter(
293
297
  torch.empty(
294
- num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
298
+ num_experts,
299
+ 2 * intermediate_size_per_partition,
300
+ hidden_size,
301
+ dtype=params_dtype,
295
302
  ),
296
303
  requires_grad=False,
297
304
  )
@@ -300,7 +307,10 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
300
307
 
301
308
  w2_weight = torch.nn.Parameter(
302
309
  torch.empty(
303
- num_experts, hidden_size, intermediate_size, dtype=params_dtype
310
+ num_experts,
311
+ hidden_size,
312
+ intermediate_size_per_partition,
313
+ dtype=params_dtype,
304
314
  ),
305
315
  requires_grad=False,
306
316
  )
@@ -311,7 +321,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
311
321
  w13_weight_scale = torch.nn.Parameter(
312
322
  torch.ones(
313
323
  num_experts,
314
- 2 * ((intermediate_size + block_n - 1) // block_n),
324
+ 2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
315
325
  (hidden_size + block_k - 1) // block_k,
316
326
  dtype=torch.float32,
317
327
  ),
@@ -321,7 +331,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
321
331
  torch.ones(
322
332
  num_experts,
323
333
  (hidden_size + block_n - 1) // block_n,
324
- (intermediate_size + block_k - 1) // block_k,
334
+ (intermediate_size_per_partition + block_k - 1) // block_k,
325
335
  dtype=torch.float32,
326
336
  ),
327
337
  requires_grad=False,
@@ -344,26 +354,27 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
344
354
  # Block quant doesn't need to process weights after loading
345
355
  return
346
356
 
357
+ def create_moe_runner(
358
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
359
+ ):
360
+ self.moe_runner_config = moe_runner_config
361
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
362
+
347
363
  def apply(
348
364
  self,
349
365
  layer: torch.nn.Module,
350
- x: torch.Tensor,
351
- topk_output: TopKOutput,
352
- moe_runner_config: MoeRunnerConfig,
353
- ) -> torch.Tensor:
354
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
355
-
356
- # Expert fusion with INT8 quantization
357
- return fused_experts(
358
- x,
359
- layer.w13_weight,
360
- layer.w2_weight,
361
- topk_output=topk_output,
362
- moe_runner_config=moe_runner_config,
366
+ dispatch_output: StandardDispatchOutput,
367
+ ) -> CombineInput:
368
+
369
+ quant_info = TritonMoeQuantInfo(
370
+ w13_weight=layer.w13_weight,
371
+ w2_weight=layer.w2_weight,
363
372
  use_int8_w8a8=True,
364
- w1_scale=(layer.w13_weight_scale_inv),
365
- w2_scale=(layer.w2_weight_scale_inv),
366
- a1_scale=layer.w13_input_scale,
373
+ w13_scale=layer.w13_weight_scale_inv,
374
+ w2_scale=layer.w2_weight_scale_inv,
375
+ a13_scale=layer.w13_input_scale,
367
376
  a2_scale=layer.w2_input_scale,
368
377
  block_shape=self.quant_config.weight_block_size,
369
378
  )
379
+
380
+ return self.runner.run(dispatch_output, quant_info)