sglang 0.3.5__py3-none-any.whl → 0.3.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +113 -3
- sglang/srt/configs/model_config.py +5 -2
- sglang/srt/constrained/__init__.py +2 -66
- sglang/srt/constrained/base_grammar_backend.py +72 -0
- sglang/srt/constrained/outlines_backend.py +165 -0
- sglang/srt/constrained/outlines_jump_forward.py +182 -0
- sglang/srt/constrained/xgrammar_backend.py +114 -0
- sglang/srt/layers/attention/triton_ops/decode_attention.py +7 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +6 -0
- sglang/srt/layers/fused_moe/fused_moe.py +23 -7
- sglang/srt/layers/quantization/base_config.py +4 -6
- sglang/srt/layers/vocab_parallel_embedding.py +216 -150
- sglang/srt/managers/io_struct.py +5 -3
- sglang/srt/managers/schedule_batch.py +14 -20
- sglang/srt/managers/scheduler.py +153 -94
- sglang/srt/managers/tokenizer_manager.py +81 -17
- sglang/srt/metrics/collector.py +211 -0
- sglang/srt/metrics/func_timer.py +108 -0
- sglang/srt/mm_utils.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/model_executor/forward_batch_info.py +7 -3
- sglang/srt/model_executor/model_runner.py +2 -1
- sglang/srt/models/gemma2_reward.py +69 -0
- sglang/srt/models/gpt2.py +31 -37
- sglang/srt/models/internlm2_reward.py +62 -0
- sglang/srt/models/llama.py +11 -6
- sglang/srt/models/llama_reward.py +5 -26
- sglang/srt/models/qwen2_vl.py +5 -7
- sglang/srt/openai_api/adapter.py +6 -2
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/sampling/sampling_params.py +0 -14
- sglang/srt/server.py +58 -16
- sglang/srt/server_args.py +42 -22
- sglang/srt/utils.py +87 -0
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_mgsm.py +2 -2
- sglang/test/test_utils.py +18 -4
- sglang/utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/METADATA +11 -7
- {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/RECORD +45 -42
- {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/WHEEL +1 -1
- sglang/srt/constrained/base_tool_cache.py +0 -65
- sglang/srt/constrained/bnf_cache.py +0 -61
- sglang/srt/constrained/fsm_cache.py +0 -95
- sglang/srt/constrained/grammar.py +0 -190
- sglang/srt/constrained/jump_forward.py +0 -203
- {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py
CHANGED
@@ -596,12 +596,20 @@ def sample_random_requests(
|
|
596
596
|
|
597
597
|
# Filter out sequences that are too long or too short
|
598
598
|
input_requests: List[Tuple[str, int, int]] = []
|
599
|
-
for
|
599
|
+
for data in dataset:
|
600
|
+
i = len(input_requests)
|
601
|
+
if i == num_prompts:
|
602
|
+
break
|
603
|
+
|
600
604
|
# Tokenize the prompts and completions.
|
601
|
-
prompt =
|
605
|
+
prompt = data[0]
|
602
606
|
prompt_token_ids = tokenizer.encode(prompt)
|
603
607
|
prompt_len = len(prompt_token_ids)
|
604
608
|
|
609
|
+
# Skip empty prompt
|
610
|
+
if prompt_len == 0:
|
611
|
+
continue
|
612
|
+
|
605
613
|
if prompt_len > input_lens[i]:
|
606
614
|
input_ids = prompt_token_ids[: input_lens[i]]
|
607
615
|
else:
|
@@ -627,6 +635,66 @@ def sample_random_requests(
|
|
627
635
|
return input_requests
|
628
636
|
|
629
637
|
|
638
|
+
def gen_prompt(tokenizer, token_num):
|
639
|
+
"""Generate a random prompt of specified token length using tokenizer vocabulary."""
|
640
|
+
all_available_tokens = list(tokenizer.get_vocab().values())
|
641
|
+
selected_tokens = random.choices(all_available_tokens, k=token_num)
|
642
|
+
return tokenizer.decode(selected_tokens)
|
643
|
+
|
644
|
+
|
645
|
+
def sample_generated_shared_prefix_requests(
|
646
|
+
num_groups: int,
|
647
|
+
prompts_per_group: int,
|
648
|
+
system_prompt_len: int,
|
649
|
+
question_len: int,
|
650
|
+
output_len: int,
|
651
|
+
tokenizer: PreTrainedTokenizerBase,
|
652
|
+
) -> List[Tuple[str, int, int]]:
|
653
|
+
"""Generate benchmark requests with shared system prompts using random tokens."""
|
654
|
+
# Generate system prompts for each group
|
655
|
+
system_prompts = []
|
656
|
+
for _ in range(num_groups):
|
657
|
+
system_prompt = gen_prompt(tokenizer, system_prompt_len)
|
658
|
+
system_prompts.append(system_prompt)
|
659
|
+
|
660
|
+
# Generate questions
|
661
|
+
questions = []
|
662
|
+
for _ in range(num_groups * prompts_per_group):
|
663
|
+
question = gen_prompt(tokenizer, question_len)
|
664
|
+
questions.append(question)
|
665
|
+
|
666
|
+
# Combine system prompts with questions
|
667
|
+
input_requests = []
|
668
|
+
total_input_tokens = 0
|
669
|
+
total_output_tokens = 0
|
670
|
+
|
671
|
+
for group_idx in range(num_groups):
|
672
|
+
system_prompt = system_prompts[group_idx]
|
673
|
+
for prompt_idx in range(prompts_per_group):
|
674
|
+
question = questions[group_idx * prompts_per_group + prompt_idx]
|
675
|
+
full_prompt = f"{system_prompt}\n\n{question}"
|
676
|
+
prompt_len = len(tokenizer.encode(full_prompt))
|
677
|
+
|
678
|
+
input_requests.append((full_prompt, prompt_len, output_len))
|
679
|
+
total_input_tokens += prompt_len
|
680
|
+
total_output_tokens += output_len
|
681
|
+
|
682
|
+
print(f"\nGenerated shared prefix dataset statistics:")
|
683
|
+
print(f"Number of groups: {num_groups}")
|
684
|
+
print(f"Prompts per group: {prompts_per_group}")
|
685
|
+
print(f"Total prompts: {len(input_requests)}")
|
686
|
+
print(f"Total input tokens: {total_input_tokens}")
|
687
|
+
print(f"Total output tokens: {total_output_tokens}")
|
688
|
+
print(
|
689
|
+
f"Average system prompt length: {sum(len(tokenizer.encode(sp)) for sp in system_prompts) / len(system_prompts):.1f} tokens"
|
690
|
+
)
|
691
|
+
print(
|
692
|
+
f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
|
693
|
+
)
|
694
|
+
|
695
|
+
return input_requests
|
696
|
+
|
697
|
+
|
630
698
|
async def get_request(
|
631
699
|
input_requests: List[Tuple[str, int, int]],
|
632
700
|
request_rate: float,
|
@@ -1048,6 +1116,15 @@ def run_benchmark(args_: argparse.Namespace):
|
|
1048
1116
|
tokenizer=tokenizer,
|
1049
1117
|
dataset_path=args.dataset_path,
|
1050
1118
|
)
|
1119
|
+
elif args.dataset_name == "generated-shared-prefix":
|
1120
|
+
input_requests = sample_generated_shared_prefix_requests(
|
1121
|
+
num_groups=args.gen_num_groups,
|
1122
|
+
prompts_per_group=args.gen_prompts_per_group,
|
1123
|
+
system_prompt_len=args.gen_system_prompt_len,
|
1124
|
+
question_len=args.gen_question_len,
|
1125
|
+
output_len=args.gen_output_len,
|
1126
|
+
tokenizer=tokenizer,
|
1127
|
+
)
|
1051
1128
|
else:
|
1052
1129
|
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
1053
1130
|
|
@@ -1121,7 +1198,7 @@ if __name__ == "__main__":
|
|
1121
1198
|
"--dataset-name",
|
1122
1199
|
type=str,
|
1123
1200
|
default="sharegpt",
|
1124
|
-
choices=["sharegpt", "random"],
|
1201
|
+
choices=["sharegpt", "random", "generated-shared-prefix"],
|
1125
1202
|
help="Name of the dataset to benchmark on.",
|
1126
1203
|
)
|
1127
1204
|
parser.add_argument(
|
@@ -1208,5 +1285,38 @@ if __name__ == "__main__":
|
|
1208
1285
|
help="Append given JSON object to the request payload. You can use this to specify"
|
1209
1286
|
"additional generate params like sampling params.",
|
1210
1287
|
)
|
1288
|
+
|
1289
|
+
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
1290
|
+
group.add_argument(
|
1291
|
+
"--gen-num-groups",
|
1292
|
+
type=int,
|
1293
|
+
default=64,
|
1294
|
+
help="Number of system prompt groups for generated-shared-prefix dataset",
|
1295
|
+
)
|
1296
|
+
group.add_argument(
|
1297
|
+
"--gen-prompts-per-group",
|
1298
|
+
type=int,
|
1299
|
+
default=16,
|
1300
|
+
help="Number of prompts per system prompt group for generated-shared-prefix dataset",
|
1301
|
+
)
|
1302
|
+
group.add_argument(
|
1303
|
+
"--gen-system-prompt-len",
|
1304
|
+
type=int,
|
1305
|
+
default=2048,
|
1306
|
+
help="Target length in tokens for system prompts in generated-shared-prefix dataset",
|
1307
|
+
)
|
1308
|
+
group.add_argument(
|
1309
|
+
"--gen-question-len",
|
1310
|
+
type=int,
|
1311
|
+
default=128,
|
1312
|
+
help="Target length in tokens for questions in generated-shared-prefix dataset",
|
1313
|
+
)
|
1314
|
+
group.add_argument(
|
1315
|
+
"--gen-output-len",
|
1316
|
+
type=int,
|
1317
|
+
default=256,
|
1318
|
+
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
1319
|
+
)
|
1320
|
+
|
1211
1321
|
args = parser.parse_args()
|
1212
1322
|
run_benchmark(args)
|
@@ -39,7 +39,7 @@ class ModelConfig:
|
|
39
39
|
revision: Optional[str] = None,
|
40
40
|
context_length: Optional[int] = None,
|
41
41
|
model_override_args: Optional[dict] = None,
|
42
|
-
is_embedding: Optional[bool] = None
|
42
|
+
is_embedding: Optional[bool] = None,
|
43
43
|
) -> None:
|
44
44
|
# Parse args
|
45
45
|
self.model_override_args = json.loads(model_override_args)
|
@@ -52,7 +52,9 @@ class ModelConfig:
|
|
52
52
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
53
53
|
|
54
54
|
# Check model type
|
55
|
-
self.is_generation = is_generation_model(
|
55
|
+
self.is_generation = is_generation_model(
|
56
|
+
self.hf_config.architectures, is_embedding
|
57
|
+
)
|
56
58
|
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
|
57
59
|
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
58
60
|
|
@@ -208,6 +210,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
|
|
208
210
|
or "MistralModel" in model_architectures
|
209
211
|
or "LlamaForSequenceClassification" in model_architectures
|
210
212
|
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
213
|
+
or "InternLM2ForRewardModel" in model_architectures
|
211
214
|
):
|
212
215
|
return False
|
213
216
|
else:
|
@@ -13,69 +13,5 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
-
|
17
|
-
|
18
|
-
import json
|
19
|
-
from typing import Dict, Optional, Union
|
20
|
-
|
21
|
-
from pydantic import BaseModel
|
22
|
-
|
23
|
-
try:
|
24
|
-
from outlines.caching import cache as disk_cache
|
25
|
-
from outlines.caching import disable_cache
|
26
|
-
from outlines.fsm.guide import RegexGuide
|
27
|
-
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
28
|
-
from outlines.models.transformers import TransformerTokenizer
|
29
|
-
except ImportError as e:
|
30
|
-
print(
|
31
|
-
f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n'
|
32
|
-
)
|
33
|
-
raise
|
34
|
-
|
35
|
-
try:
|
36
|
-
from outlines.fsm.json_schema import build_regex_from_object
|
37
|
-
except ImportError:
|
38
|
-
# Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
|
39
|
-
# which only accepts string schema as input.
|
40
|
-
from outlines.fsm.json_schema import build_regex_from_schema
|
41
|
-
|
42
|
-
def build_regex_from_object(
|
43
|
-
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
44
|
-
):
|
45
|
-
if isinstance(object, type(BaseModel)):
|
46
|
-
schema = json.dumps(object.model_json_schema())
|
47
|
-
elif isinstance(object, Dict):
|
48
|
-
schema = json.dumps(object)
|
49
|
-
else:
|
50
|
-
schema = object
|
51
|
-
return build_regex_from_schema(schema, whitespace_pattern)
|
52
|
-
|
53
|
-
|
54
|
-
try:
|
55
|
-
from xgrammar import (
|
56
|
-
GrammarMatcher,
|
57
|
-
GrammarMatcherInitContext,
|
58
|
-
GrammarMatcherInitContextCache,
|
59
|
-
)
|
60
|
-
except ImportError as e:
|
61
|
-
|
62
|
-
class Dummy:
|
63
|
-
pass
|
64
|
-
|
65
|
-
GrammarMatcher = Dummy
|
66
|
-
GrammarMatcherInitContext = Dummy
|
67
|
-
GrammarMatcherInitContextCache = Dummy
|
68
|
-
|
69
|
-
__all__ = [
|
70
|
-
"RegexGuide",
|
71
|
-
"FSMInfo",
|
72
|
-
"make_deterministic_fsm",
|
73
|
-
"build_regex_from_object",
|
74
|
-
"TransformerTokenizer",
|
75
|
-
"disk_cache",
|
76
|
-
"disable_cache",
|
77
|
-
"make_byte_level_fsm",
|
78
|
-
"GrammarMatcher",
|
79
|
-
"GrammarMatcherInitContext",
|
80
|
-
"GrammarMatcherInitContextCache",
|
81
|
-
]
|
16
|
+
# TODO(lmzheng): make this an optional dependency
|
17
|
+
from sglang.srt.constrained.outlines_backend import build_regex_from_object
|
@@ -0,0 +1,72 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""The baseclass of backends for grammar-guided constrained decoding."""
|
17
|
+
|
18
|
+
from concurrent.futures import Future, ThreadPoolExecutor
|
19
|
+
from dataclasses import dataclass
|
20
|
+
from threading import Event, Lock
|
21
|
+
from typing import Any, Optional, Tuple
|
22
|
+
|
23
|
+
|
24
|
+
@dataclass
|
25
|
+
class CacheEntry:
|
26
|
+
value: Any
|
27
|
+
event: Event
|
28
|
+
|
29
|
+
|
30
|
+
class BaseGrammarObject:
|
31
|
+
pass
|
32
|
+
|
33
|
+
|
34
|
+
class BaseGrammarBackend:
|
35
|
+
def __init__(self):
|
36
|
+
self.executor = ThreadPoolExecutor()
|
37
|
+
self.cache = {}
|
38
|
+
self.cache_lock = Lock()
|
39
|
+
|
40
|
+
def init_value(self, key: Tuple[str, str]) -> BaseGrammarObject:
|
41
|
+
with self.cache_lock:
|
42
|
+
if key in self.cache:
|
43
|
+
cache_hit = True
|
44
|
+
entry = self.cache[key]
|
45
|
+
else:
|
46
|
+
cache_hit = False
|
47
|
+
entry = CacheEntry(None, Event())
|
48
|
+
self.cache[key] = entry
|
49
|
+
|
50
|
+
if cache_hit:
|
51
|
+
entry.event.wait()
|
52
|
+
else:
|
53
|
+
entry.value = self.init_value_impl(key)
|
54
|
+
entry.event.set()
|
55
|
+
return entry.value.copy()
|
56
|
+
|
57
|
+
def init_value_impl(self, key: Tuple[str, str]) -> BaseGrammarObject:
|
58
|
+
raise NotImplementedError()
|
59
|
+
|
60
|
+
def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
61
|
+
with self.cache_lock:
|
62
|
+
entry = self.cache.get(key)
|
63
|
+
if not entry or not entry.event.is_set():
|
64
|
+
return None
|
65
|
+
return self.cache[key].value.copy()
|
66
|
+
|
67
|
+
def get_future_value(self, key: Tuple[str, str]) -> Future:
|
68
|
+
return self.executor.submit(self.init_value, key)
|
69
|
+
|
70
|
+
def reset(self):
|
71
|
+
with self.cache_lock:
|
72
|
+
self.cache.clear()
|
@@ -0,0 +1,165 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""Constrained decoding with outlines backend."""
|
17
|
+
|
18
|
+
import json
|
19
|
+
import logging
|
20
|
+
from typing import Dict, List, Optional, Tuple, Union
|
21
|
+
|
22
|
+
import torch
|
23
|
+
from outlines.fsm.guide import RegexGuide
|
24
|
+
from outlines.models.transformers import TransformerTokenizer
|
25
|
+
|
26
|
+
from sglang.srt.constrained.base_grammar_backend import (
|
27
|
+
BaseGrammarBackend,
|
28
|
+
BaseGrammarObject,
|
29
|
+
)
|
30
|
+
from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
|
31
|
+
|
32
|
+
logger = logging.getLogger(__name__)
|
33
|
+
|
34
|
+
|
35
|
+
try:
|
36
|
+
from outlines.fsm.json_schema import build_regex_from_object
|
37
|
+
except ImportError:
|
38
|
+
# Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
|
39
|
+
# which only accepts string schema as input.
|
40
|
+
from outlines.fsm.json_schema import build_regex_from_schema
|
41
|
+
from pydantic import BaseModel
|
42
|
+
|
43
|
+
def build_regex_from_object(
|
44
|
+
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
45
|
+
):
|
46
|
+
if isinstance(object, type(BaseModel)):
|
47
|
+
schema = json.dumps(object.model_json_schema())
|
48
|
+
elif isinstance(object, Dict):
|
49
|
+
schema = json.dumps(object)
|
50
|
+
else:
|
51
|
+
schema = object
|
52
|
+
return build_regex_from_schema(schema, whitespace_pattern)
|
53
|
+
|
54
|
+
|
55
|
+
class OutlinesGrammar(BaseGrammarObject):
|
56
|
+
def __init__(
|
57
|
+
self,
|
58
|
+
guide: RegexGuide,
|
59
|
+
jump_forward_map: Union[OutlinesJumpForwardMap, None],
|
60
|
+
) -> None:
|
61
|
+
self.guide = guide
|
62
|
+
self.jump_forward_map = jump_forward_map
|
63
|
+
self.state = 0
|
64
|
+
|
65
|
+
def accept_token(self, token: int):
|
66
|
+
self.state = self.guide.get_next_state(self.state, token)
|
67
|
+
|
68
|
+
def try_jump_forward(self, tokenizer) -> Optional[Tuple]:
|
69
|
+
if not self.jump_forward_map:
|
70
|
+
return None
|
71
|
+
|
72
|
+
jump_forward_bytes = self.jump_forward_map.jump_forward_byte(self.state)
|
73
|
+
if jump_forward_bytes is None or len(jump_forward_bytes) <= 1:
|
74
|
+
return None
|
75
|
+
|
76
|
+
# preprocess the jump forward string
|
77
|
+
suffix_bytes = []
|
78
|
+
continuation_range = range(0x80, 0xC0)
|
79
|
+
cur_state = self.state
|
80
|
+
while (
|
81
|
+
len(jump_forward_bytes) and jump_forward_bytes[0][0] in continuation_range
|
82
|
+
):
|
83
|
+
# continuation bytes
|
84
|
+
byte_edge = jump_forward_bytes.pop(0)
|
85
|
+
suffix_bytes.append(byte_edge[0])
|
86
|
+
cur_state = byte_edge[1]
|
87
|
+
|
88
|
+
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
|
89
|
+
suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens)
|
90
|
+
return suffix_ids, cur_state
|
91
|
+
|
92
|
+
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
93
|
+
_, cur_state = helper
|
94
|
+
return self.jump_forward_map.jump_forward_symbol(cur_state)
|
95
|
+
|
96
|
+
def jump_and_retokenize(
|
97
|
+
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
98
|
+
):
|
99
|
+
self.state = next_state
|
100
|
+
|
101
|
+
def fill_vocab_mask(self, vocab_mask: torch.Tensor):
|
102
|
+
vocab_mask.fill_(1)
|
103
|
+
vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0
|
104
|
+
|
105
|
+
def copy(self):
|
106
|
+
return OutlinesGrammar(self.guide, self.jump_forward_map)
|
107
|
+
|
108
|
+
|
109
|
+
class OutlinesGrammarBackend(BaseGrammarBackend):
|
110
|
+
def __init__(
|
111
|
+
self,
|
112
|
+
tokenizer,
|
113
|
+
whitespace_pattern: bool,
|
114
|
+
allow_jump_forward: bool,
|
115
|
+
):
|
116
|
+
super().__init__()
|
117
|
+
|
118
|
+
try:
|
119
|
+
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
120
|
+
except AttributeError:
|
121
|
+
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
|
122
|
+
origin_pad_token_id = tokenizer.pad_token_id
|
123
|
+
|
124
|
+
def fset(self, value):
|
125
|
+
self._value = value
|
126
|
+
|
127
|
+
type(tokenizer).pad_token_id = property(
|
128
|
+
fget=type(tokenizer).pad_token_id.fget, fset=fset
|
129
|
+
)
|
130
|
+
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
131
|
+
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
|
132
|
+
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
|
133
|
+
self.outlines_tokenizer.pad_token = (
|
134
|
+
self.outlines_tokenizer.tokenizer.pad_token
|
135
|
+
)
|
136
|
+
self.outlines_tokenizer.vocabulary = (
|
137
|
+
self.outlines_tokenizer.tokenizer.get_vocab()
|
138
|
+
)
|
139
|
+
self.allow_jump_forward = allow_jump_forward
|
140
|
+
self.whitespace_pattern = whitespace_pattern
|
141
|
+
|
142
|
+
def init_value_impl(self, key: Tuple[str, str]) -> OutlinesGrammar:
|
143
|
+
key_type, key_string = key
|
144
|
+
if key_type == "json":
|
145
|
+
try:
|
146
|
+
regex = build_regex_from_object(
|
147
|
+
key_string,
|
148
|
+
whitespace_pattern=self.whitespace_pattern,
|
149
|
+
)
|
150
|
+
except NotImplementedError as e:
|
151
|
+
logger.warning(
|
152
|
+
f"skip invalid json schema: json_schema={key_string}, {e=}"
|
153
|
+
)
|
154
|
+
return None, key_string
|
155
|
+
elif key_type == "regex":
|
156
|
+
regex = key_string
|
157
|
+
else:
|
158
|
+
raise ValueError(f"Invalid key_type: {key_type}")
|
159
|
+
|
160
|
+
guide = RegexGuide(regex, self.outlines_tokenizer)
|
161
|
+
if self.allow_jump_forward:
|
162
|
+
jump_forward_map = OutlinesJumpForwardMap(regex)
|
163
|
+
else:
|
164
|
+
jump_forward_map = None
|
165
|
+
return OutlinesGrammar(guide, jump_forward_map)
|
@@ -0,0 +1,182 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""
|
17
|
+
Faster constrained decoding with jump forward decoding / compressed finite state machine.
|
18
|
+
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
|
19
|
+
"""
|
20
|
+
|
21
|
+
import dataclasses
|
22
|
+
import logging
|
23
|
+
from collections import defaultdict
|
24
|
+
|
25
|
+
import interegular
|
26
|
+
from interegular import InvalidSyntax
|
27
|
+
from outlines.caching import cache as disk_cache
|
28
|
+
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
29
|
+
|
30
|
+
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
31
|
+
|
32
|
+
logger = logging.getLogger(__name__)
|
33
|
+
|
34
|
+
|
35
|
+
@dataclasses.dataclass
|
36
|
+
class JumpEdge:
|
37
|
+
symbol: str = None
|
38
|
+
symbol_next_state: int = None
|
39
|
+
byte: int = None
|
40
|
+
byte_next_state: int = None
|
41
|
+
|
42
|
+
|
43
|
+
@disk_cache()
|
44
|
+
def init_state_to_jump_forward(regex_string):
|
45
|
+
try:
|
46
|
+
regex_pattern = interegular.parse_pattern(regex_string)
|
47
|
+
except InvalidSyntax as e:
|
48
|
+
logger.warning(f"skip invalid regex: {regex_string}, {e=}")
|
49
|
+
return
|
50
|
+
|
51
|
+
byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True)
|
52
|
+
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
|
53
|
+
|
54
|
+
fsm_info: FSMInfo = regex_fsm.fsm_info
|
55
|
+
|
56
|
+
symbol_to_id = fsm_info.alphabet_symbol_mapping
|
57
|
+
id_to_symbol = {}
|
58
|
+
for symbol, id_ in symbol_to_id.items():
|
59
|
+
id_to_symbol.setdefault(id_, []).append(symbol)
|
60
|
+
|
61
|
+
transitions = fsm_info.transitions
|
62
|
+
|
63
|
+
outgoings_ct = defaultdict(int)
|
64
|
+
# NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally
|
65
|
+
for s in fsm_info.finals:
|
66
|
+
outgoings_ct[s] = 1
|
67
|
+
|
68
|
+
state_to_jump_forward = {}
|
69
|
+
for (state, id_), next_state in transitions.items():
|
70
|
+
if id_ == fsm_info.alphabet_anything_value:
|
71
|
+
# Arbitrarily symbol cannot be recognized as jump forward
|
72
|
+
continue
|
73
|
+
|
74
|
+
symbols = id_to_symbol[id_]
|
75
|
+
for c in symbols:
|
76
|
+
if len(c) > 1:
|
77
|
+
# Skip byte level transitions like c = "5E"
|
78
|
+
continue
|
79
|
+
|
80
|
+
outgoings_ct[state] += 1
|
81
|
+
if outgoings_ct[state] > 1:
|
82
|
+
if state in state_to_jump_forward:
|
83
|
+
del state_to_jump_forward[state]
|
84
|
+
break
|
85
|
+
|
86
|
+
state_to_jump_forward[state] = JumpEdge(
|
87
|
+
symbol=c,
|
88
|
+
symbol_next_state=next_state,
|
89
|
+
)
|
90
|
+
|
91
|
+
# Process the byte level jump forward
|
92
|
+
outgoings_ct = defaultdict(int)
|
93
|
+
for s in fsm_info.finals:
|
94
|
+
outgoings_ct[s] = 1
|
95
|
+
|
96
|
+
for (state, id_), next_state in transitions.items():
|
97
|
+
if id_ == fsm_info.alphabet_anything_value:
|
98
|
+
continue
|
99
|
+
symbols = id_to_symbol[id_]
|
100
|
+
for c in symbols:
|
101
|
+
byte_ = None
|
102
|
+
if len(c) == 1 and ord(c) < 0x80:
|
103
|
+
# ASCII character
|
104
|
+
byte_ = ord(c)
|
105
|
+
elif len(c) > 1:
|
106
|
+
# FIXME: This logic is due to the leading \x00
|
107
|
+
# https://github.com/outlines-dev/outlines/pull/930
|
108
|
+
byte_ = int(symbols[0][1:], 16)
|
109
|
+
|
110
|
+
if byte_ is not None:
|
111
|
+
outgoings_ct[state] += 1
|
112
|
+
if outgoings_ct[state] > 1:
|
113
|
+
if state in state_to_jump_forward:
|
114
|
+
del state_to_jump_forward[state]
|
115
|
+
break
|
116
|
+
e = state_to_jump_forward.get(state, JumpEdge())
|
117
|
+
e.byte = byte_
|
118
|
+
e.byte_next_state = next_state
|
119
|
+
state_to_jump_forward[state] = e
|
120
|
+
|
121
|
+
return state_to_jump_forward
|
122
|
+
|
123
|
+
|
124
|
+
class OutlinesJumpForwardMap:
|
125
|
+
def __init__(self, regex_string):
|
126
|
+
self.state_to_jump_forward = init_state_to_jump_forward(regex_string)
|
127
|
+
|
128
|
+
def jump_forward_symbol(self, state):
|
129
|
+
jump_forward_str = ""
|
130
|
+
next_state = state
|
131
|
+
while state in self.state_to_jump_forward:
|
132
|
+
e = self.state_to_jump_forward[state]
|
133
|
+
if e.symbol is None:
|
134
|
+
break
|
135
|
+
jump_forward_str += e.symbol
|
136
|
+
next_state = e.symbol_next_state
|
137
|
+
state = next_state
|
138
|
+
|
139
|
+
return jump_forward_str, next_state
|
140
|
+
|
141
|
+
def jump_forward_byte(self, state):
|
142
|
+
if state not in self.state_to_jump_forward:
|
143
|
+
return None
|
144
|
+
|
145
|
+
jump_forward_bytes = []
|
146
|
+
next_state = None
|
147
|
+
while state in self.state_to_jump_forward:
|
148
|
+
e = self.state_to_jump_forward[state]
|
149
|
+
assert e.byte is not None and e.byte_next_state is not None
|
150
|
+
jump_forward_bytes.append((e.byte, e.byte_next_state))
|
151
|
+
next_state = e.byte_next_state
|
152
|
+
state = next_state
|
153
|
+
|
154
|
+
return jump_forward_bytes
|
155
|
+
|
156
|
+
def is_jump_forward_symbol_state(self, state):
|
157
|
+
return (
|
158
|
+
state in self.state_to_jump_forward
|
159
|
+
and self.state_to_jump_forward[state].symbol is not None
|
160
|
+
)
|
161
|
+
|
162
|
+
|
163
|
+
def test_main(regex_string):
|
164
|
+
jump_forward_map = OutlinesJumpForwardMap(regex_string)
|
165
|
+
for state, e in jump_forward_map.state_to_jump_forward.items():
|
166
|
+
if e.symbol is not None:
|
167
|
+
jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
|
168
|
+
print(f"{state} -> {next_state}", jump_forward_str)
|
169
|
+
bytes_ = jump_forward_map.jump_forward_byte(state)
|
170
|
+
print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_])
|
171
|
+
|
172
|
+
|
173
|
+
if __name__ == "__main__":
|
174
|
+
import outlines
|
175
|
+
|
176
|
+
outlines.caching.clear_cache()
|
177
|
+
test_main(r"The google's DNS sever address is " + IP_REGEX)
|
178
|
+
test_main(r"霍格沃茨特快列车|霍比特人比尔博")
|
179
|
+
# 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
|
180
|
+
# 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
|
181
|
+
|
182
|
+
test_main(r"[-+]?[0-9]+[ ]*")
|