sglang 0.1.22__py3-none-any.whl → 0.1.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sglang/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.22"
1
+ __version__ = "0.1.24"
2
2
 
3
3
  # SGL API Components
4
4
  from sglang.api import (
sglang/bench_serving.py CHANGED
@@ -5,6 +5,9 @@ Benchmark online serving.
5
5
 
6
6
  Usage:
7
7
  python3 -m sglang.bench_serving --backend sglang --num-prompt 10
8
+
9
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5
10
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi
8
11
  """
9
12
 
10
13
  import argparse
@@ -19,6 +22,7 @@ import traceback
19
22
  import warnings
20
23
  from argparse import ArgumentParser as FlexibleArgumentParser
21
24
  from dataclasses import dataclass, field
25
+ from datetime import datetime
22
26
  from typing import AsyncGenerator, List, Optional, Tuple, Union
23
27
 
24
28
  import aiohttp
@@ -53,12 +57,80 @@ class RequestFuncOutput:
53
57
  itl: List[float] = field(default_factory=list) # List of inter-token latencies
54
58
  prompt_len: int = 0
55
59
  error: str = ""
60
+ output_len: int = 0
56
61
 
57
62
 
58
63
  def remove_prefix(text: str, prefix: str) -> str:
59
64
  return text[len(prefix) :] if text.startswith(prefix) else text
60
65
 
61
66
 
67
+ # trt llm not support ignore_eos
68
+ # https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
69
+ async def async_request_trt_llm(
70
+ request_func_input: RequestFuncInput,
71
+ pbar: Optional[tqdm] = None,
72
+ ) -> RequestFuncOutput:
73
+ api_url = request_func_input.api_url
74
+ assert api_url.endswith("generate_stream")
75
+
76
+ async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
77
+ payload = {
78
+ "accumulate_tokens": True,
79
+ "text_input": request_func_input.prompt,
80
+ "temperature": 0.000001,
81
+ "top_p": 1.0,
82
+ "max_tokens": request_func_input.output_len,
83
+ "stream": True,
84
+ "min_length": request_func_input.output_len,
85
+ "end_id": 1048576,
86
+ }
87
+ output = RequestFuncOutput()
88
+ output.prompt_len = request_func_input.prompt_len
89
+
90
+ ttft = 0.0
91
+ st = time.perf_counter()
92
+ most_recent_timestamp = st
93
+ try:
94
+ async with session.post(url=api_url, json=payload) as response:
95
+ if response.status == 200:
96
+ async for chunk_bytes in response.content:
97
+ chunk_bytes = chunk_bytes.strip()
98
+ if not chunk_bytes:
99
+ continue
100
+
101
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:")
102
+
103
+ data = json.loads(chunk)
104
+ output.generated_text += data["text_output"]
105
+ timestamp = time.perf_counter()
106
+ # First token
107
+ if ttft == 0.0:
108
+ ttft = time.perf_counter() - st
109
+ output.ttft = ttft
110
+
111
+ # Decoding phase
112
+ else:
113
+ output.itl.append(timestamp - most_recent_timestamp)
114
+
115
+ most_recent_timestamp = timestamp
116
+
117
+ output.latency = most_recent_timestamp - st
118
+ output.success = True
119
+ output.output_len = request_func_input.output_len
120
+
121
+ else:
122
+ output.error = response.reason or ""
123
+ output.success = False
124
+ except Exception:
125
+ output.success = False
126
+ exc_info = sys.exc_info()
127
+ output.error = "".join(traceback.format_exception(*exc_info))
128
+
129
+ if pbar:
130
+ pbar.update(1)
131
+ return output
132
+
133
+
62
134
  # set ignore_eos True by default
63
135
  async def async_request_openai_completions(
64
136
  request_func_input: RequestFuncInput,
@@ -76,7 +148,7 @@ async def async_request_openai_completions(
76
148
  "temperature": 0.0,
77
149
  "best_of": 1,
78
150
  "max_tokens": request_func_input.output_len,
79
- "stream": True,
151
+ "stream": not args.disable_stream,
80
152
  "ignore_eos": True,
81
153
  }
82
154
  headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
@@ -99,8 +171,9 @@ async def async_request_openai_completions(
99
171
  continue
100
172
 
101
173
  chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
174
+ latency = time.perf_counter() - st
102
175
  if chunk == "[DONE]":
103
- latency = time.perf_counter() - st
176
+ pass
104
177
  else:
105
178
  data = json.loads(chunk)
106
179
 
@@ -123,6 +196,7 @@ async def async_request_openai_completions(
123
196
  output.generated_text = generated_text
124
197
  output.success = True
125
198
  output.latency = latency
199
+ output.output_len = request_func_input.output_len
126
200
  else:
127
201
  output.error = response.reason or ""
128
202
  output.success = False
@@ -167,6 +241,7 @@ ASYNC_REQUEST_FUNCS = {
167
241
  "sglang": async_request_openai_completions,
168
242
  "vllm": async_request_openai_completions,
169
243
  "lmdeploy": async_request_openai_completions,
244
+ "trt": async_request_trt_llm,
170
245
  }
171
246
 
172
247
 
@@ -175,9 +250,11 @@ class BenchmarkMetrics:
175
250
  completed: int
176
251
  total_input: int
177
252
  total_output: int
253
+ total_output_retokenized: int
178
254
  request_throughput: float
179
255
  input_throughput: float
180
256
  output_throughput: float
257
+ output_throughput_retokenized: float
181
258
  mean_ttft_ms: float
182
259
  median_ttft_ms: float
183
260
  std_ttft_ms: float
@@ -190,6 +267,8 @@ class BenchmarkMetrics:
190
267
  median_itl_ms: float
191
268
  std_itl_ms: float
192
269
  p99_itl_ms: float
270
+ mean_e2e_latency_ms: float
271
+ median_e2e_latency_ms: float
193
272
 
194
273
 
195
274
  default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json"
@@ -384,31 +463,36 @@ def calculate_metrics(
384
463
  outputs: List[RequestFuncOutput],
385
464
  dur_s: float,
386
465
  tokenizer: PreTrainedTokenizerBase,
466
+ backend: str,
387
467
  ) -> Tuple[BenchmarkMetrics, List[int]]:
388
- actual_output_lens: List[int] = []
468
+ output_lens: List[int] = []
469
+ retokenized_output_lens: List[int] = []
389
470
  total_input = 0
390
471
  completed = 0
391
472
  itls: List[float] = []
392
473
  tpots: List[float] = []
393
474
  ttfts: List[float] = []
475
+ e2e_latencies: List[float] = []
394
476
  for i in range(len(outputs)):
395
477
  if outputs[i].success:
396
- # We use the tokenizer to count the number of output tokens for all
397
- # serving backends instead of looking at len(outputs[i].itl) since
398
- # multiple output tokens may be bundled together
399
- # Note : this may inflate the output token count slightly
400
- output_len = len(
478
+ output_len = outputs[i].output_len
479
+ output_lens.append(output_len)
480
+ retokenized_output_len = len(
401
481
  tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids
402
482
  )
403
- actual_output_lens.append(output_len)
483
+ retokenized_output_lens.append(retokenized_output_len)
404
484
  total_input += input_requests[i][1]
405
485
  if output_len > 1:
406
486
  tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
407
487
  itls += outputs[i].itl
408
488
  ttfts.append(outputs[i].ttft)
489
+
490
+ e2e_latencies.append(outputs[i].latency)
491
+
409
492
  completed += 1
410
493
  else:
411
- actual_output_lens.append(0)
494
+ output_lens.append(0)
495
+ retokenized_output_lens.append(0)
412
496
 
413
497
  if completed == 0:
414
498
  warnings.warn(
@@ -419,10 +503,12 @@ def calculate_metrics(
419
503
  metrics = BenchmarkMetrics(
420
504
  completed=completed,
421
505
  total_input=total_input,
422
- total_output=sum(actual_output_lens),
506
+ total_output=sum(output_lens),
507
+ total_output_retokenized=sum(retokenized_output_lens),
423
508
  request_throughput=completed / dur_s,
424
509
  input_throughput=total_input / dur_s,
425
- output_throughput=sum(actual_output_lens) / dur_s,
510
+ output_throughput=sum(output_lens) / dur_s,
511
+ output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,
426
512
  mean_ttft_ms=np.mean(ttfts or 0)
427
513
  * 1000, # ttfts is empty if streaming is not supported by backend
428
514
  median_ttft_ms=np.median(ttfts or 0) * 1000,
@@ -436,9 +522,11 @@ def calculate_metrics(
436
522
  median_itl_ms=np.median(itls or 0) * 1000,
437
523
  std_itl_ms=np.std(itls or 0) * 1000,
438
524
  p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
525
+ mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
526
+ median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
439
527
  )
440
528
 
441
- return metrics, actual_output_lens
529
+ return metrics, output_lens
442
530
 
443
531
 
444
532
  async def benchmark(
@@ -449,6 +537,7 @@ async def benchmark(
449
537
  input_requests: List[Tuple[str, int, int]],
450
538
  request_rate: float,
451
539
  disable_tqdm: bool,
540
+ enable_multi: bool,
452
541
  ):
453
542
  if backend in ASYNC_REQUEST_FUNCS:
454
543
  request_func = ASYNC_REQUEST_FUNCS[backend]
@@ -498,19 +587,26 @@ async def benchmark(
498
587
 
499
588
  benchmark_duration = time.perf_counter() - benchmark_start_time
500
589
 
501
- metrics, actual_output_lens = calculate_metrics(
590
+ metrics, output_lens = calculate_metrics(
502
591
  input_requests=input_requests,
503
592
  outputs=outputs,
504
593
  dur_s=benchmark_duration,
505
594
  tokenizer=tokenizer,
595
+ backend=backend,
506
596
  )
507
597
 
508
598
  print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
599
+ print("{:<40} {:<10}".format("Backend:", backend))
509
600
  print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
510
601
  print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
511
602
  print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
512
603
  print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
513
604
  print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
605
+ print(
606
+ "{:<40} {:<10}".format(
607
+ "Total generated tokens (retokenized):", metrics.total_output_retokenized
608
+ )
609
+ )
514
610
  print(
515
611
  "{:<40} {:<10.2f}".format(
516
612
  "Request throughput (req/s):", metrics.request_throughput
@@ -526,6 +622,15 @@ async def benchmark(
526
622
  "Output token throughput (tok/s):", metrics.output_throughput
527
623
  )
528
624
  )
625
+ print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
626
+ print(
627
+ "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
628
+ )
629
+ print(
630
+ "{:<40} {:<10.2f}".format(
631
+ "Median E2E Latency (ms):", metrics.median_e2e_latency_ms
632
+ )
633
+ )
529
634
  print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-"))
530
635
  print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
531
636
  print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
@@ -542,11 +647,53 @@ async def benchmark(
542
647
  print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
543
648
  print("=" * 50)
544
649
 
650
+ if (
651
+ metrics.median_ttft_ms is not None
652
+ and metrics.mean_itl_ms is not None
653
+ and metrics.output_throughput is not None
654
+ ):
655
+ result = {
656
+ "backend": args.backend,
657
+ "dataset_name": args.dataset_name,
658
+ "request_rate": request_rate,
659
+ "total_input": metrics.total_input,
660
+ "total_output": metrics.total_output,
661
+ "total_output_retokenized": metrics.total_output_retokenized,
662
+ "mean_e2e_latency": metrics.mean_e2e_latency_ms,
663
+ "median_e2e_latency": metrics.median_e2e_latency_ms,
664
+ "median_ttft": metrics.median_ttft_ms,
665
+ "median_itl": metrics.median_itl_ms,
666
+ "output_token_throughput": metrics.output_throughput,
667
+ "sharegpt_output_len": args.sharegpt_output_len,
668
+ "random_input_len": args.random_input_len,
669
+ "random_output_len": args.random_output_len,
670
+ "random_range_ratio": args.random_range_ratio,
671
+ "benchmark_duration": benchmark_duration,
672
+ }
673
+ else:
674
+ print(f"Error running benchmark for request rate: {request_rate}")
675
+ print("-" * 30)
676
+
677
+ # Determine output file name
678
+ if args.output_file:
679
+ output_file_name = args.output_file
680
+ else:
681
+ now = datetime.now().strftime("%m%d")
682
+ if args.dataset_name == "random":
683
+ output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
684
+ else:
685
+ output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
686
+
687
+ # Append results to a JSONL file
688
+ with open(output_file_name, "a") as file:
689
+ file.write(json.dumps(result) + "\n")
690
+
545
691
  result = {
546
692
  "duration": benchmark_duration,
547
693
  "completed": metrics.completed,
548
694
  "total_input_tokens": metrics.total_input,
549
695
  "total_output_tokens": metrics.total_output,
696
+ "total_output_tokens_retokenized": metrics.total_output_retokenized,
550
697
  "request_throughput": metrics.request_throughput,
551
698
  "input_throughput": metrics.input_throughput,
552
699
  "output_throughput": metrics.output_throughput,
@@ -563,15 +710,34 @@ async def benchmark(
563
710
  "std_itl_ms": metrics.std_itl_ms,
564
711
  "p99_itl_ms": metrics.p99_itl_ms,
565
712
  "input_lens": [output.prompt_len for output in outputs],
566
- "output_lens": actual_output_lens,
713
+ "output_lens": output_lens,
567
714
  "ttfts": [output.ttft for output in outputs],
568
715
  "itls": [output.itl for output in outputs],
569
716
  "generated_texts": [output.generated_text for output in outputs],
570
717
  "errors": [output.error for output in outputs],
718
+ "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
719
+ "median_e2e_latency_ms": metrics.median_e2e_latency_ms,
571
720
  }
572
721
  return result
573
722
 
574
723
 
724
+ def parse_request_rate_range(request_rate_range):
725
+ if len(request_rate_range.split(",")) == 3:
726
+ start, stop, step = map(int, request_rate_range.split(","))
727
+ return list(range(start, stop, step))
728
+ else:
729
+ return list(map(int, request_rate_range.split(",")))
730
+
731
+
732
+ def check_chat_template(model_path):
733
+ try:
734
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
735
+ return "chat_template" in tokenizer.init_kwargs
736
+ except Exception as e:
737
+ print(f"Fail to load tokenizer config with error={e}")
738
+ return False
739
+
740
+
575
741
  def fire(args: argparse.Namespace):
576
742
  random.seed(args.seed)
577
743
  np.random.seed(args.seed)
@@ -581,6 +747,7 @@ def fire(args: argparse.Namespace):
581
747
  "sglang": 30000,
582
748
  "lmdeploy": 23333,
583
749
  "vllm": 8000,
750
+ "trt": 8000,
584
751
  }.get(args.backend, 30000)
585
752
 
586
753
  api_url = (
@@ -594,6 +761,16 @@ def fire(args: argparse.Namespace):
594
761
  else f"http://{args.host}:{args.port}/v1/models"
595
762
  )
596
763
 
764
+ if args.backend == "trt":
765
+ api_url = (
766
+ f"{args.base_url}/v2/models/ensemble/generate_stream"
767
+ if args.base_url
768
+ else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream"
769
+ )
770
+ if args.model is None:
771
+ print("Please provide a model using `--model` when using `trt` backend.")
772
+ sys.exit(1)
773
+
597
774
  if args.model is None:
598
775
  try:
599
776
  response = requests.get(model_url)
@@ -610,6 +787,12 @@ def fire(args: argparse.Namespace):
610
787
  print("No model specified or found. Please provide a model using `--model`.")
611
788
  sys.exit(1)
612
789
 
790
+ if not check_chat_template(args.model):
791
+ print(
792
+ "\nWARNING It is recommended to use the `Chat` or `Instruct` model for benchmarking.\n"
793
+ "Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\n"
794
+ )
795
+
613
796
  print(f"{args}\n")
614
797
 
615
798
  backend = args.backend
@@ -637,17 +820,35 @@ def fire(args: argparse.Namespace):
637
820
  else:
638
821
  raise ValueError(f"Unknown dataset: {args.dataset_name}")
639
822
 
640
- asyncio.run(
641
- benchmark(
642
- backend=backend,
643
- api_url=api_url,
644
- model_id=model_id,
645
- tokenizer=tokenizer,
646
- input_requests=input_requests,
647
- request_rate=args.request_rate,
648
- disable_tqdm=args.disable_tqdm,
823
+ if args.multi:
824
+ request_rates = parse_request_rate_range(args.request_rate_range)
825
+
826
+ for rate in request_rates:
827
+ asyncio.run(
828
+ benchmark(
829
+ backend=backend,
830
+ api_url=api_url,
831
+ model_id=model_id,
832
+ tokenizer=tokenizer,
833
+ input_requests=input_requests,
834
+ request_rate=rate,
835
+ disable_tqdm=args.disable_tqdm,
836
+ enable_multi=args.multi,
837
+ )
838
+ )
839
+ else:
840
+ asyncio.run(
841
+ benchmark(
842
+ backend=backend,
843
+ api_url=api_url,
844
+ model_id=model_id,
845
+ tokenizer=tokenizer,
846
+ input_requests=input_requests,
847
+ request_rate=args.request_rate,
848
+ disable_tqdm=args.disable_tqdm,
849
+ enable_multi=args.multi,
850
+ )
649
851
  )
650
- )
651
852
 
652
853
 
653
854
  # to avoid relying on SGLang's components
@@ -751,6 +952,23 @@ if __name__ == "__main__":
751
952
  action="store_true",
752
953
  help="Specify to disable tqdm progress bar.",
753
954
  )
955
+ parser.add_argument(
956
+ "--multi",
957
+ action="store_true",
958
+ help="Use request rate range rather than single value.",
959
+ )
960
+ parser.add_argument(
961
+ "--request-rate-range",
962
+ type=str,
963
+ default="2,34,2",
964
+ help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.",
965
+ )
966
+ parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
967
+ parser.add_argument(
968
+ "--disable-stream",
969
+ action="store_true",
970
+ help="Disable streaming mode.",
971
+ )
754
972
 
755
973
  set_ulimit()
756
974
 
sglang/global_config.py CHANGED
@@ -16,9 +16,9 @@ class GlobalConfig:
16
16
  self.wait_for_new_request_delay = 0.0006
17
17
 
18
18
  # Runtime constants: New generation token ratio estimation
19
- self.base_new_token_ratio = 0.4
19
+ self.init_new_token_ratio = 0.7
20
20
  self.base_min_new_token_ratio = 0.2
21
- self.new_token_ratio_decay = 0.0001
21
+ self.new_token_ratio_decay = 0.001
22
22
  self.new_token_ratio_recovery = 0.05
23
23
 
24
24
  # Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
@@ -27,6 +27,7 @@ class GlobalConfig:
27
27
 
28
28
  # Runtime constants: others
29
29
  self.num_continue_decode_steps = 10
30
+ self.retract_decode_steps = 20
30
31
  self.flashinfer_workspace_size = 192 * 1024 * 1024
31
32
 
32
33
  # Output tokenization configs
@@ -288,6 +288,7 @@ class StreamExecutor:
288
288
  exes[i].text_ = str(self.text_)
289
289
  exes[i].messages_ = list(self.messages_)
290
290
  exes[i].cur_role = self.cur_role
291
+ exes[i].cur_role_begin_pos = self.cur_role_begin_pos
291
292
  exes[i].fork_start_text_pos = len(self.text_)
292
293
  exes[i].images_ = list(self.images_)
293
294
 
@@ -4,19 +4,26 @@ import functools
4
4
  import json
5
5
  import os
6
6
  import warnings
7
- from typing import AbstractSet, Collection, Literal, Optional, Union
7
+ from typing import AbstractSet, Collection, Dict, Literal, Optional, Type, Union
8
8
 
9
9
  from huggingface_hub import snapshot_download
10
10
  from transformers import (
11
11
  AutoConfig,
12
12
  AutoProcessor,
13
13
  AutoTokenizer,
14
+ PretrainedConfig,
14
15
  PreTrainedTokenizer,
15
16
  PreTrainedTokenizerFast,
16
17
  )
18
+ from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
17
19
 
18
20
  from sglang.srt.utils import is_multimodal_model
19
21
 
22
+ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
23
+ ChatGLMConfig.model_type: ChatGLMConfig,
24
+ DbrxConfig.model_type: DbrxConfig,
25
+ }
26
+
20
27
 
21
28
  def download_from_hf(model_path: str):
22
29
  if os.path.exists(model_path):
@@ -40,6 +47,9 @@ def get_config(
40
47
  config = AutoConfig.from_pretrained(
41
48
  model, trust_remote_code=trust_remote_code, revision=revision
42
49
  )
50
+ if config.model_type in _CONFIG_REGISTRY:
51
+ config_class = _CONFIG_REGISTRY[config.model_type]
52
+ config = config_class.from_pretrained(model, revision=revision)
43
53
  if model_overide_args:
44
54
  config.update(model_overide_args)
45
55
  return config
@@ -63,6 +73,8 @@ def get_context_length(config):
63
73
  rope_scaling = getattr(config, "rope_scaling", None)
64
74
  if rope_scaling:
65
75
  rope_scaling_factor = config.rope_scaling["factor"]
76
+ if config.rope_scaling["rope_type"] == "llama3":
77
+ rope_scaling_factor = 1
66
78
  else:
67
79
  rope_scaling_factor = 1
68
80
 
@@ -34,12 +34,11 @@ class LogitProcessorOutput:
34
34
  @dataclasses.dataclass
35
35
  class LogitsMetadata:
36
36
  forward_mode: ForwardMode
37
- extend_seq_lens: torch.Tensor
38
- extend_start_loc: torch.Tensor
39
-
40
- # For logprobs
41
37
  return_logprob: bool
42
- top_logprobs_nums: List[int]
38
+
39
+ extend_seq_lens: torch.Tensor = None
40
+ extend_start_loc: torch.Tensor = None
41
+ top_logprobs_nums: List[int] = None
43
42
 
44
43
  @classmethod
45
44
  def from_input_metadata(cls, input_metadata: InputMetadata):
@@ -85,32 +85,47 @@ class RadixAttention(nn.Module):
85
85
  return o
86
86
 
87
87
  def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
88
- o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
89
- q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
90
- k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
91
- v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
92
- causal=True,
93
- sm_scale=self.scaling,
94
- logits_soft_cap=self.logit_cap,
95
- )
88
+ if not input_metadata.use_ragged:
89
+ self.store_kv_cache(k, v, input_metadata)
96
90
 
97
- if input_metadata.extend_no_prefix:
98
- o = o1
99
- else:
100
- o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
91
+ o = input_metadata.flashinfer_prefill_wrapper_paged.forward(
101
92
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
102
- input_metadata.token_to_kv_pool.kv_data[self.layer_id],
103
- causal=False,
93
+ input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
94
+ causal=True,
104
95
  sm_scale=self.scaling,
105
96
  logits_soft_cap=self.logit_cap,
106
97
  )
98
+ else:
99
+ o1, s1 = (
100
+ input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
101
+ q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
102
+ k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
103
+ v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
104
+ causal=True,
105
+ sm_scale=self.scaling,
106
+ logits_soft_cap=self.logit_cap,
107
+ )
108
+ )
107
109
 
108
- o, _ = merge_state(o1, s1, o2, s2)
110
+ if input_metadata.extend_no_prefix:
111
+ o = o1
112
+ else:
113
+ o2, s2 = (
114
+ input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
115
+ q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
116
+ input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
117
+ causal=False,
118
+ sm_scale=self.scaling,
119
+ logits_soft_cap=self.logit_cap,
120
+ )
121
+ )
109
122
 
110
- self.store_kv_cache(k, v, input_metadata)
123
+ o, _ = merge_state(o1, s1, o2, s2)
124
+
125
+ self.store_kv_cache(k, v, input_metadata)
111
126
 
112
- if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
113
- torch.cuda.synchronize()
127
+ if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
128
+ torch.cuda.synchronize()
114
129
 
115
130
  return o.view(-1, self.tp_q_head_num * self.head_dim)
116
131
 
@@ -119,7 +134,7 @@ class RadixAttention(nn.Module):
119
134
 
120
135
  o = input_metadata.flashinfer_decode_wrapper.forward(
121
136
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
122
- input_metadata.token_to_kv_pool.kv_data[self.layer_id],
137
+ input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
123
138
  sm_scale=self.scaling,
124
139
  logits_soft_cap=self.logit_cap,
125
140
  )
@@ -136,33 +151,7 @@ class RadixAttention(nn.Module):
136
151
  return self.decode_forward(q, k, v, input_metadata)
137
152
 
138
153
  def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
139
- kv_cache = input_metadata.token_to_kv_pool.kv_data[self.layer_id]
140
- _store_kv_cache(cache_k, cache_v, kv_cache, input_metadata.out_cache_loc)
141
-
142
-
143
- try:
144
-
145
- @torch.library.custom_op("mylib::store_kv_cache", mutates_args={"kv_cache"})
146
- def _store_kv_cache(
147
- k: torch.Tensor,
148
- v: torch.Tensor,
149
- kv_cache: torch.Tensor,
150
- cache_loc: torch.Tensor,
151
- ) -> None:
152
- kv_cache[cache_loc, 0] = k
153
- kv_cache[cache_loc, 1] = v
154
-
155
- @_store_kv_cache.register_fake
156
- def _(k, v, kv_cache, cache_loc):
157
- pass
158
-
159
- except:
160
-
161
- def _store_kv_cache(
162
- k: torch.Tensor,
163
- v: torch.Tensor,
164
- kv_cache: torch.Tensor,
165
- cache_loc: torch.Tensor,
166
- ) -> None:
167
- kv_cache[cache_loc, 0] = k
168
- kv_cache[cache_loc, 1] = v
154
+ k_cache = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
155
+ v_cache = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
156
+ k_cache[input_metadata.out_cache_loc] = cache_k
157
+ v_cache[input_metadata.out_cache_loc] = cache_v