sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1028 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Optional
4
+
5
+ import torch
6
+
7
+ from sglang.srt.configs.model_config import AttentionArch
8
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
9
+ from sglang.srt.layers.attention.flashattention_backend import (
10
+ FlashAttentionMetadata,
11
+ make_local_attention_virtual_batches,
12
+ merge_state_v2_wrapper,
13
+ prepare_swa_spec_page_table_triton,
14
+ )
15
+ from sglang.srt.managers.schedule_batch import get_global_server_args
16
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
17
+
18
+ if TYPE_CHECKING:
19
+ from sglang.srt.layers.radix_attention import RadixAttention
20
+ from sglang.srt.model_executor.model_runner import ModelRunner
21
+
22
+ from sgl_kernel import merge_state_v2
23
+ from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
24
+
25
+
26
+ class XPUAttentionBackend(AttentionBackend):
27
+ """XPU FlashAttention backend, currently based on FlashAttentionBackend, will be refactored later.
28
+
29
+ TODO:
30
+ - Prefill and Decode disaggregation, currently only chunked prefill is supported
31
+ - Speculative Decoding support
32
+ - XPU Graph support, see https://github.com/pytorch/pytorch/issues/162143
33
+ - MLA support
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ model_runner: ModelRunner,
39
+ skip_prefill: bool = False,
40
+ speculative_step_id=0,
41
+ topk=0,
42
+ speculative_num_steps=0,
43
+ ):
44
+ super().__init__()
45
+
46
+ assert not (
47
+ model_runner.sliding_window_size is not None
48
+ and model_runner.model_config.is_encoder_decoder
49
+ ), "Sliding window and cross attention are not supported together"
50
+
51
+ self.forward_metadata: FlashAttentionMetadata = None
52
+ # extra metadata for handling speculative decoding topk > 1, extended draft decode and verify
53
+ self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None
54
+ self.max_context_len = model_runner.model_config.context_len
55
+ self.device = model_runner.device
56
+ self.decode_cuda_graph_metadata = {}
57
+ self.target_verify_metadata = {}
58
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
59
+ self.kv_cache_dtype = model_runner.kv_cache_dtype
60
+ self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
61
+ self.page_size = model_runner.page_size
62
+ self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
63
+ assert (
64
+ self.use_mla is False
65
+ ), "XPUAttentionBackend doesn't support MLA yet, please use --attention-backend triton instead."
66
+ self.skip_prefill = skip_prefill
67
+ self.is_hybrid = model_runner.is_hybrid
68
+ if self.is_hybrid:
69
+ self.full_to_swa_index_mapping = (
70
+ model_runner.token_to_kv_pool.full_to_swa_index_mapping
71
+ )
72
+ self.topk = model_runner.server_args.speculative_eagle_topk or 0
73
+ self.speculative_num_steps = speculative_num_steps
74
+ self.speculative_num_draft_tokens = (
75
+ model_runner.server_args.speculative_num_draft_tokens
76
+ )
77
+ self.speculative_step_id = speculative_step_id
78
+
79
+ # Local attention settings
80
+ self.attention_chunk_size = (
81
+ model_runner.attention_chunk_size
82
+ if hasattr(model_runner, "attention_chunk_size")
83
+ else None
84
+ )
85
+
86
+ # For each layer, the sliding_window_size can be different. This is only used for preparing SWA metadata.
87
+ # We use `layer.sliding_window_size` to decide whether to use SWA for each layer.
88
+ self.sliding_window_size = model_runner.sliding_window_size
89
+ self.has_swa = (
90
+ self.sliding_window_size is not None and self.sliding_window_size > -1
91
+ )
92
+
93
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
94
+ """Initialize forward metadata hence all layers in the forward pass can reuse it."""
95
+ metadata = FlashAttentionMetadata()
96
+ seqlens_in_batch = forward_batch.seq_lens
97
+ batch_size = forward_batch.batch_size
98
+ device = seqlens_in_batch.device
99
+
100
+ if forward_batch.forward_mode.is_decode_or_idle():
101
+ # Draft Decode
102
+ if forward_batch.spec_info is not None:
103
+ assert (
104
+ False
105
+ ), "XPUAttentionBackend doesn't support speculative decoding yet, please use --attention-backend triton instead."
106
+ if self.topk <= 1:
107
+ metadata.cache_seqlens_int32 = (
108
+ seqlens_in_batch + (self.speculative_step_id + 1)
109
+ ).to(torch.int32)
110
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
111
+ self.speculative_step_id + 1
112
+ )
113
+ metadata.cu_seqlens_q = torch.arange(
114
+ 0, batch_size + 1, dtype=torch.int32, device=device
115
+ )
116
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
117
+ torch.cumsum(
118
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
119
+ ),
120
+ (1, 0),
121
+ )
122
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
123
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
124
+ ]
125
+ else:
126
+ metadata.cache_seqlens_int32 = (seqlens_in_batch).to(torch.int32)
127
+ metadata.max_seq_len_q = self.topk
128
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
129
+ metadata.cu_seqlens_q = torch.arange(
130
+ 0,
131
+ batch_size * self.topk + 1,
132
+ step=self.topk,
133
+ dtype=torch.int32,
134
+ device=device,
135
+ )
136
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
137
+ torch.cumsum(
138
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
139
+ ),
140
+ (1, 0),
141
+ )
142
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
143
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
144
+ ]
145
+
146
+ metadata_expand = FlashAttentionMetadata()
147
+ decode_length = self.speculative_step_id + 1
148
+ metadata_expand.cache_seqlens_int32 = torch.full(
149
+ (seqlens_in_batch.numel() * self.topk,),
150
+ decode_length,
151
+ device=device,
152
+ dtype=torch.int32,
153
+ )
154
+ metadata_expand.max_seq_len_q = 1
155
+ metadata_expand.cu_seqlens_q = torch.arange(
156
+ 0,
157
+ metadata_expand.cache_seqlens_int32.numel() + 1,
158
+ dtype=torch.int32,
159
+ device=device,
160
+ )
161
+ metadata_expand.cu_seqlens_k = torch.arange(
162
+ 0,
163
+ metadata_expand.cache_seqlens_int32.numel() * decode_length + 1,
164
+ step=decode_length,
165
+ dtype=torch.int32,
166
+ device=device,
167
+ )
168
+ # shape: [bs, num_steps, topk] -> [bs x topk, num_steps]
169
+ cache_loc = forward_batch.out_cache_loc.view(
170
+ -1, self.speculative_num_steps
171
+ )
172
+ metadata_expand.page_table = (
173
+ cache_loc[:, :decode_length].contiguous().to(torch.int32)
174
+ )
175
+ self.forward_metadata_spec_decode_expand = metadata_expand
176
+ else:
177
+ # Normal Decode
178
+ metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
179
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
180
+ metadata.cu_seqlens_q = torch.arange(
181
+ 0, batch_size + 1, dtype=torch.int32, device=device
182
+ )
183
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
184
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
185
+ )
186
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
187
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
188
+ ]
189
+ # TODO: we need to test this part for llama 4 eagle case
190
+ self._init_local_attn_metadata(forward_batch, metadata, device)
191
+ elif forward_batch.forward_mode.is_target_verify():
192
+ if self.topk <= 1:
193
+ metadata.cache_seqlens_int32 = (
194
+ forward_batch.seq_lens + self.speculative_num_draft_tokens
195
+ ).to(torch.int32)
196
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
197
+ metadata.max_seq_len_k = (
198
+ forward_batch.seq_lens_cpu.max().item()
199
+ + self.speculative_num_draft_tokens
200
+ )
201
+ metadata.cu_seqlens_q = torch.arange(
202
+ 0,
203
+ batch_size * self.speculative_num_draft_tokens + 1,
204
+ self.speculative_num_draft_tokens,
205
+ dtype=torch.int32,
206
+ device=device,
207
+ )
208
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
209
+ torch.cumsum(
210
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
211
+ ),
212
+ (1, 0),
213
+ )
214
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
215
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
216
+ ]
217
+
218
+ self._init_local_attn_metadata(forward_batch, metadata, device)
219
+ else:
220
+ metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
221
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
222
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
223
+ metadata.cu_seqlens_q = torch.arange(
224
+ 0,
225
+ batch_size * self.speculative_num_draft_tokens + 1,
226
+ step=self.speculative_num_draft_tokens,
227
+ dtype=torch.int32,
228
+ device=device,
229
+ )
230
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
231
+ torch.cumsum(
232
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
233
+ ),
234
+ (1, 0),
235
+ )
236
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
237
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
238
+ ]
239
+
240
+ metadata_expand = FlashAttentionMetadata()
241
+
242
+ metadata_expand.max_seq_len_q = 1
243
+ metadata_expand.cu_seqlens_q = torch.arange(
244
+ 0,
245
+ forward_batch.seq_lens.numel() * self.speculative_num_draft_tokens
246
+ + 1,
247
+ dtype=torch.int32,
248
+ device=device,
249
+ )
250
+
251
+ # create expand page table
252
+ offsets = torch.arange(
253
+ self.speculative_num_draft_tokens, device=device
254
+ ).unsqueeze(
255
+ 0
256
+ ) # shape: (1, self.speculative_num_draft_tokens)
257
+ cols = offsets.expand(
258
+ forward_batch.seq_lens.numel(), -1
259
+ ) + forward_batch.seq_lens.unsqueeze(1)
260
+ cum_len = torch.nn.functional.pad(
261
+ torch.cumsum(
262
+ (
263
+ forward_batch.seq_lens + self.speculative_num_draft_tokens
264
+ ).repeat_interleave(self.speculative_num_draft_tokens),
265
+ dim=0,
266
+ ),
267
+ (1, 0),
268
+ )[:-1]
269
+ mask_extraction_indices = (
270
+ cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
271
+ + cum_len[:, None]
272
+ ).view(1, -1)
273
+ mask = forward_batch.spec_info.custom_mask[
274
+ mask_extraction_indices
275
+ ].view(
276
+ -1, self.speculative_num_draft_tokens
277
+ ) # (bsz * draft_num, draft_num)
278
+
279
+ # shift table indices to avoid padding
280
+ # non_masked_page_table [[8, 9, 10], mask (display with int format) [[1, 0, 0],
281
+ # [8, 9, 10], [1, 1, 0],
282
+ # [8, 9, 10]] [1, 0, 1]]
283
+ # if masked with padding [[8, 0, 0], our mask without padding [[8, 9, 10],
284
+ # [8, 9, 0], [8, 9, 10],
285
+ # [8, 0, 10]] [8, 10, 9]]
286
+ # note here cache_seqlens_int32 is [1, 2, 2] so extra page indices will be ignored in each row
287
+ col_indices = offsets.expand(
288
+ mask.shape[0], self.speculative_num_draft_tokens
289
+ )
290
+ # Build keys: if an entry is valid (mask==True), keep its original index;
291
+ # if not, add self.speculative_num_draft_tokens so that it sorts after all valid entries.
292
+ keys = torch.where(
293
+ mask, col_indices, col_indices + self.speculative_num_draft_tokens
294
+ )
295
+ _, sort_order = torch.sort(keys, dim=1)
296
+ non_masked_page_table = (
297
+ forward_batch.req_to_token_pool.req_to_token[
298
+ forward_batch.req_pool_indices, :
299
+ ]
300
+ .gather(1, cols)
301
+ .repeat_interleave(self.speculative_num_draft_tokens, dim=0)
302
+ ) # (bsz, draft_num)
303
+ metadata_expand.page_table = non_masked_page_table.gather(1, sort_order)
304
+ metadata_expand.cache_seqlens_int32 = mask.sum(dim=1).to(torch.int32)
305
+ metadata_expand.cu_seqlens_k = torch.nn.functional.pad(
306
+ torch.cumsum(
307
+ metadata_expand.cache_seqlens_int32, dim=0, dtype=torch.int32
308
+ ),
309
+ (1, 0),
310
+ )
311
+ self.forward_metadata_spec_decode_expand = metadata_expand
312
+
313
+ if self.has_swa:
314
+ self._init_sliding_window_attn_spec_metadata(
315
+ metadata, metadata_expand
316
+ )
317
+
318
+ elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
319
+ metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
320
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
321
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
322
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
323
+ )
324
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
325
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
326
+ ]
327
+
328
+ if (
329
+ any(forward_batch.extend_prefix_lens_cpu)
330
+ or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
331
+ ):
332
+ extend_seq_lens = forward_batch.extend_seq_lens
333
+ metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
334
+ metadata.cu_seqlens_q = torch.nn.functional.pad(
335
+ torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
336
+ )
337
+ else:
338
+ metadata.max_seq_len_q = metadata.max_seq_len_k
339
+ metadata.cu_seqlens_q = metadata.cu_seqlens_k
340
+
341
+ # Setup local attention if enabled
342
+ if forward_batch.forward_mode == ForwardMode.EXTEND:
343
+ self._init_local_attn_metadata(forward_batch, metadata, device)
344
+
345
+ # Encoder metadata for cross attention
346
+ if forward_batch.encoder_lens is not None:
347
+ assert (
348
+ forward_batch.encoder_lens.numel() == 1
349
+ ), "Only encoder size 1 is supported for now"
350
+
351
+ metadata.encoder_lens_int32 = forward_batch.encoder_lens.to(torch.int32)
352
+ metadata.encoder_cu_seqlens_k = torch.nn.functional.pad(
353
+ torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
354
+ (1, 0),
355
+ )
356
+ metadata.encoder_max_seq_len_k = metadata.encoder_lens_int32.max().item()
357
+ metadata.encoder_page_table = forward_batch.req_to_token_pool.req_to_token[
358
+ forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k
359
+ ]
360
+
361
+ # Currently only support forward_batch.encoder_lens.numel() == 1
362
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
363
+ forward_batch.req_pool_indices,
364
+ metadata.encoder_max_seq_len_k : (
365
+ metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
366
+ ),
367
+ ]
368
+
369
+ # Convert the page table to a strided format which is needed by FA3 API
370
+ if self.page_size > 1:
371
+ self.strided_indices = torch.arange(
372
+ 0, metadata.page_table.shape[1], self.page_size, device=self.device
373
+ )
374
+ metadata.page_table = (
375
+ metadata.page_table[:, self.strided_indices] // self.page_size
376
+ )
377
+
378
+ self.forward_metadata = metadata
379
+
380
+ def forward_extend(
381
+ self,
382
+ q: torch.Tensor,
383
+ k: torch.Tensor,
384
+ v: torch.Tensor,
385
+ layer: RadixAttention,
386
+ forward_batch: ForwardBatch,
387
+ save_kv_cache=True,
388
+ # For multi-head latent attention
389
+ q_rope: Optional[torch.Tensor] = None,
390
+ k_rope: Optional[torch.Tensor] = None,
391
+ sinks: Optional[torch.Tensor] = None,
392
+ ):
393
+ if k is not None:
394
+ assert v is not None
395
+ if save_kv_cache:
396
+ cache_loc = (
397
+ forward_batch.out_cache_loc
398
+ if not layer.is_cross_attention
399
+ else forward_batch.encoder_out_cache_loc
400
+ )
401
+ if not self.use_mla:
402
+ forward_batch.token_to_kv_pool.set_kv_buffer(
403
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
404
+ )
405
+ else:
406
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
407
+ layer,
408
+ cache_loc,
409
+ k,
410
+ k_rope,
411
+ )
412
+
413
+ # Use precomputed metadata across all layers
414
+ metadata = self.forward_metadata
415
+
416
+ # Calculate window size (can be moved to metadata if layer properties don't change)
417
+ # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
418
+ # here is two side inclusive
419
+ is_swa = (
420
+ layer.sliding_window_size is not None and layer.sliding_window_size > -1
421
+ )
422
+ window_size = (layer.sliding_window_size, 0) if is_swa else (-1, -1)
423
+
424
+ # currently no FP8 KV cache supported
425
+ k_descale, v_descale = None, None
426
+ # # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
427
+ # # has corresponding quantization method so that layer.k_scale is not None,
428
+ # # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
429
+ # if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
430
+ # if layer.k_scale is not None:
431
+ # descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
432
+ # k_descale = layer.k_scale.expand(descale_shape)
433
+ # v_descale = layer.v_scale.expand(descale_shape)
434
+ # q = q.to(self.kv_cache_dtype)
435
+ # q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
436
+ # k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
437
+ causal = not layer.is_cross_attention
438
+
439
+ # Check if we should use local attention
440
+ use_local_attn = (
441
+ self.attention_chunk_size is not None
442
+ and metadata.local_attn_metadata is not None
443
+ and (hasattr(layer, "use_irope") and layer.use_irope)
444
+ )
445
+
446
+ # We do cascade attention for Target Verify with topk > 1
447
+ # We don't use cascade attention for Sliding Window Attention:
448
+ # - Different window sizes should be passed in for each q in the first stage of cascade attention, but FA3 interface doesn't support pass in a list of window sizes.
449
+ # - The overhead of duplicated computation of the common prefix part is small for sliding window layers (seq_len <= window_size), so we can just expand it.
450
+ use_cascade_attn = (
451
+ forward_batch.forward_mode.is_target_verify()
452
+ and self.topk > 1
453
+ and not is_swa
454
+ )
455
+
456
+ # For fa3 interface version compatibility, we put new fields into conditional keyword args
457
+ kwargs = {}
458
+ if sinks is not None:
459
+ kwargs["sinks"] = sinks
460
+
461
+ # Get the appropriate page table based on whether we're using local attention
462
+ if use_local_attn:
463
+ local_metadata = metadata.local_attn_metadata
464
+ page_table = local_metadata.local_block_table
465
+ cu_seqlens_q = local_metadata.local_query_start_loc
466
+ cache_seqlens = local_metadata.local_seqused_k
467
+ max_seqlen_q = local_metadata.local_max_query_len
468
+ elif is_swa and metadata.swa_spec_metadata is not None:
469
+ swa_spec_metadata = metadata.swa_spec_metadata
470
+ page_table = swa_spec_metadata.page_table
471
+ cu_seqlens_q = swa_spec_metadata.cu_seqlens_q
472
+ cache_seqlens = swa_spec_metadata.cache_seqlens_int32
473
+ max_seqlen_q = swa_spec_metadata.max_seq_len_q
474
+ cu_seqlens_k = swa_spec_metadata.cu_seqlens_k
475
+ else:
476
+ page_table = metadata.page_table
477
+ cu_seqlens_q = metadata.cu_seqlens_q
478
+ cache_seqlens = metadata.cache_seqlens_int32
479
+ max_seqlen_q = metadata.max_seq_len_q
480
+ cu_seqlens_k = metadata.cu_seqlens_k
481
+
482
+ # Use Flash Attention for prefill
483
+ if not self.use_mla:
484
+ # Do multi-head attention
485
+ key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
486
+ layer.layer_id
487
+ )
488
+ key_cache = key_cache.view(
489
+ -1, self.page_size, layer.tp_k_head_num, layer.head_dim
490
+ )
491
+ value_cache = value_cache.view(
492
+ -1, self.page_size, layer.tp_v_head_num, layer.head_dim
493
+ )
494
+ if layer.is_cross_attention:
495
+ page_table = metadata.encoder_page_table
496
+ cache_seqlens = metadata.encoder_lens_int32
497
+ cu_seqlens_k = metadata.encoder_cu_seqlens_k
498
+ window_size = (-1, -1)
499
+
500
+ result = flash_attn_with_kvcache(
501
+ q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
502
+ k_cache=key_cache,
503
+ v_cache=value_cache,
504
+ page_table=page_table,
505
+ cache_seqlens=cache_seqlens,
506
+ cu_seqlens_q=cu_seqlens_q,
507
+ cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
508
+ max_seqlen_q=max_seqlen_q,
509
+ softmax_scale=layer.scaling,
510
+ causal=False if use_cascade_attn else causal,
511
+ window_size=window_size,
512
+ softcap=layer.logit_cap,
513
+ k_descale=k_descale,
514
+ v_descale=v_descale,
515
+ return_softmax_lse=use_cascade_attn,
516
+ **kwargs,
517
+ )
518
+
519
+ if use_cascade_attn:
520
+ o, softmax_lse, *rest = result
521
+ o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
522
+ q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
523
+ k_cache=key_cache,
524
+ v_cache=value_cache,
525
+ page_table=self.forward_metadata_spec_decode_expand.page_table,
526
+ cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
527
+ cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
528
+ cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
529
+ max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
530
+ softmax_scale=layer.scaling,
531
+ causal=False,
532
+ window_size=window_size,
533
+ softcap=layer.logit_cap,
534
+ k_descale=k_descale,
535
+ v_descale=v_descale,
536
+ return_softmax_lse=True,
537
+ **kwargs,
538
+ )
539
+ o, _ = merge_state_v2_wrapper(
540
+ o,
541
+ softmax_lse.T.contiguous(),
542
+ o_expand,
543
+ softmax_lse_expand.T.contiguous(),
544
+ )
545
+ else:
546
+ o = result
547
+ else:
548
+ if (
549
+ forward_batch.attn_attend_prefix_cache is not None
550
+ and not forward_batch.forward_mode.is_target_verify()
551
+ and not forward_batch.forward_mode.is_draft_extend()
552
+ ):
553
+ # Do multi-head attention with chunked prefix cache
554
+ if forward_batch.attn_attend_prefix_cache:
555
+ assert not get_global_server_args().disable_chunked_prefix_cache
556
+ # MHA for chunked prefix kv cache when running model with MLA
557
+ assert forward_batch.prefix_chunk_idx is not None
558
+ assert forward_batch.prefix_chunk_cu_seq_lens is not None
559
+ assert forward_batch.prefix_chunk_max_seq_lens is not None
560
+
561
+ chunk_idx = forward_batch.prefix_chunk_idx
562
+ assert chunk_idx >= 0
563
+
564
+ assert forward_batch.mha_return_lse
565
+ output = flash_attn_varlen_func(
566
+ q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
567
+ k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
568
+ v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
569
+ cu_seqlens_q=metadata.cu_seqlens_q,
570
+ cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
571
+ max_seqlen_q=metadata.max_seq_len_q,
572
+ max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
573
+ softmax_scale=layer.scaling,
574
+ causal=False,
575
+ return_softmax_lse=True,
576
+ )
577
+ else:
578
+ # MHA for extend part of sequence without attending prefix kv cache
579
+ output = flash_attn_varlen_func(
580
+ q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
581
+ k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
582
+ v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
583
+ cu_seqlens_q=metadata.cu_seqlens_q,
584
+ cu_seqlens_k=metadata.cu_seqlens_q,
585
+ max_seqlen_q=metadata.max_seq_len_q,
586
+ max_seqlen_k=metadata.max_seq_len_q,
587
+ softmax_scale=layer.scaling,
588
+ causal=True,
589
+ return_softmax_lse=forward_batch.mha_return_lse,
590
+ )
591
+ if forward_batch.mha_return_lse:
592
+ output, lse, *rest = output
593
+ lse = torch.transpose(lse, 0, 1).contiguous()
594
+ return output, lse
595
+ return output
596
+ else:
597
+ # Do absorbed multi-latent attention
598
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
599
+ layer.layer_id
600
+ ).to(q.dtype)
601
+ k_rope = kv_cache[:, :, layer.v_head_dim :]
602
+ c_kv = kv_cache[:, :, : layer.v_head_dim]
603
+ k_rope_cache = k_rope.view(
604
+ -1,
605
+ self.page_size,
606
+ layer.tp_k_head_num,
607
+ layer.head_dim - layer.v_head_dim,
608
+ )
609
+ c_kv_cache = c_kv.view(
610
+ -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
611
+ )
612
+ if q_rope is not None:
613
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
614
+ q_rope = q_rope.view(
615
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
616
+ )
617
+ else:
618
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
619
+ q_nope = q_all[:, :, : layer.v_head_dim]
620
+ q_rope = q_all[:, :, layer.v_head_dim :]
621
+
622
+ result = flash_attn_with_kvcache(
623
+ q=q_rope,
624
+ k_cache=k_rope_cache,
625
+ v_cache=c_kv_cache,
626
+ qv=q_nope,
627
+ page_table=page_table,
628
+ cache_seqlens=cache_seqlens,
629
+ cu_seqlens_q=cu_seqlens_q,
630
+ cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
631
+ max_seqlen_q=max_seqlen_q,
632
+ softmax_scale=layer.scaling,
633
+ causal=False if use_cascade_attn else causal,
634
+ softcap=layer.logit_cap,
635
+ k_descale=k_descale,
636
+ v_descale=v_descale,
637
+ return_softmax_lse=use_cascade_attn,
638
+ )
639
+ if use_cascade_attn:
640
+ o, softmax_lse, *rest = result
641
+ o_expand, softmax_lse_expand, *rest_expand = (
642
+ flash_attn_with_kvcache(
643
+ q=q_rope,
644
+ k_cache=k_rope_cache,
645
+ v_cache=c_kv_cache,
646
+ qv=q_nope,
647
+ page_table=self.forward_metadata_spec_decode_expand.page_table,
648
+ cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
649
+ cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
650
+ cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
651
+ max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
652
+ softmax_scale=layer.scaling,
653
+ causal=False,
654
+ window_size=window_size,
655
+ softcap=layer.logit_cap,
656
+ k_descale=k_descale,
657
+ v_descale=v_descale,
658
+ return_softmax_lse=True,
659
+ )
660
+ )
661
+ o, _ = merge_state_v2_wrapper(
662
+ o,
663
+ softmax_lse.T.contiguous(),
664
+ o_expand,
665
+ softmax_lse_expand.T.contiguous(),
666
+ )
667
+ else:
668
+ o = result
669
+
670
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
671
+
672
+ def forward_decode(
673
+ self,
674
+ q: torch.Tensor,
675
+ k: torch.Tensor,
676
+ v: torch.Tensor,
677
+ layer: RadixAttention,
678
+ forward_batch: ForwardBatch,
679
+ save_kv_cache=True,
680
+ # For multi-head latent attention
681
+ q_rope: Optional[torch.Tensor] = None,
682
+ k_rope: Optional[torch.Tensor] = None,
683
+ sinks: Optional[torch.Tensor] = None,
684
+ ) -> torch.Tensor:
685
+ if k is not None:
686
+ assert v is not None
687
+ if save_kv_cache:
688
+ cache_loc = (
689
+ forward_batch.out_cache_loc
690
+ if not layer.is_cross_attention
691
+ else forward_batch.encoder_out_cache_loc
692
+ )
693
+ if not self.use_mla:
694
+ forward_batch.token_to_kv_pool.set_kv_buffer(
695
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
696
+ )
697
+ else:
698
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
699
+ layer,
700
+ cache_loc,
701
+ k,
702
+ k_rope,
703
+ )
704
+
705
+ # Use precomputed metadata across all layers
706
+ metadata = self.forward_metadata
707
+ local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
708
+ use_local_attn = (
709
+ self.attention_chunk_size is not None
710
+ and local_attn_metadata is not None
711
+ and (hasattr(layer, "use_irope") and layer.use_irope)
712
+ )
713
+
714
+ # When Spec Decode enabled, forward_decode would be called with two mode:
715
+ # 1. DRAFT_DECODE: we enable cascade attention when top_k > 1
716
+ # 2. IDLE: we don’t need cascade attention, spec_info will be none in this case
717
+ use_cascade_attn = forward_batch.spec_info is not None and self.topk > 1
718
+
719
+ # Calculate window size (can be moved to metadata if layer properties don't change)
720
+ # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
721
+ # here is two side inclusive
722
+ window_size = (
723
+ (layer.sliding_window_size, 0)
724
+ if layer.sliding_window_size is not None and layer.sliding_window_size > -1
725
+ else (-1, -1)
726
+ )
727
+ causal = not layer.is_cross_attention
728
+
729
+ # For fa3 interface version compatibility, we put new fields into conditional keyword args
730
+ kwargs = {}
731
+ if sinks is not None:
732
+ kwargs["sinks"] = sinks
733
+
734
+ k_descale, v_descale = None, None
735
+ # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
736
+ # has corresponding quantization method so that layer.k_scale is not None,
737
+ # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
738
+ if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
739
+ if layer.k_scale is not None:
740
+ descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
741
+ k_descale = layer.k_scale.expand(descale_shape)
742
+ v_descale = layer.v_scale.expand(descale_shape)
743
+ q = q.to(self.kv_cache_dtype)
744
+ q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
745
+ k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
746
+ if not self.use_mla:
747
+ # Do multi-head attention
748
+
749
+ key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
750
+ layer.layer_id
751
+ )
752
+ key_cache = key_cache.view(
753
+ -1, self.page_size, layer.tp_k_head_num, layer.head_dim
754
+ )
755
+ value_cache = value_cache.view(
756
+ -1, self.page_size, layer.tp_v_head_num, layer.head_dim
757
+ )
758
+
759
+ if layer.is_cross_attention:
760
+ # Always use non-chunked logic for cross-attention
761
+ o = flash_attn_with_kvcache(
762
+ q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
763
+ k_cache=key_cache,
764
+ v_cache=value_cache,
765
+ page_table=metadata.encoder_page_table,
766
+ cache_seqlens=metadata.encoder_lens_int32,
767
+ cu_seqlens_q=metadata.cu_seqlens_q,
768
+ cu_seqlens_k_new=metadata.encoder_cu_seqlens_k,
769
+ max_seqlen_q=1,
770
+ softmax_scale=layer.scaling,
771
+ causal=False,
772
+ window_size=(-1, -1),
773
+ softcap=layer.logit_cap,
774
+ k_descale=k_descale,
775
+ v_descale=v_descale,
776
+ **kwargs,
777
+ )
778
+ elif use_local_attn:
779
+ # Use chunked (local) attention batching for self-attention
780
+ o = flash_attn_with_kvcache(
781
+ q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
782
+ k_cache=key_cache,
783
+ v_cache=value_cache,
784
+ page_table=local_attn_metadata.local_block_table,
785
+ cache_seqlens=local_attn_metadata.local_seqused_k,
786
+ cu_seqlens_q=local_attn_metadata.local_query_start_loc,
787
+ cu_seqlens_k_new=None,
788
+ max_seqlen_q=local_attn_metadata.local_max_query_len,
789
+ softmax_scale=layer.scaling,
790
+ causal=True,
791
+ window_size=(-1, -1),
792
+ softcap=layer.logit_cap,
793
+ k_descale=k_descale,
794
+ v_descale=v_descale,
795
+ **kwargs,
796
+ )
797
+ else:
798
+ page_table = metadata.page_table
799
+ cache_seqlens = metadata.cache_seqlens_int32
800
+ cu_seqlens_k = metadata.cu_seqlens_k
801
+ max_seqlen_q = metadata.max_seq_len_q
802
+ q_reshaped = q.contiguous().view(
803
+ -1, layer.tp_q_head_num, layer.head_dim
804
+ )
805
+
806
+ # Default: single-token self-attention
807
+ result = flash_attn_with_kvcache(
808
+ q=q_reshaped,
809
+ k_cache=key_cache,
810
+ v_cache=value_cache,
811
+ page_table=page_table,
812
+ cache_seqlens=cache_seqlens,
813
+ cu_seqlens_q=metadata.cu_seqlens_q,
814
+ cu_seqlens_k_new=cu_seqlens_k,
815
+ max_seqlen_q=max_seqlen_q,
816
+ softmax_scale=layer.scaling,
817
+ causal=False if use_cascade_attn else causal,
818
+ window_size=window_size,
819
+ softcap=layer.logit_cap,
820
+ k_descale=k_descale,
821
+ v_descale=v_descale,
822
+ return_softmax_lse=use_cascade_attn,
823
+ **kwargs,
824
+ )
825
+ if use_cascade_attn:
826
+ o, softmax_lse, *rest = result
827
+ o_expand, softmax_lse_expand, *rest_expand = (
828
+ flash_attn_with_kvcache(
829
+ q=q_reshaped,
830
+ k_cache=key_cache,
831
+ v_cache=value_cache,
832
+ page_table=self.forward_metadata_spec_decode_expand.page_table,
833
+ cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
834
+ cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
835
+ cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
836
+ max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
837
+ softmax_scale=layer.scaling,
838
+ causal=False,
839
+ window_size=window_size,
840
+ softcap=layer.logit_cap,
841
+ k_descale=k_descale,
842
+ v_descale=v_descale,
843
+ return_softmax_lse=True,
844
+ **kwargs,
845
+ )
846
+ )
847
+ o, _ = merge_state_v2(
848
+ o,
849
+ softmax_lse.T.contiguous(),
850
+ o_expand,
851
+ softmax_lse_expand.T.contiguous(),
852
+ )
853
+ else:
854
+ o = result
855
+ else:
856
+ # Do absorbed multi-latent attention
857
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
858
+ q.dtype
859
+ )
860
+ k_rope = kv_cache[:, :, layer.v_head_dim :]
861
+ c_kv = kv_cache[:, :, : layer.v_head_dim]
862
+ k_rope_cache = k_rope.view(
863
+ -1,
864
+ self.page_size,
865
+ layer.tp_k_head_num,
866
+ layer.head_dim - layer.v_head_dim,
867
+ )
868
+ c_kv_cache = c_kv.view(
869
+ -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
870
+ )
871
+
872
+ if q_rope is not None:
873
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
874
+ q_rope = q_rope.view(
875
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
876
+ )
877
+ else:
878
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
879
+ q_nope = q_all[:, :, : layer.v_head_dim]
880
+ q_rope = q_all[:, :, layer.v_head_dim :]
881
+ max_seqlen_q = metadata.max_seq_len_q
882
+
883
+ result = flash_attn_with_kvcache(
884
+ q=q_rope,
885
+ k_cache=k_rope_cache,
886
+ v_cache=c_kv_cache,
887
+ qv=q_nope,
888
+ page_table=metadata.page_table,
889
+ cache_seqlens=metadata.cache_seqlens_int32,
890
+ cu_seqlens_q=metadata.cu_seqlens_q,
891
+ cu_seqlens_k_new=metadata.cu_seqlens_k,
892
+ max_seqlen_q=max_seqlen_q,
893
+ softmax_scale=layer.scaling,
894
+ causal=False if use_cascade_attn else causal,
895
+ softcap=layer.logit_cap,
896
+ k_descale=k_descale,
897
+ v_descale=v_descale,
898
+ return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
899
+ )
900
+ if use_cascade_attn:
901
+ o, softmax_lse, *rest = result
902
+ o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
903
+ q=q_rope,
904
+ k_cache=k_rope_cache,
905
+ v_cache=c_kv_cache,
906
+ qv=q_nope,
907
+ page_table=self.forward_metadata_spec_decode_expand.page_table,
908
+ cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
909
+ cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
910
+ cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
911
+ max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
912
+ softmax_scale=layer.scaling,
913
+ causal=False,
914
+ window_size=window_size,
915
+ softcap=layer.logit_cap,
916
+ k_descale=k_descale,
917
+ v_descale=v_descale,
918
+ return_softmax_lse=True,
919
+ )
920
+ o, _ = merge_state_v2(
921
+ o,
922
+ softmax_lse.T.contiguous(),
923
+ o_expand,
924
+ softmax_lse_expand.T.contiguous(),
925
+ )
926
+ else:
927
+ o = result
928
+
929
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
930
+
931
+ def get_cuda_graph_seq_len_fill_value(self):
932
+ """Get the fill value for sequence length in CUDA graph."""
933
+ return 1
934
+
935
+ def _init_local_attn_metadata(
936
+ self, forwardbatch: ForwardBatch, metadata: FlashAttentionMetadata, device
937
+ ):
938
+ """Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
939
+ if self.attention_chunk_size is None:
940
+ metadata.local_attn_metadata = None
941
+ return
942
+
943
+ cu_seqlens_q = metadata.cu_seqlens_q
944
+ cache_seqlens_int32 = metadata.cache_seqlens_int32
945
+ if self.is_hybrid:
946
+ page_table = self.full_to_swa_index_mapping[metadata.page_table].to(
947
+ torch.int32
948
+ )
949
+ else:
950
+ page_table = metadata.page_table
951
+ if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
952
+ metadata.local_attn_metadata = None
953
+ return
954
+
955
+ cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
956
+ seq_lens_np = cache_seqlens_int32.cpu().numpy()
957
+ (
958
+ seqlens_q_local_np,
959
+ cu_seqlens_q_local_np,
960
+ seqlens_k_local_np,
961
+ block_table_local,
962
+ ) = make_local_attention_virtual_batches(
963
+ self.attention_chunk_size,
964
+ cu_seqlens_q_np,
965
+ seq_lens_np,
966
+ page_table,
967
+ self.page_size,
968
+ )
969
+
970
+ local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
971
+ local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
972
+ local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
973
+ local_block_table=block_table_local.to(device),
974
+ local_max_query_len=int(seqlens_q_local_np.max()),
975
+ local_max_seq_len=int(seqlens_k_local_np.max()),
976
+ )
977
+ metadata.local_attn_metadata = local_metadata
978
+
979
+ def _init_sliding_window_attn_spec_metadata(
980
+ self,
981
+ metadata: FlashAttentionMetadata,
982
+ metadata_expand: FlashAttentionMetadata,
983
+ metadata_swa: Optional[FlashAttentionMetadata] = None,
984
+ ):
985
+ # TODO: support page_size > 1 for swa spec
986
+ assert (
987
+ self.page_size == 1
988
+ ), "FlashAttention backend doesn't support topk > 1 speculative decoding with page size > 1 sliding window attention"
989
+
990
+ cache_seqlens_int32 = (
991
+ metadata.cache_seqlens_int32.repeat_interleave(
992
+ self.speculative_num_draft_tokens
993
+ )
994
+ + metadata_expand.cache_seqlens_int32
995
+ )
996
+ cu_seqlens_k = torch.nn.functional.pad(
997
+ torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32), (1, 0)
998
+ )
999
+ bs = cache_seqlens_int32.shape[0]
1000
+ page_table = (
1001
+ metadata.page_table.new_zeros(
1002
+ (bs, metadata.max_seq_len_k + metadata_expand.page_table.shape[1])
1003
+ )
1004
+ if metadata_swa is None
1005
+ else metadata_swa.page_table
1006
+ )
1007
+
1008
+ prepare_swa_spec_page_table_triton(
1009
+ page_table,
1010
+ metadata.page_table,
1011
+ metadata_expand.page_table,
1012
+ metadata.cache_seqlens_int32,
1013
+ metadata_expand.cache_seqlens_int32,
1014
+ self.speculative_num_draft_tokens,
1015
+ )
1016
+
1017
+ if metadata_swa is None:
1018
+ metadata_swa = FlashAttentionMetadata()
1019
+ metadata_swa.max_seq_len_q = 1
1020
+ metadata_swa.cu_seqlens_q = metadata_expand.cu_seqlens_q
1021
+ metadata_swa.cache_seqlens_int32 = cache_seqlens_int32
1022
+ metadata_swa.cu_seqlens_k = cu_seqlens_k
1023
+ metadata_swa.page_table = page_table
1024
+ else:
1025
+ metadata_swa.cache_seqlens_int32.copy_(cache_seqlens_int32)
1026
+ metadata_swa.cu_seqlens_k.copy_(cu_seqlens_k)
1027
+
1028
+ metadata.swa_spec_metadata = metadata_swa