sglang 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +3 -3
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +148 -12
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/global_config.py +8 -1
- sglang/lang/interpreter.py +114 -67
- sglang/lang/ir.py +17 -2
- sglang/srt/constrained/fsm_cache.py +3 -0
- sglang/srt/flush_cache.py +1 -1
- sglang/srt/hf_transformers_utils.py +75 -1
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +485 -0
- sglang/srt/layers/logits_processor.py +12 -7
- sglang/srt/layers/radix_attention.py +10 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +110 -0
- sglang/srt/managers/controller/infer_batch.py +619 -0
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/controller/manager_single.py +97 -0
- sglang/srt/managers/controller/model_runner.py +462 -0
- sglang/srt/managers/controller/radix_cache.py +267 -0
- sglang/srt/managers/controller/schedule_heuristic.py +59 -0
- sglang/srt/managers/controller/tp_worker.py +791 -0
- sglang/srt/managers/detokenizer_manager.py +45 -45
- sglang/srt/managers/io_struct.py +15 -11
- sglang/srt/managers/router/infer_batch.py +103 -59
- sglang/srt/managers/router/manager.py +1 -1
- sglang/srt/managers/router/model_rpc.py +175 -122
- sglang/srt/managers/router/model_runner.py +91 -104
- sglang/srt/managers/router/radix_cache.py +7 -1
- sglang/srt/managers/router/scheduler.py +6 -6
- sglang/srt/managers/tokenizer_manager.py +152 -89
- sglang/srt/model_config.py +4 -5
- sglang/srt/models/commandr.py +10 -13
- sglang/srt/models/dbrx.py +9 -15
- sglang/srt/models/gemma.py +8 -15
- sglang/srt/models/grok.py +671 -0
- sglang/srt/models/llama2.py +19 -15
- sglang/srt/models/llava.py +84 -20
- sglang/srt/models/llavavid.py +11 -20
- sglang/srt/models/mixtral.py +248 -118
- sglang/srt/models/mixtral_quant.py +373 -0
- sglang/srt/models/qwen.py +9 -13
- sglang/srt/models/qwen2.py +11 -13
- sglang/srt/models/stablelm.py +9 -15
- sglang/srt/models/yivl.py +17 -22
- sglang/srt/openai_api_adapter.py +140 -95
- sglang/srt/openai_protocol.py +10 -1
- sglang/srt/server.py +77 -42
- sglang/srt/server_args.py +51 -6
- sglang/srt/utils.py +124 -66
- sglang/test/test_programs.py +44 -0
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +22 -4
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/METADATA +15 -9
- sglang-0.1.17.dist-info/RECORD +81 -0
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -417
- sglang-0.1.16.dist-info/RECORD +0 -72
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,267 @@
|
|
1
|
+
import heapq
|
2
|
+
import time
|
3
|
+
from collections import defaultdict
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
|
8
|
+
class TreeNode:
|
9
|
+
def __init__(self):
|
10
|
+
self.children = defaultdict(TreeNode)
|
11
|
+
self.parent = None
|
12
|
+
self.key = None
|
13
|
+
self.value = None
|
14
|
+
self.lock_ref = 0
|
15
|
+
self.last_access_time = time.time()
|
16
|
+
|
17
|
+
def __lt__(self, other: "TreeNode"):
|
18
|
+
return self.last_access_time < other.last_access_time
|
19
|
+
|
20
|
+
|
21
|
+
def _key_match(key0, key1):
|
22
|
+
i = 0
|
23
|
+
for k0, k1 in zip(key0, key1):
|
24
|
+
if k0 != k1:
|
25
|
+
break
|
26
|
+
i += 1
|
27
|
+
return i
|
28
|
+
|
29
|
+
|
30
|
+
class RadixCache:
|
31
|
+
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
|
32
|
+
self.req_to_token_pool = req_to_token_pool
|
33
|
+
self.token_to_kv_pool = token_to_kv_pool
|
34
|
+
self.disable = disable
|
35
|
+
self.reset()
|
36
|
+
|
37
|
+
##### Public API #####
|
38
|
+
|
39
|
+
def reset(self):
|
40
|
+
self.root_node = TreeNode()
|
41
|
+
self.root_node.key = []
|
42
|
+
self.root_node.value = []
|
43
|
+
self.root_node.lock_ref = 1
|
44
|
+
self.evictable_size_ = 0
|
45
|
+
|
46
|
+
def match_prefix(self, key):
|
47
|
+
if self.disable:
|
48
|
+
return [], self.root_node
|
49
|
+
|
50
|
+
value = []
|
51
|
+
last_node = [self.root_node]
|
52
|
+
self._match_prefix_helper(self.root_node, key, value, last_node)
|
53
|
+
if value:
|
54
|
+
value = torch.concat(value)
|
55
|
+
else:
|
56
|
+
value = torch.tensor([], dtype=torch.int64)
|
57
|
+
return value, last_node[0]
|
58
|
+
|
59
|
+
def insert(self, key, value=None):
|
60
|
+
if self.disable:
|
61
|
+
return 0
|
62
|
+
|
63
|
+
if value is None:
|
64
|
+
value = [x for x in key]
|
65
|
+
return self._insert_helper(self.root_node, key, value)
|
66
|
+
|
67
|
+
def cache_req(
|
68
|
+
self,
|
69
|
+
token_ids,
|
70
|
+
last_uncached_pos,
|
71
|
+
req_pool_idx,
|
72
|
+
del_in_memory_pool=True,
|
73
|
+
old_last_node=None,
|
74
|
+
):
|
75
|
+
# Insert the request into radix cache
|
76
|
+
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
77
|
+
new_prefix_len = self.insert(token_ids, indices.clone())
|
78
|
+
|
79
|
+
if self.disable:
|
80
|
+
if del_in_memory_pool:
|
81
|
+
self.token_to_kv_pool.dec_refs(indices)
|
82
|
+
else:
|
83
|
+
return torch.tensor([], dtype=torch.int64), self.root_node
|
84
|
+
|
85
|
+
# Radix Cache takes one ref in memory pool
|
86
|
+
self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len])
|
87
|
+
|
88
|
+
if del_in_memory_pool:
|
89
|
+
self.req_to_token_pool.free(req_pool_idx)
|
90
|
+
else:
|
91
|
+
cached_indices, new_last_node = self.match_prefix(token_ids)
|
92
|
+
assert len(cached_indices) == len(token_ids)
|
93
|
+
|
94
|
+
self.req_to_token_pool.req_to_token[
|
95
|
+
req_pool_idx, last_uncached_pos : len(cached_indices)
|
96
|
+
] = cached_indices[last_uncached_pos:]
|
97
|
+
self.dec_lock_ref(old_last_node)
|
98
|
+
self.inc_lock_ref(new_last_node)
|
99
|
+
return cached_indices, new_last_node
|
100
|
+
|
101
|
+
def pretty_print(self):
|
102
|
+
self._print_helper(self.root_node, 0)
|
103
|
+
print(f"#tokens: {self.total_size()}")
|
104
|
+
|
105
|
+
def total_size(self):
|
106
|
+
return self._total_size_helper(self.root_node)
|
107
|
+
|
108
|
+
def evict(self, num_tokens, evict_callback):
|
109
|
+
if self.disable:
|
110
|
+
return
|
111
|
+
|
112
|
+
leaves = self._collect_leaves()
|
113
|
+
heapq.heapify(leaves)
|
114
|
+
|
115
|
+
num_evicted = 0
|
116
|
+
while num_evicted < num_tokens and len(leaves):
|
117
|
+
x = heapq.heappop(leaves)
|
118
|
+
|
119
|
+
if x == self.root_node:
|
120
|
+
break
|
121
|
+
if x.lock_ref > 0:
|
122
|
+
continue
|
123
|
+
|
124
|
+
num_evicted += evict_callback(x.value)
|
125
|
+
self._delete_leaf(x)
|
126
|
+
|
127
|
+
if len(x.parent.children) == 0:
|
128
|
+
heapq.heappush(leaves, x.parent)
|
129
|
+
|
130
|
+
def inc_lock_ref(self, node: TreeNode):
|
131
|
+
delta = 0
|
132
|
+
while node != self.root_node:
|
133
|
+
if node.lock_ref == 0:
|
134
|
+
self.evictable_size_ -= len(node.value)
|
135
|
+
delta -= len(node.value)
|
136
|
+
node.lock_ref += 1
|
137
|
+
node = node.parent
|
138
|
+
return delta
|
139
|
+
|
140
|
+
def dec_lock_ref(self, node: TreeNode):
|
141
|
+
delta = 0
|
142
|
+
while node != self.root_node:
|
143
|
+
if node.lock_ref == 1:
|
144
|
+
self.evictable_size_ += len(node.value)
|
145
|
+
delta += len(node.value)
|
146
|
+
node.lock_ref -= 1
|
147
|
+
node = node.parent
|
148
|
+
return delta
|
149
|
+
|
150
|
+
def evictable_size(self):
|
151
|
+
return self.evictable_size_
|
152
|
+
|
153
|
+
##### Internal Helper Functions #####
|
154
|
+
|
155
|
+
def _match_prefix_helper(self, node, key, value, last_node):
|
156
|
+
node.last_access_time = time.time()
|
157
|
+
if len(key) == 0:
|
158
|
+
return
|
159
|
+
|
160
|
+
if key[0] in node.children.keys():
|
161
|
+
child = node.children[key[0]]
|
162
|
+
prefix_len = _key_match(child.key, key)
|
163
|
+
if prefix_len < len(child.key):
|
164
|
+
new_node = self._split_node(child.key, child, prefix_len)
|
165
|
+
value.append(new_node.value)
|
166
|
+
last_node[0] = new_node
|
167
|
+
else:
|
168
|
+
value.append(child.value)
|
169
|
+
last_node[0] = child
|
170
|
+
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
|
171
|
+
|
172
|
+
def _split_node(self, key, child: TreeNode, split_len):
|
173
|
+
# new_node -> child
|
174
|
+
new_node = TreeNode()
|
175
|
+
new_node.children = {key[split_len:][0]: child}
|
176
|
+
new_node.parent = child.parent
|
177
|
+
new_node.lock_ref = child.lock_ref
|
178
|
+
new_node.key = child.key[:split_len]
|
179
|
+
new_node.value = child.value[:split_len]
|
180
|
+
child.parent = new_node
|
181
|
+
child.key = child.key[split_len:]
|
182
|
+
child.value = child.value[split_len:]
|
183
|
+
new_node.parent.children[key[:split_len][0]] = new_node
|
184
|
+
return new_node
|
185
|
+
|
186
|
+
def _insert_helper(self, node, key, value):
|
187
|
+
node.last_access_time = time.time()
|
188
|
+
if len(key) == 0:
|
189
|
+
return 0
|
190
|
+
|
191
|
+
if key[0] in node.children.keys():
|
192
|
+
child = node.children[key[0]]
|
193
|
+
prefix_len = _key_match(child.key, key)
|
194
|
+
|
195
|
+
if prefix_len == len(child.key):
|
196
|
+
if prefix_len == len(key):
|
197
|
+
return prefix_len
|
198
|
+
else:
|
199
|
+
key = key[prefix_len:]
|
200
|
+
value = value[prefix_len:]
|
201
|
+
return prefix_len + self._insert_helper(child, key, value)
|
202
|
+
|
203
|
+
new_node = self._split_node(child.key, child, prefix_len)
|
204
|
+
return prefix_len + self._insert_helper(
|
205
|
+
new_node, key[prefix_len:], value[prefix_len:]
|
206
|
+
)
|
207
|
+
|
208
|
+
if len(key):
|
209
|
+
new_node = TreeNode()
|
210
|
+
new_node.parent = node
|
211
|
+
new_node.key = key
|
212
|
+
new_node.value = value
|
213
|
+
node.children[key[0]] = new_node
|
214
|
+
self.evictable_size_ += len(value)
|
215
|
+
return 0
|
216
|
+
|
217
|
+
def _print_helper(self, node: TreeNode, indent):
|
218
|
+
for _, child in node.children.items():
|
219
|
+
print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
|
220
|
+
self._print_helper(child, indent=indent + 2)
|
221
|
+
|
222
|
+
def _delete_leaf(self, node):
|
223
|
+
for k, v in node.parent.children.items():
|
224
|
+
if v == node:
|
225
|
+
break
|
226
|
+
del node.parent.children[k]
|
227
|
+
self.evictable_size_ -= len(node.key)
|
228
|
+
|
229
|
+
def _total_size_helper(self, node):
|
230
|
+
x = len(node.value)
|
231
|
+
for child in node.children.values():
|
232
|
+
x += self._total_size_helper(child)
|
233
|
+
return x
|
234
|
+
|
235
|
+
def _collect_leaves(self):
|
236
|
+
ret_list = []
|
237
|
+
|
238
|
+
def dfs_(cur_node):
|
239
|
+
if len(cur_node.children) == 0:
|
240
|
+
ret_list.append(cur_node)
|
241
|
+
|
242
|
+
for x in cur_node.children.values():
|
243
|
+
dfs_(x)
|
244
|
+
|
245
|
+
dfs_(self.root_node)
|
246
|
+
return ret_list
|
247
|
+
|
248
|
+
|
249
|
+
if __name__ == "__main__":
|
250
|
+
tree = RadixCache(None, None, False)
|
251
|
+
|
252
|
+
tree.insert("Hello")
|
253
|
+
tree.insert("Hello")
|
254
|
+
tree.insert("Hello_L.A.!")
|
255
|
+
# tree.insert("Hello_world! Happy")
|
256
|
+
# tree.insert("I love you!")
|
257
|
+
tree.pretty_print()
|
258
|
+
|
259
|
+
# print(tree.match_prefix("I love you! aha"))
|
260
|
+
|
261
|
+
# def evict_callback(x):
|
262
|
+
# print("evict", x)
|
263
|
+
# return len(x)
|
264
|
+
|
265
|
+
# tree.evict(5, evict_callback)
|
266
|
+
# tree.evict(10, evict_callback)
|
267
|
+
# tree.pretty_print()
|
@@ -0,0 +1,59 @@
|
|
1
|
+
import random
|
2
|
+
from collections import defaultdict
|
3
|
+
|
4
|
+
|
5
|
+
class ScheduleHeuristic:
|
6
|
+
def __init__(
|
7
|
+
self,
|
8
|
+
schedule_heuristic,
|
9
|
+
max_running_seqs,
|
10
|
+
max_prefill_num_tokens,
|
11
|
+
max_total_num_tokens,
|
12
|
+
tree_cache,
|
13
|
+
):
|
14
|
+
self.schedule_heuristic = schedule_heuristic
|
15
|
+
self.max_running_seqs = max_running_seqs
|
16
|
+
self.max_prefill_num_tokens = max_prefill_num_tokens
|
17
|
+
self.max_total_num_tokens = max_total_num_tokens
|
18
|
+
self.tree_cache = tree_cache
|
19
|
+
|
20
|
+
def get_priority_queue(self, forward_queue):
|
21
|
+
if self.schedule_heuristic == "lpm":
|
22
|
+
# longest prefix match
|
23
|
+
forward_queue.sort(key=lambda x: -len(x.prefix_indices))
|
24
|
+
return forward_queue
|
25
|
+
elif self.schedule_heuristic == "random":
|
26
|
+
random.shuffle(forward_queue)
|
27
|
+
return forward_queue
|
28
|
+
elif self.schedule_heuristic == "fcfs":
|
29
|
+
return forward_queue
|
30
|
+
elif self.schedule_heuristic == "dfs-weight":
|
31
|
+
last_node_to_reqs = defaultdict(list)
|
32
|
+
for req in forward_queue:
|
33
|
+
last_node_to_reqs[req.last_node].append(req)
|
34
|
+
|
35
|
+
node_to_weight = defaultdict(int)
|
36
|
+
for node in last_node_to_reqs:
|
37
|
+
node_to_weight[node] = len(last_node_to_reqs[node])
|
38
|
+
self.calc_weight(self.tree_cache.root_node, node_to_weight)
|
39
|
+
|
40
|
+
q = []
|
41
|
+
self.get_dfs_priority(
|
42
|
+
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
|
43
|
+
)
|
44
|
+
assert len(q) == len(forward_queue)
|
45
|
+
return q
|
46
|
+
else:
|
47
|
+
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
|
48
|
+
|
49
|
+
def calc_weight(self, cur_node, node_to_weight):
|
50
|
+
for child in cur_node.children.values():
|
51
|
+
self.calc_weight(child, node_to_weight)
|
52
|
+
node_to_weight[cur_node] += node_to_weight[child]
|
53
|
+
|
54
|
+
def get_dfs_priority(self, cur_node, node_to_priority, last_node_to_reqs, q):
|
55
|
+
childs = [child for child in cur_node.children.values()]
|
56
|
+
childs.sort(key=lambda x: -node_to_priority[x])
|
57
|
+
for child in childs:
|
58
|
+
self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
|
59
|
+
q.extend(last_node_to_reqs[cur_node])
|