sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +220 -378
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +143 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +681 -259
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +224 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +44 -18
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +94 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +208 -28
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +136 -52
  181. sglang/srt/speculative/build_eagle_tree.py +2 -8
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  183. sglang/srt/speculative/eagle_utils.py +92 -58
  184. sglang/srt/speculative/eagle_worker.py +186 -94
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
- import dataclasses
4
- from typing import TYPE_CHECKING, List
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, Dict, List
5
5
 
6
6
  import torch
7
7
  import torch.nn.functional as F
8
8
  import triton
9
9
  import triton.language as tl
10
10
 
11
- from sglang.srt.layers.attention.flashinfer_backend import (
12
- create_flashinfer_kv_indices_triton,
13
- )
11
+ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
12
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
13
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
14
+ from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
14
15
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
15
16
  from sglang.srt.speculative.build_eagle_tree import (
16
17
  build_tree_kernel,
@@ -25,7 +26,7 @@ if TYPE_CHECKING:
25
26
  from sglang.srt.managers.schedule_batch import ScheduleBatch
26
27
 
27
28
 
28
- @dataclasses.dataclass
29
+ @dataclass
29
30
  class EagleDraftInput:
30
31
  # The inputs for decode
31
32
  # shape: (b, topk)
@@ -46,57 +47,46 @@ class EagleDraftInput:
46
47
  kv_indptr: torch.Tensor = None
47
48
  kv_indices: torch.Tensor = None
48
49
 
50
+ # indices of unfinished requests during extend-after-decode
51
+ # e.g. [0, 2, 3, 4] if only the 1st request is finished
52
+ keep_indices: List[int] = None
53
+
49
54
  def prepare_for_extend(self, batch: ScheduleBatch):
50
- req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
51
- out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
52
- batch.out_cache_loc = out_cache_loc
55
+ assert batch.input_ids.numel() == batch.out_cache_loc.shape[0]
56
+ # Prefill only generate 1 token.
57
+ assert len(self.verified_id) == len(batch.seq_lens)
53
58
 
54
59
  pt = 0
55
- for i, req in enumerate(batch.reqs):
56
- req.req_pool_idx = req_pool_indices[i]
57
- pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
58
- assert seq_len - pre_len == req.extend_input_len
59
-
60
- if pre_len > 0:
61
- batch.req_to_token_pool.req_to_token[req.req_pool_idx][
62
- :pre_len
63
- ] = req.prefix_indices
64
-
65
- batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
66
- out_cache_loc[pt : pt + req.extend_input_len]
60
+ for i, extend_len in enumerate(batch.extend_lens):
61
+ input_ids = batch.input_ids[pt : pt + extend_len]
62
+ batch.input_ids[pt : pt + extend_len] = torch.concat(
63
+ (input_ids[1:], self.verified_id[i].reshape(1))
67
64
  )
68
65
 
69
- pt += req.extend_input_len
70
-
71
- # TODO: support batching inputs
72
- assert len(batch.extend_lens) == 1
73
- batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
74
-
75
66
  def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
76
- batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
67
+ assert self.verified_id.numel() == batch.out_cache_loc.shape[0]
77
68
  accept_length_cpu = batch.spec_info.accept_length_cpu
78
69
  batch.extend_lens = [x + 1 for x in accept_length_cpu]
70
+ batch.extend_num_tokens = sum(batch.extend_lens)
79
71
  batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
80
- batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
81
72
  seq_lens_cpu = batch.seq_lens.tolist()
73
+ assert len(batch.req_pool_indices) == len(batch.reqs)
82
74
 
83
75
  pt = 0
84
76
  i = 0
85
- for req in batch.reqs:
77
+ self.keep_indices = []
78
+ for idx, req in enumerate(batch.reqs):
86
79
  if req.finished():
87
80
  continue
81
+ self.keep_indices.append(idx)
88
82
  # assert seq_len - pre_len == req.extend_input_len
89
83
  input_len = batch.extend_lens[i]
90
84
  seq_len = seq_lens_cpu[i]
91
- batch.req_to_token_pool.req_to_token[req.req_pool_idx][
92
- seq_len - input_len : seq_len
93
- ] = batch.out_cache_loc[pt : pt + input_len]
94
85
  pt += input_len
95
86
  i += 1
96
- assert pt == batch.out_cache_loc.shape[0]
97
87
 
98
- self.positions = torch.empty_like(self.verified_id)
99
- new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
88
+ self.positions = torch.empty_like(self.verified_id, dtype=torch.long)
89
+ new_verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
100
90
  self.accept_length.add_(1)
101
91
 
102
92
  create_extend_spec_info[(self.accept_length.numel(),)](
@@ -117,14 +107,22 @@ class EagleDraftInput:
117
107
  self,
118
108
  req_pool_indices: torch.Tensor,
119
109
  paged_kernel_lens: torch.Tensor,
110
+ paged_kernel_lens_sum: int,
120
111
  req_to_token: torch.Tensor,
121
112
  ):
122
113
  bs = self.accept_length.numel()
114
+ keep_indices = torch.tensor(self.keep_indices, device=req_pool_indices.device)
115
+ req_pool_indices = req_pool_indices[keep_indices]
116
+ assert req_pool_indices.shape[0] == bs
117
+ assert req_pool_indices.shape[0] == paged_kernel_lens.shape[0]
118
+
123
119
  qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
124
120
  qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
125
121
 
126
122
  cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
127
123
  cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
124
+
125
+ # TODO: replace cum_kv_seq_len[-1] with paged_kernel_lens_sum to avoid the device sync.
128
126
  kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
129
127
 
130
128
  create_flashinfer_kv_indices_triton[(bs,)](
@@ -162,7 +160,21 @@ class EagleDraftInput:
162
160
  self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
163
161
 
164
162
 
165
- @dataclasses.dataclass
163
+ @dataclass
164
+ class EagleVerifyOutput:
165
+ # Draft input batch
166
+ draft_input: EagleDraftInput
167
+ # Logit outputs from target worker
168
+ logits_output: LogitsProcessorOutput
169
+ # Accepeted token ids including the bonus token
170
+ verified_id: torch.Tensor
171
+ # Accepeted token length per sequence in a batch in CPU.
172
+ accept_length_per_req_cpu: List[int]
173
+ # Accepeted indices from logits_output.next_token_logits
174
+ accepeted_indices_cpu: List[int]
175
+
176
+
177
+ @dataclass
166
178
  class EagleVerifyInput:
167
179
  draft_token: torch.Tensor
168
180
  custom_mask: torch.Tensor
@@ -267,6 +279,7 @@ class EagleVerifyInput:
267
279
  self,
268
280
  req_pool_indices: torch.Tensor,
269
281
  paged_kernel_lens: torch.Tensor,
282
+ paged_kernel_lens_sum: int,
270
283
  req_to_token: torch.Tensor,
271
284
  ):
272
285
  batch_size = len(req_pool_indices)
@@ -285,7 +298,11 @@ class EagleVerifyInput:
285
298
  paged_kernel_lens = paged_kernel_lens + self.draft_token_num
286
299
  cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
287
300
 
288
- kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
301
+ kv_indices = torch.empty(
302
+ paged_kernel_lens_sum + self.draft_token_num * batch_size,
303
+ dtype=torch.int32,
304
+ device="cuda",
305
+ )
289
306
 
290
307
  create_flashinfer_kv_indices_triton[(batch_size,)](
291
308
  req_to_token,
@@ -298,7 +315,21 @@ class EagleVerifyInput:
298
315
  )
299
316
  return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
300
317
 
301
- def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor:
318
+ def verify(
319
+ self,
320
+ batch: ScheduleBatch,
321
+ logits_output: torch.Tensor,
322
+ token_to_kv_pool_allocator: TokenToKVPoolAllocator,
323
+ ) -> torch.Tensor:
324
+ """WARNING: This API in-place modifies the states of logits_output
325
+
326
+ Verify and find accepted tokens based on logits output and batch
327
+ (which contains spec decoding information).
328
+
329
+ This API updates values inside logits_output based on the accepted
330
+ tokens. I.e., logits_output.next_token_logits only contains
331
+ accepeted token logits.
332
+ """
302
333
  draft_token = torch.cat(
303
334
  [self.draft_token, torch.full([1], -1, dtype=torch.int32, device="cuda")],
304
335
  dim=-1,
@@ -367,7 +398,6 @@ class EagleVerifyInput:
367
398
 
368
399
  new_accept_index = []
369
400
  unfinished_index = []
370
- finished_extend_len = {} # {rid:accept_length + 1}
371
401
  accept_index_cpu = accept_index.tolist()
372
402
  predict_cpu = predict.tolist()
373
403
  has_finished = False
@@ -382,7 +412,6 @@ class EagleVerifyInput:
382
412
  id = predict_cpu[idx]
383
413
  # if not found_finished:
384
414
  req.output_ids.append(id)
385
- finished_extend_len[req.rid] = j + 1
386
415
  req.check_finished()
387
416
  if req.finished():
388
417
  has_finished = True
@@ -400,11 +429,10 @@ class EagleVerifyInput:
400
429
  accept_index = accept_index[accept_index != -1]
401
430
  accept_length_cpu = accept_length.tolist()
402
431
  verified_id = predict[accept_index]
403
-
404
432
  evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
405
433
  evict_mask[accept_index] = False
406
434
  mem_need_free_idx = batch.out_cache_loc[evict_mask]
407
- batch.token_to_kv_pool.free(mem_need_free_idx)
435
+ token_to_kv_pool_allocator.free(mem_need_free_idx)
408
436
  assign_req_to_token_pool[(bs,)](
409
437
  batch.req_pool_indices,
410
438
  batch.req_to_token_pool.req_to_token,
@@ -427,20 +455,16 @@ class EagleVerifyInput:
427
455
  ]
428
456
  if has_finished:
429
457
  draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
430
- draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
431
- unfinished_index
432
- ]
433
458
  else:
434
459
  draft_input.seq_lens_for_draft_extend = batch.seq_lens
435
- draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
436
-
437
- logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
438
- return (
439
- draft_input,
440
- logits_output,
441
- verified_id,
442
- finished_extend_len,
443
- accept_length_cpu,
460
+ batch.out_cache_loc = batch.out_cache_loc[new_accept_index]
461
+
462
+ return EagleVerifyOutput(
463
+ draft_input=draft_input,
464
+ logits_output=logits_output,
465
+ verified_id=verified_id,
466
+ accept_length_per_req_cpu=accept_length_cpu,
467
+ accepeted_indices_cpu=accept_index,
444
468
  )
445
469
 
446
470
 
@@ -456,6 +480,18 @@ def eagle_verify_retrive(
456
480
  draft_token_num: tl.constexpr,
457
481
  max_len_upper: tl.constexpr,
458
482
  ):
483
+ """
484
+ Args:
485
+ retrive_index: Pointer to indices of draft tokens
486
+ accept_mask: Mask indicating which tokens were accepted
487
+ retrive_cum_len: Cumulative lengths of token sequences in a batch
488
+ accept_index (out): Accept token indices
489
+ accept_length (out): Length of accepted tokens per sequence in a batch
490
+ extract_index (out): Index for last accepted tokens
491
+ max_len: Maximum length in a batch
492
+ draft_token_num: Number of tokens speculatively generated
493
+ max_len_upper An upper bound for token sequence length
494
+ """
459
495
  pid = tl.program_id(axis=0)
460
496
 
461
497
  retrive_end = tl.load(retrive_cum_len + pid + 1)
@@ -649,7 +685,7 @@ def generate_draft_decode_kv_indices(
649
685
  tl.store(kv_indptr + zid, base + zid * iters)
650
686
 
651
687
 
652
- @torch.compile
688
+ @torch.compile(dynamic=True)
653
689
  def select_top_k_tokens(
654
690
  i: int,
655
691
  topk_p: torch.Tensor,
@@ -671,13 +707,11 @@ def select_top_k_tokens(
671
707
  .unsqueeze(0)
672
708
  .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
673
709
  )
674
-
675
710
  else:
676
711
  # The later decode steps
677
712
  expand_scores = torch.mul(
678
713
  scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
679
714
  ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
680
-
681
715
  topk_cs_p, topk_cs_index = fast_topk(
682
716
  expand_scores.flatten(start_dim=1), topk, dim=-1
683
717
  ) # (b, topk)