sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -5,9 +5,15 @@ import threading
5
5
  import time
6
6
  from typing import TYPE_CHECKING, List, Optional, Tuple, Union
7
7
 
8
+ import torch
9
+
8
10
  from sglang.srt.disaggregation.utils import DisaggregationMode
9
11
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
10
- from sglang.srt.managers.io_struct import AbortReq, BatchEmbeddingOut, BatchTokenIDOut
12
+ from sglang.srt.managers.io_struct import (
13
+ AbortReq,
14
+ BatchEmbeddingOutput,
15
+ BatchTokenIDOutput,
16
+ )
11
17
  from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
12
18
 
13
19
  if TYPE_CHECKING:
@@ -33,7 +39,6 @@ class SchedulerOutputProcessorMixin:
33
39
  self: Scheduler,
34
40
  batch: ScheduleBatch,
35
41
  result: Union[GenerationBatchResult, EmbeddingBatchResult],
36
- launch_done: Optional[threading.Event] = None,
37
42
  ):
38
43
  skip_stream_req = None
39
44
 
@@ -43,34 +48,35 @@ class SchedulerOutputProcessorMixin:
43
48
  next_token_ids,
44
49
  extend_input_len_per_req,
45
50
  extend_logprob_start_len_per_req,
51
+ copy_done,
46
52
  ) = (
47
53
  result.logits_output,
48
54
  result.next_token_ids,
49
55
  result.extend_input_len_per_req,
50
56
  result.extend_logprob_start_len_per_req,
57
+ result.copy_done,
51
58
  )
52
59
 
53
- if self.enable_overlap:
54
- logits_output, next_token_ids, _ = (
55
- self.tp_worker.resolve_last_batch_result(launch_done)
56
- )
57
- else:
58
- # Move next_token_ids and logprobs to cpu
59
- next_token_ids = next_token_ids.tolist()
60
- if batch.return_logprob:
61
- if logits_output.next_token_logprobs is not None:
62
- logits_output.next_token_logprobs = (
63
- logits_output.next_token_logprobs.tolist()
64
- )
65
- if logits_output.input_token_logprobs is not None:
66
- logits_output.input_token_logprobs = tuple(
67
- logits_output.input_token_logprobs.tolist()
68
- )
60
+ if copy_done is not None:
61
+ copy_done.synchronize()
62
+
63
+ # Move next_token_ids and logprobs to cpu
64
+ next_token_ids = next_token_ids.tolist()
65
+ if batch.return_logprob:
66
+ if logits_output.next_token_logprobs is not None:
67
+ logits_output.next_token_logprobs = (
68
+ logits_output.next_token_logprobs.tolist()
69
+ )
70
+ if logits_output.input_token_logprobs is not None:
71
+ logits_output.input_token_logprobs = tuple(
72
+ logits_output.input_token_logprobs.tolist()
73
+ )
69
74
 
70
75
  hidden_state_offset = 0
71
76
 
72
77
  # Check finish conditions
73
78
  logprob_pt = 0
79
+
74
80
  for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
75
81
  if req.is_retracted:
76
82
  continue
@@ -88,7 +94,7 @@ class SchedulerOutputProcessorMixin:
88
94
 
89
95
  if req.finished():
90
96
  self.tree_cache.cache_finished_req(req)
91
- req.time_stats.completion_time = time.time()
97
+ req.time_stats.completion_time = time.perf_counter()
92
98
  elif not batch.decoding_reqs or req not in batch.decoding_reqs:
93
99
  # This updates radix so others can match
94
100
  self.tree_cache.cache_unfinished_req(req)
@@ -98,7 +104,11 @@ class SchedulerOutputProcessorMixin:
98
104
  assert extend_input_len_per_req is not None
99
105
  extend_logprob_start_len = extend_logprob_start_len_per_req[i]
100
106
  extend_input_len = extend_input_len_per_req[i]
101
- num_input_logprobs = extend_input_len - extend_logprob_start_len
107
+
108
+ num_input_logprobs = self._calculate_num_input_logprobs(
109
+ req, extend_input_len, extend_logprob_start_len
110
+ )
111
+
102
112
  if req.return_logprob:
103
113
  self.add_logprob_return_values(
104
114
  i,
@@ -136,7 +146,7 @@ class SchedulerOutputProcessorMixin:
136
146
  logger.error(
137
147
  f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
138
148
  )
139
- self.abort_request(AbortReq(req.rid))
149
+ self.abort_request(AbortReq(rid=req.rid))
140
150
  req.grammar.finished = req.finished()
141
151
  else:
142
152
  # being chunked reqs' prefill is not finished
@@ -152,8 +162,8 @@ class SchedulerOutputProcessorMixin:
152
162
  extend_input_len = extend_input_len_per_req[i]
153
163
  if extend_logprob_start_len < extend_input_len:
154
164
  # Update input logprobs.
155
- num_input_logprobs = (
156
- extend_input_len - extend_logprob_start_len
165
+ num_input_logprobs = self._calculate_num_input_logprobs(
166
+ req, extend_input_len, extend_logprob_start_len
157
167
  )
158
168
  if req.return_logprob:
159
169
  self.add_input_logprob_return_values(
@@ -166,11 +176,8 @@ class SchedulerOutputProcessorMixin:
166
176
  )
167
177
  logprob_pt += num_input_logprobs
168
178
 
169
- self.set_next_batch_sampling_info_done(batch)
170
-
171
179
  else: # embedding or reward model
172
- embeddings, bid = result.embeddings, result.bid
173
- embeddings = embeddings.tolist()
180
+ embeddings = result.embeddings.tolist()
174
181
 
175
182
  # Check finish conditions
176
183
  for i, req in enumerate(batch.reqs):
@@ -197,22 +204,19 @@ class SchedulerOutputProcessorMixin:
197
204
  self: Scheduler,
198
205
  batch: ScheduleBatch,
199
206
  result: GenerationBatchResult,
200
- launch_done: Optional[threading.Event] = None,
201
207
  ):
202
- logits_output, next_token_ids, can_run_cuda_graph = (
208
+ logits_output, next_token_ids, can_run_cuda_graph, copy_done = (
203
209
  result.logits_output,
204
210
  result.next_token_ids,
205
211
  result.can_run_cuda_graph,
212
+ result.copy_done,
206
213
  )
207
214
  self.num_generated_tokens += len(batch.reqs)
208
215
 
209
- if self.enable_overlap:
210
- logits_output, next_token_ids, can_run_cuda_graph = (
211
- self.tp_worker.resolve_last_batch_result(launch_done)
212
- )
213
- next_token_logprobs = logits_output.next_token_logprobs
214
- elif batch.spec_algorithm.is_none():
215
- # spec decoding handles output logprobs inside verify process.
216
+ if copy_done is not None:
217
+ copy_done.synchronize()
218
+
219
+ if batch.spec_algorithm.is_none():
216
220
  next_token_ids = next_token_ids.tolist()
217
221
  if batch.return_logprob:
218
222
  next_token_logprobs = logits_output.next_token_logprobs.tolist()
@@ -246,8 +250,14 @@ class SchedulerOutputProcessorMixin:
246
250
 
247
251
  req.check_finished()
248
252
  if req.finished():
249
- self.tree_cache.cache_finished_req(req)
250
- req.time_stats.completion_time = time.time()
253
+ if self.server_args.disaggregation_decode_enable_offload_kvcache:
254
+ # Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
255
+ if not self.decode_offload_manager.offload_kv_cache(req):
256
+ self.tree_cache.cache_finished_req(req)
257
+ else:
258
+ self.tree_cache.cache_finished_req(req)
259
+
260
+ req.time_stats.completion_time = time.perf_counter()
251
261
 
252
262
  if req.return_logprob and batch.spec_algorithm.is_none():
253
263
  # speculative worker handles logprob in speculative decoding
@@ -283,10 +293,9 @@ class SchedulerOutputProcessorMixin:
283
293
  logger.error(
284
294
  f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
285
295
  )
286
- self.abort_request(AbortReq(req.rid))
296
+ self.abort_request(AbortReq(rid=req.rid))
287
297
  req.grammar.finished = req.finished()
288
298
 
289
- self.set_next_batch_sampling_info_done(batch)
290
299
  self.stream_output(batch.reqs, batch.return_logprob)
291
300
  self.token_to_kv_pool_allocator.free_group_end()
292
301
 
@@ -297,6 +306,153 @@ class SchedulerOutputProcessorMixin:
297
306
  ):
298
307
  self.log_decode_stats(can_run_cuda_graph, running_batch=batch)
299
308
 
309
+ def _process_input_token_logprobs(
310
+ self, req: Req, input_token_logprobs: List
311
+ ) -> None:
312
+ """Process input token logprobs values and indices."""
313
+ is_multi_item_scoring = self._is_multi_item_scoring(req)
314
+
315
+ # Process logprob values - handle multi-item scoring vs regular requests
316
+ if is_multi_item_scoring:
317
+ # Multi-item scoring: use all logprobs as-is
318
+ req.input_token_logprobs_val = input_token_logprobs
319
+ else:
320
+ # Regular request: add None at start, remove last (sampling token)
321
+ req.input_token_logprobs_val = [None] + input_token_logprobs[:-1]
322
+
323
+ # Process logprob indices based on scoring type
324
+ if is_multi_item_scoring:
325
+ # Multi-item scoring: only include delimiter token positions
326
+ relevant_tokens = req.origin_input_ids[req.logprob_start_len :]
327
+ input_token_logprobs_idx = [
328
+ token_id
329
+ for token_id in relevant_tokens
330
+ if token_id == self.server_args.multi_item_scoring_delimiter
331
+ ]
332
+ else:
333
+ # Regular request: include all tokens from logprob_start_len onwards
334
+ input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
335
+
336
+ # Clip padded hash values from image tokens to prevent detokenization errors
337
+ req.input_token_logprobs_idx = [
338
+ x if x < self.model_config.vocab_size - 1 else 0
339
+ for x in input_token_logprobs_idx
340
+ ]
341
+
342
+ def _process_input_top_logprobs(self, req: Req) -> None:
343
+ """Process input top logprobs."""
344
+ if req.top_logprobs_num <= 0:
345
+ return
346
+
347
+ is_multi_item_scoring = self._is_multi_item_scoring(req)
348
+
349
+ # Initialize arrays - multi-item scoring starts empty, others start with None
350
+ req.input_top_logprobs_val = [] if is_multi_item_scoring else [None]
351
+ req.input_top_logprobs_idx = [] if is_multi_item_scoring else [None]
352
+
353
+ # Extend arrays with temp values
354
+ for val, idx in zip(
355
+ req.temp_input_top_logprobs_val,
356
+ req.temp_input_top_logprobs_idx,
357
+ strict=True,
358
+ ):
359
+ req.input_top_logprobs_val.extend(val)
360
+ req.input_top_logprobs_idx.extend(idx)
361
+
362
+ # Remove last token (sampling token) for non multi-item scoring requests
363
+ if not is_multi_item_scoring:
364
+ req.input_top_logprobs_val.pop()
365
+ req.input_top_logprobs_idx.pop()
366
+
367
+ # Clean up temp storage
368
+ req.temp_input_top_logprobs_idx = None
369
+ req.temp_input_top_logprobs_val = None
370
+
371
+ def _process_input_token_ids_logprobs(self, req: Req) -> None:
372
+ """Process input token IDs logprobs."""
373
+ if req.token_ids_logprob is None:
374
+ return
375
+
376
+ is_multi_item_scoring = self._is_multi_item_scoring(req)
377
+
378
+ # Initialize arrays - multi-item scoring starts empty, others start with None
379
+ req.input_token_ids_logprobs_val = [] if is_multi_item_scoring else [None]
380
+ req.input_token_ids_logprobs_idx = [] if is_multi_item_scoring else [None]
381
+
382
+ # Process temp values - convert tensors to lists and extend arrays
383
+ for val, idx in zip(
384
+ req.temp_input_token_ids_logprobs_val,
385
+ req.temp_input_token_ids_logprobs_idx,
386
+ strict=True,
387
+ ):
388
+ val_list = val.tolist() if isinstance(val, torch.Tensor) else val
389
+ req.input_token_ids_logprobs_val.extend(
390
+ val_list if isinstance(val_list, list) else [val_list]
391
+ )
392
+ req.input_token_ids_logprobs_idx.extend(idx)
393
+
394
+ # Remove last token (sampling token) for non multi-item scoring requests
395
+ if not is_multi_item_scoring:
396
+ req.input_token_ids_logprobs_val.pop()
397
+ req.input_token_ids_logprobs_idx.pop()
398
+
399
+ # Clean up temp storage
400
+ req.temp_input_token_ids_logprobs_idx = None
401
+ req.temp_input_token_ids_logprobs_val = None
402
+
403
+ def _calculate_relevant_tokens_len(self, req: Req) -> int:
404
+ """Calculate the expected length of logprob arrays based on whether multi-item scoring is enabled.
405
+
406
+ For multi-item scoring, only delimiter positions have logprobs.
407
+ For regular requests, all positions from logprob_start_len onwards have logprobs.
408
+ """
409
+ is_multi_item_scoring = self._is_multi_item_scoring(req)
410
+
411
+ if is_multi_item_scoring:
412
+ # Multi-item scoring: count delimiter tokens from logprob_start_len onwards
413
+ relevant_tokens = req.origin_input_ids[req.logprob_start_len :]
414
+ return sum(
415
+ 1
416
+ for token_id in relevant_tokens
417
+ if token_id == self.server_args.multi_item_scoring_delimiter
418
+ )
419
+ else:
420
+ # Regular request: all tokens from logprob_start_len onwards
421
+ return len(req.origin_input_ids) - req.logprob_start_len
422
+
423
+ def _calculate_num_input_logprobs(
424
+ self, req: Req, extend_input_len: int, extend_logprob_start_len: int
425
+ ) -> int:
426
+ """Calculate the number of input logprobs based on whether multi-item scoring is enabled.
427
+
428
+ For multi-item scoring, only delimiter positions have logprobs.
429
+ For regular requests, all positions in the range have logprobs.
430
+ """
431
+ is_multi_item_scoring = self._is_multi_item_scoring(req)
432
+
433
+ if is_multi_item_scoring:
434
+ # Multi-item scoring: count delimiter tokens in the relevant portion
435
+ relevant_tokens = req.origin_input_ids[
436
+ extend_logprob_start_len:extend_input_len
437
+ ]
438
+ return sum(
439
+ 1
440
+ for token_id in relevant_tokens
441
+ if token_id == self.server_args.multi_item_scoring_delimiter
442
+ )
443
+ else:
444
+ # Regular request: all tokens in the range
445
+ return extend_input_len - extend_logprob_start_len
446
+
447
+ def _is_multi_item_scoring(self, req: Req) -> bool:
448
+ """Check if request uses multi-item scoring.
449
+
450
+ Multi-item scoring applies to prefill-only requests when a delimiter
451
+ token is configured. In this mode, only positions containing the
452
+ delimiter token receive logprobs.
453
+ """
454
+ return req.is_prefill_only and self.server_args.multi_item_scoring_delimiter
455
+
300
456
  def add_input_logprob_return_values(
301
457
  self: Scheduler,
302
458
  i: int,
@@ -365,63 +521,14 @@ class SchedulerOutputProcessorMixin:
365
521
  assert req.input_top_logprobs_val is None
366
522
  assert req.input_top_logprobs_idx is None
367
523
 
368
- # Compute input_token_logprobs_val
369
- # Always pad the first one with None.
370
- req.input_token_logprobs_val = [None]
371
- req.input_token_logprobs_val.extend(input_token_logprobs)
372
- # The last input logprob is for sampling, so just pop it out.
373
- req.input_token_logprobs_val.pop()
374
-
375
- # Compute input_token_logprobs_idx
376
- input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
377
- # Clip the padded hash values from image tokens.
378
- # Otherwise, it will lead to detokenization errors.
379
- input_token_logprobs_idx = [
380
- x if x < self.model_config.vocab_size - 1 else 0
381
- for x in input_token_logprobs_idx
382
- ]
383
- req.input_token_logprobs_idx = input_token_logprobs_idx
524
+ # Process all input logprob types using helper functions
525
+ self._process_input_token_logprobs(req, input_token_logprobs)
526
+ self._process_input_top_logprobs(req)
384
527
 
385
- if req.top_logprobs_num > 0:
386
- req.input_top_logprobs_val = [None]
387
- req.input_top_logprobs_idx = [None]
388
- assert len(req.temp_input_token_ids_logprobs_val) == len(
389
- req.temp_input_token_ids_logprobs_idx
390
- )
391
- for val, idx in zip(
392
- req.temp_input_top_logprobs_val,
393
- req.temp_input_top_logprobs_idx,
394
- strict=True,
395
- ):
396
- req.input_top_logprobs_val.extend(val)
397
- req.input_top_logprobs_idx.extend(idx)
398
-
399
- # Last token is a sample token.
400
- req.input_top_logprobs_val.pop()
401
- req.input_top_logprobs_idx.pop()
402
- req.temp_input_top_logprobs_idx = None
403
- req.temp_input_top_logprobs_val = None
404
-
405
- if req.token_ids_logprob is not None:
406
- req.input_token_ids_logprobs_val = [None]
407
- req.input_token_ids_logprobs_idx = [None]
408
-
409
- for val, idx in zip(
410
- req.temp_input_token_ids_logprobs_val,
411
- req.temp_input_token_ids_logprobs_idx,
412
- strict=True,
413
- ):
414
- req.input_token_ids_logprobs_val.extend(val)
415
- req.input_token_ids_logprobs_idx.extend(idx)
416
-
417
- # Last token is a sample token.
418
- req.input_token_ids_logprobs_val.pop()
419
- req.input_token_ids_logprobs_idx.pop()
420
- req.temp_input_token_ids_logprobs_idx = None
421
- req.temp_input_token_ids_logprobs_val = None
528
+ self._process_input_token_ids_logprobs(req)
422
529
 
423
530
  if req.return_logprob:
424
- relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
531
+ relevant_tokens_len = self._calculate_relevant_tokens_len(req)
425
532
  assert len(req.input_token_logprobs_val) == relevant_tokens_len
426
533
  assert len(req.input_token_logprobs_idx) == relevant_tokens_len
427
534
  if req.top_logprobs_num > 0:
@@ -441,27 +548,59 @@ class SchedulerOutputProcessorMixin:
441
548
  output: LogitsProcessorOutput,
442
549
  ):
443
550
  """Attach logprobs to the return values."""
444
- req.output_token_logprobs_val.append(output.next_token_logprobs[i])
445
- req.output_token_logprobs_idx.append(next_token_ids[i])
446
-
447
- self.add_input_logprob_return_values(
448
- i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
449
- )
551
+ if output.next_token_logprobs is not None:
552
+ req.output_token_logprobs_val.append(output.next_token_logprobs[i])
553
+ req.output_token_logprobs_idx.append(next_token_ids[i])
554
+
555
+ # Only add input logprobs if there are input tokens to process
556
+ # Note: For prefill-only requests with default logprob_start_len, this will be 0,
557
+ # meaning we only compute output logprobs (which is the intended behavior)
558
+ if num_input_logprobs > 0:
559
+ self.add_input_logprob_return_values(
560
+ i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
561
+ )
562
+ else:
563
+ self._initialize_empty_logprob_containers(req)
450
564
 
451
565
  if req.top_logprobs_num > 0:
452
566
  req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
453
567
  req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
454
568
 
455
- if req.token_ids_logprob is not None:
456
- req.output_token_ids_logprobs_val.append(
457
- output.next_token_token_ids_logprobs_val[i]
458
- )
569
+ if (
570
+ req.token_ids_logprob is not None
571
+ and output.next_token_token_ids_logprobs_val is not None
572
+ ):
573
+ # Convert GPU tensor to list if needed
574
+ logprobs_val = output.next_token_token_ids_logprobs_val[i]
575
+ if isinstance(logprobs_val, torch.Tensor):
576
+ logprobs_val = logprobs_val.tolist()
577
+ req.output_token_ids_logprobs_val.append(logprobs_val)
459
578
  req.output_token_ids_logprobs_idx.append(
460
579
  output.next_token_token_ids_logprobs_idx[i]
461
580
  )
462
581
 
463
582
  return num_input_logprobs
464
583
 
584
+ def _initialize_empty_logprob_containers(self, req: Req) -> None:
585
+ """
586
+ Initialize logprob fields to empty lists if unset.
587
+
588
+ This is needed for prefill-only requests where the normal initialization
589
+ flow might be bypassed, but downstream code expects these fields to be lists.
590
+ """
591
+ if req.input_token_logprobs_val is None:
592
+ req.input_token_logprobs_val = []
593
+ if req.input_token_logprobs_idx is None:
594
+ req.input_token_logprobs_idx = []
595
+ if req.input_top_logprobs_val is None:
596
+ req.input_top_logprobs_val = []
597
+ if req.input_top_logprobs_idx is None:
598
+ req.input_top_logprobs_idx = []
599
+ if req.input_token_ids_logprobs_val is None:
600
+ req.input_token_ids_logprobs_val = []
601
+ if req.input_token_ids_logprobs_idx is None:
602
+ req.input_token_ids_logprobs_idx = []
603
+
465
604
  def stream_output(
466
605
  self: Scheduler,
467
606
  reqs: List[Req],
@@ -673,8 +812,7 @@ class SchedulerOutputProcessorMixin:
673
812
  return
674
813
 
675
814
  self.send_to_detokenizer.send_pyobj(
676
- BatchTokenIDOut(
677
- rids,
815
+ BatchTokenIDOutput(
678
816
  finished_reasons,
679
817
  decoded_texts,
680
818
  decode_ids_list,
@@ -700,6 +838,9 @@ class SchedulerOutputProcessorMixin:
700
838
  output_token_ids_logprobs_val,
701
839
  output_token_ids_logprobs_idx,
702
840
  output_hidden_states,
841
+ rids=rids,
842
+ placeholder_tokens_idx=None,
843
+ placeholder_tokens_val=None,
703
844
  )
704
845
  )
705
846
 
@@ -718,7 +859,13 @@ class SchedulerOutputProcessorMixin:
718
859
  prompt_tokens.append(len(req.origin_input_ids))
719
860
  cached_tokens.append(req.cached_tokens)
720
861
  self.send_to_detokenizer.send_pyobj(
721
- BatchEmbeddingOut(
722
- rids, finished_reasons, embeddings, prompt_tokens, cached_tokens
862
+ BatchEmbeddingOutput(
863
+ finished_reasons,
864
+ embeddings,
865
+ prompt_tokens,
866
+ cached_tokens,
867
+ rids=rids,
868
+ placeholder_tokens_idx=None,
869
+ placeholder_tokens_val=None,
723
870
  )
724
871
  )
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
26
26
 
27
27
  class SchedulerProfilerMixin:
28
28
 
29
- def init_profier(self):
29
+ def init_profiler(self):
30
30
  self.torch_profiler = None
31
31
  self.torch_profiler_output_dir: Optional[str] = None
32
32
  self.profiler_activities: Optional[List[str]] = None
@@ -97,7 +97,7 @@ class SchedulerProfilerMixin:
97
97
  def start_profile(
98
98
  self, stage: Optional[ForwardMode] = None
99
99
  ) -> ProfileReqOutput | None:
100
- stage_str = f" for {stage.__str__()}" if stage else ""
100
+ stage_str = f" for {stage.name}" if stage else ""
101
101
  logger.info(
102
102
  f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
103
103
  )
@@ -181,7 +181,7 @@ class SchedulerProfilerMixin:
181
181
  if not Path(self.torch_profiler_output_dir).exists():
182
182
  Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
183
183
 
184
- stage_suffix = f"-{stage.__str__()}" if stage else ""
184
+ stage_suffix = f"-{stage.name}" if stage else ""
185
185
  logger.info("Stop profiling" + stage_suffix + "...")
186
186
  if self.torch_profiler is not None:
187
187
  self.torch_profiler.stop()
@@ -204,7 +204,7 @@ class SchedulerProfilerMixin:
204
204
 
205
205
  torch.distributed.barrier(self.tp_cpu_group)
206
206
  if self.tp_rank == 0:
207
- from sglang.srt.utils import rpd_to_chrome_trace
207
+ from sglang.srt.utils.rpd_utils import rpd_to_chrome_trace
208
208
 
209
209
  rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
210
210
  self.rpd_profiler = None
@@ -247,7 +247,7 @@ class SchedulerProfilerMixin:
247
247
  if self.profiler_decode_ct == 0:
248
248
  if self.profile_in_progress:
249
249
  # force trace flush
250
- self.stop_profile(ForwardMode.EXTEND)
250
+ self.stop_profile(stage=ForwardMode.EXTEND)
251
251
  self.start_profile(batch.forward_mode)
252
252
  self.profiler_decode_ct += 1
253
253
  if self.profiler_decode_ct > self.profiler_target_decode_ct:
@@ -294,6 +294,6 @@ class SchedulerProfilerMixin:
294
294
  recv_req.profile_by_stage,
295
295
  recv_req.profile_id,
296
296
  )
297
- return self.start_profile(True)
297
+ return self.start_profile()
298
298
  else:
299
299
  return self.stop_profile()
@@ -5,6 +5,8 @@ import torch
5
5
 
6
6
  from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
7
7
  from sglang.srt.managers.io_struct import (
8
+ DestroyWeightsUpdateGroupReqInput,
9
+ DestroyWeightsUpdateGroupReqOutput,
8
10
  GetWeightsByNameReqInput,
9
11
  GetWeightsByNameReqOutput,
10
12
  InitWeightsUpdateGroupReqInput,
@@ -41,6 +43,11 @@ class SchedulerUpdateWeightsMixin:
41
43
  success, message = self.tp_worker.init_weights_update_group(recv_req)
42
44
  return InitWeightsUpdateGroupReqOutput(success, message)
43
45
 
46
+ def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
47
+ """Destroy the online model parameter update group."""
48
+ success, message = self.tp_worker.destroy_weights_update_group(recv_req)
49
+ return DestroyWeightsUpdateGroupReqOutput(success, message)
50
+
44
51
  def update_weights_from_distributed(
45
52
  self,
46
53
  recv_req: UpdateWeightsFromDistributedReqInput,