sglang 0.3.3.post1__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/bench_latency.py +28 -10
  2. sglang/bench_server_latency.py +21 -10
  3. sglang/bench_serving.py +101 -7
  4. sglang/global_config.py +0 -1
  5. sglang/srt/layers/attention/__init__.py +27 -5
  6. sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
  7. sglang/srt/layers/attention/flashinfer_backend.py +352 -83
  8. sglang/srt/layers/attention/triton_backend.py +6 -4
  9. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
  10. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
  11. sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
  12. sglang/srt/layers/sampler.py +6 -2
  13. sglang/srt/managers/detokenizer_manager.py +31 -10
  14. sglang/srt/managers/io_struct.py +4 -0
  15. sglang/srt/managers/schedule_batch.py +120 -43
  16. sglang/srt/managers/schedule_policy.py +2 -1
  17. sglang/srt/managers/scheduler.py +202 -140
  18. sglang/srt/managers/tokenizer_manager.py +5 -1
  19. sglang/srt/managers/tp_worker.py +111 -1
  20. sglang/srt/mem_cache/chunk_cache.py +8 -4
  21. sglang/srt/mem_cache/memory_pool.py +77 -4
  22. sglang/srt/mem_cache/radix_cache.py +15 -7
  23. sglang/srt/model_executor/cuda_graph_runner.py +4 -4
  24. sglang/srt/model_executor/forward_batch_info.py +16 -21
  25. sglang/srt/model_executor/model_runner.py +60 -1
  26. sglang/srt/models/baichuan.py +2 -3
  27. sglang/srt/models/chatglm.py +5 -6
  28. sglang/srt/models/commandr.py +1 -2
  29. sglang/srt/models/dbrx.py +1 -2
  30. sglang/srt/models/deepseek.py +4 -5
  31. sglang/srt/models/deepseek_v2.py +5 -6
  32. sglang/srt/models/exaone.py +1 -2
  33. sglang/srt/models/gemma.py +2 -2
  34. sglang/srt/models/gemma2.py +5 -5
  35. sglang/srt/models/gpt_bigcode.py +5 -5
  36. sglang/srt/models/grok.py +1 -2
  37. sglang/srt/models/internlm2.py +1 -2
  38. sglang/srt/models/llama.py +1 -2
  39. sglang/srt/models/llama_classification.py +1 -2
  40. sglang/srt/models/llama_reward.py +2 -3
  41. sglang/srt/models/llava.py +4 -8
  42. sglang/srt/models/llavavid.py +1 -2
  43. sglang/srt/models/minicpm.py +1 -2
  44. sglang/srt/models/minicpm3.py +5 -6
  45. sglang/srt/models/mixtral.py +1 -2
  46. sglang/srt/models/mixtral_quant.py +1 -2
  47. sglang/srt/models/olmo.py +352 -0
  48. sglang/srt/models/olmoe.py +1 -2
  49. sglang/srt/models/qwen.py +1 -2
  50. sglang/srt/models/qwen2.py +1 -2
  51. sglang/srt/models/qwen2_moe.py +4 -5
  52. sglang/srt/models/stablelm.py +1 -2
  53. sglang/srt/models/torch_native_llama.py +1 -2
  54. sglang/srt/models/xverse.py +1 -2
  55. sglang/srt/models/xverse_moe.py +4 -5
  56. sglang/srt/models/yivl.py +1 -2
  57. sglang/srt/openai_api/adapter.py +92 -49
  58. sglang/srt/openai_api/protocol.py +10 -2
  59. sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
  60. sglang/srt/sampling/sampling_batch_info.py +92 -58
  61. sglang/srt/sampling/sampling_params.py +2 -0
  62. sglang/srt/server.py +116 -17
  63. sglang/srt/server_args.py +121 -45
  64. sglang/srt/utils.py +11 -3
  65. sglang/test/few_shot_gsm8k.py +4 -1
  66. sglang/test/few_shot_gsm8k_engine.py +144 -0
  67. sglang/test/srt/sampling/penaltylib/utils.py +16 -12
  68. sglang/version.py +1 -1
  69. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/METADATA +72 -29
  70. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/RECORD +73 -70
  71. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
  72. sglang/srt/layers/attention/flashinfer_utils.py +0 -237
  73. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
  74. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
sglang/bench_latency.py CHANGED
@@ -232,17 +232,18 @@ def extend(reqs, model_runner):
232
232
  model_worker_batch = batch.get_model_worker_batch()
233
233
  forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
234
234
  logits_output = model_runner.forward(forward_batch)
235
- next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
235
+ next_token_ids = model_runner.sample(logits_output, forward_batch)
236
236
  return next_token_ids, logits_output.next_token_logits, batch
237
237
 
238
238
 
239
239
  @torch.inference_mode()
240
240
  def decode(input_token_ids, batch, model_runner):
241
- batch.prepare_for_decode(input_token_ids)
241
+ batch.output_ids = input_token_ids
242
+ batch.prepare_for_decode()
242
243
  model_worker_batch = batch.get_model_worker_batch()
243
244
  forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
244
245
  logits_output = model_runner.forward(forward_batch)
245
- next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
246
+ next_token_ids = model_runner.sample(logits_output, forward_batch)
246
247
  return next_token_ids, logits_output.next_token_logits
247
248
 
248
249
 
@@ -252,6 +253,7 @@ def correctness_test(
252
253
  bench_args,
253
254
  tp_rank,
254
255
  ):
256
+ configure_logger(server_args, prefix=f" TP{tp_rank}")
255
257
  rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
256
258
 
257
259
  # Load the model
@@ -279,8 +281,9 @@ def correctness_test(
279
281
  output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
280
282
  for _ in range(bench_args.output_len[0] - 1):
281
283
  next_token_ids, _ = decode(next_token_ids, batch, model_runner)
284
+ next_token_ids_list = next_token_ids.tolist()
282
285
  for i in range(len(reqs)):
283
- output_ids[i].append(next_token_ids[i])
286
+ output_ids[i].append(next_token_ids_list[i])
284
287
 
285
288
  # Print
286
289
  for i in range(len(reqs)):
@@ -288,8 +291,15 @@ def correctness_test(
288
291
  rank_print(tokenizer.decode(output_ids[i]), "\n")
289
292
 
290
293
 
294
+ def synchronize(device):
295
+ if device == "cuda":
296
+ torch.cuda.synchronize()
297
+ elif device == "xpu":
298
+ torch.xpu.synchronize()
299
+
300
+
291
301
  def latency_test_run_once(
292
- run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
302
+ run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device
293
303
  ):
294
304
  max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
295
305
  if batch_size > max_batch_size:
@@ -312,10 +322,10 @@ def latency_test_run_once(
312
322
  tot_latency = 0
313
323
 
314
324
  # Prefill
315
- torch.cuda.synchronize()
325
+ synchronize(device)
316
326
  tic = time.time()
317
327
  next_token_ids, _, batch = extend(reqs, model_runner)
318
- torch.cuda.synchronize()
328
+ synchronize(device)
319
329
  prefill_latency = time.time() - tic
320
330
  tot_latency += prefill_latency
321
331
  throughput = input_len * batch_size / prefill_latency
@@ -328,10 +338,10 @@ def latency_test_run_once(
328
338
  # Decode
329
339
  decode_latencies = []
330
340
  for i in range(output_len - 1):
331
- torch.cuda.synchronize()
341
+ synchronize(device)
332
342
  tic = time.time()
333
343
  next_token_ids, _ = decode(next_token_ids, batch, model_runner)
334
- torch.cuda.synchronize()
344
+ synchronize(device)
335
345
  latency = time.time() - tic
336
346
  tot_latency += latency
337
347
  throughput = batch_size / latency
@@ -387,6 +397,7 @@ def latency_test(
387
397
  bench_args.batch_size[0],
388
398
  bench_args.input_len[0],
389
399
  8, # shorter decoding to speed up the warmup
400
+ server_args.device,
390
401
  )
391
402
  rank_print("Benchmark ...")
392
403
 
@@ -397,7 +408,14 @@ def latency_test(
397
408
  ):
398
409
  reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
399
410
  ret = latency_test_run_once(
400
- bench_args.run_name, model_runner, rank_print, reqs, bs, il, ol
411
+ bench_args.run_name,
412
+ model_runner,
413
+ rank_print,
414
+ reqs,
415
+ bs,
416
+ il,
417
+ ol,
418
+ server_args.device,
401
419
  )
402
420
  if ret is not None:
403
421
  result_list.append(ret)
@@ -6,6 +6,8 @@ It accepts arguments similar to those of launch_server.py.
6
6
  Usage:
7
7
 
8
8
  python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
9
+
10
+ python3 -m sglang.bench_server_latency --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
9
11
  """
10
12
 
11
13
  import argparse
@@ -32,6 +34,8 @@ class BenchArgs:
32
34
  input_len: Tuple[int] = (1024,)
33
35
  output_len: Tuple[int] = (16,)
34
36
  result_filename: str = "result.jsonl"
37
+ base_url: str = ""
38
+ skip_warmup: bool = False
35
39
 
36
40
  @staticmethod
37
41
  def add_cli_args(parser: argparse.ArgumentParser):
@@ -48,6 +52,8 @@ class BenchArgs:
48
52
  parser.add_argument(
49
53
  "--result-filename", type=str, default=BenchArgs.result_filename
50
54
  )
55
+ parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
56
+ parser.add_argument("--skip-warmup", action="store_true")
51
57
 
52
58
  @classmethod
53
59
  def from_cli_args(cls, args: argparse.Namespace):
@@ -139,17 +145,21 @@ def run_one_case(
139
145
 
140
146
 
141
147
  def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
142
- proc, base_url = launch_server_process(server_args)
148
+ if bench_args.base_url:
149
+ proc, base_url = None, bench_args.base_url
150
+ else:
151
+ proc, base_url = launch_server_process(server_args)
143
152
 
144
153
  # warmup
145
- run_one_case(
146
- base_url,
147
- batch_size=16,
148
- input_len=1024,
149
- output_len=16,
150
- run_name="",
151
- result_filename="",
152
- )
154
+ if not bench_args.skip_warmup:
155
+ run_one_case(
156
+ base_url,
157
+ batch_size=16,
158
+ input_len=1024,
159
+ output_len=16,
160
+ run_name="",
161
+ result_filename="",
162
+ )
153
163
 
154
164
  # benchmark
155
165
  try:
@@ -165,7 +175,8 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
165
175
  bench_args.result_filename,
166
176
  )
167
177
  finally:
168
- kill_child_process(proc.pid)
178
+ if proc:
179
+ kill_child_process(proc.pid)
169
180
 
170
181
  print(f"\nResults are saved to {bench_args.result_filename}")
171
182
 
sglang/bench_serving.py CHANGED
@@ -222,6 +222,85 @@ async def async_request_openai_completions(
222
222
  return output
223
223
 
224
224
 
225
+ async def async_request_sglang_generate(
226
+ request_func_input: RequestFuncInput,
227
+ pbar: Optional[tqdm] = None,
228
+ ) -> RequestFuncOutput:
229
+ api_url = request_func_input.api_url
230
+ prompt = request_func_input.prompt
231
+
232
+ async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
233
+ payload = {
234
+ "text": prompt,
235
+ "sampling_params": {
236
+ "temperature": 0.0,
237
+ "max_new_tokens": request_func_input.output_len,
238
+ "ignore_eos": not args.disable_ignore_eos,
239
+ },
240
+ "stream": not args.disable_stream,
241
+ **request_func_input.extra_request_body,
242
+ }
243
+ headers = {}
244
+
245
+ output = RequestFuncOutput()
246
+ output.prompt_len = request_func_input.prompt_len
247
+
248
+ generated_text = ""
249
+ ttft = 0.0
250
+ st = time.perf_counter()
251
+ most_recent_timestamp = st
252
+ try:
253
+ async with session.post(
254
+ url=api_url, json=payload, headers=headers
255
+ ) as response:
256
+ if response.status == 200:
257
+ async for chunk_bytes in response.content:
258
+ chunk_bytes = chunk_bytes.strip()
259
+ if not chunk_bytes:
260
+ continue
261
+ # print(chunk_bytes)
262
+
263
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
264
+ latency = time.perf_counter() - st
265
+ if chunk == "[DONE]":
266
+ pass
267
+ else:
268
+ data = json.loads(chunk)
269
+
270
+ # NOTE: Some completion API might have a last
271
+ # usage summary response without a token so we
272
+ # want to check a token was generated
273
+ if data["text"]:
274
+ timestamp = time.perf_counter()
275
+ # First token
276
+ if ttft == 0.0:
277
+ ttft = time.perf_counter() - st
278
+ output.ttft = ttft
279
+
280
+ # Decoding phase
281
+ else:
282
+ output.itl.append(timestamp - most_recent_timestamp)
283
+
284
+ most_recent_timestamp = timestamp
285
+ generated_text = data["text"]
286
+
287
+ output.generated_text = generated_text
288
+ output.success = True
289
+ output.latency = latency
290
+ output.output_len = request_func_input.output_len
291
+ else:
292
+ output.error = response.reason or ""
293
+ output.success = False
294
+ except Exception:
295
+ output.success = False
296
+ exc_info = sys.exc_info()
297
+ output.error = "".join(traceback.format_exception(*exc_info))
298
+
299
+ if pbar:
300
+ pbar.update(1)
301
+ return output
302
+
303
+
225
304
  async def async_request_gserver(
226
305
  request_func_input: RequestFuncInput,
227
306
  pbar: Optional[tqdm] = None,
@@ -264,7 +343,9 @@ def get_tokenizer(
264
343
 
265
344
 
266
345
  ASYNC_REQUEST_FUNCS = {
267
- "sglang": async_request_openai_completions,
346
+ "sglang": async_request_sglang_generate,
347
+ "sglang-native": async_request_sglang_generate,
348
+ "sglang-oai": async_request_openai_completions,
268
349
  "vllm": async_request_openai_completions,
269
350
  "lmdeploy": async_request_openai_completions,
270
351
  "trt": async_request_trt_llm,
@@ -387,6 +468,8 @@ def sample_sharegpt_requests(
387
468
  continue
388
469
  filtered_dataset.append((prompt, prompt_len, output_len))
389
470
 
471
+ print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
472
+ print(f"#Output tokens: {np.sum([x[2] for x in filtered_dataset])}")
390
473
  return filtered_dataset
391
474
 
392
475
 
@@ -587,6 +670,8 @@ async def benchmark(
587
670
  else:
588
671
  print("Initial test run completed. Starting main benchmark run...")
589
672
 
673
+ time.sleep(1.5)
674
+
590
675
  pbar = None if disable_tqdm else tqdm(total=len(input_requests))
591
676
 
592
677
  benchmark_start_time = time.perf_counter()
@@ -782,24 +867,33 @@ def run_benchmark(args_: argparse.Namespace):
782
867
  if args.port is None:
783
868
  args.port = {
784
869
  "sglang": 30000,
870
+ "sglang-native": 30000,
871
+ "sglang-oai": 30000,
785
872
  "lmdeploy": 23333,
786
873
  "vllm": 8000,
787
874
  "trt": 8000,
788
875
  "gserver": 9988,
789
876
  }.get(args.backend, 30000)
790
877
 
791
- api_url = (
792
- f"{args.base_url}/v1/completions"
793
- if args.base_url
794
- else f"http://{args.host}:{args.port}/v1/completions"
795
- )
796
878
  model_url = (
797
879
  f"{args.base_url}/v1/models"
798
880
  if args.base_url
799
881
  else f"http://{args.host}:{args.port}/v1/models"
800
882
  )
801
883
 
802
- if args.backend == "trt":
884
+ if args.backend in ["sglang", "sglang-native"]:
885
+ api_url = (
886
+ f"{args.base_url}/generate"
887
+ if args.base_url
888
+ else f"http://{args.host}:{args.port}/generate"
889
+ )
890
+ elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
891
+ api_url = (
892
+ f"{args.base_url}/v1/completions"
893
+ if args.base_url
894
+ else f"http://{args.host}:{args.port}/v1/completions"
895
+ )
896
+ elif args.backend == "trt":
803
897
  api_url = (
804
898
  f"{args.base_url}/v2/models/ensemble/generate_stream"
805
899
  if args.base_url
sglang/global_config.py CHANGED
@@ -19,7 +19,6 @@ class GlobalConfig:
19
19
  self.new_token_ratio_decay = 0.001
20
20
 
21
21
  # Runtime constants: others
22
- self.num_continue_decode_steps = 10
23
22
  self.retract_decode_steps = 20
24
23
  self.flashinfer_workspace_size = os.environ.get(
25
24
  "FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
@@ -1,5 +1,6 @@
1
1
  from abc import ABC, abstractmethod
2
2
 
3
+ import torch
3
4
  from torch import nn
4
5
 
5
6
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -18,13 +19,13 @@ class AttentionBackend(ABC):
18
19
  raise NotImplementedError()
19
20
 
20
21
  def init_forward_metadata_capture_cuda_graph(
21
- self, bs: int, req_pool_indices, seq_lens
22
+ self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
22
23
  ):
23
24
  """Init the metadata for a forward pass for capturing a cuda graph."""
24
25
  raise NotImplementedError()
25
26
 
26
27
  def init_forward_metadata_replay_cuda_graph(
27
- self, bs: int, req_pool_indices, seq_lens
28
+ self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
28
29
  ):
29
30
  """Init the metadata for a forward pass for replying a cuda graph."""
30
31
  raise NotImplementedError()
@@ -33,17 +34,38 @@ class AttentionBackend(ABC):
33
34
  """Get the fill value for padded seq lens. Typically, it is 0 or 1."""
34
35
  raise NotImplementedError()
35
36
 
36
- def forward(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
37
+ def forward(
38
+ self,
39
+ q: torch.Tensor,
40
+ k: torch.Tensor,
41
+ v: torch.Tensor,
42
+ layer: nn.Module,
43
+ forward_batch: ForwardBatch,
44
+ ):
37
45
  """Run forward on an attention layer."""
38
46
  if forward_batch.forward_mode.is_decode():
39
47
  return self.forward_decode(q, k, v, layer, forward_batch)
40
48
  else:
41
49
  return self.forward_extend(q, k, v, layer, forward_batch)
42
50
 
43
- def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
51
+ def forward_decode(
52
+ self,
53
+ q: torch.Tensor,
54
+ k: torch.Tensor,
55
+ v: torch.Tensor,
56
+ layer: nn.Module,
57
+ forward_batch: ForwardBatch,
58
+ ):
44
59
  """Run a forward for decode."""
45
60
  raise NotImplementedError()
46
61
 
47
- def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
62
+ def forward_extend(
63
+ self,
64
+ q: torch.Tensor,
65
+ k: torch.Tensor,
66
+ v: torch.Tensor,
67
+ layer: nn.Module,
68
+ forward_batch: ForwardBatch,
69
+ ):
48
70
  """Run a forward for extend."""
49
71
  raise NotImplementedError()
@@ -0,0 +1,281 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from sglang.srt.layers.attention import AttentionBackend
9
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
10
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
11
+
12
+ if TYPE_CHECKING:
13
+ from sglang.srt.model_executor.model_runner import ModelRunner
14
+
15
+
16
+ class DoubleSparseAttnBackend(AttentionBackend):
17
+ def __init__(self, model_runner: ModelRunner):
18
+ # Lazy import to avoid the initialization of cuda context
19
+ from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
20
+ flash_decode_attention_fwd,
21
+ flash_decode_sparse_attention_fwd,
22
+ )
23
+ from sglang.srt.layers.attention.triton_ops.extend_attention import (
24
+ extend_attention_fwd,
25
+ )
26
+
27
+ super().__init__()
28
+
29
+ self.decode_attention_fwd = flash_decode_attention_fwd
30
+ self.decode_sparse_attention_fwd = flash_decode_sparse_attention_fwd
31
+ self.extend_attention_fwd = extend_attention_fwd
32
+ self.num_head = model_runner.model_config.num_attention_heads
33
+ self.head_dim = model_runner.model_config.hidden_size // self.num_head
34
+ self.heavy_token_num = model_runner.server_args.ds_heavy_token_num
35
+
36
+ self.sorted_channels = model_runner.sorted_channels
37
+ self.sparse_decode_thresold = (
38
+ model_runner.server_args.ds_sparse_decode_threshold
39
+ )
40
+ self.att_out_approx: torch.Tensor = None
41
+ self.mid_out: torch.Tensor = None
42
+ self.mid_o_logexpsum: torch.Tensor = None
43
+
44
+ # TODO: Change the hard-coded block_seq_num
45
+ self.BLOCK_SEQ = 128
46
+
47
+ if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
48
+ self.reduce_dtype = torch.float32
49
+ else:
50
+ self.reduce_dtype = torch.float16
51
+
52
+ self.forward_metadata = None
53
+
54
+ self.cuda_graph_max_seq_len = model_runner.model_config.context_len
55
+
56
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
57
+ """Init auxiliary variables for triton attention backend."""
58
+
59
+ if forward_batch.forward_mode.is_decode():
60
+ start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
61
+ start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
62
+
63
+ total_num_tokens = torch.sum(forward_batch.seq_lens).item()
64
+ attn_logits = torch.empty(
65
+ (self.num_head, total_num_tokens),
66
+ dtype=self.reduce_dtype,
67
+ device="cuda",
68
+ )
69
+
70
+ max_seq_len = torch.max(forward_batch.seq_lens).item()
71
+ min_seq_len = torch.min(forward_batch.seq_lens).item()
72
+ max_extend_len = None
73
+ # NOTE: Align sequence order with req_to_token order
74
+ ds_req_to_token = forward_batch.req_to_token_pool.req_to_token[
75
+ forward_batch.req_pool_indices
76
+ ]
77
+
78
+ bsz = forward_batch.seq_lens.shape[0]
79
+
80
+ att_out_approx = torch.empty(
81
+ [self.num_head, bsz, max_seq_len],
82
+ dtype=self.reduce_dtype,
83
+ device="cuda",
84
+ )
85
+
86
+ block_seq_num = (
87
+ self.heavy_token_num + self.BLOCK_SEQ - 1
88
+ ) // self.BLOCK_SEQ
89
+
90
+ mid_out = torch.empty(
91
+ [bsz, self.num_head, block_seq_num, self.head_dim],
92
+ dtype=torch.float32,
93
+ device="cuda",
94
+ )
95
+ mid_o_logexpsum = torch.empty(
96
+ [bsz, self.num_head, block_seq_num], dtype=torch.float32, device="cuda"
97
+ )
98
+ self.att_out_approx = att_out_approx
99
+ self.mid_out = mid_out
100
+ self.mid_o_logexpsum = mid_o_logexpsum
101
+
102
+ else:
103
+ start_loc = attn_logits = max_seq_len = min_seq_len = None
104
+ prefix_lens = forward_batch.extend_prefix_lens
105
+ max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item()
106
+ ds_req_to_token = None
107
+
108
+ self.forward_metadata = (
109
+ start_loc,
110
+ attn_logits,
111
+ max_seq_len,
112
+ min_seq_len,
113
+ max_extend_len,
114
+ ds_req_to_token,
115
+ )
116
+
117
+ def init_cuda_graph_state(self, max_bs: int):
118
+ # TODO(Andy): Support CUDA graph for double sparse attention
119
+ raise ValueError(
120
+ "Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph"
121
+ )
122
+ self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
123
+
124
+ self.cuda_graph_start_loc = torch.zeros(
125
+ (max_bs,), dtype=torch.int32, device="cuda"
126
+ )
127
+ self.cuda_graph_attn_logits = torch.empty(
128
+ (
129
+ self.num_head,
130
+ self.cuda_graph_max_total_num_tokens,
131
+ ),
132
+ dtype=self.reduce_dtype,
133
+ device="cuda",
134
+ )
135
+
136
+ def init_forward_metadata_capture_cuda_graph(
137
+ self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
138
+ ):
139
+ self.forward_metadata = (
140
+ self.cuda_graph_start_loc,
141
+ self.cuda_graph_attn_logits,
142
+ self.cuda_graph_max_seq_len,
143
+ None,
144
+ )
145
+
146
+ def init_forward_metadata_replay_cuda_graph(
147
+ self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
148
+ ):
149
+ self.cuda_graph_start_loc.zero_()
150
+ self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
151
+
152
+ def get_cuda_graph_seq_len_fill_value(self):
153
+ return 1
154
+
155
+ def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
156
+ # TODO: reuse the buffer across layers
157
+ if layer.qk_head_dim != layer.v_head_dim:
158
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
159
+ else:
160
+ o = torch.empty_like(q)
161
+
162
+ k_label = torch.gather(
163
+ k,
164
+ 2,
165
+ self.sorted_channels[layer.layer_id]
166
+ .unsqueeze(0)
167
+ .expand(k.shape[0], -1, -1),
168
+ )
169
+
170
+ forward_batch.token_to_kv_pool.set_kv_buffer(
171
+ layer.layer_id, forward_batch.out_cache_loc, k, v, k_label
172
+ )
173
+
174
+ (
175
+ start_loc,
176
+ attn_logits,
177
+ max_seq_len,
178
+ min_seq_len,
179
+ max_extend_len,
180
+ ds_req_to_token,
181
+ ) = self.forward_metadata
182
+ self.extend_attention_fwd(
183
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
184
+ k.contiguous(),
185
+ v.contiguous(),
186
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
187
+ forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
188
+ forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
189
+ forward_batch.req_to_token_pool.req_to_token,
190
+ forward_batch.req_pool_indices,
191
+ forward_batch.seq_lens,
192
+ forward_batch.extend_seq_lens,
193
+ forward_batch.extend_start_loc,
194
+ max_extend_len,
195
+ layer.scaling,
196
+ layer.logit_cap,
197
+ )
198
+ return o
199
+
200
+ def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
201
+ # During torch.compile, there is a bug in rotary_emb that causes the
202
+ # output value to have a 3D tensor shape. This reshapes the output correctly.
203
+ q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
204
+
205
+ # TODO: reuse the buffer across layers
206
+ if layer.qk_head_dim != layer.v_head_dim:
207
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
208
+ else:
209
+ o = torch.empty_like(q)
210
+
211
+ # TODO: Add min seqlen
212
+ (
213
+ start_loc,
214
+ attn_logits,
215
+ max_seq_len,
216
+ min_seq_len,
217
+ max_extend_len,
218
+ ds_req_to_token,
219
+ ) = self.forward_metadata
220
+
221
+ k_label = torch.gather(
222
+ k,
223
+ 2,
224
+ self.sorted_channels[layer.layer_id]
225
+ .unsqueeze(0)
226
+ .expand(k.shape[0], -1, -1),
227
+ )
228
+
229
+ forward_batch.token_to_kv_pool.set_kv_buffer(
230
+ layer.layer_id, forward_batch.out_cache_loc, k, v, k_label
231
+ )
232
+
233
+ # NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
234
+ # and set a minimum value for sparse_decode
235
+ if (
236
+ min_seq_len < self.heavy_token_num
237
+ or max_seq_len < self.sparse_decode_thresold
238
+ ):
239
+ self.decode_attention_fwd(
240
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
241
+ forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
242
+ forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
243
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
244
+ forward_batch.req_to_token_pool.req_to_token,
245
+ forward_batch.req_pool_indices,
246
+ start_loc,
247
+ forward_batch.seq_lens,
248
+ attn_logits,
249
+ max_seq_len,
250
+ layer.scaling,
251
+ layer.logit_cap,
252
+ )
253
+ else:
254
+ # TODO(Andy): indexing with torch.gather or torch.index_select or customized kernel
255
+ q_label = torch.gather(
256
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
257
+ 2,
258
+ self.sorted_channels[layer.layer_id]
259
+ .unsqueeze(0)
260
+ .expand(q.shape[0], -1, -1),
261
+ )
262
+ self.decode_sparse_attention_fwd(
263
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
264
+ forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
265
+ forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
266
+ o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
267
+ q_label,
268
+ forward_batch.token_to_kv_pool.get_label_buffer(layer.layer_id),
269
+ ds_req_to_token,
270
+ forward_batch.seq_lens,
271
+ max_seq_len,
272
+ layer.scaling,
273
+ layer.logit_cap,
274
+ self.heavy_token_num,
275
+ self.att_out_approx,
276
+ self.mid_out,
277
+ self.mid_o_logexpsum,
278
+ self.BLOCK_SEQ,
279
+ )
280
+
281
+ return o