sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +7 -7
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +158 -11
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/bench_latency.py +299 -0
- sglang/global_config.py +12 -2
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +114 -67
- sglang/lang/ir.py +28 -3
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +13 -6
- sglang/srt/constrained/fsm_cache.py +8 -2
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +3 -1
- sglang/srt/hf_transformers_utils.py +130 -1
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +582 -0
- sglang/srt/layers/logits_processor.py +65 -32
- sglang/srt/layers/radix_attention.py +41 -7
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
- sglang/srt/managers/{router → controller}/model_runner.py +262 -158
- sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
- sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
- sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
- sglang/srt/managers/detokenizer_manager.py +42 -46
- sglang/srt/managers/io_struct.py +22 -12
- sglang/srt/managers/tokenizer_manager.py +151 -87
- sglang/srt/model_config.py +83 -5
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +10 -13
- sglang/srt/models/dbrx.py +9 -15
- sglang/srt/models/gemma.py +12 -15
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +26 -15
- sglang/srt/models/llama_classification.py +104 -0
- sglang/srt/models/llava.py +86 -19
- sglang/srt/models/llavavid.py +11 -20
- sglang/srt/models/mixtral.py +282 -103
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +9 -13
- sglang/srt/models/qwen2.py +11 -13
- sglang/srt/models/stablelm.py +9 -15
- sglang/srt/models/yivl.py +17 -22
- sglang/srt/openai_api_adapter.py +150 -95
- sglang/srt/openai_protocol.py +11 -2
- sglang/srt/server.py +124 -48
- sglang/srt/server_args.py +128 -48
- sglang/srt/utils.py +234 -67
- sglang/test/test_programs.py +65 -3
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +23 -4
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
- sglang-0.1.18.dist-info/RECORD +78 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -417
- sglang-0.1.16.dist-info/RECORD +0 -72
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
sglang/lang/ir.py
CHANGED
@@ -82,6 +82,19 @@ class SglSamplingParams:
|
|
82
82
|
"top_k": self.top_k,
|
83
83
|
}
|
84
84
|
|
85
|
+
def to_litellm_kwargs(self):
|
86
|
+
if self.regex is not None:
|
87
|
+
warnings.warn("Regular expression is not supported in the LiteLLM backend.")
|
88
|
+
return {
|
89
|
+
"max_tokens": self.max_new_tokens,
|
90
|
+
"stop": self.stop or None,
|
91
|
+
"temperature": self.temperature,
|
92
|
+
"top_p": self.top_p,
|
93
|
+
"top_k": self.top_k,
|
94
|
+
"frequency_penalty": self.frequency_penalty,
|
95
|
+
"presence_penalty": self.presence_penalty,
|
96
|
+
}
|
97
|
+
|
85
98
|
def to_srt_kwargs(self):
|
86
99
|
return {
|
87
100
|
"max_new_tokens": self.max_new_tokens,
|
@@ -97,9 +110,9 @@ class SglSamplingParams:
|
|
97
110
|
|
98
111
|
|
99
112
|
class SglFunction:
|
100
|
-
def __init__(self, func,
|
113
|
+
def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None):
|
101
114
|
self.func = func
|
102
|
-
self.
|
115
|
+
self.num_api_spec_tokens = num_api_spec_tokens
|
103
116
|
self.bind_arguments = bind_arguments or {}
|
104
117
|
self.pin_prefix_rid = None
|
105
118
|
|
@@ -107,6 +120,7 @@ class SglFunction:
|
|
107
120
|
argspec = inspect.getfullargspec(func)
|
108
121
|
assert argspec.args[0] == "s", 'The first argument must be "s"'
|
109
122
|
self.arg_names = argspec.args[1:]
|
123
|
+
self.arg_defaults = argspec.defaults if argspec.defaults is not None else []
|
110
124
|
|
111
125
|
def bind(self, **kwargs):
|
112
126
|
assert all(key in self.arg_names for key in kwargs)
|
@@ -165,7 +179,18 @@ class SglFunction:
|
|
165
179
|
assert isinstance(batch_kwargs, (list, tuple))
|
166
180
|
if len(batch_kwargs) == 0:
|
167
181
|
return []
|
168
|
-
|
182
|
+
if not isinstance(batch_kwargs[0], dict):
|
183
|
+
num_programs = len(batch_kwargs)
|
184
|
+
# change the list of argument values to dict of arg_name -> arg_value
|
185
|
+
batch_kwargs = [
|
186
|
+
{self.arg_names[i]: v for i, v in enumerate(arg_values)}
|
187
|
+
for arg_values in batch_kwargs
|
188
|
+
if isinstance(arg_values, (list, tuple)) and
|
189
|
+
len(self.arg_names) - len(self.arg_defaults) <= len(arg_values) <= len(self.arg_names)
|
190
|
+
]
|
191
|
+
# Ensure to raise an exception if the number of arguments mismatch
|
192
|
+
if len(batch_kwargs) != num_programs:
|
193
|
+
raise Exception("Given arguments mismatch the SGL function signature")
|
169
194
|
|
170
195
|
default_sampling_para = SglSamplingParams(
|
171
196
|
max_new_tokens=max_new_tokens,
|
sglang/launch_server.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1
|
+
"""Launch the inference server."""
|
2
|
+
|
1
3
|
import argparse
|
2
4
|
|
3
|
-
from sglang.srt.server import
|
5
|
+
from sglang.srt.server import launch_server
|
6
|
+
from sglang.srt.server_args import ServerArgs
|
4
7
|
|
5
8
|
if __name__ == "__main__":
|
6
9
|
parser = argparse.ArgumentParser()
|
sglang/launch_server_llavavid.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1
|
+
"""Launch the inference server for Llava-video model."""
|
2
|
+
|
1
3
|
import argparse
|
2
4
|
import multiprocessing as mp
|
3
5
|
|
4
6
|
from sglang.srt.server import ServerArgs, launch_server
|
5
7
|
|
6
8
|
if __name__ == "__main__":
|
7
|
-
|
8
9
|
model_overide_args = {}
|
9
10
|
|
10
11
|
model_overide_args["mm_spatial_pool_stride"] = 2
|
@@ -1,13 +1,19 @@
|
|
1
1
|
import json
|
2
2
|
from typing import Dict, Optional, Union
|
3
3
|
|
4
|
-
from outlines.caching import cache as disk_cache
|
5
|
-
from outlines.caching import disable_cache
|
6
|
-
from outlines.fsm.fsm import RegexFSM
|
7
|
-
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm
|
8
|
-
from outlines.models.transformers import TransformerTokenizer
|
9
4
|
from pydantic import BaseModel
|
10
5
|
|
6
|
+
try:
|
7
|
+
from outlines.caching import cache as disk_cache
|
8
|
+
from outlines.fsm.guide import RegexGuide
|
9
|
+
from outlines.caching import disable_cache
|
10
|
+
from outlines.fsm.guide import RegexGuide
|
11
|
+
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
12
|
+
from outlines.models.transformers import TransformerTokenizer
|
13
|
+
except ImportError as e:
|
14
|
+
print(f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n')
|
15
|
+
raise
|
16
|
+
|
11
17
|
try:
|
12
18
|
from outlines.fsm.json_schema import build_regex_from_object
|
13
19
|
except ImportError:
|
@@ -28,11 +34,12 @@ except ImportError:
|
|
28
34
|
|
29
35
|
|
30
36
|
__all__ = [
|
31
|
-
"
|
37
|
+
"RegexGuide",
|
32
38
|
"FSMInfo",
|
33
39
|
"make_deterministic_fsm",
|
34
40
|
"build_regex_from_object",
|
35
41
|
"TransformerTokenizer",
|
36
42
|
"disk_cache",
|
37
43
|
"disable_cache",
|
44
|
+
"make_byte_level_fsm",
|
38
45
|
]
|
@@ -1,4 +1,6 @@
|
|
1
|
-
|
1
|
+
"""Cache for the compressed finite state machine."""
|
2
|
+
|
3
|
+
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
2
4
|
from sglang.srt.constrained.base_cache import BaseCache
|
3
5
|
|
4
6
|
|
@@ -6,6 +8,10 @@ class FSMCache(BaseCache):
|
|
6
8
|
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
|
7
9
|
super().__init__(enable=enable)
|
8
10
|
|
11
|
+
if tokenizer_path.endswith(".json") or tokenizer_path.endswith(".model"):
|
12
|
+
# Do not support TiktokenTokenizer or SentencePieceTokenizer
|
13
|
+
return
|
14
|
+
|
9
15
|
from importlib.metadata import version
|
10
16
|
|
11
17
|
if version("outlines") >= "0.0.35":
|
@@ -22,4 +28,4 @@ class FSMCache(BaseCache):
|
|
22
28
|
)
|
23
29
|
|
24
30
|
def init_value(self, regex):
|
25
|
-
return
|
31
|
+
return RegexGuide(regex, self.outlines_tokenizer)
|
@@ -1,17 +1,43 @@
|
|
1
|
-
|
1
|
+
"""
|
2
|
+
Faster constrained decoding.
|
3
|
+
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
|
4
|
+
"""
|
5
|
+
|
6
|
+
import dataclasses
|
7
|
+
from collections import defaultdict
|
2
8
|
|
3
|
-
|
9
|
+
import interegular
|
10
|
+
import outlines.caching
|
11
|
+
|
12
|
+
from sglang.srt.constrained import (
|
13
|
+
FSMInfo,
|
14
|
+
disk_cache,
|
15
|
+
make_byte_level_fsm,
|
16
|
+
make_deterministic_fsm,
|
17
|
+
)
|
4
18
|
from sglang.srt.constrained.base_cache import BaseCache
|
5
19
|
|
6
20
|
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
7
21
|
|
8
22
|
|
23
|
+
@dataclasses.dataclass
|
24
|
+
class JumpEdge:
|
25
|
+
symbol: str = None
|
26
|
+
symbol_next_state: int = None
|
27
|
+
byte: int = None
|
28
|
+
byte_next_state: int = None
|
29
|
+
|
30
|
+
|
9
31
|
class JumpForwardMap:
|
10
32
|
def __init__(self, regex_string):
|
11
33
|
@disk_cache()
|
12
34
|
def _init_state_to_jump_forward(regex_string):
|
13
35
|
regex_pattern = interegular.parse_pattern(regex_string)
|
14
|
-
|
36
|
+
|
37
|
+
byte_fsm = make_byte_level_fsm(
|
38
|
+
regex_pattern.to_fsm().reduce(), keep_utf8=True
|
39
|
+
)
|
40
|
+
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
|
15
41
|
|
16
42
|
fsm_info: FSMInfo = regex_fsm.fsm_info
|
17
43
|
|
@@ -21,40 +47,93 @@ class JumpForwardMap:
|
|
21
47
|
id_to_symbol.setdefault(id_, []).append(symbol)
|
22
48
|
|
23
49
|
transitions = fsm_info.transitions
|
24
|
-
|
50
|
+
outgoings_ct = defaultdict(int)
|
25
51
|
state_to_jump_forward = {}
|
26
52
|
|
27
53
|
for (state, id_), next_state in transitions.items():
|
28
|
-
if
|
29
|
-
continue
|
30
|
-
if state in state_to_jump_forward:
|
31
|
-
dirty_states.add(state)
|
32
|
-
del state_to_jump_forward[state]
|
54
|
+
if id_ == fsm_info.alphabet_anything_value:
|
33
55
|
continue
|
34
|
-
|
35
|
-
|
56
|
+
symbols = id_to_symbol[id_]
|
57
|
+
for c in symbols:
|
58
|
+
if len(c) > 1:
|
59
|
+
# Skip byte level transitions
|
60
|
+
continue
|
61
|
+
|
62
|
+
outgoings_ct[state] += 1
|
63
|
+
if outgoings_ct[state] > 1:
|
64
|
+
if state in state_to_jump_forward:
|
65
|
+
del state_to_jump_forward[state]
|
66
|
+
break
|
67
|
+
|
68
|
+
state_to_jump_forward[state] = JumpEdge(
|
69
|
+
symbol=c,
|
70
|
+
symbol_next_state=next_state,
|
71
|
+
)
|
72
|
+
|
73
|
+
# Process the byte level jump forward
|
74
|
+
outgoings_ct = defaultdict(int)
|
75
|
+
for (state, id_), next_state in transitions.items():
|
76
|
+
if id_ == fsm_info.alphabet_anything_value:
|
36
77
|
continue
|
37
|
-
|
38
|
-
|
78
|
+
symbols = id_to_symbol[id_]
|
79
|
+
for c in symbols:
|
80
|
+
byte_ = None
|
81
|
+
if len(c) == 1 and ord(c) < 0x80:
|
82
|
+
# ASCII character
|
83
|
+
byte_ = ord(c)
|
84
|
+
elif len(c) > 1:
|
85
|
+
# FIXME: This logic is due to the leading \x00
|
86
|
+
# https://github.com/outlines-dev/outlines/pull/930
|
87
|
+
byte_ = int(symbols[0][1:], 16)
|
88
|
+
|
89
|
+
if byte_ is not None:
|
90
|
+
outgoings_ct[state] += 1
|
91
|
+
if outgoings_ct[state] > 1:
|
92
|
+
if state in state_to_jump_forward:
|
93
|
+
del state_to_jump_forward[state]
|
94
|
+
break
|
95
|
+
e = state_to_jump_forward.get(state, JumpEdge())
|
96
|
+
e.byte = byte_
|
97
|
+
e.byte_next_state = next_state
|
98
|
+
state_to_jump_forward[state] = e
|
39
99
|
|
40
100
|
return state_to_jump_forward
|
41
101
|
|
42
102
|
self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
|
43
103
|
|
44
|
-
def
|
45
|
-
|
104
|
+
def jump_forward_symbol(self, state):
|
105
|
+
jump_forward_str = ""
|
106
|
+
next_state = state
|
107
|
+
while state in self.state_to_jump_forward:
|
108
|
+
e = self.state_to_jump_forward[state]
|
109
|
+
if e.symbol is None:
|
110
|
+
break
|
111
|
+
jump_forward_str += e.symbol
|
112
|
+
next_state = e.symbol_next_state
|
113
|
+
state = next_state
|
46
114
|
|
47
|
-
|
115
|
+
return jump_forward_str, next_state
|
116
|
+
|
117
|
+
def jump_forward_byte(self, state):
|
48
118
|
if state not in self.state_to_jump_forward:
|
49
119
|
return None
|
50
120
|
|
51
|
-
|
121
|
+
jump_forward_bytes = []
|
52
122
|
next_state = None
|
53
123
|
while state in self.state_to_jump_forward:
|
54
|
-
|
55
|
-
|
124
|
+
e = self.state_to_jump_forward[state]
|
125
|
+
assert e.byte is not None and e.byte_next_state is not None
|
126
|
+
jump_forward_bytes.append((e.byte, e.byte_next_state))
|
127
|
+
next_state = e.byte_next_state
|
56
128
|
state = next_state
|
57
|
-
|
129
|
+
|
130
|
+
return jump_forward_bytes
|
131
|
+
|
132
|
+
def is_jump_forward_symbol_state(self, state):
|
133
|
+
return (
|
134
|
+
state in self.state_to_jump_forward
|
135
|
+
and self.state_to_jump_forward[state].symbol is not None
|
136
|
+
)
|
58
137
|
|
59
138
|
|
60
139
|
class JumpForwardCache(BaseCache):
|
@@ -65,12 +144,21 @@ class JumpForwardCache(BaseCache):
|
|
65
144
|
return JumpForwardMap(regex)
|
66
145
|
|
67
146
|
|
68
|
-
def test_main():
|
69
|
-
regex_string = r"The google's DNS sever address is " + IP_REGEX
|
147
|
+
def test_main(regex_string):
|
70
148
|
jump_forward_map = JumpForwardMap(regex_string)
|
71
|
-
for state in jump_forward_map.
|
72
|
-
|
149
|
+
for state, e in jump_forward_map.state_to_jump_forward.items():
|
150
|
+
if e.symbol is not None:
|
151
|
+
jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
|
152
|
+
print(f"{state} -> {next_state}", jump_forward_str)
|
153
|
+
bytes_ = jump_forward_map.jump_forward_byte(state)
|
154
|
+
print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_])
|
73
155
|
|
74
156
|
|
75
157
|
if __name__ == "__main__":
|
76
|
-
|
158
|
+
import outlines
|
159
|
+
|
160
|
+
outlines.caching.clear_cache()
|
161
|
+
test_main(r"The google's DNS sever address is " + IP_REGEX)
|
162
|
+
test_main(r"霍格沃茨特快列车|霍比特人比尔博")
|
163
|
+
# 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
|
164
|
+
# 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
|
sglang/srt/conversation.py
CHANGED
sglang/srt/flush_cache.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1
1
|
"""
|
2
|
+
Flush the KV cache.
|
3
|
+
|
2
4
|
Usage:
|
3
5
|
python3 -m sglang.srt.flush_cache --url http://localhost:30000
|
4
6
|
"""
|
@@ -13,4 +15,4 @@ if __name__ == "__main__":
|
|
13
15
|
args = parser.parse_args()
|
14
16
|
|
15
17
|
response = requests.get(args.url + "/flush_cache")
|
16
|
-
assert response.status_code == 200
|
18
|
+
assert response.status_code == 200
|
@@ -1,9 +1,10 @@
|
|
1
1
|
"""Utilities for Huggingface Transformers."""
|
2
2
|
|
3
|
+
import functools
|
3
4
|
import json
|
4
5
|
import os
|
5
6
|
import warnings
|
6
|
-
from typing import
|
7
|
+
from typing import AbstractSet, Collection, Literal, Optional, Union
|
7
8
|
|
8
9
|
from huggingface_hub import snapshot_download
|
9
10
|
from transformers import (
|
@@ -84,6 +85,12 @@ def get_tokenizer(
|
|
84
85
|
tokenizer_revision: Optional[str] = None,
|
85
86
|
**kwargs,
|
86
87
|
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
88
|
+
if tokenizer_name.endswith(".json"):
|
89
|
+
return TiktokenTokenizer(tokenizer_name)
|
90
|
+
|
91
|
+
if tokenizer_name.endswith(".model"):
|
92
|
+
return SentencePieceTokenizer(tokenizer_name)
|
93
|
+
|
87
94
|
"""Gets a tokenizer for the given model name via Huggingface."""
|
88
95
|
if is_multimodal_model(tokenizer_name):
|
89
96
|
processor = get_processor(
|
@@ -170,3 +177,125 @@ def get_processor(
|
|
170
177
|
**kwargs,
|
171
178
|
)
|
172
179
|
return processor
|
180
|
+
|
181
|
+
|
182
|
+
class TiktokenTokenizer:
|
183
|
+
def __init__(self, tokenizer_path):
|
184
|
+
import tiktoken
|
185
|
+
from jinja2 import Template
|
186
|
+
|
187
|
+
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
188
|
+
|
189
|
+
# Read JSON
|
190
|
+
name = "tmp-json"
|
191
|
+
with open(tokenizer_path, "rb") as fin:
|
192
|
+
tok_dict = json.load(fin)
|
193
|
+
|
194
|
+
mergeable_ranks = {
|
195
|
+
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
|
196
|
+
}
|
197
|
+
special_tokens = {
|
198
|
+
bytes(item["bytes"]).decode(): item["token"]
|
199
|
+
for item in tok_dict["special_tokens"]
|
200
|
+
}
|
201
|
+
assert tok_dict["word_split"] == "V1"
|
202
|
+
|
203
|
+
kwargs = {
|
204
|
+
"name": name,
|
205
|
+
"pat_str": tok_dict.get("pat_str", PAT_STR_B),
|
206
|
+
"mergeable_ranks": mergeable_ranks,
|
207
|
+
"special_tokens": special_tokens,
|
208
|
+
}
|
209
|
+
if "default_allowed_special" in tok_dict:
|
210
|
+
default_allowed_special = set(
|
211
|
+
[
|
212
|
+
bytes(bytes_list).decode()
|
213
|
+
for bytes_list in tok_dict["default_allowed_special"]
|
214
|
+
]
|
215
|
+
)
|
216
|
+
else:
|
217
|
+
default_allowed_special = None
|
218
|
+
if "vocab_size" in tok_dict:
|
219
|
+
kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
|
220
|
+
|
221
|
+
tokenizer = tiktoken.Encoding(**kwargs)
|
222
|
+
tokenizer._default_allowed_special = default_allowed_special or set()
|
223
|
+
tokenizer._default_allowed_special |= {"<|separator|>"}
|
224
|
+
|
225
|
+
def encode_patched(
|
226
|
+
self,
|
227
|
+
text: str,
|
228
|
+
*,
|
229
|
+
allowed_special: Union[
|
230
|
+
Literal["all"], AbstractSet[str]
|
231
|
+
] = set(), # noqa: B006
|
232
|
+
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
233
|
+
) -> list[int]:
|
234
|
+
if isinstance(allowed_special, set):
|
235
|
+
allowed_special |= self._default_allowed_special
|
236
|
+
return tiktoken.Encoding.encode(
|
237
|
+
self,
|
238
|
+
text,
|
239
|
+
allowed_special=allowed_special,
|
240
|
+
disallowed_special=disallowed_special,
|
241
|
+
)
|
242
|
+
|
243
|
+
tokenizer.encode = functools.partial(encode_patched, tokenizer)
|
244
|
+
|
245
|
+
# Convert to HF interface
|
246
|
+
self.tokenizer = tokenizer
|
247
|
+
self.eos_token_id = tokenizer._special_tokens["<|eos|>"]
|
248
|
+
self.vocab_size = tokenizer.n_vocab
|
249
|
+
self.chat_template = Template(
|
250
|
+
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
251
|
+
)
|
252
|
+
|
253
|
+
def encode(self, x, add_special_tokens=False):
|
254
|
+
return self.tokenizer.encode(x)
|
255
|
+
|
256
|
+
def decode(self, x):
|
257
|
+
return self.tokenizer.decode(x)
|
258
|
+
|
259
|
+
def batch_decode(
|
260
|
+
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
|
261
|
+
):
|
262
|
+
if isinstance(batch[0], int):
|
263
|
+
batch = [[x] for x in batch]
|
264
|
+
return self.tokenizer.decode_batch(batch)
|
265
|
+
|
266
|
+
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
|
267
|
+
ret = self.chat_template.render(messages=messages, add_generation_prompt=add_generation_prompt)
|
268
|
+
return self.encode(ret) if tokenize else ret
|
269
|
+
|
270
|
+
|
271
|
+
class SentencePieceTokenizer:
|
272
|
+
def __init__(self, tokenizer_path):
|
273
|
+
import sentencepiece as spm
|
274
|
+
from jinja2 import Template
|
275
|
+
|
276
|
+
tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path)
|
277
|
+
|
278
|
+
# Convert to HF interface
|
279
|
+
self.tokenizer = tokenizer
|
280
|
+
self.eos_token_id = tokenizer.eos_id()
|
281
|
+
self.vocab_size = tokenizer.vocab_size()
|
282
|
+
self.chat_template = Template(
|
283
|
+
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
284
|
+
)
|
285
|
+
|
286
|
+
def encode(self, x, add_special_tokens=False):
|
287
|
+
return self.tokenizer.encode(x)
|
288
|
+
|
289
|
+
def decode(self, x):
|
290
|
+
return self.tokenizer.decode(x)
|
291
|
+
|
292
|
+
def batch_decode(
|
293
|
+
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
|
294
|
+
):
|
295
|
+
if isinstance(batch[0], int):
|
296
|
+
batch = [[x] for x in batch]
|
297
|
+
return self.tokenizer.decode(batch)
|
298
|
+
|
299
|
+
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
|
300
|
+
ret = self.chat_template.render(messages=messages, add_generation_prompt=add_generation_prompt)
|
301
|
+
return self.encode(ret) if tokenize else ret
|
@@ -8,6 +8,12 @@ from sglang.srt.utils import wrap_kernel_launcher
|
|
8
8
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
9
9
|
|
10
10
|
|
11
|
+
@triton.jit
|
12
|
+
def tanh(x):
|
13
|
+
# Tanh is just a scaled sigmoid
|
14
|
+
return 2 * tl.sigmoid(2 * x) - 1
|
15
|
+
|
16
|
+
|
11
17
|
@triton.jit
|
12
18
|
def _fwd_kernel(
|
13
19
|
Q_Extend,
|
@@ -39,6 +45,7 @@ def _fwd_kernel(
|
|
39
45
|
BLOCK_DMODEL: tl.constexpr,
|
40
46
|
BLOCK_M: tl.constexpr,
|
41
47
|
BLOCK_N: tl.constexpr,
|
48
|
+
logit_cap: tl.constexpr,
|
42
49
|
):
|
43
50
|
cur_seq = tl.program_id(0)
|
44
51
|
cur_head = tl.program_id(1)
|
@@ -90,6 +97,10 @@ def _fwd_kernel(
|
|
90
97
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
91
98
|
qk += tl.dot(q, k)
|
92
99
|
qk *= sm_scale
|
100
|
+
|
101
|
+
if logit_cap > 0:
|
102
|
+
qk = logit_cap * tanh(qk / logit_cap)
|
103
|
+
|
93
104
|
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
|
94
105
|
|
95
106
|
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
@@ -126,6 +137,10 @@ def _fwd_kernel(
|
|
126
137
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
127
138
|
qk += tl.dot(q, k)
|
128
139
|
qk *= sm_scale
|
140
|
+
|
141
|
+
if logit_cap > 0:
|
142
|
+
qk = logit_cap * tanh(qk / logit_cap)
|
143
|
+
|
129
144
|
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
130
145
|
start_n + offs_n[None, :]
|
131
146
|
)
|
@@ -176,6 +191,7 @@ def extend_attention_fwd(
|
|
176
191
|
b_seq_len_extend,
|
177
192
|
max_len_in_batch,
|
178
193
|
max_len_extend,
|
194
|
+
logit_cap=-1,
|
179
195
|
):
|
180
196
|
"""
|
181
197
|
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
@@ -271,6 +287,7 @@ def extend_attention_fwd(
|
|
271
287
|
BLOCK_N=BLOCK_N,
|
272
288
|
num_warps=num_warps,
|
273
289
|
num_stages=num_stages,
|
290
|
+
logit_cap=logit_cap,
|
274
291
|
)
|
275
292
|
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
|
276
293
|
|