sglang 0.1.18__py3-none-any.whl → 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/api.py +26 -0
- sglang/backend/runtime_endpoint.py +18 -14
- sglang/bench_latency.py +34 -16
- sglang/global_config.py +1 -0
- sglang/lang/chat_template.py +41 -6
- sglang/lang/interpreter.py +5 -1
- sglang/lang/ir.py +61 -25
- sglang/srt/constrained/__init__.py +3 -2
- sglang/srt/hf_transformers_utils.py +7 -3
- sglang/srt/layers/extend_attention.py +2 -1
- sglang/srt/layers/fused_moe.py +181 -167
- sglang/srt/layers/logits_processor.py +55 -19
- sglang/srt/layers/radix_attention.py +24 -27
- sglang/srt/layers/token_attention.py +4 -1
- sglang/srt/managers/controller/infer_batch.py +2 -2
- sglang/srt/managers/controller/manager_single.py +1 -1
- sglang/srt/managers/controller/model_runner.py +27 -15
- sglang/srt/managers/controller/tp_worker.py +31 -14
- sglang/srt/managers/detokenizer_manager.py +4 -2
- sglang/srt/managers/io_struct.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +14 -13
- sglang/srt/model_config.py +6 -0
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/llama2.py +3 -3
- sglang/srt/models/llama_classification.py +10 -7
- sglang/srt/models/minicpm.py +373 -0
- sglang/srt/models/qwen2_moe.py +454 -0
- sglang/srt/openai_api_adapter.py +2 -2
- sglang/srt/openai_protocol.py +1 -1
- sglang/srt/server.py +17 -8
- sglang/srt/server_args.py +14 -16
- sglang/srt/utils.py +68 -35
- {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/METADATA +19 -13
- {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/RECORD +38 -35
- {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
- {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/WHEEL +0 -0
- {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
sglang/__init__.py
CHANGED
sglang/api.py
CHANGED
@@ -67,10 +67,16 @@ def gen(
|
|
67
67
|
frequency_penalty: Optional[float] = None,
|
68
68
|
presence_penalty: Optional[float] = None,
|
69
69
|
ignore_eos: Optional[bool] = None,
|
70
|
+
return_logprob: Optional[bool] = None,
|
71
|
+
logprob_start_len: Optional[int] = None,
|
72
|
+
top_logprobs_num: Optional[int] = None,
|
73
|
+
return_text_in_logprobs: Optional[bool] = None,
|
70
74
|
dtype: Optional[type] = None,
|
71
75
|
choices: Optional[List[str]] = None,
|
72
76
|
regex: Optional[str] = None,
|
73
77
|
):
|
78
|
+
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
|
79
|
+
|
74
80
|
if choices:
|
75
81
|
return SglSelect(name, choices, 0.0 if temperature is None else temperature)
|
76
82
|
|
@@ -91,6 +97,10 @@ def gen(
|
|
91
97
|
frequency_penalty,
|
92
98
|
presence_penalty,
|
93
99
|
ignore_eos,
|
100
|
+
return_logprob,
|
101
|
+
logprob_start_len,
|
102
|
+
top_logprobs_num,
|
103
|
+
return_text_in_logprobs,
|
94
104
|
dtype,
|
95
105
|
regex,
|
96
106
|
)
|
@@ -106,6 +116,10 @@ def gen_int(
|
|
106
116
|
frequency_penalty: Optional[float] = None,
|
107
117
|
presence_penalty: Optional[float] = None,
|
108
118
|
ignore_eos: Optional[bool] = None,
|
119
|
+
return_logprob: Optional[bool] = None,
|
120
|
+
logprob_start_len: Optional[int] = None,
|
121
|
+
top_logprobs_num: Optional[int] = None,
|
122
|
+
return_text_in_logprobs: Optional[bool] = None,
|
109
123
|
):
|
110
124
|
return SglGen(
|
111
125
|
name,
|
@@ -117,6 +131,10 @@ def gen_int(
|
|
117
131
|
frequency_penalty,
|
118
132
|
presence_penalty,
|
119
133
|
ignore_eos,
|
134
|
+
return_logprob,
|
135
|
+
logprob_start_len,
|
136
|
+
top_logprobs_num,
|
137
|
+
return_text_in_logprobs,
|
120
138
|
int,
|
121
139
|
None,
|
122
140
|
)
|
@@ -132,6 +150,10 @@ def gen_string(
|
|
132
150
|
frequency_penalty: Optional[float] = None,
|
133
151
|
presence_penalty: Optional[float] = None,
|
134
152
|
ignore_eos: Optional[bool] = None,
|
153
|
+
return_logprob: Optional[bool] = None,
|
154
|
+
logprob_start_len: Optional[int] = None,
|
155
|
+
top_logprobs_num: Optional[int] = None,
|
156
|
+
return_text_in_logprobs: Optional[bool] = None,
|
135
157
|
):
|
136
158
|
return SglGen(
|
137
159
|
name,
|
@@ -143,6 +165,10 @@ def gen_string(
|
|
143
165
|
frequency_penalty,
|
144
166
|
presence_penalty,
|
145
167
|
ignore_eos,
|
168
|
+
return_logprob,
|
169
|
+
logprob_start_len,
|
170
|
+
top_logprobs_num,
|
171
|
+
return_text_in_logprobs,
|
146
172
|
str,
|
147
173
|
None,
|
148
174
|
)
|
@@ -1,18 +1,18 @@
|
|
1
1
|
import json
|
2
|
-
from typing import
|
2
|
+
from typing import List, Optional
|
3
3
|
|
4
4
|
import numpy as np
|
5
|
-
import requests
|
6
5
|
|
7
6
|
from sglang.backend.base_backend import BaseBackend
|
8
7
|
from sglang.global_config import global_config
|
9
8
|
from sglang.lang.chat_template import get_chat_template_by_model_path
|
10
9
|
from sglang.lang.interpreter import StreamExecutor
|
11
|
-
from sglang.lang.ir import
|
12
|
-
from sglang.utils import
|
10
|
+
from sglang.lang.ir import SglSamplingParams
|
11
|
+
from sglang.utils import http_request
|
13
12
|
|
14
13
|
|
15
14
|
class RuntimeEndpoint(BaseBackend):
|
15
|
+
|
16
16
|
def __init__(
|
17
17
|
self,
|
18
18
|
base_url: str,
|
@@ -38,8 +38,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
38
38
|
self.model_info = res.json()
|
39
39
|
|
40
40
|
self.chat_template = get_chat_template_by_model_path(
|
41
|
-
self.model_info["model_path"]
|
42
|
-
)
|
41
|
+
self.model_info["model_path"])
|
43
42
|
|
44
43
|
def get_model_name(self):
|
45
44
|
return self.model_info["model_path"]
|
@@ -125,6 +124,11 @@ class RuntimeEndpoint(BaseBackend):
|
|
125
124
|
else:
|
126
125
|
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
127
126
|
|
127
|
+
for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
|
128
|
+
value = getattr(sampling_params, item, None)
|
129
|
+
if value is not None:
|
130
|
+
data[item] = value
|
131
|
+
|
128
132
|
self._add_images(s, data)
|
129
133
|
|
130
134
|
res = http_request(
|
@@ -167,6 +171,11 @@ class RuntimeEndpoint(BaseBackend):
|
|
167
171
|
else:
|
168
172
|
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
169
173
|
|
174
|
+
for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
|
175
|
+
value = getattr(sampling_params, item, None)
|
176
|
+
if value is not None:
|
177
|
+
data[item] = value
|
178
|
+
|
170
179
|
data["stream"] = True
|
171
180
|
self._add_images(s, data)
|
172
181
|
|
@@ -181,21 +190,16 @@ class RuntimeEndpoint(BaseBackend):
|
|
181
190
|
self._assert_success(res)
|
182
191
|
pos = 0
|
183
192
|
|
184
|
-
incomplete_text = ""
|
185
193
|
for chunk in res.iter_lines(decode_unicode=False):
|
186
194
|
chunk = chunk.decode("utf-8")
|
187
195
|
if chunk and chunk.startswith("data:"):
|
188
196
|
if chunk == "data: [DONE]":
|
189
197
|
break
|
190
198
|
data = json.loads(chunk[5:].strip("\n"))
|
191
|
-
|
199
|
+
chunk_text = data["text"][pos:]
|
192
200
|
meta_info = data["meta_info"]
|
193
|
-
pos += len(
|
194
|
-
|
195
|
-
yield text, meta_info
|
196
|
-
|
197
|
-
if len(incomplete_text) > 0:
|
198
|
-
yield incomplete_text, meta_info
|
201
|
+
pos += len(chunk_text)
|
202
|
+
yield chunk_text, meta_info
|
199
203
|
|
200
204
|
def select(
|
201
205
|
self,
|
sglang/bench_latency.py
CHANGED
@@ -108,7 +108,7 @@ def prepare_inputs(bench_args, tokenizer):
|
|
108
108
|
for i in range(len(prompts)):
|
109
109
|
assert len(input_ids[i]) > bench_args.cut_len
|
110
110
|
|
111
|
-
tmp_input_ids = input_ids[i][:bench_args.cut_len]
|
111
|
+
tmp_input_ids = input_ids[i][: bench_args.cut_len]
|
112
112
|
req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
|
113
113
|
req.prefix_indices = []
|
114
114
|
req.sampling_params = sampling_params
|
@@ -121,9 +121,9 @@ def prepare_inputs(bench_args, tokenizer):
|
|
121
121
|
def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
|
122
122
|
for i in range(len(reqs)):
|
123
123
|
req = reqs[i]
|
124
|
-
req.input_ids += input_ids[i][bench_args.cut_len:]
|
124
|
+
req.input_ids += input_ids[i][bench_args.cut_len :]
|
125
125
|
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
126
|
-
i, :bench_args.cut_len
|
126
|
+
i, : bench_args.cut_len
|
127
127
|
]
|
128
128
|
return reqs
|
129
129
|
|
@@ -151,7 +151,8 @@ def extend(reqs, model_runner):
|
|
151
151
|
reqs=reqs,
|
152
152
|
req_to_token_pool=model_runner.req_to_token_pool,
|
153
153
|
token_to_kv_pool=model_runner.token_to_kv_pool,
|
154
|
-
tree_cache=None
|
154
|
+
tree_cache=None,
|
155
|
+
)
|
155
156
|
batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
|
156
157
|
output = model_runner.forward(batch, ForwardMode.EXTEND)
|
157
158
|
next_token_ids, _ = batch.sample(output.next_token_logits)
|
@@ -165,6 +166,7 @@ def decode(input_token_ids, batch, model_runner):
|
|
165
166
|
return next_token_ids, output.next_token_logits
|
166
167
|
|
167
168
|
|
169
|
+
@torch.inference_mode()
|
168
170
|
def correctness_test(
|
169
171
|
server_args,
|
170
172
|
bench_args,
|
@@ -178,9 +180,10 @@ def correctness_test(
|
|
178
180
|
# Prepare inputs
|
179
181
|
input_ids, reqs = prepare_inputs(bench_args, tokenizer)
|
180
182
|
|
181
|
-
|
182
|
-
|
183
|
-
|
183
|
+
if bench_args.cut_len > 0:
|
184
|
+
# Prefill
|
185
|
+
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
186
|
+
rank_print("prefill logits (first half)", next_token_logits)
|
184
187
|
|
185
188
|
# Prepare extend inputs
|
186
189
|
reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner)
|
@@ -190,7 +193,7 @@ def correctness_test(
|
|
190
193
|
rank_print("prefill logits (final)", next_token_logits)
|
191
194
|
|
192
195
|
# Decode
|
193
|
-
output_ids = [
|
196
|
+
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
|
194
197
|
for _ in range(bench_args.output_len):
|
195
198
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
196
199
|
for i in range(len(reqs)):
|
@@ -210,7 +213,9 @@ def latency_test(
|
|
210
213
|
|
211
214
|
# Load the model
|
212
215
|
model_runner, tokenizer = load_model(server_args, tp_rank)
|
213
|
-
print(
|
216
|
+
print(
|
217
|
+
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
|
218
|
+
)
|
214
219
|
|
215
220
|
# Prepare inputs
|
216
221
|
reqs = prepare_synthetic_inputs(bench_args, tokenizer)
|
@@ -230,7 +235,9 @@ def latency_test(
|
|
230
235
|
prefill_latency = time.time() - tic
|
231
236
|
tot_latency += prefill_latency
|
232
237
|
throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
|
233
|
-
rank_print(
|
238
|
+
rank_print(
|
239
|
+
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
240
|
+
)
|
234
241
|
|
235
242
|
# Decode
|
236
243
|
for i in range(output_len):
|
@@ -241,13 +248,24 @@ def latency_test(
|
|
241
248
|
latency = time.time() - tic
|
242
249
|
tot_latency += latency
|
243
250
|
throughput = bench_args.batch_size / latency
|
244
|
-
if i < 5:
|
251
|
+
if i < 5:
|
252
|
+
rank_print(
|
253
|
+
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
254
|
+
)
|
245
255
|
avg_decode_latency = (tot_latency - prefill_latency) / output_len
|
246
256
|
avg_decode_throughput = bench_args.batch_size / avg_decode_latency
|
247
|
-
rank_print(
|
248
|
-
|
249
|
-
|
250
|
-
|
257
|
+
rank_print(
|
258
|
+
f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
|
259
|
+
)
|
260
|
+
|
261
|
+
throughput = (
|
262
|
+
(bench_args.input_len + bench_args.output_len)
|
263
|
+
* bench_args.batch_size
|
264
|
+
/ tot_latency
|
265
|
+
)
|
266
|
+
rank_print(
|
267
|
+
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
|
268
|
+
)
|
251
269
|
|
252
270
|
# Warm up
|
253
271
|
run_once(4)
|
@@ -296,4 +314,4 @@ if __name__ == "__main__":
|
|
296
314
|
format="%(message)s",
|
297
315
|
)
|
298
316
|
|
299
|
-
main(server_args, bench_args)
|
317
|
+
main(server_args, bench_args)
|
sglang/global_config.py
CHANGED
sglang/lang/chat_template.py
CHANGED
@@ -84,7 +84,7 @@ register_chat_template(
|
|
84
84
|
"system": ("SYSTEM:", "\n"),
|
85
85
|
"user": ("USER:", "\n"),
|
86
86
|
"assistant": ("ASSISTANT:", "\n"),
|
87
|
-
}
|
87
|
+
}
|
88
88
|
)
|
89
89
|
)
|
90
90
|
|
@@ -116,6 +116,23 @@ register_chat_template(
|
|
116
116
|
)
|
117
117
|
)
|
118
118
|
|
119
|
+
# There is default system prompt for qwen
|
120
|
+
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
|
121
|
+
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
122
|
+
register_chat_template(
|
123
|
+
ChatTemplate(
|
124
|
+
name="qwen",
|
125
|
+
default_system_prompt="You are a helpful assistant.",
|
126
|
+
role_prefix_and_suffix={
|
127
|
+
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
128
|
+
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
129
|
+
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
130
|
+
},
|
131
|
+
style=ChatTemplateStyle.PLAIN,
|
132
|
+
stop_str=("<|im_end|>",),
|
133
|
+
)
|
134
|
+
)
|
135
|
+
|
119
136
|
|
120
137
|
register_chat_template(
|
121
138
|
ChatTemplate(
|
@@ -132,6 +149,7 @@ register_chat_template(
|
|
132
149
|
)
|
133
150
|
)
|
134
151
|
|
152
|
+
# Reference: https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
|
135
153
|
register_chat_template(
|
136
154
|
ChatTemplate(
|
137
155
|
name="vicuna_v1.1",
|
@@ -148,6 +166,20 @@ register_chat_template(
|
|
148
166
|
)
|
149
167
|
)
|
150
168
|
|
169
|
+
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
|
170
|
+
register_chat_template(
|
171
|
+
ChatTemplate(
|
172
|
+
name="yi-1.5",
|
173
|
+
default_system_prompt=None,
|
174
|
+
role_prefix_and_suffix={
|
175
|
+
"system": ("", ""),
|
176
|
+
"user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
|
177
|
+
"assistant": ("", "<|im_end|>\n"),
|
178
|
+
},
|
179
|
+
style=ChatTemplateStyle.PLAIN,
|
180
|
+
stop_str=("<|im_end|>",)
|
181
|
+
)
|
182
|
+
)
|
151
183
|
|
152
184
|
register_chat_template(
|
153
185
|
ChatTemplate(
|
@@ -187,7 +219,7 @@ register_chat_template(
|
|
187
219
|
# Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava
|
188
220
|
register_chat_template(
|
189
221
|
ChatTemplate(
|
190
|
-
name="yi",
|
222
|
+
name="yi-vl",
|
191
223
|
default_system_prompt=(
|
192
224
|
"This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers."
|
193
225
|
"这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。"
|
@@ -289,8 +321,9 @@ def match_chat_ml(model_path: str):
|
|
289
321
|
model_path = model_path.lower()
|
290
322
|
if "tinyllama" in model_path:
|
291
323
|
return get_chat_template("chatml")
|
292
|
-
|
293
|
-
|
324
|
+
# Now the suffix for qwen2 chat model is "instruct"
|
325
|
+
if "qwen" in model_path and ("chat" in model_path or "instruct" in model_path):
|
326
|
+
return get_chat_template("qwen")
|
294
327
|
if (
|
295
328
|
"llava-v1.6-34b" in model_path
|
296
329
|
or "llava-v1.6-yi-34b" in model_path
|
@@ -302,8 +335,10 @@ def match_chat_ml(model_path: str):
|
|
302
335
|
@register_chat_template_matching_function
|
303
336
|
def match_chat_yi(model_path: str):
|
304
337
|
model_path = model_path.lower()
|
305
|
-
if "yi" in model_path and "llava" not in model_path:
|
306
|
-
return get_chat_template("yi")
|
338
|
+
if "yi-vl" in model_path and "llava" not in model_path:
|
339
|
+
return get_chat_template("yi-vl")
|
340
|
+
elif "yi-1.5" in model_path and "chat" in model_path:
|
341
|
+
return get_chat_template("yi-1.5")
|
307
342
|
|
308
343
|
|
309
344
|
@register_chat_template_matching_function
|
sglang/lang/interpreter.py
CHANGED
@@ -523,9 +523,9 @@ class StreamExecutor:
|
|
523
523
|
self, sampling_params=sampling_params
|
524
524
|
)
|
525
525
|
|
526
|
+
self.variables[name] = ""
|
526
527
|
self.stream_var_event[name].set()
|
527
528
|
|
528
|
-
self.variables[name] = ""
|
529
529
|
for comp, meta_info in generator:
|
530
530
|
self.text_ += comp
|
531
531
|
self.variables[name] += comp
|
@@ -668,6 +668,10 @@ class StreamExecutor:
|
|
668
668
|
"frequency_penalty",
|
669
669
|
"presence_penalty",
|
670
670
|
"ignore_eos",
|
671
|
+
"return_logprob",
|
672
|
+
"logprob_start_len",
|
673
|
+
"top_logprobs_num",
|
674
|
+
"return_text_in_logprobs",
|
671
675
|
"dtype",
|
672
676
|
"regex",
|
673
677
|
]:
|
sglang/lang/ir.py
CHANGED
@@ -23,6 +23,10 @@ class SglSamplingParams:
|
|
23
23
|
frequency_penalty: float = 0.0
|
24
24
|
presence_penalty: float = 0.0
|
25
25
|
ignore_eos: bool = False
|
26
|
+
return_logprob: Optional[bool] = None
|
27
|
+
logprob_start_len: Optional[int] = None,
|
28
|
+
top_logprobs_num: Optional[int] = None,
|
29
|
+
return_text_in_logprobs: Optional[bool] = None,
|
26
30
|
|
27
31
|
# for constrained generation, not included in to_xxx_kwargs
|
28
32
|
dtype: Optional[str] = None
|
@@ -37,6 +41,11 @@ class SglSamplingParams:
|
|
37
41
|
self.top_k,
|
38
42
|
self.frequency_penalty,
|
39
43
|
self.presence_penalty,
|
44
|
+
self.ignore_eos,
|
45
|
+
self.return_logprob,
|
46
|
+
self.logprob_start_len,
|
47
|
+
self.top_logprobs_num,
|
48
|
+
self.return_text_in_logprobs,
|
40
49
|
)
|
41
50
|
|
42
51
|
def to_openai_kwargs(self):
|
@@ -139,6 +148,10 @@ class SglFunction:
|
|
139
148
|
frequency_penalty: float = 0.0,
|
140
149
|
presence_penalty: float = 0.0,
|
141
150
|
ignore_eos: bool = False,
|
151
|
+
return_logprob: Optional[bool] = None,
|
152
|
+
logprob_start_len: Optional[int] = None,
|
153
|
+
top_logprobs_num: Optional[int] = None,
|
154
|
+
return_text_in_logprobs: Optional[bool] = None,
|
142
155
|
stream: bool = False,
|
143
156
|
backend=None,
|
144
157
|
**kwargs,
|
@@ -154,6 +167,10 @@ class SglFunction:
|
|
154
167
|
frequency_penalty=frequency_penalty,
|
155
168
|
presence_penalty=presence_penalty,
|
156
169
|
ignore_eos=ignore_eos,
|
170
|
+
return_logprob=return_logprob,
|
171
|
+
logprob_start_len=logprob_start_len,
|
172
|
+
top_logprobs_num=top_logprobs_num,
|
173
|
+
return_text_in_logprobs=return_text_in_logprobs,
|
157
174
|
)
|
158
175
|
backend = backend or global_config.default_backend
|
159
176
|
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
|
@@ -170,6 +187,10 @@ class SglFunction:
|
|
170
187
|
frequency_penalty: float = 0.0,
|
171
188
|
presence_penalty: float = 0.0,
|
172
189
|
ignore_eos: bool = False,
|
190
|
+
return_logprob: Optional[bool] = None,
|
191
|
+
logprob_start_len: Optional[int] = None,
|
192
|
+
top_logprobs_num: Optional[int] = None,
|
193
|
+
return_text_in_logprobs: Optional[bool] = None,
|
173
194
|
backend=None,
|
174
195
|
num_threads: Union[str, int] = "auto",
|
175
196
|
progress_bar: bool = False,
|
@@ -185,8 +206,10 @@ class SglFunction:
|
|
185
206
|
batch_kwargs = [
|
186
207
|
{self.arg_names[i]: v for i, v in enumerate(arg_values)}
|
187
208
|
for arg_values in batch_kwargs
|
188
|
-
if isinstance(arg_values, (list, tuple))
|
189
|
-
|
209
|
+
if isinstance(arg_values, (list, tuple))
|
210
|
+
and len(self.arg_names) - len(self.arg_defaults)
|
211
|
+
<= len(arg_values)
|
212
|
+
<= len(self.arg_names)
|
190
213
|
]
|
191
214
|
# Ensure to raise an exception if the number of arguments mismatch
|
192
215
|
if len(batch_kwargs) != num_programs:
|
@@ -201,6 +224,10 @@ class SglFunction:
|
|
201
224
|
frequency_penalty=frequency_penalty,
|
202
225
|
presence_penalty=presence_penalty,
|
203
226
|
ignore_eos=ignore_eos,
|
227
|
+
return_logprob=return_logprob,
|
228
|
+
logprob_start_len=logprob_start_len,
|
229
|
+
top_logprobs_num=top_logprobs_num,
|
230
|
+
return_text_in_logprobs=return_text_in_logprobs,
|
204
231
|
)
|
205
232
|
backend = backend or global_config.default_backend
|
206
233
|
return run_program_batch(
|
@@ -348,7 +375,7 @@ class SglArgument(SglExpr):
|
|
348
375
|
|
349
376
|
|
350
377
|
class SglImage(SglExpr):
|
351
|
-
def __init__(self, path):
|
378
|
+
def __init__(self, path: str):
|
352
379
|
self.path = path
|
353
380
|
|
354
381
|
def __repr__(self) -> str:
|
@@ -356,7 +383,7 @@ class SglImage(SglExpr):
|
|
356
383
|
|
357
384
|
|
358
385
|
class SglVideo(SglExpr):
|
359
|
-
def __init__(self, path, num_frames):
|
386
|
+
def __init__(self, path: str, num_frames: int):
|
360
387
|
self.path = path
|
361
388
|
self.num_frames = num_frames
|
362
389
|
|
@@ -367,18 +394,23 @@ class SglVideo(SglExpr):
|
|
367
394
|
class SglGen(SglExpr):
|
368
395
|
def __init__(
|
369
396
|
self,
|
370
|
-
name,
|
371
|
-
max_new_tokens,
|
372
|
-
stop,
|
373
|
-
temperature,
|
374
|
-
top_p,
|
375
|
-
top_k,
|
376
|
-
frequency_penalty,
|
377
|
-
presence_penalty,
|
378
|
-
ignore_eos,
|
379
|
-
|
380
|
-
|
397
|
+
name: Optional[str] = None,
|
398
|
+
max_new_tokens: Optional[int] = None,
|
399
|
+
stop: Optional[Union[str, List[str]]] = None,
|
400
|
+
temperature: Optional[float] = None,
|
401
|
+
top_p: Optional[float] = None,
|
402
|
+
top_k: Optional[int] = None,
|
403
|
+
frequency_penalty: Optional[float] = None,
|
404
|
+
presence_penalty: Optional[float] = None,
|
405
|
+
ignore_eos: Optional[bool] = None,
|
406
|
+
return_logprob: Optional[bool] = None,
|
407
|
+
logprob_start_len: Optional[int] = None,
|
408
|
+
top_logprobs_num: Optional[int] = None,
|
409
|
+
return_text_in_logprobs: Optional[bool] = None,
|
410
|
+
dtype: Optional[type] = None,
|
411
|
+
regex: Optional[str] = None,
|
381
412
|
):
|
413
|
+
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
|
382
414
|
super().__init__()
|
383
415
|
self.name = name
|
384
416
|
self.sampling_params = SglSamplingParams(
|
@@ -390,6 +422,10 @@ class SglGen(SglExpr):
|
|
390
422
|
frequency_penalty=frequency_penalty,
|
391
423
|
presence_penalty=presence_penalty,
|
392
424
|
ignore_eos=ignore_eos,
|
425
|
+
return_logprob=return_logprob,
|
426
|
+
logprob_start_len=logprob_start_len,
|
427
|
+
top_logprobs_num=top_logprobs_num,
|
428
|
+
return_text_in_logprobs=return_text_in_logprobs,
|
393
429
|
dtype=dtype,
|
394
430
|
regex=regex,
|
395
431
|
)
|
@@ -399,7 +435,7 @@ class SglGen(SglExpr):
|
|
399
435
|
|
400
436
|
|
401
437
|
class SglConstantText(SglExpr):
|
402
|
-
def __init__(self, value):
|
438
|
+
def __init__(self, value: str):
|
403
439
|
super().__init__()
|
404
440
|
self.value = value
|
405
441
|
|
@@ -408,7 +444,7 @@ class SglConstantText(SglExpr):
|
|
408
444
|
|
409
445
|
|
410
446
|
class SglRoleBegin(SglExpr):
|
411
|
-
def __init__(self, role):
|
447
|
+
def __init__(self, role: str):
|
412
448
|
super().__init__()
|
413
449
|
self.role = role
|
414
450
|
|
@@ -417,7 +453,7 @@ class SglRoleBegin(SglExpr):
|
|
417
453
|
|
418
454
|
|
419
455
|
class SglRoleEnd(SglExpr):
|
420
|
-
def __init__(self, role):
|
456
|
+
def __init__(self, role: str):
|
421
457
|
super().__init__()
|
422
458
|
self.role = role
|
423
459
|
|
@@ -426,7 +462,7 @@ class SglRoleEnd(SglExpr):
|
|
426
462
|
|
427
463
|
|
428
464
|
class SglSelect(SglExpr):
|
429
|
-
def __init__(self, name, choices, temperature):
|
465
|
+
def __init__(self, name: str, choices: List[str], temperature: float):
|
430
466
|
super().__init__()
|
431
467
|
self.name = name
|
432
468
|
self.choices = choices
|
@@ -437,7 +473,7 @@ class SglSelect(SglExpr):
|
|
437
473
|
|
438
474
|
|
439
475
|
class SglFork(SglExpr):
|
440
|
-
def __init__(self, number, position_ids_offset=None):
|
476
|
+
def __init__(self, number: int, position_ids_offset=None):
|
441
477
|
super().__init__()
|
442
478
|
self.number = number
|
443
479
|
self.position_ids_offset = position_ids_offset
|
@@ -450,7 +486,7 @@ class SglFork(SglExpr):
|
|
450
486
|
|
451
487
|
|
452
488
|
class SglGetForkItem(SglExpr):
|
453
|
-
def __init__(self, index):
|
489
|
+
def __init__(self, index: int):
|
454
490
|
super().__init__()
|
455
491
|
self.index = index
|
456
492
|
|
@@ -459,7 +495,7 @@ class SglGetForkItem(SglExpr):
|
|
459
495
|
|
460
496
|
|
461
497
|
class SglVariable(SglExpr):
|
462
|
-
def __init__(self, name, source):
|
498
|
+
def __init__(self, name: str, source):
|
463
499
|
super().__init__()
|
464
500
|
self.name = name
|
465
501
|
self.source = source
|
@@ -469,7 +505,7 @@ class SglVariable(SglExpr):
|
|
469
505
|
|
470
506
|
|
471
507
|
class SglVarScopeBegin(SglExpr):
|
472
|
-
def __init__(self, name):
|
508
|
+
def __init__(self, name: str):
|
473
509
|
super().__init__()
|
474
510
|
self.name = name
|
475
511
|
|
@@ -478,7 +514,7 @@ class SglVarScopeBegin(SglExpr):
|
|
478
514
|
|
479
515
|
|
480
516
|
class SglVarScopeEnd(SglExpr):
|
481
|
-
def __init__(self, name):
|
517
|
+
def __init__(self, name: str):
|
482
518
|
super().__init__()
|
483
519
|
self.name = name
|
484
520
|
|
@@ -500,4 +536,4 @@ class SglCommitLazy(SglExpr):
|
|
500
536
|
super().__init__()
|
501
537
|
|
502
538
|
def __repr__(self):
|
503
|
-
return
|
539
|
+
return "CommitLazy()"
|
@@ -5,13 +5,14 @@ from pydantic import BaseModel
|
|
5
5
|
|
6
6
|
try:
|
7
7
|
from outlines.caching import cache as disk_cache
|
8
|
-
from outlines.fsm.guide import RegexGuide
|
9
8
|
from outlines.caching import disable_cache
|
10
9
|
from outlines.fsm.guide import RegexGuide
|
11
10
|
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
12
11
|
from outlines.models.transformers import TransformerTokenizer
|
13
12
|
except ImportError as e:
|
14
|
-
print(
|
13
|
+
print(
|
14
|
+
f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n'
|
15
|
+
)
|
15
16
|
raise
|
16
17
|
|
17
18
|
try:
|
@@ -264,7 +264,9 @@ class TiktokenTokenizer:
|
|
264
264
|
return self.tokenizer.decode_batch(batch)
|
265
265
|
|
266
266
|
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
|
267
|
-
ret = self.chat_template.render(
|
267
|
+
ret = self.chat_template.render(
|
268
|
+
messages=messages, add_generation_prompt=add_generation_prompt
|
269
|
+
)
|
268
270
|
return self.encode(ret) if tokenize else ret
|
269
271
|
|
270
272
|
|
@@ -297,5 +299,7 @@ class SentencePieceTokenizer:
|
|
297
299
|
return self.tokenizer.decode(batch)
|
298
300
|
|
299
301
|
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
|
300
|
-
ret = self.chat_template.render(
|
301
|
-
|
302
|
+
ret = self.chat_template.render(
|
303
|
+
messages=messages, add_generation_prompt=add_generation_prompt
|
304
|
+
)
|
305
|
+
return self.encode(ret) if tokenize else ret
|
@@ -191,6 +191,7 @@ def extend_attention_fwd(
|
|
191
191
|
b_seq_len_extend,
|
192
192
|
max_len_in_batch,
|
193
193
|
max_len_extend,
|
194
|
+
sm_scale=None,
|
194
195
|
logit_cap=-1,
|
195
196
|
):
|
196
197
|
"""
|
@@ -213,7 +214,7 @@ def extend_attention_fwd(
|
|
213
214
|
else:
|
214
215
|
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
215
216
|
|
216
|
-
sm_scale = 1.0 / (Lq**0.5)
|
217
|
+
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
|
217
218
|
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
|
218
219
|
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
219
220
|
|