sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/__init__.py +5 -1
  2. sglang/api.py +8 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +11 -1
  8. sglang/lang/chat_template.py +9 -2
  9. sglang/lang/interpreter.py +161 -81
  10. sglang/lang/ir.py +29 -11
  11. sglang/lang/tracer.py +1 -1
  12. sglang/launch_server.py +1 -2
  13. sglang/launch_server_llavavid.py +31 -0
  14. sglang/srt/constrained/fsm_cache.py +3 -0
  15. sglang/srt/flush_cache.py +16 -0
  16. sglang/srt/hf_transformers_utils.py +83 -2
  17. sglang/srt/layers/extend_attention.py +17 -0
  18. sglang/srt/layers/fused_moe.py +485 -0
  19. sglang/srt/layers/logits_processor.py +12 -7
  20. sglang/srt/layers/radix_attention.py +10 -3
  21. sglang/srt/layers/token_attention.py +16 -1
  22. sglang/srt/managers/controller/dp_worker.py +110 -0
  23. sglang/srt/managers/controller/infer_batch.py +619 -0
  24. sglang/srt/managers/controller/manager_multi.py +191 -0
  25. sglang/srt/managers/controller/manager_single.py +97 -0
  26. sglang/srt/managers/controller/model_runner.py +462 -0
  27. sglang/srt/managers/controller/radix_cache.py +267 -0
  28. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  29. sglang/srt/managers/controller/tp_worker.py +791 -0
  30. sglang/srt/managers/detokenizer_manager.py +45 -45
  31. sglang/srt/managers/io_struct.py +26 -10
  32. sglang/srt/managers/router/infer_batch.py +130 -74
  33. sglang/srt/managers/router/manager.py +7 -9
  34. sglang/srt/managers/router/model_rpc.py +224 -135
  35. sglang/srt/managers/router/model_runner.py +94 -107
  36. sglang/srt/managers/router/radix_cache.py +54 -18
  37. sglang/srt/managers/router/scheduler.py +23 -34
  38. sglang/srt/managers/tokenizer_manager.py +183 -88
  39. sglang/srt/model_config.py +5 -2
  40. sglang/srt/models/commandr.py +15 -22
  41. sglang/srt/models/dbrx.py +22 -29
  42. sglang/srt/models/gemma.py +14 -24
  43. sglang/srt/models/grok.py +671 -0
  44. sglang/srt/models/llama2.py +24 -23
  45. sglang/srt/models/llava.py +85 -25
  46. sglang/srt/models/llavavid.py +298 -0
  47. sglang/srt/models/mixtral.py +254 -130
  48. sglang/srt/models/mixtral_quant.py +373 -0
  49. sglang/srt/models/qwen.py +28 -25
  50. sglang/srt/models/qwen2.py +17 -22
  51. sglang/srt/models/stablelm.py +21 -26
  52. sglang/srt/models/yivl.py +17 -25
  53. sglang/srt/openai_api_adapter.py +140 -95
  54. sglang/srt/openai_protocol.py +10 -1
  55. sglang/srt/server.py +101 -52
  56. sglang/srt/server_args.py +59 -11
  57. sglang/srt/utils.py +242 -75
  58. sglang/test/test_programs.py +44 -0
  59. sglang/test/test_utils.py +32 -1
  60. sglang/utils.py +95 -26
  61. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
  62. sglang-0.1.17.dist-info/RECORD +81 -0
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -402
  66. sglang-0.1.15.dist-info/RECORD +0 -69
  67. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  69. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
sglang/srt/server.py CHANGED
@@ -9,7 +9,8 @@ import os
9
9
  import sys
10
10
  import threading
11
11
  import time
12
- from typing import List, Optional, Union
12
+ from http import HTTPStatus
13
+ from typing import Optional
13
14
 
14
15
  # Fix a bug of Python threading
15
16
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -20,26 +21,30 @@ import requests
20
21
  import uvicorn
21
22
  import uvloop
22
23
  from fastapi import FastAPI, Request
23
- from fastapi.responses import Response, StreamingResponse
24
+ from fastapi.responses import JSONResponse, Response, StreamingResponse
24
25
 
25
26
  from sglang.backend.runtime_endpoint import RuntimeEndpoint
26
27
  from sglang.srt.constrained import disable_cache
27
28
  from sglang.srt.hf_transformers_utils import get_tokenizer
28
29
  from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
29
30
  from sglang.srt.managers.io_struct import GenerateReqInput
30
- from sglang.srt.managers.router.manager import start_router_process
31
+ from sglang.srt.managers.controller.manager_single import start_controller_process as start_controller_process_single
32
+ from sglang.srt.managers.controller.manager_multi import start_controller_process as start_controller_process_multi
31
33
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
32
34
  from sglang.srt.openai_api_adapter import (
33
- v1_completions, v1_chat_completions, load_chat_template_for_openai_api)
34
- from sglang.srt.server_args import PortArgs, ServerArgs
35
+ load_chat_template_for_openai_api,
36
+ v1_chat_completions,
37
+ v1_completions,
38
+ )
39
+ from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
35
40
  from sglang.srt.utils import (
41
+ API_KEY_HEADER_NAME,
42
+ APIKeyValidatorMiddleware,
36
43
  allocate_init_ports,
37
44
  assert_pkg_version,
38
45
  enable_show_time_cost,
39
- get_exception_traceback,
40
- API_KEY_HEADER_NAME,
41
- APIKeyValidatorMiddleware
42
46
  )
47
+ from sglang.utils import get_exception_traceback
43
48
 
44
49
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
45
50
 
@@ -69,7 +74,7 @@ async def get_server_args():
69
74
 
70
75
  @app.get("/flush_cache")
71
76
  async def flush_cache():
72
- await tokenizer_manager.flush_cache()
77
+ tokenizer_manager.flush_cache()
73
78
  return Response(
74
79
  content="Cache flushed.\nPlease check backend logs for more details. "
75
80
  "(When there are running or waiting requests, the operation will not be performed.)\n",
@@ -77,21 +82,32 @@ async def flush_cache():
77
82
  )
78
83
 
79
84
 
80
- @app.post("/generate")
81
- async def generate_request(obj: GenerateReqInput):
82
- obj.post_init()
83
-
85
+ async def generate_request(obj: GenerateReqInput, request: Request):
84
86
  if obj.stream:
85
87
 
86
88
  async def stream_results():
87
- async for out in tokenizer_manager.generate_request(obj):
89
+ try:
90
+ async for out in tokenizer_manager.generate_request(obj, request):
91
+ yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
92
+ except ValueError as e:
93
+ out = {"error": {"message": str(e)}}
88
94
  yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
89
95
  yield "data: [DONE]\n\n"
90
96
 
91
- return StreamingResponse(stream_results(), media_type="text/event-stream")
97
+ return StreamingResponse(stream_results(), media_type="text/event-stream",
98
+ background=tokenizer_manager.create_abort_task(obj))
99
+ else:
100
+ try:
101
+ ret = await tokenizer_manager.generate_request(obj, request).__anext__()
102
+ return ret
103
+ except ValueError as e:
104
+ return JSONResponse(
105
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
106
+ )
107
+
92
108
 
93
- ret = await tokenizer_manager.generate_request(obj).__anext__()
94
- return ret
109
+ app.post("/generate")(generate_request)
110
+ app.put("/generate")(generate_request)
95
111
 
96
112
 
97
113
  @app.post("/v1/completions")
@@ -104,7 +120,7 @@ async def openai_v1_chat_completions(raw_request: Request):
104
120
  return await v1_chat_completions(tokenizer_manager, raw_request)
105
121
 
106
122
 
107
- def launch_server(server_args: ServerArgs, pipe_finish_writer):
123
+ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None):
108
124
  global tokenizer_manager
109
125
 
110
126
  logging.basicConfig(
@@ -126,28 +142,42 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
126
142
 
127
143
  # Allocate ports
128
144
  server_args.port, server_args.additional_ports = allocate_init_ports(
129
- server_args.port, server_args.additional_ports, server_args.tp_size
145
+ server_args.port,
146
+ server_args.additional_ports,
147
+ server_args.tp_size,
148
+ server_args.dp_size,
130
149
  )
150
+
151
+ # Init local models port args
152
+ ports = server_args.additional_ports
153
+ tp = server_args.tp_size
154
+ model_port_args = []
155
+ for i in range(server_args.dp_size):
156
+ model_port_args.append(
157
+ ModelPortArgs(
158
+ nccl_port=ports[3 + i * (tp + 1)],
159
+ model_tp_ports=ports[3 + i * (tp + 1) + 1 : 3 + (i + 1) * (tp + 1)],
160
+ )
161
+ )
131
162
  port_args = PortArgs(
132
- tokenizer_port=server_args.additional_ports[0],
133
- router_port=server_args.additional_ports[1],
134
- detokenizer_port=server_args.additional_ports[2],
135
- nccl_port=server_args.additional_ports[3],
136
- model_rpc_ports=server_args.additional_ports[4:],
163
+ tokenizer_port=ports[0],
164
+ router_port=ports[1],
165
+ detokenizer_port=ports[2],
166
+ model_port_args=model_port_args,
137
167
  )
138
168
 
139
169
  # Launch processes
140
- tokenizer_manager = TokenizerManager(server_args, port_args)
170
+ tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
141
171
  pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
142
172
  pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
143
173
 
174
+ if server_args.dp_size == 1:
175
+ start_process = start_controller_process_single
176
+ else:
177
+ start_process = start_controller_process_multi
144
178
  proc_router = mp.Process(
145
- target=start_router_process,
146
- args=(
147
- server_args,
148
- port_args,
149
- pipe_router_writer,
150
- ),
179
+ target=start_process,
180
+ args=(server_args, port_args, pipe_router_writer, model_overide_args),
151
181
  )
152
182
  proc_router.start()
153
183
  proc_detoken = mp.Process(
@@ -167,14 +197,20 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
167
197
  if router_init_state != "init ok" or detoken_init_state != "init ok":
168
198
  proc_router.kill()
169
199
  proc_detoken.kill()
170
- print(f"Initialization failed. router_init_state: {router_init_state}", flush=True)
171
- print(f"Initialization failed. detoken_init_state: {detoken_init_state}", flush=True)
200
+ print(
201
+ f"Initialization failed. router_init_state: {router_init_state}", flush=True
202
+ )
203
+ print(
204
+ f"Initialization failed. detoken_init_state: {detoken_init_state}",
205
+ flush=True,
206
+ )
172
207
  sys.exit(1)
173
208
  assert proc_router.is_alive() and proc_detoken.is_alive()
174
209
 
175
210
  if server_args.api_key and server_args.api_key != "":
176
211
  app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
177
212
 
213
+ # Send a warmup request
178
214
  def _wait_and_warmup():
179
215
  headers = {}
180
216
  url = server_args.url()
@@ -192,20 +228,21 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
192
228
 
193
229
  # Send a warmup request
194
230
  try:
195
- res = requests.post(
196
- url + "/generate",
197
- json={
198
- "text": "Say this is a warmup request.",
199
- "sampling_params": {
200
- "temperature": 0,
201
- "max_new_tokens": 16,
231
+ for _ in range(server_args.dp_size):
232
+ res = requests.post(
233
+ url + "/generate",
234
+ json={
235
+ "text": "The capital city of France is",
236
+ "sampling_params": {
237
+ "temperature": 0,
238
+ "max_new_tokens": 16,
239
+ },
202
240
  },
203
- },
204
- headers=headers,
205
- timeout=60,
206
- )
207
- assert res.status_code == 200
208
- except Exception as e:
241
+ headers=headers,
242
+ timeout=600,
243
+ )
244
+ assert res.status_code == 200
245
+ except Exception:
209
246
  if pipe_finish_writer is not None:
210
247
  pipe_finish_writer.send(get_exception_traceback())
211
248
  print(f"Initialization failed. warmup error: {e}")
@@ -216,6 +253,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
216
253
 
217
254
  t = threading.Thread(target=_wait_and_warmup)
218
255
  t.start()
256
+
257
+ # Listen for requests
219
258
  try:
220
259
  uvicorn.run(
221
260
  app,
@@ -232,16 +271,21 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
232
271
  class Runtime:
233
272
  def __init__(
234
273
  self,
235
- log_evel="error",
274
+ log_level: str = "error",
275
+ model_overide_args: Optional[dict] = None,
236
276
  *args,
237
277
  **kwargs,
238
278
  ):
239
279
  """See the arguments in server_args.py::ServerArgs"""
240
- self.server_args = ServerArgs(*args, log_level=log_evel, **kwargs)
280
+ self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
241
281
 
242
282
  # Pre-allocate ports
243
283
  self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
244
- self.server_args.port, self.server_args.additional_ports, self.server_args.tp_size)
284
+ self.server_args.port,
285
+ self.server_args.additional_ports,
286
+ self.server_args.tp_size,
287
+ self.server_args.dp_size,
288
+ )
245
289
 
246
290
  self.url = self.server_args.url()
247
291
  self.generate_url = (
@@ -250,7 +294,10 @@ class Runtime:
250
294
 
251
295
  self.pid = None
252
296
  pipe_reader, pipe_writer = mp.Pipe(duplex=False)
253
- proc = mp.Process(target=launch_server, args=(self.server_args, pipe_writer))
297
+ proc = mp.Process(
298
+ target=launch_server,
299
+ args=(self.server_args, pipe_writer, model_overide_args),
300
+ )
254
301
  proc.start()
255
302
  pipe_writer.close()
256
303
  self.pid = proc.pid
@@ -262,7 +309,9 @@ class Runtime:
262
309
 
263
310
  if init_state != "init ok":
264
311
  self.shutdown()
265
- raise RuntimeError("Initialization failed. Please see the error messages above.")
312
+ raise RuntimeError(
313
+ "Initialization failed. Please see the error messages above."
314
+ )
266
315
 
267
316
  self.endpoint = RuntimeEndpoint(self.url)
268
317
 
@@ -314,4 +363,4 @@ class Runtime:
314
363
  pos += len(cur)
315
364
 
316
365
  def __del__(self):
317
- self.shutdown()
366
+ self.shutdown()
sglang/srt/server_args.py CHANGED
@@ -2,6 +2,7 @@
2
2
 
3
3
  import argparse
4
4
  import dataclasses
5
+ import random
5
6
  from typing import List, Optional, Union
6
7
 
7
8
 
@@ -15,6 +16,7 @@ class ServerArgs:
15
16
  chat_template: Optional[str] = None
16
17
  trust_remote_code: bool = True
17
18
  context_length: Optional[int] = None
19
+ quantization: Optional[str] = None
18
20
 
19
21
  # Port
20
22
  host: str = "127.0.0.1"
@@ -23,14 +25,15 @@ class ServerArgs:
23
25
 
24
26
  # Memory and scheduling
25
27
  mem_fraction_static: Optional[float] = None
26
- max_prefill_num_token: Optional[int] = None
28
+ max_prefill_tokens: Optional[int] = None
29
+ max_running_requests: Optional[int] = None
27
30
  schedule_heuristic: str = "lpm"
28
31
  schedule_conservativeness: float = 1.0
29
32
 
30
33
  # Other runtime options
31
34
  tp_size: int = 1
32
35
  stream_interval: int = 8
33
- random_seed: int = 42
36
+ random_seed: Optional[int] = None
34
37
 
35
38
  # Logging
36
39
  log_level: str = "info"
@@ -42,6 +45,10 @@ class ServerArgs:
42
45
  # Other
43
46
  api_key: str = ""
44
47
 
48
+ # Data parallelism
49
+ dp_size: int = 1
50
+ load_balance_method: str = "round_robin"
51
+
45
52
  # Optimization/debug options
46
53
  enable_flashinfer: bool = False
47
54
  attention_reduce_in_fp32: bool = False
@@ -66,6 +73,9 @@ class ServerArgs:
66
73
  elif self.additional_ports is None:
67
74
  self.additional_ports = []
68
75
 
76
+ if self.random_seed is None:
77
+ self.random_seed = random.randint(0, 1 << 30)
78
+
69
79
  @staticmethod
70
80
  def add_cli_args(parser: argparse.ArgumentParser):
71
81
  parser.add_argument(
@@ -80,10 +90,12 @@ class ServerArgs:
80
90
  default=ServerArgs.tokenizer_path,
81
91
  help="The path of the tokenizer.",
82
92
  )
83
- parser.add_argument("--host", type=str, default=ServerArgs.host,
84
- help="The host of the server.")
85
- parser.add_argument("--port", type=int, default=ServerArgs.port,
86
- help="The port of the server.")
93
+ parser.add_argument(
94
+ "--host", type=str, default=ServerArgs.host, help="The host of the server."
95
+ )
96
+ parser.add_argument(
97
+ "--port", type=int, default=ServerArgs.port, help="The port of the server."
98
+ )
87
99
  parser.add_argument(
88
100
  "--additional-ports",
89
101
  type=int,
@@ -133,6 +145,12 @@ class ServerArgs:
133
145
  default=ServerArgs.context_length,
134
146
  help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
135
147
  )
148
+ parser.add_argument(
149
+ "--quantization",
150
+ type=str,
151
+ default=ServerArgs.quantization,
152
+ help="The quantization method.",
153
+ )
136
154
  parser.add_argument(
137
155
  "--mem-fraction-static",
138
156
  type=float,
@@ -140,16 +158,23 @@ class ServerArgs:
140
158
  help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
141
159
  )
142
160
  parser.add_argument(
143
- "--max-prefill-num-token",
161
+ "--max-prefill-tokens",
144
162
  type=int,
145
- default=ServerArgs.max_prefill_num_token,
163
+ default=ServerArgs.max_prefill_tokens,
146
164
  help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
147
165
  )
166
+ parser.add_argument(
167
+ "--max-running-requests",
168
+ type=int,
169
+ default=ServerArgs.max_running_requests,
170
+ help="The maximum number of running requests.",
171
+ )
148
172
  parser.add_argument(
149
173
  "--schedule-heuristic",
150
174
  type=str,
151
175
  default=ServerArgs.schedule_heuristic,
152
- help="Schudule mode: [lpm, weight, random, fcfs]",
176
+ choices=["lpm", "random", "fcfs", "dfs-weight"],
177
+ help="Scheduling Heuristic.",
153
178
  )
154
179
  parser.add_argument(
155
180
  "--schedule-conservativeness",
@@ -209,6 +234,24 @@ class ServerArgs:
209
234
  help="Set API key of the server",
210
235
  )
211
236
 
237
+ # Data parallelism
238
+ parser.add_argument(
239
+ "--dp-size",
240
+ type=int,
241
+ default=ServerArgs.dp_size,
242
+ help="Data parallelism size.",
243
+ )
244
+ parser.add_argument(
245
+ "--load-balance-method",
246
+ type=str,
247
+ default=ServerArgs.load_balance_method,
248
+ help="Load balancing strategy for data parallelism.",
249
+ choices=[
250
+ "round_robin",
251
+ "shortest_queue",
252
+ ],
253
+ )
254
+
212
255
  # Optimization/debug options
213
256
  parser.add_argument(
214
257
  "--enable-flashinfer",
@@ -254,10 +297,15 @@ class ServerArgs:
254
297
  )
255
298
 
256
299
 
300
+ @dataclasses.dataclass
301
+ class ModelPortArgs:
302
+ nccl_port: int
303
+ model_tp_ports: List[int]
304
+
305
+
257
306
  @dataclasses.dataclass
258
307
  class PortArgs:
259
308
  tokenizer_port: int
260
309
  router_port: int
261
310
  detokenizer_port: int
262
- nccl_port: int
263
- model_rpc_ports: List[int]
311
+ model_port_args: List[ModelPortArgs]