sglang 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +4 -4
  3. sglang/backend/litellm.py +2 -2
  4. sglang/backend/openai.py +26 -15
  5. sglang/bench_latency.py +299 -0
  6. sglang/global_config.py +4 -1
  7. sglang/lang/compiler.py +2 -2
  8. sglang/lang/interpreter.py +1 -1
  9. sglang/lang/ir.py +15 -5
  10. sglang/launch_server.py +4 -1
  11. sglang/launch_server_llavavid.py +2 -1
  12. sglang/srt/constrained/__init__.py +13 -6
  13. sglang/srt/constrained/fsm_cache.py +6 -3
  14. sglang/srt/constrained/jump_forward.py +113 -25
  15. sglang/srt/conversation.py +2 -0
  16. sglang/srt/flush_cache.py +2 -0
  17. sglang/srt/hf_transformers_utils.py +64 -9
  18. sglang/srt/layers/fused_moe.py +186 -89
  19. sglang/srt/layers/logits_processor.py +53 -25
  20. sglang/srt/layers/radix_attention.py +34 -7
  21. sglang/srt/managers/controller/dp_worker.py +6 -3
  22. sglang/srt/managers/controller/infer_batch.py +142 -67
  23. sglang/srt/managers/controller/manager_multi.py +5 -5
  24. sglang/srt/managers/controller/manager_single.py +8 -3
  25. sglang/srt/managers/controller/model_runner.py +154 -54
  26. sglang/srt/managers/controller/radix_cache.py +4 -0
  27. sglang/srt/managers/controller/schedule_heuristic.py +2 -0
  28. sglang/srt/managers/controller/tp_worker.py +140 -135
  29. sglang/srt/managers/detokenizer_manager.py +15 -19
  30. sglang/srt/managers/io_struct.py +10 -4
  31. sglang/srt/managers/tokenizer_manager.py +14 -13
  32. sglang/srt/model_config.py +83 -4
  33. sglang/srt/models/chatglm.py +399 -0
  34. sglang/srt/models/commandr.py +2 -2
  35. sglang/srt/models/dbrx.py +1 -1
  36. sglang/srt/models/gemma.py +5 -1
  37. sglang/srt/models/grok.py +204 -137
  38. sglang/srt/models/llama2.py +11 -4
  39. sglang/srt/models/llama_classification.py +104 -0
  40. sglang/srt/models/llava.py +11 -8
  41. sglang/srt/models/llavavid.py +1 -1
  42. sglang/srt/models/mixtral.py +164 -115
  43. sglang/srt/models/mixtral_quant.py +0 -1
  44. sglang/srt/models/qwen.py +1 -1
  45. sglang/srt/models/qwen2.py +1 -1
  46. sglang/srt/models/stablelm.py +1 -1
  47. sglang/srt/models/yivl.py +2 -2
  48. sglang/srt/openai_api_adapter.py +33 -23
  49. sglang/srt/openai_protocol.py +1 -1
  50. sglang/srt/server.py +60 -19
  51. sglang/srt/server_args.py +79 -44
  52. sglang/srt/utils.py +146 -37
  53. sglang/test/test_programs.py +28 -10
  54. sglang/utils.py +4 -3
  55. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/METADATA +29 -22
  56. sglang-0.1.18.dist-info/RECORD +78 -0
  57. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
  58. sglang/srt/managers/router/infer_batch.py +0 -596
  59. sglang/srt/managers/router/manager.py +0 -82
  60. sglang/srt/managers/router/model_rpc.py +0 -818
  61. sglang/srt/managers/router/model_runner.py +0 -445
  62. sglang/srt/managers/router/radix_cache.py +0 -267
  63. sglang/srt/managers/router/scheduler.py +0 -59
  64. sglang-0.1.17.dist-info/RECORD +0 -81
  65. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
  66. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
sglang/srt/server.py CHANGED
@@ -1,4 +1,7 @@
1
- """SRT: SGLang Runtime"""
1
+ """
2
+ The entry point of inference server.
3
+ SRT = SGLang Runtime.
4
+ """
2
5
 
3
6
  import asyncio
4
7
  import dataclasses
@@ -10,7 +13,7 @@ import sys
10
13
  import threading
11
14
  import time
12
15
  from http import HTTPStatus
13
- from typing import Optional
16
+ from typing import Dict, Optional
14
17
 
15
18
  # Fix a bug of Python threading
16
19
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -26,10 +29,15 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
26
29
  from sglang.backend.runtime_endpoint import RuntimeEndpoint
27
30
  from sglang.srt.constrained import disable_cache
28
31
  from sglang.srt.hf_transformers_utils import get_tokenizer
32
+ from sglang.srt.managers.controller.manager_multi import (
33
+ start_controller_process as start_controller_process_multi,
34
+ )
35
+ from sglang.srt.managers.controller.manager_single import (
36
+ start_controller_process as start_controller_process_single,
37
+ )
38
+ from sglang.srt.managers.controller.tp_worker import ModelTpService
29
39
  from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
30
40
  from sglang.srt.managers.io_struct import GenerateReqInput
31
- from sglang.srt.managers.controller.manager_single import start_controller_process as start_controller_process_single
32
- from sglang.srt.managers.controller.manager_multi import start_controller_process as start_controller_process_multi
33
41
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
34
42
  from sglang.srt.openai_api_adapter import (
35
43
  load_chat_template_for_openai_api,
@@ -43,9 +51,15 @@ from sglang.srt.utils import (
43
51
  allocate_init_ports,
44
52
  assert_pkg_version,
45
53
  enable_show_time_cost,
54
+ send_addrs_to_rank_0,
55
+ receive_addrs,
56
+ start_rpyc_service_process,
46
57
  )
47
58
  from sglang.utils import get_exception_traceback
48
59
 
60
+
61
+ logger = logging.getLogger(__name__)
62
+
49
63
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
50
64
 
51
65
 
@@ -94,8 +108,11 @@ async def generate_request(obj: GenerateReqInput, request: Request):
94
108
  yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
95
109
  yield "data: [DONE]\n\n"
96
110
 
97
- return StreamingResponse(stream_results(), media_type="text/event-stream",
98
- background=tokenizer_manager.create_abort_task(obj))
111
+ return StreamingResponse(
112
+ stream_results(),
113
+ media_type="text/event-stream",
114
+ background=tokenizer_manager.create_abort_task(obj),
115
+ )
99
116
  else:
100
117
  try:
101
118
  ret = await tokenizer_manager.generate_request(obj, request).__anext__()
@@ -134,29 +151,32 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
134
151
  enable_show_time_cost()
135
152
  if server_args.disable_disk_cache:
136
153
  disable_cache()
137
- if server_args.enable_flashinfer:
138
- assert_pkg_version("flashinfer", "0.0.4")
154
+ if not server_args.disable_flashinfer:
155
+ assert_pkg_version("flashinfer", "0.0.8", "Please uninstall the old version and "
156
+ "reinstall the latest version by following the instructions "
157
+ "at https://docs.flashinfer.ai/installation.html.")
139
158
  if server_args.chat_template:
140
159
  # TODO: replace this with huggingface transformers template
141
160
  load_chat_template_for_openai_api(server_args.chat_template)
142
161
 
143
162
  # Allocate ports
163
+ assert server_args.tp_size % server_args.nnodes == 0
164
+ tp_size_local = server_args.tp_size // server_args.nnodes
144
165
  server_args.port, server_args.additional_ports = allocate_init_ports(
145
166
  server_args.port,
146
167
  server_args.additional_ports,
147
- server_args.tp_size,
168
+ tp_size_local,
148
169
  server_args.dp_size,
149
170
  )
150
171
 
151
- # Init local models port args
152
172
  ports = server_args.additional_ports
153
- tp = server_args.tp_size
154
173
  model_port_args = []
155
174
  for i in range(server_args.dp_size):
156
175
  model_port_args.append(
157
176
  ModelPortArgs(
158
- nccl_port=ports[3 + i * (tp + 1)],
159
- model_tp_ports=ports[3 + i * (tp + 1) + 1 : 3 + (i + 1) * (tp + 1)],
177
+ nccl_port=ports[3 + i * (tp_size_local + 1)],
178
+ model_tp_ips=[None] * tp_size_local,
179
+ model_tp_ports=ports[3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)],
160
180
  )
161
181
  )
162
182
  port_args = PortArgs(
@@ -166,6 +186,20 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
166
186
  model_port_args=model_port_args,
167
187
  )
168
188
 
189
+ # TODO multi-node dp is not supported
190
+ assert not (server_args.dp_size > 1 and server_args.node_rank is not None)
191
+ if server_args.nnodes > 1:
192
+ if server_args.node_rank != 0:
193
+ send_addrs_to_rank_0(model_port_args[0], server_args)
194
+ else:
195
+ receive_addrs(model_port_args[0], server_args)
196
+ for i in range(tp_size_local):
197
+ start_rpyc_service_process(ModelTpService, model_port_args[0].model_tp_ports[i])
198
+ if server_args.node_rank != 0:
199
+ logger.info(f"[node_rank={server_args.node_rank}]: Listen for connections...")
200
+ while True:
201
+ pass
202
+
169
203
  # Launch processes
170
204
  tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
171
205
  pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
@@ -223,7 +257,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
223
257
  try:
224
258
  requests.get(url + "/get_model_info", timeout=5, headers=headers)
225
259
  break
226
- except requests.exceptions.RequestException as e:
260
+ except requests.exceptions.RequestException:
227
261
  pass
228
262
 
229
263
  # Send a warmup request
@@ -235,19 +269,20 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
235
269
  "text": "The capital city of France is",
236
270
  "sampling_params": {
237
271
  "temperature": 0,
238
- "max_new_tokens": 16,
272
+ "max_new_tokens": 8,
239
273
  },
240
274
  },
241
275
  headers=headers,
242
276
  timeout=600,
243
277
  )
244
278
  assert res.status_code == 200
245
- except Exception:
279
+ except Exception as e:
246
280
  if pipe_finish_writer is not None:
247
281
  pipe_finish_writer.send(get_exception_traceback())
248
- print(f"Initialization failed. warmup error: {e}")
282
+ print(f"Initialization failed. warmup error: {e}", flush=True)
249
283
  raise e
250
284
 
285
+ logger.info("The server is fired up and ready to roll!")
251
286
  if pipe_finish_writer is not None:
252
287
  pipe_finish_writer.send("init ok")
253
288
 
@@ -260,7 +295,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
260
295
  app,
261
296
  host=server_args.host,
262
297
  port=server_args.port,
263
- log_level=server_args.log_level,
298
+ log_level=server_args.log_level_http or server_args.log_level,
264
299
  timeout_keep_alive=5,
265
300
  loop="uvloop",
266
301
  )
@@ -269,6 +304,12 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
269
304
 
270
305
 
271
306
  class Runtime:
307
+ """
308
+ A wrapper for the server.
309
+ This is used for launching the server in a python program without
310
+ using the commond line interface.
311
+ """
312
+
272
313
  def __init__(
273
314
  self,
274
315
  log_level: str = "error",
@@ -339,7 +380,7 @@ class Runtime:
339
380
  async def add_request(
340
381
  self,
341
382
  prompt: str,
342
- sampling_params,
383
+ sampling_params: Dict,
343
384
  ):
344
385
  json_data = {
345
386
  "text": prompt,
sglang/srt/server_args.py CHANGED
@@ -11,12 +11,13 @@ class ServerArgs:
11
11
  # Model and tokenizer
12
12
  model_path: str
13
13
  tokenizer_path: Optional[str] = None
14
- load_format: str = "auto"
15
14
  tokenizer_mode: str = "auto"
16
- chat_template: Optional[str] = None
15
+ load_format: str = "auto"
16
+ dtype: str = "auto"
17
17
  trust_remote_code: bool = True
18
18
  context_length: Optional[int] = None
19
19
  quantization: Optional[str] = None
20
+ chat_template: Optional[str] = None
20
21
 
21
22
  # Port
22
23
  host: str = "127.0.0.1"
@@ -37,9 +38,8 @@ class ServerArgs:
37
38
 
38
39
  # Logging
39
40
  log_level: str = "info"
41
+ log_level_http: Optional[str] = None
40
42
  log_requests: bool = False
41
- disable_log_stats: bool = False
42
- log_stats_interval: int = 10
43
43
  show_time_cost: bool = False
44
44
 
45
45
  # Other
@@ -50,11 +50,16 @@ class ServerArgs:
50
50
  load_balance_method: str = "round_robin"
51
51
 
52
52
  # Optimization/debug options
53
- enable_flashinfer: bool = False
54
- attention_reduce_in_fp32: bool = False
53
+ disable_flashinfer: bool = False
55
54
  disable_radix_cache: bool = False
56
55
  disable_regex_jump_forward: bool = False
57
56
  disable_disk_cache: bool = False
57
+ attention_reduce_in_fp32: bool = False
58
+
59
+ # Distributed args
60
+ nccl_init_addr: Optional[str] = None
61
+ nnodes: int = 1
62
+ node_rank: Optional[int] = None
58
63
 
59
64
  def __post_init__(self):
60
65
  if self.tokenizer_path is None:
@@ -101,7 +106,16 @@ class ServerArgs:
101
106
  type=int,
102
107
  nargs="*",
103
108
  default=[],
104
- help="Additional ports specified for the server.",
109
+ help="The additional ports specified for the server.",
110
+ )
111
+ parser.add_argument(
112
+ "--tokenizer-mode",
113
+ type=str,
114
+ default=ServerArgs.tokenizer_mode,
115
+ choices=["auto", "slow"],
116
+ help="Tokenizer mode. 'auto' will use the fast "
117
+ "tokenizer if available, and 'slow' will "
118
+ "always use the slow tokenizer.",
105
119
  )
106
120
  parser.add_argument(
107
121
  "--load-format",
@@ -120,20 +134,20 @@ class ServerArgs:
120
134
  "which is mainly for profiling.",
121
135
  )
122
136
  parser.add_argument(
123
- "--tokenizer-mode",
124
- type=str,
125
- default=ServerArgs.tokenizer_mode,
126
- choices=["auto", "slow"],
127
- help="Tokenizer mode. 'auto' will use the fast "
128
- "tokenizer if available, and 'slow' will "
129
- "always use the slow tokenizer.",
130
- )
131
- parser.add_argument(
132
- "--chat-template",
137
+ "--dtype",
133
138
  type=str,
134
- default=ServerArgs.chat_template,
135
- help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server",
136
- )
139
+ default=ServerArgs.dtype,
140
+ choices=[
141
+ "auto", "half", "float16", "bfloat16", "float", "float32"
142
+ ],
143
+ help='Data type for model weights and activations.\n\n'
144
+ '* "auto" will use FP16 precision for FP32 and FP16 models, and '
145
+ 'BF16 precision for BF16 models.\n'
146
+ '* "half" for FP16. Recommended for AWQ quantization.\n'
147
+ '* "float16" is the same as "half".\n'
148
+ '* "bfloat16" for a balance between precision and range.\n'
149
+ '* "float" is shorthand for FP32 precision.\n'
150
+ '* "float32" for FP32 precision.')
137
151
  parser.add_argument(
138
152
  "--trust-remote-code",
139
153
  action="store_true",
@@ -151,6 +165,12 @@ class ServerArgs:
151
165
  default=ServerArgs.quantization,
152
166
  help="The quantization method.",
153
167
  )
168
+ parser.add_argument(
169
+ "--chat-template",
170
+ type=str,
171
+ default=ServerArgs.chat_template,
172
+ help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
173
+ )
154
174
  parser.add_argument(
155
175
  "--mem-fraction-static",
156
176
  type=float,
@@ -174,7 +194,7 @@ class ServerArgs:
174
194
  type=str,
175
195
  default=ServerArgs.schedule_heuristic,
176
196
  choices=["lpm", "random", "fcfs", "dfs-weight"],
177
- help="Scheduling Heuristic.",
197
+ help="The scheduling heuristic.",
178
198
  )
179
199
  parser.add_argument(
180
200
  "--schedule-conservativeness",
@@ -186,7 +206,7 @@ class ServerArgs:
186
206
  "--tp-size",
187
207
  type=int,
188
208
  default=ServerArgs.tp_size,
189
- help="Tensor parallelism size.",
209
+ help="The tensor parallelism size.",
190
210
  )
191
211
  parser.add_argument(
192
212
  "--stream-interval",
@@ -198,29 +218,24 @@ class ServerArgs:
198
218
  "--random-seed",
199
219
  type=int,
200
220
  default=ServerArgs.random_seed,
201
- help="Random seed.",
221
+ help="The random seed.",
202
222
  )
203
223
  parser.add_argument(
204
224
  "--log-level",
205
225
  type=str,
206
226
  default=ServerArgs.log_level,
207
- help="Logging level",
227
+ help="The logging level of all loggers.",
208
228
  )
209
229
  parser.add_argument(
210
- "--log-requests",
211
- action="store_true",
212
- help="Log all requests",
230
+ "--log-level-http",
231
+ type=str,
232
+ default=ServerArgs.log_level_http,
233
+ help="The logging level of HTTP server. If not set, reuse --log-level by default.",
213
234
  )
214
235
  parser.add_argument(
215
- "--disable-log-stats",
236
+ "--log-requests",
216
237
  action="store_true",
217
- help="Disable logging throughput stats.",
218
- )
219
- parser.add_argument(
220
- "--log-stats-interval",
221
- type=int,
222
- default=ServerArgs.log_stats_interval,
223
- help="Log stats interval in second.",
238
+ help="Log the inputs and outputs of all requests.",
224
239
  )
225
240
  parser.add_argument(
226
241
  "--show-time-cost",
@@ -239,29 +254,42 @@ class ServerArgs:
239
254
  "--dp-size",
240
255
  type=int,
241
256
  default=ServerArgs.dp_size,
242
- help="Data parallelism size.",
257
+ help="The data parallelism size.",
243
258
  )
244
259
  parser.add_argument(
245
260
  "--load-balance-method",
246
261
  type=str,
247
262
  default=ServerArgs.load_balance_method,
248
- help="Load balancing strategy for data parallelism.",
263
+ help="The load balancing strategy for data parallelism.",
249
264
  choices=[
250
265
  "round_robin",
251
266
  "shortest_queue",
252
267
  ],
253
268
  )
254
269
 
255
- # Optimization/debug options
270
+ # Multi-node distributed serving args
256
271
  parser.add_argument(
257
- "--enable-flashinfer",
258
- action="store_true",
259
- help="Enable flashinfer inference kernels",
272
+ "--nccl-init-addr",
273
+ type=str,
274
+ help="The nccl init address of multi-node server."
260
275
  )
261
276
  parser.add_argument(
262
- "--attention-reduce-in-fp32",
277
+ "--nnodes",
278
+ type=int,
279
+ default=1,
280
+ help="The number of nodes."
281
+ )
282
+ parser.add_argument(
283
+ "--node-rank",
284
+ type=int,
285
+ help="The node rank."
286
+ )
287
+
288
+ # Optimization/debug options
289
+ parser.add_argument(
290
+ "--disable-flashinfer",
263
291
  action="store_true",
264
- help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
292
+ help="Disable flashinfer inference kernels",
265
293
  )
266
294
  parser.add_argument(
267
295
  "--disable-radix-cache",
@@ -278,6 +306,12 @@ class ServerArgs:
278
306
  action="store_true",
279
307
  help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
280
308
  )
309
+ parser.add_argument(
310
+ "--attention-reduce-in-fp32",
311
+ action="store_true",
312
+ help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
313
+ "This only affects Triton attention kernels",
314
+ )
281
315
 
282
316
  @classmethod
283
317
  def from_cli_args(cls, args: argparse.Namespace):
@@ -289,7 +323,7 @@ class ServerArgs:
289
323
 
290
324
  def print_mode_args(self):
291
325
  return (
292
- f"enable_flashinfer={self.enable_flashinfer}, "
326
+ f"disable_flashinfer={self.disable_flashinfer}, "
293
327
  f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
294
328
  f"disable_radix_cache={self.disable_radix_cache}, "
295
329
  f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
@@ -300,6 +334,7 @@ class ServerArgs:
300
334
  @dataclasses.dataclass
301
335
  class ModelPortArgs:
302
336
  nccl_port: int
337
+ model_tp_ips: List[str]
303
338
  model_tp_ports: List[int]
304
339
 
305
340