sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -48,18 +48,22 @@ from sglang.srt.model_executor.forward_batch_info import (
48
48
  PPProxyTensors,
49
49
  enable_num_token_non_padded,
50
50
  )
51
- from sglang.srt.patch_torch import monkey_patch_torch_compile
52
51
  from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
53
52
  from sglang.srt.utils import (
54
53
  empty_context,
55
54
  get_available_gpu_memory,
55
+ get_bool_env_var,
56
56
  get_device_memory_capacity,
57
+ is_hip,
57
58
  log_info_on_rank0,
58
59
  require_attn_tp_gather,
59
60
  require_gathered_buffer,
60
61
  require_mlp_sync,
61
62
  require_mlp_tp_gather,
62
63
  )
64
+ from sglang.srt.utils.patch_torch import monkey_patch_torch_compile
65
+
66
+ _is_hip = is_hip()
63
67
 
64
68
  logger = logging.getLogger(__name__)
65
69
 
@@ -100,6 +104,7 @@ def freeze_gc(enable_cudagraph_gc: bool):
100
104
  finally:
101
105
  if should_freeze:
102
106
  gc.unfreeze()
107
+ gc.collect()
103
108
 
104
109
 
105
110
  def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
@@ -136,7 +141,7 @@ def patch_model(
136
141
  mode=os.environ.get(
137
142
  "SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
138
143
  ),
139
- dynamic=False,
144
+ dynamic=_is_hip and get_bool_env_var("SGLANG_TORCH_DYNAMIC_SHAPE"),
140
145
  )
141
146
  else:
142
147
  yield model.forward
@@ -166,29 +171,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
166
171
  server_args = model_runner.server_args
167
172
  capture_bs = server_args.cuda_graph_bs
168
173
 
169
- if capture_bs is None:
170
- if server_args.speculative_algorithm is None:
171
- if server_args.disable_cuda_graph_padding:
172
- capture_bs = list(range(1, 33)) + list(range(48, 161, 16))
173
- else:
174
- capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
175
- else:
176
- # Since speculative decoding requires more cuda graph memory, we
177
- # capture less.
178
- capture_bs = (
179
- list(range(1, 9))
180
- + list(range(10, 33, 2))
181
- + list(range(40, 64, 8))
182
- + list(range(80, 161, 16))
183
- )
184
-
185
- gpu_mem = get_device_memory_capacity()
186
- if gpu_mem is not None:
187
- if gpu_mem > 90 * 1024: # H200, H20
188
- capture_bs += list(range(160, 257, 8))
189
- if gpu_mem > 160 * 1000: # B200, MI300
190
- capture_bs += list(range(256, 513, 16))
191
-
192
174
  if max(capture_bs) > model_runner.req_to_token_pool.size:
193
175
  # In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
194
176
  # is very small. We add more values here to make sure we capture the maximum bs.
@@ -204,12 +186,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
204
186
 
205
187
  capture_bs = [bs for bs in capture_bs if bs % mul_base == 0]
206
188
 
207
- if server_args.cuda_graph_max_bs:
208
- capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
209
- if max(capture_bs) < server_args.cuda_graph_max_bs:
210
- capture_bs += list(
211
- range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16)
212
- )
213
189
  capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
214
190
  capture_bs = list(sorted(set(capture_bs)))
215
191
  assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}"
@@ -271,7 +247,11 @@ class CudaGraphRunner:
271
247
  self.capture_forward_mode = ForwardMode.DECODE
272
248
  self.capture_hidden_mode = CaptureHiddenMode.NULL
273
249
  self.num_tokens_per_bs = 1
274
- if model_runner.spec_algorithm.is_eagle():
250
+ if (
251
+ model_runner.spec_algorithm.is_eagle()
252
+ or model_runner.spec_algorithm.is_standalone()
253
+ or model_runner.spec_algorithm.is_ngram()
254
+ ):
275
255
  if self.model_runner.is_draft_worker:
276
256
  raise RuntimeError("This should not happen")
277
257
  else:
@@ -317,7 +297,9 @@ class CudaGraphRunner:
317
297
  (self.max_num_token,), dtype=self._cache_loc_dtype()
318
298
  )
319
299
  self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
320
- self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
300
+ self.mrope_positions = torch.zeros(
301
+ (3, self.max_num_token), dtype=torch.int64
302
+ )
321
303
  self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
322
304
  self.tbo_plugin = TboCudaGraphRunnerPlugin()
323
305
 
@@ -435,11 +417,21 @@ class CudaGraphRunner:
435
417
  forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
436
418
  )
437
419
 
420
+ is_ngram_supported = (
421
+ (
422
+ forward_batch.batch_size * self.num_tokens_per_bs
423
+ == forward_batch.input_ids.numel()
424
+ )
425
+ if self.model_runner.spec_algorithm.is_ngram()
426
+ else True
427
+ )
428
+
438
429
  return (
439
430
  is_bs_supported
440
431
  and is_encoder_lens_supported
441
432
  and is_tbo_supported
442
433
  and capture_hidden_mode_matches
434
+ and is_ngram_supported
443
435
  )
444
436
 
445
437
  def capture(self) -> None:
@@ -449,6 +441,7 @@ class CudaGraphRunner:
449
441
  activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
450
442
  record_shapes=True,
451
443
  )
444
+ torch.cuda.memory._record_memory_history()
452
445
 
453
446
  # Trigger CUDA graph capture for specific shapes.
454
447
  # Capture the large shapes first so that the smaller shapes
@@ -497,6 +490,8 @@ class CudaGraphRunner:
497
490
  save_gemlite_cache()
498
491
 
499
492
  if self.enable_profile_cuda_graph:
493
+ torch.cuda.memory._dump_snapshot(f"cuda_graph_runner_memory_usage.pickle")
494
+ torch.cuda.memory._record_memory_history(enabled=None)
500
495
  log_message = (
501
496
  "Sorted by CUDA Time:\n"
502
497
  + prof.key_averages(group_by_input_shape=True).table(
@@ -506,6 +501,7 @@ class CudaGraphRunner:
506
501
  + prof.key_averages(group_by_input_shape=True).table(
507
502
  sort_by="cpu_time_total", row_limit=10
508
503
  )
504
+ + "\n\nMemory Usage is saved to cuda_graph_runner_memory_usage.pickle\n"
509
505
  )
510
506
  logger.info(log_message)
511
507
 
@@ -526,13 +522,14 @@ class CudaGraphRunner:
526
522
  input_ids = self.input_ids[:num_tokens]
527
523
  req_pool_indices = self.req_pool_indices[:bs]
528
524
  seq_lens = self.seq_lens[:bs]
525
+ seq_lens_cpu = self.seq_lens_cpu[:bs]
529
526
  out_cache_loc = self.out_cache_loc[:num_tokens]
530
527
  positions = self.positions[:num_tokens]
531
528
  if self.is_encoder_decoder:
532
529
  encoder_lens = self.encoder_lens[:bs]
533
530
  else:
534
531
  encoder_lens = None
535
- mrope_positions = self.mrope_positions[:, :bs]
532
+ mrope_positions = self.mrope_positions[:, :num_tokens]
536
533
  next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens]
537
534
  self.num_token_non_padded[...] = num_tokens
538
535
 
@@ -596,6 +593,7 @@ class CudaGraphRunner:
596
593
  input_ids=input_ids,
597
594
  req_pool_indices=req_pool_indices,
598
595
  seq_lens=seq_lens,
596
+ seq_lens_cpu=seq_lens_cpu,
599
597
  next_token_logits_buffer=next_token_logits_buffer,
600
598
  orig_seq_lens=seq_lens,
601
599
  req_to_token_pool=self.model_runner.req_to_token_pool,
@@ -751,7 +749,7 @@ class CudaGraphRunner:
751
749
  if self.is_encoder_decoder:
752
750
  self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
753
751
  if forward_batch.mrope_positions is not None:
754
- self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
752
+ self.mrope_positions[:, :raw_num_token].copy_(forward_batch.mrope_positions)
755
753
  if self.require_gathered_buffer:
756
754
  self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
757
755
  self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
@@ -825,8 +823,11 @@ class CudaGraphRunner:
825
823
 
826
824
  def get_spec_info(self, num_tokens: int):
827
825
  spec_info = None
828
- if self.model_runner.spec_algorithm.is_eagle():
829
- from sglang.srt.speculative.eagle_utils import EagleVerifyInput
826
+ if (
827
+ self.model_runner.spec_algorithm.is_eagle()
828
+ or self.model_runner.spec_algorithm.is_standalone()
829
+ ):
830
+ from sglang.srt.speculative.eagle_info import EagleVerifyInput
830
831
 
831
832
  if self.model_runner.is_draft_worker:
832
833
  raise RuntimeError("This should not happen.")
@@ -847,6 +848,20 @@ class CudaGraphRunner:
847
848
  seq_lens_cpu=None,
848
849
  )
849
850
 
851
+ elif self.model_runner.spec_algorithm.is_ngram():
852
+ from sglang.srt.speculative.ngram_utils import NgramVerifyInput
853
+
854
+ spec_info = NgramVerifyInput(
855
+ draft_token=None,
856
+ tree_mask=self.custom_mask,
857
+ positions=None,
858
+ retrive_index=None,
859
+ retrive_next_token=None,
860
+ retrive_next_sibling=None,
861
+ draft_token_num=self.num_tokens_per_bs,
862
+ )
863
+ spec_info.capture_hidden_mode = CaptureHiddenMode.NULL
864
+
850
865
  return spec_info
851
866
 
852
867
 
@@ -45,13 +45,7 @@ from sglang.srt.layers.dp_attention import (
45
45
  get_attention_tp_size,
46
46
  set_dp_buffer_len,
47
47
  )
48
- from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
49
- from sglang.srt.utils import (
50
- flatten_nested_list,
51
- get_compiler_backend,
52
- is_npu,
53
- support_triton,
54
- )
48
+ from sglang.srt.utils import get_compiler_backend, is_npu, support_triton
55
49
 
56
50
  if TYPE_CHECKING:
57
51
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -60,8 +54,7 @@ if TYPE_CHECKING:
60
54
  from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
61
55
  from sglang.srt.model_executor.model_runner import ModelRunner
62
56
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
63
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
64
- from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
57
+ from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
65
58
 
66
59
  _is_npu = is_npu()
67
60
 
@@ -132,6 +125,9 @@ class ForwardMode(IntEnum):
132
125
  or self == ForwardMode.IDLE
133
126
  )
134
127
 
128
+ def is_cpu_graph(self):
129
+ return self == ForwardMode.DECODE
130
+
135
131
  def is_dummy_first(self):
136
132
  return self == ForwardMode.DUMMY_FIRST
137
133
 
@@ -290,13 +286,14 @@ class ForwardBatch:
290
286
  global_forward_mode: Optional[ForwardMode] = None
291
287
 
292
288
  # Speculative decoding
293
- spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
289
+ spec_info: Optional[SpecInput] = None
294
290
  spec_algorithm: SpeculativeAlgorithm = None
295
291
  capture_hidden_mode: CaptureHiddenMode = None
296
292
 
297
293
  # For padding
298
294
  padded_static_len: int = -1 # -1 if not padded
299
295
  num_token_non_padded: Optional[torch.Tensor] = None # scalar tensor
296
+ num_token_non_padded_cpu: int = None
300
297
 
301
298
  # For Qwen2-VL
302
299
  mrope_positions: torch.Tensor = None
@@ -358,36 +355,18 @@ class ForwardBatch:
358
355
  ret.num_token_non_padded = torch.tensor(
359
356
  len(batch.input_ids), dtype=torch.int32
360
357
  ).to(device, non_blocking=True)
358
+ ret.num_token_non_padded_cpu = len(batch.input_ids)
361
359
 
362
360
  # For MLP sync
363
361
  if batch.global_num_tokens is not None:
364
- from sglang.srt.speculative.eagle_utils import (
365
- EagleDraftInput,
366
- EagleVerifyInput,
367
- )
368
-
369
362
  assert batch.global_num_tokens_for_logprob is not None
363
+
370
364
  # process global_num_tokens and global_num_tokens_for_logprob
371
365
  if batch.spec_info is not None:
372
- if isinstance(batch.spec_info, EagleDraftInput):
373
- global_num_tokens = [
374
- x * batch.spec_info.num_tokens_per_batch
375
- for x in batch.global_num_tokens
376
- ]
377
- global_num_tokens_for_logprob = [
378
- x * batch.spec_info.num_tokens_for_logprob_per_batch
379
- for x in batch.global_num_tokens_for_logprob
380
- ]
381
- else:
382
- assert isinstance(batch.spec_info, EagleVerifyInput)
383
- global_num_tokens = [
384
- x * batch.spec_info.draft_token_num
385
- for x in batch.global_num_tokens
386
- ]
387
- global_num_tokens_for_logprob = [
388
- x * batch.spec_info.draft_token_num
389
- for x in batch.global_num_tokens_for_logprob
390
- ]
366
+ spec_info: SpecInput = batch.spec_info
367
+ global_num_tokens, global_num_tokens_for_logprob = (
368
+ spec_info.get_spec_adjusted_global_num_tokens(batch)
369
+ )
391
370
  else:
392
371
  global_num_tokens = batch.global_num_tokens
393
372
  global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob
@@ -441,7 +420,13 @@ class ForwardBatch:
441
420
  ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
442
421
 
443
422
  if model_runner.model_is_mrope:
444
- ret._compute_mrope_positions(model_runner, batch)
423
+ if (
424
+ ret.spec_info is not None
425
+ and getattr(ret.spec_info, "positions", None) is not None
426
+ ):
427
+ ret._compute_spec_mrope_positions(model_runner, batch)
428
+ else:
429
+ ret._compute_mrope_positions(model_runner, batch)
445
430
 
446
431
  # Init lora information
447
432
  if model_runner.server_args.enable_lora:
@@ -507,6 +492,52 @@ class ForwardBatch:
507
492
  or self.contains_image_inputs()
508
493
  )
509
494
 
495
+ def _compute_spec_mrope_positions(
496
+ self, model_runner: ModelRunner, batch: ModelWorkerBatch
497
+ ):
498
+ # TODO support batched deltas
499
+ batch_size = self.seq_lens.shape[0]
500
+ device = model_runner.device
501
+ mm_inputs = batch.multimodal_inputs
502
+
503
+ if batch.forward_mode.is_draft_extend(): # draft_extend_after_decode
504
+ mrope_deltas = []
505
+ extend_lens = []
506
+ for batch_idx in range(batch_size):
507
+ extend_seq_len = batch.extend_seq_lens[batch_idx]
508
+ extend_lens.append(extend_seq_len)
509
+ mrope_delta = (
510
+ torch.zeros(1, dtype=torch.int64)
511
+ if mm_inputs[batch_idx] is None
512
+ else mm_inputs[batch_idx].mrope_position_delta.squeeze(0)
513
+ )
514
+ mrope_deltas.append(mrope_delta.to(device=device))
515
+ position_chunks = torch.split(batch.spec_info.positions, extend_lens)
516
+ mrope_positions_list = [
517
+ pos_chunk + delta
518
+ for pos_chunk, delta in zip(position_chunks, mrope_deltas)
519
+ ]
520
+ next_input_positions = (
521
+ torch.cat(mrope_positions_list, dim=0).unsqueeze(0).repeat(3, 1)
522
+ )
523
+
524
+ else: # target_verify or draft_decode
525
+ seq_positions = batch.spec_info.positions.view(batch_size, -1)
526
+ mrope_deltas = [
527
+ (
528
+ torch.tensor([0], dtype=torch.int64)
529
+ if mm_inputs[i] is None
530
+ else mm_inputs[i].mrope_position_delta.squeeze(0)
531
+ )
532
+ for i in range(batch_size)
533
+ ]
534
+ mrope_delta_tensor = torch.stack(mrope_deltas, dim=0).to(device=device)
535
+ next_input_positions = (
536
+ (seq_positions + mrope_delta_tensor).flatten().unsqueeze(0).repeat(3, 1)
537
+ )
538
+
539
+ self.mrope_positions = next_input_positions
540
+
510
541
  def _compute_mrope_positions(
511
542
  self, model_runner: ModelRunner, batch: ModelWorkerBatch
512
543
  ):
@@ -614,9 +645,6 @@ class ForwardBatch:
614
645
  )
615
646
 
616
647
  def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
617
-
618
- from sglang.srt.speculative.eagle_utils import EagleDraftInput
619
-
620
648
  assert self.global_num_tokens_cpu is not None
621
649
  assert self.global_num_tokens_for_logprob_cpu is not None
622
650
 
@@ -631,7 +659,9 @@ class ForwardBatch:
631
659
  (global_num_tokens[i] - 1) // attn_tp_size + 1
632
660
  ) * attn_tp_size
633
661
 
634
- dp_padding_mode = DpPaddingMode.get_dp_padding_mode(global_num_tokens)
662
+ dp_padding_mode = DpPaddingMode.get_dp_padding_mode(
663
+ self.is_extend_in_batch, global_num_tokens
664
+ )
635
665
  self.dp_padding_mode = dp_padding_mode
636
666
 
637
667
  if dp_padding_mode.is_max_len():
@@ -711,7 +741,8 @@ class ForwardBatch:
711
741
  if self.extend_seq_lens is not None:
712
742
  self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
713
743
 
714
- if self.spec_info is not None and isinstance(self.spec_info, EagleDraftInput):
744
+ if self.spec_info is not None and self.spec_info.is_draft_input():
745
+ # FIXME(lsyin): remove this isinstance logic
715
746
  spec_info = self.spec_info
716
747
  self.output_cache_loc_backup = self.out_cache_loc
717
748
  self.hidden_states_backup = spec_info.hidden_states
@@ -871,6 +902,17 @@ class ForwardBatch:
871
902
  return self.tbo_split_seq_index is not None
872
903
 
873
904
 
905
+ @dataclass
906
+ class ForwardBatchOutput:
907
+ # FIXME(lsyin): unify the forward batch output between different spec and parallelism
908
+ # need to be more organized
909
+ logits_output: Optional[torch.Tensor] = None
910
+ next_token_ids: Optional[torch.Tensor] = None
911
+ num_accepted_tokens: Optional[int] = None
912
+ pp_proxy_tensors: Optional[PPProxyTensors] = None
913
+ can_run_cuda_graph: bool = False
914
+
915
+
874
916
  def enable_num_token_non_padded(server_args):
875
917
  return get_moe_expert_parallel_world_size() > 1
876
918