sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/_custom_ops.py +29 -1
  3. sglang/srt/configs/internvl.py +3 -0
  4. sglang/srt/configs/model_config.py +5 -1
  5. sglang/srt/constrained/base_grammar_backend.py +10 -2
  6. sglang/srt/constrained/xgrammar_backend.py +7 -5
  7. sglang/srt/conversation.py +17 -2
  8. sglang/srt/debug_utils/__init__.py +0 -0
  9. sglang/srt/debug_utils/dump_comparator.py +131 -0
  10. sglang/srt/debug_utils/dumper.py +108 -0
  11. sglang/srt/debug_utils/text_comparator.py +172 -0
  12. sglang/srt/disaggregation/common/conn.py +34 -6
  13. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  14. sglang/srt/disaggregation/mini_lb.py +3 -2
  15. sglang/srt/disaggregation/mooncake/conn.py +65 -20
  16. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  17. sglang/srt/disaggregation/nixl/conn.py +17 -13
  18. sglang/srt/disaggregation/prefill.py +13 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  21. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  23. sglang/srt/distributed/parallel_state.py +70 -15
  24. sglang/srt/entrypoints/engine.py +5 -9
  25. sglang/srt/entrypoints/http_server.py +20 -32
  26. sglang/srt/entrypoints/openai/protocol.py +3 -3
  27. sglang/srt/entrypoints/openai/serving_chat.py +148 -72
  28. sglang/srt/function_call/base_format_detector.py +74 -12
  29. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  30. sglang/srt/function_call/ebnf_composer.py +105 -66
  31. sglang/srt/function_call/function_call_parser.py +6 -4
  32. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  33. sglang/srt/function_call/kimik2_detector.py +41 -16
  34. sglang/srt/function_call/llama32_detector.py +6 -3
  35. sglang/srt/function_call/mistral_detector.py +11 -3
  36. sglang/srt/function_call/pythonic_detector.py +16 -14
  37. sglang/srt/function_call/qwen25_detector.py +12 -3
  38. sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
  39. sglang/srt/layers/activation.py +11 -3
  40. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  41. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  42. sglang/srt/layers/attention/vision.py +56 -8
  43. sglang/srt/layers/communicator.py +12 -12
  44. sglang/srt/layers/dp_attention.py +72 -24
  45. sglang/srt/layers/layernorm.py +26 -1
  46. sglang/srt/layers/logits_processor.py +46 -25
  47. sglang/srt/layers/moe/ep_moe/layer.py +172 -206
  48. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
  51. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  52. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  53. sglang/srt/layers/moe/topk.py +88 -34
  54. sglang/srt/layers/multimodal.py +11 -8
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
  56. sglang/srt/layers/quantization/fp8.py +25 -247
  57. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  58. sglang/srt/layers/quantization/modelopt_quant.py +33 -14
  59. sglang/srt/layers/quantization/unquant.py +24 -76
  60. sglang/srt/layers/quantization/utils.py +0 -9
  61. sglang/srt/layers/quantization/w4afp8.py +68 -17
  62. sglang/srt/layers/radix_attention.py +5 -3
  63. sglang/srt/lora/lora_manager.py +133 -169
  64. sglang/srt/lora/lora_registry.py +188 -0
  65. sglang/srt/lora/mem_pool.py +2 -2
  66. sglang/srt/managers/cache_controller.py +62 -13
  67. sglang/srt/managers/io_struct.py +19 -1
  68. sglang/srt/managers/mm_utils.py +154 -35
  69. sglang/srt/managers/multimodal_processor.py +3 -14
  70. sglang/srt/managers/schedule_batch.py +27 -11
  71. sglang/srt/managers/scheduler.py +48 -26
  72. sglang/srt/managers/tokenizer_manager.py +62 -28
  73. sglang/srt/managers/tp_worker.py +5 -4
  74. sglang/srt/mem_cache/allocator.py +67 -7
  75. sglang/srt/mem_cache/hicache_storage.py +17 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +35 -18
  77. sglang/srt/mem_cache/memory_pool_host.py +3 -0
  78. sglang/srt/model_executor/cuda_graph_runner.py +61 -25
  79. sglang/srt/model_executor/forward_batch_info.py +201 -29
  80. sglang/srt/model_executor/model_runner.py +109 -37
  81. sglang/srt/models/deepseek_v2.py +63 -30
  82. sglang/srt/models/glm4_moe.py +1035 -0
  83. sglang/srt/models/glm4_moe_nextn.py +167 -0
  84. sglang/srt/models/interns1.py +328 -0
  85. sglang/srt/models/internvl.py +143 -47
  86. sglang/srt/models/llava.py +9 -5
  87. sglang/srt/models/minicpmo.py +4 -1
  88. sglang/srt/models/mllama4.py +10 -3
  89. sglang/srt/models/qwen2_moe.py +2 -6
  90. sglang/srt/models/qwen3_moe.py +6 -8
  91. sglang/srt/multimodal/processors/base_processor.py +20 -6
  92. sglang/srt/multimodal/processors/clip.py +2 -2
  93. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  94. sglang/srt/multimodal/processors/gemma3.py +2 -2
  95. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  96. sglang/srt/multimodal/processors/internvl.py +21 -8
  97. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  98. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  99. sglang/srt/multimodal/processors/llava.py +4 -4
  100. sglang/srt/multimodal/processors/minicpm.py +2 -3
  101. sglang/srt/multimodal/processors/mlama.py +2 -2
  102. sglang/srt/multimodal/processors/mllama4.py +18 -111
  103. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  104. sglang/srt/multimodal/processors/pixtral.py +2 -2
  105. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  106. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  107. sglang/srt/multimodal/processors/vila.py +3 -1
  108. sglang/srt/reasoning_parser.py +48 -5
  109. sglang/srt/sampling/sampling_batch_info.py +6 -5
  110. sglang/srt/server_args.py +132 -60
  111. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  112. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
  113. sglang/srt/speculative/eagle_utils.py +51 -23
  114. sglang/srt/speculative/eagle_worker.py +59 -44
  115. sglang/srt/two_batch_overlap.py +9 -5
  116. sglang/srt/utils.py +113 -69
  117. sglang/srt/weight_sync/utils.py +119 -0
  118. sglang/test/runners.py +4 -0
  119. sglang/test/test_activation.py +50 -1
  120. sglang/test/test_utils.py +65 -5
  121. sglang/utils.py +19 -0
  122. sglang/version.py +1 -1
  123. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
  124. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
  125. sglang/srt/debug_utils.py +0 -74
  126. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
  127. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
  128. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py CHANGED
@@ -15,6 +15,7 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ import asyncio
18
19
  import builtins
19
20
  import ctypes
20
21
  import dataclasses
@@ -85,6 +86,8 @@ from torch.profiler import ProfilerActivity, profile, record_function
85
86
  from torch.utils._contextlib import _DecoratorContextManager
86
87
  from triton.runtime.cache import FileCacheManager
87
88
 
89
+ from sglang.srt.metrics.func_timer import enable_func_timer
90
+
88
91
  logger = logging.getLogger(__name__)
89
92
 
90
93
  show_time_cost = False
@@ -744,9 +747,13 @@ def load_image(
744
747
  image = Image.open(BytesIO(image_file))
745
748
  elif image_file.startswith("http://") or image_file.startswith("https://"):
746
749
  timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
747
- response = requests.get(image_file, stream=True, timeout=timeout).raw
748
- image = Image.open(response)
749
- response.close()
750
+ response = requests.get(image_file, stream=True, timeout=timeout)
751
+ try:
752
+ response.raise_for_status()
753
+ image = Image.open(response.raw)
754
+ image.load() # Force loading to avoid issues after closing the stream
755
+ finally:
756
+ response.close()
750
757
  elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
751
758
  image = Image.open(image_file)
752
759
  elif image_file.startswith("data:"):
@@ -933,71 +940,6 @@ def monkey_patch_vllm_gguf_config():
933
940
  setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced)
934
941
 
935
942
 
936
- def maybe_set_triton_cache_manager() -> None:
937
- """Set environment variable to tell Triton to use a
938
- custom cache manager"""
939
- cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
940
- if cache_manger is None:
941
- manager = "sglang.srt.utils:CustomCacheManager"
942
- logger.debug("Setting Triton cache manager to: %s", manager)
943
- os.environ["TRITON_CACHE_MANAGER"] = manager
944
-
945
-
946
- class CustomCacheManager(FileCacheManager):
947
- # Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
948
- def __init__(self, key, override=False, dump=False):
949
- from sglang.srt.distributed.parallel_state import get_tp_group
950
-
951
- self.key = key
952
- self.lock_path = None
953
-
954
- try:
955
- module_path = "triton.runtime.cache"
956
- cache_module = importlib.import_module(module_path)
957
-
958
- default_cache_dir = getattr(cache_module, "default_cache_dir", None)
959
- default_dump_dir = getattr(cache_module, "default_dump_dir", None)
960
- default_override_dir = getattr(cache_module, "default_override_dir", None)
961
- except (ModuleNotFoundError, AttributeError) as e:
962
- default_cache_dir = None
963
- default_dump_dir = None
964
- default_override_dir = None
965
-
966
- if dump:
967
- self.cache_dir = (
968
- default_dump_dir()
969
- if default_dump_dir is not None
970
- else os.path.join(Path.home(), ".triton", "dump")
971
- )
972
- self.cache_dir = os.path.join(self.cache_dir, self.key)
973
- self.lock_path = os.path.join(self.cache_dir, "lock")
974
- os.makedirs(self.cache_dir, exist_ok=True)
975
- elif override:
976
- self.cache_dir = (
977
- default_override_dir()
978
- if default_override_dir is not None
979
- else os.path.join(Path.home(), ".triton", "override")
980
- )
981
- self.cache_dir = os.path.join(self.cache_dir, self.key)
982
- else:
983
- # create cache directory if it doesn't exist
984
- self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or (
985
- default_cache_dir()
986
- if default_cache_dir is not None
987
- else os.path.join(Path.home(), ".triton", "cache")
988
- )
989
- if self.cache_dir:
990
- try:
991
- self.cache_dir = f"{self.cache_dir}_{get_tp_group().local_rank}"
992
- except:
993
- self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
994
- self.cache_dir = os.path.join(self.cache_dir, self.key)
995
- self.lock_path = os.path.join(self.cache_dir, "lock")
996
- os.makedirs(self.cache_dir, exist_ok=True)
997
- else:
998
- raise RuntimeError("Could not create or locate cache dir")
999
-
1000
-
1001
943
  def set_ulimit(target_soft_limit=65535):
1002
944
  # number of open files
1003
945
  resource_type = resource.RLIMIT_NOFILE
@@ -2061,6 +2003,16 @@ def is_valid_ipv6_address(address: str) -> bool:
2061
2003
  return False
2062
2004
 
2063
2005
 
2006
+ def maybe_wrap_ipv6_address(address: str) -> str:
2007
+ if is_valid_ipv6_address(address):
2008
+ return f"[{address}]"
2009
+ return address
2010
+
2011
+
2012
+ def format_tcp_address(ip: str, port: int) -> str:
2013
+ return f"tcp://{maybe_wrap_ipv6_address(ip)}:{port}"
2014
+
2015
+
2064
2016
  def configure_ipv6(dist_init_addr):
2065
2017
  addr = dist_init_addr
2066
2018
  end = addr.find("]")
@@ -2100,7 +2052,7 @@ def rank0_log(msg: str):
2100
2052
  logger.info(msg)
2101
2053
 
2102
2054
 
2103
- def launch_dummy_health_check_server(host, port):
2055
+ def launch_dummy_health_check_server(host, port, enable_metrics):
2104
2056
  import asyncio
2105
2057
 
2106
2058
  import uvicorn
@@ -2118,6 +2070,11 @@ def launch_dummy_health_check_server(host, port):
2118
2070
  """Check the health of the http server."""
2119
2071
  return Response(status_code=200)
2120
2072
 
2073
+ # Add prometheus middleware
2074
+ if enable_metrics:
2075
+ add_prometheus_middleware(app)
2076
+ enable_func_timer()
2077
+
2121
2078
  config = uvicorn.Config(
2122
2079
  app,
2123
2080
  host=host,
@@ -2386,6 +2343,7 @@ def is_fa3_default_architecture(hf_config):
2386
2343
  "Gemma3ForConditionalGeneration",
2387
2344
  "Qwen3ForCausalLM",
2388
2345
  "Qwen3MoeForCausalLM",
2346
+ "Glm4MoeForCausalLM",
2389
2347
  }
2390
2348
  return architectures[0] in default_archs
2391
2349
 
@@ -2906,3 +2864,89 @@ SUPPORTED_LORA_TARGET_MODULES = [
2906
2864
  ]
2907
2865
 
2908
2866
  LORA_TARGET_ALL_MODULES = "all"
2867
+
2868
+
2869
+ class ConcurrentCounter:
2870
+ """
2871
+ An asynchronous counter for managing concurrent tasks that need
2872
+ coordinated increments, decrements, and waiting until the count reaches zero.
2873
+
2874
+ This class is useful for scenarios like tracking the number of in-flight tasks
2875
+ and waiting for them to complete.
2876
+ """
2877
+
2878
+ def __init__(self, initial: int = 0):
2879
+ """
2880
+ Initialize the counter with an optional initial value.
2881
+
2882
+ Args:
2883
+ initial (int): The initial value of the counter. Default is 0.
2884
+ """
2885
+ self._count = initial
2886
+ self._condition = asyncio.Condition()
2887
+
2888
+ def value(self) -> int:
2889
+ """
2890
+ Return the current value of the counter.
2891
+
2892
+ Note:
2893
+ This method is not synchronized. It may return a stale value
2894
+ if other coroutines are concurrently modifying the counter.
2895
+
2896
+ Returns:
2897
+ int: The current counter value.
2898
+ """
2899
+ return self._count
2900
+
2901
+ def __repr__(self) -> str:
2902
+ """Return an informative string representation of the counter."""
2903
+ return f"<ConcurrentCounter value={self.value()}>"
2904
+
2905
+ async def increment(self, n: int = 1, notify_all: bool = True):
2906
+ """
2907
+ Atomically increment the counter by a given amount and notify all waiters.
2908
+
2909
+ Args:
2910
+ n (int): The amount to increment the counter by. Default is 1.
2911
+ notify_all (bool): Whether to notify all waiters after incrementing. Default is True.
2912
+ """
2913
+ async with self._condition:
2914
+ self._count += n
2915
+ if notify_all:
2916
+ self._condition.notify_all()
2917
+
2918
+ async def decrement(self, n: int = 1, notify_all: bool = True):
2919
+ """
2920
+ Atomically decrement the counter by a given amount and notify all waiters.
2921
+
2922
+ Args:
2923
+ n (int): The amount to decrement the counter by. Default is 1.
2924
+ notify_all (bool): Whether to notify all waiters after decrementing. Default is True.
2925
+ """
2926
+ async with self._condition:
2927
+ self._count -= n
2928
+ if notify_all:
2929
+ self._condition.notify_all()
2930
+
2931
+ async def wait_for(self, condition: Callable[[int], bool]):
2932
+ """
2933
+ Asynchronously wait until the counter satisfies a given condition.
2934
+
2935
+ This suspends the calling coroutine without blocking the thread, allowing
2936
+ other tasks to run while waiting. When the condition is met, the coroutine resumes.
2937
+
2938
+ Args:
2939
+ condition (Callable[[int], bool]): A function that takes the current counter value
2940
+ and returns True when the condition is satisfied.
2941
+ """
2942
+ async with self._condition:
2943
+ await self._condition.wait_for(lambda: condition(self._count))
2944
+
2945
+ async def wait_for_zero(self):
2946
+ """
2947
+ Asynchronously wait until the counter reaches zero.
2948
+
2949
+ This suspends the calling coroutine without blocking the thread, allowing
2950
+ other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
2951
+ """
2952
+ self.wait_for(lambda count: count == 0)
@@ -0,0 +1,119 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+ from torch.distributed.device_mesh import DeviceMesh
6
+ from torch.distributed.tensor import DTensor
7
+
8
+ from sglang.srt.entrypoints.engine import Engine
9
+ from sglang.srt.managers.tokenizer_manager import UpdateWeightsFromTensorReqInput
10
+ from sglang.srt.model_executor.model_runner import LocalSerializedTensor
11
+ from sglang.srt.utils import MultiprocessingSerializer
12
+
13
+
14
+ async def update_weights(
15
+ engine: Engine,
16
+ params_batch: list[tuple[str, torch.Tensor]],
17
+ device_mesh_key: str,
18
+ device_mesh: DeviceMesh,
19
+ load_format: Optional[str] = None,
20
+ ):
21
+ """
22
+ Update weights for the inference engine.
23
+ This function is designed to be stateless, so that the caller process could keep the stateful engine.
24
+ Example Use Case:
25
+ - Multiple Producer Process will call this function in a SPMD style
26
+
27
+ Args:
28
+ engine: The inference engine created by the caller process.
29
+ params_batch: A list of (name, tensor) tuples. We batched the tensors to avoid the overhead of cpu call.
30
+ device_mesh_key: The key of the device mesh. Typically "tp" or "infer_tp"
31
+ device_mesh: The device mesh.
32
+ load_format: The format of the weights.
33
+ """
34
+ infer_tp_size = device_mesh[device_mesh_key].mesh.size()[0]
35
+ infer_tp_rank = device_mesh[device_mesh_key].get_local_rank()
36
+ from sglang.srt.patch_torch import monkey_patch_torch_reductions
37
+
38
+ monkey_patch_torch_reductions()
39
+
40
+ # [
41
+ # (name0, ipc_tensor0_tp0),
42
+ # (name1, ipc_tensor1_tp0),
43
+ # ]
44
+ named_tensors_batch = [
45
+ (
46
+ name,
47
+ MultiprocessingSerializer.serialize(
48
+ _preprocess_tensor_for_update_weights(tensor)
49
+ ),
50
+ )
51
+ for name, tensor in params_batch
52
+ ]
53
+
54
+ if infer_tp_rank == 0:
55
+ gathered_serialized_batches = [None for _ in range(infer_tp_size)]
56
+ else:
57
+ gathered_serialized_batches = None
58
+
59
+ # [
60
+ # [ (name0, ipc_tensor0_tp0), (name1, ipc_tensor1_tp0) ],
61
+ # [ (name0, ipc_tensor0_tp1), (name1, ipc_tensor1_tp1) ],
62
+ # ]
63
+ dist.gather_object(
64
+ obj=named_tensors_batch,
65
+ object_gather_list=gathered_serialized_batches,
66
+ dst=device_mesh[device_mesh_key].mesh.tolist()[0],
67
+ group=device_mesh[device_mesh_key].get_group(),
68
+ )
69
+
70
+ if infer_tp_rank == 0:
71
+ # Use zip(*) to "transpose" the data structure.
72
+ # After transpose, the data structure is like:
73
+ # [
74
+ # ( (name0, ipc_tensor0_tp0), (name0, ipc_tensor0_tp1) ),
75
+ # ( (name1, ipc_tensor1_tp0), (name1, ipc_tensor1_tp1) ),
76
+ # ]
77
+ logical_tensors = zip(*gathered_serialized_batches, strict=True)
78
+
79
+ named_tensors = [
80
+ # [
81
+ # (name0, LocalSerializedTensor(values=[ipc_tensor0_tp0, ipc_tensor0_tp1])),
82
+ # (name1, LocalSerializedTensor(values=[ipc_tensor1_tp0, ipc_tensor1_tp1])),
83
+ # ]
84
+ (
85
+ tensor_group[0][0],
86
+ LocalSerializedTensor(
87
+ values=[rank_part[1] for rank_part in tensor_group]
88
+ ),
89
+ )
90
+ for tensor_group in logical_tensors
91
+ ]
92
+
93
+ update_weights_request = UpdateWeightsFromTensorReqInput(
94
+ serialized_named_tensors=[
95
+ MultiprocessingSerializer.serialize(named_tensors)
96
+ for _ in range(infer_tp_size)
97
+ ],
98
+ load_format=load_format,
99
+ )
100
+
101
+ return await engine.update_weights_from_tensor(update_weights_request)
102
+
103
+
104
+ def _preprocess_tensor_for_update_weights(tensor: torch.Tensor):
105
+ """
106
+ Preprocess the tensor for update weights.
107
+ Example Use Case:
108
+ - FSDP: we gather tensor by calling full_tensor in _preprocess_tensor_for_update_weights
109
+ - Megatron: we do nothing here, assuming it is gathered when feed into this func
110
+
111
+ Args:
112
+ tensor: The tensor to be preprocessed.
113
+
114
+ Returns:
115
+ The full tensor if it is a DTensor, otherwise the original tensor.
116
+ """
117
+ if isinstance(tensor, DTensor):
118
+ return tensor.full_tensor()
119
+ return tensor
sglang/test/runners.py CHANGED
@@ -491,6 +491,8 @@ class SRTRunner:
491
491
  lora_paths: List[str] = None,
492
492
  max_loras_per_batch: int = 4,
493
493
  attention_backend: Optional[str] = None,
494
+ prefill_attention_backend: Optional[str] = None,
495
+ decode_attention_backend: Optional[str] = None,
494
496
  lora_backend: str = "triton",
495
497
  disable_cuda_graph: bool = False,
496
498
  disable_radix_cache: bool = False,
@@ -540,6 +542,8 @@ class SRTRunner:
540
542
  max_loras_per_batch=max_loras_per_batch,
541
543
  lora_backend=lora_backend,
542
544
  attention_backend=attention_backend,
545
+ prefill_attention_backend=prefill_attention_backend,
546
+ decode_attention_backend=decode_attention_backend,
543
547
  disable_cuda_graph=disable_cuda_graph,
544
548
  disable_radix_cache=disable_radix_cache,
545
549
  chunked_prefill_size=chunked_prefill_size,
@@ -3,9 +3,12 @@ import unittest
3
3
 
4
4
  import torch
5
5
 
6
- from sglang.srt.layers.activation import GeluAndMul
6
+ from sglang.srt.layers.activation import GeluAndMul, QuickGELU
7
+ from sglang.srt.utils import is_hip
7
8
  from sglang.test.test_utils import CustomTestCase
8
9
 
10
+ _is_hip = is_hip()
11
+
9
12
 
10
13
  class TestGeluAndMul(CustomTestCase):
11
14
  DTYPES = [torch.half, torch.bfloat16]
@@ -52,5 +55,51 @@ class TestGeluAndMul(CustomTestCase):
52
55
  self._run_gelu_and_mul_test(*params)
53
56
 
54
57
 
58
+ class TestQuickGELU(CustomTestCase):
59
+ DTYPES = [torch.half, torch.bfloat16]
60
+ NUM_TOKENS = [7, 83, 2048] # batch = sequence length
61
+ DIMS = [512, 4096, 5120, 13824] # all multiples of 16 bytes
62
+ SEEDS = [0]
63
+
64
+ @classmethod
65
+ def setUpClass(cls):
66
+ if not torch.cuda.is_available():
67
+ raise unittest.SkipTest("CUDA is not available")
68
+ torch.set_default_device("cuda")
69
+
70
+ def _run_gelu_quick_test(self, n_tok: int, dim: int, dtype: torch.dtype, seed: int):
71
+ torch.manual_seed(seed)
72
+
73
+ layer = QuickGELU().to(dtype=dtype)
74
+
75
+ x = torch.randn(n_tok, dim, dtype=dtype, device="cuda")
76
+
77
+ with torch.inference_mode():
78
+ ref = layer.forward_native(x) # x * sigmoid(1.702 * x), fp32 math
79
+ if _is_hip:
80
+ out = layer.forward_hip(x) # 128-bit vectorised kernel from sgl-kernel
81
+ else:
82
+ out = layer.forward_cuda(x)
83
+
84
+ tol = 1e-2 if dtype is torch.bfloat16 else 1e-3
85
+ self.assertTrue(
86
+ torch.allclose(out, ref, atol=tol, rtol=tol),
87
+ msg=f"Mismatch @ B={n_tok}, D={dim}, dtype={dtype}",
88
+ )
89
+ print(f"Match @ B={n_tok}, D={dim}, dtype={dtype}")
90
+
91
+ def test_quick_gelu(self):
92
+ for params in itertools.product(
93
+ self.NUM_TOKENS, self.DIMS, self.DTYPES, self.SEEDS
94
+ ):
95
+ with self.subTest(
96
+ num_tokens=params[0],
97
+ dim=params[1],
98
+ dtype=params[2],
99
+ seed=params[3],
100
+ ):
101
+ self._run_gelu_quick_test(*params)
102
+
103
+
55
104
  if __name__ == "__main__":
56
105
  unittest.main(verbosity=2)
sglang/test/test_utils.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """Common utilities for testing and benchmarking"""
2
2
 
3
3
  import argparse
4
+ import asyncio
4
5
  import copy
5
6
  import json
6
7
  import logging
@@ -14,8 +15,9 @@ import unittest
14
15
  from concurrent.futures import ThreadPoolExecutor
15
16
  from dataclasses import dataclass
16
17
  from functools import partial
18
+ from pathlib import Path
17
19
  from types import SimpleNamespace
18
- from typing import Callable, List, Optional, Tuple
20
+ from typing import Awaitable, Callable, List, Optional, Tuple
19
21
 
20
22
  import numpy as np
21
23
  import requests
@@ -26,6 +28,7 @@ from sglang.bench_serving import run_benchmark
26
28
  from sglang.global_config import global_config
27
29
  from sglang.lang.backend.openai import OpenAI
28
30
  from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
31
+ from sglang.lang.interpreter import ProgramState
29
32
  from sglang.srt.utils import (
30
33
  get_bool_env_var,
31
34
  get_device,
@@ -347,6 +350,7 @@ def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
347
350
  help="Device type (auto/cuda/rocm/cpu). Auto will detect available platforms",
348
351
  )
349
352
  parser.add_argument("--result-file", type=str, default="result.jsonl")
353
+ parser.add_argument("--raw-result-file", type=str)
350
354
  args = parser.parse_args()
351
355
 
352
356
  return args
@@ -714,6 +718,7 @@ def get_benchmark_args(
714
718
  seed: int = 0,
715
719
  device="auto",
716
720
  pd_separated: bool = False,
721
+ lora_name=None,
717
722
  ):
718
723
  return SimpleNamespace(
719
724
  backend="sglang",
@@ -741,7 +746,7 @@ def get_benchmark_args(
741
746
  extra_request_body=None,
742
747
  apply_chat_template=False,
743
748
  profile=None,
744
- lora_name=None,
749
+ lora_name=lora_name,
745
750
  prompt_suffix="",
746
751
  device=device,
747
752
  pd_separated=pd_separated,
@@ -764,6 +769,8 @@ def run_bench_serving(
764
769
  need_warmup=False,
765
770
  seed: int = 0,
766
771
  device="auto",
772
+ background_task: Optional[Callable[[str, asyncio.Event], Awaitable[None]]] = None,
773
+ lora_name: Optional[str] = None,
767
774
  ):
768
775
  if device == "auto":
769
776
  device = auto_config_device()
@@ -791,14 +798,35 @@ def run_bench_serving(
791
798
  disable_ignore_eos=disable_ignore_eos,
792
799
  seed=seed,
793
800
  device=device,
801
+ lora_name=lora_name,
794
802
  )
795
803
 
796
- try:
804
+ async def _run():
797
805
  if need_warmup:
798
806
  warmup_args = copy.deepcopy(args)
799
807
  warmup_args.num_prompts = 16
800
- run_benchmark(warmup_args)
801
- res = run_benchmark(args)
808
+ await asyncio.to_thread(run_benchmark, warmup_args)
809
+
810
+ start_event = asyncio.Event()
811
+ stop_event = asyncio.Event()
812
+ task_handle = (
813
+ asyncio.create_task(background_task(base_url, start_event, stop_event))
814
+ if background_task
815
+ else None
816
+ )
817
+
818
+ try:
819
+ start_event.set()
820
+ result = await asyncio.to_thread(run_benchmark, args)
821
+ finally:
822
+ if task_handle:
823
+ stop_event.set()
824
+ await task_handle
825
+
826
+ return result
827
+
828
+ try:
829
+ res = asyncio.run(_run())
802
830
  finally:
803
831
  kill_process_tree(process.pid)
804
832
 
@@ -1284,3 +1312,35 @@ class CustomTestCase(unittest.TestCase):
1284
1312
  lambda: super(CustomTestCase, self)._callTestMethod(method),
1285
1313
  max_retry=max_retry,
1286
1314
  )
1315
+
1316
+
1317
+ def dump_bench_raw_result(
1318
+ path: str,
1319
+ states,
1320
+ preds,
1321
+ labels,
1322
+ ):
1323
+ if not path:
1324
+ return
1325
+
1326
+ rows = []
1327
+ for i in range(len(states)):
1328
+ state = states[i]
1329
+ output = state["answer"]
1330
+ prompt = _ensure_remove_suffix(state.text(), output)
1331
+ rows.append(
1332
+ dict(
1333
+ prompt_id=i,
1334
+ prompt=prompt,
1335
+ output=output,
1336
+ correct=bool(preds[i] == labels[i]),
1337
+ )
1338
+ )
1339
+
1340
+ print(f"BenchRawResultDumper save results to {path}")
1341
+ Path(path).write_text("\n".join(json.dumps(row) for row in rows))
1342
+
1343
+
1344
+ def _ensure_remove_suffix(text: str, suffix: str):
1345
+ assert text.endswith(suffix)
1346
+ return text.removesuffix(suffix)
sglang/utils.py CHANGED
@@ -14,6 +14,7 @@ import traceback
14
14
  import urllib.request
15
15
  import weakref
16
16
  from concurrent.futures import ThreadPoolExecutor
17
+ from functools import wraps
17
18
  from io import BytesIO
18
19
  from json import dumps
19
20
  from typing import Any, Callable, List, Optional, Tuple, Type, Union
@@ -28,6 +29,24 @@ from tqdm import tqdm
28
29
  logger = logging.getLogger(__name__)
29
30
 
30
31
 
32
+ def execute_once(func):
33
+ has_run = None
34
+
35
+ @wraps(func)
36
+ def wrapper(*args, **kwargs):
37
+ nonlocal has_run
38
+ if not has_run:
39
+ func(*args, **kwargs)
40
+ has_run = True
41
+
42
+ return wrapper
43
+
44
+
45
+ @execute_once
46
+ def info_once(message: str):
47
+ logger.info(message)
48
+
49
+
31
50
  def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str:
32
51
  """Convert a JSON schema to a string.
33
52
  Parameters
sglang/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.9.post3"
1
+ __version__ = "0.4.9.post5"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sglang
3
- Version: 0.4.9.post3
3
+ Version: 0.4.9.post5
4
4
  Summary: SGLang is yet another fast serving framework for large language models and vision language models.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -246,20 +246,20 @@ Requires-Dist: sentencepiece; extra == "runtime-common"
246
246
  Requires-Dist: soundfile==0.13.1; extra == "runtime-common"
247
247
  Requires-Dist: scipy; extra == "runtime-common"
248
248
  Requires-Dist: torchao==0.9.0; extra == "runtime-common"
249
- Requires-Dist: transformers==4.53.2; extra == "runtime-common"
249
+ Requires-Dist: transformers==4.54.0; extra == "runtime-common"
250
250
  Requires-Dist: timm==1.0.16; extra == "runtime-common"
251
251
  Requires-Dist: uvicorn; extra == "runtime-common"
252
252
  Requires-Dist: uvloop; extra == "runtime-common"
253
253
  Requires-Dist: xgrammar==0.1.21; extra == "runtime-common"
254
254
  Provides-Extra: srt
255
255
  Requires-Dist: sglang[runtime_common]; extra == "srt"
256
- Requires-Dist: sgl-kernel==0.2.6.post1; extra == "srt"
256
+ Requires-Dist: sgl-kernel==0.2.7; extra == "srt"
257
257
  Requires-Dist: torch==2.7.1; extra == "srt"
258
258
  Requires-Dist: torchaudio==2.7.1; extra == "srt"
259
259
  Requires-Dist: torchvision==0.22.1; extra == "srt"
260
260
  Requires-Dist: cuda-python; extra == "srt"
261
261
  Requires-Dist: einops; extra == "srt"
262
- Requires-Dist: flashinfer_python==0.2.7.post1; extra == "srt"
262
+ Requires-Dist: flashinfer_python==0.2.9rc2; extra == "srt"
263
263
  Provides-Extra: blackwell
264
264
  Requires-Dist: sglang[runtime_common]; extra == "blackwell"
265
265
  Requires-Dist: sgl-kernel; extra == "blackwell"
@@ -268,11 +268,11 @@ Requires-Dist: torchaudio==2.7.1; extra == "blackwell"
268
268
  Requires-Dist: torchvision==0.22.1; extra == "blackwell"
269
269
  Requires-Dist: cuda-python; extra == "blackwell"
270
270
  Requires-Dist: einops; extra == "blackwell"
271
- Requires-Dist: flashinfer_python==0.2.7.post1; extra == "blackwell"
271
+ Requires-Dist: flashinfer_python==0.2.9rc2; extra == "blackwell"
272
272
  Provides-Extra: srt-hip
273
273
  Requires-Dist: sglang[runtime_common]; extra == "srt-hip"
274
274
  Requires-Dist: torch; extra == "srt-hip"
275
- Requires-Dist: petit_kernel; extra == "srt-hip"
275
+ Requires-Dist: petit_kernel==0.0.2; extra == "srt-hip"
276
276
  Provides-Extra: srt-xpu
277
277
  Requires-Dist: sglang[runtime_common]; extra == "srt-xpu"
278
278
  Provides-Extra: srt-hpu