sglang 0.4.5.post3__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +19 -3
- sglang/bench_serving.py +8 -9
- sglang/compile_deep_gemm.py +45 -4
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +1 -1
- sglang/srt/configs/model_config.py +9 -3
- sglang/srt/constrained/llguidance_backend.py +78 -61
- sglang/srt/conversation.py +34 -1
- sglang/srt/disaggregation/decode.py +59 -11
- sglang/srt/disaggregation/mini_lb.py +45 -8
- sglang/srt/disaggregation/mooncake/conn.py +198 -31
- sglang/srt/disaggregation/prefill.py +24 -9
- sglang/srt/entrypoints/http_server.py +8 -2
- sglang/srt/function_call_parser.py +77 -5
- sglang/srt/layers/attention/base_attn_backend.py +3 -0
- sglang/srt/layers/attention/flashattention_backend.py +28 -10
- sglang/srt/layers/attention/flashmla_backend.py +8 -11
- sglang/srt/layers/attention/vision.py +2 -0
- sglang/srt/layers/layernorm.py +38 -16
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/fused_moe_native.py +2 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
- sglang/srt/layers/pooler.py +6 -0
- sglang/srt/layers/quantization/awq.py +5 -1
- sglang/srt/layers/quantization/deep_gemm.py +17 -10
- sglang/srt/layers/quantization/int8_kernel.py +32 -1
- sglang/srt/layers/radix_attention.py +13 -3
- sglang/srt/layers/rotary_embedding.py +170 -126
- sglang/srt/managers/data_parallel_controller.py +10 -3
- sglang/srt/managers/io_struct.py +7 -0
- sglang/srt/managers/mm_utils.py +85 -28
- sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
- sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
- sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
- sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
- sglang/srt/managers/schedule_batch.py +29 -12
- sglang/srt/managers/scheduler.py +31 -20
- sglang/srt/managers/tokenizer_manager.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +87 -0
- sglang/srt/model_executor/cuda_graph_runner.py +4 -3
- sglang/srt/model_executor/forward_batch_info.py +51 -95
- sglang/srt/model_executor/model_runner.py +11 -24
- sglang/srt/models/deepseek.py +12 -2
- sglang/srt/models/deepseek_nextn.py +101 -6
- sglang/srt/models/deepseek_v2.py +144 -70
- sglang/srt/models/deepseek_vl2.py +9 -4
- sglang/srt/models/gemma3_causal.py +1 -1
- sglang/srt/models/llama4.py +0 -1
- sglang/srt/models/minicpmo.py +5 -1
- sglang/srt/models/mllama4.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +3 -6
- sglang/srt/models/qwen2_vl.py +3 -7
- sglang/srt/models/roberta.py +178 -0
- sglang/srt/openai_api/adapter.py +18 -8
- sglang/srt/server_args.py +15 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/torch_memory_saver_adapter.py +10 -1
- sglang/srt/utils.py +2 -1
- sglang/test/runners.py +6 -13
- sglang/test/test_utils.py +36 -18
- sglang/version.py +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/METADATA +4 -5
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/RECORD +70 -68
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py
CHANGED
@@ -57,6 +57,7 @@ import torch
|
|
57
57
|
import torch.distributed as dist
|
58
58
|
|
59
59
|
from sglang.srt.configs.model_config import ModelConfig
|
60
|
+
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
|
60
61
|
from sglang.srt.entrypoints.engine import _set_envs_and_config
|
61
62
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
62
63
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
@@ -85,6 +86,7 @@ class BenchArgs:
|
|
85
86
|
correctness_test: bool = False
|
86
87
|
# This is only used for correctness test
|
87
88
|
cut_len: int = 4
|
89
|
+
log_decode_step: int = 0
|
88
90
|
profile: bool = False
|
89
91
|
profile_filename_prefix: str = "profile"
|
90
92
|
|
@@ -105,6 +107,12 @@ class BenchArgs:
|
|
105
107
|
)
|
106
108
|
parser.add_argument("--correctness-test", action="store_true")
|
107
109
|
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
110
|
+
parser.add_argument(
|
111
|
+
"--log-decode-step",
|
112
|
+
type=int,
|
113
|
+
default=BenchArgs.log_decode_step,
|
114
|
+
help="Log decode latency by step, default is set to zero to disable.",
|
115
|
+
)
|
108
116
|
parser.add_argument(
|
109
117
|
"--profile", action="store_true", help="Use Torch Profiler."
|
110
118
|
)
|
@@ -335,6 +343,7 @@ def latency_test_run_once(
|
|
335
343
|
input_len,
|
336
344
|
output_len,
|
337
345
|
device,
|
346
|
+
log_decode_step,
|
338
347
|
profile,
|
339
348
|
profile_filename_prefix,
|
340
349
|
):
|
@@ -394,9 +403,9 @@ def latency_test_run_once(
|
|
394
403
|
tot_latency += latency
|
395
404
|
throughput = batch_size / latency
|
396
405
|
decode_latencies.append(latency)
|
397
|
-
if i < 5:
|
406
|
+
if i < 5 or (log_decode_step > 0 and i % log_decode_step == 0):
|
398
407
|
rank_print(
|
399
|
-
f"Decode. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
408
|
+
f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
400
409
|
)
|
401
410
|
|
402
411
|
if profile:
|
@@ -457,8 +466,9 @@ def latency_test(
|
|
457
466
|
reqs,
|
458
467
|
bench_args.batch_size[0],
|
459
468
|
bench_args.input_len[0],
|
460
|
-
|
469
|
+
min(32, bench_args.output_len[0]), # shorter decoding to speed up the warmup
|
461
470
|
server_args.device,
|
471
|
+
log_decode_step=0,
|
462
472
|
profile=False,
|
463
473
|
profile_filename_prefix="", # not used
|
464
474
|
)
|
@@ -480,6 +490,7 @@ def latency_test(
|
|
480
490
|
il,
|
481
491
|
ol,
|
482
492
|
server_args.device,
|
493
|
+
bench_args.log_decode_step,
|
483
494
|
bench_args.profile if tp_rank == 0 else None,
|
484
495
|
bench_args.profile_filename_prefix,
|
485
496
|
)
|
@@ -492,8 +503,13 @@ def latency_test(
|
|
492
503
|
for result in result_list:
|
493
504
|
fout.write(json.dumps(result) + "\n")
|
494
505
|
|
506
|
+
if server_args.tp_size > 1:
|
507
|
+
destroy_distributed_environment()
|
508
|
+
|
495
509
|
|
496
510
|
def main(server_args, bench_args):
|
511
|
+
server_args.cuda_graph_max_bs = max(bench_args.batch_size)
|
512
|
+
|
497
513
|
_set_envs_and_config(server_args)
|
498
514
|
|
499
515
|
if server_args.model_path:
|
sglang/bench_serving.py
CHANGED
@@ -295,7 +295,7 @@ async def async_request_truss(
|
|
295
295
|
# NOTE: Some completion API might have a last
|
296
296
|
# usage summary response without a token so we
|
297
297
|
# want to check a token was generated
|
298
|
-
if data["choices"][0]["
|
298
|
+
if data["choices"][0]["text"]:
|
299
299
|
timestamp = time.perf_counter()
|
300
300
|
# First token
|
301
301
|
if ttft == 0.0:
|
@@ -307,7 +307,7 @@ async def async_request_truss(
|
|
307
307
|
output.itl.append(timestamp - most_recent_timestamp)
|
308
308
|
|
309
309
|
most_recent_timestamp = timestamp
|
310
|
-
generated_text += data["choices"][0]["
|
310
|
+
generated_text += data["choices"][0]["text"]
|
311
311
|
|
312
312
|
output.generated_text = generated_text
|
313
313
|
output.success = True
|
@@ -977,6 +977,7 @@ async def benchmark(
|
|
977
977
|
profile: bool,
|
978
978
|
pd_seperated: bool = False,
|
979
979
|
flush_cache: bool = False,
|
980
|
+
warmup_requests: int = 1,
|
980
981
|
):
|
981
982
|
if backend in ASYNC_REQUEST_FUNCS:
|
982
983
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
@@ -994,11 +995,11 @@ async def benchmark(
|
|
994
995
|
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
995
996
|
|
996
997
|
# Warmup
|
997
|
-
print(f"Starting warmup with {
|
998
|
+
print(f"Starting warmup with {warmup_requests} sequences...")
|
998
999
|
|
999
1000
|
# Use the first request for all warmup iterations
|
1000
1001
|
test_prompt, test_prompt_len, test_output_len = input_requests[0]
|
1001
|
-
if lora_names
|
1002
|
+
if lora_names is not None and len(lora_names) != 0:
|
1002
1003
|
lora_name = lora_names[0]
|
1003
1004
|
else:
|
1004
1005
|
lora_name = None
|
@@ -1016,7 +1017,7 @@ async def benchmark(
|
|
1016
1017
|
|
1017
1018
|
# Run warmup requests
|
1018
1019
|
warmup_tasks = []
|
1019
|
-
for _ in range(
|
1020
|
+
for _ in range(warmup_requests):
|
1020
1021
|
warmup_tasks.append(
|
1021
1022
|
asyncio.create_task(request_func(request_func_input=test_input))
|
1022
1023
|
)
|
@@ -1024,9 +1025,7 @@ async def benchmark(
|
|
1024
1025
|
warmup_outputs = await asyncio.gather(*warmup_tasks)
|
1025
1026
|
|
1026
1027
|
# Check if at least one warmup request succeeded
|
1027
|
-
if
|
1028
|
-
output.success for output in warmup_outputs
|
1029
|
-
):
|
1028
|
+
if warmup_requests > 0 and not any(output.success for output in warmup_outputs):
|
1030
1029
|
raise ValueError(
|
1031
1030
|
"Warmup failed - Please make sure benchmark arguments "
|
1032
1031
|
f"are correctly specified. Error: {warmup_outputs[0].error}"
|
@@ -1058,7 +1057,7 @@ async def benchmark(
|
|
1058
1057
|
tasks: List[asyncio.Task] = []
|
1059
1058
|
async for request in get_request(input_requests, request_rate):
|
1060
1059
|
prompt, prompt_len, output_len = request
|
1061
|
-
if lora_names
|
1060
|
+
if lora_names is not None and len(lora_names) != 0:
|
1062
1061
|
idx = random.randint(0, len(lora_names) - 1)
|
1063
1062
|
lora_name = lora_names[idx]
|
1064
1063
|
else:
|
sglang/compile_deep_gemm.py
CHANGED
@@ -27,7 +27,11 @@ from sglang.srt.warmup import warmup
|
|
27
27
|
multiprocessing.set_start_method("spawn", force=True)
|
28
28
|
|
29
29
|
# Reduce warning
|
30
|
-
os.environ["
|
30
|
+
os.environ["SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"] = "1"
|
31
|
+
# Force enable deep gemm
|
32
|
+
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1"
|
33
|
+
# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
|
34
|
+
os.environ["SGL_CHUNKED_PREFIX_CACHE_THRESHOLD"] = "0"
|
31
35
|
|
32
36
|
|
33
37
|
@dataclasses.dataclass
|
@@ -84,8 +88,36 @@ def launch_server_process_and_send_one_request(
|
|
84
88
|
headers = {
|
85
89
|
"Content-Type": "application/json; charset=utf-8",
|
86
90
|
}
|
87
|
-
|
91
|
+
if server_args.node_rank == 0:
|
92
|
+
response = requests.get(f"{base_url}/v1/models", headers=headers)
|
93
|
+
else:
|
94
|
+
# This http api is created by launch_dummy_health_check_server for none-rank0 node.
|
95
|
+
response = requests.get(f"{base_url}/health", headers=headers)
|
88
96
|
if response.status_code == 200:
|
97
|
+
# Rank-0 node send a request to sync with other node and then return.
|
98
|
+
if server_args.node_rank == 0:
|
99
|
+
response = requests.post(
|
100
|
+
f"{base_url}/generate",
|
101
|
+
json={
|
102
|
+
"input_ids": [0, 1, 2, 3],
|
103
|
+
"sampling_params": {
|
104
|
+
"max_new_tokens": 8,
|
105
|
+
"temperature": 0,
|
106
|
+
},
|
107
|
+
},
|
108
|
+
timeout=600,
|
109
|
+
)
|
110
|
+
if response.status_code != 200:
|
111
|
+
error = response.json()
|
112
|
+
raise RuntimeError(f"Sync request failed: {error}")
|
113
|
+
# Other nodes should wait for the exit signal from Rank-0 node.
|
114
|
+
else:
|
115
|
+
start_time_waiting = time.time()
|
116
|
+
while proc.is_alive():
|
117
|
+
if time.time() - start_time_waiting < timeout:
|
118
|
+
time.sleep(10)
|
119
|
+
else:
|
120
|
+
raise TimeoutError("Waiting for main node timeout!")
|
89
121
|
return proc
|
90
122
|
except requests.RequestException:
|
91
123
|
pass
|
@@ -118,10 +150,19 @@ def run_compile(server_args: ServerArgs, compile_args: CompileArgs):
|
|
118
150
|
|
119
151
|
proc = launch_server_process_and_send_one_request(server_args, compile_args)
|
120
152
|
|
121
|
-
kill_process_tree(proc.pid)
|
122
|
-
|
123
153
|
print("\nDeepGEMM Kernels compilation finished successfully.")
|
124
154
|
|
155
|
+
# Sleep for safety
|
156
|
+
time.sleep(10)
|
157
|
+
if proc.is_alive():
|
158
|
+
# This is the rank0 node.
|
159
|
+
kill_process_tree(proc.pid)
|
160
|
+
else:
|
161
|
+
try:
|
162
|
+
kill_process_tree(proc.pid)
|
163
|
+
except Exception:
|
164
|
+
pass
|
165
|
+
|
125
166
|
|
126
167
|
if __name__ == "__main__":
|
127
168
|
parser = argparse.ArgumentParser()
|
@@ -113,7 +113,7 @@ def completion_template_exists(template_name: str) -> bool:
|
|
113
113
|
|
114
114
|
def is_completion_template_defined() -> bool:
|
115
115
|
global completion_template_name
|
116
|
-
return completion_template_name
|
116
|
+
return completion_template_name is not None
|
117
117
|
|
118
118
|
|
119
119
|
def generate_completion_prompt_from_request(request: ChatCompletionRequest) -> str:
|
@@ -182,7 +182,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
|
|
182
182
|
tokenized_str, images, seq_mask, spatial_crop = self.tokenize_with_images(
|
183
183
|
messages,
|
184
184
|
pil_images[image_index : image_index + image_token_cnt],
|
185
|
-
bos=
|
185
|
+
bos=True,
|
186
186
|
eos=True,
|
187
187
|
cropping=len(pil_images) <= 2,
|
188
188
|
max_req_input_len=max_req_input_len,
|
@@ -73,10 +73,14 @@ class ModelConfig:
|
|
73
73
|
)
|
74
74
|
|
75
75
|
if enable_multimodal is None:
|
76
|
-
|
76
|
+
mm_disabled_models = [
|
77
|
+
"Gemma3ForConditionalGeneration",
|
78
|
+
"Llama4ForConditionalGeneration",
|
79
|
+
]
|
80
|
+
if self.hf_config.architectures[0] in mm_disabled_models:
|
77
81
|
enable_multimodal = False
|
78
82
|
logger.info(
|
79
|
-
"Multimodal is disabled for
|
83
|
+
f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
|
80
84
|
)
|
81
85
|
else:
|
82
86
|
enable_multimodal = True
|
@@ -158,7 +162,9 @@ class ModelConfig:
|
|
158
162
|
self.attention_arch = AttentionArch.MLA
|
159
163
|
self.kv_lora_rank = self.hf_config.kv_lora_rank
|
160
164
|
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
161
|
-
elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures
|
165
|
+
elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures and getattr(
|
166
|
+
self.hf_text_config, "use_mla", True
|
167
|
+
):
|
162
168
|
self.head_dim = 256
|
163
169
|
self.attention_arch = AttentionArch.MLA
|
164
170
|
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
|
@@ -14,49 +14,48 @@
|
|
14
14
|
"""Constrained decoding with llguidance backend."""
|
15
15
|
|
16
16
|
import json
|
17
|
+
import logging
|
17
18
|
import os
|
18
19
|
from typing import List, Optional, Tuple
|
19
20
|
|
20
|
-
import llguidance
|
21
|
-
import llguidance.hf
|
22
|
-
import llguidance.torch
|
23
21
|
import torch
|
24
|
-
from llguidance
|
22
|
+
from llguidance import LLMatcher, LLTokenizer, StructTag, grammar_from
|
23
|
+
from llguidance.hf import from_tokenizer
|
24
|
+
from llguidance.torch import (
|
25
|
+
allocate_token_bitmask,
|
26
|
+
apply_token_bitmask_inplace,
|
27
|
+
fill_next_token_bitmask,
|
28
|
+
)
|
25
29
|
|
26
30
|
from sglang.srt.constrained.base_grammar_backend import (
|
27
31
|
BaseGrammarBackend,
|
28
32
|
BaseGrammarObject,
|
29
33
|
)
|
30
34
|
|
35
|
+
logger = logging.getLogger(__name__)
|
36
|
+
|
31
37
|
|
32
38
|
class GuidanceGrammar(BaseGrammarObject):
|
33
|
-
|
34
|
-
|
35
|
-
):
|
39
|
+
|
40
|
+
def __init__(self, llguidance_tokenizer: LLTokenizer, serialized_grammar: str):
|
36
41
|
super().__init__()
|
37
42
|
self.llguidance_tokenizer = llguidance_tokenizer
|
38
43
|
self.serialized_grammar = serialized_grammar
|
39
44
|
|
40
|
-
|
41
|
-
self.ll_interpreter = llguidance.LLInterpreter(
|
45
|
+
self.ll_matcher = LLMatcher(
|
42
46
|
self.llguidance_tokenizer,
|
43
47
|
self.serialized_grammar,
|
44
|
-
enable_backtrack=False,
|
45
|
-
enable_ff_tokens=False,
|
46
48
|
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
|
47
49
|
)
|
48
|
-
self.pending_ff_tokens: list[int] = []
|
49
50
|
self.finished = False
|
50
51
|
self.bitmask = None
|
51
52
|
|
52
53
|
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
53
|
-
|
54
|
-
|
55
|
-
ff_tokens
|
56
|
-
|
57
|
-
return
|
58
|
-
|
59
|
-
return None
|
54
|
+
ff_tokens = self.ll_matcher.compute_ff_tokens()
|
55
|
+
if ff_tokens:
|
56
|
+
return ff_tokens, ""
|
57
|
+
else:
|
58
|
+
return None
|
60
59
|
|
61
60
|
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
62
61
|
return "", -1
|
@@ -67,32 +66,22 @@ class GuidanceGrammar(BaseGrammarObject):
|
|
67
66
|
pass
|
68
67
|
|
69
68
|
def accept_token(self, token: int):
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
ff_tokens = ff_tokens[1:]
|
74
|
-
self.pending_ff_tokens.extend(ff_tokens)
|
69
|
+
if not self.ll_matcher.consume_token(token):
|
70
|
+
logger.warning(f"matcher error: {self.ll_matcher.get_error()}")
|
71
|
+
self.finished = True
|
75
72
|
|
76
73
|
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
77
|
-
if
|
78
|
-
# if we have pending fast-forward tokens,
|
79
|
-
# just return them immediately
|
80
|
-
ff_token = self.pending_ff_tokens.pop(0)
|
81
|
-
vocab_mask[idx, :] = 0
|
82
|
-
vocab_mask[idx, ff_token // 32] = 1 << (ff_token % 32)
|
83
|
-
return
|
84
|
-
|
85
|
-
if self.ll_interpreter.has_pending_stop():
|
74
|
+
if self.ll_matcher.is_stopped():
|
86
75
|
self.finished = True
|
87
76
|
|
88
|
-
|
77
|
+
fill_next_token_bitmask(self.ll_matcher, vocab_mask, idx)
|
89
78
|
|
90
79
|
def allocate_vocab_mask(
|
91
80
|
self, vocab_size: int, batch_size: int, device
|
92
81
|
) -> torch.Tensor:
|
93
82
|
if self.bitmask is None or self.bitmask.shape[0] < batch_size:
|
94
83
|
# only create bitmask when batch gets larger
|
95
|
-
self.bitmask =
|
84
|
+
self.bitmask = allocate_token_bitmask(
|
96
85
|
batch_size, self.llguidance_tokenizer.vocab_size
|
97
86
|
)
|
98
87
|
bitmask = self.bitmask
|
@@ -107,7 +96,7 @@ class GuidanceGrammar(BaseGrammarObject):
|
|
107
96
|
|
108
97
|
@staticmethod
|
109
98
|
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
110
|
-
|
99
|
+
apply_token_bitmask_inplace(logits, vocab_mask)
|
111
100
|
|
112
101
|
def copy(self):
|
113
102
|
return GuidanceGrammar(
|
@@ -117,36 +106,64 @@ class GuidanceGrammar(BaseGrammarObject):
|
|
117
106
|
|
118
107
|
|
119
108
|
class GuidanceBackend(BaseGrammarBackend):
|
120
|
-
|
109
|
+
|
110
|
+
def __init__(
|
111
|
+
self,
|
112
|
+
tokenizer,
|
113
|
+
whitespace_pattern: Optional[str] = None,
|
114
|
+
n_vocab: Optional[int] = None,
|
115
|
+
):
|
121
116
|
super().__init__()
|
122
117
|
|
123
118
|
self.tokenizer = tokenizer
|
124
|
-
self.
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
119
|
+
self.whitespace_pattern = whitespace_pattern
|
120
|
+
self.llguidance_tokenizer = from_tokenizer(self.tokenizer, n_vocab)
|
121
|
+
|
122
|
+
def _from_serialized(self, serialized_grammar) -> Optional[GuidanceGrammar]:
|
123
|
+
try:
|
124
|
+
return GuidanceGrammar(
|
125
|
+
llguidance_tokenizer=self.llguidance_tokenizer,
|
126
|
+
serialized_grammar=serialized_grammar,
|
127
|
+
)
|
128
|
+
except Exception as e:
|
129
|
+
logger.warning(f"Skip invalid grammar: {serialized_grammar}, {e=}")
|
130
|
+
return None
|
131
|
+
|
132
|
+
def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]:
|
133
|
+
serialized_grammar = LLMatcher.grammar_from_json_schema(
|
134
|
+
key_string,
|
135
|
+
defaults={
|
136
|
+
"whitespace_pattern": self.whitespace_pattern,
|
137
|
+
},
|
133
138
|
)
|
134
|
-
|
135
|
-
def dispatch_json(self, key_string: str) -> GuidanceGrammar:
|
136
|
-
json_schema = key_string
|
137
|
-
compiler = llguidance.JsonCompiler(whitespace_flexible=self.whitespace_flexible)
|
138
|
-
serialized_grammar = compiler.compile(json_schema)
|
139
|
-
return self._from_serialized(serialized_grammar)
|
140
|
-
|
141
|
-
def dispatch_regex(self, key_string: str) -> GuidanceGrammar:
|
142
|
-
compiler = llguidance.RegexCompiler()
|
143
|
-
serialized_grammar = compiler.compile(regex=key_string)
|
144
139
|
return self._from_serialized(serialized_grammar)
|
145
140
|
|
146
|
-
def
|
147
|
-
|
148
|
-
serialized_grammar = compiler.compile(any_to_lark(key_string))
|
141
|
+
def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]:
|
142
|
+
serialized_grammar = grammar_from("regex", key_string)
|
149
143
|
return self._from_serialized(serialized_grammar)
|
150
144
|
|
151
|
-
def
|
152
|
-
|
145
|
+
def dispatch_ebnf(self, key_string: str) -> Optional[GuidanceGrammar]:
|
146
|
+
try:
|
147
|
+
serialized_grammar = grammar_from("ebnf", key_string)
|
148
|
+
return self._from_serialized(serialized_grammar)
|
149
|
+
except ValueError as e:
|
150
|
+
logger.warning(f"Skip invalid ebnf: regex={key_string}, {e=}")
|
151
|
+
return None
|
152
|
+
|
153
|
+
def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]:
|
154
|
+
try:
|
155
|
+
structural_tag = json.loads(key_string)
|
156
|
+
tags = [
|
157
|
+
StructTag(
|
158
|
+
begin=structure["begin"],
|
159
|
+
grammar=structure["schema"],
|
160
|
+
end=structure["end"],
|
161
|
+
trigger=structural_tag["triggers"][0], # TODO?
|
162
|
+
)
|
163
|
+
for structure in structural_tag["structures"]
|
164
|
+
]
|
165
|
+
g = StructTag.to_grammar(tags)
|
166
|
+
return self._from_serialized(g)
|
167
|
+
except Exception as e:
|
168
|
+
logging.warning(f"Skip invalid structural_tag: {key_string}, {e=}")
|
169
|
+
return None
|
sglang/srt/conversation.py
CHANGED
@@ -463,6 +463,30 @@ def generate_embedding_convs(
|
|
463
463
|
return convs
|
464
464
|
|
465
465
|
|
466
|
+
# Models in which system adds modality tokens at prompt start automatically
|
467
|
+
# when media inputs exceed modality tokens in prompt (e.g. 3 images but 2 <image> tokens)
|
468
|
+
_MODELS_REQUIRING_MODALITY_SUPPLEMENT = {"deepseek-vl2"}
|
469
|
+
|
470
|
+
|
471
|
+
# adapted from https://github.com/vllm-project/vllm/blob/5124f5bf51b83e6f344c1bc6652e8c4d81313b34/vllm/entrypoints/chat_utils.py#L856
|
472
|
+
def _get_full_multimodal_text_prompt(
|
473
|
+
modality_token: str, modality_count: int, text_prompt: str
|
474
|
+
) -> str:
|
475
|
+
"""Combine multimodal prompts for a multimodal language model."""
|
476
|
+
|
477
|
+
# For any existing placeholder in the text prompt, we leave it as is
|
478
|
+
left: int = modality_count - text_prompt.count(modality_token)
|
479
|
+
if left < 0:
|
480
|
+
raise ValueError(
|
481
|
+
f"Found more '{modality_token}' placeholders in input prompt than "
|
482
|
+
"actual multimodal data items."
|
483
|
+
)
|
484
|
+
|
485
|
+
# NOTE: For now we always add missing modality_token at the front of
|
486
|
+
# the prompt. This may change to be customizable in the future.
|
487
|
+
return "\n".join([modality_token] * left + [text_prompt])
|
488
|
+
|
489
|
+
|
466
490
|
def generate_chat_conv(
|
467
491
|
request: ChatCompletionRequest, template_name: str
|
468
492
|
) -> Conversation:
|
@@ -520,6 +544,12 @@ def generate_chat_conv(
|
|
520
544
|
if conv.name != "qwen2-vl"
|
521
545
|
else conv.image_token
|
522
546
|
)
|
547
|
+
add_token_as_needed: bool = (
|
548
|
+
conv.name in _MODELS_REQUIRING_MODALITY_SUPPLEMENT
|
549
|
+
)
|
550
|
+
if add_token_as_needed:
|
551
|
+
image_token = ""
|
552
|
+
|
523
553
|
audio_token = conv.audio_token
|
524
554
|
for content in message.content:
|
525
555
|
if content.type == "text":
|
@@ -533,7 +563,10 @@ def generate_chat_conv(
|
|
533
563
|
elif content.type == "audio_url":
|
534
564
|
real_content += audio_token
|
535
565
|
conv.append_audio(content.audio_url.url)
|
536
|
-
|
566
|
+
if add_token_as_needed:
|
567
|
+
real_content = _get_full_multimodal_text_prompt(
|
568
|
+
conv.image_token, num_image_url, real_content
|
569
|
+
)
|
537
570
|
conv.append_message(conv.roles[0], real_content)
|
538
571
|
elif msg_role == "assistant":
|
539
572
|
parsed_content = ""
|