sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,842 @@
1
+ """
2
+ gRPC Request Manager - Orchestrates request lifecycle without tokenization.
3
+ Mimics TokenizerManager's state management and ZMQ communication patterns.
4
+ """
5
+
6
+ import asyncio
7
+ import copy
8
+ import dataclasses
9
+ import logging
10
+ import os
11
+ import signal
12
+ import sys
13
+ import threading
14
+ import time
15
+ import uuid
16
+ from typing import Any, AsyncGenerator, Dict, List, Optional, Union
17
+
18
+ import grpc
19
+ import zmq
20
+ import zmq.asyncio
21
+
22
+ from sglang.srt.managers.io_struct import (
23
+ AbortReq,
24
+ BatchEmbeddingOutput,
25
+ BatchTokenIDOutput,
26
+ HealthCheckOutput,
27
+ TokenizedEmbeddingReqInput,
28
+ TokenizedGenerateReqInput,
29
+ )
30
+ from sglang.srt.managers.scheduler import is_health_check_generate_req
31
+ from sglang.srt.server_args import PortArgs, ServerArgs
32
+ from sglang.srt.utils import get_zmq_socket, kill_process_tree
33
+ from sglang.utils import get_exception_traceback
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class GrpcSignalHandler:
39
+ """Minimal signal handler for gRPC server - delegates real crash handling to scheduler."""
40
+
41
+ def __init__(self, grpc_manager):
42
+ self.grpc_manager = grpc_manager
43
+
44
+ def sigterm_handler(self, signum=None, frame=None):
45
+ """Handle SIGTERM by gracefully shutting down gRPC server."""
46
+ logger.warning(
47
+ f"SIGTERM received. {signum=} {frame=}. Shutting down gRPC server..."
48
+ )
49
+ self.grpc_manager.gracefully_exit = True
50
+
51
+ def running_phase_sigquit_handler(self, signum=None, frame=None):
52
+ """Handle SIGQUIT from failed scheduler process."""
53
+ logger.error(
54
+ "Received SIGQUIT from scheduler process. Scheduler failed, shutting down gRPC server."
55
+ )
56
+ logger.info(
57
+ "Note: Crash dumps are handled by the scheduler process, not the gRPC server."
58
+ )
59
+ # Just exit cleanly - the scheduler handles crash dumps
60
+ kill_process_tree(os.getpid(), include_parent=True)
61
+
62
+
63
+ @dataclasses.dataclass
64
+ class GrpcReqState:
65
+ """State tracking for a gRPC request."""
66
+
67
+ # Request identification
68
+ request_id: str
69
+ grpc_context: Optional[grpc.aio.ServicerContext]
70
+
71
+ # Communication
72
+ out_queue: asyncio.Queue
73
+ finished: bool
74
+ event: asyncio.Event
75
+ obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]
76
+
77
+ # Metrics (same as TokenizerManager's ReqState)
78
+ created_time: float
79
+ finished_time: float = 0.0
80
+ first_token_time: float = 0.0
81
+ last_time: float = 0.0
82
+ last_completion_tokens: int = 1
83
+
84
+ # Streaming state
85
+ stream_finished: bool = False
86
+ input_logprobs_sent: bool = False # Track if input logprobs were sent in streaming
87
+
88
+ # Token accumulation (for non-streaming)
89
+ output_ids: List[int] = dataclasses.field(default_factory=list)
90
+ input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
91
+ input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
92
+ output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
93
+ output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
94
+ input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
95
+ input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
96
+ output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
97
+ output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
98
+
99
+ # Session state
100
+ session_id: Optional[str] = None
101
+ is_session_request: bool = False
102
+
103
+
104
+ class GrpcRequestManager:
105
+ """
106
+ Manages gRPC request lifecycle, mimicking TokenizerManager's orchestration
107
+ behaviors without tokenization.
108
+ """
109
+
110
+ def __init__(
111
+ self,
112
+ server_args: ServerArgs,
113
+ port_args: PortArgs,
114
+ bootstrap_server=None,
115
+ ):
116
+ """Initialize the gRPC request manager."""
117
+ self.server_args = server_args
118
+ self.port_args = port_args
119
+
120
+ # ZMQ Communication Setup (same pattern as TokenizerManager)
121
+ self.context = zmq.asyncio.Context(2)
122
+
123
+ # Socket for receiving outputs from scheduler
124
+ self.recv_from_scheduler = get_zmq_socket(
125
+ self.context, zmq.PULL, port_args.detokenizer_ipc_name, bind=True
126
+ )
127
+
128
+ # Socket for sending requests to scheduler
129
+ self.send_to_scheduler = get_zmq_socket(
130
+ self.context, zmq.PUSH, port_args.scheduler_input_ipc_name, bind=True
131
+ )
132
+
133
+ # State Management (from TokenizerManager)
134
+ self.rid_to_state: Dict[str, GrpcReqState] = {}
135
+ self.asyncio_tasks: set = set()
136
+ self.gracefully_exit = False
137
+ self.no_create_loop = False
138
+ self.event_loop = None
139
+
140
+ # Pause/Resume Control
141
+ self.is_pause = False
142
+ self.is_pause_cond = asyncio.Condition()
143
+
144
+ # Metrics
145
+ self.last_receive_tstamp = time.time()
146
+
147
+ # Crash dump for debugging
148
+ self.crash_dump_request_list = []
149
+ self.crash_dump_performed = False
150
+
151
+ # Bootstrap server (passed from serve_grpc, not started here)
152
+ self.bootstrap_server = bootstrap_server
153
+
154
+ logger.info(
155
+ f"GrpcRequestManager initialized with ZMQ IPC: "
156
+ f"recv={port_args.detokenizer_ipc_name}, "
157
+ f"send={port_args.scheduler_input_ipc_name}"
158
+ )
159
+ if self.bootstrap_server:
160
+ logger.info(
161
+ f"Bootstrap server initialized for disaggregation mode: "
162
+ f"{server_args.disaggregation_mode}"
163
+ )
164
+
165
+ async def generate_request(
166
+ self,
167
+ obj: TokenizedGenerateReqInput,
168
+ request_id: Optional[str] = None,
169
+ grpc_context: Optional[grpc.aio.ServicerContext] = None,
170
+ ) -> AsyncGenerator[Union[Dict, List[Dict]], None]:
171
+ """
172
+ Submit a generation request to the scheduler with n>1 parallel sampling support.
173
+
174
+ This method implements the same two-phase approach as tokenizer_manager.py:
175
+ 1. Phase 1: Send prefix caching request (max_new_tokens=0)
176
+ 2. Phase 2: Send n generation requests that reuse the cached prefix
177
+
178
+ Yields individual responses for streaming, or aggregated responses for non-streaming.
179
+ """
180
+ n = getattr(obj.sampling_params, "n", 1)
181
+
182
+ if n <= 1:
183
+ async for response in self._handle_single_request(
184
+ obj, request_id, grpc_context
185
+ ):
186
+ yield response
187
+ return
188
+
189
+ # N>1 handling - two-phase approach
190
+ logger.debug(f"Multiple sampling request (n={n}), using two-phase approach")
191
+
192
+ # Generate base request ID if not provided
193
+ if request_id is None:
194
+ base_request_id = f"grpc-{uuid.uuid4().hex}"
195
+ else:
196
+ base_request_id = request_id
197
+
198
+ # Phase 1: Cache the common prefix
199
+ logger.debug(f"Phase 1: Caching prefix for request {base_request_id}")
200
+ prefix_obj = copy.copy(obj)
201
+ prefix_obj.sampling_params = copy.copy(obj.sampling_params)
202
+ prefix_obj.sampling_params.max_new_tokens = 0 # Prefill-only
203
+ prefix_obj.sampling_params.n = 1 # Don't replicate prefix request
204
+
205
+ # Send prefix caching request and consume response
206
+ async for _ in self._handle_single_request(
207
+ prefix_obj, f"{base_request_id}-prefix", grpc_context
208
+ ):
209
+ # Consume prefix response (usually just one chunk with finish_reason)
210
+ pass
211
+
212
+ logger.debug(f"Phase 1 completed: Prefix cached for {base_request_id}")
213
+
214
+ # Phase 2: Generate n parallel requests
215
+ logger.debug(f"Phase 2: Generating {n} parallel requests")
216
+ generators = []
217
+ request_ids = []
218
+
219
+ for i in range(n):
220
+ # Create individual generation request
221
+ gen_obj = copy.copy(obj)
222
+ gen_obj.sampling_params = copy.copy(obj.sampling_params)
223
+ gen_obj.sampling_params.n = 1 # Each request generates 1 response
224
+
225
+ gen_request_id = f"{base_request_id}-{i}"
226
+ request_ids.append(gen_request_id)
227
+
228
+ # Start generation request
229
+ generators.append(
230
+ self._handle_single_request(gen_obj, gen_request_id, grpc_context)
231
+ )
232
+
233
+ # Handle response aggregation
234
+ is_stream = getattr(obj, "stream", False)
235
+
236
+ if not is_stream:
237
+ # Non-streaming: collect all responses and return as batch
238
+ logger.debug(f"Non-streaming mode: collecting {n} responses")
239
+ responses = []
240
+ for generator in generators:
241
+ async for response in generator:
242
+ responses.append(response)
243
+ yield responses # Return all responses as a batch
244
+ else:
245
+ # Streaming mode: multiplex responses with index for ordering
246
+ logger.debug(f"Streaming mode: multiplexing {n} streams")
247
+ rid_to_index = {rid: i for i, rid in enumerate(request_ids)}
248
+
249
+ # Create async tasks for all generators
250
+ task_map = {}
251
+ for generator in generators:
252
+ task = asyncio.create_task(generator.__anext__())
253
+ task_map[task] = generator
254
+
255
+ # Process responses as they arrive
256
+ while task_map:
257
+ done, _ = await asyncio.wait(
258
+ task_map.keys(), return_when=asyncio.FIRST_COMPLETED
259
+ )
260
+
261
+ for task in done:
262
+ generator = task_map.pop(task)
263
+ try:
264
+ response = await task
265
+
266
+ # Add index for client-side ordering
267
+ if isinstance(response, dict):
268
+ response_rid = response.get("request_id", "")
269
+ if response_rid in rid_to_index:
270
+ response["index"] = rid_to_index[response_rid]
271
+
272
+ yield response
273
+
274
+ # Create next task for this generator
275
+ next_task = asyncio.create_task(generator.__anext__())
276
+ task_map[next_task] = generator
277
+
278
+ except StopAsyncIteration:
279
+ # This generator is finished
280
+ pass
281
+
282
+ async def _handle_single_request(
283
+ self,
284
+ obj: TokenizedGenerateReqInput,
285
+ request_id: Optional[str] = None,
286
+ grpc_context: Optional[grpc.aio.ServicerContext] = None,
287
+ ):
288
+ """Handle a single request - core implementation without n>1 logic."""
289
+ # Generate request ID if not provided
290
+ if request_id is None:
291
+ request_id = f"grpc-{uuid.uuid4().hex}"
292
+
293
+ obj.rid = request_id
294
+
295
+ # Create and register request state
296
+ # TODO: support log_request
297
+ state = GrpcReqState(
298
+ request_id=request_id,
299
+ grpc_context=grpc_context,
300
+ out_queue=asyncio.Queue(),
301
+ finished=False,
302
+ event=asyncio.Event(),
303
+ obj=obj,
304
+ created_time=time.time(),
305
+ )
306
+
307
+ # Track session if needed
308
+ if hasattr(obj, "session_params") and obj.session_params:
309
+ state.session_id = obj.session_params.session_id
310
+ state.is_session_request = True
311
+
312
+ self.rid_to_state[request_id] = state
313
+ self.record_request_for_crash_dump(obj)
314
+
315
+ try:
316
+ # Send to scheduler - let exceptions bubble up to grpc_server.py
317
+ await self._send_to_scheduler(obj)
318
+
319
+ is_stream = getattr(obj, "stream", False)
320
+
321
+ while True:
322
+ # Client cancelled - notify scheduler and exit
323
+ if grpc_context and grpc_context.cancelled():
324
+ await self.abort_request(request_id)
325
+ return
326
+
327
+ try:
328
+ response = await asyncio.wait_for(state.out_queue.get(), timeout=4)
329
+
330
+ if is_stream:
331
+ yield response
332
+
333
+ # Non-streaming: yield final response with accumulated tokens from state
334
+ if isinstance(response, dict) and response.get("finished", False):
335
+ if not is_stream:
336
+ final_response = response.copy()
337
+ final_response["token_ids"] = state.output_ids
338
+ yield final_response
339
+ break
340
+
341
+ except asyncio.TimeoutError:
342
+ # Timeout is for periodic client cancellation check
343
+ # Continue waiting for scheduler response
344
+ continue
345
+
346
+ finally:
347
+ # Always clean up request state when exiting
348
+ self._cleanup_request_state(request_id)
349
+
350
+ def _cleanup_request_state(self, request_id: str):
351
+ """Clean up local request state (does not notify scheduler)."""
352
+ if request_id in self.rid_to_state:
353
+ del self.rid_to_state[request_id]
354
+
355
+ async def embedding_request(
356
+ self,
357
+ obj: TokenizedEmbeddingReqInput,
358
+ request_id: Optional[str] = None,
359
+ ) -> asyncio.Future:
360
+ """
361
+ Submit an embedding request to the scheduler.
362
+ Returns a future that will contain the embedding result.
363
+ """
364
+ # Generate request ID if not provided
365
+ if request_id is None:
366
+ request_id = f"grpc-embed-{uuid.uuid4().hex}"
367
+
368
+ obj.rid = request_id
369
+
370
+ # Create request state
371
+ state = GrpcReqState(
372
+ request_id=request_id,
373
+ grpc_context=None,
374
+ out_queue=asyncio.Queue(),
375
+ finished=False,
376
+ event=asyncio.Event(),
377
+ obj=obj,
378
+ created_time=time.time(),
379
+ )
380
+
381
+ # Register state
382
+ self.rid_to_state[request_id] = state
383
+
384
+ # Create future for result
385
+ future = asyncio.Future()
386
+
387
+ # Send to scheduler
388
+ try:
389
+ await self._send_to_scheduler(obj)
390
+ except Exception as e:
391
+ del self.rid_to_state[request_id]
392
+ future.set_exception(e)
393
+ return future
394
+
395
+ # Wait for result in background
396
+ async def wait_for_result():
397
+ try:
398
+ await state.event.wait()
399
+ result = await state.out_queue.get()
400
+ future.set_result(result)
401
+ except Exception as e:
402
+ future.set_exception(e)
403
+ finally:
404
+ # Clean up
405
+ if request_id in self.rid_to_state:
406
+ del self.rid_to_state[request_id]
407
+
408
+ asyncio.create_task(wait_for_result())
409
+ return future
410
+
411
+ async def abort_request(self, request_id: str) -> bool:
412
+ """Abort a running request."""
413
+ # Skip aborting health check requests (they clean themselves up)
414
+ if request_id.startswith("HEALTH_CHECK"):
415
+ return False
416
+
417
+ if request_id not in self.rid_to_state:
418
+ return False
419
+
420
+ # Send abort to scheduler
421
+ abort_req = AbortReq(rid=request_id)
422
+ try:
423
+ await self._send_to_scheduler(abort_req)
424
+ except Exception as e:
425
+ logger.error(f"Failed to send abort request: {e}")
426
+ return False
427
+
428
+ # Mark as finished
429
+ state = self.rid_to_state.get(request_id)
430
+ if state:
431
+ state.finished = True
432
+ state.stream_finished = True
433
+ state.event.set()
434
+
435
+ # Send abort notification to output queue
436
+ await state.out_queue.put({"error": "Request aborted", "abort": True})
437
+
438
+ return True
439
+
440
+ async def handle_loop(self):
441
+ """
442
+ Main event loop - processes outputs from scheduler.
443
+ Mimics TokenizerManager's handle_loop.
444
+ """
445
+ while not self.gracefully_exit:
446
+ try:
447
+ # Receive from scheduler
448
+ recv_obj = await self.recv_from_scheduler.recv_pyobj()
449
+ self.last_receive_tstamp = time.time()
450
+
451
+ # Check for pause
452
+ async with self.is_pause_cond:
453
+ while self.is_pause:
454
+ await self.is_pause_cond.wait()
455
+
456
+ # Handle different output types
457
+ if isinstance(recv_obj, BatchTokenIDOutput):
458
+ await self._handle_batch_output(recv_obj)
459
+ elif isinstance(recv_obj, BatchEmbeddingOutput):
460
+ await self._handle_embedding_output(recv_obj)
461
+ elif isinstance(recv_obj, HealthCheckOutput):
462
+ await self._handle_health_check_output(recv_obj)
463
+ else:
464
+ logger.warning(f"Unknown output type: {type(recv_obj)}")
465
+
466
+ except zmq.error.Again:
467
+ # Timeout, check if we should exit
468
+ if self.gracefully_exit:
469
+ break
470
+ continue
471
+ except zmq.error.ZMQError as e:
472
+ # Socket closed or other ZMQ error - exit cleanly if shutting down
473
+ if self.gracefully_exit:
474
+ logger.debug(f"ZMQ recv interrupted during shutdown: {e}")
475
+ break
476
+ logger.error(
477
+ f"ZMQ error in handle loop: {e}\n{get_exception_traceback()}"
478
+ )
479
+ break
480
+ except Exception as e:
481
+ logger.error(f"Handle loop error: {e}\n{get_exception_traceback()}")
482
+ if self.gracefully_exit:
483
+ break
484
+
485
+ def _convert_logprob_style(
486
+ self,
487
+ state: GrpcReqState,
488
+ batch_out: BatchTokenIDOutput,
489
+ batch_index: int,
490
+ ):
491
+ """
492
+ Convert and accumulate logprobs from batch output to state.
493
+ Follows the same logic as tokenizer_manager.convert_logprob_style.
494
+ """
495
+ # Early exit if no input logprobs at all
496
+ if batch_out.input_token_logprobs_val is None:
497
+ return
498
+
499
+ # Accumulate input token logprobs (only if list is non-empty)
500
+ if len(batch_out.input_token_logprobs_val) > 0:
501
+ state.input_token_logprobs_val.extend(
502
+ batch_out.input_token_logprobs_val[batch_index]
503
+ )
504
+ state.input_token_logprobs_idx.extend(
505
+ batch_out.input_token_logprobs_idx[batch_index]
506
+ )
507
+
508
+ # Always accumulate output token logprobs
509
+ state.output_token_logprobs_val.extend(
510
+ batch_out.output_token_logprobs_val[batch_index]
511
+ )
512
+ state.output_token_logprobs_idx.extend(
513
+ batch_out.output_token_logprobs_idx[batch_index]
514
+ )
515
+
516
+ # Handle top logprobs if requested
517
+ if state.obj.top_logprobs_num > 0:
518
+ # Accumulate input top logprobs (only if list is non-empty)
519
+ if len(batch_out.input_top_logprobs_val) > 0:
520
+ state.input_top_logprobs_val.extend(
521
+ batch_out.input_top_logprobs_val[batch_index]
522
+ )
523
+ state.input_top_logprobs_idx.extend(
524
+ batch_out.input_top_logprobs_idx[batch_index]
525
+ )
526
+
527
+ # Always accumulate output top logprobs
528
+ state.output_top_logprobs_val.extend(
529
+ batch_out.output_top_logprobs_val[batch_index]
530
+ )
531
+ state.output_top_logprobs_idx.extend(
532
+ batch_out.output_top_logprobs_idx[batch_index]
533
+ )
534
+
535
+ async def _handle_batch_output(self, batch_out: BatchTokenIDOutput):
536
+ """Handle batch generation output from scheduler."""
537
+ # Process each request in the batch
538
+ for i, rid in enumerate(batch_out.rids):
539
+ if rid not in self.rid_to_state:
540
+ continue
541
+
542
+ state = self.rid_to_state[rid]
543
+
544
+ # Update metrics
545
+ now = time.time()
546
+ if state.first_token_time == 0.0:
547
+ state.first_token_time = now
548
+ state.last_time = now
549
+
550
+ # Extract output for this request
551
+ output_data = {
552
+ "request_id": rid,
553
+ "token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
554
+ "finished": batch_out.finished_reasons[i] is not None,
555
+ "meta_info": {
556
+ "prompt_tokens": (
557
+ batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
558
+ ),
559
+ "completion_tokens": (
560
+ batch_out.completion_tokens[i]
561
+ if batch_out.completion_tokens
562
+ else 0
563
+ ),
564
+ "cached_tokens": (
565
+ batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
566
+ ),
567
+ "finish_reason": (
568
+ batch_out.finished_reasons[i]
569
+ if batch_out.finished_reasons[i]
570
+ else None
571
+ ),
572
+ },
573
+ }
574
+
575
+ # Accumulate logprobs (following tokenizer_manager pattern)
576
+ if state.obj.return_logprob:
577
+ self._convert_logprob_style(state, batch_out, i)
578
+
579
+ # Send input logprobs based if available
580
+ if (
581
+ state.obj.return_logprob
582
+ and state.obj.logprob_start_len >= 0
583
+ and state.input_token_logprobs_val
584
+ ):
585
+ if state.obj.stream and not state.input_logprobs_sent:
586
+ # Streaming: send input logprobs once in first chunk that has them
587
+ output_data["input_logprobs"] = {
588
+ "token_logprobs_val": state.input_token_logprobs_val,
589
+ "token_logprobs_idx": state.input_token_logprobs_idx,
590
+ "top_logprobs_val": state.input_top_logprobs_val,
591
+ "top_logprobs_idx": state.input_top_logprobs_idx,
592
+ }
593
+ state.input_logprobs_sent = True
594
+ elif not state.obj.stream and output_data["finished"]:
595
+ # Non-streaming: send input logprobs in final chunk
596
+ output_data["input_logprobs"] = {
597
+ "token_logprobs_val": state.input_token_logprobs_val,
598
+ "token_logprobs_idx": state.input_token_logprobs_idx,
599
+ "top_logprobs_val": state.input_top_logprobs_val,
600
+ "top_logprobs_idx": state.input_top_logprobs_idx,
601
+ }
602
+
603
+ # Send output logprobs if available
604
+ if (
605
+ state.obj.return_logprob
606
+ and batch_out.output_token_logprobs_val
607
+ and i < len(batch_out.output_token_logprobs_val)
608
+ ):
609
+ if state.obj.stream:
610
+ # For streaming: send incremental logprobs (only new tokens in this chunk)
611
+ # NOTE: this is different than TokenizerManager, which always accumulates
612
+ def get_part(attr_name):
613
+ source_list = getattr(batch_out, attr_name, None)
614
+ return (
615
+ source_list[i]
616
+ if source_list and i < len(source_list)
617
+ else []
618
+ )
619
+
620
+ output_data["output_logprobs"] = {
621
+ "token_logprobs_val": batch_out.output_token_logprobs_val[i],
622
+ "token_logprobs_idx": get_part("output_token_logprobs_idx"),
623
+ "top_logprobs_val": get_part("output_top_logprobs_val"),
624
+ "top_logprobs_idx": get_part("output_top_logprobs_idx"),
625
+ }
626
+ elif output_data["finished"]:
627
+ # Non-streaming: send cumulative output logprobs in final chunk
628
+ output_data["output_logprobs"] = {
629
+ "token_logprobs_val": state.output_token_logprobs_val,
630
+ "token_logprobs_idx": state.output_token_logprobs_idx,
631
+ "top_logprobs_val": state.output_top_logprobs_val,
632
+ "top_logprobs_idx": state.output_top_logprobs_idx,
633
+ }
634
+
635
+ # Update state for accumulation
636
+ if output_data["token_ids"]:
637
+ state.output_ids.extend(output_data["token_ids"])
638
+
639
+ await state.out_queue.put(output_data)
640
+
641
+ # Handle completion
642
+ if output_data["finished"]:
643
+ state.finished = True
644
+ state.finished_time = now
645
+ state.stream_finished = True
646
+ state.event.set()
647
+
648
+ # Remove from tracking after a delay
649
+ async def cleanup():
650
+ await asyncio.sleep(5.0)
651
+ if rid in self.rid_to_state:
652
+ del self.rid_to_state[rid]
653
+
654
+ asyncio.create_task(cleanup())
655
+
656
+ async def _handle_embedding_output(self, batch_out: BatchEmbeddingOutput):
657
+ """Handle batch embedding output from scheduler."""
658
+ for i, rid in enumerate(batch_out.rids):
659
+ if rid not in self.rid_to_state:
660
+ continue
661
+
662
+ state = self.rid_to_state[rid]
663
+
664
+ # Create result
665
+ result = {
666
+ "request_id": rid,
667
+ "embedding": batch_out.embeddings[i],
668
+ "prompt_tokens": (
669
+ batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
670
+ ),
671
+ "finish_reason": (
672
+ batch_out.finish_reason[i] if batch_out.finish_reason else None
673
+ ),
674
+ }
675
+
676
+ # Send result
677
+ await state.out_queue.put(result)
678
+
679
+ # Mark as finished
680
+ state.finished = True
681
+ state.finished_time = time.time()
682
+ state.event.set()
683
+
684
+ async def _handle_health_check_output(self, health_out: HealthCheckOutput):
685
+ """Handle health check output from scheduler."""
686
+ rid = health_out.rid
687
+
688
+ if rid not in self.rid_to_state:
689
+ logger.warning(f"Health check output for unknown request: {rid}")
690
+ return
691
+
692
+ state = self.rid_to_state[rid]
693
+
694
+ # Create health check result
695
+ result = {
696
+ "request_id": rid,
697
+ "healthy": True, # If we got a response, scheduler is healthy
698
+ "output_text": (
699
+ health_out.output_str if hasattr(health_out, "output_str") else ""
700
+ ),
701
+ "finish_reason": (
702
+ health_out.finish_reason
703
+ if hasattr(health_out, "finish_reason")
704
+ else "stop"
705
+ ),
706
+ }
707
+
708
+ # Send result
709
+ await state.out_queue.put(result)
710
+
711
+ # Mark as finished
712
+ state.finished = True
713
+ state.finished_time = time.time()
714
+ state.event.set()
715
+
716
+ async def _send_to_scheduler(self, obj):
717
+ """Send an object to the scheduler via ZMQ."""
718
+ try:
719
+ self.send_to_scheduler.send_pyobj(obj)
720
+ except Exception as e:
721
+ logger.error(f"Failed to send to scheduler: {e}")
722
+ raise
723
+
724
+ def record_request_for_crash_dump(self, obj):
725
+ """Record request for potential crash dump."""
726
+ if len(self.crash_dump_request_list) < 100:
727
+ self.crash_dump_request_list.append(
728
+ {
729
+ "time": time.time(),
730
+ "request_id": getattr(obj, "rid", "unknown"),
731
+ "type": type(obj).__name__,
732
+ }
733
+ )
734
+
735
+ async def shutdown(self):
736
+ """Gracefully shutdown the request manager."""
737
+ logger.info("Shutting down GrpcRequestManager")
738
+ self.gracefully_exit = True
739
+
740
+ # Cancel all asyncio tasks FIRST - this will interrupt blocked recv() calls
741
+ for task in list(self.asyncio_tasks):
742
+ if not task.done():
743
+ task.cancel()
744
+
745
+ # Give tasks a moment to process cancellation
746
+ if self.asyncio_tasks:
747
+ await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
748
+
749
+ # Cancel all pending requests
750
+ for rid, state in list(self.rid_to_state.items()):
751
+ if not state.finished:
752
+ await state.out_queue.put(
753
+ {"error": "Server shutting down", "shutdown": True}
754
+ )
755
+ state.finished = True
756
+ state.event.set()
757
+
758
+ # Wait for tasks to complete
759
+ if self.asyncio_tasks:
760
+ await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
761
+
762
+ # Shutdown bootstrap server if running
763
+ if self.bootstrap_server:
764
+ logger.info("Shutting down bootstrap server")
765
+ try:
766
+ if hasattr(self.bootstrap_server, "shutdown"):
767
+ if asyncio.iscoroutinefunction(self.bootstrap_server.shutdown):
768
+ await self.bootstrap_server.shutdown()
769
+ else:
770
+ self.bootstrap_server.shutdown()
771
+ except Exception as e:
772
+ logger.warning(f"Error shutting down bootstrap server: {e}")
773
+
774
+ # Close ZMQ sockets
775
+ self.recv_from_scheduler.close()
776
+ self.send_to_scheduler.close()
777
+
778
+ # Terminate the ZMQ context - this is critical for asyncio loop to exit cleanly
779
+ self.context.term()
780
+
781
+ logger.info("GrpcRequestManager shutdown complete")
782
+
783
+ def get_server_info(self) -> Dict[str, Any]:
784
+ """Get server information for health checks."""
785
+ return {
786
+ "active_requests": len(self.rid_to_state),
787
+ "paused": self.is_pause,
788
+ "last_receive_time": self.last_receive_tstamp,
789
+ }
790
+
791
+ def auto_create_handle_loop(self):
792
+ """Automatically create and start the handle_loop task, matching TokenizerManager pattern."""
793
+ if self.no_create_loop:
794
+ return
795
+
796
+ self.no_create_loop = True
797
+ loop = asyncio.get_event_loop()
798
+ self.asyncio_tasks.add(
799
+ loop.create_task(print_exception_wrapper(self.handle_loop))
800
+ )
801
+
802
+ self.event_loop = loop
803
+
804
+ # We cannot add signal handler when the grpc manager is not in
805
+ # the main thread due to the CPython limitation.
806
+ if threading.current_thread() is threading.main_thread():
807
+ signal_handler = GrpcSignalHandler(self)
808
+ loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
809
+ # Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
810
+ loop.add_signal_handler(
811
+ signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
812
+ )
813
+ else:
814
+ logger.warning(
815
+ "Signal handler is not added because the grpc request manager is "
816
+ "not in the main thread. This disables graceful shutdown of the "
817
+ "grpc request manager when SIGTERM is received."
818
+ )
819
+ self.asyncio_tasks.add(
820
+ loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
821
+ )
822
+
823
+ async def sigterm_watchdog(self):
824
+ """Watchdog to handle SIGTERM gracefully, matching TokenizerManager pattern."""
825
+ while not self.gracefully_exit:
826
+ await asyncio.sleep(1.0)
827
+
828
+
829
+ async def print_exception_wrapper(func):
830
+ """
831
+ Sometimes an asyncio function does not print exception.
832
+ We do another wrapper to handle the exception.
833
+ """
834
+ try:
835
+ await func()
836
+ except Exception:
837
+ traceback = get_exception_traceback()
838
+ logger.error(f"GrpcRequestManager hit an exception: {traceback}")
839
+ if hasattr(func, "__self__") and isinstance(func.__self__, GrpcRequestManager):
840
+ func.__self__.dump_requests_before_crash()
841
+ kill_process_tree(os.getpid(), include_parent=True)
842
+ sys.exit(1)