sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +7 -0
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +16 -1
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mooncake/conn.py +16 -0
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +13 -1
- sglang/srt/entrypoints/openai/protocol.py +3 -1
- sglang/srt/entrypoints/openai/serving_base.py +5 -2
- sglang/srt/entrypoints/openai/serving_chat.py +132 -79
- sglang/srt/function_call/ebnf_composer.py +10 -3
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/qwen3_coder_detector.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +14 -3
- sglang/srt/layers/moe/ep_moe/layer.py +323 -242
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
- sglang/srt/layers/moe/topk.py +90 -24
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +27 -10
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/lora/lora_registry.py +93 -29
- sglang/srt/managers/cache_controller.py +9 -7
- sglang/srt/managers/data_parallel_controller.py +4 -0
- sglang/srt/managers/io_struct.py +12 -0
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +14 -8
- sglang/srt/managers/scheduler.py +64 -1
- sglang/srt/managers/scheduler_input_blocker.py +106 -0
- sglang/srt/managers/tokenizer_manager.py +80 -15
- sglang/srt/managers/tp_worker.py +8 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -2
- sglang/srt/model_executor/model_runner.py +83 -27
- sglang/srt/models/deepseek_v2.py +75 -84
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/qwen2_moe.py +2 -2
- sglang/srt/models/qwen3_moe.py +17 -71
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/poll_based_barrier.py +31 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +65 -6
- sglang/srt/two_batch_overlap.py +8 -3
- sglang/srt/utils.py +96 -1
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_utils.py +118 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/METADATA +5 -4
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/RECORD +97 -80
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
+
import asyncio
|
18
19
|
import builtins
|
19
20
|
import ctypes
|
20
21
|
import dataclasses
|
@@ -85,6 +86,8 @@ from torch.profiler import ProfilerActivity, profile, record_function
|
|
85
86
|
from torch.utils._contextlib import _DecoratorContextManager
|
86
87
|
from triton.runtime.cache import FileCacheManager
|
87
88
|
|
89
|
+
from sglang.srt.metrics.func_timer import enable_func_timer
|
90
|
+
|
88
91
|
logger = logging.getLogger(__name__)
|
89
92
|
|
90
93
|
show_time_cost = False
|
@@ -2049,7 +2052,7 @@ def rank0_log(msg: str):
|
|
2049
2052
|
logger.info(msg)
|
2050
2053
|
|
2051
2054
|
|
2052
|
-
def launch_dummy_health_check_server(host, port):
|
2055
|
+
def launch_dummy_health_check_server(host, port, enable_metrics):
|
2053
2056
|
import asyncio
|
2054
2057
|
|
2055
2058
|
import uvicorn
|
@@ -2067,6 +2070,11 @@ def launch_dummy_health_check_server(host, port):
|
|
2067
2070
|
"""Check the health of the http server."""
|
2068
2071
|
return Response(status_code=200)
|
2069
2072
|
|
2073
|
+
# Add prometheus middleware
|
2074
|
+
if enable_metrics:
|
2075
|
+
add_prometheus_middleware(app)
|
2076
|
+
enable_func_timer()
|
2077
|
+
|
2070
2078
|
config = uvicorn.Config(
|
2071
2079
|
app,
|
2072
2080
|
host=host,
|
@@ -2335,6 +2343,7 @@ def is_fa3_default_architecture(hf_config):
|
|
2335
2343
|
"Gemma3ForConditionalGeneration",
|
2336
2344
|
"Qwen3ForCausalLM",
|
2337
2345
|
"Qwen3MoeForCausalLM",
|
2346
|
+
"Glm4MoeForCausalLM",
|
2338
2347
|
}
|
2339
2348
|
return architectures[0] in default_archs
|
2340
2349
|
|
@@ -2855,3 +2864,89 @@ SUPPORTED_LORA_TARGET_MODULES = [
|
|
2855
2864
|
]
|
2856
2865
|
|
2857
2866
|
LORA_TARGET_ALL_MODULES = "all"
|
2867
|
+
|
2868
|
+
|
2869
|
+
class ConcurrentCounter:
|
2870
|
+
"""
|
2871
|
+
An asynchronous counter for managing concurrent tasks that need
|
2872
|
+
coordinated increments, decrements, and waiting until the count reaches zero.
|
2873
|
+
|
2874
|
+
This class is useful for scenarios like tracking the number of in-flight tasks
|
2875
|
+
and waiting for them to complete.
|
2876
|
+
"""
|
2877
|
+
|
2878
|
+
def __init__(self, initial: int = 0):
|
2879
|
+
"""
|
2880
|
+
Initialize the counter with an optional initial value.
|
2881
|
+
|
2882
|
+
Args:
|
2883
|
+
initial (int): The initial value of the counter. Default is 0.
|
2884
|
+
"""
|
2885
|
+
self._count = initial
|
2886
|
+
self._condition = asyncio.Condition()
|
2887
|
+
|
2888
|
+
def value(self) -> int:
|
2889
|
+
"""
|
2890
|
+
Return the current value of the counter.
|
2891
|
+
|
2892
|
+
Note:
|
2893
|
+
This method is not synchronized. It may return a stale value
|
2894
|
+
if other coroutines are concurrently modifying the counter.
|
2895
|
+
|
2896
|
+
Returns:
|
2897
|
+
int: The current counter value.
|
2898
|
+
"""
|
2899
|
+
return self._count
|
2900
|
+
|
2901
|
+
def __repr__(self) -> str:
|
2902
|
+
"""Return an informative string representation of the counter."""
|
2903
|
+
return f"<ConcurrentCounter value={self.value()}>"
|
2904
|
+
|
2905
|
+
async def increment(self, n: int = 1, notify_all: bool = True):
|
2906
|
+
"""
|
2907
|
+
Atomically increment the counter by a given amount and notify all waiters.
|
2908
|
+
|
2909
|
+
Args:
|
2910
|
+
n (int): The amount to increment the counter by. Default is 1.
|
2911
|
+
notify_all (bool): Whether to notify all waiters after incrementing. Default is True.
|
2912
|
+
"""
|
2913
|
+
async with self._condition:
|
2914
|
+
self._count += n
|
2915
|
+
if notify_all:
|
2916
|
+
self._condition.notify_all()
|
2917
|
+
|
2918
|
+
async def decrement(self, n: int = 1, notify_all: bool = True):
|
2919
|
+
"""
|
2920
|
+
Atomically decrement the counter by a given amount and notify all waiters.
|
2921
|
+
|
2922
|
+
Args:
|
2923
|
+
n (int): The amount to decrement the counter by. Default is 1.
|
2924
|
+
notify_all (bool): Whether to notify all waiters after decrementing. Default is True.
|
2925
|
+
"""
|
2926
|
+
async with self._condition:
|
2927
|
+
self._count -= n
|
2928
|
+
if notify_all:
|
2929
|
+
self._condition.notify_all()
|
2930
|
+
|
2931
|
+
async def wait_for(self, condition: Callable[[int], bool]):
|
2932
|
+
"""
|
2933
|
+
Asynchronously wait until the counter satisfies a given condition.
|
2934
|
+
|
2935
|
+
This suspends the calling coroutine without blocking the thread, allowing
|
2936
|
+
other tasks to run while waiting. When the condition is met, the coroutine resumes.
|
2937
|
+
|
2938
|
+
Args:
|
2939
|
+
condition (Callable[[int], bool]): A function that takes the current counter value
|
2940
|
+
and returns True when the condition is satisfied.
|
2941
|
+
"""
|
2942
|
+
async with self._condition:
|
2943
|
+
await self._condition.wait_for(lambda: condition(self._count))
|
2944
|
+
|
2945
|
+
async def wait_for_zero(self):
|
2946
|
+
"""
|
2947
|
+
Asynchronously wait until the counter reaches zero.
|
2948
|
+
|
2949
|
+
This suspends the calling coroutine without blocking the thread, allowing
|
2950
|
+
other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
|
2951
|
+
"""
|
2952
|
+
self.wait_for(lambda count: count == 0)
|
@@ -0,0 +1,119 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.distributed as dist
|
5
|
+
from torch.distributed.device_mesh import DeviceMesh
|
6
|
+
from torch.distributed.tensor import DTensor
|
7
|
+
|
8
|
+
from sglang.srt.entrypoints.engine import Engine
|
9
|
+
from sglang.srt.managers.tokenizer_manager import UpdateWeightsFromTensorReqInput
|
10
|
+
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
11
|
+
from sglang.srt.utils import MultiprocessingSerializer
|
12
|
+
|
13
|
+
|
14
|
+
async def update_weights(
|
15
|
+
engine: Engine,
|
16
|
+
params_batch: list[tuple[str, torch.Tensor]],
|
17
|
+
device_mesh_key: str,
|
18
|
+
device_mesh: DeviceMesh,
|
19
|
+
load_format: Optional[str] = None,
|
20
|
+
):
|
21
|
+
"""
|
22
|
+
Update weights for the inference engine.
|
23
|
+
This function is designed to be stateless, so that the caller process could keep the stateful engine.
|
24
|
+
Example Use Case:
|
25
|
+
- Multiple Producer Process will call this function in a SPMD style
|
26
|
+
|
27
|
+
Args:
|
28
|
+
engine: The inference engine created by the caller process.
|
29
|
+
params_batch: A list of (name, tensor) tuples. We batched the tensors to avoid the overhead of cpu call.
|
30
|
+
device_mesh_key: The key of the device mesh. Typically "tp" or "infer_tp"
|
31
|
+
device_mesh: The device mesh.
|
32
|
+
load_format: The format of the weights.
|
33
|
+
"""
|
34
|
+
infer_tp_size = device_mesh[device_mesh_key].mesh.size()[0]
|
35
|
+
infer_tp_rank = device_mesh[device_mesh_key].get_local_rank()
|
36
|
+
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
37
|
+
|
38
|
+
monkey_patch_torch_reductions()
|
39
|
+
|
40
|
+
# [
|
41
|
+
# (name0, ipc_tensor0_tp0),
|
42
|
+
# (name1, ipc_tensor1_tp0),
|
43
|
+
# ]
|
44
|
+
named_tensors_batch = [
|
45
|
+
(
|
46
|
+
name,
|
47
|
+
MultiprocessingSerializer.serialize(
|
48
|
+
_preprocess_tensor_for_update_weights(tensor)
|
49
|
+
),
|
50
|
+
)
|
51
|
+
for name, tensor in params_batch
|
52
|
+
]
|
53
|
+
|
54
|
+
if infer_tp_rank == 0:
|
55
|
+
gathered_serialized_batches = [None for _ in range(infer_tp_size)]
|
56
|
+
else:
|
57
|
+
gathered_serialized_batches = None
|
58
|
+
|
59
|
+
# [
|
60
|
+
# [ (name0, ipc_tensor0_tp0), (name1, ipc_tensor1_tp0) ],
|
61
|
+
# [ (name0, ipc_tensor0_tp1), (name1, ipc_tensor1_tp1) ],
|
62
|
+
# ]
|
63
|
+
dist.gather_object(
|
64
|
+
obj=named_tensors_batch,
|
65
|
+
object_gather_list=gathered_serialized_batches,
|
66
|
+
dst=device_mesh[device_mesh_key].mesh.tolist()[0],
|
67
|
+
group=device_mesh[device_mesh_key].get_group(),
|
68
|
+
)
|
69
|
+
|
70
|
+
if infer_tp_rank == 0:
|
71
|
+
# Use zip(*) to "transpose" the data structure.
|
72
|
+
# After transpose, the data structure is like:
|
73
|
+
# [
|
74
|
+
# ( (name0, ipc_tensor0_tp0), (name0, ipc_tensor0_tp1) ),
|
75
|
+
# ( (name1, ipc_tensor1_tp0), (name1, ipc_tensor1_tp1) ),
|
76
|
+
# ]
|
77
|
+
logical_tensors = zip(*gathered_serialized_batches, strict=True)
|
78
|
+
|
79
|
+
named_tensors = [
|
80
|
+
# [
|
81
|
+
# (name0, LocalSerializedTensor(values=[ipc_tensor0_tp0, ipc_tensor0_tp1])),
|
82
|
+
# (name1, LocalSerializedTensor(values=[ipc_tensor1_tp0, ipc_tensor1_tp1])),
|
83
|
+
# ]
|
84
|
+
(
|
85
|
+
tensor_group[0][0],
|
86
|
+
LocalSerializedTensor(
|
87
|
+
values=[rank_part[1] for rank_part in tensor_group]
|
88
|
+
),
|
89
|
+
)
|
90
|
+
for tensor_group in logical_tensors
|
91
|
+
]
|
92
|
+
|
93
|
+
update_weights_request = UpdateWeightsFromTensorReqInput(
|
94
|
+
serialized_named_tensors=[
|
95
|
+
MultiprocessingSerializer.serialize(named_tensors)
|
96
|
+
for _ in range(infer_tp_size)
|
97
|
+
],
|
98
|
+
load_format=load_format,
|
99
|
+
)
|
100
|
+
|
101
|
+
return await engine.update_weights_from_tensor(update_weights_request)
|
102
|
+
|
103
|
+
|
104
|
+
def _preprocess_tensor_for_update_weights(tensor: torch.Tensor):
|
105
|
+
"""
|
106
|
+
Preprocess the tensor for update weights.
|
107
|
+
Example Use Case:
|
108
|
+
- FSDP: we gather tensor by calling full_tensor in _preprocess_tensor_for_update_weights
|
109
|
+
- Megatron: we do nothing here, assuming it is gathered when feed into this func
|
110
|
+
|
111
|
+
Args:
|
112
|
+
tensor: The tensor to be preprocessed.
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
The full tensor if it is a DTensor, otherwise the original tensor.
|
116
|
+
"""
|
117
|
+
if isinstance(tensor, DTensor):
|
118
|
+
return tensor.full_tensor()
|
119
|
+
return tensor
|
sglang/test/runners.py
CHANGED
@@ -491,6 +491,8 @@ class SRTRunner:
|
|
491
491
|
lora_paths: List[str] = None,
|
492
492
|
max_loras_per_batch: int = 4,
|
493
493
|
attention_backend: Optional[str] = None,
|
494
|
+
prefill_attention_backend: Optional[str] = None,
|
495
|
+
decode_attention_backend: Optional[str] = None,
|
494
496
|
lora_backend: str = "triton",
|
495
497
|
disable_cuda_graph: bool = False,
|
496
498
|
disable_radix_cache: bool = False,
|
@@ -540,6 +542,8 @@ class SRTRunner:
|
|
540
542
|
max_loras_per_batch=max_loras_per_batch,
|
541
543
|
lora_backend=lora_backend,
|
542
544
|
attention_backend=attention_backend,
|
545
|
+
prefill_attention_backend=prefill_attention_backend,
|
546
|
+
decode_attention_backend=decode_attention_backend,
|
543
547
|
disable_cuda_graph=disable_cuda_graph,
|
544
548
|
disable_radix_cache=disable_radix_cache,
|
545
549
|
chunked_prefill_size=chunked_prefill_size,
|
sglang/test/test_utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
"""Common utilities for testing and benchmarking"""
|
2
2
|
|
3
3
|
import argparse
|
4
|
+
import asyncio
|
4
5
|
import copy
|
5
6
|
import json
|
6
7
|
import logging
|
@@ -14,9 +15,11 @@ import unittest
|
|
14
15
|
from concurrent.futures import ThreadPoolExecutor
|
15
16
|
from dataclasses import dataclass
|
16
17
|
from functools import partial
|
18
|
+
from pathlib import Path
|
17
19
|
from types import SimpleNamespace
|
18
|
-
from typing import Callable, List, Optional, Tuple
|
20
|
+
from typing import Awaitable, Callable, List, Optional, Tuple
|
19
21
|
|
22
|
+
import aiohttp
|
20
23
|
import numpy as np
|
21
24
|
import requests
|
22
25
|
import torch
|
@@ -26,6 +29,7 @@ from sglang.bench_serving import run_benchmark
|
|
26
29
|
from sglang.global_config import global_config
|
27
30
|
from sglang.lang.backend.openai import OpenAI
|
28
31
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
32
|
+
from sglang.lang.interpreter import ProgramState
|
29
33
|
from sglang.srt.utils import (
|
30
34
|
get_bool_env_var,
|
31
35
|
get_device,
|
@@ -347,6 +351,7 @@ def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
|
|
347
351
|
help="Device type (auto/cuda/rocm/cpu). Auto will detect available platforms",
|
348
352
|
)
|
349
353
|
parser.add_argument("--result-file", type=str, default="result.jsonl")
|
354
|
+
parser.add_argument("--raw-result-file", type=str)
|
350
355
|
args = parser.parse_args()
|
351
356
|
|
352
357
|
return args
|
@@ -714,6 +719,7 @@ def get_benchmark_args(
|
|
714
719
|
seed: int = 0,
|
715
720
|
device="auto",
|
716
721
|
pd_separated: bool = False,
|
722
|
+
lora_name=None,
|
717
723
|
):
|
718
724
|
return SimpleNamespace(
|
719
725
|
backend="sglang",
|
@@ -741,7 +747,7 @@ def get_benchmark_args(
|
|
741
747
|
extra_request_body=None,
|
742
748
|
apply_chat_template=False,
|
743
749
|
profile=None,
|
744
|
-
lora_name=
|
750
|
+
lora_name=lora_name,
|
745
751
|
prompt_suffix="",
|
746
752
|
device=device,
|
747
753
|
pd_separated=pd_separated,
|
@@ -764,6 +770,8 @@ def run_bench_serving(
|
|
764
770
|
need_warmup=False,
|
765
771
|
seed: int = 0,
|
766
772
|
device="auto",
|
773
|
+
background_task: Optional[Callable[[str, asyncio.Event], Awaitable[None]]] = None,
|
774
|
+
lora_name: Optional[str] = None,
|
767
775
|
):
|
768
776
|
if device == "auto":
|
769
777
|
device = auto_config_device()
|
@@ -791,14 +799,35 @@ def run_bench_serving(
|
|
791
799
|
disable_ignore_eos=disable_ignore_eos,
|
792
800
|
seed=seed,
|
793
801
|
device=device,
|
802
|
+
lora_name=lora_name,
|
794
803
|
)
|
795
804
|
|
796
|
-
|
805
|
+
async def _run():
|
797
806
|
if need_warmup:
|
798
807
|
warmup_args = copy.deepcopy(args)
|
799
808
|
warmup_args.num_prompts = 16
|
800
|
-
run_benchmark
|
801
|
-
|
809
|
+
await asyncio.to_thread(run_benchmark, warmup_args)
|
810
|
+
|
811
|
+
start_event = asyncio.Event()
|
812
|
+
stop_event = asyncio.Event()
|
813
|
+
task_handle = (
|
814
|
+
asyncio.create_task(background_task(base_url, start_event, stop_event))
|
815
|
+
if background_task
|
816
|
+
else None
|
817
|
+
)
|
818
|
+
|
819
|
+
try:
|
820
|
+
start_event.set()
|
821
|
+
result = await asyncio.to_thread(run_benchmark, args)
|
822
|
+
finally:
|
823
|
+
if task_handle:
|
824
|
+
stop_event.set()
|
825
|
+
await task_handle
|
826
|
+
|
827
|
+
return result
|
828
|
+
|
829
|
+
try:
|
830
|
+
res = asyncio.run(_run())
|
802
831
|
finally:
|
803
832
|
kill_process_tree(process.pid)
|
804
833
|
|
@@ -1275,6 +1304,58 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
|
|
1275
1304
|
raise
|
1276
1305
|
|
1277
1306
|
|
1307
|
+
def send_generate_requests(base_url: str, num_requests: int) -> List[str]:
|
1308
|
+
"""Sends generate request serially and returns status codes. Max concurrency is 1."""
|
1309
|
+
|
1310
|
+
def generate():
|
1311
|
+
prompt = """
|
1312
|
+
System: You are a helpful assistant.
|
1313
|
+
User: What is the capital of France?
|
1314
|
+
Assistant: The capital of France is
|
1315
|
+
"""
|
1316
|
+
response = requests.post(
|
1317
|
+
f"{base_url}/generate",
|
1318
|
+
json={
|
1319
|
+
"text": prompt,
|
1320
|
+
"sampling_params": {
|
1321
|
+
"temperature": 0,
|
1322
|
+
"max_new_tokens": 50,
|
1323
|
+
},
|
1324
|
+
},
|
1325
|
+
)
|
1326
|
+
return response.status_code
|
1327
|
+
|
1328
|
+
return [generate() for _ in range(num_requests)]
|
1329
|
+
|
1330
|
+
|
1331
|
+
async def send_concurrent_generate_requests(
|
1332
|
+
base_url: str, num_requests: int
|
1333
|
+
) -> List[str]:
|
1334
|
+
"""Sends generate request concurrently and returns status codes. Max concurrency is num_requests."""
|
1335
|
+
|
1336
|
+
async def async_generate():
|
1337
|
+
async with aiohttp.ClientSession() as session:
|
1338
|
+
prompt = """
|
1339
|
+
System: You are a helpful assistant.
|
1340
|
+
User: What is the capital of France?
|
1341
|
+
Assistant: The capital of France is
|
1342
|
+
"""
|
1343
|
+
async with session.post(
|
1344
|
+
f"{base_url}/generate",
|
1345
|
+
json={
|
1346
|
+
"text": prompt,
|
1347
|
+
"sampling_params": {
|
1348
|
+
"temperature": 0,
|
1349
|
+
"max_new_tokens": 50,
|
1350
|
+
},
|
1351
|
+
},
|
1352
|
+
) as response:
|
1353
|
+
return response.status
|
1354
|
+
|
1355
|
+
tasks = [asyncio.create_task(async_generate()) for _ in range(num_requests)]
|
1356
|
+
return await asyncio.gather(*tasks)
|
1357
|
+
|
1358
|
+
|
1278
1359
|
class CustomTestCase(unittest.TestCase):
|
1279
1360
|
def _callTestMethod(self, method):
|
1280
1361
|
max_retry = int(
|
@@ -1284,3 +1365,35 @@ class CustomTestCase(unittest.TestCase):
|
|
1284
1365
|
lambda: super(CustomTestCase, self)._callTestMethod(method),
|
1285
1366
|
max_retry=max_retry,
|
1286
1367
|
)
|
1368
|
+
|
1369
|
+
|
1370
|
+
def dump_bench_raw_result(
|
1371
|
+
path: str,
|
1372
|
+
states,
|
1373
|
+
preds,
|
1374
|
+
labels,
|
1375
|
+
):
|
1376
|
+
if not path:
|
1377
|
+
return
|
1378
|
+
|
1379
|
+
rows = []
|
1380
|
+
for i in range(len(states)):
|
1381
|
+
state = states[i]
|
1382
|
+
output = state["answer"]
|
1383
|
+
prompt = _ensure_remove_suffix(state.text(), output)
|
1384
|
+
rows.append(
|
1385
|
+
dict(
|
1386
|
+
prompt_id=i,
|
1387
|
+
prompt=prompt,
|
1388
|
+
output=output,
|
1389
|
+
correct=bool(preds[i] == labels[i]),
|
1390
|
+
)
|
1391
|
+
)
|
1392
|
+
|
1393
|
+
print(f"BenchRawResultDumper save results to {path}")
|
1394
|
+
Path(path).write_text("\n".join(json.dumps(row) for row in rows))
|
1395
|
+
|
1396
|
+
|
1397
|
+
def _ensure_remove_suffix(text: str, suffix: str):
|
1398
|
+
assert text.endswith(suffix)
|
1399
|
+
return text.removesuffix(suffix)
|
sglang/utils.py
CHANGED
@@ -14,6 +14,7 @@ import traceback
|
|
14
14
|
import urllib.request
|
15
15
|
import weakref
|
16
16
|
from concurrent.futures import ThreadPoolExecutor
|
17
|
+
from functools import wraps
|
17
18
|
from io import BytesIO
|
18
19
|
from json import dumps
|
19
20
|
from typing import Any, Callable, List, Optional, Tuple, Type, Union
|
@@ -28,6 +29,24 @@ from tqdm import tqdm
|
|
28
29
|
logger = logging.getLogger(__name__)
|
29
30
|
|
30
31
|
|
32
|
+
def execute_once(func):
|
33
|
+
has_run = None
|
34
|
+
|
35
|
+
@wraps(func)
|
36
|
+
def wrapper(*args, **kwargs):
|
37
|
+
nonlocal has_run
|
38
|
+
if not has_run:
|
39
|
+
func(*args, **kwargs)
|
40
|
+
has_run = True
|
41
|
+
|
42
|
+
return wrapper
|
43
|
+
|
44
|
+
|
45
|
+
@execute_once
|
46
|
+
def info_once(message: str):
|
47
|
+
logger.info(message)
|
48
|
+
|
49
|
+
|
31
50
|
def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str:
|
32
51
|
"""Convert a JSON schema to a string.
|
33
52
|
Parameters
|
sglang/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.4.9.
|
1
|
+
__version__ = "0.4.9.post6"
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: sglang
|
3
|
-
Version: 0.4.9.
|
3
|
+
Version: 0.4.9.post6
|
4
4
|
Summary: SGLang is yet another fast serving framework for large language models and vision language models.
|
5
5
|
License: Apache License
|
6
6
|
Version 2.0, January 2004
|
@@ -246,7 +246,7 @@ Requires-Dist: sentencepiece; extra == "runtime-common"
|
|
246
246
|
Requires-Dist: soundfile==0.13.1; extra == "runtime-common"
|
247
247
|
Requires-Dist: scipy; extra == "runtime-common"
|
248
248
|
Requires-Dist: torchao==0.9.0; extra == "runtime-common"
|
249
|
-
Requires-Dist: transformers==4.
|
249
|
+
Requires-Dist: transformers==4.54.0; extra == "runtime-common"
|
250
250
|
Requires-Dist: timm==1.0.16; extra == "runtime-common"
|
251
251
|
Requires-Dist: uvicorn; extra == "runtime-common"
|
252
252
|
Requires-Dist: uvloop; extra == "runtime-common"
|
@@ -259,7 +259,7 @@ Requires-Dist: torchaudio==2.7.1; extra == "srt"
|
|
259
259
|
Requires-Dist: torchvision==0.22.1; extra == "srt"
|
260
260
|
Requires-Dist: cuda-python; extra == "srt"
|
261
261
|
Requires-Dist: einops; extra == "srt"
|
262
|
-
Requires-Dist: flashinfer_python==0.2.
|
262
|
+
Requires-Dist: flashinfer_python==0.2.9rc2; extra == "srt"
|
263
263
|
Provides-Extra: blackwell
|
264
264
|
Requires-Dist: sglang[runtime_common]; extra == "blackwell"
|
265
265
|
Requires-Dist: sgl-kernel; extra == "blackwell"
|
@@ -268,7 +268,8 @@ Requires-Dist: torchaudio==2.7.1; extra == "blackwell"
|
|
268
268
|
Requires-Dist: torchvision==0.22.1; extra == "blackwell"
|
269
269
|
Requires-Dist: cuda-python; extra == "blackwell"
|
270
270
|
Requires-Dist: einops; extra == "blackwell"
|
271
|
-
Requires-Dist: flashinfer_python==0.2.
|
271
|
+
Requires-Dist: flashinfer_python==0.2.9rc2; extra == "blackwell"
|
272
|
+
Requires-Dist: tiktoken; extra == "blackwell"
|
272
273
|
Provides-Extra: srt-hip
|
273
274
|
Requires-Dist: sglang[runtime_common]; extra == "srt-hip"
|
274
275
|
Requires-Dist: torch; extra == "srt-hip"
|