sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -20,8 +20,9 @@ from sglang.srt.layers.attention.utils import (
20
20
  create_flashmla_kv_indices_triton,
21
21
  )
22
22
  from sglang.srt.layers.dp_attention import get_attention_tp_size
23
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
23
24
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
24
- from sglang.srt.utils import is_flashinfer_available
25
+ from sglang.srt.utils import is_cuda, is_flashinfer_available
25
26
 
26
27
  if is_flashinfer_available():
27
28
  import flashinfer
@@ -29,7 +30,12 @@ if is_flashinfer_available():
29
30
  if TYPE_CHECKING:
30
31
  from sglang.srt.layers.radix_attention import RadixAttention
31
32
  from sglang.srt.model_executor.model_runner import ModelRunner
32
- from sglang.srt.speculative.spec_info import SpecInfo
33
+ from sglang.srt.speculative.spec_info import SpecInput
34
+
35
+ _is_cuda = is_cuda()
36
+
37
+ if _is_cuda:
38
+ from sgl_kernel import concat_mla_absorb_q
33
39
 
34
40
  # Constants
35
41
  DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
@@ -45,11 +51,19 @@ TRTLLM_BLOCK_CONSTRAINT = 128
45
51
  global_zero_init_workspace_buffer = None
46
52
 
47
53
 
54
+ @dataclass
55
+ class TRTLLMMLAPrefillMetadata:
56
+ """Metadata for TRTLLM MLA prefill operations."""
57
+
58
+ max_seq_len: int
59
+ cum_seq_lens: torch.Tensor
60
+ seq_lens: torch.Tensor
61
+
62
+
48
63
  @dataclass
49
64
  class TRTLLMMLADecodeMetadata:
50
65
  """Metadata for TRTLLM MLA decode operations."""
51
66
 
52
- workspace: Optional[torch.Tensor] = None
53
67
  block_kv_indices: Optional[torch.Tensor] = None
54
68
  max_seq_len: Optional[int] = None
55
69
 
@@ -64,7 +78,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
64
78
  kv_indptr_buf: Optional[torch.Tensor] = None,
65
79
  q_indptr_decode_buf: Optional[torch.Tensor] = None,
66
80
  ):
67
- super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
81
+ super().__init__(
82
+ model_runner,
83
+ skip_prefill,
84
+ kv_indptr_buf,
85
+ q_indptr_decode_buf,
86
+ )
68
87
 
69
88
  config = model_runner.model_config
70
89
 
@@ -101,7 +120,14 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
101
120
  # CUDA graph state
102
121
  self.decode_cuda_graph_metadata = {}
103
122
  self.decode_cuda_graph_kv_indices = None
104
- self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
123
+ self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
124
+ self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
125
+
126
+ self.disable_chunked_prefix_cache = global_server_args_dict[
127
+ "disable_chunked_prefix_cache"
128
+ ]
129
+
130
+ self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
105
131
 
106
132
  def _calc_padded_blocks(self, max_seq_len: int) -> int:
107
133
  """
@@ -177,9 +203,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
177
203
  self.decode_cuda_graph_kv_indices = torch.full(
178
204
  (max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
179
205
  )
180
- self.decode_cuda_graph_workspace = torch.empty(
181
- self.workspace_size, dtype=torch.int8, device=self.device
182
- )
183
206
 
184
207
  super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
185
208
 
@@ -191,12 +214,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
191
214
  seq_lens: torch.Tensor,
192
215
  encoder_lens: Optional[torch.Tensor],
193
216
  forward_mode: ForwardMode,
194
- spec_info: Optional[SpecInfo],
217
+ spec_info: Optional[SpecInput],
195
218
  ):
196
219
  """Initialize metadata for CUDA graph capture."""
197
220
 
198
221
  # Delegate to parent for non-decode modes.
199
- if not forward_mode.is_decode_or_idle():
222
+ if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
200
223
  return super().init_forward_metadata_capture_cuda_graph(
201
224
  bs,
202
225
  num_tokens,
@@ -207,6 +230,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
207
230
  spec_info,
208
231
  )
209
232
 
233
+ if forward_mode.is_target_verify():
234
+ seq_lens = seq_lens + self.num_draft_tokens
235
+
210
236
  # Custom fast-path for decode/idle.
211
237
  # Capture with full width so future longer sequences are safe during replay
212
238
  max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
@@ -230,12 +256,11 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
230
256
  max_seq_len_val = int(seq_lens.max().item())
231
257
 
232
258
  metadata = TRTLLMMLADecodeMetadata(
233
- self.decode_cuda_graph_workspace,
234
259
  block_kv_indices,
235
260
  max_seq_len_val,
236
261
  )
237
262
  self.decode_cuda_graph_metadata[bs] = metadata
238
- self.forward_metadata = metadata
263
+ self.forward_decode_metadata = metadata
239
264
 
240
265
  def init_forward_metadata_replay_cuda_graph(
241
266
  self,
@@ -245,12 +270,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
245
270
  seq_lens_sum: int,
246
271
  encoder_lens: Optional[torch.Tensor],
247
272
  forward_mode: ForwardMode,
248
- spec_info: Optional[SpecInfo],
273
+ spec_info: Optional[SpecInput],
249
274
  seq_lens_cpu: Optional[torch.Tensor],
250
275
  ):
251
276
  """Replay CUDA graph with new inputs."""
252
277
  # Delegate to parent for non-decode modes.
253
- if not forward_mode.is_decode_or_idle():
278
+ if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
254
279
  return super().init_forward_metadata_replay_cuda_graph(
255
280
  bs,
256
281
  req_pool_indices,
@@ -262,6 +287,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
262
287
  seq_lens_cpu,
263
288
  )
264
289
 
290
+ if forward_mode.is_target_verify():
291
+ seq_lens = seq_lens + self.num_draft_tokens
292
+ del seq_lens_sum # not handle "num_draft_tokens" but we do not need it
293
+
265
294
  metadata = self.decode_cuda_graph_metadata[bs]
266
295
 
267
296
  # Update block indices for new sequences.
@@ -291,31 +320,64 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
291
320
  def init_forward_metadata(self, forward_batch: ForwardBatch):
292
321
  """Initialize the metadata for a forward pass."""
293
322
  # Delegate to parent for non-decode modes.
294
- if not forward_batch.forward_mode.is_decode_or_idle():
295
- return super().init_forward_metadata(forward_batch)
296
-
297
- bs = forward_batch.batch_size
323
+ if (
324
+ forward_batch.forward_mode.is_extend()
325
+ and not forward_batch.forward_mode.is_target_verify()
326
+ and not forward_batch.forward_mode.is_draft_extend()
327
+ ):
328
+ if self.disable_chunked_prefix_cache:
329
+ super().init_forward_metadata(forward_batch)
330
+
331
+ seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens
332
+ cum_seq_lens_q = torch.cat(
333
+ (
334
+ torch.tensor([0], device=forward_batch.seq_lens.device),
335
+ torch.cumsum(seq_lens, dim=0),
336
+ )
337
+ ).int()
338
+ max_seq_len = max(forward_batch.extend_seq_lens_cpu)
339
+ self.forward_prefill_metadata = TRTLLMMLAPrefillMetadata(
340
+ max_seq_len,
341
+ cum_seq_lens_q,
342
+ seq_lens,
343
+ )
344
+ elif (
345
+ forward_batch.forward_mode.is_decode_or_idle()
346
+ or forward_batch.forward_mode.is_target_verify()
347
+ ):
348
+ bs = forward_batch.batch_size
349
+
350
+ # Get maximum sequence length.
351
+ if getattr(forward_batch, "seq_lens_cpu", None) is not None:
352
+ max_seq = forward_batch.seq_lens_cpu.max().item()
353
+ else:
354
+ max_seq = forward_batch.seq_lens.max().item()
355
+
356
+ seq_lens = forward_batch.seq_lens
357
+
358
+ if forward_batch.forward_mode.is_target_verify():
359
+ max_seq = max_seq + self.num_draft_tokens
360
+ seq_lens = seq_lens + self.num_draft_tokens
361
+
362
+ max_seqlen_pad = self._calc_padded_blocks(max_seq)
363
+ block_kv_indices = self._create_block_kv_indices(
364
+ bs,
365
+ max_seqlen_pad,
366
+ forward_batch.req_pool_indices,
367
+ seq_lens,
368
+ seq_lens.device,
369
+ )
298
370
 
299
- # Get maximum sequence length.
300
- if getattr(forward_batch, "seq_lens_cpu", None) is not None:
301
- max_seq = forward_batch.seq_lens_cpu.max().item()
371
+ max_seq_len_val = int(max_seq)
372
+ self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
373
+ block_kv_indices, max_seq_len_val
374
+ )
375
+ forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
302
376
  else:
303
- max_seq = forward_batch.seq_lens.max().item()
304
-
305
- max_seqlen_pad = self._calc_padded_blocks(max_seq)
306
- block_kv_indices = self._create_block_kv_indices(
307
- bs,
308
- max_seqlen_pad,
309
- forward_batch.req_pool_indices,
310
- forward_batch.seq_lens,
311
- forward_batch.seq_lens.device,
312
- )
377
+ return super().init_forward_metadata(forward_batch)
313
378
 
314
- max_seq_len_val = int(max_seq)
315
- self.forward_metadata = TRTLLMMLADecodeMetadata(
316
- self.workspace_buffer, block_kv_indices, max_seq_len_val
317
- )
318
- forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
379
+ def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
380
+ super().init_mha_chunk_metadata(forward_batch, disable_flashinfer_ragged=True)
319
381
 
320
382
  def quantize_and_rope_for_fp8(
321
383
  self,
@@ -443,7 +505,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
443
505
  q_rope_reshaped = q_rope.view(
444
506
  -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
445
507
  )
446
- query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
508
+ query = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped)
447
509
  else:
448
510
  # For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
449
511
  query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
@@ -459,7 +521,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
459
521
  # Get metadata
460
522
  metadata = (
461
523
  getattr(forward_batch, "decode_trtllm_mla_metadata", None)
462
- or self.forward_metadata
524
+ or self.forward_decode_metadata
463
525
  )
464
526
 
465
527
  # Scale computation for TRTLLM MLA kernel BMM1 operation:
@@ -482,7 +544,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
482
544
  raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
483
545
  query=query,
484
546
  kv_cache=kv_cache,
485
- workspace_buffer=metadata.workspace,
547
+ workspace_buffer=self.workspace_buffer,
486
548
  qk_nope_head_dim=self.qk_nope_head_dim,
487
549
  kv_lora_rank=self.kv_lora_rank,
488
550
  qk_rope_head_dim=self.qk_rope_head_dim,
@@ -496,6 +558,174 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
496
558
  output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
497
559
  return output
498
560
 
561
+ def forward_extend(
562
+ self,
563
+ q: torch.Tensor,
564
+ k: torch.Tensor,
565
+ v: torch.Tensor,
566
+ layer: RadixAttention,
567
+ forward_batch: ForwardBatch,
568
+ save_kv_cache: bool = True,
569
+ q_rope: Optional[torch.Tensor] = None,
570
+ k_rope: Optional[torch.Tensor] = None,
571
+ cos_sin_cache: Optional[torch.Tensor] = None,
572
+ is_neox: Optional[bool] = False,
573
+ ) -> torch.Tensor:
574
+ if forward_batch.forward_mode.is_draft_extend():
575
+ return super().forward_extend(
576
+ q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
577
+ )
578
+
579
+ # TODO refactor to avoid code duplication
580
+ merge_query = q_rope is not None
581
+ if (
582
+ self.data_type == torch.float8_e4m3fn
583
+ ) and forward_batch.forward_mode.is_target_verify():
584
+ # For FP8 path, we quantize the query and rope parts and merge them into a single tensor
585
+ # Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend
586
+ assert all(
587
+ x is not None for x in [q_rope, k_rope, cos_sin_cache]
588
+ ), "For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None."
589
+ q, k, k_rope = self.quantize_and_rope_for_fp8(
590
+ q,
591
+ q_rope,
592
+ k.squeeze(1),
593
+ k_rope.squeeze(1),
594
+ forward_batch,
595
+ cos_sin_cache,
596
+ is_neox,
597
+ )
598
+ merge_query = False
599
+
600
+ # Save KV cache if requested
601
+ if save_kv_cache:
602
+ assert (
603
+ k is not None and k_rope is not None
604
+ ), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
605
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
606
+ layer, forward_batch.out_cache_loc, k, k_rope
607
+ )
608
+
609
+ # TODO refactor to avoid code duplication
610
+ # Prepare query tensor inline
611
+ if merge_query:
612
+ # For FP16 path, we merge the query and rope parts into a single tensor
613
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
614
+ q_rope_reshaped = q_rope.view(
615
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
616
+ )
617
+ q = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped)
618
+ else:
619
+ # For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
620
+ q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
621
+
622
+ q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
623
+
624
+ if k_rope is not None:
625
+ k = torch.cat([k, k_rope], dim=-1)
626
+ k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
627
+
628
+ v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
629
+
630
+ if forward_batch.forward_mode.is_target_verify():
631
+ metadata = (
632
+ getattr(forward_batch, "decode_trtllm_mla_metadata", None)
633
+ or self.forward_decode_metadata
634
+ )
635
+
636
+ # Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]
637
+ bs = forward_batch.batch_size
638
+ q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
639
+
640
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
641
+ kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
642
+
643
+ q_scale = 1.0
644
+ k_scale = (
645
+ layer.k_scale_float
646
+ if getattr(layer, "k_scale_float", None) is not None
647
+ else 1.0
648
+ )
649
+
650
+ bmm1_scale = q_scale * k_scale * layer.scaling
651
+
652
+ seq_lens = (
653
+ forward_batch.seq_lens.to(torch.int32)
654
+ + forward_batch.spec_info.draft_token_num
655
+ )
656
+ max_seq_len = metadata.max_seq_len + forward_batch.spec_info.draft_token_num
657
+
658
+ # TODO may use `mla_rope_quantize_fp8` fusion
659
+ q = q.to(self.data_type)
660
+ assert kv_cache.dtype == self.data_type
661
+
662
+ raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
663
+ query=q,
664
+ kv_cache=kv_cache,
665
+ workspace_buffer=self.workspace_buffer,
666
+ qk_nope_head_dim=self.qk_nope_head_dim,
667
+ kv_lora_rank=self.kv_lora_rank,
668
+ qk_rope_head_dim=self.qk_rope_head_dim,
669
+ block_tables=metadata.block_kv_indices,
670
+ seq_lens=seq_lens,
671
+ max_seq_len=max_seq_len,
672
+ bmm1_scale=bmm1_scale,
673
+ )
674
+
675
+ # Reshape output directly without slicing
676
+ output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
677
+ return output
678
+
679
+ if forward_batch.attn_attend_prefix_cache:
680
+ # MHA for chunked prefix kv cache when running model with MLA
681
+ assert forward_batch.prefix_chunk_idx is not None
682
+ assert forward_batch.prefix_chunk_cu_seq_lens is not None
683
+ assert q_rope is None
684
+ assert k_rope is None
685
+ chunk_idx = forward_batch.prefix_chunk_idx
686
+
687
+ output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
688
+ return flashinfer.prefill.trtllm_ragged_attention_deepseek(
689
+ query=q,
690
+ key=k,
691
+ value=v,
692
+ workspace_buffer=self.workspace_buffer,
693
+ seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],
694
+ max_q_len=self.forward_prefill_metadata.max_seq_len,
695
+ max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
696
+ bmm1_scale=layer.scaling,
697
+ bmm2_scale=1.0,
698
+ o_sf_scale=-1.0,
699
+ batch_size=forward_batch.batch_size,
700
+ window_left=-1,
701
+ cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
702
+ cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
703
+ enable_pdl=False,
704
+ is_causal=False,
705
+ return_lse=True,
706
+ out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
707
+ )
708
+
709
+ return flashinfer.prefill.trtllm_ragged_attention_deepseek(
710
+ query=q,
711
+ key=k,
712
+ value=v,
713
+ workspace_buffer=self.workspace_buffer,
714
+ seq_lens=self.forward_prefill_metadata.seq_lens,
715
+ max_q_len=self.forward_prefill_metadata.max_seq_len,
716
+ max_kv_len=self.forward_prefill_metadata.max_seq_len,
717
+ bmm1_scale=layer.scaling,
718
+ bmm2_scale=1.0,
719
+ o_sf_scale=1.0,
720
+ batch_size=forward_batch.batch_size,
721
+ window_left=-1,
722
+ cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
723
+ cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
724
+ enable_pdl=False,
725
+ is_causal=True,
726
+ return_lse=forward_batch.mha_return_lse,
727
+ )
728
+
499
729
 
500
730
  class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
501
731
  """Multi-step draft backend for TRT-LLM MLA used by EAGLE."""
@@ -512,3 +742,10 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
512
742
  kv_indptr_buf=self.kv_indptr[i],
513
743
  q_indptr_decode_buf=self.q_indptr_decode,
514
744
  )
745
+
746
+
747
+ def _concat_mla_absorb_q_general(q_nope, q_rope):
748
+ if _is_cuda and q_nope.shape[-1] == 512 and q_rope.shape[-1] == 64:
749
+ return concat_mla_absorb_q(q_nope, q_rope)
750
+ else:
751
+ return torch.cat([q_nope, q_rope], dim=-1)
@@ -16,14 +16,19 @@ from sglang.srt.utils import (
16
16
  get_device_capability,
17
17
  is_blackwell,
18
18
  is_cuda,
19
+ is_npu,
19
20
  print_info_once,
20
21
  )
21
22
 
22
23
  _is_cuda = is_cuda()
24
+ _is_npu = is_npu()
23
25
 
24
26
  if _is_cuda:
25
27
  from sgl_kernel.flash_attn import flash_attn_varlen_func
26
28
 
29
+ if _is_npu:
30
+ import torch_npu
31
+
27
32
  from sglang.srt.distributed import (
28
33
  split_tensor_along_last_dim,
29
34
  tensor_model_parallel_all_gather,
@@ -331,10 +336,63 @@ class VisionFlash3Attention(nn.Module):
331
336
  return output
332
337
 
333
338
 
339
+ class VisionAscendAttention(nn.Module):
340
+
341
+ def __init__(
342
+ self,
343
+ **kwargs,
344
+ ):
345
+ if not _is_npu:
346
+ raise Exception("VisionAscendAttention is only available for ascend npu")
347
+ super().__init__()
348
+
349
+ def forward(
350
+ self,
351
+ q: torch.Tensor,
352
+ k: torch.Tensor,
353
+ v: torch.Tensor,
354
+ cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
355
+ bsz: int,
356
+ seq_len: int,
357
+ **kwargs,
358
+ ) -> torch.Tensor:
359
+ r"""
360
+ Args:
361
+ cu_seqlens: [b]
362
+ Returns:
363
+ [b * s, h, head_size]
364
+ """
365
+ if cu_seqlens is None:
366
+ cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
367
+
368
+ seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
369
+ if seq_lens.is_npu:
370
+ # cu_seqlens must be on cpu because of operator restriction
371
+ seq_lens = seq_lens.to("cpu")
372
+ _, num_heads, head_size = q.shape
373
+ num_kv_heads = k.shape[1]
374
+ output = torch.empty_like(q)
375
+
376
+ # operator requires pta version >= 2.5.1
377
+ torch_npu._npu_flash_attention_unpad(
378
+ query=q,
379
+ key=k,
380
+ value=v,
381
+ seq_len=seq_lens.to(torch.int32),
382
+ scale_value=head_size**-0.5,
383
+ num_heads=num_heads,
384
+ num_kv_heads=num_kv_heads,
385
+ out=output,
386
+ )
387
+
388
+ return output
389
+
390
+
334
391
  QKV_BACKEND_IMPL = {
335
392
  "triton_attn": VisionTritonAttention,
336
393
  "sdpa": VisionSdpaAttention,
337
394
  "fa3": VisionFlash3Attention,
395
+ "ascend_attn": VisionAscendAttention,
338
396
  }
339
397
 
340
398
 
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  from dataclasses import dataclass
5
- from typing import TYPE_CHECKING, Optional, Union
5
+ from typing import TYPE_CHECKING, Optional
6
6
 
7
7
  import torch
8
8
  import triton
@@ -17,7 +17,7 @@ from sglang.srt.utils import get_bool_env_var, get_device_core_count
17
17
  if TYPE_CHECKING:
18
18
  from sglang.srt.layers.radix_attention import RadixAttention
19
19
  from sglang.srt.model_executor.model_runner import ModelRunner
20
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
20
+ from sglang.srt.speculative.spec_info import SpecInput
21
21
 
22
22
  logger = logging.getLogger(__name__)
23
23
 
@@ -393,7 +393,7 @@ class WaveAttnBackend(AttentionBackend):
393
393
  seq_lens: torch.Tensor,
394
394
  encoder_lens: Optional[torch.Tensor],
395
395
  forward_mode: ForwardMode,
396
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
396
+ spec_info: Optional[SpecInput],
397
397
  ):
398
398
  assert encoder_lens is None, "Not supported"
399
399
 
@@ -477,7 +477,7 @@ class WaveAttnBackend(AttentionBackend):
477
477
  seq_lens_sum: int,
478
478
  encoder_lens: Optional[torch.Tensor],
479
479
  forward_mode: ForwardMode,
480
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
480
+ spec_info: Optional[SpecInput],
481
481
  seq_lens_cpu: Optional[torch.Tensor],
482
482
  ):
483
483
  # NOTE: encoder_lens expected to be zeros or None
@@ -64,8 +64,7 @@ def get_wave_kernel(
64
64
  subs=hyperparams_0,
65
65
  canonicalize=True,
66
66
  run_bench=False,
67
- use_buffer_load_ops=True,
68
- use_buffer_store_ops=True,
67
+ use_buffer_ops=True,
69
68
  waves_per_eu=2,
70
69
  dynamic_symbols=dynamic_symbols_0,
71
70
  wave_runtime=True,
@@ -77,8 +76,7 @@ def get_wave_kernel(
77
76
  subs=hyperparams_1,
78
77
  canonicalize=True,
79
78
  run_bench=False,
80
- use_buffer_load_ops=False,
81
- use_buffer_store_ops=False,
79
+ use_buffer_ops=False,
82
80
  waves_per_eu=4,
83
81
  dynamic_symbols=dynamic_symbols_1,
84
82
  wave_runtime=True,
@@ -67,11 +67,9 @@ def get_wave_kernel(
67
67
  schedule=SchedulingType.NONE,
68
68
  use_scheduling_barriers=False,
69
69
  dynamic_symbols=dynamic_symbols,
70
- use_buffer_load_ops=True,
71
- use_buffer_store_ops=True,
70
+ use_buffer_ops=True,
72
71
  waves_per_eu=2,
73
72
  denorm_fp_math_f32="preserve-sign",
74
- gpu_native_math_precision=True,
75
73
  wave_runtime=True,
76
74
  )
77
75
  options = set_default_run_config(options)
@@ -50,6 +50,7 @@ from sglang.srt.utils import (
50
50
  is_hip,
51
51
  is_sm90_supported,
52
52
  is_sm100_supported,
53
+ prepare_weight_cache,
53
54
  )
54
55
 
55
56
  _is_flashinfer_available = is_flashinfer_available()
@@ -275,7 +276,11 @@ class LayerCommunicator:
275
276
  hidden_states: torch.Tensor,
276
277
  residual: torch.Tensor,
277
278
  forward_batch: ForwardBatch,
279
+ cache=None,
278
280
  ):
281
+ if cache is not None:
282
+ self._context.cache = cache
283
+
279
284
  return self._communicate_with_all_reduce_and_layer_norm_fn(
280
285
  hidden_states=hidden_states,
281
286
  residual=residual,
@@ -349,6 +354,7 @@ class CommunicateContext:
349
354
  attn_tp_size: int
350
355
  attn_dp_size: int
351
356
  tp_size: int
357
+ cache = None
352
358
 
353
359
  def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
354
360
  return self.process_group_sizes[a] == self.process_group_sizes[b]
@@ -533,6 +539,8 @@ class CommunicateWithAllReduceAndLayerNormFn:
533
539
  )
534
540
  else:
535
541
  hidden_states = tensor_model_parallel_all_reduce(hidden_states)
542
+ if context.cache is not None:
543
+ _ = prepare_weight_cache(hidden_states, context.cache)
536
544
  hidden_states, residual = layernorm(hidden_states, residual)
537
545
  return hidden_states, residual
538
546