sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +3 -13
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +158 -8
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +119 -75
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +5 -2
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/internvl.py +696 -0
  13. sglang/srt/configs/janus_pro.py +3 -0
  14. sglang/srt/configs/model_config.py +18 -0
  15. sglang/srt/constrained/base_grammar_backend.py +55 -72
  16. sglang/srt/constrained/llguidance_backend.py +25 -21
  17. sglang/srt/constrained/outlines_backend.py +27 -26
  18. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  19. sglang/srt/constrained/xgrammar_backend.py +71 -53
  20. sglang/srt/conversation.py +78 -46
  21. sglang/srt/disaggregation/base/conn.py +1 -0
  22. sglang/srt/disaggregation/decode.py +11 -3
  23. sglang/srt/disaggregation/fake/conn.py +1 -1
  24. sglang/srt/disaggregation/mini_lb.py +74 -23
  25. sglang/srt/disaggregation/mooncake/conn.py +236 -138
  26. sglang/srt/disaggregation/nixl/conn.py +242 -71
  27. sglang/srt/disaggregation/prefill.py +7 -4
  28. sglang/srt/disaggregation/utils.py +51 -2
  29. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  30. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  31. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  32. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  33. sglang/srt/distributed/parallel_state.py +22 -1
  34. sglang/srt/entrypoints/engine.py +31 -4
  35. sglang/srt/entrypoints/http_server.py +45 -3
  36. sglang/srt/entrypoints/verl_engine.py +3 -2
  37. sglang/srt/function_call_parser.py +2 -2
  38. sglang/srt/hf_transformers_utils.py +20 -1
  39. sglang/srt/layers/attention/flashattention_backend.py +147 -51
  40. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  41. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  42. sglang/srt/layers/attention/merge_state.py +46 -0
  43. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  44. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  45. sglang/srt/layers/attention/utils.py +4 -2
  46. sglang/srt/layers/attention/vision.py +290 -163
  47. sglang/srt/layers/dp_attention.py +71 -21
  48. sglang/srt/layers/layernorm.py +1 -1
  49. sglang/srt/layers/logits_processor.py +46 -11
  50. sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
  51. sglang/srt/layers/moe/ep_moe/layer.py +121 -2
  52. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  56. sglang/srt/layers/moe/topk.py +1 -1
  57. sglang/srt/layers/quantization/__init__.py +1 -1
  58. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  59. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  60. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  61. sglang/srt/layers/quantization/deep_gemm.py +77 -71
  62. sglang/srt/layers/quantization/fp8.py +110 -97
  63. sglang/srt/layers/quantization/fp8_kernel.py +81 -62
  64. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  65. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  66. sglang/srt/layers/quantization/kv_cache.py +3 -10
  67. sglang/srt/layers/quantization/utils.py +0 -5
  68. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  69. sglang/srt/layers/sampler.py +0 -4
  70. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  71. sglang/srt/lora/lora_manager.py +11 -14
  72. sglang/srt/lora/mem_pool.py +4 -4
  73. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  74. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  75. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  76. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  77. sglang/srt/lora/utils.py +1 -1
  78. sglang/srt/managers/cache_controller.py +115 -119
  79. sglang/srt/managers/data_parallel_controller.py +3 -3
  80. sglang/srt/managers/detokenizer_manager.py +21 -8
  81. sglang/srt/managers/io_struct.py +13 -1
  82. sglang/srt/managers/mm_utils.py +1 -1
  83. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  84. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  85. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  86. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  87. sglang/srt/managers/schedule_batch.py +93 -23
  88. sglang/srt/managers/schedule_policy.py +11 -8
  89. sglang/srt/managers/scheduler.py +140 -100
  90. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  91. sglang/srt/managers/tokenizer_manager.py +157 -47
  92. sglang/srt/managers/tp_worker.py +21 -21
  93. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  94. sglang/srt/mem_cache/chunk_cache.py +2 -0
  95. sglang/srt/mem_cache/memory_pool.py +4 -2
  96. sglang/srt/metrics/collector.py +312 -37
  97. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  98. sglang/srt/model_executor/forward_batch_info.py +1 -1
  99. sglang/srt/model_executor/model_runner.py +57 -41
  100. sglang/srt/model_loader/loader.py +18 -11
  101. sglang/srt/models/clip.py +4 -4
  102. sglang/srt/models/deepseek_janus_pro.py +3 -3
  103. sglang/srt/models/deepseek_nextn.py +1 -20
  104. sglang/srt/models/deepseek_v2.py +77 -39
  105. sglang/srt/models/gemma3_mm.py +1 -1
  106. sglang/srt/models/internlm2.py +3 -0
  107. sglang/srt/models/internvl.py +670 -0
  108. sglang/srt/models/llama.py +3 -1
  109. sglang/srt/models/llama4.py +58 -13
  110. sglang/srt/models/llava.py +248 -5
  111. sglang/srt/models/minicpmv.py +1 -1
  112. sglang/srt/models/mixtral.py +98 -34
  113. sglang/srt/models/mllama.py +1 -1
  114. sglang/srt/models/phi3_small.py +16 -2
  115. sglang/srt/models/pixtral.py +467 -0
  116. sglang/srt/models/qwen2_5_vl.py +8 -4
  117. sglang/srt/models/qwen2_vl.py +4 -4
  118. sglang/srt/models/roberta.py +1 -1
  119. sglang/srt/models/torch_native_llama.py +1 -1
  120. sglang/srt/models/xiaomi_mimo.py +171 -0
  121. sglang/srt/openai_api/adapter.py +52 -42
  122. sglang/srt/openai_api/protocol.py +20 -16
  123. sglang/srt/reasoning_parser.py +1 -1
  124. sglang/srt/sampling/custom_logit_processor.py +18 -3
  125. sglang/srt/sampling/sampling_batch_info.py +2 -2
  126. sglang/srt/sampling/sampling_params.py +2 -0
  127. sglang/srt/server_args.py +64 -10
  128. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  129. sglang/srt/speculative/eagle_utils.py +7 -7
  130. sglang/srt/speculative/eagle_worker.py +22 -19
  131. sglang/srt/utils.py +41 -6
  132. sglang/test/few_shot_gsm8k.py +2 -2
  133. sglang/test/few_shot_gsm8k_engine.py +2 -2
  134. sglang/test/run_eval.py +2 -2
  135. sglang/test/runners.py +8 -1
  136. sglang/test/send_one.py +13 -3
  137. sglang/test/simple_eval_common.py +1 -1
  138. sglang/test/simple_eval_humaneval.py +1 -1
  139. sglang/test/test_block_fp8.py +2 -2
  140. sglang/test/test_deepep_utils.py +219 -0
  141. sglang/test/test_programs.py +5 -5
  142. sglang/test/test_utils.py +92 -15
  143. sglang/utils.py +1 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
  146. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
  147. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
  148. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  149. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -290,6 +290,9 @@ class DictOutput(object):
290
290
  def __getitem__(self, item):
291
291
  return self.__dict__[item]
292
292
 
293
+ def __contains__(self, key):
294
+ return key in self.__dict__
295
+
293
296
  def __setitem__(self, key, value):
294
297
  self.__dict__[key] = value
295
298
 
@@ -24,6 +24,7 @@ from transformers import PretrainedConfig
24
24
 
25
25
  from sglang.srt.hf_transformers_utils import get_config, get_context_length
26
26
  from sglang.srt.layers.quantization import QUANTIZATION_METHODS
27
+ from sglang.srt.server_args import ServerArgs
27
28
  from sglang.srt.utils import get_bool_env_var, is_hip
28
29
 
29
30
  logger = logging.getLogger(__name__)
@@ -210,6 +211,21 @@ class ModelConfig:
210
211
  self.hf_eos_token_id = self.get_hf_eos_token_id()
211
212
  self.image_token_id = getattr(self.hf_config, "image_token_id", None)
212
213
 
214
+ @staticmethod
215
+ def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
216
+ return ModelConfig(
217
+ model_path=model_path or server_args.model_path,
218
+ trust_remote_code=server_args.trust_remote_code,
219
+ revision=server_args.revision,
220
+ context_length=server_args.context_length,
221
+ model_override_args=server_args.json_model_override_args,
222
+ is_embedding=server_args.is_embedding,
223
+ enable_multimodal=server_args.enable_multimodal,
224
+ dtype=server_args.dtype,
225
+ quantization=server_args.quantization,
226
+ **kwargs,
227
+ )
228
+
213
229
  # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
214
230
  def get_total_num_kv_heads(self) -> int:
215
231
  """Returns the total number of KV heads."""
@@ -529,6 +545,7 @@ multimodal_model_archs = [
529
545
  "Llama4ForConditionalGeneration",
530
546
  "LlavaMistralForCausalLM",
531
547
  "LlavaQwenForCausalLM",
548
+ "LlavaForConditionalGeneration",
532
549
  "LlavaVidForCausalLM",
533
550
  "MiniCPMO",
534
551
  "MiniCPMV",
@@ -538,6 +555,7 @@ multimodal_model_archs = [
538
555
  "Qwen2_5_VLForConditionalGeneration",
539
556
  "CLIPModel",
540
557
  "KimiVLForConditionalGeneration",
558
+ "InternVLChatModel",
541
559
  ]
542
560
 
543
561
 
@@ -14,10 +14,9 @@
14
14
  """The baseclass of a backend for grammar-guided constrained decoding."""
15
15
 
16
16
  import logging
17
- from abc import ABC, abstractmethod
18
- from concurrent.futures import Future, ThreadPoolExecutor
17
+ from concurrent.futures import ThreadPoolExecutor
19
18
  from dataclasses import dataclass
20
- from threading import Event, Lock
19
+ from threading import Event
21
20
  from typing import Dict, List, Optional, Tuple
22
21
 
23
22
  import torch
@@ -27,11 +26,42 @@ from sglang.srt.server_args import ServerArgs
27
26
  logger = logging.getLogger(__name__)
28
27
 
29
28
 
30
- class BaseGrammarObject(ABC):
29
+ class BaseGrammarObject:
31
30
 
32
31
  def __init__(self):
33
32
  self._finished = False
34
33
 
34
+ def accept_token(self, token: int) -> None:
35
+ """
36
+ Accept a token in the grammar.
37
+ """
38
+ raise NotImplementedError()
39
+
40
+ def rollback(self, k: int):
41
+ raise NotImplementedError()
42
+
43
+ def is_terminated(self):
44
+ return False
45
+
46
+ def allocate_vocab_mask(
47
+ self, vocab_size: int, batch_size: int, device
48
+ ) -> torch.Tensor:
49
+ raise NotImplementedError()
50
+
51
+ def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
52
+ raise NotImplementedError()
53
+
54
+ @staticmethod
55
+ def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
56
+ raise NotImplementedError()
57
+
58
+ @staticmethod
59
+ def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
60
+ raise NotImplementedError()
61
+
62
+ def copy(self) -> "BaseGrammarObject":
63
+ raise NotImplementedError()
64
+
35
65
  @property
36
66
  def finished(self):
37
67
  return self._finished
@@ -40,7 +70,6 @@ class BaseGrammarObject(ABC):
40
70
  def finished(self, finished):
41
71
  self._finished = finished
42
72
 
43
- @abstractmethod
44
73
  def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
45
74
  """
46
75
  Try to jump forward in the grammar.
@@ -49,9 +78,8 @@ class BaseGrammarObject(ABC):
49
78
  A jump forward helper which may be used in `jump_forward_str_state`.
50
79
  None if the jump forward is not possible.
51
80
  """
52
- raise NotImplementedError
81
+ raise NotImplementedError()
53
82
 
54
- @abstractmethod
55
83
  def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
56
84
  """
57
85
  Jump forward for the grammar.
@@ -60,47 +88,15 @@ class BaseGrammarObject(ABC):
60
88
  A tuple of the jump forward string and the next state of the grammar
61
89
  (which can be used in `jump_and_retokenize` if needed).
62
90
  """
63
- raise NotImplementedError
91
+ raise NotImplementedError()
64
92
 
65
- @abstractmethod
66
93
  def jump_and_retokenize(
67
94
  self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
68
95
  ) -> None:
69
96
  """
70
97
  Jump forward occurs, and update the grammar state if needed.
71
98
  """
72
- raise NotImplementedError
73
-
74
- @abstractmethod
75
- def accept_token(self, token: int) -> None:
76
- """
77
- Accept a token in the grammar.
78
- """
79
- raise NotImplementedError
80
-
81
- @abstractmethod
82
- def allocate_vocab_mask(
83
- self, vocab_size: int, batch_size: int, device
84
- ) -> torch.Tensor:
85
- raise NotImplementedError
86
-
87
- @abstractmethod
88
- def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
89
- raise NotImplementedError
90
-
91
- @staticmethod
92
- @abstractmethod
93
- def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
94
- raise NotImplementedError
95
-
96
- @staticmethod
97
- @abstractmethod
98
- def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
99
- raise NotImplementedError
100
-
101
- @abstractmethod
102
- def copy(self) -> "BaseGrammarObject":
103
- raise NotImplementedError
99
+ raise NotImplementedError()
104
100
 
105
101
 
106
102
  @dataclass
@@ -113,10 +109,9 @@ class BaseGrammarBackend:
113
109
  def __init__(self):
114
110
  self.executor = ThreadPoolExecutor()
115
111
  self.cache: Dict[Tuple[str, str], CacheEntry] = {}
116
- self.cache_lock = Lock()
117
112
 
118
113
  def _not_supported(self, key_type: str, key_string: str) -> None:
119
- logger.warning(f"Skip unsupported {key_type}: {key_type}={key_string}")
114
+ logger.warning(f"Skip unsupported {key_type=}, {key_string=}")
120
115
 
121
116
  def dispatch_fallback(
122
117
  self, key_type: str, key_string: str
@@ -148,40 +143,25 @@ class BaseGrammarBackend:
148
143
  return self.dispatch_ebnf(key_string)
149
144
  elif key_type == "structural_tag":
150
145
  return self.dispatch_structural_tag(key_string)
146
+ elif key_type == "structural_pattern":
147
+ return self.dispatch_structural_pattern(key_string)
151
148
  else:
152
149
  return self.dispatch_fallback(key_type, key_string)
153
150
 
154
- def _init_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
155
- with self.cache_lock:
156
- if key in self.cache:
157
- cache_hit = True
158
- entry = self.cache[key]
159
- else:
160
- cache_hit = False
161
- entry = CacheEntry(None, Event())
162
- self.cache[key] = entry
163
-
164
- if cache_hit:
165
- entry.event.wait()
166
- else:
167
- entry.value = self._init_value_dispatch(key)
168
- entry.event.set()
169
- return entry.value.copy() if entry.value else None
170
-
171
- def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
172
- with self.cache_lock:
173
- entry = self.cache.get(key)
174
- if not entry or not entry.event.is_set():
175
- return None
176
- val = self.cache[key].value
177
- return val.copy() if val else None
151
+ def get_cached_or_future_value(
152
+ self, key: Tuple[str, str]
153
+ ) -> Optional[BaseGrammarObject]:
154
+ value = self.cache.get(key)
155
+ if value:
156
+ return value.copy(), True
157
+ value = self.executor.submit(self._init_value_dispatch, key)
158
+ return value, False
178
159
 
179
- def get_future_value(self, key: Tuple[str, str]) -> Future:
180
- return self.executor.submit(self._init_value, key)
160
+ def set_cache(self, key: Tuple[str, str], value: BaseGrammarObject):
161
+ self.cache[key] = value
181
162
 
182
163
  def reset(self):
183
- with self.cache_lock:
184
- self.cache.clear()
164
+ self.cache.clear()
185
165
 
186
166
 
187
167
  def create_grammar_backend(
@@ -211,9 +191,12 @@ def create_grammar_backend(
211
191
  raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
212
192
 
213
193
  if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
214
- from .reasoner_grammar_backend import ReasonerGrammarBackend
194
+ from sglang.srt.constrained.reasoner_grammar_backend import (
195
+ ReasonerGrammarBackend,
196
+ )
215
197
 
216
198
  grammar_backend = ReasonerGrammarBackend(
217
199
  grammar_backend, tokenizer.think_end_id
218
200
  )
201
+
219
202
  return grammar_backend
@@ -50,21 +50,6 @@ class GuidanceGrammar(BaseGrammarObject):
50
50
  self.finished = False
51
51
  self.bitmask = None
52
52
 
53
- def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
54
- ff_tokens = self.ll_matcher.compute_ff_tokens()
55
- if ff_tokens:
56
- return ff_tokens, ""
57
- else:
58
- return None
59
-
60
- def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
61
- return "", -1
62
-
63
- def jump_and_retokenize(
64
- self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
65
- ):
66
- pass
67
-
68
53
  def accept_token(self, token: int):
69
54
  if not self.ll_matcher.consume_token(token):
70
55
  logger.warning(f"matcher error: {self.ll_matcher.get_error()}")
@@ -104,6 +89,21 @@ class GuidanceGrammar(BaseGrammarObject):
104
89
  serialized_grammar=self.serialized_grammar,
105
90
  )
106
91
 
92
+ def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
93
+ ff_tokens = self.ll_matcher.compute_ff_tokens()
94
+ if ff_tokens:
95
+ return ff_tokens, ""
96
+ else:
97
+ return None
98
+
99
+ def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
100
+ return "", -1
101
+
102
+ def jump_and_retokenize(
103
+ self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
104
+ ):
105
+ pass
106
+
107
107
 
108
108
  class GuidanceBackend(BaseGrammarBackend):
109
109
 
@@ -130,12 +130,16 @@ class GuidanceBackend(BaseGrammarBackend):
130
130
  return None
131
131
 
132
132
  def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]:
133
- serialized_grammar = LLMatcher.grammar_from_json_schema(
134
- key_string,
135
- defaults={
136
- "whitespace_pattern": self.whitespace_pattern,
137
- },
138
- )
133
+ try:
134
+ serialized_grammar = LLMatcher.grammar_from_json_schema(
135
+ key_string,
136
+ defaults={
137
+ "whitespace_pattern": self.whitespace_pattern,
138
+ },
139
+ )
140
+ except Exception as e:
141
+ logger.warning(f"Skip invalid grammar: {key_string=}, {e=}")
142
+ return None
139
143
  return self._from_serialized(serialized_grammar)
140
144
 
141
145
  def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]:
@@ -53,6 +53,30 @@ class OutlinesGrammar(BaseGrammarObject):
53
53
  def accept_token(self, token: int):
54
54
  self.state = self.guide.get_next_state(self.state, token)
55
55
 
56
+ def allocate_vocab_mask(
57
+ self, vocab_size: int, batch_size: int, device
58
+ ) -> torch.Tensor:
59
+ return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
60
+
61
+ @staticmethod
62
+ def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
63
+ return vocab_mask
64
+
65
+ def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
66
+ tokens = torch.tensor(
67
+ self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
68
+ ).to(vocab_mask.device, non_blocking=True)
69
+ vocab_mask = vocab_mask[idx]
70
+ vocab_mask.fill_(1)
71
+ vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
72
+
73
+ @staticmethod
74
+ def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
75
+ logits.masked_fill_(vocab_mask, float("-inf"))
76
+
77
+ def copy(self):
78
+ return OutlinesGrammar(self.guide, self.jump_forward_map)
79
+
56
80
  def try_jump_forward(self, tokenizer) -> Optional[Tuple]:
57
81
  if not self.jump_forward_map:
58
82
  return None
@@ -86,30 +110,6 @@ class OutlinesGrammar(BaseGrammarObject):
86
110
  ):
87
111
  self.state = next_state
88
112
 
89
- def allocate_vocab_mask(
90
- self, vocab_size: int, batch_size: int, device
91
- ) -> torch.Tensor:
92
- return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
93
-
94
- @staticmethod
95
- def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
96
- return vocab_mask
97
-
98
- def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
99
- tokens = torch.tensor(
100
- self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
101
- ).to(vocab_mask.device, non_blocking=True)
102
- vocab_mask = vocab_mask[idx]
103
- vocab_mask.fill_(1)
104
- vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
105
-
106
- @staticmethod
107
- def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
108
- logits.masked_fill_(vocab_mask, float("-inf"))
109
-
110
- def copy(self):
111
- return OutlinesGrammar(self.guide, self.jump_forward_map)
112
-
113
113
 
114
114
  class OutlinesGrammarBackend(BaseGrammarBackend):
115
115
  def __init__(
@@ -169,8 +169,9 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
169
169
  key_string,
170
170
  whitespace_pattern=self.whitespace_pattern,
171
171
  )
172
- except (NotImplementedError, json.decoder.JSONDecodeError) as e:
173
- logger.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
172
+ except (NotImplementedError, json.decoder.JSONDecodeError, ValueError) as e:
173
+ logger.warning(f"Skip invalid json_schema: {key_string=}, {e=}")
174
+ return None
174
175
  return self._compile_regex(regex)
175
176
 
176
177
  def dispatch_regex(self, key_string: str):
@@ -13,7 +13,6 @@
13
13
  # ==============================================================================
14
14
  """The baseclass of a backend for reasoner grammar-guided constrained decoding."""
15
15
 
16
- from concurrent.futures import Future
17
16
  from typing import List, Optional, Tuple
18
17
 
19
18
  import torch
@@ -28,13 +27,12 @@ class ReasonerGrammarObject(BaseGrammarObject):
28
27
  self.think_end_id = think_end_id
29
28
  self.is_in_reasoning = True
30
29
 
31
- @property
32
- def finished(self):
33
- return self.grammar.finished
30
+ def accept_token(self, token: int):
31
+ if token == self.think_end_id:
32
+ self.is_in_reasoning = False
34
33
 
35
- @finished.setter
36
- def finished(self, finished):
37
- self.grammar.finished = finished
34
+ if not self.is_in_reasoning and token != self.think_end_id:
35
+ self.grammar.accept_token(token)
38
36
 
39
37
  def allocate_vocab_mask(
40
38
  self, vocab_size: int, batch_size: int, device
@@ -52,12 +50,16 @@ class ReasonerGrammarObject(BaseGrammarObject):
52
50
  def apply_vocab_mask(self):
53
51
  return self.grammar.apply_vocab_mask
54
52
 
55
- def accept_token(self, token: int):
56
- if token == self.think_end_id:
57
- self.is_in_reasoning = False
53
+ def copy(self) -> BaseGrammarObject:
54
+ return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)
58
55
 
59
- if not self.is_in_reasoning and token != self.think_end_id:
60
- self.grammar.accept_token(token)
56
+ @property
57
+ def finished(self):
58
+ return self.grammar.finished
59
+
60
+ @finished.setter
61
+ def finished(self, finished):
62
+ self.grammar.finished = finished
61
63
 
62
64
  def try_jump_forward(self, tokenizer):
63
65
  return self.grammar.try_jump_forward(tokenizer)
@@ -72,30 +74,17 @@ class ReasonerGrammarObject(BaseGrammarObject):
72
74
  old_output_ids, new_output_ids, next_state
73
75
  )
74
76
 
75
- def copy(self) -> BaseGrammarObject:
76
- return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)
77
-
78
77
 
79
78
  class ReasonerGrammarBackend(BaseGrammarBackend):
80
79
  def __init__(self, grammar_backend: BaseGrammarBackend, think_end_id):
80
+ super().__init__()
81
81
  self.grammar_backend = grammar_backend
82
82
  self.think_end_id = think_end_id
83
83
 
84
- def get_cached_value(self, key: Tuple[str, str]) -> Optional[ReasonerGrammarObject]:
85
- grammar = self.grammar_backend.get_cached_value(key)
86
- return ReasonerGrammarObject(grammar, self.think_end_id) if grammar else None
87
-
88
- def get_future_value(self, key: Tuple[str, str]) -> Future:
89
- grammar = Future()
90
-
91
- def callback(f: Future):
92
- if result := f.result():
93
- grammar.set_result(ReasonerGrammarObject(result, self.think_end_id))
94
- else:
95
- grammar.set_result(None)
96
-
97
- self.grammar_backend.get_future_value(key).add_done_callback(callback)
98
- return grammar
99
-
100
- def reset(self):
101
- self.grammar_backend.reset()
84
+ def _init_value_dispatch(
85
+ self, key: Tuple[str, str]
86
+ ) -> Optional[ReasonerGrammarObject]:
87
+ ret = self.grammar_backend._init_value_dispatch(key)
88
+ if ret is None:
89
+ return None
90
+ return ReasonerGrammarObject(ret, self.think_end_id)
@@ -34,7 +34,6 @@ from sglang.srt.constrained.base_grammar_backend import (
34
34
  from sglang.srt.constrained.triton_ops.bitmask_ops import (
35
35
  apply_token_bitmask_inplace_triton,
36
36
  )
37
- from sglang.srt.utils import get_bool_env_var
38
37
 
39
38
  logger = logging.getLogger(__name__)
40
39
 
@@ -50,28 +49,69 @@ class XGrammarGrammar(BaseGrammarObject):
50
49
  vocab_size: int,
51
50
  ctx: CompiledGrammar,
52
51
  override_stop_tokens: Optional[Union[List[int], int]],
52
+ key_string: Optional[str] = None, # TODO (sk): for debugging, remove later
53
53
  ) -> None:
54
- super().__init__()
55
54
  self.matcher = matcher
56
55
  self.vocab_size = vocab_size
57
56
  self.ctx = ctx
58
57
  self.override_stop_tokens = override_stop_tokens
59
58
  self.finished = False
59
+ self.accepted_tokens = []
60
+ self.key_string = key_string
61
+
62
+ def accept_token(self, token: int):
63
+ if not self.is_terminated():
64
+ accepted = self.matcher.accept_token(token)
65
+ if not accepted:
66
+ # log for debugging
67
+ raise ValueError(
68
+ f"Tokens not accepted: {token}\n"
69
+ f"Accepted tokens: {self.accepted_tokens}\n"
70
+ f"Key string: {self.key_string}"
71
+ )
72
+ else:
73
+ self.accepted_tokens.append(token)
74
+
75
+ def rollback(self, k: int):
76
+ self.matcher.rollback(k)
77
+ self.accepted_tokens = self.accepted_tokens[:-k]
78
+
79
+ def is_terminated(self):
80
+ return self.matcher.is_terminated()
81
+
82
+ def allocate_vocab_mask(
83
+ self, vocab_size: int, batch_size: int, device
84
+ ) -> torch.Tensor:
85
+ return allocate_token_bitmask(batch_size, vocab_size)
86
+
87
+ def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
88
+ self.matcher.fill_next_token_bitmask(vocab_mask, idx)
60
89
 
61
- # Fix (from vLLM team): postpone the import of apply_token_bitmask_inplace_kernels to the
62
- # class init site to avoid re-initializing CUDA in forked subprocess.
63
- from xgrammar.kernels import apply_token_bitmask_inplace_kernels
90
+ @staticmethod
91
+ def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
92
+ return vocab_mask.to(device, non_blocking=True)
93
+
94
+ def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
95
+ if logits.device.type == "cuda":
96
+ apply_token_bitmask_inplace_triton(logits, vocab_mask)
97
+ elif logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
98
+ self.apply_vocab_mask_cpu(logits, vocab_mask)
99
+ else:
100
+ raise RuntimeError(f"Unsupported device: {logits.device.type}")
64
101
 
65
- self.use_token_bitmask_triton = get_bool_env_var(
66
- "SGLANG_TOKEN_BITMASK_TRITON", "false"
102
+ def copy(self):
103
+ matcher = GrammarMatcher(
104
+ self.ctx,
105
+ max_rollback_tokens=MAX_ROLLBACK_TOKENS,
106
+ override_stop_tokens=self.override_stop_tokens,
67
107
  )
68
- self.apply_vocab_mask_cuda = apply_token_bitmask_inplace_kernels.get(
69
- "cuda", None
108
+ return XGrammarGrammar(
109
+ matcher,
110
+ self.vocab_size,
111
+ self.ctx,
112
+ self.override_stop_tokens,
113
+ self.key_string,
70
114
  )
71
- self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_kernels.get("cpu", None)
72
-
73
- def accept_token(self, token: int):
74
- assert self.matcher.accept_token(token)
75
115
 
76
116
  def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
77
117
  s = self.matcher.find_jump_forward_string()
@@ -100,38 +140,8 @@ class XGrammarGrammar(BaseGrammarObject):
100
140
  for i in range(k, len(new_output_ids)):
101
141
  assert self.matcher.accept_token(new_output_ids[i])
102
142
 
103
- def allocate_vocab_mask(
104
- self, vocab_size: int, batch_size: int, device
105
- ) -> torch.Tensor:
106
- return allocate_token_bitmask(batch_size, vocab_size)
107
-
108
- def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
109
- self.matcher.fill_next_token_bitmask(vocab_mask, idx)
110
-
111
- @staticmethod
112
- def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
113
- return vocab_mask.to(device, non_blocking=True)
114
-
115
- def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
116
- if (
117
- not self.use_token_bitmask_triton
118
- and logits.device.type == "cuda"
119
- and self.apply_vocab_mask_cuda
120
- ):
121
- return self.apply_vocab_mask_cuda(logits, vocab_mask)
122
- if logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
123
- return self.apply_vocab_mask_cpu(logits, vocab_mask)
124
- apply_token_bitmask_inplace_triton(logits, vocab_mask)
125
-
126
- def copy(self):
127
- matcher = GrammarMatcher(
128
- self.ctx,
129
- max_rollback_tokens=MAX_ROLLBACK_TOKENS,
130
- override_stop_tokens=self.override_stop_tokens,
131
- )
132
- return XGrammarGrammar(
133
- matcher, self.vocab_size, self.ctx, self.override_stop_tokens
134
- )
143
+ def __repr__(self):
144
+ return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=})"
135
145
 
136
146
 
137
147
  class XGrammarGrammarBackend(BaseGrammarBackend):
@@ -151,9 +161,15 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
151
161
  self.vocab_size = vocab_size
152
162
  self.override_stop_tokens = override_stop_tokens
153
163
 
154
- def _from_context(self, ctx: CompiledGrammar) -> XGrammarGrammar:
155
- matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
156
- return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens)
164
+ def _from_context(self, ctx: CompiledGrammar, key_string: str) -> XGrammarGrammar:
165
+ matcher = GrammarMatcher(
166
+ ctx,
167
+ max_rollback_tokens=MAX_ROLLBACK_TOKENS,
168
+ override_stop_tokens=self.override_stop_tokens,
169
+ )
170
+ return XGrammarGrammar(
171
+ matcher, self.vocab_size, ctx, self.override_stop_tokens, key_string
172
+ )
157
173
 
158
174
  def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
159
175
  try:
@@ -165,7 +181,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
165
181
  except RuntimeError as e:
166
182
  logging.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
167
183
  return None
168
- return self._from_context(ctx)
184
+ return self._from_context(ctx, key_string)
169
185
 
170
186
  def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
171
187
  try:
@@ -173,7 +189,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
173
189
  except RuntimeError as e:
174
190
  logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
175
191
  return None
176
- return self._from_context(ctx)
192
+ return self._from_context(ctx, key_string)
177
193
 
178
194
  def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
179
195
  try:
@@ -181,7 +197,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
181
197
  except RuntimeError as e:
182
198
  logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
183
199
  return None
184
- return self._from_context(ctx)
200
+ return self._from_context(ctx, key_string)
185
201
 
186
202
  def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
187
203
  try:
@@ -198,9 +214,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
198
214
  tags, structural_tag["triggers"]
199
215
  )
200
216
  except RuntimeError as e:
201
- logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
217
+ logging.warning(
218
+ f"Skip invalid structural_tag: structural_tag={key_string}, {e=}"
219
+ )
202
220
  return None
203
- return self._from_context(ctx)
221
+ return self._from_context(ctx, key_string)
204
222
 
205
223
  def reset(self):
206
224
  if self.grammar_compiler: