sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -19,10 +19,9 @@ import logging
19
19
  import threading
20
20
  from typing import TYPE_CHECKING, Optional, Union
21
21
 
22
- import numpy as np
23
22
  import torch
24
23
 
25
- from sglang.srt.configs.model_config import AttentionArch
24
+ from sglang.srt.configs.model_config import is_deepseek_nsa
26
25
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
27
26
 
28
27
  logger = logging.getLogger(__name__)
@@ -75,7 +74,7 @@ class NPUGraphRunner(CudaGraphRunner):
75
74
  self.positions[: self.raw_num_token].copy_(forward_batch.positions)
76
75
 
77
76
  # Replay
78
- if self.model_runner.model_config.index_head_dim is None:
77
+ if not is_deepseek_nsa(self.model_runner.model_config.hf_config):
79
78
  seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
80
79
  self.bs - self.raw_bs
81
80
  )
@@ -0,0 +1,549 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """Run the model with cuda graph and torch.compile."""
15
+
16
+ from __future__ import annotations
17
+
18
+ import bisect
19
+ import gc
20
+ import logging
21
+ from contextlib import contextmanager
22
+ from typing import TYPE_CHECKING, Union
23
+
24
+ import torch
25
+ import tqdm
26
+
27
+ from sglang.srt.compilation.compilation_config import CompilationConfig
28
+ from sglang.srt.compilation.compile import install_torch_compiled, set_compiled
29
+ from sglang.srt.compilation.piecewise_context_manager import set_forward_context
30
+ from sglang.srt.custom_op import CustomOp
31
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
32
+ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
33
+ set_graph_pool_id,
34
+ )
35
+ from sglang.srt.layers.dp_attention import (
36
+ DpPaddingMode,
37
+ get_attention_tp_rank,
38
+ get_attention_tp_size,
39
+ set_dp_buffer_len,
40
+ set_is_extend_in_batch,
41
+ )
42
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
43
+ from sglang.srt.layers.torchao_utils import save_gemlite_cache
44
+ from sglang.srt.model_executor.forward_batch_info import (
45
+ CaptureHiddenMode,
46
+ ForwardBatch,
47
+ ForwardMode,
48
+ PPProxyTensors,
49
+ )
50
+ from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
51
+ from sglang.srt.utils import get_available_gpu_memory, log_info_on_rank0
52
+
53
+ logger = logging.getLogger(__name__)
54
+
55
+ if TYPE_CHECKING:
56
+ from sglang.srt.model_executor.model_runner import ModelRunner
57
+
58
+ # Detect whether the current forward pass is in capture mode
59
+ is_capture_mode = False
60
+
61
+
62
+ def get_is_capture_mode():
63
+ return is_capture_mode
64
+
65
+
66
+ @contextmanager
67
+ def model_capture_mode():
68
+ global is_capture_mode
69
+ is_capture_mode = True
70
+
71
+ yield
72
+
73
+ is_capture_mode = False
74
+
75
+
76
+ @contextmanager
77
+ def freeze_gc(enable_cudagraph_gc: bool):
78
+ """
79
+ Optimize garbage collection during CUDA graph capture.
80
+ Clean up, then freeze all remaining objects from being included
81
+ in future collections if GC is disabled during capture.
82
+ """
83
+ gc.collect()
84
+ should_freeze = not enable_cudagraph_gc
85
+ if should_freeze:
86
+ gc.freeze()
87
+ try:
88
+ yield
89
+ finally:
90
+ if should_freeze:
91
+ gc.unfreeze()
92
+
93
+
94
+ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
95
+ for sub in model._modules.values():
96
+ if isinstance(sub, CustomOp):
97
+ if reverse:
98
+ sub.leave_torch_compile()
99
+ else:
100
+ sub.enter_torch_compile(num_tokens=num_tokens)
101
+ if isinstance(sub, torch.nn.Module):
102
+ _to_torch(sub, reverse, num_tokens)
103
+
104
+
105
+ @contextmanager
106
+ def patch_model(model: torch.nn.Module, compiler: str):
107
+ try:
108
+ if compiler != "eager":
109
+ _to_torch(model, reverse=False, num_tokens=16)
110
+ yield model
111
+ finally:
112
+ _to_torch(model, reverse=True, num_tokens=16)
113
+
114
+
115
+ # Reuse this memory pool across all cuda graph runners.
116
+ global_graph_memory_pool = None
117
+
118
+
119
+ def get_global_graph_memory_pool():
120
+ return global_graph_memory_pool
121
+
122
+
123
+ def set_global_graph_memory_pool(val):
124
+ global global_graph_memory_pool
125
+ global_graph_memory_pool = val
126
+
127
+
128
+ class PiecewiseCudaGraphRunner:
129
+ """A PiecewiseCudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
130
+
131
+ def __init__(self, model_runner: ModelRunner):
132
+ # Parse args
133
+ self.model_runner = model_runner
134
+ self.device = model_runner.device
135
+ self.device_module = torch.get_device_module(self.device)
136
+ self.graphs = {}
137
+ self.output_buffers = {}
138
+ self.tp_size = model_runner.server_args.tp_size
139
+ self.dp_size = model_runner.server_args.dp_size
140
+ self.pp_size = model_runner.server_args.pp_size
141
+
142
+ self.attn_tp_size = get_attention_tp_size()
143
+ self.attn_tp_rank = get_attention_tp_rank()
144
+
145
+ assert (
146
+ self.model_runner.server_args.piecewise_cuda_graph_tokens is not None
147
+ ), "piecewise_cuda_graph_tokens is not set"
148
+ assert self.model_runner.server_args.piecewise_cuda_graph_compiler in [
149
+ "eager",
150
+ "inductor",
151
+ ], "By now, only eager and inductor are supported for piecewise cuda graph compiler."
152
+ self.compile_config = CompilationConfig(
153
+ self.model_runner.server_args.piecewise_cuda_graph_tokens,
154
+ self.model_runner.server_args.piecewise_cuda_graph_compiler,
155
+ )
156
+
157
+ # Batch sizes to capture
158
+ self.capture_num_tokens = self.compile_config.get_capture_sizes()
159
+ log_info_on_rank0(
160
+ logger, f"Capture cuda graph num tokens {self.capture_num_tokens}"
161
+ )
162
+ self.capture_forward_mode = ForwardMode.EXTEND
163
+ self.capture_hidden_mode = CaptureHiddenMode.NULL
164
+
165
+ # If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
166
+ if model_runner.server_args.enable_return_hidden_states:
167
+ self.capture_hidden_mode = CaptureHiddenMode.FULL
168
+
169
+ # Attention backend
170
+ self.max_num_tokens = max(self.capture_num_tokens)
171
+
172
+ # Graph inputs
173
+ with torch.device(self.device):
174
+ self.input_ids = torch.zeros((self.max_num_tokens,), dtype=torch.int64)
175
+ self.out_cache_loc = torch.zeros(
176
+ (self.max_num_tokens,), dtype=self._cache_loc_dtype()
177
+ )
178
+ self.positions = torch.zeros((self.max_num_tokens,), dtype=torch.int64)
179
+ self.tbo_plugin = TboCudaGraphRunnerPlugin()
180
+
181
+ self.attention_layers = self.model_runner.attention_layers
182
+
183
+ if get_global_graph_memory_pool() is None:
184
+ set_global_graph_memory_pool(self.device_module.graph_pool_handle())
185
+ # Set graph pool id globally to be able to use symmetric memory
186
+ set_graph_pool_id(get_global_graph_memory_pool())
187
+
188
+ with patch_model(
189
+ self.model_runner.model.model, self.compile_config.compiler
190
+ ) as patched_model:
191
+ install_torch_compiled(
192
+ patched_model,
193
+ fullgraph=True,
194
+ dynamic_arg_dims=None,
195
+ compile_config=self.compile_config,
196
+ graph_pool=get_global_graph_memory_pool(),
197
+ )
198
+
199
+ with set_compiled(True):
200
+ self.warmup_and_capture()
201
+
202
+ # Capture
203
+ try:
204
+ with model_capture_mode():
205
+ self.capture()
206
+ except RuntimeError as e:
207
+ raise Exception(
208
+ f"Capture cuda graph failed: {e}\n{PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG}"
209
+ )
210
+
211
+ self.raw_num_tokens = 0
212
+
213
+ def warmup_and_capture(self):
214
+ num_tokens = 2
215
+ with torch.device(self.device):
216
+ forward_batch = ForwardBatch(
217
+ forward_mode=ForwardMode.EXTEND,
218
+ batch_size=1,
219
+ input_ids=torch.randint(0, 100, (num_tokens,), device=self.device),
220
+ req_pool_indices=torch.arange(1, device=self.device),
221
+ seq_lens=torch.tensor([num_tokens], device=self.device),
222
+ next_token_logits_buffer=None,
223
+ orig_seq_lens=torch.tensor([num_tokens], device=self.device),
224
+ seq_lens_cpu=torch.tensor([num_tokens]),
225
+ req_to_token_pool=self.model_runner.req_to_token_pool,
226
+ token_to_kv_pool=self.model_runner.token_to_kv_pool,
227
+ attn_backend=self.model_runner.attn_backend,
228
+ out_cache_loc=torch.randint(0, 100, (num_tokens,), device=self.device),
229
+ seq_lens_sum=num_tokens,
230
+ encoder_lens=None,
231
+ return_logprob=False,
232
+ extend_seq_lens=torch.tensor([num_tokens], device=self.device),
233
+ extend_prefix_lens=torch.tensor([num_tokens], device=self.device),
234
+ extend_start_loc=torch.tensor([0], device=self.device),
235
+ extend_prefix_lens_cpu=torch.tensor([num_tokens]),
236
+ extend_seq_lens_cpu=torch.tensor([num_tokens]),
237
+ extend_logprob_start_lens_cpu=torch.tensor([num_tokens]),
238
+ positions=torch.arange(num_tokens, device=self.device),
239
+ global_num_tokens_gpu=None,
240
+ global_num_tokens_for_logprob_gpu=None,
241
+ dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
242
+ global_dp_buffer_len=None,
243
+ mrope_positions=None,
244
+ spec_algorithm=None,
245
+ spec_info=None,
246
+ capture_hidden_mode=CaptureHiddenMode.NULL,
247
+ num_token_non_padded=None,
248
+ global_forward_mode=ForwardMode.EXTEND,
249
+ lora_ids=None,
250
+ )
251
+
252
+ # Attention backend
253
+ self.model_runner.attn_backend.init_forward_metadata(forward_batch)
254
+
255
+ with set_forward_context(forward_batch, self.attention_layers):
256
+ _ = self.model_runner.model.forward(
257
+ forward_batch.input_ids,
258
+ forward_batch.positions,
259
+ forward_batch,
260
+ )
261
+
262
+ def _cache_loc_dtype(self):
263
+ return torch.int64
264
+
265
+ def can_run(self, forward_batch: ForwardBatch):
266
+ num_tokens = len(forward_batch.input_ids)
267
+ # TODO(yuwei): support return input_ids' logprob
268
+ if forward_batch.return_logprob:
269
+ for start_len, seq_len in zip(
270
+ forward_batch.extend_logprob_start_lens_cpu,
271
+ forward_batch.extend_seq_lens_cpu,
272
+ ):
273
+ if start_len is not None and start_len < seq_len:
274
+ return False
275
+ if num_tokens <= self.max_num_tokens:
276
+ return True
277
+ return False
278
+
279
+ def capture(self) -> None:
280
+ # Trigger CUDA graph capture for specific shapes.
281
+ # Capture the large shapes first so that the smaller shapes
282
+ # can reuse the memory pool allocated for the large shapes.
283
+ with freeze_gc(self.model_runner.server_args.enable_cudagraph_gc):
284
+ if self.model_runner.tp_group.ca_comm is not None:
285
+ old_ca_disable = self.model_runner.tp_group.ca_comm.disabled
286
+ self.model_runner.tp_group.ca_comm.disabled = True
287
+ avail_mem = get_available_gpu_memory(
288
+ self.model_runner.device,
289
+ self.model_runner.gpu_id,
290
+ empty_cache=False,
291
+ )
292
+ # Reverse the order to enable better memory sharing across cuda graphs.
293
+ capture_range = (
294
+ tqdm.tqdm(list(reversed(self.capture_num_tokens)))
295
+ if get_tensor_model_parallel_rank() == 0
296
+ else reversed(self.capture_num_tokens)
297
+ )
298
+ for i, num_tokens in enumerate(capture_range):
299
+ if get_tensor_model_parallel_rank() == 0:
300
+ avail_mem = get_available_gpu_memory(
301
+ self.model_runner.device,
302
+ self.model_runner.gpu_id,
303
+ empty_cache=False,
304
+ )
305
+ capture_range.set_description(
306
+ f"Capturing num tokens ({num_tokens=} {avail_mem=:.2f} GB)"
307
+ )
308
+
309
+ with set_compiled(True):
310
+ self.capture_one_batch_size(num_tokens)
311
+
312
+ # Save gemlite cache after each capture
313
+ save_gemlite_cache()
314
+ if self.model_runner.tp_group.ca_comm is not None:
315
+ self.model_runner.tp_group.ca_comm.disabled = old_ca_disable
316
+
317
+ def capture_one_batch_size(self, num_tokens: int):
318
+ bs = 1
319
+
320
+ # Graph inputs
321
+ input_ids = self.input_ids[:num_tokens]
322
+ out_cache_loc = self.out_cache_loc[:num_tokens]
323
+ positions = self.positions[:num_tokens]
324
+
325
+ # pipeline parallelism
326
+ if self.pp_size > 1:
327
+ pp_proxy_tensors = PPProxyTensors(
328
+ {k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
329
+ )
330
+
331
+ global_dp_buffer_len = None
332
+
333
+ if self.model_runner.server_args.enable_lora:
334
+ # It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever
335
+ # `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization).
336
+ lora_ids = [None] * bs
337
+ else:
338
+ lora_ids = None
339
+
340
+ with torch.device(self.device):
341
+ forward_batch = ForwardBatch(
342
+ forward_mode=ForwardMode.EXTEND,
343
+ batch_size=bs,
344
+ input_ids=input_ids,
345
+ req_pool_indices=torch.arange(bs, device=self.device),
346
+ seq_lens=torch.tensor([num_tokens], device=self.device),
347
+ next_token_logits_buffer=None,
348
+ orig_seq_lens=torch.tensor([num_tokens], device=self.device),
349
+ seq_lens_cpu=torch.tensor([num_tokens]),
350
+ req_to_token_pool=self.model_runner.req_to_token_pool,
351
+ token_to_kv_pool=self.model_runner.token_to_kv_pool,
352
+ attn_backend=self.model_runner.attn_backend,
353
+ out_cache_loc=out_cache_loc,
354
+ seq_lens_sum=num_tokens,
355
+ encoder_lens=None,
356
+ return_logprob=False,
357
+ extend_seq_lens=torch.tensor([num_tokens], device=self.device),
358
+ extend_prefix_lens=torch.tensor([num_tokens], device=self.device),
359
+ extend_start_loc=torch.tensor([0], device=self.device),
360
+ extend_prefix_lens_cpu=torch.tensor([num_tokens]),
361
+ extend_seq_lens_cpu=torch.tensor([num_tokens]),
362
+ extend_logprob_start_lens_cpu=torch.tensor([num_tokens]),
363
+ positions=positions,
364
+ global_num_tokens_gpu=None,
365
+ global_num_tokens_for_logprob_gpu=None,
366
+ dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
367
+ global_dp_buffer_len=None,
368
+ mrope_positions=None,
369
+ spec_algorithm=None,
370
+ spec_info=None,
371
+ capture_hidden_mode=CaptureHiddenMode.NULL,
372
+ num_token_non_padded=None,
373
+ global_forward_mode=ForwardMode.EXTEND,
374
+ lora_ids=None,
375
+ )
376
+ self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
377
+
378
+ if lora_ids is not None:
379
+ self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
380
+
381
+ # Run and capture
382
+ def run_once():
383
+ # Clean intermediate result cache for DP attention
384
+ forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
385
+ set_dp_buffer_len(global_dp_buffer_len, num_tokens)
386
+ # FIXME: the implementation is hacky. `is_extend_in_batch`` is for determining the deepep mode.
387
+ # It is True in this context but we need to set it to use low latency deepep mode.
388
+ set_is_extend_in_batch(False)
389
+
390
+ kwargs = {}
391
+ with set_forward_context(forward_batch, self.attention_layers):
392
+ self.model_runner.model.forward(
393
+ forward_batch.input_ids,
394
+ forward_batch.positions,
395
+ forward_batch,
396
+ **kwargs,
397
+ )
398
+ return
399
+
400
+ for _ in range(2):
401
+ self.device_module.synchronize()
402
+ self.model_runner.tp_group.barrier()
403
+ run_once()
404
+
405
+ return
406
+
407
+ def replay_prepare(
408
+ self,
409
+ forward_batch: ForwardBatch,
410
+ **kwargs,
411
+ ):
412
+ num_tokens = len(forward_batch.input_ids)
413
+ index = bisect.bisect_left(self.capture_num_tokens, num_tokens)
414
+ static_num_tokens = self.capture_num_tokens[index]
415
+ self.raw_num_tokens = num_tokens
416
+ if static_num_tokens != num_tokens:
417
+ self.out_cache_loc.zero_()
418
+ bs = forward_batch.batch_size
419
+
420
+ self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
421
+ self.positions[:num_tokens].copy_(forward_batch.positions)
422
+ self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
423
+
424
+ input_ids = self.input_ids[:static_num_tokens]
425
+ positions = self.positions[:static_num_tokens]
426
+ out_cache_loc = self.out_cache_loc[:static_num_tokens]
427
+
428
+ next_token_logits_buffer = None
429
+ mrope_positions = None
430
+
431
+ static_forward_batch = ForwardBatch(
432
+ forward_mode=forward_batch.forward_mode,
433
+ batch_size=bs,
434
+ input_ids=input_ids,
435
+ req_pool_indices=forward_batch.req_pool_indices,
436
+ seq_lens=forward_batch.seq_lens,
437
+ next_token_logits_buffer=next_token_logits_buffer,
438
+ orig_seq_lens=forward_batch.orig_seq_lens,
439
+ seq_lens_cpu=forward_batch.seq_lens_cpu,
440
+ req_to_token_pool=self.model_runner.req_to_token_pool,
441
+ token_to_kv_pool=self.model_runner.token_to_kv_pool,
442
+ attn_backend=self.model_runner.attn_backend,
443
+ out_cache_loc=out_cache_loc,
444
+ seq_lens_sum=forward_batch.seq_lens_sum,
445
+ encoder_lens=forward_batch.encoder_lens,
446
+ return_logprob=False,
447
+ extend_seq_lens=forward_batch.extend_seq_lens,
448
+ extend_prefix_lens=forward_batch.extend_prefix_lens,
449
+ extend_start_loc=forward_batch.extend_start_loc,
450
+ extend_prefix_lens_cpu=forward_batch.extend_prefix_lens_cpu,
451
+ extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
452
+ extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
453
+ extend_num_tokens=forward_batch.extend_num_tokens,
454
+ extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
455
+ positions=positions,
456
+ global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
457
+ global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
458
+ dp_padding_mode=forward_batch.dp_padding_mode,
459
+ global_dp_buffer_len=forward_batch.global_dp_buffer_len,
460
+ mrope_positions=mrope_positions,
461
+ spec_algorithm=forward_batch.spec_algorithm,
462
+ spec_info=forward_batch.spec_info,
463
+ capture_hidden_mode=forward_batch.capture_hidden_mode,
464
+ num_token_non_padded=forward_batch.num_token_non_padded,
465
+ global_forward_mode=forward_batch.global_forward_mode,
466
+ lora_ids=forward_batch.lora_ids,
467
+ sampling_info=forward_batch.sampling_info,
468
+ mm_inputs=forward_batch.mm_inputs,
469
+ temp_scaled_logprobs=forward_batch.temp_scaled_logprobs,
470
+ temperature=forward_batch.temperature,
471
+ top_p_normalized_logprobs=forward_batch.top_p_normalized_logprobs,
472
+ top_p=forward_batch.top_p,
473
+ )
474
+
475
+ return static_forward_batch
476
+
477
+ def replay(
478
+ self,
479
+ forward_batch: ForwardBatch,
480
+ **kwargs,
481
+ ) -> Union[LogitsProcessorOutput, PPProxyTensors]:
482
+ if self.model_runner.tp_group.ca_comm is not None:
483
+ old_ca_disable = self.model_runner.tp_group.ca_comm.disabled
484
+ self.model_runner.tp_group.ca_comm.disabled = True
485
+ static_forward_batch = self.replay_prepare(forward_batch, **kwargs)
486
+ # Replay
487
+ with set_forward_context(static_forward_batch, self.attention_layers):
488
+ with set_compiled(True):
489
+ output = self.model_runner.model.forward(
490
+ static_forward_batch.input_ids,
491
+ static_forward_batch.positions,
492
+ static_forward_batch,
493
+ **kwargs,
494
+ )
495
+ if isinstance(output, LogitsProcessorOutput):
496
+ return LogitsProcessorOutput(
497
+ next_token_logits=output.next_token_logits[: self.raw_num_tokens],
498
+ hidden_states=(
499
+ output.hidden_states[: self.raw_num_tokens]
500
+ if output.hidden_states is not None
501
+ else None
502
+ ),
503
+ )
504
+ else:
505
+ assert isinstance(output, PPProxyTensors)
506
+ # TODO(Yuwei): support PP Support
507
+ raise NotImplementedError(
508
+ "PPProxyTensors is not supported in PiecewiseCudaGraphRunner yet."
509
+ )
510
+ if self.model_runner.tp_group.ca_comm is not None:
511
+ self.model_runner.tp_group.ca_comm.disabled = old_ca_disable
512
+
513
+ def get_spec_info(self, num_tokens: int):
514
+ spec_info = None
515
+ if (
516
+ self.model_runner.spec_algorithm.is_eagle()
517
+ or self.model_runner.spec_algorithm.is_standalone()
518
+ ):
519
+ from sglang.srt.speculative.eagle_utils import EagleVerifyInput
520
+
521
+ if self.model_runner.is_draft_worker:
522
+ raise RuntimeError("This should not happen.")
523
+ else:
524
+ spec_info = EagleVerifyInput(
525
+ draft_token=None,
526
+ custom_mask=self.custom_mask,
527
+ positions=None,
528
+ retrive_index=None,
529
+ retrive_next_token=None,
530
+ retrive_next_sibling=None,
531
+ retrive_cum_len=None,
532
+ spec_steps=self.model_runner.server_args.speculative_num_steps,
533
+ topk=self.model_runner.server_args.speculative_eagle_topk,
534
+ draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
535
+ capture_hidden_mode=CaptureHiddenMode.FULL,
536
+ seq_lens_sum=None,
537
+ seq_lens_cpu=None,
538
+ )
539
+
540
+ return spec_info
541
+
542
+
543
+ PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG = (
544
+ "Possible solutions:\n"
545
+ "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
546
+ "2. set --piecewise-cuda-graph-max-tokens to a smaller value (e.g., 512)\n"
547
+ "3. disable Piecewise CUDA graph by unset --enable-piecewise-cuda-graph\n"
548
+ "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
549
+ )
@@ -24,7 +24,7 @@ def get_model(
24
24
  load_config: LoadConfig,
25
25
  device_config: DeviceConfig,
26
26
  ) -> nn.Module:
27
- loader = get_model_loader(load_config)
27
+ loader = get_model_loader(load_config, model_config)
28
28
  return loader.load_model(
29
29
  model_config=model_config,
30
30
  device_config=device_config,