sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,110 @@
1
+ #pragma once
2
+
3
+ #include <cstddef>
4
+ #include <cstdint>
5
+ #include <functional>
6
+ #include <list>
7
+ #include <mutex>
8
+ #include <set>
9
+ #include <sstream>
10
+ #include <thread>
11
+ #include <tuple>
12
+ #include <unordered_map>
13
+ #include <vector>
14
+
15
+ #include "param.h"
16
+ #include "queue.h"
17
+
18
+ namespace ngram {
19
+
20
+ struct TrieNode {
21
+ std::unordered_map<int32_t, TrieNode*> child;
22
+ std::list<TrieNode*>::const_iterator global_lru_pos;
23
+ std::list<TrieNode*>::const_iterator parent_lru_pos;
24
+ int32_t token;
25
+ TrieNode* parent;
26
+ std::list<TrieNode*> lru;
27
+ int32_t freq = 0;
28
+
29
+ struct CompareByFreq {
30
+ bool operator()(TrieNode* a, TrieNode* b) const {
31
+ return std::tie(b->freq, a->token, a) < std::tie(a->freq, b->token, b);
32
+ }
33
+ };
34
+ std::multiset<TrieNode*, CompareByFreq> sorted_children;
35
+ };
36
+
37
+ class Ngram {
38
+ std::vector<TrieNode> nodes_;
39
+ std::vector<TrieNode*> node_pool_;
40
+ size_t free_node_count_;
41
+ std::list<TrieNode*> global_lru_;
42
+ TrieNode* root_;
43
+ std::vector<TrieNode*> path_;
44
+ Param param_;
45
+
46
+ std::vector<std::pair<TrieNode*, int32_t>> match(const std::vector<int32_t>& tokens, size_t batch_size) const;
47
+
48
+ void squeeze(size_t count);
49
+
50
+ TrieNode* getNode() {
51
+ auto node = node_pool_[--free_node_count_];
52
+ node->~TrieNode();
53
+ new (node) TrieNode();
54
+ return node;
55
+ }
56
+
57
+ mutable std::mutex mutex_;
58
+ bool quit_flag_;
59
+ utils::Queue<std::vector<int32_t>> insert_queue_;
60
+ std::thread insert_worker_;
61
+ std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>> match_tmp_data_;
62
+
63
+ public:
64
+ Ngram(size_t capacity, const Param& param);
65
+ Ngram() = default;
66
+ ~Ngram();
67
+
68
+ static Ngram& instance() {
69
+ static Ngram instance;
70
+ return instance;
71
+ }
72
+
73
+ void synchronize() const;
74
+
75
+ void asyncInsert(std::vector<std::vector<int32_t>>&& tokens);
76
+
77
+ struct Result {
78
+ std::vector<int32_t> token;
79
+ std::vector<uint8_t> mask;
80
+
81
+ void truncate(size_t n);
82
+ };
83
+
84
+ Result batchMatch(const std::vector<std::vector<int32_t>>& tokens) const;
85
+
86
+ void reset() {
87
+ std::unique_lock<std::mutex> lock(mutex_);
88
+
89
+ global_lru_.clear();
90
+ path_.clear();
91
+ node_pool_.clear();
92
+ for (auto& node : nodes_) {
93
+ node_pool_.emplace_back(&node);
94
+ }
95
+ free_node_count_ = node_pool_.size();
96
+ root_ = getNode();
97
+ }
98
+
99
+ const Param& param() const {
100
+ return param_;
101
+ }
102
+
103
+ private:
104
+ Result matchBFS(const std::vector<int32_t>& tokens, size_t batch_size) const;
105
+ Result matchProb(const std::vector<int32_t>& tokens, size_t batch_size) const;
106
+
107
+ void insert();
108
+ };
109
+
110
+ } // namespace ngram
@@ -0,0 +1,138 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import logging
4
+ import os
5
+ from typing import List, Tuple
6
+
7
+ import numpy as np
8
+ from torch.utils.cpp_extension import load
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ _abs_path = os.path.dirname(os.path.abspath(__file__))
13
+ ngram_cache_cpp = load(
14
+ name="ngram_cache_cpp",
15
+ sources=[
16
+ f"{_abs_path}/ngram_cache_binding.cpp",
17
+ f"{_abs_path}/ngram.cpp",
18
+ ],
19
+ extra_cflags=["-O3", "-std=c++20"],
20
+ )
21
+
22
+
23
+ class NgramCache:
24
+ def __init__(
25
+ self,
26
+ branch_length=18,
27
+ min_match_window_size=1,
28
+ max_match_window_size=10,
29
+ min_bfs_breadth=1,
30
+ max_bfs_breadth=8,
31
+ draft_token_num=8,
32
+ match_type="BFS",
33
+ capacity=1000000,
34
+ ):
35
+ param = ngram_cache_cpp.Param()
36
+ param.branch_length = branch_length
37
+ param.min_match_window_size = min_match_window_size
38
+ param.max_match_window_size = max_match_window_size
39
+ param.min_bfs_breadth = min_bfs_breadth
40
+ param.max_bfs_breadth = max_bfs_breadth
41
+ param.draft_token_num = draft_token_num
42
+ param.match_type = match_type
43
+ self.cache = ngram_cache_cpp.Ngram(capacity, param)
44
+
45
+ self.default_mask = np.ones((1, 1), dtype=np.int64)
46
+ self.draft_token_num = draft_token_num
47
+
48
+ def batch_put(self, batch_tokens: List[List[int]]):
49
+ self.cache.asyncInsert(batch_tokens)
50
+
51
+ def synchronize(self):
52
+ self.cache.synchronize()
53
+
54
+ def reset(self):
55
+ self.cache.reset()
56
+
57
+ def batch_get(self, batch_tokens: List[List[int]]) -> Tuple[np.ndarray, np.ndarray]:
58
+ result = self.cache.batchMatch(batch_tokens)
59
+ return np.array(result.token), np.array(result.mask)
60
+
61
+ def leaf_paths_from_mask(
62
+ self, tokens: List[int], tree_mask: List[List[int]]
63
+ ) -> List[List[int]]:
64
+ """
65
+ Find all leaf paths according to the binary tree_mask (i.e., paths that are not prefixes of any other path).
66
+
67
+ Args:
68
+ mask : List[List[int]] # nxn binary matrix
69
+ tokens : List[int] # token list corresponding to columns
70
+
71
+ Returns:
72
+ List[List[int]] # token lists of only the leaf paths, preserving their order of appearance
73
+ """
74
+
75
+ row_sets = [
76
+ (i, {idx for idx, v in enumerate(row) if v == 1})
77
+ for i, row in enumerate(tree_mask)
78
+ ]
79
+ leaf_sets = []
80
+ leaf_rows = []
81
+
82
+ for i, cur_set in reversed(row_sets):
83
+ if any(cur_set <= kept for kept in leaf_sets):
84
+ continue
85
+ leaf_sets.append(cur_set)
86
+ leaf_rows.append(i)
87
+
88
+ leaf_rows.reverse()
89
+ result = []
90
+ for r in leaf_rows:
91
+ path = [tokens[col] for col in range(len(tokens)) if tree_mask[r][col] == 1]
92
+ result.append(path)
93
+
94
+ return result
95
+
96
+ def debug_result(
97
+ self, decoding_ids: np.ndarray, decoding_masks: np.ndarray, tokenizer=None
98
+ ):
99
+ decoding_ids = decoding_ids.reshape(-1, self.draft_token_num)
100
+ decoding_masks = decoding_masks.reshape(
101
+ -1, self.draft_token_num, self.draft_token_num
102
+ )
103
+ logger.info(f"\n{decoding_ids=}\n{decoding_masks=}")
104
+ for i in range(decoding_ids.shape[0]):
105
+ leaf_paths = self.leaf_paths_from_mask(
106
+ decoding_ids[i].tolist(), decoding_masks[i].tolist()
107
+ )
108
+ if tokenizer is None:
109
+ logger.info(f"draft path {i}: {leaf_paths}")
110
+ else:
111
+ logger.info(f"result {i}:")
112
+ for leaf_path in leaf_paths:
113
+ logger.info(
114
+ f"draft path {i}: {leaf_path} -> {tokenizer.decode(leaf_path, ensure_ascii=False)}"
115
+ )
116
+
117
+
118
+ # main function
119
+ if __name__ == "__main__":
120
+ format = f"%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
121
+ logging.basicConfig(
122
+ level=logging.DEBUG,
123
+ format=format,
124
+ datefmt="%Y-%m-%d %H:%M:%S",
125
+ force=True,
126
+ )
127
+
128
+ token_ids = [
129
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
130
+ [1, 2, 3, 44, 55, 66, 77, 88, 99, 100],
131
+ ]
132
+ cache = NgramCache(branch_length=12, draft_token_num=8)
133
+ cache.batch_put(token_ids)
134
+
135
+ cache.synchronize()
136
+ decoding_ids, decoding_masks = cache.batch_get([[1, 2, 3], [3, 44], [3, 6, 999]])
137
+
138
+ cache.debug_result(decoding_ids, decoding_masks)
@@ -0,0 +1,43 @@
1
+ #include <pybind11/pybind11.h>
2
+ #include <pybind11/stl.h>
3
+
4
+ #include "ngram.h"
5
+
6
+ PYBIND11_MODULE(ngram_cache_cpp, m) {
7
+ using namespace ngram;
8
+ namespace py = pybind11;
9
+ m.doc() = "";
10
+
11
+ py::class_<Ngram>(m, "Ngram")
12
+ .def(py::init<size_t, const Param&>(), py::arg("capacity"), py::arg("param"))
13
+ .def("asyncInsert", &Ngram::asyncInsert, "")
14
+ .def("batchMatch", &Ngram::batchMatch, "")
15
+ .def("reset", &Ngram::reset, "")
16
+ .def("synchronize", &Ngram::synchronize, "");
17
+
18
+ py::class_<Param>(m, "Param")
19
+ .def(py::init<>())
20
+ .def_readwrite("enable", &Param::enable)
21
+ .def_readwrite("enable_router_mode", &Param::enable_router_mode)
22
+ .def_readwrite("min_bfs_breadth", &Param::min_bfs_breadth)
23
+ .def_readwrite("max_bfs_breadth", &Param::max_bfs_breadth)
24
+ .def_readwrite("min_match_window_size", &Param::min_match_window_size)
25
+ .def_readwrite("max_match_window_size", &Param::max_match_window_size)
26
+ .def_readwrite("branch_length", &Param::branch_length)
27
+ .def_readwrite("draft_token_num", &Param::draft_token_num)
28
+ .def_readwrite("match_type", &Param::match_type)
29
+ .def_readwrite("batch_min_match_window_size", &Param::batch_min_match_window_size)
30
+ .def_readwrite("batch_draft_token_num", &Param::batch_draft_token_num)
31
+ .def("get_draft_token_num", &Param::get_draft_token_num, "")
32
+ .def("get_min_match_window_size", &Param::get_min_match_window_size, "")
33
+ .def("parse", &Param::parse, "")
34
+ .def("resetBatchMinMatchWindowSize", &Param::resetBatchMinMatchWindowSize, "")
35
+ .def("resetBatchReturnTokenNum", &Param::resetBatchReturnTokenNum, "")
36
+ .def("detail", &Param::detail, "");
37
+
38
+ py::class_<Ngram::Result>(m, "Result")
39
+ .def(py::init<>())
40
+ .def_readwrite("token", &Ngram::Result::token)
41
+ .def_readwrite("mask", &Ngram::Result::mask)
42
+ .def("truncate", &Ngram::Result::truncate);
43
+ }
@@ -0,0 +1,125 @@
1
+ #pragma once
2
+
3
+ #include <cstddef>
4
+ #include <iostream>
5
+ #include <limits>
6
+ #include <regex>
7
+ #include <sstream>
8
+ #include <stdexcept>
9
+ #include <string>
10
+ #include <vector>
11
+
12
+ namespace ngram {
13
+
14
+ struct Param {
15
+ bool enable;
16
+ bool enable_router_mode;
17
+ size_t min_bfs_breadth;
18
+ size_t max_bfs_breadth;
19
+ size_t min_match_window_size;
20
+ size_t max_match_window_size;
21
+ size_t branch_length;
22
+ size_t draft_token_num;
23
+ std::string match_type;
24
+
25
+ std::vector<size_t> batch_min_match_window_size;
26
+ std::vector<size_t> batch_draft_token_num;
27
+
28
+ size_t get_draft_token_num(size_t batch_size) const {
29
+ if (batch_size < batch_draft_token_num.size()) {
30
+ if (batch_draft_token_num[batch_size] !=
31
+ std::numeric_limits<decltype(batch_draft_token_num)::value_type>::max()) {
32
+ return batch_draft_token_num[batch_size];
33
+ }
34
+ }
35
+ return draft_token_num - 1;
36
+ }
37
+
38
+ size_t get_min_match_window_size(size_t batch_size) const {
39
+ if (batch_size < batch_min_match_window_size.size()) {
40
+ if (batch_min_match_window_size[batch_size] !=
41
+ std::numeric_limits<decltype(batch_min_match_window_size)::value_type>::max()) {
42
+ return batch_min_match_window_size[batch_size];
43
+ }
44
+ }
45
+ return min_match_window_size;
46
+ }
47
+
48
+ std::vector<size_t> parse(const std::string& value) {
49
+ // 0-1|10,2-3|20,
50
+ std::vector<size_t> result;
51
+ if (value.empty()) {
52
+ return result;
53
+ }
54
+ std::vector<size_t> mark;
55
+ std::regex comma_re(",");
56
+ std::sregex_token_iterator first{value.begin(), value.end(), comma_re, -1}, last;
57
+ for (auto p : std::vector<std::string>(first, last)) {
58
+ std::cerr << "seg " << p << std::endl;
59
+ }
60
+ for (const auto& seg : std::vector<std::string>(first, last)) {
61
+ std::regex pipe_re("\\|");
62
+ std::sregex_token_iterator seg_first{seg.begin(), seg.end(), pipe_re, -1}, seg_last;
63
+ std::vector<std::string> part(seg_first, seg_last);
64
+ for (auto p : part) {
65
+ std::cerr << "part " << p << std::endl;
66
+ }
67
+ if (part.size() != 2) {
68
+ throw std::runtime_error(
69
+ "failed to get config, invalid config: " + seg + ", part's size = " + std::to_string(part.size()));
70
+ }
71
+ std::regex endash_re("-");
72
+ std::sregex_token_iterator range_first{part[0].begin(), part[0].end(), endash_re, -1}, range_last;
73
+ std::vector<std::string> range(range_first, range_last);
74
+ if (range.size() != 2) {
75
+ throw std::runtime_error("failed to get range, invalid config: " + value);
76
+ }
77
+ size_t L = std::atoi(range[0].c_str());
78
+ size_t R = std::atoi(range[1].c_str());
79
+ if (L > R || R > 128) {
80
+ throw std::runtime_error("invalid range, config: " + value);
81
+ }
82
+ if (R >= result.size()) {
83
+ result.resize(R + 1, std::numeric_limits<decltype(result)::value_type>::max());
84
+ mark.resize(result.size(), false);
85
+ }
86
+ size_t config = std::atoi(part[1].c_str());
87
+ do {
88
+ if (mark[L]) {
89
+ throw std::runtime_error("repeated position " + std::to_string(L) + ", config : " + value);
90
+ }
91
+ mark[L] = true;
92
+ result[L] = config;
93
+ } while (++L <= R);
94
+ }
95
+ return result;
96
+ }
97
+
98
+ void resetBatchMinMatchWindowSize(const std::string& value) {
99
+ batch_min_match_window_size = parse(value);
100
+ }
101
+
102
+ void resetBatchReturnTokenNum(const std::string& value) {
103
+ batch_draft_token_num = parse(value);
104
+ }
105
+
106
+ std::string detail() {
107
+ std::stringstream ss;
108
+ ss << "enable = " << enable << ", enable_router_mode = " << enable_router_mode
109
+ << ", min_bfs_breadth = " << min_bfs_breadth << ", max_bfs_breadth = " << max_bfs_breadth
110
+ << ", min_match_window_size = " << min_match_window_size << ", max_match_window_size = " << max_match_window_size
111
+ << ", branch_length = " << branch_length << ", draft_token_num = " << draft_token_num
112
+ << ", match_type = " << match_type;
113
+ ss << ", batch_min_match_window_size(" << batch_min_match_window_size.size() << ") = ";
114
+ for (int i = 0; i < batch_min_match_window_size.size(); ++i) {
115
+ ss << i << "|" << batch_min_match_window_size[i] << ",";
116
+ }
117
+ ss << ", batch_draft_token_num(" << batch_draft_token_num.size() << ") = ";
118
+ for (int i = 0; i < batch_draft_token_num.size(); ++i) {
119
+ ss << i << "|" << batch_draft_token_num[i] << ",";
120
+ }
121
+ return ss.str();
122
+ }
123
+ };
124
+
125
+ } // namespace ngram
@@ -0,0 +1,71 @@
1
+ #pragma once
2
+
3
+ #include <condition_variable>
4
+ #include <queue>
5
+
6
+ namespace utils {
7
+
8
+ template <typename T>
9
+ class Queue {
10
+ public:
11
+ bool enqueue(T&& rhs) {
12
+ {
13
+ std::lock_guard<std::mutex> lock(mutex_);
14
+ if (closed_) {
15
+ return false;
16
+ }
17
+ queue_.emplace(std::move(rhs));
18
+ }
19
+ cv_.notify_one();
20
+ return true;
21
+ }
22
+
23
+ bool enqueue(const T& rhs) {
24
+ {
25
+ std::lock_guard<std::mutex> lock(mutex_);
26
+ if (closed_) {
27
+ return false;
28
+ }
29
+ queue_.emplace(rhs);
30
+ }
31
+ cv_.notify_one();
32
+ return true;
33
+ }
34
+
35
+ bool dequeue(T& rhs) {
36
+ std::unique_lock<std::mutex> lock(mutex_);
37
+ cv_.wait(lock, [this] { return queue_.size() || closed_; });
38
+ if (closed_) {
39
+ return false;
40
+ }
41
+ rhs = std::move(queue_.front());
42
+ queue_.pop();
43
+ return true;
44
+ }
45
+
46
+ size_t size() const {
47
+ std::lock_guard<std::mutex> lock(mutex_);
48
+ return queue_.size();
49
+ }
50
+
51
+ bool empty() const {
52
+ std::lock_guard<std::mutex> lock(mutex_);
53
+ return queue_.empty();
54
+ }
55
+
56
+ void close() {
57
+ {
58
+ std::lock_guard<std::mutex> lock(mutex_);
59
+ closed_ = true;
60
+ }
61
+ cv_.notify_all();
62
+ }
63
+
64
+ private:
65
+ std::queue<T> queue_;
66
+ mutable std::mutex mutex_;
67
+ std::condition_variable cv_;
68
+ bool closed_{false};
69
+ };
70
+
71
+ } // namespace utils
@@ -20,7 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import (
20
20
  ForwardBatch,
21
21
  ForwardMode,
22
22
  )
23
- from sglang.srt.speculative.eagle_utils import EagleDraftInput
23
+ from sglang.srt.speculative.eagle_info import EagleDraftInput
24
24
  from sglang.srt.utils import (
25
25
  require_attn_tp_gather,
26
26
  require_gathered_buffer,
@@ -91,6 +91,9 @@ class EAGLEDraftCudaGraphRunner:
91
91
  (self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
92
92
  )
93
93
  self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
94
+ self.mrope_positions = torch.zeros(
95
+ (3, self.max_num_token), dtype=torch.int64
96
+ )
94
97
  self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
95
98
  self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
96
99
  self.hidden_states = torch.zeros(
@@ -159,6 +162,7 @@ class EAGLEDraftCudaGraphRunner:
159
162
  seq_lens = self.seq_lens[:num_seqs]
160
163
  out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
161
164
  positions = self.positions[:num_tokens]
165
+ mrope_positions = self.mrope_positions[:, :num_tokens]
162
166
  topk_p = self.topk_p[:num_seqs]
163
167
  topk_index = self.topk_index[:num_seqs]
164
168
  hidden_states = self.hidden_states[:num_seqs]
@@ -224,6 +228,7 @@ class EAGLEDraftCudaGraphRunner:
224
228
  seq_lens_sum=seq_lens.sum().item(),
225
229
  return_logprob=False,
226
230
  positions=positions,
231
+ mrope_positions=mrope_positions,
227
232
  global_num_tokens_gpu=global_num_tokens,
228
233
  dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
229
234
  global_dp_buffer_len=global_dp_buffer_len,
@@ -297,6 +302,7 @@ class EAGLEDraftCudaGraphRunner:
297
302
  if bs != raw_bs:
298
303
  self.seq_lens.fill_(self.seq_len_fill_value)
299
304
  self.out_cache_loc.zero_()
305
+ self.positions.zero_()
300
306
 
301
307
  num_tokens = bs * self.num_tokens_per_bs
302
308
 
@@ -21,7 +21,8 @@ from sglang.srt.model_executor.forward_batch_info import (
21
21
  ForwardBatch,
22
22
  ForwardMode,
23
23
  )
24
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
24
+ from sglang.srt.speculative.eagle_info import EagleDraftInput
25
+ from sglang.srt.speculative.spec_utils import fast_topk
25
26
  from sglang.srt.utils import (
26
27
  require_attn_tp_gather,
27
28
  require_gathered_buffer,
@@ -80,6 +81,9 @@ class EAGLEDraftExtendCudaGraphRunner:
80
81
  self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
81
82
  self.out_cache_loc = torch.ones((self.max_num_token,), dtype=torch.int64)
82
83
  self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
84
+ self.mrope_positions = torch.zeros(
85
+ (3, self.max_num_token), dtype=torch.int64
86
+ )
83
87
 
84
88
  if self.eagle_worker.speculative_algorithm.is_eagle3():
85
89
  self.hidden_states = torch.zeros(
@@ -189,6 +193,7 @@ class EAGLEDraftExtendCudaGraphRunner:
189
193
  accept_length = self.accept_length[:bs]
190
194
  out_cache_loc = self.out_cache_loc[:num_tokens]
191
195
  positions = self.positions[:num_tokens]
196
+ mrope_positions = self.mrope_positions[:, :num_tokens]
192
197
  hidden_states = self.hidden_states[:num_tokens]
193
198
  next_token_logits_buffer = self.next_token_logits_buffer[:bs]
194
199
 
@@ -247,6 +252,7 @@ class EAGLEDraftExtendCudaGraphRunner:
247
252
  seq_lens_sum=seq_lens.sum().item(),
248
253
  return_logprob=False,
249
254
  positions=positions,
255
+ mrope_positions=mrope_positions,
250
256
  global_num_tokens_gpu=self.global_num_tokens_gpu,
251
257
  global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
252
258
  dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
@@ -326,6 +332,7 @@ class EAGLEDraftExtendCudaGraphRunner:
326
332
  if bs * self.num_tokens_per_bs != num_tokens:
327
333
  self.seq_lens.fill_(self.seq_len_fill_value)
328
334
  self.out_cache_loc.zero_()
335
+ self.positions.zero_()
329
336
  self.accept_length.fill_(1)
330
337
  self.extend_seq_lens.fill_(1)
331
338
 
@@ -336,7 +343,11 @@ class EAGLEDraftExtendCudaGraphRunner:
336
343
  self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
337
344
  self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
338
345
  self.positions[:num_tokens].copy_(forward_batch.positions)
339
- self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
346
+ if (
347
+ forward_batch.spec_info.hidden_states.shape[1]
348
+ == self.hidden_states.shape[1]
349
+ ):
350
+ self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
340
351
  if forward_batch.spec_info.accept_length is not None:
341
352
  self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
342
353
  self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)