sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +208 -295
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  64. sglang/srt/layers/moe/topk.py +13 -4
  65. sglang/srt/layers/quantization/__init__.py +111 -7
  66. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  69. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  71. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  72. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  73. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  75. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  80. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  82. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  83. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  86. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  87. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  89. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  90. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  91. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  93. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  94. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  95. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  96. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  97. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  98. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  99. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  100. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  101. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  102. sglang/srt/layers/quantization/fp8.py +69 -28
  103. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  104. sglang/srt/layers/quantization/gptq.py +416 -0
  105. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  106. sglang/srt/layers/quantization/int8_utils.py +73 -0
  107. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  108. sglang/srt/layers/radix_attention.py +1 -0
  109. sglang/srt/layers/rotary_embedding.py +0 -1
  110. sglang/srt/layers/sampler.py +76 -31
  111. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  112. sglang/srt/lora/lora.py +17 -1
  113. sglang/srt/lora/lora_config.py +5 -0
  114. sglang/srt/lora/lora_manager.py +1 -3
  115. sglang/srt/managers/cache_controller.py +193 -62
  116. sglang/srt/managers/configure_logging.py +2 -1
  117. sglang/srt/managers/data_parallel_controller.py +6 -2
  118. sglang/srt/managers/detokenizer_manager.py +124 -102
  119. sglang/srt/managers/image_processor.py +2 -1
  120. sglang/srt/managers/io_struct.py +143 -6
  121. sglang/srt/managers/schedule_batch.py +238 -197
  122. sglang/srt/managers/schedule_policy.py +29 -29
  123. sglang/srt/managers/scheduler.py +681 -259
  124. sglang/srt/managers/session_controller.py +6 -2
  125. sglang/srt/managers/tokenizer_manager.py +224 -68
  126. sglang/srt/managers/tp_worker.py +15 -4
  127. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  128. sglang/srt/mem_cache/chunk_cache.py +18 -11
  129. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  130. sglang/srt/mem_cache/memory_pool.py +44 -18
  131. sglang/srt/mem_cache/radix_cache.py +58 -47
  132. sglang/srt/metrics/collector.py +94 -36
  133. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  134. sglang/srt/model_executor/forward_batch_info.py +49 -16
  135. sglang/srt/model_executor/model_runner.py +209 -28
  136. sglang/srt/model_loader/loader.py +3 -3
  137. sglang/srt/model_loader/weight_utils.py +36 -14
  138. sglang/srt/models/baichuan.py +31 -6
  139. sglang/srt/models/chatglm.py +39 -7
  140. sglang/srt/models/commandr.py +29 -5
  141. sglang/srt/models/dbrx.py +31 -5
  142. sglang/srt/models/deepseek.py +43 -6
  143. sglang/srt/models/deepseek_nextn.py +32 -19
  144. sglang/srt/models/deepseek_v2.py +265 -29
  145. sglang/srt/models/exaone.py +19 -9
  146. sglang/srt/models/gemma.py +22 -8
  147. sglang/srt/models/gemma2.py +25 -12
  148. sglang/srt/models/gemma2_reward.py +5 -1
  149. sglang/srt/models/gpt2.py +28 -13
  150. sglang/srt/models/gpt_bigcode.py +27 -5
  151. sglang/srt/models/granite.py +21 -9
  152. sglang/srt/models/grok.py +21 -4
  153. sglang/srt/models/internlm2.py +36 -6
  154. sglang/srt/models/internlm2_reward.py +5 -1
  155. sglang/srt/models/llama.py +26 -9
  156. sglang/srt/models/llama_classification.py +5 -1
  157. sglang/srt/models/llama_eagle.py +17 -4
  158. sglang/srt/models/llama_embedding.py +5 -1
  159. sglang/srt/models/llama_reward.py +7 -2
  160. sglang/srt/models/llava.py +19 -3
  161. sglang/srt/models/llavavid.py +10 -1
  162. sglang/srt/models/minicpm.py +26 -2
  163. sglang/srt/models/minicpm3.py +39 -3
  164. sglang/srt/models/minicpmv.py +45 -14
  165. sglang/srt/models/mixtral.py +20 -9
  166. sglang/srt/models/mixtral_quant.py +50 -8
  167. sglang/srt/models/mllama.py +57 -11
  168. sglang/srt/models/olmo.py +34 -6
  169. sglang/srt/models/olmo2.py +34 -13
  170. sglang/srt/models/olmoe.py +26 -4
  171. sglang/srt/models/phi3_small.py +29 -10
  172. sglang/srt/models/qwen.py +26 -3
  173. sglang/srt/models/qwen2.py +26 -4
  174. sglang/srt/models/qwen2_5_vl.py +46 -8
  175. sglang/srt/models/qwen2_eagle.py +17 -5
  176. sglang/srt/models/qwen2_moe.py +44 -6
  177. sglang/srt/models/qwen2_rm.py +78 -0
  178. sglang/srt/models/qwen2_vl.py +39 -8
  179. sglang/srt/models/stablelm.py +32 -5
  180. sglang/srt/models/torch_native_llama.py +5 -2
  181. sglang/srt/models/xverse.py +21 -9
  182. sglang/srt/models/xverse_moe.py +45 -7
  183. sglang/srt/models/yivl.py +2 -1
  184. sglang/srt/openai_api/adapter.py +109 -24
  185. sglang/srt/openai_api/protocol.py +17 -1
  186. sglang/srt/reasoning_parser.py +154 -0
  187. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  188. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  189. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  190. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  191. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  192. sglang/srt/sampling/sampling_batch_info.py +79 -157
  193. sglang/srt/sampling/sampling_params.py +16 -13
  194. sglang/srt/server_args.py +136 -52
  195. sglang/srt/speculative/build_eagle_tree.py +2 -8
  196. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  197. sglang/srt/speculative/eagle_utils.py +92 -58
  198. sglang/srt/speculative/eagle_worker.py +186 -94
  199. sglang/srt/speculative/spec_info.py +1 -13
  200. sglang/srt/utils.py +43 -17
  201. sglang/srt/warmup.py +47 -0
  202. sglang/test/few_shot_gsm8k.py +4 -1
  203. sglang/test/runners.py +389 -126
  204. sglang/test/send_one.py +88 -0
  205. sglang/test/test_block_fp8_ep.py +361 -0
  206. sglang/test/test_programs.py +1 -1
  207. sglang/test/test_utils.py +138 -84
  208. sglang/utils.py +50 -60
  209. sglang/version.py +1 -1
  210. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  211. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
  212. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  213. sglang/bench_latency.py +0 -1
  214. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  215. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  216. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  217. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  218. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  219. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,151 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """Constrained decoding with llguidance backend."""
15
+
16
+ import json
17
+ import os
18
+ from typing import List, Optional, Tuple
19
+
20
+ import llguidance
21
+ import llguidance.hf
22
+ import llguidance.torch
23
+ import torch
24
+ from llguidance.gbnf_to_lark import any_to_lark
25
+
26
+ from sglang.srt.constrained.base_grammar_backend import (
27
+ BaseGrammarBackend,
28
+ BaseGrammarObject,
29
+ )
30
+
31
+
32
+ class GuidanceGrammar(BaseGrammarObject):
33
+ def __init__(
34
+ self, llguidance_tokenizer: llguidance.LLTokenizer, serialized_grammar: str
35
+ ):
36
+ self.llguidance_tokenizer = llguidance_tokenizer
37
+ self.serialized_grammar = serialized_grammar
38
+
39
+ # TODO: add support for fast-forward tokens in the future
40
+ self.ll_interpreter = llguidance.LLInterpreter(
41
+ self.llguidance_tokenizer,
42
+ self.serialized_grammar,
43
+ enable_backtrack=False,
44
+ enable_ff_tokens=False,
45
+ log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
46
+ )
47
+ self.pending_ff_tokens: list[int] = []
48
+ self.finished = False
49
+ self.bitmask = None
50
+
51
+ def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
52
+ if len(self.pending_ff_tokens) > 0:
53
+ s = self.llguidance_tokenizer.decode_str(self.pending_ff_tokens)
54
+ ff_tokens = self.pending_ff_tokens
55
+ self.pending_ff_tokens = []
56
+ return (ff_tokens, s)
57
+
58
+ return None
59
+
60
+ def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
61
+ return "", -1
62
+
63
+ def jump_and_retokenize(
64
+ self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
65
+ ):
66
+ pass
67
+
68
+ def accept_token(self, token: int):
69
+ backtrack, ff_tokens = self.ll_interpreter.commit_token(token)
70
+ if len(ff_tokens) > 0 and backtrack == 0:
71
+ # first token is last generated token
72
+ ff_tokens = ff_tokens[1:]
73
+ self.pending_ff_tokens.extend(ff_tokens)
74
+
75
+ def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
76
+ if len(self.pending_ff_tokens) > 0:
77
+ # if we have pending fast-forward tokens,
78
+ # just return them immediately
79
+ ff_token = self.pending_ff_tokens.pop(0)
80
+ vocab_mask[idx, :] = 0
81
+ vocab_mask[idx, ff_token // 32] = 1 << (ff_token % 32)
82
+ return
83
+
84
+ if self.ll_interpreter.has_pending_stop():
85
+ self.finished = True
86
+
87
+ llguidance.torch.fill_next_token_bitmask(self.ll_interpreter, vocab_mask, idx)
88
+
89
+ def allocate_vocab_mask(
90
+ self, vocab_size: int, batch_size: int, device
91
+ ) -> torch.Tensor:
92
+ if self.bitmask is None or self.bitmask.shape[0] < batch_size:
93
+ # only create bitmask when batch gets larger
94
+ self.bitmask = llguidance.torch.allocate_token_bitmask(
95
+ batch_size, self.llguidance_tokenizer.vocab_size
96
+ )
97
+ bitmask = self.bitmask
98
+ else:
99
+ bitmask = self.bitmask[:batch_size]
100
+
101
+ return bitmask
102
+
103
+ @staticmethod
104
+ def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
105
+ return vocab_mask.to(device, non_blocking=True)
106
+
107
+ @staticmethod
108
+ def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
109
+ llguidance.torch.apply_token_bitmask_inplace(logits, vocab_mask)
110
+
111
+ def copy(self):
112
+ return GuidanceGrammar(
113
+ llguidance_tokenizer=self.llguidance_tokenizer,
114
+ serialized_grammar=self.serialized_grammar,
115
+ )
116
+
117
+
118
+ class GuidanceBackend(BaseGrammarBackend):
119
+ def __init__(self, tokenizer, whitespace_pattern: Optional[str] = None):
120
+ super().__init__()
121
+
122
+ self.tokenizer = tokenizer
123
+ self.whitespace_flexible = (
124
+ True if whitespace_pattern == "whitespace_flexible" else False
125
+ )
126
+ self.llguidance_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)
127
+
128
+ def _from_serialized(self, serialized_grammar) -> GuidanceGrammar:
129
+ return GuidanceGrammar(
130
+ llguidance_tokenizer=self.llguidance_tokenizer,
131
+ serialized_grammar=serialized_grammar,
132
+ )
133
+
134
+ def dispatch_json(self, key_string: str) -> GuidanceGrammar:
135
+ json_schema = key_string
136
+ compiler = llguidance.JsonCompiler(whitespace_flexible=self.whitespace_flexible)
137
+ serialized_grammar = compiler.compile(json_schema)
138
+ return self._from_serialized(serialized_grammar)
139
+
140
+ def dispatch_regex(self, key_string: str) -> GuidanceGrammar:
141
+ compiler = llguidance.RegexCompiler()
142
+ serialized_grammar = compiler.compile(regex=key_string)
143
+ return self._from_serialized(serialized_grammar)
144
+
145
+ def dispatch_ebnf(self, key_string: str) -> GuidanceGrammar:
146
+ compiler = llguidance.LarkCompiler()
147
+ serialized_grammar = compiler.compile(any_to_lark(key_string))
148
+ return self._from_serialized(serialized_grammar)
149
+
150
+ def dispatch_structural_tag(self, key_string: str):
151
+ return super().dispatch_structural_tag(key_string)
@@ -28,17 +28,11 @@ from sglang.srt.constrained.base_grammar_backend import (
28
28
  BaseGrammarObject,
29
29
  )
30
30
  from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
31
- from sglang.srt.utils import is_hip
32
31
 
33
- is_hip_ = is_hip()
34
-
35
- if is_hip_:
32
+ try:
33
+ from outlines.fsm.json_schema import build_regex_from_schema
34
+ except ImportError:
36
35
  from outlines_core.fsm.json_schema import build_regex_from_schema
37
- else:
38
- try:
39
- from outlines.fsm.json_schema import build_regex_from_schema
40
- except ImportError:
41
- from outlines_core.fsm.json_schema import build_regex_from_schema
42
36
 
43
37
 
44
38
  logger = logging.getLogger(__name__)
@@ -121,7 +115,6 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
121
115
  self,
122
116
  tokenizer,
123
117
  whitespace_pattern: bool,
124
- allow_jump_forward: bool,
125
118
  ):
126
119
  super().__init__()
127
120
 
@@ -146,27 +139,9 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
146
139
  self.outlines_tokenizer.vocabulary = (
147
140
  self.outlines_tokenizer.tokenizer.get_vocab()
148
141
  )
149
- self.allow_jump_forward = allow_jump_forward
150
142
  self.whitespace_pattern = whitespace_pattern
151
143
 
152
- def init_value_impl(self, key: Tuple[str, str]) -> OutlinesGrammar:
153
- key_type, key_string = key
154
- if key_type == "json":
155
- try:
156
- regex = build_regex_from_object(
157
- key_string,
158
- whitespace_pattern=self.whitespace_pattern,
159
- )
160
- except (NotImplementedError, json.decoder.JSONDecodeError) as e:
161
- logger.warning(
162
- f"Skip invalid json_schema: json_schema={key_string}, {e=}"
163
- )
164
- return None
165
- elif key_type == "regex":
166
- regex = key_string
167
- else:
168
- raise ValueError(f"Invalid key_type: {key_type}")
169
-
144
+ def _compile_regex(self, regex: str) -> Optional[OutlinesGrammar]:
170
145
  try:
171
146
  if hasattr(RegexGuide, "from_regex"):
172
147
  # outlines >= 0.1.1
@@ -178,12 +153,28 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
178
153
  logger.warning(f"skip invalid regex schema: {regex=}, {e=}")
179
154
  return None
180
155
 
181
- if self.allow_jump_forward:
182
- jump_forward_map = OutlinesJumpForwardMap(regex)
183
- else:
184
- jump_forward_map = None
156
+ jump_forward_map = None
185
157
  return OutlinesGrammar(guide, jump_forward_map)
186
158
 
159
+ def dispatch_ebnf(self, key_string: str):
160
+ return super().dispatch_ebnf(key_string)
161
+
162
+ def dispatch_structural_tag(self, key_string: str):
163
+ return super().dispatch_structural_tag(key_string)
164
+
165
+ def dispatch_json(self, key_string: str):
166
+ try:
167
+ regex = build_regex_from_object(
168
+ key_string,
169
+ whitespace_pattern=self.whitespace_pattern,
170
+ )
171
+ except (NotImplementedError, json.decoder.JSONDecodeError) as e:
172
+ logger.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
173
+ return self._compile_regex(regex)
174
+
175
+ def dispatch_regex(self, key_string: str):
176
+ return self._compile_regex(key_string)
177
+
187
178
 
188
179
  def build_regex_from_object(
189
180
  object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
@@ -13,15 +13,16 @@
13
13
  # ==============================================================================
14
14
  """Constrained decoding with xgrammar backend."""
15
15
 
16
+ import json
16
17
  import logging
17
- from typing import List, Tuple
18
+ from typing import List, Optional, Tuple, Union
18
19
 
19
20
  import torch
20
21
  from xgrammar import (
21
22
  CompiledGrammar,
22
- Grammar,
23
23
  GrammarCompiler,
24
24
  GrammarMatcher,
25
+ StructuralTagItem,
25
26
  TokenizerInfo,
26
27
  allocate_token_bitmask,
27
28
  apply_token_bitmask_inplace,
@@ -41,17 +42,22 @@ MAX_ROLLBACK_TOKENS = 200
41
42
  class XGrammarGrammar(BaseGrammarObject):
42
43
 
43
44
  def __init__(
44
- self, matcher: GrammarMatcher, vocab_size: int, ctx: CompiledGrammar
45
+ self,
46
+ matcher: GrammarMatcher,
47
+ vocab_size: int,
48
+ ctx: CompiledGrammar,
49
+ override_stop_tokens: Optional[Union[List[int], int]],
45
50
  ) -> None:
46
51
  self.matcher = matcher
47
52
  self.vocab_size = vocab_size
48
53
  self.ctx = ctx
54
+ self.override_stop_tokens = override_stop_tokens
49
55
  self.finished = False
50
56
 
51
57
  def accept_token(self, token: int):
52
58
  assert self.matcher.accept_token(token)
53
59
 
54
- def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
60
+ def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
55
61
  s = self.matcher.find_jump_forward_string()
56
62
  if s:
57
63
  return [], s
@@ -95,8 +101,14 @@ class XGrammarGrammar(BaseGrammarObject):
95
101
  apply_token_bitmask_inplace(logits, vocab_mask)
96
102
 
97
103
  def copy(self):
98
- matcher = GrammarMatcher(self.ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
99
- return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
104
+ matcher = GrammarMatcher(
105
+ self.ctx,
106
+ max_rollback_tokens=MAX_ROLLBACK_TOKENS,
107
+ override_stop_tokens=self.override_stop_tokens,
108
+ )
109
+ return XGrammarGrammar(
110
+ matcher, self.vocab_size, self.ctx, self.override_stop_tokens
111
+ )
100
112
 
101
113
 
102
114
  class XGrammarGrammarBackend(BaseGrammarBackend):
@@ -110,42 +122,61 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
110
122
  tokenizer_info = TokenizerInfo.from_huggingface(
111
123
  tokenizer, vocab_size=vocab_size
112
124
  )
125
+ override_stop_tokens = None
126
+
113
127
  self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
114
128
  self.vocab_size = vocab_size
129
+ self.override_stop_tokens = override_stop_tokens
115
130
 
116
- def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
117
-
118
- key_type, key_string = key
119
- if key_type == "json":
120
- try:
121
- if key_string == "$$ANY$$":
122
- ctx = self.grammar_compiler.compile_builtin_json_grammar()
123
- else:
124
- ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
125
- except RuntimeError as e:
126
- logging.warning(
127
- f"Skip invalid json_schema: json_schema={key_string}, {e=}"
128
- )
129
- return None
130
- elif key_type == "ebnf":
131
- try:
132
- ctx = self.grammar_compiler.compile_grammar(key_string)
133
- except RuntimeError as e:
134
- logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
135
- return None
136
- elif key_type == "regex":
137
- try:
138
- ctx = self.grammar_compiler.compile_grammar(
139
- Grammar.from_regex(key_string)
140
- )
141
- except RuntimeError as e:
142
- logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
143
- return None
144
- else:
145
- raise ValueError(f"Invalid key_type: {key_type}")
146
-
131
+ def _from_context(self, ctx: CompiledGrammar) -> XGrammarGrammar:
147
132
  matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
148
- return XGrammarGrammar(matcher, self.vocab_size, ctx)
133
+ return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens)
134
+
135
+ def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
136
+ try:
137
+ if key_string == "$$ANY$$":
138
+ ctx = self.grammar_compiler.compile_builtin_json_grammar()
139
+ else:
140
+ ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
141
+ except RuntimeError as e:
142
+ logging.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
143
+ return None
144
+ return self._from_context(ctx)
145
+
146
+ def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
147
+ try:
148
+ ctx = self.grammar_compiler.compile_grammar(key_string)
149
+ except RuntimeError as e:
150
+ logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
151
+ return None
152
+ return self._from_context(ctx)
153
+
154
+ def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
155
+ try:
156
+ ctx = self.grammar_compiler.compile_regex(key_string)
157
+ except RuntimeError as e:
158
+ logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
159
+ return None
160
+ return self._from_context(ctx)
161
+
162
+ def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
163
+ try:
164
+ structural_tag = json.loads(key_string)
165
+ tags = [
166
+ StructuralTagItem(
167
+ begin=structure["begin"],
168
+ schema=json.dumps(structure["schema"]),
169
+ end=structure["end"],
170
+ )
171
+ for structure in structural_tag["structures"]
172
+ ]
173
+ ctx = self.grammar_compiler.compile_structural_tag(
174
+ tags, structural_tag["triggers"]
175
+ )
176
+ except RuntimeError as e:
177
+ logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
178
+ return None
179
+ return self._from_context(ctx)
149
180
 
150
181
  def reset(self):
151
182
  if self.grammar_compiler: