sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -48,18 +48,22 @@ from sglang.srt.model_executor.forward_batch_info import (
48
48
  PPProxyTensors,
49
49
  enable_num_token_non_padded,
50
50
  )
51
- from sglang.srt.patch_torch import monkey_patch_torch_compile
52
51
  from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
53
52
  from sglang.srt.utils import (
54
53
  empty_context,
55
54
  get_available_gpu_memory,
55
+ get_bool_env_var,
56
56
  get_device_memory_capacity,
57
+ is_hip,
57
58
  log_info_on_rank0,
58
59
  require_attn_tp_gather,
59
60
  require_gathered_buffer,
60
61
  require_mlp_sync,
61
62
  require_mlp_tp_gather,
62
63
  )
64
+ from sglang.srt.utils.patch_torch import monkey_patch_torch_compile
65
+
66
+ _is_hip = is_hip()
63
67
 
64
68
  logger = logging.getLogger(__name__)
65
69
 
@@ -100,6 +104,7 @@ def freeze_gc(enable_cudagraph_gc: bool):
100
104
  finally:
101
105
  if should_freeze:
102
106
  gc.unfreeze()
107
+ gc.collect()
103
108
 
104
109
 
105
110
  def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
@@ -136,7 +141,7 @@ def patch_model(
136
141
  mode=os.environ.get(
137
142
  "SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
138
143
  ),
139
- dynamic=False,
144
+ dynamic=_is_hip and get_bool_env_var("SGLANG_TORCH_DYNAMIC_SHAPE"),
140
145
  )
141
146
  else:
142
147
  yield model.forward
@@ -166,29 +171,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
166
171
  server_args = model_runner.server_args
167
172
  capture_bs = server_args.cuda_graph_bs
168
173
 
169
- if capture_bs is None:
170
- if server_args.speculative_algorithm is None:
171
- if server_args.disable_cuda_graph_padding:
172
- capture_bs = list(range(1, 33)) + list(range(48, 161, 16))
173
- else:
174
- capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
175
- else:
176
- # Since speculative decoding requires more cuda graph memory, we
177
- # capture less.
178
- capture_bs = (
179
- list(range(1, 9))
180
- + list(range(10, 33, 2))
181
- + list(range(40, 64, 8))
182
- + list(range(80, 161, 16))
183
- )
184
-
185
- gpu_mem = get_device_memory_capacity()
186
- if gpu_mem is not None:
187
- if gpu_mem > 90 * 1024: # H200, H20
188
- capture_bs += list(range(160, 257, 8))
189
- if gpu_mem > 160 * 1000: # B200, MI300
190
- capture_bs += list(range(256, 513, 16))
191
-
192
174
  if max(capture_bs) > model_runner.req_to_token_pool.size:
193
175
  # In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
194
176
  # is very small. We add more values here to make sure we capture the maximum bs.
@@ -204,12 +186,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
204
186
 
205
187
  capture_bs = [bs for bs in capture_bs if bs % mul_base == 0]
206
188
 
207
- if server_args.cuda_graph_max_bs:
208
- capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
209
- if max(capture_bs) < server_args.cuda_graph_max_bs:
210
- capture_bs += list(
211
- range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16)
212
- )
213
189
  capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
214
190
  capture_bs = list(sorted(set(capture_bs)))
215
191
  assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}"
@@ -274,6 +250,7 @@ class CudaGraphRunner:
274
250
  if (
275
251
  model_runner.spec_algorithm.is_eagle()
276
252
  or model_runner.spec_algorithm.is_standalone()
253
+ or model_runner.spec_algorithm.is_ngram()
277
254
  ):
278
255
  if self.model_runner.is_draft_worker:
279
256
  raise RuntimeError("This should not happen")
@@ -440,11 +417,21 @@ class CudaGraphRunner:
440
417
  forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
441
418
  )
442
419
 
420
+ is_ngram_supported = (
421
+ (
422
+ forward_batch.batch_size * self.num_tokens_per_bs
423
+ == forward_batch.input_ids.numel()
424
+ )
425
+ if self.model_runner.spec_algorithm.is_ngram()
426
+ else True
427
+ )
428
+
443
429
  return (
444
430
  is_bs_supported
445
431
  and is_encoder_lens_supported
446
432
  and is_tbo_supported
447
433
  and capture_hidden_mode_matches
434
+ and is_ngram_supported
448
435
  )
449
436
 
450
437
  def capture(self) -> None:
@@ -454,6 +441,7 @@ class CudaGraphRunner:
454
441
  activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
455
442
  record_shapes=True,
456
443
  )
444
+ torch.cuda.memory._record_memory_history()
457
445
 
458
446
  # Trigger CUDA graph capture for specific shapes.
459
447
  # Capture the large shapes first so that the smaller shapes
@@ -502,6 +490,8 @@ class CudaGraphRunner:
502
490
  save_gemlite_cache()
503
491
 
504
492
  if self.enable_profile_cuda_graph:
493
+ torch.cuda.memory._dump_snapshot(f"cuda_graph_runner_memory_usage.pickle")
494
+ torch.cuda.memory._record_memory_history(enabled=None)
505
495
  log_message = (
506
496
  "Sorted by CUDA Time:\n"
507
497
  + prof.key_averages(group_by_input_shape=True).table(
@@ -511,6 +501,7 @@ class CudaGraphRunner:
511
501
  + prof.key_averages(group_by_input_shape=True).table(
512
502
  sort_by="cpu_time_total", row_limit=10
513
503
  )
504
+ + "\n\nMemory Usage is saved to cuda_graph_runner_memory_usage.pickle\n"
514
505
  )
515
506
  logger.info(log_message)
516
507
 
@@ -531,6 +522,7 @@ class CudaGraphRunner:
531
522
  input_ids = self.input_ids[:num_tokens]
532
523
  req_pool_indices = self.req_pool_indices[:bs]
533
524
  seq_lens = self.seq_lens[:bs]
525
+ seq_lens_cpu = self.seq_lens_cpu[:bs]
534
526
  out_cache_loc = self.out_cache_loc[:num_tokens]
535
527
  positions = self.positions[:num_tokens]
536
528
  if self.is_encoder_decoder:
@@ -601,6 +593,7 @@ class CudaGraphRunner:
601
593
  input_ids=input_ids,
602
594
  req_pool_indices=req_pool_indices,
603
595
  seq_lens=seq_lens,
596
+ seq_lens_cpu=seq_lens_cpu,
604
597
  next_token_logits_buffer=next_token_logits_buffer,
605
598
  orig_seq_lens=seq_lens,
606
599
  req_to_token_pool=self.model_runner.req_to_token_pool,
@@ -834,7 +827,7 @@ class CudaGraphRunner:
834
827
  self.model_runner.spec_algorithm.is_eagle()
835
828
  or self.model_runner.spec_algorithm.is_standalone()
836
829
  ):
837
- from sglang.srt.speculative.eagle_utils import EagleVerifyInput
830
+ from sglang.srt.speculative.eagle_info import EagleVerifyInput
838
831
 
839
832
  if self.model_runner.is_draft_worker:
840
833
  raise RuntimeError("This should not happen.")
@@ -855,6 +848,20 @@ class CudaGraphRunner:
855
848
  seq_lens_cpu=None,
856
849
  )
857
850
 
851
+ elif self.model_runner.spec_algorithm.is_ngram():
852
+ from sglang.srt.speculative.ngram_utils import NgramVerifyInput
853
+
854
+ spec_info = NgramVerifyInput(
855
+ draft_token=None,
856
+ tree_mask=self.custom_mask,
857
+ positions=None,
858
+ retrive_index=None,
859
+ retrive_next_token=None,
860
+ retrive_next_sibling=None,
861
+ draft_token_num=self.num_tokens_per_bs,
862
+ )
863
+ spec_info.capture_hidden_mode = CaptureHiddenMode.NULL
864
+
858
865
  return spec_info
859
866
 
860
867
 
@@ -45,13 +45,7 @@ from sglang.srt.layers.dp_attention import (
45
45
  get_attention_tp_size,
46
46
  set_dp_buffer_len,
47
47
  )
48
- from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
49
- from sglang.srt.utils import (
50
- flatten_nested_list,
51
- get_compiler_backend,
52
- is_npu,
53
- support_triton,
54
- )
48
+ from sglang.srt.utils import get_compiler_backend, is_npu, support_triton
55
49
 
56
50
  if TYPE_CHECKING:
57
51
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -60,8 +54,7 @@ if TYPE_CHECKING:
60
54
  from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
61
55
  from sglang.srt.model_executor.model_runner import ModelRunner
62
56
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
63
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
64
- from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
57
+ from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
65
58
 
66
59
  _is_npu = is_npu()
67
60
 
@@ -293,13 +286,14 @@ class ForwardBatch:
293
286
  global_forward_mode: Optional[ForwardMode] = None
294
287
 
295
288
  # Speculative decoding
296
- spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
289
+ spec_info: Optional[SpecInput] = None
297
290
  spec_algorithm: SpeculativeAlgorithm = None
298
291
  capture_hidden_mode: CaptureHiddenMode = None
299
292
 
300
293
  # For padding
301
294
  padded_static_len: int = -1 # -1 if not padded
302
295
  num_token_non_padded: Optional[torch.Tensor] = None # scalar tensor
296
+ num_token_non_padded_cpu: int = None
303
297
 
304
298
  # For Qwen2-VL
305
299
  mrope_positions: torch.Tensor = None
@@ -361,36 +355,18 @@ class ForwardBatch:
361
355
  ret.num_token_non_padded = torch.tensor(
362
356
  len(batch.input_ids), dtype=torch.int32
363
357
  ).to(device, non_blocking=True)
358
+ ret.num_token_non_padded_cpu = len(batch.input_ids)
364
359
 
365
360
  # For MLP sync
366
361
  if batch.global_num_tokens is not None:
367
- from sglang.srt.speculative.eagle_utils import (
368
- EagleDraftInput,
369
- EagleVerifyInput,
370
- )
371
-
372
362
  assert batch.global_num_tokens_for_logprob is not None
363
+
373
364
  # process global_num_tokens and global_num_tokens_for_logprob
374
365
  if batch.spec_info is not None:
375
- if isinstance(batch.spec_info, EagleDraftInput):
376
- global_num_tokens = [
377
- x * batch.spec_info.num_tokens_per_batch
378
- for x in batch.global_num_tokens
379
- ]
380
- global_num_tokens_for_logprob = [
381
- x * batch.spec_info.num_tokens_for_logprob_per_batch
382
- for x in batch.global_num_tokens_for_logprob
383
- ]
384
- else:
385
- assert isinstance(batch.spec_info, EagleVerifyInput)
386
- global_num_tokens = [
387
- x * batch.spec_info.draft_token_num
388
- for x in batch.global_num_tokens
389
- ]
390
- global_num_tokens_for_logprob = [
391
- x * batch.spec_info.draft_token_num
392
- for x in batch.global_num_tokens_for_logprob
393
- ]
366
+ spec_info: SpecInput = batch.spec_info
367
+ global_num_tokens, global_num_tokens_for_logprob = (
368
+ spec_info.get_spec_adjusted_global_num_tokens(batch)
369
+ )
394
370
  else:
395
371
  global_num_tokens = batch.global_num_tokens
396
372
  global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob
@@ -669,9 +645,6 @@ class ForwardBatch:
669
645
  )
670
646
 
671
647
  def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
672
-
673
- from sglang.srt.speculative.eagle_utils import EagleDraftInput
674
-
675
648
  assert self.global_num_tokens_cpu is not None
676
649
  assert self.global_num_tokens_for_logprob_cpu is not None
677
650
 
@@ -768,7 +741,8 @@ class ForwardBatch:
768
741
  if self.extend_seq_lens is not None:
769
742
  self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
770
743
 
771
- if self.spec_info is not None and isinstance(self.spec_info, EagleDraftInput):
744
+ if self.spec_info is not None and self.spec_info.is_draft_input():
745
+ # FIXME(lsyin): remove this isinstance logic
772
746
  spec_info = self.spec_info
773
747
  self.output_cache_loc_backup = self.out_cache_loc
774
748
  self.hidden_states_backup = spec_info.hidden_states
@@ -928,6 +902,17 @@ class ForwardBatch:
928
902
  return self.tbo_split_seq_index is not None
929
903
 
930
904
 
905
+ @dataclass
906
+ class ForwardBatchOutput:
907
+ # FIXME(lsyin): unify the forward batch output between different spec and parallelism
908
+ # need to be more organized
909
+ logits_output: Optional[torch.Tensor] = None
910
+ next_token_ids: Optional[torch.Tensor] = None
911
+ num_accepted_tokens: Optional[int] = None
912
+ pp_proxy_tensors: Optional[PPProxyTensors] = None
913
+ can_run_cuda_graph: bool = False
914
+
915
+
931
916
  def enable_num_token_non_padded(server_args):
932
917
  return get_moe_expert_parallel_world_size() > 1
933
918