sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py CHANGED
@@ -24,6 +24,7 @@ import warnings
24
24
  from argparse import ArgumentParser
25
25
  from dataclasses import dataclass, field
26
26
  from datetime import datetime
27
+ from json import JSONDecodeError
27
28
  from pathlib import Path
28
29
  from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
29
30
 
@@ -73,6 +74,12 @@ class RequestFuncOutput:
73
74
  error: str = ""
74
75
  output_len: int = 0
75
76
 
77
+ @staticmethod
78
+ def init_new(request_func_input: RequestFuncInput):
79
+ output = RequestFuncOutput()
80
+ output.prompt_len = request_func_input.prompt_len
81
+ return output
82
+
76
83
 
77
84
  def remove_prefix(text: str, prefix: str) -> str:
78
85
  return text[len(prefix) :] if text.startswith(prefix) else text
@@ -114,8 +121,7 @@ async def async_request_trt_llm(
114
121
  if args.disable_ignore_eos:
115
122
  del payload["min_length"]
116
123
  del payload["end_id"]
117
- output = RequestFuncOutput()
118
- output.prompt_len = request_func_input.prompt_len
124
+ output = RequestFuncOutput.init_new(request_func_input)
119
125
 
120
126
  ttft = 0.0
121
127
  st = time.perf_counter()
@@ -186,8 +192,7 @@ async def async_request_openai_completions(
186
192
  }
187
193
  headers = get_auth_headers()
188
194
 
189
- output = RequestFuncOutput()
190
- output.prompt_len = request_func_input.prompt_len
195
+ output = RequestFuncOutput.init_new(request_func_input)
191
196
 
192
197
  generated_text = ""
193
198
  output_len = request_func_input.output_len
@@ -269,8 +274,7 @@ async def async_request_truss(
269
274
  }
270
275
  headers = get_auth_headers()
271
276
 
272
- output = RequestFuncOutput()
273
- output.prompt_len = request_func_input.prompt_len
277
+ output = RequestFuncOutput.init_new(request_func_input)
274
278
 
275
279
  generated_text = ""
276
280
  ttft = 0.0
@@ -355,8 +359,7 @@ async def async_request_sglang_generate(
355
359
 
356
360
  headers = get_auth_headers()
357
361
 
358
- output = RequestFuncOutput()
359
- output.prompt_len = request_func_input.prompt_len
362
+ output = RequestFuncOutput.init_new(request_func_input)
360
363
 
361
364
  generated_text = ""
362
365
  output_len = request_func_input.output_len
@@ -469,6 +472,10 @@ def get_model(pretrained_model_name_or_path: str) -> str:
469
472
  def get_tokenizer(
470
473
  pretrained_model_name_or_path: str,
471
474
  ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
475
+ assert (
476
+ pretrained_model_name_or_path is not None
477
+ and pretrained_model_name_or_path != ""
478
+ )
472
479
  if pretrained_model_name_or_path.endswith(
473
480
  ".json"
474
481
  ) or pretrained_model_name_or_path.endswith(".model"):
@@ -582,7 +589,7 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
582
589
  filename = os.path.join("/tmp", url.split("/")[-1])
583
590
 
584
591
  # Check if the cache file already exists
585
- if os.path.exists(filename):
592
+ if is_file_valid_json(filename):
586
593
  return filename
587
594
 
588
595
  print(f"Downloading from {url} to {filename}")
@@ -610,12 +617,35 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
610
617
  return filename
611
618
 
612
619
 
620
+ def is_file_valid_json(path):
621
+ if not os.path.isfile(path):
622
+ return False
623
+
624
+ # TODO can fuse into the real file open later
625
+ try:
626
+ with open(path) as f:
627
+ json.load(f)
628
+ return True
629
+ except JSONDecodeError as e:
630
+ print(
631
+ f"{path} exists but json loading fails ({e=}), thus treat as invalid file"
632
+ )
633
+ return False
634
+
635
+
636
+ @dataclass
637
+ class DatasetRow:
638
+ prompt: str
639
+ prompt_len: int
640
+ output_len: int
641
+
642
+
613
643
  def sample_mmmu_requests(
614
644
  num_requests: int,
615
645
  tokenizer: PreTrainedTokenizerBase,
616
646
  fixed_output_len: Optional[int] = None,
617
647
  random_sample: bool = True,
618
- ) -> List[Tuple[str, int, int]]:
648
+ ) -> List[DatasetRow]:
619
649
  """
620
650
  Sample requests from the MMMU dataset using HuggingFace datasets.
621
651
 
@@ -716,7 +746,11 @@ def sample_mmmu_requests(
716
746
 
717
747
  output_len = fixed_output_len if fixed_output_len is not None else 256
718
748
 
719
- filtered_dataset.append((prompt, prompt_len, output_len))
749
+ filtered_dataset.append(
750
+ DatasetRow(
751
+ prompt=prompt, prompt_len=prompt_len, output_len=output_len
752
+ )
753
+ )
720
754
 
721
755
  except Exception as e:
722
756
  print(f"Error processing example {i}: {e}")
@@ -733,12 +767,12 @@ def sample_sharegpt_requests(
733
767
  context_len: Optional[int] = None,
734
768
  prompt_suffix: Optional[str] = "",
735
769
  apply_chat_template=False,
736
- ) -> List[Tuple[str, int, int]]:
770
+ ) -> List[DatasetRow]:
737
771
  if fixed_output_len is not None and fixed_output_len < 4:
738
772
  raise ValueError("output_len too small")
739
773
 
740
774
  # Download sharegpt if necessary
741
- if not os.path.isfile(dataset_path) and dataset_path == "":
775
+ if not is_file_valid_json(dataset_path) and dataset_path == "":
742
776
  dataset_path = download_and_cache_file(SHAREGPT_URL)
743
777
 
744
778
  # Load the dataset.
@@ -764,7 +798,7 @@ def sample_sharegpt_requests(
764
798
  random.shuffle(dataset)
765
799
 
766
800
  # Filter out sequences that are too long or too short
767
- filtered_dataset: List[Tuple[str, int, int]] = []
801
+ filtered_dataset: List[DatasetRow] = []
768
802
  for i in range(len(dataset)):
769
803
  if len(filtered_dataset) == num_requests:
770
804
  break
@@ -802,10 +836,12 @@ def sample_sharegpt_requests(
802
836
  # Prune too long sequences.
803
837
  continue
804
838
 
805
- filtered_dataset.append((prompt, prompt_len, output_len))
839
+ filtered_dataset.append(
840
+ DatasetRow(prompt=prompt, prompt_len=prompt_len, output_len=output_len)
841
+ )
806
842
 
807
- print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
808
- print(f"#Output tokens: {np.sum([x[2] for x in filtered_dataset])}")
843
+ print(f"#Input tokens: {np.sum([x.prompt_len for x in filtered_dataset])}")
844
+ print(f"#Output tokens: {np.sum([x.output_len for x in filtered_dataset])}")
809
845
  return filtered_dataset
810
846
 
811
847
 
@@ -817,7 +853,8 @@ def sample_random_requests(
817
853
  tokenizer: PreTrainedTokenizerBase,
818
854
  dataset_path: str,
819
855
  random_sample: bool = True,
820
- ) -> List[Tuple[str, int, int]]:
856
+ return_text: bool = True,
857
+ ) -> List[DatasetRow]:
821
858
  input_lens = np.random.randint(
822
859
  max(int(input_len * range_ratio), 1),
823
860
  input_len + 1,
@@ -833,7 +870,7 @@ def sample_random_requests(
833
870
  # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
834
871
 
835
872
  # Download sharegpt if necessary
836
- if not os.path.isfile(dataset_path):
873
+ if not is_file_valid_json(dataset_path):
837
874
  dataset_path = download_and_cache_file(SHAREGPT_URL)
838
875
 
839
876
  # Load the dataset.
@@ -857,7 +894,7 @@ def sample_random_requests(
857
894
  random.shuffle(dataset)
858
895
 
859
896
  # Filter out sequences that are too long or too short
860
- input_requests: List[Tuple[str, int, int]] = []
897
+ input_requests: List[DatasetRow] = []
861
898
  for data in dataset:
862
899
  i = len(input_requests)
863
900
  if i == num_prompts:
@@ -877,20 +914,34 @@ def sample_random_requests(
877
914
  else:
878
915
  ratio = (input_lens[i] + prompt_len - 1) // prompt_len
879
916
  input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
880
- prompt = tokenizer.decode(input_ids)
881
- input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
917
+ input_content = input_ids
918
+ if return_text:
919
+ input_content = tokenizer.decode(input_content)
920
+ input_requests.append(
921
+ DatasetRow(
922
+ prompt=input_content,
923
+ prompt_len=int(input_lens[i]),
924
+ output_len=int(output_lens[i]),
925
+ )
926
+ )
882
927
  else:
883
928
  # Sample token ids from random integers. This can cause some NaN issues.
884
929
  offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
885
930
  input_requests = []
886
931
  for i in range(num_prompts):
887
- prompt = tokenizer.decode(
888
- [
889
- (offsets[i] + i + j) % tokenizer.vocab_size
890
- for j in range(input_lens[i])
891
- ]
932
+ input_content = [
933
+ (offsets[i] + i + j) % tokenizer.vocab_size
934
+ for j in range(input_lens[i])
935
+ ]
936
+ if return_text:
937
+ input_content = tokenizer.decode(input_content)
938
+ input_requests.append(
939
+ DatasetRow(
940
+ prompt=input_content,
941
+ prompt_len=int(input_lens[i]),
942
+ output_len=int(output_lens[i]),
943
+ )
892
944
  )
893
- input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
894
945
 
895
946
  print(f"#Input tokens: {np.sum(input_lens)}")
896
947
  print(f"#Output tokens: {np.sum(output_lens)}")
@@ -925,7 +976,7 @@ def sample_generated_shared_prefix_requests(
925
976
  output_len: int,
926
977
  tokenizer: PreTrainedTokenizerBase,
927
978
  args: argparse.Namespace,
928
- ) -> List[Tuple[str, int, int]]:
979
+ ) -> List[DatasetRow]:
929
980
  """Generate benchmark requests with shared system prompts using random tokens and caching."""
930
981
  cache_path = get_gen_prefix_cache_path(args, tokenizer)
931
982
 
@@ -963,7 +1014,11 @@ def sample_generated_shared_prefix_requests(
963
1014
  full_prompt = f"{system_prompt}\n\n{question}"
964
1015
  prompt_len = len(tokenizer.encode(full_prompt))
965
1016
 
966
- input_requests.append((full_prompt, prompt_len, output_len))
1017
+ input_requests.append(
1018
+ DatasetRow(
1019
+ prompt=full_prompt, prompt_len=prompt_len, output_len=output_len
1020
+ )
1021
+ )
967
1022
  total_input_tokens += prompt_len
968
1023
  total_output_tokens += output_len
969
1024
 
@@ -994,9 +1049,9 @@ def sample_generated_shared_prefix_requests(
994
1049
 
995
1050
 
996
1051
  async def get_request(
997
- input_requests: List[Tuple[str, int, int]],
1052
+ input_requests: List[DatasetRow],
998
1053
  request_rate: float,
999
- ) -> AsyncGenerator[Tuple[str, int, int], None]:
1054
+ ) -> AsyncGenerator[DatasetRow, None]:
1000
1055
  input_requests = iter(input_requests)
1001
1056
  for request in input_requests:
1002
1057
  yield request
@@ -1012,7 +1067,7 @@ async def get_request(
1012
1067
 
1013
1068
 
1014
1069
  def calculate_metrics(
1015
- input_requests: List[Tuple[str, int, int]],
1070
+ input_requests: List[DatasetRow],
1016
1071
  outputs: List[RequestFuncOutput],
1017
1072
  dur_s: float,
1018
1073
  tokenizer: PreTrainedTokenizerBase,
@@ -1034,7 +1089,7 @@ def calculate_metrics(
1034
1089
  tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
1035
1090
  )
1036
1091
  retokenized_output_lens.append(retokenized_output_len)
1037
- total_input += input_requests[i][1]
1092
+ total_input += input_requests[i].prompt_len
1038
1093
  if output_len > 1:
1039
1094
  tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
1040
1095
  itls += outputs[i].itl
@@ -1096,14 +1151,14 @@ async def benchmark(
1096
1151
  base_url: str,
1097
1152
  model_id: str,
1098
1153
  tokenizer: PreTrainedTokenizerBase,
1099
- input_requests: List[Tuple[str, int, int]],
1154
+ input_requests: List[DatasetRow],
1100
1155
  request_rate: float,
1101
1156
  max_concurrency: Optional[int],
1102
1157
  disable_tqdm: bool,
1103
1158
  lora_names: List[str],
1104
1159
  extra_request_body: Dict[str, Any],
1105
1160
  profile: bool,
1106
- pd_seperated: bool = False,
1161
+ pd_separated: bool = False,
1107
1162
  flush_cache: bool = False,
1108
1163
  warmup_requests: int = 1,
1109
1164
  ):
@@ -1126,7 +1181,12 @@ async def benchmark(
1126
1181
  print(f"Starting warmup with {warmup_requests} sequences...")
1127
1182
 
1128
1183
  # Use the first request for all warmup iterations
1129
- test_prompt, test_prompt_len, test_output_len = input_requests[0]
1184
+ test_request = input_requests[0]
1185
+ test_prompt, test_prompt_len, test_output_len = (
1186
+ test_request.prompt,
1187
+ test_request.prompt_len,
1188
+ test_request.output_len,
1189
+ )
1130
1190
  if lora_names is not None and len(lora_names) != 0:
1131
1191
  lora_name = lora_names[0]
1132
1192
  else:
@@ -1194,7 +1254,11 @@ async def benchmark(
1194
1254
  benchmark_start_time = time.perf_counter()
1195
1255
  tasks: List[asyncio.Task] = []
1196
1256
  async for request in get_request(input_requests, request_rate):
1197
- prompt, prompt_len, output_len = request
1257
+ prompt, prompt_len, output_len = (
1258
+ request.prompt,
1259
+ request.prompt_len,
1260
+ request.output_len,
1261
+ )
1198
1262
  if lora_names is not None and len(lora_names) != 0:
1199
1263
  idx = random.randint(0, len(lora_names) - 1)
1200
1264
  lora_name = lora_names[idx]
@@ -1239,12 +1303,17 @@ async def benchmark(
1239
1303
 
1240
1304
  if "sglang" in backend:
1241
1305
  server_info = requests.get(base_url + "/get_server_info")
1242
- if pd_seperated:
1243
- accept_length = server_info.json()["decode"][0].get(
1244
- "avg_spec_accept_length", None
1245
- )
1306
+ if server_info.status_code == 200:
1307
+ if pd_separated:
1308
+ accept_length = server_info.json()["decode"][0]["internal_states"][
1309
+ 0
1310
+ ].get("avg_spec_accept_length", None)
1311
+ else:
1312
+ accept_length = server_info.json()["internal_states"][0].get(
1313
+ "avg_spec_accept_length", None
1314
+ )
1246
1315
  else:
1247
- accept_length = server_info.json().get("avg_spec_accept_length", None)
1316
+ accept_length = None
1248
1317
  else:
1249
1318
  accept_length = None
1250
1319
 
@@ -1263,7 +1332,7 @@ async def benchmark(
1263
1332
  print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
1264
1333
  print(
1265
1334
  "{:<40} {:<10}".format(
1266
- "Max reqeuest concurrency:",
1335
+ "Max request concurrency:",
1267
1336
  max_concurrency if max_concurrency else "not set",
1268
1337
  )
1269
1338
  )
@@ -1378,21 +1447,24 @@ async def benchmark(
1378
1447
  else:
1379
1448
  output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
1380
1449
 
1450
+ result_details = {
1451
+ "input_lens": [output.prompt_len for output in outputs],
1452
+ "output_lens": output_lens,
1453
+ "ttfts": [output.ttft for output in outputs],
1454
+ "itls": [output.itl for output in outputs],
1455
+ "generated_texts": [output.generated_text for output in outputs],
1456
+ "errors": [output.error for output in outputs],
1457
+ }
1458
+
1381
1459
  # Append results to a JSONL file
1382
1460
  with open(output_file_name, "a") as file:
1383
- file.write(json.dumps(result) + "\n")
1384
-
1385
- result.update(
1386
- {
1387
- "input_lens": [output.prompt_len for output in outputs],
1388
- "output_lens": output_lens,
1389
- "ttfts": [output.ttft for output in outputs],
1390
- "itls": [output.itl for output in outputs],
1391
- "generated_texts": [output.generated_text for output in outputs],
1392
- "errors": [output.error for output in outputs],
1393
- }
1394
- )
1395
- return result
1461
+ if args.output_details:
1462
+ result_for_dump = result | result_details
1463
+ else:
1464
+ result_for_dump = result
1465
+ file.write(json.dumps(result_for_dump) + "\n")
1466
+
1467
+ return result | result_details
1396
1468
 
1397
1469
 
1398
1470
  def check_chat_template(model_path):
@@ -1422,6 +1494,9 @@ def run_benchmark(args_: argparse.Namespace):
1422
1494
  if not hasattr(args, "warmup_requests"):
1423
1495
  args.warmup_requests = 1
1424
1496
 
1497
+ if not hasattr(args, "output_details"):
1498
+ args.output_details = False
1499
+
1425
1500
  print(f"benchmark_args={args}")
1426
1501
 
1427
1502
  # Set global environments
@@ -1541,7 +1616,7 @@ def run_benchmark(args_: argparse.Namespace):
1541
1616
  lora_names=args.lora_name,
1542
1617
  extra_request_body=extra_request_body,
1543
1618
  profile=args.profile,
1544
- pd_seperated=args.pd_seperated,
1619
+ pd_separated=args.pd_separated,
1545
1620
  flush_cache=args.flush_cache,
1546
1621
  )
1547
1622
  )
@@ -1666,6 +1741,9 @@ if __name__ == "__main__":
1666
1741
  "if the server is not processing requests fast enough to keep up.",
1667
1742
  )
1668
1743
  parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
1744
+ parser.add_argument(
1745
+ "--output-details", action="store_true", help="Output details of benchmarking."
1746
+ )
1669
1747
  parser.add_argument(
1670
1748
  "--disable-tqdm",
1671
1749
  action="store_true",
@@ -1720,7 +1798,7 @@ if __name__ == "__main__":
1720
1798
  help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
1721
1799
  )
1722
1800
  parser.add_argument(
1723
- "--pd-seperated",
1801
+ "--pd-separated",
1724
1802
  action="store_true",
1725
1803
  help="Benchmark PD disaggregation server",
1726
1804
  )
@@ -82,8 +82,8 @@ def launch_server_process_and_send_one_request(
82
82
  base_url = f"http://{server_args.host}:{server_args.port}"
83
83
  timeout = compile_args.timeout
84
84
 
85
- start_time = time.time()
86
- while time.time() - start_time < timeout:
85
+ start_time = time.perf_counter()
86
+ while time.perf_counter() - start_time < timeout:
87
87
  try:
88
88
  headers = {
89
89
  "Content-Type": "application/json; charset=utf-8",
@@ -112,9 +112,9 @@ def launch_server_process_and_send_one_request(
112
112
  raise RuntimeError(f"Sync request failed: {error}")
113
113
  # Other nodes should wait for the exit signal from Rank-0 node.
114
114
  else:
115
- start_time_waiting = time.time()
115
+ start_time_waiting = time.perf_counter()
116
116
  while proc.is_alive():
117
- if time.time() - start_time_waiting < timeout:
117
+ if time.perf_counter() - start_time_waiting < timeout:
118
118
  time.sleep(10)
119
119
  else:
120
120
  raise TimeoutError("Waiting for main node timeout!")
@@ -129,7 +129,7 @@ def launch_server_process_and_send_one_request(
129
129
 
130
130
 
131
131
  def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
132
- # Disbale cuda graph and torch compile to save time
132
+ # Disable cuda graph and torch compile to save time
133
133
  server_args.disable_cuda_graph = True
134
134
  server_args.enable_torch_compile = False
135
135
  print(f"Disable CUDA Graph and Torch Compile to save time...")
@@ -0,0 +1,157 @@
1
+ import argparse
2
+ import asyncio
3
+ import os
4
+ import pickle
5
+ from pathlib import Path
6
+ from typing import List
7
+
8
+ import openai
9
+ import torch
10
+ from bert_score import BERTScorer
11
+ from datasets import load_dataset
12
+ from tqdm import tqdm
13
+
14
+
15
+ def get_client(api_url: str) -> openai.AsyncOpenAI:
16
+ if os.getenv("OPENAI_API_KEY") is None:
17
+ os.environ["OPENAI_API_KEY"] = "EMPTY"
18
+ return openai.AsyncOpenAI(base_url=api_url)
19
+
20
+
21
+ def get_dataset():
22
+ return load_dataset("bigai-nlco/LooGLE", "longdep_qa", split="test")
23
+
24
+
25
+ async def fetch_response(
26
+ client: openai.AsyncOpenAI,
27
+ context: str,
28
+ question: str,
29
+ semaphore: asyncio.Semaphore,
30
+ index: int,
31
+ model: str,
32
+ output_dir: Path,
33
+ ):
34
+ output_file = output_dir / f"response_{index}.pkl"
35
+ if output_file.exists():
36
+ return
37
+
38
+ prompt = (
39
+ "Please answer the question based on the long texts below.\n"
40
+ f"{context}\n"
41
+ f"Question: {question}\n"
42
+ "Answer:"
43
+ )
44
+ messages = [
45
+ {"role": "system", "content": "You are a helpful assistant."},
46
+ {"role": "user", "content": prompt},
47
+ ]
48
+
49
+ async with semaphore:
50
+ try:
51
+ response = await client.chat.completions.create(
52
+ model=model,
53
+ messages=messages,
54
+ temperature=0.0,
55
+ max_tokens=512,
56
+ )
57
+ except openai.BadRequestError as e:
58
+ with open(output_file, "wb") as f:
59
+ pickle.dump({"error": str(e)}, f)
60
+ return
61
+
62
+ with open(output_file, "wb") as f:
63
+ pickle.dump(response, f)
64
+
65
+
66
+ async def benchmark(args):
67
+ dataset = get_dataset()
68
+ output_dir = Path(args.output_dir)
69
+ output_dir.mkdir(parents=True, exist_ok=True)
70
+
71
+ client = get_client(args.api_url)
72
+ semaphore = asyncio.Semaphore(args.max_concurrency)
73
+
74
+ tasks: List[asyncio.Task] = []
75
+ for idx, ex in enumerate(dataset):
76
+ tasks.append(
77
+ asyncio.create_task(
78
+ fetch_response(
79
+ client,
80
+ ex["context"],
81
+ ex["question"],
82
+ semaphore,
83
+ idx,
84
+ args.model,
85
+ output_dir,
86
+ )
87
+ )
88
+ )
89
+
90
+ for _ in tqdm(
91
+ asyncio.as_completed(tasks), total=len(tasks), desc="Running benchmark"
92
+ ):
93
+ await _
94
+
95
+
96
+ def analyse(args):
97
+ dataset = get_dataset()
98
+ output_dir = Path(args.output_dir)
99
+
100
+ device = "cuda" if torch.cuda.is_available() else "cpu"
101
+ scorer = BERTScorer(lang="en", device=device)
102
+
103
+ hyps: List[str] = []
104
+ refs: List[str] = []
105
+ for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")):
106
+ pkl_file = output_dir / f"response_{idx}.pkl"
107
+ if not pkl_file.exists():
108
+ raise FileNotFoundError(pkl_file)
109
+
110
+ response = pickle.load(open(pkl_file, "rb"))
111
+ if isinstance(response, dict) and "error" in response:
112
+ continue
113
+
114
+ hyps.append(response.choices[0].message.content.strip())
115
+ refs.append(ex["answer"])
116
+
117
+ if not hyps:
118
+ print("No valid responses to score!")
119
+ return
120
+
121
+ batch_size = 64
122
+ all_f1: List[float] = []
123
+ for i in tqdm(range(0, len(hyps), batch_size), desc="Scoring batches"):
124
+ h_batch = hyps[i : i + batch_size]
125
+ r_batch = refs[i : i + batch_size]
126
+ _, _, f1_scores = scorer.score(h_batch, r_batch, verbose=False)
127
+ all_f1.extend([float(x) for x in f1_scores])
128
+
129
+ avg = sum(all_f1) / len(all_f1)
130
+ print(f"Average BERTScore (F1): {avg:.2%}")
131
+
132
+
133
+ if __name__ == "__main__":
134
+ parser = argparse.ArgumentParser(
135
+ description="Run benchmark and evaluation in one go."
136
+ )
137
+ parser.add_argument(
138
+ "--api-url",
139
+ default="http://127.0.0.1:30000/v1",
140
+ help="OpenAI‑compatible API base URL",
141
+ )
142
+ parser.add_argument(
143
+ "--model",
144
+ default="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
145
+ help="Model name or ID, only used for model name",
146
+ )
147
+ parser.add_argument(
148
+ "--max-concurrency", type=int, default=144, help="Maximum concurrent requests"
149
+ )
150
+ parser.add_argument(
151
+ "--output-dir", default="tmp-output-dir", help="Directory for cached responses"
152
+ )
153
+ args = parser.parse_args()
154
+
155
+ asyncio.run(benchmark(args))
156
+
157
+ analyse(args)