sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +41 -27
  4. sglang/bench_one_batch.py +60 -4
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +83 -71
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +46 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/load_config.py +1 -0
  13. sglang/srt/configs/model_config.py +1 -0
  14. sglang/srt/constrained/base_grammar_backend.py +21 -0
  15. sglang/srt/constrained/xgrammar_backend.py +8 -4
  16. sglang/srt/conversation.py +14 -1
  17. sglang/srt/distributed/__init__.py +3 -3
  18. sglang/srt/distributed/communication_op.py +2 -1
  19. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
  21. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  22. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  23. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  24. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  25. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  26. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  27. sglang/srt/distributed/parallel_state.py +1 -1
  28. sglang/srt/distributed/utils.py +2 -1
  29. sglang/srt/entrypoints/engine.py +452 -0
  30. sglang/srt/entrypoints/http_server.py +603 -0
  31. sglang/srt/function_call_parser.py +494 -0
  32. sglang/srt/layers/activation.py +8 -8
  33. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  34. sglang/srt/layers/attention/triton_backend.py +4 -6
  35. sglang/srt/layers/attention/vision.py +204 -0
  36. sglang/srt/layers/dp_attention.py +71 -0
  37. sglang/srt/layers/layernorm.py +5 -5
  38. sglang/srt/layers/linear.py +65 -14
  39. sglang/srt/layers/logits_processor.py +49 -64
  40. sglang/srt/layers/moe/ep_moe/layer.py +24 -16
  41. sglang/srt/layers/moe/fused_moe_native.py +84 -1
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
  45. sglang/srt/layers/parameter.py +18 -8
  46. sglang/srt/layers/quantization/__init__.py +20 -23
  47. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  49. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  50. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  51. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  52. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  56. sglang/srt/layers/quantization/fp8.py +10 -4
  57. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  58. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  59. sglang/srt/layers/radix_attention.py +2 -2
  60. sglang/srt/layers/rotary_embedding.py +1184 -31
  61. sglang/srt/layers/sampler.py +64 -6
  62. sglang/srt/layers/torchao_utils.py +12 -6
  63. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  64. sglang/srt/lora/lora.py +1 -9
  65. sglang/srt/managers/configure_logging.py +3 -0
  66. sglang/srt/managers/data_parallel_controller.py +79 -72
  67. sglang/srt/managers/detokenizer_manager.py +24 -6
  68. sglang/srt/managers/image_processor.py +158 -2
  69. sglang/srt/managers/io_struct.py +57 -3
  70. sglang/srt/managers/schedule_batch.py +78 -45
  71. sglang/srt/managers/schedule_policy.py +26 -12
  72. sglang/srt/managers/scheduler.py +326 -201
  73. sglang/srt/managers/session_controller.py +1 -0
  74. sglang/srt/managers/tokenizer_manager.py +210 -121
  75. sglang/srt/managers/tp_worker.py +6 -4
  76. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  77. sglang/srt/managers/utils.py +44 -0
  78. sglang/srt/mem_cache/memory_pool.py +10 -32
  79. sglang/srt/metrics/collector.py +15 -6
  80. sglang/srt/model_executor/cuda_graph_runner.py +26 -30
  81. sglang/srt/model_executor/forward_batch_info.py +5 -7
  82. sglang/srt/model_executor/model_runner.py +44 -19
  83. sglang/srt/model_loader/loader.py +83 -6
  84. sglang/srt/model_loader/weight_utils.py +145 -6
  85. sglang/srt/models/baichuan.py +6 -6
  86. sglang/srt/models/chatglm.py +2 -2
  87. sglang/srt/models/commandr.py +17 -5
  88. sglang/srt/models/dbrx.py +13 -5
  89. sglang/srt/models/deepseek.py +3 -3
  90. sglang/srt/models/deepseek_v2.py +11 -11
  91. sglang/srt/models/exaone.py +2 -2
  92. sglang/srt/models/gemma.py +2 -2
  93. sglang/srt/models/gemma2.py +15 -25
  94. sglang/srt/models/gpt2.py +3 -5
  95. sglang/srt/models/gpt_bigcode.py +1 -1
  96. sglang/srt/models/granite.py +2 -2
  97. sglang/srt/models/grok.py +4 -3
  98. sglang/srt/models/internlm2.py +2 -2
  99. sglang/srt/models/llama.py +7 -5
  100. sglang/srt/models/minicpm.py +2 -2
  101. sglang/srt/models/minicpm3.py +9 -9
  102. sglang/srt/models/minicpmv.py +1238 -0
  103. sglang/srt/models/mixtral.py +3 -3
  104. sglang/srt/models/mixtral_quant.py +3 -3
  105. sglang/srt/models/mllama.py +2 -2
  106. sglang/srt/models/olmo.py +3 -3
  107. sglang/srt/models/olmo2.py +4 -4
  108. sglang/srt/models/olmoe.py +7 -13
  109. sglang/srt/models/phi3_small.py +2 -2
  110. sglang/srt/models/qwen.py +2 -2
  111. sglang/srt/models/qwen2.py +41 -4
  112. sglang/srt/models/qwen2_moe.py +3 -3
  113. sglang/srt/models/qwen2_vl.py +22 -122
  114. sglang/srt/models/stablelm.py +2 -2
  115. sglang/srt/models/torch_native_llama.py +20 -7
  116. sglang/srt/models/xverse.py +6 -6
  117. sglang/srt/models/xverse_moe.py +6 -6
  118. sglang/srt/openai_api/adapter.py +139 -37
  119. sglang/srt/openai_api/protocol.py +7 -4
  120. sglang/srt/sampling/custom_logit_processor.py +38 -0
  121. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  122. sglang/srt/sampling/sampling_batch_info.py +143 -18
  123. sglang/srt/sampling/sampling_params.py +3 -1
  124. sglang/srt/server.py +4 -1090
  125. sglang/srt/server_args.py +77 -15
  126. sglang/srt/speculative/eagle_utils.py +37 -15
  127. sglang/srt/speculative/eagle_worker.py +11 -13
  128. sglang/srt/utils.py +164 -129
  129. sglang/test/runners.py +8 -13
  130. sglang/test/test_programs.py +2 -1
  131. sglang/test/test_utils.py +83 -22
  132. sglang/utils.py +12 -2
  133. sglang/version.py +1 -1
  134. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
  135. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
  136. sglang/launch_server_llavavid.py +0 -25
  137. sglang/srt/constrained/__init__.py +0 -16
  138. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  139. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
  140. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
  141. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py CHANGED
@@ -14,6 +14,7 @@
14
14
  """Common utilities."""
15
15
 
16
16
  import base64
17
+ import ctypes
17
18
  import dataclasses
18
19
  import io
19
20
  import ipaddress
@@ -29,6 +30,7 @@ import shutil
29
30
  import signal
30
31
  import socket
31
32
  import subprocess
33
+ import sys
32
34
  import tempfile
33
35
  import time
34
36
  import warnings
@@ -72,7 +74,7 @@ def is_hip() -> bool:
72
74
 
73
75
 
74
76
  def is_cuda():
75
- return hasattr(torch, "cuda") and torch.cuda.is_available()
77
+ return hasattr(torch, "cuda") and torch.version.cuda is not None
76
78
 
77
79
 
78
80
  def is_cuda_alike():
@@ -101,14 +103,6 @@ def is_cuda_available():
101
103
  return torch.cuda.is_available() and torch.version.cuda
102
104
 
103
105
 
104
- def is_ipv6(address):
105
- try:
106
- ipaddress.IPv6Address(address)
107
- return True
108
- except ipaddress.AddressValueError:
109
- return False
110
-
111
-
112
106
  def enable_show_time_cost():
113
107
  global show_time_cost
114
108
  show_time_cost = True
@@ -222,6 +216,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
222
216
 
223
217
  free_gpu_memory, total_gpu_memory = torch.hpu.mem_get_info()
224
218
 
219
+ elif device == "cpu":
220
+ # TODO: rename the variables in the current function to be not GPU specific
221
+ free_gpu_memory = psutil.virtual_memory().available
222
+
225
223
  if distributed:
226
224
  tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
227
225
  torch.device(device, gpu_id)
@@ -446,6 +444,8 @@ def load_image(image_file: Union[str, bytes]):
446
444
  else:
447
445
  raise ValueError(f"Invalid image: {image}")
448
446
 
447
+ # if image_size is None:
448
+ # image_size = image.size
449
449
  return image, image_size
450
450
 
451
451
 
@@ -511,76 +511,32 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
511
511
  pass
512
512
 
513
513
 
514
- def monkey_patch_vllm_p2p_access_check(gpu_id: int):
514
+ def monkey_patch_p2p_access_check():
515
515
  """
516
- Monkey patch the slow p2p access check in vllm.
516
+ Monkey patch the slow p2p access check.
517
517
  NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
518
518
  """
519
519
 
520
- import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
520
+ import sglang.srt.distributed.device_communicators.custom_all_reduce_utils as tgt
521
521
 
522
522
  setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
523
523
 
524
524
  # Suppress the warnings from this delete function when using sglang.bench_one_batch
525
- from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
525
+ from sglang.srt.distributed.device_communicators.custom_all_reduce import (
526
+ CustomAllreduce,
527
+ )
526
528
 
527
529
  setattr(CustomAllreduce, "__del__", lambda *args, **kwargs: None)
528
530
 
529
531
 
530
- vllm_all_gather_backup = None
531
-
532
-
533
- def monkey_patch_vllm_all_gather(reverse: bool = False):
534
- """Monkey patch all-gather to remove in-place operations."""
535
- from torch.distributed import _functional_collectives as funcol
536
- from vllm.distributed.parallel_state import GroupCoordinator
537
-
538
- global vllm_all_gather_backup
539
- if vllm_all_gather_backup is None:
540
- vllm_all_gather_backup = GroupCoordinator.all_gather
541
-
542
- def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
543
- world_size = self.world_size
544
- # Bypass the function if we are using only 1 GPU.
545
- if world_size == 1:
546
- return input_
547
- assert (
548
- -input_.dim() <= dim < input_.dim()
549
- ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
550
- if dim < 0:
551
- # Convert negative dim to positive.
552
- dim += input_.dim()
553
- input_size = input_.size()
554
- # Allocate output tensor.
555
- output_tensor = torch.empty(
556
- (world_size,) + input_size, dtype=input_.dtype, device=input_.device
557
- )
558
-
559
- output_tensor = funcol.all_gather_tensor(
560
- input_, gather_dim=0, group=self.device_group
561
- ).view((world_size,) + input_size)
562
-
563
- # Reshape
564
- output_tensor = output_tensor.movedim(0, dim)
565
- output_tensor = output_tensor.reshape(
566
- input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
567
- )
568
- return output_tensor
569
-
570
- if reverse:
571
- setattr(GroupCoordinator, "all_gather", vllm_all_gather_backup)
572
- else:
573
- setattr(GroupCoordinator, "all_gather", all_gather)
574
-
575
-
576
532
  def monkey_patch_vllm_gguf_config():
577
- from vllm.model_executor.layers.linear import LinearBase
578
533
  from vllm.model_executor.layers.quantization.gguf import (
579
534
  GGUFConfig,
580
535
  GGUFEmbeddingMethod,
581
536
  GGUFLinearMethod,
582
537
  )
583
538
 
539
+ from sglang.srt.layers.linear import LinearBase
584
540
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
585
541
 
586
542
  def get_quant_method_with_embedding_replaced(
@@ -788,7 +744,9 @@ def first_rank_print(*args, **kwargs):
788
744
  pass
789
745
 
790
746
 
791
- def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str):
747
+ def get_zmq_socket(
748
+ context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
749
+ ):
792
750
  mem = psutil.virtual_memory()
793
751
  total_mem = mem.total / 1024**3
794
752
  available_mem = mem.available / 1024**3
@@ -801,19 +759,22 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint:
801
759
  if socket_type == zmq.PUSH:
802
760
  socket.setsockopt(zmq.SNDHWM, 0)
803
761
  socket.setsockopt(zmq.SNDBUF, buf_size)
804
- socket.connect(f"ipc://{endpoint}")
805
762
  elif socket_type == zmq.PULL:
806
763
  socket.setsockopt(zmq.RCVHWM, 0)
807
764
  socket.setsockopt(zmq.RCVBUF, buf_size)
808
- socket.bind(f"ipc://{endpoint}")
809
765
  else:
810
766
  raise ValueError(f"Unsupported socket type: {socket_type}")
811
767
 
768
+ if bind:
769
+ socket.bind(endpoint)
770
+ else:
771
+ socket.connect(endpoint)
772
+
812
773
  return socket
813
774
 
814
775
 
815
776
  def dump_to_file(dirpath, name, value):
816
- from vllm.distributed import get_tensor_model_parallel_rank
777
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
817
778
 
818
779
  if get_tensor_model_parallel_rank() != 0:
819
780
  return
@@ -1250,9 +1211,9 @@ def dataclass_to_string_truncated(data, max_length=2048):
1250
1211
  if isinstance(data, str):
1251
1212
  if len(data) > max_length:
1252
1213
  half_length = max_length // 2
1253
- return f'"{data[:half_length]} ... {data[-half_length:]}"'
1214
+ return f"{repr(data[:half_length])} ... {repr(data[-half_length:])}"
1254
1215
  else:
1255
- return f'"{data}"'
1216
+ return f"{repr(data)}"
1256
1217
  elif isinstance(data, (list, tuple)):
1257
1218
  if len(data) > max_length:
1258
1219
  half_length = max_length // 2
@@ -1263,7 +1224,7 @@ def dataclass_to_string_truncated(data, max_length=2048):
1263
1224
  return (
1264
1225
  "{"
1265
1226
  + ", ".join(
1266
- f"{k}: {dataclass_to_string_truncated(v, max_length)}"
1227
+ f"'{k}': {dataclass_to_string_truncated(v, max_length)}"
1267
1228
  for k, v in data.items()
1268
1229
  )
1269
1230
  + "}"
@@ -1282,68 +1243,6 @@ def dataclass_to_string_truncated(data, max_length=2048):
1282
1243
  return str(data)
1283
1244
 
1284
1245
 
1285
- TOOLS_TAG_LIST = ["<|plugin|>", "<function=", "<tool_call>", "<|python_tag|>"]
1286
-
1287
-
1288
- def parse_tool_response(text, tools, **kwargs):
1289
- """Parse model response containing tool information.
1290
-
1291
- Args:
1292
- text(str): model response in string format
1293
- tools(List): tools from user request
1294
- """
1295
- if "<|plugin|>" in text: # internlm2
1296
- text, action = text.split("<|action_start|><|plugin|>")
1297
- action = action.split("<|action_end|>".strip())[0]
1298
- action = action[action.find("{") :]
1299
- action = json.loads(action)
1300
- name, parameters = action["name"], json.dumps(
1301
- action.get("parameters", action.get("arguments", {})), ensure_ascii=False
1302
- )
1303
- call_info_list = [(name, parameters)]
1304
- elif "<function=" in text: # llama3.1
1305
- action, _ = text.split("</function>")
1306
- parameters = action[action.find("{") :]
1307
- name = action.split("<function=")[1].split(">{")[0]
1308
- call_info_list = [(name, parameters)]
1309
- elif "<tool_call>" in text and "</tool_call>" in text: # qwen2.5
1310
- # get tool_call in text
1311
- pattern = r"<tool_call>(.*?)</tool_call>"
1312
- match_result_list = re.findall(pattern, text, re.DOTALL)
1313
- call_info_list = []
1314
- for match_result in match_result_list:
1315
- action = json.loads(match_result)
1316
- call_info_list.append(
1317
- (action["name"], json.dumps(action["arguments"], ensure_ascii=False))
1318
- )
1319
- # get text outside of tags
1320
- if not text.startswith("<tool_call>"):
1321
- text = text[: text.find("<tool_call>")]
1322
- elif not text.endswith("</tool_call>"):
1323
- text = text[text.rfind("</tool_call>") + len("</tool_call>") :]
1324
- else:
1325
- text = ""
1326
- elif "<|python_tag|>" in text: # llama3.2
1327
- _, action = text.split("<|python_tag|>")
1328
- action = json.loads(action)
1329
- name, parameters = action["name"], json.dumps(
1330
- action.get("parameters", action.get("arguments", {})), ensure_ascii=False
1331
- )
1332
- call_info_list = [(name, parameters)]
1333
- else:
1334
- raise RuntimeError(f"Unexpected model response: {text}")
1335
-
1336
- call_info_list = [
1337
- (
1338
- [tool.function.name for tool in tools].index(call_info[0]),
1339
- call_info[0],
1340
- call_info[1],
1341
- )
1342
- for call_info in call_info_list
1343
- ]
1344
- return text, call_info_list
1345
-
1346
-
1347
1246
  def permute_weight(x: torch.Tensor) -> torch.Tensor:
1348
1247
  b_ = x.shape[0]
1349
1248
  n_ = x.shape[1]
@@ -1404,3 +1303,139 @@ def nullable_str(val: str):
1404
1303
  if not val or val == "None":
1405
1304
  return None
1406
1305
  return val
1306
+
1307
+
1308
+ def pyspy_dump_schedulers():
1309
+ """py-spy dump on all scheduler in a local node."""
1310
+ try:
1311
+ pid = psutil.Process().pid
1312
+ # Command to run py-spy with the PID
1313
+ cmd = f"py-spy dump --pid {pid}"
1314
+ result = subprocess.run(
1315
+ cmd, shell=True, capture_output=True, text=True, check=True
1316
+ )
1317
+ logger.info(f"Profile for PID {pid}:\n{result.stdout}")
1318
+ except subprocess.CalledProcessError as e:
1319
+ logger.info(f"Failed to profile PID {pid}. Error: {e.stderr}")
1320
+
1321
+
1322
+ def kill_itself_when_parent_died():
1323
+ if sys.platform == "linux":
1324
+ # sigkill this process when parent worker manager dies
1325
+ PR_SET_PDEATHSIG = 1
1326
+ libc = ctypes.CDLL("libc.so.6")
1327
+ libc.prctl(PR_SET_PDEATHSIG, signal.SIGKILL)
1328
+ else:
1329
+ logger.warninig("kill_itself_when_parent_died is only supported in linux.")
1330
+
1331
+
1332
+ def set_uvicorn_logging_configs():
1333
+ from uvicorn.config import LOGGING_CONFIG
1334
+
1335
+ LOGGING_CONFIG["formatters"]["default"][
1336
+ "fmt"
1337
+ ] = "[%(asctime)s] %(levelprefix)s %(message)s"
1338
+ LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
1339
+ LOGGING_CONFIG["formatters"]["access"][
1340
+ "fmt"
1341
+ ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
1342
+ LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
1343
+
1344
+
1345
+ def get_ip() -> str:
1346
+ # SGLANG_HOST_IP env can be ignore
1347
+ host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
1348
+ if host_ip:
1349
+ return host_ip
1350
+
1351
+ # IP is not set, try to get it from the network interface
1352
+
1353
+ # try ipv4
1354
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1355
+ try:
1356
+ s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
1357
+ return s.getsockname()[0]
1358
+ except Exception:
1359
+ pass
1360
+
1361
+ # try ipv6
1362
+ try:
1363
+ s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
1364
+ # Google's public DNS server, see
1365
+ # https://developers.google.com/speed/public-dns/docs/using#addresses
1366
+ s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
1367
+ return s.getsockname()[0]
1368
+ except Exception:
1369
+ pass
1370
+
1371
+ warnings.warn(
1372
+ "Failed to get the IP address, using 0.0.0.0 by default."
1373
+ "The value can be set by the environment variable"
1374
+ " SGLANG_HOST_IP or HOST_IP.",
1375
+ stacklevel=2,
1376
+ )
1377
+ return "0.0.0.0"
1378
+
1379
+
1380
+ def get_open_port() -> int:
1381
+
1382
+ port = os.getenv("SGLANG_PORT")
1383
+ if port is not None:
1384
+ while True:
1385
+ try:
1386
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
1387
+ s.bind(("", port))
1388
+ return port
1389
+ except OSError:
1390
+ port += 1 # Increment port number if already in use
1391
+ logger.info("Port %d is already in use, trying port %d", port - 1, port)
1392
+ # try ipv4
1393
+ try:
1394
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
1395
+ s.bind(("", 0))
1396
+ return s.getsockname()[1]
1397
+ except OSError:
1398
+ # try ipv6
1399
+ with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
1400
+ s.bind(("", 0))
1401
+ return s.getsockname()[1]
1402
+
1403
+
1404
+ def is_valid_ipv6_address(address: str) -> bool:
1405
+ try:
1406
+ ipaddress.IPv6Address(address)
1407
+ return True
1408
+ except ValueError:
1409
+ return False
1410
+
1411
+
1412
+ def rank0_print(msg: str):
1413
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
1414
+
1415
+ if get_tensor_model_parallel_rank() == 0:
1416
+ print(msg, flush=True)
1417
+
1418
+
1419
+ def launch_dummy_health_check_server(host, port):
1420
+ import uvicorn
1421
+ from fastapi import FastAPI, Response
1422
+
1423
+ app = FastAPI()
1424
+
1425
+ @app.get("/health")
1426
+ async def health():
1427
+ """Check the health of the http server."""
1428
+ return Response(status_code=200)
1429
+
1430
+ @app.get("/health_generate")
1431
+ async def health_generate():
1432
+ """Check the health of the http server."""
1433
+ return Response(status_code=200)
1434
+
1435
+ uvicorn.run(
1436
+ app,
1437
+ host=host,
1438
+ port=port,
1439
+ timeout_keep_alive=5,
1440
+ loop="uvloop",
1441
+ )
sglang/test/runners.py CHANGED
@@ -12,7 +12,6 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
 
15
- import json
16
15
  import multiprocessing as mp
17
16
  import os
18
17
  from dataclasses import dataclass
@@ -22,8 +21,8 @@ import torch
22
21
  import torch.nn.functional as F
23
22
  from transformers import AutoModelForCausalLM
24
23
 
24
+ from sglang.srt.entrypoints.engine import Engine
25
25
  from sglang.srt.hf_transformers_utils import get_tokenizer
26
- from sglang.srt.server import Runtime
27
26
  from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
28
27
 
29
28
  DEFAULT_PROMPTS = [
@@ -278,7 +277,7 @@ class SRTRunner:
278
277
  ):
279
278
  self.model_type = model_type
280
279
  self.is_generation = model_type == "generation"
281
- self.runtime = Runtime(
280
+ self.engine = Engine(
282
281
  model_path=model_path,
283
282
  tp_size=tp_size,
284
283
  dtype=get_dtype_str(torch_dtype),
@@ -306,7 +305,7 @@ class SRTRunner:
306
305
  top_output_logprobs = []
307
306
  sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
308
307
  for i, prompt in enumerate(prompts):
309
- response = self.runtime.generate(
308
+ response = self.engine.generate(
310
309
  prompt,
311
310
  lora_path=lora_paths[i] if lora_paths else None,
312
311
  sampling_params=sampling_params,
@@ -314,7 +313,6 @@ class SRTRunner:
314
313
  logprob_start_len=0,
315
314
  top_logprobs_num=NUM_TOP_LOGPROBS,
316
315
  )
317
- response = json.loads(response)
318
316
  output_strs.append(response["text"])
319
317
  top_input_logprobs.append(
320
318
  [
@@ -343,8 +341,7 @@ class SRTRunner:
343
341
  top_output_logprobs=top_output_logprobs,
344
342
  )
345
343
  else:
346
- response = self.runtime.encode(prompts)
347
- response = json.loads(response)
344
+ response = self.engine.encode(prompts)
348
345
  if self.model_type == "embedding":
349
346
  logits = [x["embedding"] for x in response]
350
347
  return ModelOutput(embed_logits=logits)
@@ -366,20 +363,18 @@ class SRTRunner:
366
363
  # the return value contains logprobs from prefill
367
364
  output_strs = []
368
365
  sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
369
- response = self.runtime.generate(
366
+ response = self.engine.generate(
370
367
  prompts,
371
368
  lora_path=lora_paths if lora_paths else None,
372
369
  sampling_params=sampling_params,
373
370
  )
374
- response = json.loads(response)
375
371
  output_strs = [r["text"] for r in response]
376
372
 
377
373
  return ModelOutput(
378
374
  output_strs=output_strs,
379
375
  )
380
376
  else:
381
- response = self.runtime.encode(prompts)
382
- response = json.loads(response)
377
+ response = self.engine.encode(prompts)
383
378
  if self.model_type == "embedding":
384
379
  logits = [x["embedding"] for x in response]
385
380
  return ModelOutput(embed_logits=logits)
@@ -391,8 +386,8 @@ class SRTRunner:
391
386
  return self
392
387
 
393
388
  def __exit__(self, exc_type, exc_value, traceback):
394
- self.runtime.shutdown()
395
- del self.runtime
389
+ self.engine.shutdown()
390
+ del self.engine
396
391
 
397
392
 
398
393
  def monkey_patch_gemma2_sdpa():
@@ -535,7 +535,8 @@ def test_hellaswag_select():
535
535
 
536
536
  # Compute accuracy
537
537
  accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
538
- assert np.abs(accuracy_gen - accuracy) < 0.01
538
+ print(f"{accuracy=}, {accuracy_gen=}")
539
+ assert np.abs(accuracy_gen - accuracy) < 0.05
539
540
  assert np.abs(latency_gen - latency) < 1
540
541
 
541
542
  return accuracy, latency
sglang/test/test_utils.py CHANGED
@@ -34,12 +34,16 @@ DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"
34
34
  DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
35
35
  DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
36
36
  DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
37
- DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600
37
+ DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 1000
38
38
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it"
39
39
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct"
40
40
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
41
41
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
42
42
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
43
+ DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
44
+
45
+ DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
46
+ DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmzheng/sglang-EAGLE-llama2-chat-7B"
43
47
 
44
48
 
45
49
  def is_in_ci():
@@ -131,10 +135,6 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
131
135
  return pred
132
136
 
133
137
 
134
- def call_generate_gserver(prompt, temperature, max_tokens, stop=None, url=None):
135
- raise NotImplementedError()
136
-
137
-
138
138
  def call_generate_guidance(
139
139
  prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None
140
140
  ):
@@ -405,7 +405,7 @@ def popen_launch_server(
405
405
  base_url: str,
406
406
  timeout: float,
407
407
  api_key: Optional[str] = None,
408
- other_args: tuple = (),
408
+ other_args: list[str] = (),
409
409
  env: Optional[dict] = None,
410
410
  return_stdout_stderr: Optional[tuple] = None,
411
411
  ):
@@ -526,6 +526,48 @@ def get_similarities(vec1, vec2):
526
526
  return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)
527
527
 
528
528
 
529
+ def get_benchmark_args(
530
+ base_url="",
531
+ dataset_name="",
532
+ dataset_path="",
533
+ tokenizer="",
534
+ num_prompts=500,
535
+ random_input_len=4096,
536
+ random_output_len=2048,
537
+ request_rate=float("inf"),
538
+ disable_stream=False,
539
+ disable_ignore_eos=False,
540
+ ):
541
+ return SimpleNamespace(
542
+ backend="sglang",
543
+ base_url=base_url,
544
+ host=None,
545
+ port=None,
546
+ dataset_name=dataset_name,
547
+ dataset_path=dataset_path,
548
+ model=None,
549
+ tokenizer=tokenizer,
550
+ num_prompts=num_prompts,
551
+ sharegpt_output_len=None,
552
+ sharegpt_context_len=None,
553
+ random_input_len=random_input_len,
554
+ random_output_len=random_output_len,
555
+ random_range_ratio=0.0,
556
+ request_rate=request_rate,
557
+ multi=None,
558
+ output_file=None,
559
+ disable_tqdm=False,
560
+ disable_stream=disable_stream,
561
+ return_logprob=False,
562
+ seed=0,
563
+ disable_ignore_eos=disable_ignore_eos,
564
+ extra_request_body=None,
565
+ apply_chat_template=False,
566
+ profile=None,
567
+ lora_name=None,
568
+ )
569
+
570
+
529
571
  def run_bench_serving(
530
572
  model,
531
573
  num_prompts,
@@ -537,6 +579,7 @@ def run_bench_serving(
537
579
  random_input_len=4096,
538
580
  random_output_len=2048,
539
581
  disable_stream=False,
582
+ disable_ignore_eos=False,
540
583
  need_warmup=False,
541
584
  ):
542
585
  # Launch the server
@@ -549,31 +592,17 @@ def run_bench_serving(
549
592
  )
550
593
 
551
594
  # Run benchmark
552
- args = SimpleNamespace(
553
- backend="sglang",
595
+ args = get_benchmark_args(
554
596
  base_url=base_url,
555
- host=None,
556
- port=None,
557
597
  dataset_name=dataset_name,
558
598
  dataset_path=dataset_path,
559
- model=None,
560
599
  tokenizer=tokenizer,
561
600
  num_prompts=num_prompts,
562
- sharegpt_output_len=None,
563
601
  random_input_len=random_input_len,
564
602
  random_output_len=random_output_len,
565
- random_range_ratio=0.0,
566
603
  request_rate=request_rate,
567
- multi=None,
568
- seed=0,
569
- output_file=None,
570
- disable_tqdm=False,
571
604
  disable_stream=disable_stream,
572
- disable_ignore_eos=False,
573
- return_logprob=False,
574
- lora_name=None,
575
- extra_request_body=None,
576
- profile=None,
605
+ disable_ignore_eos=disable_ignore_eos,
577
606
  )
578
607
 
579
608
  try:
@@ -589,6 +618,38 @@ def run_bench_serving(
589
618
  return res
590
619
 
591
620
 
621
+ def run_bench_serving_multi(
622
+ model,
623
+ base_url,
624
+ other_server_args,
625
+ benchmark_args,
626
+ need_warmup=False,
627
+ ):
628
+ # Launch the server
629
+ process = popen_launch_server(
630
+ model,
631
+ base_url,
632
+ timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
633
+ other_args=other_server_args,
634
+ )
635
+
636
+ # run benchmark for all
637
+ res_l = []
638
+ try:
639
+ for args in benchmark_args:
640
+ if need_warmup:
641
+ warmup_args = copy.deepcopy(args)
642
+ warmup_args.num_prompts = 16
643
+ run_benchmark(warmup_args)
644
+
645
+ res = run_benchmark(args)
646
+ res_l.append((args, res))
647
+ finally:
648
+ kill_process_tree(process.pid)
649
+
650
+ return res_l
651
+
652
+
592
653
  def run_bench_one_batch(model, other_args):
593
654
  command = [
594
655
  "python3",
sglang/utils.py CHANGED
@@ -1,7 +1,6 @@
1
1
  """Common utilities"""
2
2
 
3
3
  import base64
4
- import gc
5
4
  import importlib
6
5
  import json
7
6
  import logging
@@ -15,7 +14,7 @@ import urllib.request
15
14
  from concurrent.futures import ThreadPoolExecutor
16
15
  from io import BytesIO
17
16
  from json import dumps
18
- from typing import Optional, Union
17
+ from typing import Any, Callable, List, Optional, Tuple, Type, Union
19
18
 
20
19
  import numpy as np
21
20
  import requests
@@ -363,3 +362,14 @@ def terminate_process(process):
363
362
  def print_highlight(html_content: str):
364
363
  html_content = str(html_content).replace("\n", "<br>")
365
364
  display(HTML(f"<strong style='color: #00008B;'>{html_content}</strong>"))
365
+
366
+
367
+ class TypeBasedDispatcher:
368
+ def __init__(self, mapping: List[Tuple[Type, Callable]]):
369
+ self._mapping = mapping
370
+
371
+ def __call__(self, obj: Any):
372
+ for ty, fn in self._mapping:
373
+ if isinstance(obj, ty):
374
+ return fn(obj)
375
+ raise ValueError(f"Invalid object: {obj}")