sglang 0.1.20__py3-none-any.whl → 0.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/__init__.py +8 -8
  2. sglang/api.py +1 -1
  3. sglang/backend/runtime_endpoint.py +14 -4
  4. sglang/backend/vertexai.py +5 -4
  5. sglang/bench.py +627 -0
  6. sglang/bench_latency.py +22 -20
  7. sglang/bench_serving.py +758 -0
  8. sglang/check_env.py +171 -0
  9. sglang/global_config.py +3 -1
  10. sglang/lang/backend/__init__.py +0 -0
  11. sglang/lang/backend/anthropic.py +77 -0
  12. sglang/lang/backend/base_backend.py +80 -0
  13. sglang/lang/backend/litellm.py +90 -0
  14. sglang/lang/backend/openai.py +438 -0
  15. sglang/lang/backend/runtime_endpoint.py +283 -0
  16. sglang/lang/backend/vertexai.py +149 -0
  17. sglang/lang/chat_template.py +2 -2
  18. sglang/lang/ir.py +3 -3
  19. sglang/lang/tracer.py +1 -1
  20. sglang/launch_server.py +1 -1
  21. sglang/launch_server_llavavid.py +1 -4
  22. sglang/srt/conversation.py +1 -1
  23. sglang/srt/layers/context_flashattention_nopad.py +0 -29
  24. sglang/srt/layers/extend_attention.py +0 -39
  25. sglang/srt/layers/linear.py +869 -0
  26. sglang/srt/layers/quantization/__init__.py +49 -0
  27. sglang/srt/layers/quantization/fp8.py +662 -0
  28. sglang/srt/layers/radix_attention.py +31 -5
  29. sglang/srt/layers/token_attention.py +1 -51
  30. sglang/srt/managers/controller/cuda_graph_runner.py +44 -18
  31. sglang/srt/managers/controller/infer_batch.py +76 -72
  32. sglang/srt/managers/controller/manager_multi.py +109 -98
  33. sglang/srt/managers/controller/manager_single.py +105 -50
  34. sglang/srt/managers/controller/model_runner.py +42 -18
  35. sglang/srt/managers/controller/radix_cache.py +4 -3
  36. sglang/srt/managers/controller/schedule_heuristic.py +4 -0
  37. sglang/srt/managers/controller/tp_worker.py +143 -156
  38. sglang/srt/managers/detokenizer_manager.py +49 -5
  39. sglang/srt/managers/io_struct.py +36 -17
  40. sglang/srt/managers/tokenizer_manager.py +228 -125
  41. sglang/srt/memory_pool.py +46 -58
  42. sglang/srt/model_loader/model_loader.py +277 -0
  43. sglang/srt/model_loader/utils.py +260 -0
  44. sglang/srt/models/chatglm.py +1 -0
  45. sglang/srt/models/dbrx.py +1 -0
  46. sglang/srt/models/grok.py +1 -0
  47. sglang/srt/models/internlm2.py +317 -0
  48. sglang/srt/models/llama2.py +65 -16
  49. sglang/srt/models/llama_classification.py +1 -0
  50. sglang/srt/models/llava.py +1 -0
  51. sglang/srt/models/llavavid.py +1 -0
  52. sglang/srt/models/minicpm.py +2 -8
  53. sglang/srt/models/mixtral.py +1 -0
  54. sglang/srt/models/mixtral_quant.py +1 -0
  55. sglang/srt/models/qwen.py +1 -0
  56. sglang/srt/models/qwen2.py +6 -0
  57. sglang/srt/models/qwen2_moe.py +130 -108
  58. sglang/srt/models/stablelm.py +1 -0
  59. sglang/srt/openai_api/adapter.py +432 -0
  60. sglang/srt/openai_api/api_adapter.py +432 -0
  61. sglang/srt/openai_api/openai_api_adapter.py +431 -0
  62. sglang/srt/openai_api/openai_protocol.py +207 -0
  63. sglang/srt/openai_api/protocol.py +208 -0
  64. sglang/srt/openai_protocol.py +17 -0
  65. sglang/srt/sampling_params.py +2 -0
  66. sglang/srt/server.py +114 -90
  67. sglang/srt/server_args.py +27 -17
  68. sglang/srt/utils.py +17 -118
  69. sglang/test/test_conversation.py +1 -1
  70. sglang/test/test_openai_protocol.py +1 -1
  71. sglang/test/test_programs.py +1 -1
  72. sglang/test/test_utils.py +2 -2
  73. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/METADATA +157 -159
  74. sglang-0.1.22.dist-info/RECORD +103 -0
  75. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/WHEEL +1 -1
  76. sglang-0.1.20.dist-info/RECORD +0 -82
  77. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/LICENSE +0 -0
  78. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,758 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py
2
+ # Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py
3
+ """
4
+ Benchmark online serving.
5
+
6
+ Usage:
7
+ python3 -m sglang.bench_serving --backend sglang --num-prompt 10
8
+ """
9
+
10
+ import argparse
11
+ import asyncio
12
+ import json
13
+ import os
14
+ import random
15
+ import resource
16
+ import sys
17
+ import time
18
+ import traceback
19
+ import warnings
20
+ from argparse import ArgumentParser as FlexibleArgumentParser
21
+ from dataclasses import dataclass, field
22
+ from typing import AsyncGenerator, List, Optional, Tuple, Union
23
+
24
+ import aiohttp
25
+ import numpy as np
26
+ import requests
27
+ from tqdm.asyncio import tqdm
28
+ from transformers import (
29
+ AutoTokenizer,
30
+ PreTrainedTokenizer,
31
+ PreTrainedTokenizerBase,
32
+ PreTrainedTokenizerFast,
33
+ )
34
+
35
+ AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
36
+
37
+
38
+ @dataclass
39
+ class RequestFuncInput:
40
+ prompt: str
41
+ api_url: str
42
+ prompt_len: int
43
+ output_len: int
44
+ model: str
45
+
46
+
47
+ @dataclass
48
+ class RequestFuncOutput:
49
+ generated_text: str = ""
50
+ success: bool = False
51
+ latency: float = 0.0
52
+ ttft: float = 0.0 # Time to first token
53
+ itl: List[float] = field(default_factory=list) # List of inter-token latencies
54
+ prompt_len: int = 0
55
+ error: str = ""
56
+
57
+
58
+ def remove_prefix(text: str, prefix: str) -> str:
59
+ return text[len(prefix) :] if text.startswith(prefix) else text
60
+
61
+
62
+ # set ignore_eos True by default
63
+ async def async_request_openai_completions(
64
+ request_func_input: RequestFuncInput,
65
+ pbar: Optional[tqdm] = None,
66
+ ) -> RequestFuncOutput:
67
+ api_url = request_func_input.api_url
68
+ assert api_url.endswith(
69
+ "completions"
70
+ ), "OpenAI Completions API URL must end with 'completions'."
71
+
72
+ async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
73
+ payload = {
74
+ "model": request_func_input.model,
75
+ "prompt": request_func_input.prompt,
76
+ "temperature": 0.0,
77
+ "best_of": 1,
78
+ "max_tokens": request_func_input.output_len,
79
+ "stream": True,
80
+ "ignore_eos": True,
81
+ }
82
+ headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
83
+
84
+ output = RequestFuncOutput()
85
+ output.prompt_len = request_func_input.prompt_len
86
+
87
+ generated_text = ""
88
+ ttft = 0.0
89
+ st = time.perf_counter()
90
+ most_recent_timestamp = st
91
+ try:
92
+ async with session.post(
93
+ url=api_url, json=payload, headers=headers
94
+ ) as response:
95
+ if response.status == 200:
96
+ async for chunk_bytes in response.content:
97
+ chunk_bytes = chunk_bytes.strip()
98
+ if not chunk_bytes:
99
+ continue
100
+
101
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
102
+ if chunk == "[DONE]":
103
+ latency = time.perf_counter() - st
104
+ else:
105
+ data = json.loads(chunk)
106
+
107
+ # NOTE: Some completion API might have a last
108
+ # usage summary response without a token so we
109
+ # want to check a token was generated
110
+ if data["choices"][0]["text"]:
111
+ timestamp = time.perf_counter()
112
+ # First token
113
+ if ttft == 0.0:
114
+ ttft = time.perf_counter() - st
115
+ output.ttft = ttft
116
+
117
+ # Decoding phase
118
+ output.itl.append(timestamp - most_recent_timestamp)
119
+
120
+ most_recent_timestamp = timestamp
121
+ generated_text += data["choices"][0]["text"]
122
+
123
+ output.generated_text = generated_text
124
+ output.success = True
125
+ output.latency = latency
126
+ else:
127
+ output.error = response.reason or ""
128
+ output.success = False
129
+ except Exception:
130
+ output.success = False
131
+ exc_info = sys.exc_info()
132
+ output.error = "".join(traceback.format_exception(*exc_info))
133
+
134
+ if pbar:
135
+ pbar.update(1)
136
+ return output
137
+
138
+
139
+ def get_model(pretrained_model_name_or_path: str) -> str:
140
+ if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true":
141
+ import huggingface_hub.constants
142
+ from modelscope import snapshot_download
143
+
144
+ model_path = snapshot_download(
145
+ model_id=pretrained_model_name_or_path,
146
+ local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
147
+ ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
148
+ )
149
+
150
+ return model_path
151
+ return pretrained_model_name_or_path
152
+
153
+
154
+ def get_tokenizer(
155
+ pretrained_model_name_or_path: str,
156
+ ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
157
+ if pretrained_model_name_or_path is not None and not os.path.exists(
158
+ pretrained_model_name_or_path
159
+ ):
160
+ pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)
161
+ return AutoTokenizer.from_pretrained(
162
+ pretrained_model_name_or_path, trust_remote_code=True
163
+ )
164
+
165
+
166
+ ASYNC_REQUEST_FUNCS = {
167
+ "sglang": async_request_openai_completions,
168
+ "vllm": async_request_openai_completions,
169
+ "lmdeploy": async_request_openai_completions,
170
+ }
171
+
172
+
173
+ @dataclass
174
+ class BenchmarkMetrics:
175
+ completed: int
176
+ total_input: int
177
+ total_output: int
178
+ request_throughput: float
179
+ input_throughput: float
180
+ output_throughput: float
181
+ mean_ttft_ms: float
182
+ median_ttft_ms: float
183
+ std_ttft_ms: float
184
+ p99_ttft_ms: float
185
+ mean_tpot_ms: float
186
+ median_tpot_ms: float
187
+ std_tpot_ms: float
188
+ p99_tpot_ms: float
189
+ mean_itl_ms: float
190
+ median_itl_ms: float
191
+ std_itl_ms: float
192
+ p99_itl_ms: float
193
+
194
+
195
+ default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json"
196
+
197
+
198
+ def download_sharegpt_dataset(path):
199
+ url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
200
+
201
+ print(f"Downloading dataset from {url}")
202
+ try:
203
+ response = requests.get(url, stream=True)
204
+ response.raise_for_status()
205
+
206
+ total_size = int(response.headers.get("content-length", 0))
207
+ block_size = 8192
208
+
209
+ with open(path, "wb") as f, tqdm(
210
+ desc="Downloading",
211
+ total=total_size,
212
+ unit="iB",
213
+ unit_scale=True,
214
+ unit_divisor=1024,
215
+ ) as progress_bar:
216
+ for data in response.iter_content(block_size):
217
+ size = f.write(data)
218
+ progress_bar.update(size)
219
+
220
+ print(f"Dataset downloaded and saved to {path}")
221
+ except requests.RequestException as e:
222
+ raise Exception(f"Failed to download dataset: {e}")
223
+
224
+
225
+ def sample_sharegpt_requests(
226
+ dataset_path: str,
227
+ num_requests: int,
228
+ tokenizer: PreTrainedTokenizerBase,
229
+ fixed_output_len: Optional[int] = None,
230
+ ) -> List[Tuple[str, int, int]]:
231
+ if fixed_output_len is not None and fixed_output_len < 4:
232
+ raise ValueError("output_len too small")
233
+
234
+ # Download sharegpt if necessary
235
+ if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path):
236
+ download_sharegpt_dataset(default_sharegpt_path)
237
+ dataset_path = default_sharegpt_path
238
+ else:
239
+ dataset_path = (
240
+ dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
241
+ )
242
+
243
+ # Load the dataset.
244
+ with open(dataset_path) as f:
245
+ dataset = json.load(f)
246
+ # Filter out the conversations with less than 2 turns.
247
+ dataset = [data for data in dataset if len(data["conversations"]) >= 2]
248
+ # Only keep the first two turns of each conversation.
249
+ dataset = [
250
+ (data["conversations"][0]["value"], data["conversations"][1]["value"])
251
+ for data in dataset
252
+ ]
253
+
254
+ # Shuffle the dataset.
255
+ random.shuffle(dataset)
256
+
257
+ # Filter out sequences that are too long or too short
258
+ filtered_dataset: List[Tuple[str, int, int]] = []
259
+ for i in range(len(dataset)):
260
+ if len(filtered_dataset) == num_requests:
261
+ break
262
+
263
+ # Tokenize the prompts and completions.
264
+ prompt = dataset[i][0]
265
+ prompt_token_ids = tokenizer(prompt).input_ids
266
+ completion = dataset[i][1]
267
+ completion_token_ids = tokenizer(completion).input_ids
268
+ prompt_len = len(prompt_token_ids)
269
+ output_len = (
270
+ len(completion_token_ids) if fixed_output_len is None else fixed_output_len
271
+ )
272
+ if prompt_len < 4 or output_len < 4:
273
+ # Prune too short sequences.
274
+ continue
275
+ if prompt_len > 1024 or prompt_len + output_len > 2048:
276
+ # Prune too long sequences.
277
+ continue
278
+ filtered_dataset.append((prompt, prompt_len, output_len))
279
+
280
+ return filtered_dataset
281
+
282
+
283
+ def sample_random_requests(
284
+ input_len: int,
285
+ output_len: int,
286
+ num_prompts: int,
287
+ range_ratio: float,
288
+ tokenizer: PreTrainedTokenizerBase,
289
+ dataset_path: str,
290
+ ) -> List[Tuple[str, int, int]]:
291
+
292
+ input_lens = np.random.randint(
293
+ int(input_len * range_ratio),
294
+ input_len + 1,
295
+ size=num_prompts,
296
+ )
297
+ output_lens = np.random.randint(
298
+ int(output_len * range_ratio),
299
+ output_len + 1,
300
+ size=num_prompts,
301
+ )
302
+
303
+ if True:
304
+ # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
305
+
306
+ # Download sharegpt if necessary
307
+ if not os.path.isfile(dataset_path) and not os.path.isfile(
308
+ default_sharegpt_path
309
+ ):
310
+ download_sharegpt_dataset(default_sharegpt_path)
311
+ dataset_path = default_sharegpt_path
312
+ else:
313
+ dataset_path = (
314
+ dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
315
+ )
316
+
317
+ # Load the dataset.
318
+ with open(dataset_path) as f:
319
+ dataset = json.load(f)
320
+ # Filter out the conversations with less than 2 turns.
321
+ dataset = [data for data in dataset if len(data["conversations"]) >= 2]
322
+ # Only keep the first two turns of each conversation.
323
+ dataset = [
324
+ (data["conversations"][0]["value"], data["conversations"][1]["value"])
325
+ for data in dataset
326
+ ]
327
+
328
+ # Shuffle the dataset.
329
+ random.shuffle(dataset)
330
+
331
+ # Filter out sequences that are too long or too short
332
+ input_requests: List[Tuple[str, int, int]] = []
333
+ for i in range(num_prompts):
334
+ # Tokenize the prompts and completions.
335
+ prompt = dataset[i][0]
336
+ prompt_token_ids = tokenizer(prompt).input_ids
337
+ prompt_len = len(prompt_token_ids)
338
+
339
+ if prompt_len <= input_lens[i]:
340
+ input_ids = prompt_token_ids[: input_lens[i]]
341
+ else:
342
+ ratio = (input_lens[i] + prompt_len - 1) // prompt_len
343
+ input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
344
+ prompt = tokenizer.decode(input_ids)
345
+ input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
346
+ else:
347
+ # Sample token ids from random integers. This can cause some NaN issues.
348
+ offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
349
+ input_requests = []
350
+ for i in range(num_prompts):
351
+ prompt = tokenizer.decode(
352
+ [
353
+ (offsets[i] + i + j) % tokenizer.vocab_size
354
+ for j in range(input_lens[i])
355
+ ]
356
+ )
357
+ input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
358
+
359
+ print(f"#Input tokens: {np.sum(input_lens)}")
360
+ print(f"#Output tokens: {np.sum(output_lens)}")
361
+ return input_requests
362
+
363
+
364
+ async def get_request(
365
+ input_requests: List[Tuple[str, int, int]],
366
+ request_rate: float,
367
+ ) -> AsyncGenerator[Tuple[str, int, int], None]:
368
+ input_requests = iter(input_requests)
369
+ for request in input_requests:
370
+ yield request
371
+
372
+ if request_rate == float("inf"):
373
+ # If the request rate is infinity, then we don't need to wait.
374
+ continue
375
+
376
+ # Sample the request interval from the exponential distribution.
377
+ interval = np.random.exponential(1.0 / request_rate)
378
+ # The next request will be sent after the interval.
379
+ await asyncio.sleep(interval)
380
+
381
+
382
+ def calculate_metrics(
383
+ input_requests: List[Tuple[str, int, int]],
384
+ outputs: List[RequestFuncOutput],
385
+ dur_s: float,
386
+ tokenizer: PreTrainedTokenizerBase,
387
+ ) -> Tuple[BenchmarkMetrics, List[int]]:
388
+ actual_output_lens: List[int] = []
389
+ total_input = 0
390
+ completed = 0
391
+ itls: List[float] = []
392
+ tpots: List[float] = []
393
+ ttfts: List[float] = []
394
+ for i in range(len(outputs)):
395
+ if outputs[i].success:
396
+ # We use the tokenizer to count the number of output tokens for all
397
+ # serving backends instead of looking at len(outputs[i].itl) since
398
+ # multiple output tokens may be bundled together
399
+ # Note : this may inflate the output token count slightly
400
+ output_len = len(
401
+ tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids
402
+ )
403
+ actual_output_lens.append(output_len)
404
+ total_input += input_requests[i][1]
405
+ if output_len > 1:
406
+ tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
407
+ itls += outputs[i].itl
408
+ ttfts.append(outputs[i].ttft)
409
+ completed += 1
410
+ else:
411
+ actual_output_lens.append(0)
412
+
413
+ if completed == 0:
414
+ warnings.warn(
415
+ "All requests failed. This is likely due to a misconfiguration "
416
+ "on the benchmark arguments.",
417
+ stacklevel=2,
418
+ )
419
+ metrics = BenchmarkMetrics(
420
+ completed=completed,
421
+ total_input=total_input,
422
+ total_output=sum(actual_output_lens),
423
+ request_throughput=completed / dur_s,
424
+ input_throughput=total_input / dur_s,
425
+ output_throughput=sum(actual_output_lens) / dur_s,
426
+ mean_ttft_ms=np.mean(ttfts or 0)
427
+ * 1000, # ttfts is empty if streaming is not supported by backend
428
+ median_ttft_ms=np.median(ttfts or 0) * 1000,
429
+ std_ttft_ms=np.std(ttfts or 0) * 1000,
430
+ p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
431
+ mean_tpot_ms=np.mean(tpots or 0) * 1000,
432
+ median_tpot_ms=np.median(tpots or 0) * 1000,
433
+ std_tpot_ms=np.std(tpots or 0) * 1000,
434
+ p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
435
+ mean_itl_ms=np.mean(itls or 0) * 1000,
436
+ median_itl_ms=np.median(itls or 0) * 1000,
437
+ std_itl_ms=np.std(itls or 0) * 1000,
438
+ p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
439
+ )
440
+
441
+ return metrics, actual_output_lens
442
+
443
+
444
+ async def benchmark(
445
+ backend: str,
446
+ api_url: str,
447
+ model_id: str,
448
+ tokenizer: PreTrainedTokenizerBase,
449
+ input_requests: List[Tuple[str, int, int]],
450
+ request_rate: float,
451
+ disable_tqdm: bool,
452
+ ):
453
+ if backend in ASYNC_REQUEST_FUNCS:
454
+ request_func = ASYNC_REQUEST_FUNCS[backend]
455
+ else:
456
+ raise ValueError(f"Unknown backend: {backend}")
457
+
458
+ print("Starting initial single prompt test run...")
459
+ test_prompt, test_prompt_len, test_output_len = input_requests[0]
460
+ test_input = RequestFuncInput(
461
+ model=model_id,
462
+ prompt=test_prompt,
463
+ api_url=api_url,
464
+ prompt_len=test_prompt_len,
465
+ output_len=test_output_len,
466
+ )
467
+ test_output = await request_func(request_func_input=test_input)
468
+ if not test_output.success:
469
+ raise ValueError(
470
+ "Initial test run failed - Please make sure benchmark arguments "
471
+ f"are correctly specified. Error: {test_output.error}"
472
+ )
473
+ else:
474
+ print("Initial test run completed. Starting main benchmark run...")
475
+
476
+ pbar = None if disable_tqdm else tqdm(total=len(input_requests))
477
+
478
+ benchmark_start_time = time.perf_counter()
479
+ tasks: List[asyncio.Task] = []
480
+ async for request in get_request(input_requests, request_rate):
481
+ prompt, prompt_len, output_len = request
482
+ request_func_input = RequestFuncInput(
483
+ model=model_id,
484
+ prompt=prompt,
485
+ api_url=api_url,
486
+ prompt_len=prompt_len,
487
+ output_len=output_len,
488
+ )
489
+ tasks.append(
490
+ asyncio.create_task(
491
+ request_func(request_func_input=request_func_input, pbar=pbar)
492
+ )
493
+ )
494
+ outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
495
+
496
+ if pbar is not None:
497
+ pbar.close()
498
+
499
+ benchmark_duration = time.perf_counter() - benchmark_start_time
500
+
501
+ metrics, actual_output_lens = calculate_metrics(
502
+ input_requests=input_requests,
503
+ outputs=outputs,
504
+ dur_s=benchmark_duration,
505
+ tokenizer=tokenizer,
506
+ )
507
+
508
+ print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
509
+ print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
510
+ print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
511
+ print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
512
+ print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
513
+ print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
514
+ print(
515
+ "{:<40} {:<10.2f}".format(
516
+ "Request throughput (req/s):", metrics.request_throughput
517
+ )
518
+ )
519
+ print(
520
+ "{:<40} {:<10.2f}".format(
521
+ "Input token throughput (tok/s):", metrics.input_throughput
522
+ )
523
+ )
524
+ print(
525
+ "{:<40} {:<10.2f}".format(
526
+ "Output token throughput (tok/s):", metrics.output_throughput
527
+ )
528
+ )
529
+ print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-"))
530
+ print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
531
+ print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
532
+ print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
533
+ print(
534
+ "{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")
535
+ )
536
+ print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
537
+ print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
538
+ print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
539
+ print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-"))
540
+ print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
541
+ print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
542
+ print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
543
+ print("=" * 50)
544
+
545
+ result = {
546
+ "duration": benchmark_duration,
547
+ "completed": metrics.completed,
548
+ "total_input_tokens": metrics.total_input,
549
+ "total_output_tokens": metrics.total_output,
550
+ "request_throughput": metrics.request_throughput,
551
+ "input_throughput": metrics.input_throughput,
552
+ "output_throughput": metrics.output_throughput,
553
+ "mean_ttft_ms": metrics.mean_ttft_ms,
554
+ "median_ttft_ms": metrics.median_ttft_ms,
555
+ "std_ttft_ms": metrics.std_ttft_ms,
556
+ "p99_ttft_ms": metrics.p99_ttft_ms,
557
+ "mean_tpot_ms": metrics.mean_tpot_ms,
558
+ "median_tpot_ms": metrics.median_tpot_ms,
559
+ "std_tpot_ms": metrics.std_tpot_ms,
560
+ "p99_tpot_ms": metrics.p99_tpot_ms,
561
+ "mean_itl_ms": metrics.mean_itl_ms,
562
+ "median_itl_ms": metrics.median_itl_ms,
563
+ "std_itl_ms": metrics.std_itl_ms,
564
+ "p99_itl_ms": metrics.p99_itl_ms,
565
+ "input_lens": [output.prompt_len for output in outputs],
566
+ "output_lens": actual_output_lens,
567
+ "ttfts": [output.ttft for output in outputs],
568
+ "itls": [output.itl for output in outputs],
569
+ "generated_texts": [output.generated_text for output in outputs],
570
+ "errors": [output.error for output in outputs],
571
+ }
572
+ return result
573
+
574
+
575
+ def fire(args: argparse.Namespace):
576
+ random.seed(args.seed)
577
+ np.random.seed(args.seed)
578
+
579
+ if args.port is None:
580
+ args.port = {
581
+ "sglang": 30000,
582
+ "lmdeploy": 23333,
583
+ "vllm": 8000,
584
+ }.get(args.backend, 30000)
585
+
586
+ api_url = (
587
+ f"{args.base_url}/v1/completions"
588
+ if args.base_url
589
+ else f"http://{args.host}:{args.port}/v1/completions"
590
+ )
591
+ model_url = (
592
+ f"{args.base_url}/v1/models"
593
+ if args.base_url
594
+ else f"http://{args.host}:{args.port}/v1/models"
595
+ )
596
+
597
+ if args.model is None:
598
+ try:
599
+ response = requests.get(model_url)
600
+ model_list = response.json().get("data", [])
601
+ args.model = model_list[0]["id"] if model_list else None
602
+ except Exception as e:
603
+ print(f"Failed to fetch model from {model_url}. Error: {e}")
604
+ print(
605
+ "Please specify the correct host and port using `--host` and `--port`."
606
+ )
607
+ sys.exit(1)
608
+
609
+ if args.model is None:
610
+ print("No model specified or found. Please provide a model using `--model`.")
611
+ sys.exit(1)
612
+
613
+ print(f"{args}\n")
614
+
615
+ backend = args.backend
616
+ model_id = args.model
617
+ tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
618
+
619
+ tokenizer = get_tokenizer(tokenizer_id)
620
+
621
+ if args.dataset_name == "sharegpt":
622
+ input_requests = sample_sharegpt_requests(
623
+ dataset_path=args.dataset_path,
624
+ num_requests=args.num_prompts,
625
+ tokenizer=tokenizer,
626
+ fixed_output_len=args.sharegpt_output_len,
627
+ )
628
+ elif args.dataset_name == "random":
629
+ input_requests = sample_random_requests(
630
+ input_len=args.random_input_len,
631
+ output_len=args.random_output_len,
632
+ num_prompts=args.num_prompts,
633
+ range_ratio=args.random_range_ratio,
634
+ tokenizer=tokenizer,
635
+ dataset_path=args.dataset_path,
636
+ )
637
+ else:
638
+ raise ValueError(f"Unknown dataset: {args.dataset_name}")
639
+
640
+ asyncio.run(
641
+ benchmark(
642
+ backend=backend,
643
+ api_url=api_url,
644
+ model_id=model_id,
645
+ tokenizer=tokenizer,
646
+ input_requests=input_requests,
647
+ request_rate=args.request_rate,
648
+ disable_tqdm=args.disable_tqdm,
649
+ )
650
+ )
651
+
652
+
653
+ # to avoid relying on SGLang's components
654
+ def set_ulimit(target_soft_limit=65535):
655
+ resource_type = resource.RLIMIT_NOFILE
656
+ current_soft, current_hard = resource.getrlimit(resource_type)
657
+
658
+ if current_soft < target_soft_limit:
659
+ try:
660
+ resource.setrlimit(resource_type, (target_soft_limit, current_hard))
661
+ except ValueError as e:
662
+ print(f"Fail to set RLIMIT_NOFILE: {e}")
663
+
664
+
665
+ if __name__ == "__main__":
666
+ parser = FlexibleArgumentParser(
667
+ description="Benchmark the online serving throughput."
668
+ )
669
+ parser.add_argument(
670
+ "--backend",
671
+ type=str,
672
+ required=True,
673
+ choices=list(ASYNC_REQUEST_FUNCS.keys()),
674
+ help="Must specify a backend, depending on the LLM Inference Engine.",
675
+ )
676
+ parser.add_argument(
677
+ "--base-url",
678
+ type=str,
679
+ default=None,
680
+ help="Server or API base url if not using http host and port.",
681
+ )
682
+ parser.add_argument(
683
+ "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
684
+ )
685
+ parser.add_argument(
686
+ "--port",
687
+ type=int,
688
+ help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
689
+ )
690
+ parser.add_argument(
691
+ "--dataset-name",
692
+ type=str,
693
+ default="sharegpt",
694
+ choices=["sharegpt", "random"],
695
+ help="Name of the dataset to benchmark on.",
696
+ )
697
+ parser.add_argument(
698
+ "--dataset-path", type=str, default="", help="Path to the dataset."
699
+ )
700
+ parser.add_argument(
701
+ "--model",
702
+ type=str,
703
+ help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
704
+ )
705
+ parser.add_argument(
706
+ "--tokenizer",
707
+ type=str,
708
+ help="Name or path of the tokenizer. If not set, using the model conf.",
709
+ )
710
+ parser.add_argument(
711
+ "--num-prompts",
712
+ type=int,
713
+ default=1000,
714
+ help="Number of prompts to process. Default is 1000.",
715
+ )
716
+ parser.add_argument(
717
+ "--sharegpt-output-len",
718
+ type=int,
719
+ default=None,
720
+ help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
721
+ )
722
+ parser.add_argument(
723
+ "--random-input-len",
724
+ type=int,
725
+ default=1024,
726
+ help="Number of input tokens per request, used only for random dataset.",
727
+ )
728
+ parser.add_argument(
729
+ "--random-output-len",
730
+ type=int,
731
+ default=128,
732
+ help="Number of output tokens per request, used only for random dataset.",
733
+ )
734
+ parser.add_argument(
735
+ "--random-range-ratio",
736
+ type=float,
737
+ default=1.0,
738
+ help="Range of sampled ratio of input/output length, "
739
+ "used only for random dataset.",
740
+ )
741
+ parser.add_argument(
742
+ "--request-rate",
743
+ type=float,
744
+ default=float("inf"),
745
+ help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
746
+ "Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.",
747
+ )
748
+ parser.add_argument("--seed", type=int, default=0, help="Default is 0.")
749
+ parser.add_argument(
750
+ "--disable-tqdm",
751
+ action="store_true",
752
+ help="Specify to disable tqdm progress bar.",
753
+ )
754
+
755
+ set_ulimit()
756
+
757
+ args = parser.parse_args()
758
+ fire(args)