sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +21 -0
- sglang/bench_serving.py +10 -4
- sglang/lang/chat_template.py +24 -0
- sglang/srt/configs/model_config.py +40 -4
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +29 -4
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +18 -5
- sglang/srt/disaggregation/mini_lb.py +53 -122
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +615 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
- sglang/srt/disaggregation/prefill.py +43 -19
- sglang/srt/disaggregation/utils.py +31 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +37 -10
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/attention/flashattention_backend.py +609 -202
- sglang/srt/layers/attention/flashinfer_backend.py +13 -7
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +37 -16
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
- sglang/srt/layers/quantization/fp8.py +28 -14
- sglang/srt/layers/quantization/fp8_kernel.py +130 -4
- sglang/srt/layers/quantization/fp8_utils.py +34 -6
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
- sglang/srt/layers/quantization/w8a8_int8.py +3 -0
- sglang/srt/layers/radix_attention.py +14 -0
- sglang/srt/layers/rotary_embedding.py +75 -1
- sglang/srt/managers/io_struct.py +254 -97
- sglang/srt/managers/mm_utils.py +3 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
- sglang/srt/managers/schedule_batch.py +62 -21
- sglang/srt/managers/scheduler.py +71 -14
- sglang/srt/managers/tokenizer_manager.py +17 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +14 -1
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +49 -9
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +1 -0
- sglang/srt/models/deepseek_v2.py +248 -61
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +13 -4
- sglang/srt/models/llama4.py +487 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +2 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +227 -0
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +1 -0
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +1 -0
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/server_args.py +34 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +6 -2
- sglang/srt/utils.py +120 -9
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/test_block_fp8.py +57 -0
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
- sglang/srt/disaggregation/conn.py +0 -81
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py
CHANGED
@@ -60,6 +60,7 @@ from sglang.srt.configs.model_config import ModelConfig
|
|
60
60
|
from sglang.srt.entrypoints.engine import _set_envs_and_config
|
61
61
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
62
62
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
63
|
+
from sglang.srt.managers.scheduler import Scheduler
|
63
64
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
64
65
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
65
66
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -135,6 +136,7 @@ def load_model(server_args, port_args, tp_rank):
|
|
135
136
|
context_length=server_args.context_length,
|
136
137
|
model_override_args=server_args.json_model_override_args,
|
137
138
|
is_embedding=server_args.is_embedding,
|
139
|
+
enable_multimodal=server_args.enable_multimodal,
|
138
140
|
dtype=server_args.dtype,
|
139
141
|
quantization=server_args.quantization,
|
140
142
|
)
|
@@ -184,6 +186,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
|
|
184
186
|
req.prefix_indices = []
|
185
187
|
req.fill_ids = req.origin_input_ids
|
186
188
|
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
189
|
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
187
190
|
reqs.append(req)
|
188
191
|
|
189
192
|
return input_ids, reqs
|
@@ -199,6 +202,7 @@ def prepare_extend_inputs_for_correctness_test(
|
|
199
202
|
i, : bench_args.cut_len
|
200
203
|
]
|
201
204
|
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
205
|
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
202
206
|
return reqs
|
203
207
|
|
204
208
|
|
@@ -220,6 +224,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
|
220
224
|
req.prefix_indices = []
|
221
225
|
req.fill_ids = req.origin_input_ids
|
222
226
|
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
227
|
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
223
228
|
reqs.append(req)
|
224
229
|
|
225
230
|
return reqs
|
@@ -238,6 +243,7 @@ def extend(reqs, model_runner):
|
|
238
243
|
enable_custom_logit_processor=False,
|
239
244
|
)
|
240
245
|
batch.prepare_for_extend()
|
246
|
+
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
241
247
|
model_worker_batch = batch.get_model_worker_batch()
|
242
248
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
243
249
|
logits_output = model_runner.forward(forward_batch)
|
@@ -249,6 +255,7 @@ def extend(reqs, model_runner):
|
|
249
255
|
def decode(input_token_ids, batch, model_runner):
|
250
256
|
batch.output_ids = input_token_ids
|
251
257
|
batch.prepare_for_decode()
|
258
|
+
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
252
259
|
model_worker_batch = batch.get_model_worker_batch()
|
253
260
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
254
261
|
logits_output = model_runner.forward(forward_batch)
|
@@ -256,6 +263,20 @@ def decode(input_token_ids, batch, model_runner):
|
|
256
263
|
return next_token_ids, logits_output.next_token_logits
|
257
264
|
|
258
265
|
|
266
|
+
def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
|
267
|
+
if model_runner.server_args.enable_dp_attention:
|
268
|
+
Scheduler.prepare_dp_attn_batch_raw(
|
269
|
+
batch,
|
270
|
+
dp_size=model_runner.server_args.dp_size,
|
271
|
+
attn_tp_size=1,
|
272
|
+
tp_cpu_group=model_runner.tp_group.cpu_group,
|
273
|
+
get_idle_batch=None,
|
274
|
+
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
275
|
+
spec_algorithm=SpeculativeAlgorithm.NONE,
|
276
|
+
speculative_num_draft_tokens=None,
|
277
|
+
)
|
278
|
+
|
279
|
+
|
259
280
|
def correctness_test(
|
260
281
|
server_args,
|
261
282
|
port_args,
|
sglang/bench_serving.py
CHANGED
@@ -490,7 +490,7 @@ def get_dataset(args, tokenizer):
|
|
490
490
|
prompt_suffix=args.prompt_suffix,
|
491
491
|
apply_chat_template=args.apply_chat_template,
|
492
492
|
)
|
493
|
-
elif args.dataset_name
|
493
|
+
elif args.dataset_name.startswith("random"):
|
494
494
|
input_requests = sample_random_requests(
|
495
495
|
input_len=args.random_input_len,
|
496
496
|
output_len=args.random_output_len,
|
@@ -498,6 +498,7 @@ def get_dataset(args, tokenizer):
|
|
498
498
|
range_ratio=args.random_range_ratio,
|
499
499
|
tokenizer=tokenizer,
|
500
500
|
dataset_path=args.dataset_path,
|
501
|
+
random_sample=args.dataset_name == "random",
|
501
502
|
)
|
502
503
|
elif args.dataset_name == "generated-shared-prefix":
|
503
504
|
input_requests = sample_generated_shared_prefix_requests(
|
@@ -687,6 +688,7 @@ def sample_random_requests(
|
|
687
688
|
range_ratio: float,
|
688
689
|
tokenizer: PreTrainedTokenizerBase,
|
689
690
|
dataset_path: str,
|
691
|
+
random_sample: bool = True,
|
690
692
|
) -> List[Tuple[str, int, int]]:
|
691
693
|
|
692
694
|
input_lens = np.random.randint(
|
@@ -700,11 +702,15 @@ def sample_random_requests(
|
|
700
702
|
size=num_prompts,
|
701
703
|
)
|
702
704
|
|
703
|
-
if
|
705
|
+
if random_sample:
|
704
706
|
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
|
705
707
|
|
706
708
|
# Download sharegpt if necessary
|
707
709
|
if not os.path.isfile(dataset_path):
|
710
|
+
print(
|
711
|
+
"If you do not want to randomly sample from a dataset,"
|
712
|
+
" please use --dataset-name random-ids."
|
713
|
+
)
|
708
714
|
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
709
715
|
|
710
716
|
# Load the dataset.
|
@@ -1223,7 +1229,7 @@ async def benchmark(
|
|
1223
1229
|
output_file_name = args.output_file
|
1224
1230
|
else:
|
1225
1231
|
now = datetime.now().strftime("%m%d")
|
1226
|
-
if args.dataset_name
|
1232
|
+
if args.dataset_name.startswith("random"):
|
1227
1233
|
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
|
1228
1234
|
else:
|
1229
1235
|
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
|
@@ -1442,7 +1448,7 @@ if __name__ == "__main__":
|
|
1442
1448
|
"--dataset-name",
|
1443
1449
|
type=str,
|
1444
1450
|
default="sharegpt",
|
1445
|
-
choices=["sharegpt", "random", "generated-shared-prefix"],
|
1451
|
+
choices=["sharegpt", "random", "random-ids", "generated-shared-prefix"],
|
1446
1452
|
help="Name of the dataset to benchmark on.",
|
1447
1453
|
)
|
1448
1454
|
parser.add_argument(
|
sglang/lang/chat_template.py
CHANGED
@@ -294,6 +294,30 @@ register_chat_template(
|
|
294
294
|
)
|
295
295
|
)
|
296
296
|
|
297
|
+
# Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
|
298
|
+
register_chat_template(
|
299
|
+
ChatTemplate(
|
300
|
+
name="llama-4",
|
301
|
+
default_system_prompt=None,
|
302
|
+
role_prefix_and_suffix={
|
303
|
+
"system": (
|
304
|
+
"<|header_start|>system<|header_end|>\n\n",
|
305
|
+
"<|eot|>",
|
306
|
+
),
|
307
|
+
"user": (
|
308
|
+
"<|header_start|>user<|header_end|>\n\n",
|
309
|
+
"<|eot|>",
|
310
|
+
),
|
311
|
+
"assistant": (
|
312
|
+
"<|header_start|>assistant<|header_end|>\n\n",
|
313
|
+
"<|eot|>",
|
314
|
+
),
|
315
|
+
},
|
316
|
+
stop_str=("<|eot|>",),
|
317
|
+
image_token="<|image|>",
|
318
|
+
)
|
319
|
+
)
|
320
|
+
|
297
321
|
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
|
298
322
|
register_chat_template(
|
299
323
|
ChatTemplate(
|
@@ -15,6 +15,7 @@
|
|
15
15
|
import json
|
16
16
|
import logging
|
17
17
|
import math
|
18
|
+
import os
|
18
19
|
from enum import IntEnum, auto
|
19
20
|
from typing import List, Optional, Set, Union
|
20
21
|
|
@@ -42,10 +43,12 @@ class ModelConfig:
|
|
42
43
|
context_length: Optional[int] = None,
|
43
44
|
model_override_args: Optional[str] = None,
|
44
45
|
is_embedding: Optional[bool] = None,
|
46
|
+
enable_multimodal: Optional[bool] = None,
|
45
47
|
dtype: str = "auto",
|
46
48
|
quantization: Optional[str] = None,
|
47
49
|
override_config_file: Optional[str] = None,
|
48
50
|
) -> None:
|
51
|
+
|
49
52
|
self.model_path = model_path
|
50
53
|
self.revision = revision
|
51
54
|
self.quantization = quantization
|
@@ -65,15 +68,32 @@ class ModelConfig:
|
|
65
68
|
**kwargs,
|
66
69
|
)
|
67
70
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
71
|
+
self.attention_chunk_size = getattr(
|
72
|
+
self.hf_text_config, "attention_chunk_size", None
|
73
|
+
)
|
74
|
+
|
75
|
+
if enable_multimodal is None:
|
76
|
+
if self.hf_config.architectures == "Llama4ForConditionalGeneration":
|
77
|
+
enable_multimodal = False
|
78
|
+
else:
|
79
|
+
enable_multimodal = True
|
68
80
|
|
69
81
|
# Check model type
|
70
82
|
self.is_generation = is_generation_model(
|
71
83
|
self.hf_config.architectures, is_embedding
|
72
84
|
)
|
73
|
-
self.is_multimodal = is_multimodal_model(
|
74
|
-
|
75
|
-
|
76
|
-
self.
|
85
|
+
self.is_multimodal = enable_multimodal and is_multimodal_model(
|
86
|
+
self.hf_config.architectures
|
87
|
+
)
|
88
|
+
self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
|
89
|
+
self.hf_config.architectures
|
90
|
+
)
|
91
|
+
self.is_image_gen = enable_multimodal and is_image_gen_model(
|
92
|
+
self.hf_config.architectures
|
93
|
+
)
|
94
|
+
self.is_audio_model = enable_multimodal and is_audio_model(
|
95
|
+
self.hf_config.architectures
|
96
|
+
)
|
77
97
|
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
78
98
|
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
79
99
|
|
@@ -231,6 +251,20 @@ class ModelConfig:
|
|
231
251
|
if quant_cfg is None:
|
232
252
|
# compressed-tensors uses a "compression_config" key
|
233
253
|
quant_cfg = getattr(self.hf_config, "compression_config", None)
|
254
|
+
if quant_cfg is None:
|
255
|
+
# check if is modelopt model -- modelopt doesn't have corresponding field
|
256
|
+
# in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
|
257
|
+
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
|
258
|
+
is_local = os.path.exists(self.model_path)
|
259
|
+
modelopt_quant_config = {"quant_method": "modelopt"}
|
260
|
+
if not is_local:
|
261
|
+
from huggingface_hub import HfApi
|
262
|
+
|
263
|
+
hf_api = HfApi()
|
264
|
+
if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
|
265
|
+
quant_cfg = modelopt_quant_config
|
266
|
+
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
|
267
|
+
quant_cfg = modelopt_quant_config
|
234
268
|
return quant_cfg
|
235
269
|
|
236
270
|
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
@@ -261,6 +295,7 @@ class ModelConfig:
|
|
261
295
|
"moe_wna16",
|
262
296
|
]
|
263
297
|
compatible_quantization_methods = {
|
298
|
+
"modelopt_fp4": ["modelopt"],
|
264
299
|
"w8a8_int8": ["compressed-tensors", "compressed_tensors"],
|
265
300
|
"w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
|
266
301
|
}
|
@@ -468,6 +503,7 @@ multimodal_model_archs = [
|
|
468
503
|
"Grok1VForCausalLM",
|
469
504
|
"Grok1AForCausalLM",
|
470
505
|
"LlavaLlamaForCausalLM",
|
506
|
+
"Llama4ForConditionalGeneration",
|
471
507
|
"LlavaMistralForCausalLM",
|
472
508
|
"LlavaQwenForCausalLM",
|
473
509
|
"LlavaVidForCausalLM",
|
@@ -28,6 +28,18 @@ logger = logging.getLogger(__name__)
|
|
28
28
|
|
29
29
|
|
30
30
|
class BaseGrammarObject(ABC):
|
31
|
+
|
32
|
+
def __init__(self):
|
33
|
+
self._finished = False
|
34
|
+
|
35
|
+
@property
|
36
|
+
def finished(self):
|
37
|
+
return self._finished
|
38
|
+
|
39
|
+
@finished.setter
|
40
|
+
def finished(self, finished):
|
41
|
+
self._finished = finished
|
42
|
+
|
31
43
|
@abstractmethod
|
32
44
|
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
33
45
|
"""
|
@@ -59,6 +71,13 @@ class BaseGrammarObject(ABC):
|
|
59
71
|
"""
|
60
72
|
raise NotImplementedError
|
61
73
|
|
74
|
+
@abstractmethod
|
75
|
+
def accept_token(self, token: int) -> None:
|
76
|
+
"""
|
77
|
+
Accept a token in the grammar.
|
78
|
+
"""
|
79
|
+
raise NotImplementedError
|
80
|
+
|
62
81
|
@abstractmethod
|
63
82
|
def allocate_vocab_mask(
|
64
83
|
self, vocab_size: int, batch_size: int, device
|
@@ -90,7 +109,7 @@ class CacheEntry:
|
|
90
109
|
event: Event
|
91
110
|
|
92
111
|
|
93
|
-
class BaseGrammarBackend
|
112
|
+
class BaseGrammarBackend:
|
94
113
|
def __init__(self):
|
95
114
|
self.executor = ThreadPoolExecutor()
|
96
115
|
self.cache: Dict[Tuple[str, str], CacheEntry] = {}
|
@@ -107,19 +126,15 @@ class BaseGrammarBackend(ABC):
|
|
107
126
|
"""
|
108
127
|
raise ValueError(f"Invalid key_type: {key_type}={key_string}")
|
109
128
|
|
110
|
-
@abstractmethod
|
111
129
|
def dispatch_json(self, key_string: str) -> Optional[BaseGrammarObject]:
|
112
130
|
return self._not_supported("json", key_string)
|
113
131
|
|
114
|
-
@abstractmethod
|
115
132
|
def dispatch_regex(self, key_string: str) -> Optional[BaseGrammarObject]:
|
116
133
|
return self._not_supported("regex", key_string)
|
117
134
|
|
118
|
-
@abstractmethod
|
119
135
|
def dispatch_ebnf(self, key_string: str) -> Optional[BaseGrammarObject]:
|
120
136
|
return self._not_supported("ebnf", key_string)
|
121
137
|
|
122
|
-
@abstractmethod
|
123
138
|
def dispatch_structural_tag(self, key_string: str) -> Optional[BaseGrammarObject]:
|
124
139
|
return self._not_supported("structural_tag", key_string)
|
125
140
|
|
@@ -195,4 +210,10 @@ def create_grammar_backend(
|
|
195
210
|
else:
|
196
211
|
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
|
197
212
|
|
213
|
+
if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
|
214
|
+
from .reasoner_grammar_backend import ReasonerGrammarBackend
|
215
|
+
|
216
|
+
grammar_backend = ReasonerGrammarBackend(
|
217
|
+
grammar_backend, tokenizer.think_end_id
|
218
|
+
)
|
198
219
|
return grammar_backend
|
@@ -33,6 +33,7 @@ class GuidanceGrammar(BaseGrammarObject):
|
|
33
33
|
def __init__(
|
34
34
|
self, llguidance_tokenizer: llguidance.LLTokenizer, serialized_grammar: str
|
35
35
|
):
|
36
|
+
super().__init__()
|
36
37
|
self.llguidance_tokenizer = llguidance_tokenizer
|
37
38
|
self.serialized_grammar = serialized_grammar
|
38
39
|
|
@@ -0,0 +1,101 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
"""The baseclass of a backend for reasoner grammar-guided constrained decoding."""
|
15
|
+
|
16
|
+
from concurrent.futures import Future
|
17
|
+
from typing import List, Optional, Tuple
|
18
|
+
|
19
|
+
import torch
|
20
|
+
|
21
|
+
from .base_grammar_backend import BaseGrammarBackend, BaseGrammarObject
|
22
|
+
|
23
|
+
|
24
|
+
class ReasonerGrammarObject(BaseGrammarObject):
|
25
|
+
def __init__(self, grammar: BaseGrammarObject, think_end_id):
|
26
|
+
super().__init__()
|
27
|
+
self.grammar = grammar
|
28
|
+
self.think_end_id = think_end_id
|
29
|
+
self.is_in_reasoning = True
|
30
|
+
|
31
|
+
@property
|
32
|
+
def finished(self):
|
33
|
+
return self.grammar.finished
|
34
|
+
|
35
|
+
@finished.setter
|
36
|
+
def finished(self, finished):
|
37
|
+
self.grammar.finished = finished
|
38
|
+
|
39
|
+
def allocate_vocab_mask(
|
40
|
+
self, vocab_size: int, batch_size: int, device
|
41
|
+
) -> torch.Tensor:
|
42
|
+
return self.grammar.allocate_vocab_mask(vocab_size, batch_size, device)
|
43
|
+
|
44
|
+
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
45
|
+
if not self.is_in_reasoning:
|
46
|
+
self.grammar.fill_vocab_mask(vocab_mask, idx)
|
47
|
+
|
48
|
+
def move_vocab_mask(self, vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
49
|
+
return self.grammar.move_vocab_mask(vocab_mask, device)
|
50
|
+
|
51
|
+
@property
|
52
|
+
def apply_vocab_mask(self):
|
53
|
+
return self.grammar.apply_vocab_mask
|
54
|
+
|
55
|
+
def accept_token(self, token: int):
|
56
|
+
if token == self.think_end_id:
|
57
|
+
self.is_in_reasoning = False
|
58
|
+
|
59
|
+
if not self.is_in_reasoning and token != self.think_end_id:
|
60
|
+
self.grammar.accept_token(token)
|
61
|
+
|
62
|
+
def try_jump_forward(self, tokenizer):
|
63
|
+
return self.grammar.try_jump_forward(tokenizer)
|
64
|
+
|
65
|
+
def jump_forward_str_state(self, helper):
|
66
|
+
return self.grammar.jump_forward_str_state(helper)
|
67
|
+
|
68
|
+
def jump_and_retokenize(
|
69
|
+
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
70
|
+
):
|
71
|
+
return self.grammar.jump_and_retokenize(
|
72
|
+
old_output_ids, new_output_ids, next_state
|
73
|
+
)
|
74
|
+
|
75
|
+
def copy(self) -> BaseGrammarObject:
|
76
|
+
return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)
|
77
|
+
|
78
|
+
|
79
|
+
class ReasonerGrammarBackend(BaseGrammarBackend):
|
80
|
+
def __init__(self, grammar_backend: BaseGrammarBackend, think_end_id):
|
81
|
+
self.grammar_backend = grammar_backend
|
82
|
+
self.think_end_id = think_end_id
|
83
|
+
|
84
|
+
def get_cached_value(self, key: Tuple[str, str]) -> Optional[ReasonerGrammarObject]:
|
85
|
+
grammar = self.grammar_backend.get_cached_value(key)
|
86
|
+
return ReasonerGrammarObject(grammar, self.think_end_id) if grammar else None
|
87
|
+
|
88
|
+
def get_future_value(self, key: Tuple[str, str]) -> Future:
|
89
|
+
grammar = Future()
|
90
|
+
|
91
|
+
def callback(f: Future):
|
92
|
+
if result := f.result():
|
93
|
+
grammar.set_result(ReasonerGrammarObject(result, self.think_end_id))
|
94
|
+
else:
|
95
|
+
grammar.set_result(None)
|
96
|
+
|
97
|
+
self.grammar_backend.get_future_value(key).add_done_callback(callback)
|
98
|
+
return grammar
|
99
|
+
|
100
|
+
def reset(self):
|
101
|
+
self.grammar_backend.reset()
|
sglang/srt/conversation.py
CHANGED
@@ -33,6 +33,7 @@ class SeparatorStyle(IntEnum):
|
|
33
33
|
ADD_NEW_LINE_SINGLE = auto()
|
34
34
|
LLAMA2 = auto()
|
35
35
|
LLAMA3 = auto()
|
36
|
+
LLAMA4 = auto()
|
36
37
|
CHATGLM = auto()
|
37
38
|
CHATML = auto()
|
38
39
|
CHATINTERN = auto()
|
@@ -156,19 +157,30 @@ class Conversation:
|
|
156
157
|
else:
|
157
158
|
ret += role + ":"
|
158
159
|
return ret
|
160
|
+
elif self.sep_style == SeparatorStyle.LLAMA4:
|
161
|
+
# begin_of_text is added by default
|
162
|
+
if self.system_message:
|
163
|
+
ret = system_prompt
|
164
|
+
else:
|
165
|
+
ret = ""
|
166
|
+
for i, (role, message) in enumerate(self.messages):
|
167
|
+
if message:
|
168
|
+
ret += f"<|header_start|>{role}<|header_end|>\n\n"
|
169
|
+
ret += f"{message.strip()}<|eot|>"
|
170
|
+
else:
|
171
|
+
ret += f"<|header_start|>{role}<|header_end|>\n\n"
|
172
|
+
return ret
|
159
173
|
elif self.sep_style == SeparatorStyle.LLAMA3:
|
160
|
-
ret = "<|begin_of_text|>"
|
161
174
|
if self.system_message:
|
162
|
-
ret
|
175
|
+
ret = system_prompt
|
163
176
|
else:
|
164
|
-
ret
|
177
|
+
ret = ""
|
165
178
|
for i, (role, message) in enumerate(self.messages):
|
166
179
|
if message:
|
167
180
|
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
168
181
|
ret += f"{message.strip()}<|eot_id|>"
|
169
182
|
else:
|
170
183
|
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
171
|
-
# print(ret)
|
172
184
|
return ret
|
173
185
|
elif self.sep_style == SeparatorStyle.LLAMA2:
|
174
186
|
seps = [self.sep, self.sep2]
|
@@ -561,6 +573,19 @@ register_conv_template(
|
|
561
573
|
)
|
562
574
|
)
|
563
575
|
|
576
|
+
# reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
|
577
|
+
register_conv_template(
|
578
|
+
Conversation(
|
579
|
+
name="llama-4",
|
580
|
+
system_template="<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>",
|
581
|
+
roles=("user", "assistant"),
|
582
|
+
sep_style=SeparatorStyle.LLAMA4,
|
583
|
+
sep="",
|
584
|
+
stop_str=["<|end_of_text|>", "<|eot|>", "<|eom|>"],
|
585
|
+
image_token="<|image|>",
|
586
|
+
)
|
587
|
+
)
|
588
|
+
|
564
589
|
register_conv_template(
|
565
590
|
Conversation(
|
566
591
|
name="chatml",
|
@@ -0,0 +1,113 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import numpy.typing as npt
|
6
|
+
|
7
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
8
|
+
from sglang.srt.server_args import ServerArgs
|
9
|
+
|
10
|
+
|
11
|
+
class KVArgs:
|
12
|
+
engine_rank: int
|
13
|
+
kv_data_ptrs: list[int]
|
14
|
+
kv_data_lens: list[int]
|
15
|
+
kv_item_lens: list[int]
|
16
|
+
aux_data_ptrs: list[int]
|
17
|
+
aux_data_lens: list[int]
|
18
|
+
aux_item_lens: list[int]
|
19
|
+
ib_device: str
|
20
|
+
gpu_id: int
|
21
|
+
|
22
|
+
|
23
|
+
class KVPoll:
|
24
|
+
Failed = 0
|
25
|
+
Bootstrapping = 1
|
26
|
+
WaitingForInput = 2
|
27
|
+
Transferring = 3
|
28
|
+
Success = 4
|
29
|
+
|
30
|
+
|
31
|
+
class BaseKVManager(ABC):
|
32
|
+
"""Base class for managing transfers states"""
|
33
|
+
|
34
|
+
@abstractmethod
|
35
|
+
def __init__(
|
36
|
+
self,
|
37
|
+
args: KVArgs,
|
38
|
+
disaggregation_mode: DisaggregationMode,
|
39
|
+
server_args: ServerArgs,
|
40
|
+
): ...
|
41
|
+
|
42
|
+
|
43
|
+
class BaseKVSender(ABC):
|
44
|
+
|
45
|
+
@abstractmethod
|
46
|
+
def __init__(
|
47
|
+
self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int
|
48
|
+
): ...
|
49
|
+
|
50
|
+
@abstractmethod
|
51
|
+
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
52
|
+
"""
|
53
|
+
Notify the decoder server about the kv indices length and aux index
|
54
|
+
"""
|
55
|
+
...
|
56
|
+
|
57
|
+
@abstractmethod
|
58
|
+
def send(self, kv_indices: npt.NDArray[np.int64]):
|
59
|
+
"""
|
60
|
+
Send the kv cache at the given kv indices to the decoder server
|
61
|
+
"""
|
62
|
+
...
|
63
|
+
|
64
|
+
@abstractmethod
|
65
|
+
def poll(self) -> KVPoll:
|
66
|
+
"""
|
67
|
+
Check the status of the kv cache transfer
|
68
|
+
"""
|
69
|
+
...
|
70
|
+
|
71
|
+
@abstractmethod
|
72
|
+
def failure_exception(self):
|
73
|
+
"""
|
74
|
+
Raise an exception if the kv cache transfer fails
|
75
|
+
"""
|
76
|
+
...
|
77
|
+
|
78
|
+
|
79
|
+
class BaseKVReceiver(ABC):
|
80
|
+
|
81
|
+
@abstractmethod
|
82
|
+
def __init__(
|
83
|
+
self,
|
84
|
+
mgr: BaseKVManager,
|
85
|
+
bootstrap_addr: str,
|
86
|
+
bootstrap_room: Optional[int] = None,
|
87
|
+
): ...
|
88
|
+
|
89
|
+
@abstractmethod
|
90
|
+
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
91
|
+
"""
|
92
|
+
Notify the prefill server about the kv indices and aux index
|
93
|
+
"""
|
94
|
+
...
|
95
|
+
|
96
|
+
@abstractmethod
|
97
|
+
def poll(self) -> KVPoll:
|
98
|
+
"""
|
99
|
+
Check the status of the kv cache transfer
|
100
|
+
"""
|
101
|
+
...
|
102
|
+
|
103
|
+
@abstractmethod
|
104
|
+
def failure_exception(self):
|
105
|
+
"""
|
106
|
+
Raise an exception if the kv cache transfer fails
|
107
|
+
"""
|
108
|
+
...
|
109
|
+
|
110
|
+
|
111
|
+
class BaseKVBootstrapServer(ABC):
|
112
|
+
@abstractmethod
|
113
|
+
def __init__(self, port: int): ...
|