sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +24 -16
  4. sglang/bench_one_batch.py +51 -3
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +37 -28
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +15 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/model_config.py +16 -6
  13. sglang/srt/constrained/base_grammar_backend.py +21 -0
  14. sglang/srt/constrained/xgrammar_backend.py +8 -4
  15. sglang/srt/conversation.py +14 -1
  16. sglang/srt/distributed/__init__.py +3 -3
  17. sglang/srt/distributed/communication_op.py +2 -1
  18. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  21. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  22. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  23. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  24. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  25. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  26. sglang/srt/distributed/parallel_state.py +1 -1
  27. sglang/srt/distributed/utils.py +2 -1
  28. sglang/srt/entrypoints/engine.py +449 -0
  29. sglang/srt/entrypoints/http_server.py +579 -0
  30. sglang/srt/layers/activation.py +3 -3
  31. sglang/srt/layers/attention/flashinfer_backend.py +27 -12
  32. sglang/srt/layers/attention/triton_backend.py +4 -6
  33. sglang/srt/layers/attention/vision.py +204 -0
  34. sglang/srt/layers/dp_attention.py +69 -0
  35. sglang/srt/layers/linear.py +76 -102
  36. sglang/srt/layers/logits_processor.py +48 -63
  37. sglang/srt/layers/moe/ep_moe/layer.py +4 -4
  38. sglang/srt/layers/moe/fused_moe_native.py +69 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
  41. sglang/srt/layers/moe/topk.py +4 -2
  42. sglang/srt/layers/parameter.py +26 -17
  43. sglang/srt/layers/quantization/__init__.py +22 -23
  44. sglang/srt/layers/quantization/fp8.py +112 -55
  45. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  46. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +2 -3
  48. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  49. sglang/srt/layers/radix_attention.py +2 -0
  50. sglang/srt/layers/rotary_embedding.py +1179 -31
  51. sglang/srt/layers/sampler.py +39 -1
  52. sglang/srt/layers/vocab_parallel_embedding.py +17 -4
  53. sglang/srt/lora/lora.py +1 -9
  54. sglang/srt/managers/configure_logging.py +46 -0
  55. sglang/srt/managers/data_parallel_controller.py +79 -72
  56. sglang/srt/managers/detokenizer_manager.py +23 -8
  57. sglang/srt/managers/image_processor.py +158 -2
  58. sglang/srt/managers/io_struct.py +54 -15
  59. sglang/srt/managers/schedule_batch.py +49 -22
  60. sglang/srt/managers/schedule_policy.py +26 -12
  61. sglang/srt/managers/scheduler.py +319 -181
  62. sglang/srt/managers/session_controller.py +1 -0
  63. sglang/srt/managers/tokenizer_manager.py +303 -158
  64. sglang/srt/managers/tp_worker.py +6 -4
  65. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  66. sglang/srt/managers/utils.py +44 -0
  67. sglang/srt/mem_cache/memory_pool.py +110 -77
  68. sglang/srt/metrics/collector.py +25 -11
  69. sglang/srt/model_executor/cuda_graph_runner.py +4 -6
  70. sglang/srt/model_executor/model_runner.py +80 -21
  71. sglang/srt/model_loader/loader.py +8 -6
  72. sglang/srt/model_loader/weight_utils.py +55 -2
  73. sglang/srt/models/baichuan.py +6 -6
  74. sglang/srt/models/chatglm.py +2 -2
  75. sglang/srt/models/commandr.py +3 -3
  76. sglang/srt/models/dbrx.py +4 -4
  77. sglang/srt/models/deepseek.py +3 -3
  78. sglang/srt/models/deepseek_v2.py +8 -8
  79. sglang/srt/models/exaone.py +2 -2
  80. sglang/srt/models/gemma.py +2 -2
  81. sglang/srt/models/gemma2.py +6 -24
  82. sglang/srt/models/gpt2.py +3 -5
  83. sglang/srt/models/gpt_bigcode.py +1 -1
  84. sglang/srt/models/granite.py +2 -2
  85. sglang/srt/models/grok.py +3 -3
  86. sglang/srt/models/internlm2.py +2 -2
  87. sglang/srt/models/llama.py +41 -4
  88. sglang/srt/models/minicpm.py +2 -2
  89. sglang/srt/models/minicpm3.py +6 -6
  90. sglang/srt/models/minicpmv.py +1238 -0
  91. sglang/srt/models/mixtral.py +3 -3
  92. sglang/srt/models/mixtral_quant.py +3 -3
  93. sglang/srt/models/mllama.py +2 -2
  94. sglang/srt/models/olmo.py +3 -3
  95. sglang/srt/models/olmo2.py +4 -4
  96. sglang/srt/models/olmoe.py +7 -13
  97. sglang/srt/models/phi3_small.py +2 -2
  98. sglang/srt/models/qwen.py +2 -2
  99. sglang/srt/models/qwen2.py +52 -4
  100. sglang/srt/models/qwen2_eagle.py +131 -0
  101. sglang/srt/models/qwen2_moe.py +3 -3
  102. sglang/srt/models/qwen2_vl.py +22 -122
  103. sglang/srt/models/stablelm.py +2 -2
  104. sglang/srt/models/torch_native_llama.py +3 -3
  105. sglang/srt/models/xverse.py +6 -6
  106. sglang/srt/models/xverse_moe.py +6 -6
  107. sglang/srt/openai_api/protocol.py +2 -0
  108. sglang/srt/sampling/custom_logit_processor.py +38 -0
  109. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  110. sglang/srt/sampling/sampling_batch_info.py +153 -9
  111. sglang/srt/sampling/sampling_params.py +4 -2
  112. sglang/srt/server.py +4 -1037
  113. sglang/srt/server_args.py +84 -32
  114. sglang/srt/speculative/eagle_worker.py +1 -0
  115. sglang/srt/torch_memory_saver_adapter.py +59 -0
  116. sglang/srt/utils.py +130 -63
  117. sglang/test/runners.py +8 -13
  118. sglang/test/test_programs.py +1 -1
  119. sglang/test/test_utils.py +3 -1
  120. sglang/utils.py +12 -2
  121. sglang/version.py +1 -1
  122. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
  123. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
  124. sglang/launch_server_llavavid.py +0 -25
  125. sglang/srt/constrained/__init__.py +0 -16
  126. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  127. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
  128. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
  129. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -23,15 +23,15 @@ from typing import List, Optional
23
23
  import torch
24
24
 
25
25
  from sglang.srt.hf_transformers_utils import check_gguf_file
26
- from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
27
26
  from sglang.srt.utils import (
28
27
  get_amdgpu_memory_capacity,
29
28
  get_hpu_memory_capacity,
30
29
  get_nvgpu_memory_capacity,
31
30
  is_flashinfer_available,
32
31
  is_hip,
33
- is_ipv6,
34
32
  is_port_available,
33
+ is_valid_ipv6_address,
34
+ nullable_str,
35
35
  )
36
36
 
37
37
  logger = logging.getLogger(__name__)
@@ -47,6 +47,7 @@ class ServerArgs:
47
47
  trust_remote_code: bool = True
48
48
  dtype: str = "auto"
49
49
  kv_cache_dtype: str = "auto"
50
+ quantization_param_path: nullable_str = None
50
51
  quantization: Optional[str] = None
51
52
  context_length: Optional[int] = None
52
53
  device: str = "cuda"
@@ -55,7 +56,6 @@ class ServerArgs:
55
56
  is_embedding: bool = False
56
57
  revision: Optional[str] = None
57
58
  skip_tokenizer_init: bool = False
58
- return_token_ids: bool = False
59
59
 
60
60
  # Port for the HTTP server
61
61
  host: str = "127.0.0.1"
@@ -91,7 +91,7 @@ class ServerArgs:
91
91
 
92
92
  # API related
93
93
  api_key: Optional[str] = None
94
- file_storage_pth: str = "SGLang_storage"
94
+ file_storage_pth: str = "sglang_storage"
95
95
  enable_cache_report: bool = False
96
96
 
97
97
  # Data parallelism
@@ -156,6 +156,11 @@ class ServerArgs:
156
156
  triton_attention_num_kv_splits: int = 8
157
157
  num_continuous_decode_steps: int = 1
158
158
  delete_ckpt_after_loading: bool = False
159
+ enable_memory_saver: bool = False
160
+ allow_auto_truncate: bool = False
161
+
162
+ # Custom logit processor
163
+ enable_custom_logit_processor: bool = False
159
164
 
160
165
  def __post_init__(self):
161
166
  # Set missing default values
@@ -239,14 +244,13 @@ class ServerArgs:
239
244
  # Others
240
245
  if self.enable_dp_attention:
241
246
  self.dp_size = self.tp_size
247
+ assert self.tp_size % self.dp_size == 0
242
248
  self.chunked_prefill_size = self.chunked_prefill_size // 2
243
249
  self.schedule_conservativeness = self.schedule_conservativeness * 0.3
244
- self.disable_overlap_schedule = True
245
250
  logger.warning(
246
251
  f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
247
252
  f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
248
253
  "Data parallel size is adjusted to be the same as tensor parallel size. "
249
- "Overlap scheduler is disabled."
250
254
  )
251
255
 
252
256
  # Speculative Decoding
@@ -296,6 +300,11 @@ class ServerArgs:
296
300
  "tokenizer if available, and 'slow' will "
297
301
  "always use the slow tokenizer.",
298
302
  )
303
+ parser.add_argument(
304
+ "--skip-tokenizer-init",
305
+ action="store_true",
306
+ help="If set, skip init tokenizer and pass input_ids in generate request",
307
+ )
299
308
  parser.add_argument(
300
309
  "--load-format",
301
310
  type=str,
@@ -346,8 +355,17 @@ class ServerArgs:
346
355
  "--kv-cache-dtype",
347
356
  type=str,
348
357
  default=ServerArgs.kv_cache_dtype,
349
- choices=["auto", "fp8_e5m2"],
350
- help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
358
+ choices=["auto", "fp8_e5m2", "fp8_e4m3"],
359
+ help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
360
+ )
361
+ parser.add_argument(
362
+ "--quantization-param-path",
363
+ type=nullable_str,
364
+ default=None,
365
+ help="Path to the JSON file containing the KV cache "
366
+ "scaling factors. This should generally be supplied, when "
367
+ "KV cache dtype is FP8. Otherwise, KV cache scaling factors "
368
+ "default to 1.0, which may cause accuracy issues. ",
351
369
  )
352
370
  parser.add_argument(
353
371
  "--quantization",
@@ -363,6 +381,7 @@ class ServerArgs:
363
381
  "bitsandbytes",
364
382
  "gguf",
365
383
  "modelopt",
384
+ "w8a8_int8",
366
385
  ],
367
386
  help="The quantization method.",
368
387
  )
@@ -376,7 +395,7 @@ class ServerArgs:
376
395
  "--device",
377
396
  type=str,
378
397
  default="cuda",
379
- choices=["cuda", "xpu", "hpu"],
398
+ choices=["cuda", "xpu", "hpu", "cpu"],
380
399
  help="The device type.",
381
400
  )
382
401
  parser.add_argument(
@@ -404,18 +423,6 @@ class ServerArgs:
404
423
  "name, a tag name, or a commit id. If unspecified, will use "
405
424
  "the default version.",
406
425
  )
407
- parser.add_argument(
408
- "--skip-tokenizer-init",
409
- action="store_true",
410
- help="If set, skip init tokenizer and pass input_ids in generate request",
411
- )
412
- parser.add_argument(
413
- "--return-token-ids",
414
- action="store_true",
415
- default=ServerArgs.return_token_ids,
416
- help="Whether to return token IDs in the output, this may introduce additional overhead.",
417
- )
418
-
419
426
  # Memory and scheduling
420
427
  parser.add_argument(
421
428
  "--mem-fraction-static",
@@ -551,7 +558,7 @@ class ServerArgs:
551
558
  "--decode-log-interval",
552
559
  type=int,
553
560
  default=ServerArgs.decode_log_interval,
554
- help="The log interval of decode batch",
561
+ help="The log interval of decode batch.",
555
562
  )
556
563
 
557
564
  # API related
@@ -851,6 +858,21 @@ class ServerArgs:
851
858
  action="store_true",
852
859
  help="Delete the model checkpoint after loading the model.",
853
860
  )
861
+ parser.add_argument(
862
+ "--enable-memory-saver",
863
+ action="store_true",
864
+ help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
865
+ )
866
+ parser.add_argument(
867
+ "--allow-auto-truncate",
868
+ action="store_true",
869
+ help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
870
+ )
871
+ parser.add_argument(
872
+ "--enable-custom-logit-processor",
873
+ action="store_true",
874
+ help="Enable users to pass custom logit processors to the server (disabled by default for security)",
875
+ )
854
876
 
855
877
  @classmethod
856
878
  def from_cli_args(cls, args: argparse.Namespace):
@@ -861,7 +883,7 @@ class ServerArgs:
861
883
  return cls(**{attr: getattr(args, attr) for attr in attrs})
862
884
 
863
885
  def url(self):
864
- if is_ipv6(self.host):
886
+ if is_valid_ipv6_address(self.host):
865
887
  return f"http://[{self.host}]:{self.port}"
866
888
  else:
867
889
  return f"http://{self.host}:{self.port}"
@@ -871,8 +893,8 @@ class ServerArgs:
871
893
  self.tp_size % self.nnodes == 0
872
894
  ), "tp_size must be divisible by number of nodes"
873
895
  assert not (
874
- self.dp_size > 1 and self.nnodes != 1
875
- ), "multi-node data parallel is not supported"
896
+ self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
897
+ ), "multi-node data parallel is not supported unless dp attention!"
876
898
  assert (
877
899
  self.max_loras_per_batch > 0
878
900
  # FIXME
@@ -910,6 +932,9 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
910
932
  return server_args
911
933
 
912
934
 
935
+ ZMQ_TCP_PORT_DELTA = 233
936
+
937
+
913
938
  @dataclasses.dataclass
914
939
  class PortArgs:
915
940
  # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
@@ -923,7 +948,7 @@ class PortArgs:
923
948
  nccl_port: int
924
949
 
925
950
  @staticmethod
926
- def init_new(server_args) -> "PortArgs":
951
+ def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
927
952
  port = server_args.port + random.randint(100, 1000)
928
953
  while True:
929
954
  if is_port_available(port):
@@ -933,12 +958,39 @@ class PortArgs:
933
958
  else:
934
959
  port -= 43
935
960
 
936
- return PortArgs(
937
- tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
938
- scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
939
- detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
940
- nccl_port=port,
941
- )
961
+ if not server_args.enable_dp_attention:
962
+ # Normal case, use IPC within a single node
963
+ return PortArgs(
964
+ tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
965
+ scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
966
+ detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
967
+ nccl_port=port,
968
+ )
969
+ else:
970
+ # DP attention. Use TCP + port to handle both single-node and multi-node.
971
+ if server_args.nnodes == 1 and server_args.dist_init_addr is None:
972
+ dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
973
+ else:
974
+ dist_init_addr = server_args.dist_init_addr.split(":")
975
+ assert (
976
+ len(dist_init_addr) == 2
977
+ ), "please provide --dist-init-addr as host:port of head node"
978
+
979
+ dist_init_host, dist_init_port = dist_init_addr
980
+ port_base = int(dist_init_port) + 1
981
+ if dp_rank is None:
982
+ scheduler_input_port = (
983
+ port_base + 2
984
+ ) # TokenizerManager to DataParallelController
985
+ else:
986
+ scheduler_input_port = port_base + 2 + 1 + dp_rank
987
+
988
+ return PortArgs(
989
+ tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
990
+ scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
991
+ detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
992
+ nccl_port=port,
993
+ )
942
994
 
943
995
 
944
996
  class LoRAPathAction(argparse.Action):
@@ -40,6 +40,7 @@ class EAGLEWorker(TpModelWorker):
40
40
  )
41
41
  self.target_worker = target_worker
42
42
  self.server_args = server_args
43
+ self.finish_extend_len = []
43
44
 
44
45
  # Share the embedding and lm_head
45
46
  embed, head = self.target_worker.model_runner.model.get_embed_and_head()
@@ -0,0 +1,59 @@
1
+ from abc import ABC
2
+ from contextlib import contextmanager
3
+
4
+ try:
5
+ import torch_memory_saver
6
+
7
+ _primary_memory_saver = torch_memory_saver.TorchMemorySaver()
8
+ except ImportError:
9
+ pass
10
+
11
+
12
+ class TorchMemorySaverAdapter(ABC):
13
+ @staticmethod
14
+ def create(enable: bool):
15
+ return (
16
+ _TorchMemorySaverAdapterReal() if enable else _TorchMemorySaverAdapterNoop()
17
+ )
18
+
19
+ def configure_subprocess(self):
20
+ raise NotImplementedError
21
+
22
+ def region(self):
23
+ raise NotImplementedError
24
+
25
+ def pause(self):
26
+ raise NotImplementedError
27
+
28
+ def resume(self):
29
+ raise NotImplementedError
30
+
31
+
32
+ class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
33
+ def configure_subprocess(self):
34
+ return torch_memory_saver.configure_subprocess()
35
+
36
+ def region(self):
37
+ return _primary_memory_saver.region()
38
+
39
+ def pause(self):
40
+ return _primary_memory_saver.pause()
41
+
42
+ def resume(self):
43
+ return _primary_memory_saver.resume()
44
+
45
+
46
+ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
47
+ @contextmanager
48
+ def configure_subprocess(self):
49
+ yield
50
+
51
+ @contextmanager
52
+ def region(self):
53
+ yield
54
+
55
+ def pause(self):
56
+ pass
57
+
58
+ def resume(self):
59
+ pass
sglang/srt/utils.py CHANGED
@@ -59,6 +59,7 @@ from triton.runtime.cache import (
59
59
  default_dump_dir,
60
60
  default_override_dir,
61
61
  )
62
+ from uvicorn.config import LOGGING_CONFIG
62
63
 
63
64
  logger = logging.getLogger(__name__)
64
65
 
@@ -97,12 +98,8 @@ def is_flashinfer_available():
97
98
  return torch.cuda.is_available() and torch.version.cuda
98
99
 
99
100
 
100
- def is_ipv6(address):
101
- try:
102
- ipaddress.IPv6Address(address)
103
- return True
104
- except ipaddress.AddressValueError:
105
- return False
101
+ def is_cuda_available():
102
+ return torch.cuda.is_available() and torch.version.cuda
106
103
 
107
104
 
108
105
  def enable_show_time_cost():
@@ -218,6 +215,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
218
215
 
219
216
  free_gpu_memory, total_gpu_memory = torch.hpu.mem_get_info()
220
217
 
218
+ elif device == "cpu":
219
+ # TODO: rename the variables in the current function to be not GPU specific
220
+ free_gpu_memory = psutil.virtual_memory().available
221
+
221
222
  if distributed:
222
223
  tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
223
224
  torch.device(device, gpu_id)
@@ -442,6 +443,8 @@ def load_image(image_file: Union[str, bytes]):
442
443
  else:
443
444
  raise ValueError(f"Invalid image: {image}")
444
445
 
446
+ # if image_size is None:
447
+ # image_size = image.size
445
448
  return image, image_size
446
449
 
447
450
 
@@ -507,76 +510,32 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
507
510
  pass
508
511
 
509
512
 
510
- def monkey_patch_vllm_p2p_access_check(gpu_id: int):
513
+ def monkey_patch_p2p_access_check():
511
514
  """
512
- Monkey patch the slow p2p access check in vllm.
515
+ Monkey patch the slow p2p access check.
513
516
  NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
514
517
  """
515
518
 
516
- import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
519
+ import sglang.srt.distributed.device_communicators.custom_all_reduce_utils as tgt
517
520
 
518
521
  setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
519
522
 
520
523
  # Suppress the warnings from this delete function when using sglang.bench_one_batch
521
- from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
524
+ from sglang.srt.distributed.device_communicators.custom_all_reduce import (
525
+ CustomAllreduce,
526
+ )
522
527
 
523
528
  setattr(CustomAllreduce, "__del__", lambda *args, **kwargs: None)
524
529
 
525
530
 
526
- vllm_all_gather_backup = None
527
-
528
-
529
- def monkey_patch_vllm_all_gather(reverse: bool = False):
530
- """Monkey patch all-gather to remove in-place operations."""
531
- from torch.distributed import _functional_collectives as funcol
532
- from vllm.distributed.parallel_state import GroupCoordinator
533
-
534
- global vllm_all_gather_backup
535
- if vllm_all_gather_backup is None:
536
- vllm_all_gather_backup = GroupCoordinator.all_gather
537
-
538
- def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
539
- world_size = self.world_size
540
- # Bypass the function if we are using only 1 GPU.
541
- if world_size == 1:
542
- return input_
543
- assert (
544
- -input_.dim() <= dim < input_.dim()
545
- ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
546
- if dim < 0:
547
- # Convert negative dim to positive.
548
- dim += input_.dim()
549
- input_size = input_.size()
550
- # Allocate output tensor.
551
- output_tensor = torch.empty(
552
- (world_size,) + input_size, dtype=input_.dtype, device=input_.device
553
- )
554
-
555
- output_tensor = funcol.all_gather_tensor(
556
- input_, gather_dim=0, group=self.device_group
557
- ).view((world_size,) + input_size)
558
-
559
- # Reshape
560
- output_tensor = output_tensor.movedim(0, dim)
561
- output_tensor = output_tensor.reshape(
562
- input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
563
- )
564
- return output_tensor
565
-
566
- if reverse:
567
- setattr(GroupCoordinator, "all_gather", vllm_all_gather_backup)
568
- else:
569
- setattr(GroupCoordinator, "all_gather", all_gather)
570
-
571
-
572
531
  def monkey_patch_vllm_gguf_config():
573
- from vllm.model_executor.layers.linear import LinearBase
574
532
  from vllm.model_executor.layers.quantization.gguf import (
575
533
  GGUFConfig,
576
534
  GGUFEmbeddingMethod,
577
535
  GGUFLinearMethod,
578
536
  )
579
537
 
538
+ from sglang.srt.layers.linear import LinearBase
580
539
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
581
540
 
582
541
  def get_quant_method_with_embedding_replaced(
@@ -784,7 +743,9 @@ def first_rank_print(*args, **kwargs):
784
743
  pass
785
744
 
786
745
 
787
- def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str):
746
+ def get_zmq_socket(
747
+ context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
748
+ ):
788
749
  mem = psutil.virtual_memory()
789
750
  total_mem = mem.total / 1024**3
790
751
  available_mem = mem.available / 1024**3
@@ -797,14 +758,17 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint:
797
758
  if socket_type == zmq.PUSH:
798
759
  socket.setsockopt(zmq.SNDHWM, 0)
799
760
  socket.setsockopt(zmq.SNDBUF, buf_size)
800
- socket.connect(f"ipc://{endpoint}")
801
761
  elif socket_type == zmq.PULL:
802
762
  socket.setsockopt(zmq.RCVHWM, 0)
803
763
  socket.setsockopt(zmq.RCVBUF, buf_size)
804
- socket.bind(f"ipc://{endpoint}")
805
764
  else:
806
765
  raise ValueError(f"Unsupported socket type: {socket_type}")
807
766
 
767
+ if bind:
768
+ socket.bind(endpoint)
769
+ else:
770
+ socket.connect(endpoint)
771
+
808
772
  return socket
809
773
 
810
774
 
@@ -1246,9 +1210,9 @@ def dataclass_to_string_truncated(data, max_length=2048):
1246
1210
  if isinstance(data, str):
1247
1211
  if len(data) > max_length:
1248
1212
  half_length = max_length // 2
1249
- return f'"{data[:half_length]} ... {data[-half_length:]}"'
1213
+ return f"{repr(data[:half_length])} ... {repr(data[-half_length:])}"
1250
1214
  else:
1251
- return f'"{data}"'
1215
+ return f"{repr(data)}"
1252
1216
  elif isinstance(data, (list, tuple)):
1253
1217
  if len(data) > max_length:
1254
1218
  half_length = max_length // 2
@@ -1259,7 +1223,7 @@ def dataclass_to_string_truncated(data, max_length=2048):
1259
1223
  return (
1260
1224
  "{"
1261
1225
  + ", ".join(
1262
- f"{k}: {dataclass_to_string_truncated(v, max_length)}"
1226
+ f"'{k}': {dataclass_to_string_truncated(v, max_length)}"
1263
1227
  for k, v in data.items()
1264
1228
  )
1265
1229
  + "}"
@@ -1340,6 +1304,25 @@ def parse_tool_response(text, tools, **kwargs):
1340
1304
  return text, call_info_list
1341
1305
 
1342
1306
 
1307
+ def permute_weight(x: torch.Tensor) -> torch.Tensor:
1308
+ b_ = x.shape[0]
1309
+ n_ = x.shape[1]
1310
+ k_ = x.shape[2]
1311
+
1312
+ x_ = x
1313
+ if x.dtype == torch.bfloat16 or x.dtype == torch.float16:
1314
+ x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 32), 4, 8)
1315
+ elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8:
1316
+ x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16)
1317
+ else:
1318
+ return x_
1319
+
1320
+ x_ = x_.permute(0, 1, 3, 4, 2, 5)
1321
+ x_ = x_.contiguous()
1322
+ x_ = x_.view(*x.shape)
1323
+ return x_
1324
+
1325
+
1343
1326
  class MultiprocessingSerializer:
1344
1327
  @staticmethod
1345
1328
  def serialize(obj):
@@ -1375,3 +1358,87 @@ def debug_timing(func):
1375
1358
  return func(*args, **kwargs)
1376
1359
 
1377
1360
  return wrapper
1361
+
1362
+
1363
+ def nullable_str(val: str):
1364
+ if not val or val == "None":
1365
+ return None
1366
+ return val
1367
+
1368
+
1369
+ def set_uvicorn_logging_configs():
1370
+ LOGGING_CONFIG["formatters"]["default"][
1371
+ "fmt"
1372
+ ] = "[%(asctime)s] %(levelprefix)s %(message)s"
1373
+ LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
1374
+ LOGGING_CONFIG["formatters"]["access"][
1375
+ "fmt"
1376
+ ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
1377
+ LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
1378
+
1379
+
1380
+ def get_ip() -> str:
1381
+ # SGLANG_HOST_IP env can be ignore
1382
+ host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
1383
+ if host_ip:
1384
+ return host_ip
1385
+
1386
+ # IP is not set, try to get it from the network interface
1387
+
1388
+ # try ipv4
1389
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1390
+ try:
1391
+ s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
1392
+ return s.getsockname()[0]
1393
+ except Exception:
1394
+ pass
1395
+
1396
+ # try ipv6
1397
+ try:
1398
+ s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
1399
+ # Google's public DNS server, see
1400
+ # https://developers.google.com/speed/public-dns/docs/using#addresses
1401
+ s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
1402
+ return s.getsockname()[0]
1403
+ except Exception:
1404
+ pass
1405
+
1406
+ warnings.warn(
1407
+ "Failed to get the IP address, using 0.0.0.0 by default."
1408
+ "The value can be set by the environment variable"
1409
+ " SGLANG_HOST_IP or HOST_IP.",
1410
+ stacklevel=2,
1411
+ )
1412
+ return "0.0.0.0"
1413
+
1414
+
1415
+ def get_open_port() -> int:
1416
+
1417
+ port = os.getenv("SGLANG_PORT")
1418
+ if port is not None:
1419
+ while True:
1420
+ try:
1421
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
1422
+ s.bind(("", port))
1423
+ return port
1424
+ except OSError:
1425
+ port += 1 # Increment port number if already in use
1426
+ logger.info("Port %d is already in use, trying port %d", port - 1, port)
1427
+ # try ipv4
1428
+ try:
1429
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
1430
+ s.bind(("", 0))
1431
+ return s.getsockname()[1]
1432
+ except OSError:
1433
+ # try ipv6
1434
+ with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
1435
+ s.bind(("", 0))
1436
+ return s.getsockname()[1]
1437
+
1438
+
1439
+ def is_valid_ipv6_address(address: str) -> bool:
1440
+ try:
1441
+ ipaddress.IPv6Address(address)
1442
+ return True
1443
+ except ValueError:
1444
+ return False
sglang/test/runners.py CHANGED
@@ -12,7 +12,6 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
 
15
- import json
16
15
  import multiprocessing as mp
17
16
  import os
18
17
  from dataclasses import dataclass
@@ -22,8 +21,8 @@ import torch
22
21
  import torch.nn.functional as F
23
22
  from transformers import AutoModelForCausalLM
24
23
 
24
+ from sglang.srt.entrypoints.engine import Engine
25
25
  from sglang.srt.hf_transformers_utils import get_tokenizer
26
- from sglang.srt.server import Runtime
27
26
  from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
28
27
 
29
28
  DEFAULT_PROMPTS = [
@@ -278,7 +277,7 @@ class SRTRunner:
278
277
  ):
279
278
  self.model_type = model_type
280
279
  self.is_generation = model_type == "generation"
281
- self.runtime = Runtime(
280
+ self.engine = Engine(
282
281
  model_path=model_path,
283
282
  tp_size=tp_size,
284
283
  dtype=get_dtype_str(torch_dtype),
@@ -306,7 +305,7 @@ class SRTRunner:
306
305
  top_output_logprobs = []
307
306
  sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
308
307
  for i, prompt in enumerate(prompts):
309
- response = self.runtime.generate(
308
+ response = self.engine.generate(
310
309
  prompt,
311
310
  lora_path=lora_paths[i] if lora_paths else None,
312
311
  sampling_params=sampling_params,
@@ -314,7 +313,6 @@ class SRTRunner:
314
313
  logprob_start_len=0,
315
314
  top_logprobs_num=NUM_TOP_LOGPROBS,
316
315
  )
317
- response = json.loads(response)
318
316
  output_strs.append(response["text"])
319
317
  top_input_logprobs.append(
320
318
  [
@@ -343,8 +341,7 @@ class SRTRunner:
343
341
  top_output_logprobs=top_output_logprobs,
344
342
  )
345
343
  else:
346
- response = self.runtime.encode(prompts)
347
- response = json.loads(response)
344
+ response = self.engine.encode(prompts)
348
345
  if self.model_type == "embedding":
349
346
  logits = [x["embedding"] for x in response]
350
347
  return ModelOutput(embed_logits=logits)
@@ -366,20 +363,18 @@ class SRTRunner:
366
363
  # the return value contains logprobs from prefill
367
364
  output_strs = []
368
365
  sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
369
- response = self.runtime.generate(
366
+ response = self.engine.generate(
370
367
  prompts,
371
368
  lora_path=lora_paths if lora_paths else None,
372
369
  sampling_params=sampling_params,
373
370
  )
374
- response = json.loads(response)
375
371
  output_strs = [r["text"] for r in response]
376
372
 
377
373
  return ModelOutput(
378
374
  output_strs=output_strs,
379
375
  )
380
376
  else:
381
- response = self.runtime.encode(prompts)
382
- response = json.loads(response)
377
+ response = self.engine.encode(prompts)
383
378
  if self.model_type == "embedding":
384
379
  logits = [x["embedding"] for x in response]
385
380
  return ModelOutput(embed_logits=logits)
@@ -391,8 +386,8 @@ class SRTRunner:
391
386
  return self
392
387
 
393
388
  def __exit__(self, exc_type, exc_value, traceback):
394
- self.runtime.shutdown()
395
- del self.runtime
389
+ self.engine.shutdown()
390
+ del self.engine
396
391
 
397
392
 
398
393
  def monkey_patch_gemma2_sdpa():