sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +3 -13
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +158 -8
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +119 -75
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +5 -2
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +18 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +71 -53
- sglang/srt/conversation.py +78 -46
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +11 -3
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +236 -138
- sglang/srt/disaggregation/nixl/conn.py +242 -71
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +51 -2
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +31 -4
- sglang/srt/entrypoints/http_server.py +45 -3
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +147 -51
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
- sglang/srt/layers/moe/ep_moe/layer.py +121 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +77 -71
- sglang/srt/layers/quantization/fp8.py +110 -97
- sglang/srt/layers/quantization/fp8_kernel.py +81 -62
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +11 -14
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +13 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +93 -23
- sglang/srt/managers/schedule_policy.py +11 -8
- sglang/srt/managers/scheduler.py +140 -100
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +157 -47
- sglang/srt/managers/tp_worker.py +21 -21
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +4 -2
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +57 -41
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +3 -3
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +77 -39
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +3 -1
- sglang/srt/models/llama4.py +58 -13
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +52 -42
- sglang/srt/openai_api/protocol.py +20 -16
- sglang/srt/reasoning_parser.py +1 -1
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +2 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +64 -10
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +41 -6
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +92 -15
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
sglang/srt/configs/janus_pro.py
CHANGED
@@ -24,6 +24,7 @@ from transformers import PretrainedConfig
|
|
24
24
|
|
25
25
|
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
26
26
|
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
27
|
+
from sglang.srt.server_args import ServerArgs
|
27
28
|
from sglang.srt.utils import get_bool_env_var, is_hip
|
28
29
|
|
29
30
|
logger = logging.getLogger(__name__)
|
@@ -210,6 +211,21 @@ class ModelConfig:
|
|
210
211
|
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
211
212
|
self.image_token_id = getattr(self.hf_config, "image_token_id", None)
|
212
213
|
|
214
|
+
@staticmethod
|
215
|
+
def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
|
216
|
+
return ModelConfig(
|
217
|
+
model_path=model_path or server_args.model_path,
|
218
|
+
trust_remote_code=server_args.trust_remote_code,
|
219
|
+
revision=server_args.revision,
|
220
|
+
context_length=server_args.context_length,
|
221
|
+
model_override_args=server_args.json_model_override_args,
|
222
|
+
is_embedding=server_args.is_embedding,
|
223
|
+
enable_multimodal=server_args.enable_multimodal,
|
224
|
+
dtype=server_args.dtype,
|
225
|
+
quantization=server_args.quantization,
|
226
|
+
**kwargs,
|
227
|
+
)
|
228
|
+
|
213
229
|
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
214
230
|
def get_total_num_kv_heads(self) -> int:
|
215
231
|
"""Returns the total number of KV heads."""
|
@@ -529,6 +545,7 @@ multimodal_model_archs = [
|
|
529
545
|
"Llama4ForConditionalGeneration",
|
530
546
|
"LlavaMistralForCausalLM",
|
531
547
|
"LlavaQwenForCausalLM",
|
548
|
+
"LlavaForConditionalGeneration",
|
532
549
|
"LlavaVidForCausalLM",
|
533
550
|
"MiniCPMO",
|
534
551
|
"MiniCPMV",
|
@@ -538,6 +555,7 @@ multimodal_model_archs = [
|
|
538
555
|
"Qwen2_5_VLForConditionalGeneration",
|
539
556
|
"CLIPModel",
|
540
557
|
"KimiVLForConditionalGeneration",
|
558
|
+
"InternVLChatModel",
|
541
559
|
]
|
542
560
|
|
543
561
|
|
@@ -14,10 +14,9 @@
|
|
14
14
|
"""The baseclass of a backend for grammar-guided constrained decoding."""
|
15
15
|
|
16
16
|
import logging
|
17
|
-
from
|
18
|
-
from concurrent.futures import Future, ThreadPoolExecutor
|
17
|
+
from concurrent.futures import ThreadPoolExecutor
|
19
18
|
from dataclasses import dataclass
|
20
|
-
from threading import Event
|
19
|
+
from threading import Event
|
21
20
|
from typing import Dict, List, Optional, Tuple
|
22
21
|
|
23
22
|
import torch
|
@@ -27,11 +26,42 @@ from sglang.srt.server_args import ServerArgs
|
|
27
26
|
logger = logging.getLogger(__name__)
|
28
27
|
|
29
28
|
|
30
|
-
class BaseGrammarObject
|
29
|
+
class BaseGrammarObject:
|
31
30
|
|
32
31
|
def __init__(self):
|
33
32
|
self._finished = False
|
34
33
|
|
34
|
+
def accept_token(self, token: int) -> None:
|
35
|
+
"""
|
36
|
+
Accept a token in the grammar.
|
37
|
+
"""
|
38
|
+
raise NotImplementedError()
|
39
|
+
|
40
|
+
def rollback(self, k: int):
|
41
|
+
raise NotImplementedError()
|
42
|
+
|
43
|
+
def is_terminated(self):
|
44
|
+
return False
|
45
|
+
|
46
|
+
def allocate_vocab_mask(
|
47
|
+
self, vocab_size: int, batch_size: int, device
|
48
|
+
) -> torch.Tensor:
|
49
|
+
raise NotImplementedError()
|
50
|
+
|
51
|
+
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
52
|
+
raise NotImplementedError()
|
53
|
+
|
54
|
+
@staticmethod
|
55
|
+
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
56
|
+
raise NotImplementedError()
|
57
|
+
|
58
|
+
@staticmethod
|
59
|
+
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
60
|
+
raise NotImplementedError()
|
61
|
+
|
62
|
+
def copy(self) -> "BaseGrammarObject":
|
63
|
+
raise NotImplementedError()
|
64
|
+
|
35
65
|
@property
|
36
66
|
def finished(self):
|
37
67
|
return self._finished
|
@@ -40,7 +70,6 @@ class BaseGrammarObject(ABC):
|
|
40
70
|
def finished(self, finished):
|
41
71
|
self._finished = finished
|
42
72
|
|
43
|
-
@abstractmethod
|
44
73
|
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
45
74
|
"""
|
46
75
|
Try to jump forward in the grammar.
|
@@ -49,9 +78,8 @@ class BaseGrammarObject(ABC):
|
|
49
78
|
A jump forward helper which may be used in `jump_forward_str_state`.
|
50
79
|
None if the jump forward is not possible.
|
51
80
|
"""
|
52
|
-
raise NotImplementedError
|
81
|
+
raise NotImplementedError()
|
53
82
|
|
54
|
-
@abstractmethod
|
55
83
|
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
56
84
|
"""
|
57
85
|
Jump forward for the grammar.
|
@@ -60,47 +88,15 @@ class BaseGrammarObject(ABC):
|
|
60
88
|
A tuple of the jump forward string and the next state of the grammar
|
61
89
|
(which can be used in `jump_and_retokenize` if needed).
|
62
90
|
"""
|
63
|
-
raise NotImplementedError
|
91
|
+
raise NotImplementedError()
|
64
92
|
|
65
|
-
@abstractmethod
|
66
93
|
def jump_and_retokenize(
|
67
94
|
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
68
95
|
) -> None:
|
69
96
|
"""
|
70
97
|
Jump forward occurs, and update the grammar state if needed.
|
71
98
|
"""
|
72
|
-
raise NotImplementedError
|
73
|
-
|
74
|
-
@abstractmethod
|
75
|
-
def accept_token(self, token: int) -> None:
|
76
|
-
"""
|
77
|
-
Accept a token in the grammar.
|
78
|
-
"""
|
79
|
-
raise NotImplementedError
|
80
|
-
|
81
|
-
@abstractmethod
|
82
|
-
def allocate_vocab_mask(
|
83
|
-
self, vocab_size: int, batch_size: int, device
|
84
|
-
) -> torch.Tensor:
|
85
|
-
raise NotImplementedError
|
86
|
-
|
87
|
-
@abstractmethod
|
88
|
-
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
89
|
-
raise NotImplementedError
|
90
|
-
|
91
|
-
@staticmethod
|
92
|
-
@abstractmethod
|
93
|
-
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
94
|
-
raise NotImplementedError
|
95
|
-
|
96
|
-
@staticmethod
|
97
|
-
@abstractmethod
|
98
|
-
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
99
|
-
raise NotImplementedError
|
100
|
-
|
101
|
-
@abstractmethod
|
102
|
-
def copy(self) -> "BaseGrammarObject":
|
103
|
-
raise NotImplementedError
|
99
|
+
raise NotImplementedError()
|
104
100
|
|
105
101
|
|
106
102
|
@dataclass
|
@@ -113,10 +109,9 @@ class BaseGrammarBackend:
|
|
113
109
|
def __init__(self):
|
114
110
|
self.executor = ThreadPoolExecutor()
|
115
111
|
self.cache: Dict[Tuple[str, str], CacheEntry] = {}
|
116
|
-
self.cache_lock = Lock()
|
117
112
|
|
118
113
|
def _not_supported(self, key_type: str, key_string: str) -> None:
|
119
|
-
logger.warning(f"Skip unsupported {key_type}
|
114
|
+
logger.warning(f"Skip unsupported {key_type=}, {key_string=}")
|
120
115
|
|
121
116
|
def dispatch_fallback(
|
122
117
|
self, key_type: str, key_string: str
|
@@ -148,40 +143,25 @@ class BaseGrammarBackend:
|
|
148
143
|
return self.dispatch_ebnf(key_string)
|
149
144
|
elif key_type == "structural_tag":
|
150
145
|
return self.dispatch_structural_tag(key_string)
|
146
|
+
elif key_type == "structural_pattern":
|
147
|
+
return self.dispatch_structural_pattern(key_string)
|
151
148
|
else:
|
152
149
|
return self.dispatch_fallback(key_type, key_string)
|
153
150
|
|
154
|
-
def
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
self.cache[key] = entry
|
163
|
-
|
164
|
-
if cache_hit:
|
165
|
-
entry.event.wait()
|
166
|
-
else:
|
167
|
-
entry.value = self._init_value_dispatch(key)
|
168
|
-
entry.event.set()
|
169
|
-
return entry.value.copy() if entry.value else None
|
170
|
-
|
171
|
-
def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
172
|
-
with self.cache_lock:
|
173
|
-
entry = self.cache.get(key)
|
174
|
-
if not entry or not entry.event.is_set():
|
175
|
-
return None
|
176
|
-
val = self.cache[key].value
|
177
|
-
return val.copy() if val else None
|
151
|
+
def get_cached_or_future_value(
|
152
|
+
self, key: Tuple[str, str]
|
153
|
+
) -> Optional[BaseGrammarObject]:
|
154
|
+
value = self.cache.get(key)
|
155
|
+
if value:
|
156
|
+
return value.copy(), True
|
157
|
+
value = self.executor.submit(self._init_value_dispatch, key)
|
158
|
+
return value, False
|
178
159
|
|
179
|
-
def
|
180
|
-
|
160
|
+
def set_cache(self, key: Tuple[str, str], value: BaseGrammarObject):
|
161
|
+
self.cache[key] = value
|
181
162
|
|
182
163
|
def reset(self):
|
183
|
-
|
184
|
-
self.cache.clear()
|
164
|
+
self.cache.clear()
|
185
165
|
|
186
166
|
|
187
167
|
def create_grammar_backend(
|
@@ -211,9 +191,12 @@ def create_grammar_backend(
|
|
211
191
|
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
|
212
192
|
|
213
193
|
if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
|
214
|
-
from .reasoner_grammar_backend import
|
194
|
+
from sglang.srt.constrained.reasoner_grammar_backend import (
|
195
|
+
ReasonerGrammarBackend,
|
196
|
+
)
|
215
197
|
|
216
198
|
grammar_backend = ReasonerGrammarBackend(
|
217
199
|
grammar_backend, tokenizer.think_end_id
|
218
200
|
)
|
201
|
+
|
219
202
|
return grammar_backend
|
@@ -50,21 +50,6 @@ class GuidanceGrammar(BaseGrammarObject):
|
|
50
50
|
self.finished = False
|
51
51
|
self.bitmask = None
|
52
52
|
|
53
|
-
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
54
|
-
ff_tokens = self.ll_matcher.compute_ff_tokens()
|
55
|
-
if ff_tokens:
|
56
|
-
return ff_tokens, ""
|
57
|
-
else:
|
58
|
-
return None
|
59
|
-
|
60
|
-
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
61
|
-
return "", -1
|
62
|
-
|
63
|
-
def jump_and_retokenize(
|
64
|
-
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
65
|
-
):
|
66
|
-
pass
|
67
|
-
|
68
53
|
def accept_token(self, token: int):
|
69
54
|
if not self.ll_matcher.consume_token(token):
|
70
55
|
logger.warning(f"matcher error: {self.ll_matcher.get_error()}")
|
@@ -104,6 +89,21 @@ class GuidanceGrammar(BaseGrammarObject):
|
|
104
89
|
serialized_grammar=self.serialized_grammar,
|
105
90
|
)
|
106
91
|
|
92
|
+
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
93
|
+
ff_tokens = self.ll_matcher.compute_ff_tokens()
|
94
|
+
if ff_tokens:
|
95
|
+
return ff_tokens, ""
|
96
|
+
else:
|
97
|
+
return None
|
98
|
+
|
99
|
+
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
100
|
+
return "", -1
|
101
|
+
|
102
|
+
def jump_and_retokenize(
|
103
|
+
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
104
|
+
):
|
105
|
+
pass
|
106
|
+
|
107
107
|
|
108
108
|
class GuidanceBackend(BaseGrammarBackend):
|
109
109
|
|
@@ -130,12 +130,16 @@ class GuidanceBackend(BaseGrammarBackend):
|
|
130
130
|
return None
|
131
131
|
|
132
132
|
def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]:
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
133
|
+
try:
|
134
|
+
serialized_grammar = LLMatcher.grammar_from_json_schema(
|
135
|
+
key_string,
|
136
|
+
defaults={
|
137
|
+
"whitespace_pattern": self.whitespace_pattern,
|
138
|
+
},
|
139
|
+
)
|
140
|
+
except Exception as e:
|
141
|
+
logger.warning(f"Skip invalid grammar: {key_string=}, {e=}")
|
142
|
+
return None
|
139
143
|
return self._from_serialized(serialized_grammar)
|
140
144
|
|
141
145
|
def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]:
|
@@ -53,6 +53,30 @@ class OutlinesGrammar(BaseGrammarObject):
|
|
53
53
|
def accept_token(self, token: int):
|
54
54
|
self.state = self.guide.get_next_state(self.state, token)
|
55
55
|
|
56
|
+
def allocate_vocab_mask(
|
57
|
+
self, vocab_size: int, batch_size: int, device
|
58
|
+
) -> torch.Tensor:
|
59
|
+
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
|
60
|
+
|
61
|
+
@staticmethod
|
62
|
+
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
63
|
+
return vocab_mask
|
64
|
+
|
65
|
+
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
66
|
+
tokens = torch.tensor(
|
67
|
+
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
|
68
|
+
).to(vocab_mask.device, non_blocking=True)
|
69
|
+
vocab_mask = vocab_mask[idx]
|
70
|
+
vocab_mask.fill_(1)
|
71
|
+
vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
|
72
|
+
|
73
|
+
@staticmethod
|
74
|
+
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
|
75
|
+
logits.masked_fill_(vocab_mask, float("-inf"))
|
76
|
+
|
77
|
+
def copy(self):
|
78
|
+
return OutlinesGrammar(self.guide, self.jump_forward_map)
|
79
|
+
|
56
80
|
def try_jump_forward(self, tokenizer) -> Optional[Tuple]:
|
57
81
|
if not self.jump_forward_map:
|
58
82
|
return None
|
@@ -86,30 +110,6 @@ class OutlinesGrammar(BaseGrammarObject):
|
|
86
110
|
):
|
87
111
|
self.state = next_state
|
88
112
|
|
89
|
-
def allocate_vocab_mask(
|
90
|
-
self, vocab_size: int, batch_size: int, device
|
91
|
-
) -> torch.Tensor:
|
92
|
-
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
|
93
|
-
|
94
|
-
@staticmethod
|
95
|
-
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
96
|
-
return vocab_mask
|
97
|
-
|
98
|
-
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
99
|
-
tokens = torch.tensor(
|
100
|
-
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
|
101
|
-
).to(vocab_mask.device, non_blocking=True)
|
102
|
-
vocab_mask = vocab_mask[idx]
|
103
|
-
vocab_mask.fill_(1)
|
104
|
-
vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
|
105
|
-
|
106
|
-
@staticmethod
|
107
|
-
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
|
108
|
-
logits.masked_fill_(vocab_mask, float("-inf"))
|
109
|
-
|
110
|
-
def copy(self):
|
111
|
-
return OutlinesGrammar(self.guide, self.jump_forward_map)
|
112
|
-
|
113
113
|
|
114
114
|
class OutlinesGrammarBackend(BaseGrammarBackend):
|
115
115
|
def __init__(
|
@@ -169,8 +169,9 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
|
169
169
|
key_string,
|
170
170
|
whitespace_pattern=self.whitespace_pattern,
|
171
171
|
)
|
172
|
-
except (NotImplementedError, json.decoder.JSONDecodeError) as e:
|
173
|
-
logger.warning(f"Skip invalid json_schema:
|
172
|
+
except (NotImplementedError, json.decoder.JSONDecodeError, ValueError) as e:
|
173
|
+
logger.warning(f"Skip invalid json_schema: {key_string=}, {e=}")
|
174
|
+
return None
|
174
175
|
return self._compile_regex(regex)
|
175
176
|
|
176
177
|
def dispatch_regex(self, key_string: str):
|
@@ -13,7 +13,6 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""The baseclass of a backend for reasoner grammar-guided constrained decoding."""
|
15
15
|
|
16
|
-
from concurrent.futures import Future
|
17
16
|
from typing import List, Optional, Tuple
|
18
17
|
|
19
18
|
import torch
|
@@ -28,13 +27,12 @@ class ReasonerGrammarObject(BaseGrammarObject):
|
|
28
27
|
self.think_end_id = think_end_id
|
29
28
|
self.is_in_reasoning = True
|
30
29
|
|
31
|
-
|
32
|
-
|
33
|
-
|
30
|
+
def accept_token(self, token: int):
|
31
|
+
if token == self.think_end_id:
|
32
|
+
self.is_in_reasoning = False
|
34
33
|
|
35
|
-
|
36
|
-
|
37
|
-
self.grammar.finished = finished
|
34
|
+
if not self.is_in_reasoning and token != self.think_end_id:
|
35
|
+
self.grammar.accept_token(token)
|
38
36
|
|
39
37
|
def allocate_vocab_mask(
|
40
38
|
self, vocab_size: int, batch_size: int, device
|
@@ -52,12 +50,16 @@ class ReasonerGrammarObject(BaseGrammarObject):
|
|
52
50
|
def apply_vocab_mask(self):
|
53
51
|
return self.grammar.apply_vocab_mask
|
54
52
|
|
55
|
-
def
|
56
|
-
|
57
|
-
self.is_in_reasoning = False
|
53
|
+
def copy(self) -> BaseGrammarObject:
|
54
|
+
return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)
|
58
55
|
|
59
|
-
|
60
|
-
|
56
|
+
@property
|
57
|
+
def finished(self):
|
58
|
+
return self.grammar.finished
|
59
|
+
|
60
|
+
@finished.setter
|
61
|
+
def finished(self, finished):
|
62
|
+
self.grammar.finished = finished
|
61
63
|
|
62
64
|
def try_jump_forward(self, tokenizer):
|
63
65
|
return self.grammar.try_jump_forward(tokenizer)
|
@@ -72,30 +74,17 @@ class ReasonerGrammarObject(BaseGrammarObject):
|
|
72
74
|
old_output_ids, new_output_ids, next_state
|
73
75
|
)
|
74
76
|
|
75
|
-
def copy(self) -> BaseGrammarObject:
|
76
|
-
return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)
|
77
|
-
|
78
77
|
|
79
78
|
class ReasonerGrammarBackend(BaseGrammarBackend):
|
80
79
|
def __init__(self, grammar_backend: BaseGrammarBackend, think_end_id):
|
80
|
+
super().__init__()
|
81
81
|
self.grammar_backend = grammar_backend
|
82
82
|
self.think_end_id = think_end_id
|
83
83
|
|
84
|
-
def
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
def callback(f: Future):
|
92
|
-
if result := f.result():
|
93
|
-
grammar.set_result(ReasonerGrammarObject(result, self.think_end_id))
|
94
|
-
else:
|
95
|
-
grammar.set_result(None)
|
96
|
-
|
97
|
-
self.grammar_backend.get_future_value(key).add_done_callback(callback)
|
98
|
-
return grammar
|
99
|
-
|
100
|
-
def reset(self):
|
101
|
-
self.grammar_backend.reset()
|
84
|
+
def _init_value_dispatch(
|
85
|
+
self, key: Tuple[str, str]
|
86
|
+
) -> Optional[ReasonerGrammarObject]:
|
87
|
+
ret = self.grammar_backend._init_value_dispatch(key)
|
88
|
+
if ret is None:
|
89
|
+
return None
|
90
|
+
return ReasonerGrammarObject(ret, self.think_end_id)
|
@@ -34,7 +34,6 @@ from sglang.srt.constrained.base_grammar_backend import (
|
|
34
34
|
from sglang.srt.constrained.triton_ops.bitmask_ops import (
|
35
35
|
apply_token_bitmask_inplace_triton,
|
36
36
|
)
|
37
|
-
from sglang.srt.utils import get_bool_env_var
|
38
37
|
|
39
38
|
logger = logging.getLogger(__name__)
|
40
39
|
|
@@ -50,28 +49,69 @@ class XGrammarGrammar(BaseGrammarObject):
|
|
50
49
|
vocab_size: int,
|
51
50
|
ctx: CompiledGrammar,
|
52
51
|
override_stop_tokens: Optional[Union[List[int], int]],
|
52
|
+
key_string: Optional[str] = None, # TODO (sk): for debugging, remove later
|
53
53
|
) -> None:
|
54
|
-
super().__init__()
|
55
54
|
self.matcher = matcher
|
56
55
|
self.vocab_size = vocab_size
|
57
56
|
self.ctx = ctx
|
58
57
|
self.override_stop_tokens = override_stop_tokens
|
59
58
|
self.finished = False
|
59
|
+
self.accepted_tokens = []
|
60
|
+
self.key_string = key_string
|
61
|
+
|
62
|
+
def accept_token(self, token: int):
|
63
|
+
if not self.is_terminated():
|
64
|
+
accepted = self.matcher.accept_token(token)
|
65
|
+
if not accepted:
|
66
|
+
# log for debugging
|
67
|
+
raise ValueError(
|
68
|
+
f"Tokens not accepted: {token}\n"
|
69
|
+
f"Accepted tokens: {self.accepted_tokens}\n"
|
70
|
+
f"Key string: {self.key_string}"
|
71
|
+
)
|
72
|
+
else:
|
73
|
+
self.accepted_tokens.append(token)
|
74
|
+
|
75
|
+
def rollback(self, k: int):
|
76
|
+
self.matcher.rollback(k)
|
77
|
+
self.accepted_tokens = self.accepted_tokens[:-k]
|
78
|
+
|
79
|
+
def is_terminated(self):
|
80
|
+
return self.matcher.is_terminated()
|
81
|
+
|
82
|
+
def allocate_vocab_mask(
|
83
|
+
self, vocab_size: int, batch_size: int, device
|
84
|
+
) -> torch.Tensor:
|
85
|
+
return allocate_token_bitmask(batch_size, vocab_size)
|
86
|
+
|
87
|
+
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
88
|
+
self.matcher.fill_next_token_bitmask(vocab_mask, idx)
|
60
89
|
|
61
|
-
|
62
|
-
|
63
|
-
|
90
|
+
@staticmethod
|
91
|
+
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
92
|
+
return vocab_mask.to(device, non_blocking=True)
|
93
|
+
|
94
|
+
def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
95
|
+
if logits.device.type == "cuda":
|
96
|
+
apply_token_bitmask_inplace_triton(logits, vocab_mask)
|
97
|
+
elif logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
|
98
|
+
self.apply_vocab_mask_cpu(logits, vocab_mask)
|
99
|
+
else:
|
100
|
+
raise RuntimeError(f"Unsupported device: {logits.device.type}")
|
64
101
|
|
65
|
-
|
66
|
-
|
102
|
+
def copy(self):
|
103
|
+
matcher = GrammarMatcher(
|
104
|
+
self.ctx,
|
105
|
+
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
106
|
+
override_stop_tokens=self.override_stop_tokens,
|
67
107
|
)
|
68
|
-
|
69
|
-
|
108
|
+
return XGrammarGrammar(
|
109
|
+
matcher,
|
110
|
+
self.vocab_size,
|
111
|
+
self.ctx,
|
112
|
+
self.override_stop_tokens,
|
113
|
+
self.key_string,
|
70
114
|
)
|
71
|
-
self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_kernels.get("cpu", None)
|
72
|
-
|
73
|
-
def accept_token(self, token: int):
|
74
|
-
assert self.matcher.accept_token(token)
|
75
115
|
|
76
116
|
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
77
117
|
s = self.matcher.find_jump_forward_string()
|
@@ -100,38 +140,8 @@ class XGrammarGrammar(BaseGrammarObject):
|
|
100
140
|
for i in range(k, len(new_output_ids)):
|
101
141
|
assert self.matcher.accept_token(new_output_ids[i])
|
102
142
|
|
103
|
-
def
|
104
|
-
self,
|
105
|
-
) -> torch.Tensor:
|
106
|
-
return allocate_token_bitmask(batch_size, vocab_size)
|
107
|
-
|
108
|
-
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
109
|
-
self.matcher.fill_next_token_bitmask(vocab_mask, idx)
|
110
|
-
|
111
|
-
@staticmethod
|
112
|
-
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
113
|
-
return vocab_mask.to(device, non_blocking=True)
|
114
|
-
|
115
|
-
def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
116
|
-
if (
|
117
|
-
not self.use_token_bitmask_triton
|
118
|
-
and logits.device.type == "cuda"
|
119
|
-
and self.apply_vocab_mask_cuda
|
120
|
-
):
|
121
|
-
return self.apply_vocab_mask_cuda(logits, vocab_mask)
|
122
|
-
if logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
|
123
|
-
return self.apply_vocab_mask_cpu(logits, vocab_mask)
|
124
|
-
apply_token_bitmask_inplace_triton(logits, vocab_mask)
|
125
|
-
|
126
|
-
def copy(self):
|
127
|
-
matcher = GrammarMatcher(
|
128
|
-
self.ctx,
|
129
|
-
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
130
|
-
override_stop_tokens=self.override_stop_tokens,
|
131
|
-
)
|
132
|
-
return XGrammarGrammar(
|
133
|
-
matcher, self.vocab_size, self.ctx, self.override_stop_tokens
|
134
|
-
)
|
143
|
+
def __repr__(self):
|
144
|
+
return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=})"
|
135
145
|
|
136
146
|
|
137
147
|
class XGrammarGrammarBackend(BaseGrammarBackend):
|
@@ -151,9 +161,15 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
151
161
|
self.vocab_size = vocab_size
|
152
162
|
self.override_stop_tokens = override_stop_tokens
|
153
163
|
|
154
|
-
def _from_context(self, ctx: CompiledGrammar) -> XGrammarGrammar:
|
155
|
-
matcher = GrammarMatcher(
|
156
|
-
|
164
|
+
def _from_context(self, ctx: CompiledGrammar, key_string: str) -> XGrammarGrammar:
|
165
|
+
matcher = GrammarMatcher(
|
166
|
+
ctx,
|
167
|
+
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
168
|
+
override_stop_tokens=self.override_stop_tokens,
|
169
|
+
)
|
170
|
+
return XGrammarGrammar(
|
171
|
+
matcher, self.vocab_size, ctx, self.override_stop_tokens, key_string
|
172
|
+
)
|
157
173
|
|
158
174
|
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
|
159
175
|
try:
|
@@ -165,7 +181,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
165
181
|
except RuntimeError as e:
|
166
182
|
logging.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
|
167
183
|
return None
|
168
|
-
return self._from_context(ctx)
|
184
|
+
return self._from_context(ctx, key_string)
|
169
185
|
|
170
186
|
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
|
171
187
|
try:
|
@@ -173,7 +189,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
173
189
|
except RuntimeError as e:
|
174
190
|
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
|
175
191
|
return None
|
176
|
-
return self._from_context(ctx)
|
192
|
+
return self._from_context(ctx, key_string)
|
177
193
|
|
178
194
|
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
|
179
195
|
try:
|
@@ -181,7 +197,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
181
197
|
except RuntimeError as e:
|
182
198
|
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
|
183
199
|
return None
|
184
|
-
return self._from_context(ctx)
|
200
|
+
return self._from_context(ctx, key_string)
|
185
201
|
|
186
202
|
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
|
187
203
|
try:
|
@@ -198,9 +214,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
198
214
|
tags, structural_tag["triggers"]
|
199
215
|
)
|
200
216
|
except RuntimeError as e:
|
201
|
-
logging.warning(
|
217
|
+
logging.warning(
|
218
|
+
f"Skip invalid structural_tag: structural_tag={key_string}, {e=}"
|
219
|
+
)
|
202
220
|
return None
|
203
|
-
return self._from_context(ctx)
|
221
|
+
return self._from_context(ctx, key_string)
|
204
222
|
|
205
223
|
def reset(self):
|
206
224
|
if self.grammar_compiler:
|