sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,9 @@
1
+ from abc import ABC, abstractmethod
1
2
  from enum import IntEnum, auto
3
+ from functools import lru_cache
4
+ from typing import List, Tuple
5
+
6
+ from sglang.srt.managers.schedule_batch import ModelWorkerBatch
2
7
 
3
8
 
4
9
  class SpeculativeAlgorithm(IntEnum):
@@ -6,6 +11,7 @@ class SpeculativeAlgorithm(IntEnum):
6
11
  EAGLE = auto()
7
12
  EAGLE3 = auto()
8
13
  STANDALONE = auto()
14
+ NGRAM = auto()
9
15
 
10
16
  def is_none(self):
11
17
  return self == SpeculativeAlgorithm.NONE
@@ -19,14 +25,57 @@ class SpeculativeAlgorithm(IntEnum):
19
25
  def is_standalone(self):
20
26
  return self == SpeculativeAlgorithm.STANDALONE
21
27
 
28
+ def is_ngram(self):
29
+ return self == SpeculativeAlgorithm.NGRAM
30
+
31
+ @lru_cache(maxsize=None)
22
32
  @staticmethod
23
33
  def from_string(name: str):
24
34
  name_map = {
25
35
  "EAGLE": SpeculativeAlgorithm.EAGLE,
26
36
  "EAGLE3": SpeculativeAlgorithm.EAGLE3,
27
37
  "STANDALONE": SpeculativeAlgorithm.STANDALONE,
38
+ "NGRAM": SpeculativeAlgorithm.NGRAM,
28
39
  None: SpeculativeAlgorithm.NONE,
29
40
  }
30
41
  if name is not None:
31
42
  name = name.upper()
32
43
  return name_map[name]
44
+
45
+
46
+ class SpecInputType(IntEnum):
47
+ # NOTE: introduce this to distinguish the SpecInput types of multiple algorithms when asserting in attention backends.
48
+ # If all algorithms can share the same datastrucutre of draft_input and verify_input, consider simplify it
49
+ EAGLE_DRAFT = auto()
50
+ EAGLE_VERIFY = auto()
51
+ NGRAM_VERIFY = auto()
52
+
53
+
54
+ class SpecInput(ABC):
55
+ def __init__(self, spec_input_type: SpecInputType):
56
+ self.spec_input_type = spec_input_type
57
+
58
+ def is_draft_input(self) -> bool:
59
+ # FIXME: remove this function which is only used for assertion
60
+ # or use another variable name like `draft_input` to substitute `spec_info`
61
+ return self.spec_input_type == SpecInputType.EAGLE_DRAFT
62
+
63
+ def is_verify_input(self) -> bool:
64
+ return self.spec_input_type in {
65
+ SpecInputType.EAGLE_VERIFY,
66
+ SpecInputType.NGRAM_VERIFY,
67
+ }
68
+
69
+ @abstractmethod
70
+ def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
71
+ pass
72
+
73
+ def get_spec_adjusted_global_num_tokens(
74
+ self, forward_batch: ModelWorkerBatch
75
+ ) -> Tuple[List[int], List[int]]:
76
+ c1, c2 = self.get_spec_adjust_token_coefficient()
77
+ global_num_tokens = [x * c1 for x in forward_batch.global_num_tokens]
78
+ global_num_tokens_for_logprob = [
79
+ x * c2 for x in forward_batch.global_num_tokens_for_logprob
80
+ ]
81
+ return global_num_tokens, global_num_tokens_for_logprob
@@ -0,0 +1,641 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ import time
6
+ from contextlib import contextmanager
7
+ from typing import TYPE_CHECKING, List
8
+
9
+ import torch
10
+ import triton
11
+ import triton.language as tl
12
+ from huggingface_hub import snapshot_download
13
+
14
+ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
15
+ from sglang.srt.distributed.parallel_state import (
16
+ GroupCoordinator,
17
+ patch_tensor_parallel_group,
18
+ )
19
+ from sglang.srt.environ import envs
20
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
21
+ from sglang.srt.managers.schedule_batch import Req
22
+ from sglang.srt.utils import is_cuda, is_hip
23
+
24
+ if TYPE_CHECKING:
25
+ from sglang.srt.speculative.eagle_info import EagleVerifyInput
26
+
27
+
28
+ if is_cuda():
29
+ from sgl_kernel import fast_topk
30
+ elif is_hip():
31
+ from sgl_kernel import fast_topk
32
+
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ # Simulate acceptance length for benchmarking purposes
38
+ SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0
39
+ SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get()
40
+
41
+ TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
42
+ TREE_SPEC_KERNEL_AVAILABLE = is_cuda() # This kernel is only available for CUDA now
43
+
44
+
45
+ @triton.jit
46
+ def create_extend_after_decode_spec_info(
47
+ verified_id,
48
+ seq_lens,
49
+ accept_lens,
50
+ positions,
51
+ new_verified_id,
52
+ bs_upper: tl.constexpr,
53
+ ):
54
+ pid = tl.program_id(axis=0)
55
+ offsets = tl.arange(0, bs_upper)
56
+ seq_length = tl.load(seq_lens + pid)
57
+ accept_length = tl.load(accept_lens + pid)
58
+
59
+ accept_len_cumsum = tl.sum(
60
+ tl.load(accept_lens + offsets, mask=offsets < pid, other=0)
61
+ )
62
+ positions_ptr = positions + accept_len_cumsum
63
+ mask = offsets < accept_length
64
+ tl.store(positions_ptr + offsets, seq_length - accept_length + offsets, mask)
65
+
66
+ accept_len_cumsum += accept_length - 1
67
+ verified_id_data = tl.load(verified_id + accept_len_cumsum)
68
+ tl.store(new_verified_id + pid, verified_id_data)
69
+
70
+
71
+ @triton.jit
72
+ def assign_req_to_token_pool(
73
+ req_pool_indices,
74
+ req_to_token,
75
+ start_offset,
76
+ end_offset,
77
+ out_cache_loc,
78
+ pool_len: tl.constexpr,
79
+ bs_upper: tl.constexpr,
80
+ ):
81
+ BLOCK_SIZE: tl.constexpr = 32
82
+ pid = tl.program_id(axis=0)
83
+ kv_start = tl.load(start_offset + pid)
84
+ kv_end = tl.load(end_offset + pid)
85
+ token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
86
+
87
+ length_offset = tl.arange(0, bs_upper)
88
+ start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0)
89
+ end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0)
90
+ out_offset = tl.sum(end - start, axis=0)
91
+
92
+ out_cache_ptr = out_cache_loc + out_offset
93
+
94
+ save_offset = tl.arange(0, BLOCK_SIZE) + kv_start
95
+ load_offset = tl.arange(0, BLOCK_SIZE)
96
+
97
+ num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
98
+ for _ in range(num_loop):
99
+ mask = save_offset < kv_end
100
+ data = tl.load(out_cache_ptr + load_offset, mask=mask)
101
+ tl.store(token_pool + save_offset, data, mask=mask)
102
+ save_offset += BLOCK_SIZE
103
+ load_offset += BLOCK_SIZE
104
+
105
+
106
+ @triton.jit
107
+ def assign_draft_cache_locs(
108
+ req_pool_indices,
109
+ req_to_token,
110
+ seq_lens,
111
+ extend_lens,
112
+ num_new_pages_per_topk,
113
+ out_cache_loc,
114
+ pool_len: tl.constexpr,
115
+ topk: tl.constexpr,
116
+ speculative_num_steps: tl.constexpr,
117
+ page_size: tl.constexpr,
118
+ bs_upper: tl.constexpr,
119
+ iter_upper: tl.constexpr,
120
+ ):
121
+ BLOCK_SIZE: tl.constexpr = 128
122
+ pid = tl.program_id(axis=0)
123
+
124
+ if page_size == 1 or topk == 1:
125
+ copy_len = topk * speculative_num_steps
126
+ out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
127
+ else:
128
+ bs_offset = tl.arange(0, bs_upper)
129
+ copy_len = tl.load(extend_lens + pid)
130
+ cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
131
+ out_cache_ptr = out_cache_loc + cum_copy_len
132
+
133
+ # Part 1: Copy from out_cache_loc to req_to_token
134
+ kv_start = tl.load(seq_lens + pid)
135
+ token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
136
+ num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
137
+ for i in range(num_loop):
138
+ copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
139
+ mask = copy_offset < copy_len
140
+ data = tl.load(out_cache_ptr + copy_offset, mask=mask)
141
+ tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
142
+
143
+ if page_size == 1 or topk == 1:
144
+ return
145
+
146
+ # Part 2: Copy the indices for the last partial page
147
+ prefix_len = tl.load(seq_lens + pid)
148
+ last_page_len = prefix_len % page_size
149
+ offsets = tl.arange(0, page_size)
150
+ mask = offsets < last_page_len
151
+ num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
152
+ prefix_base = token_pool + prefix_len - last_page_len
153
+
154
+ for topk_id in range(topk):
155
+ value = tl.load(prefix_base + offsets, mask=mask)
156
+ tl.store(
157
+ prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
158
+ value,
159
+ mask=mask,
160
+ )
161
+
162
+ # Part 3: Remove the padding in out_cache_loc
163
+ iter_offest = tl.arange(0, iter_upper)
164
+ for topk_id in range(topk):
165
+ indices = tl.load(
166
+ prefix_base
167
+ + topk_id * num_new_pages_per_topk_ * page_size
168
+ + last_page_len
169
+ + iter_offest,
170
+ mask=iter_offest < speculative_num_steps,
171
+ )
172
+ tl.store(
173
+ out_cache_loc
174
+ + pid * topk * speculative_num_steps
175
+ + topk_id * speculative_num_steps
176
+ + iter_offest,
177
+ indices,
178
+ mask=iter_offest < speculative_num_steps,
179
+ )
180
+
181
+
182
+ @triton.jit
183
+ def generate_draft_decode_kv_indices(
184
+ req_pool_indices,
185
+ req_to_token,
186
+ paged_kernel_lens,
187
+ kv_indices,
188
+ kv_indptr,
189
+ positions,
190
+ pool_len: tl.constexpr,
191
+ kv_indices_stride: tl.constexpr,
192
+ kv_indptr_stride: tl.constexpr,
193
+ bs_upper: tl.constexpr,
194
+ iter_upper: tl.constexpr,
195
+ num_tokens_upper: tl.constexpr,
196
+ page_size: tl.constexpr,
197
+ ):
198
+ BLOCK_SIZE: tl.constexpr = 128
199
+ iters = tl.program_id(axis=0)
200
+ bid = tl.program_id(axis=1)
201
+ topk_id = tl.program_id(axis=2)
202
+
203
+ num_steps = tl.num_programs(axis=0)
204
+ num_seqs = tl.num_programs(axis=1)
205
+ topk = tl.num_programs(axis=2)
206
+
207
+ kv_indices += kv_indices_stride * iters
208
+ kv_indptr += kv_indptr_stride * iters
209
+ iters += 1
210
+
211
+ load_offset = tl.arange(0, bs_upper)
212
+ seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid, other=0)
213
+ seq_len = tl.load(paged_kernel_lens + bid)
214
+ cum_seq_len = tl.sum(seq_lens)
215
+
216
+ # Update kv_indices
217
+ kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
218
+ kv_ptr = kv_indices + kv_offset
219
+ token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
220
+
221
+ kv_offset = tl.arange(0, BLOCK_SIZE)
222
+ num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
223
+ for _ in range(num_loop):
224
+ mask = kv_offset < seq_len
225
+ data = tl.load(token_pool_ptr + kv_offset, mask=mask)
226
+ tl.store(kv_ptr + kv_offset, data, mask=mask)
227
+ kv_offset += BLOCK_SIZE
228
+
229
+ extend_offset = tl.arange(0, iter_upper)
230
+ if page_size == 1 or topk == 1:
231
+ extend_data = tl.load(
232
+ token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper),
233
+ mask=extend_offset < iters,
234
+ )
235
+ else:
236
+ prefix_len = seq_len
237
+ last_page_len = prefix_len % page_size
238
+ num_new_pages_per_topk = (
239
+ last_page_len + num_steps + page_size - 1
240
+ ) // page_size
241
+ prefix_base = seq_len // page_size * page_size
242
+ start = (
243
+ prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
244
+ )
245
+ extend_data = tl.load(
246
+ token_pool_ptr + start + extend_offset,
247
+ mask=extend_offset < iters,
248
+ )
249
+
250
+ tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
251
+
252
+ # Update kv_indptr
253
+ bs_offset = tl.arange(0, num_tokens_upper)
254
+
255
+ zid = bid * topk + topk_id
256
+ if zid == 0:
257
+ zid = num_seqs * topk
258
+ positions = tl.load(positions + bs_offset, mask=bs_offset < zid, other=0)
259
+ base = tl.sum(positions)
260
+ tl.store(kv_indptr + zid, base + zid * iters)
261
+
262
+
263
+ @triton.jit
264
+ def align_evict_mask_to_page_size(
265
+ seq_lens,
266
+ evict_mask,
267
+ page_size: tl.constexpr,
268
+ num_draft_tokens: tl.constexpr,
269
+ BLOCK_SIZE: tl.constexpr,
270
+ ):
271
+ t_range = tl.arange(0, BLOCK_SIZE)
272
+
273
+ bid = tl.program_id(axis=0)
274
+ seq_len = tl.load(seq_lens + bid)
275
+ io_mask = t_range < num_draft_tokens
276
+ mask_row = tl.load(
277
+ evict_mask + bid * num_draft_tokens + t_range, mask=io_mask, other=0
278
+ )
279
+
280
+ num_trues = tl.sum(mask_row)
281
+ num_false = num_draft_tokens - num_trues
282
+
283
+ start = (seq_len + num_false - 1) // page_size * page_size - seq_len
284
+ for i in range(max(start, 0), min(start + page_size, num_draft_tokens)):
285
+ tl.store(evict_mask + bid * num_draft_tokens + i, False)
286
+
287
+
288
+ @triton.jit
289
+ def get_target_cache_loc(
290
+ tgt_cache_loc,
291
+ to_free_slots,
292
+ accept_length,
293
+ to_free_num_slots,
294
+ out_cache_loc,
295
+ num_verify_tokens: tl.constexpr,
296
+ num_verify_tokens_upper: tl.constexpr,
297
+ bs_upper: tl.constexpr,
298
+ ):
299
+ bid = tl.program_id(axis=0)
300
+ offset = tl.arange(0, num_verify_tokens_upper)
301
+ bs_offset = tl.arange(0, bs_upper)
302
+
303
+ # write the first part to tgt_cache_loc
304
+ accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
305
+ tgt_cache_loc_start = tl.sum(accept_len_all) + bid
306
+ copy_len = tl.load(accept_length + bid) + 1
307
+ out_cache_loc_row = tl.load(
308
+ out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
309
+ )
310
+ tl.store(
311
+ tgt_cache_loc + tgt_cache_loc_start + offset,
312
+ out_cache_loc_row,
313
+ mask=offset < copy_len,
314
+ )
315
+
316
+ # write the second part to to_free_num_pages
317
+ to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid)
318
+ to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
319
+ out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
320
+ to_free_slots_start = tl.sum(to_free_num_slots_all)
321
+
322
+ copy_len = to_free_num_slots_cur
323
+ out_cache_loc_row = tl.load(
324
+ out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
325
+ mask=offset < copy_len,
326
+ )
327
+ tl.store(
328
+ to_free_slots + to_free_slots_start + offset,
329
+ out_cache_loc_row,
330
+ mask=offset < copy_len,
331
+ )
332
+
333
+
334
+ @torch.compile(dynamic=True)
335
+ def get_src_tgt_cache_loc(
336
+ seq_lens: torch.Tensor,
337
+ out_cache_loc: torch.Tensor,
338
+ accept_index: torch.Tensor,
339
+ accept_length: torch.Tensor,
340
+ draft_token_num: int,
341
+ page_size: int,
342
+ ):
343
+ src_cache_loc = out_cache_loc[accept_index]
344
+ tgt_cache_loc = torch.empty_like(src_cache_loc)
345
+ extended_len = seq_lens + draft_token_num
346
+ keep_len = torch.minimum(
347
+ (seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
348
+ extended_len,
349
+ )
350
+ to_free_num_slots = extended_len - keep_len
351
+ return src_cache_loc, tgt_cache_loc, to_free_num_slots
352
+
353
+
354
+ @triton.jit
355
+ def filter_finished_cache_loc_kernel(
356
+ out_cache_loc,
357
+ tgt_cache_loc,
358
+ accept_length,
359
+ accept_length_filter,
360
+ bs_upper: tl.constexpr,
361
+ num_verify_tokens_upper: tl.constexpr,
362
+ ):
363
+ bid = tl.program_id(0)
364
+ bs_offset = tl.arange(0, bs_upper)
365
+
366
+ accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
367
+ old_start = tl.sum(accept_length_all) + bid
368
+
369
+ accept_length_filter_all = tl.load(
370
+ accept_length_filter + bs_offset, mask=bs_offset < bid
371
+ )
372
+ new_start = tl.sum(accept_length_filter_all)
373
+
374
+ copy_len = tl.load(accept_length_filter + bid)
375
+ copy_offset = tl.arange(0, num_verify_tokens_upper)
376
+ value = tl.load(
377
+ tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
378
+ )
379
+ tl.store(
380
+ out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
381
+ )
382
+
383
+
384
+ @torch.compile(dynamic=True)
385
+ def create_accept_length_filter(
386
+ accept_length: torch.Tensor,
387
+ unfinished_index_device: torch.Tensor,
388
+ seq_lens: torch.Tensor,
389
+ ):
390
+ accept_length_filter = torch.zeros_like(accept_length)
391
+ accept_length_filter[unfinished_index_device] = (
392
+ accept_length[unfinished_index_device] + 1
393
+ )
394
+ seq_lens.add_(accept_length + 1)
395
+ return accept_length_filter
396
+
397
+
398
+ @torch.compile(dynamic=True)
399
+ def select_top_k_tokens(
400
+ i: int,
401
+ topk_p: torch.Tensor,
402
+ topk_index: torch.Tensor,
403
+ hidden_states: torch.Tensor,
404
+ scores: torch.Tensor,
405
+ topk: int,
406
+ ):
407
+ if i == 0:
408
+ # The first step after extend
409
+ input_ids = topk_index.flatten()
410
+ hidden_states = hidden_states.repeat_interleave(topk, dim=0)
411
+ scores = topk_p # shape: (b, topk)
412
+
413
+ tree_info = (
414
+ topk_p.unsqueeze(1), # shape: (b, 1, topk)
415
+ topk_index, # shape: (b, topk)
416
+ torch.arange(-1, topk, dtype=torch.long, device="cuda")
417
+ .unsqueeze(0)
418
+ .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
419
+ )
420
+ else:
421
+ # The later decode steps
422
+ expand_scores = torch.mul(
423
+ scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
424
+ ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
425
+ topk_cs_p, topk_cs_index = fast_topk(
426
+ expand_scores.flatten(start_dim=1), topk, dim=-1
427
+ ) # (b, topk)
428
+ scores = topk_cs_p # shape: (b, topk)
429
+
430
+ topk_index = topk_index.reshape(-1, topk**2)
431
+ input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
432
+
433
+ if hidden_states.shape[0] > 0:
434
+ selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
435
+ 0, hidden_states.shape[0], step=topk, device="cuda"
436
+ ).repeat_interleave(topk)
437
+ hidden_states = hidden_states[selected_input_index, :]
438
+
439
+ tree_info = (
440
+ expand_scores, # shape: (b, topk, topk)
441
+ topk_index, # shape: (b, topk * topk)
442
+ topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk)
443
+ )
444
+
445
+ return input_ids, hidden_states, scores, tree_info
446
+
447
+
448
+ def generate_simulated_accept_index(
449
+ accept_index,
450
+ predict,
451
+ accept_length,
452
+ bs,
453
+ spec_steps,
454
+ simulate_acc_len: float = SIMULATE_ACC_LEN,
455
+ simulate_acc_method: str = SIMULATE_ACC_METHOD,
456
+ ):
457
+ assert simulate_acc_len > 0.0
458
+
459
+ if simulate_acc_method == "multinomial":
460
+ simulated_values = torch.normal(
461
+ mean=simulate_acc_len,
462
+ std=1.0,
463
+ size=(1,),
464
+ device="cpu",
465
+ )
466
+ # clamp simulated values to be between 1 and self.spec_steps
467
+ simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
468
+ simulate_acc_len = int(simulated_values.round().item())
469
+ elif simulate_acc_method == "match-expected":
470
+ # multinomial sampling does not match the expected length
471
+ # we keep it for the sake of compatibility of existing tests
472
+ # but it's better to use "match-expected" for the cases that need to
473
+ # match the expected length, One caveat is that this will only sample
474
+ # either round down or round up of the expected length
475
+ simulate_acc_len = max(1.0, min(spec_steps + 1, simulate_acc_len))
476
+ lower = int(simulate_acc_len // 1)
477
+ upper = lower + 1 if lower < spec_steps + 1 else lower
478
+ if lower == upper:
479
+ simulate_acc_len = lower
480
+ else:
481
+ weight_upper = simulate_acc_len - lower
482
+ weight_lower = 1.0 - weight_upper
483
+ probs = torch.tensor([weight_lower, weight_upper], device="cpu")
484
+ sampled_index = torch.multinomial(probs, num_samples=1)
485
+ simulate_acc_len = lower if sampled_index == 0 else upper
486
+ else:
487
+ raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
488
+
489
+ accept_indx_first_col = accept_index[:, 0].view(-1, 1)
490
+ sim_accept_index = torch.full(
491
+ (bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda"
492
+ )
493
+ sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange(
494
+ simulate_acc_len, device=accept_index.device
495
+ )
496
+ accept_length.fill_(simulate_acc_len - 1)
497
+ predict.fill_(100) # some legit token id
498
+ return sim_accept_index
499
+
500
+
501
+ def traverse_tree(
502
+ retrieve_next_token: torch.Tensor,
503
+ retrieve_next_sibling: torch.Tensor,
504
+ draft_tokens: torch.Tensor,
505
+ grammar: BaseGrammarObject,
506
+ allocate_token_bitmask: torch.Tensor,
507
+ ):
508
+ """
509
+ Traverse the tree constructed by the draft model to generate the logits mask.
510
+ """
511
+ assert (
512
+ retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape
513
+ )
514
+
515
+ allocate_token_bitmask.fill_(0)
516
+
517
+ def dfs(
518
+ curr: int,
519
+ retrieve_next_token: torch.Tensor,
520
+ retrieve_next_sibling: torch.Tensor,
521
+ parent_pos: int,
522
+ ):
523
+ if curr == 0:
524
+ # the first token generated by the target model, and thus it is always
525
+ # accepted from the previous iteration
526
+ accepted = True
527
+ else:
528
+ parent_bitmask = allocate_token_bitmask[parent_pos]
529
+ curr_token_id = draft_tokens[curr]
530
+ # 32 boolean bitmask values are packed into 32-bit integers
531
+ accepted = (
532
+ parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))
533
+ ) != 0
534
+
535
+ if accepted:
536
+ if curr != 0:
537
+ # Accept the current token
538
+ grammar.accept_token(draft_tokens[curr])
539
+ if not grammar.is_terminated():
540
+ # Generate the bitmask for the current token
541
+ grammar.fill_vocab_mask(allocate_token_bitmask, curr)
542
+ if retrieve_next_token[curr] != -1:
543
+ # Visit the child node
544
+ dfs(
545
+ retrieve_next_token[curr],
546
+ retrieve_next_token,
547
+ retrieve_next_sibling,
548
+ curr,
549
+ )
550
+
551
+ if curr != 0:
552
+ # Rollback the current token
553
+ grammar.rollback(1)
554
+
555
+ if retrieve_next_sibling[curr] != -1:
556
+ # Visit the sibling node
557
+ dfs(
558
+ retrieve_next_sibling[curr],
559
+ retrieve_next_token,
560
+ retrieve_next_sibling,
561
+ parent_pos,
562
+ )
563
+
564
+ dfs(0, retrieve_next_token, retrieve_next_sibling, -1)
565
+
566
+
567
+ def generate_token_bitmask(
568
+ reqs: List[Req],
569
+ verify_input: EagleVerifyInput,
570
+ retrieve_next_token_cpu: torch.Tensor,
571
+ retrieve_next_sibling_cpu: torch.Tensor,
572
+ draft_tokens_cpu: torch.Tensor,
573
+ vocab_size: int,
574
+ ):
575
+ """
576
+ Generate the logit mask for structured output.
577
+ Draft model's token can be either valid or invalid with respect to the grammar.
578
+ We need to perform DFS to
579
+ 1. figure out which tokens are accepted by the grammar.
580
+ 2. if so, what is the corresponding logit mask.
581
+ """
582
+
583
+ num_draft_tokens = draft_tokens_cpu.shape[-1]
584
+
585
+ allocate_token_bitmask = None
586
+ assert len(reqs) == retrieve_next_token_cpu.shape[0]
587
+ grammar = None
588
+ for i, req in enumerate(reqs):
589
+ if req.grammar is not None:
590
+ if allocate_token_bitmask is None:
591
+ allocate_token_bitmask = req.grammar.allocate_vocab_mask(
592
+ vocab_size=vocab_size,
593
+ batch_size=draft_tokens_cpu.numel(),
594
+ device="cpu",
595
+ )
596
+ grammar = req.grammar
597
+ s = time.perf_counter()
598
+ traverse_tree(
599
+ retrieve_next_token_cpu[i],
600
+ retrieve_next_sibling_cpu[i],
601
+ draft_tokens_cpu[i],
602
+ req.grammar,
603
+ allocate_token_bitmask[
604
+ i * num_draft_tokens : (i + 1) * num_draft_tokens
605
+ ],
606
+ )
607
+ tree_traverse_time = time.perf_counter() - s
608
+ if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
609
+ logger.warning(
610
+ f"Bit mask generation took {tree_traverse_time} seconds with "
611
+ f"grammar: {req.grammar}"
612
+ )
613
+
614
+ verify_input.grammar = grammar
615
+ return allocate_token_bitmask
616
+
617
+
618
+ def load_token_map(token_map_path: str) -> List[int]:
619
+ if not os.path.exists(token_map_path):
620
+ cache_dir = snapshot_download(
621
+ os.path.dirname(token_map_path),
622
+ ignore_patterns=["*.bin", "*.safetensors"],
623
+ )
624
+ token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
625
+ hot_token_id = torch.load(token_map_path, weights_only=True)
626
+ return torch.tensor(hot_token_id, dtype=torch.int64)
627
+
628
+
629
+ @contextmanager
630
+ def draft_tp_context(tp_group: GroupCoordinator):
631
+ # Draft model doesn't use dp and has its own tp group.
632
+ # We disable mscclpp now because it doesn't support 2 comm groups.
633
+ with patch_tensor_parallel_group(tp_group):
634
+ yield
635
+
636
+
637
+ def detect_nan(logits_output: LogitsProcessorOutput):
638
+ logits = logits_output.next_token_logits
639
+ if torch.any(torch.isnan(logits)):
640
+ logger.error("Detected errors during sampling! NaN in the logits.")
641
+ raise ValueError("Detected errors during sampling! NaN in the logits.")