sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +23 -1
- sglang/bench_latency.py +48 -33
- sglang/bench_server_latency.py +0 -6
- sglang/bench_serving.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +14 -1
- sglang/lang/interpreter.py +16 -6
- sglang/lang/ir.py +20 -4
- sglang/srt/configs/model_config.py +11 -9
- sglang/srt/constrained/fsm_cache.py +9 -1
- sglang/srt/constrained/jump_forward.py +15 -2
- sglang/srt/hf_transformers_utils.py +1 -0
- sglang/srt/layers/activation.py +4 -4
- sglang/srt/layers/attention/__init__.py +49 -0
- sglang/srt/layers/attention/flashinfer_backend.py +277 -0
- sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
- sglang/srt/layers/attention/triton_backend.py +161 -0
- sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/patch.py +117 -0
- sglang/srt/layers/layernorm.py +4 -4
- sglang/srt/layers/logits_processor.py +19 -15
- sglang/srt/layers/pooler.py +3 -3
- sglang/srt/layers/quantization/__init__.py +0 -2
- sglang/srt/layers/radix_attention.py +6 -4
- sglang/srt/layers/sampler.py +6 -4
- sglang/srt/layers/torchao_utils.py +18 -0
- sglang/srt/lora/lora.py +20 -21
- sglang/srt/lora/lora_manager.py +97 -25
- sglang/srt/managers/detokenizer_manager.py +31 -18
- sglang/srt/managers/image_processor.py +187 -0
- sglang/srt/managers/io_struct.py +99 -75
- sglang/srt/managers/schedule_batch.py +187 -68
- sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
- sglang/srt/managers/scheduler.py +1021 -0
- sglang/srt/managers/tokenizer_manager.py +120 -247
- sglang/srt/managers/tp_worker.py +28 -925
- sglang/srt/mem_cache/memory_pool.py +34 -52
- sglang/srt/mem_cache/radix_cache.py +5 -5
- sglang/srt/model_executor/cuda_graph_runner.py +25 -25
- sglang/srt/model_executor/forward_batch_info.py +94 -97
- sglang/srt/model_executor/model_runner.py +76 -78
- sglang/srt/models/baichuan.py +10 -10
- sglang/srt/models/chatglm.py +12 -12
- sglang/srt/models/commandr.py +10 -10
- sglang/srt/models/dbrx.py +12 -12
- sglang/srt/models/deepseek.py +10 -10
- sglang/srt/models/deepseek_v2.py +14 -15
- sglang/srt/models/exaone.py +10 -10
- sglang/srt/models/gemma.py +10 -10
- sglang/srt/models/gemma2.py +11 -11
- sglang/srt/models/gpt_bigcode.py +10 -10
- sglang/srt/models/grok.py +10 -10
- sglang/srt/models/internlm2.py +10 -10
- sglang/srt/models/llama.py +22 -10
- sglang/srt/models/llama_classification.py +5 -5
- sglang/srt/models/llama_embedding.py +4 -4
- sglang/srt/models/llama_reward.py +142 -0
- sglang/srt/models/llava.py +39 -33
- sglang/srt/models/llavavid.py +31 -28
- sglang/srt/models/minicpm.py +10 -10
- sglang/srt/models/minicpm3.py +14 -15
- sglang/srt/models/mixtral.py +10 -10
- sglang/srt/models/mixtral_quant.py +10 -10
- sglang/srt/models/olmoe.py +10 -10
- sglang/srt/models/qwen.py +10 -10
- sglang/srt/models/qwen2.py +11 -11
- sglang/srt/models/qwen2_moe.py +10 -10
- sglang/srt/models/stablelm.py +10 -10
- sglang/srt/models/torch_native_llama.py +506 -0
- sglang/srt/models/xverse.py +10 -10
- sglang/srt/models/xverse_moe.py +10 -10
- sglang/srt/openai_api/adapter.py +7 -0
- sglang/srt/sampling/sampling_batch_info.py +36 -27
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +170 -119
- sglang/srt/server_args.py +54 -27
- sglang/srt/utils.py +101 -128
- sglang/test/runners.py +76 -33
- sglang/test/test_programs.py +38 -5
- sglang/test/test_utils.py +53 -9
- sglang/version.py +1 -1
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
- sglang-0.3.3.dist-info/RECORD +139 -0
- sglang/srt/layers/attention_backend.py +0 -482
- sglang/srt/managers/controller_multi.py +0 -207
- sglang/srt/managers/controller_single.py +0 -164
- sglang-0.3.1.post3.dist-info/RECORD +0 -134
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
sglang/__init__.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
# SGL API Components
|
2
2
|
|
3
3
|
from sglang.api import (
|
4
|
+
Engine,
|
4
5
|
Runtime,
|
5
6
|
assistant,
|
6
7
|
assistant_begin,
|
@@ -31,6 +32,7 @@ from sglang.lang.choices import (
|
|
31
32
|
# SGLang DSL APIs
|
32
33
|
__all__ = [
|
33
34
|
"Runtime",
|
35
|
+
"Engine",
|
34
36
|
"assistant",
|
35
37
|
"assistant_begin",
|
36
38
|
"assistant_end",
|
sglang/api.py
CHANGED
@@ -33,13 +33,23 @@ def function(
|
|
33
33
|
|
34
34
|
|
35
35
|
def Runtime(*args, **kwargs):
|
36
|
-
# Avoid importing unnecessary dependency
|
37
36
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
37
|
+
|
38
|
+
# Avoid importing unnecessary dependency
|
38
39
|
from sglang.srt.server import Runtime
|
39
40
|
|
40
41
|
return Runtime(*args, **kwargs)
|
41
42
|
|
42
43
|
|
44
|
+
def Engine(*args, **kwargs):
|
45
|
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
46
|
+
|
47
|
+
# Avoid importing unnecessary dependency
|
48
|
+
from sglang.srt.server import Engine
|
49
|
+
|
50
|
+
return Engine(*args, **kwargs)
|
51
|
+
|
52
|
+
|
43
53
|
def set_default_backend(backend: BaseBackend):
|
44
54
|
global_config.default_backend = backend
|
45
55
|
|
@@ -48,6 +58,10 @@ def flush_cache(backend: Optional[BaseBackend] = None):
|
|
48
58
|
backend = backend or global_config.default_backend
|
49
59
|
if backend is None:
|
50
60
|
return False
|
61
|
+
|
62
|
+
# If backend is Runtime
|
63
|
+
if hasattr(backend, "endpoint"):
|
64
|
+
backend = backend.endpoint
|
51
65
|
return backend.flush_cache()
|
52
66
|
|
53
67
|
|
@@ -55,12 +69,17 @@ def get_server_args(backend: Optional[BaseBackend] = None):
|
|
55
69
|
backend = backend or global_config.default_backend
|
56
70
|
if backend is None:
|
57
71
|
return None
|
72
|
+
|
73
|
+
# If backend is Runtime
|
74
|
+
if hasattr(backend, "endpoint"):
|
75
|
+
backend = backend.endpoint
|
58
76
|
return backend.get_server_args()
|
59
77
|
|
60
78
|
|
61
79
|
def gen(
|
62
80
|
name: Optional[str] = None,
|
63
81
|
max_tokens: Optional[int] = None,
|
82
|
+
min_tokens: Optional[int] = None,
|
64
83
|
stop: Optional[Union[str, List[str]]] = None,
|
65
84
|
stop_token_ids: Optional[List[int]] = None,
|
66
85
|
temperature: Optional[float] = None,
|
@@ -100,6 +119,7 @@ def gen(
|
|
100
119
|
return SglGen(
|
101
120
|
name,
|
102
121
|
max_tokens,
|
122
|
+
min_tokens,
|
103
123
|
stop,
|
104
124
|
stop_token_ids,
|
105
125
|
temperature,
|
@@ -139,6 +159,7 @@ def gen_int(
|
|
139
159
|
return SglGen(
|
140
160
|
name,
|
141
161
|
max_tokens,
|
162
|
+
None,
|
142
163
|
stop,
|
143
164
|
stop_token_ids,
|
144
165
|
temperature,
|
@@ -177,6 +198,7 @@ def gen_string(
|
|
177
198
|
return SglGen(
|
178
199
|
name,
|
179
200
|
max_tokens,
|
201
|
+
None,
|
180
202
|
stop,
|
181
203
|
stop_token_ids,
|
182
204
|
temperature,
|
sglang/bench_latency.py
CHANGED
@@ -47,6 +47,7 @@ I'm going to the park
|
|
47
47
|
import argparse
|
48
48
|
import dataclasses
|
49
49
|
import itertools
|
50
|
+
import json
|
50
51
|
import logging
|
51
52
|
import multiprocessing
|
52
53
|
import os
|
@@ -62,10 +63,11 @@ import torch.distributed as dist
|
|
62
63
|
from sglang.srt.configs.model_config import ModelConfig
|
63
64
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
64
65
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
66
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
65
67
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
66
68
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
67
69
|
from sglang.srt.server import _set_envs_and_config
|
68
|
-
from sglang.srt.server_args import ServerArgs
|
70
|
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
69
71
|
from sglang.srt.utils import (
|
70
72
|
configure_logger,
|
71
73
|
kill_child_process,
|
@@ -121,7 +123,7 @@ class BenchArgs:
|
|
121
123
|
)
|
122
124
|
|
123
125
|
|
124
|
-
def load_model(server_args, tp_rank):
|
126
|
+
def load_model(server_args, port_args, tp_rank):
|
125
127
|
suppress_other_loggers()
|
126
128
|
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
127
129
|
|
@@ -129,6 +131,7 @@ def load_model(server_args, tp_rank):
|
|
129
131
|
server_args.model_path,
|
130
132
|
server_args.trust_remote_code,
|
131
133
|
context_length=server_args.context_length,
|
134
|
+
model_override_args=json.loads(server_args.json_model_override_args),
|
132
135
|
)
|
133
136
|
model_runner = ModelRunner(
|
134
137
|
model_config=model_config,
|
@@ -136,7 +139,7 @@ def load_model(server_args, tp_rank):
|
|
136
139
|
gpu_id=tp_rank,
|
137
140
|
tp_rank=tp_rank,
|
138
141
|
tp_size=server_args.tp_size,
|
139
|
-
nccl_port=
|
142
|
+
nccl_port=port_args.nccl_ports[0],
|
140
143
|
server_args=server_args,
|
141
144
|
)
|
142
145
|
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
@@ -167,9 +170,13 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
|
|
167
170
|
assert len(input_ids[i]) > bench_args.cut_len
|
168
171
|
|
169
172
|
tmp_input_ids = input_ids[i][: bench_args.cut_len]
|
170
|
-
req = Req(
|
173
|
+
req = Req(
|
174
|
+
rid=i,
|
175
|
+
origin_input_text=prompts[i],
|
176
|
+
origin_input_ids=tmp_input_ids,
|
177
|
+
sampling_params=sampling_params,
|
178
|
+
)
|
171
179
|
req.prefix_indices = []
|
172
|
-
req.sampling_params = sampling_params
|
173
180
|
req.fill_ids = req.origin_input_ids
|
174
181
|
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
175
182
|
reqs.append(req)
|
@@ -199,9 +206,13 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
|
199
206
|
|
200
207
|
reqs = []
|
201
208
|
for i in range(len(input_ids)):
|
202
|
-
req = Req(
|
209
|
+
req = Req(
|
210
|
+
rid=i,
|
211
|
+
origin_input_text="",
|
212
|
+
origin_input_ids=list(input_ids[i]),
|
213
|
+
sampling_params=sampling_params,
|
214
|
+
)
|
203
215
|
req.prefix_indices = []
|
204
|
-
req.sampling_params = sampling_params
|
205
216
|
req.fill_ids = req.origin_input_ids
|
206
217
|
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
207
218
|
reqs.append(req)
|
@@ -217,28 +228,33 @@ def extend(reqs, model_runner):
|
|
217
228
|
tree_cache=None,
|
218
229
|
)
|
219
230
|
batch.prepare_for_extend(model_runner.model_config.vocab_size)
|
220
|
-
|
221
|
-
|
231
|
+
model_worker_batch = batch.get_model_worker_batch()
|
232
|
+
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
233
|
+
logits_output = model_runner.forward(forward_batch)
|
234
|
+
next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
|
222
235
|
return next_token_ids, logits_output.next_token_logits, batch
|
223
236
|
|
224
237
|
|
225
238
|
def decode(input_token_ids, batch, model_runner):
|
226
239
|
batch.prepare_for_decode(input_token_ids)
|
227
|
-
|
228
|
-
|
240
|
+
model_worker_batch = batch.get_model_worker_batch()
|
241
|
+
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
242
|
+
logits_output = model_runner.forward(forward_batch)
|
243
|
+
next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
|
229
244
|
return next_token_ids, logits_output.next_token_logits
|
230
245
|
|
231
246
|
|
232
247
|
@torch.inference_mode()
|
233
248
|
def correctness_test(
|
234
249
|
server_args,
|
250
|
+
port_args,
|
235
251
|
bench_args,
|
236
252
|
tp_rank,
|
237
253
|
):
|
238
254
|
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
239
255
|
|
240
256
|
# Load the model
|
241
|
-
model_runner, tokenizer = load_model(server_args, tp_rank)
|
257
|
+
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
|
242
258
|
|
243
259
|
# Prepare inputs
|
244
260
|
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
|
@@ -260,7 +276,7 @@ def correctness_test(
|
|
260
276
|
|
261
277
|
# Decode
|
262
278
|
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
|
263
|
-
for _ in range(bench_args.output_len[0]):
|
279
|
+
for _ in range(bench_args.output_len[0] - 1):
|
264
280
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
265
281
|
for i in range(len(reqs)):
|
266
282
|
output_ids[i].append(next_token_ids[i])
|
@@ -311,7 +327,7 @@ def latency_test_run_once(
|
|
311
327
|
|
312
328
|
# Decode
|
313
329
|
decode_latencies = []
|
314
|
-
for i in range(output_len):
|
330
|
+
for i in range(output_len - 1):
|
315
331
|
torch.cuda.synchronize()
|
316
332
|
tic = time.time()
|
317
333
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
@@ -324,13 +340,16 @@ def latency_test_run_once(
|
|
324
340
|
rank_print(
|
325
341
|
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
326
342
|
)
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
343
|
+
|
344
|
+
# record decode timing from 2nd output
|
345
|
+
if output_len > 1:
|
346
|
+
med_decode_latency = np.median(decode_latencies)
|
347
|
+
med_decode_throughput = batch_size / med_decode_latency
|
348
|
+
rank_print(
|
349
|
+
f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
|
350
|
+
)
|
351
|
+
measurement_results["median_decode_latency"] = med_decode_latency
|
352
|
+
measurement_results["median_decode_throughput"] = med_decode_throughput
|
334
353
|
|
335
354
|
throughput = (input_len + output_len) * batch_size / tot_latency
|
336
355
|
rank_print(
|
@@ -343,15 +362,15 @@ def latency_test_run_once(
|
|
343
362
|
|
344
363
|
def latency_test(
|
345
364
|
server_args,
|
365
|
+
port_args,
|
346
366
|
bench_args,
|
347
367
|
tp_rank,
|
348
368
|
):
|
349
369
|
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
350
|
-
_set_envs_and_config(server_args)
|
351
370
|
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
352
371
|
|
353
372
|
# Load the model
|
354
|
-
model_runner, tokenizer = load_model(server_args, tp_rank)
|
373
|
+
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
|
355
374
|
|
356
375
|
# Prepare inputs for warm up
|
357
376
|
reqs = prepare_synthetic_inputs_for_latency_test(
|
@@ -367,7 +386,7 @@ def latency_test(
|
|
367
386
|
reqs,
|
368
387
|
bench_args.batch_size[0],
|
369
388
|
bench_args.input_len[0],
|
370
|
-
|
389
|
+
8, # shorter decoding to speed up the warmup
|
371
390
|
)
|
372
391
|
rank_print("Benchmark ...")
|
373
392
|
|
@@ -453,6 +472,7 @@ def plot_latency_test(
|
|
453
472
|
|
454
473
|
|
455
474
|
def main(server_args, bench_args):
|
475
|
+
_set_envs_and_config(server_args)
|
456
476
|
|
457
477
|
if server_args.model_path:
|
458
478
|
if bench_args.correctness_test:
|
@@ -468,8 +488,10 @@ def main(server_args, bench_args):
|
|
468
488
|
"provide --result-filename for plotting the results"
|
469
489
|
)
|
470
490
|
|
491
|
+
port_args = PortArgs.init_new(server_args)
|
492
|
+
|
471
493
|
if server_args.tp_size == 1:
|
472
|
-
work_func(server_args, bench_args, 0)
|
494
|
+
work_func(server_args, port_args, bench_args, 0)
|
473
495
|
else:
|
474
496
|
workers = []
|
475
497
|
for tp_rank in range(server_args.tp_size):
|
@@ -477,6 +499,7 @@ def main(server_args, bench_args):
|
|
477
499
|
target=work_func,
|
478
500
|
args=(
|
479
501
|
server_args,
|
502
|
+
port_args,
|
480
503
|
bench_args,
|
481
504
|
tp_rank,
|
482
505
|
),
|
@@ -491,18 +514,10 @@ def main(server_args, bench_args):
|
|
491
514
|
|
492
515
|
|
493
516
|
if __name__ == "__main__":
|
494
|
-
multiprocessing.set_start_method("spawn", force=True)
|
495
|
-
|
496
517
|
parser = argparse.ArgumentParser()
|
497
518
|
ServerArgs.add_cli_args(parser)
|
498
519
|
BenchArgs.add_cli_args(parser)
|
499
|
-
# For this script, model-path is not required
|
500
|
-
assert (
|
501
|
-
parser._actions[1].option_strings[0] == "--model-path"
|
502
|
-
), "options changed, this code need to be updated"
|
503
|
-
parser._actions[1].required = False
|
504
520
|
args = parser.parse_args()
|
505
|
-
|
506
521
|
server_args = ServerArgs.from_cli_args(args)
|
507
522
|
bench_args = BenchArgs.from_cli_args(args)
|
508
523
|
|
sglang/bench_server_latency.py
CHANGED
@@ -174,13 +174,7 @@ if __name__ == "__main__":
|
|
174
174
|
parser = argparse.ArgumentParser()
|
175
175
|
ServerArgs.add_cli_args(parser)
|
176
176
|
BenchArgs.add_cli_args(parser)
|
177
|
-
# For this script, model-path is not required
|
178
|
-
assert (
|
179
|
-
parser._actions[1].option_strings[0] == "--model-path"
|
180
|
-
), "options changed, this code need to be updated"
|
181
|
-
parser._actions[1].required = False
|
182
177
|
args = parser.parse_args()
|
183
|
-
|
184
178
|
server_args = ServerArgs.from_cli_args(args)
|
185
179
|
bench_args = BenchArgs.from_cli_args(args)
|
186
180
|
|
sglang/bench_serving.py
CHANGED
@@ -845,6 +845,7 @@ def run_benchmark(args_: argparse.Namespace):
|
|
845
845
|
tokenizer = get_tokenizer(tokenizer_id)
|
846
846
|
|
847
847
|
if args.dataset_name == "sharegpt":
|
848
|
+
assert args.random_input_len is None and args.random_output_len is None
|
848
849
|
input_requests = sample_sharegpt_requests(
|
849
850
|
dataset_path=args.dataset_path,
|
850
851
|
num_requests=args.num_prompts,
|
@@ -852,6 +853,7 @@ def run_benchmark(args_: argparse.Namespace):
|
|
852
853
|
fixed_output_len=args.sharegpt_output_len,
|
853
854
|
)
|
854
855
|
elif args.dataset_name == "random":
|
856
|
+
assert args.random_input_len is not None and args.random_output_len is not None
|
855
857
|
input_requests = sample_random_requests(
|
856
858
|
input_len=args.random_input_len,
|
857
859
|
output_len=args.random_output_len,
|
@@ -964,13 +966,11 @@ if __name__ == "__main__":
|
|
964
966
|
parser.add_argument(
|
965
967
|
"--random-input-len",
|
966
968
|
type=int,
|
967
|
-
default=1024,
|
968
969
|
help="Number of input tokens per request, used only for random dataset.",
|
969
970
|
)
|
970
971
|
parser.add_argument(
|
971
972
|
"--random-output-len",
|
972
973
|
type=int,
|
973
|
-
default=128,
|
974
974
|
help="Number of output tokens per request, used only for random dataset.",
|
975
975
|
)
|
976
976
|
parser.add_argument(
|
@@ -235,6 +235,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
235
235
|
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
236
236
|
obj = self._generate_http_request(s, data)
|
237
237
|
prompt_len = obj["meta_info"]["prompt_tokens"]
|
238
|
+
logprob_start_len = max(prompt_len - 2, 0) # For token healing
|
238
239
|
|
239
240
|
# Compute logprob
|
240
241
|
data = {
|
@@ -244,7 +245,8 @@ class RuntimeEndpoint(BaseBackend):
|
|
244
245
|
"temperature": 0,
|
245
246
|
},
|
246
247
|
"return_logprob": True,
|
247
|
-
"
|
248
|
+
"return_text_in_logprobs": True,
|
249
|
+
"logprob_start_len": logprob_start_len,
|
248
250
|
}
|
249
251
|
obj = self._generate_http_request(s, data)
|
250
252
|
|
@@ -254,6 +256,17 @@ class RuntimeEndpoint(BaseBackend):
|
|
254
256
|
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
|
255
257
|
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
|
256
258
|
|
259
|
+
# Remove extra token if no token healing occurred
|
260
|
+
for i in range(len(input_token_logprobs)):
|
261
|
+
healed_token_str = input_token_logprobs[i][0][-1]
|
262
|
+
if s.text_.endswith(healed_token_str):
|
263
|
+
healed_token_logprob = input_token_logprobs[i][0][0]
|
264
|
+
normalized_prompt_logprobs[i] = (
|
265
|
+
normalized_prompt_logprobs[i] * len(input_token_logprobs[i])
|
266
|
+
- healed_token_logprob
|
267
|
+
) / (len(input_token_logprobs[i]) - 1)
|
268
|
+
input_token_logprobs[i] = input_token_logprobs[i][1:]
|
269
|
+
|
257
270
|
# Compute unconditional logprobs if required
|
258
271
|
if choices_method.requires_unconditional_logprobs:
|
259
272
|
input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
|
sglang/lang/interpreter.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
import asyncio
|
4
4
|
import contextvars
|
5
|
+
import copy
|
5
6
|
import multiprocessing
|
6
7
|
import queue
|
7
8
|
import threading
|
@@ -652,9 +653,22 @@ class StreamExecutor:
|
|
652
653
|
self._init_var_event(e)
|
653
654
|
|
654
655
|
def _resolve_sampling_params(self, sampling_params):
|
655
|
-
|
656
|
+
"""
|
657
|
+
Construct sampling param based on default + override values
|
658
|
+
|
659
|
+
The default values of sampling are populated in `default_sampling_para` via sgl.function.run(...sampling_args)
|
660
|
+
, and `sampling_params` contains the override values from sgl.gen().
|
661
|
+
|
662
|
+
Here we use default_sampling_para as the base and override the values if they exist in `sampling_params`.
|
663
|
+
It also extends the stop tokens based on the chat template.
|
664
|
+
"""
|
665
|
+
|
666
|
+
# deepcopy is required because the dict has lists inside
|
667
|
+
clone = copy.deepcopy(self.default_sampling_para)
|
668
|
+
|
656
669
|
for item in [
|
657
670
|
"max_new_tokens",
|
671
|
+
"min_new_tokens",
|
658
672
|
"stop",
|
659
673
|
"stop_token_ids",
|
660
674
|
"temperature",
|
@@ -674,20 +688,16 @@ class StreamExecutor:
|
|
674
688
|
]:
|
675
689
|
value = getattr(sampling_params, item, None)
|
676
690
|
if value is not None:
|
677
|
-
if clone is None:
|
678
|
-
clone = self.default_sampling_para.clone()
|
679
691
|
setattr(clone, item, value)
|
680
692
|
|
681
693
|
if self.chat_template.stop_str:
|
682
|
-
if not clone:
|
683
|
-
clone = self.default_sampling_para.clone()
|
684
694
|
if clone.stop == ():
|
685
695
|
clone.stop = []
|
686
696
|
elif isinstance(clone.stop, str):
|
687
697
|
clone.stop = [clone.stop]
|
688
698
|
clone.stop += self.chat_template.stop_str
|
689
699
|
|
690
|
-
return clone
|
700
|
+
return clone
|
691
701
|
|
692
702
|
def __del__(self):
|
693
703
|
self.end()
|
sglang/lang/ir.py
CHANGED
@@ -17,6 +17,7 @@ REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
|
|
17
17
|
@dataclasses.dataclass
|
18
18
|
class SglSamplingParams:
|
19
19
|
max_new_tokens: int = 128
|
20
|
+
min_new_tokens: int = 0
|
20
21
|
stop: Union[str, List[str]] = ()
|
21
22
|
stop_token_ids: Optional[List[int]] = ()
|
22
23
|
temperature: float = 1.0
|
@@ -39,6 +40,7 @@ class SglSamplingParams:
|
|
39
40
|
def clone(self):
|
40
41
|
return SglSamplingParams(
|
41
42
|
self.max_new_tokens,
|
43
|
+
self.min_new_tokens,
|
42
44
|
self.stop,
|
43
45
|
self.stop_token_ids,
|
44
46
|
self.temperature,
|
@@ -113,6 +115,7 @@ class SglSamplingParams:
|
|
113
115
|
def to_srt_kwargs(self):
|
114
116
|
return {
|
115
117
|
"max_new_tokens": self.max_new_tokens,
|
118
|
+
"min_new_tokens": self.min_new_tokens,
|
116
119
|
"stop": self.stop,
|
117
120
|
"stop_token_ids": self.stop_token_ids,
|
118
121
|
"temperature": self.temperature,
|
@@ -150,8 +153,8 @@ class SglFunction:
|
|
150
153
|
self,
|
151
154
|
*args,
|
152
155
|
max_new_tokens: int = 128,
|
153
|
-
stop: Union[str, List[str]] =
|
154
|
-
stop_token_ids: Optional[List[int]] =
|
156
|
+
stop: Optional[Union[str, List[str]]] = None,
|
157
|
+
stop_token_ids: Optional[List[int]] = None,
|
155
158
|
temperature: float = 1.0,
|
156
159
|
top_p: float = 1.0,
|
157
160
|
top_k: int = -1,
|
@@ -169,6 +172,12 @@ class SglFunction:
|
|
169
172
|
):
|
170
173
|
from sglang.lang.interpreter import run_program
|
171
174
|
|
175
|
+
# avoid using [] as the default arg: https://nikos7am.com/posts/mutable-default-arguments/
|
176
|
+
if stop is None:
|
177
|
+
stop = []
|
178
|
+
if stop_token_ids is None:
|
179
|
+
stop_token_ids = []
|
180
|
+
|
172
181
|
default_sampling_para = SglSamplingParams(
|
173
182
|
max_new_tokens=max_new_tokens,
|
174
183
|
stop=stop,
|
@@ -193,8 +202,8 @@ class SglFunction:
|
|
193
202
|
batch_kwargs,
|
194
203
|
*,
|
195
204
|
max_new_tokens: int = 128,
|
196
|
-
stop: Union[str, List[str]] =
|
197
|
-
stop_token_ids: Optional[List[int]] =
|
205
|
+
stop: Optional[Union[str, List[str]]] = None,
|
206
|
+
stop_token_ids: Optional[List[int]] = None,
|
198
207
|
temperature: float = 1.0,
|
199
208
|
top_p: float = 1.0,
|
200
209
|
top_k: int = -1,
|
@@ -212,6 +221,11 @@ class SglFunction:
|
|
212
221
|
):
|
213
222
|
from sglang.lang.interpreter import run_program_batch
|
214
223
|
|
224
|
+
if stop is None:
|
225
|
+
stop = []
|
226
|
+
if stop_token_ids is None:
|
227
|
+
stop_token_ids = []
|
228
|
+
|
215
229
|
assert isinstance(batch_kwargs, (list, tuple))
|
216
230
|
if len(batch_kwargs) == 0:
|
217
231
|
return []
|
@@ -413,6 +427,7 @@ class SglGen(SglExpr):
|
|
413
427
|
self,
|
414
428
|
name: Optional[str] = None,
|
415
429
|
max_new_tokens: Optional[int] = None,
|
430
|
+
min_new_tokens: Optional[int] = None,
|
416
431
|
stop: Optional[Union[str, List[str]]] = None,
|
417
432
|
stop_token_ids: Optional[List[int]] = None,
|
418
433
|
temperature: Optional[float] = None,
|
@@ -435,6 +450,7 @@ class SglGen(SglExpr):
|
|
435
450
|
self.name = name
|
436
451
|
self.sampling_params = SglSamplingParams(
|
437
452
|
max_new_tokens=max_new_tokens,
|
453
|
+
min_new_tokens=min_new_tokens,
|
438
454
|
stop=stop,
|
439
455
|
stop_token_ids=stop_token_ids,
|
440
456
|
temperature=temperature,
|
@@ -49,13 +49,13 @@ class ModelConfig:
|
|
49
49
|
if context_length is not None:
|
50
50
|
self.context_len = context_length
|
51
51
|
else:
|
52
|
-
self.context_len = get_context_length(self.
|
52
|
+
self.context_len = get_context_length(self.hf_text_config)
|
53
53
|
|
54
|
-
# Unify the config keys for
|
54
|
+
# Unify the config keys for hf_text_config
|
55
55
|
self.head_dim = getattr(
|
56
|
-
self.
|
56
|
+
self.hf_text_config,
|
57
57
|
"head_dim",
|
58
|
-
self.
|
58
|
+
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
|
59
59
|
)
|
60
60
|
|
61
61
|
# FIXME: temporary special judge for deepseek v2 MLA architecture
|
@@ -72,8 +72,10 @@ class ModelConfig:
|
|
72
72
|
else:
|
73
73
|
self.attention_arch = AttentionArch.MHA
|
74
74
|
|
75
|
-
self.num_attention_heads = self.
|
76
|
-
self.num_key_value_heads = getattr(
|
75
|
+
self.num_attention_heads = self.hf_text_config.num_attention_heads
|
76
|
+
self.num_key_value_heads = getattr(
|
77
|
+
self.hf_text_config, "num_key_value_heads", None
|
78
|
+
)
|
77
79
|
|
78
80
|
# for Dbrx and MPT models
|
79
81
|
if self.hf_config.model_type in ["dbrx", "mpt"]:
|
@@ -83,9 +85,9 @@ class ModelConfig:
|
|
83
85
|
|
84
86
|
if self.num_key_value_heads is None:
|
85
87
|
self.num_key_value_heads = self.num_attention_heads
|
86
|
-
self.hidden_size = self.
|
87
|
-
self.num_hidden_layers = self.
|
88
|
-
self.vocab_size = self.
|
88
|
+
self.hidden_size = self.hf_text_config.hidden_size
|
89
|
+
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
90
|
+
self.vocab_size = self.hf_text_config.vocab_size
|
89
91
|
|
90
92
|
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
91
93
|
def get_total_num_kv_heads(self) -> int:
|
@@ -14,13 +14,17 @@ limitations under the License.
|
|
14
14
|
"""
|
15
15
|
|
16
16
|
"""Cache for the compressed finite state machine."""
|
17
|
+
import logging
|
17
18
|
|
19
|
+
from interegular import InvalidSyntax, parse_pattern
|
18
20
|
from outlines.fsm.json_schema import build_regex_from_schema
|
19
21
|
from transformers import AutoTokenizer
|
20
22
|
|
21
23
|
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
22
24
|
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
23
25
|
|
26
|
+
logger = logging.getLogger(__name__)
|
27
|
+
|
24
28
|
|
25
29
|
class FSMCache(BaseToolCache):
|
26
30
|
def __init__(
|
@@ -76,5 +80,9 @@ class FSMCache(BaseToolCache):
|
|
76
80
|
regex = key_string
|
77
81
|
else:
|
78
82
|
raise ValueError(f"Invalid key_type: {key_type}")
|
79
|
-
|
83
|
+
try:
|
84
|
+
parse_pattern(regex)
|
85
|
+
except InvalidSyntax as e:
|
86
|
+
logger.warning(f"skip invalid regex guide: {regex=}, {e=}")
|
87
|
+
return None, regex
|
80
88
|
return RegexGuide(regex, self.outlines_tokenizer), regex
|
@@ -19,10 +19,12 @@ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
|
|
19
19
|
"""
|
20
20
|
|
21
21
|
import dataclasses
|
22
|
+
import logging
|
22
23
|
from collections import defaultdict
|
23
24
|
|
24
25
|
import interegular
|
25
26
|
import outlines.caching
|
27
|
+
from interegular import InvalidSyntax
|
26
28
|
|
27
29
|
from sglang.srt.constrained import (
|
28
30
|
FSMInfo,
|
@@ -34,6 +36,8 @@ from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
|
34
36
|
|
35
37
|
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
36
38
|
|
39
|
+
logger = logging.getLogger(__name__)
|
40
|
+
|
37
41
|
|
38
42
|
@dataclasses.dataclass
|
39
43
|
class JumpEdge:
|
@@ -47,7 +51,12 @@ class JumpForwardMap:
|
|
47
51
|
def __init__(self, regex_string):
|
48
52
|
@disk_cache()
|
49
53
|
def _init_state_to_jump_forward(regex_string):
|
50
|
-
|
54
|
+
try:
|
55
|
+
regex_pattern = interegular.parse_pattern(regex_string)
|
56
|
+
except InvalidSyntax as e:
|
57
|
+
logger.warning(f"skip invalid regex: {regex_string}, {e=}")
|
58
|
+
self.state_to_jump_forward = None
|
59
|
+
return
|
51
60
|
|
52
61
|
byte_fsm = make_byte_level_fsm(
|
53
62
|
regex_pattern.to_fsm().reduce(), keep_utf8=True
|
@@ -165,7 +174,11 @@ class JumpForwardCache(BaseToolCache):
|
|
165
174
|
super().__init__()
|
166
175
|
|
167
176
|
def init_value(self, regex):
|
168
|
-
|
177
|
+
forward_map = JumpForwardMap(regex)
|
178
|
+
if forward_map.state_to_jump_forward:
|
179
|
+
return forward_map
|
180
|
+
else:
|
181
|
+
return None
|
169
182
|
|
170
183
|
|
171
184
|
def test_main(regex_string):
|