sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +23 -2
- sglang/bench_serving.py +6 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +80 -11
- sglang/srt/disaggregation/mini_lb.py +58 -123
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +585 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
- sglang/srt/disaggregation/prefill.py +82 -22
- sglang/srt/disaggregation/utils.py +46 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +42 -13
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +430 -257
- sglang/srt/layers/attention/flashinfer_backend.py +18 -9
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +18 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +63 -45
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +131 -136
- sglang/srt/layers/quantization/fp8_kernel.py +328 -46
- sglang/srt/layers/quantization/fp8_utils.py +206 -253
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
- sglang/srt/layers/quantization/w8a8_int8.py +8 -7
- sglang/srt/layers/radix_attention.py +28 -1
- sglang/srt/layers/rotary_embedding.py +15 -3
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +255 -97
- sglang/srt/managers/mm_utils.py +7 -5
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +64 -25
- sglang/srt/managers/scheduler.py +80 -82
- sglang/srt/managers/tokenizer_manager.py +18 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +21 -3
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -6
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +67 -35
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +494 -366
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +6 -5
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +30 -200
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +5 -1
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +15 -13
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +55 -19
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +10 -9
- sglang/srt/utils.py +136 -10
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +224 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/disaggregation/conn.py +0 -81
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
sglang/__init__.py
CHANGED
@@ -24,6 +24,7 @@ from sglang.api import (
|
|
24
24
|
user_end,
|
25
25
|
video,
|
26
26
|
)
|
27
|
+
from sglang.global_config import global_config
|
27
28
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
28
29
|
from sglang.lang.choices import (
|
29
30
|
greedy_token_selection,
|
@@ -31,6 +32,7 @@ from sglang.lang.choices import (
|
|
31
32
|
unconditional_likelihood_normalized,
|
32
33
|
)
|
33
34
|
from sglang.utils import LazyImport
|
35
|
+
from sglang.version import __version__
|
34
36
|
|
35
37
|
ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
|
36
38
|
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
|
@@ -38,10 +40,6 @@ LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
|
|
38
40
|
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
|
39
41
|
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
|
40
42
|
|
41
|
-
# Other configs
|
42
|
-
from sglang.global_config import global_config
|
43
|
-
from sglang.version import __version__
|
44
|
-
|
45
43
|
__all__ = [
|
46
44
|
"Engine",
|
47
45
|
"Runtime",
|
sglang/bench_one_batch.py
CHANGED
@@ -60,6 +60,7 @@ from sglang.srt.configs.model_config import ModelConfig
|
|
60
60
|
from sglang.srt.entrypoints.engine import _set_envs_and_config
|
61
61
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
62
62
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
63
|
+
from sglang.srt.managers.scheduler import Scheduler
|
63
64
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
64
65
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
65
66
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -135,6 +136,7 @@ def load_model(server_args, port_args, tp_rank):
|
|
135
136
|
context_length=server_args.context_length,
|
136
137
|
model_override_args=server_args.json_model_override_args,
|
137
138
|
is_embedding=server_args.is_embedding,
|
139
|
+
enable_multimodal=server_args.enable_multimodal,
|
138
140
|
dtype=server_args.dtype,
|
139
141
|
quantization=server_args.quantization,
|
140
142
|
)
|
@@ -184,6 +186,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
|
|
184
186
|
req.prefix_indices = []
|
185
187
|
req.fill_ids = req.origin_input_ids
|
186
188
|
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
189
|
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
187
190
|
reqs.append(req)
|
188
191
|
|
189
192
|
return input_ids, reqs
|
@@ -199,11 +202,12 @@ def prepare_extend_inputs_for_correctness_test(
|
|
199
202
|
i, : bench_args.cut_len
|
200
203
|
]
|
201
204
|
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
205
|
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
202
206
|
return reqs
|
203
207
|
|
204
208
|
|
205
209
|
def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
206
|
-
input_ids = np.
|
210
|
+
input_ids = np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
|
207
211
|
sampling_params = SamplingParams(
|
208
212
|
temperature=0,
|
209
213
|
max_new_tokens=BenchArgs.output_len,
|
@@ -220,6 +224,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
|
220
224
|
req.prefix_indices = []
|
221
225
|
req.fill_ids = req.origin_input_ids
|
222
226
|
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
227
|
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
223
228
|
reqs.append(req)
|
224
229
|
|
225
230
|
return reqs
|
@@ -238,6 +243,7 @@ def extend(reqs, model_runner):
|
|
238
243
|
enable_custom_logit_processor=False,
|
239
244
|
)
|
240
245
|
batch.prepare_for_extend()
|
246
|
+
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
241
247
|
model_worker_batch = batch.get_model_worker_batch()
|
242
248
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
243
249
|
logits_output = model_runner.forward(forward_batch)
|
@@ -249,6 +255,7 @@ def extend(reqs, model_runner):
|
|
249
255
|
def decode(input_token_ids, batch, model_runner):
|
250
256
|
batch.output_ids = input_token_ids
|
251
257
|
batch.prepare_for_decode()
|
258
|
+
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
252
259
|
model_worker_batch = batch.get_model_worker_batch()
|
253
260
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
254
261
|
logits_output = model_runner.forward(forward_batch)
|
@@ -256,6 +263,20 @@ def decode(input_token_ids, batch, model_runner):
|
|
256
263
|
return next_token_ids, logits_output.next_token_logits
|
257
264
|
|
258
265
|
|
266
|
+
def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
|
267
|
+
if model_runner.server_args.enable_dp_attention:
|
268
|
+
Scheduler.prepare_dp_attn_batch_raw(
|
269
|
+
batch,
|
270
|
+
dp_size=model_runner.server_args.dp_size,
|
271
|
+
attn_tp_size=1,
|
272
|
+
tp_cpu_group=model_runner.tp_group.cpu_group,
|
273
|
+
get_idle_batch=None,
|
274
|
+
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
275
|
+
spec_algorithm=SpeculativeAlgorithm.NONE,
|
276
|
+
speculative_num_draft_tokens=None,
|
277
|
+
)
|
278
|
+
|
279
|
+
|
259
280
|
def correctness_test(
|
260
281
|
server_args,
|
261
282
|
port_args,
|
@@ -375,7 +396,7 @@ def latency_test_run_once(
|
|
375
396
|
decode_latencies.append(latency)
|
376
397
|
if i < 5:
|
377
398
|
rank_print(
|
378
|
-
f"Decode.
|
399
|
+
f"Decode. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
379
400
|
)
|
380
401
|
|
381
402
|
if profile:
|
sglang/bench_serving.py
CHANGED
@@ -490,7 +490,7 @@ def get_dataset(args, tokenizer):
|
|
490
490
|
prompt_suffix=args.prompt_suffix,
|
491
491
|
apply_chat_template=args.apply_chat_template,
|
492
492
|
)
|
493
|
-
elif args.dataset_name
|
493
|
+
elif args.dataset_name.startswith("random"):
|
494
494
|
input_requests = sample_random_requests(
|
495
495
|
input_len=args.random_input_len,
|
496
496
|
output_len=args.random_output_len,
|
@@ -498,6 +498,7 @@ def get_dataset(args, tokenizer):
|
|
498
498
|
range_ratio=args.random_range_ratio,
|
499
499
|
tokenizer=tokenizer,
|
500
500
|
dataset_path=args.dataset_path,
|
501
|
+
random_sample=args.dataset_name == "random",
|
501
502
|
)
|
502
503
|
elif args.dataset_name == "generated-shared-prefix":
|
503
504
|
input_requests = sample_generated_shared_prefix_requests(
|
@@ -687,6 +688,7 @@ def sample_random_requests(
|
|
687
688
|
range_ratio: float,
|
688
689
|
tokenizer: PreTrainedTokenizerBase,
|
689
690
|
dataset_path: str,
|
691
|
+
random_sample: bool = True,
|
690
692
|
) -> List[Tuple[str, int, int]]:
|
691
693
|
|
692
694
|
input_lens = np.random.randint(
|
@@ -700,7 +702,7 @@ def sample_random_requests(
|
|
700
702
|
size=num_prompts,
|
701
703
|
)
|
702
704
|
|
703
|
-
if
|
705
|
+
if random_sample:
|
704
706
|
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
|
705
707
|
|
706
708
|
# Download sharegpt if necessary
|
@@ -1223,7 +1225,7 @@ async def benchmark(
|
|
1223
1225
|
output_file_name = args.output_file
|
1224
1226
|
else:
|
1225
1227
|
now = datetime.now().strftime("%m%d")
|
1226
|
-
if args.dataset_name
|
1228
|
+
if args.dataset_name.startswith("random"):
|
1227
1229
|
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
|
1228
1230
|
else:
|
1229
1231
|
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
|
@@ -1442,7 +1444,7 @@ if __name__ == "__main__":
|
|
1442
1444
|
"--dataset-name",
|
1443
1445
|
type=str,
|
1444
1446
|
default="sharegpt",
|
1445
|
-
choices=["sharegpt", "random", "generated-shared-prefix"],
|
1447
|
+
choices=["sharegpt", "random", "random-ids", "generated-shared-prefix"],
|
1446
1448
|
help="Name of the dataset to benchmark on.",
|
1447
1449
|
)
|
1448
1450
|
parser.add_argument(
|
sglang/lang/backend/anthropic.py
CHANGED
sglang/lang/backend/openai.py
CHANGED
sglang/lang/backend/vertexai.py
CHANGED
sglang/lang/compiler.py
CHANGED
@@ -5,13 +5,7 @@ from typing import List, Union
|
|
5
5
|
|
6
6
|
from sglang.global_config import global_config
|
7
7
|
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
|
8
|
-
from sglang.lang.ir import
|
9
|
-
SglArgument,
|
10
|
-
SglConstantText,
|
11
|
-
SglExpr,
|
12
|
-
SglSamplingParams,
|
13
|
-
SglVariable,
|
14
|
-
)
|
8
|
+
from sglang.lang.ir import SglArgument, SglExpr, SglSamplingParams, SglVariable
|
15
9
|
|
16
10
|
|
17
11
|
def compile_func(function, backend):
|
sglang/lang/tracer.py
CHANGED
@@ -1,20 +1,16 @@
|
|
1
1
|
"""Tracing a program."""
|
2
2
|
|
3
3
|
import uuid
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Dict, List, Optional
|
5
5
|
|
6
|
-
from sglang.global_config import global_config
|
7
6
|
from sglang.lang.backend.base_backend import BaseBackend
|
8
7
|
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
|
9
8
|
from sglang.lang.ir import (
|
10
9
|
SglArgument,
|
11
|
-
SglCommitLazy,
|
12
|
-
SglConcateAndAppend,
|
13
10
|
SglConstantText,
|
14
11
|
SglExpr,
|
15
12
|
SglExprList,
|
16
13
|
SglFork,
|
17
|
-
SglFunction,
|
18
14
|
SglGen,
|
19
15
|
SglGetForkItem,
|
20
16
|
SglRoleBegin,
|
@@ -230,8 +226,8 @@ class TracerProgramState(ProgramState):
|
|
230
226
|
self.cur_role = None
|
231
227
|
|
232
228
|
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
|
233
|
-
new_node = SglVariable(name, source=self.last_node)
|
234
|
-
self.variables[name] = new_node
|
229
|
+
new_node = SglVariable(expr.name, source=self.last_node)
|
230
|
+
self.variables[expr.name] = new_node
|
235
231
|
|
236
232
|
def get_var(self, name):
|
237
233
|
ret = self.arguments.get(name, None)
|
sglang/srt/_custom_ops.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15
15
|
import json
|
16
16
|
import logging
|
17
17
|
import math
|
18
|
+
import os
|
18
19
|
from enum import IntEnum, auto
|
19
20
|
from typing import List, Optional, Set, Union
|
20
21
|
|
@@ -42,10 +43,12 @@ class ModelConfig:
|
|
42
43
|
context_length: Optional[int] = None,
|
43
44
|
model_override_args: Optional[str] = None,
|
44
45
|
is_embedding: Optional[bool] = None,
|
46
|
+
enable_multimodal: Optional[bool] = None,
|
45
47
|
dtype: str = "auto",
|
46
48
|
quantization: Optional[str] = None,
|
47
49
|
override_config_file: Optional[str] = None,
|
48
50
|
) -> None:
|
51
|
+
|
49
52
|
self.model_path = model_path
|
50
53
|
self.revision = revision
|
51
54
|
self.quantization = quantization
|
@@ -69,14 +72,28 @@ class ModelConfig:
|
|
69
72
|
self.hf_text_config, "attention_chunk_size", None
|
70
73
|
)
|
71
74
|
|
75
|
+
if enable_multimodal is None:
|
76
|
+
if self.hf_config.architectures == "Llama4ForConditionalGeneration":
|
77
|
+
enable_multimodal = False
|
78
|
+
else:
|
79
|
+
enable_multimodal = True
|
80
|
+
|
72
81
|
# Check model type
|
73
82
|
self.is_generation = is_generation_model(
|
74
83
|
self.hf_config.architectures, is_embedding
|
75
84
|
)
|
76
|
-
self.is_multimodal = is_multimodal_model(
|
77
|
-
|
78
|
-
|
79
|
-
self.
|
85
|
+
self.is_multimodal = enable_multimodal and is_multimodal_model(
|
86
|
+
self.hf_config.architectures
|
87
|
+
)
|
88
|
+
self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
|
89
|
+
self.hf_config.architectures
|
90
|
+
)
|
91
|
+
self.is_image_gen = enable_multimodal and is_image_gen_model(
|
92
|
+
self.hf_config.architectures
|
93
|
+
)
|
94
|
+
self.is_audio_model = enable_multimodal and is_audio_model(
|
95
|
+
self.hf_config.architectures
|
96
|
+
)
|
80
97
|
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
81
98
|
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
82
99
|
|
@@ -234,6 +251,20 @@ class ModelConfig:
|
|
234
251
|
if quant_cfg is None:
|
235
252
|
# compressed-tensors uses a "compression_config" key
|
236
253
|
quant_cfg = getattr(self.hf_config, "compression_config", None)
|
254
|
+
if quant_cfg is None:
|
255
|
+
# check if is modelopt model -- modelopt doesn't have corresponding field
|
256
|
+
# in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
|
257
|
+
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
|
258
|
+
is_local = os.path.exists(self.model_path)
|
259
|
+
modelopt_quant_config = {"quant_method": "modelopt"}
|
260
|
+
if not is_local:
|
261
|
+
from huggingface_hub import HfApi
|
262
|
+
|
263
|
+
hf_api = HfApi()
|
264
|
+
if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
|
265
|
+
quant_cfg = modelopt_quant_config
|
266
|
+
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
|
267
|
+
quant_cfg = modelopt_quant_config
|
237
268
|
return quant_cfg
|
238
269
|
|
239
270
|
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
@@ -264,6 +295,7 @@ class ModelConfig:
|
|
264
295
|
"moe_wna16",
|
265
296
|
]
|
266
297
|
compatible_quantization_methods = {
|
298
|
+
"modelopt_fp4": ["modelopt"],
|
267
299
|
"w8a8_int8": ["compressed-tensors", "compressed_tensors"],
|
268
300
|
"w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
|
269
301
|
}
|
@@ -470,8 +502,8 @@ multimodal_model_archs = [
|
|
470
502
|
"Gemma3ForConditionalGeneration",
|
471
503
|
"Grok1VForCausalLM",
|
472
504
|
"Grok1AForCausalLM",
|
473
|
-
# TODO: add multimodal support for "Llama4ForConditionalGeneration",
|
474
505
|
"LlavaLlamaForCausalLM",
|
506
|
+
"Llama4ForConditionalGeneration",
|
475
507
|
"LlavaMistralForCausalLM",
|
476
508
|
"LlavaQwenForCausalLM",
|
477
509
|
"LlavaVidForCausalLM",
|
@@ -28,6 +28,18 @@ logger = logging.getLogger(__name__)
|
|
28
28
|
|
29
29
|
|
30
30
|
class BaseGrammarObject(ABC):
|
31
|
+
|
32
|
+
def __init__(self):
|
33
|
+
self._finished = False
|
34
|
+
|
35
|
+
@property
|
36
|
+
def finished(self):
|
37
|
+
return self._finished
|
38
|
+
|
39
|
+
@finished.setter
|
40
|
+
def finished(self, finished):
|
41
|
+
self._finished = finished
|
42
|
+
|
31
43
|
@abstractmethod
|
32
44
|
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
33
45
|
"""
|
@@ -59,6 +71,13 @@ class BaseGrammarObject(ABC):
|
|
59
71
|
"""
|
60
72
|
raise NotImplementedError
|
61
73
|
|
74
|
+
@abstractmethod
|
75
|
+
def accept_token(self, token: int) -> None:
|
76
|
+
"""
|
77
|
+
Accept a token in the grammar.
|
78
|
+
"""
|
79
|
+
raise NotImplementedError
|
80
|
+
|
62
81
|
@abstractmethod
|
63
82
|
def allocate_vocab_mask(
|
64
83
|
self, vocab_size: int, batch_size: int, device
|
@@ -90,7 +109,7 @@ class CacheEntry:
|
|
90
109
|
event: Event
|
91
110
|
|
92
111
|
|
93
|
-
class BaseGrammarBackend
|
112
|
+
class BaseGrammarBackend:
|
94
113
|
def __init__(self):
|
95
114
|
self.executor = ThreadPoolExecutor()
|
96
115
|
self.cache: Dict[Tuple[str, str], CacheEntry] = {}
|
@@ -107,19 +126,15 @@ class BaseGrammarBackend(ABC):
|
|
107
126
|
"""
|
108
127
|
raise ValueError(f"Invalid key_type: {key_type}={key_string}")
|
109
128
|
|
110
|
-
@abstractmethod
|
111
129
|
def dispatch_json(self, key_string: str) -> Optional[BaseGrammarObject]:
|
112
130
|
return self._not_supported("json", key_string)
|
113
131
|
|
114
|
-
@abstractmethod
|
115
132
|
def dispatch_regex(self, key_string: str) -> Optional[BaseGrammarObject]:
|
116
133
|
return self._not_supported("regex", key_string)
|
117
134
|
|
118
|
-
@abstractmethod
|
119
135
|
def dispatch_ebnf(self, key_string: str) -> Optional[BaseGrammarObject]:
|
120
136
|
return self._not_supported("ebnf", key_string)
|
121
137
|
|
122
|
-
@abstractmethod
|
123
138
|
def dispatch_structural_tag(self, key_string: str) -> Optional[BaseGrammarObject]:
|
124
139
|
return self._not_supported("structural_tag", key_string)
|
125
140
|
|
@@ -195,4 +210,10 @@ def create_grammar_backend(
|
|
195
210
|
else:
|
196
211
|
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
|
197
212
|
|
213
|
+
if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
|
214
|
+
from .reasoner_grammar_backend import ReasonerGrammarBackend
|
215
|
+
|
216
|
+
grammar_backend = ReasonerGrammarBackend(
|
217
|
+
grammar_backend, tokenizer.think_end_id
|
218
|
+
)
|
198
219
|
return grammar_backend
|
@@ -33,6 +33,7 @@ class GuidanceGrammar(BaseGrammarObject):
|
|
33
33
|
def __init__(
|
34
34
|
self, llguidance_tokenizer: llguidance.LLTokenizer, serialized_grammar: str
|
35
35
|
):
|
36
|
+
super().__init__()
|
36
37
|
self.llguidance_tokenizer = llguidance_tokenizer
|
37
38
|
self.serialized_grammar = serialized_grammar
|
38
39
|
|
@@ -19,10 +19,13 @@ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
|
|
19
19
|
import dataclasses
|
20
20
|
import logging
|
21
21
|
from collections import defaultdict
|
22
|
+
from typing import Optional
|
22
23
|
|
23
24
|
import interegular
|
24
25
|
from interegular import InvalidSyntax
|
25
|
-
from outlines.caching import cache
|
26
|
+
from outlines.caching import cache
|
27
|
+
|
28
|
+
from sglang.srt.utils import get_bool_env_var
|
26
29
|
|
27
30
|
try:
|
28
31
|
# outlines >= 0.1.0
|
@@ -34,6 +37,9 @@ except ImportError:
|
|
34
37
|
|
35
38
|
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
36
39
|
|
40
|
+
# Env var was set in sglang.srt.server_args.ServerArgs.__post__init__
|
41
|
+
DISABLE_DISK_CACHE = get_bool_env_var("SGLANG_DISABLE_OUTLINES_DISK_CACHE", "true")
|
42
|
+
|
37
43
|
logger = logging.getLogger(__name__)
|
38
44
|
|
39
45
|
|
@@ -45,6 +51,13 @@ class JumpEdge:
|
|
45
51
|
byte_next_state: int = None
|
46
52
|
|
47
53
|
|
54
|
+
def disk_cache(expire: Optional[float] = None, typed=False, ignore=()):
|
55
|
+
if not DISABLE_DISK_CACHE:
|
56
|
+
return cache(expire, typed, ignore)
|
57
|
+
else:
|
58
|
+
return lambda fn: None
|
59
|
+
|
60
|
+
|
48
61
|
@disk_cache()
|
49
62
|
def init_state_to_jump_forward(regex_string):
|
50
63
|
try:
|
@@ -0,0 +1,101 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
"""The baseclass of a backend for reasoner grammar-guided constrained decoding."""
|
15
|
+
|
16
|
+
from concurrent.futures import Future
|
17
|
+
from typing import List, Optional, Tuple
|
18
|
+
|
19
|
+
import torch
|
20
|
+
|
21
|
+
from .base_grammar_backend import BaseGrammarBackend, BaseGrammarObject
|
22
|
+
|
23
|
+
|
24
|
+
class ReasonerGrammarObject(BaseGrammarObject):
|
25
|
+
def __init__(self, grammar: BaseGrammarObject, think_end_id):
|
26
|
+
super().__init__()
|
27
|
+
self.grammar = grammar
|
28
|
+
self.think_end_id = think_end_id
|
29
|
+
self.is_in_reasoning = True
|
30
|
+
|
31
|
+
@property
|
32
|
+
def finished(self):
|
33
|
+
return self.grammar.finished
|
34
|
+
|
35
|
+
@finished.setter
|
36
|
+
def finished(self, finished):
|
37
|
+
self.grammar.finished = finished
|
38
|
+
|
39
|
+
def allocate_vocab_mask(
|
40
|
+
self, vocab_size: int, batch_size: int, device
|
41
|
+
) -> torch.Tensor:
|
42
|
+
return self.grammar.allocate_vocab_mask(vocab_size, batch_size, device)
|
43
|
+
|
44
|
+
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
45
|
+
if not self.is_in_reasoning:
|
46
|
+
self.grammar.fill_vocab_mask(vocab_mask, idx)
|
47
|
+
|
48
|
+
def move_vocab_mask(self, vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
49
|
+
return self.grammar.move_vocab_mask(vocab_mask, device)
|
50
|
+
|
51
|
+
@property
|
52
|
+
def apply_vocab_mask(self):
|
53
|
+
return self.grammar.apply_vocab_mask
|
54
|
+
|
55
|
+
def accept_token(self, token: int):
|
56
|
+
if token == self.think_end_id:
|
57
|
+
self.is_in_reasoning = False
|
58
|
+
|
59
|
+
if not self.is_in_reasoning and token != self.think_end_id:
|
60
|
+
self.grammar.accept_token(token)
|
61
|
+
|
62
|
+
def try_jump_forward(self, tokenizer):
|
63
|
+
return self.grammar.try_jump_forward(tokenizer)
|
64
|
+
|
65
|
+
def jump_forward_str_state(self, helper):
|
66
|
+
return self.grammar.jump_forward_str_state(helper)
|
67
|
+
|
68
|
+
def jump_and_retokenize(
|
69
|
+
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
70
|
+
):
|
71
|
+
return self.grammar.jump_and_retokenize(
|
72
|
+
old_output_ids, new_output_ids, next_state
|
73
|
+
)
|
74
|
+
|
75
|
+
def copy(self) -> BaseGrammarObject:
|
76
|
+
return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)
|
77
|
+
|
78
|
+
|
79
|
+
class ReasonerGrammarBackend(BaseGrammarBackend):
|
80
|
+
def __init__(self, grammar_backend: BaseGrammarBackend, think_end_id):
|
81
|
+
self.grammar_backend = grammar_backend
|
82
|
+
self.think_end_id = think_end_id
|
83
|
+
|
84
|
+
def get_cached_value(self, key: Tuple[str, str]) -> Optional[ReasonerGrammarObject]:
|
85
|
+
grammar = self.grammar_backend.get_cached_value(key)
|
86
|
+
return ReasonerGrammarObject(grammar, self.think_end_id) if grammar else None
|
87
|
+
|
88
|
+
def get_future_value(self, key: Tuple[str, str]) -> Future:
|
89
|
+
grammar = Future()
|
90
|
+
|
91
|
+
def callback(f: Future):
|
92
|
+
if result := f.result():
|
93
|
+
grammar.set_result(ReasonerGrammarObject(result, self.think_end_id))
|
94
|
+
else:
|
95
|
+
grammar.set_result(None)
|
96
|
+
|
97
|
+
self.grammar_backend.get_future_value(key).add_done_callback(callback)
|
98
|
+
return grammar
|
99
|
+
|
100
|
+
def reset(self):
|
101
|
+
self.grammar_backend.reset()
|