sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +220 -378
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +208 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,151 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
"""Constrained decoding with llguidance backend."""
|
15
|
+
|
16
|
+
import json
|
17
|
+
import os
|
18
|
+
from typing import List, Optional, Tuple
|
19
|
+
|
20
|
+
import llguidance
|
21
|
+
import llguidance.hf
|
22
|
+
import llguidance.torch
|
23
|
+
import torch
|
24
|
+
from llguidance.gbnf_to_lark import any_to_lark
|
25
|
+
|
26
|
+
from sglang.srt.constrained.base_grammar_backend import (
|
27
|
+
BaseGrammarBackend,
|
28
|
+
BaseGrammarObject,
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
class GuidanceGrammar(BaseGrammarObject):
|
33
|
+
def __init__(
|
34
|
+
self, llguidance_tokenizer: llguidance.LLTokenizer, serialized_grammar: str
|
35
|
+
):
|
36
|
+
self.llguidance_tokenizer = llguidance_tokenizer
|
37
|
+
self.serialized_grammar = serialized_grammar
|
38
|
+
|
39
|
+
# TODO: add support for fast-forward tokens in the future
|
40
|
+
self.ll_interpreter = llguidance.LLInterpreter(
|
41
|
+
self.llguidance_tokenizer,
|
42
|
+
self.serialized_grammar,
|
43
|
+
enable_backtrack=False,
|
44
|
+
enable_ff_tokens=False,
|
45
|
+
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
|
46
|
+
)
|
47
|
+
self.pending_ff_tokens: list[int] = []
|
48
|
+
self.finished = False
|
49
|
+
self.bitmask = None
|
50
|
+
|
51
|
+
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
52
|
+
if len(self.pending_ff_tokens) > 0:
|
53
|
+
s = self.llguidance_tokenizer.decode_str(self.pending_ff_tokens)
|
54
|
+
ff_tokens = self.pending_ff_tokens
|
55
|
+
self.pending_ff_tokens = []
|
56
|
+
return (ff_tokens, s)
|
57
|
+
|
58
|
+
return None
|
59
|
+
|
60
|
+
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
61
|
+
return "", -1
|
62
|
+
|
63
|
+
def jump_and_retokenize(
|
64
|
+
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
65
|
+
):
|
66
|
+
pass
|
67
|
+
|
68
|
+
def accept_token(self, token: int):
|
69
|
+
backtrack, ff_tokens = self.ll_interpreter.commit_token(token)
|
70
|
+
if len(ff_tokens) > 0 and backtrack == 0:
|
71
|
+
# first token is last generated token
|
72
|
+
ff_tokens = ff_tokens[1:]
|
73
|
+
self.pending_ff_tokens.extend(ff_tokens)
|
74
|
+
|
75
|
+
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
76
|
+
if len(self.pending_ff_tokens) > 0:
|
77
|
+
# if we have pending fast-forward tokens,
|
78
|
+
# just return them immediately
|
79
|
+
ff_token = self.pending_ff_tokens.pop(0)
|
80
|
+
vocab_mask[idx, :] = 0
|
81
|
+
vocab_mask[idx, ff_token // 32] = 1 << (ff_token % 32)
|
82
|
+
return
|
83
|
+
|
84
|
+
if self.ll_interpreter.has_pending_stop():
|
85
|
+
self.finished = True
|
86
|
+
|
87
|
+
llguidance.torch.fill_next_token_bitmask(self.ll_interpreter, vocab_mask, idx)
|
88
|
+
|
89
|
+
def allocate_vocab_mask(
|
90
|
+
self, vocab_size: int, batch_size: int, device
|
91
|
+
) -> torch.Tensor:
|
92
|
+
if self.bitmask is None or self.bitmask.shape[0] < batch_size:
|
93
|
+
# only create bitmask when batch gets larger
|
94
|
+
self.bitmask = llguidance.torch.allocate_token_bitmask(
|
95
|
+
batch_size, self.llguidance_tokenizer.vocab_size
|
96
|
+
)
|
97
|
+
bitmask = self.bitmask
|
98
|
+
else:
|
99
|
+
bitmask = self.bitmask[:batch_size]
|
100
|
+
|
101
|
+
return bitmask
|
102
|
+
|
103
|
+
@staticmethod
|
104
|
+
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
105
|
+
return vocab_mask.to(device, non_blocking=True)
|
106
|
+
|
107
|
+
@staticmethod
|
108
|
+
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
109
|
+
llguidance.torch.apply_token_bitmask_inplace(logits, vocab_mask)
|
110
|
+
|
111
|
+
def copy(self):
|
112
|
+
return GuidanceGrammar(
|
113
|
+
llguidance_tokenizer=self.llguidance_tokenizer,
|
114
|
+
serialized_grammar=self.serialized_grammar,
|
115
|
+
)
|
116
|
+
|
117
|
+
|
118
|
+
class GuidanceBackend(BaseGrammarBackend):
|
119
|
+
def __init__(self, tokenizer, whitespace_pattern: Optional[str] = None):
|
120
|
+
super().__init__()
|
121
|
+
|
122
|
+
self.tokenizer = tokenizer
|
123
|
+
self.whitespace_flexible = (
|
124
|
+
True if whitespace_pattern == "whitespace_flexible" else False
|
125
|
+
)
|
126
|
+
self.llguidance_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)
|
127
|
+
|
128
|
+
def _from_serialized(self, serialized_grammar) -> GuidanceGrammar:
|
129
|
+
return GuidanceGrammar(
|
130
|
+
llguidance_tokenizer=self.llguidance_tokenizer,
|
131
|
+
serialized_grammar=serialized_grammar,
|
132
|
+
)
|
133
|
+
|
134
|
+
def dispatch_json(self, key_string: str) -> GuidanceGrammar:
|
135
|
+
json_schema = key_string
|
136
|
+
compiler = llguidance.JsonCompiler(whitespace_flexible=self.whitespace_flexible)
|
137
|
+
serialized_grammar = compiler.compile(json_schema)
|
138
|
+
return self._from_serialized(serialized_grammar)
|
139
|
+
|
140
|
+
def dispatch_regex(self, key_string: str) -> GuidanceGrammar:
|
141
|
+
compiler = llguidance.RegexCompiler()
|
142
|
+
serialized_grammar = compiler.compile(regex=key_string)
|
143
|
+
return self._from_serialized(serialized_grammar)
|
144
|
+
|
145
|
+
def dispatch_ebnf(self, key_string: str) -> GuidanceGrammar:
|
146
|
+
compiler = llguidance.LarkCompiler()
|
147
|
+
serialized_grammar = compiler.compile(any_to_lark(key_string))
|
148
|
+
return self._from_serialized(serialized_grammar)
|
149
|
+
|
150
|
+
def dispatch_structural_tag(self, key_string: str):
|
151
|
+
return super().dispatch_structural_tag(key_string)
|
@@ -28,17 +28,11 @@ from sglang.srt.constrained.base_grammar_backend import (
|
|
28
28
|
BaseGrammarObject,
|
29
29
|
)
|
30
30
|
from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
|
31
|
-
from sglang.srt.utils import is_hip
|
32
31
|
|
33
|
-
|
34
|
-
|
35
|
-
|
32
|
+
try:
|
33
|
+
from outlines.fsm.json_schema import build_regex_from_schema
|
34
|
+
except ImportError:
|
36
35
|
from outlines_core.fsm.json_schema import build_regex_from_schema
|
37
|
-
else:
|
38
|
-
try:
|
39
|
-
from outlines.fsm.json_schema import build_regex_from_schema
|
40
|
-
except ImportError:
|
41
|
-
from outlines_core.fsm.json_schema import build_regex_from_schema
|
42
36
|
|
43
37
|
|
44
38
|
logger = logging.getLogger(__name__)
|
@@ -121,7 +115,6 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
|
121
115
|
self,
|
122
116
|
tokenizer,
|
123
117
|
whitespace_pattern: bool,
|
124
|
-
allow_jump_forward: bool,
|
125
118
|
):
|
126
119
|
super().__init__()
|
127
120
|
|
@@ -146,27 +139,9 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
|
146
139
|
self.outlines_tokenizer.vocabulary = (
|
147
140
|
self.outlines_tokenizer.tokenizer.get_vocab()
|
148
141
|
)
|
149
|
-
self.allow_jump_forward = allow_jump_forward
|
150
142
|
self.whitespace_pattern = whitespace_pattern
|
151
143
|
|
152
|
-
def
|
153
|
-
key_type, key_string = key
|
154
|
-
if key_type == "json":
|
155
|
-
try:
|
156
|
-
regex = build_regex_from_object(
|
157
|
-
key_string,
|
158
|
-
whitespace_pattern=self.whitespace_pattern,
|
159
|
-
)
|
160
|
-
except (NotImplementedError, json.decoder.JSONDecodeError) as e:
|
161
|
-
logger.warning(
|
162
|
-
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
163
|
-
)
|
164
|
-
return None
|
165
|
-
elif key_type == "regex":
|
166
|
-
regex = key_string
|
167
|
-
else:
|
168
|
-
raise ValueError(f"Invalid key_type: {key_type}")
|
169
|
-
|
144
|
+
def _compile_regex(self, regex: str) -> Optional[OutlinesGrammar]:
|
170
145
|
try:
|
171
146
|
if hasattr(RegexGuide, "from_regex"):
|
172
147
|
# outlines >= 0.1.1
|
@@ -178,12 +153,28 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
|
178
153
|
logger.warning(f"skip invalid regex schema: {regex=}, {e=}")
|
179
154
|
return None
|
180
155
|
|
181
|
-
|
182
|
-
jump_forward_map = OutlinesJumpForwardMap(regex)
|
183
|
-
else:
|
184
|
-
jump_forward_map = None
|
156
|
+
jump_forward_map = None
|
185
157
|
return OutlinesGrammar(guide, jump_forward_map)
|
186
158
|
|
159
|
+
def dispatch_ebnf(self, key_string: str):
|
160
|
+
return super().dispatch_ebnf(key_string)
|
161
|
+
|
162
|
+
def dispatch_structural_tag(self, key_string: str):
|
163
|
+
return super().dispatch_structural_tag(key_string)
|
164
|
+
|
165
|
+
def dispatch_json(self, key_string: str):
|
166
|
+
try:
|
167
|
+
regex = build_regex_from_object(
|
168
|
+
key_string,
|
169
|
+
whitespace_pattern=self.whitespace_pattern,
|
170
|
+
)
|
171
|
+
except (NotImplementedError, json.decoder.JSONDecodeError) as e:
|
172
|
+
logger.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
|
173
|
+
return self._compile_regex(regex)
|
174
|
+
|
175
|
+
def dispatch_regex(self, key_string: str):
|
176
|
+
return self._compile_regex(key_string)
|
177
|
+
|
187
178
|
|
188
179
|
def build_regex_from_object(
|
189
180
|
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
@@ -13,15 +13,16 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""Constrained decoding with xgrammar backend."""
|
15
15
|
|
16
|
+
import json
|
16
17
|
import logging
|
17
|
-
from typing import List, Tuple
|
18
|
+
from typing import List, Optional, Tuple, Union
|
18
19
|
|
19
20
|
import torch
|
20
21
|
from xgrammar import (
|
21
22
|
CompiledGrammar,
|
22
|
-
Grammar,
|
23
23
|
GrammarCompiler,
|
24
24
|
GrammarMatcher,
|
25
|
+
StructuralTagItem,
|
25
26
|
TokenizerInfo,
|
26
27
|
allocate_token_bitmask,
|
27
28
|
apply_token_bitmask_inplace,
|
@@ -41,17 +42,22 @@ MAX_ROLLBACK_TOKENS = 200
|
|
41
42
|
class XGrammarGrammar(BaseGrammarObject):
|
42
43
|
|
43
44
|
def __init__(
|
44
|
-
self,
|
45
|
+
self,
|
46
|
+
matcher: GrammarMatcher,
|
47
|
+
vocab_size: int,
|
48
|
+
ctx: CompiledGrammar,
|
49
|
+
override_stop_tokens: Optional[Union[List[int], int]],
|
45
50
|
) -> None:
|
46
51
|
self.matcher = matcher
|
47
52
|
self.vocab_size = vocab_size
|
48
53
|
self.ctx = ctx
|
54
|
+
self.override_stop_tokens = override_stop_tokens
|
49
55
|
self.finished = False
|
50
56
|
|
51
57
|
def accept_token(self, token: int):
|
52
58
|
assert self.matcher.accept_token(token)
|
53
59
|
|
54
|
-
def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
|
60
|
+
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
55
61
|
s = self.matcher.find_jump_forward_string()
|
56
62
|
if s:
|
57
63
|
return [], s
|
@@ -95,8 +101,14 @@ class XGrammarGrammar(BaseGrammarObject):
|
|
95
101
|
apply_token_bitmask_inplace(logits, vocab_mask)
|
96
102
|
|
97
103
|
def copy(self):
|
98
|
-
matcher = GrammarMatcher(
|
99
|
-
|
104
|
+
matcher = GrammarMatcher(
|
105
|
+
self.ctx,
|
106
|
+
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
107
|
+
override_stop_tokens=self.override_stop_tokens,
|
108
|
+
)
|
109
|
+
return XGrammarGrammar(
|
110
|
+
matcher, self.vocab_size, self.ctx, self.override_stop_tokens
|
111
|
+
)
|
100
112
|
|
101
113
|
|
102
114
|
class XGrammarGrammarBackend(BaseGrammarBackend):
|
@@ -110,42 +122,61 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
110
122
|
tokenizer_info = TokenizerInfo.from_huggingface(
|
111
123
|
tokenizer, vocab_size=vocab_size
|
112
124
|
)
|
125
|
+
override_stop_tokens = None
|
126
|
+
|
113
127
|
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
|
114
128
|
self.vocab_size = vocab_size
|
129
|
+
self.override_stop_tokens = override_stop_tokens
|
115
130
|
|
116
|
-
def
|
117
|
-
|
118
|
-
key_type, key_string = key
|
119
|
-
if key_type == "json":
|
120
|
-
try:
|
121
|
-
if key_string == "$$ANY$$":
|
122
|
-
ctx = self.grammar_compiler.compile_builtin_json_grammar()
|
123
|
-
else:
|
124
|
-
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
|
125
|
-
except RuntimeError as e:
|
126
|
-
logging.warning(
|
127
|
-
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
128
|
-
)
|
129
|
-
return None
|
130
|
-
elif key_type == "ebnf":
|
131
|
-
try:
|
132
|
-
ctx = self.grammar_compiler.compile_grammar(key_string)
|
133
|
-
except RuntimeError as e:
|
134
|
-
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
|
135
|
-
return None
|
136
|
-
elif key_type == "regex":
|
137
|
-
try:
|
138
|
-
ctx = self.grammar_compiler.compile_grammar(
|
139
|
-
Grammar.from_regex(key_string)
|
140
|
-
)
|
141
|
-
except RuntimeError as e:
|
142
|
-
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
|
143
|
-
return None
|
144
|
-
else:
|
145
|
-
raise ValueError(f"Invalid key_type: {key_type}")
|
146
|
-
|
131
|
+
def _from_context(self, ctx: CompiledGrammar) -> XGrammarGrammar:
|
147
132
|
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
|
148
|
-
return XGrammarGrammar(matcher, self.vocab_size, ctx)
|
133
|
+
return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens)
|
134
|
+
|
135
|
+
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
|
136
|
+
try:
|
137
|
+
if key_string == "$$ANY$$":
|
138
|
+
ctx = self.grammar_compiler.compile_builtin_json_grammar()
|
139
|
+
else:
|
140
|
+
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
|
141
|
+
except RuntimeError as e:
|
142
|
+
logging.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
|
143
|
+
return None
|
144
|
+
return self._from_context(ctx)
|
145
|
+
|
146
|
+
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
|
147
|
+
try:
|
148
|
+
ctx = self.grammar_compiler.compile_grammar(key_string)
|
149
|
+
except RuntimeError as e:
|
150
|
+
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
|
151
|
+
return None
|
152
|
+
return self._from_context(ctx)
|
153
|
+
|
154
|
+
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
|
155
|
+
try:
|
156
|
+
ctx = self.grammar_compiler.compile_regex(key_string)
|
157
|
+
except RuntimeError as e:
|
158
|
+
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
|
159
|
+
return None
|
160
|
+
return self._from_context(ctx)
|
161
|
+
|
162
|
+
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
|
163
|
+
try:
|
164
|
+
structural_tag = json.loads(key_string)
|
165
|
+
tags = [
|
166
|
+
StructuralTagItem(
|
167
|
+
begin=structure["begin"],
|
168
|
+
schema=json.dumps(structure["schema"]),
|
169
|
+
end=structure["end"],
|
170
|
+
)
|
171
|
+
for structure in structural_tag["structures"]
|
172
|
+
]
|
173
|
+
ctx = self.grammar_compiler.compile_structural_tag(
|
174
|
+
tags, structural_tag["triggers"]
|
175
|
+
)
|
176
|
+
except RuntimeError as e:
|
177
|
+
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
|
178
|
+
return None
|
179
|
+
return self._from_context(ctx)
|
149
180
|
|
150
181
|
def reset(self):
|
151
182
|
if self.grammar_compiler:
|