sglang 0.2.9.post1__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sglang/bench_latency.py CHANGED
@@ -1,13 +1,13 @@
1
1
  """
2
2
  Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
3
3
 
4
- # Usage (latency test):
4
+ # Usage (latency test) with dummy weights:
5
5
  python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
6
6
 
7
7
  # Usage (correctness test):
8
8
  python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
9
9
 
10
- ### Reference output:
10
+ ### Reference output (of the correctness test above, can be gpu dependent):
11
11
  prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
12
12
  [-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
13
13
  [ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
@@ -31,7 +31,9 @@ import dataclasses
31
31
  import logging
32
32
  import multiprocessing
33
33
  import time
34
+ from typing import Tuple
34
35
 
36
+ import jsonlines
35
37
  import numpy as np
36
38
  import torch
37
39
  import torch.distributed as dist
@@ -47,25 +49,34 @@ from sglang.srt.utils import suppress_other_loggers
47
49
 
48
50
  @dataclasses.dataclass
49
51
  class BenchArgs:
50
- batch_size: int = 1
52
+ batch_size: Tuple[int] = (1,)
51
53
  input_len: int = 1024
52
54
  output_len: int = 4
55
+ result_filename: str = ""
53
56
  correctness_test: bool = False
54
57
  # This is only used for correctness test
55
58
  cut_len: int = 4
56
59
 
57
60
  @staticmethod
58
61
  def add_cli_args(parser: argparse.ArgumentParser):
59
- parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size)
62
+ parser.add_argument(
63
+ "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
64
+ )
60
65
  parser.add_argument("--input-len", type=int, default=BenchArgs.input_len)
61
66
  parser.add_argument("--output-len", type=int, default=BenchArgs.output_len)
67
+ parser.add_argument(
68
+ "--result-filename", type=str, default=BenchArgs.result_filename
69
+ )
62
70
  parser.add_argument("--correctness-test", action="store_true")
63
71
  parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
64
72
 
65
73
  @classmethod
66
74
  def from_cli_args(cls, args: argparse.Namespace):
67
- attrs = [attr.name for attr in dataclasses.fields(cls)]
68
- return cls(**{attr: getattr(args, attr) for attr in attrs})
75
+ # use the default value's type to case the args into correct types.
76
+ attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
77
+ return cls(
78
+ **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
79
+ )
69
80
 
70
81
 
71
82
  def load_model(server_args, tp_rank):
@@ -93,7 +104,7 @@ def load_model(server_args, tp_rank):
93
104
  return model_runner, tokenizer
94
105
 
95
106
 
96
- def prepare_inputs(bench_args, tokenizer):
107
+ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
97
108
  prompts = [
98
109
  "The capital of France is",
99
110
  "The capital of the United Kindom is",
@@ -119,7 +130,9 @@ def prepare_inputs(bench_args, tokenizer):
119
130
  return input_ids, reqs
120
131
 
121
132
 
122
- def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
133
+ def prepare_extend_inputs_for_correctness_test(
134
+ bench_args, input_ids, reqs, model_runner
135
+ ):
123
136
  for i in range(len(reqs)):
124
137
  req = reqs[i]
125
138
  req.input_ids += input_ids[i][bench_args.cut_len :]
@@ -129,8 +142,8 @@ def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
129
142
  return reqs
130
143
 
131
144
 
132
- def prepare_synthetic_inputs(bench_args, tokenizer):
133
- input_ids = np.ones((bench_args.batch_size, bench_args.input_len), dtype=np.int32)
145
+ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
146
+ input_ids = np.ones((batch_size, input_len), dtype=np.int32)
134
147
  sampling_params = SamplingParams(
135
148
  temperature=0,
136
149
  max_new_tokens=BenchArgs.output_len,
@@ -179,7 +192,7 @@ def correctness_test(
179
192
  model_runner, tokenizer = load_model(server_args, tp_rank)
180
193
 
181
194
  # Prepare inputs
182
- input_ids, reqs = prepare_inputs(bench_args, tokenizer)
195
+ input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
183
196
 
184
197
  if bench_args.cut_len > 0:
185
198
  # Prefill
@@ -187,7 +200,9 @@ def correctness_test(
187
200
  rank_print("prefill logits (first half)", next_token_logits)
188
201
 
189
202
  # Prepare extend inputs
190
- reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner)
203
+ reqs = prepare_extend_inputs_for_correctness_test(
204
+ bench_args, input_ids, reqs, model_runner
205
+ )
191
206
 
192
207
  # Extend
193
208
  next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
@@ -205,6 +220,68 @@ def correctness_test(
205
220
  rank_print(tokenizer.decode(output_ids[i]))
206
221
 
207
222
 
223
+ @torch.inference_mode()
224
+ def latency_test_run_once(
225
+ model_runner, rank_print, reqs, batch_size, input_len, output_len
226
+ ):
227
+
228
+ # Clear the pools.
229
+ model_runner.req_to_token_pool.clear()
230
+ model_runner.token_to_kv_pool.clear()
231
+
232
+ measurement_results = {
233
+ "run_name": "before",
234
+ "batch_size": batch_size,
235
+ "input_len": input_len,
236
+ "output_len": output_len,
237
+ }
238
+
239
+ tot_latency = 0
240
+
241
+ # Prefill
242
+ torch.cuda.synchronize()
243
+ tic = time.time()
244
+ next_token_ids, _, batch = extend(reqs, model_runner)
245
+ torch.cuda.synchronize()
246
+ prefill_latency = time.time() - tic
247
+ tot_latency += prefill_latency
248
+ throughput = input_len * batch_size / prefill_latency
249
+ rank_print(
250
+ f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
251
+ )
252
+ measurement_results["prefill_latency"] = prefill_latency
253
+ measurement_results["prefill_throughput"] = throughput
254
+
255
+ # Decode
256
+ for i in range(output_len):
257
+ torch.cuda.synchronize()
258
+ tic = time.time()
259
+ next_token_ids, _ = decode(next_token_ids, batch, model_runner)
260
+ torch.cuda.synchronize()
261
+ latency = time.time() - tic
262
+ tot_latency += latency
263
+ throughput = batch_size / latency
264
+ if i < 5:
265
+ rank_print(
266
+ f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
267
+ )
268
+ avg_decode_latency = (tot_latency - prefill_latency) / output_len
269
+ avg_decode_throughput = batch_size / avg_decode_latency
270
+ rank_print(
271
+ f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
272
+ )
273
+ measurement_results["avg_decode_latency"] = avg_decode_latency
274
+ measurement_results["avg_decode_throughput"] = avg_decode_throughput
275
+
276
+ throughput = (input_len + output_len) * batch_size / tot_latency
277
+ rank_print(
278
+ f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
279
+ )
280
+ measurement_results["total_latency"] = tot_latency
281
+ measurement_results["total_throughput"] = throughput
282
+ return measurement_results
283
+
284
+
208
285
  def latency_test(
209
286
  server_args,
210
287
  bench_args,
@@ -218,62 +295,36 @@ def latency_test(
218
295
  f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
219
296
  )
220
297
 
221
- # Prepare inputs
222
- reqs = prepare_synthetic_inputs(bench_args, tokenizer)
223
-
224
- def clear():
225
- model_runner.req_to_token_pool.clear()
226
- model_runner.token_to_kv_pool.clear()
227
-
228
- @torch.inference_mode()
229
- def run_once(output_len):
230
- # Prefill
231
- torch.cuda.synchronize()
232
- tot_latency = 0
233
- tic = time.time()
234
- next_token_ids, _, batch = extend(reqs, model_runner)
235
- torch.cuda.synchronize()
236
- prefill_latency = time.time() - tic
237
- tot_latency += prefill_latency
238
- throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
239
- rank_print(
240
- f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
241
- )
298
+ # To make this PR easier to review, for now, only do the first element in batch_size tuple.
299
+ bench_args.batch_size = bench_args.batch_size[0]
242
300
 
243
- # Decode
244
- for i in range(output_len):
245
- torch.cuda.synchronize()
246
- tic = time.time()
247
- next_token_ids, _ = decode(next_token_ids, batch, model_runner)
248
- torch.cuda.synchronize()
249
- latency = time.time() - tic
250
- tot_latency += latency
251
- throughput = bench_args.batch_size / latency
252
- if i < 5:
253
- rank_print(
254
- f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
255
- )
256
- avg_decode_latency = (tot_latency - prefill_latency) / output_len
257
- avg_decode_throughput = bench_args.batch_size / avg_decode_latency
258
- rank_print(
259
- f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
260
- )
261
-
262
- throughput = (
263
- (bench_args.input_len + bench_args.output_len)
264
- * bench_args.batch_size
265
- / tot_latency
266
- )
267
- rank_print(
268
- f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
269
- )
301
+ # Prepare inputs
302
+ reqs = prepare_synthetic_inputs_for_latency_test(
303
+ bench_args.batch_size, bench_args.input_len
304
+ )
270
305
 
271
306
  # Warm up
272
- run_once(4)
273
- clear()
307
+ latency_test_run_once(
308
+ model_runner, rank_print, reqs, bench_args.batch_size, bench_args.input_len, 4
309
+ )
274
310
 
275
311
  # Run again
276
- run_once(bench_args.output_len)
312
+ result_list = []
313
+ result_list.append(
314
+ latency_test_run_once(
315
+ model_runner,
316
+ rank_print,
317
+ reqs,
318
+ bench_args.batch_size,
319
+ bench_args.input_len,
320
+ bench_args.output_len,
321
+ )
322
+ )
323
+
324
+ # Write results in jsonlines format.
325
+ if bench_args.result_filename:
326
+ with jsonlines.open(bench_args.result_filename, "a") as f:
327
+ f.write_all(result_list)
277
328
 
278
329
 
279
330
  def main(server_args, bench_args):
sglang/check_env.py CHANGED
@@ -13,6 +13,7 @@ import torch
13
13
  PACKAGE_LIST = [
14
14
  "sglang",
15
15
  "flashinfer",
16
+ "triton",
16
17
  "requests",
17
18
  "tqdm",
18
19
  "numpy",
@@ -15,7 +15,6 @@ class RuntimeEndpoint(BaseBackend):
15
15
  def __init__(
16
16
  self,
17
17
  base_url: str,
18
- auth_token: Optional[str] = None,
19
18
  api_key: Optional[str] = None,
20
19
  verify: Optional[str] = None,
21
20
  ):
@@ -23,13 +22,11 @@ class RuntimeEndpoint(BaseBackend):
23
22
  self.support_concate_and_append = True
24
23
 
25
24
  self.base_url = base_url
26
- self.auth_token = auth_token
27
25
  self.api_key = api_key
28
26
  self.verify = verify
29
27
 
30
28
  res = http_request(
31
29
  self.base_url + "/get_model_info",
32
- auth_token=self.auth_token,
33
30
  api_key=self.api_key,
34
31
  verify=self.verify,
35
32
  )
@@ -67,7 +64,6 @@ class RuntimeEndpoint(BaseBackend):
67
64
  res = http_request(
68
65
  self.base_url + "/generate",
69
66
  json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
70
- auth_token=self.auth_token,
71
67
  api_key=self.api_key,
72
68
  verify=self.verify,
73
69
  )
@@ -79,7 +75,6 @@ class RuntimeEndpoint(BaseBackend):
79
75
  res = http_request(
80
76
  self.base_url + "/generate",
81
77
  json=data,
82
- auth_token=self.auth_token,
83
78
  api_key=self.api_key,
84
79
  verify=self.verify,
85
80
  )
@@ -91,7 +86,6 @@ class RuntimeEndpoint(BaseBackend):
91
86
  res = http_request(
92
87
  self.base_url + "/generate",
93
88
  json=data,
94
- auth_token=self.auth_token,
95
89
  api_key=self.api_key,
96
90
  verify=self.verify,
97
91
  )
@@ -139,7 +133,6 @@ class RuntimeEndpoint(BaseBackend):
139
133
  res = http_request(
140
134
  self.base_url + "/generate",
141
135
  json=data,
142
- auth_token=self.auth_token,
143
136
  api_key=self.api_key,
144
137
  verify=self.verify,
145
138
  )
@@ -193,7 +186,6 @@ class RuntimeEndpoint(BaseBackend):
193
186
  self.base_url + "/generate",
194
187
  json=data,
195
188
  stream=True,
196
- auth_token=self.auth_token,
197
189
  api_key=self.api_key,
198
190
  verify=self.verify,
199
191
  )
@@ -225,7 +217,6 @@ class RuntimeEndpoint(BaseBackend):
225
217
  res = http_request(
226
218
  self.base_url + "/generate",
227
219
  json=data,
228
- auth_token=self.auth_token,
229
220
  api_key=self.api_key,
230
221
  verify=self.verify,
231
222
  )
@@ -243,7 +234,6 @@ class RuntimeEndpoint(BaseBackend):
243
234
  res = http_request(
244
235
  self.base_url + "/generate",
245
236
  json=data,
246
- auth_token=self.auth_token,
247
237
  api_key=self.api_key,
248
238
  verify=self.verify,
249
239
  )
@@ -267,7 +257,6 @@ class RuntimeEndpoint(BaseBackend):
267
257
  res = http_request(
268
258
  self.base_url + "/concate_and_append_request",
269
259
  json={"src_rids": src_rids, "dst_rid": dst_rid},
270
- auth_token=self.auth_token,
271
260
  api_key=self.api_key,
272
261
  verify=self.verify,
273
262
  )
@@ -19,7 +19,7 @@ import functools
19
19
  import json
20
20
  import os
21
21
  import warnings
22
- from typing import AbstractSet, Collection, Dict, Literal, Optional, Type, Union
22
+ from typing import AbstractSet, Collection, Dict, List, Literal, Optional, Type, Union
23
23
 
24
24
  from huggingface_hub import snapshot_download
25
25
  from transformers import (
@@ -259,7 +259,7 @@ class TiktokenTokenizer:
259
259
  Literal["all"], AbstractSet[str]
260
260
  ] = set(), # noqa: B006
261
261
  disallowed_special: Union[Literal["all"], Collection[str]] = "all",
262
- ) -> list[int]:
262
+ ) -> List[int]:
263
263
  if isinstance(allowed_special, set):
264
264
  allowed_special |= self._default_allowed_special
265
265
  return tiktoken.Encoding.encode(
@@ -57,6 +57,8 @@ def _fwd_kernel(
57
57
  stride_buf_vh,
58
58
  stride_req_to_tokens_b,
59
59
  BLOCK_DMODEL: tl.constexpr,
60
+ BLOCK_DPE: tl.constexpr,
61
+ BLOCK_DV: tl.constexpr,
60
62
  BLOCK_M: tl.constexpr,
61
63
  BLOCK_N: tl.constexpr,
62
64
  logit_cap: tl.constexpr,
@@ -75,8 +77,10 @@ def _fwd_kernel(
75
77
  cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
76
78
 
77
79
  offs_d = tl.arange(0, BLOCK_DMODEL)
80
+ offs_dv = tl.arange(0, BLOCK_DV)
78
81
  offs_m = tl.arange(0, BLOCK_M)
79
82
  mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
83
+
80
84
  offs_q = (
81
85
  (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
82
86
  * stride_qbs
@@ -85,10 +89,20 @@ def _fwd_kernel(
85
89
  )
86
90
  q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0)
87
91
 
92
+ if BLOCK_DPE > 0:
93
+ offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
94
+ offs_qpe = (
95
+ (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
96
+ * stride_qbs
97
+ + cur_head * stride_qh
98
+ + offs_dpe[None, :]
99
+ )
100
+ qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
101
+
88
102
  # stage1: compute scores with prefix
89
103
  offs_n = tl.arange(0, BLOCK_N)
90
104
 
91
- acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
105
+ acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
92
106
  deno = tl.zeros([BLOCK_M], dtype=tl.float32)
93
107
  e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
94
108
 
@@ -110,6 +124,18 @@ def _fwd_kernel(
110
124
 
111
125
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
112
126
  qk += tl.dot(q, k)
127
+ if BLOCK_DPE > 0:
128
+ offs_kpe = (
129
+ offs_kv_loc[None, :] * stride_buf_kbs
130
+ + cur_kv_head * stride_buf_kh
131
+ + offs_dpe[:, None]
132
+ )
133
+ kpe = tl.load(
134
+ K_Buffer + offs_kpe,
135
+ mask=mask_n[None, :],
136
+ other=0.0,
137
+ )
138
+ qk += tl.dot(qpe, kpe)
113
139
  qk *= sm_scale
114
140
 
115
141
  if logit_cap > 0:
@@ -125,7 +151,7 @@ def _fwd_kernel(
125
151
  offs_buf_v = (
126
152
  offs_kv_loc[:, None] * stride_buf_vbs
127
153
  + cur_kv_head * stride_buf_vh
128
- + offs_d[None, :]
154
+ + offs_dv[None, :]
129
155
  )
130
156
  v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0)
131
157
  p = p.to(v.dtype)
@@ -150,6 +176,21 @@ def _fwd_kernel(
150
176
 
151
177
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
152
178
  qk += tl.dot(q, k)
179
+
180
+ if BLOCK_DPE > 0:
181
+ offs_kpe = (
182
+ (cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
183
+ * stride_kbs
184
+ + cur_kv_head * stride_kh
185
+ + offs_dpe[:, None]
186
+ )
187
+ kpe = tl.load(
188
+ K_Extend + offs_kpe,
189
+ mask=mask_n[None, :],
190
+ other=0.0,
191
+ )
192
+ qk += tl.dot(qpe, kpe)
193
+
153
194
  qk *= sm_scale
154
195
 
155
196
  if logit_cap > 0:
@@ -169,7 +210,7 @@ def _fwd_kernel(
169
210
  offs_v = (
170
211
  (cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
171
212
  + cur_kv_head * stride_vh
172
- + offs_d[None, :]
213
+ + offs_dv[None, :]
173
214
  )
174
215
  v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0)
175
216
  p = p.to(v.dtype)
@@ -181,7 +222,7 @@ def _fwd_kernel(
181
222
  (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
182
223
  * stride_obs
183
224
  + cur_head * stride_oh
184
- + offs_d[None, :]
225
+ + offs_dv[None, :]
185
226
  )
186
227
  tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
187
228
 
@@ -217,8 +258,17 @@ def extend_attention_fwd(
217
258
  o_extend.shape[-1],
218
259
  )
219
260
 
220
- assert Lq == Lk and Lk == Lv and Lv == Lo
221
- assert Lq in {16, 32, 64, 128, 256}
261
+ assert Lq == Lk and Lv == Lo
262
+ assert Lq in {16, 32, 64, 128, 256, 576}
263
+ assert Lv in {16, 32, 64, 128, 256, 512}
264
+
265
+ if Lq == 576:
266
+ BLOCK_DMODEL = 512
267
+ BLOCK_DPE = 64
268
+ else:
269
+ BLOCK_DMODEL = Lq
270
+ BLOCK_DPE = 0
271
+ BLOCK_DV = Lv
222
272
 
223
273
  if CUDA_CAPABILITY[0] >= 8:
224
274
  BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
@@ -260,7 +310,9 @@ def extend_attention_fwd(
260
310
  v_buffer.stride(0),
261
311
  v_buffer.stride(1),
262
312
  req_to_tokens.stride(0),
263
- BLOCK_DMODEL=Lq,
313
+ BLOCK_DMODEL=BLOCK_DMODEL,
314
+ BLOCK_DPE=BLOCK_DPE,
315
+ BLOCK_DV=BLOCK_DV,
264
316
  BLOCK_M=BLOCK_M,
265
317
  BLOCK_N=BLOCK_N,
266
318
  num_warps=num_warps,
@@ -38,16 +38,22 @@ class RadixAttention(nn.Module):
38
38
  num_kv_heads: int,
39
39
  layer_id: int,
40
40
  logit_cap: int = -1,
41
+ v_head_dim: int = -1,
41
42
  ):
42
43
  super().__init__()
43
44
  self.tp_q_head_num = num_heads
44
45
  self.tp_k_head_num = num_kv_heads
45
46
  self.tp_v_head_num = num_kv_heads
46
47
  self.head_dim = head_dim
48
+ self.qk_head_dim = head_dim
49
+ self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
47
50
  self.scaling = scaling
48
51
  self.layer_id = layer_id
49
52
 
50
- if not global_server_args_dict.get("disable_flashinfer", False):
53
+ if (
54
+ not global_server_args_dict.get("disable_flashinfer", False)
55
+ and self.qk_head_dim == self.v_head_dim
56
+ ):
51
57
  self.extend_forward = self.extend_forward_flashinfer
52
58
  self.decode_forward = self.decode_forward_flashinfer
53
59
  else:
@@ -57,13 +63,17 @@ class RadixAttention(nn.Module):
57
63
  self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
58
64
 
59
65
  def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
60
- o = torch.empty_like(q)
66
+ if self.qk_head_dim != self.v_head_dim:
67
+ o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
68
+ else:
69
+ o = torch.empty_like(q)
70
+
61
71
  self.store_kv_cache(k, v, input_metadata)
62
72
  extend_attention_fwd(
63
- q.view(-1, self.tp_q_head_num, self.head_dim),
73
+ q.view(-1, self.tp_q_head_num, self.qk_head_dim),
64
74
  k.contiguous(),
65
75
  v.contiguous(),
66
- o.view(-1, self.tp_q_head_num, self.head_dim),
76
+ o.view(-1, self.tp_q_head_num, self.v_head_dim),
67
77
  input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
68
78
  input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
69
79
  input_metadata.req_to_token_pool.req_to_token,
@@ -82,14 +92,17 @@ class RadixAttention(nn.Module):
82
92
  return o
83
93
 
84
94
  def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
85
- o = torch.empty_like(q)
95
+ if self.qk_head_dim != self.v_head_dim:
96
+ o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
97
+ else:
98
+ o = torch.empty_like(q)
86
99
  self.store_kv_cache(k, v, input_metadata)
87
100
 
88
101
  token_attention_fwd(
89
- q.view(-1, self.tp_q_head_num, self.head_dim),
102
+ q.view(-1, self.tp_q_head_num, self.qk_head_dim),
90
103
  input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
91
104
  input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
92
- o.view(-1, self.tp_q_head_num, self.head_dim),
105
+ o.view(-1, self.tp_q_head_num, self.v_head_dim),
93
106
  input_metadata.req_to_token_pool.req_to_token,
94
107
  input_metadata.req_pool_indices,
95
108
  input_metadata.triton_start_loc,
@@ -160,8 +173,8 @@ class RadixAttention(nn.Module):
160
173
  return o.view(-1, self.tp_q_head_num * self.head_dim)
161
174
 
162
175
  def forward(self, q, k, v, input_metadata: InputMetadata):
163
- k = k.view(-1, self.tp_k_head_num, self.head_dim)
164
- v = v.view(-1, self.tp_v_head_num, self.head_dim)
176
+ k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
177
+ v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
165
178
 
166
179
  if input_metadata.forward_mode == ForwardMode.EXTEND:
167
180
  return self.extend_forward(q, k, v, input_metadata)
@@ -54,6 +54,7 @@ def _fwd_kernel_stage1(
54
54
  att_stride_h,
55
55
  kv_group_num: tl.constexpr,
56
56
  BLOCK_DMODEL: tl.constexpr,
57
+ BLOCK_DPE: tl.constexpr,
57
58
  BLOCK_N: tl.constexpr,
58
59
  logit_cap: tl.constexpr,
59
60
  ):
@@ -73,6 +74,10 @@ def _fwd_kernel_stage1(
73
74
 
74
75
  off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
75
76
 
77
+ if BLOCK_DPE > 0:
78
+ offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
79
+ off_qpe = cur_batch * stride_qbs + cur_head * stride_qh + offs_dpe
80
+
76
81
  offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
77
82
 
78
83
  block_stard_index = start_n * BLOCK_N
@@ -97,6 +102,19 @@ def _fwd_kernel_stage1(
97
102
  other=0.0,
98
103
  ).to(REDUCE_TRITON_TYPE)
99
104
  att_value = tl.sum(q[None, :] * k, 1)
105
+ if BLOCK_DPE > 0:
106
+ qpe = tl.load(Q + off_qpe + start_mark).to(REDUCE_TRITON_TYPE)
107
+ offs_buf_kpe = (
108
+ k_loc[:, None] * stride_buf_kbs
109
+ + cur_kv_head * stride_buf_kh
110
+ + offs_dpe[None, :]
111
+ )
112
+ kpe = tl.load(
113
+ K_Buffer + offs_buf_kpe,
114
+ mask=offs_n_new[:, None] < cur_batch_end_index,
115
+ other=0.0,
116
+ ).to(REDUCE_TRITON_TYPE)
117
+ att_value += tl.sum(qpe[None, :] * kpe, 1)
100
118
  att_value *= sm_scale
101
119
 
102
120
  if logit_cap > 0:
@@ -192,7 +210,14 @@ def _token_att_m_fwd(
192
210
  # shape constraints
193
211
  Lq, Lk = q.shape[-1], k_buffer.shape[-1]
194
212
  assert Lq == Lk
195
- assert Lk in {16, 32, 64, 128, 256}
213
+ assert Lk in {16, 32, 64, 128, 256, 576}
214
+
215
+ if Lk == 576:
216
+ BLOCK_DMODEL = 512
217
+ BLOCK_DPE = 64
218
+ else:
219
+ BLOCK_DMODEL = Lk
220
+ BLOCK_DPE = 0
196
221
 
197
222
  batch, head_num = B_req_idx.shape[0], q.shape[1]
198
223
 
@@ -220,7 +245,8 @@ def _token_att_m_fwd(
220
245
  k_buffer.stride(1),
221
246
  att_out.stride(0),
222
247
  kv_group_num=kv_group_num,
223
- BLOCK_DMODEL=Lk,
248
+ BLOCK_DMODEL=BLOCK_DMODEL,
249
+ BLOCK_DPE=BLOCK_DPE,
224
250
  BLOCK_N=BLOCK,
225
251
  logit_cap=logit_cap,
226
252
  num_warps=num_warps,
@@ -92,7 +92,7 @@ class GenerateReqInput:
92
92
  for element in parallel_sample_num_list
93
93
  )
94
94
  if parallel_sample_num > 1 and (not all_equal):
95
- ## TODO cope with the case that the parallel_sample_num is different for different samples
95
+ # TODO cope with the case that the parallel_sample_num is different for different samples
96
96
  raise ValueError(
97
97
  "The parallel_sample_num should be the same for all samples in sample params."
98
98
  )
@@ -103,14 +103,19 @@ class GenerateReqInput:
103
103
  if parallel_sample_num != 1:
104
104
  # parallel sampling +1 represents the original prefill stage
105
105
  num = parallel_sample_num + 1
106
- if isinstance(self.text, List):
107
- ## suppot batch operation
106
+ if isinstance(self.text, list):
107
+ # suppot batch operation
108
108
  self.batch_size = len(self.text)
109
109
  num = num * len(self.text)
110
+ elif isinstance(self.input_ids, list) and isinstance(
111
+ self.input_ids[0], list
112
+ ):
113
+ self.batch_size = len(self.input_ids)
114
+ num = num * len(self.input_ids)
110
115
  else:
111
116
  self.batch_size = 1
112
117
  else:
113
- ## support select operation
118
+ # support select operation
114
119
  num = len(self.text) if self.text is not None else len(self.input_ids)
115
120
  self.batch_size = num
116
121