sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -19,10 +19,9 @@ import logging
19
19
  import threading
20
20
  from typing import TYPE_CHECKING, Optional, Union
21
21
 
22
- import numpy as np
23
22
  import torch
24
23
 
25
- from sglang.srt.configs.model_config import AttentionArch
24
+ from sglang.srt.configs.model_config import is_deepseek_nsa
26
25
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
27
26
 
28
27
  logger = logging.getLogger(__name__)
@@ -75,7 +74,7 @@ class NPUGraphRunner(CudaGraphRunner):
75
74
  self.positions[: self.raw_num_token].copy_(forward_batch.positions)
76
75
 
77
76
  # Replay
78
- if self.model_runner.model_config.index_head_dim is None:
77
+ if not is_deepseek_nsa(self.model_runner.model_config.hf_config):
79
78
  seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
80
79
  self.bs - self.raw_bs
81
80
  )
@@ -0,0 +1,539 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """Run the model with cuda graph and torch.compile."""
15
+
16
+ from __future__ import annotations
17
+
18
+ import bisect
19
+ import gc
20
+ import logging
21
+ from contextlib import contextmanager
22
+ from typing import TYPE_CHECKING, Union
23
+
24
+ import torch
25
+ import tqdm
26
+
27
+ from sglang.srt.compilation.compilation_config import CompilationConfig
28
+ from sglang.srt.compilation.compile import install_torch_compiled, set_compiled
29
+ from sglang.srt.compilation.piecewise_context_manager import set_forward_context
30
+ from sglang.srt.custom_op import CustomOp
31
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
32
+ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
33
+ set_graph_pool_id,
34
+ )
35
+ from sglang.srt.distributed.parallel_state import graph_capture
36
+ from sglang.srt.layers.dp_attention import (
37
+ DpPaddingMode,
38
+ get_attention_tp_rank,
39
+ get_attention_tp_size,
40
+ set_dp_buffer_len,
41
+ set_is_extend_in_batch,
42
+ )
43
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
44
+ from sglang.srt.layers.torchao_utils import save_gemlite_cache
45
+ from sglang.srt.model_executor.forward_batch_info import (
46
+ CaptureHiddenMode,
47
+ ForwardBatch,
48
+ ForwardMode,
49
+ PPProxyTensors,
50
+ )
51
+ from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
52
+ from sglang.srt.utils import get_available_gpu_memory, log_info_on_rank0
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+ if TYPE_CHECKING:
57
+ from sglang.srt.model_executor.model_runner import ModelRunner
58
+
59
+ # Detect whether the current forward pass is in capture mode
60
+ is_capture_mode = False
61
+
62
+
63
+ def get_is_capture_mode():
64
+ return is_capture_mode
65
+
66
+
67
+ @contextmanager
68
+ def model_capture_mode():
69
+ global is_capture_mode
70
+ is_capture_mode = True
71
+
72
+ yield
73
+
74
+ is_capture_mode = False
75
+
76
+
77
+ @contextmanager
78
+ def freeze_gc(enable_cudagraph_gc: bool):
79
+ """
80
+ Optimize garbage collection during CUDA graph capture.
81
+ Clean up, then freeze all remaining objects from being included
82
+ in future collections if GC is disabled during capture.
83
+ """
84
+ gc.collect()
85
+ should_freeze = not enable_cudagraph_gc
86
+ if should_freeze:
87
+ gc.freeze()
88
+ try:
89
+ yield
90
+ finally:
91
+ if should_freeze:
92
+ gc.unfreeze()
93
+
94
+
95
+ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
96
+ for sub in model._modules.values():
97
+ if isinstance(sub, CustomOp):
98
+ if reverse:
99
+ sub.leave_torch_compile()
100
+ else:
101
+ sub.enter_torch_compile(num_tokens=num_tokens)
102
+ if isinstance(sub, torch.nn.Module):
103
+ _to_torch(sub, reverse, num_tokens)
104
+
105
+
106
+ @contextmanager
107
+ def patch_model(model: torch.nn.Module, compiler: str):
108
+ try:
109
+ if compiler != "eager":
110
+ _to_torch(model, reverse=False, num_tokens=16)
111
+ yield model
112
+ finally:
113
+ _to_torch(model, reverse=True, num_tokens=16)
114
+
115
+
116
+ # Reuse this memory pool across all cuda graph runners.
117
+ global_graph_memory_pool = None
118
+
119
+
120
+ def get_global_graph_memory_pool():
121
+ return global_graph_memory_pool
122
+
123
+
124
+ def set_global_graph_memory_pool(val):
125
+ global global_graph_memory_pool
126
+ global_graph_memory_pool = val
127
+
128
+
129
+ class PiecewiseCudaGraphRunner:
130
+ """A PiecewiseCudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
131
+
132
+ def __init__(self, model_runner: ModelRunner):
133
+ # Parse args
134
+ self.model_runner = model_runner
135
+ self.device = model_runner.device
136
+ self.device_module = torch.get_device_module(self.device)
137
+ self.graphs = {}
138
+ self.output_buffers = {}
139
+ self.tp_size = model_runner.server_args.tp_size
140
+ self.dp_size = model_runner.server_args.dp_size
141
+ self.pp_size = model_runner.server_args.pp_size
142
+
143
+ self.attn_tp_size = get_attention_tp_size()
144
+ self.attn_tp_rank = get_attention_tp_rank()
145
+
146
+ assert (
147
+ self.model_runner.server_args.piecewise_cuda_graph_tokens is not None
148
+ ), "piecewise_cuda_graph_tokens is not set"
149
+ assert self.model_runner.server_args.piecewise_cuda_graph_compiler in [
150
+ "eager",
151
+ "inductor",
152
+ ], "By now, only eager and inductor are supported for piecewise cuda graph compiler."
153
+ self.compile_config = CompilationConfig(
154
+ self.model_runner.server_args.piecewise_cuda_graph_tokens,
155
+ self.model_runner.server_args.piecewise_cuda_graph_compiler,
156
+ )
157
+
158
+ # Batch sizes to capture
159
+ self.capture_num_tokens = self.compile_config.get_capture_sizes()
160
+ log_info_on_rank0(
161
+ logger, f"Capture cuda graph num tokens {self.capture_num_tokens}"
162
+ )
163
+ self.capture_forward_mode = ForwardMode.EXTEND
164
+ self.capture_hidden_mode = CaptureHiddenMode.NULL
165
+
166
+ # If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
167
+ if model_runner.server_args.enable_return_hidden_states:
168
+ self.capture_hidden_mode = CaptureHiddenMode.FULL
169
+
170
+ # Attention backend
171
+ self.max_num_tokens = max(self.capture_num_tokens)
172
+
173
+ # Graph inputs
174
+ with torch.device(self.device):
175
+ self.input_ids = torch.zeros((self.max_num_tokens,), dtype=torch.int64)
176
+ self.out_cache_loc = torch.zeros(
177
+ (self.max_num_tokens,), dtype=self._cache_loc_dtype()
178
+ )
179
+ self.positions = torch.zeros((self.max_num_tokens,), dtype=torch.int64)
180
+ self.tbo_plugin = TboCudaGraphRunnerPlugin()
181
+
182
+ self.attention_layers = self.model_runner.attention_layers
183
+
184
+ if get_global_graph_memory_pool() is None:
185
+ set_global_graph_memory_pool(self.device_module.graph_pool_handle())
186
+ # Set graph pool id globally to be able to use symmetric memory
187
+ set_graph_pool_id(get_global_graph_memory_pool())
188
+
189
+ with patch_model(
190
+ self.model_runner.model.model, self.compile_config.compiler
191
+ ) as patched_model:
192
+ install_torch_compiled(
193
+ patched_model,
194
+ fullgraph=True,
195
+ dynamic_arg_dims=None,
196
+ compile_config=self.compile_config,
197
+ graph_pool=get_global_graph_memory_pool(),
198
+ )
199
+
200
+ with set_compiled(True):
201
+ self.warmup_and_capture()
202
+
203
+ # Capture
204
+ try:
205
+ with model_capture_mode():
206
+ self.capture()
207
+ except RuntimeError as e:
208
+ raise Exception(
209
+ f"Capture cuda graph failed: {e}\n{PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG}"
210
+ )
211
+
212
+ self.raw_num_tokens = 0
213
+
214
+ def warmup_and_capture(self):
215
+ num_tokens = 2
216
+ with torch.device(self.device):
217
+ forward_batch = ForwardBatch(
218
+ forward_mode=ForwardMode.EXTEND,
219
+ batch_size=1,
220
+ input_ids=torch.randint(0, 100, (num_tokens,), device=self.device),
221
+ req_pool_indices=torch.arange(1, device=self.device),
222
+ seq_lens=torch.tensor([num_tokens], device=self.device),
223
+ next_token_logits_buffer=None,
224
+ orig_seq_lens=torch.tensor([num_tokens], device=self.device),
225
+ seq_lens_cpu=torch.tensor([num_tokens]),
226
+ req_to_token_pool=self.model_runner.req_to_token_pool,
227
+ token_to_kv_pool=self.model_runner.token_to_kv_pool,
228
+ attn_backend=self.model_runner.attn_backend,
229
+ out_cache_loc=torch.randint(0, 100, (num_tokens,), device=self.device),
230
+ seq_lens_sum=num_tokens,
231
+ encoder_lens=None,
232
+ return_logprob=False,
233
+ extend_seq_lens=torch.tensor([num_tokens], device=self.device),
234
+ extend_prefix_lens=torch.tensor([num_tokens], device=self.device),
235
+ extend_start_loc=torch.tensor([0], device=self.device),
236
+ extend_prefix_lens_cpu=torch.tensor([num_tokens]),
237
+ extend_seq_lens_cpu=torch.tensor([num_tokens]),
238
+ extend_logprob_start_lens_cpu=torch.tensor([num_tokens]),
239
+ positions=torch.arange(num_tokens, device=self.device),
240
+ global_num_tokens_gpu=None,
241
+ global_num_tokens_for_logprob_gpu=None,
242
+ dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
243
+ global_dp_buffer_len=None,
244
+ mrope_positions=None,
245
+ spec_algorithm=None,
246
+ spec_info=None,
247
+ capture_hidden_mode=CaptureHiddenMode.NULL,
248
+ num_token_non_padded=None,
249
+ global_forward_mode=ForwardMode.EXTEND,
250
+ lora_ids=None,
251
+ )
252
+
253
+ with set_forward_context(forward_batch, self.attention_layers):
254
+ _ = self.model_runner.model.forward(
255
+ forward_batch.input_ids,
256
+ forward_batch.positions,
257
+ forward_batch,
258
+ )
259
+
260
+ def _cache_loc_dtype(self):
261
+ return torch.int64
262
+
263
+ def can_run(self, forward_batch: ForwardBatch):
264
+ num_tokens = len(forward_batch.input_ids)
265
+ # TODO(yuwei): support return logprob
266
+ if forward_batch.return_logprob:
267
+ return False
268
+ if num_tokens <= self.max_num_tokens:
269
+ return True
270
+ return False
271
+
272
+ def capture(self) -> None:
273
+ # Trigger CUDA graph capture for specific shapes.
274
+ # Capture the large shapes first so that the smaller shapes
275
+ # can reuse the memory pool allocated for the large shapes.
276
+ with freeze_gc(
277
+ self.model_runner.server_args.enable_cudagraph_gc
278
+ ), graph_capture() as graph_capture_context:
279
+ self.stream = graph_capture_context.stream
280
+ avail_mem = get_available_gpu_memory(
281
+ self.model_runner.device,
282
+ self.model_runner.gpu_id,
283
+ empty_cache=False,
284
+ )
285
+ # Reverse the order to enable better memory sharing across cuda graphs.
286
+ capture_range = (
287
+ tqdm.tqdm(list(reversed(self.capture_num_tokens)))
288
+ if get_tensor_model_parallel_rank() == 0
289
+ else reversed(self.capture_num_tokens)
290
+ )
291
+ for i, num_tokens in enumerate(capture_range):
292
+ if get_tensor_model_parallel_rank() == 0:
293
+ avail_mem = get_available_gpu_memory(
294
+ self.model_runner.device,
295
+ self.model_runner.gpu_id,
296
+ empty_cache=False,
297
+ )
298
+ capture_range.set_description(
299
+ f"Capturing num tokens ({num_tokens=} {avail_mem=:.2f} GB)"
300
+ )
301
+
302
+ with set_compiled(True):
303
+ self.capture_one_batch_size(num_tokens)
304
+
305
+ # Save gemlite cache after each capture
306
+ save_gemlite_cache()
307
+
308
+ def capture_one_batch_size(self, num_tokens: int):
309
+ stream = self.stream
310
+ bs = 1
311
+
312
+ # Graph inputs
313
+ input_ids = self.input_ids[:num_tokens]
314
+ out_cache_loc = self.out_cache_loc[:num_tokens]
315
+ positions = self.positions[:num_tokens]
316
+
317
+ # pipeline parallelism
318
+ if self.pp_size > 1:
319
+ pp_proxy_tensors = PPProxyTensors(
320
+ {k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
321
+ )
322
+
323
+ global_dp_buffer_len = None
324
+
325
+ if self.model_runner.server_args.enable_lora:
326
+ # It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever
327
+ # `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization).
328
+ lora_ids = [None] * bs
329
+ else:
330
+ lora_ids = None
331
+
332
+ with torch.device(self.device):
333
+ forward_batch = ForwardBatch(
334
+ forward_mode=ForwardMode.EXTEND,
335
+ batch_size=bs,
336
+ input_ids=input_ids,
337
+ req_pool_indices=torch.arange(bs, device=self.device),
338
+ seq_lens=torch.tensor([num_tokens], device=self.device),
339
+ next_token_logits_buffer=None,
340
+ orig_seq_lens=torch.tensor([num_tokens], device=self.device),
341
+ seq_lens_cpu=torch.tensor([num_tokens]),
342
+ req_to_token_pool=self.model_runner.req_to_token_pool,
343
+ token_to_kv_pool=self.model_runner.token_to_kv_pool,
344
+ attn_backend=self.model_runner.attn_backend,
345
+ out_cache_loc=out_cache_loc,
346
+ seq_lens_sum=num_tokens,
347
+ encoder_lens=None,
348
+ return_logprob=False,
349
+ extend_seq_lens=torch.tensor([num_tokens], device=self.device),
350
+ extend_prefix_lens=torch.tensor([num_tokens], device=self.device),
351
+ extend_start_loc=torch.tensor([0], device=self.device),
352
+ extend_prefix_lens_cpu=torch.tensor([num_tokens]),
353
+ extend_seq_lens_cpu=torch.tensor([num_tokens]),
354
+ extend_logprob_start_lens_cpu=torch.tensor([num_tokens]),
355
+ positions=positions,
356
+ global_num_tokens_gpu=None,
357
+ global_num_tokens_for_logprob_gpu=None,
358
+ dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
359
+ global_dp_buffer_len=None,
360
+ mrope_positions=None,
361
+ spec_algorithm=None,
362
+ spec_info=None,
363
+ capture_hidden_mode=CaptureHiddenMode.NULL,
364
+ num_token_non_padded=None,
365
+ global_forward_mode=ForwardMode.EXTEND,
366
+ lora_ids=None,
367
+ )
368
+ self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
369
+
370
+ if lora_ids is not None:
371
+ self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
372
+
373
+ # # Attention backend
374
+ self.model_runner.attn_backend.init_forward_metadata(forward_batch)
375
+
376
+ # Run and capture
377
+ def run_once():
378
+ # Clean intermediate result cache for DP attention
379
+ forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
380
+ set_dp_buffer_len(global_dp_buffer_len, num_tokens)
381
+ # FIXME: the implementation is hacky. `is_extend_in_batch`` is for determining the deepep mode.
382
+ # It is True in this context but we need to set it to use low latency deepep mode.
383
+ set_is_extend_in_batch(False)
384
+
385
+ kwargs = {}
386
+ with set_forward_context(forward_batch, self.attention_layers):
387
+ self.model_runner.model.forward(
388
+ forward_batch.input_ids,
389
+ forward_batch.positions,
390
+ forward_batch,
391
+ **kwargs,
392
+ )
393
+ return
394
+
395
+ for _ in range(2):
396
+ self.device_module.synchronize()
397
+ self.model_runner.tp_group.barrier()
398
+ run_once()
399
+
400
+ return
401
+
402
+ def replay_prepare(
403
+ self,
404
+ forward_batch: ForwardBatch,
405
+ **kwargs,
406
+ ):
407
+ num_tokens = len(forward_batch.input_ids)
408
+ index = bisect.bisect_left(self.capture_num_tokens, num_tokens)
409
+ static_num_tokens = self.capture_num_tokens[index]
410
+ self.raw_num_tokens = num_tokens
411
+ if static_num_tokens != num_tokens:
412
+ self.out_cache_loc.zero_()
413
+ bs = forward_batch.batch_size
414
+
415
+ self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
416
+ self.positions[:num_tokens].copy_(forward_batch.positions)
417
+ self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
418
+
419
+ input_ids = self.input_ids[:static_num_tokens]
420
+ positions = self.positions[:static_num_tokens]
421
+ out_cache_loc = self.out_cache_loc[:static_num_tokens]
422
+
423
+ next_token_logits_buffer = None
424
+ mrope_positions = None
425
+
426
+ static_forward_batch = ForwardBatch(
427
+ forward_mode=forward_batch.forward_mode,
428
+ batch_size=bs,
429
+ input_ids=input_ids,
430
+ req_pool_indices=forward_batch.req_pool_indices,
431
+ seq_lens=forward_batch.seq_lens,
432
+ next_token_logits_buffer=next_token_logits_buffer,
433
+ orig_seq_lens=forward_batch.orig_seq_lens,
434
+ seq_lens_cpu=forward_batch.seq_lens_cpu,
435
+ req_to_token_pool=self.model_runner.req_to_token_pool,
436
+ token_to_kv_pool=self.model_runner.token_to_kv_pool,
437
+ attn_backend=self.model_runner.attn_backend,
438
+ out_cache_loc=out_cache_loc,
439
+ seq_lens_sum=forward_batch.seq_lens_sum,
440
+ encoder_lens=forward_batch.encoder_lens,
441
+ return_logprob=forward_batch.return_logprob,
442
+ extend_seq_lens=forward_batch.extend_seq_lens,
443
+ extend_prefix_lens=forward_batch.extend_prefix_lens,
444
+ extend_start_loc=forward_batch.extend_start_loc,
445
+ extend_prefix_lens_cpu=forward_batch.extend_prefix_lens_cpu,
446
+ extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
447
+ extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
448
+ extend_num_tokens=forward_batch.extend_num_tokens,
449
+ extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
450
+ positions=positions,
451
+ global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
452
+ global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
453
+ dp_padding_mode=forward_batch.dp_padding_mode,
454
+ global_dp_buffer_len=forward_batch.global_dp_buffer_len,
455
+ mrope_positions=mrope_positions,
456
+ spec_algorithm=forward_batch.spec_algorithm,
457
+ spec_info=forward_batch.spec_info,
458
+ capture_hidden_mode=forward_batch.capture_hidden_mode,
459
+ num_token_non_padded=forward_batch.num_token_non_padded,
460
+ global_forward_mode=forward_batch.global_forward_mode,
461
+ lora_ids=forward_batch.lora_ids,
462
+ sampling_info=forward_batch.sampling_info,
463
+ mm_inputs=forward_batch.mm_inputs,
464
+ temp_scaled_logprobs=forward_batch.temp_scaled_logprobs,
465
+ temperature=forward_batch.temperature,
466
+ top_p_normalized_logprobs=forward_batch.top_p_normalized_logprobs,
467
+ top_p=forward_batch.top_p,
468
+ )
469
+
470
+ return static_forward_batch
471
+
472
+ def replay(
473
+ self,
474
+ forward_batch: ForwardBatch,
475
+ **kwargs,
476
+ ) -> Union[LogitsProcessorOutput, PPProxyTensors]:
477
+ static_forward_batch = self.replay_prepare(forward_batch, **kwargs)
478
+ # Replay
479
+ with set_forward_context(static_forward_batch, self.attention_layers):
480
+ with set_compiled(True):
481
+ output = self.model_runner.model.forward(
482
+ static_forward_batch.input_ids,
483
+ static_forward_batch.positions,
484
+ static_forward_batch,
485
+ **kwargs,
486
+ )
487
+ if isinstance(output, LogitsProcessorOutput):
488
+ return LogitsProcessorOutput(
489
+ next_token_logits=output.next_token_logits[: self.raw_num_tokens],
490
+ hidden_states=(
491
+ output.hidden_states[: self.raw_num_tokens]
492
+ if output.hidden_states is not None
493
+ else None
494
+ ),
495
+ )
496
+ else:
497
+ assert isinstance(output, PPProxyTensors)
498
+ # TODO(Yuwei): support PP Support
499
+ raise NotImplementedError(
500
+ "PPProxyTensors is not supported in PiecewiseCudaGraphRunner yet."
501
+ )
502
+
503
+ def get_spec_info(self, num_tokens: int):
504
+ spec_info = None
505
+ if (
506
+ self.model_runner.spec_algorithm.is_eagle()
507
+ or self.model_runner.spec_algorithm.is_standalone()
508
+ ):
509
+ from sglang.srt.speculative.eagle_utils import EagleVerifyInput
510
+
511
+ if self.model_runner.is_draft_worker:
512
+ raise RuntimeError("This should not happen.")
513
+ else:
514
+ spec_info = EagleVerifyInput(
515
+ draft_token=None,
516
+ custom_mask=self.custom_mask,
517
+ positions=None,
518
+ retrive_index=None,
519
+ retrive_next_token=None,
520
+ retrive_next_sibling=None,
521
+ retrive_cum_len=None,
522
+ spec_steps=self.model_runner.server_args.speculative_num_steps,
523
+ topk=self.model_runner.server_args.speculative_eagle_topk,
524
+ draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
525
+ capture_hidden_mode=CaptureHiddenMode.FULL,
526
+ seq_lens_sum=None,
527
+ seq_lens_cpu=None,
528
+ )
529
+
530
+ return spec_info
531
+
532
+
533
+ PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG = (
534
+ "Possible solutions:\n"
535
+ "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
536
+ "2. set --piecewise-cuda-graph-max-tokens to a smaller value (e.g., 512)\n"
537
+ "3. disable Piecewise CUDA graph by unset --enable-piecewise-cuda-graph\n"
538
+ "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
539
+ )
@@ -24,7 +24,7 @@ def get_model(
24
24
  load_config: LoadConfig,
25
25
  device_config: DeviceConfig,
26
26
  ) -> nn.Module:
27
- loader = get_model_loader(load_config)
27
+ loader = get_model_loader(load_config, model_config)
28
28
  return loader.load_model(
29
29
  model_config=model_config,
30
30
  device_config=device_config,