sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,708 @@
1
+ from dataclasses import astuple, dataclass
2
+ from functools import lru_cache
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
9
+ from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule
10
+ from sglang.srt.layers.attention.fla.fused_recurrent import (
11
+ fused_recurrent_gated_delta_rule_update,
12
+ )
13
+ from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
14
+ fused_sigmoid_gating_delta_rule_update,
15
+ )
16
+ from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
17
+ PAD_SLOT_ID,
18
+ causal_conv1d_fn,
19
+ causal_conv1d_update,
20
+ )
21
+ from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
22
+ from sglang.srt.layers.attention.mamba.mamba2_metadata import (
23
+ ForwardMetadata,
24
+ Mamba2Metadata,
25
+ )
26
+ from sglang.srt.layers.radix_attention import RadixAttention
27
+ from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, MambaPool
28
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
29
+ from sglang.srt.model_executor.model_runner import ModelRunner
30
+ from sglang.srt.models.qwen3_next import fused_gdn_gating
31
+ from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
32
+ from sglang.srt.speculative.spec_info import SpecInput
33
+ from sglang.srt.utils import is_cuda, is_npu
34
+
35
+ if is_cuda():
36
+ from sglang.srt.layers.attention.mamba.causal_conv1d import (
37
+ causal_conv1d_fn as causal_conv1d_fn_cuda,
38
+ )
39
+
40
+ causal_conv1d_fn = causal_conv1d_fn_cuda
41
+ elif is_npu():
42
+ from sgl_kernel_npu.fla.chunk import chunk_gated_delta_rule_npu
43
+ from sgl_kernel_npu.fla.fused_sigmoid_gating_recurrent import (
44
+ fused_sigmoid_gating_delta_rule_update_npu,
45
+ )
46
+ from sgl_kernel_npu.mamba.causal_conv1d import (
47
+ causal_conv1d_fn_npu,
48
+ causal_conv1d_update_npu,
49
+ )
50
+
51
+ chunk_gated_delta_rule = chunk_gated_delta_rule_npu
52
+ fused_sigmoid_gating_delta_rule_update = fused_sigmoid_gating_delta_rule_update_npu
53
+ causal_conv1d_fn = causal_conv1d_fn_npu
54
+ causal_conv1d_update = causal_conv1d_update_npu
55
+
56
+
57
+ class MambaAttnBackendBase(AttentionBackend):
58
+ def __init__(self, model_runner: ModelRunner):
59
+ super().__init__()
60
+ self.pad_slot_id = PAD_SLOT_ID
61
+ self.device = model_runner.device
62
+ self.req_to_token_pool: HybridReqToTokenPool = model_runner.req_to_token_pool
63
+ self.forward_metadata: ForwardMetadata = None
64
+ self.state_indices_list = []
65
+ self.query_start_loc_list = []
66
+ self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None
67
+ self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None
68
+
69
+ def _forward_metadata(self, forward_batch: ForwardBatch):
70
+ bs = forward_batch.batch_size
71
+
72
+ if forward_batch.forward_mode.is_decode_or_idle():
73
+ query_start_loc = torch.arange(
74
+ 0, bs + 1, dtype=torch.int32, device=self.device
75
+ )
76
+ elif forward_batch.forward_mode.is_extend():
77
+ if forward_batch.forward_mode.is_target_verify():
78
+ query_start_loc = torch.arange(
79
+ 0,
80
+ forward_batch.input_ids.shape[0] + 1,
81
+ step=forward_batch.spec_info.draft_token_num,
82
+ dtype=torch.int32,
83
+ device=forward_batch.input_ids.device,
84
+ )
85
+ else:
86
+ query_start_loc = torch.empty(
87
+ (bs + 1,), dtype=torch.int32, device=self.device
88
+ )
89
+ query_start_loc[:bs] = forward_batch.extend_start_loc
90
+ query_start_loc[bs] = (
91
+ forward_batch.extend_start_loc[-1]
92
+ + forward_batch.extend_seq_lens[-1]
93
+ )
94
+ else:
95
+ raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode=}")
96
+ mamba_cache_indices = self.req_to_token_pool.get_mamba_indices(
97
+ forward_batch.req_pool_indices
98
+ )
99
+ return ForwardMetadata(
100
+ query_start_loc=query_start_loc,
101
+ mamba_cache_indices=mamba_cache_indices,
102
+ )
103
+
104
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
105
+ self.forward_metadata = self._forward_metadata(forward_batch)
106
+
107
+ def init_forward_metadata_capture_cuda_graph(
108
+ self,
109
+ bs: int,
110
+ num_tokens: int,
111
+ req_pool_indices: torch.Tensor,
112
+ seq_lens: torch.Tensor,
113
+ encoder_lens: Optional[torch.Tensor],
114
+ forward_mode: ForwardMode,
115
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
116
+ ):
117
+ self.forward_metadata = self._capture_metadata(
118
+ bs, req_pool_indices, forward_mode
119
+ )
120
+
121
+ def init_forward_metadata_replay_cuda_graph(
122
+ self,
123
+ bs: int,
124
+ req_pool_indices: torch.Tensor,
125
+ seq_lens: torch.Tensor,
126
+ seq_lens_sum: int,
127
+ encoder_lens: Optional[torch.Tensor],
128
+ forward_mode: ForwardMode,
129
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
130
+ seq_lens_cpu: Optional[torch.Tensor],
131
+ ):
132
+ self.forward_metadata = self._replay_metadata(
133
+ bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
134
+ )
135
+
136
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
137
+ assert (
138
+ max_num_tokens % max_bs == 0
139
+ ), f"max_num_tokens={max_num_tokens} must be divisible by max_bs={max_bs}"
140
+ verify_step = max_num_tokens / max_bs
141
+ for i in range(max_bs):
142
+ self.state_indices_list.append(
143
+ torch.full(
144
+ (i + 1,), self.pad_slot_id, dtype=torch.int32, device=self.device
145
+ )
146
+ )
147
+ self.query_start_loc_list.append(
148
+ torch.empty((i + 2,), dtype=torch.int32, device=self.device)
149
+ )
150
+ self.cached_cuda_graph_decode_query_start_loc = torch.arange(
151
+ 0, max_bs + 1, dtype=torch.int32, device=self.device
152
+ )
153
+ self.cached_cuda_graph_verify_query_start_loc = torch.arange(
154
+ 0,
155
+ max_bs * verify_step + 1,
156
+ step=verify_step,
157
+ dtype=torch.int32,
158
+ device=self.device,
159
+ )
160
+
161
+ def _capture_metadata(
162
+ self, bs: int, req_pool_indices: torch.Tensor, forward_mode: ForwardMode
163
+ ):
164
+ if forward_mode.is_decode_or_idle():
165
+ self.query_start_loc_list[bs - 1].copy_(
166
+ self.cached_cuda_graph_decode_query_start_loc[: bs + 1]
167
+ )
168
+ elif forward_mode.is_target_verify():
169
+ self.query_start_loc_list[bs - 1].copy_(
170
+ self.cached_cuda_graph_verify_query_start_loc[: bs + 1]
171
+ )
172
+ else:
173
+ raise ValueError(f"Invalid forward mode: {forward_mode=}")
174
+ mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)
175
+ self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
176
+ return ForwardMetadata(
177
+ query_start_loc=self.query_start_loc_list[bs - 1],
178
+ mamba_cache_indices=self.state_indices_list[bs - 1],
179
+ )
180
+
181
+ def _replay_metadata(
182
+ self,
183
+ bs: int,
184
+ req_pool_indices: torch.Tensor,
185
+ forward_mode: ForwardMode,
186
+ spec_info: Optional[SpecInput],
187
+ seq_lens_cpu: Optional[torch.Tensor],
188
+ ):
189
+ num_padding = torch.count_nonzero(
190
+ seq_lens_cpu == self.get_cuda_graph_seq_len_fill_value()
191
+ )
192
+ # Make sure forward metadata is correctly handled for padding reqs
193
+ req_pool_indices[bs - num_padding :] = 0
194
+ mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)
195
+ mamba_indices[bs - num_padding :] = -1
196
+ self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
197
+ if forward_mode.is_decode_or_idle():
198
+ if num_padding == 0:
199
+ self.query_start_loc_list[bs - 1].copy_(
200
+ self.cached_cuda_graph_decode_query_start_loc[: bs + 1]
201
+ )
202
+ else:
203
+ self.query_start_loc_list[bs - 1][: bs - num_padding].copy_(
204
+ self.cached_cuda_graph_decode_query_start_loc[: bs - num_padding]
205
+ )
206
+ self.query_start_loc_list[bs - 1][bs - num_padding :].copy_(
207
+ bs - num_padding
208
+ )
209
+ elif forward_mode.is_target_verify():
210
+ if num_padding == 0:
211
+ self.query_start_loc_list[bs - 1].copy_(
212
+ self.cached_cuda_graph_verify_query_start_loc[: bs + 1]
213
+ )
214
+ else:
215
+ self.query_start_loc_list[bs - 1][: bs - num_padding].copy_(
216
+ self.cached_cuda_graph_verify_query_start_loc[: bs - num_padding]
217
+ )
218
+ self.query_start_loc_list[bs - 1][bs - num_padding :].copy_(
219
+ (bs - num_padding) * spec_info.draft_token_num
220
+ )
221
+ else:
222
+ raise ValueError(f"Invalid forward mode: {forward_mode=}")
223
+
224
+ return ForwardMetadata(
225
+ query_start_loc=self.query_start_loc_list[bs - 1],
226
+ mamba_cache_indices=self.state_indices_list[bs - 1],
227
+ )
228
+
229
+ def get_cuda_graph_seq_len_fill_value(self):
230
+ return 1 # Mamba attn does not use seq lens to index kv cache
231
+
232
+
233
+ class GDNAttnBackend(MambaAttnBackendBase):
234
+ """Attention backend using Mamba kernel."""
235
+
236
+ def forward_decode(
237
+ self,
238
+ q: torch.Tensor,
239
+ k: torch.Tensor,
240
+ v: torch.Tensor,
241
+ layer: RadixAttention,
242
+ forward_batch: ForwardBatch,
243
+ save_kv_cache: bool = True,
244
+ **kwargs,
245
+ ):
246
+ mixed_qkv = kwargs["mixed_qkv"]
247
+ conv_weights = kwargs["conv_weights"]
248
+ bias = kwargs["bias"]
249
+ activation = kwargs["activation"]
250
+ key_dim = kwargs["key_dim"]
251
+ value_dim = kwargs["value_dim"]
252
+ attn_tp_size = kwargs["attention_tp_size"]
253
+ head_k_dim = kwargs["head_k_dim"]
254
+ head_v_dim = kwargs["head_v_dim"]
255
+ a = kwargs["a"]
256
+ b = kwargs["b"]
257
+ A_log = kwargs["A_log"]
258
+ dt_bias = kwargs["dt_bias"]
259
+ layer_id = kwargs["layer_id"]
260
+
261
+ layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
262
+ conv_states = layer_cache.conv
263
+ ssm_states = layer_cache.temporal
264
+ query_start_loc = self.forward_metadata.query_start_loc
265
+ cache_indices = self.forward_metadata.mamba_cache_indices
266
+
267
+ mixed_qkv = causal_conv1d_update(
268
+ mixed_qkv,
269
+ conv_states,
270
+ conv_weights,
271
+ bias,
272
+ activation,
273
+ conv_state_indices=cache_indices,
274
+ )
275
+
276
+ query, key, value = torch.split(
277
+ mixed_qkv,
278
+ [
279
+ key_dim // attn_tp_size,
280
+ key_dim // attn_tp_size,
281
+ value_dim // attn_tp_size,
282
+ ],
283
+ dim=-1,
284
+ )
285
+ # Reshape from [l, h*d] to [1, l, h, d]
286
+ seq_len = query.shape[0]
287
+ num_heads = query.shape[1] // head_k_dim
288
+ query = query.view(1, seq_len, num_heads, head_k_dim)
289
+ key = key.view(1, seq_len, num_heads, head_k_dim)
290
+ value = value.view(1, seq_len, value.shape[1] // head_v_dim, head_v_dim)
291
+
292
+ core_attn_out = fused_sigmoid_gating_delta_rule_update(
293
+ A_log=A_log,
294
+ dt_bias=dt_bias,
295
+ q=query,
296
+ k=key,
297
+ v=value,
298
+ a=a,
299
+ b=b,
300
+ initial_state_source=ssm_states,
301
+ initial_state_indices=cache_indices,
302
+ cu_seqlens=query_start_loc,
303
+ use_qk_l2norm_in_kernel=True,
304
+ softplus_beta=1.0,
305
+ softplus_threshold=20.0,
306
+ )
307
+
308
+ return core_attn_out
309
+
310
+ def forward_extend(
311
+ self,
312
+ q: torch.Tensor,
313
+ k: torch.Tensor,
314
+ v: torch.Tensor,
315
+ layer: RadixAttention,
316
+ forward_batch: ForwardBatch,
317
+ save_kv_cache: bool = True,
318
+ **kwargs,
319
+ ):
320
+ mixed_qkv = kwargs["mixed_qkv"]
321
+ conv_weights = kwargs["conv_weights"]
322
+ bias = kwargs["bias"]
323
+ activation = kwargs["activation"]
324
+ key_dim = kwargs["key_dim"]
325
+ value_dim = kwargs["value_dim"]
326
+ attn_tp_size = kwargs["attention_tp_size"]
327
+ head_k_dim = kwargs["head_k_dim"]
328
+ head_v_dim = kwargs["head_v_dim"]
329
+ a = kwargs["a"]
330
+ b = kwargs["b"]
331
+ A_log = kwargs["A_log"]
332
+ dt_bias = kwargs["dt_bias"]
333
+ layer_id = kwargs["layer_id"]
334
+ seq_len = kwargs["seq_len"]
335
+
336
+ is_target_verify = forward_batch.forward_mode.is_target_verify()
337
+
338
+ query_start_loc = self.forward_metadata.query_start_loc
339
+ cache_indices = self.forward_metadata.mamba_cache_indices
340
+
341
+ mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)
342
+ conv_states = mamba_cache_params.conv
343
+ ssm_states = mamba_cache_params.temporal
344
+ if is_target_verify:
345
+ assert isinstance(mamba_cache_params, MambaPool.SpeculativeState)
346
+ intermediate_state_cache = mamba_cache_params.intermediate_ssm
347
+ intermediate_conv_window_cache = mamba_cache_params.intermediate_conv_window
348
+ has_initial_states = torch.ones(
349
+ seq_len // forward_batch.spec_info.draft_token_num,
350
+ dtype=torch.bool,
351
+ device=forward_batch.input_ids.device,
352
+ )
353
+ conv_states_to_use = conv_states.clone()
354
+ else:
355
+ has_initial_states = forward_batch.extend_prefix_lens > 0
356
+ conv_states_to_use = conv_states
357
+
358
+ if is_target_verify:
359
+ batch_size = seq_len // forward_batch.spec_info.draft_token_num
360
+ draft_token_num = forward_batch.spec_info.draft_token_num
361
+ mixed_qkv_reshaped = (
362
+ mixed_qkv.view(batch_size, draft_token_num, -1)
363
+ .transpose(1, 2)
364
+ .contiguous()
365
+ )
366
+ mixed_qkv_processed = causal_conv1d_update(
367
+ mixed_qkv_reshaped,
368
+ conv_states_to_use,
369
+ conv_weights,
370
+ bias,
371
+ activation,
372
+ conv_state_indices=cache_indices[:batch_size],
373
+ intermediate_conv_window=intermediate_conv_window_cache,
374
+ )
375
+ mixed_qkv = (
376
+ mixed_qkv_processed.transpose(1, 2).contiguous().view(seq_len, -1)
377
+ )
378
+ else:
379
+ mixed_qkv = causal_conv1d_fn(
380
+ mixed_qkv.transpose(0, 1),
381
+ conv_weights,
382
+ bias,
383
+ activation=activation,
384
+ conv_states=conv_states_to_use,
385
+ has_initial_state=has_initial_states,
386
+ cache_indices=cache_indices,
387
+ query_start_loc=query_start_loc,
388
+ seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
389
+ ).transpose(0, 1)[:seq_len]
390
+
391
+ key_split_dim = key_dim // attn_tp_size
392
+ value_split_dim = value_dim // attn_tp_size
393
+
394
+ query, key, value = torch.split(
395
+ mixed_qkv,
396
+ [key_split_dim, key_split_dim, value_split_dim],
397
+ dim=-1,
398
+ )
399
+
400
+ actual_seq_len = query.shape[0]
401
+ num_heads = query.shape[1] // head_k_dim
402
+ num_value_heads = value.shape[1] // head_v_dim
403
+
404
+ query = query.view(1, actual_seq_len, num_heads, head_k_dim)
405
+ key = key.view(1, actual_seq_len, num_heads, head_k_dim)
406
+ value = value.view(1, actual_seq_len, num_value_heads, head_v_dim)
407
+
408
+ beta = b.sigmoid()
409
+ g = fused_gdn_gating(A_log, a, dt_bias)
410
+
411
+ g = g.unsqueeze(0)
412
+ beta = beta.unsqueeze(0)
413
+
414
+ if is_target_verify:
415
+ core_attn_out = fused_recurrent_gated_delta_rule_update(
416
+ q=query,
417
+ k=key,
418
+ v=value,
419
+ g=g,
420
+ beta=beta,
421
+ initial_state_source=ssm_states,
422
+ initial_state_indices=cache_indices,
423
+ cu_seqlens=query_start_loc,
424
+ use_qk_l2norm_in_kernel=True,
425
+ disable_state_update=True,
426
+ intermediate_states_buffer=intermediate_state_cache,
427
+ cache_steps=forward_batch.spec_info.draft_token_num,
428
+ )
429
+ else:
430
+ recurrent_state = ssm_states[cache_indices]
431
+ core_attn_out, last_recurrent_state = chunk_gated_delta_rule(
432
+ q=query,
433
+ k=key,
434
+ v=value,
435
+ g=g,
436
+ beta=beta,
437
+ initial_state=recurrent_state,
438
+ output_final_state=True,
439
+ cu_seqlens=query_start_loc,
440
+ head_first=False,
441
+ use_qk_l2norm_in_kernel=True,
442
+ )
443
+ last_recurrent_state = last_recurrent_state.to(ssm_states.dtype, copy=False)
444
+ ssm_states[cache_indices] = last_recurrent_state
445
+
446
+ return core_attn_out
447
+
448
+
449
+ class Mamba2AttnBackend(MambaAttnBackendBase):
450
+ """Attention backend wrapper for Mamba2Mixer kernels."""
451
+
452
+ def __init__(self, model_runner: ModelRunner):
453
+ super().__init__(model_runner)
454
+ config = model_runner.mamba2_config
455
+ assert config is not None
456
+ self.mamba_chunk_size = config.mamba_chunk_size
457
+
458
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
459
+ metadata = self._forward_metadata(forward_batch)
460
+ self.forward_metadata = Mamba2Metadata.prepare_mixed(
461
+ metadata.query_start_loc,
462
+ metadata.mamba_cache_indices,
463
+ self.mamba_chunk_size,
464
+ forward_batch,
465
+ )
466
+
467
+ def init_forward_metadata_capture_cuda_graph(
468
+ self,
469
+ bs: int,
470
+ num_tokens: int,
471
+ req_pool_indices: torch.Tensor,
472
+ seq_lens: torch.Tensor,
473
+ encoder_lens: Optional[torch.Tensor],
474
+ forward_mode: ForwardMode,
475
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
476
+ ):
477
+ metadata = self._capture_metadata(bs, req_pool_indices, forward_mode)
478
+ self.forward_metadata = Mamba2Metadata.prepare_decode(
479
+ metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
480
+ )
481
+
482
+ def init_forward_metadata_replay_cuda_graph(
483
+ self,
484
+ bs: int,
485
+ req_pool_indices: torch.Tensor,
486
+ seq_lens: torch.Tensor,
487
+ seq_lens_sum: int,
488
+ encoder_lens: Optional[torch.Tensor],
489
+ forward_mode: ForwardMode,
490
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
491
+ seq_lens_cpu: Optional[torch.Tensor],
492
+ ):
493
+ metadata = self._replay_metadata(
494
+ bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
495
+ )
496
+ self.forward_metadata = Mamba2Metadata.prepare_decode(
497
+ metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
498
+ )
499
+
500
+ def forward(
501
+ self,
502
+ mixer: MambaMixer2,
503
+ hidden_states: torch.Tensor,
504
+ output: torch.Tensor,
505
+ layer_id: int,
506
+ mup_vector: Optional[torch.Tensor] = None,
507
+ use_triton_causal_conv: bool = False,
508
+ ):
509
+ assert isinstance(self.forward_metadata, Mamba2Metadata)
510
+ layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
511
+ return mixer.forward(
512
+ hidden_states=hidden_states,
513
+ output=output,
514
+ layer_cache=layer_cache,
515
+ metadata=self.forward_metadata,
516
+ mup_vector=mup_vector,
517
+ use_triton_causal_conv=use_triton_causal_conv,
518
+ )
519
+
520
+ def forward_decode(self, *args, **kwargs):
521
+ raise NotImplementedError(
522
+ "Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
523
+ )
524
+
525
+ def forward_extend(self, *args, **kwargs):
526
+ raise NotImplementedError(
527
+ "Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
528
+ )
529
+
530
+
531
+ class HybridLinearAttnBackend(AttentionBackend):
532
+ """Manages a full and linear attention backend"""
533
+
534
+ def __init__(
535
+ self,
536
+ full_attn_backend: AttentionBackend,
537
+ linear_attn_backend: MambaAttnBackendBase,
538
+ full_attn_layers: list[int],
539
+ ):
540
+ self.full_attn_layers = full_attn_layers
541
+ self.full_attn_backend = full_attn_backend
542
+ self.linear_attn_backend = linear_attn_backend
543
+ self.attn_backend_list = [full_attn_backend, linear_attn_backend]
544
+
545
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
546
+ for attn_backend in self.attn_backend_list:
547
+ attn_backend.init_forward_metadata(forward_batch)
548
+
549
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
550
+ for attn_backend in self.attn_backend_list:
551
+ attn_backend.init_cuda_graph_state(max_bs, max_num_tokens)
552
+
553
+ def init_forward_metadata_capture_cuda_graph(
554
+ self,
555
+ bs: int,
556
+ num_tokens: int,
557
+ req_pool_indices: torch.Tensor,
558
+ seq_lens: torch.Tensor,
559
+ encoder_lens: Optional[torch.Tensor],
560
+ forward_mode: ForwardMode,
561
+ spec_info: Optional[SpecInput],
562
+ ):
563
+ for attn_backend in self.attn_backend_list:
564
+ attn_backend.init_forward_metadata_capture_cuda_graph(
565
+ bs,
566
+ num_tokens,
567
+ req_pool_indices,
568
+ seq_lens,
569
+ encoder_lens,
570
+ forward_mode,
571
+ spec_info,
572
+ )
573
+
574
+ def init_forward_metadata_replay_cuda_graph(
575
+ self,
576
+ bs: int,
577
+ req_pool_indices: torch.Tensor,
578
+ seq_lens: torch.Tensor,
579
+ seq_lens_sum: int,
580
+ encoder_lens: Optional[torch.Tensor],
581
+ forward_mode: ForwardMode,
582
+ spec_info: Optional[SpecInput],
583
+ seq_lens_cpu: Optional[torch.Tensor],
584
+ ):
585
+ for attn_backend in self.attn_backend_list:
586
+ attn_backend.init_forward_metadata_replay_cuda_graph(
587
+ bs,
588
+ req_pool_indices,
589
+ seq_lens,
590
+ seq_lens_sum,
591
+ encoder_lens,
592
+ forward_mode,
593
+ spec_info,
594
+ seq_lens_cpu,
595
+ )
596
+
597
+ def get_cuda_graph_seq_len_fill_value(self):
598
+ return self.full_attn_backend.get_cuda_graph_seq_len_fill_value()
599
+
600
+ def forward_decode(
601
+ self,
602
+ q: torch.Tensor,
603
+ k: torch.Tensor,
604
+ v: torch.Tensor,
605
+ layer: RadixAttention,
606
+ forward_batch: ForwardBatch,
607
+ save_kv_cache: bool = True,
608
+ **kwargs,
609
+ ):
610
+ layer_id = layer.layer_id if layer else kwargs["layer_id"]
611
+ if layer_id in self.full_attn_layers:
612
+ return self.full_attn_backend.forward_decode(
613
+ q, k, v, layer, forward_batch, save_kv_cache, **kwargs
614
+ )
615
+ return self.linear_attn_backend.forward_decode(
616
+ q, k, v, layer, forward_batch, save_kv_cache, **kwargs
617
+ )
618
+
619
+ def forward_extend(
620
+ self,
621
+ q: torch.Tensor,
622
+ k: torch.Tensor,
623
+ v: torch.Tensor,
624
+ layer: RadixAttention,
625
+ forward_batch: ForwardBatch,
626
+ save_kv_cache: bool = True,
627
+ **kwargs,
628
+ ):
629
+ layer_id = layer.layer_id if layer else kwargs["layer_id"]
630
+ if layer_id in self.full_attn_layers:
631
+ return self.full_attn_backend.forward_extend(
632
+ q, k, v, layer, forward_batch, save_kv_cache, **kwargs
633
+ )
634
+ return self.linear_attn_backend.forward_extend(
635
+ q, k, v, layer, forward_batch, save_kv_cache, **kwargs
636
+ )
637
+
638
+ def forward(
639
+ self,
640
+ q: torch.Tensor,
641
+ k: torch.Tensor,
642
+ v: torch.Tensor,
643
+ layer: RadixAttention,
644
+ forward_batch: ForwardBatch,
645
+ save_kv_cache: bool = True,
646
+ **kwargs,
647
+ ):
648
+ """Run forward on an attention layer."""
649
+ if forward_batch.forward_mode.is_idle():
650
+ if layer is None:
651
+ return torch.empty_like(kwargs["z"])
652
+ return q.new_empty(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
653
+ elif forward_batch.forward_mode.is_decode():
654
+ return self.forward_decode(
655
+ q,
656
+ k,
657
+ v,
658
+ layer,
659
+ forward_batch,
660
+ save_kv_cache=save_kv_cache,
661
+ **kwargs,
662
+ )
663
+ else:
664
+ return self.forward_extend(
665
+ q,
666
+ k,
667
+ v,
668
+ layer,
669
+ forward_batch,
670
+ save_kv_cache=save_kv_cache,
671
+ **kwargs,
672
+ )
673
+
674
+ def update_mamba_state_after_mtp_verify(self, accepted_length, model):
675
+ request_number = accepted_length.shape[0]
676
+
677
+ state_indices_tensor = (
678
+ self.linear_attn_backend.forward_metadata.mamba_cache_indices[
679
+ :request_number
680
+ ]
681
+ )
682
+
683
+ mamba_caches = (
684
+ self.linear_attn_backend.req_to_token_pool.get_speculative_mamba2_params_all_layers()
685
+ )
686
+
687
+ conv_states = mamba_caches.conv
688
+ ssm_states = mamba_caches.temporal
689
+ intermediate_state_cache = mamba_caches.intermediate_ssm
690
+ intermediate_conv_window_cache = mamba_caches.intermediate_conv_window
691
+
692
+ # SSM state updates (chunked to reduce peak memory)
693
+ valid_mask = accepted_length > 0
694
+
695
+ # Compute common indices once to avoid duplication
696
+ last_steps_all = (accepted_length - 1).to(torch.int64)
697
+ valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64) # [N]
698
+ last_steps = last_steps_all[valid_mask].to(torch.int64) # [N]
699
+
700
+ # scatter into ssm_states at the chosen cache lines
701
+ ssm_states[:, valid_state_indices, :] = intermediate_state_cache[
702
+ :, valid_state_indices, last_steps
703
+ ].to(ssm_states.dtype, copy=False)
704
+
705
+ # Scatter into conv_states at the chosen cache lines
706
+ conv_states[:, valid_state_indices, :, :] = intermediate_conv_window_cache[
707
+ :, valid_state_indices, last_steps
708
+ ].to(conv_states.dtype, copy=False)