sglang 0.4.1__py3-none-any.whl → 0.4.1.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +1 -0
- sglang/bench_serving.py +11 -3
- sglang/lang/backend/openai.py +10 -0
- sglang/srt/configs/model_config.py +11 -2
- sglang/srt/constrained/xgrammar_backend.py +6 -0
- sglang/srt/layers/attention/__init__.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +54 -41
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
- sglang/srt/layers/logits_processor.py +30 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -30
- sglang/srt/layers/moe/topk.py +14 -0
- sglang/srt/layers/quantization/fp8.py +42 -2
- sglang/srt/layers/quantization/fp8_kernel.py +91 -18
- sglang/srt/layers/quantization/fp8_utils.py +8 -2
- sglang/srt/managers/io_struct.py +29 -8
- sglang/srt/managers/schedule_batch.py +22 -15
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +71 -34
- sglang/srt/managers/session_controller.py +102 -27
- sglang/srt/managers/tokenizer_manager.py +95 -55
- sglang/srt/managers/tp_worker.py +7 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -0
- sglang/srt/model_executor/forward_batch_info.py +42 -3
- sglang/srt/model_executor/model_runner.py +4 -6
- sglang/srt/model_loader/loader.py +22 -11
- sglang/srt/models/gemma2.py +19 -0
- sglang/srt/models/llama.py +13 -2
- sglang/srt/models/llama_eagle.py +132 -0
- sglang/srt/openai_api/adapter.py +79 -2
- sglang/srt/openai_api/protocol.py +50 -0
- sglang/srt/sampling/sampling_params.py +9 -2
- sglang/srt/server.py +45 -39
- sglang/srt/server_args.py +17 -30
- sglang/srt/speculative/spec_info.py +19 -0
- sglang/srt/utils.py +62 -0
- sglang/version.py +1 -1
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/METADATA +5 -5
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/RECORD +41 -39
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py
CHANGED
@@ -897,6 +897,7 @@ async def benchmark(
|
|
897
897
|
else:
|
898
898
|
raise ValueError(f"Unknown backend: {backend}")
|
899
899
|
|
900
|
+
# Limit concurrency
|
900
901
|
# From https://github.com/vllm-project/vllm/pull/9390
|
901
902
|
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
|
902
903
|
|
@@ -906,6 +907,7 @@ async def benchmark(
|
|
906
907
|
async with semaphore:
|
907
908
|
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
908
909
|
|
910
|
+
# Warmup
|
909
911
|
print("Starting initial single prompt test run...")
|
910
912
|
test_prompt, test_prompt_len, test_output_len = input_requests[0]
|
911
913
|
test_input = RequestFuncInput(
|
@@ -924,11 +926,15 @@ async def benchmark(
|
|
924
926
|
f"are correctly specified. Error: {test_output.error}"
|
925
927
|
)
|
926
928
|
else:
|
927
|
-
requests.post(base_url + "/flush_cache")
|
928
929
|
print("Initial test run completed. Starting main benchmark run...")
|
929
930
|
|
930
|
-
|
931
|
+
# Flush cache
|
932
|
+
if "sglang" in backend:
|
933
|
+
requests.post(base_url + "/flush_cache")
|
934
|
+
|
935
|
+
time.sleep(1.0)
|
931
936
|
|
937
|
+
# Start profiler
|
932
938
|
if profile:
|
933
939
|
print("Starting profiler...")
|
934
940
|
profile_output = await async_request_profile(
|
@@ -939,6 +945,7 @@ async def benchmark(
|
|
939
945
|
|
940
946
|
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
941
947
|
|
948
|
+
# Run all requests
|
942
949
|
benchmark_start_time = time.perf_counter()
|
943
950
|
tasks: List[asyncio.Task] = []
|
944
951
|
async for request in get_request(input_requests, request_rate):
|
@@ -959,6 +966,7 @@ async def benchmark(
|
|
959
966
|
)
|
960
967
|
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
961
968
|
|
969
|
+
# Stop profiler
|
962
970
|
if profile:
|
963
971
|
print("Stopping profiler...")
|
964
972
|
profile_output = await async_request_profile(api_url=base_url + "/stop_profile")
|
@@ -968,8 +976,8 @@ async def benchmark(
|
|
968
976
|
if pbar is not None:
|
969
977
|
pbar.close()
|
970
978
|
|
979
|
+
# Compute metrics and print results
|
971
980
|
benchmark_duration = time.perf_counter() - benchmark_start_time
|
972
|
-
|
973
981
|
metrics, output_lens = calculate_metrics(
|
974
982
|
input_requests=input_requests,
|
975
983
|
outputs=outputs,
|
sglang/lang/backend/openai.py
CHANGED
@@ -366,6 +366,11 @@ class OpenAI(BaseBackend):
|
|
366
366
|
def openai_completion(
|
367
367
|
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
368
368
|
):
|
369
|
+
# if "ebnf" is in kwargs, warn and remove
|
370
|
+
if "ebnf" in kwargs:
|
371
|
+
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
|
372
|
+
del kwargs["ebnf"]
|
373
|
+
|
369
374
|
for attempt in range(retries):
|
370
375
|
try:
|
371
376
|
if is_chat:
|
@@ -398,6 +403,11 @@ def openai_completion(
|
|
398
403
|
def openai_completion_stream(
|
399
404
|
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
400
405
|
):
|
406
|
+
# if "ebnf" is in kwargs, warn and remove
|
407
|
+
if "ebnf" in kwargs:
|
408
|
+
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
|
409
|
+
del kwargs["ebnf"]
|
410
|
+
|
401
411
|
for attempt in range(retries):
|
402
412
|
try:
|
403
413
|
if is_chat:
|
@@ -15,7 +15,7 @@
|
|
15
15
|
import json
|
16
16
|
import logging
|
17
17
|
from enum import IntEnum, auto
|
18
|
-
from typing import List, Optional, Union
|
18
|
+
from typing import List, Optional, Set, Union
|
19
19
|
|
20
20
|
import torch
|
21
21
|
from transformers import PretrainedConfig
|
@@ -47,6 +47,7 @@ class ModelConfig:
|
|
47
47
|
self.model_path = model_path
|
48
48
|
self.revision = revision
|
49
49
|
self.quantization = quantization
|
50
|
+
|
50
51
|
# Parse args
|
51
52
|
self.model_override_args = json.loads(model_override_args)
|
52
53
|
self.hf_config = get_config(
|
@@ -130,7 +131,8 @@ class ModelConfig:
|
|
130
131
|
# Veirfy quantization
|
131
132
|
self._verify_quantization()
|
132
133
|
|
133
|
-
#
|
134
|
+
# Cache attributes
|
135
|
+
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
134
136
|
self.image_token_id = getattr(self.hf_config, "image_token_id", None)
|
135
137
|
|
136
138
|
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
@@ -271,6 +273,13 @@ class ModelConfig:
|
|
271
273
|
self.quantization,
|
272
274
|
)
|
273
275
|
|
276
|
+
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
|
277
|
+
eos_ids = getattr(self.hf_config, "eos_token_id", None)
|
278
|
+
if eos_ids:
|
279
|
+
# it can be either int or list of int
|
280
|
+
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
|
281
|
+
return eos_ids
|
282
|
+
|
274
283
|
|
275
284
|
def get_hf_text_config(config: PretrainedConfig):
|
276
285
|
"""Get the "sub" config relevant to llm for multi modal models.
|
@@ -126,6 +126,12 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
126
126
|
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
127
127
|
)
|
128
128
|
return None
|
129
|
+
elif key_type == "ebnf":
|
130
|
+
try:
|
131
|
+
ctx = self.grammar_compiler.compile_grammar(key_string)
|
132
|
+
except RuntimeError as e:
|
133
|
+
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
|
134
|
+
return None
|
129
135
|
elif key_type == "regex":
|
130
136
|
logger.warning(
|
131
137
|
"regex hasn't been supported by xgrammar yet. This is skipped."
|
@@ -8,8 +8,9 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
|
|
8
8
|
"""
|
9
9
|
|
10
10
|
import os
|
11
|
+
from dataclasses import dataclass
|
11
12
|
from enum import Enum, auto
|
12
|
-
from typing import TYPE_CHECKING, List
|
13
|
+
from typing import TYPE_CHECKING, List, Union
|
13
14
|
|
14
15
|
import torch
|
15
16
|
import triton
|
@@ -38,12 +39,25 @@ class WrapperDispatch(Enum):
|
|
38
39
|
CROSS_ATTENTION = auto()
|
39
40
|
|
40
41
|
|
42
|
+
@dataclass
|
43
|
+
class DecodeMetadata:
|
44
|
+
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
|
45
|
+
|
46
|
+
|
47
|
+
@dataclass
|
48
|
+
class PrefillMetadata:
|
49
|
+
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
|
50
|
+
use_ragged: bool
|
51
|
+
extend_no_prefix: bool
|
52
|
+
|
53
|
+
|
41
54
|
class FlashInferAttnBackend(AttentionBackend):
|
42
55
|
"""Flashinfer attention kernels."""
|
43
56
|
|
44
57
|
def __init__(self, model_runner: ModelRunner):
|
45
58
|
super().__init__()
|
46
59
|
|
60
|
+
# Parse constants
|
47
61
|
self.decode_use_tensor_cores = should_use_tensor_core(
|
48
62
|
kv_cache_dtype=model_runner.kv_cache_dtype,
|
49
63
|
num_attention_heads=model_runner.model_config.num_attention_heads
|
@@ -52,7 +66,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
52
66
|
model_runner.tp_size
|
53
67
|
),
|
54
68
|
)
|
55
|
-
|
56
69
|
self.max_context_len = model_runner.model_config.context_len
|
57
70
|
|
58
71
|
assert not (
|
@@ -120,8 +133,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
120
133
|
)
|
121
134
|
|
122
135
|
# Other metadata
|
123
|
-
self.forward_metadata = None
|
124
|
-
self.
|
136
|
+
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
137
|
+
self.decode_cuda_graph_metadata = {}
|
125
138
|
|
126
139
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
127
140
|
if forward_batch.forward_mode.is_decode():
|
@@ -129,10 +142,10 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
129
142
|
forward_batch.req_pool_indices,
|
130
143
|
forward_batch.seq_lens,
|
131
144
|
forward_batch.seq_lens_sum,
|
132
|
-
decode_wrappers=
|
145
|
+
decode_wrappers=self.decode_wrappers,
|
133
146
|
encoder_lens=forward_batch.encoder_lens,
|
134
147
|
)
|
135
|
-
self.forward_metadata = (self.decode_wrappers
|
148
|
+
self.forward_metadata = DecodeMetadata(self.decode_wrappers)
|
136
149
|
else:
|
137
150
|
prefix_lens = forward_batch.extend_prefix_lens
|
138
151
|
|
@@ -149,11 +162,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
149
162
|
forward_batch.seq_lens,
|
150
163
|
forward_batch.seq_lens_sum,
|
151
164
|
prefix_lens,
|
165
|
+
prefill_wrappers=self.prefill_wrappers_paged,
|
152
166
|
use_ragged=use_ragged,
|
153
167
|
encoder_lens=forward_batch.encoder_lens,
|
154
168
|
)
|
155
|
-
|
156
|
-
|
169
|
+
self.forward_metadata = PrefillMetadata(
|
170
|
+
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
|
171
|
+
)
|
157
172
|
|
158
173
|
def init_cuda_graph_state(self, max_bs: int):
|
159
174
|
cuda_graph_kv_indices = torch.zeros(
|
@@ -194,8 +209,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
194
209
|
decode_wrappers=decode_wrappers,
|
195
210
|
encoder_lens=encoder_lens,
|
196
211
|
)
|
197
|
-
self.
|
198
|
-
self.forward_metadata = (decode_wrappers
|
212
|
+
self.decode_cuda_graph_metadata[bs] = decode_wrappers
|
213
|
+
self.forward_metadata = DecodeMetadata(decode_wrappers)
|
199
214
|
|
200
215
|
def init_forward_metadata_replay_cuda_graph(
|
201
216
|
self,
|
@@ -209,7 +224,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
209
224
|
req_pool_indices[:bs],
|
210
225
|
seq_lens[:bs],
|
211
226
|
seq_lens_sum,
|
212
|
-
decode_wrappers=self.
|
227
|
+
decode_wrappers=self.decode_cuda_graph_metadata[bs],
|
213
228
|
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
214
229
|
)
|
215
230
|
|
@@ -225,18 +240,16 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
225
240
|
forward_batch: ForwardBatch,
|
226
241
|
save_kv_cache=True,
|
227
242
|
):
|
228
|
-
prefill_wrapper_paged = self.
|
243
|
+
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
229
244
|
self._get_wrapper_idx(layer)
|
230
245
|
]
|
231
|
-
|
232
|
-
use_ragged, extend_no_prefix = self.forward_metadata
|
233
246
|
cache_loc = (
|
234
247
|
forward_batch.out_cache_loc
|
235
248
|
if not layer.is_cross_attention
|
236
249
|
else forward_batch.encoder_out_cache_loc
|
237
250
|
)
|
238
251
|
|
239
|
-
if not use_ragged:
|
252
|
+
if not self.forward_metadata.use_ragged:
|
240
253
|
if k is not None:
|
241
254
|
assert v is not None
|
242
255
|
if save_kv_cache:
|
@@ -260,7 +273,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
260
273
|
logits_soft_cap=layer.logit_cap,
|
261
274
|
)
|
262
275
|
|
263
|
-
if extend_no_prefix:
|
276
|
+
if self.forward_metadata.extend_no_prefix:
|
264
277
|
o = o1
|
265
278
|
else:
|
266
279
|
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
@@ -287,7 +300,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
287
300
|
forward_batch: ForwardBatch,
|
288
301
|
save_kv_cache=True,
|
289
302
|
):
|
290
|
-
decode_wrapper = self.forward_metadata[
|
303
|
+
decode_wrapper = self.forward_metadata.decode_wrappers[
|
304
|
+
self._get_wrapper_idx(layer)
|
305
|
+
]
|
291
306
|
cache_loc = (
|
292
307
|
forward_batch.out_cache_loc
|
293
308
|
if not layer.is_cross_attention
|
@@ -322,7 +337,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
322
337
|
|
323
338
|
class FlashInferIndicesUpdaterDecode:
|
324
339
|
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
325
|
-
# Constants
|
340
|
+
# Parse Constants
|
326
341
|
self.num_qo_heads = (
|
327
342
|
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
328
343
|
)
|
@@ -340,9 +355,8 @@ class FlashInferIndicesUpdaterDecode:
|
|
340
355
|
self.kv_indptr = attn_backend.kv_indptr
|
341
356
|
self.kv_last_page_len = attn_backend.kv_last_page_len
|
342
357
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
343
|
-
self.decode_wrappers = attn_backend.decode_wrappers
|
344
358
|
|
345
|
-
# Dispatch
|
359
|
+
# Dispatch the update function
|
346
360
|
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
347
361
|
self.update = self.update_sliding_window
|
348
362
|
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
@@ -356,7 +370,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
356
370
|
req_pool_indices: torch.Tensor,
|
357
371
|
seq_lens: torch.Tensor,
|
358
372
|
seq_lens_sum: int,
|
359
|
-
decode_wrappers: List,
|
373
|
+
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
360
374
|
encoder_lens: torch.Tensor,
|
361
375
|
):
|
362
376
|
# Keep the signature for type checking. It will be assigned during runtime.
|
@@ -367,7 +381,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
367
381
|
req_pool_indices: torch.Tensor,
|
368
382
|
seq_lens: torch.Tensor,
|
369
383
|
seq_lens_sum: int,
|
370
|
-
decode_wrappers: List,
|
384
|
+
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
371
385
|
encoder_lens: torch.Tensor,
|
372
386
|
):
|
373
387
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
@@ -385,11 +399,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
385
399
|
req_pool_indices: torch.Tensor,
|
386
400
|
seq_lens: torch.Tensor,
|
387
401
|
seq_lens_sum: int,
|
388
|
-
decode_wrappers: List,
|
402
|
+
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
389
403
|
encoder_lens: torch.Tensor,
|
390
404
|
):
|
391
|
-
decode_wrappers = decode_wrappers or self.decode_wrappers
|
392
|
-
|
393
405
|
for wrapper_id in range(2):
|
394
406
|
if wrapper_id == 0:
|
395
407
|
# Sliding window attention
|
@@ -419,11 +431,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
419
431
|
req_pool_indices: torch.Tensor,
|
420
432
|
seq_lens: torch.Tensor,
|
421
433
|
seq_lens_sum: int,
|
422
|
-
decode_wrappers: List,
|
434
|
+
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
423
435
|
encoder_lens: torch.Tensor,
|
424
436
|
):
|
425
|
-
decode_wrappers = decode_wrappers or self.decode_wrappers
|
426
|
-
|
427
437
|
for wrapper_id in range(2):
|
428
438
|
if wrapper_id == 0:
|
429
439
|
# Normal attention
|
@@ -446,7 +456,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
446
456
|
|
447
457
|
def call_begin_forward(
|
448
458
|
self,
|
449
|
-
wrapper,
|
459
|
+
wrapper: BatchDecodeWithPagedKVCacheWrapper,
|
450
460
|
req_pool_indices: torch.Tensor,
|
451
461
|
paged_kernel_lens: torch.Tensor,
|
452
462
|
paged_kernel_lens_sum: int,
|
@@ -486,7 +496,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
486
496
|
|
487
497
|
class FlashInferIndicesUpdaterPrefill:
|
488
498
|
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
489
|
-
# Constants
|
499
|
+
# Parse Constants
|
490
500
|
self.num_qo_heads = (
|
491
501
|
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
492
502
|
)
|
@@ -505,10 +515,9 @@ class FlashInferIndicesUpdaterPrefill:
|
|
505
515
|
self.kv_last_page_len = attn_backend.kv_last_page_len
|
506
516
|
self.qo_indptr = attn_backend.qo_indptr
|
507
517
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
508
|
-
self.
|
509
|
-
self.wrappers_paged = attn_backend.prefill_wrappers_paged
|
518
|
+
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
510
519
|
|
511
|
-
# Dispatch
|
520
|
+
# Dispatch the update function
|
512
521
|
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
513
522
|
self.update = self.update_sliding_window
|
514
523
|
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
@@ -523,6 +532,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
523
532
|
seq_lens: torch.Tensor,
|
524
533
|
seq_lens_sum: int,
|
525
534
|
prefix_lens: torch.Tensor,
|
535
|
+
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
526
536
|
use_ragged: bool,
|
527
537
|
encoder_lens: torch.Tensor,
|
528
538
|
):
|
@@ -535,6 +545,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
535
545
|
seq_lens: torch.Tensor,
|
536
546
|
seq_lens_sum: int,
|
537
547
|
prefix_lens: torch.Tensor,
|
548
|
+
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
538
549
|
use_ragged: bool,
|
539
550
|
encoder_lens: torch.Tensor,
|
540
551
|
):
|
@@ -546,8 +557,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
546
557
|
paged_kernel_lens_sum = seq_lens_sum
|
547
558
|
|
548
559
|
self.call_begin_forward(
|
549
|
-
self.
|
550
|
-
|
560
|
+
self.prefill_wrapper_ragged,
|
561
|
+
prefill_wrappers[0],
|
551
562
|
req_pool_indices,
|
552
563
|
paged_kernel_lens,
|
553
564
|
paged_kernel_lens_sum,
|
@@ -565,6 +576,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
565
576
|
seq_lens: torch.Tensor,
|
566
577
|
seq_lens_sum: int,
|
567
578
|
prefix_lens: torch.Tensor,
|
579
|
+
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
568
580
|
use_ragged: bool,
|
569
581
|
encoder_lens: torch.Tensor,
|
570
582
|
):
|
@@ -584,8 +596,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
584
596
|
kv_start_idx = seq_lens - paged_kernel_lens
|
585
597
|
|
586
598
|
self.call_begin_forward(
|
587
|
-
self.
|
588
|
-
|
599
|
+
self.prefill_wrapper_ragged,
|
600
|
+
prefill_wrappers[wrapper_id],
|
589
601
|
req_pool_indices,
|
590
602
|
paged_kernel_lens,
|
591
603
|
paged_kernel_lens_sum,
|
@@ -603,6 +615,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
603
615
|
seq_lens: torch.Tensor,
|
604
616
|
seq_lens_sum: int,
|
605
617
|
prefix_lens: torch.Tensor,
|
618
|
+
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
606
619
|
use_ragged: bool,
|
607
620
|
encoder_lens: torch.Tensor,
|
608
621
|
):
|
@@ -619,8 +632,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
619
632
|
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
620
633
|
|
621
634
|
self.call_begin_forward(
|
622
|
-
self.
|
623
|
-
|
635
|
+
self.prefill_wrapper_ragged,
|
636
|
+
prefill_wrappers[wrapper_id],
|
624
637
|
req_pool_indices,
|
625
638
|
paged_kernel_lens,
|
626
639
|
paged_kernel_lens_sum,
|
@@ -634,8 +647,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
634
647
|
|
635
648
|
def call_begin_forward(
|
636
649
|
self,
|
637
|
-
wrapper_ragged,
|
638
|
-
wrapper_paged,
|
650
|
+
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
|
651
|
+
wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,
|
639
652
|
req_pool_indices: torch.Tensor,
|
640
653
|
paged_kernel_lens: torch.Tensor,
|
641
654
|
paged_kernel_lens_sum: int,
|
@@ -292,27 +292,33 @@ def extend_attention_fwd(
|
|
292
292
|
BLOCK_DPE = 0
|
293
293
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
294
294
|
|
295
|
-
if
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
BLOCK_M, BLOCK_N = (32, 64)
|
300
|
-
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
301
|
-
if Lq <= 128:
|
302
|
-
BLOCK_M, BLOCK_N = (128, 128)
|
303
|
-
elif Lq <= 256:
|
304
|
-
BLOCK_M, BLOCK_N = (64, 64)
|
305
|
-
else:
|
306
|
-
BLOCK_M, BLOCK_N = (32, 64)
|
295
|
+
if is_hip_:
|
296
|
+
BLOCK_M, BLOCK_N = (64, 64)
|
297
|
+
num_warps = 4
|
298
|
+
|
307
299
|
else:
|
308
|
-
|
300
|
+
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
|
301
|
+
if Lq <= 256:
|
302
|
+
BLOCK_M, BLOCK_N = (128, 64)
|
303
|
+
else:
|
304
|
+
BLOCK_M, BLOCK_N = (32, 64)
|
305
|
+
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
306
|
+
if Lq <= 128:
|
307
|
+
BLOCK_M, BLOCK_N = (128, 128)
|
308
|
+
elif Lq <= 256:
|
309
|
+
BLOCK_M, BLOCK_N = (64, 64)
|
310
|
+
else:
|
311
|
+
BLOCK_M, BLOCK_N = (32, 64)
|
312
|
+
else:
|
313
|
+
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
314
|
+
|
315
|
+
num_warps = 4 if Lk <= 64 else 8
|
309
316
|
|
310
317
|
sm_scale = sm_scale or 1.0 / (Lq**0.5)
|
311
318
|
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
|
312
319
|
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
313
320
|
|
314
321
|
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
315
|
-
num_warps = 4 if Lk <= 64 else 8
|
316
322
|
num_stages = 1
|
317
323
|
|
318
324
|
extra_kargs = {}
|
@@ -24,7 +24,11 @@ from vllm.distributed import (
|
|
24
24
|
)
|
25
25
|
|
26
26
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
27
|
-
from sglang.srt.model_executor.forward_batch_info import
|
27
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
28
|
+
CaptureHiddenMode,
|
29
|
+
ForwardBatch,
|
30
|
+
ForwardMode,
|
31
|
+
)
|
28
32
|
|
29
33
|
|
30
34
|
@dataclasses.dataclass
|
@@ -46,6 +50,10 @@ class LogitsProcessorOutput:
|
|
46
50
|
output_top_logprobs_val: List = None
|
47
51
|
output_top_logprobs_idx: List = None
|
48
52
|
|
53
|
+
# Used by speculative decoding (EAGLE)
|
54
|
+
# The output of transformer layers
|
55
|
+
hidden_states: Optional[torch.Tensor] = None
|
56
|
+
|
49
57
|
|
50
58
|
@dataclasses.dataclass
|
51
59
|
class LogitsMetadata:
|
@@ -61,6 +69,8 @@ class LogitsMetadata:
|
|
61
69
|
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
62
70
|
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
|
63
71
|
|
72
|
+
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
|
73
|
+
|
64
74
|
@classmethod
|
65
75
|
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
66
76
|
extend_logprob_pruned_lens_cpu = None
|
@@ -78,6 +88,11 @@ class LogitsMetadata:
|
|
78
88
|
else:
|
79
89
|
return_top_logprob = False
|
80
90
|
|
91
|
+
if forward_batch.spec_info:
|
92
|
+
capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
|
93
|
+
else:
|
94
|
+
capture_hidden_mode = CaptureHiddenMode.NULL
|
95
|
+
|
81
96
|
return cls(
|
82
97
|
forward_mode=forward_batch.forward_mode,
|
83
98
|
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
@@ -87,6 +102,7 @@ class LogitsMetadata:
|
|
87
102
|
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
88
103
|
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
|
89
104
|
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
|
105
|
+
capture_hidden_mode=capture_hidden_mode,
|
90
106
|
)
|
91
107
|
|
92
108
|
|
@@ -116,7 +132,10 @@ class LogitsProcessor(nn.Module):
|
|
116
132
|
assert isinstance(logits_metadata, LogitsMetadata)
|
117
133
|
|
118
134
|
# Get the last hidden states and last logits for the next token prediction
|
119
|
-
if
|
135
|
+
if (
|
136
|
+
logits_metadata.forward_mode.is_decode()
|
137
|
+
or logits_metadata.forward_mode.is_target_verify()
|
138
|
+
):
|
120
139
|
last_index = None
|
121
140
|
last_hidden = hidden_states
|
122
141
|
else:
|
@@ -137,6 +156,15 @@ class LogitsProcessor(nn.Module):
|
|
137
156
|
if not logits_metadata.return_logprob:
|
138
157
|
return LogitsProcessorOutput(
|
139
158
|
next_token_logits=last_logits,
|
159
|
+
hidden_states=(
|
160
|
+
hidden_states
|
161
|
+
if logits_metadata.capture_hidden_mode.is_full()
|
162
|
+
else (
|
163
|
+
last_hidden
|
164
|
+
if logits_metadata.capture_hidden_mode.is_last()
|
165
|
+
else None
|
166
|
+
)
|
167
|
+
),
|
140
168
|
)
|
141
169
|
else:
|
142
170
|
last_logprobs = self.compute_temp_top_p_normalized_logprobs(
|