sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +2 -2
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +48 -20
- sglang/bench_one_batch.py +472 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +125 -6
- sglang/check_env.py +3 -6
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +2 -2
- sglang/srt/configs/model_config.py +13 -14
- sglang/srt/constrained/__init__.py +13 -14
- sglang/srt/constrained/base_grammar_backend.py +13 -15
- sglang/srt/constrained/outlines_backend.py +28 -17
- sglang/srt/constrained/outlines_jump_forward.py +13 -15
- sglang/srt/constrained/xgrammar_backend.py +47 -58
- sglang/srt/conversation.py +13 -15
- sglang/srt/hf_transformers_utils.py +13 -15
- sglang/srt/layers/activation.py +16 -13
- sglang/srt/layers/attention/flashinfer_backend.py +106 -54
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
- sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
- sglang/srt/layers/custom_op_util.py +25 -0
- sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
- sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
- sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
- sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
- sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
- sglang/srt/layers/fused_moe_triton/layer.py +633 -0
- sglang/srt/layers/layernorm.py +17 -15
- sglang/srt/layers/logits_processor.py +23 -25
- sglang/srt/layers/quantization/__init__.py +77 -17
- sglang/srt/layers/radix_attention.py +13 -15
- sglang/srt/layers/rotary_embedding.py +13 -13
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/lora/lora.py +13 -14
- sglang/srt/lora/lora_config.py +13 -14
- sglang/srt/lora/lora_manager.py +22 -24
- sglang/srt/managers/data_parallel_controller.py +98 -27
- sglang/srt/managers/detokenizer_manager.py +13 -15
- sglang/srt/managers/io_struct.py +63 -21
- sglang/srt/managers/schedule_batch.py +154 -59
- sglang/srt/managers/schedule_policy.py +18 -16
- sglang/srt/managers/scheduler.py +278 -109
- sglang/srt/managers/session_controller.py +61 -0
- sglang/srt/managers/tokenizer_manager.py +63 -18
- sglang/srt/managers/tp_worker.py +25 -16
- sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
- sglang/srt/metrics/collector.py +13 -15
- sglang/srt/metrics/func_timer.py +13 -15
- sglang/srt/mm_utils.py +13 -14
- sglang/srt/model_executor/cuda_graph_runner.py +63 -25
- sglang/srt/model_executor/forward_batch_info.py +128 -32
- sglang/srt/model_executor/model_runner.py +132 -64
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/chatglm.py +15 -16
- sglang/srt/models/commandr.py +15 -16
- sglang/srt/models/dbrx.py +15 -16
- sglang/srt/models/deepseek.py +15 -15
- sglang/srt/models/deepseek_v2.py +162 -59
- sglang/srt/models/exaone.py +14 -15
- sglang/srt/models/gemma.py +14 -14
- sglang/srt/models/gemma2.py +31 -25
- sglang/srt/models/gemma2_reward.py +13 -14
- sglang/srt/models/gpt_bigcode.py +14 -14
- sglang/srt/models/grok.py +15 -15
- sglang/srt/models/internlm2.py +13 -15
- sglang/srt/models/internlm2_reward.py +13 -14
- sglang/srt/models/llama.py +21 -21
- sglang/srt/models/llama_classification.py +13 -14
- sglang/srt/models/llama_reward.py +13 -14
- sglang/srt/models/llava.py +14 -16
- sglang/srt/models/llavavid.py +14 -16
- sglang/srt/models/minicpm.py +13 -15
- sglang/srt/models/minicpm3.py +13 -15
- sglang/srt/models/mistral.py +13 -15
- sglang/srt/models/mixtral.py +15 -15
- sglang/srt/models/mixtral_quant.py +14 -14
- sglang/srt/models/olmo.py +22 -20
- sglang/srt/models/olmoe.py +23 -20
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen.py +14 -14
- sglang/srt/models/qwen2.py +22 -19
- sglang/srt/models/qwen2_moe.py +17 -18
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/stablelm.py +18 -16
- sglang/srt/models/torch_native_llama.py +107 -93
- sglang/srt/models/xverse.py +13 -14
- sglang/srt/models/xverse_moe.py +15 -16
- sglang/srt/models/yivl.py +13 -15
- sglang/srt/openai_api/adapter.py +19 -17
- sglang/srt/openai_api/protocol.py +14 -16
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +61 -57
- sglang/srt/sampling/sampling_params.py +14 -16
- sglang/srt/server.py +86 -35
- sglang/srt/server_args.py +96 -80
- sglang/srt/utils.py +266 -68
- sglang/test/few_shot_gsm8k.py +8 -4
- sglang/test/runners.py +38 -20
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +31 -20
- sglang/version.py +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
- sglang-0.3.6.post1.dist-info/RECORD +164 -0
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
- sglang/srt/layers/fused_moe/__init__.py +0 -1
- sglang-0.3.5.post2.dist-info/RECORD +0 -156
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py
CHANGED
@@ -15,6 +15,7 @@ import argparse
|
|
15
15
|
import asyncio
|
16
16
|
import json
|
17
17
|
import os
|
18
|
+
import pickle
|
18
19
|
import random
|
19
20
|
import resource
|
20
21
|
import sys
|
@@ -24,6 +25,7 @@ import warnings
|
|
24
25
|
from argparse import ArgumentParser
|
25
26
|
from dataclasses import dataclass, field
|
26
27
|
from datetime import datetime
|
28
|
+
from pathlib import Path
|
27
29
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
28
30
|
|
29
31
|
import aiohttp
|
@@ -387,8 +389,26 @@ async def async_request_gserver(
|
|
387
389
|
raise NotImplementedError()
|
388
390
|
|
389
391
|
|
392
|
+
async def async_request_profile(api_url: str) -> RequestFuncOutput:
|
393
|
+
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
394
|
+
output = RequestFuncOutput()
|
395
|
+
try:
|
396
|
+
async with session.post(url=api_url) as response:
|
397
|
+
if response.status == 200:
|
398
|
+
output.success = True
|
399
|
+
else:
|
400
|
+
output.error = response.reason or ""
|
401
|
+
output.success = False
|
402
|
+
except Exception:
|
403
|
+
output.success = False
|
404
|
+
exc_info = sys.exc_info()
|
405
|
+
output.error = "".join(traceback.format_exception(*exc_info))
|
406
|
+
|
407
|
+
return output
|
408
|
+
|
409
|
+
|
390
410
|
def get_model(pretrained_model_name_or_path: str) -> str:
|
391
|
-
if os.getenv("SGLANG_USE_MODELSCOPE", "
|
411
|
+
if os.getenv("SGLANG_USE_MODELSCOPE", "false").lower() == "true":
|
392
412
|
import huggingface_hub.constants
|
393
413
|
from modelscope import snapshot_download
|
394
414
|
|
@@ -674,6 +694,19 @@ def gen_prompt(tokenizer, token_num):
|
|
674
694
|
return tokenizer.decode(selected_tokens)
|
675
695
|
|
676
696
|
|
697
|
+
def get_gen_prefix_cache_path(args, tokenizer):
|
698
|
+
"""Create cache directory under ~/.cache/sglang/benchmark"""
|
699
|
+
cache_dir = Path.home() / ".cache" / "sglang" / "benchmark"
|
700
|
+
|
701
|
+
# Create a unique cache filename based on the generation parameters
|
702
|
+
cache_key = (
|
703
|
+
f"gen_prefix_{args.gen_num_groups}_{args.gen_prompts_per_group}_"
|
704
|
+
f"{args.gen_system_prompt_len}_{args.gen_question_len}_{args.gen_output_len}_"
|
705
|
+
f"{tokenizer.__class__.__name__}.pkl"
|
706
|
+
)
|
707
|
+
return cache_dir / cache_key
|
708
|
+
|
709
|
+
|
677
710
|
def sample_generated_shared_prefix_requests(
|
678
711
|
num_groups: int,
|
679
712
|
prompts_per_group: int,
|
@@ -682,7 +715,17 @@ def sample_generated_shared_prefix_requests(
|
|
682
715
|
output_len: int,
|
683
716
|
tokenizer: PreTrainedTokenizerBase,
|
684
717
|
) -> List[Tuple[str, int, int]]:
|
685
|
-
"""Generate benchmark requests with shared system prompts using random tokens."""
|
718
|
+
"""Generate benchmark requests with shared system prompts using random tokens and caching."""
|
719
|
+
cache_path = get_gen_prefix_cache_path(args, tokenizer)
|
720
|
+
|
721
|
+
# Try to load from cache first
|
722
|
+
if cache_path.exists():
|
723
|
+
print(f"\nLoading cached generated input data from {cache_path}")
|
724
|
+
with open(cache_path, "rb") as f:
|
725
|
+
return pickle.load(f)
|
726
|
+
|
727
|
+
print("\nGenerating new input data...")
|
728
|
+
|
686
729
|
# Generate system prompts for each group
|
687
730
|
system_prompts = []
|
688
731
|
for _ in range(num_groups):
|
@@ -700,9 +743,11 @@ def sample_generated_shared_prefix_requests(
|
|
700
743
|
total_input_tokens = 0
|
701
744
|
total_output_tokens = 0
|
702
745
|
|
703
|
-
for group_idx in range(num_groups):
|
746
|
+
for group_idx in tqdm(range(num_groups), desc="Generating system prompt"):
|
704
747
|
system_prompt = system_prompts[group_idx]
|
705
|
-
for prompt_idx in
|
748
|
+
for prompt_idx in tqdm(
|
749
|
+
range(prompts_per_group), desc="Generating questions", leave=False
|
750
|
+
):
|
706
751
|
question = questions[group_idx * prompts_per_group + prompt_idx]
|
707
752
|
full_prompt = f"{system_prompt}\n\n{question}"
|
708
753
|
prompt_len = len(tokenizer.encode(full_prompt))
|
@@ -711,6 +756,10 @@ def sample_generated_shared_prefix_requests(
|
|
711
756
|
total_input_tokens += prompt_len
|
712
757
|
total_output_tokens += output_len
|
713
758
|
|
759
|
+
# Shuffle questions
|
760
|
+
random.shuffle(input_requests)
|
761
|
+
|
762
|
+
# Print statistics
|
714
763
|
print(f"\nGenerated shared prefix dataset statistics:")
|
715
764
|
print(f"Number of groups: {num_groups}")
|
716
765
|
print(f"Prompts per group: {prompts_per_group}")
|
@@ -724,6 +773,12 @@ def sample_generated_shared_prefix_requests(
|
|
724
773
|
f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
|
725
774
|
)
|
726
775
|
|
776
|
+
# Save to cache
|
777
|
+
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
778
|
+
print(f"Caching generated input data to {cache_path}")
|
779
|
+
with open(cache_path, "wb") as f:
|
780
|
+
pickle.dump(input_requests, f)
|
781
|
+
|
727
782
|
return input_requests
|
728
783
|
|
729
784
|
|
@@ -822,18 +877,30 @@ def calculate_metrics(
|
|
822
877
|
async def benchmark(
|
823
878
|
backend: str,
|
824
879
|
api_url: str,
|
880
|
+
base_url: str,
|
825
881
|
model_id: str,
|
826
882
|
tokenizer: PreTrainedTokenizerBase,
|
827
883
|
input_requests: List[Tuple[str, int, int]],
|
828
884
|
request_rate: float,
|
885
|
+
max_concurrency: Optional[int],
|
829
886
|
disable_tqdm: bool,
|
830
887
|
extra_request_body: Dict[str, Any],
|
888
|
+
profile: bool,
|
831
889
|
):
|
832
890
|
if backend in ASYNC_REQUEST_FUNCS:
|
833
891
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
834
892
|
else:
|
835
893
|
raise ValueError(f"Unknown backend: {backend}")
|
836
894
|
|
895
|
+
# From https://github.com/vllm-project/vllm/pull/9390
|
896
|
+
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
|
897
|
+
|
898
|
+
async def limited_request_func(request_func_input, pbar):
|
899
|
+
if semaphore is None:
|
900
|
+
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
901
|
+
async with semaphore:
|
902
|
+
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
903
|
+
|
837
904
|
print("Starting initial single prompt test run...")
|
838
905
|
test_prompt, test_prompt_len, test_output_len = input_requests[0]
|
839
906
|
test_input = RequestFuncInput(
|
@@ -855,6 +922,14 @@ async def benchmark(
|
|
855
922
|
|
856
923
|
time.sleep(1.5)
|
857
924
|
|
925
|
+
if profile:
|
926
|
+
print("Starting profiler...")
|
927
|
+
profile_output = await async_request_profile(
|
928
|
+
api_url=base_url + "/start_profile"
|
929
|
+
)
|
930
|
+
if profile_output.success:
|
931
|
+
print("Profiler started")
|
932
|
+
|
858
933
|
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
859
934
|
|
860
935
|
benchmark_start_time = time.perf_counter()
|
@@ -871,11 +946,17 @@ async def benchmark(
|
|
871
946
|
)
|
872
947
|
tasks.append(
|
873
948
|
asyncio.create_task(
|
874
|
-
|
949
|
+
limited_request_func(request_func_input=request_func_input, pbar=pbar)
|
875
950
|
)
|
876
951
|
)
|
877
952
|
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
878
953
|
|
954
|
+
if profile:
|
955
|
+
print("Stopping profiler...")
|
956
|
+
profile_output = await async_request_profile(api_url=base_url + "/stop_profile")
|
957
|
+
if profile_output.success:
|
958
|
+
print("Profiler stopped")
|
959
|
+
|
879
960
|
if pbar is not None:
|
880
961
|
pbar.close()
|
881
962
|
|
@@ -892,6 +973,12 @@ async def benchmark(
|
|
892
973
|
print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
|
893
974
|
print("{:<40} {:<10}".format("Backend:", backend))
|
894
975
|
print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
|
976
|
+
print(
|
977
|
+
"{:<40} {:<10}".format(
|
978
|
+
"Max reqeuest concurrency:",
|
979
|
+
max_concurrency if max_concurrency else "not set",
|
980
|
+
)
|
981
|
+
)
|
895
982
|
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
896
983
|
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
|
897
984
|
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
@@ -955,6 +1042,7 @@ async def benchmark(
|
|
955
1042
|
"backend": args.backend,
|
956
1043
|
"dataset_name": args.dataset_name,
|
957
1044
|
"request_rate": request_rate,
|
1045
|
+
"max_concurrency": max_concurrency,
|
958
1046
|
"total_input_tokens": metrics.total_input,
|
959
1047
|
"total_output_tokens": metrics.total_output,
|
960
1048
|
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
@@ -1042,6 +1130,10 @@ def run_benchmark(args_: argparse.Namespace):
|
|
1042
1130
|
global args
|
1043
1131
|
args = args_
|
1044
1132
|
|
1133
|
+
# Set default value for max_concurrency if not present
|
1134
|
+
if not hasattr(args, "max_concurrency"):
|
1135
|
+
args.max_concurrency = None
|
1136
|
+
|
1045
1137
|
# Set global environments
|
1046
1138
|
set_ulimit()
|
1047
1139
|
random.seed(args.seed)
|
@@ -1100,6 +1192,9 @@ def run_benchmark(args_: argparse.Namespace):
|
|
1100
1192
|
if args.base_url
|
1101
1193
|
else f"http://{args.host}:{args.port}/v1/models/model:predict"
|
1102
1194
|
)
|
1195
|
+
base_url = (
|
1196
|
+
f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url
|
1197
|
+
)
|
1103
1198
|
|
1104
1199
|
# Get model name
|
1105
1200
|
if args.model is None:
|
@@ -1145,12 +1240,15 @@ def run_benchmark(args_: argparse.Namespace):
|
|
1145
1240
|
benchmark(
|
1146
1241
|
backend=backend,
|
1147
1242
|
api_url=api_url,
|
1243
|
+
base_url=base_url,
|
1148
1244
|
model_id=model_id,
|
1149
1245
|
tokenizer=tokenizer,
|
1150
1246
|
input_requests=input_requests,
|
1151
1247
|
request_rate=args.request_rate,
|
1248
|
+
max_concurrency=args.max_concurrency,
|
1152
1249
|
disable_tqdm=args.disable_tqdm,
|
1153
1250
|
extra_request_body=extra_request_body,
|
1251
|
+
profile=args.profile,
|
1154
1252
|
)
|
1155
1253
|
)
|
1156
1254
|
else:
|
@@ -1162,12 +1260,15 @@ def run_benchmark(args_: argparse.Namespace):
|
|
1162
1260
|
benchmark(
|
1163
1261
|
backend=backend,
|
1164
1262
|
api_url=api_url,
|
1263
|
+
base_url=base_url,
|
1165
1264
|
model_id=model_id,
|
1166
1265
|
tokenizer=tokenizer,
|
1167
1266
|
input_requests=input_requests,
|
1168
1267
|
request_rate=rate,
|
1268
|
+
max_concurrency=args.max_concurrency,
|
1169
1269
|
disable_tqdm=args.disable_tqdm,
|
1170
1270
|
extra_request_body=extra_request_body,
|
1271
|
+
profile=args.profile,
|
1171
1272
|
)
|
1172
1273
|
)
|
1173
1274
|
|
@@ -1264,6 +1365,19 @@ if __name__ == "__main__":
|
|
1264
1365
|
help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
|
1265
1366
|
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
|
1266
1367
|
)
|
1368
|
+
parser.add_argument(
|
1369
|
+
"--max-concurrency",
|
1370
|
+
type=int,
|
1371
|
+
default=None,
|
1372
|
+
help="Maximum number of concurrent requests. This can be used "
|
1373
|
+
"to help simulate an environment where a higher level component "
|
1374
|
+
"is enforcing a maximum number of concurrent requests. While the "
|
1375
|
+
"--request-rate argument controls the rate at which requests are "
|
1376
|
+
"initiated, this argument will control how many are actually allowed "
|
1377
|
+
"to execute at a time. This means that when used in combination, the "
|
1378
|
+
"actual request rate may be lower than specified with --request-rate, "
|
1379
|
+
"if the server is not processing requests fast enough to keep up.",
|
1380
|
+
)
|
1267
1381
|
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
1268
1382
|
parser.add_argument(
|
1269
1383
|
"--multi",
|
@@ -1331,6 +1445,11 @@ if __name__ == "__main__":
|
|
1331
1445
|
default=256,
|
1332
1446
|
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
1333
1447
|
)
|
1334
|
-
|
1448
|
+
parser.add_argument(
|
1449
|
+
"--profile",
|
1450
|
+
action="store_true",
|
1451
|
+
help="Use Torch Profiler. The endpoint must be launched with "
|
1452
|
+
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
1453
|
+
)
|
1335
1454
|
args = parser.parse_args()
|
1336
1455
|
run_benchmark(args)
|
sglang/check_env.py
CHANGED
@@ -15,24 +15,21 @@ PACKAGE_LIST = [
|
|
15
15
|
"flashinfer",
|
16
16
|
"triton",
|
17
17
|
"transformers",
|
18
|
-
"
|
19
|
-
"tqdm",
|
18
|
+
"torchao",
|
20
19
|
"numpy",
|
21
20
|
"aiohttp",
|
22
21
|
"fastapi",
|
23
22
|
"hf_transfer",
|
24
23
|
"huggingface_hub",
|
25
24
|
"interegular",
|
26
|
-
"packaging",
|
27
|
-
"PIL",
|
28
25
|
"psutil",
|
29
26
|
"pydantic",
|
27
|
+
"multipart",
|
28
|
+
"zmq",
|
30
29
|
"uvicorn",
|
31
30
|
"uvloop",
|
32
|
-
"zmq",
|
33
31
|
"vllm",
|
34
32
|
"outlines",
|
35
|
-
"multipart",
|
36
33
|
"openai",
|
37
34
|
"tiktoken",
|
38
35
|
"anthropic",
|
@@ -58,9 +58,9 @@ class RuntimeEndpoint(BaseBackend):
|
|
58
58
|
)
|
59
59
|
self._assert_success(res)
|
60
60
|
|
61
|
-
def
|
61
|
+
def get_server_info(self):
|
62
62
|
res = http_request(
|
63
|
-
self.base_url + "/
|
63
|
+
self.base_url + "/get_server_info",
|
64
64
|
api_key=self.api_key,
|
65
65
|
verify=self.verify,
|
66
66
|
)
|
@@ -1,17 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
15
14
|
|
16
15
|
import json
|
17
16
|
import logging
|
@@ -1,17 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
15
14
|
|
16
15
|
# TODO(lmzheng): make this an optional dependency
|
17
16
|
from sglang.srt.constrained.outlines_backend import build_regex_from_object
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""The baseclass of a backend for grammar-guided constrained decoding."""
|
17
15
|
|
18
16
|
from concurrent.futures import Future, ThreadPoolExecutor
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""Constrained decoding with outlines backend."""
|
17
15
|
|
18
16
|
import json
|
@@ -81,9 +79,22 @@ class OutlinesGrammar(BaseGrammarObject):
|
|
81
79
|
):
|
82
80
|
self.state = next_state
|
83
81
|
|
84
|
-
def
|
82
|
+
def allocate_vocab_mask(
|
83
|
+
self, vocab_size: int, batch_size: int, device
|
84
|
+
) -> torch.Tensor:
|
85
|
+
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
|
86
|
+
|
87
|
+
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
88
|
+
tokens = torch.tensor(
|
89
|
+
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
|
90
|
+
).to(vocab_mask.device, non_blocking=True)
|
91
|
+
vocab_mask = vocab_mask[idx]
|
85
92
|
vocab_mask.fill_(1)
|
86
|
-
vocab_mask
|
93
|
+
vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
|
94
|
+
|
95
|
+
@staticmethod
|
96
|
+
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
|
97
|
+
logits.masked_fill_(vocab_mask, float("-inf"))
|
87
98
|
|
88
99
|
def copy(self):
|
89
100
|
return OutlinesGrammar(self.guide, self.jump_forward_map)
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""
|
17
15
|
Faster constrained decoding with jump forward decoding / compressed finite state machine.
|
18
16
|
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
|