sglang 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +57 -2
- sglang/api.py +8 -5
- sglang/backend/anthropic.py +18 -4
- sglang/backend/openai.py +2 -1
- sglang/backend/runtime_endpoint.py +18 -5
- sglang/backend/vertexai.py +1 -0
- sglang/global_config.py +5 -1
- sglang/lang/chat_template.py +83 -2
- sglang/lang/interpreter.py +92 -35
- sglang/lang/ir.py +12 -9
- sglang/lang/tracer.py +6 -4
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/constrained/fsm_cache.py +1 -0
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/conversation.py +2 -2
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +10 -2
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +1 -0
- sglang/srt/layers/logits_processor.py +114 -54
- sglang/srt/layers/radix_attention.py +2 -1
- sglang/srt/layers/token_attention.py +1 -0
- sglang/srt/managers/detokenizer_manager.py +5 -1
- sglang/srt/managers/io_struct.py +27 -3
- sglang/srt/managers/router/infer_batch.py +97 -48
- sglang/srt/managers/router/manager.py +11 -8
- sglang/srt/managers/router/model_rpc.py +169 -90
- sglang/srt/managers/router/model_runner.py +110 -166
- sglang/srt/managers/router/radix_cache.py +89 -51
- sglang/srt/managers/router/scheduler.py +17 -28
- sglang/srt/managers/tokenizer_manager.py +110 -33
- sglang/srt/memory_pool.py +5 -14
- sglang/srt/model_config.py +11 -0
- sglang/srt/models/commandr.py +372 -0
- sglang/srt/models/dbrx.py +412 -0
- sglang/srt/models/dbrx_config.py +281 -0
- sglang/srt/models/gemma.py +24 -25
- sglang/srt/models/llama2.py +25 -26
- sglang/srt/models/llava.py +8 -10
- sglang/srt/models/llavavid.py +307 -0
- sglang/srt/models/mixtral.py +29 -33
- sglang/srt/models/qwen.py +34 -25
- sglang/srt/models/qwen2.py +25 -26
- sglang/srt/models/stablelm.py +26 -26
- sglang/srt/models/yivl.py +3 -5
- sglang/srt/openai_api_adapter.py +356 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +91 -456
- sglang/srt/server_args.py +79 -49
- sglang/srt/utils.py +212 -47
- sglang/srt/weight_utils.py +417 -0
- sglang/test/test_programs.py +8 -7
- sglang/test/test_utils.py +195 -7
- sglang/utils.py +77 -26
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/METADATA +20 -18
- sglang-0.1.16.dist-info/RECORD +72 -0
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
@@ -1,24 +1,36 @@
|
|
1
1
|
from dataclasses import dataclass
|
2
|
-
from enum import
|
2
|
+
from enum import IntEnum, auto
|
3
3
|
from typing import List
|
4
4
|
|
5
5
|
import numpy as np
|
6
6
|
import torch
|
7
|
+
|
7
8
|
from sglang.srt.managers.router.radix_cache import RadixCache
|
8
9
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
9
10
|
|
10
11
|
|
11
|
-
class ForwardMode(
|
12
|
+
class ForwardMode(IntEnum):
|
12
13
|
PREFILL = auto()
|
13
14
|
EXTEND = auto()
|
14
15
|
DECODE = auto()
|
15
16
|
|
16
17
|
|
17
|
-
class FinishReason(
|
18
|
-
LENGTH = auto()
|
18
|
+
class FinishReason(IntEnum):
|
19
19
|
EOS_TOKEN = auto()
|
20
|
+
LENGTH = auto()
|
20
21
|
STOP_STR = auto()
|
21
22
|
|
23
|
+
@staticmethod
|
24
|
+
def to_str(reason):
|
25
|
+
if reason == FinishReason.EOS_TOKEN:
|
26
|
+
return None
|
27
|
+
elif reason == FinishReason.LENGTH:
|
28
|
+
return "length"
|
29
|
+
elif reason == FinishReason.STOP_STR:
|
30
|
+
return "stop"
|
31
|
+
else:
|
32
|
+
return None
|
33
|
+
|
22
34
|
|
23
35
|
class Req:
|
24
36
|
def __init__(self, rid, input_text, input_ids):
|
@@ -30,6 +42,7 @@ class Req:
|
|
30
42
|
# Since jump forward may retokenize the prompt with partial outputs,
|
31
43
|
# we maintain the original prompt length to report the correct usage.
|
32
44
|
self.prompt_tokens = len(input_ids)
|
45
|
+
|
33
46
|
# The number of decoded tokens for token usage report. Note that
|
34
47
|
# this does not include the jump forward tokens.
|
35
48
|
self.completion_tokens_wo_jump_forward = 0
|
@@ -40,11 +53,11 @@ class Req:
|
|
40
53
|
self.image_offset = 0
|
41
54
|
self.pad_value = None
|
42
55
|
|
56
|
+
# Sampling parameters
|
43
57
|
self.sampling_params = None
|
44
|
-
self.return_logprob = False
|
45
|
-
self.logprob_start_len = 0
|
46
58
|
self.stream = False
|
47
59
|
|
60
|
+
# Check finish
|
48
61
|
self.tokenizer = None
|
49
62
|
self.finished = False
|
50
63
|
self.finish_reason = None
|
@@ -54,11 +67,17 @@ class Req:
|
|
54
67
|
self.prefix_indices = []
|
55
68
|
self.last_node = None
|
56
69
|
|
57
|
-
|
58
|
-
self.
|
59
|
-
self.
|
60
|
-
|
61
|
-
|
70
|
+
# Logprobs
|
71
|
+
self.return_logprob = False
|
72
|
+
self.logprob_start_len = 0
|
73
|
+
self.top_logprobs_num = 0
|
74
|
+
self.normalized_prompt_logprob = None
|
75
|
+
self.prefill_token_logprobs = None
|
76
|
+
self.decode_token_logprobs = None
|
77
|
+
self.prefill_top_logprobs = None
|
78
|
+
self.decode_top_logprobs = None
|
79
|
+
|
80
|
+
# Constrained decoding
|
62
81
|
self.regex_fsm = None
|
63
82
|
self.regex_fsm_state = 0
|
64
83
|
self.jump_forward_map = None
|
@@ -77,6 +96,9 @@ class Req:
|
|
77
96
|
)
|
78
97
|
if first_token.startswith("▁"):
|
79
98
|
old_output_str = " " + old_output_str
|
99
|
+
if self.input_text is None:
|
100
|
+
# TODO(lmzheng): This can be wrong. Check with Liangsheng.
|
101
|
+
self.input_text = self.tokenizer.decode(self.input_ids)
|
80
102
|
new_input_string = (
|
81
103
|
self.input_text
|
82
104
|
+ self.output_and_jump_forward_str
|
@@ -159,7 +181,10 @@ class Batch:
|
|
159
181
|
out_cache_loc: torch.Tensor = None
|
160
182
|
out_cache_cont_start: torch.Tensor = None
|
161
183
|
out_cache_cont_end: torch.Tensor = None
|
184
|
+
|
185
|
+
# for processing logprobs
|
162
186
|
return_logprob: bool = False
|
187
|
+
top_logprobs_nums: List[int] = None
|
163
188
|
|
164
189
|
# for multimodal
|
165
190
|
pixel_values: List[torch.Tensor] = None
|
@@ -229,12 +254,11 @@ class Batch:
|
|
229
254
|
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
230
255
|
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
231
256
|
if out_cache_loc is None:
|
232
|
-
|
233
|
-
|
234
|
-
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
257
|
+
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.dec_refs)
|
258
|
+
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
235
259
|
|
236
260
|
if out_cache_loc is None:
|
237
|
-
print("Prefill out of memory. This should
|
261
|
+
print("Prefill out of memory. This should never happen.")
|
238
262
|
self.tree_cache.pretty_print()
|
239
263
|
exit()
|
240
264
|
|
@@ -245,10 +269,14 @@ class Batch:
|
|
245
269
|
] = out_cache_loc[pt : pt + extend_lens[i]]
|
246
270
|
pt += extend_lens[i]
|
247
271
|
|
248
|
-
# Handle logit bias
|
249
|
-
logit_bias =
|
272
|
+
# Handle logit bias but only allocate when needed
|
273
|
+
logit_bias = None
|
250
274
|
for i in range(bs):
|
251
275
|
if reqs[i].sampling_params.dtype == "int":
|
276
|
+
if logit_bias is None:
|
277
|
+
logit_bias = torch.zeros(
|
278
|
+
(bs, vocab_size), dtype=torch.float32, device=device
|
279
|
+
)
|
252
280
|
logit_bias[i] = int_token_logit_bias
|
253
281
|
|
254
282
|
# Set fields
|
@@ -266,6 +294,7 @@ class Batch:
|
|
266
294
|
self.position_ids_offsets = position_ids_offsets
|
267
295
|
self.extend_num_tokens = extend_num_tokens
|
268
296
|
self.out_cache_loc = out_cache_loc
|
297
|
+
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
269
298
|
|
270
299
|
self.temperatures = torch.tensor(
|
271
300
|
[r.sampling_params.temperature for r in reqs],
|
@@ -295,8 +324,8 @@ class Batch:
|
|
295
324
|
if self.token_to_kv_pool.available_size() >= bs:
|
296
325
|
return True
|
297
326
|
|
298
|
-
|
299
|
-
|
327
|
+
self.tree_cache.evict(bs, self.token_to_kv_pool.dec_refs)
|
328
|
+
|
300
329
|
if self.token_to_kv_pool.available_size() >= bs:
|
301
330
|
return True
|
302
331
|
|
@@ -310,27 +339,27 @@ class Batch:
|
|
310
339
|
)
|
311
340
|
|
312
341
|
retracted_reqs = []
|
313
|
-
|
314
|
-
|
342
|
+
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
343
|
+
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
315
344
|
while self.token_to_kv_pool.available_size() < len(self.reqs):
|
316
345
|
idx = sorted_indices.pop()
|
317
346
|
req = self.reqs[idx]
|
318
347
|
retracted_reqs.append(req)
|
319
348
|
|
320
|
-
|
349
|
+
# TODO: apply more fine-grained retraction
|
350
|
+
last_uncached_pos = len(req.prefix_indices)
|
351
|
+
token_indices = self.req_to_token_pool.req_to_token[
|
352
|
+
req_pool_indices_cpu[idx]
|
353
|
+
][last_uncached_pos : seq_lens_cpu[idx]]
|
354
|
+
self.token_to_kv_pool.dec_refs(token_indices)
|
355
|
+
|
356
|
+
self.tree_cache.dec_lock_ref(req.last_node)
|
321
357
|
req.prefix_indices = None
|
322
358
|
req.last_node = None
|
323
359
|
req.extend_input_len = 0
|
324
360
|
req.output_ids = []
|
325
361
|
req.regex_fsm_state = 0
|
326
362
|
|
327
|
-
# TODO: apply more fine-grained retraction
|
328
|
-
|
329
|
-
token_indices = self.req_to_token_pool.req_to_token[
|
330
|
-
req_pool_indices_np[idx]
|
331
|
-
][: seq_lens_np[idx]]
|
332
|
-
self.token_to_kv_pool.free(token_indices)
|
333
|
-
|
334
363
|
self.filter_batch(sorted_indices)
|
335
364
|
|
336
365
|
return retracted_reqs
|
@@ -349,20 +378,18 @@ class Batch:
|
|
349
378
|
if len(jump_forward_str) <= 1:
|
350
379
|
continue
|
351
380
|
|
352
|
-
# insert the old request into tree_cache
|
353
|
-
token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
|
354
381
|
if req_pool_indices_cpu is None:
|
355
|
-
req_pool_indices_cpu = self.req_pool_indices.
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
382
|
+
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
383
|
+
|
384
|
+
# insert the old request into tree_cache
|
385
|
+
self.tree_cache.cache_req(
|
386
|
+
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
387
|
+
last_uncached_pos=len(req.prefix_indices),
|
388
|
+
req_pool_idx=req_pool_indices_cpu[i],
|
362
389
|
)
|
363
|
-
|
364
|
-
|
365
|
-
self.tree_cache.
|
390
|
+
|
391
|
+
# unlock the last node
|
392
|
+
self.tree_cache.dec_lock_ref(req.last_node)
|
366
393
|
|
367
394
|
# jump-forward
|
368
395
|
req.jump_forward_and_retokenize(jump_forward_str, next_state)
|
@@ -391,7 +418,7 @@ class Batch:
|
|
391
418
|
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
392
419
|
|
393
420
|
if self.out_cache_loc is None:
|
394
|
-
print("Decode out of memory. This should
|
421
|
+
print("Decode out of memory. This should never happen.")
|
395
422
|
self.tree_cache.pretty_print()
|
396
423
|
exit()
|
397
424
|
|
@@ -415,6 +442,7 @@ class Batch:
|
|
415
442
|
self.prefix_lens = None
|
416
443
|
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
417
444
|
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
445
|
+
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
418
446
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
419
447
|
|
420
448
|
for item in [
|
@@ -425,9 +453,12 @@ class Batch:
|
|
425
453
|
"presence_penalties",
|
426
454
|
"logit_bias",
|
427
455
|
]:
|
428
|
-
|
456
|
+
self_val = getattr(self, item, None)
|
457
|
+
# logit_bias can be None
|
458
|
+
if self_val is not None:
|
459
|
+
setattr(self, item, self_val[new_indices])
|
429
460
|
|
430
|
-
def merge(self, other):
|
461
|
+
def merge(self, other: "Batch"):
|
431
462
|
self.reqs.extend(other.reqs)
|
432
463
|
|
433
464
|
self.req_pool_indices = torch.concat(
|
@@ -439,6 +470,7 @@ class Batch:
|
|
439
470
|
[self.position_ids_offsets, other.position_ids_offsets]
|
440
471
|
)
|
441
472
|
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
473
|
+
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
442
474
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
443
475
|
|
444
476
|
for item in [
|
@@ -447,17 +479,34 @@ class Batch:
|
|
447
479
|
"top_ks",
|
448
480
|
"frequency_penalties",
|
449
481
|
"presence_penalties",
|
450
|
-
"logit_bias",
|
451
482
|
]:
|
452
|
-
|
453
|
-
|
483
|
+
self_val = getattr(self, item, None)
|
484
|
+
other_val = getattr(other, item, None)
|
485
|
+
setattr(self, item, torch.concat([self_val, other_val]))
|
486
|
+
|
487
|
+
# logit_bias can be None
|
488
|
+
if self.logit_bias is not None or other.logit_bias is not None:
|
489
|
+
vocab_size = (
|
490
|
+
self.logit_bias.shape[1]
|
491
|
+
if self.logit_bias is not None
|
492
|
+
else other.logit_bias.shape[1]
|
454
493
|
)
|
494
|
+
if self.logit_bias is None:
|
495
|
+
self.logit_bias = torch.zeros(
|
496
|
+
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
497
|
+
)
|
498
|
+
if other.logit_bias is None:
|
499
|
+
other.logit_bias = torch.zeros(
|
500
|
+
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
501
|
+
)
|
502
|
+
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
455
503
|
|
456
504
|
def sample(self, logits: torch.Tensor):
|
457
505
|
# Post process logits
|
458
506
|
logits = logits.contiguous()
|
459
507
|
logits.div_(self.temperatures)
|
460
|
-
|
508
|
+
if self.logit_bias is not None:
|
509
|
+
logits.add_(self.logit_bias)
|
461
510
|
|
462
511
|
has_regex = any(req.regex_fsm is not None for req in self.reqs)
|
463
512
|
if has_regex:
|
@@ -4,7 +4,8 @@ import logging
|
|
4
4
|
import uvloop
|
5
5
|
import zmq
|
6
6
|
import zmq.asyncio
|
7
|
-
|
7
|
+
|
8
|
+
from sglang.global_config import global_config
|
8
9
|
from sglang.srt.managers.router.model_rpc import ModelRpcClient
|
9
10
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
10
11
|
from sglang.srt.utils import get_exception_traceback
|
@@ -29,7 +30,7 @@ class RouterManager:
|
|
29
30
|
self.recv_reqs = []
|
30
31
|
|
31
32
|
# Init some configs
|
32
|
-
self.
|
33
|
+
self.request_dependency_time = global_config.request_dependency_time
|
33
34
|
|
34
35
|
async def loop_for_forward(self):
|
35
36
|
while True:
|
@@ -41,12 +42,16 @@ class RouterManager:
|
|
41
42
|
self.send_to_detokenizer.send_pyobj(obj)
|
42
43
|
|
43
44
|
# async sleep for receiving the subsequent request and avoiding cache miss
|
45
|
+
slept = False
|
44
46
|
if len(out_pyobjs) != 0:
|
45
47
|
has_finished = any([obj.finished for obj in out_pyobjs])
|
46
48
|
if has_finished:
|
47
|
-
|
49
|
+
if self.request_dependency_time > 0:
|
50
|
+
slept = True
|
51
|
+
await asyncio.sleep(self.request_dependency_time)
|
48
52
|
|
49
|
-
|
53
|
+
if not slept:
|
54
|
+
await asyncio.sleep(0.0006)
|
50
55
|
|
51
56
|
async def loop_for_recv_requests(self):
|
52
57
|
while True:
|
@@ -55,9 +60,7 @@ class RouterManager:
|
|
55
60
|
|
56
61
|
|
57
62
|
def start_router_process(
|
58
|
-
server_args: ServerArgs,
|
59
|
-
port_args: PortArgs,
|
60
|
-
pipe_writer,
|
63
|
+
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
|
61
64
|
):
|
62
65
|
logging.basicConfig(
|
63
66
|
level=getattr(logging, server_args.log_level.upper()),
|
@@ -65,7 +68,7 @@ def start_router_process(
|
|
65
68
|
)
|
66
69
|
|
67
70
|
try:
|
68
|
-
model_client = ModelRpcClient(server_args, port_args)
|
71
|
+
model_client = ModelRpcClient(server_args, port_args, model_overide_args)
|
69
72
|
router = RouterManager(model_client, port_args)
|
70
73
|
except Exception:
|
71
74
|
pipe_writer.send(get_exception_traceback())
|