sglang 0.1.21__py3-none-any.whl → 0.1.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. sglang/__init__.py +8 -8
  2. sglang/api.py +1 -1
  3. sglang/backend/vertexai.py +5 -4
  4. sglang/bench.py +627 -0
  5. sglang/bench_latency.py +22 -19
  6. sglang/bench_serving.py +976 -0
  7. sglang/check_env.py +171 -0
  8. sglang/global_config.py +3 -2
  9. sglang/lang/backend/__init__.py +0 -0
  10. sglang/lang/backend/anthropic.py +77 -0
  11. sglang/lang/backend/base_backend.py +80 -0
  12. sglang/lang/backend/litellm.py +90 -0
  13. sglang/lang/backend/openai.py +438 -0
  14. sglang/lang/backend/runtime_endpoint.py +283 -0
  15. sglang/lang/backend/vertexai.py +149 -0
  16. sglang/lang/interpreter.py +1 -0
  17. sglang/lang/tracer.py +1 -1
  18. sglang/launch_server.py +1 -1
  19. sglang/launch_server_llavavid.py +1 -4
  20. sglang/srt/conversation.py +1 -1
  21. sglang/srt/hf_transformers_utils.py +13 -1
  22. sglang/srt/layers/context_flashattention_nopad.py +0 -29
  23. sglang/srt/layers/extend_attention.py +0 -39
  24. sglang/srt/layers/linear.py +869 -0
  25. sglang/srt/layers/logits_processor.py +4 -5
  26. sglang/srt/layers/quantization/__init__.py +49 -0
  27. sglang/srt/layers/quantization/fp8.py +662 -0
  28. sglang/srt/layers/radix_attention.py +39 -24
  29. sglang/srt/layers/token_attention.py +1 -51
  30. sglang/srt/managers/controller/cuda_graph_runner.py +72 -28
  31. sglang/srt/managers/controller/infer_batch.py +90 -63
  32. sglang/srt/managers/controller/manager_multi.py +107 -100
  33. sglang/srt/managers/controller/manager_single.py +76 -96
  34. sglang/srt/managers/controller/model_runner.py +41 -26
  35. sglang/srt/managers/controller/schedule_heuristic.py +8 -3
  36. sglang/srt/managers/controller/tp_worker.py +136 -149
  37. sglang/srt/managers/detokenizer_manager.py +49 -5
  38. sglang/srt/managers/io_struct.py +36 -17
  39. sglang/srt/managers/tokenizer_manager.py +228 -125
  40. sglang/srt/memory_pool.py +32 -11
  41. sglang/srt/model_loader/model_loader.py +277 -0
  42. sglang/srt/model_loader/utils.py +260 -0
  43. sglang/srt/models/chatglm.py +1 -0
  44. sglang/srt/models/dbrx.py +1 -0
  45. sglang/srt/models/deepseek.py +430 -0
  46. sglang/srt/models/gpt_bigcode.py +282 -0
  47. sglang/srt/models/grok.py +1 -0
  48. sglang/srt/models/internlm2.py +317 -0
  49. sglang/srt/models/llama2.py +81 -23
  50. sglang/srt/models/llama_classification.py +1 -0
  51. sglang/srt/models/llava.py +1 -0
  52. sglang/srt/models/llavavid.py +1 -0
  53. sglang/srt/models/minicpm.py +1 -0
  54. sglang/srt/models/mixtral.py +1 -0
  55. sglang/srt/models/mixtral_quant.py +1 -0
  56. sglang/srt/models/qwen.py +1 -0
  57. sglang/srt/models/qwen2.py +6 -0
  58. sglang/srt/models/qwen2_moe.py +7 -4
  59. sglang/srt/models/stablelm.py +1 -0
  60. sglang/srt/openai_api/adapter.py +432 -0
  61. sglang/srt/openai_api/api_adapter.py +432 -0
  62. sglang/srt/openai_api/openai_api_adapter.py +431 -0
  63. sglang/srt/openai_api/openai_protocol.py +207 -0
  64. sglang/srt/openai_api/protocol.py +208 -0
  65. sglang/srt/openai_protocol.py +17 -0
  66. sglang/srt/sampling_params.py +2 -0
  67. sglang/srt/server.py +132 -84
  68. sglang/srt/server_args.py +35 -21
  69. sglang/srt/utils.py +65 -117
  70. sglang/test/test_conversation.py +1 -1
  71. sglang/test/test_openai_protocol.py +1 -1
  72. sglang/test/test_programs.py +1 -1
  73. sglang/test/test_utils.py +2 -2
  74. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/METADATA +162 -168
  75. sglang-0.1.24.dist-info/RECORD +105 -0
  76. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/WHEEL +1 -1
  77. sglang-0.1.21.dist-info/RECORD +0 -82
  78. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/LICENSE +0 -0
  79. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,976 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py
2
+ # Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py
3
+ """
4
+ Benchmark online serving.
5
+
6
+ Usage:
7
+ python3 -m sglang.bench_serving --backend sglang --num-prompt 10
8
+
9
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5
10
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi
11
+ """
12
+
13
+ import argparse
14
+ import asyncio
15
+ import json
16
+ import os
17
+ import random
18
+ import resource
19
+ import sys
20
+ import time
21
+ import traceback
22
+ import warnings
23
+ from argparse import ArgumentParser as FlexibleArgumentParser
24
+ from dataclasses import dataclass, field
25
+ from datetime import datetime
26
+ from typing import AsyncGenerator, List, Optional, Tuple, Union
27
+
28
+ import aiohttp
29
+ import numpy as np
30
+ import requests
31
+ from tqdm.asyncio import tqdm
32
+ from transformers import (
33
+ AutoTokenizer,
34
+ PreTrainedTokenizer,
35
+ PreTrainedTokenizerBase,
36
+ PreTrainedTokenizerFast,
37
+ )
38
+
39
+ AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
40
+
41
+
42
+ @dataclass
43
+ class RequestFuncInput:
44
+ prompt: str
45
+ api_url: str
46
+ prompt_len: int
47
+ output_len: int
48
+ model: str
49
+
50
+
51
+ @dataclass
52
+ class RequestFuncOutput:
53
+ generated_text: str = ""
54
+ success: bool = False
55
+ latency: float = 0.0
56
+ ttft: float = 0.0 # Time to first token
57
+ itl: List[float] = field(default_factory=list) # List of inter-token latencies
58
+ prompt_len: int = 0
59
+ error: str = ""
60
+ output_len: int = 0
61
+
62
+
63
+ def remove_prefix(text: str, prefix: str) -> str:
64
+ return text[len(prefix) :] if text.startswith(prefix) else text
65
+
66
+
67
+ # trt llm not support ignore_eos
68
+ # https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
69
+ async def async_request_trt_llm(
70
+ request_func_input: RequestFuncInput,
71
+ pbar: Optional[tqdm] = None,
72
+ ) -> RequestFuncOutput:
73
+ api_url = request_func_input.api_url
74
+ assert api_url.endswith("generate_stream")
75
+
76
+ async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
77
+ payload = {
78
+ "accumulate_tokens": True,
79
+ "text_input": request_func_input.prompt,
80
+ "temperature": 0.000001,
81
+ "top_p": 1.0,
82
+ "max_tokens": request_func_input.output_len,
83
+ "stream": True,
84
+ "min_length": request_func_input.output_len,
85
+ "end_id": 1048576,
86
+ }
87
+ output = RequestFuncOutput()
88
+ output.prompt_len = request_func_input.prompt_len
89
+
90
+ ttft = 0.0
91
+ st = time.perf_counter()
92
+ most_recent_timestamp = st
93
+ try:
94
+ async with session.post(url=api_url, json=payload) as response:
95
+ if response.status == 200:
96
+ async for chunk_bytes in response.content:
97
+ chunk_bytes = chunk_bytes.strip()
98
+ if not chunk_bytes:
99
+ continue
100
+
101
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:")
102
+
103
+ data = json.loads(chunk)
104
+ output.generated_text += data["text_output"]
105
+ timestamp = time.perf_counter()
106
+ # First token
107
+ if ttft == 0.0:
108
+ ttft = time.perf_counter() - st
109
+ output.ttft = ttft
110
+
111
+ # Decoding phase
112
+ else:
113
+ output.itl.append(timestamp - most_recent_timestamp)
114
+
115
+ most_recent_timestamp = timestamp
116
+
117
+ output.latency = most_recent_timestamp - st
118
+ output.success = True
119
+ output.output_len = request_func_input.output_len
120
+
121
+ else:
122
+ output.error = response.reason or ""
123
+ output.success = False
124
+ except Exception:
125
+ output.success = False
126
+ exc_info = sys.exc_info()
127
+ output.error = "".join(traceback.format_exception(*exc_info))
128
+
129
+ if pbar:
130
+ pbar.update(1)
131
+ return output
132
+
133
+
134
+ # set ignore_eos True by default
135
+ async def async_request_openai_completions(
136
+ request_func_input: RequestFuncInput,
137
+ pbar: Optional[tqdm] = None,
138
+ ) -> RequestFuncOutput:
139
+ api_url = request_func_input.api_url
140
+ assert api_url.endswith(
141
+ "completions"
142
+ ), "OpenAI Completions API URL must end with 'completions'."
143
+
144
+ async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
145
+ payload = {
146
+ "model": request_func_input.model,
147
+ "prompt": request_func_input.prompt,
148
+ "temperature": 0.0,
149
+ "best_of": 1,
150
+ "max_tokens": request_func_input.output_len,
151
+ "stream": not args.disable_stream,
152
+ "ignore_eos": True,
153
+ }
154
+ headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
155
+
156
+ output = RequestFuncOutput()
157
+ output.prompt_len = request_func_input.prompt_len
158
+
159
+ generated_text = ""
160
+ ttft = 0.0
161
+ st = time.perf_counter()
162
+ most_recent_timestamp = st
163
+ try:
164
+ async with session.post(
165
+ url=api_url, json=payload, headers=headers
166
+ ) as response:
167
+ if response.status == 200:
168
+ async for chunk_bytes in response.content:
169
+ chunk_bytes = chunk_bytes.strip()
170
+ if not chunk_bytes:
171
+ continue
172
+
173
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
174
+ latency = time.perf_counter() - st
175
+ if chunk == "[DONE]":
176
+ pass
177
+ else:
178
+ data = json.loads(chunk)
179
+
180
+ # NOTE: Some completion API might have a last
181
+ # usage summary response without a token so we
182
+ # want to check a token was generated
183
+ if data["choices"][0]["text"]:
184
+ timestamp = time.perf_counter()
185
+ # First token
186
+ if ttft == 0.0:
187
+ ttft = time.perf_counter() - st
188
+ output.ttft = ttft
189
+
190
+ # Decoding phase
191
+ output.itl.append(timestamp - most_recent_timestamp)
192
+
193
+ most_recent_timestamp = timestamp
194
+ generated_text += data["choices"][0]["text"]
195
+
196
+ output.generated_text = generated_text
197
+ output.success = True
198
+ output.latency = latency
199
+ output.output_len = request_func_input.output_len
200
+ else:
201
+ output.error = response.reason or ""
202
+ output.success = False
203
+ except Exception:
204
+ output.success = False
205
+ exc_info = sys.exc_info()
206
+ output.error = "".join(traceback.format_exception(*exc_info))
207
+
208
+ if pbar:
209
+ pbar.update(1)
210
+ return output
211
+
212
+
213
+ def get_model(pretrained_model_name_or_path: str) -> str:
214
+ if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true":
215
+ import huggingface_hub.constants
216
+ from modelscope import snapshot_download
217
+
218
+ model_path = snapshot_download(
219
+ model_id=pretrained_model_name_or_path,
220
+ local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
221
+ ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
222
+ )
223
+
224
+ return model_path
225
+ return pretrained_model_name_or_path
226
+
227
+
228
+ def get_tokenizer(
229
+ pretrained_model_name_or_path: str,
230
+ ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
231
+ if pretrained_model_name_or_path is not None and not os.path.exists(
232
+ pretrained_model_name_or_path
233
+ ):
234
+ pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)
235
+ return AutoTokenizer.from_pretrained(
236
+ pretrained_model_name_or_path, trust_remote_code=True
237
+ )
238
+
239
+
240
+ ASYNC_REQUEST_FUNCS = {
241
+ "sglang": async_request_openai_completions,
242
+ "vllm": async_request_openai_completions,
243
+ "lmdeploy": async_request_openai_completions,
244
+ "trt": async_request_trt_llm,
245
+ }
246
+
247
+
248
+ @dataclass
249
+ class BenchmarkMetrics:
250
+ completed: int
251
+ total_input: int
252
+ total_output: int
253
+ total_output_retokenized: int
254
+ request_throughput: float
255
+ input_throughput: float
256
+ output_throughput: float
257
+ output_throughput_retokenized: float
258
+ mean_ttft_ms: float
259
+ median_ttft_ms: float
260
+ std_ttft_ms: float
261
+ p99_ttft_ms: float
262
+ mean_tpot_ms: float
263
+ median_tpot_ms: float
264
+ std_tpot_ms: float
265
+ p99_tpot_ms: float
266
+ mean_itl_ms: float
267
+ median_itl_ms: float
268
+ std_itl_ms: float
269
+ p99_itl_ms: float
270
+ mean_e2e_latency_ms: float
271
+ median_e2e_latency_ms: float
272
+
273
+
274
+ default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json"
275
+
276
+
277
+ def download_sharegpt_dataset(path):
278
+ url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
279
+
280
+ print(f"Downloading dataset from {url}")
281
+ try:
282
+ response = requests.get(url, stream=True)
283
+ response.raise_for_status()
284
+
285
+ total_size = int(response.headers.get("content-length", 0))
286
+ block_size = 8192
287
+
288
+ with open(path, "wb") as f, tqdm(
289
+ desc="Downloading",
290
+ total=total_size,
291
+ unit="iB",
292
+ unit_scale=True,
293
+ unit_divisor=1024,
294
+ ) as progress_bar:
295
+ for data in response.iter_content(block_size):
296
+ size = f.write(data)
297
+ progress_bar.update(size)
298
+
299
+ print(f"Dataset downloaded and saved to {path}")
300
+ except requests.RequestException as e:
301
+ raise Exception(f"Failed to download dataset: {e}")
302
+
303
+
304
+ def sample_sharegpt_requests(
305
+ dataset_path: str,
306
+ num_requests: int,
307
+ tokenizer: PreTrainedTokenizerBase,
308
+ fixed_output_len: Optional[int] = None,
309
+ ) -> List[Tuple[str, int, int]]:
310
+ if fixed_output_len is not None and fixed_output_len < 4:
311
+ raise ValueError("output_len too small")
312
+
313
+ # Download sharegpt if necessary
314
+ if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path):
315
+ download_sharegpt_dataset(default_sharegpt_path)
316
+ dataset_path = default_sharegpt_path
317
+ else:
318
+ dataset_path = (
319
+ dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
320
+ )
321
+
322
+ # Load the dataset.
323
+ with open(dataset_path) as f:
324
+ dataset = json.load(f)
325
+ # Filter out the conversations with less than 2 turns.
326
+ dataset = [data for data in dataset if len(data["conversations"]) >= 2]
327
+ # Only keep the first two turns of each conversation.
328
+ dataset = [
329
+ (data["conversations"][0]["value"], data["conversations"][1]["value"])
330
+ for data in dataset
331
+ ]
332
+
333
+ # Shuffle the dataset.
334
+ random.shuffle(dataset)
335
+
336
+ # Filter out sequences that are too long or too short
337
+ filtered_dataset: List[Tuple[str, int, int]] = []
338
+ for i in range(len(dataset)):
339
+ if len(filtered_dataset) == num_requests:
340
+ break
341
+
342
+ # Tokenize the prompts and completions.
343
+ prompt = dataset[i][0]
344
+ prompt_token_ids = tokenizer(prompt).input_ids
345
+ completion = dataset[i][1]
346
+ completion_token_ids = tokenizer(completion).input_ids
347
+ prompt_len = len(prompt_token_ids)
348
+ output_len = (
349
+ len(completion_token_ids) if fixed_output_len is None else fixed_output_len
350
+ )
351
+ if prompt_len < 4 or output_len < 4:
352
+ # Prune too short sequences.
353
+ continue
354
+ if prompt_len > 1024 or prompt_len + output_len > 2048:
355
+ # Prune too long sequences.
356
+ continue
357
+ filtered_dataset.append((prompt, prompt_len, output_len))
358
+
359
+ return filtered_dataset
360
+
361
+
362
+ def sample_random_requests(
363
+ input_len: int,
364
+ output_len: int,
365
+ num_prompts: int,
366
+ range_ratio: float,
367
+ tokenizer: PreTrainedTokenizerBase,
368
+ dataset_path: str,
369
+ ) -> List[Tuple[str, int, int]]:
370
+
371
+ input_lens = np.random.randint(
372
+ int(input_len * range_ratio),
373
+ input_len + 1,
374
+ size=num_prompts,
375
+ )
376
+ output_lens = np.random.randint(
377
+ int(output_len * range_ratio),
378
+ output_len + 1,
379
+ size=num_prompts,
380
+ )
381
+
382
+ if True:
383
+ # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
384
+
385
+ # Download sharegpt if necessary
386
+ if not os.path.isfile(dataset_path) and not os.path.isfile(
387
+ default_sharegpt_path
388
+ ):
389
+ download_sharegpt_dataset(default_sharegpt_path)
390
+ dataset_path = default_sharegpt_path
391
+ else:
392
+ dataset_path = (
393
+ dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
394
+ )
395
+
396
+ # Load the dataset.
397
+ with open(dataset_path) as f:
398
+ dataset = json.load(f)
399
+ # Filter out the conversations with less than 2 turns.
400
+ dataset = [data for data in dataset if len(data["conversations"]) >= 2]
401
+ # Only keep the first two turns of each conversation.
402
+ dataset = [
403
+ (data["conversations"][0]["value"], data["conversations"][1]["value"])
404
+ for data in dataset
405
+ ]
406
+
407
+ # Shuffle the dataset.
408
+ random.shuffle(dataset)
409
+
410
+ # Filter out sequences that are too long or too short
411
+ input_requests: List[Tuple[str, int, int]] = []
412
+ for i in range(num_prompts):
413
+ # Tokenize the prompts and completions.
414
+ prompt = dataset[i][0]
415
+ prompt_token_ids = tokenizer(prompt).input_ids
416
+ prompt_len = len(prompt_token_ids)
417
+
418
+ if prompt_len <= input_lens[i]:
419
+ input_ids = prompt_token_ids[: input_lens[i]]
420
+ else:
421
+ ratio = (input_lens[i] + prompt_len - 1) // prompt_len
422
+ input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
423
+ prompt = tokenizer.decode(input_ids)
424
+ input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
425
+ else:
426
+ # Sample token ids from random integers. This can cause some NaN issues.
427
+ offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
428
+ input_requests = []
429
+ for i in range(num_prompts):
430
+ prompt = tokenizer.decode(
431
+ [
432
+ (offsets[i] + i + j) % tokenizer.vocab_size
433
+ for j in range(input_lens[i])
434
+ ]
435
+ )
436
+ input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
437
+
438
+ print(f"#Input tokens: {np.sum(input_lens)}")
439
+ print(f"#Output tokens: {np.sum(output_lens)}")
440
+ return input_requests
441
+
442
+
443
+ async def get_request(
444
+ input_requests: List[Tuple[str, int, int]],
445
+ request_rate: float,
446
+ ) -> AsyncGenerator[Tuple[str, int, int], None]:
447
+ input_requests = iter(input_requests)
448
+ for request in input_requests:
449
+ yield request
450
+
451
+ if request_rate == float("inf"):
452
+ # If the request rate is infinity, then we don't need to wait.
453
+ continue
454
+
455
+ # Sample the request interval from the exponential distribution.
456
+ interval = np.random.exponential(1.0 / request_rate)
457
+ # The next request will be sent after the interval.
458
+ await asyncio.sleep(interval)
459
+
460
+
461
+ def calculate_metrics(
462
+ input_requests: List[Tuple[str, int, int]],
463
+ outputs: List[RequestFuncOutput],
464
+ dur_s: float,
465
+ tokenizer: PreTrainedTokenizerBase,
466
+ backend: str,
467
+ ) -> Tuple[BenchmarkMetrics, List[int]]:
468
+ output_lens: List[int] = []
469
+ retokenized_output_lens: List[int] = []
470
+ total_input = 0
471
+ completed = 0
472
+ itls: List[float] = []
473
+ tpots: List[float] = []
474
+ ttfts: List[float] = []
475
+ e2e_latencies: List[float] = []
476
+ for i in range(len(outputs)):
477
+ if outputs[i].success:
478
+ output_len = outputs[i].output_len
479
+ output_lens.append(output_len)
480
+ retokenized_output_len = len(
481
+ tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids
482
+ )
483
+ retokenized_output_lens.append(retokenized_output_len)
484
+ total_input += input_requests[i][1]
485
+ if output_len > 1:
486
+ tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
487
+ itls += outputs[i].itl
488
+ ttfts.append(outputs[i].ttft)
489
+
490
+ e2e_latencies.append(outputs[i].latency)
491
+
492
+ completed += 1
493
+ else:
494
+ output_lens.append(0)
495
+ retokenized_output_lens.append(0)
496
+
497
+ if completed == 0:
498
+ warnings.warn(
499
+ "All requests failed. This is likely due to a misconfiguration "
500
+ "on the benchmark arguments.",
501
+ stacklevel=2,
502
+ )
503
+ metrics = BenchmarkMetrics(
504
+ completed=completed,
505
+ total_input=total_input,
506
+ total_output=sum(output_lens),
507
+ total_output_retokenized=sum(retokenized_output_lens),
508
+ request_throughput=completed / dur_s,
509
+ input_throughput=total_input / dur_s,
510
+ output_throughput=sum(output_lens) / dur_s,
511
+ output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,
512
+ mean_ttft_ms=np.mean(ttfts or 0)
513
+ * 1000, # ttfts is empty if streaming is not supported by backend
514
+ median_ttft_ms=np.median(ttfts or 0) * 1000,
515
+ std_ttft_ms=np.std(ttfts or 0) * 1000,
516
+ p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
517
+ mean_tpot_ms=np.mean(tpots or 0) * 1000,
518
+ median_tpot_ms=np.median(tpots or 0) * 1000,
519
+ std_tpot_ms=np.std(tpots or 0) * 1000,
520
+ p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
521
+ mean_itl_ms=np.mean(itls or 0) * 1000,
522
+ median_itl_ms=np.median(itls or 0) * 1000,
523
+ std_itl_ms=np.std(itls or 0) * 1000,
524
+ p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
525
+ mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
526
+ median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
527
+ )
528
+
529
+ return metrics, output_lens
530
+
531
+
532
+ async def benchmark(
533
+ backend: str,
534
+ api_url: str,
535
+ model_id: str,
536
+ tokenizer: PreTrainedTokenizerBase,
537
+ input_requests: List[Tuple[str, int, int]],
538
+ request_rate: float,
539
+ disable_tqdm: bool,
540
+ enable_multi: bool,
541
+ ):
542
+ if backend in ASYNC_REQUEST_FUNCS:
543
+ request_func = ASYNC_REQUEST_FUNCS[backend]
544
+ else:
545
+ raise ValueError(f"Unknown backend: {backend}")
546
+
547
+ print("Starting initial single prompt test run...")
548
+ test_prompt, test_prompt_len, test_output_len = input_requests[0]
549
+ test_input = RequestFuncInput(
550
+ model=model_id,
551
+ prompt=test_prompt,
552
+ api_url=api_url,
553
+ prompt_len=test_prompt_len,
554
+ output_len=test_output_len,
555
+ )
556
+ test_output = await request_func(request_func_input=test_input)
557
+ if not test_output.success:
558
+ raise ValueError(
559
+ "Initial test run failed - Please make sure benchmark arguments "
560
+ f"are correctly specified. Error: {test_output.error}"
561
+ )
562
+ else:
563
+ print("Initial test run completed. Starting main benchmark run...")
564
+
565
+ pbar = None if disable_tqdm else tqdm(total=len(input_requests))
566
+
567
+ benchmark_start_time = time.perf_counter()
568
+ tasks: List[asyncio.Task] = []
569
+ async for request in get_request(input_requests, request_rate):
570
+ prompt, prompt_len, output_len = request
571
+ request_func_input = RequestFuncInput(
572
+ model=model_id,
573
+ prompt=prompt,
574
+ api_url=api_url,
575
+ prompt_len=prompt_len,
576
+ output_len=output_len,
577
+ )
578
+ tasks.append(
579
+ asyncio.create_task(
580
+ request_func(request_func_input=request_func_input, pbar=pbar)
581
+ )
582
+ )
583
+ outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
584
+
585
+ if pbar is not None:
586
+ pbar.close()
587
+
588
+ benchmark_duration = time.perf_counter() - benchmark_start_time
589
+
590
+ metrics, output_lens = calculate_metrics(
591
+ input_requests=input_requests,
592
+ outputs=outputs,
593
+ dur_s=benchmark_duration,
594
+ tokenizer=tokenizer,
595
+ backend=backend,
596
+ )
597
+
598
+ print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
599
+ print("{:<40} {:<10}".format("Backend:", backend))
600
+ print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
601
+ print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
602
+ print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
603
+ print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
604
+ print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
605
+ print(
606
+ "{:<40} {:<10}".format(
607
+ "Total generated tokens (retokenized):", metrics.total_output_retokenized
608
+ )
609
+ )
610
+ print(
611
+ "{:<40} {:<10.2f}".format(
612
+ "Request throughput (req/s):", metrics.request_throughput
613
+ )
614
+ )
615
+ print(
616
+ "{:<40} {:<10.2f}".format(
617
+ "Input token throughput (tok/s):", metrics.input_throughput
618
+ )
619
+ )
620
+ print(
621
+ "{:<40} {:<10.2f}".format(
622
+ "Output token throughput (tok/s):", metrics.output_throughput
623
+ )
624
+ )
625
+ print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
626
+ print(
627
+ "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
628
+ )
629
+ print(
630
+ "{:<40} {:<10.2f}".format(
631
+ "Median E2E Latency (ms):", metrics.median_e2e_latency_ms
632
+ )
633
+ )
634
+ print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-"))
635
+ print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
636
+ print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
637
+ print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
638
+ print(
639
+ "{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")
640
+ )
641
+ print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
642
+ print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
643
+ print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
644
+ print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-"))
645
+ print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
646
+ print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
647
+ print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
648
+ print("=" * 50)
649
+
650
+ if (
651
+ metrics.median_ttft_ms is not None
652
+ and metrics.mean_itl_ms is not None
653
+ and metrics.output_throughput is not None
654
+ ):
655
+ result = {
656
+ "backend": args.backend,
657
+ "dataset_name": args.dataset_name,
658
+ "request_rate": request_rate,
659
+ "total_input": metrics.total_input,
660
+ "total_output": metrics.total_output,
661
+ "total_output_retokenized": metrics.total_output_retokenized,
662
+ "mean_e2e_latency": metrics.mean_e2e_latency_ms,
663
+ "median_e2e_latency": metrics.median_e2e_latency_ms,
664
+ "median_ttft": metrics.median_ttft_ms,
665
+ "median_itl": metrics.median_itl_ms,
666
+ "output_token_throughput": metrics.output_throughput,
667
+ "sharegpt_output_len": args.sharegpt_output_len,
668
+ "random_input_len": args.random_input_len,
669
+ "random_output_len": args.random_output_len,
670
+ "random_range_ratio": args.random_range_ratio,
671
+ "benchmark_duration": benchmark_duration,
672
+ }
673
+ else:
674
+ print(f"Error running benchmark for request rate: {request_rate}")
675
+ print("-" * 30)
676
+
677
+ # Determine output file name
678
+ if args.output_file:
679
+ output_file_name = args.output_file
680
+ else:
681
+ now = datetime.now().strftime("%m%d")
682
+ if args.dataset_name == "random":
683
+ output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
684
+ else:
685
+ output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
686
+
687
+ # Append results to a JSONL file
688
+ with open(output_file_name, "a") as file:
689
+ file.write(json.dumps(result) + "\n")
690
+
691
+ result = {
692
+ "duration": benchmark_duration,
693
+ "completed": metrics.completed,
694
+ "total_input_tokens": metrics.total_input,
695
+ "total_output_tokens": metrics.total_output,
696
+ "total_output_tokens_retokenized": metrics.total_output_retokenized,
697
+ "request_throughput": metrics.request_throughput,
698
+ "input_throughput": metrics.input_throughput,
699
+ "output_throughput": metrics.output_throughput,
700
+ "mean_ttft_ms": metrics.mean_ttft_ms,
701
+ "median_ttft_ms": metrics.median_ttft_ms,
702
+ "std_ttft_ms": metrics.std_ttft_ms,
703
+ "p99_ttft_ms": metrics.p99_ttft_ms,
704
+ "mean_tpot_ms": metrics.mean_tpot_ms,
705
+ "median_tpot_ms": metrics.median_tpot_ms,
706
+ "std_tpot_ms": metrics.std_tpot_ms,
707
+ "p99_tpot_ms": metrics.p99_tpot_ms,
708
+ "mean_itl_ms": metrics.mean_itl_ms,
709
+ "median_itl_ms": metrics.median_itl_ms,
710
+ "std_itl_ms": metrics.std_itl_ms,
711
+ "p99_itl_ms": metrics.p99_itl_ms,
712
+ "input_lens": [output.prompt_len for output in outputs],
713
+ "output_lens": output_lens,
714
+ "ttfts": [output.ttft for output in outputs],
715
+ "itls": [output.itl for output in outputs],
716
+ "generated_texts": [output.generated_text for output in outputs],
717
+ "errors": [output.error for output in outputs],
718
+ "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
719
+ "median_e2e_latency_ms": metrics.median_e2e_latency_ms,
720
+ }
721
+ return result
722
+
723
+
724
+ def parse_request_rate_range(request_rate_range):
725
+ if len(request_rate_range.split(",")) == 3:
726
+ start, stop, step = map(int, request_rate_range.split(","))
727
+ return list(range(start, stop, step))
728
+ else:
729
+ return list(map(int, request_rate_range.split(",")))
730
+
731
+
732
+ def check_chat_template(model_path):
733
+ try:
734
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
735
+ return "chat_template" in tokenizer.init_kwargs
736
+ except Exception as e:
737
+ print(f"Fail to load tokenizer config with error={e}")
738
+ return False
739
+
740
+
741
+ def fire(args: argparse.Namespace):
742
+ random.seed(args.seed)
743
+ np.random.seed(args.seed)
744
+
745
+ if args.port is None:
746
+ args.port = {
747
+ "sglang": 30000,
748
+ "lmdeploy": 23333,
749
+ "vllm": 8000,
750
+ "trt": 8000,
751
+ }.get(args.backend, 30000)
752
+
753
+ api_url = (
754
+ f"{args.base_url}/v1/completions"
755
+ if args.base_url
756
+ else f"http://{args.host}:{args.port}/v1/completions"
757
+ )
758
+ model_url = (
759
+ f"{args.base_url}/v1/models"
760
+ if args.base_url
761
+ else f"http://{args.host}:{args.port}/v1/models"
762
+ )
763
+
764
+ if args.backend == "trt":
765
+ api_url = (
766
+ f"{args.base_url}/v2/models/ensemble/generate_stream"
767
+ if args.base_url
768
+ else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream"
769
+ )
770
+ if args.model is None:
771
+ print("Please provide a model using `--model` when using `trt` backend.")
772
+ sys.exit(1)
773
+
774
+ if args.model is None:
775
+ try:
776
+ response = requests.get(model_url)
777
+ model_list = response.json().get("data", [])
778
+ args.model = model_list[0]["id"] if model_list else None
779
+ except Exception as e:
780
+ print(f"Failed to fetch model from {model_url}. Error: {e}")
781
+ print(
782
+ "Please specify the correct host and port using `--host` and `--port`."
783
+ )
784
+ sys.exit(1)
785
+
786
+ if args.model is None:
787
+ print("No model specified or found. Please provide a model using `--model`.")
788
+ sys.exit(1)
789
+
790
+ if not check_chat_template(args.model):
791
+ print(
792
+ "\nWARNING It is recommended to use the `Chat` or `Instruct` model for benchmarking.\n"
793
+ "Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\n"
794
+ )
795
+
796
+ print(f"{args}\n")
797
+
798
+ backend = args.backend
799
+ model_id = args.model
800
+ tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
801
+
802
+ tokenizer = get_tokenizer(tokenizer_id)
803
+
804
+ if args.dataset_name == "sharegpt":
805
+ input_requests = sample_sharegpt_requests(
806
+ dataset_path=args.dataset_path,
807
+ num_requests=args.num_prompts,
808
+ tokenizer=tokenizer,
809
+ fixed_output_len=args.sharegpt_output_len,
810
+ )
811
+ elif args.dataset_name == "random":
812
+ input_requests = sample_random_requests(
813
+ input_len=args.random_input_len,
814
+ output_len=args.random_output_len,
815
+ num_prompts=args.num_prompts,
816
+ range_ratio=args.random_range_ratio,
817
+ tokenizer=tokenizer,
818
+ dataset_path=args.dataset_path,
819
+ )
820
+ else:
821
+ raise ValueError(f"Unknown dataset: {args.dataset_name}")
822
+
823
+ if args.multi:
824
+ request_rates = parse_request_rate_range(args.request_rate_range)
825
+
826
+ for rate in request_rates:
827
+ asyncio.run(
828
+ benchmark(
829
+ backend=backend,
830
+ api_url=api_url,
831
+ model_id=model_id,
832
+ tokenizer=tokenizer,
833
+ input_requests=input_requests,
834
+ request_rate=rate,
835
+ disable_tqdm=args.disable_tqdm,
836
+ enable_multi=args.multi,
837
+ )
838
+ )
839
+ else:
840
+ asyncio.run(
841
+ benchmark(
842
+ backend=backend,
843
+ api_url=api_url,
844
+ model_id=model_id,
845
+ tokenizer=tokenizer,
846
+ input_requests=input_requests,
847
+ request_rate=args.request_rate,
848
+ disable_tqdm=args.disable_tqdm,
849
+ enable_multi=args.multi,
850
+ )
851
+ )
852
+
853
+
854
+ # to avoid relying on SGLang's components
855
+ def set_ulimit(target_soft_limit=65535):
856
+ resource_type = resource.RLIMIT_NOFILE
857
+ current_soft, current_hard = resource.getrlimit(resource_type)
858
+
859
+ if current_soft < target_soft_limit:
860
+ try:
861
+ resource.setrlimit(resource_type, (target_soft_limit, current_hard))
862
+ except ValueError as e:
863
+ print(f"Fail to set RLIMIT_NOFILE: {e}")
864
+
865
+
866
+ if __name__ == "__main__":
867
+ parser = FlexibleArgumentParser(
868
+ description="Benchmark the online serving throughput."
869
+ )
870
+ parser.add_argument(
871
+ "--backend",
872
+ type=str,
873
+ required=True,
874
+ choices=list(ASYNC_REQUEST_FUNCS.keys()),
875
+ help="Must specify a backend, depending on the LLM Inference Engine.",
876
+ )
877
+ parser.add_argument(
878
+ "--base-url",
879
+ type=str,
880
+ default=None,
881
+ help="Server or API base url if not using http host and port.",
882
+ )
883
+ parser.add_argument(
884
+ "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
885
+ )
886
+ parser.add_argument(
887
+ "--port",
888
+ type=int,
889
+ help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
890
+ )
891
+ parser.add_argument(
892
+ "--dataset-name",
893
+ type=str,
894
+ default="sharegpt",
895
+ choices=["sharegpt", "random"],
896
+ help="Name of the dataset to benchmark on.",
897
+ )
898
+ parser.add_argument(
899
+ "--dataset-path", type=str, default="", help="Path to the dataset."
900
+ )
901
+ parser.add_argument(
902
+ "--model",
903
+ type=str,
904
+ help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
905
+ )
906
+ parser.add_argument(
907
+ "--tokenizer",
908
+ type=str,
909
+ help="Name or path of the tokenizer. If not set, using the model conf.",
910
+ )
911
+ parser.add_argument(
912
+ "--num-prompts",
913
+ type=int,
914
+ default=1000,
915
+ help="Number of prompts to process. Default is 1000.",
916
+ )
917
+ parser.add_argument(
918
+ "--sharegpt-output-len",
919
+ type=int,
920
+ default=None,
921
+ help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
922
+ )
923
+ parser.add_argument(
924
+ "--random-input-len",
925
+ type=int,
926
+ default=1024,
927
+ help="Number of input tokens per request, used only for random dataset.",
928
+ )
929
+ parser.add_argument(
930
+ "--random-output-len",
931
+ type=int,
932
+ default=128,
933
+ help="Number of output tokens per request, used only for random dataset.",
934
+ )
935
+ parser.add_argument(
936
+ "--random-range-ratio",
937
+ type=float,
938
+ default=1.0,
939
+ help="Range of sampled ratio of input/output length, "
940
+ "used only for random dataset.",
941
+ )
942
+ parser.add_argument(
943
+ "--request-rate",
944
+ type=float,
945
+ default=float("inf"),
946
+ help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
947
+ "Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.",
948
+ )
949
+ parser.add_argument("--seed", type=int, default=0, help="Default is 0.")
950
+ parser.add_argument(
951
+ "--disable-tqdm",
952
+ action="store_true",
953
+ help="Specify to disable tqdm progress bar.",
954
+ )
955
+ parser.add_argument(
956
+ "--multi",
957
+ action="store_true",
958
+ help="Use request rate range rather than single value.",
959
+ )
960
+ parser.add_argument(
961
+ "--request-rate-range",
962
+ type=str,
963
+ default="2,34,2",
964
+ help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.",
965
+ )
966
+ parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
967
+ parser.add_argument(
968
+ "--disable-stream",
969
+ action="store_true",
970
+ help="Disable streaming mode.",
971
+ )
972
+
973
+ set_ulimit()
974
+
975
+ args = parser.parse_args()
976
+ fire(args)