sglang 0.3.5__py3-none-any.whl → 0.3.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +309 -0
- sglang/bench_serving.py +148 -24
- sglang/srt/configs/model_config.py +5 -2
- sglang/srt/constrained/__init__.py +2 -66
- sglang/srt/constrained/base_grammar_backend.py +73 -0
- sglang/srt/constrained/outlines_backend.py +165 -0
- sglang/srt/constrained/outlines_jump_forward.py +182 -0
- sglang/srt/constrained/xgrammar_backend.py +150 -0
- sglang/srt/layers/attention/triton_ops/decode_attention.py +7 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +6 -0
- sglang/srt/layers/fused_moe/fused_moe.py +23 -7
- sglang/srt/layers/fused_moe/patch.py +4 -2
- sglang/srt/layers/quantization/base_config.py +4 -6
- sglang/srt/layers/vocab_parallel_embedding.py +216 -150
- sglang/srt/managers/detokenizer_manager.py +0 -14
- sglang/srt/managers/io_struct.py +5 -3
- sglang/srt/managers/schedule_batch.py +14 -20
- sglang/srt/managers/scheduler.py +159 -96
- sglang/srt/managers/tokenizer_manager.py +81 -17
- sglang/srt/metrics/collector.py +211 -0
- sglang/srt/metrics/func_timer.py +108 -0
- sglang/srt/mm_utils.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/model_executor/forward_batch_info.py +7 -3
- sglang/srt/model_executor/model_runner.py +6 -2
- sglang/srt/models/gemma2_reward.py +69 -0
- sglang/srt/models/gpt2.py +31 -37
- sglang/srt/models/internlm2_reward.py +62 -0
- sglang/srt/models/llama.py +11 -6
- sglang/srt/models/llama_reward.py +5 -26
- sglang/srt/models/qwen2_vl.py +5 -7
- sglang/srt/openai_api/adapter.py +11 -4
- sglang/srt/openai_api/protocol.py +29 -26
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/sampling/sampling_params.py +2 -16
- sglang/srt/server.py +60 -17
- sglang/srt/server_args.py +66 -25
- sglang/srt/utils.py +120 -0
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_mgsm.py +2 -2
- sglang/test/test_utils.py +21 -7
- sglang/utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/METADATA +12 -8
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/RECORD +49 -45
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/WHEEL +1 -1
- sglang/srt/constrained/base_tool_cache.py +0 -65
- sglang/srt/constrained/bnf_cache.py +0 -61
- sglang/srt/constrained/fsm_cache.py +0 -95
- sglang/srt/constrained/grammar.py +0 -190
- sglang/srt/constrained/jump_forward.py +0 -203
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/LICENSE +0 -0
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/top_level.txt +0 -0
@@ -1,190 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Copyright 2023-2024 SGLang Team
|
3
|
-
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
you may not use this file except in compliance with the License.
|
5
|
-
You may obtain a copy of the License at
|
6
|
-
http://www.apache.org/licenses/LICENSE-2.0
|
7
|
-
Unless required by applicable law or agreed to in writing, software
|
8
|
-
distributed under the License is distributed on an "AS IS" BASIS,
|
9
|
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
10
|
-
See the License for the specific language governing permissions and
|
11
|
-
limitations under the License.
|
12
|
-
"""
|
13
|
-
|
14
|
-
"""Cache for the compressed finite state machine."""
|
15
|
-
import logging
|
16
|
-
from typing import List, Optional, Tuple, Union
|
17
|
-
|
18
|
-
import torch
|
19
|
-
|
20
|
-
from sglang.srt.constrained import GrammarMatcher, RegexGuide
|
21
|
-
from sglang.srt.constrained.bnf_cache import BNFCache
|
22
|
-
from sglang.srt.constrained.fsm_cache import FSMCache
|
23
|
-
from sglang.srt.constrained.jump_forward import JumpForwardCache, JumpForwardMap
|
24
|
-
|
25
|
-
# from sglang.srt.managers.schedule_batch import Req
|
26
|
-
|
27
|
-
logger = logging.getLogger(__name__)
|
28
|
-
|
29
|
-
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
30
|
-
|
31
|
-
|
32
|
-
class XGrammarJump:
|
33
|
-
pass
|
34
|
-
|
35
|
-
|
36
|
-
class JumpHelper:
|
37
|
-
data: Union[List, str]
|
38
|
-
state: int
|
39
|
-
suffix_ids: List[int]
|
40
|
-
|
41
|
-
def __init__(
|
42
|
-
self, data: Union[List, str] = "", state: int = -1, suffix_ids=[]
|
43
|
-
) -> None:
|
44
|
-
self.data = data
|
45
|
-
self.state = state
|
46
|
-
self.suffix_ids = suffix_ids
|
47
|
-
|
48
|
-
def can_jump(self):
|
49
|
-
return len(self.data) > 0
|
50
|
-
|
51
|
-
|
52
|
-
class Grammar:
|
53
|
-
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]]
|
54
|
-
jump_map: Union[XGrammarJump, JumpForwardMap, None]
|
55
|
-
|
56
|
-
def __init__(
|
57
|
-
self,
|
58
|
-
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]],
|
59
|
-
jump_map: Union[XGrammarJump, JumpForwardMap, None],
|
60
|
-
) -> None:
|
61
|
-
self.grammar = grammar
|
62
|
-
self.jump_map = jump_map
|
63
|
-
|
64
|
-
def accept_token(self, token: int):
|
65
|
-
if isinstance(self.grammar, GrammarMatcher):
|
66
|
-
assert self.grammar.accept_token(token)
|
67
|
-
else:
|
68
|
-
guide, state = self.grammar
|
69
|
-
self.grammar = guide, guide.get_next_state(state, token)
|
70
|
-
|
71
|
-
def try_jump(self, tokenizer) -> JumpHelper:
|
72
|
-
if isinstance(self.jump_map, XGrammarJump):
|
73
|
-
assert isinstance(self.grammar, GrammarMatcher)
|
74
|
-
return JumpHelper(self.grammar.find_jump_forward_string())
|
75
|
-
elif isinstance(self.jump_map, JumpForwardMap):
|
76
|
-
assert isinstance(self.grammar, Tuple)
|
77
|
-
|
78
|
-
_, state = self.grammar
|
79
|
-
jump_forward_bytes = self.jump_map.jump_forward_byte(state)
|
80
|
-
if jump_forward_bytes is None or len(jump_forward_bytes) == 0:
|
81
|
-
return JumpHelper() # can't jump
|
82
|
-
|
83
|
-
# preprocess the jump forward string
|
84
|
-
suffix_bytes = []
|
85
|
-
continuation_range = range(0x80, 0xC0)
|
86
|
-
cur_state = state
|
87
|
-
while (
|
88
|
-
len(jump_forward_bytes)
|
89
|
-
and jump_forward_bytes[0][0] in continuation_range
|
90
|
-
):
|
91
|
-
# continuation bytes
|
92
|
-
byte_edge = jump_forward_bytes.pop(0)
|
93
|
-
suffix_bytes.append(byte_edge[0])
|
94
|
-
cur_state = byte_edge[1]
|
95
|
-
|
96
|
-
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
|
97
|
-
suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens)
|
98
|
-
return JumpHelper(suffix_ids, cur_state, suffix_bytes)
|
99
|
-
else:
|
100
|
-
return JumpHelper() # can't jump
|
101
|
-
|
102
|
-
def jump_forward_str_state(self, helper: JumpHelper) -> Tuple[str, int]:
|
103
|
-
if isinstance(helper.data, str):
|
104
|
-
return helper.data, -1
|
105
|
-
else:
|
106
|
-
assert isinstance(self.jump_map, JumpForwardMap)
|
107
|
-
return self.jump_map.jump_forward_symbol(helper.state)
|
108
|
-
|
109
|
-
def jump_and_retokenize(
|
110
|
-
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
111
|
-
):
|
112
|
-
if isinstance(self.grammar, GrammarMatcher):
|
113
|
-
k = 0
|
114
|
-
for i, old_id in enumerate(old_output_ids):
|
115
|
-
if old_id == new_output_ids[i]:
|
116
|
-
k = i + 1
|
117
|
-
else:
|
118
|
-
break
|
119
|
-
|
120
|
-
# rollback to the last token that is the same
|
121
|
-
if k < len(old_output_ids):
|
122
|
-
self.grammar.rollback(len(old_output_ids) - k)
|
123
|
-
|
124
|
-
for i in range(k, len(new_output_ids)):
|
125
|
-
assert self.grammar.accept_token(new_output_ids[i])
|
126
|
-
else:
|
127
|
-
self.grammar = self.grammar[0], next_state
|
128
|
-
|
129
|
-
def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int):
|
130
|
-
if isinstance(self.grammar, GrammarMatcher):
|
131
|
-
# Note that this bitmask is a bitset, not bool
|
132
|
-
bitmask = self.grammar.find_next_token_bitmask()
|
133
|
-
# Mask the tokens that are not allowed
|
134
|
-
vocab_mask[
|
135
|
-
self.grammar.get_rejected_tokens_from_bitmask(bitmask, vocab_size)
|
136
|
-
] = 1
|
137
|
-
else:
|
138
|
-
guide, state = self.grammar
|
139
|
-
vocab_mask.fill_(1)
|
140
|
-
vocab_mask[guide.get_next_instruction(state).tokens] = 0
|
141
|
-
|
142
|
-
|
143
|
-
class GrammarCache:
|
144
|
-
grammar_cache: Union[BNFCache, FSMCache]
|
145
|
-
jump_cache: Union[XGrammarJump, JumpForwardCache, None]
|
146
|
-
|
147
|
-
def __init__(
|
148
|
-
self,
|
149
|
-
tokenizer_path,
|
150
|
-
tokenizer_args_dict,
|
151
|
-
skip_tokenizer_init=False,
|
152
|
-
whitespace_patterns=None,
|
153
|
-
backend=None,
|
154
|
-
allow_jump=False,
|
155
|
-
):
|
156
|
-
if backend == "xgrammar":
|
157
|
-
self.grammar_cache = BNFCache(
|
158
|
-
tokenizer_path=tokenizer_path,
|
159
|
-
tokenizer_args_dict=tokenizer_args_dict,
|
160
|
-
skip_tokenizer_init=skip_tokenizer_init,
|
161
|
-
whitespace_patterns=whitespace_patterns,
|
162
|
-
)
|
163
|
-
self.jump_cache = XGrammarJump() if allow_jump else None
|
164
|
-
else:
|
165
|
-
assert backend == "outlines"
|
166
|
-
self.grammar_cache = FSMCache(
|
167
|
-
tokenizer_path=tokenizer_path,
|
168
|
-
tokenizer_args_dict=tokenizer_args_dict,
|
169
|
-
skip_tokenizer_init=skip_tokenizer_init,
|
170
|
-
constrained_json_whitespace_pattern=whitespace_patterns,
|
171
|
-
enable=True,
|
172
|
-
)
|
173
|
-
self.jump_cache = JumpForwardCache() if allow_jump else None
|
174
|
-
|
175
|
-
def query(self, key: Tuple[str, str], vocab_size: int) -> Grammar:
|
176
|
-
if isinstance(self.grammar_cache, BNFCache):
|
177
|
-
assert not isinstance(self.jump_cache, JumpForwardCache)
|
178
|
-
return Grammar(self.grammar_cache.query(key, vocab_size), self.jump_cache)
|
179
|
-
else:
|
180
|
-
jump_map = None
|
181
|
-
guide, regex = self.grammar_cache.query(key)
|
182
|
-
if isinstance(self.jump_cache, JumpForwardCache):
|
183
|
-
jump_map = self.jump_cache.query(regex)
|
184
|
-
return Grammar((guide, 0), jump_map)
|
185
|
-
|
186
|
-
def reset(self):
|
187
|
-
if isinstance(self.grammar_cache, FSMCache):
|
188
|
-
self.grammar_cache.reset()
|
189
|
-
if isinstance(self.jump_cache, JumpForwardCache):
|
190
|
-
self.jump_cache.reset()
|
@@ -1,203 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Copyright 2023-2024 SGLang Team
|
3
|
-
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
you may not use this file except in compliance with the License.
|
5
|
-
You may obtain a copy of the License at
|
6
|
-
|
7
|
-
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
|
9
|
-
Unless required by applicable law or agreed to in writing, software
|
10
|
-
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
See the License for the specific language governing permissions and
|
13
|
-
limitations under the License.
|
14
|
-
"""
|
15
|
-
|
16
|
-
"""
|
17
|
-
Faster constrained decoding.
|
18
|
-
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
|
19
|
-
"""
|
20
|
-
|
21
|
-
import dataclasses
|
22
|
-
import logging
|
23
|
-
from collections import defaultdict
|
24
|
-
|
25
|
-
import interegular
|
26
|
-
import outlines.caching
|
27
|
-
from interegular import InvalidSyntax
|
28
|
-
|
29
|
-
from sglang.srt.constrained import (
|
30
|
-
FSMInfo,
|
31
|
-
disk_cache,
|
32
|
-
make_byte_level_fsm,
|
33
|
-
make_deterministic_fsm,
|
34
|
-
)
|
35
|
-
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
36
|
-
|
37
|
-
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
38
|
-
|
39
|
-
logger = logging.getLogger(__name__)
|
40
|
-
|
41
|
-
|
42
|
-
@dataclasses.dataclass
|
43
|
-
class JumpEdge:
|
44
|
-
symbol: str = None
|
45
|
-
symbol_next_state: int = None
|
46
|
-
byte: int = None
|
47
|
-
byte_next_state: int = None
|
48
|
-
|
49
|
-
|
50
|
-
class JumpForwardMap:
|
51
|
-
def __init__(self, regex_string):
|
52
|
-
@disk_cache()
|
53
|
-
def _init_state_to_jump_forward(regex_string):
|
54
|
-
try:
|
55
|
-
regex_pattern = interegular.parse_pattern(regex_string)
|
56
|
-
except InvalidSyntax as e:
|
57
|
-
logger.warning(f"skip invalid regex: {regex_string}, {e=}")
|
58
|
-
self.state_to_jump_forward = None
|
59
|
-
return
|
60
|
-
|
61
|
-
byte_fsm = make_byte_level_fsm(
|
62
|
-
regex_pattern.to_fsm().reduce(), keep_utf8=True
|
63
|
-
)
|
64
|
-
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
|
65
|
-
|
66
|
-
fsm_info: FSMInfo = regex_fsm.fsm_info
|
67
|
-
|
68
|
-
symbol_to_id = fsm_info.alphabet_symbol_mapping
|
69
|
-
id_to_symbol = {}
|
70
|
-
for symbol, id_ in symbol_to_id.items():
|
71
|
-
id_to_symbol.setdefault(id_, []).append(symbol)
|
72
|
-
|
73
|
-
transitions = fsm_info.transitions
|
74
|
-
|
75
|
-
outgoings_ct = defaultdict(int)
|
76
|
-
# NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally
|
77
|
-
for s in fsm_info.finals:
|
78
|
-
outgoings_ct[s] = 1
|
79
|
-
|
80
|
-
state_to_jump_forward = {}
|
81
|
-
for (state, id_), next_state in transitions.items():
|
82
|
-
if id_ == fsm_info.alphabet_anything_value:
|
83
|
-
# Arbitrarily symbol cannot be recognized as jump forward
|
84
|
-
continue
|
85
|
-
|
86
|
-
symbols = id_to_symbol[id_]
|
87
|
-
for c in symbols:
|
88
|
-
if len(c) > 1:
|
89
|
-
# Skip byte level transitions like c = "5E"
|
90
|
-
continue
|
91
|
-
|
92
|
-
outgoings_ct[state] += 1
|
93
|
-
if outgoings_ct[state] > 1:
|
94
|
-
if state in state_to_jump_forward:
|
95
|
-
del state_to_jump_forward[state]
|
96
|
-
break
|
97
|
-
|
98
|
-
state_to_jump_forward[state] = JumpEdge(
|
99
|
-
symbol=c,
|
100
|
-
symbol_next_state=next_state,
|
101
|
-
)
|
102
|
-
|
103
|
-
# Process the byte level jump forward
|
104
|
-
outgoings_ct = defaultdict(int)
|
105
|
-
for s in fsm_info.finals:
|
106
|
-
outgoings_ct[s] = 1
|
107
|
-
|
108
|
-
for (state, id_), next_state in transitions.items():
|
109
|
-
if id_ == fsm_info.alphabet_anything_value:
|
110
|
-
continue
|
111
|
-
symbols = id_to_symbol[id_]
|
112
|
-
for c in symbols:
|
113
|
-
byte_ = None
|
114
|
-
if len(c) == 1 and ord(c) < 0x80:
|
115
|
-
# ASCII character
|
116
|
-
byte_ = ord(c)
|
117
|
-
elif len(c) > 1:
|
118
|
-
# FIXME: This logic is due to the leading \x00
|
119
|
-
# https://github.com/outlines-dev/outlines/pull/930
|
120
|
-
byte_ = int(symbols[0][1:], 16)
|
121
|
-
|
122
|
-
if byte_ is not None:
|
123
|
-
outgoings_ct[state] += 1
|
124
|
-
if outgoings_ct[state] > 1:
|
125
|
-
if state in state_to_jump_forward:
|
126
|
-
del state_to_jump_forward[state]
|
127
|
-
break
|
128
|
-
e = state_to_jump_forward.get(state, JumpEdge())
|
129
|
-
e.byte = byte_
|
130
|
-
e.byte_next_state = next_state
|
131
|
-
state_to_jump_forward[state] = e
|
132
|
-
|
133
|
-
return state_to_jump_forward
|
134
|
-
|
135
|
-
self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
|
136
|
-
|
137
|
-
def jump_forward_symbol(self, state):
|
138
|
-
jump_forward_str = ""
|
139
|
-
next_state = state
|
140
|
-
while state in self.state_to_jump_forward:
|
141
|
-
e = self.state_to_jump_forward[state]
|
142
|
-
if e.symbol is None:
|
143
|
-
break
|
144
|
-
jump_forward_str += e.symbol
|
145
|
-
next_state = e.symbol_next_state
|
146
|
-
state = next_state
|
147
|
-
|
148
|
-
return jump_forward_str, next_state
|
149
|
-
|
150
|
-
def jump_forward_byte(self, state):
|
151
|
-
if state not in self.state_to_jump_forward:
|
152
|
-
return None
|
153
|
-
|
154
|
-
jump_forward_bytes = []
|
155
|
-
next_state = None
|
156
|
-
while state in self.state_to_jump_forward:
|
157
|
-
e = self.state_to_jump_forward[state]
|
158
|
-
assert e.byte is not None and e.byte_next_state is not None
|
159
|
-
jump_forward_bytes.append((e.byte, e.byte_next_state))
|
160
|
-
next_state = e.byte_next_state
|
161
|
-
state = next_state
|
162
|
-
|
163
|
-
return jump_forward_bytes
|
164
|
-
|
165
|
-
def is_jump_forward_symbol_state(self, state):
|
166
|
-
return (
|
167
|
-
state in self.state_to_jump_forward
|
168
|
-
and self.state_to_jump_forward[state].symbol is not None
|
169
|
-
)
|
170
|
-
|
171
|
-
|
172
|
-
class JumpForwardCache(BaseToolCache):
|
173
|
-
def __init__(self):
|
174
|
-
super().__init__()
|
175
|
-
|
176
|
-
def init_value(self, regex):
|
177
|
-
forward_map = JumpForwardMap(regex)
|
178
|
-
if forward_map.state_to_jump_forward:
|
179
|
-
return forward_map
|
180
|
-
else:
|
181
|
-
return None
|
182
|
-
|
183
|
-
|
184
|
-
def test_main(regex_string):
|
185
|
-
jump_forward_map = JumpForwardMap(regex_string)
|
186
|
-
for state, e in jump_forward_map.state_to_jump_forward.items():
|
187
|
-
if e.symbol is not None:
|
188
|
-
jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
|
189
|
-
print(f"{state} -> {next_state}", jump_forward_str)
|
190
|
-
bytes_ = jump_forward_map.jump_forward_byte(state)
|
191
|
-
print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_])
|
192
|
-
|
193
|
-
|
194
|
-
if __name__ == "__main__":
|
195
|
-
import outlines
|
196
|
-
|
197
|
-
outlines.caching.clear_cache()
|
198
|
-
test_main(r"The google's DNS sever address is " + IP_REGEX)
|
199
|
-
test_main(r"霍格沃茨特快列车|霍比特人比尔博")
|
200
|
-
# 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
|
201
|
-
# 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
|
202
|
-
|
203
|
-
test_main(r"[-+]?[0-9]+[ ]*")
|
File without changes
|
File without changes
|