sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -10,19 +10,21 @@ from typing import TYPE_CHECKING, Optional, Union
10
10
 
11
11
  import torch
12
12
  import triton
13
+ import triton.language as tl
13
14
 
14
15
  from sglang.srt.layers.attention.flashinfer_mla_backend import (
15
16
  FlashInferMLAAttnBackend,
16
17
  FlashInferMLAMultiStepDraftBackend,
17
18
  )
18
19
  from sglang.srt.layers.attention.utils import (
19
- TRITON_PAD_NUM_PAGE_PER_BLOCK,
20
20
  create_flashmla_kv_indices_triton,
21
+ get_num_page_per_block_flashmla,
21
22
  )
22
23
  from sglang.srt.layers.dp_attention import get_attention_tp_size
23
- from sglang.srt.managers.schedule_batch import global_server_args_dict
24
24
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
25
+ from sglang.srt.server_args import get_global_server_args
25
26
  from sglang.srt.utils import is_cuda, is_flashinfer_available
27
+ from sglang.srt.utils.common import cached_triton_kernel
26
28
 
27
29
  if is_flashinfer_available():
28
30
  import flashinfer
@@ -48,6 +50,153 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
48
50
  # compute the LCM with other padding constraints.
49
51
  TRTLLM_BLOCK_CONSTRAINT = 128
50
52
 
53
+
54
+ @cached_triton_kernel(lambda _, kwargs: (kwargs["BLOCK_SIZE"]))
55
+ @triton.jit
56
+ def pad_draft_extend_query_kernel(
57
+ q_ptr, # Input query tensor [total_seq_len, num_heads, head_dim]
58
+ padded_q_ptr, # Output padded query tensor [batch_size, max_seq_len, num_heads, head_dim]
59
+ seq_lens_q_ptr, # Sequence lengths for each sequence [batch_size]
60
+ cumsum_ptr, # Cumulative sum of accept lengths [batch_size + 1]
61
+ batch_size,
62
+ max_seq_len,
63
+ num_heads,
64
+ head_dim,
65
+ BLOCK_SIZE: tl.constexpr,
66
+ ):
67
+ """Triton kernel for padding draft extended query tensor with parallelized head and dim processing."""
68
+ # Use 3D program IDs: (batch_seq, head_block, dim_block)
69
+ batch_seq_pid = tl.program_id(0)
70
+ head_pid = tl.program_id(1)
71
+ dim_pid = tl.program_id(2)
72
+
73
+ batch_id = batch_seq_pid // max_seq_len
74
+ seq_pos = batch_seq_pid % max_seq_len
75
+
76
+ if batch_id >= batch_size:
77
+ return
78
+
79
+ # Load accept length for this batch
80
+ seq_len = tl.load(seq_lens_q_ptr + batch_id)
81
+
82
+ if seq_pos >= seq_len:
83
+ return
84
+
85
+ # Load cumulative sum to get start position in input tensor
86
+ input_start = tl.load(cumsum_ptr + batch_id)
87
+ input_pos = input_start + seq_pos
88
+
89
+ # Calculate head and dim block ranges
90
+ head_start = head_pid * BLOCK_SIZE
91
+ head_end = tl.minimum(head_start + BLOCK_SIZE, num_heads)
92
+ head_mask = tl.arange(0, BLOCK_SIZE) < (head_end - head_start)
93
+
94
+ dim_start = dim_pid * BLOCK_SIZE
95
+ dim_end = tl.minimum(dim_start + BLOCK_SIZE, head_dim)
96
+ dim_mask = tl.arange(0, BLOCK_SIZE) < (dim_end - dim_start)
97
+
98
+ # Calculate input offset
99
+ input_offset = (
100
+ input_pos * num_heads * head_dim
101
+ + (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * head_dim
102
+ + (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]
103
+ )
104
+
105
+ # Load data
106
+ data = tl.load(
107
+ q_ptr + input_offset,
108
+ mask=head_mask[:, None] & dim_mask[None, :],
109
+ other=0.0,
110
+ )
111
+
112
+ # Calculate output offset
113
+ output_offset = (
114
+ batch_id * max_seq_len * num_heads * head_dim
115
+ + seq_pos * num_heads * head_dim
116
+ + (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * head_dim
117
+ + (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]
118
+ )
119
+
120
+ # Store data
121
+ tl.store(
122
+ padded_q_ptr + output_offset,
123
+ data,
124
+ mask=head_mask[:, None] & dim_mask[None, :],
125
+ )
126
+
127
+
128
+ @cached_triton_kernel(lambda _, kwargs: (kwargs["BLOCK_SIZE"]))
129
+ @triton.jit
130
+ def unpad_draft_extend_output_kernel(
131
+ raw_out_ptr, # Input raw output tensor (batch_size, token_per_batch, tp_q_head_num, v_head_dim)
132
+ output_ptr, # Output tensor (-1, tp_q_head_num, v_head_dim)
133
+ accept_length_ptr, # Accept lengths for each sequence [batch_size]
134
+ cumsum_ptr, # Cumulative sum of accept lengths [batch_size + 1]
135
+ batch_size,
136
+ token_per_batch,
137
+ tp_q_head_num,
138
+ v_head_dim,
139
+ BLOCK_SIZE: tl.constexpr,
140
+ ):
141
+ """Triton kernel for unpadding draft extended output tensor with parallelized head and dim processing."""
142
+ batch_seq_pid = tl.program_id(0)
143
+ head_pid = tl.program_id(1)
144
+ dim_pid = tl.program_id(2)
145
+
146
+ batch_id = batch_seq_pid // token_per_batch
147
+ seq_pos = batch_seq_pid % token_per_batch
148
+
149
+ if batch_id >= batch_size:
150
+ return
151
+
152
+ # Load accept length for this batch
153
+ accept_len = tl.load(accept_length_ptr + batch_id)
154
+
155
+ if seq_pos >= accept_len:
156
+ return
157
+
158
+ # Load cumulative sum to get start position in output tensor
159
+ output_start = tl.load(cumsum_ptr + batch_id)
160
+ output_pos = output_start + seq_pos
161
+
162
+ # Calculate head and dim block ranges
163
+ head_start = head_pid * BLOCK_SIZE
164
+ head_end = tl.minimum(head_start + BLOCK_SIZE, tp_q_head_num)
165
+ head_mask = tl.arange(0, BLOCK_SIZE) < (head_end - head_start)
166
+
167
+ dim_start = dim_pid * BLOCK_SIZE
168
+ dim_end = tl.minimum(dim_start + BLOCK_SIZE, v_head_dim)
169
+ dim_mask = tl.arange(0, BLOCK_SIZE) < (dim_end - dim_start)
170
+
171
+ # Calculate input offset: (batch_id, seq_pos, head_id, dim_id)
172
+ input_offset = (
173
+ batch_id * token_per_batch * tp_q_head_num * v_head_dim
174
+ + seq_pos * tp_q_head_num * v_head_dim
175
+ + (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * v_head_dim
176
+ + (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]
177
+ )
178
+
179
+ # Load data
180
+ data = tl.load(
181
+ raw_out_ptr + input_offset,
182
+ mask=head_mask[:, None] & dim_mask[None, :],
183
+ other=0.0,
184
+ )
185
+
186
+ output_offset = (
187
+ output_pos * tp_q_head_num * v_head_dim
188
+ + (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * v_head_dim
189
+ + (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]
190
+ )
191
+
192
+ # Store data
193
+ tl.store(
194
+ output_ptr + output_offset,
195
+ data,
196
+ mask=head_mask[:, None] & dim_mask[None, :],
197
+ )
198
+
199
+
51
200
  global_zero_init_workspace_buffer = None
52
201
 
53
202
 
@@ -65,7 +214,11 @@ class TRTLLMMLADecodeMetadata:
65
214
  """Metadata for TRTLLM MLA decode operations."""
66
215
 
67
216
  block_kv_indices: Optional[torch.Tensor] = None
68
- max_seq_len: Optional[int] = None
217
+ max_seq_len_k: Optional[int] = None
218
+ max_seq_len_q: Optional[int] = None
219
+ sum_seq_lens_q: Optional[int] = None
220
+ cu_seqlens_q: Optional[torch.Tensor] = None
221
+ seq_lens_q: Optional[torch.Tensor] = None
69
222
 
70
223
 
71
224
  class TRTLLMMLABackend(FlashInferMLAAttnBackend):
@@ -120,12 +273,14 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
120
273
  # CUDA graph state
121
274
  self.decode_cuda_graph_metadata = {}
122
275
  self.decode_cuda_graph_kv_indices = None
276
+ self.padded_q_buffer = None
277
+ self.unpad_output_buffer = None
123
278
  self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
124
279
  self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
125
280
 
126
- self.disable_chunked_prefix_cache = global_server_args_dict[
127
- "disable_chunked_prefix_cache"
128
- ]
281
+ self.disable_chunked_prefix_cache = (
282
+ get_global_server_args().disable_chunked_prefix_cache
283
+ )
129
284
 
130
285
  self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
131
286
 
@@ -143,9 +298,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
143
298
 
144
299
  # Apply dual constraints (take LCM to satisfy both):
145
300
  # 1. TRT-LLM: block_num % (128 / page_size) == 0
146
- # 2. Triton: page table builder uses 64-index bursts, needs multiple of 64
301
+ # 2. Triton: number of pages per block
147
302
  trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size
148
- constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK)
303
+ triton_constraint = get_num_page_per_block_flashmla(self.page_size)
304
+ constraint_lcm = math.lcm(trtllm_constraint, triton_constraint)
149
305
 
150
306
  if blocks % constraint_lcm != 0:
151
307
  blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm
@@ -184,7 +340,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
184
340
  block_kv_indices,
185
341
  self.req_to_token.stride(0),
186
342
  max_blocks,
187
- NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
188
343
  PAGED_SIZE=self.page_size,
189
344
  )
190
345
 
@@ -203,6 +358,21 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
203
358
  self.decode_cuda_graph_kv_indices = torch.full(
204
359
  (max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
205
360
  )
361
+ num_tokens_per_bs = max_num_tokens // max_bs
362
+
363
+ # Buffer for padded query: (max_bs, max_draft_tokens, num_q_heads, v_head_dim)
364
+ self.padded_q_buffer = torch.zeros(
365
+ (max_bs, num_tokens_per_bs, self.num_q_heads, self.kv_cache_dim),
366
+ dtype=self.data_type,
367
+ device=self.device,
368
+ )
369
+
370
+ # Buffer for unpadded output: (max_num_tokens, num_q_heads, v_head_dim)
371
+ self.unpad_output_buffer = torch.zeros(
372
+ (max_num_tokens, self.num_q_heads, 512),
373
+ dtype=self.data_type,
374
+ device=self.device,
375
+ )
206
376
 
207
377
  super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
208
378
 
@@ -219,7 +389,11 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
219
389
  """Initialize metadata for CUDA graph capture."""
220
390
 
221
391
  # Delegate to parent for non-decode modes.
222
- if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
392
+ if (
393
+ not forward_mode.is_decode_or_idle()
394
+ and not forward_mode.is_target_verify()
395
+ and not forward_mode.is_draft_extend(include_v2=True)
396
+ ):
223
397
  return super().init_forward_metadata_capture_cuda_graph(
224
398
  bs,
225
399
  num_tokens,
@@ -246,7 +420,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
246
420
  block_kv_indices,
247
421
  self.req_to_token.stride(0),
248
422
  max_blocks_per_seq,
249
- NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
250
423
  PAGED_SIZE=self.page_size,
251
424
  )
252
425
 
@@ -259,6 +432,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
259
432
  block_kv_indices,
260
433
  max_seq_len_val,
261
434
  )
435
+ if forward_mode.is_draft_extend(include_v2=True):
436
+ num_tokens_per_bs = num_tokens // bs
437
+ metadata.max_seq_len_q = num_tokens_per_bs + 1
438
+ metadata.sum_seq_lens_q = num_tokens_per_bs * bs
439
+ metadata.cu_seqlens_q = torch.arange(
440
+ 0,
441
+ bs * num_tokens_per_bs + 1,
442
+ num_tokens_per_bs,
443
+ dtype=torch.int32,
444
+ device=seq_lens.device,
445
+ )
446
+ metadata.seq_lens_q = torch.full(
447
+ (bs,), num_tokens_per_bs, dtype=torch.int32, device=seq_lens.device
448
+ )
262
449
  self.decode_cuda_graph_metadata[bs] = metadata
263
450
  self.forward_decode_metadata = metadata
264
451
 
@@ -275,7 +462,11 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
275
462
  ):
276
463
  """Replay CUDA graph with new inputs."""
277
464
  # Delegate to parent for non-decode modes.
278
- if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
465
+ if (
466
+ not forward_mode.is_decode_or_idle()
467
+ and not forward_mode.is_target_verify()
468
+ and not forward_mode.is_draft_extend(include_v2=True)
469
+ ):
279
470
  return super().init_forward_metadata_replay_cuda_graph(
280
471
  bs,
281
472
  req_pool_indices,
@@ -293,6 +484,19 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
293
484
 
294
485
  metadata = self.decode_cuda_graph_metadata[bs]
295
486
 
487
+ if forward_mode.is_draft_extend(include_v2=True):
488
+ accept_length = spec_info.accept_length[:bs]
489
+ if spec_info.accept_length_cpu:
490
+ metadata.max_seq_len_q = max(spec_info.accept_length_cpu[:bs])
491
+ metadata.sum_seq_lens_q = sum(spec_info.accept_length_cpu[:bs])
492
+ else:
493
+ metadata.max_seq_len_q = 1
494
+ metadata.sum_seq_lens_q = bs
495
+ metadata.cu_seqlens_q[1:].copy_(
496
+ torch.cumsum(accept_length, dim=0, dtype=torch.int32)
497
+ )
498
+ metadata.seq_lens_q.copy_(accept_length)
499
+
296
500
  # Update block indices for new sequences.
297
501
  create_flashmla_kv_indices_triton[(bs,)](
298
502
  self.req_to_token,
@@ -302,7 +506,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
302
506
  metadata.block_kv_indices,
303
507
  self.req_to_token.stride(0),
304
508
  metadata.block_kv_indices.shape[1],
305
- NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
306
509
  PAGED_SIZE=self.page_size,
307
510
  )
308
511
 
@@ -323,7 +526,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
323
526
  if (
324
527
  forward_batch.forward_mode.is_extend()
325
528
  and not forward_batch.forward_mode.is_target_verify()
326
- and not forward_batch.forward_mode.is_draft_extend()
529
+ and not forward_batch.forward_mode.is_draft_extend(include_v2=True)
327
530
  ):
328
531
  if self.disable_chunked_prefix_cache:
329
532
  super().init_forward_metadata(forward_batch)
@@ -344,6 +547,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
344
547
  elif (
345
548
  forward_batch.forward_mode.is_decode_or_idle()
346
549
  or forward_batch.forward_mode.is_target_verify()
550
+ or forward_batch.forward_mode.is_draft_extend(include_v2=True)
347
551
  ):
348
552
  bs = forward_batch.batch_size
349
553
 
@@ -372,6 +576,23 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
372
576
  self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
373
577
  block_kv_indices, max_seq_len_val
374
578
  )
579
+ if forward_batch.forward_mode.is_draft_extend(include_v2=True):
580
+ max_seq = forward_batch.seq_lens_cpu.max().item()
581
+
582
+ sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu)
583
+ max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
584
+ cu_seqlens_q = torch.nn.functional.pad(
585
+ torch.cumsum(
586
+ forward_batch.extend_seq_lens, dim=0, dtype=torch.int32
587
+ ),
588
+ (1, 0),
589
+ )
590
+
591
+ self.forward_decode_metadata.max_seq_len_q = max_seq_len_q
592
+ self.forward_decode_metadata.sum_seq_lens_q = sum_seq_lens_q
593
+ self.forward_decode_metadata.cu_seqlens_q = cu_seqlens_q
594
+ self.forward_decode_metadata.seq_lens_q = forward_batch.extend_seq_lens
595
+
375
596
  forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
376
597
  else:
377
598
  return super().init_forward_metadata(forward_batch)
@@ -457,6 +678,86 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
457
678
 
458
679
  return q_out, k_nope_out, k_rope_out
459
680
 
681
+ def pad_draft_extend_query(
682
+ self,
683
+ q: torch.Tensor,
684
+ padded_q: torch.Tensor,
685
+ seq_lens_q: torch.Tensor,
686
+ cu_seqlens_q: torch.Tensor,
687
+ ) -> torch.Tensor:
688
+ """Pad draft extended query using Triton kernel."""
689
+ batch_size = cu_seqlens_q.shape[0] - 1
690
+ max_seq_len_q = padded_q.shape[1]
691
+ num_heads = padded_q.shape[2]
692
+ head_dim = padded_q.shape[3]
693
+
694
+ # Launch Triton kernel with 3D grid for parallelized head and dim processing
695
+ BLOCK_SIZE = 64
696
+ num_head_blocks = triton.cdiv(num_heads, BLOCK_SIZE)
697
+ num_dim_blocks = triton.cdiv(head_dim, BLOCK_SIZE)
698
+ grid = (batch_size * max_seq_len_q, num_head_blocks, num_dim_blocks)
699
+
700
+ pad_draft_extend_query_kernel[grid](
701
+ q_ptr=q,
702
+ padded_q_ptr=padded_q,
703
+ seq_lens_q_ptr=seq_lens_q,
704
+ cumsum_ptr=cu_seqlens_q,
705
+ batch_size=batch_size,
706
+ max_seq_len=max_seq_len_q,
707
+ num_heads=num_heads,
708
+ head_dim=head_dim,
709
+ BLOCK_SIZE=BLOCK_SIZE,
710
+ )
711
+ return padded_q
712
+
713
+ def unpad_draft_extend_output(
714
+ self,
715
+ raw_out: torch.Tensor,
716
+ cu_seqlens_q: torch.Tensor,
717
+ seq_lens_q: torch.Tensor,
718
+ sum_seq_lens_q: int,
719
+ ) -> torch.Tensor:
720
+ """Unpad draft extended output using Triton kernel."""
721
+ # raw_out: (batch_size, token_per_batch, layer.tp_q_head_num, layer.v_head_dim)
722
+ batch_size = seq_lens_q.shape[0]
723
+ token_per_batch = raw_out.shape[1] # max_seq_len
724
+ tp_q_head_num = raw_out.shape[2] # num_heads
725
+ v_head_dim = raw_out.shape[3] # head_dim
726
+ total_tokens = sum_seq_lens_q
727
+
728
+ # Check if we're in CUDA graph mode (buffers are pre-allocated)
729
+ if self.unpad_output_buffer is not None:
730
+ # Use pre-allocated buffer for CUDA graph compatibility
731
+ output = self.unpad_output_buffer[:total_tokens, :, :].to(
732
+ dtype=raw_out.dtype
733
+ )
734
+ else:
735
+ # Dynamic allocation for non-CUDA graph mode
736
+ output = torch.empty(
737
+ (total_tokens, tp_q_head_num, v_head_dim),
738
+ dtype=raw_out.dtype,
739
+ device=raw_out.device,
740
+ )
741
+
742
+ # Launch Triton kernel with 3D grid for parallelized head and dim processing
743
+ BLOCK_SIZE = 64
744
+ num_head_blocks = triton.cdiv(tp_q_head_num, BLOCK_SIZE)
745
+ num_dim_blocks = triton.cdiv(v_head_dim, BLOCK_SIZE)
746
+ grid = (batch_size * token_per_batch, num_head_blocks, num_dim_blocks)
747
+
748
+ unpad_draft_extend_output_kernel[grid](
749
+ raw_out_ptr=raw_out,
750
+ output_ptr=output,
751
+ accept_length_ptr=seq_lens_q,
752
+ cumsum_ptr=cu_seqlens_q,
753
+ batch_size=batch_size,
754
+ token_per_batch=token_per_batch,
755
+ tp_q_head_num=tp_q_head_num,
756
+ v_head_dim=v_head_dim,
757
+ BLOCK_SIZE=BLOCK_SIZE,
758
+ )
759
+ return output[:total_tokens, :, :]
760
+
460
761
  def forward_decode(
461
762
  self,
462
763
  q: torch.Tensor, # q_nope
@@ -550,7 +851,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
550
851
  qk_rope_head_dim=self.qk_rope_head_dim,
551
852
  block_tables=metadata.block_kv_indices,
552
853
  seq_lens=forward_batch.seq_lens.to(torch.int32),
553
- max_seq_len=metadata.max_seq_len,
854
+ max_seq_len=metadata.max_seq_len_k,
554
855
  bmm1_scale=bmm1_scale,
555
856
  )
556
857
 
@@ -571,11 +872,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
571
872
  cos_sin_cache: Optional[torch.Tensor] = None,
572
873
  is_neox: Optional[bool] = False,
573
874
  ) -> torch.Tensor:
574
- if forward_batch.forward_mode.is_draft_extend():
575
- return super().forward_extend(
576
- q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
577
- )
578
-
579
875
  # TODO refactor to avoid code duplication
580
876
  merge_query = q_rope is not None
581
877
  if (
@@ -627,7 +923,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
627
923
 
628
924
  v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
629
925
 
630
- if forward_batch.forward_mode.is_target_verify():
926
+ if (
927
+ forward_batch.forward_mode.is_target_verify()
928
+ or forward_batch.forward_mode.is_draft_extend(include_v2=True)
929
+ ):
631
930
  metadata = (
632
931
  getattr(forward_batch, "decode_trtllm_mla_metadata", None)
633
932
  or self.forward_decode_metadata
@@ -635,7 +934,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
635
934
 
636
935
  # Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]
637
936
  bs = forward_batch.batch_size
638
- q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
639
937
 
640
938
  k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
641
939
  kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
@@ -646,17 +944,42 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
646
944
  if getattr(layer, "k_scale_float", None) is not None
647
945
  else 1.0
648
946
  )
947
+ q = q.to(self.data_type)
649
948
 
650
949
  bmm1_scale = q_scale * k_scale * layer.scaling
651
-
652
- seq_lens = (
653
- forward_batch.seq_lens.to(torch.int32)
654
- + forward_batch.spec_info.draft_token_num
655
- )
656
- max_seq_len = metadata.max_seq_len + forward_batch.spec_info.draft_token_num
950
+ if forward_batch.forward_mode.is_target_verify():
951
+ seq_lens = (
952
+ forward_batch.seq_lens.to(torch.int32)
953
+ + forward_batch.spec_info.draft_token_num
954
+ )
955
+ max_seq_len = (
956
+ metadata.max_seq_len_k + forward_batch.spec_info.draft_token_num
957
+ )
958
+ else:
959
+ seq_lens = forward_batch.seq_lens.to(torch.int32)
960
+ max_seq_len = metadata.max_seq_len_k
961
+ # Check if we're in CUDA graph mode (buffers are pre-allocated)
962
+ if self.padded_q_buffer is not None:
963
+ # Use pre-allocated buffer for CUDA graph compatibility
964
+ padded_q = self.padded_q_buffer[
965
+ :bs, : metadata.max_seq_len_q, :, :
966
+ ].to(dtype=q.dtype)
967
+ else:
968
+ # Dynamic allocation for non-CUDA graph mode
969
+ padded_q = torch.zeros(
970
+ bs,
971
+ metadata.max_seq_len_q,
972
+ layer.tp_q_head_num,
973
+ layer.head_dim,
974
+ dtype=q.dtype,
975
+ device=q.device,
976
+ )
977
+ q = self.pad_draft_extend_query(
978
+ q, padded_q, metadata.seq_lens_q, metadata.cu_seqlens_q
979
+ )
657
980
 
658
981
  # TODO may use `mla_rope_quantize_fp8` fusion
659
- q = q.to(self.data_type)
982
+ q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
660
983
  assert kv_cache.dtype == self.data_type
661
984
 
662
985
  raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
@@ -673,6 +996,14 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
673
996
  )
674
997
 
675
998
  # Reshape output directly without slicing
999
+
1000
+ if forward_batch.forward_mode.is_draft_extend(include_v2=True):
1001
+ raw_out = self.unpad_draft_extend_output(
1002
+ raw_out,
1003
+ metadata.cu_seqlens_q,
1004
+ metadata.seq_lens_q,
1005
+ metadata.sum_seq_lens_q,
1006
+ )
676
1007
  output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
677
1008
  return output
678
1009
 
@@ -735,7 +1066,7 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
735
1066
  ):
736
1067
  super().__init__(model_runner, topk, speculative_num_steps)
737
1068
 
738
- for i in range(self.speculative_num_steps):
1069
+ for i in range(self.speculative_num_steps - 1):
739
1070
  self.attn_backends[i] = TRTLLMMLABackend(
740
1071
  model_runner,
741
1072
  skip_prefill=True,
@@ -1,10 +1,8 @@
1
1
  import triton
2
2
  import triton.language as tl
3
3
 
4
- # Keep this in sync with the Triton kernel inside `create_flashmla_kv_indices_triton`.
5
- # Number of pages that the kernel writes per iteration.
6
- # Exposed here so other Python modules can import it instead of hard-coding 64.
7
- TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
4
+ _FLASHMLA_CREATE_KV_BLOCK_SIZE = 4096
5
+ FLASHMLA_CREATE_KV_BLOCK_SIZE_TRITON = tl.constexpr(_FLASHMLA_CREATE_KV_BLOCK_SIZE)
8
6
 
9
7
 
10
8
  @triton.jit
@@ -46,6 +44,11 @@ def create_flashinfer_kv_indices_triton(
46
44
  tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
47
45
 
48
46
 
47
+ def get_num_page_per_block_flashmla(page_size: int = 64) -> int:
48
+ num_page_per_block = _FLASHMLA_CREATE_KV_BLOCK_SIZE // page_size
49
+ return num_page_per_block
50
+
51
+
49
52
  @triton.jit
50
53
  def create_flashmla_kv_indices_triton(
51
54
  req_to_token_ptr, # [max_batch, max_context_len]
@@ -55,10 +58,11 @@ def create_flashmla_kv_indices_triton(
55
58
  kv_indices_ptr,
56
59
  req_to_token_ptr_stride: tl.constexpr,
57
60
  kv_indices_ptr_stride: tl.constexpr,
58
- NUM_PAGE_PER_BLOCK: tl.constexpr = TRITON_PAD_NUM_PAGE_PER_BLOCK,
59
61
  PAGED_SIZE: tl.constexpr = 64,
60
62
  ):
61
- BLOCK_SIZE: tl.constexpr = 4096
63
+ NUM_PAGE_PER_BLOCK: tl.constexpr = (
64
+ FLASHMLA_CREATE_KV_BLOCK_SIZE_TRITON // PAGED_SIZE
65
+ )
62
66
  pid = tl.program_id(axis=0)
63
67
 
64
68
  # find the req pool idx, this is for batch to token
@@ -73,7 +77,7 @@ def create_flashmla_kv_indices_triton(
73
77
  kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
74
78
 
75
79
  num_paged = tl.cdiv(kv_end - kv_start, PAGED_SIZE)
76
- num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
80
+ num_pages_loop = tl.cdiv(kv_end - kv_start, FLASHMLA_CREATE_KV_BLOCK_SIZE_TRITON)
77
81
 
78
82
  for i in range(num_pages_loop):
79
83
  # index into req_to_token_ptr needs to be int64
@@ -45,7 +45,7 @@ from sglang.srt.layers.linear import (
45
45
  )
46
46
  from sglang.srt.layers.quantization import QuantizationConfig
47
47
  from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
48
- from sglang.srt.managers.schedule_batch import global_server_args_dict
48
+ from sglang.srt.server_args import get_global_server_args
49
49
  from sglang.srt.utils import add_prefix
50
50
 
51
51
  ROTARY_EMBED_CLASSES = {
@@ -468,7 +468,7 @@ class VisionAttention(nn.Module):
468
468
  _passed_backend = qkv_backend
469
469
  qkv_backend = self._determine_attention_backend(_passed_backend)
470
470
  if (
471
- global_server_args_dict["mm_attention_backend"] is None
471
+ get_global_server_args().mm_attention_backend is None
472
472
  and _passed_backend is None
473
473
  ):
474
474
  print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
@@ -528,7 +528,7 @@ class VisionAttention(nn.Module):
528
528
  - CUDA: "triton_attn"
529
529
  - Non-CUDA: "sdpa"
530
530
  """
531
- override_backend = global_server_args_dict["mm_attention_backend"]
531
+ override_backend = get_global_server_args().mm_attention_backend
532
532
  if override_backend is not None:
533
533
  backend = override_backend
534
534
  elif passed_backend is not None: