sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -3,149 +3,44 @@ Standalone gRPC Server for SGLang - Fully separated from HTTP server.
3
3
  Uses GrpcRequestManager for orchestration without tokenization.
4
4
  """
5
5
 
6
- import argparse
7
6
  import asyncio
7
+ import dataclasses
8
8
  import logging
9
9
  import multiprocessing as mp
10
10
  import os
11
11
  import signal
12
+ import threading
12
13
  import time
13
14
  from concurrent import futures
14
- from typing import AsyncIterator, Dict, Optional, Tuple
15
+ from typing import AsyncIterator, Dict, Optional
15
16
 
16
17
  import grpc
18
+ from google.protobuf.json_format import MessageToDict
19
+ from google.protobuf.struct_pb2 import Struct
20
+ from google.protobuf.timestamp_pb2 import Timestamp
21
+ from grpc_health.v1 import health_pb2_grpc
17
22
  from grpc_reflection.v1alpha import reflection
18
23
 
19
- from sglang.srt.entrypoints.grpc_request_manager import GrpcRequestManager
24
+ import sglang
25
+ from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
20
26
  from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc
21
- from sglang.srt.managers.data_parallel_controller import (
22
- run_data_parallel_controller_process,
23
- )
27
+ from sglang.srt.grpc.grpc_request_manager import GrpcRequestManager
28
+ from sglang.srt.grpc.health_servicer import SGLangHealthServicer
29
+ from sglang.srt.grpc.scheduler_launcher import launch_scheduler_process_only
30
+ from sglang.srt.managers.disagg_service import start_disagg_service
24
31
  from sglang.srt.managers.io_struct import (
25
32
  TokenizedEmbeddingReqInput,
26
33
  TokenizedGenerateReqInput,
27
34
  )
28
- from sglang.srt.managers.scheduler import run_scheduler_process
29
35
  from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams
30
- from sglang.srt.server_args import PortArgs, ServerArgs
31
- from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
32
- from sglang.srt.utils import configure_logger, prepare_model_and_tokenizer
36
+ from sglang.srt.server_args import ServerArgs
37
+ from sglang.srt.utils import kill_process_tree
33
38
  from sglang.utils import get_exception_traceback
34
39
 
35
40
  logger = logging.getLogger(__name__)
36
41
  HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
37
42
 
38
43
 
39
- def _launch_scheduler_process_only(
40
- server_args: ServerArgs,
41
- port_args: Optional[PortArgs] = None,
42
- ) -> Tuple[Dict, PortArgs, list]:
43
- """
44
- Launch only the scheduler process(es) without tokenizer/detokenizer.
45
- Returns scheduler info, port args, and list of scheduler processes.
46
- """
47
- # Configure global environment
48
- configure_logger(server_args)
49
- server_args.check_server_args()
50
-
51
- # Allocate ports for inter-process communications
52
- if port_args is None:
53
- port_args = PortArgs.init_new(server_args)
54
- logger.info(f"{server_args=}")
55
-
56
- # Prepare model and tokenizer paths
57
- server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
58
- server_args.model_path, server_args.tokenizer_path
59
- )
60
-
61
- scheduler_procs = []
62
- if server_args.dp_size == 1:
63
- memory_saver_adapter = TorchMemorySaverAdapter.create(
64
- enable=server_args.enable_memory_saver
65
- )
66
- scheduler_pipe_readers = []
67
-
68
- nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
69
- tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
70
- tp_rank_range = range(
71
- tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
72
- tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
73
- )
74
-
75
- pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
76
- pp_rank_range = range(
77
- pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
78
- pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
79
- )
80
-
81
- for pp_rank in pp_rank_range:
82
- for tp_rank in tp_rank_range:
83
- reader, writer = mp.Pipe(duplex=False)
84
- gpu_id = (
85
- server_args.base_gpu_id
86
- + ((pp_rank % pp_size_per_node) * tp_size_per_node)
87
- + (tp_rank % tp_size_per_node) * server_args.gpu_id_step
88
- )
89
- moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
90
- proc = mp.Process(
91
- target=run_scheduler_process,
92
- args=(
93
- server_args,
94
- port_args,
95
- gpu_id,
96
- tp_rank,
97
- moe_ep_rank,
98
- pp_rank,
99
- None,
100
- writer,
101
- None,
102
- ),
103
- )
104
-
105
- with memory_saver_adapter.configure_subprocess():
106
- proc.start()
107
- scheduler_procs.append(proc)
108
- scheduler_pipe_readers.append(reader)
109
- else:
110
- # Launch the data parallel controller
111
- reader, writer = mp.Pipe(duplex=False)
112
- scheduler_pipe_readers = [reader]
113
- proc = mp.Process(
114
- target=run_data_parallel_controller_process,
115
- args=(server_args, port_args, writer),
116
- )
117
- proc.start()
118
- scheduler_procs.append(proc)
119
-
120
- # TODO(CatherineSue): handle cases for multi-node
121
-
122
- # Wait for all scheduler processes to be ready
123
- scheduler_infos = []
124
- for i, reader in enumerate(scheduler_pipe_readers):
125
- try:
126
- data = reader.recv()
127
- except EOFError:
128
- logger.error(
129
- f"Rank {i} scheduler is dead. Please check if there are relevant logs."
130
- )
131
- scheduler_procs[i].join()
132
- logger.error(f"Exit code: {scheduler_procs[i].exitcode}")
133
- raise RuntimeError(f"Failed to initialize scheduler rank {i}")
134
-
135
- if data.get("status") != "ready":
136
- raise RuntimeError(
137
- f"Scheduler rank {i} initialization failed: {data.get('error', 'Unknown error')}"
138
- )
139
- scheduler_infos.append(data)
140
-
141
- logger.info(
142
- f"All {len(scheduler_procs)} scheduler process(es) initialized successfully"
143
- )
144
-
145
- # Return the first scheduler's info (they should all be the same)
146
- return scheduler_infos[0], port_args, scheduler_procs
147
-
148
-
149
44
  class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer):
150
45
  """
151
46
  Standalone gRPC service implementation using GrpcRequestManager.
@@ -157,17 +52,21 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
157
52
  request_manager: GrpcRequestManager,
158
53
  server_args: ServerArgs,
159
54
  model_info: Dict,
55
+ scheduler_info: Dict,
56
+ health_servicer: Optional[SGLangHealthServicer] = None,
160
57
  ):
161
58
  """Initialize the standalone gRPC service."""
162
59
  self.request_manager = request_manager
163
60
  self.server_args = server_args
164
61
  self.model_info = model_info
62
+ self.scheduler_info = scheduler_info
165
63
  self.start_time = time.time()
64
+ self.health_servicer = health_servicer
166
65
 
167
66
  # Start the request manager's event loop using auto_create_handle_loop
168
67
  self.request_manager.auto_create_handle_loop()
169
68
 
170
- logger.info("Standalone gRPC scheduler service initialized")
69
+ logger.info("gRPC scheduler servicer initialized")
171
70
 
172
71
  async def Generate(
173
72
  self,
@@ -175,26 +74,40 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
175
74
  context: grpc.aio.ServicerContext,
176
75
  ) -> AsyncIterator[sglang_scheduler_pb2.GenerateResponse]:
177
76
  """Handle generation requests with streaming responses."""
178
- logger.info(f"Generation request: {request.request_id}")
77
+ logger.info(f"Receive generation request: {request.request_id}")
179
78
 
180
79
  try:
181
80
  # Convert gRPC request to internal format
182
81
  tokenized_req = self._convert_generate_request(request)
183
82
 
184
- # Submit to request manager
185
- output_queue = await self.request_manager.generate_request(
83
+ # Submit to request manager (automatically handles n>1)
84
+ response_generator = self.request_manager.generate_request(
186
85
  obj=tokenized_req,
187
86
  request_id=request.request_id,
188
87
  grpc_context=context,
189
88
  )
190
89
 
191
- # Stream outputs
192
- while True:
193
- try:
194
- # Get output with timeout
195
- output = await asyncio.wait_for(output_queue.get(), timeout=4)
196
-
197
- # Check for errors
90
+ async for output in response_generator:
91
+ # Handle batch responses (for n>1 non-streaming)
92
+ if isinstance(output, list):
93
+ for batch_output in output:
94
+ if "error" in batch_output:
95
+ yield sglang_scheduler_pb2.GenerateResponse(
96
+ request_id=request.request_id,
97
+ error=sglang_scheduler_pb2.GenerateError(
98
+ message=batch_output["error"],
99
+ http_status_code=(
100
+ "500" if "abort" not in batch_output else "499"
101
+ ),
102
+ ),
103
+ )
104
+ else:
105
+ # All non-error batch outputs are final responses
106
+ yield self._create_completion_response(
107
+ request.request_id, batch_output
108
+ )
109
+ else:
110
+ # Handle single response (for streaming or n=1 non-streaming)
198
111
  if "error" in output:
199
112
  yield sglang_scheduler_pb2.GenerateResponse(
200
113
  request_id=request.request_id,
@@ -205,29 +118,18 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
205
118
  ),
206
119
  ),
207
120
  )
208
- break
209
-
210
- # Check if finished
211
- if output.get("finished", False):
212
- # Send completion
121
+ elif output.get("finished", False):
213
122
  yield self._create_completion_response(
214
123
  request.request_id, output
215
124
  )
216
- break
217
125
  else:
218
- # Send chunk
219
126
  yield self._create_chunk_response(request.request_id, output)
220
127
 
221
- except asyncio.TimeoutError:
222
- # Check if context is still active
223
- if context.cancelled():
224
- # Abort the request
225
- await self.request_manager.abort_request(request.request_id)
226
- break
227
- continue
228
-
229
128
  except Exception as e:
230
- logger.error(f"Generate failed: {e}\n{get_exception_traceback()}")
129
+ logger.error(
130
+ f"Generate failed for request {request.request_id}: {e}\n"
131
+ f"{get_exception_traceback()}"
132
+ )
231
133
  yield sglang_scheduler_pb2.GenerateResponse(
232
134
  request_id=request.request_id,
233
135
  error=sglang_scheduler_pb2.GenerateError(
@@ -240,10 +142,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
240
142
  async def Embed(
241
143
  self,
242
144
  request: sglang_scheduler_pb2.EmbedRequest,
243
- context: grpc.aio.ServicerContext,
145
+ _context: grpc.aio.ServicerContext,
244
146
  ) -> sglang_scheduler_pb2.EmbedResponse:
245
147
  """Handle embedding requests."""
246
- logger.info(f"Embedding request: {request.request_id}")
148
+ logger.info(f"Receive embedding request: {request.request_id}")
247
149
 
248
150
  try:
249
151
  # Convert request
@@ -266,12 +168,14 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
266
168
  prompt_tokens=result.get("prompt_tokens", 0),
267
169
  cached_tokens=0,
268
170
  embedding_dim=len(result["embedding"]),
269
- generation_time=time.time() - self.start_time,
270
171
  ),
271
172
  )
272
173
 
273
174
  except Exception as e:
274
- logger.error(f"Embed failed: {e}\n{get_exception_traceback()}")
175
+ logger.error(
176
+ f"Embed failed for request {request.request_id}: {e}\n"
177
+ f"{get_exception_traceback()}"
178
+ )
275
179
  return sglang_scheduler_pb2.EmbedResponse(
276
180
  request_id=request.request_id,
277
181
  error=sglang_scheduler_pb2.EmbedError(
@@ -286,82 +190,95 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
286
190
  request: sglang_scheduler_pb2.HealthCheckRequest,
287
191
  context: grpc.aio.ServicerContext,
288
192
  ) -> sglang_scheduler_pb2.HealthCheckResponse:
289
- """Health check by generating from client input."""
290
- try:
291
- # Check if request manager is shutting down
292
- if self.request_manager.gracefully_exit:
293
- return sglang_scheduler_pb2.HealthCheckResponse(
294
- healthy=False, message="Server shutting down"
295
- )
296
-
297
- # Extract tokenized input from request
298
- if not request.HasField("tokenized"):
299
- return sglang_scheduler_pb2.HealthCheckResponse(
300
- healthy=False, message="Tokenized input required for health check"
301
- )
302
-
303
- input_text = request.tokenized.original_text
304
- input_ids = list(request.tokenized.input_ids)
193
+ """
194
+ Check the health of the inference server by sending a special request to generate one token.
195
+ Similar to HTTP server's /health endpoint.
196
+ """
197
+ rid = f"HEALTH_CHECK_{time.time()}"
198
+ logger.info(f"Receive health check request: {rid}")
199
+
200
+ if self.request_manager.gracefully_exit:
201
+ logger.info(
202
+ "Health check request received during shutdown. Returning unhealthy."
203
+ )
204
+ return sglang_scheduler_pb2.HealthCheckResponse(
205
+ healthy=False, message="Server is shutting down"
206
+ )
305
207
 
306
- # Create health check request
307
- rid = f"HEALTH_CHECK_GRPC_{time.time()}"
208
+ # Create a special health check request
209
+ sampling_params = SGLSamplingParams(max_new_tokens=1, temperature=0.0)
210
+ sampling_params.normalize(tokenizer=None)
308
211
 
309
- health_request = TokenizedGenerateReqInput(
212
+ # Create health check request
213
+ is_generation = self.scheduler_info.get("is_generation", True)
214
+ if is_generation:
215
+ health_req = TokenizedGenerateReqInput(
310
216
  rid=rid,
311
- input_text=input_text,
312
- input_ids=input_ids,
313
- sampling_params=SGLSamplingParams(max_new_tokens=1, temperature=0.0),
314
- stream=False,
315
- mm_inputs=None,
217
+ input_text="",
218
+ input_ids=[0],
219
+ sampling_params=sampling_params,
316
220
  return_logprob=False,
317
221
  logprob_start_len=-1,
318
222
  top_logprobs_num=0,
223
+ stream=False,
224
+ mm_inputs=None,
319
225
  token_ids_logprob=None,
320
226
  )
321
-
322
- logger.info(f"Sending health check request to request manager...")
323
-
324
- # Submit and wait for response
325
- output_queue = await self.request_manager.generate_request(
326
- health_request, request_id=rid
227
+ # Set disaggregation params if needed
228
+ if self.server_args.disaggregation_mode != DisaggregationMode.NULL:
229
+ health_req.bootstrap_host = FAKE_BOOTSTRAP_HOST
230
+ health_req.bootstrap_room = 0
231
+ else:
232
+ health_req = TokenizedEmbeddingReqInput(
233
+ rid=rid,
234
+ input_text="",
235
+ input_ids=[0],
327
236
  )
328
237
 
238
+ # Submit health check request
239
+ async def run_health_check():
329
240
  try:
330
- # Wait for response with configurable timeout
331
- response = await asyncio.wait_for(
332
- output_queue.get(), timeout=HEALTH_CHECK_TIMEOUT
333
- )
334
-
335
- # Clean up
336
- if rid in self.request_manager.rid_to_state:
337
- del self.request_manager.rid_to_state[rid]
338
-
241
+ async for _ in self.request_manager.generate_request(
242
+ obj=health_req,
243
+ request_id=rid,
244
+ ):
245
+ # Got at least one response, server is healthy
246
+ return True
247
+ except Exception as e:
248
+ logger.warning(f"Health check failed: {e}")
249
+ return False
250
+ return False
251
+
252
+ task = asyncio.create_task(run_health_check())
253
+
254
+ # Wait for response with timeout
255
+ tic = time.time()
256
+ while time.time() < tic + HEALTH_CHECK_TIMEOUT:
257
+ await asyncio.sleep(1)
258
+ # Check if we got a response from scheduler
259
+ if self.request_manager.last_receive_tstamp > tic:
260
+ task.cancel()
261
+ # Clean up health check state
262
+ self.request_manager._cleanup_request_state(rid)
339
263
  return sglang_scheduler_pb2.HealthCheckResponse(
340
264
  healthy=True, message="Health check passed"
341
265
  )
342
266
 
343
- except asyncio.TimeoutError:
344
- # Clean up on timeout
345
- if rid in self.request_manager.rid_to_state:
346
- del self.request_manager.rid_to_state[rid]
347
-
348
- return sglang_scheduler_pb2.HealthCheckResponse(
349
- healthy=False, message="Health check timeout"
350
- )
351
-
352
- except Exception as e:
353
- logger.error(f"Health check failed: {e}")
354
- return sglang_scheduler_pb2.HealthCheckResponse(
355
- healthy=False, message=f"Health check error: {str(e)}"
356
- )
267
+ # Timeout - server not responding
268
+ task.cancel()
269
+ self.request_manager._cleanup_request_state(rid)
270
+ logger.warning(f"Health check timeout after {HEALTH_CHECK_TIMEOUT}s")
271
+ return sglang_scheduler_pb2.HealthCheckResponse(
272
+ healthy=False, message=f"Health check timeout after {HEALTH_CHECK_TIMEOUT}s"
273
+ )
357
274
 
358
275
  async def Abort(
359
276
  self,
360
277
  request: sglang_scheduler_pb2.AbortRequest,
361
- context: grpc.aio.ServicerContext,
278
+ _context: grpc.aio.ServicerContext,
362
279
  ) -> sglang_scheduler_pb2.AbortResponse:
363
280
  """Abort an ongoing request."""
364
- logger.info(f"Aborting request: {request.request_id}")
281
+ logger.info(f"Receive abort request: {request.request_id}")
365
282
 
366
283
  try:
367
284
  success = await self.request_manager.abort_request(request.request_id)
@@ -371,12 +288,98 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
371
288
  message=f"Request {request.request_id} {'aborted' if success else 'not found'}",
372
289
  )
373
290
  except Exception as e:
374
- logger.error(f"Abort failed: {e}")
291
+ logger.error(
292
+ f"Abort failed for request {request.request_id}: {e}\n"
293
+ f"{get_exception_traceback()}"
294
+ )
375
295
  return sglang_scheduler_pb2.AbortResponse(
376
296
  success=False,
377
297
  message=str(e),
378
298
  )
379
299
 
300
+ async def GetModelInfo(
301
+ self,
302
+ _request: sglang_scheduler_pb2.GetModelInfoRequest,
303
+ _context: grpc.aio.ServicerContext,
304
+ ) -> sglang_scheduler_pb2.GetModelInfoResponse:
305
+ """Get model information."""
306
+ logger.debug("Receive model info request")
307
+
308
+ is_generation = self.scheduler_info.get("is_generation")
309
+ if is_generation is None:
310
+ is_generation = not self.server_args.is_embedding
311
+
312
+ return sglang_scheduler_pb2.GetModelInfoResponse(
313
+ model_path=self.server_args.model_path,
314
+ tokenizer_path=self.server_args.tokenizer_path or "",
315
+ is_generation=is_generation,
316
+ preferred_sampling_params=(
317
+ self.server_args.preferred_sampling_params or ""
318
+ ),
319
+ weight_version=self.server_args.weight_version or "",
320
+ served_model_name=self.server_args.served_model_name,
321
+ max_context_length=self.model_info["max_context_length"],
322
+ vocab_size=self.model_info["vocab_size"],
323
+ supports_vision=self.model_info["supports_vision"],
324
+ model_type=self.model_info["model_type"],
325
+ eos_token_ids=self.model_info["eos_token_ids"],
326
+ pad_token_id=self.model_info["pad_token_id"],
327
+ bos_token_id=self.model_info["bos_token_id"],
328
+ max_req_input_len=self.model_info["max_req_input_len"],
329
+ )
330
+
331
+ async def GetServerInfo(
332
+ self,
333
+ _request: sglang_scheduler_pb2.GetServerInfoRequest,
334
+ _context: grpc.aio.ServicerContext,
335
+ ) -> sglang_scheduler_pb2.GetServerInfoResponse:
336
+ """Get server information."""
337
+ logger.debug("Receive server info request")
338
+
339
+ server_args_dict = dataclasses.asdict(self.server_args)
340
+ server_args_struct = Struct()
341
+
342
+ def make_serializable(obj):
343
+ if obj is None:
344
+ return None
345
+ elif isinstance(obj, (str, int, float, bool)):
346
+ return obj
347
+ elif isinstance(obj, (list, tuple, set)):
348
+ return [make_serializable(item) for item in obj]
349
+ elif isinstance(obj, dict):
350
+ return {k: make_serializable(v) for k, v in obj.items()}
351
+ else:
352
+ return str(obj)
353
+
354
+ serializable_args = make_serializable(server_args_dict)
355
+ server_args_struct.update(serializable_args)
356
+
357
+ # Convert scheduler_info to Struct
358
+ scheduler_info_struct = Struct()
359
+ scheduler_info_struct.update(self.scheduler_info)
360
+
361
+ # Get runtime state from request manager
362
+ manager_state = self.request_manager.get_server_info()
363
+
364
+ # Calculate uptime
365
+ uptime = time.time() - self.start_time
366
+
367
+ # Create timestamp
368
+ start_timestamp = Timestamp()
369
+ start_timestamp.FromSeconds(int(self.start_time))
370
+
371
+ return sglang_scheduler_pb2.GetServerInfoResponse(
372
+ server_args=server_args_struct,
373
+ scheduler_info=scheduler_info_struct,
374
+ active_requests=manager_state["active_requests"],
375
+ is_paused=manager_state["paused"],
376
+ last_receive_timestamp=manager_state["last_receive_time"],
377
+ uptime_seconds=uptime,
378
+ sglang_version=sglang.__version__,
379
+ server_type="grpc",
380
+ start_time=start_timestamp,
381
+ )
382
+
380
383
  # Helper methods for request/response conversion
381
384
 
382
385
  def _convert_generate_request(
@@ -393,6 +396,27 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
393
396
 
394
397
  # Convert sampling params
395
398
  sampling_params = self._convert_sampling_params(grpc_req.sampling_params)
399
+ sampling_params.normalize(tokenizer=None)
400
+
401
+ # Extract disaggregated params if present
402
+ bootstrap_host = None
403
+ bootstrap_port = None
404
+ bootstrap_room = None
405
+ if grpc_req.HasField("disaggregated_params"):
406
+ # Don't use 'or None' as it treats 0 as falsy
407
+ bootstrap_host = (
408
+ grpc_req.disaggregated_params.bootstrap_host
409
+ if grpc_req.disaggregated_params.bootstrap_host
410
+ else None
411
+ )
412
+ bootstrap_port = (
413
+ grpc_req.disaggregated_params.bootstrap_port
414
+ if grpc_req.disaggregated_params.bootstrap_port
415
+ else None
416
+ )
417
+ bootstrap_room = (
418
+ grpc_req.disaggregated_params.bootstrap_room
419
+ ) # Can be 0, don't use 'or None'
396
420
 
397
421
  # Create request
398
422
  return TokenizedGenerateReqInput(
@@ -402,13 +426,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
402
426
  mm_inputs=None, # TODO: implement mm support
403
427
  sampling_params=sampling_params,
404
428
  return_logprob=grpc_req.return_logprob,
405
- logprob_start_len=grpc_req.logprob_start_len or -1,
429
+ logprob_start_len=(
430
+ grpc_req.logprob_start_len
431
+ if grpc_req.logprob_start_len is not None
432
+ else -1
433
+ ),
406
434
  top_logprobs_num=grpc_req.top_logprobs_num or 0,
407
- stream=True, # Always stream for gRPC
408
- lora_path=grpc_req.lora_id if grpc_req.lora_id else None,
435
+ stream=grpc_req.stream or False,
436
+ lora_id=grpc_req.lora_id if grpc_req.lora_id else None,
409
437
  token_ids_logprob=(
410
438
  list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
411
439
  ),
440
+ bootstrap_host=bootstrap_host,
441
+ bootstrap_port=bootstrap_port,
442
+ bootstrap_room=bootstrap_room,
412
443
  )
413
444
 
414
445
  def _convert_embed_request(
@@ -438,6 +469,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
438
469
  regex = None
439
470
  json_schema = None
440
471
  ebnf_grammar = None
472
+ structural_tag = None
441
473
 
442
474
  if grpc_params.HasField("regex"):
443
475
  regex = grpc_params.regex
@@ -445,44 +477,151 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
445
477
  json_schema = grpc_params.json_schema
446
478
  elif grpc_params.HasField("ebnf_grammar"):
447
479
  ebnf_grammar = grpc_params.ebnf_grammar
480
+ elif grpc_params.HasField("structural_tag"):
481
+ structural_tag = grpc_params.structural_tag
482
+
483
+ # Handle optional parameters conversion
484
+ custom_params = (
485
+ MessageToDict(grpc_params.custom_params)
486
+ if grpc_params.HasField("custom_params")
487
+ else None
488
+ )
489
+ max_new_tokens = (
490
+ grpc_params.max_new_tokens
491
+ if grpc_params.HasField("max_new_tokens")
492
+ else None
493
+ )
494
+ stream_interval = (
495
+ grpc_params.stream_interval
496
+ if grpc_params.HasField("stream_interval")
497
+ else None
498
+ )
499
+ logit_bias = dict(grpc_params.logit_bias) if grpc_params.logit_bias else None
500
+ stop = list(grpc_params.stop) if grpc_params.stop else None
501
+ stop_token_ids = (
502
+ list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else None
503
+ )
448
504
 
449
505
  return SGLSamplingParams(
450
- temperature=grpc_params.temperature or 1.0,
451
- top_p=grpc_params.top_p or 1.0,
452
- top_k=grpc_params.top_k or -1,
453
- min_p=grpc_params.min_p or 0.0,
454
- frequency_penalty=grpc_params.frequency_penalty or 0.0,
455
- presence_penalty=grpc_params.presence_penalty or 0.0,
456
- repetition_penalty=grpc_params.repetition_penalty or 1.0,
457
- max_new_tokens=grpc_params.max_new_tokens or 128,
458
- min_new_tokens=grpc_params.min_new_tokens or 0,
459
- stop=list(grpc_params.stop) if grpc_params.stop else None,
460
- stop_token_ids=(
461
- list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else None
462
- ),
506
+ temperature=grpc_params.temperature,
507
+ top_p=grpc_params.top_p,
508
+ top_k=grpc_params.top_k,
509
+ min_p=grpc_params.min_p,
510
+ frequency_penalty=grpc_params.frequency_penalty,
511
+ presence_penalty=grpc_params.presence_penalty,
512
+ repetition_penalty=grpc_params.repetition_penalty,
513
+ max_new_tokens=max_new_tokens,
514
+ min_new_tokens=grpc_params.min_new_tokens,
515
+ stop=stop,
516
+ stop_token_ids=stop_token_ids,
463
517
  skip_special_tokens=grpc_params.skip_special_tokens,
464
518
  spaces_between_special_tokens=grpc_params.spaces_between_special_tokens,
519
+ no_stop_trim=grpc_params.no_stop_trim,
465
520
  regex=regex,
466
521
  json_schema=json_schema,
467
522
  ebnf=ebnf_grammar,
468
- n=grpc_params.n or 1,
523
+ structural_tag=structural_tag,
524
+ n=grpc_params.n,
469
525
  ignore_eos=grpc_params.ignore_eos,
526
+ stream_interval=stream_interval,
527
+ logit_bias=logit_bias,
528
+ custom_params=custom_params,
529
+ )
530
+
531
+ def _convert_output_logprobs_to_proto(
532
+ self, logprobs_data: Dict
533
+ ) -> Optional[sglang_scheduler_pb2.OutputLogProbs]:
534
+ """Convert output logprobs dict to proto (no None values, plain floats)."""
535
+ if not logprobs_data:
536
+ return None
537
+
538
+ token_logprobs_val = logprobs_data.get("token_logprobs_val", [])
539
+ token_logprobs_idx = logprobs_data.get("token_logprobs_idx", [])
540
+ top_logprobs_val = logprobs_data.get("top_logprobs_val", [])
541
+ top_logprobs_idx = logprobs_data.get("top_logprobs_idx", [])
542
+
543
+ # Build TopLogProbs entries
544
+ top_logprobs_proto = []
545
+ if top_logprobs_val and top_logprobs_idx:
546
+ for val_list, idx_list in zip(top_logprobs_val, top_logprobs_idx):
547
+ top_logprobs_proto.append(
548
+ sglang_scheduler_pb2.TopLogProbs(
549
+ values=val_list,
550
+ token_ids=idx_list,
551
+ )
552
+ )
553
+
554
+ return sglang_scheduler_pb2.OutputLogProbs(
555
+ token_logprobs=token_logprobs_val, # Plain float array
556
+ token_ids=token_logprobs_idx,
557
+ top_logprobs=top_logprobs_proto,
558
+ )
559
+
560
+ def _convert_input_logprobs_to_proto(
561
+ self, logprobs_data: Dict
562
+ ) -> Optional[sglang_scheduler_pb2.InputLogProbs]:
563
+ """Convert input logprobs dict to proto (first token is None, wrapped in InputTokenLogProb)."""
564
+ if not logprobs_data:
565
+ return None
566
+
567
+ token_logprobs_val = logprobs_data.get("token_logprobs_val", [])
568
+ token_logprobs_idx = logprobs_data.get("token_logprobs_idx", [])
569
+ top_logprobs_val = logprobs_data.get("top_logprobs_val", [])
570
+ top_logprobs_idx = logprobs_data.get("top_logprobs_idx", [])
571
+
572
+ # Wrap values in InputTokenLogProb (None for first token, value for others)
573
+ token_logprobs_wrapped = [
574
+ (
575
+ sglang_scheduler_pb2.InputTokenLogProb()
576
+ if x is None
577
+ else sglang_scheduler_pb2.InputTokenLogProb(value=x)
578
+ )
579
+ for x in token_logprobs_val
580
+ ]
581
+
582
+ # Build TopLogProbs entries
583
+ top_logprobs_proto = []
584
+ if top_logprobs_val and top_logprobs_idx:
585
+ for val_list, idx_list in zip(top_logprobs_val, top_logprobs_idx):
586
+ top_logprobs_proto.append(
587
+ sglang_scheduler_pb2.TopLogProbs(
588
+ values=val_list,
589
+ token_ids=idx_list,
590
+ )
591
+ )
592
+
593
+ return sglang_scheduler_pb2.InputLogProbs(
594
+ token_logprobs=token_logprobs_wrapped,
595
+ token_ids=token_logprobs_idx,
596
+ top_logprobs=top_logprobs_proto,
470
597
  )
471
598
 
472
599
  def _create_chunk_response(
473
600
  self, request_id: str, output: Dict
474
601
  ) -> sglang_scheduler_pb2.GenerateResponse:
475
602
  """Create a streaming chunk response."""
603
+ meta_info = output.get("meta_info", {})
604
+
605
+ # Convert output logprobs if present
606
+ output_logprobs_proto = self._convert_output_logprobs_to_proto(
607
+ output.get("output_logprobs")
608
+ )
609
+
610
+ # Convert input logprobs if present (only in first chunk)
611
+ input_logprobs_proto = self._convert_input_logprobs_to_proto(
612
+ output.get("input_logprobs")
613
+ )
614
+
476
615
  return sglang_scheduler_pb2.GenerateResponse(
477
616
  request_id=request_id,
478
617
  chunk=sglang_scheduler_pb2.GenerateStreamChunk(
479
- token_id=output["token_ids"][-1] if output.get("token_ids") else 0,
480
- text=output.get("text", ""),
481
- prompt_tokens=0,
482
- completion_tokens=len(output.get("token_ids", [])),
483
- cached_tokens=0,
484
- generation_time=time.time() - self.start_time,
485
- queue_time=0.0,
618
+ token_ids=output.get("token_ids", []),
619
+ prompt_tokens=meta_info.get("prompt_tokens", 0),
620
+ completion_tokens=meta_info.get("completion_tokens", 0),
621
+ cached_tokens=meta_info.get("cached_tokens", 0),
622
+ output_logprobs=output_logprobs_proto,
623
+ input_logprobs=input_logprobs_proto,
624
+ index=output.get("index", 0),
486
625
  ),
487
626
  )
488
627
 
@@ -491,20 +630,57 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
491
630
  ) -> sglang_scheduler_pb2.GenerateResponse:
492
631
  """Create a completion response."""
493
632
 
494
- # Determine finish reason
495
- finish_reason = sglang_scheduler_pb2.GenerateComplete.STOP
633
+ # Extract meta info and finish reason details
496
634
  meta_info = output.get("meta_info", {})
497
- if meta_info.get("finish_reason") == "length":
498
- finish_reason = sglang_scheduler_pb2.GenerateComplete.LENGTH
499
- elif meta_info.get("finish_reason") == "eos_token":
500
- finish_reason = sglang_scheduler_pb2.GenerateComplete.EOS_TOKEN
635
+ finish_reason_data = meta_info.get("finish_reason")
636
+
637
+ # Determine finish reason, default is stop
638
+ finish_reason = "stop"
639
+ if finish_reason_data:
640
+ if isinstance(finish_reason_data, dict):
641
+ finish_reason_type = finish_reason_data.get("type")
642
+ else:
643
+ # Handle legacy string format
644
+ finish_reason_type = finish_reason_data
645
+
646
+ if finish_reason_type == "length":
647
+ finish_reason = "length"
648
+ elif finish_reason_type == "abort":
649
+ finish_reason = "abort"
650
+
651
+ # Extract matched_stop information
652
+ matched_stop_kwargs = {}
653
+ if isinstance(finish_reason_data, dict) and "matched" in finish_reason_data:
654
+ matched = finish_reason_data["matched"]
655
+ if isinstance(matched, int):
656
+ matched_stop_kwargs["matched_token_id"] = matched
657
+ elif isinstance(matched, str):
658
+ matched_stop_kwargs["matched_stop_str"] = matched
659
+
660
+ # Convert output logprobs if present
661
+ output_logprobs_proto = self._convert_output_logprobs_to_proto(
662
+ output.get("output_logprobs")
663
+ )
664
+
665
+ # Convert input logprobs if present
666
+ input_logprobs_proto = self._convert_input_logprobs_to_proto(
667
+ output.get("input_logprobs")
668
+ )
501
669
 
502
670
  return sglang_scheduler_pb2.GenerateResponse(
503
671
  request_id=request_id,
504
672
  complete=sglang_scheduler_pb2.GenerateComplete(
505
673
  output_ids=output.get("token_ids", []),
506
- output_text=output.get("text", ""),
507
674
  finish_reason=finish_reason,
675
+ prompt_tokens=meta_info.get("prompt_tokens", 0),
676
+ completion_tokens=meta_info.get(
677
+ "completion_tokens", len(output.get("token_ids", []))
678
+ ),
679
+ cached_tokens=meta_info.get("cached_tokens", 0),
680
+ output_logprobs=output_logprobs_proto,
681
+ input_logprobs=input_logprobs_proto,
682
+ index=output.get("index", 0),
683
+ **matched_stop_kwargs,
508
684
  ),
509
685
  )
510
686
 
@@ -512,6 +688,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
512
688
  """Shutdown the service."""
513
689
  logger.info("Shutting down gRPC service")
514
690
 
691
+ # Mark health service as NOT_SERVING before shutdown
692
+ if self.health_servicer:
693
+ self.health_servicer.set_not_serving()
694
+
515
695
  # Shutdown request manager (handles its own tasks)
516
696
  await self.request_manager.shutdown()
517
697
 
@@ -522,9 +702,19 @@ async def serve_grpc(
522
702
  ):
523
703
  """Start the standalone gRPC server with integrated scheduler."""
524
704
 
705
+ # Start bootstrap server BEFORE launching scheduler processes (only in PREFILL mode)
706
+ # This ensures the bootstrap server is ready when prefill schedulers try to register
707
+ bootstrap_server = None
708
+ if server_args.disaggregation_mode == "prefill":
709
+ bootstrap_server = start_disagg_service(server_args)
710
+ if bootstrap_server:
711
+ logger.info(
712
+ f"Bootstrap server started for disaggregation mode on {server_args.host}:{server_args.disaggregation_bootstrap_port}"
713
+ )
714
+
525
715
  # Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC)
526
716
  logger.info("Launching scheduler process(es)...")
527
- scheduler_info, port_args, scheduler_procs = _launch_scheduler_process_only(
717
+ scheduler_info, port_args, scheduler_procs = launch_scheduler_process_only(
528
718
  server_args=server_args,
529
719
  )
530
720
 
@@ -545,9 +735,11 @@ async def serve_grpc(
545
735
  }
546
736
 
547
737
  # Create request manager with the correct port args
738
+ # Note: We pass None for bootstrap_server since it's already started above
548
739
  request_manager = GrpcRequestManager(
549
740
  server_args=server_args,
550
741
  port_args=port_args,
742
+ bootstrap_server=bootstrap_server,
551
743
  )
552
744
 
553
745
  # Create gRPC server
@@ -559,17 +751,27 @@ async def serve_grpc(
559
751
  ],
560
752
  )
561
753
 
562
- # Add service
754
+ # Create standard health service (for Kubernetes probes)
755
+ health_servicer = SGLangHealthServicer(
756
+ request_manager=request_manager,
757
+ scheduler_info=scheduler_info,
758
+ )
759
+ health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server)
760
+
761
+ # Add SGLang service
563
762
  servicer = SGLangSchedulerServicer(
564
763
  request_manager=request_manager,
565
764
  server_args=server_args,
566
765
  model_info=model_info,
766
+ scheduler_info=scheduler_info,
767
+ health_servicer=health_servicer,
567
768
  )
568
769
  sglang_scheduler_pb2_grpc.add_SglangSchedulerServicer_to_server(servicer, server)
569
770
 
570
771
  # Enable reflection
571
772
  SERVICE_NAMES = (
572
773
  sglang_scheduler_pb2.DESCRIPTOR.services_by_name["SglangScheduler"].full_name,
774
+ "grpc.health.v1.Health",
573
775
  reflection.SERVICE_NAME,
574
776
  )
575
777
  reflection.enable_server_reflection(SERVICE_NAMES, server)
@@ -578,9 +780,15 @@ async def serve_grpc(
578
780
  listen_addr = f"{server_args.host}:{server_args.port}"
579
781
  server.add_insecure_port(listen_addr)
580
782
 
581
- logger.info(f"Starting standalone gRPC server on {listen_addr}")
582
-
583
783
  await server.start()
784
+ logger.info(f"gRPC server listening on {listen_addr}")
785
+
786
+ # Start warmup in a separate thread
787
+ warmup_thread = threading.Thread(
788
+ target=_wait_and_warmup_grpc,
789
+ args=(server_args, None, health_servicer),
790
+ )
791
+ warmup_thread.start()
584
792
 
585
793
  # Handle shutdown signals
586
794
  loop = asyncio.get_running_loop()
@@ -597,84 +805,203 @@ async def serve_grpc(
597
805
  await stop_event.wait()
598
806
  finally:
599
807
  logger.info("Shutting down gRPC server")
808
+
809
+ # Shutdown request manager first - this closes ZMQ sockets and stops background tasks
600
810
  await servicer.shutdown()
811
+
812
+ # Stop the gRPC server
601
813
  await server.stop(5.0)
602
814
 
603
- # Terminate scheduler processes
815
+ # Wait for warmup thread to finish
816
+ if warmup_thread.is_alive():
817
+ logger.info("Waiting for warmup thread to finish...")
818
+ warmup_thread.join(timeout=5.0)
819
+
820
+ # Terminate scheduler processes before exiting to avoid atexit hang
821
+ # The scheduler processes have SIGINT ignored, so they won't get KeyboardInterrupt
604
822
  for i, proc in enumerate(scheduler_procs):
605
- if proc and proc.is_alive():
823
+ if proc.is_alive():
606
824
  logger.info(f"Terminating scheduler process {i}...")
607
825
  proc.terminate()
608
- proc.join(timeout=5.0)
826
+ proc.join(timeout=2.0)
609
827
  if proc.is_alive():
610
- logger.warning(f"Force killing scheduler process {i}...")
828
+ logger.warning(
829
+ f"Scheduler process {i} did not terminate, killing..."
830
+ )
611
831
  proc.kill()
612
- proc.join()
832
+ proc.join(timeout=1.0)
613
833
 
834
+ logger.info("All scheduler processes terminated")
614
835
 
615
- def main():
616
- """Main entry point for standalone gRPC server."""
617
- # Fix CUDA multiprocessing issues - must be called before any CUDA operations
618
- mp.set_start_method("spawn", force=True)
619
836
 
620
- parser = argparse.ArgumentParser(description="SGLang Standalone gRPC Server")
837
+ def _execute_grpc_server_warmup(
838
+ server_args: ServerArgs,
839
+ pipe_finish_writer: Optional[mp.connection.Connection],
840
+ ):
841
+ """Execute warmup for gRPC server by checking health and sending test request."""
842
+ try:
843
+ # Connect to the gRPC server
844
+ grpc_url = f"{server_args.host}:{server_args.port}"
845
+ channel = grpc.insecure_channel(
846
+ grpc_url,
847
+ options=[
848
+ ("grpc.max_send_message_length", 1024 * 1024 * 256),
849
+ ("grpc.max_receive_message_length", 1024 * 1024 * 256),
850
+ ],
851
+ )
852
+ stub = sglang_scheduler_pb2_grpc.SglangSchedulerStub(channel)
621
853
 
622
- # Server arguments
623
- parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
624
- parser.add_argument("--port", type=int, default=30000, help="gRPC server port")
854
+ # Wait until the server is launched (poll GetModelInfo)
855
+ success = False
856
+ last_error = None
857
+ for _ in range(120):
858
+ time.sleep(1)
859
+ try:
860
+ request = sglang_scheduler_pb2.GetModelInfoRequest()
861
+ response = stub.GetModelInfo(request, timeout=5)
862
+ success = True
863
+ break
864
+ except Exception as e:
865
+ last_error = str(e)
866
+ pass
867
+
868
+ if not success:
869
+ error_msg = f"gRPC server warmup failed: Could not connect to server after 120 seconds. Last error: {last_error}"
870
+ logger.error(error_msg)
871
+ if pipe_finish_writer is not None:
872
+ pipe_finish_writer.send(error_msg)
873
+ channel.close()
874
+ kill_process_tree(os.getpid())
875
+ return False
876
+
877
+ # Get model info to determine if it's generation or embedding
878
+ is_generation = response.is_generation
879
+
880
+ # Send a warmup request
881
+ logger.info("Sending warmup request to gRPC server...")
882
+ max_new_tokens = 8 if is_generation else 1
883
+
884
+ if is_generation:
885
+ warmup_request_kwargs = {
886
+ "request_id": f"WARMUP_{time.time()}",
887
+ "tokenized": sglang_scheduler_pb2.TokenizedInput(
888
+ input_ids=[
889
+ 123,
890
+ 456,
891
+ 789,
892
+ 234,
893
+ 567,
894
+ 890,
895
+ 345,
896
+ ], # Random-looking but safe token IDs
897
+ original_text="warmup request",
898
+ ),
899
+ "sampling_params": sglang_scheduler_pb2.SamplingParams(
900
+ temperature=0.0,
901
+ max_new_tokens=max_new_tokens,
902
+ ),
903
+ "stream": False,
904
+ }
905
+
906
+ # Set disaggregation params if needed
907
+ if server_args.disaggregation_mode != DisaggregationMode.NULL:
908
+ warmup_request_kwargs["disaggregated_params"] = (
909
+ sglang_scheduler_pb2.DisaggregatedParams(
910
+ bootstrap_host=FAKE_BOOTSTRAP_HOST,
911
+ bootstrap_room=0,
912
+ )
913
+ )
625
914
 
626
- # Model arguments
627
- parser.add_argument("--model-path", type=str, required=True, help="Model path")
628
- parser.add_argument("--tokenizer-path", type=str, help="Tokenizer path")
629
- parser.add_argument("--context-length", type=int, help="Context length")
630
- parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size")
631
- parser.add_argument("--dp-size", type=int, default=1, help="Data parallel size")
915
+ warmup_request = sglang_scheduler_pb2.GenerateRequest(
916
+ **warmup_request_kwargs
917
+ )
632
918
 
633
- # Runtime arguments
634
- parser.add_argument(
635
- "--max-running-requests", type=int, default=2048, help="Max concurrent requests"
636
- )
637
- parser.add_argument(
638
- "--max-total-tokens", type=int, default=1000000, help="Max total tokens"
639
- )
640
- parser.add_argument(
641
- "--max-prefill-tokens", type=int, default=16384, help="Max prefill tokens"
642
- )
643
- parser.add_argument(
644
- "--attention-backend", type=str, default="flashinfer", help="Attention backend"
645
- )
646
- parser.add_argument("--lora-paths", type=str, help="LoRA adapter paths")
647
-
648
- # Logging
649
- parser.add_argument("--log-level", type=str, default="INFO", help="Logging level")
650
-
651
- args = parser.parse_args()
652
-
653
- # Convert to ServerArgs with gRPC host and port
654
- server_args = ServerArgs(
655
- model_path=args.model_path,
656
- tokenizer_path=args.tokenizer_path or args.model_path,
657
- context_length=args.context_length,
658
- tp_size=args.tp_size,
659
- dp_size=args.dp_size,
660
- max_running_requests=args.max_running_requests,
661
- max_total_tokens=args.max_total_tokens,
662
- max_prefill_tokens=args.max_prefill_tokens,
663
- attention_backend=args.attention_backend,
664
- lora_paths=args.lora_paths.split(",") if args.lora_paths else None,
665
- log_level=args.log_level,
666
- # Override with gRPC server host and port
667
- host=args.host,
668
- port=args.port,
669
- )
919
+ # Send the warmup request
920
+ try:
921
+ responses = list(stub.Generate(warmup_request, timeout=600))
922
+ # Check if we got a valid response
923
+ if responses and not responses[-1].HasField("error"):
924
+ logger.info("gRPC warmup request completed successfully")
925
+ success = True
926
+ else:
927
+ error_msg = (
928
+ responses[-1].error.message if responses else "No response"
929
+ )
930
+ logger.warning(f"gRPC warmup request returned error: {error_msg}")
931
+ success = False
932
+ except Exception as e:
933
+ error_msg = f"gRPC warmup request failed: {e}"
934
+ logger.error(error_msg)
935
+ if pipe_finish_writer is not None:
936
+ pipe_finish_writer.send(error_msg)
937
+ channel.close()
938
+ kill_process_tree(os.getpid())
939
+ return False
940
+ else:
941
+ # For embedding models
942
+ warmup_request = sglang_scheduler_pb2.EmbedRequest(
943
+ request_id=f"WARMUP_{time.time()}",
944
+ tokenized=sglang_scheduler_pb2.TokenizedInput(
945
+ input_ids=[10, 11, 12],
946
+ original_text="test embedding",
947
+ ),
948
+ )
670
949
 
671
- # Run server
672
- asyncio.run(
673
- serve_grpc(
674
- server_args=server_args,
950
+ try:
951
+ response = stub.Embed(warmup_request, timeout=600)
952
+ if not response.HasField("error"):
953
+ logger.info("gRPC warmup request completed successfully")
954
+ success = True
955
+ else:
956
+ logger.warning(
957
+ f"gRPC warmup request returned error: {response.error.message}"
958
+ )
959
+ success = False
960
+ except Exception as e:
961
+ error_msg = f"gRPC warmup request failed: {e}"
962
+ logger.error(error_msg)
963
+ if pipe_finish_writer is not None:
964
+ pipe_finish_writer.send(error_msg)
965
+ channel.close()
966
+ kill_process_tree(os.getpid())
967
+ return False
968
+
969
+ channel.close()
970
+ return success
971
+
972
+ except Exception as e:
973
+ error_msg = (
974
+ f"gRPC warmup failed with exception: {e}\n{get_exception_traceback()}"
675
975
  )
676
- )
976
+ logger.error(error_msg)
977
+ if pipe_finish_writer is not None:
978
+ pipe_finish_writer.send(error_msg)
979
+ try:
980
+ channel.close()
981
+ except Exception:
982
+ pass
983
+ kill_process_tree(os.getpid())
984
+ return False
985
+
986
+
987
+ def _wait_and_warmup_grpc(
988
+ server_args: ServerArgs,
989
+ pipe_finish_writer: Optional[mp.connection.Connection],
990
+ health_servicer: Optional[SGLangHealthServicer] = None,
991
+ ):
992
+ """Wait for gRPC server to be ready and execute warmup."""
993
+ if not server_args.skip_server_warmup:
994
+ if not _execute_grpc_server_warmup(server_args, pipe_finish_writer):
995
+ return
996
+ else:
997
+ logger.info("Skipping gRPC server warmup (skip_server_warmup=True)")
998
+
999
+ # Mark health service as SERVING after warmup completes
1000
+ if health_servicer:
1001
+ health_servicer.set_serving()
1002
+ logger.info("Health service marked as SERVING")
677
1003
 
1004
+ logger.info("The server is fired up and ready to roll!")
678
1005
 
679
- if __name__ == "__main__":
680
- main()
1006
+ if pipe_finish_writer is not None:
1007
+ pipe_finish_writer.send("ready")