sglang 0.1.20__py3-none-any.whl → 0.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/__init__.py +8 -8
  2. sglang/api.py +1 -1
  3. sglang/backend/runtime_endpoint.py +14 -4
  4. sglang/backend/vertexai.py +5 -4
  5. sglang/bench.py +627 -0
  6. sglang/bench_latency.py +22 -20
  7. sglang/bench_serving.py +758 -0
  8. sglang/check_env.py +171 -0
  9. sglang/global_config.py +3 -1
  10. sglang/lang/backend/__init__.py +0 -0
  11. sglang/lang/backend/anthropic.py +77 -0
  12. sglang/lang/backend/base_backend.py +80 -0
  13. sglang/lang/backend/litellm.py +90 -0
  14. sglang/lang/backend/openai.py +438 -0
  15. sglang/lang/backend/runtime_endpoint.py +283 -0
  16. sglang/lang/backend/vertexai.py +149 -0
  17. sglang/lang/chat_template.py +2 -2
  18. sglang/lang/ir.py +3 -3
  19. sglang/lang/tracer.py +1 -1
  20. sglang/launch_server.py +1 -1
  21. sglang/launch_server_llavavid.py +1 -4
  22. sglang/srt/conversation.py +1 -1
  23. sglang/srt/layers/context_flashattention_nopad.py +0 -29
  24. sglang/srt/layers/extend_attention.py +0 -39
  25. sglang/srt/layers/linear.py +869 -0
  26. sglang/srt/layers/quantization/__init__.py +49 -0
  27. sglang/srt/layers/quantization/fp8.py +662 -0
  28. sglang/srt/layers/radix_attention.py +31 -5
  29. sglang/srt/layers/token_attention.py +1 -51
  30. sglang/srt/managers/controller/cuda_graph_runner.py +44 -18
  31. sglang/srt/managers/controller/infer_batch.py +76 -72
  32. sglang/srt/managers/controller/manager_multi.py +109 -98
  33. sglang/srt/managers/controller/manager_single.py +105 -50
  34. sglang/srt/managers/controller/model_runner.py +42 -18
  35. sglang/srt/managers/controller/radix_cache.py +4 -3
  36. sglang/srt/managers/controller/schedule_heuristic.py +4 -0
  37. sglang/srt/managers/controller/tp_worker.py +143 -156
  38. sglang/srt/managers/detokenizer_manager.py +49 -5
  39. sglang/srt/managers/io_struct.py +36 -17
  40. sglang/srt/managers/tokenizer_manager.py +228 -125
  41. sglang/srt/memory_pool.py +46 -58
  42. sglang/srt/model_loader/model_loader.py +277 -0
  43. sglang/srt/model_loader/utils.py +260 -0
  44. sglang/srt/models/chatglm.py +1 -0
  45. sglang/srt/models/dbrx.py +1 -0
  46. sglang/srt/models/grok.py +1 -0
  47. sglang/srt/models/internlm2.py +317 -0
  48. sglang/srt/models/llama2.py +65 -16
  49. sglang/srt/models/llama_classification.py +1 -0
  50. sglang/srt/models/llava.py +1 -0
  51. sglang/srt/models/llavavid.py +1 -0
  52. sglang/srt/models/minicpm.py +2 -8
  53. sglang/srt/models/mixtral.py +1 -0
  54. sglang/srt/models/mixtral_quant.py +1 -0
  55. sglang/srt/models/qwen.py +1 -0
  56. sglang/srt/models/qwen2.py +6 -0
  57. sglang/srt/models/qwen2_moe.py +130 -108
  58. sglang/srt/models/stablelm.py +1 -0
  59. sglang/srt/openai_api/adapter.py +432 -0
  60. sglang/srt/openai_api/api_adapter.py +432 -0
  61. sglang/srt/openai_api/openai_api_adapter.py +431 -0
  62. sglang/srt/openai_api/openai_protocol.py +207 -0
  63. sglang/srt/openai_api/protocol.py +208 -0
  64. sglang/srt/openai_protocol.py +17 -0
  65. sglang/srt/sampling_params.py +2 -0
  66. sglang/srt/server.py +114 -90
  67. sglang/srt/server_args.py +27 -17
  68. sglang/srt/utils.py +17 -118
  69. sglang/test/test_conversation.py +1 -1
  70. sglang/test/test_openai_protocol.py +1 -1
  71. sglang/test/test_programs.py +1 -1
  72. sglang/test/test_utils.py +2 -2
  73. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/METADATA +157 -159
  74. sglang-0.1.22.dist-info/RECORD +103 -0
  75. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/WHEEL +1 -1
  76. sglang-0.1.20.dist-info/RECORD +0 -82
  77. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/LICENSE +0 -0
  78. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -33,7 +33,7 @@ class ServerArgs:
33
33
 
34
34
  # Other runtime options
35
35
  tp_size: int = 1
36
- stream_interval: int = 8
36
+ stream_interval: int = 1
37
37
  random_seed: Optional[int] = None
38
38
 
39
39
  # Logging
@@ -57,6 +57,7 @@ class ServerArgs:
57
57
  disable_disk_cache: bool = False
58
58
  attention_reduce_in_fp32: bool = False
59
59
  enable_p2p_check: bool = False
60
+ efficient_weight_load: bool = False
60
61
 
61
62
  # Distributed args
62
63
  nccl_init_addr: Optional[str] = None
@@ -67,10 +68,12 @@ class ServerArgs:
67
68
  if self.tokenizer_path is None:
68
69
  self.tokenizer_path = self.model_path
69
70
  if self.mem_fraction_static is None:
70
- if self.tp_size >= 8:
71
+ if self.tp_size >= 16:
72
+ self.mem_fraction_static = 0.74
73
+ elif self.tp_size >= 8:
71
74
  self.mem_fraction_static = 0.78
72
75
  elif self.tp_size >= 4:
73
- self.mem_fraction_static = 0.80
76
+ self.mem_fraction_static = 0.82
74
77
  elif self.tp_size >= 2:
75
78
  self.mem_fraction_static = 0.85
76
79
  else:
@@ -164,6 +167,15 @@ class ServerArgs:
164
167
  "--quantization",
165
168
  type=str,
166
169
  default=ServerArgs.quantization,
170
+ choices=[
171
+ "awq",
172
+ "fp8",
173
+ "gptq",
174
+ "marlin",
175
+ "gptq_marlin",
176
+ "squeezellm",
177
+ "bitsandbytes",
178
+ ],
167
179
  help="The quantization method.",
168
180
  )
169
181
  parser.add_argument(
@@ -241,13 +253,13 @@ class ServerArgs:
241
253
  parser.add_argument(
242
254
  "--show-time-cost",
243
255
  action="store_true",
244
- help="Show time cost of custom marks",
256
+ help="Show time cost of custom marks.",
245
257
  )
246
258
  parser.add_argument(
247
259
  "--api-key",
248
260
  type=str,
249
261
  default=ServerArgs.api_key,
250
- help="Set API key of the server",
262
+ help="Set API key of the server.",
251
263
  )
252
264
 
253
265
  # Data parallelism
@@ -283,17 +295,17 @@ class ServerArgs:
283
295
  parser.add_argument(
284
296
  "--disable-flashinfer",
285
297
  action="store_true",
286
- help="Disable flashinfer inference kernels",
298
+ help="Disable flashinfer inference kernels.",
287
299
  )
288
300
  parser.add_argument(
289
301
  "--disable-radix-cache",
290
302
  action="store_true",
291
- help="Disable RadixAttention",
303
+ help="Disable RadixAttention for prefix caching.",
292
304
  )
293
305
  parser.add_argument(
294
306
  "--disable-regex-jump-forward",
295
307
  action="store_true",
296
- help="Disable regex jump-forward",
308
+ help="Disable regex jump-forward.",
297
309
  )
298
310
  parser.add_argument(
299
311
  "--disable-cuda-graph",
@@ -316,6 +328,11 @@ class ServerArgs:
316
328
  action="store_true",
317
329
  help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
318
330
  )
331
+ parser.add_argument(
332
+ "--efficient-weight-load",
333
+ action="store_true",
334
+ help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
335
+ )
319
336
 
320
337
  @classmethod
321
338
  def from_cli_args(cls, args: argparse.Namespace):
@@ -335,16 +352,9 @@ class ServerArgs:
335
352
  )
336
353
 
337
354
 
338
- @dataclasses.dataclass
339
- class ModelPortArgs:
340
- nccl_port: int
341
- model_tp_ips: List[str]
342
- model_tp_ports: List[int]
343
-
344
-
345
355
  @dataclasses.dataclass
346
356
  class PortArgs:
347
357
  tokenizer_port: int
348
- router_port: int
358
+ controller_port: int
349
359
  detokenizer_port: int
350
- model_port_args: List[ModelPortArgs]
360
+ nccl_ports: List[int]
sglang/srt/utils.py CHANGED
@@ -3,9 +3,9 @@
3
3
  import base64
4
4
  import fcntl
5
5
  import logging
6
- import multiprocessing
7
6
  import os
8
7
  import random
8
+ import resource
9
9
  import socket
10
10
  import struct
11
11
  import time
@@ -16,12 +16,11 @@ from typing import List, Optional
16
16
  import numpy as np
17
17
  import psutil
18
18
  import requests
19
- import rpyc
20
19
  import torch
20
+ import torch.distributed as dist
21
21
  import triton
22
22
  from fastapi.responses import JSONResponse
23
23
  from packaging import version as pkg_version
24
- from rpyc.utils.server import ThreadedServer
25
24
  from starlette.middleware.base import BaseHTTPMiddleware
26
25
 
27
26
  logger = logging.getLogger(__name__)
@@ -148,7 +147,6 @@ def is_port_available(port):
148
147
  def allocate_init_ports(
149
148
  port: Optional[int] = None,
150
149
  additional_ports: Optional[List[int]] = None,
151
- tp_size: int = 1,
152
150
  dp_size: int = 1,
153
151
  ):
154
152
  """Allocate ports for all connections."""
@@ -160,8 +158,8 @@ def allocate_init_ports(
160
158
  ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
161
159
  cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
162
160
 
163
- # HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size)
164
- num_ports_needed = 4 + dp_size * (1 + tp_size)
161
+ # HTTP + Tokenizer + Controller + Detokenizer + dp_size * 1 (nccl)
162
+ num_ports_needed = 4 + dp_size
165
163
  while len(ret_ports) < num_ports_needed:
166
164
  if cur_port not in ret_ports and is_port_available(cur_port):
167
165
  ret_ports.append(cur_port)
@@ -188,71 +186,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
188
186
  return logit_bias
189
187
 
190
188
 
191
- def wrap_kernel_launcher(kernel):
192
- """A faster launcher for triton kernels."""
193
- if int(triton.__version__.split(".")[0]) >= 3:
194
- return None
195
-
196
- gpu_id = torch.cuda.current_device()
197
- kernels = kernel.cache[gpu_id].values()
198
- kernel = next(iter(kernels))
199
-
200
- # Different trition versions use different low-level names
201
- if hasattr(kernel, "cu_function"):
202
- kfunction = kernel.cu_function
203
- else:
204
- kfunction = kernel.function
205
-
206
- if hasattr(kernel, "c_wrapper"):
207
- run = kernel.c_wrapper
208
- else:
209
- run = kernel.run
210
-
211
- add_cluster_dim = True
212
-
213
- def ret_func(grid, num_warps, *args):
214
- nonlocal add_cluster_dim
215
-
216
- try:
217
- if add_cluster_dim:
218
- run(
219
- grid[0],
220
- grid[1],
221
- grid[2],
222
- num_warps,
223
- 1,
224
- 1,
225
- 1,
226
- 1,
227
- kernel.shared,
228
- 0,
229
- kfunction,
230
- None,
231
- None,
232
- kernel,
233
- *args,
234
- )
235
- else:
236
- run(
237
- grid[0],
238
- grid[1],
239
- grid[2],
240
- num_warps,
241
- kernel.shared,
242
- 0,
243
- kfunction,
244
- None,
245
- None,
246
- kernel,
247
- *args,
248
- )
249
- except TypeError:
250
- add_cluster_dim = not add_cluster_dim
251
- ret_func(grid, num_warps, *args)
252
-
253
- return ret_func
254
-
255
-
256
189
  def is_multimodal_model(model):
257
190
  from sglang.srt.model_config import ModelConfig
258
191
 
@@ -371,49 +304,6 @@ def load_image(image_file):
371
304
  return image, image_size
372
305
 
373
306
 
374
- def connect_rpyc_service(host, port):
375
- repeat_count = 0
376
- while repeat_count < 20:
377
- try:
378
- con = rpyc.connect(
379
- host,
380
- port,
381
- config={
382
- "allow_public_attrs": True,
383
- "allow_pickle": True,
384
- "sync_request_timeout": 3600,
385
- },
386
- )
387
- break
388
- except ConnectionRefusedError as e:
389
- time.sleep(1)
390
- repeat_count += 1
391
- if repeat_count == 20:
392
- raise RuntimeError(f"Connect rpyc error: {e}")
393
-
394
- return con.root
395
-
396
-
397
- def start_rpyc_service(service: rpyc.Service, port: int):
398
- t = ThreadedServer(
399
- service=service,
400
- port=port,
401
- protocol_config={
402
- "allow_public_attrs": True,
403
- "allow_pickle": True,
404
- "sync_request_timeout": 3600,
405
- },
406
- )
407
- t.logger.setLevel(logging.WARN)
408
- t.start()
409
-
410
-
411
- def start_rpyc_service_process(service: rpyc.Service, port: int):
412
- proc = multiprocessing.Process(target=start_rpyc_service, args=(service, port))
413
- proc.start()
414
- return proc
415
-
416
-
417
307
  def suppress_other_loggers():
418
308
  from vllm.logger import logger as vllm_default_logger
419
309
 
@@ -445,7 +335,7 @@ def kill_parent_process():
445
335
  """Kill the parent process and all children of the parent process."""
446
336
  current_process = psutil.Process()
447
337
  parent_process = current_process.parent()
448
- children = current_process.children(recursive=True)
338
+ children = parent_process.children(recursive=True)
449
339
  for child in children:
450
340
  if child.pid != current_process.pid:
451
341
  os.kill(child.pid, 9)
@@ -474,9 +364,9 @@ def monkey_patch_vllm_dummy_weight_loader():
474
364
  DummyModelLoader,
475
365
  LoRAConfig,
476
366
  ModelConfig,
367
+ MultiModalConfig,
477
368
  ParallelConfig,
478
369
  SchedulerConfig,
479
- MultiModalConfig,
480
370
  _initialize_model,
481
371
  initialize_dummy_weights,
482
372
  nn,
@@ -559,7 +449,6 @@ def get_ip_address(ifname):
559
449
 
560
450
  def send_addrs_to_rank_0(model_port_args, server_args):
561
451
  assert server_args.node_rank != 0 and server_args.dp_size == 1
562
- import torch.distributed as dist
563
452
 
564
453
  ifname = os.environ.get(
565
454
  "SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
@@ -591,7 +480,6 @@ def send_addrs_to_rank_0(model_port_args, server_args):
591
480
 
592
481
  def receive_addrs(model_port_args, server_args):
593
482
  assert server_args.node_rank == 0 and server_args.dp_size == 1
594
- import torch.distributed as dist
595
483
 
596
484
  ifname = os.environ.get(
597
485
  "SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
@@ -624,3 +512,14 @@ def receive_addrs(model_port_args, server_args):
624
512
 
625
513
  dist.barrier()
626
514
  dist.destroy_process_group()
515
+
516
+
517
+ def set_ulimit(target_soft_limit=65535):
518
+ resource_type = resource.RLIMIT_NOFILE
519
+ current_soft, current_hard = resource.getrlimit(resource_type)
520
+
521
+ if current_soft < target_soft_limit:
522
+ try:
523
+ resource.setrlimit(resource_type, (target_soft_limit, current_hard))
524
+ except ValueError as e:
525
+ logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
@@ -1,5 +1,5 @@
1
1
  from sglang.srt.conversation import generate_chat_conv
2
- from sglang.srt.managers.openai_protocol import (
2
+ from sglang.srt.managers.openai_api.protocol import (
3
3
  ChatCompletionMessageContentImagePart,
4
4
  ChatCompletionMessageContentImageURL,
5
5
  ChatCompletionMessageContentTextPart,
@@ -1,4 +1,4 @@
1
- from sglang.srt.managers.openai_protocol import (
1
+ from sglang.srt.managers.openai_api.protocol import (
2
2
  ChatCompletionMessageContentImagePart,
3
3
  ChatCompletionMessageContentImageURL,
4
4
  ChatCompletionMessageContentTextPart,
@@ -306,7 +306,7 @@ def test_image_qa():
306
306
  assert (
307
307
  "taxi" in state.messages()[-1]["content"]
308
308
  or "car" in state.messages()[-1]["content"]
309
- )
309
+ ), f"{state.messages()[-1]['content']}"
310
310
 
311
311
 
312
312
  def test_stream():
sglang/test/test_utils.py CHANGED
@@ -6,9 +6,9 @@ from functools import partial
6
6
  import numpy as np
7
7
  import requests
8
8
 
9
- from sglang.backend.openai import OpenAI
10
- from sglang.backend.runtime_endpoint import RuntimeEndpoint
11
9
  from sglang.global_config import global_config
10
+ from sglang.lang.backend.openai import OpenAI
11
+ from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
12
12
  from sglang.utils import get_exception_traceback
13
13
 
14
14