sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +30 -4
  3. sglang/backend/litellm.py +2 -2
  4. sglang/backend/openai.py +26 -15
  5. sglang/backend/runtime_endpoint.py +18 -14
  6. sglang/bench_latency.py +317 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +41 -6
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +6 -2
  11. sglang/lang/ir.py +74 -28
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +14 -6
  15. sglang/srt/constrained/fsm_cache.py +6 -3
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +2 -0
  19. sglang/srt/hf_transformers_utils.py +68 -9
  20. sglang/srt/layers/extend_attention.py +2 -1
  21. sglang/srt/layers/fused_moe.py +280 -169
  22. sglang/srt/layers/logits_processor.py +106 -42
  23. sglang/srt/layers/radix_attention.py +53 -29
  24. sglang/srt/layers/token_attention.py +4 -1
  25. sglang/srt/managers/controller/dp_worker.py +6 -3
  26. sglang/srt/managers/controller/infer_batch.py +144 -69
  27. sglang/srt/managers/controller/manager_multi.py +5 -5
  28. sglang/srt/managers/controller/manager_single.py +9 -4
  29. sglang/srt/managers/controller/model_runner.py +167 -55
  30. sglang/srt/managers/controller/radix_cache.py +4 -0
  31. sglang/srt/managers/controller/schedule_heuristic.py +2 -0
  32. sglang/srt/managers/controller/tp_worker.py +156 -134
  33. sglang/srt/managers/detokenizer_manager.py +19 -21
  34. sglang/srt/managers/io_struct.py +11 -5
  35. sglang/srt/managers/tokenizer_manager.py +16 -14
  36. sglang/srt/model_config.py +89 -4
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +2 -2
  39. sglang/srt/models/dbrx.py +1 -1
  40. sglang/srt/models/gemma.py +5 -1
  41. sglang/srt/models/gemma2.py +436 -0
  42. sglang/srt/models/grok.py +204 -137
  43. sglang/srt/models/llama2.py +12 -5
  44. sglang/srt/models/llama_classification.py +107 -0
  45. sglang/srt/models/llava.py +11 -8
  46. sglang/srt/models/llavavid.py +1 -1
  47. sglang/srt/models/minicpm.py +373 -0
  48. sglang/srt/models/mixtral.py +164 -115
  49. sglang/srt/models/mixtral_quant.py +0 -1
  50. sglang/srt/models/qwen.py +1 -1
  51. sglang/srt/models/qwen2.py +1 -1
  52. sglang/srt/models/qwen2_moe.py +454 -0
  53. sglang/srt/models/stablelm.py +1 -1
  54. sglang/srt/models/yivl.py +2 -2
  55. sglang/srt/openai_api_adapter.py +35 -25
  56. sglang/srt/openai_protocol.py +2 -2
  57. sglang/srt/server.py +69 -19
  58. sglang/srt/server_args.py +76 -43
  59. sglang/srt/utils.py +177 -35
  60. sglang/test/test_programs.py +28 -10
  61. sglang/utils.py +4 -3
  62. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
  63. sglang-0.1.19.dist-info/RECORD +81 -0
  64. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
  65. sglang/srt/managers/router/infer_batch.py +0 -596
  66. sglang/srt/managers/router/manager.py +0 -82
  67. sglang/srt/managers/router/model_rpc.py +0 -818
  68. sglang/srt/managers/router/model_runner.py +0 -445
  69. sglang/srt/managers/router/radix_cache.py +0 -267
  70. sglang/srt/managers/router/scheduler.py +0 -59
  71. sglang-0.1.17.dist-info/RECORD +0 -81
  72. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
  73. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
sglang/srt/server.py CHANGED
@@ -1,4 +1,7 @@
1
- """SRT: SGLang Runtime"""
1
+ """
2
+ The entry point of inference server.
3
+ SRT = SGLang Runtime.
4
+ """
2
5
 
3
6
  import asyncio
4
7
  import dataclasses
@@ -10,7 +13,7 @@ import sys
10
13
  import threading
11
14
  import time
12
15
  from http import HTTPStatus
13
- from typing import Optional
16
+ from typing import Dict, Optional
14
17
 
15
18
  # Fix a bug of Python threading
16
19
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -26,10 +29,15 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
26
29
  from sglang.backend.runtime_endpoint import RuntimeEndpoint
27
30
  from sglang.srt.constrained import disable_cache
28
31
  from sglang.srt.hf_transformers_utils import get_tokenizer
32
+ from sglang.srt.managers.controller.manager_multi import (
33
+ start_controller_process as start_controller_process_multi,
34
+ )
35
+ from sglang.srt.managers.controller.manager_single import (
36
+ start_controller_process as start_controller_process_single,
37
+ )
38
+ from sglang.srt.managers.controller.tp_worker import ModelTpService
29
39
  from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
30
40
  from sglang.srt.managers.io_struct import GenerateReqInput
31
- from sglang.srt.managers.controller.manager_single import start_controller_process as start_controller_process_single
32
- from sglang.srt.managers.controller.manager_multi import start_controller_process as start_controller_process_multi
33
41
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
34
42
  from sglang.srt.openai_api_adapter import (
35
43
  load_chat_template_for_openai_api,
@@ -43,9 +51,14 @@ from sglang.srt.utils import (
43
51
  allocate_init_ports,
44
52
  assert_pkg_version,
45
53
  enable_show_time_cost,
54
+ receive_addrs,
55
+ send_addrs_to_rank_0,
56
+ start_rpyc_service_process,
46
57
  )
47
58
  from sglang.utils import get_exception_traceback
48
59
 
60
+ logger = logging.getLogger(__name__)
61
+
49
62
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
50
63
 
51
64
 
@@ -94,8 +107,11 @@ async def generate_request(obj: GenerateReqInput, request: Request):
94
107
  yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
95
108
  yield "data: [DONE]\n\n"
96
109
 
97
- return StreamingResponse(stream_results(), media_type="text/event-stream",
98
- background=tokenizer_manager.create_abort_task(obj))
110
+ return StreamingResponse(
111
+ stream_results(),
112
+ media_type="text/event-stream",
113
+ background=tokenizer_manager.create_abort_task(obj),
114
+ )
99
115
  else:
100
116
  try:
101
117
  ret = await tokenizer_manager.generate_request(obj, request).__anext__()
@@ -134,29 +150,38 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
134
150
  enable_show_time_cost()
135
151
  if server_args.disable_disk_cache:
136
152
  disable_cache()
137
- if server_args.enable_flashinfer:
138
- assert_pkg_version("flashinfer", "0.0.4")
153
+ if not server_args.disable_flashinfer:
154
+ assert_pkg_version(
155
+ "flashinfer",
156
+ "0.0.8",
157
+ "Please uninstall the old version and "
158
+ "reinstall the latest version by following the instructions "
159
+ "at https://docs.flashinfer.ai/installation.html.",
160
+ )
139
161
  if server_args.chat_template:
140
162
  # TODO: replace this with huggingface transformers template
141
163
  load_chat_template_for_openai_api(server_args.chat_template)
142
164
 
143
165
  # Allocate ports
166
+ assert server_args.tp_size % server_args.nnodes == 0
167
+ tp_size_local = server_args.tp_size // server_args.nnodes
144
168
  server_args.port, server_args.additional_ports = allocate_init_ports(
145
169
  server_args.port,
146
170
  server_args.additional_ports,
147
- server_args.tp_size,
171
+ tp_size_local,
148
172
  server_args.dp_size,
149
173
  )
150
174
 
151
- # Init local models port args
152
175
  ports = server_args.additional_ports
153
- tp = server_args.tp_size
154
176
  model_port_args = []
155
177
  for i in range(server_args.dp_size):
156
178
  model_port_args.append(
157
179
  ModelPortArgs(
158
- nccl_port=ports[3 + i * (tp + 1)],
159
- model_tp_ports=ports[3 + i * (tp + 1) + 1 : 3 + (i + 1) * (tp + 1)],
180
+ nccl_port=ports[3 + i * (tp_size_local + 1)],
181
+ model_tp_ips=[None] * tp_size_local,
182
+ model_tp_ports=ports[
183
+ 3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)
184
+ ],
160
185
  )
161
186
  )
162
187
  port_args = PortArgs(
@@ -166,6 +191,24 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
166
191
  model_port_args=model_port_args,
167
192
  )
168
193
 
194
+ # TODO multi-node dp is not supported
195
+ assert not (server_args.dp_size > 1 and server_args.node_rank is not None)
196
+ if server_args.nnodes > 1:
197
+ if server_args.node_rank != 0:
198
+ send_addrs_to_rank_0(model_port_args[0], server_args)
199
+ else:
200
+ receive_addrs(model_port_args[0], server_args)
201
+ for i in range(tp_size_local):
202
+ start_rpyc_service_process(
203
+ ModelTpService, model_port_args[0].model_tp_ports[i]
204
+ )
205
+ if server_args.node_rank != 0:
206
+ logger.info(
207
+ f"[node_rank={server_args.node_rank}]: Listen for connections..."
208
+ )
209
+ while True:
210
+ pass
211
+
169
212
  # Launch processes
170
213
  tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
171
214
  pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
@@ -223,7 +266,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
223
266
  try:
224
267
  requests.get(url + "/get_model_info", timeout=5, headers=headers)
225
268
  break
226
- except requests.exceptions.RequestException as e:
269
+ except requests.exceptions.RequestException:
227
270
  pass
228
271
 
229
272
  # Send a warmup request
@@ -235,19 +278,20 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
235
278
  "text": "The capital city of France is",
236
279
  "sampling_params": {
237
280
  "temperature": 0,
238
- "max_new_tokens": 16,
281
+ "max_new_tokens": 8,
239
282
  },
240
283
  },
241
284
  headers=headers,
242
285
  timeout=600,
243
286
  )
244
287
  assert res.status_code == 200
245
- except Exception:
288
+ except Exception as e:
246
289
  if pipe_finish_writer is not None:
247
290
  pipe_finish_writer.send(get_exception_traceback())
248
- print(f"Initialization failed. warmup error: {e}")
291
+ print(f"Initialization failed. warmup error: {e}", flush=True)
249
292
  raise e
250
293
 
294
+ logger.info("The server is fired up and ready to roll!")
251
295
  if pipe_finish_writer is not None:
252
296
  pipe_finish_writer.send("init ok")
253
297
 
@@ -260,7 +304,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
260
304
  app,
261
305
  host=server_args.host,
262
306
  port=server_args.port,
263
- log_level=server_args.log_level,
307
+ log_level=server_args.log_level_http or server_args.log_level,
264
308
  timeout_keep_alive=5,
265
309
  loop="uvloop",
266
310
  )
@@ -269,6 +313,12 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
269
313
 
270
314
 
271
315
  class Runtime:
316
+ """
317
+ A wrapper for the server.
318
+ This is used for launching the server in a python program without
319
+ using the commond line interface.
320
+ """
321
+
272
322
  def __init__(
273
323
  self,
274
324
  log_level: str = "error",
@@ -339,7 +389,7 @@ class Runtime:
339
389
  async def add_request(
340
390
  self,
341
391
  prompt: str,
342
- sampling_params,
392
+ sampling_params: Dict,
343
393
  ):
344
394
  json_data = {
345
395
  "text": prompt,
sglang/srt/server_args.py CHANGED
@@ -11,12 +11,13 @@ class ServerArgs:
11
11
  # Model and tokenizer
12
12
  model_path: str
13
13
  tokenizer_path: Optional[str] = None
14
- load_format: str = "auto"
15
14
  tokenizer_mode: str = "auto"
16
- chat_template: Optional[str] = None
15
+ load_format: str = "auto"
16
+ dtype: str = "auto"
17
17
  trust_remote_code: bool = True
18
18
  context_length: Optional[int] = None
19
19
  quantization: Optional[str] = None
20
+ chat_template: Optional[str] = None
20
21
 
21
22
  # Port
22
23
  host: str = "127.0.0.1"
@@ -37,9 +38,8 @@ class ServerArgs:
37
38
 
38
39
  # Logging
39
40
  log_level: str = "info"
41
+ log_level_http: Optional[str] = None
40
42
  log_requests: bool = False
41
- disable_log_stats: bool = False
42
- log_stats_interval: int = 10
43
43
  show_time_cost: bool = False
44
44
 
45
45
  # Other
@@ -50,11 +50,17 @@ class ServerArgs:
50
50
  load_balance_method: str = "round_robin"
51
51
 
52
52
  # Optimization/debug options
53
- enable_flashinfer: bool = False
54
- attention_reduce_in_fp32: bool = False
53
+ disable_flashinfer: bool = False
55
54
  disable_radix_cache: bool = False
56
55
  disable_regex_jump_forward: bool = False
57
56
  disable_disk_cache: bool = False
57
+ attention_reduce_in_fp32: bool = False
58
+ enable_p2p_check: bool = False
59
+
60
+ # Distributed args
61
+ nccl_init_addr: Optional[str] = None
62
+ nnodes: int = 1
63
+ node_rank: Optional[int] = None
58
64
 
59
65
  def __post_init__(self):
60
66
  if self.tokenizer_path is None:
@@ -101,7 +107,16 @@ class ServerArgs:
101
107
  type=int,
102
108
  nargs="*",
103
109
  default=[],
104
- help="Additional ports specified for the server.",
110
+ help="The additional ports specified for the server.",
111
+ )
112
+ parser.add_argument(
113
+ "--tokenizer-mode",
114
+ type=str,
115
+ default=ServerArgs.tokenizer_mode,
116
+ choices=["auto", "slow"],
117
+ help="Tokenizer mode. 'auto' will use the fast "
118
+ "tokenizer if available, and 'slow' will "
119
+ "always use the slow tokenizer.",
105
120
  )
106
121
  parser.add_argument(
107
122
  "--load-format",
@@ -120,19 +135,18 @@ class ServerArgs:
120
135
  "which is mainly for profiling.",
121
136
  )
122
137
  parser.add_argument(
123
- "--tokenizer-mode",
138
+ "--dtype",
124
139
  type=str,
125
- default=ServerArgs.tokenizer_mode,
126
- choices=["auto", "slow"],
127
- help="Tokenizer mode. 'auto' will use the fast "
128
- "tokenizer if available, and 'slow' will "
129
- "always use the slow tokenizer.",
130
- )
131
- parser.add_argument(
132
- "--chat-template",
133
- type=str,
134
- default=ServerArgs.chat_template,
135
- help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server",
140
+ default=ServerArgs.dtype,
141
+ choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
142
+ help="Data type for model weights and activations.\n\n"
143
+ '* "auto" will use FP16 precision for FP32 and FP16 models, and '
144
+ "BF16 precision for BF16 models.\n"
145
+ '* "half" for FP16. Recommended for AWQ quantization.\n'
146
+ '* "float16" is the same as "half".\n'
147
+ '* "bfloat16" for a balance between precision and range.\n'
148
+ '* "float" is shorthand for FP32 precision.\n'
149
+ '* "float32" for FP32 precision.',
136
150
  )
137
151
  parser.add_argument(
138
152
  "--trust-remote-code",
@@ -151,6 +165,12 @@ class ServerArgs:
151
165
  default=ServerArgs.quantization,
152
166
  help="The quantization method.",
153
167
  )
168
+ parser.add_argument(
169
+ "--chat-template",
170
+ type=str,
171
+ default=ServerArgs.chat_template,
172
+ help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
173
+ )
154
174
  parser.add_argument(
155
175
  "--mem-fraction-static",
156
176
  type=float,
@@ -174,7 +194,7 @@ class ServerArgs:
174
194
  type=str,
175
195
  default=ServerArgs.schedule_heuristic,
176
196
  choices=["lpm", "random", "fcfs", "dfs-weight"],
177
- help="Scheduling Heuristic.",
197
+ help="The scheduling heuristic.",
178
198
  )
179
199
  parser.add_argument(
180
200
  "--schedule-conservativeness",
@@ -186,7 +206,7 @@ class ServerArgs:
186
206
  "--tp-size",
187
207
  type=int,
188
208
  default=ServerArgs.tp_size,
189
- help="Tensor parallelism size.",
209
+ help="The tensor parallelism size.",
190
210
  )
191
211
  parser.add_argument(
192
212
  "--stream-interval",
@@ -198,29 +218,24 @@ class ServerArgs:
198
218
  "--random-seed",
199
219
  type=int,
200
220
  default=ServerArgs.random_seed,
201
- help="Random seed.",
221
+ help="The random seed.",
202
222
  )
203
223
  parser.add_argument(
204
224
  "--log-level",
205
225
  type=str,
206
226
  default=ServerArgs.log_level,
207
- help="Logging level",
227
+ help="The logging level of all loggers.",
208
228
  )
209
229
  parser.add_argument(
210
- "--log-requests",
211
- action="store_true",
212
- help="Log all requests",
230
+ "--log-level-http",
231
+ type=str,
232
+ default=ServerArgs.log_level_http,
233
+ help="The logging level of HTTP server. If not set, reuse --log-level by default.",
213
234
  )
214
235
  parser.add_argument(
215
- "--disable-log-stats",
236
+ "--log-requests",
216
237
  action="store_true",
217
- help="Disable logging throughput stats.",
218
- )
219
- parser.add_argument(
220
- "--log-stats-interval",
221
- type=int,
222
- default=ServerArgs.log_stats_interval,
223
- help="Log stats interval in second.",
238
+ help="Log the inputs and outputs of all requests.",
224
239
  )
225
240
  parser.add_argument(
226
241
  "--show-time-cost",
@@ -239,29 +254,35 @@ class ServerArgs:
239
254
  "--dp-size",
240
255
  type=int,
241
256
  default=ServerArgs.dp_size,
242
- help="Data parallelism size.",
257
+ help="The data parallelism size.",
243
258
  )
244
259
  parser.add_argument(
245
260
  "--load-balance-method",
246
261
  type=str,
247
262
  default=ServerArgs.load_balance_method,
248
- help="Load balancing strategy for data parallelism.",
263
+ help="The load balancing strategy for data parallelism.",
249
264
  choices=[
250
265
  "round_robin",
251
266
  "shortest_queue",
252
267
  ],
253
268
  )
254
269
 
255
- # Optimization/debug options
270
+ # Multi-node distributed serving args
256
271
  parser.add_argument(
257
- "--enable-flashinfer",
258
- action="store_true",
259
- help="Enable flashinfer inference kernels",
272
+ "--nccl-init-addr",
273
+ type=str,
274
+ help="The nccl init address of multi-node server.",
260
275
  )
261
276
  parser.add_argument(
262
- "--attention-reduce-in-fp32",
277
+ "--nnodes", type=int, default=1, help="The number of nodes."
278
+ )
279
+ parser.add_argument("--node-rank", type=int, help="The node rank.")
280
+
281
+ # Optimization/debug options
282
+ parser.add_argument(
283
+ "--disable-flashinfer",
263
284
  action="store_true",
264
- help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
285
+ help="Disable flashinfer inference kernels",
265
286
  )
266
287
  parser.add_argument(
267
288
  "--disable-radix-cache",
@@ -278,6 +299,17 @@ class ServerArgs:
278
299
  action="store_true",
279
300
  help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
280
301
  )
302
+ parser.add_argument(
303
+ "--attention-reduce-in-fp32",
304
+ action="store_true",
305
+ help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
306
+ "This only affects Triton attention kernels",
307
+ )
308
+ parser.add_argument(
309
+ "--enable-p2p-check",
310
+ action="store_true",
311
+ help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
312
+ )
281
313
 
282
314
  @classmethod
283
315
  def from_cli_args(cls, args: argparse.Namespace):
@@ -289,7 +321,7 @@ class ServerArgs:
289
321
 
290
322
  def print_mode_args(self):
291
323
  return (
292
- f"enable_flashinfer={self.enable_flashinfer}, "
324
+ f"disable_flashinfer={self.disable_flashinfer}, "
293
325
  f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
294
326
  f"disable_radix_cache={self.disable_radix_cache}, "
295
327
  f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
@@ -300,6 +332,7 @@ class ServerArgs:
300
332
  @dataclasses.dataclass
301
333
  class ModelPortArgs:
302
334
  nccl_port: int
335
+ model_tp_ips: List[str]
303
336
  model_tp_ports: List[int]
304
337
 
305
338