sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,915 @@
1
+ """
2
+ gRPC Request Manager - Orchestrates request lifecycle without tokenization.
3
+ Mimics TokenizerManager's state management and ZMQ communication patterns.
4
+ """
5
+
6
+ import asyncio
7
+ import copy
8
+ import dataclasses
9
+ import logging
10
+ import os
11
+ import signal
12
+ import sys
13
+ import threading
14
+ import time
15
+ import uuid
16
+ from typing import Any, AsyncGenerator, Dict, List, Optional, Union
17
+
18
+ import grpc
19
+ import zmq
20
+ import zmq.asyncio
21
+
22
+ from sglang.srt.managers.io_struct import (
23
+ AbortReq,
24
+ BatchEmbeddingOutput,
25
+ BatchTokenIDOutput,
26
+ HealthCheckOutput,
27
+ TokenizedEmbeddingReqInput,
28
+ TokenizedGenerateReqInput,
29
+ )
30
+ from sglang.srt.server_args import PortArgs, ServerArgs
31
+ from sglang.srt.utils import get_zmq_socket, kill_process_tree
32
+ from sglang.utils import get_exception_traceback
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class GrpcSignalHandler:
38
+ """Minimal signal handler for gRPC server - delegates real crash handling to scheduler."""
39
+
40
+ def __init__(self, grpc_manager):
41
+ self.grpc_manager = grpc_manager
42
+
43
+ def sigterm_handler(self, signum=None, frame=None):
44
+ """Handle SIGTERM by gracefully shutting down gRPC server."""
45
+ logger.warning(
46
+ f"SIGTERM received. {signum=} {frame=}. Shutting down gRPC server..."
47
+ )
48
+ self.grpc_manager.gracefully_exit = True
49
+
50
+ def running_phase_sigquit_handler(self, signum=None, frame=None):
51
+ """Handle SIGQUIT from failed scheduler process."""
52
+ logger.error(
53
+ "Received SIGQUIT from scheduler process. Scheduler failed, shutting down gRPC server."
54
+ )
55
+ logger.info(
56
+ "Note: Crash dumps are handled by the scheduler process, not the gRPC server."
57
+ )
58
+ # Just exit cleanly - the scheduler handles crash dumps
59
+ kill_process_tree(os.getpid(), include_parent=True)
60
+
61
+
62
+ @dataclasses.dataclass
63
+ class GrpcReqState:
64
+ """State tracking for a gRPC request."""
65
+
66
+ # Request identification
67
+ request_id: str
68
+ grpc_context: Optional[grpc.aio.ServicerContext]
69
+
70
+ # Communication
71
+ out_queue: asyncio.Queue
72
+ finished: bool
73
+ event: asyncio.Event
74
+ obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]
75
+
76
+ # Metrics (same as TokenizerManager's ReqState)
77
+ created_time: float
78
+ finished_time: float = 0.0
79
+ first_token_time: float = 0.0
80
+ last_time: float = 0.0
81
+ last_completion_tokens: int = 1
82
+
83
+ # Streaming state
84
+ stream_finished: bool = False
85
+ input_logprobs_sent: bool = False # Track if input logprobs were sent in streaming
86
+
87
+ # Token accumulation (for non-streaming)
88
+ output_ids: List[int] = dataclasses.field(default_factory=list)
89
+ input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
90
+ input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
91
+ output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
92
+ output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
93
+ input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
94
+ input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
95
+ output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
96
+ output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
97
+
98
+ # Session state
99
+ session_id: Optional[str] = None
100
+ is_session_request: bool = False
101
+
102
+
103
+ class GrpcRequestManager:
104
+ """
105
+ Manages gRPC request lifecycle, mimicking TokenizerManager's orchestration
106
+ behaviors without tokenization.
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ server_args: ServerArgs,
112
+ port_args: PortArgs,
113
+ bootstrap_server=None,
114
+ ):
115
+ """Initialize the gRPC request manager."""
116
+ self.server_args = server_args
117
+ self.port_args = port_args
118
+
119
+ # ZMQ Communication Setup (same pattern as TokenizerManager)
120
+ self.context = zmq.asyncio.Context(2)
121
+
122
+ # Socket for receiving outputs from scheduler
123
+ self.recv_from_scheduler = get_zmq_socket(
124
+ self.context, zmq.PULL, port_args.detokenizer_ipc_name, bind=True
125
+ )
126
+
127
+ # Socket for sending requests to scheduler
128
+ self.send_to_scheduler = get_zmq_socket(
129
+ self.context, zmq.PUSH, port_args.scheduler_input_ipc_name, bind=True
130
+ )
131
+
132
+ # State Management (from TokenizerManager)
133
+ self.rid_to_state: Dict[str, GrpcReqState] = {}
134
+ self.asyncio_tasks: set = set()
135
+ self.gracefully_exit = False
136
+ self.no_create_loop = False
137
+ self.event_loop = None
138
+
139
+ # Pause/Resume Control
140
+ self.is_pause = False
141
+ self.is_pause_cond = asyncio.Condition()
142
+
143
+ # Metrics
144
+ self.last_receive_tstamp = time.time()
145
+
146
+ # Crash dump for debugging
147
+ self.crash_dump_request_list = []
148
+ self.crash_dump_performed = False
149
+
150
+ # Bootstrap server (passed from serve_grpc, not started here)
151
+ self.bootstrap_server = bootstrap_server
152
+
153
+ logger.info(
154
+ f"GrpcRequestManager initialized with ZMQ IPC: "
155
+ f"recv={port_args.detokenizer_ipc_name}, "
156
+ f"send={port_args.scheduler_input_ipc_name}"
157
+ )
158
+ if self.bootstrap_server:
159
+ logger.info(
160
+ f"Bootstrap server initialized for disaggregation mode: "
161
+ f"{server_args.disaggregation_mode}"
162
+ )
163
+
164
+ async def generate_request(
165
+ self,
166
+ obj: TokenizedGenerateReqInput,
167
+ request_id: Optional[str] = None,
168
+ grpc_context: Optional[grpc.aio.ServicerContext] = None,
169
+ ) -> AsyncGenerator[Union[Dict, List[Dict]], None]:
170
+ """
171
+ Submit a generation request to the scheduler with n>1 parallel sampling support.
172
+
173
+ This method implements the same two-phase approach as tokenizer_manager.py:
174
+ 1. Phase 1: Send prefix caching request (max_new_tokens=0)
175
+ 2. Phase 2: Send n generation requests that reuse the cached prefix
176
+
177
+ Yields individual responses for streaming, or aggregated responses for non-streaming.
178
+ """
179
+ n = getattr(obj.sampling_params, "n", 1)
180
+
181
+ if n <= 1:
182
+ async for response in self._handle_single_request(
183
+ obj, request_id, grpc_context
184
+ ):
185
+ yield response
186
+ return
187
+
188
+ # N>1 handling - two-phase approach
189
+ logger.debug(f"Multiple sampling request (n={n}), using two-phase approach")
190
+
191
+ # Generate base request ID if not provided
192
+ if request_id is None:
193
+ base_request_id = f"grpc-{uuid.uuid4().hex}"
194
+ else:
195
+ base_request_id = request_id
196
+
197
+ # Phase 1: Cache the common prefix
198
+ logger.debug(f"Phase 1: Caching prefix for request {base_request_id}")
199
+ prefix_obj = copy.copy(obj)
200
+ prefix_obj.sampling_params = copy.copy(obj.sampling_params)
201
+ prefix_obj.sampling_params.max_new_tokens = 0 # Prefill-only
202
+ prefix_obj.sampling_params.n = 1 # Don't replicate prefix request
203
+
204
+ # Send prefix caching request and consume response
205
+ async for _ in self._handle_single_request(
206
+ prefix_obj, f"{base_request_id}-prefix", grpc_context
207
+ ):
208
+ # Consume prefix response (usually just one chunk with finish_reason)
209
+ pass
210
+
211
+ logger.debug(f"Phase 1 completed: Prefix cached for {base_request_id}")
212
+
213
+ # Phase 2: Generate n parallel requests
214
+ logger.debug(f"Phase 2: Generating {n} parallel requests")
215
+ generators = []
216
+ request_ids = []
217
+
218
+ for i in range(n):
219
+ # Create individual generation request
220
+ gen_obj = copy.copy(obj)
221
+ gen_obj.sampling_params = copy.copy(obj.sampling_params)
222
+ gen_obj.sampling_params.n = 1 # Each request generates 1 response
223
+
224
+ gen_request_id = f"{base_request_id}-{i}"
225
+ request_ids.append(gen_request_id)
226
+
227
+ # Start generation request
228
+ generators.append(
229
+ self._handle_single_request(gen_obj, gen_request_id, grpc_context)
230
+ )
231
+
232
+ # Handle response aggregation
233
+ is_stream = getattr(obj, "stream", False)
234
+
235
+ if not is_stream:
236
+ # Non-streaming: collect all responses and return as batch
237
+ logger.debug(f"Non-streaming mode: collecting {n} responses")
238
+ responses = []
239
+ for generator in generators:
240
+ async for response in generator:
241
+ responses.append(response)
242
+ yield responses # Return all responses as a batch
243
+ else:
244
+ # Streaming mode: multiplex responses with index for ordering
245
+ logger.debug(f"Streaming mode: multiplexing {n} streams")
246
+ rid_to_index = {rid: i for i, rid in enumerate(request_ids)}
247
+
248
+ # Create async tasks for all generators
249
+ task_map = {}
250
+ for generator in generators:
251
+ task = asyncio.create_task(generator.__anext__())
252
+ task_map[task] = generator
253
+
254
+ # Process responses as they arrive
255
+ while task_map:
256
+ done, _ = await asyncio.wait(
257
+ task_map.keys(), return_when=asyncio.FIRST_COMPLETED
258
+ )
259
+
260
+ for task in done:
261
+ generator = task_map.pop(task)
262
+ try:
263
+ response = await task
264
+
265
+ # Add index for client-side ordering
266
+ if isinstance(response, dict):
267
+ response_rid = response.get("request_id", "")
268
+ if response_rid in rid_to_index:
269
+ response["index"] = rid_to_index[response_rid]
270
+
271
+ yield response
272
+
273
+ # Create next task for this generator
274
+ next_task = asyncio.create_task(generator.__anext__())
275
+ task_map[next_task] = generator
276
+
277
+ except StopAsyncIteration:
278
+ # This generator is finished
279
+ pass
280
+
281
+ async def _handle_single_request(
282
+ self,
283
+ obj: TokenizedGenerateReqInput,
284
+ request_id: Optional[str] = None,
285
+ grpc_context: Optional[grpc.aio.ServicerContext] = None,
286
+ ):
287
+ """Handle a single request - core implementation without n>1 logic."""
288
+ # Generate request ID if not provided
289
+ if request_id is None:
290
+ request_id = f"grpc-{uuid.uuid4().hex}"
291
+
292
+ obj.rid = request_id
293
+
294
+ # Create and register request state
295
+ # TODO: support log_request
296
+ state = GrpcReqState(
297
+ request_id=request_id,
298
+ grpc_context=grpc_context,
299
+ out_queue=asyncio.Queue(),
300
+ finished=False,
301
+ event=asyncio.Event(),
302
+ obj=obj,
303
+ created_time=time.time(),
304
+ )
305
+
306
+ # Track session if needed
307
+ if hasattr(obj, "session_params") and obj.session_params:
308
+ state.session_id = obj.session_params.session_id
309
+ state.is_session_request = True
310
+
311
+ self.rid_to_state[request_id] = state
312
+ self.record_request_for_crash_dump(obj)
313
+
314
+ try:
315
+ # Send to scheduler - let exceptions bubble up to grpc_server.py
316
+ await self._send_to_scheduler(obj)
317
+
318
+ is_stream = getattr(obj, "stream", False)
319
+
320
+ while True:
321
+ try:
322
+ response = await state.out_queue.get()
323
+
324
+ if is_stream:
325
+ yield response
326
+
327
+ # Non-streaming: yield final response with accumulated tokens from state
328
+ if isinstance(response, dict) and response.get("finished", False):
329
+ if not is_stream:
330
+ final_response = response.copy()
331
+ final_response["token_ids"] = state.output_ids
332
+ yield final_response
333
+ break
334
+
335
+ except asyncio.CancelledError:
336
+ # Task was cancelled by gRPC framework when client disconnected
337
+ logger.info(f"Request {request_id} cancelled by client")
338
+ await self.abort_request(request_id)
339
+ raise # Re-raise to let gRPC server handle cleanup
340
+
341
+ finally:
342
+ # Always clean up request state when exiting
343
+ self._cleanup_request_state(request_id)
344
+
345
+ def _cleanup_request_state(self, request_id: str):
346
+ """Clean up local request state (does not notify scheduler)."""
347
+ if request_id in self.rid_to_state:
348
+ del self.rid_to_state[request_id]
349
+
350
+ async def embedding_request(
351
+ self,
352
+ obj: TokenizedEmbeddingReqInput,
353
+ request_id: Optional[str] = None,
354
+ ) -> asyncio.Future:
355
+ """
356
+ Submit an embedding request to the scheduler.
357
+ Returns a future that will contain the embedding result.
358
+ """
359
+ # Generate request ID if not provided
360
+ if request_id is None:
361
+ request_id = f"grpc-embed-{uuid.uuid4().hex}"
362
+
363
+ obj.rid = request_id
364
+
365
+ # Create request state
366
+ state = GrpcReqState(
367
+ request_id=request_id,
368
+ grpc_context=None,
369
+ out_queue=asyncio.Queue(),
370
+ finished=False,
371
+ event=asyncio.Event(),
372
+ obj=obj,
373
+ created_time=time.time(),
374
+ )
375
+
376
+ # Register state
377
+ self.rid_to_state[request_id] = state
378
+
379
+ # Create future for result
380
+ future = asyncio.Future()
381
+
382
+ # Send to scheduler
383
+ try:
384
+ await self._send_to_scheduler(obj)
385
+ except Exception as e:
386
+ del self.rid_to_state[request_id]
387
+ future.set_exception(e)
388
+ return future
389
+
390
+ # Wait for result in background
391
+ async def wait_for_result():
392
+ try:
393
+ await state.event.wait()
394
+ result = await state.out_queue.get()
395
+ future.set_result(result)
396
+ except Exception as e:
397
+ future.set_exception(e)
398
+ finally:
399
+ # Clean up
400
+ if request_id in self.rid_to_state:
401
+ del self.rid_to_state[request_id]
402
+
403
+ asyncio.create_task(wait_for_result())
404
+ return future
405
+
406
+ async def abort_request(self, request_id: str) -> bool:
407
+ """Abort a running request.
408
+
409
+ Sends abort request to scheduler and marks local state as finished
410
+ to stop processing any further outputs from the scheduler.
411
+ """
412
+ # Skip aborting health check requests (they clean themselves up)
413
+ if request_id.startswith("HEALTH_CHECK"):
414
+ return False
415
+
416
+ # Mark state as finished immediately to stop processing scheduler outputs
417
+ state = self.rid_to_state.get(request_id)
418
+ if state:
419
+ state.finished = True
420
+ state.stream_finished = True
421
+ logger.debug(f"Marked request {request_id} as aborted locally")
422
+
423
+ # Send abort to scheduler - the scheduler will send AbortReq back
424
+ # which will be handled by _handle_abort_req
425
+ abort_req = AbortReq(rid=request_id)
426
+ try:
427
+ await self._send_to_scheduler(abort_req)
428
+ logger.debug(f"Sent abort to scheduler for request {request_id}")
429
+ except Exception as e:
430
+ logger.error(f"Failed to send abort request to scheduler: {e}")
431
+ return False
432
+
433
+ return True
434
+
435
+ async def handle_loop(self):
436
+ """
437
+ Main event loop - processes outputs from scheduler.
438
+ Mimics TokenizerManager's handle_loop.
439
+ """
440
+ while not self.gracefully_exit:
441
+ try:
442
+ # Receive from scheduler
443
+ recv_obj = await self.recv_from_scheduler.recv_pyobj()
444
+ self.last_receive_tstamp = time.time()
445
+
446
+ # Check for pause (optimized: check flag before acquiring lock)
447
+ if self.is_pause:
448
+ async with self.is_pause_cond:
449
+ while self.is_pause:
450
+ await self.is_pause_cond.wait()
451
+
452
+ # Handle different output types
453
+ if isinstance(recv_obj, BatchTokenIDOutput):
454
+ await self._handle_batch_output(recv_obj)
455
+ elif isinstance(recv_obj, BatchEmbeddingOutput):
456
+ await self._handle_embedding_output(recv_obj)
457
+ elif isinstance(recv_obj, HealthCheckOutput):
458
+ await self._handle_health_check_output(recv_obj)
459
+ elif isinstance(recv_obj, AbortReq):
460
+ await self._handle_abort_req(recv_obj)
461
+ else:
462
+ logger.warning(f"Unknown output type: {type(recv_obj)}")
463
+
464
+ except zmq.error.Again:
465
+ # Timeout, check if we should exit
466
+ if self.gracefully_exit:
467
+ break
468
+ continue
469
+ except zmq.error.ZMQError as e:
470
+ # Socket closed or other ZMQ error - exit cleanly if shutting down
471
+ if self.gracefully_exit:
472
+ logger.debug(f"ZMQ recv interrupted during shutdown: {e}")
473
+ break
474
+ logger.error(
475
+ f"ZMQ error in handle loop: {e}\n{get_exception_traceback()}"
476
+ )
477
+ break
478
+ except Exception as e:
479
+ logger.error(f"Handle loop error: {e}\n{get_exception_traceback()}")
480
+ if self.gracefully_exit:
481
+ break
482
+
483
+ def _convert_logprob_style(
484
+ self,
485
+ state: GrpcReqState,
486
+ batch_out: BatchTokenIDOutput,
487
+ batch_index: int,
488
+ ):
489
+ """
490
+ Convert and accumulate logprobs from batch output to state.
491
+ Follows the same logic as tokenizer_manager.convert_logprob_style.
492
+ """
493
+ # Early exit if no input logprobs at all
494
+ if batch_out.input_token_logprobs_val is None:
495
+ return
496
+
497
+ # Accumulate input token logprobs (only if list is non-empty)
498
+ if len(batch_out.input_token_logprobs_val) > 0:
499
+ state.input_token_logprobs_val.extend(
500
+ batch_out.input_token_logprobs_val[batch_index]
501
+ )
502
+ state.input_token_logprobs_idx.extend(
503
+ batch_out.input_token_logprobs_idx[batch_index]
504
+ )
505
+
506
+ # Always accumulate output token logprobs
507
+ state.output_token_logprobs_val.extend(
508
+ batch_out.output_token_logprobs_val[batch_index]
509
+ )
510
+ state.output_token_logprobs_idx.extend(
511
+ batch_out.output_token_logprobs_idx[batch_index]
512
+ )
513
+
514
+ # Handle top logprobs if requested
515
+ if state.obj.top_logprobs_num > 0:
516
+ # Accumulate input top logprobs (only if list is non-empty)
517
+ if len(batch_out.input_top_logprobs_val) > 0:
518
+ state.input_top_logprobs_val.extend(
519
+ batch_out.input_top_logprobs_val[batch_index]
520
+ )
521
+ state.input_top_logprobs_idx.extend(
522
+ batch_out.input_top_logprobs_idx[batch_index]
523
+ )
524
+
525
+ # Always accumulate output top logprobs
526
+ state.output_top_logprobs_val.extend(
527
+ batch_out.output_top_logprobs_val[batch_index]
528
+ )
529
+ state.output_top_logprobs_idx.extend(
530
+ batch_out.output_top_logprobs_idx[batch_index]
531
+ )
532
+
533
+ async def _handle_batch_output(self, batch_out: BatchTokenIDOutput):
534
+ """Handle batch generation output from scheduler."""
535
+ # Collect all queue.put() tasks for parallel execution
536
+ put_tasks = []
537
+ cleanup_tasks = []
538
+ now = time.time()
539
+
540
+ # Process each request in the batch
541
+ for i, rid in enumerate(batch_out.rids):
542
+ if rid not in self.rid_to_state:
543
+ continue
544
+
545
+ state = self.rid_to_state[rid]
546
+
547
+ # Skip if already aborted/finished locally (client cancelled)
548
+ if state.finished:
549
+ logger.debug(f"Skipping output for aborted request {rid}")
550
+ continue
551
+
552
+ # Update metrics
553
+ if state.first_token_time == 0.0:
554
+ state.first_token_time = now
555
+ state.last_time = now
556
+
557
+ # Extract output for this request
558
+ output_data = {
559
+ "request_id": rid,
560
+ "token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
561
+ "finished": batch_out.finished_reasons[i] is not None,
562
+ "meta_info": {
563
+ "prompt_tokens": (
564
+ batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
565
+ ),
566
+ "completion_tokens": (
567
+ batch_out.completion_tokens[i]
568
+ if batch_out.completion_tokens
569
+ else 0
570
+ ),
571
+ "cached_tokens": (
572
+ batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
573
+ ),
574
+ "finish_reason": (
575
+ batch_out.finished_reasons[i]
576
+ if batch_out.finished_reasons[i]
577
+ else None
578
+ ),
579
+ },
580
+ }
581
+
582
+ # Accumulate logprobs (following tokenizer_manager pattern)
583
+ if state.obj.return_logprob:
584
+ self._convert_logprob_style(state, batch_out, i)
585
+
586
+ # Send input logprobs based if available
587
+ if (
588
+ state.obj.return_logprob
589
+ and state.obj.logprob_start_len >= 0
590
+ and state.input_token_logprobs_val
591
+ ):
592
+ if state.obj.stream and not state.input_logprobs_sent:
593
+ # Streaming: send input logprobs once in first chunk that has them
594
+ output_data["input_logprobs"] = {
595
+ "token_logprobs_val": state.input_token_logprobs_val,
596
+ "token_logprobs_idx": state.input_token_logprobs_idx,
597
+ "top_logprobs_val": state.input_top_logprobs_val,
598
+ "top_logprobs_idx": state.input_top_logprobs_idx,
599
+ }
600
+ state.input_logprobs_sent = True
601
+ elif not state.obj.stream and output_data["finished"]:
602
+ # Non-streaming: send input logprobs in final chunk
603
+ output_data["input_logprobs"] = {
604
+ "token_logprobs_val": state.input_token_logprobs_val,
605
+ "token_logprobs_idx": state.input_token_logprobs_idx,
606
+ "top_logprobs_val": state.input_top_logprobs_val,
607
+ "top_logprobs_idx": state.input_top_logprobs_idx,
608
+ }
609
+
610
+ # Send output logprobs if available
611
+ if (
612
+ state.obj.return_logprob
613
+ and batch_out.output_token_logprobs_val
614
+ and i < len(batch_out.output_token_logprobs_val)
615
+ ):
616
+ if state.obj.stream:
617
+ # For streaming: send incremental logprobs (only new tokens in this chunk)
618
+ # NOTE: this is different than TokenizerManager, which always accumulates
619
+ def get_part(attr_name):
620
+ source_list = getattr(batch_out, attr_name, None)
621
+ return (
622
+ source_list[i]
623
+ if source_list and i < len(source_list)
624
+ else []
625
+ )
626
+
627
+ output_data["output_logprobs"] = {
628
+ "token_logprobs_val": batch_out.output_token_logprobs_val[i],
629
+ "token_logprobs_idx": get_part("output_token_logprobs_idx"),
630
+ "top_logprobs_val": get_part("output_top_logprobs_val"),
631
+ "top_logprobs_idx": get_part("output_top_logprobs_idx"),
632
+ }
633
+ elif output_data["finished"]:
634
+ # Non-streaming: send cumulative output logprobs in final chunk
635
+ output_data["output_logprobs"] = {
636
+ "token_logprobs_val": state.output_token_logprobs_val,
637
+ "token_logprobs_idx": state.output_token_logprobs_idx,
638
+ "top_logprobs_val": state.output_top_logprobs_val,
639
+ "top_logprobs_idx": state.output_top_logprobs_idx,
640
+ }
641
+
642
+ # Update state for accumulation
643
+ if output_data["token_ids"]:
644
+ state.output_ids.extend(output_data["token_ids"])
645
+
646
+ # Add queue.put() to parallel task list
647
+ put_tasks.append(state.out_queue.put(output_data))
648
+
649
+ # Handle completion
650
+ if output_data["finished"]:
651
+ state.finished = True
652
+ state.finished_time = now
653
+ state.stream_finished = True
654
+ state.event.set()
655
+
656
+ # Remove from tracking after a delay
657
+ async def cleanup(request_id):
658
+ await asyncio.sleep(5.0)
659
+ if request_id in self.rid_to_state:
660
+ del self.rid_to_state[request_id]
661
+
662
+ cleanup_tasks.append(asyncio.create_task(cleanup(rid)))
663
+
664
+ # Execute all queue.put() operations in parallel
665
+ if put_tasks:
666
+ await asyncio.gather(*put_tasks, return_exceptions=True)
667
+
668
+ async def _handle_embedding_output(self, batch_out: BatchEmbeddingOutput):
669
+ """Handle batch embedding output from scheduler."""
670
+ for i, rid in enumerate(batch_out.rids):
671
+ if rid not in self.rid_to_state:
672
+ continue
673
+
674
+ state = self.rid_to_state[rid]
675
+
676
+ # Create result
677
+ result = {
678
+ "request_id": rid,
679
+ "embedding": batch_out.embeddings[i],
680
+ "prompt_tokens": (
681
+ batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
682
+ ),
683
+ "finish_reason": (
684
+ batch_out.finish_reason[i] if batch_out.finish_reason else None
685
+ ),
686
+ }
687
+
688
+ # Send result
689
+ await state.out_queue.put(result)
690
+
691
+ # Mark as finished
692
+ state.finished = True
693
+ state.finished_time = time.time()
694
+ state.event.set()
695
+
696
+ async def _handle_health_check_output(self, health_out: HealthCheckOutput):
697
+ """Handle health check output from scheduler."""
698
+ rid = health_out.rid
699
+
700
+ if rid not in self.rid_to_state:
701
+ logger.warning(f"Health check output for unknown request: {rid}")
702
+ return
703
+
704
+ state = self.rid_to_state[rid]
705
+
706
+ # Create health check result
707
+ result = {
708
+ "request_id": rid,
709
+ "healthy": True, # If we got a response, scheduler is healthy
710
+ "output_text": (
711
+ health_out.output_str if hasattr(health_out, "output_str") else ""
712
+ ),
713
+ "finish_reason": (
714
+ health_out.finish_reason
715
+ if hasattr(health_out, "finish_reason")
716
+ else "stop"
717
+ ),
718
+ }
719
+
720
+ # Send result
721
+ await state.out_queue.put(result)
722
+
723
+ # Mark as finished
724
+ state.finished = True
725
+ state.finished_time = time.time()
726
+ state.event.set()
727
+
728
+ async def _handle_abort_req(self, recv_obj: AbortReq):
729
+ """Handle abort request from scheduler.
730
+
731
+ The scheduler sends AbortReq back to notify us that a request was aborted,
732
+ either due to explicit abort_request() call or scheduler-initiated abort
733
+ (priority preemption, queue full, KV cache pressure, etc).
734
+ """
735
+ # Skip health check requests
736
+ if recv_obj.rid.startswith("HEALTH_CHECK"):
737
+ return
738
+
739
+ # Check if request still exists
740
+ if recv_obj.rid not in self.rid_to_state:
741
+ logger.debug(
742
+ f"Abort request for {recv_obj.rid} not in local state (may have already finished or not started yet)"
743
+ )
744
+ return
745
+
746
+ state = self.rid_to_state[recv_obj.rid]
747
+
748
+ # Mark as finished
749
+ state.finished = True
750
+ state.stream_finished = True
751
+
752
+ # Create abort response
753
+ if recv_obj.finished_reason:
754
+ # Scheduler provided a specific finish reason (e.g., priority preemption, queue full)
755
+ abort_response = {
756
+ "request_id": recv_obj.rid,
757
+ "error": recv_obj.finished_reason.get("message", "Request aborted"),
758
+ "finished": True,
759
+ "meta_info": {
760
+ "id": recv_obj.rid,
761
+ "finish_reason": recv_obj.finished_reason,
762
+ },
763
+ }
764
+ else:
765
+ # Generic abort (e.g., explicit abort_request call)
766
+ abort_response = {
767
+ "request_id": recv_obj.rid,
768
+ "error": "Request aborted",
769
+ "finished": True,
770
+ "meta_info": {
771
+ "id": recv_obj.rid,
772
+ "finish_reason": {
773
+ "type": "abort",
774
+ "message": "Abort before prefill",
775
+ },
776
+ "prompt_tokens": 0,
777
+ "completion_tokens": 0,
778
+ },
779
+ }
780
+
781
+ # Send abort notification to output queue
782
+ await state.out_queue.put(abort_response)
783
+
784
+ # Wake up any waiting coroutines
785
+ state.event.set()
786
+
787
+ logger.debug(f"Handled abort request for {recv_obj.rid}")
788
+
789
+ async def _send_to_scheduler(self, obj):
790
+ """Send an object to the scheduler via ZMQ."""
791
+ try:
792
+ self.send_to_scheduler.send_pyobj(obj)
793
+ except Exception as e:
794
+ logger.error(f"Failed to send to scheduler: {e}")
795
+ raise
796
+
797
+ def record_request_for_crash_dump(self, obj):
798
+ """Record request for potential crash dump."""
799
+ if len(self.crash_dump_request_list) < 100:
800
+ self.crash_dump_request_list.append(
801
+ {
802
+ "time": time.time(),
803
+ "request_id": getattr(obj, "rid", "unknown"),
804
+ "type": type(obj).__name__,
805
+ }
806
+ )
807
+
808
+ async def shutdown(self):
809
+ """Gracefully shutdown the request manager."""
810
+ logger.info("Shutting down GrpcRequestManager")
811
+ self.gracefully_exit = True
812
+
813
+ # Cancel all asyncio tasks FIRST - this will interrupt blocked recv() calls
814
+ for task in list(self.asyncio_tasks):
815
+ if not task.done():
816
+ task.cancel()
817
+
818
+ # Give tasks a moment to process cancellation
819
+ if self.asyncio_tasks:
820
+ await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
821
+
822
+ # Cancel all pending requests
823
+ for rid, state in list(self.rid_to_state.items()):
824
+ if not state.finished:
825
+ await state.out_queue.put(
826
+ {"error": "Server shutting down", "shutdown": True}
827
+ )
828
+ state.finished = True
829
+ state.event.set()
830
+
831
+ # Wait for tasks to complete
832
+ if self.asyncio_tasks:
833
+ await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
834
+
835
+ # Shutdown bootstrap server if running
836
+ if self.bootstrap_server:
837
+ logger.info("Shutting down bootstrap server")
838
+ try:
839
+ if hasattr(self.bootstrap_server, "shutdown"):
840
+ if asyncio.iscoroutinefunction(self.bootstrap_server.shutdown):
841
+ await self.bootstrap_server.shutdown()
842
+ else:
843
+ self.bootstrap_server.shutdown()
844
+ except Exception as e:
845
+ logger.warning(f"Error shutting down bootstrap server: {e}")
846
+
847
+ # Close ZMQ sockets
848
+ self.recv_from_scheduler.close()
849
+ self.send_to_scheduler.close()
850
+
851
+ # Terminate the ZMQ context - this is critical for asyncio loop to exit cleanly
852
+ self.context.term()
853
+
854
+ logger.info("GrpcRequestManager shutdown complete")
855
+
856
+ def get_server_info(self) -> Dict[str, Any]:
857
+ """Get server information for health checks."""
858
+ return {
859
+ "active_requests": len(self.rid_to_state),
860
+ "paused": self.is_pause,
861
+ "last_receive_time": self.last_receive_tstamp,
862
+ }
863
+
864
+ def auto_create_handle_loop(self):
865
+ """Automatically create and start the handle_loop task, matching TokenizerManager pattern."""
866
+ if self.no_create_loop:
867
+ return
868
+
869
+ self.no_create_loop = True
870
+ loop = asyncio.get_event_loop()
871
+ self.asyncio_tasks.add(
872
+ loop.create_task(print_exception_wrapper(self.handle_loop))
873
+ )
874
+
875
+ self.event_loop = loop
876
+
877
+ # We cannot add signal handler when the grpc manager is not in
878
+ # the main thread due to the CPython limitation.
879
+ if threading.current_thread() is threading.main_thread():
880
+ signal_handler = GrpcSignalHandler(self)
881
+ loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
882
+ # Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
883
+ loop.add_signal_handler(
884
+ signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
885
+ )
886
+ else:
887
+ logger.warning(
888
+ "Signal handler is not added because the grpc request manager is "
889
+ "not in the main thread. This disables graceful shutdown of the "
890
+ "grpc request manager when SIGTERM is received."
891
+ )
892
+ self.asyncio_tasks.add(
893
+ loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
894
+ )
895
+
896
+ async def sigterm_watchdog(self):
897
+ """Watchdog to handle SIGTERM gracefully, matching TokenizerManager pattern."""
898
+ while not self.gracefully_exit:
899
+ await asyncio.sleep(1.0)
900
+
901
+
902
+ async def print_exception_wrapper(func):
903
+ """
904
+ Sometimes an asyncio function does not print exception.
905
+ We do another wrapper to handle the exception.
906
+ """
907
+ try:
908
+ await func()
909
+ except Exception:
910
+ traceback = get_exception_traceback()
911
+ logger.error(f"GrpcRequestManager hit an exception: {traceback}")
912
+ if hasattr(func, "__self__") and isinstance(func.__self__, GrpcRequestManager):
913
+ func.__self__.dump_requests_before_crash()
914
+ kill_process_tree(os.getpid(), include_parent=True)
915
+ sys.exit(1)