sglang 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +4 -4
- sglang/backend/litellm.py +2 -2
- sglang/backend/openai.py +26 -15
- sglang/bench_latency.py +299 -0
- sglang/global_config.py +4 -1
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +1 -1
- sglang/lang/ir.py +15 -5
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +13 -6
- sglang/srt/constrained/fsm_cache.py +6 -3
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +2 -0
- sglang/srt/hf_transformers_utils.py +64 -9
- sglang/srt/layers/fused_moe.py +186 -89
- sglang/srt/layers/logits_processor.py +53 -25
- sglang/srt/layers/radix_attention.py +34 -7
- sglang/srt/managers/controller/dp_worker.py +6 -3
- sglang/srt/managers/controller/infer_batch.py +142 -67
- sglang/srt/managers/controller/manager_multi.py +5 -5
- sglang/srt/managers/controller/manager_single.py +8 -3
- sglang/srt/managers/controller/model_runner.py +154 -54
- sglang/srt/managers/controller/radix_cache.py +4 -0
- sglang/srt/managers/controller/schedule_heuristic.py +2 -0
- sglang/srt/managers/controller/tp_worker.py +140 -135
- sglang/srt/managers/detokenizer_manager.py +15 -19
- sglang/srt/managers/io_struct.py +10 -4
- sglang/srt/managers/tokenizer_manager.py +14 -13
- sglang/srt/model_config.py +83 -4
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/grok.py +204 -137
- sglang/srt/models/llama2.py +11 -4
- sglang/srt/models/llama_classification.py +104 -0
- sglang/srt/models/llava.py +11 -8
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/mixtral.py +164 -115
- sglang/srt/models/mixtral_quant.py +0 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/models/yivl.py +2 -2
- sglang/srt/openai_api_adapter.py +33 -23
- sglang/srt/openai_protocol.py +1 -1
- sglang/srt/server.py +60 -19
- sglang/srt/server_args.py +79 -44
- sglang/srt/utils.py +146 -37
- sglang/test/test_programs.py +28 -10
- sglang/utils.py +4 -3
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/METADATA +29 -22
- sglang-0.1.18.dist-info/RECORD +78 -0
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
- sglang/srt/managers/router/infer_batch.py +0 -596
- sglang/srt/managers/router/manager.py +0 -82
- sglang/srt/managers/router/model_rpc.py +0 -818
- sglang/srt/managers/router/model_runner.py +0 -445
- sglang/srt/managers/router/radix_cache.py +0 -267
- sglang/srt/managers/router/scheduler.py +0 -59
- sglang-0.1.17.dist-info/RECORD +0 -81
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
sglang/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
__version__ = "0.1.
|
1
|
+
__version__ = "0.1.18"
|
2
2
|
|
3
3
|
# SGL API Components
|
4
4
|
from sglang.api import (
|
@@ -24,10 +24,10 @@ from sglang.api import (
|
|
24
24
|
|
25
25
|
# SGL Backends
|
26
26
|
from sglang.backend.anthropic import Anthropic
|
27
|
+
from sglang.backend.litellm import LiteLLM
|
27
28
|
from sglang.backend.openai import OpenAI
|
28
29
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
29
30
|
from sglang.backend.vertexai import VertexAI
|
30
|
-
from sglang.backend.litellm import LiteLLM
|
31
31
|
|
32
32
|
# Global Configurations
|
33
33
|
from sglang.global_config import global_config
|
sglang/api.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
"""
|
1
|
+
"""Public APIs of the language."""
|
2
2
|
|
3
3
|
import os
|
4
4
|
import re
|
@@ -43,14 +43,14 @@ def set_default_backend(backend: BaseBackend):
|
|
43
43
|
global_config.default_backend = backend
|
44
44
|
|
45
45
|
|
46
|
-
def flush_cache(backend: BaseBackend = None):
|
46
|
+
def flush_cache(backend: Optional[BaseBackend] = None):
|
47
47
|
backend = backend or global_config.default_backend
|
48
48
|
if backend is None:
|
49
49
|
return False
|
50
50
|
return backend.flush_cache()
|
51
51
|
|
52
52
|
|
53
|
-
def get_server_args(backend: BaseBackend = None):
|
53
|
+
def get_server_args(backend: Optional[BaseBackend] = None):
|
54
54
|
backend = backend or global_config.default_backend
|
55
55
|
if backend is None:
|
56
56
|
return None
|
@@ -158,7 +158,7 @@ def video(path: str, num_frames: int):
|
|
158
158
|
|
159
159
|
def select(
|
160
160
|
name: Optional[str] = None,
|
161
|
-
choices: List[str] = None,
|
161
|
+
choices: Optional[List[str]] = None,
|
162
162
|
temperature: float = 0.0,
|
163
163
|
):
|
164
164
|
assert choices is not None
|
sglang/backend/litellm.py
CHANGED
@@ -13,7 +13,6 @@ except ImportError as e:
|
|
13
13
|
|
14
14
|
|
15
15
|
class LiteLLM(BaseBackend):
|
16
|
-
|
17
16
|
def __init__(
|
18
17
|
self,
|
19
18
|
model_name,
|
@@ -33,7 +32,8 @@ class LiteLLM(BaseBackend):
|
|
33
32
|
self.model_name = model_name
|
34
33
|
|
35
34
|
self.chat_template = chat_template or get_chat_template_by_model_path(
|
36
|
-
model_name
|
35
|
+
model_name
|
36
|
+
)
|
37
37
|
|
38
38
|
self.client_params = {
|
39
39
|
"api_key": api_key,
|
sglang/backend/openai.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
|
+
import dataclasses
|
1
2
|
import logging
|
2
3
|
import time
|
3
4
|
import warnings
|
4
|
-
import dataclasses
|
5
5
|
from typing import Callable, List, Optional, Union
|
6
6
|
|
7
7
|
import numpy as np
|
@@ -105,14 +105,16 @@ class OpenAI(BaseBackend):
|
|
105
105
|
def get_chat_template(self):
|
106
106
|
return self.chat_template
|
107
107
|
|
108
|
-
def _prepare_spec_execution(
|
109
|
-
|
108
|
+
def _prepare_spec_execution(
|
109
|
+
self,
|
110
|
+
sampling_params: SglSamplingParams,
|
111
|
+
num_api_spec_tokens: int,
|
112
|
+
spec_var_name: str,
|
113
|
+
):
|
110
114
|
if "max_tokens" not in self.spec_kwargs:
|
111
115
|
self.spec_kwargs["max_tokens"] = num_api_spec_tokens
|
112
116
|
else:
|
113
|
-
assert
|
114
|
-
self.spec_kwargs["max_tokens"] == num_api_spec_tokens
|
115
|
-
)
|
117
|
+
assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
|
116
118
|
|
117
119
|
params = sampling_params.to_openai_kwargs()
|
118
120
|
for key, value in params.items():
|
@@ -151,8 +153,9 @@ class OpenAI(BaseBackend):
|
|
151
153
|
)
|
152
154
|
prompt = s.messages_
|
153
155
|
else:
|
154
|
-
return self._prepare_spec_execution(
|
155
|
-
s.num_api_spec_tokens, spec_var_name
|
156
|
+
return self._prepare_spec_execution(
|
157
|
+
sampling_params, s.num_api_spec_tokens, spec_var_name
|
158
|
+
)
|
156
159
|
else:
|
157
160
|
prompt = s.text_
|
158
161
|
|
@@ -325,7 +328,7 @@ class OpenAI(BaseBackend):
|
|
325
328
|
ret_str = ret.choices[0].text
|
326
329
|
ret_token = self.tokenizer.encode(ret_str)[0]
|
327
330
|
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
|
328
|
-
self.token_usage.completion_tokens= ret.usage.completion_tokens
|
331
|
+
self.token_usage.completion_tokens = ret.usage.completion_tokens
|
329
332
|
|
330
333
|
# TODO:
|
331
334
|
# 1. return logits as the scores
|
@@ -355,7 +358,9 @@ class OpenAI(BaseBackend):
|
|
355
358
|
return decision, scores, None, None
|
356
359
|
|
357
360
|
|
358
|
-
def openai_completion(
|
361
|
+
def openai_completion(
|
362
|
+
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
363
|
+
):
|
359
364
|
for attempt in range(retries):
|
360
365
|
try:
|
361
366
|
if is_chat:
|
@@ -385,15 +390,19 @@ def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None,
|
|
385
390
|
return comp
|
386
391
|
|
387
392
|
|
388
|
-
def openai_completion_stream(
|
393
|
+
def openai_completion_stream(
|
394
|
+
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
395
|
+
):
|
389
396
|
for attempt in range(retries):
|
390
397
|
try:
|
391
398
|
if is_chat:
|
392
399
|
if "stop" in kwargs and kwargs["stop"] is None:
|
393
400
|
kwargs.pop("stop")
|
394
401
|
generator = client.chat.completions.create(
|
395
|
-
messages=prompt,
|
396
|
-
|
402
|
+
messages=prompt,
|
403
|
+
stream=True,
|
404
|
+
stream_options={"include_usage": True},
|
405
|
+
**kwargs,
|
397
406
|
)
|
398
407
|
for ret in generator:
|
399
408
|
if len(ret.choices) == 0:
|
@@ -405,8 +414,10 @@ def openai_completion_stream(client, token_usage, is_chat=None, retries=3, promp
|
|
405
414
|
yield content or "", {}
|
406
415
|
else:
|
407
416
|
generator = client.completions.create(
|
408
|
-
prompt=prompt,
|
409
|
-
|
417
|
+
prompt=prompt,
|
418
|
+
stream=True,
|
419
|
+
stream_options={"include_usage": True},
|
420
|
+
**kwargs,
|
410
421
|
)
|
411
422
|
for ret in generator:
|
412
423
|
if len(ret.choices) == 0:
|
sglang/bench_latency.py
ADDED
@@ -0,0 +1,299 @@
|
|
1
|
+
"""
|
2
|
+
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
|
3
|
+
|
4
|
+
# Usage (latency test):
|
5
|
+
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
|
6
|
+
|
7
|
+
# Usage (correctness test):
|
8
|
+
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
|
9
|
+
|
10
|
+
### Reference output:
|
11
|
+
prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
|
12
|
+
[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
|
13
|
+
[ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
|
14
|
+
device='cuda:0', dtype=torch.float16)
|
15
|
+
prefill logits (final) tensor([[-8.3203, -7.1211, 3.3379, ..., -4.9570, -4.1328, -3.4141],
|
16
|
+
[-8.9062, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0742],
|
17
|
+
[-9.6328, -9.0547, 4.0117, ..., -5.3047, -4.7148, -4.4609]],
|
18
|
+
device='cuda:0', dtype=torch.float16)
|
19
|
+
<s> The capital of France is.
|
20
|
+
The capital of the United States is Washington, D.C.
|
21
|
+
|
22
|
+
<s> The capital of the United Kindom is.
|
23
|
+
The capital of the United Kingdom is London.
|
24
|
+
The capital of the
|
25
|
+
<s> Today is a sunny day and I like go for a walk in the park.
|
26
|
+
I'm going to the park
|
27
|
+
"""
|
28
|
+
|
29
|
+
import argparse
|
30
|
+
import dataclasses
|
31
|
+
import logging
|
32
|
+
import multiprocessing
|
33
|
+
import time
|
34
|
+
|
35
|
+
import numpy as np
|
36
|
+
import torch
|
37
|
+
import torch.distributed as dist
|
38
|
+
|
39
|
+
from sglang.srt.hf_transformers_utils import get_tokenizer
|
40
|
+
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req
|
41
|
+
from sglang.srt.managers.controller.model_runner import ModelRunner
|
42
|
+
from sglang.srt.model_config import ModelConfig
|
43
|
+
from sglang.srt.sampling_params import SamplingParams
|
44
|
+
from sglang.srt.server_args import ServerArgs
|
45
|
+
from sglang.srt.utils import suppress_other_loggers
|
46
|
+
|
47
|
+
|
48
|
+
@dataclasses.dataclass
|
49
|
+
class BenchArgs:
|
50
|
+
batch_size: int = 1
|
51
|
+
input_len: int = 1024
|
52
|
+
output_len: int = 4
|
53
|
+
correctness_test: bool = False
|
54
|
+
# This is only used for correctness test
|
55
|
+
cut_len: int = 4
|
56
|
+
|
57
|
+
@staticmethod
|
58
|
+
def add_cli_args(parser: argparse.ArgumentParser):
|
59
|
+
parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size)
|
60
|
+
parser.add_argument("--input-len", type=int, default=BenchArgs.input_len)
|
61
|
+
parser.add_argument("--output-len", type=int, default=BenchArgs.output_len)
|
62
|
+
parser.add_argument("--correctness-test", action="store_true")
|
63
|
+
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
64
|
+
|
65
|
+
@classmethod
|
66
|
+
def from_cli_args(cls, args: argparse.Namespace):
|
67
|
+
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
68
|
+
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
69
|
+
|
70
|
+
|
71
|
+
def load_model(server_args, tp_rank):
|
72
|
+
suppress_other_loggers()
|
73
|
+
|
74
|
+
model_config = ModelConfig(path=server_args.model_path)
|
75
|
+
model_runner = ModelRunner(
|
76
|
+
model_config=model_config,
|
77
|
+
mem_fraction_static=server_args.mem_fraction_static,
|
78
|
+
gpu_id=tp_rank,
|
79
|
+
tp_rank=tp_rank,
|
80
|
+
tp_size=server_args.tp_size,
|
81
|
+
nccl_port=28888,
|
82
|
+
server_args=server_args,
|
83
|
+
)
|
84
|
+
print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
85
|
+
tokenizer = get_tokenizer(
|
86
|
+
server_args.tokenizer_path,
|
87
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
88
|
+
trust_remote_code=server_args.trust_remote_code,
|
89
|
+
)
|
90
|
+
if server_args.tp_size > 1:
|
91
|
+
dist.barrier()
|
92
|
+
return model_runner, tokenizer
|
93
|
+
|
94
|
+
|
95
|
+
def prepare_inputs(bench_args, tokenizer):
|
96
|
+
prompts = [
|
97
|
+
"The capital of France is",
|
98
|
+
"The capital of the United Kindom is",
|
99
|
+
"Today is a sunny day and I like",
|
100
|
+
]
|
101
|
+
input_ids = [tokenizer.encode(p) for p in prompts]
|
102
|
+
sampling_params = SamplingParams(
|
103
|
+
temperature=0,
|
104
|
+
max_new_tokens=BenchArgs.output_len,
|
105
|
+
)
|
106
|
+
|
107
|
+
reqs = []
|
108
|
+
for i in range(len(prompts)):
|
109
|
+
assert len(input_ids[i]) > bench_args.cut_len
|
110
|
+
|
111
|
+
tmp_input_ids = input_ids[i][:bench_args.cut_len]
|
112
|
+
req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
|
113
|
+
req.prefix_indices = []
|
114
|
+
req.sampling_params = sampling_params
|
115
|
+
req.input_ids = req.origin_input_ids
|
116
|
+
reqs.append(req)
|
117
|
+
|
118
|
+
return input_ids, reqs
|
119
|
+
|
120
|
+
|
121
|
+
def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
|
122
|
+
for i in range(len(reqs)):
|
123
|
+
req = reqs[i]
|
124
|
+
req.input_ids += input_ids[i][bench_args.cut_len:]
|
125
|
+
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
126
|
+
i, :bench_args.cut_len
|
127
|
+
]
|
128
|
+
return reqs
|
129
|
+
|
130
|
+
|
131
|
+
def prepare_synthetic_inputs(bench_args, tokenizer):
|
132
|
+
input_ids = np.ones((bench_args.batch_size, bench_args.input_len), dtype=np.int32)
|
133
|
+
sampling_params = SamplingParams(
|
134
|
+
temperature=0,
|
135
|
+
max_new_tokens=BenchArgs.output_len,
|
136
|
+
)
|
137
|
+
|
138
|
+
reqs = []
|
139
|
+
for i in range(len(input_ids)):
|
140
|
+
req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
|
141
|
+
req.prefix_indices = []
|
142
|
+
req.sampling_params = sampling_params
|
143
|
+
req.input_ids = req.origin_input_ids
|
144
|
+
reqs.append(req)
|
145
|
+
|
146
|
+
return reqs
|
147
|
+
|
148
|
+
|
149
|
+
def extend(reqs, model_runner):
|
150
|
+
batch = Batch.init_new(
|
151
|
+
reqs=reqs,
|
152
|
+
req_to_token_pool=model_runner.req_to_token_pool,
|
153
|
+
token_to_kv_pool=model_runner.token_to_kv_pool,
|
154
|
+
tree_cache=None)
|
155
|
+
batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
|
156
|
+
output = model_runner.forward(batch, ForwardMode.EXTEND)
|
157
|
+
next_token_ids, _ = batch.sample(output.next_token_logits)
|
158
|
+
return next_token_ids, output.next_token_logits, batch
|
159
|
+
|
160
|
+
|
161
|
+
def decode(input_token_ids, batch, model_runner):
|
162
|
+
batch.prepare_for_decode(input_token_ids.cpu().numpy())
|
163
|
+
output = model_runner.forward(batch, ForwardMode.DECODE)
|
164
|
+
next_token_ids, _ = batch.sample(output.next_token_logits)
|
165
|
+
return next_token_ids, output.next_token_logits
|
166
|
+
|
167
|
+
|
168
|
+
def correctness_test(
|
169
|
+
server_args,
|
170
|
+
bench_args,
|
171
|
+
tp_rank,
|
172
|
+
):
|
173
|
+
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
174
|
+
|
175
|
+
# Load the model
|
176
|
+
model_runner, tokenizer = load_model(server_args, tp_rank)
|
177
|
+
|
178
|
+
# Prepare inputs
|
179
|
+
input_ids, reqs = prepare_inputs(bench_args, tokenizer)
|
180
|
+
|
181
|
+
# Prefill
|
182
|
+
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
183
|
+
rank_print("prefill logits (first half)", next_token_logits)
|
184
|
+
|
185
|
+
# Prepare extend inputs
|
186
|
+
reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner)
|
187
|
+
|
188
|
+
# Extend
|
189
|
+
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
190
|
+
rank_print("prefill logits (final)", next_token_logits)
|
191
|
+
|
192
|
+
# Decode
|
193
|
+
output_ids = [list(req.input_ids) for req in reqs]
|
194
|
+
for _ in range(bench_args.output_len):
|
195
|
+
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
196
|
+
for i in range(len(reqs)):
|
197
|
+
output_ids[i].append(next_token_ids[i])
|
198
|
+
|
199
|
+
# Print
|
200
|
+
for i in range(len(reqs)):
|
201
|
+
print(tokenizer.decode(output_ids[i]))
|
202
|
+
|
203
|
+
|
204
|
+
def latency_test(
|
205
|
+
server_args,
|
206
|
+
bench_args,
|
207
|
+
tp_rank,
|
208
|
+
):
|
209
|
+
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
210
|
+
|
211
|
+
# Load the model
|
212
|
+
model_runner, tokenizer = load_model(server_args, tp_rank)
|
213
|
+
print(f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}")
|
214
|
+
|
215
|
+
# Prepare inputs
|
216
|
+
reqs = prepare_synthetic_inputs(bench_args, tokenizer)
|
217
|
+
|
218
|
+
def clear():
|
219
|
+
model_runner.req_to_token_pool.clear()
|
220
|
+
model_runner.token_to_kv_pool.clear()
|
221
|
+
|
222
|
+
@torch.inference_mode()
|
223
|
+
def run_once(output_len):
|
224
|
+
# Prefill
|
225
|
+
torch.cuda.synchronize()
|
226
|
+
tot_latency = 0
|
227
|
+
tic = time.time()
|
228
|
+
next_token_ids, _, batch = extend(reqs, model_runner)
|
229
|
+
torch.cuda.synchronize()
|
230
|
+
prefill_latency = time.time() - tic
|
231
|
+
tot_latency += prefill_latency
|
232
|
+
throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
|
233
|
+
rank_print(f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s")
|
234
|
+
|
235
|
+
# Decode
|
236
|
+
for i in range(output_len):
|
237
|
+
torch.cuda.synchronize()
|
238
|
+
tic = time.time()
|
239
|
+
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
240
|
+
torch.cuda.synchronize()
|
241
|
+
latency = time.time() - tic
|
242
|
+
tot_latency += latency
|
243
|
+
throughput = bench_args.batch_size / latency
|
244
|
+
if i < 5: rank_print(f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s")
|
245
|
+
avg_decode_latency = (tot_latency - prefill_latency) / output_len
|
246
|
+
avg_decode_throughput = bench_args.batch_size / avg_decode_latency
|
247
|
+
rank_print(f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s")
|
248
|
+
|
249
|
+
throughput = (bench_args.input_len + bench_args.output_len) * bench_args.batch_size / tot_latency
|
250
|
+
rank_print(f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s")
|
251
|
+
|
252
|
+
# Warm up
|
253
|
+
run_once(4)
|
254
|
+
clear()
|
255
|
+
|
256
|
+
# Run again
|
257
|
+
run_once(bench_args.output_len)
|
258
|
+
|
259
|
+
|
260
|
+
def main(server_args, bench_args):
|
261
|
+
print(bench_args)
|
262
|
+
|
263
|
+
if bench_args.correctness_test:
|
264
|
+
work_func = correctness_test
|
265
|
+
else:
|
266
|
+
work_func = latency_test
|
267
|
+
|
268
|
+
workers = []
|
269
|
+
for tp_rank in range(server_args.tp_size):
|
270
|
+
proc = multiprocessing.Process(
|
271
|
+
target=work_func,
|
272
|
+
args=(
|
273
|
+
server_args,
|
274
|
+
bench_args,
|
275
|
+
tp_rank,
|
276
|
+
),
|
277
|
+
)
|
278
|
+
proc.start()
|
279
|
+
workers.append(proc)
|
280
|
+
|
281
|
+
for proc in workers:
|
282
|
+
proc.join()
|
283
|
+
|
284
|
+
|
285
|
+
if __name__ == "__main__":
|
286
|
+
parser = argparse.ArgumentParser()
|
287
|
+
ServerArgs.add_cli_args(parser)
|
288
|
+
BenchArgs.add_cli_args(parser)
|
289
|
+
args = parser.parse_args()
|
290
|
+
|
291
|
+
server_args = ServerArgs.from_cli_args(args)
|
292
|
+
bench_args = BenchArgs.from_cli_args(args)
|
293
|
+
|
294
|
+
logging.basicConfig(
|
295
|
+
level=getattr(logging, server_args.log_level.upper()),
|
296
|
+
format="%(message)s",
|
297
|
+
)
|
298
|
+
|
299
|
+
main(server_args, bench_args)
|
sglang/global_config.py
CHANGED
@@ -27,7 +27,7 @@ class GlobalConfig:
|
|
27
27
|
|
28
28
|
# Request dependency time due to network delay
|
29
29
|
self.request_dependency_delay = 0.02
|
30
|
-
self.wait_for_new_request_delay = 0.
|
30
|
+
self.wait_for_new_request_delay = 0.0006
|
31
31
|
|
32
32
|
# New generation token ratio estimation
|
33
33
|
self.base_new_token_ratio = 0.4
|
@@ -35,5 +35,8 @@ class GlobalConfig:
|
|
35
35
|
self.new_token_ratio_decay = 0.0001
|
36
36
|
self.new_token_ratio_recovery = 0.05
|
37
37
|
|
38
|
+
# The threshold (number of tokens) to trigger layer-wise cuda sync.
|
39
|
+
# This can improve the speed for large batch sizes during prefill.
|
40
|
+
self.layer_sync_threshold = 8192
|
38
41
|
|
39
42
|
global_config = GlobalConfig()
|
sglang/lang/compiler.py
CHANGED
@@ -4,7 +4,7 @@ from queue import Queue
|
|
4
4
|
from typing import List, Union
|
5
5
|
|
6
6
|
from sglang.global_config import global_config
|
7
|
-
from sglang.lang.interpreter import ProgramState, StreamExecutor,
|
7
|
+
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
|
8
8
|
from sglang.lang.ir import (
|
9
9
|
SglArgument,
|
10
10
|
SglConstantText,
|
@@ -184,7 +184,7 @@ class CompiledFunction:
|
|
184
184
|
|
185
185
|
# Extract prefix by tracing and cache it
|
186
186
|
if len(batch_kwargs) > 1:
|
187
|
-
|
187
|
+
cache_program(self.function, backend)
|
188
188
|
|
189
189
|
# Run all programs
|
190
190
|
if num_threads == "auto":
|
sglang/lang/interpreter.py
CHANGED
@@ -507,7 +507,7 @@ class StreamExecutor:
|
|
507
507
|
)
|
508
508
|
return
|
509
509
|
|
510
|
-
else:
|
510
|
+
else: # Speculative execution on models with completion interface
|
511
511
|
comp, meta_info = self._spec_gen(sampling_params)
|
512
512
|
|
513
513
|
self.text_ += comp
|
sglang/lang/ir.py
CHANGED
@@ -81,12 +81,10 @@ class SglSamplingParams:
|
|
81
81
|
"top_p": self.top_p,
|
82
82
|
"top_k": self.top_k,
|
83
83
|
}
|
84
|
-
|
84
|
+
|
85
85
|
def to_litellm_kwargs(self):
|
86
86
|
if self.regex is not None:
|
87
|
-
warnings.warn(
|
88
|
-
"Regular expression is not supported in the LiteLLM backend."
|
89
|
-
)
|
87
|
+
warnings.warn("Regular expression is not supported in the LiteLLM backend.")
|
90
88
|
return {
|
91
89
|
"max_tokens": self.max_new_tokens,
|
92
90
|
"stop": self.stop or None,
|
@@ -122,6 +120,7 @@ class SglFunction:
|
|
122
120
|
argspec = inspect.getfullargspec(func)
|
123
121
|
assert argspec.args[0] == "s", 'The first argument must be "s"'
|
124
122
|
self.arg_names = argspec.args[1:]
|
123
|
+
self.arg_defaults = argspec.defaults if argspec.defaults is not None else []
|
125
124
|
|
126
125
|
def bind(self, **kwargs):
|
127
126
|
assert all(key in self.arg_names for key in kwargs)
|
@@ -180,7 +179,18 @@ class SglFunction:
|
|
180
179
|
assert isinstance(batch_kwargs, (list, tuple))
|
181
180
|
if len(batch_kwargs) == 0:
|
182
181
|
return []
|
183
|
-
|
182
|
+
if not isinstance(batch_kwargs[0], dict):
|
183
|
+
num_programs = len(batch_kwargs)
|
184
|
+
# change the list of argument values to dict of arg_name -> arg_value
|
185
|
+
batch_kwargs = [
|
186
|
+
{self.arg_names[i]: v for i, v in enumerate(arg_values)}
|
187
|
+
for arg_values in batch_kwargs
|
188
|
+
if isinstance(arg_values, (list, tuple)) and
|
189
|
+
len(self.arg_names) - len(self.arg_defaults) <= len(arg_values) <= len(self.arg_names)
|
190
|
+
]
|
191
|
+
# Ensure to raise an exception if the number of arguments mismatch
|
192
|
+
if len(batch_kwargs) != num_programs:
|
193
|
+
raise Exception("Given arguments mismatch the SGL function signature")
|
184
194
|
|
185
195
|
default_sampling_para = SglSamplingParams(
|
186
196
|
max_new_tokens=max_new_tokens,
|
sglang/launch_server.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1
|
+
"""Launch the inference server."""
|
2
|
+
|
1
3
|
import argparse
|
2
4
|
|
3
|
-
from sglang.srt.server import
|
5
|
+
from sglang.srt.server import launch_server
|
6
|
+
from sglang.srt.server_args import ServerArgs
|
4
7
|
|
5
8
|
if __name__ == "__main__":
|
6
9
|
parser = argparse.ArgumentParser()
|
sglang/launch_server_llavavid.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1
|
+
"""Launch the inference server for Llava-video model."""
|
2
|
+
|
1
3
|
import argparse
|
2
4
|
import multiprocessing as mp
|
3
5
|
|
4
6
|
from sglang.srt.server import ServerArgs, launch_server
|
5
7
|
|
6
8
|
if __name__ == "__main__":
|
7
|
-
|
8
9
|
model_overide_args = {}
|
9
10
|
|
10
11
|
model_overide_args["mm_spatial_pool_stride"] = 2
|
@@ -1,13 +1,19 @@
|
|
1
1
|
import json
|
2
2
|
from typing import Dict, Optional, Union
|
3
3
|
|
4
|
-
from outlines.caching import cache as disk_cache
|
5
|
-
from outlines.caching import disable_cache
|
6
|
-
from outlines.fsm.fsm import RegexFSM
|
7
|
-
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm
|
8
|
-
from outlines.models.transformers import TransformerTokenizer
|
9
4
|
from pydantic import BaseModel
|
10
5
|
|
6
|
+
try:
|
7
|
+
from outlines.caching import cache as disk_cache
|
8
|
+
from outlines.fsm.guide import RegexGuide
|
9
|
+
from outlines.caching import disable_cache
|
10
|
+
from outlines.fsm.guide import RegexGuide
|
11
|
+
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
12
|
+
from outlines.models.transformers import TransformerTokenizer
|
13
|
+
except ImportError as e:
|
14
|
+
print(f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n')
|
15
|
+
raise
|
16
|
+
|
11
17
|
try:
|
12
18
|
from outlines.fsm.json_schema import build_regex_from_object
|
13
19
|
except ImportError:
|
@@ -28,11 +34,12 @@ except ImportError:
|
|
28
34
|
|
29
35
|
|
30
36
|
__all__ = [
|
31
|
-
"
|
37
|
+
"RegexGuide",
|
32
38
|
"FSMInfo",
|
33
39
|
"make_deterministic_fsm",
|
34
40
|
"build_regex_from_object",
|
35
41
|
"TransformerTokenizer",
|
36
42
|
"disk_cache",
|
37
43
|
"disable_cache",
|
44
|
+
"make_byte_level_fsm",
|
38
45
|
]
|
@@ -1,4 +1,6 @@
|
|
1
|
-
|
1
|
+
"""Cache for the compressed finite state machine."""
|
2
|
+
|
3
|
+
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
2
4
|
from sglang.srt.constrained.base_cache import BaseCache
|
3
5
|
|
4
6
|
|
@@ -6,7 +8,8 @@ class FSMCache(BaseCache):
|
|
6
8
|
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
|
7
9
|
super().__init__(enable=enable)
|
8
10
|
|
9
|
-
if tokenizer_path.endswith(".json"):
|
11
|
+
if tokenizer_path.endswith(".json") or tokenizer_path.endswith(".model"):
|
12
|
+
# Do not support TiktokenTokenizer or SentencePieceTokenizer
|
10
13
|
return
|
11
14
|
|
12
15
|
from importlib.metadata import version
|
@@ -25,4 +28,4 @@ class FSMCache(BaseCache):
|
|
25
28
|
)
|
26
29
|
|
27
30
|
def init_value(self, regex):
|
28
|
-
return
|
31
|
+
return RegexGuide(regex, self.outlines_tokenizer)
|