sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,14 @@
1
1
  import json
2
- from typing import Callable, List, Optional, Union
2
+ from typing import List, Optional
3
3
 
4
4
  import numpy as np
5
- import requests
5
+
6
6
  from sglang.backend.base_backend import BaseBackend
7
7
  from sglang.global_config import global_config
8
8
  from sglang.lang.chat_template import get_chat_template_by_model_path
9
9
  from sglang.lang.interpreter import StreamExecutor
10
- from sglang.lang.ir import SglArgument, SglSamplingParams
11
- from sglang.utils import encode_image_base64, find_printable_text, http_request
10
+ from sglang.lang.ir import SglSamplingParams
11
+ from sglang.utils import http_request
12
12
 
13
13
 
14
14
  class RuntimeEndpoint(BaseBackend):
@@ -33,7 +33,7 @@ class RuntimeEndpoint(BaseBackend):
33
33
  api_key=self.api_key,
34
34
  verify=self.verify,
35
35
  )
36
- assert res.status_code == 200
36
+ self._assert_success(res)
37
37
  self.model_info = res.json()
38
38
 
39
39
  self.chat_template = get_chat_template_by_model_path(
@@ -49,7 +49,7 @@ class RuntimeEndpoint(BaseBackend):
49
49
  auth_token=self.auth_token,
50
50
  verify=self.verify,
51
51
  )
52
- return res.status_code == 200
52
+ self._assert_success(res)
53
53
 
54
54
  def get_server_args(self):
55
55
  res = http_request(
@@ -57,6 +57,7 @@ class RuntimeEndpoint(BaseBackend):
57
57
  auth_token=self.auth_token,
58
58
  verify=self.verify,
59
59
  )
60
+ self._assert_success(res)
60
61
  return res.json()
61
62
 
62
63
  def get_chat_template(self):
@@ -70,17 +71,19 @@ class RuntimeEndpoint(BaseBackend):
70
71
  api_key=self.api_key,
71
72
  verify=self.verify,
72
73
  )
73
- assert res.status_code == 200
74
+ self._assert_success(res)
74
75
 
75
76
  def commit_lazy_operations(self, s: StreamExecutor):
77
+ data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
78
+ self._add_images(s, data)
76
79
  res = http_request(
77
80
  self.base_url + "/generate",
78
- json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
81
+ json=data,
79
82
  auth_token=self.auth_token,
80
83
  api_key=self.api_key,
81
84
  verify=self.verify,
82
85
  )
83
- assert res.status_code == 200
86
+ self._assert_success(res)
84
87
 
85
88
  def fill_image(self, s: StreamExecutor):
86
89
  data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
@@ -92,7 +95,7 @@ class RuntimeEndpoint(BaseBackend):
92
95
  api_key=self.api_key,
93
96
  verify=self.verify,
94
97
  )
95
- assert res.status_code == 200
98
+ self._assert_success(res)
96
99
 
97
100
  def generate(
98
101
  self,
@@ -104,6 +107,7 @@ class RuntimeEndpoint(BaseBackend):
104
107
  "text": s.text_,
105
108
  "sampling_params": {
106
109
  "skip_special_tokens": global_config.skip_special_tokens_in_output,
110
+ "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
107
111
  **sampling_params.to_srt_kwargs(),
108
112
  },
109
113
  }
@@ -112,6 +116,7 @@ class RuntimeEndpoint(BaseBackend):
112
116
  "text": s.text_,
113
117
  "sampling_params": {
114
118
  "skip_special_tokens": global_config.skip_special_tokens_in_output,
119
+ "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
115
120
  "dtype": "int",
116
121
  **sampling_params.to_srt_kwargs(),
117
122
  },
@@ -119,6 +124,16 @@ class RuntimeEndpoint(BaseBackend):
119
124
  else:
120
125
  raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
121
126
 
127
+ for item in [
128
+ "return_logprob",
129
+ "logprob_start_len",
130
+ "top_logprobs_num",
131
+ "return_text_in_logprobs",
132
+ ]:
133
+ value = getattr(sampling_params, item, None)
134
+ if value is not None:
135
+ data[item] = value
136
+
122
137
  self._add_images(s, data)
123
138
 
124
139
  res = http_request(
@@ -128,6 +143,8 @@ class RuntimeEndpoint(BaseBackend):
128
143
  api_key=self.api_key,
129
144
  verify=self.verify,
130
145
  )
146
+ self._assert_success(res)
147
+
131
148
  obj = res.json()
132
149
  comp = obj["text"]
133
150
  return comp, obj["meta_info"]
@@ -142,6 +159,7 @@ class RuntimeEndpoint(BaseBackend):
142
159
  "text": s.text_,
143
160
  "sampling_params": {
144
161
  "skip_special_tokens": global_config.skip_special_tokens_in_output,
162
+ "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
145
163
  **sampling_params.to_srt_kwargs(),
146
164
  },
147
165
  }
@@ -150,6 +168,7 @@ class RuntimeEndpoint(BaseBackend):
150
168
  "text": s.text_,
151
169
  "sampling_params": {
152
170
  "skip_special_tokens": global_config.skip_special_tokens_in_output,
171
+ "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
153
172
  "dtype": "int",
154
173
  **sampling_params.to_srt_kwargs(),
155
174
  },
@@ -157,10 +176,20 @@ class RuntimeEndpoint(BaseBackend):
157
176
  else:
158
177
  raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
159
178
 
179
+ for item in [
180
+ "return_logprob",
181
+ "logprob_start_len",
182
+ "top_logprobs_num",
183
+ "return_text_in_logprobs",
184
+ ]:
185
+ value = getattr(sampling_params, item, None)
186
+ if value is not None:
187
+ data[item] = value
188
+
160
189
  data["stream"] = True
161
190
  self._add_images(s, data)
162
191
 
163
- response = http_request(
192
+ res = http_request(
164
193
  self.base_url + "/generate",
165
194
  json=data,
166
195
  stream=True,
@@ -168,23 +197,19 @@ class RuntimeEndpoint(BaseBackend):
168
197
  api_key=self.api_key,
169
198
  verify=self.verify,
170
199
  )
200
+ self._assert_success(res)
171
201
  pos = 0
172
202
 
173
- incomplete_text = ""
174
- for chunk in response.iter_lines(decode_unicode=False):
203
+ for chunk in res.iter_lines(decode_unicode=False):
175
204
  chunk = chunk.decode("utf-8")
176
205
  if chunk and chunk.startswith("data:"):
177
206
  if chunk == "data: [DONE]":
178
207
  break
179
208
  data = json.loads(chunk[5:].strip("\n"))
180
- text = find_printable_text(data["text"][pos:])
209
+ chunk_text = data["text"][pos:]
181
210
  meta_info = data["meta_info"]
182
- pos += len(text)
183
- incomplete_text = data["text"][pos:]
184
- yield text, meta_info
185
-
186
- if len(incomplete_text) > 0:
187
- yield incomplete_text, meta_info
211
+ pos += len(chunk_text)
212
+ yield chunk_text, meta_info
188
213
 
189
214
  def select(
190
215
  self,
@@ -204,7 +229,7 @@ class RuntimeEndpoint(BaseBackend):
204
229
  api_key=self.api_key,
205
230
  verify=self.verify,
206
231
  )
207
- assert res.status_code == 200
232
+ self._assert_success(res)
208
233
  prompt_len = res.json()["meta_info"]["prompt_tokens"]
209
234
 
210
235
  # Compute logprob
@@ -222,15 +247,21 @@ class RuntimeEndpoint(BaseBackend):
222
247
  api_key=self.api_key,
223
248
  verify=self.verify,
224
249
  )
225
- assert res.status_code == 200
250
+ self._assert_success(res)
226
251
  obj = res.json()
227
- normalized_prompt_logprob = [
252
+ normalized_prompt_logprobs = [
228
253
  r["meta_info"]["normalized_prompt_logprob"] for r in obj
229
254
  ]
230
- prompt_logprob = [r["meta_info"]["prompt_logprob"] for r in obj]
255
+ decision = choices[np.argmax(normalized_prompt_logprobs)]
256
+ prefill_token_logprobs = [r["meta_info"]["prefill_token_logprobs"] for r in obj]
257
+ decode_token_logprobs = [r["meta_info"]["decode_token_logprobs"] for r in obj]
231
258
 
232
- decision = choices[np.argmax(normalized_prompt_logprob)]
233
- return decision, normalized_prompt_logprob, prompt_logprob
259
+ return (
260
+ decision,
261
+ normalized_prompt_logprobs,
262
+ prefill_token_logprobs,
263
+ decode_token_logprobs,
264
+ )
234
265
 
235
266
  def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
236
267
  res = http_request(
@@ -240,9 +271,13 @@ class RuntimeEndpoint(BaseBackend):
240
271
  api_key=self.api_key,
241
272
  verify=self.verify,
242
273
  )
243
- assert res.status_code == 200
274
+ self._assert_success(res)
244
275
 
245
276
  def _add_images(self, s: StreamExecutor, data):
246
277
  if s.images_:
247
278
  assert len(s.images_) == 1, "Only support one image."
248
279
  data["image_data"] = s.images_[0][1]
280
+
281
+ def _assert_success(self, res):
282
+ if res.status_code != 200:
283
+ raise RuntimeError(res.json())
@@ -3,6 +3,7 @@ import warnings
3
3
  from typing import List, Optional, Union
4
4
 
5
5
  import numpy as np
6
+
6
7
  from sglang.backend.base_backend import BaseBackend
7
8
  from sglang.lang.chat_template import get_chat_template
8
9
  from sglang.lang.interpreter import StreamExecutor
@@ -0,0 +1,320 @@
1
+ """
2
+ Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
3
+
4
+ # Usage (latency test):
5
+ python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
6
+
7
+ # Usage (correctness test):
8
+ python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
9
+
10
+ ### Reference output:
11
+ prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
12
+ [-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
13
+ [ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
14
+ device='cuda:0', dtype=torch.float16)
15
+ prefill logits (final) tensor([[-8.3203, -7.1211, 3.3379, ..., -4.9570, -4.1328, -3.4141],
16
+ [-8.9062, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0742],
17
+ [-9.6328, -9.0547, 4.0117, ..., -5.3047, -4.7148, -4.4609]],
18
+ device='cuda:0', dtype=torch.float16)
19
+ <s> The capital of France is.
20
+ The capital of the United States is Washington, D.C.
21
+
22
+ <s> The capital of the United Kindom is.
23
+ The capital of the United Kingdom is London.
24
+ The capital of the
25
+ <s> Today is a sunny day and I like go for a walk in the park.
26
+ I'm going to the park
27
+ """
28
+
29
+ import argparse
30
+ import dataclasses
31
+ import logging
32
+ import multiprocessing
33
+ import time
34
+
35
+ import numpy as np
36
+ import torch
37
+ import torch.distributed as dist
38
+
39
+ from sglang.srt.hf_transformers_utils import get_tokenizer
40
+ from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req
41
+ from sglang.srt.managers.controller.model_runner import ModelRunner
42
+ from sglang.srt.model_config import ModelConfig
43
+ from sglang.srt.sampling_params import SamplingParams
44
+ from sglang.srt.server_args import ServerArgs
45
+ from sglang.srt.utils import suppress_other_loggers
46
+
47
+
48
+ @dataclasses.dataclass
49
+ class BenchArgs:
50
+ batch_size: int = 1
51
+ input_len: int = 1024
52
+ output_len: int = 4
53
+ correctness_test: bool = False
54
+ # This is only used for correctness test
55
+ cut_len: int = 4
56
+
57
+ @staticmethod
58
+ def add_cli_args(parser: argparse.ArgumentParser):
59
+ parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size)
60
+ parser.add_argument("--input-len", type=int, default=BenchArgs.input_len)
61
+ parser.add_argument("--output-len", type=int, default=BenchArgs.output_len)
62
+ parser.add_argument("--correctness-test", action="store_true")
63
+ parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
64
+
65
+ @classmethod
66
+ def from_cli_args(cls, args: argparse.Namespace):
67
+ attrs = [attr.name for attr in dataclasses.fields(cls)]
68
+ return cls(**{attr: getattr(args, attr) for attr in attrs})
69
+
70
+
71
+ def load_model(server_args, tp_rank):
72
+ suppress_other_loggers()
73
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
74
+
75
+ model_config = ModelConfig(path=server_args.model_path)
76
+ model_runner = ModelRunner(
77
+ model_config=model_config,
78
+ mem_fraction_static=server_args.mem_fraction_static,
79
+ gpu_id=tp_rank,
80
+ tp_rank=tp_rank,
81
+ tp_size=server_args.tp_size,
82
+ nccl_port=28888,
83
+ server_args=server_args,
84
+ )
85
+ rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
86
+ tokenizer = get_tokenizer(
87
+ server_args.tokenizer_path,
88
+ tokenizer_mode=server_args.tokenizer_mode,
89
+ trust_remote_code=server_args.trust_remote_code,
90
+ )
91
+ if server_args.tp_size > 1:
92
+ dist.barrier()
93
+ return model_runner, tokenizer
94
+
95
+
96
+ def prepare_inputs(bench_args, tokenizer):
97
+ prompts = [
98
+ "The capital of France is",
99
+ "The capital of the United Kindom is",
100
+ "Today is a sunny day and I like",
101
+ ]
102
+ input_ids = [tokenizer.encode(p) for p in prompts]
103
+ sampling_params = SamplingParams(
104
+ temperature=0,
105
+ max_new_tokens=BenchArgs.output_len,
106
+ )
107
+
108
+ reqs = []
109
+ for i in range(len(prompts)):
110
+ assert len(input_ids[i]) > bench_args.cut_len
111
+
112
+ tmp_input_ids = input_ids[i][: bench_args.cut_len]
113
+ req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
114
+ req.prefix_indices = []
115
+ req.sampling_params = sampling_params
116
+ req.input_ids = req.origin_input_ids
117
+ reqs.append(req)
118
+
119
+ return input_ids, reqs
120
+
121
+
122
+ def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
123
+ for i in range(len(reqs)):
124
+ req = reqs[i]
125
+ req.input_ids += input_ids[i][bench_args.cut_len :]
126
+ req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
127
+ i, : bench_args.cut_len
128
+ ]
129
+ return reqs
130
+
131
+
132
+ def prepare_synthetic_inputs(bench_args, tokenizer):
133
+ input_ids = np.ones((bench_args.batch_size, bench_args.input_len), dtype=np.int32)
134
+ sampling_params = SamplingParams(
135
+ temperature=0,
136
+ max_new_tokens=BenchArgs.output_len,
137
+ )
138
+
139
+ reqs = []
140
+ for i in range(len(input_ids)):
141
+ req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
142
+ req.prefix_indices = []
143
+ req.sampling_params = sampling_params
144
+ req.input_ids = req.origin_input_ids
145
+ reqs.append(req)
146
+
147
+ return reqs
148
+
149
+
150
+ def extend(reqs, model_runner):
151
+ batch = Batch.init_new(
152
+ reqs=reqs,
153
+ req_to_token_pool=model_runner.req_to_token_pool,
154
+ token_to_kv_pool=model_runner.token_to_kv_pool,
155
+ tree_cache=None,
156
+ )
157
+ batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
158
+ output = model_runner.forward(batch, ForwardMode.EXTEND)
159
+ next_token_ids, _ = batch.sample(output.next_token_logits)
160
+ return next_token_ids, output.next_token_logits, batch
161
+
162
+
163
+ def decode(input_token_ids, batch, model_runner):
164
+ batch.prepare_for_decode(input_token_ids.cpu().numpy())
165
+ output = model_runner.forward(batch, ForwardMode.DECODE)
166
+ next_token_ids, _ = batch.sample(output.next_token_logits)
167
+ return next_token_ids, output.next_token_logits
168
+
169
+
170
+ @torch.inference_mode()
171
+ def correctness_test(
172
+ server_args,
173
+ bench_args,
174
+ tp_rank,
175
+ ):
176
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
177
+
178
+ # Load the model
179
+ model_runner, tokenizer = load_model(server_args, tp_rank)
180
+
181
+ # Prepare inputs
182
+ input_ids, reqs = prepare_inputs(bench_args, tokenizer)
183
+
184
+ if bench_args.cut_len > 0:
185
+ # Prefill
186
+ next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
187
+ rank_print("prefill logits (first half)", next_token_logits)
188
+
189
+ # Prepare extend inputs
190
+ reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner)
191
+
192
+ # Extend
193
+ next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
194
+ rank_print("prefill logits (final)", next_token_logits)
195
+
196
+ # Decode
197
+ output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
198
+ for _ in range(bench_args.output_len):
199
+ next_token_ids, _ = decode(next_token_ids, batch, model_runner)
200
+ for i in range(len(reqs)):
201
+ output_ids[i].append(next_token_ids[i])
202
+
203
+ # Print
204
+ for i in range(len(reqs)):
205
+ rank_print(tokenizer.decode(output_ids[i]))
206
+
207
+
208
+ def latency_test(
209
+ server_args,
210
+ bench_args,
211
+ tp_rank,
212
+ ):
213
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
214
+
215
+ # Load the model
216
+ model_runner, tokenizer = load_model(server_args, tp_rank)
217
+ rank_print(
218
+ f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
219
+ )
220
+
221
+ # Prepare inputs
222
+ reqs = prepare_synthetic_inputs(bench_args, tokenizer)
223
+
224
+ def clear():
225
+ model_runner.req_to_token_pool.clear()
226
+ model_runner.token_to_kv_pool.clear()
227
+
228
+ @torch.inference_mode()
229
+ def run_once(output_len):
230
+ # Prefill
231
+ torch.cuda.synchronize()
232
+ tot_latency = 0
233
+ tic = time.time()
234
+ next_token_ids, _, batch = extend(reqs, model_runner)
235
+ torch.cuda.synchronize()
236
+ prefill_latency = time.time() - tic
237
+ tot_latency += prefill_latency
238
+ throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
239
+ rank_print(
240
+ f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
241
+ )
242
+
243
+ # Decode
244
+ for i in range(output_len):
245
+ torch.cuda.synchronize()
246
+ tic = time.time()
247
+ next_token_ids, _ = decode(next_token_ids, batch, model_runner)
248
+ torch.cuda.synchronize()
249
+ latency = time.time() - tic
250
+ tot_latency += latency
251
+ throughput = bench_args.batch_size / latency
252
+ if i < 5:
253
+ rank_print(
254
+ f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
255
+ )
256
+ avg_decode_latency = (tot_latency - prefill_latency) / output_len
257
+ avg_decode_throughput = bench_args.batch_size / avg_decode_latency
258
+ rank_print(
259
+ f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
260
+ )
261
+
262
+ throughput = (
263
+ (bench_args.input_len + bench_args.output_len)
264
+ * bench_args.batch_size
265
+ / tot_latency
266
+ )
267
+ rank_print(
268
+ f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
269
+ )
270
+
271
+ # Warm up
272
+ run_once(4)
273
+ clear()
274
+
275
+ # Run again
276
+ run_once(bench_args.output_len)
277
+
278
+
279
+ def main(server_args, bench_args):
280
+ print(bench_args)
281
+
282
+ if bench_args.correctness_test:
283
+ work_func = correctness_test
284
+ else:
285
+ work_func = latency_test
286
+
287
+ workers = []
288
+ for tp_rank in range(server_args.tp_size):
289
+ proc = multiprocessing.Process(
290
+ target=work_func,
291
+ args=(
292
+ server_args,
293
+ bench_args,
294
+ tp_rank,
295
+ ),
296
+ )
297
+ proc.start()
298
+ workers.append(proc)
299
+
300
+ for proc in workers:
301
+ proc.join()
302
+
303
+ proc.terminate()
304
+
305
+
306
+ if __name__ == "__main__":
307
+ parser = argparse.ArgumentParser()
308
+ ServerArgs.add_cli_args(parser)
309
+ BenchArgs.add_cli_args(parser)
310
+ args = parser.parse_args()
311
+
312
+ server_args = ServerArgs.from_cli_args(args)
313
+ bench_args = BenchArgs.from_cli_args(args)
314
+
315
+ logging.basicConfig(
316
+ level=getattr(logging, server_args.log_level.upper()),
317
+ format="%(message)s",
318
+ )
319
+
320
+ main(server_args, bench_args)
sglang/global_config.py CHANGED
@@ -8,17 +8,38 @@ class GlobalConfig:
8
8
  # 2: output final text after every run
9
9
  self.verbosity = 0
10
10
 
11
+ # Default backend of the language
11
12
  self.default_backend = None
12
13
 
13
- # Output configs
14
+ # Runtime constants: Request dependency time due to network delay
15
+ self.request_dependency_delay = 0.02
16
+ self.wait_for_new_request_delay = 0.0006
17
+
18
+ # Runtime constants: New generation token ratio estimation
19
+ self.base_new_token_ratio = 0.4
20
+ self.base_min_new_token_ratio = 0.2
21
+ self.new_token_ratio_decay = 0.0001
22
+ self.new_token_ratio_recovery = 0.05
23
+
24
+ # Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
25
+ # This can improve the speed for large batch sizes during prefill.
26
+ self.layer_sync_threshold = 8192
27
+
28
+ # Runtime constants: others
29
+ self.num_continue_decode_steps = 10
30
+ self.flashinfer_workspace_size = 192 * 1024 * 1024
31
+
32
+ # Output tokenization configs
14
33
  self.skip_special_tokens_in_output = True
34
+ self.spaces_between_special_tokens_in_out = True
15
35
 
16
- # Optimization configs
36
+ # Interpreter optimization configs
17
37
  self.eager_fill_image = False
18
- self.enable_prefix_sharing = True
38
+ self.enable_precache_with_tracing = True
19
39
  self.enable_parallel_encoding = True
20
40
  self.enable_parallel_decoding = True
21
41
 
42
+ # Deprecated
22
43
  # Choices: ["no_adjust", "adjust_cache"]
23
44
  # no_adjust: Do not adjust the position embedding of KV cache.
24
45
  # adjust_cache: Adjust the position embedding of KV cache.