sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -48,18 +48,22 @@ from sglang.srt.model_executor.forward_batch_info import (
48
48
  PPProxyTensors,
49
49
  enable_num_token_non_padded,
50
50
  )
51
- from sglang.srt.patch_torch import monkey_patch_torch_compile
52
51
  from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
53
52
  from sglang.srt.utils import (
54
53
  empty_context,
55
54
  get_available_gpu_memory,
55
+ get_bool_env_var,
56
56
  get_device_memory_capacity,
57
+ is_hip,
57
58
  log_info_on_rank0,
58
59
  require_attn_tp_gather,
59
60
  require_gathered_buffer,
60
61
  require_mlp_sync,
61
62
  require_mlp_tp_gather,
62
63
  )
64
+ from sglang.srt.utils.patch_torch import monkey_patch_torch_compile
65
+
66
+ _is_hip = is_hip()
63
67
 
64
68
  logger = logging.getLogger(__name__)
65
69
 
@@ -100,6 +104,7 @@ def freeze_gc(enable_cudagraph_gc: bool):
100
104
  finally:
101
105
  if should_freeze:
102
106
  gc.unfreeze()
107
+ gc.collect()
103
108
 
104
109
 
105
110
  def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
@@ -136,7 +141,7 @@ def patch_model(
136
141
  mode=os.environ.get(
137
142
  "SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
138
143
  ),
139
- dynamic=False,
144
+ dynamic=_is_hip and get_bool_env_var("SGLANG_TORCH_DYNAMIC_SHAPE"),
140
145
  )
141
146
  else:
142
147
  yield model.forward
@@ -166,29 +171,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
166
171
  server_args = model_runner.server_args
167
172
  capture_bs = server_args.cuda_graph_bs
168
173
 
169
- if capture_bs is None:
170
- if server_args.speculative_algorithm is None:
171
- if server_args.disable_cuda_graph_padding:
172
- capture_bs = list(range(1, 33)) + list(range(48, 161, 16))
173
- else:
174
- capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
175
- else:
176
- # Since speculative decoding requires more cuda graph memory, we
177
- # capture less.
178
- capture_bs = (
179
- list(range(1, 9))
180
- + list(range(10, 33, 2))
181
- + list(range(40, 64, 8))
182
- + list(range(80, 161, 16))
183
- )
184
-
185
- gpu_mem = get_device_memory_capacity()
186
- if gpu_mem is not None:
187
- if gpu_mem > 90 * 1024: # H200, H20
188
- capture_bs += list(range(160, 257, 8))
189
- if gpu_mem > 160 * 1000: # B200, MI300
190
- capture_bs += list(range(256, 513, 16))
191
-
192
174
  if max(capture_bs) > model_runner.req_to_token_pool.size:
193
175
  # In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
194
176
  # is very small. We add more values here to make sure we capture the maximum bs.
@@ -204,12 +186,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
204
186
 
205
187
  capture_bs = [bs for bs in capture_bs if bs % mul_base == 0]
206
188
 
207
- if server_args.cuda_graph_max_bs:
208
- capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
209
- if max(capture_bs) < server_args.cuda_graph_max_bs:
210
- capture_bs += list(
211
- range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16)
212
- )
213
189
  capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
214
190
  capture_bs = list(sorted(set(capture_bs)))
215
191
  assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}"
@@ -271,7 +247,11 @@ class CudaGraphRunner:
271
247
  self.capture_forward_mode = ForwardMode.DECODE
272
248
  self.capture_hidden_mode = CaptureHiddenMode.NULL
273
249
  self.num_tokens_per_bs = 1
274
- if model_runner.spec_algorithm.is_eagle():
250
+ if (
251
+ model_runner.spec_algorithm.is_eagle()
252
+ or model_runner.spec_algorithm.is_standalone()
253
+ or model_runner.spec_algorithm.is_ngram()
254
+ ):
275
255
  if self.model_runner.is_draft_worker:
276
256
  raise RuntimeError("This should not happen")
277
257
  else:
@@ -317,7 +297,9 @@ class CudaGraphRunner:
317
297
  (self.max_num_token,), dtype=self._cache_loc_dtype()
318
298
  )
319
299
  self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
320
- self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
300
+ self.mrope_positions = torch.zeros(
301
+ (3, self.max_num_token), dtype=torch.int64
302
+ )
321
303
  self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
322
304
  self.tbo_plugin = TboCudaGraphRunnerPlugin()
323
305
 
@@ -435,11 +417,21 @@ class CudaGraphRunner:
435
417
  forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
436
418
  )
437
419
 
420
+ is_ngram_supported = (
421
+ (
422
+ forward_batch.batch_size * self.num_tokens_per_bs
423
+ == forward_batch.input_ids.numel()
424
+ )
425
+ if self.model_runner.spec_algorithm.is_ngram()
426
+ else True
427
+ )
428
+
438
429
  return (
439
430
  is_bs_supported
440
431
  and is_encoder_lens_supported
441
432
  and is_tbo_supported
442
433
  and capture_hidden_mode_matches
434
+ and is_ngram_supported
443
435
  )
444
436
 
445
437
  def capture(self) -> None:
@@ -449,6 +441,7 @@ class CudaGraphRunner:
449
441
  activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
450
442
  record_shapes=True,
451
443
  )
444
+ torch.cuda.memory._record_memory_history()
452
445
 
453
446
  # Trigger CUDA graph capture for specific shapes.
454
447
  # Capture the large shapes first so that the smaller shapes
@@ -497,6 +490,8 @@ class CudaGraphRunner:
497
490
  save_gemlite_cache()
498
491
 
499
492
  if self.enable_profile_cuda_graph:
493
+ torch.cuda.memory._dump_snapshot(f"cuda_graph_runner_memory_usage.pickle")
494
+ torch.cuda.memory._record_memory_history(enabled=None)
500
495
  log_message = (
501
496
  "Sorted by CUDA Time:\n"
502
497
  + prof.key_averages(group_by_input_shape=True).table(
@@ -506,6 +501,7 @@ class CudaGraphRunner:
506
501
  + prof.key_averages(group_by_input_shape=True).table(
507
502
  sort_by="cpu_time_total", row_limit=10
508
503
  )
504
+ + "\n\nMemory Usage is saved to cuda_graph_runner_memory_usage.pickle\n"
509
505
  )
510
506
  logger.info(log_message)
511
507
 
@@ -526,13 +522,14 @@ class CudaGraphRunner:
526
522
  input_ids = self.input_ids[:num_tokens]
527
523
  req_pool_indices = self.req_pool_indices[:bs]
528
524
  seq_lens = self.seq_lens[:bs]
525
+ seq_lens_cpu = self.seq_lens_cpu[:bs]
529
526
  out_cache_loc = self.out_cache_loc[:num_tokens]
530
527
  positions = self.positions[:num_tokens]
531
528
  if self.is_encoder_decoder:
532
529
  encoder_lens = self.encoder_lens[:bs]
533
530
  else:
534
531
  encoder_lens = None
535
- mrope_positions = self.mrope_positions[:, :bs]
532
+ mrope_positions = self.mrope_positions[:, :num_tokens]
536
533
  next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens]
537
534
  self.num_token_non_padded[...] = num_tokens
538
535
 
@@ -596,6 +593,7 @@ class CudaGraphRunner:
596
593
  input_ids=input_ids,
597
594
  req_pool_indices=req_pool_indices,
598
595
  seq_lens=seq_lens,
596
+ seq_lens_cpu=seq_lens_cpu,
599
597
  next_token_logits_buffer=next_token_logits_buffer,
600
598
  orig_seq_lens=seq_lens,
601
599
  req_to_token_pool=self.model_runner.req_to_token_pool,
@@ -751,7 +749,7 @@ class CudaGraphRunner:
751
749
  if self.is_encoder_decoder:
752
750
  self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
753
751
  if forward_batch.mrope_positions is not None:
754
- self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
752
+ self.mrope_positions[:, :raw_num_token].copy_(forward_batch.mrope_positions)
755
753
  if self.require_gathered_buffer:
756
754
  self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
757
755
  self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
@@ -825,8 +823,11 @@ class CudaGraphRunner:
825
823
 
826
824
  def get_spec_info(self, num_tokens: int):
827
825
  spec_info = None
828
- if self.model_runner.spec_algorithm.is_eagle():
829
- from sglang.srt.speculative.eagle_utils import EagleVerifyInput
826
+ if (
827
+ self.model_runner.spec_algorithm.is_eagle()
828
+ or self.model_runner.spec_algorithm.is_standalone()
829
+ ):
830
+ from sglang.srt.speculative.eagle_info import EagleVerifyInput
830
831
 
831
832
  if self.model_runner.is_draft_worker:
832
833
  raise RuntimeError("This should not happen.")
@@ -847,6 +848,20 @@ class CudaGraphRunner:
847
848
  seq_lens_cpu=None,
848
849
  )
849
850
 
851
+ elif self.model_runner.spec_algorithm.is_ngram():
852
+ from sglang.srt.speculative.ngram_info import NgramVerifyInput
853
+
854
+ spec_info = NgramVerifyInput(
855
+ draft_token=None,
856
+ tree_mask=self.custom_mask,
857
+ positions=None,
858
+ retrive_index=None,
859
+ retrive_next_token=None,
860
+ retrive_next_sibling=None,
861
+ draft_token_num=self.num_tokens_per_bs,
862
+ )
863
+ spec_info.capture_hidden_mode = CaptureHiddenMode.NULL
864
+
850
865
  return spec_info
851
866
 
852
867
 
@@ -45,13 +45,7 @@ from sglang.srt.layers.dp_attention import (
45
45
  get_attention_tp_size,
46
46
  set_dp_buffer_len,
47
47
  )
48
- from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
49
- from sglang.srt.utils import (
50
- flatten_nested_list,
51
- get_compiler_backend,
52
- is_npu,
53
- support_triton,
54
- )
48
+ from sglang.srt.utils import get_compiler_backend, is_npu, support_triton
55
49
 
56
50
  if TYPE_CHECKING:
57
51
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -60,8 +54,7 @@ if TYPE_CHECKING:
60
54
  from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
61
55
  from sglang.srt.model_executor.model_runner import ModelRunner
62
56
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
63
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
64
- from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
57
+ from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
65
58
 
66
59
  _is_npu = is_npu()
67
60
 
@@ -82,10 +75,6 @@ class ForwardMode(IntEnum):
82
75
  # Used in speculative decoding: extend a batch in the draft model.
83
76
  DRAFT_EXTEND = auto()
84
77
 
85
- # A dummy first batch to start the pipeline for overlap scheduler.
86
- # It is now used for triggering the sampling_info_done event for the first prefill batch.
87
- DUMMY_FIRST = auto()
88
-
89
78
  # Split Prefill for PD multiplexing
90
79
  SPLIT_PREFILL = auto()
91
80
 
@@ -132,8 +121,8 @@ class ForwardMode(IntEnum):
132
121
  or self == ForwardMode.IDLE
133
122
  )
134
123
 
135
- def is_dummy_first(self):
136
- return self == ForwardMode.DUMMY_FIRST
124
+ def is_cpu_graph(self):
125
+ return self == ForwardMode.DECODE
137
126
 
138
127
  def is_split_prefill(self):
139
128
  return self == ForwardMode.SPLIT_PREFILL
@@ -289,14 +278,18 @@ class ForwardBatch:
289
278
  can_run_dp_cuda_graph: bool = False
290
279
  global_forward_mode: Optional[ForwardMode] = None
291
280
 
281
+ # Whether this batch is prefill-only (no token generation needed)
282
+ is_prefill_only: bool = False
283
+
292
284
  # Speculative decoding
293
- spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
285
+ spec_info: Optional[SpecInput] = None
294
286
  spec_algorithm: SpeculativeAlgorithm = None
295
287
  capture_hidden_mode: CaptureHiddenMode = None
296
288
 
297
289
  # For padding
298
290
  padded_static_len: int = -1 # -1 if not padded
299
291
  num_token_non_padded: Optional[torch.Tensor] = None # scalar tensor
292
+ num_token_non_padded_cpu: int = None
300
293
 
301
294
  # For Qwen2-VL
302
295
  mrope_positions: torch.Tensor = None
@@ -335,6 +328,7 @@ class ForwardBatch:
335
328
  is_extend_in_batch=batch.is_extend_in_batch,
336
329
  can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
337
330
  global_forward_mode=batch.global_forward_mode,
331
+ is_prefill_only=batch.is_prefill_only,
338
332
  lora_ids=batch.lora_ids,
339
333
  sampling_info=batch.sampling_info,
340
334
  req_to_token_pool=model_runner.req_to_token_pool,
@@ -358,36 +352,18 @@ class ForwardBatch:
358
352
  ret.num_token_non_padded = torch.tensor(
359
353
  len(batch.input_ids), dtype=torch.int32
360
354
  ).to(device, non_blocking=True)
355
+ ret.num_token_non_padded_cpu = len(batch.input_ids)
361
356
 
362
357
  # For MLP sync
363
358
  if batch.global_num_tokens is not None:
364
- from sglang.srt.speculative.eagle_utils import (
365
- EagleDraftInput,
366
- EagleVerifyInput,
367
- )
368
-
369
359
  assert batch.global_num_tokens_for_logprob is not None
360
+
370
361
  # process global_num_tokens and global_num_tokens_for_logprob
371
362
  if batch.spec_info is not None:
372
- if isinstance(batch.spec_info, EagleDraftInput):
373
- global_num_tokens = [
374
- x * batch.spec_info.num_tokens_per_batch
375
- for x in batch.global_num_tokens
376
- ]
377
- global_num_tokens_for_logprob = [
378
- x * batch.spec_info.num_tokens_for_logprob_per_batch
379
- for x in batch.global_num_tokens_for_logprob
380
- ]
381
- else:
382
- assert isinstance(batch.spec_info, EagleVerifyInput)
383
- global_num_tokens = [
384
- x * batch.spec_info.draft_token_num
385
- for x in batch.global_num_tokens
386
- ]
387
- global_num_tokens_for_logprob = [
388
- x * batch.spec_info.draft_token_num
389
- for x in batch.global_num_tokens_for_logprob
390
- ]
363
+ spec_info: SpecInput = batch.spec_info
364
+ global_num_tokens, global_num_tokens_for_logprob = (
365
+ spec_info.get_spec_adjusted_global_num_tokens(batch)
366
+ )
391
367
  else:
392
368
  global_num_tokens = batch.global_num_tokens
393
369
  global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob
@@ -441,7 +417,13 @@ class ForwardBatch:
441
417
  ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
442
418
 
443
419
  if model_runner.model_is_mrope:
444
- ret._compute_mrope_positions(model_runner, batch)
420
+ if (
421
+ ret.spec_info is not None
422
+ and getattr(ret.spec_info, "positions", None) is not None
423
+ ):
424
+ ret._compute_spec_mrope_positions(model_runner, batch)
425
+ else:
426
+ ret._compute_mrope_positions(model_runner, batch)
445
427
 
446
428
  # Init lora information
447
429
  if model_runner.server_args.enable_lora:
@@ -507,6 +489,52 @@ class ForwardBatch:
507
489
  or self.contains_image_inputs()
508
490
  )
509
491
 
492
+ def _compute_spec_mrope_positions(
493
+ self, model_runner: ModelRunner, batch: ModelWorkerBatch
494
+ ):
495
+ # TODO support batched deltas
496
+ batch_size = self.seq_lens.shape[0]
497
+ device = model_runner.device
498
+ mm_inputs = batch.multimodal_inputs
499
+
500
+ if batch.forward_mode.is_draft_extend(): # draft_extend_after_decode
501
+ mrope_deltas = []
502
+ extend_lens = []
503
+ for batch_idx in range(batch_size):
504
+ extend_seq_len = batch.extend_seq_lens[batch_idx]
505
+ extend_lens.append(extend_seq_len)
506
+ mrope_delta = (
507
+ torch.zeros(1, dtype=torch.int64)
508
+ if mm_inputs[batch_idx] is None
509
+ else mm_inputs[batch_idx].mrope_position_delta.squeeze(0)
510
+ )
511
+ mrope_deltas.append(mrope_delta.to(device=device))
512
+ position_chunks = torch.split(batch.spec_info.positions, extend_lens)
513
+ mrope_positions_list = [
514
+ pos_chunk + delta
515
+ for pos_chunk, delta in zip(position_chunks, mrope_deltas)
516
+ ]
517
+ next_input_positions = (
518
+ torch.cat(mrope_positions_list, dim=0).unsqueeze(0).repeat(3, 1)
519
+ )
520
+
521
+ else: # target_verify or draft_decode
522
+ seq_positions = batch.spec_info.positions.view(batch_size, -1)
523
+ mrope_deltas = [
524
+ (
525
+ torch.tensor([0], dtype=torch.int64)
526
+ if mm_inputs[i] is None
527
+ else mm_inputs[i].mrope_position_delta.squeeze(0)
528
+ )
529
+ for i in range(batch_size)
530
+ ]
531
+ mrope_delta_tensor = torch.stack(mrope_deltas, dim=0).to(device=device)
532
+ next_input_positions = (
533
+ (seq_positions + mrope_delta_tensor).flatten().unsqueeze(0).repeat(3, 1)
534
+ )
535
+
536
+ self.mrope_positions = next_input_positions
537
+
510
538
  def _compute_mrope_positions(
511
539
  self, model_runner: ModelRunner, batch: ModelWorkerBatch
512
540
  ):
@@ -614,9 +642,6 @@ class ForwardBatch:
614
642
  )
615
643
 
616
644
  def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
617
-
618
- from sglang.srt.speculative.eagle_utils import EagleDraftInput
619
-
620
645
  assert self.global_num_tokens_cpu is not None
621
646
  assert self.global_num_tokens_for_logprob_cpu is not None
622
647
 
@@ -631,7 +656,9 @@ class ForwardBatch:
631
656
  (global_num_tokens[i] - 1) // attn_tp_size + 1
632
657
  ) * attn_tp_size
633
658
 
634
- dp_padding_mode = DpPaddingMode.get_dp_padding_mode(global_num_tokens)
659
+ dp_padding_mode = DpPaddingMode.get_dp_padding_mode(
660
+ self.is_extend_in_batch, global_num_tokens
661
+ )
635
662
  self.dp_padding_mode = dp_padding_mode
636
663
 
637
664
  if dp_padding_mode.is_max_len():
@@ -711,7 +738,8 @@ class ForwardBatch:
711
738
  if self.extend_seq_lens is not None:
712
739
  self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
713
740
 
714
- if self.spec_info is not None and isinstance(self.spec_info, EagleDraftInput):
741
+ if self.spec_info is not None and self.spec_info.is_draft_input():
742
+ # FIXME(lsyin): remove this isinstance logic
715
743
  spec_info = self.spec_info
716
744
  self.output_cache_loc_backup = self.out_cache_loc
717
745
  self.hidden_states_backup = spec_info.hidden_states