sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
sglang/lang/ir.py
CHANGED
@@ -23,6 +23,10 @@ class SglSamplingParams:
|
|
23
23
|
frequency_penalty: float = 0.0
|
24
24
|
presence_penalty: float = 0.0
|
25
25
|
ignore_eos: bool = False
|
26
|
+
return_logprob: Optional[bool] = None
|
27
|
+
logprob_start_len: Optional[int] = (None,)
|
28
|
+
top_logprobs_num: Optional[int] = (None,)
|
29
|
+
return_text_in_logprobs: Optional[bool] = (None,)
|
26
30
|
|
27
31
|
# for constrained generation, not included in to_xxx_kwargs
|
28
32
|
dtype: Optional[str] = None
|
@@ -37,6 +41,11 @@ class SglSamplingParams:
|
|
37
41
|
self.top_k,
|
38
42
|
self.frequency_penalty,
|
39
43
|
self.presence_penalty,
|
44
|
+
self.ignore_eos,
|
45
|
+
self.return_logprob,
|
46
|
+
self.logprob_start_len,
|
47
|
+
self.top_logprobs_num,
|
48
|
+
self.return_text_in_logprobs,
|
40
49
|
)
|
41
50
|
|
42
51
|
def to_openai_kwargs(self):
|
@@ -82,6 +91,19 @@ class SglSamplingParams:
|
|
82
91
|
"top_k": self.top_k,
|
83
92
|
}
|
84
93
|
|
94
|
+
def to_litellm_kwargs(self):
|
95
|
+
if self.regex is not None:
|
96
|
+
warnings.warn("Regular expression is not supported in the LiteLLM backend.")
|
97
|
+
return {
|
98
|
+
"max_tokens": self.max_new_tokens,
|
99
|
+
"stop": self.stop or None,
|
100
|
+
"temperature": self.temperature,
|
101
|
+
"top_p": self.top_p,
|
102
|
+
"top_k": self.top_k,
|
103
|
+
"frequency_penalty": self.frequency_penalty,
|
104
|
+
"presence_penalty": self.presence_penalty,
|
105
|
+
}
|
106
|
+
|
85
107
|
def to_srt_kwargs(self):
|
86
108
|
return {
|
87
109
|
"max_new_tokens": self.max_new_tokens,
|
@@ -97,9 +119,9 @@ class SglSamplingParams:
|
|
97
119
|
|
98
120
|
|
99
121
|
class SglFunction:
|
100
|
-
def __init__(self, func,
|
122
|
+
def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None):
|
101
123
|
self.func = func
|
102
|
-
self.
|
124
|
+
self.num_api_spec_tokens = num_api_spec_tokens
|
103
125
|
self.bind_arguments = bind_arguments or {}
|
104
126
|
self.pin_prefix_rid = None
|
105
127
|
|
@@ -107,6 +129,7 @@ class SglFunction:
|
|
107
129
|
argspec = inspect.getfullargspec(func)
|
108
130
|
assert argspec.args[0] == "s", 'The first argument must be "s"'
|
109
131
|
self.arg_names = argspec.args[1:]
|
132
|
+
self.arg_defaults = argspec.defaults if argspec.defaults is not None else []
|
110
133
|
|
111
134
|
def bind(self, **kwargs):
|
112
135
|
assert all(key in self.arg_names for key in kwargs)
|
@@ -125,6 +148,10 @@ class SglFunction:
|
|
125
148
|
frequency_penalty: float = 0.0,
|
126
149
|
presence_penalty: float = 0.0,
|
127
150
|
ignore_eos: bool = False,
|
151
|
+
return_logprob: Optional[bool] = None,
|
152
|
+
logprob_start_len: Optional[int] = None,
|
153
|
+
top_logprobs_num: Optional[int] = None,
|
154
|
+
return_text_in_logprobs: Optional[bool] = None,
|
128
155
|
stream: bool = False,
|
129
156
|
backend=None,
|
130
157
|
**kwargs,
|
@@ -140,6 +167,10 @@ class SglFunction:
|
|
140
167
|
frequency_penalty=frequency_penalty,
|
141
168
|
presence_penalty=presence_penalty,
|
142
169
|
ignore_eos=ignore_eos,
|
170
|
+
return_logprob=return_logprob,
|
171
|
+
logprob_start_len=logprob_start_len,
|
172
|
+
top_logprobs_num=top_logprobs_num,
|
173
|
+
return_text_in_logprobs=return_text_in_logprobs,
|
143
174
|
)
|
144
175
|
backend = backend or global_config.default_backend
|
145
176
|
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
|
@@ -156,6 +187,10 @@ class SglFunction:
|
|
156
187
|
frequency_penalty: float = 0.0,
|
157
188
|
presence_penalty: float = 0.0,
|
158
189
|
ignore_eos: bool = False,
|
190
|
+
return_logprob: Optional[bool] = None,
|
191
|
+
logprob_start_len: Optional[int] = None,
|
192
|
+
top_logprobs_num: Optional[int] = None,
|
193
|
+
return_text_in_logprobs: Optional[bool] = None,
|
159
194
|
backend=None,
|
160
195
|
num_threads: Union[str, int] = "auto",
|
161
196
|
progress_bar: bool = False,
|
@@ -165,7 +200,20 @@ class SglFunction:
|
|
165
200
|
assert isinstance(batch_kwargs, (list, tuple))
|
166
201
|
if len(batch_kwargs) == 0:
|
167
202
|
return []
|
168
|
-
|
203
|
+
if not isinstance(batch_kwargs[0], dict):
|
204
|
+
num_programs = len(batch_kwargs)
|
205
|
+
# change the list of argument values to dict of arg_name -> arg_value
|
206
|
+
batch_kwargs = [
|
207
|
+
{self.arg_names[i]: v for i, v in enumerate(arg_values)}
|
208
|
+
for arg_values in batch_kwargs
|
209
|
+
if isinstance(arg_values, (list, tuple))
|
210
|
+
and len(self.arg_names) - len(self.arg_defaults)
|
211
|
+
<= len(arg_values)
|
212
|
+
<= len(self.arg_names)
|
213
|
+
]
|
214
|
+
# Ensure to raise an exception if the number of arguments mismatch
|
215
|
+
if len(batch_kwargs) != num_programs:
|
216
|
+
raise Exception("Given arguments mismatch the SGL function signature")
|
169
217
|
|
170
218
|
default_sampling_para = SglSamplingParams(
|
171
219
|
max_new_tokens=max_new_tokens,
|
@@ -176,6 +224,10 @@ class SglFunction:
|
|
176
224
|
frequency_penalty=frequency_penalty,
|
177
225
|
presence_penalty=presence_penalty,
|
178
226
|
ignore_eos=ignore_eos,
|
227
|
+
return_logprob=return_logprob,
|
228
|
+
logprob_start_len=logprob_start_len,
|
229
|
+
top_logprobs_num=top_logprobs_num,
|
230
|
+
return_text_in_logprobs=return_text_in_logprobs,
|
179
231
|
)
|
180
232
|
backend = backend or global_config.default_backend
|
181
233
|
return run_program_batch(
|
@@ -193,17 +245,11 @@ class SglFunction:
|
|
193
245
|
backend = backend or global_config.default_backend
|
194
246
|
return trace_program(self, kwargs, backend)
|
195
247
|
|
196
|
-
def
|
197
|
-
from sglang.lang.interpreter import
|
248
|
+
def cache(self, backend=None):
|
249
|
+
from sglang.lang.interpreter import cache_program
|
198
250
|
|
199
251
|
backend = backend or global_config.default_backend
|
200
|
-
return
|
201
|
-
|
202
|
-
def unpin(self, backend=None):
|
203
|
-
from sglang.lang.interpreter import unpin_program
|
204
|
-
|
205
|
-
backend = backend or global_config.default_backend
|
206
|
-
return unpin_program(self, backend)
|
252
|
+
return cache_program(self, backend)
|
207
253
|
|
208
254
|
def compile(self, *, backend=None):
|
209
255
|
from sglang.lang.compiler import compile_func
|
@@ -329,28 +375,42 @@ class SglArgument(SglExpr):
|
|
329
375
|
|
330
376
|
|
331
377
|
class SglImage(SglExpr):
|
332
|
-
def __init__(self, path):
|
378
|
+
def __init__(self, path: str):
|
333
379
|
self.path = path
|
334
380
|
|
335
381
|
def __repr__(self) -> str:
|
336
382
|
return f"SglImage({self.path})"
|
337
383
|
|
338
384
|
|
385
|
+
class SglVideo(SglExpr):
|
386
|
+
def __init__(self, path: str, num_frames: int):
|
387
|
+
self.path = path
|
388
|
+
self.num_frames = num_frames
|
389
|
+
|
390
|
+
def __repr__(self) -> str:
|
391
|
+
return f"SglVideo({self.path}, {self.num_frames})"
|
392
|
+
|
393
|
+
|
339
394
|
class SglGen(SglExpr):
|
340
395
|
def __init__(
|
341
396
|
self,
|
342
|
-
name,
|
343
|
-
max_new_tokens,
|
344
|
-
stop,
|
345
|
-
temperature,
|
346
|
-
top_p,
|
347
|
-
top_k,
|
348
|
-
frequency_penalty,
|
349
|
-
presence_penalty,
|
350
|
-
ignore_eos,
|
351
|
-
|
352
|
-
|
397
|
+
name: Optional[str] = None,
|
398
|
+
max_new_tokens: Optional[int] = None,
|
399
|
+
stop: Optional[Union[str, List[str]]] = None,
|
400
|
+
temperature: Optional[float] = None,
|
401
|
+
top_p: Optional[float] = None,
|
402
|
+
top_k: Optional[int] = None,
|
403
|
+
frequency_penalty: Optional[float] = None,
|
404
|
+
presence_penalty: Optional[float] = None,
|
405
|
+
ignore_eos: Optional[bool] = None,
|
406
|
+
return_logprob: Optional[bool] = None,
|
407
|
+
logprob_start_len: Optional[int] = None,
|
408
|
+
top_logprobs_num: Optional[int] = None,
|
409
|
+
return_text_in_logprobs: Optional[bool] = None,
|
410
|
+
dtype: Optional[type] = None,
|
411
|
+
regex: Optional[str] = None,
|
353
412
|
):
|
413
|
+
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
|
354
414
|
super().__init__()
|
355
415
|
self.name = name
|
356
416
|
self.sampling_params = SglSamplingParams(
|
@@ -362,6 +422,10 @@ class SglGen(SglExpr):
|
|
362
422
|
frequency_penalty=frequency_penalty,
|
363
423
|
presence_penalty=presence_penalty,
|
364
424
|
ignore_eos=ignore_eos,
|
425
|
+
return_logprob=return_logprob,
|
426
|
+
logprob_start_len=logprob_start_len,
|
427
|
+
top_logprobs_num=top_logprobs_num,
|
428
|
+
return_text_in_logprobs=return_text_in_logprobs,
|
365
429
|
dtype=dtype,
|
366
430
|
regex=regex,
|
367
431
|
)
|
@@ -371,7 +435,7 @@ class SglGen(SglExpr):
|
|
371
435
|
|
372
436
|
|
373
437
|
class SglConstantText(SglExpr):
|
374
|
-
def __init__(self, value):
|
438
|
+
def __init__(self, value: str):
|
375
439
|
super().__init__()
|
376
440
|
self.value = value
|
377
441
|
|
@@ -380,7 +444,7 @@ class SglConstantText(SglExpr):
|
|
380
444
|
|
381
445
|
|
382
446
|
class SglRoleBegin(SglExpr):
|
383
|
-
def __init__(self, role):
|
447
|
+
def __init__(self, role: str):
|
384
448
|
super().__init__()
|
385
449
|
self.role = role
|
386
450
|
|
@@ -389,7 +453,7 @@ class SglRoleBegin(SglExpr):
|
|
389
453
|
|
390
454
|
|
391
455
|
class SglRoleEnd(SglExpr):
|
392
|
-
def __init__(self, role):
|
456
|
+
def __init__(self, role: str):
|
393
457
|
super().__init__()
|
394
458
|
self.role = role
|
395
459
|
|
@@ -398,7 +462,7 @@ class SglRoleEnd(SglExpr):
|
|
398
462
|
|
399
463
|
|
400
464
|
class SglSelect(SglExpr):
|
401
|
-
def __init__(self, name, choices, temperature):
|
465
|
+
def __init__(self, name: str, choices: List[str], temperature: float):
|
402
466
|
super().__init__()
|
403
467
|
self.name = name
|
404
468
|
self.choices = choices
|
@@ -409,7 +473,7 @@ class SglSelect(SglExpr):
|
|
409
473
|
|
410
474
|
|
411
475
|
class SglFork(SglExpr):
|
412
|
-
def __init__(self, number, position_ids_offset=None):
|
476
|
+
def __init__(self, number: int, position_ids_offset=None):
|
413
477
|
super().__init__()
|
414
478
|
self.number = number
|
415
479
|
self.position_ids_offset = position_ids_offset
|
@@ -422,7 +486,7 @@ class SglFork(SglExpr):
|
|
422
486
|
|
423
487
|
|
424
488
|
class SglGetForkItem(SglExpr):
|
425
|
-
def __init__(self, index):
|
489
|
+
def __init__(self, index: int):
|
426
490
|
super().__init__()
|
427
491
|
self.index = index
|
428
492
|
|
@@ -431,7 +495,7 @@ class SglGetForkItem(SglExpr):
|
|
431
495
|
|
432
496
|
|
433
497
|
class SglVariable(SglExpr):
|
434
|
-
def __init__(self, name, source):
|
498
|
+
def __init__(self, name: str, source):
|
435
499
|
super().__init__()
|
436
500
|
self.name = name
|
437
501
|
self.source = source
|
@@ -441,7 +505,7 @@ class SglVariable(SglExpr):
|
|
441
505
|
|
442
506
|
|
443
507
|
class SglVarScopeBegin(SglExpr):
|
444
|
-
def __init__(self, name):
|
508
|
+
def __init__(self, name: str):
|
445
509
|
super().__init__()
|
446
510
|
self.name = name
|
447
511
|
|
@@ -450,7 +514,7 @@ class SglVarScopeBegin(SglExpr):
|
|
450
514
|
|
451
515
|
|
452
516
|
class SglVarScopeEnd(SglExpr):
|
453
|
-
def __init__(self, name):
|
517
|
+
def __init__(self, name: str):
|
454
518
|
super().__init__()
|
455
519
|
self.name = name
|
456
520
|
|
@@ -472,4 +536,4 @@ class SglCommitLazy(SglExpr):
|
|
472
536
|
super().__init__()
|
473
537
|
|
474
538
|
def __repr__(self):
|
475
|
-
return
|
539
|
+
return "CommitLazy()"
|
sglang/lang/tracer.py
CHANGED
@@ -109,19 +109,21 @@ class TracerProgramState(ProgramState):
|
|
109
109
|
########### Public API ###########
|
110
110
|
##################################
|
111
111
|
|
112
|
-
def fork(self,
|
112
|
+
def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
|
113
|
+
assert size >= 1
|
114
|
+
|
113
115
|
if self.only_trace_prefix:
|
114
116
|
raise StopTracing()
|
115
117
|
|
116
|
-
fork_node = SglFork(
|
118
|
+
fork_node = SglFork(size)
|
117
119
|
fork_node.prev_node = self.last_node
|
118
120
|
|
119
121
|
states = [
|
120
122
|
TracerProgramState(self.backend, self.arguments, self.only_trace_prefix)
|
121
|
-
for _ in range(
|
123
|
+
for _ in range(size)
|
122
124
|
]
|
123
125
|
|
124
|
-
for i in range(
|
126
|
+
for i in range(size):
|
125
127
|
node = SglGetForkItem(i)
|
126
128
|
node.prev_node = fork_node
|
127
129
|
states[i].last_node = node
|
sglang/launch_server.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1
|
+
"""Launch the inference server."""
|
2
|
+
|
1
3
|
import argparse
|
2
4
|
|
3
|
-
from sglang.srt.server import
|
5
|
+
from sglang.srt.server import launch_server
|
6
|
+
from sglang.srt.server_args import ServerArgs
|
4
7
|
|
5
8
|
if __name__ == "__main__":
|
6
9
|
parser = argparse.ArgumentParser()
|
@@ -0,0 +1,32 @@
|
|
1
|
+
"""Launch the inference server for Llava-video model."""
|
2
|
+
|
3
|
+
import argparse
|
4
|
+
import multiprocessing as mp
|
5
|
+
|
6
|
+
from sglang.srt.server import ServerArgs, launch_server
|
7
|
+
|
8
|
+
if __name__ == "__main__":
|
9
|
+
model_overide_args = {}
|
10
|
+
|
11
|
+
model_overide_args["mm_spatial_pool_stride"] = 2
|
12
|
+
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
|
13
|
+
model_overide_args["num_frames"] = 16
|
14
|
+
model_overide_args["model_type"] = "llavavid"
|
15
|
+
if model_overide_args["num_frames"] == 32:
|
16
|
+
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
|
17
|
+
model_overide_args["max_sequence_length"] = 4096 * 2
|
18
|
+
model_overide_args["tokenizer_model_max_length"] = 4096 * 2
|
19
|
+
model_overide_args["model_max_length"] = 4096 * 2
|
20
|
+
|
21
|
+
parser = argparse.ArgumentParser()
|
22
|
+
ServerArgs.add_cli_args(parser)
|
23
|
+
args = parser.parse_args()
|
24
|
+
|
25
|
+
if "34b" in args.model_path.lower():
|
26
|
+
model_overide_args["image_token_index"] = 64002
|
27
|
+
|
28
|
+
server_args = ServerArgs.from_cli_args(args)
|
29
|
+
|
30
|
+
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
31
|
+
|
32
|
+
launch_server(server_args, pipe_writer, model_overide_args)
|
@@ -1,13 +1,20 @@
|
|
1
1
|
import json
|
2
2
|
from typing import Dict, Optional, Union
|
3
3
|
|
4
|
-
from outlines.caching import cache as disk_cache
|
5
|
-
from outlines.caching import disable_cache
|
6
|
-
from outlines.fsm.fsm import RegexFSM
|
7
|
-
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm
|
8
|
-
from outlines.models.transformers import TransformerTokenizer
|
9
4
|
from pydantic import BaseModel
|
10
5
|
|
6
|
+
try:
|
7
|
+
from outlines.caching import cache as disk_cache
|
8
|
+
from outlines.caching import disable_cache
|
9
|
+
from outlines.fsm.guide import RegexGuide
|
10
|
+
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
11
|
+
from outlines.models.transformers import TransformerTokenizer
|
12
|
+
except ImportError as e:
|
13
|
+
print(
|
14
|
+
f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n'
|
15
|
+
)
|
16
|
+
raise
|
17
|
+
|
11
18
|
try:
|
12
19
|
from outlines.fsm.json_schema import build_regex_from_object
|
13
20
|
except ImportError:
|
@@ -28,11 +35,12 @@ except ImportError:
|
|
28
35
|
|
29
36
|
|
30
37
|
__all__ = [
|
31
|
-
"
|
38
|
+
"RegexGuide",
|
32
39
|
"FSMInfo",
|
33
40
|
"make_deterministic_fsm",
|
34
41
|
"build_regex_from_object",
|
35
42
|
"TransformerTokenizer",
|
36
43
|
"disk_cache",
|
37
44
|
"disable_cache",
|
45
|
+
"make_byte_level_fsm",
|
38
46
|
]
|
@@ -1,4 +1,6 @@
|
|
1
|
-
|
1
|
+
"""Cache for the compressed finite state machine."""
|
2
|
+
|
3
|
+
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
2
4
|
from sglang.srt.constrained.base_cache import BaseCache
|
3
5
|
|
4
6
|
|
@@ -6,7 +8,12 @@ class FSMCache(BaseCache):
|
|
6
8
|
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
|
7
9
|
super().__init__(enable=enable)
|
8
10
|
|
11
|
+
if tokenizer_path.endswith(".json") or tokenizer_path.endswith(".model"):
|
12
|
+
# Do not support TiktokenTokenizer or SentencePieceTokenizer
|
13
|
+
return
|
14
|
+
|
9
15
|
from importlib.metadata import version
|
16
|
+
|
10
17
|
if version("outlines") >= "0.0.35":
|
11
18
|
from transformers import AutoTokenizer
|
12
19
|
|
@@ -21,4 +28,4 @@ class FSMCache(BaseCache):
|
|
21
28
|
)
|
22
29
|
|
23
30
|
def init_value(self, regex):
|
24
|
-
return
|
31
|
+
return RegexGuide(regex, self.outlines_tokenizer)
|
@@ -1,16 +1,43 @@
|
|
1
|
+
"""
|
2
|
+
Faster constrained decoding.
|
3
|
+
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
|
4
|
+
"""
|
5
|
+
|
6
|
+
import dataclasses
|
7
|
+
from collections import defaultdict
|
8
|
+
|
1
9
|
import interegular
|
2
|
-
|
10
|
+
import outlines.caching
|
11
|
+
|
12
|
+
from sglang.srt.constrained import (
|
13
|
+
FSMInfo,
|
14
|
+
disk_cache,
|
15
|
+
make_byte_level_fsm,
|
16
|
+
make_deterministic_fsm,
|
17
|
+
)
|
3
18
|
from sglang.srt.constrained.base_cache import BaseCache
|
4
19
|
|
5
20
|
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
6
21
|
|
7
22
|
|
23
|
+
@dataclasses.dataclass
|
24
|
+
class JumpEdge:
|
25
|
+
symbol: str = None
|
26
|
+
symbol_next_state: int = None
|
27
|
+
byte: int = None
|
28
|
+
byte_next_state: int = None
|
29
|
+
|
30
|
+
|
8
31
|
class JumpForwardMap:
|
9
32
|
def __init__(self, regex_string):
|
10
33
|
@disk_cache()
|
11
34
|
def _init_state_to_jump_forward(regex_string):
|
12
35
|
regex_pattern = interegular.parse_pattern(regex_string)
|
13
|
-
|
36
|
+
|
37
|
+
byte_fsm = make_byte_level_fsm(
|
38
|
+
regex_pattern.to_fsm().reduce(), keep_utf8=True
|
39
|
+
)
|
40
|
+
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
|
14
41
|
|
15
42
|
fsm_info: FSMInfo = regex_fsm.fsm_info
|
16
43
|
|
@@ -20,40 +47,93 @@ class JumpForwardMap:
|
|
20
47
|
id_to_symbol.setdefault(id_, []).append(symbol)
|
21
48
|
|
22
49
|
transitions = fsm_info.transitions
|
23
|
-
|
50
|
+
outgoings_ct = defaultdict(int)
|
24
51
|
state_to_jump_forward = {}
|
25
52
|
|
26
53
|
for (state, id_), next_state in transitions.items():
|
27
|
-
if
|
54
|
+
if id_ == fsm_info.alphabet_anything_value:
|
28
55
|
continue
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
56
|
+
symbols = id_to_symbol[id_]
|
57
|
+
for c in symbols:
|
58
|
+
if len(c) > 1:
|
59
|
+
# Skip byte level transitions
|
60
|
+
continue
|
61
|
+
|
62
|
+
outgoings_ct[state] += 1
|
63
|
+
if outgoings_ct[state] > 1:
|
64
|
+
if state in state_to_jump_forward:
|
65
|
+
del state_to_jump_forward[state]
|
66
|
+
break
|
67
|
+
|
68
|
+
state_to_jump_forward[state] = JumpEdge(
|
69
|
+
symbol=c,
|
70
|
+
symbol_next_state=next_state,
|
71
|
+
)
|
72
|
+
|
73
|
+
# Process the byte level jump forward
|
74
|
+
outgoings_ct = defaultdict(int)
|
75
|
+
for (state, id_), next_state in transitions.items():
|
76
|
+
if id_ == fsm_info.alphabet_anything_value:
|
35
77
|
continue
|
36
|
-
|
37
|
-
|
78
|
+
symbols = id_to_symbol[id_]
|
79
|
+
for c in symbols:
|
80
|
+
byte_ = None
|
81
|
+
if len(c) == 1 and ord(c) < 0x80:
|
82
|
+
# ASCII character
|
83
|
+
byte_ = ord(c)
|
84
|
+
elif len(c) > 1:
|
85
|
+
# FIXME: This logic is due to the leading \x00
|
86
|
+
# https://github.com/outlines-dev/outlines/pull/930
|
87
|
+
byte_ = int(symbols[0][1:], 16)
|
88
|
+
|
89
|
+
if byte_ is not None:
|
90
|
+
outgoings_ct[state] += 1
|
91
|
+
if outgoings_ct[state] > 1:
|
92
|
+
if state in state_to_jump_forward:
|
93
|
+
del state_to_jump_forward[state]
|
94
|
+
break
|
95
|
+
e = state_to_jump_forward.get(state, JumpEdge())
|
96
|
+
e.byte = byte_
|
97
|
+
e.byte_next_state = next_state
|
98
|
+
state_to_jump_forward[state] = e
|
38
99
|
|
39
100
|
return state_to_jump_forward
|
40
101
|
|
41
102
|
self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
|
42
103
|
|
43
|
-
def
|
44
|
-
|
104
|
+
def jump_forward_symbol(self, state):
|
105
|
+
jump_forward_str = ""
|
106
|
+
next_state = state
|
107
|
+
while state in self.state_to_jump_forward:
|
108
|
+
e = self.state_to_jump_forward[state]
|
109
|
+
if e.symbol is None:
|
110
|
+
break
|
111
|
+
jump_forward_str += e.symbol
|
112
|
+
next_state = e.symbol_next_state
|
113
|
+
state = next_state
|
114
|
+
|
115
|
+
return jump_forward_str, next_state
|
45
116
|
|
46
|
-
def
|
117
|
+
def jump_forward_byte(self, state):
|
47
118
|
if state not in self.state_to_jump_forward:
|
48
119
|
return None
|
49
120
|
|
50
|
-
|
121
|
+
jump_forward_bytes = []
|
51
122
|
next_state = None
|
52
123
|
while state in self.state_to_jump_forward:
|
53
|
-
|
54
|
-
|
124
|
+
e = self.state_to_jump_forward[state]
|
125
|
+
assert e.byte is not None and e.byte_next_state is not None
|
126
|
+
jump_forward_bytes.append((e.byte, e.byte_next_state))
|
127
|
+
next_state = e.byte_next_state
|
55
128
|
state = next_state
|
56
|
-
|
129
|
+
|
130
|
+
return jump_forward_bytes
|
131
|
+
|
132
|
+
def is_jump_forward_symbol_state(self, state):
|
133
|
+
return (
|
134
|
+
state in self.state_to_jump_forward
|
135
|
+
and self.state_to_jump_forward[state].symbol is not None
|
136
|
+
)
|
57
137
|
|
58
138
|
|
59
139
|
class JumpForwardCache(BaseCache):
|
@@ -64,12 +144,21 @@ class JumpForwardCache(BaseCache):
|
|
64
144
|
return JumpForwardMap(regex)
|
65
145
|
|
66
146
|
|
67
|
-
def test_main():
|
68
|
-
regex_string = r"The google's DNS sever address is " + IP_REGEX
|
147
|
+
def test_main(regex_string):
|
69
148
|
jump_forward_map = JumpForwardMap(regex_string)
|
70
|
-
for state in jump_forward_map.
|
71
|
-
|
149
|
+
for state, e in jump_forward_map.state_to_jump_forward.items():
|
150
|
+
if e.symbol is not None:
|
151
|
+
jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
|
152
|
+
print(f"{state} -> {next_state}", jump_forward_str)
|
153
|
+
bytes_ = jump_forward_map.jump_forward_byte(state)
|
154
|
+
print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_])
|
72
155
|
|
73
156
|
|
74
157
|
if __name__ == "__main__":
|
75
|
-
|
158
|
+
import outlines
|
159
|
+
|
160
|
+
outlines.caching.clear_cache()
|
161
|
+
test_main(r"The google's DNS sever address is " + IP_REGEX)
|
162
|
+
test_main(r"霍格沃茨特快列车|霍比特人比尔博")
|
163
|
+
# 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
|
164
|
+
# 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
|
sglang/srt/conversation.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1
|
+
"""Conversation templates."""
|
2
|
+
|
1
3
|
# Adapted from
|
2
4
|
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
3
5
|
import dataclasses
|
4
6
|
from enum import IntEnum, auto
|
5
7
|
from typing import Dict, List, Optional, Tuple, Union
|
6
8
|
|
7
|
-
from sglang.srt.
|
9
|
+
from sglang.srt.openai_protocol import ChatCompletionRequest
|
8
10
|
|
9
11
|
|
10
12
|
class SeparatorStyle(IntEnum):
|
@@ -400,7 +402,7 @@ register_conv_template(
|
|
400
402
|
Conversation(
|
401
403
|
name="chatml",
|
402
404
|
system_template="<|im_start|>system\n{system_message}",
|
403
|
-
system_message="You are
|
405
|
+
system_message="You are a helpful assistant.",
|
404
406
|
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
405
407
|
sep_style=SeparatorStyle.CHATML,
|
406
408
|
sep="<|im_end|>",
|
@@ -0,0 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Flush the KV cache.
|
3
|
+
|
4
|
+
Usage:
|
5
|
+
python3 -m sglang.srt.flush_cache --url http://localhost:30000
|
6
|
+
"""
|
7
|
+
|
8
|
+
import argparse
|
9
|
+
|
10
|
+
import requests
|
11
|
+
|
12
|
+
if __name__ == "__main__":
|
13
|
+
parser = argparse.ArgumentParser()
|
14
|
+
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
15
|
+
args = parser.parse_args()
|
16
|
+
|
17
|
+
response = requests.get(args.url + "/flush_cache")
|
18
|
+
assert response.status_code == 200
|