sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -21,10 +21,11 @@ Life cycle of a request in the decode server
21
21
  from __future__ import annotations
22
22
 
23
23
  import logging
24
+ import time
24
25
  from collections import deque
25
26
  from dataclasses import dataclass
26
27
  from http import HTTPStatus
27
- from typing import TYPE_CHECKING, List, Optional, Tuple, Union
28
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
28
29
 
29
30
  import torch
30
31
  from torch.distributed import ProcessGroup
@@ -45,7 +46,7 @@ from sglang.srt.disaggregation.utils import (
45
46
  prepare_abort,
46
47
  )
47
48
  from sglang.srt.layers.dp_attention import get_attention_tp_size
48
- from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
49
+ from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
49
50
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
50
51
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
51
52
  from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
@@ -218,8 +219,10 @@ class DecodePreallocQueue:
218
219
 
219
220
  kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
220
221
  kv_args.gpu_id = self.scheduler.gpu_id
221
- kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
222
- kv_manager = kv_manager_class(
222
+ kv_manager_class: Type[BaseKVManager] = get_kv_class(
223
+ self.transfer_backend, KVClassType.MANAGER
224
+ )
225
+ kv_manager: BaseKVManager = kv_manager_class(
223
226
  kv_args,
224
227
  DisaggregationMode.DECODE,
225
228
  self.scheduler.server_args,
@@ -248,9 +251,10 @@ class DecodePreallocQueue:
248
251
  mgr=self.kv_manager,
249
252
  bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
250
253
  bootstrap_room=req.bootstrap_room,
251
- data_parallel_rank=req.data_parallel_rank,
254
+ prefill_dp_rank=req.data_parallel_rank,
252
255
  )
253
256
 
257
+ req.add_latency(RequestStage.DECODE_PREPARE)
254
258
  self.queue.append(
255
259
  DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
256
260
  )
@@ -419,8 +423,13 @@ class DecodePreallocQueue:
419
423
  kv_indices, self.token_to_kv_pool_allocator.page_size
420
424
  )
421
425
  decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
426
+
422
427
  preallocated_reqs.append(decode_req)
423
428
  indices_to_remove.add(i)
429
+ decode_req.req.time_stats.decode_transfer_queue_entry_time = (
430
+ time.perf_counter()
431
+ )
432
+ decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
424
433
 
425
434
  self.queue = [
426
435
  entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
@@ -514,11 +523,19 @@ class DecodePreallocQueue:
514
523
  dtype=torch.int64,
515
524
  device=self.token_to_kv_pool_allocator.device,
516
525
  ),
526
+ prefix_lens_cpu=torch.tensor(
527
+ [0],
528
+ dtype=torch.int64,
529
+ ),
517
530
  seq_lens=torch.tensor(
518
531
  [num_tokens],
519
532
  dtype=torch.int64,
520
533
  device=self.token_to_kv_pool_allocator.device,
521
534
  ),
535
+ seq_lens_cpu=torch.tensor(
536
+ [num_tokens],
537
+ dtype=torch.int64,
538
+ ),
522
539
  last_loc=torch.tensor(
523
540
  [-1],
524
541
  dtype=torch.int64,
@@ -605,16 +622,23 @@ class DecodeTransferQueue:
605
622
  idx = decode_req.metadata_buffer_index
606
623
  (
607
624
  output_id,
625
+ cached_tokens,
608
626
  output_token_logprobs_val,
609
627
  output_token_logprobs_idx,
610
628
  output_top_logprobs_val,
611
629
  output_top_logprobs_idx,
630
+ output_topk_p,
631
+ output_topk_index,
612
632
  output_hidden_states,
613
633
  ) = self.metadata_buffers.get_buf(idx)
614
634
 
615
635
  decode_req.req.output_ids.append(output_id[0].item())
636
+ decode_req.req.cached_tokens = cached_tokens[0].item()
616
637
  if not self.spec_algorithm.is_none():
638
+ decode_req.req.output_topk_p = output_topk_p
639
+ decode_req.req.output_topk_index = output_topk_index
617
640
  decode_req.req.hidden_states_tensor = output_hidden_states
641
+
618
642
  if decode_req.req.return_logprob:
619
643
  decode_req.req.output_token_logprobs_val.append(
620
644
  output_token_logprobs_val[0].item()
@@ -635,10 +659,17 @@ class DecodeTransferQueue:
635
659
 
636
660
  if hasattr(decode_req.kv_receiver, "clear"):
637
661
  decode_req.kv_receiver.clear()
662
+ decode_req.kv_receiver = None
663
+
664
+ indices_to_remove.add(i)
665
+ decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter()
638
666
 
639
667
  # special handling for sampling_params.max_new_tokens == 1
640
668
  if decode_req.req.sampling_params.max_new_tokens == 1:
641
669
  # finish immediately
670
+ decode_req.req.time_stats.forward_entry_time = (
671
+ decode_req.req.time_stats.completion_time
672
+ ) = time.perf_counter()
642
673
  decode_req.req.check_finished()
643
674
  self.scheduler.stream_output(
644
675
  [decode_req.req], decode_req.req.return_logprob
@@ -646,8 +677,6 @@ class DecodeTransferQueue:
646
677
  self.tree_cache.cache_finished_req(decode_req.req)
647
678
  else:
648
679
  transferred_reqs.append(decode_req.req)
649
-
650
- indices_to_remove.add(i)
651
680
  elif poll in [
652
681
  KVPoll.Bootstrapping,
653
682
  KVPoll.WaitingForInput,
@@ -660,6 +689,7 @@ class DecodeTransferQueue:
660
689
  for i in indices_to_remove:
661
690
  idx = self.queue[i].metadata_buffer_index
662
691
  assert idx != -1
692
+ self.queue[i].req.add_latency(RequestStage.DECODE_TRANSFERRED)
663
693
  self.req_to_metadata_buffer_idx_allocator.free(idx)
664
694
 
665
695
  self.queue = [
@@ -702,12 +732,15 @@ class SchedulerDisaggregationDecodeMixin:
702
732
  elif prepare_mlp_sync_flag:
703
733
  batch, _ = self._prepare_idle_batch_and_run(None)
704
734
 
705
- if batch is None and (
735
+ queue_size = (
706
736
  len(self.waiting_queue)
707
737
  + len(self.disagg_decode_transfer_queue.queue)
708
738
  + len(self.disagg_decode_prealloc_queue.queue)
709
- == 0
710
- ):
739
+ )
740
+ if self.server_args.disaggregation_decode_enable_offload_kvcache:
741
+ queue_size += len(self.decode_offload_manager.ongoing_offload)
742
+
743
+ if batch is None and queue_size == 0:
711
744
  self.self_check_during_idle()
712
745
 
713
746
  self.last_batch = batch
@@ -776,12 +809,15 @@ class SchedulerDisaggregationDecodeMixin:
776
809
  )
777
810
  self.process_batch_result(tmp_batch, tmp_result)
778
811
 
779
- if batch is None and (
812
+ queue_size = (
780
813
  len(self.waiting_queue)
781
814
  + len(self.disagg_decode_transfer_queue.queue)
782
815
  + len(self.disagg_decode_prealloc_queue.queue)
783
- == 0
784
- ):
816
+ )
817
+ if self.server_args.disaggregation_decode_enable_offload_kvcache:
818
+ queue_size += len(self.decode_offload_manager.ongoing_offload)
819
+
820
+ if batch is None and queue_size == 0:
785
821
  self.self_check_during_idle()
786
822
 
787
823
  self.last_batch = batch
@@ -851,6 +887,7 @@ class SchedulerDisaggregationDecodeMixin:
851
887
  # we can only add at least `num_not_used_batch` new batch to the running queue
852
888
  if i < num_not_used_batch:
853
889
  can_run_list.append(req)
890
+ req.add_latency(RequestStage.DECODE_WAITING)
854
891
  req.init_next_round_input(self.tree_cache)
855
892
  else:
856
893
  waiting_queue.append(req)
@@ -859,6 +896,9 @@ class SchedulerDisaggregationDecodeMixin:
859
896
  if len(can_run_list) == 0:
860
897
  return None
861
898
 
899
+ for req in can_run_list:
900
+ req.time_stats.forward_entry_time = time.perf_counter()
901
+
862
902
  # construct a schedule batch with those requests and mark as decode
863
903
  new_batch = ScheduleBatch.init_new(
864
904
  can_run_list,
@@ -884,9 +924,21 @@ class SchedulerDisaggregationDecodeMixin:
884
924
  # if there are still retracted requests, we do not allocate new requests
885
925
  return
886
926
 
887
- req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
888
- self.disagg_decode_transfer_queue.extend(req_conns)
889
- alloc_reqs = (
890
- self.disagg_decode_transfer_queue.pop_transferred()
891
- ) # the requests which kv has arrived
892
- self.waiting_queue.extend(alloc_reqs)
927
+ if not hasattr(self, "polling_count"):
928
+ self.polling_count = 0
929
+ self.polling_interval = (
930
+ self.server_args.disaggregation_decode_polling_interval
931
+ )
932
+
933
+ self.polling_count = (self.polling_count + 1) % self.polling_interval
934
+
935
+ if self.polling_count % self.polling_interval == 0:
936
+ req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
937
+ self.disagg_decode_transfer_queue.extend(req_conns)
938
+ alloc_reqs = (
939
+ self.disagg_decode_transfer_queue.pop_transferred()
940
+ ) # the requests which kv has arrived
941
+ self.waiting_queue.extend(alloc_reqs)
942
+
943
+ if self.server_args.disaggregation_decode_enable_offload_kvcache:
944
+ self.decode_offload_manager.check_offload_progress()
@@ -0,0 +1,185 @@
1
+ import logging
2
+ import threading
3
+ import time
4
+
5
+ import torch
6
+
7
+ from sglang import ServerArgs
8
+ from sglang.srt.managers.cache_controller import HiCacheController
9
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
10
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
11
+ from sglang.srt.mem_cache.memory_pool import (
12
+ MHATokenToKVPool,
13
+ MLATokenToKVPool,
14
+ ReqToTokenPool,
15
+ )
16
+ from sglang.srt.mem_cache.memory_pool_host import (
17
+ MHATokenToKVPoolHost,
18
+ MLATokenToKVPoolHost,
19
+ )
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class DecodeKVCacheOffloadManager:
25
+ """Manage decode-side KV cache offloading lifecycle and operations."""
26
+
27
+ def __init__(
28
+ self,
29
+ req_to_token_pool: ReqToTokenPool,
30
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
31
+ tp_group: torch.distributed.ProcessGroup,
32
+ tree_cache: BasePrefixCache,
33
+ server_args: ServerArgs,
34
+ ) -> None:
35
+ self.req_to_token_pool = req_to_token_pool
36
+ self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
37
+ self.page_size = server_args.page_size
38
+ self.server_args = server_args
39
+ self.request_counter = 0
40
+ self.tree_cache = tree_cache
41
+ kv_cache = self.token_to_kv_pool_allocator.get_kvcache()
42
+ if isinstance(kv_cache, MHATokenToKVPool):
43
+ self.decode_host_mem_pool = MHATokenToKVPoolHost(
44
+ kv_cache,
45
+ server_args.hicache_ratio,
46
+ server_args.hicache_size,
47
+ self.page_size,
48
+ server_args.hicache_mem_layout,
49
+ )
50
+ elif isinstance(kv_cache, MLATokenToKVPool):
51
+ self.decode_host_mem_pool = MLATokenToKVPoolHost(
52
+ kv_cache,
53
+ server_args.hicache_ratio,
54
+ server_args.hicache_size,
55
+ self.page_size,
56
+ server_args.hicache_mem_layout,
57
+ )
58
+ else:
59
+ raise ValueError("Unsupported KV cache type for decode offload")
60
+
61
+ self.tp_group = tp_group
62
+ self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
63
+ self.cache_controller = HiCacheController(
64
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
65
+ mem_pool_host=self.decode_host_mem_pool,
66
+ page_size=self.page_size,
67
+ tp_group=tp_group,
68
+ io_backend=server_args.hicache_io_backend,
69
+ load_cache_event=threading.Event(),
70
+ storage_backend=server_args.hicache_storage_backend,
71
+ model_name=server_args.served_model_name,
72
+ storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
73
+ )
74
+
75
+ self.ongoing_offload = {}
76
+ self.ongoing_backup = {}
77
+ logger.info("Enable offload kv cache for decode side")
78
+
79
+ def offload_kv_cache(self, req) -> bool:
80
+ """Offload a finished request's KV cache to storage."""
81
+
82
+ if self.cache_controller is None or self.decode_host_mem_pool is None:
83
+ return False
84
+
85
+ if req.req_pool_idx == -1:
86
+ return False
87
+
88
+ token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx]
89
+ if token_indices.dim() == 0 or token_indices.numel() == 0:
90
+ logger.debug(
91
+ f"Request {req.rid} has invalid token_indices: {token_indices}"
92
+ )
93
+ return False
94
+
95
+ tokens = req.origin_input_ids + req.output_ids
96
+ aligned_len = (len(tokens) // self.page_size) * self.page_size
97
+ if aligned_len == 0:
98
+ return False
99
+
100
+ token_indices = token_indices[:aligned_len]
101
+ tokens = tokens[:aligned_len]
102
+
103
+ # Asynchronously offload KV cache from device to host by cache controller
104
+ self.request_counter += 1
105
+ ack_id = self.request_counter
106
+ host_indices = self.cache_controller.write(
107
+ device_indices=token_indices.long(),
108
+ node_id=ack_id,
109
+ )
110
+ if host_indices is None:
111
+ logger.error(f"Not enough host memory for request {req.rid}")
112
+ return False
113
+
114
+ self.ongoing_offload[ack_id] = (req, host_indices, tokens, time.time())
115
+ return True
116
+
117
+ def check_offload_progress(self):
118
+ """Check the progress of offload from device to host and backup from host to storage."""
119
+ cc = self.cache_controller
120
+
121
+ qsizes = torch.tensor(
122
+ [
123
+ len(cc.ack_write_queue),
124
+ cc.ack_backup_queue.qsize(),
125
+ ],
126
+ dtype=torch.int,
127
+ )
128
+ if self.tp_world_size > 1:
129
+ torch.distributed.all_reduce(
130
+ qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
131
+ )
132
+
133
+ n_write, n_backup = map(int, qsizes.tolist())
134
+ self._check_offload_progress(n_write)
135
+ self._check_backup_progress(n_backup)
136
+
137
+ def _check_offload_progress(self, finish_count):
138
+ """Check the progress of offload from device to host."""
139
+ while finish_count > 0:
140
+ _, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
141
+ finish_event.synchronize()
142
+ for ack_id in ack_list:
143
+ req, host_indices, tokens, start_time = self.ongoing_offload.pop(ack_id)
144
+
145
+ # Release device
146
+ self.tree_cache.cache_finished_req(req)
147
+
148
+ # Trigger async backup from host to storage by cache controller
149
+ self._trigger_backup(req.rid, host_indices, tokens, start_time)
150
+ finish_count -= 1
151
+
152
+ def _check_backup_progress(self, finish_count):
153
+ """Check the progress of backup from host to storage."""
154
+ for _ in range(finish_count):
155
+ storage_operation = self.cache_controller.ack_backup_queue.get()
156
+ ack_id = storage_operation.id
157
+ req_id, host_indices, start_time = self.ongoing_backup.pop(ack_id)
158
+
159
+ # Release host memory
160
+ self.decode_host_mem_pool.free(host_indices)
161
+
162
+ logger.debug(
163
+ f"Finished backup request {req_id}, free host memory, len:{len(host_indices)}, cost time:{time.time() - start_time:.2f} seconds."
164
+ )
165
+
166
+ def _trigger_backup(self, req_id, host_indices, tokens, start_time):
167
+ """Trigger async backup from host to storage by cache controller."""
168
+
169
+ # Generate page hashes and write to storage
170
+ page_hashes = self._compute_prefix_hash(tokens)
171
+ ack_id = self.cache_controller.write_storage(
172
+ host_indices,
173
+ tokens,
174
+ hash_value=page_hashes,
175
+ )
176
+ self.ongoing_backup[ack_id] = (req_id, host_indices, start_time)
177
+
178
+ def _compute_prefix_hash(self, tokens):
179
+ last_hash = ""
180
+ page_hashes = []
181
+ for offset in range(0, len(tokens), self.page_size):
182
+ page_tokens = tokens[offset : offset + self.page_size]
183
+ last_hash = self.cache_controller.get_hash_str(page_tokens, last_hash)
184
+ page_hashes.append(last_hash)
185
+ return page_hashes
@@ -76,6 +76,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
76
76
  req_pool_indices, dtype=torch.int64, device=self.device
77
77
  )
78
78
  self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
79
+ self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
79
80
  self.orig_seq_lens = torch.tensor(
80
81
  seq_lens, dtype=torch.int32, device=self.device
81
82
  )
@@ -110,7 +111,10 @@ class ScheduleBatchDisaggregationDecodeMixin:
110
111
  if req.grammar is not None:
111
112
  # FIXME: this try-except block is for handling unexpected xgrammar issue.
112
113
  try:
113
- req.grammar.accept_token(req.output_ids[-1])
114
+ # if it is not None, then the grammar is from a retracted request, and we should not
115
+ # accept the token as it's already accepted
116
+ if req.grammar.current_token is None:
117
+ req.grammar.accept_token(req.output_ids[-1])
114
118
  except ValueError as e:
115
119
  # Grammar accept_token can raise ValueError if the token is not in the grammar.
116
120
  # This can happen if the grammar is not set correctly or the token is invalid.
@@ -122,31 +126,39 @@ class ScheduleBatchDisaggregationDecodeMixin:
122
126
  req.grammar.finished = req.finished()
123
127
  self.output_ids = torch.tensor(self.output_ids, device=self.device)
124
128
 
125
- # Simulate the eagle run. We add mock data to hidden states for the
126
- # ease of implementation now meaning the first token will have acc rate
127
- # of 0.
128
- if not self.spec_algorithm.is_none():
129
+ # Simulate the eagle run.
130
+ if self.spec_algorithm.is_eagle():
129
131
 
130
132
  b = len(self.reqs)
131
- topk_p = torch.arange(
132
- b * server_args.speculative_eagle_topk,
133
- 0,
134
- -1,
135
- device=self.device,
136
- dtype=torch.float32,
133
+ topk = server_args.speculative_eagle_topk
134
+ topk_p = torch.stack(
135
+ [
136
+ torch.as_tensor(
137
+ req.output_topk_p[:topk],
138
+ device=self.device,
139
+ dtype=torch.float32,
140
+ )
141
+ for req in self.reqs
142
+ ],
143
+ dim=0,
137
144
  )
138
- topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk)
139
- topk_p /= b * server_args.speculative_eagle_topk
140
- topk_index = torch.arange(
141
- b * server_args.speculative_eagle_topk, device=self.device
145
+ topk_index = torch.stack(
146
+ [
147
+ torch.as_tensor(
148
+ req.output_topk_index[:topk],
149
+ device=self.device,
150
+ dtype=torch.int64,
151
+ )
152
+ for req in self.reqs
153
+ ],
154
+ dim=0,
142
155
  )
143
- topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
144
156
 
145
157
  hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
146
158
  hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
147
159
 
148
160
  # local import to avoid circular import
149
- from sglang.srt.speculative.eagle_utils import EagleDraftInput
161
+ from sglang.srt.speculative.eagle_info import EagleDraftInput
150
162
 
151
163
  spec_info = EagleDraftInput(
152
164
  topk_p=topk_p,
@@ -62,7 +62,7 @@ class FakeKVReceiver(BaseKVReceiver):
62
62
  mgr: BaseKVManager,
63
63
  bootstrap_addr: str,
64
64
  bootstrap_room: Optional[int] = None,
65
- data_parallel_rank: Optional[int] = None,
65
+ prefill_dp_rank: Optional[int] = None,
66
66
  ):
67
67
  self.has_init = False
68
68