sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +7 -1
- sglang/bench_latency.py +9 -6
- sglang/bench_serving.py +46 -22
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +4 -2
- sglang/lang/ir.py +16 -7
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/layers/activation.py +32 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +9 -2
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +7 -2
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +40 -16
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +115 -97
- sglang/srt/managers/tokenizer_manager.py +194 -112
- sglang/srt/managers/tp_worker.py +290 -359
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +71 -25
- sglang/srt/model_executor/forward_batch_info.py +293 -156
- sglang/srt/model_executor/model_runner.py +77 -57
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/deepseek.py +2 -2
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +11 -6
- sglang/srt/models/grok.py +50 -396
- sglang/srt/models/internlm2.py +2 -7
- sglang/srt/models/llama2.py +4 -4
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/mixtral.py +56 -254
- sglang/srt/models/mixtral_quant.py +1 -4
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_moe.py +2 -13
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +187 -48
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -8
- sglang/srt/server.py +91 -29
- sglang/srt/server_args.py +32 -19
- sglang/srt/utils.py +32 -15
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +81 -73
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +36 -7
- sglang/test/test_utils.py +24 -2
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
- sglang-0.2.13.dist-info/RECORD +112 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.11.dist-info/RECORD +0 -102
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
sglang/api.py
CHANGED
@@ -62,6 +62,7 @@ def gen(
|
|
62
62
|
name: Optional[str] = None,
|
63
63
|
max_tokens: Optional[int] = None,
|
64
64
|
stop: Optional[Union[str, List[str]]] = None,
|
65
|
+
stop_token_ids: Optional[List[int]] = None,
|
65
66
|
temperature: Optional[float] = None,
|
66
67
|
top_p: Optional[float] = None,
|
67
68
|
top_k: Optional[int] = None,
|
@@ -72,7 +73,7 @@ def gen(
|
|
72
73
|
logprob_start_len: Optional[int] = None,
|
73
74
|
top_logprobs_num: Optional[int] = None,
|
74
75
|
return_text_in_logprobs: Optional[bool] = None,
|
75
|
-
dtype: Optional[type] = None,
|
76
|
+
dtype: Optional[Union[type, str]] = None,
|
76
77
|
choices: Optional[List[str]] = None,
|
77
78
|
choices_method: Optional[ChoicesSamplingMethod] = None,
|
78
79
|
regex: Optional[str] = None,
|
@@ -98,6 +99,7 @@ def gen(
|
|
98
99
|
name,
|
99
100
|
max_tokens,
|
100
101
|
stop,
|
102
|
+
stop_token_ids,
|
101
103
|
temperature,
|
102
104
|
top_p,
|
103
105
|
top_k,
|
@@ -117,6 +119,7 @@ def gen_int(
|
|
117
119
|
name: Optional[str] = None,
|
118
120
|
max_tokens: Optional[int] = None,
|
119
121
|
stop: Optional[Union[str, List[str]]] = None,
|
122
|
+
stop_token_ids: Optional[List[int]] = None,
|
120
123
|
temperature: Optional[float] = None,
|
121
124
|
top_p: Optional[float] = None,
|
122
125
|
top_k: Optional[int] = None,
|
@@ -132,6 +135,7 @@ def gen_int(
|
|
132
135
|
name,
|
133
136
|
max_tokens,
|
134
137
|
stop,
|
138
|
+
stop_token_ids,
|
135
139
|
temperature,
|
136
140
|
top_p,
|
137
141
|
top_k,
|
@@ -151,6 +155,7 @@ def gen_string(
|
|
151
155
|
name: Optional[str] = None,
|
152
156
|
max_tokens: Optional[int] = None,
|
153
157
|
stop: Optional[Union[str, List[str]]] = None,
|
158
|
+
stop_token_ids: Optional[List[int]] = None,
|
154
159
|
temperature: Optional[float] = None,
|
155
160
|
top_p: Optional[float] = None,
|
156
161
|
top_k: Optional[int] = None,
|
@@ -166,6 +171,7 @@ def gen_string(
|
|
166
171
|
name,
|
167
172
|
max_tokens,
|
168
173
|
stop,
|
174
|
+
stop_token_ids,
|
169
175
|
temperature,
|
170
176
|
top_p,
|
171
177
|
top_k,
|
sglang/bench_latency.py
CHANGED
@@ -64,7 +64,7 @@ class BenchArgs:
|
|
64
64
|
run_name: str = "before"
|
65
65
|
batch_size: Tuple[int] = (1,)
|
66
66
|
input_len: Tuple[int] = (1024,)
|
67
|
-
output_len: Tuple[int] = (
|
67
|
+
output_len: Tuple[int] = (16,)
|
68
68
|
result_filename: str = ""
|
69
69
|
correctness_test: bool = False
|
70
70
|
# This is only used for correctness test
|
@@ -152,7 +152,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
|
|
152
152
|
req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
|
153
153
|
req.prefix_indices = []
|
154
154
|
req.sampling_params = sampling_params
|
155
|
-
req.
|
155
|
+
req.fill_ids = req.origin_input_ids
|
156
156
|
reqs.append(req)
|
157
157
|
|
158
158
|
return input_ids, reqs
|
@@ -163,7 +163,7 @@ def prepare_extend_inputs_for_correctness_test(
|
|
163
163
|
):
|
164
164
|
for i in range(len(reqs)):
|
165
165
|
req = reqs[i]
|
166
|
-
req.
|
166
|
+
req.fill_ids += input_ids[i][bench_args.cut_len :]
|
167
167
|
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
168
168
|
i, : bench_args.cut_len
|
169
169
|
]
|
@@ -182,7 +182,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
|
182
182
|
req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
|
183
183
|
req.prefix_indices = []
|
184
184
|
req.sampling_params = sampling_params
|
185
|
-
req.
|
185
|
+
req.fill_ids = req.origin_input_ids
|
186
186
|
reqs.append(req)
|
187
187
|
|
188
188
|
return reqs
|
@@ -195,7 +195,7 @@ def extend(reqs, model_runner):
|
|
195
195
|
token_to_kv_pool=model_runner.token_to_kv_pool,
|
196
196
|
tree_cache=None,
|
197
197
|
)
|
198
|
-
batch.prepare_for_extend(model_runner.model_config.vocab_size
|
198
|
+
batch.prepare_for_extend(model_runner.model_config.vocab_size)
|
199
199
|
output = model_runner.forward(batch, ForwardMode.EXTEND)
|
200
200
|
next_token_ids = batch.sample(output.next_token_logits)
|
201
201
|
return next_token_ids, output.next_token_logits, batch
|
@@ -221,6 +221,7 @@ def correctness_test(
|
|
221
221
|
|
222
222
|
# Prepare inputs
|
223
223
|
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
|
224
|
+
rank_print(f"{input_ids=}")
|
224
225
|
|
225
226
|
if bench_args.cut_len > 0:
|
226
227
|
# Prefill
|
@@ -238,7 +239,7 @@ def correctness_test(
|
|
238
239
|
|
239
240
|
# Decode
|
240
241
|
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
|
241
|
-
for _ in range(bench_args.output_len):
|
242
|
+
for _ in range(bench_args.output_len[0]):
|
242
243
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
243
244
|
for i in range(len(reqs)):
|
244
245
|
output_ids[i].append(next_token_ids[i])
|
@@ -332,6 +333,7 @@ def latency_test(
|
|
332
333
|
)
|
333
334
|
|
334
335
|
# Warm up
|
336
|
+
rank_print("Warmup ...")
|
335
337
|
latency_test_run_once(
|
336
338
|
bench_args.run_name,
|
337
339
|
model_runner,
|
@@ -341,6 +343,7 @@ def latency_test(
|
|
341
343
|
bench_args.input_len[0],
|
342
344
|
4, # shorter decoding to speed up the warmup
|
343
345
|
)
|
346
|
+
rank_print("Benchmark ...")
|
344
347
|
|
345
348
|
# Run the sweep
|
346
349
|
result_list = []
|
sglang/bench_serving.py
CHANGED
@@ -24,7 +24,7 @@ import warnings
|
|
24
24
|
from argparse import ArgumentParser
|
25
25
|
from dataclasses import dataclass, field
|
26
26
|
from datetime import datetime
|
27
|
-
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
27
|
+
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
28
28
|
|
29
29
|
import aiohttp
|
30
30
|
import numpy as np
|
@@ -39,6 +39,8 @@ from transformers import (
|
|
39
39
|
|
40
40
|
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
41
41
|
|
42
|
+
global args
|
43
|
+
|
42
44
|
|
43
45
|
@dataclass
|
44
46
|
class RequestFuncInput:
|
@@ -47,6 +49,7 @@ class RequestFuncInput:
|
|
47
49
|
prompt_len: int
|
48
50
|
output_len: int
|
49
51
|
model: str
|
52
|
+
extra_request_body: Dict[str, Any]
|
50
53
|
|
51
54
|
|
52
55
|
@dataclass
|
@@ -84,6 +87,7 @@ async def async_request_trt_llm(
|
|
84
87
|
"stream": True,
|
85
88
|
"min_length": request_func_input.output_len,
|
86
89
|
"end_id": 1048576,
|
90
|
+
**request_func_input.extra_request_body,
|
87
91
|
}
|
88
92
|
if args.disable_ignore_eos:
|
89
93
|
del payload["min_length"]
|
@@ -154,6 +158,7 @@ async def async_request_openai_completions(
|
|
154
158
|
"max_tokens": request_func_input.output_len,
|
155
159
|
"stream": not args.disable_stream,
|
156
160
|
"ignore_eos": not args.disable_ignore_eos,
|
161
|
+
**request_func_input.extra_request_body,
|
157
162
|
}
|
158
163
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
159
164
|
|
@@ -192,7 +197,8 @@ async def async_request_openai_completions(
|
|
192
197
|
output.ttft = ttft
|
193
198
|
|
194
199
|
# Decoding phase
|
195
|
-
|
200
|
+
else:
|
201
|
+
output.itl.append(timestamp - most_recent_timestamp)
|
196
202
|
|
197
203
|
most_recent_timestamp = timestamp
|
198
204
|
generated_text += data["choices"][0]["text"]
|
@@ -542,6 +548,7 @@ async def benchmark(
|
|
542
548
|
request_rate: float,
|
543
549
|
disable_tqdm: bool,
|
544
550
|
enable_multi: bool,
|
551
|
+
extra_request_body: Dict[str, Any],
|
545
552
|
):
|
546
553
|
if backend in ASYNC_REQUEST_FUNCS:
|
547
554
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
@@ -556,6 +563,7 @@ async def benchmark(
|
|
556
563
|
api_url=api_url,
|
557
564
|
prompt_len=test_prompt_len,
|
558
565
|
output_len=test_output_len,
|
566
|
+
extra_request_body=extra_request_body,
|
559
567
|
)
|
560
568
|
test_output = await request_func(request_func_input=test_input)
|
561
569
|
if not test_output.success:
|
@@ -578,6 +586,7 @@ async def benchmark(
|
|
578
586
|
api_url=api_url,
|
579
587
|
prompt_len=prompt_len,
|
580
588
|
output_len=output_len,
|
589
|
+
extra_request_body=extra_request_body,
|
581
590
|
)
|
582
591
|
tasks.append(
|
583
592
|
asyncio.create_task(
|
@@ -660,19 +669,20 @@ async def benchmark(
|
|
660
669
|
"backend": args.backend,
|
661
670
|
"dataset_name": args.dataset_name,
|
662
671
|
"request_rate": request_rate,
|
663
|
-
"
|
664
|
-
"
|
665
|
-
"
|
666
|
-
"
|
667
|
-
"
|
668
|
-
"
|
669
|
-
"
|
670
|
-
"
|
672
|
+
"total_input_tokens": metrics.total_input,
|
673
|
+
"total_output_tokens": metrics.total_output,
|
674
|
+
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
675
|
+
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
|
676
|
+
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
|
677
|
+
"median_ttft_ms": metrics.median_ttft_ms,
|
678
|
+
"median_itl_ms": metrics.median_itl_ms,
|
679
|
+
"output_throughput": metrics.output_throughput,
|
671
680
|
"sharegpt_output_len": args.sharegpt_output_len,
|
672
681
|
"random_input_len": args.random_input_len,
|
673
682
|
"random_output_len": args.random_output_len,
|
674
683
|
"random_range_ratio": args.random_range_ratio,
|
675
|
-
"
|
684
|
+
"duration": benchmark_duration,
|
685
|
+
"completed": metrics.completed,
|
676
686
|
}
|
677
687
|
else:
|
678
688
|
print(f"Error running benchmark for request rate: {request_rate}")
|
@@ -742,10 +752,18 @@ def check_chat_template(model_path):
|
|
742
752
|
return False
|
743
753
|
|
744
754
|
|
745
|
-
def
|
755
|
+
def run_benchmark(args_: argparse.Namespace):
|
756
|
+
global args
|
757
|
+
args = args_
|
758
|
+
|
759
|
+
set_ulimit()
|
746
760
|
random.seed(args.seed)
|
747
761
|
np.random.seed(args.seed)
|
748
762
|
|
763
|
+
extra_request_body = {}
|
764
|
+
if args.extra_request_body:
|
765
|
+
extra_request_body = json.loads(args.extra_request_body)
|
766
|
+
|
749
767
|
if args.port is None:
|
750
768
|
args.port = {
|
751
769
|
"sglang": 30000,
|
@@ -838,10 +856,11 @@ def fire(args: argparse.Namespace):
|
|
838
856
|
request_rate=rate,
|
839
857
|
disable_tqdm=args.disable_tqdm,
|
840
858
|
enable_multi=args.multi,
|
859
|
+
extra_request_body=extra_request_body,
|
841
860
|
)
|
842
861
|
)
|
843
862
|
else:
|
844
|
-
asyncio.run(
|
863
|
+
return asyncio.run(
|
845
864
|
benchmark(
|
846
865
|
backend=backend,
|
847
866
|
api_url=api_url,
|
@@ -851,6 +870,7 @@ def fire(args: argparse.Namespace):
|
|
851
870
|
request_rate=args.request_rate,
|
852
871
|
disable_tqdm=args.disable_tqdm,
|
853
872
|
enable_multi=args.multi,
|
873
|
+
extra_request_body=extra_request_body,
|
854
874
|
)
|
855
875
|
)
|
856
876
|
|
@@ -949,11 +969,6 @@ if __name__ == "__main__":
|
|
949
969
|
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.",
|
950
970
|
)
|
951
971
|
parser.add_argument("--seed", type=int, default=0, help="Default is 0.")
|
952
|
-
parser.add_argument(
|
953
|
-
"--disable-tqdm",
|
954
|
-
action="store_true",
|
955
|
-
help="Specify to disable tqdm progress bar.",
|
956
|
-
)
|
957
972
|
parser.add_argument(
|
958
973
|
"--multi",
|
959
974
|
action="store_true",
|
@@ -966,6 +981,11 @@ if __name__ == "__main__":
|
|
966
981
|
help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.",
|
967
982
|
)
|
968
983
|
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
|
984
|
+
parser.add_argument(
|
985
|
+
"--disable-tqdm",
|
986
|
+
action="store_true",
|
987
|
+
help="Specify to disable tqdm progress bar.",
|
988
|
+
)
|
969
989
|
parser.add_argument(
|
970
990
|
"--disable-stream",
|
971
991
|
action="store_true",
|
@@ -976,8 +996,12 @@ if __name__ == "__main__":
|
|
976
996
|
action="store_true",
|
977
997
|
help="Disable ignoring EOS.",
|
978
998
|
)
|
979
|
-
|
980
|
-
|
981
|
-
|
999
|
+
parser.add_argument(
|
1000
|
+
"--extra-request-body",
|
1001
|
+
metavar='{"key1": "value1", "key2": "value2"}',
|
1002
|
+
type=str,
|
1003
|
+
help="Append given JSON object to the request payload. You can use this to specify"
|
1004
|
+
"additional generate params like sampling params.",
|
1005
|
+
)
|
982
1006
|
args = parser.parse_args()
|
983
|
-
|
1007
|
+
run_benchmark(args)
|
sglang/global_config.py
CHANGED
@@ -27,7 +27,7 @@ class GlobalConfig:
|
|
27
27
|
# Runtime constants: others
|
28
28
|
self.num_continue_decode_steps = 10
|
29
29
|
self.retract_decode_steps = 20
|
30
|
-
self.flashinfer_workspace_size =
|
30
|
+
self.flashinfer_workspace_size = 384 * 1024 * 1024
|
31
31
|
|
32
32
|
# Output tokenization configs
|
33
33
|
self.skip_special_tokens_in_output = True
|
@@ -1,21 +1,23 @@
|
|
1
1
|
import json
|
2
|
+
import warnings
|
2
3
|
from typing import List, Optional
|
3
4
|
|
4
5
|
from sglang.global_config import global_config
|
5
6
|
from sglang.lang.backend.base_backend import BaseBackend
|
6
7
|
from sglang.lang.chat_template import get_chat_template_by_model_path
|
7
|
-
from sglang.lang.choices import
|
8
|
-
ChoicesDecision,
|
9
|
-
ChoicesSamplingMethod,
|
10
|
-
token_length_normalized,
|
11
|
-
)
|
8
|
+
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
12
9
|
from sglang.lang.interpreter import StreamExecutor
|
13
|
-
from sglang.lang.ir import
|
10
|
+
from sglang.lang.ir import (
|
11
|
+
REGEX_BOOL,
|
12
|
+
REGEX_FLOAT,
|
13
|
+
REGEX_INT,
|
14
|
+
REGEX_STR,
|
15
|
+
SglSamplingParams,
|
16
|
+
)
|
14
17
|
from sglang.utils import http_request
|
15
18
|
|
16
19
|
|
17
20
|
class RuntimeEndpoint(BaseBackend):
|
18
|
-
|
19
21
|
def __init__(
|
20
22
|
self,
|
21
23
|
base_url: str,
|
@@ -95,32 +97,52 @@ class RuntimeEndpoint(BaseBackend):
|
|
95
97
|
)
|
96
98
|
self._assert_success(res)
|
97
99
|
|
100
|
+
def _handle_dtype_to_regex(self, sampling_params: SglSamplingParams):
|
101
|
+
if sampling_params.dtype is None:
|
102
|
+
return
|
103
|
+
|
104
|
+
if sampling_params.stop == ():
|
105
|
+
sampling_params.stop = []
|
106
|
+
|
107
|
+
dtype_regex = None
|
108
|
+
if sampling_params.dtype in ["int", int]:
|
109
|
+
|
110
|
+
dtype_regex = REGEX_INT
|
111
|
+
sampling_params.stop.extend([" ", "\n"])
|
112
|
+
elif sampling_params.dtype in ["float", float]:
|
113
|
+
|
114
|
+
dtype_regex = REGEX_FLOAT
|
115
|
+
sampling_params.stop.extend([" ", "\n"])
|
116
|
+
elif sampling_params.dtype in ["str", str]:
|
117
|
+
|
118
|
+
dtype_regex = REGEX_STR
|
119
|
+
elif sampling_params.dtype in ["bool", bool]:
|
120
|
+
|
121
|
+
dtype_regex = REGEX_BOOL
|
122
|
+
else:
|
123
|
+
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
124
|
+
|
125
|
+
if dtype_regex is not None and sampling_params.regex is not None:
|
126
|
+
warnings.warn(
|
127
|
+
f"Both dtype and regex are set. Only dtype will be used. dtype: {sampling_params.dtype}, regex: {sampling_params.regex}"
|
128
|
+
)
|
129
|
+
|
130
|
+
sampling_params.regex = dtype_regex
|
131
|
+
|
98
132
|
def generate(
|
99
133
|
self,
|
100
134
|
s: StreamExecutor,
|
101
135
|
sampling_params: SglSamplingParams,
|
102
136
|
):
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
elif sampling_params.dtype in [int, "int"]:
|
113
|
-
data = {
|
114
|
-
"text": s.text_,
|
115
|
-
"sampling_params": {
|
116
|
-
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
117
|
-
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
118
|
-
"dtype": "int",
|
119
|
-
**sampling_params.to_srt_kwargs(),
|
120
|
-
},
|
121
|
-
}
|
122
|
-
else:
|
123
|
-
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
137
|
+
self._handle_dtype_to_regex(sampling_params)
|
138
|
+
data = {
|
139
|
+
"text": s.text_,
|
140
|
+
"sampling_params": {
|
141
|
+
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
142
|
+
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
143
|
+
**sampling_params.to_srt_kwargs(),
|
144
|
+
},
|
145
|
+
}
|
124
146
|
|
125
147
|
for item in [
|
126
148
|
"return_logprob",
|
@@ -151,27 +173,16 @@ class RuntimeEndpoint(BaseBackend):
|
|
151
173
|
s: StreamExecutor,
|
152
174
|
sampling_params: SglSamplingParams,
|
153
175
|
):
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
}
|
163
|
-
|
164
|
-
data = {
|
165
|
-
"text": s.text_,
|
166
|
-
"sampling_params": {
|
167
|
-
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
168
|
-
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
169
|
-
"dtype": "int",
|
170
|
-
**sampling_params.to_srt_kwargs(),
|
171
|
-
},
|
172
|
-
}
|
173
|
-
else:
|
174
|
-
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
176
|
+
self._handle_dtype_to_regex(sampling_params)
|
177
|
+
|
178
|
+
data = {
|
179
|
+
"text": s.text_,
|
180
|
+
"sampling_params": {
|
181
|
+
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
182
|
+
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
183
|
+
**sampling_params.to_srt_kwargs(),
|
184
|
+
},
|
185
|
+
}
|
175
186
|
|
176
187
|
for item in [
|
177
188
|
"return_logprob",
|
sglang/lang/compiler.py
CHANGED
@@ -125,7 +125,7 @@ class CompiledFunction:
|
|
125
125
|
def run(
|
126
126
|
self,
|
127
127
|
*,
|
128
|
-
max_new_tokens: int =
|
128
|
+
max_new_tokens: int = 128,
|
129
129
|
stop: Union[str, List[str]] = (),
|
130
130
|
temperature: float = 1.0,
|
131
131
|
top_p: float = 1.0,
|
@@ -155,7 +155,7 @@ class CompiledFunction:
|
|
155
155
|
self,
|
156
156
|
batch_kwargs,
|
157
157
|
*,
|
158
|
-
max_new_tokens: int =
|
158
|
+
max_new_tokens: int = 128,
|
159
159
|
stop: Union[str, List[str]] = (),
|
160
160
|
temperature: float = 1.0,
|
161
161
|
top_p: float = 1.0,
|
sglang/lang/interpreter.py
CHANGED
@@ -20,7 +20,6 @@ from sglang.lang.ir import (
|
|
20
20
|
SglConstantText,
|
21
21
|
SglExpr,
|
22
22
|
SglExprList,
|
23
|
-
SglFunction,
|
24
23
|
SglGen,
|
25
24
|
SglImage,
|
26
25
|
SglRoleBegin,
|
@@ -181,8 +180,10 @@ class StreamExecutor:
|
|
181
180
|
num_api_spec_tokens=None,
|
182
181
|
use_thread=True,
|
183
182
|
):
|
183
|
+
from sglang.lang.backend.base_backend import BaseBackend
|
184
|
+
|
184
185
|
self.sid = uuid.uuid4().hex
|
185
|
-
self.backend = backend
|
186
|
+
self.backend: BaseBackend = backend
|
186
187
|
self.arguments: Dict[str, Any] = arguments
|
187
188
|
self.default_sampling_para = default_sampling_para
|
188
189
|
self.stream = stream
|
@@ -658,6 +659,7 @@ class StreamExecutor:
|
|
658
659
|
for item in [
|
659
660
|
"max_new_tokens",
|
660
661
|
"stop",
|
662
|
+
"stop_token_ids",
|
661
663
|
"temperature",
|
662
664
|
"top_p",
|
663
665
|
"top_k",
|
sglang/lang/ir.py
CHANGED
@@ -8,16 +8,17 @@ from typing import List, Optional, Union
|
|
8
8
|
from sglang.global_config import global_config
|
9
9
|
from sglang.lang.choices import ChoicesSamplingMethod
|
10
10
|
|
11
|
-
REGEX_INT = r"[-+]?[0-9]+"
|
12
|
-
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+"
|
11
|
+
REGEX_INT = r"[-+]?[0-9]+[ \n]*"
|
12
|
+
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+[ \n]*"
|
13
13
|
REGEX_BOOL = r"(True|False)"
|
14
|
-
|
14
|
+
REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
|
15
15
|
|
16
16
|
|
17
17
|
@dataclasses.dataclass
|
18
18
|
class SglSamplingParams:
|
19
|
-
max_new_tokens: int =
|
19
|
+
max_new_tokens: int = 128
|
20
20
|
stop: Union[str, List[str]] = ()
|
21
|
+
stop_token_ids: Optional[List[int]] = ()
|
21
22
|
temperature: float = 1.0
|
22
23
|
top_p: float = 1.0
|
23
24
|
top_k: int = -1 # -1 means disable
|
@@ -37,6 +38,7 @@ class SglSamplingParams:
|
|
37
38
|
return SglSamplingParams(
|
38
39
|
self.max_new_tokens,
|
39
40
|
self.stop,
|
41
|
+
self.stop_token_ids,
|
40
42
|
self.temperature,
|
41
43
|
self.top_p,
|
42
44
|
self.top_k,
|
@@ -108,6 +110,7 @@ class SglSamplingParams:
|
|
108
110
|
return {
|
109
111
|
"max_new_tokens": self.max_new_tokens,
|
110
112
|
"stop": self.stop,
|
113
|
+
"stop_token_ids": self.stop_token_ids,
|
111
114
|
"temperature": self.temperature,
|
112
115
|
"top_p": self.top_p,
|
113
116
|
"top_k": self.top_k,
|
@@ -140,8 +143,9 @@ class SglFunction:
|
|
140
143
|
def run(
|
141
144
|
self,
|
142
145
|
*args,
|
143
|
-
max_new_tokens: int =
|
144
|
-
stop: Union[str, List[str]] =
|
146
|
+
max_new_tokens: int = 128,
|
147
|
+
stop: Union[str, List[str]] = [],
|
148
|
+
stop_token_ids: Optional[List[int]] = [],
|
145
149
|
temperature: float = 1.0,
|
146
150
|
top_p: float = 1.0,
|
147
151
|
top_k: int = -1,
|
@@ -161,6 +165,7 @@ class SglFunction:
|
|
161
165
|
default_sampling_para = SglSamplingParams(
|
162
166
|
max_new_tokens=max_new_tokens,
|
163
167
|
stop=stop,
|
168
|
+
stop_token_ids=stop_token_ids,
|
164
169
|
temperature=temperature,
|
165
170
|
top_p=top_p,
|
166
171
|
top_k=top_k,
|
@@ -179,8 +184,9 @@ class SglFunction:
|
|
179
184
|
self,
|
180
185
|
batch_kwargs,
|
181
186
|
*,
|
182
|
-
max_new_tokens: int =
|
187
|
+
max_new_tokens: int = 128,
|
183
188
|
stop: Union[str, List[str]] = (),
|
189
|
+
stop_token_ids: Optional[List[int]] = [],
|
184
190
|
temperature: float = 1.0,
|
185
191
|
top_p: float = 1.0,
|
186
192
|
top_k: int = -1,
|
@@ -218,6 +224,7 @@ class SglFunction:
|
|
218
224
|
default_sampling_para = SglSamplingParams(
|
219
225
|
max_new_tokens=max_new_tokens,
|
220
226
|
stop=stop,
|
227
|
+
stop_token_ids=stop_token_ids,
|
221
228
|
temperature=temperature,
|
222
229
|
top_p=top_p,
|
223
230
|
top_k=top_k,
|
@@ -397,6 +404,7 @@ class SglGen(SglExpr):
|
|
397
404
|
name: Optional[str] = None,
|
398
405
|
max_new_tokens: Optional[int] = None,
|
399
406
|
stop: Optional[Union[str, List[str]]] = None,
|
407
|
+
stop_token_ids: Optional[List[int]] = None,
|
400
408
|
temperature: Optional[float] = None,
|
401
409
|
top_p: Optional[float] = None,
|
402
410
|
top_k: Optional[int] = None,
|
@@ -416,6 +424,7 @@ class SglGen(SglExpr):
|
|
416
424
|
self.sampling_params = SglSamplingParams(
|
417
425
|
max_new_tokens=max_new_tokens,
|
418
426
|
stop=stop,
|
427
|
+
stop_token_ids=stop_token_ids,
|
419
428
|
temperature=temperature,
|
420
429
|
top_p=top_p,
|
421
430
|
top_k=top_k,
|
@@ -20,10 +20,20 @@ from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
|
20
20
|
|
21
21
|
|
22
22
|
class FSMCache(BaseToolCache):
|
23
|
-
def __init__(
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
tokenizer_path,
|
26
|
+
tokenizer_args_dict,
|
27
|
+
enable=True,
|
28
|
+
skip_tokenizer_init=False,
|
29
|
+
):
|
24
30
|
super().__init__(enable=enable)
|
25
31
|
|
26
|
-
if
|
32
|
+
if (
|
33
|
+
skip_tokenizer_init
|
34
|
+
or tokenizer_path.endswith(".json")
|
35
|
+
or tokenizer_path.endswith(".model")
|
36
|
+
):
|
27
37
|
# Do not support TiktokenTokenizer or SentencePieceTokenizer
|
28
38
|
return
|
29
39
|
|
@@ -62,16 +62,22 @@ class JumpForwardMap:
|
|
62
62
|
id_to_symbol.setdefault(id_, []).append(symbol)
|
63
63
|
|
64
64
|
transitions = fsm_info.transitions
|
65
|
+
|
65
66
|
outgoings_ct = defaultdict(int)
|
66
|
-
|
67
|
+
# NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally
|
68
|
+
for s in fsm_info.finals:
|
69
|
+
outgoings_ct[s] = 1
|
67
70
|
|
71
|
+
state_to_jump_forward = {}
|
68
72
|
for (state, id_), next_state in transitions.items():
|
69
73
|
if id_ == fsm_info.alphabet_anything_value:
|
74
|
+
# Arbitrarily symbol cannot be recognized as jump forward
|
70
75
|
continue
|
76
|
+
|
71
77
|
symbols = id_to_symbol[id_]
|
72
78
|
for c in symbols:
|
73
79
|
if len(c) > 1:
|
74
|
-
# Skip byte level transitions
|
80
|
+
# Skip byte level transitions like c = "5E"
|
75
81
|
continue
|
76
82
|
|
77
83
|
outgoings_ct[state] += 1
|
@@ -87,6 +93,9 @@ class JumpForwardMap:
|
|
87
93
|
|
88
94
|
# Process the byte level jump forward
|
89
95
|
outgoings_ct = defaultdict(int)
|
96
|
+
for s in fsm_info.finals:
|
97
|
+
outgoings_ct[s] = 1
|
98
|
+
|
90
99
|
for (state, id_), next_state in transitions.items():
|
91
100
|
if id_ == fsm_info.alphabet_anything_value:
|
92
101
|
continue
|
@@ -177,3 +186,5 @@ if __name__ == "__main__":
|
|
177
186
|
test_main(r"霍格沃茨特快列车|霍比特人比尔博")
|
178
187
|
# 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
|
179
188
|
# 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
|
189
|
+
|
190
|
+
test_main(r"[-+]?[0-9]+[ ]*")
|