sglang 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +57 -2
- sglang/api.py +8 -5
- sglang/backend/anthropic.py +18 -4
- sglang/backend/openai.py +2 -1
- sglang/backend/runtime_endpoint.py +18 -5
- sglang/backend/vertexai.py +1 -0
- sglang/global_config.py +5 -1
- sglang/lang/chat_template.py +83 -2
- sglang/lang/interpreter.py +92 -35
- sglang/lang/ir.py +12 -9
- sglang/lang/tracer.py +6 -4
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/constrained/fsm_cache.py +1 -0
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/conversation.py +2 -2
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +10 -2
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +1 -0
- sglang/srt/layers/logits_processor.py +114 -54
- sglang/srt/layers/radix_attention.py +2 -1
- sglang/srt/layers/token_attention.py +1 -0
- sglang/srt/managers/detokenizer_manager.py +5 -1
- sglang/srt/managers/io_struct.py +27 -3
- sglang/srt/managers/router/infer_batch.py +97 -48
- sglang/srt/managers/router/manager.py +11 -8
- sglang/srt/managers/router/model_rpc.py +169 -90
- sglang/srt/managers/router/model_runner.py +110 -166
- sglang/srt/managers/router/radix_cache.py +89 -51
- sglang/srt/managers/router/scheduler.py +17 -28
- sglang/srt/managers/tokenizer_manager.py +110 -33
- sglang/srt/memory_pool.py +5 -14
- sglang/srt/model_config.py +11 -0
- sglang/srt/models/commandr.py +372 -0
- sglang/srt/models/dbrx.py +412 -0
- sglang/srt/models/dbrx_config.py +281 -0
- sglang/srt/models/gemma.py +24 -25
- sglang/srt/models/llama2.py +25 -26
- sglang/srt/models/llava.py +8 -10
- sglang/srt/models/llavavid.py +307 -0
- sglang/srt/models/mixtral.py +29 -33
- sglang/srt/models/qwen.py +34 -25
- sglang/srt/models/qwen2.py +25 -26
- sglang/srt/models/stablelm.py +26 -26
- sglang/srt/models/yivl.py +3 -5
- sglang/srt/openai_api_adapter.py +356 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +91 -456
- sglang/srt/server_args.py +79 -49
- sglang/srt/utils.py +212 -47
- sglang/srt/weight_utils.py +417 -0
- sglang/test/test_programs.py +8 -7
- sglang/test/test_utils.py +195 -7
- sglang/utils.py +77 -26
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/METADATA +20 -18
- sglang-0.1.16.dist-info/RECORD +72 -0
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
sglang/lang/ir.py
CHANGED
@@ -193,17 +193,11 @@ class SglFunction:
|
|
193
193
|
backend = backend or global_config.default_backend
|
194
194
|
return trace_program(self, kwargs, backend)
|
195
195
|
|
196
|
-
def
|
197
|
-
from sglang.lang.interpreter import
|
196
|
+
def cache(self, backend=None):
|
197
|
+
from sglang.lang.interpreter import cache_program
|
198
198
|
|
199
199
|
backend = backend or global_config.default_backend
|
200
|
-
return
|
201
|
-
|
202
|
-
def unpin(self, backend=None):
|
203
|
-
from sglang.lang.interpreter import unpin_program
|
204
|
-
|
205
|
-
backend = backend or global_config.default_backend
|
206
|
-
return unpin_program(self, backend)
|
200
|
+
return cache_program(self, backend)
|
207
201
|
|
208
202
|
def compile(self, *, backend=None):
|
209
203
|
from sglang.lang.compiler import compile_func
|
@@ -336,6 +330,15 @@ class SglImage(SglExpr):
|
|
336
330
|
return f"SglImage({self.path})"
|
337
331
|
|
338
332
|
|
333
|
+
class SglVideo(SglExpr):
|
334
|
+
def __init__(self, path, num_frames):
|
335
|
+
self.path = path
|
336
|
+
self.num_frames = num_frames
|
337
|
+
|
338
|
+
def __repr__(self) -> str:
|
339
|
+
return f"SglVideo({self.path}, {self.num_frames})"
|
340
|
+
|
341
|
+
|
339
342
|
class SglGen(SglExpr):
|
340
343
|
def __init__(
|
341
344
|
self,
|
sglang/lang/tracer.py
CHANGED
@@ -109,19 +109,21 @@ class TracerProgramState(ProgramState):
|
|
109
109
|
########### Public API ###########
|
110
110
|
##################################
|
111
111
|
|
112
|
-
def fork(self,
|
112
|
+
def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
|
113
|
+
assert size >= 1
|
114
|
+
|
113
115
|
if self.only_trace_prefix:
|
114
116
|
raise StopTracing()
|
115
117
|
|
116
|
-
fork_node = SglFork(
|
118
|
+
fork_node = SglFork(size)
|
117
119
|
fork_node.prev_node = self.last_node
|
118
120
|
|
119
121
|
states = [
|
120
122
|
TracerProgramState(self.backend, self.arguments, self.only_trace_prefix)
|
121
|
-
for _ in range(
|
123
|
+
for _ in range(size)
|
122
124
|
]
|
123
125
|
|
124
|
-
for i in range(
|
126
|
+
for i in range(size):
|
125
127
|
node = SglGetForkItem(i)
|
126
128
|
node.prev_node = fork_node
|
127
129
|
states[i].last_node = node
|
@@ -0,0 +1,31 @@
|
|
1
|
+
import argparse
|
2
|
+
import multiprocessing as mp
|
3
|
+
|
4
|
+
from sglang.srt.server import ServerArgs, launch_server
|
5
|
+
|
6
|
+
if __name__ == "__main__":
|
7
|
+
|
8
|
+
model_overide_args = {}
|
9
|
+
|
10
|
+
model_overide_args["mm_spatial_pool_stride"] = 2
|
11
|
+
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
|
12
|
+
model_overide_args["num_frames"] = 16
|
13
|
+
model_overide_args["model_type"] = "llavavid"
|
14
|
+
if model_overide_args["num_frames"] == 32:
|
15
|
+
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
|
16
|
+
model_overide_args["max_sequence_length"] = 4096 * 2
|
17
|
+
model_overide_args["tokenizer_model_max_length"] = 4096 * 2
|
18
|
+
model_overide_args["model_max_length"] = 4096 * 2
|
19
|
+
|
20
|
+
parser = argparse.ArgumentParser()
|
21
|
+
ServerArgs.add_cli_args(parser)
|
22
|
+
args = parser.parse_args()
|
23
|
+
|
24
|
+
if "34b" in args.model_path.lower():
|
25
|
+
model_overide_args["image_token_index"] = 64002
|
26
|
+
|
27
|
+
server_args = ServerArgs.from_cli_args(args)
|
28
|
+
|
29
|
+
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
30
|
+
|
31
|
+
launch_server(server_args, pipe_writer, model_overide_args)
|
sglang/srt/conversation.py
CHANGED
@@ -4,7 +4,7 @@ import dataclasses
|
|
4
4
|
from enum import IntEnum, auto
|
5
5
|
from typing import Dict, List, Optional, Tuple, Union
|
6
6
|
|
7
|
-
from sglang.srt.
|
7
|
+
from sglang.srt.openai_protocol import ChatCompletionRequest
|
8
8
|
|
9
9
|
|
10
10
|
class SeparatorStyle(IntEnum):
|
@@ -400,7 +400,7 @@ register_conv_template(
|
|
400
400
|
Conversation(
|
401
401
|
name="chatml",
|
402
402
|
system_template="<|im_start|>system\n{system_message}",
|
403
|
-
system_message="You are
|
403
|
+
system_message="You are a helpful assistant.",
|
404
404
|
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
405
405
|
sep_style=SeparatorStyle.CHATML,
|
406
406
|
sep="<|im_end|>",
|
@@ -0,0 +1,16 @@
|
|
1
|
+
"""
|
2
|
+
Usage:
|
3
|
+
python3 -m sglang.srt.flush_cache --url http://localhost:30000
|
4
|
+
"""
|
5
|
+
|
6
|
+
import argparse
|
7
|
+
|
8
|
+
import requests
|
9
|
+
|
10
|
+
if __name__ == "__main__":
|
11
|
+
parser = argparse.ArgumentParser()
|
12
|
+
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
13
|
+
args = parser.parse_args()
|
14
|
+
|
15
|
+
response = requests.get(args.url + "/flush_cache")
|
16
|
+
assert response.status_code == 200
|
@@ -6,7 +6,6 @@ import warnings
|
|
6
6
|
from typing import List, Optional, Tuple, Union
|
7
7
|
|
8
8
|
from huggingface_hub import snapshot_download
|
9
|
-
from sglang.srt.utils import is_multimodal_model
|
10
9
|
from transformers import (
|
11
10
|
AutoConfig,
|
12
11
|
AutoProcessor,
|
@@ -15,6 +14,8 @@ from transformers import (
|
|
15
14
|
PreTrainedTokenizerFast,
|
16
15
|
)
|
17
16
|
|
17
|
+
from sglang.srt.utils import is_multimodal_model
|
18
|
+
|
18
19
|
|
19
20
|
def download_from_hf(model_path: str):
|
20
21
|
if os.path.exists(model_path):
|
@@ -29,10 +30,17 @@ def get_config_json(model_path: str):
|
|
29
30
|
return config
|
30
31
|
|
31
32
|
|
32
|
-
def get_config(
|
33
|
+
def get_config(
|
34
|
+
model: str,
|
35
|
+
trust_remote_code: bool,
|
36
|
+
revision: Optional[str] = None,
|
37
|
+
model_overide_args: Optional[dict] = None,
|
38
|
+
):
|
33
39
|
config = AutoConfig.from_pretrained(
|
34
40
|
model, trust_remote_code=trust_remote_code, revision=revision
|
35
41
|
)
|
42
|
+
if model_overide_args:
|
43
|
+
config.update(model_overide_args)
|
36
44
|
return config
|
37
45
|
|
38
46
|
|
@@ -1,11 +1,12 @@
|
|
1
1
|
import torch
|
2
|
-
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
3
2
|
from torch import nn
|
4
|
-
from vllm.
|
3
|
+
from vllm.distributed import (
|
5
4
|
get_tensor_model_parallel_world_size,
|
6
5
|
tensor_model_parallel_all_gather,
|
7
6
|
)
|
8
7
|
|
8
|
+
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
9
|
+
|
9
10
|
|
10
11
|
class LogitsProcessor(nn.Module):
|
11
12
|
def __init__(self, config):
|
@@ -13,76 +14,136 @@ class LogitsProcessor(nn.Module):
|
|
13
14
|
self.config = config
|
14
15
|
self.tp_size = get_tensor_model_parallel_world_size()
|
15
16
|
|
16
|
-
def
|
17
|
-
|
17
|
+
def _get_normalized_prompt_logprobs(
|
18
|
+
self, prefill_token_logprobs, input_metadata: InputMetadata
|
19
|
+
):
|
20
|
+
logprobs_cumsum = torch.cumsum(
|
21
|
+
prefill_token_logprobs, dim=0, dtype=torch.float32
|
22
|
+
)
|
18
23
|
|
19
|
-
|
20
|
-
|
24
|
+
start = input_metadata.extend_start_loc.clone()
|
25
|
+
end = start + input_metadata.extend_seq_lens - 2
|
26
|
+
start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
27
|
+
end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
28
|
+
sum_logp = (
|
29
|
+
logprobs_cumsum[end]
|
30
|
+
- logprobs_cumsum[start]
|
31
|
+
+ prefill_token_logprobs[start]
|
32
|
+
)
|
33
|
+
normalized_prompt_logprobs = sum_logp / (
|
34
|
+
(input_metadata.extend_seq_lens - 1).clamp(min=1)
|
35
|
+
)
|
36
|
+
|
37
|
+
return normalized_prompt_logprobs
|
38
|
+
|
39
|
+
def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
|
40
|
+
if input_metadata.forward_mode == ForwardMode.DECODE:
|
41
|
+
decode_top_logprobs = []
|
42
|
+
for i in range(all_logprobs.shape[0]):
|
43
|
+
k = input_metadata.top_logprobs_nums[i]
|
44
|
+
t = all_logprobs[i].topk(k)
|
45
|
+
v_cpu = t.values.tolist()
|
46
|
+
p_cpu = t.indices.tolist()
|
47
|
+
decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
|
48
|
+
return None, decode_top_logprobs
|
49
|
+
else:
|
50
|
+
prefill_top_logprobs, decode_top_logprobs = [], []
|
51
|
+
pt = 0
|
52
|
+
# NOTE: the GPU-CPU overhead can be reduced
|
53
|
+
extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy()
|
54
|
+
for i in range(len(extend_seq_lens_cpu)):
|
55
|
+
if extend_seq_lens_cpu[i] == 0:
|
56
|
+
prefill_top_logprobs.append([])
|
57
|
+
decode_top_logprobs.append([])
|
58
|
+
continue
|
59
|
+
k = input_metadata.top_logprobs_nums[i]
|
60
|
+
t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
|
61
|
+
vs_cpu = t.values.tolist()
|
62
|
+
ps_cpu = t.indices.tolist()
|
63
|
+
prefill_top_logprobs.append(
|
64
|
+
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
|
65
|
+
)
|
66
|
+
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
67
|
+
pt += extend_seq_lens_cpu[i]
|
68
|
+
return prefill_top_logprobs, decode_top_logprobs
|
69
|
+
|
70
|
+
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
|
71
|
+
# Get last index for next token prediction, except for DECODE mode.
|
72
|
+
last_index = None
|
21
73
|
if input_metadata.forward_mode != ForwardMode.DECODE:
|
22
74
|
last_index = (
|
23
|
-
torch.cumsum(
|
24
|
-
input_metadata.seq_lens - input_metadata.prefix_lens,
|
25
|
-
dim=0,
|
26
|
-
dtype=torch.long,
|
27
|
-
)
|
75
|
+
torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long)
|
28
76
|
- 1
|
29
77
|
)
|
30
78
|
|
79
|
+
# Get the last hidden states and last logits
|
80
|
+
if input_metadata.forward_mode == ForwardMode.DECODE:
|
81
|
+
last_hidden = hidden_states
|
82
|
+
else:
|
83
|
+
last_hidden = hidden_states[last_index]
|
84
|
+
|
85
|
+
last_logits = torch.matmul(last_hidden, weight.T)
|
86
|
+
if self.tp_size > 1:
|
87
|
+
last_logits = tensor_model_parallel_all_gather(last_logits)
|
88
|
+
last_logits = last_logits[:, : self.config.vocab_size]
|
89
|
+
|
90
|
+
# Return only last_logits if logprob is not requested
|
31
91
|
if not input_metadata.return_logprob:
|
32
|
-
|
33
|
-
|
34
|
-
last_hidden = hidden_states
|
35
|
-
else:
|
36
|
-
last_hidden = hidden_states[last_index]
|
37
|
-
hidden_states = None
|
38
|
-
|
39
|
-
last_logits = torch.matmul(last_hidden, weight.T)
|
40
|
-
if self.tp_size > 1:
|
41
|
-
last_logits = tensor_model_parallel_all_gather(last_logits)
|
42
|
-
last_logits = last_logits[:, : self.config.vocab_size]
|
43
|
-
return last_logits, (None, None, None)
|
92
|
+
hidden_states = None
|
93
|
+
return last_logits, (None, None, None, None, None)
|
44
94
|
else:
|
45
95
|
# When logprob is requested, compute the logits for all tokens.
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
96
|
+
if input_metadata.forward_mode == ForwardMode.DECODE:
|
97
|
+
all_logits = last_logits
|
98
|
+
else:
|
99
|
+
all_logits = torch.matmul(hidden_states, weight.T)
|
100
|
+
if self.tp_size > 1:
|
101
|
+
all_logits = tensor_model_parallel_all_gather(all_logits)
|
102
|
+
all_logits = all_logits[:, : self.config.vocab_size]
|
103
|
+
|
104
|
+
all_logprobs = all_logits.float()
|
105
|
+
del all_logits
|
106
|
+
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
107
|
+
|
108
|
+
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
|
109
|
+
if return_top_logprob:
|
110
|
+
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
111
|
+
all_logprobs, input_metadata
|
112
|
+
)
|
113
|
+
else:
|
114
|
+
prefill_top_logprobs = decode_top_logprobs = None
|
51
115
|
|
52
116
|
if input_metadata.forward_mode == ForwardMode.DECODE:
|
53
|
-
last_logits = logits
|
54
117
|
last_logprobs = all_logprobs
|
55
|
-
|
118
|
+
return last_logits, (
|
119
|
+
None,
|
120
|
+
None,
|
121
|
+
None,
|
122
|
+
decode_top_logprobs,
|
123
|
+
last_logprobs,
|
124
|
+
)
|
56
125
|
else:
|
57
126
|
# Compute the logprobs for the last token of each request.
|
58
|
-
last_logits = logits[last_index]
|
59
127
|
last_logprobs = all_logprobs[last_index]
|
60
128
|
|
61
129
|
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
62
130
|
# Note that we pad a zero at the end of each sequence for easy computation.
|
63
|
-
|
131
|
+
prefill_token_logprobs = all_logprobs[
|
64
132
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
65
133
|
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
66
134
|
]
|
67
|
-
logprobs_cumsum = torch.cumsum(
|
68
|
-
prefill_logprobs, dim=0, dtype=torch.float32
|
69
|
-
)
|
70
135
|
|
71
|
-
|
72
|
-
|
73
|
-
start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
74
|
-
end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
75
|
-
sum_logp = (
|
76
|
-
logprobs_cumsum[end]
|
77
|
-
- logprobs_cumsum[start]
|
78
|
-
+ prefill_logprobs[start]
|
136
|
+
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
137
|
+
prefill_token_logprobs, input_metadata
|
79
138
|
)
|
80
|
-
|
81
|
-
|
139
|
+
return last_logits, (
|
140
|
+
prefill_token_logprobs,
|
141
|
+
normalized_prompt_logprobs,
|
142
|
+
prefill_top_logprobs,
|
143
|
+
decode_top_logprobs,
|
144
|
+
last_logprobs,
|
82
145
|
)
|
83
146
|
|
84
|
-
return last_logits, (prefill_logprobs, normalized_logprobs, last_logprobs)
|
85
|
-
|
86
147
|
|
87
148
|
if __name__ == "__main__":
|
88
149
|
all_logprobs = torch.tensor(
|
@@ -93,23 +154,22 @@ if __name__ == "__main__":
|
|
93
154
|
)
|
94
155
|
seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda")
|
95
156
|
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
|
96
|
-
logprobs = torch.zeros(5, dtype=torch.float32, device="cuda")
|
97
157
|
|
98
|
-
|
158
|
+
token_logprobs = all_logprobs[
|
99
159
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
100
160
|
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
101
161
|
]
|
102
|
-
logprobs_cumsum = torch.cumsum(
|
162
|
+
logprobs_cumsum = torch.cumsum(token_logprobs, dim=0, dtype=torch.float32)
|
103
163
|
|
104
164
|
len_cumsum = torch.cumsum(seq_lens, dim=0)
|
105
165
|
start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0)
|
106
166
|
end = start + seq_lens - 2
|
107
|
-
start.clamp_(min=0, max=
|
108
|
-
end.clamp_(min=0, max=
|
109
|
-
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] +
|
167
|
+
start.clamp_(min=0, max=token_logprobs.shape[0] - 1)
|
168
|
+
end.clamp_(min=0, max=token_logprobs.shape[0] - 1)
|
169
|
+
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + token_logprobs[start]
|
110
170
|
|
111
171
|
# assert logprobs == [2, _, 2, 4, _]
|
112
|
-
print("logprobs",
|
172
|
+
print("token logprobs", token_logprobs)
|
113
173
|
print("start", start)
|
114
174
|
print("end", end)
|
115
175
|
print("sum_logp", sum_logp)
|
@@ -1,9 +1,10 @@
|
|
1
1
|
import torch
|
2
|
+
from torch import nn
|
3
|
+
|
2
4
|
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
3
5
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
4
6
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
5
7
|
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
6
|
-
from torch import nn
|
7
8
|
|
8
9
|
|
9
10
|
class RadixAttention(nn.Module):
|
@@ -3,6 +3,7 @@ import asyncio
|
|
3
3
|
import uvloop
|
4
4
|
import zmq
|
5
5
|
import zmq.asyncio
|
6
|
+
|
6
7
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
7
8
|
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
8
9
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
@@ -37,10 +38,13 @@ class DetokenizerManager:
|
|
37
38
|
if isinstance(recv_obj, BatchTokenIDOut):
|
38
39
|
output_tokens = recv_obj.output_tokens
|
39
40
|
|
40
|
-
# TODO(lmzheng): handle skip_special_tokens per request
|
41
|
+
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
41
42
|
output_strs = self.tokenizer.batch_decode(
|
42
43
|
output_tokens,
|
43
44
|
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
45
|
+
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
|
46
|
+
0
|
47
|
+
],
|
44
48
|
)
|
45
49
|
|
46
50
|
# Trim stop str
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -8,7 +8,9 @@ from sglang.srt.sampling_params import SamplingParams
|
|
8
8
|
@dataclass
|
9
9
|
class GenerateReqInput:
|
10
10
|
# The input prompt
|
11
|
-
text: Union[List[str], str]
|
11
|
+
text: Optional[Union[List[str], str]] = None
|
12
|
+
# The token ids for text; one can either specify text or input_ids
|
13
|
+
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
12
14
|
# The image input
|
13
15
|
image_data: Optional[Union[List[str], str]] = None
|
14
16
|
# The sampling_params
|
@@ -19,13 +21,26 @@ class GenerateReqInput:
|
|
19
21
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
20
22
|
# The start location of the prompt for return_logprob
|
21
23
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
24
|
+
# The number of top logprobs to return
|
25
|
+
top_logprobs_num: Optional[Union[List[int], int]] = None
|
22
26
|
# Whether to detokenize tokens in logprobs
|
23
27
|
return_text_in_logprobs: bool = False
|
24
28
|
# Whether to stream output
|
25
29
|
stream: bool = False
|
30
|
+
# TODO: make all parameters a Union[List[T], T] to allow for batched requests
|
26
31
|
|
27
32
|
def post_init(self):
|
28
|
-
|
33
|
+
|
34
|
+
if self.text is None:
|
35
|
+
assert self.input_ids is not None, "Either text or input_ids should be provided"
|
36
|
+
else:
|
37
|
+
assert self.input_ids is None, "Either text or input_ids should be provided"
|
38
|
+
|
39
|
+
if self.text is not None:
|
40
|
+
is_single = isinstance(self.text, str)
|
41
|
+
else:
|
42
|
+
is_single = isinstance(self.input_ids[0], int)
|
43
|
+
self.is_single = is_single
|
29
44
|
|
30
45
|
if is_single:
|
31
46
|
if self.sampling_params is None:
|
@@ -36,8 +51,10 @@ class GenerateReqInput:
|
|
36
51
|
self.return_logprob = False
|
37
52
|
if self.logprob_start_len is None:
|
38
53
|
self.logprob_start_len = 0
|
54
|
+
if self.top_logprobs_num is None:
|
55
|
+
self.top_logprobs_num = 0
|
39
56
|
else:
|
40
|
-
num = len(self.text)
|
57
|
+
num = len(self.text) if self.text is not None else len(self.input_ids)
|
41
58
|
|
42
59
|
if self.image_data is None:
|
43
60
|
self.image_data = [None] * num
|
@@ -64,6 +81,11 @@ class GenerateReqInput:
|
|
64
81
|
elif not isinstance(self.logprob_start_len, list):
|
65
82
|
self.logprob_start_len = [self.logprob_start_len] * num
|
66
83
|
|
84
|
+
if self.top_logprobs_num is None:
|
85
|
+
self.top_logprobs_num = [0] * num
|
86
|
+
elif not isinstance(self.top_logprobs_num, list):
|
87
|
+
self.top_logprobs_num = [self.top_logprobs_num] * num
|
88
|
+
|
67
89
|
|
68
90
|
@dataclass
|
69
91
|
class TokenizedGenerateReqInput:
|
@@ -76,6 +98,7 @@ class TokenizedGenerateReqInput:
|
|
76
98
|
sampling_params: SamplingParams
|
77
99
|
return_logprob: bool
|
78
100
|
logprob_start_len: int
|
101
|
+
top_logprobs_num: int
|
79
102
|
stream: bool
|
80
103
|
|
81
104
|
|
@@ -86,6 +109,7 @@ class BatchTokenIDOut:
|
|
86
109
|
output_and_jump_forward_strs: List[str]
|
87
110
|
hit_stop_str: List[Optional[str]]
|
88
111
|
skip_special_tokens: List[bool]
|
112
|
+
spaces_between_special_tokens: List[bool]
|
89
113
|
meta_info: List[Dict]
|
90
114
|
finished: List[bool]
|
91
115
|
|