sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/srt/layers/attention/__init__.py +14 -5
  3. sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
  4. sglang/srt/layers/attention/flashinfer_backend.py +211 -81
  5. sglang/srt/layers/attention/torch_native_backend.py +1 -38
  6. sglang/srt/layers/attention/triton_backend.py +20 -11
  7. sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
  8. sglang/srt/layers/logits_processor.py +167 -212
  9. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +187 -29
  31. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -6
  32. sglang/srt/layers/quantization/fp8.py +2 -2
  33. sglang/srt/layers/sampler.py +57 -21
  34. sglang/srt/layers/torchao_utils.py +17 -3
  35. sglang/srt/managers/io_struct.py +1 -2
  36. sglang/srt/managers/schedule_batch.py +26 -2
  37. sglang/srt/managers/schedule_policy.py +159 -90
  38. sglang/srt/managers/scheduler.py +62 -26
  39. sglang/srt/managers/tokenizer_manager.py +22 -20
  40. sglang/srt/managers/tp_worker.py +16 -4
  41. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  42. sglang/srt/model_executor/cuda_graph_runner.py +118 -73
  43. sglang/srt/model_executor/forward_batch_info.py +33 -8
  44. sglang/srt/model_executor/model_runner.py +63 -61
  45. sglang/srt/models/deepseek_v2.py +34 -7
  46. sglang/srt/models/grok.py +97 -26
  47. sglang/srt/openai_api/adapter.py +0 -17
  48. sglang/srt/openai_api/protocol.py +3 -3
  49. sglang/srt/sampling/sampling_batch_info.py +21 -0
  50. sglang/srt/sampling/sampling_params.py +9 -1
  51. sglang/srt/server.py +9 -5
  52. sglang/srt/server_args.py +108 -57
  53. sglang/srt/speculative/build_eagle_tree.py +347 -0
  54. sglang/srt/speculative/eagle_utils.py +618 -0
  55. sglang/srt/speculative/eagle_worker.py +170 -0
  56. sglang/srt/speculative/spec_info.py +5 -0
  57. sglang/srt/utils.py +15 -2
  58. sglang/version.py +1 -1
  59. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/METADATA +9 -8
  60. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/RECORD +63 -39
  61. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/WHEEL +1 -1
  62. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/LICENSE +0 -0
  63. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,618 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, List
4
+
5
+ import torch
6
+ import triton
7
+ import triton.language as tl
8
+
9
+ from sglang.srt.layers.attention.flashinfer_backend import (
10
+ create_flashinfer_kv_indices_triton,
11
+ )
12
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
13
+ from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
14
+ from sglang.srt.speculative.spec_info import SpecInfo
15
+
16
+ if TYPE_CHECKING:
17
+ from python.sglang.srt.layers.sampler import SampleOutput
18
+ from python.sglang.srt.managers.schedule_batch import ScheduleBatch
19
+ from sglang.srt.server_args import ServerArgs
20
+
21
+
22
+ @triton.jit
23
+ def eagle_verify_retrive(
24
+ retrive_index,
25
+ accept_mask,
26
+ retrive_cum_len,
27
+ accept_index,
28
+ accept_length,
29
+ extract_index,
30
+ max_len: tl.constexpr,
31
+ draft_token_num: tl.constexpr,
32
+ max_len_upper: tl.constexpr,
33
+ ):
34
+ pid = tl.program_id(axis=0)
35
+
36
+ retrive_end = tl.load(retrive_cum_len + pid + 1)
37
+ retrive_start = tl.load(retrive_cum_len + pid)
38
+ retrive_len = retrive_end - retrive_start
39
+ accept_ptr = accept_mask + retrive_start
40
+ accept_offset = tl.arange(0, draft_token_num)
41
+ accept_load_mask = accept_offset < retrive_len
42
+ accept_len_list = tl.load(
43
+ accept_ptr + accept_offset, mask=accept_load_mask, other=-1
44
+ )
45
+
46
+ accept_len = tl.max(accept_len_list)
47
+ max_index = tl.argmax(accept_len_list, axis=0, tie_break_left=True)
48
+ # triton is not support argmax with tie_break_right, so I need implement it by some way
49
+ mask_max = accept_len_list == accept_len
50
+
51
+ count_mask = tl.full(shape=[draft_token_num], value=0, dtype=tl.int32)
52
+ count = tl.sum(tl.where(mask_max, 1, count_mask))
53
+ if count > 1:
54
+ index = tl.arange(0, draft_token_num)
55
+ mask_left = index != max_index
56
+ remained_index = tl.where(mask_max and mask_left, index, 0)
57
+ max_index = tl.max(remained_index)
58
+
59
+ tl.store(accept_length + pid, accept_len)
60
+ retrive_index_ptr = retrive_index + (retrive_start + max_index) * max_len
61
+ retrive_offset = tl.arange(0, max_len_upper)
62
+ retrive_load_mask = retrive_offset < accept_len + 1
63
+ data = tl.load(retrive_index_ptr + retrive_offset, mask=retrive_load_mask)
64
+
65
+ tl.store(
66
+ accept_index + pid * max_len + retrive_offset, data, mask=retrive_load_mask
67
+ )
68
+
69
+ extract_load_ptr = accept_index + pid * max_len + accept_len
70
+ if accept_len == max_len - 1:
71
+ extract_data = tl.load(extract_load_ptr - 1)
72
+ tl.store(extract_index + pid * 2, extract_data)
73
+ extract_data = tl.load(extract_load_ptr)
74
+ tl.store(extract_index + pid * 2 + 1, extract_data)
75
+
76
+ else:
77
+ extract_data = tl.load(extract_load_ptr)
78
+ tl.store(extract_index + pid * 2, extract_data)
79
+
80
+
81
+ @triton.jit
82
+ def create_extend_spec_info(
83
+ verified_id,
84
+ seq_len,
85
+ accept_len,
86
+ accept_len_cum,
87
+ positions,
88
+ new_verified_id,
89
+ accept_len_upper: tl.constexpr,
90
+ ):
91
+ pid = tl.program_id(axis=0)
92
+ offset = 0 if pid == 0 else tl.load(accept_len_cum + pid - 1)
93
+ seq_length = tl.load(seq_len + pid)
94
+ accept_length = tl.load(accept_len + pid)
95
+ positions_ptr = positions + offset
96
+ data = tl.arange(0, accept_len_upper)
97
+ mask = data < accept_length
98
+ tl.store(positions_ptr + data, seq_length - accept_length + data, mask)
99
+
100
+ offset = tl.load(accept_len_cum + pid) - 1
101
+ verified_id_data = tl.load(verified_id + offset)
102
+ tl.store(new_verified_id + pid, verified_id_data)
103
+
104
+
105
+ @triton.jit
106
+ def assign_req_to_token_pool(
107
+ req_pool_indices,
108
+ req_to_token,
109
+ start_offset,
110
+ end_offset,
111
+ out_cache_loc,
112
+ pool_len: tl.constexpr,
113
+ bs_upper: tl.constexpr,
114
+ ):
115
+ BLOCK_SIZE: tl.constexpr = 32
116
+ pid = tl.program_id(axis=0)
117
+ kv_start = tl.load(start_offset + pid)
118
+ kv_end = tl.load(end_offset + pid)
119
+ token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
120
+
121
+ length_offset = tl.arange(0, bs_upper)
122
+ start = tl.load(start_offset + length_offset, mask=length_offset < pid)
123
+ end = tl.load(end_offset + length_offset, mask=length_offset < pid)
124
+ out_offset = tl.sum(end - start, axis=0)
125
+
126
+ out_cache_ptr = out_cache_loc + out_offset
127
+
128
+ save_offset = tl.arange(0, BLOCK_SIZE) + kv_start
129
+ load_offset = tl.arange(0, BLOCK_SIZE)
130
+
131
+ num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
132
+ for _ in range(num_loop):
133
+ mask = save_offset < kv_end
134
+ data = tl.load(out_cache_ptr + load_offset, mask=mask)
135
+ tl.store(token_pool + save_offset, data, mask=mask)
136
+ save_offset += BLOCK_SIZE
137
+ load_offset += BLOCK_SIZE
138
+
139
+
140
+ @triton.jit
141
+ def generate_draft_decode_kv_indices(
142
+ req_pool_indices,
143
+ req_to_token,
144
+ paged_kernel_lens,
145
+ kv_indices,
146
+ iters: tl.constexpr,
147
+ topk: tl.constexpr,
148
+ pool_len: tl.constexpr,
149
+ bs_upper: tl.constexpr,
150
+ iter_upper: tl.constexpr,
151
+ ):
152
+ BLOCK_SIZE: tl.constexpr = 128
153
+ bid = tl.program_id(axis=0)
154
+ topk_id = tl.program_id(axis=1)
155
+
156
+ load_offset = tl.arange(0, bs_upper)
157
+ seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid)
158
+ seq_len = tl.load(paged_kernel_lens + bid)
159
+ cum_seq_len = tl.sum(seq_lens)
160
+
161
+ kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
162
+ kv_ptr = kv_indices + kv_offset
163
+ token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
164
+
165
+ kv_offset = tl.arange(0, BLOCK_SIZE)
166
+ num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
167
+ for _ in range(num_loop):
168
+ mask = kv_offset < seq_len
169
+ data = tl.load(token_pool_ptr + kv_offset, mask=mask)
170
+ tl.store(kv_ptr + kv_offset, data, mask=mask)
171
+ kv_offset += BLOCK_SIZE
172
+
173
+ extend_offset = tl.arange(0, iter_upper)
174
+ extend_data = tl.load(
175
+ token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id,
176
+ mask=extend_offset < iters,
177
+ )
178
+ tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
179
+
180
+
181
+ class EAGLEDraftInput(SpecInfo):
182
+ hidden_states: torch.Tensor = None
183
+ verified_id: torch.Tensor = None
184
+ positions: torch.Tensor = None
185
+ accept_length: torch.Tensor = None
186
+ has_finished: bool = False
187
+ unfinished_index: List[int] = None
188
+
189
+ def init(self, server_args: ServerArgs):
190
+ self.prev_mode = ForwardMode.DECODE
191
+ self.sample_output = None
192
+ self.topk: int = server_args.speculative_eagle_topk
193
+ self.num_verify_token: int = server_args.speculative_num_draft_tokens
194
+ self.spec_steps = server_args.speculative_num_steps
195
+
196
+ self.scores: torch.Tensor = None
197
+ self.score_list: List[torch.Tensor] = []
198
+ self.token_list: List[torch.Tensor] = []
199
+ self.origin_score_list: List[torch.Tensor] = [] # used for sampling
200
+ self.parents_list: List[torch.Tensor] = []
201
+ self.cache_list: List[torch.Tenor] = []
202
+ self.iter = 0
203
+ self.root_token: int = None
204
+
205
+ assert self.topk <= 10, "topk should <= 10"
206
+
207
+ def prepare_for_extend(self, batch: ForwardBatch):
208
+ req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
209
+ out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
210
+ batch.out_cache_loc = out_cache_loc
211
+
212
+ pt = 0
213
+ for i, req in enumerate(batch.reqs):
214
+ req.req_pool_idx = req_pool_indices[i]
215
+ pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
216
+ assert seq_len - pre_len == req.extend_input_len
217
+
218
+ if pre_len > 0:
219
+ batch.req_to_token_pool.req_to_token[req.req_pool_idx][
220
+ :pre_len
221
+ ] = req.prefix_indices
222
+
223
+ batch.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
224
+ out_cache_loc[pt : pt + req.extend_input_len]
225
+ )
226
+
227
+ pt += req.extend_input_len
228
+
229
+ seq_lens = [0] + batch.extend_lens
230
+ input_ids = batch.input_ids.tolist()
231
+ verified_id = batch.spec_info.verified_id.tolist()
232
+ model_input_ids = []
233
+ for i in range(len(seq_lens) - 1):
234
+ model_input_ids.extend(
235
+ input_ids[seq_lens[i] + 1 : seq_lens[i + 1]] + [verified_id[i]]
236
+ )
237
+ batch.input_ids = torch.tensor(
238
+ model_input_ids, dtype=torch.int32, device="cuda"
239
+ )
240
+
241
+ def capture_for_decode(
242
+ self,
243
+ sample_output: SampleOutput,
244
+ hidden_states: torch.Tensor,
245
+ prev_mode: ForwardMode,
246
+ ):
247
+ self.sample_output = sample_output
248
+ self.prev_mode = prev_mode
249
+ self.hidden_states = hidden_states
250
+
251
+ def prepare_for_decode(self, batch: ScheduleBatch):
252
+ prob = self.sample_output # b * (1/topk), vocab
253
+ top = torch.topk(prob, self.topk, dim=-1)
254
+ topk_index, topk_p = top.indices, top.values # b * (1/topk), topk
255
+ if self.prev_mode == ForwardMode.DECODE:
256
+ scores = torch.mul(
257
+ self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk)
258
+ ) # (b, topk) mul (b * topk ,topk) -> b, topk, topk
259
+ topk_cs = torch.topk(
260
+ scores.flatten(start_dim=1), self.topk, dim=-1
261
+ ) # (b, topk)
262
+ topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
263
+ self.scores = topk_cs_p
264
+
265
+ selected_input_index = topk_cs_index.flatten() // self.topk # b* topk
266
+
267
+ batch.spec_info.hidden_states = batch.spec_info.hidden_states[
268
+ selected_input_index, :
269
+ ]
270
+ topk_index = topk_index.reshape(-1, self.topk**2)
271
+ batch.input_ids = torch.gather(
272
+ topk_index, index=topk_cs_index, dim=1
273
+ ).flatten()
274
+ batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
275
+ self.score_list.append(scores) # b, topk, topk
276
+ self.token_list.append(topk_index) # b, topk*topk
277
+ self.origin_score_list.append(topk_p.reshape(topk_index.shape))
278
+ self.parents_list.append(
279
+ topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk)
280
+ ) # b, topk
281
+
282
+ elif self.prev_mode in (ForwardMode.EXTEND, ForwardMode.DRAFT_EXTEND):
283
+ self.scores = topk_p # b, top_k
284
+ self.score_list.append(topk_p.unsqueeze(1))
285
+ self.token_list.append(topk_index)
286
+ self.origin_score_list.append(topk_p)
287
+ batch.spec_info.hidden_states = (
288
+ batch.spec_info.hidden_states.repeat_interleave(self.topk, 0)
289
+ )
290
+ batch.input_ids = topk_index.flatten()
291
+ batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel())
292
+ self.parents_list.append(
293
+ torch.arange(-1, self.topk, dtype=torch.long, device="cuda")
294
+ .unsqueeze(0)
295
+ .repeat(self.scores.shape[0], 1)
296
+ ) # b, topk+1
297
+ self.cache_list.append(batch.out_cache_loc)
298
+ self.positions = (
299
+ batch.seq_lens[:, None]
300
+ + torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter
301
+ ).flatten()
302
+
303
+ bs = batch.seq_lens.numel()
304
+ assign_req_to_token_pool[(bs,)](
305
+ batch.req_pool_indices,
306
+ batch.req_to_token_pool.req_to_token,
307
+ batch.seq_lens + self.topk * self.iter,
308
+ batch.seq_lens + self.topk * (self.iter + 1),
309
+ batch.out_cache_loc,
310
+ batch.req_to_token_pool.req_to_token.shape[1],
311
+ triton.next_power_of_2(bs),
312
+ )
313
+ self.iter += 1
314
+
315
+ def prepare_extend_after_decode(self, batch: ScheduleBatch):
316
+ batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
317
+ batch.extend_lens = (self.accept_length + 1).tolist()
318
+
319
+ pt = 0
320
+ seq_lens = batch.seq_lens.tolist()
321
+
322
+ i = 0
323
+
324
+ for req in batch.reqs:
325
+ if req.finished():
326
+ continue
327
+ # assert seq_len - pre_len == req.extend_input_len
328
+ input_len = self.accept_length[i] + 1
329
+ seq_len = seq_lens[i]
330
+ batch.req_to_token_pool.req_to_token[req.req_pool_idx][
331
+ seq_len - input_len : seq_len
332
+ ] = batch.out_cache_loc[pt : pt + input_len]
333
+ pt += input_len
334
+ i += 1
335
+
336
+ self.positions = torch.empty_like(self.verified_id)
337
+ new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
338
+ self.accept_length.add_(1)
339
+
340
+ create_extend_spec_info[(self.accept_length.numel(),)](
341
+ self.verified_id,
342
+ batch.seq_lens,
343
+ self.accept_length,
344
+ torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
345
+ self.positions,
346
+ new_verified_id,
347
+ triton.next_power_of_2(self.spec_steps + 1),
348
+ )
349
+
350
+ batch.input_ids = self.verified_id
351
+ self.verified_id = new_verified_id
352
+
353
+ def prepare_for_verify(self, batch: ScheduleBatch):
354
+ score_list = torch.cat(self.score_list, dim=1).flatten(
355
+ 1
356
+ ) # b, n, topk; n= 1+(self.iter-1)*self.topk
357
+ ss_token_list = torch.cat(
358
+ self.token_list, dim=1
359
+ ) # b, (self.topk+(self.iter-1)*self.topk)
360
+ origin_token_list = torch.cat(self.origin_score_list, dim=1)
361
+ top_scores = torch.topk(score_list, self.num_verify_token - 1, dim=-1)
362
+ top_scores_index = top_scores.indices
363
+ top_scores_index = torch.sort(top_scores_index).values
364
+
365
+ draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
366
+ scores = torch.gather(origin_token_list, index=top_scores_index, dim=1)
367
+ draft_tokens = torch.cat((self.verified_id.unsqueeze(1), draft_tokens), dim=1)
368
+ parent_list = torch.cat(self.parents_list[:-1], dim=1)
369
+
370
+ tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel(
371
+ parent_list,
372
+ top_scores_index,
373
+ batch.seq_lens,
374
+ self.topk,
375
+ self.iter - 1,
376
+ self.num_verify_token,
377
+ )
378
+
379
+ return EagleVerifyInput(
380
+ draft_tokens.flatten(),
381
+ scores.flatten(),
382
+ tree_mask,
383
+ position,
384
+ retrive_index,
385
+ retrive_cum_len,
386
+ self.num_verify_token,
387
+ )
388
+
389
+ def generate_attn_arg_decode(
390
+ self,
391
+ req_pool_indices: torch.Tensor,
392
+ paged_kernel_lens: torch.Tensor,
393
+ req_to_token: torch.Tensor,
394
+ ):
395
+ seq_num = req_pool_indices.numel()
396
+ bs = self.topk * req_pool_indices.numel()
397
+ seq_len = self.positions.reshape(-1).contiguous()
398
+
399
+ cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
400
+ cum_kv_seq_len[1:] = torch.cumsum(seq_len + 1, dim=0)
401
+ total_len = torch.sum(paged_kernel_lens).item()
402
+
403
+ kv_indices = torch.empty(
404
+ (total_len * self.topk + seq_num * self.iter * self.topk,),
405
+ dtype=torch.int32,
406
+ device="cuda",
407
+ )
408
+
409
+ generate_draft_decode_kv_indices[(req_pool_indices.numel(), self.topk)](
410
+ req_pool_indices,
411
+ req_to_token,
412
+ paged_kernel_lens,
413
+ kv_indices,
414
+ self.iter,
415
+ self.topk,
416
+ req_to_token.shape[1],
417
+ triton.next_power_of_2(seq_num),
418
+ triton.next_power_of_2(self.spec_steps),
419
+ )
420
+ return bs, kv_indices, cum_kv_seq_len
421
+
422
+ def clear(self):
423
+ self.iter = 0
424
+ self.score_list.clear()
425
+ self.positions = None
426
+
427
+ def clear_draft_cache(self, batch):
428
+ draft_cache = torch.cat(self.cache_list, dim=0)
429
+ batch.token_to_kv_pool.free(draft_cache)
430
+
431
+ def generate_attn_arg_prefill(
432
+ self,
433
+ req_pool_indices: torch.Tensor,
434
+ paged_kernel_lens: torch.Tensor,
435
+ req_to_token: torch.Tensor,
436
+ ):
437
+ bs = self.accept_length.numel()
438
+ qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
439
+ qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
440
+
441
+ cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
442
+ cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
443
+ kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
444
+
445
+ create_flashinfer_kv_indices_triton[(bs,)](
446
+ req_to_token,
447
+ req_pool_indices,
448
+ paged_kernel_lens,
449
+ cum_kv_seq_len,
450
+ None,
451
+ kv_indices,
452
+ req_to_token.size(1),
453
+ )
454
+
455
+ return kv_indices, cum_kv_seq_len, qo_indptr, None
456
+
457
+ def merge_batch(self, spec_info: EAGLEDraftInput):
458
+
459
+ self.hidden_states = torch.cat(
460
+ [self.hidden_states, spec_info.hidden_states], axis=0
461
+ )
462
+ self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
463
+ # self.positions = torch.cat([self.positions, spec_info.positions], axis=0)
464
+ self.sample_output = torch.cat([self.sample_output, spec_info.sample_output])
465
+
466
+
467
+ class EagleVerifyInput(SpecInfo):
468
+ def __init__(
469
+ self,
470
+ draft_token: torch.Tensor,
471
+ draft_score: torch.Tensor,
472
+ tree_mask: torch.Tensor,
473
+ positions: torch.Tensor,
474
+ retrive_index: torch.Tensor,
475
+ retrive_cum_len: torch.Tensor,
476
+ draft_token_num: int,
477
+ ):
478
+ self.draft_token = draft_token
479
+ self.draft_score = draft_score
480
+ self.custom_mask = tree_mask
481
+ self.positions = positions
482
+ self.retrive_index = retrive_index
483
+ self.retrive_cum_len = retrive_cum_len
484
+ self.draft_token_num = draft_token_num
485
+
486
+ def prepare_for_verify(self, batch: ScheduleBatch):
487
+ batch.input_ids = self.draft_token
488
+ batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
489
+ bs = batch.seq_lens.numel()
490
+ assign_req_to_token_pool[(bs,)](
491
+ batch.req_pool_indices,
492
+ batch.req_to_token_pool.req_to_token,
493
+ batch.seq_lens,
494
+ batch.seq_lens + self.draft_token_num,
495
+ batch.out_cache_loc,
496
+ batch.req_to_token_pool.req_to_token.shape[1],
497
+ triton.next_power_of_2(bs),
498
+ )
499
+
500
+ def generate_attn_arg_prefill(
501
+ self,
502
+ req_pool_indices: torch.Tensor,
503
+ paged_kernel_lens: torch.Tensor,
504
+ req_to_token: torch.Tensor,
505
+ ):
506
+ batch_size = len(req_pool_indices)
507
+ qo_indptr = torch.arange(
508
+ 0,
509
+ (1 + batch_size) * self.draft_token_num,
510
+ step=self.draft_token_num,
511
+ dtype=torch.int32,
512
+ device="cuda",
513
+ )
514
+
515
+ cum_kv_seq_len = torch.zeros(
516
+ (batch_size + 1,), dtype=torch.int32, device="cuda"
517
+ )
518
+
519
+ paged_kernel_lens = paged_kernel_lens + self.draft_token_num
520
+ cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
521
+
522
+ kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
523
+
524
+ create_flashinfer_kv_indices_triton[(batch_size,)](
525
+ req_to_token,
526
+ req_pool_indices,
527
+ paged_kernel_lens,
528
+ cum_kv_seq_len,
529
+ None,
530
+ kv_indices,
531
+ req_to_token.size(1),
532
+ )
533
+ return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
534
+
535
+ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor:
536
+ predict = torch.argmax(logits_output.next_token_logits, dim=-1)
537
+ predict = torch.cat(
538
+ [predict, torch.full([1], -1, dtype=torch.long, device="cuda")], dim=-1
539
+ )
540
+ draft_token = torch.cat(
541
+ [self.draft_token, torch.full([1], -1, dtype=torch.long, device="cuda")],
542
+ dim=-1,
543
+ )
544
+ target_predict = predict[self.retrive_index]
545
+ candidates = draft_token[self.retrive_index]
546
+ # logits = logits_output.next_token_logits[self.retrive_index]
547
+ # target_predict = torch.argmax(logits[:, :-1], dim=-1)
548
+ accept_mask = candidates[:, 1:] == target_predict[:, :-1]
549
+ accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1)
550
+ bs = self.retrive_cum_len.numel() - 1
551
+
552
+ max_draft_len = self.retrive_index.shape[-1]
553
+ accept_index = torch.full(
554
+ (bs, max_draft_len), -1, dtype=torch.long, device="cuda"
555
+ )
556
+ accept_length = torch.empty((bs,), dtype=torch.int, device="cuda")
557
+ extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda")
558
+ eagle_verify_retrive[(bs,)](
559
+ self.retrive_index.contiguous(),
560
+ accept_mask.contiguous(),
561
+ self.retrive_cum_len,
562
+ accept_index,
563
+ accept_length,
564
+ extract_index,
565
+ max_draft_len,
566
+ self.draft_token_num,
567
+ triton.next_power_of_2(max_draft_len),
568
+ )
569
+
570
+ accept_index = accept_index[accept_index != -1]
571
+ # extract_index = extract_index[extract_index != 0]
572
+
573
+ draft_input = EAGLEDraftInput()
574
+
575
+ accept_length_cpu = accept_length.tolist()
576
+ verified_id = predict[accept_index]
577
+ verified_id_cpu = verified_id.tolist()
578
+
579
+ evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
580
+ evict_mask[accept_index] = False
581
+ mem_need_free_idx = batch.out_cache_loc[evict_mask]
582
+ batch.token_to_kv_pool.free(mem_need_free_idx)
583
+ assign_req_to_token_pool[(bs,)](
584
+ batch.req_pool_indices,
585
+ batch.req_to_token_pool.req_to_token,
586
+ batch.seq_lens,
587
+ batch.seq_lens + accept_length + 1,
588
+ batch.out_cache_loc[accept_index],
589
+ batch.req_to_token_pool.req_to_token.shape[1],
590
+ triton.next_power_of_2(bs),
591
+ )
592
+ batch.seq_lens.add_(accept_length + 1)
593
+ new_accept_index = []
594
+ unfinished_index = []
595
+ finished_extend_len = {} # {rid:accept_length + 1}
596
+ # retracted_reqs, new_token_ratio = batch.retract_decode()
597
+
598
+ low = 0
599
+ for i, (req, verified_len) in enumerate(zip(batch.reqs, accept_length_cpu)):
600
+ req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1])
601
+ req.check_finished()
602
+ if req.finished():
603
+ draft_input.has_finished = True
604
+ else:
605
+ new_accept_index.append(accept_index[low : low + verified_len + 1])
606
+ unfinished_index.append(i)
607
+ low += verified_len + 1
608
+ finished_extend_len[req.rid] = verified_len + 1
609
+
610
+ if len(new_accept_index) > 0:
611
+ new_accept_index = torch.cat(new_accept_index, dim=0)
612
+ draft_input.verified_id = predict[new_accept_index]
613
+ draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
614
+ draft_input.accept_length = accept_length[unfinished_index]
615
+ draft_input.unfinished_index = unfinished_index
616
+
617
+ logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
618
+ return draft_input, logits_output, verified_id, finished_extend_len