sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +3 -13
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +158 -8
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +119 -75
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +5 -2
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/internvl.py +696 -0
  13. sglang/srt/configs/janus_pro.py +3 -0
  14. sglang/srt/configs/model_config.py +18 -0
  15. sglang/srt/constrained/base_grammar_backend.py +55 -72
  16. sglang/srt/constrained/llguidance_backend.py +25 -21
  17. sglang/srt/constrained/outlines_backend.py +27 -26
  18. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  19. sglang/srt/constrained/xgrammar_backend.py +71 -53
  20. sglang/srt/conversation.py +78 -46
  21. sglang/srt/disaggregation/base/conn.py +1 -0
  22. sglang/srt/disaggregation/decode.py +11 -3
  23. sglang/srt/disaggregation/fake/conn.py +1 -1
  24. sglang/srt/disaggregation/mini_lb.py +74 -23
  25. sglang/srt/disaggregation/mooncake/conn.py +236 -138
  26. sglang/srt/disaggregation/nixl/conn.py +242 -71
  27. sglang/srt/disaggregation/prefill.py +7 -4
  28. sglang/srt/disaggregation/utils.py +51 -2
  29. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  30. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  31. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  32. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  33. sglang/srt/distributed/parallel_state.py +22 -1
  34. sglang/srt/entrypoints/engine.py +31 -4
  35. sglang/srt/entrypoints/http_server.py +45 -3
  36. sglang/srt/entrypoints/verl_engine.py +3 -2
  37. sglang/srt/function_call_parser.py +2 -2
  38. sglang/srt/hf_transformers_utils.py +20 -1
  39. sglang/srt/layers/attention/flashattention_backend.py +147 -51
  40. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  41. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  42. sglang/srt/layers/attention/merge_state.py +46 -0
  43. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  44. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  45. sglang/srt/layers/attention/utils.py +4 -2
  46. sglang/srt/layers/attention/vision.py +290 -163
  47. sglang/srt/layers/dp_attention.py +71 -21
  48. sglang/srt/layers/layernorm.py +1 -1
  49. sglang/srt/layers/logits_processor.py +46 -11
  50. sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
  51. sglang/srt/layers/moe/ep_moe/layer.py +121 -2
  52. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  56. sglang/srt/layers/moe/topk.py +1 -1
  57. sglang/srt/layers/quantization/__init__.py +1 -1
  58. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  59. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  60. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  61. sglang/srt/layers/quantization/deep_gemm.py +77 -71
  62. sglang/srt/layers/quantization/fp8.py +110 -97
  63. sglang/srt/layers/quantization/fp8_kernel.py +81 -62
  64. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  65. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  66. sglang/srt/layers/quantization/kv_cache.py +3 -10
  67. sglang/srt/layers/quantization/utils.py +0 -5
  68. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  69. sglang/srt/layers/sampler.py +0 -4
  70. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  71. sglang/srt/lora/lora_manager.py +11 -14
  72. sglang/srt/lora/mem_pool.py +4 -4
  73. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  74. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  75. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  76. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  77. sglang/srt/lora/utils.py +1 -1
  78. sglang/srt/managers/cache_controller.py +115 -119
  79. sglang/srt/managers/data_parallel_controller.py +3 -3
  80. sglang/srt/managers/detokenizer_manager.py +21 -8
  81. sglang/srt/managers/io_struct.py +13 -1
  82. sglang/srt/managers/mm_utils.py +1 -1
  83. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  84. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  85. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  86. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  87. sglang/srt/managers/schedule_batch.py +93 -23
  88. sglang/srt/managers/schedule_policy.py +11 -8
  89. sglang/srt/managers/scheduler.py +140 -100
  90. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  91. sglang/srt/managers/tokenizer_manager.py +157 -47
  92. sglang/srt/managers/tp_worker.py +21 -21
  93. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  94. sglang/srt/mem_cache/chunk_cache.py +2 -0
  95. sglang/srt/mem_cache/memory_pool.py +4 -2
  96. sglang/srt/metrics/collector.py +312 -37
  97. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  98. sglang/srt/model_executor/forward_batch_info.py +1 -1
  99. sglang/srt/model_executor/model_runner.py +57 -41
  100. sglang/srt/model_loader/loader.py +18 -11
  101. sglang/srt/models/clip.py +4 -4
  102. sglang/srt/models/deepseek_janus_pro.py +3 -3
  103. sglang/srt/models/deepseek_nextn.py +1 -20
  104. sglang/srt/models/deepseek_v2.py +77 -39
  105. sglang/srt/models/gemma3_mm.py +1 -1
  106. sglang/srt/models/internlm2.py +3 -0
  107. sglang/srt/models/internvl.py +670 -0
  108. sglang/srt/models/llama.py +3 -1
  109. sglang/srt/models/llama4.py +58 -13
  110. sglang/srt/models/llava.py +248 -5
  111. sglang/srt/models/minicpmv.py +1 -1
  112. sglang/srt/models/mixtral.py +98 -34
  113. sglang/srt/models/mllama.py +1 -1
  114. sglang/srt/models/phi3_small.py +16 -2
  115. sglang/srt/models/pixtral.py +467 -0
  116. sglang/srt/models/qwen2_5_vl.py +8 -4
  117. sglang/srt/models/qwen2_vl.py +4 -4
  118. sglang/srt/models/roberta.py +1 -1
  119. sglang/srt/models/torch_native_llama.py +1 -1
  120. sglang/srt/models/xiaomi_mimo.py +171 -0
  121. sglang/srt/openai_api/adapter.py +52 -42
  122. sglang/srt/openai_api/protocol.py +20 -16
  123. sglang/srt/reasoning_parser.py +1 -1
  124. sglang/srt/sampling/custom_logit_processor.py +18 -3
  125. sglang/srt/sampling/sampling_batch_info.py +2 -2
  126. sglang/srt/sampling/sampling_params.py +2 -0
  127. sglang/srt/server_args.py +64 -10
  128. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  129. sglang/srt/speculative/eagle_utils.py +7 -7
  130. sglang/srt/speculative/eagle_worker.py +22 -19
  131. sglang/srt/utils.py +41 -6
  132. sglang/test/few_shot_gsm8k.py +2 -2
  133. sglang/test/few_shot_gsm8k_engine.py +2 -2
  134. sglang/test/run_eval.py +2 -2
  135. sglang/test/runners.py +8 -1
  136. sglang/test/send_one.py +13 -3
  137. sglang/test/simple_eval_common.py +1 -1
  138. sglang/test/simple_eval_humaneval.py +1 -1
  139. sglang/test/test_block_fp8.py +2 -2
  140. sglang/test/test_deepep_utils.py +219 -0
  141. sglang/test/test_programs.py +5 -5
  142. sglang/test/test_utils.py +92 -15
  143. sglang/utils.py +1 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
  146. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
  147. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
  148. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  149. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -259,7 +259,9 @@ def throughput_test_once(
259
259
  measurement_results["total_input_tokens"]
260
260
  + measurement_results["total_output_tokens"]
261
261
  ) / latency
262
- measurement_results["last_gen_throughput"] = server_info["last_gen_throughput"]
262
+ measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
263
+ "last_gen_throughput"
264
+ ]
263
265
 
264
266
  return measurement_results
265
267
 
@@ -315,7 +317,7 @@ def throughput_test(
315
317
  tokenizer_id = server_args.tokenizer_path or server_args.model_path
316
318
  tokenizer = get_tokenizer(tokenizer_id)
317
319
 
318
- # Set global environmnets
320
+ # Set global environments
319
321
  set_ulimit()
320
322
  random.seed(bench_args.seed)
321
323
  np.random.seed(bench_args.seed)
sglang/bench_one_batch.py CHANGED
@@ -137,17 +137,7 @@ def load_model(server_args, port_args, tp_rank):
137
137
  suppress_other_loggers()
138
138
  rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
139
139
 
140
- model_config = ModelConfig(
141
- server_args.model_path,
142
- trust_remote_code=server_args.trust_remote_code,
143
- revision=server_args.revision,
144
- context_length=server_args.context_length,
145
- model_override_args=server_args.json_model_override_args,
146
- is_embedding=server_args.is_embedding,
147
- enable_multimodal=server_args.enable_multimodal,
148
- dtype=server_args.dtype,
149
- quantization=server_args.quantization,
150
- )
140
+ model_config = ModelConfig.from_server_args(server_args)
151
141
  model_runner = ModelRunner(
152
142
  model_config=model_config,
153
143
  mem_fraction_static=server_args.mem_fraction_static,
@@ -256,7 +246,7 @@ def extend(reqs, model_runner):
256
246
  _maybe_prepare_dp_attn_batch(batch, model_runner)
257
247
  model_worker_batch = batch.get_model_worker_batch()
258
248
  forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
259
- logits_output = model_runner.forward(forward_batch)
249
+ logits_output, _ = model_runner.forward(forward_batch)
260
250
  next_token_ids = model_runner.sample(logits_output, forward_batch)
261
251
  return next_token_ids, logits_output.next_token_logits, batch
262
252
 
@@ -268,7 +258,7 @@ def decode(input_token_ids, batch, model_runner):
268
258
  _maybe_prepare_dp_attn_batch(batch, model_runner)
269
259
  model_worker_batch = batch.get_model_worker_batch()
270
260
  forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
271
- logits_output = model_runner.forward(forward_batch)
261
+ logits_output, _ = model_runner.forward(forward_batch)
272
262
  next_token_ids = model_runner.sample(logits_output, forward_batch)
273
263
  return next_token_ids, logits_output.next_token_logits
274
264
 
@@ -25,6 +25,7 @@ import requests
25
25
  from sglang.srt.entrypoints.http_server import launch_server
26
26
  from sglang.srt.server_args import ServerArgs
27
27
  from sglang.srt.utils import kill_process_tree
28
+ from sglang.test.test_utils import is_in_ci, write_github_step_summary
28
29
 
29
30
 
30
31
  @dataclasses.dataclass
@@ -33,9 +34,13 @@ class BenchArgs:
33
34
  batch_size: Tuple[int] = (1,)
34
35
  input_len: Tuple[int] = (1024,)
35
36
  output_len: Tuple[int] = (16,)
37
+ temperature: float = 0.0
38
+ return_logprob: bool = False
39
+ input_len_step_percentage: float = 0.0
36
40
  result_filename: str = "result.jsonl"
37
41
  base_url: str = ""
38
42
  skip_warmup: bool = False
43
+ show_report: bool = False
39
44
 
40
45
  @staticmethod
41
46
  def add_cli_args(parser: argparse.ArgumentParser):
@@ -49,11 +54,19 @@ class BenchArgs:
49
54
  parser.add_argument(
50
55
  "--output-len", type=int, nargs="+", default=BenchArgs.output_len
51
56
  )
57
+ parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
58
+ parser.add_argument("--return-logprob", action="store_true")
59
+ parser.add_argument(
60
+ "--input-len-step-percentage",
61
+ type=float,
62
+ default=BenchArgs.input_len_step_percentage,
63
+ )
52
64
  parser.add_argument(
53
65
  "--result-filename", type=str, default=BenchArgs.result_filename
54
66
  )
55
67
  parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
56
68
  parser.add_argument("--skip-warmup", action="store_true")
69
+ parser.add_argument("--show-report", action="store_true")
57
70
 
58
71
  @classmethod
59
72
  def from_cli_args(cls, args: argparse.Namespace):
@@ -99,36 +112,89 @@ def run_one_case(
99
112
  batch_size: int,
100
113
  input_len: int,
101
114
  output_len: int,
115
+ temperature: float,
116
+ return_logprob: bool,
117
+ input_len_step_percentage: float,
102
118
  run_name: str,
103
119
  result_filename: str,
104
120
  ):
121
+ requests.post(url + "/flush_cache")
122
+ input_lens = [
123
+ int(input_len * (1 + (i - (batch_size - 1) / 2) * input_len_step_percentage))
124
+ for i in range(batch_size)
125
+ ]
105
126
  input_ids = [
106
- [int(x) for x in np.random.randint(0, high=16384, size=(input_len,))]
107
- for _ in range(batch_size)
127
+ [int(x) for x in np.random.randint(0, high=16384, size=(input_lens[i],))]
128
+ for i in range(batch_size)
108
129
  ]
109
130
 
131
+ use_structured_outputs = False
132
+ if use_structured_outputs:
133
+ texts = []
134
+ for _ in range(batch_size):
135
+ texts.append(
136
+ "Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.\n"
137
+ * 50
138
+ + "Assistant:"
139
+ )
140
+ json_schema = "$$ANY$$"
141
+ else:
142
+ json_schema = None
143
+
110
144
  tic = time.time()
111
145
  response = requests.post(
112
146
  url + "/generate",
113
147
  json={
148
+ # "text": texts,
114
149
  "input_ids": input_ids,
115
150
  "sampling_params": {
116
- "temperature": 0,
151
+ "temperature": temperature,
117
152
  "max_new_tokens": output_len,
118
153
  "ignore_eos": True,
154
+ "json_schema": json_schema,
119
155
  },
156
+ "return_logprob": return_logprob,
157
+ "stream": True,
120
158
  },
159
+ stream=True,
121
160
  )
122
- latency = time.time() - tic
123
161
 
124
- _ = response.json()
125
- output_throughput = batch_size * output_len / latency
162
+ # The TTFT of the last request in the batch
163
+ ttft = 0.0
164
+ for chunk in response.iter_lines(decode_unicode=False):
165
+ chunk = chunk.decode("utf-8")
166
+ if chunk and chunk.startswith("data:"):
167
+ if chunk == "data: [DONE]":
168
+ break
169
+ data = json.loads(chunk[5:].strip("\n"))
170
+ if "error" in data:
171
+ raise RuntimeError(f"Request has failed. {data}.")
172
+
173
+ assert (
174
+ data["meta_info"]["finish_reason"] is None
175
+ or data["meta_info"]["finish_reason"]["type"] == "length"
176
+ )
177
+ if data["meta_info"]["completion_tokens"] == 1:
178
+ ttft = time.time() - tic
179
+
180
+ latency = time.time() - tic
181
+ input_throughput = batch_size * input_len / ttft
182
+ output_throughput = batch_size * output_len / (latency - ttft)
126
183
  overall_throughput = batch_size * (input_len + output_len) / latency
127
184
 
185
+ server_info = requests.get(url + "/get_server_info").json()
186
+ acc_length = server_info["internal_states"][0].get("avg_spec_accept_length", None)
187
+ last_gen_throughput = server_info["internal_states"][0]["last_gen_throughput"]
188
+
128
189
  print(f"batch size: {batch_size}")
190
+ print(f"input_len: {input_len}")
191
+ print(f"output_len: {output_len}")
129
192
  print(f"latency: {latency:.2f} s")
130
- print(f"output throughput: {output_throughput:.2f} token/s")
131
- print(f"(input + output) throughput: {overall_throughput:.2f} token/s")
193
+ print(f"ttft: {ttft:.2f} s")
194
+ print(f"Last generation throughput: {last_gen_throughput:.2f} tok/s")
195
+ print(f"Input throughput: {input_throughput:.2f} tok/s")
196
+ if output_len != 1:
197
+ print(f"output throughput: {output_throughput:.2f} tok/s")
132
198
 
133
199
  if result_filename:
134
200
  with open(result_filename, "a") as fout:
@@ -140,9 +206,21 @@ def run_one_case(
140
206
  "latency": round(latency, 4),
141
207
  "output_throughput": round(output_throughput, 2),
142
208
  "overall_throughput": round(overall_throughput, 2),
209
+ "last_gen_throughput": round(last_gen_throughput, 2),
143
210
  }
144
211
  fout.write(json.dumps(res) + "\n")
145
212
 
213
+ return (
214
+ batch_size,
215
+ latency,
216
+ ttft,
217
+ input_throughput,
218
+ output_throughput,
219
+ overall_throughput,
220
+ last_gen_throughput,
221
+ acc_length,
222
+ )
223
+
146
224
 
147
225
  def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
148
226
  if bench_args.base_url:
@@ -152,27 +230,38 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
152
230
 
153
231
  # warmup
154
232
  if not bench_args.skip_warmup:
233
+ print("=" * 8 + " Warmup Begin " + "=" * 8)
155
234
  run_one_case(
156
235
  base_url,
157
236
  batch_size=16,
158
237
  input_len=1024,
159
238
  output_len=16,
239
+ temperature=bench_args.temperature,
240
+ return_logprob=bench_args.return_logprob,
241
+ input_len_step_percentage=bench_args.input_len_step_percentage,
160
242
  run_name="",
161
243
  result_filename="",
162
244
  )
245
+ print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
163
246
 
164
247
  # benchmark
248
+ result = []
165
249
  try:
166
250
  for bs, il, ol in itertools.product(
167
251
  bench_args.batch_size, bench_args.input_len, bench_args.output_len
168
252
  ):
169
- run_one_case(
170
- base_url,
171
- bs,
172
- il,
173
- ol,
174
- bench_args.run_name,
175
- bench_args.result_filename,
253
+ result.append(
254
+ run_one_case(
255
+ base_url,
256
+ bs,
257
+ il,
258
+ ol,
259
+ temperature=bench_args.temperature,
260
+ return_logprob=bench_args.return_logprob,
261
+ input_len_step_percentage=bench_args.input_len_step_percentage,
262
+ run_name=bench_args.run_name,
263
+ result_filename=bench_args.result_filename,
264
+ )
176
265
  )
177
266
  finally:
178
267
  if proc:
@@ -180,6 +269,45 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
180
269
 
181
270
  print(f"\nResults are saved to {bench_args.result_filename}")
182
271
 
272
+ if not bench_args.show_report:
273
+ return
274
+
275
+ summary = " | batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input price ($/1M) | output price ($/1M) |\n"
276
+ summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ------------------ | ------------------- |\n"
277
+
278
+ for (
279
+ batch_size,
280
+ latency,
281
+ ttft,
282
+ input_throughput,
283
+ output_throughput,
284
+ overall_throughput,
285
+ last_gen_throughput,
286
+ acc_length,
287
+ ) in result:
288
+ hourly_cost = 2 * server_args.tp_size # $2/hour for one H100
289
+ input_util = 0.7
290
+ accept_length = round(acc_length, 2) if acc_length is not None else "n/a"
291
+ line = (
292
+ f"| {batch_size} | "
293
+ f"{latency:.2f} | "
294
+ f"{input_throughput:.2f} | "
295
+ f"{output_throughput:.2f} | "
296
+ f"{accept_length} | "
297
+ f"{1 / (output_throughput/batch_size) * 1000:.2f} | "
298
+ f"{1e6 / (input_throughput * input_util) / 3600 * hourly_cost:.2f} | "
299
+ f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |\n"
300
+ )
301
+ summary += line
302
+
303
+ # print metrics table
304
+ print(summary)
305
+
306
+ if is_in_ci():
307
+ write_github_step_summary(
308
+ f"### Test Nightly Benchmark (bench_one_batch) \n{summary}"
309
+ )
310
+
183
311
 
184
312
  if __name__ == "__main__":
185
313
  parser = argparse.ArgumentParser()
sglang/bench_serving.py CHANGED
@@ -58,6 +58,7 @@ class RequestFuncInput:
58
58
  output_len: int
59
59
  model: str
60
60
  lora_name: str
61
+ image_data: str
61
62
  extra_request_body: Dict[str, Any]
62
63
 
63
64
 
@@ -347,6 +348,11 @@ async def async_request_sglang_generate(
347
348
  "logprob_start_len": -1,
348
349
  **request_func_input.extra_request_body,
349
350
  }
351
+
352
+ # Add image data if available
353
+ if request_func_input.image_data:
354
+ payload["image_data"] = request_func_input.image_data
355
+
350
356
  headers = get_auth_headers()
351
357
 
352
358
  output = RequestFuncOutput()
@@ -510,6 +516,13 @@ def get_dataset(args, tokenizer):
510
516
  tokenizer=tokenizer,
511
517
  args=args,
512
518
  )
519
+ elif args.dataset_name == "mmmu":
520
+ input_requests = sample_mmmu_requests(
521
+ num_requests=args.num_prompts,
522
+ tokenizer=tokenizer,
523
+ fixed_output_len=args.random_output_len,
524
+ random_sample=True,
525
+ )
513
526
  else:
514
527
  raise ValueError(f"Unknown dataset: {args.dataset_name}")
515
528
  return input_requests
@@ -597,6 +610,121 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
597
610
  return filename
598
611
 
599
612
 
613
+ def sample_mmmu_requests(
614
+ num_requests: int,
615
+ tokenizer: PreTrainedTokenizerBase,
616
+ fixed_output_len: Optional[int] = None,
617
+ random_sample: bool = True,
618
+ ) -> List[Tuple[str, int, int]]:
619
+ """
620
+ Sample requests from the MMMU dataset using HuggingFace datasets.
621
+
622
+ Args:
623
+ num_requests: Number of requests to sample.
624
+ tokenizer: Tokenizer to use for token counting.
625
+ fixed_output_len: If provided, use this fixed output length for all requests.
626
+ random_sample: Whether to randomly sample or take the first N.
627
+
628
+ Returns:
629
+ List of tuples (prompt, prompt_token_len, output_token_len).
630
+ """
631
+ try:
632
+ import base64
633
+ import io
634
+
635
+ from datasets import load_dataset
636
+ except ImportError:
637
+ raise ImportError("Please install datasets: pip install datasets")
638
+
639
+ print("Loading MMMU dataset from HuggingFace...")
640
+
641
+ try:
642
+ print("Attempting to load MMMU Math dataset...")
643
+ mmmu_dataset = load_dataset("MMMU/MMMU", "Math", split="test")
644
+ print(
645
+ f"Successfully loaded MMMU Math dataset from HuggingFace with {len(mmmu_dataset)} examples"
646
+ )
647
+ except Exception as e:
648
+ print(f"Failed to load MMMU Math dataset: {e}")
649
+ raise ValueError(f"Failed to load MMMU dataset: {e}")
650
+
651
+ # Sample from the dataset
652
+ if len(mmmu_dataset) > num_requests:
653
+ if random_sample:
654
+ # Random sample
655
+ indices = random.sample(range(len(mmmu_dataset)), num_requests)
656
+ sample_dataset = mmmu_dataset.select(indices)
657
+ else:
658
+ # Take first N
659
+ sample_dataset = mmmu_dataset.select(
660
+ range(min(num_requests, len(mmmu_dataset)))
661
+ )
662
+ else:
663
+ print(f"Dataset has less than {num_requests} examples, using all examples")
664
+ sample_dataset = mmmu_dataset
665
+
666
+ print(f"Selected {len(sample_dataset)} examples for benchmarking")
667
+
668
+ # Create prompts
669
+ filtered_dataset = []
670
+
671
+ for i, example in enumerate(sample_dataset):
672
+ try:
673
+ # Extract image_1
674
+ image = example.get("image_1")
675
+
676
+ if image is not None:
677
+ if hasattr(image, "save"):
678
+ # Convert RGBA images to RGB before encoding
679
+ if image.mode == "RGBA":
680
+ image = image.convert("RGB")
681
+
682
+ # Encode image to base64
683
+ buffered = io.BytesIO()
684
+ image.save(buffered, format="JPEG")
685
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
686
+ image_path = f"data:image/jpeg;base64,{img_str}"
687
+ else:
688
+ continue
689
+
690
+ # Extract the question
691
+ question = example.get("question")
692
+
693
+ # Create the prompt with image, question
694
+ prompt = f"Question: {question}\n\nAnswer: "
695
+ prompt = tokenizer.apply_chat_template(
696
+ [
697
+ {
698
+ "role": "user",
699
+ "content": [
700
+ {"type": "image_url", "image_url": {"url": image_path}},
701
+ {"type": "text", "text": prompt},
702
+ ],
703
+ }
704
+ ],
705
+ add_generation_prompt=True,
706
+ tokenize=False,
707
+ )
708
+ prompt = f"<image>{image_path}</image>{prompt}"
709
+
710
+ # Calculate token lengths
711
+ # Note: This is approximate since we're not rendering the actual image tokens
712
+ prompt_token_ids = tokenizer.encode(prompt)
713
+ prompt_len = (
714
+ len(prompt_token_ids) + 512
715
+ ) # Add estimate for image tokens
716
+
717
+ output_len = fixed_output_len if fixed_output_len is not None else 256
718
+
719
+ filtered_dataset.append((prompt, prompt_len, output_len))
720
+
721
+ except Exception as e:
722
+ print(f"Error processing example {i}: {e}")
723
+
724
+ print(f"\nCreated {len(filtered_dataset)} MMMU prompts")
725
+ return filtered_dataset
726
+
727
+
600
728
  def sample_sharegpt_requests(
601
729
  dataset_path: str,
602
730
  num_requests: int,
@@ -975,7 +1103,7 @@ async def benchmark(
975
1103
  lora_names: List[str],
976
1104
  extra_request_body: Dict[str, Any],
977
1105
  profile: bool,
978
- pd_seperated: bool = False,
1106
+ pd_separated: bool = False,
979
1107
  flush_cache: bool = False,
980
1108
  warmup_requests: int = 1,
981
1109
  ):
@@ -1004,6 +1132,15 @@ async def benchmark(
1004
1132
  else:
1005
1133
  lora_name = None
1006
1134
 
1135
+ if "<image>" in test_prompt:
1136
+ import re
1137
+
1138
+ image_match = re.search(r"<image>(.*?)</image>(.*)", test_prompt)
1139
+ image_data = image_match.group(1) if image_match else None
1140
+ test_prompt = image_match.group(2) if image_match else test_prompt
1141
+ else:
1142
+ image_data = None
1143
+
1007
1144
  # Create the test input once
1008
1145
  test_input = RequestFuncInput(
1009
1146
  model=model_id,
@@ -1012,6 +1149,7 @@ async def benchmark(
1012
1149
  prompt_len=test_prompt_len,
1013
1150
  output_len=min(test_output_len, 32),
1014
1151
  lora_name=lora_name,
1152
+ image_data=image_data,
1015
1153
  extra_request_body=extra_request_body,
1016
1154
  )
1017
1155
 
@@ -1063,6 +1201,15 @@ async def benchmark(
1063
1201
  else:
1064
1202
  lora_name = None
1065
1203
 
1204
+ if "<image>" in prompt:
1205
+ import re
1206
+
1207
+ image_match = re.search(r"<image>(.*?)</image>(.*)", prompt)
1208
+ image_data = image_match.group(1) if image_match else None
1209
+ prompt = image_match.group(2) if image_match else prompt
1210
+ else:
1211
+ image_data = None
1212
+
1066
1213
  request_func_input = RequestFuncInput(
1067
1214
  model=model_id,
1068
1215
  prompt=prompt,
@@ -1070,6 +1217,7 @@ async def benchmark(
1070
1217
  prompt_len=prompt_len,
1071
1218
  output_len=output_len,
1072
1219
  lora_name=lora_name,
1220
+ image_data=image_data,
1073
1221
  extra_request_body=extra_request_body,
1074
1222
  )
1075
1223
  tasks.append(
@@ -1091,12 +1239,14 @@ async def benchmark(
1091
1239
 
1092
1240
  if "sglang" in backend:
1093
1241
  server_info = requests.get(base_url + "/get_server_info")
1094
- if pd_seperated:
1095
- accept_length = server_info.json()["decode"][0].get(
1242
+ if pd_separated:
1243
+ accept_length = server_info.json()["decode"][0]["internal_states"][0].get(
1096
1244
  "avg_spec_accept_length", None
1097
1245
  )
1098
1246
  else:
1099
- accept_length = server_info.json().get("avg_spec_accept_length", None)
1247
+ accept_length = server_info.json()["internal_states"][0].get(
1248
+ "avg_spec_accept_length", None
1249
+ )
1100
1250
  else:
1101
1251
  accept_length = None
1102
1252
 
@@ -1115,7 +1265,7 @@ async def benchmark(
1115
1265
  print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
1116
1266
  print(
1117
1267
  "{:<40} {:<10}".format(
1118
- "Max reqeuest concurrency:",
1268
+ "Max request concurrency:",
1119
1269
  max_concurrency if max_concurrency else "not set",
1120
1270
  )
1121
1271
  )
@@ -1393,7 +1543,7 @@ def run_benchmark(args_: argparse.Namespace):
1393
1543
  lora_names=args.lora_name,
1394
1544
  extra_request_body=extra_request_body,
1395
1545
  profile=args.profile,
1396
- pd_seperated=args.pd_seperated,
1546
+ pd_separated=args.pd_separated,
1397
1547
  flush_cache=args.flush_cache,
1398
1548
  )
1399
1549
  )
@@ -1444,7 +1594,7 @@ if __name__ == "__main__":
1444
1594
  "--dataset-name",
1445
1595
  type=str,
1446
1596
  default="sharegpt",
1447
- choices=["sharegpt", "random", "random-ids", "generated-shared-prefix"],
1597
+ choices=["sharegpt", "random", "random-ids", "generated-shared-prefix", "mmmu"],
1448
1598
  help="Name of the dataset to benchmark on.",
1449
1599
  )
1450
1600
  parser.add_argument(
@@ -1572,7 +1722,7 @@ if __name__ == "__main__":
1572
1722
  help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
1573
1723
  )
1574
1724
  parser.add_argument(
1575
- "--pd-seperated",
1725
+ "--pd-separated",
1576
1726
  action="store_true",
1577
1727
  help="Benchmark PD disaggregation server",
1578
1728
  )
@@ -129,7 +129,7 @@ def launch_server_process_and_send_one_request(
129
129
 
130
130
 
131
131
  def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
132
- # Disbale cuda graph and torch compile to save time
132
+ # Disable cuda graph and torch compile to save time
133
133
  server_args.disable_cuda_graph = True
134
134
  server_args.enable_torch_compile = False
135
135
  print(f"Disable CUDA Graph and Torch Compile to save time...")