sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,6 @@
1
- from dataclasses import astuple, dataclass
2
- from functools import lru_cache
3
1
  from typing import Optional, Union
4
2
 
5
3
  import torch
6
- import torch.nn.functional as F
7
4
 
8
5
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
9
6
  from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule
@@ -14,18 +11,31 @@ from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
14
11
  fused_sigmoid_gating_delta_rule_update,
15
12
  )
16
13
  from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
14
+ PAD_SLOT_ID,
17
15
  causal_conv1d_fn,
18
16
  causal_conv1d_update,
19
17
  )
18
+ from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
19
+ from sglang.srt.layers.attention.mamba.mamba2_metadata import (
20
+ ForwardMetadata,
21
+ Mamba2Metadata,
22
+ )
20
23
  from sglang.srt.layers.radix_attention import RadixAttention
21
- from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
24
+ from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, MambaPool
22
25
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
23
26
  from sglang.srt.model_executor.model_runner import ModelRunner
24
- from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer, fused_gdn_gating
25
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
26
- from sglang.srt.utils import is_npu
27
+ from sglang.srt.models.qwen3_next import fused_gdn_gating
28
+ from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
29
+ from sglang.srt.speculative.spec_info import SpecInput
30
+ from sglang.srt.utils import is_cuda, is_npu
31
+
32
+ if is_cuda():
33
+ from sglang.srt.layers.attention.mamba.causal_conv1d import (
34
+ causal_conv1d_fn as causal_conv1d_fn_cuda,
35
+ )
27
36
 
28
- if is_npu():
37
+ causal_conv1d_fn = causal_conv1d_fn_cuda
38
+ elif is_npu():
29
39
  from sgl_kernel_npu.fla.chunk import chunk_gated_delta_rule_npu
30
40
  from sgl_kernel_npu.fla.fused_sigmoid_gating_recurrent import (
31
41
  fused_sigmoid_gating_delta_rule_update_npu,
@@ -41,35 +51,25 @@ if is_npu():
41
51
  causal_conv1d_update = causal_conv1d_update_npu
42
52
 
43
53
 
44
- @dataclass
45
- class ForwardMetadata:
46
- query_start_loc: Optional[torch.Tensor]
47
- mamba_cache_indices: torch.Tensor
48
-
49
-
50
- class MambaAttnBackend(AttentionBackend):
51
- """Attention backend using Mamba kernel."""
52
-
54
+ class MambaAttnBackendBase(AttentionBackend):
53
55
  def __init__(self, model_runner: ModelRunner):
54
56
  super().__init__()
55
- self.pad_slot_id = -1 # Default pad slot id
57
+ self.pad_slot_id = PAD_SLOT_ID
56
58
  self.device = model_runner.device
57
59
  self.req_to_token_pool: HybridReqToTokenPool = model_runner.req_to_token_pool
58
60
  self.forward_metadata: ForwardMetadata = None
59
61
  self.state_indices_list = []
60
62
  self.query_start_loc_list = []
63
+ self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None
64
+ self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None
61
65
 
62
- @classmethod
63
- @lru_cache(maxsize=128)
64
- def _get_cached_arange(cls, bs: int, device_str: str) -> torch.Tensor:
65
- """Cache torch.arange tensors for common batch sizes to avoid repeated allocation."""
66
- device = torch.device(device_str)
67
- return torch.arange(0, bs + 1, dtype=torch.int32, device=device)
68
-
69
- def init_forward_metadata(self, forward_batch: ForwardBatch):
66
+ def _forward_metadata(self, forward_batch: ForwardBatch):
70
67
  bs = forward_batch.batch_size
68
+
71
69
  if forward_batch.forward_mode.is_decode_or_idle():
72
- query_start_loc = self._get_cached_arange(bs, str(self.device))
70
+ query_start_loc = torch.arange(
71
+ 0, bs + 1, dtype=torch.int32, device=self.device
72
+ )
73
73
  elif forward_batch.forward_mode.is_extend():
74
74
  if forward_batch.forward_mode.is_target_verify():
75
75
  query_start_loc = torch.arange(
@@ -93,12 +93,48 @@ class MambaAttnBackend(AttentionBackend):
93
93
  mamba_cache_indices = self.req_to_token_pool.get_mamba_indices(
94
94
  forward_batch.req_pool_indices
95
95
  )
96
- self.forward_metadata = ForwardMetadata(
96
+ return ForwardMetadata(
97
97
  query_start_loc=query_start_loc,
98
98
  mamba_cache_indices=mamba_cache_indices,
99
99
  )
100
100
 
101
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
102
+ self.forward_metadata = self._forward_metadata(forward_batch)
103
+
104
+ def init_forward_metadata_capture_cuda_graph(
105
+ self,
106
+ bs: int,
107
+ num_tokens: int,
108
+ req_pool_indices: torch.Tensor,
109
+ seq_lens: torch.Tensor,
110
+ encoder_lens: Optional[torch.Tensor],
111
+ forward_mode: ForwardMode,
112
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
113
+ ):
114
+ self.forward_metadata = self._capture_metadata(
115
+ bs, req_pool_indices, forward_mode
116
+ )
117
+
118
+ def init_forward_metadata_replay_cuda_graph(
119
+ self,
120
+ bs: int,
121
+ req_pool_indices: torch.Tensor,
122
+ seq_lens: torch.Tensor,
123
+ seq_lens_sum: int,
124
+ encoder_lens: Optional[torch.Tensor],
125
+ forward_mode: ForwardMode,
126
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
127
+ seq_lens_cpu: Optional[torch.Tensor],
128
+ ):
129
+ self.forward_metadata = self._replay_metadata(
130
+ bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
131
+ )
132
+
101
133
  def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
134
+ assert (
135
+ max_num_tokens % max_bs == 0
136
+ ), f"max_num_tokens={max_num_tokens} must be divisible by max_bs={max_bs}"
137
+ verify_step = max_num_tokens / max_bs
102
138
  for i in range(max_bs):
103
139
  self.state_indices_list.append(
104
140
  torch.full(
@@ -108,47 +144,43 @@ class MambaAttnBackend(AttentionBackend):
108
144
  self.query_start_loc_list.append(
109
145
  torch.empty((i + 2,), dtype=torch.int32, device=self.device)
110
146
  )
147
+ self.cached_cuda_graph_decode_query_start_loc = torch.arange(
148
+ 0, max_bs + 1, dtype=torch.int32, device=self.device
149
+ )
150
+ self.cached_cuda_graph_verify_query_start_loc = torch.arange(
151
+ 0,
152
+ max_bs * verify_step + 1,
153
+ step=verify_step,
154
+ dtype=torch.int32,
155
+ device=self.device,
156
+ )
111
157
 
112
- def init_forward_metadata_capture_cuda_graph(
113
- self,
114
- bs: int,
115
- num_tokens: int,
116
- req_pool_indices: torch.Tensor,
117
- seq_lens: torch.Tensor,
118
- encoder_lens: Optional[torch.Tensor],
119
- forward_mode: ForwardMode,
120
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
158
+ def _capture_metadata(
159
+ self, bs: int, req_pool_indices: torch.Tensor, forward_mode: ForwardMode
121
160
  ):
122
161
  if forward_mode.is_decode_or_idle():
123
- self.query_start_loc_list[bs - 1].copy_(self._get_cached_arange(bs, "cuda"))
162
+ self.query_start_loc_list[bs - 1].copy_(
163
+ self.cached_cuda_graph_decode_query_start_loc[: bs + 1]
164
+ )
124
165
  elif forward_mode.is_target_verify():
125
166
  self.query_start_loc_list[bs - 1].copy_(
126
- torch.arange(
127
- 0,
128
- bs * spec_info.draft_token_num + 1,
129
- step=spec_info.draft_token_num,
130
- dtype=torch.int32,
131
- device=self.device,
132
- )
167
+ self.cached_cuda_graph_verify_query_start_loc[: bs + 1]
133
168
  )
134
169
  else:
135
170
  raise ValueError(f"Invalid forward mode: {forward_mode=}")
136
171
  mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)
137
172
  self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
138
- self.forward_metadata = ForwardMetadata(
173
+ return ForwardMetadata(
139
174
  query_start_loc=self.query_start_loc_list[bs - 1],
140
175
  mamba_cache_indices=self.state_indices_list[bs - 1],
141
176
  )
142
177
 
143
- def init_forward_metadata_replay_cuda_graph(
178
+ def _replay_metadata(
144
179
  self,
145
180
  bs: int,
146
181
  req_pool_indices: torch.Tensor,
147
- seq_lens: torch.Tensor,
148
- seq_lens_sum: int,
149
- encoder_lens: Optional[torch.Tensor],
150
182
  forward_mode: ForwardMode,
151
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
183
+ spec_info: Optional[SpecInput],
152
184
  seq_lens_cpu: Optional[torch.Tensor],
153
185
  ):
154
186
  num_padding = torch.count_nonzero(
@@ -160,27 +192,33 @@ class MambaAttnBackend(AttentionBackend):
160
192
  mamba_indices[bs - num_padding :] = -1
161
193
  self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
162
194
  if forward_mode.is_decode_or_idle():
163
- self.query_start_loc_list[bs - 1].copy_(self._get_cached_arange(bs, "cuda"))
164
- if num_padding > 0:
165
- self.query_start_loc_list[bs - 1][bs - num_padding :] = bs - num_padding
166
- elif forward_mode.is_target_verify():
167
- self.query_start_loc_list[bs - 1].copy_(
168
- torch.arange(
169
- 0,
170
- bs * spec_info.draft_token_num + 1,
171
- step=spec_info.draft_token_num,
172
- dtype=torch.int32,
173
- device=self.device,
195
+ if num_padding == 0:
196
+ self.query_start_loc_list[bs - 1].copy_(
197
+ self.cached_cuda_graph_decode_query_start_loc[: bs + 1]
174
198
  )
175
- )
176
- if num_padding > 0:
177
- self.query_start_loc_list[bs - 1][bs - num_padding :] = (
199
+ else:
200
+ self.query_start_loc_list[bs - 1][: bs - num_padding].copy_(
201
+ self.cached_cuda_graph_decode_query_start_loc[: bs - num_padding]
202
+ )
203
+ self.query_start_loc_list[bs - 1][bs - num_padding :].copy_(
178
204
  bs - num_padding
179
- ) * spec_info.draft_token_num
205
+ )
206
+ elif forward_mode.is_target_verify():
207
+ if num_padding == 0:
208
+ self.query_start_loc_list[bs - 1].copy_(
209
+ self.cached_cuda_graph_verify_query_start_loc[: bs + 1]
210
+ )
211
+ else:
212
+ self.query_start_loc_list[bs - 1][: bs - num_padding].copy_(
213
+ self.cached_cuda_graph_verify_query_start_loc[: bs - num_padding]
214
+ )
215
+ self.query_start_loc_list[bs - 1][bs - num_padding :].copy_(
216
+ (bs - num_padding) * spec_info.draft_token_num
217
+ )
180
218
  else:
181
219
  raise ValueError(f"Invalid forward mode: {forward_mode=}")
182
220
 
183
- self.forward_metadata = ForwardMetadata(
221
+ return ForwardMetadata(
184
222
  query_start_loc=self.query_start_loc_list[bs - 1],
185
223
  mamba_cache_indices=self.state_indices_list[bs - 1],
186
224
  )
@@ -188,6 +226,10 @@ class MambaAttnBackend(AttentionBackend):
188
226
  def get_cuda_graph_seq_len_fill_value(self):
189
227
  return 1 # Mamba attn does not use seq lens to index kv cache
190
228
 
229
+
230
+ class GDNAttnBackend(MambaAttnBackendBase):
231
+ """Attention backend using Mamba kernel."""
232
+
191
233
  def forward_decode(
192
234
  self,
193
235
  q: torch.Tensor,
@@ -213,9 +255,9 @@ class MambaAttnBackend(AttentionBackend):
213
255
  dt_bias = kwargs["dt_bias"]
214
256
  layer_id = kwargs["layer_id"]
215
257
 
216
- conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
217
- layer_id
218
- )
258
+ layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
259
+ conv_states = layer_cache.conv
260
+ ssm_states = layer_cache.temporal
219
261
  query_start_loc = self.forward_metadata.query_start_loc
220
262
  cache_indices = self.forward_metadata.mamba_cache_indices
221
263
 
@@ -293,13 +335,13 @@ class MambaAttnBackend(AttentionBackend):
293
335
  query_start_loc = self.forward_metadata.query_start_loc
294
336
  cache_indices = self.forward_metadata.mamba_cache_indices
295
337
 
338
+ mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)
339
+ conv_states = mamba_cache_params.conv
340
+ ssm_states = mamba_cache_params.temporal
296
341
  if is_target_verify:
297
- (
298
- conv_states,
299
- ssm_states,
300
- intermediate_state_cache,
301
- intermediate_conv_window_cache,
302
- ) = self.req_to_token_pool.get_mamba_params(layer_id)
342
+ assert isinstance(mamba_cache_params, MambaPool.SpeculativeState)
343
+ intermediate_state_cache = mamba_cache_params.intermediate_ssm
344
+ intermediate_conv_window_cache = mamba_cache_params.intermediate_conv_window
303
345
  has_initial_states = torch.ones(
304
346
  seq_len // forward_batch.spec_info.draft_token_num,
305
347
  dtype=torch.bool,
@@ -307,9 +349,6 @@ class MambaAttnBackend(AttentionBackend):
307
349
  )
308
350
  conv_states_to_use = conv_states.clone()
309
351
  else:
310
- conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
311
- layer_id
312
- )
313
352
  has_initial_states = forward_batch.extend_prefix_lens > 0
314
353
  conv_states_to_use = conv_states
315
354
 
@@ -343,6 +382,7 @@ class MambaAttnBackend(AttentionBackend):
343
382
  has_initial_state=has_initial_states,
344
383
  cache_indices=cache_indices,
345
384
  query_start_loc=query_start_loc,
385
+ seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
346
386
  ).transpose(0, 1)[:seq_len]
347
387
 
348
388
  key_split_dim = key_dim // attn_tp_size
@@ -403,16 +443,100 @@ class MambaAttnBackend(AttentionBackend):
403
443
  return core_attn_out
404
444
 
405
445
 
446
+ class Mamba2AttnBackend(MambaAttnBackendBase):
447
+ """Attention backend wrapper for Mamba2Mixer kernels."""
448
+
449
+ def __init__(self, model_runner: ModelRunner):
450
+ super().__init__(model_runner)
451
+ config = model_runner.mamba2_config
452
+ assert config is not None
453
+ self.mamba_chunk_size = config.mamba_chunk_size
454
+
455
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
456
+ metadata = self._forward_metadata(forward_batch)
457
+ self.forward_metadata = Mamba2Metadata.prepare_mixed(
458
+ metadata.query_start_loc,
459
+ metadata.mamba_cache_indices,
460
+ self.mamba_chunk_size,
461
+ forward_batch,
462
+ )
463
+
464
+ def init_forward_metadata_capture_cuda_graph(
465
+ self,
466
+ bs: int,
467
+ num_tokens: int,
468
+ req_pool_indices: torch.Tensor,
469
+ seq_lens: torch.Tensor,
470
+ encoder_lens: Optional[torch.Tensor],
471
+ forward_mode: ForwardMode,
472
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
473
+ ):
474
+ metadata = self._capture_metadata(bs, req_pool_indices, forward_mode)
475
+ self.forward_metadata = Mamba2Metadata.prepare_decode(
476
+ metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
477
+ )
478
+
479
+ def init_forward_metadata_replay_cuda_graph(
480
+ self,
481
+ bs: int,
482
+ req_pool_indices: torch.Tensor,
483
+ seq_lens: torch.Tensor,
484
+ seq_lens_sum: int,
485
+ encoder_lens: Optional[torch.Tensor],
486
+ forward_mode: ForwardMode,
487
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
488
+ seq_lens_cpu: Optional[torch.Tensor],
489
+ ):
490
+ metadata = self._replay_metadata(
491
+ bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
492
+ )
493
+ self.forward_metadata = Mamba2Metadata.prepare_decode(
494
+ metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
495
+ )
496
+
497
+ def forward(
498
+ self,
499
+ mixer: MambaMixer2,
500
+ hidden_states: torch.Tensor,
501
+ output: torch.Tensor,
502
+ layer_id: int,
503
+ mup_vector: Optional[torch.Tensor] = None,
504
+ use_triton_causal_conv: bool = False,
505
+ ):
506
+ assert isinstance(self.forward_metadata, Mamba2Metadata)
507
+ layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
508
+ return mixer.forward(
509
+ hidden_states=hidden_states,
510
+ output=output,
511
+ layer_cache=layer_cache,
512
+ metadata=self.forward_metadata,
513
+ mup_vector=mup_vector,
514
+ use_triton_causal_conv=use_triton_causal_conv,
515
+ )
516
+
517
+ def forward_decode(self, *args, **kwargs):
518
+ raise NotImplementedError(
519
+ "Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
520
+ )
521
+
522
+ def forward_extend(self, *args, **kwargs):
523
+ raise NotImplementedError(
524
+ "Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
525
+ )
526
+
527
+
406
528
  class HybridLinearAttnBackend(AttentionBackend):
407
- """Support different backends for prefill and decode."""
529
+ """Manages a full and linear attention backend"""
408
530
 
409
531
  def __init__(
410
532
  self,
411
533
  full_attn_backend: AttentionBackend,
412
- linear_attn_backend: AttentionBackend,
534
+ linear_attn_backend: MambaAttnBackendBase,
413
535
  full_attn_layers: list[int],
414
536
  ):
415
537
  self.full_attn_layers = full_attn_layers
538
+ self.full_attn_backend = full_attn_backend
539
+ self.linear_attn_backend = linear_attn_backend
416
540
  self.attn_backend_list = [full_attn_backend, linear_attn_backend]
417
541
 
418
542
  def init_forward_metadata(self, forward_batch: ForwardBatch):
@@ -431,7 +555,7 @@ class HybridLinearAttnBackend(AttentionBackend):
431
555
  seq_lens: torch.Tensor,
432
556
  encoder_lens: Optional[torch.Tensor],
433
557
  forward_mode: ForwardMode,
434
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
558
+ spec_info: Optional[SpecInput],
435
559
  ):
436
560
  for attn_backend in self.attn_backend_list:
437
561
  attn_backend.init_forward_metadata_capture_cuda_graph(
@@ -452,7 +576,7 @@ class HybridLinearAttnBackend(AttentionBackend):
452
576
  seq_lens_sum: int,
453
577
  encoder_lens: Optional[torch.Tensor],
454
578
  forward_mode: ForwardMode,
455
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
579
+ spec_info: Optional[SpecInput],
456
580
  seq_lens_cpu: Optional[torch.Tensor],
457
581
  ):
458
582
  for attn_backend in self.attn_backend_list:
@@ -468,7 +592,7 @@ class HybridLinearAttnBackend(AttentionBackend):
468
592
  )
469
593
 
470
594
  def get_cuda_graph_seq_len_fill_value(self):
471
- return self.attn_backend_list[0].get_cuda_graph_seq_len_fill_value()
595
+ return self.full_attn_backend.get_cuda_graph_seq_len_fill_value()
472
596
 
473
597
  def forward_decode(
474
598
  self,
@@ -482,10 +606,10 @@ class HybridLinearAttnBackend(AttentionBackend):
482
606
  ):
483
607
  layer_id = layer.layer_id if layer else kwargs["layer_id"]
484
608
  if layer_id in self.full_attn_layers:
485
- return self.attn_backend_list[0].forward_decode(
609
+ return self.full_attn_backend.forward_decode(
486
610
  q, k, v, layer, forward_batch, save_kv_cache, **kwargs
487
611
  )
488
- return self.attn_backend_list[1].forward_decode(
612
+ return self.linear_attn_backend.forward_decode(
489
613
  q, k, v, layer, forward_batch, save_kv_cache, **kwargs
490
614
  )
491
615
 
@@ -501,10 +625,10 @@ class HybridLinearAttnBackend(AttentionBackend):
501
625
  ):
502
626
  layer_id = layer.layer_id if layer else kwargs["layer_id"]
503
627
  if layer_id in self.full_attn_layers:
504
- return self.attn_backend_list[0].forward_extend(
628
+ return self.full_attn_backend.forward_extend(
505
629
  q, k, v, layer, forward_batch, save_kv_cache, **kwargs
506
630
  )
507
- return self.attn_backend_list[1].forward_extend(
631
+ return self.linear_attn_backend.forward_extend(
508
632
  q, k, v, layer, forward_batch, save_kv_cache, **kwargs
509
633
  )
510
634
 
@@ -547,56 +671,35 @@ class HybridLinearAttnBackend(AttentionBackend):
547
671
  def update_mamba_state_after_mtp_verify(self, accepted_length, model):
548
672
  request_number = accepted_length.shape[0]
549
673
 
550
- state_indices_tensor = self.attn_backend_list[
551
- 1
552
- ].forward_metadata.mamba_cache_indices[:request_number]
674
+ state_indices_tensor = (
675
+ self.linear_attn_backend.forward_metadata.mamba_cache_indices[
676
+ :request_number
677
+ ]
678
+ )
553
679
 
554
- mamba_caches = self.attn_backend_list[
555
- 1
556
- ].req_to_token_pool.get_mamba_params_all_layers()
680
+ mamba_caches = (
681
+ self.linear_attn_backend.req_to_token_pool.get_speculative_mamba2_params_all_layers()
682
+ )
557
683
 
558
- (
559
- conv_states,
560
- ssm_states,
561
- intermediate_state_cache,
562
- intermediate_conv_window_cache,
563
- ) = mamba_caches
684
+ conv_states = mamba_caches.conv
685
+ ssm_states = mamba_caches.temporal
686
+ intermediate_state_cache = mamba_caches.intermediate_ssm
687
+ intermediate_conv_window_cache = mamba_caches.intermediate_conv_window
564
688
 
565
689
  # SSM state updates (chunked to reduce peak memory)
566
690
  valid_mask = accepted_length > 0
567
691
 
568
692
  # Compute common indices once to avoid duplication
569
693
  last_steps_all = (accepted_length - 1).to(torch.int64)
570
- valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64)
571
- last_steps = last_steps_all[valid_mask].to(torch.int64)
572
-
573
- if valid_state_indices.numel() > 0:
574
- chunk = 256
575
- num_valid = valid_state_indices.numel()
576
-
577
- # SSM state updates
578
- for i in range(0, num_valid, chunk):
579
- idx = valid_state_indices[i : i + chunk]
580
- steps = last_steps[i : i + chunk]
581
- # per (cache line, step)
582
- for j in range(idx.numel()):
583
- ci = idx[j].item()
584
- st = steps[j].item()
585
- ssm_states[:, ci, :].copy_(
586
- intermediate_state_cache[:, ci, st].to(
587
- ssm_states.dtype, copy=False
588
- )
589
- )
590
-
591
- # Conv window updates
592
- for i in range(0, num_valid, chunk):
593
- idx = valid_state_indices[i : i + chunk]
594
- steps = last_steps[i : i + chunk]
595
- for j in range(idx.numel()):
596
- ci = idx[j].item()
597
- st = steps[j].item()
598
- conv_states[:, ci, :, :].copy_(
599
- intermediate_conv_window_cache[:, ci, st].to(
600
- conv_states.dtype, copy=False
601
- )
602
- )
694
+ valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64) # [N]
695
+ last_steps = last_steps_all[valid_mask].to(torch.int64) # [N]
696
+
697
+ # scatter into ssm_states at the chosen cache lines
698
+ ssm_states[:, valid_state_indices, :] = intermediate_state_cache[
699
+ :, valid_state_indices, last_steps
700
+ ].to(ssm_states.dtype, copy=False)
701
+
702
+ # Scatter into conv_states at the chosen cache lines
703
+ conv_states[:, valid_state_indices, :, :] = intermediate_conv_window_cache[
704
+ :, valid_state_indices, last_steps
705
+ ].to(conv_states.dtype, copy=False)
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
14
14
 
15
15
  class IntelAMXAttnBackend(AttentionBackend):
16
16
  def __init__(self, model_runner: ModelRunner):
17
- import sgl_kernel
17
+ import sgl_kernel # noqa: F401
18
18
 
19
19
  super().__init__()
20
20
  self.forward_metadata = None
@@ -10,7 +10,7 @@ import torch
10
10
  from sgl_kernel import causal_conv1d_fwd
11
11
  from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel
12
12
 
13
- PAD_SLOT_ID = -1
13
+ from .causal_conv1d_triton import PAD_SLOT_ID
14
14
 
15
15
 
16
16
  def causal_conv1d_fn(
@@ -23,6 +23,7 @@ def causal_conv1d_fn(
23
23
  conv_states: Optional[torch.Tensor] = None,
24
24
  activation: Optional[str] = "silu",
25
25
  pad_slot_id: int = PAD_SLOT_ID,
26
+ **kwargs,
26
27
  ):
27
28
  """
28
29
  x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen