sglang 0.4.2__py3-none-any.whl → 0.4.2.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/srt/constrained/outlines_backend.py +9 -1
  2. sglang/srt/custom_op.py +40 -0
  3. sglang/srt/entrypoints/engine.py +2 -2
  4. sglang/srt/layers/activation.py +10 -5
  5. sglang/srt/layers/attention/flashinfer_backend.py +284 -39
  6. sglang/srt/layers/attention/triton_backend.py +71 -7
  7. sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
  8. sglang/srt/layers/attention/triton_ops/prefill_attention.py +6 -0
  9. sglang/srt/layers/attention/vision.py +243 -40
  10. sglang/srt/layers/layernorm.py +1 -5
  11. sglang/srt/layers/moe/ep_moe/layer.py +1 -3
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
  20. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -11
  21. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
  22. sglang/srt/layers/moe/topk.py +4 -0
  23. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  30. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  33. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  36. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  40. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  44. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/fp8.py +7 -0
  46. sglang/srt/layers/quantization/fp8_kernel.py +140 -2
  47. sglang/srt/layers/rotary_embedding.py +29 -15
  48. sglang/srt/layers/sampler.py +9 -6
  49. sglang/srt/lora/backend/__init__.py +8 -0
  50. sglang/srt/lora/backend/base_backend.py +95 -0
  51. sglang/srt/lora/backend/flashinfer_backend.py +91 -0
  52. sglang/srt/lora/backend/triton_backend.py +61 -0
  53. sglang/srt/lora/lora.py +127 -112
  54. sglang/srt/lora/lora_manager.py +50 -18
  55. sglang/srt/lora/triton_ops/__init__.py +5 -0
  56. sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
  57. sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
  58. sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
  59. sglang/srt/managers/image_processor.py +77 -38
  60. sglang/srt/managers/scheduler.py +17 -3
  61. sglang/srt/mem_cache/base_prefix_cache.py +4 -0
  62. sglang/srt/mem_cache/chunk_cache.py +3 -0
  63. sglang/srt/mem_cache/radix_cache.py +30 -1
  64. sglang/srt/model_executor/cuda_graph_runner.py +77 -80
  65. sglang/srt/model_executor/forward_batch_info.py +58 -59
  66. sglang/srt/model_executor/model_runner.py +2 -2
  67. sglang/srt/models/minicpmv.py +129 -76
  68. sglang/srt/models/mllama.py +16 -56
  69. sglang/srt/models/qwen2.py +4 -1
  70. sglang/srt/models/qwen2_vl.py +19 -9
  71. sglang/srt/server_args.py +19 -2
  72. sglang/srt/speculative/build_eagle_tree.py +4 -2
  73. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
  74. sglang/srt/speculative/eagle_utils.py +361 -372
  75. sglang/srt/speculative/eagle_worker.py +177 -45
  76. sglang/srt/utils.py +7 -2
  77. sglang/test/runners.py +2 -0
  78. sglang/utils.py +42 -0
  79. sglang/version.py +1 -1
  80. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/METADATA +16 -7
  81. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/RECORD +84 -45
  82. sglang/srt/layers/custom_op_util.py +0 -25
  83. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/LICENSE +0 -0
  84. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/WHEEL +0 -0
  85. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import dataclasses
3
4
  from typing import TYPE_CHECKING, List
4
5
 
5
6
  import torch
@@ -9,201 +10,33 @@ import triton.language as tl
9
10
  from sglang.srt.layers.attention.flashinfer_backend import (
10
11
  create_flashinfer_kv_indices_triton,
11
12
  )
12
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
13
+ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
13
14
  from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
14
- from sglang.srt.speculative.spec_info import SpecInfo
15
15
 
16
16
  if TYPE_CHECKING:
17
17
  from sglang.srt.managers.schedule_batch import ScheduleBatch
18
- from sglang.srt.server_args import ServerArgs
19
18
 
20
19
 
21
- @triton.jit
22
- def eagle_verify_retrive(
23
- retrive_index,
24
- accept_mask,
25
- retrive_cum_len,
26
- accept_index,
27
- accept_length,
28
- extract_index,
29
- max_len: tl.constexpr,
30
- draft_token_num: tl.constexpr,
31
- max_len_upper: tl.constexpr,
32
- ):
33
- pid = tl.program_id(axis=0)
34
-
35
- retrive_end = tl.load(retrive_cum_len + pid + 1)
36
- retrive_start = tl.load(retrive_cum_len + pid)
37
- retrive_len = retrive_end - retrive_start
38
- accept_ptr = accept_mask + retrive_start
39
- accept_offset = tl.arange(0, draft_token_num)
40
- accept_load_mask = accept_offset < retrive_len
41
- accept_len_list = tl.load(
42
- accept_ptr + accept_offset, mask=accept_load_mask, other=-1
43
- )
44
-
45
- accept_len = tl.max(accept_len_list)
46
- max_index = tl.argmax(accept_len_list, axis=0, tie_break_left=True)
47
- # triton is not support argmax with tie_break_right, so I need implement it by some way
48
- mask_max = accept_len_list == accept_len
49
-
50
- count_mask = tl.full(shape=[draft_token_num], value=0, dtype=tl.int32)
51
- count = tl.sum(tl.where(mask_max, 1, count_mask))
52
- if count > 1:
53
- index = tl.arange(0, draft_token_num)
54
- mask_left = index != max_index
55
- remained_index = tl.where(mask_max and mask_left, index, 0)
56
- max_index = tl.max(remained_index)
57
-
58
- tl.store(accept_length + pid, accept_len)
59
- retrive_index_ptr = retrive_index + (retrive_start + max_index) * max_len
60
- retrive_offset = tl.arange(0, max_len_upper)
61
- retrive_load_mask = retrive_offset < accept_len + 1
62
- data = tl.load(retrive_index_ptr + retrive_offset, mask=retrive_load_mask)
63
-
64
- tl.store(
65
- accept_index + pid * max_len + retrive_offset, data, mask=retrive_load_mask
66
- )
67
-
68
- extract_load_ptr = accept_index + pid * max_len + accept_len
69
- if accept_len == max_len - 1:
70
- extract_data = tl.load(extract_load_ptr - 1)
71
- tl.store(extract_index + pid * 2, extract_data)
72
- extract_data = tl.load(extract_load_ptr)
73
- tl.store(extract_index + pid * 2 + 1, extract_data)
74
-
75
- else:
76
- extract_data = tl.load(extract_load_ptr)
77
- tl.store(extract_index + pid * 2, extract_data)
78
-
79
-
80
- @triton.jit
81
- def create_extend_spec_info(
82
- verified_id,
83
- seq_len,
84
- accept_len,
85
- accept_len_cum,
86
- positions,
87
- new_verified_id,
88
- accept_len_upper: tl.constexpr,
89
- ):
90
- pid = tl.program_id(axis=0)
91
- offset = 0 if pid == 0 else tl.load(accept_len_cum + pid - 1)
92
- seq_length = tl.load(seq_len + pid)
93
- accept_length = tl.load(accept_len + pid)
94
- positions_ptr = positions + offset
95
- data = tl.arange(0, accept_len_upper)
96
- mask = data < accept_length
97
- tl.store(positions_ptr + data, seq_length - accept_length + data, mask)
98
-
99
- offset = tl.load(accept_len_cum + pid) - 1
100
- verified_id_data = tl.load(verified_id + offset)
101
- tl.store(new_verified_id + pid, verified_id_data)
102
-
103
-
104
- @triton.jit
105
- def assign_req_to_token_pool(
106
- req_pool_indices,
107
- req_to_token,
108
- start_offset,
109
- end_offset,
110
- out_cache_loc,
111
- pool_len: tl.constexpr,
112
- bs_upper: tl.constexpr,
113
- ):
114
- BLOCK_SIZE: tl.constexpr = 32
115
- pid = tl.program_id(axis=0)
116
- kv_start = tl.load(start_offset + pid)
117
- kv_end = tl.load(end_offset + pid)
118
- token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
119
-
120
- length_offset = tl.arange(0, bs_upper)
121
- start = tl.load(start_offset + length_offset, mask=length_offset < pid)
122
- end = tl.load(end_offset + length_offset, mask=length_offset < pid)
123
- out_offset = tl.sum(end - start, axis=0)
124
-
125
- out_cache_ptr = out_cache_loc + out_offset
126
-
127
- save_offset = tl.arange(0, BLOCK_SIZE) + kv_start
128
- load_offset = tl.arange(0, BLOCK_SIZE)
129
-
130
- num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
131
- for _ in range(num_loop):
132
- mask = save_offset < kv_end
133
- data = tl.load(out_cache_ptr + load_offset, mask=mask)
134
- tl.store(token_pool + save_offset, data, mask=mask)
135
- save_offset += BLOCK_SIZE
136
- load_offset += BLOCK_SIZE
20
+ @dataclasses.dataclass
21
+ class EagleDraftInput:
22
+ # The inputs for decode
23
+ # shape: (b, topk)
24
+ topk_p: torch.Tensor = None
25
+ topk_index: torch.Tensor = None
26
+ # shape: (b, hidden_size)
27
+ hidden_states: torch.Tensor = None
28
+ capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL
137
29
 
30
+ # Inputs for extend
31
+ # shape: (b,)
32
+ verified_id: torch.Tensor = None
33
+ accept_length: torch.Tensor = None
34
+ accept_length_cpu: List[int] = None
138
35
 
139
- @triton.jit
140
- def generate_draft_decode_kv_indices(
141
- req_pool_indices,
142
- req_to_token,
143
- paged_kernel_lens,
144
- kv_indices,
145
- iters: tl.constexpr,
146
- topk: tl.constexpr,
147
- pool_len: tl.constexpr,
148
- bs_upper: tl.constexpr,
149
- iter_upper: tl.constexpr,
150
- ):
151
- BLOCK_SIZE: tl.constexpr = 128
152
- bid = tl.program_id(axis=0)
153
- topk_id = tl.program_id(axis=1)
154
-
155
- load_offset = tl.arange(0, bs_upper)
156
- seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid)
157
- seq_len = tl.load(paged_kernel_lens + bid)
158
- cum_seq_len = tl.sum(seq_lens)
159
-
160
- kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
161
- kv_ptr = kv_indices + kv_offset
162
- token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
163
-
164
- kv_offset = tl.arange(0, BLOCK_SIZE)
165
- num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
166
- for _ in range(num_loop):
167
- mask = kv_offset < seq_len
168
- data = tl.load(token_pool_ptr + kv_offset, mask=mask)
169
- tl.store(kv_ptr + kv_offset, data, mask=mask)
170
- kv_offset += BLOCK_SIZE
171
-
172
- extend_offset = tl.arange(0, iter_upper)
173
- extend_data = tl.load(
174
- token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id,
175
- mask=extend_offset < iters,
176
- )
177
- tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
178
-
179
-
180
- class EAGLEDraftInput(SpecInfo):
181
- def __init__(self):
182
- self.prev_mode = ForwardMode.DECODE
183
-
184
- self.scores: torch.Tensor = None
185
- self.score_list: List[torch.Tensor] = []
186
- self.token_list: List[torch.Tensor] = []
187
- self.origin_score_list: List[torch.Tensor] = [] # used for sampling
188
- self.parents_list: List[torch.Tensor] = []
189
- self.cache_list: List[torch.Tenor] = []
190
- self.iter = 0
191
-
192
- # shape: (b, hidden_size)
193
- self.hidden_states: torch.Tensor = None
194
- # shape: (b,)
195
- self.verified_id: torch.Tensor = None
196
- # shape: (b, vocab_size)
197
- self.sample_output: torch.Tensor = None
198
-
199
- self.positions: torch.Tensor = None
200
- self.accept_length: torch.Tensor = None
201
- self.accept_length_cpu: List[int] = None
202
-
203
- def load_server_args(self, server_args: ServerArgs):
204
- self.topk: int = server_args.speculative_eagle_topk
205
- self.num_verify_token: int = server_args.speculative_num_draft_tokens
206
- self.spec_steps = server_args.speculative_num_steps
36
+ # Inputs for the attention backends
37
+ # shape: (b + 1,)
38
+ kv_indptr: torch.Tensor = None
39
+ kv_indices: torch.Tensor = None
207
40
 
208
41
  def prepare_for_extend(self, batch: ScheduleBatch):
209
42
  req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
@@ -231,95 +64,12 @@ class EAGLEDraftInput(SpecInfo):
231
64
  assert len(batch.extend_lens) == 1
232
65
  batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
233
66
 
234
- def filter_batch(
235
- self,
236
- new_indices: torch.Tensor,
237
- ):
238
- self.sample_output = self.sample_output[: len(new_indices)]
239
- self.hidden_states = self.hidden_states[: len(new_indices)]
240
- self.verified_id = self.verified_id[: len(new_indices)]
241
-
242
- def prepare_for_decode(self, batch: ScheduleBatch):
243
- prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
244
- top = torch.topk(prob, self.topk, dim=-1)
245
- topk_index, topk_p = (
246
- top.indices,
247
- top.values,
248
- ) # shape: (b * top_k, top_k) or (b, top_k)
249
-
250
- if self.prev_mode.is_decode():
251
- scores = torch.mul(
252
- self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk)
253
- ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
254
- topk_cs = torch.topk(
255
- scores.flatten(start_dim=1), self.topk, dim=-1
256
- ) # (b, topk)
257
- topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
258
-
259
- selected_input_index = topk_cs_index.flatten() // self.topk + torch.arange(
260
- 0, batch.batch_size() * self.topk, step=self.topk, device="cuda"
261
- ).repeat_interleave(self.topk)
262
-
263
- batch.spec_info.hidden_states = batch.spec_info.hidden_states[
264
- selected_input_index, :
265
- ]
266
-
267
- topk_index = topk_index.reshape(-1, self.topk**2)
268
- batch.input_ids = torch.gather(
269
- topk_index, index=topk_cs_index, dim=1
270
- ).flatten()
271
- batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
272
-
273
- self.scores = topk_cs_p
274
- self.score_list.append(scores) # (b, topk, topk)
275
- self.token_list.append(topk_index) # (b, topk * topk)
276
- self.origin_score_list.append(topk_p.reshape(topk_index.shape))
277
- self.parents_list.append(
278
- topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk)
279
- ) # shape: (b, topk)
280
- else:
281
- # ForwardMode.EXTEND or ForwardMode.DRAFT_EXTEND
282
- batch.spec_info.hidden_states = (
283
- batch.spec_info.hidden_states.repeat_interleave(self.topk, dim=0)
284
- )
285
-
286
- batch.input_ids = topk_index.flatten()
287
- batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel())
288
-
289
- self.scores = topk_p # shape: (b, topk)
290
- self.score_list.append(topk_p.unsqueeze(1)) # shape: (b, 1, topk)
291
- self.token_list.append(topk_index) # shape: (b, topk)
292
- self.origin_score_list.append(topk_p)
293
- self.parents_list.append(
294
- torch.arange(-1, self.topk, dtype=torch.long, device="cuda")
295
- .unsqueeze(0)
296
- .repeat(self.scores.shape[0], 1)
297
- ) # shape: (b, topk + 1)
298
- self.cache_list.append(batch.out_cache_loc)
299
- self.positions = (
300
- batch.seq_lens[:, None]
301
- + torch.full(
302
- [1, self.topk], fill_value=self.iter, device="cuda", dtype=torch.long
303
- )
304
- ).flatten()
305
-
306
- bs = len(batch.seq_lens)
307
- assign_req_to_token_pool[(bs,)](
308
- batch.req_pool_indices,
309
- batch.req_to_token_pool.req_to_token,
310
- batch.seq_lens + self.topk * self.iter,
311
- batch.seq_lens + self.topk * (self.iter + 1),
312
- batch.out_cache_loc,
313
- batch.req_to_token_pool.req_to_token.shape[1],
314
- triton.next_power_of_2(bs),
315
- )
316
- self.iter += 1
317
-
318
- def prepare_extend_after_decode(self, batch: ScheduleBatch):
67
+ def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
319
68
  batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
320
69
  accept_length_cpu = batch.spec_info.accept_length_cpu
321
70
  batch.extend_lens = [x + 1 for x in accept_length_cpu]
322
71
  batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
72
+ batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
323
73
  seq_lens_cpu = batch.seq_lens.tolist()
324
74
 
325
75
  pt = 0
@@ -348,86 +98,13 @@ class EAGLEDraftInput(SpecInfo):
348
98
  torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
349
99
  self.positions,
350
100
  new_verified_id,
351
- triton.next_power_of_2(self.spec_steps + 1),
101
+ triton.next_power_of_2(speculative_num_steps + 1),
352
102
  )
353
103
 
354
104
  batch.seq_lens_sum = sum(seq_lens_cpu)
355
105
  batch.input_ids = self.verified_id
356
106
  self.verified_id = new_verified_id
357
107
 
358
- def prepare_for_verify(self, batch: ScheduleBatch):
359
- score_list = torch.cat(self.score_list, dim=1).flatten(
360
- 1
361
- ) # b, n, topk; n= 1+(self.iter-1)*self.topk
362
- ss_token_list = torch.cat(
363
- self.token_list, dim=1
364
- ) # b, (self.topk+(self.iter-1)*self.topk)
365
- origin_token_list = torch.cat(self.origin_score_list, dim=1)
366
- top_scores = torch.topk(score_list, self.num_verify_token - 1, dim=-1)
367
- top_scores_index = top_scores.indices
368
- top_scores_index = torch.sort(top_scores_index).values
369
-
370
- draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
371
- scores = torch.gather(origin_token_list, index=top_scores_index, dim=1)
372
- draft_tokens = torch.cat((self.verified_id.unsqueeze(1), draft_tokens), dim=1)
373
- parent_list = torch.cat(self.parents_list[:-1], dim=1)
374
-
375
- tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel(
376
- parent_list,
377
- top_scores_index,
378
- batch.seq_lens,
379
- self.topk,
380
- self.iter - 1,
381
- self.num_verify_token,
382
- )
383
-
384
- return EagleVerifyInput(
385
- draft_tokens.flatten(),
386
- scores.flatten(),
387
- tree_mask,
388
- position,
389
- retrive_index,
390
- retrive_cum_len,
391
- self.num_verify_token,
392
- )
393
-
394
- def generate_attn_arg_decode(
395
- self,
396
- req_pool_indices: torch.Tensor,
397
- paged_kernel_lens: torch.Tensor,
398
- req_to_token: torch.Tensor,
399
- ):
400
- seq_num = req_pool_indices.numel()
401
- bs = self.topk * req_pool_indices.numel()
402
- seq_len = self.positions.reshape(-1).contiguous()
403
-
404
- cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
405
- cum_kv_seq_len[1:] = torch.cumsum(seq_len + 1, dim=0)
406
- total_len = torch.sum(paged_kernel_lens).item()
407
-
408
- kv_indices = torch.empty(
409
- (total_len * self.topk + seq_num * self.iter * self.topk,),
410
- dtype=torch.int32,
411
- device="cuda",
412
- )
413
-
414
- generate_draft_decode_kv_indices[(req_pool_indices.numel(), self.topk)](
415
- req_pool_indices,
416
- req_to_token,
417
- paged_kernel_lens,
418
- kv_indices,
419
- self.iter,
420
- self.topk,
421
- req_to_token.shape[1],
422
- triton.next_power_of_2(seq_num),
423
- triton.next_power_of_2(self.spec_steps),
424
- )
425
- return bs, kv_indices, cum_kv_seq_len
426
-
427
- def clear_draft_cache(self, batch):
428
- draft_cache = torch.cat(self.cache_list, dim=0)
429
- batch.token_to_kv_pool.free(draft_cache)
430
-
431
108
  def generate_attn_arg_prefill(
432
109
  self,
433
110
  req_pool_indices: torch.Tensor,
@@ -454,12 +131,18 @@ class EAGLEDraftInput(SpecInfo):
454
131
 
455
132
  return kv_indices, cum_kv_seq_len, qo_indptr, None
456
133
 
457
- def merge_batch(self, spec_info: EAGLEDraftInput):
134
+ def filter_batch(self, new_indices: torch.Tensor):
135
+ self.topk_p = self.topk_p[: len(new_indices)]
136
+ self.topk_index = self.topk_index[: len(new_indices)]
137
+ self.hidden_states = self.hidden_states[: len(new_indices)]
138
+ self.verified_id = self.verified_id[: len(new_indices)]
139
+
140
+ def merge_batch(self, spec_info: EagleDraftInput):
458
141
  if self.hidden_states is None:
459
142
  self.hidden_states = spec_info.hidden_states
460
143
  self.verified_id = spec_info.verified_id
461
- self.sample_output = spec_info.sample_output
462
- self.prev_mode = spec_info.prev_mode
144
+ self.topk_p = spec_info.topk_p
145
+ self.topk_index = spec_info.topk_index
463
146
  return
464
147
  if spec_info.hidden_states is None:
465
148
  return
@@ -467,32 +150,68 @@ class EAGLEDraftInput(SpecInfo):
467
150
  [self.hidden_states, spec_info.hidden_states], axis=0
468
151
  )
469
152
  self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
470
- self.sample_output = torch.cat([self.sample_output, spec_info.sample_output])
471
-
472
-
473
- class EagleVerifyInput(SpecInfo):
474
- def __init__(
475
- self,
476
- draft_token: torch.Tensor,
477
- draft_score: torch.Tensor,
478
- tree_mask: torch.Tensor,
479
- positions: torch.Tensor,
480
- retrive_index: torch.Tensor,
481
- retrive_cum_len: torch.Tensor,
482
- draft_token_num: int,
153
+ self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
154
+ self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
155
+
156
+
157
+ @dataclasses.dataclass
158
+ class EagleVerifyInput:
159
+ draft_token: torch.Tensor
160
+ custom_mask: torch.Tensor
161
+ positions: torch.Tensor
162
+ retrive_index: torch.Tensor
163
+ retrive_cum_len: torch.Tensor
164
+ draft_token_num: int
165
+ capture_hidden_mode: CaptureHiddenMode
166
+
167
+ @classmethod
168
+ def create(
169
+ cls,
170
+ verified_id: torch.Tensor,
171
+ score_list: List[torch.Tensor],
172
+ token_list: List[torch.Tensor],
173
+ parents_list: List[torch.Tensor],
174
+ seq_lens: torch.Tensor,
175
+ seq_lens_sum: int,
176
+ topk: int,
177
+ spec_steps: int,
178
+ num_verify_token: int,
483
179
  ):
484
- self.draft_token = draft_token
485
- self.draft_score = draft_score
486
- self.custom_mask = tree_mask
487
- self.positions = positions
488
- self.retrive_index = retrive_index
489
- self.retrive_cum_len = retrive_cum_len
490
- self.draft_token_num = draft_token_num
180
+ score_list = torch.cat(score_list, dim=1).flatten(
181
+ 1
182
+ ) # b, n, topk; n= 1 + (num_steps-1) * self.topk
183
+ ss_token_list = torch.cat(
184
+ token_list, dim=1
185
+ ) # b, (self.topk + (num_steps-1) * self.topk)
186
+ top_scores = torch.topk(score_list, num_verify_token - 1, dim=-1)
187
+ top_scores_index = top_scores.indices
188
+ top_scores_index = torch.sort(top_scores_index).values
189
+ draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
190
+ draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1)
191
+ parent_list = torch.cat(parents_list[:-1], dim=1)
192
+ tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel(
193
+ parent_list,
194
+ top_scores_index,
195
+ seq_lens,
196
+ seq_lens_sum,
197
+ topk,
198
+ spec_steps,
199
+ num_verify_token,
200
+ )
201
+ return cls(
202
+ draft_tokens.flatten(),
203
+ tree_mask,
204
+ position,
205
+ retrive_index,
206
+ retrive_cum_len,
207
+ num_verify_token,
208
+ CaptureHiddenMode.FULL,
209
+ )
491
210
 
492
211
  def prepare_for_verify(self, batch: ScheduleBatch):
493
212
  batch.input_ids = self.draft_token
494
213
  batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
495
- bs = batch.seq_lens.numel()
214
+ bs = batch.batch_size()
496
215
  assign_req_to_token_pool[(bs,)](
497
216
  batch.req_pool_indices,
498
217
  batch.req_to_token_pool.req_to_token,
@@ -573,7 +292,6 @@ class EagleVerifyInput(SpecInfo):
573
292
  triton.next_power_of_2(max_draft_len),
574
293
  )
575
294
 
576
- draft_input = EAGLEDraftInput()
577
295
  new_accept_index = []
578
296
  unfinished_index = []
579
297
  finished_extend_len = {} # {rid:accept_length + 1}
@@ -625,18 +343,23 @@ class EagleVerifyInput(SpecInfo):
625
343
  )
626
344
  batch.seq_lens.add_(accept_length + 1)
627
345
 
346
+ draft_input = EagleDraftInput()
628
347
  if len(new_accept_index) > 0:
629
348
  new_accept_index = torch.tensor(new_accept_index, device="cuda")
630
- draft_input.verified_id = predict[new_accept_index]
631
349
  draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
350
+ draft_input.verified_id = predict[new_accept_index]
632
351
  draft_input.accept_length = accept_length[unfinished_index]
633
352
  draft_input.accept_length_cpu = [
634
353
  accept_length_cpu[i] for i in unfinished_index
635
354
  ]
636
355
  if has_finished:
637
356
  draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
357
+ draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
358
+ unfinished_index
359
+ ]
638
360
  else:
639
361
  draft_input.seq_lens_for_draft_extend = batch.seq_lens
362
+ draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
640
363
 
641
364
  logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
642
365
  return (
@@ -646,3 +369,269 @@ class EagleVerifyInput(SpecInfo):
646
369
  finished_extend_len,
647
370
  accept_length_cpu,
648
371
  )
372
+
373
+
374
+ @triton.jit
375
+ def eagle_verify_retrive(
376
+ retrive_index,
377
+ accept_mask,
378
+ retrive_cum_len,
379
+ accept_index,
380
+ accept_length,
381
+ extract_index,
382
+ max_len: tl.constexpr,
383
+ draft_token_num: tl.constexpr,
384
+ max_len_upper: tl.constexpr,
385
+ ):
386
+ pid = tl.program_id(axis=0)
387
+
388
+ retrive_end = tl.load(retrive_cum_len + pid + 1)
389
+ retrive_start = tl.load(retrive_cum_len + pid)
390
+ retrive_len = retrive_end - retrive_start
391
+ accept_ptr = accept_mask + retrive_start
392
+ accept_offset = tl.arange(0, draft_token_num)
393
+ accept_load_mask = accept_offset < retrive_len
394
+ accept_len_list = tl.load(
395
+ accept_ptr + accept_offset, mask=accept_load_mask, other=-1
396
+ )
397
+
398
+ accept_len = tl.max(accept_len_list)
399
+ max_index = tl.argmax(accept_len_list, axis=0, tie_break_left=True)
400
+ # triton is not support argmax with tie_break_right, so I need implement it by some way
401
+ mask_max = accept_len_list == accept_len
402
+
403
+ count_mask = tl.full(shape=[draft_token_num], value=0, dtype=tl.int32)
404
+ count = tl.sum(tl.where(mask_max, 1, count_mask))
405
+ if count > 1:
406
+ index = tl.arange(0, draft_token_num)
407
+ mask_left = index != max_index
408
+ remained_index = tl.where(mask_max and mask_left, index, 0)
409
+ max_index = tl.max(remained_index)
410
+
411
+ tl.store(accept_length + pid, accept_len)
412
+ retrive_index_ptr = retrive_index + (retrive_start + max_index) * max_len
413
+ retrive_offset = tl.arange(0, max_len_upper)
414
+ retrive_load_mask = retrive_offset < accept_len + 1
415
+ data = tl.load(retrive_index_ptr + retrive_offset, mask=retrive_load_mask)
416
+
417
+ tl.store(
418
+ accept_index + pid * max_len + retrive_offset, data, mask=retrive_load_mask
419
+ )
420
+
421
+ extract_load_ptr = accept_index + pid * max_len + accept_len
422
+ if accept_len == max_len - 1:
423
+ extract_data = tl.load(extract_load_ptr - 1)
424
+ tl.store(extract_index + pid * 2, extract_data)
425
+ extract_data = tl.load(extract_load_ptr)
426
+ tl.store(extract_index + pid * 2 + 1, extract_data)
427
+
428
+ else:
429
+ extract_data = tl.load(extract_load_ptr)
430
+ tl.store(extract_index + pid * 2, extract_data)
431
+
432
+
433
+ @triton.jit
434
+ def create_extend_spec_info(
435
+ verified_id,
436
+ seq_len,
437
+ accept_len,
438
+ accept_len_cum,
439
+ positions,
440
+ new_verified_id,
441
+ accept_len_upper: tl.constexpr,
442
+ ):
443
+ pid = tl.program_id(axis=0)
444
+ offset = 0 if pid == 0 else tl.load(accept_len_cum + pid - 1)
445
+ seq_length = tl.load(seq_len + pid)
446
+ accept_length = tl.load(accept_len + pid)
447
+ positions_ptr = positions + offset
448
+ data = tl.arange(0, accept_len_upper)
449
+ mask = data < accept_length
450
+ tl.store(positions_ptr + data, seq_length - accept_length + data, mask)
451
+
452
+ offset = tl.load(accept_len_cum + pid) - 1
453
+ verified_id_data = tl.load(verified_id + offset)
454
+ tl.store(new_verified_id + pid, verified_id_data)
455
+
456
+
457
+ @triton.jit
458
+ def assign_req_to_token_pool(
459
+ req_pool_indices,
460
+ req_to_token,
461
+ start_offset,
462
+ end_offset,
463
+ out_cache_loc,
464
+ pool_len: tl.constexpr,
465
+ bs_upper: tl.constexpr,
466
+ ):
467
+ BLOCK_SIZE: tl.constexpr = 32
468
+ pid = tl.program_id(axis=0)
469
+ kv_start = tl.load(start_offset + pid)
470
+ kv_end = tl.load(end_offset + pid)
471
+ token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
472
+
473
+ length_offset = tl.arange(0, bs_upper)
474
+ start = tl.load(start_offset + length_offset, mask=length_offset < pid)
475
+ end = tl.load(end_offset + length_offset, mask=length_offset < pid)
476
+ out_offset = tl.sum(end - start, axis=0)
477
+
478
+ out_cache_ptr = out_cache_loc + out_offset
479
+
480
+ save_offset = tl.arange(0, BLOCK_SIZE) + kv_start
481
+ load_offset = tl.arange(0, BLOCK_SIZE)
482
+
483
+ num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
484
+ for _ in range(num_loop):
485
+ mask = save_offset < kv_end
486
+ data = tl.load(out_cache_ptr + load_offset, mask=mask)
487
+ tl.store(token_pool + save_offset, data, mask=mask)
488
+ save_offset += BLOCK_SIZE
489
+ load_offset += BLOCK_SIZE
490
+
491
+
492
+ @triton.jit
493
+ def assign_draft_cache_locs(
494
+ req_pool_indices,
495
+ req_to_token,
496
+ seq_lens,
497
+ out_cache_loc,
498
+ pool_len: tl.constexpr,
499
+ topk: tl.constexpr,
500
+ speculative_num_steps: tl.constexpr,
501
+ ):
502
+ BLOCK_SIZE: tl.constexpr = 32
503
+ pid = tl.program_id(axis=0)
504
+ kv_start = tl.load(seq_lens + pid)
505
+ kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
506
+ token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
507
+ out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
508
+
509
+ num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
510
+ for i in range(num_loop):
511
+ save_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + kv_start
512
+ load_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
513
+ mask = save_offset < kv_end
514
+ data = tl.load(out_cache_ptr + load_offset, mask=mask)
515
+ tl.store(token_pool + save_offset, data, mask=mask)
516
+
517
+
518
+ @triton.jit
519
+ def generate_draft_decode_kv_indices(
520
+ req_pool_indices,
521
+ req_to_token,
522
+ paged_kernel_lens,
523
+ kv_indices,
524
+ kv_indptr,
525
+ positions,
526
+ num_seqs: tl.constexpr,
527
+ topk: tl.constexpr,
528
+ pool_len: tl.constexpr,
529
+ kv_indices_stride: tl.constexpr,
530
+ kv_indptr_stride: tl.constexpr,
531
+ bs_upper: tl.constexpr,
532
+ iter_upper: tl.constexpr,
533
+ num_tokens_upper: tl.constexpr,
534
+ ):
535
+ BLOCK_SIZE: tl.constexpr = 128
536
+ iters = tl.program_id(axis=0)
537
+ bid = tl.program_id(axis=1)
538
+ topk_id = tl.program_id(axis=2)
539
+
540
+ kv_indices += kv_indices_stride * iters
541
+ kv_indptr += kv_indptr_stride * iters
542
+ iters += 1
543
+
544
+ load_offset = tl.arange(0, bs_upper)
545
+ seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid)
546
+ seq_len = tl.load(paged_kernel_lens + bid)
547
+ cum_seq_len = tl.sum(seq_lens)
548
+
549
+ kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
550
+ kv_ptr = kv_indices + kv_offset
551
+ token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
552
+
553
+ kv_offset = tl.arange(0, BLOCK_SIZE)
554
+ num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
555
+ for _ in range(num_loop):
556
+ mask = kv_offset < seq_len
557
+ data = tl.load(token_pool_ptr + kv_offset, mask=mask)
558
+ tl.store(kv_ptr + kv_offset, data, mask=mask)
559
+ kv_offset += BLOCK_SIZE
560
+
561
+ extend_offset = tl.arange(0, iter_upper)
562
+ extend_data = tl.load(
563
+ token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id,
564
+ mask=extend_offset < iters,
565
+ )
566
+ tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
567
+
568
+ # Update kv_indptr
569
+ bs_offset = tl.arange(0, num_tokens_upper)
570
+
571
+ zid = bid * topk + topk_id
572
+ if zid == 0:
573
+ zid = num_seqs * topk
574
+ positions = tl.load(positions + bs_offset, mask=bs_offset < zid)
575
+ base = tl.sum(positions)
576
+ tl.store(kv_indptr + zid, base + zid * iters)
577
+
578
+
579
+ @torch.compile
580
+ def select_top_k_tokens(
581
+ i: int,
582
+ topk_p: torch.Tensor,
583
+ topk_index: torch.Tensor,
584
+ hidden_states: torch.Tensor,
585
+ scores: torch.Tensor,
586
+ topk: int,
587
+ ):
588
+ if i == 0:
589
+ # The first step after extend
590
+ input_ids = topk_index.flatten()
591
+ hidden_states = hidden_states.repeat_interleave(topk, dim=0)
592
+ scores = topk_p # shape: (b, topk)
593
+
594
+ tree_info = (
595
+ topk_p.unsqueeze(1), # shape: (b, 1, topk)
596
+ topk_index, # shape: (b, topk)
597
+ torch.arange(-1, topk, dtype=torch.long, device="cuda")
598
+ .unsqueeze(0)
599
+ .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
600
+ )
601
+
602
+ else:
603
+ # The later decode steps
604
+ expand_scores = torch.mul(
605
+ scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
606
+ ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
607
+
608
+ topk_cs_p, topk_cs_index = fast_topk(
609
+ expand_scores.flatten(start_dim=1), topk, dim=-1
610
+ ) # (b, topk)
611
+ scores = topk_cs_p # shape: (b, topk)
612
+
613
+ topk_index = topk_index.reshape(-1, topk**2)
614
+ input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
615
+
616
+ selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
617
+ 0, hidden_states.shape[0], step=topk, device="cuda"
618
+ ).repeat_interleave(topk)
619
+ hidden_states = hidden_states[selected_input_index, :]
620
+
621
+ tree_info = (
622
+ expand_scores, # shape: (b, topk, topk)
623
+ topk_index, # shape: (b, topk * topk)
624
+ topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk)
625
+ )
626
+
627
+ return input_ids, hidden_states, scores, tree_info
628
+
629
+
630
+ def fast_topk(values, topk, dim):
631
+ if topk == 1:
632
+ # Use max along the specified dimension to get both value and index
633
+ max_value, max_index = torch.max(values, dim=dim)
634
+ return max_value.unsqueeze(1), max_index.unsqueeze(1)
635
+ else:
636
+ # Use topk for efficiency with larger k values
637
+ return torch.topk(values, topk, dim=dim)