sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +5 -4
  3. sglang/bench_one_batch_server.py +23 -15
  4. sglang/bench_serving.py +133 -57
  5. sglang/compile_deep_gemm.py +4 -4
  6. sglang/srt/configs/model_config.py +39 -28
  7. sglang/srt/conversation.py +1 -1
  8. sglang/srt/disaggregation/decode.py +122 -133
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  10. sglang/srt/disaggregation/fake/conn.py +3 -13
  11. sglang/srt/disaggregation/kv_events.py +357 -0
  12. sglang/srt/disaggregation/mini_lb.py +57 -24
  13. sglang/srt/disaggregation/mooncake/conn.py +11 -2
  14. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  15. sglang/srt/disaggregation/nixl/conn.py +9 -19
  16. sglang/srt/disaggregation/prefill.py +126 -44
  17. sglang/srt/disaggregation/utils.py +116 -5
  18. sglang/srt/distributed/utils.py +3 -3
  19. sglang/srt/entrypoints/EngineBase.py +5 -0
  20. sglang/srt/entrypoints/engine.py +28 -8
  21. sglang/srt/entrypoints/http_server.py +6 -4
  22. sglang/srt/entrypoints/http_server_engine.py +5 -2
  23. sglang/srt/function_call/base_format_detector.py +250 -0
  24. sglang/srt/function_call/core_types.py +34 -0
  25. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  26. sglang/srt/function_call/ebnf_composer.py +234 -0
  27. sglang/srt/function_call/function_call_parser.py +175 -0
  28. sglang/srt/function_call/llama32_detector.py +74 -0
  29. sglang/srt/function_call/mistral_detector.py +84 -0
  30. sglang/srt/function_call/pythonic_detector.py +163 -0
  31. sglang/srt/function_call/qwen25_detector.py +67 -0
  32. sglang/srt/function_call/utils.py +35 -0
  33. sglang/srt/hf_transformers_utils.py +46 -7
  34. sglang/srt/layers/attention/aiter_backend.py +513 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +63 -17
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  37. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  38. sglang/srt/layers/attention/triton_backend.py +3 -0
  39. sglang/srt/layers/attention/utils.py +2 -2
  40. sglang/srt/layers/attention/vision.py +1 -1
  41. sglang/srt/layers/communicator.py +451 -0
  42. sglang/srt/layers/dp_attention.py +0 -10
  43. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  44. sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
  45. sglang/srt/layers/moe/ep_moe/layer.py +104 -50
  46. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  48. sglang/srt/layers/moe/topk.py +66 -9
  49. sglang/srt/layers/multimodal.py +70 -0
  50. sglang/srt/layers/quantization/__init__.py +7 -2
  51. sglang/srt/layers/quantization/deep_gemm.py +5 -3
  52. sglang/srt/layers/quantization/fp8.py +90 -0
  53. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  54. sglang/srt/layers/quantization/gptq.py +298 -6
  55. sglang/srt/layers/quantization/int8_kernel.py +18 -5
  56. sglang/srt/layers/quantization/qoq.py +244 -0
  57. sglang/srt/lora/lora_manager.py +1 -3
  58. sglang/srt/managers/deepseek_eplb.py +278 -0
  59. sglang/srt/managers/eplb_manager.py +55 -0
  60. sglang/srt/managers/expert_distribution.py +704 -56
  61. sglang/srt/managers/expert_location.py +394 -0
  62. sglang/srt/managers/expert_location_dispatch.py +91 -0
  63. sglang/srt/managers/io_struct.py +16 -3
  64. sglang/srt/managers/mm_utils.py +293 -139
  65. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  66. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  67. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  68. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  69. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  70. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  71. sglang/srt/managers/multimodal_processors/llava.py +3 -3
  72. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  73. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  74. sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
  75. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  76. sglang/srt/managers/schedule_batch.py +49 -21
  77. sglang/srt/managers/schedule_policy.py +4 -5
  78. sglang/srt/managers/scheduler.py +92 -50
  79. sglang/srt/managers/session_controller.py +1 -1
  80. sglang/srt/managers/tokenizer_manager.py +99 -24
  81. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  82. sglang/srt/mem_cache/chunk_cache.py +3 -1
  83. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  84. sglang/srt/mem_cache/memory_pool.py +74 -52
  85. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  86. sglang/srt/mem_cache/radix_cache.py +58 -5
  87. sglang/srt/metrics/collector.py +2 -2
  88. sglang/srt/mm_utils.py +10 -0
  89. sglang/srt/model_executor/cuda_graph_runner.py +20 -9
  90. sglang/srt/model_executor/expert_location_updater.py +422 -0
  91. sglang/srt/model_executor/forward_batch_info.py +4 -0
  92. sglang/srt/model_executor/model_runner.py +144 -54
  93. sglang/srt/model_loader/loader.py +10 -6
  94. sglang/srt/models/clip.py +5 -1
  95. sglang/srt/models/deepseek_v2.py +297 -343
  96. sglang/srt/models/exaone.py +8 -3
  97. sglang/srt/models/gemma3_mm.py +70 -33
  98. sglang/srt/models/llama4.py +10 -2
  99. sglang/srt/models/llava.py +26 -18
  100. sglang/srt/models/mimo_mtp.py +220 -0
  101. sglang/srt/models/minicpmo.py +5 -12
  102. sglang/srt/models/mistral.py +71 -1
  103. sglang/srt/models/mllama.py +3 -3
  104. sglang/srt/models/qwen2.py +95 -26
  105. sglang/srt/models/qwen2_5_vl.py +8 -0
  106. sglang/srt/models/qwen2_moe.py +330 -60
  107. sglang/srt/models/qwen2_vl.py +6 -0
  108. sglang/srt/models/qwen3.py +52 -10
  109. sglang/srt/models/qwen3_moe.py +411 -48
  110. sglang/srt/models/siglip.py +294 -0
  111. sglang/srt/openai_api/adapter.py +28 -16
  112. sglang/srt/openai_api/protocol.py +6 -0
  113. sglang/srt/operations.py +154 -0
  114. sglang/srt/operations_strategy.py +31 -0
  115. sglang/srt/server_args.py +134 -24
  116. sglang/srt/speculative/eagle_utils.py +131 -0
  117. sglang/srt/speculative/eagle_worker.py +47 -2
  118. sglang/srt/utils.py +68 -12
  119. sglang/test/test_cutlass_moe.py +278 -0
  120. sglang/test/test_utils.py +2 -36
  121. sglang/utils.py +2 -2
  122. sglang/version.py +1 -1
  123. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
  124. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
  125. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  126. sglang/srt/function_call_parser.py +0 -858
  127. sglang/srt/platforms/interface.py +0 -371
  128. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  129. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -17,11 +17,12 @@ import logging
17
17
  import os
18
18
  import random
19
19
  import time
20
- from typing import Dict, List, Optional, Tuple
20
+ from typing import Dict, List, Optional
21
21
 
22
22
  import numpy as np
23
23
 
24
24
  from sglang.bench_serving import (
25
+ DatasetRow,
25
26
  get_dataset,
26
27
  get_tokenizer,
27
28
  sample_random_requests,
@@ -194,7 +195,7 @@ class BenchArgs:
194
195
  def throughput_test_once(
195
196
  backend_name: str,
196
197
  backend,
197
- reqs: List[Tuple[str, int, int]],
198
+ reqs: List[DatasetRow],
198
199
  ignore_eos: bool,
199
200
  extra_request_body: Dict,
200
201
  profile: bool,
@@ -203,7 +204,7 @@ def throughput_test_once(
203
204
  "backend": backend_name,
204
205
  "successful_requests": len(reqs),
205
206
  "total_latency": -1,
206
- "total_input_tokens": sum(r[1] for r in reqs),
207
+ "total_input_tokens": sum(r.prompt_len for r in reqs),
207
208
  "total_output_tokens": -1,
208
209
  "request_throughput": -1,
209
210
  "input_throughput": -1,
@@ -211,11 +212,11 @@ def throughput_test_once(
211
212
  "total_throughput": -1,
212
213
  }
213
214
 
214
- prompt = [r[0] for r in reqs]
215
+ prompt = [r.prompt for r in reqs]
215
216
  sampling_params = [
216
217
  {
217
218
  "temperature": 0,
218
- "max_new_tokens": r[2],
219
+ "max_new_tokens": r.output_len,
219
220
  "ignore_eos": ignore_eos,
220
221
  **extra_request_body,
221
222
  }
@@ -267,7 +268,6 @@ def throughput_test_once(
267
268
 
268
269
 
269
270
  def monitor_trace_file(directory, interval=1):
270
-
271
271
  print(f"Monitoring {directory} for new trace files...")
272
272
 
273
273
  known_files = set(os.listdir(directory))
sglang/bench_one_batch.py CHANGED
@@ -269,6 +269,7 @@ def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
269
269
  batch,
270
270
  dp_size=model_runner.server_args.dp_size,
271
271
  attn_tp_size=1,
272
+ moe_dense_tp_size=model_runner.server_args.moe_dense_tp_size,
272
273
  tp_cpu_group=model_runner.tp_group.cpu_group,
273
274
  get_idle_batch=None,
274
275
  disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
@@ -372,10 +373,10 @@ def latency_test_run_once(
372
373
 
373
374
  # Prefill
374
375
  synchronize(device)
375
- tic = time.time()
376
+ tic = time.perf_counter()
376
377
  next_token_ids, _, batch = extend(reqs, model_runner)
377
378
  synchronize(device)
378
- prefill_latency = time.time() - tic
379
+ prefill_latency = time.perf_counter() - tic
379
380
  tot_latency += prefill_latency
380
381
  throughput = input_len * batch_size / prefill_latency
381
382
  rank_print(
@@ -388,10 +389,10 @@ def latency_test_run_once(
388
389
  decode_latencies = []
389
390
  for i in range(output_len - 1):
390
391
  synchronize(device)
391
- tic = time.time()
392
+ tic = time.perf_counter()
392
393
  next_token_ids, _ = decode(next_token_ids, batch, model_runner)
393
394
  synchronize(device)
394
- latency = time.time() - tic
395
+ latency = time.perf_counter() - tic
395
396
  tot_latency += latency
396
397
  throughput = batch_size / latency
397
398
  decode_latencies.append(latency)
@@ -22,6 +22,7 @@ from typing import Tuple
22
22
  import numpy as np
23
23
  import requests
24
24
 
25
+ from sglang.bench_serving import get_tokenizer, sample_random_requests
25
26
  from sglang.srt.entrypoints.http_server import launch_server
26
27
  from sglang.srt.server_args import ServerArgs
27
28
  from sglang.srt.utils import kill_process_tree
@@ -92,8 +93,8 @@ def launch_server_process(server_args: ServerArgs):
92
93
  base_url = f"http://{server_args.host}:{server_args.port}"
93
94
  timeout = 600
94
95
 
95
- start_time = time.time()
96
- while time.time() - start_time < timeout:
96
+ start_time = time.perf_counter()
97
+ while time.perf_counter() - start_time < timeout:
97
98
  try:
98
99
  headers = {
99
100
  "Content-Type": "application/json; charset=utf-8",
@@ -117,16 +118,19 @@ def run_one_case(
117
118
  input_len_step_percentage: float,
118
119
  run_name: str,
119
120
  result_filename: str,
121
+ tokenizer,
120
122
  ):
121
123
  requests.post(url + "/flush_cache")
122
- input_lens = [
123
- int(input_len * (1 + (i - (batch_size - 1) / 2) * input_len_step_percentage))
124
- for i in range(batch_size)
125
- ]
126
- input_ids = [
127
- [int(x) for x in np.random.randint(0, high=16384, size=(input_lens[i],))]
128
- for i in range(batch_size)
129
- ]
124
+ input_requests = sample_random_requests(
125
+ input_len=input_len,
126
+ output_len=output_len,
127
+ num_prompts=batch_size,
128
+ range_ratio=1.0,
129
+ tokenizer=tokenizer,
130
+ dataset_path="",
131
+ random_sample=True,
132
+ return_text=False,
133
+ )
130
134
 
131
135
  use_structured_outputs = False
132
136
  if use_structured_outputs:
@@ -141,12 +145,11 @@ def run_one_case(
141
145
  else:
142
146
  json_schema = None
143
147
 
144
- tic = time.time()
148
+ tic = time.perf_counter()
145
149
  response = requests.post(
146
150
  url + "/generate",
147
151
  json={
148
- # "text": texts,
149
- "input_ids": input_ids,
152
+ "input_ids": [req.prompt for req in input_requests],
150
153
  "sampling_params": {
151
154
  "temperature": temperature,
152
155
  "max_new_tokens": output_len,
@@ -175,9 +178,9 @@ def run_one_case(
175
178
  or data["meta_info"]["finish_reason"]["type"] == "length"
176
179
  )
177
180
  if data["meta_info"]["completion_tokens"] == 1:
178
- ttft = time.time() - tic
181
+ ttft = time.perf_counter() - tic
179
182
 
180
- latency = time.time() - tic
183
+ latency = time.perf_counter() - tic
181
184
  input_throughput = batch_size * input_len / ttft
182
185
  output_throughput = batch_size * output_len / (latency - ttft)
183
186
  overall_throughput = batch_size * (input_len + output_len) / latency
@@ -228,6 +231,9 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
228
231
  else:
229
232
  proc, base_url = launch_server_process(server_args)
230
233
 
234
+ tokenizer_id = server_args.tokenizer_path or server_args.model_path
235
+ tokenizer = get_tokenizer(tokenizer_id)
236
+
231
237
  # warmup
232
238
  if not bench_args.skip_warmup:
233
239
  print("=" * 8 + " Warmup Begin " + "=" * 8)
@@ -241,6 +247,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
241
247
  input_len_step_percentage=bench_args.input_len_step_percentage,
242
248
  run_name="",
243
249
  result_filename="",
250
+ tokenizer=tokenizer,
244
251
  )
245
252
  print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
246
253
 
@@ -261,6 +268,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
261
268
  input_len_step_percentage=bench_args.input_len_step_percentage,
262
269
  run_name=bench_args.run_name,
263
270
  result_filename=bench_args.result_filename,
271
+ tokenizer=tokenizer,
264
272
  )
265
273
  )
266
274
  finally:
sglang/bench_serving.py CHANGED
@@ -24,6 +24,7 @@ import warnings
24
24
  from argparse import ArgumentParser
25
25
  from dataclasses import dataclass, field
26
26
  from datetime import datetime
27
+ from json import JSONDecodeError
27
28
  from pathlib import Path
28
29
  from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
29
30
 
@@ -73,6 +74,12 @@ class RequestFuncOutput:
73
74
  error: str = ""
74
75
  output_len: int = 0
75
76
 
77
+ @staticmethod
78
+ def init_new(request_func_input: RequestFuncInput):
79
+ output = RequestFuncOutput()
80
+ output.prompt_len = request_func_input.prompt_len
81
+ return output
82
+
76
83
 
77
84
  def remove_prefix(text: str, prefix: str) -> str:
78
85
  return text[len(prefix) :] if text.startswith(prefix) else text
@@ -114,8 +121,7 @@ async def async_request_trt_llm(
114
121
  if args.disable_ignore_eos:
115
122
  del payload["min_length"]
116
123
  del payload["end_id"]
117
- output = RequestFuncOutput()
118
- output.prompt_len = request_func_input.prompt_len
124
+ output = RequestFuncOutput.init_new(request_func_input)
119
125
 
120
126
  ttft = 0.0
121
127
  st = time.perf_counter()
@@ -186,8 +192,7 @@ async def async_request_openai_completions(
186
192
  }
187
193
  headers = get_auth_headers()
188
194
 
189
- output = RequestFuncOutput()
190
- output.prompt_len = request_func_input.prompt_len
195
+ output = RequestFuncOutput.init_new(request_func_input)
191
196
 
192
197
  generated_text = ""
193
198
  output_len = request_func_input.output_len
@@ -269,8 +274,7 @@ async def async_request_truss(
269
274
  }
270
275
  headers = get_auth_headers()
271
276
 
272
- output = RequestFuncOutput()
273
- output.prompt_len = request_func_input.prompt_len
277
+ output = RequestFuncOutput.init_new(request_func_input)
274
278
 
275
279
  generated_text = ""
276
280
  ttft = 0.0
@@ -355,8 +359,7 @@ async def async_request_sglang_generate(
355
359
 
356
360
  headers = get_auth_headers()
357
361
 
358
- output = RequestFuncOutput()
359
- output.prompt_len = request_func_input.prompt_len
362
+ output = RequestFuncOutput.init_new(request_func_input)
360
363
 
361
364
  generated_text = ""
362
365
  output_len = request_func_input.output_len
@@ -469,6 +472,10 @@ def get_model(pretrained_model_name_or_path: str) -> str:
469
472
  def get_tokenizer(
470
473
  pretrained_model_name_or_path: str,
471
474
  ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
475
+ assert (
476
+ pretrained_model_name_or_path is not None
477
+ and pretrained_model_name_or_path != ""
478
+ )
472
479
  if pretrained_model_name_or_path.endswith(
473
480
  ".json"
474
481
  ) or pretrained_model_name_or_path.endswith(".model"):
@@ -582,7 +589,7 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
582
589
  filename = os.path.join("/tmp", url.split("/")[-1])
583
590
 
584
591
  # Check if the cache file already exists
585
- if os.path.exists(filename):
592
+ if is_file_valid_json(filename):
586
593
  return filename
587
594
 
588
595
  print(f"Downloading from {url} to {filename}")
@@ -610,12 +617,35 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
610
617
  return filename
611
618
 
612
619
 
620
+ def is_file_valid_json(path):
621
+ if not os.path.isfile(path):
622
+ return False
623
+
624
+ # TODO can fuse into the real file open later
625
+ try:
626
+ with open(path) as f:
627
+ json.load(f)
628
+ return True
629
+ except JSONDecodeError as e:
630
+ print(
631
+ f"{path} exists but json loading fails ({e=}), thus treat as invalid file"
632
+ )
633
+ return False
634
+
635
+
636
+ @dataclass
637
+ class DatasetRow:
638
+ prompt: str
639
+ prompt_len: int
640
+ output_len: int
641
+
642
+
613
643
  def sample_mmmu_requests(
614
644
  num_requests: int,
615
645
  tokenizer: PreTrainedTokenizerBase,
616
646
  fixed_output_len: Optional[int] = None,
617
647
  random_sample: bool = True,
618
- ) -> List[Tuple[str, int, int]]:
648
+ ) -> List[DatasetRow]:
619
649
  """
620
650
  Sample requests from the MMMU dataset using HuggingFace datasets.
621
651
 
@@ -716,7 +746,11 @@ def sample_mmmu_requests(
716
746
 
717
747
  output_len = fixed_output_len if fixed_output_len is not None else 256
718
748
 
719
- filtered_dataset.append((prompt, prompt_len, output_len))
749
+ filtered_dataset.append(
750
+ DatasetRow(
751
+ prompt=prompt, prompt_len=prompt_len, output_len=output_len
752
+ )
753
+ )
720
754
 
721
755
  except Exception as e:
722
756
  print(f"Error processing example {i}: {e}")
@@ -733,12 +767,12 @@ def sample_sharegpt_requests(
733
767
  context_len: Optional[int] = None,
734
768
  prompt_suffix: Optional[str] = "",
735
769
  apply_chat_template=False,
736
- ) -> List[Tuple[str, int, int]]:
770
+ ) -> List[DatasetRow]:
737
771
  if fixed_output_len is not None and fixed_output_len < 4:
738
772
  raise ValueError("output_len too small")
739
773
 
740
774
  # Download sharegpt if necessary
741
- if not os.path.isfile(dataset_path) and dataset_path == "":
775
+ if not is_file_valid_json(dataset_path) and dataset_path == "":
742
776
  dataset_path = download_and_cache_file(SHAREGPT_URL)
743
777
 
744
778
  # Load the dataset.
@@ -764,7 +798,7 @@ def sample_sharegpt_requests(
764
798
  random.shuffle(dataset)
765
799
 
766
800
  # Filter out sequences that are too long or too short
767
- filtered_dataset: List[Tuple[str, int, int]] = []
801
+ filtered_dataset: List[DatasetRow] = []
768
802
  for i in range(len(dataset)):
769
803
  if len(filtered_dataset) == num_requests:
770
804
  break
@@ -802,10 +836,12 @@ def sample_sharegpt_requests(
802
836
  # Prune too long sequences.
803
837
  continue
804
838
 
805
- filtered_dataset.append((prompt, prompt_len, output_len))
839
+ filtered_dataset.append(
840
+ DatasetRow(prompt=prompt, prompt_len=prompt_len, output_len=output_len)
841
+ )
806
842
 
807
- print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
808
- print(f"#Output tokens: {np.sum([x[2] for x in filtered_dataset])}")
843
+ print(f"#Input tokens: {np.sum([x.prompt_len for x in filtered_dataset])}")
844
+ print(f"#Output tokens: {np.sum([x.output_len for x in filtered_dataset])}")
809
845
  return filtered_dataset
810
846
 
811
847
 
@@ -817,7 +853,8 @@ def sample_random_requests(
817
853
  tokenizer: PreTrainedTokenizerBase,
818
854
  dataset_path: str,
819
855
  random_sample: bool = True,
820
- ) -> List[Tuple[str, int, int]]:
856
+ return_text: bool = True,
857
+ ) -> List[DatasetRow]:
821
858
  input_lens = np.random.randint(
822
859
  max(int(input_len * range_ratio), 1),
823
860
  input_len + 1,
@@ -833,7 +870,7 @@ def sample_random_requests(
833
870
  # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
834
871
 
835
872
  # Download sharegpt if necessary
836
- if not os.path.isfile(dataset_path):
873
+ if not is_file_valid_json(dataset_path):
837
874
  dataset_path = download_and_cache_file(SHAREGPT_URL)
838
875
 
839
876
  # Load the dataset.
@@ -857,7 +894,7 @@ def sample_random_requests(
857
894
  random.shuffle(dataset)
858
895
 
859
896
  # Filter out sequences that are too long or too short
860
- input_requests: List[Tuple[str, int, int]] = []
897
+ input_requests: List[DatasetRow] = []
861
898
  for data in dataset:
862
899
  i = len(input_requests)
863
900
  if i == num_prompts:
@@ -877,20 +914,34 @@ def sample_random_requests(
877
914
  else:
878
915
  ratio = (input_lens[i] + prompt_len - 1) // prompt_len
879
916
  input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
880
- prompt = tokenizer.decode(input_ids)
881
- input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
917
+ input_content = input_ids
918
+ if return_text:
919
+ input_content = tokenizer.decode(input_content)
920
+ input_requests.append(
921
+ DatasetRow(
922
+ prompt=input_content,
923
+ prompt_len=int(input_lens[i]),
924
+ output_len=int(output_lens[i]),
925
+ )
926
+ )
882
927
  else:
883
928
  # Sample token ids from random integers. This can cause some NaN issues.
884
929
  offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
885
930
  input_requests = []
886
931
  for i in range(num_prompts):
887
- prompt = tokenizer.decode(
888
- [
889
- (offsets[i] + i + j) % tokenizer.vocab_size
890
- for j in range(input_lens[i])
891
- ]
932
+ input_content = [
933
+ (offsets[i] + i + j) % tokenizer.vocab_size
934
+ for j in range(input_lens[i])
935
+ ]
936
+ if return_text:
937
+ input_content = tokenizer.decode(input_content)
938
+ input_requests.append(
939
+ DatasetRow(
940
+ prompt=input_content,
941
+ prompt_len=int(input_lens[i]),
942
+ output_len=int(output_lens[i]),
943
+ )
892
944
  )
893
- input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
894
945
 
895
946
  print(f"#Input tokens: {np.sum(input_lens)}")
896
947
  print(f"#Output tokens: {np.sum(output_lens)}")
@@ -925,7 +976,7 @@ def sample_generated_shared_prefix_requests(
925
976
  output_len: int,
926
977
  tokenizer: PreTrainedTokenizerBase,
927
978
  args: argparse.Namespace,
928
- ) -> List[Tuple[str, int, int]]:
979
+ ) -> List[DatasetRow]:
929
980
  """Generate benchmark requests with shared system prompts using random tokens and caching."""
930
981
  cache_path = get_gen_prefix_cache_path(args, tokenizer)
931
982
 
@@ -963,7 +1014,11 @@ def sample_generated_shared_prefix_requests(
963
1014
  full_prompt = f"{system_prompt}\n\n{question}"
964
1015
  prompt_len = len(tokenizer.encode(full_prompt))
965
1016
 
966
- input_requests.append((full_prompt, prompt_len, output_len))
1017
+ input_requests.append(
1018
+ DatasetRow(
1019
+ prompt=full_prompt, prompt_len=prompt_len, output_len=output_len
1020
+ )
1021
+ )
967
1022
  total_input_tokens += prompt_len
968
1023
  total_output_tokens += output_len
969
1024
 
@@ -994,9 +1049,9 @@ def sample_generated_shared_prefix_requests(
994
1049
 
995
1050
 
996
1051
  async def get_request(
997
- input_requests: List[Tuple[str, int, int]],
1052
+ input_requests: List[DatasetRow],
998
1053
  request_rate: float,
999
- ) -> AsyncGenerator[Tuple[str, int, int], None]:
1054
+ ) -> AsyncGenerator[DatasetRow, None]:
1000
1055
  input_requests = iter(input_requests)
1001
1056
  for request in input_requests:
1002
1057
  yield request
@@ -1012,7 +1067,7 @@ async def get_request(
1012
1067
 
1013
1068
 
1014
1069
  def calculate_metrics(
1015
- input_requests: List[Tuple[str, int, int]],
1070
+ input_requests: List[DatasetRow],
1016
1071
  outputs: List[RequestFuncOutput],
1017
1072
  dur_s: float,
1018
1073
  tokenizer: PreTrainedTokenizerBase,
@@ -1034,7 +1089,7 @@ def calculate_metrics(
1034
1089
  tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
1035
1090
  )
1036
1091
  retokenized_output_lens.append(retokenized_output_len)
1037
- total_input += input_requests[i][1]
1092
+ total_input += input_requests[i].prompt_len
1038
1093
  if output_len > 1:
1039
1094
  tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
1040
1095
  itls += outputs[i].itl
@@ -1096,7 +1151,7 @@ async def benchmark(
1096
1151
  base_url: str,
1097
1152
  model_id: str,
1098
1153
  tokenizer: PreTrainedTokenizerBase,
1099
- input_requests: List[Tuple[str, int, int]],
1154
+ input_requests: List[DatasetRow],
1100
1155
  request_rate: float,
1101
1156
  max_concurrency: Optional[int],
1102
1157
  disable_tqdm: bool,
@@ -1126,7 +1181,12 @@ async def benchmark(
1126
1181
  print(f"Starting warmup with {warmup_requests} sequences...")
1127
1182
 
1128
1183
  # Use the first request for all warmup iterations
1129
- test_prompt, test_prompt_len, test_output_len = input_requests[0]
1184
+ test_request = input_requests[0]
1185
+ test_prompt, test_prompt_len, test_output_len = (
1186
+ test_request.prompt,
1187
+ test_request.prompt_len,
1188
+ test_request.output_len,
1189
+ )
1130
1190
  if lora_names is not None and len(lora_names) != 0:
1131
1191
  lora_name = lora_names[0]
1132
1192
  else:
@@ -1194,7 +1254,11 @@ async def benchmark(
1194
1254
  benchmark_start_time = time.perf_counter()
1195
1255
  tasks: List[asyncio.Task] = []
1196
1256
  async for request in get_request(input_requests, request_rate):
1197
- prompt, prompt_len, output_len = request
1257
+ prompt, prompt_len, output_len = (
1258
+ request.prompt,
1259
+ request.prompt_len,
1260
+ request.output_len,
1261
+ )
1198
1262
  if lora_names is not None and len(lora_names) != 0:
1199
1263
  idx = random.randint(0, len(lora_names) - 1)
1200
1264
  lora_name = lora_names[idx]
@@ -1239,14 +1303,17 @@ async def benchmark(
1239
1303
 
1240
1304
  if "sglang" in backend:
1241
1305
  server_info = requests.get(base_url + "/get_server_info")
1242
- if pd_separated:
1243
- accept_length = server_info.json()["decode"][0]["internal_states"][0].get(
1244
- "avg_spec_accept_length", None
1245
- )
1306
+ if server_info.status_code == 200:
1307
+ if pd_separated:
1308
+ accept_length = server_info.json()["decode"][0]["internal_states"][
1309
+ 0
1310
+ ].get("avg_spec_accept_length", None)
1311
+ else:
1312
+ accept_length = server_info.json()["internal_states"][0].get(
1313
+ "avg_spec_accept_length", None
1314
+ )
1246
1315
  else:
1247
- accept_length = server_info.json()["internal_states"][0].get(
1248
- "avg_spec_accept_length", None
1249
- )
1316
+ accept_length = None
1250
1317
  else:
1251
1318
  accept_length = None
1252
1319
 
@@ -1380,21 +1447,24 @@ async def benchmark(
1380
1447
  else:
1381
1448
  output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
1382
1449
 
1450
+ result_details = {
1451
+ "input_lens": [output.prompt_len for output in outputs],
1452
+ "output_lens": output_lens,
1453
+ "ttfts": [output.ttft for output in outputs],
1454
+ "itls": [output.itl for output in outputs],
1455
+ "generated_texts": [output.generated_text for output in outputs],
1456
+ "errors": [output.error for output in outputs],
1457
+ }
1458
+
1383
1459
  # Append results to a JSONL file
1384
1460
  with open(output_file_name, "a") as file:
1385
- file.write(json.dumps(result) + "\n")
1386
-
1387
- result.update(
1388
- {
1389
- "input_lens": [output.prompt_len for output in outputs],
1390
- "output_lens": output_lens,
1391
- "ttfts": [output.ttft for output in outputs],
1392
- "itls": [output.itl for output in outputs],
1393
- "generated_texts": [output.generated_text for output in outputs],
1394
- "errors": [output.error for output in outputs],
1395
- }
1396
- )
1397
- return result
1461
+ if args.output_details:
1462
+ result_for_dump = result | result_details
1463
+ else:
1464
+ result_for_dump = result
1465
+ file.write(json.dumps(result_for_dump) + "\n")
1466
+
1467
+ return result | result_details
1398
1468
 
1399
1469
 
1400
1470
  def check_chat_template(model_path):
@@ -1424,6 +1494,9 @@ def run_benchmark(args_: argparse.Namespace):
1424
1494
  if not hasattr(args, "warmup_requests"):
1425
1495
  args.warmup_requests = 1
1426
1496
 
1497
+ if not hasattr(args, "output_details"):
1498
+ args.output_details = False
1499
+
1427
1500
  print(f"benchmark_args={args}")
1428
1501
 
1429
1502
  # Set global environments
@@ -1668,6 +1741,9 @@ if __name__ == "__main__":
1668
1741
  "if the server is not processing requests fast enough to keep up.",
1669
1742
  )
1670
1743
  parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
1744
+ parser.add_argument(
1745
+ "--output-details", action="store_true", help="Output details of benchmarking."
1746
+ )
1671
1747
  parser.add_argument(
1672
1748
  "--disable-tqdm",
1673
1749
  action="store_true",
@@ -82,8 +82,8 @@ def launch_server_process_and_send_one_request(
82
82
  base_url = f"http://{server_args.host}:{server_args.port}"
83
83
  timeout = compile_args.timeout
84
84
 
85
- start_time = time.time()
86
- while time.time() - start_time < timeout:
85
+ start_time = time.perf_counter()
86
+ while time.perf_counter() - start_time < timeout:
87
87
  try:
88
88
  headers = {
89
89
  "Content-Type": "application/json; charset=utf-8",
@@ -112,9 +112,9 @@ def launch_server_process_and_send_one_request(
112
112
  raise RuntimeError(f"Sync request failed: {error}")
113
113
  # Other nodes should wait for the exit signal from Rank-0 node.
114
114
  else:
115
- start_time_waiting = time.time()
115
+ start_time_waiting = time.perf_counter()
116
116
  while proc.is_alive():
117
- if time.time() - start_time_waiting < timeout:
117
+ if time.perf_counter() - start_time_waiting < timeout:
118
118
  time.sleep(10)
119
119
  else:
120
120
  raise TimeoutError("Waiting for main node timeout!")