sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -10,19 +10,21 @@ from typing import TYPE_CHECKING, Optional, Union
10
10
 
11
11
  import torch
12
12
  import triton
13
+ import triton.language as tl
13
14
 
14
15
  from sglang.srt.layers.attention.flashinfer_mla_backend import (
15
16
  FlashInferMLAAttnBackend,
16
17
  FlashInferMLAMultiStepDraftBackend,
17
18
  )
18
19
  from sglang.srt.layers.attention.utils import (
19
- TRITON_PAD_NUM_PAGE_PER_BLOCK,
20
20
  create_flashmla_kv_indices_triton,
21
+ get_num_page_per_block_flashmla,
21
22
  )
22
23
  from sglang.srt.layers.dp_attention import get_attention_tp_size
23
- from sglang.srt.managers.schedule_batch import global_server_args_dict
24
24
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
25
+ from sglang.srt.server_args import get_global_server_args
25
26
  from sglang.srt.utils import is_cuda, is_flashinfer_available
27
+ from sglang.srt.utils.common import cached_triton_kernel
26
28
 
27
29
  if is_flashinfer_available():
28
30
  import flashinfer
@@ -48,6 +50,153 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
48
50
  # compute the LCM with other padding constraints.
49
51
  TRTLLM_BLOCK_CONSTRAINT = 128
50
52
 
53
+
54
+ @cached_triton_kernel(lambda _, kwargs: (kwargs["BLOCK_SIZE"]))
55
+ @triton.jit
56
+ def pad_draft_extend_query_kernel(
57
+ q_ptr, # Input query tensor [total_seq_len, num_heads, head_dim]
58
+ padded_q_ptr, # Output padded query tensor [batch_size, max_seq_len, num_heads, head_dim]
59
+ seq_lens_q_ptr, # Sequence lengths for each sequence [batch_size]
60
+ cumsum_ptr, # Cumulative sum of accept lengths [batch_size + 1]
61
+ batch_size,
62
+ max_seq_len,
63
+ num_heads,
64
+ head_dim,
65
+ BLOCK_SIZE: tl.constexpr,
66
+ ):
67
+ """Triton kernel for padding draft extended query tensor with parallelized head and dim processing."""
68
+ # Use 3D program IDs: (batch_seq, head_block, dim_block)
69
+ batch_seq_pid = tl.program_id(0)
70
+ head_pid = tl.program_id(1)
71
+ dim_pid = tl.program_id(2)
72
+
73
+ batch_id = batch_seq_pid // max_seq_len
74
+ seq_pos = batch_seq_pid % max_seq_len
75
+
76
+ if batch_id >= batch_size:
77
+ return
78
+
79
+ # Load accept length for this batch
80
+ seq_len = tl.load(seq_lens_q_ptr + batch_id)
81
+
82
+ if seq_pos >= seq_len:
83
+ return
84
+
85
+ # Load cumulative sum to get start position in input tensor
86
+ input_start = tl.load(cumsum_ptr + batch_id)
87
+ input_pos = input_start + seq_pos
88
+
89
+ # Calculate head and dim block ranges
90
+ head_start = head_pid * BLOCK_SIZE
91
+ head_end = tl.minimum(head_start + BLOCK_SIZE, num_heads)
92
+ head_mask = tl.arange(0, BLOCK_SIZE) < (head_end - head_start)
93
+
94
+ dim_start = dim_pid * BLOCK_SIZE
95
+ dim_end = tl.minimum(dim_start + BLOCK_SIZE, head_dim)
96
+ dim_mask = tl.arange(0, BLOCK_SIZE) < (dim_end - dim_start)
97
+
98
+ # Calculate input offset
99
+ input_offset = (
100
+ input_pos * num_heads * head_dim
101
+ + (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * head_dim
102
+ + (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]
103
+ )
104
+
105
+ # Load data
106
+ data = tl.load(
107
+ q_ptr + input_offset,
108
+ mask=head_mask[:, None] & dim_mask[None, :],
109
+ other=0.0,
110
+ )
111
+
112
+ # Calculate output offset
113
+ output_offset = (
114
+ batch_id * max_seq_len * num_heads * head_dim
115
+ + seq_pos * num_heads * head_dim
116
+ + (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * head_dim
117
+ + (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]
118
+ )
119
+
120
+ # Store data
121
+ tl.store(
122
+ padded_q_ptr + output_offset,
123
+ data,
124
+ mask=head_mask[:, None] & dim_mask[None, :],
125
+ )
126
+
127
+
128
+ @cached_triton_kernel(lambda _, kwargs: (kwargs["BLOCK_SIZE"]))
129
+ @triton.jit
130
+ def unpad_draft_extend_output_kernel(
131
+ raw_out_ptr, # Input raw output tensor (batch_size, token_per_batch, tp_q_head_num, v_head_dim)
132
+ output_ptr, # Output tensor (-1, tp_q_head_num, v_head_dim)
133
+ accept_length_ptr, # Accept lengths for each sequence [batch_size]
134
+ cumsum_ptr, # Cumulative sum of accept lengths [batch_size + 1]
135
+ batch_size,
136
+ token_per_batch,
137
+ tp_q_head_num,
138
+ v_head_dim,
139
+ BLOCK_SIZE: tl.constexpr,
140
+ ):
141
+ """Triton kernel for unpadding draft extended output tensor with parallelized head and dim processing."""
142
+ batch_seq_pid = tl.program_id(0)
143
+ head_pid = tl.program_id(1)
144
+ dim_pid = tl.program_id(2)
145
+
146
+ batch_id = batch_seq_pid // token_per_batch
147
+ seq_pos = batch_seq_pid % token_per_batch
148
+
149
+ if batch_id >= batch_size:
150
+ return
151
+
152
+ # Load accept length for this batch
153
+ accept_len = tl.load(accept_length_ptr + batch_id)
154
+
155
+ if seq_pos >= accept_len:
156
+ return
157
+
158
+ # Load cumulative sum to get start position in output tensor
159
+ output_start = tl.load(cumsum_ptr + batch_id)
160
+ output_pos = output_start + seq_pos
161
+
162
+ # Calculate head and dim block ranges
163
+ head_start = head_pid * BLOCK_SIZE
164
+ head_end = tl.minimum(head_start + BLOCK_SIZE, tp_q_head_num)
165
+ head_mask = tl.arange(0, BLOCK_SIZE) < (head_end - head_start)
166
+
167
+ dim_start = dim_pid * BLOCK_SIZE
168
+ dim_end = tl.minimum(dim_start + BLOCK_SIZE, v_head_dim)
169
+ dim_mask = tl.arange(0, BLOCK_SIZE) < (dim_end - dim_start)
170
+
171
+ # Calculate input offset: (batch_id, seq_pos, head_id, dim_id)
172
+ input_offset = (
173
+ batch_id * token_per_batch * tp_q_head_num * v_head_dim
174
+ + seq_pos * tp_q_head_num * v_head_dim
175
+ + (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * v_head_dim
176
+ + (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]
177
+ )
178
+
179
+ # Load data
180
+ data = tl.load(
181
+ raw_out_ptr + input_offset,
182
+ mask=head_mask[:, None] & dim_mask[None, :],
183
+ other=0.0,
184
+ )
185
+
186
+ output_offset = (
187
+ output_pos * tp_q_head_num * v_head_dim
188
+ + (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * v_head_dim
189
+ + (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]
190
+ )
191
+
192
+ # Store data
193
+ tl.store(
194
+ output_ptr + output_offset,
195
+ data,
196
+ mask=head_mask[:, None] & dim_mask[None, :],
197
+ )
198
+
199
+
51
200
  global_zero_init_workspace_buffer = None
52
201
 
53
202
 
@@ -65,7 +214,11 @@ class TRTLLMMLADecodeMetadata:
65
214
  """Metadata for TRTLLM MLA decode operations."""
66
215
 
67
216
  block_kv_indices: Optional[torch.Tensor] = None
68
- max_seq_len: Optional[int] = None
217
+ max_seq_len_k: Optional[int] = None
218
+ max_seq_len_q: Optional[int] = None
219
+ sum_seq_lens_q: Optional[int] = None
220
+ cu_seqlens_q: Optional[torch.Tensor] = None
221
+ seq_lens_q: Optional[torch.Tensor] = None
69
222
 
70
223
 
71
224
  class TRTLLMMLABackend(FlashInferMLAAttnBackend):
@@ -120,12 +273,14 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
120
273
  # CUDA graph state
121
274
  self.decode_cuda_graph_metadata = {}
122
275
  self.decode_cuda_graph_kv_indices = None
276
+ self.padded_q_buffer = None
277
+ self.unpad_output_buffer = None
123
278
  self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
124
279
  self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
125
280
 
126
- self.disable_chunked_prefix_cache = global_server_args_dict[
127
- "disable_chunked_prefix_cache"
128
- ]
281
+ self.disable_chunked_prefix_cache = (
282
+ get_global_server_args().disable_chunked_prefix_cache
283
+ )
129
284
 
130
285
  self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
131
286
 
@@ -143,9 +298,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
143
298
 
144
299
  # Apply dual constraints (take LCM to satisfy both):
145
300
  # 1. TRT-LLM: block_num % (128 / page_size) == 0
146
- # 2. Triton: page table builder uses 64-index bursts, needs multiple of 64
301
+ # 2. Triton: number of pages per block
147
302
  trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size
148
- constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK)
303
+ triton_constraint = get_num_page_per_block_flashmla(self.page_size)
304
+ constraint_lcm = math.lcm(trtllm_constraint, triton_constraint)
149
305
 
150
306
  if blocks % constraint_lcm != 0:
151
307
  blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm
@@ -184,7 +340,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
184
340
  block_kv_indices,
185
341
  self.req_to_token.stride(0),
186
342
  max_blocks,
187
- NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
188
343
  PAGED_SIZE=self.page_size,
189
344
  )
190
345
 
@@ -203,6 +358,21 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
203
358
  self.decode_cuda_graph_kv_indices = torch.full(
204
359
  (max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
205
360
  )
361
+ num_tokens_per_bs = max_num_tokens // max_bs
362
+
363
+ # Buffer for padded query: (max_bs, max_draft_tokens, num_q_heads, v_head_dim)
364
+ self.padded_q_buffer = torch.zeros(
365
+ (max_bs, num_tokens_per_bs, self.num_q_heads, self.kv_cache_dim),
366
+ dtype=self.data_type,
367
+ device=self.device,
368
+ )
369
+
370
+ # Buffer for unpadded output: (max_num_tokens, num_q_heads, v_head_dim)
371
+ self.unpad_output_buffer = torch.zeros(
372
+ (max_num_tokens, self.num_q_heads, 512),
373
+ dtype=self.data_type,
374
+ device=self.device,
375
+ )
206
376
 
207
377
  super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
208
378
 
@@ -219,7 +389,11 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
219
389
  """Initialize metadata for CUDA graph capture."""
220
390
 
221
391
  # Delegate to parent for non-decode modes.
222
- if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
392
+ if (
393
+ not forward_mode.is_decode_or_idle()
394
+ and not forward_mode.is_target_verify()
395
+ and not forward_mode.is_draft_extend(include_v2=True)
396
+ ):
223
397
  return super().init_forward_metadata_capture_cuda_graph(
224
398
  bs,
225
399
  num_tokens,
@@ -246,19 +420,27 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
246
420
  block_kv_indices,
247
421
  self.req_to_token.stride(0),
248
422
  max_blocks_per_seq,
249
- NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
250
423
  PAGED_SIZE=self.page_size,
251
424
  )
252
425
 
253
- # Record the true maximum sequence length for this capture batch so that
254
- # the kernel launch path (which requires an int not a tensor) can reuse
255
- # it safely during both capture and replay.
256
- max_seq_len_val = int(seq_lens.max().item())
257
-
258
426
  metadata = TRTLLMMLADecodeMetadata(
259
427
  block_kv_indices,
260
- max_seq_len_val,
428
+ self.max_context_len,
261
429
  )
430
+ if forward_mode.is_draft_extend(include_v2=True):
431
+ num_tokens_per_bs = num_tokens // bs
432
+ metadata.max_seq_len_q = num_tokens_per_bs + 1
433
+ metadata.sum_seq_lens_q = num_tokens_per_bs * bs
434
+ metadata.cu_seqlens_q = torch.arange(
435
+ 0,
436
+ bs * num_tokens_per_bs + 1,
437
+ num_tokens_per_bs,
438
+ dtype=torch.int32,
439
+ device=seq_lens.device,
440
+ )
441
+ metadata.seq_lens_q = torch.full(
442
+ (bs,), num_tokens_per_bs, dtype=torch.int32, device=seq_lens.device
443
+ )
262
444
  self.decode_cuda_graph_metadata[bs] = metadata
263
445
  self.forward_decode_metadata = metadata
264
446
 
@@ -275,7 +457,11 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
275
457
  ):
276
458
  """Replay CUDA graph with new inputs."""
277
459
  # Delegate to parent for non-decode modes.
278
- if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
460
+ if (
461
+ not forward_mode.is_decode_or_idle()
462
+ and not forward_mode.is_target_verify()
463
+ and not forward_mode.is_draft_extend(include_v2=True)
464
+ ):
279
465
  return super().init_forward_metadata_replay_cuda_graph(
280
466
  bs,
281
467
  req_pool_indices,
@@ -293,6 +479,19 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
293
479
 
294
480
  metadata = self.decode_cuda_graph_metadata[bs]
295
481
 
482
+ if forward_mode.is_draft_extend(include_v2=True):
483
+ accept_length = spec_info.accept_length[:bs]
484
+ if spec_info.accept_length_cpu:
485
+ metadata.max_seq_len_q = max(spec_info.accept_length_cpu[:bs])
486
+ metadata.sum_seq_lens_q = sum(spec_info.accept_length_cpu[:bs])
487
+ else:
488
+ metadata.max_seq_len_q = 1
489
+ metadata.sum_seq_lens_q = bs
490
+ metadata.cu_seqlens_q[1:].copy_(
491
+ torch.cumsum(accept_length, dim=0, dtype=torch.int32)
492
+ )
493
+ metadata.seq_lens_q.copy_(accept_length)
494
+
296
495
  # Update block indices for new sequences.
297
496
  create_flashmla_kv_indices_triton[(bs,)](
298
497
  self.req_to_token,
@@ -302,17 +501,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
302
501
  metadata.block_kv_indices,
303
502
  self.req_to_token.stride(0),
304
503
  metadata.block_kv_indices.shape[1],
305
- NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
306
504
  PAGED_SIZE=self.page_size,
307
505
  )
308
506
 
309
- # Update stored max_seq_len so subsequent kernel calls use the correct value
310
- # Prefer CPU tensor to avoid GPU synchronization when available.
311
- if seq_lens_cpu is not None:
312
- metadata.max_seq_len = int(seq_lens_cpu.max().item())
313
- else:
314
- metadata.max_seq_len = int(seq_lens.max().item())
315
-
316
507
  def get_cuda_graph_seq_len_fill_value(self) -> int:
317
508
  """Get the fill value for sequence lengths in CUDA graph."""
318
509
  return 1
@@ -323,7 +514,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
323
514
  if (
324
515
  forward_batch.forward_mode.is_extend()
325
516
  and not forward_batch.forward_mode.is_target_verify()
326
- and not forward_batch.forward_mode.is_draft_extend()
517
+ and not forward_batch.forward_mode.is_draft_extend(include_v2=True)
327
518
  ):
328
519
  if self.disable_chunked_prefix_cache:
329
520
  super().init_forward_metadata(forward_batch)
@@ -344,6 +535,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
344
535
  elif (
345
536
  forward_batch.forward_mode.is_decode_or_idle()
346
537
  or forward_batch.forward_mode.is_target_verify()
538
+ or forward_batch.forward_mode.is_draft_extend(include_v2=True)
347
539
  ):
348
540
  bs = forward_batch.batch_size
349
541
 
@@ -372,6 +564,23 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
372
564
  self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
373
565
  block_kv_indices, max_seq_len_val
374
566
  )
567
+ if forward_batch.forward_mode.is_draft_extend(include_v2=True):
568
+ max_seq = forward_batch.seq_lens_cpu.max().item()
569
+
570
+ sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu)
571
+ max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
572
+ cu_seqlens_q = torch.nn.functional.pad(
573
+ torch.cumsum(
574
+ forward_batch.extend_seq_lens, dim=0, dtype=torch.int32
575
+ ),
576
+ (1, 0),
577
+ )
578
+
579
+ self.forward_decode_metadata.max_seq_len_q = max_seq_len_q
580
+ self.forward_decode_metadata.sum_seq_lens_q = sum_seq_lens_q
581
+ self.forward_decode_metadata.cu_seqlens_q = cu_seqlens_q
582
+ self.forward_decode_metadata.seq_lens_q = forward_batch.extend_seq_lens
583
+
375
584
  forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
376
585
  else:
377
586
  return super().init_forward_metadata(forward_batch)
@@ -457,6 +666,86 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
457
666
 
458
667
  return q_out, k_nope_out, k_rope_out
459
668
 
669
+ def pad_draft_extend_query(
670
+ self,
671
+ q: torch.Tensor,
672
+ padded_q: torch.Tensor,
673
+ seq_lens_q: torch.Tensor,
674
+ cu_seqlens_q: torch.Tensor,
675
+ ) -> torch.Tensor:
676
+ """Pad draft extended query using Triton kernel."""
677
+ batch_size = cu_seqlens_q.shape[0] - 1
678
+ max_seq_len_q = padded_q.shape[1]
679
+ num_heads = padded_q.shape[2]
680
+ head_dim = padded_q.shape[3]
681
+
682
+ # Launch Triton kernel with 3D grid for parallelized head and dim processing
683
+ BLOCK_SIZE = 64
684
+ num_head_blocks = triton.cdiv(num_heads, BLOCK_SIZE)
685
+ num_dim_blocks = triton.cdiv(head_dim, BLOCK_SIZE)
686
+ grid = (batch_size * max_seq_len_q, num_head_blocks, num_dim_blocks)
687
+
688
+ pad_draft_extend_query_kernel[grid](
689
+ q_ptr=q,
690
+ padded_q_ptr=padded_q,
691
+ seq_lens_q_ptr=seq_lens_q,
692
+ cumsum_ptr=cu_seqlens_q,
693
+ batch_size=batch_size,
694
+ max_seq_len=max_seq_len_q,
695
+ num_heads=num_heads,
696
+ head_dim=head_dim,
697
+ BLOCK_SIZE=BLOCK_SIZE,
698
+ )
699
+ return padded_q
700
+
701
+ def unpad_draft_extend_output(
702
+ self,
703
+ raw_out: torch.Tensor,
704
+ cu_seqlens_q: torch.Tensor,
705
+ seq_lens_q: torch.Tensor,
706
+ sum_seq_lens_q: int,
707
+ ) -> torch.Tensor:
708
+ """Unpad draft extended output using Triton kernel."""
709
+ # raw_out: (batch_size, token_per_batch, layer.tp_q_head_num, layer.v_head_dim)
710
+ batch_size = seq_lens_q.shape[0]
711
+ token_per_batch = raw_out.shape[1] # max_seq_len
712
+ tp_q_head_num = raw_out.shape[2] # num_heads
713
+ v_head_dim = raw_out.shape[3] # head_dim
714
+ total_tokens = sum_seq_lens_q
715
+
716
+ # Check if we're in CUDA graph mode (buffers are pre-allocated)
717
+ if self.unpad_output_buffer is not None:
718
+ # Use pre-allocated buffer for CUDA graph compatibility
719
+ output = self.unpad_output_buffer[:total_tokens, :, :].to(
720
+ dtype=raw_out.dtype
721
+ )
722
+ else:
723
+ # Dynamic allocation for non-CUDA graph mode
724
+ output = torch.empty(
725
+ (total_tokens, tp_q_head_num, v_head_dim),
726
+ dtype=raw_out.dtype,
727
+ device=raw_out.device,
728
+ )
729
+
730
+ # Launch Triton kernel with 3D grid for parallelized head and dim processing
731
+ BLOCK_SIZE = 64
732
+ num_head_blocks = triton.cdiv(tp_q_head_num, BLOCK_SIZE)
733
+ num_dim_blocks = triton.cdiv(v_head_dim, BLOCK_SIZE)
734
+ grid = (batch_size * token_per_batch, num_head_blocks, num_dim_blocks)
735
+
736
+ unpad_draft_extend_output_kernel[grid](
737
+ raw_out_ptr=raw_out,
738
+ output_ptr=output,
739
+ accept_length_ptr=seq_lens_q,
740
+ cumsum_ptr=cu_seqlens_q,
741
+ batch_size=batch_size,
742
+ token_per_batch=token_per_batch,
743
+ tp_q_head_num=tp_q_head_num,
744
+ v_head_dim=v_head_dim,
745
+ BLOCK_SIZE=BLOCK_SIZE,
746
+ )
747
+ return output[:total_tokens, :, :]
748
+
460
749
  def forward_decode(
461
750
  self,
462
751
  q: torch.Tensor, # q_nope
@@ -550,7 +839,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
550
839
  qk_rope_head_dim=self.qk_rope_head_dim,
551
840
  block_tables=metadata.block_kv_indices,
552
841
  seq_lens=forward_batch.seq_lens.to(torch.int32),
553
- max_seq_len=metadata.max_seq_len,
842
+ max_seq_len=metadata.max_seq_len_k,
554
843
  bmm1_scale=bmm1_scale,
555
844
  )
556
845
 
@@ -571,11 +860,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
571
860
  cos_sin_cache: Optional[torch.Tensor] = None,
572
861
  is_neox: Optional[bool] = False,
573
862
  ) -> torch.Tensor:
574
- if forward_batch.forward_mode.is_draft_extend():
575
- return super().forward_extend(
576
- q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
577
- )
578
-
579
863
  # TODO refactor to avoid code duplication
580
864
  merge_query = q_rope is not None
581
865
  if (
@@ -627,7 +911,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
627
911
 
628
912
  v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
629
913
 
630
- if forward_batch.forward_mode.is_target_verify():
914
+ if (
915
+ forward_batch.forward_mode.is_target_verify()
916
+ or forward_batch.forward_mode.is_draft_extend(include_v2=True)
917
+ ):
631
918
  metadata = (
632
919
  getattr(forward_batch, "decode_trtllm_mla_metadata", None)
633
920
  or self.forward_decode_metadata
@@ -635,7 +922,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
635
922
 
636
923
  # Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]
637
924
  bs = forward_batch.batch_size
638
- q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
639
925
 
640
926
  k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
641
927
  kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
@@ -646,17 +932,42 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
646
932
  if getattr(layer, "k_scale_float", None) is not None
647
933
  else 1.0
648
934
  )
935
+ q = q.to(self.data_type)
649
936
 
650
937
  bmm1_scale = q_scale * k_scale * layer.scaling
651
-
652
- seq_lens = (
653
- forward_batch.seq_lens.to(torch.int32)
654
- + forward_batch.spec_info.draft_token_num
655
- )
656
- max_seq_len = metadata.max_seq_len + forward_batch.spec_info.draft_token_num
938
+ if forward_batch.forward_mode.is_target_verify():
939
+ seq_lens = (
940
+ forward_batch.seq_lens.to(torch.int32)
941
+ + forward_batch.spec_info.draft_token_num
942
+ )
943
+ max_seq_len = (
944
+ metadata.max_seq_len_k + forward_batch.spec_info.draft_token_num
945
+ )
946
+ else:
947
+ seq_lens = forward_batch.seq_lens.to(torch.int32)
948
+ max_seq_len = metadata.max_seq_len_k
949
+ # Check if we're in CUDA graph mode (buffers are pre-allocated)
950
+ if self.padded_q_buffer is not None:
951
+ # Use pre-allocated buffer for CUDA graph compatibility
952
+ padded_q = self.padded_q_buffer[
953
+ :bs, : metadata.max_seq_len_q, :, :
954
+ ].to(dtype=q.dtype)
955
+ else:
956
+ # Dynamic allocation for non-CUDA graph mode
957
+ padded_q = torch.zeros(
958
+ bs,
959
+ metadata.max_seq_len_q,
960
+ layer.tp_q_head_num,
961
+ layer.head_dim,
962
+ dtype=q.dtype,
963
+ device=q.device,
964
+ )
965
+ q = self.pad_draft_extend_query(
966
+ q, padded_q, metadata.seq_lens_q, metadata.cu_seqlens_q
967
+ )
657
968
 
658
969
  # TODO may use `mla_rope_quantize_fp8` fusion
659
- q = q.to(self.data_type)
970
+ q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
660
971
  assert kv_cache.dtype == self.data_type
661
972
 
662
973
  raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
@@ -673,6 +984,14 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
673
984
  )
674
985
 
675
986
  # Reshape output directly without slicing
987
+
988
+ if forward_batch.forward_mode.is_draft_extend(include_v2=True):
989
+ raw_out = self.unpad_draft_extend_output(
990
+ raw_out,
991
+ metadata.cu_seqlens_q,
992
+ metadata.seq_lens_q,
993
+ metadata.sum_seq_lens_q,
994
+ )
676
995
  output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
677
996
  return output
678
997
 
@@ -735,7 +1054,7 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
735
1054
  ):
736
1055
  super().__init__(model_runner, topk, speculative_num_steps)
737
1056
 
738
- for i in range(self.speculative_num_steps):
1057
+ for i in range(self.speculative_num_steps - 1):
739
1058
  self.attn_backends[i] = TRTLLMMLABackend(
740
1059
  model_runner,
741
1060
  skip_prefill=True,