sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,12 @@
|
|
|
1
|
-
import hashlib
|
|
2
1
|
import logging
|
|
3
2
|
import os
|
|
4
3
|
import time
|
|
5
4
|
import uuid
|
|
6
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, List, Optional, Union
|
|
7
6
|
|
|
8
7
|
import torch
|
|
9
8
|
|
|
10
|
-
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
|
9
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
|
11
10
|
|
|
12
11
|
from .nixl_utils import NixlBackendSelection, NixlFileManager, NixlRegistration
|
|
13
12
|
|
|
@@ -26,7 +25,12 @@ logger = logging.getLogger(__name__)
|
|
|
26
25
|
class HiCacheNixl(HiCacheStorage):
|
|
27
26
|
"""HiCacheNixl provides high-performance storage using NIXL plugins."""
|
|
28
27
|
|
|
29
|
-
def __init__(
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
storage_config: HiCacheStorageConfig,
|
|
31
|
+
file_path: str = "/tmp/hicache_storage",
|
|
32
|
+
plugin: str = "auto",
|
|
33
|
+
):
|
|
30
34
|
"""Initialize NIXL storage connector."""
|
|
31
35
|
# Might be better to be unified across HiCache backends and moved to HiCacheController
|
|
32
36
|
file_path = os.getenv("SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR", file_path)
|
|
@@ -36,6 +40,19 @@ class HiCacheNixl(HiCacheStorage):
|
|
|
36
40
|
else None
|
|
37
41
|
)
|
|
38
42
|
|
|
43
|
+
# Initialize suffix based on storage config
|
|
44
|
+
tp_rank, tp_size, model_name, is_mla_model = (
|
|
45
|
+
storage_config.tp_rank,
|
|
46
|
+
storage_config.tp_size,
|
|
47
|
+
storage_config.model_name,
|
|
48
|
+
storage_config.is_mla_model,
|
|
49
|
+
)
|
|
50
|
+
model_name = "-".join(model_name.split("/")) if model_name else ""
|
|
51
|
+
if is_mla_model:
|
|
52
|
+
self.config_suffix = f"_{model_name}"
|
|
53
|
+
else:
|
|
54
|
+
self.config_suffix = f"_{model_name}_{tp_rank}_{tp_size}"
|
|
55
|
+
|
|
39
56
|
agent_config = nixl_agent_config(backends=[])
|
|
40
57
|
self.agent_name = f"hicache_nixl_{str(uuid.uuid4())}"
|
|
41
58
|
self.agent = nixl_agent(self.agent_name, agent_config)
|
|
@@ -46,6 +63,9 @@ class HiCacheNixl(HiCacheStorage):
|
|
|
46
63
|
|
|
47
64
|
self.registration = NixlRegistration(self.agent)
|
|
48
65
|
|
|
66
|
+
def _get_suffixed_key(self, key: str) -> str:
|
|
67
|
+
return key + self.config_suffix
|
|
68
|
+
|
|
49
69
|
def register_buffers(
|
|
50
70
|
self, buffers: Union[torch.Tensor, List[torch.Tensor], List[tuple]]
|
|
51
71
|
) -> Optional[Any]:
|
|
@@ -194,11 +214,14 @@ class HiCacheNixl(HiCacheStorage):
|
|
|
194
214
|
else:
|
|
195
215
|
dest = target_locations
|
|
196
216
|
|
|
217
|
+
# Add suffix to keys
|
|
218
|
+
suffixed_keys = [self._get_suffixed_key(key) for key in keys]
|
|
219
|
+
|
|
197
220
|
if self.backend_selector.mem_type == "FILE":
|
|
198
|
-
file_paths = [self.file_manager.get_file_path(key) for key in
|
|
221
|
+
file_paths = [self.file_manager.get_file_path(key) for key in suffixed_keys]
|
|
199
222
|
success = self._execute_transfer(dest, file_paths, "READ")
|
|
200
223
|
else:
|
|
201
|
-
success = self._execute_transfer(dest,
|
|
224
|
+
success = self._execute_transfer(dest, suffixed_keys, "READ")
|
|
202
225
|
return target_locations if success and not target_sizes else [None] * len(keys)
|
|
203
226
|
|
|
204
227
|
def set(
|
|
@@ -227,9 +250,12 @@ class HiCacheNixl(HiCacheStorage):
|
|
|
227
250
|
if not values:
|
|
228
251
|
values = list(zip(target_locations, target_sizes))
|
|
229
252
|
|
|
253
|
+
# Add suffix to keys
|
|
254
|
+
suffixed_keys = [self._get_suffixed_key(key) for key in keys]
|
|
255
|
+
|
|
230
256
|
if self.backend_selector.mem_type == "FILE":
|
|
231
257
|
file_paths = []
|
|
232
|
-
for key in
|
|
258
|
+
for key in suffixed_keys:
|
|
233
259
|
file_path = self.file_manager.get_file_path(key)
|
|
234
260
|
# New file per set, to be updated when partial writes is added to HiCache
|
|
235
261
|
if not self.file_manager.create_file(file_path):
|
|
@@ -238,11 +264,14 @@ class HiCacheNixl(HiCacheStorage):
|
|
|
238
264
|
file_paths.append(file_path)
|
|
239
265
|
return self._execute_transfer(values, file_paths, "WRITE")
|
|
240
266
|
else: # mem_type == "OBJ"
|
|
241
|
-
return self._execute_transfer(values,
|
|
267
|
+
return self._execute_transfer(values, suffixed_keys, "WRITE")
|
|
242
268
|
|
|
243
269
|
def exists(self, key: str) -> bool:
|
|
270
|
+
# Add suffix to key
|
|
271
|
+
suffixed_key = self._get_suffixed_key(key)
|
|
272
|
+
|
|
244
273
|
tuples = self.registration.create_query_tuples(
|
|
245
|
-
|
|
274
|
+
suffixed_key,
|
|
246
275
|
self.backend_selector.mem_type,
|
|
247
276
|
self.file_manager if self.backend_selector.mem_type == "FILE" else None,
|
|
248
277
|
)
|
|
@@ -2,11 +2,12 @@
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import unittest
|
|
5
|
-
from typing import List
|
|
5
|
+
from typing import List
|
|
6
6
|
from unittest.mock import MagicMock
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
|
|
10
11
|
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
|
11
12
|
from sglang.srt.mem_cache.storage.nixl.nixl_utils import (
|
|
12
13
|
NixlFileManager,
|
|
@@ -31,8 +32,22 @@ class TestNixlUnified(unittest.TestCase):
|
|
|
31
32
|
# Create instances
|
|
32
33
|
self.file_manager = NixlFileManager(self.test_dir)
|
|
33
34
|
self.registration = NixlRegistration(self.mock_agent)
|
|
35
|
+
|
|
36
|
+
# Create storage config for testing
|
|
37
|
+
self.storage_config = HiCacheStorageConfig(
|
|
38
|
+
tp_rank=0,
|
|
39
|
+
tp_size=2,
|
|
40
|
+
is_mla_model=False,
|
|
41
|
+
is_page_first_layout=False,
|
|
42
|
+
model_name="test_model",
|
|
43
|
+
)
|
|
44
|
+
|
|
34
45
|
try:
|
|
35
|
-
self.hicache = HiCacheNixl(
|
|
46
|
+
self.hicache = HiCacheNixl(
|
|
47
|
+
storage_config=self.storage_config,
|
|
48
|
+
file_path=self.test_dir,
|
|
49
|
+
plugin="POSIX",
|
|
50
|
+
)
|
|
36
51
|
except ImportError:
|
|
37
52
|
self.skipTest("NIXL not available, skipping NIXL storage tests")
|
|
38
53
|
|
|
@@ -32,6 +32,7 @@ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
|
|
32
32
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
|
33
33
|
from sglang.srt.mem_cache.radix_cache import (
|
|
34
34
|
RadixKey,
|
|
35
|
+
_convert_to_bigram_key,
|
|
35
36
|
_key_match_page_size1,
|
|
36
37
|
_key_match_paged,
|
|
37
38
|
get_child_key,
|
|
@@ -327,12 +328,14 @@ class SWARadixCache(BasePrefixCache):
|
|
|
327
328
|
sliding_window_size: int,
|
|
328
329
|
page_size: int,
|
|
329
330
|
disable: bool = False,
|
|
331
|
+
is_eagle: bool = False,
|
|
330
332
|
):
|
|
331
333
|
assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
|
|
332
334
|
self.req_to_token_pool = req_to_token_pool
|
|
333
335
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
|
334
336
|
self.page_size = page_size
|
|
335
337
|
self.disable = disable
|
|
338
|
+
self.is_eagle = is_eagle
|
|
336
339
|
|
|
337
340
|
if self.token_to_kv_pool_allocator:
|
|
338
341
|
self.device = self.token_to_kv_pool_allocator.device
|
|
@@ -346,6 +349,11 @@ class SWARadixCache(BasePrefixCache):
|
|
|
346
349
|
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
|
|
347
350
|
self.get_child_key_fn = partial(get_child_key, page_size=page_size)
|
|
348
351
|
|
|
352
|
+
if is_eagle:
|
|
353
|
+
self.key_convert_fn = _convert_to_bigram_key
|
|
354
|
+
else:
|
|
355
|
+
self.key_convert_fn = lambda key: key
|
|
356
|
+
|
|
349
357
|
self.sliding_window_size = sliding_window_size
|
|
350
358
|
self.reset()
|
|
351
359
|
|
|
@@ -376,6 +384,8 @@ class SWARadixCache(BasePrefixCache):
|
|
|
376
384
|
The last node create a new child if the prefix is shorter
|
|
377
385
|
than the last node's value.
|
|
378
386
|
"""
|
|
387
|
+
key.token_ids = self.key_convert_fn(key.token_ids)
|
|
388
|
+
|
|
379
389
|
if self.disable or len(key) == 0:
|
|
380
390
|
return MatchResult(
|
|
381
391
|
device_indices=torch.empty(
|
|
@@ -406,42 +416,73 @@ class SWARadixCache(BasePrefixCache):
|
|
|
406
416
|
if self.disable:
|
|
407
417
|
return 0
|
|
408
418
|
|
|
419
|
+
key.token_ids = self.key_convert_fn(key.token_ids)
|
|
420
|
+
|
|
409
421
|
if value is None:
|
|
410
422
|
value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
|
|
423
|
+
|
|
424
|
+
if self.is_eagle:
|
|
425
|
+
# Make sure the value len equal to the EAGLE bigram key len
|
|
426
|
+
value = value[: len(key)]
|
|
427
|
+
|
|
411
428
|
return self._insert_helper(self.root_node, key, value, prev_prefix_len)
|
|
412
429
|
|
|
413
|
-
def cache_finished_req(self, req: Req) -> None:
|
|
430
|
+
def cache_finished_req(self, req: Req, is_insert: bool = True) -> None:
|
|
414
431
|
"""Cache request when it finishes."""
|
|
432
|
+
all_token_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
|
415
433
|
if self.disable:
|
|
416
434
|
kv_indices = self.req_to_token_pool.req_to_token[
|
|
417
|
-
req.req_pool_idx,
|
|
418
|
-
: len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0),
|
|
435
|
+
req.req_pool_idx, :all_token_len
|
|
419
436
|
]
|
|
420
437
|
self.token_to_kv_pool_allocator.free(kv_indices)
|
|
421
438
|
self.req_to_token_pool.free(req.req_pool_idx)
|
|
422
439
|
return
|
|
423
440
|
|
|
424
|
-
token_ids = (req.origin_input_ids + req.output_ids)[
|
|
441
|
+
token_ids = (req.origin_input_ids + req.output_ids)[:all_token_len]
|
|
442
|
+
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
|
|
443
|
+
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
|
|
444
|
+
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
|
425
445
|
kv_indices = self.req_to_token_pool.req_to_token[
|
|
426
|
-
req.req_pool_idx, :
|
|
446
|
+
req.req_pool_idx, :all_token_len
|
|
427
447
|
]
|
|
428
448
|
|
|
429
449
|
if self.page_size != 1:
|
|
430
|
-
page_aligned_len =
|
|
431
|
-
page_aligned_kv_indices = kv_indices[:page_aligned_len].
|
|
432
|
-
|
|
450
|
+
page_aligned_len = actual_kv_len // self.page_size * self.page_size
|
|
451
|
+
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
|
452
|
+
dtype=torch.int64, copy=True
|
|
453
|
+
)
|
|
433
454
|
else:
|
|
434
|
-
page_aligned_len =
|
|
435
|
-
page_aligned_kv_indices = kv_indices.
|
|
455
|
+
page_aligned_len = actual_kv_len
|
|
456
|
+
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
|
457
|
+
if self.is_eagle:
|
|
458
|
+
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
|
459
|
+
|
|
460
|
+
page_aligned_token_len = (
|
|
461
|
+
page_aligned_len + 1 if self.is_eagle else page_aligned_len
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
old_prefix_len = len(req.prefix_indices)
|
|
465
|
+
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
|
|
466
|
+
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
|
|
467
|
+
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
|
|
468
|
+
old_prefix_len -= 1
|
|
436
469
|
|
|
437
470
|
# Radix Cache takes one ref in memory pool
|
|
438
471
|
# insert the token_ids and kv_indices into the radix tree
|
|
439
472
|
# Note: the insert function already frees the overlapped kv_indices
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
473
|
+
if is_insert:
|
|
474
|
+
new_prefix_len = self.insert(
|
|
475
|
+
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
|
|
476
|
+
page_aligned_kv_indices,
|
|
477
|
+
old_prefix_len,
|
|
478
|
+
)
|
|
479
|
+
else:
|
|
480
|
+
self.token_to_kv_pool_allocator.free(
|
|
481
|
+
kv_indices[old_prefix_len:page_aligned_len]
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
# free the unaligned tail
|
|
485
|
+
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
|
445
486
|
|
|
446
487
|
# Remove req slot release the cache lock
|
|
447
488
|
self.req_to_token_pool.free(req.req_pool_idx)
|
|
@@ -459,39 +500,58 @@ class SWARadixCache(BasePrefixCache):
|
|
|
459
500
|
return
|
|
460
501
|
|
|
461
502
|
token_ids = req.fill_ids
|
|
503
|
+
all_token_len = len(token_ids)
|
|
504
|
+
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
|
|
505
|
+
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
|
|
506
|
+
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
|
462
507
|
kv_indices = self.req_to_token_pool.req_to_token[
|
|
463
|
-
req.req_pool_idx, :
|
|
508
|
+
req.req_pool_idx, :all_token_len
|
|
464
509
|
]
|
|
465
510
|
|
|
466
511
|
if self.page_size != 1:
|
|
467
|
-
page_aligned_len =
|
|
468
|
-
page_aligned_kv_indices = kv_indices[:page_aligned_len].
|
|
512
|
+
page_aligned_len = actual_kv_len // self.page_size * self.page_size
|
|
513
|
+
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
|
514
|
+
dtype=torch.int64, copy=True
|
|
515
|
+
)
|
|
469
516
|
else:
|
|
470
|
-
page_aligned_len =
|
|
471
|
-
page_aligned_kv_indices = kv_indices.
|
|
472
|
-
|
|
517
|
+
page_aligned_len = actual_kv_len
|
|
518
|
+
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
|
519
|
+
|
|
520
|
+
# For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
|
|
521
|
+
page_aligned_token_len = (
|
|
522
|
+
page_aligned_len + 1 if self.is_eagle else page_aligned_len
|
|
523
|
+
)
|
|
524
|
+
page_aligned_token_ids = token_ids[:page_aligned_token_len]
|
|
525
|
+
|
|
526
|
+
old_prefix_len = len(req.prefix_indices)
|
|
527
|
+
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
|
|
528
|
+
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
|
|
529
|
+
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
|
|
530
|
+
old_prefix_len -= 1
|
|
473
531
|
|
|
474
532
|
# Radix Cache takes one ref in memory pool
|
|
475
533
|
# Note: the insert function already frees the overlapped kv_indices
|
|
476
534
|
new_prefix_len = self.insert(
|
|
477
535
|
RadixKey(page_aligned_token_ids, req.extra_key),
|
|
478
536
|
page_aligned_kv_indices,
|
|
479
|
-
|
|
537
|
+
old_prefix_len,
|
|
480
538
|
)
|
|
481
539
|
|
|
482
540
|
# The prefix indices could be updated, reuse it
|
|
483
541
|
new_indices, new_last_node, _, _ = self.match_prefix(
|
|
484
542
|
RadixKey(page_aligned_token_ids, req.extra_key)
|
|
485
543
|
)
|
|
486
|
-
assert
|
|
544
|
+
assert old_prefix_len <= len(
|
|
487
545
|
new_indices
|
|
488
546
|
), f"{req.prefix_indices=}, {new_indices=}"
|
|
489
547
|
assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}"
|
|
490
548
|
self.req_to_token_pool.write(
|
|
491
|
-
(req.req_pool_idx, slice(
|
|
492
|
-
new_indices[
|
|
549
|
+
(req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
|
|
550
|
+
new_indices[old_prefix_len:],
|
|
493
551
|
)
|
|
494
552
|
|
|
553
|
+
req.last_matched_prefix_len = len(new_indices)
|
|
554
|
+
|
|
495
555
|
self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
|
496
556
|
swa_uuid_for_lock = self.inc_lock_ref(new_last_node)
|
|
497
557
|
|
|
@@ -501,7 +561,13 @@ class SWARadixCache(BasePrefixCache):
|
|
|
501
561
|
[new_indices, kv_indices[len(new_indices) :]]
|
|
502
562
|
)
|
|
503
563
|
else:
|
|
504
|
-
|
|
564
|
+
if self.is_eagle:
|
|
565
|
+
# Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
|
|
566
|
+
req.prefix_indices = torch.cat(
|
|
567
|
+
[new_indices, kv_indices[actual_kv_len:]]
|
|
568
|
+
)
|
|
569
|
+
else:
|
|
570
|
+
req.prefix_indices = new_indices
|
|
505
571
|
req.last_node = new_last_node
|
|
506
572
|
req.swa_uuid_for_lock = swa_uuid_for_lock
|
|
507
573
|
|
sglang/srt/metrics/collector.py
CHANGED
|
@@ -118,6 +118,7 @@ class SchedulerStats:
|
|
|
118
118
|
num_running_reqs: int = 0
|
|
119
119
|
num_used_tokens: int = 0
|
|
120
120
|
token_usage: float = 0.0
|
|
121
|
+
pending_prealloc_token_usage: float = 0.0
|
|
121
122
|
swa_token_usage: float = 0.0
|
|
122
123
|
gen_throughput: float = 0.0
|
|
123
124
|
num_queue_reqs: int = 0
|
|
@@ -127,6 +128,7 @@ class SchedulerStats:
|
|
|
127
128
|
|
|
128
129
|
# Speculative decoding
|
|
129
130
|
spec_accept_length: float = 0.0
|
|
131
|
+
spec_accept_rate: float = 0.0
|
|
130
132
|
|
|
131
133
|
# Retract
|
|
132
134
|
num_retracted_reqs: int = 0
|
|
@@ -148,6 +150,9 @@ class SchedulerStats:
|
|
|
148
150
|
engine_startup_time: float = 0.0
|
|
149
151
|
engine_load_weights_time: float = 0.0
|
|
150
152
|
|
|
153
|
+
# CUDA graph
|
|
154
|
+
is_cuda_graph: float = 0.0
|
|
155
|
+
|
|
151
156
|
|
|
152
157
|
class SchedulerMetricsCollector:
|
|
153
158
|
|
|
@@ -176,6 +181,12 @@ class SchedulerMetricsCollector:
|
|
|
176
181
|
labelnames=labels.keys(),
|
|
177
182
|
multiprocess_mode="mostrecent",
|
|
178
183
|
)
|
|
184
|
+
self.pending_prealloc_token_usage = Gauge(
|
|
185
|
+
name="sglang:pending_prealloc_token_usage",
|
|
186
|
+
documentation="The token usage for pending preallocated tokens (not preallocated yet).",
|
|
187
|
+
labelnames=labels.keys(),
|
|
188
|
+
multiprocess_mode="mostrecent",
|
|
189
|
+
)
|
|
179
190
|
self.swa_token_usage = Gauge(
|
|
180
191
|
name="sglang:swa_token_usage",
|
|
181
192
|
documentation="The token usage for SWA layers.",
|
|
@@ -220,6 +231,12 @@ class SchedulerMetricsCollector:
|
|
|
220
231
|
labelnames=labels.keys(),
|
|
221
232
|
multiprocess_mode="mostrecent",
|
|
222
233
|
)
|
|
234
|
+
self.spec_accept_rate = Gauge(
|
|
235
|
+
name="sglang:spec_accept_rate",
|
|
236
|
+
documentation="The average acceptance rate of speculative decoding (`accepted tokens / total draft tokens` in batch).",
|
|
237
|
+
labelnames=labels.keys(),
|
|
238
|
+
multiprocess_mode="mostrecent",
|
|
239
|
+
)
|
|
223
240
|
|
|
224
241
|
# Retract
|
|
225
242
|
self.num_retracted_reqs = Gauge(
|
|
@@ -485,6 +502,13 @@ class SchedulerMetricsCollector:
|
|
|
485
502
|
labelnames=list(labels.keys()) + ["stage"],
|
|
486
503
|
)
|
|
487
504
|
|
|
505
|
+
self.is_cuda_graph = Gauge(
|
|
506
|
+
name="sglang:is_cuda_graph",
|
|
507
|
+
documentation="Whether the batch is using CUDA graph.",
|
|
508
|
+
labelnames=labels.keys(),
|
|
509
|
+
multiprocess_mode="mostrecent",
|
|
510
|
+
)
|
|
511
|
+
|
|
488
512
|
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
|
|
489
513
|
# Convenience function for logging to gauge.
|
|
490
514
|
gauge.labels(**self.labels).set(data)
|
|
@@ -509,6 +533,9 @@ class SchedulerMetricsCollector:
|
|
|
509
533
|
self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
|
|
510
534
|
self._log_gauge(self.num_used_tokens, stats.num_used_tokens)
|
|
511
535
|
self._log_gauge(self.token_usage, stats.token_usage)
|
|
536
|
+
self._log_gauge(
|
|
537
|
+
self.pending_prealloc_token_usage, stats.pending_prealloc_token_usage
|
|
538
|
+
)
|
|
512
539
|
self._log_gauge(self.swa_token_usage, stats.swa_token_usage)
|
|
513
540
|
self._log_gauge(self.gen_throughput, stats.gen_throughput)
|
|
514
541
|
self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
|
|
@@ -520,6 +547,7 @@ class SchedulerMetricsCollector:
|
|
|
520
547
|
|
|
521
548
|
# Speculative decoding
|
|
522
549
|
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
|
|
550
|
+
self._log_gauge(self.spec_accept_rate, stats.spec_accept_rate)
|
|
523
551
|
|
|
524
552
|
# PD disaggregation
|
|
525
553
|
self._log_gauge(
|
|
@@ -556,6 +584,9 @@ class SchedulerMetricsCollector:
|
|
|
556
584
|
self.engine_load_weights_time, stats.engine_load_weights_time
|
|
557
585
|
)
|
|
558
586
|
|
|
587
|
+
# CUDA graph
|
|
588
|
+
self._log_gauge(self.is_cuda_graph, stats.is_cuda_graph)
|
|
589
|
+
|
|
559
590
|
self.last_log_time = time.perf_counter()
|
|
560
591
|
|
|
561
592
|
def log_grammar_stats(self, grammar_stats) -> None:
|
sglang/srt/metrics/func_timer.py
CHANGED
|
@@ -18,7 +18,7 @@ Records the latency of some functions
|
|
|
18
18
|
import asyncio
|
|
19
19
|
import time
|
|
20
20
|
from functools import wraps
|
|
21
|
-
from typing import Any, Callable,
|
|
21
|
+
from typing import Any, Callable, Optional
|
|
22
22
|
|
|
23
23
|
from sglang.srt.metrics.utils import exponential_buckets
|
|
24
24
|
|
|
@@ -38,8 +38,11 @@ from sglang.srt.layers.dp_attention import (
|
|
|
38
38
|
get_attention_tp_rank,
|
|
39
39
|
get_attention_tp_size,
|
|
40
40
|
set_dp_buffer_len,
|
|
41
|
+
set_is_extend_in_batch,
|
|
41
42
|
)
|
|
42
43
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
44
|
+
from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPBuffer
|
|
45
|
+
from sglang.srt.layers.moe.utils import get_deepep_mode, get_moe_a2a_backend
|
|
43
46
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
|
44
47
|
from sglang.srt.model_executor.forward_batch_info import (
|
|
45
48
|
CaptureHiddenMode,
|
|
@@ -53,7 +56,6 @@ from sglang.srt.utils import (
|
|
|
53
56
|
empty_context,
|
|
54
57
|
get_available_gpu_memory,
|
|
55
58
|
get_bool_env_var,
|
|
56
|
-
get_device_memory_capacity,
|
|
57
59
|
is_hip,
|
|
58
60
|
log_info_on_rank0,
|
|
59
61
|
require_attn_tp_gather,
|
|
@@ -63,6 +65,13 @@ from sglang.srt.utils import (
|
|
|
63
65
|
)
|
|
64
66
|
from sglang.srt.utils.patch_torch import monkey_patch_torch_compile
|
|
65
67
|
|
|
68
|
+
try:
|
|
69
|
+
from kt_kernel import AMXMoEWrapper
|
|
70
|
+
|
|
71
|
+
KTRANSFORMERS_AVAILABLE = True
|
|
72
|
+
except ImportError:
|
|
73
|
+
KTRANSFORMERS_AVAILABLE = False
|
|
74
|
+
|
|
66
75
|
_is_hip = is_hip()
|
|
67
76
|
|
|
68
77
|
logger = logging.getLogger(__name__)
|
|
@@ -241,9 +250,13 @@ class CudaGraphRunner:
|
|
|
241
250
|
self.attn_tp_size = get_attention_tp_size()
|
|
242
251
|
self.attn_tp_rank = get_attention_tp_rank()
|
|
243
252
|
|
|
253
|
+
self.deepep_adapter = DeepEPCudaGraphRunnerAdapter()
|
|
254
|
+
|
|
244
255
|
# Batch sizes to capture
|
|
245
256
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
|
246
257
|
log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}")
|
|
258
|
+
if KTRANSFORMERS_AVAILABLE:
|
|
259
|
+
AMXMoEWrapper.set_capture_batch_sizes(self.capture_bs)
|
|
247
260
|
self.capture_forward_mode = ForwardMode.DECODE
|
|
248
261
|
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
|
249
262
|
self.num_tokens_per_bs = 1
|
|
@@ -274,7 +287,6 @@ class CudaGraphRunner:
|
|
|
274
287
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
|
275
288
|
)
|
|
276
289
|
|
|
277
|
-
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
|
278
290
|
self.encoder_len_fill_value = 0
|
|
279
291
|
self.seq_lens_cpu = torch.full(
|
|
280
292
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
|
@@ -637,6 +649,7 @@ class CudaGraphRunner:
|
|
|
637
649
|
# Clean intermediate result cache for DP attention
|
|
638
650
|
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
|
639
651
|
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
|
652
|
+
set_is_extend_in_batch(False)
|
|
640
653
|
|
|
641
654
|
kwargs = {}
|
|
642
655
|
if (
|
|
@@ -655,6 +668,8 @@ class CudaGraphRunner:
|
|
|
655
668
|
)
|
|
656
669
|
return logits_output_or_pp_proxy_tensors
|
|
657
670
|
|
|
671
|
+
self.deepep_adapter.capture(is_extend_in_batch=False)
|
|
672
|
+
|
|
658
673
|
for _ in range(2):
|
|
659
674
|
self.device_module.synchronize()
|
|
660
675
|
self.model_runner.tp_group.barrier()
|
|
@@ -678,8 +693,9 @@ class CudaGraphRunner:
|
|
|
678
693
|
capture_hidden_mode_required_by_forward_batch = (
|
|
679
694
|
forward_batch.capture_hidden_mode
|
|
680
695
|
)
|
|
681
|
-
capture_hidden_mode_required_by_spec_info =
|
|
682
|
-
forward_batch.spec_info, "capture_hidden_mode",
|
|
696
|
+
capture_hidden_mode_required_by_spec_info = (
|
|
697
|
+
getattr(forward_batch.spec_info, "capture_hidden_mode", None)
|
|
698
|
+
or CaptureHiddenMode.NULL
|
|
683
699
|
)
|
|
684
700
|
capture_hidden_mode_required_for_returning_hidden_states = (
|
|
685
701
|
CaptureHiddenMode.FULL
|
|
@@ -797,6 +813,8 @@ class CudaGraphRunner:
|
|
|
797
813
|
skip_attn_backend_init: bool = False,
|
|
798
814
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
|
799
815
|
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
|
816
|
+
self.deepep_adapter.replay()
|
|
817
|
+
|
|
800
818
|
if not skip_attn_backend_init:
|
|
801
819
|
self.replay_prepare(forward_batch, pp_proxy_tensors)
|
|
802
820
|
else:
|
|
@@ -849,7 +867,7 @@ class CudaGraphRunner:
|
|
|
849
867
|
)
|
|
850
868
|
|
|
851
869
|
elif self.model_runner.spec_algorithm.is_ngram():
|
|
852
|
-
from sglang.srt.speculative.
|
|
870
|
+
from sglang.srt.speculative.ngram_info import NgramVerifyInput
|
|
853
871
|
|
|
854
872
|
spec_info = NgramVerifyInput(
|
|
855
873
|
draft_token=None,
|
|
@@ -873,3 +891,23 @@ CUDA_GRAPH_CAPTURE_FAILED_MSG = (
|
|
|
873
891
|
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
|
|
874
892
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
|
875
893
|
)
|
|
894
|
+
|
|
895
|
+
|
|
896
|
+
class DeepEPCudaGraphRunnerAdapter:
|
|
897
|
+
def __init__(self):
|
|
898
|
+
# Record DeepEP mode used during capture to ensure replay consistency
|
|
899
|
+
self._captured_deepep_mode = None
|
|
900
|
+
|
|
901
|
+
def capture(self, is_extend_in_batch: bool):
|
|
902
|
+
if not get_moe_a2a_backend().is_deepep():
|
|
903
|
+
return
|
|
904
|
+
self._captured_deepep_mode = get_deepep_mode().resolve(
|
|
905
|
+
is_extend_in_batch=is_extend_in_batch
|
|
906
|
+
)
|
|
907
|
+
DeepEPBuffer.set_dispatch_mode(self._captured_deepep_mode)
|
|
908
|
+
|
|
909
|
+
def replay(self):
|
|
910
|
+
if not get_moe_a2a_backend().is_deepep():
|
|
911
|
+
return
|
|
912
|
+
assert self._captured_deepep_mode is not None
|
|
913
|
+
DeepEPBuffer.set_dispatch_mode(self._captured_deepep_mode)
|