sglang 0.1.21__py3-none-any.whl → 0.1.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. sglang/__init__.py +8 -8
  2. sglang/api.py +1 -1
  3. sglang/backend/vertexai.py +5 -4
  4. sglang/bench.py +627 -0
  5. sglang/bench_latency.py +22 -19
  6. sglang/bench_serving.py +976 -0
  7. sglang/check_env.py +171 -0
  8. sglang/global_config.py +3 -2
  9. sglang/lang/backend/__init__.py +0 -0
  10. sglang/lang/backend/anthropic.py +77 -0
  11. sglang/lang/backend/base_backend.py +80 -0
  12. sglang/lang/backend/litellm.py +90 -0
  13. sglang/lang/backend/openai.py +438 -0
  14. sglang/lang/backend/runtime_endpoint.py +283 -0
  15. sglang/lang/backend/vertexai.py +149 -0
  16. sglang/lang/interpreter.py +1 -0
  17. sglang/lang/tracer.py +1 -1
  18. sglang/launch_server.py +1 -1
  19. sglang/launch_server_llavavid.py +1 -4
  20. sglang/srt/conversation.py +1 -1
  21. sglang/srt/hf_transformers_utils.py +13 -1
  22. sglang/srt/layers/context_flashattention_nopad.py +0 -29
  23. sglang/srt/layers/extend_attention.py +0 -39
  24. sglang/srt/layers/linear.py +869 -0
  25. sglang/srt/layers/logits_processor.py +4 -5
  26. sglang/srt/layers/quantization/__init__.py +49 -0
  27. sglang/srt/layers/quantization/fp8.py +662 -0
  28. sglang/srt/layers/radix_attention.py +39 -24
  29. sglang/srt/layers/token_attention.py +1 -51
  30. sglang/srt/managers/controller/cuda_graph_runner.py +72 -28
  31. sglang/srt/managers/controller/infer_batch.py +90 -63
  32. sglang/srt/managers/controller/manager_multi.py +107 -100
  33. sglang/srt/managers/controller/manager_single.py +76 -96
  34. sglang/srt/managers/controller/model_runner.py +41 -26
  35. sglang/srt/managers/controller/schedule_heuristic.py +8 -3
  36. sglang/srt/managers/controller/tp_worker.py +136 -149
  37. sglang/srt/managers/detokenizer_manager.py +49 -5
  38. sglang/srt/managers/io_struct.py +36 -17
  39. sglang/srt/managers/tokenizer_manager.py +228 -125
  40. sglang/srt/memory_pool.py +32 -11
  41. sglang/srt/model_loader/model_loader.py +277 -0
  42. sglang/srt/model_loader/utils.py +260 -0
  43. sglang/srt/models/chatglm.py +1 -0
  44. sglang/srt/models/dbrx.py +1 -0
  45. sglang/srt/models/deepseek.py +430 -0
  46. sglang/srt/models/gpt_bigcode.py +282 -0
  47. sglang/srt/models/grok.py +1 -0
  48. sglang/srt/models/internlm2.py +317 -0
  49. sglang/srt/models/llama2.py +81 -23
  50. sglang/srt/models/llama_classification.py +1 -0
  51. sglang/srt/models/llava.py +1 -0
  52. sglang/srt/models/llavavid.py +1 -0
  53. sglang/srt/models/minicpm.py +1 -0
  54. sglang/srt/models/mixtral.py +1 -0
  55. sglang/srt/models/mixtral_quant.py +1 -0
  56. sglang/srt/models/qwen.py +1 -0
  57. sglang/srt/models/qwen2.py +6 -0
  58. sglang/srt/models/qwen2_moe.py +7 -4
  59. sglang/srt/models/stablelm.py +1 -0
  60. sglang/srt/openai_api/adapter.py +432 -0
  61. sglang/srt/openai_api/api_adapter.py +432 -0
  62. sglang/srt/openai_api/openai_api_adapter.py +431 -0
  63. sglang/srt/openai_api/openai_protocol.py +207 -0
  64. sglang/srt/openai_api/protocol.py +208 -0
  65. sglang/srt/openai_protocol.py +17 -0
  66. sglang/srt/sampling_params.py +2 -0
  67. sglang/srt/server.py +132 -84
  68. sglang/srt/server_args.py +35 -21
  69. sglang/srt/utils.py +65 -117
  70. sglang/test/test_conversation.py +1 -1
  71. sglang/test/test_openai_protocol.py +1 -1
  72. sglang/test/test_programs.py +1 -1
  73. sglang/test/test_utils.py +2 -2
  74. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/METADATA +162 -168
  75. sglang-0.1.24.dist-info/RECORD +105 -0
  76. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/WHEEL +1 -1
  77. sglang-0.1.21.dist-info/RECORD +0 -82
  78. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/LICENSE +0 -0
  79. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -29,11 +29,11 @@ class ServerArgs:
29
29
  max_prefill_tokens: Optional[int] = None
30
30
  max_running_requests: Optional[int] = None
31
31
  schedule_heuristic: str = "lpm"
32
- schedule_conservativeness: float = 0.8
32
+ schedule_conservativeness: float = 1.0
33
33
 
34
34
  # Other runtime options
35
35
  tp_size: int = 1
36
- stream_interval: int = 8
36
+ stream_interval: int = 1
37
37
  random_seed: Optional[int] = None
38
38
 
39
39
  # Logging
@@ -55,8 +55,10 @@ class ServerArgs:
55
55
  disable_regex_jump_forward: bool = False
56
56
  disable_cuda_graph: bool = False
57
57
  disable_disk_cache: bool = False
58
+ enable_torch_compile: bool = False
58
59
  attention_reduce_in_fp32: bool = False
59
60
  enable_p2p_check: bool = False
61
+ efficient_weight_load: bool = False
60
62
 
61
63
  # Distributed args
62
64
  nccl_init_addr: Optional[str] = None
@@ -68,15 +70,15 @@ class ServerArgs:
68
70
  self.tokenizer_path = self.model_path
69
71
  if self.mem_fraction_static is None:
70
72
  if self.tp_size >= 16:
71
- self.mem_fraction_static = 0.74
73
+ self.mem_fraction_static = 0.80
72
74
  elif self.tp_size >= 8:
73
- self.mem_fraction_static = 0.78
75
+ self.mem_fraction_static = 0.84
74
76
  elif self.tp_size >= 4:
75
- self.mem_fraction_static = 0.82
77
+ self.mem_fraction_static = 0.86
76
78
  elif self.tp_size >= 2:
77
- self.mem_fraction_static = 0.85
78
- else:
79
79
  self.mem_fraction_static = 0.88
80
+ else:
81
+ self.mem_fraction_static = 0.89
80
82
  if isinstance(self.additional_ports, int):
81
83
  self.additional_ports = [self.additional_ports]
82
84
  elif self.additional_ports is None:
@@ -166,6 +168,15 @@ class ServerArgs:
166
168
  "--quantization",
167
169
  type=str,
168
170
  default=ServerArgs.quantization,
171
+ choices=[
172
+ "awq",
173
+ "fp8",
174
+ "gptq",
175
+ "marlin",
176
+ "gptq_marlin",
177
+ "squeezellm",
178
+ "bitsandbytes",
179
+ ],
169
180
  help="The quantization method.",
170
181
  )
171
182
  parser.add_argument(
@@ -243,13 +254,13 @@ class ServerArgs:
243
254
  parser.add_argument(
244
255
  "--show-time-cost",
245
256
  action="store_true",
246
- help="Show time cost of custom marks",
257
+ help="Show time cost of custom marks.",
247
258
  )
248
259
  parser.add_argument(
249
260
  "--api-key",
250
261
  type=str,
251
262
  default=ServerArgs.api_key,
252
- help="Set API key of the server",
263
+ help="Set API key of the server.",
253
264
  )
254
265
 
255
266
  # Data parallelism
@@ -285,17 +296,17 @@ class ServerArgs:
285
296
  parser.add_argument(
286
297
  "--disable-flashinfer",
287
298
  action="store_true",
288
- help="Disable flashinfer inference kernels",
299
+ help="Disable flashinfer inference kernels.",
289
300
  )
290
301
  parser.add_argument(
291
302
  "--disable-radix-cache",
292
303
  action="store_true",
293
- help="Disable RadixAttention",
304
+ help="Disable RadixAttention for prefix caching.",
294
305
  )
295
306
  parser.add_argument(
296
307
  "--disable-regex-jump-forward",
297
308
  action="store_true",
298
- help="Disable regex jump-forward",
309
+ help="Disable regex jump-forward.",
299
310
  )
300
311
  parser.add_argument(
301
312
  "--disable-cuda-graph",
@@ -307,6 +318,11 @@ class ServerArgs:
307
318
  action="store_true",
308
319
  help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
309
320
  )
321
+ parser.add_argument(
322
+ "--enable-torch-compile",
323
+ action="store_true",
324
+ help="Optimize the model with torch.compile, experimental feature.",
325
+ )
310
326
  parser.add_argument(
311
327
  "--attention-reduce-in-fp32",
312
328
  action="store_true",
@@ -318,6 +334,11 @@ class ServerArgs:
318
334
  action="store_true",
319
335
  help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
320
336
  )
337
+ parser.add_argument(
338
+ "--efficient-weight-load",
339
+ action="store_true",
340
+ help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
341
+ )
321
342
 
322
343
  @classmethod
323
344
  def from_cli_args(cls, args: argparse.Namespace):
@@ -337,16 +358,9 @@ class ServerArgs:
337
358
  )
338
359
 
339
360
 
340
- @dataclasses.dataclass
341
- class ModelPortArgs:
342
- nccl_port: int
343
- model_tp_ips: List[str]
344
- model_tp_ports: List[int]
345
-
346
-
347
361
  @dataclasses.dataclass
348
362
  class PortArgs:
349
363
  tokenizer_port: int
350
- router_port: int
364
+ controller_port: int
351
365
  detokenizer_port: int
352
- model_port_args: List[ModelPortArgs]
366
+ nccl_ports: List[int]
sglang/srt/utils.py CHANGED
@@ -3,9 +3,9 @@
3
3
  import base64
4
4
  import fcntl
5
5
  import logging
6
- import multiprocessing
7
6
  import os
8
7
  import random
8
+ import resource
9
9
  import socket
10
10
  import struct
11
11
  import time
@@ -16,12 +16,11 @@ from typing import List, Optional
16
16
  import numpy as np
17
17
  import psutil
18
18
  import requests
19
- import rpyc
20
19
  import torch
20
+ import torch.distributed as dist
21
21
  import triton
22
22
  from fastapi.responses import JSONResponse
23
23
  from packaging import version as pkg_version
24
- from rpyc.utils.server import ThreadedServer
25
24
  from starlette.middleware.base import BaseHTTPMiddleware
26
25
 
27
26
  logger = logging.getLogger(__name__)
@@ -148,7 +147,6 @@ def is_port_available(port):
148
147
  def allocate_init_ports(
149
148
  port: Optional[int] = None,
150
149
  additional_ports: Optional[List[int]] = None,
151
- tp_size: int = 1,
152
150
  dp_size: int = 1,
153
151
  ):
154
152
  """Allocate ports for all connections."""
@@ -160,8 +158,8 @@ def allocate_init_ports(
160
158
  ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
161
159
  cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
162
160
 
163
- # HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size)
164
- num_ports_needed = 4 + dp_size * (1 + tp_size)
161
+ # HTTP + Tokenizer + Controller + Detokenizer + dp_size * 1 (nccl)
162
+ num_ports_needed = 4 + dp_size
165
163
  while len(ret_ports) < num_ports_needed:
166
164
  if cur_port not in ret_ports and is_port_available(cur_port):
167
165
  ret_ports.append(cur_port)
@@ -188,71 +186,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
188
186
  return logit_bias
189
187
 
190
188
 
191
- def wrap_kernel_launcher(kernel):
192
- """A faster launcher for triton kernels."""
193
- if int(triton.__version__.split(".")[0]) >= 3:
194
- return None
195
-
196
- gpu_id = torch.cuda.current_device()
197
- kernels = kernel.cache[gpu_id].values()
198
- kernel = next(iter(kernels))
199
-
200
- # Different trition versions use different low-level names
201
- if hasattr(kernel, "cu_function"):
202
- kfunction = kernel.cu_function
203
- else:
204
- kfunction = kernel.function
205
-
206
- if hasattr(kernel, "c_wrapper"):
207
- run = kernel.c_wrapper
208
- else:
209
- run = kernel.run
210
-
211
- add_cluster_dim = True
212
-
213
- def ret_func(grid, num_warps, *args):
214
- nonlocal add_cluster_dim
215
-
216
- try:
217
- if add_cluster_dim:
218
- run(
219
- grid[0],
220
- grid[1],
221
- grid[2],
222
- num_warps,
223
- 1,
224
- 1,
225
- 1,
226
- 1,
227
- kernel.shared,
228
- 0,
229
- kfunction,
230
- None,
231
- None,
232
- kernel,
233
- *args,
234
- )
235
- else:
236
- run(
237
- grid[0],
238
- grid[1],
239
- grid[2],
240
- num_warps,
241
- kernel.shared,
242
- 0,
243
- kfunction,
244
- None,
245
- None,
246
- kernel,
247
- *args,
248
- )
249
- except TypeError:
250
- add_cluster_dim = not add_cluster_dim
251
- ret_func(grid, num_warps, *args)
252
-
253
- return ret_func
254
-
255
-
256
189
  def is_multimodal_model(model):
257
190
  from sglang.srt.model_config import ModelConfig
258
191
 
@@ -371,49 +304,6 @@ def load_image(image_file):
371
304
  return image, image_size
372
305
 
373
306
 
374
- def connect_rpyc_service(host, port):
375
- repeat_count = 0
376
- while repeat_count < 20:
377
- try:
378
- con = rpyc.connect(
379
- host,
380
- port,
381
- config={
382
- "allow_public_attrs": True,
383
- "allow_pickle": True,
384
- "sync_request_timeout": 3600,
385
- },
386
- )
387
- break
388
- except ConnectionRefusedError as e:
389
- time.sleep(1)
390
- repeat_count += 1
391
- if repeat_count == 20:
392
- raise RuntimeError(f"Connect rpyc error: {e}")
393
-
394
- return con.root
395
-
396
-
397
- def start_rpyc_service(service: rpyc.Service, port: int):
398
- t = ThreadedServer(
399
- service=service,
400
- port=port,
401
- protocol_config={
402
- "allow_public_attrs": True,
403
- "allow_pickle": True,
404
- "sync_request_timeout": 3600,
405
- },
406
- )
407
- t.logger.setLevel(logging.WARN)
408
- t.start()
409
-
410
-
411
- def start_rpyc_service_process(service: rpyc.Service, port: int):
412
- proc = multiprocessing.Process(target=start_rpyc_service, args=(service, port))
413
- proc.start()
414
- return proc
415
-
416
-
417
307
  def suppress_other_loggers():
418
308
  from vllm.logger import logger as vllm_default_logger
419
309
 
@@ -422,6 +312,9 @@ def suppress_other_loggers():
422
312
  logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
423
313
  logging.WARN
424
314
  )
315
+ logging.getLogger("vllm.distributed.device_communicators.shm_broadcast").setLevel(
316
+ logging.WARN
317
+ )
425
318
  logging.getLogger("vllm.selector").setLevel(logging.WARN)
426
319
  logging.getLogger("vllm.utils").setLevel(logging.WARN)
427
320
 
@@ -445,7 +338,7 @@ def kill_parent_process():
445
338
  """Kill the parent process and all children of the parent process."""
446
339
  current_process = psutil.Process()
447
340
  parent_process = current_process.parent()
448
- children = current_process.children(recursive=True)
341
+ children = parent_process.children(recursive=True)
449
342
  for child in children:
450
343
  if child.pid != current_process.pid:
451
344
  os.kill(child.pid, 9)
@@ -521,6 +414,52 @@ def monkey_patch_vllm_dummy_weight_loader():
521
414
  setattr(DummyModelLoader, "load_model", load_model)
522
415
 
523
416
 
417
+ vllm_all_gather_backup = None
418
+
419
+
420
+ def monkey_patch_vllm_all_gather(reverse: bool = False):
421
+ """Monkey patch all-gather to remove in-place operations."""
422
+ from torch.distributed import _functional_collectives as funcol
423
+ from vllm.distributed.parallel_state import GroupCoordinator
424
+
425
+ global vllm_all_gather_backup
426
+ if vllm_all_gather_backup is None:
427
+ vllm_all_gather_backup = GroupCoordinator.all_gather
428
+
429
+ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
430
+ world_size = self.world_size
431
+ # Bypass the function if we are using only 1 GPU.
432
+ if world_size == 1:
433
+ return input_
434
+ assert (
435
+ -input_.dim() <= dim < input_.dim()
436
+ ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
437
+ if dim < 0:
438
+ # Convert negative dim to positive.
439
+ dim += input_.dim()
440
+ input_size = input_.size()
441
+ # Allocate output tensor.
442
+ output_tensor = torch.empty(
443
+ (world_size,) + input_size, dtype=input_.dtype, device=input_.device
444
+ )
445
+
446
+ output_tensor = funcol.all_gather_tensor(
447
+ input_, gather_dim=0, group=self.device_group
448
+ ).view((world_size,) + input_size)
449
+
450
+ # Reshape
451
+ output_tensor = output_tensor.movedim(0, dim)
452
+ output_tensor = output_tensor.reshape(
453
+ input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
454
+ )
455
+ return output_tensor
456
+
457
+ if reverse:
458
+ setattr(GroupCoordinator, "all_gather", vllm_all_gather_backup)
459
+ else:
460
+ setattr(GroupCoordinator, "all_gather", all_gather)
461
+
462
+
524
463
  API_KEY_HEADER_NAME = "X-API-Key"
525
464
 
526
465
 
@@ -559,7 +498,6 @@ def get_ip_address(ifname):
559
498
 
560
499
  def send_addrs_to_rank_0(model_port_args, server_args):
561
500
  assert server_args.node_rank != 0 and server_args.dp_size == 1
562
- import torch.distributed as dist
563
501
 
564
502
  ifname = os.environ.get(
565
503
  "SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
@@ -591,7 +529,6 @@ def send_addrs_to_rank_0(model_port_args, server_args):
591
529
 
592
530
  def receive_addrs(model_port_args, server_args):
593
531
  assert server_args.node_rank == 0 and server_args.dp_size == 1
594
- import torch.distributed as dist
595
532
 
596
533
  ifname = os.environ.get(
597
534
  "SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
@@ -624,3 +561,14 @@ def receive_addrs(model_port_args, server_args):
624
561
 
625
562
  dist.barrier()
626
563
  dist.destroy_process_group()
564
+
565
+
566
+ def set_ulimit(target_soft_limit=65535):
567
+ resource_type = resource.RLIMIT_NOFILE
568
+ current_soft, current_hard = resource.getrlimit(resource_type)
569
+
570
+ if current_soft < target_soft_limit:
571
+ try:
572
+ resource.setrlimit(resource_type, (target_soft_limit, current_hard))
573
+ except ValueError as e:
574
+ logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
@@ -1,5 +1,5 @@
1
1
  from sglang.srt.conversation import generate_chat_conv
2
- from sglang.srt.managers.openai_protocol import (
2
+ from sglang.srt.managers.openai_api.protocol import (
3
3
  ChatCompletionMessageContentImagePart,
4
4
  ChatCompletionMessageContentImageURL,
5
5
  ChatCompletionMessageContentTextPart,
@@ -1,4 +1,4 @@
1
- from sglang.srt.managers.openai_protocol import (
1
+ from sglang.srt.managers.openai_api.protocol import (
2
2
  ChatCompletionMessageContentImagePart,
3
3
  ChatCompletionMessageContentImageURL,
4
4
  ChatCompletionMessageContentTextPart,
@@ -306,7 +306,7 @@ def test_image_qa():
306
306
  assert (
307
307
  "taxi" in state.messages()[-1]["content"]
308
308
  or "car" in state.messages()[-1]["content"]
309
- )
309
+ ), f"{state.messages()[-1]['content']}"
310
310
 
311
311
 
312
312
  def test_stream():
sglang/test/test_utils.py CHANGED
@@ -6,9 +6,9 @@ from functools import partial
6
6
  import numpy as np
7
7
  import requests
8
8
 
9
- from sglang.backend.openai import OpenAI
10
- from sglang.backend.runtime_endpoint import RuntimeEndpoint
11
9
  from sglang.global_config import global_config
10
+ from sglang.lang.backend.openai import OpenAI
11
+ from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
12
12
  from sglang.utils import get_exception_traceback
13
13
 
14
14