sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@ from enum import Enum, auto
17
17
  from typing import Any, List, Optional
18
18
 
19
19
  from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType
20
- from sglang.srt.poll_based_barrier import PollBasedBarrier
20
+ from sglang.srt.utils.poll_based_barrier import PollBasedBarrier
21
21
 
22
22
  logger = logging.getLogger(__name__)
23
23
 
@@ -12,7 +12,6 @@ from sglang.srt.disaggregation.utils import DisaggregationMode
12
12
  from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
13
13
  from sglang.srt.managers.schedule_policy import PrefillAdder
14
14
  from sglang.srt.managers.scheduler import Req, ScheduleBatch
15
- from sglang.srt.managers.utils import DPBalanceMeta
16
15
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
17
16
  from sglang.srt.utils import get_bool_env_var
18
17
 
@@ -47,8 +46,11 @@ class SchedulerMetricsMixin:
47
46
  self.spec_num_total_forward_ct = 0
48
47
  self.cum_spec_accept_length = 0
49
48
  self.cum_spec_accept_count = 0
50
- self.total_retracted_reqs = 0
49
+ self.kv_transfer_speed_gb_s: float = 0.0
50
+ self.kv_transfer_latency_ms: float = 0.0
51
+
51
52
  self.stats = SchedulerStats()
53
+
52
54
  if self.enable_metrics:
53
55
  engine_type = "unified"
54
56
  labels = {
@@ -61,33 +63,30 @@ class SchedulerMetricsMixin:
61
63
  labels["dp_rank"] = dp_rank
62
64
  self.metrics_collector = SchedulerMetricsCollector(labels=labels)
63
65
 
64
- def init_dp_balance(self: Scheduler, dp_balance_meta: Optional[DPBalanceMeta]):
65
- self.balance_meta = dp_balance_meta
66
- if (
67
- self.server_args.enable_dp_attention
68
- and self.server_args.load_balance_method == "minimum_tokens"
69
- ):
70
- assert dp_balance_meta is not None
71
-
72
- self.recv_dp_balance_id_this_term = []
73
-
74
66
  def init_kv_events(self: Scheduler, kv_events_config: Optional[str]):
75
67
  if self.enable_kv_cache_events:
76
68
  self.kv_event_publisher = EventPublisherFactory.create(
77
69
  kv_events_config, self.attn_dp_rank
78
70
  )
79
71
 
72
+ def udpate_spec_metrics(self, bs: int, num_accepted_tokens: int):
73
+ self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
74
+ self.spec_num_total_forward_ct += bs
75
+ self.num_generated_tokens += num_accepted_tokens
76
+
80
77
  def log_prefill_stats(
81
78
  self: Scheduler,
82
79
  adder: PrefillAdder,
83
80
  can_run_list: List[Req],
84
81
  running_bs: int,
82
+ running_bs_offline_batch: int,
85
83
  ):
86
84
  gap_latency = time.perf_counter() - self.last_prefill_stats_tic
87
85
  self.last_prefill_stats_tic = time.perf_counter()
88
86
  self.last_input_throughput = self.last_prefill_tokens / gap_latency
89
87
  self.last_prefill_tokens = adder.log_input_tokens
90
88
 
89
+ # TODO: generalize this for various memory pools
91
90
  if self.is_hybrid:
92
91
  (
93
92
  full_num_used,
@@ -101,51 +100,53 @@ class SchedulerMetricsMixin:
101
100
  ) = self._get_swa_token_info()
102
101
  num_used = max(full_num_used, swa_num_used)
103
102
  token_usage = max(full_token_usage, swa_token_usage)
104
- token_msg = (
103
+ token_usage_msg = (
105
104
  f"full token usage: {full_token_usage:.2f}, "
106
105
  f"swa token usage: {swa_token_usage:.2f}, "
107
106
  )
108
107
  else:
109
108
  num_used, token_usage, _, _ = self._get_token_info()
110
- token_msg = f"token usage: {token_usage:.2f}, "
109
+ token_usage_msg = f"token usage: {token_usage:.2f}, "
111
110
 
112
- num_new_seq = len(can_run_list)
113
111
  f = (
114
112
  f"Prefill batch. "
115
- f"#new-seq: {num_new_seq}, "
113
+ f"#new-seq: {len(can_run_list)}, "
116
114
  f"#new-token: {adder.log_input_tokens}, "
117
115
  f"#cached-token: {adder.log_hit_tokens}, "
118
- f"{token_msg}"
116
+ f"{token_usage_msg}"
117
+ f"#running-req: {running_bs}, "
118
+ f"#queue-req: {len(self.waiting_queue)}, "
119
119
  )
120
120
 
121
121
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
122
- f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
123
- f += f"#queue-req: {len(self.waiting_queue)}, "
124
- f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
125
- f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
126
- else:
127
- f += f"#running-req: {running_bs}, "
128
- f += f"#queue-req: {len(self.waiting_queue)}, "
122
+ f += f"#prealloc-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
123
+ f += f"#inflight-req: {len(self.disagg_prefill_inflight_queue)}, "
129
124
 
130
125
  logger.info(f)
131
126
 
132
127
  if self.enable_metrics:
128
+ # Basics
133
129
  total_tokens = adder.log_input_tokens + adder.log_hit_tokens
134
-
135
130
  cache_hit_rate = (
136
131
  adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
137
132
  )
133
+
138
134
  self.stats.num_running_reqs = running_bs
135
+ self.stats.num_running_reqs_offline_batch = running_bs_offline_batch
139
136
  self.stats.num_used_tokens = num_used
140
- self.stats.token_usage = round(token_usage, 2)
137
+ self.stats.token_usage = token_usage
138
+ if self.is_hybrid:
139
+ self.stats.swa_token_usage = swa_token_usage
141
140
  self.stats.num_queue_reqs = len(self.waiting_queue)
141
+ self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
142
142
  self.stats.cache_hit_rate = cache_hit_rate
143
143
 
144
- total_queue_latency = 0
145
- for req in can_run_list:
146
- total_queue_latency += req.queue_time_end - req.queue_time_start
147
- self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
144
+ # Retract
145
+ self.stats.num_retracted_reqs = self.num_retracted_reqs
146
+ self.stats.num_paused_reqs = self.num_paused_reqs
147
+ self.num_retracted_reqs = self.num_paused_reqs = 0
148
148
 
149
+ # PD disaggregation
149
150
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
150
151
  self.stats.num_prefill_prealloc_queue_reqs = len(
151
152
  self.disagg_prefill_bootstrap_queue.queue
@@ -153,7 +154,18 @@ class SchedulerMetricsMixin:
153
154
  self.stats.num_prefill_inflight_queue_reqs = len(
154
155
  self.disagg_prefill_inflight_queue
155
156
  )
157
+ self.stats.kv_transfer_speed_gb_s = self.kv_transfer_speed_gb_s
158
+ self.stats.kv_transfer_latency_ms = self.kv_transfer_latency_ms
159
+ elif self.disaggregation_mode == DisaggregationMode.DECODE:
160
+ self.stats.num_decode_prealloc_queue_reqs = len(
161
+ self.disagg_decode_prealloc_queue.queue
162
+ )
163
+ self.stats.num_decode_transfer_queue_reqs = len(
164
+ self.disagg_decode_transfer_queue.queue
165
+ )
156
166
 
167
+ # Others
168
+ self.calculate_utilization()
157
169
  self.metrics_collector.log_stats(self.stats)
158
170
  self._emit_kv_metrics()
159
171
  self._publish_kv_events()
@@ -166,8 +178,12 @@ class SchedulerMetricsMixin:
166
178
  gap_latency = time.perf_counter() - self.last_decode_stats_tic
167
179
  self.last_decode_stats_tic = time.perf_counter()
168
180
  self.last_gen_throughput = self.num_generated_tokens / gap_latency
181
+
169
182
  self.num_generated_tokens = 0
170
183
  num_running_reqs = len(batch.reqs)
184
+ num_running_reqs_offline_batch = 0
185
+
186
+ # TODO: generalize this for various memory pools
171
187
  if self.is_hybrid:
172
188
  (
173
189
  full_num_used,
@@ -181,7 +197,7 @@ class SchedulerMetricsMixin:
181
197
  ) = self._get_swa_token_info()
182
198
  num_used = max(full_num_used, swa_num_used)
183
199
  token_usage = max(full_token_usage, swa_token_usage)
184
- token_msg = (
200
+ token_usage_msg = (
185
201
  f"#full token: {full_num_used}, "
186
202
  f"full token usage: {full_token_usage:.2f}, "
187
203
  f"#swa token: {swa_num_used}, "
@@ -189,14 +205,14 @@ class SchedulerMetricsMixin:
189
205
  )
190
206
  else:
191
207
  num_used, token_usage, _, _ = self._get_token_info()
192
- token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
208
+ token_usage_msg = f"#token: {num_used}, token usage: {token_usage:.2f}, "
193
209
 
194
210
  if RECORD_STEP_TIME:
195
211
  self.step_time_dict[num_running_reqs].append(
196
212
  gap_latency / self.server_args.decode_log_interval
197
213
  )
198
214
 
199
- msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
215
+ msg = f"Decode batch. #running-req: {num_running_reqs}, {token_usage_msg}"
200
216
 
201
217
  if self.spec_algorithm.is_none():
202
218
  spec_accept_length = 0
@@ -208,40 +224,66 @@ class SchedulerMetricsMixin:
208
224
  self.cum_spec_accept_count += self.spec_num_total_forward_ct
209
225
  self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
210
226
  msg += f"accept len: {spec_accept_length:.2f}, "
227
+ cache_hit_rate = 0.0
211
228
 
212
229
  if self.disaggregation_mode == DisaggregationMode.DECODE:
213
230
  msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
231
+ msg += f"#prealloc-req: {len(self.disagg_decode_prealloc_queue.queue)}, "
232
+ msg += f"#transfer-req: {len(self.disagg_decode_transfer_queue.queue)}, "
214
233
  msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
215
234
 
216
235
  msg += (
217
- f"cuda graph: {can_run_cuda_graph}, "
236
+ f"{'cuda graph' if self.device == 'cuda' else 'cpu graph'}: {can_run_cuda_graph}, "
218
237
  f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
219
238
  f"#queue-req: {len(self.waiting_queue)}, "
220
239
  )
221
240
 
222
241
  logger.info(msg)
223
242
  if self.enable_metrics:
243
+ # Basics
224
244
  self.stats.num_running_reqs = num_running_reqs
245
+ self.stats.num_running_reqs_offline_batch = num_running_reqs_offline_batch
225
246
  self.stats.num_used_tokens = num_used
226
- self.stats.token_usage = round(token_usage, 2)
227
- self.stats.cache_hit_rate = 0.0
247
+ self.stats.token_usage = token_usage
248
+ if self.is_hybrid:
249
+ self.stats.swa_token_usage = swa_token_usage
228
250
  self.stats.gen_throughput = self.last_gen_throughput
229
251
  self.stats.num_queue_reqs = len(self.waiting_queue)
230
252
  self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
253
+ self.stats.cache_hit_rate = cache_hit_rate
231
254
  self.stats.spec_accept_length = spec_accept_length
232
- self.stats.total_retracted_reqs = self.total_retracted_reqs
233
- self.metrics_collector.log_stats(self.stats)
234
- if self.disaggregation_mode == DisaggregationMode.DECODE:
255
+
256
+ # Retract
257
+ self.stats.num_retracted_reqs = self.num_retracted_reqs
258
+ self.stats.num_paused_reqs = self.num_paused_reqs
259
+ self.num_retracted_reqs = self.num_paused_reqs = 0
260
+
261
+ # PD disaggregation
262
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
263
+ self.stats.num_prefill_prealloc_queue_reqs = len(
264
+ self.disagg_prefill_bootstrap_queue.queue
265
+ )
266
+ self.stats.num_prefill_inflight_queue_reqs = len(
267
+ self.disagg_prefill_inflight_queue
268
+ )
269
+ elif self.disaggregation_mode == DisaggregationMode.DECODE:
235
270
  self.stats.num_decode_prealloc_queue_reqs = len(
236
271
  self.disagg_decode_prealloc_queue.queue
237
272
  )
238
273
  self.stats.num_decode_transfer_queue_reqs = len(
239
274
  self.disagg_decode_transfer_queue.queue
240
275
  )
276
+
277
+ # Others
278
+ self.calculate_utilization()
279
+ self.metrics_collector.log_stats(self.stats)
241
280
  self._emit_kv_metrics()
242
281
  self._publish_kv_events()
243
282
 
244
283
  def _emit_kv_metrics(self: Scheduler):
284
+ if not self.enable_kv_cache_events:
285
+ return
286
+
245
287
  kv_metrics = KvMetrics()
246
288
  kv_metrics.request_active_slots = self.stats.num_running_reqs
247
289
  kv_metrics.request_total_slots = self.max_running_requests
@@ -258,93 +300,24 @@ class SchedulerMetricsMixin:
258
300
  self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
259
301
 
260
302
  def _publish_kv_events(self: Scheduler):
261
- if self.enable_kv_cache_events:
262
- events = self.tree_cache.take_events()
263
- if events:
264
- batch = KVEventBatch(ts=time.time(), events=events)
265
- self.kv_event_publisher.publish(batch)
266
-
267
- def maybe_update_dp_balance_data(
268
- self: Scheduler, recv_req: TokenizedGenerateReqInput
269
- ):
270
- if (
271
- self.server_args.enable_dp_attention
272
- and self.server_args.load_balance_method == "minimum_tokens"
273
- ):
274
- self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
275
-
276
- def maybe_handle_dp_balance_data(self: Scheduler):
277
- if (
278
- self.server_args.load_balance_method == "minimum_tokens"
279
- and self.forward_ct % 40 == 0
280
- ):
281
- holding_tokens = self.get_load()
282
-
283
- new_recv_dp_balance_id_list, holding_token_list = (
284
- self.gather_dp_balance_info(holding_tokens)
285
- )
303
+ if not self.enable_kv_cache_events:
304
+ return
286
305
 
287
- self.recv_dp_balance_id_this_term.clear()
288
- if self.tp_rank == 0: # only first worker write info
289
- self.write_shared_dp_balance_info(
290
- new_recv_dp_balance_id_list, holding_token_list
291
- )
306
+ events = self.tree_cache.take_events()
307
+ if events:
308
+ batch = KVEventBatch(ts=time.time(), events=events)
309
+ self.kv_event_publisher.publish(batch)
292
310
 
293
- def gather_dp_balance_info(
294
- self: Scheduler, holding_tokens_list
295
- ) -> Union[None, List[List[int]]]:
296
- """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
297
- recv_list = self.recv_dp_balance_id_this_term
298
- assert len(recv_list) <= 511, (
299
- "The number of requests received this round is too large. "
300
- "Please increase gather_tensor_size and onfly_info_size."
301
- )
302
- # The maximum size of the tensor used for gathering data from all workers.
303
- gather_tensor_size = 512
304
-
305
- # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
306
- recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
307
- recv_tensor[0] = holding_tokens_list
308
- recv_tensor[1] = len(recv_list) # The first element is the length of the list.
309
- recv_tensor[2 : len(recv_list) + 2] = torch.tensor(recv_list, dtype=torch.int32)
310
-
311
- if self.tp_rank == 0:
312
- gathered_list = [
313
- torch.zeros(gather_tensor_size, dtype=torch.int32)
314
- for _ in range(self.balance_meta.num_workers)
315
- ]
311
+ def calculate_utilization(self):
312
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
313
+ self.stats.utilization = -1
316
314
  else:
317
- gathered_list = None
318
-
319
- torch.distributed.gather(recv_tensor, gathered_list, group=self.tp_cpu_group)
320
-
321
- gathered_id_list_per_worker = None
322
- if self.tp_rank == 0:
323
- gathered_id_list_per_worker = []
324
- holding_tokens_list = []
325
- for tensor in gathered_list:
326
- holding_tokens_list.append(tensor[0].item())
327
- list_length = tensor[1].item()
328
- gathered_id_list_per_worker.append(tensor[2 : list_length + 2].tolist())
329
-
330
- return gathered_id_list_per_worker, holding_tokens_list
331
-
332
- def write_shared_dp_balance_info(self: Scheduler, new_recv_rid_lists, local_tokens):
333
- meta = self.balance_meta
334
-
335
- with meta.mutex:
336
- onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
337
- assert len(new_recv_rid_lists) == len(onfly_list), "num_worker not equal"
338
- # 1.Check if the rid received by each worker this round is present in onfly.
339
- # If it is, remove the corresponding onfly item.
340
- worker_id = 0
341
- for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
342
- for new_recv_rid in new_recv_rids:
343
- assert (
344
- new_recv_rid in on_fly_reqs
345
- ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
346
- del on_fly_reqs[new_recv_rid]
347
- worker_id += 1
348
- # 2. Atomically write local_tokens and onfly into shm under the mutex
349
- meta.set_shared_onfly_info(onfly_list)
350
- meta.set_shared_local_tokens(local_tokens)
315
+ if (
316
+ self.stats.max_running_requests_under_SLO is not None
317
+ and self.stats.max_running_requests_under_SLO > 0
318
+ ):
319
+ self.stats.utilization = max(
320
+ self.stats.num_running_reqs
321
+ / self.stats.max_running_requests_under_SLO,
322
+ self.stats.token_usage / 0.9,
323
+ )
@@ -5,9 +5,15 @@ import threading
5
5
  import time
6
6
  from typing import TYPE_CHECKING, List, Optional, Tuple, Union
7
7
 
8
+ import torch
9
+
8
10
  from sglang.srt.disaggregation.utils import DisaggregationMode
9
11
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
10
- from sglang.srt.managers.io_struct import AbortReq, BatchEmbeddingOut, BatchTokenIDOut
12
+ from sglang.srt.managers.io_struct import (
13
+ AbortReq,
14
+ BatchEmbeddingOutput,
15
+ BatchTokenIDOutput,
16
+ )
11
17
  from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
12
18
 
13
19
  if TYPE_CHECKING:
@@ -71,6 +77,7 @@ class SchedulerOutputProcessorMixin:
71
77
 
72
78
  # Check finish conditions
73
79
  logprob_pt = 0
80
+
74
81
  for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
75
82
  if req.is_retracted:
76
83
  continue
@@ -88,7 +95,7 @@ class SchedulerOutputProcessorMixin:
88
95
 
89
96
  if req.finished():
90
97
  self.tree_cache.cache_finished_req(req)
91
- req.time_stats.completion_time = time.time()
98
+ req.time_stats.completion_time = time.perf_counter()
92
99
  elif not batch.decoding_reqs or req not in batch.decoding_reqs:
93
100
  # This updates radix so others can match
94
101
  self.tree_cache.cache_unfinished_req(req)
@@ -99,6 +106,7 @@ class SchedulerOutputProcessorMixin:
99
106
  extend_logprob_start_len = extend_logprob_start_len_per_req[i]
100
107
  extend_input_len = extend_input_len_per_req[i]
101
108
  num_input_logprobs = extend_input_len - extend_logprob_start_len
109
+
102
110
  if req.return_logprob:
103
111
  self.add_logprob_return_values(
104
112
  i,
@@ -136,7 +144,7 @@ class SchedulerOutputProcessorMixin:
136
144
  logger.error(
137
145
  f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
138
146
  )
139
- self.abort_request(AbortReq(req.rid))
147
+ self.abort_request(AbortReq(rid=req.rid))
140
148
  req.grammar.finished = req.finished()
141
149
  else:
142
150
  # being chunked reqs' prefill is not finished
@@ -169,8 +177,7 @@ class SchedulerOutputProcessorMixin:
169
177
  self.set_next_batch_sampling_info_done(batch)
170
178
 
171
179
  else: # embedding or reward model
172
- embeddings, bid = result.embeddings, result.bid
173
- embeddings = embeddings.tolist()
180
+ embeddings = result.embeddings.tolist()
174
181
 
175
182
  # Check finish conditions
176
183
  for i, req in enumerate(batch.reqs):
@@ -246,8 +253,14 @@ class SchedulerOutputProcessorMixin:
246
253
 
247
254
  req.check_finished()
248
255
  if req.finished():
249
- self.tree_cache.cache_finished_req(req)
250
- req.time_stats.completion_time = time.time()
256
+ if self.server_args.disaggregation_decode_enable_offload_kvcache:
257
+ # Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
258
+ if not self.decode_offload_manager.offload_kv_cache(req):
259
+ self.tree_cache.cache_finished_req(req)
260
+ else:
261
+ self.tree_cache.cache_finished_req(req)
262
+
263
+ req.time_stats.completion_time = time.perf_counter()
251
264
 
252
265
  if req.return_logprob and batch.spec_algorithm.is_none():
253
266
  # speculative worker handles logprob in speculative decoding
@@ -283,7 +296,7 @@ class SchedulerOutputProcessorMixin:
283
296
  logger.error(
284
297
  f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
285
298
  )
286
- self.abort_request(AbortReq(req.rid))
299
+ self.abort_request(AbortReq(rid=req.rid))
287
300
  req.grammar.finished = req.finished()
288
301
 
289
302
  self.set_next_batch_sampling_info_done(batch)
@@ -441,27 +454,59 @@ class SchedulerOutputProcessorMixin:
441
454
  output: LogitsProcessorOutput,
442
455
  ):
443
456
  """Attach logprobs to the return values."""
444
- req.output_token_logprobs_val.append(output.next_token_logprobs[i])
445
- req.output_token_logprobs_idx.append(next_token_ids[i])
446
-
447
- self.add_input_logprob_return_values(
448
- i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
449
- )
457
+ if output.next_token_logprobs is not None:
458
+ req.output_token_logprobs_val.append(output.next_token_logprobs[i])
459
+ req.output_token_logprobs_idx.append(next_token_ids[i])
460
+
461
+ # Only add input logprobs if there are input tokens to process
462
+ # Note: For prefill-only requests with default logprob_start_len, this will be 0,
463
+ # meaning we only compute output logprobs (which is the intended behavior)
464
+ if num_input_logprobs > 0:
465
+ self.add_input_logprob_return_values(
466
+ i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
467
+ )
468
+ else:
469
+ self._initialize_empty_logprob_containers(req)
450
470
 
451
471
  if req.top_logprobs_num > 0:
452
472
  req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
453
473
  req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
454
474
 
455
- if req.token_ids_logprob is not None:
456
- req.output_token_ids_logprobs_val.append(
457
- output.next_token_token_ids_logprobs_val[i]
458
- )
475
+ if (
476
+ req.token_ids_logprob is not None
477
+ and output.next_token_token_ids_logprobs_val is not None
478
+ ):
479
+ # Convert GPU tensor to list if needed
480
+ logprobs_val = output.next_token_token_ids_logprobs_val[i]
481
+ if isinstance(logprobs_val, torch.Tensor):
482
+ logprobs_val = logprobs_val.tolist()
483
+ req.output_token_ids_logprobs_val.append(logprobs_val)
459
484
  req.output_token_ids_logprobs_idx.append(
460
485
  output.next_token_token_ids_logprobs_idx[i]
461
486
  )
462
487
 
463
488
  return num_input_logprobs
464
489
 
490
+ def _initialize_empty_logprob_containers(self, req: Req) -> None:
491
+ """
492
+ Initialize logprob fields to empty lists if unset.
493
+
494
+ This is needed for prefill-only requests where the normal initialization
495
+ flow might be bypassed, but downstream code expects these fields to be lists.
496
+ """
497
+ if req.input_token_logprobs_val is None:
498
+ req.input_token_logprobs_val = []
499
+ if req.input_token_logprobs_idx is None:
500
+ req.input_token_logprobs_idx = []
501
+ if req.input_top_logprobs_val is None:
502
+ req.input_top_logprobs_val = []
503
+ if req.input_top_logprobs_idx is None:
504
+ req.input_top_logprobs_idx = []
505
+ if req.input_token_ids_logprobs_val is None:
506
+ req.input_token_ids_logprobs_val = []
507
+ if req.input_token_ids_logprobs_idx is None:
508
+ req.input_token_ids_logprobs_idx = []
509
+
465
510
  def stream_output(
466
511
  self: Scheduler,
467
512
  reqs: List[Req],
@@ -673,8 +718,7 @@ class SchedulerOutputProcessorMixin:
673
718
  return
674
719
 
675
720
  self.send_to_detokenizer.send_pyobj(
676
- BatchTokenIDOut(
677
- rids,
721
+ BatchTokenIDOutput(
678
722
  finished_reasons,
679
723
  decoded_texts,
680
724
  decode_ids_list,
@@ -700,6 +744,9 @@ class SchedulerOutputProcessorMixin:
700
744
  output_token_ids_logprobs_val,
701
745
  output_token_ids_logprobs_idx,
702
746
  output_hidden_states,
747
+ rids=rids,
748
+ placeholder_tokens_idx=None,
749
+ placeholder_tokens_val=None,
703
750
  )
704
751
  )
705
752
 
@@ -718,7 +765,13 @@ class SchedulerOutputProcessorMixin:
718
765
  prompt_tokens.append(len(req.origin_input_ids))
719
766
  cached_tokens.append(req.cached_tokens)
720
767
  self.send_to_detokenizer.send_pyobj(
721
- BatchEmbeddingOut(
722
- rids, finished_reasons, embeddings, prompt_tokens, cached_tokens
768
+ BatchEmbeddingOutput(
769
+ finished_reasons,
770
+ embeddings,
771
+ prompt_tokens,
772
+ cached_tokens,
773
+ rids=rids,
774
+ placeholder_tokens_idx=None,
775
+ placeholder_tokens_val=None,
723
776
  )
724
777
  )
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
26
26
 
27
27
  class SchedulerProfilerMixin:
28
28
 
29
- def init_profier(self):
29
+ def init_profiler(self):
30
30
  self.torch_profiler = None
31
31
  self.torch_profiler_output_dir: Optional[str] = None
32
32
  self.profiler_activities: Optional[List[str]] = None
@@ -97,7 +97,7 @@ class SchedulerProfilerMixin:
97
97
  def start_profile(
98
98
  self, stage: Optional[ForwardMode] = None
99
99
  ) -> ProfileReqOutput | None:
100
- stage_str = f" for {stage.__str__()}" if stage else ""
100
+ stage_str = f" for {stage.name}" if stage else ""
101
101
  logger.info(
102
102
  f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
103
103
  )
@@ -181,7 +181,7 @@ class SchedulerProfilerMixin:
181
181
  if not Path(self.torch_profiler_output_dir).exists():
182
182
  Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
183
183
 
184
- stage_suffix = f"-{stage.__str__()}" if stage else ""
184
+ stage_suffix = f"-{stage.name}" if stage else ""
185
185
  logger.info("Stop profiling" + stage_suffix + "...")
186
186
  if self.torch_profiler is not None:
187
187
  self.torch_profiler.stop()
@@ -204,7 +204,7 @@ class SchedulerProfilerMixin:
204
204
 
205
205
  torch.distributed.barrier(self.tp_cpu_group)
206
206
  if self.tp_rank == 0:
207
- from sglang.srt.utils import rpd_to_chrome_trace
207
+ from sglang.srt.utils.rpd_utils import rpd_to_chrome_trace
208
208
 
209
209
  rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
210
210
  self.rpd_profiler = None
@@ -247,7 +247,7 @@ class SchedulerProfilerMixin:
247
247
  if self.profiler_decode_ct == 0:
248
248
  if self.profile_in_progress:
249
249
  # force trace flush
250
- self.stop_profile(ForwardMode.EXTEND)
250
+ self.stop_profile(stage=ForwardMode.EXTEND)
251
251
  self.start_profile(batch.forward_mode)
252
252
  self.profiler_decode_ct += 1
253
253
  if self.profiler_decode_ct > self.profiler_target_decode_ct:
@@ -294,6 +294,6 @@ class SchedulerProfilerMixin:
294
294
  recv_req.profile_by_stage,
295
295
  recv_req.profile_id,
296
296
  )
297
- return self.start_profile(True)
297
+ return self.start_profile()
298
298
  else:
299
299
  return self.stop_profile()
@@ -5,6 +5,8 @@ import torch
5
5
 
6
6
  from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
7
7
  from sglang.srt.managers.io_struct import (
8
+ DestroyWeightsUpdateGroupReqInput,
9
+ DestroyWeightsUpdateGroupReqOutput,
8
10
  GetWeightsByNameReqInput,
9
11
  GetWeightsByNameReqOutput,
10
12
  InitWeightsUpdateGroupReqInput,
@@ -41,6 +43,11 @@ class SchedulerUpdateWeightsMixin:
41
43
  success, message = self.tp_worker.init_weights_update_group(recv_req)
42
44
  return InitWeightsUpdateGroupReqOutput(success, message)
43
45
 
46
+ def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
47
+ """Destroy the online model parameter update group."""
48
+ success, message = self.tp_worker.destroy_weights_update_group(recv_req)
49
+ return DestroyWeightsUpdateGroupReqOutput(success, message)
50
+
44
51
  def update_weights_from_distributed(
45
52
  self,
46
53
  recv_req: UpdateWeightsFromDistributedReqInput,