sglang 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +55 -2
- sglang/api.py +3 -5
- sglang/backend/anthropic.py +33 -13
- sglang/backend/openai.py +2 -1
- sglang/backend/runtime_endpoint.py +18 -5
- sglang/backend/vertexai.py +1 -0
- sglang/global_config.py +1 -0
- sglang/lang/chat_template.py +74 -0
- sglang/lang/interpreter.py +40 -16
- sglang/lang/ir.py +1 -1
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +2 -1
- sglang/srt/constrained/fsm_cache.py +15 -3
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/conversation.py +2 -2
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +1 -0
- sglang/srt/layers/logits_processor.py +114 -54
- sglang/srt/layers/radix_attention.py +2 -1
- sglang/srt/layers/token_attention.py +1 -0
- sglang/srt/managers/detokenizer_manager.py +5 -1
- sglang/srt/managers/io_struct.py +12 -0
- sglang/srt/managers/router/infer_batch.py +70 -33
- sglang/srt/managers/router/manager.py +7 -2
- sglang/srt/managers/router/model_rpc.py +116 -73
- sglang/srt/managers/router/model_runner.py +121 -155
- sglang/srt/managers/router/radix_cache.py +46 -38
- sglang/srt/managers/tokenizer_manager.py +56 -11
- sglang/srt/memory_pool.py +5 -14
- sglang/srt/model_config.py +7 -0
- sglang/srt/models/commandr.py +376 -0
- sglang/srt/models/dbrx.py +413 -0
- sglang/srt/models/dbrx_config.py +281 -0
- sglang/srt/models/gemma.py +22 -20
- sglang/srt/models/llama2.py +23 -21
- sglang/srt/models/llava.py +12 -10
- sglang/srt/models/mixtral.py +27 -25
- sglang/srt/models/qwen.py +23 -21
- sglang/srt/models/qwen2.py +23 -21
- sglang/srt/models/stablelm.py +292 -0
- sglang/srt/models/yivl.py +6 -5
- sglang/srt/openai_api_adapter.py +356 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +68 -439
- sglang/srt/server_args.py +76 -49
- sglang/srt/utils.py +88 -32
- sglang/srt/weight_utils.py +402 -0
- sglang/test/test_programs.py +8 -7
- sglang/test/test_utils.py +196 -8
- {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/METADATA +13 -15
- sglang-0.1.15.dist-info/RECORD +69 -0
- {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/WHEEL +1 -1
- sglang-0.1.13.dist-info/RECORD +0 -63
- {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/LICENSE +0 -0
- {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,12 @@
|
|
1
1
|
import torch
|
2
|
-
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
3
2
|
from torch import nn
|
4
|
-
from vllm.
|
3
|
+
from vllm.distributed import (
|
5
4
|
get_tensor_model_parallel_world_size,
|
6
5
|
tensor_model_parallel_all_gather,
|
7
6
|
)
|
8
7
|
|
8
|
+
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
9
|
+
|
9
10
|
|
10
11
|
class LogitsProcessor(nn.Module):
|
11
12
|
def __init__(self, config):
|
@@ -13,76 +14,136 @@ class LogitsProcessor(nn.Module):
|
|
13
14
|
self.config = config
|
14
15
|
self.tp_size = get_tensor_model_parallel_world_size()
|
15
16
|
|
16
|
-
def
|
17
|
-
|
17
|
+
def _get_normalized_prompt_logprobs(
|
18
|
+
self, prefill_token_logprobs, input_metadata: InputMetadata
|
19
|
+
):
|
20
|
+
logprobs_cumsum = torch.cumsum(
|
21
|
+
prefill_token_logprobs, dim=0, dtype=torch.float32
|
22
|
+
)
|
18
23
|
|
19
|
-
|
20
|
-
|
24
|
+
start = input_metadata.extend_start_loc.clone()
|
25
|
+
end = start + input_metadata.extend_seq_lens - 2
|
26
|
+
start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
27
|
+
end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
28
|
+
sum_logp = (
|
29
|
+
logprobs_cumsum[end]
|
30
|
+
- logprobs_cumsum[start]
|
31
|
+
+ prefill_token_logprobs[start]
|
32
|
+
)
|
33
|
+
normalized_prompt_logprobs = sum_logp / (
|
34
|
+
(input_metadata.extend_seq_lens - 1).clamp(min=1)
|
35
|
+
)
|
36
|
+
|
37
|
+
return normalized_prompt_logprobs
|
38
|
+
|
39
|
+
def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
|
40
|
+
if input_metadata.forward_mode == ForwardMode.DECODE:
|
41
|
+
decode_top_logprobs = []
|
42
|
+
for i in range(all_logprobs.shape[0]):
|
43
|
+
k = input_metadata.top_logprobs_nums[i]
|
44
|
+
t = all_logprobs[i].topk(k)
|
45
|
+
v_cpu = t.values.tolist()
|
46
|
+
p_cpu = t.indices.tolist()
|
47
|
+
decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
|
48
|
+
return None, decode_top_logprobs
|
49
|
+
else:
|
50
|
+
prefill_top_logprobs, decode_top_logprobs = [], []
|
51
|
+
pt = 0
|
52
|
+
# NOTE: the GPU-CPU overhead can be reduced
|
53
|
+
extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy()
|
54
|
+
for i in range(len(extend_seq_lens_cpu)):
|
55
|
+
if extend_seq_lens_cpu[i] == 0:
|
56
|
+
prefill_top_logprobs.append([])
|
57
|
+
decode_top_logprobs.append([])
|
58
|
+
continue
|
59
|
+
k = input_metadata.top_logprobs_nums[i]
|
60
|
+
t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
|
61
|
+
vs_cpu = t.values.tolist()
|
62
|
+
ps_cpu = t.indices.tolist()
|
63
|
+
prefill_top_logprobs.append(
|
64
|
+
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
|
65
|
+
)
|
66
|
+
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
67
|
+
pt += extend_seq_lens_cpu[i]
|
68
|
+
return prefill_top_logprobs, decode_top_logprobs
|
69
|
+
|
70
|
+
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
|
71
|
+
# Get last index for next token prediction, except for DECODE mode.
|
72
|
+
last_index = None
|
21
73
|
if input_metadata.forward_mode != ForwardMode.DECODE:
|
22
74
|
last_index = (
|
23
|
-
torch.cumsum(
|
24
|
-
input_metadata.seq_lens - input_metadata.prefix_lens,
|
25
|
-
dim=0,
|
26
|
-
dtype=torch.long,
|
27
|
-
)
|
75
|
+
torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long)
|
28
76
|
- 1
|
29
77
|
)
|
30
78
|
|
79
|
+
# Get the last hidden states and last logits
|
80
|
+
if input_metadata.forward_mode == ForwardMode.DECODE:
|
81
|
+
last_hidden = hidden_states
|
82
|
+
else:
|
83
|
+
last_hidden = hidden_states[last_index]
|
84
|
+
|
85
|
+
last_logits = torch.matmul(last_hidden, weight.T)
|
86
|
+
if self.tp_size > 1:
|
87
|
+
last_logits = tensor_model_parallel_all_gather(last_logits)
|
88
|
+
last_logits = last_logits[:, : self.config.vocab_size]
|
89
|
+
|
90
|
+
# Return only last_logits if logprob is not requested
|
31
91
|
if not input_metadata.return_logprob:
|
32
|
-
|
33
|
-
|
34
|
-
last_hidden = hidden_states
|
35
|
-
else:
|
36
|
-
last_hidden = hidden_states[last_index]
|
37
|
-
hidden_states = None
|
38
|
-
|
39
|
-
last_logits = torch.matmul(last_hidden, weight.T)
|
40
|
-
if self.tp_size > 1:
|
41
|
-
last_logits = tensor_model_parallel_all_gather(last_logits)
|
42
|
-
last_logits = last_logits[:, : self.config.vocab_size]
|
43
|
-
return last_logits, (None, None, None)
|
92
|
+
hidden_states = None
|
93
|
+
return last_logits, (None, None, None, None, None)
|
44
94
|
else:
|
45
95
|
# When logprob is requested, compute the logits for all tokens.
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
96
|
+
if input_metadata.forward_mode == ForwardMode.DECODE:
|
97
|
+
all_logits = last_logits
|
98
|
+
else:
|
99
|
+
all_logits = torch.matmul(hidden_states, weight.T)
|
100
|
+
if self.tp_size > 1:
|
101
|
+
all_logits = tensor_model_parallel_all_gather(all_logits)
|
102
|
+
all_logits = all_logits[:, : self.config.vocab_size]
|
103
|
+
|
104
|
+
all_logprobs = all_logits.float()
|
105
|
+
del all_logits
|
106
|
+
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
107
|
+
|
108
|
+
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
|
109
|
+
if return_top_logprob:
|
110
|
+
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
111
|
+
all_logprobs, input_metadata
|
112
|
+
)
|
113
|
+
else:
|
114
|
+
prefill_top_logprobs = decode_top_logprobs = None
|
51
115
|
|
52
116
|
if input_metadata.forward_mode == ForwardMode.DECODE:
|
53
|
-
last_logits = logits
|
54
117
|
last_logprobs = all_logprobs
|
55
|
-
|
118
|
+
return last_logits, (
|
119
|
+
None,
|
120
|
+
None,
|
121
|
+
None,
|
122
|
+
decode_top_logprobs,
|
123
|
+
last_logprobs,
|
124
|
+
)
|
56
125
|
else:
|
57
126
|
# Compute the logprobs for the last token of each request.
|
58
|
-
last_logits = logits[last_index]
|
59
127
|
last_logprobs = all_logprobs[last_index]
|
60
128
|
|
61
129
|
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
62
130
|
# Note that we pad a zero at the end of each sequence for easy computation.
|
63
|
-
|
131
|
+
prefill_token_logprobs = all_logprobs[
|
64
132
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
65
133
|
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
66
134
|
]
|
67
|
-
logprobs_cumsum = torch.cumsum(
|
68
|
-
prefill_logprobs, dim=0, dtype=torch.float32
|
69
|
-
)
|
70
135
|
|
71
|
-
|
72
|
-
|
73
|
-
start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
74
|
-
end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
75
|
-
sum_logp = (
|
76
|
-
logprobs_cumsum[end]
|
77
|
-
- logprobs_cumsum[start]
|
78
|
-
+ prefill_logprobs[start]
|
136
|
+
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
137
|
+
prefill_token_logprobs, input_metadata
|
79
138
|
)
|
80
|
-
|
81
|
-
|
139
|
+
return last_logits, (
|
140
|
+
prefill_token_logprobs,
|
141
|
+
normalized_prompt_logprobs,
|
142
|
+
prefill_top_logprobs,
|
143
|
+
decode_top_logprobs,
|
144
|
+
last_logprobs,
|
82
145
|
)
|
83
146
|
|
84
|
-
return last_logits, (prefill_logprobs, normalized_logprobs, last_logprobs)
|
85
|
-
|
86
147
|
|
87
148
|
if __name__ == "__main__":
|
88
149
|
all_logprobs = torch.tensor(
|
@@ -93,23 +154,22 @@ if __name__ == "__main__":
|
|
93
154
|
)
|
94
155
|
seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda")
|
95
156
|
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
|
96
|
-
logprobs = torch.zeros(5, dtype=torch.float32, device="cuda")
|
97
157
|
|
98
|
-
|
158
|
+
token_logprobs = all_logprobs[
|
99
159
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
100
160
|
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
101
161
|
]
|
102
|
-
logprobs_cumsum = torch.cumsum(
|
162
|
+
logprobs_cumsum = torch.cumsum(token_logprobs, dim=0, dtype=torch.float32)
|
103
163
|
|
104
164
|
len_cumsum = torch.cumsum(seq_lens, dim=0)
|
105
165
|
start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0)
|
106
166
|
end = start + seq_lens - 2
|
107
|
-
start.clamp_(min=0, max=
|
108
|
-
end.clamp_(min=0, max=
|
109
|
-
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] +
|
167
|
+
start.clamp_(min=0, max=token_logprobs.shape[0] - 1)
|
168
|
+
end.clamp_(min=0, max=token_logprobs.shape[0] - 1)
|
169
|
+
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + token_logprobs[start]
|
110
170
|
|
111
171
|
# assert logprobs == [2, _, 2, 4, _]
|
112
|
-
print("logprobs",
|
172
|
+
print("token logprobs", token_logprobs)
|
113
173
|
print("start", start)
|
114
174
|
print("end", end)
|
115
175
|
print("sum_logp", sum_logp)
|
@@ -1,9 +1,10 @@
|
|
1
1
|
import torch
|
2
|
+
from torch import nn
|
3
|
+
|
2
4
|
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
3
5
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
4
6
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
5
7
|
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
6
|
-
from torch import nn
|
7
8
|
|
8
9
|
|
9
10
|
class RadixAttention(nn.Module):
|
@@ -3,6 +3,7 @@ import asyncio
|
|
3
3
|
import uvloop
|
4
4
|
import zmq
|
5
5
|
import zmq.asyncio
|
6
|
+
|
6
7
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
7
8
|
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
8
9
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
@@ -37,10 +38,13 @@ class DetokenizerManager:
|
|
37
38
|
if isinstance(recv_obj, BatchTokenIDOut):
|
38
39
|
output_tokens = recv_obj.output_tokens
|
39
40
|
|
40
|
-
# TODO(lmzheng): handle skip_special_tokens per request
|
41
|
+
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
41
42
|
output_strs = self.tokenizer.batch_decode(
|
42
43
|
output_tokens,
|
43
44
|
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
45
|
+
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
|
46
|
+
0
|
47
|
+
],
|
44
48
|
)
|
45
49
|
|
46
50
|
# Trim stop str
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -19,10 +19,13 @@ class GenerateReqInput:
|
|
19
19
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
20
20
|
# The start location of the prompt for return_logprob
|
21
21
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
22
|
+
# The number of top logprobs to return
|
23
|
+
top_logprobs_num: Optional[Union[List[int], int]] = None
|
22
24
|
# Whether to detokenize tokens in logprobs
|
23
25
|
return_text_in_logprobs: bool = False
|
24
26
|
# Whether to stream output
|
25
27
|
stream: bool = False
|
28
|
+
# TODO: make all parameters a Union[List[T], T] to allow for batched requests
|
26
29
|
|
27
30
|
def post_init(self):
|
28
31
|
is_single = isinstance(self.text, str)
|
@@ -36,6 +39,8 @@ class GenerateReqInput:
|
|
36
39
|
self.return_logprob = False
|
37
40
|
if self.logprob_start_len is None:
|
38
41
|
self.logprob_start_len = 0
|
42
|
+
if self.top_logprobs_num is None:
|
43
|
+
self.top_logprobs_num = 0
|
39
44
|
else:
|
40
45
|
num = len(self.text)
|
41
46
|
|
@@ -64,6 +69,11 @@ class GenerateReqInput:
|
|
64
69
|
elif not isinstance(self.logprob_start_len, list):
|
65
70
|
self.logprob_start_len = [self.logprob_start_len] * num
|
66
71
|
|
72
|
+
if self.top_logprobs_num is None:
|
73
|
+
self.top_logprobs_num = [0] * num
|
74
|
+
elif not isinstance(self.top_logprobs_num, list):
|
75
|
+
self.top_logprobs_num = [self.top_logprobs_num] * num
|
76
|
+
|
67
77
|
|
68
78
|
@dataclass
|
69
79
|
class TokenizedGenerateReqInput:
|
@@ -76,6 +86,7 @@ class TokenizedGenerateReqInput:
|
|
76
86
|
sampling_params: SamplingParams
|
77
87
|
return_logprob: bool
|
78
88
|
logprob_start_len: int
|
89
|
+
top_logprobs_num: int
|
79
90
|
stream: bool
|
80
91
|
|
81
92
|
|
@@ -86,6 +97,7 @@ class BatchTokenIDOut:
|
|
86
97
|
output_and_jump_forward_strs: List[str]
|
87
98
|
hit_stop_str: List[Optional[str]]
|
88
99
|
skip_special_tokens: List[bool]
|
100
|
+
spaces_between_special_tokens: List[bool]
|
89
101
|
meta_info: List[Dict]
|
90
102
|
finished: List[bool]
|
91
103
|
|
@@ -1,22 +1,23 @@
|
|
1
1
|
from dataclasses import dataclass
|
2
|
-
from enum import
|
2
|
+
from enum import IntEnum, auto
|
3
3
|
from typing import List
|
4
4
|
|
5
5
|
import numpy as np
|
6
6
|
import torch
|
7
|
+
|
7
8
|
from sglang.srt.managers.router.radix_cache import RadixCache
|
8
9
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
9
10
|
|
10
11
|
|
11
|
-
class ForwardMode(
|
12
|
+
class ForwardMode(IntEnum):
|
12
13
|
PREFILL = auto()
|
13
14
|
EXTEND = auto()
|
14
15
|
DECODE = auto()
|
15
16
|
|
16
17
|
|
17
|
-
class FinishReason(
|
18
|
-
LENGTH = auto()
|
18
|
+
class FinishReason(IntEnum):
|
19
19
|
EOS_TOKEN = auto()
|
20
|
+
LENGTH = auto()
|
20
21
|
STOP_STR = auto()
|
21
22
|
|
22
23
|
|
@@ -30,6 +31,7 @@ class Req:
|
|
30
31
|
# Since jump forward may retokenize the prompt with partial outputs,
|
31
32
|
# we maintain the original prompt length to report the correct usage.
|
32
33
|
self.prompt_tokens = len(input_ids)
|
34
|
+
|
33
35
|
# The number of decoded tokens for token usage report. Note that
|
34
36
|
# this does not include the jump forward tokens.
|
35
37
|
self.completion_tokens_wo_jump_forward = 0
|
@@ -40,11 +42,11 @@ class Req:
|
|
40
42
|
self.image_offset = 0
|
41
43
|
self.pad_value = None
|
42
44
|
|
45
|
+
# Sampling parameters
|
43
46
|
self.sampling_params = None
|
44
|
-
self.return_logprob = False
|
45
|
-
self.logprob_start_len = 0
|
46
47
|
self.stream = False
|
47
48
|
|
49
|
+
# Check finish
|
48
50
|
self.tokenizer = None
|
49
51
|
self.finished = False
|
50
52
|
self.finish_reason = None
|
@@ -54,11 +56,17 @@ class Req:
|
|
54
56
|
self.prefix_indices = []
|
55
57
|
self.last_node = None
|
56
58
|
|
57
|
-
|
58
|
-
self.
|
59
|
-
self.
|
60
|
-
|
61
|
-
|
59
|
+
# Logprobs
|
60
|
+
self.return_logprob = False
|
61
|
+
self.logprob_start_len = 0
|
62
|
+
self.top_logprobs_num = 0
|
63
|
+
self.normalized_prompt_logprob = None
|
64
|
+
self.prefill_token_logprobs = None
|
65
|
+
self.decode_token_logprobs = None
|
66
|
+
self.prefill_top_logprobs = None
|
67
|
+
self.decode_top_logprobs = None
|
68
|
+
|
69
|
+
# Constrained decoding
|
62
70
|
self.regex_fsm = None
|
63
71
|
self.regex_fsm_state = 0
|
64
72
|
self.jump_forward_map = None
|
@@ -159,7 +167,10 @@ class Batch:
|
|
159
167
|
out_cache_loc: torch.Tensor = None
|
160
168
|
out_cache_cont_start: torch.Tensor = None
|
161
169
|
out_cache_cont_end: torch.Tensor = None
|
170
|
+
|
171
|
+
# for processing logprobs
|
162
172
|
return_logprob: bool = False
|
173
|
+
top_logprobs_nums: List[int] = None
|
163
174
|
|
164
175
|
# for multimodal
|
165
176
|
pixel_values: List[torch.Tensor] = None
|
@@ -229,12 +240,11 @@ class Batch:
|
|
229
240
|
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
230
241
|
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
231
242
|
if out_cache_loc is None:
|
232
|
-
|
233
|
-
|
234
|
-
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
243
|
+
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.dec_refs)
|
244
|
+
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
235
245
|
|
236
246
|
if out_cache_loc is None:
|
237
|
-
print("Prefill out of memory. This should
|
247
|
+
print("Prefill out of memory. This should never happen.")
|
238
248
|
self.tree_cache.pretty_print()
|
239
249
|
exit()
|
240
250
|
|
@@ -245,10 +255,14 @@ class Batch:
|
|
245
255
|
] = out_cache_loc[pt : pt + extend_lens[i]]
|
246
256
|
pt += extend_lens[i]
|
247
257
|
|
248
|
-
# Handle logit bias
|
249
|
-
logit_bias =
|
258
|
+
# Handle logit bias but only allocate when needed
|
259
|
+
logit_bias = None
|
250
260
|
for i in range(bs):
|
251
261
|
if reqs[i].sampling_params.dtype == "int":
|
262
|
+
if logit_bias is None:
|
263
|
+
logit_bias = torch.zeros(
|
264
|
+
(bs, vocab_size), dtype=torch.float32, device=device
|
265
|
+
)
|
252
266
|
logit_bias[i] = int_token_logit_bias
|
253
267
|
|
254
268
|
# Set fields
|
@@ -266,6 +280,7 @@ class Batch:
|
|
266
280
|
self.position_ids_offsets = position_ids_offsets
|
267
281
|
self.extend_num_tokens = extend_num_tokens
|
268
282
|
self.out_cache_loc = out_cache_loc
|
283
|
+
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
269
284
|
|
270
285
|
self.temperatures = torch.tensor(
|
271
286
|
[r.sampling_params.temperature for r in reqs],
|
@@ -295,8 +310,8 @@ class Batch:
|
|
295
310
|
if self.token_to_kv_pool.available_size() >= bs:
|
296
311
|
return True
|
297
312
|
|
298
|
-
|
299
|
-
|
313
|
+
self.tree_cache.evict(bs, self.token_to_kv_pool.dec_refs)
|
314
|
+
|
300
315
|
if self.token_to_kv_pool.available_size() >= bs:
|
301
316
|
return True
|
302
317
|
|
@@ -310,8 +325,8 @@ class Batch:
|
|
310
325
|
)
|
311
326
|
|
312
327
|
retracted_reqs = []
|
313
|
-
|
314
|
-
|
328
|
+
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
329
|
+
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
315
330
|
while self.token_to_kv_pool.available_size() < len(self.reqs):
|
316
331
|
idx = sorted_indices.pop()
|
317
332
|
req = self.reqs[idx]
|
@@ -327,9 +342,9 @@ class Batch:
|
|
327
342
|
# TODO: apply more fine-grained retraction
|
328
343
|
|
329
344
|
token_indices = self.req_to_token_pool.req_to_token[
|
330
|
-
|
331
|
-
][:
|
332
|
-
self.token_to_kv_pool.
|
345
|
+
req_pool_indices_cpu[idx]
|
346
|
+
][: seq_lens_cpu[idx]]
|
347
|
+
self.token_to_kv_pool.dec_refs(token_indices)
|
333
348
|
|
334
349
|
self.filter_batch(sorted_indices)
|
335
350
|
|
@@ -352,7 +367,7 @@ class Batch:
|
|
352
367
|
# insert the old request into tree_cache
|
353
368
|
token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
|
354
369
|
if req_pool_indices_cpu is None:
|
355
|
-
req_pool_indices_cpu = self.req_pool_indices.
|
370
|
+
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
356
371
|
req_pool_idx = req_pool_indices_cpu[i]
|
357
372
|
indices = self.req_to_token_pool.req_to_token[
|
358
373
|
req_pool_idx, : len(token_ids_in_memory)
|
@@ -360,7 +375,7 @@ class Batch:
|
|
360
375
|
prefix_len = self.tree_cache.insert(
|
361
376
|
token_ids_in_memory, indices.clone()
|
362
377
|
)
|
363
|
-
self.token_to_kv_pool.
|
378
|
+
self.token_to_kv_pool.dec_refs(indices[:prefix_len])
|
364
379
|
self.req_to_token_pool.free(req_pool_idx)
|
365
380
|
self.tree_cache.dec_ref_counter(req.last_node)
|
366
381
|
|
@@ -391,7 +406,7 @@ class Batch:
|
|
391
406
|
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
392
407
|
|
393
408
|
if self.out_cache_loc is None:
|
394
|
-
print("Decode out of memory. This should
|
409
|
+
print("Decode out of memory. This should never happen.")
|
395
410
|
self.tree_cache.pretty_print()
|
396
411
|
exit()
|
397
412
|
|
@@ -415,6 +430,7 @@ class Batch:
|
|
415
430
|
self.prefix_lens = None
|
416
431
|
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
417
432
|
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
433
|
+
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
418
434
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
419
435
|
|
420
436
|
for item in [
|
@@ -425,9 +441,12 @@ class Batch:
|
|
425
441
|
"presence_penalties",
|
426
442
|
"logit_bias",
|
427
443
|
]:
|
428
|
-
|
444
|
+
self_val = getattr(self, item, None)
|
445
|
+
# logit_bias can be None
|
446
|
+
if self_val is not None:
|
447
|
+
setattr(self, item, self_val[new_indices])
|
429
448
|
|
430
|
-
def merge(self, other):
|
449
|
+
def merge(self, other: "Batch"):
|
431
450
|
self.reqs.extend(other.reqs)
|
432
451
|
|
433
452
|
self.req_pool_indices = torch.concat(
|
@@ -439,6 +458,7 @@ class Batch:
|
|
439
458
|
[self.position_ids_offsets, other.position_ids_offsets]
|
440
459
|
)
|
441
460
|
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
461
|
+
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
442
462
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
443
463
|
|
444
464
|
for item in [
|
@@ -447,17 +467,34 @@ class Batch:
|
|
447
467
|
"top_ks",
|
448
468
|
"frequency_penalties",
|
449
469
|
"presence_penalties",
|
450
|
-
"logit_bias",
|
451
470
|
]:
|
452
|
-
|
453
|
-
|
471
|
+
self_val = getattr(self, item, None)
|
472
|
+
other_val = getattr(other, item, None)
|
473
|
+
setattr(self, item, torch.concat([self_val, other_val]))
|
474
|
+
|
475
|
+
# logit_bias can be None
|
476
|
+
if self.logit_bias is not None or other.logit_bias is not None:
|
477
|
+
vocab_size = (
|
478
|
+
self.logit_bias.shape[1]
|
479
|
+
if self.logit_bias is not None
|
480
|
+
else other.logit_bias.shape[1]
|
454
481
|
)
|
482
|
+
if self.logit_bias is None:
|
483
|
+
self.logit_bias = torch.zeros(
|
484
|
+
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
485
|
+
)
|
486
|
+
if other.logit_bias is None:
|
487
|
+
other.logit_bias = torch.zeros(
|
488
|
+
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
489
|
+
)
|
490
|
+
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
455
491
|
|
456
492
|
def sample(self, logits: torch.Tensor):
|
457
493
|
# Post process logits
|
458
494
|
logits = logits.contiguous()
|
459
495
|
logits.div_(self.temperatures)
|
460
|
-
|
496
|
+
if self.logit_bias is not None:
|
497
|
+
logits.add_(self.logit_bias)
|
461
498
|
|
462
499
|
has_regex = any(req.regex_fsm is not None for req in self.reqs)
|
463
500
|
if has_regex:
|
@@ -4,6 +4,7 @@ import logging
|
|
4
4
|
import uvloop
|
5
5
|
import zmq
|
6
6
|
import zmq.asyncio
|
7
|
+
|
7
8
|
from sglang.srt.backend_config import GLOBAL_BACKEND_CONFIG
|
8
9
|
from sglang.srt.managers.router.model_rpc import ModelRpcClient
|
9
10
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
@@ -41,12 +42,16 @@ class RouterManager:
|
|
41
42
|
self.send_to_detokenizer.send_pyobj(obj)
|
42
43
|
|
43
44
|
# async sleep for receiving the subsequent request and avoiding cache miss
|
45
|
+
slept = False
|
44
46
|
if len(out_pyobjs) != 0:
|
45
47
|
has_finished = any([obj.finished for obj in out_pyobjs])
|
46
48
|
if has_finished:
|
47
|
-
|
49
|
+
if self.extend_dependency_time > 0:
|
50
|
+
slept = True
|
51
|
+
await asyncio.sleep(self.extend_dependency_time)
|
48
52
|
|
49
|
-
|
53
|
+
if not slept:
|
54
|
+
await asyncio.sleep(0.0006)
|
50
55
|
|
51
56
|
async def loop_for_recv_requests(self):
|
52
57
|
while True:
|