sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,6 @@ from sglang.srt.distributed import (
14
14
  )
15
15
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
16
16
  from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
17
- from sglang.srt.managers.mm_utils import embed_mm_inputs
18
17
  from sglang.srt.managers.schedule_batch import (
19
18
  ScheduleBatch,
20
19
  get_last_loc,
@@ -24,6 +23,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
24
23
  from sglang.srt.model_executor.forward_batch_info import (
25
24
  CaptureHiddenMode,
26
25
  ForwardBatch,
26
+ ForwardBatchOutput,
27
27
  ForwardMode,
28
28
  )
29
29
  from sglang.srt.server_args import ServerArgs
@@ -34,16 +34,18 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
34
34
  from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
35
35
  EAGLEDraftExtendCudaGraphRunner,
36
36
  )
37
- from sglang.srt.speculative.eagle_utils import (
37
+ from sglang.srt.speculative.eagle_info import (
38
38
  EagleDraftInput,
39
39
  EagleVerifyInput,
40
40
  EagleVerifyOutput,
41
+ )
42
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
43
+ from sglang.srt.speculative.spec_utils import (
41
44
  assign_draft_cache_locs,
42
45
  fast_topk,
43
46
  generate_token_bitmask,
44
47
  select_top_k_tokens,
45
48
  )
46
- from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
47
49
  from sglang.srt.utils import (
48
50
  empty_context,
49
51
  get_available_gpu_memory,
@@ -242,6 +244,7 @@ class EAGLEWorker(TpModelWorker):
242
244
  if not is_blackwell()
243
245
  else self._create_triton_prefill_backend
244
246
  ),
247
+ "flashmla": self._create_flashmla_prefill_backend,
245
248
  "trtllm_mha": self._create_trtllm_mha_prefill_backend,
246
249
  "trtllm_mla": self._create_trtllm_mla_prefill_backend,
247
250
  }
@@ -381,6 +384,12 @@ class EAGLEWorker(TpModelWorker):
381
384
 
382
385
  return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
383
386
 
387
+ def _create_flashmla_prefill_backend(self):
388
+ logger.warning(
389
+ "flashmla prefill backend is not yet supported for draft extend."
390
+ )
391
+ return None
392
+
384
393
  def init_cuda_graphs(self):
385
394
  """Capture cuda graphs."""
386
395
  self.cuda_graph_runner = None
@@ -420,9 +429,7 @@ class EAGLEWorker(TpModelWorker):
420
429
  def draft_model_runner(self):
421
430
  return self.model_runner
422
431
 
423
- def forward_batch_speculative_generation(
424
- self, batch: ScheduleBatch
425
- ) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]:
432
+ def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput:
426
433
  """Run speculative decoding forward.
427
434
 
428
435
  NOTE: Many states of batch is modified as you go through. It is not guaranteed that
@@ -435,14 +442,19 @@ class EAGLEWorker(TpModelWorker):
435
442
  the batch id (used for overlap schedule), and number of accepted tokens.
436
443
  """
437
444
  if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
438
- logits_output, next_token_ids, bid, seq_lens_cpu = (
439
- self.forward_target_extend(batch)
445
+ logits_output, next_token_ids, seq_lens_cpu = self.forward_target_extend(
446
+ batch
440
447
  )
441
448
  with self.draft_tp_context(self.draft_model_runner.tp_group):
442
449
  self.forward_draft_extend(
443
450
  batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
444
451
  )
445
- return logits_output, next_token_ids, bid, 0, False
452
+ return ForwardBatchOutput(
453
+ logits_output=logits_output,
454
+ next_token_ids=next_token_ids,
455
+ num_accepted_tokens=0,
456
+ can_run_cuda_graph=False,
457
+ )
446
458
  else:
447
459
  with self.draft_tp_context(self.draft_model_runner.tp_group):
448
460
  spec_info = self.draft(batch)
@@ -460,12 +472,11 @@ class EAGLEWorker(TpModelWorker):
460
472
  # decode is not finished
461
473
  self.forward_draft_extend_after_decode(batch)
462
474
 
463
- return (
464
- logits_output,
465
- verify_output.verified_id,
466
- model_worker_batch.bid,
467
- sum(verify_output.accept_length_per_req_cpu),
468
- can_run_cuda_graph,
475
+ return ForwardBatchOutput(
476
+ logits_output=logits_output,
477
+ next_token_ids=verify_output.verified_id,
478
+ num_accepted_tokens=sum(verify_output.accept_length_per_req_cpu),
479
+ can_run_cuda_graph=can_run_cuda_graph,
469
480
  )
470
481
 
471
482
  def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
@@ -497,19 +508,21 @@ class EAGLEWorker(TpModelWorker):
497
508
  Returns:
498
509
  logits_output: The output of logits. It will contain the full hidden states.
499
510
  next_token_ids: Next token ids generated.
500
- bid: The model batch ID. Used for overlap schedule.
501
511
  """
502
512
  # Forward with the target model and get hidden states.
503
513
  # We need the full hidden states to prefill the KV cache of the draft model.
504
514
  model_worker_batch = batch.get_model_worker_batch()
505
515
  model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
506
- logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
516
+ forward_batch_output = self.target_worker.forward_batch_generation(
507
517
  model_worker_batch
508
518
  )
519
+ logits_output, next_token_ids = (
520
+ forward_batch_output.logits_output,
521
+ forward_batch_output.next_token_ids,
522
+ )
509
523
  return (
510
524
  logits_output,
511
525
  next_token_ids,
512
- model_worker_batch.bid,
513
526
  model_worker_batch.seq_lens_cpu,
514
527
  )
515
528
 
@@ -541,6 +554,8 @@ class EAGLEWorker(TpModelWorker):
541
554
  batch.seq_lens,
542
555
  self.speculative_num_steps,
543
556
  )
557
+ prefix_lens_cpu = batch.seq_lens_cpu
558
+ seq_lens_cpu = batch.seq_lens_cpu + self.speculative_num_steps
544
559
  extend_num_tokens = num_seqs * self.speculative_num_steps
545
560
  else:
546
561
  # In this case, the last partial page needs to be duplicated.
@@ -576,14 +591,23 @@ class EAGLEWorker(TpModelWorker):
576
591
  self.topk,
577
592
  self.page_size,
578
593
  )
579
-
580
- # TODO(lmzheng): remove this device sync
581
- extend_num_tokens = torch.sum(self.extend_lens).item()
594
+ prefix_lens_cpu = batch.seq_lens_cpu
595
+ last_page_lens = prefix_lens_cpu % self.page_size
596
+ num_new_pages_per_topk = (
597
+ last_page_lens + self.speculative_num_steps + self.page_size - 1
598
+ ) // self.page_size
599
+ seq_lens_cpu = (
600
+ prefix_lens_cpu // self.page_size * self.page_size
601
+ + num_new_pages_per_topk * (self.page_size * self.topk)
602
+ )
603
+ extend_num_tokens = torch.sum((seq_lens_cpu - prefix_lens_cpu)).item()
582
604
 
583
605
  out_cache_loc, token_to_kv_pool_state_backup = (
584
606
  batch.alloc_paged_token_slots_extend(
585
607
  prefix_lens,
608
+ prefix_lens_cpu,
586
609
  seq_lens,
610
+ seq_lens_cpu,
587
611
  last_loc,
588
612
  extend_num_tokens,
589
613
  backup_state=True,
@@ -771,6 +795,10 @@ class EAGLEWorker(TpModelWorker):
771
795
 
772
796
  return score_list, token_list, parents_list
773
797
 
798
+ def clear_cache_pool(self):
799
+ self.model_runner.req_to_token_pool.clear()
800
+ self.model_runner.token_to_kv_pool_allocator.clear()
801
+
774
802
  def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
775
803
  spec_info.prepare_for_verify(batch, self.page_size)
776
804
  batch.return_hidden_states = False
@@ -794,10 +822,12 @@ class EAGLEWorker(TpModelWorker):
794
822
  ).cpu()
795
823
 
796
824
  # Forward
797
- logits_output, _, can_run_cuda_graph = (
798
- self.target_worker.forward_batch_generation(
799
- model_worker_batch, skip_sample=True
800
- )
825
+ forward_batch_output = self.target_worker.forward_batch_generation(
826
+ model_worker_batch, is_verify=True
827
+ )
828
+ logits_output, can_run_cuda_graph = (
829
+ forward_batch_output.logits_output,
830
+ forward_batch_output.can_run_cuda_graph,
801
831
  )
802
832
 
803
833
  vocab_mask = None
@@ -997,6 +1027,7 @@ class EAGLEWorker(TpModelWorker):
997
1027
  assert isinstance(batch.spec_info, EagleDraftInput)
998
1028
  # Backup fields that will be modified in-place
999
1029
  seq_lens_backup = batch.seq_lens.clone()
1030
+ seq_lens_cpu_backup = batch.seq_lens_cpu.clone()
1000
1031
  req_pool_indices_backup = batch.req_pool_indices
1001
1032
  accept_length_backup = batch.spec_info.accept_length
1002
1033
  return_logprob_backup = batch.return_logprob
@@ -1075,6 +1106,7 @@ class EAGLEWorker(TpModelWorker):
1075
1106
  ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
1076
1107
  )
1077
1108
  batch.seq_lens = seq_lens_backup
1109
+ batch.seq_lens_cpu = seq_lens_cpu_backup
1078
1110
  batch.req_pool_indices = req_pool_indices_backup
1079
1111
  batch.spec_info.accept_length = accept_length_backup
1080
1112
  batch.return_logprob = return_logprob_backup
@@ -0,0 +1,428 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import logging
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ from dataclasses import dataclass
13
+
14
+ import torch.nn.functional as F
15
+
16
+ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
17
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
18
+ from sglang.srt.layers.sampler import apply_custom_logit_processor
19
+ from sglang.srt.managers.schedule_batch import (
20
+ ScheduleBatch,
21
+ get_last_loc,
22
+ global_server_args_dict,
23
+ )
24
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
25
+ from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
26
+ from sglang.srt.speculative.spec_utils import (
27
+ TREE_SPEC_KERNEL_AVAILABLE,
28
+ assign_req_to_token_pool,
29
+ get_src_tgt_cache_loc,
30
+ get_target_cache_loc,
31
+ )
32
+ from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
33
+
34
+ if is_cuda():
35
+ from sgl_kernel import (
36
+ top_k_renorm_prob,
37
+ top_p_renorm_prob,
38
+ tree_speculative_sampling_target_only,
39
+ verify_tree_greedy,
40
+ )
41
+ elif is_hip():
42
+ from sgl_kernel import verify_tree_greedy
43
+
44
+
45
+ @dataclass
46
+ class NgramVerifyInput(SpecInput):
47
+ def __init__(
48
+ self,
49
+ draft_token: torch.Tensor,
50
+ tree_mask: torch.Tensor,
51
+ positions: torch.Tensor,
52
+ retrive_index: torch.Tensor,
53
+ retrive_next_token: torch.Tensor,
54
+ retrive_next_sibling: torch.Tensor,
55
+ draft_token_num: int,
56
+ ):
57
+ super().__init__(SpecInputType.NGRAM_VERIFY)
58
+ self.draft_token = draft_token
59
+ self.custom_mask = tree_mask
60
+ self.positions = positions
61
+ self.retrive_index = retrive_index
62
+ self.retrive_next_token = retrive_next_token
63
+ self.retrive_next_sibling = retrive_next_sibling
64
+ self.draft_token_num = draft_token_num
65
+ self.device = self.custom_mask.device
66
+
67
+ def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
68
+ return self.draft_token_num, self.draft_token_num
69
+
70
+ def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
71
+ if batch.forward_mode.is_idle():
72
+ return
73
+
74
+ batch.input_ids = self.draft_token
75
+
76
+ if page_size == 1:
77
+ batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
78
+ end_offset = batch.seq_lens + self.draft_token_num
79
+ else:
80
+ # TODO(lsyin): add prefix lens cpu here to support page size > 1
81
+ prefix_lens = batch.seq_lens
82
+ prefix_lens_cpu = batch.seq_lens_cpu
83
+ end_offset = prefix_lens + self.draft_token_num
84
+ end_offset_cpu = prefix_lens_cpu + self.draft_token_num
85
+ last_loc = get_last_loc(
86
+ batch.req_to_token_pool.req_to_token,
87
+ batch.req_pool_indices,
88
+ prefix_lens,
89
+ )
90
+ batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
91
+ prefix_lens,
92
+ prefix_lens_cpu,
93
+ end_offset,
94
+ end_offset_cpu,
95
+ last_loc,
96
+ len(batch.input_ids),
97
+ )
98
+ self.last_loc = last_loc
99
+
100
+ bs = batch.batch_size()
101
+ assign_req_to_token_pool[(bs,)](
102
+ batch.req_pool_indices,
103
+ batch.req_to_token_pool.req_to_token,
104
+ batch.seq_lens,
105
+ end_offset,
106
+ batch.out_cache_loc,
107
+ batch.req_to_token_pool.req_to_token.shape[1],
108
+ triton.next_power_of_2(bs),
109
+ )
110
+
111
+ def generate_attn_arg_prefill(
112
+ self,
113
+ req_pool_indices: torch.Tensor,
114
+ paged_kernel_lens: torch.Tensor,
115
+ paged_kernel_lens_sum: int,
116
+ req_to_token: torch.Tensor,
117
+ ):
118
+ bs = len(req_pool_indices)
119
+
120
+ cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
121
+
122
+ paged_kernel_lens = paged_kernel_lens + self.draft_token_num
123
+ cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
124
+
125
+ self.qo_indptr = (
126
+ torch.arange(0, bs + 1, dtype=torch.int32, device=self.device)
127
+ * self.draft_token_num
128
+ )
129
+
130
+ kv_indices = torch.empty(
131
+ cum_kv_seq_len[-1], dtype=torch.int32, device=self.device
132
+ )
133
+
134
+ create_flashinfer_kv_indices_triton[(bs,)](
135
+ req_to_token,
136
+ req_pool_indices,
137
+ paged_kernel_lens,
138
+ cum_kv_seq_len,
139
+ None,
140
+ kv_indices,
141
+ req_to_token.size(1),
142
+ )
143
+ return kv_indices, cum_kv_seq_len, self.qo_indptr, self.custom_mask
144
+
145
+ def _fill_requests(
146
+ self,
147
+ batch: ScheduleBatch,
148
+ logits_output: torch.Tensor,
149
+ ):
150
+ accept_index_cpu = self.accept_index.tolist()
151
+ predict_cpu = self.predict.tolist()
152
+ has_finished = False
153
+
154
+ # Iterate every accepted token and check if req has finished after append the token
155
+ # should be checked BEFORE free kv cache slots
156
+ for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
157
+ for j, idx in enumerate(accept_index_row):
158
+ if idx == -1:
159
+ break
160
+ id = predict_cpu[idx]
161
+ req.output_ids.append(id)
162
+ req.check_finished()
163
+ if req.finished():
164
+ has_finished = True
165
+ # set all tokens after finished token to -1 and break
166
+ self.accept_index[i, j + 1 :] = -1
167
+ break
168
+ else:
169
+ if req.grammar is not None:
170
+ try:
171
+ req.grammar.accept_token(id)
172
+ except ValueError as e:
173
+ logger.info(
174
+ f"{i=}, {req=}\n"
175
+ f"{self.accept_index=}\n"
176
+ f"{self.predict=}\n"
177
+ )
178
+ raise e
179
+ req.spec_verify_ct += 1
180
+ if has_finished:
181
+ self.accept_length = (self.accept_index != -1).sum(dim=1) - 1
182
+ self.accept_index = self.accept_index[self.accept_index != -1]
183
+
184
+ logits_output.next_token_logits = logits_output.next_token_logits[
185
+ self.accept_index
186
+ ]
187
+ if logits_output.hidden_states:
188
+ logits_output.hidden_states = logits_output.hidden_states[self.accept_index]
189
+ self.verified_id = self.predict[self.accept_index]
190
+
191
+ def _free_cache(self, batch: ScheduleBatch, page_size: int):
192
+ bs = batch.batch_size()
193
+ # Free the KV cache for unaccepted tokens
194
+ if page_size == 1:
195
+ # TODO: boolean array index leads to a device sync. Remove it.
196
+ evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
197
+ evict_mask[self.accept_index] = False
198
+ batch.token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
199
+ batch.out_cache_loc = batch.out_cache_loc[self.accept_index]
200
+ else:
201
+ # Shift the accepted tokens to the beginning.
202
+ # Only evict the last part
203
+ src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc(
204
+ batch.seq_lens,
205
+ batch.out_cache_loc,
206
+ self.accept_index,
207
+ self.accept_length,
208
+ self.draft_token_num,
209
+ page_size,
210
+ )
211
+ to_free_slots = torch.empty(
212
+ (to_free_num_slots.sum().item(),),
213
+ dtype=torch.int64,
214
+ device=to_free_num_slots.device,
215
+ )
216
+
217
+ # out_cache_loc: [0 1 2, 3 4 5, 6 7 8]
218
+ # accept_index: [0 -1 2, 3 4 -1, 6 -1 -1]
219
+ # tgt_cache_loc: [0 1 , 3 4 , 6 ]
220
+ # to_free_slots: [ 2, 5, 7 8]
221
+ # to_free_slots also needs to be page-aligned without the first partial page
222
+ #
223
+ # split each row of out_cache_loc into two parts.
224
+ # 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1
225
+ # 2. the second part goes to to_free_slots.
226
+ get_target_cache_loc[(bs,)](
227
+ tgt_cache_loc,
228
+ to_free_slots,
229
+ self.accept_length,
230
+ to_free_num_slots,
231
+ batch.out_cache_loc,
232
+ self.draft_token_num,
233
+ next_power_of_2(self.draft_token_num),
234
+ next_power_of_2(bs),
235
+ )
236
+
237
+ # Free the kv cache
238
+ batch.token_to_kv_pool_allocator.free(to_free_slots)
239
+
240
+ # Copy the kv cache
241
+ batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
242
+ tgt_cache_loc, src_cache_loc
243
+ )
244
+ batch.out_cache_loc = tgt_cache_loc
245
+
246
+ assign_req_to_token_pool[(bs,)](
247
+ batch.req_pool_indices,
248
+ batch.req_to_token_pool.req_to_token,
249
+ batch.seq_lens,
250
+ batch.seq_lens + self.accept_length + 1,
251
+ batch.out_cache_loc,
252
+ batch.req_to_token_pool.req_to_token.shape[1],
253
+ triton.next_power_of_2(bs),
254
+ )
255
+
256
+ def _greedy_verify(
257
+ self,
258
+ batch: ScheduleBatch,
259
+ logits_output: LogitsProcessorOutput,
260
+ ):
261
+ bs = batch.batch_size()
262
+ target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
263
+ target_predict = target_predict.reshape(bs, self.draft_token_num)
264
+
265
+ candidates = self.draft_token.reshape(bs, self.draft_token_num)
266
+ predict_shape = list(logits_output.next_token_logits.shape)[:-1]
267
+ predict_shape[-1] += 1
268
+ self.predict = torch.empty(predict_shape, dtype=torch.int32, device=self.device)
269
+ self.accept_index = torch.full(
270
+ (bs, self.draft_token_num), -1, dtype=torch.int32, device=self.device
271
+ )
272
+ self.accept_length = torch.empty((bs,), dtype=torch.int32, device=self.device)
273
+
274
+ verify_tree_greedy(
275
+ predicts=self.predict, # mutable
276
+ accept_index=self.accept_index, # mutable
277
+ accept_token_num=self.accept_length, # mutable
278
+ candidates=candidates,
279
+ retrive_index=self.retrive_index,
280
+ retrive_next_token=self.retrive_next_token,
281
+ retrive_next_sibling=self.retrive_next_sibling,
282
+ target_predict=target_predict,
283
+ )
284
+
285
+ def _sampling_verify(
286
+ self,
287
+ batch: ScheduleBatch,
288
+ logits_output: LogitsProcessorOutput,
289
+ sampling_info: SamplingBatchInfo,
290
+ ):
291
+ bs = batch.batch_size()
292
+ candidates = self.draft_token.reshape(bs, self.draft_token_num)
293
+ predict_shape = list(logits_output.next_token_logits.shape)[:-1]
294
+ predict_shape[-1] += 1
295
+ self.predict = torch.empty(predict_shape, dtype=torch.int32, device=self.device)
296
+ self.accept_index = torch.full(
297
+ (bs, self.draft_token_num), -1, dtype=torch.int32, device=self.device
298
+ )
299
+ self.accept_length = torch.empty((bs,), dtype=torch.int32, device=self.device)
300
+ # apply temperature and get target probs
301
+ expanded_temperature = torch.repeat_interleave(
302
+ sampling_info.temperatures, self.draft_token_num, dim=0
303
+ ) # (bs * draft_token_num, 1)
304
+
305
+ target_probs = F.softmax(
306
+ logits_output.next_token_logits / expanded_temperature, dim=-1
307
+ ) # (bs * draft_token_num, vocab_size)
308
+
309
+ # NOTE: The test shows that top_p_renorm_prob and top_k_renorm_prob are the key factors
310
+ # contributing to the poor performance of _sampling_verify.
311
+ target_probs = top_k_renorm_prob(
312
+ target_probs,
313
+ torch.repeat_interleave(sampling_info.top_ks, self.draft_token_num, dim=0),
314
+ ) # (bs * draft_token_num, vocab_size)
315
+
316
+ if sampling_info.need_top_p_sampling:
317
+ # logger.info("Using top-p sampling in speculative decoding verification.")
318
+ target_probs = top_p_renorm_prob(
319
+ target_probs,
320
+ torch.repeat_interleave(
321
+ sampling_info.top_ps, self.draft_token_num, dim=0
322
+ ),
323
+ )
324
+
325
+ target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
326
+ draft_probs = torch.zeros(
327
+ target_probs.shape, dtype=torch.float32, device=self.device
328
+ )
329
+
330
+ # coins for rejection sampling
331
+ coins = torch.rand_like(candidates, dtype=torch.float32, device=self.device)
332
+ # coins for final sampling
333
+ coins_for_final_sampling = torch.rand(
334
+ (bs,), dtype=torch.float32, device=self.device
335
+ )
336
+ tree_speculative_sampling_target_only(
337
+ predicts=self.predict, # mutable
338
+ accept_index=self.accept_index, # mutable
339
+ accept_token_num=self.accept_length, # mutable
340
+ candidates=candidates.to(torch.int64),
341
+ retrive_index=self.retrive_index.to(torch.int64),
342
+ retrive_next_token=self.retrive_next_token.to(torch.int64),
343
+ retrive_next_sibling=self.retrive_next_sibling.to(torch.int64),
344
+ uniform_samples=coins,
345
+ uniform_samples_for_final_sampling=coins_for_final_sampling,
346
+ target_probs=target_probs,
347
+ draft_probs=draft_probs,
348
+ threshold_single=global_server_args_dict[
349
+ "speculative_accept_threshold_single"
350
+ ],
351
+ threshold_acc=global_server_args_dict["speculative_accept_threshold_acc"],
352
+ deterministic=True,
353
+ )
354
+
355
+ def verify(
356
+ self,
357
+ batch: ScheduleBatch,
358
+ logits_output: LogitsProcessorOutput,
359
+ page_size: int,
360
+ vocab_mask: Optional[torch.Tensor] = None, # For grammar
361
+ ) -> torch.Tensor:
362
+ bs = self.retrive_index.shape[0]
363
+ sampling_info = batch.sampling_info
364
+
365
+ if bs != len(sampling_info):
366
+ sampling_info = copy.deepcopy(sampling_info)
367
+ # NOTE: retrive_index are the indices of the requests that are kept.
368
+ sampling_info.filter_batch(self.retrive_index.tolist(), self.retrive_index)
369
+
370
+ # Apply the custom logit processors if registered in the sampling info.
371
+ if sampling_info.has_custom_logit_processor:
372
+ apply_custom_logit_processor(
373
+ logits_output.next_token_logits,
374
+ sampling_info,
375
+ num_tokens_in_batch=self.draft_token_num,
376
+ )
377
+
378
+ # Apply penalty
379
+ if sampling_info.penalizer_orchestrator.is_required:
380
+ # This is a relaxed version of penalties for speculative decoding.
381
+ linear_penalty = torch.zeros(
382
+ (bs, logits_output.next_token_logits.shape[1]),
383
+ dtype=torch.float32,
384
+ device=self.device,
385
+ )
386
+ sampling_info.apply_logits_bias(linear_penalty)
387
+ logits_output.next_token_logits.add_(
388
+ torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
389
+ )
390
+
391
+ # Apply grammar mask
392
+ if vocab_mask is not None:
393
+ assert self.grammar is not None
394
+ self.grammar.apply_vocab_mask(
395
+ logits=logits_output.next_token_logits, vocab_mask=vocab_mask
396
+ )
397
+
398
+ # Sample tokens. Force greedy sampling on AMD
399
+ is_all_greedy = sampling_info.is_all_greedy
400
+ if (not is_all_greedy) and (not TREE_SPEC_KERNEL_AVAILABLE):
401
+ logger.warning(
402
+ "Tree speculative sampling kernel unavailable (likely AMD/HIP build). "
403
+ "Falling back to greedy verification."
404
+ )
405
+
406
+ if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE:
407
+ self._greedy_verify(batch, logits_output)
408
+ else:
409
+ # NOTE: Compared with greedy_verify, the performance of _sampling_verify is relatively poor.
410
+ self._greedy_verify(batch, logits_output)
411
+ # self._sampling_verify(batch, logits_output, sampling_info)
412
+
413
+ self._fill_requests(batch, logits_output)
414
+ self._free_cache(batch, page_size)
415
+
416
+ accept_length_cpu = self.accept_length.cpu()
417
+ num_accepted_tokens = accept_length_cpu.sum().item()
418
+
419
+ batch.seq_lens.add_(self.accept_length + 1)
420
+ batch.seq_lens_cpu.add_(accept_length_cpu + 1)
421
+
422
+ return logits_output, self.verified_id, num_accepted_tokens
423
+
424
+ def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
425
+ pass
426
+
427
+ def merge_batch(self, spec_info: NgramVerifyInput):
428
+ pass