sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -14,8 +14,9 @@
14
14
  """The baseclass of a backend for grammar-guided constrained decoding."""
15
15
 
16
16
  import logging
17
+ import time
17
18
  from concurrent.futures import ThreadPoolExecutor
18
- from dataclasses import dataclass
19
+ from dataclasses import dataclass, field
19
20
  from threading import Event
20
21
  from typing import Dict, List, Optional, Tuple
21
22
 
@@ -26,10 +27,22 @@ from sglang.srt.server_args import ServerArgs
26
27
  logger = logging.getLogger(__name__)
27
28
 
28
29
 
30
+ @dataclass
31
+ class GrammarStats:
32
+ compilation_time: Optional[float] = None
33
+ schema_count: Optional[int] = None
34
+ ebnf_size: Optional[int] = None
35
+ is_cache_hit: bool = False
36
+ is_grammar_aborted: bool = False
37
+ tree_traversal_time: List[float] = field(default_factory=list)
38
+
39
+
29
40
  class BaseGrammarObject:
30
41
 
31
42
  def __init__(self):
32
43
  self._finished = False
44
+ self.grammar_stats = None
45
+ self.current_token = None
33
46
 
34
47
  def accept_token(self, token: int) -> None:
35
48
  """
@@ -137,19 +150,26 @@ class BaseGrammarBackend:
137
150
  return self._not_supported("structural_tag", key_string)
138
151
 
139
152
  def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
153
+ s = time.perf_counter()
140
154
  key_type, key_string = key
141
155
  if key_type == "json":
142
- return self.dispatch_json(key_string)
156
+ grammar = self.dispatch_json(key_string)
143
157
  elif key_type == "regex":
144
- return self.dispatch_regex(key_string)
158
+ grammar = self.dispatch_regex(key_string)
145
159
  elif key_type == "ebnf":
146
- return self.dispatch_ebnf(key_string)
160
+ grammar = self.dispatch_ebnf(key_string)
147
161
  elif key_type == "structural_tag":
148
- return self.dispatch_structural_tag(key_string)
162
+ grammar = self.dispatch_structural_tag(key_string)
149
163
  elif key_type == "structural_pattern":
150
- return self.dispatch_structural_pattern(key_string)
164
+ grammar = self.dispatch_structural_pattern(key_string)
165
+ elif key_type == "structural_pattern_v2":
166
+ grammar = self.dispatch_structural_pattern_v2(key_string)
151
167
  else:
152
- return self.dispatch_fallback(key_type, key_string)
168
+ grammar = self.dispatch_fallback(key_type, key_string)
169
+
170
+ if grammar is not None and grammar.grammar_stats is not None:
171
+ grammar.grammar_stats.compilation_time = time.perf_counter() - s
172
+ return grammar
153
173
 
154
174
  def get_cached_or_future_value(
155
175
  self, key: Tuple[str, str]
@@ -167,20 +187,36 @@ class BaseGrammarBackend:
167
187
  self.cache.clear()
168
188
 
169
189
 
190
+ GRAMMAR_BACKEND_REGISTRY = {}
191
+
192
+
193
+ def register_grammar_backend(name, init_func):
194
+ GRAMMAR_BACKEND_REGISTRY[name] = init_func
195
+
196
+
170
197
  def create_grammar_backend(
171
198
  server_args: ServerArgs,
172
199
  tokenizer,
173
200
  vocab_size: int,
174
201
  eos_token_ids: Optional[set] = None,
175
202
  ) -> Optional[BaseGrammarBackend]:
176
- if server_args.grammar_backend == "outlines":
203
+ name = server_args.grammar_backend
204
+
205
+ # Custom grammar backend has the highest priority
206
+ if name in GRAMMAR_BACKEND_REGISTRY:
207
+ return GRAMMAR_BACKEND_REGISTRY[name](
208
+ server_args, tokenizer, vocab_size, eos_token_ids
209
+ )
210
+
211
+ # Default grammar backends
212
+ if name == "outlines":
177
213
  from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
178
214
 
179
215
  grammar_backend = OutlinesGrammarBackend(
180
216
  tokenizer,
181
217
  whitespace_pattern=server_args.constrained_json_whitespace_pattern,
182
218
  )
183
- elif server_args.grammar_backend == "xgrammar":
219
+ elif name == "xgrammar":
184
220
  from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
185
221
 
186
222
  # Convert Set[int] to List[int] if needed
@@ -189,17 +225,17 @@ def create_grammar_backend(
189
225
  grammar_backend = XGrammarGrammarBackend(
190
226
  tokenizer, vocab_size=vocab_size, model_eos_token_ids=eos_list
191
227
  )
192
- elif server_args.grammar_backend == "llguidance":
228
+ elif name == "llguidance":
193
229
  from sglang.srt.constrained.llguidance_backend import GuidanceBackend
194
230
 
195
231
  grammar_backend = GuidanceBackend(
196
232
  tokenizer=tokenizer,
197
233
  whitespace_pattern=server_args.constrained_json_whitespace_pattern,
198
234
  )
199
- elif server_args.grammar_backend == "none":
235
+ elif name == "none":
200
236
  return None
201
237
  else:
202
- raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
238
+ raise ValueError(f"Invalid grammar backend: {name}")
203
239
 
204
240
  if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
205
241
  from sglang.srt.constrained.reasoner_grammar_backend import (
@@ -48,7 +48,6 @@ class GuidanceGrammar(BaseGrammarObject):
48
48
  self.serialized_grammar,
49
49
  log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
50
50
  )
51
- self.finished = False
52
51
  self.bitmask = None
53
52
 
54
53
  def accept_token(self, token: int):
@@ -49,7 +49,6 @@ class OutlinesGrammar(BaseGrammarObject):
49
49
  self.guide = guide
50
50
  self.jump_forward_map = jump_forward_map
51
51
  self.state = 0
52
- self.finished = False
53
52
 
54
53
  def accept_token(self, token: int):
55
54
  self.state = self.guide.get_next_state(self.state, token)
@@ -13,6 +13,7 @@
13
13
  # ==============================================================================
14
14
  """Constrained decoding with xgrammar backend."""
15
15
 
16
+ import dataclasses
16
17
  import json
17
18
  import logging
18
19
  from typing import List, Optional, Tuple, Union
@@ -31,6 +32,7 @@ from sglang.srt.constrained.base_grammar_backend import (
31
32
  INVALID_GRAMMAR_OBJ,
32
33
  BaseGrammarBackend,
33
34
  BaseGrammarObject,
35
+ GrammarStats,
34
36
  )
35
37
  from sglang.srt.utils import is_hip
36
38
 
@@ -41,9 +43,9 @@ else:
41
43
  from sglang.srt.constrained.triton_ops.bitmask_ops import (
42
44
  apply_token_bitmask_inplace_triton,
43
45
  )
44
- logger = logging.getLogger(__name__)
45
46
 
46
47
 
48
+ logger = logging.getLogger(__name__)
47
49
  MAX_ROLLBACK_TOKENS = 200
48
50
 
49
51
 
@@ -56,17 +58,20 @@ class XGrammarGrammar(BaseGrammarObject):
56
58
  ctx: CompiledGrammar,
57
59
  override_stop_tokens: Optional[Union[List[int], int]],
58
60
  key_string: Optional[str] = None, # TODO (sk): for debugging, remove later
61
+ grammar_stats: Optional[GrammarStats] = GrammarStats(),
59
62
  ) -> None:
63
+ super().__init__()
60
64
  self.matcher = matcher
61
65
  self.vocab_size = vocab_size
62
66
  self.ctx = ctx
63
67
  self.override_stop_tokens = override_stop_tokens
64
- self.finished = False
65
68
  self.accepted_tokens = []
66
69
  self.key_string = key_string
70
+ self.grammar_stats = grammar_stats
67
71
 
68
72
  def accept_token(self, token: int):
69
73
  if not self.is_terminated():
74
+ self.current_token = token
70
75
  accepted = self.matcher.accept_token(token)
71
76
  if not accepted:
72
77
  # log for debugging
@@ -120,6 +125,9 @@ class XGrammarGrammar(BaseGrammarObject):
120
125
  self.ctx,
121
126
  self.override_stop_tokens,
122
127
  self.key_string,
128
+ dataclasses.replace(
129
+ self.grammar_stats, is_cache_hit=True, tree_traversal_time=[]
130
+ ),
123
131
  )
124
132
 
125
133
  def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
@@ -150,7 +158,7 @@ class XGrammarGrammar(BaseGrammarObject):
150
158
  assert self.matcher.accept_token(new_output_ids[i])
151
159
 
152
160
  def __repr__(self):
153
- return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=})"
161
+ return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=}, {self.current_token=})"
154
162
 
155
163
 
156
164
  class XGrammarGrammarBackend(BaseGrammarBackend):
@@ -165,6 +173,10 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
165
173
  if hasattr(tokenizer, "init_xgrammar"):
166
174
  # For special tokenizer
167
175
  tokenizer_info, override_stop_tokens = tokenizer.init_xgrammar()
176
+
177
+ if tokenizer_info is None:
178
+ # Not supported tokenizer
179
+ return
168
180
  else:
169
181
  # Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens
170
182
  # This ensures consistency between what the model considers EOS and what XGrammar uses
@@ -177,14 +189,21 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
177
189
  self.vocab_size = vocab_size
178
190
  self.override_stop_tokens = override_stop_tokens
179
191
 
180
- def _from_context(self, ctx: CompiledGrammar, key_string: str) -> XGrammarGrammar:
192
+ def _from_context(
193
+ self, ctx: CompiledGrammar, key_string: str, grammar_stats: GrammarStats
194
+ ) -> XGrammarGrammar:
181
195
  matcher = GrammarMatcher(
182
196
  ctx,
183
197
  max_rollback_tokens=MAX_ROLLBACK_TOKENS,
184
198
  override_stop_tokens=self.override_stop_tokens,
185
199
  )
186
200
  return XGrammarGrammar(
187
- matcher, self.vocab_size, ctx, self.override_stop_tokens, key_string
201
+ matcher,
202
+ self.vocab_size,
203
+ ctx,
204
+ self.override_stop_tokens,
205
+ key_string,
206
+ grammar_stats,
188
207
  )
189
208
 
190
209
  def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
@@ -198,7 +217,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
198
217
  except (RuntimeError, json.decoder.JSONDecodeError) as e:
199
218
  logging.error(f"Hit invalid json_schema: {key_string=}, {e=}")
200
219
  return INVALID_GRAMMAR_OBJ
201
- return self._from_context(ctx, key_string)
220
+ return self._from_context(ctx, key_string, GrammarStats())
202
221
 
203
222
  def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
204
223
  try:
@@ -206,7 +225,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
206
225
  except RuntimeError as e:
207
226
  logging.error(f"Hit invalid ebnf: {key_string=}, {e=}")
208
227
  return INVALID_GRAMMAR_OBJ
209
- return self._from_context(ctx, key_string)
228
+ return self._from_context(ctx, key_string, GrammarStats())
210
229
 
211
230
  def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
212
231
  try:
@@ -214,7 +233,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
214
233
  except RuntimeError as e:
215
234
  logging.error(f"Hit invalid regex: {key_string=}, {e=}")
216
235
  return INVALID_GRAMMAR_OBJ
217
- return self._from_context(ctx, key_string)
236
+ return self._from_context(ctx, key_string, GrammarStats())
218
237
 
219
238
  def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
220
239
  try:
@@ -233,7 +252,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
233
252
  except (RuntimeError, json.decoder.JSONDecodeError) as e:
234
253
  logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
235
254
  return INVALID_GRAMMAR_OBJ
236
- return self._from_context(ctx, key_string)
255
+ return self._from_context(ctx, key_string, GrammarStats())
237
256
 
238
257
  def reset(self):
239
258
  self.grammar_compiler.clear_cache()
sglang/srt/custom_op.py CHANGED
@@ -1,12 +1,20 @@
1
1
  from torch import nn
2
2
 
3
- from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
3
+ from sglang.srt.utils import (
4
+ cpu_has_amx_support,
5
+ is_cpu,
6
+ is_cuda,
7
+ is_hip,
8
+ is_npu,
9
+ is_xpu,
10
+ )
4
11
 
5
12
  _is_cuda = is_cuda()
6
13
  _is_hip = is_hip()
7
14
  _is_cpu = is_cpu()
8
15
  _is_cpu_amx_available = cpu_has_amx_support()
9
16
  _is_npu = is_npu()
17
+ _is_xpu = is_xpu()
10
18
 
11
19
 
12
20
  class CustomOp(nn.Module):
@@ -88,5 +96,7 @@ class CustomOp(nn.Module):
88
96
  return self.forward_cpu
89
97
  elif _is_npu:
90
98
  return self.forward_npu
99
+ elif _is_xpu:
100
+ return self.forward_xpu
91
101
  else:
92
102
  return self.forward_native
@@ -1,11 +1,11 @@
1
1
  import argparse
2
2
  import functools
3
- import re
4
3
  from pathlib import Path
5
4
 
6
5
  import polars as pl
7
6
  import torch
8
7
 
8
+ from sglang.srt.debug_utils.dump_loader import find_row, read_meta
9
9
  from sglang.srt.debug_utils.dumper import get_truncated_value
10
10
 
11
11
 
@@ -26,66 +26,77 @@ def main(args):
26
26
  print("df_baseline", df_baseline)
27
27
 
28
28
  for row in df_target.iter_rows(named=True):
29
- rows_baseline = df_baseline.filter(
30
- (
31
- pl.col("forward_pass_id")
32
- == row["forward_pass_id"] - args.start_id + args.baseline_start_id
33
- )
34
- & functools.reduce(
35
- lambda a, b: a & b,
36
- [
37
- pl.col(col) == row[col]
38
- for col in row.keys()
39
- if col not in ["forward_pass_id", "dump_index", "filename"]
40
- ],
41
- )
29
+ path_target = Path(args.target_path) / row["filename"]
30
+
31
+ row_baseline = find_row(
32
+ df_baseline,
33
+ conditions=dict(
34
+ forward_pass_id=row["forward_pass_id"]
35
+ - args.start_id
36
+ + args.baseline_start_id,
37
+ **{
38
+ k: v
39
+ for k, v in row.items()
40
+ if k not in ["forward_pass_id", "dump_index", "filename"]
41
+ },
42
+ ),
42
43
  )
43
- assert len(rows_baseline) == 1, f"{rows_baseline=}"
44
- row_baseline = rows_baseline.to_dicts()[0]
44
+
45
+ if row_baseline is None:
46
+ print(f"Skip: target={str(path_target)} since no baseline")
47
+ x_target = _load_object(path_target)
48
+ if x_target is not None:
49
+ print(f"x_target(sample)={get_truncated_value(x_target)}")
50
+ continue
45
51
 
46
52
  path_baseline = Path(args.baseline_path) / row_baseline["filename"]
47
- path_target = Path(args.target_path) / row["filename"]
48
53
  print(f"Check: target={str(path_target)} baseline={str(path_baseline)}")
49
- check_tensor_pair(path_baseline=path_baseline, path_target=path_target)
54
+ check_tensor_pair(
55
+ path_baseline=path_baseline, path_target=path_target, name=row["name"]
56
+ )
50
57
  print()
51
58
 
52
59
 
53
- def read_meta(directory):
54
- directory = Path(directory)
55
- assert directory.is_dir(), f"{directory=} should be a directory"
56
-
57
- rows = []
58
- for p in directory.glob("*.pt"):
59
- full_kwargs = {}
60
- for kv in p.stem.split("___"):
61
- k, v = kv.split("=")
62
- full_kwargs[k] = v
63
- rows.append(
64
- {
65
- "filename": str(p.name),
66
- **full_kwargs,
67
- }
68
- )
60
+ def check_tensor_pair(path_baseline, path_target, name=""):
61
+ x_baseline = _load_object(path_baseline)
62
+ x_target = _load_object(path_target)
69
63
 
70
- df = pl.DataFrame(rows)
71
- df = df.with_columns(
72
- pl.col("forward_pass_id").cast(int),
73
- pl.col("rank").cast(int),
64
+ print(
65
+ f"Raw "
66
+ f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
67
+ f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
74
68
  )
75
- return df
76
-
77
69
 
78
- def check_tensor_pair(path_baseline, path_target):
79
- x_baseline = torch.load(path_baseline, weights_only=True)
80
- x_target = torch.load(path_target, weights_only=True)
70
+ x_baseline, x_target = _comparison_preprocessor(x_baseline, x_target, name=name)
71
+ x_baseline = _try_unify_shape(x_baseline, target_shape=x_target.shape)
81
72
 
82
73
  print(
74
+ f"After preprocessor "
83
75
  f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
84
76
  f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
85
77
  )
86
78
 
79
+ x_target = x_target.float()
80
+ x_baseline = x_baseline.float()
81
+
82
+ for name, fn in (
83
+ ("mean", torch.mean),
84
+ ("std", torch.std),
85
+ ("min", torch.min),
86
+ ("max", torch.max),
87
+ ("p1", functools.partial(torch.quantile, q=0.01)),
88
+ ("p5", functools.partial(torch.quantile, q=0.05)),
89
+ ("p95", functools.partial(torch.quantile, q=0.95)),
90
+ ("p99", functools.partial(torch.quantile, q=0.99)),
91
+ ):
92
+ value_baseline = fn(x_baseline).item()
93
+ value_target = fn(x_target).item()
94
+ print(
95
+ f"[{name}] {value_baseline :.4f} vs {value_target:.4f} (diff: {value_target - value_baseline:.4f})"
96
+ )
97
+
87
98
  if x_baseline.shape != x_target.shape:
88
- print(f" Shape mismatch")
99
+ print(f"⚠️ Shape mismatch")
89
100
  return
90
101
 
91
102
  raw_abs_diff = (x_target - x_baseline).abs()
@@ -112,6 +123,19 @@ def check_tensor_pair(path_baseline, path_target):
112
123
  print(f"x_target(sample)={get_truncated_value(x_target)}")
113
124
 
114
125
 
126
+ def _try_unify_shape(x: torch.Tensor, target_shape):
127
+ x_shape = x.shape
128
+ num_dim_to_remove = len(x_shape) - len(target_shape)
129
+ if (x_shape[num_dim_to_remove:] == target_shape) and all(
130
+ val == 1 for val in x_shape[:num_dim_to_remove]
131
+ ):
132
+ out = functools.reduce(lambda a, _: a.squeeze(0), range(num_dim_to_remove), x)
133
+ print(f"Unify shape: {x_shape} -> {out.shape} (to match {target_shape})")
134
+ return out
135
+
136
+ return x
137
+
138
+
115
139
  # Copied from DeepGEMM
116
140
  def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
117
141
  x, y = x.double(), y.double()
@@ -120,6 +144,19 @@ def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
120
144
  return 1 - sim
121
145
 
122
146
 
147
+ def _comparison_preprocessor(x_baseline, x_target, name):
148
+ # can insert arbitrary adhoc postprocessing logic here
149
+ return x_baseline, x_target
150
+
151
+
152
+ def _load_object(path):
153
+ x = torch.load(path, weights_only=False)
154
+ if not isinstance(x, torch.Tensor):
155
+ print(f"Skip load {path} since {type(x)=} is not a Tensor")
156
+ return None
157
+ return x.cuda()
158
+
159
+
123
160
  if __name__ == "__main__":
124
161
  parser = argparse.ArgumentParser()
125
162
  parser.add_argument("--baseline-path", type=str)
@@ -0,0 +1,97 @@
1
+ import functools
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Any, Dict
5
+
6
+ import polars as pl
7
+ import torch
8
+
9
+
10
+ class DumpLoader:
11
+ def __init__(self):
12
+ directory = os.environ.get("SGLANG_DUMP_LOADER_DIR")
13
+
14
+ self._enable = directory is not None
15
+ if self._enable:
16
+ self._directory = Path(directory)
17
+ self._df = read_meta(directory)
18
+
19
+ @property
20
+ def enable(self):
21
+ return self._enable
22
+
23
+ def load(self, name, **kwargs):
24
+ assert self._enable, "Please call DumpLoader.load only when it is enabled"
25
+
26
+ from sglang.srt.debug_utils.dumper import dumper
27
+
28
+ forward_pass_id = dumper._forward_pass_id
29
+ conditions = dict(name=name, forward_pass_id=forward_pass_id, **kwargs)
30
+ row = find_row(self._df, conditions=conditions)
31
+ assert (
32
+ row is not None
33
+ ), f"DumpLoader cannot find row given query {name=} {kwargs=} {self._directory=}"
34
+
35
+ path = self._directory / row["filename"]
36
+ output = torch.load(path, weights_only=False)
37
+
38
+ print(
39
+ f"[DumpLoader] load from {path=} (query: {name=} {kwargs=}, output: {type(output)})"
40
+ )
41
+ return output
42
+
43
+
44
+ def read_meta(directory):
45
+ directory = Path(directory)
46
+ assert directory.is_dir(), f"{directory=} should be a directory"
47
+
48
+ rows = []
49
+ for p in directory.glob("*.pt"):
50
+ full_kwargs = {}
51
+ for kv in p.stem.split("___"):
52
+ k, v = kv.split("=")
53
+ full_kwargs[k] = v
54
+ rows.append(
55
+ {
56
+ "filename": str(p.name),
57
+ **full_kwargs,
58
+ }
59
+ )
60
+
61
+ df = pl.DataFrame(rows)
62
+ df = df.with_columns(
63
+ pl.col("forward_pass_id").cast(int),
64
+ pl.col("rank").cast(int),
65
+ pl.col("dump_index").cast(int),
66
+ )
67
+ return df
68
+
69
+
70
+ def find_row(df, conditions: Dict[str, Any]):
71
+ df_sub = df.filter(
72
+ functools.reduce(
73
+ lambda a, b: a & b,
74
+ [
75
+ pl.col(col) == _cast_to_polars_dtype(conditions[col], df.schema[col])
76
+ for col in conditions.keys()
77
+ ],
78
+ )
79
+ )
80
+ assert len(df_sub) <= 1
81
+ return df_sub.to_dicts()[0] if len(df_sub) > 0 else None
82
+
83
+
84
+ def _cast_to_polars_dtype(value, target_dtype):
85
+ if target_dtype in (pl.Int64, pl.Int32, pl.UInt64, pl.UInt32):
86
+ return int(value)
87
+ elif target_dtype in (pl.Float64, pl.Float32):
88
+ return float(value)
89
+ elif target_dtype == pl.Boolean:
90
+ return bool(value)
91
+ elif target_dtype == pl.String:
92
+ return str(value)
93
+ else:
94
+ return value
95
+
96
+
97
+ dump_loader = DumpLoader()
@@ -53,7 +53,7 @@ class _Dumper:
53
53
  if self._partial_name is None:
54
54
  self._partial_name = _get_partial_name()
55
55
 
56
- rank = dist.get_rank()
56
+ rank = _get_rank()
57
57
  full_kwargs = dict(
58
58
  forward_pass_id=self._forward_pass_id,
59
59
  rank=rank,
@@ -80,12 +80,20 @@ class _Dumper:
80
80
 
81
81
 
82
82
  def _get_partial_name():
83
- rank = dist.get_rank()
83
+ rank = _get_rank()
84
84
  object_list = [str(time.time()) if rank == 0 else None]
85
- dist.broadcast_object_list(object_list, device="cuda")
85
+ if dist.is_initialized():
86
+ dist.broadcast_object_list(object_list, device="cuda")
86
87
  return object_list[0]
87
88
 
88
89
 
90
+ def _get_rank():
91
+ if dist.is_initialized():
92
+ return dist.get_rank()
93
+ else:
94
+ return 0
95
+
96
+
89
97
  def get_truncated_value(value):
90
98
  if value is None:
91
99
  return None