sglang 0.3.4.post1__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_latency.py +3 -3
- sglang/bench_server_latency.py +2 -3
- sglang/bench_serving.py +92 -0
- sglang/global_config.py +9 -3
- sglang/lang/chat_template.py +50 -25
- sglang/lang/interpreter.py +9 -1
- sglang/lang/ir.py +11 -2
- sglang/launch_server.py +1 -1
- sglang/srt/configs/model_config.py +76 -15
- sglang/srt/constrained/__init__.py +18 -0
- sglang/srt/constrained/bnf_cache.py +61 -0
- sglang/srt/constrained/fsm_cache.py +10 -3
- sglang/srt/constrained/grammar.py +190 -0
- sglang/srt/hf_transformers_utils.py +20 -5
- sglang/srt/layers/attention/flashinfer_backend.py +5 -5
- sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
- sglang/srt/layers/fused_moe/fused_moe.py +4 -3
- sglang/srt/layers/fused_moe/layer.py +28 -0
- sglang/srt/layers/logits_processor.py +5 -5
- sglang/srt/layers/quantization/base_config.py +16 -1
- sglang/srt/layers/rotary_embedding.py +15 -48
- sglang/srt/layers/sampler.py +51 -39
- sglang/srt/layers/vocab_parallel_embedding.py +486 -0
- sglang/srt/managers/data_parallel_controller.py +8 -7
- sglang/srt/managers/detokenizer_manager.py +11 -9
- sglang/srt/managers/image_processor.py +4 -3
- sglang/srt/managers/io_struct.py +80 -78
- sglang/srt/managers/schedule_batch.py +46 -52
- sglang/srt/managers/schedule_policy.py +24 -13
- sglang/srt/managers/scheduler.py +145 -82
- sglang/srt/managers/tokenizer_manager.py +236 -334
- sglang/srt/managers/tp_worker.py +5 -5
- sglang/srt/managers/tp_worker_overlap_thread.py +58 -21
- sglang/srt/mem_cache/flush_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +10 -3
- sglang/srt/model_executor/cuda_graph_runner.py +34 -23
- sglang/srt/model_executor/forward_batch_info.py +6 -9
- sglang/srt/model_executor/model_runner.py +10 -19
- sglang/srt/models/baichuan.py +4 -4
- sglang/srt/models/chatglm.py +4 -4
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +5 -5
- sglang/srt/models/deepseek.py +4 -4
- sglang/srt/models/deepseek_v2.py +4 -4
- sglang/srt/models/exaone.py +4 -4
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -1
- sglang/srt/models/gpt2.py +287 -0
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +4 -4
- sglang/srt/models/internlm2.py +4 -4
- sglang/srt/models/llama.py +15 -7
- sglang/srt/models/llama_embedding.py +2 -10
- sglang/srt/models/llama_reward.py +5 -0
- sglang/srt/models/minicpm.py +4 -4
- sglang/srt/models/minicpm3.py +4 -4
- sglang/srt/models/mixtral.py +7 -5
- sglang/srt/models/mixtral_quant.py +4 -4
- sglang/srt/models/mllama.py +5 -5
- sglang/srt/models/olmo.py +4 -4
- sglang/srt/models/olmoe.py +4 -4
- sglang/srt/models/qwen.py +4 -4
- sglang/srt/models/qwen2.py +4 -4
- sglang/srt/models/qwen2_moe.py +4 -4
- sglang/srt/models/qwen2_vl.py +4 -8
- sglang/srt/models/stablelm.py +4 -4
- sglang/srt/models/torch_native_llama.py +4 -4
- sglang/srt/models/xverse.py +4 -4
- sglang/srt/models/xverse_moe.py +4 -4
- sglang/srt/openai_api/adapter.py +52 -66
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -13
- sglang/srt/sampling/sampling_params.py +5 -7
- sglang/srt/server.py +41 -33
- sglang/srt/server_args.py +34 -5
- sglang/srt/utils.py +40 -56
- sglang/test/run_eval.py +2 -0
- sglang/test/runners.py +2 -1
- sglang/test/srt/sampling/penaltylib/utils.py +1 -0
- sglang/test/test_utils.py +151 -6
- sglang/utils.py +62 -1
- sglang/version.py +1 -1
- sglang-0.3.5.dist-info/METADATA +344 -0
- sglang-0.3.5.dist-info/RECORD +152 -0
- {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
- sglang-0.3.4.post1.dist-info/METADATA +0 -900
- sglang-0.3.4.post1.dist-info/RECORD +0 -148
- {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
- {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
sglang/api.py
CHANGED
@@ -99,7 +99,7 @@ def gen(
|
|
99
99
|
regex: Optional[str] = None,
|
100
100
|
json_schema: Optional[str] = None,
|
101
101
|
):
|
102
|
-
"""Call the model to generate. See the meaning of the arguments in docs/
|
102
|
+
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
|
103
103
|
|
104
104
|
if choices:
|
105
105
|
return SglSelect(
|
sglang/bench_latency.py
CHANGED
@@ -129,9 +129,9 @@ def load_model(server_args, port_args, tp_rank):
|
|
129
129
|
|
130
130
|
model_config = ModelConfig(
|
131
131
|
server_args.model_path,
|
132
|
-
server_args.trust_remote_code,
|
132
|
+
trust_remote_code=server_args.trust_remote_code,
|
133
133
|
context_length=server_args.context_length,
|
134
|
-
model_override_args=
|
134
|
+
model_override_args=server_args.json_model_override_args,
|
135
135
|
)
|
136
136
|
model_runner = ModelRunner(
|
137
137
|
model_config=model_config,
|
@@ -550,4 +550,4 @@ if __name__ == "__main__":
|
|
550
550
|
except Exception as e:
|
551
551
|
raise e
|
552
552
|
finally:
|
553
|
-
kill_child_process(
|
553
|
+
kill_child_process()
|
sglang/bench_server_latency.py
CHANGED
@@ -15,7 +15,6 @@ import dataclasses
|
|
15
15
|
import itertools
|
16
16
|
import json
|
17
17
|
import multiprocessing
|
18
|
-
import os
|
19
18
|
import time
|
20
19
|
from typing import Tuple
|
21
20
|
|
@@ -70,7 +69,7 @@ def launch_server_internal(server_args):
|
|
70
69
|
except Exception as e:
|
71
70
|
raise e
|
72
71
|
finally:
|
73
|
-
kill_child_process(
|
72
|
+
kill_child_process()
|
74
73
|
|
75
74
|
|
76
75
|
def launch_server_process(server_args: ServerArgs):
|
@@ -176,7 +175,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|
176
175
|
)
|
177
176
|
finally:
|
178
177
|
if proc:
|
179
|
-
kill_child_process(proc.pid)
|
178
|
+
kill_child_process(proc.pid, include_self=True)
|
180
179
|
|
181
180
|
print(f"\nResults are saved to {bench_args.result_filename}")
|
182
181
|
|
sglang/bench_serving.py
CHANGED
@@ -222,6 +222,85 @@ async def async_request_openai_completions(
|
|
222
222
|
return output
|
223
223
|
|
224
224
|
|
225
|
+
async def async_request_truss(
|
226
|
+
request_func_input: RequestFuncInput,
|
227
|
+
pbar: Optional[tqdm] = None,
|
228
|
+
) -> RequestFuncOutput:
|
229
|
+
api_url = request_func_input.api_url
|
230
|
+
|
231
|
+
prompt = request_func_input.prompt
|
232
|
+
|
233
|
+
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
234
|
+
payload = {
|
235
|
+
"model": request_func_input.model,
|
236
|
+
"prompt": prompt,
|
237
|
+
"temperature": 0.0,
|
238
|
+
"best_of": 1,
|
239
|
+
"max_tokens": request_func_input.output_len,
|
240
|
+
"stream": not args.disable_stream,
|
241
|
+
"ignore_eos": not args.disable_ignore_eos,
|
242
|
+
**request_func_input.extra_request_body,
|
243
|
+
}
|
244
|
+
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
245
|
+
|
246
|
+
output = RequestFuncOutput()
|
247
|
+
output.prompt_len = request_func_input.prompt_len
|
248
|
+
|
249
|
+
generated_text = ""
|
250
|
+
ttft = 0.0
|
251
|
+
st = time.perf_counter()
|
252
|
+
most_recent_timestamp = st
|
253
|
+
try:
|
254
|
+
async with session.post(
|
255
|
+
url=api_url, json=payload, headers=headers
|
256
|
+
) as response:
|
257
|
+
if response.status == 200:
|
258
|
+
async for chunk_bytes in response.content:
|
259
|
+
chunk_bytes = chunk_bytes.strip()
|
260
|
+
if not chunk_bytes:
|
261
|
+
continue
|
262
|
+
|
263
|
+
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
264
|
+
latency = time.perf_counter() - st
|
265
|
+
if chunk == "[DONE]":
|
266
|
+
pass
|
267
|
+
else:
|
268
|
+
data = json.loads(chunk)
|
269
|
+
|
270
|
+
# NOTE: Some completion API might have a last
|
271
|
+
# usage summary response without a token so we
|
272
|
+
# want to check a token was generated
|
273
|
+
if data["choices"][0]["delta"]["content"]:
|
274
|
+
timestamp = time.perf_counter()
|
275
|
+
# First token
|
276
|
+
if ttft == 0.0:
|
277
|
+
ttft = time.perf_counter() - st
|
278
|
+
output.ttft = ttft
|
279
|
+
|
280
|
+
# Decoding phase
|
281
|
+
else:
|
282
|
+
output.itl.append(timestamp - most_recent_timestamp)
|
283
|
+
|
284
|
+
most_recent_timestamp = timestamp
|
285
|
+
generated_text += data["choices"][0]["delta"]["content"]
|
286
|
+
|
287
|
+
output.generated_text = generated_text
|
288
|
+
output.success = True
|
289
|
+
output.latency = latency
|
290
|
+
output.output_len = request_func_input.output_len
|
291
|
+
else:
|
292
|
+
output.error = response.reason or ""
|
293
|
+
output.success = False
|
294
|
+
except Exception:
|
295
|
+
output.success = False
|
296
|
+
exc_info = sys.exc_info()
|
297
|
+
output.error = "".join(traceback.format_exception(*exc_info))
|
298
|
+
|
299
|
+
if pbar:
|
300
|
+
pbar.update(1)
|
301
|
+
return output
|
302
|
+
|
303
|
+
|
225
304
|
async def async_request_sglang_generate(
|
226
305
|
request_func_input: RequestFuncInput,
|
227
306
|
pbar: Optional[tqdm] = None,
|
@@ -350,6 +429,7 @@ ASYNC_REQUEST_FUNCS = {
|
|
350
429
|
"lmdeploy": async_request_openai_completions,
|
351
430
|
"trt": async_request_trt_llm,
|
352
431
|
"gserver": async_request_gserver,
|
432
|
+
"truss": async_request_truss,
|
353
433
|
}
|
354
434
|
|
355
435
|
|
@@ -873,6 +953,7 @@ def run_benchmark(args_: argparse.Namespace):
|
|
873
953
|
"vllm": 8000,
|
874
954
|
"trt": 8000,
|
875
955
|
"gserver": 9988,
|
956
|
+
"truss": 8080,
|
876
957
|
}.get(args.backend, 30000)
|
877
958
|
|
878
959
|
model_url = (
|
@@ -905,9 +986,20 @@ def run_benchmark(args_: argparse.Namespace):
|
|
905
986
|
elif args.backend == "gserver":
|
906
987
|
api_url = args.base_url if args.base_url else f"{args.host}:{args.port}"
|
907
988
|
args.model = args.model or "default"
|
989
|
+
elif args.backend == "truss":
|
990
|
+
api_url = (
|
991
|
+
f"{args.base_url}/v1/models/model:predict"
|
992
|
+
if args.base_url
|
993
|
+
else f"http://{args.host}:{args.port}/v1/models/model:predict"
|
994
|
+
)
|
908
995
|
|
909
996
|
# Get model name
|
910
997
|
if args.model is None:
|
998
|
+
if args.backend == "truss":
|
999
|
+
print(
|
1000
|
+
"Please provide a model with `--model` when using truss backend. e.g. --model meta-llama/Llama-3.1-8B-Instruct"
|
1001
|
+
)
|
1002
|
+
sys.exit(1)
|
911
1003
|
try:
|
912
1004
|
response = requests.get(model_url)
|
913
1005
|
model_list = response.json().get("data", [])
|
sglang/global_config.py
CHANGED
@@ -14,9 +14,15 @@ class GlobalConfig:
|
|
14
14
|
self.default_backend = None
|
15
15
|
|
16
16
|
# Runtime constants: New generation token ratio estimation
|
17
|
-
self.
|
18
|
-
|
19
|
-
|
17
|
+
self.default_init_new_token_ratio = float(
|
18
|
+
os.environ.get("SGLANG_INIT_NEW_TOKEN_RATIO", 0.7)
|
19
|
+
)
|
20
|
+
self.default_min_new_token_ratio_factor = float(
|
21
|
+
os.environ.get("SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR", 0.14)
|
22
|
+
)
|
23
|
+
self.default_new_token_ratio_decay_steps = float(
|
24
|
+
os.environ.get("SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS", 600)
|
25
|
+
)
|
20
26
|
|
21
27
|
# Runtime constants: others
|
22
28
|
self.retract_decode_steps = 20
|
sglang/lang/chat_template.py
CHANGED
@@ -116,12 +116,10 @@ register_chat_template(
|
|
116
116
|
)
|
117
117
|
)
|
118
118
|
|
119
|
-
|
120
|
-
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
|
121
|
-
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
119
|
+
|
122
120
|
register_chat_template(
|
123
121
|
ChatTemplate(
|
124
|
-
name="
|
122
|
+
name="chatml-llava",
|
125
123
|
default_system_prompt="You are a helpful assistant.",
|
126
124
|
role_prefix_and_suffix={
|
127
125
|
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
@@ -130,13 +128,17 @@ register_chat_template(
|
|
130
128
|
},
|
131
129
|
style=ChatTemplateStyle.PLAIN,
|
132
130
|
stop_str=("<|im_end|>",),
|
131
|
+
image_token="<image>\n",
|
133
132
|
)
|
134
133
|
)
|
135
134
|
|
136
|
-
|
135
|
+
|
136
|
+
# There is default system prompt for qwen
|
137
|
+
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
|
138
|
+
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
137
139
|
register_chat_template(
|
138
140
|
ChatTemplate(
|
139
|
-
name="
|
141
|
+
name="qwen",
|
140
142
|
default_system_prompt="You are a helpful assistant.",
|
141
143
|
role_prefix_and_suffix={
|
142
144
|
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
@@ -144,15 +146,14 @@ register_chat_template(
|
|
144
146
|
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
145
147
|
},
|
146
148
|
style=ChatTemplateStyle.PLAIN,
|
147
|
-
stop_str=("<|im_end|>"),
|
148
|
-
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
149
|
+
stop_str=("<|im_end|>",),
|
149
150
|
)
|
150
151
|
)
|
151
152
|
|
152
|
-
|
153
|
+
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
|
153
154
|
register_chat_template(
|
154
155
|
ChatTemplate(
|
155
|
-
name="
|
156
|
+
name="qwen2-vl",
|
156
157
|
default_system_prompt="You are a helpful assistant.",
|
157
158
|
role_prefix_and_suffix={
|
158
159
|
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
@@ -161,7 +162,7 @@ register_chat_template(
|
|
161
162
|
},
|
162
163
|
style=ChatTemplateStyle.PLAIN,
|
163
164
|
stop_str=("<|im_end|>",),
|
164
|
-
image_token="
|
165
|
+
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
165
166
|
)
|
166
167
|
)
|
167
168
|
|
@@ -182,37 +183,46 @@ register_chat_template(
|
|
182
183
|
)
|
183
184
|
)
|
184
185
|
|
185
|
-
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
|
186
186
|
register_chat_template(
|
187
187
|
ChatTemplate(
|
188
|
-
name="
|
188
|
+
name="llama-2-chat",
|
189
189
|
default_system_prompt=None,
|
190
190
|
role_prefix_and_suffix={
|
191
|
-
"system": ("", ""),
|
192
|
-
"user": ("
|
193
|
-
"assistant": ("", "
|
191
|
+
"system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
|
192
|
+
"user": ("[INST] ", " [/INST]"),
|
193
|
+
"assistant": ("", " </s><s>"),
|
194
194
|
},
|
195
|
-
style=ChatTemplateStyle.
|
196
|
-
stop_str=("<|im_end|>",),
|
195
|
+
style=ChatTemplateStyle.LLAMA2,
|
197
196
|
)
|
198
197
|
)
|
199
198
|
|
200
199
|
register_chat_template(
|
201
200
|
ChatTemplate(
|
202
|
-
name="llama-
|
201
|
+
name="llama-3-instruct",
|
203
202
|
default_system_prompt=None,
|
204
203
|
role_prefix_and_suffix={
|
205
|
-
"system": (
|
206
|
-
|
207
|
-
|
204
|
+
"system": (
|
205
|
+
"<|start_header_id|>system<|end_header_id|>\n\n",
|
206
|
+
"<|eot_id|>",
|
207
|
+
),
|
208
|
+
"user": (
|
209
|
+
"<|start_header_id|>user<|end_header_id|>\n\n",
|
210
|
+
"<|eot_id|>",
|
211
|
+
),
|
212
|
+
"assistant": (
|
213
|
+
"<|start_header_id|>assistant<|end_header_id|>\n\n",
|
214
|
+
"<|eot_id|>",
|
215
|
+
),
|
208
216
|
},
|
209
|
-
|
217
|
+
stop_str=("<|eot_id|>",),
|
218
|
+
image_token="<|image|>",
|
210
219
|
)
|
211
220
|
)
|
212
221
|
|
222
|
+
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
|
213
223
|
register_chat_template(
|
214
224
|
ChatTemplate(
|
215
|
-
name="llama-3-instruct",
|
225
|
+
name="llama-3-instruct-llava",
|
216
226
|
default_system_prompt=None,
|
217
227
|
role_prefix_and_suffix={
|
218
228
|
"system": (
|
@@ -229,7 +239,22 @@ register_chat_template(
|
|
229
239
|
),
|
230
240
|
},
|
231
241
|
stop_str=("<|eot_id|>",),
|
232
|
-
image_token="
|
242
|
+
image_token="<image>\n",
|
243
|
+
)
|
244
|
+
)
|
245
|
+
|
246
|
+
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
|
247
|
+
register_chat_template(
|
248
|
+
ChatTemplate(
|
249
|
+
name="yi-1.5",
|
250
|
+
default_system_prompt=None,
|
251
|
+
role_prefix_and_suffix={
|
252
|
+
"system": ("", ""),
|
253
|
+
"user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
|
254
|
+
"assistant": ("", "<|im_end|>\n"),
|
255
|
+
},
|
256
|
+
style=ChatTemplateStyle.PLAIN,
|
257
|
+
stop_str=("<|im_end|>",),
|
233
258
|
)
|
234
259
|
)
|
235
260
|
|
sglang/lang/interpreter.py
CHANGED
@@ -54,7 +54,14 @@ def run_internal(state, program, func_args, func_kwargs, sync):
|
|
54
54
|
|
55
55
|
|
56
56
|
def run_program(
|
57
|
-
program,
|
57
|
+
program,
|
58
|
+
backend,
|
59
|
+
func_args,
|
60
|
+
func_kwargs,
|
61
|
+
default_sampling_para,
|
62
|
+
stream,
|
63
|
+
sync=False,
|
64
|
+
use_thread=True,
|
58
65
|
):
|
59
66
|
if hasattr(backend, "endpoint"):
|
60
67
|
backend = backend.endpoint
|
@@ -67,6 +74,7 @@ def run_program(
|
|
67
74
|
chat_template=None,
|
68
75
|
stream=stream,
|
69
76
|
num_api_spec_tokens=program.num_api_spec_tokens,
|
77
|
+
use_thread=use_thread,
|
70
78
|
)
|
71
79
|
state = ProgramState(stream_executor)
|
72
80
|
|
sglang/lang/ir.py
CHANGED
@@ -168,6 +168,7 @@ class SglFunction:
|
|
168
168
|
return_text_in_logprobs: Optional[bool] = None,
|
169
169
|
stream: bool = False,
|
170
170
|
backend=None,
|
171
|
+
use_thread: bool = True,
|
171
172
|
**kwargs,
|
172
173
|
):
|
173
174
|
from sglang.lang.interpreter import run_program
|
@@ -195,7 +196,15 @@ class SglFunction:
|
|
195
196
|
return_text_in_logprobs=return_text_in_logprobs,
|
196
197
|
)
|
197
198
|
backend = backend or global_config.default_backend
|
198
|
-
return run_program(
|
199
|
+
return run_program(
|
200
|
+
self,
|
201
|
+
backend,
|
202
|
+
args,
|
203
|
+
kwargs,
|
204
|
+
default_sampling_para,
|
205
|
+
stream,
|
206
|
+
use_thread=use_thread,
|
207
|
+
)
|
199
208
|
|
200
209
|
def run_batch(
|
201
210
|
self,
|
@@ -445,7 +454,7 @@ class SglGen(SglExpr):
|
|
445
454
|
regex: Optional[str] = None,
|
446
455
|
json_schema: Optional[str] = None,
|
447
456
|
):
|
448
|
-
"""Call the model to generate. See the meaning of the arguments in docs/
|
457
|
+
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
|
449
458
|
super().__init__()
|
450
459
|
self.name = name
|
451
460
|
self.sampling_params = SglSamplingParams(
|
sglang/launch_server.py
CHANGED
@@ -13,13 +13,18 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
+
import json
|
17
|
+
import logging
|
18
|
+
import os
|
16
19
|
from enum import IntEnum, auto
|
17
|
-
from typing import Optional
|
20
|
+
from typing import List, Optional
|
18
21
|
|
19
22
|
from transformers import PretrainedConfig
|
20
23
|
|
21
24
|
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
22
25
|
|
26
|
+
logger = logging.getLogger(__name__)
|
27
|
+
|
23
28
|
|
24
29
|
class AttentionArch(IntEnum):
|
25
30
|
MLA = auto()
|
@@ -34,22 +39,47 @@ class ModelConfig:
|
|
34
39
|
revision: Optional[str] = None,
|
35
40
|
context_length: Optional[int] = None,
|
36
41
|
model_override_args: Optional[dict] = None,
|
42
|
+
is_embedding: Optional[bool] = None
|
37
43
|
) -> None:
|
38
|
-
|
39
|
-
self.
|
40
|
-
self.revision = revision
|
41
|
-
self.model_override_args = model_override_args
|
44
|
+
# Parse args
|
45
|
+
self.model_override_args = json.loads(model_override_args)
|
42
46
|
self.hf_config = get_config(
|
43
|
-
|
44
|
-
trust_remote_code,
|
45
|
-
revision,
|
46
|
-
model_override_args=model_override_args,
|
47
|
+
path,
|
48
|
+
trust_remote_code=trust_remote_code,
|
49
|
+
revision=revision,
|
50
|
+
model_override_args=self.model_override_args,
|
47
51
|
)
|
48
52
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
53
|
+
|
54
|
+
# Check model type
|
55
|
+
self.is_generation = is_generation_model(self.hf_config.architectures, is_embedding)
|
56
|
+
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
|
57
|
+
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
58
|
+
|
59
|
+
# Derive context length
|
60
|
+
derived_context_len = get_context_length(self.hf_text_config)
|
61
|
+
allow_long_context = os.environ.get(
|
62
|
+
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
|
63
|
+
)
|
64
|
+
|
49
65
|
if context_length is not None:
|
50
|
-
|
66
|
+
if context_length > derived_context_len:
|
67
|
+
if allow_long_context:
|
68
|
+
logger.warning(
|
69
|
+
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
|
70
|
+
f"This may lead to incorrect model outputs or CUDA errors."
|
71
|
+
)
|
72
|
+
self.context_len = context_length
|
73
|
+
else:
|
74
|
+
raise ValueError(
|
75
|
+
f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
|
76
|
+
f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. "
|
77
|
+
f"To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
|
78
|
+
)
|
79
|
+
else:
|
80
|
+
self.context_len = context_length
|
51
81
|
else:
|
52
|
-
self.context_len =
|
82
|
+
self.context_len = derived_context_len
|
53
83
|
|
54
84
|
# Unify the config keys for hf_text_config
|
55
85
|
self.head_dim = getattr(
|
@@ -58,7 +88,7 @@ class ModelConfig:
|
|
58
88
|
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
|
59
89
|
)
|
60
90
|
|
61
|
-
# FIXME: temporary special judge for
|
91
|
+
# FIXME: temporary special judge for MLA architecture
|
62
92
|
if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
|
63
93
|
self.head_dim = 256
|
64
94
|
self.attention_arch = AttentionArch.MLA
|
@@ -89,8 +119,6 @@ class ModelConfig:
|
|
89
119
|
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
90
120
|
self.vocab_size = self.hf_text_config.vocab_size
|
91
121
|
|
92
|
-
self.is_encoder_decoder = self.hf_config.model_type in ["mllama"]
|
93
|
-
|
94
122
|
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
95
123
|
def get_total_num_kv_heads(self) -> int:
|
96
124
|
"""Returns the total number of KV heads."""
|
@@ -140,7 +168,6 @@ class ModelConfig:
|
|
140
168
|
# equal to the number of attention heads.
|
141
169
|
return self.hf_text_config.num_attention_heads
|
142
170
|
|
143
|
-
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328
|
144
171
|
def get_num_kv_heads(self, tensor_parallel_size) -> int:
|
145
172
|
"""Returns the number of KV heads per GPU."""
|
146
173
|
total_num_kv_heads = self.get_total_num_kv_heads()
|
@@ -169,3 +196,37 @@ def get_hf_text_config(config: PretrainedConfig):
|
|
169
196
|
return config.text_config
|
170
197
|
else:
|
171
198
|
return config
|
199
|
+
|
200
|
+
|
201
|
+
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
|
202
|
+
# We have two ways to determine whether a model is a generative model.
|
203
|
+
# 1. Check the model architectue
|
204
|
+
# 2. check the `is_embedding` server args
|
205
|
+
|
206
|
+
if (
|
207
|
+
"LlamaEmbeddingModel" in model_architectures
|
208
|
+
or "MistralModel" in model_architectures
|
209
|
+
or "LlamaForSequenceClassification" in model_architectures
|
210
|
+
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
211
|
+
):
|
212
|
+
return False
|
213
|
+
else:
|
214
|
+
return not is_embedding
|
215
|
+
|
216
|
+
|
217
|
+
def is_multimodal_model(model_architectures: List[str]):
|
218
|
+
if (
|
219
|
+
"LlavaLlamaForCausalLM" in model_architectures
|
220
|
+
or "LlavaQwenForCausalLM" in model_architectures
|
221
|
+
or "LlavaMistralForCausalLM" in model_architectures
|
222
|
+
or "LlavaVidForCausalLM" in model_architectures
|
223
|
+
or "MllamaForConditionalGeneration" in model_architectures
|
224
|
+
or "Qwen2VLForConditionalGeneration" in model_architectures
|
225
|
+
):
|
226
|
+
return True
|
227
|
+
else:
|
228
|
+
return False
|
229
|
+
|
230
|
+
|
231
|
+
def is_encoder_decoder_model(model_architectures: List[str]):
|
232
|
+
return "MllamaForConditionalGeneration" in model_architectures
|
@@ -51,6 +51,21 @@ except ImportError:
|
|
51
51
|
return build_regex_from_schema(schema, whitespace_pattern)
|
52
52
|
|
53
53
|
|
54
|
+
try:
|
55
|
+
from xgrammar import (
|
56
|
+
GrammarMatcher,
|
57
|
+
GrammarMatcherInitContext,
|
58
|
+
GrammarMatcherInitContextCache,
|
59
|
+
)
|
60
|
+
except ImportError as e:
|
61
|
+
|
62
|
+
class Dummy:
|
63
|
+
pass
|
64
|
+
|
65
|
+
GrammarMatcher = Dummy
|
66
|
+
GrammarMatcherInitContext = Dummy
|
67
|
+
GrammarMatcherInitContextCache = Dummy
|
68
|
+
|
54
69
|
__all__ = [
|
55
70
|
"RegexGuide",
|
56
71
|
"FSMInfo",
|
@@ -60,4 +75,7 @@ __all__ = [
|
|
60
75
|
"disk_cache",
|
61
76
|
"disable_cache",
|
62
77
|
"make_byte_level_fsm",
|
78
|
+
"GrammarMatcher",
|
79
|
+
"GrammarMatcherInitContext",
|
80
|
+
"GrammarMatcherInitContextCache",
|
63
81
|
]
|
@@ -0,0 +1,61 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
Unless required by applicable law or agreed to in writing, software
|
8
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
9
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
10
|
+
See the License for the specific language governing permissions and
|
11
|
+
limitations under the License.
|
12
|
+
"""
|
13
|
+
|
14
|
+
"""Cache for the compressed finite state machine."""
|
15
|
+
|
16
|
+
from typing import Tuple
|
17
|
+
|
18
|
+
from transformers import AutoTokenizer
|
19
|
+
|
20
|
+
from sglang.srt.constrained import (
|
21
|
+
GrammarMatcher,
|
22
|
+
GrammarMatcherInitContext,
|
23
|
+
GrammarMatcherInitContextCache,
|
24
|
+
)
|
25
|
+
|
26
|
+
MAX_ROLLBACK_TOKENS = 10
|
27
|
+
|
28
|
+
|
29
|
+
class BNFCache:
|
30
|
+
grammar_cache: GrammarMatcherInitContextCache
|
31
|
+
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
tokenizer_path,
|
35
|
+
tokenizer_args_dict,
|
36
|
+
skip_tokenizer_init=False,
|
37
|
+
whitespace_patterns=None,
|
38
|
+
):
|
39
|
+
# TODO(dark): how to deal with whitespace_patterns and skip_tokenizer_init
|
40
|
+
if skip_tokenizer_init:
|
41
|
+
return
|
42
|
+
|
43
|
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
|
44
|
+
self.grammar_cache = GrammarMatcherInitContextCache(
|
45
|
+
tokenizer_or_vocab=tokenizer
|
46
|
+
)
|
47
|
+
|
48
|
+
def get_context(self, key: Tuple[str, str]) -> GrammarMatcherInitContext:
|
49
|
+
key_type, key_string = key
|
50
|
+
if key_type == "json":
|
51
|
+
return self.grammar_cache.get_init_context_for_json_schema(key_string)
|
52
|
+
elif key_type == "regex":
|
53
|
+
raise ValueError(f"regex hasn't been supported by xgrammar yet")
|
54
|
+
else:
|
55
|
+
raise ValueError(f"Invalid key_type: {key_type}")
|
56
|
+
|
57
|
+
def query(self, key: Tuple[str, str], vocab_size: int) -> GrammarMatcher:
|
58
|
+
ctx = self.get_context(key)
|
59
|
+
return GrammarMatcher(
|
60
|
+
ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, mask_vocab_size=vocab_size
|
61
|
+
)
|
@@ -73,9 +73,16 @@ class FSMCache(BaseToolCache):
|
|
73
73
|
def init_value(self, key):
|
74
74
|
key_type, key_string = key
|
75
75
|
if key_type == "json":
|
76
|
-
|
77
|
-
|
78
|
-
|
76
|
+
try:
|
77
|
+
regex = build_regex_from_schema(
|
78
|
+
key_string,
|
79
|
+
whitespace_pattern=self.constrained_json_whitespace_pattern,
|
80
|
+
)
|
81
|
+
except NotImplementedError as e:
|
82
|
+
logger.warning(
|
83
|
+
f"skip invalid json schema: json_schema={key_string}, {e=}"
|
84
|
+
)
|
85
|
+
return None, key_string
|
79
86
|
elif key_type == "regex":
|
80
87
|
regex = key_string
|
81
88
|
else:
|