sglang 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +3 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +8 -1
  8. sglang/lang/interpreter.py +114 -67
  9. sglang/lang/ir.py +17 -2
  10. sglang/srt/constrained/fsm_cache.py +3 -0
  11. sglang/srt/flush_cache.py +1 -1
  12. sglang/srt/hf_transformers_utils.py +75 -1
  13. sglang/srt/layers/extend_attention.py +17 -0
  14. sglang/srt/layers/fused_moe.py +485 -0
  15. sglang/srt/layers/logits_processor.py +12 -7
  16. sglang/srt/layers/radix_attention.py +10 -3
  17. sglang/srt/layers/token_attention.py +16 -1
  18. sglang/srt/managers/controller/dp_worker.py +110 -0
  19. sglang/srt/managers/controller/infer_batch.py +619 -0
  20. sglang/srt/managers/controller/manager_multi.py +191 -0
  21. sglang/srt/managers/controller/manager_single.py +97 -0
  22. sglang/srt/managers/controller/model_runner.py +462 -0
  23. sglang/srt/managers/controller/radix_cache.py +267 -0
  24. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  25. sglang/srt/managers/controller/tp_worker.py +791 -0
  26. sglang/srt/managers/detokenizer_manager.py +45 -45
  27. sglang/srt/managers/io_struct.py +15 -11
  28. sglang/srt/managers/router/infer_batch.py +103 -59
  29. sglang/srt/managers/router/manager.py +1 -1
  30. sglang/srt/managers/router/model_rpc.py +175 -122
  31. sglang/srt/managers/router/model_runner.py +91 -104
  32. sglang/srt/managers/router/radix_cache.py +7 -1
  33. sglang/srt/managers/router/scheduler.py +6 -6
  34. sglang/srt/managers/tokenizer_manager.py +152 -89
  35. sglang/srt/model_config.py +4 -5
  36. sglang/srt/models/commandr.py +10 -13
  37. sglang/srt/models/dbrx.py +9 -15
  38. sglang/srt/models/gemma.py +8 -15
  39. sglang/srt/models/grok.py +671 -0
  40. sglang/srt/models/llama2.py +19 -15
  41. sglang/srt/models/llava.py +84 -20
  42. sglang/srt/models/llavavid.py +11 -20
  43. sglang/srt/models/mixtral.py +248 -118
  44. sglang/srt/models/mixtral_quant.py +373 -0
  45. sglang/srt/models/qwen.py +9 -13
  46. sglang/srt/models/qwen2.py +11 -13
  47. sglang/srt/models/stablelm.py +9 -15
  48. sglang/srt/models/yivl.py +17 -22
  49. sglang/srt/openai_api_adapter.py +140 -95
  50. sglang/srt/openai_protocol.py +10 -1
  51. sglang/srt/server.py +77 -42
  52. sglang/srt/server_args.py +51 -6
  53. sglang/srt/utils.py +124 -66
  54. sglang/test/test_programs.py +44 -0
  55. sglang/test/test_utils.py +32 -1
  56. sglang/utils.py +22 -4
  57. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/METADATA +15 -9
  58. sglang-0.1.17.dist-info/RECORD +81 -0
  59. sglang/srt/backend_config.py +0 -13
  60. sglang/srt/models/dbrx_config.py +0 -281
  61. sglang/srt/weight_utils.py +0 -417
  62. sglang-0.1.16.dist-info/RECORD +0 -72
  63. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  64. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  65. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,267 @@
1
+ import heapq
2
+ import time
3
+ from collections import defaultdict
4
+
5
+ import torch
6
+
7
+
8
+ class TreeNode:
9
+ def __init__(self):
10
+ self.children = defaultdict(TreeNode)
11
+ self.parent = None
12
+ self.key = None
13
+ self.value = None
14
+ self.lock_ref = 0
15
+ self.last_access_time = time.time()
16
+
17
+ def __lt__(self, other: "TreeNode"):
18
+ return self.last_access_time < other.last_access_time
19
+
20
+
21
+ def _key_match(key0, key1):
22
+ i = 0
23
+ for k0, k1 in zip(key0, key1):
24
+ if k0 != k1:
25
+ break
26
+ i += 1
27
+ return i
28
+
29
+
30
+ class RadixCache:
31
+ def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
32
+ self.req_to_token_pool = req_to_token_pool
33
+ self.token_to_kv_pool = token_to_kv_pool
34
+ self.disable = disable
35
+ self.reset()
36
+
37
+ ##### Public API #####
38
+
39
+ def reset(self):
40
+ self.root_node = TreeNode()
41
+ self.root_node.key = []
42
+ self.root_node.value = []
43
+ self.root_node.lock_ref = 1
44
+ self.evictable_size_ = 0
45
+
46
+ def match_prefix(self, key):
47
+ if self.disable:
48
+ return [], self.root_node
49
+
50
+ value = []
51
+ last_node = [self.root_node]
52
+ self._match_prefix_helper(self.root_node, key, value, last_node)
53
+ if value:
54
+ value = torch.concat(value)
55
+ else:
56
+ value = torch.tensor([], dtype=torch.int64)
57
+ return value, last_node[0]
58
+
59
+ def insert(self, key, value=None):
60
+ if self.disable:
61
+ return 0
62
+
63
+ if value is None:
64
+ value = [x for x in key]
65
+ return self._insert_helper(self.root_node, key, value)
66
+
67
+ def cache_req(
68
+ self,
69
+ token_ids,
70
+ last_uncached_pos,
71
+ req_pool_idx,
72
+ del_in_memory_pool=True,
73
+ old_last_node=None,
74
+ ):
75
+ # Insert the request into radix cache
76
+ indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
77
+ new_prefix_len = self.insert(token_ids, indices.clone())
78
+
79
+ if self.disable:
80
+ if del_in_memory_pool:
81
+ self.token_to_kv_pool.dec_refs(indices)
82
+ else:
83
+ return torch.tensor([], dtype=torch.int64), self.root_node
84
+
85
+ # Radix Cache takes one ref in memory pool
86
+ self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len])
87
+
88
+ if del_in_memory_pool:
89
+ self.req_to_token_pool.free(req_pool_idx)
90
+ else:
91
+ cached_indices, new_last_node = self.match_prefix(token_ids)
92
+ assert len(cached_indices) == len(token_ids)
93
+
94
+ self.req_to_token_pool.req_to_token[
95
+ req_pool_idx, last_uncached_pos : len(cached_indices)
96
+ ] = cached_indices[last_uncached_pos:]
97
+ self.dec_lock_ref(old_last_node)
98
+ self.inc_lock_ref(new_last_node)
99
+ return cached_indices, new_last_node
100
+
101
+ def pretty_print(self):
102
+ self._print_helper(self.root_node, 0)
103
+ print(f"#tokens: {self.total_size()}")
104
+
105
+ def total_size(self):
106
+ return self._total_size_helper(self.root_node)
107
+
108
+ def evict(self, num_tokens, evict_callback):
109
+ if self.disable:
110
+ return
111
+
112
+ leaves = self._collect_leaves()
113
+ heapq.heapify(leaves)
114
+
115
+ num_evicted = 0
116
+ while num_evicted < num_tokens and len(leaves):
117
+ x = heapq.heappop(leaves)
118
+
119
+ if x == self.root_node:
120
+ break
121
+ if x.lock_ref > 0:
122
+ continue
123
+
124
+ num_evicted += evict_callback(x.value)
125
+ self._delete_leaf(x)
126
+
127
+ if len(x.parent.children) == 0:
128
+ heapq.heappush(leaves, x.parent)
129
+
130
+ def inc_lock_ref(self, node: TreeNode):
131
+ delta = 0
132
+ while node != self.root_node:
133
+ if node.lock_ref == 0:
134
+ self.evictable_size_ -= len(node.value)
135
+ delta -= len(node.value)
136
+ node.lock_ref += 1
137
+ node = node.parent
138
+ return delta
139
+
140
+ def dec_lock_ref(self, node: TreeNode):
141
+ delta = 0
142
+ while node != self.root_node:
143
+ if node.lock_ref == 1:
144
+ self.evictable_size_ += len(node.value)
145
+ delta += len(node.value)
146
+ node.lock_ref -= 1
147
+ node = node.parent
148
+ return delta
149
+
150
+ def evictable_size(self):
151
+ return self.evictable_size_
152
+
153
+ ##### Internal Helper Functions #####
154
+
155
+ def _match_prefix_helper(self, node, key, value, last_node):
156
+ node.last_access_time = time.time()
157
+ if len(key) == 0:
158
+ return
159
+
160
+ if key[0] in node.children.keys():
161
+ child = node.children[key[0]]
162
+ prefix_len = _key_match(child.key, key)
163
+ if prefix_len < len(child.key):
164
+ new_node = self._split_node(child.key, child, prefix_len)
165
+ value.append(new_node.value)
166
+ last_node[0] = new_node
167
+ else:
168
+ value.append(child.value)
169
+ last_node[0] = child
170
+ self._match_prefix_helper(child, key[prefix_len:], value, last_node)
171
+
172
+ def _split_node(self, key, child: TreeNode, split_len):
173
+ # new_node -> child
174
+ new_node = TreeNode()
175
+ new_node.children = {key[split_len:][0]: child}
176
+ new_node.parent = child.parent
177
+ new_node.lock_ref = child.lock_ref
178
+ new_node.key = child.key[:split_len]
179
+ new_node.value = child.value[:split_len]
180
+ child.parent = new_node
181
+ child.key = child.key[split_len:]
182
+ child.value = child.value[split_len:]
183
+ new_node.parent.children[key[:split_len][0]] = new_node
184
+ return new_node
185
+
186
+ def _insert_helper(self, node, key, value):
187
+ node.last_access_time = time.time()
188
+ if len(key) == 0:
189
+ return 0
190
+
191
+ if key[0] in node.children.keys():
192
+ child = node.children[key[0]]
193
+ prefix_len = _key_match(child.key, key)
194
+
195
+ if prefix_len == len(child.key):
196
+ if prefix_len == len(key):
197
+ return prefix_len
198
+ else:
199
+ key = key[prefix_len:]
200
+ value = value[prefix_len:]
201
+ return prefix_len + self._insert_helper(child, key, value)
202
+
203
+ new_node = self._split_node(child.key, child, prefix_len)
204
+ return prefix_len + self._insert_helper(
205
+ new_node, key[prefix_len:], value[prefix_len:]
206
+ )
207
+
208
+ if len(key):
209
+ new_node = TreeNode()
210
+ new_node.parent = node
211
+ new_node.key = key
212
+ new_node.value = value
213
+ node.children[key[0]] = new_node
214
+ self.evictable_size_ += len(value)
215
+ return 0
216
+
217
+ def _print_helper(self, node: TreeNode, indent):
218
+ for _, child in node.children.items():
219
+ print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
220
+ self._print_helper(child, indent=indent + 2)
221
+
222
+ def _delete_leaf(self, node):
223
+ for k, v in node.parent.children.items():
224
+ if v == node:
225
+ break
226
+ del node.parent.children[k]
227
+ self.evictable_size_ -= len(node.key)
228
+
229
+ def _total_size_helper(self, node):
230
+ x = len(node.value)
231
+ for child in node.children.values():
232
+ x += self._total_size_helper(child)
233
+ return x
234
+
235
+ def _collect_leaves(self):
236
+ ret_list = []
237
+
238
+ def dfs_(cur_node):
239
+ if len(cur_node.children) == 0:
240
+ ret_list.append(cur_node)
241
+
242
+ for x in cur_node.children.values():
243
+ dfs_(x)
244
+
245
+ dfs_(self.root_node)
246
+ return ret_list
247
+
248
+
249
+ if __name__ == "__main__":
250
+ tree = RadixCache(None, None, False)
251
+
252
+ tree.insert("Hello")
253
+ tree.insert("Hello")
254
+ tree.insert("Hello_L.A.!")
255
+ # tree.insert("Hello_world! Happy")
256
+ # tree.insert("I love you!")
257
+ tree.pretty_print()
258
+
259
+ # print(tree.match_prefix("I love you! aha"))
260
+
261
+ # def evict_callback(x):
262
+ # print("evict", x)
263
+ # return len(x)
264
+
265
+ # tree.evict(5, evict_callback)
266
+ # tree.evict(10, evict_callback)
267
+ # tree.pretty_print()
@@ -0,0 +1,59 @@
1
+ import random
2
+ from collections import defaultdict
3
+
4
+
5
+ class ScheduleHeuristic:
6
+ def __init__(
7
+ self,
8
+ schedule_heuristic,
9
+ max_running_seqs,
10
+ max_prefill_num_tokens,
11
+ max_total_num_tokens,
12
+ tree_cache,
13
+ ):
14
+ self.schedule_heuristic = schedule_heuristic
15
+ self.max_running_seqs = max_running_seqs
16
+ self.max_prefill_num_tokens = max_prefill_num_tokens
17
+ self.max_total_num_tokens = max_total_num_tokens
18
+ self.tree_cache = tree_cache
19
+
20
+ def get_priority_queue(self, forward_queue):
21
+ if self.schedule_heuristic == "lpm":
22
+ # longest prefix match
23
+ forward_queue.sort(key=lambda x: -len(x.prefix_indices))
24
+ return forward_queue
25
+ elif self.schedule_heuristic == "random":
26
+ random.shuffle(forward_queue)
27
+ return forward_queue
28
+ elif self.schedule_heuristic == "fcfs":
29
+ return forward_queue
30
+ elif self.schedule_heuristic == "dfs-weight":
31
+ last_node_to_reqs = defaultdict(list)
32
+ for req in forward_queue:
33
+ last_node_to_reqs[req.last_node].append(req)
34
+
35
+ node_to_weight = defaultdict(int)
36
+ for node in last_node_to_reqs:
37
+ node_to_weight[node] = len(last_node_to_reqs[node])
38
+ self.calc_weight(self.tree_cache.root_node, node_to_weight)
39
+
40
+ q = []
41
+ self.get_dfs_priority(
42
+ self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
43
+ )
44
+ assert len(q) == len(forward_queue)
45
+ return q
46
+ else:
47
+ raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
48
+
49
+ def calc_weight(self, cur_node, node_to_weight):
50
+ for child in cur_node.children.values():
51
+ self.calc_weight(child, node_to_weight)
52
+ node_to_weight[cur_node] += node_to_weight[child]
53
+
54
+ def get_dfs_priority(self, cur_node, node_to_priority, last_node_to_reqs, q):
55
+ childs = [child for child in cur_node.children.values()]
56
+ childs.sort(key=lambda x: -node_to_priority[x])
57
+ for child in childs:
58
+ self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
59
+ q.extend(last_node_to_reqs[cur_node])