sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,82 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import Generator, List, Optional, Tuple
|
5
|
+
from urllib.parse import urlparse
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import torch.distributed as dist
|
9
|
+
|
10
|
+
from sglang.srt.connector import BaseConnector
|
11
|
+
from sglang.srt.utils import init_custom_process_group
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class RemoteInstanceConnector(BaseConnector):
|
17
|
+
|
18
|
+
def __init__(self, url: str, device: torch.device = "cpu"):
|
19
|
+
assert (
|
20
|
+
device.type == "cuda"
|
21
|
+
), "RemoteInstanceConnector only supports cuda device."
|
22
|
+
super().__init__(url)
|
23
|
+
self.url = url
|
24
|
+
self.device = device
|
25
|
+
|
26
|
+
def build_group(
|
27
|
+
self,
|
28
|
+
gpu_id: int = -1,
|
29
|
+
tp_rank: int = -1,
|
30
|
+
instance_ip: str = None,
|
31
|
+
group_rank: int = 1,
|
32
|
+
world_size: int = 2,
|
33
|
+
):
|
34
|
+
assert (
|
35
|
+
self.device.type == "cuda"
|
36
|
+
), "RemoteInstanceConnector only supports cuda device."
|
37
|
+
assert (
|
38
|
+
gpu_id != -1 and tp_rank != -1
|
39
|
+
), "gpu_id and tp_rank must be specified for RemoteInstanceConnector. "
|
40
|
+
|
41
|
+
self.device_id = torch.device(self.device.type, gpu_id)
|
42
|
+
|
43
|
+
parsed_url = urlparse(self.url)
|
44
|
+
master_address = parsed_url.hostname
|
45
|
+
master_port = parsed_url.port
|
46
|
+
group_name = f"send_weights_{instance_ip}_{master_port}_{tp_rank}"
|
47
|
+
backend = "nccl"
|
48
|
+
|
49
|
+
logger.info(
|
50
|
+
f"init custom process group: master_address={master_address}, master_port={master_port}, "
|
51
|
+
f"rank_offset={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
52
|
+
)
|
53
|
+
|
54
|
+
try:
|
55
|
+
self._model_update_group = init_custom_process_group(
|
56
|
+
backend=backend,
|
57
|
+
init_method=f"tcp://{master_address}:{master_port}",
|
58
|
+
world_size=world_size,
|
59
|
+
rank=group_rank,
|
60
|
+
group_name=group_name,
|
61
|
+
device_id=self.device_id,
|
62
|
+
)
|
63
|
+
dist.barrier(group=self._model_update_group)
|
64
|
+
return True, "Succeeded to initialize custom process group."
|
65
|
+
except Exception as e:
|
66
|
+
message = f"Failed to initialize custom process group: {e}."
|
67
|
+
logger.error(message)
|
68
|
+
return False, message
|
69
|
+
|
70
|
+
# Implemented as a no-op to make BaseConnector interface consistent.
|
71
|
+
def pull_files(
|
72
|
+
self,
|
73
|
+
allow_pattern: Optional[list[str]] = None,
|
74
|
+
ignore_pattern: Optional[list[str]] = None,
|
75
|
+
) -> None:
|
76
|
+
return
|
77
|
+
|
78
|
+
# Implemented as a no-op to make BaseConnector interface consistent.
|
79
|
+
def weight_iterator(
|
80
|
+
self, rank: int = 0
|
81
|
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
82
|
+
return
|
@@ -14,8 +14,9 @@
|
|
14
14
|
"""The baseclass of a backend for grammar-guided constrained decoding."""
|
15
15
|
|
16
16
|
import logging
|
17
|
+
import time
|
17
18
|
from concurrent.futures import ThreadPoolExecutor
|
18
|
-
from dataclasses import dataclass
|
19
|
+
from dataclasses import dataclass, field
|
19
20
|
from threading import Event
|
20
21
|
from typing import Dict, List, Optional, Tuple
|
21
22
|
|
@@ -26,10 +27,23 @@ from sglang.srt.server_args import ServerArgs
|
|
26
27
|
logger = logging.getLogger(__name__)
|
27
28
|
|
28
29
|
|
30
|
+
@dataclass
|
31
|
+
class GrammarStats:
|
32
|
+
compilation_time: Optional[float] = None
|
33
|
+
schema_count: Optional[int] = None
|
34
|
+
ebnf_size: Optional[int] = None
|
35
|
+
is_cache_hit: bool = False
|
36
|
+
is_grammar_aborted: bool = False
|
37
|
+
tree_traversal_time: List[float] = field(default_factory=list)
|
38
|
+
dispatch_type: Optional[str] = None
|
39
|
+
|
40
|
+
|
29
41
|
class BaseGrammarObject:
|
30
42
|
|
31
43
|
def __init__(self):
|
32
44
|
self._finished = False
|
45
|
+
self.grammar_stats = None
|
46
|
+
self.current_token = None
|
33
47
|
|
34
48
|
def accept_token(self, token: int) -> None:
|
35
49
|
"""
|
@@ -137,19 +151,26 @@ class BaseGrammarBackend:
|
|
137
151
|
return self._not_supported("structural_tag", key_string)
|
138
152
|
|
139
153
|
def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
154
|
+
s = time.perf_counter()
|
140
155
|
key_type, key_string = key
|
141
156
|
if key_type == "json":
|
142
|
-
|
157
|
+
grammar = self.dispatch_json(key_string)
|
143
158
|
elif key_type == "regex":
|
144
|
-
|
159
|
+
grammar = self.dispatch_regex(key_string)
|
145
160
|
elif key_type == "ebnf":
|
146
|
-
|
161
|
+
grammar = self.dispatch_ebnf(key_string)
|
147
162
|
elif key_type == "structural_tag":
|
148
|
-
|
163
|
+
grammar = self.dispatch_structural_tag(key_string)
|
149
164
|
elif key_type == "structural_pattern":
|
150
|
-
|
165
|
+
grammar = self.dispatch_structural_pattern(key_string)
|
166
|
+
elif key_type == "structural_pattern_v2":
|
167
|
+
grammar = self.dispatch_structural_pattern_v2(key_string)
|
151
168
|
else:
|
152
|
-
|
169
|
+
grammar = self.dispatch_fallback(key_type, key_string)
|
170
|
+
|
171
|
+
if grammar is not None and grammar.grammar_stats is not None:
|
172
|
+
grammar.grammar_stats.compilation_time = time.perf_counter() - s
|
173
|
+
return grammar
|
153
174
|
|
154
175
|
def get_cached_or_future_value(
|
155
176
|
self, key: Tuple[str, str]
|
@@ -167,20 +188,36 @@ class BaseGrammarBackend:
|
|
167
188
|
self.cache.clear()
|
168
189
|
|
169
190
|
|
191
|
+
GRAMMAR_BACKEND_REGISTRY = {}
|
192
|
+
|
193
|
+
|
194
|
+
def register_grammar_backend(name, init_func):
|
195
|
+
GRAMMAR_BACKEND_REGISTRY[name] = init_func
|
196
|
+
|
197
|
+
|
170
198
|
def create_grammar_backend(
|
171
199
|
server_args: ServerArgs,
|
172
200
|
tokenizer,
|
173
201
|
vocab_size: int,
|
174
202
|
eos_token_ids: Optional[set] = None,
|
175
203
|
) -> Optional[BaseGrammarBackend]:
|
176
|
-
|
204
|
+
name = server_args.grammar_backend
|
205
|
+
|
206
|
+
# Custom grammar backend has the highest priority
|
207
|
+
if name in GRAMMAR_BACKEND_REGISTRY:
|
208
|
+
return GRAMMAR_BACKEND_REGISTRY[name](
|
209
|
+
server_args, tokenizer, vocab_size, eos_token_ids
|
210
|
+
)
|
211
|
+
|
212
|
+
# Default grammar backends
|
213
|
+
if name == "outlines":
|
177
214
|
from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
|
178
215
|
|
179
216
|
grammar_backend = OutlinesGrammarBackend(
|
180
217
|
tokenizer,
|
181
218
|
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
182
219
|
)
|
183
|
-
elif
|
220
|
+
elif name == "xgrammar":
|
184
221
|
from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
|
185
222
|
|
186
223
|
# Convert Set[int] to List[int] if needed
|
@@ -189,17 +226,17 @@ def create_grammar_backend(
|
|
189
226
|
grammar_backend = XGrammarGrammarBackend(
|
190
227
|
tokenizer, vocab_size=vocab_size, model_eos_token_ids=eos_list
|
191
228
|
)
|
192
|
-
elif
|
229
|
+
elif name == "llguidance":
|
193
230
|
from sglang.srt.constrained.llguidance_backend import GuidanceBackend
|
194
231
|
|
195
232
|
grammar_backend = GuidanceBackend(
|
196
233
|
tokenizer=tokenizer,
|
197
234
|
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
198
235
|
)
|
199
|
-
elif
|
236
|
+
elif name == "none":
|
200
237
|
return None
|
201
238
|
else:
|
202
|
-
raise ValueError(f"Invalid grammar backend: {
|
239
|
+
raise ValueError(f"Invalid grammar backend: {name}")
|
203
240
|
|
204
241
|
if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
|
205
242
|
from sglang.srt.constrained.reasoner_grammar_backend import (
|
@@ -37,7 +37,7 @@ except ImportError:
|
|
37
37
|
|
38
38
|
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
39
39
|
|
40
|
-
# Env var was set in sglang.srt.server_args.ServerArgs.
|
40
|
+
# Env var was set in sglang.srt.server_args.ServerArgs.__post_init__
|
41
41
|
DISABLE_DISK_CACHE = get_bool_env_var("SGLANG_DISABLE_OUTLINES_DISK_CACHE", "true")
|
42
42
|
|
43
43
|
logger = logging.getLogger(__name__)
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""Constrained decoding with xgrammar backend."""
|
15
15
|
|
16
|
+
import dataclasses
|
16
17
|
import json
|
17
18
|
import logging
|
18
19
|
from typing import List, Optional, Tuple, Union
|
@@ -31,6 +32,7 @@ from sglang.srt.constrained.base_grammar_backend import (
|
|
31
32
|
INVALID_GRAMMAR_OBJ,
|
32
33
|
BaseGrammarBackend,
|
33
34
|
BaseGrammarObject,
|
35
|
+
GrammarStats,
|
34
36
|
)
|
35
37
|
from sglang.srt.utils import is_hip
|
36
38
|
|
@@ -41,9 +43,9 @@ else:
|
|
41
43
|
from sglang.srt.constrained.triton_ops.bitmask_ops import (
|
42
44
|
apply_token_bitmask_inplace_triton,
|
43
45
|
)
|
44
|
-
logger = logging.getLogger(__name__)
|
45
46
|
|
46
47
|
|
48
|
+
logger = logging.getLogger(__name__)
|
47
49
|
MAX_ROLLBACK_TOKENS = 200
|
48
50
|
|
49
51
|
|
@@ -56,17 +58,20 @@ class XGrammarGrammar(BaseGrammarObject):
|
|
56
58
|
ctx: CompiledGrammar,
|
57
59
|
override_stop_tokens: Optional[Union[List[int], int]],
|
58
60
|
key_string: Optional[str] = None, # TODO (sk): for debugging, remove later
|
61
|
+
grammar_stats: Optional[GrammarStats] = GrammarStats(),
|
59
62
|
) -> None:
|
63
|
+
super().__init__()
|
60
64
|
self.matcher = matcher
|
61
65
|
self.vocab_size = vocab_size
|
62
66
|
self.ctx = ctx
|
63
67
|
self.override_stop_tokens = override_stop_tokens
|
64
|
-
self.finished = False
|
65
68
|
self.accepted_tokens = []
|
66
69
|
self.key_string = key_string
|
70
|
+
self.grammar_stats = grammar_stats
|
67
71
|
|
68
72
|
def accept_token(self, token: int):
|
69
73
|
if not self.is_terminated():
|
74
|
+
self.current_token = token
|
70
75
|
accepted = self.matcher.accept_token(token)
|
71
76
|
if not accepted:
|
72
77
|
# log for debugging
|
@@ -120,6 +125,9 @@ class XGrammarGrammar(BaseGrammarObject):
|
|
120
125
|
self.ctx,
|
121
126
|
self.override_stop_tokens,
|
122
127
|
self.key_string,
|
128
|
+
dataclasses.replace(
|
129
|
+
self.grammar_stats, is_cache_hit=True, tree_traversal_time=[]
|
130
|
+
),
|
123
131
|
)
|
124
132
|
|
125
133
|
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
@@ -150,7 +158,7 @@ class XGrammarGrammar(BaseGrammarObject):
|
|
150
158
|
assert self.matcher.accept_token(new_output_ids[i])
|
151
159
|
|
152
160
|
def __repr__(self):
|
153
|
-
return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=})"
|
161
|
+
return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=}, {self.current_token=})"
|
154
162
|
|
155
163
|
|
156
164
|
class XGrammarGrammarBackend(BaseGrammarBackend):
|
@@ -165,6 +173,10 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
165
173
|
if hasattr(tokenizer, "init_xgrammar"):
|
166
174
|
# For special tokenizer
|
167
175
|
tokenizer_info, override_stop_tokens = tokenizer.init_xgrammar()
|
176
|
+
|
177
|
+
if tokenizer_info is None:
|
178
|
+
# Not supported tokenizer
|
179
|
+
return
|
168
180
|
else:
|
169
181
|
# Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens
|
170
182
|
# This ensures consistency between what the model considers EOS and what XGrammar uses
|
@@ -177,14 +189,21 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
177
189
|
self.vocab_size = vocab_size
|
178
190
|
self.override_stop_tokens = override_stop_tokens
|
179
191
|
|
180
|
-
def _from_context(
|
192
|
+
def _from_context(
|
193
|
+
self, ctx: CompiledGrammar, key_string: str, grammar_stats: GrammarStats
|
194
|
+
) -> XGrammarGrammar:
|
181
195
|
matcher = GrammarMatcher(
|
182
196
|
ctx,
|
183
197
|
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
184
198
|
override_stop_tokens=self.override_stop_tokens,
|
185
199
|
)
|
186
200
|
return XGrammarGrammar(
|
187
|
-
matcher,
|
201
|
+
matcher,
|
202
|
+
self.vocab_size,
|
203
|
+
ctx,
|
204
|
+
self.override_stop_tokens,
|
205
|
+
key_string,
|
206
|
+
grammar_stats,
|
188
207
|
)
|
189
208
|
|
190
209
|
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
|
@@ -198,7 +217,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
198
217
|
except (RuntimeError, json.decoder.JSONDecodeError) as e:
|
199
218
|
logging.error(f"Hit invalid json_schema: {key_string=}, {e=}")
|
200
219
|
return INVALID_GRAMMAR_OBJ
|
201
|
-
return self._from_context(ctx, key_string)
|
220
|
+
return self._from_context(ctx, key_string, GrammarStats(dispatch_type="json"))
|
202
221
|
|
203
222
|
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
|
204
223
|
try:
|
@@ -206,7 +225,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
206
225
|
except RuntimeError as e:
|
207
226
|
logging.error(f"Hit invalid ebnf: {key_string=}, {e=}")
|
208
227
|
return INVALID_GRAMMAR_OBJ
|
209
|
-
return self._from_context(ctx, key_string)
|
228
|
+
return self._from_context(ctx, key_string, GrammarStats(dispatch_type="ebnf"))
|
210
229
|
|
211
230
|
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
|
212
231
|
try:
|
@@ -214,7 +233,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
214
233
|
except RuntimeError as e:
|
215
234
|
logging.error(f"Hit invalid regex: {key_string=}, {e=}")
|
216
235
|
return INVALID_GRAMMAR_OBJ
|
217
|
-
return self._from_context(ctx, key_string)
|
236
|
+
return self._from_context(ctx, key_string, GrammarStats(dispatch_type="regex"))
|
218
237
|
|
219
238
|
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
|
220
239
|
try:
|
@@ -233,7 +252,9 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
233
252
|
except (RuntimeError, json.decoder.JSONDecodeError) as e:
|
234
253
|
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
|
235
254
|
return INVALID_GRAMMAR_OBJ
|
236
|
-
return self._from_context(
|
255
|
+
return self._from_context(
|
256
|
+
ctx, key_string, GrammarStats(dispatch_type="structural_tag")
|
257
|
+
)
|
237
258
|
|
238
259
|
def reset(self):
|
239
260
|
self.grammar_compiler.clear_cache()
|
sglang/srt/custom_op.py
CHANGED
@@ -1,12 +1,20 @@
|
|
1
1
|
from torch import nn
|
2
2
|
|
3
|
-
from sglang.srt.utils import
|
3
|
+
from sglang.srt.utils import (
|
4
|
+
cpu_has_amx_support,
|
5
|
+
is_cpu,
|
6
|
+
is_cuda,
|
7
|
+
is_hip,
|
8
|
+
is_npu,
|
9
|
+
is_xpu,
|
10
|
+
)
|
4
11
|
|
5
12
|
_is_cuda = is_cuda()
|
6
13
|
_is_hip = is_hip()
|
7
14
|
_is_cpu = is_cpu()
|
8
15
|
_is_cpu_amx_available = cpu_has_amx_support()
|
9
16
|
_is_npu = is_npu()
|
17
|
+
_is_xpu = is_xpu()
|
10
18
|
|
11
19
|
|
12
20
|
class CustomOp(nn.Module):
|
@@ -88,5 +96,7 @@ class CustomOp(nn.Module):
|
|
88
96
|
return self.forward_cpu
|
89
97
|
elif _is_npu:
|
90
98
|
return self.forward_npu
|
99
|
+
elif _is_xpu:
|
100
|
+
return self.forward_xpu
|
91
101
|
else:
|
92
102
|
return self.forward_native
|
@@ -1,11 +1,11 @@
|
|
1
1
|
import argparse
|
2
2
|
import functools
|
3
|
-
import re
|
4
3
|
from pathlib import Path
|
5
4
|
|
6
5
|
import polars as pl
|
7
6
|
import torch
|
8
7
|
|
8
|
+
from sglang.srt.debug_utils.dump_loader import find_row, read_meta
|
9
9
|
from sglang.srt.debug_utils.dumper import get_truncated_value
|
10
10
|
|
11
11
|
|
@@ -26,66 +26,77 @@ def main(args):
|
|
26
26
|
print("df_baseline", df_baseline)
|
27
27
|
|
28
28
|
for row in df_target.iter_rows(named=True):
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
29
|
+
path_target = Path(args.target_path) / row["filename"]
|
30
|
+
|
31
|
+
row_baseline = find_row(
|
32
|
+
df_baseline,
|
33
|
+
conditions=dict(
|
34
|
+
forward_pass_id=row["forward_pass_id"]
|
35
|
+
- args.start_id
|
36
|
+
+ args.baseline_start_id,
|
37
|
+
**{
|
38
|
+
k: v
|
39
|
+
for k, v in row.items()
|
40
|
+
if k not in ["forward_pass_id", "dump_index", "filename"]
|
41
|
+
},
|
42
|
+
),
|
42
43
|
)
|
43
|
-
|
44
|
-
row_baseline
|
44
|
+
|
45
|
+
if row_baseline is None:
|
46
|
+
print(f"Skip: target={str(path_target)} since no baseline")
|
47
|
+
x_target = _load_object(path_target)
|
48
|
+
if x_target is not None:
|
49
|
+
print(f"x_target(sample)={get_truncated_value(x_target)}")
|
50
|
+
continue
|
45
51
|
|
46
52
|
path_baseline = Path(args.baseline_path) / row_baseline["filename"]
|
47
|
-
path_target = Path(args.target_path) / row["filename"]
|
48
53
|
print(f"Check: target={str(path_target)} baseline={str(path_baseline)}")
|
49
|
-
check_tensor_pair(
|
54
|
+
check_tensor_pair(
|
55
|
+
path_baseline=path_baseline, path_target=path_target, name=row["name"]
|
56
|
+
)
|
50
57
|
print()
|
51
58
|
|
52
59
|
|
53
|
-
def
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
rows = []
|
58
|
-
for p in directory.glob("*.pt"):
|
59
|
-
full_kwargs = {}
|
60
|
-
for kv in p.stem.split("___"):
|
61
|
-
k, v = kv.split("=")
|
62
|
-
full_kwargs[k] = v
|
63
|
-
rows.append(
|
64
|
-
{
|
65
|
-
"filename": str(p.name),
|
66
|
-
**full_kwargs,
|
67
|
-
}
|
68
|
-
)
|
60
|
+
def check_tensor_pair(path_baseline, path_target, name=""):
|
61
|
+
x_baseline = _load_object(path_baseline)
|
62
|
+
x_target = _load_object(path_target)
|
69
63
|
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
64
|
+
print(
|
65
|
+
f"Raw "
|
66
|
+
f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
|
67
|
+
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
|
74
68
|
)
|
75
|
-
return df
|
76
|
-
|
77
69
|
|
78
|
-
|
79
|
-
x_baseline =
|
80
|
-
x_target = torch.load(path_target, weights_only=True)
|
70
|
+
x_baseline, x_target = _comparison_preprocessor(x_baseline, x_target, name=name)
|
71
|
+
x_baseline = _try_unify_shape(x_baseline, target_shape=x_target.shape)
|
81
72
|
|
82
73
|
print(
|
74
|
+
f"After preprocessor "
|
83
75
|
f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
|
84
76
|
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
|
85
77
|
)
|
86
78
|
|
79
|
+
x_target = x_target.float()
|
80
|
+
x_baseline = x_baseline.float()
|
81
|
+
|
82
|
+
for name, fn in (
|
83
|
+
("mean", torch.mean),
|
84
|
+
("std", torch.std),
|
85
|
+
("min", torch.min),
|
86
|
+
("max", torch.max),
|
87
|
+
("p1", functools.partial(torch.quantile, q=0.01)),
|
88
|
+
("p5", functools.partial(torch.quantile, q=0.05)),
|
89
|
+
("p95", functools.partial(torch.quantile, q=0.95)),
|
90
|
+
("p99", functools.partial(torch.quantile, q=0.99)),
|
91
|
+
):
|
92
|
+
value_baseline = fn(x_baseline).item()
|
93
|
+
value_target = fn(x_target).item()
|
94
|
+
print(
|
95
|
+
f"[{name}] {value_baseline :.4f} vs {value_target:.4f} (diff: {value_target - value_baseline:.4f})"
|
96
|
+
)
|
97
|
+
|
87
98
|
if x_baseline.shape != x_target.shape:
|
88
|
-
print(f"
|
99
|
+
print(f"⚠️ Shape mismatch")
|
89
100
|
return
|
90
101
|
|
91
102
|
raw_abs_diff = (x_target - x_baseline).abs()
|
@@ -112,6 +123,19 @@ def check_tensor_pair(path_baseline, path_target):
|
|
112
123
|
print(f"x_target(sample)={get_truncated_value(x_target)}")
|
113
124
|
|
114
125
|
|
126
|
+
def _try_unify_shape(x: torch.Tensor, target_shape):
|
127
|
+
x_shape = x.shape
|
128
|
+
num_dim_to_remove = len(x_shape) - len(target_shape)
|
129
|
+
if (x_shape[num_dim_to_remove:] == target_shape) and all(
|
130
|
+
val == 1 for val in x_shape[:num_dim_to_remove]
|
131
|
+
):
|
132
|
+
out = functools.reduce(lambda a, _: a.squeeze(0), range(num_dim_to_remove), x)
|
133
|
+
print(f"Unify shape: {x_shape} -> {out.shape} (to match {target_shape})")
|
134
|
+
return out
|
135
|
+
|
136
|
+
return x
|
137
|
+
|
138
|
+
|
115
139
|
# Copied from DeepGEMM
|
116
140
|
def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
|
117
141
|
x, y = x.double(), y.double()
|
@@ -120,6 +144,19 @@ def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
|
|
120
144
|
return 1 - sim
|
121
145
|
|
122
146
|
|
147
|
+
def _comparison_preprocessor(x_baseline, x_target, name):
|
148
|
+
# can insert arbitrary adhoc postprocessing logic here
|
149
|
+
return x_baseline, x_target
|
150
|
+
|
151
|
+
|
152
|
+
def _load_object(path):
|
153
|
+
x = torch.load(path, weights_only=False)
|
154
|
+
if not isinstance(x, torch.Tensor):
|
155
|
+
print(f"Skip load {path} since {type(x)=} is not a Tensor")
|
156
|
+
return None
|
157
|
+
return x.cuda()
|
158
|
+
|
159
|
+
|
123
160
|
if __name__ == "__main__":
|
124
161
|
parser = argparse.ArgumentParser()
|
125
162
|
parser.add_argument("--baseline-path", type=str)
|
@@ -0,0 +1,97 @@
|
|
1
|
+
import functools
|
2
|
+
import os
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import Any, Dict
|
5
|
+
|
6
|
+
import polars as pl
|
7
|
+
import torch
|
8
|
+
|
9
|
+
|
10
|
+
class DumpLoader:
|
11
|
+
def __init__(self):
|
12
|
+
directory = os.environ.get("SGLANG_DUMP_LOADER_DIR")
|
13
|
+
|
14
|
+
self._enable = directory is not None
|
15
|
+
if self._enable:
|
16
|
+
self._directory = Path(directory)
|
17
|
+
self._df = read_meta(directory)
|
18
|
+
|
19
|
+
@property
|
20
|
+
def enable(self):
|
21
|
+
return self._enable
|
22
|
+
|
23
|
+
def load(self, name, **kwargs):
|
24
|
+
assert self._enable, "Please call DumpLoader.load only when it is enabled"
|
25
|
+
|
26
|
+
from sglang.srt.debug_utils.dumper import dumper
|
27
|
+
|
28
|
+
forward_pass_id = dumper._forward_pass_id
|
29
|
+
conditions = dict(name=name, forward_pass_id=forward_pass_id, **kwargs)
|
30
|
+
row = find_row(self._df, conditions=conditions)
|
31
|
+
assert (
|
32
|
+
row is not None
|
33
|
+
), f"DumpLoader cannot find row given query {name=} {kwargs=} {self._directory=}"
|
34
|
+
|
35
|
+
path = self._directory / row["filename"]
|
36
|
+
output = torch.load(path, weights_only=False)
|
37
|
+
|
38
|
+
print(
|
39
|
+
f"[DumpLoader] load from {path=} (query: {name=} {kwargs=}, output: {type(output)})"
|
40
|
+
)
|
41
|
+
return output
|
42
|
+
|
43
|
+
|
44
|
+
def read_meta(directory):
|
45
|
+
directory = Path(directory)
|
46
|
+
assert directory.is_dir(), f"{directory=} should be a directory"
|
47
|
+
|
48
|
+
rows = []
|
49
|
+
for p in directory.glob("*.pt"):
|
50
|
+
full_kwargs = {}
|
51
|
+
for kv in p.stem.split("___"):
|
52
|
+
k, v = kv.split("=")
|
53
|
+
full_kwargs[k] = v
|
54
|
+
rows.append(
|
55
|
+
{
|
56
|
+
"filename": str(p.name),
|
57
|
+
**full_kwargs,
|
58
|
+
}
|
59
|
+
)
|
60
|
+
|
61
|
+
df = pl.DataFrame(rows)
|
62
|
+
df = df.with_columns(
|
63
|
+
pl.col("forward_pass_id").cast(int),
|
64
|
+
pl.col("rank").cast(int),
|
65
|
+
pl.col("dump_index").cast(int),
|
66
|
+
)
|
67
|
+
return df
|
68
|
+
|
69
|
+
|
70
|
+
def find_row(df, conditions: Dict[str, Any]):
|
71
|
+
df_sub = df.filter(
|
72
|
+
functools.reduce(
|
73
|
+
lambda a, b: a & b,
|
74
|
+
[
|
75
|
+
pl.col(col) == _cast_to_polars_dtype(conditions[col], df.schema[col])
|
76
|
+
for col in conditions.keys()
|
77
|
+
],
|
78
|
+
)
|
79
|
+
)
|
80
|
+
assert len(df_sub) <= 1
|
81
|
+
return df_sub.to_dicts()[0] if len(df_sub) > 0 else None
|
82
|
+
|
83
|
+
|
84
|
+
def _cast_to_polars_dtype(value, target_dtype):
|
85
|
+
if target_dtype in (pl.Int64, pl.Int32, pl.UInt64, pl.UInt32):
|
86
|
+
return int(value)
|
87
|
+
elif target_dtype in (pl.Float64, pl.Float32):
|
88
|
+
return float(value)
|
89
|
+
elif target_dtype == pl.Boolean:
|
90
|
+
return bool(value)
|
91
|
+
elif target_dtype == pl.String:
|
92
|
+
return str(value)
|
93
|
+
else:
|
94
|
+
return value
|
95
|
+
|
96
|
+
|
97
|
+
dump_loader = DumpLoader()
|