sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +26 -4
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +676 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +49 -8
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +42 -8
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +78 -13
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +133 -55
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +434 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +41 -19
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +60 -20
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +80 -53
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +25 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -19
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +78 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +87 -33
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +67 -30
- sglang/srt/lora/mem_pool.py +117 -52
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +18 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +43 -5
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -30
- sglang/srt/managers/scheduler.py +290 -31
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -24
- sglang/srt/managers/tp_worker.py +4 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +36 -21
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +75 -8
- sglang/srt/model_loader/loader.py +171 -3
- sglang/srt/model_loader/weight_utils.py +51 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +31 -88
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +329 -73
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +694 -0
- sglang/srt/models/gemma3_mm.py +468 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +201 -104
- sglang/srt/openai_api/protocol.py +33 -7
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +114 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +140 -54
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +215 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +29 -2
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +56 -5
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -3,8 +3,13 @@
|
|
3
3
|
from typing import List
|
4
4
|
|
5
5
|
import torch
|
6
|
-
|
7
|
-
from
|
6
|
+
|
7
|
+
from sglang.srt.utils import is_cuda_available, is_hip
|
8
|
+
|
9
|
+
if is_cuda_available() or is_hip():
|
10
|
+
from sgl_kernel import (
|
11
|
+
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
|
12
|
+
)
|
8
13
|
|
9
14
|
|
10
15
|
def build_tree_kernel_efficient_preprocess(
|
@@ -23,7 +28,6 @@ def build_tree_kernel_efficient_preprocess(
|
|
23
28
|
top_scores = torch.topk(score_list, num_verify_tokens - 1, dim=-1)
|
24
29
|
top_scores_index = top_scores.indices
|
25
30
|
top_scores_index = torch.sort(top_scores_index).values
|
26
|
-
|
27
31
|
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
|
28
32
|
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
|
29
33
|
|
@@ -108,296 +112,6 @@ def build_tree_kernel_efficient(
|
|
108
112
|
)
|
109
113
|
|
110
114
|
|
111
|
-
def build_tree_kernel(
|
112
|
-
verified_id: torch.Tensor,
|
113
|
-
score_list: List[torch.Tensor],
|
114
|
-
token_list: List[torch.Tensor],
|
115
|
-
parents_list: List[torch.Tensor],
|
116
|
-
seq_lens: torch.Tensor,
|
117
|
-
seq_lens_sum: int,
|
118
|
-
topk: int,
|
119
|
-
spec_steps: int,
|
120
|
-
num_verify_tokens: int,
|
121
|
-
):
|
122
|
-
parent_list, top_scores_index, draft_tokens = (
|
123
|
-
build_tree_kernel_efficient_preprocess(
|
124
|
-
verified_id,
|
125
|
-
score_list,
|
126
|
-
token_list,
|
127
|
-
parents_list,
|
128
|
-
num_verify_tokens,
|
129
|
-
)
|
130
|
-
)
|
131
|
-
|
132
|
-
bs = seq_lens.numel()
|
133
|
-
device = seq_lens.device
|
134
|
-
|
135
|
-
tree_mask = torch.full(
|
136
|
-
(
|
137
|
-
seq_lens_sum * num_verify_tokens
|
138
|
-
+ num_verify_tokens * num_verify_tokens * bs,
|
139
|
-
),
|
140
|
-
True,
|
141
|
-
device=device,
|
142
|
-
)
|
143
|
-
retrive_index = torch.full(
|
144
|
-
(bs, num_verify_tokens, spec_steps + 2), -1, device=device, dtype=torch.long
|
145
|
-
)
|
146
|
-
positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
|
147
|
-
|
148
|
-
sgl_build_tree_kernel(
|
149
|
-
parent_list,
|
150
|
-
top_scores_index,
|
151
|
-
seq_lens.to(torch.int32),
|
152
|
-
tree_mask,
|
153
|
-
positions,
|
154
|
-
retrive_index,
|
155
|
-
topk,
|
156
|
-
spec_steps,
|
157
|
-
num_verify_tokens,
|
158
|
-
)
|
159
|
-
|
160
|
-
index = retrive_index.sum(dim=-1) != -spec_steps - 2
|
161
|
-
cum_len = torch.cumsum(torch.sum(index, dim=-1), dim=-1)
|
162
|
-
retrive_cum_len = torch.zeros(
|
163
|
-
(cum_len.numel() + 1,), dtype=torch.int32, device="cuda"
|
164
|
-
)
|
165
|
-
retrive_cum_len[1:] = cum_len
|
166
|
-
# TODO: this indexing cause a synchronization, optimize this
|
167
|
-
retrive_index = retrive_index[index]
|
168
|
-
return tree_mask, positions, retrive_index, retrive_cum_len, draft_tokens
|
169
|
-
|
170
|
-
|
171
|
-
def test_build_tree_kernel():
|
172
|
-
def findp(p_i, index, parent_list):
|
173
|
-
pos = index // 10
|
174
|
-
index_list = index.tolist()
|
175
|
-
parent_list = parent_list.tolist()
|
176
|
-
res = [p_i]
|
177
|
-
while True:
|
178
|
-
p = pos[p_i]
|
179
|
-
if p == 0:
|
180
|
-
break
|
181
|
-
token_idx = parent_list[p]
|
182
|
-
p_i = index_list.index(token_idx)
|
183
|
-
res.append(p_i)
|
184
|
-
return res
|
185
|
-
|
186
|
-
def create_mask(seq_len, draft_token, index, parent_list, max_depth):
|
187
|
-
mask = []
|
188
|
-
positions = []
|
189
|
-
retrive_index = []
|
190
|
-
for i, lens in enumerate(seq_len.tolist()):
|
191
|
-
first_mask = torch.full((lens + draft_token,), True)
|
192
|
-
first_mask[-(draft_token - 1) :] = False
|
193
|
-
positions.append(lens)
|
194
|
-
mask.append(first_mask)
|
195
|
-
seq_order = []
|
196
|
-
first_index = torch.Tensor([0] + [-1] * (depth + 1)).cuda().to(torch.long)
|
197
|
-
r_index = [first_index]
|
198
|
-
for j in range(draft_token - 1):
|
199
|
-
mask.append(torch.full((lens + 1,), True))
|
200
|
-
idx = findp(j, index, parent_list)
|
201
|
-
|
202
|
-
seq_order.append(idx)
|
203
|
-
positions.append(len(idx) + seq_len)
|
204
|
-
t = torch.full((draft_token - 1,), False)
|
205
|
-
t[idx] = True
|
206
|
-
mask.append(t)
|
207
|
-
|
208
|
-
for i in range(1, draft_token - 1):
|
209
|
-
is_leaf = 0
|
210
|
-
for j in range(draft_token - 1):
|
211
|
-
if i in seq_order[j]:
|
212
|
-
is_leaf += 1
|
213
|
-
|
214
|
-
if is_leaf == 1:
|
215
|
-
order_list = [0] + [x + 1 for x in seq_order[i][::-1]]
|
216
|
-
for _ in range(max_depth + 1 - len(seq_order[i])):
|
217
|
-
order_list.append(-1)
|
218
|
-
order = torch.Tensor(order_list).cuda().to(torch.long)
|
219
|
-
r_index.append(order)
|
220
|
-
retrive_index.append(torch.stack(r_index))
|
221
|
-
|
222
|
-
return (
|
223
|
-
torch.cat(mask).cuda(),
|
224
|
-
torch.Tensor(positions).cuda().to(torch.long),
|
225
|
-
torch.stack(retrive_index),
|
226
|
-
)
|
227
|
-
|
228
|
-
index = (
|
229
|
-
torch.Tensor(
|
230
|
-
[
|
231
|
-
0,
|
232
|
-
1,
|
233
|
-
2,
|
234
|
-
3,
|
235
|
-
10,
|
236
|
-
11,
|
237
|
-
12,
|
238
|
-
13,
|
239
|
-
20,
|
240
|
-
21,
|
241
|
-
22,
|
242
|
-
30,
|
243
|
-
110,
|
244
|
-
130,
|
245
|
-
150,
|
246
|
-
160,
|
247
|
-
210,
|
248
|
-
211,
|
249
|
-
212,
|
250
|
-
213,
|
251
|
-
214,
|
252
|
-
215,
|
253
|
-
216,
|
254
|
-
217,
|
255
|
-
218,
|
256
|
-
219,
|
257
|
-
220,
|
258
|
-
230,
|
259
|
-
310,
|
260
|
-
311,
|
261
|
-
312,
|
262
|
-
313,
|
263
|
-
314,
|
264
|
-
315,
|
265
|
-
316,
|
266
|
-
317,
|
267
|
-
320,
|
268
|
-
321,
|
269
|
-
322,
|
270
|
-
330,
|
271
|
-
360,
|
272
|
-
380,
|
273
|
-
390,
|
274
|
-
410,
|
275
|
-
411,
|
276
|
-
412,
|
277
|
-
413,
|
278
|
-
414,
|
279
|
-
415,
|
280
|
-
416,
|
281
|
-
417,
|
282
|
-
418,
|
283
|
-
419,
|
284
|
-
420,
|
285
|
-
421,
|
286
|
-
422,
|
287
|
-
423,
|
288
|
-
430,
|
289
|
-
431,
|
290
|
-
440,
|
291
|
-
441,
|
292
|
-
460,
|
293
|
-
470,
|
294
|
-
]
|
295
|
-
)
|
296
|
-
.to(torch.long)
|
297
|
-
.cuda()
|
298
|
-
)
|
299
|
-
|
300
|
-
parent_list = (
|
301
|
-
torch.Tensor(
|
302
|
-
[
|
303
|
-
-1,
|
304
|
-
0,
|
305
|
-
1,
|
306
|
-
2,
|
307
|
-
3,
|
308
|
-
4,
|
309
|
-
5,
|
310
|
-
6,
|
311
|
-
7,
|
312
|
-
8,
|
313
|
-
9,
|
314
|
-
10,
|
315
|
-
11,
|
316
|
-
12,
|
317
|
-
20,
|
318
|
-
30,
|
319
|
-
21,
|
320
|
-
13,
|
321
|
-
22,
|
322
|
-
40,
|
323
|
-
23,
|
324
|
-
110,
|
325
|
-
130,
|
326
|
-
160,
|
327
|
-
150,
|
328
|
-
190,
|
329
|
-
120,
|
330
|
-
111,
|
331
|
-
121,
|
332
|
-
200,
|
333
|
-
180,
|
334
|
-
210,
|
335
|
-
211,
|
336
|
-
212,
|
337
|
-
213,
|
338
|
-
214,
|
339
|
-
215,
|
340
|
-
216,
|
341
|
-
220,
|
342
|
-
230,
|
343
|
-
217,
|
344
|
-
310,
|
345
|
-
311,
|
346
|
-
312,
|
347
|
-
313,
|
348
|
-
320,
|
349
|
-
314,
|
350
|
-
321,
|
351
|
-
315,
|
352
|
-
316,
|
353
|
-
317,
|
354
|
-
]
|
355
|
-
)
|
356
|
-
.to(torch.long)
|
357
|
-
.cuda()
|
358
|
-
)
|
359
|
-
|
360
|
-
verified_seq_len = torch.Tensor([47]).to(torch.long).cuda()
|
361
|
-
bs = verified_seq_len.shape[0]
|
362
|
-
topk = 10
|
363
|
-
depth = 5 # depth <= 10
|
364
|
-
num_draft_token = 64
|
365
|
-
|
366
|
-
tree_mask = torch.full(
|
367
|
-
(
|
368
|
-
torch.sum(verified_seq_len).item() * num_draft_token
|
369
|
-
+ num_draft_token * num_draft_token * bs,
|
370
|
-
),
|
371
|
-
True,
|
372
|
-
).cuda()
|
373
|
-
retrive_index = torch.full(
|
374
|
-
(bs, num_draft_token, depth + 2), -1, device="cuda", dtype=torch.long
|
375
|
-
)
|
376
|
-
positions = torch.empty((bs * num_draft_token,), device="cuda", dtype=torch.long)
|
377
|
-
|
378
|
-
sgl_build_tree_kernel(
|
379
|
-
parent_list.unsqueeze(0),
|
380
|
-
index.unsqueeze(0),
|
381
|
-
verified_seq_len,
|
382
|
-
tree_mask,
|
383
|
-
positions,
|
384
|
-
retrive_index,
|
385
|
-
topk,
|
386
|
-
depth,
|
387
|
-
num_draft_token,
|
388
|
-
)
|
389
|
-
|
390
|
-
retrive_index = retrive_index[retrive_index.sum(dim=-1) != -depth - 2]
|
391
|
-
|
392
|
-
c_mask, c_positions, c_retive_index = create_mask(
|
393
|
-
verified_seq_len, num_draft_token, index, parent_list, depth
|
394
|
-
)
|
395
|
-
|
396
|
-
assert torch.allclose(tree_mask, c_mask), "tree mask has error."
|
397
|
-
assert torch.allclose(positions, c_positions), "positions has error."
|
398
|
-
assert torch.allclose(retrive_index, c_retive_index), "retrive_index has error."
|
399
|
-
|
400
|
-
|
401
115
|
def test_build_tree_kernel_efficient():
|
402
116
|
verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32)
|
403
117
|
score_list = [
|
@@ -611,59 +325,6 @@ def test_build_tree_kernel_efficient():
|
|
611
325
|
depth = 4
|
612
326
|
num_draft_token = 8
|
613
327
|
|
614
|
-
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
|
615
|
-
build_tree_kernel(
|
616
|
-
verified_id=verified_id,
|
617
|
-
score_list=score_list,
|
618
|
-
token_list=token_list,
|
619
|
-
parents_list=parents_list,
|
620
|
-
seq_lens=seq_lens,
|
621
|
-
seq_lens_sum=torch.sum(seq_lens).item(),
|
622
|
-
topk=topk,
|
623
|
-
spec_steps=depth,
|
624
|
-
num_verify_tokens=num_draft_token,
|
625
|
-
)
|
626
|
-
)
|
627
|
-
|
628
|
-
from sglang.srt.utils import first_rank_print
|
629
|
-
|
630
|
-
first_rank_print("=========== build tree kernel ==========")
|
631
|
-
# first_rank_print(f"{tree_mask=}", flush=True)
|
632
|
-
first_rank_print(f"{position=}", flush=True)
|
633
|
-
first_rank_print(f"{retrive_index=}", flush=True)
|
634
|
-
first_rank_print(f"{retrive_cum_len=}", flush=True)
|
635
|
-
first_rank_print(f"{draft_tokens=}", flush=True)
|
636
|
-
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
|
637
|
-
assert retrive_index.tolist() == [
|
638
|
-
[0, -1, -1, -1, -1, -1],
|
639
|
-
[0, 2, 4, 6, -1, -1],
|
640
|
-
[0, 1, 3, 5, 7, -1],
|
641
|
-
[8, -1, -1, -1, -1, -1],
|
642
|
-
[8, 9, 10, -1, -1, -1],
|
643
|
-
[8, 9, 12, -1, -1, -1],
|
644
|
-
[8, 9, 13, -1, -1, -1],
|
645
|
-
[8, 9, 11, 14, 15, -1],
|
646
|
-
]
|
647
|
-
assert retrive_cum_len.tolist() == [0, 3, 8]
|
648
|
-
assert draft_tokens.tolist() == [
|
649
|
-
29974,
|
650
|
-
29896,
|
651
|
-
29906,
|
652
|
-
29889,
|
653
|
-
29974,
|
654
|
-
29946,
|
655
|
-
29896,
|
656
|
-
29946,
|
657
|
-
13,
|
658
|
-
13,
|
659
|
-
22550,
|
660
|
-
4136,
|
661
|
-
16492,
|
662
|
-
8439,
|
663
|
-
29871,
|
664
|
-
29941,
|
665
|
-
]
|
666
|
-
|
667
328
|
(
|
668
329
|
tree_mask,
|
669
330
|
position,
|
@@ -725,4 +386,3 @@ def test_build_tree_kernel_efficient():
|
|
725
386
|
|
726
387
|
if __name__ == "__main__":
|
727
388
|
test_build_tree_kernel_efficient()
|
728
|
-
test_build_tree_kernel()
|
@@ -22,6 +22,10 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
|
22
22
|
if TYPE_CHECKING:
|
23
23
|
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
24
24
|
|
25
|
+
import logging
|
26
|
+
|
27
|
+
logger = logging.getLogger(__name__)
|
28
|
+
|
25
29
|
|
26
30
|
class EAGLEDraftCudaGraphRunner:
|
27
31
|
def __init__(self, eagle_worker: EAGLEWorker):
|
@@ -33,13 +37,10 @@ class EAGLEDraftCudaGraphRunner:
|
|
33
37
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
34
38
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
35
39
|
self.tp_size = self.model_runner.tp_size
|
36
|
-
self.dp_size = model_runner.server_args.dp_size
|
37
40
|
self.topk = model_runner.server_args.speculative_eagle_topk
|
38
41
|
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
39
42
|
server_args = model_runner.server_args
|
40
43
|
|
41
|
-
assert self.disable_padding
|
42
|
-
|
43
44
|
# Batch sizes to capture
|
44
45
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
45
46
|
self.num_tokens_per_bs = server_args.speculative_eagle_topk
|
@@ -51,6 +52,9 @@ class EAGLEDraftCudaGraphRunner:
|
|
51
52
|
self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
|
52
53
|
0
|
53
54
|
].get_cuda_graph_seq_len_fill_value()
|
55
|
+
self.seq_lens_cpu = torch.full(
|
56
|
+
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
57
|
+
)
|
54
58
|
|
55
59
|
if self.enable_torch_compile:
|
56
60
|
set_torch_compile_config()
|
@@ -169,6 +173,13 @@ class EAGLEDraftCudaGraphRunner:
|
|
169
173
|
set_global_graph_memory_pool(graph.pool())
|
170
174
|
return graph, out
|
171
175
|
|
176
|
+
def _postprocess_output_to_raw_bs(self, out, raw_bs):
|
177
|
+
score_list, token_list, parents_list = out
|
178
|
+
score_list = [x[:raw_bs] for x in score_list]
|
179
|
+
token_list = [x[:raw_bs] for x in token_list]
|
180
|
+
parents_list = [x[:raw_bs] for x in parents_list]
|
181
|
+
return (score_list, token_list, parents_list)
|
182
|
+
|
172
183
|
def replay(self, forward_batch: ForwardBatch):
|
173
184
|
assert forward_batch.out_cache_loc is not None
|
174
185
|
raw_bs = forward_batch.batch_size
|
@@ -180,6 +191,9 @@ class EAGLEDraftCudaGraphRunner:
|
|
180
191
|
if bs != raw_bs:
|
181
192
|
self.seq_lens.fill_(1)
|
182
193
|
self.out_cache_loc.zero_()
|
194
|
+
self.positions.zero_()
|
195
|
+
|
196
|
+
num_tokens = bs * self.num_tokens_per_bs
|
183
197
|
|
184
198
|
# Common inputs
|
185
199
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
@@ -193,11 +207,33 @@ class EAGLEDraftCudaGraphRunner:
|
|
193
207
|
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
|
194
208
|
|
195
209
|
# Attention backend
|
210
|
+
if bs != raw_bs:
|
211
|
+
forward_batch.batch_size = bs
|
212
|
+
forward_batch.seq_lens = self.seq_lens[:bs]
|
213
|
+
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
|
214
|
+
forward_batch.positions = self.positions[:num_tokens]
|
215
|
+
|
216
|
+
# Special handle for seq_len_cpu used when flashinfer mla is used
|
217
|
+
if (forward_batch.decode_seq_lens_cpu is not None) and (bs != raw_bs):
|
218
|
+
self.seq_lens_cpu.fill_(1)
|
219
|
+
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
|
220
|
+
forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:bs]
|
221
|
+
|
196
222
|
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
|
197
|
-
forward_batch,
|
223
|
+
forward_batch, bs
|
198
224
|
)
|
199
225
|
|
200
226
|
# Replay
|
201
227
|
self.graphs[bs].replay()
|
228
|
+
out = self.output_buffers[bs]
|
202
229
|
|
203
|
-
|
230
|
+
if bs != raw_bs:
|
231
|
+
out = self._postprocess_output_to_raw_bs(out, raw_bs)
|
232
|
+
forward_batch.batch_size = raw_bs
|
233
|
+
forward_batch.positions = self.positions[:raw_num_token]
|
234
|
+
forward_batch.seq_lens = self.seq_lens[:raw_bs]
|
235
|
+
forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs]
|
236
|
+
if forward_batch.decode_seq_lens_cpu is not None:
|
237
|
+
forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:raw_bs]
|
238
|
+
|
239
|
+
return out
|