sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -5,13 +5,15 @@ from dataclasses import dataclass
5
5
  from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
6
6
 
7
7
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
8
- from sglang.srt.layers.moe import DeepEPMode, get_deepep_config, is_tbo_enabled
9
- from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
8
+ from sglang.srt.layers.moe.token_dispatcher.base import (
10
9
  BaseDispatcher,
11
10
  BaseDispatcherConfig,
11
+ CombineInput,
12
+ CombineInputFormat,
12
13
  DispatchOutput,
13
14
  DispatchOutputFormat,
14
15
  )
16
+ from sglang.srt.layers.moe.utils import DeepEPMode, get_deepep_config, is_tbo_enabled
15
17
  from sglang.srt.layers.quantization import deep_gemm_wrapper
16
18
  from sglang.srt.utils import (
17
19
  get_bool_env_var,
@@ -40,11 +42,6 @@ from enum import Enum, IntEnum, auto
40
42
  import torch
41
43
  import torch.distributed as dist
42
44
 
43
- from sglang.srt.layers.moe.ep_moe.kernels import (
44
- deepep_permute_triton_kernel,
45
- deepep_post_reorder_triton_kernel,
46
- deepep_run_moe_deep_preprocess,
47
- )
48
45
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
46
 
50
47
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
@@ -56,6 +53,7 @@ class DeepEPNormalOutput(NamedTuple):
56
53
  """DeepEP normal dispatch output."""
57
54
 
58
55
  hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
56
+ # hidden_states_scale
59
57
  topk_idx: torch.Tensor
60
58
  topk_weights: torch.Tensor
61
59
  num_recv_tokens_per_expert: List[int]
@@ -79,24 +77,32 @@ class DeepEPLLOutput(NamedTuple):
79
77
  return DispatchOutputFormat.DEEPEP_LL
80
78
 
81
79
 
82
- class AscendDeepEPLLOutput(NamedTuple):
83
- """AscendDeepEP low latency dispatch output."""
80
+ assert isinstance(DeepEPNormalOutput, DispatchOutput)
81
+ assert isinstance(DeepEPLLOutput, DispatchOutput)
84
82
 
85
- hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
86
- topk_idx: torch.Tensor
87
- topk_weights: torch.Tensor
88
- masked_m: torch.Tensor
89
- seg_indptr: torch.Tensor
90
- expected_m: int
83
+
84
+ class DeepEPNormalCombineInput(NamedTuple):
85
+ """DeepEP normal combine input."""
86
+
87
+ pass
91
88
 
92
89
  @property
93
- def format(self) -> DispatchOutputFormat:
94
- return DispatchOutputFormat.ASCENT_LL
90
+ def format(self) -> CombineInputFormat:
91
+ return CombineInputFormat.DEEPEP_NORMAL
95
92
 
96
93
 
97
- assert isinstance(DeepEPNormalOutput, DispatchOutput)
98
- assert isinstance(DeepEPLLOutput, DispatchOutput)
99
- assert isinstance(AscendDeepEPLLOutput, DispatchOutput)
94
+ class DeepEPLLCombineInput(NamedTuple):
95
+ """DeepEP low latency combine input."""
96
+
97
+ pass
98
+
99
+ @property
100
+ def format(self) -> CombineInputFormat:
101
+ return CombineInputFormat.DEEPEP_LL
102
+
103
+
104
+ assert isinstance(DeepEPNormalCombineInput, CombineInput)
105
+ assert isinstance(DeepEPLLCombineInput, CombineInput)
100
106
 
101
107
 
102
108
  class DeepEPDispatchMode(IntEnum):
@@ -272,6 +278,9 @@ class _DeepEPDispatcherImplBase:
272
278
  self.num_max_dispatch_tokens_per_rank = get_int_env_var(
273
279
  "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 128
274
280
  )
281
+ # DeepEP internode_ll dispatch uses FINISHED_SUM_TAG=1024
282
+ # and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
283
+ assert self.num_max_dispatch_tokens_per_rank <= 1024
275
284
 
276
285
  self.handle = None
277
286
 
@@ -409,7 +418,11 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
409
418
  topk_idx: torch.Tensor,
410
419
  topk_weights: torch.Tensor,
411
420
  ):
412
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter:
421
+ from sglang.srt.layers.moe.ep_moe.kernels import (
422
+ deepep_post_reorder_triton_kernel,
423
+ )
424
+
425
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
413
426
  output = hidden_states
414
427
  else:
415
428
  if hidden_states.shape[0] > 0:
@@ -523,23 +536,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
523
536
  masked_m
524
537
  )
525
538
 
526
- if _is_npu:
527
- deepep_output = AscendDeepEPLLOutput(
528
- hidden_states,
529
- topk_idx,
530
- topk_weights,
531
- masked_m,
532
- self.handle[1],
533
- expected_m,
534
- )
535
- else:
536
- deepep_output = DeepEPLLOutput(
537
- hidden_states,
538
- topk_idx,
539
- topk_weights,
540
- masked_m,
541
- expected_m,
542
- )
539
+ deepep_output = DeepEPLLOutput(
540
+ hidden_states,
541
+ topk_idx,
542
+ topk_weights,
543
+ masked_m,
544
+ expected_m,
545
+ )
543
546
  return deepep_output
544
547
 
545
548
  def _dispatch_core(
@@ -1,19 +1,61 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import NamedTuple
3
+ from typing import TYPE_CHECKING, NamedTuple
4
4
 
5
- from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
5
+ import torch
6
+
7
+ from sglang.srt.layers.moe.token_dispatcher.base import (
8
+ BaseDispatcher,
9
+ CombineInput,
10
+ CombineInputFormat,
6
11
  DispatchOutput,
7
12
  DispatchOutputFormat,
8
13
  )
9
14
 
15
+ if TYPE_CHECKING:
16
+ from sglang.srt.layers.moe.topk import TopKOutput
17
+
10
18
 
11
19
  class StandardDispatchOutput(NamedTuple):
12
20
  """Standard dispatch output."""
13
21
 
22
+ hidden_states: torch.Tensor
23
+ topk_output: TopKOutput
24
+
14
25
  @property
15
26
  def format(self) -> DispatchOutputFormat:
16
27
  return DispatchOutputFormat.STANDARD
17
28
 
18
29
 
19
30
  assert isinstance(StandardDispatchOutput, DispatchOutput)
31
+
32
+
33
+ class StandardCombineInput(NamedTuple):
34
+ """Standard combine input."""
35
+
36
+ hidden_states: torch.Tensor
37
+
38
+ @property
39
+ def format(self) -> CombineInputFormat:
40
+ return CombineInputFormat.STANDARD
41
+
42
+
43
+ assert isinstance(StandardCombineInput, CombineInput)
44
+
45
+
46
+ class StandardDispatcher(BaseDispatcher):
47
+
48
+ def dispatch(
49
+ self, hidden_states: torch.Tensor, topk_output: TopKOutput
50
+ ) -> DispatchOutput:
51
+ return StandardDispatchOutput(
52
+ hidden_states=hidden_states, topk_output=topk_output
53
+ )
54
+
55
+ def combine(self, combine_input: CombineInput) -> torch.Tensor:
56
+ if isinstance(combine_input, StandardCombineInput):
57
+ return combine_input.hidden_states
58
+ else:
59
+ # TODO: this branch should be removed in the future
60
+ assert isinstance(combine_input, torch.Tensor)
61
+ return combine_input
@@ -304,12 +304,12 @@ class TopK(CustomOp):
304
304
  global_num_experts = router_logits.shape[-1]
305
305
 
306
306
  # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
307
- if global_num_experts == 256 and self.topk_config.renormalize is False:
307
+ if global_num_experts == 256:
308
308
 
309
309
  routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
310
310
  router_logits = router_logits.to(torch.float32)
311
311
 
312
- return torch_npu.npu_moe_gating_top_k(
312
+ topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
313
313
  router_logits,
314
314
  k=self.topk_config.top_k,
315
315
  bias=self.topk_config.correction_bias.to(torch.float32),
@@ -321,6 +321,24 @@ class TopK(CustomOp):
321
321
  routed_scaling_factor=routed_scaling_factor,
322
322
  eps=float(1e-20),
323
323
  )
324
+
325
+ if self.topk_config.renormalize:
326
+ topk_weights_sum = (
327
+ topk_weights.sum(dim=-1, keepdim=True)
328
+ if self.topk_config.num_fused_shared_experts == 0
329
+ else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
330
+ )
331
+ topk_weights = topk_weights / topk_weights_sum
332
+
333
+ if expert_location_dispatch_info is not None:
334
+ topk_ids = topk_ids_logical_to_physical(
335
+ topk_ids, expert_location_dispatch_info
336
+ )
337
+ get_global_expert_distribution_recorder().on_select_experts(
338
+ topk_ids=topk_ids
339
+ )
340
+
341
+ return StandardTopKOutput(topk_weights, topk_ids, _)
324
342
  else:
325
343
  self.topk_config.torch_native = True
326
344
  return select_experts(
@@ -347,17 +365,28 @@ def fused_topk_torch_native(
347
365
  gating_output: torch.Tensor,
348
366
  topk: int,
349
367
  renormalize: bool,
368
+ correction_bias: torch.Tensor = None,
350
369
  ):
351
- assert (
352
- hidden_states.shape[0] == gating_output.shape[0]
353
- ), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
354
- M, _ = hidden_states.shape
355
- topk_weights = torch.empty(
356
- M, topk, dtype=torch.float32, device=hidden_states.device
357
- )
358
- topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
359
- topk_weights = F.softmax(gating_output.float(), dim=-1)
360
- topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
370
+ if correction_bias is not None:
371
+ n_routed_experts = gating_output.shape[-1]
372
+ scores = gating_output.softmax(dim=-1)
373
+ scores_for_choice = scores.view(
374
+ -1, n_routed_experts
375
+ ) + correction_bias.unsqueeze(0)
376
+ topk_ids = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1]
377
+ topk_weights = scores.gather(1, topk_ids)
378
+ else:
379
+ assert (
380
+ hidden_states.shape[0] == gating_output.shape[0]
381
+ ), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
382
+ M, _ = hidden_states.shape
383
+ topk_weights = torch.empty(
384
+ M, topk, dtype=torch.float32, device=hidden_states.device
385
+ )
386
+ topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
387
+ topk_weights = F.softmax(gating_output.float(), dim=-1)
388
+ topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
389
+
361
390
  if renormalize:
362
391
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
363
392
  return topk_weights, topk_ids
@@ -370,6 +399,7 @@ def fused_topk_cpu(
370
399
  renormalize: bool,
371
400
  num_token_non_padded: Optional[torch.Tensor] = None,
372
401
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
402
+ correction_bias: torch.Tensor = None,
373
403
  ):
374
404
  topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
375
405
  hidden_states=hidden_states,
@@ -815,6 +845,7 @@ def select_experts(
815
845
  gating_output=router_logits,
816
846
  topk=top_k,
817
847
  renormalize=renormalize,
848
+ correction_bias=correction_bias,
818
849
  )
819
850
  elif custom_routing_function is None:
820
851
  assert not apply_routed_scaling_factor_on_output, "Not implemented"
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import importlib.util
4
+ import logging
4
5
  from enum import Enum
5
6
  from functools import lru_cache
6
7
  from typing import TYPE_CHECKING, Optional
@@ -12,11 +13,12 @@ from sglang.srt.layers.dp_attention import (
12
13
  get_attention_dp_size,
13
14
  is_dp_attention_enabled,
14
15
  )
15
- from sglang.srt.utils import logger
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  from sglang.srt.server_args import ServerArgs
19
19
 
20
+ logger = logging.getLogger(__name__)
21
+
20
22
 
21
23
  class MoeA2ABackend(Enum):
22
24
 
@@ -131,7 +133,7 @@ def get_moe_a2a_backend() -> MoeA2ABackend:
131
133
  global MOE_A2A_BACKEND
132
134
  if MOE_A2A_BACKEND is None:
133
135
  logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
134
- MOE_A2A_BACKEND = MoeA2ABackend(None)
136
+ MOE_A2A_BACKEND = MoeA2ABackend.NONE
135
137
  return MOE_A2A_BACKEND
136
138
 
137
139
 
@@ -139,7 +141,7 @@ def get_moe_runner_backend() -> MoeRunnerBackend:
139
141
  global MOE_RUNNER_BACKEND
140
142
  if MOE_RUNNER_BACKEND is None:
141
143
  logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
142
- MOE_RUNNER_BACKEND = MoeRunnerBackend("triton")
144
+ MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
143
145
  return MOE_RUNNER_BACKEND
144
146
 
145
147
 
@@ -147,7 +149,7 @@ def get_deepep_mode() -> DeepEPMode:
147
149
  global DEEPEP_MODE
148
150
  if DEEPEP_MODE is None:
149
151
  logger.warning("DEEPEP_MODE is not initialized, using auto mode")
150
- DEEPEP_MODE = DeepEPMode("auto")
152
+ DEEPEP_MODE = DeepEPMode.AUTO
151
153
  return DEEPEP_MODE
152
154
 
153
155
 
@@ -162,7 +164,6 @@ def get_deepep_config() -> str:
162
164
  def is_tbo_enabled() -> bool:
163
165
  global IS_TBO_ENABLED
164
166
  if IS_TBO_ENABLED is None:
165
- logger.warning("IS_TBO_ENABLED is not initialized, using False")
166
167
  IS_TBO_ENABLED = False
167
168
  return IS_TBO_ENABLED
168
169
 
@@ -34,7 +34,10 @@ from sglang.srt.layers.quantization.utils import get_scalar_types, replace_param
34
34
 
35
35
  if TYPE_CHECKING:
36
36
  from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
37
- from sglang.srt.layers.moe.topk import StandardTopKOutput
37
+ from sglang.srt.layers.moe.token_dispatcher import (
38
+ StandardDispatchOutput,
39
+ CombineInput,
40
+ )
38
41
 
39
42
  from sglang.srt.utils import is_cuda, is_hip
40
43
 
@@ -736,24 +739,32 @@ class AWQMoEMethod(FusedMoEMethodBase):
736
739
  )
737
740
  replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
738
741
 
742
+ def create_moe_runner(
743
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
744
+ ):
745
+ self.moe_runner_config = moe_runner_config
746
+
739
747
  def apply(
740
748
  self,
741
749
  layer: torch.nn.Module,
742
- x: torch.Tensor,
743
- topk_output: StandardTopKOutput,
744
- moe_runner_config: MoeRunnerConfig,
745
- ) -> torch.Tensor:
750
+ dispatch_output: StandardDispatchOutput,
751
+ ) -> CombineInput:
752
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
753
+
746
754
  assert (
747
- moe_runner_config.activation == "silu"
755
+ self.moe_runner_config.activation == "silu"
748
756
  ), "Only SiLU activation is supported."
749
757
 
750
758
  # The input must currently be float16
759
+ x = dispatch_output.hidden_states
760
+ topk_output = dispatch_output.topk_output
761
+
751
762
  orig_dtype = x.dtype
752
763
  x = x.half()
753
764
 
754
765
  topk_weights, topk_ids, router_logits = topk_output
755
766
 
756
- return fused_marlin_moe(
767
+ output = fused_marlin_moe(
757
768
  x,
758
769
  layer.w13_qweight,
759
770
  layer.w2_qweight,
@@ -768,3 +779,4 @@ class AWQMoEMethod(FusedMoEMethodBase):
768
779
  w2_zeros=layer.w2_qzeros,
769
780
  num_bits=self.quant_config.weight_bits,
770
781
  ).to(orig_dtype)
782
+ return StandardCombineInput(hidden_states=output)
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
 
4
4
  import inspect
5
5
  from abc import ABC, abstractmethod
6
+ from dataclasses import dataclass
6
7
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
7
8
 
8
9
  import torch
@@ -10,7 +11,7 @@ from torch import nn
10
11
 
11
12
  if TYPE_CHECKING:
12
13
  from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
13
- from sglang.srt.layers.moe.topk import TopKOutput
14
+ from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput
14
15
 
15
16
 
16
17
  class QuantizeMethodBase(ABC):
@@ -89,20 +90,24 @@ class FusedMoEMethodBase(QuantizeMethodBase):
89
90
  layer: torch.nn.Module,
90
91
  num_experts: int,
91
92
  hidden_size: int,
92
- intermediate_size: int,
93
+ intermediate_size_per_partition: int,
93
94
  params_dtype: torch.dtype,
94
95
  **extra_weight_attrs,
95
96
  ):
96
97
  raise NotImplementedError
97
98
 
99
+ @abstractmethod
100
+ def create_moe_runner(
101
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
102
+ ):
103
+ raise NotImplementedError
104
+
98
105
  @abstractmethod
99
106
  def apply(
100
107
  self,
101
108
  layer: torch.nn.Module,
102
- x: torch.Tensor,
103
- topk_output: TopKOutput,
104
- moe_runner_config: MoeRunnerConfig,
105
- ) -> torch.Tensor:
109
+ dispatch_output: DispatchOutput,
110
+ ) -> CombineInput:
106
111
  raise NotImplementedError
107
112
 
108
113
 
@@ -9,6 +9,8 @@ import torch
9
9
  from torch.nn import Module
10
10
 
11
11
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
12
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
13
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
12
14
  from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter
13
15
  from sglang.srt.layers.quantization.base_config import (
14
16
  FusedMoEMethodBase,
@@ -22,8 +24,10 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
22
24
  from sglang.srt.utils import set_weight_attrs
23
25
 
24
26
  if TYPE_CHECKING:
25
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
26
- from sglang.srt.layers.moe.topk import TopKOutput
27
+ from sglang.srt.layers.moe.token_dispatcher import (
28
+ CombineInput,
29
+ StandardDispatchOutput,
30
+ )
27
31
 
28
32
  ACTIVATION_SCHEMES = ["static", "dynamic"]
29
33
 
@@ -257,7 +261,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
257
261
  layer: Module,
258
262
  num_experts: int,
259
263
  hidden_size: int,
260
- intermediate_size: int,
264
+ intermediate_size_per_partition: int,
261
265
  params_dtype: torch.dtype,
262
266
  **extra_weight_attrs,
263
267
  ):
@@ -273,25 +277,28 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
273
277
  )
274
278
  # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
275
279
  # Required by column parallel or enabling merged weights
276
- if intermediate_size % block_n != 0:
280
+ if intermediate_size_per_partition % block_n != 0:
277
281
  raise ValueError(
278
282
  f"The output_size of gate's and up's weight = "
279
- f"{intermediate_size} is not divisible by "
283
+ f"{intermediate_size_per_partition} is not divisible by "
280
284
  f"weight quantization block_n = {block_n}."
281
285
  )
282
286
  if tp_size > 1:
283
287
  # Required by row parallel
284
- if intermediate_size % block_k != 0:
288
+ if intermediate_size_per_partition % block_k != 0:
285
289
  raise ValueError(
286
290
  f"The input_size of down's weight = "
287
- f"{intermediate_size} is not divisible by "
291
+ f"{intermediate_size_per_partition} is not divisible by "
288
292
  f"weight quantization block_k = {block_k}."
289
293
  )
290
294
 
291
295
  # WEIGHTS
292
296
  w13_weight = torch.nn.Parameter(
293
297
  torch.empty(
294
- num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
298
+ num_experts,
299
+ 2 * intermediate_size_per_partition,
300
+ hidden_size,
301
+ dtype=params_dtype,
295
302
  ),
296
303
  requires_grad=False,
297
304
  )
@@ -300,7 +307,10 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
300
307
 
301
308
  w2_weight = torch.nn.Parameter(
302
309
  torch.empty(
303
- num_experts, hidden_size, intermediate_size, dtype=params_dtype
310
+ num_experts,
311
+ hidden_size,
312
+ intermediate_size_per_partition,
313
+ dtype=params_dtype,
304
314
  ),
305
315
  requires_grad=False,
306
316
  )
@@ -311,7 +321,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
311
321
  w13_weight_scale = torch.nn.Parameter(
312
322
  torch.ones(
313
323
  num_experts,
314
- 2 * ((intermediate_size + block_n - 1) // block_n),
324
+ 2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
315
325
  (hidden_size + block_k - 1) // block_k,
316
326
  dtype=torch.float32,
317
327
  ),
@@ -321,7 +331,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
321
331
  torch.ones(
322
332
  num_experts,
323
333
  (hidden_size + block_n - 1) // block_n,
324
- (intermediate_size + block_k - 1) // block_k,
334
+ (intermediate_size_per_partition + block_k - 1) // block_k,
325
335
  dtype=torch.float32,
326
336
  ),
327
337
  requires_grad=False,
@@ -344,26 +354,27 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
344
354
  # Block quant doesn't need to process weights after loading
345
355
  return
346
356
 
357
+ def create_moe_runner(
358
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
359
+ ):
360
+ self.moe_runner_config = moe_runner_config
361
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
362
+
347
363
  def apply(
348
364
  self,
349
365
  layer: torch.nn.Module,
350
- x: torch.Tensor,
351
- topk_output: TopKOutput,
352
- moe_runner_config: MoeRunnerConfig,
353
- ) -> torch.Tensor:
354
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
355
-
356
- # Expert fusion with INT8 quantization
357
- return fused_experts(
358
- x,
359
- layer.w13_weight,
360
- layer.w2_weight,
361
- topk_output=topk_output,
362
- moe_runner_config=moe_runner_config,
366
+ dispatch_output: StandardDispatchOutput,
367
+ ) -> CombineInput:
368
+
369
+ quant_info = TritonMoeQuantInfo(
370
+ w13_weight=layer.w13_weight,
371
+ w2_weight=layer.w2_weight,
363
372
  use_int8_w8a8=True,
364
- w1_scale=(layer.w13_weight_scale_inv),
365
- w2_scale=(layer.w2_weight_scale_inv),
366
- a1_scale=layer.w13_input_scale,
373
+ w13_scale=layer.w13_weight_scale_inv,
374
+ w2_scale=layer.w2_weight_scale_inv,
375
+ a13_scale=layer.w13_input_scale,
367
376
  a2_scale=layer.w2_input_scale,
368
377
  block_shape=self.quant_config.weight_block_size,
369
378
  )
379
+
380
+ return self.runner.run(dispatch_output, quant_info)