sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +30 -4
  3. sglang/backend/litellm.py +2 -2
  4. sglang/backend/openai.py +26 -15
  5. sglang/backend/runtime_endpoint.py +18 -14
  6. sglang/bench_latency.py +317 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +41 -6
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +6 -2
  11. sglang/lang/ir.py +74 -28
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +14 -6
  15. sglang/srt/constrained/fsm_cache.py +6 -3
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +2 -0
  19. sglang/srt/hf_transformers_utils.py +68 -9
  20. sglang/srt/layers/extend_attention.py +2 -1
  21. sglang/srt/layers/fused_moe.py +280 -169
  22. sglang/srt/layers/logits_processor.py +106 -42
  23. sglang/srt/layers/radix_attention.py +53 -29
  24. sglang/srt/layers/token_attention.py +4 -1
  25. sglang/srt/managers/controller/dp_worker.py +6 -3
  26. sglang/srt/managers/controller/infer_batch.py +144 -69
  27. sglang/srt/managers/controller/manager_multi.py +5 -5
  28. sglang/srt/managers/controller/manager_single.py +9 -4
  29. sglang/srt/managers/controller/model_runner.py +167 -55
  30. sglang/srt/managers/controller/radix_cache.py +4 -0
  31. sglang/srt/managers/controller/schedule_heuristic.py +2 -0
  32. sglang/srt/managers/controller/tp_worker.py +156 -134
  33. sglang/srt/managers/detokenizer_manager.py +19 -21
  34. sglang/srt/managers/io_struct.py +11 -5
  35. sglang/srt/managers/tokenizer_manager.py +16 -14
  36. sglang/srt/model_config.py +89 -4
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +2 -2
  39. sglang/srt/models/dbrx.py +1 -1
  40. sglang/srt/models/gemma.py +5 -1
  41. sglang/srt/models/gemma2.py +436 -0
  42. sglang/srt/models/grok.py +204 -137
  43. sglang/srt/models/llama2.py +12 -5
  44. sglang/srt/models/llama_classification.py +107 -0
  45. sglang/srt/models/llava.py +11 -8
  46. sglang/srt/models/llavavid.py +1 -1
  47. sglang/srt/models/minicpm.py +373 -0
  48. sglang/srt/models/mixtral.py +164 -115
  49. sglang/srt/models/mixtral_quant.py +0 -1
  50. sglang/srt/models/qwen.py +1 -1
  51. sglang/srt/models/qwen2.py +1 -1
  52. sglang/srt/models/qwen2_moe.py +454 -0
  53. sglang/srt/models/stablelm.py +1 -1
  54. sglang/srt/models/yivl.py +2 -2
  55. sglang/srt/openai_api_adapter.py +35 -25
  56. sglang/srt/openai_protocol.py +2 -2
  57. sglang/srt/server.py +69 -19
  58. sglang/srt/server_args.py +76 -43
  59. sglang/srt/utils.py +177 -35
  60. sglang/test/test_programs.py +28 -10
  61. sglang/utils.py +4 -3
  62. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
  63. sglang-0.1.19.dist-info/RECORD +81 -0
  64. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
  65. sglang/srt/managers/router/infer_batch.py +0 -596
  66. sglang/srt/managers/router/manager.py +0 -82
  67. sglang/srt/managers/router/model_rpc.py +0 -818
  68. sglang/srt/managers/router/model_runner.py +0 -445
  69. sglang/srt/managers/router/radix_cache.py +0 -267
  70. sglang/srt/managers/router/scheduler.py +0 -59
  71. sglang-0.1.17.dist-info/RECORD +0 -81
  72. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
  73. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
sglang/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.17"
1
+ __version__ = "0.1.19"
2
2
 
3
3
  # SGL API Components
4
4
  from sglang.api import (
@@ -24,10 +24,10 @@ from sglang.api import (
24
24
 
25
25
  # SGL Backends
26
26
  from sglang.backend.anthropic import Anthropic
27
+ from sglang.backend.litellm import LiteLLM
27
28
  from sglang.backend.openai import OpenAI
28
29
  from sglang.backend.runtime_endpoint import RuntimeEndpoint
29
30
  from sglang.backend.vertexai import VertexAI
30
- from sglang.backend.litellm import LiteLLM
31
31
 
32
32
  # Global Configurations
33
33
  from sglang.global_config import global_config
sglang/api.py CHANGED
@@ -1,4 +1,4 @@
1
- """Some Public API Definitions"""
1
+ """Public APIs of the language."""
2
2
 
3
3
  import os
4
4
  import re
@@ -43,14 +43,14 @@ def set_default_backend(backend: BaseBackend):
43
43
  global_config.default_backend = backend
44
44
 
45
45
 
46
- def flush_cache(backend: BaseBackend = None):
46
+ def flush_cache(backend: Optional[BaseBackend] = None):
47
47
  backend = backend or global_config.default_backend
48
48
  if backend is None:
49
49
  return False
50
50
  return backend.flush_cache()
51
51
 
52
52
 
53
- def get_server_args(backend: BaseBackend = None):
53
+ def get_server_args(backend: Optional[BaseBackend] = None):
54
54
  backend = backend or global_config.default_backend
55
55
  if backend is None:
56
56
  return None
@@ -67,10 +67,16 @@ def gen(
67
67
  frequency_penalty: Optional[float] = None,
68
68
  presence_penalty: Optional[float] = None,
69
69
  ignore_eos: Optional[bool] = None,
70
+ return_logprob: Optional[bool] = None,
71
+ logprob_start_len: Optional[int] = None,
72
+ top_logprobs_num: Optional[int] = None,
73
+ return_text_in_logprobs: Optional[bool] = None,
70
74
  dtype: Optional[type] = None,
71
75
  choices: Optional[List[str]] = None,
72
76
  regex: Optional[str] = None,
73
77
  ):
78
+ """Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
79
+
74
80
  if choices:
75
81
  return SglSelect(name, choices, 0.0 if temperature is None else temperature)
76
82
 
@@ -91,6 +97,10 @@ def gen(
91
97
  frequency_penalty,
92
98
  presence_penalty,
93
99
  ignore_eos,
100
+ return_logprob,
101
+ logprob_start_len,
102
+ top_logprobs_num,
103
+ return_text_in_logprobs,
94
104
  dtype,
95
105
  regex,
96
106
  )
@@ -106,6 +116,10 @@ def gen_int(
106
116
  frequency_penalty: Optional[float] = None,
107
117
  presence_penalty: Optional[float] = None,
108
118
  ignore_eos: Optional[bool] = None,
119
+ return_logprob: Optional[bool] = None,
120
+ logprob_start_len: Optional[int] = None,
121
+ top_logprobs_num: Optional[int] = None,
122
+ return_text_in_logprobs: Optional[bool] = None,
109
123
  ):
110
124
  return SglGen(
111
125
  name,
@@ -117,6 +131,10 @@ def gen_int(
117
131
  frequency_penalty,
118
132
  presence_penalty,
119
133
  ignore_eos,
134
+ return_logprob,
135
+ logprob_start_len,
136
+ top_logprobs_num,
137
+ return_text_in_logprobs,
120
138
  int,
121
139
  None,
122
140
  )
@@ -132,6 +150,10 @@ def gen_string(
132
150
  frequency_penalty: Optional[float] = None,
133
151
  presence_penalty: Optional[float] = None,
134
152
  ignore_eos: Optional[bool] = None,
153
+ return_logprob: Optional[bool] = None,
154
+ logprob_start_len: Optional[int] = None,
155
+ top_logprobs_num: Optional[int] = None,
156
+ return_text_in_logprobs: Optional[bool] = None,
135
157
  ):
136
158
  return SglGen(
137
159
  name,
@@ -143,6 +165,10 @@ def gen_string(
143
165
  frequency_penalty,
144
166
  presence_penalty,
145
167
  ignore_eos,
168
+ return_logprob,
169
+ logprob_start_len,
170
+ top_logprobs_num,
171
+ return_text_in_logprobs,
146
172
  str,
147
173
  None,
148
174
  )
@@ -158,7 +184,7 @@ def video(path: str, num_frames: int):
158
184
 
159
185
  def select(
160
186
  name: Optional[str] = None,
161
- choices: List[str] = None,
187
+ choices: Optional[List[str]] = None,
162
188
  temperature: float = 0.0,
163
189
  ):
164
190
  assert choices is not None
sglang/backend/litellm.py CHANGED
@@ -13,7 +13,6 @@ except ImportError as e:
13
13
 
14
14
 
15
15
  class LiteLLM(BaseBackend):
16
-
17
16
  def __init__(
18
17
  self,
19
18
  model_name,
@@ -33,7 +32,8 @@ class LiteLLM(BaseBackend):
33
32
  self.model_name = model_name
34
33
 
35
34
  self.chat_template = chat_template or get_chat_template_by_model_path(
36
- model_name)
35
+ model_name
36
+ )
37
37
 
38
38
  self.client_params = {
39
39
  "api_key": api_key,
sglang/backend/openai.py CHANGED
@@ -1,7 +1,7 @@
1
+ import dataclasses
1
2
  import logging
2
3
  import time
3
4
  import warnings
4
- import dataclasses
5
5
  from typing import Callable, List, Optional, Union
6
6
 
7
7
  import numpy as np
@@ -105,14 +105,16 @@ class OpenAI(BaseBackend):
105
105
  def get_chat_template(self):
106
106
  return self.chat_template
107
107
 
108
- def _prepare_spec_execution(self, sampling_params: SglSamplingParams,
109
- num_api_spec_tokens: int, spec_var_name: str):
108
+ def _prepare_spec_execution(
109
+ self,
110
+ sampling_params: SglSamplingParams,
111
+ num_api_spec_tokens: int,
112
+ spec_var_name: str,
113
+ ):
110
114
  if "max_tokens" not in self.spec_kwargs:
111
115
  self.spec_kwargs["max_tokens"] = num_api_spec_tokens
112
116
  else:
113
- assert (
114
- self.spec_kwargs["max_tokens"] == num_api_spec_tokens
115
- )
117
+ assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
116
118
 
117
119
  params = sampling_params.to_openai_kwargs()
118
120
  for key, value in params.items():
@@ -151,8 +153,9 @@ class OpenAI(BaseBackend):
151
153
  )
152
154
  prompt = s.messages_
153
155
  else:
154
- return self._prepare_spec_execution(sampling_params,
155
- s.num_api_spec_tokens, spec_var_name)
156
+ return self._prepare_spec_execution(
157
+ sampling_params, s.num_api_spec_tokens, spec_var_name
158
+ )
156
159
  else:
157
160
  prompt = s.text_
158
161
 
@@ -325,7 +328,7 @@ class OpenAI(BaseBackend):
325
328
  ret_str = ret.choices[0].text
326
329
  ret_token = self.tokenizer.encode(ret_str)[0]
327
330
  self.token_usage.prompt_tokens += ret.usage.prompt_tokens
328
- self.token_usage.completion_tokens= ret.usage.completion_tokens
331
+ self.token_usage.completion_tokens = ret.usage.completion_tokens
329
332
 
330
333
  # TODO:
331
334
  # 1. return logits as the scores
@@ -355,7 +358,9 @@ class OpenAI(BaseBackend):
355
358
  return decision, scores, None, None
356
359
 
357
360
 
358
- def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
361
+ def openai_completion(
362
+ client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
363
+ ):
359
364
  for attempt in range(retries):
360
365
  try:
361
366
  if is_chat:
@@ -385,15 +390,19 @@ def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None,
385
390
  return comp
386
391
 
387
392
 
388
- def openai_completion_stream(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
393
+ def openai_completion_stream(
394
+ client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
395
+ ):
389
396
  for attempt in range(retries):
390
397
  try:
391
398
  if is_chat:
392
399
  if "stop" in kwargs and kwargs["stop"] is None:
393
400
  kwargs.pop("stop")
394
401
  generator = client.chat.completions.create(
395
- messages=prompt, stream=True, stream_options={"include_usage": True},
396
- **kwargs
402
+ messages=prompt,
403
+ stream=True,
404
+ stream_options={"include_usage": True},
405
+ **kwargs,
397
406
  )
398
407
  for ret in generator:
399
408
  if len(ret.choices) == 0:
@@ -405,8 +414,10 @@ def openai_completion_stream(client, token_usage, is_chat=None, retries=3, promp
405
414
  yield content or "", {}
406
415
  else:
407
416
  generator = client.completions.create(
408
- prompt=prompt, stream=True, stream_options={"include_usage": True},
409
- **kwargs
417
+ prompt=prompt,
418
+ stream=True,
419
+ stream_options={"include_usage": True},
420
+ **kwargs,
410
421
  )
411
422
  for ret in generator:
412
423
  if len(ret.choices) == 0:
@@ -1,18 +1,18 @@
1
1
  import json
2
- from typing import Callable, List, Optional, Union
2
+ from typing import List, Optional
3
3
 
4
4
  import numpy as np
5
- import requests
6
5
 
7
6
  from sglang.backend.base_backend import BaseBackend
8
7
  from sglang.global_config import global_config
9
8
  from sglang.lang.chat_template import get_chat_template_by_model_path
10
9
  from sglang.lang.interpreter import StreamExecutor
11
- from sglang.lang.ir import SglArgument, SglSamplingParams
12
- from sglang.utils import encode_image_base64, find_printable_text, http_request
10
+ from sglang.lang.ir import SglSamplingParams
11
+ from sglang.utils import http_request
13
12
 
14
13
 
15
14
  class RuntimeEndpoint(BaseBackend):
15
+
16
16
  def __init__(
17
17
  self,
18
18
  base_url: str,
@@ -38,8 +38,7 @@ class RuntimeEndpoint(BaseBackend):
38
38
  self.model_info = res.json()
39
39
 
40
40
  self.chat_template = get_chat_template_by_model_path(
41
- self.model_info["model_path"]
42
- )
41
+ self.model_info["model_path"])
43
42
 
44
43
  def get_model_name(self):
45
44
  return self.model_info["model_path"]
@@ -125,6 +124,11 @@ class RuntimeEndpoint(BaseBackend):
125
124
  else:
126
125
  raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
127
126
 
127
+ for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
128
+ value = getattr(sampling_params, item, None)
129
+ if value is not None:
130
+ data[item] = value
131
+
128
132
  self._add_images(s, data)
129
133
 
130
134
  res = http_request(
@@ -167,6 +171,11 @@ class RuntimeEndpoint(BaseBackend):
167
171
  else:
168
172
  raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
169
173
 
174
+ for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
175
+ value = getattr(sampling_params, item, None)
176
+ if value is not None:
177
+ data[item] = value
178
+
170
179
  data["stream"] = True
171
180
  self._add_images(s, data)
172
181
 
@@ -181,21 +190,16 @@ class RuntimeEndpoint(BaseBackend):
181
190
  self._assert_success(res)
182
191
  pos = 0
183
192
 
184
- incomplete_text = ""
185
193
  for chunk in res.iter_lines(decode_unicode=False):
186
194
  chunk = chunk.decode("utf-8")
187
195
  if chunk and chunk.startswith("data:"):
188
196
  if chunk == "data: [DONE]":
189
197
  break
190
198
  data = json.loads(chunk[5:].strip("\n"))
191
- text = find_printable_text(data["text"][pos:])
199
+ chunk_text = data["text"][pos:]
192
200
  meta_info = data["meta_info"]
193
- pos += len(text)
194
- incomplete_text = data["text"][pos:]
195
- yield text, meta_info
196
-
197
- if len(incomplete_text) > 0:
198
- yield incomplete_text, meta_info
201
+ pos += len(chunk_text)
202
+ yield chunk_text, meta_info
199
203
 
200
204
  def select(
201
205
  self,
@@ -0,0 +1,317 @@
1
+ """
2
+ Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
3
+
4
+ # Usage (latency test):
5
+ python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
6
+
7
+ # Usage (correctness test):
8
+ python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
9
+
10
+ ### Reference output:
11
+ prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
12
+ [-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
13
+ [ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
14
+ device='cuda:0', dtype=torch.float16)
15
+ prefill logits (final) tensor([[-8.3203, -7.1211, 3.3379, ..., -4.9570, -4.1328, -3.4141],
16
+ [-8.9062, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0742],
17
+ [-9.6328, -9.0547, 4.0117, ..., -5.3047, -4.7148, -4.4609]],
18
+ device='cuda:0', dtype=torch.float16)
19
+ <s> The capital of France is.
20
+ The capital of the United States is Washington, D.C.
21
+
22
+ <s> The capital of the United Kindom is.
23
+ The capital of the United Kingdom is London.
24
+ The capital of the
25
+ <s> Today is a sunny day and I like go for a walk in the park.
26
+ I'm going to the park
27
+ """
28
+
29
+ import argparse
30
+ import dataclasses
31
+ import logging
32
+ import multiprocessing
33
+ import time
34
+
35
+ import numpy as np
36
+ import torch
37
+ import torch.distributed as dist
38
+
39
+ from sglang.srt.hf_transformers_utils import get_tokenizer
40
+ from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req
41
+ from sglang.srt.managers.controller.model_runner import ModelRunner
42
+ from sglang.srt.model_config import ModelConfig
43
+ from sglang.srt.sampling_params import SamplingParams
44
+ from sglang.srt.server_args import ServerArgs
45
+ from sglang.srt.utils import suppress_other_loggers
46
+
47
+
48
+ @dataclasses.dataclass
49
+ class BenchArgs:
50
+ batch_size: int = 1
51
+ input_len: int = 1024
52
+ output_len: int = 4
53
+ correctness_test: bool = False
54
+ # This is only used for correctness test
55
+ cut_len: int = 4
56
+
57
+ @staticmethod
58
+ def add_cli_args(parser: argparse.ArgumentParser):
59
+ parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size)
60
+ parser.add_argument("--input-len", type=int, default=BenchArgs.input_len)
61
+ parser.add_argument("--output-len", type=int, default=BenchArgs.output_len)
62
+ parser.add_argument("--correctness-test", action="store_true")
63
+ parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
64
+
65
+ @classmethod
66
+ def from_cli_args(cls, args: argparse.Namespace):
67
+ attrs = [attr.name for attr in dataclasses.fields(cls)]
68
+ return cls(**{attr: getattr(args, attr) for attr in attrs})
69
+
70
+
71
+ def load_model(server_args, tp_rank):
72
+ suppress_other_loggers()
73
+
74
+ model_config = ModelConfig(path=server_args.model_path)
75
+ model_runner = ModelRunner(
76
+ model_config=model_config,
77
+ mem_fraction_static=server_args.mem_fraction_static,
78
+ gpu_id=tp_rank,
79
+ tp_rank=tp_rank,
80
+ tp_size=server_args.tp_size,
81
+ nccl_port=28888,
82
+ server_args=server_args,
83
+ )
84
+ print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
85
+ tokenizer = get_tokenizer(
86
+ server_args.tokenizer_path,
87
+ tokenizer_mode=server_args.tokenizer_mode,
88
+ trust_remote_code=server_args.trust_remote_code,
89
+ )
90
+ if server_args.tp_size > 1:
91
+ dist.barrier()
92
+ return model_runner, tokenizer
93
+
94
+
95
+ def prepare_inputs(bench_args, tokenizer):
96
+ prompts = [
97
+ "The capital of France is",
98
+ "The capital of the United Kindom is",
99
+ "Today is a sunny day and I like",
100
+ ]
101
+ input_ids = [tokenizer.encode(p) for p in prompts]
102
+ sampling_params = SamplingParams(
103
+ temperature=0,
104
+ max_new_tokens=BenchArgs.output_len,
105
+ )
106
+
107
+ reqs = []
108
+ for i in range(len(prompts)):
109
+ assert len(input_ids[i]) > bench_args.cut_len
110
+
111
+ tmp_input_ids = input_ids[i][: bench_args.cut_len]
112
+ req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
113
+ req.prefix_indices = []
114
+ req.sampling_params = sampling_params
115
+ req.input_ids = req.origin_input_ids
116
+ reqs.append(req)
117
+
118
+ return input_ids, reqs
119
+
120
+
121
+ def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
122
+ for i in range(len(reqs)):
123
+ req = reqs[i]
124
+ req.input_ids += input_ids[i][bench_args.cut_len :]
125
+ req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
126
+ i, : bench_args.cut_len
127
+ ]
128
+ return reqs
129
+
130
+
131
+ def prepare_synthetic_inputs(bench_args, tokenizer):
132
+ input_ids = np.ones((bench_args.batch_size, bench_args.input_len), dtype=np.int32)
133
+ sampling_params = SamplingParams(
134
+ temperature=0,
135
+ max_new_tokens=BenchArgs.output_len,
136
+ )
137
+
138
+ reqs = []
139
+ for i in range(len(input_ids)):
140
+ req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
141
+ req.prefix_indices = []
142
+ req.sampling_params = sampling_params
143
+ req.input_ids = req.origin_input_ids
144
+ reqs.append(req)
145
+
146
+ return reqs
147
+
148
+
149
+ def extend(reqs, model_runner):
150
+ batch = Batch.init_new(
151
+ reqs=reqs,
152
+ req_to_token_pool=model_runner.req_to_token_pool,
153
+ token_to_kv_pool=model_runner.token_to_kv_pool,
154
+ tree_cache=None,
155
+ )
156
+ batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
157
+ output = model_runner.forward(batch, ForwardMode.EXTEND)
158
+ next_token_ids, _ = batch.sample(output.next_token_logits)
159
+ return next_token_ids, output.next_token_logits, batch
160
+
161
+
162
+ def decode(input_token_ids, batch, model_runner):
163
+ batch.prepare_for_decode(input_token_ids.cpu().numpy())
164
+ output = model_runner.forward(batch, ForwardMode.DECODE)
165
+ next_token_ids, _ = batch.sample(output.next_token_logits)
166
+ return next_token_ids, output.next_token_logits
167
+
168
+
169
+ @torch.inference_mode()
170
+ def correctness_test(
171
+ server_args,
172
+ bench_args,
173
+ tp_rank,
174
+ ):
175
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
176
+
177
+ # Load the model
178
+ model_runner, tokenizer = load_model(server_args, tp_rank)
179
+
180
+ # Prepare inputs
181
+ input_ids, reqs = prepare_inputs(bench_args, tokenizer)
182
+
183
+ if bench_args.cut_len > 0:
184
+ # Prefill
185
+ next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
186
+ rank_print("prefill logits (first half)", next_token_logits)
187
+
188
+ # Prepare extend inputs
189
+ reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner)
190
+
191
+ # Extend
192
+ next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
193
+ rank_print("prefill logits (final)", next_token_logits)
194
+
195
+ # Decode
196
+ output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
197
+ for _ in range(bench_args.output_len):
198
+ next_token_ids, _ = decode(next_token_ids, batch, model_runner)
199
+ for i in range(len(reqs)):
200
+ output_ids[i].append(next_token_ids[i])
201
+
202
+ # Print
203
+ for i in range(len(reqs)):
204
+ print(tokenizer.decode(output_ids[i]))
205
+
206
+
207
+ def latency_test(
208
+ server_args,
209
+ bench_args,
210
+ tp_rank,
211
+ ):
212
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
213
+
214
+ # Load the model
215
+ model_runner, tokenizer = load_model(server_args, tp_rank)
216
+ print(
217
+ f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
218
+ )
219
+
220
+ # Prepare inputs
221
+ reqs = prepare_synthetic_inputs(bench_args, tokenizer)
222
+
223
+ def clear():
224
+ model_runner.req_to_token_pool.clear()
225
+ model_runner.token_to_kv_pool.clear()
226
+
227
+ @torch.inference_mode()
228
+ def run_once(output_len):
229
+ # Prefill
230
+ torch.cuda.synchronize()
231
+ tot_latency = 0
232
+ tic = time.time()
233
+ next_token_ids, _, batch = extend(reqs, model_runner)
234
+ torch.cuda.synchronize()
235
+ prefill_latency = time.time() - tic
236
+ tot_latency += prefill_latency
237
+ throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
238
+ rank_print(
239
+ f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
240
+ )
241
+
242
+ # Decode
243
+ for i in range(output_len):
244
+ torch.cuda.synchronize()
245
+ tic = time.time()
246
+ next_token_ids, _ = decode(next_token_ids, batch, model_runner)
247
+ torch.cuda.synchronize()
248
+ latency = time.time() - tic
249
+ tot_latency += latency
250
+ throughput = bench_args.batch_size / latency
251
+ if i < 5:
252
+ rank_print(
253
+ f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
254
+ )
255
+ avg_decode_latency = (tot_latency - prefill_latency) / output_len
256
+ avg_decode_throughput = bench_args.batch_size / avg_decode_latency
257
+ rank_print(
258
+ f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
259
+ )
260
+
261
+ throughput = (
262
+ (bench_args.input_len + bench_args.output_len)
263
+ * bench_args.batch_size
264
+ / tot_latency
265
+ )
266
+ rank_print(
267
+ f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
268
+ )
269
+
270
+ # Warm up
271
+ run_once(4)
272
+ clear()
273
+
274
+ # Run again
275
+ run_once(bench_args.output_len)
276
+
277
+
278
+ def main(server_args, bench_args):
279
+ print(bench_args)
280
+
281
+ if bench_args.correctness_test:
282
+ work_func = correctness_test
283
+ else:
284
+ work_func = latency_test
285
+
286
+ workers = []
287
+ for tp_rank in range(server_args.tp_size):
288
+ proc = multiprocessing.Process(
289
+ target=work_func,
290
+ args=(
291
+ server_args,
292
+ bench_args,
293
+ tp_rank,
294
+ ),
295
+ )
296
+ proc.start()
297
+ workers.append(proc)
298
+
299
+ for proc in workers:
300
+ proc.join()
301
+
302
+
303
+ if __name__ == "__main__":
304
+ parser = argparse.ArgumentParser()
305
+ ServerArgs.add_cli_args(parser)
306
+ BenchArgs.add_cli_args(parser)
307
+ args = parser.parse_args()
308
+
309
+ server_args = ServerArgs.from_cli_args(args)
310
+ bench_args = BenchArgs.from_cli_args(args)
311
+
312
+ logging.basicConfig(
313
+ level=getattr(logging, server_args.log_level.upper()),
314
+ format="%(message)s",
315
+ )
316
+
317
+ main(server_args, bench_args)
sglang/global_config.py CHANGED
@@ -27,7 +27,7 @@ class GlobalConfig:
27
27
 
28
28
  # Request dependency time due to network delay
29
29
  self.request_dependency_delay = 0.02
30
- self.wait_for_new_request_delay = 0.0004
30
+ self.wait_for_new_request_delay = 0.0006
31
31
 
32
32
  # New generation token ratio estimation
33
33
  self.base_new_token_ratio = 0.4
@@ -35,5 +35,9 @@ class GlobalConfig:
35
35
  self.new_token_ratio_decay = 0.0001
36
36
  self.new_token_ratio_recovery = 0.05
37
37
 
38
+ # The threshold (number of tokens) to trigger layer-wise cuda sync.
39
+ # This can improve the speed for large batch sizes during prefill.
40
+ self.layer_sync_threshold = 8192
41
+
38
42
 
39
43
  global_config = GlobalConfig()