sglang 0.4.2.post1__py3-none-any.whl → 0.4.2.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/srt/constrained/outlines_backend.py +9 -1
  2. sglang/srt/custom_op.py +40 -0
  3. sglang/srt/entrypoints/engine.py +2 -2
  4. sglang/srt/function_call_parser.py +96 -69
  5. sglang/srt/layers/activation.py +10 -5
  6. sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
  7. sglang/srt/layers/attention/flashinfer_backend.py +284 -39
  8. sglang/srt/layers/attention/triton_backend.py +124 -12
  9. sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
  10. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
  11. sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
  12. sglang/srt/layers/layernorm.py +1 -5
  13. sglang/srt/layers/moe/ep_moe/layer.py +1 -3
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
  22. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -13
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
  24. sglang/srt/layers/moe/topk.py +4 -0
  25. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  33. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  36. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  38. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  40. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  44. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  46. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/fp8_kernel.py +173 -2
  48. sglang/srt/layers/rotary_embedding.py +1 -3
  49. sglang/srt/layers/sampler.py +4 -4
  50. sglang/srt/lora/backend/__init__.py +8 -0
  51. sglang/srt/lora/backend/base_backend.py +95 -0
  52. sglang/srt/lora/backend/flashinfer_backend.py +91 -0
  53. sglang/srt/lora/backend/triton_backend.py +61 -0
  54. sglang/srt/lora/lora.py +127 -112
  55. sglang/srt/lora/lora_manager.py +50 -18
  56. sglang/srt/lora/triton_ops/__init__.py +5 -0
  57. sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
  58. sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
  59. sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
  60. sglang/srt/model_executor/cuda_graph_runner.py +77 -80
  61. sglang/srt/model_executor/forward_batch_info.py +58 -59
  62. sglang/srt/model_executor/model_runner.py +2 -2
  63. sglang/srt/models/llama.py +8 -3
  64. sglang/srt/models/qwen2_vl.py +1 -1
  65. sglang/srt/server_args.py +13 -2
  66. sglang/srt/speculative/build_eagle_tree.py +486 -104
  67. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
  68. sglang/srt/speculative/eagle_utils.py +420 -401
  69. sglang/srt/speculative/eagle_worker.py +177 -45
  70. sglang/srt/utils.py +7 -0
  71. sglang/test/runners.py +2 -0
  72. sglang/version.py +1 -1
  73. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/METADATA +15 -6
  74. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/RECORD +77 -38
  75. sglang/srt/layers/custom_op_util.py +0 -25
  76. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/LICENSE +0 -0
  77. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/WHEEL +0 -0
  78. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import dataclasses
3
4
  from typing import TYPE_CHECKING, List
4
5
 
5
6
  import torch
@@ -9,201 +10,33 @@ import triton.language as tl
9
10
  from sglang.srt.layers.attention.flashinfer_backend import (
10
11
  create_flashinfer_kv_indices_triton,
11
12
  )
12
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
13
+ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
13
14
  from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
14
- from sglang.srt.speculative.spec_info import SpecInfo
15
15
 
16
16
  if TYPE_CHECKING:
17
17
  from sglang.srt.managers.schedule_batch import ScheduleBatch
18
- from sglang.srt.server_args import ServerArgs
19
18
 
20
19
 
21
- @triton.jit
22
- def eagle_verify_retrive(
23
- retrive_index,
24
- accept_mask,
25
- retrive_cum_len,
26
- accept_index,
27
- accept_length,
28
- extract_index,
29
- max_len: tl.constexpr,
30
- draft_token_num: tl.constexpr,
31
- max_len_upper: tl.constexpr,
32
- ):
33
- pid = tl.program_id(axis=0)
34
-
35
- retrive_end = tl.load(retrive_cum_len + pid + 1)
36
- retrive_start = tl.load(retrive_cum_len + pid)
37
- retrive_len = retrive_end - retrive_start
38
- accept_ptr = accept_mask + retrive_start
39
- accept_offset = tl.arange(0, draft_token_num)
40
- accept_load_mask = accept_offset < retrive_len
41
- accept_len_list = tl.load(
42
- accept_ptr + accept_offset, mask=accept_load_mask, other=-1
43
- )
44
-
45
- accept_len = tl.max(accept_len_list)
46
- max_index = tl.argmax(accept_len_list, axis=0, tie_break_left=True)
47
- # triton is not support argmax with tie_break_right, so I need implement it by some way
48
- mask_max = accept_len_list == accept_len
49
-
50
- count_mask = tl.full(shape=[draft_token_num], value=0, dtype=tl.int32)
51
- count = tl.sum(tl.where(mask_max, 1, count_mask))
52
- if count > 1:
53
- index = tl.arange(0, draft_token_num)
54
- mask_left = index != max_index
55
- remained_index = tl.where(mask_max and mask_left, index, 0)
56
- max_index = tl.max(remained_index)
57
-
58
- tl.store(accept_length + pid, accept_len)
59
- retrive_index_ptr = retrive_index + (retrive_start + max_index) * max_len
60
- retrive_offset = tl.arange(0, max_len_upper)
61
- retrive_load_mask = retrive_offset < accept_len + 1
62
- data = tl.load(retrive_index_ptr + retrive_offset, mask=retrive_load_mask)
63
-
64
- tl.store(
65
- accept_index + pid * max_len + retrive_offset, data, mask=retrive_load_mask
66
- )
67
-
68
- extract_load_ptr = accept_index + pid * max_len + accept_len
69
- if accept_len == max_len - 1:
70
- extract_data = tl.load(extract_load_ptr - 1)
71
- tl.store(extract_index + pid * 2, extract_data)
72
- extract_data = tl.load(extract_load_ptr)
73
- tl.store(extract_index + pid * 2 + 1, extract_data)
74
-
75
- else:
76
- extract_data = tl.load(extract_load_ptr)
77
- tl.store(extract_index + pid * 2, extract_data)
78
-
79
-
80
- @triton.jit
81
- def create_extend_spec_info(
82
- verified_id,
83
- seq_len,
84
- accept_len,
85
- accept_len_cum,
86
- positions,
87
- new_verified_id,
88
- accept_len_upper: tl.constexpr,
89
- ):
90
- pid = tl.program_id(axis=0)
91
- offset = 0 if pid == 0 else tl.load(accept_len_cum + pid - 1)
92
- seq_length = tl.load(seq_len + pid)
93
- accept_length = tl.load(accept_len + pid)
94
- positions_ptr = positions + offset
95
- data = tl.arange(0, accept_len_upper)
96
- mask = data < accept_length
97
- tl.store(positions_ptr + data, seq_length - accept_length + data, mask)
98
-
99
- offset = tl.load(accept_len_cum + pid) - 1
100
- verified_id_data = tl.load(verified_id + offset)
101
- tl.store(new_verified_id + pid, verified_id_data)
102
-
103
-
104
- @triton.jit
105
- def assign_req_to_token_pool(
106
- req_pool_indices,
107
- req_to_token,
108
- start_offset,
109
- end_offset,
110
- out_cache_loc,
111
- pool_len: tl.constexpr,
112
- bs_upper: tl.constexpr,
113
- ):
114
- BLOCK_SIZE: tl.constexpr = 32
115
- pid = tl.program_id(axis=0)
116
- kv_start = tl.load(start_offset + pid)
117
- kv_end = tl.load(end_offset + pid)
118
- token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
119
-
120
- length_offset = tl.arange(0, bs_upper)
121
- start = tl.load(start_offset + length_offset, mask=length_offset < pid)
122
- end = tl.load(end_offset + length_offset, mask=length_offset < pid)
123
- out_offset = tl.sum(end - start, axis=0)
124
-
125
- out_cache_ptr = out_cache_loc + out_offset
126
-
127
- save_offset = tl.arange(0, BLOCK_SIZE) + kv_start
128
- load_offset = tl.arange(0, BLOCK_SIZE)
129
-
130
- num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
131
- for _ in range(num_loop):
132
- mask = save_offset < kv_end
133
- data = tl.load(out_cache_ptr + load_offset, mask=mask)
134
- tl.store(token_pool + save_offset, data, mask=mask)
135
- save_offset += BLOCK_SIZE
136
- load_offset += BLOCK_SIZE
137
-
138
-
139
- @triton.jit
140
- def generate_draft_decode_kv_indices(
141
- req_pool_indices,
142
- req_to_token,
143
- paged_kernel_lens,
144
- kv_indices,
145
- iters: tl.constexpr,
146
- topk: tl.constexpr,
147
- pool_len: tl.constexpr,
148
- bs_upper: tl.constexpr,
149
- iter_upper: tl.constexpr,
150
- ):
151
- BLOCK_SIZE: tl.constexpr = 128
152
- bid = tl.program_id(axis=0)
153
- topk_id = tl.program_id(axis=1)
154
-
155
- load_offset = tl.arange(0, bs_upper)
156
- seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid)
157
- seq_len = tl.load(paged_kernel_lens + bid)
158
- cum_seq_len = tl.sum(seq_lens)
159
-
160
- kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
161
- kv_ptr = kv_indices + kv_offset
162
- token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
163
-
164
- kv_offset = tl.arange(0, BLOCK_SIZE)
165
- num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
166
- for _ in range(num_loop):
167
- mask = kv_offset < seq_len
168
- data = tl.load(token_pool_ptr + kv_offset, mask=mask)
169
- tl.store(kv_ptr + kv_offset, data, mask=mask)
170
- kv_offset += BLOCK_SIZE
171
-
172
- extend_offset = tl.arange(0, iter_upper)
173
- extend_data = tl.load(
174
- token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id,
175
- mask=extend_offset < iters,
176
- )
177
- tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
178
-
179
-
180
- class EAGLEDraftInput(SpecInfo):
181
- def __init__(self):
182
- self.prev_mode = ForwardMode.DECODE
20
+ @dataclasses.dataclass
21
+ class EagleDraftInput:
22
+ # The inputs for decode
23
+ # shape: (b, topk)
24
+ topk_p: torch.Tensor = None
25
+ topk_index: torch.Tensor = None
26
+ # shape: (b, hidden_size)
27
+ hidden_states: torch.Tensor = None
28
+ capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL
183
29
 
184
- self.scores: torch.Tensor = None
185
- self.score_list: List[torch.Tensor] = []
186
- self.token_list: List[torch.Tensor] = []
187
- self.origin_score_list: List[torch.Tensor] = [] # used for sampling
188
- self.parents_list: List[torch.Tensor] = []
189
- self.cache_list: List[torch.Tenor] = []
190
- self.iter = 0
30
+ # Inputs for extend
31
+ # shape: (b,)
32
+ verified_id: torch.Tensor = None
33
+ accept_length: torch.Tensor = None
34
+ accept_length_cpu: List[int] = None
191
35
 
192
- # shape: (b, hidden_size)
193
- self.hidden_states: torch.Tensor = None
194
- # shape: (b,)
195
- self.verified_id: torch.Tensor = None
196
- # shape: (b, vocab_size)
197
- self.sample_output: torch.Tensor = None
198
-
199
- self.positions: torch.Tensor = None
200
- self.accept_length: torch.Tensor = None
201
- self.accept_length_cpu: List[int] = None
202
-
203
- def load_server_args(self, server_args: ServerArgs):
204
- self.topk: int = server_args.speculative_eagle_topk
205
- self.num_verify_token: int = server_args.speculative_num_draft_tokens
206
- self.spec_steps = server_args.speculative_num_steps
36
+ # Inputs for the attention backends
37
+ # shape: (b + 1,)
38
+ kv_indptr: torch.Tensor = None
39
+ kv_indices: torch.Tensor = None
207
40
 
208
41
  def prepare_for_extend(self, batch: ScheduleBatch):
209
42
  req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
@@ -231,95 +64,12 @@ class EAGLEDraftInput(SpecInfo):
231
64
  assert len(batch.extend_lens) == 1
232
65
  batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
233
66
 
234
- def filter_batch(
235
- self,
236
- new_indices: torch.Tensor,
237
- ):
238
- self.sample_output = self.sample_output[: len(new_indices)]
239
- self.hidden_states = self.hidden_states[: len(new_indices)]
240
- self.verified_id = self.verified_id[: len(new_indices)]
241
-
242
- def prepare_for_decode(self, batch: ScheduleBatch):
243
- prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
244
- top = torch.topk(prob, self.topk, dim=-1)
245
- topk_index, topk_p = (
246
- top.indices,
247
- top.values,
248
- ) # shape: (b * top_k, top_k) or (b, top_k)
249
-
250
- if self.prev_mode.is_decode():
251
- scores = torch.mul(
252
- self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk)
253
- ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
254
- topk_cs = torch.topk(
255
- scores.flatten(start_dim=1), self.topk, dim=-1
256
- ) # (b, topk)
257
- topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
258
-
259
- selected_input_index = topk_cs_index.flatten() // self.topk + torch.arange(
260
- 0, batch.batch_size() * self.topk, step=self.topk, device="cuda"
261
- ).repeat_interleave(self.topk)
262
-
263
- batch.spec_info.hidden_states = batch.spec_info.hidden_states[
264
- selected_input_index, :
265
- ]
266
-
267
- topk_index = topk_index.reshape(-1, self.topk**2)
268
- batch.input_ids = torch.gather(
269
- topk_index, index=topk_cs_index, dim=1
270
- ).flatten()
271
- batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
272
-
273
- self.scores = topk_cs_p
274
- self.score_list.append(scores) # (b, topk, topk)
275
- self.token_list.append(topk_index) # (b, topk * topk)
276
- self.origin_score_list.append(topk_p.reshape(topk_index.shape))
277
- self.parents_list.append(
278
- topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk)
279
- ) # shape: (b, topk)
280
- else:
281
- # ForwardMode.EXTEND or ForwardMode.DRAFT_EXTEND
282
- batch.spec_info.hidden_states = (
283
- batch.spec_info.hidden_states.repeat_interleave(self.topk, dim=0)
284
- )
285
-
286
- batch.input_ids = topk_index.flatten()
287
- batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel())
288
-
289
- self.scores = topk_p # shape: (b, topk)
290
- self.score_list.append(topk_p.unsqueeze(1)) # shape: (b, 1, topk)
291
- self.token_list.append(topk_index) # shape: (b, topk)
292
- self.origin_score_list.append(topk_p)
293
- self.parents_list.append(
294
- torch.arange(-1, self.topk, dtype=torch.long, device="cuda")
295
- .unsqueeze(0)
296
- .repeat(self.scores.shape[0], 1)
297
- ) # shape: (b, topk + 1)
298
- self.cache_list.append(batch.out_cache_loc)
299
- self.positions = (
300
- batch.seq_lens[:, None]
301
- + torch.full(
302
- [1, self.topk], fill_value=self.iter, device="cuda", dtype=torch.long
303
- )
304
- ).flatten()
305
-
306
- bs = len(batch.seq_lens)
307
- assign_req_to_token_pool[(bs,)](
308
- batch.req_pool_indices,
309
- batch.req_to_token_pool.req_to_token,
310
- batch.seq_lens + self.topk * self.iter,
311
- batch.seq_lens + self.topk * (self.iter + 1),
312
- batch.out_cache_loc,
313
- batch.req_to_token_pool.req_to_token.shape[1],
314
- triton.next_power_of_2(bs),
315
- )
316
- self.iter += 1
317
-
318
- def prepare_extend_after_decode(self, batch: ScheduleBatch):
67
+ def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
319
68
  batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
320
69
  accept_length_cpu = batch.spec_info.accept_length_cpu
321
70
  batch.extend_lens = [x + 1 for x in accept_length_cpu]
322
71
  batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
72
+ batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
323
73
  seq_lens_cpu = batch.seq_lens.tolist()
324
74
 
325
75
  pt = 0
@@ -348,86 +98,13 @@ class EAGLEDraftInput(SpecInfo):
348
98
  torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
349
99
  self.positions,
350
100
  new_verified_id,
351
- triton.next_power_of_2(self.spec_steps + 1),
101
+ triton.next_power_of_2(speculative_num_steps + 1),
352
102
  )
353
103
 
354
104
  batch.seq_lens_sum = sum(seq_lens_cpu)
355
105
  batch.input_ids = self.verified_id
356
106
  self.verified_id = new_verified_id
357
107
 
358
- def prepare_for_verify(self, batch: ScheduleBatch):
359
- score_list = torch.cat(self.score_list, dim=1).flatten(
360
- 1
361
- ) # b, n, topk; n= 1+(self.iter-1)*self.topk
362
- ss_token_list = torch.cat(
363
- self.token_list, dim=1
364
- ) # b, (self.topk+(self.iter-1)*self.topk)
365
- origin_token_list = torch.cat(self.origin_score_list, dim=1)
366
- top_scores = torch.topk(score_list, self.num_verify_token - 1, dim=-1)
367
- top_scores_index = top_scores.indices
368
- top_scores_index = torch.sort(top_scores_index).values
369
-
370
- draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
371
- scores = torch.gather(origin_token_list, index=top_scores_index, dim=1)
372
- draft_tokens = torch.cat((self.verified_id.unsqueeze(1), draft_tokens), dim=1)
373
- parent_list = torch.cat(self.parents_list[:-1], dim=1)
374
-
375
- tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel(
376
- parent_list,
377
- top_scores_index,
378
- batch.seq_lens,
379
- self.topk,
380
- self.iter - 1,
381
- self.num_verify_token,
382
- )
383
-
384
- return EagleVerifyInput(
385
- draft_tokens.flatten(),
386
- scores.flatten(),
387
- tree_mask,
388
- position,
389
- retrive_index,
390
- retrive_cum_len,
391
- self.num_verify_token,
392
- )
393
-
394
- def generate_attn_arg_decode(
395
- self,
396
- req_pool_indices: torch.Tensor,
397
- paged_kernel_lens: torch.Tensor,
398
- req_to_token: torch.Tensor,
399
- ):
400
- seq_num = req_pool_indices.numel()
401
- bs = self.topk * req_pool_indices.numel()
402
- seq_len = self.positions.reshape(-1).contiguous()
403
-
404
- cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
405
- cum_kv_seq_len[1:] = torch.cumsum(seq_len + 1, dim=0)
406
- total_len = torch.sum(paged_kernel_lens).item()
407
-
408
- kv_indices = torch.empty(
409
- (total_len * self.topk + seq_num * self.iter * self.topk,),
410
- dtype=torch.int32,
411
- device="cuda",
412
- )
413
-
414
- generate_draft_decode_kv_indices[(req_pool_indices.numel(), self.topk)](
415
- req_pool_indices,
416
- req_to_token,
417
- paged_kernel_lens,
418
- kv_indices,
419
- self.iter,
420
- self.topk,
421
- req_to_token.shape[1],
422
- triton.next_power_of_2(seq_num),
423
- triton.next_power_of_2(self.spec_steps),
424
- )
425
- return bs, kv_indices, cum_kv_seq_len
426
-
427
- def clear_draft_cache(self, batch):
428
- draft_cache = torch.cat(self.cache_list, dim=0)
429
- batch.token_to_kv_pool.free(draft_cache)
430
-
431
108
  def generate_attn_arg_prefill(
432
109
  self,
433
110
  req_pool_indices: torch.Tensor,
@@ -454,12 +131,18 @@ class EAGLEDraftInput(SpecInfo):
454
131
 
455
132
  return kv_indices, cum_kv_seq_len, qo_indptr, None
456
133
 
457
- def merge_batch(self, spec_info: EAGLEDraftInput):
134
+ def filter_batch(self, new_indices: torch.Tensor):
135
+ self.topk_p = self.topk_p[: len(new_indices)]
136
+ self.topk_index = self.topk_index[: len(new_indices)]
137
+ self.hidden_states = self.hidden_states[: len(new_indices)]
138
+ self.verified_id = self.verified_id[: len(new_indices)]
139
+
140
+ def merge_batch(self, spec_info: EagleDraftInput):
458
141
  if self.hidden_states is None:
459
142
  self.hidden_states = spec_info.hidden_states
460
143
  self.verified_id = spec_info.verified_id
461
- self.sample_output = spec_info.sample_output
462
- self.prev_mode = spec_info.prev_mode
144
+ self.topk_p = spec_info.topk_p
145
+ self.topk_index = spec_info.topk_index
463
146
  return
464
147
  if spec_info.hidden_states is None:
465
148
  return
@@ -467,32 +150,60 @@ class EAGLEDraftInput(SpecInfo):
467
150
  [self.hidden_states, spec_info.hidden_states], axis=0
468
151
  )
469
152
  self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
470
- self.sample_output = torch.cat([self.sample_output, spec_info.sample_output])
471
-
472
-
473
- class EagleVerifyInput(SpecInfo):
474
- def __init__(
475
- self,
476
- draft_token: torch.Tensor,
477
- draft_score: torch.Tensor,
478
- tree_mask: torch.Tensor,
479
- positions: torch.Tensor,
480
- retrive_index: torch.Tensor,
481
- retrive_cum_len: torch.Tensor,
482
- draft_token_num: int,
153
+ self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
154
+ self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
155
+
156
+
157
+ @dataclasses.dataclass
158
+ class EagleVerifyInput:
159
+ draft_token: torch.Tensor
160
+ custom_mask: torch.Tensor
161
+ positions: torch.Tensor
162
+ retrive_index: torch.Tensor
163
+ retrive_cum_len: torch.Tensor
164
+ draft_token_num: int
165
+ capture_hidden_mode: CaptureHiddenMode
166
+
167
+ @classmethod
168
+ def create(
169
+ cls,
170
+ verified_id: torch.Tensor,
171
+ score_list: List[torch.Tensor],
172
+ token_list: List[torch.Tensor],
173
+ parents_list: List[torch.Tensor],
174
+ seq_lens: torch.Tensor,
175
+ seq_lens_sum: int,
176
+ topk: int,
177
+ spec_steps: int,
178
+ num_verify_token: int,
483
179
  ):
484
- self.draft_token = draft_token
485
- self.draft_score = draft_score
486
- self.custom_mask = tree_mask
487
- self.positions = positions
488
- self.retrive_index = retrive_index
489
- self.retrive_cum_len = retrive_cum_len
490
- self.draft_token_num = draft_token_num
180
+ tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
181
+ build_tree_kernel(
182
+ verified_id,
183
+ score_list,
184
+ token_list,
185
+ parents_list,
186
+ seq_lens,
187
+ seq_lens_sum,
188
+ topk,
189
+ spec_steps,
190
+ num_verify_token,
191
+ )
192
+ )
193
+ return cls(
194
+ draft_tokens,
195
+ tree_mask,
196
+ position,
197
+ retrive_index,
198
+ retrive_cum_len,
199
+ num_verify_token,
200
+ CaptureHiddenMode.FULL,
201
+ )
491
202
 
492
203
  def prepare_for_verify(self, batch: ScheduleBatch):
493
204
  batch.input_ids = self.draft_token
494
205
  batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
495
- bs = batch.seq_lens.numel()
206
+ bs = batch.batch_size()
496
207
  assign_req_to_token_pool[(bs,)](
497
208
  batch.req_pool_indices,
498
209
  batch.req_to_token_pool.req_to_token,
@@ -539,41 +250,78 @@ class EagleVerifyInput(SpecInfo):
539
250
  return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
540
251
 
541
252
  def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor:
542
- predict = torch.argmax(logits_output.next_token_logits, dim=-1)
543
- predict = torch.cat(
544
- [predict, torch.full([1], -1, dtype=torch.long, device="cuda")], dim=-1
545
- )
546
253
  draft_token = torch.cat(
547
- [self.draft_token, torch.full([1], -1, dtype=torch.long, device="cuda")],
254
+ [self.draft_token, torch.full([1], -1, dtype=torch.int32, device="cuda")],
548
255
  dim=-1,
549
256
  )
550
- target_predict = predict[self.retrive_index]
551
257
  candidates = draft_token[self.retrive_index]
552
- # logits = logits_output.next_token_logits[self.retrive_index]
553
- # target_predict = torch.argmax(logits[:, :-1], dim=-1)
554
- accept_mask = candidates[:, 1:] == target_predict[:, :-1]
555
- accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1)
556
- bs = self.retrive_cum_len.numel() - 1
557
-
558
- max_draft_len = self.retrive_index.shape[-1]
559
- accept_index = torch.full(
560
- (bs, max_draft_len), -1, dtype=torch.long, device="cuda"
561
- )
562
- accept_length = torch.empty((bs,), dtype=torch.int, device="cuda")
563
- extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda")
564
- eagle_verify_retrive[(bs,)](
565
- self.retrive_index.contiguous(),
566
- accept_mask.contiguous(),
567
- self.retrive_cum_len,
568
- accept_index,
569
- accept_length,
570
- extract_index,
571
- max_draft_len,
572
- self.draft_token_num,
573
- triton.next_power_of_2(max_draft_len),
574
- )
258
+ if batch.sampling_info.is_all_greedy:
259
+ # temp == 0
260
+ bs = self.retrive_cum_len.numel() - 1
261
+ predict = torch.argmax(logits_output.next_token_logits, dim=-1)
262
+ predict = torch.cat(
263
+ [predict, torch.full([1], -1, dtype=torch.int32, device="cuda")], dim=-1
264
+ )
265
+ target_predict = predict[self.retrive_index]
266
+ # logits = logits_output.next_token_logits[self.retrive_index]
267
+ # target_predict = torch.argmax(logits[:, :-1], dim=-1)
268
+ accept_mask = candidates[:, 1:] == target_predict[:, :-1]
269
+
270
+ accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1)
271
+ max_draft_len = self.retrive_index.shape[-1]
272
+ accept_index = torch.full(
273
+ (bs, max_draft_len), -1, dtype=torch.int32, device="cuda"
274
+ )
275
+ accept_length = torch.empty((bs,), dtype=torch.int, device="cuda")
276
+ extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda")
277
+ eagle_verify_retrive[(bs,)](
278
+ self.retrive_index.contiguous(),
279
+ accept_mask.contiguous(),
280
+ self.retrive_cum_len,
281
+ accept_index,
282
+ accept_length,
283
+ extract_index,
284
+ max_draft_len,
285
+ self.draft_token_num,
286
+ triton.next_power_of_2(max_draft_len),
287
+ )
288
+ else:
289
+ # temp > 0
290
+ bs = self.retrive_index.shape[0]
291
+ predict_shape = list(logits_output.next_token_logits.shape)[:-1]
292
+ predict_shape[-1] += 1
293
+ target_logits = logits_output.next_token_logits[self.retrive_index]
294
+ predict = torch.full(predict_shape, -1, dtype=torch.int32, device="cuda")
295
+ accept_index = torch.full(
296
+ (bs, self.spec_steps + 1), -1, dtype=torch.int32, device="cuda"
297
+ )
298
+ accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
299
+ expanded_temperature = batch.sampling_info.temperatures.unsqueeze(1)
300
+ target_probs = F.softmax(target_logits / expanded_temperature, dim=-1)
301
+ draft_probs = torch.full_like(
302
+ target_probs, 0, dtype=torch.float32, device="cuda"
303
+ )
304
+ coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
305
+ tree_speculative_sampling_target_only(
306
+ predicts=predict, # mutable
307
+ accept_index=accept_index, # mutable
308
+ accept_token_num=accept_length, # mutable
309
+ candidates=candidates.to(torch.int32),
310
+ retrive_index=self.retrive_index.to(torch.int32),
311
+ retrive_next_token=self.retrive_next_token.to(torch.int32),
312
+ retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
313
+ uniform_samples=coins,
314
+ target_probs=target_probs,
315
+ draft_probs=draft_probs,
316
+ threshold_single=global_server_args_dict[
317
+ "speculative_accept_threshold_single"
318
+ ],
319
+ threshold_acc=global_server_args_dict[
320
+ "speculative_accept_threshold_acc"
321
+ ],
322
+ deterministic=True,
323
+ )
575
324
 
576
- draft_input = EAGLEDraftInput()
577
325
  new_accept_index = []
578
326
  unfinished_index = []
579
327
  finished_extend_len = {} # {rid:accept_length + 1}
@@ -625,18 +373,23 @@ class EagleVerifyInput(SpecInfo):
625
373
  )
626
374
  batch.seq_lens.add_(accept_length + 1)
627
375
 
376
+ draft_input = EagleDraftInput()
628
377
  if len(new_accept_index) > 0:
629
378
  new_accept_index = torch.tensor(new_accept_index, device="cuda")
630
- draft_input.verified_id = predict[new_accept_index]
631
379
  draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
380
+ draft_input.verified_id = predict[new_accept_index]
632
381
  draft_input.accept_length = accept_length[unfinished_index]
633
382
  draft_input.accept_length_cpu = [
634
383
  accept_length_cpu[i] for i in unfinished_index
635
384
  ]
636
385
  if has_finished:
637
386
  draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
387
+ draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
388
+ unfinished_index
389
+ ]
638
390
  else:
639
391
  draft_input.seq_lens_for_draft_extend = batch.seq_lens
392
+ draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
640
393
 
641
394
  logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
642
395
  return (
@@ -646,3 +399,269 @@ class EagleVerifyInput(SpecInfo):
646
399
  finished_extend_len,
647
400
  accept_length_cpu,
648
401
  )
402
+
403
+
404
+ @triton.jit
405
+ def eagle_verify_retrive(
406
+ retrive_index,
407
+ accept_mask,
408
+ retrive_cum_len,
409
+ accept_index,
410
+ accept_length,
411
+ extract_index,
412
+ max_len: tl.constexpr,
413
+ draft_token_num: tl.constexpr,
414
+ max_len_upper: tl.constexpr,
415
+ ):
416
+ pid = tl.program_id(axis=0)
417
+
418
+ retrive_end = tl.load(retrive_cum_len + pid + 1)
419
+ retrive_start = tl.load(retrive_cum_len + pid)
420
+ retrive_len = retrive_end - retrive_start
421
+ accept_ptr = accept_mask + retrive_start
422
+ accept_offset = tl.arange(0, draft_token_num)
423
+ accept_load_mask = accept_offset < retrive_len
424
+ accept_len_list = tl.load(
425
+ accept_ptr + accept_offset, mask=accept_load_mask, other=-1
426
+ )
427
+
428
+ accept_len = tl.max(accept_len_list)
429
+ max_index = tl.argmax(accept_len_list, axis=0, tie_break_left=True)
430
+ # triton is not support argmax with tie_break_right, so I need implement it by some way
431
+ mask_max = accept_len_list == accept_len
432
+
433
+ count_mask = tl.full(shape=[draft_token_num], value=0, dtype=tl.int32)
434
+ count = tl.sum(tl.where(mask_max, 1, count_mask))
435
+ if count > 1:
436
+ index = tl.arange(0, draft_token_num)
437
+ mask_left = index != max_index
438
+ remained_index = tl.where(mask_max and mask_left, index, 0)
439
+ max_index = tl.max(remained_index)
440
+
441
+ tl.store(accept_length + pid, accept_len)
442
+ retrive_index_ptr = retrive_index + (retrive_start + max_index) * max_len
443
+ retrive_offset = tl.arange(0, max_len_upper)
444
+ retrive_load_mask = retrive_offset < accept_len + 1
445
+ data = tl.load(retrive_index_ptr + retrive_offset, mask=retrive_load_mask)
446
+
447
+ tl.store(
448
+ accept_index + pid * max_len + retrive_offset, data, mask=retrive_load_mask
449
+ )
450
+
451
+ extract_load_ptr = accept_index + pid * max_len + accept_len
452
+ if accept_len == max_len - 1:
453
+ extract_data = tl.load(extract_load_ptr - 1)
454
+ tl.store(extract_index + pid * 2, extract_data)
455
+ extract_data = tl.load(extract_load_ptr)
456
+ tl.store(extract_index + pid * 2 + 1, extract_data)
457
+
458
+ else:
459
+ extract_data = tl.load(extract_load_ptr)
460
+ tl.store(extract_index + pid * 2, extract_data)
461
+
462
+
463
+ @triton.jit
464
+ def create_extend_spec_info(
465
+ verified_id,
466
+ seq_len,
467
+ accept_len,
468
+ accept_len_cum,
469
+ positions,
470
+ new_verified_id,
471
+ accept_len_upper: tl.constexpr,
472
+ ):
473
+ pid = tl.program_id(axis=0)
474
+ offset = 0 if pid == 0 else tl.load(accept_len_cum + pid - 1)
475
+ seq_length = tl.load(seq_len + pid)
476
+ accept_length = tl.load(accept_len + pid)
477
+ positions_ptr = positions + offset
478
+ data = tl.arange(0, accept_len_upper)
479
+ mask = data < accept_length
480
+ tl.store(positions_ptr + data, seq_length - accept_length + data, mask)
481
+
482
+ offset = tl.load(accept_len_cum + pid) - 1
483
+ verified_id_data = tl.load(verified_id + offset)
484
+ tl.store(new_verified_id + pid, verified_id_data)
485
+
486
+
487
+ @triton.jit
488
+ def assign_req_to_token_pool(
489
+ req_pool_indices,
490
+ req_to_token,
491
+ start_offset,
492
+ end_offset,
493
+ out_cache_loc,
494
+ pool_len: tl.constexpr,
495
+ bs_upper: tl.constexpr,
496
+ ):
497
+ BLOCK_SIZE: tl.constexpr = 32
498
+ pid = tl.program_id(axis=0)
499
+ kv_start = tl.load(start_offset + pid)
500
+ kv_end = tl.load(end_offset + pid)
501
+ token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
502
+
503
+ length_offset = tl.arange(0, bs_upper)
504
+ start = tl.load(start_offset + length_offset, mask=length_offset < pid)
505
+ end = tl.load(end_offset + length_offset, mask=length_offset < pid)
506
+ out_offset = tl.sum(end - start, axis=0)
507
+
508
+ out_cache_ptr = out_cache_loc + out_offset
509
+
510
+ save_offset = tl.arange(0, BLOCK_SIZE) + kv_start
511
+ load_offset = tl.arange(0, BLOCK_SIZE)
512
+
513
+ num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
514
+ for _ in range(num_loop):
515
+ mask = save_offset < kv_end
516
+ data = tl.load(out_cache_ptr + load_offset, mask=mask)
517
+ tl.store(token_pool + save_offset, data, mask=mask)
518
+ save_offset += BLOCK_SIZE
519
+ load_offset += BLOCK_SIZE
520
+
521
+
522
+ @triton.jit
523
+ def assign_draft_cache_locs(
524
+ req_pool_indices,
525
+ req_to_token,
526
+ seq_lens,
527
+ out_cache_loc,
528
+ pool_len: tl.constexpr,
529
+ topk: tl.constexpr,
530
+ speculative_num_steps: tl.constexpr,
531
+ ):
532
+ BLOCK_SIZE: tl.constexpr = 32
533
+ pid = tl.program_id(axis=0)
534
+ kv_start = tl.load(seq_lens + pid)
535
+ kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
536
+ token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
537
+ out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
538
+
539
+ num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
540
+ for i in range(num_loop):
541
+ save_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + kv_start
542
+ load_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
543
+ mask = save_offset < kv_end
544
+ data = tl.load(out_cache_ptr + load_offset, mask=mask)
545
+ tl.store(token_pool + save_offset, data, mask=mask)
546
+
547
+
548
+ @triton.jit
549
+ def generate_draft_decode_kv_indices(
550
+ req_pool_indices,
551
+ req_to_token,
552
+ paged_kernel_lens,
553
+ kv_indices,
554
+ kv_indptr,
555
+ positions,
556
+ num_seqs: tl.constexpr,
557
+ topk: tl.constexpr,
558
+ pool_len: tl.constexpr,
559
+ kv_indices_stride: tl.constexpr,
560
+ kv_indptr_stride: tl.constexpr,
561
+ bs_upper: tl.constexpr,
562
+ iter_upper: tl.constexpr,
563
+ num_tokens_upper: tl.constexpr,
564
+ ):
565
+ BLOCK_SIZE: tl.constexpr = 128
566
+ iters = tl.program_id(axis=0)
567
+ bid = tl.program_id(axis=1)
568
+ topk_id = tl.program_id(axis=2)
569
+
570
+ kv_indices += kv_indices_stride * iters
571
+ kv_indptr += kv_indptr_stride * iters
572
+ iters += 1
573
+
574
+ load_offset = tl.arange(0, bs_upper)
575
+ seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid)
576
+ seq_len = tl.load(paged_kernel_lens + bid)
577
+ cum_seq_len = tl.sum(seq_lens)
578
+
579
+ kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
580
+ kv_ptr = kv_indices + kv_offset
581
+ token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
582
+
583
+ kv_offset = tl.arange(0, BLOCK_SIZE)
584
+ num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
585
+ for _ in range(num_loop):
586
+ mask = kv_offset < seq_len
587
+ data = tl.load(token_pool_ptr + kv_offset, mask=mask)
588
+ tl.store(kv_ptr + kv_offset, data, mask=mask)
589
+ kv_offset += BLOCK_SIZE
590
+
591
+ extend_offset = tl.arange(0, iter_upper)
592
+ extend_data = tl.load(
593
+ token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id,
594
+ mask=extend_offset < iters,
595
+ )
596
+ tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
597
+
598
+ # Update kv_indptr
599
+ bs_offset = tl.arange(0, num_tokens_upper)
600
+
601
+ zid = bid * topk + topk_id
602
+ if zid == 0:
603
+ zid = num_seqs * topk
604
+ positions = tl.load(positions + bs_offset, mask=bs_offset < zid)
605
+ base = tl.sum(positions)
606
+ tl.store(kv_indptr + zid, base + zid * iters)
607
+
608
+
609
+ @torch.compile
610
+ def select_top_k_tokens(
611
+ i: int,
612
+ topk_p: torch.Tensor,
613
+ topk_index: torch.Tensor,
614
+ hidden_states: torch.Tensor,
615
+ scores: torch.Tensor,
616
+ topk: int,
617
+ ):
618
+ if i == 0:
619
+ # The first step after extend
620
+ input_ids = topk_index.flatten()
621
+ hidden_states = hidden_states.repeat_interleave(topk, dim=0)
622
+ scores = topk_p # shape: (b, topk)
623
+
624
+ tree_info = (
625
+ topk_p.unsqueeze(1), # shape: (b, 1, topk)
626
+ topk_index, # shape: (b, topk)
627
+ torch.arange(-1, topk, dtype=torch.long, device="cuda")
628
+ .unsqueeze(0)
629
+ .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
630
+ )
631
+
632
+ else:
633
+ # The later decode steps
634
+ expand_scores = torch.mul(
635
+ scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
636
+ ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
637
+
638
+ topk_cs_p, topk_cs_index = fast_topk(
639
+ expand_scores.flatten(start_dim=1), topk, dim=-1
640
+ ) # (b, topk)
641
+ scores = topk_cs_p # shape: (b, topk)
642
+
643
+ topk_index = topk_index.reshape(-1, topk**2)
644
+ input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
645
+
646
+ selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
647
+ 0, hidden_states.shape[0], step=topk, device="cuda"
648
+ ).repeat_interleave(topk)
649
+ hidden_states = hidden_states[selected_input_index, :]
650
+
651
+ tree_info = (
652
+ expand_scores, # shape: (b, topk, topk)
653
+ topk_index, # shape: (b, topk * topk)
654
+ topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk)
655
+ )
656
+
657
+ return input_ids, hidden_states, scores, tree_info
658
+
659
+
660
+ def fast_topk(values, topk, dim):
661
+ if topk == 1:
662
+ # Use max along the specified dimension to get both value and index
663
+ max_value, max_index = torch.max(values, dim=dim)
664
+ return max_value.unsqueeze(1), max_index.unsqueeze(1)
665
+ else:
666
+ # Use topk for efficiency with larger k values
667
+ return torch.topk(values, topk, dim=dim)