sglang 0.3.5.post1__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +337 -0
- sglang/bench_one_batch.py +474 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +115 -31
- sglang/check_env.py +3 -6
- sglang/srt/constrained/base_grammar_backend.py +4 -3
- sglang/srt/constrained/outlines_backend.py +39 -26
- sglang/srt/constrained/xgrammar_backend.py +58 -14
- sglang/srt/layers/activation.py +3 -0
- sglang/srt/layers/attention/flashinfer_backend.py +93 -48
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/custom_op_util.py +26 -0
- sglang/srt/layers/fused_moe/fused_moe.py +11 -4
- sglang/srt/layers/fused_moe/patch.py +4 -2
- sglang/srt/layers/layernorm.py +4 -0
- sglang/srt/layers/logits_processor.py +10 -10
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/managers/data_parallel_controller.py +74 -9
- sglang/srt/managers/detokenizer_manager.py +1 -14
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/schedule_batch.py +104 -38
- sglang/srt/managers/schedule_policy.py +5 -1
- sglang/srt/managers/scheduler.py +210 -56
- sglang/srt/managers/session_controller.py +62 -0
- sglang/srt/managers/tokenizer_manager.py +38 -0
- sglang/srt/managers/tp_worker.py +12 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
- sglang/srt/model_executor/cuda_graph_runner.py +43 -6
- sglang/srt/model_executor/forward_batch_info.py +109 -15
- sglang/srt/model_executor/model_runner.py +102 -43
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/deepseek_v2.py +147 -44
- sglang/srt/models/gemma2.py +9 -8
- sglang/srt/models/llava.py +1 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/torch_native_llama.py +94 -78
- sglang/srt/openai_api/adapter.py +11 -4
- sglang/srt/openai_api/protocol.py +30 -27
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +58 -57
- sglang/srt/sampling/sampling_params.py +3 -3
- sglang/srt/server.py +29 -2
- sglang/srt/server_args.py +97 -60
- sglang/srt/utils.py +103 -51
- sglang/test/runners.py +25 -6
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +33 -22
- sglang/version.py +1 -1
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/RECORD +62 -56
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py
CHANGED
@@ -15,6 +15,7 @@ import argparse
|
|
15
15
|
import asyncio
|
16
16
|
import json
|
17
17
|
import os
|
18
|
+
import pickle
|
18
19
|
import random
|
19
20
|
import resource
|
20
21
|
import sys
|
@@ -387,6 +388,24 @@ async def async_request_gserver(
|
|
387
388
|
raise NotImplementedError()
|
388
389
|
|
389
390
|
|
391
|
+
async def async_request_profile(api_url: str) -> RequestFuncOutput:
|
392
|
+
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
393
|
+
output = RequestFuncOutput()
|
394
|
+
try:
|
395
|
+
async with session.post(url=api_url) as response:
|
396
|
+
if response.status == 200:
|
397
|
+
output.success = True
|
398
|
+
else:
|
399
|
+
output.error = response.reason or ""
|
400
|
+
output.success = False
|
401
|
+
except Exception:
|
402
|
+
output.success = False
|
403
|
+
exc_info = sys.exc_info()
|
404
|
+
output.error = "".join(traceback.format_exception(*exc_info))
|
405
|
+
|
406
|
+
return output
|
407
|
+
|
408
|
+
|
390
409
|
def get_model(pretrained_model_name_or_path: str) -> str:
|
391
410
|
if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true":
|
392
411
|
import huggingface_hub.constants
|
@@ -421,6 +440,37 @@ def get_tokenizer(
|
|
421
440
|
)
|
422
441
|
|
423
442
|
|
443
|
+
def get_dataset(args, tokenizer):
|
444
|
+
if args.dataset_name == "sharegpt":
|
445
|
+
input_requests = sample_sharegpt_requests(
|
446
|
+
dataset_path=args.dataset_path,
|
447
|
+
num_requests=args.num_prompts,
|
448
|
+
tokenizer=tokenizer,
|
449
|
+
fixed_output_len=args.sharegpt_output_len,
|
450
|
+
)
|
451
|
+
elif args.dataset_name == "random":
|
452
|
+
input_requests = sample_random_requests(
|
453
|
+
input_len=args.random_input_len,
|
454
|
+
output_len=args.random_output_len,
|
455
|
+
num_prompts=args.num_prompts,
|
456
|
+
range_ratio=args.random_range_ratio,
|
457
|
+
tokenizer=tokenizer,
|
458
|
+
dataset_path=args.dataset_path,
|
459
|
+
)
|
460
|
+
elif args.dataset_name == "generated-shared-prefix":
|
461
|
+
input_requests = sample_generated_shared_prefix_requests(
|
462
|
+
num_groups=args.gen_num_groups,
|
463
|
+
prompts_per_group=args.gen_prompts_per_group,
|
464
|
+
system_prompt_len=args.gen_system_prompt_len,
|
465
|
+
question_len=args.gen_question_len,
|
466
|
+
output_len=args.gen_output_len,
|
467
|
+
tokenizer=tokenizer,
|
468
|
+
)
|
469
|
+
else:
|
470
|
+
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
471
|
+
return input_requests
|
472
|
+
|
473
|
+
|
424
474
|
ASYNC_REQUEST_FUNCS = {
|
425
475
|
"sglang": async_request_sglang_generate,
|
426
476
|
"sglang-native": async_request_sglang_generate,
|
@@ -443,6 +493,8 @@ class BenchmarkMetrics:
|
|
443
493
|
input_throughput: float
|
444
494
|
output_throughput: float
|
445
495
|
output_throughput_retokenized: float
|
496
|
+
total_throughput: float
|
497
|
+
total_throughput_retokenized: float
|
446
498
|
mean_ttft_ms: float
|
447
499
|
median_ttft_ms: float
|
448
500
|
std_ttft_ms: float
|
@@ -590,7 +642,6 @@ def sample_random_requests(
|
|
590
642
|
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
591
643
|
for data in dataset
|
592
644
|
]
|
593
|
-
|
594
645
|
# Shuffle the dataset.
|
595
646
|
random.shuffle(dataset)
|
596
647
|
|
@@ -650,6 +701,11 @@ def sample_generated_shared_prefix_requests(
|
|
650
701
|
output_len: int,
|
651
702
|
tokenizer: PreTrainedTokenizerBase,
|
652
703
|
) -> List[Tuple[str, int, int]]:
|
704
|
+
if args.generated_input_path and os.path.exists(args.generated_input_path):
|
705
|
+
print(f"\nloading generated input data from {args.generated_input_path}")
|
706
|
+
with open(args.generated_input_path, "rb") as f:
|
707
|
+
return pickle.load(f)
|
708
|
+
|
653
709
|
"""Generate benchmark requests with shared system prompts using random tokens."""
|
654
710
|
# Generate system prompts for each group
|
655
711
|
system_prompts = []
|
@@ -663,6 +719,9 @@ def sample_generated_shared_prefix_requests(
|
|
663
719
|
question = gen_prompt(tokenizer, question_len)
|
664
720
|
questions.append(question)
|
665
721
|
|
722
|
+
# Shuffle questions
|
723
|
+
random.shuffle(questions)
|
724
|
+
|
666
725
|
# Combine system prompts with questions
|
667
726
|
input_requests = []
|
668
727
|
total_input_tokens = 0
|
@@ -691,6 +750,11 @@ def sample_generated_shared_prefix_requests(
|
|
691
750
|
print(
|
692
751
|
f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
|
693
752
|
)
|
753
|
+
if args.generated_input_save_path:
|
754
|
+
print(f"Saving generated input data to {args.generated_input_save_path}")
|
755
|
+
os.makedirs(os.path.dirname(args.generated_input_save_path), exist_ok=True)
|
756
|
+
with open(args.generated_input_save_path, "wb") as f:
|
757
|
+
pickle.dump(input_requests, f)
|
694
758
|
|
695
759
|
return input_requests
|
696
760
|
|
@@ -764,6 +828,9 @@ def calculate_metrics(
|
|
764
828
|
input_throughput=total_input / dur_s,
|
765
829
|
output_throughput=sum(output_lens) / dur_s,
|
766
830
|
output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,
|
831
|
+
total_throughput=(total_input + sum(output_lens)) / dur_s,
|
832
|
+
total_throughput_retokenized=(total_input + sum(retokenized_output_lens))
|
833
|
+
/ dur_s,
|
767
834
|
mean_ttft_ms=np.mean(ttfts or 0)
|
768
835
|
* 1000, # ttfts is empty if streaming is not supported by backend
|
769
836
|
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
@@ -787,12 +854,14 @@ def calculate_metrics(
|
|
787
854
|
async def benchmark(
|
788
855
|
backend: str,
|
789
856
|
api_url: str,
|
857
|
+
base_url: str,
|
790
858
|
model_id: str,
|
791
859
|
tokenizer: PreTrainedTokenizerBase,
|
792
860
|
input_requests: List[Tuple[str, int, int]],
|
793
861
|
request_rate: float,
|
794
862
|
disable_tqdm: bool,
|
795
863
|
extra_request_body: Dict[str, Any],
|
864
|
+
profile: bool,
|
796
865
|
):
|
797
866
|
if backend in ASYNC_REQUEST_FUNCS:
|
798
867
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
@@ -820,6 +889,14 @@ async def benchmark(
|
|
820
889
|
|
821
890
|
time.sleep(1.5)
|
822
891
|
|
892
|
+
if profile:
|
893
|
+
print("Starting profiler...")
|
894
|
+
profile_output = await async_request_profile(
|
895
|
+
api_url=base_url + "/start_profile"
|
896
|
+
)
|
897
|
+
if profile_output.success:
|
898
|
+
print("Profiler started")
|
899
|
+
|
823
900
|
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
824
901
|
|
825
902
|
benchmark_start_time = time.perf_counter()
|
@@ -841,6 +918,12 @@ async def benchmark(
|
|
841
918
|
)
|
842
919
|
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
843
920
|
|
921
|
+
if profile:
|
922
|
+
print("Stopping profiler...")
|
923
|
+
profile_output = await async_request_profile(api_url=base_url + "/stop_profile")
|
924
|
+
if profile_output.success:
|
925
|
+
print("Profiler stopped")
|
926
|
+
|
844
927
|
if pbar is not None:
|
845
928
|
pbar.close()
|
846
929
|
|
@@ -881,6 +964,11 @@ async def benchmark(
|
|
881
964
|
"Output token throughput (tok/s):", metrics.output_throughput
|
882
965
|
)
|
883
966
|
)
|
967
|
+
print(
|
968
|
+
"{:<40} {:<10.2f}".format(
|
969
|
+
"Total token throughput (tok/s):", metrics.total_throughput
|
970
|
+
)
|
971
|
+
)
|
884
972
|
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
|
885
973
|
print(
|
886
974
|
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
|
@@ -1060,6 +1148,9 @@ def run_benchmark(args_: argparse.Namespace):
|
|
1060
1148
|
if args.base_url
|
1061
1149
|
else f"http://{args.host}:{args.port}/v1/models/model:predict"
|
1062
1150
|
)
|
1151
|
+
base_url = (
|
1152
|
+
f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url
|
1153
|
+
)
|
1063
1154
|
|
1064
1155
|
# Get model name
|
1065
1156
|
if args.model is None:
|
@@ -1098,47 +1189,21 @@ def run_benchmark(args_: argparse.Namespace):
|
|
1098
1189
|
|
1099
1190
|
tokenizer = get_tokenizer(tokenizer_id)
|
1100
1191
|
|
1101
|
-
|
1102
|
-
assert args.random_input_len is None and args.random_output_len is None
|
1103
|
-
input_requests = sample_sharegpt_requests(
|
1104
|
-
dataset_path=args.dataset_path,
|
1105
|
-
num_requests=args.num_prompts,
|
1106
|
-
tokenizer=tokenizer,
|
1107
|
-
fixed_output_len=args.sharegpt_output_len,
|
1108
|
-
)
|
1109
|
-
elif args.dataset_name == "random":
|
1110
|
-
assert args.random_input_len is not None and args.random_output_len is not None
|
1111
|
-
input_requests = sample_random_requests(
|
1112
|
-
input_len=args.random_input_len,
|
1113
|
-
output_len=args.random_output_len,
|
1114
|
-
num_prompts=args.num_prompts,
|
1115
|
-
range_ratio=args.random_range_ratio,
|
1116
|
-
tokenizer=tokenizer,
|
1117
|
-
dataset_path=args.dataset_path,
|
1118
|
-
)
|
1119
|
-
elif args.dataset_name == "generated-shared-prefix":
|
1120
|
-
input_requests = sample_generated_shared_prefix_requests(
|
1121
|
-
num_groups=args.gen_num_groups,
|
1122
|
-
prompts_per_group=args.gen_prompts_per_group,
|
1123
|
-
system_prompt_len=args.gen_system_prompt_len,
|
1124
|
-
question_len=args.gen_question_len,
|
1125
|
-
output_len=args.gen_output_len,
|
1126
|
-
tokenizer=tokenizer,
|
1127
|
-
)
|
1128
|
-
else:
|
1129
|
-
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
1192
|
+
input_requests = get_dataset(args, tokenizer)
|
1130
1193
|
|
1131
1194
|
if not args.multi:
|
1132
1195
|
return asyncio.run(
|
1133
1196
|
benchmark(
|
1134
1197
|
backend=backend,
|
1135
1198
|
api_url=api_url,
|
1199
|
+
base_url=base_url,
|
1136
1200
|
model_id=model_id,
|
1137
1201
|
tokenizer=tokenizer,
|
1138
1202
|
input_requests=input_requests,
|
1139
1203
|
request_rate=args.request_rate,
|
1140
1204
|
disable_tqdm=args.disable_tqdm,
|
1141
1205
|
extra_request_body=extra_request_body,
|
1206
|
+
profile=args.profile,
|
1142
1207
|
)
|
1143
1208
|
)
|
1144
1209
|
else:
|
@@ -1150,12 +1215,14 @@ def run_benchmark(args_: argparse.Namespace):
|
|
1150
1215
|
benchmark(
|
1151
1216
|
backend=backend,
|
1152
1217
|
api_url=api_url,
|
1218
|
+
base_url=base_url,
|
1153
1219
|
model_id=model_id,
|
1154
1220
|
tokenizer=tokenizer,
|
1155
1221
|
input_requests=input_requests,
|
1156
1222
|
request_rate=rate,
|
1157
1223
|
disable_tqdm=args.disable_tqdm,
|
1158
1224
|
extra_request_body=extra_request_body,
|
1225
|
+
profile=args.profile,
|
1159
1226
|
)
|
1160
1227
|
)
|
1161
1228
|
|
@@ -1229,10 +1296,12 @@ if __name__ == "__main__":
|
|
1229
1296
|
parser.add_argument(
|
1230
1297
|
"--random-input-len",
|
1231
1298
|
type=int,
|
1299
|
+
default=1024,
|
1232
1300
|
help="Number of input tokens per request, used only for random dataset.",
|
1233
1301
|
)
|
1234
1302
|
parser.add_argument(
|
1235
1303
|
"--random-output-len",
|
1304
|
+
default=1024,
|
1236
1305
|
type=int,
|
1237
1306
|
help="Number of output tokens per request, used only for random dataset.",
|
1238
1307
|
)
|
@@ -1317,6 +1386,21 @@ if __name__ == "__main__":
|
|
1317
1386
|
default=256,
|
1318
1387
|
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
1319
1388
|
)
|
1320
|
-
|
1389
|
+
parser.add_argument(
|
1390
|
+
"--generated-input-save-path",
|
1391
|
+
type=str,
|
1392
|
+
help="Path to save generated input data",
|
1393
|
+
)
|
1394
|
+
parser.add_argument(
|
1395
|
+
"--generated-input-path",
|
1396
|
+
type=str,
|
1397
|
+
help="Path to load previously generated input data",
|
1398
|
+
)
|
1399
|
+
parser.add_argument(
|
1400
|
+
"--profile",
|
1401
|
+
action="store_true",
|
1402
|
+
help="Use Torch Profiler. The endpoint must be launched with "
|
1403
|
+
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
1404
|
+
)
|
1321
1405
|
args = parser.parse_args()
|
1322
1406
|
run_benchmark(args)
|
sglang/check_env.py
CHANGED
@@ -15,24 +15,21 @@ PACKAGE_LIST = [
|
|
15
15
|
"flashinfer",
|
16
16
|
"triton",
|
17
17
|
"transformers",
|
18
|
-
"
|
19
|
-
"tqdm",
|
18
|
+
"torchao",
|
20
19
|
"numpy",
|
21
20
|
"aiohttp",
|
22
21
|
"fastapi",
|
23
22
|
"hf_transfer",
|
24
23
|
"huggingface_hub",
|
25
24
|
"interegular",
|
26
|
-
"packaging",
|
27
|
-
"PIL",
|
28
25
|
"psutil",
|
29
26
|
"pydantic",
|
27
|
+
"multipart",
|
28
|
+
"zmq",
|
30
29
|
"uvicorn",
|
31
30
|
"uvloop",
|
32
|
-
"zmq",
|
33
31
|
"vllm",
|
34
32
|
"outlines",
|
35
|
-
"multipart",
|
36
33
|
"openai",
|
37
34
|
"tiktoken",
|
38
35
|
"anthropic",
|
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
-
"""The baseclass of
|
16
|
+
"""The baseclass of a backend for grammar-guided constrained decoding."""
|
17
17
|
|
18
18
|
from concurrent.futures import Future, ThreadPoolExecutor
|
19
19
|
from dataclasses import dataclass
|
@@ -52,7 +52,7 @@ class BaseGrammarBackend:
|
|
52
52
|
else:
|
53
53
|
entry.value = self.init_value_impl(key)
|
54
54
|
entry.event.set()
|
55
|
-
return entry.value.copy()
|
55
|
+
return entry.value.copy() if entry.value else None
|
56
56
|
|
57
57
|
def init_value_impl(self, key: Tuple[str, str]) -> BaseGrammarObject:
|
58
58
|
raise NotImplementedError()
|
@@ -62,7 +62,8 @@ class BaseGrammarBackend:
|
|
62
62
|
entry = self.cache.get(key)
|
63
63
|
if not entry or not entry.event.is_set():
|
64
64
|
return None
|
65
|
-
|
65
|
+
val = self.cache[key].value
|
66
|
+
return val.copy() if val else None
|
66
67
|
|
67
68
|
def get_future_value(self, key: Tuple[str, str]) -> Future:
|
68
69
|
return self.executor.submit(self.init_value, key)
|
@@ -19,9 +19,12 @@ import json
|
|
19
19
|
import logging
|
20
20
|
from typing import Dict, List, Optional, Tuple, Union
|
21
21
|
|
22
|
+
import interegular
|
22
23
|
import torch
|
23
24
|
from outlines.fsm.guide import RegexGuide
|
25
|
+
from outlines.fsm.json_schema import build_regex_from_schema
|
24
26
|
from outlines.models.transformers import TransformerTokenizer
|
27
|
+
from pydantic import BaseModel
|
25
28
|
|
26
29
|
from sglang.srt.constrained.base_grammar_backend import (
|
27
30
|
BaseGrammarBackend,
|
@@ -32,26 +35,6 @@ from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
|
|
32
35
|
logger = logging.getLogger(__name__)
|
33
36
|
|
34
37
|
|
35
|
-
try:
|
36
|
-
from outlines.fsm.json_schema import build_regex_from_object
|
37
|
-
except ImportError:
|
38
|
-
# Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
|
39
|
-
# which only accepts string schema as input.
|
40
|
-
from outlines.fsm.json_schema import build_regex_from_schema
|
41
|
-
from pydantic import BaseModel
|
42
|
-
|
43
|
-
def build_regex_from_object(
|
44
|
-
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
45
|
-
):
|
46
|
-
if isinstance(object, type(BaseModel)):
|
47
|
-
schema = json.dumps(object.model_json_schema())
|
48
|
-
elif isinstance(object, Dict):
|
49
|
-
schema = json.dumps(object)
|
50
|
-
else:
|
51
|
-
schema = object
|
52
|
-
return build_regex_from_schema(schema, whitespace_pattern)
|
53
|
-
|
54
|
-
|
55
38
|
class OutlinesGrammar(BaseGrammarObject):
|
56
39
|
def __init__(
|
57
40
|
self,
|
@@ -98,9 +81,22 @@ class OutlinesGrammar(BaseGrammarObject):
|
|
98
81
|
):
|
99
82
|
self.state = next_state
|
100
83
|
|
101
|
-
def
|
84
|
+
def allocate_vocab_mask(
|
85
|
+
self, vocab_size: int, batch_size: int, device
|
86
|
+
) -> torch.Tensor:
|
87
|
+
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
|
88
|
+
|
89
|
+
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
90
|
+
tokens = torch.tensor(
|
91
|
+
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
|
92
|
+
).to(vocab_mask.device, non_blocking=True)
|
93
|
+
vocab_mask = vocab_mask[idx]
|
102
94
|
vocab_mask.fill_(1)
|
103
|
-
vocab_mask
|
95
|
+
vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
|
96
|
+
|
97
|
+
@staticmethod
|
98
|
+
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
|
99
|
+
logits.masked_fill_(vocab_mask, float("-inf"))
|
104
100
|
|
105
101
|
def copy(self):
|
106
102
|
return OutlinesGrammar(self.guide, self.jump_forward_map)
|
@@ -147,19 +143,36 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
|
147
143
|
key_string,
|
148
144
|
whitespace_pattern=self.whitespace_pattern,
|
149
145
|
)
|
150
|
-
except NotImplementedError as e:
|
146
|
+
except (NotImplementedError, json.decoder.JSONDecodeError) as e:
|
151
147
|
logger.warning(
|
152
|
-
f"
|
148
|
+
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
153
149
|
)
|
154
|
-
return None
|
150
|
+
return None
|
155
151
|
elif key_type == "regex":
|
156
152
|
regex = key_string
|
157
153
|
else:
|
158
154
|
raise ValueError(f"Invalid key_type: {key_type}")
|
159
155
|
|
160
|
-
|
156
|
+
try:
|
157
|
+
guide = RegexGuide(regex, self.outlines_tokenizer)
|
158
|
+
except interegular.patterns.InvalidSyntax as e:
|
159
|
+
logger.warning(f"skip invalid regex schema: {regex=}, {e=}")
|
160
|
+
return None
|
161
|
+
|
161
162
|
if self.allow_jump_forward:
|
162
163
|
jump_forward_map = OutlinesJumpForwardMap(regex)
|
163
164
|
else:
|
164
165
|
jump_forward_map = None
|
165
166
|
return OutlinesGrammar(guide, jump_forward_map)
|
167
|
+
|
168
|
+
|
169
|
+
def build_regex_from_object(
|
170
|
+
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
171
|
+
):
|
172
|
+
if isinstance(object, type(BaseModel)):
|
173
|
+
schema = json.dumps(object.model_json_schema())
|
174
|
+
elif isinstance(object, Dict):
|
175
|
+
schema = json.dumps(object)
|
176
|
+
else:
|
177
|
+
schema = object
|
178
|
+
return build_regex_from_schema(schema, whitespace_pattern)
|
@@ -15,16 +15,34 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""Constrained decoding with xgrammar backend."""
|
17
17
|
|
18
|
+
import logging
|
18
19
|
from typing import List, Tuple
|
19
20
|
|
20
21
|
import torch
|
21
|
-
|
22
|
+
|
23
|
+
try:
|
24
|
+
from xgrammar import (
|
25
|
+
CachedGrammarCompiler,
|
26
|
+
CompiledGrammar,
|
27
|
+
GrammarMatcher,
|
28
|
+
TokenizerInfo,
|
29
|
+
)
|
30
|
+
|
31
|
+
import_error = None
|
32
|
+
except ImportError as e:
|
33
|
+
CachedGrammarCompiler = CompiledGrammar = GrammarMatcher = TokenizerInfo = (
|
34
|
+
ImportError
|
35
|
+
)
|
36
|
+
import_error = e
|
22
37
|
|
23
38
|
from sglang.srt.constrained.base_grammar_backend import (
|
24
39
|
BaseGrammarBackend,
|
25
40
|
BaseGrammarObject,
|
26
41
|
)
|
27
42
|
|
43
|
+
logger = logging.getLogger(__name__)
|
44
|
+
|
45
|
+
|
28
46
|
MAX_ROLLBACK_TOKENS = 10
|
29
47
|
|
30
48
|
|
@@ -67,19 +85,23 @@ class XGrammarGrammar(BaseGrammarObject):
|
|
67
85
|
for i in range(k, len(new_output_ids)):
|
68
86
|
assert self.matcher.accept_token(new_output_ids[i])
|
69
87
|
|
70
|
-
def
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
88
|
+
def allocate_vocab_mask(
|
89
|
+
self, vocab_size: int, batch_size: int, device
|
90
|
+
) -> torch.Tensor:
|
91
|
+
return self.matcher.allocate_token_bitmask(vocab_size, batch_size)
|
92
|
+
|
93
|
+
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
94
|
+
self.matcher.fill_next_token_bitmask(vocab_mask, idx)
|
95
|
+
|
96
|
+
@staticmethod
|
97
|
+
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
98
|
+
GrammarMatcher.apply_token_bitmask_inplace(logits, vocab_mask)
|
77
99
|
|
78
100
|
def copy(self):
|
79
101
|
matcher = GrammarMatcher(
|
80
102
|
self.ctx,
|
81
103
|
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
82
|
-
|
104
|
+
vocab_size=self.vocab_size,
|
83
105
|
)
|
84
106
|
return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
|
85
107
|
|
@@ -91,24 +113,46 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
91
113
|
vocab_size: int,
|
92
114
|
):
|
93
115
|
super().__init__()
|
94
|
-
|
116
|
+
|
117
|
+
if import_error:
|
118
|
+
logger.warning(
|
119
|
+
f"Ignore import error for the grammar backend: {import_error}"
|
120
|
+
)
|
121
|
+
self.grammar_cache = None
|
122
|
+
return
|
123
|
+
|
124
|
+
tokenizer_info = TokenizerInfo.from_huggingface(tokenizer)
|
125
|
+
self.grammar_cache = CachedGrammarCompiler(tokenizer_info=tokenizer_info)
|
95
126
|
self.vocab_size = vocab_size
|
96
127
|
|
97
128
|
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
129
|
+
if import_error:
|
130
|
+
raise import_error
|
131
|
+
|
98
132
|
key_type, key_string = key
|
99
133
|
if key_type == "json":
|
100
|
-
|
134
|
+
try:
|
135
|
+
ctx = self.grammar_cache.compile_json_schema_grammar(schema=key_string)
|
136
|
+
except RuntimeError as e:
|
137
|
+
logging.warning(
|
138
|
+
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
139
|
+
)
|
140
|
+
return None
|
101
141
|
elif key_type == "regex":
|
102
|
-
|
142
|
+
logger.warning(
|
143
|
+
"regex hasn't been supported by xgrammar yet. This is skipped."
|
144
|
+
)
|
145
|
+
return None
|
103
146
|
else:
|
104
147
|
raise ValueError(f"Invalid key_type: {key_type}")
|
105
148
|
|
106
149
|
matcher = GrammarMatcher(
|
107
150
|
ctx,
|
108
151
|
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
109
|
-
|
152
|
+
vocab_size=self.vocab_size,
|
110
153
|
)
|
111
154
|
return XGrammarGrammar(matcher, self.vocab_size, ctx)
|
112
155
|
|
113
156
|
def reset(self):
|
114
|
-
self.grammar_cache
|
157
|
+
if self.grammar_cache:
|
158
|
+
self.grammar_cache.clear()
|
sglang/srt/layers/activation.py
CHANGED
@@ -32,12 +32,14 @@ from vllm.distributed import (
|
|
32
32
|
)
|
33
33
|
from vllm.model_executor.custom_op import CustomOp
|
34
34
|
|
35
|
+
from sglang.srt.layers.custom_op_util import register_custom_op
|
35
36
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
36
37
|
from sglang.srt.utils import set_weight_attrs
|
37
38
|
|
38
39
|
logger = logging.getLogger(__name__)
|
39
40
|
|
40
41
|
|
42
|
+
@register_custom_op("sglang_silu_and_mul")
|
41
43
|
class SiluAndMul(CustomOp):
|
42
44
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
43
45
|
d = x.shape[-1] // 2
|
@@ -51,6 +53,7 @@ class SiluAndMul(CustomOp):
|
|
51
53
|
return out
|
52
54
|
|
53
55
|
|
56
|
+
@register_custom_op("sglang_gelu_and_mul")
|
54
57
|
class GeluAndMul(CustomOp):
|
55
58
|
def __init__(self, approximate="tanh"):
|
56
59
|
super().__init__()
|