sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,855 @@
1
+ """
2
+ gRPC Request Manager - Orchestrates request lifecycle without tokenization.
3
+ Mimics TokenizerManager's state management and ZMQ communication patterns.
4
+ """
5
+
6
+ import asyncio
7
+ import copy
8
+ import dataclasses
9
+ import logging
10
+ import os
11
+ import signal
12
+ import sys
13
+ import threading
14
+ import time
15
+ import uuid
16
+ from typing import Any, AsyncGenerator, Dict, List, Optional, Union
17
+
18
+ import grpc
19
+ import zmq
20
+ import zmq.asyncio
21
+
22
+ from sglang.srt.managers.io_struct import (
23
+ AbortReq,
24
+ BatchEmbeddingOutput,
25
+ BatchTokenIDOutput,
26
+ HealthCheckOutput,
27
+ TokenizedEmbeddingReqInput,
28
+ TokenizedGenerateReqInput,
29
+ )
30
+ from sglang.srt.server_args import PortArgs, ServerArgs
31
+ from sglang.srt.utils import get_zmq_socket, kill_process_tree
32
+ from sglang.utils import get_exception_traceback
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class GrpcSignalHandler:
38
+ """Minimal signal handler for gRPC server - delegates real crash handling to scheduler."""
39
+
40
+ def __init__(self, grpc_manager):
41
+ self.grpc_manager = grpc_manager
42
+
43
+ def sigterm_handler(self, signum=None, frame=None):
44
+ """Handle SIGTERM by gracefully shutting down gRPC server."""
45
+ logger.warning(
46
+ f"SIGTERM received. {signum=} {frame=}. Shutting down gRPC server..."
47
+ )
48
+ self.grpc_manager.gracefully_exit = True
49
+
50
+ def running_phase_sigquit_handler(self, signum=None, frame=None):
51
+ """Handle SIGQUIT from failed scheduler process."""
52
+ logger.error(
53
+ "Received SIGQUIT from scheduler process. Scheduler failed, shutting down gRPC server."
54
+ )
55
+ logger.info(
56
+ "Note: Crash dumps are handled by the scheduler process, not the gRPC server."
57
+ )
58
+ # Just exit cleanly - the scheduler handles crash dumps
59
+ kill_process_tree(os.getpid(), include_parent=True)
60
+
61
+
62
+ @dataclasses.dataclass
63
+ class GrpcReqState:
64
+ """State tracking for a gRPC request."""
65
+
66
+ # Request identification
67
+ request_id: str
68
+ grpc_context: Optional[grpc.aio.ServicerContext]
69
+
70
+ # Communication
71
+ out_queue: asyncio.Queue
72
+ finished: bool
73
+ event: asyncio.Event
74
+ obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]
75
+
76
+ # Metrics (same as TokenizerManager's ReqState)
77
+ created_time: float
78
+ finished_time: float = 0.0
79
+ first_token_time: float = 0.0
80
+ last_time: float = 0.0
81
+ last_completion_tokens: int = 1
82
+
83
+ # Streaming state
84
+ stream_finished: bool = False
85
+ input_logprobs_sent: bool = False # Track if input logprobs were sent in streaming
86
+
87
+ # Token accumulation (for non-streaming)
88
+ output_ids: List[int] = dataclasses.field(default_factory=list)
89
+ input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
90
+ input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
91
+ output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
92
+ output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
93
+ input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
94
+ input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
95
+ output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
96
+ output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
97
+
98
+ # Session state
99
+ session_id: Optional[str] = None
100
+ is_session_request: bool = False
101
+
102
+
103
+ class GrpcRequestManager:
104
+ """
105
+ Manages gRPC request lifecycle, mimicking TokenizerManager's orchestration
106
+ behaviors without tokenization.
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ server_args: ServerArgs,
112
+ port_args: PortArgs,
113
+ bootstrap_server=None,
114
+ ):
115
+ """Initialize the gRPC request manager."""
116
+ self.server_args = server_args
117
+ self.port_args = port_args
118
+
119
+ # ZMQ Communication Setup (same pattern as TokenizerManager)
120
+ self.context = zmq.asyncio.Context(2)
121
+
122
+ # Socket for receiving outputs from scheduler
123
+ self.recv_from_scheduler = get_zmq_socket(
124
+ self.context, zmq.PULL, port_args.detokenizer_ipc_name, bind=True
125
+ )
126
+
127
+ # Socket for sending requests to scheduler
128
+ self.send_to_scheduler = get_zmq_socket(
129
+ self.context, zmq.PUSH, port_args.scheduler_input_ipc_name, bind=True
130
+ )
131
+
132
+ # State Management (from TokenizerManager)
133
+ self.rid_to_state: Dict[str, GrpcReqState] = {}
134
+ self.asyncio_tasks: set = set()
135
+ self.gracefully_exit = False
136
+ self.no_create_loop = False
137
+ self.event_loop = None
138
+
139
+ # Pause/Resume Control
140
+ self.is_pause = False
141
+ self.is_pause_cond = asyncio.Condition()
142
+
143
+ # Metrics
144
+ self.last_receive_tstamp = time.time()
145
+
146
+ # Crash dump for debugging
147
+ self.crash_dump_request_list = []
148
+ self.crash_dump_performed = False
149
+
150
+ # Bootstrap server (passed from serve_grpc, not started here)
151
+ self.bootstrap_server = bootstrap_server
152
+
153
+ logger.info(
154
+ f"GrpcRequestManager initialized with ZMQ IPC: "
155
+ f"recv={port_args.detokenizer_ipc_name}, "
156
+ f"send={port_args.scheduler_input_ipc_name}"
157
+ )
158
+ if self.bootstrap_server:
159
+ logger.info(
160
+ f"Bootstrap server initialized for disaggregation mode: "
161
+ f"{server_args.disaggregation_mode}"
162
+ )
163
+
164
+ async def generate_request(
165
+ self,
166
+ obj: TokenizedGenerateReqInput,
167
+ request_id: Optional[str] = None,
168
+ grpc_context: Optional[grpc.aio.ServicerContext] = None,
169
+ ) -> AsyncGenerator[Union[Dict, List[Dict]], None]:
170
+ """
171
+ Submit a generation request to the scheduler with n>1 parallel sampling support.
172
+
173
+ This method implements the same two-phase approach as tokenizer_manager.py:
174
+ 1. Phase 1: Send prefix caching request (max_new_tokens=0)
175
+ 2. Phase 2: Send n generation requests that reuse the cached prefix
176
+
177
+ Yields individual responses for streaming, or aggregated responses for non-streaming.
178
+ """
179
+ n = getattr(obj.sampling_params, "n", 1)
180
+
181
+ if n <= 1:
182
+ async for response in self._handle_single_request(
183
+ obj, request_id, grpc_context
184
+ ):
185
+ yield response
186
+ return
187
+
188
+ # N>1 handling - two-phase approach
189
+ logger.debug(f"Multiple sampling request (n={n}), using two-phase approach")
190
+
191
+ # Generate base request ID if not provided
192
+ if request_id is None:
193
+ base_request_id = f"grpc-{uuid.uuid4().hex}"
194
+ else:
195
+ base_request_id = request_id
196
+
197
+ # Phase 1: Cache the common prefix
198
+ logger.debug(f"Phase 1: Caching prefix for request {base_request_id}")
199
+ prefix_obj = copy.copy(obj)
200
+ prefix_obj.sampling_params = copy.copy(obj.sampling_params)
201
+ prefix_obj.sampling_params.max_new_tokens = 0 # Prefill-only
202
+ prefix_obj.sampling_params.n = 1 # Don't replicate prefix request
203
+
204
+ # Send prefix caching request and consume response
205
+ async for _ in self._handle_single_request(
206
+ prefix_obj, f"{base_request_id}-prefix", grpc_context
207
+ ):
208
+ # Consume prefix response (usually just one chunk with finish_reason)
209
+ pass
210
+
211
+ logger.debug(f"Phase 1 completed: Prefix cached for {base_request_id}")
212
+
213
+ # Phase 2: Generate n parallel requests
214
+ logger.debug(f"Phase 2: Generating {n} parallel requests")
215
+ generators = []
216
+ request_ids = []
217
+
218
+ for i in range(n):
219
+ # Create individual generation request
220
+ gen_obj = copy.copy(obj)
221
+ gen_obj.sampling_params = copy.copy(obj.sampling_params)
222
+ gen_obj.sampling_params.n = 1 # Each request generates 1 response
223
+
224
+ gen_request_id = f"{base_request_id}-{i}"
225
+ request_ids.append(gen_request_id)
226
+
227
+ # Start generation request
228
+ generators.append(
229
+ self._handle_single_request(gen_obj, gen_request_id, grpc_context)
230
+ )
231
+
232
+ # Handle response aggregation
233
+ is_stream = getattr(obj, "stream", False)
234
+
235
+ if not is_stream:
236
+ # Non-streaming: collect all responses and return as batch
237
+ logger.debug(f"Non-streaming mode: collecting {n} responses")
238
+ responses = []
239
+ for generator in generators:
240
+ async for response in generator:
241
+ responses.append(response)
242
+ yield responses # Return all responses as a batch
243
+ else:
244
+ # Streaming mode: multiplex responses with index for ordering
245
+ logger.debug(f"Streaming mode: multiplexing {n} streams")
246
+ rid_to_index = {rid: i for i, rid in enumerate(request_ids)}
247
+
248
+ # Create async tasks for all generators
249
+ task_map = {}
250
+ for generator in generators:
251
+ task = asyncio.create_task(generator.__anext__())
252
+ task_map[task] = generator
253
+
254
+ # Process responses as they arrive
255
+ while task_map:
256
+ done, _ = await asyncio.wait(
257
+ task_map.keys(), return_when=asyncio.FIRST_COMPLETED
258
+ )
259
+
260
+ for task in done:
261
+ generator = task_map.pop(task)
262
+ try:
263
+ response = await task
264
+
265
+ # Add index for client-side ordering
266
+ if isinstance(response, dict) and "meta_info" in response:
267
+ response_rid = response["meta_info"].get("id", "")
268
+ if response_rid in rid_to_index:
269
+ response["index"] = rid_to_index[response_rid]
270
+
271
+ yield response
272
+
273
+ # Create next task for this generator
274
+ next_task = asyncio.create_task(generator.__anext__())
275
+ task_map[next_task] = generator
276
+
277
+ except StopAsyncIteration:
278
+ # This generator is finished
279
+ pass
280
+
281
+ async def _handle_single_request(
282
+ self,
283
+ obj: TokenizedGenerateReqInput,
284
+ request_id: Optional[str] = None,
285
+ grpc_context: Optional[grpc.aio.ServicerContext] = None,
286
+ ):
287
+ """Handle a single request - core implementation without n>1 logic."""
288
+ # Generate request ID if not provided
289
+ if request_id is None:
290
+ request_id = f"grpc-{uuid.uuid4().hex}"
291
+
292
+ obj.rid = request_id
293
+
294
+ # Create and register request state
295
+ # TODO: support log_request
296
+ state = GrpcReqState(
297
+ request_id=request_id,
298
+ grpc_context=grpc_context,
299
+ out_queue=asyncio.Queue(),
300
+ finished=False,
301
+ event=asyncio.Event(),
302
+ obj=obj,
303
+ created_time=time.time(),
304
+ )
305
+
306
+ # Track session if needed
307
+ if hasattr(obj, "session_params") and obj.session_params:
308
+ state.session_id = obj.session_params.session_id
309
+ state.is_session_request = True
310
+
311
+ self.rid_to_state[request_id] = state
312
+ self.record_request_for_crash_dump(obj)
313
+
314
+ try:
315
+ # Send to scheduler - let exceptions bubble up to grpc_server.py
316
+ await self._send_to_scheduler(obj)
317
+
318
+ is_stream = getattr(obj, "stream", False)
319
+
320
+ while True:
321
+ # Client cancelled - notify scheduler and exit
322
+ if grpc_context and grpc_context.cancelled():
323
+ await self.abort_request(request_id)
324
+ return
325
+
326
+ try:
327
+ response = await asyncio.wait_for(state.out_queue.get(), timeout=4)
328
+
329
+ if is_stream:
330
+ yield response
331
+
332
+ # Non-streaming: yield final response with accumulated tokens from state
333
+ if isinstance(response, dict) and response.get("finished", False):
334
+ if not is_stream:
335
+ final_response = response.copy()
336
+ final_response["token_ids"] = state.output_ids
337
+ yield final_response
338
+ break
339
+
340
+ except asyncio.TimeoutError:
341
+ # Timeout waiting for response - abort and cleanup
342
+ logger.warning(
343
+ f"Timeout waiting for response for request {request_id}"
344
+ )
345
+ await self.abort_request(request_id)
346
+ return
347
+
348
+ finally:
349
+ # Always clean up request state when exiting
350
+ self._cleanup_request_state(request_id)
351
+
352
+ def _cleanup_request_state(self, request_id: str):
353
+ """Clean up local request state (does not notify scheduler)."""
354
+ if request_id in self.rid_to_state:
355
+ del self.rid_to_state[request_id]
356
+
357
+ async def embedding_request(
358
+ self,
359
+ obj: TokenizedEmbeddingReqInput,
360
+ request_id: Optional[str] = None,
361
+ ) -> asyncio.Future:
362
+ """
363
+ Submit an embedding request to the scheduler.
364
+ Returns a future that will contain the embedding result.
365
+ """
366
+ # Generate request ID if not provided
367
+ if request_id is None:
368
+ request_id = f"grpc-embed-{uuid.uuid4().hex}"
369
+
370
+ obj.rid = request_id
371
+
372
+ # Create request state
373
+ state = GrpcReqState(
374
+ request_id=request_id,
375
+ grpc_context=None,
376
+ out_queue=asyncio.Queue(),
377
+ finished=False,
378
+ event=asyncio.Event(),
379
+ obj=obj,
380
+ created_time=time.time(),
381
+ )
382
+
383
+ # Register state
384
+ self.rid_to_state[request_id] = state
385
+
386
+ # Create future for result
387
+ future = asyncio.Future()
388
+
389
+ # Send to scheduler
390
+ try:
391
+ await self._send_to_scheduler(obj)
392
+ except Exception as e:
393
+ del self.rid_to_state[request_id]
394
+ future.set_exception(e)
395
+ return future
396
+
397
+ # Wait for result in background
398
+ async def wait_for_result():
399
+ try:
400
+ # Wait for completion
401
+ await state.event.wait()
402
+ # Get result from queue
403
+ result = await state.out_queue.get()
404
+ future.set_result(result)
405
+ except Exception as e:
406
+ future.set_exception(e)
407
+ finally:
408
+ # Clean up
409
+ if request_id in self.rid_to_state:
410
+ del self.rid_to_state[request_id]
411
+
412
+ asyncio.create_task(wait_for_result())
413
+ return future
414
+
415
+ async def abort_request(self, request_id: str) -> bool:
416
+ """Abort a running request."""
417
+ if request_id not in self.rid_to_state:
418
+ return False
419
+
420
+ # Send abort to scheduler
421
+ abort_req = AbortReq(rid=request_id)
422
+ try:
423
+ await self._send_to_scheduler(abort_req)
424
+ except Exception as e:
425
+ logger.error(f"Failed to send abort request: {e}")
426
+ return False
427
+
428
+ # Mark as finished
429
+ state = self.rid_to_state.get(request_id)
430
+ if state:
431
+ state.finished = True
432
+ state.stream_finished = True
433
+ state.event.set()
434
+
435
+ # Send abort notification to output queue
436
+ await state.out_queue.put({"error": "Request aborted", "abort": True})
437
+
438
+ return True
439
+
440
+ async def pause_generation(self):
441
+ """Pause generation processing."""
442
+ async with self.is_pause_cond:
443
+ self.is_pause = True
444
+ logger.info("Generation paused")
445
+
446
+ async def resume_generation(self):
447
+ """Resume generation processing."""
448
+ async with self.is_pause_cond:
449
+ self.is_pause = False
450
+ self.is_pause_cond.notify_all()
451
+ logger.info("Generation resumed")
452
+
453
+ async def handle_loop(self):
454
+ """
455
+ Main event loop - processes outputs from scheduler.
456
+ Mimics TokenizerManager's handle_loop.
457
+ """
458
+ while not self.gracefully_exit:
459
+ try:
460
+ # Receive from scheduler
461
+ recv_obj = await self.recv_from_scheduler.recv_pyobj()
462
+ self.last_receive_tstamp = time.time()
463
+
464
+ # Check for pause
465
+ async with self.is_pause_cond:
466
+ while self.is_pause:
467
+ await self.is_pause_cond.wait()
468
+
469
+ # Handle different output types
470
+ if isinstance(recv_obj, BatchTokenIDOutput):
471
+ await self._handle_batch_output(recv_obj)
472
+ elif isinstance(recv_obj, BatchEmbeddingOutput):
473
+ await self._handle_embedding_output(recv_obj)
474
+ elif isinstance(recv_obj, HealthCheckOutput):
475
+ await self._handle_health_check_output(recv_obj)
476
+ else:
477
+ logger.warning(f"Unknown output type: {type(recv_obj)}")
478
+
479
+ except zmq.error.Again:
480
+ # Timeout, check if we should exit
481
+ if self.gracefully_exit:
482
+ break
483
+ continue
484
+ except zmq.error.ZMQError as e:
485
+ # Socket closed or other ZMQ error - exit cleanly if shutting down
486
+ if self.gracefully_exit:
487
+ logger.debug(f"ZMQ recv interrupted during shutdown: {e}")
488
+ break
489
+ logger.error(
490
+ f"ZMQ error in handle loop: {e}\n{get_exception_traceback()}"
491
+ )
492
+ break
493
+ except Exception as e:
494
+ logger.error(f"Handle loop error: {e}\n{get_exception_traceback()}")
495
+ if self.gracefully_exit:
496
+ break
497
+
498
+ def _convert_logprob_style(
499
+ self,
500
+ state: GrpcReqState,
501
+ batch_out: BatchTokenIDOutput,
502
+ batch_index: int,
503
+ ):
504
+ """
505
+ Convert and accumulate logprobs from batch output to state.
506
+ Follows the same logic as tokenizer_manager.convert_logprob_style.
507
+ """
508
+ # Early exit if no input logprobs at all
509
+ if batch_out.input_token_logprobs_val is None:
510
+ return
511
+
512
+ # Accumulate input token logprobs (only if list is non-empty)
513
+ if len(batch_out.input_token_logprobs_val) > 0:
514
+ state.input_token_logprobs_val.extend(
515
+ batch_out.input_token_logprobs_val[batch_index]
516
+ )
517
+ state.input_token_logprobs_idx.extend(
518
+ batch_out.input_token_logprobs_idx[batch_index]
519
+ )
520
+
521
+ # Always accumulate output token logprobs
522
+ state.output_token_logprobs_val.extend(
523
+ batch_out.output_token_logprobs_val[batch_index]
524
+ )
525
+ state.output_token_logprobs_idx.extend(
526
+ batch_out.output_token_logprobs_idx[batch_index]
527
+ )
528
+
529
+ # Handle top logprobs if requested
530
+ if state.obj.top_logprobs_num > 0:
531
+ # Accumulate input top logprobs (only if list is non-empty)
532
+ if len(batch_out.input_top_logprobs_val) > 0:
533
+ state.input_top_logprobs_val.extend(
534
+ batch_out.input_top_logprobs_val[batch_index]
535
+ )
536
+ state.input_top_logprobs_idx.extend(
537
+ batch_out.input_top_logprobs_idx[batch_index]
538
+ )
539
+
540
+ # Always accumulate output top logprobs
541
+ state.output_top_logprobs_val.extend(
542
+ batch_out.output_top_logprobs_val[batch_index]
543
+ )
544
+ state.output_top_logprobs_idx.extend(
545
+ batch_out.output_top_logprobs_idx[batch_index]
546
+ )
547
+
548
+ async def _handle_batch_output(self, batch_out: BatchTokenIDOutput):
549
+ """Handle batch generation output from scheduler."""
550
+ # Process each request in the batch
551
+ for i, rid in enumerate(batch_out.rids):
552
+ if rid not in self.rid_to_state:
553
+ continue
554
+
555
+ state = self.rid_to_state[rid]
556
+
557
+ # Update metrics
558
+ now = time.time()
559
+ if state.first_token_time == 0.0:
560
+ state.first_token_time = now
561
+ state.last_time = now
562
+
563
+ # Extract output for this request
564
+ output_data = {
565
+ "request_id": rid,
566
+ "token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
567
+ "finished": batch_out.finished_reasons[i] is not None,
568
+ "meta_info": {
569
+ "prompt_tokens": (
570
+ batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
571
+ ),
572
+ "completion_tokens": (
573
+ batch_out.completion_tokens[i]
574
+ if batch_out.completion_tokens
575
+ else 0
576
+ ),
577
+ "cached_tokens": (
578
+ batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
579
+ ),
580
+ "finish_reason": (
581
+ batch_out.finished_reasons[i]
582
+ if batch_out.finished_reasons[i]
583
+ else None
584
+ ),
585
+ },
586
+ }
587
+
588
+ # Accumulate logprobs (following tokenizer_manager pattern)
589
+ if state.obj.return_logprob:
590
+ self._convert_logprob_style(state, batch_out, i)
591
+
592
+ # Send input logprobs based if available
593
+ if (
594
+ state.obj.return_logprob
595
+ and state.obj.logprob_start_len >= 0
596
+ and state.input_token_logprobs_val
597
+ ):
598
+ if state.obj.stream and not state.input_logprobs_sent:
599
+ # Streaming: send input logprobs once in first chunk that has them
600
+ output_data["input_logprobs"] = {
601
+ "token_logprobs_val": state.input_token_logprobs_val,
602
+ "token_logprobs_idx": state.input_token_logprobs_idx,
603
+ "top_logprobs_val": state.input_top_logprobs_val,
604
+ "top_logprobs_idx": state.input_top_logprobs_idx,
605
+ }
606
+ state.input_logprobs_sent = True
607
+ elif not state.obj.stream and output_data["finished"]:
608
+ # Non-streaming: send input logprobs in final chunk
609
+ output_data["input_logprobs"] = {
610
+ "token_logprobs_val": state.input_token_logprobs_val,
611
+ "token_logprobs_idx": state.input_token_logprobs_idx,
612
+ "top_logprobs_val": state.input_top_logprobs_val,
613
+ "top_logprobs_idx": state.input_top_logprobs_idx,
614
+ }
615
+
616
+ # Send output logprobs if available
617
+ if (
618
+ state.obj.return_logprob
619
+ and batch_out.output_token_logprobs_val
620
+ and i < len(batch_out.output_token_logprobs_val)
621
+ ):
622
+ if state.obj.stream:
623
+ # For streaming: send incremental logprobs (only new tokens in this chunk)
624
+ # NOTE: this is different than TokenizerManager, which always accumulates
625
+ def get_part(attr_name):
626
+ source_list = getattr(batch_out, attr_name, None)
627
+ return (
628
+ source_list[i]
629
+ if source_list and i < len(source_list)
630
+ else []
631
+ )
632
+
633
+ output_data["output_logprobs"] = {
634
+ "token_logprobs_val": batch_out.output_token_logprobs_val[i],
635
+ "token_logprobs_idx": get_part("output_token_logprobs_idx"),
636
+ "top_logprobs_val": get_part("output_top_logprobs_val"),
637
+ "top_logprobs_idx": get_part("output_top_logprobs_idx"),
638
+ }
639
+ elif output_data["finished"]:
640
+ # Non-streaming: send cumulative output logprobs in final chunk
641
+ output_data["output_logprobs"] = {
642
+ "token_logprobs_val": state.output_token_logprobs_val,
643
+ "token_logprobs_idx": state.output_token_logprobs_idx,
644
+ "top_logprobs_val": state.output_top_logprobs_val,
645
+ "top_logprobs_idx": state.output_top_logprobs_idx,
646
+ }
647
+
648
+ # Update state for accumulation
649
+ if output_data["token_ids"]:
650
+ state.output_ids.extend(output_data["token_ids"])
651
+
652
+ await state.out_queue.put(output_data)
653
+
654
+ # Handle completion
655
+ if output_data["finished"]:
656
+ state.finished = True
657
+ state.finished_time = now
658
+ state.stream_finished = True
659
+ state.event.set()
660
+
661
+ # Remove from tracking after a delay
662
+ async def cleanup():
663
+ await asyncio.sleep(5.0)
664
+ if rid in self.rid_to_state:
665
+ del self.rid_to_state[rid]
666
+
667
+ asyncio.create_task(cleanup())
668
+
669
+ async def _handle_embedding_output(self, batch_out: BatchEmbeddingOutput):
670
+ """Handle batch embedding output from scheduler."""
671
+ for i, rid in enumerate(batch_out.rids):
672
+ if rid not in self.rid_to_state:
673
+ continue
674
+
675
+ state = self.rid_to_state[rid]
676
+
677
+ # Create result
678
+ result = {
679
+ "request_id": rid,
680
+ "embedding": batch_out.embeddings[i],
681
+ "prompt_tokens": (
682
+ batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
683
+ ),
684
+ "finish_reason": (
685
+ batch_out.finish_reason[i] if batch_out.finish_reason else None
686
+ ),
687
+ }
688
+
689
+ # Send result
690
+ await state.out_queue.put(result)
691
+
692
+ # Mark as finished
693
+ state.finished = True
694
+ state.finished_time = time.time()
695
+ state.event.set()
696
+
697
+ async def _handle_health_check_output(self, health_out: HealthCheckOutput):
698
+ """Handle health check output from scheduler."""
699
+ rid = health_out.rid
700
+
701
+ if rid not in self.rid_to_state:
702
+ logger.warning(f"Health check output for unknown request: {rid}")
703
+ return
704
+
705
+ state = self.rid_to_state[rid]
706
+
707
+ # Create health check result
708
+ result = {
709
+ "request_id": rid,
710
+ "healthy": True, # If we got a response, scheduler is healthy
711
+ "output_text": (
712
+ health_out.output_str if hasattr(health_out, "output_str") else ""
713
+ ),
714
+ "finish_reason": (
715
+ health_out.finish_reason
716
+ if hasattr(health_out, "finish_reason")
717
+ else "stop"
718
+ ),
719
+ }
720
+
721
+ # Send result
722
+ await state.out_queue.put(result)
723
+
724
+ # Mark as finished
725
+ state.finished = True
726
+ state.finished_time = time.time()
727
+ state.event.set()
728
+
729
+ async def _send_to_scheduler(self, obj):
730
+ """Send an object to the scheduler via ZMQ."""
731
+ try:
732
+ self.send_to_scheduler.send_pyobj(obj)
733
+ except Exception as e:
734
+ logger.error(f"Failed to send to scheduler: {e}")
735
+ raise
736
+
737
+ def record_request_for_crash_dump(self, obj):
738
+ """Record request for potential crash dump."""
739
+ if len(self.crash_dump_request_list) < 100:
740
+ self.crash_dump_request_list.append(
741
+ {
742
+ "time": time.time(),
743
+ "request_id": getattr(obj, "rid", "unknown"),
744
+ "type": type(obj).__name__,
745
+ }
746
+ )
747
+
748
+ async def shutdown(self):
749
+ """Gracefully shutdown the request manager."""
750
+ logger.info("Shutting down GrpcRequestManager")
751
+ self.gracefully_exit = True
752
+
753
+ # Cancel all asyncio tasks FIRST - this will interrupt blocked recv() calls
754
+ for task in list(self.asyncio_tasks):
755
+ if not task.done():
756
+ task.cancel()
757
+
758
+ # Give tasks a moment to process cancellation
759
+ if self.asyncio_tasks:
760
+ await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
761
+
762
+ # Cancel all pending requests
763
+ for rid, state in list(self.rid_to_state.items()):
764
+ if not state.finished:
765
+ await state.out_queue.put(
766
+ {"error": "Server shutting down", "shutdown": True}
767
+ )
768
+ state.finished = True
769
+ state.event.set()
770
+
771
+ # Wait for tasks to complete
772
+ if self.asyncio_tasks:
773
+ await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
774
+
775
+ # Shutdown bootstrap server if running
776
+ if self.bootstrap_server:
777
+ logger.info("Shutting down bootstrap server")
778
+ try:
779
+ if hasattr(self.bootstrap_server, "shutdown"):
780
+ if asyncio.iscoroutinefunction(self.bootstrap_server.shutdown):
781
+ await self.bootstrap_server.shutdown()
782
+ else:
783
+ self.bootstrap_server.shutdown()
784
+ except Exception as e:
785
+ logger.warning(f"Error shutting down bootstrap server: {e}")
786
+
787
+ # Close ZMQ sockets
788
+ self.recv_from_scheduler.close()
789
+ self.send_to_scheduler.close()
790
+
791
+ # Terminate the ZMQ context - this is critical for asyncio loop to exit cleanly
792
+ self.context.term()
793
+
794
+ logger.info("GrpcRequestManager shutdown complete")
795
+
796
+ def get_server_info(self) -> Dict[str, Any]:
797
+ """Get server information for health checks."""
798
+ return {
799
+ "active_requests": len(self.rid_to_state),
800
+ "paused": self.is_pause,
801
+ "last_receive_time": self.last_receive_tstamp,
802
+ }
803
+
804
+ def auto_create_handle_loop(self):
805
+ """Automatically create and start the handle_loop task, matching TokenizerManager pattern."""
806
+ if self.no_create_loop:
807
+ return
808
+
809
+ self.no_create_loop = True
810
+ loop = asyncio.get_event_loop()
811
+ self.asyncio_tasks.add(
812
+ loop.create_task(print_exception_wrapper(self.handle_loop))
813
+ )
814
+
815
+ self.event_loop = loop
816
+
817
+ # We cannot add signal handler when the grpc manager is not in
818
+ # the main thread due to the CPython limitation.
819
+ if threading.current_thread() is threading.main_thread():
820
+ signal_handler = GrpcSignalHandler(self)
821
+ loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
822
+ # Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
823
+ loop.add_signal_handler(
824
+ signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
825
+ )
826
+ else:
827
+ logger.warning(
828
+ "Signal handler is not added because the grpc request manager is "
829
+ "not in the main thread. This disables graceful shutdown of the "
830
+ "grpc request manager when SIGTERM is received."
831
+ )
832
+ self.asyncio_tasks.add(
833
+ loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
834
+ )
835
+
836
+ async def sigterm_watchdog(self):
837
+ """Watchdog to handle SIGTERM gracefully, matching TokenizerManager pattern."""
838
+ while not self.gracefully_exit:
839
+ await asyncio.sleep(1.0)
840
+
841
+
842
+ async def print_exception_wrapper(func):
843
+ """
844
+ Sometimes an asyncio function does not print exception.
845
+ We do another wrapper to handle the exception.
846
+ """
847
+ try:
848
+ await func()
849
+ except Exception:
850
+ traceback = get_exception_traceback()
851
+ logger.error(f"GrpcRequestManager hit an exception: {traceback}")
852
+ if hasattr(func, "__self__") and isinstance(func.__self__, GrpcRequestManager):
853
+ func.__self__.dump_requests_before_crash()
854
+ kill_process_tree(os.getpid(), include_parent=True)
855
+ sys.exit(1)