sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -15,14 +15,13 @@
15
15
  from __future__ import annotations
16
16
 
17
17
  import logging
18
- import threading
19
- from typing import TYPE_CHECKING, Optional, Tuple, Union
18
+ from abc import ABC, abstractmethod
19
+ from typing import TYPE_CHECKING, Optional
20
20
 
21
21
  import torch
22
22
 
23
23
  from sglang.srt.configs.model_config import ModelConfig
24
24
  from sglang.srt.distributed import get_pp_group, get_world_group
25
- from sglang.srt.layers.logits_processor import LogitsProcessorOutput
26
25
  from sglang.srt.managers.io_struct import (
27
26
  DestroyWeightsUpdateGroupReqInput,
28
27
  GetWeightsByNameReqInput,
@@ -33,16 +32,14 @@ from sglang.srt.managers.io_struct import (
33
32
  UnloadLoRAAdapterReqInput,
34
33
  UpdateWeightFromDiskReqInput,
35
34
  UpdateWeightsFromDistributedReqInput,
35
+ UpdateWeightsFromIPCReqInput,
36
36
  UpdateWeightsFromTensorReqInput,
37
37
  )
38
- from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
38
+ from sglang.srt.managers.schedule_batch import ModelWorkerBatch
39
+ from sglang.srt.managers.scheduler import GenerationBatchResult
39
40
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
40
41
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
41
- from sglang.srt.model_executor.forward_batch_info import (
42
- ForwardBatch,
43
- ForwardBatchOutput,
44
- PPProxyTensors,
45
- )
42
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
46
43
  from sglang.srt.model_executor.model_runner import ModelRunner
47
44
  from sglang.srt.server_args import ServerArgs
48
45
  from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
@@ -59,7 +56,145 @@ if TYPE_CHECKING:
59
56
  logger = logging.getLogger(__name__)
60
57
 
61
58
 
62
- class TpModelWorker:
59
+ class BaseTpWorker(ABC):
60
+ @abstractmethod
61
+ def forward_batch_generation(self, forward_batch: ForwardBatch):
62
+ pass
63
+
64
+ @property
65
+ @abstractmethod
66
+ def model_runner(self) -> ModelRunner:
67
+ pass
68
+
69
+ @property
70
+ def sliding_window_size(self) -> Optional[int]:
71
+ return self.model_runner.sliding_window_size
72
+
73
+ @property
74
+ def is_hybrid(self) -> bool:
75
+ return self.model_runner.is_hybrid is not None
76
+
77
+ def get_tokens_per_layer_info(self):
78
+ return (
79
+ self.model_runner.full_max_total_num_tokens,
80
+ self.model_runner.swa_max_total_num_tokens,
81
+ )
82
+
83
+ def get_pad_input_ids_func(self):
84
+ return getattr(self.model_runner.model, "pad_input_ids", None)
85
+
86
+ def get_tp_group(self):
87
+ return self.model_runner.tp_group
88
+
89
+ def get_attention_tp_group(self):
90
+ return self.model_runner.attention_tp_group
91
+
92
+ def get_attention_tp_cpu_group(self):
93
+ return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
94
+
95
+ def get_memory_pool(self):
96
+ return (
97
+ self.model_runner.req_to_token_pool,
98
+ self.model_runner.token_to_kv_pool_allocator,
99
+ )
100
+
101
+ def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
102
+ success, message = self.model_runner.update_weights_from_disk(
103
+ recv_req.model_path, recv_req.load_format
104
+ )
105
+ return success, message
106
+
107
+ def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
108
+ success, message = self.model_runner.init_weights_update_group(
109
+ recv_req.master_address,
110
+ recv_req.master_port,
111
+ recv_req.rank_offset,
112
+ recv_req.world_size,
113
+ recv_req.group_name,
114
+ recv_req.backend,
115
+ )
116
+ return success, message
117
+
118
+ def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
119
+ success, message = self.model_runner.destroy_weights_update_group(
120
+ recv_req.group_name,
121
+ )
122
+ return success, message
123
+
124
+ def init_weights_send_group_for_remote_instance(
125
+ self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
126
+ ):
127
+ success, message = (
128
+ self.model_runner.init_weights_send_group_for_remote_instance(
129
+ recv_req.master_address,
130
+ recv_req.ports,
131
+ recv_req.group_rank,
132
+ recv_req.world_size,
133
+ recv_req.group_name,
134
+ recv_req.backend,
135
+ )
136
+ )
137
+ return success, message
138
+
139
+ def send_weights_to_remote_instance(
140
+ self, recv_req: SendWeightsToRemoteInstanceReqInput
141
+ ):
142
+ success, message = self.model_runner.send_weights_to_remote_instance(
143
+ recv_req.master_address,
144
+ recv_req.ports,
145
+ recv_req.group_name,
146
+ )
147
+ return success, message
148
+
149
+ def update_weights_from_distributed(
150
+ self, recv_req: UpdateWeightsFromDistributedReqInput
151
+ ):
152
+ success, message = self.model_runner.update_weights_from_distributed(
153
+ recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
154
+ )
155
+ return success, message
156
+
157
+ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
158
+
159
+ monkey_patch_torch_reductions()
160
+ success, message = self.model_runner.update_weights_from_tensor(
161
+ named_tensors=MultiprocessingSerializer.deserialize(
162
+ recv_req.serialized_named_tensors[self.tp_rank]
163
+ ),
164
+ load_format=recv_req.load_format,
165
+ )
166
+ return success, message
167
+
168
+ def update_weights_from_ipc(self, recv_req: UpdateWeightsFromIPCReqInput):
169
+ """Update weights from IPC for checkpoint-engine integration."""
170
+ success, message = self.model_runner.update_weights_from_ipc(recv_req)
171
+ return success, message
172
+
173
+ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
174
+ parameter = self.model_runner.get_weights_by_name(
175
+ recv_req.name, recv_req.truncate_size
176
+ )
177
+ return parameter
178
+
179
+ def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
180
+ result = self.model_runner.load_lora_adapter(recv_req.to_ref())
181
+ return result
182
+
183
+ def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
184
+ result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
185
+ return result
186
+
187
+ def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
188
+ return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
189
+
190
+ def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
191
+ forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
192
+ logits_output, _ = self.model_runner.forward(forward_batch)
193
+ embeddings = logits_output.embeddings
194
+ return embeddings
195
+
196
+
197
+ class TpModelWorker(BaseTpWorker):
63
198
  """A tensor parallel model worker."""
64
199
 
65
200
  def __init__(
@@ -97,7 +232,7 @@ class TpModelWorker:
97
232
  is_draft_model=is_draft_worker,
98
233
  )
99
234
 
100
- self.model_runner = ModelRunner(
235
+ self._model_runner = ModelRunner(
101
236
  model_config=self.model_config,
102
237
  mem_fraction_static=server_args.mem_fraction_static,
103
238
  gpu_id=gpu_id,
@@ -173,11 +308,13 @@ class TpModelWorker:
173
308
  )[0]
174
309
  set_random_seed(self.random_seed)
175
310
 
176
- # A reference make this class has the same member as TpModelWorkerClient
177
- self.worker = self
178
-
311
+ self.enable_overlap = not server_args.disable_overlap_schedule
179
312
  self.hicache_layer_transfer_counter = None
180
313
 
314
+ @property
315
+ def model_runner(self) -> ModelRunner:
316
+ return self._model_runner
317
+
181
318
  def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
182
319
  self.hicache_layer_transfer_counter = counter
183
320
 
@@ -195,54 +332,29 @@ class TpModelWorker:
195
332
  self.max_req_input_len,
196
333
  self.random_seed,
197
334
  self.device,
198
- global_server_args_dict,
199
335
  self.model_runner.req_to_token_pool.size,
200
336
  self.model_runner.req_to_token_pool.max_context_len,
201
337
  self.model_runner.token_to_kv_pool.size,
202
338
  )
203
339
 
204
- @property
205
- def sliding_window_size(self) -> Optional[int]:
206
- return self.model_runner.sliding_window_size
207
-
208
- @property
209
- def is_hybrid(self) -> bool:
210
- return self.model_runner.is_hybrid is not None
211
-
212
- def get_tokens_per_layer_info(self):
213
- return (
214
- self.model_runner.full_max_total_num_tokens,
215
- self.model_runner.swa_max_total_num_tokens,
216
- )
217
-
218
- def get_pad_input_ids_func(self):
219
- return getattr(self.model_runner.model, "pad_input_ids", None)
220
-
221
- def get_tp_group(self):
222
- return self.model_runner.tp_group
223
-
224
- def get_attention_tp_group(self):
225
- return self.model_runner.attention_tp_group
226
-
227
- def get_attention_tp_cpu_group(self):
228
- return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
229
-
230
- def get_memory_pool(self):
231
- return (
232
- self.model_runner.req_to_token_pool,
233
- self.model_runner.token_to_kv_pool_allocator,
234
- )
235
-
236
340
  def forward_batch_generation(
237
341
  self,
238
342
  model_worker_batch: ModelWorkerBatch,
239
- launch_done: Optional[threading.Event] = None,
343
+ forward_batch: Optional[ForwardBatch] = None,
240
344
  is_verify: bool = False,
241
- ) -> ForwardBatchOutput:
242
- # update the consumer index of hicache to the running batch
243
- self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
345
+ skip_attn_backend_init=False,
346
+ ) -> GenerationBatchResult:
347
+ # FIXME(lsyin): maybe remove skip_attn_backend_init in forward_batch_generation,
348
+ # which requires preparing replay to always be in this function
244
349
 
245
- forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
350
+ if model_worker_batch is not None:
351
+ # update the consumer index of hicache to the running batch
352
+ self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
353
+
354
+ forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
355
+ else:
356
+ # FIXME(lsyin): unify the interface of forward_batch
357
+ assert forward_batch is not None
246
358
 
247
359
  pp_proxy_tensors = None
248
360
  if not self.pp_group.is_first_rank:
@@ -254,123 +366,62 @@ class TpModelWorker:
254
366
 
255
367
  if self.pp_group.is_last_rank:
256
368
  logits_output, can_run_cuda_graph = self.model_runner.forward(
257
- forward_batch, pp_proxy_tensors=pp_proxy_tensors
369
+ forward_batch,
370
+ pp_proxy_tensors=pp_proxy_tensors,
371
+ skip_attn_backend_init=skip_attn_backend_init,
258
372
  )
259
- if launch_done is not None:
260
- launch_done.set()
261
-
262
- skip_sample = is_verify or model_worker_batch.is_prefill_only
263
- next_token_ids = None
264
-
265
- if not skip_sample:
266
- next_token_ids = self.model_runner.sample(logits_output, forward_batch)
267
- elif model_worker_batch.return_logprob and not is_verify:
268
- # NOTE: Compute logprobs without full sampling
269
- self.model_runner.compute_logprobs_only(
270
- logits_output, model_worker_batch
271
- )
272
-
273
- return ForwardBatchOutput(
373
+ batch_result = GenerationBatchResult(
274
374
  logits_output=logits_output,
275
- next_token_ids=next_token_ids,
276
375
  can_run_cuda_graph=can_run_cuda_graph,
277
376
  )
377
+
378
+ if is_verify:
379
+ # Skip sampling and return logits for target forward
380
+ return batch_result
381
+
382
+ if (
383
+ self.enable_overlap
384
+ and model_worker_batch.sampling_info.grammars is not None
385
+ ):
386
+
387
+ def sample_batch_func():
388
+ batch_result.next_token_ids = self.model_runner.sample(
389
+ logits_output, forward_batch
390
+ )
391
+ return batch_result
392
+
393
+ batch_result.delay_sample_func = sample_batch_func
394
+ return batch_result
395
+
396
+ if model_worker_batch.is_prefill_only:
397
+ # For prefill-only requests, create dummy token IDs on CPU
398
+ # The size should match the batch size (number of sequences), not total tokens
399
+ batch_result.next_token_ids = torch.zeros(
400
+ len(model_worker_batch.seq_lens),
401
+ dtype=torch.long,
402
+ device=model_worker_batch.input_ids.device,
403
+ )
404
+ if (
405
+ model_worker_batch.return_logprob
406
+ and logits_output.next_token_logits is not None
407
+ ):
408
+ # NOTE: Compute logprobs without full sampling
409
+ self.model_runner.compute_logprobs_only(
410
+ logits_output, model_worker_batch
411
+ )
412
+ else:
413
+ batch_result.next_token_ids = self.model_runner.sample(
414
+ logits_output, forward_batch
415
+ )
416
+
417
+ return batch_result
278
418
  else:
279
419
  pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
280
420
  forward_batch,
281
421
  pp_proxy_tensors=pp_proxy_tensors,
422
+ skip_attn_backend_init=skip_attn_backend_init,
282
423
  )
283
- return ForwardBatchOutput(
284
- pp_proxy_tensors=pp_proxy_tensors,
424
+ return GenerationBatchResult(
425
+ pp_hidden_states_proxy_tensors=pp_proxy_tensors,
285
426
  can_run_cuda_graph=can_run_cuda_graph,
286
427
  )
287
-
288
- def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
289
- forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
290
- logits_output, _ = self.model_runner.forward(forward_batch)
291
- embeddings = logits_output.embeddings
292
- return embeddings
293
-
294
- def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
295
- success, message = self.model_runner.update_weights_from_disk(
296
- recv_req.model_path, recv_req.load_format
297
- )
298
- return success, message
299
-
300
- def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
301
- success, message = self.model_runner.init_weights_update_group(
302
- recv_req.master_address,
303
- recv_req.master_port,
304
- recv_req.rank_offset,
305
- recv_req.world_size,
306
- recv_req.group_name,
307
- recv_req.backend,
308
- )
309
- return success, message
310
-
311
- def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
312
- success, message = self.model_runner.destroy_weights_update_group(
313
- recv_req.group_name,
314
- )
315
- return success, message
316
-
317
- def init_weights_send_group_for_remote_instance(
318
- self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
319
- ):
320
- success, message = (
321
- self.model_runner.init_weights_send_group_for_remote_instance(
322
- recv_req.master_address,
323
- recv_req.ports,
324
- recv_req.group_rank,
325
- recv_req.world_size,
326
- recv_req.group_name,
327
- recv_req.backend,
328
- )
329
- )
330
- return success, message
331
-
332
- def send_weights_to_remote_instance(
333
- self, recv_req: SendWeightsToRemoteInstanceReqInput
334
- ):
335
- success, message = self.model_runner.send_weights_to_remote_instance(
336
- recv_req.master_address,
337
- recv_req.ports,
338
- recv_req.group_name,
339
- )
340
- return success, message
341
-
342
- def update_weights_from_distributed(
343
- self, recv_req: UpdateWeightsFromDistributedReqInput
344
- ):
345
- success, message = self.model_runner.update_weights_from_distributed(
346
- recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
347
- )
348
- return success, message
349
-
350
- def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
351
-
352
- monkey_patch_torch_reductions()
353
- success, message = self.model_runner.update_weights_from_tensor(
354
- named_tensors=MultiprocessingSerializer.deserialize(
355
- recv_req.serialized_named_tensors[self.tp_rank]
356
- ),
357
- load_format=recv_req.load_format,
358
- )
359
- return success, message
360
-
361
- def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
362
- parameter = self.model_runner.get_weights_by_name(
363
- recv_req.name, recv_req.truncate_size
364
- )
365
- return parameter
366
-
367
- def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
368
- result = self.model_runner.load_lora_adapter(recv_req.to_ref())
369
- return result
370
-
371
- def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
372
- result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
373
- return result
374
-
375
- def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
376
- return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
@@ -1,19 +1,95 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import dataclasses
3
4
  import logging
4
- import multiprocessing as mp
5
- from typing import TYPE_CHECKING, Dict, List, Optional
5
+ from typing import TYPE_CHECKING, List, Optional
6
+
7
+ import torch
6
8
 
7
9
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
10
+ from sglang.srt.managers.overlap_utils import FutureIndices
8
11
  from sglang.srt.managers.schedule_batch import Req
9
12
  from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
10
13
 
11
14
  if TYPE_CHECKING:
12
15
  from sglang.srt.managers.scheduler import GenerationBatchResult
16
+ from sglang.srt.speculative.eagle_info import EagleDraftInput
17
+
13
18
 
14
19
  logger = logging.getLogger(__name__)
15
20
 
16
21
 
22
+ @dataclasses.dataclass
23
+ class GenerationBatchResult:
24
+ logits_output: Optional[LogitsProcessorOutput] = None
25
+ pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None
26
+ next_token_ids: Optional[torch.Tensor] = None
27
+ num_accepted_tokens: Optional[int] = None
28
+ can_run_cuda_graph: bool = False
29
+
30
+ # For output processing
31
+ extend_input_len_per_req: Optional[List[int]] = None
32
+ extend_logprob_start_len_per_req: Optional[List[int]] = None
33
+
34
+ # For overlap scheduling
35
+ copy_done: Optional[torch.cuda.Event] = None
36
+ delay_sample_func: Optional[callable] = None
37
+ future_indices: Optional[FutureIndices] = None
38
+
39
+ # FIXME(lsyin): maybe move to a better place?
40
+ # sync path: forward stream -> output processor
41
+ accept_lens: Optional[torch.Tensor] = None
42
+ allocate_lens: Optional[torch.Tensor] = None
43
+
44
+ # relay path: forward stream -> next step forward
45
+ next_draft_input: Optional[EagleDraftInput] = None
46
+
47
+ def copy_to_cpu(self, return_logprob: bool = False):
48
+ """Copy tensors to CPU in overlap scheduling.
49
+ Only the tensors which are needed for processing results are copied,
50
+ e.g., next_token_ids, logits outputs
51
+ """
52
+ if return_logprob:
53
+ if self.logits_output.next_token_logits is not None:
54
+ self.logits_output.next_token_logits = (
55
+ self.logits_output.next_token_logits.to("cpu", non_blocking=True)
56
+ )
57
+ if self.logits_output.input_token_logprobs is not None:
58
+ self.logits_output.input_token_logprobs = (
59
+ self.logits_output.input_token_logprobs.to("cpu", non_blocking=True)
60
+ )
61
+ if self.logits_output.hidden_states is not None:
62
+ self.logits_output.hidden_states = self.logits_output.hidden_states.to(
63
+ "cpu", non_blocking=True
64
+ )
65
+ self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True)
66
+
67
+ if self.accept_lens is not None:
68
+ self.accept_lens = self.accept_lens.to("cpu", non_blocking=True)
69
+
70
+ if self.allocate_lens is not None:
71
+ self.allocate_lens = self.allocate_lens.to("cpu", non_blocking=True)
72
+
73
+ self.copy_done.record()
74
+
75
+ @classmethod
76
+ def from_pp_proxy(
77
+ cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
78
+ ):
79
+ # TODO(lsyin): refactor PP and avoid using dict
80
+ proxy_dict = next_pp_outputs.tensors
81
+ return cls(
82
+ logits_output=logits_output,
83
+ pp_hidden_states_proxy_tensors=None,
84
+ next_token_ids=next_pp_outputs["next_token_ids"],
85
+ extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
86
+ extend_logprob_start_len_per_req=proxy_dict.get(
87
+ "extend_logprob_start_len_per_req", None
88
+ ),
89
+ can_run_cuda_graph=can_run_cuda_graph,
90
+ )
91
+
92
+
17
93
  def validate_input_length(
18
94
  req: Req, max_req_input_len: int, allow_auto_truncate: bool
19
95
  ) -> Optional[str]:
@@ -274,10 +274,15 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
274
274
  self.full_to_swa_index_mapping[free_index] = 0
275
275
 
276
276
  def backup_state(self):
277
- raise NotImplementedError
277
+ return [
278
+ self.full_attn_allocator.backup_state(),
279
+ self.swa_attn_allocator.backup_state(),
280
+ ]
278
281
 
279
282
  def restore_state(self, state):
280
- raise NotImplementedError
283
+ assert len(state) == 2
284
+ self.full_attn_allocator.restore_state(state[0])
285
+ self.swa_attn_allocator.restore_state(state[1])
281
286
 
282
287
  def clear(self):
283
288
  self.swa_attn_allocator.clear()
@@ -92,7 +92,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
92
92
  )
93
93
 
94
94
  if num_new_pages_item < 200:
95
- import sgl_kernel_npu
95
+ import sgl_kernel_npu # noqa: F401
96
96
 
97
97
  torch.ops.npu.alloc_extend(
98
98
  prefix_lens,
@@ -119,7 +119,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
119
119
  assert len(torch.unique(out_indices)) == len(out_indices)
120
120
 
121
121
  self.free_pages = self.free_pages[num_new_pages_item:]
122
- return out_indices
122
+ return out_indices.int()
123
123
 
124
124
  def alloc_decode(
125
125
  self,
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import TYPE_CHECKING, Any, List, NamedTuple, Optional, Tuple
2
+ from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple
3
3
 
4
4
  import torch
5
5
 
@@ -40,7 +40,7 @@ class BasePrefixCache(ABC):
40
40
  pass
41
41
 
42
42
  @abstractmethod
43
- def cache_finished_req(self, req: Req, **kwargs):
43
+ def cache_finished_req(self, req: Req, is_insert: bool = True, **kwargs):
44
44
  pass
45
45
 
46
46
  @abstractmethod
@@ -27,6 +27,12 @@ class ChunkCache(BasePrefixCache):
27
27
  self.req_to_token_pool = req_to_token_pool
28
28
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
29
29
  self.page_size = page_size
30
+ if self.token_to_kv_pool_allocator:
31
+ self.device = self.token_to_kv_pool_allocator.device
32
+ else:
33
+ self.device = torch.device("cpu")
34
+
35
+ self.protected_size_ = 0
30
36
 
31
37
  # NOTE (csy): this is to determine if a cache has prefix matching feature.
32
38
  # Chunk cache always return True to indicate no prefix matching.
@@ -45,7 +51,7 @@ class ChunkCache(BasePrefixCache):
45
51
  last_host_node=None,
46
52
  )
47
53
 
48
- def cache_finished_req(self, req: Req, insert: bool = True):
54
+ def cache_finished_req(self, req: Req, is_insert: bool = True):
49
55
  kv_indices = self.req_to_token_pool.req_to_token[
50
56
  req.req_pool_idx,
51
57
  # For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids
@@ -53,14 +59,16 @@ class ChunkCache(BasePrefixCache):
53
59
  ]
54
60
  self.req_to_token_pool.free(req.req_pool_idx)
55
61
  self.token_to_kv_pool_allocator.free(kv_indices)
62
+ self.protected_size_ -= len(req.prefix_indices)
56
63
 
57
64
  def cache_unfinished_req(self, req: Req, chunked=False):
58
65
  kv_indices = self.req_to_token_pool.req_to_token[
59
66
  req.req_pool_idx, : len(req.fill_ids)
60
67
  ]
68
+ self.protected_size_ += len(kv_indices) - len(req.prefix_indices)
61
69
 
62
70
  # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
63
- req.prefix_indices = kv_indices
71
+ req.prefix_indices = kv_indices.to(dtype=torch.int64, copy=True)
64
72
 
65
73
  def evict(self, num_tokens: int):
66
74
  pass
@@ -71,6 +79,9 @@ class ChunkCache(BasePrefixCache):
71
79
  def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
72
80
  return 0
73
81
 
82
+ def protected_size(self):
83
+ return self.protected_size_
84
+
74
85
  def pretty_print(self):
75
86
  return ""
76
87