sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,569 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, List, Optional
5
+
6
+ import torch
7
+
8
+ from sglang.srt.layers import deep_gemm_wrapper
9
+ from sglang.srt.layers.moe.moe_runner.base import (
10
+ MoeQuantInfo,
11
+ MoeRunnerConfig,
12
+ MoeRunnerCore,
13
+ RunnerInput,
14
+ RunnerOutput,
15
+ register_post_permute,
16
+ register_pre_permute,
17
+ )
18
+ from sglang.srt.layers.moe.utils import MoeRunnerBackend
19
+ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
20
+ from sglang.srt.utils.offloader import get_offloader
21
+
22
+ if TYPE_CHECKING:
23
+ from sglang.srt.layers.moe.token_dispatcher.deepep import (
24
+ DeepEPLLCombineInput,
25
+ DeepEPLLDispatchOutput,
26
+ DeepEPNormalCombineInput,
27
+ DeepEPNormalDispatchOutput,
28
+ )
29
+ from sglang.srt.layers.moe.token_dispatcher.standard import (
30
+ StandardCombineInput,
31
+ StandardDispatchOutput,
32
+ )
33
+
34
+ _is_hip = is_hip()
35
+ _is_npu = is_npu()
36
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
37
+
38
+ if not (_is_npu or _is_hip):
39
+ from sgl_kernel import silu_and_mul
40
+
41
+
42
+ # TODO(kaixih@nvidia): ideally we should merge this logic into
43
+ # `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
44
+ @torch.compile
45
+ def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
46
+ temp = x.to(torch.float32).view(torch.int32)
47
+ exp = torch.bitwise_right_shift(temp, 23)
48
+ mant = torch.bitwise_and(temp, 0x7FFFFF)
49
+ is_ru = torch.logical_and(
50
+ torch.logical_and((mant > 0), (exp != 0xFE)),
51
+ ~torch.logical_and((exp == 0), (mant <= 0x400000)),
52
+ )
53
+ exp = torch.where(is_ru, exp + 1, exp)
54
+ new_x = exp.to(torch.uint8).view(torch.int)
55
+ return new_x.transpose(1, 2).contiguous().transpose(1, 2)
56
+
57
+
58
+ def copy_list_to_gpu_no_ce(arr: List[int]):
59
+ from sgl_kernel.elementwise import copy_to_gpu_no_ce
60
+
61
+ tensor_cpu = torch.tensor(arr, dtype=torch.int32, device="cpu")
62
+ tensor_gpu = torch.empty_like(tensor_cpu, device="cuda")
63
+ copy_to_gpu_no_ce(tensor_cpu, tensor_gpu)
64
+ return tensor_gpu
65
+
66
+
67
+ @dataclass
68
+ class DeepGemmRunnerInput(RunnerInput):
69
+ hidden_states: torch.Tensor
70
+ hidden_states_scale: torch.Tensor
71
+ use_masked_gemm: bool
72
+ masked_m: Optional[torch.Tensor] = None
73
+ expected_m: Optional[int] = None
74
+ m_indices: Optional[torch.Tensor] = None
75
+
76
+ @property
77
+ def runner_backend(self) -> MoeRunnerBackend:
78
+ return MoeRunnerBackend.DEEP_GEMM
79
+
80
+
81
+ @dataclass
82
+ class DeepGemmRunnerOutput(RunnerOutput):
83
+ hidden_states: torch.Tensor
84
+
85
+ @property
86
+ def runner_backend(self) -> MoeRunnerBackend:
87
+ return MoeRunnerBackend.DEEP_GEMM
88
+
89
+
90
+ @dataclass
91
+ class DeepGemmMoeQuantInfo(MoeQuantInfo):
92
+ w13_weight: torch.Tensor
93
+ w2_weight: torch.Tensor
94
+ use_fp8: bool
95
+ w13_scale: Optional[torch.Tensor] = None
96
+ w2_scale: Optional[torch.Tensor] = None
97
+ block_shape: Optional[List[int]] = None
98
+
99
+
100
+ class DeepGemmRunnerCore(MoeRunnerCore):
101
+ def __init__(self, config: MoeRunnerConfig):
102
+ super().__init__(config)
103
+ assert self.config.activation == "silu"
104
+
105
+ def run(
106
+ self,
107
+ runner_input: DeepGemmRunnerInput,
108
+ quant_info: DeepGemmMoeQuantInfo,
109
+ running_state: dict,
110
+ ) -> DeepGemmRunnerOutput:
111
+
112
+ if not runner_input.use_masked_gemm:
113
+ hidden_states = self._run_contiguous_gemm(
114
+ runner_input, quant_info, running_state
115
+ )
116
+ else:
117
+ hidden_states = self._run_masked_gemm(
118
+ runner_input, quant_info, running_state
119
+ )
120
+ return DeepGemmRunnerOutput(hidden_states=hidden_states)
121
+
122
+ def _run_contiguous_gemm(
123
+ self,
124
+ runner_input: DeepGemmRunnerInput,
125
+ quant_info: DeepGemmMoeQuantInfo,
126
+ running_state: dict,
127
+ ) -> torch.Tensor:
128
+
129
+ from sglang.srt.layers.moe.ep_moe.kernels import tma_align_input_scale
130
+ from sglang.srt.layers.quantization.fp8_kernel import (
131
+ sglang_per_token_group_quant_fp8,
132
+ )
133
+
134
+ hidden_states = runner_input.hidden_states
135
+ hidden_states_scale = runner_input.hidden_states_scale
136
+ all_tokens = running_state["all_tokens"]
137
+ hidden_states_device = running_state["hidden_states_device"]
138
+ hidden_states_dtype = running_state["hidden_states_dtype"]
139
+ hidden_states_shape = running_state["hidden_states_shape"]
140
+ m_indices = runner_input.m_indices
141
+
142
+ N = quant_info.w13_weight.size(1)
143
+ K = hidden_states_shape[1]
144
+ scale_block_size = 128
145
+
146
+ w13_weight_fp8 = (
147
+ quant_info.w13_weight,
148
+ quant_info.w13_scale,
149
+ )
150
+ w2_weight_fp8 = (quant_info.w2_weight, quant_info.w2_scale)
151
+
152
+ gateup_output = torch.empty(
153
+ (all_tokens, N),
154
+ device=hidden_states_device,
155
+ dtype=torch.bfloat16,
156
+ )
157
+ if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
158
+ hidden_states_scale = tma_align_input_scale(hidden_states_scale)
159
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
160
+ (hidden_states, hidden_states_scale),
161
+ w13_weight_fp8,
162
+ gateup_output,
163
+ m_indices,
164
+ )
165
+
166
+ dispose_tensor(hidden_states)
167
+ dispose_tensor(hidden_states_scale)
168
+
169
+ down_input = torch.empty(
170
+ (
171
+ all_tokens,
172
+ N // 2,
173
+ ),
174
+ device=gateup_output.device,
175
+ dtype=torch.bfloat16,
176
+ )
177
+ silu_and_mul(gateup_output.view(-1, N), down_input)
178
+ del gateup_output
179
+
180
+ down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
181
+ down_input,
182
+ scale_block_size,
183
+ column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
184
+ scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
185
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
186
+ )
187
+ del down_input
188
+
189
+ down_output = torch.empty(
190
+ (all_tokens, K),
191
+ device=hidden_states_device,
192
+ dtype=torch.bfloat16,
193
+ )
194
+ if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
195
+ down_input_scale = tma_align_input_scale(down_input_scale)
196
+
197
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
198
+ (down_input_fp8, down_input_scale),
199
+ w2_weight_fp8,
200
+ down_output,
201
+ m_indices,
202
+ )
203
+
204
+ return down_output
205
+
206
+ def _run_masked_gemm(
207
+ self,
208
+ runner_input: DeepGemmRunnerInput,
209
+ quant_info: DeepGemmMoeQuantInfo,
210
+ running_state: dict,
211
+ ) -> torch.Tensor:
212
+
213
+ from sglang.srt.layers import deep_gemm_wrapper
214
+ from sglang.srt.layers.moe.ep_moe.kernels import (
215
+ silu_and_mul_masked_post_quant_fwd,
216
+ )
217
+
218
+ hidden_states = runner_input.hidden_states
219
+ hidden_states_scale = runner_input.hidden_states_scale
220
+ masked_m = runner_input.masked_m
221
+ expected_m = runner_input.expected_m
222
+
223
+ w13_weight = quant_info.w13_weight
224
+ w2_weight = quant_info.w2_weight
225
+ w13_scale = quant_info.w13_scale
226
+ w2_scale = quant_info.w2_scale
227
+
228
+ hidden_states_device = running_state["hidden_states_device"]
229
+
230
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
231
+ b, s_mn, s_k = hidden_states_scale.shape
232
+ assert (
233
+ s_mn % 4 == 0 and s_k % 4 == 0
234
+ ), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
235
+
236
+ # GroupGemm-0
237
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
238
+ hidden_states_scale = _cast_to_e8m0_with_rounding_up(hidden_states_scale)
239
+ else:
240
+ hidden_states_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
241
+ hidden_states_scale
242
+ )
243
+
244
+ num_groups, m, k = hidden_states.shape
245
+ n = w13_weight.size(1)
246
+ gateup_output = torch.empty(
247
+ (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
248
+ )
249
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
250
+ (hidden_states, hidden_states_scale),
251
+ (w13_weight, w13_scale),
252
+ gateup_output,
253
+ masked_m,
254
+ expected_m,
255
+ )
256
+ dispose_tensor(hidden_states)
257
+ dispose_tensor(hidden_states_scale)
258
+
259
+ # Act
260
+ down_input = torch.empty(
261
+ (
262
+ gateup_output.shape[0],
263
+ gateup_output.shape[1],
264
+ gateup_output.shape[2] // 2,
265
+ ),
266
+ device=hidden_states_device,
267
+ dtype=torch.float8_e4m3fn,
268
+ )
269
+ scale_block_size = 128
270
+ down_input_scale = torch.empty(
271
+ (
272
+ gateup_output.shape[0],
273
+ gateup_output.shape[1],
274
+ gateup_output.shape[2] // 2 // scale_block_size,
275
+ ),
276
+ device=hidden_states_device,
277
+ dtype=torch.float32,
278
+ )
279
+ silu_and_mul_masked_post_quant_fwd(
280
+ gateup_output,
281
+ down_input,
282
+ down_input_scale,
283
+ scale_block_size,
284
+ masked_m,
285
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
286
+ )
287
+ del gateup_output
288
+
289
+ # GroupGemm-1
290
+ n = w2_weight.shape[1]
291
+
292
+ if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
293
+ down_input_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
294
+ down_input_scale
295
+ )
296
+
297
+ down_output = torch.empty(
298
+ (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
299
+ )
300
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
301
+ (down_input, down_input_scale),
302
+ (w2_weight, w2_scale),
303
+ down_output,
304
+ masked_m,
305
+ expected_m,
306
+ )
307
+
308
+ return down_output
309
+
310
+ @property
311
+ def runner_backend(self) -> MoeRunnerBackend:
312
+ return MoeRunnerBackend.DEEP_GEMM
313
+
314
+
315
+ @register_pre_permute("standard", "deep_gemm")
316
+ def pre_permute_standard_to_deep_gemm(
317
+ dispatch_output: StandardDispatchOutput,
318
+ quant_info: DeepGemmMoeQuantInfo,
319
+ runner_config: MoeRunnerConfig,
320
+ running_state: dict,
321
+ ) -> DeepGemmRunnerInput:
322
+
323
+ from sglang.srt.layers.moe.ep_moe.kernels import moe_ep_deepgemm_preprocess
324
+
325
+ hidden_states, topk_output = dispatch_output
326
+ topk_weights, topk_ids, _ = topk_output
327
+
328
+ hidden_states_shape = hidden_states.shape
329
+ hidden_states_dtype = hidden_states.dtype
330
+ hidden_states_device = hidden_states.device
331
+ hidden_states_ref = hidden_states
332
+
333
+ topk_weights, topk_ids = topk_weights, topk_ids
334
+
335
+ # PreReorder
336
+ masked_m, expected_m, src2dst, hidden_states, hidden_states_scale = (
337
+ moe_ep_deepgemm_preprocess(
338
+ topk_ids,
339
+ runner_config.num_local_experts,
340
+ hidden_states,
341
+ runner_config.top_k,
342
+ quant_info.block_shape,
343
+ )
344
+ )
345
+
346
+ dispose_tensor(hidden_states_ref)
347
+
348
+ running_state["topk_ids"] = topk_ids
349
+ running_state["topk_weights"] = topk_weights
350
+ running_state["hidden_states_shape"] = hidden_states_shape
351
+ running_state["hidden_states_dtype"] = hidden_states_dtype
352
+ running_state["hidden_states_device"] = hidden_states_device
353
+ running_state["src2dst"] = src2dst
354
+
355
+ return DeepGemmRunnerInput(
356
+ hidden_states=hidden_states,
357
+ hidden_states_scale=hidden_states_scale,
358
+ use_masked_gemm=True,
359
+ masked_m=masked_m,
360
+ expected_m=expected_m,
361
+ )
362
+
363
+
364
+ @register_post_permute("deep_gemm", "standard")
365
+ def post_permute_deep_gemm_to_standard(
366
+ runner_output: DeepGemmRunnerOutput,
367
+ quant_info: DeepGemmMoeQuantInfo,
368
+ runner_config: MoeRunnerConfig,
369
+ running_state: dict,
370
+ ) -> StandardCombineInput:
371
+ from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
372
+ from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
373
+
374
+ hidden_states_shape = running_state["hidden_states_shape"]
375
+ hidden_states_dtype = running_state["hidden_states_dtype"]
376
+ hidden_states_device = running_state["hidden_states_device"]
377
+ src2dst = running_state["src2dst"]
378
+ topk_ids = running_state["topk_ids"]
379
+ topk_weights = running_state["topk_weights"]
380
+
381
+ output = torch.empty(
382
+ hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
383
+ )
384
+ post_reorder_triton_kernel[(hidden_states_shape[0],)](
385
+ runner_output.hidden_states,
386
+ output,
387
+ src2dst,
388
+ topk_ids,
389
+ topk_weights,
390
+ runner_config.top_k,
391
+ hidden_states_shape[1],
392
+ BLOCK_SIZE=512,
393
+ )
394
+
395
+ dispose_tensor(runner_output.hidden_states)
396
+
397
+ if runner_config.routed_scaling_factor is not None:
398
+ output *= runner_config.routed_scaling_factor
399
+
400
+ return StandardCombineInput(
401
+ hidden_states=output,
402
+ )
403
+
404
+
405
+ @register_pre_permute("deepep_ll", "deep_gemm")
406
+ def pre_permute_deepep_ll_to_deep_gemm(
407
+ dispatch_output: DeepEPLLDispatchOutput,
408
+ quant_info: DeepGemmMoeQuantInfo,
409
+ runner_config: MoeRunnerConfig,
410
+ running_state: dict,
411
+ ) -> DeepGemmRunnerInput:
412
+
413
+ hidden_states, hidden_states_scale, topk_ids, topk_weights, masked_m, expected_m = (
414
+ dispatch_output
415
+ )
416
+
417
+ running_state["topk_ids"] = topk_ids
418
+ running_state["topk_weights"] = topk_weights
419
+ running_state["hidden_states_shape"] = hidden_states.shape
420
+ running_state["hidden_states_dtype"] = hidden_states.dtype
421
+ running_state["hidden_states_device"] = hidden_states.device
422
+
423
+ return DeepGemmRunnerInput(
424
+ hidden_states=hidden_states,
425
+ hidden_states_scale=hidden_states_scale,
426
+ use_masked_gemm=True,
427
+ masked_m=masked_m,
428
+ expected_m=expected_m,
429
+ )
430
+
431
+
432
+ @register_post_permute("deep_gemm", "deepep_ll")
433
+ def post_permute_deep_gemm_to_deepep_ll(
434
+ runner_output: DeepGemmRunnerOutput,
435
+ quant_info: DeepGemmMoeQuantInfo,
436
+ runner_config: MoeRunnerConfig,
437
+ running_state: dict,
438
+ ) -> DeepEPLLCombineInput:
439
+
440
+ from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPLLCombineInput
441
+
442
+ return DeepEPLLCombineInput(
443
+ hidden_states=runner_output.hidden_states,
444
+ topk_ids=running_state["topk_ids"],
445
+ topk_weights=running_state["topk_weights"],
446
+ )
447
+
448
+
449
+ @register_pre_permute("deepep_normal", "deep_gemm")
450
+ def pre_permute_deepep_normal_to_deep_gemm(
451
+ dispatch_output: DeepEPNormalDispatchOutput,
452
+ quant_info: DeepGemmMoeQuantInfo,
453
+ runner_config: MoeRunnerConfig,
454
+ running_state: dict,
455
+ ) -> DeepGemmRunnerInput:
456
+
457
+ from sglang.srt.layers.moe.ep_moe.kernels import ep_scatter
458
+
459
+ (
460
+ hidden_states,
461
+ hidden_states_scale,
462
+ topk_ids,
463
+ topk_weights,
464
+ num_recv_tokens_per_expert,
465
+ ) = dispatch_output
466
+ assert runner_config.activation == "silu"
467
+
468
+ all_tokens = sum(num_recv_tokens_per_expert)
469
+ running_state["all_tokens"] = all_tokens
470
+
471
+ K = hidden_states.shape[1]
472
+
473
+ hidden_states_shape = hidden_states.shape
474
+ hidden_states_device = hidden_states.device
475
+ hidden_states_dtype = hidden_states.dtype
476
+
477
+ running_state["hidden_states_shape"] = hidden_states_shape
478
+ running_state["hidden_states_device"] = hidden_states_device
479
+ running_state["hidden_states_dtype"] = hidden_states_dtype
480
+ running_state["topk_ids"] = topk_ids
481
+ running_state["topk_weights"] = topk_weights
482
+
483
+ input_tensor = torch.empty(
484
+ (all_tokens, K),
485
+ device=hidden_states.device,
486
+ dtype=hidden_states.dtype,
487
+ )
488
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
489
+ # TODO check whether need `zeros`
490
+ input_tensor_scale = torch.zeros(
491
+ (ceil_div(K // 128, 4), all_tokens),
492
+ device=hidden_states.device,
493
+ dtype=torch.int,
494
+ ).transpose(0, 1)
495
+ else:
496
+ input_tensor_scale = torch.empty(
497
+ (all_tokens, K // 128),
498
+ device=hidden_states.device,
499
+ dtype=torch.float32,
500
+ )
501
+ m_indices = torch.empty(all_tokens, device=hidden_states.device, dtype=torch.int32)
502
+ output_index = torch.empty_like(topk_ids)
503
+
504
+ if get_offloader().forbid_copy_engine_usage:
505
+ num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
506
+ num_recv_tokens_per_expert
507
+ )
508
+ else:
509
+ num_recv_tokens_per_expert_gpu = torch.tensor(
510
+ num_recv_tokens_per_expert,
511
+ dtype=torch.int32,
512
+ pin_memory=True,
513
+ device="cpu",
514
+ ).cuda(non_blocking=True)
515
+ expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
516
+
517
+ ep_scatter(
518
+ hidden_states,
519
+ hidden_states_scale,
520
+ topk_ids,
521
+ num_recv_tokens_per_expert_gpu,
522
+ expert_start_loc,
523
+ input_tensor,
524
+ input_tensor_scale,
525
+ m_indices,
526
+ output_index,
527
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
528
+ )
529
+ dispose_tensor(hidden_states)
530
+ dispose_tensor(hidden_states_scale)
531
+
532
+ running_state["output_index"] = output_index
533
+
534
+ return DeepGemmRunnerInput(
535
+ hidden_states=input_tensor,
536
+ hidden_states_scale=input_tensor_scale,
537
+ use_masked_gemm=False,
538
+ m_indices=m_indices,
539
+ )
540
+
541
+
542
+ @register_post_permute("deep_gemm", "deepep_normal")
543
+ def post_permute_deep_gemm_to_deepep_normal(
544
+ runner_output: DeepGemmRunnerOutput,
545
+ quant_info: DeepGemmMoeQuantInfo,
546
+ runner_config: MoeRunnerConfig,
547
+ running_state: dict,
548
+ ) -> DeepEPNormalCombineInput:
549
+
550
+ from sglang.srt.layers.moe.ep_moe.kernels import ep_gather
551
+ from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPNormalCombineInput
552
+
553
+ hidden_states = runner_output.hidden_states
554
+ topk_ids = running_state["topk_ids"]
555
+ topk_weights = running_state["topk_weights"]
556
+ output_index = running_state["output_index"]
557
+
558
+ gather_out = torch.empty(
559
+ running_state["hidden_states_shape"],
560
+ device=running_state["hidden_states_device"],
561
+ dtype=torch.bfloat16,
562
+ )
563
+ ep_gather(hidden_states, topk_ids, topk_weights, output_index, gather_out)
564
+
565
+ return DeepEPNormalCombineInput(
566
+ hidden_states=gather_out,
567
+ topk_ids=running_state["topk_ids"],
568
+ topk_weights=running_state["topk_weights"],
569
+ )
@@ -9,7 +9,9 @@ from sglang.srt.layers.moe.moe_runner.base import (
9
9
  MoeRunnerConfig,
10
10
  PermuteMethodPool,
11
11
  )
12
+ from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmRunnerCore
12
13
  from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
14
+ from sglang.srt.layers.moe.moe_runner.triton_kernels import TritonKernelsRunnerCore
13
15
  from sglang.srt.layers.moe.utils import get_moe_a2a_backend
14
16
 
15
17
  if TYPE_CHECKING:
@@ -30,6 +32,10 @@ class MoeRunner:
30
32
 
31
33
  if runner_backend.is_triton():
32
34
  self.runner_core = TritonRunnerCore(config)
35
+ elif runner_backend.is_triton_kernels():
36
+ self.runner_core = TritonKernelsRunnerCore(config)
37
+ elif runner_backend.is_deep_gemm():
38
+ self.runner_core = DeepGemmRunnerCore(config)
33
39
  else:
34
40
  raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
35
41
 
@@ -51,7 +51,9 @@ elif _is_hip:
51
51
 
52
52
 
53
53
  if _is_cuda or _is_hip:
54
- from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
54
+ from sgl_kernel import ( # noqa: F401
55
+ moe_align_block_size as sgl_moe_align_block_size,
56
+ )
55
57
 
56
58
 
57
59
  @dataclass