sglang 0.3.3.post1__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +28 -10
- sglang/bench_server_latency.py +21 -10
- sglang/bench_serving.py +101 -7
- sglang/global_config.py +0 -1
- sglang/srt/layers/attention/__init__.py +27 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
- sglang/srt/layers/attention/flashinfer_backend.py +352 -83
- sglang/srt/layers/attention/triton_backend.py +6 -4
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
- sglang/srt/layers/sampler.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +31 -10
- sglang/srt/managers/io_struct.py +4 -0
- sglang/srt/managers/schedule_batch.py +120 -43
- sglang/srt/managers/schedule_policy.py +2 -1
- sglang/srt/managers/scheduler.py +202 -140
- sglang/srt/managers/tokenizer_manager.py +5 -1
- sglang/srt/managers/tp_worker.py +111 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -4
- sglang/srt/mem_cache/memory_pool.py +77 -4
- sglang/srt/mem_cache/radix_cache.py +15 -7
- sglang/srt/model_executor/cuda_graph_runner.py +4 -4
- sglang/srt/model_executor/forward_batch_info.py +16 -21
- sglang/srt/model_executor/model_runner.py +60 -1
- sglang/srt/models/baichuan.py +2 -3
- sglang/srt/models/chatglm.py +5 -6
- sglang/srt/models/commandr.py +1 -2
- sglang/srt/models/dbrx.py +1 -2
- sglang/srt/models/deepseek.py +4 -5
- sglang/srt/models/deepseek_v2.py +5 -6
- sglang/srt/models/exaone.py +1 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +5 -5
- sglang/srt/models/gpt_bigcode.py +5 -5
- sglang/srt/models/grok.py +1 -2
- sglang/srt/models/internlm2.py +1 -2
- sglang/srt/models/llama.py +1 -2
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +4 -8
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -2
- sglang/srt/models/minicpm3.py +5 -6
- sglang/srt/models/mixtral.py +1 -2
- sglang/srt/models/mixtral_quant.py +1 -2
- sglang/srt/models/olmo.py +352 -0
- sglang/srt/models/olmoe.py +1 -2
- sglang/srt/models/qwen.py +1 -2
- sglang/srt/models/qwen2.py +1 -2
- sglang/srt/models/qwen2_moe.py +4 -5
- sglang/srt/models/stablelm.py +1 -2
- sglang/srt/models/torch_native_llama.py +1 -2
- sglang/srt/models/xverse.py +1 -2
- sglang/srt/models/xverse_moe.py +4 -5
- sglang/srt/models/yivl.py +1 -2
- sglang/srt/openai_api/adapter.py +92 -49
- sglang/srt/openai_api/protocol.py +10 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
- sglang/srt/sampling/sampling_batch_info.py +92 -58
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server.py +116 -17
- sglang/srt/server_args.py +121 -45
- sglang/srt/utils.py +11 -3
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/few_shot_gsm8k_engine.py +144 -0
- sglang/test/srt/sampling/penaltylib/utils.py +16 -12
- sglang/version.py +1 -1
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/METADATA +72 -29
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/RECORD +73 -70
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
- sglang/srt/layers/attention/flashinfer_utils.py +0 -237
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
sglang/bench_latency.py
CHANGED
@@ -232,17 +232,18 @@ def extend(reqs, model_runner):
|
|
232
232
|
model_worker_batch = batch.get_model_worker_batch()
|
233
233
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
234
234
|
logits_output = model_runner.forward(forward_batch)
|
235
|
-
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
235
|
+
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
236
236
|
return next_token_ids, logits_output.next_token_logits, batch
|
237
237
|
|
238
238
|
|
239
239
|
@torch.inference_mode()
|
240
240
|
def decode(input_token_ids, batch, model_runner):
|
241
|
-
batch.
|
241
|
+
batch.output_ids = input_token_ids
|
242
|
+
batch.prepare_for_decode()
|
242
243
|
model_worker_batch = batch.get_model_worker_batch()
|
243
244
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
244
245
|
logits_output = model_runner.forward(forward_batch)
|
245
|
-
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
246
|
+
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
246
247
|
return next_token_ids, logits_output.next_token_logits
|
247
248
|
|
248
249
|
|
@@ -252,6 +253,7 @@ def correctness_test(
|
|
252
253
|
bench_args,
|
253
254
|
tp_rank,
|
254
255
|
):
|
256
|
+
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
255
257
|
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
256
258
|
|
257
259
|
# Load the model
|
@@ -279,8 +281,9 @@ def correctness_test(
|
|
279
281
|
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
|
280
282
|
for _ in range(bench_args.output_len[0] - 1):
|
281
283
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
284
|
+
next_token_ids_list = next_token_ids.tolist()
|
282
285
|
for i in range(len(reqs)):
|
283
|
-
output_ids[i].append(
|
286
|
+
output_ids[i].append(next_token_ids_list[i])
|
284
287
|
|
285
288
|
# Print
|
286
289
|
for i in range(len(reqs)):
|
@@ -288,8 +291,15 @@ def correctness_test(
|
|
288
291
|
rank_print(tokenizer.decode(output_ids[i]), "\n")
|
289
292
|
|
290
293
|
|
294
|
+
def synchronize(device):
|
295
|
+
if device == "cuda":
|
296
|
+
torch.cuda.synchronize()
|
297
|
+
elif device == "xpu":
|
298
|
+
torch.xpu.synchronize()
|
299
|
+
|
300
|
+
|
291
301
|
def latency_test_run_once(
|
292
|
-
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
|
302
|
+
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device
|
293
303
|
):
|
294
304
|
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
295
305
|
if batch_size > max_batch_size:
|
@@ -312,10 +322,10 @@ def latency_test_run_once(
|
|
312
322
|
tot_latency = 0
|
313
323
|
|
314
324
|
# Prefill
|
315
|
-
|
325
|
+
synchronize(device)
|
316
326
|
tic = time.time()
|
317
327
|
next_token_ids, _, batch = extend(reqs, model_runner)
|
318
|
-
|
328
|
+
synchronize(device)
|
319
329
|
prefill_latency = time.time() - tic
|
320
330
|
tot_latency += prefill_latency
|
321
331
|
throughput = input_len * batch_size / prefill_latency
|
@@ -328,10 +338,10 @@ def latency_test_run_once(
|
|
328
338
|
# Decode
|
329
339
|
decode_latencies = []
|
330
340
|
for i in range(output_len - 1):
|
331
|
-
|
341
|
+
synchronize(device)
|
332
342
|
tic = time.time()
|
333
343
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
334
|
-
|
344
|
+
synchronize(device)
|
335
345
|
latency = time.time() - tic
|
336
346
|
tot_latency += latency
|
337
347
|
throughput = batch_size / latency
|
@@ -387,6 +397,7 @@ def latency_test(
|
|
387
397
|
bench_args.batch_size[0],
|
388
398
|
bench_args.input_len[0],
|
389
399
|
8, # shorter decoding to speed up the warmup
|
400
|
+
server_args.device,
|
390
401
|
)
|
391
402
|
rank_print("Benchmark ...")
|
392
403
|
|
@@ -397,7 +408,14 @@ def latency_test(
|
|
397
408
|
):
|
398
409
|
reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
|
399
410
|
ret = latency_test_run_once(
|
400
|
-
bench_args.run_name,
|
411
|
+
bench_args.run_name,
|
412
|
+
model_runner,
|
413
|
+
rank_print,
|
414
|
+
reqs,
|
415
|
+
bs,
|
416
|
+
il,
|
417
|
+
ol,
|
418
|
+
server_args.device,
|
401
419
|
)
|
402
420
|
if ret is not None:
|
403
421
|
result_list.append(ret)
|
sglang/bench_server_latency.py
CHANGED
@@ -6,6 +6,8 @@ It accepts arguments similar to those of launch_server.py.
|
|
6
6
|
Usage:
|
7
7
|
|
8
8
|
python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
|
9
|
+
|
10
|
+
python3 -m sglang.bench_server_latency --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
|
9
11
|
"""
|
10
12
|
|
11
13
|
import argparse
|
@@ -32,6 +34,8 @@ class BenchArgs:
|
|
32
34
|
input_len: Tuple[int] = (1024,)
|
33
35
|
output_len: Tuple[int] = (16,)
|
34
36
|
result_filename: str = "result.jsonl"
|
37
|
+
base_url: str = ""
|
38
|
+
skip_warmup: bool = False
|
35
39
|
|
36
40
|
@staticmethod
|
37
41
|
def add_cli_args(parser: argparse.ArgumentParser):
|
@@ -48,6 +52,8 @@ class BenchArgs:
|
|
48
52
|
parser.add_argument(
|
49
53
|
"--result-filename", type=str, default=BenchArgs.result_filename
|
50
54
|
)
|
55
|
+
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
|
56
|
+
parser.add_argument("--skip-warmup", action="store_true")
|
51
57
|
|
52
58
|
@classmethod
|
53
59
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -139,17 +145,21 @@ def run_one_case(
|
|
139
145
|
|
140
146
|
|
141
147
|
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
142
|
-
|
148
|
+
if bench_args.base_url:
|
149
|
+
proc, base_url = None, bench_args.base_url
|
150
|
+
else:
|
151
|
+
proc, base_url = launch_server_process(server_args)
|
143
152
|
|
144
153
|
# warmup
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
154
|
+
if not bench_args.skip_warmup:
|
155
|
+
run_one_case(
|
156
|
+
base_url,
|
157
|
+
batch_size=16,
|
158
|
+
input_len=1024,
|
159
|
+
output_len=16,
|
160
|
+
run_name="",
|
161
|
+
result_filename="",
|
162
|
+
)
|
153
163
|
|
154
164
|
# benchmark
|
155
165
|
try:
|
@@ -165,7 +175,8 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|
165
175
|
bench_args.result_filename,
|
166
176
|
)
|
167
177
|
finally:
|
168
|
-
|
178
|
+
if proc:
|
179
|
+
kill_child_process(proc.pid)
|
169
180
|
|
170
181
|
print(f"\nResults are saved to {bench_args.result_filename}")
|
171
182
|
|
sglang/bench_serving.py
CHANGED
@@ -222,6 +222,85 @@ async def async_request_openai_completions(
|
|
222
222
|
return output
|
223
223
|
|
224
224
|
|
225
|
+
async def async_request_sglang_generate(
|
226
|
+
request_func_input: RequestFuncInput,
|
227
|
+
pbar: Optional[tqdm] = None,
|
228
|
+
) -> RequestFuncOutput:
|
229
|
+
api_url = request_func_input.api_url
|
230
|
+
prompt = request_func_input.prompt
|
231
|
+
|
232
|
+
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
233
|
+
payload = {
|
234
|
+
"text": prompt,
|
235
|
+
"sampling_params": {
|
236
|
+
"temperature": 0.0,
|
237
|
+
"max_new_tokens": request_func_input.output_len,
|
238
|
+
"ignore_eos": not args.disable_ignore_eos,
|
239
|
+
},
|
240
|
+
"stream": not args.disable_stream,
|
241
|
+
**request_func_input.extra_request_body,
|
242
|
+
}
|
243
|
+
headers = {}
|
244
|
+
|
245
|
+
output = RequestFuncOutput()
|
246
|
+
output.prompt_len = request_func_input.prompt_len
|
247
|
+
|
248
|
+
generated_text = ""
|
249
|
+
ttft = 0.0
|
250
|
+
st = time.perf_counter()
|
251
|
+
most_recent_timestamp = st
|
252
|
+
try:
|
253
|
+
async with session.post(
|
254
|
+
url=api_url, json=payload, headers=headers
|
255
|
+
) as response:
|
256
|
+
if response.status == 200:
|
257
|
+
async for chunk_bytes in response.content:
|
258
|
+
chunk_bytes = chunk_bytes.strip()
|
259
|
+
if not chunk_bytes:
|
260
|
+
continue
|
261
|
+
# print(chunk_bytes)
|
262
|
+
|
263
|
+
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
264
|
+
latency = time.perf_counter() - st
|
265
|
+
if chunk == "[DONE]":
|
266
|
+
pass
|
267
|
+
else:
|
268
|
+
data = json.loads(chunk)
|
269
|
+
|
270
|
+
# NOTE: Some completion API might have a last
|
271
|
+
# usage summary response without a token so we
|
272
|
+
# want to check a token was generated
|
273
|
+
if data["text"]:
|
274
|
+
timestamp = time.perf_counter()
|
275
|
+
# First token
|
276
|
+
if ttft == 0.0:
|
277
|
+
ttft = time.perf_counter() - st
|
278
|
+
output.ttft = ttft
|
279
|
+
|
280
|
+
# Decoding phase
|
281
|
+
else:
|
282
|
+
output.itl.append(timestamp - most_recent_timestamp)
|
283
|
+
|
284
|
+
most_recent_timestamp = timestamp
|
285
|
+
generated_text = data["text"]
|
286
|
+
|
287
|
+
output.generated_text = generated_text
|
288
|
+
output.success = True
|
289
|
+
output.latency = latency
|
290
|
+
output.output_len = request_func_input.output_len
|
291
|
+
else:
|
292
|
+
output.error = response.reason or ""
|
293
|
+
output.success = False
|
294
|
+
except Exception:
|
295
|
+
output.success = False
|
296
|
+
exc_info = sys.exc_info()
|
297
|
+
output.error = "".join(traceback.format_exception(*exc_info))
|
298
|
+
|
299
|
+
if pbar:
|
300
|
+
pbar.update(1)
|
301
|
+
return output
|
302
|
+
|
303
|
+
|
225
304
|
async def async_request_gserver(
|
226
305
|
request_func_input: RequestFuncInput,
|
227
306
|
pbar: Optional[tqdm] = None,
|
@@ -264,7 +343,9 @@ def get_tokenizer(
|
|
264
343
|
|
265
344
|
|
266
345
|
ASYNC_REQUEST_FUNCS = {
|
267
|
-
"sglang":
|
346
|
+
"sglang": async_request_sglang_generate,
|
347
|
+
"sglang-native": async_request_sglang_generate,
|
348
|
+
"sglang-oai": async_request_openai_completions,
|
268
349
|
"vllm": async_request_openai_completions,
|
269
350
|
"lmdeploy": async_request_openai_completions,
|
270
351
|
"trt": async_request_trt_llm,
|
@@ -387,6 +468,8 @@ def sample_sharegpt_requests(
|
|
387
468
|
continue
|
388
469
|
filtered_dataset.append((prompt, prompt_len, output_len))
|
389
470
|
|
471
|
+
print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
|
472
|
+
print(f"#Output tokens: {np.sum([x[2] for x in filtered_dataset])}")
|
390
473
|
return filtered_dataset
|
391
474
|
|
392
475
|
|
@@ -587,6 +670,8 @@ async def benchmark(
|
|
587
670
|
else:
|
588
671
|
print("Initial test run completed. Starting main benchmark run...")
|
589
672
|
|
673
|
+
time.sleep(1.5)
|
674
|
+
|
590
675
|
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
591
676
|
|
592
677
|
benchmark_start_time = time.perf_counter()
|
@@ -782,24 +867,33 @@ def run_benchmark(args_: argparse.Namespace):
|
|
782
867
|
if args.port is None:
|
783
868
|
args.port = {
|
784
869
|
"sglang": 30000,
|
870
|
+
"sglang-native": 30000,
|
871
|
+
"sglang-oai": 30000,
|
785
872
|
"lmdeploy": 23333,
|
786
873
|
"vllm": 8000,
|
787
874
|
"trt": 8000,
|
788
875
|
"gserver": 9988,
|
789
876
|
}.get(args.backend, 30000)
|
790
877
|
|
791
|
-
api_url = (
|
792
|
-
f"{args.base_url}/v1/completions"
|
793
|
-
if args.base_url
|
794
|
-
else f"http://{args.host}:{args.port}/v1/completions"
|
795
|
-
)
|
796
878
|
model_url = (
|
797
879
|
f"{args.base_url}/v1/models"
|
798
880
|
if args.base_url
|
799
881
|
else f"http://{args.host}:{args.port}/v1/models"
|
800
882
|
)
|
801
883
|
|
802
|
-
if args.backend
|
884
|
+
if args.backend in ["sglang", "sglang-native"]:
|
885
|
+
api_url = (
|
886
|
+
f"{args.base_url}/generate"
|
887
|
+
if args.base_url
|
888
|
+
else f"http://{args.host}:{args.port}/generate"
|
889
|
+
)
|
890
|
+
elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
|
891
|
+
api_url = (
|
892
|
+
f"{args.base_url}/v1/completions"
|
893
|
+
if args.base_url
|
894
|
+
else f"http://{args.host}:{args.port}/v1/completions"
|
895
|
+
)
|
896
|
+
elif args.backend == "trt":
|
803
897
|
api_url = (
|
804
898
|
f"{args.base_url}/v2/models/ensemble/generate_stream"
|
805
899
|
if args.base_url
|
sglang/global_config.py
CHANGED
@@ -19,7 +19,6 @@ class GlobalConfig:
|
|
19
19
|
self.new_token_ratio_decay = 0.001
|
20
20
|
|
21
21
|
# Runtime constants: others
|
22
|
-
self.num_continue_decode_steps = 10
|
23
22
|
self.retract_decode_steps = 20
|
24
23
|
self.flashinfer_workspace_size = os.environ.get(
|
25
24
|
"FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
2
|
|
3
|
+
import torch
|
3
4
|
from torch import nn
|
4
5
|
|
5
6
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
@@ -18,13 +19,13 @@ class AttentionBackend(ABC):
|
|
18
19
|
raise NotImplementedError()
|
19
20
|
|
20
21
|
def init_forward_metadata_capture_cuda_graph(
|
21
|
-
self, bs: int, req_pool_indices, seq_lens
|
22
|
+
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
22
23
|
):
|
23
24
|
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
24
25
|
raise NotImplementedError()
|
25
26
|
|
26
27
|
def init_forward_metadata_replay_cuda_graph(
|
27
|
-
self, bs: int, req_pool_indices, seq_lens
|
28
|
+
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
28
29
|
):
|
29
30
|
"""Init the metadata for a forward pass for replying a cuda graph."""
|
30
31
|
raise NotImplementedError()
|
@@ -33,17 +34,38 @@ class AttentionBackend(ABC):
|
|
33
34
|
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
|
34
35
|
raise NotImplementedError()
|
35
36
|
|
36
|
-
def forward(
|
37
|
+
def forward(
|
38
|
+
self,
|
39
|
+
q: torch.Tensor,
|
40
|
+
k: torch.Tensor,
|
41
|
+
v: torch.Tensor,
|
42
|
+
layer: nn.Module,
|
43
|
+
forward_batch: ForwardBatch,
|
44
|
+
):
|
37
45
|
"""Run forward on an attention layer."""
|
38
46
|
if forward_batch.forward_mode.is_decode():
|
39
47
|
return self.forward_decode(q, k, v, layer, forward_batch)
|
40
48
|
else:
|
41
49
|
return self.forward_extend(q, k, v, layer, forward_batch)
|
42
50
|
|
43
|
-
def forward_decode(
|
51
|
+
def forward_decode(
|
52
|
+
self,
|
53
|
+
q: torch.Tensor,
|
54
|
+
k: torch.Tensor,
|
55
|
+
v: torch.Tensor,
|
56
|
+
layer: nn.Module,
|
57
|
+
forward_batch: ForwardBatch,
|
58
|
+
):
|
44
59
|
"""Run a forward for decode."""
|
45
60
|
raise NotImplementedError()
|
46
61
|
|
47
|
-
def forward_extend(
|
62
|
+
def forward_extend(
|
63
|
+
self,
|
64
|
+
q: torch.Tensor,
|
65
|
+
k: torch.Tensor,
|
66
|
+
v: torch.Tensor,
|
67
|
+
layer: nn.Module,
|
68
|
+
forward_batch: ForwardBatch,
|
69
|
+
):
|
48
70
|
"""Run a forward for extend."""
|
49
71
|
raise NotImplementedError()
|
@@ -0,0 +1,281 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch.nn as nn
|
7
|
+
|
8
|
+
from sglang.srt.layers.attention import AttentionBackend
|
9
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
10
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
14
|
+
|
15
|
+
|
16
|
+
class DoubleSparseAttnBackend(AttentionBackend):
|
17
|
+
def __init__(self, model_runner: ModelRunner):
|
18
|
+
# Lazy import to avoid the initialization of cuda context
|
19
|
+
from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
|
20
|
+
flash_decode_attention_fwd,
|
21
|
+
flash_decode_sparse_attention_fwd,
|
22
|
+
)
|
23
|
+
from sglang.srt.layers.attention.triton_ops.extend_attention import (
|
24
|
+
extend_attention_fwd,
|
25
|
+
)
|
26
|
+
|
27
|
+
super().__init__()
|
28
|
+
|
29
|
+
self.decode_attention_fwd = flash_decode_attention_fwd
|
30
|
+
self.decode_sparse_attention_fwd = flash_decode_sparse_attention_fwd
|
31
|
+
self.extend_attention_fwd = extend_attention_fwd
|
32
|
+
self.num_head = model_runner.model_config.num_attention_heads
|
33
|
+
self.head_dim = model_runner.model_config.hidden_size // self.num_head
|
34
|
+
self.heavy_token_num = model_runner.server_args.ds_heavy_token_num
|
35
|
+
|
36
|
+
self.sorted_channels = model_runner.sorted_channels
|
37
|
+
self.sparse_decode_thresold = (
|
38
|
+
model_runner.server_args.ds_sparse_decode_threshold
|
39
|
+
)
|
40
|
+
self.att_out_approx: torch.Tensor = None
|
41
|
+
self.mid_out: torch.Tensor = None
|
42
|
+
self.mid_o_logexpsum: torch.Tensor = None
|
43
|
+
|
44
|
+
# TODO: Change the hard-coded block_seq_num
|
45
|
+
self.BLOCK_SEQ = 128
|
46
|
+
|
47
|
+
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
48
|
+
self.reduce_dtype = torch.float32
|
49
|
+
else:
|
50
|
+
self.reduce_dtype = torch.float16
|
51
|
+
|
52
|
+
self.forward_metadata = None
|
53
|
+
|
54
|
+
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
|
55
|
+
|
56
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
57
|
+
"""Init auxiliary variables for triton attention backend."""
|
58
|
+
|
59
|
+
if forward_batch.forward_mode.is_decode():
|
60
|
+
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
|
61
|
+
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
|
62
|
+
|
63
|
+
total_num_tokens = torch.sum(forward_batch.seq_lens).item()
|
64
|
+
attn_logits = torch.empty(
|
65
|
+
(self.num_head, total_num_tokens),
|
66
|
+
dtype=self.reduce_dtype,
|
67
|
+
device="cuda",
|
68
|
+
)
|
69
|
+
|
70
|
+
max_seq_len = torch.max(forward_batch.seq_lens).item()
|
71
|
+
min_seq_len = torch.min(forward_batch.seq_lens).item()
|
72
|
+
max_extend_len = None
|
73
|
+
# NOTE: Align sequence order with req_to_token order
|
74
|
+
ds_req_to_token = forward_batch.req_to_token_pool.req_to_token[
|
75
|
+
forward_batch.req_pool_indices
|
76
|
+
]
|
77
|
+
|
78
|
+
bsz = forward_batch.seq_lens.shape[0]
|
79
|
+
|
80
|
+
att_out_approx = torch.empty(
|
81
|
+
[self.num_head, bsz, max_seq_len],
|
82
|
+
dtype=self.reduce_dtype,
|
83
|
+
device="cuda",
|
84
|
+
)
|
85
|
+
|
86
|
+
block_seq_num = (
|
87
|
+
self.heavy_token_num + self.BLOCK_SEQ - 1
|
88
|
+
) // self.BLOCK_SEQ
|
89
|
+
|
90
|
+
mid_out = torch.empty(
|
91
|
+
[bsz, self.num_head, block_seq_num, self.head_dim],
|
92
|
+
dtype=torch.float32,
|
93
|
+
device="cuda",
|
94
|
+
)
|
95
|
+
mid_o_logexpsum = torch.empty(
|
96
|
+
[bsz, self.num_head, block_seq_num], dtype=torch.float32, device="cuda"
|
97
|
+
)
|
98
|
+
self.att_out_approx = att_out_approx
|
99
|
+
self.mid_out = mid_out
|
100
|
+
self.mid_o_logexpsum = mid_o_logexpsum
|
101
|
+
|
102
|
+
else:
|
103
|
+
start_loc = attn_logits = max_seq_len = min_seq_len = None
|
104
|
+
prefix_lens = forward_batch.extend_prefix_lens
|
105
|
+
max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item()
|
106
|
+
ds_req_to_token = None
|
107
|
+
|
108
|
+
self.forward_metadata = (
|
109
|
+
start_loc,
|
110
|
+
attn_logits,
|
111
|
+
max_seq_len,
|
112
|
+
min_seq_len,
|
113
|
+
max_extend_len,
|
114
|
+
ds_req_to_token,
|
115
|
+
)
|
116
|
+
|
117
|
+
def init_cuda_graph_state(self, max_bs: int):
|
118
|
+
# TODO(Andy): Support CUDA graph for double sparse attention
|
119
|
+
raise ValueError(
|
120
|
+
"Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
121
|
+
)
|
122
|
+
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
123
|
+
|
124
|
+
self.cuda_graph_start_loc = torch.zeros(
|
125
|
+
(max_bs,), dtype=torch.int32, device="cuda"
|
126
|
+
)
|
127
|
+
self.cuda_graph_attn_logits = torch.empty(
|
128
|
+
(
|
129
|
+
self.num_head,
|
130
|
+
self.cuda_graph_max_total_num_tokens,
|
131
|
+
),
|
132
|
+
dtype=self.reduce_dtype,
|
133
|
+
device="cuda",
|
134
|
+
)
|
135
|
+
|
136
|
+
def init_forward_metadata_capture_cuda_graph(
|
137
|
+
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
138
|
+
):
|
139
|
+
self.forward_metadata = (
|
140
|
+
self.cuda_graph_start_loc,
|
141
|
+
self.cuda_graph_attn_logits,
|
142
|
+
self.cuda_graph_max_seq_len,
|
143
|
+
None,
|
144
|
+
)
|
145
|
+
|
146
|
+
def init_forward_metadata_replay_cuda_graph(
|
147
|
+
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
148
|
+
):
|
149
|
+
self.cuda_graph_start_loc.zero_()
|
150
|
+
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
151
|
+
|
152
|
+
def get_cuda_graph_seq_len_fill_value(self):
|
153
|
+
return 1
|
154
|
+
|
155
|
+
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
156
|
+
# TODO: reuse the buffer across layers
|
157
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
158
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
159
|
+
else:
|
160
|
+
o = torch.empty_like(q)
|
161
|
+
|
162
|
+
k_label = torch.gather(
|
163
|
+
k,
|
164
|
+
2,
|
165
|
+
self.sorted_channels[layer.layer_id]
|
166
|
+
.unsqueeze(0)
|
167
|
+
.expand(k.shape[0], -1, -1),
|
168
|
+
)
|
169
|
+
|
170
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
171
|
+
layer.layer_id, forward_batch.out_cache_loc, k, v, k_label
|
172
|
+
)
|
173
|
+
|
174
|
+
(
|
175
|
+
start_loc,
|
176
|
+
attn_logits,
|
177
|
+
max_seq_len,
|
178
|
+
min_seq_len,
|
179
|
+
max_extend_len,
|
180
|
+
ds_req_to_token,
|
181
|
+
) = self.forward_metadata
|
182
|
+
self.extend_attention_fwd(
|
183
|
+
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
184
|
+
k.contiguous(),
|
185
|
+
v.contiguous(),
|
186
|
+
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
187
|
+
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
188
|
+
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
189
|
+
forward_batch.req_to_token_pool.req_to_token,
|
190
|
+
forward_batch.req_pool_indices,
|
191
|
+
forward_batch.seq_lens,
|
192
|
+
forward_batch.extend_seq_lens,
|
193
|
+
forward_batch.extend_start_loc,
|
194
|
+
max_extend_len,
|
195
|
+
layer.scaling,
|
196
|
+
layer.logit_cap,
|
197
|
+
)
|
198
|
+
return o
|
199
|
+
|
200
|
+
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
201
|
+
# During torch.compile, there is a bug in rotary_emb that causes the
|
202
|
+
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
203
|
+
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
204
|
+
|
205
|
+
# TODO: reuse the buffer across layers
|
206
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
207
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
208
|
+
else:
|
209
|
+
o = torch.empty_like(q)
|
210
|
+
|
211
|
+
# TODO: Add min seqlen
|
212
|
+
(
|
213
|
+
start_loc,
|
214
|
+
attn_logits,
|
215
|
+
max_seq_len,
|
216
|
+
min_seq_len,
|
217
|
+
max_extend_len,
|
218
|
+
ds_req_to_token,
|
219
|
+
) = self.forward_metadata
|
220
|
+
|
221
|
+
k_label = torch.gather(
|
222
|
+
k,
|
223
|
+
2,
|
224
|
+
self.sorted_channels[layer.layer_id]
|
225
|
+
.unsqueeze(0)
|
226
|
+
.expand(k.shape[0], -1, -1),
|
227
|
+
)
|
228
|
+
|
229
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
230
|
+
layer.layer_id, forward_batch.out_cache_loc, k, v, k_label
|
231
|
+
)
|
232
|
+
|
233
|
+
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
|
234
|
+
# and set a minimum value for sparse_decode
|
235
|
+
if (
|
236
|
+
min_seq_len < self.heavy_token_num
|
237
|
+
or max_seq_len < self.sparse_decode_thresold
|
238
|
+
):
|
239
|
+
self.decode_attention_fwd(
|
240
|
+
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
241
|
+
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
242
|
+
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
243
|
+
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
244
|
+
forward_batch.req_to_token_pool.req_to_token,
|
245
|
+
forward_batch.req_pool_indices,
|
246
|
+
start_loc,
|
247
|
+
forward_batch.seq_lens,
|
248
|
+
attn_logits,
|
249
|
+
max_seq_len,
|
250
|
+
layer.scaling,
|
251
|
+
layer.logit_cap,
|
252
|
+
)
|
253
|
+
else:
|
254
|
+
# TODO(Andy): indexing with torch.gather or torch.index_select or customized kernel
|
255
|
+
q_label = torch.gather(
|
256
|
+
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
257
|
+
2,
|
258
|
+
self.sorted_channels[layer.layer_id]
|
259
|
+
.unsqueeze(0)
|
260
|
+
.expand(q.shape[0], -1, -1),
|
261
|
+
)
|
262
|
+
self.decode_sparse_attention_fwd(
|
263
|
+
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
264
|
+
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
265
|
+
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
266
|
+
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
267
|
+
q_label,
|
268
|
+
forward_batch.token_to_kv_pool.get_label_buffer(layer.layer_id),
|
269
|
+
ds_req_to_token,
|
270
|
+
forward_batch.seq_lens,
|
271
|
+
max_seq_len,
|
272
|
+
layer.scaling,
|
273
|
+
layer.logit_cap,
|
274
|
+
self.heavy_token_num,
|
275
|
+
self.att_out_approx,
|
276
|
+
self.mid_out,
|
277
|
+
self.mid_o_logexpsum,
|
278
|
+
self.BLOCK_SEQ,
|
279
|
+
)
|
280
|
+
|
281
|
+
return o
|