sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,16 @@
1
+ from abc import ABC, abstractmethod
1
2
  from enum import IntEnum, auto
3
+ from typing import List, Tuple
4
+
5
+ from sglang.srt.managers.schedule_batch import ModelWorkerBatch
2
6
 
3
7
 
4
8
  class SpeculativeAlgorithm(IntEnum):
5
9
  NONE = auto()
6
10
  EAGLE = auto()
7
11
  EAGLE3 = auto()
12
+ STANDALONE = auto()
13
+ NGRAM = auto()
8
14
 
9
15
  def is_none(self):
10
16
  return self == SpeculativeAlgorithm.NONE
@@ -15,13 +21,59 @@ class SpeculativeAlgorithm(IntEnum):
15
21
  def is_eagle3(self):
16
22
  return self == SpeculativeAlgorithm.EAGLE3
17
23
 
24
+ def is_standalone(self):
25
+ return self == SpeculativeAlgorithm.STANDALONE
26
+
27
+ def is_ngram(self):
28
+ return self == SpeculativeAlgorithm.NGRAM
29
+
18
30
  @staticmethod
19
31
  def from_string(name: str):
20
32
  name_map = {
21
33
  "EAGLE": SpeculativeAlgorithm.EAGLE,
22
34
  "EAGLE3": SpeculativeAlgorithm.EAGLE3,
35
+ "STANDALONE": SpeculativeAlgorithm.STANDALONE,
36
+ "NGRAM": SpeculativeAlgorithm.NGRAM,
23
37
  None: SpeculativeAlgorithm.NONE,
24
38
  }
25
39
  if name is not None:
26
40
  name = name.upper()
27
41
  return name_map[name]
42
+
43
+
44
+ class SpecInputType(IntEnum):
45
+ # NOTE: introduce this to distinguish the SpecInput types of multiple algorithms when asserting in attention backends.
46
+ # If all algorithms can share the same datastrucutre of draft_input and verify_input, consider simplify it
47
+ EAGLE_DRAFT = auto()
48
+ EAGLE_VERIFY = auto()
49
+ NGRAM_VERIFY = auto()
50
+
51
+
52
+ class SpecInput(ABC):
53
+ def __init__(self, spec_input_type: SpecInputType):
54
+ self.spec_input_type = spec_input_type
55
+
56
+ def is_draft_input(self) -> bool:
57
+ # FIXME: remove this function which is only used for assertion
58
+ # or use another variable name like `draft_input` to substitute `spec_info`
59
+ return self.spec_input_type == SpecInputType.EAGLE_DRAFT
60
+
61
+ def is_verify_input(self) -> bool:
62
+ return self.spec_input_type in {
63
+ SpecInputType.EAGLE_VERIFY,
64
+ SpecInputType.NGRAM_VERIFY,
65
+ }
66
+
67
+ @abstractmethod
68
+ def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
69
+ pass
70
+
71
+ def get_spec_adjusted_global_num_tokens(
72
+ self, forward_batch: ModelWorkerBatch
73
+ ) -> Tuple[List[int], List[int]]:
74
+ c1, c2 = self.get_spec_adjust_token_coefficient()
75
+ global_num_tokens = [x * c1 for x in forward_batch.global_num_tokens]
76
+ global_num_tokens_for_logprob = [
77
+ x * c2 for x in forward_batch.global_num_tokens_for_logprob
78
+ ]
79
+ return global_num_tokens, global_num_tokens_for_logprob
@@ -0,0 +1,605 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import time
5
+ from typing import TYPE_CHECKING, List
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
12
+ from sglang.srt.environ import envs
13
+ from sglang.srt.managers.schedule_batch import Req
14
+ from sglang.srt.utils import is_cuda, is_hip
15
+
16
+ if is_cuda():
17
+ from sgl_kernel import fast_topk
18
+ elif is_hip():
19
+ from sgl_kernel import fast_topk
20
+
21
+ if TYPE_CHECKING:
22
+ from sglang.srt.speculative.eagle_info import EagleVerifyInput
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ # Simulate acceptance length for benchmarking purposes
28
+ SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0
29
+ SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get()
30
+
31
+ TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
32
+ TREE_SPEC_KERNEL_AVAILABLE = is_cuda() # This kernel is only available for CUDA now
33
+
34
+
35
+ @triton.jit
36
+ def create_extend_after_decode_spec_info(
37
+ verified_id,
38
+ seq_lens,
39
+ accept_lens,
40
+ positions,
41
+ new_verified_id,
42
+ bs_upper: tl.constexpr,
43
+ ):
44
+ pid = tl.program_id(axis=0)
45
+ offsets = tl.arange(0, bs_upper)
46
+ seq_length = tl.load(seq_lens + pid)
47
+ accept_length = tl.load(accept_lens + pid)
48
+
49
+ accept_len_cumsum = tl.sum(
50
+ tl.load(accept_lens + offsets, mask=offsets < pid, other=0)
51
+ )
52
+ positions_ptr = positions + accept_len_cumsum
53
+ mask = offsets < accept_length
54
+ tl.store(positions_ptr + offsets, seq_length - accept_length + offsets, mask)
55
+
56
+ accept_len_cumsum += accept_length - 1
57
+ verified_id_data = tl.load(verified_id + accept_len_cumsum)
58
+ tl.store(new_verified_id + pid, verified_id_data)
59
+
60
+
61
+ @triton.jit
62
+ def assign_req_to_token_pool(
63
+ req_pool_indices,
64
+ req_to_token,
65
+ start_offset,
66
+ end_offset,
67
+ out_cache_loc,
68
+ pool_len: tl.constexpr,
69
+ bs_upper: tl.constexpr,
70
+ ):
71
+ BLOCK_SIZE: tl.constexpr = 32
72
+ pid = tl.program_id(axis=0)
73
+ kv_start = tl.load(start_offset + pid)
74
+ kv_end = tl.load(end_offset + pid)
75
+ token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
76
+
77
+ length_offset = tl.arange(0, bs_upper)
78
+ start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0)
79
+ end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0)
80
+ out_offset = tl.sum(end - start, axis=0)
81
+
82
+ out_cache_ptr = out_cache_loc + out_offset
83
+
84
+ save_offset = tl.arange(0, BLOCK_SIZE) + kv_start
85
+ load_offset = tl.arange(0, BLOCK_SIZE)
86
+
87
+ num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
88
+ for _ in range(num_loop):
89
+ mask = save_offset < kv_end
90
+ data = tl.load(out_cache_ptr + load_offset, mask=mask)
91
+ tl.store(token_pool + save_offset, data, mask=mask)
92
+ save_offset += BLOCK_SIZE
93
+ load_offset += BLOCK_SIZE
94
+
95
+
96
+ @triton.jit
97
+ def assign_draft_cache_locs(
98
+ req_pool_indices,
99
+ req_to_token,
100
+ seq_lens,
101
+ extend_lens,
102
+ num_new_pages_per_topk,
103
+ out_cache_loc,
104
+ pool_len: tl.constexpr,
105
+ topk: tl.constexpr,
106
+ speculative_num_steps: tl.constexpr,
107
+ page_size: tl.constexpr,
108
+ bs_upper: tl.constexpr,
109
+ iter_upper: tl.constexpr,
110
+ ):
111
+ BLOCK_SIZE: tl.constexpr = 128
112
+ pid = tl.program_id(axis=0)
113
+
114
+ if page_size == 1 or topk == 1:
115
+ copy_len = topk * speculative_num_steps
116
+ out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
117
+ else:
118
+ bs_offset = tl.arange(0, bs_upper)
119
+ copy_len = tl.load(extend_lens + pid)
120
+ cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
121
+ out_cache_ptr = out_cache_loc + cum_copy_len
122
+
123
+ # Part 1: Copy from out_cache_loc to req_to_token
124
+ kv_start = tl.load(seq_lens + pid)
125
+ token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
126
+ num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
127
+ for i in range(num_loop):
128
+ copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
129
+ mask = copy_offset < copy_len
130
+ data = tl.load(out_cache_ptr + copy_offset, mask=mask)
131
+ tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
132
+
133
+ if page_size == 1 or topk == 1:
134
+ return
135
+
136
+ # Part 2: Copy the indices for the last partial page
137
+ prefix_len = tl.load(seq_lens + pid)
138
+ last_page_len = prefix_len % page_size
139
+ offsets = tl.arange(0, page_size)
140
+ mask = offsets < last_page_len
141
+ num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
142
+ prefix_base = token_pool + prefix_len - last_page_len
143
+
144
+ for topk_id in range(topk):
145
+ value = tl.load(prefix_base + offsets, mask=mask)
146
+ tl.store(
147
+ prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
148
+ value,
149
+ mask=mask,
150
+ )
151
+
152
+ # Part 3: Remove the padding in out_cache_loc
153
+ iter_offest = tl.arange(0, iter_upper)
154
+ for topk_id in range(topk):
155
+ indices = tl.load(
156
+ prefix_base
157
+ + topk_id * num_new_pages_per_topk_ * page_size
158
+ + last_page_len
159
+ + iter_offest,
160
+ mask=iter_offest < speculative_num_steps,
161
+ )
162
+ tl.store(
163
+ out_cache_loc
164
+ + pid * topk * speculative_num_steps
165
+ + topk_id * speculative_num_steps
166
+ + iter_offest,
167
+ indices,
168
+ mask=iter_offest < speculative_num_steps,
169
+ )
170
+
171
+
172
+ @triton.jit
173
+ def generate_draft_decode_kv_indices(
174
+ req_pool_indices,
175
+ req_to_token,
176
+ paged_kernel_lens,
177
+ kv_indices,
178
+ kv_indptr,
179
+ positions,
180
+ pool_len: tl.constexpr,
181
+ kv_indices_stride: tl.constexpr,
182
+ kv_indptr_stride: tl.constexpr,
183
+ bs_upper: tl.constexpr,
184
+ iter_upper: tl.constexpr,
185
+ num_tokens_upper: tl.constexpr,
186
+ page_size: tl.constexpr,
187
+ ):
188
+ BLOCK_SIZE: tl.constexpr = 128
189
+ iters = tl.program_id(axis=0)
190
+ bid = tl.program_id(axis=1)
191
+ topk_id = tl.program_id(axis=2)
192
+
193
+ num_steps = tl.num_programs(axis=0)
194
+ num_seqs = tl.num_programs(axis=1)
195
+ topk = tl.num_programs(axis=2)
196
+
197
+ kv_indices += kv_indices_stride * iters
198
+ kv_indptr += kv_indptr_stride * iters
199
+ iters += 1
200
+
201
+ load_offset = tl.arange(0, bs_upper)
202
+ seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid, other=0)
203
+ seq_len = tl.load(paged_kernel_lens + bid)
204
+ cum_seq_len = tl.sum(seq_lens)
205
+
206
+ # Update kv_indices
207
+ kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
208
+ kv_ptr = kv_indices + kv_offset
209
+ token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
210
+
211
+ kv_offset = tl.arange(0, BLOCK_SIZE)
212
+ num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
213
+ for _ in range(num_loop):
214
+ mask = kv_offset < seq_len
215
+ data = tl.load(token_pool_ptr + kv_offset, mask=mask)
216
+ tl.store(kv_ptr + kv_offset, data, mask=mask)
217
+ kv_offset += BLOCK_SIZE
218
+
219
+ extend_offset = tl.arange(0, iter_upper)
220
+ if page_size == 1 or topk == 1:
221
+ extend_data = tl.load(
222
+ token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper),
223
+ mask=extend_offset < iters,
224
+ )
225
+ else:
226
+ prefix_len = seq_len
227
+ last_page_len = prefix_len % page_size
228
+ num_new_pages_per_topk = (
229
+ last_page_len + num_steps + page_size - 1
230
+ ) // page_size
231
+ prefix_base = seq_len // page_size * page_size
232
+ start = (
233
+ prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
234
+ )
235
+ extend_data = tl.load(
236
+ token_pool_ptr + start + extend_offset,
237
+ mask=extend_offset < iters,
238
+ )
239
+
240
+ tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
241
+
242
+ # Update kv_indptr
243
+ bs_offset = tl.arange(0, num_tokens_upper)
244
+
245
+ zid = bid * topk + topk_id
246
+ if zid == 0:
247
+ zid = num_seqs * topk
248
+ positions = tl.load(positions + bs_offset, mask=bs_offset < zid, other=0)
249
+ base = tl.sum(positions)
250
+ tl.store(kv_indptr + zid, base + zid * iters)
251
+
252
+
253
+ @triton.jit
254
+ def align_evict_mask_to_page_size(
255
+ seq_lens,
256
+ evict_mask,
257
+ page_size: tl.constexpr,
258
+ num_draft_tokens: tl.constexpr,
259
+ BLOCK_SIZE: tl.constexpr,
260
+ ):
261
+ t_range = tl.arange(0, BLOCK_SIZE)
262
+
263
+ bid = tl.program_id(axis=0)
264
+ seq_len = tl.load(seq_lens + bid)
265
+ io_mask = t_range < num_draft_tokens
266
+ mask_row = tl.load(
267
+ evict_mask + bid * num_draft_tokens + t_range, mask=io_mask, other=0
268
+ )
269
+
270
+ num_trues = tl.sum(mask_row)
271
+ num_false = num_draft_tokens - num_trues
272
+
273
+ start = (seq_len + num_false - 1) // page_size * page_size - seq_len
274
+ for i in range(max(start, 0), min(start + page_size, num_draft_tokens)):
275
+ tl.store(evict_mask + bid * num_draft_tokens + i, False)
276
+
277
+
278
+ @triton.jit
279
+ def get_target_cache_loc(
280
+ tgt_cache_loc,
281
+ to_free_slots,
282
+ accept_length,
283
+ to_free_num_slots,
284
+ out_cache_loc,
285
+ num_verify_tokens: tl.constexpr,
286
+ num_verify_tokens_upper: tl.constexpr,
287
+ bs_upper: tl.constexpr,
288
+ ):
289
+ bid = tl.program_id(axis=0)
290
+ offset = tl.arange(0, num_verify_tokens_upper)
291
+ bs_offset = tl.arange(0, bs_upper)
292
+
293
+ # write the first part to tgt_cache_loc
294
+ accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
295
+ tgt_cache_loc_start = tl.sum(accept_len_all) + bid
296
+ copy_len = tl.load(accept_length + bid) + 1
297
+ out_cache_loc_row = tl.load(
298
+ out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
299
+ )
300
+ tl.store(
301
+ tgt_cache_loc + tgt_cache_loc_start + offset,
302
+ out_cache_loc_row,
303
+ mask=offset < copy_len,
304
+ )
305
+
306
+ # write the second part to to_free_num_pages
307
+ to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid)
308
+ to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
309
+ out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
310
+ to_free_slots_start = tl.sum(to_free_num_slots_all)
311
+
312
+ copy_len = to_free_num_slots_cur
313
+ out_cache_loc_row = tl.load(
314
+ out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
315
+ mask=offset < copy_len,
316
+ )
317
+ tl.store(
318
+ to_free_slots + to_free_slots_start + offset,
319
+ out_cache_loc_row,
320
+ mask=offset < copy_len,
321
+ )
322
+
323
+
324
+ @torch.compile(dynamic=True)
325
+ def get_src_tgt_cache_loc(
326
+ seq_lens: torch.Tensor,
327
+ out_cache_loc: torch.Tensor,
328
+ accept_index: torch.Tensor,
329
+ accept_length: torch.Tensor,
330
+ draft_token_num: int,
331
+ page_size: int,
332
+ ):
333
+ src_cache_loc = out_cache_loc[accept_index]
334
+ tgt_cache_loc = torch.empty_like(src_cache_loc)
335
+ extended_len = seq_lens + draft_token_num
336
+ keep_len = torch.minimum(
337
+ (seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
338
+ extended_len,
339
+ )
340
+ to_free_num_slots = extended_len - keep_len
341
+ return src_cache_loc, tgt_cache_loc, to_free_num_slots
342
+
343
+
344
+ @triton.jit
345
+ def filter_finished_cache_loc_kernel(
346
+ out_cache_loc,
347
+ tgt_cache_loc,
348
+ accept_length,
349
+ accept_length_filter,
350
+ bs_upper: tl.constexpr,
351
+ num_verify_tokens_upper: tl.constexpr,
352
+ ):
353
+ bid = tl.program_id(0)
354
+ bs_offset = tl.arange(0, bs_upper)
355
+
356
+ accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
357
+ old_start = tl.sum(accept_length_all) + bid
358
+
359
+ accept_length_filter_all = tl.load(
360
+ accept_length_filter + bs_offset, mask=bs_offset < bid
361
+ )
362
+ new_start = tl.sum(accept_length_filter_all)
363
+
364
+ copy_len = tl.load(accept_length_filter + bid)
365
+ copy_offset = tl.arange(0, num_verify_tokens_upper)
366
+ value = tl.load(
367
+ tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
368
+ )
369
+ tl.store(
370
+ out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
371
+ )
372
+
373
+
374
+ @torch.compile(dynamic=True)
375
+ def create_accept_length_filter(
376
+ accept_length: torch.Tensor,
377
+ unfinished_index_device: torch.Tensor,
378
+ seq_lens: torch.Tensor,
379
+ ):
380
+ accept_length_filter = torch.zeros_like(accept_length)
381
+ accept_length_filter[unfinished_index_device] = (
382
+ accept_length[unfinished_index_device] + 1
383
+ )
384
+ seq_lens.add_(accept_length + 1)
385
+ return accept_length_filter
386
+
387
+
388
+ @torch.compile(dynamic=True)
389
+ def select_top_k_tokens(
390
+ i: int,
391
+ topk_p: torch.Tensor,
392
+ topk_index: torch.Tensor,
393
+ hidden_states: torch.Tensor,
394
+ scores: torch.Tensor,
395
+ topk: int,
396
+ ):
397
+ if i == 0:
398
+ # The first step after extend
399
+ input_ids = topk_index.flatten()
400
+ hidden_states = hidden_states.repeat_interleave(topk, dim=0)
401
+ scores = topk_p # shape: (b, topk)
402
+
403
+ tree_info = (
404
+ topk_p.unsqueeze(1), # shape: (b, 1, topk)
405
+ topk_index, # shape: (b, topk)
406
+ torch.arange(-1, topk, dtype=torch.long, device="cuda")
407
+ .unsqueeze(0)
408
+ .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
409
+ )
410
+ else:
411
+ # The later decode steps
412
+ expand_scores = torch.mul(
413
+ scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
414
+ ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
415
+ topk_cs_p, topk_cs_index = fast_topk(
416
+ expand_scores.flatten(start_dim=1), topk, dim=-1
417
+ ) # (b, topk)
418
+ scores = topk_cs_p # shape: (b, topk)
419
+
420
+ topk_index = topk_index.reshape(-1, topk**2)
421
+ input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
422
+
423
+ if hidden_states.shape[0] > 0:
424
+ selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
425
+ 0, hidden_states.shape[0], step=topk, device="cuda"
426
+ ).repeat_interleave(topk)
427
+ hidden_states = hidden_states[selected_input_index, :]
428
+
429
+ tree_info = (
430
+ expand_scores, # shape: (b, topk, topk)
431
+ topk_index, # shape: (b, topk * topk)
432
+ topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk)
433
+ )
434
+
435
+ return input_ids, hidden_states, scores, tree_info
436
+
437
+
438
+ def _generate_simulated_accept_index(
439
+ accept_index,
440
+ predict,
441
+ accept_length,
442
+ bs,
443
+ spec_steps,
444
+ simulate_acc_len: float = SIMULATE_ACC_LEN,
445
+ simulate_acc_method: str = SIMULATE_ACC_METHOD,
446
+ ):
447
+ assert simulate_acc_len > 0.0
448
+
449
+ if simulate_acc_method == "multinomial":
450
+ simulated_values = torch.normal(
451
+ mean=simulate_acc_len,
452
+ std=1.0,
453
+ size=(1,),
454
+ device="cpu",
455
+ )
456
+ # clamp simulated values to be between 1 and self.spec_steps
457
+ simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
458
+ simulate_acc_len = int(simulated_values.round().item())
459
+ elif simulate_acc_method == "match-expected":
460
+ # multinomial sampling does not match the expected length
461
+ # we keep it for the sake of compatibility of existing tests
462
+ # but it's better to use "match-expected" for the cases that need to
463
+ # match the expected length, One caveat is that this will only sample
464
+ # either round down or round up of the expected length
465
+ simulate_acc_len = max(1.0, min(spec_steps + 1, simulate_acc_len))
466
+ lower = int(simulate_acc_len // 1)
467
+ upper = lower + 1 if lower < spec_steps + 1 else lower
468
+ if lower == upper:
469
+ simulate_acc_len = lower
470
+ else:
471
+ weight_upper = simulate_acc_len - lower
472
+ weight_lower = 1.0 - weight_upper
473
+ probs = torch.tensor([weight_lower, weight_upper], device="cpu")
474
+ sampled_index = torch.multinomial(probs, num_samples=1)
475
+ simulate_acc_len = lower if sampled_index == 0 else upper
476
+ else:
477
+ raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
478
+
479
+ accept_indx_first_col = accept_index[:, 0].view(-1, 1)
480
+ sim_accept_index = torch.full(
481
+ (bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda"
482
+ )
483
+ sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange(
484
+ simulate_acc_len, device=accept_index.device
485
+ )
486
+ accept_length.fill_(simulate_acc_len - 1)
487
+ predict.fill_(100) # some legit token id
488
+ return sim_accept_index
489
+
490
+
491
+ def traverse_tree(
492
+ retrieve_next_token: torch.Tensor,
493
+ retrieve_next_sibling: torch.Tensor,
494
+ draft_tokens: torch.Tensor,
495
+ grammar: BaseGrammarObject,
496
+ allocate_token_bitmask: torch.Tensor,
497
+ ):
498
+ """
499
+ Traverse the tree constructed by the draft model to generate the logits mask.
500
+ """
501
+ assert (
502
+ retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape
503
+ )
504
+
505
+ allocate_token_bitmask.fill_(0)
506
+
507
+ def dfs(
508
+ curr: int,
509
+ retrieve_next_token: torch.Tensor,
510
+ retrieve_next_sibling: torch.Tensor,
511
+ parent_pos: int,
512
+ ):
513
+ if curr == 0:
514
+ # the first token generated by the target model, and thus it is always
515
+ # accepted from the previous iteration
516
+ accepted = True
517
+ else:
518
+ parent_bitmask = allocate_token_bitmask[parent_pos]
519
+ curr_token_id = draft_tokens[curr]
520
+ # 32 boolean bitmask values are packed into 32-bit integers
521
+ accepted = (
522
+ parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))
523
+ ) != 0
524
+
525
+ if accepted:
526
+ if curr != 0:
527
+ # Accept the current token
528
+ grammar.accept_token(draft_tokens[curr])
529
+ if not grammar.is_terminated():
530
+ # Generate the bitmask for the current token
531
+ grammar.fill_vocab_mask(allocate_token_bitmask, curr)
532
+ if retrieve_next_token[curr] != -1:
533
+ # Visit the child node
534
+ dfs(
535
+ retrieve_next_token[curr],
536
+ retrieve_next_token,
537
+ retrieve_next_sibling,
538
+ curr,
539
+ )
540
+
541
+ if curr != 0:
542
+ # Rollback the current token
543
+ grammar.rollback(1)
544
+
545
+ if retrieve_next_sibling[curr] != -1:
546
+ # Visit the sibling node
547
+ dfs(
548
+ retrieve_next_sibling[curr],
549
+ retrieve_next_token,
550
+ retrieve_next_sibling,
551
+ parent_pos,
552
+ )
553
+
554
+ dfs(0, retrieve_next_token, retrieve_next_sibling, -1)
555
+
556
+
557
+ def generate_token_bitmask(
558
+ reqs: List[Req],
559
+ verify_input: EagleVerifyInput,
560
+ retrieve_next_token_cpu: torch.Tensor,
561
+ retrieve_next_sibling_cpu: torch.Tensor,
562
+ draft_tokens_cpu: torch.Tensor,
563
+ vocab_size: int,
564
+ ):
565
+ """
566
+ Generate the logit mask for structured output.
567
+ Draft model's token can be either valid or invalid with respect to the grammar.
568
+ We need to perform DFS to
569
+ 1. figure out which tokens are accepted by the grammar.
570
+ 2. if so, what is the corresponding logit mask.
571
+ """
572
+
573
+ num_draft_tokens = draft_tokens_cpu.shape[-1]
574
+
575
+ allocate_token_bitmask = None
576
+ assert len(reqs) == retrieve_next_token_cpu.shape[0]
577
+ grammar = None
578
+ for i, req in enumerate(reqs):
579
+ if req.grammar is not None:
580
+ if allocate_token_bitmask is None:
581
+ allocate_token_bitmask = req.grammar.allocate_vocab_mask(
582
+ vocab_size=vocab_size,
583
+ batch_size=draft_tokens_cpu.numel(),
584
+ device="cpu",
585
+ )
586
+ grammar = req.grammar
587
+ s = time.perf_counter()
588
+ traverse_tree(
589
+ retrieve_next_token_cpu[i],
590
+ retrieve_next_sibling_cpu[i],
591
+ draft_tokens_cpu[i],
592
+ req.grammar,
593
+ allocate_token_bitmask[
594
+ i * num_draft_tokens : (i + 1) * num_draft_tokens
595
+ ],
596
+ )
597
+ tree_traverse_time = time.perf_counter() - s
598
+ if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
599
+ logger.warning(
600
+ f"Bit mask generation took {tree_traverse_time} seconds with "
601
+ f"grammar: {req.grammar}"
602
+ )
603
+
604
+ verify_input.grammar = grammar
605
+ return allocate_token_bitmask