sglang 0.2.9__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +114 -63
- sglang/check_env.py +2 -0
- sglang/lang/backend/runtime_endpoint.py +0 -11
- sglang/srt/hf_transformers_utils.py +2 -2
- sglang/srt/layers/extend_attention.py +59 -7
- sglang/srt/layers/radix_attention.py +22 -9
- sglang/srt/layers/token_attention.py +28 -2
- sglang/srt/managers/io_struct.py +9 -4
- sglang/srt/managers/schedule_batch.py +15 -11
- sglang/srt/managers/tokenizer_manager.py +28 -13
- sglang/srt/mem_cache/memory_pool.py +65 -24
- sglang/srt/model_config.py +11 -0
- sglang/srt/model_executor/model_runner.py +52 -21
- sglang/srt/models/deepseek_v2.py +198 -16
- sglang/srt/openai_api/adapter.py +120 -20
- sglang/srt/openai_api/protocol.py +1 -1
- sglang/srt/server.py +87 -78
- sglang/srt/server_args.py +8 -2
- sglang/srt/utils.py +25 -20
- sglang/test/run_eval.py +21 -10
- sglang/test/runners.py +237 -0
- sglang/test/simple_eval_common.py +12 -12
- sglang/test/simple_eval_gpqa.py +92 -0
- sglang/test/simple_eval_humaneval.py +5 -5
- sglang/test/simple_eval_math.py +72 -0
- sglang/test/test_utils.py +94 -13
- sglang/utils.py +15 -37
- sglang/version.py +1 -1
- {sglang-0.2.9.dist-info → sglang-0.2.10.dist-info}/METADATA +29 -27
- {sglang-0.2.9.dist-info → sglang-0.2.10.dist-info}/RECORD +33 -30
- {sglang-0.2.9.dist-info → sglang-0.2.10.dist-info}/LICENSE +0 -0
- {sglang-0.2.9.dist-info → sglang-0.2.10.dist-info}/WHEEL +0 -0
- {sglang-0.2.9.dist-info → sglang-0.2.10.dist-info}/top_level.txt +0 -0
sglang/bench_latency.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1
1
|
"""
|
2
2
|
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
|
3
3
|
|
4
|
-
# Usage (latency test):
|
4
|
+
# Usage (latency test) with dummy weights:
|
5
5
|
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
|
6
6
|
|
7
7
|
# Usage (correctness test):
|
8
8
|
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
|
9
9
|
|
10
|
-
### Reference output:
|
10
|
+
### Reference output (of the correctness test above, can be gpu dependent):
|
11
11
|
prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
|
12
12
|
[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
|
13
13
|
[ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
|
@@ -31,7 +31,9 @@ import dataclasses
|
|
31
31
|
import logging
|
32
32
|
import multiprocessing
|
33
33
|
import time
|
34
|
+
from typing import Tuple
|
34
35
|
|
36
|
+
import jsonlines
|
35
37
|
import numpy as np
|
36
38
|
import torch
|
37
39
|
import torch.distributed as dist
|
@@ -47,25 +49,34 @@ from sglang.srt.utils import suppress_other_loggers
|
|
47
49
|
|
48
50
|
@dataclasses.dataclass
|
49
51
|
class BenchArgs:
|
50
|
-
batch_size: int = 1
|
52
|
+
batch_size: Tuple[int] = (1,)
|
51
53
|
input_len: int = 1024
|
52
54
|
output_len: int = 4
|
55
|
+
result_filename: str = ""
|
53
56
|
correctness_test: bool = False
|
54
57
|
# This is only used for correctness test
|
55
58
|
cut_len: int = 4
|
56
59
|
|
57
60
|
@staticmethod
|
58
61
|
def add_cli_args(parser: argparse.ArgumentParser):
|
59
|
-
parser.add_argument(
|
62
|
+
parser.add_argument(
|
63
|
+
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
|
64
|
+
)
|
60
65
|
parser.add_argument("--input-len", type=int, default=BenchArgs.input_len)
|
61
66
|
parser.add_argument("--output-len", type=int, default=BenchArgs.output_len)
|
67
|
+
parser.add_argument(
|
68
|
+
"--result-filename", type=str, default=BenchArgs.result_filename
|
69
|
+
)
|
62
70
|
parser.add_argument("--correctness-test", action="store_true")
|
63
71
|
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
64
72
|
|
65
73
|
@classmethod
|
66
74
|
def from_cli_args(cls, args: argparse.Namespace):
|
67
|
-
|
68
|
-
|
75
|
+
# use the default value's type to case the args into correct types.
|
76
|
+
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
77
|
+
return cls(
|
78
|
+
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
79
|
+
)
|
69
80
|
|
70
81
|
|
71
82
|
def load_model(server_args, tp_rank):
|
@@ -93,7 +104,7 @@ def load_model(server_args, tp_rank):
|
|
93
104
|
return model_runner, tokenizer
|
94
105
|
|
95
106
|
|
96
|
-
def
|
107
|
+
def prepare_inputs_for_correctness_test(bench_args, tokenizer):
|
97
108
|
prompts = [
|
98
109
|
"The capital of France is",
|
99
110
|
"The capital of the United Kindom is",
|
@@ -119,7 +130,9 @@ def prepare_inputs(bench_args, tokenizer):
|
|
119
130
|
return input_ids, reqs
|
120
131
|
|
121
132
|
|
122
|
-
def
|
133
|
+
def prepare_extend_inputs_for_correctness_test(
|
134
|
+
bench_args, input_ids, reqs, model_runner
|
135
|
+
):
|
123
136
|
for i in range(len(reqs)):
|
124
137
|
req = reqs[i]
|
125
138
|
req.input_ids += input_ids[i][bench_args.cut_len :]
|
@@ -129,8 +142,8 @@ def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
|
|
129
142
|
return reqs
|
130
143
|
|
131
144
|
|
132
|
-
def
|
133
|
-
input_ids = np.ones((
|
145
|
+
def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
146
|
+
input_ids = np.ones((batch_size, input_len), dtype=np.int32)
|
134
147
|
sampling_params = SamplingParams(
|
135
148
|
temperature=0,
|
136
149
|
max_new_tokens=BenchArgs.output_len,
|
@@ -179,7 +192,7 @@ def correctness_test(
|
|
179
192
|
model_runner, tokenizer = load_model(server_args, tp_rank)
|
180
193
|
|
181
194
|
# Prepare inputs
|
182
|
-
input_ids, reqs =
|
195
|
+
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
|
183
196
|
|
184
197
|
if bench_args.cut_len > 0:
|
185
198
|
# Prefill
|
@@ -187,7 +200,9 @@ def correctness_test(
|
|
187
200
|
rank_print("prefill logits (first half)", next_token_logits)
|
188
201
|
|
189
202
|
# Prepare extend inputs
|
190
|
-
reqs =
|
203
|
+
reqs = prepare_extend_inputs_for_correctness_test(
|
204
|
+
bench_args, input_ids, reqs, model_runner
|
205
|
+
)
|
191
206
|
|
192
207
|
# Extend
|
193
208
|
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
@@ -205,6 +220,68 @@ def correctness_test(
|
|
205
220
|
rank_print(tokenizer.decode(output_ids[i]))
|
206
221
|
|
207
222
|
|
223
|
+
@torch.inference_mode()
|
224
|
+
def latency_test_run_once(
|
225
|
+
model_runner, rank_print, reqs, batch_size, input_len, output_len
|
226
|
+
):
|
227
|
+
|
228
|
+
# Clear the pools.
|
229
|
+
model_runner.req_to_token_pool.clear()
|
230
|
+
model_runner.token_to_kv_pool.clear()
|
231
|
+
|
232
|
+
measurement_results = {
|
233
|
+
"run_name": "before",
|
234
|
+
"batch_size": batch_size,
|
235
|
+
"input_len": input_len,
|
236
|
+
"output_len": output_len,
|
237
|
+
}
|
238
|
+
|
239
|
+
tot_latency = 0
|
240
|
+
|
241
|
+
# Prefill
|
242
|
+
torch.cuda.synchronize()
|
243
|
+
tic = time.time()
|
244
|
+
next_token_ids, _, batch = extend(reqs, model_runner)
|
245
|
+
torch.cuda.synchronize()
|
246
|
+
prefill_latency = time.time() - tic
|
247
|
+
tot_latency += prefill_latency
|
248
|
+
throughput = input_len * batch_size / prefill_latency
|
249
|
+
rank_print(
|
250
|
+
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
251
|
+
)
|
252
|
+
measurement_results["prefill_latency"] = prefill_latency
|
253
|
+
measurement_results["prefill_throughput"] = throughput
|
254
|
+
|
255
|
+
# Decode
|
256
|
+
for i in range(output_len):
|
257
|
+
torch.cuda.synchronize()
|
258
|
+
tic = time.time()
|
259
|
+
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
260
|
+
torch.cuda.synchronize()
|
261
|
+
latency = time.time() - tic
|
262
|
+
tot_latency += latency
|
263
|
+
throughput = batch_size / latency
|
264
|
+
if i < 5:
|
265
|
+
rank_print(
|
266
|
+
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
267
|
+
)
|
268
|
+
avg_decode_latency = (tot_latency - prefill_latency) / output_len
|
269
|
+
avg_decode_throughput = batch_size / avg_decode_latency
|
270
|
+
rank_print(
|
271
|
+
f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
|
272
|
+
)
|
273
|
+
measurement_results["avg_decode_latency"] = avg_decode_latency
|
274
|
+
measurement_results["avg_decode_throughput"] = avg_decode_throughput
|
275
|
+
|
276
|
+
throughput = (input_len + output_len) * batch_size / tot_latency
|
277
|
+
rank_print(
|
278
|
+
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
|
279
|
+
)
|
280
|
+
measurement_results["total_latency"] = tot_latency
|
281
|
+
measurement_results["total_throughput"] = throughput
|
282
|
+
return measurement_results
|
283
|
+
|
284
|
+
|
208
285
|
def latency_test(
|
209
286
|
server_args,
|
210
287
|
bench_args,
|
@@ -218,62 +295,36 @@ def latency_test(
|
|
218
295
|
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
|
219
296
|
)
|
220
297
|
|
221
|
-
#
|
222
|
-
|
223
|
-
|
224
|
-
def clear():
|
225
|
-
model_runner.req_to_token_pool.clear()
|
226
|
-
model_runner.token_to_kv_pool.clear()
|
227
|
-
|
228
|
-
@torch.inference_mode()
|
229
|
-
def run_once(output_len):
|
230
|
-
# Prefill
|
231
|
-
torch.cuda.synchronize()
|
232
|
-
tot_latency = 0
|
233
|
-
tic = time.time()
|
234
|
-
next_token_ids, _, batch = extend(reqs, model_runner)
|
235
|
-
torch.cuda.synchronize()
|
236
|
-
prefill_latency = time.time() - tic
|
237
|
-
tot_latency += prefill_latency
|
238
|
-
throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
|
239
|
-
rank_print(
|
240
|
-
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
241
|
-
)
|
298
|
+
# To make this PR easier to review, for now, only do the first element in batch_size tuple.
|
299
|
+
bench_args.batch_size = bench_args.batch_size[0]
|
242
300
|
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
248
|
-
torch.cuda.synchronize()
|
249
|
-
latency = time.time() - tic
|
250
|
-
tot_latency += latency
|
251
|
-
throughput = bench_args.batch_size / latency
|
252
|
-
if i < 5:
|
253
|
-
rank_print(
|
254
|
-
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
255
|
-
)
|
256
|
-
avg_decode_latency = (tot_latency - prefill_latency) / output_len
|
257
|
-
avg_decode_throughput = bench_args.batch_size / avg_decode_latency
|
258
|
-
rank_print(
|
259
|
-
f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
|
260
|
-
)
|
261
|
-
|
262
|
-
throughput = (
|
263
|
-
(bench_args.input_len + bench_args.output_len)
|
264
|
-
* bench_args.batch_size
|
265
|
-
/ tot_latency
|
266
|
-
)
|
267
|
-
rank_print(
|
268
|
-
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
|
269
|
-
)
|
301
|
+
# Prepare inputs
|
302
|
+
reqs = prepare_synthetic_inputs_for_latency_test(
|
303
|
+
bench_args.batch_size, bench_args.input_len
|
304
|
+
)
|
270
305
|
|
271
306
|
# Warm up
|
272
|
-
|
273
|
-
|
307
|
+
latency_test_run_once(
|
308
|
+
model_runner, rank_print, reqs, bench_args.batch_size, bench_args.input_len, 4
|
309
|
+
)
|
274
310
|
|
275
311
|
# Run again
|
276
|
-
|
312
|
+
result_list = []
|
313
|
+
result_list.append(
|
314
|
+
latency_test_run_once(
|
315
|
+
model_runner,
|
316
|
+
rank_print,
|
317
|
+
reqs,
|
318
|
+
bench_args.batch_size,
|
319
|
+
bench_args.input_len,
|
320
|
+
bench_args.output_len,
|
321
|
+
)
|
322
|
+
)
|
323
|
+
|
324
|
+
# Write results in jsonlines format.
|
325
|
+
if bench_args.result_filename:
|
326
|
+
with jsonlines.open(bench_args.result_filename, "a") as f:
|
327
|
+
f.write_all(result_list)
|
277
328
|
|
278
329
|
|
279
330
|
def main(server_args, bench_args):
|
sglang/check_env.py
CHANGED
@@ -13,6 +13,7 @@ import torch
|
|
13
13
|
PACKAGE_LIST = [
|
14
14
|
"sglang",
|
15
15
|
"flashinfer",
|
16
|
+
"triton",
|
16
17
|
"requests",
|
17
18
|
"tqdm",
|
18
19
|
"numpy",
|
@@ -30,6 +31,7 @@ PACKAGE_LIST = [
|
|
30
31
|
"zmq",
|
31
32
|
"vllm",
|
32
33
|
"outlines",
|
34
|
+
"multipart",
|
33
35
|
"openai",
|
34
36
|
"tiktoken",
|
35
37
|
"anthropic",
|
@@ -15,7 +15,6 @@ class RuntimeEndpoint(BaseBackend):
|
|
15
15
|
def __init__(
|
16
16
|
self,
|
17
17
|
base_url: str,
|
18
|
-
auth_token: Optional[str] = None,
|
19
18
|
api_key: Optional[str] = None,
|
20
19
|
verify: Optional[str] = None,
|
21
20
|
):
|
@@ -23,13 +22,11 @@ class RuntimeEndpoint(BaseBackend):
|
|
23
22
|
self.support_concate_and_append = True
|
24
23
|
|
25
24
|
self.base_url = base_url
|
26
|
-
self.auth_token = auth_token
|
27
25
|
self.api_key = api_key
|
28
26
|
self.verify = verify
|
29
27
|
|
30
28
|
res = http_request(
|
31
29
|
self.base_url + "/get_model_info",
|
32
|
-
auth_token=self.auth_token,
|
33
30
|
api_key=self.api_key,
|
34
31
|
verify=self.verify,
|
35
32
|
)
|
@@ -67,7 +64,6 @@ class RuntimeEndpoint(BaseBackend):
|
|
67
64
|
res = http_request(
|
68
65
|
self.base_url + "/generate",
|
69
66
|
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
70
|
-
auth_token=self.auth_token,
|
71
67
|
api_key=self.api_key,
|
72
68
|
verify=self.verify,
|
73
69
|
)
|
@@ -79,7 +75,6 @@ class RuntimeEndpoint(BaseBackend):
|
|
79
75
|
res = http_request(
|
80
76
|
self.base_url + "/generate",
|
81
77
|
json=data,
|
82
|
-
auth_token=self.auth_token,
|
83
78
|
api_key=self.api_key,
|
84
79
|
verify=self.verify,
|
85
80
|
)
|
@@ -91,7 +86,6 @@ class RuntimeEndpoint(BaseBackend):
|
|
91
86
|
res = http_request(
|
92
87
|
self.base_url + "/generate",
|
93
88
|
json=data,
|
94
|
-
auth_token=self.auth_token,
|
95
89
|
api_key=self.api_key,
|
96
90
|
verify=self.verify,
|
97
91
|
)
|
@@ -139,7 +133,6 @@ class RuntimeEndpoint(BaseBackend):
|
|
139
133
|
res = http_request(
|
140
134
|
self.base_url + "/generate",
|
141
135
|
json=data,
|
142
|
-
auth_token=self.auth_token,
|
143
136
|
api_key=self.api_key,
|
144
137
|
verify=self.verify,
|
145
138
|
)
|
@@ -193,7 +186,6 @@ class RuntimeEndpoint(BaseBackend):
|
|
193
186
|
self.base_url + "/generate",
|
194
187
|
json=data,
|
195
188
|
stream=True,
|
196
|
-
auth_token=self.auth_token,
|
197
189
|
api_key=self.api_key,
|
198
190
|
verify=self.verify,
|
199
191
|
)
|
@@ -225,7 +217,6 @@ class RuntimeEndpoint(BaseBackend):
|
|
225
217
|
res = http_request(
|
226
218
|
self.base_url + "/generate",
|
227
219
|
json=data,
|
228
|
-
auth_token=self.auth_token,
|
229
220
|
api_key=self.api_key,
|
230
221
|
verify=self.verify,
|
231
222
|
)
|
@@ -243,7 +234,6 @@ class RuntimeEndpoint(BaseBackend):
|
|
243
234
|
res = http_request(
|
244
235
|
self.base_url + "/generate",
|
245
236
|
json=data,
|
246
|
-
auth_token=self.auth_token,
|
247
237
|
api_key=self.api_key,
|
248
238
|
verify=self.verify,
|
249
239
|
)
|
@@ -267,7 +257,6 @@ class RuntimeEndpoint(BaseBackend):
|
|
267
257
|
res = http_request(
|
268
258
|
self.base_url + "/concate_and_append_request",
|
269
259
|
json={"src_rids": src_rids, "dst_rid": dst_rid},
|
270
|
-
auth_token=self.auth_token,
|
271
260
|
api_key=self.api_key,
|
272
261
|
verify=self.verify,
|
273
262
|
)
|
@@ -19,7 +19,7 @@ import functools
|
|
19
19
|
import json
|
20
20
|
import os
|
21
21
|
import warnings
|
22
|
-
from typing import AbstractSet, Collection, Dict, Literal, Optional, Type, Union
|
22
|
+
from typing import AbstractSet, Collection, Dict, List, Literal, Optional, Type, Union
|
23
23
|
|
24
24
|
from huggingface_hub import snapshot_download
|
25
25
|
from transformers import (
|
@@ -259,7 +259,7 @@ class TiktokenTokenizer:
|
|
259
259
|
Literal["all"], AbstractSet[str]
|
260
260
|
] = set(), # noqa: B006
|
261
261
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
262
|
-
) ->
|
262
|
+
) -> List[int]:
|
263
263
|
if isinstance(allowed_special, set):
|
264
264
|
allowed_special |= self._default_allowed_special
|
265
265
|
return tiktoken.Encoding.encode(
|
@@ -57,6 +57,8 @@ def _fwd_kernel(
|
|
57
57
|
stride_buf_vh,
|
58
58
|
stride_req_to_tokens_b,
|
59
59
|
BLOCK_DMODEL: tl.constexpr,
|
60
|
+
BLOCK_DPE: tl.constexpr,
|
61
|
+
BLOCK_DV: tl.constexpr,
|
60
62
|
BLOCK_M: tl.constexpr,
|
61
63
|
BLOCK_N: tl.constexpr,
|
62
64
|
logit_cap: tl.constexpr,
|
@@ -75,8 +77,10 @@ def _fwd_kernel(
|
|
75
77
|
cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
|
76
78
|
|
77
79
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
80
|
+
offs_dv = tl.arange(0, BLOCK_DV)
|
78
81
|
offs_m = tl.arange(0, BLOCK_M)
|
79
82
|
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
|
83
|
+
|
80
84
|
offs_q = (
|
81
85
|
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
82
86
|
* stride_qbs
|
@@ -85,10 +89,20 @@ def _fwd_kernel(
|
|
85
89
|
)
|
86
90
|
q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0)
|
87
91
|
|
92
|
+
if BLOCK_DPE > 0:
|
93
|
+
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
94
|
+
offs_qpe = (
|
95
|
+
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
96
|
+
* stride_qbs
|
97
|
+
+ cur_head * stride_qh
|
98
|
+
+ offs_dpe[None, :]
|
99
|
+
)
|
100
|
+
qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
|
101
|
+
|
88
102
|
# stage1: compute scores with prefix
|
89
103
|
offs_n = tl.arange(0, BLOCK_N)
|
90
104
|
|
91
|
-
acc = tl.zeros([BLOCK_M,
|
105
|
+
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
|
92
106
|
deno = tl.zeros([BLOCK_M], dtype=tl.float32)
|
93
107
|
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
94
108
|
|
@@ -110,6 +124,18 @@ def _fwd_kernel(
|
|
110
124
|
|
111
125
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
112
126
|
qk += tl.dot(q, k)
|
127
|
+
if BLOCK_DPE > 0:
|
128
|
+
offs_kpe = (
|
129
|
+
offs_kv_loc[None, :] * stride_buf_kbs
|
130
|
+
+ cur_kv_head * stride_buf_kh
|
131
|
+
+ offs_dpe[:, None]
|
132
|
+
)
|
133
|
+
kpe = tl.load(
|
134
|
+
K_Buffer + offs_kpe,
|
135
|
+
mask=mask_n[None, :],
|
136
|
+
other=0.0,
|
137
|
+
)
|
138
|
+
qk += tl.dot(qpe, kpe)
|
113
139
|
qk *= sm_scale
|
114
140
|
|
115
141
|
if logit_cap > 0:
|
@@ -125,7 +151,7 @@ def _fwd_kernel(
|
|
125
151
|
offs_buf_v = (
|
126
152
|
offs_kv_loc[:, None] * stride_buf_vbs
|
127
153
|
+ cur_kv_head * stride_buf_vh
|
128
|
-
+
|
154
|
+
+ offs_dv[None, :]
|
129
155
|
)
|
130
156
|
v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0)
|
131
157
|
p = p.to(v.dtype)
|
@@ -150,6 +176,21 @@ def _fwd_kernel(
|
|
150
176
|
|
151
177
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
152
178
|
qk += tl.dot(q, k)
|
179
|
+
|
180
|
+
if BLOCK_DPE > 0:
|
181
|
+
offs_kpe = (
|
182
|
+
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
|
183
|
+
* stride_kbs
|
184
|
+
+ cur_kv_head * stride_kh
|
185
|
+
+ offs_dpe[:, None]
|
186
|
+
)
|
187
|
+
kpe = tl.load(
|
188
|
+
K_Extend + offs_kpe,
|
189
|
+
mask=mask_n[None, :],
|
190
|
+
other=0.0,
|
191
|
+
)
|
192
|
+
qk += tl.dot(qpe, kpe)
|
193
|
+
|
153
194
|
qk *= sm_scale
|
154
195
|
|
155
196
|
if logit_cap > 0:
|
@@ -169,7 +210,7 @@ def _fwd_kernel(
|
|
169
210
|
offs_v = (
|
170
211
|
(cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
|
171
212
|
+ cur_kv_head * stride_vh
|
172
|
-
+
|
213
|
+
+ offs_dv[None, :]
|
173
214
|
)
|
174
215
|
v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0)
|
175
216
|
p = p.to(v.dtype)
|
@@ -181,7 +222,7 @@ def _fwd_kernel(
|
|
181
222
|
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
182
223
|
* stride_obs
|
183
224
|
+ cur_head * stride_oh
|
184
|
-
+
|
225
|
+
+ offs_dv[None, :]
|
185
226
|
)
|
186
227
|
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
|
187
228
|
|
@@ -217,8 +258,17 @@ def extend_attention_fwd(
|
|
217
258
|
o_extend.shape[-1],
|
218
259
|
)
|
219
260
|
|
220
|
-
assert Lq == Lk and
|
221
|
-
assert Lq in {16, 32, 64, 128, 256}
|
261
|
+
assert Lq == Lk and Lv == Lo
|
262
|
+
assert Lq in {16, 32, 64, 128, 256, 576}
|
263
|
+
assert Lv in {16, 32, 64, 128, 256, 512}
|
264
|
+
|
265
|
+
if Lq == 576:
|
266
|
+
BLOCK_DMODEL = 512
|
267
|
+
BLOCK_DPE = 64
|
268
|
+
else:
|
269
|
+
BLOCK_DMODEL = Lq
|
270
|
+
BLOCK_DPE = 0
|
271
|
+
BLOCK_DV = Lv
|
222
272
|
|
223
273
|
if CUDA_CAPABILITY[0] >= 8:
|
224
274
|
BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
|
@@ -260,7 +310,9 @@ def extend_attention_fwd(
|
|
260
310
|
v_buffer.stride(0),
|
261
311
|
v_buffer.stride(1),
|
262
312
|
req_to_tokens.stride(0),
|
263
|
-
BLOCK_DMODEL=
|
313
|
+
BLOCK_DMODEL=BLOCK_DMODEL,
|
314
|
+
BLOCK_DPE=BLOCK_DPE,
|
315
|
+
BLOCK_DV=BLOCK_DV,
|
264
316
|
BLOCK_M=BLOCK_M,
|
265
317
|
BLOCK_N=BLOCK_N,
|
266
318
|
num_warps=num_warps,
|
@@ -38,16 +38,22 @@ class RadixAttention(nn.Module):
|
|
38
38
|
num_kv_heads: int,
|
39
39
|
layer_id: int,
|
40
40
|
logit_cap: int = -1,
|
41
|
+
v_head_dim: int = -1,
|
41
42
|
):
|
42
43
|
super().__init__()
|
43
44
|
self.tp_q_head_num = num_heads
|
44
45
|
self.tp_k_head_num = num_kv_heads
|
45
46
|
self.tp_v_head_num = num_kv_heads
|
46
47
|
self.head_dim = head_dim
|
48
|
+
self.qk_head_dim = head_dim
|
49
|
+
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
|
47
50
|
self.scaling = scaling
|
48
51
|
self.layer_id = layer_id
|
49
52
|
|
50
|
-
if
|
53
|
+
if (
|
54
|
+
not global_server_args_dict.get("disable_flashinfer", False)
|
55
|
+
and self.qk_head_dim == self.v_head_dim
|
56
|
+
):
|
51
57
|
self.extend_forward = self.extend_forward_flashinfer
|
52
58
|
self.decode_forward = self.decode_forward_flashinfer
|
53
59
|
else:
|
@@ -57,13 +63,17 @@ class RadixAttention(nn.Module):
|
|
57
63
|
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
58
64
|
|
59
65
|
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
60
|
-
|
66
|
+
if self.qk_head_dim != self.v_head_dim:
|
67
|
+
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
|
68
|
+
else:
|
69
|
+
o = torch.empty_like(q)
|
70
|
+
|
61
71
|
self.store_kv_cache(k, v, input_metadata)
|
62
72
|
extend_attention_fwd(
|
63
|
-
q.view(-1, self.tp_q_head_num, self.
|
73
|
+
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
64
74
|
k.contiguous(),
|
65
75
|
v.contiguous(),
|
66
|
-
o.view(-1, self.tp_q_head_num, self.
|
76
|
+
o.view(-1, self.tp_q_head_num, self.v_head_dim),
|
67
77
|
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
68
78
|
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
69
79
|
input_metadata.req_to_token_pool.req_to_token,
|
@@ -82,14 +92,17 @@ class RadixAttention(nn.Module):
|
|
82
92
|
return o
|
83
93
|
|
84
94
|
def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
85
|
-
|
95
|
+
if self.qk_head_dim != self.v_head_dim:
|
96
|
+
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
|
97
|
+
else:
|
98
|
+
o = torch.empty_like(q)
|
86
99
|
self.store_kv_cache(k, v, input_metadata)
|
87
100
|
|
88
101
|
token_attention_fwd(
|
89
|
-
q.view(-1, self.tp_q_head_num, self.
|
102
|
+
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
90
103
|
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
91
104
|
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
92
|
-
o.view(-1, self.tp_q_head_num, self.
|
105
|
+
o.view(-1, self.tp_q_head_num, self.v_head_dim),
|
93
106
|
input_metadata.req_to_token_pool.req_to_token,
|
94
107
|
input_metadata.req_pool_indices,
|
95
108
|
input_metadata.triton_start_loc,
|
@@ -160,8 +173,8 @@ class RadixAttention(nn.Module):
|
|
160
173
|
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
161
174
|
|
162
175
|
def forward(self, q, k, v, input_metadata: InputMetadata):
|
163
|
-
k = k.view(-1, self.tp_k_head_num, self.
|
164
|
-
v = v.view(-1, self.tp_v_head_num, self.
|
176
|
+
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
177
|
+
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
165
178
|
|
166
179
|
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
167
180
|
return self.extend_forward(q, k, v, input_metadata)
|
@@ -54,6 +54,7 @@ def _fwd_kernel_stage1(
|
|
54
54
|
att_stride_h,
|
55
55
|
kv_group_num: tl.constexpr,
|
56
56
|
BLOCK_DMODEL: tl.constexpr,
|
57
|
+
BLOCK_DPE: tl.constexpr,
|
57
58
|
BLOCK_N: tl.constexpr,
|
58
59
|
logit_cap: tl.constexpr,
|
59
60
|
):
|
@@ -73,6 +74,10 @@ def _fwd_kernel_stage1(
|
|
73
74
|
|
74
75
|
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
75
76
|
|
77
|
+
if BLOCK_DPE > 0:
|
78
|
+
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
79
|
+
off_qpe = cur_batch * stride_qbs + cur_head * stride_qh + offs_dpe
|
80
|
+
|
76
81
|
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
77
82
|
|
78
83
|
block_stard_index = start_n * BLOCK_N
|
@@ -97,6 +102,19 @@ def _fwd_kernel_stage1(
|
|
97
102
|
other=0.0,
|
98
103
|
).to(REDUCE_TRITON_TYPE)
|
99
104
|
att_value = tl.sum(q[None, :] * k, 1)
|
105
|
+
if BLOCK_DPE > 0:
|
106
|
+
qpe = tl.load(Q + off_qpe + start_mark).to(REDUCE_TRITON_TYPE)
|
107
|
+
offs_buf_kpe = (
|
108
|
+
k_loc[:, None] * stride_buf_kbs
|
109
|
+
+ cur_kv_head * stride_buf_kh
|
110
|
+
+ offs_dpe[None, :]
|
111
|
+
)
|
112
|
+
kpe = tl.load(
|
113
|
+
K_Buffer + offs_buf_kpe,
|
114
|
+
mask=offs_n_new[:, None] < cur_batch_end_index,
|
115
|
+
other=0.0,
|
116
|
+
).to(REDUCE_TRITON_TYPE)
|
117
|
+
att_value += tl.sum(qpe[None, :] * kpe, 1)
|
100
118
|
att_value *= sm_scale
|
101
119
|
|
102
120
|
if logit_cap > 0:
|
@@ -192,7 +210,14 @@ def _token_att_m_fwd(
|
|
192
210
|
# shape constraints
|
193
211
|
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
194
212
|
assert Lq == Lk
|
195
|
-
assert Lk in {16, 32, 64, 128, 256}
|
213
|
+
assert Lk in {16, 32, 64, 128, 256, 576}
|
214
|
+
|
215
|
+
if Lk == 576:
|
216
|
+
BLOCK_DMODEL = 512
|
217
|
+
BLOCK_DPE = 64
|
218
|
+
else:
|
219
|
+
BLOCK_DMODEL = Lk
|
220
|
+
BLOCK_DPE = 0
|
196
221
|
|
197
222
|
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
198
223
|
|
@@ -220,7 +245,8 @@ def _token_att_m_fwd(
|
|
220
245
|
k_buffer.stride(1),
|
221
246
|
att_out.stride(0),
|
222
247
|
kv_group_num=kv_group_num,
|
223
|
-
BLOCK_DMODEL=
|
248
|
+
BLOCK_DMODEL=BLOCK_DMODEL,
|
249
|
+
BLOCK_DPE=BLOCK_DPE,
|
224
250
|
BLOCK_N=BLOCK,
|
225
251
|
logit_cap=logit_cap,
|
226
252
|
num_warps=num_warps,
|