sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/bench_serving.py +18 -1
- sglang/lang/interpreter.py +71 -1
- sglang/lang/ir.py +2 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/chatglm.py +78 -0
- sglang/srt/configs/dbrx.py +279 -0
- sglang/srt/configs/model_config.py +1 -1
- sglang/srt/hf_transformers_utils.py +9 -14
- sglang/srt/layers/attention/__init__.py +22 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
- sglang/srt/layers/attention/flashinfer_backend.py +215 -83
- sglang/srt/layers/attention/torch_native_backend.py +1 -38
- sglang/srt/layers/attention/triton_backend.py +20 -11
- sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
- sglang/srt/layers/linear.py +159 -55
- sglang/srt/layers/logits_processor.py +170 -215
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +198 -29
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -7
- sglang/srt/layers/parameter.py +431 -0
- sglang/srt/layers/quantization/__init__.py +3 -2
- sglang/srt/layers/quantization/fp8.py +3 -3
- sglang/srt/layers/quantization/modelopt_quant.py +174 -0
- sglang/srt/layers/sampler.py +57 -21
- sglang/srt/layers/torchao_utils.py +17 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -1
- sglang/srt/managers/cache_controller.py +307 -0
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +1 -2
- sglang/srt/managers/schedule_batch.py +33 -3
- sglang/srt/managers/schedule_policy.py +159 -90
- sglang/srt/managers/scheduler.py +68 -28
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +27 -21
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/memory_pool.py +206 -1
- sglang/srt/metrics/collector.py +22 -30
- sglang/srt/model_executor/cuda_graph_runner.py +129 -77
- sglang/srt/model_executor/forward_batch_info.py +51 -21
- sglang/srt/model_executor/model_runner.py +72 -64
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek_v2.py +34 -7
- sglang/srt/models/grok.py +109 -29
- sglang/srt/models/llama.py +9 -2
- sglang/srt/openai_api/adapter.py +0 -17
- sglang/srt/openai_api/protocol.py +3 -3
- sglang/srt/sampling/sampling_batch_info.py +22 -0
- sglang/srt/sampling/sampling_params.py +9 -1
- sglang/srt/server.py +20 -13
- sglang/srt/server_args.py +120 -58
- sglang/srt/speculative/build_eagle_tree.py +347 -0
- sglang/srt/speculative/eagle_utils.py +626 -0
- sglang/srt/speculative/eagle_worker.py +184 -0
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/utils.py +47 -7
- sglang/test/test_programs.py +23 -1
- sglang/test/test_utils.py +36 -7
- sglang/version.py +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +12 -12
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +86 -57
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -25,14 +25,15 @@ from vllm.distributed import get_tensor_model_parallel_rank
|
|
25
25
|
from vllm.distributed.parallel_state import graph_capture
|
26
26
|
from vllm.model_executor.custom_op import CustomOp
|
27
27
|
|
28
|
-
from sglang.srt.layers.logits_processor import
|
29
|
-
LogitsMetadata,
|
30
|
-
LogitsProcessor,
|
31
|
-
LogitsProcessorOutput,
|
32
|
-
)
|
28
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
33
29
|
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
|
34
|
-
from sglang.srt.
|
35
|
-
from sglang.srt.
|
30
|
+
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
31
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
32
|
+
CaptureHiddenMode,
|
33
|
+
ForwardBatch,
|
34
|
+
ForwardMode,
|
35
|
+
)
|
36
|
+
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
36
37
|
|
37
38
|
if TYPE_CHECKING:
|
38
39
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
@@ -105,11 +106,6 @@ def set_torch_compile_config():
|
|
105
106
|
torch._dynamo.config.cache_size_limit = 1024
|
106
107
|
|
107
108
|
|
108
|
-
@maybe_torch_compile(dynamic=True)
|
109
|
-
def clamp_position(seq_lens):
|
110
|
-
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
111
|
-
|
112
|
-
|
113
109
|
class CudaGraphRunner:
|
114
110
|
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
115
111
|
|
@@ -128,10 +124,12 @@ class CudaGraphRunner:
|
|
128
124
|
self.tp_size = self.model_runner.tp_size
|
129
125
|
|
130
126
|
# Batch sizes to capture
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
127
|
+
self.capture_bs = self.model_runner.server_args.cuda_graph_bs
|
128
|
+
if self.capture_bs is None:
|
129
|
+
if model_runner.server_args.disable_cuda_graph_padding:
|
130
|
+
self.capture_bs = list(range(1, 33)) + [64, 128]
|
131
|
+
else:
|
132
|
+
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
135
133
|
|
136
134
|
if max(self.capture_bs) > model_runner.req_to_token_pool.size:
|
137
135
|
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
@@ -152,6 +150,21 @@ class CudaGraphRunner:
|
|
152
150
|
if bs <= model_runner.req_to_token_pool.size
|
153
151
|
and bs <= model_runner.server_args.cuda_graph_max_bs
|
154
152
|
]
|
153
|
+
|
154
|
+
self.capture_forward_mode = ForwardMode.DECODE
|
155
|
+
self.num_tokens_per_bs = 1
|
156
|
+
|
157
|
+
if model_runner.spec_algorithm.is_eagle():
|
158
|
+
if self.model_runner.is_draft_worker:
|
159
|
+
self.num_tokens_per_bs = (
|
160
|
+
self.model_runner.server_args.speculative_eagle_topk
|
161
|
+
)
|
162
|
+
else:
|
163
|
+
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
|
164
|
+
self.num_tokens_per_bs = (
|
165
|
+
self.model_runner.server_args.speculative_num_draft_tokens
|
166
|
+
)
|
167
|
+
|
155
168
|
self.compile_bs = (
|
156
169
|
[
|
157
170
|
bs
|
@@ -164,8 +177,8 @@ class CudaGraphRunner:
|
|
164
177
|
|
165
178
|
# Attention backend
|
166
179
|
self.max_bs = max(self.capture_bs)
|
167
|
-
self.
|
168
|
-
|
180
|
+
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
181
|
+
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
|
169
182
|
self.seq_len_fill_value = (
|
170
183
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
171
184
|
)
|
@@ -178,14 +191,22 @@ class CudaGraphRunner:
|
|
178
191
|
|
179
192
|
# Common inputs
|
180
193
|
with torch.device("cuda"):
|
181
|
-
self.input_ids = torch.zeros((self.
|
194
|
+
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int32)
|
182
195
|
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
183
196
|
self.seq_lens = torch.full(
|
184
197
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
185
198
|
)
|
186
|
-
self.out_cache_loc = torch.zeros((self.
|
199
|
+
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int32)
|
200
|
+
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
187
201
|
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
|
188
202
|
|
203
|
+
# Speculative_inference
|
204
|
+
if model_runner.spec_algorithm.is_eagle():
|
205
|
+
self.hidden_states = torch.zeros(
|
206
|
+
(self.max_num_token, self.model_runner.model_config.hidden_size),
|
207
|
+
dtype=self.model_runner.dtype,
|
208
|
+
)
|
209
|
+
|
189
210
|
if self.is_encoder_decoder:
|
190
211
|
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
|
191
212
|
self.encoder_lens = torch.full(
|
@@ -257,12 +278,12 @@ class CudaGraphRunner:
|
|
257
278
|
def capture(self):
|
258
279
|
with graph_capture() as graph_capture_context:
|
259
280
|
self.stream = graph_capture_context.stream
|
260
|
-
|
281
|
+
capture_range = (
|
261
282
|
tqdm.tqdm(self.capture_bs)
|
262
283
|
if get_tensor_model_parallel_rank() == 0
|
263
284
|
else self.capture_bs
|
264
285
|
)
|
265
|
-
for bs in
|
286
|
+
for bs in capture_range:
|
266
287
|
with patch_model(
|
267
288
|
self.model_runner.model,
|
268
289
|
bs in self.compile_bs,
|
@@ -276,21 +297,24 @@ class CudaGraphRunner:
|
|
276
297
|
self.graphs[bs] = graph
|
277
298
|
self.output_buffers[bs] = output_buffers
|
278
299
|
|
300
|
+
# Save gemlite cache after each capture
|
301
|
+
save_gemlite_cache()
|
302
|
+
|
279
303
|
def capture_one_batch_size(self, bs: int, forward: Callable):
|
280
304
|
graph = torch.cuda.CUDAGraph()
|
281
305
|
stream = self.stream
|
306
|
+
num_tokens = bs * self.num_tokens_per_bs
|
282
307
|
|
283
308
|
# Common inputs
|
284
|
-
input_ids = self.input_ids[:
|
309
|
+
input_ids = self.input_ids[:num_tokens]
|
285
310
|
req_pool_indices = self.req_pool_indices[:bs]
|
286
311
|
seq_lens = self.seq_lens[:bs]
|
287
|
-
out_cache_loc = self.out_cache_loc[:
|
312
|
+
out_cache_loc = self.out_cache_loc[:num_tokens]
|
313
|
+
positions = self.positions[:num_tokens]
|
288
314
|
if self.is_encoder_decoder:
|
289
315
|
encoder_lens = self.encoder_lens[:bs]
|
290
316
|
else:
|
291
317
|
encoder_lens = None
|
292
|
-
|
293
|
-
seq_lens_sum = seq_lens.sum().item()
|
294
318
|
mrope_positions = self.mrope_positions[:, :bs]
|
295
319
|
|
296
320
|
if self.enable_dp_attention:
|
@@ -300,37 +324,48 @@ class CudaGraphRunner:
|
|
300
324
|
global_num_tokens = None
|
301
325
|
gathered_buffer = None
|
302
326
|
|
327
|
+
spec_info = self.get_spec_info(num_tokens, positions)
|
328
|
+
|
329
|
+
forward_batch = ForwardBatch(
|
330
|
+
forward_mode=self.capture_forward_mode,
|
331
|
+
batch_size=bs,
|
332
|
+
input_ids=input_ids,
|
333
|
+
req_pool_indices=req_pool_indices,
|
334
|
+
seq_lens=seq_lens,
|
335
|
+
req_to_token_pool=self.model_runner.req_to_token_pool,
|
336
|
+
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
337
|
+
attn_backend=self.model_runner.attn_backend,
|
338
|
+
out_cache_loc=out_cache_loc,
|
339
|
+
seq_lens_sum=seq_lens.sum(),
|
340
|
+
encoder_lens=encoder_lens,
|
341
|
+
return_logprob=False,
|
342
|
+
top_logprobs_nums=[0] * bs,
|
343
|
+
positions=positions,
|
344
|
+
global_num_tokens=global_num_tokens,
|
345
|
+
gathered_buffer=gathered_buffer,
|
346
|
+
mrope_positions=mrope_positions,
|
347
|
+
spec_algorithm=self.model_runner.spec_algorithm,
|
348
|
+
spec_info=spec_info,
|
349
|
+
capture_hidden_mode=(
|
350
|
+
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
351
|
+
),
|
352
|
+
)
|
353
|
+
|
303
354
|
# Attention backend
|
304
355
|
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
305
356
|
bs,
|
357
|
+
num_tokens,
|
306
358
|
req_pool_indices,
|
307
359
|
seq_lens,
|
308
360
|
encoder_lens,
|
361
|
+
forward_batch.forward_mode,
|
362
|
+
forward_batch.spec_info,
|
309
363
|
)
|
310
364
|
|
311
365
|
# Run and capture
|
312
366
|
def run_once():
|
313
|
-
forward_batch = ForwardBatch(
|
314
|
-
forward_mode=ForwardMode.DECODE,
|
315
|
-
batch_size=bs,
|
316
|
-
input_ids=input_ids,
|
317
|
-
req_pool_indices=req_pool_indices,
|
318
|
-
seq_lens=seq_lens,
|
319
|
-
req_to_token_pool=self.model_runner.req_to_token_pool,
|
320
|
-
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
321
|
-
attn_backend=self.model_runner.attn_backend,
|
322
|
-
out_cache_loc=out_cache_loc,
|
323
|
-
seq_lens_sum=seq_lens_sum,
|
324
|
-
encoder_lens=encoder_lens,
|
325
|
-
return_logprob=False,
|
326
|
-
top_logprobs_nums=[0] * bs,
|
327
|
-
positions=clamp_position(seq_lens),
|
328
|
-
mrope_positions=mrope_positions,
|
329
|
-
global_num_tokens=global_num_tokens,
|
330
|
-
gathered_buffer=gathered_buffer,
|
331
|
-
)
|
332
367
|
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
333
|
-
return logits_output.next_token_logits
|
368
|
+
return logits_output.next_token_logits, logits_output.hidden_states
|
334
369
|
|
335
370
|
for _ in range(2):
|
336
371
|
torch.cuda.synchronize()
|
@@ -356,6 +391,7 @@ class CudaGraphRunner:
|
|
356
391
|
def replay(self, forward_batch: ForwardBatch):
|
357
392
|
assert forward_batch.out_cache_loc is not None
|
358
393
|
raw_bs = forward_batch.batch_size
|
394
|
+
raw_num_token = raw_bs * self.num_tokens_per_bs
|
359
395
|
|
360
396
|
# Pad
|
361
397
|
if self.enable_dp_attention:
|
@@ -370,15 +406,20 @@ class CudaGraphRunner:
|
|
370
406
|
self.out_cache_loc.zero_()
|
371
407
|
|
372
408
|
# Common inputs
|
373
|
-
self.input_ids[:
|
409
|
+
self.input_ids[:raw_num_token].copy_(forward_batch.input_ids)
|
374
410
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
375
411
|
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
376
|
-
self.out_cache_loc[:
|
412
|
+
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
413
|
+
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
414
|
+
|
377
415
|
if self.is_encoder_decoder:
|
378
416
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
379
417
|
if forward_batch.mrope_positions is not None:
|
380
418
|
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
381
419
|
|
420
|
+
if hasattr(forward_batch.spec_info, "hidden_states"):
|
421
|
+
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
|
422
|
+
|
382
423
|
# Attention backend
|
383
424
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
384
425
|
bs,
|
@@ -386,40 +427,51 @@ class CudaGraphRunner:
|
|
386
427
|
self.seq_lens,
|
387
428
|
forward_batch.seq_lens_sum + (bs - raw_bs),
|
388
429
|
self.encoder_lens,
|
430
|
+
forward_batch.forward_mode,
|
431
|
+
forward_batch.spec_info,
|
389
432
|
)
|
390
433
|
|
391
434
|
# Replay
|
392
435
|
self.graphs[bs].replay()
|
393
|
-
next_token_logits = self.output_buffers[bs]
|
436
|
+
next_token_logits, hidden_states = self.output_buffers[bs]
|
437
|
+
|
438
|
+
logits_output = LogitsProcessorOutput(
|
439
|
+
next_token_logits=next_token_logits[:raw_num_token],
|
440
|
+
hidden_states=(
|
441
|
+
hidden_states[:raw_num_token] if hidden_states is not None else None
|
442
|
+
),
|
443
|
+
)
|
444
|
+
return logits_output
|
394
445
|
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
446
|
+
def get_spec_info(self, num_tokens: int, positions: torch.Tensor):
|
447
|
+
spec_info = None
|
448
|
+
if self.model_runner.spec_algorithm.is_eagle():
|
449
|
+
from sglang.srt.speculative.eagle_utils import (
|
450
|
+
EAGLEDraftInput,
|
451
|
+
EagleVerifyInput,
|
400
452
|
)
|
401
|
-
|
402
|
-
|
403
|
-
|
453
|
+
|
454
|
+
if self.model_runner.is_draft_worker:
|
455
|
+
spec_info = EAGLEDraftInput()
|
456
|
+
spec_info.load_server_args(self.model_runner.server_args)
|
457
|
+
spec_info.hidden_states = self.hidden_states[:num_tokens]
|
458
|
+
spec_info.positions = positions
|
459
|
+
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
460
|
+
else:
|
461
|
+
spec_info = EagleVerifyInput(
|
462
|
+
None,
|
463
|
+
None,
|
464
|
+
None,
|
465
|
+
None,
|
466
|
+
None,
|
467
|
+
None,
|
468
|
+
self.model_runner.server_args.speculative_num_draft_tokens,
|
404
469
|
)
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
if return_top_logprob:
|
412
|
-
(
|
413
|
-
logits_output.output_top_logprobs_val,
|
414
|
-
logits_output.output_top_logprobs_idx,
|
415
|
-
) = LogitsProcessor.get_top_logprobs(
|
416
|
-
next_token_logprobs, logits_metadata
|
417
|
-
)[
|
418
|
-
2:4
|
419
|
-
]
|
420
|
-
else:
|
421
|
-
logits_output = LogitsProcessorOutput(
|
422
|
-
next_token_logits=next_token_logits,
|
423
|
-
)
|
470
|
+
spec_info.custom_mask = torch.zeros(
|
471
|
+
(num_tokens * self.model_runner.model_config.context_len),
|
472
|
+
dtype=torch.bool,
|
473
|
+
device="cuda",
|
474
|
+
)
|
475
|
+
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
424
476
|
|
425
|
-
return
|
477
|
+
return spec_info
|
@@ -38,6 +38,7 @@ import triton
|
|
38
38
|
import triton.language as tl
|
39
39
|
|
40
40
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
41
|
+
from sglang.srt.utils import maybe_torch_compile
|
41
42
|
|
42
43
|
if TYPE_CHECKING:
|
43
44
|
from sglang.srt.layers.attention import AttentionBackend
|
@@ -96,11 +97,33 @@ class ForwardMode(IntEnum):
|
|
96
97
|
return self == ForwardMode.DRAFT_EXTEND
|
97
98
|
|
98
99
|
def is_cuda_graph(self):
|
99
|
-
return
|
100
|
+
return (
|
101
|
+
self == ForwardMode.DECODE
|
102
|
+
or self == ForwardMode.TARGET_VERIFY
|
103
|
+
or self == ForwardMode.IDLE
|
104
|
+
)
|
100
105
|
|
101
106
|
def is_dummy_first(self):
|
102
107
|
return self == ForwardMode.DUMMY_FIRST
|
103
108
|
|
109
|
+
def is_decode_or_idle(self):
|
110
|
+
return self == ForwardMode.DECODE or self == ForwardMode.IDLE
|
111
|
+
|
112
|
+
|
113
|
+
class CaptureHiddenMode(IntEnum):
|
114
|
+
NULL = auto()
|
115
|
+
FULL = auto()
|
116
|
+
LAST = auto()
|
117
|
+
|
118
|
+
def need_capture(self):
|
119
|
+
return self != CaptureHiddenMode.NULL
|
120
|
+
|
121
|
+
def is_full(self):
|
122
|
+
return self == CaptureHiddenMode.FULL
|
123
|
+
|
124
|
+
def is_last(self):
|
125
|
+
return self == CaptureHiddenMode.LAST
|
126
|
+
|
104
127
|
|
105
128
|
@dataclass
|
106
129
|
class ForwardBatch:
|
@@ -161,15 +184,16 @@ class ForwardBatch:
|
|
161
184
|
token_to_kv_pool: BaseTokenToKVPool = None
|
162
185
|
attn_backend: AttentionBackend = None
|
163
186
|
|
164
|
-
# Speculative decoding
|
165
|
-
spec_info: SpecInfo = None
|
166
|
-
spec_algorithm: SpeculativeAlgorithm = None
|
167
|
-
|
168
187
|
# For DP attention
|
169
188
|
global_num_tokens: Optional[List[int]] = None
|
170
189
|
gathered_buffer: Optional[torch.Tensor] = None
|
171
190
|
can_run_dp_cuda_graph: bool = False
|
172
191
|
|
192
|
+
# Speculative decoding
|
193
|
+
spec_info: SpecInfo = None
|
194
|
+
spec_algorithm: SpeculativeAlgorithm = None
|
195
|
+
capture_hidden_mode: CaptureHiddenMode = None
|
196
|
+
|
173
197
|
# For Qwen2-VL
|
174
198
|
mrope_positions: torch.Tensor = None
|
175
199
|
|
@@ -258,6 +282,9 @@ class ForwardBatch:
|
|
258
282
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
259
283
|
lora_paths=batch.lora_paths,
|
260
284
|
sampling_info=batch.sampling_info,
|
285
|
+
spec_algorithm=batch.spec_algorithm,
|
286
|
+
spec_info=batch.spec_info,
|
287
|
+
capture_hidden_mode=batch.capture_hidden_mode,
|
261
288
|
input_embeds=batch.input_embeds,
|
262
289
|
)
|
263
290
|
|
@@ -270,10 +297,21 @@ class ForwardBatch:
|
|
270
297
|
)
|
271
298
|
|
272
299
|
if ret.forward_mode.is_idle():
|
300
|
+
ret.positions = torch.empty((0,), device=device)
|
273
301
|
return ret
|
274
302
|
|
303
|
+
# Override the positions with spec_info
|
304
|
+
if (
|
305
|
+
ret.spec_info is not None
|
306
|
+
and getattr(ret.spec_info, "positions", None) is not None
|
307
|
+
):
|
308
|
+
ret.positions = ret.spec_info.positions
|
309
|
+
|
275
310
|
# Init position information
|
276
|
-
if
|
311
|
+
if ret.forward_mode.is_decode():
|
312
|
+
if ret.positions is None:
|
313
|
+
ret.positions = clamp_position(batch.seq_lens)
|
314
|
+
else:
|
277
315
|
ret.extend_seq_lens = torch.tensor(
|
278
316
|
batch.extend_seq_lens, dtype=torch.int32
|
279
317
|
).to(device, non_blocking=True)
|
@@ -282,13 +320,15 @@ class ForwardBatch:
|
|
282
320
|
).to(device, non_blocking=True)
|
283
321
|
if model_runner.server_args.attention_backend != "torch_native":
|
284
322
|
ret.extend_num_tokens = batch.extend_num_tokens
|
285
|
-
|
323
|
+
positions, ret.extend_start_loc = compute_position_triton(
|
286
324
|
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
|
287
325
|
)
|
288
326
|
else:
|
289
|
-
|
327
|
+
positions, ret.extend_start_loc = compute_position_torch(
|
290
328
|
ret.extend_prefix_lens, ret.extend_seq_lens
|
291
329
|
)
|
330
|
+
if ret.positions is None:
|
331
|
+
ret.positions = positions
|
292
332
|
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
|
293
333
|
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
294
334
|
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
@@ -377,16 +417,6 @@ def compute_position_torch(
|
|
377
417
|
return positions.to(torch.int64), extend_start_loc
|
378
418
|
|
379
419
|
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
LAST = auto()
|
384
|
-
|
385
|
-
def need_capture(self):
|
386
|
-
return self != CaptureHiddenMode.NULL
|
387
|
-
|
388
|
-
def is_full(self):
|
389
|
-
return self == CaptureHiddenMode.FULL
|
390
|
-
|
391
|
-
def is_last(self):
|
392
|
-
return self == CaptureHiddenMode.LAST
|
420
|
+
@maybe_torch_compile(dynamic=True)
|
421
|
+
def clamp_position(seq_lens):
|
422
|
+
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|