sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,16 @@
1
+ from abc import ABC, abstractmethod
1
2
  from enum import IntEnum, auto
3
+ from typing import List, Tuple
4
+
5
+ from sglang.srt.managers.schedule_batch import ModelWorkerBatch
2
6
 
3
7
 
4
8
  class SpeculativeAlgorithm(IntEnum):
5
9
  NONE = auto()
6
10
  EAGLE = auto()
7
11
  EAGLE3 = auto()
12
+ STANDALONE = auto()
13
+ NGRAM = auto()
8
14
 
9
15
  def is_none(self):
10
16
  return self == SpeculativeAlgorithm.NONE
@@ -15,13 +21,59 @@ class SpeculativeAlgorithm(IntEnum):
15
21
  def is_eagle3(self):
16
22
  return self == SpeculativeAlgorithm.EAGLE3
17
23
 
24
+ def is_standalone(self):
25
+ return self == SpeculativeAlgorithm.STANDALONE
26
+
27
+ def is_ngram(self):
28
+ return self == SpeculativeAlgorithm.NGRAM
29
+
18
30
  @staticmethod
19
31
  def from_string(name: str):
20
32
  name_map = {
21
33
  "EAGLE": SpeculativeAlgorithm.EAGLE,
22
34
  "EAGLE3": SpeculativeAlgorithm.EAGLE3,
35
+ "STANDALONE": SpeculativeAlgorithm.STANDALONE,
36
+ "NGRAM": SpeculativeAlgorithm.NGRAM,
23
37
  None: SpeculativeAlgorithm.NONE,
24
38
  }
25
39
  if name is not None:
26
40
  name = name.upper()
27
41
  return name_map[name]
42
+
43
+
44
+ class SpecInputType(IntEnum):
45
+ # NOTE: introduce this to distinguish the SpecInput types of multiple algorithms when asserting in attention backends.
46
+ # If all algorithms can share the same datastrucutre of draft_input and verify_input, consider simplify it
47
+ EAGLE_DRAFT = auto()
48
+ EAGLE_VERIFY = auto()
49
+ NGRAM_VERIFY = auto()
50
+
51
+
52
+ class SpecInput(ABC):
53
+ def __init__(self, spec_input_type: SpecInputType):
54
+ self.spec_input_type = spec_input_type
55
+
56
+ def is_draft_input(self) -> bool:
57
+ # FIXME: remove this function which is only used for assertion
58
+ # or use another variable name like `draft_input` to substitute `spec_info`
59
+ return self.spec_input_type == SpecInputType.EAGLE_DRAFT
60
+
61
+ def is_verify_input(self) -> bool:
62
+ return self.spec_input_type in {
63
+ SpecInputType.EAGLE_VERIFY,
64
+ SpecInputType.NGRAM_VERIFY,
65
+ }
66
+
67
+ @abstractmethod
68
+ def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
69
+ pass
70
+
71
+ def get_spec_adjusted_global_num_tokens(
72
+ self, forward_batch: ModelWorkerBatch
73
+ ) -> Tuple[List[int], List[int]]:
74
+ c1, c2 = self.get_spec_adjust_token_coefficient()
75
+ global_num_tokens = [x * c1 for x in forward_batch.global_num_tokens]
76
+ global_num_tokens_for_logprob = [
77
+ x * c2 for x in forward_batch.global_num_tokens_for_logprob
78
+ ]
79
+ return global_num_tokens, global_num_tokens_for_logprob
@@ -0,0 +1,606 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ import time
6
+ from typing import TYPE_CHECKING, List
7
+
8
+ import torch
9
+ import triton
10
+ import triton.language as tl
11
+
12
+ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
13
+ from sglang.srt.environ import envs
14
+ from sglang.srt.managers.schedule_batch import Req
15
+ from sglang.srt.utils import is_cuda, is_hip
16
+
17
+ if is_cuda():
18
+ from sgl_kernel import fast_topk
19
+ elif is_hip():
20
+ from sgl_kernel import fast_topk
21
+
22
+ if TYPE_CHECKING:
23
+ from sglang.srt.speculative.eagle_info import EagleVerifyInput
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ # Simulate acceptance length for benchmarking purposes
29
+ SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0
30
+ SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get()
31
+
32
+ TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
33
+ TREE_SPEC_KERNEL_AVAILABLE = is_cuda() # This kernel is only available for CUDA now
34
+
35
+
36
+ @triton.jit
37
+ def create_extend_after_decode_spec_info(
38
+ verified_id,
39
+ seq_lens,
40
+ accept_lens,
41
+ positions,
42
+ new_verified_id,
43
+ bs_upper: tl.constexpr,
44
+ ):
45
+ pid = tl.program_id(axis=0)
46
+ offsets = tl.arange(0, bs_upper)
47
+ seq_length = tl.load(seq_lens + pid)
48
+ accept_length = tl.load(accept_lens + pid)
49
+
50
+ accept_len_cumsum = tl.sum(
51
+ tl.load(accept_lens + offsets, mask=offsets < pid, other=0)
52
+ )
53
+ positions_ptr = positions + accept_len_cumsum
54
+ mask = offsets < accept_length
55
+ tl.store(positions_ptr + offsets, seq_length - accept_length + offsets, mask)
56
+
57
+ accept_len_cumsum += accept_length - 1
58
+ verified_id_data = tl.load(verified_id + accept_len_cumsum)
59
+ tl.store(new_verified_id + pid, verified_id_data)
60
+
61
+
62
+ @triton.jit
63
+ def assign_req_to_token_pool(
64
+ req_pool_indices,
65
+ req_to_token,
66
+ start_offset,
67
+ end_offset,
68
+ out_cache_loc,
69
+ pool_len: tl.constexpr,
70
+ bs_upper: tl.constexpr,
71
+ ):
72
+ BLOCK_SIZE: tl.constexpr = 32
73
+ pid = tl.program_id(axis=0)
74
+ kv_start = tl.load(start_offset + pid)
75
+ kv_end = tl.load(end_offset + pid)
76
+ token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
77
+
78
+ length_offset = tl.arange(0, bs_upper)
79
+ start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0)
80
+ end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0)
81
+ out_offset = tl.sum(end - start, axis=0)
82
+
83
+ out_cache_ptr = out_cache_loc + out_offset
84
+
85
+ save_offset = tl.arange(0, BLOCK_SIZE) + kv_start
86
+ load_offset = tl.arange(0, BLOCK_SIZE)
87
+
88
+ num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
89
+ for _ in range(num_loop):
90
+ mask = save_offset < kv_end
91
+ data = tl.load(out_cache_ptr + load_offset, mask=mask)
92
+ tl.store(token_pool + save_offset, data, mask=mask)
93
+ save_offset += BLOCK_SIZE
94
+ load_offset += BLOCK_SIZE
95
+
96
+
97
+ @triton.jit
98
+ def assign_draft_cache_locs(
99
+ req_pool_indices,
100
+ req_to_token,
101
+ seq_lens,
102
+ extend_lens,
103
+ num_new_pages_per_topk,
104
+ out_cache_loc,
105
+ pool_len: tl.constexpr,
106
+ topk: tl.constexpr,
107
+ speculative_num_steps: tl.constexpr,
108
+ page_size: tl.constexpr,
109
+ bs_upper: tl.constexpr,
110
+ iter_upper: tl.constexpr,
111
+ ):
112
+ BLOCK_SIZE: tl.constexpr = 128
113
+ pid = tl.program_id(axis=0)
114
+
115
+ if page_size == 1 or topk == 1:
116
+ copy_len = topk * speculative_num_steps
117
+ out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
118
+ else:
119
+ bs_offset = tl.arange(0, bs_upper)
120
+ copy_len = tl.load(extend_lens + pid)
121
+ cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
122
+ out_cache_ptr = out_cache_loc + cum_copy_len
123
+
124
+ # Part 1: Copy from out_cache_loc to req_to_token
125
+ kv_start = tl.load(seq_lens + pid)
126
+ token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
127
+ num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
128
+ for i in range(num_loop):
129
+ copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
130
+ mask = copy_offset < copy_len
131
+ data = tl.load(out_cache_ptr + copy_offset, mask=mask)
132
+ tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
133
+
134
+ if page_size == 1 or topk == 1:
135
+ return
136
+
137
+ # Part 2: Copy the indices for the last partial page
138
+ prefix_len = tl.load(seq_lens + pid)
139
+ last_page_len = prefix_len % page_size
140
+ offsets = tl.arange(0, page_size)
141
+ mask = offsets < last_page_len
142
+ num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
143
+ prefix_base = token_pool + prefix_len - last_page_len
144
+
145
+ for topk_id in range(topk):
146
+ value = tl.load(prefix_base + offsets, mask=mask)
147
+ tl.store(
148
+ prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
149
+ value,
150
+ mask=mask,
151
+ )
152
+
153
+ # Part 3: Remove the padding in out_cache_loc
154
+ iter_offest = tl.arange(0, iter_upper)
155
+ for topk_id in range(topk):
156
+ indices = tl.load(
157
+ prefix_base
158
+ + topk_id * num_new_pages_per_topk_ * page_size
159
+ + last_page_len
160
+ + iter_offest,
161
+ mask=iter_offest < speculative_num_steps,
162
+ )
163
+ tl.store(
164
+ out_cache_loc
165
+ + pid * topk * speculative_num_steps
166
+ + topk_id * speculative_num_steps
167
+ + iter_offest,
168
+ indices,
169
+ mask=iter_offest < speculative_num_steps,
170
+ )
171
+
172
+
173
+ @triton.jit
174
+ def generate_draft_decode_kv_indices(
175
+ req_pool_indices,
176
+ req_to_token,
177
+ paged_kernel_lens,
178
+ kv_indices,
179
+ kv_indptr,
180
+ positions,
181
+ pool_len: tl.constexpr,
182
+ kv_indices_stride: tl.constexpr,
183
+ kv_indptr_stride: tl.constexpr,
184
+ bs_upper: tl.constexpr,
185
+ iter_upper: tl.constexpr,
186
+ num_tokens_upper: tl.constexpr,
187
+ page_size: tl.constexpr,
188
+ ):
189
+ BLOCK_SIZE: tl.constexpr = 128
190
+ iters = tl.program_id(axis=0)
191
+ bid = tl.program_id(axis=1)
192
+ topk_id = tl.program_id(axis=2)
193
+
194
+ num_steps = tl.num_programs(axis=0)
195
+ num_seqs = tl.num_programs(axis=1)
196
+ topk = tl.num_programs(axis=2)
197
+
198
+ kv_indices += kv_indices_stride * iters
199
+ kv_indptr += kv_indptr_stride * iters
200
+ iters += 1
201
+
202
+ load_offset = tl.arange(0, bs_upper)
203
+ seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid, other=0)
204
+ seq_len = tl.load(paged_kernel_lens + bid)
205
+ cum_seq_len = tl.sum(seq_lens)
206
+
207
+ # Update kv_indices
208
+ kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
209
+ kv_ptr = kv_indices + kv_offset
210
+ token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
211
+
212
+ kv_offset = tl.arange(0, BLOCK_SIZE)
213
+ num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
214
+ for _ in range(num_loop):
215
+ mask = kv_offset < seq_len
216
+ data = tl.load(token_pool_ptr + kv_offset, mask=mask)
217
+ tl.store(kv_ptr + kv_offset, data, mask=mask)
218
+ kv_offset += BLOCK_SIZE
219
+
220
+ extend_offset = tl.arange(0, iter_upper)
221
+ if page_size == 1 or topk == 1:
222
+ extend_data = tl.load(
223
+ token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper),
224
+ mask=extend_offset < iters,
225
+ )
226
+ else:
227
+ prefix_len = seq_len
228
+ last_page_len = prefix_len % page_size
229
+ num_new_pages_per_topk = (
230
+ last_page_len + num_steps + page_size - 1
231
+ ) // page_size
232
+ prefix_base = seq_len // page_size * page_size
233
+ start = (
234
+ prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
235
+ )
236
+ extend_data = tl.load(
237
+ token_pool_ptr + start + extend_offset,
238
+ mask=extend_offset < iters,
239
+ )
240
+
241
+ tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
242
+
243
+ # Update kv_indptr
244
+ bs_offset = tl.arange(0, num_tokens_upper)
245
+
246
+ zid = bid * topk + topk_id
247
+ if zid == 0:
248
+ zid = num_seqs * topk
249
+ positions = tl.load(positions + bs_offset, mask=bs_offset < zid, other=0)
250
+ base = tl.sum(positions)
251
+ tl.store(kv_indptr + zid, base + zid * iters)
252
+
253
+
254
+ @triton.jit
255
+ def align_evict_mask_to_page_size(
256
+ seq_lens,
257
+ evict_mask,
258
+ page_size: tl.constexpr,
259
+ num_draft_tokens: tl.constexpr,
260
+ BLOCK_SIZE: tl.constexpr,
261
+ ):
262
+ t_range = tl.arange(0, BLOCK_SIZE)
263
+
264
+ bid = tl.program_id(axis=0)
265
+ seq_len = tl.load(seq_lens + bid)
266
+ io_mask = t_range < num_draft_tokens
267
+ mask_row = tl.load(
268
+ evict_mask + bid * num_draft_tokens + t_range, mask=io_mask, other=0
269
+ )
270
+
271
+ num_trues = tl.sum(mask_row)
272
+ num_false = num_draft_tokens - num_trues
273
+
274
+ start = (seq_len + num_false - 1) // page_size * page_size - seq_len
275
+ for i in range(max(start, 0), min(start + page_size, num_draft_tokens)):
276
+ tl.store(evict_mask + bid * num_draft_tokens + i, False)
277
+
278
+
279
+ @triton.jit
280
+ def get_target_cache_loc(
281
+ tgt_cache_loc,
282
+ to_free_slots,
283
+ accept_length,
284
+ to_free_num_slots,
285
+ out_cache_loc,
286
+ num_verify_tokens: tl.constexpr,
287
+ num_verify_tokens_upper: tl.constexpr,
288
+ bs_upper: tl.constexpr,
289
+ ):
290
+ bid = tl.program_id(axis=0)
291
+ offset = tl.arange(0, num_verify_tokens_upper)
292
+ bs_offset = tl.arange(0, bs_upper)
293
+
294
+ # write the first part to tgt_cache_loc
295
+ accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
296
+ tgt_cache_loc_start = tl.sum(accept_len_all) + bid
297
+ copy_len = tl.load(accept_length + bid) + 1
298
+ out_cache_loc_row = tl.load(
299
+ out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
300
+ )
301
+ tl.store(
302
+ tgt_cache_loc + tgt_cache_loc_start + offset,
303
+ out_cache_loc_row,
304
+ mask=offset < copy_len,
305
+ )
306
+
307
+ # write the second part to to_free_num_pages
308
+ to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid)
309
+ to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
310
+ out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
311
+ to_free_slots_start = tl.sum(to_free_num_slots_all)
312
+
313
+ copy_len = to_free_num_slots_cur
314
+ out_cache_loc_row = tl.load(
315
+ out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
316
+ mask=offset < copy_len,
317
+ )
318
+ tl.store(
319
+ to_free_slots + to_free_slots_start + offset,
320
+ out_cache_loc_row,
321
+ mask=offset < copy_len,
322
+ )
323
+
324
+
325
+ @torch.compile(dynamic=True)
326
+ def get_src_tgt_cache_loc(
327
+ seq_lens: torch.Tensor,
328
+ out_cache_loc: torch.Tensor,
329
+ accept_index: torch.Tensor,
330
+ accept_length: torch.Tensor,
331
+ draft_token_num: int,
332
+ page_size: int,
333
+ ):
334
+ src_cache_loc = out_cache_loc[accept_index]
335
+ tgt_cache_loc = torch.empty_like(src_cache_loc)
336
+ extended_len = seq_lens + draft_token_num
337
+ keep_len = torch.minimum(
338
+ (seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
339
+ extended_len,
340
+ )
341
+ to_free_num_slots = extended_len - keep_len
342
+ return src_cache_loc, tgt_cache_loc, to_free_num_slots
343
+
344
+
345
+ @triton.jit
346
+ def filter_finished_cache_loc_kernel(
347
+ out_cache_loc,
348
+ tgt_cache_loc,
349
+ accept_length,
350
+ accept_length_filter,
351
+ bs_upper: tl.constexpr,
352
+ num_verify_tokens_upper: tl.constexpr,
353
+ ):
354
+ bid = tl.program_id(0)
355
+ bs_offset = tl.arange(0, bs_upper)
356
+
357
+ accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
358
+ old_start = tl.sum(accept_length_all) + bid
359
+
360
+ accept_length_filter_all = tl.load(
361
+ accept_length_filter + bs_offset, mask=bs_offset < bid
362
+ )
363
+ new_start = tl.sum(accept_length_filter_all)
364
+
365
+ copy_len = tl.load(accept_length_filter + bid)
366
+ copy_offset = tl.arange(0, num_verify_tokens_upper)
367
+ value = tl.load(
368
+ tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
369
+ )
370
+ tl.store(
371
+ out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
372
+ )
373
+
374
+
375
+ @torch.compile(dynamic=True)
376
+ def create_accept_length_filter(
377
+ accept_length: torch.Tensor,
378
+ unfinished_index_device: torch.Tensor,
379
+ seq_lens: torch.Tensor,
380
+ ):
381
+ accept_length_filter = torch.zeros_like(accept_length)
382
+ accept_length_filter[unfinished_index_device] = (
383
+ accept_length[unfinished_index_device] + 1
384
+ )
385
+ seq_lens.add_(accept_length + 1)
386
+ return accept_length_filter
387
+
388
+
389
+ @torch.compile(dynamic=True)
390
+ def select_top_k_tokens(
391
+ i: int,
392
+ topk_p: torch.Tensor,
393
+ topk_index: torch.Tensor,
394
+ hidden_states: torch.Tensor,
395
+ scores: torch.Tensor,
396
+ topk: int,
397
+ ):
398
+ if i == 0:
399
+ # The first step after extend
400
+ input_ids = topk_index.flatten()
401
+ hidden_states = hidden_states.repeat_interleave(topk, dim=0)
402
+ scores = topk_p # shape: (b, topk)
403
+
404
+ tree_info = (
405
+ topk_p.unsqueeze(1), # shape: (b, 1, topk)
406
+ topk_index, # shape: (b, topk)
407
+ torch.arange(-1, topk, dtype=torch.long, device="cuda")
408
+ .unsqueeze(0)
409
+ .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
410
+ )
411
+ else:
412
+ # The later decode steps
413
+ expand_scores = torch.mul(
414
+ scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
415
+ ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
416
+ topk_cs_p, topk_cs_index = fast_topk(
417
+ expand_scores.flatten(start_dim=1), topk, dim=-1
418
+ ) # (b, topk)
419
+ scores = topk_cs_p # shape: (b, topk)
420
+
421
+ topk_index = topk_index.reshape(-1, topk**2)
422
+ input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
423
+
424
+ if hidden_states.shape[0] > 0:
425
+ selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
426
+ 0, hidden_states.shape[0], step=topk, device="cuda"
427
+ ).repeat_interleave(topk)
428
+ hidden_states = hidden_states[selected_input_index, :]
429
+
430
+ tree_info = (
431
+ expand_scores, # shape: (b, topk, topk)
432
+ topk_index, # shape: (b, topk * topk)
433
+ topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk)
434
+ )
435
+
436
+ return input_ids, hidden_states, scores, tree_info
437
+
438
+
439
+ def _generate_simulated_accept_index(
440
+ accept_index,
441
+ predict,
442
+ accept_length,
443
+ bs,
444
+ spec_steps,
445
+ simulate_acc_len: float = SIMULATE_ACC_LEN,
446
+ simulate_acc_method: str = SIMULATE_ACC_METHOD,
447
+ ):
448
+ assert simulate_acc_len > 0.0
449
+
450
+ if simulate_acc_method == "multinomial":
451
+ simulated_values = torch.normal(
452
+ mean=simulate_acc_len,
453
+ std=1.0,
454
+ size=(1,),
455
+ device="cpu",
456
+ )
457
+ # clamp simulated values to be between 1 and self.spec_steps
458
+ simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
459
+ simulate_acc_len = int(simulated_values.round().item())
460
+ elif simulate_acc_method == "match-expected":
461
+ # multinomial sampling does not match the expected length
462
+ # we keep it for the sake of compatibility of existing tests
463
+ # but it's better to use "match-expected" for the cases that need to
464
+ # match the expected length, One caveat is that this will only sample
465
+ # either round down or round up of the expected length
466
+ simulate_acc_len = max(1.0, min(spec_steps + 1, simulate_acc_len))
467
+ lower = int(simulate_acc_len // 1)
468
+ upper = lower + 1 if lower < spec_steps + 1 else lower
469
+ if lower == upper:
470
+ simulate_acc_len = lower
471
+ else:
472
+ weight_upper = simulate_acc_len - lower
473
+ weight_lower = 1.0 - weight_upper
474
+ probs = torch.tensor([weight_lower, weight_upper], device="cpu")
475
+ sampled_index = torch.multinomial(probs, num_samples=1)
476
+ simulate_acc_len = lower if sampled_index == 0 else upper
477
+ else:
478
+ raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
479
+
480
+ accept_indx_first_col = accept_index[:, 0].view(-1, 1)
481
+ sim_accept_index = torch.full(
482
+ (bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda"
483
+ )
484
+ sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange(
485
+ simulate_acc_len, device=accept_index.device
486
+ )
487
+ accept_length.fill_(simulate_acc_len - 1)
488
+ predict.fill_(100) # some legit token id
489
+ return sim_accept_index
490
+
491
+
492
+ def traverse_tree(
493
+ retrieve_next_token: torch.Tensor,
494
+ retrieve_next_sibling: torch.Tensor,
495
+ draft_tokens: torch.Tensor,
496
+ grammar: BaseGrammarObject,
497
+ allocate_token_bitmask: torch.Tensor,
498
+ ):
499
+ """
500
+ Traverse the tree constructed by the draft model to generate the logits mask.
501
+ """
502
+ assert (
503
+ retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape
504
+ )
505
+
506
+ allocate_token_bitmask.fill_(0)
507
+
508
+ def dfs(
509
+ curr: int,
510
+ retrieve_next_token: torch.Tensor,
511
+ retrieve_next_sibling: torch.Tensor,
512
+ parent_pos: int,
513
+ ):
514
+ if curr == 0:
515
+ # the first token generated by the target model, and thus it is always
516
+ # accepted from the previous iteration
517
+ accepted = True
518
+ else:
519
+ parent_bitmask = allocate_token_bitmask[parent_pos]
520
+ curr_token_id = draft_tokens[curr]
521
+ # 32 boolean bitmask values are packed into 32-bit integers
522
+ accepted = (
523
+ parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))
524
+ ) != 0
525
+
526
+ if accepted:
527
+ if curr != 0:
528
+ # Accept the current token
529
+ grammar.accept_token(draft_tokens[curr])
530
+ if not grammar.is_terminated():
531
+ # Generate the bitmask for the current token
532
+ grammar.fill_vocab_mask(allocate_token_bitmask, curr)
533
+ if retrieve_next_token[curr] != -1:
534
+ # Visit the child node
535
+ dfs(
536
+ retrieve_next_token[curr],
537
+ retrieve_next_token,
538
+ retrieve_next_sibling,
539
+ curr,
540
+ )
541
+
542
+ if curr != 0:
543
+ # Rollback the current token
544
+ grammar.rollback(1)
545
+
546
+ if retrieve_next_sibling[curr] != -1:
547
+ # Visit the sibling node
548
+ dfs(
549
+ retrieve_next_sibling[curr],
550
+ retrieve_next_token,
551
+ retrieve_next_sibling,
552
+ parent_pos,
553
+ )
554
+
555
+ dfs(0, retrieve_next_token, retrieve_next_sibling, -1)
556
+
557
+
558
+ def generate_token_bitmask(
559
+ reqs: List[Req],
560
+ verify_input: EagleVerifyInput,
561
+ retrieve_next_token_cpu: torch.Tensor,
562
+ retrieve_next_sibling_cpu: torch.Tensor,
563
+ draft_tokens_cpu: torch.Tensor,
564
+ vocab_size: int,
565
+ ):
566
+ """
567
+ Generate the logit mask for structured output.
568
+ Draft model's token can be either valid or invalid with respect to the grammar.
569
+ We need to perform DFS to
570
+ 1. figure out which tokens are accepted by the grammar.
571
+ 2. if so, what is the corresponding logit mask.
572
+ """
573
+
574
+ num_draft_tokens = draft_tokens_cpu.shape[-1]
575
+
576
+ allocate_token_bitmask = None
577
+ assert len(reqs) == retrieve_next_token_cpu.shape[0]
578
+ grammar = None
579
+ for i, req in enumerate(reqs):
580
+ if req.grammar is not None:
581
+ if allocate_token_bitmask is None:
582
+ allocate_token_bitmask = req.grammar.allocate_vocab_mask(
583
+ vocab_size=vocab_size,
584
+ batch_size=draft_tokens_cpu.numel(),
585
+ device="cpu",
586
+ )
587
+ grammar = req.grammar
588
+ s = time.perf_counter()
589
+ traverse_tree(
590
+ retrieve_next_token_cpu[i],
591
+ retrieve_next_sibling_cpu[i],
592
+ draft_tokens_cpu[i],
593
+ req.grammar,
594
+ allocate_token_bitmask[
595
+ i * num_draft_tokens : (i + 1) * num_draft_tokens
596
+ ],
597
+ )
598
+ tree_traverse_time = time.perf_counter() - s
599
+ if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
600
+ logger.warning(
601
+ f"Bit mask generation took {tree_traverse_time} seconds with "
602
+ f"grammar: {req.grammar}"
603
+ )
604
+
605
+ verify_input.grammar = grammar
606
+ return allocate_token_bitmask