sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +133 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +32 -21
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +133 -30
- sglang/srt/managers/scheduler.py +273 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +27 -13
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +208 -77
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +124 -28
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +99 -9
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,10 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
|
22
22
|
if TYPE_CHECKING:
|
23
23
|
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
24
24
|
|
25
|
+
import logging
|
26
|
+
|
27
|
+
logger = logging.getLogger(__name__)
|
28
|
+
|
25
29
|
|
26
30
|
class EAGLEDraftCudaGraphRunner:
|
27
31
|
def __init__(self, eagle_worker: EAGLEWorker):
|
@@ -33,13 +37,10 @@ class EAGLEDraftCudaGraphRunner:
|
|
33
37
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
34
38
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
35
39
|
self.tp_size = self.model_runner.tp_size
|
36
|
-
self.dp_size = model_runner.server_args.dp_size
|
37
40
|
self.topk = model_runner.server_args.speculative_eagle_topk
|
38
41
|
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
39
42
|
server_args = model_runner.server_args
|
40
43
|
|
41
|
-
assert self.disable_padding
|
42
|
-
|
43
44
|
# Batch sizes to capture
|
44
45
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
45
46
|
self.num_tokens_per_bs = server_args.speculative_eagle_topk
|
@@ -51,6 +52,9 @@ class EAGLEDraftCudaGraphRunner:
|
|
51
52
|
self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
|
52
53
|
0
|
53
54
|
].get_cuda_graph_seq_len_fill_value()
|
55
|
+
self.seq_lens_cpu = torch.full(
|
56
|
+
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
57
|
+
)
|
54
58
|
|
55
59
|
if self.enable_torch_compile:
|
56
60
|
set_torch_compile_config()
|
@@ -169,6 +173,13 @@ class EAGLEDraftCudaGraphRunner:
|
|
169
173
|
set_global_graph_memory_pool(graph.pool())
|
170
174
|
return graph, out
|
171
175
|
|
176
|
+
def _postprocess_output_to_raw_bs(self, out, raw_bs):
|
177
|
+
score_list, token_list, parents_list = out
|
178
|
+
score_list = [x[:raw_bs] for x in score_list]
|
179
|
+
token_list = [x[:raw_bs] for x in token_list]
|
180
|
+
parents_list = [x[:raw_bs] for x in parents_list]
|
181
|
+
return (score_list, token_list, parents_list)
|
182
|
+
|
172
183
|
def replay(self, forward_batch: ForwardBatch):
|
173
184
|
assert forward_batch.out_cache_loc is not None
|
174
185
|
raw_bs = forward_batch.batch_size
|
@@ -180,6 +191,9 @@ class EAGLEDraftCudaGraphRunner:
|
|
180
191
|
if bs != raw_bs:
|
181
192
|
self.seq_lens.fill_(1)
|
182
193
|
self.out_cache_loc.zero_()
|
194
|
+
self.positions.zero_()
|
195
|
+
|
196
|
+
num_tokens = bs * self.num_tokens_per_bs
|
183
197
|
|
184
198
|
# Common inputs
|
185
199
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
@@ -193,11 +207,33 @@ class EAGLEDraftCudaGraphRunner:
|
|
193
207
|
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
|
194
208
|
|
195
209
|
# Attention backend
|
210
|
+
if bs != raw_bs:
|
211
|
+
forward_batch.batch_size = bs
|
212
|
+
forward_batch.seq_lens = self.seq_lens[:bs]
|
213
|
+
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
|
214
|
+
forward_batch.positions = self.positions[:num_tokens]
|
215
|
+
|
216
|
+
# Special handle for seq_len_cpu used when flashinfer mla is used
|
217
|
+
if (forward_batch.decode_seq_lens_cpu is not None) and (bs != raw_bs):
|
218
|
+
self.seq_lens_cpu.fill_(1)
|
219
|
+
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
|
220
|
+
forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:bs]
|
221
|
+
|
196
222
|
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
|
197
|
-
forward_batch,
|
223
|
+
forward_batch, bs
|
198
224
|
)
|
199
225
|
|
200
226
|
# Replay
|
201
227
|
self.graphs[bs].replay()
|
228
|
+
out = self.output_buffers[bs]
|
202
229
|
|
203
|
-
|
230
|
+
if bs != raw_bs:
|
231
|
+
out = self._postprocess_output_to_raw_bs(out, raw_bs)
|
232
|
+
forward_batch.batch_size = raw_bs
|
233
|
+
forward_batch.positions = self.positions[:raw_num_token]
|
234
|
+
forward_batch.seq_lens = self.seq_lens[:raw_bs]
|
235
|
+
forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs]
|
236
|
+
if forward_batch.decode_seq_lens_cpu is not None:
|
237
|
+
forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:raw_bs]
|
238
|
+
|
239
|
+
return out
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from dataclasses import dataclass
|
4
|
-
from typing import TYPE_CHECKING, List
|
4
|
+
from typing import TYPE_CHECKING, List, Optional
|
5
5
|
|
6
6
|
import torch
|
7
7
|
import torch.nn.functional as F
|
@@ -13,18 +13,26 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
13
13
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
14
14
|
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
15
15
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
16
|
-
from sglang.srt.speculative.build_eagle_tree import
|
17
|
-
|
18
|
-
build_tree_kernel_efficient,
|
19
|
-
)
|
20
|
-
from sglang.srt.utils import is_cuda_available
|
16
|
+
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
17
|
+
from sglang.srt.utils import is_cuda_available, is_hip
|
21
18
|
|
22
19
|
if is_cuda_available():
|
23
|
-
from sgl_kernel import
|
20
|
+
from sgl_kernel import (
|
21
|
+
top_k_renorm_prob,
|
22
|
+
top_p_renorm_prob,
|
23
|
+
tree_speculative_sampling_target_only,
|
24
|
+
verify_tree_greedy,
|
25
|
+
)
|
26
|
+
elif is_hip():
|
27
|
+
from sgl_kernel import verify_tree_greedy
|
24
28
|
|
25
29
|
if TYPE_CHECKING:
|
26
30
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
27
31
|
|
32
|
+
import logging
|
33
|
+
|
34
|
+
logger = logging.getLogger(__name__)
|
35
|
+
|
28
36
|
|
29
37
|
@dataclass
|
30
38
|
class EagleDraftInput:
|
@@ -47,44 +55,32 @@ class EagleDraftInput:
|
|
47
55
|
kv_indptr: torch.Tensor = None
|
48
56
|
kv_indices: torch.Tensor = None
|
49
57
|
|
50
|
-
|
51
|
-
# e.g. [0, 2, 3, 4] if only the 1st request is finished
|
52
|
-
keep_indices: List[int] = None
|
58
|
+
all_padding_lens: Optional[torch.Tensor] = None
|
53
59
|
|
54
60
|
def prepare_for_extend(self, batch: ScheduleBatch):
|
55
|
-
assert batch.input_ids.numel() == batch.out_cache_loc.shape[0]
|
56
61
|
# Prefill only generate 1 token.
|
57
62
|
assert len(self.verified_id) == len(batch.seq_lens)
|
58
63
|
|
59
64
|
pt = 0
|
60
65
|
for i, extend_len in enumerate(batch.extend_lens):
|
61
66
|
input_ids = batch.input_ids[pt : pt + extend_len]
|
62
|
-
batch.input_ids[pt : pt + extend_len] = torch.
|
67
|
+
batch.input_ids[pt : pt + extend_len] = torch.cat(
|
63
68
|
(input_ids[1:], self.verified_id[i].reshape(1))
|
64
69
|
)
|
65
70
|
pt += extend_len
|
66
71
|
|
67
|
-
def prepare_extend_after_decode(
|
68
|
-
|
72
|
+
def prepare_extend_after_decode(
|
73
|
+
self,
|
74
|
+
batch: ScheduleBatch,
|
75
|
+
speculative_num_steps: int,
|
76
|
+
):
|
77
|
+
assert len(self.verified_id) == len(batch.out_cache_loc)
|
69
78
|
accept_length_cpu = batch.spec_info.accept_length_cpu
|
70
79
|
batch.extend_lens = [x + 1 for x in accept_length_cpu]
|
71
80
|
batch.extend_num_tokens = sum(batch.extend_lens)
|
72
81
|
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
82
|
+
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
|
73
83
|
seq_lens_cpu = batch.seq_lens.tolist()
|
74
|
-
assert len(batch.req_pool_indices) == len(batch.reqs)
|
75
|
-
|
76
|
-
pt = 0
|
77
|
-
i = 0
|
78
|
-
self.keep_indices = []
|
79
|
-
for idx, req in enumerate(batch.reqs):
|
80
|
-
if req.finished():
|
81
|
-
continue
|
82
|
-
self.keep_indices.append(idx)
|
83
|
-
# assert seq_len - pre_len == req.extend_input_len
|
84
|
-
input_len = batch.extend_lens[i]
|
85
|
-
seq_len = seq_lens_cpu[i]
|
86
|
-
pt += input_len
|
87
|
-
i += 1
|
88
84
|
|
89
85
|
self.positions = torch.empty_like(self.verified_id, dtype=torch.long)
|
90
86
|
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
|
@@ -112,10 +108,6 @@ class EagleDraftInput:
|
|
112
108
|
req_to_token: torch.Tensor,
|
113
109
|
):
|
114
110
|
bs = self.accept_length.numel()
|
115
|
-
keep_indices = torch.tensor(self.keep_indices, device=req_pool_indices.device)
|
116
|
-
req_pool_indices = req_pool_indices[keep_indices]
|
117
|
-
assert req_pool_indices.shape[0] == bs
|
118
|
-
assert req_pool_indices.shape[0] == paged_kernel_lens.shape[0]
|
119
111
|
|
120
112
|
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
121
113
|
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
|
@@ -172,7 +164,7 @@ class EagleVerifyOutput:
|
|
172
164
|
# Accepeted token length per sequence in a batch in CPU.
|
173
165
|
accept_length_per_req_cpu: List[int]
|
174
166
|
# Accepeted indices from logits_output.next_token_logits
|
175
|
-
|
167
|
+
accepeted_indices: torch.Tensor
|
176
168
|
|
177
169
|
|
178
170
|
@dataclass
|
@@ -200,67 +192,38 @@ class EagleVerifyInput:
|
|
200
192
|
topk: int,
|
201
193
|
spec_steps: int,
|
202
194
|
num_verify_tokens: int,
|
203
|
-
is_all_greedy: bool,
|
204
195
|
):
|
205
|
-
|
206
|
-
tree_mask,
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
position,
|
224
|
-
retrive_index,
|
225
|
-
None,
|
226
|
-
None,
|
227
|
-
retrive_cum_len,
|
228
|
-
num_verify_tokens,
|
229
|
-
spec_steps,
|
230
|
-
CaptureHiddenMode.FULL,
|
231
|
-
)
|
232
|
-
else:
|
233
|
-
(
|
234
|
-
tree_mask,
|
235
|
-
position,
|
236
|
-
retrive_index,
|
237
|
-
retrive_next_token,
|
238
|
-
retrive_next_sibling,
|
239
|
-
draft_tokens,
|
240
|
-
) = build_tree_kernel_efficient(
|
241
|
-
verified_id,
|
242
|
-
score_list,
|
243
|
-
token_list,
|
244
|
-
parents_list,
|
245
|
-
seq_lens,
|
246
|
-
seq_lens_sum,
|
247
|
-
topk,
|
248
|
-
spec_steps,
|
249
|
-
num_verify_tokens,
|
250
|
-
)
|
196
|
+
(
|
197
|
+
tree_mask,
|
198
|
+
position,
|
199
|
+
retrive_index,
|
200
|
+
retrive_next_token,
|
201
|
+
retrive_next_sibling,
|
202
|
+
draft_tokens,
|
203
|
+
) = build_tree_kernel_efficient(
|
204
|
+
verified_id,
|
205
|
+
score_list,
|
206
|
+
token_list,
|
207
|
+
parents_list,
|
208
|
+
seq_lens,
|
209
|
+
seq_lens_sum,
|
210
|
+
topk,
|
211
|
+
spec_steps,
|
212
|
+
num_verify_tokens,
|
213
|
+
)
|
251
214
|
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
215
|
+
return cls(
|
216
|
+
draft_tokens,
|
217
|
+
tree_mask,
|
218
|
+
position,
|
219
|
+
retrive_index,
|
220
|
+
retrive_next_token,
|
221
|
+
retrive_next_sibling,
|
222
|
+
None,
|
223
|
+
num_verify_tokens,
|
224
|
+
spec_steps,
|
225
|
+
CaptureHiddenMode.FULL,
|
226
|
+
)
|
264
227
|
|
265
228
|
def prepare_for_verify(self, batch: ScheduleBatch):
|
266
229
|
batch.input_ids = self.draft_token
|
@@ -291,7 +254,6 @@ class EagleVerifyInput:
|
|
291
254
|
dtype=torch.int32,
|
292
255
|
device="cuda",
|
293
256
|
)
|
294
|
-
|
295
257
|
cum_kv_seq_len = torch.zeros(
|
296
258
|
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
297
259
|
)
|
@@ -304,7 +266,6 @@ class EagleVerifyInput:
|
|
304
266
|
dtype=torch.int32,
|
305
267
|
device="cuda",
|
306
268
|
)
|
307
|
-
|
308
269
|
create_flashinfer_kv_indices_triton[(batch_size,)](
|
309
270
|
req_to_token,
|
310
271
|
req_pool_indices,
|
@@ -322,65 +283,79 @@ class EagleVerifyInput:
|
|
322
283
|
logits_output: torch.Tensor,
|
323
284
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
324
285
|
) -> torch.Tensor:
|
325
|
-
"""
|
326
|
-
|
286
|
+
"""
|
327
287
|
Verify and find accepted tokens based on logits output and batch
|
328
288
|
(which contains spec decoding information).
|
329
289
|
|
290
|
+
WARNING: This API in-place modifies the states of logits_output
|
291
|
+
|
330
292
|
This API updates values inside logits_output based on the accepted
|
331
293
|
tokens. I.e., logits_output.next_token_logits only contains
|
332
294
|
accepeted token logits.
|
333
295
|
"""
|
334
|
-
|
335
|
-
|
336
|
-
|
296
|
+
bs = self.retrive_index.shape[0]
|
297
|
+
candidates = self.draft_token.reshape(bs, self.draft_token_num)
|
298
|
+
sampling_info = batch.sampling_info
|
299
|
+
|
300
|
+
predict_shape = list(logits_output.next_token_logits.shape)[:-1]
|
301
|
+
predict_shape[-1] += 1
|
302
|
+
predict = torch.empty(predict_shape, dtype=torch.int32, device="cuda")
|
303
|
+
accept_index = torch.full(
|
304
|
+
(bs, self.spec_steps + 1), -1, dtype=torch.int32, device="cuda"
|
337
305
|
)
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
306
|
+
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
|
307
|
+
|
308
|
+
if sampling_info.penalizer_orchestrator.is_required:
|
309
|
+
# This is a relaxed version of penalties for speculative decoding.
|
310
|
+
linear_penalty = torch.zeros(
|
311
|
+
(bs, logits_output.next_token_logits.shape[1]),
|
312
|
+
dtype=torch.float32,
|
313
|
+
device="cuda",
|
345
314
|
)
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
accept_mask = candidates[:, 1:] == target_predict[:, :-1]
|
350
|
-
|
351
|
-
accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1)
|
352
|
-
max_draft_len = self.retrive_index.shape[-1]
|
353
|
-
accept_index = torch.full(
|
354
|
-
(bs, max_draft_len), -1, dtype=torch.int32, device="cuda"
|
315
|
+
sampling_info.apply_logits_bias(linear_penalty)
|
316
|
+
logits_output.next_token_logits.add_(
|
317
|
+
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
|
355
318
|
)
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
self.
|
367
|
-
|
319
|
+
|
320
|
+
if batch.sampling_info.is_all_greedy:
|
321
|
+
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
322
|
+
target_predict = target_predict.reshape(bs, self.draft_token_num)
|
323
|
+
|
324
|
+
verify_tree_greedy(
|
325
|
+
predicts=predict, # mutable
|
326
|
+
accept_index=accept_index, # mutable
|
327
|
+
accept_token_num=accept_length, # mutable
|
328
|
+
candidates=candidates.to(torch.int32),
|
329
|
+
retrive_index=self.retrive_index.to(torch.int32),
|
330
|
+
retrive_next_token=self.retrive_next_token.to(torch.int32),
|
331
|
+
retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
|
332
|
+
target_predict=target_predict.to(torch.int32),
|
368
333
|
)
|
369
334
|
else:
|
370
|
-
#
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
335
|
+
# apply temperature and get target probs
|
336
|
+
expanded_temperature = torch.repeat_interleave(
|
337
|
+
sampling_info.temperatures, self.draft_token_num, dim=0
|
338
|
+
) # (bs * draft_token_num, 1)
|
339
|
+
|
340
|
+
target_probs = F.softmax(
|
341
|
+
logits_output.next_token_logits / expanded_temperature, dim=-1
|
342
|
+
) # (bs * draft_token_num, vocab_size)
|
343
|
+
target_probs = top_k_renorm_prob(
|
344
|
+
target_probs,
|
345
|
+
torch.repeat_interleave(
|
346
|
+
sampling_info.top_ks, self.draft_token_num, dim=0
|
347
|
+
),
|
348
|
+
) # (bs * draft_token_num, vocab_size)
|
349
|
+
target_probs = top_p_renorm_prob(
|
350
|
+
target_probs,
|
351
|
+
torch.repeat_interleave(
|
352
|
+
sampling_info.top_ps, self.draft_token_num, dim=0
|
353
|
+
),
|
378
354
|
)
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
target_probs, 0, dtype=torch.float32, device="cuda"
|
355
|
+
target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
|
356
|
+
|
357
|
+
draft_probs = torch.zeros(
|
358
|
+
target_probs.shape, dtype=torch.float32, device="cuda"
|
384
359
|
)
|
385
360
|
coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
|
386
361
|
tree_speculative_sampling_target_only(
|
@@ -394,6 +369,12 @@ class EagleVerifyInput:
|
|
394
369
|
uniform_samples=coins,
|
395
370
|
target_probs=target_probs,
|
396
371
|
draft_probs=draft_probs,
|
372
|
+
threshold_single=global_server_args_dict[
|
373
|
+
"speculative_accept_threshold_single"
|
374
|
+
],
|
375
|
+
threshold_acc=global_server_args_dict[
|
376
|
+
"speculative_accept_threshold_acc"
|
377
|
+
],
|
397
378
|
deterministic=True,
|
398
379
|
)
|
399
380
|
|
@@ -425,119 +406,94 @@ class EagleVerifyInput:
|
|
425
406
|
new_accept_index.extend(new_accept_index_)
|
426
407
|
unfinished_index.append(i)
|
427
408
|
req.spec_verify_ct += 1
|
428
|
-
accept_length = (accept_index != -1).sum(dim=1) - 1
|
429
|
-
|
430
|
-
accept_index = accept_index[accept_index != -1]
|
431
|
-
accept_length_cpu = accept_length.tolist()
|
432
|
-
verified_id = predict[accept_index]
|
433
|
-
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
434
|
-
evict_mask[accept_index] = False
|
435
|
-
mem_need_free_idx = batch.out_cache_loc[evict_mask]
|
436
|
-
token_to_kv_pool_allocator.free(mem_need_free_idx)
|
437
|
-
assign_req_to_token_pool[(bs,)](
|
438
|
-
batch.req_pool_indices,
|
439
|
-
batch.req_to_token_pool.req_to_token,
|
440
|
-
batch.seq_lens,
|
441
|
-
batch.seq_lens + accept_length + 1,
|
442
|
-
batch.out_cache_loc[accept_index],
|
443
|
-
batch.req_to_token_pool.req_to_token.shape[1],
|
444
|
-
triton.next_power_of_2(bs),
|
445
|
-
)
|
446
|
-
batch.seq_lens.add_(accept_length + 1)
|
447
|
-
|
448
|
-
draft_input = EagleDraftInput()
|
449
|
-
if len(new_accept_index) > 0:
|
450
|
-
new_accept_index = torch.tensor(new_accept_index, device="cuda")
|
451
|
-
draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
|
452
|
-
draft_input.verified_id = predict[new_accept_index]
|
453
|
-
draft_input.accept_length = accept_length[unfinished_index]
|
454
|
-
draft_input.accept_length_cpu = [
|
455
|
-
accept_length_cpu[i] for i in unfinished_index
|
456
|
-
]
|
457
|
-
if has_finished:
|
458
|
-
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
|
459
|
-
else:
|
460
|
-
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
461
|
-
batch.out_cache_loc = batch.out_cache_loc[new_accept_index]
|
462
|
-
|
463
|
-
return EagleVerifyOutput(
|
464
|
-
draft_input=draft_input,
|
465
|
-
logits_output=logits_output,
|
466
|
-
verified_id=verified_id,
|
467
|
-
accept_length_per_req_cpu=accept_length_cpu,
|
468
|
-
accepeted_indices_cpu=accept_index,
|
469
|
-
)
|
470
409
|
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
410
|
+
if not has_finished:
|
411
|
+
accept_index = accept_index[accept_index != -1]
|
412
|
+
verified_id = predict[accept_index]
|
413
|
+
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
414
|
+
evict_mask[accept_index] = False
|
415
|
+
mem_need_free_idx = batch.out_cache_loc[evict_mask]
|
416
|
+
token_to_kv_pool_allocator.free(mem_need_free_idx)
|
417
|
+
batch.out_cache_loc = batch.out_cache_loc[accept_index]
|
418
|
+
assign_req_to_token_pool[(bs,)](
|
419
|
+
batch.req_pool_indices,
|
420
|
+
batch.req_to_token_pool.req_to_token,
|
421
|
+
batch.seq_lens,
|
422
|
+
batch.seq_lens + accept_length + 1,
|
423
|
+
batch.out_cache_loc,
|
424
|
+
batch.req_to_token_pool.req_to_token.shape[1],
|
425
|
+
triton.next_power_of_2(bs),
|
426
|
+
)
|
427
|
+
batch.seq_lens.add_(accept_length + 1)
|
428
|
+
accept_length_cpu = accept_length.tolist()
|
429
|
+
|
430
|
+
draft_input = EagleDraftInput()
|
431
|
+
draft_input.hidden_states = batch.spec_info.hidden_states[accept_index]
|
432
|
+
draft_input.verified_id = verified_id
|
433
|
+
draft_input.accept_length = accept_length
|
434
|
+
draft_input.accept_length_cpu = accept_length_cpu
|
435
|
+
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
436
|
+
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
|
437
|
+
|
438
|
+
return EagleVerifyOutput(
|
439
|
+
draft_input=draft_input,
|
440
|
+
logits_output=logits_output,
|
441
|
+
verified_id=verified_id,
|
442
|
+
accept_length_per_req_cpu=accept_length_cpu,
|
443
|
+
accepeted_indices=accept_index,
|
444
|
+
)
|
445
|
+
else:
|
446
|
+
accept_length = (accept_index != -1).sum(dim=1) - 1
|
447
|
+
accept_index = accept_index[accept_index != -1]
|
448
|
+
verified_id = predict[accept_index]
|
449
|
+
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
450
|
+
evict_mask[accept_index] = False
|
451
|
+
mem_need_free_idx = batch.out_cache_loc[evict_mask]
|
452
|
+
token_to_kv_pool_allocator.free(mem_need_free_idx)
|
453
|
+
assign_req_to_token_pool[(bs,)](
|
454
|
+
batch.req_pool_indices,
|
455
|
+
batch.req_to_token_pool.req_to_token,
|
456
|
+
batch.seq_lens,
|
457
|
+
batch.seq_lens + accept_length + 1,
|
458
|
+
batch.out_cache_loc[accept_index],
|
459
|
+
batch.req_to_token_pool.req_to_token.shape[1],
|
460
|
+
triton.next_power_of_2(bs),
|
461
|
+
)
|
462
|
+
batch.seq_lens.add_(accept_length + 1)
|
463
|
+
accept_length_cpu = accept_length.tolist()
|
464
|
+
|
465
|
+
draft_input = EagleDraftInput()
|
466
|
+
if len(new_accept_index) > 0:
|
467
|
+
new_accept_index = torch.tensor(new_accept_index, device="cuda")
|
468
|
+
draft_input.hidden_states = batch.spec_info.hidden_states[
|
469
|
+
new_accept_index
|
470
|
+
]
|
471
|
+
draft_input.verified_id = predict[new_accept_index]
|
472
|
+
draft_input.accept_length = accept_length[unfinished_index]
|
473
|
+
draft_input.accept_length_cpu = [
|
474
|
+
accept_length_cpu[i] for i in unfinished_index
|
475
|
+
]
|
476
|
+
if has_finished:
|
477
|
+
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
|
478
|
+
unfinished_index
|
479
|
+
]
|
480
|
+
draft_input.req_pool_indices_for_draft_extend = (
|
481
|
+
batch.req_pool_indices[unfinished_index]
|
482
|
+
)
|
483
|
+
else:
|
484
|
+
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
485
|
+
draft_input.req_pool_indices_for_draft_extend = (
|
486
|
+
batch.req_pool_indices
|
487
|
+
)
|
488
|
+
batch.out_cache_loc = batch.out_cache_loc[new_accept_index]
|
489
|
+
|
490
|
+
return EagleVerifyOutput(
|
491
|
+
draft_input=draft_input,
|
492
|
+
logits_output=logits_output,
|
493
|
+
verified_id=verified_id,
|
494
|
+
accept_length_per_req_cpu=accept_length_cpu,
|
495
|
+
accepeted_indices=accept_index,
|
496
|
+
)
|
541
497
|
|
542
498
|
|
543
499
|
@triton.jit
|