sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,640 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """Run the model with cpu torch compile."""
15
+
16
+ # The implementation of CPUGraphRunner follows the CudaGraphRunner
17
+
18
+ from __future__ import annotations
19
+
20
+ import logging
21
+ from contextlib import contextmanager
22
+ from typing import TYPE_CHECKING, Callable, Optional, Union
23
+
24
+ import psutil
25
+ import torch
26
+ import tqdm
27
+
28
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
29
+ from sglang.srt.distributed.parallel_state import GroupCoordinator
30
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
31
+ from sglang.srt.model_executor.forward_batch_info import (
32
+ CaptureHiddenMode,
33
+ ForwardBatch,
34
+ ForwardMode,
35
+ PPProxyTensors,
36
+ )
37
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
38
+ from sglang.srt.utils import (
39
+ log_info_on_rank0,
40
+ require_attn_tp_gather,
41
+ require_gathered_buffer,
42
+ require_mlp_sync,
43
+ require_mlp_tp_gather,
44
+ )
45
+ from sglang.srt.utils.patch_torch import monkey_patch_torch_compile
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+ if TYPE_CHECKING:
50
+ from sglang.srt.model_executor.model_runner import ModelRunner
51
+
52
+
53
+ @contextmanager
54
+ def patch_model(
55
+ model: torch.nn.Module,
56
+ enable_compile: bool,
57
+ num_tokens: int,
58
+ tp_group: GroupCoordinator,
59
+ ):
60
+ """Patch the model to make it compatible with torch.compile"""
61
+ backup_ca_comm = None
62
+
63
+ try:
64
+ if enable_compile:
65
+ backup_ca_comm = tp_group.ca_comm
66
+ # Use custom-allreduce here.
67
+ # We found the custom allreduce is much faster than the built-in allreduce in torch,
68
+ # even with ENABLE_INTRA_NODE_COMM=1.
69
+ # tp_group.ca_comm = None
70
+ yield torch.compile(
71
+ torch.no_grad()(model.forward),
72
+ dynamic=False,
73
+ )
74
+ else:
75
+ yield model.forward
76
+ finally:
77
+ if enable_compile:
78
+ tp_group.ca_comm = backup_ca_comm
79
+
80
+
81
+ def set_torch_compile_config():
82
+ import torch._dynamo.config
83
+ import torch._inductor.config
84
+
85
+ torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
86
+ torch._inductor.config.freezing = True
87
+ torch._dynamo.config.accumulated_cache_size_limit = 1024
88
+ if hasattr(torch._dynamo.config, "cache_size_limit"):
89
+ torch._dynamo.config.cache_size_limit = 1024
90
+ monkey_patch_torch_compile()
91
+
92
+
93
+ def get_batch_sizes_to_capture(model_runner: ModelRunner):
94
+ server_args = model_runner.server_args
95
+ # cpu torch compile only speeds up decoding by
96
+ # reducing python overhead when bs is small
97
+ capture_bs = list(range(1, 17))
98
+ capture_bs = [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
99
+ capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
100
+ capture_bs = list(sorted(set(capture_bs)))
101
+ assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}"
102
+ return capture_bs
103
+
104
+
105
+ def register_fake_ops():
106
+ """
107
+ Registers fake/meta implementations for all custom sgl_kernel CPU operators
108
+ using torch.library.register_fake to support torch.compile
109
+ """
110
+
111
+ none_return_ops = [
112
+ "shm_allreduce",
113
+ "bmm_cpu",
114
+ "fused_add_rmsnorm_cpu",
115
+ "decode_attention_cpu",
116
+ "extend_attention_cpu",
117
+ ]
118
+ for op in none_return_ops:
119
+
120
+ @torch.library.register_fake(f"sgl_kernel::{op}")
121
+ def _(*args, **kwargs):
122
+ return
123
+
124
+ for op in [
125
+ "rmsnorm_cpu",
126
+ "l2norm_cpu",
127
+ "fused_experts_cpu",
128
+ "shared_expert_cpu",
129
+ ]:
130
+
131
+ @torch.library.register_fake(f"sgl_kernel::{op}")
132
+ def _(input, *args, **kwargs):
133
+ return torch.empty_like(input)
134
+
135
+ @torch.library.register_fake("sgl_kernel::qkv_proj_with_rope")
136
+ def _(
137
+ hidden_states,
138
+ q_a_proj_weight,
139
+ q_b_proj_weight,
140
+ kv_a_proj_weight,
141
+ w_kc,
142
+ q_a_layernorm_weight,
143
+ kv_a_layernorm_weight,
144
+ positions,
145
+ cos_sin_cache,
146
+ eps,
147
+ use_int8_w8a8,
148
+ use_fp8_w8a16,
149
+ q_a_proj_scale,
150
+ q_b_proj_scale,
151
+ kv_a_proj_scale,
152
+ is_vnni,
153
+ block_size,
154
+ ):
155
+ num_seqs = hidden_states.shape[0]
156
+ num_heads = w_kc.shape[0]
157
+ kv_lora_rank = w_kc.shape[1]
158
+ qk_rope_head_dim = kv_a_proj_weight.shape[0] - kv_lora_rank
159
+ q_input = torch.empty(
160
+ num_seqs,
161
+ num_heads,
162
+ kv_lora_rank + qk_rope_head_dim,
163
+ dtype=hidden_states.dtype,
164
+ device=hidden_states.device,
165
+ )
166
+ k_input = torch.empty(
167
+ num_seqs,
168
+ 1,
169
+ kv_lora_rank + qk_rope_head_dim,
170
+ dtype=hidden_states.dtype,
171
+ device=hidden_states.device,
172
+ )
173
+ v_input = k_input.narrow(-1, 0, kv_lora_rank)
174
+ return q_input, k_input, v_input
175
+
176
+ @torch.library.register_fake("sgl_kernel::rotary_embedding_cpu")
177
+ def _(positions, query, key, head_size, cos_sin_cache, is_neox):
178
+ if query.ndim == 2:
179
+ return query, key
180
+ else:
181
+ return torch.empty_like(query), torch.empty_like(key)
182
+
183
+ @torch.library.register_fake("sgl_kernel::qkv_proj_with_rope_fused_weight")
184
+ def _(
185
+ hidden_states,
186
+ q_a_proj_weight,
187
+ q_b_proj_weight,
188
+ w_kc,
189
+ q_a_layernorm_weight,
190
+ kv_a_layernorm_weight,
191
+ positions,
192
+ cos_sin_cache,
193
+ eps,
194
+ use_int8_w8a8,
195
+ use_fp8_w8a16,
196
+ qkv_a_proj_scale,
197
+ q_b_proj_scale,
198
+ is_vnni,
199
+ block_size,
200
+ q_lora_rank,
201
+ kv_lora_rank,
202
+ qk_rope_head_dim,
203
+ ):
204
+ num_seqs = hidden_states.shape[0]
205
+ num_heads = w_kc.shape[0]
206
+ kv_lora_rank = w_kc.shape[1]
207
+ weight_chunks = torch.split(
208
+ q_a_proj_weight, [q_lora_rank, kv_lora_rank + qk_rope_head_dim], dim=0
209
+ )
210
+ qk_rope_head_dim = weight_chunks[1].shape[0] - kv_lora_rank
211
+ q_input = torch.empty(
212
+ num_seqs,
213
+ num_heads,
214
+ kv_lora_rank + qk_rope_head_dim,
215
+ dtype=hidden_states.dtype,
216
+ device=hidden_states.device,
217
+ )
218
+ k_input = torch.empty(
219
+ num_seqs,
220
+ 1,
221
+ kv_lora_rank + qk_rope_head_dim,
222
+ dtype=hidden_states.dtype,
223
+ device=hidden_states.device,
224
+ )
225
+ v_input = k_input.narrow(-1, 0, kv_lora_rank)
226
+ return q_input, k_input, v_input
227
+
228
+ @torch.library.register_fake("sgl_kernel::weight_packed_linear")
229
+ def _(x, weight, bias, is_vnni):
230
+ return x.new_empty(x.shape[0], weight.shape[0])
231
+
232
+ @torch.library.register_fake("sgl_kernel::per_token_quant_int8_cpu")
233
+ def _(input):
234
+ M = input.shape[0]
235
+ K = input.shape[1]
236
+ Aq = input.new_empty(M, K, dtype=torch.int8)
237
+ As = input.new_empty(M, dtype=torch.float32)
238
+ return Aq, As
239
+
240
+ @torch.library.register_fake("sgl_kernel::int8_scaled_mm_cpu")
241
+ def _(mat1, mat2, scales1, scales2, bias, out_dtype, is_vnni):
242
+ M = mat1.shape[0]
243
+ N = mat2.shape[0]
244
+ out = mat1.new_empty(M, N, dtype=out_dtype)
245
+ return out
246
+
247
+ @torch.library.register_fake("sgl_kernel::grouped_topk_cpu")
248
+ def _(
249
+ hidden_states,
250
+ gating_output,
251
+ topk,
252
+ renormalize,
253
+ num_expert_group,
254
+ topk_group,
255
+ num_fused_shared_experts,
256
+ routed_scaling_factor,
257
+ num_token_non_padded,
258
+ ):
259
+ num_tokens = hidden_states.shape[0]
260
+ shape = (num_tokens, topk)
261
+ device = hidden_states.device
262
+ topk_weights = torch.empty(shape, device=device, dtype=torch.float32)
263
+ topk_ids = torch.empty(shape, device=device, dtype=torch.int)
264
+ return topk_weights, topk_ids
265
+
266
+ @torch.library.register_fake("sgl_kernel::biased_grouped_topk_cpu")
267
+ def _(
268
+ hidden_states,
269
+ gating_output,
270
+ correction_bias,
271
+ topk,
272
+ renormalize,
273
+ num_expert_group,
274
+ topk_group,
275
+ num_fused_shared_experts,
276
+ routed_scaling_factor,
277
+ num_token_non_padded,
278
+ ):
279
+ num_tokens = hidden_states.shape[0]
280
+ shape = (num_tokens, topk)
281
+ device = hidden_states.device
282
+ topk_weights = torch.empty(shape, device=device, dtype=torch.float32)
283
+ topk_ids = torch.empty(shape, device=device, dtype=torch.int)
284
+ return topk_weights, topk_ids
285
+
286
+ @torch.library.register_fake("sgl_kernel::topk_sigmoid_cpu")
287
+ def _(hidden_states, gating_output, topk, renormalize):
288
+ num_tokens = hidden_states.shape[0]
289
+ shape = (num_tokens, topk)
290
+ return (
291
+ torch.empty(shape, device=hidden_states.device, dtype=torch.float),
292
+ torch.empty(shape, device=hidden_states.device, dtype=torch.int),
293
+ )
294
+
295
+ @torch.library.register_fake("sgl_kernel::topk_softmax_cpu")
296
+ def _(
297
+ hidden_states,
298
+ gating_output,
299
+ topk,
300
+ renormalize,
301
+ ):
302
+ num_tokens = hidden_states.shape[0]
303
+ shape = (num_tokens, topk)
304
+ return (
305
+ torch.empty(shape, device=hidden_states.device, dtype=torch.float),
306
+ torch.empty(shape, device=hidden_states.device, dtype=torch.int),
307
+ )
308
+
309
+ @torch.library.register_fake("sgl_kernel::silu_and_mul_cpu")
310
+ def _(input):
311
+ return input.new_empty(input.shape[0], input.shape[1] // 2)
312
+
313
+ @torch.library.register_fake("sgl_kernel::int8_scaled_mm_with_quant")
314
+ def _(
315
+ mat1,
316
+ mat2,
317
+ scales2,
318
+ bias,
319
+ out_dtype,
320
+ is_vnni,
321
+ ):
322
+ M = mat1.shape[0]
323
+ N = mat2.shape[0]
324
+ return mat1.new_empty(M, N, dtype=out_dtype)
325
+
326
+ @torch.library.register_fake("sgl_kernel::fp8_scaled_mm_cpu")
327
+ def _(
328
+ mat1,
329
+ mat2,
330
+ scales2,
331
+ block_size,
332
+ bias,
333
+ out_dtype,
334
+ is_vnni,
335
+ ):
336
+ M = mat1.shape[0]
337
+ N = mat2.shape[0]
338
+ return mat1.new_empty(M, N, dtype=out_dtype)
339
+
340
+
341
+ # TODO Remove unnecessary settings for CPUGraphRunner.
342
+ # Re-abstract the graph runner and restructure CPUGraphRunner to reuse the same logic.
343
+ class CPUGraphRunner:
344
+ """A CPUGraphRunner runs the forward pass of a model with cpu torch.compile."""
345
+
346
+ def __init__(self, model_runner: ModelRunner):
347
+ # Parse args
348
+ self.model_runner = model_runner
349
+ self.device = model_runner.device
350
+ self.graphs = {}
351
+ self.output_buffers = {}
352
+ self.enable_torch_compile = model_runner.server_args.enable_torch_compile
353
+ self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
354
+ self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
355
+ self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
356
+ self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
357
+ self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
358
+ self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
359
+ self.enable_two_batch_overlap = (
360
+ model_runner.server_args.enable_two_batch_overlap
361
+ )
362
+ self.speculative_algorithm = model_runner.server_args.speculative_algorithm
363
+ self.enable_profile_cuda_graph = (
364
+ model_runner.server_args.enable_profile_cuda_graph
365
+ )
366
+ self.tp_size = model_runner.server_args.tp_size
367
+ self.dp_size = model_runner.server_args.dp_size
368
+ self.pp_size = model_runner.server_args.pp_size
369
+
370
+ self.capture_forward_mode = ForwardMode.DECODE
371
+ self.capture_hidden_mode = CaptureHiddenMode.NULL
372
+ self.num_tokens_per_bs = 1
373
+
374
+ # If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
375
+ if model_runner.server_args.enable_return_hidden_states:
376
+ self.capture_hidden_mode = CaptureHiddenMode.FULL
377
+
378
+ assert (
379
+ not self.model_runner.server_args.enable_lora
380
+ ), "CPUGraphRunner does not support LoRA yet."
381
+ assert (
382
+ not self.enable_two_batch_overlap
383
+ ), "CPUGraphRunner does not support two batch overlap yet."
384
+ assert (
385
+ not self.require_mlp_tp_gather
386
+ ), "CPUGraphRunner does not support MLP TP gather yet."
387
+ assert (
388
+ not self.require_mlp_sync
389
+ ), "CPUGraphRunner does not support MLP sync yet."
390
+ assert (
391
+ not self.require_gathered_buffer
392
+ ), "CPUGraphRunner does not support gathered buffer yet."
393
+ assert (
394
+ model_runner.spec_algorithm == SpeculativeAlgorithm.NONE
395
+ ), "CPUGraphRunner does not support speculative inference yet."
396
+ # TODO add compile support for encoder-decoder models
397
+ assert (
398
+ not self.is_encoder_decoder
399
+ ), "CPUGraphRunner does not support encoder-decoder models yet."
400
+ assert self.dp_size == 1, "CPUGraphRunner does not support DP yet."
401
+ assert self.pp_size == 1, "CPUGraphRunner does not support PP yet."
402
+
403
+ # Batch sizes to capture
404
+ self.capture_bs = get_batch_sizes_to_capture(model_runner)
405
+ log_info_on_rank0(logger, f"Capture cpu graph bs {self.capture_bs}")
406
+ # Attention backend
407
+ self.max_bs = max(self.capture_bs)
408
+ self.max_num_token = self.max_bs * self.num_tokens_per_bs
409
+
410
+ self.seq_len_fill_value = (
411
+ self.model_runner.attn_backend.get_graph_seq_len_fill_value()
412
+ )
413
+
414
+ if self.enable_torch_compile:
415
+ register_fake_ops()
416
+ set_torch_compile_config()
417
+
418
+ # Graph inputs
419
+ with torch.device(self.device):
420
+ self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
421
+ self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int64)
422
+ self.seq_lens = torch.full(
423
+ (self.max_bs,), self.seq_len_fill_value, dtype=torch.int64
424
+ )
425
+ self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
426
+ self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
427
+ self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
428
+ self.num_token_non_padded = torch.zeros((1,), dtype=torch.int64)
429
+ self.custom_mask = torch.ones(
430
+ (
431
+ (self.seq_lens.sum().item() + self.max_num_token)
432
+ * self.num_tokens_per_bs
433
+ ),
434
+ dtype=torch.bool,
435
+ device=self.device,
436
+ )
437
+
438
+ # Capture
439
+ try:
440
+ self.capture()
441
+ except RuntimeError as e:
442
+ raise Exception(
443
+ f"Capture CPU graph failed: {e}\n{CPU_GRAPH_CAPTURE_FAILED_MSG}"
444
+ )
445
+
446
+ def can_run(self, forward_batch: ForwardBatch):
447
+ is_bs_supported = forward_batch.batch_size in self.graphs
448
+
449
+ requested_capture_hidden_mode = max(
450
+ forward_batch.capture_hidden_mode,
451
+ (
452
+ forward_batch.spec_info.capture_hidden_mode
453
+ if getattr(forward_batch.spec_info, "capture_hidden_mode", None)
454
+ is not None
455
+ else CaptureHiddenMode.NULL
456
+ ),
457
+ )
458
+ capture_hidden_mode_matches = (
459
+ requested_capture_hidden_mode == CaptureHiddenMode.NULL
460
+ or requested_capture_hidden_mode == self.capture_hidden_mode
461
+ )
462
+
463
+ return is_bs_supported and capture_hidden_mode_matches
464
+
465
+ def capture(self) -> None:
466
+ capture_range = (
467
+ tqdm.tqdm(list(reversed(self.capture_bs)))
468
+ if get_tensor_model_parallel_rank() == 0
469
+ else reversed(self.capture_bs)
470
+ )
471
+ for bs in capture_range:
472
+ if get_tensor_model_parallel_rank() == 0:
473
+ avail_mem = psutil.virtual_memory().available / (1 << 30)
474
+ capture_range.set_description(
475
+ f"Capturing batches ({bs=} {avail_mem=:.2f} GB)"
476
+ )
477
+
478
+ with patch_model(
479
+ self.model_runner.model,
480
+ bs in self.capture_bs,
481
+ num_tokens=bs * self.num_tokens_per_bs,
482
+ tp_group=self.model_runner.tp_group,
483
+ ) as forward:
484
+ (
485
+ graph,
486
+ output_buffers,
487
+ ) = self.capture_one_batch_size(bs, forward)
488
+ self.graphs[bs] = graph
489
+ self.output_buffers[bs] = output_buffers
490
+
491
+ def capture_one_batch_size(self, bs: int, forward: Callable):
492
+ num_tokens = bs * self.num_tokens_per_bs
493
+
494
+ # Graph inputs
495
+ input_ids = self.input_ids[:num_tokens]
496
+ req_pool_indices = self.req_pool_indices[:bs]
497
+ seq_lens = self.seq_lens[:bs]
498
+ out_cache_loc = self.out_cache_loc[:num_tokens]
499
+ positions = self.positions[:num_tokens]
500
+ mrope_positions = self.mrope_positions[:, :bs]
501
+ self.num_token_non_padded[...] = num_tokens
502
+
503
+ spec_info = self.get_spec_info(num_tokens)
504
+ if self.capture_hidden_mode != CaptureHiddenMode.FULL:
505
+ self.capture_hidden_mode = (
506
+ spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
507
+ )
508
+
509
+ forward_batch = ForwardBatch(
510
+ forward_mode=self.capture_forward_mode,
511
+ batch_size=bs,
512
+ input_ids=input_ids,
513
+ req_pool_indices=req_pool_indices,
514
+ seq_lens=seq_lens,
515
+ req_to_token_pool=self.model_runner.req_to_token_pool,
516
+ token_to_kv_pool=self.model_runner.token_to_kv_pool,
517
+ attn_backend=self.model_runner.attn_backend,
518
+ out_cache_loc=out_cache_loc,
519
+ seq_lens_sum=seq_lens.sum().item(),
520
+ return_logprob=False,
521
+ positions=positions,
522
+ mrope_positions=mrope_positions,
523
+ spec_algorithm=self.model_runner.spec_algorithm,
524
+ spec_info=spec_info,
525
+ capture_hidden_mode=self.capture_hidden_mode,
526
+ num_token_non_padded=self.num_token_non_padded,
527
+ global_forward_mode=self.capture_forward_mode,
528
+ )
529
+
530
+ # Attention backend
531
+ self.model_runner.attn_backend.init_forward_metadata(forward_batch)
532
+ # Do infernence to avoid setting attr at runtime, e.g.,
533
+ # self.attn_mha.kv_b_proj = self.kv_b_proj for full graph compile on CPU
534
+ self.model_runner.model.forward(
535
+ forward_batch.input_ids,
536
+ forward_batch.positions,
537
+ forward_batch,
538
+ )
539
+
540
+ # Run and capture
541
+ def run_once():
542
+ # Clean intermediate result cache for DP attention
543
+ forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
544
+ logits_output_or_pp_proxy_tensors = forward(
545
+ input_ids,
546
+ forward_batch.positions,
547
+ forward_batch,
548
+ )
549
+ return logits_output_or_pp_proxy_tensors
550
+
551
+ with torch.no_grad():
552
+ for _ in range(2):
553
+ self.model_runner.tp_group.barrier()
554
+ out = run_once()
555
+ return forward, out
556
+
557
+ def recapture_if_needed(self, forward_batch: ForwardBatch):
558
+
559
+ # If the required capture_hidden_mode changes, we need to recapture the graph
560
+
561
+ # These are the different factors that can influence the capture_hidden_mode
562
+ capture_hidden_mode_required_by_forward_batch = (
563
+ forward_batch.capture_hidden_mode
564
+ )
565
+ capture_hidden_mode_required_by_spec_info = getattr(
566
+ forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
567
+ )
568
+ capture_hidden_mode_required_for_returning_hidden_states = (
569
+ CaptureHiddenMode.FULL
570
+ if self.model_runner.server_args.enable_return_hidden_states
571
+ else CaptureHiddenMode.NULL
572
+ )
573
+
574
+ # Determine the highest capture_hidden_mode required
575
+ # (If we have FULL, we can emulate LAST or NULL)
576
+ # (If we have LAST, we can emulate NULL)
577
+ required_capture_hidden_mode = max(
578
+ capture_hidden_mode_required_by_forward_batch,
579
+ capture_hidden_mode_required_by_spec_info,
580
+ capture_hidden_mode_required_for_returning_hidden_states,
581
+ )
582
+
583
+ # If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture
584
+ if self.capture_hidden_mode != required_capture_hidden_mode:
585
+ self.capture_hidden_mode = required_capture_hidden_mode
586
+ self.capture()
587
+
588
+ # TODO add padding support for CPUGraphRunner
589
+ def replay(
590
+ self,
591
+ forward_batch: ForwardBatch,
592
+ skip_attn_backend_init: bool = False,
593
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
594
+ ) -> Union[LogitsProcessorOutput, PPProxyTensors]:
595
+ assert (
596
+ pp_proxy_tensors is None
597
+ ), "PPProxyTensors is not supported in CPUGraphRunner yet."
598
+ self.recapture_if_needed(forward_batch)
599
+ self.model_runner.attn_backend.init_forward_metadata(forward_batch)
600
+ output = self.graphs[forward_batch.batch_size](
601
+ forward_batch.input_ids,
602
+ forward_batch.positions,
603
+ forward_batch,
604
+ )
605
+ return output
606
+
607
+ def get_spec_info(self, num_tokens: int):
608
+ spec_info = None
609
+ if self.model_runner.spec_algorithm.is_eagle():
610
+ from sglang.srt.speculative.eagle_info import EagleVerifyInput
611
+
612
+ if self.model_runner.is_draft_worker:
613
+ raise RuntimeError("This should not happen.")
614
+ else:
615
+ spec_info = EagleVerifyInput(
616
+ draft_token=None,
617
+ custom_mask=self.custom_mask,
618
+ positions=None,
619
+ retrive_index=None,
620
+ retrive_next_token=None,
621
+ retrive_next_sibling=None,
622
+ retrive_cum_len=None,
623
+ spec_steps=self.model_runner.server_args.speculative_num_steps,
624
+ topk=self.model_runner.server_args.speculative_eagle_topk,
625
+ draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
626
+ capture_hidden_mode=CaptureHiddenMode.FULL,
627
+ seq_lens_sum=None,
628
+ seq_lens_cpu=None,
629
+ )
630
+
631
+ return spec_info
632
+
633
+
634
+ CPU_GRAPH_CAPTURE_FAILED_MSG = (
635
+ "Possible solutions:\n"
636
+ "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
637
+ "2. set --torch-compile-max-bs to a smaller value (e.g., 8)\n"
638
+ "3. disable torch compile by not using --enable-torch-compile\n"
639
+ "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
640
+ )