sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -10,19 +10,21 @@ from typing import TYPE_CHECKING, Optional, Union
10
10
 
11
11
  import torch
12
12
  import triton
13
+ import triton.language as tl
13
14
 
14
15
  from sglang.srt.layers.attention.flashinfer_mla_backend import (
15
16
  FlashInferMLAAttnBackend,
16
17
  FlashInferMLAMultiStepDraftBackend,
17
18
  )
18
19
  from sglang.srt.layers.attention.utils import (
19
- TRITON_PAD_NUM_PAGE_PER_BLOCK,
20
20
  create_flashmla_kv_indices_triton,
21
+ get_num_page_per_block_flashmla,
21
22
  )
22
23
  from sglang.srt.layers.dp_attention import get_attention_tp_size
23
- from sglang.srt.managers.schedule_batch import global_server_args_dict
24
24
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
25
- from sglang.srt.utils import is_flashinfer_available
25
+ from sglang.srt.server_args import get_global_server_args
26
+ from sglang.srt.utils import is_cuda, is_flashinfer_available
27
+ from sglang.srt.utils.common import cached_triton_kernel
26
28
 
27
29
  if is_flashinfer_available():
28
30
  import flashinfer
@@ -30,7 +32,12 @@ if is_flashinfer_available():
30
32
  if TYPE_CHECKING:
31
33
  from sglang.srt.layers.radix_attention import RadixAttention
32
34
  from sglang.srt.model_executor.model_runner import ModelRunner
33
- from sglang.srt.speculative.spec_info import SpecInfo
35
+ from sglang.srt.speculative.spec_info import SpecInput
36
+
37
+ _is_cuda = is_cuda()
38
+
39
+ if _is_cuda:
40
+ from sgl_kernel import concat_mla_absorb_q
34
41
 
35
42
  # Constants
36
43
  DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
@@ -43,6 +50,153 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
43
50
  # compute the LCM with other padding constraints.
44
51
  TRTLLM_BLOCK_CONSTRAINT = 128
45
52
 
53
+
54
+ @cached_triton_kernel(lambda _, kwargs: (kwargs["BLOCK_SIZE"]))
55
+ @triton.jit
56
+ def pad_draft_extend_query_kernel(
57
+ q_ptr, # Input query tensor [total_seq_len, num_heads, head_dim]
58
+ padded_q_ptr, # Output padded query tensor [batch_size, max_seq_len, num_heads, head_dim]
59
+ seq_lens_q_ptr, # Sequence lengths for each sequence [batch_size]
60
+ cumsum_ptr, # Cumulative sum of accept lengths [batch_size + 1]
61
+ batch_size,
62
+ max_seq_len,
63
+ num_heads,
64
+ head_dim,
65
+ BLOCK_SIZE: tl.constexpr,
66
+ ):
67
+ """Triton kernel for padding draft extended query tensor with parallelized head and dim processing."""
68
+ # Use 3D program IDs: (batch_seq, head_block, dim_block)
69
+ batch_seq_pid = tl.program_id(0)
70
+ head_pid = tl.program_id(1)
71
+ dim_pid = tl.program_id(2)
72
+
73
+ batch_id = batch_seq_pid // max_seq_len
74
+ seq_pos = batch_seq_pid % max_seq_len
75
+
76
+ if batch_id >= batch_size:
77
+ return
78
+
79
+ # Load accept length for this batch
80
+ seq_len = tl.load(seq_lens_q_ptr + batch_id)
81
+
82
+ if seq_pos >= seq_len:
83
+ return
84
+
85
+ # Load cumulative sum to get start position in input tensor
86
+ input_start = tl.load(cumsum_ptr + batch_id)
87
+ input_pos = input_start + seq_pos
88
+
89
+ # Calculate head and dim block ranges
90
+ head_start = head_pid * BLOCK_SIZE
91
+ head_end = tl.minimum(head_start + BLOCK_SIZE, num_heads)
92
+ head_mask = tl.arange(0, BLOCK_SIZE) < (head_end - head_start)
93
+
94
+ dim_start = dim_pid * BLOCK_SIZE
95
+ dim_end = tl.minimum(dim_start + BLOCK_SIZE, head_dim)
96
+ dim_mask = tl.arange(0, BLOCK_SIZE) < (dim_end - dim_start)
97
+
98
+ # Calculate input offset
99
+ input_offset = (
100
+ input_pos * num_heads * head_dim
101
+ + (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * head_dim
102
+ + (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]
103
+ )
104
+
105
+ # Load data
106
+ data = tl.load(
107
+ q_ptr + input_offset,
108
+ mask=head_mask[:, None] & dim_mask[None, :],
109
+ other=0.0,
110
+ )
111
+
112
+ # Calculate output offset
113
+ output_offset = (
114
+ batch_id * max_seq_len * num_heads * head_dim
115
+ + seq_pos * num_heads * head_dim
116
+ + (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * head_dim
117
+ + (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]
118
+ )
119
+
120
+ # Store data
121
+ tl.store(
122
+ padded_q_ptr + output_offset,
123
+ data,
124
+ mask=head_mask[:, None] & dim_mask[None, :],
125
+ )
126
+
127
+
128
+ @cached_triton_kernel(lambda _, kwargs: (kwargs["BLOCK_SIZE"]))
129
+ @triton.jit
130
+ def unpad_draft_extend_output_kernel(
131
+ raw_out_ptr, # Input raw output tensor (batch_size, token_per_batch, tp_q_head_num, v_head_dim)
132
+ output_ptr, # Output tensor (-1, tp_q_head_num, v_head_dim)
133
+ accept_length_ptr, # Accept lengths for each sequence [batch_size]
134
+ cumsum_ptr, # Cumulative sum of accept lengths [batch_size + 1]
135
+ batch_size,
136
+ token_per_batch,
137
+ tp_q_head_num,
138
+ v_head_dim,
139
+ BLOCK_SIZE: tl.constexpr,
140
+ ):
141
+ """Triton kernel for unpadding draft extended output tensor with parallelized head and dim processing."""
142
+ batch_seq_pid = tl.program_id(0)
143
+ head_pid = tl.program_id(1)
144
+ dim_pid = tl.program_id(2)
145
+
146
+ batch_id = batch_seq_pid // token_per_batch
147
+ seq_pos = batch_seq_pid % token_per_batch
148
+
149
+ if batch_id >= batch_size:
150
+ return
151
+
152
+ # Load accept length for this batch
153
+ accept_len = tl.load(accept_length_ptr + batch_id)
154
+
155
+ if seq_pos >= accept_len:
156
+ return
157
+
158
+ # Load cumulative sum to get start position in output tensor
159
+ output_start = tl.load(cumsum_ptr + batch_id)
160
+ output_pos = output_start + seq_pos
161
+
162
+ # Calculate head and dim block ranges
163
+ head_start = head_pid * BLOCK_SIZE
164
+ head_end = tl.minimum(head_start + BLOCK_SIZE, tp_q_head_num)
165
+ head_mask = tl.arange(0, BLOCK_SIZE) < (head_end - head_start)
166
+
167
+ dim_start = dim_pid * BLOCK_SIZE
168
+ dim_end = tl.minimum(dim_start + BLOCK_SIZE, v_head_dim)
169
+ dim_mask = tl.arange(0, BLOCK_SIZE) < (dim_end - dim_start)
170
+
171
+ # Calculate input offset: (batch_id, seq_pos, head_id, dim_id)
172
+ input_offset = (
173
+ batch_id * token_per_batch * tp_q_head_num * v_head_dim
174
+ + seq_pos * tp_q_head_num * v_head_dim
175
+ + (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * v_head_dim
176
+ + (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]
177
+ )
178
+
179
+ # Load data
180
+ data = tl.load(
181
+ raw_out_ptr + input_offset,
182
+ mask=head_mask[:, None] & dim_mask[None, :],
183
+ other=0.0,
184
+ )
185
+
186
+ output_offset = (
187
+ output_pos * tp_q_head_num * v_head_dim
188
+ + (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * v_head_dim
189
+ + (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]
190
+ )
191
+
192
+ # Store data
193
+ tl.store(
194
+ output_ptr + output_offset,
195
+ data,
196
+ mask=head_mask[:, None] & dim_mask[None, :],
197
+ )
198
+
199
+
46
200
  global_zero_init_workspace_buffer = None
47
201
 
48
202
 
@@ -60,7 +214,11 @@ class TRTLLMMLADecodeMetadata:
60
214
  """Metadata for TRTLLM MLA decode operations."""
61
215
 
62
216
  block_kv_indices: Optional[torch.Tensor] = None
63
- max_seq_len: Optional[int] = None
217
+ max_seq_len_k: Optional[int] = None
218
+ max_seq_len_q: Optional[int] = None
219
+ sum_seq_lens_q: Optional[int] = None
220
+ cu_seqlens_q: Optional[torch.Tensor] = None
221
+ seq_lens_q: Optional[torch.Tensor] = None
64
222
 
65
223
 
66
224
  class TRTLLMMLABackend(FlashInferMLAAttnBackend):
@@ -115,12 +273,16 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
115
273
  # CUDA graph state
116
274
  self.decode_cuda_graph_metadata = {}
117
275
  self.decode_cuda_graph_kv_indices = None
276
+ self.padded_q_buffer = None
277
+ self.unpad_output_buffer = None
118
278
  self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
119
279
  self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
120
280
 
121
- self.disable_chunked_prefix_cache = global_server_args_dict[
122
- "disable_chunked_prefix_cache"
123
- ]
281
+ self.disable_chunked_prefix_cache = (
282
+ get_global_server_args().disable_chunked_prefix_cache
283
+ )
284
+
285
+ self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
124
286
 
125
287
  def _calc_padded_blocks(self, max_seq_len: int) -> int:
126
288
  """
@@ -136,9 +298,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
136
298
 
137
299
  # Apply dual constraints (take LCM to satisfy both):
138
300
  # 1. TRT-LLM: block_num % (128 / page_size) == 0
139
- # 2. Triton: page table builder uses 64-index bursts, needs multiple of 64
301
+ # 2. Triton: number of pages per block
140
302
  trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size
141
- constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK)
303
+ triton_constraint = get_num_page_per_block_flashmla(self.page_size)
304
+ constraint_lcm = math.lcm(trtllm_constraint, triton_constraint)
142
305
 
143
306
  if blocks % constraint_lcm != 0:
144
307
  blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm
@@ -177,7 +340,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
177
340
  block_kv_indices,
178
341
  self.req_to_token.stride(0),
179
342
  max_blocks,
180
- NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
181
343
  PAGED_SIZE=self.page_size,
182
344
  )
183
345
 
@@ -196,6 +358,21 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
196
358
  self.decode_cuda_graph_kv_indices = torch.full(
197
359
  (max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
198
360
  )
361
+ num_tokens_per_bs = max_num_tokens // max_bs
362
+
363
+ # Buffer for padded query: (max_bs, max_draft_tokens, num_q_heads, v_head_dim)
364
+ self.padded_q_buffer = torch.zeros(
365
+ (max_bs, num_tokens_per_bs, self.num_q_heads, self.kv_cache_dim),
366
+ dtype=self.data_type,
367
+ device=self.device,
368
+ )
369
+
370
+ # Buffer for unpadded output: (max_num_tokens, num_q_heads, v_head_dim)
371
+ self.unpad_output_buffer = torch.zeros(
372
+ (max_num_tokens, self.num_q_heads, 512),
373
+ dtype=self.data_type,
374
+ device=self.device,
375
+ )
199
376
 
200
377
  super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
201
378
 
@@ -207,12 +384,16 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
207
384
  seq_lens: torch.Tensor,
208
385
  encoder_lens: Optional[torch.Tensor],
209
386
  forward_mode: ForwardMode,
210
- spec_info: Optional[SpecInfo],
387
+ spec_info: Optional[SpecInput],
211
388
  ):
212
389
  """Initialize metadata for CUDA graph capture."""
213
390
 
214
391
  # Delegate to parent for non-decode modes.
215
- if not forward_mode.is_decode_or_idle():
392
+ if (
393
+ not forward_mode.is_decode_or_idle()
394
+ and not forward_mode.is_target_verify()
395
+ and not forward_mode.is_draft_extend(include_v2=True)
396
+ ):
216
397
  return super().init_forward_metadata_capture_cuda_graph(
217
398
  bs,
218
399
  num_tokens,
@@ -223,6 +404,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
223
404
  spec_info,
224
405
  )
225
406
 
407
+ if forward_mode.is_target_verify():
408
+ seq_lens = seq_lens + self.num_draft_tokens
409
+
226
410
  # Custom fast-path for decode/idle.
227
411
  # Capture with full width so future longer sequences are safe during replay
228
412
  max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
@@ -236,7 +420,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
236
420
  block_kv_indices,
237
421
  self.req_to_token.stride(0),
238
422
  max_blocks_per_seq,
239
- NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
240
423
  PAGED_SIZE=self.page_size,
241
424
  )
242
425
 
@@ -249,6 +432,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
249
432
  block_kv_indices,
250
433
  max_seq_len_val,
251
434
  )
435
+ if forward_mode.is_draft_extend(include_v2=True):
436
+ num_tokens_per_bs = num_tokens // bs
437
+ metadata.max_seq_len_q = num_tokens_per_bs + 1
438
+ metadata.sum_seq_lens_q = num_tokens_per_bs * bs
439
+ metadata.cu_seqlens_q = torch.arange(
440
+ 0,
441
+ bs * num_tokens_per_bs + 1,
442
+ num_tokens_per_bs,
443
+ dtype=torch.int32,
444
+ device=seq_lens.device,
445
+ )
446
+ metadata.seq_lens_q = torch.full(
447
+ (bs,), num_tokens_per_bs, dtype=torch.int32, device=seq_lens.device
448
+ )
252
449
  self.decode_cuda_graph_metadata[bs] = metadata
253
450
  self.forward_decode_metadata = metadata
254
451
 
@@ -260,12 +457,16 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
260
457
  seq_lens_sum: int,
261
458
  encoder_lens: Optional[torch.Tensor],
262
459
  forward_mode: ForwardMode,
263
- spec_info: Optional[SpecInfo],
460
+ spec_info: Optional[SpecInput],
264
461
  seq_lens_cpu: Optional[torch.Tensor],
265
462
  ):
266
463
  """Replay CUDA graph with new inputs."""
267
464
  # Delegate to parent for non-decode modes.
268
- if not forward_mode.is_decode_or_idle():
465
+ if (
466
+ not forward_mode.is_decode_or_idle()
467
+ and not forward_mode.is_target_verify()
468
+ and not forward_mode.is_draft_extend(include_v2=True)
469
+ ):
269
470
  return super().init_forward_metadata_replay_cuda_graph(
270
471
  bs,
271
472
  req_pool_indices,
@@ -277,8 +478,25 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
277
478
  seq_lens_cpu,
278
479
  )
279
480
 
481
+ if forward_mode.is_target_verify():
482
+ seq_lens = seq_lens + self.num_draft_tokens
483
+ del seq_lens_sum # not handle "num_draft_tokens" but we do not need it
484
+
280
485
  metadata = self.decode_cuda_graph_metadata[bs]
281
486
 
487
+ if forward_mode.is_draft_extend(include_v2=True):
488
+ accept_length = spec_info.accept_length[:bs]
489
+ if spec_info.accept_length_cpu:
490
+ metadata.max_seq_len_q = max(spec_info.accept_length_cpu[:bs])
491
+ metadata.sum_seq_lens_q = sum(spec_info.accept_length_cpu[:bs])
492
+ else:
493
+ metadata.max_seq_len_q = 1
494
+ metadata.sum_seq_lens_q = bs
495
+ metadata.cu_seqlens_q[1:].copy_(
496
+ torch.cumsum(accept_length, dim=0, dtype=torch.int32)
497
+ )
498
+ metadata.seq_lens_q.copy_(accept_length)
499
+
282
500
  # Update block indices for new sequences.
283
501
  create_flashmla_kv_indices_triton[(bs,)](
284
502
  self.req_to_token,
@@ -288,7 +506,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
288
506
  metadata.block_kv_indices,
289
507
  self.req_to_token.stride(0),
290
508
  metadata.block_kv_indices.shape[1],
291
- NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
292
509
  PAGED_SIZE=self.page_size,
293
510
  )
294
511
 
@@ -309,7 +526,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
309
526
  if (
310
527
  forward_batch.forward_mode.is_extend()
311
528
  and not forward_batch.forward_mode.is_target_verify()
312
- and not forward_batch.forward_mode.is_draft_extend()
529
+ and not forward_batch.forward_mode.is_draft_extend(include_v2=True)
313
530
  ):
314
531
  if self.disable_chunked_prefix_cache:
315
532
  super().init_forward_metadata(forward_batch)
@@ -327,7 +544,11 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
327
544
  cum_seq_lens_q,
328
545
  seq_lens,
329
546
  )
330
- elif forward_batch.forward_mode.is_decode_or_idle():
547
+ elif (
548
+ forward_batch.forward_mode.is_decode_or_idle()
549
+ or forward_batch.forward_mode.is_target_verify()
550
+ or forward_batch.forward_mode.is_draft_extend(include_v2=True)
551
+ ):
331
552
  bs = forward_batch.batch_size
332
553
 
333
554
  # Get maximum sequence length.
@@ -336,19 +557,42 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
336
557
  else:
337
558
  max_seq = forward_batch.seq_lens.max().item()
338
559
 
560
+ seq_lens = forward_batch.seq_lens
561
+
562
+ if forward_batch.forward_mode.is_target_verify():
563
+ max_seq = max_seq + self.num_draft_tokens
564
+ seq_lens = seq_lens + self.num_draft_tokens
565
+
339
566
  max_seqlen_pad = self._calc_padded_blocks(max_seq)
340
567
  block_kv_indices = self._create_block_kv_indices(
341
568
  bs,
342
569
  max_seqlen_pad,
343
570
  forward_batch.req_pool_indices,
344
- forward_batch.seq_lens,
345
- forward_batch.seq_lens.device,
571
+ seq_lens,
572
+ seq_lens.device,
346
573
  )
347
574
 
348
575
  max_seq_len_val = int(max_seq)
349
576
  self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
350
577
  block_kv_indices, max_seq_len_val
351
578
  )
579
+ if forward_batch.forward_mode.is_draft_extend(include_v2=True):
580
+ max_seq = forward_batch.seq_lens_cpu.max().item()
581
+
582
+ sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu)
583
+ max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
584
+ cu_seqlens_q = torch.nn.functional.pad(
585
+ torch.cumsum(
586
+ forward_batch.extend_seq_lens, dim=0, dtype=torch.int32
587
+ ),
588
+ (1, 0),
589
+ )
590
+
591
+ self.forward_decode_metadata.max_seq_len_q = max_seq_len_q
592
+ self.forward_decode_metadata.sum_seq_lens_q = sum_seq_lens_q
593
+ self.forward_decode_metadata.cu_seqlens_q = cu_seqlens_q
594
+ self.forward_decode_metadata.seq_lens_q = forward_batch.extend_seq_lens
595
+
352
596
  forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
353
597
  else:
354
598
  return super().init_forward_metadata(forward_batch)
@@ -434,6 +678,86 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
434
678
 
435
679
  return q_out, k_nope_out, k_rope_out
436
680
 
681
+ def pad_draft_extend_query(
682
+ self,
683
+ q: torch.Tensor,
684
+ padded_q: torch.Tensor,
685
+ seq_lens_q: torch.Tensor,
686
+ cu_seqlens_q: torch.Tensor,
687
+ ) -> torch.Tensor:
688
+ """Pad draft extended query using Triton kernel."""
689
+ batch_size = cu_seqlens_q.shape[0] - 1
690
+ max_seq_len_q = padded_q.shape[1]
691
+ num_heads = padded_q.shape[2]
692
+ head_dim = padded_q.shape[3]
693
+
694
+ # Launch Triton kernel with 3D grid for parallelized head and dim processing
695
+ BLOCK_SIZE = 64
696
+ num_head_blocks = triton.cdiv(num_heads, BLOCK_SIZE)
697
+ num_dim_blocks = triton.cdiv(head_dim, BLOCK_SIZE)
698
+ grid = (batch_size * max_seq_len_q, num_head_blocks, num_dim_blocks)
699
+
700
+ pad_draft_extend_query_kernel[grid](
701
+ q_ptr=q,
702
+ padded_q_ptr=padded_q,
703
+ seq_lens_q_ptr=seq_lens_q,
704
+ cumsum_ptr=cu_seqlens_q,
705
+ batch_size=batch_size,
706
+ max_seq_len=max_seq_len_q,
707
+ num_heads=num_heads,
708
+ head_dim=head_dim,
709
+ BLOCK_SIZE=BLOCK_SIZE,
710
+ )
711
+ return padded_q
712
+
713
+ def unpad_draft_extend_output(
714
+ self,
715
+ raw_out: torch.Tensor,
716
+ cu_seqlens_q: torch.Tensor,
717
+ seq_lens_q: torch.Tensor,
718
+ sum_seq_lens_q: int,
719
+ ) -> torch.Tensor:
720
+ """Unpad draft extended output using Triton kernel."""
721
+ # raw_out: (batch_size, token_per_batch, layer.tp_q_head_num, layer.v_head_dim)
722
+ batch_size = seq_lens_q.shape[0]
723
+ token_per_batch = raw_out.shape[1] # max_seq_len
724
+ tp_q_head_num = raw_out.shape[2] # num_heads
725
+ v_head_dim = raw_out.shape[3] # head_dim
726
+ total_tokens = sum_seq_lens_q
727
+
728
+ # Check if we're in CUDA graph mode (buffers are pre-allocated)
729
+ if self.unpad_output_buffer is not None:
730
+ # Use pre-allocated buffer for CUDA graph compatibility
731
+ output = self.unpad_output_buffer[:total_tokens, :, :].to(
732
+ dtype=raw_out.dtype
733
+ )
734
+ else:
735
+ # Dynamic allocation for non-CUDA graph mode
736
+ output = torch.empty(
737
+ (total_tokens, tp_q_head_num, v_head_dim),
738
+ dtype=raw_out.dtype,
739
+ device=raw_out.device,
740
+ )
741
+
742
+ # Launch Triton kernel with 3D grid for parallelized head and dim processing
743
+ BLOCK_SIZE = 64
744
+ num_head_blocks = triton.cdiv(tp_q_head_num, BLOCK_SIZE)
745
+ num_dim_blocks = triton.cdiv(v_head_dim, BLOCK_SIZE)
746
+ grid = (batch_size * token_per_batch, num_head_blocks, num_dim_blocks)
747
+
748
+ unpad_draft_extend_output_kernel[grid](
749
+ raw_out_ptr=raw_out,
750
+ output_ptr=output,
751
+ accept_length_ptr=seq_lens_q,
752
+ cumsum_ptr=cu_seqlens_q,
753
+ batch_size=batch_size,
754
+ token_per_batch=token_per_batch,
755
+ tp_q_head_num=tp_q_head_num,
756
+ v_head_dim=v_head_dim,
757
+ BLOCK_SIZE=BLOCK_SIZE,
758
+ )
759
+ return output[:total_tokens, :, :]
760
+
437
761
  def forward_decode(
438
762
  self,
439
763
  q: torch.Tensor, # q_nope
@@ -482,7 +806,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
482
806
  q_rope_reshaped = q_rope.view(
483
807
  -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
484
808
  )
485
- query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
809
+ query = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped)
486
810
  else:
487
811
  # For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
488
812
  query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
@@ -527,7 +851,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
527
851
  qk_rope_head_dim=self.qk_rope_head_dim,
528
852
  block_tables=metadata.block_kv_indices,
529
853
  seq_lens=forward_batch.seq_lens.to(torch.int32),
530
- max_seq_len=metadata.max_seq_len,
854
+ max_seq_len=metadata.max_seq_len_k,
531
855
  bmm1_scale=bmm1_scale,
532
856
  )
533
857
 
@@ -545,49 +869,193 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
545
869
  save_kv_cache: bool = True,
546
870
  q_rope: Optional[torch.Tensor] = None,
547
871
  k_rope: Optional[torch.Tensor] = None,
872
+ cos_sin_cache: Optional[torch.Tensor] = None,
873
+ is_neox: Optional[bool] = False,
548
874
  ) -> torch.Tensor:
875
+ # TODO refactor to avoid code duplication
876
+ merge_query = q_rope is not None
877
+ if (
878
+ self.data_type == torch.float8_e4m3fn
879
+ ) and forward_batch.forward_mode.is_target_verify():
880
+ # For FP8 path, we quantize the query and rope parts and merge them into a single tensor
881
+ # Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend
882
+ assert all(
883
+ x is not None for x in [q_rope, k_rope, cos_sin_cache]
884
+ ), "For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None."
885
+ q, k, k_rope = self.quantize_and_rope_for_fp8(
886
+ q,
887
+ q_rope,
888
+ k.squeeze(1),
889
+ k_rope.squeeze(1),
890
+ forward_batch,
891
+ cos_sin_cache,
892
+ is_neox,
893
+ )
894
+ merge_query = False
895
+
896
+ # Save KV cache if requested
897
+ if save_kv_cache:
898
+ assert (
899
+ k is not None and k_rope is not None
900
+ ), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
901
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
902
+ layer, forward_batch.out_cache_loc, k, k_rope
903
+ )
904
+
905
+ # TODO refactor to avoid code duplication
906
+ # Prepare query tensor inline
907
+ if merge_query:
908
+ # For FP16 path, we merge the query and rope parts into a single tensor
909
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
910
+ q_rope_reshaped = q_rope.view(
911
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
912
+ )
913
+ q = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped)
914
+ else:
915
+ # For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
916
+ q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
917
+
918
+ q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
919
+
920
+ if k_rope is not None:
921
+ k = torch.cat([k, k_rope], dim=-1)
922
+ k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
923
+
924
+ v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
925
+
549
926
  if (
550
927
  forward_batch.forward_mode.is_target_verify()
551
- or forward_batch.forward_mode.is_draft_extend()
928
+ or forward_batch.forward_mode.is_draft_extend(include_v2=True)
552
929
  ):
553
- return super().forward_extend(
554
- q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
930
+ metadata = (
931
+ getattr(forward_batch, "decode_trtllm_mla_metadata", None)
932
+ or self.forward_decode_metadata
555
933
  )
556
- # chunked prefix cache is not enabled, use Flashinfer MLA prefill kernel
557
- if forward_batch.attn_attend_prefix_cache is None:
558
- return super().forward_extend(
559
- q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
934
+
935
+ # Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]
936
+ bs = forward_batch.batch_size
937
+
938
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
939
+ kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
940
+
941
+ q_scale = 1.0
942
+ k_scale = (
943
+ layer.k_scale_float
944
+ if getattr(layer, "k_scale_float", None) is not None
945
+ else 1.0
560
946
  )
947
+ q = q.to(self.data_type)
561
948
 
562
- if not forward_batch.attn_attend_prefix_cache:
563
- q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
564
- k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
565
- v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
566
- output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
949
+ bmm1_scale = q_scale * k_scale * layer.scaling
950
+ if forward_batch.forward_mode.is_target_verify():
951
+ seq_lens = (
952
+ forward_batch.seq_lens.to(torch.int32)
953
+ + forward_batch.spec_info.draft_token_num
954
+ )
955
+ max_seq_len = (
956
+ metadata.max_seq_len_k + forward_batch.spec_info.draft_token_num
957
+ )
958
+ else:
959
+ seq_lens = forward_batch.seq_lens.to(torch.int32)
960
+ max_seq_len = metadata.max_seq_len_k
961
+ # Check if we're in CUDA graph mode (buffers are pre-allocated)
962
+ if self.padded_q_buffer is not None:
963
+ # Use pre-allocated buffer for CUDA graph compatibility
964
+ padded_q = self.padded_q_buffer[
965
+ :bs, : metadata.max_seq_len_q, :, :
966
+ ].to(dtype=q.dtype)
967
+ else:
968
+ # Dynamic allocation for non-CUDA graph mode
969
+ padded_q = torch.zeros(
970
+ bs,
971
+ metadata.max_seq_len_q,
972
+ layer.tp_q_head_num,
973
+ layer.head_dim,
974
+ dtype=q.dtype,
975
+ device=q.device,
976
+ )
977
+ q = self.pad_draft_extend_query(
978
+ q, padded_q, metadata.seq_lens_q, metadata.cu_seqlens_q
979
+ )
980
+
981
+ # TODO may use `mla_rope_quantize_fp8` fusion
982
+ q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
983
+ assert kv_cache.dtype == self.data_type
984
+
985
+ raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
986
+ query=q,
987
+ kv_cache=kv_cache,
988
+ workspace_buffer=self.workspace_buffer,
989
+ qk_nope_head_dim=self.qk_nope_head_dim,
990
+ kv_lora_rank=self.kv_lora_rank,
991
+ qk_rope_head_dim=self.qk_rope_head_dim,
992
+ block_tables=metadata.block_kv_indices,
993
+ seq_lens=seq_lens,
994
+ max_seq_len=max_seq_len,
995
+ bmm1_scale=bmm1_scale,
996
+ )
997
+
998
+ # Reshape output directly without slicing
999
+
1000
+ if forward_batch.forward_mode.is_draft_extend(include_v2=True):
1001
+ raw_out = self.unpad_draft_extend_output(
1002
+ raw_out,
1003
+ metadata.cu_seqlens_q,
1004
+ metadata.seq_lens_q,
1005
+ metadata.sum_seq_lens_q,
1006
+ )
1007
+ output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
1008
+ return output
1009
+
1010
+ if forward_batch.attn_attend_prefix_cache:
1011
+ # MHA for chunked prefix kv cache when running model with MLA
1012
+ assert forward_batch.prefix_chunk_idx is not None
1013
+ assert forward_batch.prefix_chunk_cu_seq_lens is not None
1014
+ assert q_rope is None
1015
+ assert k_rope is None
1016
+ chunk_idx = forward_batch.prefix_chunk_idx
1017
+
1018
+ output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
1019
+ return flashinfer.prefill.trtllm_ragged_attention_deepseek(
567
1020
  query=q,
568
1021
  key=k,
569
1022
  value=v,
570
1023
  workspace_buffer=self.workspace_buffer,
571
- seq_lens=self.forward_prefill_metadata.seq_lens,
1024
+ seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],
572
1025
  max_q_len=self.forward_prefill_metadata.max_seq_len,
573
- max_kv_len=self.forward_prefill_metadata.max_seq_len,
1026
+ max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
574
1027
  bmm1_scale=layer.scaling,
575
1028
  bmm2_scale=1.0,
576
- o_sf_scale=1.0,
1029
+ o_sf_scale=-1.0,
577
1030
  batch_size=forward_batch.batch_size,
578
1031
  window_left=-1,
579
1032
  cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
580
- cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
1033
+ cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
581
1034
  enable_pdl=False,
582
- is_causal=True,
583
- return_lse=forward_batch.mha_return_lse,
584
- )
585
- else:
586
- # replace with trtllm ragged attention once accuracy is resolved.
587
- output = super().forward_extend(
588
- q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
1035
+ is_causal=False,
1036
+ return_lse=True,
1037
+ out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
589
1038
  )
590
- return output
1039
+
1040
+ return flashinfer.prefill.trtllm_ragged_attention_deepseek(
1041
+ query=q,
1042
+ key=k,
1043
+ value=v,
1044
+ workspace_buffer=self.workspace_buffer,
1045
+ seq_lens=self.forward_prefill_metadata.seq_lens,
1046
+ max_q_len=self.forward_prefill_metadata.max_seq_len,
1047
+ max_kv_len=self.forward_prefill_metadata.max_seq_len,
1048
+ bmm1_scale=layer.scaling,
1049
+ bmm2_scale=1.0,
1050
+ o_sf_scale=1.0,
1051
+ batch_size=forward_batch.batch_size,
1052
+ window_left=-1,
1053
+ cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
1054
+ cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
1055
+ enable_pdl=False,
1056
+ is_causal=True,
1057
+ return_lse=forward_batch.mha_return_lse,
1058
+ )
591
1059
 
592
1060
 
593
1061
  class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
@@ -598,10 +1066,17 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
598
1066
  ):
599
1067
  super().__init__(model_runner, topk, speculative_num_steps)
600
1068
 
601
- for i in range(self.speculative_num_steps):
1069
+ for i in range(self.speculative_num_steps - 1):
602
1070
  self.attn_backends[i] = TRTLLMMLABackend(
603
1071
  model_runner,
604
1072
  skip_prefill=True,
605
1073
  kv_indptr_buf=self.kv_indptr[i],
606
1074
  q_indptr_decode_buf=self.q_indptr_decode,
607
1075
  )
1076
+
1077
+
1078
+ def _concat_mla_absorb_q_general(q_nope, q_rope):
1079
+ if _is_cuda and q_nope.shape[-1] == 512 and q_rope.shape[-1] == 64:
1080
+ return concat_mla_absorb_q(q_nope, q_rope)
1081
+ else:
1082
+ return torch.cat([q_nope, q_rope], dim=-1)