sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +30 -4
- sglang/backend/litellm.py +2 -2
- sglang/backend/openai.py +26 -15
- sglang/backend/runtime_endpoint.py +18 -14
- sglang/bench_latency.py +317 -0
- sglang/global_config.py +5 -1
- sglang/lang/chat_template.py +41 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +6 -2
- sglang/lang/ir.py +74 -28
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +6 -3
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +2 -0
- sglang/srt/hf_transformers_utils.py +68 -9
- sglang/srt/layers/extend_attention.py +2 -1
- sglang/srt/layers/fused_moe.py +280 -169
- sglang/srt/layers/logits_processor.py +106 -42
- sglang/srt/layers/radix_attention.py +53 -29
- sglang/srt/layers/token_attention.py +4 -1
- sglang/srt/managers/controller/dp_worker.py +6 -3
- sglang/srt/managers/controller/infer_batch.py +144 -69
- sglang/srt/managers/controller/manager_multi.py +5 -5
- sglang/srt/managers/controller/manager_single.py +9 -4
- sglang/srt/managers/controller/model_runner.py +167 -55
- sglang/srt/managers/controller/radix_cache.py +4 -0
- sglang/srt/managers/controller/schedule_heuristic.py +2 -0
- sglang/srt/managers/controller/tp_worker.py +156 -134
- sglang/srt/managers/detokenizer_manager.py +19 -21
- sglang/srt/managers/io_struct.py +11 -5
- sglang/srt/managers/tokenizer_manager.py +16 -14
- sglang/srt/model_config.py +89 -4
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +204 -137
- sglang/srt/models/llama2.py +12 -5
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +11 -8
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +373 -0
- sglang/srt/models/mixtral.py +164 -115
- sglang/srt/models/mixtral_quant.py +0 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +454 -0
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/models/yivl.py +2 -2
- sglang/srt/openai_api_adapter.py +35 -25
- sglang/srt/openai_protocol.py +2 -2
- sglang/srt/server.py +69 -19
- sglang/srt/server_args.py +76 -43
- sglang/srt/utils.py +177 -35
- sglang/test/test_programs.py +28 -10
- sglang/utils.py +4 -3
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
- sglang-0.1.19.dist-info/RECORD +81 -0
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
- sglang/srt/managers/router/infer_batch.py +0 -596
- sglang/srt/managers/router/manager.py +0 -82
- sglang/srt/managers/router/model_rpc.py +0 -818
- sglang/srt/managers/router/model_runner.py +0 -445
- sglang/srt/managers/router/radix_cache.py +0 -267
- sglang/srt/managers/router/scheduler.py +0 -59
- sglang-0.1.17.dist-info/RECORD +0 -81
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
sglang/lang/chat_template.py
CHANGED
@@ -84,7 +84,7 @@ register_chat_template(
|
|
84
84
|
"system": ("SYSTEM:", "\n"),
|
85
85
|
"user": ("USER:", "\n"),
|
86
86
|
"assistant": ("ASSISTANT:", "\n"),
|
87
|
-
}
|
87
|
+
}
|
88
88
|
)
|
89
89
|
)
|
90
90
|
|
@@ -116,6 +116,23 @@ register_chat_template(
|
|
116
116
|
)
|
117
117
|
)
|
118
118
|
|
119
|
+
# There is default system prompt for qwen
|
120
|
+
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
|
121
|
+
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
122
|
+
register_chat_template(
|
123
|
+
ChatTemplate(
|
124
|
+
name="qwen",
|
125
|
+
default_system_prompt="You are a helpful assistant.",
|
126
|
+
role_prefix_and_suffix={
|
127
|
+
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
128
|
+
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
129
|
+
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
130
|
+
},
|
131
|
+
style=ChatTemplateStyle.PLAIN,
|
132
|
+
stop_str=("<|im_end|>",),
|
133
|
+
)
|
134
|
+
)
|
135
|
+
|
119
136
|
|
120
137
|
register_chat_template(
|
121
138
|
ChatTemplate(
|
@@ -132,6 +149,7 @@ register_chat_template(
|
|
132
149
|
)
|
133
150
|
)
|
134
151
|
|
152
|
+
# Reference: https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
|
135
153
|
register_chat_template(
|
136
154
|
ChatTemplate(
|
137
155
|
name="vicuna_v1.1",
|
@@ -148,6 +166,20 @@ register_chat_template(
|
|
148
166
|
)
|
149
167
|
)
|
150
168
|
|
169
|
+
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
|
170
|
+
register_chat_template(
|
171
|
+
ChatTemplate(
|
172
|
+
name="yi-1.5",
|
173
|
+
default_system_prompt=None,
|
174
|
+
role_prefix_and_suffix={
|
175
|
+
"system": ("", ""),
|
176
|
+
"user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
|
177
|
+
"assistant": ("", "<|im_end|>\n"),
|
178
|
+
},
|
179
|
+
style=ChatTemplateStyle.PLAIN,
|
180
|
+
stop_str=("<|im_end|>",)
|
181
|
+
)
|
182
|
+
)
|
151
183
|
|
152
184
|
register_chat_template(
|
153
185
|
ChatTemplate(
|
@@ -187,7 +219,7 @@ register_chat_template(
|
|
187
219
|
# Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava
|
188
220
|
register_chat_template(
|
189
221
|
ChatTemplate(
|
190
|
-
name="yi",
|
222
|
+
name="yi-vl",
|
191
223
|
default_system_prompt=(
|
192
224
|
"This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers."
|
193
225
|
"这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。"
|
@@ -289,8 +321,9 @@ def match_chat_ml(model_path: str):
|
|
289
321
|
model_path = model_path.lower()
|
290
322
|
if "tinyllama" in model_path:
|
291
323
|
return get_chat_template("chatml")
|
292
|
-
|
293
|
-
|
324
|
+
# Now the suffix for qwen2 chat model is "instruct"
|
325
|
+
if "qwen" in model_path and ("chat" in model_path or "instruct" in model_path):
|
326
|
+
return get_chat_template("qwen")
|
294
327
|
if (
|
295
328
|
"llava-v1.6-34b" in model_path
|
296
329
|
or "llava-v1.6-yi-34b" in model_path
|
@@ -302,8 +335,10 @@ def match_chat_ml(model_path: str):
|
|
302
335
|
@register_chat_template_matching_function
|
303
336
|
def match_chat_yi(model_path: str):
|
304
337
|
model_path = model_path.lower()
|
305
|
-
if "yi" in model_path and "llava" not in model_path:
|
306
|
-
return get_chat_template("yi")
|
338
|
+
if "yi-vl" in model_path and "llava" not in model_path:
|
339
|
+
return get_chat_template("yi-vl")
|
340
|
+
elif "yi-1.5" in model_path and "chat" in model_path:
|
341
|
+
return get_chat_template("yi-1.5")
|
307
342
|
|
308
343
|
|
309
344
|
@register_chat_template_matching_function
|
sglang/lang/compiler.py
CHANGED
@@ -4,7 +4,7 @@ from queue import Queue
|
|
4
4
|
from typing import List, Union
|
5
5
|
|
6
6
|
from sglang.global_config import global_config
|
7
|
-
from sglang.lang.interpreter import ProgramState, StreamExecutor,
|
7
|
+
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
|
8
8
|
from sglang.lang.ir import (
|
9
9
|
SglArgument,
|
10
10
|
SglConstantText,
|
@@ -184,7 +184,7 @@ class CompiledFunction:
|
|
184
184
|
|
185
185
|
# Extract prefix by tracing and cache it
|
186
186
|
if len(batch_kwargs) > 1:
|
187
|
-
|
187
|
+
cache_program(self.function, backend)
|
188
188
|
|
189
189
|
# Run all programs
|
190
190
|
if num_threads == "auto":
|
sglang/lang/interpreter.py
CHANGED
@@ -507,7 +507,7 @@ class StreamExecutor:
|
|
507
507
|
)
|
508
508
|
return
|
509
509
|
|
510
|
-
else:
|
510
|
+
else: # Speculative execution on models with completion interface
|
511
511
|
comp, meta_info = self._spec_gen(sampling_params)
|
512
512
|
|
513
513
|
self.text_ += comp
|
@@ -523,9 +523,9 @@ class StreamExecutor:
|
|
523
523
|
self, sampling_params=sampling_params
|
524
524
|
)
|
525
525
|
|
526
|
+
self.variables[name] = ""
|
526
527
|
self.stream_var_event[name].set()
|
527
528
|
|
528
|
-
self.variables[name] = ""
|
529
529
|
for comp, meta_info in generator:
|
530
530
|
self.text_ += comp
|
531
531
|
self.variables[name] += comp
|
@@ -668,6 +668,10 @@ class StreamExecutor:
|
|
668
668
|
"frequency_penalty",
|
669
669
|
"presence_penalty",
|
670
670
|
"ignore_eos",
|
671
|
+
"return_logprob",
|
672
|
+
"logprob_start_len",
|
673
|
+
"top_logprobs_num",
|
674
|
+
"return_text_in_logprobs",
|
671
675
|
"dtype",
|
672
676
|
"regex",
|
673
677
|
]:
|
sglang/lang/ir.py
CHANGED
@@ -23,6 +23,10 @@ class SglSamplingParams:
|
|
23
23
|
frequency_penalty: float = 0.0
|
24
24
|
presence_penalty: float = 0.0
|
25
25
|
ignore_eos: bool = False
|
26
|
+
return_logprob: Optional[bool] = None
|
27
|
+
logprob_start_len: Optional[int] = None,
|
28
|
+
top_logprobs_num: Optional[int] = None,
|
29
|
+
return_text_in_logprobs: Optional[bool] = None,
|
26
30
|
|
27
31
|
# for constrained generation, not included in to_xxx_kwargs
|
28
32
|
dtype: Optional[str] = None
|
@@ -37,6 +41,11 @@ class SglSamplingParams:
|
|
37
41
|
self.top_k,
|
38
42
|
self.frequency_penalty,
|
39
43
|
self.presence_penalty,
|
44
|
+
self.ignore_eos,
|
45
|
+
self.return_logprob,
|
46
|
+
self.logprob_start_len,
|
47
|
+
self.top_logprobs_num,
|
48
|
+
self.return_text_in_logprobs,
|
40
49
|
)
|
41
50
|
|
42
51
|
def to_openai_kwargs(self):
|
@@ -81,12 +90,10 @@ class SglSamplingParams:
|
|
81
90
|
"top_p": self.top_p,
|
82
91
|
"top_k": self.top_k,
|
83
92
|
}
|
84
|
-
|
93
|
+
|
85
94
|
def to_litellm_kwargs(self):
|
86
95
|
if self.regex is not None:
|
87
|
-
warnings.warn(
|
88
|
-
"Regular expression is not supported in the LiteLLM backend."
|
89
|
-
)
|
96
|
+
warnings.warn("Regular expression is not supported in the LiteLLM backend.")
|
90
97
|
return {
|
91
98
|
"max_tokens": self.max_new_tokens,
|
92
99
|
"stop": self.stop or None,
|
@@ -122,6 +129,7 @@ class SglFunction:
|
|
122
129
|
argspec = inspect.getfullargspec(func)
|
123
130
|
assert argspec.args[0] == "s", 'The first argument must be "s"'
|
124
131
|
self.arg_names = argspec.args[1:]
|
132
|
+
self.arg_defaults = argspec.defaults if argspec.defaults is not None else []
|
125
133
|
|
126
134
|
def bind(self, **kwargs):
|
127
135
|
assert all(key in self.arg_names for key in kwargs)
|
@@ -140,6 +148,10 @@ class SglFunction:
|
|
140
148
|
frequency_penalty: float = 0.0,
|
141
149
|
presence_penalty: float = 0.0,
|
142
150
|
ignore_eos: bool = False,
|
151
|
+
return_logprob: Optional[bool] = None,
|
152
|
+
logprob_start_len: Optional[int] = None,
|
153
|
+
top_logprobs_num: Optional[int] = None,
|
154
|
+
return_text_in_logprobs: Optional[bool] = None,
|
143
155
|
stream: bool = False,
|
144
156
|
backend=None,
|
145
157
|
**kwargs,
|
@@ -155,6 +167,10 @@ class SglFunction:
|
|
155
167
|
frequency_penalty=frequency_penalty,
|
156
168
|
presence_penalty=presence_penalty,
|
157
169
|
ignore_eos=ignore_eos,
|
170
|
+
return_logprob=return_logprob,
|
171
|
+
logprob_start_len=logprob_start_len,
|
172
|
+
top_logprobs_num=top_logprobs_num,
|
173
|
+
return_text_in_logprobs=return_text_in_logprobs,
|
158
174
|
)
|
159
175
|
backend = backend or global_config.default_backend
|
160
176
|
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
|
@@ -171,6 +187,10 @@ class SglFunction:
|
|
171
187
|
frequency_penalty: float = 0.0,
|
172
188
|
presence_penalty: float = 0.0,
|
173
189
|
ignore_eos: bool = False,
|
190
|
+
return_logprob: Optional[bool] = None,
|
191
|
+
logprob_start_len: Optional[int] = None,
|
192
|
+
top_logprobs_num: Optional[int] = None,
|
193
|
+
return_text_in_logprobs: Optional[bool] = None,
|
174
194
|
backend=None,
|
175
195
|
num_threads: Union[str, int] = "auto",
|
176
196
|
progress_bar: bool = False,
|
@@ -180,7 +200,20 @@ class SglFunction:
|
|
180
200
|
assert isinstance(batch_kwargs, (list, tuple))
|
181
201
|
if len(batch_kwargs) == 0:
|
182
202
|
return []
|
183
|
-
|
203
|
+
if not isinstance(batch_kwargs[0], dict):
|
204
|
+
num_programs = len(batch_kwargs)
|
205
|
+
# change the list of argument values to dict of arg_name -> arg_value
|
206
|
+
batch_kwargs = [
|
207
|
+
{self.arg_names[i]: v for i, v in enumerate(arg_values)}
|
208
|
+
for arg_values in batch_kwargs
|
209
|
+
if isinstance(arg_values, (list, tuple))
|
210
|
+
and len(self.arg_names) - len(self.arg_defaults)
|
211
|
+
<= len(arg_values)
|
212
|
+
<= len(self.arg_names)
|
213
|
+
]
|
214
|
+
# Ensure to raise an exception if the number of arguments mismatch
|
215
|
+
if len(batch_kwargs) != num_programs:
|
216
|
+
raise Exception("Given arguments mismatch the SGL function signature")
|
184
217
|
|
185
218
|
default_sampling_para = SglSamplingParams(
|
186
219
|
max_new_tokens=max_new_tokens,
|
@@ -191,6 +224,10 @@ class SglFunction:
|
|
191
224
|
frequency_penalty=frequency_penalty,
|
192
225
|
presence_penalty=presence_penalty,
|
193
226
|
ignore_eos=ignore_eos,
|
227
|
+
return_logprob=return_logprob,
|
228
|
+
logprob_start_len=logprob_start_len,
|
229
|
+
top_logprobs_num=top_logprobs_num,
|
230
|
+
return_text_in_logprobs=return_text_in_logprobs,
|
194
231
|
)
|
195
232
|
backend = backend or global_config.default_backend
|
196
233
|
return run_program_batch(
|
@@ -338,7 +375,7 @@ class SglArgument(SglExpr):
|
|
338
375
|
|
339
376
|
|
340
377
|
class SglImage(SglExpr):
|
341
|
-
def __init__(self, path):
|
378
|
+
def __init__(self, path: str):
|
342
379
|
self.path = path
|
343
380
|
|
344
381
|
def __repr__(self) -> str:
|
@@ -346,7 +383,7 @@ class SglImage(SglExpr):
|
|
346
383
|
|
347
384
|
|
348
385
|
class SglVideo(SglExpr):
|
349
|
-
def __init__(self, path, num_frames):
|
386
|
+
def __init__(self, path: str, num_frames: int):
|
350
387
|
self.path = path
|
351
388
|
self.num_frames = num_frames
|
352
389
|
|
@@ -357,18 +394,23 @@ class SglVideo(SglExpr):
|
|
357
394
|
class SglGen(SglExpr):
|
358
395
|
def __init__(
|
359
396
|
self,
|
360
|
-
name,
|
361
|
-
max_new_tokens,
|
362
|
-
stop,
|
363
|
-
temperature,
|
364
|
-
top_p,
|
365
|
-
top_k,
|
366
|
-
frequency_penalty,
|
367
|
-
presence_penalty,
|
368
|
-
ignore_eos,
|
369
|
-
|
370
|
-
|
397
|
+
name: Optional[str] = None,
|
398
|
+
max_new_tokens: Optional[int] = None,
|
399
|
+
stop: Optional[Union[str, List[str]]] = None,
|
400
|
+
temperature: Optional[float] = None,
|
401
|
+
top_p: Optional[float] = None,
|
402
|
+
top_k: Optional[int] = None,
|
403
|
+
frequency_penalty: Optional[float] = None,
|
404
|
+
presence_penalty: Optional[float] = None,
|
405
|
+
ignore_eos: Optional[bool] = None,
|
406
|
+
return_logprob: Optional[bool] = None,
|
407
|
+
logprob_start_len: Optional[int] = None,
|
408
|
+
top_logprobs_num: Optional[int] = None,
|
409
|
+
return_text_in_logprobs: Optional[bool] = None,
|
410
|
+
dtype: Optional[type] = None,
|
411
|
+
regex: Optional[str] = None,
|
371
412
|
):
|
413
|
+
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
|
372
414
|
super().__init__()
|
373
415
|
self.name = name
|
374
416
|
self.sampling_params = SglSamplingParams(
|
@@ -380,6 +422,10 @@ class SglGen(SglExpr):
|
|
380
422
|
frequency_penalty=frequency_penalty,
|
381
423
|
presence_penalty=presence_penalty,
|
382
424
|
ignore_eos=ignore_eos,
|
425
|
+
return_logprob=return_logprob,
|
426
|
+
logprob_start_len=logprob_start_len,
|
427
|
+
top_logprobs_num=top_logprobs_num,
|
428
|
+
return_text_in_logprobs=return_text_in_logprobs,
|
383
429
|
dtype=dtype,
|
384
430
|
regex=regex,
|
385
431
|
)
|
@@ -389,7 +435,7 @@ class SglGen(SglExpr):
|
|
389
435
|
|
390
436
|
|
391
437
|
class SglConstantText(SglExpr):
|
392
|
-
def __init__(self, value):
|
438
|
+
def __init__(self, value: str):
|
393
439
|
super().__init__()
|
394
440
|
self.value = value
|
395
441
|
|
@@ -398,7 +444,7 @@ class SglConstantText(SglExpr):
|
|
398
444
|
|
399
445
|
|
400
446
|
class SglRoleBegin(SglExpr):
|
401
|
-
def __init__(self, role):
|
447
|
+
def __init__(self, role: str):
|
402
448
|
super().__init__()
|
403
449
|
self.role = role
|
404
450
|
|
@@ -407,7 +453,7 @@ class SglRoleBegin(SglExpr):
|
|
407
453
|
|
408
454
|
|
409
455
|
class SglRoleEnd(SglExpr):
|
410
|
-
def __init__(self, role):
|
456
|
+
def __init__(self, role: str):
|
411
457
|
super().__init__()
|
412
458
|
self.role = role
|
413
459
|
|
@@ -416,7 +462,7 @@ class SglRoleEnd(SglExpr):
|
|
416
462
|
|
417
463
|
|
418
464
|
class SglSelect(SglExpr):
|
419
|
-
def __init__(self, name, choices, temperature):
|
465
|
+
def __init__(self, name: str, choices: List[str], temperature: float):
|
420
466
|
super().__init__()
|
421
467
|
self.name = name
|
422
468
|
self.choices = choices
|
@@ -427,7 +473,7 @@ class SglSelect(SglExpr):
|
|
427
473
|
|
428
474
|
|
429
475
|
class SglFork(SglExpr):
|
430
|
-
def __init__(self, number, position_ids_offset=None):
|
476
|
+
def __init__(self, number: int, position_ids_offset=None):
|
431
477
|
super().__init__()
|
432
478
|
self.number = number
|
433
479
|
self.position_ids_offset = position_ids_offset
|
@@ -440,7 +486,7 @@ class SglFork(SglExpr):
|
|
440
486
|
|
441
487
|
|
442
488
|
class SglGetForkItem(SglExpr):
|
443
|
-
def __init__(self, index):
|
489
|
+
def __init__(self, index: int):
|
444
490
|
super().__init__()
|
445
491
|
self.index = index
|
446
492
|
|
@@ -449,7 +495,7 @@ class SglGetForkItem(SglExpr):
|
|
449
495
|
|
450
496
|
|
451
497
|
class SglVariable(SglExpr):
|
452
|
-
def __init__(self, name, source):
|
498
|
+
def __init__(self, name: str, source):
|
453
499
|
super().__init__()
|
454
500
|
self.name = name
|
455
501
|
self.source = source
|
@@ -459,7 +505,7 @@ class SglVariable(SglExpr):
|
|
459
505
|
|
460
506
|
|
461
507
|
class SglVarScopeBegin(SglExpr):
|
462
|
-
def __init__(self, name):
|
508
|
+
def __init__(self, name: str):
|
463
509
|
super().__init__()
|
464
510
|
self.name = name
|
465
511
|
|
@@ -468,7 +514,7 @@ class SglVarScopeBegin(SglExpr):
|
|
468
514
|
|
469
515
|
|
470
516
|
class SglVarScopeEnd(SglExpr):
|
471
|
-
def __init__(self, name):
|
517
|
+
def __init__(self, name: str):
|
472
518
|
super().__init__()
|
473
519
|
self.name = name
|
474
520
|
|
@@ -490,4 +536,4 @@ class SglCommitLazy(SglExpr):
|
|
490
536
|
super().__init__()
|
491
537
|
|
492
538
|
def __repr__(self):
|
493
|
-
return
|
539
|
+
return "CommitLazy()"
|
sglang/launch_server.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1
|
+
"""Launch the inference server."""
|
2
|
+
|
1
3
|
import argparse
|
2
4
|
|
3
|
-
from sglang.srt.server import
|
5
|
+
from sglang.srt.server import launch_server
|
6
|
+
from sglang.srt.server_args import ServerArgs
|
4
7
|
|
5
8
|
if __name__ == "__main__":
|
6
9
|
parser = argparse.ArgumentParser()
|
sglang/launch_server_llavavid.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1
|
+
"""Launch the inference server for Llava-video model."""
|
2
|
+
|
1
3
|
import argparse
|
2
4
|
import multiprocessing as mp
|
3
5
|
|
4
6
|
from sglang.srt.server import ServerArgs, launch_server
|
5
7
|
|
6
8
|
if __name__ == "__main__":
|
7
|
-
|
8
9
|
model_overide_args = {}
|
9
10
|
|
10
11
|
model_overide_args["mm_spatial_pool_stride"] = 2
|
@@ -1,13 +1,20 @@
|
|
1
1
|
import json
|
2
2
|
from typing import Dict, Optional, Union
|
3
3
|
|
4
|
-
from outlines.caching import cache as disk_cache
|
5
|
-
from outlines.caching import disable_cache
|
6
|
-
from outlines.fsm.fsm import RegexFSM
|
7
|
-
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm
|
8
|
-
from outlines.models.transformers import TransformerTokenizer
|
9
4
|
from pydantic import BaseModel
|
10
5
|
|
6
|
+
try:
|
7
|
+
from outlines.caching import cache as disk_cache
|
8
|
+
from outlines.caching import disable_cache
|
9
|
+
from outlines.fsm.guide import RegexGuide
|
10
|
+
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
11
|
+
from outlines.models.transformers import TransformerTokenizer
|
12
|
+
except ImportError as e:
|
13
|
+
print(
|
14
|
+
f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n'
|
15
|
+
)
|
16
|
+
raise
|
17
|
+
|
11
18
|
try:
|
12
19
|
from outlines.fsm.json_schema import build_regex_from_object
|
13
20
|
except ImportError:
|
@@ -28,11 +35,12 @@ except ImportError:
|
|
28
35
|
|
29
36
|
|
30
37
|
__all__ = [
|
31
|
-
"
|
38
|
+
"RegexGuide",
|
32
39
|
"FSMInfo",
|
33
40
|
"make_deterministic_fsm",
|
34
41
|
"build_regex_from_object",
|
35
42
|
"TransformerTokenizer",
|
36
43
|
"disk_cache",
|
37
44
|
"disable_cache",
|
45
|
+
"make_byte_level_fsm",
|
38
46
|
]
|
@@ -1,4 +1,6 @@
|
|
1
|
-
|
1
|
+
"""Cache for the compressed finite state machine."""
|
2
|
+
|
3
|
+
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
2
4
|
from sglang.srt.constrained.base_cache import BaseCache
|
3
5
|
|
4
6
|
|
@@ -6,7 +8,8 @@ class FSMCache(BaseCache):
|
|
6
8
|
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
|
7
9
|
super().__init__(enable=enable)
|
8
10
|
|
9
|
-
if tokenizer_path.endswith(".json"):
|
11
|
+
if tokenizer_path.endswith(".json") or tokenizer_path.endswith(".model"):
|
12
|
+
# Do not support TiktokenTokenizer or SentencePieceTokenizer
|
10
13
|
return
|
11
14
|
|
12
15
|
from importlib.metadata import version
|
@@ -25,4 +28,4 @@ class FSMCache(BaseCache):
|
|
25
28
|
)
|
26
29
|
|
27
30
|
def init_value(self, regex):
|
28
|
-
return
|
31
|
+
return RegexGuide(regex, self.outlines_tokenizer)
|
@@ -1,17 +1,43 @@
|
|
1
|
-
|
1
|
+
"""
|
2
|
+
Faster constrained decoding.
|
3
|
+
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
|
4
|
+
"""
|
5
|
+
|
6
|
+
import dataclasses
|
7
|
+
from collections import defaultdict
|
2
8
|
|
3
|
-
|
9
|
+
import interegular
|
10
|
+
import outlines.caching
|
11
|
+
|
12
|
+
from sglang.srt.constrained import (
|
13
|
+
FSMInfo,
|
14
|
+
disk_cache,
|
15
|
+
make_byte_level_fsm,
|
16
|
+
make_deterministic_fsm,
|
17
|
+
)
|
4
18
|
from sglang.srt.constrained.base_cache import BaseCache
|
5
19
|
|
6
20
|
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
7
21
|
|
8
22
|
|
23
|
+
@dataclasses.dataclass
|
24
|
+
class JumpEdge:
|
25
|
+
symbol: str = None
|
26
|
+
symbol_next_state: int = None
|
27
|
+
byte: int = None
|
28
|
+
byte_next_state: int = None
|
29
|
+
|
30
|
+
|
9
31
|
class JumpForwardMap:
|
10
32
|
def __init__(self, regex_string):
|
11
33
|
@disk_cache()
|
12
34
|
def _init_state_to_jump_forward(regex_string):
|
13
35
|
regex_pattern = interegular.parse_pattern(regex_string)
|
14
|
-
|
36
|
+
|
37
|
+
byte_fsm = make_byte_level_fsm(
|
38
|
+
regex_pattern.to_fsm().reduce(), keep_utf8=True
|
39
|
+
)
|
40
|
+
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
|
15
41
|
|
16
42
|
fsm_info: FSMInfo = regex_fsm.fsm_info
|
17
43
|
|
@@ -21,40 +47,93 @@ class JumpForwardMap:
|
|
21
47
|
id_to_symbol.setdefault(id_, []).append(symbol)
|
22
48
|
|
23
49
|
transitions = fsm_info.transitions
|
24
|
-
|
50
|
+
outgoings_ct = defaultdict(int)
|
25
51
|
state_to_jump_forward = {}
|
26
52
|
|
27
53
|
for (state, id_), next_state in transitions.items():
|
28
|
-
if
|
29
|
-
continue
|
30
|
-
if state in state_to_jump_forward:
|
31
|
-
dirty_states.add(state)
|
32
|
-
del state_to_jump_forward[state]
|
54
|
+
if id_ == fsm_info.alphabet_anything_value:
|
33
55
|
continue
|
34
|
-
|
35
|
-
|
56
|
+
symbols = id_to_symbol[id_]
|
57
|
+
for c in symbols:
|
58
|
+
if len(c) > 1:
|
59
|
+
# Skip byte level transitions
|
60
|
+
continue
|
61
|
+
|
62
|
+
outgoings_ct[state] += 1
|
63
|
+
if outgoings_ct[state] > 1:
|
64
|
+
if state in state_to_jump_forward:
|
65
|
+
del state_to_jump_forward[state]
|
66
|
+
break
|
67
|
+
|
68
|
+
state_to_jump_forward[state] = JumpEdge(
|
69
|
+
symbol=c,
|
70
|
+
symbol_next_state=next_state,
|
71
|
+
)
|
72
|
+
|
73
|
+
# Process the byte level jump forward
|
74
|
+
outgoings_ct = defaultdict(int)
|
75
|
+
for (state, id_), next_state in transitions.items():
|
76
|
+
if id_ == fsm_info.alphabet_anything_value:
|
36
77
|
continue
|
37
|
-
|
38
|
-
|
78
|
+
symbols = id_to_symbol[id_]
|
79
|
+
for c in symbols:
|
80
|
+
byte_ = None
|
81
|
+
if len(c) == 1 and ord(c) < 0x80:
|
82
|
+
# ASCII character
|
83
|
+
byte_ = ord(c)
|
84
|
+
elif len(c) > 1:
|
85
|
+
# FIXME: This logic is due to the leading \x00
|
86
|
+
# https://github.com/outlines-dev/outlines/pull/930
|
87
|
+
byte_ = int(symbols[0][1:], 16)
|
88
|
+
|
89
|
+
if byte_ is not None:
|
90
|
+
outgoings_ct[state] += 1
|
91
|
+
if outgoings_ct[state] > 1:
|
92
|
+
if state in state_to_jump_forward:
|
93
|
+
del state_to_jump_forward[state]
|
94
|
+
break
|
95
|
+
e = state_to_jump_forward.get(state, JumpEdge())
|
96
|
+
e.byte = byte_
|
97
|
+
e.byte_next_state = next_state
|
98
|
+
state_to_jump_forward[state] = e
|
39
99
|
|
40
100
|
return state_to_jump_forward
|
41
101
|
|
42
102
|
self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
|
43
103
|
|
44
|
-
def
|
45
|
-
|
104
|
+
def jump_forward_symbol(self, state):
|
105
|
+
jump_forward_str = ""
|
106
|
+
next_state = state
|
107
|
+
while state in self.state_to_jump_forward:
|
108
|
+
e = self.state_to_jump_forward[state]
|
109
|
+
if e.symbol is None:
|
110
|
+
break
|
111
|
+
jump_forward_str += e.symbol
|
112
|
+
next_state = e.symbol_next_state
|
113
|
+
state = next_state
|
46
114
|
|
47
|
-
|
115
|
+
return jump_forward_str, next_state
|
116
|
+
|
117
|
+
def jump_forward_byte(self, state):
|
48
118
|
if state not in self.state_to_jump_forward:
|
49
119
|
return None
|
50
120
|
|
51
|
-
|
121
|
+
jump_forward_bytes = []
|
52
122
|
next_state = None
|
53
123
|
while state in self.state_to_jump_forward:
|
54
|
-
|
55
|
-
|
124
|
+
e = self.state_to_jump_forward[state]
|
125
|
+
assert e.byte is not None and e.byte_next_state is not None
|
126
|
+
jump_forward_bytes.append((e.byte, e.byte_next_state))
|
127
|
+
next_state = e.byte_next_state
|
56
128
|
state = next_state
|
57
|
-
|
129
|
+
|
130
|
+
return jump_forward_bytes
|
131
|
+
|
132
|
+
def is_jump_forward_symbol_state(self, state):
|
133
|
+
return (
|
134
|
+
state in self.state_to_jump_forward
|
135
|
+
and self.state_to_jump_forward[state].symbol is not None
|
136
|
+
)
|
58
137
|
|
59
138
|
|
60
139
|
class JumpForwardCache(BaseCache):
|
@@ -65,12 +144,21 @@ class JumpForwardCache(BaseCache):
|
|
65
144
|
return JumpForwardMap(regex)
|
66
145
|
|
67
146
|
|
68
|
-
def test_main():
|
69
|
-
regex_string = r"The google's DNS sever address is " + IP_REGEX
|
147
|
+
def test_main(regex_string):
|
70
148
|
jump_forward_map = JumpForwardMap(regex_string)
|
71
|
-
for state in jump_forward_map.
|
72
|
-
|
149
|
+
for state, e in jump_forward_map.state_to_jump_forward.items():
|
150
|
+
if e.symbol is not None:
|
151
|
+
jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
|
152
|
+
print(f"{state} -> {next_state}", jump_forward_str)
|
153
|
+
bytes_ = jump_forward_map.jump_forward_byte(state)
|
154
|
+
print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_])
|
73
155
|
|
74
156
|
|
75
157
|
if __name__ == "__main__":
|
76
|
-
|
158
|
+
import outlines
|
159
|
+
|
160
|
+
outlines.caching.clear_cache()
|
161
|
+
test_main(r"The google's DNS sever address is " + IP_REGEX)
|
162
|
+
test_main(r"霍格沃茨特快列车|霍比特人比尔博")
|
163
|
+
# 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
|
164
|
+
# 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
|