sglang 0.1.14__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. sglang/__init__.py +55 -2
  2. sglang/api.py +3 -5
  3. sglang/backend/anthropic.py +18 -4
  4. sglang/backend/openai.py +2 -1
  5. sglang/backend/runtime_endpoint.py +18 -5
  6. sglang/backend/vertexai.py +1 -0
  7. sglang/global_config.py +1 -0
  8. sglang/lang/chat_template.py +74 -0
  9. sglang/lang/interpreter.py +40 -16
  10. sglang/lang/tracer.py +6 -4
  11. sglang/launch_server.py +2 -1
  12. sglang/srt/constrained/fsm_cache.py +1 -0
  13. sglang/srt/constrained/jump_forward.py +1 -0
  14. sglang/srt/conversation.py +2 -2
  15. sglang/srt/hf_transformers_utils.py +2 -1
  16. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  17. sglang/srt/layers/extend_attention.py +1 -0
  18. sglang/srt/layers/logits_processor.py +114 -54
  19. sglang/srt/layers/radix_attention.py +2 -1
  20. sglang/srt/layers/token_attention.py +1 -0
  21. sglang/srt/managers/detokenizer_manager.py +5 -1
  22. sglang/srt/managers/io_struct.py +12 -0
  23. sglang/srt/managers/router/infer_batch.py +70 -33
  24. sglang/srt/managers/router/manager.py +7 -2
  25. sglang/srt/managers/router/model_rpc.py +116 -73
  26. sglang/srt/managers/router/model_runner.py +111 -167
  27. sglang/srt/managers/router/radix_cache.py +46 -38
  28. sglang/srt/managers/tokenizer_manager.py +56 -11
  29. sglang/srt/memory_pool.py +5 -14
  30. sglang/srt/model_config.py +7 -0
  31. sglang/srt/models/commandr.py +376 -0
  32. sglang/srt/models/dbrx.py +413 -0
  33. sglang/srt/models/dbrx_config.py +281 -0
  34. sglang/srt/models/gemma.py +22 -20
  35. sglang/srt/models/llama2.py +23 -21
  36. sglang/srt/models/llava.py +12 -10
  37. sglang/srt/models/mixtral.py +27 -25
  38. sglang/srt/models/qwen.py +23 -21
  39. sglang/srt/models/qwen2.py +23 -21
  40. sglang/srt/models/stablelm.py +20 -21
  41. sglang/srt/models/yivl.py +6 -5
  42. sglang/srt/openai_api_adapter.py +356 -0
  43. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
  44. sglang/srt/sampling_params.py +2 -0
  45. sglang/srt/server.py +68 -447
  46. sglang/srt/server_args.py +76 -49
  47. sglang/srt/utils.py +88 -32
  48. sglang/srt/weight_utils.py +402 -0
  49. sglang/test/test_programs.py +8 -7
  50. sglang/test/test_utils.py +195 -7
  51. {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/METADATA +12 -14
  52. sglang-0.1.15.dist-info/RECORD +69 -0
  53. sglang-0.1.14.dist-info/RECORD +0 -64
  54. {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/LICENSE +0 -0
  55. {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/WHEEL +0 -0
  56. {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -1,3 +1,5 @@
1
+ """The arguments of the server."""
2
+
1
3
  import argparse
2
4
  import dataclasses
3
5
  from typing import List, Optional, Union
@@ -5,34 +7,47 @@ from typing import List, Optional, Union
5
7
 
6
8
  @dataclasses.dataclass
7
9
  class ServerArgs:
10
+ # Model and tokenizer
8
11
  model_path: str
9
12
  tokenizer_path: Optional[str] = None
10
- host: str = "127.0.0.1"
11
- port: int = 30000
12
- additional_ports: Optional[Union[List[int], int]] = None
13
13
  load_format: str = "auto"
14
14
  tokenizer_mode: str = "auto"
15
15
  chat_template: Optional[str] = None
16
16
  trust_remote_code: bool = True
17
+ context_length: Optional[int] = None
18
+
19
+ # Port
20
+ host: str = "127.0.0.1"
21
+ port: int = 30000
22
+ additional_ports: Optional[Union[List[int], int]] = None
23
+
24
+ # Memory and scheduling
17
25
  mem_fraction_static: Optional[float] = None
18
26
  max_prefill_num_token: Optional[int] = None
19
- context_length: Optional[int] = None
20
- tp_size: int = 1
21
27
  schedule_heuristic: str = "lpm"
22
28
  schedule_conservativeness: float = 1.0
23
- attention_reduce_in_fp32: bool = False
24
- random_seed: int = 42
29
+
30
+ # Other runtime options
31
+ tp_size: int = 1
25
32
  stream_interval: int = 8
33
+ random_seed: int = 42
34
+
35
+ # Logging
36
+ log_level: str = "info"
37
+ log_requests: bool = False
26
38
  disable_log_stats: bool = False
27
39
  log_stats_interval: int = 10
28
- log_level: str = "info"
40
+ show_time_cost: bool = False
29
41
 
30
- # optional modes
31
- disable_radix_cache: bool = False
42
+ # Other
43
+ api_key: str = ""
44
+
45
+ # Optimization/debug options
32
46
  enable_flashinfer: bool = False
47
+ attention_reduce_in_fp32: bool = False
48
+ disable_radix_cache: bool = False
33
49
  disable_regex_jump_forward: bool = False
34
50
  disable_disk_cache: bool = False
35
- api_key: str = ""
36
51
 
37
52
  def __post_init__(self):
38
53
  if self.tokenizer_path is None:
@@ -65,15 +80,16 @@ class ServerArgs:
65
80
  default=ServerArgs.tokenizer_path,
66
81
  help="The path of the tokenizer.",
67
82
  )
68
- parser.add_argument("--host", type=str, default=ServerArgs.host)
69
- parser.add_argument("--port", type=int, default=ServerArgs.port)
70
- # we want to be able to pass a list of ports
83
+ parser.add_argument("--host", type=str, default=ServerArgs.host,
84
+ help="The host of the server.")
85
+ parser.add_argument("--port", type=int, default=ServerArgs.port,
86
+ help="The port of the server.")
71
87
  parser.add_argument(
72
88
  "--additional-ports",
73
89
  type=int,
74
90
  nargs="*",
75
91
  default=[],
76
- help="Additional ports specified for launching server.",
92
+ help="Additional ports specified for the server.",
77
93
  )
78
94
  parser.add_argument(
79
95
  "--load-format",
@@ -111,6 +127,12 @@ class ServerArgs:
111
127
  action="store_true",
112
128
  help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
113
129
  )
130
+ parser.add_argument(
131
+ "--context-length",
132
+ type=int,
133
+ default=ServerArgs.context_length,
134
+ help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
135
+ )
114
136
  parser.add_argument(
115
137
  "--mem-fraction-static",
116
138
  type=float,
@@ -123,18 +145,6 @@ class ServerArgs:
123
145
  default=ServerArgs.max_prefill_num_token,
124
146
  help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
125
147
  )
126
- parser.add_argument(
127
- "--context-length",
128
- type=int,
129
- default=ServerArgs.context_length,
130
- help="The model's maximum context length. Use this to reduce the context length to save memory. Defaults to None (will use the value from the model's config.json instead).",
131
- )
132
- parser.add_argument(
133
- "--tp-size",
134
- type=int,
135
- default=ServerArgs.tp_size,
136
- help="Tensor parallelism degree.",
137
- )
138
148
  parser.add_argument(
139
149
  "--schedule-heuristic",
140
150
  type=str,
@@ -148,15 +158,10 @@ class ServerArgs:
148
158
  help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
149
159
  )
150
160
  parser.add_argument(
151
- "--random-seed",
161
+ "--tp-size",
152
162
  type=int,
153
- default=ServerArgs.random_seed,
154
- help="Random seed.",
155
- )
156
- parser.add_argument(
157
- "--attention-reduce-in-fp32",
158
- action="store_true",
159
- help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
163
+ default=ServerArgs.tp_size,
164
+ help="Tensor parallelism size.",
160
165
  )
161
166
  parser.add_argument(
162
167
  "--stream-interval",
@@ -164,11 +169,22 @@ class ServerArgs:
164
169
  default=ServerArgs.stream_interval,
165
170
  help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
166
171
  )
172
+ parser.add_argument(
173
+ "--random-seed",
174
+ type=int,
175
+ default=ServerArgs.random_seed,
176
+ help="Random seed.",
177
+ )
167
178
  parser.add_argument(
168
179
  "--log-level",
169
180
  type=str,
170
181
  default=ServerArgs.log_level,
171
- help="Log level",
182
+ help="Logging level",
183
+ )
184
+ parser.add_argument(
185
+ "--log-requests",
186
+ action="store_true",
187
+ help="Log all requests",
172
188
  )
173
189
  parser.add_argument(
174
190
  "--disable-log-stats",
@@ -181,17 +197,34 @@ class ServerArgs:
181
197
  default=ServerArgs.log_stats_interval,
182
198
  help="Log stats interval in second.",
183
199
  )
184
- # optional modes
185
200
  parser.add_argument(
186
- "--disable-radix-cache",
201
+ "--show-time-cost",
187
202
  action="store_true",
188
- help="Disable RadixAttention",
203
+ help="Show time cost of custom marks",
204
+ )
205
+ parser.add_argument(
206
+ "--api-key",
207
+ type=str,
208
+ default=ServerArgs.api_key,
209
+ help="Set API key of the server",
189
210
  )
211
+
212
+ # Optimization/debug options
190
213
  parser.add_argument(
191
214
  "--enable-flashinfer",
192
215
  action="store_true",
193
216
  help="Enable flashinfer inference kernels",
194
217
  )
218
+ parser.add_argument(
219
+ "--attention-reduce-in-fp32",
220
+ action="store_true",
221
+ help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
222
+ )
223
+ parser.add_argument(
224
+ "--disable-radix-cache",
225
+ action="store_true",
226
+ help="Disable RadixAttention",
227
+ )
195
228
  parser.add_argument(
196
229
  "--disable-regex-jump-forward",
197
230
  action="store_true",
@@ -202,12 +235,6 @@ class ServerArgs:
202
235
  action="store_true",
203
236
  help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
204
237
  )
205
- parser.add_argument(
206
- "--api-key",
207
- type=str,
208
- default=ServerArgs.api_key,
209
- help="Set API Key",
210
- )
211
238
 
212
239
  @classmethod
213
240
  def from_cli_args(cls, args: argparse.Namespace):
@@ -217,13 +244,13 @@ class ServerArgs:
217
244
  def url(self):
218
245
  return f"http://{self.host}:{self.port}"
219
246
 
220
- def get_optional_modes_logging(self):
247
+ def print_mode_args(self):
221
248
  return (
222
- f"disable_radix_cache={self.disable_radix_cache}, "
223
249
  f"enable_flashinfer={self.enable_flashinfer}, "
250
+ f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
251
+ f"disable_radix_cache={self.disable_radix_cache}, "
224
252
  f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
225
253
  f"disable_disk_cache={self.disable_disk_cache}, "
226
- f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}"
227
254
  )
228
255
 
229
256
 
@@ -233,4 +260,4 @@ class PortArgs:
233
260
  router_port: int
234
261
  detokenizer_port: int
235
262
  nccl_port: int
236
- model_rpc_ports: List[int]
263
+ model_rpc_ports: List[int]
sglang/srt/utils.py CHANGED
@@ -1,3 +1,5 @@
1
+ """Common utilities."""
2
+
1
3
  import base64
2
4
  import os
3
5
  import random
@@ -5,54 +7,68 @@ import socket
5
7
  import sys
6
8
  import time
7
9
  import traceback
10
+ from importlib.metadata import PackageNotFoundError, version
8
11
  from io import BytesIO
9
12
  from typing import List, Optional
10
13
 
11
14
  import numpy as np
15
+ import pydantic
12
16
  import requests
13
17
  import torch
14
- import torch.distributed as dist
18
+ from fastapi.responses import JSONResponse
19
+ from packaging import version as pkg_version
20
+ from pydantic import BaseModel
21
+ from starlette.middleware.base import BaseHTTPMiddleware
15
22
 
16
- is_show_cost_time = False
23
+ show_time_cost = False
24
+ time_infos = {}
17
25
 
18
26
 
19
- def mark_cost_time(func_name):
20
- def inner_func(func):
21
- def time_func(*args, **kwargs):
22
- if dist.get_rank() in [0, 1] and is_show_cost_time:
23
- torch.cuda.synchronize()
24
- start_time = time.time()
25
- ans = func(*args, **kwargs)
26
- torch.cuda.synchronize()
27
- print(func_name, "cost time:", (time.time() - start_time) * 1000)
28
- return ans
29
- else:
30
- torch.cuda.synchronize()
31
- ans = func(*args, **kwargs)
32
- torch.cuda.synchronize()
33
- return ans
27
+ def enable_show_time_cost():
28
+ global show_time_cost
29
+ show_time_cost = True
34
30
 
35
- return time_func
36
31
 
37
- return inner_func
32
+ class TimeInfo:
33
+ def __init__(self, name, interval=0.1, color=0, indent=0):
34
+ self.name = name
35
+ self.interval = interval
36
+ self.color = color
37
+ self.indent = indent
38
38
 
39
+ self.acc_time = 0
40
+ self.last_acc_time = 0
41
+
42
+ def check(self):
43
+ if self.acc_time - self.last_acc_time > self.interval:
44
+ self.last_acc_time = self.acc_time
45
+ return True
46
+ return False
39
47
 
40
- time_mark = {}
48
+ def pretty_print(self):
49
+ print(f"\x1b[{self.color}m", end="")
50
+ print("-" * self.indent * 2, end="")
51
+ print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m")
41
52
 
42
53
 
43
- def mark_start(key):
54
+ def mark_start(name, interval=0.1, color=0, indent=0):
55
+ global time_infos, show_time_cost
56
+ if not show_time_cost:
57
+ return
44
58
  torch.cuda.synchronize()
45
- global time_mark
46
- time_mark[key] = time.time()
47
- return
59
+ if time_infos.get(name, None) is None:
60
+ time_infos[name] = TimeInfo(name, interval, color, indent)
61
+ time_infos[name].acc_time -= time.time()
48
62
 
49
63
 
50
- def mark_end(key, print_min_cost=0.0):
64
+ def mark_end(name):
65
+ global time_infos, show_time_cost
66
+ if not show_time_cost:
67
+ return
51
68
  torch.cuda.synchronize()
52
- global time_mark
53
- cost_time = (time.time() - time_mark[key]) * 1000
54
- if cost_time > print_min_cost:
55
- print(f"cost {key}:", cost_time)
69
+ time_infos[name].acc_time += time.time()
70
+ if time_infos[name].check():
71
+ time_infos[name].pretty_print()
56
72
 
57
73
 
58
74
  def calculate_time(show=False, min_cost_ms=0.0):
@@ -110,7 +126,7 @@ def check_port(port):
110
126
  return False
111
127
 
112
128
 
113
- def handle_port_init(
129
+ def allocate_init_ports(
114
130
  port: Optional[int] = None,
115
131
  additional_ports: Optional[List[int]] = None,
116
132
  tp_size: int = 1,
@@ -149,8 +165,6 @@ def get_exception_traceback():
149
165
 
150
166
 
151
167
  def get_int_token_logit_bias(tokenizer, vocab_size):
152
- from transformers import LlamaTokenizer, LlamaTokenizerFast
153
-
154
168
  # a bug when model's vocab size > tokenizer.vocab_size
155
169
  vocab_size = tokenizer.vocab_size
156
170
  logit_bias = np.zeros(vocab_size, dtype=np.float32)
@@ -259,3 +273,45 @@ def load_image(image_file):
259
273
  image = Image.open(BytesIO(base64.b64decode(image_file)))
260
274
 
261
275
  return image
276
+
277
+
278
+ def assert_pkg_version(pkg: str, min_version: str):
279
+ try:
280
+ installed_version = version(pkg)
281
+ if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
282
+ raise Exception(
283
+ f"{pkg} is installed with version {installed_version} which "
284
+ f"is less than the minimum required version {min_version}"
285
+ )
286
+ except PackageNotFoundError:
287
+ raise Exception(f"{pkg} with minimum required version {min_version} is not installed")
288
+
289
+
290
+ API_KEY_HEADER_NAME = "X-API-Key"
291
+
292
+
293
+ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
294
+ def __init__(self, app, api_key: str):
295
+ super().__init__(app)
296
+ self.api_key = api_key
297
+
298
+ async def dispatch(self, request, call_next):
299
+ # extract API key from the request headers
300
+ api_key_header = request.headers.get(API_KEY_HEADER_NAME)
301
+ if not api_key_header or api_key_header != self.api_key:
302
+ return JSONResponse(
303
+ status_code=403,
304
+ content={"detail": "Invalid API Key"},
305
+ )
306
+ response = await call_next(request)
307
+ return response
308
+
309
+
310
+ # FIXME: Remove this once we drop support for pydantic 1.x
311
+ IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
312
+
313
+
314
+ def jsonify_pydantic_model(obj: BaseModel):
315
+ if IS_PYDANTIC_1:
316
+ return obj.json(ensure_ascii=False)
317
+ return obj.model_dump_json()