sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -478,7 +478,7 @@ class FlashMLAMultiStepDraftBackend:
478
478
  )
479
479
 
480
480
  self.attn_backends = []
481
- for i in range(self.speculative_num_steps):
481
+ for i in range(self.speculative_num_steps - 1):
482
482
  self.attn_backends.append(
483
483
  FlashMLABackend(
484
484
  model_runner,
@@ -506,7 +506,7 @@ class FlashMLAMultiStepDraftBackend:
506
506
  self.common_template(forward_batch, call_fn)
507
507
 
508
508
  def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
509
- for i in range(self.speculative_num_steps):
509
+ for i in range(self.speculative_num_steps - 1):
510
510
  self.attn_backends[i].init_cuda_graph_state(
511
511
  max_bs, max_num_tokens, block_kv_indices=None
512
512
  )
@@ -1,4 +1,4 @@
1
- from typing import Optional, Union
1
+ from typing import Optional
2
2
 
3
3
  import torch
4
4
 
@@ -1,9 +1,6 @@
1
- from dataclasses import astuple, dataclass
2
- from functools import lru_cache
3
1
  from typing import Optional, Union
4
2
 
5
3
  import torch
6
- import torch.nn.functional as F
7
4
 
8
5
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
9
6
  from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule
@@ -14,14 +11,21 @@ from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
14
11
  fused_sigmoid_gating_delta_rule_update,
15
12
  )
16
13
  from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
14
+ PAD_SLOT_ID,
17
15
  causal_conv1d_fn,
18
16
  causal_conv1d_update,
19
17
  )
18
+ from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
19
+ from sglang.srt.layers.attention.mamba.mamba2_metadata import (
20
+ ForwardMetadata,
21
+ Mamba2Metadata,
22
+ )
20
23
  from sglang.srt.layers.radix_attention import RadixAttention
21
- from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
24
+ from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, MambaPool
22
25
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
23
26
  from sglang.srt.model_executor.model_runner import ModelRunner
24
27
  from sglang.srt.models.qwen3_next import fused_gdn_gating
28
+ from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
25
29
  from sglang.srt.speculative.spec_info import SpecInput
26
30
  from sglang.srt.utils import is_cuda, is_npu
27
31
 
@@ -47,18 +51,10 @@ elif is_npu():
47
51
  causal_conv1d_update = causal_conv1d_update_npu
48
52
 
49
53
 
50
- @dataclass
51
- class ForwardMetadata:
52
- query_start_loc: Optional[torch.Tensor]
53
- mamba_cache_indices: torch.Tensor
54
-
55
-
56
- class MambaAttnBackend(AttentionBackend):
57
- """Attention backend using Mamba kernel."""
58
-
54
+ class MambaAttnBackendBase(AttentionBackend):
59
55
  def __init__(self, model_runner: ModelRunner):
60
56
  super().__init__()
61
- self.pad_slot_id = -1 # Default pad slot id
57
+ self.pad_slot_id = PAD_SLOT_ID
62
58
  self.device = model_runner.device
63
59
  self.req_to_token_pool: HybridReqToTokenPool = model_runner.req_to_token_pool
64
60
  self.forward_metadata: ForwardMetadata = None
@@ -67,7 +63,7 @@ class MambaAttnBackend(AttentionBackend):
67
63
  self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None
68
64
  self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None
69
65
 
70
- def init_forward_metadata(self, forward_batch: ForwardBatch):
66
+ def _forward_metadata(self, forward_batch: ForwardBatch):
71
67
  bs = forward_batch.batch_size
72
68
 
73
69
  if forward_batch.forward_mode.is_decode_or_idle():
@@ -97,11 +93,43 @@ class MambaAttnBackend(AttentionBackend):
97
93
  mamba_cache_indices = self.req_to_token_pool.get_mamba_indices(
98
94
  forward_batch.req_pool_indices
99
95
  )
100
- self.forward_metadata = ForwardMetadata(
96
+ return ForwardMetadata(
101
97
  query_start_loc=query_start_loc,
102
98
  mamba_cache_indices=mamba_cache_indices,
103
99
  )
104
100
 
101
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
102
+ self.forward_metadata = self._forward_metadata(forward_batch)
103
+
104
+ def init_forward_metadata_capture_cuda_graph(
105
+ self,
106
+ bs: int,
107
+ num_tokens: int,
108
+ req_pool_indices: torch.Tensor,
109
+ seq_lens: torch.Tensor,
110
+ encoder_lens: Optional[torch.Tensor],
111
+ forward_mode: ForwardMode,
112
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
113
+ ):
114
+ self.forward_metadata = self._capture_metadata(
115
+ bs, req_pool_indices, forward_mode
116
+ )
117
+
118
+ def init_forward_metadata_replay_cuda_graph(
119
+ self,
120
+ bs: int,
121
+ req_pool_indices: torch.Tensor,
122
+ seq_lens: torch.Tensor,
123
+ seq_lens_sum: int,
124
+ encoder_lens: Optional[torch.Tensor],
125
+ forward_mode: ForwardMode,
126
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
127
+ seq_lens_cpu: Optional[torch.Tensor],
128
+ ):
129
+ self.forward_metadata = self._replay_metadata(
130
+ bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
131
+ )
132
+
105
133
  def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
106
134
  assert (
107
135
  max_num_tokens % max_bs == 0
@@ -127,15 +155,8 @@ class MambaAttnBackend(AttentionBackend):
127
155
  device=self.device,
128
156
  )
129
157
 
130
- def init_forward_metadata_capture_cuda_graph(
131
- self,
132
- bs: int,
133
- num_tokens: int,
134
- req_pool_indices: torch.Tensor,
135
- seq_lens: torch.Tensor,
136
- encoder_lens: Optional[torch.Tensor],
137
- forward_mode: ForwardMode,
138
- spec_info: Optional[SpecInput],
158
+ def _capture_metadata(
159
+ self, bs: int, req_pool_indices: torch.Tensor, forward_mode: ForwardMode
139
160
  ):
140
161
  if forward_mode.is_decode_or_idle():
141
162
  self.query_start_loc_list[bs - 1].copy_(
@@ -149,18 +170,15 @@ class MambaAttnBackend(AttentionBackend):
149
170
  raise ValueError(f"Invalid forward mode: {forward_mode=}")
150
171
  mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)
151
172
  self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
152
- self.forward_metadata = ForwardMetadata(
173
+ return ForwardMetadata(
153
174
  query_start_loc=self.query_start_loc_list[bs - 1],
154
175
  mamba_cache_indices=self.state_indices_list[bs - 1],
155
176
  )
156
177
 
157
- def init_forward_metadata_replay_cuda_graph(
178
+ def _replay_metadata(
158
179
  self,
159
180
  bs: int,
160
181
  req_pool_indices: torch.Tensor,
161
- seq_lens: torch.Tensor,
162
- seq_lens_sum: int,
163
- encoder_lens: Optional[torch.Tensor],
164
182
  forward_mode: ForwardMode,
165
183
  spec_info: Optional[SpecInput],
166
184
  seq_lens_cpu: Optional[torch.Tensor],
@@ -200,7 +218,7 @@ class MambaAttnBackend(AttentionBackend):
200
218
  else:
201
219
  raise ValueError(f"Invalid forward mode: {forward_mode=}")
202
220
 
203
- self.forward_metadata = ForwardMetadata(
221
+ return ForwardMetadata(
204
222
  query_start_loc=self.query_start_loc_list[bs - 1],
205
223
  mamba_cache_indices=self.state_indices_list[bs - 1],
206
224
  )
@@ -208,6 +226,10 @@ class MambaAttnBackend(AttentionBackend):
208
226
  def get_cuda_graph_seq_len_fill_value(self):
209
227
  return 1 # Mamba attn does not use seq lens to index kv cache
210
228
 
229
+
230
+ class GDNAttnBackend(MambaAttnBackendBase):
231
+ """Attention backend using Mamba kernel."""
232
+
211
233
  def forward_decode(
212
234
  self,
213
235
  q: torch.Tensor,
@@ -233,9 +255,9 @@ class MambaAttnBackend(AttentionBackend):
233
255
  dt_bias = kwargs["dt_bias"]
234
256
  layer_id = kwargs["layer_id"]
235
257
 
236
- conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
237
- layer_id
238
- )
258
+ layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
259
+ conv_states = layer_cache.conv
260
+ ssm_states = layer_cache.temporal
239
261
  query_start_loc = self.forward_metadata.query_start_loc
240
262
  cache_indices = self.forward_metadata.mamba_cache_indices
241
263
 
@@ -313,13 +335,13 @@ class MambaAttnBackend(AttentionBackend):
313
335
  query_start_loc = self.forward_metadata.query_start_loc
314
336
  cache_indices = self.forward_metadata.mamba_cache_indices
315
337
 
338
+ mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)
339
+ conv_states = mamba_cache_params.conv
340
+ ssm_states = mamba_cache_params.temporal
316
341
  if is_target_verify:
317
- (
318
- conv_states,
319
- ssm_states,
320
- intermediate_state_cache,
321
- intermediate_conv_window_cache,
322
- ) = self.req_to_token_pool.get_mamba_params(layer_id)
342
+ assert isinstance(mamba_cache_params, MambaPool.SpeculativeState)
343
+ intermediate_state_cache = mamba_cache_params.intermediate_ssm
344
+ intermediate_conv_window_cache = mamba_cache_params.intermediate_conv_window
323
345
  has_initial_states = torch.ones(
324
346
  seq_len // forward_batch.spec_info.draft_token_num,
325
347
  dtype=torch.bool,
@@ -327,9 +349,6 @@ class MambaAttnBackend(AttentionBackend):
327
349
  )
328
350
  conv_states_to_use = conv_states.clone()
329
351
  else:
330
- conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
331
- layer_id
332
- )
333
352
  has_initial_states = forward_batch.extend_prefix_lens > 0
334
353
  conv_states_to_use = conv_states
335
354
 
@@ -424,16 +443,100 @@ class MambaAttnBackend(AttentionBackend):
424
443
  return core_attn_out
425
444
 
426
445
 
446
+ class Mamba2AttnBackend(MambaAttnBackendBase):
447
+ """Attention backend wrapper for Mamba2Mixer kernels."""
448
+
449
+ def __init__(self, model_runner: ModelRunner):
450
+ super().__init__(model_runner)
451
+ config = model_runner.mamba2_config
452
+ assert config is not None
453
+ self.mamba_chunk_size = config.mamba_chunk_size
454
+
455
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
456
+ metadata = self._forward_metadata(forward_batch)
457
+ self.forward_metadata = Mamba2Metadata.prepare_mixed(
458
+ metadata.query_start_loc,
459
+ metadata.mamba_cache_indices,
460
+ self.mamba_chunk_size,
461
+ forward_batch,
462
+ )
463
+
464
+ def init_forward_metadata_capture_cuda_graph(
465
+ self,
466
+ bs: int,
467
+ num_tokens: int,
468
+ req_pool_indices: torch.Tensor,
469
+ seq_lens: torch.Tensor,
470
+ encoder_lens: Optional[torch.Tensor],
471
+ forward_mode: ForwardMode,
472
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
473
+ ):
474
+ metadata = self._capture_metadata(bs, req_pool_indices, forward_mode)
475
+ self.forward_metadata = Mamba2Metadata.prepare_decode(
476
+ metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
477
+ )
478
+
479
+ def init_forward_metadata_replay_cuda_graph(
480
+ self,
481
+ bs: int,
482
+ req_pool_indices: torch.Tensor,
483
+ seq_lens: torch.Tensor,
484
+ seq_lens_sum: int,
485
+ encoder_lens: Optional[torch.Tensor],
486
+ forward_mode: ForwardMode,
487
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
488
+ seq_lens_cpu: Optional[torch.Tensor],
489
+ ):
490
+ metadata = self._replay_metadata(
491
+ bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
492
+ )
493
+ self.forward_metadata = Mamba2Metadata.prepare_decode(
494
+ metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
495
+ )
496
+
497
+ def forward(
498
+ self,
499
+ mixer: MambaMixer2,
500
+ hidden_states: torch.Tensor,
501
+ output: torch.Tensor,
502
+ layer_id: int,
503
+ mup_vector: Optional[torch.Tensor] = None,
504
+ use_triton_causal_conv: bool = False,
505
+ ):
506
+ assert isinstance(self.forward_metadata, Mamba2Metadata)
507
+ layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
508
+ return mixer.forward(
509
+ hidden_states=hidden_states,
510
+ output=output,
511
+ layer_cache=layer_cache,
512
+ metadata=self.forward_metadata,
513
+ mup_vector=mup_vector,
514
+ use_triton_causal_conv=use_triton_causal_conv,
515
+ )
516
+
517
+ def forward_decode(self, *args, **kwargs):
518
+ raise NotImplementedError(
519
+ "Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
520
+ )
521
+
522
+ def forward_extend(self, *args, **kwargs):
523
+ raise NotImplementedError(
524
+ "Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
525
+ )
526
+
527
+
427
528
  class HybridLinearAttnBackend(AttentionBackend):
428
- """Support different backends for prefill and decode."""
529
+ """Manages a full and linear attention backend"""
429
530
 
430
531
  def __init__(
431
532
  self,
432
533
  full_attn_backend: AttentionBackend,
433
- linear_attn_backend: AttentionBackend,
534
+ linear_attn_backend: MambaAttnBackendBase,
434
535
  full_attn_layers: list[int],
435
536
  ):
436
537
  self.full_attn_layers = full_attn_layers
538
+ self.full_attn_backend = full_attn_backend
539
+ self.linear_attn_backend = linear_attn_backend
437
540
  self.attn_backend_list = [full_attn_backend, linear_attn_backend]
438
541
 
439
542
  def init_forward_metadata(self, forward_batch: ForwardBatch):
@@ -489,7 +592,7 @@ class HybridLinearAttnBackend(AttentionBackend):
489
592
  )
490
593
 
491
594
  def get_cuda_graph_seq_len_fill_value(self):
492
- return self.attn_backend_list[0].get_cuda_graph_seq_len_fill_value()
595
+ return self.full_attn_backend.get_cuda_graph_seq_len_fill_value()
493
596
 
494
597
  def forward_decode(
495
598
  self,
@@ -503,10 +606,10 @@ class HybridLinearAttnBackend(AttentionBackend):
503
606
  ):
504
607
  layer_id = layer.layer_id if layer else kwargs["layer_id"]
505
608
  if layer_id in self.full_attn_layers:
506
- return self.attn_backend_list[0].forward_decode(
609
+ return self.full_attn_backend.forward_decode(
507
610
  q, k, v, layer, forward_batch, save_kv_cache, **kwargs
508
611
  )
509
- return self.attn_backend_list[1].forward_decode(
612
+ return self.linear_attn_backend.forward_decode(
510
613
  q, k, v, layer, forward_batch, save_kv_cache, **kwargs
511
614
  )
512
615
 
@@ -522,10 +625,10 @@ class HybridLinearAttnBackend(AttentionBackend):
522
625
  ):
523
626
  layer_id = layer.layer_id if layer else kwargs["layer_id"]
524
627
  if layer_id in self.full_attn_layers:
525
- return self.attn_backend_list[0].forward_extend(
628
+ return self.full_attn_backend.forward_extend(
526
629
  q, k, v, layer, forward_batch, save_kv_cache, **kwargs
527
630
  )
528
- return self.attn_backend_list[1].forward_extend(
631
+ return self.linear_attn_backend.forward_extend(
529
632
  q, k, v, layer, forward_batch, save_kv_cache, **kwargs
530
633
  )
531
634
 
@@ -568,20 +671,20 @@ class HybridLinearAttnBackend(AttentionBackend):
568
671
  def update_mamba_state_after_mtp_verify(self, accepted_length, model):
569
672
  request_number = accepted_length.shape[0]
570
673
 
571
- state_indices_tensor = self.attn_backend_list[
572
- 1
573
- ].forward_metadata.mamba_cache_indices[:request_number]
674
+ state_indices_tensor = (
675
+ self.linear_attn_backend.forward_metadata.mamba_cache_indices[
676
+ :request_number
677
+ ]
678
+ )
574
679
 
575
- mamba_caches = self.attn_backend_list[
576
- 1
577
- ].req_to_token_pool.get_mamba_params_all_layers()
680
+ mamba_caches = (
681
+ self.linear_attn_backend.req_to_token_pool.get_speculative_mamba2_params_all_layers()
682
+ )
578
683
 
579
- (
580
- conv_states,
581
- ssm_states,
582
- intermediate_state_cache,
583
- intermediate_conv_window_cache,
584
- ) = mamba_caches
684
+ conv_states = mamba_caches.conv
685
+ ssm_states = mamba_caches.temporal
686
+ intermediate_state_cache = mamba_caches.intermediate_ssm
687
+ intermediate_conv_window_cache = mamba_caches.intermediate_conv_window
585
688
 
586
689
  # SSM state updates (chunked to reduce peak memory)
587
690
  valid_mask = accepted_length > 0
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
14
14
 
15
15
  class IntelAMXAttnBackend(AttentionBackend):
16
16
  def __init__(self, model_runner: ModelRunner):
17
- import sgl_kernel
17
+ import sgl_kernel # noqa: F401
18
18
 
19
19
  super().__init__()
20
20
  self.forward_metadata = None
@@ -10,7 +10,7 @@ import torch
10
10
  from sgl_kernel import causal_conv1d_fwd
11
11
  from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel
12
12
 
13
- PAD_SLOT_ID = -1
13
+ from .causal_conv1d_triton import PAD_SLOT_ID
14
14
 
15
15
 
16
16
  def causal_conv1d_fn(
@@ -4,13 +4,12 @@
4
4
 
5
5
  from typing import List, Optional, Union
6
6
 
7
- import numpy as np
8
7
  import torch
9
-
10
- PAD_SLOT_ID = -1
11
8
  import triton
12
9
  import triton.language as tl
13
10
 
11
+ PAD_SLOT_ID = -1
12
+
14
13
 
15
14
  @triton.jit()
16
15
  def _causal_conv1d_fwd_kernel( # continuous batching
@@ -672,7 +671,9 @@ def _causal_conv1d_update_kernel(
672
671
  + (conv_state_batch_coord * stride_conv_state_seq)
673
672
  + conv_state_token_offset * stride_conv_state_tok
674
673
  + (idx_feats * stride_conv_state_dim)[None, :]
675
- + ((idx_tokens + 1) * stride_conv_state_tok)[:, None]
674
+ + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[
675
+ :, None
676
+ ]
676
677
  ) # [BLOCK_M, BLOCK_N]
677
678
  mask = (
678
679
  (conv_state_batch_coord < num_cache_lines)
@@ -897,7 +898,10 @@ def causal_conv1d_update(
897
898
  stride_state_indices = (
898
899
  conv_state_indices.stride(0) if conv_state_indices is not None else 0
899
900
  )
900
- state_len = width - 1 + (seqlen - 1) # effective state_len needed
901
+ if num_accepted_tokens is not None:
902
+ state_len = width - 1 + (seqlen - 1) # effective state_len needed
903
+ else:
904
+ state_len = width - 1
901
905
  np2_statelen = triton.next_power_of_2(state_len)
902
906
 
903
907
  def grid(META):