sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -14,8 +14,9 @@
|
|
14
14
|
"""The baseclass of a backend for grammar-guided constrained decoding."""
|
15
15
|
|
16
16
|
import logging
|
17
|
+
import time
|
17
18
|
from concurrent.futures import ThreadPoolExecutor
|
18
|
-
from dataclasses import dataclass
|
19
|
+
from dataclasses import dataclass, field
|
19
20
|
from threading import Event
|
20
21
|
from typing import Dict, List, Optional, Tuple
|
21
22
|
|
@@ -26,10 +27,22 @@ from sglang.srt.server_args import ServerArgs
|
|
26
27
|
logger = logging.getLogger(__name__)
|
27
28
|
|
28
29
|
|
30
|
+
@dataclass
|
31
|
+
class GrammarStats:
|
32
|
+
compilation_time: Optional[float] = None
|
33
|
+
schema_count: Optional[int] = None
|
34
|
+
ebnf_size: Optional[int] = None
|
35
|
+
is_cache_hit: bool = False
|
36
|
+
is_grammar_aborted: bool = False
|
37
|
+
tree_traversal_time: List[float] = field(default_factory=list)
|
38
|
+
|
39
|
+
|
29
40
|
class BaseGrammarObject:
|
30
41
|
|
31
42
|
def __init__(self):
|
32
43
|
self._finished = False
|
44
|
+
self.grammar_stats = None
|
45
|
+
self.current_token = None
|
33
46
|
|
34
47
|
def accept_token(self, token: int) -> None:
|
35
48
|
"""
|
@@ -137,19 +150,26 @@ class BaseGrammarBackend:
|
|
137
150
|
return self._not_supported("structural_tag", key_string)
|
138
151
|
|
139
152
|
def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
153
|
+
s = time.perf_counter()
|
140
154
|
key_type, key_string = key
|
141
155
|
if key_type == "json":
|
142
|
-
|
156
|
+
grammar = self.dispatch_json(key_string)
|
143
157
|
elif key_type == "regex":
|
144
|
-
|
158
|
+
grammar = self.dispatch_regex(key_string)
|
145
159
|
elif key_type == "ebnf":
|
146
|
-
|
160
|
+
grammar = self.dispatch_ebnf(key_string)
|
147
161
|
elif key_type == "structural_tag":
|
148
|
-
|
162
|
+
grammar = self.dispatch_structural_tag(key_string)
|
149
163
|
elif key_type == "structural_pattern":
|
150
|
-
|
164
|
+
grammar = self.dispatch_structural_pattern(key_string)
|
165
|
+
elif key_type == "structural_pattern_v2":
|
166
|
+
grammar = self.dispatch_structural_pattern_v2(key_string)
|
151
167
|
else:
|
152
|
-
|
168
|
+
grammar = self.dispatch_fallback(key_type, key_string)
|
169
|
+
|
170
|
+
if grammar is not None and grammar.grammar_stats is not None:
|
171
|
+
grammar.grammar_stats.compilation_time = time.perf_counter() - s
|
172
|
+
return grammar
|
153
173
|
|
154
174
|
def get_cached_or_future_value(
|
155
175
|
self, key: Tuple[str, str]
|
@@ -167,20 +187,36 @@ class BaseGrammarBackend:
|
|
167
187
|
self.cache.clear()
|
168
188
|
|
169
189
|
|
190
|
+
GRAMMAR_BACKEND_REGISTRY = {}
|
191
|
+
|
192
|
+
|
193
|
+
def register_grammar_backend(name, init_func):
|
194
|
+
GRAMMAR_BACKEND_REGISTRY[name] = init_func
|
195
|
+
|
196
|
+
|
170
197
|
def create_grammar_backend(
|
171
198
|
server_args: ServerArgs,
|
172
199
|
tokenizer,
|
173
200
|
vocab_size: int,
|
174
201
|
eos_token_ids: Optional[set] = None,
|
175
202
|
) -> Optional[BaseGrammarBackend]:
|
176
|
-
|
203
|
+
name = server_args.grammar_backend
|
204
|
+
|
205
|
+
# Custom grammar backend has the highest priority
|
206
|
+
if name in GRAMMAR_BACKEND_REGISTRY:
|
207
|
+
return GRAMMAR_BACKEND_REGISTRY[name](
|
208
|
+
server_args, tokenizer, vocab_size, eos_token_ids
|
209
|
+
)
|
210
|
+
|
211
|
+
# Default grammar backends
|
212
|
+
if name == "outlines":
|
177
213
|
from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
|
178
214
|
|
179
215
|
grammar_backend = OutlinesGrammarBackend(
|
180
216
|
tokenizer,
|
181
217
|
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
182
218
|
)
|
183
|
-
elif
|
219
|
+
elif name == "xgrammar":
|
184
220
|
from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
|
185
221
|
|
186
222
|
# Convert Set[int] to List[int] if needed
|
@@ -189,17 +225,17 @@ def create_grammar_backend(
|
|
189
225
|
grammar_backend = XGrammarGrammarBackend(
|
190
226
|
tokenizer, vocab_size=vocab_size, model_eos_token_ids=eos_list
|
191
227
|
)
|
192
|
-
elif
|
228
|
+
elif name == "llguidance":
|
193
229
|
from sglang.srt.constrained.llguidance_backend import GuidanceBackend
|
194
230
|
|
195
231
|
grammar_backend = GuidanceBackend(
|
196
232
|
tokenizer=tokenizer,
|
197
233
|
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
198
234
|
)
|
199
|
-
elif
|
235
|
+
elif name == "none":
|
200
236
|
return None
|
201
237
|
else:
|
202
|
-
raise ValueError(f"Invalid grammar backend: {
|
238
|
+
raise ValueError(f"Invalid grammar backend: {name}")
|
203
239
|
|
204
240
|
if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
|
205
241
|
from sglang.srt.constrained.reasoner_grammar_backend import (
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""Constrained decoding with xgrammar backend."""
|
15
15
|
|
16
|
+
import dataclasses
|
16
17
|
import json
|
17
18
|
import logging
|
18
19
|
from typing import List, Optional, Tuple, Union
|
@@ -31,6 +32,7 @@ from sglang.srt.constrained.base_grammar_backend import (
|
|
31
32
|
INVALID_GRAMMAR_OBJ,
|
32
33
|
BaseGrammarBackend,
|
33
34
|
BaseGrammarObject,
|
35
|
+
GrammarStats,
|
34
36
|
)
|
35
37
|
from sglang.srt.utils import is_hip
|
36
38
|
|
@@ -41,9 +43,9 @@ else:
|
|
41
43
|
from sglang.srt.constrained.triton_ops.bitmask_ops import (
|
42
44
|
apply_token_bitmask_inplace_triton,
|
43
45
|
)
|
44
|
-
logger = logging.getLogger(__name__)
|
45
46
|
|
46
47
|
|
48
|
+
logger = logging.getLogger(__name__)
|
47
49
|
MAX_ROLLBACK_TOKENS = 200
|
48
50
|
|
49
51
|
|
@@ -56,17 +58,20 @@ class XGrammarGrammar(BaseGrammarObject):
|
|
56
58
|
ctx: CompiledGrammar,
|
57
59
|
override_stop_tokens: Optional[Union[List[int], int]],
|
58
60
|
key_string: Optional[str] = None, # TODO (sk): for debugging, remove later
|
61
|
+
grammar_stats: Optional[GrammarStats] = GrammarStats(),
|
59
62
|
) -> None:
|
63
|
+
super().__init__()
|
60
64
|
self.matcher = matcher
|
61
65
|
self.vocab_size = vocab_size
|
62
66
|
self.ctx = ctx
|
63
67
|
self.override_stop_tokens = override_stop_tokens
|
64
|
-
self.finished = False
|
65
68
|
self.accepted_tokens = []
|
66
69
|
self.key_string = key_string
|
70
|
+
self.grammar_stats = grammar_stats
|
67
71
|
|
68
72
|
def accept_token(self, token: int):
|
69
73
|
if not self.is_terminated():
|
74
|
+
self.current_token = token
|
70
75
|
accepted = self.matcher.accept_token(token)
|
71
76
|
if not accepted:
|
72
77
|
# log for debugging
|
@@ -120,6 +125,9 @@ class XGrammarGrammar(BaseGrammarObject):
|
|
120
125
|
self.ctx,
|
121
126
|
self.override_stop_tokens,
|
122
127
|
self.key_string,
|
128
|
+
dataclasses.replace(
|
129
|
+
self.grammar_stats, is_cache_hit=True, tree_traversal_time=[]
|
130
|
+
),
|
123
131
|
)
|
124
132
|
|
125
133
|
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
@@ -150,7 +158,7 @@ class XGrammarGrammar(BaseGrammarObject):
|
|
150
158
|
assert self.matcher.accept_token(new_output_ids[i])
|
151
159
|
|
152
160
|
def __repr__(self):
|
153
|
-
return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=})"
|
161
|
+
return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=}, {self.current_token=})"
|
154
162
|
|
155
163
|
|
156
164
|
class XGrammarGrammarBackend(BaseGrammarBackend):
|
@@ -165,6 +173,10 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
165
173
|
if hasattr(tokenizer, "init_xgrammar"):
|
166
174
|
# For special tokenizer
|
167
175
|
tokenizer_info, override_stop_tokens = tokenizer.init_xgrammar()
|
176
|
+
|
177
|
+
if tokenizer_info is None:
|
178
|
+
# Not supported tokenizer
|
179
|
+
return
|
168
180
|
else:
|
169
181
|
# Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens
|
170
182
|
# This ensures consistency between what the model considers EOS and what XGrammar uses
|
@@ -177,14 +189,21 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
177
189
|
self.vocab_size = vocab_size
|
178
190
|
self.override_stop_tokens = override_stop_tokens
|
179
191
|
|
180
|
-
def _from_context(
|
192
|
+
def _from_context(
|
193
|
+
self, ctx: CompiledGrammar, key_string: str, grammar_stats: GrammarStats
|
194
|
+
) -> XGrammarGrammar:
|
181
195
|
matcher = GrammarMatcher(
|
182
196
|
ctx,
|
183
197
|
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
184
198
|
override_stop_tokens=self.override_stop_tokens,
|
185
199
|
)
|
186
200
|
return XGrammarGrammar(
|
187
|
-
matcher,
|
201
|
+
matcher,
|
202
|
+
self.vocab_size,
|
203
|
+
ctx,
|
204
|
+
self.override_stop_tokens,
|
205
|
+
key_string,
|
206
|
+
grammar_stats,
|
188
207
|
)
|
189
208
|
|
190
209
|
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
|
@@ -198,7 +217,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
198
217
|
except (RuntimeError, json.decoder.JSONDecodeError) as e:
|
199
218
|
logging.error(f"Hit invalid json_schema: {key_string=}, {e=}")
|
200
219
|
return INVALID_GRAMMAR_OBJ
|
201
|
-
return self._from_context(ctx, key_string)
|
220
|
+
return self._from_context(ctx, key_string, GrammarStats())
|
202
221
|
|
203
222
|
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
|
204
223
|
try:
|
@@ -206,7 +225,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
206
225
|
except RuntimeError as e:
|
207
226
|
logging.error(f"Hit invalid ebnf: {key_string=}, {e=}")
|
208
227
|
return INVALID_GRAMMAR_OBJ
|
209
|
-
return self._from_context(ctx, key_string)
|
228
|
+
return self._from_context(ctx, key_string, GrammarStats())
|
210
229
|
|
211
230
|
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
|
212
231
|
try:
|
@@ -214,7 +233,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
214
233
|
except RuntimeError as e:
|
215
234
|
logging.error(f"Hit invalid regex: {key_string=}, {e=}")
|
216
235
|
return INVALID_GRAMMAR_OBJ
|
217
|
-
return self._from_context(ctx, key_string)
|
236
|
+
return self._from_context(ctx, key_string, GrammarStats())
|
218
237
|
|
219
238
|
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
|
220
239
|
try:
|
@@ -233,7 +252,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
233
252
|
except (RuntimeError, json.decoder.JSONDecodeError) as e:
|
234
253
|
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
|
235
254
|
return INVALID_GRAMMAR_OBJ
|
236
|
-
return self._from_context(ctx, key_string)
|
255
|
+
return self._from_context(ctx, key_string, GrammarStats())
|
237
256
|
|
238
257
|
def reset(self):
|
239
258
|
self.grammar_compiler.clear_cache()
|
sglang/srt/custom_op.py
CHANGED
@@ -1,12 +1,20 @@
|
|
1
1
|
from torch import nn
|
2
2
|
|
3
|
-
from sglang.srt.utils import
|
3
|
+
from sglang.srt.utils import (
|
4
|
+
cpu_has_amx_support,
|
5
|
+
is_cpu,
|
6
|
+
is_cuda,
|
7
|
+
is_hip,
|
8
|
+
is_npu,
|
9
|
+
is_xpu,
|
10
|
+
)
|
4
11
|
|
5
12
|
_is_cuda = is_cuda()
|
6
13
|
_is_hip = is_hip()
|
7
14
|
_is_cpu = is_cpu()
|
8
15
|
_is_cpu_amx_available = cpu_has_amx_support()
|
9
16
|
_is_npu = is_npu()
|
17
|
+
_is_xpu = is_xpu()
|
10
18
|
|
11
19
|
|
12
20
|
class CustomOp(nn.Module):
|
@@ -88,5 +96,7 @@ class CustomOp(nn.Module):
|
|
88
96
|
return self.forward_cpu
|
89
97
|
elif _is_npu:
|
90
98
|
return self.forward_npu
|
99
|
+
elif _is_xpu:
|
100
|
+
return self.forward_xpu
|
91
101
|
else:
|
92
102
|
return self.forward_native
|
@@ -1,11 +1,11 @@
|
|
1
1
|
import argparse
|
2
2
|
import functools
|
3
|
-
import re
|
4
3
|
from pathlib import Path
|
5
4
|
|
6
5
|
import polars as pl
|
7
6
|
import torch
|
8
7
|
|
8
|
+
from sglang.srt.debug_utils.dump_loader import find_row, read_meta
|
9
9
|
from sglang.srt.debug_utils.dumper import get_truncated_value
|
10
10
|
|
11
11
|
|
@@ -26,66 +26,77 @@ def main(args):
|
|
26
26
|
print("df_baseline", df_baseline)
|
27
27
|
|
28
28
|
for row in df_target.iter_rows(named=True):
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
29
|
+
path_target = Path(args.target_path) / row["filename"]
|
30
|
+
|
31
|
+
row_baseline = find_row(
|
32
|
+
df_baseline,
|
33
|
+
conditions=dict(
|
34
|
+
forward_pass_id=row["forward_pass_id"]
|
35
|
+
- args.start_id
|
36
|
+
+ args.baseline_start_id,
|
37
|
+
**{
|
38
|
+
k: v
|
39
|
+
for k, v in row.items()
|
40
|
+
if k not in ["forward_pass_id", "dump_index", "filename"]
|
41
|
+
},
|
42
|
+
),
|
42
43
|
)
|
43
|
-
|
44
|
-
row_baseline
|
44
|
+
|
45
|
+
if row_baseline is None:
|
46
|
+
print(f"Skip: target={str(path_target)} since no baseline")
|
47
|
+
x_target = _load_object(path_target)
|
48
|
+
if x_target is not None:
|
49
|
+
print(f"x_target(sample)={get_truncated_value(x_target)}")
|
50
|
+
continue
|
45
51
|
|
46
52
|
path_baseline = Path(args.baseline_path) / row_baseline["filename"]
|
47
|
-
path_target = Path(args.target_path) / row["filename"]
|
48
53
|
print(f"Check: target={str(path_target)} baseline={str(path_baseline)}")
|
49
|
-
check_tensor_pair(
|
54
|
+
check_tensor_pair(
|
55
|
+
path_baseline=path_baseline, path_target=path_target, name=row["name"]
|
56
|
+
)
|
50
57
|
print()
|
51
58
|
|
52
59
|
|
53
|
-
def
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
rows = []
|
58
|
-
for p in directory.glob("*.pt"):
|
59
|
-
full_kwargs = {}
|
60
|
-
for kv in p.stem.split("___"):
|
61
|
-
k, v = kv.split("=")
|
62
|
-
full_kwargs[k] = v
|
63
|
-
rows.append(
|
64
|
-
{
|
65
|
-
"filename": str(p.name),
|
66
|
-
**full_kwargs,
|
67
|
-
}
|
68
|
-
)
|
60
|
+
def check_tensor_pair(path_baseline, path_target, name=""):
|
61
|
+
x_baseline = _load_object(path_baseline)
|
62
|
+
x_target = _load_object(path_target)
|
69
63
|
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
64
|
+
print(
|
65
|
+
f"Raw "
|
66
|
+
f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
|
67
|
+
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
|
74
68
|
)
|
75
|
-
return df
|
76
|
-
|
77
69
|
|
78
|
-
|
79
|
-
x_baseline =
|
80
|
-
x_target = torch.load(path_target, weights_only=True)
|
70
|
+
x_baseline, x_target = _comparison_preprocessor(x_baseline, x_target, name=name)
|
71
|
+
x_baseline = _try_unify_shape(x_baseline, target_shape=x_target.shape)
|
81
72
|
|
82
73
|
print(
|
74
|
+
f"After preprocessor "
|
83
75
|
f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
|
84
76
|
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
|
85
77
|
)
|
86
78
|
|
79
|
+
x_target = x_target.float()
|
80
|
+
x_baseline = x_baseline.float()
|
81
|
+
|
82
|
+
for name, fn in (
|
83
|
+
("mean", torch.mean),
|
84
|
+
("std", torch.std),
|
85
|
+
("min", torch.min),
|
86
|
+
("max", torch.max),
|
87
|
+
("p1", functools.partial(torch.quantile, q=0.01)),
|
88
|
+
("p5", functools.partial(torch.quantile, q=0.05)),
|
89
|
+
("p95", functools.partial(torch.quantile, q=0.95)),
|
90
|
+
("p99", functools.partial(torch.quantile, q=0.99)),
|
91
|
+
):
|
92
|
+
value_baseline = fn(x_baseline).item()
|
93
|
+
value_target = fn(x_target).item()
|
94
|
+
print(
|
95
|
+
f"[{name}] {value_baseline :.4f} vs {value_target:.4f} (diff: {value_target - value_baseline:.4f})"
|
96
|
+
)
|
97
|
+
|
87
98
|
if x_baseline.shape != x_target.shape:
|
88
|
-
print(f"
|
99
|
+
print(f"⚠️ Shape mismatch")
|
89
100
|
return
|
90
101
|
|
91
102
|
raw_abs_diff = (x_target - x_baseline).abs()
|
@@ -112,6 +123,19 @@ def check_tensor_pair(path_baseline, path_target):
|
|
112
123
|
print(f"x_target(sample)={get_truncated_value(x_target)}")
|
113
124
|
|
114
125
|
|
126
|
+
def _try_unify_shape(x: torch.Tensor, target_shape):
|
127
|
+
x_shape = x.shape
|
128
|
+
num_dim_to_remove = len(x_shape) - len(target_shape)
|
129
|
+
if (x_shape[num_dim_to_remove:] == target_shape) and all(
|
130
|
+
val == 1 for val in x_shape[:num_dim_to_remove]
|
131
|
+
):
|
132
|
+
out = functools.reduce(lambda a, _: a.squeeze(0), range(num_dim_to_remove), x)
|
133
|
+
print(f"Unify shape: {x_shape} -> {out.shape} (to match {target_shape})")
|
134
|
+
return out
|
135
|
+
|
136
|
+
return x
|
137
|
+
|
138
|
+
|
115
139
|
# Copied from DeepGEMM
|
116
140
|
def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
|
117
141
|
x, y = x.double(), y.double()
|
@@ -120,6 +144,19 @@ def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
|
|
120
144
|
return 1 - sim
|
121
145
|
|
122
146
|
|
147
|
+
def _comparison_preprocessor(x_baseline, x_target, name):
|
148
|
+
# can insert arbitrary adhoc postprocessing logic here
|
149
|
+
return x_baseline, x_target
|
150
|
+
|
151
|
+
|
152
|
+
def _load_object(path):
|
153
|
+
x = torch.load(path, weights_only=False)
|
154
|
+
if not isinstance(x, torch.Tensor):
|
155
|
+
print(f"Skip load {path} since {type(x)=} is not a Tensor")
|
156
|
+
return None
|
157
|
+
return x.cuda()
|
158
|
+
|
159
|
+
|
123
160
|
if __name__ == "__main__":
|
124
161
|
parser = argparse.ArgumentParser()
|
125
162
|
parser.add_argument("--baseline-path", type=str)
|
@@ -0,0 +1,97 @@
|
|
1
|
+
import functools
|
2
|
+
import os
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import Any, Dict
|
5
|
+
|
6
|
+
import polars as pl
|
7
|
+
import torch
|
8
|
+
|
9
|
+
|
10
|
+
class DumpLoader:
|
11
|
+
def __init__(self):
|
12
|
+
directory = os.environ.get("SGLANG_DUMP_LOADER_DIR")
|
13
|
+
|
14
|
+
self._enable = directory is not None
|
15
|
+
if self._enable:
|
16
|
+
self._directory = Path(directory)
|
17
|
+
self._df = read_meta(directory)
|
18
|
+
|
19
|
+
@property
|
20
|
+
def enable(self):
|
21
|
+
return self._enable
|
22
|
+
|
23
|
+
def load(self, name, **kwargs):
|
24
|
+
assert self._enable, "Please call DumpLoader.load only when it is enabled"
|
25
|
+
|
26
|
+
from sglang.srt.debug_utils.dumper import dumper
|
27
|
+
|
28
|
+
forward_pass_id = dumper._forward_pass_id
|
29
|
+
conditions = dict(name=name, forward_pass_id=forward_pass_id, **kwargs)
|
30
|
+
row = find_row(self._df, conditions=conditions)
|
31
|
+
assert (
|
32
|
+
row is not None
|
33
|
+
), f"DumpLoader cannot find row given query {name=} {kwargs=} {self._directory=}"
|
34
|
+
|
35
|
+
path = self._directory / row["filename"]
|
36
|
+
output = torch.load(path, weights_only=False)
|
37
|
+
|
38
|
+
print(
|
39
|
+
f"[DumpLoader] load from {path=} (query: {name=} {kwargs=}, output: {type(output)})"
|
40
|
+
)
|
41
|
+
return output
|
42
|
+
|
43
|
+
|
44
|
+
def read_meta(directory):
|
45
|
+
directory = Path(directory)
|
46
|
+
assert directory.is_dir(), f"{directory=} should be a directory"
|
47
|
+
|
48
|
+
rows = []
|
49
|
+
for p in directory.glob("*.pt"):
|
50
|
+
full_kwargs = {}
|
51
|
+
for kv in p.stem.split("___"):
|
52
|
+
k, v = kv.split("=")
|
53
|
+
full_kwargs[k] = v
|
54
|
+
rows.append(
|
55
|
+
{
|
56
|
+
"filename": str(p.name),
|
57
|
+
**full_kwargs,
|
58
|
+
}
|
59
|
+
)
|
60
|
+
|
61
|
+
df = pl.DataFrame(rows)
|
62
|
+
df = df.with_columns(
|
63
|
+
pl.col("forward_pass_id").cast(int),
|
64
|
+
pl.col("rank").cast(int),
|
65
|
+
pl.col("dump_index").cast(int),
|
66
|
+
)
|
67
|
+
return df
|
68
|
+
|
69
|
+
|
70
|
+
def find_row(df, conditions: Dict[str, Any]):
|
71
|
+
df_sub = df.filter(
|
72
|
+
functools.reduce(
|
73
|
+
lambda a, b: a & b,
|
74
|
+
[
|
75
|
+
pl.col(col) == _cast_to_polars_dtype(conditions[col], df.schema[col])
|
76
|
+
for col in conditions.keys()
|
77
|
+
],
|
78
|
+
)
|
79
|
+
)
|
80
|
+
assert len(df_sub) <= 1
|
81
|
+
return df_sub.to_dicts()[0] if len(df_sub) > 0 else None
|
82
|
+
|
83
|
+
|
84
|
+
def _cast_to_polars_dtype(value, target_dtype):
|
85
|
+
if target_dtype in (pl.Int64, pl.Int32, pl.UInt64, pl.UInt32):
|
86
|
+
return int(value)
|
87
|
+
elif target_dtype in (pl.Float64, pl.Float32):
|
88
|
+
return float(value)
|
89
|
+
elif target_dtype == pl.Boolean:
|
90
|
+
return bool(value)
|
91
|
+
elif target_dtype == pl.String:
|
92
|
+
return str(value)
|
93
|
+
else:
|
94
|
+
return value
|
95
|
+
|
96
|
+
|
97
|
+
dump_loader = DumpLoader()
|
sglang/srt/debug_utils/dumper.py
CHANGED
@@ -53,7 +53,7 @@ class _Dumper:
|
|
53
53
|
if self._partial_name is None:
|
54
54
|
self._partial_name = _get_partial_name()
|
55
55
|
|
56
|
-
rank =
|
56
|
+
rank = _get_rank()
|
57
57
|
full_kwargs = dict(
|
58
58
|
forward_pass_id=self._forward_pass_id,
|
59
59
|
rank=rank,
|
@@ -80,12 +80,20 @@ class _Dumper:
|
|
80
80
|
|
81
81
|
|
82
82
|
def _get_partial_name():
|
83
|
-
rank =
|
83
|
+
rank = _get_rank()
|
84
84
|
object_list = [str(time.time()) if rank == 0 else None]
|
85
|
-
dist.
|
85
|
+
if dist.is_initialized():
|
86
|
+
dist.broadcast_object_list(object_list, device="cuda")
|
86
87
|
return object_list[0]
|
87
88
|
|
88
89
|
|
90
|
+
def _get_rank():
|
91
|
+
if dist.is_initialized():
|
92
|
+
return dist.get_rank()
|
93
|
+
else:
|
94
|
+
return 0
|
95
|
+
|
96
|
+
|
89
97
|
def get_truncated_value(value):
|
90
98
|
if value is None:
|
91
99
|
return None
|