sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/bench_serving.py +18 -1
  3. sglang/lang/interpreter.py +71 -1
  4. sglang/lang/ir.py +2 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/chatglm.py +78 -0
  7. sglang/srt/configs/dbrx.py +279 -0
  8. sglang/srt/configs/model_config.py +1 -1
  9. sglang/srt/hf_transformers_utils.py +9 -14
  10. sglang/srt/layers/attention/__init__.py +22 -6
  11. sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
  12. sglang/srt/layers/attention/flashinfer_backend.py +215 -83
  13. sglang/srt/layers/attention/torch_native_backend.py +1 -38
  14. sglang/srt/layers/attention/triton_backend.py +20 -11
  15. sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
  16. sglang/srt/layers/linear.py +159 -55
  17. sglang/srt/layers/logits_processor.py +170 -215
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +198 -29
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -7
  41. sglang/srt/layers/parameter.py +431 -0
  42. sglang/srt/layers/quantization/__init__.py +3 -2
  43. sglang/srt/layers/quantization/fp8.py +3 -3
  44. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  45. sglang/srt/layers/sampler.py +57 -21
  46. sglang/srt/layers/torchao_utils.py +17 -3
  47. sglang/srt/layers/vocab_parallel_embedding.py +1 -1
  48. sglang/srt/managers/cache_controller.py +307 -0
  49. sglang/srt/managers/data_parallel_controller.py +2 -0
  50. sglang/srt/managers/io_struct.py +1 -2
  51. sglang/srt/managers/schedule_batch.py +33 -3
  52. sglang/srt/managers/schedule_policy.py +159 -90
  53. sglang/srt/managers/scheduler.py +68 -28
  54. sglang/srt/managers/session_controller.py +1 -1
  55. sglang/srt/managers/tokenizer_manager.py +27 -21
  56. sglang/srt/managers/tp_worker.py +16 -4
  57. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  58. sglang/srt/mem_cache/memory_pool.py +206 -1
  59. sglang/srt/metrics/collector.py +22 -30
  60. sglang/srt/model_executor/cuda_graph_runner.py +129 -77
  61. sglang/srt/model_executor/forward_batch_info.py +51 -21
  62. sglang/srt/model_executor/model_runner.py +72 -64
  63. sglang/srt/models/chatglm.py +1 -1
  64. sglang/srt/models/dbrx.py +1 -1
  65. sglang/srt/models/deepseek_v2.py +34 -7
  66. sglang/srt/models/grok.py +109 -29
  67. sglang/srt/models/llama.py +9 -2
  68. sglang/srt/openai_api/adapter.py +0 -17
  69. sglang/srt/openai_api/protocol.py +3 -3
  70. sglang/srt/sampling/sampling_batch_info.py +22 -0
  71. sglang/srt/sampling/sampling_params.py +9 -1
  72. sglang/srt/server.py +20 -13
  73. sglang/srt/server_args.py +120 -58
  74. sglang/srt/speculative/build_eagle_tree.py +347 -0
  75. sglang/srt/speculative/eagle_utils.py +626 -0
  76. sglang/srt/speculative/eagle_worker.py +184 -0
  77. sglang/srt/speculative/spec_info.py +5 -0
  78. sglang/srt/utils.py +47 -7
  79. sglang/test/test_programs.py +23 -1
  80. sglang/test/test_utils.py +36 -7
  81. sglang/version.py +1 -1
  82. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +12 -12
  83. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +86 -57
  84. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
  85. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
  86. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,626 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, List
4
+
5
+ import torch
6
+ import triton
7
+ import triton.language as tl
8
+
9
+ from sglang.srt.layers.attention.flashinfer_backend import (
10
+ create_flashinfer_kv_indices_triton,
11
+ )
12
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
13
+ from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
14
+ from sglang.srt.speculative.spec_info import SpecInfo
15
+
16
+ if TYPE_CHECKING:
17
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
18
+ from sglang.srt.server_args import ServerArgs
19
+
20
+
21
+ @triton.jit
22
+ def eagle_verify_retrive(
23
+ retrive_index,
24
+ accept_mask,
25
+ retrive_cum_len,
26
+ accept_index,
27
+ accept_length,
28
+ extract_index,
29
+ max_len: tl.constexpr,
30
+ draft_token_num: tl.constexpr,
31
+ max_len_upper: tl.constexpr,
32
+ ):
33
+ pid = tl.program_id(axis=0)
34
+
35
+ retrive_end = tl.load(retrive_cum_len + pid + 1)
36
+ retrive_start = tl.load(retrive_cum_len + pid)
37
+ retrive_len = retrive_end - retrive_start
38
+ accept_ptr = accept_mask + retrive_start
39
+ accept_offset = tl.arange(0, draft_token_num)
40
+ accept_load_mask = accept_offset < retrive_len
41
+ accept_len_list = tl.load(
42
+ accept_ptr + accept_offset, mask=accept_load_mask, other=-1
43
+ )
44
+
45
+ accept_len = tl.max(accept_len_list)
46
+ max_index = tl.argmax(accept_len_list, axis=0, tie_break_left=True)
47
+ # triton is not support argmax with tie_break_right, so I need implement it by some way
48
+ mask_max = accept_len_list == accept_len
49
+
50
+ count_mask = tl.full(shape=[draft_token_num], value=0, dtype=tl.int32)
51
+ count = tl.sum(tl.where(mask_max, 1, count_mask))
52
+ if count > 1:
53
+ index = tl.arange(0, draft_token_num)
54
+ mask_left = index != max_index
55
+ remained_index = tl.where(mask_max and mask_left, index, 0)
56
+ max_index = tl.max(remained_index)
57
+
58
+ tl.store(accept_length + pid, accept_len)
59
+ retrive_index_ptr = retrive_index + (retrive_start + max_index) * max_len
60
+ retrive_offset = tl.arange(0, max_len_upper)
61
+ retrive_load_mask = retrive_offset < accept_len + 1
62
+ data = tl.load(retrive_index_ptr + retrive_offset, mask=retrive_load_mask)
63
+
64
+ tl.store(
65
+ accept_index + pid * max_len + retrive_offset, data, mask=retrive_load_mask
66
+ )
67
+
68
+ extract_load_ptr = accept_index + pid * max_len + accept_len
69
+ if accept_len == max_len - 1:
70
+ extract_data = tl.load(extract_load_ptr - 1)
71
+ tl.store(extract_index + pid * 2, extract_data)
72
+ extract_data = tl.load(extract_load_ptr)
73
+ tl.store(extract_index + pid * 2 + 1, extract_data)
74
+
75
+ else:
76
+ extract_data = tl.load(extract_load_ptr)
77
+ tl.store(extract_index + pid * 2, extract_data)
78
+
79
+
80
+ @triton.jit
81
+ def create_extend_spec_info(
82
+ verified_id,
83
+ seq_len,
84
+ accept_len,
85
+ accept_len_cum,
86
+ positions,
87
+ new_verified_id,
88
+ accept_len_upper: tl.constexpr,
89
+ ):
90
+ pid = tl.program_id(axis=0)
91
+ offset = 0 if pid == 0 else tl.load(accept_len_cum + pid - 1)
92
+ seq_length = tl.load(seq_len + pid)
93
+ accept_length = tl.load(accept_len + pid)
94
+ positions_ptr = positions + offset
95
+ data = tl.arange(0, accept_len_upper)
96
+ mask = data < accept_length
97
+ tl.store(positions_ptr + data, seq_length - accept_length + data, mask)
98
+
99
+ offset = tl.load(accept_len_cum + pid) - 1
100
+ verified_id_data = tl.load(verified_id + offset)
101
+ tl.store(new_verified_id + pid, verified_id_data)
102
+
103
+
104
+ @triton.jit
105
+ def assign_req_to_token_pool(
106
+ req_pool_indices,
107
+ req_to_token,
108
+ start_offset,
109
+ end_offset,
110
+ out_cache_loc,
111
+ pool_len: tl.constexpr,
112
+ bs_upper: tl.constexpr,
113
+ ):
114
+ BLOCK_SIZE: tl.constexpr = 32
115
+ pid = tl.program_id(axis=0)
116
+ kv_start = tl.load(start_offset + pid)
117
+ kv_end = tl.load(end_offset + pid)
118
+ token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
119
+
120
+ length_offset = tl.arange(0, bs_upper)
121
+ start = tl.load(start_offset + length_offset, mask=length_offset < pid)
122
+ end = tl.load(end_offset + length_offset, mask=length_offset < pid)
123
+ out_offset = tl.sum(end - start, axis=0)
124
+
125
+ out_cache_ptr = out_cache_loc + out_offset
126
+
127
+ save_offset = tl.arange(0, BLOCK_SIZE) + kv_start
128
+ load_offset = tl.arange(0, BLOCK_SIZE)
129
+
130
+ num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
131
+ for _ in range(num_loop):
132
+ mask = save_offset < kv_end
133
+ data = tl.load(out_cache_ptr + load_offset, mask=mask)
134
+ tl.store(token_pool + save_offset, data, mask=mask)
135
+ save_offset += BLOCK_SIZE
136
+ load_offset += BLOCK_SIZE
137
+
138
+
139
+ @triton.jit
140
+ def generate_draft_decode_kv_indices(
141
+ req_pool_indices,
142
+ req_to_token,
143
+ paged_kernel_lens,
144
+ kv_indices,
145
+ iters: tl.constexpr,
146
+ topk: tl.constexpr,
147
+ pool_len: tl.constexpr,
148
+ bs_upper: tl.constexpr,
149
+ iter_upper: tl.constexpr,
150
+ ):
151
+ BLOCK_SIZE: tl.constexpr = 128
152
+ bid = tl.program_id(axis=0)
153
+ topk_id = tl.program_id(axis=1)
154
+
155
+ load_offset = tl.arange(0, bs_upper)
156
+ seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid)
157
+ seq_len = tl.load(paged_kernel_lens + bid)
158
+ cum_seq_len = tl.sum(seq_lens)
159
+
160
+ kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
161
+ kv_ptr = kv_indices + kv_offset
162
+ token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
163
+
164
+ kv_offset = tl.arange(0, BLOCK_SIZE)
165
+ num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
166
+ for _ in range(num_loop):
167
+ mask = kv_offset < seq_len
168
+ data = tl.load(token_pool_ptr + kv_offset, mask=mask)
169
+ tl.store(kv_ptr + kv_offset, data, mask=mask)
170
+ kv_offset += BLOCK_SIZE
171
+
172
+ extend_offset = tl.arange(0, iter_upper)
173
+ extend_data = tl.load(
174
+ token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id,
175
+ mask=extend_offset < iters,
176
+ )
177
+ tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
178
+
179
+
180
+ class EAGLEDraftInput(SpecInfo):
181
+ def __init__(self):
182
+ self.prev_mode = ForwardMode.DECODE
183
+ self.sample_output = None
184
+
185
+ self.scores: torch.Tensor = None
186
+ self.score_list: List[torch.Tensor] = []
187
+ self.token_list: List[torch.Tensor] = []
188
+ self.origin_score_list: List[torch.Tensor] = [] # used for sampling
189
+ self.parents_list: List[torch.Tensor] = []
190
+ self.cache_list: List[torch.Tenor] = []
191
+ self.iter = 0
192
+
193
+ self.hidden_states: torch.Tensor = None
194
+ self.verified_id: torch.Tensor = None
195
+ self.positions: torch.Tensor = None
196
+ self.accept_length: torch.Tensor = None
197
+ self.has_finished: bool = False
198
+ self.unfinished_index: List[int] = None
199
+
200
+ def load_server_args(self, server_args: ServerArgs):
201
+ self.topk: int = server_args.speculative_eagle_topk
202
+ self.num_verify_token: int = server_args.speculative_num_draft_tokens
203
+ self.spec_steps = server_args.speculative_num_steps
204
+
205
+ def prepare_for_extend(self, batch: ScheduleBatch):
206
+ req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
207
+ out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
208
+ batch.out_cache_loc = out_cache_loc
209
+
210
+ pt = 0
211
+ for i, req in enumerate(batch.reqs):
212
+ req.req_pool_idx = req_pool_indices[i]
213
+ pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
214
+ assert seq_len - pre_len == req.extend_input_len
215
+
216
+ if pre_len > 0:
217
+ batch.req_to_token_pool.req_to_token[req.req_pool_idx][
218
+ :pre_len
219
+ ] = req.prefix_indices
220
+
221
+ batch.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
222
+ out_cache_loc[pt : pt + req.extend_input_len]
223
+ )
224
+
225
+ pt += req.extend_input_len
226
+
227
+ # TODO: support batching inputs
228
+ assert len(batch.extend_lens) == 1
229
+ batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
230
+
231
+ def prepare_for_decode(self, batch: ScheduleBatch):
232
+ prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
233
+ top = torch.topk(prob, self.topk, dim=-1)
234
+ topk_index, topk_p = (
235
+ top.indices,
236
+ top.values,
237
+ ) # shape: (b * top_k, top_k) or (b, top_k)
238
+
239
+ if self.prev_mode.is_decode():
240
+ scores = torch.mul(
241
+ self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk)
242
+ ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
243
+ topk_cs = torch.topk(
244
+ scores.flatten(start_dim=1), self.topk, dim=-1
245
+ ) # (b, topk)
246
+ topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
247
+
248
+ selected_input_index = topk_cs_index.flatten() // self.topk + torch.arange(
249
+ 0, batch.batch_size() * self.topk, step=self.topk, device="cuda"
250
+ ).repeat_interleave(self.topk)
251
+
252
+ batch.spec_info.hidden_states = batch.spec_info.hidden_states[
253
+ selected_input_index, :
254
+ ]
255
+
256
+ topk_index = topk_index.reshape(-1, self.topk**2)
257
+ batch.input_ids = torch.gather(
258
+ topk_index, index=topk_cs_index, dim=1
259
+ ).flatten()
260
+ batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
261
+
262
+ self.scores = topk_cs_p
263
+ self.score_list.append(scores) # (b, topk, topk)
264
+ self.token_list.append(topk_index) # (b, topk * topk)
265
+ self.origin_score_list.append(topk_p.reshape(topk_index.shape))
266
+ self.parents_list.append(
267
+ topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk)
268
+ ) # shape: (b, topk)
269
+ else:
270
+ # ForwardMode.EXTEND or ForwardMode.DRAFT_EXTEND
271
+ batch.spec_info.hidden_states = (
272
+ batch.spec_info.hidden_states.repeat_interleave(self.topk, dim=0)
273
+ )
274
+
275
+ batch.input_ids = topk_index.flatten()
276
+ batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel())
277
+
278
+ self.scores = topk_p # shape: (b, topk)
279
+ self.score_list.append(topk_p.unsqueeze(1)) # shape: (b, 1, topk)
280
+ self.token_list.append(topk_index) # shape: (b, topk)
281
+ self.origin_score_list.append(topk_p)
282
+ self.parents_list.append(
283
+ torch.arange(-1, self.topk, dtype=torch.long, device="cuda")
284
+ .unsqueeze(0)
285
+ .repeat(self.scores.shape[0], 1)
286
+ ) # shape: (b, topk + 1)
287
+ self.cache_list.append(batch.out_cache_loc)
288
+ self.positions = (
289
+ batch.seq_lens[:, None]
290
+ + torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter
291
+ ).flatten()
292
+
293
+ bs = len(batch.seq_lens)
294
+ assign_req_to_token_pool[(bs,)](
295
+ batch.req_pool_indices,
296
+ batch.req_to_token_pool.req_to_token,
297
+ batch.seq_lens + self.topk * self.iter,
298
+ batch.seq_lens + self.topk * (self.iter + 1),
299
+ batch.out_cache_loc,
300
+ batch.req_to_token_pool.req_to_token.shape[1],
301
+ triton.next_power_of_2(bs),
302
+ )
303
+ self.iter += 1
304
+
305
+ def prepare_extend_after_decode(self, batch: ScheduleBatch):
306
+ batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
307
+ batch.extend_lens = (self.accept_length + 1).tolist()
308
+
309
+ pt = 0
310
+ seq_lens = batch.seq_lens.tolist()
311
+
312
+ i = 0
313
+
314
+ for req in batch.reqs:
315
+ if req.finished():
316
+ continue
317
+ # assert seq_len - pre_len == req.extend_input_len
318
+ input_len = self.accept_length[i] + 1
319
+ seq_len = seq_lens[i]
320
+ batch.req_to_token_pool.req_to_token[req.req_pool_idx][
321
+ seq_len - input_len : seq_len
322
+ ] = batch.out_cache_loc[pt : pt + input_len]
323
+ pt += input_len
324
+ i += 1
325
+
326
+ self.positions = torch.empty_like(self.verified_id)
327
+ new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
328
+ self.accept_length.add_(1)
329
+
330
+ create_extend_spec_info[(self.accept_length.numel(),)](
331
+ self.verified_id,
332
+ batch.seq_lens,
333
+ self.accept_length,
334
+ torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
335
+ self.positions,
336
+ new_verified_id,
337
+ triton.next_power_of_2(self.spec_steps + 1),
338
+ )
339
+
340
+ batch.seq_lens_sum = sum(batch.seq_lens)
341
+ batch.input_ids = self.verified_id
342
+ self.verified_id = new_verified_id
343
+
344
+ def prepare_for_verify(self, batch: ScheduleBatch):
345
+ score_list = torch.cat(self.score_list, dim=1).flatten(
346
+ 1
347
+ ) # b, n, topk; n= 1+(self.iter-1)*self.topk
348
+ ss_token_list = torch.cat(
349
+ self.token_list, dim=1
350
+ ) # b, (self.topk+(self.iter-1)*self.topk)
351
+ origin_token_list = torch.cat(self.origin_score_list, dim=1)
352
+ top_scores = torch.topk(score_list, self.num_verify_token - 1, dim=-1)
353
+ top_scores_index = top_scores.indices
354
+ top_scores_index = torch.sort(top_scores_index).values
355
+
356
+ draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
357
+ scores = torch.gather(origin_token_list, index=top_scores_index, dim=1)
358
+ draft_tokens = torch.cat((self.verified_id.unsqueeze(1), draft_tokens), dim=1)
359
+ parent_list = torch.cat(self.parents_list[:-1], dim=1)
360
+
361
+ tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel(
362
+ parent_list,
363
+ top_scores_index,
364
+ batch.seq_lens,
365
+ self.topk,
366
+ self.iter - 1,
367
+ self.num_verify_token,
368
+ )
369
+
370
+ return EagleVerifyInput(
371
+ draft_tokens.flatten(),
372
+ scores.flatten(),
373
+ tree_mask,
374
+ position,
375
+ retrive_index,
376
+ retrive_cum_len,
377
+ self.num_verify_token,
378
+ )
379
+
380
+ def generate_attn_arg_decode(
381
+ self,
382
+ req_pool_indices: torch.Tensor,
383
+ paged_kernel_lens: torch.Tensor,
384
+ req_to_token: torch.Tensor,
385
+ ):
386
+ seq_num = req_pool_indices.numel()
387
+ bs = self.topk * req_pool_indices.numel()
388
+ seq_len = self.positions.reshape(-1).contiguous()
389
+
390
+ cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
391
+ cum_kv_seq_len[1:] = torch.cumsum(seq_len + 1, dim=0)
392
+ total_len = torch.sum(paged_kernel_lens).item()
393
+
394
+ kv_indices = torch.empty(
395
+ (total_len * self.topk + seq_num * self.iter * self.topk,),
396
+ dtype=torch.int32,
397
+ device="cuda",
398
+ )
399
+
400
+ generate_draft_decode_kv_indices[(req_pool_indices.numel(), self.topk)](
401
+ req_pool_indices,
402
+ req_to_token,
403
+ paged_kernel_lens,
404
+ kv_indices,
405
+ self.iter,
406
+ self.topk,
407
+ req_to_token.shape[1],
408
+ triton.next_power_of_2(seq_num),
409
+ triton.next_power_of_2(self.spec_steps),
410
+ )
411
+ return bs, kv_indices, cum_kv_seq_len
412
+
413
+ def clear_draft_cache(self, batch):
414
+ draft_cache = torch.cat(self.cache_list, dim=0)
415
+ batch.token_to_kv_pool.free(draft_cache)
416
+
417
+ def generate_attn_arg_prefill(
418
+ self,
419
+ req_pool_indices: torch.Tensor,
420
+ paged_kernel_lens: torch.Tensor,
421
+ req_to_token: torch.Tensor,
422
+ ):
423
+ bs = self.accept_length.numel()
424
+ qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
425
+ qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
426
+
427
+ cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
428
+ cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
429
+ kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
430
+
431
+ create_flashinfer_kv_indices_triton[(bs,)](
432
+ req_to_token,
433
+ req_pool_indices,
434
+ paged_kernel_lens,
435
+ cum_kv_seq_len,
436
+ None,
437
+ kv_indices,
438
+ req_to_token.size(1),
439
+ )
440
+
441
+ return kv_indices, cum_kv_seq_len, qo_indptr, None
442
+
443
+ def merge_batch(self, spec_info: EAGLEDraftInput):
444
+ if self.hidden_states is None:
445
+ self.hidden_states = spec_info.hidden_states
446
+ self.verified_id = spec_info.verified_id
447
+ self.sample_output = spec_info.sample_output
448
+ self.prev_mode = spec_info.prev_mode
449
+ return
450
+ if spec_info.hidden_states is None:
451
+ return
452
+ self.hidden_states = torch.cat(
453
+ [self.hidden_states, spec_info.hidden_states], axis=0
454
+ )
455
+ self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
456
+ self.sample_output = torch.cat([self.sample_output, spec_info.sample_output])
457
+
458
+
459
+ class EagleVerifyInput(SpecInfo):
460
+ def __init__(
461
+ self,
462
+ draft_token: torch.Tensor,
463
+ draft_score: torch.Tensor,
464
+ tree_mask: torch.Tensor,
465
+ positions: torch.Tensor,
466
+ retrive_index: torch.Tensor,
467
+ retrive_cum_len: torch.Tensor,
468
+ draft_token_num: int,
469
+ ):
470
+ self.draft_token = draft_token
471
+ self.draft_score = draft_score
472
+ self.custom_mask = tree_mask
473
+ self.positions = positions
474
+ self.retrive_index = retrive_index
475
+ self.retrive_cum_len = retrive_cum_len
476
+ self.draft_token_num = draft_token_num
477
+
478
+ def prepare_for_verify(self, batch: ScheduleBatch):
479
+ batch.input_ids = self.draft_token
480
+ batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
481
+ bs = batch.seq_lens.numel()
482
+ assign_req_to_token_pool[(bs,)](
483
+ batch.req_pool_indices,
484
+ batch.req_to_token_pool.req_to_token,
485
+ batch.seq_lens,
486
+ batch.seq_lens + self.draft_token_num,
487
+ batch.out_cache_loc,
488
+ batch.req_to_token_pool.req_to_token.shape[1],
489
+ triton.next_power_of_2(bs),
490
+ )
491
+
492
+ def generate_attn_arg_prefill(
493
+ self,
494
+ req_pool_indices: torch.Tensor,
495
+ paged_kernel_lens: torch.Tensor,
496
+ req_to_token: torch.Tensor,
497
+ ):
498
+ batch_size = len(req_pool_indices)
499
+ qo_indptr = torch.arange(
500
+ 0,
501
+ (1 + batch_size) * self.draft_token_num,
502
+ step=self.draft_token_num,
503
+ dtype=torch.int32,
504
+ device="cuda",
505
+ )
506
+
507
+ cum_kv_seq_len = torch.zeros(
508
+ (batch_size + 1,), dtype=torch.int32, device="cuda"
509
+ )
510
+
511
+ paged_kernel_lens = paged_kernel_lens + self.draft_token_num
512
+ cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
513
+
514
+ kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
515
+
516
+ create_flashinfer_kv_indices_triton[(batch_size,)](
517
+ req_to_token,
518
+ req_pool_indices,
519
+ paged_kernel_lens,
520
+ cum_kv_seq_len,
521
+ None,
522
+ kv_indices,
523
+ req_to_token.size(1),
524
+ )
525
+ return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
526
+
527
+ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor:
528
+ predict = torch.argmax(logits_output.next_token_logits, dim=-1)
529
+ predict = torch.cat(
530
+ [predict, torch.full([1], -1, dtype=torch.long, device="cuda")], dim=-1
531
+ )
532
+ draft_token = torch.cat(
533
+ [self.draft_token, torch.full([1], -1, dtype=torch.long, device="cuda")],
534
+ dim=-1,
535
+ )
536
+ target_predict = predict[self.retrive_index]
537
+ candidates = draft_token[self.retrive_index]
538
+ # logits = logits_output.next_token_logits[self.retrive_index]
539
+ # target_predict = torch.argmax(logits[:, :-1], dim=-1)
540
+ accept_mask = candidates[:, 1:] == target_predict[:, :-1]
541
+ accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1)
542
+ bs = self.retrive_cum_len.numel() - 1
543
+
544
+ max_draft_len = self.retrive_index.shape[-1]
545
+ accept_index = torch.full(
546
+ (bs, max_draft_len), -1, dtype=torch.long, device="cuda"
547
+ )
548
+ accept_length = torch.empty((bs,), dtype=torch.int, device="cuda")
549
+ extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda")
550
+ eagle_verify_retrive[(bs,)](
551
+ self.retrive_index.contiguous(),
552
+ accept_mask.contiguous(),
553
+ self.retrive_cum_len,
554
+ accept_index,
555
+ accept_length,
556
+ extract_index,
557
+ max_draft_len,
558
+ self.draft_token_num,
559
+ triton.next_power_of_2(max_draft_len),
560
+ )
561
+
562
+ draft_input = EAGLEDraftInput()
563
+ new_accept_index = []
564
+ unfinished_index = []
565
+ finished_extend_len = {} # {rid:accept_length + 1}
566
+ accept_index_cpu = accept_index.tolist()
567
+ predict_cpu = predict.tolist()
568
+ # iterate every accepted token and check if req has finished after append the token
569
+ # should be checked BEFORE free kv cache slots
570
+ for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
571
+ new_accept_index_ = []
572
+ for j, idx in enumerate(accept_index_row):
573
+ if idx == -1:
574
+ break
575
+ id = predict_cpu[idx]
576
+ # if not found_finished:
577
+ req.output_ids.append(id)
578
+ finished_extend_len[req.rid] = j + 1
579
+ req.check_finished()
580
+ if req.finished():
581
+ draft_input.has_finished = True
582
+ # set all tokens after finished token to -1 and break
583
+ accept_index[i, j + 1 :] = -1
584
+ break
585
+ else:
586
+ new_accept_index_.append(idx)
587
+ if not req.finished():
588
+ new_accept_index.extend(new_accept_index_)
589
+ unfinished_index.append(i)
590
+ accept_length = (accept_index != -1).sum(dim=1) - 1
591
+
592
+ accept_index = accept_index[accept_index != -1]
593
+ accept_length_cpu = accept_length.tolist()
594
+ verified_id = predict[accept_index]
595
+ verified_id_cpu = verified_id.tolist()
596
+
597
+ evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
598
+ evict_mask[accept_index] = False
599
+ mem_need_free_idx = batch.out_cache_loc[evict_mask]
600
+ batch.token_to_kv_pool.free(mem_need_free_idx)
601
+ assign_req_to_token_pool[(bs,)](
602
+ batch.req_pool_indices,
603
+ batch.req_to_token_pool.req_to_token,
604
+ batch.seq_lens,
605
+ batch.seq_lens + accept_length + 1,
606
+ batch.out_cache_loc[accept_index],
607
+ batch.req_to_token_pool.req_to_token.shape[1],
608
+ triton.next_power_of_2(bs),
609
+ )
610
+ batch.seq_lens.add_(accept_length + 1)
611
+
612
+ if len(new_accept_index) > 0:
613
+ new_accept_index = torch.tensor(new_accept_index, device="cuda")
614
+ draft_input.verified_id = predict[new_accept_index]
615
+ draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
616
+ draft_input.accept_length = accept_length[unfinished_index]
617
+ draft_input.unfinished_index = unfinished_index
618
+
619
+ logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
620
+ return (
621
+ draft_input,
622
+ logits_output,
623
+ verified_id,
624
+ finished_extend_len,
625
+ accept_length_cpu,
626
+ )