sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +302 -414
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +13 -8
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +144 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +773 -334
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +225 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +68 -37
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +102 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +56 -31
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +280 -81
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +135 -60
  181. sglang/srt/speculative/build_eagle_tree.py +8 -9
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
  183. sglang/srt/speculative/eagle_utils.py +92 -57
  184. sglang/srt/speculative/eagle_worker.py +238 -111
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- import dataclasses
3
+ from dataclasses import dataclass
4
4
  from typing import TYPE_CHECKING, List
5
5
 
6
6
  import torch
@@ -8,9 +8,10 @@ import torch.nn.functional as F
8
8
  import triton
9
9
  import triton.language as tl
10
10
 
11
- from sglang.srt.layers.attention.flashinfer_backend import (
12
- create_flashinfer_kv_indices_triton,
13
- )
11
+ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
12
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
13
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
14
+ from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
14
15
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
15
16
  from sglang.srt.speculative.build_eagle_tree import (
16
17
  build_tree_kernel,
@@ -25,7 +26,7 @@ if TYPE_CHECKING:
25
26
  from sglang.srt.managers.schedule_batch import ScheduleBatch
26
27
 
27
28
 
28
- @dataclasses.dataclass
29
+ @dataclass
29
30
  class EagleDraftInput:
30
31
  # The inputs for decode
31
32
  # shape: (b, topk)
@@ -46,57 +47,47 @@ class EagleDraftInput:
46
47
  kv_indptr: torch.Tensor = None
47
48
  kv_indices: torch.Tensor = None
48
49
 
50
+ # indices of unfinished requests during extend-after-decode
51
+ # e.g. [0, 2, 3, 4] if only the 1st request is finished
52
+ keep_indices: List[int] = None
53
+
49
54
  def prepare_for_extend(self, batch: ScheduleBatch):
50
- req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
51
- out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
52
- batch.out_cache_loc = out_cache_loc
55
+ assert batch.input_ids.numel() == batch.out_cache_loc.shape[0]
56
+ # Prefill only generate 1 token.
57
+ assert len(self.verified_id) == len(batch.seq_lens)
53
58
 
54
59
  pt = 0
55
- for i, req in enumerate(batch.reqs):
56
- req.req_pool_idx = req_pool_indices[i]
57
- pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
58
- assert seq_len - pre_len == req.extend_input_len
59
-
60
- if pre_len > 0:
61
- batch.req_to_token_pool.req_to_token[req.req_pool_idx][
62
- :pre_len
63
- ] = req.prefix_indices
64
-
65
- batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
66
- out_cache_loc[pt : pt + req.extend_input_len]
60
+ for i, extend_len in enumerate(batch.extend_lens):
61
+ input_ids = batch.input_ids[pt : pt + extend_len]
62
+ batch.input_ids[pt : pt + extend_len] = torch.concat(
63
+ (input_ids[1:], self.verified_id[i].reshape(1))
67
64
  )
68
-
69
- pt += req.extend_input_len
70
-
71
- # TODO: support batching inputs
72
- assert len(batch.extend_lens) == 1
73
- batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
65
+ pt += extend_len
74
66
 
75
67
  def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
76
- batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
68
+ assert self.verified_id.numel() == batch.out_cache_loc.shape[0]
77
69
  accept_length_cpu = batch.spec_info.accept_length_cpu
78
70
  batch.extend_lens = [x + 1 for x in accept_length_cpu]
71
+ batch.extend_num_tokens = sum(batch.extend_lens)
79
72
  batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
80
- batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
81
73
  seq_lens_cpu = batch.seq_lens.tolist()
74
+ assert len(batch.req_pool_indices) == len(batch.reqs)
82
75
 
83
76
  pt = 0
84
77
  i = 0
85
- for req in batch.reqs:
78
+ self.keep_indices = []
79
+ for idx, req in enumerate(batch.reqs):
86
80
  if req.finished():
87
81
  continue
82
+ self.keep_indices.append(idx)
88
83
  # assert seq_len - pre_len == req.extend_input_len
89
84
  input_len = batch.extend_lens[i]
90
85
  seq_len = seq_lens_cpu[i]
91
- batch.req_to_token_pool.req_to_token[req.req_pool_idx][
92
- seq_len - input_len : seq_len
93
- ] = batch.out_cache_loc[pt : pt + input_len]
94
86
  pt += input_len
95
87
  i += 1
96
- assert pt == batch.out_cache_loc.shape[0]
97
88
 
98
- self.positions = torch.empty_like(self.verified_id)
99
- new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
89
+ self.positions = torch.empty_like(self.verified_id, dtype=torch.long)
90
+ new_verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
100
91
  self.accept_length.add_(1)
101
92
 
102
93
  create_extend_spec_info[(self.accept_length.numel(),)](
@@ -117,14 +108,22 @@ class EagleDraftInput:
117
108
  self,
118
109
  req_pool_indices: torch.Tensor,
119
110
  paged_kernel_lens: torch.Tensor,
111
+ paged_kernel_lens_sum: int,
120
112
  req_to_token: torch.Tensor,
121
113
  ):
122
114
  bs = self.accept_length.numel()
115
+ keep_indices = torch.tensor(self.keep_indices, device=req_pool_indices.device)
116
+ req_pool_indices = req_pool_indices[keep_indices]
117
+ assert req_pool_indices.shape[0] == bs
118
+ assert req_pool_indices.shape[0] == paged_kernel_lens.shape[0]
119
+
123
120
  qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
124
121
  qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
125
122
 
126
123
  cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
127
124
  cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
125
+
126
+ # TODO: replace cum_kv_seq_len[-1] with paged_kernel_lens_sum to avoid the device sync.
128
127
  kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
129
128
 
130
129
  create_flashinfer_kv_indices_triton[(bs,)](
@@ -162,7 +161,21 @@ class EagleDraftInput:
162
161
  self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
163
162
 
164
163
 
165
- @dataclasses.dataclass
164
+ @dataclass
165
+ class EagleVerifyOutput:
166
+ # Draft input batch
167
+ draft_input: EagleDraftInput
168
+ # Logit outputs from target worker
169
+ logits_output: LogitsProcessorOutput
170
+ # Accepeted token ids including the bonus token
171
+ verified_id: torch.Tensor
172
+ # Accepeted token length per sequence in a batch in CPU.
173
+ accept_length_per_req_cpu: List[int]
174
+ # Accepeted indices from logits_output.next_token_logits
175
+ accepeted_indices_cpu: List[int]
176
+
177
+
178
+ @dataclass
166
179
  class EagleVerifyInput:
167
180
  draft_token: torch.Tensor
168
181
  custom_mask: torch.Tensor
@@ -267,6 +280,7 @@ class EagleVerifyInput:
267
280
  self,
268
281
  req_pool_indices: torch.Tensor,
269
282
  paged_kernel_lens: torch.Tensor,
283
+ paged_kernel_lens_sum: int,
270
284
  req_to_token: torch.Tensor,
271
285
  ):
272
286
  batch_size = len(req_pool_indices)
@@ -285,7 +299,11 @@ class EagleVerifyInput:
285
299
  paged_kernel_lens = paged_kernel_lens + self.draft_token_num
286
300
  cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
287
301
 
288
- kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
302
+ kv_indices = torch.empty(
303
+ paged_kernel_lens_sum + self.draft_token_num * batch_size,
304
+ dtype=torch.int32,
305
+ device="cuda",
306
+ )
289
307
 
290
308
  create_flashinfer_kv_indices_triton[(batch_size,)](
291
309
  req_to_token,
@@ -298,7 +316,21 @@ class EagleVerifyInput:
298
316
  )
299
317
  return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
300
318
 
301
- def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor:
319
+ def verify(
320
+ self,
321
+ batch: ScheduleBatch,
322
+ logits_output: torch.Tensor,
323
+ token_to_kv_pool_allocator: TokenToKVPoolAllocator,
324
+ ) -> torch.Tensor:
325
+ """WARNING: This API in-place modifies the states of logits_output
326
+
327
+ Verify and find accepted tokens based on logits output and batch
328
+ (which contains spec decoding information).
329
+
330
+ This API updates values inside logits_output based on the accepted
331
+ tokens. I.e., logits_output.next_token_logits only contains
332
+ accepeted token logits.
333
+ """
302
334
  draft_token = torch.cat(
303
335
  [self.draft_token, torch.full([1], -1, dtype=torch.int32, device="cuda")],
304
336
  dim=-1,
@@ -367,7 +399,6 @@ class EagleVerifyInput:
367
399
 
368
400
  new_accept_index = []
369
401
  unfinished_index = []
370
- finished_extend_len = {} # {rid:accept_length + 1}
371
402
  accept_index_cpu = accept_index.tolist()
372
403
  predict_cpu = predict.tolist()
373
404
  has_finished = False
@@ -382,7 +413,6 @@ class EagleVerifyInput:
382
413
  id = predict_cpu[idx]
383
414
  # if not found_finished:
384
415
  req.output_ids.append(id)
385
- finished_extend_len[req.rid] = j + 1
386
416
  req.check_finished()
387
417
  if req.finished():
388
418
  has_finished = True
@@ -400,11 +430,10 @@ class EagleVerifyInput:
400
430
  accept_index = accept_index[accept_index != -1]
401
431
  accept_length_cpu = accept_length.tolist()
402
432
  verified_id = predict[accept_index]
403
-
404
433
  evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
405
434
  evict_mask[accept_index] = False
406
435
  mem_need_free_idx = batch.out_cache_loc[evict_mask]
407
- batch.token_to_kv_pool.free(mem_need_free_idx)
436
+ token_to_kv_pool_allocator.free(mem_need_free_idx)
408
437
  assign_req_to_token_pool[(bs,)](
409
438
  batch.req_pool_indices,
410
439
  batch.req_to_token_pool.req_to_token,
@@ -427,20 +456,16 @@ class EagleVerifyInput:
427
456
  ]
428
457
  if has_finished:
429
458
  draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
430
- draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
431
- unfinished_index
432
- ]
433
459
  else:
434
460
  draft_input.seq_lens_for_draft_extend = batch.seq_lens
435
- draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
436
-
437
- logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
438
- return (
439
- draft_input,
440
- logits_output,
441
- verified_id,
442
- finished_extend_len,
443
- accept_length_cpu,
461
+ batch.out_cache_loc = batch.out_cache_loc[new_accept_index]
462
+
463
+ return EagleVerifyOutput(
464
+ draft_input=draft_input,
465
+ logits_output=logits_output,
466
+ verified_id=verified_id,
467
+ accept_length_per_req_cpu=accept_length_cpu,
468
+ accepeted_indices_cpu=accept_index,
444
469
  )
445
470
 
446
471
 
@@ -456,6 +481,18 @@ def eagle_verify_retrive(
456
481
  draft_token_num: tl.constexpr,
457
482
  max_len_upper: tl.constexpr,
458
483
  ):
484
+ """
485
+ Args:
486
+ retrive_index: Pointer to indices of draft tokens
487
+ accept_mask: Mask indicating which tokens were accepted
488
+ retrive_cum_len: Cumulative lengths of token sequences in a batch
489
+ accept_index (out): Accept token indices
490
+ accept_length (out): Length of accepted tokens per sequence in a batch
491
+ extract_index (out): Index for last accepted tokens
492
+ max_len: Maximum length in a batch
493
+ draft_token_num: Number of tokens speculatively generated
494
+ max_len_upper An upper bound for token sequence length
495
+ """
459
496
  pid = tl.program_id(axis=0)
460
497
 
461
498
  retrive_end = tl.load(retrive_cum_len + pid + 1)
@@ -649,7 +686,7 @@ def generate_draft_decode_kv_indices(
649
686
  tl.store(kv_indptr + zid, base + zid * iters)
650
687
 
651
688
 
652
- @torch.compile
689
+ @torch.compile(dynamic=True)
653
690
  def select_top_k_tokens(
654
691
  i: int,
655
692
  topk_p: torch.Tensor,
@@ -671,13 +708,11 @@ def select_top_k_tokens(
671
708
  .unsqueeze(0)
672
709
  .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
673
710
  )
674
-
675
711
  else:
676
712
  # The later decode steps
677
713
  expand_scores = torch.mul(
678
714
  scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
679
715
  ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
680
-
681
716
  topk_cs_p, topk_cs_index = fast_topk(
682
717
  expand_scores.flatten(start_dim=1), topk, dim=-1
683
718
  ) # (b, topk)