sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/lang/interpreter.py +1 -1
  7. sglang/launch_server.py +14 -0
  8. sglang/profiler.py +2 -2
  9. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  10. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  11. sglang/srt/configs/__init__.py +8 -0
  12. sglang/srt/configs/device_config.py +3 -1
  13. sglang/srt/configs/dots_ocr.py +64 -0
  14. sglang/srt/configs/dots_vlm.py +139 -0
  15. sglang/srt/configs/falcon_h1.py +360 -0
  16. sglang/srt/configs/internvl.py +6 -0
  17. sglang/srt/configs/load_config.py +9 -0
  18. sglang/srt/configs/model_config.py +181 -82
  19. sglang/srt/configs/qwen3_next.py +326 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +71 -19
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +326 -53
  44. sglang/srt/disaggregation/prefill.py +36 -17
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +192 -113
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  52. sglang/srt/entrypoints/grpc_server.py +810 -0
  53. sglang/srt/entrypoints/http_server.py +132 -57
  54. sglang/srt/entrypoints/openai/protocol.py +115 -7
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +207 -58
  57. sglang/srt/entrypoints/openai/serving_completions.py +17 -4
  58. sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +49 -4
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/environ.py +285 -0
  63. sglang/srt/eplb/eplb_manager.py +2 -2
  64. sglang/srt/eplb/expert_distribution.py +26 -13
  65. sglang/srt/eplb/expert_location.py +38 -8
  66. sglang/srt/eplb/expert_location_updater.py +1 -1
  67. sglang/srt/function_call/base_format_detector.py +3 -6
  68. sglang/srt/function_call/ebnf_composer.py +11 -9
  69. sglang/srt/function_call/function_call_parser.py +9 -2
  70. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  71. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  72. sglang/srt/function_call/json_array_parser.py +63 -0
  73. sglang/srt/function_call/kimik2_detector.py +17 -4
  74. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  75. sglang/srt/function_call/utils.py +96 -5
  76. sglang/srt/grpc/__init__.py +1 -0
  77. sglang/srt/grpc/compile_proto.py +245 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  81. sglang/srt/layers/activation.py +143 -9
  82. sglang/srt/layers/attention/aiter_backend.py +106 -82
  83. sglang/srt/layers/attention/ascend_backend.py +115 -9
  84. sglang/srt/layers/attention/attention_registry.py +206 -0
  85. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  86. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  87. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  88. sglang/srt/layers/attention/fla/chunk.py +242 -0
  89. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  90. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  91. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  92. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  93. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  94. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  95. sglang/srt/layers/attention/fla/index.py +37 -0
  96. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  97. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  98. sglang/srt/layers/attention/fla/op.py +66 -0
  99. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  100. sglang/srt/layers/attention/fla/utils.py +331 -0
  101. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  102. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  103. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  104. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  105. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  106. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  107. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  108. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  109. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  111. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  112. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  113. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  114. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  115. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  121. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  122. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  123. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  124. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  125. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  126. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  127. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  128. sglang/srt/layers/attention/nsa/utils.py +24 -0
  129. sglang/srt/layers/attention/nsa_backend.py +887 -0
  130. sglang/srt/layers/attention/tbo_backend.py +6 -6
  131. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  132. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  133. sglang/srt/layers/attention/triton_backend.py +57 -7
  134. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  135. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  136. sglang/srt/layers/attention/vision.py +58 -0
  137. sglang/srt/layers/attention/wave_backend.py +4 -4
  138. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  139. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  140. sglang/srt/layers/communicator.py +53 -7
  141. sglang/srt/layers/dp_attention.py +41 -2
  142. sglang/srt/layers/elementwise.py +3 -1
  143. sglang/srt/layers/layernorm.py +34 -15
  144. sglang/srt/layers/linear.py +55 -7
  145. sglang/srt/layers/logits_processor.py +44 -12
  146. sglang/srt/layers/moe/__init__.py +2 -1
  147. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  148. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  149. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  150. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  151. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  167. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  169. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  170. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  171. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  172. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  173. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  174. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  175. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  176. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  177. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  178. sglang/srt/layers/moe/topk.py +30 -9
  179. sglang/srt/layers/moe/utils.py +22 -7
  180. sglang/srt/layers/parameter.py +23 -6
  181. sglang/srt/layers/quantization/awq.py +19 -7
  182. sglang/srt/layers/quantization/base_config.py +11 -6
  183. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  184. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  185. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  186. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  187. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  188. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  189. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  190. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  191. sglang/srt/layers/quantization/fp8.py +78 -49
  192. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  193. sglang/srt/layers/quantization/gptq.py +25 -17
  194. sglang/srt/layers/quantization/modelopt_quant.py +225 -57
  195. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  196. sglang/srt/layers/quantization/mxfp4.py +77 -42
  197. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  199. sglang/srt/layers/quantization/quark/utils.py +97 -0
  200. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  201. sglang/srt/layers/quantization/unquant.py +135 -47
  202. sglang/srt/layers/quantization/w4afp8.py +26 -17
  203. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  204. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  205. sglang/srt/layers/rocm_linear_utils.py +44 -0
  206. sglang/srt/layers/rotary_embedding.py +78 -49
  207. sglang/srt/layers/sampler.py +213 -21
  208. sglang/srt/layers/utils.py +23 -0
  209. sglang/srt/lora/backend/base_backend.py +50 -8
  210. sglang/srt/lora/backend/chunked_backend.py +348 -0
  211. sglang/srt/lora/backend/triton_backend.py +99 -5
  212. sglang/srt/lora/layers.py +32 -0
  213. sglang/srt/lora/lora.py +8 -3
  214. sglang/srt/lora/lora_manager.py +52 -118
  215. sglang/srt/lora/mem_pool.py +25 -11
  216. sglang/srt/lora/triton_ops/__init__.py +4 -0
  217. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  219. sglang/srt/lora/utils.py +22 -11
  220. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  221. sglang/srt/managers/cache_controller.py +215 -314
  222. sglang/srt/managers/data_parallel_controller.py +115 -80
  223. sglang/srt/managers/detokenizer_manager.py +19 -15
  224. sglang/srt/managers/disagg_service.py +46 -0
  225. sglang/srt/managers/io_struct.py +340 -109
  226. sglang/srt/managers/mm_utils.py +44 -6
  227. sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
  228. sglang/srt/managers/multimodal_processor.py +1 -2
  229. sglang/srt/managers/overlap_utils.py +53 -0
  230. sglang/srt/managers/schedule_batch.py +240 -138
  231. sglang/srt/managers/schedule_policy.py +147 -19
  232. sglang/srt/managers/scheduler.py +501 -304
  233. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  234. sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
  235. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  236. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  237. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  238. sglang/srt/managers/template_manager.py +3 -3
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +321 -632
  241. sglang/srt/managers/tp_worker.py +81 -22
  242. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  243. sglang/srt/managers/utils.py +1 -45
  244. sglang/srt/mem_cache/allocator.py +15 -21
  245. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  246. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  247. sglang/srt/mem_cache/chunk_cache.py +8 -1
  248. sglang/srt/mem_cache/evict_policy.py +23 -0
  249. sglang/srt/mem_cache/hicache_storage.py +58 -34
  250. sglang/srt/mem_cache/hiradix_cache.py +227 -80
  251. sglang/srt/mem_cache/memory_pool.py +535 -58
  252. sglang/srt/mem_cache/memory_pool_host.py +239 -223
  253. sglang/srt/mem_cache/radix_cache.py +222 -73
  254. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  255. sglang/srt/mem_cache/storage/__init__.py +10 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  257. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  258. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  259. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  260. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  261. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  262. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  263. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
  264. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  265. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  266. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
  267. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  268. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  269. sglang/srt/metrics/collector.py +519 -132
  270. sglang/srt/metrics/func_timer.py +2 -7
  271. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  272. sglang/srt/metrics/utils.py +55 -0
  273. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  274. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  275. sglang/srt/model_executor/forward_batch_info.py +98 -57
  276. sglang/srt/model_executor/model_runner.py +433 -158
  277. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  278. sglang/srt/model_loader/__init__.py +9 -3
  279. sglang/srt/model_loader/loader.py +133 -5
  280. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  281. sglang/srt/model_loader/weight_utils.py +158 -3
  282. sglang/srt/models/apertus.py +686 -0
  283. sglang/srt/models/bailing_moe.py +820 -217
  284. sglang/srt/models/bailing_moe_nextn.py +168 -0
  285. sglang/srt/models/deepseek_nextn.py +6 -1
  286. sglang/srt/models/deepseek_v2.py +833 -152
  287. sglang/srt/models/dots_ocr.py +173 -0
  288. sglang/srt/models/dots_vlm.py +174 -0
  289. sglang/srt/models/dots_vlm_vit.py +337 -0
  290. sglang/srt/models/ernie4.py +1 -1
  291. sglang/srt/models/falcon_h1.py +576 -0
  292. sglang/srt/models/gemma3_causal.py +0 -2
  293. sglang/srt/models/gemma3_mm.py +1 -1
  294. sglang/srt/models/gemma3n_mm.py +2 -2
  295. sglang/srt/models/glm4_moe.py +14 -5
  296. sglang/srt/models/glm4_moe_nextn.py +2 -2
  297. sglang/srt/models/glm4v.py +5 -3
  298. sglang/srt/models/glm4v_moe.py +4 -1
  299. sglang/srt/models/gpt_oss.py +8 -31
  300. sglang/srt/models/internvl.py +28 -0
  301. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  302. sglang/srt/models/llama.py +4 -0
  303. sglang/srt/models/llama4.py +9 -0
  304. sglang/srt/models/llama_eagle3.py +13 -0
  305. sglang/srt/models/longcat_flash.py +3 -3
  306. sglang/srt/models/longcat_flash_nextn.py +1 -1
  307. sglang/srt/models/minicpmv.py +165 -3
  308. sglang/srt/models/mllama4.py +40 -4
  309. sglang/srt/models/opt.py +637 -0
  310. sglang/srt/models/qwen2_5_vl.py +29 -5
  311. sglang/srt/models/qwen2_audio.py +1 -1
  312. sglang/srt/models/qwen2_moe.py +124 -14
  313. sglang/srt/models/qwen2_vl.py +1 -1
  314. sglang/srt/models/qwen3.py +26 -5
  315. sglang/srt/models/qwen3_moe.py +71 -12
  316. sglang/srt/models/qwen3_next.py +1069 -0
  317. sglang/srt/models/qwen3_next_mtp.py +112 -0
  318. sglang/srt/models/qwen3_vl.py +787 -0
  319. sglang/srt/models/qwen3_vl_moe.py +471 -0
  320. sglang/srt/models/registry.py +15 -3
  321. sglang/srt/models/sarashina2_vision.py +269 -0
  322. sglang/srt/models/solar.py +505 -0
  323. sglang/srt/models/starcoder2.py +357 -0
  324. sglang/srt/models/step3_vl.py +1 -1
  325. sglang/srt/models/torch_native_llama.py +10 -3
  326. sglang/srt/models/utils.py +51 -0
  327. sglang/srt/multimodal/processors/base_processor.py +15 -7
  328. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  329. sglang/srt/multimodal/processors/glm4v.py +9 -9
  330. sglang/srt/multimodal/processors/internvl.py +153 -129
  331. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  332. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  333. sglang/srt/offloader.py +27 -3
  334. sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
  335. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  336. sglang/srt/sampling/sampling_batch_info.py +38 -17
  337. sglang/srt/sampling/sampling_params.py +7 -0
  338. sglang/srt/server_args.py +1030 -254
  339. sglang/srt/server_args_config_parser.py +146 -0
  340. sglang/srt/single_batch_overlap.py +151 -0
  341. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  342. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  343. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  344. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  345. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  346. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  347. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  348. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  349. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  350. sglang/srt/speculative/eagle_worker.py +253 -136
  351. sglang/srt/speculative/ngram_utils.py +428 -0
  352. sglang/srt/speculative/ngram_worker.py +245 -0
  353. sglang/srt/speculative/spec_info.py +52 -0
  354. sglang/srt/speculative/spec_utils.py +606 -0
  355. sglang/srt/speculative/standalone_worker.py +109 -0
  356. sglang/srt/torch_memory_saver_adapter.py +5 -7
  357. sglang/srt/tracing/trace.py +578 -0
  358. sglang/srt/two_batch_overlap.py +8 -5
  359. sglang/srt/utils/__init__.py +2 -0
  360. sglang/srt/{utils.py → utils/common.py} +445 -77
  361. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  362. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  363. sglang/srt/utils/rpd_utils.py +452 -0
  364. sglang/srt/utils/slow_rank_detector.py +71 -0
  365. sglang/srt/warmup.py +8 -4
  366. sglang/srt/weight_sync/utils.py +2 -2
  367. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  368. sglang/test/few_shot_gsm8k.py +1 -0
  369. sglang/test/get_logits_ut.py +57 -0
  370. sglang/test/run_eval.py +79 -11
  371. sglang/test/runners.py +5 -1
  372. sglang/test/simple_eval_common.py +5 -2
  373. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  374. sglang/test/test_block_fp8.py +2 -2
  375. sglang/test/test_cutlass_moe.py +24 -6
  376. sglang/test/test_deterministic.py +297 -0
  377. sglang/test/test_disaggregation_utils.py +77 -0
  378. sglang/test/test_fp4_moe.py +370 -1
  379. sglang/test/test_programs.py +1 -1
  380. sglang/test/test_utils.py +383 -5
  381. sglang/utils.py +22 -1
  382. sglang/version.py +1 -1
  383. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  384. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
  385. sglang/srt/disaggregation/launch_lb.py +0 -118
  386. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  387. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  388. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  389. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  390. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  391. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  392. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  393. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  394. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  395. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
23
23
  from sglang.srt.model_executor.forward_batch_info import (
24
24
  CaptureHiddenMode,
25
25
  ForwardBatch,
26
+ ForwardBatchOutput,
26
27
  ForwardMode,
27
28
  )
28
29
  from sglang.srt.server_args import ServerArgs
@@ -33,20 +34,23 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
33
34
  from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
34
35
  EAGLEDraftExtendCudaGraphRunner,
35
36
  )
36
- from sglang.srt.speculative.eagle_utils import (
37
+ from sglang.srt.speculative.eagle_info import (
37
38
  EagleDraftInput,
38
39
  EagleVerifyInput,
39
40
  EagleVerifyOutput,
41
+ )
42
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
43
+ from sglang.srt.speculative.spec_utils import (
40
44
  assign_draft_cache_locs,
41
45
  fast_topk,
42
46
  generate_token_bitmask,
43
47
  select_top_k_tokens,
44
48
  )
45
- from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
46
49
  from sglang.srt.utils import (
47
50
  empty_context,
48
51
  get_available_gpu_memory,
49
52
  get_bool_env_var,
53
+ is_blackwell,
50
54
  is_cuda,
51
55
  next_power_of_2,
52
56
  )
@@ -187,137 +191,204 @@ class EAGLEWorker(TpModelWorker):
187
191
  self.has_prefill_wrapper_verify = False
188
192
  self.draft_extend_attn_backend = None
189
193
 
190
- if self.server_args.attention_backend == "flashinfer":
191
- if not global_server_args_dict["use_mla_backend"]:
192
- from sglang.srt.layers.attention.flashinfer_backend import (
193
- FlashInferAttnBackend,
194
- FlashInferMultiStepDraftBackend,
195
- )
194
+ # Initialize decode attention backend
195
+ self.draft_attn_backend = self._create_decode_backend()
196
196
 
197
- self.draft_attn_backend = FlashInferMultiStepDraftBackend(
198
- self.draft_model_runner,
199
- self.topk,
200
- self.speculative_num_steps,
201
- )
202
- self.draft_extend_attn_backend = FlashInferAttnBackend(
203
- self.draft_model_runner,
204
- skip_prefill=False,
205
- )
206
- else:
207
- from sglang.srt.layers.attention.flashinfer_mla_backend import (
208
- FlashInferMLAAttnBackend,
209
- FlashInferMLAMultiStepDraftBackend,
210
- )
197
+ # Initialize draft extend attention backend (respects speculative_attention_mode setting)
198
+ self.draft_extend_attn_backend = self._create_draft_extend_backend()
211
199
 
212
- self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
213
- self.draft_model_runner,
214
- self.topk,
215
- self.speculative_num_steps,
216
- )
217
- self.draft_extend_attn_backend = FlashInferMLAAttnBackend(
218
- self.draft_model_runner,
219
- skip_prefill=False,
220
- )
221
- self.has_prefill_wrapper_verify = True
222
- elif self.server_args.attention_backend == "triton":
223
- from sglang.srt.layers.attention.triton_backend import (
224
- TritonAttnBackend,
225
- TritonMultiStepDraftBackend,
226
- )
200
+ self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
227
201
 
228
- self.draft_attn_backend = TritonMultiStepDraftBackend(
229
- self.draft_model_runner,
230
- self.topk,
231
- self.speculative_num_steps,
232
- )
233
- self.draft_extend_attn_backend = TritonAttnBackend(
234
- self.draft_model_runner,
235
- skip_prefill=False,
236
- )
237
- elif self.server_args.attention_backend == "aiter":
238
- from sglang.srt.layers.attention.aiter_backend import (
239
- AiterAttnBackend,
240
- AiterMultiStepDraftBackend,
241
- )
202
+ def _create_backend(
203
+ self, backend_name: str, backend_map: dict, error_template: str
204
+ ):
205
+ backend_type = getattr(self.server_args, backend_name)
206
+ if backend_type is None:
207
+ backend_type = self.server_args.attention_backend
208
+
209
+ if backend_type not in backend_map:
210
+ raise ValueError(error_template.format(backend_type=backend_type))
211
+
212
+ return backend_map[backend_type]()
213
+
214
+ def _create_decode_backend(self):
215
+ backend_map = {
216
+ "flashinfer": self._create_flashinfer_decode_backend,
217
+ "triton": self._create_triton_decode_backend,
218
+ "aiter": self._create_aiter_decode_backend,
219
+ "fa3": self._create_fa3_decode_backend,
220
+ "hybrid_linear_attn": (
221
+ self._create_fa3_decode_backend
222
+ if not is_blackwell()
223
+ else self._create_triton_decode_backend
224
+ ),
225
+ "flashmla": self._create_flashmla_decode_backend,
226
+ "trtllm_mha": self._create_trtllm_mha_decode_backend,
227
+ "trtllm_mla": self._create_trtllm_mla_decode_backend,
228
+ }
229
+
230
+ return self._create_backend(
231
+ "decode_attention_backend",
232
+ backend_map,
233
+ "EAGLE is not supported in decode attention backend {backend_type}",
234
+ )
242
235
 
243
- self.draft_attn_backend = AiterMultiStepDraftBackend(
244
- self.draft_model_runner,
245
- self.topk,
246
- self.speculative_num_steps,
247
- )
248
- self.draft_extend_attn_backend = AiterAttnBackend(
249
- self.draft_model_runner,
250
- skip_prefill=False,
251
- )
252
- self.has_prefill_wrapper_verify = False
253
- elif self.server_args.attention_backend == "fa3":
254
- from sglang.srt.layers.attention.flashattention_backend import (
255
- FlashAttentionBackend,
256
- FlashAttentionMultiStepBackend,
257
- )
236
+ def _create_draft_extend_backend(self):
237
+ backend_map = {
238
+ "flashinfer": self._create_flashinfer_prefill_backend,
239
+ "triton": self._create_triton_prefill_backend,
240
+ "aiter": self._create_aiter_prefill_backend,
241
+ "fa3": self._create_fa3_prefill_backend,
242
+ "hybrid_linear_attn": (
243
+ self._create_fa3_prefill_backend
244
+ if not is_blackwell()
245
+ else self._create_triton_prefill_backend
246
+ ),
247
+ "flashmla": self._create_flashmla_prefill_backend,
248
+ "trtllm_mha": self._create_trtllm_mha_prefill_backend,
249
+ "trtllm_mla": self._create_trtllm_mla_prefill_backend,
250
+ }
251
+ backend_name = (
252
+ "decode_attention_backend"
253
+ if self.server_args.speculative_attention_mode == "decode"
254
+ else "prefill_attention_backend"
255
+ )
256
+ return self._create_backend(
257
+ backend_name,
258
+ backend_map,
259
+ "EAGLE is not supported in attention backend {backend_type}",
260
+ )
258
261
 
259
- self.draft_attn_backend = FlashAttentionMultiStepBackend(
260
- self.draft_model_runner,
261
- self.topk,
262
- self.speculative_num_steps,
263
- )
264
- self.draft_extend_attn_backend = FlashAttentionBackend(
265
- self.draft_model_runner,
266
- skip_prefill=False,
267
- )
268
- elif self.server_args.attention_backend == "flashmla":
269
- from sglang.srt.layers.attention.flashmla_backend import (
270
- FlashMLAMultiStepDraftBackend,
262
+ def _create_flashinfer_decode_backend(self):
263
+ if not global_server_args_dict["use_mla_backend"]:
264
+ from sglang.srt.layers.attention.flashinfer_backend import (
265
+ FlashInferMultiStepDraftBackend,
271
266
  )
272
267
 
273
- self.draft_attn_backend = FlashMLAMultiStepDraftBackend(
274
- self.draft_model_runner,
275
- self.topk,
276
- self.speculative_num_steps,
268
+ self.has_prefill_wrapper_verify = True
269
+ return FlashInferMultiStepDraftBackend(
270
+ self.draft_model_runner, self.topk, self.speculative_num_steps
277
271
  )
278
- elif self.server_args.attention_backend == "trtllm_mha":
279
- from sglang.srt.layers.attention.trtllm_mha_backend import (
280
- TRTLLMHAAttnBackend,
281
- TRTLLMHAAttnMultiStepDraftBackend,
272
+ else:
273
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
274
+ FlashInferMLAMultiStepDraftBackend,
282
275
  )
283
276
 
284
- self.draft_attn_backend = TRTLLMHAAttnMultiStepDraftBackend(
285
- self.draft_model_runner,
286
- self.topk,
287
- self.speculative_num_steps,
288
- )
289
- self.draft_extend_attn_backend = TRTLLMHAAttnBackend(
290
- self.draft_model_runner,
291
- skip_prefill=False,
292
- )
293
277
  self.has_prefill_wrapper_verify = True
294
- elif self.server_args.attention_backend == "trtllm_mla":
295
- if not global_server_args_dict["use_mla_backend"]:
296
- raise ValueError(
297
- "trtllm_mla backend requires MLA model (use_mla_backend=True)."
298
- )
299
-
300
- from sglang.srt.layers.attention.trtllm_mla_backend import (
301
- TRTLLMMLABackend,
302
- TRTLLMMLAMultiStepDraftBackend,
278
+ return FlashInferMLAMultiStepDraftBackend(
279
+ self.draft_model_runner, self.topk, self.speculative_num_steps
303
280
  )
304
281
 
305
- self.draft_attn_backend = TRTLLMMLAMultiStepDraftBackend(
306
- self.draft_model_runner,
307
- self.topk,
308
- self.speculative_num_steps,
282
+ def _create_triton_decode_backend(self):
283
+ from sglang.srt.layers.attention.triton_backend import (
284
+ TritonMultiStepDraftBackend,
285
+ )
286
+
287
+ return TritonMultiStepDraftBackend(
288
+ self.draft_model_runner, self.topk, self.speculative_num_steps
289
+ )
290
+
291
+ def _create_aiter_decode_backend(self):
292
+ from sglang.srt.layers.attention.aiter_backend import AiterMultiStepDraftBackend
293
+
294
+ return AiterMultiStepDraftBackend(
295
+ self.draft_model_runner, self.topk, self.speculative_num_steps
296
+ )
297
+
298
+ def _create_fa3_decode_backend(self):
299
+ from sglang.srt.layers.attention.flashattention_backend import (
300
+ FlashAttentionMultiStepBackend,
301
+ )
302
+
303
+ return FlashAttentionMultiStepBackend(
304
+ self.draft_model_runner, self.topk, self.speculative_num_steps
305
+ )
306
+
307
+ def _create_flashmla_decode_backend(self):
308
+ from sglang.srt.layers.attention.flashmla_backend import (
309
+ FlashMLAMultiStepDraftBackend,
310
+ )
311
+
312
+ return FlashMLAMultiStepDraftBackend(
313
+ self.draft_model_runner, self.topk, self.speculative_num_steps
314
+ )
315
+
316
+ def _create_trtllm_mha_decode_backend(self):
317
+ from sglang.srt.layers.attention.trtllm_mha_backend import (
318
+ TRTLLMHAAttnMultiStepDraftBackend,
319
+ )
320
+
321
+ self.has_prefill_wrapper_verify = True
322
+ return TRTLLMHAAttnMultiStepDraftBackend(
323
+ self.draft_model_runner, self.topk, self.speculative_num_steps
324
+ )
325
+
326
+ def _create_trtllm_mla_decode_backend(self):
327
+ if not global_server_args_dict["use_mla_backend"]:
328
+ raise ValueError(
329
+ "trtllm_mla backend requires MLA model (use_mla_backend=True)."
309
330
  )
310
- self.draft_extend_attn_backend = TRTLLMMLABackend(
311
- self.draft_model_runner,
312
- skip_prefill=False,
331
+
332
+ from sglang.srt.layers.attention.trtllm_mla_backend import (
333
+ TRTLLMMLAMultiStepDraftBackend,
334
+ )
335
+
336
+ self.has_prefill_wrapper_verify = True
337
+ return TRTLLMMLAMultiStepDraftBackend(
338
+ self.draft_model_runner, self.topk, self.speculative_num_steps
339
+ )
340
+
341
+ def _create_flashinfer_prefill_backend(self):
342
+ if not global_server_args_dict["use_mla_backend"]:
343
+ from sglang.srt.layers.attention.flashinfer_backend import (
344
+ FlashInferAttnBackend,
313
345
  )
314
- self.has_prefill_wrapper_verify = True
346
+
347
+ return FlashInferAttnBackend(self.draft_model_runner, skip_prefill=False)
315
348
  else:
349
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
350
+ FlashInferMLAAttnBackend,
351
+ )
352
+
353
+ return FlashInferMLAAttnBackend(self.draft_model_runner, skip_prefill=False)
354
+
355
+ def _create_triton_prefill_backend(self):
356
+ from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
357
+
358
+ return TritonAttnBackend(self.draft_model_runner, skip_prefill=False)
359
+
360
+ def _create_aiter_prefill_backend(self):
361
+ from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
362
+
363
+ return AiterAttnBackend(self.draft_model_runner, skip_prefill=False)
364
+
365
+ def _create_fa3_prefill_backend(self):
366
+ from sglang.srt.layers.attention.flashattention_backend import (
367
+ FlashAttentionBackend,
368
+ )
369
+
370
+ return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)
371
+
372
+ def _create_trtllm_mha_prefill_backend(self):
373
+ from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
374
+
375
+ return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)
376
+
377
+ def _create_trtllm_mla_prefill_backend(self):
378
+ if not global_server_args_dict["use_mla_backend"]:
316
379
  raise ValueError(
317
- f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
380
+ "trtllm_mla backend requires MLA model (use_mla_backend=True)."
318
381
  )
319
382
 
320
- self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
383
+ from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
384
+
385
+ return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
386
+
387
+ def _create_flashmla_prefill_backend(self):
388
+ logger.warning(
389
+ "flashmla prefill backend is not yet supported for draft extend."
390
+ )
391
+ return None
321
392
 
322
393
  def init_cuda_graphs(self):
323
394
  """Capture cuda graphs."""
@@ -358,9 +429,7 @@ class EAGLEWorker(TpModelWorker):
358
429
  def draft_model_runner(self):
359
430
  return self.model_runner
360
431
 
361
- def forward_batch_speculative_generation(
362
- self, batch: ScheduleBatch
363
- ) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]:
432
+ def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput:
364
433
  """Run speculative decoding forward.
365
434
 
366
435
  NOTE: Many states of batch is modified as you go through. It is not guaranteed that
@@ -373,14 +442,19 @@ class EAGLEWorker(TpModelWorker):
373
442
  the batch id (used for overlap schedule), and number of accepted tokens.
374
443
  """
375
444
  if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
376
- logits_output, next_token_ids, bid, seq_lens_cpu = (
377
- self.forward_target_extend(batch)
445
+ logits_output, next_token_ids, seq_lens_cpu = self.forward_target_extend(
446
+ batch
378
447
  )
379
448
  with self.draft_tp_context(self.draft_model_runner.tp_group):
380
449
  self.forward_draft_extend(
381
450
  batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
382
451
  )
383
- return logits_output, next_token_ids, bid, 0, False
452
+ return ForwardBatchOutput(
453
+ logits_output=logits_output,
454
+ next_token_ids=next_token_ids,
455
+ num_accepted_tokens=0,
456
+ can_run_cuda_graph=False,
457
+ )
384
458
  else:
385
459
  with self.draft_tp_context(self.draft_model_runner.tp_group):
386
460
  spec_info = self.draft(batch)
@@ -398,12 +472,11 @@ class EAGLEWorker(TpModelWorker):
398
472
  # decode is not finished
399
473
  self.forward_draft_extend_after_decode(batch)
400
474
 
401
- return (
402
- logits_output,
403
- verify_output.verified_id,
404
- model_worker_batch.bid,
405
- sum(verify_output.accept_length_per_req_cpu),
406
- can_run_cuda_graph,
475
+ return ForwardBatchOutput(
476
+ logits_output=logits_output,
477
+ next_token_ids=verify_output.verified_id,
478
+ num_accepted_tokens=sum(verify_output.accept_length_per_req_cpu),
479
+ can_run_cuda_graph=can_run_cuda_graph,
407
480
  )
408
481
 
409
482
  def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
@@ -435,19 +508,21 @@ class EAGLEWorker(TpModelWorker):
435
508
  Returns:
436
509
  logits_output: The output of logits. It will contain the full hidden states.
437
510
  next_token_ids: Next token ids generated.
438
- bid: The model batch ID. Used for overlap schedule.
439
511
  """
440
512
  # Forward with the target model and get hidden states.
441
513
  # We need the full hidden states to prefill the KV cache of the draft model.
442
514
  model_worker_batch = batch.get_model_worker_batch()
443
515
  model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
444
- logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
516
+ forward_batch_output = self.target_worker.forward_batch_generation(
445
517
  model_worker_batch
446
518
  )
519
+ logits_output, next_token_ids = (
520
+ forward_batch_output.logits_output,
521
+ forward_batch_output.next_token_ids,
522
+ )
447
523
  return (
448
524
  logits_output,
449
525
  next_token_ids,
450
- model_worker_batch.bid,
451
526
  model_worker_batch.seq_lens_cpu,
452
527
  )
453
528
 
@@ -479,6 +554,8 @@ class EAGLEWorker(TpModelWorker):
479
554
  batch.seq_lens,
480
555
  self.speculative_num_steps,
481
556
  )
557
+ prefix_lens_cpu = batch.seq_lens_cpu
558
+ seq_lens_cpu = batch.seq_lens_cpu + self.speculative_num_steps
482
559
  extend_num_tokens = num_seqs * self.speculative_num_steps
483
560
  else:
484
561
  # In this case, the last partial page needs to be duplicated.
@@ -514,14 +591,23 @@ class EAGLEWorker(TpModelWorker):
514
591
  self.topk,
515
592
  self.page_size,
516
593
  )
517
-
518
- # TODO(lmzheng): remove this device sync
519
- extend_num_tokens = torch.sum(self.extend_lens).item()
594
+ prefix_lens_cpu = batch.seq_lens_cpu
595
+ last_page_lens = prefix_lens_cpu % self.page_size
596
+ num_new_pages_per_topk = (
597
+ last_page_lens + self.speculative_num_steps + self.page_size - 1
598
+ ) // self.page_size
599
+ seq_lens_cpu = (
600
+ prefix_lens_cpu // self.page_size * self.page_size
601
+ + num_new_pages_per_topk * (self.page_size * self.topk)
602
+ )
603
+ extend_num_tokens = torch.sum((seq_lens_cpu - prefix_lens_cpu)).item()
520
604
 
521
605
  out_cache_loc, token_to_kv_pool_state_backup = (
522
606
  batch.alloc_paged_token_slots_extend(
523
607
  prefix_lens,
608
+ prefix_lens_cpu,
524
609
  seq_lens,
610
+ seq_lens_cpu,
525
611
  last_loc,
526
612
  extend_num_tokens,
527
613
  backup_state=True,
@@ -683,6 +769,14 @@ class EAGLEWorker(TpModelWorker):
683
769
 
684
770
  # Set inputs
685
771
  forward_batch.input_ids = input_ids
772
+ # This is a temporary fix for the case that the user is using standalone
773
+ # speculative decoding and the draft model architecture is gpt-oss. gpt-oss
774
+ # rope kernel needs cache_loc to be contiguous.
775
+ if (
776
+ self.server_args.speculative_algorithm == "STANDALONE"
777
+ and self.model_config.hf_config.architectures[0] == "GptOssForCausalLM"
778
+ ):
779
+ out_cache_loc = out_cache_loc.contiguous()
686
780
  forward_batch.out_cache_loc = out_cache_loc[i]
687
781
  forward_batch.positions.add_(1)
688
782
  forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
@@ -701,6 +795,10 @@ class EAGLEWorker(TpModelWorker):
701
795
 
702
796
  return score_list, token_list, parents_list
703
797
 
798
+ def clear_cache_pool(self):
799
+ self.model_runner.req_to_token_pool.clear()
800
+ self.model_runner.token_to_kv_pool_allocator.clear()
801
+
704
802
  def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
705
803
  spec_info.prepare_for_verify(batch, self.page_size)
706
804
  batch.return_hidden_states = False
@@ -724,10 +822,12 @@ class EAGLEWorker(TpModelWorker):
724
822
  ).cpu()
725
823
 
726
824
  # Forward
727
- logits_output, _, can_run_cuda_graph = (
728
- self.target_worker.forward_batch_generation(
729
- model_worker_batch, skip_sample=True
730
- )
825
+ forward_batch_output = self.target_worker.forward_batch_generation(
826
+ model_worker_batch, is_verify=True
827
+ )
828
+ logits_output, can_run_cuda_graph = (
829
+ forward_batch_output.logits_output,
830
+ forward_batch_output.can_run_cuda_graph,
731
831
  )
732
832
 
733
833
  vocab_mask = None
@@ -767,6 +867,21 @@ class EAGLEWorker(TpModelWorker):
767
867
  ]
768
868
  logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
769
869
 
870
+ # QQ: can be optimized
871
+ if self.target_worker.model_runner.is_hybrid_gdn:
872
+ # res.draft_input.accept_length is on GPU but may be empty for last verify?
873
+ accepted_length = (
874
+ torch.tensor(
875
+ res.accept_length_per_req_cpu,
876
+ device=logits_output.hidden_states.device,
877
+ dtype=torch.int32,
878
+ )
879
+ + 1
880
+ )
881
+ self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
882
+ accepted_length, self.target_worker.model_runner.model
883
+ )
884
+
770
885
  if batch.return_logprob:
771
886
  self.add_logprob_values(batch, res, logits_output)
772
887
 
@@ -912,6 +1027,7 @@ class EAGLEWorker(TpModelWorker):
912
1027
  assert isinstance(batch.spec_info, EagleDraftInput)
913
1028
  # Backup fields that will be modified in-place
914
1029
  seq_lens_backup = batch.seq_lens.clone()
1030
+ seq_lens_cpu_backup = batch.seq_lens_cpu.clone()
915
1031
  req_pool_indices_backup = batch.req_pool_indices
916
1032
  accept_length_backup = batch.spec_info.accept_length
917
1033
  return_logprob_backup = batch.return_logprob
@@ -990,6 +1106,7 @@ class EAGLEWorker(TpModelWorker):
990
1106
  ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
991
1107
  )
992
1108
  batch.seq_lens = seq_lens_backup
1109
+ batch.seq_lens_cpu = seq_lens_cpu_backup
993
1110
  batch.req_pool_indices = req_pool_indices_backup
994
1111
  batch.spec_info.accept_length = accept_length_backup
995
1112
  batch.return_logprob = return_logprob_backup