sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/srt/layers/attention/__init__.py +14 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
- sglang/srt/layers/attention/flashinfer_backend.py +211 -81
- sglang/srt/layers/attention/torch_native_backend.py +1 -38
- sglang/srt/layers/attention/triton_backend.py +20 -11
- sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
- sglang/srt/layers/logits_processor.py +167 -212
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +187 -29
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -6
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/sampler.py +57 -21
- sglang/srt/layers/torchao_utils.py +17 -3
- sglang/srt/managers/io_struct.py +1 -2
- sglang/srt/managers/schedule_batch.py +26 -2
- sglang/srt/managers/schedule_policy.py +159 -90
- sglang/srt/managers/scheduler.py +62 -26
- sglang/srt/managers/tokenizer_manager.py +22 -20
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/model_executor/cuda_graph_runner.py +118 -73
- sglang/srt/model_executor/forward_batch_info.py +33 -8
- sglang/srt/model_executor/model_runner.py +63 -61
- sglang/srt/models/deepseek_v2.py +34 -7
- sglang/srt/models/grok.py +97 -26
- sglang/srt/openai_api/adapter.py +0 -17
- sglang/srt/openai_api/protocol.py +3 -3
- sglang/srt/sampling/sampling_batch_info.py +21 -0
- sglang/srt/sampling/sampling_params.py +9 -1
- sglang/srt/server.py +9 -5
- sglang/srt/server_args.py +108 -57
- sglang/srt/speculative/build_eagle_tree.py +347 -0
- sglang/srt/speculative/eagle_utils.py +618 -0
- sglang/srt/speculative/eagle_worker.py +170 -0
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/utils.py +15 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/METADATA +9 -8
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/RECORD +63 -39
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/top_level.txt +0 -0
@@ -222,10 +222,8 @@ class TokenizerManager:
|
|
222
222
|
is_single = obj.is_single
|
223
223
|
if is_single:
|
224
224
|
tokenized_obj = await self._tokenize_one_request(obj)
|
225
|
-
self.
|
226
|
-
async for response in self._wait_one_response(
|
227
|
-
obj, request, created_time
|
228
|
-
):
|
225
|
+
self._send_one_request(obj, tokenized_obj, created_time)
|
226
|
+
async for response in self._wait_one_response(obj, request):
|
229
227
|
yield response
|
230
228
|
else:
|
231
229
|
async for response in self._handle_batch_request(
|
@@ -306,16 +304,24 @@ class TokenizerManager:
|
|
306
304
|
|
307
305
|
return tokenized_obj
|
308
306
|
|
309
|
-
|
307
|
+
def _send_one_request(
|
310
308
|
self,
|
311
309
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
312
|
-
|
310
|
+
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
313
311
|
created_time: Optional[float] = None,
|
314
312
|
):
|
315
|
-
"""Wait for the response of one request."""
|
316
313
|
event = asyncio.Event()
|
317
314
|
state = ReqState([], False, event, obj, created_time=created_time)
|
318
315
|
self.rid_to_state[obj.rid] = state
|
316
|
+
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
317
|
+
|
318
|
+
async def _wait_one_response(
|
319
|
+
self,
|
320
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
321
|
+
request: Optional[fastapi.Request] = None,
|
322
|
+
):
|
323
|
+
"""Wait for the response of one request."""
|
324
|
+
state = self.rid_to_state[obj.rid]
|
319
325
|
|
320
326
|
while True:
|
321
327
|
try:
|
@@ -361,10 +367,8 @@ class TokenizerManager:
|
|
361
367
|
for i in range(batch_size):
|
362
368
|
tmp_obj = obj[i]
|
363
369
|
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
364
|
-
self.
|
365
|
-
generators.append(
|
366
|
-
self._wait_one_response(tmp_obj, request, created_time)
|
367
|
-
)
|
370
|
+
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
371
|
+
generators.append(self._wait_one_response(tmp_obj, request))
|
368
372
|
rids.append(tmp_obj.rid)
|
369
373
|
else:
|
370
374
|
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
@@ -389,10 +393,8 @@ class TokenizerManager:
|
|
389
393
|
tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
|
390
394
|
tokenized_obj.sampling_params.max_new_tokens = 0
|
391
395
|
tokenized_obj.stream = False
|
392
|
-
self.
|
393
|
-
await self._wait_one_response(
|
394
|
-
tmp_obj, request, created_time
|
395
|
-
).__anext__()
|
396
|
+
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
397
|
+
await self._wait_one_response(tmp_obj, request).__anext__()
|
396
398
|
|
397
399
|
# Expand requests, assign new rids for them, and send them
|
398
400
|
for i in range(batch_size):
|
@@ -400,10 +402,8 @@ class TokenizerManager:
|
|
400
402
|
tmp_obj = copy.copy(objs[i])
|
401
403
|
tokenized_obj = copy.copy(tokenized_objs[i])
|
402
404
|
tokenized_obj.rid = tmp_obj.regenerate_rid()
|
403
|
-
self.
|
404
|
-
generators.append(
|
405
|
-
self._wait_one_response(tmp_obj, request, created_time)
|
406
|
-
)
|
405
|
+
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
406
|
+
generators.append(self._wait_one_response(tmp_obj, request))
|
407
407
|
rids.append(tmp_obj.rid)
|
408
408
|
|
409
409
|
# Wait for all requests
|
@@ -699,6 +699,7 @@ class TokenizerManager:
|
|
699
699
|
)
|
700
700
|
else:
|
701
701
|
if completion_tokens >= 2:
|
702
|
+
# Compute time_per_output_token for the streaming case
|
702
703
|
self.metrics_collector.observe_time_per_output_token(
|
703
704
|
(time.time() - state.first_token_time)
|
704
705
|
/ (completion_tokens - 1)
|
@@ -714,7 +715,8 @@ class TokenizerManager:
|
|
714
715
|
self.metrics_collector.observe_e2e_request_latency(
|
715
716
|
time.time() - state.created_time
|
716
717
|
)
|
717
|
-
|
718
|
+
# Compute time_per_output_token for the non-streaming case
|
719
|
+
if not state.obj.stream and completion_tokens >= 1:
|
718
720
|
self.metrics_collector.observe_time_per_output_token(
|
719
721
|
(time.time() - state.created_time)
|
720
722
|
/ completion_tokens
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -30,7 +30,7 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_a
|
|
30
30
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
31
31
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
32
32
|
from sglang.srt.server_args import ServerArgs
|
33
|
-
from sglang.srt.utils import broadcast_pyobj, set_random_seed
|
33
|
+
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
34
34
|
|
35
35
|
logger = logging.getLogger(__name__)
|
36
36
|
|
@@ -45,13 +45,18 @@ class TpModelWorker:
|
|
45
45
|
tp_rank: int,
|
46
46
|
dp_rank: Optional[int],
|
47
47
|
nccl_port: int,
|
48
|
+
is_draft_worker: bool = False,
|
48
49
|
):
|
49
50
|
# Parse args
|
50
51
|
self.tp_rank = tp_rank
|
51
52
|
|
52
53
|
# Init model and tokenizer
|
53
54
|
self.model_config = ModelConfig(
|
54
|
-
|
55
|
+
(
|
56
|
+
server_args.model_path
|
57
|
+
if not is_draft_worker
|
58
|
+
else server_args.speculative_draft_model_path
|
59
|
+
),
|
55
60
|
trust_remote_code=server_args.trust_remote_code,
|
56
61
|
revision=server_args.revision,
|
57
62
|
context_length=server_args.context_length,
|
@@ -68,6 +73,7 @@ class TpModelWorker:
|
|
68
73
|
tp_size=server_args.tp_size,
|
69
74
|
nccl_port=nccl_port,
|
70
75
|
server_args=server_args,
|
76
|
+
is_draft_worker=is_draft_worker,
|
71
77
|
)
|
72
78
|
if server_args.skip_tokenizer_init:
|
73
79
|
self.tokenizer = self.processor = None
|
@@ -150,12 +156,18 @@ class TpModelWorker:
|
|
150
156
|
self,
|
151
157
|
model_worker_batch: ModelWorkerBatch,
|
152
158
|
launch_done: Optional[threading.Event] = None,
|
159
|
+
skip_sample: bool = False,
|
153
160
|
):
|
154
161
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
155
162
|
logits_output = self.model_runner.forward(forward_batch)
|
156
163
|
if launch_done:
|
157
164
|
launch_done.set()
|
158
|
-
|
165
|
+
|
166
|
+
if skip_sample:
|
167
|
+
next_token_ids = None
|
168
|
+
else:
|
169
|
+
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
|
170
|
+
|
159
171
|
return logits_output, next_token_ids
|
160
172
|
|
161
173
|
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
@@ -191,7 +203,7 @@ class TpModelWorker:
|
|
191
203
|
|
192
204
|
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
193
205
|
success, message = self.model_runner.update_weights_from_tensor(
|
194
|
-
|
206
|
+
MultiprocessingSerializer.deserialize(recv_req.serialized_named_tensors)
|
195
207
|
)
|
196
208
|
return success, message
|
197
209
|
|
@@ -144,10 +144,9 @@ class TpModelWorkerClient:
|
|
144
144
|
|
145
145
|
# Copy results to the CPU
|
146
146
|
if model_worker_batch.return_logprob:
|
147
|
-
logits_output.next_token_logprobs =
|
148
|
-
|
149
|
-
|
150
|
-
].to("cpu", non_blocking=True)
|
147
|
+
logits_output.next_token_logprobs = (
|
148
|
+
logits_output.next_token_logprobs.to("cpu", non_blocking=True)
|
149
|
+
)
|
151
150
|
if logits_output.input_token_logprobs is not None:
|
152
151
|
logits_output.input_token_logprobs = (
|
153
152
|
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
@@ -25,14 +25,15 @@ from vllm.distributed import get_tensor_model_parallel_rank
|
|
25
25
|
from vllm.distributed.parallel_state import graph_capture
|
26
26
|
from vllm.model_executor.custom_op import CustomOp
|
27
27
|
|
28
|
-
from sglang.srt.layers.logits_processor import
|
29
|
-
LogitsMetadata,
|
30
|
-
LogitsProcessor,
|
31
|
-
LogitsProcessorOutput,
|
32
|
-
)
|
28
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
33
29
|
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
|
34
|
-
from sglang.srt.
|
35
|
-
from sglang.srt.
|
30
|
+
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
31
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
32
|
+
CaptureHiddenMode,
|
33
|
+
ForwardBatch,
|
34
|
+
ForwardMode,
|
35
|
+
)
|
36
|
+
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
36
37
|
|
37
38
|
if TYPE_CHECKING:
|
38
39
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
@@ -105,11 +106,6 @@ def set_torch_compile_config():
|
|
105
106
|
torch._dynamo.config.cache_size_limit = 1024
|
106
107
|
|
107
108
|
|
108
|
-
@maybe_torch_compile(dynamic=True)
|
109
|
-
def clamp_position(seq_lens):
|
110
|
-
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
111
|
-
|
112
|
-
|
113
109
|
class CudaGraphRunner:
|
114
110
|
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
115
111
|
|
@@ -152,6 +148,21 @@ class CudaGraphRunner:
|
|
152
148
|
if bs <= model_runner.req_to_token_pool.size
|
153
149
|
and bs <= model_runner.server_args.cuda_graph_max_bs
|
154
150
|
]
|
151
|
+
|
152
|
+
self.capture_forward_mode = ForwardMode.DECODE
|
153
|
+
self.num_tokens_per_bs = 1
|
154
|
+
|
155
|
+
if model_runner.spec_algorithm.is_eagle():
|
156
|
+
if self.model_runner.is_draft_worker:
|
157
|
+
self.num_tokens_per_bs = (
|
158
|
+
self.model_runner.server_args.speculative_eagle_topk
|
159
|
+
)
|
160
|
+
else:
|
161
|
+
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
|
162
|
+
self.num_tokens_per_bs = (
|
163
|
+
self.model_runner.server_args.speculative_num_draft_tokens
|
164
|
+
)
|
165
|
+
|
155
166
|
self.compile_bs = (
|
156
167
|
[
|
157
168
|
bs
|
@@ -164,8 +175,8 @@ class CudaGraphRunner:
|
|
164
175
|
|
165
176
|
# Attention backend
|
166
177
|
self.max_bs = max(self.capture_bs)
|
167
|
-
self.
|
168
|
-
|
178
|
+
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
179
|
+
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
|
169
180
|
self.seq_len_fill_value = (
|
170
181
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
171
182
|
)
|
@@ -178,14 +189,22 @@ class CudaGraphRunner:
|
|
178
189
|
|
179
190
|
# Common inputs
|
180
191
|
with torch.device("cuda"):
|
181
|
-
self.input_ids = torch.zeros((self.
|
192
|
+
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int32)
|
182
193
|
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
183
194
|
self.seq_lens = torch.full(
|
184
195
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
185
196
|
)
|
186
|
-
self.out_cache_loc = torch.zeros((self.
|
197
|
+
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int32)
|
198
|
+
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
187
199
|
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
|
188
200
|
|
201
|
+
# Speculative_inference
|
202
|
+
if model_runner.spec_algorithm.is_eagle():
|
203
|
+
self.hidden_states = torch.zeros(
|
204
|
+
(self.max_num_token, self.model_runner.model_config.hidden_size),
|
205
|
+
dtype=self.model_runner.dtype,
|
206
|
+
)
|
207
|
+
|
189
208
|
if self.is_encoder_decoder:
|
190
209
|
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
|
191
210
|
self.encoder_lens = torch.full(
|
@@ -257,12 +276,12 @@ class CudaGraphRunner:
|
|
257
276
|
def capture(self):
|
258
277
|
with graph_capture() as graph_capture_context:
|
259
278
|
self.stream = graph_capture_context.stream
|
260
|
-
|
279
|
+
capture_range = (
|
261
280
|
tqdm.tqdm(self.capture_bs)
|
262
281
|
if get_tensor_model_parallel_rank() == 0
|
263
282
|
else self.capture_bs
|
264
283
|
)
|
265
|
-
for bs in
|
284
|
+
for bs in capture_range:
|
266
285
|
with patch_model(
|
267
286
|
self.model_runner.model,
|
268
287
|
bs in self.compile_bs,
|
@@ -276,21 +295,24 @@ class CudaGraphRunner:
|
|
276
295
|
self.graphs[bs] = graph
|
277
296
|
self.output_buffers[bs] = output_buffers
|
278
297
|
|
298
|
+
# Save gemlite cache after each capture
|
299
|
+
save_gemlite_cache()
|
300
|
+
|
279
301
|
def capture_one_batch_size(self, bs: int, forward: Callable):
|
280
302
|
graph = torch.cuda.CUDAGraph()
|
281
303
|
stream = self.stream
|
304
|
+
num_tokens = bs * self.num_tokens_per_bs
|
282
305
|
|
283
306
|
# Common inputs
|
284
|
-
input_ids = self.input_ids[:
|
307
|
+
input_ids = self.input_ids[:num_tokens]
|
285
308
|
req_pool_indices = self.req_pool_indices[:bs]
|
286
309
|
seq_lens = self.seq_lens[:bs]
|
287
|
-
out_cache_loc = self.out_cache_loc[:
|
310
|
+
out_cache_loc = self.out_cache_loc[:num_tokens]
|
311
|
+
positions = self.positions[:num_tokens]
|
288
312
|
if self.is_encoder_decoder:
|
289
313
|
encoder_lens = self.encoder_lens[:bs]
|
290
314
|
else:
|
291
315
|
encoder_lens = None
|
292
|
-
|
293
|
-
seq_lens_sum = seq_lens.sum().item()
|
294
316
|
mrope_positions = self.mrope_positions[:, :bs]
|
295
317
|
|
296
318
|
if self.enable_dp_attention:
|
@@ -300,37 +322,43 @@ class CudaGraphRunner:
|
|
300
322
|
global_num_tokens = None
|
301
323
|
gathered_buffer = None
|
302
324
|
|
325
|
+
forward_batch = ForwardBatch(
|
326
|
+
forward_mode=self.capture_forward_mode,
|
327
|
+
batch_size=bs,
|
328
|
+
input_ids=input_ids,
|
329
|
+
req_pool_indices=req_pool_indices,
|
330
|
+
seq_lens=seq_lens,
|
331
|
+
req_to_token_pool=self.model_runner.req_to_token_pool,
|
332
|
+
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
333
|
+
attn_backend=self.model_runner.attn_backend,
|
334
|
+
out_cache_loc=out_cache_loc,
|
335
|
+
seq_lens_sum=seq_lens.sum(),
|
336
|
+
encoder_lens=encoder_lens,
|
337
|
+
return_logprob=False,
|
338
|
+
top_logprobs_nums=[0] * bs,
|
339
|
+
positions=positions,
|
340
|
+
global_num_tokens=global_num_tokens,
|
341
|
+
mrope_positions=mrope_positions,
|
342
|
+
gathered_buffer=gathered_buffer,
|
343
|
+
spec_algorithm=self.model_runner.spec_algorithm,
|
344
|
+
spec_info=self.get_spec_info(num_tokens, positions),
|
345
|
+
)
|
346
|
+
|
303
347
|
# Attention backend
|
304
348
|
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
305
349
|
bs,
|
350
|
+
num_tokens,
|
306
351
|
req_pool_indices,
|
307
352
|
seq_lens,
|
308
353
|
encoder_lens,
|
354
|
+
forward_batch.forward_mode,
|
355
|
+
forward_batch.spec_info,
|
309
356
|
)
|
310
357
|
|
311
358
|
# Run and capture
|
312
359
|
def run_once():
|
313
|
-
forward_batch = ForwardBatch(
|
314
|
-
forward_mode=ForwardMode.DECODE,
|
315
|
-
batch_size=bs,
|
316
|
-
input_ids=input_ids,
|
317
|
-
req_pool_indices=req_pool_indices,
|
318
|
-
seq_lens=seq_lens,
|
319
|
-
req_to_token_pool=self.model_runner.req_to_token_pool,
|
320
|
-
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
321
|
-
attn_backend=self.model_runner.attn_backend,
|
322
|
-
out_cache_loc=out_cache_loc,
|
323
|
-
seq_lens_sum=seq_lens_sum,
|
324
|
-
encoder_lens=encoder_lens,
|
325
|
-
return_logprob=False,
|
326
|
-
top_logprobs_nums=[0] * bs,
|
327
|
-
positions=clamp_position(seq_lens),
|
328
|
-
mrope_positions=mrope_positions,
|
329
|
-
global_num_tokens=global_num_tokens,
|
330
|
-
gathered_buffer=gathered_buffer,
|
331
|
-
)
|
332
360
|
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
333
|
-
return logits_output.next_token_logits
|
361
|
+
return logits_output.next_token_logits, logits_output.hidden_states
|
334
362
|
|
335
363
|
for _ in range(2):
|
336
364
|
torch.cuda.synchronize()
|
@@ -356,6 +384,7 @@ class CudaGraphRunner:
|
|
356
384
|
def replay(self, forward_batch: ForwardBatch):
|
357
385
|
assert forward_batch.out_cache_loc is not None
|
358
386
|
raw_bs = forward_batch.batch_size
|
387
|
+
raw_num_token = raw_bs * self.num_tokens_per_bs
|
359
388
|
|
360
389
|
# Pad
|
361
390
|
if self.enable_dp_attention:
|
@@ -370,15 +399,20 @@ class CudaGraphRunner:
|
|
370
399
|
self.out_cache_loc.zero_()
|
371
400
|
|
372
401
|
# Common inputs
|
373
|
-
self.input_ids[:
|
402
|
+
self.input_ids[:raw_num_token].copy_(forward_batch.input_ids)
|
374
403
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
375
404
|
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
376
|
-
self.out_cache_loc[:
|
405
|
+
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
406
|
+
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
407
|
+
|
377
408
|
if self.is_encoder_decoder:
|
378
409
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
379
410
|
if forward_batch.mrope_positions is not None:
|
380
411
|
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
381
412
|
|
413
|
+
if hasattr(forward_batch.spec_info, "hidden_states"):
|
414
|
+
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
|
415
|
+
|
382
416
|
# Attention backend
|
383
417
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
384
418
|
bs,
|
@@ -386,40 +420,51 @@ class CudaGraphRunner:
|
|
386
420
|
self.seq_lens,
|
387
421
|
forward_batch.seq_lens_sum + (bs - raw_bs),
|
388
422
|
self.encoder_lens,
|
423
|
+
forward_batch.forward_mode,
|
424
|
+
forward_batch.spec_info,
|
389
425
|
)
|
390
426
|
|
391
427
|
# Replay
|
392
428
|
self.graphs[bs].replay()
|
393
|
-
next_token_logits = self.output_buffers[bs]
|
429
|
+
next_token_logits, hidden_states = self.output_buffers[bs]
|
430
|
+
|
431
|
+
logits_output = LogitsProcessorOutput(
|
432
|
+
next_token_logits=next_token_logits[:raw_num_token],
|
433
|
+
hidden_states=(
|
434
|
+
hidden_states[:raw_num_token] if hidden_states is not None else None
|
435
|
+
),
|
436
|
+
)
|
437
|
+
return logits_output
|
394
438
|
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
439
|
+
def get_spec_info(self, num_tokens: int, positions: torch.Tensor):
|
440
|
+
spec_info = None
|
441
|
+
if self.model_runner.spec_algorithm.is_eagle():
|
442
|
+
from sglang.srt.speculative.eagle_utils import (
|
443
|
+
EAGLEDraftInput,
|
444
|
+
EagleVerifyInput,
|
400
445
|
)
|
401
|
-
|
402
|
-
|
403
|
-
|
446
|
+
|
447
|
+
if self.model_runner.is_draft_worker:
|
448
|
+
spec_info = EAGLEDraftInput()
|
449
|
+
spec_info.hidden_states = self.hidden_states[:num_tokens]
|
450
|
+
spec_info.positions = positions
|
451
|
+
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
452
|
+
spec_info.init(self.model_runner.server_args)
|
453
|
+
else:
|
454
|
+
spec_info = EagleVerifyInput(
|
455
|
+
None,
|
456
|
+
None,
|
457
|
+
None,
|
458
|
+
None,
|
459
|
+
None,
|
460
|
+
None,
|
461
|
+
self.model_runner.server_args.speculative_num_draft_tokens,
|
404
462
|
)
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
if return_top_logprob:
|
412
|
-
(
|
413
|
-
logits_output.output_top_logprobs_val,
|
414
|
-
logits_output.output_top_logprobs_idx,
|
415
|
-
) = LogitsProcessor.get_top_logprobs(
|
416
|
-
next_token_logprobs, logits_metadata
|
417
|
-
)[
|
418
|
-
2:4
|
419
|
-
]
|
420
|
-
else:
|
421
|
-
logits_output = LogitsProcessorOutput(
|
422
|
-
next_token_logits=next_token_logits,
|
423
|
-
)
|
463
|
+
spec_info.custom_mask = torch.zeros(
|
464
|
+
(num_tokens * self.model_runner.model_config.context_len),
|
465
|
+
dtype=torch.bool,
|
466
|
+
device="cuda",
|
467
|
+
)
|
468
|
+
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
424
469
|
|
425
|
-
return
|
470
|
+
return spec_info
|
@@ -38,6 +38,7 @@ import triton
|
|
38
38
|
import triton.language as tl
|
39
39
|
|
40
40
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
41
|
+
from sglang.srt.utils import maybe_torch_compile
|
41
42
|
|
42
43
|
if TYPE_CHECKING:
|
43
44
|
from sglang.srt.layers.attention import AttentionBackend
|
@@ -96,7 +97,11 @@ class ForwardMode(IntEnum):
|
|
96
97
|
return self == ForwardMode.DRAFT_EXTEND
|
97
98
|
|
98
99
|
def is_cuda_graph(self):
|
99
|
-
return
|
100
|
+
return (
|
101
|
+
self == ForwardMode.DECODE
|
102
|
+
or self == ForwardMode.TARGET_VERIFY
|
103
|
+
or self == ForwardMode.IDLE
|
104
|
+
)
|
100
105
|
|
101
106
|
def is_dummy_first(self):
|
102
107
|
return self == ForwardMode.DUMMY_FIRST
|
@@ -161,15 +166,15 @@ class ForwardBatch:
|
|
161
166
|
token_to_kv_pool: BaseTokenToKVPool = None
|
162
167
|
attn_backend: AttentionBackend = None
|
163
168
|
|
164
|
-
# Speculative decoding
|
165
|
-
spec_info: SpecInfo = None
|
166
|
-
spec_algorithm: SpeculativeAlgorithm = None
|
167
|
-
|
168
169
|
# For DP attention
|
169
170
|
global_num_tokens: Optional[List[int]] = None
|
170
171
|
gathered_buffer: Optional[torch.Tensor] = None
|
171
172
|
can_run_dp_cuda_graph: bool = False
|
172
173
|
|
174
|
+
# Speculative decoding
|
175
|
+
spec_info: SpecInfo = None
|
176
|
+
spec_algorithm: SpeculativeAlgorithm = None
|
177
|
+
|
173
178
|
# For Qwen2-VL
|
174
179
|
mrope_positions: torch.Tensor = None
|
175
180
|
|
@@ -258,6 +263,8 @@ class ForwardBatch:
|
|
258
263
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
259
264
|
lora_paths=batch.lora_paths,
|
260
265
|
sampling_info=batch.sampling_info,
|
266
|
+
spec_algorithm=batch.spec_algorithm,
|
267
|
+
spec_info=batch.spec_info,
|
261
268
|
input_embeds=batch.input_embeds,
|
262
269
|
)
|
263
270
|
|
@@ -270,10 +277,21 @@ class ForwardBatch:
|
|
270
277
|
)
|
271
278
|
|
272
279
|
if ret.forward_mode.is_idle():
|
280
|
+
ret.positions = torch.empty((0,), device=device)
|
273
281
|
return ret
|
274
282
|
|
283
|
+
# Override the positions with spec_info
|
284
|
+
if (
|
285
|
+
ret.spec_info is not None
|
286
|
+
and getattr(ret.spec_info, "positions", None) is not None
|
287
|
+
):
|
288
|
+
ret.positions = ret.spec_info.positions
|
289
|
+
|
275
290
|
# Init position information
|
276
|
-
if
|
291
|
+
if ret.forward_mode.is_decode():
|
292
|
+
if ret.positions is None:
|
293
|
+
ret.positions = clamp_position(batch.seq_lens)
|
294
|
+
else:
|
277
295
|
ret.extend_seq_lens = torch.tensor(
|
278
296
|
batch.extend_seq_lens, dtype=torch.int32
|
279
297
|
).to(device, non_blocking=True)
|
@@ -282,13 +300,15 @@ class ForwardBatch:
|
|
282
300
|
).to(device, non_blocking=True)
|
283
301
|
if model_runner.server_args.attention_backend != "torch_native":
|
284
302
|
ret.extend_num_tokens = batch.extend_num_tokens
|
285
|
-
|
303
|
+
positions, ret.extend_start_loc = compute_position_triton(
|
286
304
|
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
|
287
305
|
)
|
288
306
|
else:
|
289
|
-
|
307
|
+
positions, ret.extend_start_loc = compute_position_torch(
|
290
308
|
ret.extend_prefix_lens, ret.extend_seq_lens
|
291
309
|
)
|
310
|
+
if ret.positions is None:
|
311
|
+
ret.positions = positions
|
292
312
|
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
|
293
313
|
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
294
314
|
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
@@ -377,6 +397,11 @@ def compute_position_torch(
|
|
377
397
|
return positions.to(torch.int64), extend_start_loc
|
378
398
|
|
379
399
|
|
400
|
+
@maybe_torch_compile(dynamic=True)
|
401
|
+
def clamp_position(seq_lens):
|
402
|
+
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
403
|
+
|
404
|
+
|
380
405
|
class CaptureHiddenMode(IntEnum):
|
381
406
|
NULL = auto()
|
382
407
|
FULL = auto()
|