sglang 0.3.5.post1__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. sglang/bench_latency.py +1 -553
  2. sglang/bench_offline_throughput.py +337 -0
  3. sglang/bench_one_batch.py +474 -0
  4. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  5. sglang/bench_serving.py +115 -31
  6. sglang/check_env.py +3 -6
  7. sglang/srt/constrained/base_grammar_backend.py +4 -3
  8. sglang/srt/constrained/outlines_backend.py +39 -26
  9. sglang/srt/constrained/xgrammar_backend.py +58 -14
  10. sglang/srt/layers/activation.py +3 -0
  11. sglang/srt/layers/attention/flashinfer_backend.py +93 -48
  12. sglang/srt/layers/attention/triton_backend.py +9 -7
  13. sglang/srt/layers/custom_op_util.py +26 -0
  14. sglang/srt/layers/fused_moe/fused_moe.py +11 -4
  15. sglang/srt/layers/fused_moe/patch.py +4 -2
  16. sglang/srt/layers/layernorm.py +4 -0
  17. sglang/srt/layers/logits_processor.py +10 -10
  18. sglang/srt/layers/sampler.py +4 -8
  19. sglang/srt/layers/torchao_utils.py +2 -0
  20. sglang/srt/managers/data_parallel_controller.py +74 -9
  21. sglang/srt/managers/detokenizer_manager.py +1 -14
  22. sglang/srt/managers/io_struct.py +27 -0
  23. sglang/srt/managers/schedule_batch.py +104 -38
  24. sglang/srt/managers/schedule_policy.py +5 -1
  25. sglang/srt/managers/scheduler.py +210 -56
  26. sglang/srt/managers/session_controller.py +62 -0
  27. sglang/srt/managers/tokenizer_manager.py +38 -0
  28. sglang/srt/managers/tp_worker.py +12 -1
  29. sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
  30. sglang/srt/model_executor/cuda_graph_runner.py +43 -6
  31. sglang/srt/model_executor/forward_batch_info.py +109 -15
  32. sglang/srt/model_executor/model_runner.py +102 -43
  33. sglang/srt/model_parallel.py +98 -0
  34. sglang/srt/models/deepseek_v2.py +147 -44
  35. sglang/srt/models/gemma2.py +9 -8
  36. sglang/srt/models/llava.py +1 -1
  37. sglang/srt/models/llavavid.py +1 -1
  38. sglang/srt/models/olmo.py +3 -3
  39. sglang/srt/models/phi3_small.py +447 -0
  40. sglang/srt/models/qwen2_vl.py +13 -6
  41. sglang/srt/models/torch_native_llama.py +94 -78
  42. sglang/srt/openai_api/adapter.py +11 -4
  43. sglang/srt/openai_api/protocol.py +30 -27
  44. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  45. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  46. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  47. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  48. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  49. sglang/srt/sampling/sampling_batch_info.py +58 -57
  50. sglang/srt/sampling/sampling_params.py +3 -3
  51. sglang/srt/server.py +29 -2
  52. sglang/srt/server_args.py +97 -60
  53. sglang/srt/utils.py +103 -51
  54. sglang/test/runners.py +25 -6
  55. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  56. sglang/test/test_utils.py +33 -22
  57. sglang/version.py +1 -1
  58. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
  59. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/RECORD +62 -56
  60. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
  61. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
  62. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,474 @@
1
+ """
2
+ Benchmark the latency of running a single static batch without a server.
3
+
4
+ This script does not launch a server and uses the low-level APIs.
5
+ It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
6
+
7
+ # Usage (latency test)
8
+ ## with dummy weights:
9
+ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
10
+ ## sweep through multiple data points and store (append) the results in a jsonl file:
11
+ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
12
+
13
+ # Usage (correctness test):
14
+ python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
15
+
16
+ ## Reference output (of the correctness test above, can be gpu dependent):
17
+ input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]]
18
+
19
+ prefill logits (first half): tensor([[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
20
+ [-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
21
+ [ -9.1875, -10.2500, 2.7129, ..., -4.3359, -4.0664, -4.1328]],
22
+ device='cuda:0')
23
+
24
+ prefill logits (final): tensor([[-8.3125, -7.1172, 3.3457, ..., -4.9570, -4.1328, -3.4141],
25
+ [-8.9141, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0781],
26
+ [-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4570]],
27
+ device='cuda:0')
28
+
29
+ ========== Prompt 0 ==========
30
+ <s> The capital of France is Paris.
31
+ The capital of the United States is Washington, D.C.
32
+
33
+
34
+ ========== Prompt 1 ==========
35
+ <s> The capital of the United Kindom is London.
36
+ The capital of the United Kingdom is London.
37
+ The capital of the
38
+
39
+ ========== Prompt 2 ==========
40
+ <s> Today is a sunny day and I like to go for a walk in the park.
41
+ I'm going to the park
42
+ """
43
+
44
+ import argparse
45
+ import dataclasses
46
+ import itertools
47
+ import json
48
+ import logging
49
+ import multiprocessing
50
+ import time
51
+ from typing import Tuple
52
+
53
+ import numpy as np
54
+ import torch
55
+ import torch.distributed as dist
56
+
57
+ from sglang.srt.configs.model_config import ModelConfig
58
+ from sglang.srt.hf_transformers_utils import get_tokenizer
59
+ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
60
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
61
+ from sglang.srt.model_executor.model_runner import ModelRunner
62
+ from sglang.srt.sampling.sampling_params import SamplingParams
63
+ from sglang.srt.server import _set_envs_and_config
64
+ from sglang.srt.server_args import PortArgs, ServerArgs
65
+ from sglang.srt.utils import (
66
+ configure_logger,
67
+ kill_child_process,
68
+ suppress_other_loggers,
69
+ )
70
+
71
+
72
+ @dataclasses.dataclass
73
+ class BenchArgs:
74
+ run_name: str = "default"
75
+ batch_size: Tuple[int] = (1,)
76
+ input_len: Tuple[int] = (1024,)
77
+ output_len: Tuple[int] = (16,)
78
+ result_filename: str = "result.jsonl"
79
+ correctness_test: bool = False
80
+ # This is only used for correctness test
81
+ cut_len: int = 4
82
+
83
+ @staticmethod
84
+ def add_cli_args(parser: argparse.ArgumentParser):
85
+ parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
86
+ parser.add_argument(
87
+ "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
88
+ )
89
+ parser.add_argument(
90
+ "--input-len", type=int, nargs="+", default=BenchArgs.input_len
91
+ )
92
+ parser.add_argument(
93
+ "--output-len", type=int, nargs="+", default=BenchArgs.output_len
94
+ )
95
+ parser.add_argument(
96
+ "--result-filename", type=str, default=BenchArgs.result_filename
97
+ )
98
+ parser.add_argument("--correctness-test", action="store_true")
99
+ parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
100
+
101
+ @classmethod
102
+ def from_cli_args(cls, args: argparse.Namespace):
103
+ # use the default value's type to case the args into correct types.
104
+ attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
105
+ return cls(
106
+ **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
107
+ )
108
+
109
+
110
+ def load_model(server_args, port_args, tp_rank):
111
+ suppress_other_loggers()
112
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
113
+
114
+ model_config = ModelConfig(
115
+ server_args.model_path,
116
+ trust_remote_code=server_args.trust_remote_code,
117
+ context_length=server_args.context_length,
118
+ model_override_args=server_args.json_model_override_args,
119
+ )
120
+ model_runner = ModelRunner(
121
+ model_config=model_config,
122
+ mem_fraction_static=server_args.mem_fraction_static,
123
+ gpu_id=tp_rank,
124
+ tp_rank=tp_rank,
125
+ tp_size=server_args.tp_size,
126
+ nccl_port=port_args.nccl_port,
127
+ server_args=server_args,
128
+ )
129
+ rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
130
+ tokenizer = get_tokenizer(
131
+ server_args.tokenizer_path,
132
+ tokenizer_mode=server_args.tokenizer_mode,
133
+ trust_remote_code=server_args.trust_remote_code,
134
+ )
135
+ if server_args.tp_size > 1:
136
+ dist.barrier()
137
+ return model_runner, tokenizer
138
+
139
+
140
+ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
141
+ prompts = [
142
+ "The capital of France is",
143
+ "The capital of the United Kindom is",
144
+ "Today is a sunny day and I like",
145
+ ]
146
+ input_ids = [tokenizer.encode(p) for p in prompts]
147
+ sampling_params = SamplingParams(
148
+ temperature=0,
149
+ max_new_tokens=BenchArgs.output_len,
150
+ )
151
+
152
+ reqs = []
153
+ for i in range(len(prompts)):
154
+ assert len(input_ids[i]) > bench_args.cut_len
155
+
156
+ tmp_input_ids = input_ids[i][: bench_args.cut_len]
157
+ req = Req(
158
+ rid=i,
159
+ origin_input_text=prompts[i],
160
+ origin_input_ids=tmp_input_ids,
161
+ sampling_params=sampling_params,
162
+ )
163
+ req.prefix_indices = []
164
+ req.fill_ids = req.origin_input_ids
165
+ req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
166
+ reqs.append(req)
167
+
168
+ return input_ids, reqs
169
+
170
+
171
+ def prepare_extend_inputs_for_correctness_test(
172
+ bench_args, input_ids, reqs, model_runner
173
+ ):
174
+ for i in range(len(reqs)):
175
+ req = reqs[i]
176
+ req.fill_ids += input_ids[i][bench_args.cut_len :]
177
+ req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
178
+ i, : bench_args.cut_len
179
+ ]
180
+ req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
181
+ return reqs
182
+
183
+
184
+ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
185
+ input_ids = np.ones((batch_size, input_len), dtype=np.int32)
186
+ sampling_params = SamplingParams(
187
+ temperature=0,
188
+ max_new_tokens=BenchArgs.output_len,
189
+ )
190
+
191
+ reqs = []
192
+ for i in range(len(input_ids)):
193
+ req = Req(
194
+ rid=i,
195
+ origin_input_text="",
196
+ origin_input_ids=list(input_ids[i]),
197
+ sampling_params=sampling_params,
198
+ )
199
+ req.prefix_indices = []
200
+ req.fill_ids = req.origin_input_ids
201
+ req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
202
+ reqs.append(req)
203
+
204
+ return reqs
205
+
206
+
207
+ @torch.no_grad
208
+ def extend(reqs, model_runner):
209
+ batch = ScheduleBatch.init_new(
210
+ reqs=reqs,
211
+ req_to_token_pool=model_runner.req_to_token_pool,
212
+ token_to_kv_pool=model_runner.token_to_kv_pool,
213
+ tree_cache=None,
214
+ model_config=model_runner.model_config,
215
+ )
216
+ batch.prepare_for_extend()
217
+ model_worker_batch = batch.get_model_worker_batch()
218
+ forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
219
+ logits_output = model_runner.forward(forward_batch)
220
+ next_token_ids = model_runner.sample(logits_output, forward_batch)
221
+ return next_token_ids, logits_output.next_token_logits, batch
222
+
223
+
224
+ @torch.no_grad
225
+ def decode(input_token_ids, batch, model_runner):
226
+ batch.output_ids = input_token_ids
227
+ batch.prepare_for_decode()
228
+ model_worker_batch = batch.get_model_worker_batch()
229
+ forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
230
+ logits_output = model_runner.forward(forward_batch)
231
+ next_token_ids = model_runner.sample(logits_output, forward_batch)
232
+ return next_token_ids, logits_output.next_token_logits
233
+
234
+
235
+ def correctness_test(
236
+ server_args,
237
+ port_args,
238
+ bench_args,
239
+ tp_rank,
240
+ ):
241
+ # Configure the logger
242
+ configure_logger(server_args, prefix=f" TP{tp_rank}")
243
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
244
+
245
+ # Load the model
246
+ model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
247
+
248
+ # Prepare inputs
249
+ input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
250
+ rank_print(f"\n{input_ids=}\n")
251
+
252
+ if bench_args.cut_len > 0:
253
+ # Prefill
254
+ next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
255
+ rank_print(f"prefill logits (first half): {next_token_logits} \n")
256
+
257
+ # Prepare extend inputs
258
+ reqs = prepare_extend_inputs_for_correctness_test(
259
+ bench_args, input_ids, reqs, model_runner
260
+ )
261
+
262
+ # Extend (prefill w/ KV cache)
263
+ next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
264
+ rank_print(f"prefill logits (final): {next_token_logits} \n")
265
+
266
+ # Decode
267
+ output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
268
+ for _ in range(bench_args.output_len[0] - 1):
269
+ next_token_ids, _ = decode(next_token_ids, batch, model_runner)
270
+ next_token_ids_list = next_token_ids.tolist()
271
+ for i in range(len(reqs)):
272
+ output_ids[i].append(next_token_ids_list[i])
273
+
274
+ # Print output texts
275
+ for i in range(len(reqs)):
276
+ rank_print(f"========== Prompt {i} ==========")
277
+ rank_print(tokenizer.decode(output_ids[i]), "\n")
278
+
279
+
280
+ def synchronize(device):
281
+ if device == "cuda":
282
+ torch.cuda.synchronize()
283
+ elif device == "xpu":
284
+ torch.xpu.synchronize()
285
+
286
+
287
+ def latency_test_run_once(
288
+ run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device
289
+ ):
290
+ max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
291
+ if batch_size > max_batch_size:
292
+ rank_print(
293
+ f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
294
+ )
295
+ return
296
+
297
+ # Clear the pools.
298
+ model_runner.req_to_token_pool.clear()
299
+ model_runner.token_to_kv_pool.clear()
300
+
301
+ measurement_results = {
302
+ "run_name": run_name,
303
+ "batch_size": batch_size,
304
+ "input_len": input_len,
305
+ "output_len": output_len,
306
+ }
307
+
308
+ tot_latency = 0
309
+
310
+ # Prefill
311
+ synchronize(device)
312
+ tic = time.time()
313
+ next_token_ids, _, batch = extend(reqs, model_runner)
314
+ synchronize(device)
315
+ prefill_latency = time.time() - tic
316
+ tot_latency += prefill_latency
317
+ throughput = input_len * batch_size / prefill_latency
318
+ rank_print(
319
+ f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
320
+ )
321
+ measurement_results["prefill_latency"] = prefill_latency
322
+ measurement_results["prefill_throughput"] = throughput
323
+
324
+ # Decode
325
+ decode_latencies = []
326
+ for i in range(output_len - 1):
327
+ synchronize(device)
328
+ tic = time.time()
329
+ next_token_ids, _ = decode(next_token_ids, batch, model_runner)
330
+ synchronize(device)
331
+ latency = time.time() - tic
332
+ tot_latency += latency
333
+ throughput = batch_size / latency
334
+ decode_latencies.append(latency)
335
+ if i < 5:
336
+ rank_print(
337
+ f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
338
+ )
339
+
340
+ # Record decode timing from 2nd output
341
+ if output_len > 1:
342
+ med_decode_latency = np.median(decode_latencies)
343
+ med_decode_throughput = batch_size / med_decode_latency
344
+ rank_print(
345
+ f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
346
+ )
347
+ measurement_results["median_decode_latency"] = med_decode_latency
348
+ measurement_results["median_decode_throughput"] = med_decode_throughput
349
+
350
+ throughput = (input_len + output_len) * batch_size / tot_latency
351
+ rank_print(
352
+ f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
353
+ )
354
+ measurement_results["total_latency"] = tot_latency
355
+ measurement_results["overall_throughput"] = throughput
356
+ return measurement_results
357
+
358
+
359
+ def latency_test(
360
+ server_args,
361
+ port_args,
362
+ bench_args,
363
+ tp_rank,
364
+ ):
365
+ # Configure the logger
366
+ configure_logger(server_args, prefix=f" TP{tp_rank}")
367
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
368
+
369
+ # Load the model
370
+ model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
371
+
372
+ # Prepare inputs for warm up
373
+ reqs = prepare_synthetic_inputs_for_latency_test(
374
+ bench_args.batch_size[0], bench_args.input_len[0]
375
+ )
376
+
377
+ # Warm up
378
+ rank_print("Warmup ...")
379
+ latency_test_run_once(
380
+ bench_args.run_name,
381
+ model_runner,
382
+ rank_print,
383
+ reqs,
384
+ bench_args.batch_size[0],
385
+ bench_args.input_len[0],
386
+ 8, # shorter decoding to speed up the warmup
387
+ server_args.device,
388
+ )
389
+ rank_print("Benchmark ...")
390
+
391
+ # Run the sweep
392
+ result_list = []
393
+ for bs, il, ol in itertools.product(
394
+ bench_args.batch_size, bench_args.input_len, bench_args.output_len
395
+ ):
396
+ reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
397
+ ret = latency_test_run_once(
398
+ bench_args.run_name,
399
+ model_runner,
400
+ rank_print,
401
+ reqs,
402
+ bs,
403
+ il,
404
+ ol,
405
+ server_args.device,
406
+ )
407
+ if ret is not None:
408
+ result_list.append(ret)
409
+
410
+ # Write results in jsonlines format on rank 0.
411
+ if tp_rank == 0 and bench_args.result_filename:
412
+ with open(bench_args.result_filename, "a") as fout:
413
+ for result in result_list:
414
+ fout.write(json.dumps(result) + "\n")
415
+
416
+
417
+ def main(server_args, bench_args):
418
+ _set_envs_and_config(server_args)
419
+
420
+ if server_args.model_path:
421
+ if bench_args.correctness_test:
422
+ work_func = correctness_test
423
+ else:
424
+ work_func = latency_test
425
+ else:
426
+ raise ValueError(
427
+ "Provide --model-path for running the tests or "
428
+ "provide --result-filename for plotting the results"
429
+ )
430
+
431
+ port_args = PortArgs.init_new(server_args)
432
+
433
+ if server_args.tp_size == 1:
434
+ work_func(server_args, port_args, bench_args, 0)
435
+ else:
436
+ workers = []
437
+ for tp_rank in range(server_args.tp_size):
438
+ proc = multiprocessing.Process(
439
+ target=work_func,
440
+ args=(
441
+ server_args,
442
+ port_args,
443
+ bench_args,
444
+ tp_rank,
445
+ ),
446
+ )
447
+ proc.start()
448
+ workers.append(proc)
449
+
450
+ for proc in workers:
451
+ proc.join()
452
+
453
+ proc.terminate()
454
+
455
+
456
+ if __name__ == "__main__":
457
+ parser = argparse.ArgumentParser()
458
+ ServerArgs.add_cli_args(parser)
459
+ BenchArgs.add_cli_args(parser)
460
+ args = parser.parse_args()
461
+ server_args = ServerArgs.from_cli_args(args)
462
+ bench_args = BenchArgs.from_cli_args(args)
463
+
464
+ logging.basicConfig(
465
+ level=getattr(logging, server_args.log_level.upper()),
466
+ format="%(message)s",
467
+ )
468
+
469
+ try:
470
+ main(server_args, bench_args)
471
+ except Exception as e:
472
+ raise e
473
+ finally:
474
+ kill_child_process()
@@ -1,10 +1,10 @@
1
1
  """
2
- Benchmark the latency of serving a single batch with a real server.
2
+ Benchmark the latency of running a single batch with a server.
3
+
3
4
  This script launches a server and uses the HTTP interface.
4
- It accepts arguments similar to those of launch_server.py.
5
+ It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
5
6
 
6
7
  Usage:
7
-
8
8
  python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
9
9
 
10
10
  python3 -m sglang.bench_server_latency --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8