sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +5 -1
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +17 -2
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +65 -20
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +70 -15
- sglang/srt/entrypoints/engine.py +5 -9
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +148 -72
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +105 -66
- sglang/srt/function_call/function_call_parser.py +6 -4
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
- sglang/srt/layers/activation.py +11 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +46 -25
- sglang/srt/layers/moe/ep_moe/layer.py +172 -206
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/topk.py +88 -34
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +33 -14
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/utils.py +0 -9
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/lora/lora_manager.py +133 -169
- sglang/srt/lora/lora_registry.py +188 -0
- sglang/srt/lora/mem_pool.py +2 -2
- sglang/srt/managers/cache_controller.py +62 -13
- sglang/srt/managers/io_struct.py +19 -1
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +27 -11
- sglang/srt/managers/scheduler.py +48 -26
- sglang/srt/managers/tokenizer_manager.py +62 -28
- sglang/srt/managers/tp_worker.py +5 -4
- sglang/srt/mem_cache/allocator.py +67 -7
- sglang/srt/mem_cache/hicache_storage.py +17 -1
- sglang/srt/mem_cache/hiradix_cache.py +35 -18
- sglang/srt/mem_cache/memory_pool_host.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +61 -25
- sglang/srt/model_executor/forward_batch_info.py +201 -29
- sglang/srt/model_executor/model_runner.py +109 -37
- sglang/srt/models/deepseek_v2.py +63 -30
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/mllama4.py +10 -3
- sglang/srt/models/qwen2_moe.py +2 -6
- sglang/srt/models/qwen3_moe.py +6 -8
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/reasoning_parser.py +48 -5
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/server_args.py +132 -60
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils.py +113 -69
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_activation.py +50 -1
- sglang/test/test_utils.py +65 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
+
import asyncio
|
18
19
|
import builtins
|
19
20
|
import ctypes
|
20
21
|
import dataclasses
|
@@ -85,6 +86,8 @@ from torch.profiler import ProfilerActivity, profile, record_function
|
|
85
86
|
from torch.utils._contextlib import _DecoratorContextManager
|
86
87
|
from triton.runtime.cache import FileCacheManager
|
87
88
|
|
89
|
+
from sglang.srt.metrics.func_timer import enable_func_timer
|
90
|
+
|
88
91
|
logger = logging.getLogger(__name__)
|
89
92
|
|
90
93
|
show_time_cost = False
|
@@ -744,9 +747,13 @@ def load_image(
|
|
744
747
|
image = Image.open(BytesIO(image_file))
|
745
748
|
elif image_file.startswith("http://") or image_file.startswith("https://"):
|
746
749
|
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
747
|
-
response = requests.get(image_file, stream=True, timeout=timeout)
|
748
|
-
|
749
|
-
|
750
|
+
response = requests.get(image_file, stream=True, timeout=timeout)
|
751
|
+
try:
|
752
|
+
response.raise_for_status()
|
753
|
+
image = Image.open(response.raw)
|
754
|
+
image.load() # Force loading to avoid issues after closing the stream
|
755
|
+
finally:
|
756
|
+
response.close()
|
750
757
|
elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
|
751
758
|
image = Image.open(image_file)
|
752
759
|
elif image_file.startswith("data:"):
|
@@ -933,71 +940,6 @@ def monkey_patch_vllm_gguf_config():
|
|
933
940
|
setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced)
|
934
941
|
|
935
942
|
|
936
|
-
def maybe_set_triton_cache_manager() -> None:
|
937
|
-
"""Set environment variable to tell Triton to use a
|
938
|
-
custom cache manager"""
|
939
|
-
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
|
940
|
-
if cache_manger is None:
|
941
|
-
manager = "sglang.srt.utils:CustomCacheManager"
|
942
|
-
logger.debug("Setting Triton cache manager to: %s", manager)
|
943
|
-
os.environ["TRITON_CACHE_MANAGER"] = manager
|
944
|
-
|
945
|
-
|
946
|
-
class CustomCacheManager(FileCacheManager):
|
947
|
-
# Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
|
948
|
-
def __init__(self, key, override=False, dump=False):
|
949
|
-
from sglang.srt.distributed.parallel_state import get_tp_group
|
950
|
-
|
951
|
-
self.key = key
|
952
|
-
self.lock_path = None
|
953
|
-
|
954
|
-
try:
|
955
|
-
module_path = "triton.runtime.cache"
|
956
|
-
cache_module = importlib.import_module(module_path)
|
957
|
-
|
958
|
-
default_cache_dir = getattr(cache_module, "default_cache_dir", None)
|
959
|
-
default_dump_dir = getattr(cache_module, "default_dump_dir", None)
|
960
|
-
default_override_dir = getattr(cache_module, "default_override_dir", None)
|
961
|
-
except (ModuleNotFoundError, AttributeError) as e:
|
962
|
-
default_cache_dir = None
|
963
|
-
default_dump_dir = None
|
964
|
-
default_override_dir = None
|
965
|
-
|
966
|
-
if dump:
|
967
|
-
self.cache_dir = (
|
968
|
-
default_dump_dir()
|
969
|
-
if default_dump_dir is not None
|
970
|
-
else os.path.join(Path.home(), ".triton", "dump")
|
971
|
-
)
|
972
|
-
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
973
|
-
self.lock_path = os.path.join(self.cache_dir, "lock")
|
974
|
-
os.makedirs(self.cache_dir, exist_ok=True)
|
975
|
-
elif override:
|
976
|
-
self.cache_dir = (
|
977
|
-
default_override_dir()
|
978
|
-
if default_override_dir is not None
|
979
|
-
else os.path.join(Path.home(), ".triton", "override")
|
980
|
-
)
|
981
|
-
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
982
|
-
else:
|
983
|
-
# create cache directory if it doesn't exist
|
984
|
-
self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or (
|
985
|
-
default_cache_dir()
|
986
|
-
if default_cache_dir is not None
|
987
|
-
else os.path.join(Path.home(), ".triton", "cache")
|
988
|
-
)
|
989
|
-
if self.cache_dir:
|
990
|
-
try:
|
991
|
-
self.cache_dir = f"{self.cache_dir}_{get_tp_group().local_rank}"
|
992
|
-
except:
|
993
|
-
self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
|
994
|
-
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
995
|
-
self.lock_path = os.path.join(self.cache_dir, "lock")
|
996
|
-
os.makedirs(self.cache_dir, exist_ok=True)
|
997
|
-
else:
|
998
|
-
raise RuntimeError("Could not create or locate cache dir")
|
999
|
-
|
1000
|
-
|
1001
943
|
def set_ulimit(target_soft_limit=65535):
|
1002
944
|
# number of open files
|
1003
945
|
resource_type = resource.RLIMIT_NOFILE
|
@@ -2061,6 +2003,16 @@ def is_valid_ipv6_address(address: str) -> bool:
|
|
2061
2003
|
return False
|
2062
2004
|
|
2063
2005
|
|
2006
|
+
def maybe_wrap_ipv6_address(address: str) -> str:
|
2007
|
+
if is_valid_ipv6_address(address):
|
2008
|
+
return f"[{address}]"
|
2009
|
+
return address
|
2010
|
+
|
2011
|
+
|
2012
|
+
def format_tcp_address(ip: str, port: int) -> str:
|
2013
|
+
return f"tcp://{maybe_wrap_ipv6_address(ip)}:{port}"
|
2014
|
+
|
2015
|
+
|
2064
2016
|
def configure_ipv6(dist_init_addr):
|
2065
2017
|
addr = dist_init_addr
|
2066
2018
|
end = addr.find("]")
|
@@ -2100,7 +2052,7 @@ def rank0_log(msg: str):
|
|
2100
2052
|
logger.info(msg)
|
2101
2053
|
|
2102
2054
|
|
2103
|
-
def launch_dummy_health_check_server(host, port):
|
2055
|
+
def launch_dummy_health_check_server(host, port, enable_metrics):
|
2104
2056
|
import asyncio
|
2105
2057
|
|
2106
2058
|
import uvicorn
|
@@ -2118,6 +2070,11 @@ def launch_dummy_health_check_server(host, port):
|
|
2118
2070
|
"""Check the health of the http server."""
|
2119
2071
|
return Response(status_code=200)
|
2120
2072
|
|
2073
|
+
# Add prometheus middleware
|
2074
|
+
if enable_metrics:
|
2075
|
+
add_prometheus_middleware(app)
|
2076
|
+
enable_func_timer()
|
2077
|
+
|
2121
2078
|
config = uvicorn.Config(
|
2122
2079
|
app,
|
2123
2080
|
host=host,
|
@@ -2386,6 +2343,7 @@ def is_fa3_default_architecture(hf_config):
|
|
2386
2343
|
"Gemma3ForConditionalGeneration",
|
2387
2344
|
"Qwen3ForCausalLM",
|
2388
2345
|
"Qwen3MoeForCausalLM",
|
2346
|
+
"Glm4MoeForCausalLM",
|
2389
2347
|
}
|
2390
2348
|
return architectures[0] in default_archs
|
2391
2349
|
|
@@ -2906,3 +2864,89 @@ SUPPORTED_LORA_TARGET_MODULES = [
|
|
2906
2864
|
]
|
2907
2865
|
|
2908
2866
|
LORA_TARGET_ALL_MODULES = "all"
|
2867
|
+
|
2868
|
+
|
2869
|
+
class ConcurrentCounter:
|
2870
|
+
"""
|
2871
|
+
An asynchronous counter for managing concurrent tasks that need
|
2872
|
+
coordinated increments, decrements, and waiting until the count reaches zero.
|
2873
|
+
|
2874
|
+
This class is useful for scenarios like tracking the number of in-flight tasks
|
2875
|
+
and waiting for them to complete.
|
2876
|
+
"""
|
2877
|
+
|
2878
|
+
def __init__(self, initial: int = 0):
|
2879
|
+
"""
|
2880
|
+
Initialize the counter with an optional initial value.
|
2881
|
+
|
2882
|
+
Args:
|
2883
|
+
initial (int): The initial value of the counter. Default is 0.
|
2884
|
+
"""
|
2885
|
+
self._count = initial
|
2886
|
+
self._condition = asyncio.Condition()
|
2887
|
+
|
2888
|
+
def value(self) -> int:
|
2889
|
+
"""
|
2890
|
+
Return the current value of the counter.
|
2891
|
+
|
2892
|
+
Note:
|
2893
|
+
This method is not synchronized. It may return a stale value
|
2894
|
+
if other coroutines are concurrently modifying the counter.
|
2895
|
+
|
2896
|
+
Returns:
|
2897
|
+
int: The current counter value.
|
2898
|
+
"""
|
2899
|
+
return self._count
|
2900
|
+
|
2901
|
+
def __repr__(self) -> str:
|
2902
|
+
"""Return an informative string representation of the counter."""
|
2903
|
+
return f"<ConcurrentCounter value={self.value()}>"
|
2904
|
+
|
2905
|
+
async def increment(self, n: int = 1, notify_all: bool = True):
|
2906
|
+
"""
|
2907
|
+
Atomically increment the counter by a given amount and notify all waiters.
|
2908
|
+
|
2909
|
+
Args:
|
2910
|
+
n (int): The amount to increment the counter by. Default is 1.
|
2911
|
+
notify_all (bool): Whether to notify all waiters after incrementing. Default is True.
|
2912
|
+
"""
|
2913
|
+
async with self._condition:
|
2914
|
+
self._count += n
|
2915
|
+
if notify_all:
|
2916
|
+
self._condition.notify_all()
|
2917
|
+
|
2918
|
+
async def decrement(self, n: int = 1, notify_all: bool = True):
|
2919
|
+
"""
|
2920
|
+
Atomically decrement the counter by a given amount and notify all waiters.
|
2921
|
+
|
2922
|
+
Args:
|
2923
|
+
n (int): The amount to decrement the counter by. Default is 1.
|
2924
|
+
notify_all (bool): Whether to notify all waiters after decrementing. Default is True.
|
2925
|
+
"""
|
2926
|
+
async with self._condition:
|
2927
|
+
self._count -= n
|
2928
|
+
if notify_all:
|
2929
|
+
self._condition.notify_all()
|
2930
|
+
|
2931
|
+
async def wait_for(self, condition: Callable[[int], bool]):
|
2932
|
+
"""
|
2933
|
+
Asynchronously wait until the counter satisfies a given condition.
|
2934
|
+
|
2935
|
+
This suspends the calling coroutine without blocking the thread, allowing
|
2936
|
+
other tasks to run while waiting. When the condition is met, the coroutine resumes.
|
2937
|
+
|
2938
|
+
Args:
|
2939
|
+
condition (Callable[[int], bool]): A function that takes the current counter value
|
2940
|
+
and returns True when the condition is satisfied.
|
2941
|
+
"""
|
2942
|
+
async with self._condition:
|
2943
|
+
await self._condition.wait_for(lambda: condition(self._count))
|
2944
|
+
|
2945
|
+
async def wait_for_zero(self):
|
2946
|
+
"""
|
2947
|
+
Asynchronously wait until the counter reaches zero.
|
2948
|
+
|
2949
|
+
This suspends the calling coroutine without blocking the thread, allowing
|
2950
|
+
other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
|
2951
|
+
"""
|
2952
|
+
self.wait_for(lambda count: count == 0)
|
@@ -0,0 +1,119 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.distributed as dist
|
5
|
+
from torch.distributed.device_mesh import DeviceMesh
|
6
|
+
from torch.distributed.tensor import DTensor
|
7
|
+
|
8
|
+
from sglang.srt.entrypoints.engine import Engine
|
9
|
+
from sglang.srt.managers.tokenizer_manager import UpdateWeightsFromTensorReqInput
|
10
|
+
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
11
|
+
from sglang.srt.utils import MultiprocessingSerializer
|
12
|
+
|
13
|
+
|
14
|
+
async def update_weights(
|
15
|
+
engine: Engine,
|
16
|
+
params_batch: list[tuple[str, torch.Tensor]],
|
17
|
+
device_mesh_key: str,
|
18
|
+
device_mesh: DeviceMesh,
|
19
|
+
load_format: Optional[str] = None,
|
20
|
+
):
|
21
|
+
"""
|
22
|
+
Update weights for the inference engine.
|
23
|
+
This function is designed to be stateless, so that the caller process could keep the stateful engine.
|
24
|
+
Example Use Case:
|
25
|
+
- Multiple Producer Process will call this function in a SPMD style
|
26
|
+
|
27
|
+
Args:
|
28
|
+
engine: The inference engine created by the caller process.
|
29
|
+
params_batch: A list of (name, tensor) tuples. We batched the tensors to avoid the overhead of cpu call.
|
30
|
+
device_mesh_key: The key of the device mesh. Typically "tp" or "infer_tp"
|
31
|
+
device_mesh: The device mesh.
|
32
|
+
load_format: The format of the weights.
|
33
|
+
"""
|
34
|
+
infer_tp_size = device_mesh[device_mesh_key].mesh.size()[0]
|
35
|
+
infer_tp_rank = device_mesh[device_mesh_key].get_local_rank()
|
36
|
+
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
37
|
+
|
38
|
+
monkey_patch_torch_reductions()
|
39
|
+
|
40
|
+
# [
|
41
|
+
# (name0, ipc_tensor0_tp0),
|
42
|
+
# (name1, ipc_tensor1_tp0),
|
43
|
+
# ]
|
44
|
+
named_tensors_batch = [
|
45
|
+
(
|
46
|
+
name,
|
47
|
+
MultiprocessingSerializer.serialize(
|
48
|
+
_preprocess_tensor_for_update_weights(tensor)
|
49
|
+
),
|
50
|
+
)
|
51
|
+
for name, tensor in params_batch
|
52
|
+
]
|
53
|
+
|
54
|
+
if infer_tp_rank == 0:
|
55
|
+
gathered_serialized_batches = [None for _ in range(infer_tp_size)]
|
56
|
+
else:
|
57
|
+
gathered_serialized_batches = None
|
58
|
+
|
59
|
+
# [
|
60
|
+
# [ (name0, ipc_tensor0_tp0), (name1, ipc_tensor1_tp0) ],
|
61
|
+
# [ (name0, ipc_tensor0_tp1), (name1, ipc_tensor1_tp1) ],
|
62
|
+
# ]
|
63
|
+
dist.gather_object(
|
64
|
+
obj=named_tensors_batch,
|
65
|
+
object_gather_list=gathered_serialized_batches,
|
66
|
+
dst=device_mesh[device_mesh_key].mesh.tolist()[0],
|
67
|
+
group=device_mesh[device_mesh_key].get_group(),
|
68
|
+
)
|
69
|
+
|
70
|
+
if infer_tp_rank == 0:
|
71
|
+
# Use zip(*) to "transpose" the data structure.
|
72
|
+
# After transpose, the data structure is like:
|
73
|
+
# [
|
74
|
+
# ( (name0, ipc_tensor0_tp0), (name0, ipc_tensor0_tp1) ),
|
75
|
+
# ( (name1, ipc_tensor1_tp0), (name1, ipc_tensor1_tp1) ),
|
76
|
+
# ]
|
77
|
+
logical_tensors = zip(*gathered_serialized_batches, strict=True)
|
78
|
+
|
79
|
+
named_tensors = [
|
80
|
+
# [
|
81
|
+
# (name0, LocalSerializedTensor(values=[ipc_tensor0_tp0, ipc_tensor0_tp1])),
|
82
|
+
# (name1, LocalSerializedTensor(values=[ipc_tensor1_tp0, ipc_tensor1_tp1])),
|
83
|
+
# ]
|
84
|
+
(
|
85
|
+
tensor_group[0][0],
|
86
|
+
LocalSerializedTensor(
|
87
|
+
values=[rank_part[1] for rank_part in tensor_group]
|
88
|
+
),
|
89
|
+
)
|
90
|
+
for tensor_group in logical_tensors
|
91
|
+
]
|
92
|
+
|
93
|
+
update_weights_request = UpdateWeightsFromTensorReqInput(
|
94
|
+
serialized_named_tensors=[
|
95
|
+
MultiprocessingSerializer.serialize(named_tensors)
|
96
|
+
for _ in range(infer_tp_size)
|
97
|
+
],
|
98
|
+
load_format=load_format,
|
99
|
+
)
|
100
|
+
|
101
|
+
return await engine.update_weights_from_tensor(update_weights_request)
|
102
|
+
|
103
|
+
|
104
|
+
def _preprocess_tensor_for_update_weights(tensor: torch.Tensor):
|
105
|
+
"""
|
106
|
+
Preprocess the tensor for update weights.
|
107
|
+
Example Use Case:
|
108
|
+
- FSDP: we gather tensor by calling full_tensor in _preprocess_tensor_for_update_weights
|
109
|
+
- Megatron: we do nothing here, assuming it is gathered when feed into this func
|
110
|
+
|
111
|
+
Args:
|
112
|
+
tensor: The tensor to be preprocessed.
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
The full tensor if it is a DTensor, otherwise the original tensor.
|
116
|
+
"""
|
117
|
+
if isinstance(tensor, DTensor):
|
118
|
+
return tensor.full_tensor()
|
119
|
+
return tensor
|
sglang/test/runners.py
CHANGED
@@ -491,6 +491,8 @@ class SRTRunner:
|
|
491
491
|
lora_paths: List[str] = None,
|
492
492
|
max_loras_per_batch: int = 4,
|
493
493
|
attention_backend: Optional[str] = None,
|
494
|
+
prefill_attention_backend: Optional[str] = None,
|
495
|
+
decode_attention_backend: Optional[str] = None,
|
494
496
|
lora_backend: str = "triton",
|
495
497
|
disable_cuda_graph: bool = False,
|
496
498
|
disable_radix_cache: bool = False,
|
@@ -540,6 +542,8 @@ class SRTRunner:
|
|
540
542
|
max_loras_per_batch=max_loras_per_batch,
|
541
543
|
lora_backend=lora_backend,
|
542
544
|
attention_backend=attention_backend,
|
545
|
+
prefill_attention_backend=prefill_attention_backend,
|
546
|
+
decode_attention_backend=decode_attention_backend,
|
543
547
|
disable_cuda_graph=disable_cuda_graph,
|
544
548
|
disable_radix_cache=disable_radix_cache,
|
545
549
|
chunked_prefill_size=chunked_prefill_size,
|
sglang/test/test_activation.py
CHANGED
@@ -3,9 +3,12 @@ import unittest
|
|
3
3
|
|
4
4
|
import torch
|
5
5
|
|
6
|
-
from sglang.srt.layers.activation import GeluAndMul
|
6
|
+
from sglang.srt.layers.activation import GeluAndMul, QuickGELU
|
7
|
+
from sglang.srt.utils import is_hip
|
7
8
|
from sglang.test.test_utils import CustomTestCase
|
8
9
|
|
10
|
+
_is_hip = is_hip()
|
11
|
+
|
9
12
|
|
10
13
|
class TestGeluAndMul(CustomTestCase):
|
11
14
|
DTYPES = [torch.half, torch.bfloat16]
|
@@ -52,5 +55,51 @@ class TestGeluAndMul(CustomTestCase):
|
|
52
55
|
self._run_gelu_and_mul_test(*params)
|
53
56
|
|
54
57
|
|
58
|
+
class TestQuickGELU(CustomTestCase):
|
59
|
+
DTYPES = [torch.half, torch.bfloat16]
|
60
|
+
NUM_TOKENS = [7, 83, 2048] # batch = sequence length
|
61
|
+
DIMS = [512, 4096, 5120, 13824] # all multiples of 16 bytes
|
62
|
+
SEEDS = [0]
|
63
|
+
|
64
|
+
@classmethod
|
65
|
+
def setUpClass(cls):
|
66
|
+
if not torch.cuda.is_available():
|
67
|
+
raise unittest.SkipTest("CUDA is not available")
|
68
|
+
torch.set_default_device("cuda")
|
69
|
+
|
70
|
+
def _run_gelu_quick_test(self, n_tok: int, dim: int, dtype: torch.dtype, seed: int):
|
71
|
+
torch.manual_seed(seed)
|
72
|
+
|
73
|
+
layer = QuickGELU().to(dtype=dtype)
|
74
|
+
|
75
|
+
x = torch.randn(n_tok, dim, dtype=dtype, device="cuda")
|
76
|
+
|
77
|
+
with torch.inference_mode():
|
78
|
+
ref = layer.forward_native(x) # x * sigmoid(1.702 * x), fp32 math
|
79
|
+
if _is_hip:
|
80
|
+
out = layer.forward_hip(x) # 128-bit vectorised kernel from sgl-kernel
|
81
|
+
else:
|
82
|
+
out = layer.forward_cuda(x)
|
83
|
+
|
84
|
+
tol = 1e-2 if dtype is torch.bfloat16 else 1e-3
|
85
|
+
self.assertTrue(
|
86
|
+
torch.allclose(out, ref, atol=tol, rtol=tol),
|
87
|
+
msg=f"Mismatch @ B={n_tok}, D={dim}, dtype={dtype}",
|
88
|
+
)
|
89
|
+
print(f"Match @ B={n_tok}, D={dim}, dtype={dtype}")
|
90
|
+
|
91
|
+
def test_quick_gelu(self):
|
92
|
+
for params in itertools.product(
|
93
|
+
self.NUM_TOKENS, self.DIMS, self.DTYPES, self.SEEDS
|
94
|
+
):
|
95
|
+
with self.subTest(
|
96
|
+
num_tokens=params[0],
|
97
|
+
dim=params[1],
|
98
|
+
dtype=params[2],
|
99
|
+
seed=params[3],
|
100
|
+
):
|
101
|
+
self._run_gelu_quick_test(*params)
|
102
|
+
|
103
|
+
|
55
104
|
if __name__ == "__main__":
|
56
105
|
unittest.main(verbosity=2)
|
sglang/test/test_utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
"""Common utilities for testing and benchmarking"""
|
2
2
|
|
3
3
|
import argparse
|
4
|
+
import asyncio
|
4
5
|
import copy
|
5
6
|
import json
|
6
7
|
import logging
|
@@ -14,8 +15,9 @@ import unittest
|
|
14
15
|
from concurrent.futures import ThreadPoolExecutor
|
15
16
|
from dataclasses import dataclass
|
16
17
|
from functools import partial
|
18
|
+
from pathlib import Path
|
17
19
|
from types import SimpleNamespace
|
18
|
-
from typing import Callable, List, Optional, Tuple
|
20
|
+
from typing import Awaitable, Callable, List, Optional, Tuple
|
19
21
|
|
20
22
|
import numpy as np
|
21
23
|
import requests
|
@@ -26,6 +28,7 @@ from sglang.bench_serving import run_benchmark
|
|
26
28
|
from sglang.global_config import global_config
|
27
29
|
from sglang.lang.backend.openai import OpenAI
|
28
30
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
31
|
+
from sglang.lang.interpreter import ProgramState
|
29
32
|
from sglang.srt.utils import (
|
30
33
|
get_bool_env_var,
|
31
34
|
get_device,
|
@@ -347,6 +350,7 @@ def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
|
|
347
350
|
help="Device type (auto/cuda/rocm/cpu). Auto will detect available platforms",
|
348
351
|
)
|
349
352
|
parser.add_argument("--result-file", type=str, default="result.jsonl")
|
353
|
+
parser.add_argument("--raw-result-file", type=str)
|
350
354
|
args = parser.parse_args()
|
351
355
|
|
352
356
|
return args
|
@@ -714,6 +718,7 @@ def get_benchmark_args(
|
|
714
718
|
seed: int = 0,
|
715
719
|
device="auto",
|
716
720
|
pd_separated: bool = False,
|
721
|
+
lora_name=None,
|
717
722
|
):
|
718
723
|
return SimpleNamespace(
|
719
724
|
backend="sglang",
|
@@ -741,7 +746,7 @@ def get_benchmark_args(
|
|
741
746
|
extra_request_body=None,
|
742
747
|
apply_chat_template=False,
|
743
748
|
profile=None,
|
744
|
-
lora_name=
|
749
|
+
lora_name=lora_name,
|
745
750
|
prompt_suffix="",
|
746
751
|
device=device,
|
747
752
|
pd_separated=pd_separated,
|
@@ -764,6 +769,8 @@ def run_bench_serving(
|
|
764
769
|
need_warmup=False,
|
765
770
|
seed: int = 0,
|
766
771
|
device="auto",
|
772
|
+
background_task: Optional[Callable[[str, asyncio.Event], Awaitable[None]]] = None,
|
773
|
+
lora_name: Optional[str] = None,
|
767
774
|
):
|
768
775
|
if device == "auto":
|
769
776
|
device = auto_config_device()
|
@@ -791,14 +798,35 @@ def run_bench_serving(
|
|
791
798
|
disable_ignore_eos=disable_ignore_eos,
|
792
799
|
seed=seed,
|
793
800
|
device=device,
|
801
|
+
lora_name=lora_name,
|
794
802
|
)
|
795
803
|
|
796
|
-
|
804
|
+
async def _run():
|
797
805
|
if need_warmup:
|
798
806
|
warmup_args = copy.deepcopy(args)
|
799
807
|
warmup_args.num_prompts = 16
|
800
|
-
run_benchmark
|
801
|
-
|
808
|
+
await asyncio.to_thread(run_benchmark, warmup_args)
|
809
|
+
|
810
|
+
start_event = asyncio.Event()
|
811
|
+
stop_event = asyncio.Event()
|
812
|
+
task_handle = (
|
813
|
+
asyncio.create_task(background_task(base_url, start_event, stop_event))
|
814
|
+
if background_task
|
815
|
+
else None
|
816
|
+
)
|
817
|
+
|
818
|
+
try:
|
819
|
+
start_event.set()
|
820
|
+
result = await asyncio.to_thread(run_benchmark, args)
|
821
|
+
finally:
|
822
|
+
if task_handle:
|
823
|
+
stop_event.set()
|
824
|
+
await task_handle
|
825
|
+
|
826
|
+
return result
|
827
|
+
|
828
|
+
try:
|
829
|
+
res = asyncio.run(_run())
|
802
830
|
finally:
|
803
831
|
kill_process_tree(process.pid)
|
804
832
|
|
@@ -1284,3 +1312,35 @@ class CustomTestCase(unittest.TestCase):
|
|
1284
1312
|
lambda: super(CustomTestCase, self)._callTestMethod(method),
|
1285
1313
|
max_retry=max_retry,
|
1286
1314
|
)
|
1315
|
+
|
1316
|
+
|
1317
|
+
def dump_bench_raw_result(
|
1318
|
+
path: str,
|
1319
|
+
states,
|
1320
|
+
preds,
|
1321
|
+
labels,
|
1322
|
+
):
|
1323
|
+
if not path:
|
1324
|
+
return
|
1325
|
+
|
1326
|
+
rows = []
|
1327
|
+
for i in range(len(states)):
|
1328
|
+
state = states[i]
|
1329
|
+
output = state["answer"]
|
1330
|
+
prompt = _ensure_remove_suffix(state.text(), output)
|
1331
|
+
rows.append(
|
1332
|
+
dict(
|
1333
|
+
prompt_id=i,
|
1334
|
+
prompt=prompt,
|
1335
|
+
output=output,
|
1336
|
+
correct=bool(preds[i] == labels[i]),
|
1337
|
+
)
|
1338
|
+
)
|
1339
|
+
|
1340
|
+
print(f"BenchRawResultDumper save results to {path}")
|
1341
|
+
Path(path).write_text("\n".join(json.dumps(row) for row in rows))
|
1342
|
+
|
1343
|
+
|
1344
|
+
def _ensure_remove_suffix(text: str, suffix: str):
|
1345
|
+
assert text.endswith(suffix)
|
1346
|
+
return text.removesuffix(suffix)
|
sglang/utils.py
CHANGED
@@ -14,6 +14,7 @@ import traceback
|
|
14
14
|
import urllib.request
|
15
15
|
import weakref
|
16
16
|
from concurrent.futures import ThreadPoolExecutor
|
17
|
+
from functools import wraps
|
17
18
|
from io import BytesIO
|
18
19
|
from json import dumps
|
19
20
|
from typing import Any, Callable, List, Optional, Tuple, Type, Union
|
@@ -28,6 +29,24 @@ from tqdm import tqdm
|
|
28
29
|
logger = logging.getLogger(__name__)
|
29
30
|
|
30
31
|
|
32
|
+
def execute_once(func):
|
33
|
+
has_run = None
|
34
|
+
|
35
|
+
@wraps(func)
|
36
|
+
def wrapper(*args, **kwargs):
|
37
|
+
nonlocal has_run
|
38
|
+
if not has_run:
|
39
|
+
func(*args, **kwargs)
|
40
|
+
has_run = True
|
41
|
+
|
42
|
+
return wrapper
|
43
|
+
|
44
|
+
|
45
|
+
@execute_once
|
46
|
+
def info_once(message: str):
|
47
|
+
logger.info(message)
|
48
|
+
|
49
|
+
|
31
50
|
def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str:
|
32
51
|
"""Convert a JSON schema to a string.
|
33
52
|
Parameters
|
sglang/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.4.9.
|
1
|
+
__version__ = "0.4.9.post5"
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: sglang
|
3
|
-
Version: 0.4.9.
|
3
|
+
Version: 0.4.9.post5
|
4
4
|
Summary: SGLang is yet another fast serving framework for large language models and vision language models.
|
5
5
|
License: Apache License
|
6
6
|
Version 2.0, January 2004
|
@@ -246,20 +246,20 @@ Requires-Dist: sentencepiece; extra == "runtime-common"
|
|
246
246
|
Requires-Dist: soundfile==0.13.1; extra == "runtime-common"
|
247
247
|
Requires-Dist: scipy; extra == "runtime-common"
|
248
248
|
Requires-Dist: torchao==0.9.0; extra == "runtime-common"
|
249
|
-
Requires-Dist: transformers==4.
|
249
|
+
Requires-Dist: transformers==4.54.0; extra == "runtime-common"
|
250
250
|
Requires-Dist: timm==1.0.16; extra == "runtime-common"
|
251
251
|
Requires-Dist: uvicorn; extra == "runtime-common"
|
252
252
|
Requires-Dist: uvloop; extra == "runtime-common"
|
253
253
|
Requires-Dist: xgrammar==0.1.21; extra == "runtime-common"
|
254
254
|
Provides-Extra: srt
|
255
255
|
Requires-Dist: sglang[runtime_common]; extra == "srt"
|
256
|
-
Requires-Dist: sgl-kernel==0.2.
|
256
|
+
Requires-Dist: sgl-kernel==0.2.7; extra == "srt"
|
257
257
|
Requires-Dist: torch==2.7.1; extra == "srt"
|
258
258
|
Requires-Dist: torchaudio==2.7.1; extra == "srt"
|
259
259
|
Requires-Dist: torchvision==0.22.1; extra == "srt"
|
260
260
|
Requires-Dist: cuda-python; extra == "srt"
|
261
261
|
Requires-Dist: einops; extra == "srt"
|
262
|
-
Requires-Dist: flashinfer_python==0.2.
|
262
|
+
Requires-Dist: flashinfer_python==0.2.9rc2; extra == "srt"
|
263
263
|
Provides-Extra: blackwell
|
264
264
|
Requires-Dist: sglang[runtime_common]; extra == "blackwell"
|
265
265
|
Requires-Dist: sgl-kernel; extra == "blackwell"
|
@@ -268,11 +268,11 @@ Requires-Dist: torchaudio==2.7.1; extra == "blackwell"
|
|
268
268
|
Requires-Dist: torchvision==0.22.1; extra == "blackwell"
|
269
269
|
Requires-Dist: cuda-python; extra == "blackwell"
|
270
270
|
Requires-Dist: einops; extra == "blackwell"
|
271
|
-
Requires-Dist: flashinfer_python==0.2.
|
271
|
+
Requires-Dist: flashinfer_python==0.2.9rc2; extra == "blackwell"
|
272
272
|
Provides-Extra: srt-hip
|
273
273
|
Requires-Dist: sglang[runtime_common]; extra == "srt-hip"
|
274
274
|
Requires-Dist: torch; extra == "srt-hip"
|
275
|
-
Requires-Dist: petit_kernel; extra == "srt-hip"
|
275
|
+
Requires-Dist: petit_kernel==0.0.2; extra == "srt-hip"
|
276
276
|
Provides-Extra: srt-xpu
|
277
277
|
Requires-Dist: sglang[runtime_common]; extra == "srt-xpu"
|
278
278
|
Provides-Extra: srt-hpu
|