sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -50,21 +50,6 @@ class GuidanceGrammar(BaseGrammarObject):
50
50
  self.finished = False
51
51
  self.bitmask = None
52
52
 
53
- def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
54
- ff_tokens = self.ll_matcher.compute_ff_tokens()
55
- if ff_tokens:
56
- return ff_tokens, ""
57
- else:
58
- return None
59
-
60
- def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
61
- return "", -1
62
-
63
- def jump_and_retokenize(
64
- self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
65
- ):
66
- pass
67
-
68
53
  def accept_token(self, token: int):
69
54
  if not self.ll_matcher.consume_token(token):
70
55
  logger.warning(f"matcher error: {self.ll_matcher.get_error()}")
@@ -104,6 +89,21 @@ class GuidanceGrammar(BaseGrammarObject):
104
89
  serialized_grammar=self.serialized_grammar,
105
90
  )
106
91
 
92
+ def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
93
+ ff_tokens = self.ll_matcher.compute_ff_tokens()
94
+ if ff_tokens:
95
+ return ff_tokens, ""
96
+ else:
97
+ return None
98
+
99
+ def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
100
+ return "", -1
101
+
102
+ def jump_and_retokenize(
103
+ self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
104
+ ):
105
+ pass
106
+
107
107
 
108
108
  class GuidanceBackend(BaseGrammarBackend):
109
109
 
@@ -130,12 +130,16 @@ class GuidanceBackend(BaseGrammarBackend):
130
130
  return None
131
131
 
132
132
  def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]:
133
- serialized_grammar = LLMatcher.grammar_from_json_schema(
134
- key_string,
135
- defaults={
136
- "whitespace_pattern": self.whitespace_pattern,
137
- },
138
- )
133
+ try:
134
+ serialized_grammar = LLMatcher.grammar_from_json_schema(
135
+ key_string,
136
+ defaults={
137
+ "whitespace_pattern": self.whitespace_pattern,
138
+ },
139
+ )
140
+ except Exception as e:
141
+ logger.warning(f"Skip invalid grammar: {key_string=}, {e=}")
142
+ return None
139
143
  return self._from_serialized(serialized_grammar)
140
144
 
141
145
  def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]:
@@ -53,6 +53,30 @@ class OutlinesGrammar(BaseGrammarObject):
53
53
  def accept_token(self, token: int):
54
54
  self.state = self.guide.get_next_state(self.state, token)
55
55
 
56
+ def allocate_vocab_mask(
57
+ self, vocab_size: int, batch_size: int, device
58
+ ) -> torch.Tensor:
59
+ return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
60
+
61
+ @staticmethod
62
+ def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
63
+ return vocab_mask
64
+
65
+ def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
66
+ tokens = torch.tensor(
67
+ self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
68
+ ).to(vocab_mask.device, non_blocking=True)
69
+ vocab_mask = vocab_mask[idx]
70
+ vocab_mask.fill_(1)
71
+ vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
72
+
73
+ @staticmethod
74
+ def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
75
+ logits.masked_fill_(vocab_mask, float("-inf"))
76
+
77
+ def copy(self):
78
+ return OutlinesGrammar(self.guide, self.jump_forward_map)
79
+
56
80
  def try_jump_forward(self, tokenizer) -> Optional[Tuple]:
57
81
  if not self.jump_forward_map:
58
82
  return None
@@ -86,30 +110,6 @@ class OutlinesGrammar(BaseGrammarObject):
86
110
  ):
87
111
  self.state = next_state
88
112
 
89
- def allocate_vocab_mask(
90
- self, vocab_size: int, batch_size: int, device
91
- ) -> torch.Tensor:
92
- return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
93
-
94
- @staticmethod
95
- def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
96
- return vocab_mask
97
-
98
- def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
99
- tokens = torch.tensor(
100
- self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
101
- ).to(vocab_mask.device, non_blocking=True)
102
- vocab_mask = vocab_mask[idx]
103
- vocab_mask.fill_(1)
104
- vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
105
-
106
- @staticmethod
107
- def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
108
- logits.masked_fill_(vocab_mask, float("-inf"))
109
-
110
- def copy(self):
111
- return OutlinesGrammar(self.guide, self.jump_forward_map)
112
-
113
113
 
114
114
  class OutlinesGrammarBackend(BaseGrammarBackend):
115
115
  def __init__(
@@ -169,8 +169,9 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
169
169
  key_string,
170
170
  whitespace_pattern=self.whitespace_pattern,
171
171
  )
172
- except (NotImplementedError, json.decoder.JSONDecodeError) as e:
173
- logger.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
172
+ except (NotImplementedError, json.decoder.JSONDecodeError, ValueError) as e:
173
+ logger.warning(f"Skip invalid json_schema: {key_string=}, {e=}")
174
+ return None
174
175
  return self._compile_regex(regex)
175
176
 
176
177
  def dispatch_regex(self, key_string: str):
@@ -13,7 +13,6 @@
13
13
  # ==============================================================================
14
14
  """The baseclass of a backend for reasoner grammar-guided constrained decoding."""
15
15
 
16
- from concurrent.futures import Future
17
16
  from typing import List, Optional, Tuple
18
17
 
19
18
  import torch
@@ -28,13 +27,12 @@ class ReasonerGrammarObject(BaseGrammarObject):
28
27
  self.think_end_id = think_end_id
29
28
  self.is_in_reasoning = True
30
29
 
31
- @property
32
- def finished(self):
33
- return self.grammar.finished
30
+ def accept_token(self, token: int):
31
+ if token == self.think_end_id:
32
+ self.is_in_reasoning = False
34
33
 
35
- @finished.setter
36
- def finished(self, finished):
37
- self.grammar.finished = finished
34
+ if not self.is_in_reasoning and token != self.think_end_id:
35
+ self.grammar.accept_token(token)
38
36
 
39
37
  def allocate_vocab_mask(
40
38
  self, vocab_size: int, batch_size: int, device
@@ -52,12 +50,16 @@ class ReasonerGrammarObject(BaseGrammarObject):
52
50
  def apply_vocab_mask(self):
53
51
  return self.grammar.apply_vocab_mask
54
52
 
55
- def accept_token(self, token: int):
56
- if token == self.think_end_id:
57
- self.is_in_reasoning = False
53
+ def copy(self) -> BaseGrammarObject:
54
+ return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)
58
55
 
59
- if not self.is_in_reasoning and token != self.think_end_id:
60
- self.grammar.accept_token(token)
56
+ @property
57
+ def finished(self):
58
+ return self.grammar.finished
59
+
60
+ @finished.setter
61
+ def finished(self, finished):
62
+ self.grammar.finished = finished
61
63
 
62
64
  def try_jump_forward(self, tokenizer):
63
65
  return self.grammar.try_jump_forward(tokenizer)
@@ -72,30 +74,17 @@ class ReasonerGrammarObject(BaseGrammarObject):
72
74
  old_output_ids, new_output_ids, next_state
73
75
  )
74
76
 
75
- def copy(self) -> BaseGrammarObject:
76
- return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)
77
-
78
77
 
79
78
  class ReasonerGrammarBackend(BaseGrammarBackend):
80
79
  def __init__(self, grammar_backend: BaseGrammarBackend, think_end_id):
80
+ super().__init__()
81
81
  self.grammar_backend = grammar_backend
82
82
  self.think_end_id = think_end_id
83
83
 
84
- def get_cached_value(self, key: Tuple[str, str]) -> Optional[ReasonerGrammarObject]:
85
- grammar = self.grammar_backend.get_cached_value(key)
86
- return ReasonerGrammarObject(grammar, self.think_end_id) if grammar else None
87
-
88
- def get_future_value(self, key: Tuple[str, str]) -> Future:
89
- grammar = Future()
90
-
91
- def callback(f: Future):
92
- if result := f.result():
93
- grammar.set_result(ReasonerGrammarObject(result, self.think_end_id))
94
- else:
95
- grammar.set_result(None)
96
-
97
- self.grammar_backend.get_future_value(key).add_done_callback(callback)
98
- return grammar
99
-
100
- def reset(self):
101
- self.grammar_backend.reset()
84
+ def _init_value_dispatch(
85
+ self, key: Tuple[str, str]
86
+ ) -> Optional[ReasonerGrammarObject]:
87
+ ret = self.grammar_backend._init_value_dispatch(key)
88
+ if ret is None:
89
+ return None
90
+ return ReasonerGrammarObject(ret, self.think_end_id)
@@ -18,7 +18,6 @@ import logging
18
18
  from typing import List, Optional, Tuple, Union
19
19
 
20
20
  import torch
21
- import xgrammar
22
21
  from xgrammar import (
23
22
  CompiledGrammar,
24
23
  GrammarCompiler,
@@ -35,7 +34,6 @@ from sglang.srt.constrained.base_grammar_backend import (
35
34
  from sglang.srt.constrained.triton_ops.bitmask_ops import (
36
35
  apply_token_bitmask_inplace_triton,
37
36
  )
38
- from sglang.srt.utils import get_bool_env_var
39
37
 
40
38
  logger = logging.getLogger(__name__)
41
39
 
@@ -51,49 +49,35 @@ class XGrammarGrammar(BaseGrammarObject):
51
49
  vocab_size: int,
52
50
  ctx: CompiledGrammar,
53
51
  override_stop_tokens: Optional[Union[List[int], int]],
52
+ key_string: Optional[str] = None, # TODO (sk): for debugging, remove later
54
53
  ) -> None:
55
- super().__init__()
56
54
  self.matcher = matcher
57
55
  self.vocab_size = vocab_size
58
56
  self.ctx = ctx
59
57
  self.override_stop_tokens = override_stop_tokens
60
58
  self.finished = False
61
-
62
- from xgrammar.kernels.apply_token_bitmask_inplace_cpu import (
63
- apply_token_bitmask_inplace_cpu,
64
- )
65
-
66
- self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_cpu
59
+ self.accepted_tokens = []
60
+ self.key_string = key_string
67
61
 
68
62
  def accept_token(self, token: int):
69
- assert self.matcher.accept_token(token)
70
-
71
- def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
72
- s = self.matcher.find_jump_forward_string()
73
- if s:
74
- return [], s
75
- return None
76
-
77
- def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
78
- _, data = helper
79
- return data, -1
80
-
81
- def jump_and_retokenize(
82
- self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
83
- ):
84
- k = 0
85
- for i, old_id in enumerate(old_output_ids):
86
- if old_id == new_output_ids[i]:
87
- k = i + 1
63
+ if not self.is_terminated():
64
+ accepted = self.matcher.accept_token(token)
65
+ if not accepted:
66
+ # log for debugging
67
+ raise ValueError(
68
+ f"Tokens not accepted: {token}\n"
69
+ f"Accepted tokens: {self.accepted_tokens}\n"
70
+ f"Key string: {self.key_string}"
71
+ )
88
72
  else:
89
- break
73
+ self.accepted_tokens.append(token)
90
74
 
91
- # rollback to the last token that is the same
92
- if k < len(old_output_ids):
93
- self.matcher.rollback(len(old_output_ids) - k)
75
+ def rollback(self, k: int):
76
+ self.matcher.rollback(k)
77
+ self.accepted_tokens = self.accepted_tokens[:-k]
94
78
 
95
- for i in range(k, len(new_output_ids)):
96
- assert self.matcher.accept_token(new_output_ids[i])
79
+ def is_terminated(self):
80
+ return self.matcher.is_terminated()
97
81
 
98
82
  def allocate_vocab_mask(
99
83
  self, vocab_size: int, batch_size: int, device
@@ -122,9 +106,43 @@ class XGrammarGrammar(BaseGrammarObject):
122
106
  override_stop_tokens=self.override_stop_tokens,
123
107
  )
124
108
  return XGrammarGrammar(
125
- matcher, self.vocab_size, self.ctx, self.override_stop_tokens
109
+ matcher,
110
+ self.vocab_size,
111
+ self.ctx,
112
+ self.override_stop_tokens,
113
+ self.key_string,
126
114
  )
127
115
 
116
+ def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
117
+ s = self.matcher.find_jump_forward_string()
118
+ if s:
119
+ return [], s
120
+ return None
121
+
122
+ def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
123
+ _, data = helper
124
+ return data, -1
125
+
126
+ def jump_and_retokenize(
127
+ self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
128
+ ):
129
+ k = 0
130
+ for i, old_id in enumerate(old_output_ids):
131
+ if old_id == new_output_ids[i]:
132
+ k = i + 1
133
+ else:
134
+ break
135
+
136
+ # rollback to the last token that is the same
137
+ if k < len(old_output_ids):
138
+ self.matcher.rollback(len(old_output_ids) - k)
139
+
140
+ for i in range(k, len(new_output_ids)):
141
+ assert self.matcher.accept_token(new_output_ids[i])
142
+
143
+ def __repr__(self):
144
+ return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=})"
145
+
128
146
 
129
147
  class XGrammarGrammarBackend(BaseGrammarBackend):
130
148
  def __init__(
@@ -143,9 +161,15 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
143
161
  self.vocab_size = vocab_size
144
162
  self.override_stop_tokens = override_stop_tokens
145
163
 
146
- def _from_context(self, ctx: CompiledGrammar) -> XGrammarGrammar:
147
- matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
148
- return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens)
164
+ def _from_context(self, ctx: CompiledGrammar, key_string: str) -> XGrammarGrammar:
165
+ matcher = GrammarMatcher(
166
+ ctx,
167
+ max_rollback_tokens=MAX_ROLLBACK_TOKENS,
168
+ override_stop_tokens=self.override_stop_tokens,
169
+ )
170
+ return XGrammarGrammar(
171
+ matcher, self.vocab_size, ctx, self.override_stop_tokens, key_string
172
+ )
149
173
 
150
174
  def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
151
175
  try:
@@ -157,7 +181,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
157
181
  except RuntimeError as e:
158
182
  logging.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
159
183
  return None
160
- return self._from_context(ctx)
184
+ return self._from_context(ctx, key_string)
161
185
 
162
186
  def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
163
187
  try:
@@ -165,7 +189,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
165
189
  except RuntimeError as e:
166
190
  logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
167
191
  return None
168
- return self._from_context(ctx)
192
+ return self._from_context(ctx, key_string)
169
193
 
170
194
  def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
171
195
  try:
@@ -173,7 +197,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
173
197
  except RuntimeError as e:
174
198
  logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
175
199
  return None
176
- return self._from_context(ctx)
200
+ return self._from_context(ctx, key_string)
177
201
 
178
202
  def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
179
203
  try:
@@ -190,9 +214,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
190
214
  tags, structural_tag["triggers"]
191
215
  )
192
216
  except RuntimeError as e:
193
- logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
217
+ logging.warning(
218
+ f"Skip invalid structural_tag: structural_tag={key_string}, {e=}"
219
+ )
194
220
  return None
195
- return self._from_context(ctx)
221
+ return self._from_context(ctx, key_string)
196
222
 
197
223
  def reset(self):
198
224
  if self.grammar_compiler:
@@ -16,6 +16,7 @@
16
16
  # Adapted from
17
17
  # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
18
18
  import dataclasses
19
+ import re
19
20
  from enum import IntEnum, auto
20
21
  from typing import Callable, Dict, List, Optional, Tuple, Union
21
22
 
@@ -633,6 +634,20 @@ register_conv_template(
633
634
  )
634
635
  )
635
636
 
637
+ # reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json
638
+ register_conv_template(
639
+ Conversation(
640
+ name="mistral",
641
+ system_template="[SYSTEM_PROMPT]\n{system_message}\n[/SYSTEM_PROMPT]\n\n",
642
+ roles=("[INST]", "[/INST]"),
643
+ sep_style=SeparatorStyle.LLAMA2,
644
+ sep=" ",
645
+ sep2=" </s><s>",
646
+ stop_str=["[INST]", "[/INST]", "[SYSTEM_PROMPT]", "[/SYSTEM_PROMPT]"],
647
+ image_token="[IMG]",
648
+ )
649
+ )
650
+
636
651
  # reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
637
652
  register_conv_template(
638
653
  Conversation(
@@ -766,7 +781,7 @@ register_conv_template(
766
781
  Conversation(
767
782
  name="gemma-it",
768
783
  system_message="You are a helpful assistant.",
769
- system_template="<start_of_turn>user{system_message}\n\n",
784
+ system_template="<start_of_turn>user\n{system_message}\n\n",
770
785
  roles=("<start_of_turn>user\n", "<start_of_turn>model\n"),
771
786
  sep="<end_of_turn>\n",
772
787
  sep_style=SeparatorStyle.GEMMA3,
@@ -852,91 +867,81 @@ register_conv_template(
852
867
  )
853
868
 
854
869
 
870
+ @register_conv_template_matching_function
871
+ def match_internvl(model_path: str):
872
+ if re.search(r"internvl2_5", model_path, re.IGNORECASE):
873
+ return "internvl-2-5"
874
+
875
+
855
876
  @register_conv_template_matching_function
856
877
  def match_llama_3_vision(model_path: str):
857
- if (
858
- "llama" in model_path.lower()
859
- and "3.2" in model_path.lower()
860
- and "vision" in model_path.lower()
861
- ):
878
+ if re.search(r"llama.*3\.2.*vision", model_path, re.IGNORECASE):
862
879
  return "llama_3_vision"
863
880
 
864
881
 
865
882
  @register_conv_template_matching_function
866
883
  def match_deepseek_janus_pro(model_path: str):
867
- if "janus" in model_path.lower():
884
+ if re.search(r"janus", model_path, re.IGNORECASE):
868
885
  return "janus-pro"
869
886
 
870
887
 
871
888
  @register_conv_template_matching_function
872
889
  def match_vicuna(model_path: str):
873
- if "vicuna" in model_path.lower():
874
- return "vicuna_v1.1"
875
- if "llava-v1.5" in model_path.lower():
876
- return "vicuna_v1.1"
877
- if "llava-next-video-7b" in model_path.lower():
890
+ if re.search(r"vicuna|llava-v1\.5|llava-next-video-7b", model_path, re.IGNORECASE):
878
891
  return "vicuna_v1.1"
879
892
 
880
893
 
881
894
  @register_conv_template_matching_function
882
895
  def match_llama2_chat(model_path: str):
883
- model_path = model_path.lower()
884
- if "llama-2" in model_path and "chat" in model_path:
885
- return "llama-2"
886
- if (
887
- "mistral" in model_path or "mixtral" in model_path
888
- ) and "instruct" in model_path:
889
- return "llama-2"
890
- if "codellama" in model_path and "instruct" in model_path:
896
+ if re.search(
897
+ r"llama-2.*chat|codellama.*instruct",
898
+ model_path,
899
+ re.IGNORECASE,
900
+ ):
891
901
  return "llama-2"
892
902
 
893
903
 
904
+ @register_conv_template_matching_function
905
+ def match_mistral(model_path: str):
906
+ if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
907
+ return "mistral"
908
+
909
+
894
910
  @register_conv_template_matching_function
895
911
  def match_deepseek_vl(model_path: str):
896
- model_path = model_path.lower()
897
- if "deepseek" in model_path and "vl2" in model_path:
912
+ if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE):
898
913
  return "deepseek-vl2"
899
914
 
900
915
 
901
916
  @register_conv_template_matching_function
902
- def match_chat_ml(model_path: str):
903
- # import pdb;pdb.set_trace()
904
- model_path = model_path.lower()
905
- # Now the suffix for qwen2 chat model is "instruct"
906
- if "gme" in model_path and "qwen" in model_path and "vl" in model_path:
917
+ def match_qwen_chat_ml(model_path: str):
918
+ if re.search(r"gme.*qwen.*vl", model_path, re.IGNORECASE):
907
919
  return "gme-qwen2-vl"
908
- if "qwen" in model_path and "vl" in model_path:
920
+ if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
909
921
  return "qwen2-vl"
910
- if (
911
- "llava-v1.6-34b" in model_path
912
- or "llava-v1.6-yi-34b" in model_path
913
- or "llava-next-video-34b" in model_path
914
- or "llava-onevision-qwen2" in model_path
922
+ if re.search(
923
+ r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
924
+ model_path,
925
+ re.IGNORECASE,
915
926
  ):
916
927
  return "chatml-llava"
917
928
 
918
929
 
919
930
  @register_conv_template_matching_function
920
- def match_gemma_it(model_path: str):
921
- model_path = model_path.lower()
922
- if "gemma" in model_path and "it" in model_path:
923
- return "gemma-it"
924
- if "gemma-3" in model_path and "1b" not in model_path:
925
- # gemma-3-1b-it is completion model
931
+ def match_gemma3_instruct(model_path: str):
932
+ if re.search(r"gemma-3.*it", model_path, re.IGNORECASE):
926
933
  return "gemma-it"
927
934
 
928
935
 
929
936
  @register_conv_template_matching_function
930
937
  def match_openbmb_minicpm(model_path: str):
931
- model_path = model_path.lower()
932
- if "minicpm-v" in model_path:
938
+ if re.search(r"minicpm-v", model_path, re.IGNORECASE):
933
939
  return "minicpmv"
934
- elif "minicpm-o" in model_path:
940
+ elif re.search(r"minicpm-o", model_path, re.IGNORECASE):
935
941
  return "minicpmo"
936
942
 
937
943
 
938
944
  @register_conv_template_matching_function
939
945
  def match_moonshot_kimivl(model_path: str):
940
- model_path = model_path.lower()
941
- if "kimi" in model_path and "vl" in model_path:
946
+ if re.search(r"kimi.*vl", model_path, re.IGNORECASE):
942
947
  return "kimi-vl"
@@ -37,6 +37,7 @@ class BaseKVManager(ABC):
37
37
  args: KVArgs,
38
38
  disaggregation_mode: DisaggregationMode,
39
39
  server_args: ServerArgs,
40
+ is_mla_backend: Optional[bool] = False,
40
41
  ): ...
41
42
 
42
43