sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,14 @@
|
|
|
2
2
|
Batch the same prompt in random batch sizes, and test if the results are consistent across different trials.
|
|
3
3
|
|
|
4
4
|
Usage:
|
|
5
|
-
|
|
5
|
+
# Single mode: test determinism with varying batch sizes
|
|
6
|
+
python3 -m sglang.test.test_deterministic --n-trials 50 --test-mode single
|
|
7
|
+
|
|
8
|
+
# Prefix mode: test with shared prefixes
|
|
9
|
+
python3 -m sglang.test.test_deterministic --n-start 1 --n-trials 50 --test-mode prefix
|
|
10
|
+
|
|
11
|
+
# Radix Cache Consistency mode: test radix cache determinism (cached vs uncached prefill)
|
|
12
|
+
python3 -m sglang.test.test_deterministic --test-mode radix_cache
|
|
6
13
|
"""
|
|
7
14
|
|
|
8
15
|
import argparse
|
|
@@ -39,12 +46,15 @@ class BenchArgs:
|
|
|
39
46
|
profile_steps: int = 3
|
|
40
47
|
profile_by_stage: bool = False
|
|
41
48
|
test_mode: str = "single"
|
|
49
|
+
n_trials: int = 50
|
|
50
|
+
n_start: int = 1
|
|
42
51
|
|
|
43
52
|
@staticmethod
|
|
44
53
|
def add_cli_args(parser: argparse.ArgumentParser):
|
|
45
54
|
parser.add_argument("--host", type=str, default=BenchArgs.host)
|
|
46
55
|
parser.add_argument("--port", type=int, default=BenchArgs.port)
|
|
47
|
-
parser.add_argument("--n-trials", type=int, default=
|
|
56
|
+
parser.add_argument("--n-trials", type=int, default=BenchArgs.n_trials)
|
|
57
|
+
parser.add_argument("--n-start", type=int, default=BenchArgs.n_start)
|
|
48
58
|
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
|
|
49
59
|
parser.add_argument(
|
|
50
60
|
"--sampling-seed", type=int, default=BenchArgs.sampling_seed
|
|
@@ -64,7 +74,11 @@ class BenchArgs:
|
|
|
64
74
|
"--test-mode",
|
|
65
75
|
type=str,
|
|
66
76
|
default=BenchArgs.test_mode,
|
|
67
|
-
choices=[
|
|
77
|
+
choices=[
|
|
78
|
+
"single",
|
|
79
|
+
"prefix",
|
|
80
|
+
"radix_cache",
|
|
81
|
+
],
|
|
68
82
|
)
|
|
69
83
|
parser.add_argument("--profile", action="store_true")
|
|
70
84
|
parser.add_argument(
|
|
@@ -80,26 +94,50 @@ class BenchArgs:
|
|
|
80
94
|
|
|
81
95
|
def send_single(
|
|
82
96
|
args,
|
|
83
|
-
batch_size: int,
|
|
97
|
+
batch_size: int = 1,
|
|
84
98
|
profile: bool = False,
|
|
85
99
|
profile_steps: int = 3,
|
|
86
100
|
profile_by_stage: bool = False,
|
|
101
|
+
return_full_response: bool = False,
|
|
102
|
+
input_ids: List[int] = None,
|
|
103
|
+
max_new_tokens: int = None,
|
|
87
104
|
):
|
|
88
|
-
|
|
89
105
|
base_url = f"http://{args.host}:{args.port}"
|
|
90
|
-
prompt = [PROMPT_1] * batch_size
|
|
91
106
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
"
|
|
96
|
-
"
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
107
|
+
# Use input_ids if provided, otherwise use text prompts
|
|
108
|
+
if input_ids is not None:
|
|
109
|
+
json_data = {
|
|
110
|
+
"input_ids": input_ids,
|
|
111
|
+
"sampling_params": {
|
|
112
|
+
"temperature": args.temperature,
|
|
113
|
+
"max_new_tokens": (
|
|
114
|
+
max_new_tokens
|
|
115
|
+
if max_new_tokens is not None
|
|
116
|
+
else args.max_new_tokens
|
|
117
|
+
),
|
|
118
|
+
"frequency_penalty": args.frequency_penalty,
|
|
119
|
+
"presence_penalty": args.presence_penalty,
|
|
120
|
+
},
|
|
121
|
+
"return_logprob": args.return_logprob,
|
|
122
|
+
"stream": args.stream,
|
|
123
|
+
}
|
|
124
|
+
else:
|
|
125
|
+
prompt = [PROMPT_1] * batch_size
|
|
126
|
+
json_data = {
|
|
127
|
+
"text": prompt,
|
|
128
|
+
"sampling_params": {
|
|
129
|
+
"temperature": args.temperature,
|
|
130
|
+
"max_new_tokens": (
|
|
131
|
+
max_new_tokens
|
|
132
|
+
if max_new_tokens is not None
|
|
133
|
+
else args.max_new_tokens
|
|
134
|
+
),
|
|
135
|
+
"frequency_penalty": args.frequency_penalty,
|
|
136
|
+
"presence_penalty": args.presence_penalty,
|
|
137
|
+
},
|
|
138
|
+
"return_logprob": args.return_logprob,
|
|
139
|
+
"stream": args.stream,
|
|
140
|
+
}
|
|
103
141
|
|
|
104
142
|
if args.sampling_seed is not None:
|
|
105
143
|
# sglang server cannot parse None value for sampling_seed
|
|
@@ -116,6 +154,11 @@ def send_single(
|
|
|
116
154
|
stream=args.stream,
|
|
117
155
|
)
|
|
118
156
|
|
|
157
|
+
if response.status_code != 200:
|
|
158
|
+
ret = response.json()
|
|
159
|
+
print(f"Error: {ret}")
|
|
160
|
+
return None
|
|
161
|
+
|
|
119
162
|
if args.stream:
|
|
120
163
|
for chunk in response.iter_lines(decode_unicode=False):
|
|
121
164
|
chunk = chunk.decode("utf-8")
|
|
@@ -125,59 +168,13 @@ def send_single(
|
|
|
125
168
|
ret = json.loads(chunk[5:].strip("\n"))
|
|
126
169
|
else:
|
|
127
170
|
ret = response.json()
|
|
128
|
-
ret = ret[0]
|
|
129
|
-
|
|
130
|
-
if response.status_code != 200:
|
|
131
|
-
print(ret)
|
|
132
|
-
return -1
|
|
133
|
-
|
|
134
|
-
return ret["text"]
|
|
135
171
|
|
|
172
|
+
ret = ret[0] if isinstance(ret, list) else ret
|
|
136
173
|
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
json_data = {
|
|
143
|
-
"text": [PROMPT_1] * num_prompt_1
|
|
144
|
-
+ [PROMPT_2] * num_prompt_2
|
|
145
|
-
+ [LONG_PROMPT] * num_long_prompt,
|
|
146
|
-
"sampling_params": {
|
|
147
|
-
"temperature": args.temperature,
|
|
148
|
-
"max_new_tokens": args.max_new_tokens,
|
|
149
|
-
"frequency_penalty": args.frequency_penalty,
|
|
150
|
-
"presence_penalty": args.presence_penalty,
|
|
151
|
-
},
|
|
152
|
-
"return_logprob": args.return_logprob,
|
|
153
|
-
"stream": args.stream,
|
|
154
|
-
}
|
|
155
|
-
|
|
156
|
-
if args.sampling_seed is not None:
|
|
157
|
-
json_data["sampling_params"]["sampling_seed"] = args.sampling_seed
|
|
158
|
-
|
|
159
|
-
response = requests.post(
|
|
160
|
-
f"http://{args.host}:{args.port}/generate",
|
|
161
|
-
json=json_data,
|
|
162
|
-
stream=args.stream,
|
|
163
|
-
)
|
|
164
|
-
ret = response.json()
|
|
165
|
-
if response.status_code != 200:
|
|
166
|
-
print(ret)
|
|
167
|
-
return -1, -1, -1
|
|
168
|
-
|
|
169
|
-
prompt_1_ret = [ret[i]["text"] for i in range(num_prompt_1)]
|
|
170
|
-
prompt_2_ret = [
|
|
171
|
-
ret[i]["text"] for i in range(num_prompt_1, num_prompt_1 + num_prompt_2)
|
|
172
|
-
]
|
|
173
|
-
long_prompt_ret = [
|
|
174
|
-
ret[i]["text"]
|
|
175
|
-
for i in range(
|
|
176
|
-
num_prompt_1 + num_prompt_2, num_prompt_1 + num_prompt_2 + num_long_prompt
|
|
177
|
-
)
|
|
178
|
-
]
|
|
179
|
-
|
|
180
|
-
return prompt_1_ret, prompt_2_ret, long_prompt_ret
|
|
174
|
+
if return_full_response:
|
|
175
|
+
return ret
|
|
176
|
+
else:
|
|
177
|
+
return ret["text"]
|
|
181
178
|
|
|
182
179
|
|
|
183
180
|
def send_prefix(args, batch_size: int, prompts: List[str]):
|
|
@@ -223,10 +220,6 @@ def send_prefix(args, batch_size: int, prompts: List[str]):
|
|
|
223
220
|
|
|
224
221
|
|
|
225
222
|
def test_deterministic(args):
|
|
226
|
-
# First do some warmups
|
|
227
|
-
for i in range(3):
|
|
228
|
-
send_single(args, 16, args.profile)
|
|
229
|
-
|
|
230
223
|
if args.test_mode == "single":
|
|
231
224
|
# In single mode, we test the deterministic behavior by sending the same prompt in batch sizes ranging from 1 to n_trials.
|
|
232
225
|
texts = []
|
|
@@ -236,33 +229,8 @@ def test_deterministic(args):
|
|
|
236
229
|
text = text.replace("\n", " ")
|
|
237
230
|
print(f"Trial {i} with batch size {batch_size}: {text}")
|
|
238
231
|
texts.append(text)
|
|
239
|
-
|
|
240
232
|
print(f"Total samples: {len(texts)}, Unique samples: {len(set(texts))}")
|
|
241
|
-
|
|
242
|
-
# In mixed mode, we send a mixture of two short prompts and one long prompt in the same batch with batch size ranging from 1 to n_trials.
|
|
243
|
-
output_prompt_1 = []
|
|
244
|
-
output_prompt_2 = []
|
|
245
|
-
output_long_prompt = []
|
|
246
|
-
for i in range(1, args.n_trials + 1):
|
|
247
|
-
batch_size = i
|
|
248
|
-
ret_prompt_1, ret_prompt_2, ret_long_prompt = send_mixed(args, batch_size)
|
|
249
|
-
output_prompt_1.extend(ret_prompt_1)
|
|
250
|
-
output_prompt_2.extend(ret_prompt_2)
|
|
251
|
-
output_long_prompt.extend(ret_long_prompt)
|
|
252
|
-
|
|
253
|
-
print(
|
|
254
|
-
f"Testing Trial {i} with batch size {batch_size}, number of prompt 1: {len(ret_prompt_1)}, number of prompt 2: {len(ret_prompt_2)}, number of long prompt: {len(ret_long_prompt)}"
|
|
255
|
-
)
|
|
256
|
-
|
|
257
|
-
print(
|
|
258
|
-
f"Prompt 1: total samples: {len(output_prompt_1)}, Unique samples: {len(set(output_prompt_1))}"
|
|
259
|
-
)
|
|
260
|
-
print(
|
|
261
|
-
f"Prompt 2: total samples: {len(output_prompt_2)}, Unique samples: {len(set(output_prompt_2))}"
|
|
262
|
-
)
|
|
263
|
-
print(
|
|
264
|
-
f"Long prompt: total samples: {len(output_long_prompt)}, Unique samples: {len(set(output_long_prompt))}"
|
|
265
|
-
)
|
|
233
|
+
return [len(set(texts))]
|
|
266
234
|
|
|
267
235
|
elif args.test_mode == "prefix":
|
|
268
236
|
# In prefix mode, we create prompts from the same long prompt, with different lengths of common prefix.
|
|
@@ -270,7 +238,7 @@ def test_deterministic(args):
|
|
|
270
238
|
num_prompts = len(len_prefix)
|
|
271
239
|
outputs = {i: [] for i in range(4)}
|
|
272
240
|
prompts = [LONG_PROMPT[: len_prefix[i]] for i in range(4)]
|
|
273
|
-
for i in range(
|
|
241
|
+
for i in range(args.n_start, args.n_start + args.n_trials):
|
|
274
242
|
batch_size = i
|
|
275
243
|
ret_dict = send_prefix(args, batch_size, prompts)
|
|
276
244
|
msg = f"Testing Trial {i} with batch size {batch_size},"
|
|
@@ -285,6 +253,168 @@ def test_deterministic(args):
|
|
|
285
253
|
f"Prompt {i} with prefix length {len_prefix[i]}: total samples: {len(outputs[i])}, Unique samples: {len(set(outputs[i]))}"
|
|
286
254
|
)
|
|
287
255
|
|
|
256
|
+
results = []
|
|
257
|
+
for i in range(num_prompts):
|
|
258
|
+
results.append(len(set(outputs[i])))
|
|
259
|
+
return results
|
|
260
|
+
|
|
261
|
+
elif args.test_mode == "radix_cache":
|
|
262
|
+
# Radix mode requires logprobs to compare results
|
|
263
|
+
args.return_logprob = True
|
|
264
|
+
|
|
265
|
+
print("\n=== Prefill Cache Consistency Test ===")
|
|
266
|
+
print(
|
|
267
|
+
"This test verifies prefill request produces consistent logprobs w/ and w/o cache.\n"
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# We noticed that we cannot call flush cache before any request, otherwise it will hang.
|
|
271
|
+
warmup_response = send_single(
|
|
272
|
+
args, input_ids=[1] * 64, max_new_tokens=65, return_full_response=True
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# Flush cache first to make sure there is no cache hit from previous tests
|
|
276
|
+
flush_response = requests.post(f"http://{args.host}:{args.port}/flush_cache")
|
|
277
|
+
|
|
278
|
+
print(f"Step 1: Generating random 64 token IDs...")
|
|
279
|
+
# Use a reasonable token ID range (e.g., 1-50000 for most tokenizers)
|
|
280
|
+
# Avoid special tokens like 0 (padding), 1 (BOS), 2 (EOS)
|
|
281
|
+
# set seed for random.randint
|
|
282
|
+
random.seed(42)
|
|
283
|
+
initial_token_ids = [random.randint(100, 50000) for _ in range(64)]
|
|
284
|
+
|
|
285
|
+
print(f"✓ Using {len(initial_token_ids)} initial tokens")
|
|
286
|
+
print(f" Initial token IDs: {initial_token_ids}")
|
|
287
|
+
|
|
288
|
+
print(
|
|
289
|
+
f"\nStep 2: Generating 2 tokens from {len(initial_token_ids)} token prefix..."
|
|
290
|
+
)
|
|
291
|
+
first_response = send_single(
|
|
292
|
+
args,
|
|
293
|
+
input_ids=initial_token_ids,
|
|
294
|
+
max_new_tokens=100,
|
|
295
|
+
return_full_response=True,
|
|
296
|
+
)
|
|
297
|
+
first_output_text = first_response["text"]
|
|
298
|
+
first_output_token_ids = first_response["output_ids"]
|
|
299
|
+
first_output_logprobs = first_response["meta_info"]["output_token_logprobs"]
|
|
300
|
+
|
|
301
|
+
expected_token_id = first_output_token_ids[-1]
|
|
302
|
+
expected_logprob = first_output_logprobs[-1][0]
|
|
303
|
+
|
|
304
|
+
print(f"✓ Generated {len(first_output_token_ids)} tokens")
|
|
305
|
+
print(f' Output text: "{first_output_text}"')
|
|
306
|
+
|
|
307
|
+
print(
|
|
308
|
+
f"\nStep 3: Generating with radix cache (164 tokens prefill, should hit > 128 tokens cache, based on page size)..."
|
|
309
|
+
)
|
|
310
|
+
prefix_token_ids = initial_token_ids + first_output_token_ids[:-1]
|
|
311
|
+
print(
|
|
312
|
+
f" Prefix: {len(initial_token_ids)} initial + 64 generated = {len(prefix_token_ids)} tokens"
|
|
313
|
+
)
|
|
314
|
+
print(f"Using Prompt: {prefix_token_ids}")
|
|
315
|
+
cached_response = send_single(
|
|
316
|
+
args,
|
|
317
|
+
input_ids=prefix_token_ids,
|
|
318
|
+
max_new_tokens=1,
|
|
319
|
+
return_full_response=True,
|
|
320
|
+
)
|
|
321
|
+
cached_logprobs = cached_response["meta_info"]["output_token_logprobs"]
|
|
322
|
+
cached_token_data = cached_logprobs[0]
|
|
323
|
+
cached_logprob = cached_token_data[0]
|
|
324
|
+
cached_token_id = cached_token_data[1]
|
|
325
|
+
|
|
326
|
+
print(f"✓ Generated with cache:")
|
|
327
|
+
print(f" Token ID: {cached_token_id}")
|
|
328
|
+
print(f" Logprob: {cached_logprob:.10f}")
|
|
329
|
+
|
|
330
|
+
print(f"\nStep 4: Flushing cache...")
|
|
331
|
+
flush_response = requests.post(f"http://{args.host}:{args.port}/flush_cache")
|
|
332
|
+
|
|
333
|
+
print(
|
|
334
|
+
f"\nStep 5: Generating without cache (same 164 tokens prefill, no cache)..."
|
|
335
|
+
)
|
|
336
|
+
print(f"Using Prompt: {prefix_token_ids}")
|
|
337
|
+
|
|
338
|
+
uncached_response = send_single(
|
|
339
|
+
args,
|
|
340
|
+
input_ids=prefix_token_ids,
|
|
341
|
+
max_new_tokens=1,
|
|
342
|
+
return_full_response=True,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
uncached_logprobs = uncached_response["meta_info"]["output_token_logprobs"]
|
|
346
|
+
uncached_token_data = uncached_logprobs[0]
|
|
347
|
+
uncached_logprob = uncached_token_data[0]
|
|
348
|
+
uncached_token_id = uncached_token_data[1]
|
|
349
|
+
|
|
350
|
+
print(f"✓ Generated without cache:")
|
|
351
|
+
print(f" Token ID: {uncached_token_id}")
|
|
352
|
+
print(f" Logprob: {uncached_logprob:.10f}")
|
|
353
|
+
|
|
354
|
+
# Step 6: Compare results
|
|
355
|
+
print(f"\n{'='*60}")
|
|
356
|
+
print("Comparison 1: Decode (Request 1) vs Prefill with Cache (Request 2)")
|
|
357
|
+
print("=" * 60)
|
|
358
|
+
|
|
359
|
+
# Compare first request (decode) vs second request (prefill with cache)
|
|
360
|
+
# We expect them to be different (different kernels)
|
|
361
|
+
decode_vs_prefill_token_match = expected_token_id == cached_token_id
|
|
362
|
+
decode_vs_prefill_logprob_match = expected_logprob == cached_logprob
|
|
363
|
+
|
|
364
|
+
print(
|
|
365
|
+
f" Decode token (Request 1): ID={expected_token_id}, logprob={expected_logprob:.10f}"
|
|
366
|
+
)
|
|
367
|
+
print(
|
|
368
|
+
f" Prefill w/ cache token (Request 2): ID={cached_token_id}, logprob={cached_logprob:.10f}"
|
|
369
|
+
)
|
|
370
|
+
print(
|
|
371
|
+
f" Token ID match: {'✓ YES' if decode_vs_prefill_token_match else '✗ NO'}"
|
|
372
|
+
)
|
|
373
|
+
print(
|
|
374
|
+
f" Logprob match: {'✓ YES' if decode_vs_prefill_logprob_match else '✗ NO'}"
|
|
375
|
+
)
|
|
376
|
+
if not decode_vs_prefill_logprob_match:
|
|
377
|
+
diff = abs(expected_logprob - cached_logprob)
|
|
378
|
+
print(f" Logprob difference: {diff:.10e}")
|
|
379
|
+
print(f" Note: We expect these to be DIFFERENT (decode vs prefill kernels)")
|
|
380
|
+
|
|
381
|
+
print(f"\n{'='*60}")
|
|
382
|
+
print(
|
|
383
|
+
"Comparison 2: Cached Prefill (Request 2) vs Uncached Prefill (Request 3)"
|
|
384
|
+
)
|
|
385
|
+
print("=" * 60)
|
|
386
|
+
|
|
387
|
+
# Main test: compare cached vs uncached prefill (should be identical)
|
|
388
|
+
token_match = cached_token_id == uncached_token_id
|
|
389
|
+
logprob_match = cached_logprob == uncached_logprob
|
|
390
|
+
|
|
391
|
+
print(
|
|
392
|
+
f" Cached prefill token (Request 2): ID={cached_token_id}, logprob={cached_logprob:.10f}"
|
|
393
|
+
)
|
|
394
|
+
print(
|
|
395
|
+
f" Uncached prefill token (Request 3): ID={uncached_token_id}, logprob={uncached_logprob:.10f}"
|
|
396
|
+
)
|
|
397
|
+
print(f" Token ID match: {'✓ YES' if token_match else '✗ NO'}")
|
|
398
|
+
if not token_match:
|
|
399
|
+
print(f" Cached: {cached_token_id}")
|
|
400
|
+
print(f" Uncached: {uncached_token_id}")
|
|
401
|
+
|
|
402
|
+
print(f" Logprob match: {'✓ YES' if logprob_match else '✗ NO'}")
|
|
403
|
+
if not logprob_match:
|
|
404
|
+
print(f" Cached: {cached_logprob:.10f}")
|
|
405
|
+
print(f" Uncached: {uncached_logprob:.10f}")
|
|
406
|
+
diff = abs(cached_logprob - uncached_logprob)
|
|
407
|
+
print(f" Difference: {diff:.10e}")
|
|
408
|
+
print(f" Note: We expect these to be IDENTICAL (both prefill kernels)")
|
|
409
|
+
|
|
410
|
+
print(f"\n{'='*60}")
|
|
411
|
+
if token_match and logprob_match:
|
|
412
|
+
print("✓✓✓ TEST PASSED - Radix cache is consistent! ✓✓✓")
|
|
413
|
+
return [1]
|
|
414
|
+
else:
|
|
415
|
+
print("✗✗✗ TEST FAILED - Radix cache produces different results! ✗✗✗")
|
|
416
|
+
return [0]
|
|
417
|
+
|
|
288
418
|
else:
|
|
289
419
|
raise ValueError(f"Invalid test mode: {args.test_mode}")
|
|
290
420
|
|
|
@@ -294,4 +424,7 @@ if __name__ == "__main__":
|
|
|
294
424
|
BenchArgs.add_cli_args(parser)
|
|
295
425
|
args = parser.parse_args()
|
|
296
426
|
|
|
427
|
+
if args.sampling_seed is None:
|
|
428
|
+
args.sampling_seed = 42
|
|
429
|
+
|
|
297
430
|
test_deterministic(args)
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
|
|
3
|
+
from sglang.srt.utils import kill_process_tree
|
|
4
|
+
from sglang.test.test_deterministic import BenchArgs, test_deterministic
|
|
5
|
+
from sglang.test.test_utils import (
|
|
6
|
+
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
7
|
+
DEFAULT_URL_FOR_TEST,
|
|
8
|
+
CustomTestCase,
|
|
9
|
+
popen_launch_server,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
DEFAULT_MODEL = "Qwen/Qwen3-8B"
|
|
13
|
+
COMMON_SERVER_ARGS = [
|
|
14
|
+
"--trust-remote-code",
|
|
15
|
+
"--cuda-graph-max-bs",
|
|
16
|
+
"32",
|
|
17
|
+
"--enable-deterministic-inference",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TestDeterministicBase(CustomTestCase):
|
|
22
|
+
@classmethod
|
|
23
|
+
def get_server_args(cls):
|
|
24
|
+
return COMMON_SERVER_ARGS
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def get_model(cls):
|
|
28
|
+
return DEFAULT_MODEL
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def setUpClass(cls):
|
|
32
|
+
cls.model = cls.get_model()
|
|
33
|
+
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
34
|
+
if "--attention-backend" not in cls.get_server_args():
|
|
35
|
+
raise unittest.SkipTest("Skip the base test class")
|
|
36
|
+
|
|
37
|
+
cls.process = popen_launch_server(
|
|
38
|
+
cls.model,
|
|
39
|
+
cls.base_url,
|
|
40
|
+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
41
|
+
other_args=cls.get_server_args(),
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def tearDownClass(cls):
|
|
46
|
+
kill_process_tree(cls.process.pid)
|
|
47
|
+
|
|
48
|
+
def _extract_host_and_port(self, url):
|
|
49
|
+
return url.split("://")[-1].split(":")[0], int(url.split(":")[-1])
|
|
50
|
+
|
|
51
|
+
def test_single(self):
|
|
52
|
+
args = BenchArgs()
|
|
53
|
+
url = DEFAULT_URL_FOR_TEST
|
|
54
|
+
args.host, args.port = self._extract_host_and_port(url)
|
|
55
|
+
args.test_mode = "single"
|
|
56
|
+
args.n_start = 10
|
|
57
|
+
args.n_trials = 20
|
|
58
|
+
results = test_deterministic(args)
|
|
59
|
+
args.temperature = 0.5 # test for deterministic sampling
|
|
60
|
+
for result in results:
|
|
61
|
+
assert result == 1
|
|
62
|
+
|
|
63
|
+
def test_prefix(self):
|
|
64
|
+
args = BenchArgs()
|
|
65
|
+
url = DEFAULT_URL_FOR_TEST
|
|
66
|
+
args.host, args.port = self._extract_host_and_port(url)
|
|
67
|
+
args.test_mode = "prefix"
|
|
68
|
+
args.n_start = 10
|
|
69
|
+
args.n_trials = 10
|
|
70
|
+
args.temperature = 0.5 # test for deterministic sampling
|
|
71
|
+
results = test_deterministic(args)
|
|
72
|
+
for result in results:
|
|
73
|
+
assert result == 1
|
|
@@ -1,16 +1,23 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
1
3
|
import time
|
|
4
|
+
import warnings
|
|
2
5
|
from urllib.parse import urlparse
|
|
3
6
|
|
|
4
7
|
import requests
|
|
5
8
|
|
|
9
|
+
from sglang.srt.environ import envs
|
|
6
10
|
from sglang.srt.utils import kill_process_tree
|
|
7
11
|
from sglang.test.test_utils import (
|
|
8
12
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
9
13
|
DEFAULT_URL_FOR_TEST,
|
|
10
14
|
CustomTestCase,
|
|
15
|
+
is_in_ci,
|
|
11
16
|
popen_with_error_check,
|
|
12
17
|
)
|
|
13
18
|
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
14
21
|
|
|
15
22
|
class TestDisaggregationBase(CustomTestCase):
|
|
16
23
|
@classmethod
|
|
@@ -27,6 +34,24 @@ class TestDisaggregationBase(CustomTestCase):
|
|
|
27
34
|
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
|
|
28
35
|
cls.process_lb, cls.process_decode, cls.process_prefill = None, None, None
|
|
29
36
|
|
|
37
|
+
# config transfer backend and rdma devices
|
|
38
|
+
if is_in_ci():
|
|
39
|
+
cls.transfer_backend = ["--disaggregation-transfer-backend", "mooncake"]
|
|
40
|
+
cls.rdma_devices = ["--disaggregation-ib-device", get_rdma_devices_args()]
|
|
41
|
+
else:
|
|
42
|
+
cls.transfer_backend = [
|
|
43
|
+
"--disaggregation-transfer-backend",
|
|
44
|
+
envs.SGLANG_TEST_PD_DISAGG_BACKEND.get(),
|
|
45
|
+
]
|
|
46
|
+
cls.rdma_devices = [
|
|
47
|
+
"--disaggregation-ib-device",
|
|
48
|
+
envs.SGLANG_TEST_PD_DISAGG_DEVICES.get(),
|
|
49
|
+
]
|
|
50
|
+
if cls.rdma_devices[1] is None:
|
|
51
|
+
cls.rdma_devices = []
|
|
52
|
+
msg = "No RDMA devices specified for disaggregation test, using default settings."
|
|
53
|
+
warnings.warn(msg)
|
|
54
|
+
|
|
30
55
|
@classmethod
|
|
31
56
|
def launch_lb(cls):
|
|
32
57
|
lb_command = [
|
|
@@ -75,3 +100,59 @@ class TestDisaggregationBase(CustomTestCase):
|
|
|
75
100
|
|
|
76
101
|
# wait for 5 seconds
|
|
77
102
|
time.sleep(5)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def get_rdma_devices_args():
|
|
106
|
+
def _parse_list_env(var_name: str):
|
|
107
|
+
val = os.getenv(var_name)
|
|
108
|
+
if not val:
|
|
109
|
+
return None
|
|
110
|
+
items = [x.strip() for x in val.split(",") if x.strip()]
|
|
111
|
+
return items or None
|
|
112
|
+
|
|
113
|
+
def _pick_default_pair(rdma_all_devices):
|
|
114
|
+
return [rdma_all_devices[0], rdma_all_devices[len(rdma_all_devices) // 2]]
|
|
115
|
+
|
|
116
|
+
rdma_all_devices = _parse_list_env("SGLANG_CI_RDMA_ALL_DEVICES") or [
|
|
117
|
+
f"mlx5_roce{i}" for i in range(8)
|
|
118
|
+
]
|
|
119
|
+
logger.info("Resolved rdma_all_devices=%s", rdma_all_devices)
|
|
120
|
+
|
|
121
|
+
n_rdma = len(rdma_all_devices)
|
|
122
|
+
|
|
123
|
+
# 1. Get visible GPU indices
|
|
124
|
+
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
|
|
125
|
+
if not cuda_visible_devices:
|
|
126
|
+
warnings.warn("CUDA_VISIBLE_DEVICES is not set. Using default RDMA devices.")
|
|
127
|
+
return ",".join(_pick_default_pair(rdma_all_devices))
|
|
128
|
+
|
|
129
|
+
try:
|
|
130
|
+
# Convert to list of integers (handling possible spaces and empty strings)
|
|
131
|
+
gpu_indices = [
|
|
132
|
+
int(idx.strip()) for idx in cuda_visible_devices.split(",") if idx.strip()
|
|
133
|
+
]
|
|
134
|
+
if not gpu_indices or len(gpu_indices) > 4:
|
|
135
|
+
return ",".join(_pick_default_pair(rdma_all_devices))
|
|
136
|
+
except ValueError:
|
|
137
|
+
warnings.warn(f"Invalid CUDA_VISIBLE_DEVICES format: {cuda_visible_devices}")
|
|
138
|
+
return ",".join(_pick_default_pair(rdma_all_devices))
|
|
139
|
+
|
|
140
|
+
# 2. Calculate base RDMA index group (each group of 4 GPUs uses consecutive devices)
|
|
141
|
+
base_rdma_group = (min(gpu_indices) // 4) * 4
|
|
142
|
+
for gpu_idx in gpu_indices:
|
|
143
|
+
if not (base_rdma_group <= gpu_idx < base_rdma_group + 4):
|
|
144
|
+
warnings.warn(
|
|
145
|
+
f"GPU index {gpu_idx} is outside expected group "
|
|
146
|
+
f"{base_rdma_group}-{base_rdma_group+3}"
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# 3. Generate RDMA device names
|
|
150
|
+
rdma_devices = []
|
|
151
|
+
for gpu_idx in gpu_indices:
|
|
152
|
+
nic_index = gpu_idx // (8 // n_rdma)
|
|
153
|
+
rdma_devices.append(rdma_all_devices[nic_index])
|
|
154
|
+
|
|
155
|
+
if not rdma_devices:
|
|
156
|
+
return ",".join(_pick_default_pair(rdma_all_devices))
|
|
157
|
+
|
|
158
|
+
return ",".join(rdma_devices)
|
sglang/test/test_marlin_moe.py
CHANGED