sglang 0.3.5__py3-none-any.whl → 0.3.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +309 -0
- sglang/bench_serving.py +148 -24
- sglang/srt/configs/model_config.py +5 -2
- sglang/srt/constrained/__init__.py +2 -66
- sglang/srt/constrained/base_grammar_backend.py +73 -0
- sglang/srt/constrained/outlines_backend.py +165 -0
- sglang/srt/constrained/outlines_jump_forward.py +182 -0
- sglang/srt/constrained/xgrammar_backend.py +150 -0
- sglang/srt/layers/attention/triton_ops/decode_attention.py +7 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +6 -0
- sglang/srt/layers/fused_moe/fused_moe.py +23 -7
- sglang/srt/layers/fused_moe/patch.py +4 -2
- sglang/srt/layers/quantization/base_config.py +4 -6
- sglang/srt/layers/vocab_parallel_embedding.py +216 -150
- sglang/srt/managers/detokenizer_manager.py +0 -14
- sglang/srt/managers/io_struct.py +5 -3
- sglang/srt/managers/schedule_batch.py +14 -20
- sglang/srt/managers/scheduler.py +159 -96
- sglang/srt/managers/tokenizer_manager.py +81 -17
- sglang/srt/metrics/collector.py +211 -0
- sglang/srt/metrics/func_timer.py +108 -0
- sglang/srt/mm_utils.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/model_executor/forward_batch_info.py +7 -3
- sglang/srt/model_executor/model_runner.py +6 -2
- sglang/srt/models/gemma2_reward.py +69 -0
- sglang/srt/models/gpt2.py +31 -37
- sglang/srt/models/internlm2_reward.py +62 -0
- sglang/srt/models/llama.py +11 -6
- sglang/srt/models/llama_reward.py +5 -26
- sglang/srt/models/qwen2_vl.py +5 -7
- sglang/srt/openai_api/adapter.py +11 -4
- sglang/srt/openai_api/protocol.py +29 -26
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/sampling/sampling_params.py +2 -16
- sglang/srt/server.py +60 -17
- sglang/srt/server_args.py +66 -25
- sglang/srt/utils.py +120 -0
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_mgsm.py +2 -2
- sglang/test/test_utils.py +21 -7
- sglang/utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/METADATA +12 -8
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/RECORD +49 -45
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/WHEEL +1 -1
- sglang/srt/constrained/base_tool_cache.py +0 -65
- sglang/srt/constrained/bnf_cache.py +0 -61
- sglang/srt/constrained/fsm_cache.py +0 -95
- sglang/srt/constrained/grammar.py +0 -190
- sglang/srt/constrained/jump_forward.py +0 -203
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/LICENSE +0 -0
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/top_level.txt +0 -0
sglang/srt/server.py
CHANGED
@@ -30,12 +30,11 @@ import time
|
|
30
30
|
from http import HTTPStatus
|
31
31
|
from typing import AsyncIterator, Dict, List, Optional, Union
|
32
32
|
|
33
|
-
import orjson
|
34
|
-
|
35
33
|
# Fix a bug of Python threading
|
36
34
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
37
35
|
|
38
36
|
import aiohttp
|
37
|
+
import orjson
|
39
38
|
import requests
|
40
39
|
import uvicorn
|
41
40
|
import uvloop
|
@@ -57,6 +56,7 @@ from sglang.srt.managers.io_struct import (
|
|
57
56
|
)
|
58
57
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
59
58
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
59
|
+
from sglang.srt.metrics.func_timer import enable_func_timer, time_func_latency
|
60
60
|
from sglang.srt.openai_api.adapter import (
|
61
61
|
load_chat_template_for_openai_api,
|
62
62
|
v1_batches,
|
@@ -74,12 +74,15 @@ from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
|
74
74
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
75
75
|
from sglang.srt.utils import (
|
76
76
|
add_api_key_middleware,
|
77
|
+
add_prometheus_middleware,
|
77
78
|
assert_pkg_version,
|
78
79
|
configure_logger,
|
80
|
+
delete_directory,
|
79
81
|
is_port_available,
|
80
82
|
kill_child_process,
|
81
83
|
maybe_set_triton_cache_manager,
|
82
84
|
prepare_model_and_tokenizer,
|
85
|
+
set_prometheus_multiproc_dir,
|
83
86
|
set_ulimit,
|
84
87
|
)
|
85
88
|
from sglang.utils import get_exception_traceback
|
@@ -90,8 +93,6 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
90
93
|
|
91
94
|
|
92
95
|
app = FastAPI()
|
93
|
-
tokenizer_manager: TokenizerManager = None
|
94
|
-
|
95
96
|
app.add_middleware(
|
96
97
|
CORSMiddleware,
|
97
98
|
allow_origins=["*"],
|
@@ -100,6 +101,10 @@ app.add_middleware(
|
|
100
101
|
allow_headers=["*"],
|
101
102
|
)
|
102
103
|
|
104
|
+
tokenizer_manager: TokenizerManager = None
|
105
|
+
|
106
|
+
##### Native API endpoints #####
|
107
|
+
|
103
108
|
|
104
109
|
@app.get("/health")
|
105
110
|
async def health() -> Response:
|
@@ -110,9 +115,16 @@ async def health() -> Response:
|
|
110
115
|
@app.get("/health_generate")
|
111
116
|
async def health_generate(request: Request) -> Response:
|
112
117
|
"""Check the health of the inference server by generating one token."""
|
113
|
-
|
114
|
-
|
115
|
-
|
118
|
+
|
119
|
+
if tokenizer_manager.is_generation:
|
120
|
+
gri = GenerateReqInput(
|
121
|
+
input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
|
122
|
+
)
|
123
|
+
else:
|
124
|
+
gri = EmbeddingReqInput(
|
125
|
+
input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
|
126
|
+
)
|
127
|
+
|
116
128
|
try:
|
117
129
|
async for _ in tokenizer_manager.generate_request(gri, request):
|
118
130
|
break
|
@@ -127,6 +139,7 @@ async def get_model_info():
|
|
127
139
|
"""Get the model information."""
|
128
140
|
result = {
|
129
141
|
"model_path": tokenizer_manager.model_path,
|
142
|
+
"tokenizer_path": tokenizer_manager.server_args.tokenizer_path,
|
130
143
|
"is_generation": tokenizer_manager.is_generation,
|
131
144
|
}
|
132
145
|
return result
|
@@ -185,6 +198,7 @@ async def get_memory_pool_size():
|
|
185
198
|
|
186
199
|
|
187
200
|
@app.post("/update_weights")
|
201
|
+
@time_func_latency
|
188
202
|
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
189
203
|
"""Update the weights inplace without re-launching the server."""
|
190
204
|
success, message = await tokenizer_manager.update_weights(obj, request)
|
@@ -201,7 +215,7 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
|
201
215
|
)
|
202
216
|
|
203
217
|
|
204
|
-
|
218
|
+
@time_func_latency
|
205
219
|
async def generate_request(obj: GenerateReqInput, request: Request):
|
206
220
|
"""Handle a generate request."""
|
207
221
|
if obj.stream:
|
@@ -234,10 +248,12 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|
234
248
|
)
|
235
249
|
|
236
250
|
|
251
|
+
# fastapi implicitly converts json in the request to obj (dataclass)
|
237
252
|
app.post("/generate")(generate_request)
|
238
253
|
app.put("/generate")(generate_request)
|
239
254
|
|
240
255
|
|
256
|
+
@time_func_latency
|
241
257
|
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
242
258
|
"""Handle an embedding request."""
|
243
259
|
try:
|
@@ -253,7 +269,8 @@ app.post("/encode")(encode_request)
|
|
253
269
|
app.put("/encode")(encode_request)
|
254
270
|
|
255
271
|
|
256
|
-
|
272
|
+
@time_func_latency
|
273
|
+
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
257
274
|
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
258
275
|
try:
|
259
276
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
@@ -264,21 +281,27 @@ async def judge_request(obj: EmbeddingReqInput, request: Request):
|
|
264
281
|
)
|
265
282
|
|
266
283
|
|
267
|
-
app.post("/
|
268
|
-
app.put("/
|
284
|
+
app.post("/classify")(classify_request)
|
285
|
+
app.put("/classify")(classify_request)
|
286
|
+
|
287
|
+
|
288
|
+
##### OpenAI-compatible API endpoints #####
|
269
289
|
|
270
290
|
|
271
291
|
@app.post("/v1/completions")
|
292
|
+
@time_func_latency
|
272
293
|
async def openai_v1_completions(raw_request: Request):
|
273
294
|
return await v1_completions(tokenizer_manager, raw_request)
|
274
295
|
|
275
296
|
|
276
297
|
@app.post("/v1/chat/completions")
|
298
|
+
@time_func_latency
|
277
299
|
async def openai_v1_chat_completions(raw_request: Request):
|
278
300
|
return await v1_chat_completions(tokenizer_manager, raw_request)
|
279
301
|
|
280
302
|
|
281
303
|
@app.post("/v1/embeddings", response_class=ORJSONResponse)
|
304
|
+
@time_func_latency
|
282
305
|
async def openai_v1_embeddings(raw_request: Request):
|
283
306
|
response = await v1_embeddings(tokenizer_manager, raw_request)
|
284
307
|
return response
|
@@ -432,13 +455,17 @@ def launch_server(
|
|
432
455
|
1. The HTTP server and Tokenizer Manager both run in the main process.
|
433
456
|
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
|
434
457
|
"""
|
435
|
-
|
436
458
|
launch_engine(server_args=server_args)
|
437
459
|
|
438
460
|
# Add api key authorization
|
439
461
|
if server_args.api_key:
|
440
462
|
add_api_key_middleware(app, server_args.api_key)
|
441
463
|
|
464
|
+
# add prometheus middleware
|
465
|
+
if server_args.enable_metrics:
|
466
|
+
add_prometheus_middleware(app)
|
467
|
+
enable_func_timer()
|
468
|
+
|
442
469
|
# Send a warmup request
|
443
470
|
t = threading.Thread(
|
444
471
|
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
|
@@ -475,6 +502,10 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
475
502
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
476
503
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
477
504
|
|
505
|
+
# Set prometheus env vars
|
506
|
+
if server_args.enable_metrics:
|
507
|
+
set_prometheus_multiproc_dir()
|
508
|
+
|
478
509
|
# Set ulimit
|
479
510
|
set_ulimit()
|
480
511
|
|
@@ -523,6 +554,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|
523
554
|
return
|
524
555
|
|
525
556
|
model_info = res.json()
|
557
|
+
|
526
558
|
# Send a warmup request
|
527
559
|
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
528
560
|
max_new_tokens = 8 if model_info["is_generation"] else 1
|
@@ -560,6 +592,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|
560
592
|
if pipe_finish_writer is not None:
|
561
593
|
pipe_finish_writer.send("ready")
|
562
594
|
|
595
|
+
if server_args.delete_ckpt_after_loading:
|
596
|
+
delete_directory(server_args.model_path)
|
597
|
+
|
563
598
|
|
564
599
|
class Runtime:
|
565
600
|
"""
|
@@ -720,12 +755,12 @@ class Engine:
|
|
720
755
|
|
721
756
|
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
722
757
|
atexit.register(self.shutdown)
|
723
|
-
|
758
|
+
|
724
759
|
# runtime server default log level is log
|
725
760
|
# offline engine works in scripts, so we set it to error
|
726
761
|
|
727
|
-
if
|
728
|
-
kwargs[
|
762
|
+
if "log_level" not in kwargs:
|
763
|
+
kwargs["log_level"] = "error"
|
729
764
|
|
730
765
|
server_args = ServerArgs(*args, **kwargs)
|
731
766
|
launch_engine(server_args=server_args)
|
@@ -734,7 +769,7 @@ class Engine:
|
|
734
769
|
self,
|
735
770
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
736
771
|
prompt: Optional[Union[List[str], str]] = None,
|
737
|
-
sampling_params: Optional[Dict] = None,
|
772
|
+
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
738
773
|
# The token ids for text; one can either specify text or input_ids.
|
739
774
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
740
775
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
@@ -840,4 +875,12 @@ class Engine:
|
|
840
875
|
else:
|
841
876
|
return tokenizer_manager.tokenizer
|
842
877
|
|
843
|
-
|
878
|
+
def encode(
|
879
|
+
self,
|
880
|
+
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
881
|
+
):
|
882
|
+
obj = EmbeddingReqInput(text=prompt)
|
883
|
+
|
884
|
+
# get the current event loop
|
885
|
+
loop = asyncio.get_event_loop()
|
886
|
+
return loop.run_until_complete(encode_request(obj, None))
|
sglang/srt/server_args.py
CHANGED
@@ -22,7 +22,12 @@ import random
|
|
22
22
|
import tempfile
|
23
23
|
from typing import List, Optional
|
24
24
|
|
25
|
-
from sglang.srt.utils import
|
25
|
+
from sglang.srt.utils import (
|
26
|
+
get_gpu_memory_capacity,
|
27
|
+
is_flashinfer_available,
|
28
|
+
is_ipv6,
|
29
|
+
is_port_available,
|
30
|
+
)
|
26
31
|
|
27
32
|
logger = logging.getLogger(__name__)
|
28
33
|
|
@@ -63,25 +68,27 @@ class ServerArgs:
|
|
63
68
|
stream_interval: int = 1
|
64
69
|
random_seed: Optional[int] = None
|
65
70
|
constrained_json_whitespace_pattern: Optional[str] = None
|
66
|
-
|
71
|
+
watchdog_timeout: float = 300
|
72
|
+
download_dir: Optional[str] = None
|
67
73
|
|
68
74
|
# Logging
|
69
75
|
log_level: str = "info"
|
70
76
|
log_level_http: Optional[str] = None
|
71
77
|
log_requests: bool = False
|
72
78
|
show_time_cost: bool = False
|
79
|
+
enable_metrics: bool = False
|
80
|
+
decode_log_interval: int = 40
|
73
81
|
|
74
|
-
#
|
82
|
+
# API related
|
75
83
|
api_key: Optional[str] = None
|
76
84
|
file_storage_pth: str = "SGLang_storage"
|
77
85
|
enable_cache_report: bool = False
|
78
|
-
watchdog_timeout: float = 600
|
79
86
|
|
80
87
|
# Data parallelism
|
81
88
|
dp_size: int = 1
|
82
89
|
load_balance_method: str = "round_robin"
|
83
90
|
|
84
|
-
#
|
91
|
+
# Multi-node distributed serving
|
85
92
|
dist_init_addr: Optional[str] = None
|
86
93
|
nnodes: int = 1
|
87
94
|
node_rank: int = 0
|
@@ -110,7 +117,7 @@ class ServerArgs:
|
|
110
117
|
disable_flashinfer: bool = False
|
111
118
|
disable_flashinfer_sampling: bool = False
|
112
119
|
disable_radix_cache: bool = False
|
113
|
-
|
120
|
+
disable_jump_forward: bool = False
|
114
121
|
disable_cuda_graph: bool = False
|
115
122
|
disable_cuda_graph_padding: bool = False
|
116
123
|
disable_disk_cache: bool = False
|
@@ -127,6 +134,7 @@ class ServerArgs:
|
|
127
134
|
enable_p2p_check: bool = False
|
128
135
|
triton_attention_reduce_in_fp32: bool = False
|
129
136
|
num_continuous_decode_steps: int = 1
|
137
|
+
delete_ckpt_after_loading: bool = False
|
130
138
|
|
131
139
|
def __post_init__(self):
|
132
140
|
# Set missing default values
|
@@ -140,6 +148,9 @@ class ServerArgs:
|
|
140
148
|
# Disable chunked prefill
|
141
149
|
self.chunked_prefill_size = None
|
142
150
|
|
151
|
+
if self.random_seed is None:
|
152
|
+
self.random_seed = random.randint(0, 1 << 30)
|
153
|
+
|
143
154
|
# Mem fraction depends on the tensor parallelism size
|
144
155
|
if self.mem_fraction_static is None:
|
145
156
|
if self.tp_size >= 16:
|
@@ -153,8 +164,14 @@ class ServerArgs:
|
|
153
164
|
else:
|
154
165
|
self.mem_fraction_static = 0.88
|
155
166
|
|
156
|
-
|
157
|
-
|
167
|
+
# Adjust for GPUs with small memory capacities
|
168
|
+
gpu_mem = get_gpu_memory_capacity()
|
169
|
+
if gpu_mem < 25000:
|
170
|
+
logger.warning(
|
171
|
+
"Automatically adjust --chunked-prefill-size for small GPUs."
|
172
|
+
)
|
173
|
+
self.chunked_prefill_size //= 4 # make it 2048
|
174
|
+
self.cuda_graph_max_bs = 4
|
158
175
|
|
159
176
|
# Deprecation warnings
|
160
177
|
if self.disable_flashinfer:
|
@@ -204,6 +221,7 @@ class ServerArgs:
|
|
204
221
|
|
205
222
|
@staticmethod
|
206
223
|
def add_cli_args(parser: argparse.ArgumentParser):
|
224
|
+
# Model and port args
|
207
225
|
parser.add_argument(
|
208
226
|
"--model-path",
|
209
227
|
type=str,
|
@@ -323,6 +341,8 @@ class ServerArgs:
|
|
323
341
|
action="store_true",
|
324
342
|
help="Whether to use a CausalLM as an embedding model.",
|
325
343
|
)
|
344
|
+
|
345
|
+
# Memory and scheduling
|
326
346
|
parser.add_argument(
|
327
347
|
"--mem-fraction-static",
|
328
348
|
type=float,
|
@@ -367,6 +387,8 @@ class ServerArgs:
|
|
367
387
|
default=ServerArgs.schedule_conservativeness,
|
368
388
|
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
|
369
389
|
)
|
390
|
+
|
391
|
+
# Other runtime options
|
370
392
|
parser.add_argument(
|
371
393
|
"--tensor-parallel-size",
|
372
394
|
"--tp-size",
|
@@ -392,6 +414,20 @@ class ServerArgs:
|
|
392
414
|
default=ServerArgs.constrained_json_whitespace_pattern,
|
393
415
|
help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
|
394
416
|
)
|
417
|
+
parser.add_argument(
|
418
|
+
"--watchdog-timeout",
|
419
|
+
type=float,
|
420
|
+
default=ServerArgs.watchdog_timeout,
|
421
|
+
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
|
422
|
+
)
|
423
|
+
parser.add_argument(
|
424
|
+
"--download-dir",
|
425
|
+
type=str,
|
426
|
+
default=ServerArgs.download_dir,
|
427
|
+
help="Model download directory.",
|
428
|
+
)
|
429
|
+
|
430
|
+
# Logging
|
395
431
|
parser.add_argument(
|
396
432
|
"--log-level",
|
397
433
|
type=str,
|
@@ -414,6 +450,19 @@ class ServerArgs:
|
|
414
450
|
action="store_true",
|
415
451
|
help="Show time cost of custom marks.",
|
416
452
|
)
|
453
|
+
parser.add_argument(
|
454
|
+
"--enable-metrics",
|
455
|
+
action="store_true",
|
456
|
+
help="Enable log prometheus metrics.",
|
457
|
+
)
|
458
|
+
parser.add_argument(
|
459
|
+
"--decode-log-interval",
|
460
|
+
type=int,
|
461
|
+
default=ServerArgs.decode_log_interval,
|
462
|
+
help="The log interval of decode batch",
|
463
|
+
)
|
464
|
+
|
465
|
+
# API related
|
417
466
|
parser.add_argument(
|
418
467
|
"--api-key",
|
419
468
|
type=str,
|
@@ -431,18 +480,6 @@ class ServerArgs:
|
|
431
480
|
action="store_true",
|
432
481
|
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
|
433
482
|
)
|
434
|
-
parser.add_argument(
|
435
|
-
"--watchdog-timeout",
|
436
|
-
type=float,
|
437
|
-
default=ServerArgs.watchdog_timeout,
|
438
|
-
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
|
439
|
-
)
|
440
|
-
parser.add_argument(
|
441
|
-
"--decode-log-interval",
|
442
|
-
type=int,
|
443
|
-
default=ServerArgs.decode_log_interval,
|
444
|
-
help="The log interval of decode batch"
|
445
|
-
)
|
446
483
|
|
447
484
|
# Data parallelism
|
448
485
|
parser.add_argument(
|
@@ -463,7 +500,7 @@ class ServerArgs:
|
|
463
500
|
],
|
464
501
|
)
|
465
502
|
|
466
|
-
# Multi-node distributed serving
|
503
|
+
# Multi-node distributed serving
|
467
504
|
parser.add_argument(
|
468
505
|
"--dist-init-addr",
|
469
506
|
"--nccl-init-addr", # For backward compatbility. This will be removed in the future.
|
@@ -558,7 +595,7 @@ class ServerArgs:
|
|
558
595
|
type=str,
|
559
596
|
choices=["xgrammar", "outlines"],
|
560
597
|
default=ServerArgs.grammar_backend,
|
561
|
-
help="Choose the backend for
|
598
|
+
help="Choose the backend for grammar-guided decoding.",
|
562
599
|
)
|
563
600
|
|
564
601
|
# Optimization/debug options
|
@@ -578,9 +615,9 @@ class ServerArgs:
|
|
578
615
|
help="Disable RadixAttention for prefix caching.",
|
579
616
|
)
|
580
617
|
parser.add_argument(
|
581
|
-
"--disable-
|
618
|
+
"--disable-jump-forward",
|
582
619
|
action="store_true",
|
583
|
-
help="Disable
|
620
|
+
help="Disable jump-forward for grammar-guided decoding.",
|
584
621
|
)
|
585
622
|
parser.add_argument(
|
586
623
|
"--disable-cuda-graph",
|
@@ -600,7 +637,6 @@ class ServerArgs:
|
|
600
637
|
parser.add_argument(
|
601
638
|
"--disable-custom-all-reduce",
|
602
639
|
action="store_true",
|
603
|
-
default=False,
|
604
640
|
help="Disable the custom all-reduce kernel and fall back to NCCL.",
|
605
641
|
)
|
606
642
|
parser.add_argument(
|
@@ -670,6 +706,11 @@ class ServerArgs:
|
|
670
706
|
"This can potentially increase throughput but may also increase time-to-first-token latency. "
|
671
707
|
"The default value is 1, meaning only run one decoding step at a time.",
|
672
708
|
)
|
709
|
+
parser.add_argument(
|
710
|
+
"--delete-ckpt-after-loading",
|
711
|
+
action="store_true",
|
712
|
+
help="Delete the model checkpoint after loading the model.",
|
713
|
+
)
|
673
714
|
|
674
715
|
@classmethod
|
675
716
|
def from_cli_args(cls, args: argparse.Namespace):
|
sglang/srt/utils.py
CHANGED
@@ -22,8 +22,13 @@ import logging
|
|
22
22
|
import os
|
23
23
|
import pickle
|
24
24
|
import random
|
25
|
+
import re
|
25
26
|
import resource
|
27
|
+
import shutil
|
28
|
+
import signal
|
26
29
|
import socket
|
30
|
+
import subprocess
|
31
|
+
import tempfile
|
27
32
|
import time
|
28
33
|
import warnings
|
29
34
|
from importlib.metadata import PackageNotFoundError, version
|
@@ -35,9 +40,11 @@ import psutil
|
|
35
40
|
import requests
|
36
41
|
import torch
|
37
42
|
import torch.distributed as dist
|
43
|
+
import triton
|
38
44
|
import zmq
|
39
45
|
from fastapi.responses import ORJSONResponse
|
40
46
|
from packaging import version as pkg_version
|
47
|
+
from starlette.routing import Mount
|
41
48
|
from torch import nn
|
42
49
|
from torch.profiler import ProfilerActivity, profile, record_function
|
43
50
|
from triton.runtime.cache import (
|
@@ -379,6 +386,10 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
|
|
379
386
|
if include_self:
|
380
387
|
try:
|
381
388
|
itself.kill()
|
389
|
+
|
390
|
+
# Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
|
391
|
+
# so we send an additional signal to kill them.
|
392
|
+
itself.send_signal(signal.SIGINT)
|
382
393
|
except psutil.NoSuchProcess:
|
383
394
|
pass
|
384
395
|
|
@@ -704,3 +715,112 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint:
|
|
704
715
|
raise ValueError(f"Unsupported socket type: {socket_type}")
|
705
716
|
|
706
717
|
return socket
|
718
|
+
|
719
|
+
|
720
|
+
def dump_to_file(dirpath, name, value):
|
721
|
+
from vllm.distributed import get_tensor_model_parallel_rank
|
722
|
+
|
723
|
+
if get_tensor_model_parallel_rank() != 0:
|
724
|
+
return
|
725
|
+
|
726
|
+
os.makedirs(dirpath, exist_ok=True)
|
727
|
+
if value.dtype is torch.bfloat16:
|
728
|
+
value = value.float()
|
729
|
+
value = value.cpu().numpy()
|
730
|
+
output_filename = os.path.join(dirpath, f"pytorch_dump_{name}.npy")
|
731
|
+
logger.info(f"Dump a tensor to {output_filename}. Shape = {value.shape}")
|
732
|
+
np.save(output_filename, value)
|
733
|
+
|
734
|
+
|
735
|
+
def is_triton_3():
|
736
|
+
return triton.__version__.startswith("3.")
|
737
|
+
|
738
|
+
|
739
|
+
def maybe_torch_compile(*args, **kwargs):
|
740
|
+
"""
|
741
|
+
torch.compile does not work for triton 2.2.0, which is needed in xlm1's jax.
|
742
|
+
Therefore, we disable it here.
|
743
|
+
"""
|
744
|
+
|
745
|
+
def decorator(func):
|
746
|
+
if is_triton_3():
|
747
|
+
return torch.compile(*args, **kwargs)(func)
|
748
|
+
return func
|
749
|
+
|
750
|
+
return decorator
|
751
|
+
|
752
|
+
|
753
|
+
def delete_directory(dirpath):
|
754
|
+
try:
|
755
|
+
# This will remove the directory and all its contents
|
756
|
+
shutil.rmtree(dirpath)
|
757
|
+
except OSError as e:
|
758
|
+
print(f"Warning: {dirpath} : {e.strerror}")
|
759
|
+
|
760
|
+
|
761
|
+
# Temporary directory for prometheus multiprocess mode
|
762
|
+
# Cleaned up automatically when this object is garbage collected
|
763
|
+
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
764
|
+
|
765
|
+
|
766
|
+
def set_prometheus_multiproc_dir():
|
767
|
+
# Set prometheus multiprocess directory
|
768
|
+
# sglang uses prometheus multiprocess mode
|
769
|
+
# we need to set this before importing prometheus_client
|
770
|
+
# https://prometheus.github.io/client_python/multiprocess/
|
771
|
+
global prometheus_multiproc_dir
|
772
|
+
|
773
|
+
if "PROMETHEUS_MULTIPROC_DIR" in os.environ:
|
774
|
+
logger.debug("User set PROMETHEUS_MULTIPROC_DIR detected.")
|
775
|
+
prometheus_multiproc_dir = tempfile.TemporaryDirectory(
|
776
|
+
dir=os.environ["PROMETHEUS_MULTIPROC_DIR"]
|
777
|
+
)
|
778
|
+
else:
|
779
|
+
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
|
780
|
+
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
|
781
|
+
logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")
|
782
|
+
|
783
|
+
|
784
|
+
def add_prometheus_middleware(app):
|
785
|
+
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
|
786
|
+
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
|
787
|
+
|
788
|
+
registry = CollectorRegistry()
|
789
|
+
multiprocess.MultiProcessCollector(registry)
|
790
|
+
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
791
|
+
|
792
|
+
# Workaround for 307 Redirect for /metrics
|
793
|
+
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
794
|
+
app.routes.append(metrics_route)
|
795
|
+
|
796
|
+
|
797
|
+
def get_gpu_memory_capacity():
|
798
|
+
try:
|
799
|
+
# Run nvidia-smi and capture the output
|
800
|
+
result = subprocess.run(
|
801
|
+
["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"],
|
802
|
+
stdout=subprocess.PIPE,
|
803
|
+
stderr=subprocess.PIPE,
|
804
|
+
text=True,
|
805
|
+
)
|
806
|
+
|
807
|
+
if result.returncode != 0:
|
808
|
+
raise RuntimeError(f"nvidia-smi error: {result.stderr.strip()}")
|
809
|
+
|
810
|
+
# Parse the output to extract memory values
|
811
|
+
memory_values = [
|
812
|
+
float(mem)
|
813
|
+
for mem in result.stdout.strip().split("\n")
|
814
|
+
if re.match(r"^\d+(\.\d+)?$", mem.strip())
|
815
|
+
]
|
816
|
+
|
817
|
+
if not memory_values:
|
818
|
+
raise ValueError("No GPU memory values found.")
|
819
|
+
|
820
|
+
# Return the minimum memory value
|
821
|
+
return min(memory_values)
|
822
|
+
|
823
|
+
except FileNotFoundError:
|
824
|
+
raise RuntimeError(
|
825
|
+
"nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible."
|
826
|
+
)
|
@@ -320,7 +320,7 @@ jinja_env = jinja2.Environment(
|
|
320
320
|
_message_template = """
|
321
321
|
<div class="message {{ role }}">
|
322
322
|
<div class="role">
|
323
|
-
{{ role }}
|
323
|
+
{{ role }}
|
324
324
|
{% if variant %}<span class="variant">({{ variant }})</span>{% endif %}
|
325
325
|
</div>
|
326
326
|
<div class="content">
|
@@ -2,8 +2,8 @@
|
|
2
2
|
|
3
3
|
"""
|
4
4
|
HumanEval: Evaluating Large Language Models Trained on Code
|
5
|
-
Mark Chen and Jerry Tworek and Heewoo Jun and Qiming Yuan and Henrique Ponde de Oliveira Pinto and Jared Kaplan and Harri Edwards and Yuri Burda and Nicholas Joseph and Greg Brockman and Alex Ray and Raul Puri and Gretchen Krueger and Michael Petrov and Heidy Khlaaf and Girish Sastry and Pamela Mishkin and Brooke Chan and Scott Gray and Nick Ryder and Mikhail Pavlov and Alethea Power and Lukasz Kaiser and Mohammad Bavarian and Clemens Winter and Philippe Tillet and Felipe Petroski Such and Dave Cummings and Matthias Plappert and Fotios Chantzis and Elizabeth Barnes and Ariel Herbert-Voss and William Hebgen Guss and Alex Nichol and Alex Paino and Nikolas Tezak and Jie Tang and Igor Babuschkin and Suchir Balaji and Shantanu Jain and William Saunders and Christopher Hesse and Andrew N. Carr and Jan Leike and Josh Achiam and Vedant Misra and Evan Morikawa and Alec Radford and Matthew Knight and Miles Brundage and Mira Murati and Katie Mayer and Peter Welinder and Bob McGrew and Dario Amodei and Sam McCandlish and Ilya Sutskever and Wojciech Zaremba
|
6
|
-
https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
|
5
|
+
Mark Chen and Jerry Tworek and Heewoo Jun and Qiming Yuan and Henrique Ponde de Oliveira Pinto and Jared Kaplan and Harri Edwards and Yuri Burda and Nicholas Joseph and Greg Brockman and Alex Ray and Raul Puri and Gretchen Krueger and Michael Petrov and Heidy Khlaaf and Girish Sastry and Pamela Mishkin and Brooke Chan and Scott Gray and Nick Ryder and Mikhail Pavlov and Alethea Power and Lukasz Kaiser and Mohammad Bavarian and Clemens Winter and Philippe Tillet and Felipe Petroski Such and Dave Cummings and Matthias Plappert and Fotios Chantzis and Elizabeth Barnes and Ariel Herbert-Voss and William Hebgen Guss and Alex Nichol and Alex Paino and Nikolas Tezak and Jie Tang and Igor Babuschkin and Suchir Balaji and Shantanu Jain and William Saunders and Christopher Hesse and Andrew N. Carr and Jan Leike and Josh Achiam and Vedant Misra and Evan Morikawa and Alec Radford and Matthew Knight and Miles Brundage and Mira Murati and Katie Mayer and Peter Welinder and Bob McGrew and Dario Amodei and Sam McCandlish and Ilya Sutskever and Wojciech Zaremba
|
6
|
+
https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
|
7
7
|
"""
|
8
8
|
|
9
9
|
import random
|
sglang/test/simple_eval_mgsm.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1
1
|
# Adapted from https://github.com/openai/simple-evals/
|
2
2
|
|
3
3
|
"""
|
4
|
-
MGSM: Multilingual Grade School Math Benchmark (MGSM) is a benchmark of grade-school math problems.
|
4
|
+
MGSM: Multilingual Grade School Math Benchmark (MGSM) is a benchmark of grade-school math problems.
|
5
5
|
Language Models are Multilingual Chain-of-Thought Reasoners
|
6
6
|
Freda Shi, Mirac Suzgun, Markus Freitag, Xuezhi Wang, Suraj Srivats, Soroush Vosoughi, Hyung Won Chung, Yi Tay, Sebastian Ruder, Denny Zhou, Dipanjan Das, Jason Wei
|
7
|
-
https://arxiv.org/abs/2210.03057 reference: https://github.com/google-research/url-nlp
|
7
|
+
https://arxiv.org/abs/2210.03057 reference: https://github.com/google-research/url-nlp
|
8
8
|
"""
|
9
9
|
|
10
10
|
import re
|