sglang 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +57 -2
- sglang/api.py +8 -5
- sglang/backend/anthropic.py +18 -4
- sglang/backend/openai.py +2 -1
- sglang/backend/runtime_endpoint.py +18 -5
- sglang/backend/vertexai.py +1 -0
- sglang/global_config.py +5 -1
- sglang/lang/chat_template.py +83 -2
- sglang/lang/interpreter.py +92 -35
- sglang/lang/ir.py +12 -9
- sglang/lang/tracer.py +6 -4
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/constrained/fsm_cache.py +1 -0
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/conversation.py +2 -2
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +10 -2
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +1 -0
- sglang/srt/layers/logits_processor.py +114 -54
- sglang/srt/layers/radix_attention.py +2 -1
- sglang/srt/layers/token_attention.py +1 -0
- sglang/srt/managers/detokenizer_manager.py +5 -1
- sglang/srt/managers/io_struct.py +27 -3
- sglang/srt/managers/router/infer_batch.py +97 -48
- sglang/srt/managers/router/manager.py +11 -8
- sglang/srt/managers/router/model_rpc.py +169 -90
- sglang/srt/managers/router/model_runner.py +110 -166
- sglang/srt/managers/router/radix_cache.py +89 -51
- sglang/srt/managers/router/scheduler.py +17 -28
- sglang/srt/managers/tokenizer_manager.py +110 -33
- sglang/srt/memory_pool.py +5 -14
- sglang/srt/model_config.py +11 -0
- sglang/srt/models/commandr.py +372 -0
- sglang/srt/models/dbrx.py +412 -0
- sglang/srt/models/dbrx_config.py +281 -0
- sglang/srt/models/gemma.py +24 -25
- sglang/srt/models/llama2.py +25 -26
- sglang/srt/models/llava.py +8 -10
- sglang/srt/models/llavavid.py +307 -0
- sglang/srt/models/mixtral.py +29 -33
- sglang/srt/models/qwen.py +34 -25
- sglang/srt/models/qwen2.py +25 -26
- sglang/srt/models/stablelm.py +26 -26
- sglang/srt/models/yivl.py +3 -5
- sglang/srt/openai_api_adapter.py +356 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +91 -456
- sglang/srt/server_args.py +79 -49
- sglang/srt/utils.py +212 -47
- sglang/srt/weight_utils.py +417 -0
- sglang/test/test_programs.py +8 -7
- sglang/test/test_utils.py +195 -7
- sglang/utils.py +77 -26
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/METADATA +20 -18
- sglang-0.1.16.dist-info/RECORD +72 -0
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
@@ -27,44 +27,33 @@ class Scheduler:
|
|
27
27
|
return forward_queue
|
28
28
|
elif self.schedule_heuristic == "fcfs":
|
29
29
|
return forward_queue
|
30
|
-
elif self.schedule_heuristic == "weight":
|
30
|
+
elif self.schedule_heuristic == "dfs-weight":
|
31
31
|
last_node_to_reqs = defaultdict(list)
|
32
32
|
for req in forward_queue:
|
33
33
|
last_node_to_reqs[req.last_node].append(req)
|
34
|
-
for node in last_node_to_reqs:
|
35
|
-
last_node_to_reqs[node].sort(key=lambda x: -len(x.prefix_indices))
|
36
34
|
|
37
35
|
node_to_weight = defaultdict(int)
|
38
|
-
|
39
|
-
|
40
|
-
)
|
36
|
+
for node in last_node_to_reqs:
|
37
|
+
node_to_weight[node] = len(last_node_to_reqs[node])
|
38
|
+
self.calc_weight(self.tree_cache.root_node, node_to_weight)
|
41
39
|
|
42
|
-
|
43
|
-
self.
|
44
|
-
self.tree_cache.root_node, node_to_weight, last_node_to_reqs,
|
40
|
+
q = []
|
41
|
+
self.get_dfs_priority(
|
42
|
+
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
|
45
43
|
)
|
46
|
-
assert len(
|
47
|
-
return
|
44
|
+
assert len(q) == len(forward_queue)
|
45
|
+
return q
|
48
46
|
else:
|
49
47
|
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
|
50
48
|
|
51
|
-
def
|
52
|
-
node_to_weight[cur_node] = 1
|
53
|
-
if cur_node in last_node_to_reqs:
|
54
|
-
node_to_weight[cur_node] += len(last_node_to_reqs[cur_node])
|
49
|
+
def calc_weight(self, cur_node, node_to_weight):
|
55
50
|
for child in cur_node.children.values():
|
56
|
-
self.
|
51
|
+
self.calc_weight(child, node_to_weight)
|
57
52
|
node_to_weight[cur_node] += node_to_weight[child]
|
58
53
|
|
59
|
-
def
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
# print(f"{node_to_wight[node]} {len(node.value) if node.value is not None else 0}")
|
66
|
-
for child in visit_list:
|
67
|
-
self._get_weight_priority_recursive(
|
68
|
-
child, node_to_wight, last_node_to_reqs, tmp_queue
|
69
|
-
)
|
70
|
-
tmp_queue.extend(last_node_to_reqs[cur_node])
|
54
|
+
def get_dfs_priority(self, cur_node, node_to_priority, last_node_to_reqs, q):
|
55
|
+
childs = [child for child in cur_node.children.values()]
|
56
|
+
childs.sort(key=lambda x: -node_to_priority[x])
|
57
|
+
for child in childs:
|
58
|
+
self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
|
59
|
+
q.extend(last_node_to_reqs[cur_node])
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import asyncio
|
2
2
|
import concurrent.futures
|
3
3
|
import dataclasses
|
4
|
+
import logging
|
4
5
|
import multiprocessing as mp
|
5
6
|
import os
|
6
7
|
from typing import List
|
@@ -10,6 +11,7 @@ import transformers
|
|
10
11
|
import uvloop
|
11
12
|
import zmq
|
12
13
|
import zmq.asyncio
|
14
|
+
|
13
15
|
from sglang.srt.hf_transformers_utils import (
|
14
16
|
get_config,
|
15
17
|
get_context_length,
|
@@ -30,13 +32,14 @@ from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_
|
|
30
32
|
|
31
33
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
32
34
|
|
35
|
+
logger = logging.getLogger(__name__)
|
36
|
+
|
33
37
|
|
34
38
|
@dataclasses.dataclass
|
35
39
|
class ReqState:
|
36
40
|
out_list: List
|
37
41
|
finished: bool
|
38
42
|
event: asyncio.Event
|
39
|
-
lock: asyncio.Lock
|
40
43
|
|
41
44
|
|
42
45
|
global global_processor
|
@@ -57,21 +60,29 @@ def get_pixel_values(
|
|
57
60
|
):
|
58
61
|
try:
|
59
62
|
processor = processor or global_processor
|
60
|
-
image = load_image(image_data)
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
pixel_values =
|
67
|
-
|
68
|
-
pixel_values = process_anyres_image(
|
69
|
-
image, processor.image_processor, image_grid_pinpoints
|
70
|
-
)
|
63
|
+
image, image_size = load_image(image_data)
|
64
|
+
if image_size != None:
|
65
|
+
image_hash = hash(image_data)
|
66
|
+
pixel_values = processor.image_processor(image)["pixel_values"]
|
67
|
+
for _ in range(len(pixel_values)):
|
68
|
+
pixel_values[_] = pixel_values[_].astype(np.float16)
|
69
|
+
pixel_values = np.stack(pixel_values, axis=0)
|
70
|
+
return pixel_values, image_hash, image_size
|
71
71
|
else:
|
72
|
-
|
73
|
-
|
74
|
-
|
72
|
+
image_hash = hash(image_data)
|
73
|
+
if image_aspect_ratio == "pad":
|
74
|
+
image = expand2square(
|
75
|
+
image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
|
76
|
+
)
|
77
|
+
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
78
|
+
elif image_aspect_ratio == "anyres":
|
79
|
+
pixel_values = process_anyres_image(
|
80
|
+
image, processor.image_processor, image_grid_pinpoints
|
81
|
+
)
|
82
|
+
else:
|
83
|
+
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
84
|
+
pixel_values = pixel_values.astype(np.float16)
|
85
|
+
return pixel_values, image_hash, image.size
|
75
86
|
except Exception:
|
76
87
|
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
77
88
|
|
@@ -81,6 +92,7 @@ class TokenizerManager:
|
|
81
92
|
self,
|
82
93
|
server_args: ServerArgs,
|
83
94
|
port_args: PortArgs,
|
95
|
+
model_overide_args: dict = None,
|
84
96
|
):
|
85
97
|
self.server_args = server_args
|
86
98
|
|
@@ -93,9 +105,10 @@ class TokenizerManager:
|
|
93
105
|
|
94
106
|
self.model_path = server_args.model_path
|
95
107
|
self.hf_config = get_config(
|
96
|
-
self.model_path,
|
108
|
+
self.model_path,
|
109
|
+
trust_remote_code=server_args.trust_remote_code,
|
110
|
+
model_overide_args=model_overide_args,
|
97
111
|
)
|
98
|
-
|
99
112
|
self.context_len = get_context_length(self.hf_config)
|
100
113
|
|
101
114
|
if is_multimodal_model(self.model_path):
|
@@ -144,11 +157,21 @@ class TokenizerManager:
|
|
144
157
|
if self.to_create_loop:
|
145
158
|
await self.create_handle_loop()
|
146
159
|
|
147
|
-
is_single =
|
148
|
-
|
160
|
+
is_single = obj.is_single
|
149
161
|
if is_single:
|
150
162
|
rid = obj.rid
|
151
|
-
|
163
|
+
|
164
|
+
if obj.input_ids is None:
|
165
|
+
input_ids = self.tokenizer.encode(obj.text)
|
166
|
+
else:
|
167
|
+
input_ids = obj.input_ids
|
168
|
+
|
169
|
+
if len(input_ids) >= self.context_len:
|
170
|
+
raise ValueError(
|
171
|
+
f"The input ({len(input_ids)} tokens) is longer than the "
|
172
|
+
f"model's context length ({self.context_len} tokens)"
|
173
|
+
)
|
174
|
+
|
152
175
|
sampling_params = SamplingParams(**obj.sampling_params)
|
153
176
|
if sampling_params.max_new_tokens != 0:
|
154
177
|
sampling_params.normalize(self.tokenizer)
|
@@ -174,18 +197,26 @@ class TokenizerManager:
|
|
174
197
|
sampling_params=sampling_params,
|
175
198
|
return_logprob=obj.return_logprob,
|
176
199
|
logprob_start_len=obj.logprob_start_len,
|
200
|
+
top_logprobs_num=obj.top_logprobs_num,
|
177
201
|
stream=obj.stream,
|
178
202
|
)
|
179
203
|
self.send_to_router.send_pyobj(tokenized_obj)
|
180
204
|
|
181
|
-
lock = asyncio.Lock()
|
182
205
|
event = asyncio.Event()
|
183
|
-
state = ReqState([], False, event
|
206
|
+
state = ReqState([], False, event)
|
184
207
|
self.rid_to_state[rid] = state
|
185
208
|
|
186
209
|
while True:
|
187
210
|
await event.wait()
|
188
|
-
|
211
|
+
out = self.convert_logprob_style(state.out_list[-1],
|
212
|
+
obj.return_logprob,
|
213
|
+
obj.top_logprobs_num,
|
214
|
+
obj.return_text_in_logprobs)
|
215
|
+
|
216
|
+
if self.server_args.log_requests and state.finished:
|
217
|
+
logger.info(f"in={obj.text}, out={out}")
|
218
|
+
|
219
|
+
yield out
|
189
220
|
state.out_list = []
|
190
221
|
if state.finished:
|
191
222
|
del self.rid_to_state[rid]
|
@@ -193,10 +224,22 @@ class TokenizerManager:
|
|
193
224
|
event.clear()
|
194
225
|
else:
|
195
226
|
assert obj.stream is False
|
196
|
-
|
227
|
+
|
228
|
+
if obj.input_ids is None:
|
229
|
+
bs = len(obj.text)
|
230
|
+
else:
|
231
|
+
bs = len(obj.input_ids)
|
232
|
+
|
197
233
|
for i in range(bs):
|
198
234
|
rid = obj.rid[i]
|
199
|
-
|
235
|
+
|
236
|
+
if obj.input_ids is None:
|
237
|
+
input_text = obj.text[i]
|
238
|
+
input_ids = self.tokenizer.encode(obj.text[i])
|
239
|
+
else:
|
240
|
+
input_text = None
|
241
|
+
input_ids = obj.input_ids[i]
|
242
|
+
|
200
243
|
sampling_params = SamplingParams(**obj.sampling_params[i])
|
201
244
|
if sampling_params.max_new_tokens != 0:
|
202
245
|
sampling_params.normalize(self.tokenizer)
|
@@ -209,7 +252,7 @@ class TokenizerManager:
|
|
209
252
|
)
|
210
253
|
tokenized_obj = TokenizedGenerateReqInput(
|
211
254
|
rid=rid,
|
212
|
-
input_text=
|
255
|
+
input_text=input_text,
|
213
256
|
input_ids=input_ids,
|
214
257
|
pixel_values=pixel_values,
|
215
258
|
image_hash=image_hash,
|
@@ -217,13 +260,13 @@ class TokenizerManager:
|
|
217
260
|
sampling_params=sampling_params,
|
218
261
|
return_logprob=obj.return_logprob[i],
|
219
262
|
logprob_start_len=obj.logprob_start_len[i],
|
263
|
+
top_logprobs_num=obj.top_logprobs_num[i],
|
220
264
|
stream=obj.stream,
|
221
265
|
)
|
222
266
|
self.send_to_router.send_pyobj(tokenized_obj)
|
223
267
|
|
224
|
-
lock = asyncio.Lock()
|
225
268
|
event = asyncio.Event()
|
226
|
-
state = ReqState([], False, event
|
269
|
+
state = ReqState([], False, event)
|
227
270
|
self.rid_to_state[rid] = state
|
228
271
|
|
229
272
|
output_list = []
|
@@ -231,16 +274,16 @@ class TokenizerManager:
|
|
231
274
|
rid = obj.rid[i]
|
232
275
|
state = self.rid_to_state[rid]
|
233
276
|
await state.event.wait()
|
234
|
-
output_list.append(
|
277
|
+
output_list.append(
|
278
|
+
self.convert_logprob_style(state.out_list[-1],
|
279
|
+
obj.return_logprob[i],
|
280
|
+
obj.top_logprobs_num[i],
|
281
|
+
obj.return_text_in_logprobs))
|
235
282
|
assert state.finished
|
236
283
|
del self.rid_to_state[rid]
|
237
284
|
|
238
285
|
yield output_list
|
239
286
|
|
240
|
-
async def detokenize(self, obj: DetokenizeReqInput):
|
241
|
-
token_texts = self.tokenizer.convert_ids_to_tokens(obj.input_ids)
|
242
|
-
return [t.decode() if isinstance(t, bytes) else t for t in token_texts]
|
243
|
-
|
244
287
|
async def flush_cache(self):
|
245
288
|
flush_cache_req = FlushCacheReq()
|
246
289
|
self.send_to_router.send_pyobj(flush_cache_req)
|
@@ -267,3 +310,37 @@ class TokenizerManager:
|
|
267
310
|
state.event.set()
|
268
311
|
else:
|
269
312
|
raise ValueError(f"Invalid object: {recv_obj}")
|
313
|
+
|
314
|
+
def convert_logprob_style(self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs):
|
315
|
+
if return_logprob:
|
316
|
+
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
|
317
|
+
ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
|
318
|
+
)
|
319
|
+
ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
|
320
|
+
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
|
321
|
+
)
|
322
|
+
if top_logprobs_num > 0:
|
323
|
+
ret["meta_info"]["prefill_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
324
|
+
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
325
|
+
)
|
326
|
+
ret["meta_info"]["decode_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
327
|
+
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
328
|
+
)
|
329
|
+
return ret
|
330
|
+
|
331
|
+
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
|
332
|
+
if not decode_to_text:
|
333
|
+
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
334
|
+
|
335
|
+
token_ids = [tid for _, tid in token_logprobs]
|
336
|
+
token_texts = self.tokenizer.batch_decode(token_ids)
|
337
|
+
return [
|
338
|
+
(logprob, token_id, token_text)
|
339
|
+
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
|
340
|
+
]
|
341
|
+
|
342
|
+
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text):
|
343
|
+
for i, t in enumerate(top_logprobs):
|
344
|
+
if t:
|
345
|
+
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
|
346
|
+
return top_logprobs
|
sglang/srt/memory_pool.py
CHANGED
@@ -31,9 +31,6 @@ class ReqToTokenPool:
|
|
31
31
|
self.can_use_mem_size += free_index.shape[0]
|
32
32
|
self.mem_state[free_index] = 1
|
33
33
|
|
34
|
-
# if self.can_use_mem_size == len(self.mem_state):
|
35
|
-
# print(f"ReqToTokenPool: freed all. size = {self.can_use_mem_size}.")
|
36
|
-
|
37
34
|
def clear(self):
|
38
35
|
self.mem_state.fill_(1)
|
39
36
|
self.can_use_mem_size = len(self.mem_state)
|
@@ -42,7 +39,7 @@ class ReqToTokenPool:
|
|
42
39
|
class TokenToKVPool:
|
43
40
|
def __init__(self, size, dtype, head_num, head_dim, layer_num):
|
44
41
|
self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda")
|
45
|
-
self.
|
42
|
+
self.total_ref_ct = 0
|
46
43
|
|
47
44
|
# [size, key/value, head_num, head_dim] for each layer
|
48
45
|
self.kv_data = [
|
@@ -83,9 +80,6 @@ class TokenToKVPool:
|
|
83
80
|
self.add_refs(select_index)
|
84
81
|
return select_index.to(torch.int32), start_loc, start_loc + need_size
|
85
82
|
|
86
|
-
def free(self, free_index):
|
87
|
-
return self.decrease_refs(free_index)
|
88
|
-
|
89
83
|
def used_size(self):
|
90
84
|
return len(torch.nonzero(self.mem_state).squeeze(1))
|
91
85
|
|
@@ -93,20 +87,17 @@ class TokenToKVPool:
|
|
93
87
|
return torch.sum(self.mem_state == 0).item()
|
94
88
|
|
95
89
|
def add_refs(self, token_index: torch.Tensor):
|
96
|
-
self.
|
90
|
+
self.total_ref_ct += len(token_index)
|
97
91
|
self.mem_state[token_index] += 1
|
98
92
|
|
99
|
-
def
|
100
|
-
self.
|
93
|
+
def dec_refs(self, token_index: torch.Tensor):
|
94
|
+
self.total_ref_ct -= len(token_index)
|
101
95
|
self.mem_state[token_index] -= 1
|
102
96
|
|
103
97
|
num_freed = torch.sum(self.mem_state[token_index] == 0)
|
104
98
|
|
105
|
-
# if self.alloc_ct == 0:
|
106
|
-
# print(f"TokenToKVPool: freed all. size = {len(self.mem_state)}.")
|
107
|
-
|
108
99
|
return num_freed
|
109
100
|
|
110
101
|
def clear(self):
|
111
102
|
self.mem_state.fill_(0)
|
112
|
-
self.
|
103
|
+
self.total_ref_ct = 0
|
sglang/srt/model_config.py
CHANGED
@@ -10,12 +10,16 @@ class ModelConfig:
|
|
10
10
|
trust_remote_code: bool = True,
|
11
11
|
revision: Optional[str] = None,
|
12
12
|
context_length: Optional[int] = None,
|
13
|
+
model_overide_args: Optional[dict] = None,
|
13
14
|
) -> None:
|
14
15
|
self.path = path
|
15
16
|
self.trust_remote_code = trust_remote_code
|
16
17
|
self.revision = revision
|
17
18
|
self.hf_config = get_config(self.path, trust_remote_code, revision)
|
18
19
|
|
20
|
+
if model_overide_args is not None:
|
21
|
+
self.hf_config.update(model_overide_args)
|
22
|
+
|
19
23
|
if context_length is not None:
|
20
24
|
self.context_len = context_length
|
21
25
|
else:
|
@@ -29,6 +33,13 @@ class ModelConfig:
|
|
29
33
|
)
|
30
34
|
self.num_attention_heads = self.hf_config.num_attention_heads
|
31
35
|
self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
|
36
|
+
|
37
|
+
# for Dbrx and MPT models
|
38
|
+
if self.hf_config.model_type in ["dbrx", "mpt"]:
|
39
|
+
self.num_key_value_heads = getattr(
|
40
|
+
self.hf_config.attn_config, "kv_n_heads", None
|
41
|
+
)
|
42
|
+
|
32
43
|
if self.num_key_value_heads is None:
|
33
44
|
self.num_key_value_heads = self.num_attention_heads
|
34
45
|
self.hidden_size = self.hf_config.hidden_size
|