sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -17,11 +17,12 @@ import logging
17
17
  import os
18
18
  import random
19
19
  import time
20
- from typing import Dict, List, Optional, Tuple
20
+ from typing import Dict, List, Optional
21
21
 
22
22
  import numpy as np
23
23
 
24
24
  from sglang.bench_serving import (
25
+ DatasetRow,
25
26
  get_dataset,
26
27
  get_tokenizer,
27
28
  sample_random_requests,
@@ -194,7 +195,7 @@ class BenchArgs:
194
195
  def throughput_test_once(
195
196
  backend_name: str,
196
197
  backend,
197
- reqs: List[Tuple[str, int, int]],
198
+ reqs: List[DatasetRow],
198
199
  ignore_eos: bool,
199
200
  extra_request_body: Dict,
200
201
  profile: bool,
@@ -203,7 +204,7 @@ def throughput_test_once(
203
204
  "backend": backend_name,
204
205
  "successful_requests": len(reqs),
205
206
  "total_latency": -1,
206
- "total_input_tokens": sum(r[1] for r in reqs),
207
+ "total_input_tokens": sum(r.prompt_len for r in reqs),
207
208
  "total_output_tokens": -1,
208
209
  "request_throughput": -1,
209
210
  "input_throughput": -1,
@@ -211,11 +212,11 @@ def throughput_test_once(
211
212
  "total_throughput": -1,
212
213
  }
213
214
 
214
- prompt = [r[0] for r in reqs]
215
+ prompt = [r.prompt for r in reqs]
215
216
  sampling_params = [
216
217
  {
217
218
  "temperature": 0,
218
- "max_new_tokens": r[2],
219
+ "max_new_tokens": r.output_len,
219
220
  "ignore_eos": ignore_eos,
220
221
  **extra_request_body,
221
222
  }
@@ -259,13 +260,14 @@ def throughput_test_once(
259
260
  measurement_results["total_input_tokens"]
260
261
  + measurement_results["total_output_tokens"]
261
262
  ) / latency
262
- measurement_results["last_gen_throughput"] = server_info["last_gen_throughput"]
263
+ measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
264
+ "last_gen_throughput"
265
+ ]
263
266
 
264
267
  return measurement_results
265
268
 
266
269
 
267
270
  def monitor_trace_file(directory, interval=1):
268
-
269
271
  print(f"Monitoring {directory} for new trace files...")
270
272
 
271
273
  known_files = set(os.listdir(directory))
@@ -315,7 +317,7 @@ def throughput_test(
315
317
  tokenizer_id = server_args.tokenizer_path or server_args.model_path
316
318
  tokenizer = get_tokenizer(tokenizer_id)
317
319
 
318
- # Set global environmnets
320
+ # Set global environments
319
321
  set_ulimit()
320
322
  random.seed(bench_args.seed)
321
323
  np.random.seed(bench_args.seed)
sglang/bench_one_batch.py CHANGED
@@ -246,7 +246,7 @@ def extend(reqs, model_runner):
246
246
  _maybe_prepare_dp_attn_batch(batch, model_runner)
247
247
  model_worker_batch = batch.get_model_worker_batch()
248
248
  forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
249
- logits_output = model_runner.forward(forward_batch)
249
+ logits_output, _ = model_runner.forward(forward_batch)
250
250
  next_token_ids = model_runner.sample(logits_output, forward_batch)
251
251
  return next_token_ids, logits_output.next_token_logits, batch
252
252
 
@@ -258,7 +258,7 @@ def decode(input_token_ids, batch, model_runner):
258
258
  _maybe_prepare_dp_attn_batch(batch, model_runner)
259
259
  model_worker_batch = batch.get_model_worker_batch()
260
260
  forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
261
- logits_output = model_runner.forward(forward_batch)
261
+ logits_output, _ = model_runner.forward(forward_batch)
262
262
  next_token_ids = model_runner.sample(logits_output, forward_batch)
263
263
  return next_token_ids, logits_output.next_token_logits
264
264
 
@@ -269,6 +269,7 @@ def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
269
269
  batch,
270
270
  dp_size=model_runner.server_args.dp_size,
271
271
  attn_tp_size=1,
272
+ moe_dense_tp_size=model_runner.server_args.moe_dense_tp_size,
272
273
  tp_cpu_group=model_runner.tp_group.cpu_group,
273
274
  get_idle_batch=None,
274
275
  disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
@@ -372,10 +373,10 @@ def latency_test_run_once(
372
373
 
373
374
  # Prefill
374
375
  synchronize(device)
375
- tic = time.time()
376
+ tic = time.perf_counter()
376
377
  next_token_ids, _, batch = extend(reqs, model_runner)
377
378
  synchronize(device)
378
- prefill_latency = time.time() - tic
379
+ prefill_latency = time.perf_counter() - tic
379
380
  tot_latency += prefill_latency
380
381
  throughput = input_len * batch_size / prefill_latency
381
382
  rank_print(
@@ -388,10 +389,10 @@ def latency_test_run_once(
388
389
  decode_latencies = []
389
390
  for i in range(output_len - 1):
390
391
  synchronize(device)
391
- tic = time.time()
392
+ tic = time.perf_counter()
392
393
  next_token_ids, _ = decode(next_token_ids, batch, model_runner)
393
394
  synchronize(device)
394
- latency = time.time() - tic
395
+ latency = time.perf_counter() - tic
395
396
  tot_latency += latency
396
397
  throughput = batch_size / latency
397
398
  decode_latencies.append(latency)
@@ -22,9 +22,11 @@ from typing import Tuple
22
22
  import numpy as np
23
23
  import requests
24
24
 
25
+ from sglang.bench_serving import get_tokenizer, sample_random_requests
25
26
  from sglang.srt.entrypoints.http_server import launch_server
26
27
  from sglang.srt.server_args import ServerArgs
27
28
  from sglang.srt.utils import kill_process_tree
29
+ from sglang.test.test_utils import is_in_ci, write_github_step_summary
28
30
 
29
31
 
30
32
  @dataclasses.dataclass
@@ -33,9 +35,13 @@ class BenchArgs:
33
35
  batch_size: Tuple[int] = (1,)
34
36
  input_len: Tuple[int] = (1024,)
35
37
  output_len: Tuple[int] = (16,)
38
+ temperature: float = 0.0
39
+ return_logprob: bool = False
40
+ input_len_step_percentage: float = 0.0
36
41
  result_filename: str = "result.jsonl"
37
42
  base_url: str = ""
38
43
  skip_warmup: bool = False
44
+ show_report: bool = False
39
45
 
40
46
  @staticmethod
41
47
  def add_cli_args(parser: argparse.ArgumentParser):
@@ -49,11 +55,19 @@ class BenchArgs:
49
55
  parser.add_argument(
50
56
  "--output-len", type=int, nargs="+", default=BenchArgs.output_len
51
57
  )
58
+ parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
59
+ parser.add_argument("--return-logprob", action="store_true")
60
+ parser.add_argument(
61
+ "--input-len-step-percentage",
62
+ type=float,
63
+ default=BenchArgs.input_len_step_percentage,
64
+ )
52
65
  parser.add_argument(
53
66
  "--result-filename", type=str, default=BenchArgs.result_filename
54
67
  )
55
68
  parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
56
69
  parser.add_argument("--skip-warmup", action="store_true")
70
+ parser.add_argument("--show-report", action="store_true")
57
71
 
58
72
  @classmethod
59
73
  def from_cli_args(cls, args: argparse.Namespace):
@@ -79,8 +93,8 @@ def launch_server_process(server_args: ServerArgs):
79
93
  base_url = f"http://{server_args.host}:{server_args.port}"
80
94
  timeout = 600
81
95
 
82
- start_time = time.time()
83
- while time.time() - start_time < timeout:
96
+ start_time = time.perf_counter()
97
+ while time.perf_counter() - start_time < timeout:
84
98
  try:
85
99
  headers = {
86
100
  "Content-Type": "application/json; charset=utf-8",
@@ -99,36 +113,91 @@ def run_one_case(
99
113
  batch_size: int,
100
114
  input_len: int,
101
115
  output_len: int,
116
+ temperature: float,
117
+ return_logprob: bool,
118
+ input_len_step_percentage: float,
102
119
  run_name: str,
103
120
  result_filename: str,
121
+ tokenizer,
104
122
  ):
105
- input_ids = [
106
- [int(x) for x in np.random.randint(0, high=16384, size=(input_len,))]
107
- for _ in range(batch_size)
108
- ]
123
+ requests.post(url + "/flush_cache")
124
+ input_requests = sample_random_requests(
125
+ input_len=input_len,
126
+ output_len=output_len,
127
+ num_prompts=batch_size,
128
+ range_ratio=1.0,
129
+ tokenizer=tokenizer,
130
+ dataset_path="",
131
+ random_sample=True,
132
+ return_text=False,
133
+ )
134
+
135
+ use_structured_outputs = False
136
+ if use_structured_outputs:
137
+ texts = []
138
+ for _ in range(batch_size):
139
+ texts.append(
140
+ "Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.\n"
141
+ * 50
142
+ + "Assistant:"
143
+ )
144
+ json_schema = "$$ANY$$"
145
+ else:
146
+ json_schema = None
109
147
 
110
- tic = time.time()
148
+ tic = time.perf_counter()
111
149
  response = requests.post(
112
150
  url + "/generate",
113
151
  json={
114
- "input_ids": input_ids,
152
+ "input_ids": [req.prompt for req in input_requests],
115
153
  "sampling_params": {
116
- "temperature": 0,
154
+ "temperature": temperature,
117
155
  "max_new_tokens": output_len,
118
156
  "ignore_eos": True,
157
+ "json_schema": json_schema,
119
158
  },
159
+ "return_logprob": return_logprob,
160
+ "stream": True,
120
161
  },
162
+ stream=True,
121
163
  )
122
- latency = time.time() - tic
123
164
 
124
- _ = response.json()
125
- output_throughput = batch_size * output_len / latency
165
+ # The TTFT of the last request in the batch
166
+ ttft = 0.0
167
+ for chunk in response.iter_lines(decode_unicode=False):
168
+ chunk = chunk.decode("utf-8")
169
+ if chunk and chunk.startswith("data:"):
170
+ if chunk == "data: [DONE]":
171
+ break
172
+ data = json.loads(chunk[5:].strip("\n"))
173
+ if "error" in data:
174
+ raise RuntimeError(f"Request has failed. {data}.")
175
+
176
+ assert (
177
+ data["meta_info"]["finish_reason"] is None
178
+ or data["meta_info"]["finish_reason"]["type"] == "length"
179
+ )
180
+ if data["meta_info"]["completion_tokens"] == 1:
181
+ ttft = time.perf_counter() - tic
182
+
183
+ latency = time.perf_counter() - tic
184
+ input_throughput = batch_size * input_len / ttft
185
+ output_throughput = batch_size * output_len / (latency - ttft)
126
186
  overall_throughput = batch_size * (input_len + output_len) / latency
127
187
 
188
+ server_info = requests.get(url + "/get_server_info").json()
189
+ acc_length = server_info["internal_states"][0].get("avg_spec_accept_length", None)
190
+ last_gen_throughput = server_info["internal_states"][0]["last_gen_throughput"]
191
+
128
192
  print(f"batch size: {batch_size}")
193
+ print(f"input_len: {input_len}")
194
+ print(f"output_len: {output_len}")
129
195
  print(f"latency: {latency:.2f} s")
130
- print(f"output throughput: {output_throughput:.2f} token/s")
131
- print(f"(input + output) throughput: {overall_throughput:.2f} token/s")
196
+ print(f"ttft: {ttft:.2f} s")
197
+ print(f"Last generation throughput: {last_gen_throughput:.2f} tok/s")
198
+ print(f"Input throughput: {input_throughput:.2f} tok/s")
199
+ if output_len != 1:
200
+ print(f"output throughput: {output_throughput:.2f} tok/s")
132
201
 
133
202
  if result_filename:
134
203
  with open(result_filename, "a") as fout:
@@ -140,9 +209,21 @@ def run_one_case(
140
209
  "latency": round(latency, 4),
141
210
  "output_throughput": round(output_throughput, 2),
142
211
  "overall_throughput": round(overall_throughput, 2),
212
+ "last_gen_throughput": round(last_gen_throughput, 2),
143
213
  }
144
214
  fout.write(json.dumps(res) + "\n")
145
215
 
216
+ return (
217
+ batch_size,
218
+ latency,
219
+ ttft,
220
+ input_throughput,
221
+ output_throughput,
222
+ overall_throughput,
223
+ last_gen_throughput,
224
+ acc_length,
225
+ )
226
+
146
227
 
147
228
  def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
148
229
  if bench_args.base_url:
@@ -150,29 +231,45 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
150
231
  else:
151
232
  proc, base_url = launch_server_process(server_args)
152
233
 
234
+ tokenizer_id = server_args.tokenizer_path or server_args.model_path
235
+ tokenizer = get_tokenizer(tokenizer_id)
236
+
153
237
  # warmup
154
238
  if not bench_args.skip_warmup:
239
+ print("=" * 8 + " Warmup Begin " + "=" * 8)
155
240
  run_one_case(
156
241
  base_url,
157
242
  batch_size=16,
158
243
  input_len=1024,
159
244
  output_len=16,
245
+ temperature=bench_args.temperature,
246
+ return_logprob=bench_args.return_logprob,
247
+ input_len_step_percentage=bench_args.input_len_step_percentage,
160
248
  run_name="",
161
249
  result_filename="",
250
+ tokenizer=tokenizer,
162
251
  )
252
+ print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
163
253
 
164
254
  # benchmark
255
+ result = []
165
256
  try:
166
257
  for bs, il, ol in itertools.product(
167
258
  bench_args.batch_size, bench_args.input_len, bench_args.output_len
168
259
  ):
169
- run_one_case(
170
- base_url,
171
- bs,
172
- il,
173
- ol,
174
- bench_args.run_name,
175
- bench_args.result_filename,
260
+ result.append(
261
+ run_one_case(
262
+ base_url,
263
+ bs,
264
+ il,
265
+ ol,
266
+ temperature=bench_args.temperature,
267
+ return_logprob=bench_args.return_logprob,
268
+ input_len_step_percentage=bench_args.input_len_step_percentage,
269
+ run_name=bench_args.run_name,
270
+ result_filename=bench_args.result_filename,
271
+ tokenizer=tokenizer,
272
+ )
176
273
  )
177
274
  finally:
178
275
  if proc:
@@ -180,6 +277,45 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
180
277
 
181
278
  print(f"\nResults are saved to {bench_args.result_filename}")
182
279
 
280
+ if not bench_args.show_report:
281
+ return
282
+
283
+ summary = " | batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input price ($/1M) | output price ($/1M) |\n"
284
+ summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ------------------ | ------------------- |\n"
285
+
286
+ for (
287
+ batch_size,
288
+ latency,
289
+ ttft,
290
+ input_throughput,
291
+ output_throughput,
292
+ overall_throughput,
293
+ last_gen_throughput,
294
+ acc_length,
295
+ ) in result:
296
+ hourly_cost = 2 * server_args.tp_size # $2/hour for one H100
297
+ input_util = 0.7
298
+ accept_length = round(acc_length, 2) if acc_length is not None else "n/a"
299
+ line = (
300
+ f"| {batch_size} | "
301
+ f"{latency:.2f} | "
302
+ f"{input_throughput:.2f} | "
303
+ f"{output_throughput:.2f} | "
304
+ f"{accept_length} | "
305
+ f"{1 / (output_throughput/batch_size) * 1000:.2f} | "
306
+ f"{1e6 / (input_throughput * input_util) / 3600 * hourly_cost:.2f} | "
307
+ f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |\n"
308
+ )
309
+ summary += line
310
+
311
+ # print metrics table
312
+ print(summary)
313
+
314
+ if is_in_ci():
315
+ write_github_step_summary(
316
+ f"### Test Nightly Benchmark (bench_one_batch) \n{summary}"
317
+ )
318
+
183
319
 
184
320
  if __name__ == "__main__":
185
321
  parser = argparse.ArgumentParser()