sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import TYPE_CHECKING, Optional, Union
4
+ from contextlib import nullcontext
5
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
5
6
 
6
7
  import torch
8
+ import triton
9
+ import triton.language as tl
7
10
 
8
11
  from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
9
12
  from sglang.srt.layers.moe import (
@@ -29,13 +32,26 @@ from sglang.srt.layers.quantization.fp8_kernel import (
29
32
  is_fp8_fnuz,
30
33
  sglang_per_token_group_quant_fp8,
31
34
  )
35
+ from sglang.srt.layers.quantization.modelopt_quant import (
36
+ CUTEDSL_MOE_NVFP4_DISPATCH,
37
+ ModelOptNvFp4FusedMoEMethod,
38
+ )
32
39
  from sglang.srt.managers.schedule_batch import global_server_args_dict
33
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
34
- from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
41
+ from sglang.srt.offloader import get_offloader
42
+ from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
43
+ from sglang.srt.utils import (
44
+ ceil_div,
45
+ dispose_tensor,
46
+ get_bool_env_var,
47
+ get_int_env_var,
48
+ is_cuda,
49
+ is_hip,
50
+ is_npu,
51
+ )
35
52
 
36
53
  if TYPE_CHECKING:
37
54
  from sglang.srt.layers.moe.token_dispatcher import (
38
- AscendDeepEPLLOutput,
39
55
  DeepEPLLOutput,
40
56
  DeepEPNormalOutput,
41
57
  DispatchOutput,
@@ -444,9 +460,20 @@ class DeepEPMoE(EPMoE):
444
460
  topk_idx=topk_idx,
445
461
  topk_weights=topk_weights,
446
462
  forward_batch=forward_batch,
463
+ input_global_scale=(
464
+ self.w13_input_scale_quant
465
+ if isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
466
+ and self.quant_method.enable_flashinfer_cutedsl_moe
467
+ and CUTEDSL_MOE_NVFP4_DISPATCH
468
+ else None
469
+ ),
447
470
  )
448
471
 
449
- def moe_impl(self, dispatch_output: DispatchOutput):
472
+ def moe_impl(
473
+ self,
474
+ dispatch_output: DispatchOutput,
475
+ down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
476
+ ):
450
477
  from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
451
478
 
452
479
  if _use_aiter:
@@ -454,12 +481,16 @@ class DeepEPMoE(EPMoE):
454
481
  # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
455
482
  return self.forward_aiter(dispatch_output)
456
483
  if _is_npu:
457
- assert DispatchOutputChecker.format_is_ascent_ll(dispatch_output)
484
+ assert DispatchOutputChecker.format_is_deepep(dispatch_output)
458
485
  return self.forward_npu(dispatch_output)
459
486
  if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
460
487
  assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
461
488
  return self.forward_deepgemm_contiguous(dispatch_output)
462
489
  elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
490
+ if get_moe_runner_backend().is_flashinfer_cutedsl():
491
+ return self.forward_flashinfer_cutedsl(
492
+ dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
493
+ )
463
494
  assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
464
495
  return self.forward_deepgemm_masked(dispatch_output)
465
496
  else:
@@ -473,12 +504,14 @@ class DeepEPMoE(EPMoE):
473
504
  topk_idx: torch.Tensor,
474
505
  topk_weights: torch.Tensor,
475
506
  forward_batch: ForwardBatch,
507
+ overlap_args: Optional[Dict[str, Any]] = None,
476
508
  ):
477
509
  return self.deepep_dispatcher.combine(
478
510
  hidden_states=hidden_states,
479
511
  topk_idx=topk_idx,
480
512
  topk_weights=topk_weights,
481
513
  forward_batch=forward_batch,
514
+ overlap_args=overlap_args,
482
515
  )
483
516
 
484
517
  def forward_aiter(
@@ -534,6 +567,24 @@ class DeepEPMoE(EPMoE):
534
567
  N = self.w13_weight.size(1)
535
568
  scale_block_size = 128
536
569
 
570
+ # TODO also unify other branches (e.g. `EPMoE.forward_deepgemm` sets the field on forward pass)
571
+ w13_weight_fp8 = (
572
+ self.w13_weight,
573
+ (
574
+ self.w13_weight_scale_inv
575
+ if self.use_block_quant
576
+ else self.w13_weight_scale
577
+ ),
578
+ )
579
+ w2_weight_fp8 = (
580
+ self.w2_weight,
581
+ (
582
+ self.w2_weight_scale_inv
583
+ if self.use_block_quant
584
+ else self.w2_weight_scale
585
+ ),
586
+ )
587
+
537
588
  hidden_states_fp8_shape = hidden_states_fp8.shape
538
589
  hidden_states_fp8_device = hidden_states_fp8.device
539
590
  hidden_states_fp8_dtype = hidden_states_fp8.dtype
@@ -564,12 +615,17 @@ class DeepEPMoE(EPMoE):
564
615
  )
565
616
  output_index = torch.empty_like(topk_idx)
566
617
 
567
- num_recv_tokens_per_expert_gpu = torch.tensor(
568
- num_recv_tokens_per_expert,
569
- dtype=torch.int32,
570
- pin_memory=True,
571
- device="cpu",
572
- ).cuda(non_blocking=True)
618
+ if get_offloader().forbid_copy_engine_usage:
619
+ num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
620
+ num_recv_tokens_per_expert
621
+ )
622
+ else:
623
+ num_recv_tokens_per_expert_gpu = torch.tensor(
624
+ num_recv_tokens_per_expert,
625
+ dtype=torch.int32,
626
+ pin_memory=True,
627
+ device="cpu",
628
+ ).cuda(non_blocking=True)
573
629
  expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
574
630
 
575
631
  ep_scatter(
@@ -594,7 +650,7 @@ class DeepEPMoE(EPMoE):
594
650
  if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
595
651
  input_tensor[1] = tma_align_input_scale(input_tensor[1])
596
652
  deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
597
- input_tensor, self.w13_weight_fp8, gateup_output, m_indices
653
+ input_tensor, w13_weight_fp8, gateup_output, m_indices
598
654
  )
599
655
  del input_tensor
600
656
  down_input = torch.empty(
@@ -624,7 +680,7 @@ class DeepEPMoE(EPMoE):
624
680
  down_input_scale = tma_align_input_scale(down_input_scale)
625
681
  deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
626
682
  (down_input_fp8, down_input_scale),
627
- self.w2_weight_fp8,
683
+ w2_weight_fp8,
628
684
  down_output,
629
685
  m_indices,
630
686
  )
@@ -639,6 +695,24 @@ class DeepEPMoE(EPMoE):
639
695
 
640
696
  return gather_out
641
697
 
698
+ def forward_flashinfer_cutedsl(
699
+ self,
700
+ dispatch_output: DeepEPLLOutput,
701
+ down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
702
+ ):
703
+ hidden_states, _, _, masked_m, _ = dispatch_output
704
+ assert self.quant_method is not None
705
+ assert self.moe_runner_config.activation == "silu"
706
+
707
+ output = self.quant_method.apply_without_routing_weights(
708
+ layer=self,
709
+ x=hidden_states,
710
+ masked_m=masked_m,
711
+ moe_runner_config=self.moe_runner_config,
712
+ down_gemm_overlap_args=down_gemm_overlap_args,
713
+ )
714
+ return output
715
+
642
716
  def forward_deepgemm_masked(
643
717
  self,
644
718
  dispatch_output: DeepEPLLOutput,
@@ -718,66 +792,176 @@ class DeepEPMoE(EPMoE):
718
792
 
719
793
  def forward_npu(
720
794
  self,
721
- dispatch_output: DeepEPLLOutput,
795
+ dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
722
796
  ):
723
- if TYPE_CHECKING:
724
- assert isinstance(dispatch_output, AscendDeepEPLLOutput)
725
- hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
726
797
  assert self.quant_method is not None
727
798
  assert self.moe_runner_config.activation == "silu"
728
799
 
800
+ import torch_npu
801
+
802
+ from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
803
+
729
804
  # NOTE: Ascend's Dispatch & Combine does not support FP16
730
805
  output_dtype = torch.bfloat16
806
+ group_list_type = 1
731
807
 
732
- pertoken_scale = hidden_states[1]
733
- hidden_states = hidden_states[0]
808
+ def _forward_normal(dispatch_output: DeepEPNormalOutput):
809
+ if TYPE_CHECKING:
810
+ assert isinstance(dispatch_output, DeepEPNormalOutput)
811
+ hidden_states, _, _, num_recv_tokens_per_expert = dispatch_output
734
812
 
735
- group_list_type = 1
736
- seg_indptr = seg_indptr.to(torch.int64)
813
+ if isinstance(hidden_states, tuple):
814
+ per_token_scale = hidden_states[1]
815
+ hidden_states = hidden_states[0]
737
816
 
738
- import torch_npu
817
+ group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
818
+ hidden_states.device
819
+ )
820
+ if self.w13_weight.dtype != torch.int8:
821
+ # gmm1: gate_up_proj
822
+ hidden_states = torch_npu.npu_grouped_matmul(
823
+ x=[hidden_states],
824
+ weight=[self.w13_weight.permute(0, 2, 1)],
825
+ # per_token_scale=[per_token_scale],
826
+ split_item=2,
827
+ group_list_type=group_list_type,
828
+ group_type=0,
829
+ group_list=group_list,
830
+ output_dtype=output_dtype,
831
+ )[0]
832
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
833
+ # gmm2: down_proj
834
+ hidden_states = torch_npu.npu_grouped_matmul(
835
+ x=[hidden_states],
836
+ weight=[self.w2_weight.permute(0, 2, 1)],
837
+ split_item=2,
838
+ group_list_type=group_list_type,
839
+ group_type=0,
840
+ group_list=group_list,
841
+ output_dtype=output_dtype,
842
+ )[0]
843
+ else:
844
+ if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"):
845
+ hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
846
+ hidden_states
847
+ )
848
+ # gmm1: gate_up_proj
849
+ hidden_states = torch_npu.npu_grouped_matmul(
850
+ x=[hidden_states],
851
+ weight=[self.w13_weight],
852
+ scale=[self.w13_weight_scale.to(output_dtype)],
853
+ per_token_scale=[per_token_scale],
854
+ split_item=2,
855
+ group_list_type=group_list_type,
856
+ group_type=0,
857
+ group_list=group_list,
858
+ output_dtype=output_dtype,
859
+ )[0]
860
+
861
+ # act_fn: swiglu
862
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
863
+ hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
864
+ hidden_states
865
+ )
739
866
 
740
- # gmm1: gate_up_proj
741
- hidden_states = torch_npu.npu_grouped_matmul(
742
- x=[hidden_states],
743
- weight=[self.w13_weight],
744
- split_item=2,
745
- group_list_type=group_list_type,
746
- group_type=0,
747
- group_list=seg_indptr,
748
- output_dtype=torch.int32,
749
- )[0]
750
-
751
- # act_fn: swiglu
752
- hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
753
- x=hidden_states,
754
- weight_scale=self.w13_weight_scale.to(torch.float32),
755
- activation_scale=pertoken_scale,
756
- bias=None,
757
- quant_scale=None,
758
- quant_offset=None,
759
- group_index=seg_indptr,
760
- activate_left=True,
761
- quant_mode=1,
762
- )
763
-
764
- # gmm2: down_proj
765
- hidden_states = torch_npu.npu_grouped_matmul(
766
- x=[hidden_states],
767
- weight=[self.w2_weight],
768
- scale=[self.w2_weight_scale.to(output_dtype)],
769
- per_token_scale=[swiglu_out_scale],
770
- split_item=2,
771
- group_list_type=group_list_type,
772
- group_type=0,
773
- group_list=seg_indptr,
774
- output_dtype=output_dtype,
775
- )[0]
867
+ # gmm2: down_proj
868
+ hidden_states = torch_npu.npu_grouped_matmul(
869
+ x=[hidden_states],
870
+ weight=[self.w2_weight],
871
+ scale=[self.w2_weight_scale.to(output_dtype)],
872
+ per_token_scale=[swiglu_out_scale],
873
+ split_item=2,
874
+ group_list_type=group_list_type,
875
+ group_type=0,
876
+ group_list=group_list,
877
+ output_dtype=output_dtype,
878
+ )[0]
776
879
 
777
- return hidden_states
880
+ return hidden_states
881
+
882
+ def _forward_ll(dispatch_output: DeepEPLLOutput):
883
+ if TYPE_CHECKING:
884
+ assert isinstance(dispatch_output, DeepEPLLOutput)
885
+ hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
886
+
887
+ if isinstance(hidden_states, tuple):
888
+ per_token_scale = hidden_states[1]
889
+ hidden_states = hidden_states[0]
890
+
891
+ group_list = group_list.to(torch.int64)
892
+
893
+ if self.w13_weight.dtype != torch.int8:
894
+ # gmm1: gate_up_proj
895
+ hidden_states = torch_npu.npu_grouped_matmul(
896
+ x=[hidden_states],
897
+ weight=[self.w13_weight.permute(0, 2, 1)],
898
+ # per_token_scale=[per_token_scale],
899
+ split_item=2,
900
+ group_list_type=group_list_type,
901
+ group_type=0,
902
+ group_list=group_list,
903
+ output_dtype=output_dtype,
904
+ )[0]
905
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
906
+ # gmm2: down_proj
907
+ hidden_states = torch_npu.npu_grouped_matmul(
908
+ x=[hidden_states],
909
+ weight=[self.w2_weight.permute(0, 2, 1)],
910
+ split_item=2,
911
+ group_list_type=group_list_type,
912
+ group_type=0,
913
+ group_list=group_list,
914
+ output_dtype=output_dtype,
915
+ )[0]
916
+ else:
917
+ # gmm1: gate_up_proj
918
+ hidden_states = torch_npu.npu_grouped_matmul(
919
+ x=[hidden_states],
920
+ weight=[self.w13_weight],
921
+ split_item=2,
922
+ group_list_type=group_list_type,
923
+ group_type=0,
924
+ group_list=group_list,
925
+ output_dtype=torch.int32,
926
+ )[0]
927
+
928
+ # act_fn: swiglu
929
+ hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
930
+ x=hidden_states,
931
+ weight_scale=self.w13_weight_scale.to(torch.float32),
932
+ activation_scale=per_token_scale,
933
+ bias=None,
934
+ quant_scale=None,
935
+ quant_offset=None,
936
+ group_index=group_list,
937
+ activate_left=True,
938
+ quant_mode=1,
939
+ )
940
+
941
+ # gmm2: down_proj
942
+ hidden_states = torch_npu.npu_grouped_matmul(
943
+ x=[hidden_states],
944
+ weight=[self.w2_weight],
945
+ scale=[self.w2_weight_scale.to(output_dtype)],
946
+ per_token_scale=[swiglu_out_scale],
947
+ split_item=2,
948
+ group_list_type=group_list_type,
949
+ group_type=0,
950
+ group_list=group_list,
951
+ output_dtype=output_dtype,
952
+ )[0]
778
953
 
954
+ return hidden_states
779
955
 
780
- def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
956
+ if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
957
+ return _forward_normal(dispatch_output)
958
+ elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
959
+ return _forward_ll(dispatch_output)
960
+ else:
961
+ raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}")
962
+
963
+
964
+ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
781
965
  if get_moe_a2a_backend().is_deepep():
782
966
  return DeepEPMoE
783
967
 
@@ -790,8 +974,7 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
790
974
  return FusedMoE
791
975
  try:
792
976
  # Check the quantization argument directly
793
- quantization = global_server_args_dict.get("quantization")
794
- if quantization == "modelopt_fp4":
977
+ if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
795
978
  from sglang.srt.layers.moe.fused_moe_triton.layer import (
796
979
  FlashInferFP4MoE,
797
980
  )
@@ -800,10 +983,20 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
800
983
  except:
801
984
  pass
802
985
 
803
- if should_use_flashinfer_trtllm_moe():
986
+ if should_use_flashinfer_trtllm_moe() and quant_config is not None:
987
+ # FIXME: FlashInferFusedMoE only supports fp8 quant now
804
988
  return FlashInferFusedMoE
805
989
  if get_moe_runner_backend().is_flashinfer_cutlass():
806
990
  return FusedMoE
807
991
  if get_moe_expert_parallel_world_size() > 1:
808
992
  return EPMoE
809
993
  return FusedMoE
994
+
995
+
996
+ def copy_list_to_gpu_no_ce(arr: List[int]):
997
+ from sgl_kernel.elementwise import copy_to_gpu_no_ce
998
+
999
+ tensor_cpu = torch.tensor(arr, dtype=torch.int32, device="cpu")
1000
+ tensor_gpu = torch.empty_like(tensor_cpu, device="cuda")
1001
+ copy_to_gpu_no_ce(tensor_cpu, tensor_gpu)
1002
+ return tensor_gpu
@@ -0,0 +1,183 @@
1
+ from typing import Any, Dict, Optional, Union
2
+
3
+ import torch
4
+ from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
5
+ from sgl_kernel.gemm import (
6
+ scaled_fp4_grouped_quant,
7
+ silu_and_mul_scaled_fp4_grouped_quant,
8
+ )
9
+
10
+
11
+ def get_cute_dtype(input: torch.Tensor) -> str:
12
+ if input.dtype == torch.bfloat16:
13
+ return "bfloat16"
14
+ elif input.dtype == torch.float16:
15
+ return "float16"
16
+ elif input.dtype == torch.float32:
17
+ return "float32"
18
+ else:
19
+ raise ValueError(f"Unsupported cute dtype {input.dtype}")
20
+
21
+
22
+ def flashinfer_cutedsl_moe_masked(
23
+ hidden_states: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
24
+ input_global_scale: torch.Tensor,
25
+ w1: torch.Tensor,
26
+ w1_blockscale: torch.Tensor,
27
+ w1_alpha,
28
+ w2: torch.Tensor,
29
+ a2_global_scale: torch.Tensor,
30
+ w2_blockscale: torch.Tensor,
31
+ w2_alpha,
32
+ masked_m: torch.Tensor,
33
+ down_sm_count: Optional[int] = None,
34
+ down_signals: Optional[torch.Tensor] = None,
35
+ down_start_event: Optional[torch.cuda.Event] = None,
36
+ ):
37
+ """
38
+ Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
39
+ kernels.
40
+
41
+ Args:
42
+ hidden_states: Either of the following case
43
+ * torch.Tensor: [num_experts, m, k], bf16
44
+ * tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2], uint8, [num_experts, m, k // 16], float8_e4m3fn
45
+ input_global_scale (torch.Tensor): (l,)
46
+ w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
47
+ w1_blockscale (torch.Tensor): blockscale factors, e4m3,
48
+ w1_alpha (torch.Tensor): (l,)
49
+ w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8
50
+ a2_global_scale (torch.Tensor): (l,)
51
+ w2_blockscale (torch.Tensor): blockscale factors, e4m3,
52
+ w2_alpha (torch.Tensor): (l,)
53
+ masked_m (torch.Tensor): Masked dimension indices
54
+
55
+ Notes:
56
+ - Assumes max(masked_m) == m.
57
+ """
58
+
59
+ # === Assertions on dtypes ===
60
+ assert w1.dtype == torch.uint8, f"w1 must be uint8 (fp4 packed), got {w1.dtype}"
61
+ assert (
62
+ w1_blockscale.dtype == torch.float8_e4m3fn
63
+ ), f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
64
+ assert (
65
+ w1_alpha.dtype == torch.float32
66
+ ), f"w1_alpha must be float32, got {w1_alpha.dtype}"
67
+ assert w2.dtype == torch.uint8, f"w2 must be uint8 (fp4 packed), got {w2.dtype}"
68
+ assert (
69
+ a2_global_scale.dtype == torch.float32
70
+ ), f"a2_global_scale must be float32, got {a2_global_scale.dtype}"
71
+ assert (
72
+ w2_blockscale.dtype == torch.float8_e4m3fn
73
+ ), f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}"
74
+ assert (
75
+ w2_alpha.dtype == torch.float32
76
+ ), f"w2_alpha must be float32, got {w2_alpha.dtype}"
77
+
78
+ # === Assertions on shapes ===
79
+ n = w2.shape[-1] * 2 # intermediate dimension
80
+
81
+ if isinstance(hidden_states, tuple):
82
+ assert (
83
+ input_global_scale is None
84
+ ), "input_global_scale is needed when input needs quant"
85
+
86
+ a_q = hidden_states[0].view(torch.uint8)
87
+ a_q_sf = hidden_states[1].view(torch.float8_e4m3fn)
88
+ m, k_by_2, num_experts = a_q.shape
89
+ k = k_by_2 * 2
90
+ else:
91
+ num_experts, m, k = hidden_states.shape
92
+
93
+ assert (
94
+ input_global_scale.dtype == torch.float32
95
+ ), f"input_global_scale must be float32, got {input_global_scale.dtype}"
96
+ assert input_global_scale.shape == (
97
+ num_experts,
98
+ ), f"input_global_scale must be (l,), got {input_global_scale.shape}"
99
+
100
+ a_q, a_q_sf = scaled_fp4_grouped_quant(
101
+ hidden_states,
102
+ input_global_scale,
103
+ masked_m,
104
+ )
105
+
106
+ assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
107
+ assert (
108
+ w1.shape[-1] * 2 == k
109
+ ), f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}"
110
+ assert w2.shape[-2:] == (
111
+ k,
112
+ n // 2,
113
+ ), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n//2)}"
114
+ assert w1_alpha.shape == (
115
+ num_experts,
116
+ ), f"w1_alpha must be (l,), got {w1_alpha.shape}"
117
+ assert a2_global_scale.shape == (
118
+ num_experts,
119
+ ), f"a2_global_scale must be (l,), got {a2_global_scale.shape}"
120
+ assert w2_alpha.shape == (
121
+ num_experts,
122
+ ), f"w2_alpha must be (l,), got {w2_alpha.shape}"
123
+
124
+ # TODO(kaixih@nvidia): dtype should be based on inputs.
125
+ gateup_output = torch.empty(
126
+ (num_experts, m, n * 2), dtype=torch.bfloat16, device=a_q.device
127
+ )
128
+ gateup_output = gateup_output.permute(1, 2, 0) # requirement of kernel
129
+ sf_vec_size = 16
130
+ assert a_q_sf.dtype == torch.float8_e4m3fn
131
+ assert a_q.dtype == torch.uint8
132
+ ab_dtype = "float4_e2m1fn"
133
+ sf_dtype = "float8_e4m3fn"
134
+ c_dtype = "bfloat16"
135
+
136
+ # Gemm1
137
+ grouped_gemm_nt_masked(
138
+ (a_q, a_q_sf),
139
+ (w1.permute(1, 2, 0), w1_blockscale),
140
+ gateup_output,
141
+ masked_m,
142
+ ab_dtype=ab_dtype,
143
+ sf_dtype=sf_dtype,
144
+ c_dtype=c_dtype,
145
+ sf_vec_size=sf_vec_size,
146
+ alpha=w1_alpha.view(1, 1, num_experts),
147
+ alpha_dtype=get_cute_dtype(w1_alpha),
148
+ ) # in logical [m, n, l]
149
+
150
+ # SILU and quantization
151
+ diq, diq_sf = silu_and_mul_scaled_fp4_grouped_quant(
152
+ gateup_output.permute(2, 0, 1),
153
+ a2_global_scale,
154
+ masked_m,
155
+ )
156
+
157
+ if down_start_event is not None:
158
+ down_start_event.record()
159
+
160
+ # Gemm2
161
+ out = torch.empty((num_experts, m, k), dtype=torch.bfloat16, device=a_q.device)
162
+ out = out.permute(1, 2, 0) # requirement of kernel
163
+ grouped_gemm_nt_masked(
164
+ (diq, diq_sf),
165
+ (w2.permute(1, 2, 0), w2_blockscale),
166
+ out,
167
+ masked_m,
168
+ ab_dtype=ab_dtype,
169
+ sf_dtype=sf_dtype,
170
+ c_dtype=c_dtype,
171
+ sf_vec_size=sf_vec_size,
172
+ alpha=w2_alpha.view(1, 1, num_experts),
173
+ alpha_dtype=get_cute_dtype(w2_alpha),
174
+ **(
175
+ dict(
176
+ sm_count=down_sm_count,
177
+ dst_signals=down_signals,
178
+ )
179
+ if down_sm_count is not None or down_signals is not None
180
+ else {}
181
+ ),
182
+ ) # in logical [m, k, l]
183
+ return out.permute(2, 0, 1)
@@ -8,16 +8,18 @@ from torch.nn import functional as F
8
8
 
9
9
  from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
10
10
  from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
11
+ from sglang.srt.layers.moe.token_dispatcher import StandardDispatchOutput
11
12
  from sglang.srt.layers.moe.topk import StandardTopKOutput
12
13
 
13
14
 
14
15
  def fused_moe_forward_native(
15
16
  layer: torch.nn.Module,
16
- x: torch.Tensor,
17
- topk_output: StandardTopKOutput,
18
- moe_runner_config: MoeRunnerConfig,
17
+ dispatch_output: StandardDispatchOutput,
19
18
  ) -> torch.Tensor:
20
19
 
20
+ x, topk_output = dispatch_output
21
+ moe_runner_config = layer.moe_runner_config
22
+
21
23
  if moe_runner_config.apply_router_weight_on_input:
22
24
  raise NotImplementedError()
23
25