sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,6 @@ import datetime
3
3
  from google.protobuf import timestamp_pb2 as _timestamp_pb2
4
4
  from google.protobuf import struct_pb2 as _struct_pb2
5
5
  from google.protobuf.internal import containers as _containers
6
- from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
7
6
  from google.protobuf import descriptor as _descriptor
8
7
  from google.protobuf import message as _message
9
8
  from collections.abc import Iterable as _Iterable, Mapping as _Mapping
@@ -12,7 +11,7 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
12
11
  DESCRIPTOR: _descriptor.FileDescriptor
13
12
 
14
13
  class SamplingParams(_message.Message):
15
- __slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "structural_tag", "custom_params")
14
+ __slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "structural_tag", "n", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "custom_params")
16
15
  class LogitBiasEntry(_message.Message):
17
16
  __slots__ = ("key", "value")
18
17
  KEY_FIELD_NUMBER: _ClassVar[int]
@@ -35,15 +34,13 @@ class SamplingParams(_message.Message):
35
34
  REGEX_FIELD_NUMBER: _ClassVar[int]
36
35
  JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int]
37
36
  EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int]
38
- LORA_PATH_FIELD_NUMBER: _ClassVar[int]
37
+ STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
39
38
  N_FIELD_NUMBER: _ClassVar[int]
40
- TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int]
41
39
  MIN_NEW_TOKENS_FIELD_NUMBER: _ClassVar[int]
42
40
  IGNORE_EOS_FIELD_NUMBER: _ClassVar[int]
43
41
  NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int]
44
42
  STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int]
45
43
  LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int]
46
- STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
47
44
  CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int]
48
45
  temperature: float
49
46
  top_p: float
@@ -60,17 +57,15 @@ class SamplingParams(_message.Message):
60
57
  regex: str
61
58
  json_schema: str
62
59
  ebnf_grammar: str
63
- lora_path: str
60
+ structural_tag: str
64
61
  n: int
65
- token_healing: bool
66
62
  min_new_tokens: int
67
63
  ignore_eos: bool
68
64
  no_stop_trim: bool
69
65
  stream_interval: int
70
66
  logit_bias: _containers.ScalarMap[str, float]
71
- structural_tag: str
72
67
  custom_params: _struct_pb2.Struct
73
- def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., structural_tag: _Optional[str] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
68
+ def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., structural_tag: _Optional[str] = ..., n: _Optional[int] = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
74
69
 
75
70
  class DisaggregatedParams(_message.Message):
76
71
  __slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room")
@@ -83,7 +78,7 @@ class DisaggregatedParams(_message.Message):
83
78
  def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
84
79
 
85
80
  class GenerateRequest(_message.Message):
86
- __slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id")
81
+ __slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "stream")
87
82
  REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
88
83
  TOKENIZED_FIELD_NUMBER: _ClassVar[int]
89
84
  MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
@@ -100,7 +95,7 @@ class GenerateRequest(_message.Message):
100
95
  INPUT_EMBEDS_FIELD_NUMBER: _ClassVar[int]
101
96
  LORA_ID_FIELD_NUMBER: _ClassVar[int]
102
97
  DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
103
- DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
98
+ STREAM_FIELD_NUMBER: _ClassVar[int]
104
99
  request_id: str
105
100
  tokenized: TokenizedInput
106
101
  mm_inputs: MultimodalInputs
@@ -117,8 +112,8 @@ class GenerateRequest(_message.Message):
117
112
  input_embeds: _containers.RepeatedScalarFieldContainer[float]
118
113
  lora_id: str
119
114
  data_parallel_rank: int
120
- dp_balance_id: int
121
- def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ...) -> None: ...
115
+ stream: bool
116
+ def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., stream: bool = ...) -> None: ...
122
117
 
123
118
  class TokenizedInput(_message.Message):
124
119
  __slots__ = ("original_text", "input_ids")
@@ -161,52 +156,50 @@ class GenerateResponse(_message.Message):
161
156
  def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
162
157
 
163
158
  class GenerateStreamChunk(_message.Message):
164
- __slots__ = ("token_id", "text", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states", "generation_time", "queue_time")
165
- TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
166
- TEXT_FIELD_NUMBER: _ClassVar[int]
159
+ __slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs", "index")
160
+ TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
167
161
  PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
168
162
  COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
169
163
  CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
170
- LOGPROBS_FIELD_NUMBER: _ClassVar[int]
164
+ OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
171
165
  HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
172
- GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
173
- QUEUE_TIME_FIELD_NUMBER: _ClassVar[int]
174
- token_id: int
175
- text: str
166
+ INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
167
+ INDEX_FIELD_NUMBER: _ClassVar[int]
168
+ token_ids: _containers.RepeatedScalarFieldContainer[int]
176
169
  prompt_tokens: int
177
170
  completion_tokens: int
178
171
  cached_tokens: int
179
- logprobs: LogProbs
172
+ output_logprobs: OutputLogProbs
180
173
  hidden_states: _containers.RepeatedScalarFieldContainer[float]
181
- generation_time: float
182
- queue_time: int
183
- def __init__(self, token_id: _Optional[int] = ..., text: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., generation_time: _Optional[float] = ..., queue_time: _Optional[int] = ...) -> None: ...
174
+ input_logprobs: InputLogProbs
175
+ index: int
176
+ def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ...
184
177
 
185
178
  class GenerateComplete(_message.Message):
186
- __slots__ = ("output_ids", "output_text", "finish_reason", "all_logprobs", "all_hidden_states")
187
- class FinishReason(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
188
- __slots__ = ()
189
- STOP: _ClassVar[GenerateComplete.FinishReason]
190
- LENGTH: _ClassVar[GenerateComplete.FinishReason]
191
- EOS_TOKEN: _ClassVar[GenerateComplete.FinishReason]
192
- STOP_STR: _ClassVar[GenerateComplete.FinishReason]
193
- ABORT: _ClassVar[GenerateComplete.FinishReason]
194
- STOP: GenerateComplete.FinishReason
195
- LENGTH: GenerateComplete.FinishReason
196
- EOS_TOKEN: GenerateComplete.FinishReason
197
- STOP_STR: GenerateComplete.FinishReason
198
- ABORT: GenerateComplete.FinishReason
179
+ __slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs", "index")
199
180
  OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
200
- OUTPUT_TEXT_FIELD_NUMBER: _ClassVar[int]
201
181
  FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
202
- ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
182
+ PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
183
+ COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
184
+ CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
185
+ OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
203
186
  ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
187
+ MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
188
+ MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
189
+ INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
190
+ INDEX_FIELD_NUMBER: _ClassVar[int]
204
191
  output_ids: _containers.RepeatedScalarFieldContainer[int]
205
- output_text: str
206
- finish_reason: GenerateComplete.FinishReason
207
- all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
192
+ finish_reason: str
193
+ prompt_tokens: int
194
+ completion_tokens: int
195
+ cached_tokens: int
196
+ output_logprobs: OutputLogProbs
208
197
  all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
209
- def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., output_text: _Optional[str] = ..., finish_reason: _Optional[_Union[GenerateComplete.FinishReason, str]] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ...) -> None: ...
198
+ matched_token_id: int
199
+ matched_stop_str: str
200
+ input_logprobs: InputLogProbs
201
+ index: int
202
+ def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ...
210
203
 
211
204
  class GenerateError(_message.Message):
212
205
  __slots__ = ("message", "http_status_code", "details")
@@ -218,27 +211,39 @@ class GenerateError(_message.Message):
218
211
  details: str
219
212
  def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
220
213
 
221
- class LogProbs(_message.Message):
222
- __slots__ = ("token_logprobs", "token_ids", "top_logprobs", "token_texts")
214
+ class OutputLogProbs(_message.Message):
215
+ __slots__ = ("token_logprobs", "token_ids", "top_logprobs")
223
216
  TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
224
217
  TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
225
218
  TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
226
- TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
227
219
  token_logprobs: _containers.RepeatedScalarFieldContainer[float]
228
220
  token_ids: _containers.RepeatedScalarFieldContainer[int]
229
221
  top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
230
- token_texts: _containers.RepeatedScalarFieldContainer[str]
231
- def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
222
+ def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ...
223
+
224
+ class InputLogProbs(_message.Message):
225
+ __slots__ = ("token_logprobs", "token_ids", "top_logprobs")
226
+ TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
227
+ TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
228
+ TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
229
+ token_logprobs: _containers.RepeatedCompositeFieldContainer[InputTokenLogProb]
230
+ token_ids: _containers.RepeatedScalarFieldContainer[int]
231
+ top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
232
+ def __init__(self, token_logprobs: _Optional[_Iterable[_Union[InputTokenLogProb, _Mapping]]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ...
233
+
234
+ class InputTokenLogProb(_message.Message):
235
+ __slots__ = ("value",)
236
+ VALUE_FIELD_NUMBER: _ClassVar[int]
237
+ value: float
238
+ def __init__(self, value: _Optional[float] = ...) -> None: ...
232
239
 
233
240
  class TopLogProbs(_message.Message):
234
- __slots__ = ("values", "token_ids", "token_texts")
241
+ __slots__ = ("values", "token_ids")
235
242
  VALUES_FIELD_NUMBER: _ClassVar[int]
236
243
  TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
237
- TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
238
244
  values: _containers.RepeatedScalarFieldContainer[float]
239
245
  token_ids: _containers.RepeatedScalarFieldContainer[int]
240
- token_texts: _containers.RepeatedScalarFieldContainer[str]
241
- def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
246
+ def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ...) -> None: ...
242
247
 
243
248
  class HiddenStates(_message.Message):
244
249
  __slots__ = ("values", "layer", "position")
@@ -283,20 +288,18 @@ class EmbedResponse(_message.Message):
283
288
  def __init__(self, request_id: _Optional[str] = ..., complete: _Optional[_Union[EmbedComplete, _Mapping]] = ..., error: _Optional[_Union[EmbedError, _Mapping]] = ...) -> None: ...
284
289
 
285
290
  class EmbedComplete(_message.Message):
286
- __slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "generation_time", "batch_embeddings")
291
+ __slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "batch_embeddings")
287
292
  EMBEDDING_FIELD_NUMBER: _ClassVar[int]
288
293
  PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
289
294
  CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
290
295
  EMBEDDING_DIM_FIELD_NUMBER: _ClassVar[int]
291
- GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
292
296
  BATCH_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int]
293
297
  embedding: _containers.RepeatedScalarFieldContainer[float]
294
298
  prompt_tokens: int
295
299
  cached_tokens: int
296
300
  embedding_dim: int
297
- generation_time: float
298
301
  batch_embeddings: _containers.RepeatedCompositeFieldContainer[Embedding]
299
- def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., generation_time: _Optional[float] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
302
+ def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
300
303
 
301
304
  class Embedding(_message.Message):
302
305
  __slots__ = ("values", "index")
@@ -317,10 +320,8 @@ class EmbedError(_message.Message):
317
320
  def __init__(self, message: _Optional[str] = ..., code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
318
321
 
319
322
  class HealthCheckRequest(_message.Message):
320
- __slots__ = ("tokenized",)
321
- TOKENIZED_FIELD_NUMBER: _ClassVar[int]
322
- tokenized: TokenizedInput
323
- def __init__(self, tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ...) -> None: ...
323
+ __slots__ = ()
324
+ def __init__(self) -> None: ...
324
325
 
325
326
  class HealthCheckResponse(_message.Message):
326
327
  __slots__ = ("healthy", "message")
@@ -425,3 +426,65 @@ class SetInternalStateResponse(_message.Message):
425
426
  success: bool
426
427
  message: str
427
428
  def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ...
429
+
430
+ class GetModelInfoRequest(_message.Message):
431
+ __slots__ = ()
432
+ def __init__(self) -> None: ...
433
+
434
+ class GetModelInfoResponse(_message.Message):
435
+ __slots__ = ("model_path", "tokenizer_path", "is_generation", "preferred_sampling_params", "weight_version", "served_model_name", "max_context_length", "vocab_size", "supports_vision", "model_type", "eos_token_ids", "pad_token_id", "bos_token_id", "max_req_input_len")
436
+ MODEL_PATH_FIELD_NUMBER: _ClassVar[int]
437
+ TOKENIZER_PATH_FIELD_NUMBER: _ClassVar[int]
438
+ IS_GENERATION_FIELD_NUMBER: _ClassVar[int]
439
+ PREFERRED_SAMPLING_PARAMS_FIELD_NUMBER: _ClassVar[int]
440
+ WEIGHT_VERSION_FIELD_NUMBER: _ClassVar[int]
441
+ SERVED_MODEL_NAME_FIELD_NUMBER: _ClassVar[int]
442
+ MAX_CONTEXT_LENGTH_FIELD_NUMBER: _ClassVar[int]
443
+ VOCAB_SIZE_FIELD_NUMBER: _ClassVar[int]
444
+ SUPPORTS_VISION_FIELD_NUMBER: _ClassVar[int]
445
+ MODEL_TYPE_FIELD_NUMBER: _ClassVar[int]
446
+ EOS_TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
447
+ PAD_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
448
+ BOS_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
449
+ MAX_REQ_INPUT_LEN_FIELD_NUMBER: _ClassVar[int]
450
+ model_path: str
451
+ tokenizer_path: str
452
+ is_generation: bool
453
+ preferred_sampling_params: str
454
+ weight_version: str
455
+ served_model_name: str
456
+ max_context_length: int
457
+ vocab_size: int
458
+ supports_vision: bool
459
+ model_type: str
460
+ eos_token_ids: _containers.RepeatedScalarFieldContainer[int]
461
+ pad_token_id: int
462
+ bos_token_id: int
463
+ max_req_input_len: int
464
+ def __init__(self, model_path: _Optional[str] = ..., tokenizer_path: _Optional[str] = ..., is_generation: bool = ..., preferred_sampling_params: _Optional[str] = ..., weight_version: _Optional[str] = ..., served_model_name: _Optional[str] = ..., max_context_length: _Optional[int] = ..., vocab_size: _Optional[int] = ..., supports_vision: bool = ..., model_type: _Optional[str] = ..., eos_token_ids: _Optional[_Iterable[int]] = ..., pad_token_id: _Optional[int] = ..., bos_token_id: _Optional[int] = ..., max_req_input_len: _Optional[int] = ...) -> None: ...
465
+
466
+ class GetServerInfoRequest(_message.Message):
467
+ __slots__ = ()
468
+ def __init__(self) -> None: ...
469
+
470
+ class GetServerInfoResponse(_message.Message):
471
+ __slots__ = ("server_args", "scheduler_info", "active_requests", "is_paused", "last_receive_timestamp", "uptime_seconds", "sglang_version", "server_type", "start_time")
472
+ SERVER_ARGS_FIELD_NUMBER: _ClassVar[int]
473
+ SCHEDULER_INFO_FIELD_NUMBER: _ClassVar[int]
474
+ ACTIVE_REQUESTS_FIELD_NUMBER: _ClassVar[int]
475
+ IS_PAUSED_FIELD_NUMBER: _ClassVar[int]
476
+ LAST_RECEIVE_TIMESTAMP_FIELD_NUMBER: _ClassVar[int]
477
+ UPTIME_SECONDS_FIELD_NUMBER: _ClassVar[int]
478
+ SGLANG_VERSION_FIELD_NUMBER: _ClassVar[int]
479
+ SERVER_TYPE_FIELD_NUMBER: _ClassVar[int]
480
+ START_TIME_FIELD_NUMBER: _ClassVar[int]
481
+ server_args: _struct_pb2.Struct
482
+ scheduler_info: _struct_pb2.Struct
483
+ active_requests: int
484
+ is_paused: bool
485
+ last_receive_timestamp: float
486
+ uptime_seconds: float
487
+ sglang_version: str
488
+ server_type: str
489
+ start_time: _timestamp_pb2.Timestamp
490
+ def __init__(self, server_args: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., scheduler_info: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., active_requests: _Optional[int] = ..., is_paused: bool = ..., last_receive_timestamp: _Optional[float] = ..., uptime_seconds: _Optional[float] = ..., sglang_version: _Optional[str] = ..., server_type: _Optional[str] = ..., start_time: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ...) -> None: ...
@@ -1,3 +1,6 @@
1
+ # This file is auto-generated. Do not edit manually.
2
+ # Regenerate with: python compile_proto.py
3
+
1
4
  # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
5
  """Client and server classes corresponding to protobuf-defined services."""
3
6
  import grpc
@@ -5,7 +8,7 @@ import warnings
5
8
 
6
9
  from . import sglang_scheduler_pb2 as sglang__scheduler__pb2
7
10
 
8
- GRPC_GENERATED_VERSION = '1.74.0'
11
+ GRPC_GENERATED_VERSION = '1.75.1'
9
12
  GRPC_VERSION = grpc.__version__
10
13
  _version_not_supported = False
11
14
 
@@ -56,6 +59,16 @@ class SglangSchedulerStub(object):
56
59
  request_serializer=sglang__scheduler__pb2.AbortRequest.SerializeToString,
57
60
  response_deserializer=sglang__scheduler__pb2.AbortResponse.FromString,
58
61
  _registered_method=True)
62
+ self.GetModelInfo = channel.unary_unary(
63
+ '/sglang.grpc.scheduler.SglangScheduler/GetModelInfo',
64
+ request_serializer=sglang__scheduler__pb2.GetModelInfoRequest.SerializeToString,
65
+ response_deserializer=sglang__scheduler__pb2.GetModelInfoResponse.FromString,
66
+ _registered_method=True)
67
+ self.GetServerInfo = channel.unary_unary(
68
+ '/sglang.grpc.scheduler.SglangScheduler/GetServerInfo',
69
+ request_serializer=sglang__scheduler__pb2.GetServerInfoRequest.SerializeToString,
70
+ response_deserializer=sglang__scheduler__pb2.GetServerInfoResponse.FromString,
71
+ _registered_method=True)
59
72
 
60
73
 
61
74
  class SglangSchedulerServicer(object):
@@ -91,6 +104,20 @@ class SglangSchedulerServicer(object):
91
104
  context.set_details('Method not implemented!')
92
105
  raise NotImplementedError('Method not implemented!')
93
106
 
107
+ def GetModelInfo(self, request, context):
108
+ """Get model information
109
+ """
110
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
111
+ context.set_details('Method not implemented!')
112
+ raise NotImplementedError('Method not implemented!')
113
+
114
+ def GetServerInfo(self, request, context):
115
+ """Get server information
116
+ """
117
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
118
+ context.set_details('Method not implemented!')
119
+ raise NotImplementedError('Method not implemented!')
120
+
94
121
 
95
122
  def add_SglangSchedulerServicer_to_server(servicer, server):
96
123
  rpc_method_handlers = {
@@ -114,6 +141,16 @@ def add_SglangSchedulerServicer_to_server(servicer, server):
114
141
  request_deserializer=sglang__scheduler__pb2.AbortRequest.FromString,
115
142
  response_serializer=sglang__scheduler__pb2.AbortResponse.SerializeToString,
116
143
  ),
144
+ 'GetModelInfo': grpc.unary_unary_rpc_method_handler(
145
+ servicer.GetModelInfo,
146
+ request_deserializer=sglang__scheduler__pb2.GetModelInfoRequest.FromString,
147
+ response_serializer=sglang__scheduler__pb2.GetModelInfoResponse.SerializeToString,
148
+ ),
149
+ 'GetServerInfo': grpc.unary_unary_rpc_method_handler(
150
+ servicer.GetServerInfo,
151
+ request_deserializer=sglang__scheduler__pb2.GetServerInfoRequest.FromString,
152
+ response_serializer=sglang__scheduler__pb2.GetServerInfoResponse.SerializeToString,
153
+ ),
117
154
  }
118
155
  generic_handler = grpc.method_handlers_generic_handler(
119
156
  'sglang.grpc.scheduler.SglangScheduler', rpc_method_handlers)
@@ -234,3 +271,57 @@ class SglangScheduler(object):
234
271
  timeout,
235
272
  metadata,
236
273
  _registered_method=True)
274
+
275
+ @staticmethod
276
+ def GetModelInfo(request,
277
+ target,
278
+ options=(),
279
+ channel_credentials=None,
280
+ call_credentials=None,
281
+ insecure=False,
282
+ compression=None,
283
+ wait_for_ready=None,
284
+ timeout=None,
285
+ metadata=None):
286
+ return grpc.experimental.unary_unary(
287
+ request,
288
+ target,
289
+ '/sglang.grpc.scheduler.SglangScheduler/GetModelInfo',
290
+ sglang__scheduler__pb2.GetModelInfoRequest.SerializeToString,
291
+ sglang__scheduler__pb2.GetModelInfoResponse.FromString,
292
+ options,
293
+ channel_credentials,
294
+ insecure,
295
+ call_credentials,
296
+ compression,
297
+ wait_for_ready,
298
+ timeout,
299
+ metadata,
300
+ _registered_method=True)
301
+
302
+ @staticmethod
303
+ def GetServerInfo(request,
304
+ target,
305
+ options=(),
306
+ channel_credentials=None,
307
+ call_credentials=None,
308
+ insecure=False,
309
+ compression=None,
310
+ wait_for_ready=None,
311
+ timeout=None,
312
+ metadata=None):
313
+ return grpc.experimental.unary_unary(
314
+ request,
315
+ target,
316
+ '/sglang.grpc.scheduler.SglangScheduler/GetServerInfo',
317
+ sglang__scheduler__pb2.GetServerInfoRequest.SerializeToString,
318
+ sglang__scheduler__pb2.GetServerInfoResponse.FromString,
319
+ options,
320
+ channel_credentials,
321
+ insecure,
322
+ call_credentials,
323
+ compression,
324
+ wait_for_ready,
325
+ timeout,
326
+ metadata,
327
+ _registered_method=True)
@@ -224,12 +224,13 @@ class XIELU(CustomOp):
224
224
  self._xielu_cuda_fn = self._xielu_cuda
225
225
  logger.warning_once(msg)
226
226
  except Exception as err:
227
- logger.warning_once(
228
- "CUDA-fused xIELU not available (%s) –"
229
- " falling back to a Python version.\n"
230
- "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
231
- str(err),
232
- )
227
+ pass
228
+ # logger.warning_once(
229
+ # "CUDA-fused xIELU not available (%s) "
230
+ # " falling back to a Python version.\n"
231
+ # "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
232
+ # str(err),
233
+ # )
233
234
 
234
235
  def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
235
236
  alpha_p = nn.functional.softplus(self.alpha_p)
@@ -379,4 +380,7 @@ if not (
379
380
  logger.info(
380
381
  "sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
381
382
  )
382
- from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
383
+ from vllm.model_executor.layers.activation import ( # noqa: F401
384
+ GeluAndMul,
385
+ SiluAndMul,
386
+ )
@@ -4,18 +4,13 @@ from __future__ import annotations
4
4
  end to end attention solution with aiter kernels
5
5
  """
6
6
 
7
- import math
8
- import os
9
7
  from dataclasses import dataclass
10
8
  from enum import Enum, auto
11
- from functools import partial
12
- from typing import TYPE_CHECKING, List, Optional, Union
9
+ from typing import TYPE_CHECKING, Optional
13
10
 
14
11
  import torch
15
12
  import triton
16
- import triton.language as tl
17
13
 
18
- from sglang.global_config import global_config
19
14
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
20
15
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
21
16
  from sglang.srt.layers.dp_attention import (
@@ -27,7 +22,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
27
22
  if TYPE_CHECKING:
28
23
  from sglang.srt.layers.radix_attention import RadixAttention
29
24
  from sglang.srt.model_executor.model_runner import ModelRunner
30
- from sglang.srt.speculative.spec_info import SpecInfo
25
+ from sglang.srt.speculative.spec_info import SpecInput
31
26
 
32
27
  try:
33
28
  from aiter import (
@@ -374,7 +369,7 @@ class AiterAttnBackend(AttentionBackend):
374
369
  seq_lens: torch.Tensor,
375
370
  encoder_lens: Optional[torch.Tensor],
376
371
  forward_mode: ForwardMode,
377
- spec_info: Optional[SpecInfo],
372
+ spec_info: Optional[SpecInput],
378
373
  ):
379
374
  if forward_mode.is_decode_or_idle():
380
375
  qo_indptr = None
@@ -509,7 +504,7 @@ class AiterAttnBackend(AttentionBackend):
509
504
  seq_lens_sum: int,
510
505
  encoder_lens: Optional[torch.Tensor],
511
506
  forward_mode: ForwardMode,
512
- spec_info: Optional[SpecInfo],
507
+ spec_info: Optional[SpecInput],
513
508
  seq_lens_cpu: Optional[torch.Tensor],
514
509
  ):
515
510
  if forward_mode.is_decode_or_idle():
@@ -619,7 +614,11 @@ class AiterAttnBackend(AttentionBackend):
619
614
  assert len(k.shape) == 3
620
615
  assert len(v.shape) == 3
621
616
 
622
- if forward_batch.forward_mode.is_extend():
617
+ if (
618
+ forward_batch.forward_mode.is_extend()
619
+ and not forward_batch.forward_mode.is_target_verify()
620
+ and not forward_batch.forward_mode.is_draft_extend()
621
+ ):
623
622
  if kv_indices.shape[0] == 0:
624
623
  o = flash_attn_varlen_func(
625
624
  q,
@@ -884,7 +883,7 @@ class AiterIndicesUpdaterPrefill:
884
883
  seq_lens_sum: int,
885
884
  prefix_lens: torch.Tensor,
886
885
  encoder_lens: Optional[torch.Tensor],
887
- spec_info: Optional[SpecInfo],
886
+ spec_info: Optional[SpecInput],
888
887
  ):
889
888
  # Keep the signature for type checking. It will be assigned during runtime.
890
889
  raise NotImplementedError()
@@ -896,7 +895,7 @@ class AiterIndicesUpdaterPrefill:
896
895
  seq_lens_sum: int,
897
896
  prefix_lens: torch.Tensor,
898
897
  encoder_lens: Optional[torch.Tensor],
899
- spec_info: Optional[SpecInfo],
898
+ spec_info: Optional[SpecInput],
900
899
  ):
901
900
 
902
901
  kv_start_idx = None
@@ -980,7 +979,7 @@ class AiterMlaIndicesUpdaterPrefill:
980
979
  extend_lens: torch.Tensor,
981
980
  max_q_len: int,
982
981
  max_kv_len: int,
983
- spec_info: Optional[SpecInfo],
982
+ spec_info: Optional[SpecInput],
984
983
  ):
985
984
  # Keep the signature for type checking. It will be assigned during runtime.
986
985
  raise NotImplementedError()
@@ -993,7 +992,7 @@ class AiterMlaIndicesUpdaterPrefill:
993
992
  extend_lens: torch.Tensor,
994
993
  max_q_len: int,
995
994
  max_kv_len: int,
996
- spec_info: Optional[SpecInfo],
995
+ spec_info: Optional[SpecInput],
997
996
  ):
998
997
  bs = len(req_pool_indices)
999
998
 
@@ -1050,7 +1049,7 @@ class AiterMultiStepDraftBackend:
1050
1049
  topk: int,
1051
1050
  speculative_num_steps: int,
1052
1051
  ):
1053
- from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
1052
+ from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
1054
1053
 
1055
1054
  self.topk = topk
1056
1055
  self.speculative_num_steps = speculative_num_steps
@@ -1065,7 +1064,7 @@ class AiterMultiStepDraftBackend:
1065
1064
  device=model_runner.device,
1066
1065
  )
1067
1066
  self.attn_backends = []
1068
- for i in range(self.speculative_num_steps):
1067
+ for i in range(self.speculative_num_steps - 1):
1069
1068
  self.attn_backends.append(
1070
1069
  AiterAttnBackend(
1071
1070
  model_runner,
@@ -1108,7 +1107,7 @@ class AiterMultiStepDraftBackend:
1108
1107
  self.page_size,
1109
1108
  )
1110
1109
 
1111
- for i in range(self.speculative_num_steps):
1110
+ for i in range(self.speculative_num_steps - 1):
1112
1111
  forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
1113
1112
  forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
1114
1113
  : seq_lens_sum * self.topk + bs * (i + 1)
@@ -1142,7 +1141,7 @@ class AiterMultiStepDraftBackend:
1142
1141
  dtype=torch.int32,
1143
1142
  device=self.device,
1144
1143
  )
1145
- for i in range(self.speculative_num_steps):
1144
+ for i in range(self.speculative_num_steps - 1):
1146
1145
  self.attn_backends[i].init_cuda_graph_state(
1147
1146
  max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
1148
1147
  )