sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -1,234 +1,52 @@
1
- from __future__ import annotations
2
-
3
- import copy
4
1
  import logging
5
- import os
6
- import time
2
+ from copy import copy
7
3
  from dataclasses import dataclass
8
- from typing import List, Optional
4
+ from typing import List, Optional, Tuple
9
5
 
10
6
  import torch
11
7
  import torch.nn.functional as F
12
- import triton
13
- import triton.language as tl
14
8
 
15
9
  from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
16
10
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
17
11
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
18
12
  from sglang.srt.layers.sampler import apply_custom_logit_processor
19
13
  from sglang.srt.managers.schedule_batch import (
20
- Req,
21
14
  ScheduleBatch,
22
15
  get_last_loc,
23
16
  global_server_args_dict,
24
17
  )
25
18
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
26
- from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
19
+ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
20
+ from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
21
+ from sglang.srt.speculative.spec_utils import (
22
+ SIMULATE_ACC_LEN,
23
+ TREE_SPEC_KERNEL_AVAILABLE,
24
+ _generate_simulated_accept_index,
25
+ align_evict_mask_to_page_size,
26
+ assign_req_to_token_pool,
27
+ create_accept_length_filter,
28
+ create_extend_after_decode_spec_info,
29
+ filter_finished_cache_loc_kernel,
30
+ get_src_tgt_cache_loc,
31
+ get_target_cache_loc,
32
+ )
27
33
  from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
28
34
 
29
35
  if is_cuda():
30
36
  from sgl_kernel import (
31
- fast_topk,
32
37
  top_k_renorm_prob,
33
38
  top_p_renorm_prob,
34
39
  tree_speculative_sampling_target_only,
35
40
  verify_tree_greedy,
36
41
  )
37
42
  elif is_hip():
38
- from sgl_kernel import fast_topk, verify_tree_greedy
39
-
43
+ from sgl_kernel import verify_tree_greedy
40
44
 
41
45
  logger = logging.getLogger(__name__)
42
46
 
43
47
 
44
- # Simulate acceptance length for benchmarking purposes
45
- SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
46
- SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial")
47
-
48
- TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
49
-
50
- TREE_SPEC_KERNEL_AVAILABLE = "tree_speculative_sampling_target_only" in globals()
51
-
52
-
53
48
  @dataclass
54
- class EagleDraftInput:
55
- # The inputs for decode
56
- # shape: (b, topk)
57
- topk_p: torch.Tensor = None
58
- topk_index: torch.Tensor = None
59
- # shape: (b, hidden_size)
60
- hidden_states: torch.Tensor = None
61
- capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL
62
-
63
- # Inputs for extend
64
- # shape: (b,)
65
- verified_id: torch.Tensor = None
66
- accept_length: torch.Tensor = None
67
- accept_length_cpu: List[int] = None
68
-
69
- # Inputs for the attention backends
70
- # shape: (b + 1,)
71
- kv_indptr: torch.Tensor = None
72
- kv_indices: torch.Tensor = None
73
-
74
- # Shape info for padding
75
- num_tokens_per_batch: int = -1
76
- num_tokens_for_logprob_per_batch: int = -1
77
-
78
- # Inputs for draft extend
79
- # shape: (b,)
80
- seq_lens_for_draft_extend: torch.Tensor = None
81
- req_pool_indices_for_draft_extend: torch.Tensor = None
82
-
83
- def prepare_for_extend(self, batch: ScheduleBatch):
84
-
85
- if batch.forward_mode.is_idle():
86
- return
87
-
88
- # Prefill only generate 1 token.
89
- assert len(self.verified_id) == len(batch.seq_lens)
90
-
91
- pt = 0
92
- for i, extend_len in enumerate(batch.extend_lens):
93
- input_ids = batch.input_ids[pt : pt + extend_len]
94
- batch.input_ids[pt : pt + extend_len] = torch.cat(
95
- (input_ids[1:], self.verified_id[i].reshape(1))
96
- )
97
- pt += extend_len
98
-
99
- @classmethod
100
- def create_idle_input(
101
- cls,
102
- device: torch.device,
103
- hidden_size: int,
104
- dtype: torch.dtype,
105
- topk: int,
106
- capture_hidden_mode: CaptureHiddenMode,
107
- ):
108
- return cls(
109
- verified_id=torch.empty((0,), device=device, dtype=torch.int32),
110
- hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
111
- topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
112
- topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
113
- capture_hidden_mode=capture_hidden_mode,
114
- accept_length=torch.empty((0,), device=device, dtype=torch.int32),
115
- accept_length_cpu=[],
116
- )
117
-
118
- def prepare_extend_after_decode(
119
- self,
120
- batch: ScheduleBatch,
121
- speculative_num_steps: int,
122
- ):
123
-
124
- if batch.forward_mode.is_idle():
125
- return
126
-
127
- batch.input_ids = self.verified_id
128
- batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
129
- batch.extend_num_tokens = sum(batch.extend_lens)
130
- batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
131
- batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
132
- batch.return_logprob = False
133
- batch.return_hidden_states = False
134
-
135
- self.capture_hidden_mode = CaptureHiddenMode.LAST
136
- self.accept_length.add_(1)
137
- self.positions = torch.empty_like(batch.input_ids, dtype=torch.long)
138
- self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
139
-
140
- create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
141
- batch.input_ids,
142
- batch.seq_lens,
143
- self.accept_length,
144
- self.positions,
145
- self.verified_id,
146
- next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
147
- )
148
-
149
- def generate_attn_arg_prefill(
150
- self,
151
- req_pool_indices: torch.Tensor,
152
- paged_kernel_lens: torch.Tensor,
153
- paged_kernel_lens_sum: int,
154
- req_to_token: torch.Tensor,
155
- ):
156
- bs = self.accept_length.numel()
157
- qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
158
- qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
159
- cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
160
- cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
161
-
162
- if paged_kernel_lens_sum is None:
163
- paged_kernel_lens_sum = cum_kv_seq_len[-1]
164
-
165
- kv_indices = torch.empty(
166
- paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
167
- )
168
-
169
- create_flashinfer_kv_indices_triton[(bs,)](
170
- req_to_token,
171
- req_pool_indices,
172
- paged_kernel_lens,
173
- cum_kv_seq_len,
174
- None,
175
- kv_indices,
176
- req_to_token.size(1),
177
- )
178
- return kv_indices, cum_kv_seq_len, qo_indptr, None
179
-
180
- def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
181
- if has_been_filtered:
182
- # in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
183
- # therefore, we don't need to filter the batch again in scheduler
184
- if len(new_indices) != len(self.topk_p):
185
- logger.warning(
186
- f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen"
187
- )
188
- self.topk_p = self.topk_p[: len(new_indices)]
189
- self.topk_index = self.topk_index[: len(new_indices)]
190
- self.hidden_states = self.hidden_states[: len(new_indices)]
191
- self.verified_id = self.verified_id[: len(new_indices)]
192
- else:
193
- # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
194
- self.topk_p = self.topk_p[new_indices]
195
- self.topk_index = self.topk_index[new_indices]
196
- self.hidden_states = self.hidden_states[new_indices]
197
- self.verified_id = self.verified_id[new_indices]
198
-
199
- def merge_batch(self, spec_info: EagleDraftInput):
200
- if self.hidden_states is None:
201
- self.hidden_states = spec_info.hidden_states
202
- self.verified_id = spec_info.verified_id
203
- self.topk_p = spec_info.topk_p
204
- self.topk_index = spec_info.topk_index
205
- return
206
- if spec_info.hidden_states is None:
207
- return
208
- self.hidden_states = torch.cat(
209
- [self.hidden_states, spec_info.hidden_states], axis=0
210
- )
211
- self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
212
- self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
213
- self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
214
-
215
-
216
- @dataclass
217
- class EagleVerifyOutput:
218
- # Draft input batch
219
- draft_input: EagleDraftInput
220
- # Logit outputs from target worker
221
- logits_output: LogitsProcessorOutput
222
- # Accepted token ids including the bonus token
223
- verified_id: torch.Tensor
224
- # Accepted token length per sequence in a batch in CPU.
225
- accept_length_per_req_cpu: List[int]
226
- # Accepted indices from logits_output.next_token_logits
227
- accepted_indices: torch.Tensor
228
-
229
-
230
- @dataclass
231
- class EagleVerifyInput:
49
+ class EagleVerifyInput(SpecInput):
232
50
  draft_token: torch.Tensor
233
51
  custom_mask: torch.Tensor
234
52
  positions: torch.Tensor
@@ -244,6 +62,12 @@ class EagleVerifyInput:
244
62
  seq_lens_cpu: torch.Tensor
245
63
  grammar: BaseGrammarObject = None
246
64
 
65
+ def __post_init__(self):
66
+ super().__init__(SpecInputType.EAGLE_VERIFY)
67
+
68
+ def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
69
+ return self.draft_token_num, self.draft_token_num
70
+
247
71
  @classmethod
248
72
  def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
249
73
  return cls(
@@ -280,14 +104,21 @@ class EagleVerifyInput:
280
104
  end_offset = batch.seq_lens + self.draft_token_num
281
105
  else:
282
106
  prefix_lens = batch.seq_lens
107
+ prefix_lens_cpu = batch.seq_lens_cpu
283
108
  end_offset = prefix_lens + self.draft_token_num
109
+ end_offset_cpu = prefix_lens_cpu + self.draft_token_num
284
110
  last_loc = get_last_loc(
285
111
  batch.req_to_token_pool.req_to_token,
286
112
  batch.req_pool_indices,
287
113
  prefix_lens,
288
114
  )
289
115
  batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
290
- prefix_lens, end_offset, last_loc, len(batch.input_ids)
116
+ prefix_lens,
117
+ prefix_lens_cpu,
118
+ end_offset,
119
+ end_offset_cpu,
120
+ last_loc,
121
+ len(batch.input_ids),
291
122
  )
292
123
  self.last_loc = last_loc
293
124
 
@@ -500,13 +331,12 @@ class EagleVerifyInput:
500
331
  deterministic=True,
501
332
  )
502
333
 
503
- if SIMULATE_ACC_LEN:
334
+ if SIMULATE_ACC_LEN > 0.0:
504
335
  # Do simulation
505
336
  accept_index = _generate_simulated_accept_index(
506
337
  accept_index=accept_index,
507
338
  predict=predict, # mutable
508
339
  accept_length=accept_length, # mutable
509
- simulate_acc_len=SIMULATE_ACC_LEN,
510
340
  bs=bs,
511
341
  spec_steps=self.spec_steps,
512
342
  )
@@ -557,6 +387,10 @@ class EagleVerifyInput:
557
387
  verified_id = predict[accept_index]
558
388
  evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
559
389
  evict_mask[accept_index] = False
390
+ accept_length_cpu = accept_length.cpu()
391
+ # FIXME: this `tolist()` fixes the numerical calculation consistency
392
+ # try to unify the tensor representation and list representation
393
+ accept_length_list = accept_length_cpu.tolist()
560
394
 
561
395
  if page_size == 1:
562
396
  # TODO: boolean array index leads to a device sync. Remove it.
@@ -633,13 +467,15 @@ class EagleVerifyInput:
633
467
  else:
634
468
  batch.out_cache_loc = tgt_cache_loc
635
469
  batch.seq_lens.add_(accept_length + 1)
470
+ batch.seq_lens_cpu.add_(accept_length_cpu + 1)
636
471
 
637
472
  draft_input = EagleDraftInput(
638
473
  hidden_states=batch.spec_info.hidden_states[accept_index],
639
474
  verified_id=verified_id,
640
475
  accept_length=accept_length,
641
- accept_length_cpu=accept_length.tolist(),
476
+ accept_length_cpu=accept_length_list,
642
477
  seq_lens_for_draft_extend=batch.seq_lens,
478
+ seq_lens_for_draft_extend_cpu=batch.seq_lens_cpu,
643
479
  req_pool_indices_for_draft_extend=batch.req_pool_indices,
644
480
  )
645
481
 
@@ -662,15 +498,15 @@ class EagleVerifyInput:
662
498
  next_power_of_2(bs),
663
499
  )
664
500
  batch.seq_lens.add_(accept_length + 1)
501
+ batch.seq_lens_cpu.add_(accept_length_cpu + 1)
665
502
 
666
- accept_length_cpu = accept_length.tolist()
667
503
  if len(unfinished_accept_index) > 0:
668
504
  unfinished_accept_index = torch.cat(unfinished_accept_index)
669
505
  unfinished_index_device = torch.tensor(
670
506
  unfinished_index, dtype=torch.int64, device=predict.device
671
507
  )
672
508
  draft_input_accept_length_cpu = [
673
- accept_length_cpu[i] for i in unfinished_index
509
+ accept_length_list[i] for i in unfinished_index
674
510
  ]
675
511
  if page_size == 1 or self.topk == 1:
676
512
  batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index]
@@ -685,6 +521,7 @@ class EagleVerifyInput:
685
521
  unfinished_index_device,
686
522
  batch.seq_lens,
687
523
  )
524
+ batch.seq_lens_cpu.add_(accept_length_cpu + 1)
688
525
  filter_finished_cache_loc_kernel[(bs,)](
689
526
  batch.out_cache_loc,
690
527
  tgt_cache_loc,
@@ -702,6 +539,7 @@ class EagleVerifyInput:
702
539
  accept_length_cpu=draft_input_accept_length_cpu,
703
540
  accept_length=accept_length[unfinished_index_device],
704
541
  seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device],
542
+ seq_lens_for_draft_extend_cpu=batch.seq_lens_cpu[unfinished_index],
705
543
  req_pool_indices_for_draft_extend=batch.req_pool_indices[
706
544
  unfinished_index_device
707
545
  ],
@@ -719,577 +557,191 @@ class EagleVerifyInput:
719
557
  draft_input=draft_input,
720
558
  logits_output=logits_output,
721
559
  verified_id=verified_id,
722
- accept_length_per_req_cpu=accept_length_cpu,
560
+ accept_length_per_req_cpu=accept_length_list,
723
561
  accepted_indices=accept_index,
724
562
  )
725
563
 
726
564
 
727
- @triton.jit
728
- def create_extend_after_decode_spec_info(
729
- verified_id,
730
- seq_lens,
731
- accept_lens,
732
- positions,
733
- new_verified_id,
734
- bs_upper: tl.constexpr,
735
- ):
736
- pid = tl.program_id(axis=0)
737
- offsets = tl.arange(0, bs_upper)
738
- seq_length = tl.load(seq_lens + pid)
739
- accept_length = tl.load(accept_lens + pid)
740
-
741
- accept_len_cumsum = tl.sum(
742
- tl.load(accept_lens + offsets, mask=offsets < pid, other=0)
743
- )
744
- positions_ptr = positions + accept_len_cumsum
745
- mask = offsets < accept_length
746
- tl.store(positions_ptr + offsets, seq_length - accept_length + offsets, mask)
747
-
748
- accept_len_cumsum += accept_length - 1
749
- verified_id_data = tl.load(verified_id + accept_len_cumsum)
750
- tl.store(new_verified_id + pid, verified_id_data)
751
-
752
-
753
- @triton.jit
754
- def assign_req_to_token_pool(
755
- req_pool_indices,
756
- req_to_token,
757
- start_offset,
758
- end_offset,
759
- out_cache_loc,
760
- pool_len: tl.constexpr,
761
- bs_upper: tl.constexpr,
762
- ):
763
- BLOCK_SIZE: tl.constexpr = 32
764
- pid = tl.program_id(axis=0)
765
- kv_start = tl.load(start_offset + pid)
766
- kv_end = tl.load(end_offset + pid)
767
- token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
768
-
769
- length_offset = tl.arange(0, bs_upper)
770
- start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0)
771
- end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0)
772
- out_offset = tl.sum(end - start, axis=0)
773
-
774
- out_cache_ptr = out_cache_loc + out_offset
775
-
776
- save_offset = tl.arange(0, BLOCK_SIZE) + kv_start
777
- load_offset = tl.arange(0, BLOCK_SIZE)
778
-
779
- num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
780
- for _ in range(num_loop):
781
- mask = save_offset < kv_end
782
- data = tl.load(out_cache_ptr + load_offset, mask=mask)
783
- tl.store(token_pool + save_offset, data, mask=mask)
784
- save_offset += BLOCK_SIZE
785
- load_offset += BLOCK_SIZE
786
-
787
-
788
- @triton.jit
789
- def assign_draft_cache_locs(
790
- req_pool_indices,
791
- req_to_token,
792
- seq_lens,
793
- extend_lens,
794
- num_new_pages_per_topk,
795
- out_cache_loc,
796
- pool_len: tl.constexpr,
797
- topk: tl.constexpr,
798
- speculative_num_steps: tl.constexpr,
799
- page_size: tl.constexpr,
800
- bs_upper: tl.constexpr,
801
- iter_upper: tl.constexpr,
802
- ):
803
- BLOCK_SIZE: tl.constexpr = 128
804
- pid = tl.program_id(axis=0)
805
-
806
- if page_size == 1 or topk == 1:
807
- copy_len = topk * speculative_num_steps
808
- out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
809
- else:
810
- bs_offset = tl.arange(0, bs_upper)
811
- copy_len = tl.load(extend_lens + pid)
812
- cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
813
- out_cache_ptr = out_cache_loc + cum_copy_len
814
-
815
- # Part 1: Copy from out_cache_loc to req_to_token
816
- kv_start = tl.load(seq_lens + pid)
817
- token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
818
- num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
819
- for i in range(num_loop):
820
- copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
821
- mask = copy_offset < copy_len
822
- data = tl.load(out_cache_ptr + copy_offset, mask=mask)
823
- tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
824
-
825
- if page_size == 1 or topk == 1:
826
- return
827
-
828
- # Part 2: Copy the indices for the last partial page
829
- prefix_len = tl.load(seq_lens + pid)
830
- last_page_len = prefix_len % page_size
831
- offsets = tl.arange(0, page_size)
832
- mask = offsets < last_page_len
833
- num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
834
- prefix_base = token_pool + prefix_len - last_page_len
835
-
836
- for topk_id in range(topk):
837
- value = tl.load(prefix_base + offsets, mask=mask)
838
- tl.store(
839
- prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
840
- value,
841
- mask=mask,
842
- )
843
-
844
- # Part 3: Remove the padding in out_cache_loc
845
- iter_offest = tl.arange(0, iter_upper)
846
- for topk_id in range(topk):
847
- indices = tl.load(
848
- prefix_base
849
- + topk_id * num_new_pages_per_topk_ * page_size
850
- + last_page_len
851
- + iter_offest,
852
- mask=iter_offest < speculative_num_steps,
853
- )
854
- tl.store(
855
- out_cache_loc
856
- + pid * topk * speculative_num_steps
857
- + topk_id * speculative_num_steps
858
- + iter_offest,
859
- indices,
860
- mask=iter_offest < speculative_num_steps,
861
- )
565
+ @dataclass
566
+ class EagleDraftInput(SpecInput):
567
+ # The inputs for decode
568
+ # shape: (b, topk)
569
+ topk_p: torch.Tensor = None
570
+ topk_index: torch.Tensor = None
571
+ # shape: (b, hidden_size)
572
+ hidden_states: torch.Tensor = None
573
+ capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL
862
574
 
575
+ # Inputs for extend
576
+ # shape: (b,)
577
+ verified_id: torch.Tensor = None
578
+ accept_length: torch.Tensor = None
579
+ accept_length_cpu: List[int] = None
863
580
 
864
- @triton.jit
865
- def generate_draft_decode_kv_indices(
866
- req_pool_indices,
867
- req_to_token,
868
- paged_kernel_lens,
869
- kv_indices,
870
- kv_indptr,
871
- positions,
872
- pool_len: tl.constexpr,
873
- kv_indices_stride: tl.constexpr,
874
- kv_indptr_stride: tl.constexpr,
875
- bs_upper: tl.constexpr,
876
- iter_upper: tl.constexpr,
877
- num_tokens_upper: tl.constexpr,
878
- page_size: tl.constexpr,
879
- ):
880
- BLOCK_SIZE: tl.constexpr = 128
881
- iters = tl.program_id(axis=0)
882
- bid = tl.program_id(axis=1)
883
- topk_id = tl.program_id(axis=2)
884
-
885
- num_steps = tl.num_programs(axis=0)
886
- num_seqs = tl.num_programs(axis=1)
887
- topk = tl.num_programs(axis=2)
888
-
889
- kv_indices += kv_indices_stride * iters
890
- kv_indptr += kv_indptr_stride * iters
891
- iters += 1
892
-
893
- load_offset = tl.arange(0, bs_upper)
894
- seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid, other=0)
895
- seq_len = tl.load(paged_kernel_lens + bid)
896
- cum_seq_len = tl.sum(seq_lens)
897
-
898
- # Update kv_indices
899
- kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
900
- kv_ptr = kv_indices + kv_offset
901
- token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
902
-
903
- kv_offset = tl.arange(0, BLOCK_SIZE)
904
- num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
905
- for _ in range(num_loop):
906
- mask = kv_offset < seq_len
907
- data = tl.load(token_pool_ptr + kv_offset, mask=mask)
908
- tl.store(kv_ptr + kv_offset, data, mask=mask)
909
- kv_offset += BLOCK_SIZE
910
-
911
- extend_offset = tl.arange(0, iter_upper)
912
- if page_size == 1 or topk == 1:
913
- extend_data = tl.load(
914
- token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper),
915
- mask=extend_offset < iters,
916
- )
917
- else:
918
- prefix_len = seq_len
919
- last_page_len = prefix_len % page_size
920
- num_new_pages_per_topk = (
921
- last_page_len + num_steps + page_size - 1
922
- ) // page_size
923
- prefix_base = seq_len // page_size * page_size
924
- start = (
925
- prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
926
- )
927
- extend_data = tl.load(
928
- token_pool_ptr + start + extend_offset,
929
- mask=extend_offset < iters,
930
- )
581
+ # Inputs for the attention backends
582
+ # shape: (b + 1,)
583
+ kv_indptr: torch.Tensor = None
584
+ kv_indices: torch.Tensor = None
931
585
 
932
- tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
933
-
934
- # Update kv_indptr
935
- bs_offset = tl.arange(0, num_tokens_upper)
936
-
937
- zid = bid * topk + topk_id
938
- if zid == 0:
939
- zid = num_seqs * topk
940
- positions = tl.load(positions + bs_offset, mask=bs_offset < zid, other=0)
941
- base = tl.sum(positions)
942
- tl.store(kv_indptr + zid, base + zid * iters)
943
-
944
-
945
- @triton.jit
946
- def align_evict_mask_to_page_size(
947
- seq_lens,
948
- evict_mask,
949
- page_size: tl.constexpr,
950
- num_draft_tokens: tl.constexpr,
951
- BLOCK_SIZE: tl.constexpr,
952
- ):
953
- t_range = tl.arange(0, BLOCK_SIZE)
954
-
955
- bid = tl.program_id(axis=0)
956
- seq_len = tl.load(seq_lens + bid)
957
- io_mask = t_range < num_draft_tokens
958
- mask_row = tl.load(
959
- evict_mask + bid * num_draft_tokens + t_range, mask=io_mask, other=0
960
- )
586
+ # Shape info for padding
587
+ num_tokens_per_batch: int = -1
588
+ num_tokens_for_logprob_per_batch: int = -1
961
589
 
962
- num_trues = tl.sum(mask_row)
963
- num_false = num_draft_tokens - num_trues
964
-
965
- start = (seq_len + num_false - 1) // page_size * page_size - seq_len
966
- for i in range(max(start, 0), min(start + page_size, num_draft_tokens)):
967
- tl.store(evict_mask + bid * num_draft_tokens + i, False)
968
-
969
-
970
- @triton.jit
971
- def get_target_cache_loc(
972
- tgt_cache_loc,
973
- to_free_slots,
974
- accept_length,
975
- to_free_num_slots,
976
- out_cache_loc,
977
- num_verify_tokens: tl.constexpr,
978
- num_verify_tokens_upper: tl.constexpr,
979
- bs_upper: tl.constexpr,
980
- ):
981
- bid = tl.program_id(axis=0)
982
- offset = tl.arange(0, num_verify_tokens_upper)
983
- bs_offset = tl.arange(0, bs_upper)
984
-
985
- # write the first part to tgt_cache_loc
986
- accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
987
- tgt_cache_loc_start = tl.sum(accept_len_all) + bid
988
- copy_len = tl.load(accept_length + bid) + 1
989
- out_cache_loc_row = tl.load(
990
- out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
991
- )
992
- tl.store(
993
- tgt_cache_loc + tgt_cache_loc_start + offset,
994
- out_cache_loc_row,
995
- mask=offset < copy_len,
996
- )
590
+ # Inputs for draft extend
591
+ # shape: (b,)
592
+ seq_lens_for_draft_extend: torch.Tensor = None
593
+ seq_lens_for_draft_extend_cpu: torch.Tensor = None
594
+ req_pool_indices_for_draft_extend: torch.Tensor = None
997
595
 
998
- # write the second part to to_free_num_pages
999
- to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid)
1000
- to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
1001
- out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
1002
- to_free_slots_start = tl.sum(to_free_num_slots_all)
596
+ def __post_init__(self):
597
+ super().__init__(SpecInputType.EAGLE_DRAFT)
1003
598
 
1004
- copy_len = to_free_num_slots_cur
1005
- out_cache_loc_row = tl.load(
1006
- out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
1007
- mask=offset < copy_len,
1008
- )
1009
- tl.store(
1010
- to_free_slots + to_free_slots_start + offset,
1011
- out_cache_loc_row,
1012
- mask=offset < copy_len,
1013
- )
599
+ def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
600
+ return self.num_tokens_per_batch, self.num_tokens_for_logprob_per_batch
1014
601
 
602
+ def prepare_for_extend(self, batch: ScheduleBatch):
1015
603
 
1016
- @torch.compile(dynamic=True)
1017
- def get_src_tgt_cache_loc(
1018
- seq_lens: torch.Tensor,
1019
- out_cache_loc: torch.Tensor,
1020
- accept_index: torch.Tensor,
1021
- accept_length: torch.Tensor,
1022
- draft_token_num: int,
1023
- page_size: int,
1024
- ):
1025
- src_cache_loc = out_cache_loc[accept_index]
1026
- tgt_cache_loc = torch.empty_like(src_cache_loc)
1027
- extended_len = seq_lens + draft_token_num
1028
- keep_len = torch.minimum(
1029
- (seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
1030
- extended_len,
1031
- )
1032
- to_free_num_slots = extended_len - keep_len
1033
- return src_cache_loc, tgt_cache_loc, to_free_num_slots
1034
-
1035
-
1036
- @triton.jit
1037
- def filter_finished_cache_loc_kernel(
1038
- out_cache_loc,
1039
- tgt_cache_loc,
1040
- accept_length,
1041
- accept_length_filter,
1042
- bs_upper: tl.constexpr,
1043
- num_verify_tokens_upper: tl.constexpr,
1044
- ):
1045
- bid = tl.program_id(0)
1046
- bs_offset = tl.arange(0, bs_upper)
1047
-
1048
- accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
1049
- old_start = tl.sum(accept_length_all) + bid
1050
-
1051
- accept_length_filter_all = tl.load(
1052
- accept_length_filter + bs_offset, mask=bs_offset < bid
1053
- )
1054
- new_start = tl.sum(accept_length_filter_all)
604
+ if batch.forward_mode.is_idle():
605
+ return
1055
606
 
1056
- copy_len = tl.load(accept_length_filter + bid)
1057
- copy_offset = tl.arange(0, num_verify_tokens_upper)
1058
- value = tl.load(
1059
- tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
1060
- )
1061
- tl.store(
1062
- out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
1063
- )
607
+ # Prefill only generate 1 token.
608
+ assert len(self.verified_id) == len(batch.seq_lens)
1064
609
 
610
+ pt = 0
611
+ for i, extend_len in enumerate(batch.extend_lens):
612
+ input_ids = batch.input_ids[pt : pt + extend_len]
613
+ batch.input_ids[pt : pt + extend_len] = torch.cat(
614
+ (input_ids[1:], self.verified_id[i].reshape(1))
615
+ )
616
+ pt += extend_len
1065
617
 
1066
- @torch.compile(dynamic=True)
1067
- def create_accept_length_filter(
1068
- accept_length: torch.Tensor,
1069
- unfinished_index_device: torch.Tensor,
1070
- seq_lens: torch.Tensor,
1071
- ):
1072
- accept_length_filter = torch.zeros_like(accept_length)
1073
- accept_length_filter[unfinished_index_device] = (
1074
- accept_length[unfinished_index_device] + 1
1075
- )
1076
- seq_lens.add_(accept_length + 1)
1077
- return accept_length_filter
1078
-
1079
-
1080
- @torch.compile(dynamic=True)
1081
- def select_top_k_tokens(
1082
- i: int,
1083
- topk_p: torch.Tensor,
1084
- topk_index: torch.Tensor,
1085
- hidden_states: torch.Tensor,
1086
- scores: torch.Tensor,
1087
- topk: int,
1088
- ):
1089
- if i == 0:
1090
- # The first step after extend
1091
- input_ids = topk_index.flatten()
1092
- hidden_states = hidden_states.repeat_interleave(topk, dim=0)
1093
- scores = topk_p # shape: (b, topk)
1094
-
1095
- tree_info = (
1096
- topk_p.unsqueeze(1), # shape: (b, 1, topk)
1097
- topk_index, # shape: (b, topk)
1098
- torch.arange(-1, topk, dtype=torch.long, device="cuda")
1099
- .unsqueeze(0)
1100
- .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
1101
- )
1102
- else:
1103
- # The later decode steps
1104
- expand_scores = torch.mul(
1105
- scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
1106
- ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
1107
- topk_cs_p, topk_cs_index = fast_topk(
1108
- expand_scores.flatten(start_dim=1), topk, dim=-1
1109
- ) # (b, topk)
1110
- scores = topk_cs_p # shape: (b, topk)
1111
-
1112
- topk_index = topk_index.reshape(-1, topk**2)
1113
- input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
1114
-
1115
- if hidden_states.shape[0] > 0:
1116
- selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
1117
- 0, hidden_states.shape[0], step=topk, device="cuda"
1118
- ).repeat_interleave(topk)
1119
- hidden_states = hidden_states[selected_input_index, :]
1120
-
1121
- tree_info = (
1122
- expand_scores, # shape: (b, topk, topk)
1123
- topk_index, # shape: (b, topk * topk)
1124
- topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk)
618
+ @classmethod
619
+ def create_idle_input(
620
+ cls,
621
+ device: torch.device,
622
+ hidden_size: int,
623
+ dtype: torch.dtype,
624
+ topk: int,
625
+ capture_hidden_mode: CaptureHiddenMode,
626
+ ):
627
+ return cls(
628
+ verified_id=torch.empty((0,), device=device, dtype=torch.int32),
629
+ hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
630
+ topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
631
+ topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
632
+ capture_hidden_mode=capture_hidden_mode,
633
+ accept_length=torch.empty((0,), device=device, dtype=torch.int32),
634
+ accept_length_cpu=[],
1125
635
  )
1126
636
 
1127
- return input_ids, hidden_states, scores, tree_info
1128
-
1129
-
1130
- def _generate_simulated_accept_index(
1131
- accept_index,
1132
- predict,
1133
- accept_length,
1134
- simulate_acc_len,
1135
- bs,
1136
- spec_steps,
1137
- ):
1138
- simulate_acc_len_float = float(simulate_acc_len)
1139
- if SIMULATE_ACC_METHOD == "multinomial":
1140
- simulated_values = torch.normal(
1141
- mean=simulate_acc_len_float,
1142
- std=1.0,
1143
- size=(1,),
1144
- device="cpu",
1145
- )
1146
- # clamp simulated values to be between 1 and self.spec_steps
1147
- simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
1148
- simulate_acc_len = int(simulated_values.round().item())
1149
- elif SIMULATE_ACC_METHOD == "match-expected":
1150
- # multinomial sampling does not match the expected length
1151
- # we keep it for the sake of compatibility of existing tests
1152
- # but it's better to use "match-expected" for the cases that need to
1153
- # match the expected length, One caveat is that this will only sample
1154
- # either round down or round up of the expected length
1155
- simulate_acc_len_float = max(1.0, min(spec_steps + 1, simulate_acc_len_float))
1156
- lower = int(simulate_acc_len_float // 1)
1157
- upper = lower + 1 if lower < spec_steps + 1 else lower
1158
- if lower == upper:
1159
- simulate_acc_len = lower
1160
- else:
1161
- weight_upper = simulate_acc_len_float - lower
1162
- weight_lower = 1.0 - weight_upper
1163
- probs = torch.tensor([weight_lower, weight_upper], device="cpu")
1164
- sampled_index = torch.multinomial(probs, num_samples=1)
1165
- simulate_acc_len = lower if sampled_index == 0 else upper
1166
- else:
1167
- raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
1168
-
1169
- accept_indx_first_col = accept_index[:, 0].view(-1, 1)
1170
- sim_accept_index = torch.full(
1171
- (bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda"
1172
- )
1173
- sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange(
1174
- simulate_acc_len, device=accept_index.device
1175
- )
1176
- accept_length.fill_(simulate_acc_len - 1)
1177
- predict.fill_(100) # some legit token id
1178
- return sim_accept_index
1179
-
1180
-
1181
- def traverse_tree(
1182
- retrieve_next_token: torch.Tensor,
1183
- retrieve_next_sibling: torch.Tensor,
1184
- draft_tokens: torch.Tensor,
1185
- grammar: BaseGrammarObject,
1186
- allocate_token_bitmask: torch.Tensor,
1187
- ):
1188
- """
1189
- Traverse the tree constructed by the draft model to generate the logits mask.
1190
- """
1191
- assert (
1192
- retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape
1193
- )
637
+ def prepare_extend_after_decode(
638
+ self,
639
+ batch: ScheduleBatch,
640
+ speculative_num_steps: int,
641
+ ):
642
+
643
+ if batch.forward_mode.is_idle():
644
+ return
645
+
646
+ batch.input_ids = self.verified_id
647
+ batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
648
+ batch.extend_num_tokens = sum(batch.extend_lens)
649
+ batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
650
+ batch.seq_lens_cpu = batch.spec_info.seq_lens_for_draft_extend_cpu
651
+ batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
652
+ batch.return_logprob = False
653
+ batch.return_hidden_states = False
1194
654
 
1195
- allocate_token_bitmask.fill_(0)
655
+ self.capture_hidden_mode = CaptureHiddenMode.LAST
656
+ self.accept_length.add_(1)
657
+ self.positions = torch.empty_like(batch.input_ids, dtype=torch.long)
658
+ self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
1196
659
 
1197
- def dfs(
1198
- curr: int,
1199
- retrieve_next_token: torch.Tensor,
1200
- retrieve_next_sibling: torch.Tensor,
1201
- parent_pos: int,
660
+ create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
661
+ batch.input_ids,
662
+ batch.seq_lens,
663
+ self.accept_length,
664
+ self.positions,
665
+ self.verified_id,
666
+ next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
667
+ )
668
+
669
+ def generate_attn_arg_prefill(
670
+ self,
671
+ req_pool_indices: torch.Tensor,
672
+ paged_kernel_lens: torch.Tensor,
673
+ paged_kernel_lens_sum: int,
674
+ req_to_token: torch.Tensor,
1202
675
  ):
1203
- if curr == 0:
1204
- # the first token generated by the target model, and thus it is always
1205
- # accepted from the previous iteration
1206
- accepted = True
1207
- else:
1208
- parent_bitmask = allocate_token_bitmask[parent_pos]
1209
- curr_token_id = draft_tokens[curr]
1210
- # 32 boolean bitmask values are packed into 32-bit integers
1211
- accepted = (
1212
- parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))
1213
- ) != 0
1214
-
1215
- if accepted:
1216
- if curr != 0:
1217
- # Accept the current token
1218
- grammar.accept_token(draft_tokens[curr])
1219
- if not grammar.is_terminated():
1220
- # Generate the bitmask for the current token
1221
- grammar.fill_vocab_mask(allocate_token_bitmask, curr)
1222
- if retrieve_next_token[curr] != -1:
1223
- # Visit the child node
1224
- dfs(
1225
- retrieve_next_token[curr],
1226
- retrieve_next_token,
1227
- retrieve_next_sibling,
1228
- curr,
1229
- )
676
+ bs = self.accept_length.numel()
677
+ qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
678
+ qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
679
+ cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
680
+ cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
1230
681
 
1231
- if curr != 0:
1232
- # Rollback the current token
1233
- grammar.rollback(1)
1234
-
1235
- if retrieve_next_sibling[curr] != -1:
1236
- # Visit the sibling node
1237
- dfs(
1238
- retrieve_next_sibling[curr],
1239
- retrieve_next_token,
1240
- retrieve_next_sibling,
1241
- parent_pos,
1242
- )
682
+ if paged_kernel_lens_sum is None:
683
+ paged_kernel_lens_sum = cum_kv_seq_len[-1]
1243
684
 
1244
- dfs(0, retrieve_next_token, retrieve_next_sibling, -1)
1245
-
1246
-
1247
- def generate_token_bitmask(
1248
- reqs: List[Req],
1249
- verify_input: EagleVerifyInput,
1250
- retrieve_next_token_cpu: torch.Tensor,
1251
- retrieve_next_sibling_cpu: torch.Tensor,
1252
- draft_tokens_cpu: torch.Tensor,
1253
- vocab_size: int,
1254
- ):
1255
- """
1256
- Generate the logit mask for structured output.
1257
- Draft model's token can be either valid or invalid with respect to the grammar.
1258
- We need to perform DFS to
1259
- 1. figure out which tokens are accepted by the grammar.
1260
- 2. if so, what is the corresponding logit mask.
1261
- """
1262
-
1263
- num_draft_tokens = draft_tokens_cpu.shape[-1]
1264
-
1265
- allocate_token_bitmask = None
1266
- assert len(reqs) == retrieve_next_token_cpu.shape[0]
1267
- grammar = None
1268
- for i, req in enumerate(reqs):
1269
- if req.grammar is not None:
1270
- if allocate_token_bitmask is None:
1271
- allocate_token_bitmask = req.grammar.allocate_vocab_mask(
1272
- vocab_size=vocab_size,
1273
- batch_size=draft_tokens_cpu.numel(),
1274
- device="cpu",
1275
- )
1276
- grammar = req.grammar
1277
- s = time.perf_counter()
1278
- traverse_tree(
1279
- retrieve_next_token_cpu[i],
1280
- retrieve_next_sibling_cpu[i],
1281
- draft_tokens_cpu[i],
1282
- req.grammar,
1283
- allocate_token_bitmask[
1284
- i * num_draft_tokens : (i + 1) * num_draft_tokens
1285
- ],
1286
- )
1287
- tree_traverse_time = time.perf_counter() - s
1288
- if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
685
+ kv_indices = torch.empty(
686
+ paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
687
+ )
688
+
689
+ create_flashinfer_kv_indices_triton[(bs,)](
690
+ req_to_token,
691
+ req_pool_indices,
692
+ paged_kernel_lens,
693
+ cum_kv_seq_len,
694
+ None,
695
+ kv_indices,
696
+ req_to_token.size(1),
697
+ )
698
+ return kv_indices, cum_kv_seq_len, qo_indptr, None
699
+
700
+ def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
701
+ if has_been_filtered:
702
+ # in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
703
+ # therefore, we don't need to filter the batch again in scheduler
704
+ if len(new_indices) != len(self.topk_p):
1289
705
  logger.warning(
1290
- f"Bit mask generation took {tree_traverse_time} seconds with "
1291
- f"grammar: {req.grammar}"
706
+ f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen"
1292
707
  )
708
+ self.topk_p = self.topk_p[: len(new_indices)]
709
+ self.topk_index = self.topk_index[: len(new_indices)]
710
+ self.hidden_states = self.hidden_states[: len(new_indices)]
711
+ self.verified_id = self.verified_id[: len(new_indices)]
712
+ else:
713
+ # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
714
+ self.topk_p = self.topk_p[new_indices]
715
+ self.topk_index = self.topk_index[new_indices]
716
+ self.hidden_states = self.hidden_states[new_indices]
717
+ self.verified_id = self.verified_id[new_indices]
718
+
719
+ def merge_batch(self, spec_info: "EagleDraftInput"):
720
+ if self.hidden_states is None:
721
+ self.hidden_states = spec_info.hidden_states
722
+ self.verified_id = spec_info.verified_id
723
+ self.topk_p = spec_info.topk_p
724
+ self.topk_index = spec_info.topk_index
725
+ return
726
+ if spec_info.hidden_states is None:
727
+ return
728
+ self.hidden_states = torch.cat(
729
+ [self.hidden_states, spec_info.hidden_states], axis=0
730
+ )
731
+ self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
732
+ self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
733
+ self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
734
+
1293
735
 
1294
- verify_input.grammar = grammar
1295
- return allocate_token_bitmask
736
+ @dataclass
737
+ class EagleVerifyOutput:
738
+ # Draft input batch
739
+ draft_input: EagleDraftInput
740
+ # Logit outputs from target worker
741
+ logits_output: LogitsProcessorOutput
742
+ # Accepted token ids including the bonus token
743
+ verified_id: torch.Tensor
744
+ # Accepted token length per sequence in a batch in CPU.
745
+ accept_length_per_req_cpu: List[int]
746
+ # Accepted indices from logits_output.next_token_logits
747
+ accepted_indices: torch.Tensor