sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +54 -37
- sglang/bench_one_batch_server.py +340 -34
- sglang/bench_serving.py +340 -159
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +9 -2
- sglang/profiler.py +20 -3
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +309 -0
- sglang/srt/configs/load_config.py +33 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +284 -118
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +576 -0
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +6 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +26 -15
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +268 -98
- sglang/srt/disaggregation/decode.py +172 -39
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +203 -555
- sglang/srt/disaggregation/nixl/conn.py +217 -63
- sglang/srt/disaggregation/prefill.py +113 -270
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +203 -97
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +85 -65
- sglang/srt/entrypoints/grpc_server.py +632 -305
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +169 -17
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +327 -34
- sglang/srt/entrypoints/openai/serving_base.py +74 -8
- sglang/srt/entrypoints/openai/serving_chat.py +202 -118
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +20 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +47 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +323 -0
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +21 -16
- sglang/srt/function_call/glm4_moe_detector.py +4 -8
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +61 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +98 -7
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/grpc_request_manager.py +915 -0
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
- sglang/srt/layers/activation.py +11 -7
- sglang/srt/layers/attention/aiter_backend.py +17 -18
- sglang/srt/layers/attention/ascend_backend.py +125 -10
- sglang/srt/layers/attention/attention_registry.py +226 -0
- sglang/srt/layers/attention/base_attn_backend.py +32 -4
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +52 -15
- sglang/srt/layers/attention/flashinfer_backend.py +357 -212
- sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
- sglang/srt/layers/attention/flashmla_backend.py +9 -7
- sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
- sglang/srt/layers/attention/mamba/mamba.py +514 -1
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +23 -0
- sglang/srt/layers/attention/nsa_backend.py +1201 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +249 -42
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
- sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +61 -3
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +19 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +28 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +47 -15
- sglang/srt/layers/linear.py +30 -5
- sglang/srt/layers/logits_processor.py +161 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
- sglang/srt/layers/moe/ep_moe/layer.py +243 -448
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +27 -1
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +86 -20
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +43 -15
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +141 -81
- sglang/srt/layers/quantization/mxfp4.py +17 -34
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -24
- sglang/srt/layers/quantization/w8a8_int8.py +45 -27
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +750 -46
- sglang/srt/layers/sampler.py +84 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +23 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +9 -4
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +33 -7
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +41 -17
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +83 -152
- sglang/srt/managers/data_parallel_controller.py +156 -87
- sglang/srt/managers/detokenizer_manager.py +51 -24
- sglang/srt/managers/io_struct.py +223 -129
- sglang/srt/managers/mm_utils.py +49 -10
- sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +130 -0
- sglang/srt/managers/schedule_batch.py +340 -529
- sglang/srt/managers/schedule_policy.py +158 -18
- sglang/srt/managers/scheduler.py +665 -620
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
- sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
- sglang/srt/managers/tokenizer_manager.py +462 -226
- sglang/srt/managers/tp_worker.py +217 -156
- sglang/srt/managers/utils.py +79 -47
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +42 -28
- sglang/srt/mem_cache/base_prefix_cache.py +3 -3
- sglang/srt/mem_cache/chunk_cache.py +20 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +38 -0
- sglang/srt/mem_cache/hicache_storage.py +44 -2
- sglang/srt/mem_cache/hiradix_cache.py +134 -34
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +602 -208
- sglang/srt/mem_cache/memory_pool_host.py +134 -183
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +263 -78
- sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +115 -58
- sglang/srt/metrics/collector.py +113 -120
- sglang/srt/metrics/func_timer.py +3 -8
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +81 -36
- sglang/srt/model_executor/forward_batch_info.py +40 -50
- sglang/srt/model_executor/model_runner.py +507 -319
- sglang/srt/model_executor/npu_graph_runner.py +11 -5
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +438 -37
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +200 -27
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +40 -56
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +25 -4
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +793 -235
- sglang/srt/models/dots_ocr.py +171 -0
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +570 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -3
- sglang/srt/models/glm4_moe.py +17 -40
- sglang/srt/models/glm4_moe_nextn.py +4 -4
- sglang/srt/models/glm4v.py +3 -2
- sglang/srt/models/glm4v_moe.py +6 -6
- sglang/srt/models/gpt_oss.py +12 -35
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +4 -2
- sglang/srt/models/llama.py +6 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +6 -23
- sglang/srt/models/longcat_flash_nextn.py +4 -15
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +27 -6
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +5 -5
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +70 -4
- sglang/srt/models/qwen2_vl.py +6 -3
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +50 -38
- sglang/srt/models/qwen3_next.py +43 -21
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +791 -0
- sglang/srt/models/qwen3_vl_moe.py +343 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +268 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +61 -0
- sglang/srt/multimodal/processors/base_processor.py +21 -9
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +2 -4
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +20 -10
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +83 -17
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +36 -23
- sglang/srt/sampling/sampling_params.py +75 -0
- sglang/srt/server_args.py +1300 -338
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +161 -0
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
- sglang/srt/speculative/eagle_info.py +786 -0
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +113 -1270
- sglang/srt/speculative/eagle_worker.py +120 -285
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/ngram_info.py +433 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +49 -0
- sglang/srt/speculative/spec_utils.py +641 -0
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +35 -18
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/{utils.py → utils/common.py} +583 -113
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +120 -11
- sglang/test/runners.py +3 -1
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +8 -2
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +3 -4
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +430 -0
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +93 -1
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +432 -16
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
- sglang/srt/entrypoints/grpc_request_manager.py +0 -580
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,430 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Batch the same prompt in random batch sizes, and test if the results are consistent across different trials.
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
# Single mode: test determinism with varying batch sizes
|
|
6
|
+
python3 -m sglang.test.test_deterministic --n-trials 50 --test-mode single
|
|
7
|
+
|
|
8
|
+
# Prefix mode: test with shared prefixes
|
|
9
|
+
python3 -m sglang.test.test_deterministic --n-start 1 --n-trials 50 --test-mode prefix
|
|
10
|
+
|
|
11
|
+
# Radix Cache Consistency mode: test radix cache determinism (cached vs uncached prefill)
|
|
12
|
+
python3 -m sglang.test.test_deterministic --test-mode radix_cache
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import argparse
|
|
16
|
+
import dataclasses
|
|
17
|
+
import json
|
|
18
|
+
import os
|
|
19
|
+
import random
|
|
20
|
+
from typing import List
|
|
21
|
+
|
|
22
|
+
import requests
|
|
23
|
+
|
|
24
|
+
from sglang.profiler import run_profile
|
|
25
|
+
|
|
26
|
+
PROMPT_1 = "Tell me about Richard Feynman: "
|
|
27
|
+
PROMPT_2 = "Generate 1000 random numbers. Go directly into it, don't say Sure and don't say here are numbers. Just start with a number."
|
|
28
|
+
dirpath = os.path.dirname(__file__)
|
|
29
|
+
with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
|
|
30
|
+
LONG_PROMPT = f.read()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclasses.dataclass
|
|
34
|
+
class BenchArgs:
|
|
35
|
+
host: str = "localhost"
|
|
36
|
+
port: int = 30000
|
|
37
|
+
batch_size: int = 1
|
|
38
|
+
temperature: float = 0.0
|
|
39
|
+
sampling_seed: int = 42
|
|
40
|
+
max_new_tokens: int = 100
|
|
41
|
+
frequency_penalty: float = 0.0
|
|
42
|
+
presence_penalty: float = 0.0
|
|
43
|
+
return_logprob: bool = False
|
|
44
|
+
stream: bool = False
|
|
45
|
+
profile: bool = False
|
|
46
|
+
profile_steps: int = 3
|
|
47
|
+
profile_by_stage: bool = False
|
|
48
|
+
test_mode: str = "single"
|
|
49
|
+
n_trials: int = 50
|
|
50
|
+
n_start: int = 1
|
|
51
|
+
|
|
52
|
+
@staticmethod
|
|
53
|
+
def add_cli_args(parser: argparse.ArgumentParser):
|
|
54
|
+
parser.add_argument("--host", type=str, default=BenchArgs.host)
|
|
55
|
+
parser.add_argument("--port", type=int, default=BenchArgs.port)
|
|
56
|
+
parser.add_argument("--n-trials", type=int, default=BenchArgs.n_trials)
|
|
57
|
+
parser.add_argument("--n-start", type=int, default=BenchArgs.n_start)
|
|
58
|
+
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
|
|
59
|
+
parser.add_argument(
|
|
60
|
+
"--sampling-seed", type=int, default=BenchArgs.sampling_seed
|
|
61
|
+
)
|
|
62
|
+
parser.add_argument(
|
|
63
|
+
"--max-new-tokens", type=int, default=BenchArgs.max_new_tokens
|
|
64
|
+
)
|
|
65
|
+
parser.add_argument(
|
|
66
|
+
"--frequency-penalty", type=float, default=BenchArgs.frequency_penalty
|
|
67
|
+
)
|
|
68
|
+
parser.add_argument(
|
|
69
|
+
"--presence-penalty", type=float, default=BenchArgs.presence_penalty
|
|
70
|
+
)
|
|
71
|
+
parser.add_argument("--return-logprob", action="store_true")
|
|
72
|
+
parser.add_argument("--stream", action="store_true")
|
|
73
|
+
parser.add_argument(
|
|
74
|
+
"--test-mode",
|
|
75
|
+
type=str,
|
|
76
|
+
default=BenchArgs.test_mode,
|
|
77
|
+
choices=[
|
|
78
|
+
"single",
|
|
79
|
+
"prefix",
|
|
80
|
+
"radix_cache",
|
|
81
|
+
],
|
|
82
|
+
)
|
|
83
|
+
parser.add_argument("--profile", action="store_true")
|
|
84
|
+
parser.add_argument(
|
|
85
|
+
"--profile-steps", type=int, default=BenchArgs.profile_steps
|
|
86
|
+
)
|
|
87
|
+
parser.add_argument("--profile-by-stage", action="store_true")
|
|
88
|
+
|
|
89
|
+
@classmethod
|
|
90
|
+
def from_cli_args(cls, args: argparse.Namespace):
|
|
91
|
+
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
|
92
|
+
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def send_single(
|
|
96
|
+
args,
|
|
97
|
+
batch_size: int = 1,
|
|
98
|
+
profile: bool = False,
|
|
99
|
+
profile_steps: int = 3,
|
|
100
|
+
profile_by_stage: bool = False,
|
|
101
|
+
return_full_response: bool = False,
|
|
102
|
+
input_ids: List[int] = None,
|
|
103
|
+
max_new_tokens: int = None,
|
|
104
|
+
):
|
|
105
|
+
base_url = f"http://{args.host}:{args.port}"
|
|
106
|
+
|
|
107
|
+
# Use input_ids if provided, otherwise use text prompts
|
|
108
|
+
if input_ids is not None:
|
|
109
|
+
json_data = {
|
|
110
|
+
"input_ids": input_ids,
|
|
111
|
+
"sampling_params": {
|
|
112
|
+
"temperature": args.temperature,
|
|
113
|
+
"max_new_tokens": (
|
|
114
|
+
max_new_tokens
|
|
115
|
+
if max_new_tokens is not None
|
|
116
|
+
else args.max_new_tokens
|
|
117
|
+
),
|
|
118
|
+
"frequency_penalty": args.frequency_penalty,
|
|
119
|
+
"presence_penalty": args.presence_penalty,
|
|
120
|
+
},
|
|
121
|
+
"return_logprob": args.return_logprob,
|
|
122
|
+
"stream": args.stream,
|
|
123
|
+
}
|
|
124
|
+
else:
|
|
125
|
+
prompt = [PROMPT_1] * batch_size
|
|
126
|
+
json_data = {
|
|
127
|
+
"text": prompt,
|
|
128
|
+
"sampling_params": {
|
|
129
|
+
"temperature": args.temperature,
|
|
130
|
+
"max_new_tokens": (
|
|
131
|
+
max_new_tokens
|
|
132
|
+
if max_new_tokens is not None
|
|
133
|
+
else args.max_new_tokens
|
|
134
|
+
),
|
|
135
|
+
"frequency_penalty": args.frequency_penalty,
|
|
136
|
+
"presence_penalty": args.presence_penalty,
|
|
137
|
+
},
|
|
138
|
+
"return_logprob": args.return_logprob,
|
|
139
|
+
"stream": args.stream,
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
if args.sampling_seed is not None:
|
|
143
|
+
# sglang server cannot parse None value for sampling_seed
|
|
144
|
+
json_data["sampling_params"]["sampling_seed"] = args.sampling_seed
|
|
145
|
+
|
|
146
|
+
if profile:
|
|
147
|
+
run_profile(
|
|
148
|
+
base_url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
response = requests.post(
|
|
152
|
+
f"{base_url}/generate",
|
|
153
|
+
json=json_data,
|
|
154
|
+
stream=args.stream,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
if response.status_code != 200:
|
|
158
|
+
ret = response.json()
|
|
159
|
+
print(f"Error: {ret}")
|
|
160
|
+
return None
|
|
161
|
+
|
|
162
|
+
if args.stream:
|
|
163
|
+
for chunk in response.iter_lines(decode_unicode=False):
|
|
164
|
+
chunk = chunk.decode("utf-8")
|
|
165
|
+
if chunk and chunk.startswith("data:"):
|
|
166
|
+
if chunk == "data: [DONE]":
|
|
167
|
+
break
|
|
168
|
+
ret = json.loads(chunk[5:].strip("\n"))
|
|
169
|
+
else:
|
|
170
|
+
ret = response.json()
|
|
171
|
+
|
|
172
|
+
ret = ret[0] if isinstance(ret, list) else ret
|
|
173
|
+
|
|
174
|
+
if return_full_response:
|
|
175
|
+
return ret
|
|
176
|
+
else:
|
|
177
|
+
return ret["text"]
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def send_prefix(args, batch_size: int, prompts: List[str]):
|
|
181
|
+
requests.post(f"http://{args.host}:{args.port}/flush_cache")
|
|
182
|
+
|
|
183
|
+
batch_data = []
|
|
184
|
+
sampled_indices = []
|
|
185
|
+
for _ in range(batch_size):
|
|
186
|
+
sampled_index = random.randint(0, len(prompts) - 1)
|
|
187
|
+
sampled_indices.append(sampled_index)
|
|
188
|
+
batch_data.append(prompts[sampled_index])
|
|
189
|
+
|
|
190
|
+
json_data = {
|
|
191
|
+
"text": batch_data,
|
|
192
|
+
"sampling_params": {
|
|
193
|
+
"temperature": args.temperature,
|
|
194
|
+
"max_new_tokens": args.max_new_tokens,
|
|
195
|
+
"frequency_penalty": args.frequency_penalty,
|
|
196
|
+
"presence_penalty": args.presence_penalty,
|
|
197
|
+
},
|
|
198
|
+
"return_logprob": args.return_logprob,
|
|
199
|
+
"stream": args.stream,
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
if args.sampling_seed is not None:
|
|
203
|
+
json_data["sampling_params"]["sampling_seed"] = args.sampling_seed
|
|
204
|
+
|
|
205
|
+
response = requests.post(
|
|
206
|
+
f"http://{args.host}:{args.port}/generate",
|
|
207
|
+
json=json_data,
|
|
208
|
+
stream=args.stream,
|
|
209
|
+
)
|
|
210
|
+
ret = response.json()
|
|
211
|
+
if response.status_code != 200:
|
|
212
|
+
print(ret)
|
|
213
|
+
return -1, -1, -1
|
|
214
|
+
|
|
215
|
+
ret_dict = {i: [] for i in range(len(prompts))}
|
|
216
|
+
for i in range(batch_size):
|
|
217
|
+
ret_dict[sampled_indices[i]].append(ret[i]["text"])
|
|
218
|
+
|
|
219
|
+
return ret_dict
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def test_deterministic(args):
|
|
223
|
+
if args.test_mode == "single":
|
|
224
|
+
# In single mode, we test the deterministic behavior by sending the same prompt in batch sizes ranging from 1 to n_trials.
|
|
225
|
+
texts = []
|
|
226
|
+
for i in range(1, args.n_trials + 1):
|
|
227
|
+
batch_size = i
|
|
228
|
+
text = send_single(args, batch_size, args.profile)
|
|
229
|
+
text = text.replace("\n", " ")
|
|
230
|
+
print(f"Trial {i} with batch size {batch_size}: {text}")
|
|
231
|
+
texts.append(text)
|
|
232
|
+
print(f"Total samples: {len(texts)}, Unique samples: {len(set(texts))}")
|
|
233
|
+
return [len(set(texts))]
|
|
234
|
+
|
|
235
|
+
elif args.test_mode == "prefix":
|
|
236
|
+
# In prefix mode, we create prompts from the same long prompt, with different lengths of common prefix.
|
|
237
|
+
len_prefix = [1, 511, 2048, 4097]
|
|
238
|
+
num_prompts = len(len_prefix)
|
|
239
|
+
outputs = {i: [] for i in range(4)}
|
|
240
|
+
prompts = [LONG_PROMPT[: len_prefix[i]] for i in range(4)]
|
|
241
|
+
for i in range(args.n_start, args.n_start + args.n_trials):
|
|
242
|
+
batch_size = i
|
|
243
|
+
ret_dict = send_prefix(args, batch_size, prompts)
|
|
244
|
+
msg = f"Testing Trial {i} with batch size {batch_size},"
|
|
245
|
+
for i in range(num_prompts):
|
|
246
|
+
msg += f" # prefix length {len_prefix[i]}: {len(ret_dict[i])},"
|
|
247
|
+
print(msg)
|
|
248
|
+
for i in range(num_prompts):
|
|
249
|
+
outputs[i].extend(ret_dict[i])
|
|
250
|
+
|
|
251
|
+
for i in range(num_prompts):
|
|
252
|
+
print(
|
|
253
|
+
f"Prompt {i} with prefix length {len_prefix[i]}: total samples: {len(outputs[i])}, Unique samples: {len(set(outputs[i]))}"
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
results = []
|
|
257
|
+
for i in range(num_prompts):
|
|
258
|
+
results.append(len(set(outputs[i])))
|
|
259
|
+
return results
|
|
260
|
+
|
|
261
|
+
elif args.test_mode == "radix_cache":
|
|
262
|
+
# Radix mode requires logprobs to compare results
|
|
263
|
+
args.return_logprob = True
|
|
264
|
+
|
|
265
|
+
print("\n=== Prefill Cache Consistency Test ===")
|
|
266
|
+
print(
|
|
267
|
+
"This test verifies prefill request produces consistent logprobs w/ and w/o cache.\n"
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# We noticed that we cannot call flush cache before any request, otherwise it will hang.
|
|
271
|
+
warmup_response = send_single(
|
|
272
|
+
args, input_ids=[1] * 64, max_new_tokens=65, return_full_response=True
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# Flush cache first to make sure there is no cache hit from previous tests
|
|
276
|
+
flush_response = requests.post(f"http://{args.host}:{args.port}/flush_cache")
|
|
277
|
+
|
|
278
|
+
print(f"Step 1: Generating random 64 token IDs...")
|
|
279
|
+
# Use a reasonable token ID range (e.g., 1-50000 for most tokenizers)
|
|
280
|
+
# Avoid special tokens like 0 (padding), 1 (BOS), 2 (EOS)
|
|
281
|
+
# set seed for random.randint
|
|
282
|
+
random.seed(42)
|
|
283
|
+
initial_token_ids = [random.randint(100, 50000) for _ in range(64)]
|
|
284
|
+
|
|
285
|
+
print(f"✓ Using {len(initial_token_ids)} initial tokens")
|
|
286
|
+
print(f" Initial token IDs: {initial_token_ids}")
|
|
287
|
+
|
|
288
|
+
print(
|
|
289
|
+
f"\nStep 2: Generating 2 tokens from {len(initial_token_ids)} token prefix..."
|
|
290
|
+
)
|
|
291
|
+
first_response = send_single(
|
|
292
|
+
args,
|
|
293
|
+
input_ids=initial_token_ids,
|
|
294
|
+
max_new_tokens=100,
|
|
295
|
+
return_full_response=True,
|
|
296
|
+
)
|
|
297
|
+
first_output_text = first_response["text"]
|
|
298
|
+
first_output_token_ids = first_response["output_ids"]
|
|
299
|
+
first_output_logprobs = first_response["meta_info"]["output_token_logprobs"]
|
|
300
|
+
|
|
301
|
+
expected_token_id = first_output_token_ids[-1]
|
|
302
|
+
expected_logprob = first_output_logprobs[-1][0]
|
|
303
|
+
|
|
304
|
+
print(f"✓ Generated {len(first_output_token_ids)} tokens")
|
|
305
|
+
print(f' Output text: "{first_output_text}"')
|
|
306
|
+
|
|
307
|
+
print(
|
|
308
|
+
f"\nStep 3: Generating with radix cache (164 tokens prefill, should hit > 128 tokens cache, based on page size)..."
|
|
309
|
+
)
|
|
310
|
+
prefix_token_ids = initial_token_ids + first_output_token_ids[:-1]
|
|
311
|
+
print(
|
|
312
|
+
f" Prefix: {len(initial_token_ids)} initial + 64 generated = {len(prefix_token_ids)} tokens"
|
|
313
|
+
)
|
|
314
|
+
print(f"Using Prompt: {prefix_token_ids}")
|
|
315
|
+
cached_response = send_single(
|
|
316
|
+
args,
|
|
317
|
+
input_ids=prefix_token_ids,
|
|
318
|
+
max_new_tokens=1,
|
|
319
|
+
return_full_response=True,
|
|
320
|
+
)
|
|
321
|
+
cached_logprobs = cached_response["meta_info"]["output_token_logprobs"]
|
|
322
|
+
cached_token_data = cached_logprobs[0]
|
|
323
|
+
cached_logprob = cached_token_data[0]
|
|
324
|
+
cached_token_id = cached_token_data[1]
|
|
325
|
+
|
|
326
|
+
print(f"✓ Generated with cache:")
|
|
327
|
+
print(f" Token ID: {cached_token_id}")
|
|
328
|
+
print(f" Logprob: {cached_logprob:.10f}")
|
|
329
|
+
|
|
330
|
+
print(f"\nStep 4: Flushing cache...")
|
|
331
|
+
flush_response = requests.post(f"http://{args.host}:{args.port}/flush_cache")
|
|
332
|
+
|
|
333
|
+
print(
|
|
334
|
+
f"\nStep 5: Generating without cache (same 164 tokens prefill, no cache)..."
|
|
335
|
+
)
|
|
336
|
+
print(f"Using Prompt: {prefix_token_ids}")
|
|
337
|
+
|
|
338
|
+
uncached_response = send_single(
|
|
339
|
+
args,
|
|
340
|
+
input_ids=prefix_token_ids,
|
|
341
|
+
max_new_tokens=1,
|
|
342
|
+
return_full_response=True,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
uncached_logprobs = uncached_response["meta_info"]["output_token_logprobs"]
|
|
346
|
+
uncached_token_data = uncached_logprobs[0]
|
|
347
|
+
uncached_logprob = uncached_token_data[0]
|
|
348
|
+
uncached_token_id = uncached_token_data[1]
|
|
349
|
+
|
|
350
|
+
print(f"✓ Generated without cache:")
|
|
351
|
+
print(f" Token ID: {uncached_token_id}")
|
|
352
|
+
print(f" Logprob: {uncached_logprob:.10f}")
|
|
353
|
+
|
|
354
|
+
# Step 6: Compare results
|
|
355
|
+
print(f"\n{'='*60}")
|
|
356
|
+
print("Comparison 1: Decode (Request 1) vs Prefill with Cache (Request 2)")
|
|
357
|
+
print("=" * 60)
|
|
358
|
+
|
|
359
|
+
# Compare first request (decode) vs second request (prefill with cache)
|
|
360
|
+
# We expect them to be different (different kernels)
|
|
361
|
+
decode_vs_prefill_token_match = expected_token_id == cached_token_id
|
|
362
|
+
decode_vs_prefill_logprob_match = expected_logprob == cached_logprob
|
|
363
|
+
|
|
364
|
+
print(
|
|
365
|
+
f" Decode token (Request 1): ID={expected_token_id}, logprob={expected_logprob:.10f}"
|
|
366
|
+
)
|
|
367
|
+
print(
|
|
368
|
+
f" Prefill w/ cache token (Request 2): ID={cached_token_id}, logprob={cached_logprob:.10f}"
|
|
369
|
+
)
|
|
370
|
+
print(
|
|
371
|
+
f" Token ID match: {'✓ YES' if decode_vs_prefill_token_match else '✗ NO'}"
|
|
372
|
+
)
|
|
373
|
+
print(
|
|
374
|
+
f" Logprob match: {'✓ YES' if decode_vs_prefill_logprob_match else '✗ NO'}"
|
|
375
|
+
)
|
|
376
|
+
if not decode_vs_prefill_logprob_match:
|
|
377
|
+
diff = abs(expected_logprob - cached_logprob)
|
|
378
|
+
print(f" Logprob difference: {diff:.10e}")
|
|
379
|
+
print(f" Note: We expect these to be DIFFERENT (decode vs prefill kernels)")
|
|
380
|
+
|
|
381
|
+
print(f"\n{'='*60}")
|
|
382
|
+
print(
|
|
383
|
+
"Comparison 2: Cached Prefill (Request 2) vs Uncached Prefill (Request 3)"
|
|
384
|
+
)
|
|
385
|
+
print("=" * 60)
|
|
386
|
+
|
|
387
|
+
# Main test: compare cached vs uncached prefill (should be identical)
|
|
388
|
+
token_match = cached_token_id == uncached_token_id
|
|
389
|
+
logprob_match = cached_logprob == uncached_logprob
|
|
390
|
+
|
|
391
|
+
print(
|
|
392
|
+
f" Cached prefill token (Request 2): ID={cached_token_id}, logprob={cached_logprob:.10f}"
|
|
393
|
+
)
|
|
394
|
+
print(
|
|
395
|
+
f" Uncached prefill token (Request 3): ID={uncached_token_id}, logprob={uncached_logprob:.10f}"
|
|
396
|
+
)
|
|
397
|
+
print(f" Token ID match: {'✓ YES' if token_match else '✗ NO'}")
|
|
398
|
+
if not token_match:
|
|
399
|
+
print(f" Cached: {cached_token_id}")
|
|
400
|
+
print(f" Uncached: {uncached_token_id}")
|
|
401
|
+
|
|
402
|
+
print(f" Logprob match: {'✓ YES' if logprob_match else '✗ NO'}")
|
|
403
|
+
if not logprob_match:
|
|
404
|
+
print(f" Cached: {cached_logprob:.10f}")
|
|
405
|
+
print(f" Uncached: {uncached_logprob:.10f}")
|
|
406
|
+
diff = abs(cached_logprob - uncached_logprob)
|
|
407
|
+
print(f" Difference: {diff:.10e}")
|
|
408
|
+
print(f" Note: We expect these to be IDENTICAL (both prefill kernels)")
|
|
409
|
+
|
|
410
|
+
print(f"\n{'='*60}")
|
|
411
|
+
if token_match and logprob_match:
|
|
412
|
+
print("✓✓✓ TEST PASSED - Radix cache is consistent! ✓✓✓")
|
|
413
|
+
return [1]
|
|
414
|
+
else:
|
|
415
|
+
print("✗✗✗ TEST FAILED - Radix cache produces different results! ✗✗✗")
|
|
416
|
+
return [0]
|
|
417
|
+
|
|
418
|
+
else:
|
|
419
|
+
raise ValueError(f"Invalid test mode: {args.test_mode}")
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
if __name__ == "__main__":
|
|
423
|
+
parser = argparse.ArgumentParser()
|
|
424
|
+
BenchArgs.add_cli_args(parser)
|
|
425
|
+
args = parser.parse_args()
|
|
426
|
+
|
|
427
|
+
if args.sampling_seed is None:
|
|
428
|
+
args.sampling_seed = 42
|
|
429
|
+
|
|
430
|
+
test_deterministic(args)
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
|
|
3
|
+
from sglang.srt.utils import kill_process_tree
|
|
4
|
+
from sglang.test.test_deterministic import BenchArgs, test_deterministic
|
|
5
|
+
from sglang.test.test_utils import (
|
|
6
|
+
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
7
|
+
DEFAULT_URL_FOR_TEST,
|
|
8
|
+
CustomTestCase,
|
|
9
|
+
popen_launch_server,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
DEFAULT_MODEL = "Qwen/Qwen3-8B"
|
|
13
|
+
COMMON_SERVER_ARGS = [
|
|
14
|
+
"--trust-remote-code",
|
|
15
|
+
"--cuda-graph-max-bs",
|
|
16
|
+
"32",
|
|
17
|
+
"--enable-deterministic-inference",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TestDeterministicBase(CustomTestCase):
|
|
22
|
+
@classmethod
|
|
23
|
+
def get_server_args(cls):
|
|
24
|
+
return COMMON_SERVER_ARGS
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def get_model(cls):
|
|
28
|
+
return DEFAULT_MODEL
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def setUpClass(cls):
|
|
32
|
+
cls.model = cls.get_model()
|
|
33
|
+
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
34
|
+
if "--attention-backend" not in cls.get_server_args():
|
|
35
|
+
raise unittest.SkipTest("Skip the base test class")
|
|
36
|
+
|
|
37
|
+
cls.process = popen_launch_server(
|
|
38
|
+
cls.model,
|
|
39
|
+
cls.base_url,
|
|
40
|
+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
41
|
+
other_args=cls.get_server_args(),
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def tearDownClass(cls):
|
|
46
|
+
kill_process_tree(cls.process.pid)
|
|
47
|
+
|
|
48
|
+
def _extract_host_and_port(self, url):
|
|
49
|
+
return url.split("://")[-1].split(":")[0], int(url.split(":")[-1])
|
|
50
|
+
|
|
51
|
+
def test_single(self):
|
|
52
|
+
args = BenchArgs()
|
|
53
|
+
url = DEFAULT_URL_FOR_TEST
|
|
54
|
+
args.host, args.port = self._extract_host_and_port(url)
|
|
55
|
+
args.test_mode = "single"
|
|
56
|
+
args.n_start = 10
|
|
57
|
+
args.n_trials = 20
|
|
58
|
+
results = test_deterministic(args)
|
|
59
|
+
args.temperature = 0.5 # test for deterministic sampling
|
|
60
|
+
for result in results:
|
|
61
|
+
assert result == 1
|
|
62
|
+
|
|
63
|
+
def test_prefix(self):
|
|
64
|
+
args = BenchArgs()
|
|
65
|
+
url = DEFAULT_URL_FOR_TEST
|
|
66
|
+
args.host, args.port = self._extract_host_and_port(url)
|
|
67
|
+
args.test_mode = "prefix"
|
|
68
|
+
args.n_start = 10
|
|
69
|
+
args.n_trials = 10
|
|
70
|
+
args.temperature = 0.5 # test for deterministic sampling
|
|
71
|
+
results = test_deterministic(args)
|
|
72
|
+
for result in results:
|
|
73
|
+
assert result == 1
|
|
@@ -1,20 +1,56 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
1
3
|
import time
|
|
4
|
+
import warnings
|
|
5
|
+
from urllib.parse import urlparse
|
|
2
6
|
|
|
3
7
|
import requests
|
|
4
8
|
|
|
9
|
+
from sglang.srt.environ import envs
|
|
5
10
|
from sglang.srt.utils import kill_process_tree
|
|
6
11
|
from sglang.test.test_utils import (
|
|
7
12
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
13
|
+
DEFAULT_URL_FOR_TEST,
|
|
8
14
|
CustomTestCase,
|
|
15
|
+
is_in_ci,
|
|
9
16
|
popen_with_error_check,
|
|
10
17
|
)
|
|
11
18
|
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
12
21
|
|
|
13
22
|
class TestDisaggregationBase(CustomTestCase):
|
|
14
23
|
@classmethod
|
|
15
24
|
def setUpClass(cls):
|
|
25
|
+
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
|
|
26
|
+
cls.base_host = parsed_url.hostname
|
|
27
|
+
base_port = str(parsed_url.port)
|
|
28
|
+
cls.lb_port = base_port
|
|
29
|
+
cls.prefill_port = f"{int(base_port) + 100}"
|
|
30
|
+
cls.decode_port = f"{int(base_port) + 200}"
|
|
31
|
+
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
|
|
32
|
+
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
|
|
33
|
+
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
|
|
34
|
+
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
|
|
16
35
|
cls.process_lb, cls.process_decode, cls.process_prefill = None, None, None
|
|
17
|
-
|
|
36
|
+
|
|
37
|
+
# config transfer backend and rdma devices
|
|
38
|
+
if is_in_ci():
|
|
39
|
+
cls.transfer_backend = ["--disaggregation-transfer-backend", "mooncake"]
|
|
40
|
+
cls.rdma_devices = ["--disaggregation-ib-device", get_rdma_devices_args()]
|
|
41
|
+
else:
|
|
42
|
+
cls.transfer_backend = [
|
|
43
|
+
"--disaggregation-transfer-backend",
|
|
44
|
+
envs.SGLANG_TEST_PD_DISAGG_BACKEND.get(),
|
|
45
|
+
]
|
|
46
|
+
cls.rdma_devices = [
|
|
47
|
+
"--disaggregation-ib-device",
|
|
48
|
+
envs.SGLANG_TEST_PD_DISAGG_DEVICES.get(),
|
|
49
|
+
]
|
|
50
|
+
if cls.rdma_devices[1] is None:
|
|
51
|
+
cls.rdma_devices = []
|
|
52
|
+
msg = "No RDMA devices specified for disaggregation test, using default settings."
|
|
53
|
+
warnings.warn(msg)
|
|
18
54
|
|
|
19
55
|
@classmethod
|
|
20
56
|
def launch_lb(cls):
|
|
@@ -64,3 +100,59 @@ class TestDisaggregationBase(CustomTestCase):
|
|
|
64
100
|
|
|
65
101
|
# wait for 5 seconds
|
|
66
102
|
time.sleep(5)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def get_rdma_devices_args():
|
|
106
|
+
def _parse_list_env(var_name: str):
|
|
107
|
+
val = os.getenv(var_name)
|
|
108
|
+
if not val:
|
|
109
|
+
return None
|
|
110
|
+
items = [x.strip() for x in val.split(",") if x.strip()]
|
|
111
|
+
return items or None
|
|
112
|
+
|
|
113
|
+
def _pick_default_pair(rdma_all_devices):
|
|
114
|
+
return [rdma_all_devices[0], rdma_all_devices[len(rdma_all_devices) // 2]]
|
|
115
|
+
|
|
116
|
+
rdma_all_devices = _parse_list_env("SGLANG_CI_RDMA_ALL_DEVICES") or [
|
|
117
|
+
f"mlx5_roce{i}" for i in range(8)
|
|
118
|
+
]
|
|
119
|
+
logger.info("Resolved rdma_all_devices=%s", rdma_all_devices)
|
|
120
|
+
|
|
121
|
+
n_rdma = len(rdma_all_devices)
|
|
122
|
+
|
|
123
|
+
# 1. Get visible GPU indices
|
|
124
|
+
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
|
|
125
|
+
if not cuda_visible_devices:
|
|
126
|
+
warnings.warn("CUDA_VISIBLE_DEVICES is not set. Using default RDMA devices.")
|
|
127
|
+
return ",".join(_pick_default_pair(rdma_all_devices))
|
|
128
|
+
|
|
129
|
+
try:
|
|
130
|
+
# Convert to list of integers (handling possible spaces and empty strings)
|
|
131
|
+
gpu_indices = [
|
|
132
|
+
int(idx.strip()) for idx in cuda_visible_devices.split(",") if idx.strip()
|
|
133
|
+
]
|
|
134
|
+
if not gpu_indices or len(gpu_indices) > 4:
|
|
135
|
+
return ",".join(_pick_default_pair(rdma_all_devices))
|
|
136
|
+
except ValueError:
|
|
137
|
+
warnings.warn(f"Invalid CUDA_VISIBLE_DEVICES format: {cuda_visible_devices}")
|
|
138
|
+
return ",".join(_pick_default_pair(rdma_all_devices))
|
|
139
|
+
|
|
140
|
+
# 2. Calculate base RDMA index group (each group of 4 GPUs uses consecutive devices)
|
|
141
|
+
base_rdma_group = (min(gpu_indices) // 4) * 4
|
|
142
|
+
for gpu_idx in gpu_indices:
|
|
143
|
+
if not (base_rdma_group <= gpu_idx < base_rdma_group + 4):
|
|
144
|
+
warnings.warn(
|
|
145
|
+
f"GPU index {gpu_idx} is outside expected group "
|
|
146
|
+
f"{base_rdma_group}-{base_rdma_group+3}"
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# 3. Generate RDMA device names
|
|
150
|
+
rdma_devices = []
|
|
151
|
+
for gpu_idx in gpu_indices:
|
|
152
|
+
nic_index = gpu_idx // (8 // n_rdma)
|
|
153
|
+
rdma_devices.append(rdma_all_devices[nic_index])
|
|
154
|
+
|
|
155
|
+
if not rdma_devices:
|
|
156
|
+
return ",".join(_pick_default_pair(rdma_all_devices))
|
|
157
|
+
|
|
158
|
+
return ",".join(rdma_devices)
|
sglang/test/test_marlin_moe.py
CHANGED
sglang/test/test_programs.py
CHANGED
|
@@ -551,7 +551,7 @@ def test_gen_min_new_tokens():
|
|
|
551
551
|
We verify that the number of tokens in the answer is >= the min_tokens threshold.
|
|
552
552
|
"""
|
|
553
553
|
import sglang as sgl
|
|
554
|
-
from sglang.srt.hf_transformers_utils import get_tokenizer
|
|
554
|
+
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
|
|
555
555
|
|
|
556
556
|
model_path = sgl.global_config.default_backend.endpoint.get_model_name()
|
|
557
557
|
MIN_TOKENS, MAX_TOKENS = 64, 128
|