sglang 0.3.5__py3-none-any.whl → 0.3.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +309 -0
- sglang/bench_serving.py +148 -24
- sglang/srt/configs/model_config.py +5 -2
- sglang/srt/constrained/__init__.py +2 -66
- sglang/srt/constrained/base_grammar_backend.py +73 -0
- sglang/srt/constrained/outlines_backend.py +165 -0
- sglang/srt/constrained/outlines_jump_forward.py +182 -0
- sglang/srt/constrained/xgrammar_backend.py +150 -0
- sglang/srt/layers/attention/triton_ops/decode_attention.py +7 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +6 -0
- sglang/srt/layers/fused_moe/fused_moe.py +23 -7
- sglang/srt/layers/fused_moe/patch.py +4 -2
- sglang/srt/layers/quantization/base_config.py +4 -6
- sglang/srt/layers/vocab_parallel_embedding.py +216 -150
- sglang/srt/managers/detokenizer_manager.py +0 -14
- sglang/srt/managers/io_struct.py +5 -3
- sglang/srt/managers/schedule_batch.py +14 -20
- sglang/srt/managers/scheduler.py +159 -96
- sglang/srt/managers/tokenizer_manager.py +81 -17
- sglang/srt/metrics/collector.py +211 -0
- sglang/srt/metrics/func_timer.py +108 -0
- sglang/srt/mm_utils.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/model_executor/forward_batch_info.py +7 -3
- sglang/srt/model_executor/model_runner.py +6 -2
- sglang/srt/models/gemma2_reward.py +69 -0
- sglang/srt/models/gpt2.py +31 -37
- sglang/srt/models/internlm2_reward.py +62 -0
- sglang/srt/models/llama.py +11 -6
- sglang/srt/models/llama_reward.py +5 -26
- sglang/srt/models/qwen2_vl.py +5 -7
- sglang/srt/openai_api/adapter.py +11 -4
- sglang/srt/openai_api/protocol.py +29 -26
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/sampling/sampling_params.py +2 -16
- sglang/srt/server.py +60 -17
- sglang/srt/server_args.py +66 -25
- sglang/srt/utils.py +120 -0
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_mgsm.py +2 -2
- sglang/test/test_utils.py +21 -7
- sglang/utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/METADATA +12 -8
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/RECORD +49 -45
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/WHEEL +1 -1
- sglang/srt/constrained/base_tool_cache.py +0 -65
- sglang/srt/constrained/bnf_cache.py +0 -61
- sglang/srt/constrained/fsm_cache.py +0 -95
- sglang/srt/constrained/grammar.py +0 -190
- sglang/srt/constrained/jump_forward.py +0 -203
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/LICENSE +0 -0
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,73 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""The baseclass of a backend for grammar-guided constrained decoding."""
|
17
|
+
|
18
|
+
from concurrent.futures import Future, ThreadPoolExecutor
|
19
|
+
from dataclasses import dataclass
|
20
|
+
from threading import Event, Lock
|
21
|
+
from typing import Any, Optional, Tuple
|
22
|
+
|
23
|
+
|
24
|
+
@dataclass
|
25
|
+
class CacheEntry:
|
26
|
+
value: Any
|
27
|
+
event: Event
|
28
|
+
|
29
|
+
|
30
|
+
class BaseGrammarObject:
|
31
|
+
pass
|
32
|
+
|
33
|
+
|
34
|
+
class BaseGrammarBackend:
|
35
|
+
def __init__(self):
|
36
|
+
self.executor = ThreadPoolExecutor()
|
37
|
+
self.cache = {}
|
38
|
+
self.cache_lock = Lock()
|
39
|
+
|
40
|
+
def init_value(self, key: Tuple[str, str]) -> BaseGrammarObject:
|
41
|
+
with self.cache_lock:
|
42
|
+
if key in self.cache:
|
43
|
+
cache_hit = True
|
44
|
+
entry = self.cache[key]
|
45
|
+
else:
|
46
|
+
cache_hit = False
|
47
|
+
entry = CacheEntry(None, Event())
|
48
|
+
self.cache[key] = entry
|
49
|
+
|
50
|
+
if cache_hit:
|
51
|
+
entry.event.wait()
|
52
|
+
else:
|
53
|
+
entry.value = self.init_value_impl(key)
|
54
|
+
entry.event.set()
|
55
|
+
return entry.value.copy() if entry.value else None
|
56
|
+
|
57
|
+
def init_value_impl(self, key: Tuple[str, str]) -> BaseGrammarObject:
|
58
|
+
raise NotImplementedError()
|
59
|
+
|
60
|
+
def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
61
|
+
with self.cache_lock:
|
62
|
+
entry = self.cache.get(key)
|
63
|
+
if not entry or not entry.event.is_set():
|
64
|
+
return None
|
65
|
+
val = self.cache[key].value
|
66
|
+
return val.copy() if val else None
|
67
|
+
|
68
|
+
def get_future_value(self, key: Tuple[str, str]) -> Future:
|
69
|
+
return self.executor.submit(self.init_value, key)
|
70
|
+
|
71
|
+
def reset(self):
|
72
|
+
with self.cache_lock:
|
73
|
+
self.cache.clear()
|
@@ -0,0 +1,165 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""Constrained decoding with outlines backend."""
|
17
|
+
|
18
|
+
import json
|
19
|
+
import logging
|
20
|
+
from typing import Dict, List, Optional, Tuple, Union
|
21
|
+
|
22
|
+
import interegular
|
23
|
+
import torch
|
24
|
+
from outlines.fsm.guide import RegexGuide
|
25
|
+
from outlines.fsm.json_schema import build_regex_from_schema
|
26
|
+
from outlines.models.transformers import TransformerTokenizer
|
27
|
+
from pydantic import BaseModel
|
28
|
+
|
29
|
+
from sglang.srt.constrained.base_grammar_backend import (
|
30
|
+
BaseGrammarBackend,
|
31
|
+
BaseGrammarObject,
|
32
|
+
)
|
33
|
+
from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
|
34
|
+
|
35
|
+
logger = logging.getLogger(__name__)
|
36
|
+
|
37
|
+
|
38
|
+
class OutlinesGrammar(BaseGrammarObject):
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
guide: RegexGuide,
|
42
|
+
jump_forward_map: Union[OutlinesJumpForwardMap, None],
|
43
|
+
) -> None:
|
44
|
+
self.guide = guide
|
45
|
+
self.jump_forward_map = jump_forward_map
|
46
|
+
self.state = 0
|
47
|
+
|
48
|
+
def accept_token(self, token: int):
|
49
|
+
self.state = self.guide.get_next_state(self.state, token)
|
50
|
+
|
51
|
+
def try_jump_forward(self, tokenizer) -> Optional[Tuple]:
|
52
|
+
if not self.jump_forward_map:
|
53
|
+
return None
|
54
|
+
|
55
|
+
jump_forward_bytes = self.jump_forward_map.jump_forward_byte(self.state)
|
56
|
+
if jump_forward_bytes is None or len(jump_forward_bytes) <= 1:
|
57
|
+
return None
|
58
|
+
|
59
|
+
# preprocess the jump forward string
|
60
|
+
suffix_bytes = []
|
61
|
+
continuation_range = range(0x80, 0xC0)
|
62
|
+
cur_state = self.state
|
63
|
+
while (
|
64
|
+
len(jump_forward_bytes) and jump_forward_bytes[0][0] in continuation_range
|
65
|
+
):
|
66
|
+
# continuation bytes
|
67
|
+
byte_edge = jump_forward_bytes.pop(0)
|
68
|
+
suffix_bytes.append(byte_edge[0])
|
69
|
+
cur_state = byte_edge[1]
|
70
|
+
|
71
|
+
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
|
72
|
+
suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens)
|
73
|
+
return suffix_ids, cur_state
|
74
|
+
|
75
|
+
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
76
|
+
_, cur_state = helper
|
77
|
+
return self.jump_forward_map.jump_forward_symbol(cur_state)
|
78
|
+
|
79
|
+
def jump_and_retokenize(
|
80
|
+
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
81
|
+
):
|
82
|
+
self.state = next_state
|
83
|
+
|
84
|
+
def fill_vocab_mask(self, vocab_mask: torch.Tensor):
|
85
|
+
vocab_mask.fill_(1)
|
86
|
+
vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0
|
87
|
+
|
88
|
+
def copy(self):
|
89
|
+
return OutlinesGrammar(self.guide, self.jump_forward_map)
|
90
|
+
|
91
|
+
|
92
|
+
class OutlinesGrammarBackend(BaseGrammarBackend):
|
93
|
+
def __init__(
|
94
|
+
self,
|
95
|
+
tokenizer,
|
96
|
+
whitespace_pattern: bool,
|
97
|
+
allow_jump_forward: bool,
|
98
|
+
):
|
99
|
+
super().__init__()
|
100
|
+
|
101
|
+
try:
|
102
|
+
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
103
|
+
except AttributeError:
|
104
|
+
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
|
105
|
+
origin_pad_token_id = tokenizer.pad_token_id
|
106
|
+
|
107
|
+
def fset(self, value):
|
108
|
+
self._value = value
|
109
|
+
|
110
|
+
type(tokenizer).pad_token_id = property(
|
111
|
+
fget=type(tokenizer).pad_token_id.fget, fset=fset
|
112
|
+
)
|
113
|
+
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
114
|
+
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
|
115
|
+
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
|
116
|
+
self.outlines_tokenizer.pad_token = (
|
117
|
+
self.outlines_tokenizer.tokenizer.pad_token
|
118
|
+
)
|
119
|
+
self.outlines_tokenizer.vocabulary = (
|
120
|
+
self.outlines_tokenizer.tokenizer.get_vocab()
|
121
|
+
)
|
122
|
+
self.allow_jump_forward = allow_jump_forward
|
123
|
+
self.whitespace_pattern = whitespace_pattern
|
124
|
+
|
125
|
+
def init_value_impl(self, key: Tuple[str, str]) -> OutlinesGrammar:
|
126
|
+
key_type, key_string = key
|
127
|
+
if key_type == "json":
|
128
|
+
try:
|
129
|
+
regex = build_regex_from_object(
|
130
|
+
key_string,
|
131
|
+
whitespace_pattern=self.whitespace_pattern,
|
132
|
+
)
|
133
|
+
except (NotImplementedError, json.decoder.JSONDecodeError) as e:
|
134
|
+
logger.warning(
|
135
|
+
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
136
|
+
)
|
137
|
+
return None
|
138
|
+
elif key_type == "regex":
|
139
|
+
regex = key_string
|
140
|
+
else:
|
141
|
+
raise ValueError(f"Invalid key_type: {key_type}")
|
142
|
+
|
143
|
+
try:
|
144
|
+
guide = RegexGuide(regex, self.outlines_tokenizer)
|
145
|
+
except interegular.patterns.InvalidSyntax as e:
|
146
|
+
logger.warning(f"skip invalid regex schema: {regex=}, {e=}")
|
147
|
+
return None
|
148
|
+
|
149
|
+
if self.allow_jump_forward:
|
150
|
+
jump_forward_map = OutlinesJumpForwardMap(regex)
|
151
|
+
else:
|
152
|
+
jump_forward_map = None
|
153
|
+
return OutlinesGrammar(guide, jump_forward_map)
|
154
|
+
|
155
|
+
|
156
|
+
def build_regex_from_object(
|
157
|
+
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
158
|
+
):
|
159
|
+
if isinstance(object, type(BaseModel)):
|
160
|
+
schema = json.dumps(object.model_json_schema())
|
161
|
+
elif isinstance(object, Dict):
|
162
|
+
schema = json.dumps(object)
|
163
|
+
else:
|
164
|
+
schema = object
|
165
|
+
return build_regex_from_schema(schema, whitespace_pattern)
|
@@ -0,0 +1,182 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""
|
17
|
+
Faster constrained decoding with jump forward decoding / compressed finite state machine.
|
18
|
+
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
|
19
|
+
"""
|
20
|
+
|
21
|
+
import dataclasses
|
22
|
+
import logging
|
23
|
+
from collections import defaultdict
|
24
|
+
|
25
|
+
import interegular
|
26
|
+
from interegular import InvalidSyntax
|
27
|
+
from outlines.caching import cache as disk_cache
|
28
|
+
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
29
|
+
|
30
|
+
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
31
|
+
|
32
|
+
logger = logging.getLogger(__name__)
|
33
|
+
|
34
|
+
|
35
|
+
@dataclasses.dataclass
|
36
|
+
class JumpEdge:
|
37
|
+
symbol: str = None
|
38
|
+
symbol_next_state: int = None
|
39
|
+
byte: int = None
|
40
|
+
byte_next_state: int = None
|
41
|
+
|
42
|
+
|
43
|
+
@disk_cache()
|
44
|
+
def init_state_to_jump_forward(regex_string):
|
45
|
+
try:
|
46
|
+
regex_pattern = interegular.parse_pattern(regex_string)
|
47
|
+
except InvalidSyntax as e:
|
48
|
+
logger.warning(f"skip invalid regex: {regex_string}, {e=}")
|
49
|
+
return
|
50
|
+
|
51
|
+
byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True)
|
52
|
+
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
|
53
|
+
|
54
|
+
fsm_info: FSMInfo = regex_fsm.fsm_info
|
55
|
+
|
56
|
+
symbol_to_id = fsm_info.alphabet_symbol_mapping
|
57
|
+
id_to_symbol = {}
|
58
|
+
for symbol, id_ in symbol_to_id.items():
|
59
|
+
id_to_symbol.setdefault(id_, []).append(symbol)
|
60
|
+
|
61
|
+
transitions = fsm_info.transitions
|
62
|
+
|
63
|
+
outgoings_ct = defaultdict(int)
|
64
|
+
# NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally
|
65
|
+
for s in fsm_info.finals:
|
66
|
+
outgoings_ct[s] = 1
|
67
|
+
|
68
|
+
state_to_jump_forward = {}
|
69
|
+
for (state, id_), next_state in transitions.items():
|
70
|
+
if id_ == fsm_info.alphabet_anything_value:
|
71
|
+
# Arbitrarily symbol cannot be recognized as jump forward
|
72
|
+
continue
|
73
|
+
|
74
|
+
symbols = id_to_symbol[id_]
|
75
|
+
for c in symbols:
|
76
|
+
if len(c) > 1:
|
77
|
+
# Skip byte level transitions like c = "5E"
|
78
|
+
continue
|
79
|
+
|
80
|
+
outgoings_ct[state] += 1
|
81
|
+
if outgoings_ct[state] > 1:
|
82
|
+
if state in state_to_jump_forward:
|
83
|
+
del state_to_jump_forward[state]
|
84
|
+
break
|
85
|
+
|
86
|
+
state_to_jump_forward[state] = JumpEdge(
|
87
|
+
symbol=c,
|
88
|
+
symbol_next_state=next_state,
|
89
|
+
)
|
90
|
+
|
91
|
+
# Process the byte level jump forward
|
92
|
+
outgoings_ct = defaultdict(int)
|
93
|
+
for s in fsm_info.finals:
|
94
|
+
outgoings_ct[s] = 1
|
95
|
+
|
96
|
+
for (state, id_), next_state in transitions.items():
|
97
|
+
if id_ == fsm_info.alphabet_anything_value:
|
98
|
+
continue
|
99
|
+
symbols = id_to_symbol[id_]
|
100
|
+
for c in symbols:
|
101
|
+
byte_ = None
|
102
|
+
if len(c) == 1 and ord(c) < 0x80:
|
103
|
+
# ASCII character
|
104
|
+
byte_ = ord(c)
|
105
|
+
elif len(c) > 1:
|
106
|
+
# FIXME: This logic is due to the leading \x00
|
107
|
+
# https://github.com/outlines-dev/outlines/pull/930
|
108
|
+
byte_ = int(symbols[0][1:], 16)
|
109
|
+
|
110
|
+
if byte_ is not None:
|
111
|
+
outgoings_ct[state] += 1
|
112
|
+
if outgoings_ct[state] > 1:
|
113
|
+
if state in state_to_jump_forward:
|
114
|
+
del state_to_jump_forward[state]
|
115
|
+
break
|
116
|
+
e = state_to_jump_forward.get(state, JumpEdge())
|
117
|
+
e.byte = byte_
|
118
|
+
e.byte_next_state = next_state
|
119
|
+
state_to_jump_forward[state] = e
|
120
|
+
|
121
|
+
return state_to_jump_forward
|
122
|
+
|
123
|
+
|
124
|
+
class OutlinesJumpForwardMap:
|
125
|
+
def __init__(self, regex_string):
|
126
|
+
self.state_to_jump_forward = init_state_to_jump_forward(regex_string)
|
127
|
+
|
128
|
+
def jump_forward_symbol(self, state):
|
129
|
+
jump_forward_str = ""
|
130
|
+
next_state = state
|
131
|
+
while state in self.state_to_jump_forward:
|
132
|
+
e = self.state_to_jump_forward[state]
|
133
|
+
if e.symbol is None:
|
134
|
+
break
|
135
|
+
jump_forward_str += e.symbol
|
136
|
+
next_state = e.symbol_next_state
|
137
|
+
state = next_state
|
138
|
+
|
139
|
+
return jump_forward_str, next_state
|
140
|
+
|
141
|
+
def jump_forward_byte(self, state):
|
142
|
+
if state not in self.state_to_jump_forward:
|
143
|
+
return None
|
144
|
+
|
145
|
+
jump_forward_bytes = []
|
146
|
+
next_state = None
|
147
|
+
while state in self.state_to_jump_forward:
|
148
|
+
e = self.state_to_jump_forward[state]
|
149
|
+
assert e.byte is not None and e.byte_next_state is not None
|
150
|
+
jump_forward_bytes.append((e.byte, e.byte_next_state))
|
151
|
+
next_state = e.byte_next_state
|
152
|
+
state = next_state
|
153
|
+
|
154
|
+
return jump_forward_bytes
|
155
|
+
|
156
|
+
def is_jump_forward_symbol_state(self, state):
|
157
|
+
return (
|
158
|
+
state in self.state_to_jump_forward
|
159
|
+
and self.state_to_jump_forward[state].symbol is not None
|
160
|
+
)
|
161
|
+
|
162
|
+
|
163
|
+
def test_main(regex_string):
|
164
|
+
jump_forward_map = OutlinesJumpForwardMap(regex_string)
|
165
|
+
for state, e in jump_forward_map.state_to_jump_forward.items():
|
166
|
+
if e.symbol is not None:
|
167
|
+
jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
|
168
|
+
print(f"{state} -> {next_state}", jump_forward_str)
|
169
|
+
bytes_ = jump_forward_map.jump_forward_byte(state)
|
170
|
+
print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_])
|
171
|
+
|
172
|
+
|
173
|
+
if __name__ == "__main__":
|
174
|
+
import outlines
|
175
|
+
|
176
|
+
outlines.caching.clear_cache()
|
177
|
+
test_main(r"The google's DNS sever address is " + IP_REGEX)
|
178
|
+
test_main(r"霍格沃茨特快列车|霍比特人比尔博")
|
179
|
+
# 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
|
180
|
+
# 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
|
181
|
+
|
182
|
+
test_main(r"[-+]?[0-9]+[ ]*")
|
@@ -0,0 +1,150 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""Constrained decoding with xgrammar backend."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
from typing import List, Tuple
|
20
|
+
|
21
|
+
import torch
|
22
|
+
|
23
|
+
try:
|
24
|
+
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
|
25
|
+
|
26
|
+
import_error = None
|
27
|
+
except ImportError as e:
|
28
|
+
CachedGrammarCompiler = CompiledGrammar = GrammarMatcher = TokenizerInfo = (
|
29
|
+
ImportError
|
30
|
+
)
|
31
|
+
import_error = e
|
32
|
+
|
33
|
+
from sglang.srt.constrained.base_grammar_backend import (
|
34
|
+
BaseGrammarBackend,
|
35
|
+
BaseGrammarObject,
|
36
|
+
)
|
37
|
+
|
38
|
+
logger = logging.getLogger(__name__)
|
39
|
+
|
40
|
+
|
41
|
+
MAX_ROLLBACK_TOKENS = 10
|
42
|
+
|
43
|
+
|
44
|
+
class XGrammarGrammar(BaseGrammarObject):
|
45
|
+
|
46
|
+
def __init__(
|
47
|
+
self, matcher: GrammarMatcher, vocab_size: int, ctx: CompiledGrammar
|
48
|
+
) -> None:
|
49
|
+
self.matcher = matcher
|
50
|
+
self.vocab_size = vocab_size
|
51
|
+
self.ctx = ctx
|
52
|
+
|
53
|
+
def accept_token(self, token: int):
|
54
|
+
assert self.matcher.accept_token(token)
|
55
|
+
|
56
|
+
def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
|
57
|
+
s = self.matcher.find_jump_forward_string()
|
58
|
+
if s:
|
59
|
+
return [], s
|
60
|
+
return None
|
61
|
+
|
62
|
+
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
63
|
+
_, data = helper
|
64
|
+
return data, -1
|
65
|
+
|
66
|
+
def jump_and_retokenize(
|
67
|
+
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
68
|
+
):
|
69
|
+
k = 0
|
70
|
+
for i, old_id in enumerate(old_output_ids):
|
71
|
+
if old_id == new_output_ids[i]:
|
72
|
+
k = i + 1
|
73
|
+
else:
|
74
|
+
break
|
75
|
+
|
76
|
+
# rollback to the last token that is the same
|
77
|
+
if k < len(old_output_ids):
|
78
|
+
self.matcher.rollback(len(old_output_ids) - k)
|
79
|
+
|
80
|
+
for i in range(k, len(new_output_ids)):
|
81
|
+
assert self.matcher.accept_token(new_output_ids[i])
|
82
|
+
|
83
|
+
def fill_vocab_mask(self, vocab_mask: torch.Tensor):
|
84
|
+
# Note that this bitmask is a bitset, not bool
|
85
|
+
bitmask = self.matcher.get_next_token_bitmask()
|
86
|
+
# Mask the tokens that are not allowed
|
87
|
+
vocab_mask[
|
88
|
+
self.matcher.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size)
|
89
|
+
] = 1
|
90
|
+
|
91
|
+
def copy(self):
|
92
|
+
matcher = GrammarMatcher(
|
93
|
+
self.ctx,
|
94
|
+
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
95
|
+
mask_vocab_size=self.vocab_size,
|
96
|
+
)
|
97
|
+
return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
|
98
|
+
|
99
|
+
|
100
|
+
class XGrammarGrammarBackend(BaseGrammarBackend):
|
101
|
+
def __init__(
|
102
|
+
self,
|
103
|
+
tokenizer,
|
104
|
+
vocab_size: int,
|
105
|
+
):
|
106
|
+
super().__init__()
|
107
|
+
|
108
|
+
if import_error:
|
109
|
+
logger.warning(
|
110
|
+
f"Ignore import error for the grammar backend: {import_error}"
|
111
|
+
)
|
112
|
+
self.grammar_cache = None
|
113
|
+
return
|
114
|
+
|
115
|
+
self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
|
116
|
+
self.vocab_size = vocab_size
|
117
|
+
|
118
|
+
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
119
|
+
if import_error:
|
120
|
+
raise import_error
|
121
|
+
|
122
|
+
key_type, key_string = key
|
123
|
+
if key_type == "json":
|
124
|
+
try:
|
125
|
+
ctx = self.grammar_cache.get_compiled_grammar_for_json_schema(
|
126
|
+
key_string
|
127
|
+
)
|
128
|
+
except RuntimeError as e:
|
129
|
+
logging.warning(
|
130
|
+
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
131
|
+
)
|
132
|
+
return None
|
133
|
+
elif key_type == "regex":
|
134
|
+
logger.warning(
|
135
|
+
"regex hasn't been supported by xgrammar yet. This is skipped."
|
136
|
+
)
|
137
|
+
return None
|
138
|
+
else:
|
139
|
+
raise ValueError(f"Invalid key_type: {key_type}")
|
140
|
+
|
141
|
+
matcher = GrammarMatcher(
|
142
|
+
ctx,
|
143
|
+
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
144
|
+
mask_vocab_size=self.vocab_size,
|
145
|
+
)
|
146
|
+
return XGrammarGrammar(matcher, self.vocab_size, ctx)
|
147
|
+
|
148
|
+
def reset(self):
|
149
|
+
if self.grammar_cache:
|
150
|
+
self.grammar_cache.clear()
|
@@ -507,6 +507,12 @@ def _decode_grouped_att_m_fwd(
|
|
507
507
|
|
508
508
|
num_warps = 4
|
509
509
|
|
510
|
+
extra_kargs = {}
|
511
|
+
if is_hip():
|
512
|
+
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
513
|
+
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
514
|
+
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
515
|
+
|
510
516
|
_fwd_grouped_kernel_stage1[grid](
|
511
517
|
q,
|
512
518
|
k_buffer,
|
@@ -532,6 +538,7 @@ def _decode_grouped_att_m_fwd(
|
|
532
538
|
num_warps=num_warps,
|
533
539
|
num_stages=1,
|
534
540
|
Lk=Lk,
|
541
|
+
**extra_kargs,
|
535
542
|
)
|
536
543
|
|
537
544
|
|
@@ -25,6 +25,7 @@ import triton.language as tl
|
|
25
25
|
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
26
26
|
context_attention_fwd,
|
27
27
|
)
|
28
|
+
from sglang.srt.utils import is_hip
|
28
29
|
|
29
30
|
is_cuda_available = torch.cuda.is_available()
|
30
31
|
if is_cuda_available:
|
@@ -311,6 +312,10 @@ def extend_attention_fwd(
|
|
311
312
|
num_warps = 4 if Lk <= 64 else 8
|
312
313
|
num_stages = 1
|
313
314
|
|
315
|
+
extra_kargs = {}
|
316
|
+
if is_hip():
|
317
|
+
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
318
|
+
|
314
319
|
_fwd_kernel[grid](
|
315
320
|
q_extend,
|
316
321
|
k_extend,
|
@@ -348,6 +353,7 @@ def extend_attention_fwd(
|
|
348
353
|
Lv=Lv,
|
349
354
|
num_warps=num_warps,
|
350
355
|
num_stages=num_stages,
|
356
|
+
**extra_kargs,
|
351
357
|
)
|
352
358
|
|
353
359
|
|
@@ -54,6 +54,7 @@ def fused_moe_kernel(
|
|
54
54
|
top_k: tl.constexpr,
|
55
55
|
compute_type: tl.constexpr,
|
56
56
|
use_fp8: tl.constexpr,
|
57
|
+
even_Ks: tl.constexpr,
|
57
58
|
):
|
58
59
|
"""
|
59
60
|
Implements the fused computation for a Mixture of Experts (MOE) using
|
@@ -130,16 +131,24 @@ def fused_moe_kernel(
|
|
130
131
|
# of fp32 values for higher accuracy.
|
131
132
|
# `accumulator` will be converted back to fp16 after the loop.
|
132
133
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
133
|
-
|
134
134
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
135
135
|
# Load the next block of A and B, generate a mask by checking the
|
136
136
|
# K dimension.
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
137
|
+
if even_Ks:
|
138
|
+
a = tl.load(
|
139
|
+
a_ptrs,
|
140
|
+
mask=token_mask[:, None],
|
141
|
+
other=0.0,
|
142
|
+
)
|
143
|
+
b = tl.load(b_ptrs)
|
144
|
+
else:
|
145
|
+
a = tl.load(
|
146
|
+
a_ptrs,
|
147
|
+
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
148
|
+
other=0.0,
|
149
|
+
)
|
150
|
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
151
|
+
|
143
152
|
# We accumulate along the K dimension.
|
144
153
|
if use_fp8:
|
145
154
|
accumulator = tl.dot(a, b, acc=accumulator)
|
@@ -253,6 +262,12 @@ def invoke_fused_moe_kernel(
|
|
253
262
|
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
254
263
|
)
|
255
264
|
|
265
|
+
K = B.shape[2] - padding_size
|
266
|
+
if K % config["BLOCK_SIZE_K"] == 0:
|
267
|
+
even_ks = True
|
268
|
+
else:
|
269
|
+
even_ks = False
|
270
|
+
|
256
271
|
fused_moe_kernel[grid](
|
257
272
|
A,
|
258
273
|
B,
|
@@ -278,6 +293,7 @@ def invoke_fused_moe_kernel(
|
|
278
293
|
top_k=top_k,
|
279
294
|
compute_type=compute_type,
|
280
295
|
use_fp8=use_fp8,
|
296
|
+
even_Ks=even_ks,
|
281
297
|
**config,
|
282
298
|
)
|
283
299
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Optional
|
1
|
+
from typing import Callable, Optional
|
2
2
|
|
3
3
|
import torch
|
4
4
|
from torch.nn import functional as F
|
@@ -98,7 +98,9 @@ def fused_moe_forward_native(
|
|
98
98
|
renormalize: bool,
|
99
99
|
topk_group: Optional[int] = None,
|
100
100
|
num_expert_group: Optional[int] = None,
|
101
|
+
custom_routing_function: Optional[Callable] = None,
|
101
102
|
) -> torch.Tensor:
|
103
|
+
assert custom_routing_function is None
|
102
104
|
topk_weights, topk_ids = select_experts_native(
|
103
105
|
hidden_states=x,
|
104
106
|
router_logits=router_logits,
|
@@ -114,4 +116,4 @@ def fused_moe_forward_native(
|
|
114
116
|
x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights))
|
115
117
|
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
116
118
|
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
117
|
-
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights)
|
119
|
+
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|