sglang 0.4.5.post2__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +19 -3
- sglang/bench_serving.py +8 -8
- sglang/compile_deep_gemm.py +177 -0
- sglang/lang/backend/openai.py +5 -1
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +1 -1
- sglang/srt/configs/model_config.py +11 -2
- sglang/srt/constrained/llguidance_backend.py +78 -61
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +34 -1
- sglang/srt/disaggregation/decode.py +96 -5
- sglang/srt/disaggregation/mini_lb.py +113 -15
- sglang/srt/disaggregation/mooncake/conn.py +199 -32
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +119 -20
- sglang/srt/disaggregation/utils.py +17 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +11 -9
- sglang/srt/function_call_parser.py +132 -0
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/attention/base_attn_backend.py +3 -0
- sglang/srt/layers/attention/flashattention_backend.py +809 -160
- sglang/srt/layers/attention/flashmla_backend.py +8 -11
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/attention/vision.py +2 -0
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +42 -5
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +2 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
- sglang/srt/layers/pooler.py +6 -0
- sglang/srt/layers/quantization/awq.py +5 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/deep_gemm.py +385 -0
- sglang/srt/layers/quantization/fp8_kernel.py +7 -38
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/int8_kernel.py +32 -1
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/w8a8_int8.py +3 -3
- sglang/srt/layers/radix_attention.py +13 -3
- sglang/srt/layers/rotary_embedding.py +176 -132
- sglang/srt/layers/sampler.py +2 -2
- sglang/srt/managers/data_parallel_controller.py +17 -4
- sglang/srt/managers/io_struct.py +21 -3
- sglang/srt/managers/mm_utils.py +85 -28
- sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
- sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
- sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
- sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
- sglang/srt/managers/schedule_batch.py +42 -12
- sglang/srt/managers/scheduler.py +47 -26
- sglang/srt/managers/tokenizer_manager.py +120 -30
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +40 -32
- sglang/srt/mem_cache/memory_pool.py +118 -13
- sglang/srt/model_executor/cuda_graph_runner.py +16 -10
- sglang/srt/model_executor/forward_batch_info.py +51 -95
- sglang/srt/model_executor/model_runner.py +29 -27
- sglang/srt/models/deepseek.py +12 -2
- sglang/srt/models/deepseek_nextn.py +101 -6
- sglang/srt/models/deepseek_v2.py +153 -76
- sglang/srt/models/deepseek_vl2.py +9 -4
- sglang/srt/models/gemma3_causal.py +1 -1
- sglang/srt/models/llama4.py +0 -1
- sglang/srt/models/minicpm3.py +2 -2
- sglang/srt/models/minicpmo.py +22 -7
- sglang/srt/models/mllama4.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +3 -6
- sglang/srt/models/qwen2_vl.py +3 -7
- sglang/srt/models/roberta.py +178 -0
- sglang/srt/openai_api/adapter.py +87 -10
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/server_args.py +65 -60
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +2 -7
- sglang/srt/torch_memory_saver_adapter.py +10 -1
- sglang/srt/utils.py +48 -6
- sglang/test/runners.py +6 -13
- sglang/test/test_utils.py +39 -19
- sglang/version.py +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/METADATA +6 -7
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/RECORD +99 -92
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py
CHANGED
@@ -57,6 +57,7 @@ import torch
|
|
57
57
|
import torch.distributed as dist
|
58
58
|
|
59
59
|
from sglang.srt.configs.model_config import ModelConfig
|
60
|
+
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
|
60
61
|
from sglang.srt.entrypoints.engine import _set_envs_and_config
|
61
62
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
62
63
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
@@ -85,6 +86,7 @@ class BenchArgs:
|
|
85
86
|
correctness_test: bool = False
|
86
87
|
# This is only used for correctness test
|
87
88
|
cut_len: int = 4
|
89
|
+
log_decode_step: int = 0
|
88
90
|
profile: bool = False
|
89
91
|
profile_filename_prefix: str = "profile"
|
90
92
|
|
@@ -105,6 +107,12 @@ class BenchArgs:
|
|
105
107
|
)
|
106
108
|
parser.add_argument("--correctness-test", action="store_true")
|
107
109
|
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
110
|
+
parser.add_argument(
|
111
|
+
"--log-decode-step",
|
112
|
+
type=int,
|
113
|
+
default=BenchArgs.log_decode_step,
|
114
|
+
help="Log decode latency by step, default is set to zero to disable.",
|
115
|
+
)
|
108
116
|
parser.add_argument(
|
109
117
|
"--profile", action="store_true", help="Use Torch Profiler."
|
110
118
|
)
|
@@ -335,6 +343,7 @@ def latency_test_run_once(
|
|
335
343
|
input_len,
|
336
344
|
output_len,
|
337
345
|
device,
|
346
|
+
log_decode_step,
|
338
347
|
profile,
|
339
348
|
profile_filename_prefix,
|
340
349
|
):
|
@@ -394,9 +403,9 @@ def latency_test_run_once(
|
|
394
403
|
tot_latency += latency
|
395
404
|
throughput = batch_size / latency
|
396
405
|
decode_latencies.append(latency)
|
397
|
-
if i < 5:
|
406
|
+
if i < 5 or (log_decode_step > 0 and i % log_decode_step == 0):
|
398
407
|
rank_print(
|
399
|
-
f"Decode. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
408
|
+
f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
400
409
|
)
|
401
410
|
|
402
411
|
if profile:
|
@@ -457,8 +466,9 @@ def latency_test(
|
|
457
466
|
reqs,
|
458
467
|
bench_args.batch_size[0],
|
459
468
|
bench_args.input_len[0],
|
460
|
-
|
469
|
+
min(32, bench_args.output_len[0]), # shorter decoding to speed up the warmup
|
461
470
|
server_args.device,
|
471
|
+
log_decode_step=0,
|
462
472
|
profile=False,
|
463
473
|
profile_filename_prefix="", # not used
|
464
474
|
)
|
@@ -480,6 +490,7 @@ def latency_test(
|
|
480
490
|
il,
|
481
491
|
ol,
|
482
492
|
server_args.device,
|
493
|
+
bench_args.log_decode_step,
|
483
494
|
bench_args.profile if tp_rank == 0 else None,
|
484
495
|
bench_args.profile_filename_prefix,
|
485
496
|
)
|
@@ -492,8 +503,13 @@ def latency_test(
|
|
492
503
|
for result in result_list:
|
493
504
|
fout.write(json.dumps(result) + "\n")
|
494
505
|
|
506
|
+
if server_args.tp_size > 1:
|
507
|
+
destroy_distributed_environment()
|
508
|
+
|
495
509
|
|
496
510
|
def main(server_args, bench_args):
|
511
|
+
server_args.cuda_graph_max_bs = max(bench_args.batch_size)
|
512
|
+
|
497
513
|
_set_envs_and_config(server_args)
|
498
514
|
|
499
515
|
if server_args.model_path:
|
sglang/bench_serving.py
CHANGED
@@ -295,7 +295,7 @@ async def async_request_truss(
|
|
295
295
|
# NOTE: Some completion API might have a last
|
296
296
|
# usage summary response without a token so we
|
297
297
|
# want to check a token was generated
|
298
|
-
if data["choices"][0]["
|
298
|
+
if data["choices"][0]["text"]:
|
299
299
|
timestamp = time.perf_counter()
|
300
300
|
# First token
|
301
301
|
if ttft == 0.0:
|
@@ -307,7 +307,7 @@ async def async_request_truss(
|
|
307
307
|
output.itl.append(timestamp - most_recent_timestamp)
|
308
308
|
|
309
309
|
most_recent_timestamp = timestamp
|
310
|
-
generated_text += data["choices"][0]["
|
310
|
+
generated_text += data["choices"][0]["text"]
|
311
311
|
|
312
312
|
output.generated_text = generated_text
|
313
313
|
output.success = True
|
@@ -690,7 +690,6 @@ def sample_random_requests(
|
|
690
690
|
dataset_path: str,
|
691
691
|
random_sample: bool = True,
|
692
692
|
) -> List[Tuple[str, int, int]]:
|
693
|
-
|
694
693
|
input_lens = np.random.randint(
|
695
694
|
max(int(input_len * range_ratio), 1),
|
696
695
|
input_len + 1,
|
@@ -978,6 +977,7 @@ async def benchmark(
|
|
978
977
|
profile: bool,
|
979
978
|
pd_seperated: bool = False,
|
980
979
|
flush_cache: bool = False,
|
980
|
+
warmup_requests: int = 1,
|
981
981
|
):
|
982
982
|
if backend in ASYNC_REQUEST_FUNCS:
|
983
983
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
@@ -995,11 +995,11 @@ async def benchmark(
|
|
995
995
|
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
996
996
|
|
997
997
|
# Warmup
|
998
|
-
print(f"Starting warmup with {
|
998
|
+
print(f"Starting warmup with {warmup_requests} sequences...")
|
999
999
|
|
1000
1000
|
# Use the first request for all warmup iterations
|
1001
1001
|
test_prompt, test_prompt_len, test_output_len = input_requests[0]
|
1002
|
-
if lora_names
|
1002
|
+
if lora_names is not None and len(lora_names) != 0:
|
1003
1003
|
lora_name = lora_names[0]
|
1004
1004
|
else:
|
1005
1005
|
lora_name = None
|
@@ -1017,7 +1017,7 @@ async def benchmark(
|
|
1017
1017
|
|
1018
1018
|
# Run warmup requests
|
1019
1019
|
warmup_tasks = []
|
1020
|
-
for _ in range(
|
1020
|
+
for _ in range(warmup_requests):
|
1021
1021
|
warmup_tasks.append(
|
1022
1022
|
asyncio.create_task(request_func(request_func_input=test_input))
|
1023
1023
|
)
|
@@ -1025,7 +1025,7 @@ async def benchmark(
|
|
1025
1025
|
warmup_outputs = await asyncio.gather(*warmup_tasks)
|
1026
1026
|
|
1027
1027
|
# Check if at least one warmup request succeeded
|
1028
|
-
if not any(output.success for output in warmup_outputs):
|
1028
|
+
if warmup_requests > 0 and not any(output.success for output in warmup_outputs):
|
1029
1029
|
raise ValueError(
|
1030
1030
|
"Warmup failed - Please make sure benchmark arguments "
|
1031
1031
|
f"are correctly specified. Error: {warmup_outputs[0].error}"
|
@@ -1057,7 +1057,7 @@ async def benchmark(
|
|
1057
1057
|
tasks: List[asyncio.Task] = []
|
1058
1058
|
async for request in get_request(input_requests, request_rate):
|
1059
1059
|
prompt, prompt_len, output_len = request
|
1060
|
-
if lora_names
|
1060
|
+
if lora_names is not None and len(lora_names) != 0:
|
1061
1061
|
idx = random.randint(0, len(lora_names) - 1)
|
1062
1062
|
lora_name = lora_names[idx]
|
1063
1063
|
else:
|
@@ -0,0 +1,177 @@
|
|
1
|
+
"""
|
2
|
+
Compile DeepGEMM Kernels for a model with specify server arguments
|
3
|
+
|
4
|
+
This script launches a server for capturing DeepGEMM calls and then compiles the kernels.
|
5
|
+
It accepts server arguments (the same as launch_server.py).
|
6
|
+
|
7
|
+
Usage:
|
8
|
+
python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
|
9
|
+
|
10
|
+
"""
|
11
|
+
|
12
|
+
import argparse
|
13
|
+
import dataclasses
|
14
|
+
import multiprocessing
|
15
|
+
import os
|
16
|
+
import time
|
17
|
+
|
18
|
+
import requests
|
19
|
+
|
20
|
+
from sglang.srt.entrypoints.http_server import launch_server
|
21
|
+
from sglang.srt.managers.io_struct import GenerateReqInput
|
22
|
+
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
23
|
+
from sglang.srt.server_args import ServerArgs
|
24
|
+
from sglang.srt.utils import kill_process_tree
|
25
|
+
from sglang.srt.warmup import warmup
|
26
|
+
|
27
|
+
multiprocessing.set_start_method("spawn", force=True)
|
28
|
+
|
29
|
+
# Reduce warning
|
30
|
+
os.environ["SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"] = "1"
|
31
|
+
# Force enable deep gemm
|
32
|
+
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1"
|
33
|
+
# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
|
34
|
+
os.environ["SGL_CHUNKED_PREFIX_CACHE_THRESHOLD"] = "0"
|
35
|
+
|
36
|
+
|
37
|
+
@dataclasses.dataclass
|
38
|
+
class CompileArgs:
|
39
|
+
timeout: int = 3600
|
40
|
+
|
41
|
+
@staticmethod
|
42
|
+
def add_cli_args(parser: argparse.ArgumentParser):
|
43
|
+
parser.add_argument("--timeout", type=int, default=CompileArgs.timeout)
|
44
|
+
|
45
|
+
@classmethod
|
46
|
+
def from_cli_args(cls, args: argparse.Namespace):
|
47
|
+
# use the default value's type to cast the args into correct types.
|
48
|
+
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
49
|
+
return cls(
|
50
|
+
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
51
|
+
)
|
52
|
+
|
53
|
+
|
54
|
+
@warmup("compile-deep-gemm")
|
55
|
+
async def warm_up_compile(tokenizer_manager: TokenizerManager):
|
56
|
+
print("\nGenerate warm up request for compiling DeepGEMM...\n")
|
57
|
+
generate_req_input = GenerateReqInput(
|
58
|
+
input_ids=[0, 1, 2, 3],
|
59
|
+
sampling_params={
|
60
|
+
"temperature": 0.0,
|
61
|
+
"max_new_tokens": 8,
|
62
|
+
"ignore_eos": True,
|
63
|
+
},
|
64
|
+
)
|
65
|
+
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
|
66
|
+
|
67
|
+
|
68
|
+
def launch_server_internal(server_args):
|
69
|
+
try:
|
70
|
+
launch_server(server_args)
|
71
|
+
except Exception as e:
|
72
|
+
raise e
|
73
|
+
finally:
|
74
|
+
kill_process_tree(os.getpid(), include_parent=False)
|
75
|
+
|
76
|
+
|
77
|
+
def launch_server_process_and_send_one_request(
|
78
|
+
server_args: ServerArgs, compile_args: CompileArgs
|
79
|
+
):
|
80
|
+
proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
|
81
|
+
proc.start()
|
82
|
+
base_url = f"http://{server_args.host}:{server_args.port}"
|
83
|
+
timeout = compile_args.timeout
|
84
|
+
|
85
|
+
start_time = time.time()
|
86
|
+
while time.time() - start_time < timeout:
|
87
|
+
try:
|
88
|
+
headers = {
|
89
|
+
"Content-Type": "application/json; charset=utf-8",
|
90
|
+
}
|
91
|
+
if server_args.node_rank == 0:
|
92
|
+
response = requests.get(f"{base_url}/v1/models", headers=headers)
|
93
|
+
else:
|
94
|
+
# This http api is created by launch_dummy_health_check_server for none-rank0 node.
|
95
|
+
response = requests.get(f"{base_url}/health", headers=headers)
|
96
|
+
if response.status_code == 200:
|
97
|
+
# Rank-0 node send a request to sync with other node and then return.
|
98
|
+
if server_args.node_rank == 0:
|
99
|
+
response = requests.post(
|
100
|
+
f"{base_url}/generate",
|
101
|
+
json={
|
102
|
+
"input_ids": [0, 1, 2, 3],
|
103
|
+
"sampling_params": {
|
104
|
+
"max_new_tokens": 8,
|
105
|
+
"temperature": 0,
|
106
|
+
},
|
107
|
+
},
|
108
|
+
timeout=600,
|
109
|
+
)
|
110
|
+
if response.status_code != 200:
|
111
|
+
error = response.json()
|
112
|
+
raise RuntimeError(f"Sync request failed: {error}")
|
113
|
+
# Other nodes should wait for the exit signal from Rank-0 node.
|
114
|
+
else:
|
115
|
+
start_time_waiting = time.time()
|
116
|
+
while proc.is_alive():
|
117
|
+
if time.time() - start_time_waiting < timeout:
|
118
|
+
time.sleep(10)
|
119
|
+
else:
|
120
|
+
raise TimeoutError("Waiting for main node timeout!")
|
121
|
+
return proc
|
122
|
+
except requests.RequestException:
|
123
|
+
pass
|
124
|
+
time.sleep(10)
|
125
|
+
raise TimeoutError(
|
126
|
+
"DeepGEMM Kernels compilation timeout."
|
127
|
+
"\n\nFeel free and please restart the command."
|
128
|
+
)
|
129
|
+
|
130
|
+
|
131
|
+
def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
|
132
|
+
# Disbale cuda graph and torch compile to save time
|
133
|
+
server_args.disable_cuda_graph = True
|
134
|
+
server_args.enable_torch_compile = False
|
135
|
+
print(f"Disable CUDA Graph and Torch Compile to save time...")
|
136
|
+
|
137
|
+
# Set watchdog timeout to compile_args.timeout because compilation will take a long time
|
138
|
+
server_args.watchdog_timeout = compile_args.timeout
|
139
|
+
server_args.warmups = "compile-deep-gemm"
|
140
|
+
|
141
|
+
|
142
|
+
def run_compile(server_args: ServerArgs, compile_args: CompileArgs):
|
143
|
+
print(
|
144
|
+
"Begin DeepGEMM Kernels compilation...\n"
|
145
|
+
"It may take a long time and timeout maybe raised "
|
146
|
+
"while the compilation is still in progress.\n"
|
147
|
+
"Just feel free to restart the command "
|
148
|
+
"until the compilation is fully finished.\n"
|
149
|
+
)
|
150
|
+
|
151
|
+
proc = launch_server_process_and_send_one_request(server_args, compile_args)
|
152
|
+
|
153
|
+
print("\nDeepGEMM Kernels compilation finished successfully.")
|
154
|
+
|
155
|
+
# Sleep for safety
|
156
|
+
time.sleep(10)
|
157
|
+
if proc.is_alive():
|
158
|
+
# This is the rank0 node.
|
159
|
+
kill_process_tree(proc.pid)
|
160
|
+
else:
|
161
|
+
try:
|
162
|
+
kill_process_tree(proc.pid)
|
163
|
+
except Exception:
|
164
|
+
pass
|
165
|
+
|
166
|
+
|
167
|
+
if __name__ == "__main__":
|
168
|
+
parser = argparse.ArgumentParser()
|
169
|
+
ServerArgs.add_cli_args(parser)
|
170
|
+
CompileArgs.add_cli_args(parser)
|
171
|
+
args = parser.parse_args()
|
172
|
+
server_args = ServerArgs.from_cli_args(args)
|
173
|
+
compile_args = CompileArgs.from_cli_args(args)
|
174
|
+
|
175
|
+
refine_server_args(server_args, compile_args)
|
176
|
+
|
177
|
+
run_compile(server_args, compile_args)
|
sglang/lang/backend/openai.py
CHANGED
@@ -161,7 +161,11 @@ class OpenAI(BaseBackend):
|
|
161
161
|
prompt = s.text_
|
162
162
|
|
163
163
|
kwargs = sampling_params.to_openai_kwargs()
|
164
|
-
if
|
164
|
+
if (
|
165
|
+
self.model_name.startswith("o1")
|
166
|
+
or self.model_name.startswith("o3")
|
167
|
+
or "o1" in self.model_name
|
168
|
+
):
|
165
169
|
kwargs.pop("max_tokens", None)
|
166
170
|
else:
|
167
171
|
kwargs.pop("max_completion_tokens", None)
|
@@ -324,7 +324,11 @@ class RuntimeEndpoint(BaseBackend):
|
|
324
324
|
|
325
325
|
def _assert_success(self, res):
|
326
326
|
if res.status_code != 200:
|
327
|
-
|
327
|
+
try:
|
328
|
+
content = res.json()
|
329
|
+
except json.JSONDecodeError:
|
330
|
+
content = res.text
|
331
|
+
raise RuntimeError(content)
|
328
332
|
|
329
333
|
|
330
334
|
def compute_normalized_prompt_logprobs(input_logprobs):
|
@@ -113,7 +113,7 @@ def completion_template_exists(template_name: str) -> bool:
|
|
113
113
|
|
114
114
|
def is_completion_template_defined() -> bool:
|
115
115
|
global completion_template_name
|
116
|
-
return completion_template_name
|
116
|
+
return completion_template_name is not None
|
117
117
|
|
118
118
|
|
119
119
|
def generate_completion_prompt_from_request(request: ChatCompletionRequest) -> str:
|
@@ -182,7 +182,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
|
|
182
182
|
tokenized_str, images, seq_mask, spatial_crop = self.tokenize_with_images(
|
183
183
|
messages,
|
184
184
|
pil_images[image_index : image_index + image_token_cnt],
|
185
|
-
bos=
|
185
|
+
bos=True,
|
186
186
|
eos=True,
|
187
187
|
cropping=len(pil_images) <= 2,
|
188
188
|
max_req_input_len=max_req_input_len,
|
@@ -73,8 +73,15 @@ class ModelConfig:
|
|
73
73
|
)
|
74
74
|
|
75
75
|
if enable_multimodal is None:
|
76
|
-
|
76
|
+
mm_disabled_models = [
|
77
|
+
"Gemma3ForConditionalGeneration",
|
78
|
+
"Llama4ForConditionalGeneration",
|
79
|
+
]
|
80
|
+
if self.hf_config.architectures[0] in mm_disabled_models:
|
77
81
|
enable_multimodal = False
|
82
|
+
logger.info(
|
83
|
+
f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
|
84
|
+
)
|
78
85
|
else:
|
79
86
|
enable_multimodal = True
|
80
87
|
|
@@ -155,7 +162,9 @@ class ModelConfig:
|
|
155
162
|
self.attention_arch = AttentionArch.MLA
|
156
163
|
self.kv_lora_rank = self.hf_config.kv_lora_rank
|
157
164
|
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
158
|
-
elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures
|
165
|
+
elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures and getattr(
|
166
|
+
self.hf_text_config, "use_mla", True
|
167
|
+
):
|
159
168
|
self.head_dim = 256
|
160
169
|
self.attention_arch = AttentionArch.MLA
|
161
170
|
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
|
@@ -14,49 +14,48 @@
|
|
14
14
|
"""Constrained decoding with llguidance backend."""
|
15
15
|
|
16
16
|
import json
|
17
|
+
import logging
|
17
18
|
import os
|
18
19
|
from typing import List, Optional, Tuple
|
19
20
|
|
20
|
-
import llguidance
|
21
|
-
import llguidance.hf
|
22
|
-
import llguidance.torch
|
23
21
|
import torch
|
24
|
-
from llguidance
|
22
|
+
from llguidance import LLMatcher, LLTokenizer, StructTag, grammar_from
|
23
|
+
from llguidance.hf import from_tokenizer
|
24
|
+
from llguidance.torch import (
|
25
|
+
allocate_token_bitmask,
|
26
|
+
apply_token_bitmask_inplace,
|
27
|
+
fill_next_token_bitmask,
|
28
|
+
)
|
25
29
|
|
26
30
|
from sglang.srt.constrained.base_grammar_backend import (
|
27
31
|
BaseGrammarBackend,
|
28
32
|
BaseGrammarObject,
|
29
33
|
)
|
30
34
|
|
35
|
+
logger = logging.getLogger(__name__)
|
36
|
+
|
31
37
|
|
32
38
|
class GuidanceGrammar(BaseGrammarObject):
|
33
|
-
|
34
|
-
|
35
|
-
):
|
39
|
+
|
40
|
+
def __init__(self, llguidance_tokenizer: LLTokenizer, serialized_grammar: str):
|
36
41
|
super().__init__()
|
37
42
|
self.llguidance_tokenizer = llguidance_tokenizer
|
38
43
|
self.serialized_grammar = serialized_grammar
|
39
44
|
|
40
|
-
|
41
|
-
self.ll_interpreter = llguidance.LLInterpreter(
|
45
|
+
self.ll_matcher = LLMatcher(
|
42
46
|
self.llguidance_tokenizer,
|
43
47
|
self.serialized_grammar,
|
44
|
-
enable_backtrack=False,
|
45
|
-
enable_ff_tokens=False,
|
46
48
|
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
|
47
49
|
)
|
48
|
-
self.pending_ff_tokens: list[int] = []
|
49
50
|
self.finished = False
|
50
51
|
self.bitmask = None
|
51
52
|
|
52
53
|
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
53
|
-
|
54
|
-
|
55
|
-
ff_tokens
|
56
|
-
|
57
|
-
return
|
58
|
-
|
59
|
-
return None
|
54
|
+
ff_tokens = self.ll_matcher.compute_ff_tokens()
|
55
|
+
if ff_tokens:
|
56
|
+
return ff_tokens, ""
|
57
|
+
else:
|
58
|
+
return None
|
60
59
|
|
61
60
|
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
62
61
|
return "", -1
|
@@ -67,32 +66,22 @@ class GuidanceGrammar(BaseGrammarObject):
|
|
67
66
|
pass
|
68
67
|
|
69
68
|
def accept_token(self, token: int):
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
ff_tokens = ff_tokens[1:]
|
74
|
-
self.pending_ff_tokens.extend(ff_tokens)
|
69
|
+
if not self.ll_matcher.consume_token(token):
|
70
|
+
logger.warning(f"matcher error: {self.ll_matcher.get_error()}")
|
71
|
+
self.finished = True
|
75
72
|
|
76
73
|
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
77
|
-
if
|
78
|
-
# if we have pending fast-forward tokens,
|
79
|
-
# just return them immediately
|
80
|
-
ff_token = self.pending_ff_tokens.pop(0)
|
81
|
-
vocab_mask[idx, :] = 0
|
82
|
-
vocab_mask[idx, ff_token // 32] = 1 << (ff_token % 32)
|
83
|
-
return
|
84
|
-
|
85
|
-
if self.ll_interpreter.has_pending_stop():
|
74
|
+
if self.ll_matcher.is_stopped():
|
86
75
|
self.finished = True
|
87
76
|
|
88
|
-
|
77
|
+
fill_next_token_bitmask(self.ll_matcher, vocab_mask, idx)
|
89
78
|
|
90
79
|
def allocate_vocab_mask(
|
91
80
|
self, vocab_size: int, batch_size: int, device
|
92
81
|
) -> torch.Tensor:
|
93
82
|
if self.bitmask is None or self.bitmask.shape[0] < batch_size:
|
94
83
|
# only create bitmask when batch gets larger
|
95
|
-
self.bitmask =
|
84
|
+
self.bitmask = allocate_token_bitmask(
|
96
85
|
batch_size, self.llguidance_tokenizer.vocab_size
|
97
86
|
)
|
98
87
|
bitmask = self.bitmask
|
@@ -107,7 +96,7 @@ class GuidanceGrammar(BaseGrammarObject):
|
|
107
96
|
|
108
97
|
@staticmethod
|
109
98
|
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
110
|
-
|
99
|
+
apply_token_bitmask_inplace(logits, vocab_mask)
|
111
100
|
|
112
101
|
def copy(self):
|
113
102
|
return GuidanceGrammar(
|
@@ -117,36 +106,64 @@ class GuidanceGrammar(BaseGrammarObject):
|
|
117
106
|
|
118
107
|
|
119
108
|
class GuidanceBackend(BaseGrammarBackend):
|
120
|
-
|
109
|
+
|
110
|
+
def __init__(
|
111
|
+
self,
|
112
|
+
tokenizer,
|
113
|
+
whitespace_pattern: Optional[str] = None,
|
114
|
+
n_vocab: Optional[int] = None,
|
115
|
+
):
|
121
116
|
super().__init__()
|
122
117
|
|
123
118
|
self.tokenizer = tokenizer
|
124
|
-
self.
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
119
|
+
self.whitespace_pattern = whitespace_pattern
|
120
|
+
self.llguidance_tokenizer = from_tokenizer(self.tokenizer, n_vocab)
|
121
|
+
|
122
|
+
def _from_serialized(self, serialized_grammar) -> Optional[GuidanceGrammar]:
|
123
|
+
try:
|
124
|
+
return GuidanceGrammar(
|
125
|
+
llguidance_tokenizer=self.llguidance_tokenizer,
|
126
|
+
serialized_grammar=serialized_grammar,
|
127
|
+
)
|
128
|
+
except Exception as e:
|
129
|
+
logger.warning(f"Skip invalid grammar: {serialized_grammar}, {e=}")
|
130
|
+
return None
|
131
|
+
|
132
|
+
def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]:
|
133
|
+
serialized_grammar = LLMatcher.grammar_from_json_schema(
|
134
|
+
key_string,
|
135
|
+
defaults={
|
136
|
+
"whitespace_pattern": self.whitespace_pattern,
|
137
|
+
},
|
133
138
|
)
|
134
|
-
|
135
|
-
def dispatch_json(self, key_string: str) -> GuidanceGrammar:
|
136
|
-
json_schema = key_string
|
137
|
-
compiler = llguidance.JsonCompiler(whitespace_flexible=self.whitespace_flexible)
|
138
|
-
serialized_grammar = compiler.compile(json_schema)
|
139
|
-
return self._from_serialized(serialized_grammar)
|
140
|
-
|
141
|
-
def dispatch_regex(self, key_string: str) -> GuidanceGrammar:
|
142
|
-
compiler = llguidance.RegexCompiler()
|
143
|
-
serialized_grammar = compiler.compile(regex=key_string)
|
144
139
|
return self._from_serialized(serialized_grammar)
|
145
140
|
|
146
|
-
def
|
147
|
-
|
148
|
-
serialized_grammar = compiler.compile(any_to_lark(key_string))
|
141
|
+
def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]:
|
142
|
+
serialized_grammar = grammar_from("regex", key_string)
|
149
143
|
return self._from_serialized(serialized_grammar)
|
150
144
|
|
151
|
-
def
|
152
|
-
|
145
|
+
def dispatch_ebnf(self, key_string: str) -> Optional[GuidanceGrammar]:
|
146
|
+
try:
|
147
|
+
serialized_grammar = grammar_from("ebnf", key_string)
|
148
|
+
return self._from_serialized(serialized_grammar)
|
149
|
+
except ValueError as e:
|
150
|
+
logger.warning(f"Skip invalid ebnf: regex={key_string}, {e=}")
|
151
|
+
return None
|
152
|
+
|
153
|
+
def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]:
|
154
|
+
try:
|
155
|
+
structural_tag = json.loads(key_string)
|
156
|
+
tags = [
|
157
|
+
StructTag(
|
158
|
+
begin=structure["begin"],
|
159
|
+
grammar=structure["schema"],
|
160
|
+
end=structure["end"],
|
161
|
+
trigger=structural_tag["triggers"][0], # TODO?
|
162
|
+
)
|
163
|
+
for structure in structural_tag["structures"]
|
164
|
+
]
|
165
|
+
g = StructTag.to_grammar(tags)
|
166
|
+
return self._from_serialized(g)
|
167
|
+
except Exception as e:
|
168
|
+
logging.warning(f"Skip invalid structural_tag: {key_string}, {e=}")
|
169
|
+
return None
|
@@ -158,6 +158,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
158
158
|
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
|
159
159
|
try:
|
160
160
|
if key_string == "$$ANY$$":
|
161
|
+
# Note: This builtin JSON grammar includes *all* valid JSON (including, for example, arrays at the root)
|
161
162
|
ctx = self.grammar_compiler.compile_builtin_json_grammar()
|
162
163
|
else:
|
163
164
|
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
|
sglang/srt/conversation.py
CHANGED
@@ -463,6 +463,30 @@ def generate_embedding_convs(
|
|
463
463
|
return convs
|
464
464
|
|
465
465
|
|
466
|
+
# Models in which system adds modality tokens at prompt start automatically
|
467
|
+
# when media inputs exceed modality tokens in prompt (e.g. 3 images but 2 <image> tokens)
|
468
|
+
_MODELS_REQUIRING_MODALITY_SUPPLEMENT = {"deepseek-vl2"}
|
469
|
+
|
470
|
+
|
471
|
+
# adapted from https://github.com/vllm-project/vllm/blob/5124f5bf51b83e6f344c1bc6652e8c4d81313b34/vllm/entrypoints/chat_utils.py#L856
|
472
|
+
def _get_full_multimodal_text_prompt(
|
473
|
+
modality_token: str, modality_count: int, text_prompt: str
|
474
|
+
) -> str:
|
475
|
+
"""Combine multimodal prompts for a multimodal language model."""
|
476
|
+
|
477
|
+
# For any existing placeholder in the text prompt, we leave it as is
|
478
|
+
left: int = modality_count - text_prompt.count(modality_token)
|
479
|
+
if left < 0:
|
480
|
+
raise ValueError(
|
481
|
+
f"Found more '{modality_token}' placeholders in input prompt than "
|
482
|
+
"actual multimodal data items."
|
483
|
+
)
|
484
|
+
|
485
|
+
# NOTE: For now we always add missing modality_token at the front of
|
486
|
+
# the prompt. This may change to be customizable in the future.
|
487
|
+
return "\n".join([modality_token] * left + [text_prompt])
|
488
|
+
|
489
|
+
|
466
490
|
def generate_chat_conv(
|
467
491
|
request: ChatCompletionRequest, template_name: str
|
468
492
|
) -> Conversation:
|
@@ -520,6 +544,12 @@ def generate_chat_conv(
|
|
520
544
|
if conv.name != "qwen2-vl"
|
521
545
|
else conv.image_token
|
522
546
|
)
|
547
|
+
add_token_as_needed: bool = (
|
548
|
+
conv.name in _MODELS_REQUIRING_MODALITY_SUPPLEMENT
|
549
|
+
)
|
550
|
+
if add_token_as_needed:
|
551
|
+
image_token = ""
|
552
|
+
|
523
553
|
audio_token = conv.audio_token
|
524
554
|
for content in message.content:
|
525
555
|
if content.type == "text":
|
@@ -533,7 +563,10 @@ def generate_chat_conv(
|
|
533
563
|
elif content.type == "audio_url":
|
534
564
|
real_content += audio_token
|
535
565
|
conv.append_audio(content.audio_url.url)
|
536
|
-
|
566
|
+
if add_token_as_needed:
|
567
|
+
real_content = _get_full_multimodal_text_prompt(
|
568
|
+
conv.image_token, num_image_url, real_content
|
569
|
+
)
|
537
570
|
conv.append_message(conv.roles[0], real_content)
|
538
571
|
elif msg_role == "assistant":
|
539
572
|
parsed_content = ""
|